diff --git a/node/bridge_api.py b/node/bridge_api.py index ee76ef205..7086ae46b 100644 --- a/node/bridge_api.py +++ b/node/bridge_api.py @@ -20,6 +20,7 @@ import hashlib import logging import os +import math from typing import Optional, Tuple, Dict, Any from decimal import Decimal from dataclasses import dataclass @@ -120,6 +121,8 @@ def validate_bridge_request(data: Optional[Dict]) -> ValidationResult: """Validate bridge transfer request payload.""" if not data: return ValidationResult(ok=False, error="Request body is required") + if not isinstance(data, dict): + return ValidationResult(ok=False, error="Request body must be a JSON object") # Required fields required = ["direction", "source_chain", "dest_chain", "source_address", "dest_address", "amount_rtc"] @@ -129,12 +132,20 @@ def validate_bridge_request(data: Optional[Dict]) -> ValidationResult: # Validate direction direction = data.get("direction") + if not isinstance(direction, str): + return ValidationResult(ok=False, error="direction must be a string") if direction not in ["deposit", "withdraw"]: return ValidationResult(ok=False, error=f"Invalid direction: {direction}. Must be 'deposit' or 'withdraw'") # Validate chains - source_chain = data.get("source_chain", "").lower() - dest_chain = data.get("dest_chain", "").lower() + source_chain_raw = data.get("source_chain", "") + dest_chain_raw = data.get("dest_chain", "") + if not isinstance(source_chain_raw, str): + return ValidationResult(ok=False, error="source_chain must be a string") + if not isinstance(dest_chain_raw, str): + return ValidationResult(ok=False, error="dest_chain must be a string") + source_chain = source_chain_raw.lower() + dest_chain = dest_chain_raw.lower() if source_chain not in VALID_CHAINS: return ValidationResult(ok=False, error=f"Invalid source_chain: {source_chain}") @@ -156,6 +167,10 @@ def validate_bridge_request(data: Optional[Dict]) -> ValidationResult: # Validate addresses source_address = data.get("source_address", "") dest_address = data.get("dest_address", "") + if not isinstance(source_address, str): + return ValidationResult(ok=False, error="source_address must be a string") + if not isinstance(dest_address, str): + return ValidationResult(ok=False, error="dest_address must be a string") if not source_address or len(source_address) < 10: return ValidationResult(ok=False, error="Invalid source_address (too short)") @@ -163,11 +178,16 @@ def validate_bridge_request(data: Optional[Dict]) -> ValidationResult: return ValidationResult(ok=False, error="Invalid dest_address (too short)") # Validate amount + amount_raw = data.get("amount_rtc", 0) + if isinstance(amount_raw, bool): + return ValidationResult(ok=False, error="amount_rtc must be a number") try: - amount_rtc = float(data.get("amount_rtc", 0)) + amount_rtc = float(amount_raw) except (TypeError, ValueError): return ValidationResult(ok=False, error="amount_rtc must be a number") + if not math.isfinite(amount_rtc): + return ValidationResult(ok=False, error="amount_rtc must be finite") if amount_rtc <= 0: return ValidationResult(ok=False, error="amount_rtc must be positive") if amount_rtc < BRIDGE_MIN_AMOUNT_RTC: @@ -175,11 +195,15 @@ def validate_bridge_request(data: Optional[Dict]) -> ValidationResult: # Validate bridge type (optional) bridge_type = data.get("bridge_type", "bottube") + if not isinstance(bridge_type, str): + return ValidationResult(ok=False, error="bridge_type must be a string") if bridge_type not in VALID_BRIDGE_TYPES: return ValidationResult(ok=False, error=f"Invalid bridge_type: {bridge_type}") # Validate memo (optional) memo = data.get("memo") + if memo is not None and not isinstance(memo, str): + return ValidationResult(ok=False, error="memo must be a string") if memo and len(memo) > 256: return ValidationResult(ok=False, error="Memo must be <= 256 characters") @@ -697,11 +721,12 @@ def initiate_bridge(): validation = validate_bridge_request(data) if not validation.ok: return jsonify({"error": validation.error}), 400 + details = validation.details or {} # Validate address formats for chain, addr in [ - (data["source_chain"], data["source_address"]), - (data["dest_chain"], data["dest_address"]) + (details["source_chain"], details["source_address"]), + (details["dest_chain"], details["dest_address"]) ]: valid, msg = validate_chain_address_format(chain, addr) if not valid: @@ -711,7 +736,7 @@ def initiate_bridge(): admin_key = request.headers.get("X-Admin-Key", "") expected_admin_key = os.environ.get("RC_ADMIN_KEY", "") admin_initiated = bool(expected_admin_key) and hmac.compare_digest(admin_key, expected_admin_key) - if data["direction"] == "deposit": + if details["direction"] == "deposit": # Deposits create balance locks by source_address; require operator # authorization until a wallet-owner signature flow exists. if not expected_admin_key: @@ -721,14 +746,14 @@ def initiate_bridge(): # Create bridge transfer req = BridgeTransferRequest( - direction=data["direction"], - source_chain=data["source_chain"], - dest_chain=data["dest_chain"], - source_address=data["source_address"], - dest_address=data["dest_address"], - amount_rtc=data["amount_rtc"], - memo=data.get("memo"), - bridge_type=data.get("bridge_type", "bottube") + direction=details["direction"], + source_chain=details["source_chain"], + dest_chain=details["dest_chain"], + source_address=details["source_address"], + dest_address=details["dest_address"], + amount_rtc=details["amount_rtc"], + memo=details.get("memo"), + bridge_type=details["bridge_type"] ) conn = sqlite3.connect(DB_PATH) diff --git a/node/test_bridge_initiate_type_validation.py b/node/test_bridge_initiate_type_validation.py new file mode 100644 index 000000000..93bcffdb4 --- /dev/null +++ b/node/test_bridge_initiate_type_validation.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import sqlite3 +import tempfile +import unittest + +from flask import Flask + +import bridge_api + + +class TestBridgeInitiateTypeValidation(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.tmp.close() + self.db_path = self.tmp.name + bridge_api.DB_PATH = self.db_path + conn = sqlite3.connect(self.db_path) + try: + bridge_api.init_bridge_schema(conn.cursor()) + conn.execute( + "CREATE TABLE IF NOT EXISTS balances (miner_id TEXT PRIMARY KEY, amount_i64 INTEGER DEFAULT 0)" + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS lock_ledger ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + bridge_transfer_id INTEGER, + miner_id TEXT, + amount_i64 INTEGER, + lock_type TEXT, + locked_at INTEGER, + unlock_at INTEGER, + status TEXT, + created_at INTEGER + ) + """ + ) + conn.commit() + finally: + conn.close() + + app = Flask(__name__) + bridge_api.register_bridge_routes(app) + app.config["TESTING"] = False + self.client = app.test_client() + + def tearDown(self): + self.client = None + os.unlink(self.db_path) + + def valid_payload(self): + return { + "direction": "withdraw", + "source_chain": "solana", + "dest_chain": "rustchain", + "source_address": "S" * 32, + "dest_address": "RTCdestination12345", + "amount_rtc": 1.0, + } + + def test_malformed_json_field_types_return_400_not_500(self): + cases = { + "source_chain_list": {"source_chain": []}, + "dest_chain_dict": {"dest_chain": {}}, + "source_address_list": {"source_address": ["x"] * 12}, + "dest_address_dict": {"dest_address": {"wallet": "RTCdestination12345"}}, + "amount_bool": {"amount_rtc": True}, + "bridge_type_list": {"bridge_type": []}, + "memo_dict": {"memo": {"note": "not a string"}}, + } + + for name, override in cases.items(): + with self.subTest(name=name): + payload = {**self.valid_payload(), **override} + response = self.client.post("/api/bridge/initiate", json=payload) + self.assertEqual(response.status_code, 400) + + def test_non_finite_amounts_return_400_not_500(self): + for amount_rtc in ("nan", "inf", "-inf"): + with self.subTest(amount_rtc=amount_rtc): + payload = {**self.valid_payload(), "amount_rtc": amount_rtc} + response = self.client.post("/api/bridge/initiate", json=payload) + self.assertEqual(response.status_code, 400) + + def test_mixed_case_chain_uses_normalized_value_for_address_validation(self): + payload = { + **self.valid_payload(), + "source_chain": "Base", + "source_address": "not-a-base-wallet", + } + + response = self.client.post("/api/bridge/initiate", json=payload) + + self.assertEqual(response.status_code, 400) + + def test_successful_mixed_case_chain_response_uses_normalized_values(self): + payload = { + **self.valid_payload(), + "source_chain": "Solana", + "dest_chain": "RustChain", + } + + response = self.client.post("/api/bridge/initiate", json=payload) + + self.assertEqual(response.status_code, 200) + body = response.get_json() + self.assertEqual(body["source_chain"], "solana") + self.assertEqual(body["dest_chain"], "rustchain") + + +if __name__ == "__main__": + unittest.main()