diff --git a/rips/rustchain-core/api/rpc.py b/rips/rustchain-core/api/rpc.py index 08b7c8563..49b38811b 100644 --- a/rips/rustchain-core/api/rpc.py +++ b/rips/rustchain-core/api/rpc.py @@ -13,6 +13,8 @@ """ import json +import hmac +import os import time from dataclasses import dataclass from typing import Dict, Any, Optional, Callable @@ -255,6 +257,16 @@ class ApiRequestHandler(BaseHTTPRequestHandler): """HTTP request handler for API""" api: RustChainApi = None # Set by server + STATE_CHANGING_PATHS = { + "/api/mine", + "/api/governance/create", + "/api/governance/vote", + } + STATE_CHANGING_RPC_METHODS = { + "submitProof", + "createProposal", + "vote", + } def do_GET(self): """Handle GET requests""" @@ -262,7 +274,7 @@ def do_GET(self): path = parsed.path params = {k: v[0] for k, v in parse_qs(parsed.query).items()} - response = self._route_request(path, params) + response = self._route_request(path, params, "GET") self._send_response(response) def do_POST(self): @@ -276,11 +288,41 @@ def do_POST(self): params = {} parsed = urlparse(self.path) - response = self._route_request(parsed.path, params) + response = self._route_request(parsed.path, params, "POST") self._send_response(response) - def _route_request(self, path: str, params: Dict[str, Any]) -> ApiResponse: + def do_OPTIONS(self): + """Handle CORS preflight requests.""" + if not self._is_origin_allowed(): + self.send_response(403) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(ApiResponse( + success=False, + error="Origin not allowed", + ).to_json().encode()) + return + + self.send_response(204) + self._send_cors_headers() + self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + self.send_header("Access-Control-Allow-Headers", "Content-Type, X-CSRF-Token") + self.send_header("Access-Control-Max-Age", "600") + self.end_headers() + + def _route_request(self, path: str, params: Dict[str, Any], http_method: str) -> ApiResponse: """Route request to appropriate handler""" + if self._is_state_changing_request(path, params): + csrf_response = self._validate_csrf_token() + if csrf_response: + return csrf_response + + if http_method != "POST": + return ApiResponse( + success=False, + error="State-changing endpoints require POST", + ) + # REST endpoints routes = { "/api/stats": ("getStats", {}), @@ -334,10 +376,49 @@ def _send_response(self, response: ApiResponse): """Send HTTP response""" self.send_response(200 if response.success else 400) self.send_header("Content-Type", "application/json") - self.send_header("Access-Control-Allow-Origin", "*") + self._send_cors_headers() self.end_headers() self.wfile.write(response.to_json().encode()) + def _is_state_changing_request(self, path: str, params: Dict[str, Any]) -> bool: + """Return whether a request can mutate node state.""" + if path in self.STATE_CHANGING_PATHS: + return True + return ( + path == "/rpc" + and params.get("method", "") in self.STATE_CHANGING_RPC_METHODS + ) + + def _validate_csrf_token(self) -> Optional[ApiResponse]: + """Require an explicit CSRF token for state-changing operations.""" + expected = os.environ.get("RUSTCHAIN_API_CSRF_TOKEN", "") + if not expected: + return ApiResponse(success=False, error="CSRF token is not configured") + + supplied = self.headers.get("X-CSRF-Token", "") + if not hmac.compare_digest(supplied, expected): + return ApiResponse(success=False, error="Invalid CSRF token") + return None + + def _is_origin_allowed(self) -> bool: + """Return whether the request Origin is configured for CORS access.""" + origin = self.headers.get("Origin", "") + if not origin: + return True + return origin in self._allowed_cors_origins() + + def _send_cors_headers(self): + """Emit CORS headers only for exact configured origins.""" + origin = self.headers.get("Origin", "") + if origin and origin in self._allowed_cors_origins(): + self.send_header("Access-Control-Allow-Origin", origin) + self.send_header("Vary", "Origin") + + @staticmethod + def _allowed_cors_origins() -> set: + raw = os.environ.get("RUSTCHAIN_API_ALLOWED_ORIGINS", "") + return {origin.strip() for origin in raw.split(",") if origin.strip()} + def log_message(self, format, *args): """Suppress default logging""" pass diff --git a/tests/test_rustchain_core_api_cors_csrf.py b/tests/test_rustchain_core_api_cors_csrf.py new file mode 100644 index 000000000..f08cd4993 --- /dev/null +++ b/tests/test_rustchain_core_api_cors_csrf.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: MIT +""" +Regression tests for RustChain core API CORS and CSRF handling. + +The API module is loaded directly because the package path contains a hyphen. +""" + +import http.client +import importlib.util +import json +import os +import sys +import threading +from http.server import HTTPServer + + +def _load_rpc_namespace(): + rpc_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "rips", + "rustchain-core", + "api", + "rpc.py", + ) + module_name = "rustchain_core_rpc_api_under_test" + spec = importlib.util.spec_from_file_location(module_name, rpc_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module.__dict__ + + +RPC = _load_rpc_namespace() +ApiRequestHandler = RPC["ApiRequestHandler"] +RustChainApi = RPC["RustChainApi"] +MockNode = RPC["MockNode"] + + +class _ApiServerFixture: + def __enter__(self): + self.node = MockNode() + ApiRequestHandler.api = RustChainApi(self.node) + self.server = HTTPServer(("127.0.0.1", 0), ApiRequestHandler) + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread.start() + self.url = f"http://127.0.0.1:{self.server.server_port}" + return self + + def __exit__(self, exc_type, exc, tb): + self.server.shutdown() + self.thread.join(timeout=5) + + def request(self, path, method="GET", body=None, headers=None): + connection = http.client.HTTPConnection( + "127.0.0.1", + self.server.server_port, + timeout=5, + ) + connection.request(method, path, body=body, headers=headers or {}) + response = connection.getresponse() + raw_body = response.read() + headers = dict(response.headers) + connection.close() + return response.status, headers, raw_body + + +def _json_body(raw_body): + return json.loads(raw_body.decode()) + + +def test_unconfigured_cors_does_not_emit_wildcard_origin(monkeypatch): + monkeypatch.delenv("RUSTCHAIN_API_ALLOWED_ORIGINS", raising=False) + + with _ApiServerFixture() as server: + status, headers, raw_body = server.request( + "/api/stats", + headers={"Origin": "https://evil.example"}, + ) + + assert status == 200 + assert _json_body(raw_body)["success"] is True + assert headers.get("Access-Control-Allow-Origin") is None + + +def test_configured_cors_echoes_only_exact_allowed_origin(monkeypatch): + monkeypatch.setenv("RUSTCHAIN_API_ALLOWED_ORIGINS", "https://app.rustchain.org") + + with _ApiServerFixture() as server: + status, headers, _ = server.request( + "/api/stats", + headers={"Origin": "https://app.rustchain.org"}, + ) + + assert status == 200 + assert headers["Access-Control-Allow-Origin"] == "https://app.rustchain.org" + assert headers["Access-Control-Allow-Origin"] != "*" + + +def test_state_changing_post_fails_closed_without_configured_csrf(monkeypatch): + monkeypatch.delenv("RUSTCHAIN_API_CSRF_TOKEN", raising=False) + + with _ApiServerFixture() as server: + status, _, raw_body = server.request( + "/api/mine", + method="POST", + body=b'{"wallet": "alice"}', + headers={"Content-Type": "application/json"}, + ) + + body = _json_body(raw_body) + assert status == 400 + assert body["success"] is False + assert body["error"] == "CSRF token is not configured" + + +def test_state_changing_rpc_requires_matching_csrf_token(monkeypatch): + monkeypatch.setenv("RUSTCHAIN_API_CSRF_TOKEN", "secret-token") + payload = json.dumps({ + "method": "submitProof", + "params": {"wallet": "alice"}, + }).encode() + + with _ApiServerFixture() as server: + bad_status, _, bad_body = server.request( + "/rpc", + method="POST", + body=payload, + headers={"Content-Type": "application/json", "X-CSRF-Token": "wrong"}, + ) + ok_status, _, ok_body = server.request( + "/rpc", + method="POST", + body=payload, + headers={"Content-Type": "application/json", "X-CSRF-Token": "secret-token"}, + ) + + assert bad_status == 400 + assert _json_body(bad_body)["error"] == "Invalid CSRF token" + assert ok_status == 200 + assert _json_body(ok_body)["success"] is True