diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index cceefccce..0efa5c8d7 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -16,6 +16,7 @@ from mcp.shared.response_router import ResponseRouter from mcp.types import ( CONNECTION_CLOSED, + INTERNAL_ERROR, INVALID_PARAMS, CancelledNotification, ClientNotification, @@ -237,6 +238,34 @@ async def __aexit__( self._task_group.cancel_scope.cancel() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + @staticmethod + def _process_response( + response_or_error: JSONRPCResponse | JSONRPCError | None, + result_type: type[ReceiveResultT], + ) -> ReceiveResultT: + """ + Process a JSON-RPC response, validating and returning the result. + + Raises McpError if the response is an error or if response_or_error is None. + The None check is a defensive guard against anyio race conditions - see #1717. + """ + if response_or_error is None: + # Defensive check for anyio fail_after race condition (#1717). + # If anyio's CancelScope incorrectly suppresses an exception, + # the response variable may never be assigned. See: + # https://github.com/agronholm/anyio/issues/589 + raise McpError( + ErrorData( + code=INTERNAL_ERROR, + message="Internal error: no response received", + ) + ) + + if isinstance(response_or_error, JSONRPCError): + raise McpError(response_or_error.error) + + return result_type.model_validate(response_or_error.result) + async def send_request( self, request: SendRequestT, @@ -287,6 +316,10 @@ async def send_request( elif self._session_read_timeout_seconds is not None: # pragma: no cover timeout = self._session_read_timeout_seconds.total_seconds() + # Initialize to None as a defensive guard against anyio race conditions + # where fail_after may incorrectly suppress exceptions (#1717) + response_or_error: JSONRPCResponse | JSONRPCError | None = None + try: with anyio.fail_after(timeout): response_or_error = await response_stream_reader.receive() @@ -301,12 +334,22 @@ async def send_request( ), ) ) - - if isinstance(response_or_error, JSONRPCError): - raise McpError(response_or_error.error) + except anyio.EndOfStream: + raise McpError( + ErrorData( + code=CONNECTION_CLOSED, + message="Connection closed: stream ended unexpectedly", + ) + ) + except anyio.ClosedResourceError: + raise McpError( + ErrorData( + code=CONNECTION_CLOSED, + message="Connection closed", + ) + ) else: - return result_type.model_validate(response_or_error.result) - + return self._process_response(response_or_error, result_type) finally: self._response_streams.pop(request_id, None) self._progress_callbacks.pop(request_id, None) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 313ec9926..92af723d2 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,5 +1,6 @@ from collections.abc import AsyncGenerator from typing import Any +from unittest.mock import AsyncMock, patch import anyio import pytest @@ -9,12 +10,18 @@ from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session +from mcp.shared.session import BaseSession from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, CancelledNotification, CancelledNotificationParams, ClientNotification, ClientRequest, EmptyResult, + ErrorData, + JSONRPCError, + JSONRPCResponse, TextContent, ) @@ -168,3 +175,111 @@ async def mock_server(): await ev_closed.wait() with anyio.fail_after(1): await ev_response.wait() + + +class TestProcessResponse: + """Tests for BaseSession._process_response static method.""" + + def test_process_response_with_valid_response(self): + """Test that a valid JSONRPCResponse is processed correctly.""" + response = JSONRPCResponse( + jsonrpc="2.0", + id=1, + result={}, + ) + + result = BaseSession._process_response(response, EmptyResult) + + assert isinstance(result, EmptyResult) + + def test_process_response_with_error(self): + """Test that a JSONRPCError raises McpError.""" + error = JSONRPCError( + jsonrpc="2.0", + id=1, + error=ErrorData(code=-32600, message="Invalid request"), + ) + + with pytest.raises(McpError) as exc_info: + BaseSession._process_response(error, EmptyResult) + + assert exc_info.value.error.code == -32600 + assert exc_info.value.error.message == "Invalid request" + + def test_process_response_with_none(self): + """ + Test defensive check for anyio fail_after race condition (#1717). + + If anyio's CancelScope incorrectly suppresses an exception during + receive(), the response variable may never be assigned. This test + verifies we handle this gracefully instead of raising UnboundLocalError. + + See: https://github.com/agronholm/anyio/issues/589 + """ + with pytest.raises(McpError) as exc_info: + BaseSession._process_response(None, EmptyResult) + + assert exc_info.value.error.code == INTERNAL_ERROR + assert "no response received" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_send_request_handles_end_of_stream(): + """Test that EndOfStream from response stream raises McpError with CONNECTION_CLOSED.""" + + async with create_client_server_memory_streams() as (client_streams, _): + client_read, client_write = client_streams + + async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session: + # Mock create_memory_object_stream to return a stream that raises EndOfStream + mock_reader = AsyncMock() + mock_reader.receive = AsyncMock(side_effect=anyio.EndOfStream) + mock_reader.aclose = AsyncMock() + + mock_sender = AsyncMock() + mock_sender.aclose = AsyncMock() + + # The subscripted form returns a callable that returns the tuple + with patch("mcp.shared.session.anyio.create_memory_object_stream") as mock_create: + # pyright: ignore[reportUnknownLambdaType] + mock_create.__getitem__ = lambda _s, _k: lambda _z: (mock_sender, mock_reader) # type: ignore + + with pytest.raises(McpError) as exc_info: + await client_session.send_request( + ClientRequest(types.PingRequest()), + EmptyResult, + ) + + assert exc_info.value.error.code == CONNECTION_CLOSED + assert "stream ended unexpectedly" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_send_request_handles_closed_resource_error(): + """Test that ClosedResourceError from response stream raises McpError with CONNECTION_CLOSED.""" + + async with create_client_server_memory_streams() as (client_streams, _): + client_read, client_write = client_streams + + async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session: + # Mock create_memory_object_stream to return a stream that raises ClosedResourceError + mock_reader = AsyncMock() + mock_reader.receive = AsyncMock(side_effect=anyio.ClosedResourceError) + mock_reader.aclose = AsyncMock() + + mock_sender = AsyncMock() + mock_sender.aclose = AsyncMock() + + # The subscripted form returns a callable that returns the tuple + with patch("mcp.shared.session.anyio.create_memory_object_stream") as mock_create: + # pyright: ignore[reportUnknownLambdaType] + mock_create.__getitem__ = lambda _s, _k: lambda _z: (mock_sender, mock_reader) # type: ignore + + with pytest.raises(McpError) as exc_info: + await client_session.send_request( + ClientRequest(types.PingRequest()), + EmptyResult, + ) + + assert exc_info.value.error.code == CONNECTION_CLOSED + assert "Connection closed" in exc_info.value.error.message