Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 85 additions & 4 deletions rips/rustchain-core/api/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"""

import json
import hmac
import os
import time
from dataclasses import dataclass
from typing import Dict, Any, Optional, Callable
Expand Down Expand Up @@ -255,14 +257,24 @@ class ApiRequestHandler(BaseHTTPRequestHandler):
"""HTTP request handler for API"""

api: RustChainApi = None # Set by server
STATE_CHANGING_PATHS = {
"/api/mine",
"/api/governance/create",
"/api/governance/vote",
}
STATE_CHANGING_RPC_METHODS = {
"submitProof",
"createProposal",
"vote",
}

def do_GET(self):
"""Handle GET requests"""
parsed = urlparse(self.path)
path = parsed.path
params = {k: v[0] for k, v in parse_qs(parsed.query).items()}

response = self._route_request(path, params)
response = self._route_request(path, params, "GET")
self._send_response(response)

def do_POST(self):
Expand All @@ -276,11 +288,41 @@ def do_POST(self):
params = {}

parsed = urlparse(self.path)
response = self._route_request(parsed.path, params)
response = self._route_request(parsed.path, params, "POST")
self._send_response(response)

def _route_request(self, path: str, params: Dict[str, Any]) -> ApiResponse:
def do_OPTIONS(self):
"""Handle CORS preflight requests."""
if not self._is_origin_allowed():
self.send_response(403)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(ApiResponse(
success=False,
error="Origin not allowed",
).to_json().encode())
return

self.send_response(204)
self._send_cors_headers()
self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type, X-CSRF-Token")
self.send_header("Access-Control-Max-Age", "600")
self.end_headers()

def _route_request(self, path: str, params: Dict[str, Any], http_method: str) -> ApiResponse:
"""Route request to appropriate handler"""
if self._is_state_changing_request(path, params):
csrf_response = self._validate_csrf_token()
if csrf_response:
return csrf_response

if http_method != "POST":
return ApiResponse(
success=False,
error="State-changing endpoints require POST",
)

# REST endpoints
routes = {
"/api/stats": ("getStats", {}),
Expand Down Expand Up @@ -334,10 +376,49 @@ def _send_response(self, response: ApiResponse):
"""Send HTTP response"""
self.send_response(200 if response.success else 400)
self.send_header("Content-Type", "application/json")
self.send_header("Access-Control-Allow-Origin", "*")
self._send_cors_headers()
self.end_headers()
self.wfile.write(response.to_json().encode())

def _is_state_changing_request(self, path: str, params: Dict[str, Any]) -> bool:
"""Return whether a request can mutate node state."""
if path in self.STATE_CHANGING_PATHS:
return True
return (
path == "/rpc"
and params.get("method", "") in self.STATE_CHANGING_RPC_METHODS
)

def _validate_csrf_token(self) -> Optional[ApiResponse]:
"""Require an explicit CSRF token for state-changing operations."""
expected = os.environ.get("RUSTCHAIN_API_CSRF_TOKEN", "")
if not expected:
return ApiResponse(success=False, error="CSRF token is not configured")

supplied = self.headers.get("X-CSRF-Token", "")
if not hmac.compare_digest(supplied, expected):
return ApiResponse(success=False, error="Invalid CSRF token")
return None

def _is_origin_allowed(self) -> bool:
"""Return whether the request Origin is configured for CORS access."""
origin = self.headers.get("Origin", "")
if not origin:
return True
return origin in self._allowed_cors_origins()

def _send_cors_headers(self):
"""Emit CORS headers only for exact configured origins."""
origin = self.headers.get("Origin", "")
if origin and origin in self._allowed_cors_origins():
self.send_header("Access-Control-Allow-Origin", origin)
self.send_header("Vary", "Origin")

@staticmethod
def _allowed_cors_origins() -> set:
raw = os.environ.get("RUSTCHAIN_API_ALLOWED_ORIGINS", "")
return {origin.strip() for origin in raw.split(",") if origin.strip()}

def log_message(self, format, *args):
"""Suppress default logging"""
pass
Expand Down
140 changes: 140 additions & 0 deletions tests/test_rustchain_core_api_cors_csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: MIT
"""
Regression tests for RustChain core API CORS and CSRF handling.

