From dd4939d879cc287340c23c174cb38f0f2220e4c7 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 10:28:37 +0530 Subject: [PATCH 01/44] add transport abstraction --- src/mcp/client/transport_session.py | 133 ++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 src/mcp/client/transport_session.py diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py new file mode 100644 index 000000000..f1cd84f17 --- /dev/null +++ b/src/mcp/client/transport_session.py @@ -0,0 +1,133 @@ +from abc import ABC +from abc import abstractmethod +from datetime import timedelta + +from typing import Any + +from pydantic import AnyUrl + +from mcp import types +from mcp.shared.session import ProgressFnT + + +class TransportSession(ABC): + """Abstract base class for communication transports.""" + + @abstractmethod + async def initialize(self) -> types.InitializeResult: + """Send an initialize request.""" + ... + + @abstractmethod + async def send_ping(self): + ... + + @abstractmethod + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + ... + + @abstractmethod + async def set_logging_level( + self, + level: types.LoggingLevel, + ) -> types.EmptyResult: + """Send a logging/setLevel request.""" + ... + + @abstractmethod + async def list_resources( + self, + cursor: str | None = None, + ) -> types.ListResourcesResult: + """Send a resources/list request.""" + ... + + @abstractmethod + async def list_resource_templates( + self, + cursor: str | None = None, + ) -> types.ListResourceTemplatesResult: + """Send a resources/templates/list request.""" + ... + + @abstractmethod + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + """Send a resources/read request.""" + ... + + @abstractmethod + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/subscribe request.""" + ... + + @abstractmethod + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/unsubscribe request.""" + ... + + @abstractmethod + async def call_tool( + self, + name: str, + arguments: Any | None = None, + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + ) -> types.CallToolResult: + """Send a tools/call request with optional progress callback support.""" + ... + + @abstractmethod + async def _validate_tool_result( + self, + name: str, + result: types.CallToolResult, + ) -> None: + """Validate the structured content of a tool result against its output + schema.""" + ... + + @abstractmethod + async def list_prompts( + self, + cursor: str | None = None, + ) -> types.ListPromptsResult: + """Send a prompts/list request.""" + ... + + @abstractmethod + async def get_prompt( + self, + name: str, + arguments: dict[str, str] | None = None, + ) -> types.GetPromptResult: + """Send a prompts/get request.""" + ... + + @abstractmethod + async def complete( + self, + ref: types.ResourceTemplateReference | types.PromptReference, + argument: dict[str, str], + context_arguments: dict[str, str] | None = None, + ) -> types.CompleteResult: + """Send a completion/complete request.""" + ... + + @abstractmethod + async def list_tools( + self, + cursor: str | None = None, + ) -> types.ListToolsResult: + """Send a tools/list request.""" + ... + + @abstractmethod + async def send_roots_list_changed(self) -> None: + """Send a roots/list_changed notification.""" + ... \ No newline at end of file From 11d12494d89edc6ca2a1cb5e14b560c6fa0d2143 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 06:10:10 +0000 Subject: [PATCH 02/44] fix ruff --- src/mcp/client/transport_session.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index f1cd84f17..7575afa55 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,7 +1,5 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from datetime import timedelta - from typing import Any from pydantic import AnyUrl From c8f3a42bf6f38ef7a4c215492709b23ea846d285 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 06:12:32 +0000 Subject: [PATCH 03/44] fix ruff format --- src/mcp/client/transport_session.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 7575afa55..5ce6fd34e 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -17,8 +17,7 @@ async def initialize(self) -> types.InitializeResult: ... @abstractmethod - async def send_ping(self): - ... + async def send_ping(self): ... @abstractmethod async def send_progress_notification( @@ -27,8 +26,7 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, - ) -> None: - ... + ) -> None: ... @abstractmethod async def set_logging_level( @@ -128,4 +126,4 @@ async def list_tools( @abstractmethod async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - ... \ No newline at end of file + ... From 03cc6c525aa6d0d94c3f0284fc80d8ec937263b9 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:00:03 +0530 Subject: [PATCH 04/44] add transport session for server --- src/mcp/server/transport_session.py | 113 ++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 src/mcp/server/transport_session.py diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py new file mode 100644 index 000000000..becc5a855 --- /dev/null +++ b/src/mcp/server/transport_session.py @@ -0,0 +1,113 @@ +"""Abstract base class for transport sessions.""" + +import abc +from typing import Any + +from anyio.streams.memory import MemoryObjectReceiveStream +from pydantic import AnyUrl + +import mcp_grpc.types as types +from mcp_grpc.server.session import ServerRequestResponder + + +class TransportSession(abc.ABC): + """Abstract base class for transport sessions.""" + + @property + @abc.abstractmethod + def client_params(self) -> types.InitializeRequestParams | None: + """Client initialization parameters.""" + raise NotImplementedError + + @abc.abstractmethod + def check_client_capability(self, capability: types.ClientCapabilities) -> bool: + """Check if the client supports a specific capability.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_log_message( + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send a log message notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_resource_updated(self, uri: AnyUrl) -> None: + """Send a resource updated notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def create_message( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResult: + """Send a sampling/create_message request.""" + raise NotImplementedError + + @abc.abstractmethod + async def list_roots(self) -> types.ListRootsResult: + """Send a roots/list request.""" + raise NotImplementedError + + @abc.abstractmethod + async def elicit( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send an elicitation/create request.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + related_request_id: str | None = None, + ) -> None: + """Send a progress notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_resource_list_changed(self) -> None: + """Send a resource list changed notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_tool_list_changed(self) -> None: + """Send a tool list changed notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_prompt_list_changed(self) -> None: + """Send a prompt list changed notification.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def incoming_messages( + self, + ) -> MemoryObjectReceiveStream[ServerRequestResponder]: + """Incoming messages stream.""" + raise NotImplementedError From 1327a9cca39f41ec4d08d7f8672d54e28869fbad Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:01:29 +0530 Subject: [PATCH 05/44] clientsession and server session to implement abstract classes --- src/mcp/client/session.py | 3 +++ src/mcp/server/session.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3835a2a57..339c64abd 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -14,6 +14,8 @@ from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from src.mcp.client.transport_session import TransportSession + DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") logger = logging.getLogger("client") @@ -100,6 +102,7 @@ async def _default_logging_callback( class ClientSession( + TransportSession, BaseSession[ types.ClientRequest, types.ClientNotification, diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index a1bfadc9f..3dc888843 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -54,6 +54,8 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from src.mcp.server.transport_session import TransportSession + class InitializationState(Enum): NotInitialized = 1 @@ -69,6 +71,7 @@ class InitializationState(Enum): class ServerSession( + TransportSession, BaseSession[ types.ServerRequest, types.ServerNotification, From 0018679c02095d981fa1c84c779633c3b9aa31e4 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:04:00 +0530 Subject: [PATCH 06/44] add raise not implemented --- src/mcp/client/transport_session.py | 38 ++++++++++++++++------------- src/mcp/server/transport_session.py | 4 +-- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 5ce6fd34e..a85b39718 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,5 +1,7 @@ -from abc import ABC, abstractmethod +from abc import ABC +from abc import abstractmethod from datetime import timedelta + from typing import Any from pydantic import AnyUrl @@ -14,10 +16,11 @@ class TransportSession(ABC): @abstractmethod async def initialize(self) -> types.InitializeResult: """Send an initialize request.""" - ... + raise NotImplementedError @abstractmethod - async def send_ping(self): ... + async def send_ping(self): + raise NotImplementedError @abstractmethod async def send_progress_notification( @@ -26,7 +29,8 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, - ) -> None: ... + ) -> None: + raise NotImplementedError @abstractmethod async def set_logging_level( @@ -34,7 +38,7 @@ async def set_logging_level( level: types.LoggingLevel, ) -> types.EmptyResult: """Send a logging/setLevel request.""" - ... + raise NotImplementedError @abstractmethod async def list_resources( @@ -42,7 +46,7 @@ async def list_resources( cursor: str | None = None, ) -> types.ListResourcesResult: """Send a resources/list request.""" - ... + raise NotImplementedError @abstractmethod async def list_resource_templates( @@ -50,22 +54,22 @@ async def list_resource_templates( cursor: str | None = None, ) -> types.ListResourceTemplatesResult: """Send a resources/templates/list request.""" - ... + raise NotImplementedError @abstractmethod async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: """Send a resources/read request.""" - ... + raise NotImplementedError @abstractmethod async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/subscribe request.""" - ... + raise NotImplementedError @abstractmethod async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/unsubscribe request.""" - ... + raise NotImplementedError @abstractmethod async def call_tool( @@ -76,7 +80,7 @@ async def call_tool( progress_callback: ProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" - ... + raise NotImplementedError @abstractmethod async def _validate_tool_result( @@ -86,7 +90,7 @@ async def _validate_tool_result( ) -> None: """Validate the structured content of a tool result against its output schema.""" - ... + raise NotImplementedError @abstractmethod async def list_prompts( @@ -94,7 +98,7 @@ async def list_prompts( cursor: str | None = None, ) -> types.ListPromptsResult: """Send a prompts/list request.""" - ... + raise NotImplementedError @abstractmethod async def get_prompt( @@ -103,7 +107,7 @@ async def get_prompt( arguments: dict[str, str] | None = None, ) -> types.GetPromptResult: """Send a prompts/get request.""" - ... + raise NotImplementedError @abstractmethod async def complete( @@ -113,7 +117,7 @@ async def complete( context_arguments: dict[str, str] | None = None, ) -> types.CompleteResult: """Send a completion/complete request.""" - ... + raise NotImplementedError @abstractmethod async def list_tools( @@ -121,9 +125,9 @@ async def list_tools( cursor: str | None = None, ) -> types.ListToolsResult: """Send a tools/list request.""" - ... + raise NotImplementedError @abstractmethod async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - ... + raise NotImplementedError \ No newline at end of file diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index becc5a855..bd0d592e5 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -6,8 +6,8 @@ from anyio.streams.memory import MemoryObjectReceiveStream from pydantic import AnyUrl -import mcp_grpc.types as types -from mcp_grpc.server.session import ServerRequestResponder +import mcp.types as types +from mcp.server.session import ServerRequestResponder class TransportSession(abc.ABC): From af7ff5a0e3bd088a0b0aba0608a774188872c24d Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:25:45 +0530 Subject: [PATCH 07/44] fix abstract server transport session --- src/mcp/server/transport_session.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index bd0d592e5..efc6aad68 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -7,8 +7,6 @@ from pydantic import AnyUrl import mcp.types as types -from mcp.server.session import ServerRequestResponder - class TransportSession(abc.ABC): """Abstract base class for transport sessions.""" @@ -103,11 +101,3 @@ async def send_tool_list_changed(self) -> None: async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" raise NotImplementedError - - @property - @abc.abstractmethod - def incoming_messages( - self, - ) -> MemoryObjectReceiveStream[ServerRequestResponder]: - """Incoming messages stream.""" - raise NotImplementedError From 7f468d0210782afabb29c90d004df3981c215956 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:26:36 +0530 Subject: [PATCH 08/44] removed unused import --- src/mcp/server/transport_session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index efc6aad68..f23a8361d 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -3,7 +3,6 @@ import abc from typing import Any -from anyio.streams.memory import MemoryObjectReceiveStream from pydantic import AnyUrl import mcp.types as types From e895d90b5101cc64fc611074c4bc77899a462230 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:35:24 +0530 Subject: [PATCH 09/44] fix type hints --- src/mcp/server/elicitation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index bba988f49..47be94b13 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.session import ServerSession +from mcp.server.transport_session import TransportSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -74,7 +74,7 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: async def elicit_with_validation( - session: ServerSession, + session: TransportSession, message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, From d01e477a3b7b77c94f2f08c78596f8deba9b48c4 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:42:52 +0530 Subject: [PATCH 10/44] revert type hints --- src/mcp/server/elicitation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 47be94b13..bba988f49 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.transport_session import TransportSession +from mcp.server.session import ServerSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -74,7 +74,7 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: async def elicit_with_validation( - session: TransportSession, + session: ServerSession, message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, From 7bdafa384f796e16505104b089b2d5c5a636ec7b Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:44:58 +0530 Subject: [PATCH 11/44] fix import --- src/mcp/server/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 3dc888843..00355ae9e 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -54,7 +54,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from src.mcp.server.transport_session import TransportSession +from mcp.server.transport_session import TransportSession class InitializationState(Enum): From e9f63dd45f65624c72d000bbb91dc3bcd6790adb Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:46:44 +0530 Subject: [PATCH 12/44] fix import --- src/mcp/client/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 339c64abd..c058de172 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -14,7 +14,7 @@ from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from src.mcp.client.transport_session import TransportSession +from mcp.client.transport_session import TransportSession DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") From 5b156a16728cdccf22fa3991316fbd5127d636e3 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 07:18:36 +0000 Subject: [PATCH 13/44] fix ruff format --- src/mcp/client/session.py | 2 +- src/mcp/client/transport_session.py | 2 +- src/mcp/server/session.py | 2 +- src/mcp/server/transport_session.py | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c058de172..c07ca8c50 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -109,7 +109,7 @@ class ClientSession( types.ClientResult, types.ServerRequest, types.ServerNotification, - ] + ], ): def __init__( self, diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index a85b39718..8dbe1a82d 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -130,4 +130,4 @@ async def list_tools( @abstractmethod async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 00355ae9e..e50e7d004 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -78,7 +78,7 @@ class ServerSession( types.ServerResult, types.ClientRequest, types.ClientNotification, - ] + ], ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index f23a8361d..d0288a0f3 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -7,6 +7,7 @@ import mcp.types as types + class TransportSession(abc.ABC): """Abstract base class for transport sessions.""" From f26d861db283bfb4903df0cbf48a88ff8446576b Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:58:12 +0530 Subject: [PATCH 14/44] request context as optional param --- src/mcp/server/fastmcp/server.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 865b8e7e7..9871063c3 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -326,10 +326,18 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: request_context = None return Context(request_context=request_context, fastmcp=self) - async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: + async def call_tool( + self, name: str, arguments: dict[str, Any], + request_context: RequestContext | None = None + ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" - context = self.get_context() - return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) + if request_context: + context = Context(request_context=request_context, fastmcp=self) + else: + context = self.get_context() + return await self._tool_manager.call_tool(name, arguments, + context=context, + convert_result=True) async def list_resources(self) -> list[MCPResource]: """List all available resources.""" From 3097cb3a3360ca5993c873b8b001099120868e28 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 07:28:45 +0000 Subject: [PATCH 15/44] fix format --- src/mcp/server/fastmcp/server.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 9871063c3..7da7ca43d 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -327,17 +327,14 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: return Context(request_context=request_context, fastmcp=self) async def call_tool( - self, name: str, arguments: dict[str, Any], - request_context: RequestContext | None = None + self, name: str, arguments: dict[str, Any], request_context: RequestContext | None = None ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" if request_context: context = Context(request_context=request_context, fastmcp=self) else: context = self.get_context() - return await self._tool_manager.call_tool(name, arguments, - context=context, - convert_result=True) + return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) async def list_resources(self) -> list[MCPResource]: """List all available resources.""" From 9e8dca3a075e393c70af798c0fb13fcbdd493c30 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 08:08:19 +0000 Subject: [PATCH 16/44] ruff check --fix --- src/mcp/client/session.py | 3 +-- src/mcp/client/transport_session.py | 4 +--- src/mcp/server/session.py | 3 +-- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c07ca8c50..02646924b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,13 +9,12 @@ from typing_extensions import deprecated import mcp.types as types +from mcp.client.transport_session import TransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.client.transport_session import TransportSession - DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") logger = logging.getLogger("client") diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 8dbe1a82d..9f9f3f8c4 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,7 +1,5 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from datetime import timedelta - from typing import Any from pydantic import AnyUrl diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index e50e7d004..99fdb8f3f 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions +from mcp.server.transport_session import TransportSession from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -54,8 +55,6 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.server.transport_session import TransportSession - class InitializationState(Enum): NotInitialized = 1 From 5b7b458f963d608252efaf13b9d39fcd6bb1824e Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 08:16:17 +0000 Subject: [PATCH 17/44] fix pyright --- src/mcp/client/transport_session.py | 2 +- src/mcp/server/fastmcp/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 9f9f3f8c4..71e69ee3e 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -17,7 +17,7 @@ async def initialize(self) -> types.InitializeResult: raise NotImplementedError @abstractmethod - async def send_ping(self): + async def send_ping(self) -> types.EmptyResult: raise NotImplementedError @abstractmethod diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 7da7ca43d..05bf7f3b7 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -327,7 +327,7 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: return Context(request_context=request_context, fastmcp=self) async def call_tool( - self, name: str, arguments: dict[str, Any], request_context: RequestContext | None = None + self, name: str, arguments: dict[str, Any], request_context: RequestContext[ServerSession, LifespanResultT, Request] | None = None ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" if request_context: From 8ca511ef8b55302411b3d0ef356ebc08789f5661 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 08:17:19 +0000 Subject: [PATCH 18/44] ruff fix --- src/mcp/server/fastmcp/server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 05bf7f3b7..c9883ca56 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -327,7 +327,10 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: return Context(request_context=request_context, fastmcp=self) async def call_tool( - self, name: str, arguments: dict[str, Any], request_context: RequestContext[ServerSession, LifespanResultT, Request] | None = None + self, + name: str, + arguments: dict[str, Any], + request_context: RequestContext[ServerSession, LifespanResultT, Request] | None = None, ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" if request_context: From 53e02fe7403c169c4453d4dc1f74f33a660f187f Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 16:18:15 +0530 Subject: [PATCH 19/44] removed fat abstract class --- src/mcp/server/transport_session.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index d0288a0f3..fcb4a21e8 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -11,17 +11,6 @@ class TransportSession(abc.ABC): """Abstract base class for transport sessions.""" - @property - @abc.abstractmethod - def client_params(self) -> types.InitializeRequestParams | None: - """Client initialization parameters.""" - raise NotImplementedError - - @abc.abstractmethod - def check_client_capability(self, capability: types.ClientCapabilities) -> bool: - """Check if the client supports a specific capability.""" - raise NotImplementedError - @abc.abstractmethod async def send_log_message( self, @@ -38,23 +27,6 @@ async def send_resource_updated(self, uri: AnyUrl) -> None: """Send a resource updated notification.""" raise NotImplementedError - @abc.abstractmethod - async def create_message( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - related_request_id: types.RequestId | None = None, - ) -> types.CreateMessageResult: - """Send a sampling/create_message request.""" - raise NotImplementedError - @abc.abstractmethod async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" From cf0f15243b785cc6c14e94e0c0c1af4634e1e7b8 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 16:27:28 +0530 Subject: [PATCH 20/44] removed client a thin interface --- src/mcp/client/transport_session.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 71e69ee3e..41150a039 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -80,16 +80,6 @@ async def call_tool( """Send a tools/call request with optional progress callback support.""" raise NotImplementedError - @abstractmethod - async def _validate_tool_result( - self, - name: str, - result: types.CallToolResult, - ) -> None: - """Validate the structured content of a tool result against its output - schema.""" - raise NotImplementedError - @abstractmethod async def list_prompts( self, From ccbdde86fa9042514d42707042d02afac79ba9fb Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 16:49:00 +0530 Subject: [PATCH 21/44] add description --- src/mcp/client/transport_session.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 41150a039..615774989 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -18,6 +18,7 @@ async def initialize(self) -> types.InitializeResult: @abstractmethod async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" raise NotImplementedError @abstractmethod @@ -28,6 +29,7 @@ async def send_progress_notification( total: float | None = None, message: str | None = None, ) -> None: + """Send a progress notification.""" raise NotImplementedError @abstractmethod From 380710e49e07e40895ffb1f0cbae9af54a7d2bf5 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 17:10:29 +0530 Subject: [PATCH 22/44] revert context change in this pr --- src/mcp/server/fastmcp/server.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c9883ca56..865b8e7e7 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -326,17 +326,9 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: request_context = None return Context(request_context=request_context, fastmcp=self) - async def call_tool( - self, - name: str, - arguments: dict[str, Any], - request_context: RequestContext[ServerSession, LifespanResultT, Request] | None = None, - ) -> Sequence[ContentBlock] | dict[str, Any]: + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" - if request_context: - context = Context(request_context=request_context, fastmcp=self) - else: - context = self.get_context() + context = self.get_context() return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) async def list_resources(self) -> list[MCPResource]: From 3f977b380cdcfcc048f23981c0985142ff6741d2 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 11:51:36 +0530 Subject: [PATCH 23/44] rename classes --- src/mcp/__init__.py | 4 ++++ src/mcp/client/session.py | 4 ++-- src/mcp/client/transport_session.py | 2 +- src/mcp/server/session.py | 4 ++-- src/mcp/server/transport_session.py | 22 +++++++++++----------- 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index e93b95c90..ae74dfa32 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,4 +1,6 @@ from .client.session import ClientSession +from .client.transport_session import ClientTransportSession +from .server.transport_session import ServerTransportSession from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client from .server.session import ServerSession @@ -113,4 +115,6 @@ "stdio_server", "CompleteRequest", "JSONRPCResponse", + "ClientTransportSession", + "ServerTransportSession", ] diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 02646924b..4243fa999 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,7 +9,7 @@ from typing_extensions import deprecated import mcp.types as types -from mcp.client.transport_session import TransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -101,7 +101,7 @@ async def _default_logging_callback( class ClientSession( - TransportSession, + ClientTransportSession, BaseSession[ types.ClientRequest, types.ClientNotification, diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 615774989..6f6f52322 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -8,7 +8,7 @@ from mcp.shared.session import ProgressFnT -class TransportSession(ABC): +class ClientTransportSession(ABC): """Abstract base class for communication transports.""" @abstractmethod diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 99fdb8f3f..96f879034 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions -from mcp.server.transport_session import TransportSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -70,7 +70,7 @@ class InitializationState(Enum): class ServerSession( - TransportSession, + ServerTransportSession, BaseSession[ types.ServerRequest, types.ServerNotification, diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index fcb4a21e8..bf3f6a1d1 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -1,6 +1,6 @@ """Abstract base class for transport sessions.""" -import abc +from abc import ABC, abstractmethod from typing import Any from pydantic import AnyUrl @@ -8,10 +8,10 @@ import mcp.types as types -class TransportSession(abc.ABC): +class ServerTransportSession(ABC): """Abstract base class for transport sessions.""" - @abc.abstractmethod + @abstractmethod async def send_log_message( self, level: types.LoggingLevel, @@ -22,17 +22,17 @@ async def send_log_message( """Send a log message notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_resource_updated(self, uri: AnyUrl) -> None: """Send a resource updated notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def elicit( self, message: str, @@ -42,12 +42,12 @@ async def elicit( """Send an elicitation/create request.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_progress_notification( self, progress_token: str | int, @@ -59,17 +59,17 @@ async def send_progress_notification( """Send a progress notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" raise NotImplementedError From ec7b6d6a2592c243dff686d3bc27685f186759c8 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 06:23:59 +0000 Subject: [PATCH 24/44] ruff fix --- src/mcp/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index ae74dfa32..93ef8acdf 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,10 +1,10 @@ from .client.session import ClientSession -from .client.transport_session import ClientTransportSession -from .server.transport_session import ServerTransportSession from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client +from .client.transport_session import ClientTransportSession from .server.session import ServerSession from .server.stdio import stdio_server +from .server.transport_session import ServerTransportSession from .shared.exceptions import McpError from .types import ( CallToolRequest, From 0359aa899a2a89388e01816c4c7dd48c57ca196d Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Wed, 12 Nov 2025 14:18:15 +0530 Subject: [PATCH 25/44] merge main --- .../mcp_simple_auth_client/main.py | 4 ++-- .../simple-chatbot/mcp_simple_chatbot/main.py | 3 ++- .../snippets/clients/display_utilities.py | 5 ++-- examples/snippets/clients/stdio_client.py | 3 ++- src/mcp/client/session.py | 16 ++++++------- src/mcp/client/session_group.py | 24 ++++++++++--------- src/mcp/client/transport_session.py | 22 +++++++++++++++-- src/mcp/shared/context.py | 3 ++- tests/client/test_list_roots_callback.py | 4 ++-- tests/client/test_sampling_callback.py | 4 ++-- tests/client/test_session.py | 6 ++--- tests/server/fastmcp/test_elicitation.py | 18 +++++++------- tests/server/fastmcp/test_integration.py | 6 ++--- tests/shared/test_streamable_http.py | 4 ++-- 14 files changed, 74 insertions(+), 48 deletions(-) diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 5987a878e..6c7201e04 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -17,7 +17,7 @@ from urllib.parse import parse_qs, urlparse from mcp.client.auth import OAuthClientProvider, TokenStorage -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -153,7 +153,7 @@ class SimpleAuthClient: def __init__(self, server_url: str, transport_type: str = "streamable-http"): self.server_url = server_url self.transport_type = transport_type - self.session: ClientSession | None = None + self.session: ClientTransportSession | None = None async def connect(self): """Connect to the MCP server.""" diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 78a81a4d9..3a9d201b1 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -10,6 +10,7 @@ from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.transport_session import ClientTransportSession # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -67,7 +68,7 @@ def __init__(self, name: str, config: dict[str, Any]) -> None: self.name: str = name self.config: dict[str, Any] = config self.stdio_context: Any | None = None - self.session: ClientSession | None = None + self.session: ClientTransportSession | None = None self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.exit_stack: AsyncExitStack = AsyncExitStack() diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index 5f1d50510..5e1b203ee 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -8,6 +8,7 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection @@ -18,7 +19,7 @@ ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientTransportSession): """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -30,7 +31,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientTransportSession): """Display available resources with human-readable names""" resources_response = await session.list_resources() diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index ac978035d..62fb0f4c4 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -9,6 +9,7 @@ from pydantic import AnyUrl from mcp import ClientSession, StdioServerParameters, types +from mcp.client.session import ClientTransportSession from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -22,7 +23,7 @@ # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 4243fa999..c3559b13a 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -23,7 +23,7 @@ class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: ... # pragma: no branch @@ -31,15 +31,15 @@ async def __call__( class ElicitationFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext["ClientSession", Any] - ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch + self, context: RequestContext["ClientTransportSession", Any] + ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch class LoggingFnT(Protocol): @@ -63,7 +63,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.ErrorData( @@ -73,7 +73,7 @@ async def _default_sampling_callback( async def _default_elicitation_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( # pragma: no cover @@ -83,7 +83,7 @@ async def _default_elicitation_callback( async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -510,7 +510,7 @@ async def send_roots_list_changed(self) -> None: # pragma: no cover await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: - ctx = RequestContext[ClientSession, Any]( + ctx = RequestContext[ClientTransportSession, Any]( request_id=responder.request_id, meta=responder.request_meta, session=self, diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 2c55bb775..9e95ed909 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -96,10 +96,10 @@ class _ComponentNames(BaseModel): _tools: dict[str, types.Tool] # Client-server connection management. - _sessions: dict[mcp.ClientSession, _ComponentNames] - _tool_to_session: dict[str, mcp.ClientSession] + _sessions: dict[mcp.ClientTransportSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientTransportSession] _exit_stack: contextlib.AsyncExitStack - _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] + _session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, serverInfo) for custom names. # This is provide a means to mitigate naming conflicts across servers. @@ -153,7 +153,7 @@ async def __aexit__( tg.start_soon(exit_stack.aclose) @property - def sessions(self) -> list[mcp.ClientSession]: + def sessions(self) -> list[mcp.ClientTransportSession]: """Returns the list of sessions being managed.""" return list(self._sessions.keys()) # pragma: no cover @@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) - async def disconnect_from_server(self, session: mcp.ClientSession) -> None: + async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None: """Disconnects from a single MCP server.""" session_known_for_components = session in self._sessions @@ -216,8 +216,8 @@ async def disconnect_from_server(self, session: mcp.ClientSession) -> None: await session_stack_to_close.aclose() # pragma: no cover async def connect_with_session( - self, server_info: types.Implementation, session: mcp.ClientSession - ) -> mcp.ClientSession: + self, server_info: types.Implementation, session: mcp.ClientTransportSession + ) -> mcp.ClientTransportSession: """Connects to a single MCP server.""" await self._aggregate_components(server_info, session) return session @@ -225,14 +225,14 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, - ) -> mcp.ClientSession: + ) -> mcp.ClientTransportSession: """Connects to a single MCP server.""" server_info, session = await self._establish_session(server_params) return await self.connect_with_session(server_info, session) async def _establish_session( self, server_params: ServerParameters - ) -> tuple[types.Implementation, mcp.ClientSession]: + ) -> tuple[types.Implementation, mcp.ClientTransportSession]: """Establish a client session to an MCP server.""" session_stack = contextlib.AsyncExitStack() @@ -276,7 +276,9 @@ async def _establish_session( await session_stack.aclose() raise - async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: + async def _aggregate_components( + self, server_info: types.Implementation, session: mcp.ClientTransportSession + ) -> None: """Aggregates prompts, resources, and tools from a given session.""" # Create a reverse index so we can find all prompts, resources, and @@ -289,7 +291,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session prompts_temp: dict[str, types.Prompt] = {} resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} - tool_to_session_temp: dict[str, mcp.ClientSession] = {} + tool_to_session_temp: dict[str, mcp.ClientTransportSession] = {} # Query the server for its prompts and aggregate to list. try: diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 6f6f52322..c51b059f6 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from datetime import timedelta -from typing import Any +from typing import Any, overload from pydantic import AnyUrl +from typing_extensions import deprecated from mcp import types from mcp.shared.session import ProgressFnT @@ -109,12 +110,29 @@ async def complete( """Send a completion/complete request.""" raise NotImplementedError + @overload + @deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead") + async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ... + + @overload + async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ... + + @overload + async def list_tools(self) -> types.ListToolsResult: ... + @abstractmethod async def list_tools( self, cursor: str | None = None, + *, + params: types.PaginatedRequestParams | None = None, ) -> types.ListToolsResult: - """Send a tools/list request.""" + """Send a tools/list request. + + Args: + cursor: Simple cursor string for pagination (deprecated, use params instead) + params: Full pagination parameters including cursor and any future fields + """ raise NotImplementedError @abstractmethod diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f3006e7d5..0fb12c649 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -3,10 +3,11 @@ from typing_extensions import TypeVar +from mcp.client.transport_session import ClientTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) +SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 0da0fff07..dc53eddbc 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,7 +1,7 @@ import pytest from pydantic import FileUrl -from mcp.client.session import ClientSession +from mcp.client.session import ClientTransportSession from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession from mcp.shared.context import RequestContext @@ -31,7 +31,7 @@ async def test_list_roots_callback(): ) async def list_roots_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientTransportSession, None], ) -> ListRootsResult: return callback_return diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index a3f6affda..8cd2c7116 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,6 @@ import pytest -from mcp.client.session import ClientSession +from mcp.client.session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, @@ -27,7 +27,7 @@ async def test_sampling_callback(): ) async def sampling_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientTransportSession, None], params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 8d0ef68a9..c327a806f 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -4,7 +4,7 @@ import pytest import mcp.types as types -from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession, ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @@ -427,7 +427,7 @@ async def test_client_capabilities_with_custom_callbacks(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( @@ -437,7 +437,7 @@ async def custom_sampling_callback( # pragma: no cover ) async def custom_list_roots_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 2c74d0e88..dd1ae72dc 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,7 +7,7 @@ import pytest from pydantic import BaseModel, Field -from mcp.client.session import ClientSession, ElicitationFnT +from mcp.client.session import ClientTransportSession, ElicitationFnT from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.context import RequestContext @@ -72,7 +72,7 @@ async def test_stdio_elicitation(): # Create a custom handler for elicitation requests async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) @@ -90,7 +90,7 @@ async def test_stdio_elicitation_decline(): mcp = FastMCP(name="StdioElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult(action="decline") await call_tool_and_assert( @@ -129,7 +129,7 @@ class InvalidNestedSchema(BaseModel): # Dummy callback (won't be called due to validation failure) async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -189,7 +189,7 @@ async def optional_tool(ctx: Context[ServerSession, None]) -> str: for content, expected in test_cases: - async def callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -208,7 +208,7 @@ async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pr return f"Validation failed: {str(e)}" async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -245,7 +245,9 @@ async def defaults_tool(ctx: Context[ServerSession, None]) -> str: return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients - async def callback_schema_verify(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_schema_verify( + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams + ): # Verify the schema includes defaults schema = params.requestedSchema props = schema["properties"] @@ -266,7 +268,7 @@ async def callback_schema_verify(context: RequestContext[ClientSession, None], p ) # Test overriding defaults - async def callback_override(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_override(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult( action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False} ) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index b1cefca29..778b0bfd7 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -32,7 +32,7 @@ structured_output, tool_progress, ) -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.shared.context import RequestContext @@ -212,7 +212,7 @@ def unpack_streams( # Callback functions for testing async def sampling_callback( - context: RequestContext[ClientSession, None], params: CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: CreateMessageRequestParams ) -> CreateMessageResult: """Sampling callback for tests.""" return CreateMessageResult( @@ -225,7 +225,7 @@ async def sampling_callback( ) -async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): +async def elicitation_callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 43b321d96..736e261cd 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -21,7 +21,7 @@ from starlette.routing import Mount import mcp.types as types -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( @@ -1233,7 +1233,7 @@ async def test_streamablehttp_server_sampling(basic_server: None, basic_server_u # Define sampling callback that returns a mock response async def sampling_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult: nonlocal sampling_callback_invoked, captured_message_params From b733fcfc8ea0df816ff5f82b3b9f4f24301f6363 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:16:07 +0000 Subject: [PATCH 26/44] fix type hints for serversession --- examples/snippets/servers/elicitation.py | 4 ++-- examples/snippets/servers/lifespan_example.py | 4 ++-- examples/snippets/servers/notifications.py | 4 ++-- examples/snippets/servers/tool_progress.py | 4 ++-- src/mcp/server/elicitation.py | 4 ++-- src/mcp/server/fastmcp/server.py | 4 ++-- src/mcp/server/lowlevel/server.py | 6 +++--- src/mcp/server/session.py | 2 +- src/mcp/shared/context.py | 6 +++++- tests/client/test_sampling_callback.py | 5 ++++- tests/shared/test_streamable_http.py | 6 ++++-- 11 files changed, 29 insertions(+), 20 deletions(-) diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 2c8a3b35a..049b42516 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -17,7 +17,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str: """Book a table with date availability check.""" # Check if date is available if date == "2024-12-25": diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 62278b6aa..32b699730 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession # Mock database class for example @@ -51,7 +51,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 833bc8905..36d9712eb 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,11 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index 2ac458f6a..dddd8c9eb 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,11 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index bba988f49..65399e27c 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -74,7 +74,7 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: async def elicit_with_validation( - session: ServerSession, + session: ServerTransportSession, message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 865b8e7e7..03e223329 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -54,7 +54,7 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSession, ServerSessionT +from mcp.server.session import ServerSessionT, ServerTransportSession from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore @@ -315,7 +315,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: + def get_context(self) -> Context[ServerTransportSession, LifespanResultT, Request]: """ Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 49d289fb7..329cd1dd2 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -85,7 +85,7 @@ async def main(): from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp.server.session import ServerSession, ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage @@ -102,7 +102,7 @@ async def main(): CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[RequestContext[ServerTransportSession, Any, Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: @@ -231,7 +231,7 @@ def get_capabilities( @property def request_context( self, - ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: + ) -> RequestContext[ServerTransportSession, LifespanResultT, RequestT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 96f879034..9456ebf9f 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -62,7 +62,7 @@ class InitializationState(Enum): Initialized = 3 -ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") +ServerSessionT = TypeVar("ServerSessionT", bound="ServerTransportSession") ServerRequestResponder = ( RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 0fb12c649..094fbddf4 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -4,10 +4,14 @@ from typing_extensions import TypeVar from mcp.client.transport_session import ClientTransportSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession) +SessionT = TypeVar("SessionT", + bound=BaseSession[Any, Any, Any, Any, Any] | + ClientTransportSession | + ServerTransportSession) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 8cd2c7116..3fe50a132 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,8 @@ import pytest from mcp.client.session import ClientTransportSession +from mcp.server.session import ServerSession +from typing import cast from mcp.shared.context import RequestContext from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, @@ -34,7 +36,8 @@ async def sampling_callback( @server.tool("test_sampling") async def test_sampling_tool(message: str): - value = await server.get_context().session.create_message( + session = cast(ServerSession, server.get_context().session) + value = await session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 736e261cd..95bbd633e 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -8,7 +8,7 @@ import multiprocessing import socket from collections.abc import Generator -from typing import Any +from typing import Any, cast import anyio import httpx @@ -22,6 +22,7 @@ import mcp.types as types from mcp.client.session import ClientSession, ClientTransportSession +from mcp.server.session import ServerSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( @@ -198,7 +199,8 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] elif name == "test_sampling_tool": # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( + session = cast(ServerSession, ctx.session) + sampling_result = await session.create_message( messages=[ types.SamplingMessage( role="user", From cdc39f4edbe0af04ab84f338e8b807e87ac514dc Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:16:44 +0000 Subject: [PATCH 27/44] fix ruff --- src/mcp/server/lowlevel/server.py | 4 +++- src/mcp/shared/context.py | 7 +++---- tests/client/test_sampling_callback.py | 3 ++- tests/shared/test_streamable_http.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 329cd1dd2..b60e04974 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -102,7 +102,9 @@ async def main(): CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerTransportSession, Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[RequestContext[ServerTransportSession, Any, Any]] = contextvars.ContextVar( + "request_ctx" +) class NotificationOptions: diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 094fbddf4..63fafa241 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -8,10 +8,9 @@ from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -SessionT = TypeVar("SessionT", - bound=BaseSession[Any, Any, Any, Any, Any] | - ClientTransportSession | - ServerTransportSession) +SessionT = TypeVar( + "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession +) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 3fe50a132..feed499af 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,8 +1,9 @@ +from typing import cast + import pytest from mcp.client.session import ClientTransportSession from mcp.server.session import ServerSession -from typing import cast from mcp.shared.context import RequestContext from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 95bbd633e..08968d6f7 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -22,9 +22,9 @@ import mcp.types as types from mcp.client.session import ClientSession, ClientTransportSession -from mcp.server.session import ServerSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server +from mcp.server.session import ServerSession from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, From 65a3b0f8aee6257cb70ddbef97873846890100a9 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:17:17 +0000 Subject: [PATCH 28/44] uv run scripts/update_readme_snippets.py --- README.md | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 5dbc4bd9d..6450adbe9 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession # Mock database class for example @@ -254,7 +254,7 @@ mcp = FastMCP("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() @@ -326,13 +326,13 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -674,13 +674,13 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -798,7 +798,7 @@ Request additional information from users. This example shows an Elicitation dur from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -814,7 +814,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str: """Book a table with date availability check.""" # Check if date is available if date == "2024-12-25": @@ -888,13 +888,13 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") @@ -2038,6 +2038,7 @@ import os from pydantic import AnyUrl from mcp import ClientSession, StdioServerParameters, types +from mcp.client.session import ClientTransportSession from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -2051,7 +2052,7 @@ server_params = StdioServerParameters( # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( @@ -2169,6 +2170,7 @@ import os from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection @@ -2179,7 +2181,7 @@ server_params = StdioServerParameters( ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientTransportSession): """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -2191,7 +2193,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientTransportSession): """Display available resources with human-readable names""" resources_response = await session.list_resources() From f34e8fe12c101de34ce67e316b9c9e490487c555 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:23:41 +0000 Subject: [PATCH 29/44] some fixes --- src/mcp/client/session_group.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9e95ed909..233c532ce 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -96,10 +96,10 @@ class _ComponentNames(BaseModel): _tools: dict[str, types.Tool] # Client-server connection management. - _sessions: dict[mcp.ClientTransportSession, _ComponentNames] - _tool_to_session: dict[str, mcp.ClientTransportSession] + _sessions: dict["mcp.ClientTransportSession", _ComponentNames] + _tool_to_session: dict[str, "mcp.ClientTransportSession"] _exit_stack: contextlib.AsyncExitStack - _session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack] + _session_exit_stacks: dict["mcp.ClientTransportSession", contextlib.AsyncExitStack] # Optional fn consuming (component_name, serverInfo) for custom names. # This is provide a means to mitigate naming conflicts across servers. @@ -153,7 +153,7 @@ async def __aexit__( tg.start_soon(exit_stack.aclose) @property - def sessions(self) -> list[mcp.ClientTransportSession]: + def sessions(self) -> list["mcp.ClientTransportSession"]: """Returns the list of sessions being managed.""" return list(self._sessions.keys()) # pragma: no cover @@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) - async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None: + async def disconnect_from_server(self, session: "mcp.ClientTransportSession") -> None: """Disconnects from a single MCP server.""" session_known_for_components = session in self._sessions @@ -216,8 +216,8 @@ async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> N await session_stack_to_close.aclose() # pragma: no cover async def connect_with_session( - self, server_info: types.Implementation, session: mcp.ClientTransportSession - ) -> mcp.ClientTransportSession: + self, server_info: types.Implementation, session: "mcp.ClientTransportSession" + ) -> "mcp.ClientTransportSession": """Connects to a single MCP server.""" await self._aggregate_components(server_info, session) return session @@ -225,14 +225,14 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, - ) -> mcp.ClientTransportSession: + ) -> "mcp.ClientTransportSession": """Connects to a single MCP server.""" server_info, session = await self._establish_session(server_params) return await self.connect_with_session(server_info, session) async def _establish_session( self, server_params: ServerParameters - ) -> tuple[types.Implementation, mcp.ClientTransportSession]: + ) -> tuple[types.Implementation, "mcp.ClientTransportSession"]: """Establish a client session to an MCP server.""" session_stack = contextlib.AsyncExitStack() @@ -277,7 +277,7 @@ async def _establish_session( raise async def _aggregate_components( - self, server_info: types.Implementation, session: mcp.ClientTransportSession + self, server_info: types.Implementation, session: "mcp.ClientTransportSession" ) -> None: """Aggregates prompts, resources, and tools from a given session.""" @@ -291,7 +291,7 @@ async def _aggregate_components( prompts_temp: dict[str, types.Prompt] = {} resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} - tool_to_session_temp: dict[str, mcp.ClientTransportSession] = {} + tool_to_session_temp: dict[str, "mcp.ClientTransportSession"] = {} # Query the server for its prompts and aggregate to list. try: From 1bfc08696de26b14b83055eb855dab236fa211bd Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:24:59 +0000 Subject: [PATCH 30/44] fix ruff --- src/mcp/client/session_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 233c532ce..f3d351d31 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -291,7 +291,7 @@ async def _aggregate_components( prompts_temp: dict[str, types.Prompt] = {} resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} - tool_to_session_temp: dict[str, "mcp.ClientTransportSession"] = {} + tool_to_session_temp: dict[str, mcp.ClientTransportSession] = {} # Query the server for its prompts and aggregate to list. try: From 481f7eabe10ef4b6cbedbb5745b160fe111e6ae7 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:49:55 +0000 Subject: [PATCH 31/44] fix type hints without cast --- src/mcp/client/session_group.py | 20 ++++++++++---------- src/mcp/shared/memory.py | 3 ++- tests/client/test_sampling_callback.py | 5 ++--- tests/server/test_cancel_handling.py | 2 ++ tests/shared/test_memory.py | 4 ++-- tests/shared/test_progress_notifications.py | 1 + tests/shared/test_session.py | 8 +++++--- tests/shared/test_sse.py | 4 ++-- tests/shared/test_streamable_http.py | 5 +++-- tests/shared/test_ws.py | 4 ++-- 10 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index f3d351d31..9e95ed909 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -96,10 +96,10 @@ class _ComponentNames(BaseModel): _tools: dict[str, types.Tool] # Client-server connection management. - _sessions: dict["mcp.ClientTransportSession", _ComponentNames] - _tool_to_session: dict[str, "mcp.ClientTransportSession"] + _sessions: dict[mcp.ClientTransportSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientTransportSession] _exit_stack: contextlib.AsyncExitStack - _session_exit_stacks: dict["mcp.ClientTransportSession", contextlib.AsyncExitStack] + _session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, serverInfo) for custom names. # This is provide a means to mitigate naming conflicts across servers. @@ -153,7 +153,7 @@ async def __aexit__( tg.start_soon(exit_stack.aclose) @property - def sessions(self) -> list["mcp.ClientTransportSession"]: + def sessions(self) -> list[mcp.ClientTransportSession]: """Returns the list of sessions being managed.""" return list(self._sessions.keys()) # pragma: no cover @@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) - async def disconnect_from_server(self, session: "mcp.ClientTransportSession") -> None: + async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None: """Disconnects from a single MCP server.""" session_known_for_components = session in self._sessions @@ -216,8 +216,8 @@ async def disconnect_from_server(self, session: "mcp.ClientTransportSession") -> await session_stack_to_close.aclose() # pragma: no cover async def connect_with_session( - self, server_info: types.Implementation, session: "mcp.ClientTransportSession" - ) -> "mcp.ClientTransportSession": + self, server_info: types.Implementation, session: mcp.ClientTransportSession + ) -> mcp.ClientTransportSession: """Connects to a single MCP server.""" await self._aggregate_components(server_info, session) return session @@ -225,14 +225,14 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, - ) -> "mcp.ClientTransportSession": + ) -> mcp.ClientTransportSession: """Connects to a single MCP server.""" server_info, session = await self._establish_session(server_params) return await self.connect_with_session(server_info, session) async def _establish_session( self, server_params: ServerParameters - ) -> tuple[types.Implementation, "mcp.ClientTransportSession"]: + ) -> tuple[types.Implementation, mcp.ClientTransportSession]: """Establish a client session to an MCP server.""" session_stack = contextlib.AsyncExitStack() @@ -277,7 +277,7 @@ async def _establish_session( raise async def _aggregate_components( - self, server_info: types.Implementation, session: "mcp.ClientTransportSession" + self, server_info: types.Implementation, session: mcp.ClientTransportSession ) -> None: """Aggregates prompts, resources, and tools from a given session.""" diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 06d404e31..9f68a0c47 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -14,6 +14,7 @@ import mcp.types as types from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ClientTransportSession from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage @@ -57,7 +58,7 @@ async def create_connected_server_and_client_session( client_info: types.Implementation | None = None, raise_exceptions: bool = False, elicitation_callback: ElicitationFnT | None = None, -) -> AsyncGenerator[ClientSession, None]: +) -> AsyncGenerator[ClientTransportSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" # TODO(Marcelo): we should have a proper `Client` that can use this "in-memory transport", diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index feed499af..49138398c 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,5 +1,3 @@ -from typing import cast - import pytest from mcp.client.session import ClientTransportSession @@ -37,7 +35,8 @@ async def sampling_callback( @server.tool("test_sampling") async def test_sampling_tool(message: str): - session = cast(ServerSession, server.get_context().session) + session = server.get_context().session + assert isinstance(session, ServerSession) value = await session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 47c49bb62..3a0df20cc 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -9,6 +9,7 @@ from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_connected_server_and_client_session +from mcp.client.session import ClientSession from mcp.types import ( CallToolRequest, CallToolRequestParams, @@ -56,6 +57,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ async with create_connected_server_and_client_session(server) as client: # First request (will be cancelled) + assert isinstance(client, ClientSession) async def first_request(): try: await client.send_request( diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index ca4368e9f..4ebce3b15 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -2,7 +2,7 @@ from pydantic import AnyUrl from typing_extensions import AsyncGenerator -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.server import Server from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import EmptyResult, Resource @@ -28,7 +28,7 @@ async def handle_list_resources(): # pragma: no cover @pytest.fixture async def client_connected_to_server( mcp_server: Server, -) -> AsyncGenerator[ClientSession, None]: +) -> AsyncGenerator[ClientTransportSession, None]: async with create_connected_server_and_client_session(mcp_server) as client_session: yield client_session diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 1552711d2..25afd7f32 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -370,6 +370,7 @@ async def handle_list_tools() -> list[types.Tool]: with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): async with create_connected_server_and_client_session(server) as client_session: # Send a request with a failing progress callback + assert isinstance(client_session, ClientSession) result = await client_session.send_request( types.ClientRequest( types.CallToolRequest( diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 313ec9926..47b5a02f6 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -5,7 +5,7 @@ import pytest import mcp.types as types -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session @@ -27,19 +27,20 @@ def mcp_server() -> Server: @pytest.fixture async def client_connected_to_server( mcp_server: Server, -) -> AsyncGenerator[ClientSession, None]: +) -> AsyncGenerator[ClientTransportSession, None]: async with create_connected_server_and_client_session(mcp_server) as client_session: yield client_session @pytest.mark.anyio async def test_in_flight_requests_cleared_after_completion( - client_connected_to_server: ClientSession, + client_connected_to_server: ClientTransportSession, ): """Verify that _in_flight is empty after all requests complete.""" # Send a request and wait for response response = await client_connected_to_server.send_ping() assert isinstance(response, EmptyResult) + assert isinstance(client_connected_to_server, ClientSession) # Verify _in_flight is empty assert len(client_connected_to_server._in_flight) == 0 @@ -101,6 +102,7 @@ async def make_request(client_session: ClientSession): async with create_connected_server_and_client_session(make_server()) as client_session: async with anyio.create_task_group() as tg: + assert isinstance(client_session, ClientSession) tg.start_soon(make_request, client_session) # Wait for the request to be in-flight diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 28ac07d09..ba823ab6a 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -17,7 +17,7 @@ from starlette.routing import Mount, Route import mcp.types as types -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport @@ -185,7 +185,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 08968d6f7..603a4270a 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -8,7 +8,7 @@ import multiprocessing import socket from collections.abc import Generator -from typing import Any, cast +from typing import Any import anyio import httpx @@ -199,7 +199,8 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] elif name == "test_sampling_tool": # Test sampling by requesting the client to sample a message - session = cast(ServerSession, ctx.session) + session = ctx.session + assert isinstance(session, ServerSession) sampling_result = await session.create_message( messages=[ types.SamplingMessage( diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index f093cb492..1fac8696f 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -12,7 +12,7 @@ from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.websocket import websocket_client from mcp.server import Server from mcp.server.websocket import websocket_server @@ -125,7 +125,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]: """Create and initialize a WebSocket client session""" async with websocket_client(server_url + "/ws") as streams: async with ClientSession(*streams) as session: From 6b8f7374b71b037098864047885ec356a4e930d9 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:51:45 +0000 Subject: [PATCH 32/44] fix ruff --- src/mcp/shared/memory.py | 11 +++++++++-- tests/server/test_cancel_handling.py | 3 ++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 9f68a0c47..b8466fe91 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -13,8 +13,15 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types -from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT -from mcp.client.session import ClientTransportSession +from mcp.client.session import ( + ClientSession, + ClientTransportSession, + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + SamplingFnT, +) from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 3a0df20cc..b1f825933 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -6,10 +6,10 @@ import pytest import mcp.types as types +from mcp.client.session import ClientSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_connected_server_and_client_session -from mcp.client.session import ClientSession from mcp.types import ( CallToolRequest, CallToolRequestParams, @@ -58,6 +58,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ async with create_connected_server_and_client_session(server) as client: # First request (will be cancelled) assert isinstance(client, ClientSession) + async def first_request(): try: await client.send_request( From 99856e813f2c2272a034cbbcfddb6da7e6612500 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:54:29 +0000 Subject: [PATCH 33/44] remove overload --- src/mcp/client/transport_session.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index c51b059f6..37578f211 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,11 +1,10 @@ from abc import ABC, abstractmethod from datetime import timedelta -from typing import Any, overload +from typing import Any from pydantic import AnyUrl -from typing_extensions import deprecated -from mcp import types +import mcp.types as types from mcp.shared.session import ProgressFnT @@ -110,15 +109,8 @@ async def complete( """Send a completion/complete request.""" raise NotImplementedError - @overload - @deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead") - async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ... - @overload - async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ... - - @overload - async def list_tools(self) -> types.ListToolsResult: ... + @abstractmethod async def list_tools( From ea8a33ca220397e88b8e6e11225c1c669c759fe2 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:56:27 +0000 Subject: [PATCH 34/44] revert client session group --- src/mcp/client/session_group.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9e95ed909..ecab5aecf 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -96,10 +96,10 @@ class _ComponentNames(BaseModel): _tools: dict[str, types.Tool] # Client-server connection management. - _sessions: dict[mcp.ClientTransportSession, _ComponentNames] - _tool_to_session: dict[str, mcp.ClientTransportSession] + _sessions: dict[mcp.ClientSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientSession] _exit_stack: contextlib.AsyncExitStack - _session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack] + _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, serverInfo) for custom names. # This is provide a means to mitigate naming conflicts across servers. @@ -153,7 +153,7 @@ async def __aexit__( tg.start_soon(exit_stack.aclose) @property - def sessions(self) -> list[mcp.ClientTransportSession]: + def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" return list(self._sessions.keys()) # pragma: no cover @@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) - async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None: + async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" session_known_for_components = session in self._sessions @@ -216,8 +216,8 @@ async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> N await session_stack_to_close.aclose() # pragma: no cover async def connect_with_session( - self, server_info: types.Implementation, session: mcp.ClientTransportSession - ) -> mcp.ClientTransportSession: + self, server_info: types.Implementation, session: mcp.ClientSession + ) -> mcp.ClientSession: """Connects to a single MCP server.""" await self._aggregate_components(server_info, session) return session @@ -225,14 +225,14 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, - ) -> mcp.ClientTransportSession: + ) -> mcp.ClientSession: """Connects to a single MCP server.""" server_info, session = await self._establish_session(server_params) return await self.connect_with_session(server_info, session) async def _establish_session( self, server_params: ServerParameters - ) -> tuple[types.Implementation, mcp.ClientTransportSession]: + ) -> tuple[types.Implementation, mcp.ClientSession]: """Establish a client session to an MCP server.""" session_stack = contextlib.AsyncExitStack() @@ -277,7 +277,7 @@ async def _establish_session( raise async def _aggregate_components( - self, server_info: types.Implementation, session: mcp.ClientTransportSession + self, server_info: types.Implementation, session: mcp.ClientSession ) -> None: """Aggregates prompts, resources, and tools from a given session.""" @@ -291,7 +291,7 @@ async def _aggregate_components( prompts_temp: dict[str, types.Prompt] = {} resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} - tool_to_session_temp: dict[str, mcp.ClientTransportSession] = {} + tool_to_session_temp: dict[str, mcp.ClientSession] = {} # Query the server for its prompts and aggregate to list. try: From 5bcfe6200e0314d9a8fcf7e1d7c8a50fcf9be5ce Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 09:30:51 +0000 Subject: [PATCH 35/44] fix ruff pyright --- src/mcp/client/session_group.py | 4 +--- src/mcp/client/transport_session.py | 3 --- src/mcp/shared/context.py | 10 ++++++---- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index ecab5aecf..2c55bb775 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -276,9 +276,7 @@ async def _establish_session( await session_stack.aclose() raise - async def _aggregate_components( - self, server_info: types.Implementation, session: mcp.ClientSession - ) -> None: + async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: """Aggregates prompts, resources, and tools from a given session.""" # Create a reverse index so we can find all prompts, resources, and diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 37578f211..07389d59a 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -109,9 +109,6 @@ async def complete( """Send a completion/complete request.""" raise NotImplementedError - - - @abstractmethod async def list_tools( self, diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 63fafa241..845cc50e2 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,15 +1,17 @@ from dataclasses import dataclass -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic from typing_extensions import TypeVar -from mcp.client.transport_session import ClientTransportSession -from mcp.server.transport_session import ServerTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams +if TYPE_CHECKING: + from mcp.client.session import ClientTransportSession + from mcp.server.session import ServerTransportSession + SessionT = TypeVar( - "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession + "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" ) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) From af6be96916628e27d6e4661ca34e48749e51cfa6 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Wed, 12 Nov 2025 08:50:14 +0000 Subject: [PATCH 36/44] fix ruff --- src/mcp/client/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c3559b13a..0bd4e9608 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -39,7 +39,7 @@ async def __call__( class ListRootsFnT(Protocol): async def __call__( self, context: RequestContext["ClientTransportSession", Any] - ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch + ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch class LoggingFnT(Protocol): From 4377d41eb524bb6165b7bf1d9a6f9d2eb392fcd6 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:27:36 +0000 Subject: [PATCH 37/44] fix imports --- examples/snippets/clients/display_utilities.py | 3 +-- examples/snippets/clients/stdio_client.py | 3 +-- examples/snippets/servers/elicitation.py | 2 +- examples/snippets/servers/lifespan_example.py | 2 +- examples/snippets/servers/notifications.py | 2 +- examples/snippets/servers/tool_progress.py | 2 +- src/mcp/server/elicitation.py | 2 +- src/mcp/server/fastmcp/server.py | 3 ++- src/mcp/server/lowlevel/server.py | 3 ++- src/mcp/shared/context.py | 9 ++++----- src/mcp/shared/memory.py | 2 +- tests/client/test_list_roots_callback.py | 2 +- tests/client/test_sampling_callback.py | 2 +- tests/client/test_session.py | 3 ++- tests/server/fastmcp/test_elicitation.py | 3 ++- tests/server/fastmcp/test_integration.py | 3 ++- tests/shared/test_memory.py | 3 ++- tests/shared/test_session.py | 3 ++- tests/shared/test_sse.py | 3 ++- tests/shared/test_streamable_http.py | 3 ++- tests/shared/test_ws.py | 3 ++- 21 files changed, 34 insertions(+), 27 deletions(-) diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index 5e1b203ee..b8ad7dffc 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -6,9 +6,8 @@ import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index 62fb0f4c4..c72cc54f2 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -8,8 +8,7 @@ from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.session import ClientTransportSession +from mcp import ClientSession, StdioServerParameters, types, ClientTransportSession from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 049b42516..45f2cb68b 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 32b699730..46f01f427 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 36d9712eb..995ecd817 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,5 +1,5 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index dddd8c9eb..a0f62fda6 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,5 +1,5 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 65399e27c..b2f33ec7c 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 03e223329..840273e50 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -54,7 +54,8 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSessionT, ServerTransportSession +from mcp.server.session import ServerSessionT +from mcp.server.transport_session import ServerTransportSession from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b60e04974..85846afc6 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -85,7 +85,8 @@ async def main(): from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession, ServerTransportSession +from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 845cc50e2..eaa3e2793 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,17 +1,16 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic +from typing import Any, Generic from typing_extensions import TypeVar from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -if TYPE_CHECKING: - from mcp.client.session import ClientTransportSession - from mcp.server.session import ServerTransportSession +from mcp.client.transport_session import ClientTransportSession +from mcp.server.transport_session import ServerTransportSession SessionT = TypeVar( - "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" + "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession ) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index b8466fe91..2d203d743 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -15,13 +15,13 @@ import mcp.types as types from mcp.client.session import ( ClientSession, - ClientTransportSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT, ) +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index dc53eddbc..5acb3b21a 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,7 +1,7 @@ import pytest from pydantic import FileUrl -from mcp.client.session import ClientTransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession from mcp.shared.context import RequestContext diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 49138398c..9fb6e29c7 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,6 @@ import pytest -from mcp.client.session import ClientTransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.memory import ( diff --git a/tests/client/test_session.py b/tests/client/test_session.py index c327a806f..bd51e4e10 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -4,7 +4,8 @@ import pytest import mcp.types as types -from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession, ClientTransportSession +from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index dd1ae72dc..77f97e677 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,7 +7,8 @@ import pytest from pydantic import BaseModel, Field -from mcp.client.session import ClientTransportSession, ElicitationFnT +from mcp.client.session import ElicitationFnT +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.context import RequestContext diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 778b0bfd7..99e8972a9 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -32,7 +32,8 @@ structured_output, tool_progress, ) -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.shared.context import RequestContext diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index 4ebce3b15..56e0b98e7 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -2,7 +2,8 @@ from pydantic import AnyUrl from typing_extensions import AsyncGenerator -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import EmptyResult, Resource diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 47b5a02f6..a056f705b 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -5,7 +5,8 @@ import pytest import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index ba823ab6a..967925a11 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -17,7 +17,8 @@ from starlette.routing import Mount, Route import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 603a4270a..7aa768ae1 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -21,7 +21,8 @@ from starlette.routing import Mount import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.session import ServerSession diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 1fac8696f..107cd5589 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -12,7 +12,8 @@ from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.websocket import websocket_client from mcp.server import Server from mcp.server.websocket import websocket_server From f02873fa25e2dc96bfe44572e31a688167f202bc Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:29:29 +0000 Subject: [PATCH 38/44] fix ruff --- examples/snippets/clients/stdio_client.py | 2 +- src/mcp/server/fastmcp/server.py | 2 +- src/mcp/shared/context.py | 5 ++--- tests/server/fastmcp/test_integration.py | 2 +- tests/shared/test_sse.py | 2 +- tests/shared/test_streamable_http.py | 2 +- 6 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index c72cc54f2..90f9fdff9 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -8,7 +8,7 @@ from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types, ClientTransportSession +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 840273e50..cc05403dd 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -55,12 +55,12 @@ from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.session import ServerSessionT -from mcp.server.transport_session import ServerTransportSession from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index eaa3e2793..63fafa241 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -3,11 +3,10 @@ from typing_extensions import TypeVar -from mcp.shared.session import BaseSession -from mcp.types import RequestId, RequestParams - from mcp.client.transport_session import ClientTransportSession from mcp.server.transport_session import ServerTransportSession +from mcp.shared.session import BaseSession +from mcp.types import RequestId, RequestParams SessionT = TypeVar( "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 99e8972a9..d95d3a380 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -33,9 +33,9 @@ tool_progress, ) from mcp.client.session import ClientSession -from mcp.client.transport_session import ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 967925a11..0f850599a 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -18,8 +18,8 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.transport_session import ClientTransportSession from mcp.client.sse import sse_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7aa768ae1..be80e3820 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -22,8 +22,8 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.transport_session import ClientTransportSession from mcp.client.streamable_http import streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.session import ServerSession from mcp.server.streamable_http import ( From d4895a72e2d5515c3cb7b42b1c84fe16ba2fb4fc Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:30:54 +0000 Subject: [PATCH 39/44] fix circle --- src/mcp/shared/context.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 63fafa241..e38fc3d5c 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -3,8 +3,7 @@ from typing_extensions import TypeVar -from mcp.client.transport_session import ClientTransportSession -from mcp.server.transport_session import ServerTransportSession +from mcp import ServerTransportSession, ClientTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams From 3f0b620e1f3944d64b2785e194bab135193b48de Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:31:48 +0000 Subject: [PATCH 40/44] fix readme --- README.md | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 6450adbe9..5cbb6510f 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example @@ -326,7 +326,7 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @@ -674,7 +674,7 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @@ -798,7 +798,7 @@ Request additional information from users. This example shows an Elicitation dur from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -888,7 +888,7 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @@ -2037,8 +2037,7 @@ import os from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.session import ClientTransportSession +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -2168,9 +2167,8 @@ cd to the `examples/snippets` directory and run: import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection From fc17b953ae64ee658af806f0261fa0ed6a48ab52 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:32:07 +0000 Subject: [PATCH 41/44] fix ruff check --- src/mcp/shared/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index e38fc3d5c..9ec3a2f17 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -3,7 +3,7 @@ from typing_extensions import TypeVar -from mcp import ServerTransportSession, ClientTransportSession +from mcp import ClientTransportSession, ServerTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams From 1d2b6262ad061650ff83e132e9dab4ddb66c6faa Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:33:23 +0000 Subject: [PATCH 42/44] fix circular import --- src/mcp/shared/context.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 9ec3a2f17..7267f4954 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,14 +1,16 @@ from dataclasses import dataclass -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic from typing_extensions import TypeVar -from mcp import ClientTransportSession, ServerTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams +if TYPE_CHECKING: + from mcp import ClientTransportSession, ServerTransportSession + SessionT = TypeVar( - "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession + "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" ) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) From fd22fe2a62b9e320bff0e759a836be447b80af17 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:27:36 +0000 Subject: [PATCH 43/44] fix imports --- README.md | 16 +++++++--------- examples/snippets/clients/display_utilities.py | 3 +-- examples/snippets/clients/stdio_client.py | 3 +-- examples/snippets/servers/elicitation.py | 2 +- examples/snippets/servers/lifespan_example.py | 2 +- examples/snippets/servers/notifications.py | 2 +- examples/snippets/servers/tool_progress.py | 2 +- src/mcp/server/elicitation.py | 2 +- src/mcp/server/fastmcp/server.py | 3 ++- src/mcp/server/lowlevel/server.py | 3 ++- src/mcp/shared/context.py | 3 +-- src/mcp/shared/memory.py | 2 +- tests/client/test_list_roots_callback.py | 2 +- tests/client/test_sampling_callback.py | 2 +- tests/client/test_session.py | 3 ++- tests/server/fastmcp/test_elicitation.py | 3 ++- tests/server/fastmcp/test_integration.py | 3 ++- tests/shared/test_memory.py | 3 ++- tests/shared/test_session.py | 3 ++- tests/shared/test_sse.py | 3 ++- tests/shared/test_streamable_http.py | 3 ++- tests/shared/test_ws.py | 3 ++- 22 files changed, 38 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 6450adbe9..5cbb6510f 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example @@ -326,7 +326,7 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @@ -674,7 +674,7 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @@ -798,7 +798,7 @@ Request additional information from users. This example shows an Elicitation dur from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -888,7 +888,7 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @@ -2037,8 +2037,7 @@ import os from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.session import ClientTransportSession +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -2168,9 +2167,8 @@ cd to the `examples/snippets` directory and run: import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index 5e1b203ee..b8ad7dffc 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -6,9 +6,8 @@ import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index 62fb0f4c4..90f9fdff9 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -8,8 +8,7 @@ from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.session import ClientTransportSession +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 049b42516..45f2cb68b 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 32b699730..46f01f427 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 36d9712eb..995ecd817 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,5 +1,5 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index dddd8c9eb..a0f62fda6 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,5 +1,5 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 65399e27c..b2f33ec7c 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 03e223329..cc05403dd 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -54,12 +54,13 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSessionT, ServerTransportSession +from mcp.server.session import ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b60e04974..85846afc6 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -85,7 +85,8 @@ async def main(): from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession, ServerTransportSession +from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 845cc50e2..7267f4954 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -7,8 +7,7 @@ from mcp.types import RequestId, RequestParams if TYPE_CHECKING: - from mcp.client.session import ClientTransportSession - from mcp.server.session import ServerTransportSession + from mcp import ClientTransportSession, ServerTransportSession SessionT = TypeVar( "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index b8466fe91..2d203d743 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -15,13 +15,13 @@ import mcp.types as types from mcp.client.session import ( ClientSession, - ClientTransportSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT, ) +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index dc53eddbc..5acb3b21a 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,7 +1,7 @@ import pytest from pydantic import FileUrl -from mcp.client.session import ClientTransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession from mcp.shared.context import RequestContext diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 49138398c..9fb6e29c7 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,6 @@ import pytest -from mcp.client.session import ClientTransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.memory import ( diff --git a/tests/client/test_session.py b/tests/client/test_session.py index c327a806f..bd51e4e10 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -4,7 +4,8 @@ import pytest import mcp.types as types -from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession, ClientTransportSession +from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index dd1ae72dc..77f97e677 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,7 +7,8 @@ import pytest from pydantic import BaseModel, Field -from mcp.client.session import ClientTransportSession, ElicitationFnT +from mcp.client.session import ElicitationFnT +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.context import RequestContext diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 778b0bfd7..d95d3a380 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -32,9 +32,10 @@ structured_output, tool_progress, ) -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index 4ebce3b15..56e0b98e7 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -2,7 +2,8 @@ from pydantic import AnyUrl from typing_extensions import AsyncGenerator -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import EmptyResult, Resource diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 47b5a02f6..a056f705b 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -5,7 +5,8 @@ import pytest import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index ba823ab6a..0f850599a 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -17,8 +17,9 @@ from starlette.routing import Mount, Route import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession from mcp.client.sse import sse_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 603a4270a..be80e3820 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -21,8 +21,9 @@ from starlette.routing import Mount import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.session import ServerSession from mcp.server.streamable_http import ( diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 1fac8696f..107cd5589 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -12,7 +12,8 @@ from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.websocket import websocket_client from mcp.server import Server from mcp.server.websocket import websocket_server From f36e9399d45a431330a79d2b3b18c2dd6820e167 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 07:00:48 +0000 Subject: [PATCH 44/44] fix some more type hints --- tests/server/fastmcp/test_elicitation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 77f97e677..52e6799b7 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -10,7 +10,7 @@ from mcp.client.session import ElicitationFnT from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -25,7 +25,7 @@ def create_ask_user_tool(mcp: FastMCP): """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str: + async def ask_user(prompt: str, ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) if result.action == "accept" and result.data: @@ -106,7 +106,7 @@ async def test_elicitation_schema_validation(): def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") - async def tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover + async def tool(ctx: Context[ServerTransportSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail validation", schema=schema_class) return "Should not reach here" @@ -160,7 +160,7 @@ class OptionalSchema(BaseModel): subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") @mcp.tool(description="Tool with optional fields") - async def optional_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_tool(ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) if result.action == "accept" and result.data: @@ -201,7 +201,7 @@ class InvalidOptionalSchema(BaseModel): optional_list: list[str] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") - async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover + async def invalid_optional_tool(ctx: Context[ServerTransportSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) return "Should not reach here" @@ -234,7 +234,7 @@ class DefaultsSchema(BaseModel): email: str = Field(description="Email address (required)") @mcp.tool(description="Tool with default values") - async def defaults_tool(ctx: Context[ServerSession, None]) -> str: + async def defaults_tool(ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message="Please provide your information", schema=DefaultsSchema) if result.action == "accept" and result.data: