diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 22645d3ba..9ba81773a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -331,47 +331,61 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: await event_source.response.aclose() break + async def _send_error_response(self, ctx: RequestContext, error: Exception) -> None: + """Send an error response to the client.""" + error_data = ErrorData(code=32000, message=str(error)) + if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch + jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=ctx.session_message.message.root.id, error=error_data) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + await ctx.read_stream_writer.send(session_message) + async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" headers = self._prepare_headers() message = ctx.session_message.message is_initialization = self._is_initialization_request(message) - async with ctx.client.stream( - "POST", - self.url, - json=message.model_dump(by_alias=True, mode="json", exclude_none=True), - headers=headers, - ) as response: - if response.status_code == 202: - logger.debug("Received 202 Accepted") - return + try: + async with ctx.client.stream( + "POST", + self.url, + json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + headers=headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + return - if response.status_code == 404: # pragma: no branch - if isinstance(message.root, JSONRPCRequest): - await self._send_session_terminated_error( # pragma: no cover - ctx.read_stream_writer, # pragma: no cover - message.root.id, # pragma: no cover - ) # pragma: no cover - return # pragma: no cover + if response.status_code == 404: # pragma: no branch + if isinstance(message.root, JSONRPCRequest): + await self._send_session_terminated_error( # pragma: no cover + ctx.read_stream_writer, # pragma: no cover + message.root.id, # pragma: no cover + ) # pragma: no cover + return # pragma: no cover - response.raise_for_status() + response.raise_for_status() + if is_initialization: + self._maybe_extract_session_id_from_response(response) + + # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications: + # The server MUST NOT send a response to notifications. + if isinstance(message.root, JSONRPCRequest): + content_type = response.headers.get(CONTENT_TYPE, "").lower() + if content_type.startswith(JSON): + await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) + elif content_type.startswith(SSE): + await self._handle_sse_response(response, ctx, is_initialization) + else: + await self._handle_unexpected_content_type( # pragma: no cover + content_type, # pragma: no cover + ctx.read_stream_writer, # pragma: no cover + ) # pragma: no cover + except Exception as exc: if is_initialization: - self._maybe_extract_session_id_from_response(response) - - # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications: - # The server MUST NOT send a response to notifications. - if isinstance(message.root, JSONRPCRequest): - content_type = response.headers.get(CONTENT_TYPE, "").lower() - if content_type.startswith(JSON): - await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) - elif content_type.startswith(SSE): - await self._handle_sse_response(response, ctx, is_initialization) - else: - await self._handle_unexpected_content_type( # pragma: no cover - content_type, # pragma: no cover - ctx.read_stream_writer, # pragma: no cover - ) # pragma: no cover + raise exc + else: + await self._send_error_response(ctx, exc) async def _handle_json_response( self, @@ -406,7 +420,7 @@ async def _handle_sse_response( try: event_source = EventSource(response) - async for sse in event_source.aiter_sse(): # pragma: no branch + async for sse in event_source.aiter_sse(): # Track last event ID for potential reconnection if sse.id: last_event_id = sse.id @@ -426,13 +440,17 @@ async def _handle_sse_response( if is_complete: await response.aclose() return # Normal completion, no reconnect needed - except Exception as e: # pragma: no cover - logger.debug(f"SSE stream ended: {e}") - # Stream ended without response - reconnect if we received an event with ID - if last_event_id is not None: # pragma: no branch - logger.info("SSE stream disconnected, reconnecting...") - await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) + # Stream ended without response - try to reconnect if we have an event ID + if last_event_id is not None: + logger.info("SSE stream disconnected, reconnecting...") + await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) + else: + # No event ID received, can't reconnect - report error + raise Exception("SSE stream ended without completing") + except Exception as exc: + logger.exception("Error handling SSE response") + await self._send_error_response(ctx, exc) async def _handle_reconnection( self, @@ -441,11 +459,14 @@ async def _handle_reconnection( retry_interval_ms: int | None = None, attempt: int = 0, ) -> None: - """Reconnect with Last-Event-ID to resume stream after server disconnect.""" + """Reconnect with Last-Event-ID to resume stream after server disconnect. + + Raises: + Exception: If max reconnection attempts exceeded or reconnection fails. + """ # Bail if max retries exceeded - if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover - logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") - return + if attempt >= MAX_RECONNECTION_ATTEMPTS: + raise Exception(f"SSE stream reconnection failed after {MAX_RECONNECTION_ATTEMPTS} attempts") # Always wait - use server value or default delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS @@ -492,7 +513,7 @@ async def _handle_reconnection( # Stream ended again without response - reconnect again (reset attempt counter) logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) - except Exception as e: # pragma: no cover + except Exception as e: logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index e95c309fb..146bebfe5 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -4,6 +4,7 @@ Contains tests for both server and client sides of the StreamableHTTP transport. """ +import contextlib import json import multiprocessing import socket @@ -2393,3 +2394,269 @@ async def test_streamablehttp_client_deprecation_warning(basic_server: None, bas await session.initialize() tools = await session.list_tools() assert len(tools.tools) > 0 + + +@pytest.mark.anyio +async def test_sse_stream_ends_without_completing_no_event_id() -> None: + """Test that SSE stream ending without completing and no event ID sends error response.""" + from unittest.mock import MagicMock, patch + + from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport + from mcp.shared.message import SessionMessage + from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest + + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + + # Create a mock response that returns an empty SSE stream (no events) + mock_response = MagicMock() + + async def mock_aclose() -> None: + pass # pragma: no cover + + mock_response.aclose = mock_aclose + + # Create a mock EventSource that yields no events + async def empty_iter(): + return + yield # Make it an async generator that yields nothing + + mock_event_source = MagicMock() + mock_event_source.aiter_sse = empty_iter + + # Create streams for testing + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + # Create a request context + mock_client = MagicMock() + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="test")) + session_message = SessionMessage(message=mock_message) + + ctx = RequestContext( + client=mock_client, + session_id="test-session", + session_message=session_message, + metadata=None, + read_stream_writer=write_stream, + ) + + try: + with patch("mcp.client.streamable_http.EventSource", return_value=mock_event_source): + await transport._handle_sse_response(mock_response, ctx, is_initialization=False) + + # Should have received an error response + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert isinstance(received.message.root, JSONRPCError) + assert "SSE stream ended without completing" in received.message.root.error.message + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_handle_post_request_non_init_error_sends_error_response() -> None: + """Test that non-initialization request errors send error response instead of raising.""" + from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport + from mcp.shared.message import SessionMessage + from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest + + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + + # Create streams for testing + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + # Create a non-initialization request + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list")) + session_message = SessionMessage(message=mock_message) + + # Create a mock client that raises an exception + mock_client = MagicMock() + + # Create an async context manager that raises + class FailingStream: + async def __aenter__(self) -> None: + raise httpx.HTTPStatusError("Server error", request=MagicMock(), response=MagicMock(status_code=500)) + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + pass # pragma: no cover + + mock_client.stream = MagicMock(return_value=FailingStream()) + + ctx = RequestContext( + client=mock_client, + session_id="test-session", + session_message=session_message, + metadata=None, + read_stream_writer=write_stream, + ) + + try: + # This should NOT raise, but send an error response + await transport._handle_post_request(ctx) + + # Should have received an error response + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert isinstance(received.message.root, JSONRPCError) + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_handle_post_request_init_error_raises() -> None: + """Test that initialization request errors are raised, not sent as error response.""" + from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport + from mcp.shared.message import SessionMessage + from mcp.types import JSONRPCMessage, JSONRPCRequest + + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + + # Create streams for testing + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + # Create an initialization request + mock_message = JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id="init-1", + method="initialize", + params={ + "clientInfo": {"name": "test", "version": "1.0"}, + "protocolVersion": "2025-03-26", + "capabilities": {}, + }, + ) + ) + session_message = SessionMessage(message=mock_message) + + # Create a mock client that raises an exception + mock_client = MagicMock() + + class FailingStream: + async def __aenter__(self) -> None: + raise httpx.HTTPStatusError("Server error", request=MagicMock(), response=MagicMock(status_code=500)) + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + pass # pragma: no cover + + mock_client.stream = MagicMock(return_value=FailingStream()) + + ctx = RequestContext( + client=mock_client, + session_id=None, + session_message=session_message, + metadata=None, + read_stream_writer=write_stream, + ) + + try: + # This SHOULD raise for initialization requests + with pytest.raises(httpx.HTTPStatusError): + await transport._handle_post_request(ctx) + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_handle_reconnection_max_attempts_exceeded() -> None: + """Test that _handle_reconnection raises when max attempts exceeded.""" + from mcp.client.streamable_http import ( + MAX_RECONNECTION_ATTEMPTS, + RequestContext, + StreamableHTTPTransport, + ) + from mcp.shared.message import SessionMessage + from mcp.types import JSONRPCMessage, JSONRPCRequest + + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + + # Create streams for testing + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + # Create a request context + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="test")) + session_message = SessionMessage(message=mock_message) + + ctx = RequestContext( + client=MagicMock(), + session_id="test-session", + session_message=session_message, + metadata=None, + read_stream_writer=write_stream, + ) + + try: + # Call with attempt >= MAX_RECONNECTION_ATTEMPTS should raise + with pytest.raises(Exception, match="SSE stream reconnection failed"): + await transport._handle_reconnection( + ctx, + last_event_id="test-event-id", + retry_interval_ms=1, # Use 1ms to speed up test + attempt=MAX_RECONNECTION_ATTEMPTS, + ) + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_handle_reconnection_failure_retries() -> None: + """Test that _handle_reconnection retries on failure and eventually raises.""" + from collections.abc import AsyncGenerator + from unittest.mock import MagicMock, patch + + from mcp.client.streamable_http import ( + MAX_RECONNECTION_ATTEMPTS, + RequestContext, + StreamableHTTPTransport, + ) + from mcp.shared.message import SessionMessage + from mcp.types import JSONRPCMessage, JSONRPCRequest + + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + + # Create streams for testing + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + # Create a mock client + mock_client = MagicMock() + + # Create a request context + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="test")) + session_message = SessionMessage(message=mock_message) + + ctx = RequestContext( + client=mock_client, + session_id="test-session", + session_message=session_message, + metadata=None, + read_stream_writer=write_stream, + ) + + # Track how many times aconnect_sse is called + call_count = 0 + + @contextlib.asynccontextmanager + async def failing_aconnect_sse(*args: Any, **kwargs: Any) -> AsyncGenerator[None, None]: + nonlocal call_count + call_count += 1 + raise httpx.HTTPStatusError("Connection failed", request=MagicMock(), response=MagicMock(status_code=503)) + yield # Make it an async generator + + try: + with patch("mcp.client.streamable_http.aconnect_sse", failing_aconnect_sse): + with pytest.raises(Exception, match="SSE stream reconnection failed"): + await transport._handle_reconnection( + ctx, + last_event_id="test-event-id", + retry_interval_ms=1, # Use 1ms to speed up test + attempt=0, + ) + + # Should have tried MAX_RECONNECTION_ATTEMPTS times + assert call_count == MAX_RECONNECTION_ATTEMPTS + finally: + await write_stream.aclose() + await read_stream.aclose()