Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion otc-bridge/otc_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,23 @@
logging.basicConfig(level=logging.INFO)

app = Flask(__name__, static_folder="static")
CORS(app)

DEFAULT_OTC_CORS_ORIGINS = ("https://bottube.ai", "https://rustchain.org")


def parse_cors_origins(raw_origins):
if raw_origins is None:
return list(DEFAULT_OTC_CORS_ORIGINS)

raw_items = [origin.strip() for origin in raw_origins.split(",") if origin.strip()]
origins = [origin for origin in raw_items if origin != "*"]
if len(origins) != len(raw_items):
log.warning("Ignoring wildcard CORS origin for OTC bridge")
return origins or list(DEFAULT_OTC_CORS_ORIGINS)


OTC_CORS_ORIGINS = parse_cors_origins(os.environ.get("OTC_CORS_ORIGINS"))
CORS(app, origins=OTC_CORS_ORIGINS)


# ---------------------------------------------------------------------------
Expand Down
28 changes: 26 additions & 2 deletions otc-bridge/test_otc_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,38 @@
"""
import json
import os
import sqlite3
import tempfile
import time
import unittest
from unittest.mock import patch, MagicMock

# Set test DB before importing
# Set an initial test DB before importing, then replace it per test.
_fd, TEST_DB = tempfile.mkstemp(suffix=".db")
os.close(_fd)
os.environ["OTC_DB_PATH"] = TEST_DB

import otc_bridge
from otc_bridge import app, init_db


class OTCBridgeTestCase(unittest.TestCase):
def setUp(self):
global TEST_DB
_fd, TEST_DB = tempfile.mkstemp(suffix=".db")
os.close(_fd)
otc_bridge.DB_PATH = TEST_DB
self.app = app.test_client()
self.app.testing = True
init_db()

def tearDown(self):
self.app = None
if os.path.exists(TEST_DB):
os.remove(TEST_DB)
try:
os.remove(TEST_DB)
except PermissionError:
pass

# ---------------------------------------------------------------
# Order Creation
Expand Down Expand Up @@ -288,9 +298,15 @@ def test_confirm_matched_order(self):
# Confirm settlement
with patch("requests.post") as mock_post:
mock_post.return_value = MagicMock(ok=True, text='{"ok":true}')
with sqlite3.connect(TEST_DB) as conn:
secret = conn.execute(
"SELECT htlc_secret FROM orders WHERE order_id = ?",
(order_id,),
).fetchone()[0]
r3 = self.app.post(f"/api/orders/{order_id}/confirm", json={
"wallet": "buyer1",
"quote_tx": "0xabc123def456",
"secret": secret,
})
data = r3.get_json()
self.assertTrue(data["ok"])
Expand Down Expand Up @@ -354,6 +370,14 @@ def test_frontend_served(self):
self.assertEqual(r.status_code, 200)
self.assertIn(b"RustChain OTC Bridge", r.data)

def test_cors_rejects_untrusted_origin(self):
r = self.app.get("/api/stats", headers={"Origin": "https://evil.example"})
self.assertNotEqual(r.headers.get("Access-Control-Allow-Origin"), "https://evil.example")

def test_cors_allows_configured_origin(self):
r = self.app.get("/api/stats", headers={"Origin": "https://rustchain.org"})
self.assertEqual(r.headers.get("Access-Control-Allow-Origin"), "https://rustchain.org")


if __name__ == "__main__":
unittest.main()
42 changes: 37 additions & 5 deletions tests/test_otc_bridge_query_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
from pathlib import Path


def load_otc_bridge(tmp_path):
if "flask_cors" not in sys.modules:
flask_cors = types.ModuleType("flask_cors")
flask_cors.CORS = lambda app: app
sys.modules["flask_cors"] = flask_cors
def load_otc_bridge(tmp_path, cors_origins=None):
flask_cors = sys.modules.get("flask_cors") or types.ModuleType("flask_cors")
flask_cors.CORS = lambda app, **kwargs: app
sys.modules["flask_cors"] = flask_cors

db_path = tmp_path / "otc_bridge.db"
os.environ["OTC_DB_PATH"] = str(db_path)
if cors_origins is None:
os.environ.pop("OTC_CORS_ORIGINS", None)
else:
os.environ["OTC_CORS_ORIGINS"] = cors_origins

module_path = Path(__file__).resolve().parents[1] / "otc-bridge" / "otc_bridge.py"
spec = importlib.util.spec_from_file_location("otc_bridge_under_test", module_path)
Expand All @@ -23,6 +26,35 @@ def load_otc_bridge(tmp_path):
return module


def test_cors_defaults_to_restricted_public_origins(tmp_path):
otc_bridge = load_otc_bridge(tmp_path)

assert otc_bridge.OTC_CORS_ORIGINS == ["https://bottube.ai", "https://rustchain.org"]
assert "*" not in otc_bridge.OTC_CORS_ORIGINS


def test_cors_env_ignores_wildcard_origin(tmp_path):
otc_bridge = load_otc_bridge(
tmp_path,
cors_origins="*, https://trusted.example, http://localhost:3000",
)

assert otc_bridge.OTC_CORS_ORIGINS == [
"https://trusted.example",
"http://localhost:3000",
]


def test_cors_env_ignores_all_wildcard_origins(tmp_path):
otc_bridge = load_otc_bridge(
tmp_path,
cors_origins="*, *, https://trusted.example",
)

assert otc_bridge.OTC_CORS_ORIGINS == ["https://trusted.example"]
assert "*" not in otc_bridge.OTC_CORS_ORIGINS


def test_orders_rejects_malformed_pagination(tmp_path):
otc_bridge = load_otc_bridge(tmp_path)

Expand Down