Skip to content

Commit c508051

Browse files
committed
add mvp for server side tasks
1 parent 149d390 commit c508051

File tree

5 files changed

+266
-91
lines changed

5 files changed

+266
-91
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Experimental handlers for the low-level MCP server.
2+
3+
WARNING: These APIs are experimental and may change without notice.
4+
"""
5+
6+
import logging
7+
from collections.abc import Awaitable, Callable
8+
9+
from mcp.server.lowlevel.func_inspection import create_call_wrapper
10+
from mcp.types import (
11+
CancelTaskRequest,
12+
CancelTaskResult,
13+
GetTaskPayloadRequest,
14+
GetTaskPayloadResult,
15+
GetTaskRequest,
16+
GetTaskResult,
17+
ListTasksRequest,
18+
ListTasksResult,
19+
ServerCapabilities,
20+
ServerResult,
21+
ServerTasksCapability,
22+
ServerTasksRequestsCapability,
23+
TasksCancelCapability,
24+
TasksListCapability,
25+
TasksToolsCapability,
26+
)
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class ExperimentalHandlers:
32+
"""Experimental request/notification handlers.
33+
34+
WARNING: These APIs are experimental and may change without notice.
35+
"""
36+
37+
def __init__(
38+
self,
39+
request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]],
40+
notification_handlers: dict[type, Callable[..., Awaitable[None]]],
41+
):
42+
self._request_handlers = request_handlers
43+
self._notification_handlers = notification_handlers
44+
45+
def update_capabilities(self, capabilities: ServerCapabilities) -> None:
46+
capabilities.tasks = ServerTasksCapability()
47+
if ListTasksRequest in self._request_handlers:
48+
capabilities.tasks.list = TasksListCapability()
49+
if CancelTaskRequest in self._request_handlers:
50+
capabilities.tasks.cancel = TasksCancelCapability()
51+
52+
capabilities.tasks.requests = ServerTasksRequestsCapability(
53+
tools=TasksToolsCapability()
54+
) # assuming always supported for now
55+
56+
def list_tasks(
57+
self,
58+
) -> Callable[
59+
[Callable[[ListTasksRequest], Awaitable[ListTasksResult]]],
60+
Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
61+
]:
62+
"""Register a handler for listing tasks.
63+
64+
WARNING: This API is experimental and may change without notice.
65+
"""
66+
67+
def decorator(
68+
func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
69+
) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]:
70+
logger.debug("Registering handler for ListTasksRequest")
71+
wrapper = create_call_wrapper(func, ListTasksRequest)
72+
73+
async def handler(req: ListTasksRequest):
74+
result = await wrapper(req)
75+
return ServerResult(result)
76+
77+
self._request_handlers[ListTasksRequest] = handler
78+
return func
79+
80+
return decorator
81+
82+
def get_task(self):
83+
"""Register a handler for getting task status.
84+
85+
WARNING: This API is experimental and may change without notice.
86+
"""
87+
88+
def decorator(func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]]):
89+
logger.debug("Registering handler for GetTaskRequest")
90+
wrapper = create_call_wrapper(func, GetTaskRequest)
91+
92+
async def handler(req: GetTaskRequest):
93+
result = await wrapper(req)
94+
return ServerResult(result)
95+
96+
self._request_handlers[GetTaskRequest] = handler
97+
return func
98+
99+
return decorator
100+
101+
def get_task_result(self):
102+
"""Register a handler for getting task results/payload.
103+
104+
WARNING: This API is experimental and may change without notice.
105+
"""
106+
107+
def decorator(func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]):
108+
logger.debug("Registering handler for GetTaskPayloadRequest")
109+
wrapper = create_call_wrapper(func, GetTaskPayloadRequest)
110+
111+
async def handler(req: GetTaskPayloadRequest):
112+
result = await wrapper(req)
113+
return ServerResult(result)
114+
115+
self._request_handlers[GetTaskPayloadRequest] = handler
116+
return func
117+
118+
return decorator
119+
120+
def cancel_task(self):
121+
"""Register a handler for cancelling tasks.
122+
123+
WARNING: This API is experimental and may change without notice.
124+
"""
125+
126+
def decorator(func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]):
127+
logger.debug("Registering handler for CancelTaskRequest")
128+
wrapper = create_call_wrapper(func, CancelTaskRequest)
129+
130+
async def handler(req: CancelTaskRequest):
131+
result = await wrapper(req)
132+
return ServerResult(result)
133+
134+
self._request_handlers[CancelTaskRequest] = handler
135+
return func
136+
137+
return decorator

src/mcp/server/lowlevel/server.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ async def main():
8282
from typing_extensions import TypeVar
8383

8484
import mcp.types as types
85+
from mcp.server.lowlevel.experimental import ExperimentalHandlers
8586
from mcp.server.lowlevel.func_inspection import create_call_wrapper
8687
from mcp.server.lowlevel.helper_types import ReadResourceContents
8788
from mcp.server.models import InitializationOptions
8889
from mcp.server.session import ServerSession
89-
from mcp.shared.context import RequestContext
90+
from mcp.shared.context import Experimental, RequestContext
9091
from mcp.shared.exceptions import McpError
9192
from mcp.shared.message import ServerMessageMetadata, SessionMessage
9293
from mcp.shared.session import RequestResponder
@@ -154,6 +155,7 @@ def __init__(
154155
}
155156
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
156157
self._tool_cache: dict[str, types.Tool] = {}
158+
self._experimental_handlers: ExperimentalHandlers | None = None
157159
logger.debug("Initializing server %r", name)
158160

159161
def create_initialization_options(
@@ -219,14 +221,17 @@ def get_capabilities(
219221
if types.CompleteRequest in self.request_handlers:
220222
completions_capability = types.CompletionsCapability()
221223

222-
return types.ServerCapabilities(
224+
capabilities = types.ServerCapabilities(
223225
prompts=prompts_capability,
224226
resources=resources_capability,
225227
tools=tools_capability,
226228
logging=logging_capability,
227229
experimental=experimental_capabilities,
228230
completions=completions_capability,
229231
)
232+
if self._experimental_handlers:
233+
self._experimental_handlers.update_capabilities(capabilities)
234+
return capabilities
230235

231236
@property
232237
def request_context(
@@ -235,6 +240,18 @@ def request_context(
235240
"""If called outside of a request context, this will raise a LookupError."""
236241
return request_ctx.get()
237242

243+
@property
244+
def experimental(self) -> ExperimentalHandlers:
245+
"""Experimental APIs for tasks and other features.
246+
247+
WARNING: These APIs are experimental and may change without notice.
248+
"""
249+
250+
# We create this inline so we only add these capabilities _if_ they're actually used
251+
if self._experimental_handlers is None:
252+
self._experimental_handlers = ExperimentalHandlers(self.request_handlers, self.notification_handlers)
253+
return self._experimental_handlers
254+
238255
def list_prompts(self):
239256
def decorator(
240257
func: Callable[[], Awaitable[list[types.Prompt]]]
@@ -666,13 +683,14 @@ async def _handle_message(
666683
async def _handle_request(
667684
self,
668685
message: RequestResponder[types.ClientRequest, types.ServerResult],
669-
req: Any,
686+
req: types.ClientRequestType,
670687
session: ServerSession,
671688
lifespan_context: LifespanResultT,
672689
raise_exceptions: bool,
673690
):
674691
logger.info("Processing request of type %s", type(req).__name__)
675-
if handler := self.request_handlers.get(type(req)): # type: ignore
692+
693+
if handler := self.request_handlers.get(type(req)):
676694
logger.debug("Dispatching request of type %s", type(req).__name__)
677695

678696
token = None
@@ -692,6 +710,7 @@ async def _handle_request(
692710
message.request_meta,
693711
session,
694712
lifespan_context,
713+
Experimental(task_metadata=message.request_params.task if message.request_params else None),
695714
request=request_data,
696715
)
697716
)

src/mcp/shared/context.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,27 @@
44
from typing_extensions import TypeVar
55

66
from mcp.shared.session import BaseSession
7-
from mcp.types import RequestId, RequestParams
7+
from mcp.types import RequestId, RequestParams, TaskMetadata
88

99
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
1010
LifespanContextT = TypeVar("LifespanContextT")
1111
RequestT = TypeVar("RequestT", default=Any)
1212

1313

14+
@dataclass
15+
class Experimental:
16+
task_metadata: TaskMetadata | None = None
17+
18+
@property
19+
def is_task(self) -> bool:
20+
return self.task_metadata is not None
21+
22+
1423
@dataclass
1524
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
1625
request_id: RequestId
1726
meta: RequestParams.Meta | None
1827
session: SessionT
1928
lifespan_context: LifespanContextT
29+
experimental: Experimental = Experimental()
2030
request: RequestT | None = None

src/mcp/shared/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ def __init__(
8181
]""",
8282
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
8383
message_metadata: MessageMetadata = None,
84+
request_params: RequestParams | None = None,
8485
) -> None:
8586
self.request_id = request_id
8687
self.request_meta = request_meta
88+
self.request_params = request_params
8789
self.request = request
8890
self.message_metadata = message_metadata
8991
self._session = session
@@ -353,6 +355,7 @@ async def _receive_loop(self) -> None:
353355
session=self,
354356
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
355357
message_metadata=message.metadata,
358+
request_params=validated_request.root.params,
356359
)
357360
self._in_flight[responder.request_id] = responder
358361
await self._received_request(responder)

0 commit comments

Comments
 (0)