Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@ jobs:
run: |
python main.py --help
python main.py --version

- name: Unit tests
run: python -m unittest discover -s tests
63 changes: 63 additions & 0 deletions tests/test_domain_fronter_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import base64
import sys
import unittest
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from relay.domain_fronter import DomainFronter


class DomainFronterPayloadTests(unittest.TestCase):
def make_fronter(self):
return DomainFronter({
"google_ip": "216.239.38.120",
"front_domain": "www.google.com",
"script_id": "dummy-deployment-id",
"auth_key": "test-secret",
})

def test_build_payload_strips_proxy_and_ip_leaking_headers(self):
payload = self.make_fronter()._build_payload(
"GET",
"https://example.com/",
{
"User-Agent": "unit-test",
"X-Forwarded-For": "198.51.100.10",
"X-Real-IP": "198.51.100.10",
"Via": "proxy",
"Proxy-Authorization": "secret",
},
b"",
)

self.assertEqual(payload["h"], {"User-Agent": "unit-test"})

def test_build_payload_does_not_readd_headers_when_all_are_stripped(self):
payload = self.make_fronter()._build_payload(
"GET",
"https://example.com/",
{
"X-Forwarded-For": "198.51.100.10",
"Forwarded": "for=198.51.100.10",
"Via": "proxy",
},
b"",
)

self.assertNotIn("h", payload)

def test_build_payload_base64_encodes_body_and_content_type(self):
payload = self.make_fronter()._build_payload(
"POST",
"https://example.com/api",
{"Content-Type": "application/json"},
b'{"ok":true}',
)

self.assertEqual(base64.b64decode(payload["b"]), b'{"ok":true}')
self.assertEqual(payload["ct"], "application/json")


if __name__ == "__main__":
unittest.main()
63 changes: 63 additions & 0 deletions tests/test_proxy_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import sys
import unittest
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from proxy.proxy_support import (
has_unsupported_transfer_encoding,
host_matches_rules,
inject_cors_headers,
load_host_rules,
parse_content_length,
)


class ProxySupportTests(unittest.TestCase):
def test_parse_content_length_matches_exact_header_name(self):
header_block = (
b"POST / HTTP/1.1\r\n"
b"X-Content-Length: 999\r\n"
b"Content-Length: 42\r\n"
b"\r\n"
)
self.assertEqual(parse_content_length(header_block), 42)

def test_transfer_encoding_only_identity_is_supported(self):
self.assertFalse(has_unsupported_transfer_encoding(
b"POST / HTTP/1.1\r\nTransfer-Encoding: identity\r\n\r\n"
))
self.assertTrue(has_unsupported_transfer_encoding(
b"POST / HTTP/1.1\r\nTransfer-Encoding: gzip, chunked\r\n\r\n"
))

def test_host_rules_support_exact_and_suffix_matches(self):
rules = load_host_rules(["localhost", ".local", "Example.COM."])
self.assertTrue(host_matches_rules("localhost", rules))
self.assertTrue(host_matches_rules("printer.local", rules))
self.assertTrue(host_matches_rules("example.com", rules))
self.assertFalse(host_matches_rules("notexample.com", rules))

def test_inject_cors_headers_replaces_existing_policy_and_keeps_body(self):
response = (
b"HTTP/1.1 200 OK\r\n"
b"Access-Control-Allow-Origin: https://old.example\r\n"
b"Content-Type: text/plain\r\n"
b"Content-Length: 5\r\n"
b"\r\n"
b"hello"
)

rewritten = inject_cors_headers(response, "https://app.example")
header_block, body = rewritten.split(b"\r\n\r\n", 1)

self.assertEqual(body, b"hello")
self.assertNotIn(b"https://old.example", header_block)
self.assertIn(
b"Access-Control-Allow-Origin: https://app.example", header_block
)
self.assertIn(b"Access-Control-Allow-Credentials: true", header_block)


if __name__ == "__main__":
unittest.main()
58 changes: 58 additions & 0 deletions tests/test_relay_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import base64
import sys
import unittest
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from relay.relay_response import parse_relay_json, split_raw_response, split_set_cookie


class RelayResponseTests(unittest.TestCase):
def test_split_set_cookie_preserves_expires_comma(self):
raw = (
"sid=abc; Expires=Wed, 21 Oct 2026 07:28:00 GMT; Path=/, "
"theme=dark; Path=/"
)

self.assertEqual(
split_set_cookie(raw),
[
"sid=abc; Expires=Wed, 21 Oct 2026 07:28:00 GMT; Path=/",
"theme=dark; Path=/",
],
)

def test_parse_relay_json_rebuilds_http_response_and_set_cookie_lines(self):
data = {
"s": 200,
"h": {
"Content-Type": "text/plain",
"Content-Encoding": "gzip",
"Set-Cookie": [
"a=1; Path=/, b=2; Path=/",
],
},
"b": base64.b64encode(b"hello").decode(),
}

raw = parse_relay_json(data, max_body_bytes=1024)
status, headers, body = split_raw_response(raw)

self.assertEqual(status, 200)
self.assertEqual(body, b"hello")
self.assertEqual(headers["content-type"], "text/plain")
self.assertEqual(headers["content-length"], "5")
self.assertNotIn("content-encoding", headers)
self.assertEqual(raw.count(b"Set-Cookie:"), 2)

def test_parse_relay_json_error_returns_gateway_response(self):
raw = parse_relay_json({"e": "unauthorized"}, max_body_bytes=1024)
status, headers, body = split_raw_response(raw)

self.assertEqual(status, 502)
self.assertIn(b"auth/permission error", body)


if __name__ == "__main__":
unittest.main()