Skip to content

Commit 481f7ea

Browse files
committed
fix type hints without cast
1 parent 1bfc086 commit 481f7ea

File tree

10 files changed

+31
-25
lines changed

10 files changed

+31
-25
lines changed

src/mcp/client/session_group.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,10 @@ class _ComponentNames(BaseModel):
9696
_tools: dict[str, types.Tool]
9797

9898
# Client-server connection management.
99-
_sessions: dict["mcp.ClientTransportSession", _ComponentNames]
100-
_tool_to_session: dict[str, "mcp.ClientTransportSession"]
99+
_sessions: dict[mcp.ClientTransportSession, _ComponentNames]
100+
_tool_to_session: dict[str, mcp.ClientTransportSession]
101101
_exit_stack: contextlib.AsyncExitStack
102-
_session_exit_stacks: dict["mcp.ClientTransportSession", contextlib.AsyncExitStack]
102+
_session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack]
103103

104104
# Optional fn consuming (component_name, serverInfo) for custom names.
105105
# This is provide a means to mitigate naming conflicts across servers.
@@ -153,7 +153,7 @@ async def __aexit__(
153153
tg.start_soon(exit_stack.aclose)
154154

155155
@property
156-
def sessions(self) -> list["mcp.ClientTransportSession"]:
156+
def sessions(self) -> list[mcp.ClientTransportSession]:
157157
"""Returns the list of sessions being managed."""
158158
return list(self._sessions.keys()) # pragma: no cover
159159

@@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu
178178
session_tool_name = self.tools[name].name
179179
return await session.call_tool(session_tool_name, args)
180180

181-
async def disconnect_from_server(self, session: "mcp.ClientTransportSession") -> None:
181+
async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None:
182182
"""Disconnects from a single MCP server."""
183183

184184
session_known_for_components = session in self._sessions
@@ -216,23 +216,23 @@ async def disconnect_from_server(self, session: "mcp.ClientTransportSession") ->
216216
await session_stack_to_close.aclose() # pragma: no cover
217217

218218
async def connect_with_session(
219-
self, server_info: types.Implementation, session: "mcp.ClientTransportSession"
220-
) -> "mcp.ClientTransportSession":
219+
self, server_info: types.Implementation, session: mcp.ClientTransportSession
220+
) -> mcp.ClientTransportSession:
221221
"""Connects to a single MCP server."""
222222
await self._aggregate_components(server_info, session)
223223
return session
224224

225225
async def connect_to_server(
226226
self,
227227
server_params: ServerParameters,
228-
) -> "mcp.ClientTransportSession":
228+
) -> mcp.ClientTransportSession:
229229
"""Connects to a single MCP server."""
230230
server_info, session = await self._establish_session(server_params)
231231
return await self.connect_with_session(server_info, session)
232232

233233
async def _establish_session(
234234
self, server_params: ServerParameters
235-
) -> tuple[types.Implementation, "mcp.ClientTransportSession"]:
235+
) -> tuple[types.Implementation, mcp.ClientTransportSession]:
236236
"""Establish a client session to an MCP server."""
237237

238238
session_stack = contextlib.AsyncExitStack()
@@ -277,7 +277,7 @@ async def _establish_session(
277277
raise
278278

279279
async def _aggregate_components(
280-
self, server_info: types.Implementation, session: "mcp.ClientTransportSession"
280+
self, server_info: types.Implementation, session: mcp.ClientTransportSession
281281
) -> None:
282282
"""Aggregates prompts, resources, and tools from a given session."""
283283

src/mcp/shared/memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import mcp.types as types
1616
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
17+
from mcp.client.session import ClientTransportSession
1718
from mcp.server import Server
1819
from mcp.server.fastmcp import FastMCP
1920
from mcp.shared.message import SessionMessage
@@ -57,7 +58,7 @@ async def create_connected_server_and_client_session(
5758
client_info: types.Implementation | None = None,
5859
raise_exceptions: bool = False,
5960
elicitation_callback: ElicitationFnT | None = None,
60-
) -> AsyncGenerator[ClientSession, None]:
61+
) -> AsyncGenerator[ClientTransportSession, None]:
6162
"""Creates a ClientSession that is connected to a running MCP server."""
6263

6364
# TODO(Marcelo): we should have a proper `Client` that can use this "in-memory transport",

tests/client/test_sampling_callback.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import cast
2-
31
import pytest
42

