diff --git a/otc-bridge/otc_bridge.py b/otc-bridge/otc_bridge.py index 6a3dcc379..dd66c9e3a 100644 --- a/otc-bridge/otc_bridge.py +++ b/otc-bridge/otc_bridge.py @@ -73,7 +73,23 @@ logging.basicConfig(level=logging.INFO) app = Flask(__name__, static_folder="static") -CORS(app) + +DEFAULT_OTC_CORS_ORIGINS = ("https://bottube.ai", "https://rustchain.org") + + +def parse_cors_origins(raw_origins): + if raw_origins is None: + return list(DEFAULT_OTC_CORS_ORIGINS) + + raw_items = [origin.strip() for origin in raw_origins.split(",") if origin.strip()] + origins = [origin for origin in raw_items if origin != "*"] + if len(origins) != len(raw_items): + log.warning("Ignoring wildcard CORS origin for OTC bridge") + return origins or list(DEFAULT_OTC_CORS_ORIGINS) + + +OTC_CORS_ORIGINS = parse_cors_origins(os.environ.get("OTC_CORS_ORIGINS")) +CORS(app, origins=OTC_CORS_ORIGINS) # --------------------------------------------------------------------------- diff --git a/otc-bridge/test_otc_bridge.py b/otc-bridge/test_otc_bridge.py index 9da25a8bd..a0dc38ec0 100644 --- a/otc-bridge/test_otc_bridge.py +++ b/otc-bridge/test_otc_bridge.py @@ -3,28 +3,38 @@ """ import json import os +import sqlite3 import tempfile import time import unittest from unittest.mock import patch, MagicMock -# Set test DB before importing +# Set an initial test DB before importing, then replace it per test. _fd, TEST_DB = tempfile.mkstemp(suffix=".db") os.close(_fd) os.environ["OTC_DB_PATH"] = TEST_DB +import otc_bridge from otc_bridge import app, init_db class OTCBridgeTestCase(unittest.TestCase): def setUp(self): + global TEST_DB + _fd, TEST_DB = tempfile.mkstemp(suffix=".db") + os.close(_fd) + otc_bridge.DB_PATH = TEST_DB self.app = app.test_client() self.app.testing = True init_db() def tearDown(self): + self.app = None if os.path.exists(TEST_DB): - os.remove(TEST_DB) + try: + os.remove(TEST_DB) + except PermissionError: + pass # --------------------------------------------------------------- # Order Creation @@ -288,9 +298,15 @@ def test_confirm_matched_order(self): # Confirm settlement with patch("requests.post") as mock_post: mock_post.return_value = MagicMock(ok=True, text='{"ok":true}') + with sqlite3.connect(TEST_DB) as conn: + secret = conn.execute( + "SELECT htlc_secret FROM orders WHERE order_id = ?", + (order_id,), + ).fetchone()[0] r3 = self.app.post(f"/api/orders/{order_id}/confirm", json={ "wallet": "buyer1", "quote_tx": "0xabc123def456", + "secret": secret, }) data = r3.get_json() self.assertTrue(data["ok"]) @@ -354,6 +370,14 @@ def test_frontend_served(self): self.assertEqual(r.status_code, 200) self.assertIn(b"RustChain OTC Bridge", r.data) + def test_cors_rejects_untrusted_origin(self): + r = self.app.get("/api/stats", headers={"Origin": "https://evil.example"}) + self.assertNotEqual(r.headers.get("Access-Control-Allow-Origin"), "https://evil.example") + + def test_cors_allows_configured_origin(self): + r = self.app.get("/api/stats", headers={"Origin": "https://rustchain.org"}) + self.assertEqual(r.headers.get("Access-Control-Allow-Origin"), "https://rustchain.org") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_otc_bridge_query_validation.py b/tests/test_otc_bridge_query_validation.py index 381995086..eb8097a35 100644 --- a/tests/test_otc_bridge_query_validation.py +++ b/tests/test_otc_bridge_query_validation.py @@ -5,14 +5,17 @@ from pathlib import Path -def load_otc_bridge(tmp_path): - if "flask_cors" not in sys.modules: - flask_cors = types.ModuleType("flask_cors") - flask_cors.CORS = lambda app: app - sys.modules["flask_cors"] = flask_cors +def load_otc_bridge(tmp_path, cors_origins=None): + flask_cors = sys.modules.get("flask_cors") or types.ModuleType("flask_cors") + flask_cors.CORS = lambda app, **kwargs: app + sys.modules["flask_cors"] = flask_cors db_path = tmp_path / "otc_bridge.db" os.environ["OTC_DB_PATH"] = str(db_path) + if cors_origins is None: + os.environ.pop("OTC_CORS_ORIGINS", None) + else: + os.environ["OTC_CORS_ORIGINS"] = cors_origins module_path = Path(__file__).resolve().parents[1] / "otc-bridge" / "otc_bridge.py" spec = importlib.util.spec_from_file_location("otc_bridge_under_test", module_path) @@ -23,6 +26,35 @@ def load_otc_bridge(tmp_path): return module +def test_cors_defaults_to_restricted_public_origins(tmp_path): + otc_bridge = load_otc_bridge(tmp_path) + + assert otc_bridge.OTC_CORS_ORIGINS == ["https://bottube.ai", "https://rustchain.org"] + assert "*" not in otc_bridge.OTC_CORS_ORIGINS + + +def test_cors_env_ignores_wildcard_origin(tmp_path): + otc_bridge = load_otc_bridge( + tmp_path, + cors_origins="*, https://trusted.example, http://localhost:3000", + ) + + assert otc_bridge.OTC_CORS_ORIGINS == [ + "https://trusted.example", + "http://localhost:3000", + ] + + +def test_cors_env_ignores_all_wildcard_origins(tmp_path): + otc_bridge = load_otc_bridge( + tmp_path, + cors_origins="*, *, https://trusted.example", + ) + + assert otc_bridge.OTC_CORS_ORIGINS == ["https://trusted.example"] + assert "*" not in otc_bridge.OTC_CORS_ORIGINS + + def test_orders_rejects_malformed_pagination(tmp_path): otc_bridge = load_otc_bridge(tmp_path)