diff --git a/packages/testing/src/consensus_testing/__init__.py b/packages/testing/src/consensus_testing/__init__.py index 8f476035..0108c98a 100644 --- a/packages/testing/src/consensus_testing/__init__.py +++ b/packages/testing/src/consensus_testing/__init__.py @@ -7,7 +7,6 @@ from .test_fixtures import ( ApiEndpointTest, BaseConsensusFixture, - DiscoveryCryptoTest, ForkChoiceTest, GossipsubHandlerTest, JustifiabilityTest, @@ -44,7 +43,6 @@ GossipsubHandlerTestFiller = Type[GossipsubHandlerTest] ApiEndpointTestFiller = Type[ApiEndpointTest] SlotClockTestFiller = Type[SlotClockTest] -DiscoveryCryptoTestFiller = Type[DiscoveryCryptoTest] JustifiabilityTestFiller = Type[JustifiabilityTest] PoseidonPermutationTestFiller = Type[PoseidonPermutationTest] SyncTestFiller = Type[SyncTest] @@ -69,7 +67,6 @@ "GossipsubHandlerTest", "ApiEndpointTest", "SlotClockTest", - "DiscoveryCryptoTest", "JustifiabilityTest", "PoseidonPermutationTest", "SyncTest", @@ -93,7 +90,6 @@ "GossipsubHandlerTestFiller", "ApiEndpointTestFiller", "SlotClockTestFiller", - "DiscoveryCryptoTestFiller", "JustifiabilityTestFiller", "PoseidonPermutationTestFiller", "SyncTestFiller", diff --git a/packages/testing/src/consensus_testing/test_fixtures/__init__.py b/packages/testing/src/consensus_testing/test_fixtures/__init__.py index 9bd03d63..452617fc 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/__init__.py +++ b/packages/testing/src/consensus_testing/test_fixtures/__init__.py @@ -2,7 +2,6 @@ from .api_endpoint import ApiEndpointTest from .base import BaseConsensusFixture -from .discovery_crypto import DiscoveryCryptoTest from .fork_choice import ForkChoiceTest from .gossipsub_handler import GossipsubHandlerTest from .justifiability import JustifiabilityTest @@ -24,7 +23,6 @@ "GossipsubHandlerTest", "ApiEndpointTest", "SlotClockTest", - "DiscoveryCryptoTest", "JustifiabilityTest", "PoseidonPermutationTest", "SyncTest", diff --git a/packages/testing/src/consensus_testing/test_fixtures/discovery_crypto.py b/packages/testing/src/consensus_testing/test_fixtures/discovery_crypto.py deleted file mode 100644 index 363fb22a..00000000 --- a/packages/testing/src/consensus_testing/test_fixtures/discovery_crypto.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Discovery v5 cryptographic primitive test fixture. - -Generates JSON test vectors for the cryptographic operations used in -Discovery v5 peer discovery. All clients must produce identical outputs -for ECDH, key derivation, signing, and encryption to interoperate. -""" - -from typing import Any, ClassVar - -from lean_spec.subspecs.networking.discovery.crypto import ( - aes_gcm_decrypt, - aes_gcm_encrypt, - ecdh_agree, - sign_id_nonce, - verify_id_nonce_signature, -) -from lean_spec.subspecs.networking.discovery.keys import compute_node_id, derive_keys -from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes33, Bytes64, Bytes65 - -from .base import BaseConsensusFixture - - -def _to_hex(data: bytes) -> str: - return "0x" + data.hex() - - -def _from_hex(hex_str: str) -> bytes: - return bytes.fromhex(hex_str.removeprefix("0x")) - - -class DiscoveryCryptoTest(BaseConsensusFixture): - """Fixture for Discovery v5 cryptographic conformance. - - Tests deterministic crypto operations: ECDH, HKDF key derivation, - ID nonce signing/verification, AES-GCM, and node ID computation. - - JSON output: operation, input, output. - """ - - format_name: ClassVar[str] = "discovery_crypto" - description: ClassVar[str] = "Tests Discovery v5 cryptographic primitives" - - operation: str - """Crypto operation: ecdh, key_derivation, id_nonce_sign, - id_nonce_verify, aes_gcm_encrypt, or node_id.""" - - input: dict[str, Any] - """Operation-specific input parameters (hex-encoded bytes).""" - - output: dict[str, Any] = {} - """Computed output. Filled by make_fixture.""" - - def make_fixture(self) -> "DiscoveryCryptoTest": - """Dispatch to the operation handler and produce computed output.""" - match self.operation: - case "ecdh": - output = self._make_ecdh() - case "key_derivation": - output = self._make_key_derivation() - case "id_nonce_sign": - output = self._make_id_nonce_sign() - case "id_nonce_verify": - output = self._make_id_nonce_verify() - case "aes_gcm_encrypt": - output = self._make_aes_gcm_encrypt() - case "node_id": - output = self._make_node_id() - case _: - raise ValueError(f"Unknown operation: {self.operation}") - return self.model_copy(update={"output": output}) - - def _make_ecdh(self) -> dict[str, Any]: - """Compute ECDH shared secret from private key and public key.""" - private_key = Bytes32(_from_hex(self.input["privateKey"])) - public_key_raw = _from_hex(self.input["publicKey"]) - public_key: Bytes33 | Bytes65 - if len(public_key_raw) == 33: - public_key = Bytes33(public_key_raw) - else: - public_key = Bytes65(public_key_raw) - - shared_secret = ecdh_agree(private_key, public_key) - return {"sharedSecret": _to_hex(shared_secret)} - - def _make_key_derivation(self) -> dict[str, Any]: - """Derive session keys via HKDF from shared secret and IDs.""" - secret = Bytes33(_from_hex(self.input["sharedSecret"])) - initiator_id = Bytes32(_from_hex(self.input["initiatorId"])) - recipient_id = Bytes32(_from_hex(self.input["recipientId"])) - challenge_data = _from_hex(self.input["challengeData"]) - - initiator_key, recipient_key = derive_keys( - secret, initiator_id, recipient_id, challenge_data - ) - return { - "initiatorKey": _to_hex(initiator_key), - "recipientKey": _to_hex(recipient_key), - } - - def _make_id_nonce_sign(self) -> dict[str, Any]: - """Sign an ID nonce for handshake authentication.""" - private_key = Bytes32(_from_hex(self.input["privateKey"])) - challenge_data = _from_hex(self.input["challengeData"]) - ephemeral_pubkey = Bytes33(_from_hex(self.input["ephemeralPubkey"])) - dest_node_id = Bytes32(_from_hex(self.input["destNodeId"])) - - signature = sign_id_nonce(private_key, challenge_data, ephemeral_pubkey, dest_node_id) - - # Cross-check: signature must verify against the signer's public key. - from cryptography.hazmat.primitives import serialization - from cryptography.hazmat.primitives.asymmetric import ec - - priv = ec.derive_private_key(int.from_bytes(private_key, "big"), ec.SECP256K1()) - pubkey = Bytes33( - priv.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, - ) - ) - assert verify_id_nonce_signature( - signature, challenge_data, ephemeral_pubkey, dest_node_id, pubkey - ), "Signature failed self-verification" - - return {"signature": _to_hex(signature)} - - def _make_id_nonce_verify(self) -> dict[str, Any]: - """Verify an ID nonce signature.""" - signature = Bytes64(_from_hex(self.input["signature"])) - challenge_data = _from_hex(self.input["challengeData"]) - ephemeral_pubkey = Bytes33(_from_hex(self.input["ephemeralPubkey"])) - dest_node_id = Bytes32(_from_hex(self.input["destNodeId"])) - public_key = Bytes33(_from_hex(self.input["publicKey"])) - - valid = verify_id_nonce_signature( - signature, challenge_data, ephemeral_pubkey, dest_node_id, public_key - ) - return {"valid": valid} - - def _make_aes_gcm_encrypt(self) -> dict[str, Any]: - """Encrypt with AES-128-GCM, verify roundtrip via decrypt.""" - key = Bytes16(_from_hex(self.input["key"])) - nonce = Bytes12(_from_hex(self.input["nonce"])) - plaintext = _from_hex(self.input["plaintext"]) - aad = _from_hex(self.input["aad"]) - - ciphertext = aes_gcm_encrypt(key, nonce, plaintext, aad) - - # Roundtrip: decrypt must recover original plaintext. - decrypted = aes_gcm_decrypt(key, nonce, ciphertext, aad) - assert decrypted == plaintext, "AES-GCM roundtrip produced different bytes" - - return {"ciphertext": _to_hex(ciphertext)} - - def _make_node_id(self) -> dict[str, Any]: - """Compute node ID from public key via keccak256.""" - public_key_raw = _from_hex(self.input["publicKey"]) - public_key: Bytes33 | Bytes65 - if len(public_key_raw) == 33: - public_key = Bytes33(public_key_raw) - else: - public_key = Bytes65(public_key_raw) - - node_id = compute_node_id(public_key) - return {"nodeId": _to_hex(node_id)} diff --git a/packages/testing/src/consensus_testing/test_fixtures/networking_codec.py b/packages/testing/src/consensus_testing/test_fixtures/networking_codec.py index a213fa49..cc2a0c85 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/networking_codec.py +++ b/packages/testing/src/consensus_testing/test_fixtures/networking_codec.py @@ -3,34 +3,6 @@ from typing import Any, ClassVar from lean_spec.snappy import compress, decompress, frame_compress, frame_decompress -from lean_spec.subspecs.networking.discovery.codec import decode_message, encode_message -from lean_spec.subspecs.networking.discovery.messages import ( - Distance, - FindNode, - IdNonce, - IPv4, - IPv6, - Nodes, - Nonce, - PacketFlag, - Ping, - Pong, - Port, - RequestId, - TalkReq, - TalkResp, -) -from lean_spec.subspecs.networking.discovery.packet import ( - decode_handshake_authdata, - decode_message_authdata, - decode_packet_header, - decode_whoareyou_authdata, - encode_handshake_authdata, - encode_message_authdata, - encode_packet, - encode_whoareyou_authdata, -) -from lean_spec.subspecs.networking.discovery.routing import log2_distance, xor_distance from lean_spec.subspecs.networking.enr.enr import ENR from lean_spec.subspecs.networking.gossipsub.message import GossipsubMessage from lean_spec.subspecs.networking.gossipsub.rpc import ( @@ -53,9 +25,8 @@ encode_request, ) from lean_spec.subspecs.networking.transport.peer_id import KeyType, PeerId, PublicKeyProto -from lean_spec.subspecs.networking.types import NodeId, SeqNumber from lean_spec.subspecs.networking.varint import decode_varint, encode_varint -from lean_spec.types import Bytes16, Bytes33, Bytes64, SubnetId +from lean_spec.types import SubnetId from .base import BaseConsensusFixture @@ -118,18 +89,10 @@ def make_fixture(self) -> "NetworkingCodecTest": output = self._make_enr() case "peer_id": output = self._make_peer_id() - case "discv5_message": - output = self._make_discv5_message() - case "discv5_packet": - output = self._make_discv5_packet() case "snappy_block": output = self._make_snappy_block() case "snappy_frame": output = self._make_snappy_frame() - case "xor_distance": - output = self._make_xor_distance() - case "log2_distance": - output = self._make_log2_distance() case "decode_failure": output = self._make_decode_failure() case _: @@ -146,8 +109,7 @@ def _make_decode_failure(self) -> dict[str, Any]: The input record carries two fields: - `decoder`: name of the target decoder (`varint`, `snappy_frame`, - `gossipsub_rpc`, `reqresp_request`, `reqresp_response`, - `discv5_message`, `enr`). + `gossipsub_rpc`, `reqresp_request`, `reqresp_response`, `enr`). - `bytes`: hex-encoded malformed input. Returns: @@ -173,11 +135,6 @@ def _make_decode_failure(self) -> dict[str, Any]: "gossipsub_rpc": RPC.decode, "reqresp_request": decode_request, "reqresp_response": ResponseCode.decode, - "discv5_message": decode_message, - "discv5_packet": lambda raw: decode_packet_header( - NodeId(_from_hex(self.input.get("localNodeId", "0x" + "00" * 32))), - raw, - ), "enr": ENR.from_rlp, } if decoder_name not in decoders: @@ -357,20 +314,6 @@ def _make_snappy_frame(self) -> dict[str, Any]: "uncompressedLength": len(data), } - def _make_xor_distance(self) -> dict[str, Any]: - """Compute XOR distance between two node IDs.""" - node_a = NodeId(_from_hex(self.input["nodeA"])) - node_b = NodeId(_from_hex(self.input["nodeB"])) - distance = xor_distance(node_a, node_b) - return {"distance": hex(distance)} - - def _make_log2_distance(self) -> dict[str, Any]: - """Compute log2 of XOR distance for k-bucket assignment.""" - node_a = NodeId(_from_hex(self.input["nodeA"])) - node_b = NodeId(_from_hex(self.input["nodeB"])) - distance = int(log2_distance(node_a, node_b)) - return {"distance": distance} - def _make_enr(self) -> dict[str, Any]: """Parse an ENR string, re-serialize, assert roundtrip, extract properties.""" enr_string = self.input["enrString"] @@ -449,149 +392,6 @@ def _make_peer_id(self) -> dict[str, Any]: "peerId": peer_id_str, } - def _make_discv5_message(self) -> dict[str, Any]: - """Encode a discv5 message as type byte + RLP, decode it back, assert roundtrip.""" - msg = _build_discv5_message(self.input) - encoded = encode_message(msg) - - # Decode and re-encode must produce identical bytes. - re_encoded = encode_message(decode_message(encoded)) - assert encoded == re_encoded, "Discv5 message roundtrip produced different bytes" - - return {"encoded": _to_hex(encoded)} - - def _make_discv5_packet(self) -> dict[str, Any]: - """Encode a Discovery v5 packet and roundtrip-decode the header. - - Input keys (all hex unless noted): - - - `packetType`: "message", "whoareyou", or "handshake". - - `destNodeId`: 32-byte destination node ID. Masking key - derives from its first 16 bytes; clients also use this as the - local node id when decoding. - - `nonce`: 12-byte message nonce. - - `maskingIv`: 16-byte header-masking IV, supplied explicitly - so the produced bytes are deterministic. - - `message`: message payload (empty for WHOAREYOU, otherwise - the already-encrypted ciphertext bytes). - - `encryptionKey`: 16-byte AES-GCM key for non-WHOAREYOU. - - Packet-type-specific input keys: - - - message: `srcId`. - - whoareyou: `idNonce`, `enrSeq` (uint64 integer). - - handshake: `srcId`, `idSignature`, `ephPubkey`, optional - `record` (RLP-encoded ENR). - - Output: - - - `encoded`: full packet hex. - - `flag`: numeric flag (0/1/2) recovered via decode. - - `authdataSize`: size of authdata in bytes. - """ - packet_type = self.input["packetType"] - dest_node_id = NodeId(_from_hex(self.input["destNodeId"])) - nonce = Nonce(_from_hex(self.input["nonce"])) - masking_iv = Bytes16(_from_hex(self.input["maskingIv"])) - message_bytes = _from_hex(self.input.get("message", "0x")) - - if packet_type == "message": - flag = PacketFlag.MESSAGE - authdata = encode_message_authdata(NodeId(_from_hex(self.input["srcId"]))) - encryption_key: Bytes16 | None = Bytes16(_from_hex(self.input["encryptionKey"])) - elif packet_type == "whoareyou": - flag = PacketFlag.WHOAREYOU - authdata = encode_whoareyou_authdata( - IdNonce(_from_hex(self.input["idNonce"])), - SeqNumber(int(self.input["enrSeq"])), - ) - encryption_key = None - elif packet_type == "handshake": - flag = PacketFlag.HANDSHAKE - record = _from_hex(self.input["record"]) if self.input.get("record") else None - authdata = encode_handshake_authdata( - NodeId(_from_hex(self.input["srcId"])), - Bytes64(_from_hex(self.input["idSignature"])), - Bytes33(_from_hex(self.input["ephPubkey"])), - record=record, - ) - encryption_key = Bytes16(_from_hex(self.input["encryptionKey"])) - else: - raise ValueError(f"Unknown discv5 packet type: {packet_type!r}") - - encoded = encode_packet( - dest_node_id=dest_node_id, - flag=flag, - nonce=nonce, - authdata=authdata, - message=message_bytes, - encryption_key=encryption_key, - masking_iv=masking_iv, - ) - - # Roundtrip: decode header and assert shape matches. - header, _msg, _ad = decode_packet_header(dest_node_id, encoded) - assert header.flag == flag, "Packet flag roundtrip mismatch" - assert header.nonce == nonce, "Packet nonce roundtrip mismatch" - assert header.authdata == authdata, "Packet authdata roundtrip mismatch" - - # Exercise per-type authdata decode for extra shape coverage. - if flag == PacketFlag.MESSAGE: - decode_message_authdata(header.authdata) - elif flag == PacketFlag.WHOAREYOU: - decode_whoareyou_authdata(header.authdata) - elif flag == PacketFlag.HANDSHAKE: - decode_handshake_authdata(header.authdata) - - return { - "encoded": _to_hex(encoded), - "flag": int(flag), - "authdataSize": len(authdata), - } - - -def _build_discv5_message( - d: dict[str, Any], -) -> Ping | Pong | FindNode | Nodes | TalkReq | TalkResp: - """Build a discv5 message dataclass from a JSON-friendly dict.""" - request_id = RequestId(data=_from_hex(d["requestId"])) - match d["type"]: - case "ping": - return Ping(request_id=request_id, enr_seq=SeqNumber(d["enrSeq"])) - case "pong": - ip_bytes = _from_hex(d["recipientIp"]) - ip = IPv4(ip_bytes) if len(ip_bytes) == 4 else IPv6(ip_bytes) - return Pong( - request_id=request_id, - enr_seq=SeqNumber(d["enrSeq"]), - recipient_ip=ip, - recipient_port=Port(d["recipientPort"]), - ) - case "findnode": - return FindNode( - request_id=request_id, - distances=[Distance(x) for x in d["distances"]], - ) - case "nodes": - return Nodes( - request_id=request_id, - total=d["total"], - enrs=[_from_hex(e) for e in d.get("enrs", [])], - ) - case "talkreq": - return TalkReq( - request_id=request_id, - protocol=_from_hex(d["protocol"]), - request=_from_hex(d["request"]), - ) - case "talkresp": - return TalkResp( - request_id=request_id, - response=_from_hex(d["response"]), - ) - case _: - raise ValueError(f"Unknown discv5 message type: {d['type']}") - def _build_rpc(d: dict[str, Any]) -> RPC: """Build an RPC from a JSON-friendly dict.""" diff --git a/src/lean_spec/subspecs/networking/discovery/__init__.py b/src/lean_spec/subspecs/networking/discovery/__init__.py deleted file mode 100644 index d7a2d9cd..00000000 --- a/src/lean_spec/subspecs/networking/discovery/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Discovery v5 Protocol Specification - -Node Discovery Protocol v5.1 for finding peers in Ethereum networks. - -The module provides: -- Wire protocol encoding/decoding -- Cryptographic primitives (AES-CTR/GCM, secp256k1 ECDH) -- Session and handshake management -- UDP transport layer -- High-level discovery service - -References: - - https://github.com/ethereum/devp2p/blob/master/discv5/discv5.md - - https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire.md - - https://github.com/ethereum/devp2p/blob/master/discv5/discv5-theory.md -""" diff --git a/src/lean_spec/subspecs/networking/discovery/codec.py b/src/lean_spec/subspecs/networking/discovery/codec.py deleted file mode 100644 index 1d39b0d9..00000000 --- a/src/lean_spec/subspecs/networking/discovery/codec.py +++ /dev/null @@ -1,320 +0,0 @@ -""" -Message codec for Discovery v5. - -Protocol messages are encoded as:: - - message-pt = message-type || message-data - message-data = RLP([field1, field2, ...]) - -Message types: -- PING (0x01): [request-id, enr-seq] -- PONG (0x02): [request-id, enr-seq, recipient-ip, recipient-port] -- FINDNODE (0x03): [request-id, [distances...]] -- NODES (0x04): [request-id, total, [ENRs...]] -- TALKREQ (0x05): [request-id, protocol, request] -- TALKRESP (0x06): [request-id, response] - -References: -- https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire.md#protocol-messages -""" - -from __future__ import annotations - -from lean_spec.subspecs.networking.types import SeqNumber -from lean_spec.types import ( - RLPDecodingError, - RLPItem, - Uint8, - Uint64, - decode_rlp, - decode_rlp_list, - encode_rlp, -) - -from .messages import ( - Distance, - FindNode, - IPv4, - IPv6, - MessageType, - Nodes, - Ping, - Pong, - Port, - RequestId, - TalkReq, - TalkResp, -) - -type DiscoveryMessage = Ping | Pong | FindNode | Nodes | TalkReq | TalkResp -"""Union of all Discovery v5 protocol messages.""" - - -class MessageEncodingError(Exception): - """Error encoding a Discovery v5 message.""" - - -class MessageDecodingError(Exception): - """Error decoding a Discovery v5 message.""" - - -def encode_message(msg: DiscoveryMessage) -> bytes: - """ - Encode a protocol message to bytes. - - Format: message-type (1 byte) || RLP(message-data) - - Args: - msg: Protocol message to encode. - - Returns: - Encoded message bytes. - """ - match msg: - case Ping(): - return _encode_ping(msg) - case Pong(): - return _encode_pong(msg) - case FindNode(): - return _encode_findnode(msg) - case Nodes(): - return _encode_nodes(msg) - case TalkReq(): - return _encode_talkreq(msg) - case TalkResp(): - return _encode_talkresp(msg) - case _: - raise MessageEncodingError(f"Unknown message type: {type(msg).__name__}") - - -def decode_message(data: bytes) -> DiscoveryMessage: - """ - Decode a protocol message from bytes. - - Args: - data: Encoded message bytes. - - Returns: - Decoded protocol message. - - Raises: - MessageDecodingError: If message is malformed or unknown type. - """ - if len(data) < 2: - raise MessageDecodingError("Message too short") - - msg_type = data[0] - payload = data[1:] - - try: - match msg_type: - case MessageType.PING: - return _decode_ping(payload) - case MessageType.PONG: - return _decode_pong(payload) - case MessageType.FINDNODE: - return _decode_findnode(payload) - case MessageType.NODES: - return _decode_nodes(payload) - case MessageType.TALKREQ: - return _decode_talkreq(payload) - case MessageType.TALKRESP: - return _decode_talkresp(payload) - case _: - raise MessageDecodingError(f"Unknown message type: {msg_type:#x}") - except RLPDecodingError as e: - raise MessageDecodingError(f"Invalid RLP: {e}") from e - except (IndexError, ValueError) as e: - raise MessageDecodingError(f"Invalid message format: {e}") from e - - -def _encode_request_id(request_id: RequestId) -> bytes: - """Encode request ID to minimal bytes.""" - data = bytes(request_id) - return data.lstrip(b"\x00") or b"\x00" - - -def _decode_request_id(data: bytes) -> RequestId: - """Decode request ID from bytes.""" - if len(data) > 8: - raise ValueError(f"Request ID too long: {len(data)} > 8") - return RequestId(data=data) - - -def _encode_uint64(value: Uint64) -> bytes: - """Encode Uint64 to minimal big-endian bytes.""" - if int(value) == 0: - return b"" - return int(value).to_bytes((int(value).bit_length() + 7) // 8, "big") - - -def _decode_uint64(data: bytes) -> Uint64: - """Decode Uint64 from big-endian bytes.""" - if len(data) == 0: - return Uint64(0) - return Uint64(int.from_bytes(data, "big")) - - -def _encode_ping(msg: Ping) -> bytes: - """Encode PING message.""" - items = [ - _encode_request_id(msg.request_id), - _encode_uint64(msg.enr_seq), - ] - return bytes([MessageType.PING]) + encode_rlp(items) - - -def _decode_ping(payload: bytes) -> Ping: - """Decode PING message.""" - items = decode_rlp_list(payload) - if len(items) != 2: - raise MessageDecodingError("PING requires 2 elements") - - return Ping( - request_id=_decode_request_id(items[0]), - enr_seq=SeqNumber(_decode_uint64(items[1])), - ) - - -def _encode_pong(msg: Pong) -> bytes: - """Encode PONG message.""" - items = [ - _encode_request_id(msg.request_id), - _encode_uint64(msg.enr_seq), - msg.recipient_ip, - int(msg.recipient_port).to_bytes(2, "big") if int(msg.recipient_port) > 0 else b"", - ] - return bytes([MessageType.PONG]) + encode_rlp(items) - - -def _decode_pong(payload: bytes) -> Pong: - """Decode PONG message.""" - items = decode_rlp_list(payload) - if len(items) != 4: - raise MessageDecodingError("PONG requires 4 elements") - - port = int.from_bytes(items[3], "big") if items[3] else 0 - ip_bytes = items[2] - recipient_ip = IPv4(ip_bytes) if len(ip_bytes) == IPv4.LENGTH else IPv6(ip_bytes) - - return Pong( - request_id=_decode_request_id(items[0]), - enr_seq=SeqNumber(_decode_uint64(items[1])), - recipient_ip=recipient_ip, - recipient_port=Port(port), - ) - - -def _encode_findnode(msg: FindNode) -> bytes: - """Encode FINDNODE message.""" - distance_items = [_encode_uint64(Uint64(int(d))) for d in msg.distances] - items = [ - _encode_request_id(msg.request_id), - distance_items, - ] - return bytes([MessageType.FINDNODE]) + encode_rlp(items) - - -def _decode_findnode(payload: bytes) -> FindNode: - """Decode FINDNODE message.""" - items = decode_rlp(payload) - if not isinstance(items, list) or len(items) != 2: - raise MessageDecodingError("FINDNODE requires 2 elements") - - request_id_raw = items[0] - if not isinstance(request_id_raw, bytes): - raise MessageDecodingError("FINDNODE request-id must be bytes") - - distances_raw = items[1] - if not isinstance(distances_raw, list): - raise MessageDecodingError("FINDNODE distances must be a list") - - distances = [Distance(int.from_bytes(d, "big") if d else 0) for d in distances_raw] - - return FindNode( - request_id=_decode_request_id(request_id_raw), - distances=distances, - ) - - -def _encode_nodes(msg: Nodes) -> bytes: - """Encode NODES message.""" - enrs: list[RLPItem] = list(msg.enrs) - items: list[RLPItem] = [ - _encode_request_id(msg.request_id), - bytes([int(msg.total)]) if int(msg.total) > 0 else b"", - enrs, - ] - return bytes([MessageType.NODES]) + encode_rlp(items) - - -def _decode_nodes(payload: bytes) -> Nodes: - """Decode NODES message.""" - items = decode_rlp(payload) - if not isinstance(items, list) or len(items) != 3: - raise MessageDecodingError("NODES requires 3 elements") - - request_id_raw = items[0] - if not isinstance(request_id_raw, bytes): - raise MessageDecodingError("NODES request-id must be bytes") - - total_raw = items[1] - if not isinstance(total_raw, bytes): - raise MessageDecodingError("NODES total must be bytes") - total = total_raw[0] if total_raw else 0 - - enrs_raw = items[2] - if not isinstance(enrs_raw, list): - raise MessageDecodingError("NODES enrs must be a list") - - enrs = [e if isinstance(e, bytes) else b"" for e in enrs_raw] - - return Nodes( - request_id=_decode_request_id(request_id_raw), - total=Uint8(total), - enrs=enrs, - ) - - -def _encode_talkreq(msg: TalkReq) -> bytes: - """Encode TALKREQ message.""" - items = [ - _encode_request_id(msg.request_id), - msg.protocol, - msg.request, - ] - return bytes([MessageType.TALKREQ]) + encode_rlp(items) - - -def _decode_talkreq(payload: bytes) -> TalkReq: - """Decode TALKREQ message.""" - items = decode_rlp_list(payload) - if len(items) != 3: - raise MessageDecodingError("TALKREQ requires 3 elements") - - return TalkReq( - request_id=_decode_request_id(items[0]), - protocol=items[1], - request=items[2], - ) - - -def _encode_talkresp(msg: TalkResp) -> bytes: - """Encode TALKRESP message.""" - items = [ - _encode_request_id(msg.request_id), - msg.response, - ] - return bytes([MessageType.TALKRESP]) + encode_rlp(items) - - -def _decode_talkresp(payload: bytes) -> TalkResp: - """Decode TALKRESP message.""" - items = decode_rlp_list(payload) - if len(items) != 2: - raise MessageDecodingError("TALKRESP requires 2 elements") - - return TalkResp( - request_id=_decode_request_id(items[0]), - response=items[1], - ) diff --git a/src/lean_spec/subspecs/networking/discovery/config.py b/src/lean_spec/subspecs/networking/discovery/config.py deleted file mode 100644 index 745eb27f..00000000 --- a/src/lean_spec/subspecs/networking/discovery/config.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -Discovery v5 Configuration - -Protocol constants and configuration for Node Discovery Protocol v5.1. - -References: - - https://github.com/ethereum/devp2p/blob/master/discv5/discv5-theory.md -""" - -from __future__ import annotations - -from typing import Final - -from lean_spec.subspecs.networking.types import Port -from lean_spec.types import StrictBaseModel - -K_BUCKET_SIZE: Final = 16 -"""Nodes per k-bucket. Standard Kademlia value balancing table size and lookup efficiency.""" - -ALPHA: Final = 3 -"""Concurrent queries during lookup. Balances speed against network load.""" - -BUCKET_COUNT: Final = 256 -"""Total k-buckets. One per bit of the 256-bit node ID space.""" - -REQUEST_TIMEOUT_SECS: Final = 0.5 -"""Single request timeout. Spec recommends 500ms for request/response.""" - -HANDSHAKE_TIMEOUT_SECS: Final = 1.0 -"""Handshake completion timeout. Spec recommends 1s for full handshake.""" - -MAX_NODES_RESPONSE: Final = 16 -"""Max ENRs per NODES message. Keeps responses under 1280 byte UDP limit.""" - -BOND_EXPIRY_SECS: Final = 86400 -"""Liveness revalidation interval. 24 hours before re-checking a node.""" - -MAX_PACKET_SIZE: Final = 1280 -"""Maximum UDP packet size in bytes.""" - -MIN_PACKET_SIZE: Final = 63 -"""Minimum valid packet size in bytes.""" - -DEFAULT_PORT: Final = Port(0) -"""Default port value for optional port parameters.""" - - -class DiscoveryConfig(StrictBaseModel): - """Runtime configuration for Discovery v5.""" - - k_bucket_size: int = K_BUCKET_SIZE - """Maximum nodes stored per k-bucket in the routing table.""" - - alpha: int = ALPHA - """Number of concurrent FINDNODE queries during lookup.""" - - request_timeout_secs: float = REQUEST_TIMEOUT_SECS - """Timeout for a single request/response exchange.""" - - handshake_timeout_secs: float = HANDSHAKE_TIMEOUT_SECS - """Timeout for completing the full handshake sequence.""" - - max_nodes_response: int = MAX_NODES_RESPONSE - """Maximum ENR records returned in a single NODES response.""" - - bond_expiry_secs: int = BOND_EXPIRY_SECS - """Seconds before a bonded node requires liveness revalidation.""" diff --git a/src/lean_spec/subspecs/networking/discovery/crypto.py b/src/lean_spec/subspecs/networking/discovery/crypto.py deleted file mode 100644 index 900828a4..00000000 --- a/src/lean_spec/subspecs/networking/discovery/crypto.py +++ /dev/null @@ -1,406 +0,0 @@ -""" -Cryptographic primitives for Discovery v5. - -Discovery v5 uses: -- AES-128-CTR for header masking -- AES-128-GCM for message encryption -- secp256k1 ECDH for key agreement (NOT X25519 like Noise) -- SHA256 for hashing and key derivation - -Wire format notes: -- Header masking key: first 16 bytes of destination node ID -- Header masking IV: random 16 bytes included in packet -- Message encryption uses 12-byte nonce (from packet header) -- GCM tag is 16 bytes, appended to ciphertext - -References: -- https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire.md -""" - -from __future__ import annotations - -import hashlib -from typing import Final - -from cryptography.exceptions import InvalidSignature -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.asymmetric.utils import ( - Prehashed, - decode_dss_signature, - encode_dss_signature, -) -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.primitives.ciphers.aead import AESGCM - -from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes33, Bytes64, Bytes65 - -COMPRESSED_PUBKEY_SIZE: Final = 33 -"""Compressed secp256k1 public key: 0x02/0x03 + 32-byte x coordinate.""" - -UNCOMPRESSED_PUBKEY_SIZE: Final = 65 -"""Uncompressed secp256k1 public key: 0x04 + 32-byte x + 32-byte y.""" - -AES_KEY_SIZE: Final = 16 -"""AES-128 key size in bytes.""" - -GCM_NONCE_SIZE: Final = 12 -"""AES-GCM nonce size in bytes.""" - -GCM_TAG_SIZE: Final = 16 -"""AES-GCM authentication tag size in bytes.""" - -CTR_IV_SIZE: Final = 16 -"""AES-CTR initialization vector size in bytes.""" - -ID_SIGNATURE_SIZE: Final = 64 -"""secp256k1 signature size (r || s, each 32 bytes).""" - -ID_SIGNATURE_DOMAIN: Final = b"discovery v5 identity proof" -"""Domain separator for ID nonce signatures. Prevents cross-protocol reuse.""" - -_P: Final = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F -"""secp256k1 field prime.""" - -_N: Final = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 -"""secp256k1 curve order.""" - -_Gx: Final = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798 -"""secp256k1 generator x-coordinate.""" - -_Gy: Final = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 -"""secp256k1 generator y-coordinate.""" - - -def _modinv(a: int, m: int) -> int: - """Compute modular inverse using Fermat's little theorem (m must be prime).""" - return pow(a, m - 2, m) - - -def _point_add(p1: tuple[int, int] | None, p2: tuple[int, int] | None) -> tuple[int, int] | None: - """Add two secp256k1 curve points.""" - if p1 is None: - return p2 - if p2 is None: - return p1 - - x1, y1 = p1 - x2, y2 = p2 - - if x1 == x2 and y1 != y2: - return None - - if x1 == x2: - # Point doubling. - lam = (3 * x1 * x1 * _modinv(2 * y1, _P)) % _P - else: - lam = ((y2 - y1) * _modinv(x2 - x1, _P)) % _P - - x3 = (lam * lam - x1 - x2) % _P - y3 = (lam * (x1 - x3) - y1) % _P - return (x3, y3) - - -def _point_mul(k: int, point: tuple[int, int] | None) -> tuple[int, int] | None: - """Scalar multiplication using double-and-add.""" - result = None - addend = point - while k: - if k & 1: - result = _point_add(result, addend) - addend = _point_add(addend, addend) - k >>= 1 - return result - - -def _decompress_pubkey(data: bytes) -> tuple[int, int]: - """Parse a compressed or uncompressed secp256k1 public key to (x, y).""" - if len(data) == UNCOMPRESSED_PUBKEY_SIZE and data[0] == 0x04: - x = int.from_bytes(data[1:33], "big") - y = int.from_bytes(data[33:65], "big") - return (x, y) - - if len(data) == COMPRESSED_PUBKEY_SIZE and data[0] in (0x02, 0x03): - x = int.from_bytes(data[1:], "big") - # Solve y^2 = x^3 + 7 (mod p). - y_sq = (pow(x, 3, _P) + 7) % _P - y = pow(y_sq, (_P + 1) // 4, _P) - # Choose the correct parity. - if (y & 1) != (data[0] & 1): - y = _P - y - return (x, y) - - raise ValueError(f"Invalid public key encoding: length={len(data)}") - - -def _compress_point(point: tuple[int, int]) -> Bytes33: - """Encode a curve point as 33-byte compressed format.""" - x, y = point - prefix = 0x02 if y % 2 == 0 else 0x03 - return Bytes33(bytes([prefix]) + x.to_bytes(32, "big")) - - -def aes_ctr_encrypt(key: Bytes16, iv: Bytes16, plaintext: bytes) -> bytes: - """ - Encrypt using AES-128-CTR. - - Used for header masking in Discovery v5 packets. - The masking key is derived from the destination node ID. - - Args: - key: 16-byte AES key (dest_node_id[:16]). - iv: 16-byte initialization vector (masking-iv from packet). - plaintext: Data to encrypt (packet header). - - Returns: - Ciphertext of same length as plaintext. - """ - cipher = Cipher(algorithms.AES(key), modes.CTR(iv)) - encryptor = cipher.encryptor() - return encryptor.update(plaintext) + encryptor.finalize() - - -def aes_ctr_decrypt(key: Bytes16, iv: Bytes16, ciphertext: bytes) -> bytes: - """ - Decrypt using AES-128-CTR. - - CTR mode is symmetric - encryption and decryption are identical operations. - - Args: - key: 16-byte AES key. - iv: 16-byte initialization vector. - ciphertext: Data to decrypt. - - Returns: - Decrypted plaintext. - """ - return aes_ctr_encrypt(key, iv, ciphertext) - - -def aes_gcm_encrypt(key: Bytes16, nonce: Bytes12, plaintext: bytes, aad: bytes) -> bytes: - """ - Encrypt using AES-128-GCM. - - Used for message encryption in Discovery v5. - The authentication tag is appended to the ciphertext. - - Args: - key: 16-byte AES key (session encryption key). - nonce: 12-byte nonce (from packet header). - plaintext: Message data to encrypt. - aad: Additional authenticated data (packet header). - - Returns: - Ciphertext with 16-byte authentication tag appended. - """ - aesgcm = AESGCM(key) - return aesgcm.encrypt(nonce, plaintext, aad) - - -def aes_gcm_decrypt(key: Bytes16, nonce: Bytes12, ciphertext: bytes, aad: bytes) -> bytes: - """ - Decrypt using AES-128-GCM. - - Verifies the authentication tag and decrypts if valid. - - Args: - key: 16-byte AES key. - nonce: 12-byte nonce. - ciphertext: Ciphertext with 16-byte auth tag. - aad: Additional authenticated data. - - Returns: - Decrypted plaintext. - - Raises: - cryptography.exceptions.InvalidTag: If authentication fails. - """ - aesgcm = AESGCM(key) - return aesgcm.decrypt(nonce, ciphertext, aad) - - -def ecdh_agree(private_key_bytes: Bytes32, public_key_bytes: Bytes33 | Bytes65) -> Bytes33: - """ - Perform secp256k1 ECDH key agreement. - - Both parties compute the same shared secret from their private key - and the other party's public key. - - Per Discovery v5 spec, the shared secret is the 33-byte compressed - point resulting from scalar multiplication of the private key with - the public key. - - Args: - private_key_bytes: 32-byte secp256k1 private key scalar. - public_key_bytes: 33-byte compressed or 65-byte uncompressed public key. - - Returns: - 33-byte shared secret (compressed point from ECDH). - """ - scalar = int.from_bytes(private_key_bytes, "big") - point = _decompress_pubkey(public_key_bytes) - result = _point_mul(scalar, point) - - if result is None: - raise ValueError("ECDH produced point at infinity") - - return _compress_point(result) - - -def generate_secp256k1_keypair() -> tuple[Bytes32, Bytes33]: - """ - Generate a new secp256k1 keypair. - - Used to create ephemeral keys for ECDH during handshake. - - Returns: - Tuple of (private_key_bytes, compressed_public_key_bytes). - - private_key: 32-byte scalar - - public_key: 33-byte compressed format - """ - private_key = ec.generate_private_key(ec.SECP256K1()) - - private_bytes = private_key.private_numbers().private_value.to_bytes(32, "big") - public_bytes = private_key.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, - ) - - return Bytes32(private_bytes), Bytes33(public_bytes) - - -def pubkey_to_uncompressed(public_key_bytes: Bytes33 | Bytes65) -> Bytes65: - """ - Convert any secp256k1 public key to uncompressed format. - - Args: - public_key_bytes: 33-byte compressed or 65-byte uncompressed public key. - - Returns: - 65-byte uncompressed public key (0x04 || x || y). - """ - if len(public_key_bytes) == UNCOMPRESSED_PUBKEY_SIZE: - return Bytes65(public_key_bytes) - - public_key = ec.EllipticCurvePublicKey.from_encoded_point( - ec.SECP256K1(), - public_key_bytes, - ) - return Bytes65( - public_key.public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.UncompressedPoint, - ) - ) - - -def sign_id_nonce( - private_key_bytes: Bytes32, - challenge_data: bytes, - ephemeral_pubkey: Bytes33, - dest_node_id: Bytes32, -) -> Bytes64: - """ - Sign for handshake authentication. - - The signature proves identity ownership without revealing the private key. - - Per Discovery v5 spec: - id-signature-input = "discovery v5 identity proof" || challenge-data || - ephemeral-pubkey || node-id-B - id-signature = sign(sha256(id-signature-input)) - - Args: - private_key_bytes: 32-byte secp256k1 private key. - challenge_data: Full WHOAREYOU challenge data (masking-iv || static-header || authdata). - ephemeral_pubkey: 33-byte compressed ephemeral public key. - dest_node_id: 32-byte node ID of the WHOAREYOU sender (node-id-B). - - Returns: - 64-byte signature (r || s, each 32 bytes). - """ - # The signing input binds several values together per the spec: - # - # - Domain separator prevents cross-protocol signature reuse - # - challenge_data provides freshness (full WHOAREYOU packet data) - # - ephemeral_pubkey binds to this specific handshake - # - dest_node_id (node-id-B) binds to the specific challenger - # - # Using the full challenge_data (not just id_nonce) ensures the signature - # is bound to the exact WHOAREYOU packet received, preventing replay attacks. - signing_input = ID_SIGNATURE_DOMAIN + challenge_data + ephemeral_pubkey + dest_node_id - - digest = hashlib.sha256(signing_input).digest() - - # Sign the pre-hashed digest. - # - # We use Prehashed because we've already computed SHA256. - # The library expects the 32-byte digest directly. - private_key = ec.derive_private_key( - int.from_bytes(private_key_bytes, "big"), - ec.SECP256K1(), - ) - - der_signature = private_key.sign( - digest, ec.ECDSA(Prehashed(hashes.SHA256()), deterministic_signing=True) - ) - - # Convert DER-encoded signature to fixed-size r||s format. - # - # ECDSA signatures in DER are variable length. - # Discovery v5 uses fixed 64-byte r||s for consistency. - r, s = decode_dss_signature(der_signature) - return Bytes64(r.to_bytes(32, "big") + s.to_bytes(32, "big")) - - -def verify_id_nonce_signature( - signature: Bytes64, - challenge_data: bytes, - ephemeral_pubkey: Bytes33, - dest_node_id: Bytes32, - public_key_bytes: Bytes33, -) -> bool: - """ - Verify an ID nonce signature. - - Verifies that the signature was created by the holder of the private key - corresponding to the given public key. - - Per Discovery v5 spec: - id-signature-input = "discovery v5 identity proof" || challenge-data || - ephemeral-pubkey || node-id-B - Verify: signature matches sha256(id-signature-input) - - Args: - signature: 64-byte signature (r || s). - challenge_data: Full WHOAREYOU challenge data (masking-iv || static-header || authdata). - ephemeral_pubkey: 33-byte compressed ephemeral public key. - dest_node_id: 32-byte node ID of the WHOAREYOU sender (node-id-B). - public_key_bytes: 33-byte compressed public key of the signer. - - Returns: - True if signature is valid, False otherwise. - """ - # Build the signing input per spec: - # domain-separator || challenge-data || ephemeral-pubkey || node-id-B - input_data = ID_SIGNATURE_DOMAIN + challenge_data + ephemeral_pubkey + dest_node_id - - # Pre-hash with SHA256 since ECDSA verification expects a fixed-size digest. - digest = hashlib.sha256(input_data).digest() - - # The cryptography library expects DER-encoded signatures, not raw r||s. - r = int.from_bytes(signature[:32], "big") - s = int.from_bytes(signature[32:], "big") - der_signature = encode_dss_signature(r, s) - - # Return False on failure rather than raising, since invalid signatures - # are expected during normal protocol operation (e.g., stale handshakes). - try: - public_key = ec.EllipticCurvePublicKey.from_encoded_point( - ec.SECP256K1(), - public_key_bytes, - ) - public_key.verify(der_signature, digest, ec.ECDSA(Prehashed(hashes.SHA256()))) - return True - except (InvalidSignature, ValueError): - return False diff --git a/src/lean_spec/subspecs/networking/discovery/handshake.py b/src/lean_spec/subspecs/networking/discovery/handshake.py deleted file mode 100644 index a94da7f9..00000000 --- a/src/lean_spec/subspecs/networking/discovery/handshake.py +++ /dev/null @@ -1,560 +0,0 @@ -""" -Handshake state machine for Discovery v5. - -The Discovery v5 handshake establishes shared session keys through ECDH. - -Handshake Flow: -1. A sends MESSAGE to B (encrypted with old/no session) -2. B can't decrypt, sends WHOAREYOU with id-nonce challenge -3. A responds with HANDSHAKE containing: - - Ephemeral public key for ECDH - - Signature proving ownership of node ID - - Optionally, A's ENR if B requested it -4. Both derive session keys from ECDH shared secret -5. Session established, further messages use derived keys - -State Machine: -- IDLE: No handshake in progress -- SENT_ORDINARY: Sent MESSAGE, awaiting potential WHOAREYOU -- SENT_WHOAREYOU: Sent WHOAREYOU, awaiting HANDSHAKE -- COMPLETED: Handshake finished, session established - -References: -- https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire.md#handshake -""" - -from __future__ import annotations - -import time -from dataclasses import dataclass, field -from enum import Enum, auto -from threading import Lock -from typing import Final - -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.types import NodeId, Port, SeqNumber -from lean_spec.types import Bytes16, Bytes32, Bytes33 - -from .config import DEFAULT_PORT, HANDSHAKE_TIMEOUT_SECS -from .crypto import ( - generate_secp256k1_keypair, - sign_id_nonce, - verify_id_nonce_signature, -) -from .keys import derive_keys_from_pubkey -from .messages import IdNonce, Nonce, PacketFlag -from .packet import ( - HandshakeAuthdata, - WhoAreYouAuthdata, - encode_handshake_authdata, - encode_static_header, - encode_whoareyou_authdata, -) -from .session import Session, SessionCache - -MAX_PENDING_HANDSHAKES: Final = 100 -"""Hard cap on concurrent pending handshakes to prevent resource exhaustion.""" - -MAX_ENR_CACHE: Final = 1000 -"""Maximum number of cached ENRs.""" - - -class HandshakeState(Enum): - """Handshake state machine states.""" - - IDLE = auto() - """No handshake in progress.""" - - SENT_ORDINARY = auto() - """Sent an ordinary MESSAGE, awaiting potential WHOAREYOU.""" - - SENT_WHOAREYOU = auto() - """Sent WHOAREYOU challenge, awaiting HANDSHAKE response.""" - - COMPLETED = auto() - """Handshake completed, session established.""" - - -@dataclass(slots=True) -class PendingHandshake: - """Tracks an in-progress handshake with a peer.""" - - state: HandshakeState - """Current state of this handshake.""" - - remote_node_id: NodeId - """32-byte node ID of the remote peer.""" - - id_nonce: IdNonce | None = None - """16-byte challenge nonce (set when WHOAREYOU sent/received).""" - - challenge_data: bytes | None = None - """Full WHOAREYOU packet data for key derivation (masking-iv || static-header || authdata).""" - - ephemeral_privkey: Bytes32 | None = None - """32-byte ephemeral private key (set when we send HANDSHAKE).""" - - challenge_nonce: Nonce | None = None - """12-byte nonce from the packet that triggered WHOAREYOU.""" - - remote_enr_seq: SeqNumber = SeqNumber(0) - """ENR seq we sent in WHOAREYOU. If 0, remote MUST include their ENR in HANDSHAKE.""" - - started_at: float = field(default_factory=time.time) - """Timestamp when handshake started.""" - - def is_expired(self, timeout_secs: float = HANDSHAKE_TIMEOUT_SECS) -> bool: - """Check if handshake has timed out.""" - return time.time() - self.started_at > timeout_secs - - -@dataclass(frozen=True, slots=True) -class HandshakeResult: - """Result of a completed handshake.""" - - session: Session - """Established session with derived keys.""" - - remote_enr: bytes | None - """Remote's ENR if included in handshake.""" - - -class HandshakeError(Exception): - """Error during handshake.""" - - -class HandshakeManager: - """ - Manages WHOAREYOU/HANDSHAKE exchanges. - - Thread-safe manager for concurrent handshakes with multiple peers. - Integrates with SessionCache to store completed sessions. - - Args: - local_node_id: Our 32-byte node ID. - local_private_key: Our 32-byte secp256k1 private key. - local_enr_rlp: Our RLP-encoded ENR. - local_enr_seq: Our current ENR sequence number. - session_cache: Session cache for storing completed sessions. - timeout_secs: Handshake timeout. - """ - - def __init__( - self, - local_node_id: NodeId, - local_private_key: Bytes32, - local_enr_rlp: bytes, - local_enr_seq: SeqNumber, - session_cache: SessionCache, - timeout_secs: float = HANDSHAKE_TIMEOUT_SECS, - ): - """Initialize handshake manager.""" - self._local_node_id = local_node_id - self._local_private_key = local_private_key - self._local_enr_rlp = local_enr_rlp - self._local_enr_seq = local_enr_seq - self._session_cache = session_cache - self._timeout_secs = timeout_secs - - self._pending: dict[NodeId, PendingHandshake] = {} - - # Cache of ENRs for nodes we may handshake with. - # - # Handshake verification requires the remote's public key. - # The key comes from their ENR, which may arrive before the handshake - # (via NODES responses) or within the handshake itself. - # This cache stores pre-known ENRs for lookup during verification. - self._enr_cache: dict[NodeId, ENR] = {} - - self._lock = Lock() - - def start_handshake(self, remote_node_id: NodeId) -> PendingHandshake: - """ - Start tracking a new handshake as initiator. - - Called when we send a MESSAGE to a node with no session. - We expect to receive a WHOAREYOU in response. - - Args: - remote_node_id: 32-byte node ID of the remote peer. - - Returns: - PendingHandshake in SENT_ORDINARY state. - """ - with self._lock: - # Reject new handshakes when at capacity to prevent resource exhaustion. - if len(self._pending) >= MAX_PENDING_HANDSHAKES and remote_node_id not in self._pending: - self.cleanup_expired() - if len(self._pending) >= MAX_PENDING_HANDSHAKES: - raise HandshakeError("Too many pending handshakes") - - pending = PendingHandshake( - state=HandshakeState.SENT_ORDINARY, - remote_node_id=remote_node_id, - ) - self._pending[remote_node_id] = pending - return pending - - def create_whoareyou( - self, - remote_node_id: NodeId, - request_nonce: Nonce, - remote_enr_seq: SeqNumber, - masking_iv: Bytes16, - ) -> tuple[bytes, bytes, Nonce, bytes]: - """ - Create a WHOAREYOU packet in response to an undecryptable message. - - Called when we receive a MESSAGE we can't decrypt. - - Args: - remote_node_id: 32-byte node ID of the sender. - request_nonce: 12-byte nonce from the failed MESSAGE packet. - remote_enr_seq: Our last known ENR seq for the remote (0 if unknown). - masking_iv: 16-byte masking IV that will be used for the WHOAREYOU packet. - - Returns: - Tuple of (id_nonce, authdata, nonce, challenge_data). - - id_nonce: 16-byte challenge nonce - - authdata: Encoded WHOAREYOU authdata - - nonce: The request_nonce to use in the packet header - - challenge_data: Full data for key derivation (masking-iv || static-header || authdata) - """ - id_nonce = IdNonce.generate() - authdata = encode_whoareyou_authdata(id_nonce, remote_enr_seq) - - # Build challenge_data per spec: masking-iv || static-header || authdata. - # - # This data becomes the HKDF salt for session key derivation. - # Both sides must use identical challenge_data to derive matching keys. - static_header = encode_static_header(PacketFlag.WHOAREYOU, request_nonce, len(authdata)) - challenge_data = bytes(masking_iv) + static_header + authdata - - with self._lock: - pending = PendingHandshake( - state=HandshakeState.SENT_WHOAREYOU, - remote_node_id=remote_node_id, - id_nonce=id_nonce, - challenge_data=challenge_data, - challenge_nonce=request_nonce, - remote_enr_seq=remote_enr_seq, - ) - self._pending[remote_node_id] = pending - - return bytes(id_nonce), authdata, request_nonce, challenge_data - - def create_handshake_response( - self, - remote_node_id: NodeId, - whoareyou: WhoAreYouAuthdata, - remote_pubkey: Bytes33, - challenge_data: bytes, - remote_ip: str = "", - remote_port: Port = DEFAULT_PORT, - ) -> tuple[bytes, Bytes16, Bytes16]: - """ - Create a HANDSHAKE packet in response to WHOAREYOU. - - Called when we receive a WHOAREYOU for a message we sent. - - Args: - remote_node_id: 32-byte node ID of the challenger. - whoareyou: Decoded WHOAREYOU authdata. - remote_pubkey: Remote's 33-byte compressed public key. - challenge_data: Full WHOAREYOU data for key derivation - (masking-iv || static-header || authdata from received packet). - remote_ip: Remote peer's IP address for session keying. - remote_port: Remote peer's UDP port for session keying. - - Returns: - Tuple of (authdata, send_key, recv_key). - - authdata: Encoded HANDSHAKE authdata - - send_key: 16-byte key for sending to this peer - - recv_key: 16-byte key for receiving from this peer - """ - # Generate ephemeral keypair for ECDH. - eph_privkey, eph_pubkey = generate_secp256k1_keypair() - - # Sign to prove our identity. - # - # Per spec, the signature input includes the full challenge_data (not just id_nonce) - # to bind the signature to this specific WHOAREYOU exchange. - id_signature = sign_id_nonce( - self._local_private_key, - challenge_data, - eph_pubkey, - remote_node_id, - ) - - # Include our ENR if the remote's known seq is stale. - record = None - if whoareyou.enr_seq < self._local_enr_seq: - record = self._local_enr_rlp - - # Build authdata. - authdata = encode_handshake_authdata( - src_id=self._local_node_id, - id_signature=id_signature, - eph_pubkey=eph_pubkey, - record=record, - ) - - # Derive session keys using full challenge_data as HKDF salt. - # - # The challenge_data binds keys to this specific WHOAREYOU exchange. - # Both sides must use identical challenge_data to derive matching keys. - send_key, recv_key = derive_keys_from_pubkey( - local_private_key=eph_privkey, - remote_public_key=remote_pubkey, - local_node_id=self._local_node_id, - remote_node_id=remote_node_id, - challenge_data=challenge_data, - is_initiator=True, - ) - - # Store session keyed by (node_id, ip, port). - self._session_cache.create( - node_id=remote_node_id, - send_key=send_key, - recv_key=recv_key, - is_initiator=True, - ip=remote_ip, - port=remote_port, - ) - - # Clean up pending handshake. - with self._lock: - self._pending.pop(remote_node_id, None) - - return authdata, send_key, recv_key - - def handle_handshake( - self, - remote_node_id: NodeId, - handshake: HandshakeAuthdata, - remote_ip: str = "", - remote_port: Port = DEFAULT_PORT, - ) -> HandshakeResult: - """ - Process a received HANDSHAKE packet. - - Called when we receive a HANDSHAKE in response to our WHOAREYOU. - - Args: - remote_node_id: 32-byte node ID from packet source. - handshake: Decoded HANDSHAKE authdata. - remote_ip: Remote peer's IP address for session keying. - remote_port: Remote peer's UDP port for session keying. - - Returns: - HandshakeResult with established session. - - Raises: - HandshakeError: If handshake verification fails. - """ - with self._lock: - pending = self._pending.get(remote_node_id) - if pending is None: - raise HandshakeError(f"No pending handshake for {remote_node_id.hex()}") - - if pending.state != HandshakeState.SENT_WHOAREYOU: - raise HandshakeError(f"Unexpected handshake state: {pending.state}") - - if pending.id_nonce is None: - raise HandshakeError("Missing id_nonce in pending handshake") - - if pending.challenge_data is None: - raise HandshakeError("Missing challenge_data in pending handshake") - - challenge_data = pending.challenge_data - remote_enr_seq = pending.remote_enr_seq - - # Verify the source ID matches. - if handshake.src_id != remote_node_id: - raise HandshakeError( - f"Source ID mismatch: expected {remote_node_id.hex()}, got {handshake.src_id.hex()}" - ) - - # If we sent enr_seq=0, we signaled that we don't know the remote's ENR. - # Per spec, the remote MUST include their ENR in the HANDSHAKE response - # so we can verify their identity. - if remote_enr_seq == SeqNumber(0) and handshake.record is None: - raise HandshakeError( - f"ENR required in HANDSHAKE from unknown node {remote_node_id.hex()[:16]}" - ) - - # Verify signature - we need the remote's static public key. - # This typically comes from their ENR which may be in the handshake record. - remote_pubkey = self._get_remote_pubkey(remote_node_id, handshake.record) - if remote_pubkey is None: - raise HandshakeError(f"Unknown public key for {remote_node_id.hex()}") - - # Verify the ID signature. - # - # The signature was computed over challenge_data (not just id_nonce), - # and includes our node_id as the WHOAREYOU sender (node-id-B). - if not verify_id_nonce_signature( - signature=handshake.id_signature, - challenge_data=challenge_data, - ephemeral_pubkey=handshake.eph_pubkey, - dest_node_id=self._local_node_id, - public_key_bytes=remote_pubkey, - ): - raise HandshakeError("Invalid ID signature") - - # Derive session keys using stored challenge_data as HKDF salt. - # - # The challenge_data was saved when we sent WHOAREYOU. - # Using the same data ensures both sides derive identical keys. - send_key, recv_key = derive_keys_from_pubkey( - local_private_key=self._local_private_key, - remote_public_key=handshake.eph_pubkey, - local_node_id=self._local_node_id, - remote_node_id=remote_node_id, - challenge_data=challenge_data, - is_initiator=False, - ) - - # Create session keyed by (node_id, ip, port). - session = self._session_cache.create( - node_id=remote_node_id, - send_key=send_key, - recv_key=recv_key, - is_initiator=False, - ip=remote_ip, - port=remote_port, - ) - - # Clean up pending handshake. - with self._lock: - self._pending.pop(remote_node_id, None) - - return HandshakeResult( - session=session, - remote_enr=handshake.record, - ) - - def get_pending(self, remote_node_id: NodeId) -> PendingHandshake | None: - """Get pending handshake for a node.""" - with self._lock: - pending = self._pending.get(remote_node_id) - if pending is not None and pending.is_expired(self._timeout_secs): - del self._pending[remote_node_id] - return None - return pending - - def cancel_handshake(self, remote_node_id: NodeId) -> bool: - """Cancel a pending handshake.""" - with self._lock: - if remote_node_id in self._pending: - del self._pending[remote_node_id] - return True - return False - - def cleanup_expired(self) -> int: - """Remove expired pending handshakes.""" - with self._lock: - expired = [ - node_id - for node_id, pending in self._pending.items() - if pending.is_expired(self._timeout_secs) - ] - for node_id in expired: - del self._pending[node_id] - return len(expired) - - def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> Bytes33 | None: - """ - Retrieve the remote node's static public key for signature verification. - - The handshake completes with a signature check. - We need the remote's public key to verify their id-nonce signature. - This key may come from two sources: - - 1. The handshake packet itself (if remote included their ENR) - 2. Our ENR cache (populated from prior NODES responses) - - Args: - node_id: 32-byte remote node ID. - enr_record: Optional RLP-encoded ENR from handshake. - - Returns: - 33-byte compressed secp256k1 public key, or None if unavailable. - """ - # Prefer the ENR from the handshake packet. - # - # The remote may include their ENR when responding to our challenge. - # This is the freshest source and takes precedence. - if enr_record is not None: - try: - enr = self._parse_enr_rlp(enr_record) - if enr is not None and enr.public_key is not None: - # Verify ENR ownership matches the claimed node ID. - # - # The node ID is keccak256(pubkey), so we recompute it - # to ensure the ENR belongs to who we think sent it. - computed_id = enr.compute_node_id() - if computed_id is not None and computed_id == node_id: - return Bytes33(enr.public_key) - except (ValueError, KeyError, IndexError): - pass - - # Fall back to our ENR cache. - # - # We may have received this node's ENR earlier via NODES responses. - # Use it if the handshake packet did not include an ENR. - cached_enr = self._enr_cache.get(node_id) - if cached_enr is not None and cached_enr.public_key is not None: - return Bytes33(cached_enr.public_key) - - return None - - def _parse_enr_rlp(self, enr_rlp: bytes) -> ENR | None: - """ - Decode an RLP-encoded ENR into a structured record. - - Delegates to ENR.from_rlp which handles full validation - including key sorting, size limits, and node ID computation. - - Args: - enr_rlp: RLP-encoded ENR bytes. - - Returns: - Parsed ENR with computed node ID, or None if malformed. - """ - try: - return ENR.from_rlp(enr_rlp) - except ValueError: - return None - - def register_enr(self, node_id: NodeId, enr: ENR) -> None: - """ - Cache an ENR for future handshake verification. - - When we learn about a node (via NODES responses or other means), - we cache its ENR here. Later, if that node initiates a handshake, - we can verify their identity without requiring them to include - their ENR in the handshake packet. - - Args: - node_id: 32-byte node ID (keccak256 of public key). - enr: The node's ENR containing their public key. - """ - # Evict oldest entry when at capacity. - if len(self._enr_cache) >= MAX_ENR_CACHE and node_id not in self._enr_cache: - oldest_key = next(iter(self._enr_cache)) - del self._enr_cache[oldest_key] - - self._enr_cache[node_id] = enr - - def get_cached_enr(self, node_id: NodeId) -> ENR | None: - """ - Retrieve a previously cached ENR. - - Args: - node_id: 32-byte node ID to look up. - - Returns: - The cached ENR, or None if not in cache. - """ - return self._enr_cache.get(node_id) diff --git a/src/lean_spec/subspecs/networking/discovery/keys.py b/src/lean_spec/subspecs/networking/discovery/keys.py deleted file mode 100644 index 60bbda8f..00000000 --- a/src/lean_spec/subspecs/networking/discovery/keys.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -Key derivation for Discovery v5. - -Discovery v5 derives session keys using HKDF-SHA256: -- Extract phase: HMAC-SHA256(salt=challenge_data, ikm=shared-secret) -- Expand phase: HMAC-SHA256(prk, info || 0x01) - -The challenge_data is the concatenation of: -- masking-iv (16 bytes) from the WHOAREYOU packet -- static-header (23 bytes) - unmasked -- authdata (24 bytes for WHOAREYOU) - -Using the full WHOAREYOU packet data as salt binds session keys to: -- The specific challenge (prevents replay across sessions) -- The packet structure (prevents malformed packet attacks) - -The derived keys are: -- initiator_key: Used by the handshake initiator to encrypt messages -- recipient_key: Used by the handshake recipient to encrypt messages - -References: -- https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire.md#session-keys -- RFC 5869 (HKDF) -""" - -from __future__ import annotations - -import hashlib -import hmac -from typing import Final - -from Crypto.Hash import keccak - -from lean_spec.types import Bytes16, Bytes32, Bytes33, Bytes65 - -from .crypto import ecdh_agree, pubkey_to_uncompressed - -DISCV5_KEY_AGREEMENT_INFO: Final = b"discovery v5 key agreement" -"""Info string used in HKDF expansion for Discovery v5 key derivation.""" - -SESSION_KEY_SIZE: Final = 16 -"""Size of each session key in bytes (AES-128).""" - - -def derive_keys( - secret: Bytes33, - initiator_id: Bytes32, - recipient_id: Bytes32, - challenge_data: bytes, -) -> tuple[Bytes16, Bytes16]: - """ - Derive session keys per Discovery v5 specification. - - Both parties derive the same pair of keys from: - - The ECDH shared secret - - Both node IDs (determines key direction) - - The challenge_data from WHOAREYOU (prevents replay attacks) - - Key derivation: - info = "discovery v5 key agreement" || initiator_id || recipient_id - prk = HKDF-Extract(salt=challenge_data, ikm=secret) - keys = HKDF-Expand(prk, info, 32) - initiator_key = keys[:16] - recipient_key = keys[16:32] - - Args: - secret: 33-byte ECDH shared secret (compressed point). - initiator_id: 32-byte node ID of the handshake initiator. - recipient_id: 32-byte node ID of the handshake recipient. - challenge_data: WHOAREYOU packet data (masking-iv || static-header || authdata). - This is 63 bytes: 16 (iv) + 23 (static header) + 24 (authdata). - - Returns: - Tuple of (initiator_key, recipient_key), each 16 bytes. - - The initiator uses initiator_key to encrypt and recipient_key to decrypt. - The recipient uses recipient_key to encrypt and initiator_key to decrypt. - """ - # HKDF-Extract: PRK = HMAC-SHA256(salt, IKM). - # - # Using challenge_data as salt binds session keys to the specific WHOAREYOU. - # challenge_data = masking-iv || static-header || authdata - # This includes the random id-nonce within authdata, providing replay protection. - # The full packet structure prevents malformed packet attacks. - prk = hmac.new(challenge_data, secret, hashlib.sha256).digest() - - # Include both node IDs in the info string. - # - # This binds keys to the specific communicating parties. - # Prevents key confusion attacks where an attacker substitutes - # their own node ID after observing a handshake. - info = DISCV5_KEY_AGREEMENT_INFO + initiator_id + recipient_id - - # HKDF-Expand produces deterministic output from PRK. - # - # We need 32 bytes (two 16-byte AES keys). - # SHA-256 outputs 32 bytes, so one round suffices. - t1 = hmac.new(prk, info + b"\x01", hashlib.sha256).digest() - - initiator_key = Bytes16(t1[:SESSION_KEY_SIZE]) - recipient_key = Bytes16(t1[SESSION_KEY_SIZE : SESSION_KEY_SIZE * 2]) - - return initiator_key, recipient_key - - -def derive_keys_from_pubkey( - local_private_key: Bytes32, - remote_public_key: Bytes33 | Bytes65, - local_node_id: Bytes32, - remote_node_id: Bytes32, - challenge_data: bytes, - is_initiator: bool, -) -> tuple[Bytes16, Bytes16]: - """ - Derive session keys from ECDH with automatic key ordering. - - Convenience function that performs ECDH and derives keys with - proper initiator/recipient ordering. - - Args: - local_private_key: Our 32-byte secp256k1 private key. - remote_public_key: Peer's compressed (33-byte) or uncompressed (65-byte) public key. - local_node_id: Our 32-byte node ID. - remote_node_id: Peer's 32-byte node ID. - challenge_data: WHOAREYOU packet data (masking-iv || static-header || authdata). - is_initiator: True if we initiated the handshake. - - Returns: - Tuple of (send_key, recv_key) for this party. - - send_key: Use to encrypt outgoing messages. - - recv_key: Use to decrypt incoming messages. - """ - # Compute shared secret. - secret = ecdh_agree(local_private_key, remote_public_key) - - # Determine key ordering based on who initiated. - if is_initiator: - initiator_key, recipient_key = derive_keys( - secret, local_node_id, remote_node_id, challenge_data - ) - # We are initiator: use initiator_key to send, recipient_key to receive. - return initiator_key, recipient_key - else: - initiator_key, recipient_key = derive_keys( - secret, remote_node_id, local_node_id, challenge_data - ) - # We are recipient: use recipient_key to send, initiator_key to receive. - return recipient_key, initiator_key - - -def compute_node_id(public_key_bytes: Bytes33 | Bytes65) -> Bytes32: - """ - Compute node ID from public key. - - Per Discovery v5 / EIP-778 "v4" identity scheme: - node_id = keccak256(uncompressed_pubkey[1:]) - - The hash is computed over the 64-byte x||y coordinates, - excluding the 0x04 prefix byte. - - Args: - public_key_bytes: Compressed (33 bytes) or uncompressed (65 bytes) public key. - - Returns: - 32-byte node ID. - """ - # Ensure uncompressed format. - uncompressed = pubkey_to_uncompressed(public_key_bytes) - - # Hash the 64-byte x||y (excluding 0x04 prefix). - k = keccak.new(digest_bits=256) - k.update(uncompressed[1:]) - return Bytes32(k.digest()) diff --git a/src/lean_spec/subspecs/networking/discovery/messages.py b/src/lean_spec/subspecs/networking/discovery/messages.py deleted file mode 100644 index 9a76a724..00000000 --- a/src/lean_spec/subspecs/networking/discovery/messages.py +++ /dev/null @@ -1,279 +0,0 @@ -""" -Discovery v5 Protocol Messages - -Wire protocol messages for Node Discovery Protocol v5.1. - -Packet Structure: - packet = masking-iv || masked-header || message - -Message Encoding: - message-pt = message-type || message-data - message-data = [request-id, ...] (RLP encoded) - -References: - - https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire.md -""" - -from __future__ import annotations - -import os -from enum import IntEnum -from typing import ClassVar, Final, Self - -from lean_spec.subspecs.networking.types import Port, SeqNumber -from lean_spec.types import BaseByteList, BaseBytes, StrictBaseModel, Uint8, Uint16 - -PROTOCOL_ID: Final[bytes] = b"discv5" -"""Protocol identifier in packet header. 6 bytes.""" - -PROTOCOL_VERSION: Final[int] = 0x0001 -"""Current protocol version (v5.1).""" - -MAX_REQUEST_ID_LENGTH: Final[int] = 8 -"""Maximum length of request-id in bytes.""" - - -class RequestId(BaseByteList): - """ - Request identifier for matching requests with responses. - - Variable length up to 8 bytes. Assigned by the requester and echoed - in responses. Selection of values is implementation-defined. - """ - - LIMIT: ClassVar[int] = MAX_REQUEST_ID_LENGTH - - @classmethod - def generate(cls) -> Self: - """Generate a random request ID.""" - return cls(data=os.urandom(8)) - - -class IPv4(BaseBytes): - """IPv4 address as 4 bytes.""" - - LENGTH: ClassVar[int] = 4 - - -class IPv6(BaseBytes): - """IPv6 address as 16 bytes.""" - - LENGTH: ClassVar[int] = 16 - - -class IdNonce(BaseBytes): - """ - Identity nonce for WHOAREYOU packets. - - 128-bit random value used in the identity verification procedure. - """ - - LENGTH: ClassVar[int] = 16 - - @classmethod - def generate(cls) -> Self: - """Generate a random 16-byte identity challenge nonce.""" - return cls(os.urandom(16)) - - -class Nonce(BaseBytes): - """ - Message nonce for packet encryption. - - 96-bit value. Must be unique for every message packet. - """ - - LENGTH: ClassVar[int] = 12 - - @classmethod - def generate(cls) -> Self: - """Generate a random 12-byte message nonce.""" - return cls(os.urandom(cls.LENGTH)) - - -class Distance(Uint16): - """Log2 distance (0-256). Distance 0 returns the node's own ENR.""" - - -class PacketFlag(IntEnum): - """ - Packet type identifier in the protocol header. - - Determines the encoding of the authdata section. - """ - - MESSAGE = 0 - """Ordinary message packet. authdata = src-id (32 bytes).""" - - WHOAREYOU = 1 - """Challenge packet. authdata = id-nonce || enr-seq (24 bytes).""" - - HANDSHAKE = 2 - """Handshake message packet. authdata = variable size.""" - - -class MessageType(IntEnum): - """ - Message type identifiers in the encrypted message payload. - - Encoded as the first byte of message-pt before RLP message-data. - """ - - PING = 0x01 - """Liveness check. message-data = [request-id, enr-seq].""" - - PONG = 0x02 - """Response to PING. message-data = [request-id, enr-seq, ip, port].""" - - FINDNODE = 0x03 - """Query nodes. message-data = [request-id, [distances...]].""" - - NODES = 0x04 - """Response with ENRs. message-data = [request-id, total, [ENRs...]].""" - - TALKREQ = 0x05 - """App protocol request. message-data = [request-id, protocol, request].""" - - TALKRESP = 0x06 - """App protocol response. message-data = [request-id, response].""" - - # Topic advertisement messages (not finalized in spec) - REGTOPIC = 0x07 - """Topic registration request (experimental).""" - - TICKET = 0x08 - """Ticket response for topic registration (experimental).""" - - REGCONFIRMATION = 0x09 - """Topic registration confirmation (experimental).""" - - TOPICQUERY = 0x0A - """Topic query request (experimental).""" - - -class Ping(StrictBaseModel): - """ - PING request (0x01) - Liveness check. - - Verifies a node is online and informs it of the sender's ENR sequence number. - The recipient compares enr_seq to decide if it needs the sender's latest record. - - Wire format: - message-data = [request-id, enr-seq] - """ - - request_id: RequestId - """Unique identifier for request/response matching.""" - - enr_seq: SeqNumber - """Sender's ENR sequence number.""" - - -class Pong(StrictBaseModel): - """ - PONG response (0x02) - Reply to PING. - - Confirms liveness and reports the sender's observed external endpoint. - Used for NAT detection and ENR endpoint verification. - - Wire format: - message-data = [request-id, enr-seq, recipient-ip, recipient-port] - """ - - request_id: RequestId - """Echoed from the PING request.""" - - enr_seq: SeqNumber - """Responder's ENR sequence number.""" - - recipient_ip: IPv4 | IPv6 - """Sender's IP as seen by responder. 4 bytes (IPv4) or 16 bytes (IPv6).""" - - recipient_port: Port - """Sender's UDP port as seen by responder.""" - - -class FindNode(StrictBaseModel): - """ - FINDNODE request (0x03) - Query nodes at distances. - - Requests nodes from the recipient's routing table at specified log2 distances. - The recommended result limit is 16 nodes per query. - - Wire format: - message-data = [request-id, [distance₁, distance₂, ...]] - - Distance semantics: - - Distance 0: Returns the recipient's own ENR - - Distance 1-256: Returns nodes at that log2 distance from recipient - """ - - request_id: RequestId - """Unique identifier for request/response matching.""" - - distances: list[Distance] - """Log2 distances to query. Each value in range 0-256.""" - - -class Nodes(StrictBaseModel): - """ - NODES response (0x04) - ENR records from routing table. - - Response to FINDNODE or TOPICQUERY. May be split across multiple messages - to stay within the 1280 byte UDP packet limit. - - Wire format: - message-data = [request-id, total, [ENR₁, ENR₂, ...]] - - Recipients should verify returned nodes match the requested distances. - """ - - request_id: RequestId - """Echoed from the FINDNODE request.""" - - total: Uint8 - """Total NODES messages for this request. Enables reassembly.""" - - enrs: list[bytes] - """RLP-encoded ENR records. Max 300 bytes each per EIP-778.""" - - -class TalkReq(StrictBaseModel): - """ - TALKREQ request (0x05) - Application protocol negotiation. - - Enables higher-layer protocols to communicate through Discovery v5. - Used by Ethereum for subnet discovery (eth2) and Portal Network. - - Wire format: - message-data = [request-id, protocol, request] - - The recipient must respond with TALKRESP. If the protocol is unknown, - the response must contain empty data. - """ - - request_id: RequestId - """Unique identifier for request/response matching.""" - - protocol: bytes - """Protocol identifier (e.g., b"eth2", b"portal").""" - - request: bytes - """Protocol-specific request payload.""" - - -class TalkResp(StrictBaseModel): - """ - TALKRESP response (0x06) - Reply to TALKREQ. - - Empty response indicates the protocol is unknown to the recipient. - - Wire format: - message-data = [request-id, response] - """ - - request_id: RequestId - """Echoed from the TALKREQ request.""" - - response: bytes - """Protocol-specific response. Empty if protocol unknown.""" diff --git a/src/lean_spec/subspecs/networking/discovery/packet.py b/src/lean_spec/subspecs/networking/discovery/packet.py deleted file mode 100644 index 3f898d5d..00000000 --- a/src/lean_spec/subspecs/networking/discovery/packet.py +++ /dev/null @@ -1,373 +0,0 @@ -""" -Packet encoding/decoding for Discovery v5. - -Discovery v5 packet structure:: - - packet = masking-iv || masked-header || message - masking-iv = random 16 bytes - masked-header = aes-ctr(key=dest-id[:16], iv=masking-iv, header) - header = static-header || authdata - -Static header (23 bytes):: - - static-header = protocol-id || version || flag || nonce || authdata-size - protocol-id = "discv5" - version = 0x0001 - flag = 0/1/2 (message/whoareyou/handshake) - nonce = 12 bytes - authdata-size = 2 bytes (big-endian) - -Authdata varies by packet type: - -- MESSAGE (flag=0): src-id (32 bytes) -- WHOAREYOU (flag=1): id-nonce (16 bytes) || enr-seq (8 bytes) -- HANDSHAKE (flag=2): variable size with ephemeral key and signature - -References: -- https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire.md#packet-encoding -""" - -from __future__ import annotations - -import os -import struct -from dataclasses import dataclass -from typing import Final - -from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes12, Bytes16, Bytes33, Bytes64 - -from .config import MAX_PACKET_SIZE, MIN_PACKET_SIZE -from .crypto import ( - AES_KEY_SIZE, - CTR_IV_SIZE, - aes_ctr_decrypt, - aes_ctr_encrypt, - aes_gcm_decrypt, - aes_gcm_encrypt, -) -from .messages import PROTOCOL_ID, PROTOCOL_VERSION, IdNonce, Nonce, PacketFlag - -STATIC_HEADER_SIZE: Final = 23 -"""Size of the static header in bytes: 6 + 2 + 1 + 12 + 2.""" - -MESSAGE_AUTHDATA_SIZE: Final = 32 -"""Authdata size for MESSAGE packets: src-id (32 bytes).""" - -WHOAREYOU_AUTHDATA_SIZE: Final = 24 -"""Authdata size for WHOAREYOU packets: id-nonce (16) + enr-seq (8).""" - -HANDSHAKE_HEADER_SIZE: Final = 34 -"""Fixed portion of handshake authdata: src-id (32) + sig-size (1) + eph-key-size (1).""" - - -@dataclass(frozen=True, slots=True) -class PacketHeader: - """Decoded packet header.""" - - flag: PacketFlag - """Packet type: message, whoareyou, or handshake.""" - - nonce: Nonce - """12-byte message nonce.""" - - authdata: bytes - """Variable-length authentication data.""" - - -@dataclass(frozen=True, slots=True) -class MessageAuthdata: - """Authdata for MESSAGE packets (flag=0).""" - - src_id: NodeId - """Sender's 32-byte node ID.""" - - -@dataclass(frozen=True, slots=True) -class WhoAreYouAuthdata: - """Authdata for WHOAREYOU packets (flag=1).""" - - id_nonce: IdNonce - """16-byte identity challenge nonce.""" - - enr_seq: SeqNumber - """Sender's last known ENR sequence for the target. 0 if unknown.""" - - -@dataclass(frozen=True, slots=True) -class HandshakeAuthdata: - """Authdata for HANDSHAKE packets (flag=2).""" - - src_id: NodeId - """Sender's 32-byte node ID.""" - - sig_size: int - """Size of the ID signature. 64 for v4 identity scheme.""" - - eph_key_size: int - """Size of ephemeral public key. 33 for compressed secp256k1.""" - - id_signature: Bytes64 - """ID nonce signature proving identity ownership.""" - - eph_pubkey: Bytes33 - """Ephemeral public key for ECDH.""" - - record: bytes | None - """RLP-encoded ENR, included if recipient's enr_seq was stale.""" - - -def encode_packet( - dest_node_id: NodeId, - flag: PacketFlag, - nonce: Nonce, - authdata: bytes, - message: bytes, - encryption_key: Bytes16 | None = None, - masking_iv: Bytes16 | None = None, -) -> bytes: - """ - Encode a Discovery v5 packet. - - Args: - dest_node_id: 32-byte destination node ID (for header masking). - flag: Packet type flag. - nonce: 12-byte message nonce. - authdata: Authentication data (varies by packet type). - message: Message payload (plaintext for WHOAREYOU, encrypted otherwise). - encryption_key: 16-byte key for message encryption (None for WHOAREYOU). - masking_iv: Optional 16-byte IV for header masking. Random if not provided. - Must be provided for WHOAREYOU to match the IV used in challenge_data. - - Returns: - Complete encoded packet ready for UDP transmission. - """ - if masking_iv is None: - # Fresh random IV for header masking. - # - # Using dest_node_id as the masking key is deterministic, - # so the IV MUST be random to prevent ciphertext patterns. - # Without randomness, identical packets would produce - # identical masked headers, enabling traffic analysis. - masking_iv = Bytes16(os.urandom(CTR_IV_SIZE)) - - static_header = encode_static_header(flag, nonce, len(authdata)) - header = static_header + authdata - - # Header masking hides protocol metadata from observers. - # - # The masking key is derived from the destination node ID. - # Only the intended recipient can unmask the header. - # This provides privacy without requiring key exchange. - masking_key = Bytes16(dest_node_id[:AES_KEY_SIZE]) - masked_header = aes_ctr_encrypt(masking_key, masking_iv, header) - - if flag == PacketFlag.WHOAREYOU: - # WHOAREYOU has no message payload. - encrypted_message = message - else: - if encryption_key is None: - raise ValueError("Encryption key required for non-WHOAREYOU packets") - - # Per spec: message-ad = masking-iv || header (plaintext). - # - # The AAD binds the plaintext header to the encrypted message. - # The recipient reconstructs this from the decoded header. - message_ad = bytes(masking_iv) + header - encrypted_message = aes_gcm_encrypt(encryption_key, Bytes12(nonce), message, message_ad) - - # Assemble packet. - packet = bytes(masking_iv) + masked_header + encrypted_message - - if len(packet) > MAX_PACKET_SIZE: - raise ValueError(f"Packet exceeds max size: {len(packet)} > {MAX_PACKET_SIZE}") - - return packet - - -def decode_packet_header(local_node_id: NodeId, data: bytes) -> tuple[PacketHeader, bytes, bytes]: - """ - Decode and unmask a Discovery v5 packet header. - - Args: - local_node_id: Our 32-byte node ID (for header unmasking). - data: Raw packet bytes. - - Returns: - Tuple of (header, message_bytes, message_ad). - message_ad is masking-iv || plaintext header, used as AAD for decryption. - - Raises: - ValueError: If packet is malformed. - """ - if len(data) < MIN_PACKET_SIZE: - raise ValueError(f"Packet too small: {len(data)} < {MIN_PACKET_SIZE}") - - # Extract masking IV. - masking_iv = Bytes16(data[:CTR_IV_SIZE]) - - # Unmask the static header to learn the authdata size, then unmask the rest. - # - # AES-CTR is a stream cipher: decrypting the first N bytes produces the same - # output regardless of how many bytes follow. We exploit this by first - # decrypting just the 23-byte static header to read authdata_size, then - # decrypting the full header (static + authdata) in a single pass. - # The second call recomputes the keystream from offset 0, so both passes - # produce identical plaintext for the overlapping bytes. - masking_key = Bytes16(local_node_id[:AES_KEY_SIZE]) - masked_data = data[CTR_IV_SIZE:] - - # First pass: decrypt static header to learn authdata size. - static_header = aes_ctr_decrypt(masking_key, masking_iv, masked_data[:STATIC_HEADER_SIZE]) - - protocol_id = static_header[:6] - if protocol_id != PROTOCOL_ID: - raise ValueError(f"Invalid protocol ID: {protocol_id!r}") - - version = struct.unpack(">H", static_header[6:8])[0] - if version != PROTOCOL_VERSION: - raise ValueError(f"Unsupported protocol version: {version}") - - flag = PacketFlag(static_header[8]) - nonce = Nonce(static_header[9:21]) - authdata_size = struct.unpack(">H", static_header[21:23])[0] - - header_end = CTR_IV_SIZE + STATIC_HEADER_SIZE + authdata_size - if len(data) < header_end: - raise ValueError(f"Packet truncated: need {header_end}, have {len(data)}") - - # Second pass: decrypt full header (static + authdata) from offset 0. - full_header = aes_ctr_decrypt( - masking_key, masking_iv, masked_data[: STATIC_HEADER_SIZE + authdata_size] - ) - authdata = full_header[STATIC_HEADER_SIZE:] - - # Message bytes are everything after the header. - message_bytes = data[header_end:] - - # Per spec: message-ad = masking-iv || header (plaintext). - message_ad = bytes(masking_iv) + full_header - - return PacketHeader(flag=flag, nonce=nonce, authdata=authdata), message_bytes, message_ad - - -def decode_message_authdata(authdata: bytes) -> MessageAuthdata: - """Decode MESSAGE packet authdata.""" - if len(authdata) != MESSAGE_AUTHDATA_SIZE: - raise ValueError(f"Invalid MESSAGE authdata size: {len(authdata)}") - return MessageAuthdata(src_id=NodeId(authdata)) - - -def decode_whoareyou_authdata(authdata: bytes) -> WhoAreYouAuthdata: - """Decode WHOAREYOU packet authdata.""" - if len(authdata) != WHOAREYOU_AUTHDATA_SIZE: - raise ValueError(f"Invalid WHOAREYOU authdata size: {len(authdata)}") - - id_nonce = IdNonce(authdata[:16]) - enr_seq = SeqNumber(struct.unpack(">Q", authdata[16:24])[0]) - - return WhoAreYouAuthdata(id_nonce=id_nonce, enr_seq=enr_seq) - - -def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata: - """Decode HANDSHAKE packet authdata.""" - # Fixed header: src-id (32 bytes) + sig-size (1 byte) + eph-key-size (1 byte) = 34 bytes. - if len(authdata) < HANDSHAKE_HEADER_SIZE: - raise ValueError(f"Handshake authdata too small: {len(authdata)}") - - src_id = NodeId(authdata[:32]) - sig_size = authdata[32] - eph_key_size = authdata[33] - - # Variable fields follow the fixed header: signature + ephemeral key + optional ENR. - expected_min = HANDSHAKE_HEADER_SIZE + sig_size + eph_key_size - if len(authdata) < expected_min: - raise ValueError(f"Handshake authdata truncated: {len(authdata)} < {expected_min}") - - offset = HANDSHAKE_HEADER_SIZE - id_signature = Bytes64(authdata[offset : offset + sig_size]) - offset += sig_size - - eph_pubkey = Bytes33(authdata[offset : offset + eph_key_size]) - offset += eph_key_size - - # Remaining bytes are the RLP-encoded ENR, included when the recipient's - # known enr_seq was stale (signaled by WHOAREYOU.enr_seq < sender's seq). - record = authdata[offset:] if offset < len(authdata) else None - - return HandshakeAuthdata( - src_id=src_id, - sig_size=sig_size, - eph_key_size=eph_key_size, - id_signature=id_signature, - eph_pubkey=eph_pubkey, - record=record, - ) - - -def decrypt_message( - encryption_key: Bytes16, - nonce: Nonce, - ciphertext: bytes, - message_ad: bytes, -) -> bytes: - """ - Decrypt an encrypted message payload. - - Args: - encryption_key: 16-byte session key. - nonce: 12-byte nonce from packet header. - ciphertext: Encrypted message with GCM tag. - message_ad: Additional authenticated data (masking-iv || plaintext header). - - Returns: - Decrypted message plaintext. - """ - return aes_gcm_decrypt(encryption_key, Bytes12(nonce), ciphertext, message_ad) - - -def encode_message_authdata(src_id: NodeId) -> bytes: - """Encode MESSAGE packet authdata.""" - return bytes(src_id) - - -def encode_whoareyou_authdata(id_nonce: IdNonce, enr_seq: SeqNumber) -> bytes: - """Encode WHOAREYOU packet authdata.""" - return id_nonce + struct.pack(">Q", enr_seq) - - -def encode_handshake_authdata( - src_id: NodeId, - id_signature: Bytes64, - eph_pubkey: Bytes33, - record: bytes | None = None, -) -> bytes: - """ - Encode HANDSHAKE packet authdata. - - Args: - src_id: 32-byte source node ID. - id_signature: 64-byte ID nonce signature. - eph_pubkey: 33-byte compressed ephemeral public key. - record: Optional RLP-encoded ENR. - - Returns: - Encoded authdata bytes. - """ - authdata = src_id + bytes([len(id_signature), len(eph_pubkey)]) + id_signature + eph_pubkey - - if record is not None: - authdata += record - - return authdata - - -def encode_static_header(flag: PacketFlag, nonce: Nonce, authdata_size: int) -> bytes: - """Encode the 23-byte static header.""" - return ( - PROTOCOL_ID - + struct.pack(">H", PROTOCOL_VERSION) - + bytes([flag]) - + nonce - + struct.pack(">H", authdata_size) - ) diff --git a/src/lean_spec/subspecs/networking/discovery/routing.py b/src/lean_spec/subspecs/networking/discovery/routing.py deleted file mode 100644 index 9638e33f..00000000 --- a/src/lean_spec/subspecs/networking/discovery/routing.py +++ /dev/null @@ -1,409 +0,0 @@ -""" -Discovery v5 Routing Table - -Kademlia-style routing table for Node Discovery Protocol v5.1. - -Node Table Structure - -Nodes keep information about other nodes in their neighborhood. Neighbor nodes -are stored in a routing table consisting of 'k-buckets'. For each 0 <= i < 256, -every node keeps a k-bucket for nodes of logdistance(self, n) == i. - -The protocol uses k = 16, meaning every k-bucket contains up to 16 node entries. -Entries are sorted by time last seen: least-recently seen at head, most-recently -seen at tail. - -Distance Metric - -The 'distance' between two node IDs is the bitwise XOR of the IDs, interpreted -as a big-endian number: - - distance(n1, n2) = n1 XOR n2 - -The logarithmic distance (length of differing suffix in bits) is used for -bucket assignment: - - logdistance(n1, n2) = log2(distance(n1, n2)) - -Bucket Eviction Policy - -When a new node N1 is encountered, it can be inserted into the corresponding -bucket. - -- If the bucket contains less than k entries, N1 is simply added. - -- If the bucket already contains k entries, the liveness of the least recently seen -node N2 must be revalidated. If no reply is received from N2, it is considered dead, -removed, and N1 added to the front of the bucket. - -Liveness Verification - -Implementations should perform liveness checks asynchronously and occasionally -verify that a random node in a random bucket is live by sending PING. When -responding to FINDNODE, implementations must avoid relaying any nodes whose -liveness has not been verified. - -References: - - https://github.com/ethereum/devp2p/blob/master/discv5/discv5-theory.md - - Maymounkov & Mazieres, "Kademlia: A Peer-to-peer Information System", 2002 -""" - -from __future__ import annotations - -from collections.abc import Iterator -from dataclasses import dataclass, field - -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.types import ForkDigest, NodeId, SeqNumber - -from .config import BUCKET_COUNT, K_BUCKET_SIZE -from .messages import Distance - - -def xor_distance(a: NodeId, b: NodeId) -> int: - """ - Compute XOR distance between two node IDs. - - XOR distance is the fundamental metric in Kademlia networks: - - distance(n1, n2) = n1 XOR n2 - - Properties: - - Symmetric: d(a, b) == d(b, a) - - Identity: d(a, a) == 0 - - Triangle inequality: d(a, c) <= d(a, b) XOR d(b, c) - - Args: - a: First 32-byte node ID. - b: Second 32-byte node ID. - - Returns: - XOR of the two IDs as a big-endian integer (0 to 2^256 - 1). - """ - return int.from_bytes(a, "big") ^ int.from_bytes(b, "big") - - -def log2_distance(a: NodeId, b: NodeId) -> Distance: - """ - Compute log2 of XOR distance between two node IDs. - - Determines which k-bucket a node belongs to: - - logdistance(n1, n2) = log2(distance(n1, n2)) - - Equivalent to the bit position of the highest differing bit. - Used for bucket assignment in the routing table. - - Args: - a: First 32-byte node ID. - b: Second 32-byte node ID. - - Returns: - Log2 distance (0-256). Returns 0 for identical IDs. - """ - distance = xor_distance(a, b) - if distance == 0: - return Distance(0) - return Distance(distance.bit_length()) - - -@dataclass(slots=True) -class NodeEntry: - """ - Entry in the routing table representing a discovered node. - - Tracks node identity and liveness information for routing decisions. - Nodes should only be relayed in FINDNODE responses if verified is True. - """ - - node_id: NodeId - """32-byte node identifier derived from keccak256(pubkey).""" - - enr_seq: SeqNumber = field(default_factory=lambda: SeqNumber(0)) - """Last known ENR sequence number. Used to detect stale records.""" - - last_seen: float = 0.0 - """Unix timestamp of last successful contact.""" - - endpoint: str | None = None - """Network endpoint in 'ip:port' format.""" - - verified: bool = False - """True if node has responded to at least one PING. Required for relay.""" - - enr: ENR | None = None - """Full ENR record. Contains fork data for compatibility checks.""" - - -@dataclass(slots=True) -class KBucket: - """ - K-bucket holding nodes at a specific log2 distance range. - - Implements Kademlia's bucket semantics per the Discovery v5 spec: - - Fixed capacity of k = 16 nodes - - Least-recently seen at head, most-recently seen at tail - - New nodes added to tail, eviction candidates at head - - Eviction Policy - - When full, ping the head node (least-recently seen). - - If it responds, keep it and discard the new node. - - If it fails, evict it and add the new node. - - Replacement Cache - - Implementations should maintain a 'replacement cache' alongside each bucket. - This cache holds recently-seen nodes which would fall into the corresponding - bucket but cannot become a member because it is at capacity. Once a bucket - member becomes unresponsive, a replacement can be chosen from the cache. - """ - - nodes: list[NodeEntry] = field(default_factory=list) - """Ordered list of node entries. Head = oldest, tail = newest.""" - - @property - def is_full(self) -> bool: - """True if bucket has reached k = 16 capacity.""" - return len(self.nodes) >= K_BUCKET_SIZE - - @property - def is_empty(self) -> bool: - """True if bucket contains no nodes.""" - return len(self.nodes) == 0 - - def __len__(self) -> int: - """Number of nodes in this bucket.""" - return len(self.nodes) - - def __iter__(self) -> Iterator[NodeEntry]: - """Iterate over nodes from oldest to newest.""" - return iter(self.nodes) - - def contains(self, node_id: NodeId) -> bool: - """Check if node ID exists in this bucket.""" - return any(entry.node_id == node_id for entry in self.nodes) - - def get(self, node_id: NodeId) -> NodeEntry | None: - """Retrieve node entry by ID. Returns None if not found.""" - for entry in self.nodes: - if entry.node_id == node_id: - return entry - return None - - def add(self, entry: NodeEntry) -> bool: - """ - Add or update a node in the bucket. - - - If the node exists, moves it to the tail (most recent). - - If the bucket is full, returns False without adding. - - Note: Caller should implement eviction by pinging the head node - when this returns False. - - Args: - entry: Node entry to add. - - Returns: - - True if node was added or updated, - - False if bucket is full. - """ - for i, existing in enumerate(self.nodes): - if existing.node_id == entry.node_id: - self.nodes.pop(i) - self.nodes.append(entry) - return True - - if self.is_full: - return False - - self.nodes.append(entry) - return True - - def remove(self, node_id: NodeId) -> bool: - """ - Remove a node from the bucket. - - Args: - node_id: 32-byte node ID to remove. - - Returns: - - True if node was removed, - - False if not found. - """ - for i, entry in enumerate(self.nodes): - if entry.node_id == node_id: - self.nodes.pop(i) - return True - return False - - def head(self) -> NodeEntry | None: - """Get least-recently seen node (eviction candidate).""" - return self.nodes[0] if self.nodes else None - - def tail(self) -> NodeEntry | None: - """Get most-recently seen node.""" - return self.nodes[-1] if self.nodes else None - - -@dataclass(slots=True) -class RoutingTable: - """ - Kademlia routing table for Discovery v5. - - Organizes nodes into 256 k-buckets by XOR distance. - Bucket i contains nodes with log2(distance) == i + 1. - - Fork Filtering - - When local_fork_digest is set: - - - Only peers with matching fork_digest are accepted - - Prevents storing peers on incompatible forks - - Requires eth2 ENR data to be present - - Lookup Algorithm - - Locates the k closest nodes to a target ID: - - 1. Pick alpha (3) closest nodes from local table - 2. Send FINDNODE to each - 3. Add responses to routing table - 4. Repeat with next closest unqueried nodes - 5. Stop when k closest have been queried - - Table Maintenance - - - Track close neighbors - - Regularly refresh stale buckets - - Perform lookup for least-recently-refreshed bucket - """ - - local_id: NodeId - """This node's 32-byte identifier derived from keccak256(pubkey).""" - - buckets: list[KBucket] = field(default_factory=lambda: [KBucket() for _ in range(BUCKET_COUNT)]) - """256 k-buckets indexed by log2 distance minus one.""" - - local_fork_digest: ForkDigest | None = None - """Our fork_digest for filtering incompatible peers. None disables filtering.""" - - def bucket_index(self, node_id: NodeId) -> int: - """ - Get bucket index for a node ID. - - Bucket i contains nodes with log2(distance) == i + 1. - - Args: - node_id: 32-byte node ID to look up. - - Returns: - Bucket index (0-255). - """ - distance = log2_distance(self.local_id, node_id) - return max(0, int(distance) - 1) - - def get_bucket(self, node_id: NodeId) -> KBucket: - """Get the k-bucket containing nodes at this distance.""" - return self.buckets[self.bucket_index(node_id)] - - def is_fork_compatible(self, entry: NodeEntry) -> bool: - """ - Check if a node entry is fork-compatible. - - If local_fork_digest is set, the entry must have an ENR with - eth2 data containing the same fork_digest. - - Args: - entry: Node entry to check. - - Returns: - - True if compatible or filtering disabled, - - False if fork_digest mismatch or missing eth2 data. - """ - if self.local_fork_digest is None: - return True - - if entry.enr is None: - return False - - eth2_data = entry.enr.eth2_data - if eth2_data is None: - return False - - return eth2_data.fork_digest == self.local_fork_digest - - def add(self, entry: NodeEntry) -> bool: - """ - Add a node to the routing table. - - Rejects nodes that are on incompatible forks when fork filtering - is enabled (local_fork_digest is set). - - Args: - entry: Node entry to add. - - Returns: - - True if added/updated, - - False if bucket full, adding self, or fork incompatible. - """ - if entry.node_id == self.local_id: - return False - - if not self.is_fork_compatible(entry): - return False - - return self.get_bucket(entry.node_id).add(entry) - - def remove(self, node_id: NodeId) -> bool: - """Remove a node from the routing table.""" - return self.get_bucket(node_id).remove(node_id) - - def get(self, node_id: NodeId) -> NodeEntry | None: - """Get a node entry by ID. Returns None if not found.""" - return self.get_bucket(node_id).get(node_id) - - def contains(self, node_id: NodeId) -> bool: - """Check if a node ID exists in the routing table.""" - return self.get(node_id) is not None - - def node_count(self) -> int: - """Total number of nodes across all buckets.""" - return sum(len(bucket) for bucket in self.buckets) - - def closest_nodes(self, target: NodeId, count: int) -> list[NodeEntry]: - """ - Find the closest nodes to a target ID. - - Used during Kademlia lookup to iteratively approach the target. - The lookup initiator picks alpha closest nodes and sends FINDNODE - requests, progressively querying closer nodes. - - Args: - target: Target 32-byte node ID. - count: Maximum nodes to return (typically k = 16). - - Returns: - Nodes sorted by XOR distance to target, closest first. - """ - all_nodes = [entry for bucket in self.buckets for entry in bucket] - all_nodes.sort(key=lambda e: xor_distance(e.node_id, target)) - return all_nodes[:count] - - def nodes_at_distance(self, distance: Distance) -> list[NodeEntry]: - """ - Get all nodes at a specific log2 distance. - - Used to respond to FINDNODE requests. The recipient returns nodes - from its routing table at the requested distance. - - Args: - distance: Log2 distance (1-256). Distance 0 returns own ENR. - - Returns: - List of nodes at the specified distance. - """ - dist_int = int(distance) - if dist_int < 1 or dist_int > BUCKET_COUNT: - return [] - return list(self.buckets[dist_int - 1]) diff --git a/src/lean_spec/subspecs/networking/discovery/service.py b/src/lean_spec/subspecs/networking/discovery/service.py deleted file mode 100644 index 762d8fb9..00000000 --- a/src/lean_spec/subspecs/networking/discovery/service.py +++ /dev/null @@ -1,704 +0,0 @@ -""" -Discovery v5 service. - -Main entry point for peer discovery over UDP. - -Service Responsibilities: -- Bootstrap from known bootnodes -- Maintain routing table with discovered peers -- Perform periodic lookups to find new peers -- Handle incoming discovery requests -- Provide peers to higher-layer protocols - -Lookup Algorithm: -1. Start with alpha closest nodes from routing table -2. Send FINDNODE to each, collecting responses -3. Add newly discovered nodes to routing table -4. Repeat with next closest unqueried nodes -5. Stop when k closest nodes have been queried - -References: -- https://github.com/ethereum/devp2p/blob/master/discv5/discv5-theory.md -""" - -from __future__ import annotations - -import asyncio -import ipaddress -import logging -import os -import random -from collections.abc import Callable -from dataclasses import dataclass -from typing import Final - -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.types import NodeId, Port, SeqNumber -from lean_spec.types import Bytes32, Bytes33, Uint8 - -from .codec import DiscoveryMessage -from .config import ALPHA, K_BUCKET_SIZE, DiscoveryConfig -from .keys import compute_node_id -from .messages import Distance, FindNode, IPv4, IPv6, Nodes, Ping, Pong, TalkReq, TalkResp -from .routing import NodeEntry, RoutingTable, log2_distance, xor_distance -from .session import BondCache -from .transport import DiscoveryTransport - -logger = logging.getLogger(__name__) - -LOOKUP_PARALLELISM: Final = ALPHA -"""Number of concurrent FINDNODE queries during lookup.""" - -REFRESH_INTERVAL_SECS: Final = 3600 -"""Interval between routing table refresh lookups (1 hour).""" - -REVALIDATION_INTERVAL_SECS: Final = 300 -"""Interval between node liveness revalidation (5 minutes).""" - - -@dataclass(slots=True) -class LookupResult: - """Result of a node lookup operation.""" - - target: NodeId - """Target node ID that was searched for.""" - - nodes: list[NodeEntry] - """Nodes found, sorted by distance to target.""" - - queried: int - """Number of nodes queried during lookup.""" - - -class DiscoveryService: - """ - Main Discovery v5 service. - - Provides high-level peer discovery functionality: - - Lookup nodes close to a target ID - - Get random peers from routing table - - Perform periodic table refresh and node revalidation - - Background tasks handle table refresh, liveness checks, and session cleanup. - - Args: - local_enr: Our ENR. - private_key: Our 32-byte secp256k1 private key. - config: Optional protocol configuration. - bootnodes: Initial nodes to connect to. - """ - - def __init__( - self, - local_enr: ENR, - private_key: Bytes32, - config: DiscoveryConfig | None = None, - bootnodes: list[ENR] | None = None, - ): - """Initialize discovery service.""" - self._local_enr = local_enr - self._private_key = private_key - self._config = config or DiscoveryConfig() - self._bootnodes = bootnodes or [] - - # Compute our node ID from public key. - if local_enr.public_key is None: - raise ValueError("Local ENR must have a public key") - self._local_node_id = NodeId(compute_node_id(Bytes33(local_enr.public_key))) - - # Initialize routing table. - self._routing_table = RoutingTable(local_id=NodeId(self._local_node_id)) - - # Initialize transport. - self._transport = DiscoveryTransport( - local_node_id=self._local_node_id, - local_private_key=private_key, - local_enr=local_enr, - config=self._config, - ) - - # Bond tracking. - self._bond_cache = BondCache() - - # Background tasks. - self._tasks: list[asyncio.Task] = [] - self._running = False - - # TALKREQ handlers by protocol. - self._talk_handlers: dict[bytes, Callable[[bytes, bytes], bytes]] = {} - - # Set up message handler. - self._transport.set_message_handler(self._handle_message) - - async def start(self, host: str = "0.0.0.0", port: int = 9000) -> None: - """ - Start the discovery service. - - Args: - host: IP address to bind to. - port: UDP port to bind to. - """ - if self._running: - return - - # Start transport. - await self._transport.start(host, port) - self._running = True - - # Bootstrap from bootnodes. - await self._bootstrap() - - # Start background tasks. - self._tasks.append(asyncio.create_task(self._refresh_loop())) - self._tasks.append(asyncio.create_task(self._revalidation_loop())) - self._tasks.append(asyncio.create_task(self._cleanup_loop())) - - logger.info( - "Discovery service started on %s:%d with node ID %s", - host, - port, - self._local_node_id.hex()[:16], - ) - - async def stop(self) -> None: - """Stop the discovery service.""" - if not self._running: - return - - self._running = False - - # Cancel background tasks. - for task in self._tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self._tasks.clear() - - # Stop transport. - await self._transport.stop() - - logger.info("Discovery service stopped") - - async def find_node(self, target: NodeId) -> LookupResult: - """ - Perform a Kademlia lookup for a target node ID. - - Iteratively queries nodes progressively closer to the target. - - Args: - target: 32-byte target node ID. - - Returns: - LookupResult with found nodes sorted by distance. - """ - if len(target) != 32: - raise ValueError(f"Target must be 32 bytes, got {len(target)}") - - # Start with closest known nodes. - closest = self._routing_table.closest_nodes(target, K_BUCKET_SIZE) - if not closest: - return LookupResult(target=target, nodes=[], queried=0) - - queried: set[NodeId] = set() - seen: dict[NodeId, NodeEntry] = {entry.node_id: entry for entry in closest} - - while True: - # Find unqueried nodes closest to target. - candidates = sorted( - [e for e in seen.values() if e.node_id not in queried], - key=lambda e: xor_distance(e.node_id, target), - )[:LOOKUP_PARALLELISM] - - if not candidates: - break - - # Query candidates in parallel. - tasks = [] - for entry in candidates: - queried.add(entry.node_id) - addr = self._transport.get_node_address(entry.node_id) - if addr is not None: - tasks.append(self._query_node(entry.node_id, addr, target)) - - if tasks: - results = await asyncio.gather(*tasks, return_exceptions=True) - for result in results: - if isinstance(result, tuple): - enr_list, queried_id, distances = result - for enr_bytes in enr_list: - self._process_discovered_enr(enr_bytes, seen, queried_id, distances) - - # Sort by distance to target. - result_nodes = sorted( - seen.values(), - key=lambda e: xor_distance(e.node_id, target), - )[:K_BUCKET_SIZE] - - return LookupResult( - target=target, - nodes=result_nodes, - queried=len(queried), - ) - - def get_random_nodes(self, count: int = K_BUCKET_SIZE) -> list[NodeEntry]: - """ - Get random nodes from the routing table. - - Useful for providing peers to connection manager. - - Args: - count: Maximum nodes to return. - - Returns: - List of random node entries. - """ - all_nodes = [] - for bucket in self._routing_table.buckets: - all_nodes.extend(bucket.nodes) - - if len(all_nodes) <= count: - return all_nodes - - return random.sample(all_nodes, count) - - def get_nodes_at_distance(self, distance: int) -> list[NodeEntry]: - """ - Get nodes at a specific log2 distance. - - Args: - distance: Log2 distance (1-256). - - Returns: - Nodes at that distance from our node ID. - """ - return self._routing_table.nodes_at_distance(Distance(distance)) - - def node_count(self) -> int: - """Return total number of nodes in routing table.""" - return self._routing_table.node_count() - - def register_talk_handler( - self, - protocol: bytes, - handler: Callable[[bytes, bytes], bytes], - ) -> None: - """ - Register a handler for TALKREQ messages. - - Args: - protocol: Protocol identifier (e.g., b"eth2"). - handler: Function(node_id, request) -> response. - """ - self._talk_handlers[protocol] = handler - - async def send_talk_request( - self, - node_id: NodeId, - protocol: bytes, - request: bytes, - ) -> bytes | None: - """ - Send a TALKREQ to a node. - - Args: - node_id: 32-byte destination node ID. - protocol: Protocol identifier. - request: Protocol-specific request. - - Returns: - Response payload or None on timeout. - """ - addr = self._transport.get_node_address(node_id) - if addr is None: - return None - - return await self._transport.send_talkreq(node_id, addr, protocol, request) - - async def _bootstrap(self) -> None: - """Bootstrap from bootnodes.""" - for enr in self._bootnodes: - try: - node_id = enr.compute_node_id() - if node_id is None: - continue - - # Register address and ENR. - # - # The transport needs the ENR to complete handshakes. - # When we PING and receive WHOAREYOU, the transport looks up - # the remote's public key from its ENR cache. - if enr.ip4 and enr.udp_port: - addr = (enr.ip4, int(enr.udp_port)) - self._transport.register_node_address(node_id, addr) - self._transport.register_enr(node_id, enr) - - # Add to routing table. - entry = self._enr_to_entry(enr) - self._routing_table.add(entry) - - # Ping to establish bond. - asyncio.create_task(self._ping_node(node_id, addr)) - - except Exception as e: - logger.debug("Failed to add bootnode: %s", e) - - async def _query_node( - self, - node_id: NodeId, - addr: tuple[str, int], - target: NodeId, - ) -> tuple[list[bytes], NodeId, list[int]]: - """Query a node for nodes close to target. - - Returns: - Tuple of (enr_bytes_list, queried_node_id, requested_distances). - """ - distance = int(log2_distance(node_id, target)) - distances = [distance] if distance > 0 else [1, 2, 3] - - enrs = await self._transport.send_findnode(node_id, addr, distances) - return enrs, node_id, distances - - async def _ping_node(self, node_id: NodeId, addr: tuple[str, int]) -> bool: - """Ping a node and update bond status.""" - pong = await self._transport.send_ping(node_id, addr) - if pong is not None: - self._bond_cache.add_bond(node_id) - return True - return False - - def _handle_message( - self, - remote_node_id: NodeId, - message: DiscoveryMessage, - addr: tuple[str, int], - ) -> None: - """Dispatch an incoming message to async processing.""" - asyncio.create_task(self._process_message(remote_node_id, message, addr)) - - async def _process_message( - self, - remote_node_id: NodeId, - message: DiscoveryMessage, - addr: tuple[str, int], - ) -> None: - """Route a decoded message to its type-specific handler.""" - # Update node address. - self._transport.register_node_address(remote_node_id, addr) - - match message: - case Ping(): - await self._handle_ping(remote_node_id, message, addr) - case FindNode(): - await self._handle_findnode(remote_node_id, message, addr) - case TalkReq(): - await self._handle_talkreq(remote_node_id, message, addr) - - async def _handle_ping( - self, - remote_node_id: NodeId, - ping: Ping, - addr: tuple[str, int], - ) -> None: - """ - Respond to a PING with a PONG message. - - PING serves two purposes in Discovery v5: - - 1. Liveness check - verifies the node is reachable - 2. ENR exchange - allows nodes to learn each other's current ENR sequence - - The PONG response includes: - - - Our ENR sequence (so they can request updated ENR if needed) - - Recipient endpoint (so they learn their external IP/port) - """ - # Build PONG with our ENR sequence and their observed endpoint. - # - # The recipient_ip/port tells the sender what address we see them as. - # This helps nodes behind NAT discover their public endpoint. - # - # Per spec, recipient_ip is raw bytes: 4 bytes for IPv4, 16 for IPv6. - recipient_ip = self._encode_ip_address(addr[0]) - pong = Pong( - request_id=ping.request_id, - enr_seq=SeqNumber(self._local_enr.seq), - recipient_ip=recipient_ip, - recipient_port=Port(addr[1]), - ) - - # Send the response using the established session. - sent = await self._transport.send_response(remote_node_id, addr, pong) - - if sent: - # Successful PONG establishes mutual liveness. - # - # The remote proved they can reach us (by sending PING). - # Our successful response proves we can reach them. - # Mark them as bonded to allow future FINDNODE queries. - self._bond_cache.add_bond(remote_node_id) - - logger.debug("Received PING from %s, sent PONG: %s", remote_node_id.hex()[:16], sent) - - async def _handle_findnode( - self, - remote_node_id: NodeId, - findnode: FindNode, - addr: tuple[str, int], - ) -> None: - """ - Respond to a FINDNODE with a NODES message. - - FINDNODE is the core lookup operation in Kademlia. - The requester specifies log2 distances, and we return nodes - from those buckets in our routing table. - - Security: Only bonded nodes can query our routing table. - This prevents amplification attacks where an attacker uses - us to flood a victim with NODES responses. - """ - # Require prior bonding before sharing routing table. - # - # Bonding means we have exchanged PING/PONG. - # This prevents using our node as a reflector for amplification attacks. - if not self._bond_cache.is_bonded(remote_node_id): - logger.debug("FINDNODE from unbonded node %s", remote_node_id.hex()[:16]) - return - - # Collect ENRs from requested distance buckets. - # - # Distance 0 is special: it means "return your own ENR". - # Distances 1-256 correspond to routing table buckets. - enrs: list[bytes] = [] - for distance in findnode.distances: - if int(distance) == 0: - enrs.append(self._local_enr.to_rlp()) - else: - for entry in self._routing_table.nodes_at_distance(distance): - if entry.enr is not None: - enrs.append(entry.enr.to_rlp()) - - # Limit response size to prevent oversized packets. - enrs = enrs[: self._config.max_nodes_response] - - # Build NODES response. - # - # The 'total' field indicates how many NODES messages to expect. - # For simplicity, we send all results in one message. - # Production implementations may split across multiple messages. - nodes = Nodes( - request_id=findnode.request_id, - total=Uint8(1), - enrs=enrs, - ) - - sent = await self._transport.send_response(remote_node_id, addr, nodes) - logger.debug( - "Received FINDNODE from %s for distances %s, sent %d ENRs: %s", - remote_node_id.hex()[:16], - [int(d) for d in findnode.distances], - len(enrs), - sent, - ) - - async def _handle_talkreq( - self, - remote_node_id: NodeId, - talkreq: TalkReq, - addr: tuple[str, int], - ) -> None: - """ - Handle a TALKREQ by delegating to the registered protocol handler. - - TALKREQ enables application-specific protocols over Discovery v5. - The protocol field identifies which handler should process the request. - - Common protocols built on TALKREQ: - - - Portal Network (state, history, beacon) - - Light client sync - - Custom peer-to-peer applications - - Unknown protocols receive an empty response (not an error). - This allows graceful handling when protocols are not supported. - """ - # Look up the handler for this protocol. - handler = self._talk_handlers.get(talkreq.protocol) - - # Dispatch to handler or return empty response. - # - # Empty response for unknown protocols is per spec. - # This avoids revealing which protocols we support - # while still allowing the requester to complete their flow. - response_data = b"" - if handler is not None: - try: - response_data = handler(remote_node_id, talkreq.request) - except Exception as e: - logger.debug("TALKREQ handler error: %s", e) - - # Build and send TALKRESP. - talkresp = TalkResp( - request_id=talkreq.request_id, - response=response_data, - ) - - sent = await self._transport.send_response(remote_node_id, addr, talkresp) - logger.debug("Received TALKREQ for protocol %s, sent response: %s", talkreq.protocol, sent) - - async def _refresh_loop(self) -> None: - """Periodically refresh routing table.""" - while self._running: - await asyncio.sleep(REFRESH_INTERVAL_SECS) - try: - # Perform lookup for random target. - target = NodeId(os.urandom(32)) - await self.find_node(target) - except Exception as e: - logger.debug("Refresh failed: %s", e) - - async def _revalidation_loop(self) -> None: - """Periodically revalidate nodes.""" - while self._running: - await asyncio.sleep(REVALIDATION_INTERVAL_SECS) - try: - # Pick a random node to revalidate. - all_nodes = self.get_random_nodes(1) - if all_nodes: - entry = all_nodes[0] - addr = self._transport.get_node_address(entry.node_id) - if addr is not None: - success = await self._ping_node(entry.node_id, addr) - if not success: - self._routing_table.remove(entry.node_id) - except Exception as e: - logger.debug("Revalidation failed: %s", e) - - async def _cleanup_loop(self) -> None: - """Periodically clean up expired state.""" - while self._running: - await asyncio.sleep(60) - self._bond_cache.cleanup_expired() - - def _encode_ip_address(self, ip_str: str) -> IPv4 | IPv6: - """ - Encode an IP address string to raw bytes. - - Per Discovery v5 spec, IP addresses in PONG are raw bytes: - - IPv4: 4 bytes - - IPv6: 16 bytes - - Args: - ip_str: IP address as dotted string (IPv4) or colon-separated hex (IPv6). - - Returns: - Raw bytes representation of the IP address. - """ - packed = ipaddress.ip_address(ip_str).packed - if len(packed) == 4: - return IPv4(packed) - return IPv6(packed) - - def _enr_to_entry(self, enr: ENR) -> NodeEntry: - """Convert an ENR to a NodeEntry.""" - node_id = enr.compute_node_id() - if node_id is None: - raise ValueError("ENR has no valid node ID") - - endpoint = None - if enr.ip4 and enr.udp_port: - endpoint = f"{enr.ip4}:{enr.udp_port}" - - return NodeEntry( - node_id=node_id, - enr_seq=SeqNumber(enr.seq), - endpoint=endpoint, - enr=enr, - ) - - def _process_discovered_enr( - self, - enr_bytes: bytes, - seen: dict[NodeId, NodeEntry], - queried_node_id: NodeId | None = None, - requested_distances: list[int] | None = None, - ) -> None: - """ - Parse and process a discovered ENR from NODES response. - - Parses the RLP-encoded ENR, validates it, and adds to: - - The routing table (for future lookups) - - The seen dict (for current lookup tracking) - - The transport ENR cache (for handshake verification) - - The address registry (for UDP communication) - - Verifies that returned nodes match the requested distances when provided. - This prevents routing table poisoning from malicious peers. - - Args: - enr_bytes: RLP-encoded ENR bytes from NODES response. - seen: Dict tracking nodes seen during current lookup. - queried_node_id: Node ID of the peer that returned this ENR. - requested_distances: Distances requested in the FINDNODE query. - """ - try: - # Parse ENR from RLP. - enr = ENR.from_rlp(enr_bytes) - - # Validate the ENR has required fields. - if not enr.is_valid(): - logger.debug("Invalid ENR: missing required fields") - return - - node_id = enr.compute_node_id() - if node_id is None: - logger.debug("ENR has no valid node ID") - return - - # Verify the returned node matches the requested distances. - # - # Per spec, recipients should verify returned nodes match requested - # distances. This prevents routing table poisoning from malicious peers. - if queried_node_id is not None and requested_distances is not None: - enr_dist = int(log2_distance(node_id, queried_node_id)) - if enr_dist not in requested_distances: - logger.debug( - "Dropping ENR: distance %d not in requested %s", - enr_dist, - requested_distances, - ) - return - - # Skip if this is our own ENR. - if node_id == self._local_node_id: - return - - # Skip if already seen in this lookup. - if node_id in seen: - return - - # Create routing table entry. - entry = self._enr_to_entry(enr) - - # Add to seen dict for lookup tracking. - seen[node_id] = entry - - # Add to routing table for future lookups. - self._routing_table.add(entry) - - # Cache ENR for handshake verification. - self._transport.register_enr(node_id, enr) - - # Register address for communication. - if enr.ip4 and enr.udp_port: - addr = (enr.ip4, int(enr.udp_port)) - self._transport.register_node_address(node_id, addr) - - logger.debug("Discovered node %s via NODES", node_id.hex()[:16]) - - except ValueError as e: - logger.debug("Failed to parse ENR: %s", e) - except Exception as e: - logger.debug("Error processing discovered ENR: %s", e) diff --git a/src/lean_spec/subspecs/networking/discovery/session.py b/src/lean_spec/subspecs/networking/discovery/session.py deleted file mode 100644 index 357a524a..00000000 --- a/src/lean_spec/subspecs/networking/discovery/session.py +++ /dev/null @@ -1,311 +0,0 @@ -""" -Session management for Discovery v5. - -A session represents an established cryptographic channel with a peer. -Sessions are created after successful WHOAREYOU/HANDSHAKE exchange. - -Session Lifecycle: -1. Initiator sends encrypted MESSAGE to recipient -2. Recipient can't decrypt (no session), sends WHOAREYOU -3. Initiator responds with HANDSHAKE containing ECDH ephemeral key -4. Both parties derive shared keys, session is established -5. Subsequent messages use session keys - -Sessions expire after a timeout and must be re-established. -The spec recommends 24 hours, but implementations often use shorter durations. - -References: -- https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire.md#session-cache -""" - -from __future__ import annotations - -import time -from dataclasses import dataclass, field -from threading import Lock -from typing import Final, NamedTuple - -from lean_spec.subspecs.networking.types import NodeId, Port -from lean_spec.types import Bytes16 - -from .config import BOND_EXPIRY_SECS, DEFAULT_PORT - -DEFAULT_SESSION_TIMEOUT_SECS: Final = 86400 -"""Default session timeout (24 hours).""" - -MAX_SESSIONS: Final = 1000 -"""Maximum number of cached sessions to prevent memory exhaustion.""" - - -@dataclass(slots=True) -class Session: - """ - Active session with a peer. - - Stores the symmetric keys derived during handshake. - Keys are directional: we use different keys for send vs receive. - """ - - node_id: NodeId - """Peer's 32-byte node ID.""" - - send_key: Bytes16 - """16-byte key for encrypting messages to this peer.""" - - recv_key: Bytes16 - """16-byte key for decrypting messages from this peer.""" - - created_at: float - """Unix timestamp when session was established.""" - - last_seen: float - """Unix timestamp of last successful message exchange.""" - - is_initiator: bool - """True if we initiated the handshake.""" - - def is_expired(self, timeout_secs: float = DEFAULT_SESSION_TIMEOUT_SECS) -> bool: - """Check if session has expired.""" - return time.time() - self.created_at > timeout_secs - - def touch(self) -> None: - """Update last_seen timestamp.""" - self.last_seen = time.time() - - -class SessionKey(NamedTuple): - """Session cache key: (node_id, ip, port). - - Per spec, sessions are tied to a specific UDP endpoint. - This prevents session confusion if a node changes IP or port. - """ - - node_id: NodeId - ip: str - port: Port - - -@dataclass -class SessionCache: - """ - Cache of active sessions with peers. - - Thread-safe session storage with automatic expiration cleanup. - Sessions are keyed by (node_id, ip, port) per spec requirement - that sessions are tied to a specific UDP endpoint. - """ - - sessions: dict[SessionKey, Session] = field(default_factory=dict) - """(node_id, ip, port) -> Session mapping.""" - - timeout_secs: float = DEFAULT_SESSION_TIMEOUT_SECS - """Session expiration timeout.""" - - max_sessions: int = MAX_SESSIONS - """Maximum cached sessions.""" - - _lock: Lock = field(default_factory=Lock) - """Thread safety lock.""" - - def get(self, node_id: NodeId, ip: str = "", port: Port = DEFAULT_PORT) -> Session | None: - """ - Get an active session for a node at a specific endpoint. - - Returns None if no session exists or if it has expired. - - Args: - node_id: 32-byte peer node ID. - ip: Peer IP address. - port: Peer UDP port. - - Returns: - Active session or None. - """ - key = SessionKey(node_id, ip, port) - with self._lock: - session = self.sessions.get(key) - if session is None: - return None - - if session.is_expired(self.timeout_secs): - del self.sessions[key] - return None - - return session - - def create( - self, - node_id: NodeId, - send_key: Bytes16, - recv_key: Bytes16, - is_initiator: bool, - ip: str = "", - port: Port = DEFAULT_PORT, - ) -> Session: - """ - Create and store a new session. - - If a session already exists for this endpoint, it is replaced. - If the cache is full, the oldest session is evicted. - - Args: - node_id: 32-byte peer node ID. - send_key: 16-byte encryption key for outgoing messages. - recv_key: 16-byte decryption key for incoming messages. - is_initiator: True if we initiated the handshake. - ip: Peer IP address. - port: Peer UDP port. - - Returns: - The newly created session. - """ - key = SessionKey(node_id, ip, port) - now = time.time() - session = Session( - node_id=node_id, - send_key=send_key, - recv_key=recv_key, - created_at=now, - last_seen=now, - is_initiator=is_initiator, - ) - - with self._lock: - # Evict oldest if at capacity. - if len(self.sessions) >= self.max_sessions and key not in self.sessions: - self._evict_oldest() - - self.sessions[key] = session - - return session - - def remove(self, node_id: NodeId, ip: str = "", port: Port = DEFAULT_PORT) -> bool: - """ - Remove a session. - - Args: - node_id: 32-byte peer node ID. - ip: Peer IP address. - port: Peer UDP port. - - Returns: - True if session was removed, False if not found. - """ - key = SessionKey(node_id, ip, port) - with self._lock: - if key in self.sessions: - del self.sessions[key] - return True - return False - - def touch(self, node_id: NodeId, ip: str = "", port: Port = DEFAULT_PORT) -> bool: - """ - Update the last_seen timestamp for a session. - - Holds the lock across lookup and mutation to prevent a concurrent - thread from evicting the session between the two operations. - - Args: - node_id: 32-byte peer node ID. - ip: Peer IP address. - port: Peer UDP port. - - Returns: - True if session was updated, False if not found. - """ - key = SessionKey(node_id, ip, port) - with self._lock: - session = self.sessions.get(key) - if session is not None and not session.is_expired(self.timeout_secs): - session.touch() - return True - return False - - def cleanup_expired(self) -> int: - """ - Remove all expired sessions. - - Returns: - Number of sessions removed. - """ - with self._lock: - expired = [ - key - for key, session in self.sessions.items() - if session.is_expired(self.timeout_secs) - ] - for key in expired: - del self.sessions[key] - return len(expired) - - def count(self) -> int: - """Return number of active sessions.""" - with self._lock: - return len(self.sessions) - - def _evict_oldest(self) -> None: - """Evict the least recently used session. Must be called with lock held.""" - if not self.sessions: - return - - oldest_key = min(self.sessions, key=lambda k: self.sessions[k].last_seen) - del self.sessions[oldest_key] - - -@dataclass(slots=True) -class BondCache: - """ - Cache tracking which nodes we have successfully bonded with. - - A node is considered "bonded" after a successful PING/PONG exchange. - Bonded nodes can be included in FINDNODE responses. - - This is separate from sessions because a bond can persist - even if the session expires. - """ - - bonds: dict[NodeId, float] = field(default_factory=dict) - """Node ID -> timestamp of last successful PONG.""" - - expiry_secs: float = BOND_EXPIRY_SECS - """Bond expiration timeout (default 24 hours).""" - - _lock: Lock = field(default_factory=Lock) - """Thread safety lock.""" - - def is_bonded(self, node_id: NodeId) -> bool: - """Check if we have a valid bond with a node.""" - with self._lock: - timestamp = self.bonds.get(node_id) - if timestamp is None: - return False - if time.time() - timestamp > self.expiry_secs: - del self.bonds[node_id] - return False - return True - - def add_bond(self, node_id: NodeId) -> None: - """Record a successful bond with a node.""" - with self._lock: - self.bonds[node_id] = time.time() - - def remove_bond(self, node_id: NodeId) -> bool: - """Remove a bond.""" - with self._lock: - if node_id in self.bonds: - del self.bonds[node_id] - return True - return False - - def cleanup_expired(self) -> int: - """Remove expired bonds.""" - now = time.time() - with self._lock: - expired = [ - node_id - for node_id, timestamp in self.bonds.items() - if now - timestamp > self.expiry_secs - ] - for node_id in expired: - del self.bonds[node_id] - return len(expired) diff --git a/src/lean_spec/subspecs/networking/discovery/transport.py b/src/lean_spec/subspecs/networking/discovery/transport.py deleted file mode 100644 index b36951f8..00000000 --- a/src/lean_spec/subspecs/networking/discovery/transport.py +++ /dev/null @@ -1,903 +0,0 @@ -""" -UDP transport for Discovery v5. - -Provides async UDP send/receive with packet encoding/decoding. - -Transport Responsibilities: -- Bind to UDP socket -- Send/receive raw packets -- Route incoming packets to appropriate handlers -- Manage pending requests and timeouts - -References: -- https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire.md -""" - -from __future__ import annotations - -import asyncio -import logging -import os -from collections.abc import Callable -from dataclasses import dataclass - -from cryptography.exceptions import InvalidTag - -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes16, Bytes32, Bytes33 - -from .codec import ( - DiscoveryMessage, - MessageDecodingError, - decode_message, - encode_message, -) -from .config import DiscoveryConfig -from .handshake import HandshakeError, HandshakeManager -from .messages import ( - Distance, - FindNode, - Nodes, - Nonce, - PacketFlag, - Ping, - Pong, - Port, - RequestId, - TalkReq, - TalkResp, -) -from .packet import ( - PacketHeader, - decode_handshake_authdata, - decode_message_authdata, - decode_packet_header, - decode_whoareyou_authdata, - decrypt_message, - encode_message_authdata, - encode_packet, - encode_static_header, -) -from .session import SessionCache - -logger = logging.getLogger(__name__) - - -@dataclass(slots=True) -class PendingRequest: - """Tracks a pending request awaiting response.""" - - request_id: RequestId - """Request ID for matching responses.""" - - dest_node_id: NodeId - """Destination node ID.""" - - sent_at: float - """Timestamp when request was sent.""" - - nonce: Nonce - """Packet nonce (needed for WHOAREYOU handling).""" - - message: DiscoveryMessage - """Original message (for retransmission after handshake).""" - - future: asyncio.Future[DiscoveryMessage | None] - """Future to complete when response arrives.""" - - -@dataclass(slots=True) -class PendingMultiRequest: - """Tracks a pending request that may receive multiple responses. - - Used for FINDNODE which can return multiple NODES messages split - across UDP packets when results exceed MTU. - """ - - request_id: RequestId - """Request ID for matching responses.""" - - dest_node_id: NodeId - """Destination node ID.""" - - sent_at: float - """Timestamp when request was sent.""" - - nonce: Nonce - """Packet nonce (needed for WHOAREYOU handling).""" - - message: DiscoveryMessage - """Original message (for retransmission after handshake).""" - - response_queue: asyncio.Queue[DiscoveryMessage] - """Queue to collect multiple responses.""" - - expected_total: int | None - """Expected number of responses (from first NODES.total field).""" - - received_count: int - """Number of responses received so far.""" - - -class DiscoveryProtocol(asyncio.DatagramProtocol): - """ - Async UDP protocol handler for Discovery v5. - - Args: - transport_handler: Parent transport for packet handling. - """ - - def __init__(self, transport_handler: DiscoveryTransport) -> None: - """Initialize protocol handler.""" - self._handler = transport_handler - self._transport: asyncio.DatagramTransport | None = None - - def connection_made(self, transport: asyncio.transports.BaseTransport) -> None: - """Called when UDP socket is ready.""" - assert isinstance(transport, asyncio.DatagramTransport) - self._transport = transport - - def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: - """Called when a UDP packet is received.""" - asyncio.create_task(self._handler._handle_packet(data, addr)) - - def error_received(self, exc: Exception) -> None: - """Called when a send/receive error occurs.""" - logger.warning("UDP error: %s", exc) - - def connection_lost(self, exc: Exception | None) -> None: - """Called when the socket is closed.""" - if exc is not None: - logger.warning("UDP connection lost: %s", exc) - - -class DiscoveryTransport: - """ - UDP transport for Discovery v5. - - Handles all wire protocol operations: - - Packet encoding/decoding - - Session management - - Handshake orchestration - - Request/response matching - - Args: - local_node_id: Our 32-byte node ID. - local_private_key: Our 32-byte secp256k1 private key. - local_enr: Our ENR. - config: Optional protocol configuration. - """ - - def __init__( - self, - local_node_id: NodeId, - local_private_key: Bytes32, - local_enr: ENR, - config: DiscoveryConfig | None = None, - ): - """Initialize discovery transport.""" - self._local_node_id = local_node_id - self._local_private_key = local_private_key - self._local_enr = local_enr - self._config = config or DiscoveryConfig() - - self._session_cache = SessionCache() - self._handshake_manager = HandshakeManager( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr_rlp=local_enr.to_rlp(), - local_enr_seq=SeqNumber(local_enr.seq), - session_cache=self._session_cache, - ) - - self._protocol: DiscoveryProtocol | None = None - self._transport: asyncio.DatagramTransport | None = None - self._pending_requests: dict[RequestId, PendingRequest] = {} - self._pending_multi_requests: dict[RequestId, PendingMultiRequest] = {} - self._node_addresses: dict[NodeId, tuple[str, int]] = {} - - self._message_handler: ( - Callable[[NodeId, DiscoveryMessage, tuple[str, int]], None] | None - ) = None - - self._running = False - - async def start(self, host: str = "0.0.0.0", port: int = 9000) -> None: - """ - Start listening for UDP packets. - - Args: - host: IP address to bind to. - port: UDP port to bind to. - """ - if self._running: - return - - loop = asyncio.get_running_loop() - - transport, protocol = await loop.create_datagram_endpoint( - lambda: DiscoveryProtocol(self), - local_addr=(host, port), - ) - assert isinstance(transport, asyncio.DatagramTransport) - assert isinstance(protocol, DiscoveryProtocol) - self._transport = transport - self._protocol = protocol - - self._running = True - logger.info("Discovery transport started on %s:%d", host, port) - - async def stop(self) -> None: - """Stop the transport.""" - if not self._running: - return - - self._running = False - - if self._transport is not None: - self._transport.close() - self._transport = None - - # Cancel pending requests. - for pending in self._pending_requests.values(): - if not pending.future.done(): - pending.future.cancel() - self._pending_requests.clear() - - logger.info("Discovery transport stopped") - - def set_message_handler( - self, - handler: Callable[[NodeId, DiscoveryMessage, tuple[str, int]], None], - ) -> None: - """Set handler for incoming messages.""" - self._message_handler = handler - - def register_node_address(self, node_id: NodeId, address: tuple[str, int]) -> None: - """Register a node's UDP address.""" - self._node_addresses[node_id] = address - - def get_node_address(self, node_id: NodeId) -> tuple[str, int] | None: - """Get a node's registered UDP address.""" - return self._node_addresses.get(node_id) - - def register_enr(self, node_id: NodeId, enr: ENR) -> None: - """ - Cache an ENR for future handshake completion. - - The ENR contains the node's public key, which is essential for: - - - ECDH key derivation during session establishment - - Verifying id-nonce signatures in handshake responses - - The handshake manager is the single owner of the ENR cache. - - Args: - node_id: 32-byte node ID (keccak256 of public key). - enr: The node's ENR. - """ - self._handshake_manager.register_enr(node_id, enr) - - def get_enr(self, node_id: NodeId) -> ENR | None: - """ - Retrieve a cached ENR by node ID. - - Args: - node_id: 32-byte node ID to look up. - - Returns: - The cached ENR, or None if unknown. - """ - return self._handshake_manager.get_cached_enr(node_id) - - async def send_ping(self, dest_node_id: NodeId, dest_addr: tuple[str, int]) -> Pong | None: - """ - Send a PING and wait for PONG. - - Args: - dest_node_id: 32-byte destination node ID. - dest_addr: (ip, port) tuple. - - Returns: - PONG response or None on timeout. - """ - request_id = RequestId.generate() - ping = Ping( - request_id=request_id, - enr_seq=SeqNumber(self._local_enr.seq), - ) - - response = await self._send_request(dest_node_id, dest_addr, ping) - if isinstance(response, Pong): - return response - return None - - async def send_findnode( - self, - dest_node_id: NodeId, - dest_addr: tuple[str, int], - distances: list[int], - ) -> list[bytes]: - """ - Send FINDNODE and collect all NODES responses. - - Per spec, FINDNODE responses may be split across multiple NODES messages - when results exceed UDP MTU. The `total` field indicates how many messages - to expect. We collect all messages until `total` is reached or timeout. - - Args: - dest_node_id: 32-byte destination node ID. - dest_addr: (ip, port) tuple. - distances: List of log2 distances to query. - - Returns: - List of RLP-encoded ENRs from all NODES responses. - """ - request_id = RequestId.generate() - findnode = FindNode( - request_id=request_id, - distances=[Distance(d) for d in distances], - ) - - # Use multi-response collection for FINDNODE. - responses = await self._send_multi_response_request(dest_node_id, dest_addr, findnode) - - # Collect all ENRs from all NODES responses. - all_enrs: list[bytes] = [] - for response in responses: - if isinstance(response, Nodes): - all_enrs.extend(response.enrs) - - return all_enrs - - async def _send_multi_response_request( - self, - dest_node_id: NodeId, - dest_addr: tuple[str, int], - message: DiscoveryMessage, - ) -> list[DiscoveryMessage]: - """ - Send a request that may receive multiple responses. - - Used for FINDNODE which can return multiple NODES messages. - Collects responses until the expected total is reached or timeout. - - Args: - dest_node_id: 32-byte destination node ID. - dest_addr: (ip, port) tuple. - message: Request message to send. - - Returns: - List of response messages (may be empty on timeout/error). - """ - if self._transport is None: - raise RuntimeError("Transport not started") - - # Register address for responses. - self._node_addresses[dest_node_id] = dest_addr - - # Build and send packet. - nonce = Nonce.generate() - message_bytes = encode_message(message) - packet = self._build_message_packet(dest_node_id, dest_addr, nonce, message_bytes) - - # Create collector for multiple responses. - loop = asyncio.get_running_loop() - - # Use a queue to collect multiple responses. - response_queue: asyncio.Queue[DiscoveryMessage] = asyncio.Queue() - pending = PendingMultiRequest( - request_id=message.request_id, - dest_node_id=dest_node_id, - sent_at=loop.time(), - nonce=nonce, - message=message, - response_queue=response_queue, - expected_total=None, - received_count=0, - ) - self._pending_multi_requests[message.request_id] = pending - - # Send packet. - self._transport.sendto(packet, dest_addr) - - # Collect responses until total reached or timeout. - responses: list[DiscoveryMessage] = [] - deadline = loop.time() + self._config.request_timeout_secs - - try: - while True: - remaining = deadline - loop.time() - if remaining <= 0: - break - - try: - response = await asyncio.wait_for(response_queue.get(), timeout=remaining) - responses.append(response) - - # Update expected total from first NODES response. - if isinstance(response, Nodes): - if pending.expected_total is None: - pending.expected_total = int(response.total) - pending.received_count += 1 - - # Check if we've received all expected messages. - if ( - pending.expected_total is not None - and pending.received_count >= pending.expected_total - ): - break - - except asyncio.TimeoutError: - break - - finally: - self._pending_multi_requests.pop(message.request_id, None) - - return responses - - async def send_talkreq( - self, - dest_node_id: NodeId, - dest_addr: tuple[str, int], - protocol: bytes, - request: bytes, - ) -> bytes | None: - """ - Send TALKREQ and wait for TALKRESP. - - Args: - dest_node_id: 32-byte destination node ID. - dest_addr: (ip, port) tuple. - protocol: Protocol identifier. - request: Protocol-specific request payload. - - Returns: - Response payload or None on timeout/error. - """ - request_id = RequestId.generate() - talkreq = TalkReq( - request_id=request_id, - protocol=protocol, - request=request, - ) - - response = await self._send_request(dest_node_id, dest_addr, talkreq) - if isinstance(response, TalkResp): - return response.response - return None - - async def _send_request( - self, - dest_node_id: NodeId, - dest_addr: tuple[str, int], - message: DiscoveryMessage, - ) -> DiscoveryMessage | None: - """ - Send a request and wait for response. - - Handles session establishment if needed. - """ - if self._transport is None: - raise RuntimeError("Transport not started") - - # Register address for responses. - self._node_addresses[dest_node_id] = dest_addr - - # Build and send packet. - nonce = Nonce.generate() - message_bytes = encode_message(message) - packet = self._build_message_packet(dest_node_id, dest_addr, nonce, message_bytes) - - # Create pending request. - loop = asyncio.get_running_loop() - future: asyncio.Future[DiscoveryMessage | None] = loop.create_future() - request_id = message.request_id - pending = PendingRequest( - request_id=request_id, - dest_node_id=dest_node_id, - sent_at=loop.time(), - nonce=nonce, - message=message, - future=future, - ) - self._pending_requests[request_id] = pending - - # Send packet. - self._transport.sendto(packet, dest_addr) - - # Wait for response with timeout. - try: - return await asyncio.wait_for( - future, - timeout=self._config.request_timeout_secs, - ) - except asyncio.TimeoutError: - return None - finally: - self._pending_requests.pop(request_id, None) - - def _build_message_packet( - self, - dest_node_id: NodeId, - dest_addr: tuple[str, int], - nonce: Nonce, - message_bytes: bytes, - ) -> bytes: - """ - Build a MESSAGE packet, using session key if available or a dummy key - to trigger handshake. - - Args: - dest_node_id: 32-byte destination node ID. - dest_addr: (ip, port) tuple. - nonce: 12-byte message nonce. - message_bytes: Encoded message payload. - - Returns: - Encoded packet bytes. - """ - ip, port = dest_addr - session = self._session_cache.get(dest_node_id, ip, Port(port)) - - authdata = encode_message_authdata(self._local_node_id) - - if session is not None: - return encode_packet( - dest_node_id=dest_node_id, - flag=PacketFlag.MESSAGE, - nonce=nonce, - authdata=authdata, - message=message_bytes, - encryption_key=session.send_key, - ) - - # Deliberate decryption failure triggers handshake. - # - # Discovery v5's handshake is initiated by failure: - # - # 1. We send a MESSAGE with random encryption key - # 2. Recipient cannot decrypt (they don't have the key) - # 3. Recipient responds with WHOAREYOU challenge - # 4. We complete handshake with HANDSHAKE packet - # - # This approach avoids the need for session negotiation - # before sending the first message. - self._handshake_manager.start_handshake(dest_node_id) - dummy_key = Bytes16(os.urandom(16)) - return encode_packet( - dest_node_id=dest_node_id, - flag=PacketFlag.MESSAGE, - nonce=nonce, - authdata=authdata, - message=message_bytes, - encryption_key=dummy_key, - ) - - async def _handle_packet(self, data: bytes, addr: tuple[str, int]) -> None: - """ - Decode and dispatch a received UDP packet. - - Unmasking the header reveals the packet type (MESSAGE, WHOAREYOU, - or HANDSHAKE) and routes to the appropriate handler. - """ - try: - # Decode packet header. - header, message_bytes, message_ad = decode_packet_header(self._local_node_id, data) - - if header.flag == PacketFlag.WHOAREYOU: - await self._handle_whoareyou(header, message_bytes, addr, data) - elif header.flag == PacketFlag.HANDSHAKE: - await self._handle_handshake(header, message_bytes, addr, message_ad) - else: - await self._handle_message(header, message_bytes, addr, message_ad) - - except (ValueError, MessageDecodingError, HandshakeError) as e: - logger.debug("Error handling packet from %s: %s", addr, e) - - async def _handle_whoareyou( - self, - header: PacketHeader, - message_bytes: bytes, - addr: tuple[str, int], - raw_packet: bytes, - ) -> None: - """ - Respond to a WHOAREYOU challenge with a HANDSHAKE packet. - - WHOAREYOU is the recipient's way of saying "I cannot decrypt your message." - We must prove our identity and establish session keys before communication. - - The response flow: - - 1. Find which pending request triggered this challenge - 2. Extract challenge_data from the WHOAREYOU for key derivation - 3. Look up the remote's public key from our ENR cache - 4. Generate ephemeral keypair for ECDH - 5. Sign the challenge nonce to prove identity - 6. Derive session keys and send HANDSHAKE with original message - """ - whoareyou = decode_whoareyou_authdata(header.authdata) - - # Match WHOAREYOU to our pending request via nonce. - # - # The WHOAREYOU contains our original packet's nonce. - # This links the challenge to the specific request that failed. - pending = None - for p in self._pending_requests.values(): - if p.nonce == header.nonce: - pending = p - break - - if pending is None: - logger.debug("No pending request for WHOAREYOU nonce") - return - - remote_node_id = pending.dest_node_id - - # Extract challenge_data for key derivation. - # - # Per spec: challenge_data = masking-iv || static-header || authdata - # This is the first 63 bytes of the WHOAREYOU packet: - # - masking-iv: 16 bytes - # - static-header: 23 bytes (protocol-id + version + flag + nonce + authdata-size) - # - authdata: 24 bytes (id-nonce 16 + enr-seq 8) - # - # We use the unmasked header, which we can reconstruct from the decoded values. - masking_iv = raw_packet[:16] - static_header = encode_static_header( - PacketFlag.WHOAREYOU, header.nonce, len(header.authdata) - ) - challenge_data = masking_iv + static_header + header.authdata - - # Retrieve the remote's public key for ECDH. - # - # Session key derivation requires ECDH between our ephemeral private key - # and the remote's static public key. Without their ENR, we cannot proceed. - remote_enr = self._handshake_manager.get_cached_enr(remote_node_id) - if remote_enr is None or remote_enr.public_key is None: - logger.debug("No ENR for %s, cannot complete handshake", remote_node_id.hex()[:16]) - return - - remote_pubkey = Bytes33(remote_enr.public_key) - - # Build and send the HANDSHAKE response. - try: - ip, port = addr - authdata, send_key, recv_key = self._handshake_manager.create_handshake_response( - remote_node_id=remote_node_id, - whoareyou=whoareyou, - remote_pubkey=remote_pubkey, - challenge_data=challenge_data, - remote_ip=ip, - remote_port=Port(port), - ) - - # Re-send the original message, now encrypted with the new session key. - # - # The HANDSHAKE packet includes both the authentication data - # and our original message (encrypted). This completes the - # handshake and delivers the message in one round trip. - message_bytes = encode_message(pending.message) - nonce = Nonce.generate() - - packet = encode_packet( - dest_node_id=remote_node_id, - flag=PacketFlag.HANDSHAKE, - nonce=nonce, - authdata=authdata, - message=message_bytes, - encryption_key=send_key, - ) - - if self._transport is not None: - self._transport.sendto(packet, addr) - logger.debug("Sent HANDSHAKE to %s", remote_node_id.hex()[:16]) - - except (HandshakeError, ValueError) as e: - logger.debug("Failed to create handshake response: %s", e) - - async def _handle_handshake( - self, - header: PacketHeader, - message_bytes: bytes, - addr: tuple[str, int], - message_ad: bytes, - ) -> None: - """ - Complete a handshake initiated by our WHOAREYOU. - - Verifies the remote's identity signature, derives session keys - via ECDH, and decrypts the included message payload. - """ - handshake_authdata = decode_handshake_authdata(header.authdata) - remote_node_id = handshake_authdata.src_id - - try: - ip, port = addr - result = self._handshake_manager.handle_handshake( - remote_node_id, handshake_authdata, remote_ip=ip, remote_port=Port(port) - ) - logger.debug("Handshake completed with %s", remote_node_id.hex()[:16]) - - # Decrypt the included message. - if len(message_bytes) > 0: - plaintext = decrypt_message( - encryption_key=result.session.recv_key, - nonce=header.nonce, - ciphertext=message_bytes, - message_ad=message_ad, - ) - - message = decode_message(plaintext) - await self._handle_decoded_message(remote_node_id, message, addr) - - except (HandshakeError, ValueError) as e: - logger.debug("Handshake failed: %s", e) - - async def _handle_message( - self, - header: PacketHeader, - message_bytes: bytes, - addr: tuple[str, int], - message_ad: bytes, - ) -> None: - """ - Decrypt and process an ordinary MESSAGE packet using session keys. - - If no session exists or decryption fails, sends WHOAREYOU - to initiate a handshake with the sender. - """ - message_authdata = decode_message_authdata(header.authdata) - remote_node_id = message_authdata.src_id - - # Get session keyed by (node_id, ip, port). - ip, port = addr - session = self._session_cache.get(remote_node_id, ip, Port(port)) - if session is None: - # Can't decrypt - send WHOAREYOU. - await self._send_whoareyou(remote_node_id, header.nonce, addr) - return - - try: - plaintext = decrypt_message( - encryption_key=session.recv_key, - nonce=header.nonce, - ciphertext=message_bytes, - message_ad=message_ad, - ) - - message = decode_message(plaintext) - await self._handle_decoded_message(remote_node_id, message, addr) - - except (InvalidTag, ValueError, MessageDecodingError) as e: - # Decryption failed - send WHOAREYOU. - logger.debug("Decryption failed, sending WHOAREYOU: %s", e) - await self._send_whoareyou(remote_node_id, header.nonce, addr) - - async def _handle_decoded_message( - self, - remote_node_id: NodeId, - message: DiscoveryMessage, - addr: tuple[str, int], - ) -> None: - """Process a successfully decoded message.""" - # Update session activity. - ip, port = addr - self._session_cache.touch(remote_node_id, ip, Port(port)) - - # Check if this is a response to a pending request. - request_id = message.request_id - - # Check for multi-response requests first (e.g., FINDNODE -> NODES). - multi_pending = self._pending_multi_requests.get(request_id) - if multi_pending is not None: - await multi_pending.response_queue.put(message) - return - - # Check for single-response requests. - pending = self._pending_requests.get(request_id) - if pending is not None and not pending.future.done(): - pending.future.set_result(message) - return - - # Otherwise, pass to message handler. - if self._message_handler is not None: - self._message_handler(remote_node_id, message, addr) - - async def _send_whoareyou( - self, - remote_node_id: NodeId, - request_nonce: Nonce, - addr: tuple[str, int], - ) -> None: - """Send a WHOAREYOU packet.""" - if self._transport is None: - return - - # Look up the cached ENR sequence for this node. - # - # If we know their ENR, send the current seq so they can skip - # including the full ENR in the handshake response, saving bandwidth. - # Fall back to 0 if unknown, which forces the remote to include their ENR. - cached_enr = self._handshake_manager.get_cached_enr(remote_node_id) - remote_enr_seq = SeqNumber(cached_enr.seq) if cached_enr is not None else SeqNumber(0) - - # Generate masking IV for the WHOAREYOU packet. - # - # This IV is part of the challenge_data used for key derivation. - # Both sides must use identical challenge_data to derive matching keys. - masking_iv = Bytes16(os.urandom(16)) - - id_nonce, authdata, nonce, challenge_data = self._handshake_manager.create_whoareyou( - remote_node_id=remote_node_id, - request_nonce=request_nonce, - remote_enr_seq=remote_enr_seq, - masking_iv=masking_iv, - ) - - packet = encode_packet( - dest_node_id=remote_node_id, - flag=PacketFlag.WHOAREYOU, - nonce=nonce, - authdata=authdata, - message=b"", - encryption_key=None, - masking_iv=masking_iv, - ) - - self._transport.sendto(packet, addr) - logger.debug("Sent WHOAREYOU to %s", remote_node_id.hex()[:16]) - - async def send_response( - self, - dest_node_id: NodeId, - dest_addr: tuple[str, int], - message: DiscoveryMessage, - ) -> bool: - """ - Send a response message using an existing session. - - Response messages (PONG, NODES, TALKRESP) reply to incoming requests. - Unlike requests, responses do not trigger handshakes if no session exists. - The session must have been established by the original request flow. - - Args: - dest_node_id: 32-byte destination node ID. - dest_addr: (ip, port) tuple. - message: Response message to send. - - Returns: - True if sent successfully. - False if transport not running or no session exists. - """ - if self._transport is None: - return False - - # Responses require an existing session. - # - # The requester initiated the handshake. - # By the time we respond, session keys must exist. - ip, port = dest_addr - session = self._session_cache.get(dest_node_id, ip, Port(port)) - if session is None: - logger.debug("No session for response to %s", dest_node_id.hex()[:16]) - return False - - # Encode and encrypt the response. - nonce = Nonce.generate() - message_bytes = encode_message(message) - authdata = encode_message_authdata(self._local_node_id) - - packet = encode_packet( - dest_node_id=dest_node_id, - flag=PacketFlag.MESSAGE, - nonce=nonce, - authdata=authdata, - message=message_bytes, - encryption_key=session.send_key, - ) - - self._transport.sendto(packet, dest_addr) - return True diff --git a/src/lean_spec/subspecs/networking/peer.py b/src/lean_spec/subspecs/networking/peer.py index f6d11394..0a47cff7 100644 --- a/src/lean_spec/subspecs/networking/peer.py +++ b/src/lean_spec/subspecs/networking/peer.py @@ -18,9 +18,9 @@ class PeerInfo: Tracks identity, connection state, and cached protocol data. - The enr and status fields cache fork data from discovery and handshake: + The enr and status fields cache fork data from peer configuration and handshake: - - enr: Populated from discovery, contains eth2 fork_digest + - enr: Populated from bootnode/peer configuration, contains eth2 fork_digest - status: Populated after Status handshake, contains finalized/head checkpoints These cached values enable fork compatibility checks at multiple layers. @@ -42,7 +42,7 @@ class PeerInfo: """Unix timestamp of last successful interaction.""" enr: ENR | None = None - """Cached ENR from discovery. Contains eth2 fork_digest for compatibility checks.""" + """Cached ENR from peer configuration. Contains eth2 fork_digest for compatibility checks.""" status: Status | None = None """Cached Status from handshake. Contains finalized/head checkpoints.""" diff --git a/src/lean_spec/subspecs/networking/types.py b/src/lean_spec/subspecs/networking/types.py index 88522b9a..2ddf7fba 100644 --- a/src/lean_spec/subspecs/networking/types.py +++ b/src/lean_spec/subspecs/networking/types.py @@ -19,7 +19,7 @@ class DomainType(Bytes4): class NodeId(Bytes32): - """32-byte node identifier for Discovery v5, derived from `keccak256(pubkey)`.""" + """32-byte ENR node identifier, derived from `keccak256(pubkey)`.""" class ForkDigest(Bytes4): diff --git a/tests/consensus/lstar/networking/test_discovery_crypto.py b/tests/consensus/lstar/networking/test_discovery_crypto.py deleted file mode 100644 index 01f5530e..00000000 --- a/tests/consensus/lstar/networking/test_discovery_crypto.py +++ /dev/null @@ -1,219 +0,0 @@ -"""Test vectors for Discovery v5 cryptographic primitives. - -These vectors use official devp2p specification key material so that -client teams can verify their ECDH, HKDF, signing, and encryption -implementations against a known reference. - -Reference: - https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire-test-vectors.md -""" - -import pytest -from consensus_testing import DiscoveryCryptoTestFiller - -pytestmark = pytest.mark.valid_until("Lstar") - -# Official devp2p spec key material. -NODE_A_PRIVKEY = "0xeef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f" -NODE_A_PUBKEY = "0x0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" -NODE_A_ID = "0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb" - -NODE_B_PRIVKEY = "0x66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628" -NODE_B_PUBKEY = "0x0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91" -NODE_B_ID = "0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9" - -SPEC_EPHEMERAL_KEY = "0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736" -SPEC_EPHEMERAL_PUBKEY = "0x039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231" - -SPEC_CHALLENGE_DATA = ( - "0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c" - "00180102030405060708090a0b0c0d0e0f100000000000000000" -) - -# Pre-computed signature for NODE_A signing with SPEC_CHALLENGE_DATA. -SPEC_SIGNATURE = ( - "0xe622b72727fd64529187b6e4f7241caadac052f840122474d91de305f7bb5cc4" - "5d0457b5f6ba72c8bbc82480dea2e2cd3eabca6984f6bd7bd3b54a80fe749fa6" -) - - -# --- ECDH --- - - -def test_ecdh_node_a_to_b(discovery_crypto: DiscoveryCryptoTestFiller) -> None: - """ECDH(A_priv, B_pub) produces deterministic shared secret.""" - discovery_crypto( - operation="ecdh", - input={"privateKey": NODE_A_PRIVKEY, "publicKey": NODE_B_PUBKEY}, - ) - - -def test_ecdh_node_b_to_a(discovery_crypto: DiscoveryCryptoTestFiller) -> None: - """ECDH(B_priv, A_pub) must equal ECDH(A_priv, B_pub) (symmetry).""" - discovery_crypto( - operation="ecdh", - input={"privateKey": NODE_B_PRIVKEY, "publicKey": NODE_A_PUBKEY}, - ) - - -def test_ecdh_ephemeral_to_b(discovery_crypto: DiscoveryCryptoTestFiller) -> None: - """ECDH with spec ephemeral key and Node B. Used in handshake key derivation.""" - discovery_crypto( - operation="ecdh", - input={"privateKey": SPEC_EPHEMERAL_KEY, "publicKey": NODE_B_PUBKEY}, - ) - - -# --- Key derivation --- - - -def test_key_derivation_spec_vector( - discovery_crypto: DiscoveryCryptoTestFiller, -) -> None: - """HKDF key derivation with official spec challenge data.""" - # Shared secret from ECDH(ephemeral, B_pub). - shared_secret = "0x022c82f214eb37159111712add00040fcdf73fd4d7d0b7c0f980da4d099aa59ba4" - discovery_crypto( - operation="key_derivation", - input={ - "sharedSecret": shared_secret, - "initiatorId": NODE_A_ID, - "recipientId": NODE_B_ID, - "challengeData": SPEC_CHALLENGE_DATA, - }, - ) - - -# --- ID nonce signing --- - - -def test_id_nonce_sign_spec_vector( - discovery_crypto: DiscoveryCryptoTestFiller, -) -> None: - """Sign ID nonce with Node A's key and spec challenge data.""" - discovery_crypto( - operation="id_nonce_sign", - input={ - "privateKey": NODE_A_PRIVKEY, - "challengeData": SPEC_CHALLENGE_DATA, - "ephemeralPubkey": SPEC_EPHEMERAL_PUBKEY, - "destNodeId": NODE_B_ID, - }, - ) - - -# --- ID nonce verification --- - - -def test_id_nonce_verify_valid(discovery_crypto: DiscoveryCryptoTestFiller) -> None: - """Verify valid signature from Node A.""" - discovery_crypto( - operation="id_nonce_verify", - input={ - "signature": SPEC_SIGNATURE, - "challengeData": SPEC_CHALLENGE_DATA, - "ephemeralPubkey": SPEC_EPHEMERAL_PUBKEY, - "destNodeId": NODE_B_ID, - "publicKey": NODE_A_PUBKEY, - }, - ) - - -def test_id_nonce_verify_wrong_pubkey( - discovery_crypto: DiscoveryCryptoTestFiller, -) -> None: - """Verification fails with wrong public key (Node B instead of A).""" - discovery_crypto( - operation="id_nonce_verify", - input={ - "signature": SPEC_SIGNATURE, - "challengeData": SPEC_CHALLENGE_DATA, - "ephemeralPubkey": SPEC_EPHEMERAL_PUBKEY, - "destNodeId": NODE_B_ID, - "publicKey": NODE_B_PUBKEY, - }, - ) - - -def test_id_nonce_verify_wrong_challenge( - discovery_crypto: DiscoveryCryptoTestFiller, -) -> None: - """Verification fails with different challenge data.""" - wrong_challenge = "0x" + "ff" * 63 - discovery_crypto( - operation="id_nonce_verify", - input={ - "signature": SPEC_SIGNATURE, - "challengeData": wrong_challenge, - "ephemeralPubkey": SPEC_EPHEMERAL_PUBKEY, - "destNodeId": NODE_B_ID, - "publicKey": NODE_A_PUBKEY, - }, - ) - - -# --- AES-GCM --- - - -def test_aes_gcm_encrypt_spec_ping( - discovery_crypto: DiscoveryCryptoTestFiller, -) -> None: - """AES-GCM encrypt a PING message with spec key and nonce.""" - discovery_crypto( - operation="aes_gcm_encrypt", - input={ - "key": "0x9f2d77db7004bf8a1a85107ac686990b", - "nonce": "0x27b5af763c446acd2749fe8e", - "plaintext": "0x01c20101", - "aad": "0x", - }, - ) - - -def test_aes_gcm_encrypt_empty_plaintext( - discovery_crypto: DiscoveryCryptoTestFiller, -) -> None: - """AES-GCM with empty plaintext. Output is just the 16-byte auth tag.""" - discovery_crypto( - operation="aes_gcm_encrypt", - input={ - "key": "0x9f2d77db7004bf8a1a85107ac686990b", - "nonce": "0x27b5af763c446acd2749fe8e", - "plaintext": "0x", - "aad": "0x", - }, - ) - - -def test_aes_gcm_encrypt_with_aad( - discovery_crypto: DiscoveryCryptoTestFiller, -) -> None: - """AES-GCM with additional authenticated data (packet header).""" - discovery_crypto( - operation="aes_gcm_encrypt", - input={ - "key": "0x9f2d77db7004bf8a1a85107ac686990b", - "nonce": "0x27b5af763c446acd2749fe8e", - "plaintext": "0x01c20101", - "aad": "0x0102030405060708090a0b0c0d0e0f10", - }, - ) - - -# --- Node ID --- - - -def test_node_id_from_node_a(discovery_crypto: DiscoveryCryptoTestFiller) -> None: - """Compute Node A's ID from its compressed public key.""" - discovery_crypto( - operation="node_id", - input={"publicKey": NODE_A_PUBKEY}, - ) - - -def test_node_id_from_node_b(discovery_crypto: DiscoveryCryptoTestFiller) -> None: - """Compute Node B's ID from its compressed public key.""" - discovery_crypto( - operation="node_id", - input={"publicKey": NODE_B_PUBKEY}, - ) diff --git a/tests/consensus/lstar/networking/test_discovery_routing.py b/tests/consensus/lstar/networking/test_discovery_routing.py deleted file mode 100644 index 996a68fd..00000000 --- a/tests/consensus/lstar/networking/test_discovery_routing.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Test vectors for Discovery v5 distance computations. - -XOR distance and log2 distance are the foundation of Kademlia routing. -Every client must compute identical distances for correct peer discovery -and k-bucket assignment. -""" - -import pytest -from consensus_testing import NetworkingCodecTestFiller - -pytestmark = pytest.mark.valid_until("Lstar") - -# Official devp2p spec node IDs. -NODE_A = "0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb" -NODE_B = "0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9" - -ZERO = "0x" + "00" * 32 -MAX = "0x" + "ff" * 32 -ONE = "0x" + "00" * 31 + "01" -TWO = "0x" + "00" * 31 + "02" -HIGH_BIT = "0x80" + "00" * 31 -LOW_BYTE_80 = "0x" + "00" * 31 + "80" -BUCKET_9 = "0x" + "00" * 30 + "0100" - - -# --- XOR distance --- - - -def test_xor_distance_self(networking_codec: NetworkingCodecTestFiller) -> None: - """XOR distance to self is always zero (identity property).""" - networking_codec(codec_name="xor_distance", input={"nodeA": NODE_A, "nodeB": NODE_A}) - - -def test_xor_distance_symmetric(networking_codec: NetworkingCodecTestFiller) -> None: - """d(A, B) == d(B, A) (symmetry property).""" - networking_codec(codec_name="xor_distance", input={"nodeA": NODE_A, "nodeB": NODE_B}) - - -def test_xor_distance_symmetric_reverse( - networking_codec: NetworkingCodecTestFiller, -) -> None: - """Reverse order produces identical distance.""" - networking_codec(codec_name="xor_distance", input={"nodeA": NODE_B, "nodeB": NODE_A}) - - -def test_xor_distance_max(networking_codec: NetworkingCodecTestFiller) -> None: - """XOR of zero and all-ones produces maximum distance (2^256 - 1).""" - networking_codec(codec_name="xor_distance", input={"nodeA": ZERO, "nodeB": MAX}) - - -def test_xor_distance_adjacent(networking_codec: NetworkingCodecTestFiller) -> None: - """XOR of 0x01 and 0x02 is 0x03 (lowest bits).""" - networking_codec(codec_name="xor_distance", input={"nodeA": ONE, "nodeB": TWO}) - - -def test_xor_distance_high_bit(networking_codec: NetworkingCodecTestFiller) -> None: - """Single high bit difference. Distance = 2^255.""" - networking_codec(codec_name="xor_distance", input={"nodeA": HIGH_BIT, "nodeB": ZERO}) - - -# --- Log2 distance (k-bucket assignment) --- - - -def test_log2_distance_self(networking_codec: NetworkingCodecTestFiller) -> None: - """Log2 distance to self is 0.""" - networking_codec(codec_name="log2_distance", input={"nodeA": NODE_A, "nodeB": NODE_A}) - - -def test_log2_distance_spec_nodes( - networking_codec: NetworkingCodecTestFiller, -) -> None: - """Log2 distance between official spec nodes A and B is 253.""" - networking_codec(codec_name="log2_distance", input={"nodeA": NODE_A, "nodeB": NODE_B}) - - -def test_log2_distance_max(networking_codec: NetworkingCodecTestFiller) -> None: - """Maximum log2 distance is 256 (all bits differ).""" - networking_codec(codec_name="log2_distance", input={"nodeA": ZERO, "nodeB": MAX}) - - -def test_log2_distance_bucket_1(networking_codec: NetworkingCodecTestFiller) -> None: - """Single lowest bit difference lands in bucket 1.""" - networking_codec(codec_name="log2_distance", input={"nodeA": ZERO, "nodeB": ONE}) - - -def test_log2_distance_bucket_2(networking_codec: NetworkingCodecTestFiller) -> None: - """Bit 1 set lands in bucket 2.""" - networking_codec(codec_name="log2_distance", input={"nodeA": ZERO, "nodeB": TWO}) - - -def test_log2_distance_bucket_8(networking_codec: NetworkingCodecTestFiller) -> None: - """Byte boundary: 0x80 in last byte lands in bucket 8.""" - networking_codec(codec_name="log2_distance", input={"nodeA": ZERO, "nodeB": LOW_BYTE_80}) - - -def test_log2_distance_bucket_9(networking_codec: NetworkingCodecTestFiller) -> None: - """0x0100 in last two bytes lands in bucket 9.""" - networking_codec(codec_name="log2_distance", input={"nodeA": ZERO, "nodeB": BUCKET_9}) - - -def test_log2_distance_bucket_256( - networking_codec: NetworkingCodecTestFiller, -) -> None: - """Highest bit only (0x80 in first byte) lands in bucket 256.""" - networking_codec(codec_name="log2_distance", input={"nodeA": ZERO, "nodeB": HIGH_BIT}) diff --git a/tests/consensus/lstar/networking/test_discv5_message_rejections.py b/tests/consensus/lstar/networking/test_discv5_message_rejections.py deleted file mode 100644 index b71a3db5..00000000 --- a/tests/consensus/lstar/networking/test_discv5_message_rejections.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Discv5 message decoder: malformed-input rejection vectors.""" - -import pytest -from consensus_testing import NetworkingCodecTestFiller - -from lean_spec.subspecs.networking.discovery.codec import MessageDecodingError - -pytestmark = pytest.mark.valid_until("Lstar") - - -def test_discv5_message_decode_rejects_empty_input( - networking_codec: NetworkingCodecTestFiller, -) -> None: - """Zero bytes is too short to carry a discv5 message type byte. - - Every discv5 message begins with a single type byte followed by an - RLP-encoded payload. Empty input cannot satisfy even the type byte - so the decoder must abort immediately. - """ - networking_codec( - codec_name="decode_failure", - input={"decoder": "discv5_message", "bytes": "0x"}, - expect_exception=MessageDecodingError, - ) - - -def test_discv5_message_decode_rejects_unknown_type_byte( - networking_codec: NetworkingCodecTestFiller, -) -> None: - """A type byte outside the assigned set is rejected. - - Only 0x01 through 0x06 are valid discv5 message types. A byte of 0xff - followed by an empty RLP list carries no meaningful message and must - be rejected rather than silently ignored. - """ - networking_codec( - codec_name="decode_failure", - input={"decoder": "discv5_message", "bytes": "0xffc0"}, - expect_exception=MessageDecodingError, - ) - - -def test_discv5_ping_decode_rejects_wrong_element_count( - networking_codec: NetworkingCodecTestFiller, -) -> None: - """A PING payload with the wrong number of RLP elements is rejected. - - PING requires exactly two fields: request_id and enr_seq. The payload - 0x01 (ping type) followed by 0xc0 (empty RLP list) carries zero - elements, so the decoder must refuse to build a PING record. - """ - networking_codec( - codec_name="decode_failure", - input={"decoder": "discv5_message", "bytes": "0x01c0"}, - expect_exception=MessageDecodingError, - ) diff --git a/tests/consensus/lstar/networking/test_discv5_messages.py b/tests/consensus/lstar/networking/test_discv5_messages.py deleted file mode 100644 index 075a0e95..00000000 --- a/tests/consensus/lstar/networking/test_discv5_messages.py +++ /dev/null @@ -1,187 +0,0 @@ -"""Test vectors for Discovery v5 message RLP encoding.""" - -import pytest -from consensus_testing import NetworkingCodecTestFiller - -pytestmark = pytest.mark.valid_until("Lstar") - -IPV4_LOCALHOST = "0x7f000001" -"""127.0.0.1 as 4 raw bytes.""" - -IPV6_LOOPBACK = "0x00000000000000000000000000000001" -"""::1 as 16 raw bytes.""" - - -# --- PING --- - - -def test_ping_typical(networking_codec: NetworkingCodecTestFiller) -> None: - """PING with request_id=0x01 and enr_seq=1. Matches devp2p spec plaintext 01c20101.""" - networking_codec( - codec_name="discv5_message", - input={"type": "ping", "requestId": "0x01", "enrSeq": 1}, - ) - - -def test_ping_leading_zeros_stripped(networking_codec: NetworkingCodecTestFiller) -> None: - """PING with 4-byte request_id containing leading zeros. Stripped to minimal encoding.""" - networking_codec( - codec_name="discv5_message", - input={"type": "ping", "requestId": "0x00000001", "enrSeq": 1}, - ) - - -def test_ping_seq_zero(networking_codec: NetworkingCodecTestFiller) -> None: - """PING with enr_seq=0. Zero encodes as RLP empty bytes, not 0x00.""" - networking_codec( - codec_name="discv5_message", - input={"type": "ping", "requestId": "0x01", "enrSeq": 0}, - ) - - -# --- PONG --- - - -def test_pong_ipv4(networking_codec: NetworkingCodecTestFiller) -> None: - """PONG with IPv4 127.0.0.1 and port 30303.""" - networking_codec( - codec_name="discv5_message", - input={ - "type": "pong", - "requestId": "0x01", - "enrSeq": 1, - "recipientIp": IPV4_LOCALHOST, - "recipientPort": 30303, - }, - ) - - -def test_pong_ipv6(networking_codec: NetworkingCodecTestFiller) -> None: - """PONG with IPv6 loopback (::1). 16-byte IP discriminates from IPv4.""" - networking_codec( - codec_name="discv5_message", - input={ - "type": "pong", - "requestId": "0x01", - "enrSeq": 1, - "recipientIp": IPV6_LOOPBACK, - "recipientPort": 9000, - }, - ) - - -def test_pong_port_zero(networking_codec: NetworkingCodecTestFiller) -> None: - """PONG with port=0. Zero port encodes as RLP empty bytes.""" - networking_codec( - codec_name="discv5_message", - input={ - "type": "pong", - "requestId": "0x01", - "enrSeq": 1, - "recipientIp": IPV4_LOCALHOST, - "recipientPort": 0, - }, - ) - - -# --- FINDNODE --- - - -def test_findnode_single_distance(networking_codec: NetworkingCodecTestFiller) -> None: - """FINDNODE requesting distance 256 (2-byte big-endian in nested RLP list).""" - networking_codec( - codec_name="discv5_message", - input={"type": "findnode", "requestId": "0x01", "distances": [256]}, - ) - - -def test_findnode_mixed_distances(networking_codec: NetworkingCodecTestFiller) -> None: - """FINDNODE with distances [0, 1, 256]. Tests zero and multi-byte encoding.""" - networking_codec( - codec_name="discv5_message", - input={"type": "findnode", "requestId": "0x01", "distances": [0, 1, 256]}, - ) - - -def test_findnode_empty(networking_codec: NetworkingCodecTestFiller) -> None: - """FINDNODE with empty distance list.""" - networking_codec( - codec_name="discv5_message", - input={"type": "findnode", "requestId": "0x01", "distances": []}, - ) - - -# --- NODES --- - - -def test_nodes_single_enr(networking_codec: NetworkingCodecTestFiller) -> None: - """NODES response with total=1 and one ENR (raw RLP bytes).""" - networking_codec( - codec_name="discv5_message", - input={ - "type": "nodes", - "requestId": "0x01", - "total": 1, - "enrs": ["0xdeadbeef"], - }, - ) - - -def test_nodes_empty(networking_codec: NetworkingCodecTestFiller) -> None: - """NODES response with total=0 and empty ENR list.""" - networking_codec( - codec_name="discv5_message", - input={"type": "nodes", "requestId": "0x01", "total": 0, "enrs": []}, - ) - - -# --- TALKREQ --- - - -def test_talkreq_typical(networking_codec: NetworkingCodecTestFiller) -> None: - """TALKREQ with protocol identifier and request payload.""" - networking_codec( - codec_name="discv5_message", - input={ - "type": "talkreq", - "requestId": "0x01", - "protocol": "0x" + b"discv5-test".hex(), - "request": "0xdeadbeef", - }, - ) - - -def test_talkreq_empty_payload(networking_codec: NetworkingCodecTestFiller) -> None: - """TALKREQ with empty request payload.""" - networking_codec( - codec_name="discv5_message", - input={ - "type": "talkreq", - "requestId": "0x01", - "protocol": "0x" + b"discv5-test".hex(), - "request": "0x", - }, - ) - - -# --- TALKRESP --- - - -def test_talkresp_typical(networking_codec: NetworkingCodecTestFiller) -> None: - """TALKRESP with response payload.""" - networking_codec( - codec_name="discv5_message", - input={ - "type": "talkresp", - "requestId": "0x01", - "response": "0xcafebabe", - }, - ) - - -def test_talkresp_empty(networking_codec: NetworkingCodecTestFiller) -> None: - """TALKRESP with empty response.""" - networking_codec( - codec_name="discv5_message", - input={"type": "talkresp", "requestId": "0x01", "response": "0x"}, - ) diff --git a/tests/consensus/lstar/networking/test_discv5_packet.py b/tests/consensus/lstar/networking/test_discv5_packet.py deleted file mode 100644 index 2a553f71..00000000 --- a/tests/consensus/lstar/networking/test_discv5_packet.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Discovery v5 packet framing: known-answer and rejection vectors. - -Pins the exact wire-level encoding produced by the spec for each packet -type clients must emit. Uses deterministic masking IVs so the resulting -bytes are stable, and roundtrip-decodes the header to assert shape -preservation. Rejection vectors cover the public size and protocol-id -checks. -""" - -import pytest -from consensus_testing import NetworkingCodecTestFiller - -pytestmark = pytest.mark.valid_until("Lstar") - - -DEST_NODE_ID = "0x" + "11" * 32 -"""Fixed destination node ID. Its first 16 bytes act as the masking key.""" - -NONCE = "0x" + "22" * 12 -"""Fixed 12-byte message nonce.""" - -MASKING_IV = "0x" + "33" * 16 -"""Fixed 16-byte header-masking IV for deterministic packet bytes.""" - -ENCRYPTION_KEY = "0x" + "44" * 16 -"""Fixed 16-byte AES-GCM encryption key for non-WHOAREYOU packets.""" - -SRC_NODE_ID = "0x" + "55" * 32 -"""Fixed source node ID used inside MESSAGE and HANDSHAKE authdata.""" - -ID_NONCE = "0x" + "66" * 16 -"""Fixed 16-byte ID nonce used inside WHOAREYOU authdata.""" - -ID_SIGNATURE = "0x" + "77" * 64 -"""Fixed 64-byte ID-nonce signature used inside HANDSHAKE authdata.""" - -EPH_PUBKEY = "0x02" + "88" * 32 -"""Fixed 33-byte compressed ephemeral public key (prefix byte then 32 bytes).""" - -# An encrypted message payload the AES-GCM layer will accept; exact -# ciphertext bytes do not matter for packet-shape assertions. -MESSAGE_PAYLOAD = "0x" + "ab" * 16 - - -def test_discv5_packet_message(networking_codec: NetworkingCodecTestFiller) -> None: - """MESSAGE packet (flag=0) with known inputs pins the full wire bytes. - - Encodes a MESSAGE packet with a fixed destination node id, nonce, - masking IV, encryption key, and source id. The fixture asserts - the encoded packet round-trips through decode_packet_header with - flag, nonce, and authdata preserved. - """ - networking_codec( - codec_name="discv5_packet", - input={ - "packetType": "message", - "destNodeId": DEST_NODE_ID, - "nonce": NONCE, - "maskingIv": MASKING_IV, - "encryptionKey": ENCRYPTION_KEY, - "srcId": SRC_NODE_ID, - "message": MESSAGE_PAYLOAD, - }, - ) - - -def test_discv5_packet_whoareyou(networking_codec: NetworkingCodecTestFiller) -> None: - """WHOAREYOU packet (flag=1) with known inputs pins the full wire bytes. - - WHOAREYOU carries no encrypted payload. The authdata is the 16-byte - ID nonce concatenated with the 8-byte ENR sequence number. - """ - networking_codec( - codec_name="discv5_packet", - input={ - "packetType": "whoareyou", - "destNodeId": DEST_NODE_ID, - "nonce": NONCE, - "maskingIv": MASKING_IV, - "idNonce": ID_NONCE, - "enrSeq": 42, - }, - ) - - -def test_discv5_packet_handshake_without_record( - networking_codec: NetworkingCodecTestFiller, -) -> None: - """HANDSHAKE packet (flag=2) without an embedded ENR record. - - Exercises the fixed-size authdata path where the record field is - absent. Pins the encoded length and the roundtrip decode that reads - sig-size / eph-key-size from the authdata header. - """ - networking_codec( - codec_name="discv5_packet", - input={ - "packetType": "handshake", - "destNodeId": DEST_NODE_ID, - "nonce": NONCE, - "maskingIv": MASKING_IV, - "encryptionKey": ENCRYPTION_KEY, - "srcId": SRC_NODE_ID, - "idSignature": ID_SIGNATURE, - "ephPubkey": EPH_PUBKEY, - "message": MESSAGE_PAYLOAD, - }, - ) - - -def test_discv5_packet_decode_rejects_too_small( - networking_codec: NetworkingCodecTestFiller, -) -> None: - """Decoding a packet below MIN_PACKET_SIZE must raise ValueError. - - MIN_PACKET_SIZE is 63 bytes. A 10-byte blob cannot carry even the - masking IV plus a static header, so the decoder must reject before - any unmask work runs. - """ - networking_codec( - codec_name="decode_failure", - input={ - "decoder": "discv5_packet", - "localNodeId": DEST_NODE_ID, - "bytes": "0x" + "00" * 10, - }, - expect_exception=ValueError, - ) - - -def test_discv5_packet_decode_rejects_wrong_protocol_id( - networking_codec: NetworkingCodecTestFiller, -) -> None: - """Decoding with the wrong local node id (unmask key) surfaces the protocol-id check. - - The masking key is derived from the local node id. Decoding with a - different node id than the packet was encoded for produces garbage - for the static header; the decoder rejects because the protocol id - bytes no longer read as 'discv5'. - """ - # 63 bytes all zero; after unmask with any key, the first 6 bytes - # produce pseudo-random output that will not spell "discv5". - networking_codec( - codec_name="decode_failure", - input={ - "decoder": "discv5_packet", - "localNodeId": DEST_NODE_ID, - "bytes": "0x" + "00" * 63, - }, - expect_exception=ValueError, - ) diff --git a/tests/lean_spec/helpers/__init__.py b/tests/lean_spec/helpers/__init__.py index 371c51bf..bad508ec 100644 --- a/tests/lean_spec/helpers/__init__.py +++ b/tests/lean_spec/helpers/__init__.py @@ -13,7 +13,6 @@ make_attestation_data_simple, make_block, make_bytes32, - make_challenge_data, make_checkpoint, make_empty_block_body, make_genesis_block, @@ -54,7 +53,6 @@ "make_attestation_data_simple", "make_block", "make_bytes32", - "make_challenge_data", "make_checkpoint", "make_empty_block_body", "make_genesis_block", diff --git a/tests/lean_spec/helpers/builders.py b/tests/lean_spec/helpers/builders.py index b876239a..f8c168f5 100644 --- a/tests/lean_spec/helpers/builders.py +++ b/tests/lean_spec/helpers/builders.py @@ -288,20 +288,6 @@ def make_test_block(slot: int = 1, seed: int = 0) -> SignedBlock: ) -def make_challenge_data(id_nonce: bytes = bytes(16), *, nonce: bytes = bytes(12)) -> bytes: - """Build mock Discovery v5 challenge_data for testing. - - Format: masking-iv (16) + static-header (23) + authdata (24) = 63 bytes. - The authdata contains the id_nonce (16) + enr_seq (8). - """ - masking_iv = bytes(16) - # static-header: protocol-id (6) + version (2) + flag (1) + nonce (12) + authdata-size (2) - static_header = b"discv5" + b"\x00\x01\x01" + nonce + b"\x00\x18" - # authdata: id-nonce (16) + enr-seq (8) - authdata = id_nonce + bytes(8) - return masking_iv + static_header + authdata - - _DEFAULT_VALIDATOR_ID = ValidatorIndex(0) _DEFAULT_ATTESTATION_SLOT = Slot(1) diff --git a/tests/lean_spec/subspecs/networking/discovery/__init__.py b/tests/lean_spec/subspecs/networking/discovery/__init__.py deleted file mode 100644 index f857259f..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for Discovery v5 protocol.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/conftest.py b/tests/lean_spec/subspecs/networking/discovery/conftest.py deleted file mode 100644 index 61ca52d5..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/conftest.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Shared pytest fixtures for Discovery v5 tests.""" - -from __future__ import annotations - -import pytest - -from lean_spec.subspecs.networking.discovery.messages import IdNonce -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.enr.keys import EnrKey -from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes32, Bytes33, Bytes64 - -# From devp2p test vectors -NODE_A_PRIVKEY = Bytes32( - bytes.fromhex("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") -) -NODE_A_ID = NodeId( - bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") -) -NODE_B_PRIVKEY = Bytes32( - bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") -) -NODE_B_ID = NodeId( - bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") -) -NODE_B_PUBKEY = Bytes33( - bytes.fromhex("0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91") -) - -# Spec id-nonce used in WHOAREYOU test vectors. -SPEC_ID_NONCE = IdNonce(bytes.fromhex("0102030405060708090a0b0c0d0e0f10")) - - -@pytest.fixture -def local_private_key() -> Bytes32: - """Node B's private key from devp2p test vectors.""" - return NODE_B_PRIVKEY - - -@pytest.fixture -def local_node_id() -> NodeId: - """Node B's ID from devp2p test vectors.""" - return NodeId(NODE_B_ID) - - -@pytest.fixture -def remote_node_id() -> NodeId: - """Node A's ID from devp2p test vectors.""" - return NodeId(NODE_A_ID) - - -@pytest.fixture -def local_enr() -> ENR: - """Minimal local ENR for testing.""" - return ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): NODE_B_PUBKEY, - EnrKey("ip"): bytes([127, 0, 0, 1]), - EnrKey("udp"): (9000).to_bytes(2, "big"), - }, - ) diff --git a/tests/lean_spec/subspecs/networking/discovery/test_codec.py b/tests/lean_spec/subspecs/networking/discovery/test_codec.py deleted file mode 100644 index 37028f3a..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_codec.py +++ /dev/null @@ -1,498 +0,0 @@ -"""Tests for Discovery v5 message codec.""" - -import pytest - -from lean_spec.subspecs.networking.discovery.codec import ( - MessageDecodingError, - MessageEncodingError, - _decode_request_id, - decode_message, - encode_message, -) -from lean_spec.subspecs.networking.discovery.messages import ( - Distance, - FindNode, - IPv4, - IPv6, - MessageType, - Nodes, - Ping, - Pong, - Port, - RequestId, - TalkReq, - TalkResp, -) -from lean_spec.subspecs.networking.types import SeqNumber -from lean_spec.types import Uint8 - - -class TestPingCodec: - """Tests for PING message encoding/decoding.""" - - def test_encode_decode_roundtrip(self): - """Test that PING encodes and decodes correctly.""" - ping = Ping( - request_id=RequestId(data=b"\x01\x02\x03"), - enr_seq=SeqNumber(42), - ) - - encoded = encode_message(ping) - decoded = decode_message(encoded) - - assert isinstance(decoded, Ping) - assert bytes(decoded.request_id) == bytes(ping.request_id) - assert decoded.enr_seq == ping.enr_seq - - def test_encode_starts_with_message_type(self): - """Test that encoded PING starts with message type byte.""" - ping = Ping( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(0), - ) - - encoded = encode_message(ping) - - assert encoded[0] == MessageType.PING - - def test_zero_enr_seq(self): - """Test PING with zero ENR sequence.""" - ping = Ping( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(0), - ) - - encoded = encode_message(ping) - decoded = decode_message(encoded) - - assert isinstance(decoded, Ping) - assert decoded.enr_seq == SeqNumber(0) - - def test_large_enr_seq(self): - """Test PING with large ENR sequence.""" - ping = Ping( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(2**63 - 1), - ) - - encoded = encode_message(ping) - decoded = decode_message(encoded) - - assert isinstance(decoded, Ping) - assert decoded.enr_seq == ping.enr_seq - - -class TestPongCodec: - """Tests for PONG message encoding/decoding.""" - - def test_encode_decode_roundtrip(self): - """Test that PONG encodes and decodes correctly.""" - pong = Pong( - request_id=RequestId(data=b"\x01\x02\x03"), - enr_seq=SeqNumber(42), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), # 127.0.0.1 - recipient_port=Port(9000), - ) - - encoded = encode_message(pong) - decoded = decode_message(encoded) - - assert isinstance(decoded, Pong) - assert bytes(decoded.request_id) == bytes(pong.request_id) - assert decoded.enr_seq == pong.enr_seq - assert decoded.recipient_ip == pong.recipient_ip - assert decoded.recipient_port == pong.recipient_port - - def test_ipv6_address(self): - """Test PONG with IPv6 address.""" - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv6(bytes(16)), # ::0 - recipient_port=Port(9000), - ) - - encoded = encode_message(pong) - decoded = decode_message(encoded) - - assert isinstance(decoded, Pong) - assert decoded.recipient_ip == IPv6(bytes(16)) - - -class TestFindNodeCodec: - """Tests for FINDNODE message encoding/decoding.""" - - def test_encode_decode_roundtrip(self): - """Test that FINDNODE encodes and decodes correctly.""" - findnode = FindNode( - request_id=RequestId(data=b"\x01\x02\x03"), - distances=[Distance(1), Distance(2), Distance(3)], - ) - - encoded = encode_message(findnode) - decoded = decode_message(encoded) - - assert isinstance(decoded, FindNode) - assert bytes(decoded.request_id) == bytes(findnode.request_id) - assert decoded.distances == findnode.distances - - def test_empty_distances(self): - """Test FINDNODE with empty distances list.""" - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[], - ) - - encoded = encode_message(findnode) - decoded = decode_message(encoded) - - assert isinstance(decoded, FindNode) - assert decoded.distances == [] - - def test_distance_zero(self): - """Test FINDNODE with distance 0 (request own ENR).""" - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(0)], - ) - - encoded = encode_message(findnode) - decoded = decode_message(encoded) - - assert isinstance(decoded, FindNode) - assert decoded.distances == [Distance(0)] - - def test_distance_256(self): - """Test FINDNODE with maximum distance 256.""" - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(256)], - ) - - encoded = encode_message(findnode) - decoded = decode_message(encoded) - - assert isinstance(decoded, FindNode) - assert decoded.distances == [Distance(256)] - - -class TestNodesCodec: - """Tests for NODES message encoding/decoding.""" - - def test_encode_decode_roundtrip(self): - """Test that NODES encodes and decodes correctly.""" - nodes = Nodes( - request_id=RequestId(data=b"\x01\x02\x03"), - total=Uint8(2), - enrs=[b"enr1", b"enr2"], - ) - - encoded = encode_message(nodes) - decoded = decode_message(encoded) - - assert isinstance(decoded, Nodes) - assert bytes(decoded.request_id) == bytes(nodes.request_id) - assert decoded.total == nodes.total - assert decoded.enrs == nodes.enrs - - def test_empty_enrs(self): - """Test NODES with empty ENRs list.""" - nodes = Nodes( - request_id=RequestId(data=b"\x01"), - total=Uint8(1), - enrs=[], - ) - - encoded = encode_message(nodes) - decoded = decode_message(encoded) - - assert isinstance(decoded, Nodes) - assert decoded.enrs == [] - - def test_zero_total(self): - """Test NODES with total=0.""" - nodes = Nodes( - request_id=RequestId(data=b"\x01"), - total=Uint8(0), - enrs=[], - ) - - encoded = encode_message(nodes) - decoded = decode_message(encoded) - - assert isinstance(decoded, Nodes) - assert decoded.total == Uint8(0) - - -class TestTalkReqCodec: - """Tests for TALKREQ message encoding/decoding.""" - - def test_encode_decode_roundtrip(self): - """Test that TALKREQ encodes and decodes correctly.""" - talkreq = TalkReq( - request_id=RequestId(data=b"\x01\x02\x03"), - protocol=b"eth2", - request=b"hello", - ) - - encoded = encode_message(talkreq) - decoded = decode_message(encoded) - - assert isinstance(decoded, TalkReq) - assert bytes(decoded.request_id) == bytes(talkreq.request_id) - assert decoded.protocol == talkreq.protocol - assert decoded.request == talkreq.request - - def test_empty_request(self): - """Test TALKREQ with empty request.""" - talkreq = TalkReq( - request_id=RequestId(data=b"\x01"), - protocol=b"test", - request=b"", - ) - - encoded = encode_message(talkreq) - decoded = decode_message(encoded) - - assert isinstance(decoded, TalkReq) - assert decoded.request == b"" - - -class TestTalkRespCodec: - """Tests for TALKRESP message encoding/decoding.""" - - def test_encode_decode_roundtrip(self): - """Test that TALKRESP encodes and decodes correctly.""" - talkresp = TalkResp( - request_id=RequestId(data=b"\x01\x02\x03"), - response=b"world", - ) - - encoded = encode_message(talkresp) - decoded = decode_message(encoded) - - assert isinstance(decoded, TalkResp) - assert bytes(decoded.request_id) == bytes(talkresp.request_id) - assert decoded.response == talkresp.response - - def test_empty_response(self): - """Test TALKRESP with empty response (protocol unknown).""" - talkresp = TalkResp( - request_id=RequestId(data=b"\x01"), - response=b"", - ) - - encoded = encode_message(talkresp) - decoded = decode_message(encoded) - - assert isinstance(decoded, TalkResp) - assert decoded.response == b"" - - -class TestDecodingErrors: - """Tests for message decoding error handling.""" - - def test_empty_data_raises(self): - """Test that empty data raises MessageDecodingError.""" - with pytest.raises(MessageDecodingError, match="Message too short"): - decode_message(b"") - - def test_single_byte_raises(self): - """Test that single byte raises MessageDecodingError.""" - with pytest.raises(MessageDecodingError, match="Message too short"): - decode_message(b"\x01") - - def test_unknown_message_type_raises(self): - """Test that unknown message type raises MessageDecodingError.""" - with pytest.raises(MessageDecodingError, match="Unknown message type"): - decode_message(b"\xff\xc0") # Unknown type + empty RLP list - - def test_invalid_rlp_raises(self): - """Test that invalid RLP raises MessageDecodingError.""" - with pytest.raises(MessageDecodingError): - decode_message(b"\x01\xff\xff") # PING type + invalid RLP - - -class TestEncodingErrors: - """Tests for message encoding error handling.""" - - def test_encode_unknown_type_raises(self): - """Encoding an unsupported message type raises MessageEncodingError.""" - with pytest.raises(MessageEncodingError, match="Unknown message type"): - encode_message("not_a_message") # type: ignore[arg-type] - - -class TestRequestIdDecoding: - """Tests for request ID decoding edge cases.""" - - def test_request_id_too_long_raises(self): - """Request ID longer than 8 bytes raises ValueError.""" - with pytest.raises(ValueError, match="Request ID too long"): - _decode_request_id(bytes(9)) - - -class TestRequestIdGeneration: - """Tests for request ID generation.""" - - def test_generates_8_byte_id(self): - """Test that generated request ID is 8 bytes.""" - request_id = RequestId.generate() - assert len(request_id) == 8 - - def test_generates_different_ids(self): - """Test that each generation produces a different ID.""" - id1 = RequestId.generate() - id2 = RequestId.generate() - assert id1 != id2 - - -class TestAddressEncoding: - """IPv4 and IPv6 address handling in PONG messages.""" - - def test_pong_ipv4_4_bytes(self): - """PONG encodes IPv4 as 4 bytes.""" - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), # 127.0.0.1 - recipient_port=Port(9000), - ) - - assert len(pong.recipient_ip) == 4 - assert pong.recipient_ip == IPv4(b"\x7f\x00\x00\x01") - - # Encode and decode roundtrip. - encoded = encode_message(pong) - decoded = decode_message(encoded) - - assert isinstance(decoded, Pong) - assert decoded.recipient_ip == IPv4(b"\x7f\x00\x00\x01") - - def test_pong_ipv6_16_bytes(self): - """PONG encodes IPv6 as 16 bytes.""" - # IPv6 loopback ::1 - ipv6_loopback = IPv6(bytes(15) + b"\x01") - - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=ipv6_loopback, - recipient_port=Port(9000), - ) - - assert len(pong.recipient_ip) == 16 - - # Encode and decode roundtrip. - encoded = encode_message(pong) - decoded = decode_message(encoded) - - assert isinstance(decoded, Pong) - assert decoded.recipient_ip == ipv6_loopback - - def test_pong_common_ipv4_addresses(self): - """Common IPv4 addresses encode correctly.""" - test_addresses = [ - (IPv4(b"\x00\x00\x00\x00"), "0.0.0.0"), - (IPv4(b"\x7f\x00\x00\x01"), "127.0.0.1"), - (IPv4(b"\xc0\xa8\x01\x01"), "192.168.1.1"), - (IPv4(b"\xff\xff\xff\xff"), "255.255.255.255"), - ] - - for ip_bytes, _ in test_addresses: - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=ip_bytes, - recipient_port=Port(9000), - ) - - encoded = encode_message(pong) - decoded = decode_message(encoded) - - assert isinstance(decoded, Pong) - assert decoded.recipient_ip == ip_bytes - - def test_pong_common_ipv6_addresses(self): - """Common IPv6 addresses encode correctly.""" - # ::1 (loopback) - ipv6_loopback = IPv6(bytes(15) + b"\x01") - - # fe80::1 (link-local) - ipv6_link_local = IPv6(b"\xfe\x80" + bytes(13) + b"\x01") - - test_addresses = [ - IPv6(bytes(16)), # :: - ipv6_loopback, # ::1 - ipv6_link_local, # fe80::1 - ] - - for ip_bytes in test_addresses: - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=ip_bytes, - recipient_port=Port(9000), - ) - - encoded = encode_message(pong) - decoded = decode_message(encoded) - - assert isinstance(decoded, Pong) - assert decoded.recipient_ip == ip_bytes - - -class TestPortEncoding: - """Port encoding in PONG messages.""" - - def test_pong_port_common_values(self): - """Common port values encode correctly.""" - test_ports = [ - 80, # HTTP - 443, # HTTPS - 8545, # Ethereum RPC - 9000, # Discovery default - 30303, # devp2p default - 65535, # Maximum port - ] - - for port_value in test_ports: - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(port_value), - ) - - encoded = encode_message(pong) - decoded = decode_message(encoded) - - assert isinstance(decoded, Pong) - assert int(decoded.recipient_port) == port_value - - def test_pong_port_boundary_values(self): - """Port boundary values encode correctly.""" - # Minimum port (0 is valid in UDP) - pong_min = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(0), - ) - - encoded = encode_message(pong_min) - decoded = decode_message(encoded) - assert isinstance(decoded, Pong) - assert int(decoded.recipient_port) == 0 - - # Maximum port - pong_max = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(65535), - ) - - encoded = encode_message(pong_max) - decoded = decode_message(encoded) - assert isinstance(decoded, Pong) - assert int(decoded.recipient_port) == 65535 diff --git a/tests/lean_spec/subspecs/networking/discovery/test_config.py b/tests/lean_spec/subspecs/networking/discovery/test_config.py deleted file mode 100644 index 4ed11c44..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_config.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Tests for Discovery v5 configuration.""" - -import pytest -from pydantic import ValidationError - -from lean_spec.subspecs.networking.discovery.config import ( - ALPHA, - BOND_EXPIRY_SECS, - HANDSHAKE_TIMEOUT_SECS, - K_BUCKET_SIZE, - MAX_NODES_RESPONSE, - REQUEST_TIMEOUT_SECS, - DiscoveryConfig, -) - - -class TestDiscoveryConfig: - """Tests for DiscoveryConfig Pydantic model.""" - - def test_defaults_match_module_constants(self): - """Default config values match the module-level constants.""" - config = DiscoveryConfig() - - assert config.k_bucket_size == K_BUCKET_SIZE - assert config.alpha == ALPHA - assert config.request_timeout_secs == REQUEST_TIMEOUT_SECS - assert config.handshake_timeout_secs == HANDSHAKE_TIMEOUT_SECS - assert config.max_nodes_response == MAX_NODES_RESPONSE - assert config.bond_expiry_secs == BOND_EXPIRY_SECS - - def test_custom_values_accepted(self): - """Custom values override defaults.""" - config = DiscoveryConfig( - k_bucket_size=32, - alpha=5, - request_timeout_secs=2.0, - handshake_timeout_secs=5.0, - max_nodes_response=8, - bond_expiry_secs=3600, - ) - - assert config.k_bucket_size == 32 - assert config.alpha == 5 - assert config.request_timeout_secs == 2.0 - assert config.handshake_timeout_secs == 5.0 - assert config.max_nodes_response == 8 - assert config.bond_expiry_secs == 3600 - - def test_strict_model_rejects_extra_fields(self): - """DiscoveryConfig rejects unknown fields (strict mode).""" - with pytest.raises(ValidationError): - DiscoveryConfig(unknown_field="oops") # type: ignore[call-arg] diff --git a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py b/tests/lean_spec/subspecs/networking/discovery/test_crypto.py deleted file mode 100644 index e7f5cb6b..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py +++ /dev/null @@ -1,335 +0,0 @@ -"""Tests for Discovery v5 cryptographic primitives.""" - -import pytest -from cryptography.exceptions import InvalidTag - -from lean_spec.subspecs.networking.discovery.crypto import ( - _decompress_pubkey, - aes_ctr_decrypt, - aes_ctr_encrypt, - aes_gcm_decrypt, - aes_gcm_encrypt, - ecdh_agree, - generate_secp256k1_keypair, - pubkey_to_uncompressed, - sign_id_nonce, - verify_id_nonce_signature, -) -from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes64 -from tests.lean_spec.helpers import make_challenge_data - - -class TestAesCtr: - """Tests for AES-CTR encryption/decryption.""" - - def test_encrypt_decrypt_roundtrip(self): - """Test that decryption reverses encryption.""" - key = Bytes16.zero() - iv = Bytes16.zero() - plaintext = b"Hello, Discovery v5!" - - ciphertext = aes_ctr_encrypt(key, iv, plaintext) - decrypted = aes_ctr_decrypt(key, iv, ciphertext) - - assert decrypted == plaintext - - def test_encryption_produces_different_output(self): - """Test that encryption actually transforms the data.""" - key = Bytes16(bytes.fromhex("00" * 16)) - iv = Bytes16(bytes.fromhex("00" * 16)) - plaintext = b"test data" - - ciphertext = aes_ctr_encrypt(key, iv, plaintext) - - assert ciphertext != plaintext - - def test_different_ivs_produce_different_ciphertext(self): - """Test that different IVs produce different ciphertext.""" - key = Bytes16.zero() - plaintext = b"same data" - - iv1 = Bytes16(bytes.fromhex("00" * 16)) - iv2 = Bytes16(bytes.fromhex("01" + "00" * 15)) - - ct1 = aes_ctr_encrypt(key, iv1, plaintext) - ct2 = aes_ctr_encrypt(key, iv2, plaintext) - - assert ct1 != ct2 - - -class TestAesGcm: - """Tests for AES-GCM encryption/decryption.""" - - def test_encrypt_decrypt_roundtrip(self): - """Test that decryption reverses encryption.""" - key = Bytes16.zero() - nonce = Bytes12.zero() - plaintext = b"Hello, Discovery v5!" - aad = b"additional data" - - ciphertext = aes_gcm_encrypt(key, nonce, plaintext, aad) - decrypted = aes_gcm_decrypt(key, nonce, ciphertext, aad) - - assert decrypted == plaintext - - def test_ciphertext_includes_auth_tag(self): - """Test that ciphertext is longer than plaintext (includes 16-byte tag).""" - key = Bytes16.zero() - nonce = Bytes12.zero() - plaintext = b"test" - aad = b"" - - ciphertext = aes_gcm_encrypt(key, nonce, plaintext, aad) - - assert len(ciphertext) == len(plaintext) + 16 - - def test_wrong_aad_fails_decryption(self): - """Test that wrong AAD causes authentication failure.""" - key = Bytes16.zero() - nonce = Bytes12.zero() - plaintext = b"secret" - aad = b"correct aad" - - ciphertext = aes_gcm_encrypt(key, nonce, plaintext, aad) - - with pytest.raises(InvalidTag): - aes_gcm_decrypt(key, nonce, ciphertext, b"wrong aad") - - -class TestEcdh: - """Tests for secp256k1 ECDH key agreement.""" - - def test_ecdh_symmetric(self): - """Test that ECDH produces the same shared secret for both parties.""" - priv_a, pub_a = generate_secp256k1_keypair() - priv_b, pub_b = generate_secp256k1_keypair() - - secret_ab = ecdh_agree(priv_a, pub_b) - secret_ba = ecdh_agree(priv_b, pub_a) - - assert secret_ab == secret_ba - - def test_ecdh_produces_33_byte_secret(self): - """Test that ECDH produces a 33-byte compressed point shared secret.""" - priv_a, pub_a = generate_secp256k1_keypair() - priv_b, pub_b = generate_secp256k1_keypair() - - secret = ecdh_agree(priv_a, pub_b) - - assert len(secret) == 33 - - def test_different_keypairs_produce_different_secrets(self): - """Test that different keypairs produce different shared secrets.""" - priv_a, pub_a = generate_secp256k1_keypair() - priv_b, pub_b = generate_secp256k1_keypair() - priv_c, pub_c = generate_secp256k1_keypair() - - secret_ab = ecdh_agree(priv_a, pub_b) - secret_ac = ecdh_agree(priv_a, pub_c) - - assert secret_ab != secret_ac - - -class TestKeypairGeneration: - """Tests for secp256k1 keypair generation.""" - - def test_generates_32_byte_private_key(self): - """Test that generated private key is 32 bytes.""" - priv, pub = generate_secp256k1_keypair() - assert len(priv) == 32 - - def test_generates_33_byte_compressed_public_key(self): - """Test that generated public key is 33 bytes (compressed).""" - priv, pub = generate_secp256k1_keypair() - assert len(pub) == 33 - - def test_public_key_starts_with_02_or_03(self): - """Test that compressed public key has correct prefix.""" - priv, pub = generate_secp256k1_keypair() - assert pub[0] in (0x02, 0x03) - - def test_generates_different_keys_each_time(self): - """Test that each generation produces different keys.""" - priv1, pub1 = generate_secp256k1_keypair() - priv2, pub2 = generate_secp256k1_keypair() - - assert priv1 != priv2 - assert pub1 != pub2 - - -class TestPubkeyConversion: - """Tests for public key format conversion.""" - - def test_uncompressed_is_65_bytes(self): - """Test that uncompressed format is 65 bytes.""" - _, compressed = generate_secp256k1_keypair() - uncompressed = pubkey_to_uncompressed(compressed) - - assert len(uncompressed) == 65 - - def test_uncompressed_starts_with_04(self): - """Test that uncompressed format has 0x04 prefix.""" - _, compressed = generate_secp256k1_keypair() - uncompressed = pubkey_to_uncompressed(compressed) - - assert uncompressed[0] == 0x04 - - def test_passthrough_for_65_byte_key(self): - """Test that 65-byte uncompressed key passes through unchanged.""" - _, compressed = generate_secp256k1_keypair() - uncompressed = pubkey_to_uncompressed(compressed) - - # Passing an already-uncompressed key returns the same bytes. - result = pubkey_to_uncompressed(uncompressed) - assert result == uncompressed - assert len(result) == 65 - - -class TestIdNonceSignature: - """Tests for ID nonce signing and verification.""" - - def test_sign_and_verify(self): - """Test that signature verifies correctly.""" - priv, pub = generate_secp256k1_keypair() - challenge_data = make_challenge_data() - dest_node_id = Bytes32.zero() - - # Need a valid ephemeral pubkey. - _, eph_pub = generate_secp256k1_keypair() - - signature = sign_id_nonce(priv, challenge_data, eph_pub, dest_node_id) - - assert verify_id_nonce_signature(signature, challenge_data, eph_pub, dest_node_id, pub) - - def test_signature_is_64_bytes(self): - """Test that signature is 64 bytes (r || s).""" - priv, _ = generate_secp256k1_keypair() - _, eph_pub = generate_secp256k1_keypair() - challenge_data = make_challenge_data() - dest_node_id = Bytes32.zero() - - signature = sign_id_nonce(priv, challenge_data, eph_pub, dest_node_id) - - assert len(signature) == 64 - - def test_wrong_pubkey_fails_verification(self): - """Test that verification fails with wrong public key.""" - priv, _ = generate_secp256k1_keypair() - _, wrong_pub = generate_secp256k1_keypair() - _, eph_pub = generate_secp256k1_keypair() - challenge_data = make_challenge_data() - dest_node_id = Bytes32.zero() - - signature = sign_id_nonce(priv, challenge_data, eph_pub, dest_node_id) - - result = verify_id_nonce_signature( - signature, challenge_data, eph_pub, dest_node_id, wrong_pub - ) - assert not result - - def test_wrong_challenge_data_fails_verification(self): - """Test that verification fails with wrong challenge data.""" - priv, pub = generate_secp256k1_keypair() - _, eph_pub = generate_secp256k1_keypair() - challenge_data = make_challenge_data() - wrong_challenge_data = make_challenge_data(bytes.fromhex("01" + "00" * 15)) - dest_node_id = Bytes32.zero() - - signature = sign_id_nonce(priv, challenge_data, eph_pub, dest_node_id) - - assert not verify_id_nonce_signature( - signature, wrong_challenge_data, eph_pub, dest_node_id, pub - ) - - -class TestEcdhNegativeCases: - """Negative tests for ECDH key agreement.""" - - def test_zero_private_key_rejected(self): - """ECDH rejects an all-zero private key (point at infinity).""" - _, pub = generate_secp256k1_keypair() - with pytest.raises(ValueError, match="point at infinity"): - ecdh_agree(Bytes32(bytes(32)), pub) - - def test_invalid_private_key_too_short(self): - """ECDH rejects private key shorter than 32 bytes.""" - _, pub = generate_secp256k1_keypair() - with pytest.raises((ValueError, TypeError)): - ecdh_agree(bytes(16), pub) # type: ignore[arg-type] - - -class TestSignIdNonceNegativeCases: - """Negative tests for ID nonce signing.""" - - def test_zero_private_key_rejected(self): - """Signing rejects an all-zero private key.""" - _, eph_pub = generate_secp256k1_keypair() - with pytest.raises((ValueError, Exception)): - sign_id_nonce( - Bytes32(bytes(32)), - make_challenge_data(), - eph_pub, - Bytes32.zero(), - ) - - -class TestVerifyIdNonceNegativeCases: - """Negative tests for ID nonce signature verification.""" - - def test_truncated_signature(self): - """Verification rejects signatures shorter than 64 bytes.""" - _, pub = generate_secp256k1_keypair() - _, eph_pub = generate_secp256k1_keypair() - - result = verify_id_nonce_signature( - Bytes64(bytes(63) + b"\x00"), # 64 bytes but content is garbage - make_challenge_data(), - eph_pub, - Bytes32.zero(), - pub, - ) - assert not result - - def test_wrong_length_node_id(self): - """Verification rejects non-32-byte node ID.""" - _, pub = generate_secp256k1_keypair() - _, eph_pub = generate_secp256k1_keypair() - - result = verify_id_nonce_signature( - Bytes64(bytes(64)), - make_challenge_data(), - eph_pub, - Bytes32(bytes(16) + bytes(16)), # 32 bytes, but let's test wrong content - pub, - ) - assert not result - - -class TestDecompressPubkeyNegativeCases: - """Negative tests for public key decompression.""" - - def test_invalid_prefix_byte(self): - """Decompression rejects keys with invalid prefix.""" - # 33 bytes but prefix is 0x05 (not 0x02 or 0x03) - bad_key = bytes([0x05]) + bytes(32) - with pytest.raises(ValueError, match="Invalid public key encoding"): - _decompress_pubkey(bad_key) - - def test_wrong_length(self): - """Decompression rejects keys with invalid length.""" - with pytest.raises(ValueError, match="Invalid public key encoding"): - _decompress_pubkey(bytes(20)) - - -class TestAesGcmNegativeCases: - """Additional negative tests for AES-GCM.""" - - def test_decrypt_with_wrong_key(self): - """AES-GCM decryption fails with wrong key.""" - key = Bytes16.zero() - wrong_key = Bytes16(bytes([0xFF] * 16)) - nonce = Bytes12.zero() - - ciphertext = aes_gcm_encrypt(key, nonce, b"secret", b"aad") - with pytest.raises(InvalidTag): - aes_gcm_decrypt(wrong_key, nonce, ciphertext, b"aad") diff --git a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py deleted file mode 100644 index c431f353..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py +++ /dev/null @@ -1,672 +0,0 @@ -"""Tests for Discovery v5 handshake state machine.""" - -import time - -import pytest - -from lean_spec.subspecs.networking.discovery.crypto import ( - generate_secp256k1_keypair, - sign_id_nonce, -) -from lean_spec.subspecs.networking.discovery.handshake import ( - HandshakeError, - HandshakeManager, - HandshakeResult, - HandshakeState, - PendingHandshake, -) -from lean_spec.subspecs.networking.discovery.keys import compute_node_id -from lean_spec.subspecs.networking.discovery.messages import IdNonce, Nonce -from lean_spec.subspecs.networking.discovery.packet import ( - HandshakeAuthdata, - WhoAreYouAuthdata, - decode_handshake_authdata, - decode_whoareyou_authdata, - encode_handshake_authdata, -) -from lean_spec.subspecs.networking.discovery.session import SessionCache -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.enr.keys import EnrKey -from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes16, Bytes32, Bytes33, Bytes64 -from tests.lean_spec.subspecs.networking.discovery.conftest import NODE_B_PUBKEY - - -@pytest.fixture -def local_keypair(): - """Generate a local keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - -@pytest.fixture -def remote_keypair(): - """Generate a remote keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - -@pytest.fixture -def session_cache(): - """Create a session cache.""" - return SessionCache() - - -@pytest.fixture -def manager(local_keypair, session_cache): - """Create a handshake manager.""" - priv, pub, node_id = local_keypair - - return HandshakeManager( - local_node_id=node_id, - local_private_key=priv, - local_enr_rlp=b"mock_enr", - local_enr_seq=SeqNumber(1), - session_cache=session_cache, - ) - - -class TestPendingHandshake: - """Tests for PendingHandshake dataclass.""" - - def test_create_pending_handshake(self): - """Test creating a pending handshake.""" - pending = PendingHandshake( - state=HandshakeState.IDLE, - remote_node_id=NodeId(bytes(32)), - ) - - assert pending.state == HandshakeState.IDLE - assert pending.id_nonce is None - assert pending.ephemeral_privkey is None - - def test_is_expired_false_for_new(self): - """Test that new handshake is not expired.""" - pending = PendingHandshake( - state=HandshakeState.IDLE, - remote_node_id=NodeId(bytes(32)), - ) - - assert not pending.is_expired(timeout_secs=1.0) - - def test_is_expired_true_for_old(self): - """Test that old handshake is expired.""" - pending = PendingHandshake( - state=HandshakeState.IDLE, - remote_node_id=NodeId(bytes(32)), - started_at=time.time() - 10, - ) - - assert pending.is_expired(timeout_secs=1.0) - - -class TestHandshakeManager: - """Tests for HandshakeManager.""" - - def test_start_handshake(self, manager): - """Test starting a handshake as initiator.""" - remote_node_id = bytes(32) - - pending = manager.start_handshake(remote_node_id) - - assert pending.state == HandshakeState.SENT_ORDINARY - assert pending.remote_node_id == remote_node_id - - def test_get_pending(self, manager): - """Test getting a pending handshake.""" - remote_node_id = bytes(32) - - manager.start_handshake(remote_node_id) - pending = manager.get_pending(remote_node_id) - - assert pending is not None - assert pending.remote_node_id == remote_node_id - - def test_get_pending_nonexistent(self, manager): - """Test getting nonexistent pending handshake.""" - pending = manager.get_pending(bytes(32)) - assert pending is None - - def test_cancel_handshake(self, manager): - """Test canceling a handshake.""" - remote_node_id = bytes(32) - - manager.start_handshake(remote_node_id) - assert manager.cancel_handshake(remote_node_id) - assert manager.get_pending(remote_node_id) is None - - def test_cancel_nonexistent_returns_false(self, manager): - """Test that canceling nonexistent handshake returns False.""" - assert not manager.cancel_handshake(bytes(32)) - - def test_create_whoareyou(self, manager): - """Test creating a WHOAREYOU challenge.""" - remote_node_id = bytes(32) - request_nonce = Nonce(bytes(12)) - remote_enr_seq = SeqNumber(0) - masking_iv = Bytes16(bytes(16)) - - id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( - remote_node_id, request_nonce, remote_enr_seq, masking_iv - ) - - assert len(id_nonce) == 16 - assert nonce == request_nonce - - # Verify authdata decodes correctly - decoded = decode_whoareyou_authdata(authdata) - assert bytes(decoded.id_nonce) == id_nonce - assert decoded.enr_seq == remote_enr_seq - - # Verify challenge_data structure: masking-iv || static-header || authdata - # masking-iv (16) + static-header (23) + authdata (24) = 63 bytes - assert len(challenge_data) == 63 - assert challenge_data[:16] == bytes(masking_iv) - assert challenge_data[39:] == authdata # 16 + 23 = 39 - - # Check pending state - pending = manager.get_pending(remote_node_id) - assert pending is not None - assert pending.state == HandshakeState.SENT_WHOAREYOU - assert bytes(pending.id_nonce) == id_nonce - assert pending.challenge_data == challenge_data - - def test_cleanup_expired(self, manager): - """Test cleanup of expired handshakes.""" - remote1 = bytes.fromhex("01" + "00" * 31) - remote2 = bytes.fromhex("02" + "00" * 31) - - # Create handshakes with short timeout - manager._timeout_secs = 0.001 - manager.start_handshake(remote1) - manager.start_handshake(remote2) - - time.sleep(0.01) - removed = manager.cleanup_expired() - - assert removed == 2 - assert manager.get_pending(remote1) is None - assert manager.get_pending(remote2) is None - - -class TestHandshakeState: - """Tests for HandshakeState enum.""" - - def test_states_exist(self): - """Test that all expected states exist.""" - assert HandshakeState.IDLE - assert HandshakeState.SENT_ORDINARY - assert HandshakeState.SENT_WHOAREYOU - assert HandshakeState.COMPLETED - - def test_states_are_distinct(self): - """Test that states are distinct.""" - states = [ - HandshakeState.IDLE, - HandshakeState.SENT_ORDINARY, - HandshakeState.SENT_WHOAREYOU, - HandshakeState.COMPLETED, - ] - - assert len(set(states)) == 4 - - -class TestHandshakeStateTransitions: - """Verify all state machine transitions.""" - - def test_idle_to_sent_ordinary_on_start_handshake(self, manager): - """Starting a handshake transitions to SENT_ORDINARY state. - - When initiating contact with a node that has no session, - we send a MESSAGE that will trigger WHOAREYOU. - """ - remote_node_id = bytes(32) - - pending = manager.start_handshake(remote_node_id) - - assert pending.state == HandshakeState.SENT_ORDINARY - assert pending.remote_node_id == remote_node_id - - def test_sent_ordinary_state_has_no_challenge_data(self, manager): - """In SENT_ORDINARY state, challenge data is not yet available. - - Challenge data only becomes available after receiving WHOAREYOU. - """ - remote_node_id = bytes(32) - - pending = manager.start_handshake(remote_node_id) - - assert pending.state == HandshakeState.SENT_ORDINARY - assert pending.id_nonce is None - assert pending.challenge_data is None - assert pending.ephemeral_privkey is None - - def test_create_whoareyou_transitions_to_sent_whoareyou(self, manager): - """Creating WHOAREYOU transitions to SENT_WHOAREYOU state. - - When we receive an undecryptable MESSAGE, we respond with WHOAREYOU. - """ - remote_node_id = bytes(32) - request_nonce = Nonce(bytes(12)) - remote_enr_seq = SeqNumber(0) - masking_iv = Bytes16(bytes(16)) - - id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( - remote_node_id, request_nonce, remote_enr_seq, masking_iv - ) - - pending = manager.get_pending(remote_node_id) - - assert pending is not None - assert pending.state == HandshakeState.SENT_WHOAREYOU - assert bytes(pending.id_nonce) == id_nonce - assert pending.challenge_data == challenge_data - - def test_sent_whoareyou_state_has_challenge_data(self, manager): - """In SENT_WHOAREYOU state, all challenge data is stored.""" - remote_node_id = bytes(32) - request_nonce = Nonce(bytes(12)) - remote_enr_seq = SeqNumber(5) - masking_iv = Bytes16(bytes(16)) - - manager.create_whoareyou(remote_node_id, request_nonce, remote_enr_seq, masking_iv) - - pending = manager.get_pending(remote_node_id) - - assert pending.id_nonce is not None - assert pending.challenge_data is not None - assert pending.challenge_nonce == request_nonce - assert pending.remote_enr_seq == remote_enr_seq - - def test_handshake_overwrites_previous_pending(self, manager): - """Starting new handshake overwrites any previous pending state.""" - remote_node_id = bytes(32) - - # Start first handshake. - pending1 = manager.start_handshake(remote_node_id) - timestamp1 = pending1.started_at - - # Wait a bit and start another. - time.sleep(0.01) - pending2 = manager.start_handshake(remote_node_id) - - # Should have new pending with later timestamp. - assert pending2.started_at > timestamp1 - assert manager.get_pending(remote_node_id) is pending2 - - -class TestHandshakeValidation: - """Handshake security validation tests.""" - - def test_handle_handshake_requires_pending_state(self, manager, remote_keypair): - """Handshake fails if no pending state exists for the remote.""" - remote_priv, remote_pub, remote_node_id = remote_keypair - - # Create fake handshake authdata. - fake_authdata = HandshakeAuthdata( - src_id=NodeId(remote_node_id), - sig_size=64, - eph_key_size=33, - id_signature=Bytes64(bytes(64)), - eph_pubkey=Bytes33(bytes(33)), - record=None, - ) - - # Should fail because no WHOAREYOU was sent. - with pytest.raises(HandshakeError, match="No pending handshake"): - manager.handle_handshake(NodeId(remote_node_id), fake_authdata) - - def test_handle_handshake_requires_sent_whoareyou_state(self, manager, remote_keypair): - """Handshake fails if not in SENT_WHOAREYOU state.""" - remote_priv, remote_pub, remote_node_id = remote_keypair - - # Start handshake (puts in SENT_ORDINARY state). - manager.start_handshake(NodeId(remote_node_id)) - - fake_authdata = HandshakeAuthdata( - src_id=NodeId(remote_node_id), - sig_size=64, - eph_key_size=33, - id_signature=Bytes64(bytes(64)), - eph_pubkey=Bytes33(bytes(33)), - record=None, - ) - - # Should fail because we're in SENT_ORDINARY, not SENT_WHOAREYOU. - with pytest.raises(HandshakeError, match="Unexpected handshake state"): - manager.handle_handshake(NodeId(remote_node_id), fake_authdata) - - def test_handle_handshake_rejects_src_id_mismatch(self, manager, remote_keypair): - """Handshake fails if src_id doesn't match expected remote.""" - remote_priv, remote_pub, remote_node_id = remote_keypair - - # Set up WHOAREYOU state. - manager.create_whoareyou( - NodeId(remote_node_id), - Nonce(bytes(12)), - SeqNumber(0), - Bytes16(bytes(16)), - ) - - # Create authdata with different src_id. - wrong_src_id = NodeId(bytes([0xFF] * 32)) - fake_authdata = HandshakeAuthdata( - src_id=wrong_src_id, - sig_size=64, - eph_key_size=33, - id_signature=Bytes64(bytes(64)), - eph_pubkey=Bytes33(bytes(33)), - record=None, - ) - - # Should fail due to source ID mismatch. - with pytest.raises(HandshakeError, match="Source ID mismatch"): - manager.handle_handshake(NodeId(remote_node_id), fake_authdata) - - def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypair): - """Handshake fails if enr_seq=0 and no ENR included. - - When we don't know the remote's ENR (signaled by enr_seq=0 in WHOAREYOU), - the remote MUST include their ENR in the HANDSHAKE response. - """ - remote_priv, remote_pub, remote_node_id = remote_keypair - - # Set up WHOAREYOU with enr_seq=0 (unknown). - manager.create_whoareyou( - NodeId(remote_node_id), - Nonce(bytes(12)), - SeqNumber(0), # enr_seq = 0 means we don't know remote's ENR - Bytes16(bytes(16)), - ) - - # Create authdata without ENR record. - fake_authdata = HandshakeAuthdata( - src_id=NodeId(remote_node_id), - sig_size=64, - eph_key_size=33, - id_signature=Bytes64(bytes(64)), - eph_pubkey=Bytes33(bytes(33)), - record=None, # No ENR included. - ) - - # Should fail because ENR is required. - with pytest.raises(HandshakeError, match="ENR required"): - manager.handle_handshake(NodeId(remote_node_id), fake_authdata) - - def test_successful_handshake_with_signature_verification( - self, manager, remote_keypair, session_cache - ): - """Full handshake succeeds when signature is valid. - - Exercises the complete WHOAREYOU -> HANDSHAKE -> session flow. - """ - remote_priv, remote_pub, remote_node_id = remote_keypair - - # Node A (manager) creates WHOAREYOU for remote. - masking_iv = Bytes16(bytes(16)) - id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( - NodeId(remote_node_id), Nonce(bytes(12)), SeqNumber(0), masking_iv - ) - - # Remote creates handshake response. - eph_priv, eph_pub = generate_secp256k1_keypair() - local_node_id = manager._local_node_id - - # Remote signs the id_nonce proving ownership. - id_signature = sign_id_nonce( - Bytes32(remote_priv), - challenge_data, - Bytes33(eph_pub), - Bytes32(local_node_id), - ) - - # Remote includes their ENR since enr_seq=0. - remote_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4", EnrKey("secp256k1"): bytes(remote_pub)}, - ) - - authdata_bytes = encode_handshake_authdata( - src_id=NodeId(remote_node_id), - id_signature=id_signature, - eph_pubkey=eph_pub, - record=remote_enr.to_rlp(), - ) - - handshake = decode_handshake_authdata(authdata_bytes) - - # Manager processes handshake - should succeed. - result = manager.handle_handshake(NodeId(remote_node_id), handshake) - - assert result is not None - assert isinstance(result, HandshakeResult) - assert result.session is not None - assert len(result.session.send_key) == 16 - assert len(result.session.recv_key) == 16 - - def test_handle_handshake_rejects_invalid_signature( - self, manager, remote_keypair, session_cache - ): - """Handshake fails when signature is invalid.""" - remote_priv, remote_pub, remote_node_id = remote_keypair - - # Set up WHOAREYOU state. - masking_iv = Bytes16(bytes(16)) - manager.create_whoareyou(NodeId(remote_node_id), Nonce(bytes(12)), SeqNumber(0), masking_iv) - - # Generate ephemeral key. - _eph_priv, eph_pub = generate_secp256k1_keypair() - - # Create authdata with INVALID signature (all-zero 64 bytes). - remote_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4", EnrKey("secp256k1"): bytes(remote_pub)}, - ) - - authdata_bytes = encode_handshake_authdata( - src_id=NodeId(remote_node_id), - id_signature=Bytes64(bytes(64)), # Wrong signature. - eph_pubkey=eph_pub, - record=remote_enr.to_rlp(), - ) - - handshake = decode_handshake_authdata(authdata_bytes) - - with pytest.raises(HandshakeError, match="Invalid ID signature"): - manager.handle_handshake(NodeId(remote_node_id), handshake) - - -class TestHandshakeConcurrency: - """Concurrent handshake handling tests.""" - - def test_multiple_handshakes_independent(self, manager): - """Handshakes to different peers don't interfere.""" - remote1 = bytes.fromhex("01" + "00" * 31) - remote2 = bytes.fromhex("02" + "00" * 31) - remote3 = bytes.fromhex("03" + "00" * 31) - - # Start handshakes with different remotes. - manager.start_handshake(remote1) - manager.start_handshake(remote2) - - # Create WHOAREYOU for third remote. - manager.create_whoareyou(remote3, Nonce(bytes(12)), SeqNumber(0), Bytes16(bytes(16))) - - # All should have independent state. - assert manager.get_pending(remote1).state == HandshakeState.SENT_ORDINARY - assert manager.get_pending(remote2).state == HandshakeState.SENT_ORDINARY - assert manager.get_pending(remote3).state == HandshakeState.SENT_WHOAREYOU - - def test_cancel_one_handshake_preserves_others(self, manager): - """Canceling one handshake doesn't affect others.""" - remote1 = bytes.fromhex("01" + "00" * 31) - remote2 = bytes.fromhex("02" + "00" * 31) - - manager.start_handshake(remote1) - manager.start_handshake(remote2) - - # Cancel first. - result = manager.cancel_handshake(remote1) - assert result is True - - # First should be gone, second should remain. - assert manager.get_pending(remote1) is None - assert manager.get_pending(remote2) is not None - assert manager.get_pending(remote2).state == HandshakeState.SENT_ORDINARY - - def test_expired_handshake_cleanup_selective(self, manager): - """Cleanup only removes expired handshakes.""" - remote1 = bytes.fromhex("01" + "00" * 31) - remote2 = bytes.fromhex("02" + "00" * 31) - - # Set short timeout. - manager._timeout_secs = 0.01 - - # Start first handshake. - manager.start_handshake(remote1) - - # Wait for expiry. - time.sleep(0.02) - - # Start second handshake (not expired yet). - manager.start_handshake(remote2) - - # Cleanup should remove only expired. - removed = manager.cleanup_expired() - assert removed == 1 - - assert manager.get_pending(remote1) is None - assert manager.get_pending(remote2) is not None - - def test_get_pending_returns_none_for_expired(self, manager): - """Getting an expired pending handshake returns None and cleans up.""" - remote = bytes.fromhex("01" + "00" * 31) - - manager._timeout_secs = 0.01 - manager.start_handshake(remote) - - time.sleep(0.02) - - # Should return None because expired. - pending = manager.get_pending(remote) - assert pending is None - - def test_id_nonce_uniqueness_across_challenges(self, manager): - """Each WHOAREYOU challenge has a unique id_nonce.""" - remote1 = bytes.fromhex("01" + "00" * 31) - remote2 = bytes.fromhex("02" + "00" * 31) - - nonce = Nonce(bytes(12)) - iv = Bytes16(bytes(16)) - id_nonce1, _, _, _ = manager.create_whoareyou(remote1, nonce, SeqNumber(0), iv) - id_nonce2, _, _, _ = manager.create_whoareyou(remote2, nonce, SeqNumber(0), iv) - - # Each challenge should have unique id_nonce. - assert id_nonce1 != id_nonce2 - - -class TestHandshakeEnrInclusion: - """Tests for ENR inclusion/exclusion in HANDSHAKE responses.""" - - def test_enr_included_when_remote_seq_is_stale(self, local_keypair, remote_keypair): - """HANDSHAKE includes our ENR when remote's known seq is lower than ours. - - When the WHOAREYOU's enr_seq < our local_enr_seq, the remote has a - stale copy of our ENR. We include our current ENR so they can update. - """ - local_priv, local_pub, local_node_id = local_keypair - remote_priv, remote_pub, remote_node_id = remote_keypair - - session_cache = SessionCache() - manager = HandshakeManager( - local_node_id=local_node_id, - local_private_key=local_priv, - local_enr_rlp=b"mock_enr_data", - local_enr_seq=SeqNumber(5), - session_cache=session_cache, - ) - - # Remote creates WHOAREYOU with enr_seq=0 (stale). - whoareyou = WhoAreYouAuthdata( - id_nonce=IdNonce(bytes(16)), - enr_seq=SeqNumber(0), - ) - - challenge_data = bytes(63) - authdata, _, _ = manager.create_handshake_response( - remote_node_id=NodeId(remote_node_id), - whoareyou=whoareyou, - remote_pubkey=remote_pub, - challenge_data=challenge_data, - ) - - # Decode authdata and verify ENR is present. - decoded = decode_handshake_authdata(authdata) - assert decoded.record is not None - - def test_enr_excluded_when_remote_seq_is_current(self, local_keypair, remote_keypair): - """HANDSHAKE excludes our ENR when remote's known seq >= ours. - - When the remote already has our current ENR, sending it again - wastes bandwidth. The handshake packet should omit the record. - """ - local_priv, local_pub, local_node_id = local_keypair - remote_priv, remote_pub, remote_node_id = remote_keypair - - session_cache = SessionCache() - manager = HandshakeManager( - local_node_id=local_node_id, - local_private_key=local_priv, - local_enr_rlp=b"mock_enr_data", - local_enr_seq=SeqNumber(5), - session_cache=session_cache, - ) - - # Remote creates WHOAREYOU with enr_seq=5 (current). - whoareyou = WhoAreYouAuthdata( - id_nonce=IdNonce(bytes(16)), - enr_seq=SeqNumber(5), - ) - - challenge_data = bytes(63) - authdata, _, _ = manager.create_handshake_response( - remote_node_id=NodeId(remote_node_id), - whoareyou=whoareyou, - remote_pubkey=remote_pub, - challenge_data=challenge_data, - ) - - # Decode authdata and verify ENR is absent. - decoded = decode_handshake_authdata(authdata) - assert decoded.record is None - - -class TestHandshakeENRCache: - """Tests for ENR caching in handshake manager.""" - - def test_register_enr_stores_in_cache(self, manager): - """Registered ENRs are retrievable from cache.""" - remote_node_id = compute_node_id(NODE_B_PUBKEY) - - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): NODE_B_PUBKEY, - }, - ) - - manager.register_enr(remote_node_id, enr) - - cached = manager.get_cached_enr(remote_node_id) - assert cached is enr - - def test_get_cached_enr_returns_none_for_unknown(self, manager): - """Getting uncached ENR returns None.""" - unknown_id = bytes(32) - assert manager.get_cached_enr(unknown_id) is None diff --git a/tests/lean_spec/subspecs/networking/discovery/test_integration.py b/tests/lean_spec/subspecs/networking/discovery/test_integration.py deleted file mode 100644 index 66a13cb5..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_integration.py +++ /dev/null @@ -1,375 +0,0 @@ -""" -Integration tests for Discovery v5. - -Tests full protocol flows between components. -""" - -from __future__ import annotations - -import time - -import pytest - -from lean_spec.subspecs.networking.discovery.codec import ( - decode_message, - encode_message, -) -from lean_spec.subspecs.networking.discovery.crypto import ( - aes_gcm_decrypt, - generate_secp256k1_keypair, -) -from lean_spec.subspecs.networking.discovery.handshake import HandshakeManager -from lean_spec.subspecs.networking.discovery.keys import compute_node_id, derive_keys_from_pubkey -from lean_spec.subspecs.networking.discovery.messages import ( - Nonce, - PacketFlag, - Ping, - RequestId, -) -from lean_spec.subspecs.networking.discovery.packet import ( - decode_handshake_authdata, - decode_message_authdata, - decode_packet_header, - decode_whoareyou_authdata, - encode_message_authdata, - encode_packet, -) -from lean_spec.subspecs.networking.discovery.routing import ( - NodeEntry, - RoutingTable, -) -from lean_spec.subspecs.networking.discovery.session import Session, SessionCache -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.enr.keys import EnrKey -from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes12, Bytes16, Bytes64 - - -@pytest.fixture -def node_a_keys(): - """Node A's keypair.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return {"private_key": priv, "public_key": pub, "node_id": NodeId(node_id)} - - -@pytest.fixture -def node_b_keys(): - """Node B's keypair.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return {"private_key": priv, "public_key": pub, "node_id": NodeId(node_id)} - - -class TestEncryptedPacketRoundtrip: - """Test encrypted packet encoding/decoding.""" - - def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): - """MESSAGE packet encrypts and decrypts correctly.""" - # Build mock challenge_data for key derivation. - # Format: masking-iv (16) + static-header (23) + authdata (24) = 63 bytes. - masking_iv = Bytes16(bytes(16)) - static_header = b"discv5" + b"\x00\x01\x01" + bytes(12) + b"\x00\x18" - authdata = bytes(24) - challenge_data = masking_iv + static_header + authdata - - # Create session keys (derived from ECDH). - # Node A is initiator. - send_key, recv_key = derive_keys_from_pubkey( - local_private_key=node_a_keys["private_key"], - remote_public_key=node_b_keys["public_key"], - local_node_id=node_a_keys["node_id"], - remote_node_id=node_b_keys["node_id"], - challenge_data=challenge_data, - is_initiator=True, - ) - - # Create a PING message. - ping = Ping( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - ) - message_bytes = encode_message(ping) - - # Create authdata. - authdata = encode_message_authdata(node_a_keys["node_id"]) - nonce = Nonce.generate() - - # Encode packet. - packet = encode_packet( - dest_node_id=node_b_keys["node_id"], - flag=PacketFlag.MESSAGE, - nonce=nonce, - authdata=authdata, - message=message_bytes, - encryption_key=send_key, - ) - - # Decode header. - header, ciphertext, message_ad = decode_packet_header(node_b_keys["node_id"], packet) - - assert header.flag == PacketFlag.MESSAGE - - # Decode authdata. - decoded_authdata = decode_message_authdata(header.authdata) - assert decoded_authdata.src_id == node_a_keys["node_id"] - - # Node B derives keys as recipient (using same challenge_data). - b_send_key, b_recv_key = derive_keys_from_pubkey( - local_private_key=node_b_keys["private_key"], - remote_public_key=node_a_keys["public_key"], - local_node_id=node_b_keys["node_id"], - remote_node_id=node_a_keys["node_id"], - challenge_data=challenge_data, - is_initiator=False, - ) - - # Node B uses recv_key to decrypt (which equals Node A's send_key). - # message_ad = masking-iv || plaintext header (per spec). - plaintext = aes_gcm_decrypt( - Bytes16(b_recv_key), Bytes12(header.nonce), ciphertext, message_ad - ) - - # Decode message. - decoded_ping = decode_message(plaintext) - assert isinstance(decoded_ping, Ping) - assert int(decoded_ping.enr_seq) == 1 - - -class TestSessionEstablishment: - """Test session key establishment flow.""" - - def test_session_cache_operations(self, node_a_keys, node_b_keys): - """Session cache stores and retrieves sessions.""" - cache = SessionCache() - - now = time.time() - session = Session( - node_id=node_b_keys["node_id"], - send_key=Bytes16(bytes(16)), - recv_key=Bytes16(bytes(16)), - created_at=now, - last_seen=now, - is_initiator=True, - ) - - cache.create( - node_id=session.node_id, - send_key=session.send_key, - recv_key=session.recv_key, - is_initiator=session.is_initiator, - ) - - retrieved = cache.get(node_b_keys["node_id"]) - assert retrieved is not None - assert retrieved.node_id == node_b_keys["node_id"] - - # Session cache eviction is tested in test_session.py TestSessionCache.test_eviction_when_full - - -class TestRoutingTableIntegration: - """Test routing table with node entries.""" - - def test_add_and_lookup_nodes(self, node_a_keys): - """Add nodes and perform lookup.""" - table = RoutingTable(local_id=NodeId(node_a_keys["node_id"])) - - # Add several nodes. - node_ids = [] - for i in range(20): - node_id = NodeId(bytes([i * 10]) + bytes(31)) - entry = NodeEntry( - node_id=node_id, - enr_seq=SeqNumber(1), - verified=True, - ) - table.add(entry) - node_ids.append(node_id) - - assert table.node_count() == 20 - - # Lookup closest to a target. - target = NodeId(bytes(32)) - closest = table.closest_nodes(target, 16) - - assert len(closest) == 16 - - def test_bucket_distribution(self, node_a_keys): - """Nodes distribute across buckets by distance.""" - table = RoutingTable(local_id=NodeId(node_a_keys["node_id"])) - - # Add nodes with varying first bytes. - for i in range(256): - node_id = NodeId(bytes([i]) + bytes(31)) - entry = NodeEntry(node_id=node_id) - table.add(entry) - - # Count non-empty buckets. - non_empty = sum(1 for b in table.buckets if not b.is_empty) - - # Should have nodes in multiple buckets. - assert non_empty > 1 - - -class TestHandshakeManagerIntegration: - """Test handshake manager flows.""" - - def test_whoareyou_generation(self, node_a_keys, node_b_keys): - """WHOAREYOU challenge generation.""" - cache = SessionCache() - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4"}, - ) - - manager = HandshakeManager( - local_node_id=node_a_keys["node_id"], - local_private_key=node_a_keys["private_key"], - local_enr_rlp=enr.to_rlp(), - local_enr_seq=SeqNumber(1), - session_cache=cache, - ) - - # Create WHOAREYOU. - request_nonce = Nonce(bytes(12)) - masking_iv = Bytes16(bytes(16)) - id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( - remote_node_id=node_b_keys["node_id"], - request_nonce=request_nonce, - remote_enr_seq=SeqNumber(0), - masking_iv=masking_iv, - ) - - assert len(id_nonce) == 16 - assert len(authdata) == 24 - - # Decode authdata. - decoded = decode_whoareyou_authdata(authdata) - assert bytes(decoded.id_nonce) == id_nonce - - def test_start_and_cancel_handshake(self, node_a_keys, node_b_keys): - """Handshake can be started and cancelled.""" - cache = SessionCache() - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4"}, - ) - - manager = HandshakeManager( - local_node_id=node_a_keys["node_id"], - local_private_key=node_a_keys["private_key"], - local_enr_rlp=enr.to_rlp(), - local_enr_seq=SeqNumber(1), - session_cache=cache, - ) - - # Start handshake. - pending = manager.start_handshake(node_b_keys["node_id"]) - assert pending is not None - assert pending.remote_node_id == node_b_keys["node_id"] - - # Get pending. - retrieved = manager.get_pending(node_b_keys["node_id"]) - assert retrieved is pending - - # Cancel. - result = manager.cancel_handshake(node_b_keys["node_id"]) - assert result is True - - # Should be gone. - assert manager.get_pending(node_b_keys["node_id"]) is None - - -class TestFullHandshakeFlow: - """Test complete handshake between two nodes.""" - - def test_handshake_key_agreement(self, node_a_keys, node_b_keys): - """ - Full handshake establishes compatible session keys. - - 1. Node A sends MESSAGE (no session) - 2. Node B can't decrypt, sends WHOAREYOU - 3. Node A responds with HANDSHAKE - 4. Both derive same session keys - """ - cache_a = SessionCache() - cache_b = SessionCache() - - enr_a = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4", EnrKey("secp256k1"): node_a_keys["public_key"]}, - ) - enr_b = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4", EnrKey("secp256k1"): node_b_keys["public_key"]}, - ) - - manager_a = HandshakeManager( - local_node_id=node_a_keys["node_id"], - local_private_key=node_a_keys["private_key"], - local_enr_rlp=enr_a.to_rlp(), - local_enr_seq=SeqNumber(1), - session_cache=cache_a, - ) - - manager_b = HandshakeManager( - local_node_id=node_b_keys["node_id"], - local_private_key=node_b_keys["private_key"], - local_enr_rlp=enr_b.to_rlp(), - local_enr_seq=SeqNumber(1), - session_cache=cache_b, - ) - - # Step 1: Node A starts handshake. - manager_a.start_handshake(node_b_keys["node_id"]) - - # Step 2: Node B creates WHOAREYOU. - request_nonce = Nonce(bytes(12)) - masking_iv = Bytes16(bytes(16)) - id_nonce, whoareyou_authdata, _, challenge_data = manager_b.create_whoareyou( - remote_node_id=node_a_keys["node_id"], - request_nonce=request_nonce, - remote_enr_seq=SeqNumber(0), - masking_iv=masking_iv, - ) - - # Decode WHOAREYOU for Node A to use. - whoareyou = decode_whoareyou_authdata(whoareyou_authdata) - - # Step 3: Node A creates HANDSHAKE response. - # This requires Node A to have Node B's public key and the challenge_data. - handshake_authdata, send_key, recv_key = manager_a.create_handshake_response( - remote_node_id=node_b_keys["node_id"], - whoareyou=whoareyou, - remote_pubkey=node_b_keys["public_key"], - challenge_data=challenge_data, - ) - - # Verify keys were derived. - assert len(send_key) == 16 - assert len(recv_key) == 16 - - # Decode handshake authdata. - handshake = decode_handshake_authdata(handshake_authdata) - assert handshake.src_id == node_a_keys["node_id"] - - # Step 4: Node B processes HANDSHAKE. - result = manager_b.handle_handshake( - remote_node_id=node_a_keys["node_id"], - handshake=handshake, - ) - - # Handshake completed successfully - session was established. - assert result is not None - assert result.session is not None - assert len(result.session.send_key) == 16 - assert len(result.session.recv_key) == 16 - - # Cross-key verification: A's send_key must equal B's recv_key and vice versa. - # This confirms both sides derived compatible session keys from the handshake. - assert send_key == result.session.recv_key - assert recv_key == result.session.send_key diff --git a/tests/lean_spec/subspecs/networking/discovery/test_keys.py b/tests/lean_spec/subspecs/networking/discovery/test_keys.py deleted file mode 100644 index 038fd31a..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_keys.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Tests for Discovery v5 key derivation.""" - -from lean_spec.subspecs.networking.discovery.crypto import ( - generate_secp256k1_keypair, - pubkey_to_uncompressed, -) -from lean_spec.subspecs.networking.discovery.keys import ( - compute_node_id, - derive_keys, - derive_keys_from_pubkey, -) -from lean_spec.types import Bytes32, Bytes33 -from tests.lean_spec.helpers import make_challenge_data - - -class TestDeriveKeys: - """Tests for session key derivation.""" - - def test_derives_two_16_byte_keys(self): - """Test that key derivation produces two 16-byte keys.""" - secret = Bytes33(bytes(33)) - initiator_id = Bytes32(bytes(32)) - recipient_id = Bytes32(bytes(32)) - challenge_data = make_challenge_data() - - init_key, recv_key = derive_keys(secret, initiator_id, recipient_id, challenge_data) - - assert len(init_key) == 16 - assert len(recv_key) == 16 - - def test_different_secrets_produce_different_keys(self): - """Test that different secrets produce different keys.""" - secret1 = Bytes33(bytes.fromhex("00" * 33)) - secret2 = Bytes33(bytes.fromhex("01" + "00" * 32)) - initiator_id = Bytes32(bytes(32)) - recipient_id = Bytes32(bytes(32)) - challenge_data = make_challenge_data() - - keys1 = derive_keys(secret1, initiator_id, recipient_id, challenge_data) - keys2 = derive_keys(secret2, initiator_id, recipient_id, challenge_data) - - assert keys1 != keys2 - - def test_different_node_ids_produce_different_keys(self): - """Test that different node IDs produce different keys.""" - secret = Bytes33(bytes(33)) - initiator_id1 = Bytes32(bytes.fromhex("00" * 32)) - initiator_id2 = Bytes32(bytes.fromhex("01" + "00" * 31)) - recipient_id = Bytes32(bytes(32)) - challenge_data = make_challenge_data() - - keys1 = derive_keys(secret, initiator_id1, recipient_id, challenge_data) - keys2 = derive_keys(secret, initiator_id2, recipient_id, challenge_data) - - assert keys1 != keys2 - - def test_different_challenge_data_produce_different_keys(self): - """Test that different challenge data produces different keys.""" - secret = Bytes33(bytes(33)) - initiator_id = Bytes32(bytes(32)) - recipient_id = Bytes32(bytes(32)) - challenge_data1 = make_challenge_data(bytes.fromhex("00" * 16)) - challenge_data2 = make_challenge_data(bytes.fromhex("01" + "00" * 15)) - - keys1 = derive_keys(secret, initiator_id, recipient_id, challenge_data1) - keys2 = derive_keys(secret, initiator_id, recipient_id, challenge_data2) - - assert keys1 != keys2 - - def test_order_matters(self): - """Test that initiator and recipient order matters.""" - secret = Bytes33(bytes(33)) - node_a = Bytes32(bytes.fromhex("aa" * 32)) - node_b = Bytes32(bytes.fromhex("bb" * 32)) - challenge_data = make_challenge_data() - - keys_ab = derive_keys(secret, node_a, node_b, challenge_data) - keys_ba = derive_keys(secret, node_b, node_a, challenge_data) - - assert keys_ab != keys_ba - - -class TestDeriveKeysFromPubkey: - """Tests for key derivation from ECDH.""" - - def test_initiator_and_recipient_derive_compatible_keys(self): - """Test that both parties derive compatible keys.""" - priv_a, pub_a = generate_secp256k1_keypair() - priv_b, pub_b = generate_secp256k1_keypair() - node_a = compute_node_id(pub_a) - node_b = compute_node_id(pub_b) - challenge_data = make_challenge_data() - - # A initiates to B - send_a, recv_a = derive_keys_from_pubkey( - priv_a, pub_b, node_a, node_b, challenge_data, is_initiator=True - ) - - # B responds to A - send_b, recv_b = derive_keys_from_pubkey( - priv_b, pub_a, node_b, node_a, challenge_data, is_initiator=False - ) - - # A's send key should be B's recv key and vice versa - assert send_a == recv_b - assert recv_a == send_b - - -class TestComputeNodeId: - """Tests for node ID computation.""" - - def test_computes_32_byte_node_id(self): - """Test that node ID is 32 bytes.""" - _, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - - assert len(node_id) == 32 - - def test_same_pubkey_produces_same_node_id(self): - """Test that same public key produces same node ID.""" - _, pub = generate_secp256k1_keypair() - - id1 = compute_node_id(pub) - id2 = compute_node_id(pub) - - assert id1 == id2 - - def test_different_pubkeys_produce_different_node_ids(self): - """Test that different public keys produce different node IDs.""" - _, pub1 = generate_secp256k1_keypair() - _, pub2 = generate_secp256k1_keypair() - - id1 = compute_node_id(pub1) - id2 = compute_node_id(pub2) - - assert id1 != id2 - - def test_accepts_compressed_pubkey(self): - """Test that compressed public key format is accepted.""" - _, pub = generate_secp256k1_keypair() - assert len(pub) == 33 - - node_id = compute_node_id(pub) - assert len(node_id) == 32 - - def test_accepts_uncompressed_pubkey(self): - """Test that uncompressed public key format is accepted.""" - - _, compressed = generate_secp256k1_keypair() - uncompressed = pubkey_to_uncompressed(compressed) - assert len(uncompressed) == 65 - - node_id = compute_node_id(uncompressed) - assert len(node_id) == 32 - - def test_compressed_and_uncompressed_produce_same_id(self): - """Test that both formats produce the same node ID.""" - - _, compressed = generate_secp256k1_keypair() - uncompressed = pubkey_to_uncompressed(compressed) - - id_compressed = compute_node_id(compressed) - id_uncompressed = compute_node_id(uncompressed) - - assert id_compressed == id_uncompressed diff --git a/tests/lean_spec/subspecs/networking/discovery/test_messages.py b/tests/lean_spec/subspecs/networking/discovery/test_messages.py deleted file mode 100644 index ac1440b7..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_messages.py +++ /dev/null @@ -1,348 +0,0 @@ -""" -Tests for Discovery v5 protocol messages, types, and constants. - -Validates that protocol constants, message types, custom types, and -configuration match the Discovery v5 specification. -""" - -from __future__ import annotations - -from lean_spec.subspecs.networking.discovery.config import ( - ALPHA, - BOND_EXPIRY_SECS, - BUCKET_COUNT, - HANDSHAKE_TIMEOUT_SECS, - K_BUCKET_SIZE, - MAX_NODES_RESPONSE, - REQUEST_TIMEOUT_SECS, - DiscoveryConfig, -) -from lean_spec.subspecs.networking.discovery.messages import ( - MAX_REQUEST_ID_LENGTH, - PROTOCOL_ID, - PROTOCOL_VERSION, - Distance, - FindNode, - IdNonce, - IPv4, - IPv6, - MessageType, - Nodes, - Nonce, - PacketFlag, - Ping, - Pong, - Port, - RequestId, - TalkReq, - TalkResp, -) -from lean_spec.subspecs.networking.discovery.packet import WhoAreYouAuthdata -from lean_spec.subspecs.networking.types import SeqNumber -from lean_spec.types import Uint8, Uint16, Uint64 -from tests.lean_spec.subspecs.networking.discovery.conftest import SPEC_ID_NONCE - - -class TestProtocolConstants: - """Verify protocol constants match Discovery v5 specification.""" - - def test_protocol_id(self): - assert PROTOCOL_ID == b"discv5" - assert len(PROTOCOL_ID) == 6 - - def test_protocol_version(self): - assert PROTOCOL_VERSION == 0x0001 - - def test_max_request_id_length(self): - assert MAX_REQUEST_ID_LENGTH == 8 - - def test_k_bucket_size(self): - assert K_BUCKET_SIZE == 16 - - def test_alpha_concurrency(self): - assert ALPHA == 3 - - def test_bucket_count(self): - assert BUCKET_COUNT == 256 - - def test_request_timeout(self): - assert REQUEST_TIMEOUT_SECS == 0.5 - - def test_handshake_timeout(self): - assert HANDSHAKE_TIMEOUT_SECS == 1.0 - - def test_max_nodes_response(self): - assert MAX_NODES_RESPONSE == 16 - - def test_bond_expiry(self): - assert BOND_EXPIRY_SECS == 86400 - - -class TestCustomTypes: - """Tests for custom Discovery v5 types.""" - - def test_request_id_limit(self): - req_id = RequestId(data=b"\x01\x02\x03\x04\x05\x06\x07\x08") - assert len(req_id.data) == 8 - - def test_request_id_variable_length(self): - req_id = RequestId(data=b"\x01") - assert len(req_id.data) == 1 - - def test_ipv4_length(self): - ip = IPv4(b"\xc0\xa8\x01\x01") - assert len(ip) == 4 - - def test_ipv6_length(self): - ip = IPv6(b"\x00" * 15 + b"\x01") - assert len(ip) == 16 - - def test_id_nonce_length(self): - nonce = IdNonce(b"\x01" * 16) - assert len(nonce) == 16 - - def test_nonce_length(self): - nonce = Nonce(b"\x01" * 12) - assert len(nonce) == 12 - - def test_distance_type(self): - d = Distance(256) - assert isinstance(d, Uint16) - - def test_port_type(self): - p = Port(30303) - assert isinstance(p, Uint16) - - def test_enr_seq_type(self): - seq = SeqNumber(42) - assert isinstance(seq, Uint64) - - -class TestPacketFlag: - """Tests for packet type flags.""" - - def test_message_flag(self): - assert PacketFlag.MESSAGE == 0 - - def test_whoareyou_flag(self): - assert PacketFlag.WHOAREYOU == 1 - - def test_handshake_flag(self): - assert PacketFlag.HANDSHAKE == 2 - - -class TestMessageTypes: - """Verify message type codes match wire protocol spec.""" - - def test_ping_type(self): - assert MessageType.PING == 0x01 - - def test_pong_type(self): - assert MessageType.PONG == 0x02 - - def test_findnode_type(self): - assert MessageType.FINDNODE == 0x03 - - def test_nodes_type(self): - assert MessageType.NODES == 0x04 - - def test_talkreq_type(self): - assert MessageType.TALKREQ == 0x05 - - def test_talkresp_type(self): - assert MessageType.TALKRESP == 0x06 - - def test_experimental_types(self): - assert MessageType.REGTOPIC == 0x07 - assert MessageType.TICKET == 0x08 - assert MessageType.REGCONFIRMATION == 0x09 - assert MessageType.TOPICQUERY == 0x0A - - -class TestDiscoveryConfig: - """Tests for DiscoveryConfig.""" - - def test_default_values(self): - config = DiscoveryConfig() - - assert config.k_bucket_size == K_BUCKET_SIZE - assert config.alpha == ALPHA - assert config.request_timeout_secs == REQUEST_TIMEOUT_SECS - assert config.handshake_timeout_secs == HANDSHAKE_TIMEOUT_SECS - assert config.max_nodes_response == MAX_NODES_RESPONSE - assert config.bond_expiry_secs == BOND_EXPIRY_SECS - - def test_custom_values(self): - config = DiscoveryConfig( - k_bucket_size=8, - alpha=5, - request_timeout_secs=2.0, - ) - assert config.k_bucket_size == 8 - assert config.alpha == 5 - assert config.request_timeout_secs == 2.0 - - -class TestPing: - """Tests for PING message.""" - - def test_creation_with_types(self): - ping = Ping( - request_id=RequestId(data=b"\x00\x00\x00\x01"), - enr_seq=SeqNumber(2), - ) - - assert ping.request_id.data == b"\x00\x00\x00\x01" - assert ping.enr_seq == SeqNumber(2) - - def test_max_request_id_length(self): - ping = Ping( - request_id=RequestId(data=b"\x01\x02\x03\x04\x05\x06\x07\x08"), - enr_seq=SeqNumber(1), - ) - assert len(ping.request_id.data) == 8 - - -class TestPong: - """Tests for PONG message.""" - - def test_creation_ipv4(self): - pong = Pong( - request_id=RequestId(data=b"\x00\x00\x00\x01"), - enr_seq=SeqNumber(42), - recipient_ip=IPv4(b"\xc0\xa8\x01\x01"), - recipient_port=Port(9000), - ) - - assert pong.enr_seq == SeqNumber(42) - assert len(pong.recipient_ip) == 4 - assert pong.recipient_port == Port(9000) - - def test_creation_ipv6(self): - ipv6 = IPv6(b"\x00" * 15 + b"\x01") - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=ipv6, - recipient_port=Port(30303), - ) - - assert len(pong.recipient_ip) == 16 - - -class TestFindNode: - """Tests for FINDNODE message.""" - - def test_single_distance(self): - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(256)], - ) - assert findnode.distances == [Distance(256)] - - def test_multiple_distances(self): - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(0), Distance(1), Distance(255), Distance(256)], - ) - assert Distance(0) in findnode.distances - assert Distance(256) in findnode.distances - - def test_distance_zero_returns_self(self): - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(0)], - ) - assert findnode.distances == [Distance(0)] - - -class TestNodes: - """Tests for NODES message.""" - - def test_single_response(self): - nodes = Nodes( - request_id=RequestId(data=b"\x01"), - total=Uint8(1), - enrs=[b"enr:-example"], - ) - assert nodes.total == Uint8(1) - assert len(nodes.enrs) == 1 - - def test_multiple_responses(self): - nodes = Nodes( - request_id=RequestId(data=b"\x01"), - total=Uint8(3), - enrs=[b"enr1", b"enr2"], - ) - assert nodes.total == Uint8(3) - assert len(nodes.enrs) == 2 - - -class TestTalkReq: - """Tests for TALKREQ message.""" - - def test_creation(self): - req = TalkReq( - request_id=RequestId(data=b"\x01"), - protocol=b"portal", - request=b"payload", - ) - assert req.protocol == b"portal" - assert req.request == b"payload" - - -class TestTalkResp: - """Tests for TALKRESP message.""" - - def test_creation(self): - resp = TalkResp( - request_id=RequestId(data=b"\x01"), - response=b"response_data", - ) - assert resp.response == b"response_data" - - def test_empty_response_unknown_protocol(self): - resp = TalkResp( - request_id=RequestId(data=b"\x01"), - response=b"", - ) - assert resp.response == b"" - - -class TestWhoAreYouAuthdataConstruction: - """Tests for WHOAREYOU authdata construction.""" - - def test_creation(self): - authdata = WhoAreYouAuthdata( - id_nonce=IdNonce(b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10"), - enr_seq=SeqNumber(0), - ) - assert len(authdata.id_nonce) == 16 - assert authdata.enr_seq == SeqNumber(0) - - -class TestMessageConstructionFromTestVectors: - """Test message construction using official Discovery v5 test vector inputs.""" - - PING_REQUEST_ID = bytes.fromhex("00000001") - PING_ENR_SEQ = 2 - - def test_ping_message_construction(self): - ping = Ping( - request_id=RequestId(data=self.PING_REQUEST_ID), - enr_seq=SeqNumber(self.PING_ENR_SEQ), - ) - assert ping.request_id.data == self.PING_REQUEST_ID - assert ping.enr_seq == SeqNumber(2) - - def test_whoareyou_authdata_construction(self): - authdata = WhoAreYouAuthdata( - id_nonce=IdNonce(SPEC_ID_NONCE), - enr_seq=SeqNumber(0), - ) - assert authdata.id_nonce == IdNonce(SPEC_ID_NONCE) - assert authdata.enr_seq == SeqNumber(0) - - def test_plaintext_message_type(self): - plaintext = bytes.fromhex("01c20101") - assert plaintext[0] == MessageType.PING diff --git a/tests/lean_spec/subspecs/networking/discovery/test_packet.py b/tests/lean_spec/subspecs/networking/discovery/test_packet.py deleted file mode 100644 index a4eebce6..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_packet.py +++ /dev/null @@ -1,414 +0,0 @@ -"""Tests for Discovery v5 packet encoding/decoding.""" - -import pytest - -from lean_spec.subspecs.networking.discovery.config import MAX_PACKET_SIZE, MIN_PACKET_SIZE -from lean_spec.subspecs.networking.discovery.crypto import aes_ctr_encrypt -from lean_spec.subspecs.networking.discovery.messages import IdNonce, Nonce, PacketFlag -from lean_spec.subspecs.networking.discovery.packet import ( - HANDSHAKE_HEADER_SIZE, - MESSAGE_AUTHDATA_SIZE, - STATIC_HEADER_SIZE, - WHOAREYOU_AUTHDATA_SIZE, - decode_handshake_authdata, - decode_message_authdata, - decode_packet_header, - decode_whoareyou_authdata, - encode_handshake_authdata, - encode_message_authdata, - encode_packet, - encode_whoareyou_authdata, -) -from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes16, Bytes33, Bytes64 - - -class TestNonceGeneration: - """Tests for nonce generation.""" - - def test_generate_nonce_is_12_bytes(self): - """Test that generated nonce is 12 bytes.""" - nonce = Nonce.generate() - assert len(nonce) == 12 - - def test_generate_id_nonce_is_16_bytes(self): - """Test that generated ID nonce is 16 bytes.""" - id_nonce = IdNonce.generate() - assert len(id_nonce) == 16 - - def test_generates_different_nonces(self): - """Test that each generation produces different nonces.""" - nonce1 = Nonce.generate() - nonce2 = Nonce.generate() - assert nonce1 != nonce2 - - -class TestMessageAuthdata: - """Tests for MESSAGE packet authdata.""" - - def test_encode_message_authdata(self): - """Test MESSAGE authdata encoding.""" - src_id = NodeId(bytes(32)) - authdata = encode_message_authdata(src_id) - - assert len(authdata) == MESSAGE_AUTHDATA_SIZE - assert authdata == bytes(src_id) - - def test_decode_message_authdata(self): - """Test MESSAGE authdata decoding.""" - src_id = NodeId(bytes.fromhex("aa" * 32)) - authdata = encode_message_authdata(src_id) - decoded = decode_message_authdata(authdata) - - assert decoded.src_id == src_id - - def test_invalid_size_raises(self): - """Test that invalid authdata size raises ValueError.""" - with pytest.raises(ValueError, match="Invalid MESSAGE authdata size"): - decode_message_authdata(bytes(31)) - - -class TestWhoAreYouAuthdata: - """Tests for WHOAREYOU packet authdata.""" - - def test_encode_whoareyou_authdata(self): - """Test WHOAREYOU authdata encoding.""" - id_nonce = IdNonce(bytes(16)) - enr_seq = SeqNumber(42) - - authdata = encode_whoareyou_authdata(id_nonce, enr_seq) - - assert len(authdata) == WHOAREYOU_AUTHDATA_SIZE - - def test_decode_whoareyou_authdata(self): - """Test WHOAREYOU authdata decoding.""" - id_nonce = IdNonce(bytes.fromhex("aa" * 16)) - enr_seq = SeqNumber(12345) - - authdata = encode_whoareyou_authdata(id_nonce, enr_seq) - decoded = decode_whoareyou_authdata(authdata) - - assert decoded.id_nonce == id_nonce - assert decoded.enr_seq == enr_seq - - def test_roundtrip(self): - """Test encoding then decoding preserves values.""" - id_nonce = IdNonce(bytes.fromhex("01" * 16)) - enr_seq = SeqNumber(2**63 - 1) # Max uint64 - - authdata = encode_whoareyou_authdata(id_nonce, enr_seq) - decoded = decode_whoareyou_authdata(authdata) - - assert decoded.id_nonce == id_nonce - assert decoded.enr_seq == enr_seq - - def test_invalid_size_raises(self): - """Test that invalid authdata size raises ValueError.""" - with pytest.raises(ValueError, match="Invalid WHOAREYOU authdata size"): - decode_whoareyou_authdata(bytes(23)) - - -class TestHandshakeAuthdata: - """Tests for HANDSHAKE packet authdata.""" - - def test_encode_handshake_authdata(self): - """Test HANDSHAKE authdata encoding.""" - src_id = NodeId(bytes(32)) - id_signature = Bytes64(bytes(64)) - eph_pubkey = Bytes33(bytes([0x02]) + bytes(32)) - - authdata = encode_handshake_authdata(src_id, id_signature, eph_pubkey) - - # 32 (src_id) + 1 (sig_size) + 1 (eph_key_size) + 64 (sig) + 33 (eph) - expected_size = HANDSHAKE_HEADER_SIZE + 64 + 33 - assert len(authdata) == expected_size - - def test_decode_handshake_authdata(self): - """Test HANDSHAKE authdata decoding.""" - src_id = NodeId(bytes.fromhex("aa" * 32)) - id_signature = Bytes64(bytes.fromhex("bb" * 64)) - eph_pubkey = Bytes33(bytes([0x02]) + bytes.fromhex("cc" * 32)) - - authdata = encode_handshake_authdata(src_id, id_signature, eph_pubkey) - decoded = decode_handshake_authdata(authdata) - - assert decoded.src_id == src_id - assert decoded.sig_size == 64 - assert decoded.eph_key_size == 33 - assert decoded.id_signature == id_signature - assert decoded.eph_pubkey == eph_pubkey - assert decoded.record is None - - def test_with_enr_record(self): - """Test HANDSHAKE authdata with ENR record.""" - src_id = NodeId(bytes(32)) - id_signature = Bytes64(bytes(64)) - eph_pubkey = Bytes33(bytes([0x02]) + bytes(32)) - record = b"enr:-IS4QHCYrY..." # Mock ENR - - authdata = encode_handshake_authdata(src_id, id_signature, eph_pubkey, record) - decoded = decode_handshake_authdata(authdata) - - assert decoded.record == record - - -class TestPacketEncoding: - """Tests for full packet encoding/decoding.""" - - def test_encode_message_packet(self): - """Test MESSAGE packet encoding.""" - dest_node_id = NodeId(bytes(32)) - src_node_id = NodeId(bytes(32)) - nonce = Nonce(bytes(12)) - authdata = encode_message_authdata(src_node_id) - message = b"encrypted message" - encryption_key = Bytes16(bytes(16)) - - packet = encode_packet( - dest_node_id=dest_node_id, - flag=PacketFlag.MESSAGE, - nonce=nonce, - authdata=authdata, - message=message, - encryption_key=encryption_key, - ) - - # Packet should contain: masking_iv (16) + masked_header + encrypted_message - assert len(packet) > 16 + STATIC_HEADER_SIZE + len(authdata) - - def test_encode_whoareyou_packet(self): - """Test WHOAREYOU packet encoding.""" - dest_node_id = NodeId(bytes(32)) - nonce = Nonce(bytes(12)) - id_nonce = IdNonce(bytes(16)) - authdata = encode_whoareyou_authdata(id_nonce, SeqNumber(0)) - - packet = encode_packet( - dest_node_id=dest_node_id, - flag=PacketFlag.WHOAREYOU, - nonce=nonce, - authdata=authdata, - message=b"", - encryption_key=None, - ) - - # WHOAREYOU has no message content - expected_size = 16 + STATIC_HEADER_SIZE + WHOAREYOU_AUTHDATA_SIZE - assert len(packet) == expected_size - - def test_decode_packet_header(self): - """Test packet header decoding.""" - local_node_id = NodeId(bytes(32)) - nonce = Nonce(bytes(12)) - authdata = encode_whoareyou_authdata(IdNonce(bytes(16)), SeqNumber(42)) - - packet = encode_packet( - dest_node_id=local_node_id, - flag=PacketFlag.WHOAREYOU, - nonce=nonce, - authdata=authdata, - message=b"", - encryption_key=None, - ) - - header, message_bytes, _message_ad = decode_packet_header(local_node_id, packet) - - assert header.flag == PacketFlag.WHOAREYOU - assert header.nonce == nonce - assert header.authdata == authdata - assert message_bytes == b"" - - -class TestConstants: - """Tests for packet format constants.""" - - def test_static_header_size(self): - """Test static header size constant.""" - # protocol_id (6) + version (2) + flag (1) + nonce (12) + authdata_size (2) - assert STATIC_HEADER_SIZE == 23 - - def test_message_authdata_size(self): - """Test MESSAGE authdata size constant.""" - # src_id (32) - assert MESSAGE_AUTHDATA_SIZE == 32 - - def test_whoareyou_authdata_size(self): - """Test WHOAREYOU authdata size constant.""" - # id_nonce (16) + enr_seq (8) - assert WHOAREYOU_AUTHDATA_SIZE == 24 - - def test_handshake_header_size(self): - """Test HANDSHAKE header size constant.""" - # src_id (32) + sig_size (1) + eph_key_size (1) - assert HANDSHAKE_HEADER_SIZE == 34 - - -class TestPacketSizeLimits: - """Packet size boundary validation. - - Per spec: - - MIN_PACKET_SIZE = 63 bytes (masking-iv + min header) - - MAX_PACKET_SIZE = 1280 bytes (IPv6 MTU) - """ - - def test_min_packet_size_constant(self): - """MIN_PACKET_SIZE matches spec minimum.""" - # masking-iv (16) + static-header (23) + min authdata (24 for WHOAREYOU) - assert MIN_PACKET_SIZE == 63 - - def test_max_packet_size_constant(self): - """MAX_PACKET_SIZE matches IPv6 MTU.""" - # IPv6 minimum MTU = 1280 bytes - assert MAX_PACKET_SIZE == 1280 - - def test_reject_undersized_packet(self): - """Packets smaller than MIN_PACKET_SIZE are rejected.""" - local_node_id = NodeId(bytes(32)) - - # Packet that's too small. - undersized_packet = bytes(MIN_PACKET_SIZE - 1) - - with pytest.raises(ValueError, match="too small"): - decode_packet_header(local_node_id, undersized_packet) - - def test_minimum_valid_packet_structure(self): - """Minimum valid packet has correct structure.""" - # WHOAREYOU is the smallest packet type: - # masking-iv (16) + static-header (23) + authdata (24) = 63 bytes - expected_min = 16 + STATIC_HEADER_SIZE + WHOAREYOU_AUTHDATA_SIZE - assert expected_min == MIN_PACKET_SIZE - - def test_encode_packet_enforces_max_size(self): - """encode_packet raises error if packet exceeds max size.""" - src_id = NodeId(bytes(32)) - dest_id = NodeId(bytes(32)) - nonce = Nonce(bytes(12)) - encryption_key = Bytes16(bytes(16)) - - # Create authdata. - authdata = encode_message_authdata(src_id) - - # Try to create a packet with message that would exceed max size. - # Need message large enough that total > 1280 - # Overhead: masking-iv(16) + static(23) + authdata(32) + tag(16) = 87 - # So message > 1193 should trigger error. - large_message = bytes(1300) - - with pytest.raises(ValueError, match="exceeds max size"): - encode_packet( - dest_node_id=dest_id, - flag=PacketFlag.MESSAGE, - nonce=nonce, - authdata=authdata, - message=large_message, - encryption_key=encryption_key, - ) - - def test_truncated_static_header_rejected(self): - """Incomplete static header is rejected.""" - local_node_id = NodeId(bytes(32)) - - # Packet with only masking-iv and partial static header. - # masking-iv (16) + partial static header (10 bytes) = 26 bytes - truncated_packet = bytes(26) - - with pytest.raises(ValueError, match="too small"): - decode_packet_header(local_node_id, truncated_packet) - - def test_truncated_authdata_rejected(self): - """Packet with incomplete authdata is rejected.""" - local_node_id = NodeId(bytes(32)) - masking_iv = bytes(16) - - # Build a valid static header but with claimed authdata larger than packet. - # static-header: protocol-id (6) + version (2) + flag (1) + nonce (12) + authdata-size (2) - # Claim authdata size of 100 bytes. - static_header = b"discv5" + b"\x00\x01\x00" + bytes(12) + b"\x00\x64" # 0x64 = 100 - - # Encrypt/mask the header. - masking_key = local_node_id[:16] - masked_header = aes_ctr_encrypt(Bytes16(masking_key), Bytes16(masking_iv), static_header) - - # Create packet: masking-iv + masked-header + only 10 bytes (not 100). - # This will be rejected because total size < MIN_PACKET_SIZE (63 bytes) - incomplete_packet = masking_iv + masked_header + bytes(10) - - with pytest.raises(ValueError, match="too small"): - decode_packet_header(local_node_id, incomplete_packet) - - -class TestEncodePacketEdgeCases: - """Edge case tests for packet encoding.""" - - def test_message_flag_without_encryption_key_raises(self): - """MESSAGE packets require an encryption key.""" - with pytest.raises(ValueError, match="Encryption key required"): - encode_packet( - dest_node_id=NodeId(bytes(32)), - flag=PacketFlag.MESSAGE, - nonce=Nonce(bytes(12)), - authdata=encode_message_authdata(NodeId(bytes(32))), - message=b"\x01\xc2\x01\x01", - encryption_key=None, - ) - - def test_handshake_flag_without_encryption_key_raises(self): - """HANDSHAKE packets require an encryption key.""" - authdata = encode_handshake_authdata( - src_id=NodeId(bytes(32)), - id_signature=Bytes64(bytes(64)), - eph_pubkey=Bytes33(bytes([0x02]) + bytes(32)), - ) - - with pytest.raises(ValueError, match="Encryption key required"): - encode_packet( - dest_node_id=NodeId(bytes(32)), - flag=PacketFlag.HANDSHAKE, - nonce=Nonce(bytes(12)), - authdata=authdata, - message=b"\x01\xc2\x01\x01", - encryption_key=None, - ) - - -class TestPacketProtocolValidation: - """Protocol ID and version validation in packet decoding.""" - - def test_invalid_protocol_id_rejected(self): - """Packet with wrong protocol ID is rejected.""" - local_node_id = NodeId(bytes(32)) - masking_iv = bytes(16) - - # Build header with wrong protocol ID but correct structure. - # static-header: protocol-id (6) + version (2) + flag (1) + nonce (12) + authdata-size (2) - wrong_protocol_header = b"WRONG!" + b"\x00\x01\x01" + bytes(12) + b"\x00\x18" - - # Mask the entire content (header + authdata). - # Authdata for WHOAREYOU = 24 bytes. - full_content = wrong_protocol_header + bytes(24) - masking_key = local_node_id[:16] - masked_content = aes_ctr_encrypt(Bytes16(masking_key), Bytes16(masking_iv), full_content) - - # Packet = masking-iv + masked-content - packet = masking_iv + masked_content - - with pytest.raises(ValueError, match="Invalid protocol ID"): - decode_packet_header(local_node_id, packet) - - def test_invalid_protocol_version_rejected(self): - """Packet with unsupported protocol version is rejected.""" - local_node_id = NodeId(bytes(32)) - masking_iv = bytes(16) - - # Build header with wrong version (0x0099 instead of 0x0001). - wrong_version_header = b"discv5" + b"\x00\x99\x01" + bytes(12) + b"\x00\x18" - - # Full masked content. - full_content = wrong_version_header + bytes(24) - masking_key = local_node_id[:16] - masked_content = aes_ctr_encrypt(Bytes16(masking_key), Bytes16(masking_iv), full_content) - - packet = masking_iv + masked_content - - with pytest.raises(ValueError, match="Unsupported protocol version"): - decode_packet_header(local_node_id, packet) diff --git a/tests/lean_spec/subspecs/networking/discovery/test_routing.py b/tests/lean_spec/subspecs/networking/discovery/test_routing.py deleted file mode 100644 index 8a15caf2..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_routing.py +++ /dev/null @@ -1,691 +0,0 @@ -""" -Tests for Discovery v5 routing table. - -Tests the RoutingTable and KBucket classes. -""" - -from __future__ import annotations - -import pytest - -from lean_spec.subspecs.networking.discovery.config import BUCKET_COUNT, K_BUCKET_SIZE -from lean_spec.subspecs.networking.discovery.messages import Distance -from lean_spec.subspecs.networking.discovery.routing import ( - KBucket, - NodeEntry, - RoutingTable, - log2_distance, - xor_distance, -) -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.enr.eth2 import FAR_FUTURE_EPOCH -from lean_spec.subspecs.networking.enr.keys import EnrKey -from lean_spec.subspecs.networking.types import ForkDigest, NodeId, SeqNumber -from lean_spec.types import Bytes64 - - -class TestXorDistance: - """Tests for XOR distance calculation.""" - - def test_distance_to_self_is_zero(self): - """XOR distance between same IDs is 0.""" - node_id = NodeId(bytes(32)) - assert xor_distance(node_id, node_id) == 0 - - def test_distance_is_symmetric(self, local_node_id, remote_node_id): - """XOR distance is symmetric: d(a,b) == d(b,a).""" - d1 = xor_distance(local_node_id, remote_node_id) - d2 = xor_distance(remote_node_id, local_node_id) - assert d1 == d2 - - def test_distance_is_positive(self, local_node_id, remote_node_id): - """XOR distance between different IDs is positive.""" - d = xor_distance(local_node_id, remote_node_id) - assert d > 0 - - def test_distance_max_for_inverted_ids(self): - """Max distance occurs when IDs are bitwise inverted.""" - zeros = NodeId(bytes(32)) - ones = NodeId(bytes([0xFF] * 32)) - d = xor_distance(zeros, ones) - assert d == 2**256 - 1 - - -class TestLog2Distance: - """Tests for log2 distance calculation.""" - - def test_log2_distance_self_is_zero(self): - """Log2 distance to self is 0.""" - node_id = NodeId(bytes(32)) - assert int(log2_distance(node_id, node_id)) == 0 - - def test_log2_distance_low_byte_diff(self): - """ - Difference in first byte (big-endian). - - bytes([1]) + bytes(31) differs in byte 0 bit 0 (LSB of first byte). - XOR = 0x0100...00 = 2^248 - log2(2^248) = 248, but bit_length() returns 249. - """ - a = NodeId(bytes(32)) - b = NodeId(bytes([1]) + bytes(31)) - # The XOR has the high bit at position 248, so bit_length is 249. - assert int(log2_distance(a, b)) == 249 - - def test_log2_distance_high_bit_first_byte(self): - """ - High bit of first byte differs. - - bytes([0x80]) + bytes(31) = 0x80000...00 - XOR distance = 2^255, bit_length = 256. - """ - a = NodeId(bytes(32)) - b = NodeId(bytes([0x80]) + bytes(31)) - assert int(log2_distance(a, b)) == 256 - - def test_log2_distance_max(self): - """Max distance for completely different IDs.""" - zeros = NodeId(bytes(32)) - ones = NodeId(bytes([0xFF] * 32)) - d = log2_distance(zeros, ones) - assert int(d) == 256 - - -class TestNodeEntry: - """Tests for NodeEntry dataclass.""" - - def test_create_minimal_entry(self): - """NodeEntry with minimum required fields.""" - node_id = NodeId(bytes(32)) - entry = NodeEntry(node_id=node_id) - - assert entry.node_id == node_id - assert int(entry.enr_seq) == 0 - assert entry.last_seen == 0.0 - assert entry.endpoint is None - assert entry.verified is False - assert entry.enr is None - - def test_create_full_entry(self): - """NodeEntry with all fields.""" - node_id = NodeId(bytes(32)) - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4"}, - ) - - entry = NodeEntry( - node_id=node_id, - enr_seq=SeqNumber(42), - last_seen=123.456, - endpoint="192.168.1.1:30303", - verified=True, - enr=enr, - ) - - assert int(entry.enr_seq) == 42 - assert entry.last_seen == 123.456 - assert entry.endpoint == "192.168.1.1:30303" - assert entry.verified is True - assert entry.enr is enr - - -class TestKBucket: - """Tests for KBucket class.""" - - def test_empty_bucket(self): - """New bucket is empty.""" - bucket = KBucket() - - assert bucket.is_empty - assert not bucket.is_full - assert len(bucket) == 0 - - def test_add_to_bucket(self): - """Adding entry to bucket increases count.""" - bucket = KBucket() - entry = NodeEntry(node_id=NodeId(bytes(32))) - - result = bucket.add(entry) - - assert result is True - assert len(bucket) == 1 - assert not bucket.is_empty - - def test_add_multiple_entries(self): - """Multiple entries can be added.""" - bucket = KBucket() - - for i in range(5): - node_id = NodeId(bytes([i]) + bytes(31)) - entry = NodeEntry(node_id=node_id) - bucket.add(entry) - - assert len(bucket) == 5 - - def test_bucket_full(self): - """Bucket becomes full at K_BUCKET_SIZE.""" - bucket = KBucket() - - for i in range(K_BUCKET_SIZE): - node_id = NodeId(bytes([i]) + bytes(31)) - entry = NodeEntry(node_id=node_id) - bucket.add(entry) - - assert bucket.is_full - assert len(bucket) == K_BUCKET_SIZE - - def test_add_to_full_bucket_returns_false(self): - """Adding to full bucket returns False.""" - bucket = KBucket() - - for i in range(K_BUCKET_SIZE): - node_id = NodeId(bytes([i]) + bytes(31)) - entry = NodeEntry(node_id=node_id) - bucket.add(entry) - - # Try to add one more. - new_entry = NodeEntry(node_id=NodeId(bytes([0xFF]) + bytes(31))) - result = bucket.add(new_entry) - - assert result is False - assert len(bucket) == K_BUCKET_SIZE - - def test_update_existing_moves_to_tail(self): - """Updating existing entry moves it to tail.""" - bucket = KBucket() - - # Add three entries. - for i in range(3): - node_id = NodeId(bytes([i]) + bytes(31)) - entry = NodeEntry(node_id=node_id) - bucket.add(entry) - - # Update the first entry. - first_id = NodeId(bytes([0]) + bytes(31)) - updated_entry = NodeEntry(node_id=first_id, enr_seq=SeqNumber(999)) - result = bucket.add(updated_entry) - - assert result is True - assert len(bucket) == 3 - # Entry should be at tail (most recent). - tail = bucket.tail() - assert tail is not None - assert tail.node_id == first_id - - def test_contains(self): - """Check if bucket contains a node ID.""" - bucket = KBucket() - node_id = NodeId(bytes(32)) - entry = NodeEntry(node_id=node_id) - bucket.add(entry) - - assert bucket.contains(node_id) - assert not bucket.contains(NodeId(bytes([1]) + bytes(31))) - - def test_get_entry(self): - """Retrieve entry by node ID.""" - bucket = KBucket() - node_id = NodeId(bytes(32)) - entry = NodeEntry(node_id=node_id, enr_seq=SeqNumber(42)) - bucket.add(entry) - - retrieved = bucket.get(node_id) - assert retrieved is not None - assert int(retrieved.enr_seq) == 42 - - def test_get_missing_returns_none(self): - """Getting missing node returns None.""" - bucket = KBucket() - result = bucket.get(NodeId(bytes(32))) - assert result is None - - def test_remove_entry(self): - """Remove entry from bucket.""" - bucket = KBucket() - node_id = NodeId(bytes(32)) - entry = NodeEntry(node_id=node_id) - bucket.add(entry) - - result = bucket.remove(node_id) - - assert result is True - assert len(bucket) == 0 - assert not bucket.contains(node_id) - - def test_remove_missing_returns_false(self): - """Removing missing node returns False.""" - bucket = KBucket() - result = bucket.remove(NodeId(bytes(32))) - assert result is False - - def test_head_and_tail(self): - """Head is oldest, tail is newest.""" - bucket = KBucket() - - for i in range(3): - node_id = NodeId(bytes([i]) + bytes(31)) - entry = NodeEntry(node_id=node_id) - bucket.add(entry) - - head = bucket.head() - tail = bucket.tail() - assert head is not None - assert tail is not None - assert head.node_id == NodeId(bytes([0]) + bytes(31)) - assert tail.node_id == NodeId(bytes([2]) + bytes(31)) - - def test_head_of_empty_is_none(self): - """Head of empty bucket is None.""" - bucket = KBucket() - assert bucket.head() is None - - def test_tail_of_empty_is_none(self): - """Tail of empty bucket is None.""" - bucket = KBucket() - assert bucket.tail() is None - - def test_iteration(self): - """Bucket is iterable.""" - bucket = KBucket() - - for i in range(3): - node_id = NodeId(bytes([i]) + bytes(31)) - entry = NodeEntry(node_id=node_id) - bucket.add(entry) - - entries = list(bucket) - assert len(entries) == 3 - - -class TestRoutingTable: - """Tests for RoutingTable class.""" - - def test_create_table(self, local_node_id): - """Create routing table with local ID.""" - table = RoutingTable(local_id=local_node_id) - - assert table.local_id == local_node_id - assert len(table.buckets) == BUCKET_COUNT - assert table.node_count() == 0 - - def test_add_node(self, local_node_id, remote_node_id): - """Add node to routing table.""" - table = RoutingTable(local_id=local_node_id) - entry = NodeEntry(node_id=remote_node_id) - - result = table.add(entry) - - assert result is True - assert table.node_count() == 1 - - def test_add_self_returns_false(self, local_node_id): - """Cannot add self to routing table.""" - table = RoutingTable(local_id=local_node_id) - entry = NodeEntry(node_id=local_node_id) - - result = table.add(entry) - - assert result is False - assert table.node_count() == 0 - - def test_get_node(self, local_node_id, remote_node_id): - """Retrieve node from routing table.""" - table = RoutingTable(local_id=local_node_id) - entry = NodeEntry(node_id=remote_node_id, enr_seq=SeqNumber(42)) - table.add(entry) - - retrieved = table.get(remote_node_id) - assert retrieved is not None - assert int(retrieved.enr_seq) == 42 - - def test_get_missing_returns_none(self, local_node_id, remote_node_id): - """Getting missing node returns None.""" - table = RoutingTable(local_id=local_node_id) - result = table.get(remote_node_id) - assert result is None - - def test_contains(self, local_node_id, remote_node_id): - """Check if node exists in table.""" - table = RoutingTable(local_id=local_node_id) - entry = NodeEntry(node_id=remote_node_id) - table.add(entry) - - assert table.contains(remote_node_id) - assert not table.contains(local_node_id) - - def test_remove_node(self, local_node_id, remote_node_id): - """Remove node from routing table.""" - table = RoutingTable(local_id=local_node_id) - entry = NodeEntry(node_id=remote_node_id) - table.add(entry) - - result = table.remove(remote_node_id) - - assert result is True - assert table.node_count() == 0 - - def test_bucket_index(self, local_node_id, remote_node_id): - """Bucket index is based on log2 distance.""" - table = RoutingTable(local_id=local_node_id) - - idx = table.bucket_index(remote_node_id) - - # Bucket index should be in valid range. - assert 0 <= idx < BUCKET_COUNT - - def test_get_bucket(self, local_node_id, remote_node_id): - """Get bucket for a node ID.""" - table = RoutingTable(local_id=local_node_id) - - bucket = table.get_bucket(remote_node_id) - - assert isinstance(bucket, KBucket) - - def test_closest_nodes_empty_table(self, local_node_id, remote_node_id): - """Closest nodes on empty table returns empty list.""" - table = RoutingTable(local_id=local_node_id) - - closest = table.closest_nodes(remote_node_id, 16) - - assert closest == [] - - def test_closest_nodes_returns_sorted(self, local_node_id): - """Closest nodes are sorted by distance.""" - table = RoutingTable(local_id=local_node_id) - - # Add some nodes. - for i in range(10): - node_id = NodeId(bytes([i * 10]) + bytes(31)) - entry = NodeEntry(node_id=node_id) - table.add(entry) - - target = NodeId(bytes(32)) - closest = table.closest_nodes(target, 5) - - assert len(closest) == 5 - - # Verify sorted by distance. - for i in range(len(closest) - 1): - d1 = xor_distance(closest[i].node_id, target) - d2 = xor_distance(closest[i + 1].node_id, target) - assert d1 <= d2 - - def test_nodes_at_distance(self, local_node_id): - """Get nodes at specific distance.""" - table = RoutingTable(local_id=local_node_id) - - nodes = table.nodes_at_distance(Distance(128)) - - assert isinstance(nodes, list) - - def test_nodes_at_invalid_distance(self, local_node_id): - """Invalid distances return empty list.""" - table = RoutingTable(local_id=local_node_id) - - # Distance 0 returns own ENR, but routing table doesn't store self. - nodes = table.nodes_at_distance(Distance(0)) - assert nodes == [] - - # Distance > 256 is invalid. - nodes = table.nodes_at_distance(Distance(300)) - assert nodes == [] - - -class TestForkCompatibility: - """Tests for fork filtering in routing table.""" - - def test_no_fork_filter_accepts_all(self, local_node_id, remote_node_id): - """Without fork filter, all nodes are accepted.""" - table = RoutingTable(local_id=local_node_id, local_fork_digest=None) - entry = NodeEntry(node_id=remote_node_id) - - assert table.is_fork_compatible(entry) - - def test_fork_filter_rejects_without_enr(self, local_node_id, remote_node_id): - """With fork filter, nodes without ENR are rejected.""" - table = RoutingTable( - local_id=local_node_id, - local_fork_digest=ForkDigest(bytes(4)), - ) - entry = NodeEntry(node_id=remote_node_id, enr=None) - - assert not table.is_fork_compatible(entry) - - def test_fork_filter_rejects_without_eth2_data(self, local_node_id, remote_node_id): - """Nodes without eth2 data are rejected when filtering.""" - table = RoutingTable( - local_id=local_node_id, - local_fork_digest=ForkDigest(bytes(4)), - ) - - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4"}, - ) - entry = NodeEntry(node_id=remote_node_id, enr=enr) - - assert not table.is_fork_compatible(entry) - - def test_fork_filter_rejects_mismatched_fork(self, local_node_id, remote_node_id): - """Node with different fork_digest is rejected.""" - - local_fork = ForkDigest(bytes.fromhex("12345678")) - table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) - - # Build eth2 bytes with a different fork digest. - remote_digest = bytes.fromhex("deadbeef") - eth2_bytes = remote_digest + remote_digest + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("eth2"): eth2_bytes, EnrKey("id"): b"v4"}, - ) - entry = NodeEntry(node_id=remote_node_id, enr=enr) - - assert not table.add(entry) - assert not table.contains(remote_node_id) - - def test_fork_filter_accepts_matching_fork(self, local_node_id, remote_node_id): - """Node with matching fork_digest is accepted.""" - - local_fork = ForkDigest(bytes.fromhex("12345678")) - table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) - - # Build eth2 bytes with the same fork digest. - eth2_bytes = ( - bytes.fromhex("12345678") - + bytes.fromhex("12345678") - + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") - ) - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("eth2"): eth2_bytes, EnrKey("id"): b"v4"}, - ) - entry = NodeEntry(node_id=remote_node_id, enr=enr) - - assert table.add(entry) - assert table.contains(remote_node_id) - - def test_is_fork_compatible_method(self, local_node_id): - """Verify is_fork_compatible for compatible, incompatible, and no-ENR entries.""" - - local_fork = ForkDigest(bytes.fromhex("12345678")) - table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) - - # Compatible entry. - eth2_match = ( - bytes.fromhex("12345678") - + bytes.fromhex("12345678") - + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") - ) - compatible_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("eth2"): eth2_match, EnrKey("id"): b"v4"}, - ) - compatible_entry = NodeEntry(node_id=NodeId(b"\x01" * 32), enr=compatible_enr) - assert table.is_fork_compatible(compatible_entry) - - # Incompatible entry (different fork). - eth2_mismatch = ( - bytes.fromhex("deadbeef") - + bytes.fromhex("deadbeef") - + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") - ) - incompatible_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("eth2"): eth2_mismatch, EnrKey("id"): b"v4"}, - ) - incompatible_entry = NodeEntry(node_id=NodeId(b"\x02" * 32), enr=incompatible_enr) - assert not table.is_fork_compatible(incompatible_entry) - - # Entry without ENR. - no_enr_entry = NodeEntry(node_id=NodeId(b"\x03" * 32)) - assert not table.is_fork_compatible(no_enr_entry) - - -class TestIPDensityTracking: - """Tests for tracking IP address density. - - Anti-eclipse protection limits nodes per IP subnet. - This prevents attackers from filling the table with nodes - all controlled from the same network. - """ - - @pytest.fixture - def local_node_id(self): - """Create local node ID.""" - return NodeId(bytes(32)) - - def test_node_entry_with_endpoint(self, local_node_id): - """NodeEntry can store endpoint information.""" - remote_id = NodeId(bytes([0x80]) + bytes(31)) - - entry = NodeEntry( - node_id=remote_id, - endpoint="192.168.1.1:9000", - ) - - assert entry.endpoint == "192.168.1.1:9000" - assert entry.node_id == remote_id - - def test_node_entry_without_endpoint(self, local_node_id): - """NodeEntry works without endpoint.""" - remote_id = NodeId(bytes([0x80]) + bytes(31)) - - entry = NodeEntry( - node_id=remote_id, - ) - - assert entry.endpoint is None - - def test_nodes_from_same_subnet_distinct(self, local_node_id): - """Nodes from same /24 subnet are distinct but related.""" - table = RoutingTable(local_id=local_node_id) - - # Create nodes from same /24 subnet. - entries = [] - for i in range(5): - # All in 192.168.1.x/24 - node_id = NodeId(bytes([0x80 + i]) + bytes(31)) - entry = NodeEntry( - node_id=node_id, - endpoint=f"192.168.1.{i + 1}:9000", - ) - entries.append(entry) - table.add(entry) - - # All should be in the table (assuming bucket has space). - count = table.node_count() - assert count == 5 - - def test_nodes_from_different_subnets_independent(self, local_node_id): - """Nodes from different /24 subnets are independent.""" - table = RoutingTable(local_id=local_node_id) - - subnets = [ - "192.168.1.1:9000", - "192.168.2.1:9000", - "10.0.0.1:9000", - "172.16.0.1:9000", - ] - - for i, subnet in enumerate(subnets): - node_id = NodeId(bytes([0x80 + i]) + bytes(31)) - entry = NodeEntry( - node_id=node_id, - endpoint=subnet, - ) - table.add(entry) - - # All should be added. - assert table.node_count() == 4 - - def test_ipv6_subnet_tracking(self, local_node_id): - """IPv6 addresses can be tracked.""" - table = RoutingTable(local_id=local_node_id) - - # IPv6 addresses. - ipv6_addresses = [ - "[::1]:9000", - "[fe80::1]:9000", - "[2001:db8::1]:9000", - ] - - for i, addr in enumerate(ipv6_addresses): - node_id = NodeId(bytes([0x80 + i]) + bytes(31)) - entry = NodeEntry( - node_id=node_id, - endpoint=addr, - ) - table.add(entry) - - # All should be tracked. - assert table.node_count() == 3 - - -class TestRoutingTableNodeDiversity: - """Tests for ensuring node diversity in routing table.""" - - @pytest.fixture - def local_node_id(self): - """Create local node ID.""" - return NodeId(bytes(32)) - - def test_bucket_accepts_diverse_nodes(self, local_node_id): - """Buckets accept nodes from different networks.""" - table = RoutingTable(local_id=local_node_id) - - # Add nodes at same distance but different IPs. - for i in range(5): - node_id = NodeId(bytes([0x80, i]) + bytes(30)) - entry = NodeEntry( - node_id=node_id, - endpoint=f"10.{i}.0.1:9000", - ) - table.add(entry) - - # All should be added to table. - assert table.node_count() == 5 - - def test_table_tracks_all_subnets(self, local_node_id): - """Table tracks nodes across all subnets.""" - table = RoutingTable(local_id=local_node_id) - - # Add nodes to different buckets and subnets. - for bucket_prefix in range(3): - for subnet in range(3): - node_id = NodeId(bytes([1 << (7 - bucket_prefix), subnet]) + bytes(30)) - entry = NodeEntry( - node_id=node_id, - endpoint=f"192.168.{bucket_prefix}.{subnet + 1}:9000", - ) - table.add(entry) - - # All 9 nodes should be added. - assert table.node_count() == 9 diff --git a/tests/lean_spec/subspecs/networking/discovery/test_service.py b/tests/lean_spec/subspecs/networking/discovery/test_service.py deleted file mode 100644 index 97dd32cc..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_service.py +++ /dev/null @@ -1,1659 +0,0 @@ -""" -Tests for Discovery v5 service layer. - -Tests the DiscoveryService class. -""" - -from __future__ import annotations - -import asyncio -from collections.abc import Callable -from dataclasses import dataclass -from unittest.mock import AsyncMock, patch - -import pytest - -from lean_spec.subspecs.networking.discovery.codec import DiscoveryMessage -from lean_spec.subspecs.networking.discovery.config import DiscoveryConfig -from lean_spec.subspecs.networking.discovery.messages import ( - Distance, - FindNode, - IPv4, - Nodes, - Ping, - Pong, - RequestId, - TalkReq, - TalkResp, -) -from lean_spec.subspecs.networking.discovery.routing import NodeEntry -from lean_spec.subspecs.networking.discovery.service import ( - DiscoveryService, - LookupResult, -) -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.enr.keys import EnrKey -from lean_spec.subspecs.networking.types import NodeId, Port, SeqNumber -from lean_spec.types import Bytes32, Bytes64 - - -@dataclass -class SentResponse: - """A single recorded send_response call.""" - - node_id: NodeId - addr: tuple[str, int] - message: DiscoveryMessage - - -@dataclass -class SentFindNode: - """A single recorded send_findnode call.""" - - node_id: NodeId - addr: tuple[str, int] - distances: list[int] - - -@dataclass -class SentPing: - """A single recorded send_ping call.""" - - node_id: NodeId - addr: tuple[str, int] - - -@dataclass -class SentTalkReq: - """A single recorded send_talkreq call.""" - - node_id: NodeId - addr: tuple[str, int] - protocol: bytes - request: bytes - - -class FakeDiscoveryTransport: - """In-memory replacement for DiscoveryTransport. - - Records all outbound calls and returns configurable canned responses. - Synchronous bookkeeping (address registry, ENR cache, message handler) - works identically to the real transport. - """ - - def __init__(self) -> None: - """Initialize with empty state and default responses.""" - # Bookkeeping (same as real transport). - self._node_addresses: dict[NodeId, tuple[str, int]] = {} - self._enr_cache: dict[NodeId, ENR] = {} - self._message_handler: ( - Callable[[NodeId, DiscoveryMessage, tuple[str, int]], None] | None - ) = None - - # Lifecycle tracking. - self.started: bool = False - self.stopped: bool = False - self.start_count: int = 0 - - # Recorded outbound calls. - self.sent_responses: list[SentResponse] = [] - self.sent_findnodes: list[SentFindNode] = [] - self.sent_pings: list[SentPing] = [] - self.sent_talkreqs: list[SentTalkReq] = [] - - # Configurable canned return values. - self.send_response_return: bool = True - self.send_ping_return: Pong | None = None - self.send_findnode_return: list[bytes] = [] - self.send_findnode_side_effect: Callable[..., list[bytes]] | None = None - self.send_talkreq_return: bytes | None = None - - # -- Lifecycle -- - - async def start(self, host: str = "0.0.0.0", port: int = 9000) -> None: - """Record start call.""" - self.started = True - self.start_count += 1 - - async def stop(self) -> None: - """Record stop call.""" - self.stopped = True - - # -- Synchronous bookkeeping (identical to real transport) -- - - def set_message_handler( - self, - handler: Callable[[NodeId, DiscoveryMessage, tuple[str, int]], None], - ) -> None: - """Set handler for incoming messages.""" - self._message_handler = handler - - def register_node_address(self, node_id: NodeId, address: tuple[str, int]) -> None: - """Register a node's UDP address.""" - self._node_addresses[node_id] = address - - def get_node_address(self, node_id: NodeId) -> tuple[str, int] | None: - """Get a node's registered UDP address.""" - return self._node_addresses.get(node_id) - - def register_enr(self, node_id: NodeId, enr: ENR) -> None: - """Cache an ENR.""" - self._enr_cache[node_id] = enr - - # -- Async send methods that record calls and return canned values -- - - async def send_response( - self, - dest_node_id: NodeId, - dest_addr: tuple[str, int], - message: DiscoveryMessage, - ) -> bool: - """Record the response and return canned success/failure.""" - self.sent_responses.append(SentResponse(dest_node_id, dest_addr, message)) - return self.send_response_return - - async def send_ping(self, dest_node_id: NodeId, dest_addr: tuple[str, int]) -> Pong | None: - """Record the ping and return canned Pong or None.""" - self.sent_pings.append(SentPing(dest_node_id, dest_addr)) - return self.send_ping_return - - async def send_findnode( - self, - dest_node_id: NodeId, - dest_addr: tuple[str, int], - distances: list[int], - ) -> list[bytes]: - """Record the findnode and return canned ENR list.""" - self.sent_findnodes.append(SentFindNode(dest_node_id, dest_addr, distances)) - if self.send_findnode_side_effect is not None: - return self.send_findnode_side_effect(dest_node_id, dest_addr, distances) - return self.send_findnode_return - - async def send_talkreq( - self, - dest_node_id: NodeId, - dest_addr: tuple[str, int], - protocol: bytes, - request: bytes, - ) -> bytes | None: - """Record the talkreq and return canned response.""" - self.sent_talkreqs.append(SentTalkReq(dest_node_id, dest_addr, protocol, request)) - return self.send_talkreq_return - - -def _make_service( - local_enr: ENR, - local_private_key: Bytes32, - config: DiscoveryConfig | None = None, - bootnodes: list[ENR] | None = None, -) -> tuple[DiscoveryService, FakeDiscoveryTransport]: - """Create a DiscoveryService with a FakeDiscoveryTransport injected.""" - service = DiscoveryService( - local_enr=local_enr, - private_key=local_private_key, - config=config, - bootnodes=bootnodes, - ) - fake = FakeDiscoveryTransport() - service._transport = fake # type: ignore[assignment] - # Re-register the message handler on the new transport. - fake.set_message_handler(service._handle_message) - return service, fake - - -class TestDiscoveryServiceInit: - """Tests for DiscoveryService initialization.""" - - def test_init_creates_required_components(self, local_enr, local_private_key): - """Service initializes all required components.""" - service = DiscoveryService( - local_enr=local_enr, - private_key=local_private_key, - ) - - assert service._local_enr is local_enr - assert service._private_key == local_private_key - assert service._routing_table is not None - assert service._transport is not None - assert service._bond_cache is not None - assert not service._running - - def test_init_with_custom_config(self, local_enr, local_private_key): - """Service accepts custom configuration.""" - config = DiscoveryConfig(request_timeout_secs=30.0) - - service = DiscoveryService( - local_enr=local_enr, - private_key=local_private_key, - config=config, - ) - - assert service._config.request_timeout_secs == 30.0 - - def test_init_with_bootnodes(self, local_enr, local_private_key): - """Service accepts bootnodes list.""" - bootnode = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): bytes.fromhex( - "02a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" - ), - }, - ) - - service = DiscoveryService( - local_enr=local_enr, - private_key=local_private_key, - bootnodes=[bootnode], - ) - - assert len(service._bootnodes) == 1 - - def test_init_requires_public_key_in_enr(self, local_private_key): - """Service requires ENR to have a public key.""" - enr_without_pubkey = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4"}, - ) - - with pytest.raises(ValueError, match="must have a public key"): - DiscoveryService( - local_enr=enr_without_pubkey, - private_key=local_private_key, - ) - - -class TestDiscoveryServiceTalkHandlers: - """Tests for TALK protocol handlers.""" - - def test_register_talk_handler(self, local_enr, local_private_key): - """TALK handlers can be registered.""" - service, _ = _make_service(local_enr, local_private_key) - - def handler(nid: bytes, data: bytes) -> bytes: - return b"response" - - service.register_talk_handler(b"test", handler) - - assert service._talk_handlers[b"test"] is handler - - def test_multiple_talk_handlers(self, local_enr, local_private_key): - """Multiple TALK handlers for different protocols.""" - service, _ = _make_service(local_enr, local_private_key) - - def handler1(nid: bytes, data: bytes) -> bytes: - return b"response1" - - def handler2(nid: bytes, data: bytes) -> bytes: - return b"response2" - - service.register_talk_handler(b"proto1", handler1) - service.register_talk_handler(b"proto2", handler2) - - assert service._talk_handlers[b"proto1"] is handler1 - assert service._talk_handlers[b"proto2"] is handler2 - - -class TestDiscoveryServiceNodeOperations: - """Tests for node operations.""" - - def test_get_random_nodes_empty_table(self, local_enr, local_private_key): - """Get random nodes from empty table returns empty list.""" - service, _ = _make_service(local_enr, local_private_key) - - nodes = service.get_random_nodes(10) - assert nodes == [] - - def test_get_random_nodes_with_entries(self, local_enr, local_private_key): - """Get random nodes returns up to requested count.""" - service, _ = _make_service(local_enr, local_private_key) - - # Add some nodes to routing table. - for i in range(5): - node_id = bytes([i]) + bytes(31) - entry = NodeEntry(node_id=NodeId(node_id), enr_seq=SeqNumber(1)) - service._routing_table.add(entry) - - nodes = service.get_random_nodes(3) - assert len(nodes) <= 3 - - def test_get_nodes_at_distance(self, local_enr, local_private_key): - """Get nodes at specific distance.""" - service, _ = _make_service(local_enr, local_private_key) - - nodes = service.get_nodes_at_distance(128) - assert isinstance(nodes, list) - - def test_node_count_empty_table(self, local_enr, local_private_key): - """Node count for empty table is zero.""" - service, _ = _make_service(local_enr, local_private_key) - - assert service.node_count() == 0 - - -class TestLookupResult: - """Tests for LookupResult dataclass.""" - - def test_create_lookup_result(self): - """LookupResult stores all fields.""" - target = NodeId(bytes(32)) - nodes = [NodeEntry(node_id=NodeId(bytes(32)), enr_seq=SeqNumber(1))] - - result = LookupResult(target=target, nodes=nodes, queried=5) - - assert result.target == target - assert result.nodes == nodes - assert result.queried == 5 - - def test_empty_lookup_result(self): - """LookupResult can have empty nodes list.""" - result = LookupResult(target=NodeId(bytes(32)), nodes=[], queried=0) - - assert result.nodes == [] - assert result.queried == 0 - - -class TestDiscoveryServiceLifecycle: - """Tests for service lifecycle.""" - - @pytest.mark.anyio - async def test_start_sets_running_flag(self, local_enr, local_private_key): - """Starting service sets running flag.""" - service, fake = _make_service(local_enr, local_private_key) - - await service.start("127.0.0.1", 9000) - - assert service._running - assert fake.started - - # Clean up. - await service.stop() - - @pytest.mark.anyio - async def test_start_is_idempotent(self, local_enr, local_private_key): - """Starting already-running service does nothing.""" - service, fake = _make_service(local_enr, local_private_key) - - await service.start("127.0.0.1", 9000) - await service.start("127.0.0.1", 9000) - - assert fake.start_count == 1 - - await service.stop() - - @pytest.mark.anyio - async def test_stop_clears_running_flag(self, local_enr, local_private_key): - """Stopping service clears running flag.""" - service, fake = _make_service(local_enr, local_private_key) - - await service.start("127.0.0.1", 9000) - await service.stop() - - assert not service._running - assert fake.stopped - - @pytest.mark.anyio - async def test_stop_is_idempotent(self, local_enr, local_private_key): - """Stopping already-stopped service does nothing.""" - service, fake = _make_service(local_enr, local_private_key) - - # Stop without starting. - await service.stop() - await service.stop() - - # Should not call transport.stop if not running. - assert not fake.stopped - - -class TestFindNode: - """Tests for find_node lookup operation.""" - - @pytest.mark.anyio - async def test_find_node_invalid_target_length(self, local_enr, local_private_key): - """find_node rejects targets that aren't 32 bytes.""" - service, _ = _make_service(local_enr, local_private_key) - - with pytest.raises(ValueError, match="32 bytes"): - await service.find_node(b"too short") # type: ignore[arg-type] - - @pytest.mark.anyio - async def test_find_node_empty_table(self, local_enr, local_private_key): - """find_node with empty routing table returns empty result.""" - service, _ = _make_service(local_enr, local_private_key) - - result = await service.find_node(NodeId(bytes(32))) - - assert result.target == NodeId(bytes(32)) - assert result.nodes == [] - assert result.queried == 0 - - @pytest.mark.anyio - async def test_find_node_with_responses(self, local_enr, local_private_key, remote_node_id): - """find_node queries candidates and processes responses.""" - service, fake = _make_service(local_enr, local_private_key) - - entry = NodeEntry(node_id=remote_node_id, enr_seq=SeqNumber(1)) - service._routing_table.add(entry) - fake.register_node_address(remote_node_id, ("192.168.1.1", 30303)) - - target = NodeId(bytes(32)) - - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - discovered_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - EnrKey("ip"): bytes([10, 0, 0, 2]), - EnrKey("udp"): (9001).to_bytes(2, "big"), - }, - ) - - fake.send_findnode_return = [discovered_enr.to_rlp()] - result = await service.find_node(target) - - assert result.queried >= 1 - assert result.target == target - assert len(result.nodes) >= 1 - - @pytest.mark.anyio - async def test_find_node_iterative_deepening(self, local_enr, local_private_key): - """find_node iteratively queries closer nodes.""" - service, fake = _make_service(local_enr, local_private_key) - - for i in range(5): - node_id = NodeId(bytes([i]) + bytes(31)) - entry = NodeEntry(node_id=node_id, enr_seq=SeqNumber(1)) - service._routing_table.add(entry) - fake.register_node_address(node_id, (f"192.168.1.{i + 1}", 30303)) - - target = NodeId(bytes(32)) - - def mock_findnode(node_id, addr, distances): - new_pubkey = bytes.fromhex( - "02a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" - ) - new_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): new_pubkey, - EnrKey("ip"): bytes([10, 0, 0, 50]), - EnrKey("udp"): (9050).to_bytes(2, "big"), - }, - ) - return [new_enr.to_rlp()] - - fake.send_findnode_side_effect = mock_findnode - result = await service.find_node(target) - - assert result.queried > 0 - assert result.target == target - - @pytest.mark.anyio - async def test_find_node_handles_exceptions_in_query(self, local_enr, local_private_key): - """find_node handles exceptions from send_findnode gracefully.""" - service, fake = _make_service(local_enr, local_private_key) - - # Add a node so the lookup has candidates. - node_id = NodeId(bytes([1]) + bytes(31)) - entry = NodeEntry(node_id=node_id, enr_seq=SeqNumber(1)) - service._routing_table.add(entry) - fake.register_node_address(node_id, ("192.168.1.1", 30303)) - - target = NodeId(bytes(32)) - - # Transport raises on send_findnode -- _query_node propagates the exception, - # and find_node's gather(return_exceptions=True) catches it. - def raise_error(*args, **kwargs): - raise RuntimeError("network error") - - fake.send_findnode_side_effect = raise_error - result = await service.find_node(target) - - assert result.target == target - assert isinstance(result.nodes, list) - - -class TestServiceIPAddressEncoding: - """Tests for IP address encoding in service layer.""" - - def test_encode_ipv4_loopback(self, local_enr, local_private_key): - """Encode IPv4 loopback address.""" - service, _ = _make_service(local_enr, local_private_key) - - encoded = service._encode_ip_address("127.0.0.1") - - assert bytes(encoded) == b"\x7f\x00\x00\x01" - assert len(encoded) == 4 - - def test_encode_ipv4_common_addresses(self, local_enr, local_private_key): - """Encode common IPv4 addresses.""" - service, _ = _make_service(local_enr, local_private_key) - - test_cases = [ - ("0.0.0.0", b"\x00\x00\x00\x00"), - ("192.168.1.1", b"\xc0\xa8\x01\x01"), - ("10.0.0.1", b"\x0a\x00\x00\x01"), - ("255.255.255.255", b"\xff\xff\xff\xff"), - ] - - for ip_str, expected_bytes in test_cases: - encoded = service._encode_ip_address(ip_str) - assert bytes(encoded) == expected_bytes - assert len(encoded) == 4 - - def test_encode_ipv6_loopback(self, local_enr, local_private_key): - """Encode IPv6 loopback address.""" - service, _ = _make_service(local_enr, local_private_key) - - encoded = service._encode_ip_address("::1") - - expected = bytes(15) + b"\x01" - assert bytes(encoded) == expected - assert len(encoded) == 16 - - def test_encode_ipv6_common_addresses(self, local_enr, local_private_key): - """Encode common IPv6 addresses.""" - service, _ = _make_service(local_enr, local_private_key) - - # :: (all zeros) - encoded_zeros = service._encode_ip_address("::") - assert bytes(encoded_zeros) == bytes(16) - - # ::1 (loopback) - encoded_loopback = service._encode_ip_address("::1") - assert bytes(encoded_loopback) == bytes(15) + b"\x01" - - -class TestBootstrap: - """Bootnode initialization tests.""" - - def test_service_accepts_bootnodes(self, local_enr, local_private_key): - """Service accepts bootnodes in constructor.""" - bootnode = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): bytes.fromhex( - "02a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" - ), - EnrKey("ip"): bytes([192, 168, 1, 1]), - EnrKey("udp"): (30303).to_bytes(2, "big"), - }, - ) - - service = DiscoveryService( - local_enr=local_enr, - private_key=local_private_key, - bootnodes=[bootnode], - ) - - assert len(service._bootnodes) == 1 - - def test_service_accepts_multiple_bootnodes(self, local_enr, local_private_key): - """Service accepts multiple bootnodes.""" - bootnodes = [] - for i in range(5): - bootnode = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(i + 1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): bytes.fromhex( - "02a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" - ), - EnrKey("ip"): bytes([192, 168, 1, i + 1]), - EnrKey("udp"): (30303 + i).to_bytes(2, "big"), - }, - ) - bootnodes.append(bootnode) - - service = DiscoveryService( - local_enr=local_enr, - private_key=local_private_key, - bootnodes=bootnodes, - ) - - assert len(service._bootnodes) == 5 - - def test_service_handles_empty_bootnodes(self, local_enr, local_private_key): - """Service handles empty bootnodes list.""" - service = DiscoveryService( - local_enr=local_enr, - private_key=local_private_key, - bootnodes=[], - ) - - assert len(service._bootnodes) == 0 - - def test_service_handles_none_bootnodes(self, local_enr, local_private_key): - """Service handles None bootnodes.""" - service = DiscoveryService( - local_enr=local_enr, - private_key=local_private_key, - bootnodes=None, - ) - - assert len(service._bootnodes) == 0 - - -class TestHandlePing: - """Tests for _handle_ping message handler.""" - - @pytest.mark.anyio - async def test_handle_ping_sends_pong(self, local_enr, local_private_key, remote_node_id): - """PING triggers a PONG response.""" - service, fake = _make_service(local_enr, local_private_key) - - ping = Ping(request_id=RequestId(data=b"\x01\x02"), enr_seq=SeqNumber(1)) - addr = ("192.168.1.1", 30303) - - await service._handle_ping(remote_node_id, ping, addr) - - assert len(fake.sent_responses) == 1 - sent = fake.sent_responses[0] - assert isinstance(sent.message, Pong) - assert bytes(sent.message.request_id) == b"\x01\x02" - - @pytest.mark.anyio - async def test_handle_ping_establishes_bond(self, local_enr, local_private_key, remote_node_id): - """Successful PONG response establishes bond.""" - service, fake = _make_service(local_enr, local_private_key) - fake.send_response_return = True - - ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) - addr = ("192.168.1.1", 30303) - - await service._handle_ping(remote_node_id, ping, addr) - - assert service._bond_cache.is_bonded(remote_node_id) - - @pytest.mark.anyio - async def test_handle_ping_no_bond_when_send_fails( - self, local_enr, local_private_key, remote_node_id - ): - """No bond established when PONG send fails.""" - service, fake = _make_service(local_enr, local_private_key) - fake.send_response_return = False - - ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) - addr = ("192.168.1.1", 30303) - - await service._handle_ping(remote_node_id, ping, addr) - - assert not service._bond_cache.is_bonded(remote_node_id) - - @pytest.mark.anyio - async def test_handle_ping_pong_includes_recipient_endpoint( - self, local_enr, local_private_key, remote_node_id - ): - """PONG includes the sender's observed IP and port.""" - service, fake = _make_service(local_enr, local_private_key) - - ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) - addr = ("10.0.0.5", 9001) - - await service._handle_ping(remote_node_id, ping, addr) - - sent_pong = fake.sent_responses[0].message - assert isinstance(sent_pong, Pong) - assert int(sent_pong.recipient_port) == 9001 - - -class TestHandleFindNode: - """Tests for _handle_findnode message handler.""" - - @pytest.mark.anyio - async def test_findnode_from_unbonded_node_ignored( - self, local_enr, local_private_key, remote_node_id - ): - """FINDNODE from unbonded node is silently ignored.""" - service, fake = _make_service(local_enr, local_private_key) - - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(1)], - ) - addr = ("192.168.1.1", 30303) - - await service._handle_findnode(remote_node_id, findnode, addr) - - assert fake.sent_responses == [] - - @pytest.mark.anyio - async def test_findnode_from_bonded_node_sends_nodes( - self, local_enr, local_private_key, remote_node_id - ): - """FINDNODE from bonded node sends NODES response.""" - service, fake = _make_service(local_enr, local_private_key) - - # Establish bond first. - service._bond_cache.add_bond(remote_node_id) - - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(128)], - ) - addr = ("192.168.1.1", 30303) - - await service._handle_findnode(remote_node_id, findnode, addr) - - assert len(fake.sent_responses) == 1 - sent_msg = fake.sent_responses[0].message - assert isinstance(sent_msg, Nodes) - assert bytes(sent_msg.request_id) == b"\x01" - - @pytest.mark.anyio - async def test_findnode_distance_zero_returns_local_enr( - self, local_enr, local_private_key, remote_node_id - ): - """FINDNODE with distance=0 returns our own ENR.""" - service, fake = _make_service(local_enr, local_private_key) - - service._bond_cache.add_bond(remote_node_id) - - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(0)], - ) - addr = ("192.168.1.1", 30303) - - await service._handle_findnode(remote_node_id, findnode, addr) - - sent_msg = fake.sent_responses[0].message - assert isinstance(sent_msg, Nodes) - # Distance 0 means our own ENR, so there should be at least 1 ENR. - assert len(sent_msg.enrs) >= 1 - - @pytest.mark.anyio - async def test_findnode_returns_nodes_from_bucket( - self, local_enr, local_private_key, remote_node_id - ): - """FINDNODE returns nodes from routing table buckets.""" - service, fake = _make_service(local_enr, local_private_key) - - service._bond_cache.add_bond(remote_node_id) - - entry = NodeEntry(node_id=NodeId(bytes([1]) + bytes(31)), enr_seq=SeqNumber(1)) - service._routing_table.add(entry) - - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(255)], - ) - addr = ("192.168.1.1", 30303) - - await service._handle_findnode(remote_node_id, findnode, addr) - - assert len(fake.sent_responses) == 1 - sent_msg = fake.sent_responses[0].message - assert isinstance(sent_msg, Nodes) - - @pytest.mark.anyio - async def test_findnode_response_capped_at_max_nodes( - self, local_enr, local_private_key, remote_node_id - ): - """FINDNODE response is capped at max_nodes_response.""" - service, fake = _make_service( - local_enr, local_private_key, config=DiscoveryConfig(max_nodes_response=3) - ) - - service._bond_cache.add_bond(remote_node_id) - - for i in range(10): - entry = NodeEntry( - node_id=NodeId(bytes([i]) + bytes(31)), - enr_seq=SeqNumber(1), - ) - service._routing_table.add(entry) - - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(255)], - ) - addr = ("192.168.1.1", 30303) - - await service._handle_findnode(remote_node_id, findnode, addr) - - assert len(fake.sent_responses) == 1 - sent_msg = fake.sent_responses[0].message - assert isinstance(sent_msg, Nodes) - assert len(sent_msg.enrs) <= 3 - - @pytest.mark.anyio - async def test_findnode_returns_enrs_with_entries( - self, local_enr, local_private_key, remote_node_id - ): - """FINDNODE returns ENRs from routing table when entries have ENRs.""" - service, fake = _make_service(local_enr, local_private_key) - - service._bond_cache.add_bond(remote_node_id) - - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - EnrKey("ip"): bytes([10, 0, 0, 1]), - EnrKey("udp"): (9000).to_bytes(2, "big"), - }, - ) - local_int = int.from_bytes(service._local_node_id, "big") - target_int = local_int ^ 1 - entry = NodeEntry( - node_id=NodeId(target_int.to_bytes(32, "big")), - enr_seq=SeqNumber(1), - enr=enr, - ) - service._routing_table.add(entry) - - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(1)], - ) - addr = ("192.168.1.1", 30303) - - await service._handle_findnode(remote_node_id, findnode, addr) - - assert len(fake.sent_responses) == 1 - sent_msg = fake.sent_responses[0].message - assert isinstance(sent_msg, Nodes) - assert len(sent_msg.enrs) >= 1 - - @pytest.mark.anyio - async def test_process_message_findnode_routes_to_handler( - self, local_enr, local_private_key, remote_node_id - ): - """FindNode messages are dispatched to _handle_findnode.""" - service, _ = _make_service(local_enr, local_private_key) - - service._bond_cache.add_bond(remote_node_id) - - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(128)], - ) - addr = ("192.168.1.1", 30303) - - await service._process_message(remote_node_id, findnode, addr) - - @pytest.mark.anyio - async def test_process_message_talkreq_routes_to_handler( - self, local_enr, local_private_key, remote_node_id - ): - """TalkReq messages are dispatched to _handle_talkreq.""" - service, _ = _make_service(local_enr, local_private_key) - - talkreq = TalkReq( - request_id=RequestId(data=b"\x01"), - protocol=b"eth2", - request=b"data", - ) - addr = ("192.168.1.1", 30303) - - await service._process_message(remote_node_id, talkreq, addr) - - -class TestHandleTalkReq: - """Tests for _handle_talkreq message handler.""" - - @pytest.mark.anyio - async def test_talkreq_unknown_protocol_sends_empty_response( - self, local_enr, local_private_key, remote_node_id - ): - """TALKREQ for unknown protocol sends empty TALKRESP.""" - service, fake = _make_service(local_enr, local_private_key) - - talkreq = TalkReq( - request_id=RequestId(data=b"\x01"), - protocol=b"unknown", - request=b"data", - ) - addr = ("192.168.1.1", 30303) - - await service._handle_talkreq(remote_node_id, talkreq, addr) - - assert len(fake.sent_responses) == 1 - sent_msg = fake.sent_responses[0].message - assert isinstance(sent_msg, TalkResp) - assert sent_msg.response == b"" - - @pytest.mark.anyio - async def test_talkreq_dispatches_to_registered_handler( - self, local_enr, local_private_key, remote_node_id - ): - """TALKREQ dispatches to the registered protocol handler.""" - service, fake = _make_service(local_enr, local_private_key) - - calls: list[tuple[bytes, bytes]] = [] - - def handler(nid: bytes, data: bytes) -> bytes: - calls.append((nid, data)) - return b"handler-response" - - service.register_talk_handler(b"eth2", handler) - - talkreq = TalkReq( - request_id=RequestId(data=b"\x01"), - protocol=b"eth2", - request=b"request-data", - ) - addr = ("192.168.1.1", 30303) - - await service._handle_talkreq(remote_node_id, talkreq, addr) - - assert calls == [(remote_node_id, b"request-data")] - sent_msg = fake.sent_responses[0].message - assert isinstance(sent_msg, TalkResp) - assert sent_msg.response == b"handler-response" - - @pytest.mark.anyio - async def test_talkreq_handler_exception_sends_empty_response( - self, local_enr, local_private_key, remote_node_id - ): - """TALKREQ handler that raises sends empty response.""" - service, fake = _make_service(local_enr, local_private_key) - - def handler(nid: bytes, data: bytes) -> bytes: - raise RuntimeError("handler error") - - service.register_talk_handler(b"eth2", handler) - - talkreq = TalkReq( - request_id=RequestId(data=b"\x01"), - protocol=b"eth2", - request=b"request-data", - ) - addr = ("192.168.1.1", 30303) - - await service._handle_talkreq(remote_node_id, talkreq, addr) - - sent_msg = fake.sent_responses[0].message - assert isinstance(sent_msg, TalkResp) - assert sent_msg.response == b"" - - -class TestSendTalkRequest: - """Tests for send_talk_request method.""" - - @pytest.mark.anyio - async def test_send_talk_request_returns_none_for_unknown_node( - self, local_enr, local_private_key - ): - """send_talk_request returns None when node address is unknown.""" - service, _ = _make_service(local_enr, local_private_key) - - unknown_id = NodeId(bytes(32)) - result = await service.send_talk_request(unknown_id, b"eth2", b"request") - - assert result is None - - @pytest.mark.anyio - async def test_send_talk_request_delegates_to_transport( - self, local_enr, local_private_key, remote_node_id - ): - """send_talk_request delegates to transport when address is known.""" - service, fake = _make_service(local_enr, local_private_key) - fake.register_node_address(remote_node_id, ("192.168.1.1", 30303)) - fake.send_talkreq_return = b"response" - - result = await service.send_talk_request(remote_node_id, b"eth2", b"request") - - assert result == b"response" - assert len(fake.sent_talkreqs) == 1 - assert fake.sent_talkreqs[0] == SentTalkReq( - remote_node_id, ("192.168.1.1", 30303), b"eth2", b"request" - ) - - @pytest.mark.anyio - async def test_send_talk_request_timeout(self, local_enr, local_private_key, remote_node_id): - """send_talk_request returns None on timeout.""" - service, fake = _make_service(local_enr, local_private_key) - fake.register_node_address(remote_node_id, ("192.168.1.1", 30303)) - fake.send_talkreq_return = None - - result = await service.send_talk_request(remote_node_id, b"eth2", b"request") - - assert result is None - - -class TestBootstrapFlow: - """Tests for _bootstrap method.""" - - @pytest.mark.anyio - async def test_bootstrap_registers_bootnode_addresses(self, local_enr, local_private_key): - """Bootstrap registers bootnode addresses and ENRs.""" - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - bootnode = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - EnrKey("ip"): bytes([192, 168, 1, 1]), - EnrKey("udp"): (30303).to_bytes(2, "big"), - }, - ) - - service, _ = _make_service(local_enr, local_private_key, bootnodes=[bootnode]) - - await service._bootstrap() - - assert service.node_count() >= 1 - - @pytest.mark.anyio - async def test_bootstrap_skips_bootnodes_without_ip(self, local_enr, local_private_key): - """Bootstrap skips bootnodes that lack IP/port.""" - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - bootnode = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - }, - ) - - service, fake = _make_service(local_enr, local_private_key, bootnodes=[bootnode]) - - await service._bootstrap() - - assert fake.sent_pings == [] - - @pytest.mark.anyio - async def test_bootstrap_handles_exception(self, local_enr, local_private_key): - """_bootstrap handles exceptions gracefully.""" - bootnode = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - }, - ) - - service, _ = _make_service(local_enr, local_private_key, bootnodes=[bootnode]) - - await service._bootstrap() - - @pytest.mark.anyio - async def test_bootstrap_handles_enr_to_entry_exception(self, local_enr, local_private_key): - """_bootstrap handles exceptions from _enr_to_entry.""" - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - bootnode = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - EnrKey("ip"): bytes([192, 168, 1, 1]), - EnrKey("udp"): (30303).to_bytes(2, "big"), - }, - ) - - service, _ = _make_service(local_enr, local_private_key, bootnodes=[bootnode]) - - with patch.object(service, "_enr_to_entry", side_effect=RuntimeError("test error")): - await service._bootstrap() - - -class TestProcessDiscoveredEnr: - """Tests for _process_discovered_enr method.""" - - def test_invalid_enr_bytes_are_skipped(self, local_enr, local_private_key): - """Invalid RLP bytes are silently skipped.""" - service, _ = _make_service(local_enr, local_private_key) - - seen: dict[NodeId, NodeEntry] = {} - # Should not raise. - service._process_discovered_enr(b"\xff\xff\xff", seen) - - def test_enr_with_wrong_distance_is_dropped(self, local_enr, local_private_key): - """ENR that doesn't match requested distances is dropped.""" - service, _ = _make_service(local_enr, local_private_key) - - # Create a valid ENR. - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): bytes.fromhex( - "02a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" - ), - EnrKey("ip"): bytes([10, 0, 0, 1]), - EnrKey("udp"): (9000).to_bytes(2, "big"), - }, - ) - enr_bytes = enr.to_rlp() - - queried_id = NodeId(bytes(32)) - seen: dict[NodeId, NodeEntry] = {} - - # Request only distance 1 — the actual distance is unlikely to be 1. - service._process_discovered_enr(enr_bytes, seen, queried_id, [1]) - - # ENR should not be added since distance doesn't match. - assert len(seen) == 0 - - def test_valid_enr_added_to_seen_and_routing_table(self, local_enr, local_private_key): - """Valid ENR is added to seen dict and routing table.""" - service, _ = _make_service(local_enr, local_private_key) - - # Create a valid ENR with distance from queried node. - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - EnrKey("ip"): bytes([10, 0, 0, 1]), - EnrKey("udp"): (9000).to_bytes(2, "big"), - }, - ) - enr_bytes = enr.to_rlp() - - # Use local node as queried node so distance will be at a valid range. - seen: dict[NodeId, NodeEntry] = {} - service._process_discovered_enr(enr_bytes, seen) - - # ENR should be added to seen. - assert len(seen) == 1 - node_id = next(iter(seen.keys())) - assert seen[node_id].enr is not None - - # Should also be in routing table. - assert service.node_count() == 1 - - def test_enr_without_node_id_is_skipped(self, local_enr, local_private_key): - """ENR without valid node ID is skipped.""" - service, _ = _make_service(local_enr, local_private_key) - - # ENR with no secp256k1 key (no node ID). - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - }, - ) - enr_bytes = enr.to_rlp() - - seen: dict[NodeId, NodeEntry] = {} - service._process_discovered_enr(enr_bytes, seen) - - assert len(seen) == 0 - - def test_own_enr_is_skipped(self, local_enr, local_private_key): - """Processing our own ENR is skipped.""" - service, _ = _make_service(local_enr, local_private_key) - - seen: dict[NodeId, NodeEntry] = {} - service._process_discovered_enr(local_enr.to_rlp(), seen) - - # Should not add our own ENR. - assert len(seen) == 0 - - def test_already_seen_enr_is_skipped(self, local_enr, local_private_key): - """Duplicate ENR within same lookup is skipped.""" - service, _ = _make_service(local_enr, local_private_key) - - # Create a valid ENR. - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - EnrKey("ip"): bytes([10, 0, 0, 1]), - EnrKey("udp"): (9000).to_bytes(2, "big"), - }, - ) - enr_bytes = enr.to_rlp() - - seen: dict[NodeId, NodeEntry] = {} - - # Process same ENR twice. - service._process_discovered_enr(enr_bytes, seen) - service._process_discovered_enr(enr_bytes, seen) - - # Should only be added once. - assert len(seen) == 1 - - def test_process_discovered_enr_catches_generic_exception(self, local_enr, local_private_key): - """Generic exceptions in ENR processing are caught and logged.""" - service, _ = _make_service(local_enr, local_private_key) - - seen: dict[NodeId, NodeEntry] = {} - - # Pass malformed data that might cause unexpected errors. - with patch("lean_spec.subspecs.networking.enr.ENR.from_rlp") as mock_from_rlp: - mock_from_rlp.side_effect = RuntimeError("unexpected error") - # Should not raise. - service._process_discovered_enr(b"\x00", seen) - - -class TestQueryNode: - """Tests for _query_node method.""" - - @pytest.mark.anyio - async def test_query_node_with_positive_distance( - self, local_enr, local_private_key, remote_node_id - ): - """_query_node sends FINDNODE with correct distance.""" - service, fake = _make_service(local_enr, local_private_key) - - addr = ("192.168.1.1", 30303) - target = NodeId(bytes(32)) - - await service._query_node(remote_node_id, addr, target) - - assert len(fake.sent_findnodes) == 1 - sent = fake.sent_findnodes[0] - assert sent.node_id == remote_node_id - assert sent.addr == addr - # Should have at least one distance. - assert isinstance(sent.distances, list) - assert len(sent.distances) >= 1 - - @pytest.mark.anyio - async def test_query_node_returns_tuple(self, local_enr, local_private_key, remote_node_id): - """_query_node returns (enr_list, node_id, distances) tuple.""" - service, fake = _make_service(local_enr, local_private_key) - - addr = ("192.168.1.1", 30303) - target = NodeId(bytes(32)) - - enrs = [b"\x00", b"\x01"] - fake.send_findnode_return = enrs - result = await service._query_node(remote_node_id, addr, target) - - assert isinstance(result, tuple) - assert len(result) == 3 - enr_list, returned_id, distances = result - assert enr_list == enrs - assert returned_id == remote_node_id - - -class TestPingNode: - """Tests for _ping_node method.""" - - @pytest.mark.anyio - async def test_ping_node_success_returns_true( - self, local_enr, local_private_key, remote_node_id - ): - """Successful ping returns True and adds bond.""" - service, fake = _make_service(local_enr, local_private_key) - - addr = ("192.168.1.1", 30303) - - fake.send_ping_return = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(30303), - ) - result = await service._ping_node(remote_node_id, addr) - - assert result is True - assert service._bond_cache.is_bonded(remote_node_id) - - @pytest.mark.anyio - async def test_ping_node_no_response_returns_false( - self, local_enr, local_private_key, remote_node_id - ): - """Failed ping returns False and no bond added.""" - service, fake = _make_service(local_enr, local_private_key) - - addr = ("192.168.1.1", 30303) - - fake.send_ping_return = None - result = await service._ping_node(remote_node_id, addr) - - assert result is False - assert not service._bond_cache.is_bonded(remote_node_id) - - -class TestProcessMessage: - """Tests for _process_message method.""" - - @pytest.mark.anyio - async def test_process_message_ping_routes_to_handler( - self, local_enr, local_private_key, remote_node_id - ): - """Ping messages are dispatched to _handle_ping.""" - service, _ = _make_service(local_enr, local_private_key) - - ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) - addr = ("192.168.1.1", 30303) - - await service._process_message(remote_node_id, ping, addr) - - assert service._bond_cache.is_bonded(remote_node_id) - - @pytest.mark.anyio - async def test_process_message_updates_node_address( - self, local_enr, local_private_key, remote_node_id - ): - """_process_message updates node address in transport.""" - service, fake = _make_service(local_enr, local_private_key) - - ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) - addr = ("192.168.1.1", 30303) - - await service._process_message(remote_node_id, ping, addr) - - registered_addr = fake.get_node_address(remote_node_id) - assert registered_addr == addr - - @pytest.mark.anyio - async def test_process_message_findnode_routes_to_handler( - self, local_enr, local_private_key, remote_node_id - ): - """FindNode messages are dispatched to _handle_findnode.""" - service, _ = _make_service(local_enr, local_private_key) - - service._bond_cache.add_bond(remote_node_id) - - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(128)], - ) - addr = ("192.168.1.1", 30303) - - await service._process_message(remote_node_id, findnode, addr) - - @pytest.mark.anyio - async def test_process_message_talkreq_routes_to_handler( - self, local_enr, local_private_key, remote_node_id - ): - """TalkReq messages are dispatched to _handle_talkreq.""" - service, _ = _make_service(local_enr, local_private_key) - - talkreq = TalkReq( - request_id=RequestId(data=b"\x01"), - protocol=b"eth2", - request=b"data", - ) - addr = ("192.168.1.1", 30303) - - await service._process_message(remote_node_id, talkreq, addr) - - -class TestHandleMessage: - """Tests for _handle_message method.""" - - @pytest.mark.anyio - async def test_handle_message_creates_task(self, local_enr, local_private_key): - """_handle_message creates async task for processing.""" - service, _ = _make_service(local_enr, local_private_key) - - ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) - addr = ("192.168.1.1", 30303) - remote_id = NodeId(bytes(32)) - - # Just verify it doesn't raise - task creation is async. - service._handle_message(remote_id, ping, addr) - - -class TestFindNodeLookup: - """Tests for find_node iterative Kademlia lookup.""" - - @pytest.mark.anyio - async def test_find_node_with_responses(self, local_enr, local_private_key, remote_node_id): - """find_node queries candidates and processes responses.""" - service, fake = _make_service(local_enr, local_private_key) - - # Add a node to routing table. - entry = NodeEntry(node_id=remote_node_id, enr_seq=SeqNumber(1)) - service._routing_table.add(entry) - fake.register_node_address(remote_node_id, ("192.168.1.1", 30303)) - - target = NodeId(bytes(32)) - - # Configure fake to return ENRs. - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - discovered_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - EnrKey("ip"): bytes([10, 0, 0, 2]), - EnrKey("udp"): (9001).to_bytes(2, "big"), - }, - ) - - fake.send_findnode_return = [discovered_enr.to_rlp()] - result = await service.find_node(target) - - assert result.queried >= 1 - assert result.target == target - assert len(result.nodes) >= 1 - - @pytest.mark.anyio - async def test_find_node_iterative_deepening(self, local_enr, local_private_key): - """find_node iteratively queries closer nodes.""" - service, fake = _make_service(local_enr, local_private_key) - - # Create multiple nodes at varying distances. - for i in range(5): - node_id = NodeId(bytes([i]) + bytes(31)) - entry = NodeEntry(node_id=node_id, enr_seq=SeqNumber(1)) - service._routing_table.add(entry) - fake.register_node_address(node_id, (f"192.168.1.{i + 1}", 30303)) - - target = NodeId(bytes(32)) - - # Return new nodes in each response. - def mock_findnode(node_id, addr, distances): - new_pubkey = bytes.fromhex( - "02a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" - ) - new_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): new_pubkey, - EnrKey("ip"): bytes([10, 0, 0, 50]), - EnrKey("udp"): (9050).to_bytes(2, "big"), - }, - ) - return [new_enr.to_rlp()] - - fake.send_findnode_side_effect = mock_findnode - result = await service.find_node(target) - - # Should have queried nodes. - assert result.queried > 0 - assert result.target == target - - @pytest.mark.anyio - async def test_find_node_handles_exceptions_in_query(self, local_enr, local_private_key): - """find_node handles exceptions from queries gracefully.""" - service, fake = _make_service(local_enr, local_private_key) - - # Add a node so the lookup has candidates. - node_id = NodeId(bytes([1]) + bytes(31)) - entry = NodeEntry(node_id=node_id, enr_seq=SeqNumber(1)) - service._routing_table.add(entry) - fake.register_node_address(node_id, ("192.168.1.1", 30303)) - - target = NodeId(bytes(32)) - - # Transport raises on send_findnode -- _query_node propagates the exception, - # and find_node's gather(return_exceptions=True) catches it. - def raise_error(*args, **kwargs): - raise RuntimeError("network error") - - fake.send_findnode_side_effect = raise_error - result = await service.find_node(target) - - assert result.target == target - assert isinstance(result.nodes, list) - - -class TestEnrToEntry: - """Tests for _enr_to_entry method.""" - - def test_enr_to_entry_with_endpoint(self, local_enr, local_private_key): - """_enr_to_entry creates entry with endpoint.""" - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - EnrKey("ip"): bytes([192, 168, 1, 1]), - EnrKey("udp"): (30303).to_bytes(2, "big"), - }, - ) - - service, _ = _make_service(local_enr, local_private_key) - - entry = service._enr_to_entry(enr) - - assert entry.enr_seq == SeqNumber(1) - assert entry.enr is enr - assert entry.endpoint == "192.168.1.1:30303" - - def test_enr_to_entry_without_ip(self, local_enr, local_private_key): - """_enr_to_entry handles ENR without IP.""" - node_a_pubkey = bytes.fromhex( - "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" - ) - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): node_a_pubkey, - }, - ) - - service, _ = _make_service(local_enr, local_private_key) - - entry = service._enr_to_entry(enr) - - assert entry.enr_seq == SeqNumber(1) - assert entry.endpoint is None - - def test_enr_to_entry_raises_without_node_id(self, local_enr, local_private_key): - """_enr_to_entry raises when ENR has no node ID.""" - enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - }, - ) - - service, _ = _make_service(local_enr, local_private_key) - - with pytest.raises(ValueError, match="no valid node ID"): - service._enr_to_entry(enr) - - -class TestBackgroundLoops: - """Tests for background maintenance loops.""" - - @pytest.mark.anyio - async def test_refresh_loop_performs_lookup(self, local_enr, local_private_key): - """_refresh_loop performs periodic lookups.""" - service, _ = _make_service(local_enr, local_private_key) - - # Start service. - await service.start("127.0.0.1", 9000) - - try: - # Wait a short time for the refresh loop to potentially run. - # Since interval is 1 hour, it won't run naturally. - # Instead, verify the loop exists and can handle errors. - assert service._running - - # Let the event loop process. - await asyncio.sleep(0.01) - finally: - await service.stop() - - @pytest.mark.anyio - async def test_revalidation_loop_handles_empty_table(self, local_enr, local_private_key): - """_revalidation_loop handles empty routing table.""" - service, _ = _make_service(local_enr, local_private_key) - - # Start service. - await service.start("127.0.0.1", 9000) - - try: - # Verify empty table doesn't cause issues. - assert service.node_count() == 0 - await asyncio.sleep(0.01) - finally: - await service.stop() - - @pytest.mark.anyio - async def test_cleanup_loop_calls_bond_cache(self, local_enr, local_private_key): - """_cleanup_loop calls cleanup_expired on bond cache.""" - service, _ = _make_service(local_enr, local_private_key) - - # Start service. - await service.start("127.0.0.1", 9000) - - try: - # Verify bond cache exists. - assert service._bond_cache is not None - await asyncio.sleep(0.01) - finally: - await service.stop() - - @pytest.mark.anyio - async def test_background_loops_handle_exceptions(self, local_enr, local_private_key): - """Background loops catch and log exceptions.""" - service, _ = _make_service(local_enr, local_private_key) - - # Mock find_node to raise. - with patch.object( - service, "find_node", new=AsyncMock(side_effect=RuntimeError("test error")) - ): - await service.start("127.0.0.1", 9000) - - # Service should still be running. - assert service._running - - # Let loop attempt. - await asyncio.sleep(0.01) - - # Service should still be running (exception caught). - assert service._running - - await service.stop() diff --git a/tests/lean_spec/subspecs/networking/discovery/test_session.py b/tests/lean_spec/subspecs/networking/discovery/test_session.py deleted file mode 100644 index b000a0f5..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_session.py +++ /dev/null @@ -1,261 +0,0 @@ -"""Tests for Discovery v5 session management.""" - -import time - -from lean_spec.subspecs.networking.discovery.session import ( - BondCache, - Session, - SessionCache, -) -from lean_spec.subspecs.networking.types import NodeId, Port -from lean_spec.types import Bytes16 - -ZERO_KEY = Bytes16(bytes(16)) - - -class TestSession: - """Tests for Session dataclass.""" - - def test_create_session(self): - """Test session creation.""" - session = Session( - node_id=NodeId(bytes(32)), - send_key=ZERO_KEY, - recv_key=ZERO_KEY, - created_at=time.time(), - last_seen=time.time(), - is_initiator=True, - ) - - assert len(session.node_id) == 32 - assert len(session.send_key) == 16 - assert len(session.recv_key) == 16 - - def test_is_expired_false_for_new_session(self): - """Test that new session is not expired.""" - session = Session( - node_id=NodeId(bytes(32)), - send_key=ZERO_KEY, - recv_key=ZERO_KEY, - created_at=time.time(), - last_seen=time.time(), - is_initiator=True, - ) - - assert not session.is_expired(timeout_secs=3600) - - def test_is_expired_true_for_old_session(self): - """Test that old session is expired.""" - session = Session( - node_id=NodeId(bytes(32)), - send_key=ZERO_KEY, - recv_key=ZERO_KEY, - created_at=time.time() - 7200, # 2 hours ago - last_seen=time.time() - 7200, - is_initiator=True, - ) - - assert session.is_expired(timeout_secs=3600) - - def test_touch_updates_last_seen(self): - """Test that touch updates last_seen timestamp.""" - session = Session( - node_id=NodeId(bytes(32)), - send_key=ZERO_KEY, - recv_key=ZERO_KEY, - created_at=time.time() - 100, - last_seen=time.time() - 100, - is_initiator=True, - ) - - old_last_seen = session.last_seen - session.touch() - - assert session.last_seen > old_last_seen - - -class TestSessionCache: - """Tests for SessionCache.""" - - def test_create_and_get_session(self): - """Test creating and retrieving a session.""" - cache = SessionCache() - node_id = NodeId(bytes.fromhex("aa" * 32)) - send_key = ZERO_KEY - recv_key = ZERO_KEY - - session = cache.create(node_id, send_key, recv_key, is_initiator=True) - - retrieved = cache.get(node_id) - assert retrieved is session - - def test_get_nonexistent_returns_none(self): - """Test that getting nonexistent session returns None.""" - cache = SessionCache() - node_id = NodeId(bytes(32)) - - assert cache.get(node_id) is None - - def test_get_expired_returns_none(self): - """Test that getting expired session returns None and removes it.""" - cache = SessionCache(timeout_secs=0.001) - node_id = NodeId(bytes(32)) - - cache.create(node_id, ZERO_KEY, ZERO_KEY, is_initiator=True) - time.sleep(0.01) - - assert cache.get(node_id) is None - assert cache.count() == 0 - - def test_remove_session(self): - """Test removing a session.""" - cache = SessionCache() - node_id = NodeId(bytes(32)) - - cache.create(node_id, ZERO_KEY, ZERO_KEY, is_initiator=True) - assert cache.remove(node_id) - assert cache.get(node_id) is None - - def test_remove_nonexistent_returns_false(self): - """Test that removing nonexistent session returns False.""" - cache = SessionCache() - assert not cache.remove(NodeId(bytes(32))) - - def test_touch_updates_session(self): - """Test that touch updates session timestamp.""" - cache = SessionCache() - node_id = NodeId(bytes(32)) - - cache.create(node_id, ZERO_KEY, ZERO_KEY, is_initiator=True) - assert cache.touch(node_id) - - def test_touch_nonexistent_returns_false(self): - """Test that touching nonexistent session returns False.""" - cache = SessionCache() - assert not cache.touch(NodeId(bytes(32))) - - def test_count(self): - """Test session count.""" - cache = SessionCache() - - assert cache.count() == 0 - - cache.create(NodeId(bytes.fromhex("aa" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) - assert cache.count() == 1 - - cache.create(NodeId(bytes.fromhex("bb" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) - assert cache.count() == 2 - - def test_cleanup_expired(self): - """Test expired session cleanup.""" - cache = SessionCache(timeout_secs=0.001) - - cache.create(NodeId(bytes.fromhex("aa" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) - cache.create(NodeId(bytes.fromhex("bb" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) - time.sleep(0.01) - - removed = cache.cleanup_expired() - assert removed == 2 - assert cache.count() == 0 - - def test_eviction_when_full(self): - """Test that oldest session is evicted when cache is full.""" - cache = SessionCache(max_sessions=2) - - node1 = NodeId(bytes.fromhex("01" + "00" * 31)) - node2 = NodeId(bytes.fromhex("02" + "00" * 31)) - node3 = NodeId(bytes.fromhex("03" + "00" * 31)) - - cache.create(node1, ZERO_KEY, ZERO_KEY, is_initiator=True) - time.sleep(0.01) # Ensure different timestamps - cache.create(node2, ZERO_KEY, ZERO_KEY, is_initiator=True) - - assert cache.count() == 2 - - # Adding third should evict first - cache.create(node3, ZERO_KEY, ZERO_KEY, is_initiator=True) - - assert cache.count() == 2 - assert cache.get(node1) is None # Oldest should be evicted - assert cache.get(node2) is not None - assert cache.get(node3) is not None - - def test_endpoint_keying_separates_sessions(self): - """Same node_id at different ip:port has separate sessions. - - Per spec, sessions are tied to a specific UDP endpoint. - This prevents session confusion if a node changes IP or port. - """ - cache = SessionCache() - node_id = NodeId(bytes.fromhex("aa" * 32)) - send_key_1 = Bytes16(bytes([0x01] * 16)) - send_key_2 = Bytes16(bytes([0x02] * 16)) - - # Create sessions for same node at different endpoints. - cache.create( - node_id, send_key_1, ZERO_KEY, is_initiator=True, ip="10.0.0.1", port=Port(9000) - ) - cache.create( - node_id, send_key_2, ZERO_KEY, is_initiator=True, ip="10.0.0.2", port=Port(9000) - ) - - # Each endpoint retrieves its own session. - session_1 = cache.get(node_id, "10.0.0.1", Port(9000)) - session_2 = cache.get(node_id, "10.0.0.2", Port(9000)) - - assert session_1 is not None - assert session_2 is not None - assert session_1.send_key == send_key_1 - assert session_2.send_key == send_key_2 - - # Different port for same IP is also separate. - assert cache.get(node_id, "10.0.0.1", Port(9001)) is None - - -class TestBondCache: - """Tests for BondCache.""" - - def test_add_and_check_bond(self): - """Test adding and checking bond.""" - cache = BondCache() - node_id = NodeId(bytes(32)) - - assert not cache.is_bonded(node_id) - - cache.add_bond(node_id) - assert cache.is_bonded(node_id) - - def test_expired_bond(self): - """Test that expired bond returns False.""" - cache = BondCache(expiry_secs=0.001) - node_id = NodeId(bytes(32)) - - cache.add_bond(node_id) - time.sleep(0.01) - - assert not cache.is_bonded(node_id) - - def test_remove_bond(self): - """Test removing a bond.""" - cache = BondCache() - node_id = NodeId(bytes(32)) - - cache.add_bond(node_id) - assert cache.remove_bond(node_id) - assert not cache.is_bonded(node_id) - - def test_remove_nonexistent_returns_false(self): - """Test that removing nonexistent bond returns False.""" - cache = BondCache() - assert not cache.remove_bond(NodeId(bytes(32))) - - def test_cleanup_expired(self): - """Test expired bond cleanup.""" - cache = BondCache(expiry_secs=0.001) - - cache.add_bond(NodeId(bytes.fromhex("aa" * 32))) - cache.add_bond(NodeId(bytes.fromhex("bb" * 32))) - time.sleep(0.01) - - removed = cache.cleanup_expired() - assert removed == 2 diff --git a/tests/lean_spec/subspecs/networking/discovery/test_transport.py b/tests/lean_spec/subspecs/networking/discovery/test_transport.py deleted file mode 100644 index ee1d5fe0..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_transport.py +++ /dev/null @@ -1,2149 +0,0 @@ -""" -Tests for Discovery v5 UDP transport layer. - -Tests the DiscoveryTransport and DiscoveryProtocol classes. -""" - -from __future__ import annotations - -import asyncio -from collections.abc import AsyncIterator -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from cryptography.exceptions import InvalidTag - -from lean_spec.subspecs.networking.discovery.config import DiscoveryConfig -from lean_spec.subspecs.networking.discovery.handshake import HandshakeError, HandshakeResult -from lean_spec.subspecs.networking.discovery.messages import ( - Distance, - FindNode, - IPv4, - Nodes, - Nonce, - PacketFlag, - Ping, - Pong, - Port, - RequestId, - TalkResp, -) -from lean_spec.subspecs.networking.discovery.packet import ( - HandshakeAuthdata, - PacketHeader, - decode_packet_header, - encode_message_authdata, - encode_packet, -) -from lean_spec.subspecs.networking.discovery.session import Session -from lean_spec.subspecs.networking.discovery.transport import ( - DiscoveryProtocol, - DiscoveryTransport, - PendingMultiRequest, - PendingRequest, -) -from lean_spec.subspecs.networking.enr import ENR -from lean_spec.subspecs.networking.enr.keys import EnrKey -from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes16, Bytes32, Bytes33, Bytes64, Uint8 -from tests.lean_spec.subspecs.networking.discovery.conftest import NODE_B_PUBKEY - - -@pytest.fixture -async def started_transport( - local_node_id: NodeId, - local_private_key: Bytes32, - local_enr: ENR, -) -> AsyncIterator[tuple[DiscoveryTransport, MagicMock]]: - """Start a DiscoveryTransport with a mocked UDP socket. - - Yields (transport, mock_udp) where mock_udp is the DatagramTransport mock. - """ - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - yield transport, mock_udp - await transport.stop() - - -class TestDiscoveryProtocol: - """Tests for DiscoveryProtocol async UDP handler.""" - - def test_connection_made_stores_transport(self): - """Protocol stores transport reference on connection.""" - mock_handler = MagicMock() - protocol = DiscoveryProtocol(mock_handler) - - mock_transport = MagicMock(spec=asyncio.DatagramTransport) - protocol.connection_made(mock_transport) - - assert protocol._transport is mock_transport - - @pytest.mark.anyio - async def test_datagram_received_dispatches_to_handler(self): - """Received datagrams are dispatched to the handler.""" - mock_handler = MagicMock() - mock_handler._handle_packet = AsyncMock() - - protocol = DiscoveryProtocol(mock_handler) - - data = b"test packet data" - addr = ("127.0.0.1", 9000) - - protocol.datagram_received(data, addr) - - # Give the task a chance to run. - await asyncio.sleep(0.01) - - mock_handler._handle_packet.assert_called_once_with(data, addr) - - def test_error_received_logs_warning(self): - """UDP errors are logged.""" - mock_handler = MagicMock() - protocol = DiscoveryProtocol(mock_handler) - - # Should not raise. - protocol.error_received(Exception("test error")) - - def test_connection_lost_handles_none_exc(self): - """Connection lost with no exception is handled.""" - mock_handler = MagicMock() - protocol = DiscoveryProtocol(mock_handler) - - # Should not raise. - protocol.connection_lost(None) - - def test_connection_lost_handles_exception(self): - """Connection lost with exception is handled.""" - mock_handler = MagicMock() - protocol = DiscoveryProtocol(mock_handler) - - # Should not raise. - protocol.connection_lost(Exception("connection error")) - - -class TestDiscoveryTransport: - """Tests for DiscoveryTransport.""" - - def test_init_creates_required_components(self, local_node_id, local_private_key, local_enr): - """Transport initializes all required components.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - assert transport._local_node_id == local_node_id - assert transport._local_private_key == local_private_key - assert transport._local_enr == local_enr - assert transport._session_cache is not None - assert transport._handshake_manager is not None - assert not transport._running - - def test_init_with_custom_config(self, local_node_id, local_private_key, local_enr): - """Transport accepts custom configuration.""" - config = DiscoveryConfig(request_timeout_secs=30.0) - - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - assert transport._config.request_timeout_secs == 30.0 - - def test_register_node_address( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """Node addresses can be registered and retrieved.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - addr = ("192.168.1.1", 30303) - transport.register_node_address(remote_node_id, addr) - - assert transport.get_node_address(remote_node_id) == addr - - def test_get_node_address_returns_none_for_unknown( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """Getting unknown node address returns None.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - assert transport.get_node_address(remote_node_id) is None - - def test_set_message_handler(self, local_node_id, local_private_key, local_enr): - """Message handler can be set.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - handler = MagicMock() - transport.set_message_handler(handler) - - assert transport._message_handler is handler - - def test_register_enr(self, local_node_id, local_private_key, local_enr, remote_node_id): - """ENRs can be registered and retrieved.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - remote_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={EnrKey("id"): b"v4"}, - ) - - transport.register_enr(remote_node_id, remote_enr) - - assert transport.get_enr(remote_node_id) is remote_enr - - @pytest.mark.anyio - async def test_start_creates_udp_endpoint(self, local_node_id, local_private_key, local_enr): - """Starting transport creates UDP endpoint.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - # Mock the event loop's create_datagram_endpoint. - mock_transport_obj = MagicMock(spec=asyncio.DatagramTransport) - mock_protocol_obj = MagicMock(spec=DiscoveryProtocol) - - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_transport_obj, mock_protocol_obj)), - ): - await transport.start("127.0.0.1", 9000) - - assert transport._running - - # Clean up. - await transport.stop() - - @pytest.mark.anyio - async def test_start_is_idempotent(self, local_node_id, local_private_key, local_enr): - """Starting an already-started transport does nothing.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - mock_transport_obj = MagicMock(spec=asyncio.DatagramTransport) - mock_protocol_obj = MagicMock(spec=DiscoveryProtocol) - - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_transport_obj, mock_protocol_obj)), - ) as mock_create: - await transport.start("127.0.0.1", 9000) - await transport.start("127.0.0.1", 9000) - - # Should only be called once. - assert mock_create.call_count == 1 - - await transport.stop() - - @pytest.mark.anyio - async def test_stop_closes_transport( - self, started_transport: tuple[DiscoveryTransport, MagicMock] - ): - """Stopping transport closes UDP socket.""" - transport, mock_udp = started_transport - - await transport.stop() - - assert not transport._running - mock_udp.close.assert_called_once() - - @pytest.mark.anyio - async def test_stop_cancels_pending_requests( - self, started_transport: tuple[DiscoveryTransport, MagicMock] - ): - """Stopping transport cancels all pending requests.""" - transport, _ = started_transport - - # Add a pending request. - loop = asyncio.get_running_loop() - future: asyncio.Future = loop.create_future() - request_id = RequestId(data=b"\x01\x02\x03\x04") - pending = PendingRequest( - request_id=request_id, - dest_node_id=NodeId(bytes(32)), - sent_at=loop.time(), - nonce=Nonce(bytes(12)), - message=Ping(request_id=request_id, enr_seq=SeqNumber(1)), - future=future, - ) - transport._pending_requests[pending.request_id] = pending - - await transport.stop() - - assert future.cancelled() - assert len(transport._pending_requests) == 0 - - @pytest.mark.anyio - async def test_stop_is_idempotent(self, local_node_id, local_private_key, local_enr): - """Stopping an already-stopped transport does nothing.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - # Stop without starting should not raise. - await transport.stop() - await transport.stop() - - -class TestSendResponse: - """Tests for sending response messages.""" - - @pytest.mark.anyio - async def test_send_response_without_session_returns_false( - self, - started_transport: tuple[DiscoveryTransport, MagicMock], - remote_node_id: NodeId, - ): - """Sending response without session fails gracefully.""" - transport, _ = started_transport - - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(9000), - ) - - result = await transport.send_response(remote_node_id, ("192.168.1.1", 30303), pong) - - assert result is False - - @pytest.mark.anyio - async def test_send_response_without_transport_returns_false( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """Sending response without starting transport fails.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(9000), - ) - - result = await transport.send_response(remote_node_id, ("192.168.1.1", 30303), pong) - - assert result is False - - -class TestMultiPacketNodesCollection: - """FINDNODE response collection with total > 1. - - When results exceed UDP MTU, NODES responses are split across - multiple packets. The `total` field indicates expected count. - """ - - def test_pending_multi_request_queue_usage(self): - """Response queue collects multiple messages.""" - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - async def test_queue(): - queue: asyncio.Queue = asyncio.Queue() - - request_id = RequestId(data=b"\x01\x02\x03\x04") - pending = PendingMultiRequest( - request_id=request_id, - dest_node_id=NodeId(bytes(32)), - sent_at=123.456, - nonce=Nonce(bytes(12)), - message=FindNode(request_id=request_id, distances=[Distance(256)]), - response_queue=queue, - expected_total=3, - received_count=0, - ) - - # Simulate receiving 3 messages. - ping1 = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) - ping2 = Ping(request_id=RequestId(data=b"\x02"), enr_seq=SeqNumber(2)) - ping3 = Ping(request_id=RequestId(data=b"\x03"), enr_seq=SeqNumber(3)) - await pending.response_queue.put(ping1) - await pending.response_queue.put(ping2) - await pending.response_queue.put(ping3) - - # Queue should have all messages. - assert pending.response_queue.qsize() == 3 - - # Retrieve messages. - msg1 = await pending.response_queue.get() - msg2 = await pending.response_queue.get() - msg3 = await pending.response_queue.get() - - assert msg1 is ping1 - assert msg2 is ping2 - assert msg3 is ping3 - - loop.run_until_complete(test_queue()) - loop.close() - - -class TestNodesResponseAccumulation: - """Tests for accumulating ENRs from multiple NODES responses.""" - - def test_empty_nodes_response_handling(self): - """NODES with total=0 indicates no results.""" - - nodes = Nodes( - request_id=RequestId(data=b"\x01"), - total=Uint8(0), - enrs=[], - ) - - assert int(nodes.total) == 0 - assert nodes.enrs == [] - - def test_single_nodes_response_collection(self): - """Single NODES response with total=1.""" - - nodes = Nodes( - request_id=RequestId(data=b"\x01"), - total=Uint8(1), - enrs=[b"enr1", b"enr2"], - ) - - assert int(nodes.total) == 1 - assert len(nodes.enrs) == 2 - - def test_multiple_nodes_responses_expected(self): - """Multiple NODES messages share same request_id.""" - - request_id = RequestId(data=b"\x01\x02\x03\x04") - - nodes1 = Nodes( - request_id=request_id, - total=Uint8(3), - enrs=[b"enr1", b"enr2"], - ) - - nodes2 = Nodes( - request_id=request_id, - total=Uint8(3), - enrs=[b"enr3", b"enr4"], - ) - - nodes3 = Nodes( - request_id=request_id, - total=Uint8(3), - enrs=[b"enr5"], - ) - - # All messages share same request_id. - assert bytes(nodes1.request_id) == bytes(nodes2.request_id) == bytes(nodes3.request_id) - - # Each has same total. - assert int(nodes1.total) == int(nodes2.total) == int(nodes3.total) == 3 - - # Accumulate all ENRs. - all_enrs = nodes1.enrs + nodes2.enrs + nodes3.enrs - assert len(all_enrs) == 5 - - -class TestRequestResponseCorrelation: - """Request ID matching and timeout handling tests.""" - - def test_request_id_bytes_for_dict_lookup(self): - """Request ID bytes work as dict key for lookup.""" - pending_requests: dict[bytes, PendingRequest] = {} - - loop = asyncio.new_event_loop() - - request_id_1 = b"\x01\x02\x03\x04" - request_id_2 = b"\x05\x06\x07\x08" - - future1: asyncio.Future = loop.create_future() - future2: asyncio.Future = loop.create_future() - - message1 = Ping(request_id=RequestId(data=request_id_1), enr_seq=SeqNumber(1)) - message2 = Ping(request_id=RequestId(data=request_id_2), enr_seq=SeqNumber(2)) - - pending1 = PendingRequest( - request_id=RequestId(data=request_id_1), - dest_node_id=NodeId(bytes(32)), - sent_at=loop.time(), - nonce=Nonce(bytes(12)), - message=message1, - future=future1, - ) - - pending2 = PendingRequest( - request_id=RequestId(data=request_id_2), - dest_node_id=NodeId(bytes(32)), - sent_at=loop.time(), - nonce=Nonce(bytes(12)), - message=message2, - future=future2, - ) - - # Store in dict. - pending_requests[request_id_1] = pending1 - pending_requests[request_id_2] = pending2 - - # Lookup by request_id. - assert pending_requests.get(request_id_1) is pending1 - assert pending_requests.get(request_id_2) is pending2 - assert pending_requests.get(b"\xff\xff\xff\xff") is None - - loop.close() - - -class TestPendingRequestsManagement: - """Tests for pending requests dict management.""" - - @pytest.mark.anyio - async def test_pending_requests_dict_initialized_empty( - self, local_node_id, local_private_key, local_enr - ): - """Transport starts with empty pending requests.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - assert len(transport._pending_requests) == 0 - assert len(transport._pending_multi_requests) == 0 - - @pytest.mark.anyio - async def test_pending_requests_cleared_on_stop( - self, started_transport: tuple[DiscoveryTransport, MagicMock] - ): - """Stop clears all pending requests.""" - transport, _ = started_transport - - # Add some pending requests. - loop = asyncio.get_running_loop() - for i in range(3): - future: asyncio.Future = loop.create_future() - request_id = RequestId(data=bytes([i])) - pending = PendingRequest( - request_id=request_id, - dest_node_id=NodeId(bytes(32)), - sent_at=loop.time(), - nonce=Nonce(bytes(12)), - message=Ping(request_id=request_id, enr_seq=SeqNumber(i)), - future=future, - ) - transport._pending_requests[pending.request_id] = pending - - assert len(transport._pending_requests) == 3 - - await transport.stop() - - # All should be cleared. - assert len(transport._pending_requests) == 0 - - @pytest.mark.anyio - async def test_pending_request_futures_cancelled_on_stop( - self, started_transport: tuple[DiscoveryTransport, MagicMock] - ): - """Stop cancels all pending request futures.""" - transport, _ = started_transport - - loop = asyncio.get_running_loop() - futures = [] - for i in range(3): - future: asyncio.Future = loop.create_future() - futures.append(future) - request_id = RequestId(data=bytes([i])) - pending = PendingRequest( - request_id=request_id, - dest_node_id=NodeId(bytes(32)), - sent_at=loop.time(), - nonce=Nonce(bytes(12)), - message=Ping(request_id=request_id, enr_seq=SeqNumber(i)), - future=future, - ) - transport._pending_requests[pending.request_id] = pending - - await transport.stop() - - # All futures should be cancelled. - for future in futures: - assert future.cancelled() - - -class TestSendPing: - """Tests for send_ping method.""" - - @pytest.mark.anyio - async def test_send_ping_requires_started_transport( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_ping raises if transport not started.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - with pytest.raises(RuntimeError, match="Transport not started"): - await transport.send_ping(remote_node_id, ("192.168.1.1", 30303)) - - @pytest.mark.anyio - async def test_send_ping_returns_none_on_timeout( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_ping returns None when no response arrives before timeout.""" - config = DiscoveryConfig(request_timeout_secs=0.05) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - result = await transport.send_ping(remote_node_id, ("192.168.1.1", 30303)) - - assert result is None - mock_udp.sendto.assert_called_once() - - await transport.stop() - - @pytest.mark.anyio - async def test_send_ping_sends_packet_to_correct_address( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_ping sends a packet to the specified address.""" - config = DiscoveryConfig(request_timeout_secs=0.05) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - dest_addr = ("192.168.1.1", 30303) - await transport.send_ping(remote_node_id, dest_addr) - - # Verify the packet was sent to the correct address. - args = mock_udp.sendto.call_args - assert args[0][1] == dest_addr - - await transport.stop() - - @pytest.mark.anyio - async def test_send_ping_registers_node_address( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_ping registers the destination address for future use.""" - config = DiscoveryConfig(request_timeout_secs=0.05) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - dest_addr = ("192.168.1.1", 30303) - await transport.send_ping(remote_node_id, dest_addr) - - assert transport.get_node_address(remote_node_id) == dest_addr - - await transport.stop() - - -class TestSendFindNode: - """Tests for send_findnode method.""" - - @pytest.mark.anyio - async def test_send_findnode_requires_started_transport( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_findnode raises if transport not started.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - with pytest.raises(RuntimeError, match="Transport not started"): - await transport.send_findnode(remote_node_id, ("192.168.1.1", 30303), [1, 2]) - - @pytest.mark.anyio - async def test_send_findnode_returns_empty_on_timeout( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_findnode returns empty list when no response arrives.""" - config = DiscoveryConfig(request_timeout_secs=0.05) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - result = await transport.send_findnode(remote_node_id, ("192.168.1.1", 30303), [1, 2, 3]) - - assert result == [] - mock_udp.sendto.assert_called_once() - - await transport.stop() - - @pytest.mark.anyio - async def test_send_findnode_sends_packet_to_correct_address( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_findnode sends a packet to the specified address.""" - config = DiscoveryConfig(request_timeout_secs=0.05) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - dest_addr = ("10.0.0.1", 9001) - await transport.send_findnode(remote_node_id, dest_addr, [256]) - - args = mock_udp.sendto.call_args - assert args[0][1] == dest_addr - - await transport.stop() - - -class TestSendTalkReq: - """Tests for send_talkreq method.""" - - @pytest.mark.anyio - async def test_send_talkreq_requires_started_transport( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_talkreq raises if transport not started.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - with pytest.raises(RuntimeError, match="Transport not started"): - await transport.send_talkreq( - remote_node_id, ("192.168.1.1", 30303), b"eth2", b"request" - ) - - @pytest.mark.anyio - async def test_send_talkreq_returns_none_on_timeout( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_talkreq returns None when no response arrives.""" - config = DiscoveryConfig(request_timeout_secs=0.05) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - result = await transport.send_talkreq( - remote_node_id, ("192.168.1.1", 30303), b"eth2", b"request" - ) - - assert result is None - mock_udp.sendto.assert_called_once() - - await transport.stop() - - -class TestHandleDecodedMessage: - """Tests for _handle_decoded_message dispatch.""" - - @pytest.mark.anyio - async def test_response_completes_pending_request_future( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """A decoded response message completes the matching pending request future.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - loop = asyncio.get_running_loop() - future: asyncio.Future[Pong | None] = loop.create_future() - request_id = RequestId(data=b"\x01\x02\x03\x04") - - pending = PendingRequest( - request_id=request_id, - dest_node_id=remote_node_id, - sent_at=loop.time(), - nonce=Nonce(bytes(12)), - message=Ping(request_id=request_id, enr_seq=SeqNumber(1)), - future=future, - ) - transport._pending_requests[request_id] = pending - - pong = Pong( - request_id=request_id, - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(9000), - ) - - await transport._handle_decoded_message(remote_node_id, pong, ("192.168.1.1", 30303)) - - assert future.done() - assert await future is pong - - @pytest.mark.anyio - async def test_response_enqueued_for_multi_request( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """A decoded NODES message is enqueued for pending multi-request.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - queue: asyncio.Queue = asyncio.Queue() - - multi_pending = PendingMultiRequest( - request_id=request_id, - dest_node_id=remote_node_id, - sent_at=0.0, - nonce=Nonce(bytes(12)), - message=FindNode(request_id=request_id, distances=[Distance(256)]), - response_queue=queue, - expected_total=None, - received_count=0, - ) - transport._pending_multi_requests[request_id] = multi_pending - - nodes = Nodes( - request_id=request_id, - total=Uint8(1), - enrs=[b"enr1"], - ) - - await transport._handle_decoded_message(remote_node_id, nodes, ("192.168.1.1", 30303)) - - assert queue.qsize() == 1 - assert await queue.get() is nodes - - @pytest.mark.anyio - async def test_unmatched_message_dispatched_to_handler( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """A message with no matching pending request goes to the message handler.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - handler = MagicMock() - transport.set_message_handler(handler) - - ping = Ping( - request_id=RequestId(data=b"\xff\xff"), - enr_seq=SeqNumber(1), - ) - - await transport._handle_decoded_message(remote_node_id, ping, ("192.168.1.1", 30303)) - - handler.assert_called_once_with(remote_node_id, ping, ("192.168.1.1", 30303)) - - @pytest.mark.anyio - async def test_unmatched_message_without_handler_is_silent( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """A message with no handler and no pending request is silently dropped.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - ping = Ping( - request_id=RequestId(data=b"\xff\xff"), - enr_seq=SeqNumber(1), - ) - - # Should not raise. - await transport._handle_decoded_message(remote_node_id, ping, ("192.168.1.1", 30303)) - - @pytest.mark.anyio - async def test_decoded_message_touches_session( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """Processing a decoded message calls touch on the session cache.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - with patch.object(transport._session_cache, "touch") as mock_touch: - ping = Ping( - request_id=RequestId(data=b"\xff"), - enr_seq=SeqNumber(1), - ) - await transport._handle_decoded_message(remote_node_id, ping, ("192.168.1.1", 30303)) - - mock_touch.assert_called_once_with(remote_node_id, "192.168.1.1", Port(30303)) - - -class TestHandlePacketDispatch: - """Tests for _handle_packet routing logic.""" - - @pytest.mark.anyio - async def test_invalid_packet_is_silently_dropped( - self, local_node_id, local_private_key, local_enr - ): - """Malformed packets are dropped without raising.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - # Garbage data that can't be decoded. - await transport._handle_packet(b"\x00" * 10, ("192.168.1.1", 30303)) - - @pytest.mark.anyio - async def test_short_packet_is_silently_dropped( - self, local_node_id, local_private_key, local_enr - ): - """Packets shorter than minimum size are dropped.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - await transport._handle_packet(b"", ("192.168.1.1", 30303)) - - -class TestHandleMessage: - """Tests for _handle_message (ordinary MESSAGE packets).""" - - @pytest.mark.anyio - async def test_message_without_session_sends_whoareyou( - self, local_node_id, local_private_key, local_enr - ): - """MESSAGE from unknown sender triggers WHOAREYOU.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - src_id = NodeId(bytes(range(32))) - authdata = encode_message_authdata(src_id) - - header = PacketHeader( - flag=PacketFlag.MESSAGE, - nonce=Nonce(bytes(12)), - authdata=authdata, - ) - - with patch.object(transport, "_send_whoareyou", new=AsyncMock()) as mock_whoareyou: - await transport._handle_message(header, b"\x00" * 32, ("192.168.1.1", 30303), b"ad") - - mock_whoareyou.assert_called_once() - - -class TestSendWhoareyou: - """Tests for _send_whoareyou method.""" - - @pytest.mark.anyio - async def test_send_whoareyou_without_transport_is_noop( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """_send_whoareyou does nothing if transport not started.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - # Should not raise. - await transport._send_whoareyou(remote_node_id, Nonce(bytes(12)), ("192.168.1.1", 30303)) - - @pytest.mark.anyio - async def test_send_whoareyou_sends_packet( - self, - started_transport: tuple[DiscoveryTransport, MagicMock], - remote_node_id: NodeId, - ): - """_send_whoareyou sends a WHOAREYOU packet via UDP.""" - transport, mock_udp = started_transport - - await transport._send_whoareyou(remote_node_id, Nonce(bytes(12)), ("192.168.1.1", 30303)) - - mock_udp.sendto.assert_called_once() - args = mock_udp.sendto.call_args - assert args[0][1] == ("192.168.1.1", 30303) - - @pytest.mark.anyio - async def test_send_whoareyou_uses_cached_enr_seq( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """_send_whoareyou uses cached ENR seq instead of hardcoded 0.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - # Register a remote ENR with seq=42. - remote_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(42), - pairs={EnrKey("id"): b"v4"}, - ) - transport.register_enr(remote_node_id, remote_enr) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - with patch.object( - transport._handshake_manager, - "create_whoareyou", - wraps=transport._handshake_manager.create_whoareyou, - ) as mock_create: - await transport._send_whoareyou( - remote_node_id, Nonce(bytes(12)), ("192.168.1.1", 30303) - ) - - # Verify enr_seq=42 was passed, not 0. - call_kwargs = mock_create.call_args - assert call_kwargs[1]["remote_enr_seq"] == SeqNumber(42) - - await transport.stop() - - -class TestSendPingNonPong: - """Tests for send_ping returning None on non-Pong responses.""" - - @pytest.mark.anyio - async def test_send_ping_returns_none_on_non_pong( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_ping returns None when response is not a Pong.""" - config = DiscoveryConfig(request_timeout_secs=0.1) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - nodes = Nodes( - request_id=request_id, - total=Uint8(1), - enrs=[], - ) - - with patch.object(transport, "_send_request", new=AsyncMock(return_value=nodes)): - result = await transport.send_ping(remote_node_id, ("192.168.1.1", 30303)) - - assert result is None - - await transport.stop() - - @pytest.mark.anyio - async def test_send_ping_returns_pong_on_pong_response( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_ping returns Pong response when received.""" - config = DiscoveryConfig(request_timeout_secs=0.1) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - pong = Pong( - request_id=request_id, - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(9000), - ) - - with patch.object(transport, "_send_request", new=AsyncMock(return_value=pong)): - result = await transport.send_ping(remote_node_id, ("192.168.1.1", 30303)) - - assert result == pong - - await transport.stop() - - -class TestSendFindNodeNonNodes: - """Tests for send_findnode handling non-Nodes responses.""" - - @pytest.mark.anyio - async def test_send_findnode_ignores_non_nodes_responses( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_findnode ignores responses that are not NODES.""" - config = DiscoveryConfig(request_timeout_secs=0.1) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - pong = Pong( - request_id=request_id, - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(9000), - ) - - with patch.object( - transport, "_send_multi_response_request", new=AsyncMock(return_value=[pong]) - ): - result = await transport.send_findnode(remote_node_id, ("192.168.1.1", 30303), [256]) - - assert result == [] - - await transport.stop() - - @pytest.mark.anyio - async def test_send_findnode_extracts_enrs_from_nodes( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_findnode extracts ENRs from NODES responses.""" - config = DiscoveryConfig(request_timeout_secs=0.1) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - nodes = Nodes( - request_id=request_id, - total=Uint8(1), - enrs=[b"enr1", b"enr2"], - ) - - with patch.object( - transport, "_send_multi_response_request", new=AsyncMock(return_value=[nodes]) - ): - result = await transport.send_findnode(remote_node_id, ("192.168.1.1", 30303), [256]) - - assert result == [b"enr1", b"enr2"] - - await transport.stop() - - -class TestMultiResponseTimeout: - """Tests for multi-response collection timeout handling.""" - - @pytest.mark.anyio - async def test_multi_response_deadline_elapsed( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """Multi-response collection exits when deadline has passed.""" - config = DiscoveryConfig(request_timeout_secs=0.0) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - mock_protocol = MagicMock(spec=DiscoveryProtocol) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, mock_protocol)), - ): - await transport.start("127.0.0.1", 9000) - - with patch.object( - transport, "_send_multi_response_request", new=AsyncMock(return_value=[]) - ): - result = await transport.send_findnode(remote_node_id, ("192.168.1.1", 30303), [256]) - - assert result == [] - - await transport.stop() - - @pytest.mark.anyio - async def test_multi_response_nodes_handling( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """Multi-response collection handles NODES responses correctly.""" - config = DiscoveryConfig(request_timeout_secs=5.0) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - mock_protocol = MagicMock(spec=DiscoveryProtocol) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, mock_protocol)), - ): - await transport.start("127.0.0.1", 9000) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - nodes1 = Nodes(request_id=request_id, total=Uint8(2), enrs=[b"enr1"]) - nodes2 = Nodes(request_id=request_id, total=Uint8(2), enrs=[b"enr2"]) - - with patch.object( - transport, - "_send_multi_response_request", - new=AsyncMock(return_value=[nodes1, nodes2]), - ): - result = await transport.send_findnode(remote_node_id, ("192.168.1.1", 30303), [256]) - - assert result == [b"enr1", b"enr2"] - - await transport.stop() - - -class TestSendMultiResponseRequest: - """Tests for _send_multi_response_request directly.""" - - @pytest.mark.anyio - async def test_send_multi_response_request_timeout_zero( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """_send_multi_response_request exits immediately with timeout=0.""" - config = DiscoveryConfig(request_timeout_secs=0.0) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - mock_protocol = MagicMock(spec=DiscoveryProtocol) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, mock_protocol)), - ): - await transport.start("127.0.0.1", 9000) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - findnode = FindNode(request_id=request_id, distances=[Distance(256)]) - - result = await transport._send_multi_response_request( - remote_node_id, ("192.168.1.1", 30303), findnode - ) - - assert result == [] - - await transport.stop() - - @pytest.mark.anyio - async def test_send_multi_response_request_collects_responses( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """_send_multi_response_request collects multiple NODES responses.""" - config = DiscoveryConfig(request_timeout_secs=1.0) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - mock_protocol = MagicMock(spec=DiscoveryProtocol) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, mock_protocol)), - ): - await transport.start("127.0.0.1", 9000) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - findnode = FindNode(request_id=request_id, distances=[Distance(256)]) - - async def feed_responses(): - await asyncio.sleep(0.01) - nodes1 = Nodes(request_id=request_id, total=Uint8(2), enrs=[b"enr1"]) - nodes2 = Nodes(request_id=request_id, total=Uint8(2), enrs=[b"enr2"]) - transport._pending_multi_requests[request_id].response_queue.put_nowait(nodes1) - await asyncio.sleep(0.01) - transport._pending_multi_requests[request_id].response_queue.put_nowait(nodes2) - - task = asyncio.create_task( - transport._send_multi_response_request(remote_node_id, ("192.168.1.1", 30303), findnode) - ) - feed_task = asyncio.create_task(feed_responses()) - - result = await task - await feed_task - - assert result == [ - Nodes(request_id=request_id, total=Uint8(2), enrs=[b"enr1"]), - Nodes(request_id=request_id, total=Uint8(2), enrs=[b"enr2"]), - ] - - await transport.stop() - - -class TestSendTalkReqNonTalkResp: - """Tests for send_talkreq returning None on non-TalkResp responses.""" - - @pytest.mark.anyio - async def test_send_talkreq_returns_none_on_non_talkresp( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_talkreq returns None when response is not a TalkResp.""" - config = DiscoveryConfig(request_timeout_secs=0.1) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - pong = Pong( - request_id=request_id, - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(9000), - ) - - with patch.object(transport, "_send_request", new=AsyncMock(return_value=pong)): - result = await transport.send_talkreq( - remote_node_id, ("192.168.1.1", 30303), b"eth2", b"request" - ) - - assert result is None - - await transport.stop() - - @pytest.mark.anyio - async def test_send_talkreq_returns_response_on_talkresp( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_talkreq returns response when TalkResp is received.""" - config = DiscoveryConfig(request_timeout_secs=0.1) - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - config=config, - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - request_id = RequestId(data=b"\x01\x02\x03\x04") - response = b"eth2 response data" - talkresp = TalkResp( - request_id=request_id, - response=response, - ) - - with patch.object(transport, "_send_request", new=AsyncMock(return_value=talkresp)): - result = await transport.send_talkreq( - remote_node_id, ("192.168.1.1", 30303), b"eth2", b"request" - ) - - assert result == response - - await transport.stop() - - -class TestBuildMessagePacketDummyKey: - """Tests for _build_message_packet without existing session.""" - - @pytest.mark.anyio - async def test_build_message_packet_uses_dummy_key_without_session( - self, - started_transport: tuple[DiscoveryTransport, MagicMock], - remote_node_id: NodeId, - ): - """_build_message_packet uses dummy key when no session exists.""" - transport, _ = started_transport - - with patch.object( - transport._handshake_manager, - "start_handshake", - ) as mock_start: - packet = transport._build_message_packet( - remote_node_id, - ("192.168.1.1", 30303), - Nonce(bytes(12)), - b"test message", - ) - - mock_start.assert_called_once_with(remote_node_id) - assert packet is not None - - @pytest.mark.anyio - async def test_build_message_packet_uses_session_key_with_session( - self, - started_transport: tuple[DiscoveryTransport, MagicMock], - remote_node_id: NodeId, - ): - """_build_message_packet uses session key when session exists.""" - transport, _ = started_transport - - transport._session_cache.create( - remote_node_id, - send_key=Bytes16(bytes(16)), - recv_key=Bytes16(bytes(range(16))), - is_initiator=True, - ip="192.168.1.1", - port=Port(30303), - ) - - with patch.object( - transport._handshake_manager, - "start_handshake", - ) as mock_start: - packet = transport._build_message_packet( - remote_node_id, - ("192.168.1.1", 30303), - Nonce(bytes(12)), - b"test message", - ) - - mock_start.assert_not_called() - assert packet is not None - - -class TestHandlePacketRouting: - """Tests for _handle_packet routing to WHOAREYOU and HANDSHAKE handlers.""" - - @pytest.mark.anyio - async def test_handle_packet_routes_whoareyou( - self, local_node_id, local_private_key, local_enr - ): - """WHOAREYOU packets are routed to _handle_whoareyou.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - packet_data = encode_packet( - dest_node_id=local_node_id, - flag=PacketFlag.WHOAREYOU, - nonce=Nonce(bytes(12)), - authdata=bytes(24), - message=bytes(32), - ) - - with patch.object(transport, "_handle_whoareyou", new=AsyncMock()) as mock_handler: - await transport._handle_packet(packet_data, ("192.168.1.1", 30303)) - mock_handler.assert_called_once() - - @pytest.mark.anyio - async def test_handle_packet_routes_handshake( - self, local_node_id, local_private_key, local_enr - ): - """HANDSHAKE packets are routed to _handle_handshake.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - packet_data = encode_packet( - dest_node_id=local_node_id, - flag=PacketFlag.HANDSHAKE, - nonce=Nonce(bytes(12)), - authdata=bytes(65), - message=b"encrypted", - encryption_key=Bytes16(bytes(16)), - ) - - with patch.object(transport, "_handle_handshake", new=AsyncMock()) as mock_handler: - await transport._handle_packet(packet_data, ("192.168.1.1", 30303)) - mock_handler.assert_called_once() - - @pytest.mark.anyio - async def test_handle_packet_routes_message(self, local_node_id, local_private_key, local_enr): - """MESSAGE packets are routed to _handle_message.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - packet_data = encode_packet( - dest_node_id=local_node_id, - flag=PacketFlag.MESSAGE, - nonce=Nonce(bytes(12)), - authdata=encode_message_authdata(NodeId(bytes(range(32)))), - message=b"encrypted", - encryption_key=Bytes16(bytes(16)), - ) - - with patch.object(transport, "_handle_message", new=AsyncMock()) as mock_handler: - await transport._handle_packet(packet_data, ("192.168.1.1", 30303)) - mock_handler.assert_called_once() - - -class TestHandleWhoareyou: - """Tests for _handle_whoareyou edge cases.""" - - @pytest.mark.anyio - async def test_handle_whoareyou_no_matching_request( - self, local_node_id, local_private_key, local_enr - ): - """_handle_whoareyou returns when no pending request matches nonce.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - id_nonce = bytes(16) - authdata = id_nonce + (1).to_bytes(8, "big") - packet_data = encode_packet( - dest_node_id=local_node_id, - flag=PacketFlag.WHOAREYOU, - nonce=Nonce(bytes(12)), - authdata=authdata, - message=bytes(32), - ) - - with patch.object(transport._handshake_manager, "start_handshake"): - await transport._handle_packet(packet_data, ("192.168.1.1", 30303)) - - @pytest.mark.anyio - async def test_handle_whoareyou_no_cached_enr( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """_handle_whoareyou returns when no cached ENR for remote.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - loop = asyncio.get_running_loop() - nonce = Nonce(bytes(12)) - future: asyncio.Future = loop.create_future() - request_id = RequestId(data=b"\x01") - pending = PendingRequest( - request_id=request_id, - dest_node_id=remote_node_id, - sent_at=loop.time(), - nonce=nonce, - message=Ping(request_id=request_id, enr_seq=SeqNumber(1)), - future=future, - ) - transport._pending_requests[pending.request_id] = pending - - id_nonce = bytes(16) - authdata = id_nonce + (1).to_bytes(8, "big") - raw_packet = encode_packet( - dest_node_id=local_node_id, - flag=PacketFlag.WHOAREYOU, - nonce=Nonce(bytes(12)), - authdata=authdata, - message=bytes(32), - ) - - header, _, _ = decode_packet_header(local_node_id, raw_packet) - - with patch.object( - transport._handshake_manager, - "get_cached_enr", - return_value=None, - ): - await transport._handle_whoareyou(header, bytes(24), ("192.168.1.1", 30303), raw_packet) - - @pytest.mark.anyio - async def test_handle_whoareyou_matching_request_sends_handshake( - self, - started_transport: tuple[DiscoveryTransport, MagicMock], - local_node_id: NodeId, - remote_node_id: NodeId, - ): - """_handle_whoareyou sends HANDSHAKE when pending request matches.""" - transport, mock_udp = started_transport - - nonce_bytes = bytes(12) - loop = asyncio.get_running_loop() - nonce = Nonce(nonce_bytes) - future: asyncio.Future = loop.create_future() - pending = PendingRequest( - request_id=RequestId(data=b"\x01"), - dest_node_id=remote_node_id, - sent_at=loop.time(), - nonce=nonce, - message=Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)), - future=future, - ) - transport._pending_requests[pending.request_id] = pending - - remote_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): NODE_B_PUBKEY, - }, - ) - - authdata = bytes(34) - mock_response_authdata = authdata - mock_send_key = Bytes16(bytes(16)) - mock_recv_key = Bytes16(bytes(range(16))) - - with ( - patch.object( - transport._handshake_manager, - "get_cached_enr", - return_value=remote_enr, - ), - patch.object( - transport._handshake_manager, - "create_handshake_response", - return_value=(mock_response_authdata, mock_send_key, mock_recv_key), - ), - ): - id_nonce = bytes(16) - whoareyou_authdata = id_nonce + (1).to_bytes(8, "big") - raw_packet = encode_packet( - dest_node_id=local_node_id, - flag=PacketFlag.WHOAREYOU, - nonce=nonce, - authdata=whoareyou_authdata, - message=bytes(32), - ) - - header, _, _ = decode_packet_header(local_node_id, raw_packet) - await transport._handle_whoareyou(header, bytes(24), ("192.168.1.1", 30303), raw_packet) - - mock_udp.sendto.assert_called() - - @pytest.mark.anyio - async def test_handle_whoareyou_handshake_error( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """_handle_whoareyou handles HandshakeError gracefully.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - loop = asyncio.get_running_loop() - nonce = Nonce(bytes(12)) - future: asyncio.Future = loop.create_future() - request_id = RequestId(data=b"\x01") - pending = PendingRequest( - request_id=request_id, - dest_node_id=remote_node_id, - sent_at=loop.time(), - nonce=nonce, - message=Ping(request_id=request_id, enr_seq=SeqNumber(1)), - future=future, - ) - transport._pending_requests[pending.request_id] = pending - - remote_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): NODE_B_PUBKEY, - }, - ) - - id_nonce = bytes(16) - authdata = id_nonce + (1).to_bytes(8, "big") - raw_packet = encode_packet( - dest_node_id=local_node_id, - flag=PacketFlag.WHOAREYOU, - nonce=Nonce(bytes(12)), - authdata=authdata, - message=bytes(32), - ) - - header, message_bytes, _ = decode_packet_header(local_node_id, raw_packet) - - with ( - patch.object( - transport._handshake_manager, - "get_cached_enr", - return_value=remote_enr, - ), - patch.object( - transport._handshake_manager, - "create_handshake_response", - side_effect=HandshakeError("test error"), - ), - ): - await transport._handle_whoareyou( - header, message_bytes, ("192.168.1.1", 30303), raw_packet - ) - - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - loop = asyncio.get_running_loop() - nonce = Nonce(bytes(12)) - future: asyncio.Future = loop.create_future() - request_id = RequestId(data=b"\x01") - pending = PendingRequest( - request_id=request_id, - dest_node_id=remote_node_id, - sent_at=loop.time(), - nonce=nonce, - message=Ping(request_id=request_id, enr_seq=SeqNumber(1)), - future=future, - ) - transport._pending_requests[pending.request_id] = pending - - remote_enr = ENR( - signature=Bytes64(bytes(64)), - seq=SeqNumber(1), - pairs={ - EnrKey("id"): b"v4", - EnrKey("secp256k1"): NODE_B_PUBKEY, - }, - ) - - id_nonce = bytes(16) - authdata = id_nonce + (1).to_bytes(8, "big") - raw_packet = encode_packet( - dest_node_id=local_node_id, - flag=PacketFlag.WHOAREYOU, - nonce=Nonce(bytes(12)), - authdata=authdata, - message=bytes(32), - ) - - header, message_bytes, _ = decode_packet_header(local_node_id, raw_packet) - - with ( - patch.object( - transport._handshake_manager, - "get_cached_enr", - return_value=remote_enr, - ), - patch.object( - transport._handshake_manager, - "create_handshake_response", - side_effect=HandshakeError("test error"), - ), - ): - await transport._handle_whoareyou( - header, message_bytes, ("192.168.1.1", 30303), raw_packet - ) - - -class TestHandleHandshake: - """Tests for _handle_handshake.""" - - @pytest.mark.anyio - async def test_handle_handshake_completes_and_dispatches( - self, local_node_id, local_private_key, local_enr - ): - """_handle_handshake completes handshake and dispatches message.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - remote_node_id = NodeId(bytes(range(32))) - session = Session( - node_id=remote_node_id, - recv_key=Bytes16(bytes(range(16))), - send_key=Bytes16(bytes(range(16))), - created_at=0.0, - last_seen=0.0, - is_initiator=False, - ) - result = HandshakeResult(session=session, remote_enr=None) - - mock_authdata = HandshakeAuthdata( - src_id=remote_node_id, - sig_size=64, - eph_key_size=33, - id_signature=Bytes64(bytes(64)), - eph_pubkey=Bytes33(bytes(range(33))), - record=None, - ) - - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(9000), - ) - - with ( - patch( - "lean_spec.subspecs.networking.discovery.transport.decode_handshake_authdata", - return_value=mock_authdata, - ), - patch.object( - transport._handshake_manager, - "handle_handshake", - return_value=result, - ), - patch( - "lean_spec.subspecs.networking.discovery.transport.decrypt_message", - return_value=b"decrypted", - ), - patch( - "lean_spec.subspecs.networking.discovery.transport.decode_message", - return_value=pong, - ), - patch.object( - transport, - "_handle_decoded_message", - new=AsyncMock(), - ) as mock_dispatch, - ): - header = MagicMock() - header.authdata = bytes(65) - header.nonce = Nonce(bytes(12)) - - await transport._handle_handshake(header, b"encrypted", ("192.168.1.1", 30303), b"ad") - - mock_dispatch.assert_called_once() - - @pytest.mark.anyio - async def test_handle_handshake_empty_message_skips_decryption( - self, local_node_id, local_private_key, local_enr - ): - """_handle_handshake skips decryption when message_bytes is empty.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - remote_node_id = NodeId(bytes(range(32))) - session = Session( - node_id=remote_node_id, - recv_key=Bytes16(bytes(range(16))), - send_key=Bytes16(bytes(range(16))), - created_at=0.0, - last_seen=0.0, - is_initiator=False, - ) - result = HandshakeResult(session=session, remote_enr=None) - - mock_authdata = HandshakeAuthdata( - src_id=remote_node_id, - sig_size=64, - eph_key_size=33, - id_signature=Bytes64(bytes(64)), - eph_pubkey=Bytes33(bytes(range(33))), - record=None, - ) - - with ( - patch( - "lean_spec.subspecs.networking.discovery.transport.decode_handshake_authdata", - return_value=mock_authdata, - ), - patch.object( - transport._handshake_manager, - "handle_handshake", - return_value=result, - ), - patch( - "lean_spec.subspecs.networking.discovery.transport.decrypt_message", - ) as mock_decrypt, - patch.object( - transport, - "_handle_decoded_message", - new=AsyncMock(), - ) as mock_dispatch, - ): - header = MagicMock() - header.authdata = bytes(65) - header.nonce = Nonce(bytes(12)) - - await transport._handle_handshake(header, b"", ("192.168.1.1", 30303), b"ad") - - mock_decrypt.assert_not_called() - mock_dispatch.assert_not_called() - - @pytest.mark.anyio - async def test_handle_handshake_error(self, local_node_id, local_private_key, local_enr): - """_handle_handshake handles errors gracefully.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - remote_node_id = NodeId(bytes(range(32))) - mock_authdata = HandshakeAuthdata( - src_id=remote_node_id, - sig_size=64, - eph_key_size=33, - id_signature=Bytes64(bytes(64)), - eph_pubkey=Bytes33(bytes(range(33))), - record=None, - ) - - with ( - patch( - "lean_spec.subspecs.networking.discovery.transport.decode_handshake_authdata", - return_value=mock_authdata, - ), - patch.object( - transport._handshake_manager, - "handle_handshake", - side_effect=HandshakeError("test error"), - ), - ): - header = MagicMock() - header.authdata = bytes(65) - - await transport._handle_handshake(header, b"", ("192.168.1.1", 30303), b"") - - -class TestHandleMessageDecryption: - """Tests for _handle_message decryption failure path.""" - - @pytest.mark.anyio - async def test_handle_message_decryption_failure_sends_whoareyou( - self, local_node_id, local_private_key, local_enr - ): - """Decryption failure triggers WHOAREYOU response.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - src_id = NodeId(bytes(range(32))) - transport._session_cache.create( - src_id, - send_key=Bytes16(bytes(16)), - recv_key=Bytes16(bytes(range(16))), - is_initiator=False, - ip="192.168.1.1", - port=Port(30303), - ) - - authdata = encode_message_authdata(src_id) - - header = PacketHeader( - flag=PacketFlag.MESSAGE, - nonce=Nonce(bytes(12)), - authdata=authdata, - ) - - with ( - patch( - "lean_spec.subspecs.networking.discovery.transport.decrypt_message", - side_effect=InvalidTag(), - ), - patch.object( - transport, - "_send_whoareyou", - new=AsyncMock(), - ) as mock_whoareyou, - ): - await transport._handle_message(header, b"encrypted", ("192.168.1.1", 30303), b"ad") - - mock_whoareyou.assert_called_once() - - -class TestSendResponseWithSession: - """Tests for send_response with existing session.""" - - @pytest.mark.anyio - async def test_send_response_with_session( - self, local_node_id, local_private_key, local_enr, remote_node_id - ): - """send_response encrypts and sends when session exists.""" - transport = DiscoveryTransport( - local_node_id=local_node_id, - local_private_key=local_private_key, - local_enr=local_enr, - ) - - transport._session_cache.create( - remote_node_id, - send_key=Bytes16(bytes(16)), - recv_key=Bytes16(bytes(range(16))), - is_initiator=False, - ip="192.168.1.1", - port=Port(30303), - ) - - mock_udp = MagicMock(spec=asyncio.DatagramTransport) - with patch.object( - asyncio.get_event_loop(), - "create_datagram_endpoint", - new=AsyncMock(return_value=(mock_udp, MagicMock(spec=DiscoveryProtocol))), - ): - await transport.start("127.0.0.1", 9000) - - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), - recipient_port=Port(9000), - ) - - result = await transport.send_response(remote_node_id, ("192.168.1.1", 30303), pong) - - assert result is True - mock_udp.sendto.assert_called_once() - - await transport.stop() diff --git a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py deleted file mode 100644 index 5c9cd9e9..00000000 --- a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py +++ /dev/null @@ -1,842 +0,0 @@ -""" -Official Discovery v5 Test Vectors - -Test vectors from the devp2p specification for spec compliance verification. - -Reference: - https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire-test-vectors.md -""" - -from __future__ import annotations - -import pytest -from cryptography.exceptions import InvalidTag -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import ec - -from lean_spec.subspecs.networking.discovery.codec import decode_message, encode_message -from lean_spec.subspecs.networking.discovery.crypto import ( - aes_gcm_decrypt, - aes_gcm_encrypt, - ecdh_agree, - sign_id_nonce, - verify_id_nonce_signature, -) -from lean_spec.subspecs.networking.discovery.keys import ( - compute_node_id, - derive_keys, -) -from lean_spec.subspecs.networking.discovery.messages import ( - Distance, - FindNode, - IdNonce, - IPv4, - MessageType, - Nodes, - Nonce, - PacketFlag, - Ping, - Pong, - Port, - RequestId, -) -from lean_spec.subspecs.networking.discovery.packet import ( - HANDSHAKE_HEADER_SIZE, - STATIC_HEADER_SIZE, - WHOAREYOU_AUTHDATA_SIZE, - decode_handshake_authdata, - decode_message_authdata, - decode_packet_header, - decode_whoareyou_authdata, - decrypt_message, - encode_handshake_authdata, - encode_message_authdata, - encode_packet, - encode_whoareyou_authdata, -) -from lean_spec.subspecs.networking.discovery.routing import log2_distance, xor_distance -from lean_spec.subspecs.networking.types import SeqNumber -from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes33, Bytes64, Uint8 -from tests.lean_spec.helpers import make_challenge_data -from tests.lean_spec.subspecs.networking.discovery.conftest import ( - NODE_A_ID, - NODE_A_PRIVKEY, - NODE_B_ID, - NODE_B_PRIVKEY, - NODE_B_PUBKEY, - SPEC_ID_NONCE, -) - -# Spec test vector values for ECDH and key derivation. -SPEC_NONCE = Nonce(bytes.fromhex("0102030405060708090a0b0c")) -SPEC_CHALLENGE_DATA = bytes.fromhex( - "000000000000000000000000000000006469736376350001010102030405060708090a0b0c" - "00180102030405060708090a0b0c0d0e0f100000000000000000" -) - -# Spec ephemeral keypair for ECDH / ID nonce signing. -SPEC_EPHEMERAL_KEY = Bytes32( - bytes.fromhex("fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736") -) -SPEC_EPHEMERAL_PUBKEY = Bytes33( - bytes.fromhex("039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231") -) - -# Derived session keys from spec HKDF test vector. -SPEC_INITIATOR_KEY = Bytes16(bytes.fromhex("dccc82d81bd610f4f76d3ebe97a40571")) -SPEC_RECIPIENT_KEY = Bytes16(bytes.fromhex("ac74bb8773749920b0d3a8881c173ec5")) - -# AES-GCM test vector values. -SPEC_AES_KEY = Bytes16(bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b")) -SPEC_AES_NONCE = Bytes12(bytes.fromhex("27b5af763c446acd2749fe8e")) - -# PING message plaintext (type 0x01, RLP [1]). -SPEC_PING_PLAINTEXT = bytes.fromhex("01c20101") - - -class TestOfficialNodeIdVectors: - """Verify node ID computation matches official test vectors.""" - - def test_node_b_id_from_privkey(self): - """ - Node B's ID is keccak256 of uncompressed public key. - - We derive the public key from the private key since the spec - provides the private key for Node B. - """ - # Derive public key from private key. - private_key = ec.derive_private_key( - int.from_bytes(NODE_B_PRIVKEY, "big"), - ec.SECP256K1(), - ) - pubkey_bytes = Bytes33( - private_key.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, - ) - ) - - computed = compute_node_id(pubkey_bytes) - assert computed == NODE_B_ID - - -class TestOfficialNodeIdAndKeyVectors: - """Verify both node IDs and bidirectional ECDH from spec key material.""" - - def test_node_a_id_from_privkey(self): - """Node A's ID from its private key matches the spec vector.""" - private_key = ec.derive_private_key( - int.from_bytes(NODE_A_PRIVKEY, "big"), - ec.SECP256K1(), - ) - pubkey_bytes = Bytes33( - private_key.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, - ) - ) - computed = compute_node_id(pubkey_bytes) - assert computed == NODE_A_ID - - def test_bidirectional_ecdh(self): - """ECDH(A_priv, B_pub) == ECDH(B_priv, A_pub). - - Derives Node A's public key from its private key and verifies - that both sides compute the same shared secret. - """ - # Derive Node A's public key from its private key. - a_privkey = ec.derive_private_key( - int.from_bytes(NODE_A_PRIVKEY, "big"), - ec.SECP256K1(), - ) - a_pubkey_bytes = Bytes33( - a_privkey.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, - ) - ) - - shared_ab = ecdh_agree(NODE_A_PRIVKEY, NODE_B_PUBKEY) - shared_ba = ecdh_agree(NODE_B_PRIVKEY, a_pubkey_bytes) - - assert shared_ab == shared_ba - - -class TestOfficialCryptoVectors: - """Cryptographic operation test vectors from devp2p spec.""" - - def test_ecdh_shared_secret(self): - """ - ECDH between Node A's private key and Node B's public key. - - Per spec, the shared secret is the 33-byte compressed point. - """ - expected_shared = bytes.fromhex( - "033b11a2a1f214567e1537ce5e509ffd9b21373247f2a3ff6841f4976f53165e7e" - ) - - shared = ecdh_agree(SPEC_EPHEMERAL_KEY, SPEC_EPHEMERAL_PUBKEY) - - assert bytes(shared) == expected_shared - - def test_key_derivation_hkdf(self): - """ - Key derivation using HKDF-SHA256. - - Derives initiator_key and recipient_key from ECDH shared secret. - Uses exact spec challenge_data (with nonce 0102030405060708090a0b0c). - """ - # Compute ECDH shared secret. - shared_secret = ecdh_agree(SPEC_EPHEMERAL_KEY, NODE_B_PUBKEY) - - # Derive keys using exact spec challenge_data. - initiator_key, recipient_key = derive_keys( - secret=shared_secret, - initiator_id=NODE_A_ID, - recipient_id=NODE_B_ID, - challenge_data=SPEC_CHALLENGE_DATA, - ) - - assert initiator_key == SPEC_INITIATOR_KEY - assert recipient_key == SPEC_RECIPIENT_KEY - - def test_id_nonce_signature(self): - """ - ID nonce signature proves node identity ownership. - - Per spec: - id-signature-input = "discovery v5 identity proof" || challenge-data || - ephemeral-pubkey || node-id-B - signature = sign(sha256(id-signature-input)) - - Uses exact spec challenge_data and verifies byte-exact signature output. - """ - # Sign using exact spec challenge_data. - signature = sign_id_nonce( - private_key_bytes=SPEC_EPHEMERAL_KEY, - challenge_data=SPEC_CHALLENGE_DATA, - ephemeral_pubkey=SPEC_EPHEMERAL_PUBKEY, - dest_node_id=NODE_B_ID, - ) - - expected_sig = bytes.fromhex( - "94852a1e2318c4e5e9d422c98eaf19d1d90d876b29cd06ca7cb7546d0fff7b48" - "4fe86c09a064fe72bdbef73ba8e9c34df0cd2b53e9d65528c2c7f336d5dfc6e6" - ) - assert bytes(signature) == expected_sig - - # Also verify the signature. - private_key = ec.derive_private_key( - int.from_bytes(SPEC_EPHEMERAL_KEY, "big"), - ec.SECP256K1(), - ) - pubkey_bytes = private_key.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, - ) - - assert verify_id_nonce_signature( - signature=Bytes64(signature), - challenge_data=SPEC_CHALLENGE_DATA, - ephemeral_pubkey=SPEC_EPHEMERAL_PUBKEY, - dest_node_id=NODE_B_ID, - public_key_bytes=Bytes33(pubkey_bytes), - ) - - def test_id_nonce_signature_different_challenge_data(self): - """Different challenge_data produces different signatures.""" - challenge_data1 = make_challenge_data(bytes(16)) - challenge_data2 = make_challenge_data(bytes([1]) + bytes(15)) - - sig1 = sign_id_nonce( - NODE_B_PRIVKEY, - challenge_data1, - SPEC_EPHEMERAL_PUBKEY, - NODE_A_ID, - ) - sig2 = sign_id_nonce( - NODE_B_PRIVKEY, - challenge_data2, - SPEC_EPHEMERAL_PUBKEY, - NODE_A_ID, - ) - - assert sig1 != sig2 - - def test_aes_gcm_encryption(self): - """ - AES-128-GCM message encryption. - - The 16-byte authentication tag is appended to ciphertext. - """ - aad = bytes.fromhex("93a7400fa0d6a694ebc24d5cf570f65d04215b6ac00757875e3f3a5f42107903") - expected_ciphertext = bytes.fromhex("a5d12a2d94b8ccb3ba55558229867dc13bfa3648") - - # Encrypt. - ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, SPEC_PING_PLAINTEXT, aad) - - assert ciphertext == expected_ciphertext - - # Verify decryption works. - decrypted = aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, ciphertext, aad) - assert decrypted == SPEC_PING_PLAINTEXT - - -class TestOfficialPacketVectors: - """Decode exact packet bytes from the devp2p spec test vectors. - - These tests verify interoperability by decoding the spec's exact hex packets. - """ - - def test_decode_spec_ping_packet(self): - """Decode the exact Ping packet from the spec test vectors. - - Verifies header fields and decrypts the message payload. - """ - packet_hex = ( - "00000000000000000000000000000000088b3d4342774649325f313964a39e55" - "ea96c005ad52be8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3" - "4c4f53245d08dab84102ed931f66d1492acb308fa1c6715b9d139b81acbdcc" - ) - packet = bytes.fromhex(packet_hex) - - header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) - - assert header.flag == PacketFlag.MESSAGE - decoded_authdata = decode_message_authdata(header.authdata) - assert decoded_authdata.src_id == NODE_A_ID - - # Decrypt using the spec's read-key (all zeros for this test vector). - read_key = Bytes16(bytes(16)) - plaintext = decrypt_message(read_key, header.nonce, ciphertext, message_ad) - - # PING with request-id=0x00000001 (4 bytes) and enr-seq=2. - decoded = decode_message(plaintext) - assert isinstance(decoded, Ping) - assert int(decoded.enr_seq) == 2 - - def test_decode_spec_whoareyou_packet(self): - """Decode the exact WHOAREYOU packet from the spec test vectors. - - Verifies id-nonce and enr-seq match expected values. - Per spec, the WHOAREYOU dest-node-id is Node B's ID. - """ - packet_hex = ( - "00000000000000000000000000000000088b3d434277464933a1ccc59f5967ad" - "1d6035f15e528627dde75cd68292f9e6c27d6b66c8100a873fcbaed4e16b8d" - ) - packet = bytes.fromhex(packet_hex) - - header, _message, _message_ad = decode_packet_header(NODE_B_ID, packet) - - assert header.flag == PacketFlag.WHOAREYOU - decoded_authdata = decode_whoareyou_authdata(header.authdata) - assert decoded_authdata.id_nonce == SPEC_ID_NONCE - assert int(decoded_authdata.enr_seq) == 0 - - def test_decode_spec_handshake_packet(self): - """Decode the exact Handshake packet (no ENR) from the spec test vectors. - - Verifies authdata fields (src-id, signature size, key size). - """ - packet_hex = ( - "00000000000000000000000000000000088b3d4342774649305f313964a39e55" - "ea96c005ad521d8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3" - "4c4f53245d08da4bb252012b2cba3f4f374a90a75cff91f142fa9be3e0a5f3ef" - "268ccb9065aeecfd67a999e7fdc137e062b2ec4a0eb92947f0d9a74bfbf44dfb" - "a776b21301f8b65efd5796706adff216ab862a9186875f9494150c4ae06fa4d1" - "f0396c93f215fa4ef524f1eadf5f0f4126b79336671cbcf7a885b1f8bd2a5d83" - "9cf8" - ) - packet = bytes.fromhex(packet_hex) - - header, _ciphertext, _message_ad = decode_packet_header(NODE_B_ID, packet) - - assert header.flag == PacketFlag.HANDSHAKE - decoded_authdata = decode_handshake_authdata(header.authdata) - assert decoded_authdata.src_id == NODE_A_ID - assert decoded_authdata.sig_size == 64 - assert decoded_authdata.eph_key_size == 33 - - def test_decode_spec_handshake_with_enr_packet(self): - """Decode the exact Handshake-with-ENR packet from the spec test vectors. - - Verifies authdata fields and presence of embedded ENR record. - """ - packet_hex = ( - "00000000000000000000000000000000088b3d4342774649305f313964a39e55" - "ea96c005ad539c8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3" - "4c4f53245d08da4bb23698868350aaad22e3ab8dd034f548a1c43cd246be9856" - "2fafa0a1fa86d8e7a3b95ae78cc2b988ded6a5b59eb83ad58097252188b902b2" - "1481e30e5e285f19735796706adff216ab862a9186875f9494150c4ae06fa4d1" - "f0396c93f215fa4ef524e0ed04c3c21e39b1868e1ca8105e585ec17315e755e6" - "cfc4dd6cb7fd8e1a1f55e49b4b5eb024221482105346f3c82b15fdaae36a3bb1" - "2a494683b4a3c7f2ae41306252fed84785e2bbff3b022812d0882f06978df84a" - "80d443972213342d04b9048fc3b1d5fcb1df0f822152eced6da4d3f6df27e70e" - "4539717307a0208cd208d65093ccab5aa596a34d7511401987662d8cf62b1394" - "71" - ) - packet = bytes.fromhex(packet_hex) - - header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) - - assert header.flag == PacketFlag.HANDSHAKE - decoded_authdata = decode_handshake_authdata(header.authdata) - assert decoded_authdata.src_id == NODE_A_ID - assert decoded_authdata.sig_size == 64 - assert decoded_authdata.eph_key_size == 33 - - # This packet includes an ENR record (unlike the no-ENR handshake). - assert decoded_authdata.record is not None - assert len(decoded_authdata.record) > 0 - - # Decrypt the message using the spec's read-key. - read_key = Bytes16(bytes.fromhex("53b1c075f41876423154e157470c2f48")) - plaintext = decrypt_message(read_key, header.nonce, ciphertext, message_ad) - - # PING with request-id=0x00000001 and enr-seq=1. - decoded = decode_message(plaintext) - assert isinstance(decoded, Ping) - assert int(decoded.enr_seq) == 1 - - -class TestPacketEncodingRoundtrip: - """Test full packet encoding/decoding roundtrips.""" - - def test_message_packet_roundtrip(self): - """MESSAGE packet encodes and decodes correctly.""" - nonce = Nonce(bytes(12)) # 12-byte nonce. - encryption_key = Bytes16(bytes(16)) # 16-byte key. - message = b"\x01\xc2\x01\x01" # PING message. - - authdata = encode_message_authdata(NODE_A_ID) - - packet = encode_packet( - dest_node_id=NODE_B_ID, - flag=PacketFlag.MESSAGE, - nonce=nonce, - authdata=authdata, - message=message, - encryption_key=encryption_key, - ) - - # Decode header. - header, ciphertext, _message_ad = decode_packet_header(NODE_B_ID, packet) - - assert header.flag == PacketFlag.MESSAGE - assert len(header.authdata) == 32 - - decoded_authdata = decode_message_authdata(header.authdata) - assert decoded_authdata.src_id == NODE_A_ID - - def test_whoareyou_packet_roundtrip(self): - """WHOAREYOU packet encodes and decodes correctly.""" - nonce = Nonce(bytes.fromhex("0102030405060708090a0b0c")) - id_nonce = IdNonce(bytes.fromhex("0102030405060708090a0b0c0d0e0f10")) - enr_seq = SeqNumber(0) - - authdata = encode_whoareyou_authdata(id_nonce, enr_seq) - - packet = encode_packet( - dest_node_id=NODE_B_ID, - flag=PacketFlag.WHOAREYOU, - nonce=nonce, - authdata=authdata, - message=b"", # WHOAREYOU has no message. - encryption_key=None, - ) - - # Decode header. - header, message, _message_ad = decode_packet_header(NODE_B_ID, packet) - - assert header.flag == PacketFlag.WHOAREYOU - assert header.nonce == nonce - - decoded_authdata = decode_whoareyou_authdata(header.authdata) - assert decoded_authdata.id_nonce == id_nonce - assert decoded_authdata.enr_seq == enr_seq - - def test_handshake_packet_roundtrip(self): - """HANDSHAKE packet encodes and decodes correctly.""" - nonce = Nonce(bytes(12)) - message = b"\x01\xc2\x01\x01" # PING message. - - id_signature = Bytes64(bytes(64)) - eph_pubkey = Bytes33( - bytes.fromhex("039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5") - ) - - authdata = encode_handshake_authdata( - src_id=NODE_A_ID, - id_signature=id_signature, - eph_pubkey=eph_pubkey, - record=None, - ) - - packet = encode_packet( - dest_node_id=NODE_B_ID, - flag=PacketFlag.HANDSHAKE, - nonce=nonce, - authdata=authdata, - message=message, - encryption_key=SPEC_INITIATOR_KEY, - ) - - # Decode header. - header, ciphertext, _message_ad = decode_packet_header(NODE_B_ID, packet) - - assert header.flag == PacketFlag.HANDSHAKE - - decoded_authdata = decode_handshake_authdata(header.authdata) - assert decoded_authdata.src_id == NODE_A_ID - assert decoded_authdata.eph_pubkey == eph_pubkey - - -class TestOfficialPacketEncoding: - """Byte-exact packet encoding from devp2p spec wire test vectors. - - These tests verify that our packet encoding produces correct structure - and can interoperate with other implementations. - - Reference: - https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire-test-vectors.md - """ - - def test_official_ping_message_rlp_encoding(self): - """PING message RLP encodes to exact spec format. - - PING format: [request-id, enr-seq] - Message type byte 0x01 prepended. - """ - # PING with request ID [0x00, 0x00, 0x00, 0x01] and enr_seq = 1 - ping = Ping( - request_id=RequestId(data=b"\x00\x00\x00\x01"), - enr_seq=SeqNumber(1), - ) - - encoded = encode_message(ping) - - # First byte should be message type PING (0x01). - assert encoded[0] == MessageType.PING - # Rest is RLP-encoded [request-id, enr-seq]. - # request-id: 84 00000001 (4-byte string) - # enr-seq: 01 (single byte) - assert len(encoded) > 1 - - def test_official_pong_message_rlp_encoding(self): - """PONG message RLP encodes to exact spec format. - - PONG format: [request-id, enr-seq, recipient-ip, recipient-port] - Message type byte 0x02 prepended. - """ - pong = Pong( - request_id=RequestId(data=b"\x00\x00\x00\x01"), - enr_seq=SeqNumber(1), - recipient_ip=IPv4(b"\x7f\x00\x00\x01"), # 127.0.0.1 - recipient_port=Port(30303), - ) - - encoded = encode_message(pong) - - # First byte should be message type PONG (0x02). - assert encoded[0] == MessageType.PONG - assert len(encoded) > 1 - - def test_official_findnode_message_rlp_encoding(self): - """FINDNODE message RLP encodes to exact spec format. - - FINDNODE format: [request-id, [distances...]] - Message type byte 0x03 prepended. - """ - findnode = FindNode( - request_id=RequestId(data=b"\x00\x00\x00\x01"), - distances=[Distance(256), Distance(255)], - ) - - encoded = encode_message(findnode) - - # First byte should be message type FINDNODE (0x03). - assert encoded[0] == MessageType.FINDNODE - assert len(encoded) > 1 - - def test_official_nodes_message_rlp_encoding(self): - """NODES message RLP encodes to exact spec format. - - NODES format: [request-id, total, [enrs...]] - Message type byte 0x04 prepended. - """ - nodes = Nodes( - request_id=RequestId(data=b"\x00\x00\x00\x01"), - total=Uint8(1), - enrs=[b"enr:-test-data"], - ) - - encoded = encode_message(nodes) - - # First byte should be message type NODES (0x04). - assert encoded[0] == MessageType.NODES - assert len(encoded) > 1 - - def test_message_packet_header_structure(self): - """MESSAGE packet header follows spec structure. - - Structure: - - masking-iv: 16 bytes - - masked-header: variable (static-header + authdata) - - message ciphertext: variable - - Static header (23 bytes): - - protocol-id: "discv5" (6 bytes) - - version: 0x0001 (2 bytes) - - flag: 0x00 for MESSAGE (1 byte) - - nonce: 12 bytes - - authdata-size: 2 bytes - """ - nonce = Nonce(bytes(12)) - encryption_key = Bytes16(bytes(16)) - message = b"\x01\xc2\x01\x01" - - authdata = encode_message_authdata(NODE_A_ID) - - packet = encode_packet( - dest_node_id=NODE_B_ID, - flag=PacketFlag.MESSAGE, - nonce=nonce, - authdata=authdata, - message=message, - encryption_key=encryption_key, - ) - - # Minimum packet size: 16 (masking-iv) + 23 (static) + 32 (authdata) + 16 (tag) - assert len(packet) >= 16 + STATIC_HEADER_SIZE + 32 + 16 - - def test_whoareyou_packet_header_structure(self): - """WHOAREYOU packet header follows spec structure. - - WHOAREYOU has: - - flag: 0x01 - - authdata: id-nonce (16) + enr-seq (8) = 24 bytes - - no message payload - """ - nonce = Nonce(bytes(12)) - id_nonce = IdNonce(bytes(16)) - enr_seq = SeqNumber(0) - - authdata = encode_whoareyou_authdata(id_nonce, enr_seq) - - packet = encode_packet( - dest_node_id=NODE_B_ID, - flag=PacketFlag.WHOAREYOU, - nonce=nonce, - authdata=authdata, - message=b"", - encryption_key=None, - ) - - # WHOAREYOU packet size: 16 (masking-iv) + 23 (static) + 24 (authdata) - expected_size = 16 + STATIC_HEADER_SIZE + WHOAREYOU_AUTHDATA_SIZE - assert len(packet) == expected_size - - def test_handshake_packet_header_structure(self): - """HANDSHAKE packet header follows spec structure. - - HANDSHAKE has: - - flag: 0x02 - - authdata: src-id (32) + sig-size (1) + eph-key-size (1) + sig + eph-key + [record] - - encrypted message - """ - nonce = Nonce(bytes(12)) - encryption_key = Bytes16(bytes(16)) - message = b"\x01\xc2\x01\x01" - - id_signature = Bytes64(bytes(64)) - eph_pubkey = Bytes33( - bytes.fromhex("039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5") - ) - - authdata = encode_handshake_authdata( - src_id=NODE_A_ID, - id_signature=id_signature, - eph_pubkey=eph_pubkey, - record=None, - ) - - packet = encode_packet( - dest_node_id=NODE_B_ID, - flag=PacketFlag.HANDSHAKE, - nonce=nonce, - authdata=authdata, - message=message, - encryption_key=encryption_key, - ) - - # Minimum: 16 (iv) + 23 (static) + 34 (handshake header) + 64 (sig) + 33 (key) + 16 (tag) - min_size = 16 + STATIC_HEADER_SIZE + HANDSHAKE_HEADER_SIZE + 64 + 33 + 16 - assert len(packet) >= min_size - - -class TestOfficialKeyDerivation: - """Key derivation with exact spec inputs/outputs. - - HKDF parameters per spec: - - Hash: SHA256 - - salt: challenge-data (63 bytes) - - IKM: secret || initiator-id || recipient-id - - info: "discovery v5 key agreement" - - L: 32 bytes (2 x 16-byte keys) - """ - - def test_challenge_data_format(self): - """challenge_data follows spec format: masking-iv || static-header || authdata.""" - id_nonce = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") - - challenge_data = make_challenge_data(id_nonce) - - # Total size: 16 (masking-iv) + 23 (static-header) + 24 (authdata) = 63 bytes. - assert len(challenge_data) == 63 - - # First 16 bytes: masking-iv (all zeros in our test helper). - assert challenge_data[:16] == bytes(16) - - # Bytes 16-22: protocol-id "discv5". - assert challenge_data[16:22] == b"discv5" - - # Bytes 22-24: version 0x0001. - assert challenge_data[22:24] == b"\x00\x01" - - # Byte 24: flag 0x01 (WHOAREYOU). - assert challenge_data[24] == 0x01 - - # Bytes 25-37: nonce (12 bytes, all zeros in test helper). - assert challenge_data[25:37] == bytes(12) - - # Bytes 37-39: authdata-size (24 = 0x0018). - assert challenge_data[37:39] == b"\x00\x18" - - # Bytes 39-55: id-nonce (16 bytes). - assert challenge_data[39:55] == id_nonce - - # Bytes 55-63: enr-seq (8 bytes, all zeros). - assert challenge_data[55:63] == bytes(8) - - -class TestAESCryptoEdgeCases: - """Additional AES-GCM test cases beyond spec vectors.""" - - def test_aes_gcm_empty_plaintext(self): - """AES-GCM handles empty plaintext correctly.""" - aad = bytes(32) - plaintext = b"" - - ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, plaintext, aad) - - # Empty plaintext should produce just the 16-byte auth tag. - assert len(ciphertext) == 16 - - # Decryption should recover empty plaintext. - decrypted = aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, ciphertext, aad) - assert decrypted == b"" - - def test_aes_gcm_large_plaintext(self): - """AES-GCM handles large plaintext correctly.""" - aad = bytes(32) - plaintext = bytes(1024) # 1KB of zeros. - - ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, plaintext, aad) - - # Ciphertext = plaintext length + 16-byte tag. - assert len(ciphertext) == len(plaintext) + 16 - - # Decryption should recover original plaintext. - decrypted = aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, ciphertext, aad) - assert decrypted == plaintext - - def test_aes_gcm_tampered_ciphertext_fails(self): - """AES-GCM decryption fails with tampered ciphertext.""" - aad = bytes(32) - plaintext = b"secret message" - - ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, plaintext, aad) - - # Tamper with ciphertext by flipping a bit. - tampered = bytearray(ciphertext) - tampered[0] ^= 0x01 - tampered = bytes(tampered) - - # Decryption of tampered ciphertext should fail with InvalidTag. - with pytest.raises(InvalidTag): - aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, tampered, aad) - - -class TestSpecPacketPayloadDecryption: - """Verify message payload decryption using correct AAD (masking-iv || plaintext header).""" - - def test_message_packet_encrypt_decrypt_roundtrip(self): - """Encrypt a message in a packet and decrypt using message_ad from decode.""" - nonce = Nonce(bytes(12)) - - authdata = encode_message_authdata(NODE_A_ID) - - packet = encode_packet( - dest_node_id=NODE_B_ID, - flag=PacketFlag.MESSAGE, - nonce=nonce, - authdata=authdata, - message=SPEC_PING_PLAINTEXT, - encryption_key=SPEC_INITIATOR_KEY, - ) - - # Decode header - returns message_ad for AAD. - header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) - - # Decrypt using message_ad as AAD. - decrypted = decrypt_message(SPEC_INITIATOR_KEY, header.nonce, ciphertext, message_ad) - assert decrypted == SPEC_PING_PLAINTEXT - - def test_handshake_packet_encrypt_decrypt_roundtrip(self): - """Handshake packet encrypts and decrypts using correct AAD.""" - nonce = Nonce(bytes(12)) - - id_signature = Bytes64(bytes(64)) - eph_pubkey = Bytes33( - bytes.fromhex("039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5") - ) - - authdata = encode_handshake_authdata( - src_id=NODE_A_ID, - id_signature=id_signature, - eph_pubkey=eph_pubkey, - record=None, - ) - - packet = encode_packet( - dest_node_id=NODE_B_ID, - flag=PacketFlag.HANDSHAKE, - nonce=nonce, - authdata=authdata, - message=SPEC_PING_PLAINTEXT, - encryption_key=SPEC_INITIATOR_KEY, - ) - - # Decode header - returns message_ad for AAD. - header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) - - # Decrypt using message_ad as AAD. - decrypted = decrypt_message(SPEC_INITIATOR_KEY, header.nonce, ciphertext, message_ad) - assert decrypted == SPEC_PING_PLAINTEXT - - -class TestRoutingWithTestVectorNodeIds: - """Tests using official test vector node IDs with routing functions.""" - - def test_xor_distance_is_symmetric(self): - """XOR distance between test vector nodes is symmetric and non-zero.""" - distance = xor_distance(NODE_A_ID, NODE_B_ID) - assert distance > 0 - assert xor_distance(NODE_A_ID, NODE_B_ID) == xor_distance(NODE_B_ID, NODE_A_ID) - - def test_log2_distance_is_high(self): - """Log2 distance between test vector nodes is high (differ in high bits).""" - - log_dist = log2_distance(NODE_A_ID, NODE_B_ID) - assert log_dist > Distance(200)