diff --git a/node/airdrop_v2.py b/node/airdrop_v2.py index 80ed1bcec..1fdc8424a 100644 --- a/node/airdrop_v2.py +++ b/node/airdrop_v2.py @@ -34,6 +34,7 @@ import hmac import json import logging +import math import os import re import sqlite3 @@ -1248,16 +1249,58 @@ def require_admin_key(): return jsonify({"ok": False, "error": "unauthorized"}), 401 return None + def text_field(data: Dict[str, Any], field: str) -> Tuple[Optional[str], Optional[Tuple[Any, int]]]: + value = data.get(field, "") + if not isinstance(value, str): + return None, ( + jsonify({"ok": False, "error": "invalid_field_type", "field": field}), + 400, + ) + return value.strip(), None + + def wrtc_amount_field(data: Dict[str, Any]) -> Tuple[Optional[int], Optional[Tuple[Any, int]]]: + value = data.get("amount_wrtc", 0) + if isinstance(value, bool): + return None, ( + jsonify({"ok": False, "error": "invalid_amount_wrtc"}), + 400, + ) + try: + amount_wrtc = float(value) + except (TypeError, ValueError): + return None, ( + jsonify({"ok": False, "error": "invalid_amount_wrtc"}), + 400, + ) + if not math.isfinite(amount_wrtc) or amount_wrtc <= 0: + return None, ( + jsonify({"ok": False, "error": "invalid_amount_wrtc"}), + 400, + ) + amount_uwrtc = int(amount_wrtc * 1_000_000) + if amount_uwrtc <= 0: + return None, ( + jsonify({"ok": False, "error": "invalid_amount_wrtc"}), + 400, + ) + return amount_uwrtc, None + @app.route("/api/airdrop/eligibility", methods=["POST"]) def check_airdrop_eligibility(): """Check airdrop eligibility.""" data = request.get_json(silent=True) - if not data: + if not isinstance(data, dict) or not data: return jsonify({"ok": False, "error": "invalid_json"}), 400 - github_username = data.get("github_username", "").strip() - wallet_address = data.get("wallet_address", "").strip() - chain = data.get("chain", "").strip() + github_username, error = text_field(data, "github_username") + if error: + return error + wallet_address, error = text_field(data, "wallet_address") + if error: + return error + chain, error = text_field(data, "chain") + if error: + return error github_token = data.get("github_token") # SECURITY: skip_antisybil must NEVER be settable from API requests. # It exists only for internal testing via direct Python calls. @@ -1279,13 +1322,21 @@ def check_airdrop_eligibility(): def claim_airdrop(): """Submit airdrop claim.""" data = request.get_json(silent=True) - if not data: + if not isinstance(data, dict) or not data: return jsonify({"ok": False, "error": "invalid_json"}), 400 - github_username = data.get("github_username", "").strip() - wallet_address = data.get("wallet_address", "").strip() - chain = data.get("chain", "").strip() - tier = data.get("tier", "").strip() + github_username, error = text_field(data, "github_username") + if error: + return error + wallet_address, error = text_field(data, "wallet_address") + if error: + return error + chain, error = text_field(data, "chain") + if error: + return error + tier, error = text_field(data, "tier") + if error: + return error github_token = data.get("github_token") if not all([github_username, wallet_address, chain, tier]): @@ -1326,14 +1377,21 @@ def get_airdrop_stats(): def create_bridge_lock(): """Create bridge lock.""" data = request.get_json(silent=True) - if not data: + if not isinstance(data, dict) or not data: return jsonify({"ok": False, "error": "invalid_json"}), 400 - from_address = data.get("from_address", "").strip() - to_address = data.get("to_address", "").strip() - from_chain = data.get("from_chain", "").strip() - to_chain = data.get("to_chain", "").strip() - amount_wrtc = data.get("amount_wrtc", 0) + from_address, error = text_field(data, "from_address") + if error: + return error + to_address, error = text_field(data, "to_address") + if error: + return error + from_chain, error = text_field(data, "from_chain") + if error: + return error + to_chain, error = text_field(data, "to_chain") + if error: + return error if not all([from_address, to_address, from_chain, to_chain]): return ( @@ -1347,7 +1405,9 @@ def create_bridge_lock(): 400, ) - amount_uwrtc = int(float(amount_wrtc) * 1_000_000) + amount_uwrtc, error = wrtc_amount_field(data) + if error: + return error success, message, lock = airdrop.create_bridge_lock( from_address, to_address, from_chain, to_chain, amount_uwrtc diff --git a/node/test_airdrop_v2_route_validation.py b/node/test_airdrop_v2_route_validation.py new file mode 100644 index 000000000..cc4af4e55 --- /dev/null +++ b/node/test_airdrop_v2_route_validation.py @@ -0,0 +1,101 @@ +import os +import tempfile +import unittest + +from flask import Flask + +from airdrop_v2 import AirdropV2, init_airdrop_routes + + +class TestAirdropV2RouteValidation(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.tmp.close() + self.airdrop = AirdropV2(db_path=self.tmp.name) + app = Flask(__name__) + init_airdrop_routes(app, self.airdrop, self.tmp.name) + app.config["TESTING"] = False + self.client = app.test_client() + + def tearDown(self): + self.client = None + os.unlink(self.tmp.name) + + def test_public_routes_reject_non_string_fields_without_500(self): + cases = [ + ( + "eligibility_github_list", + "/api/airdrop/eligibility", + {"github_username": [], "wallet_address": "x" * 32, "chain": "solana"}, + ), + ( + "eligibility_wallet_dict", + "/api/airdrop/eligibility", + {"github_username": "octocat", "wallet_address": {}, "chain": "solana"}, + ), + ( + "eligibility_chain_list", + "/api/airdrop/eligibility", + {"github_username": "octocat", "wallet_address": "x" * 32, "chain": []}, + ), + ( + "claim_tier_dict", + "/api/airdrop/claim", + { + "github_username": "octocat", + "wallet_address": "x" * 32, + "chain": "solana", + "tier": {}, + }, + ), + ( + "bridge_from_list", + "/api/bridge/lock", + { + "from_address": [], + "to_address": "dest_wallet_12345", + "from_chain": "solana", + "to_chain": "base", + "amount_wrtc": 1, + }, + ), + ] + + for name, path, payload in cases: + with self.subTest(name=name): + response = self.client.post(path, json=payload) + self.assertEqual(response.status_code, 400) + self.assertNotIn("Internal Server Error", response.get_data(as_text=True)) + + def test_public_routes_reject_array_json_bodies_without_500(self): + cases = [ + "/api/airdrop/eligibility", + "/api/airdrop/claim", + "/api/bridge/lock", + ] + + for path in cases: + with self.subTest(path=path): + response = self.client.post(path, json=["not", "an", "object"]) + self.assertEqual(response.status_code, 400) + self.assertNotIn("Internal Server Error", response.get_data(as_text=True)) + + def test_bridge_lock_rejects_invalid_amount_without_500(self): + base_payload = { + "from_address": "source_wallet_12345", + "to_address": "dest_wallet_12345", + "from_chain": "solana", + "to_chain": "base", + } + for amount_wrtc in ("not-a-number", "nan", "inf", {}, [], True): + with self.subTest(amount_wrtc=amount_wrtc): + response = self.client.post( + "/api/bridge/lock", + json={**base_payload, "amount_wrtc": amount_wrtc}, + ) + self.assertEqual(response.status_code, 400) + self.assertNotIn("Internal Server Error", response.get_data(as_text=True)) + + +if __name__ == "__main__": + unittest.main()