The API module is loaded directly because the package path contains a hyphen.
"""

import http.client
import importlib.util
import json
import os
import sys
import threading
from http.server import HTTPServer


def _load_rpc_namespace():
rpc_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"rips",
"rustchain-core",
"api",
"rpc.py",
)
module_name = "rustchain_core_rpc_api_under_test"
spec = importlib.util.spec_from_file_location(module_name, rpc_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module.__dict__


RPC = _load_rpc_namespace()
ApiRequestHandler = RPC["ApiRequestHandler"]
RustChainApi = RPC["RustChainApi"]
MockNode = RPC["MockNode"]


class _ApiServerFixture:
def __enter__(self):
self.node = MockNode()
ApiRequestHandler.api = RustChainApi(self.node)
self.server = HTTPServer(("127.0.0.1", 0), ApiRequestHandler)
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.thread.start()
self.url = f"http://127.0.0.1:{self.server.server_port}"
return self

def __exit__(self, exc_type, exc, tb):
self.server.shutdown()
self.thread.join(timeout=5)

def request(self, path, method="GET", body=None, headers=None):
connection = http.client.HTTPConnection(
"127.0.0.1",
self.server.server_port,
timeout=5,
)
connection.request(method, path, body=body, headers=headers or {})
response = connection.getresponse()
raw_body = response.read()
headers = dict(response.headers)
connection.close()
return response.status, headers, raw_body


def _json_body(raw_body):
return json.loads(raw_body.decode())


def test_unconfigured_cors_does_not_emit_wildcard_origin(monkeypatch):
monkeypatch.delenv("RUSTCHAIN_API_ALLOWED_ORIGINS", raising=False)

with _ApiServerFixture() as server:
status, headers, raw_body = server.request(
"/api/stats",
headers={"Origin": "https://evil.example"},
)

assert status == 200
assert _json_body(raw_body)["success"] is True
assert headers.get("Access-Control-Allow-Origin") is None


def test_configured_cors_echoes_only_exact_allowed_origin(monkeypatch):
monkeypatch.setenv("RUSTCHAIN_API_ALLOWED_ORIGINS", "https://app.rustchain.org")

with _ApiServerFixture() as server:
status, headers, _ = server.request(
"/api/stats",
headers={"Origin": "https://app.rustchain.org"},
)

assert status == 200
assert headers["Access-Control-Allow-Origin"] == "https://app.rustchain.org"
assert headers["Access-Control-Allow-Origin"] != "*"


def test_state_changing_post_fails_closed_without_configured_csrf(monkeypatch):
monkeypatch.delenv("RUSTCHAIN_API_CSRF_TOKEN", raising=False)

with _ApiServerFixture() as server:
status, _, raw_body = server.request(
"/api/mine",
method="POST",
body=b'{"wallet": "alice"}',
headers={"Content-Type": "application/json"},
)

body = _json_body(raw_body)
assert status == 400
assert body["success"] is False
assert body["error"] == "CSRF token is not configured"


def test_state_changing_rpc_requires_matching_csrf_token(monkeypatch):
monkeypatch.setenv("RUSTCHAIN_API_CSRF_TOKEN", "secret-token")
payload = json.dumps({
"method": "submitProof",
"params": {"wallet": "alice"},
}).encode()

with _ApiServerFixture() as server:
bad_status, _, bad_body = server.request(
"/rpc",
method="POST",
body=payload,
headers={"Content-Type": "application/json", "X-CSRF-Token": "wrong"},
)
ok_status, _, ok_body = server.request(
"/rpc",
method="POST",
body=payload,
headers={"Content-Type": "application/json", "X-CSRF-Token": "secret-token"},
)

assert bad_status == 400
assert _json_body(bad_body)["error"] == "Invalid CSRF token"
assert ok_status == 200
assert _json_body(ok_body)["success"] is True
Loading