53
from mcp.client.session import ClientTransportSession
@@ -37,7 +35,8 @@ async def sampling_callback(
3735

3836
@server.tool("test_sampling")
3937
async def test_sampling_tool(message: str):
40-
session = cast(ServerSession, server.get_context().session)
38+
session = server.get_context().session
39+
assert isinstance(session, ServerSession)
4140
value = await session.create_message(
4241
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
4342
max_tokens=100,

tests/server/test_cancel_handling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mcp.server.lowlevel.server import Server
1010
from mcp.shared.exceptions import McpError
1111
from mcp.shared.memory import create_connected_server_and_client_session
12+
from mcp.client.session import ClientSession
1213
from mcp.types import (
1314
CallToolRequest,
1415
CallToolRequestParams,
@@ -56,6 +57,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[
5657

5758
async with create_connected_server_and_client_session(server) as client:
5859
# First request (will be cancelled)
60+
assert isinstance(client, ClientSession)
5961
async def first_request():
6062
try:
6163
await client.send_request(

tests/shared/test_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pydantic import AnyUrl
33
from typing_extensions import AsyncGenerator
44

5-
from mcp.client.session import ClientSession
5+
from mcp.client.session import ClientSession, ClientTransportSession
66
from mcp.server import Server
77
from mcp.shared.memory import create_connected_server_and_client_session
88
from mcp.types import EmptyResult, Resource
@@ -28,7 +28,7 @@ async def handle_list_resources(): # pragma: no cover
2828
@pytest.fixture
2929
async def client_connected_to_server(
3030
mcp_server: Server,
31-
) -> AsyncGenerator[ClientSession, None]:
31+
) -> AsyncGenerator[ClientTransportSession, None]:
3232
async with create_connected_server_and_client_session(mcp_server) as client_session:
3333
yield client_session
3434

tests/shared/test_progress_notifications.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ async def handle_list_tools() -> list[types.Tool]:
370370
with patch("mcp.shared.session.logging.error", side_effect=mock_log_error):
371371
async with create_connected_server_and_client_session(server) as client_session:
372372
# Send a request with a failing progress callback
373+
assert isinstance(client_session, ClientSession)
373374
result = await client_session.send_request(
374375
types.ClientRequest(
375376
types.CallToolRequest(

tests/shared/test_session.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
import mcp.types as types
8-
from mcp.client.session import ClientSession
8+
from mcp.client.session import ClientSession, ClientTransportSession
99
from mcp.server.lowlevel.server import Server
1010
from mcp.shared.exceptions import McpError
1111
from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session
@@ -27,19 +27,20 @@ def mcp_server() -> Server:
2727
@pytest.fixture
2828
async def client_connected_to_server(
2929
mcp_server: Server,
30-
) -> AsyncGenerator[ClientSession, None]:
30+
) -> AsyncGenerator[ClientTransportSession, None]:
3131
async with create_connected_server_and_client_session(mcp_server) as client_session:
3232
yield client_session
3333

3434

3535
@pytest.mark.anyio
3636
async def test_in_flight_requests_cleared_after_completion(
37-
client_connected_to_server: ClientSession,
37+
client_connected_to_server: ClientTransportSession,
3838
):
3939
"""Verify that _in_flight is empty after all requests complete."""
4040
# Send a request and wait for response
4141
response = await client_connected_to_server.send_ping()
4242
assert isinstance(response, EmptyResult)
43+
assert isinstance(client_connected_to_server, ClientSession)
4344

4445
# Verify _in_flight is empty
4546
assert len(client_connected_to_server._in_flight) == 0
@@ -101,6 +102,7 @@ async def make_request(client_session: ClientSession):
101102

102103
async with create_connected_server_and_client_session(make_server()) as client_session:
103104
async with anyio.create_task_group() as tg:
105+
assert isinstance(client_session, ClientSession)
104106
tg.start_soon(make_request, client_session)
105107

106108
# Wait for the request to be in-flight

tests/shared/test_sse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from starlette.routing import Mount, Route
1818

1919
import mcp.types as types
20-
from mcp.client.session import ClientSession
20+
from mcp.client.session import ClientSession, ClientTransportSession
2121
from mcp.client.sse import sse_client
2222
from mcp.server import Server
2323
from mcp.server.sse import SseServerTransport
@@ -185,7 +185,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
185185

186186

187187
@pytest.fixture
188-
async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
188+
async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]:
189189
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
190190
async with ClientSession(*streams) as session:
191191
await session.initialize()

tests/shared/test_streamable_http.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import multiprocessing
99
import socket
1010
from collections.abc import Generator
11-
from typing import Any, cast
11+
from typing import Any
1212

1313
import anyio
1414
import httpx
@@ -199,7 +199,8 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
199199

200200
elif name == "test_sampling_tool":
201201
# Test sampling by requesting the client to sample a message
202-
session = cast(ServerSession, ctx.session)
202+
session = ctx.session
203+
assert isinstance(session, ServerSession)
203204
sampling_result = await session.create_message(
204205
messages=[
205206
types.SamplingMessage(

tests/shared/test_ws.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from starlette.routing import WebSocketRoute
1313
from starlette.websockets import WebSocket
1414

15-
from mcp.client.session import ClientSession
15+
from mcp.client.session import ClientSession, ClientTransportSession
1616
from mcp.client.websocket import websocket_client
1717
from mcp.server import Server
1818
from mcp.server.websocket import websocket_server
@@ -125,7 +125,7 @@ def server(server_port: int) -> Generator[None, None, None]:
125125

126126

127127
@pytest.fixture()
128-
async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
128+
async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]:
129129
"""Create and initialize a WebSocket client session"""
130130
async with websocket_client(server_url + "/ws") as streams:
131131
async with ClientSession(*streams) as session:

0 commit comments

Comments
 (0)