1
1
package io .modelcontextprotocol .server .transport ;
2
2
3
+ import java .awt .PageAttributes ;
4
+
5
+ import com .fasterxml .jackson .core .JsonProcessingException ;
3
6
import com .fasterxml .jackson .core .type .TypeReference ;
4
7
import com .fasterxml .jackson .databind .ObjectMapper ;
8
+
5
9
import io .modelcontextprotocol .spec .DefaultMcpTransportContext ;
6
10
import io .modelcontextprotocol .spec .McpError ;
7
11
import io .modelcontextprotocol .spec .McpSchema ;
10
14
import io .modelcontextprotocol .spec .McpStreamableServerTransportProvider ;
11
15
import io .modelcontextprotocol .spec .McpTransportContext ;
12
16
import io .modelcontextprotocol .util .Assert ;
17
+
13
18
import org .slf4j .Logger ;
14
19
import org .slf4j .LoggerFactory ;
15
20
import org .springframework .http .HttpStatus ;
19
24
import org .springframework .web .reactive .function .server .RouterFunctions ;
20
25
import org .springframework .web .reactive .function .server .ServerRequest ;
21
26
import org .springframework .web .reactive .function .server .ServerResponse ;
27
+
22
28
import reactor .core .Disposable ;
23
29
import reactor .core .Exceptions ;
24
30
import reactor .core .publisher .Flux ;
25
31
import reactor .core .publisher .FluxSink ;
26
32
import reactor .core .publisher .Mono ;
27
33
28
34
import java .io .IOException ;
35
+ import java .util .ArrayList ;
36
+ import java .util .List ;
29
37
import java .util .concurrent .ConcurrentHashMap ;
30
38
import java .util .function .Function ;
31
39
40
+ /**
41
+ * Server-side implementation of the Model Context Protocol (MCP) streamable transport
42
+ * layer using HTTP with Server-Sent Events (SSE) through Spring WebFlux.
43
+ *
44
+ * <p>
45
+ *
46
+ * @author Dariusz Jędrzejczyk
47
+ * @author Zachary German
48
+ * @see McpStreamableServerTransportProvider
49
+ * @see RouterFunction
50
+ */
32
51
public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider {
33
52
34
53
private static final Logger logger = LoggerFactory .getLogger (WebFluxStreamableServerTransportProvider .class );
@@ -37,6 +56,12 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe
37
56
38
57
public static final String DEFAULT_BASE_URL = "" ;
39
58
59
+ private static final String MCP_SESSION_ID = "mcp-session-id" ;
60
+
61
+ private static final String LAST_EVENT_ID = "Last-Event-ID" ;
62
+
63
+ private static final String ACCEPT = "Accept" ;
64
+
40
65
private final ObjectMapper objectMapper ;
41
66
42
67
private final String baseUrl ;
@@ -195,21 +220,40 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
195
220
McpTransportContext transportContext = this .contextExtractor .apply (request );
196
221
197
222
return Mono .defer (() -> {
198
- if (!request .headers ().asHttpHeaders ().containsKey ("mcp-session-id" )) {
199
- return ServerResponse .badRequest ().build (); // TODO: say we need a session
200
- // id
223
+ List <String > badRequestErrors = new ArrayList <>();
224
+
225
+ String accept = request .headers ().asHttpHeaders ().getFirst (ACCEPT );
226
+ if (accept == null || !accept .contains (MediaType .TEXT_EVENT_STREAM_VALUE )) {
227
+ badRequestErrors .add ("text/event-stream required in Accept header" );
201
228
}
202
229
203
- String sessionId = request .headers ().asHttpHeaders ().getFirst ("mcp-session-id" );
230
+ String sessionId = request .headers ().asHttpHeaders ().getFirst (MCP_SESSION_ID );
231
+
232
+ if (sessionId == null || sessionId .isBlank ()) {
233
+ badRequestErrors .add ("Session ID required in mcp-session-id header" );
234
+ }
235
+
236
+ if (!badRequestErrors .isEmpty ()) {
237
+ String combinedMessage = String .join ("; " , badRequestErrors );
238
+ try {
239
+ String errorJson = objectMapper .writeValueAsString (new McpError (combinedMessage ));
240
+ return ServerResponse .badRequest ().bodyValue (errorJson );
241
+ }
242
+ catch (JsonProcessingException e ) {
243
+ logger .debug ("Failed to serialize McpError: {}" , e );
244
+ return ServerResponse .status (HttpStatus .INTERNAL_SERVER_ERROR )
245
+ .bodyValue ("Failed to serialize error message." );
246
+ }
247
+ }
204
248
205
249
McpStreamableServerSession session = this .sessions .get (sessionId );
206
250
207
251
if (session == null ) {
208
252
return ServerResponse .notFound ().build ();
209
253
}
210
254
211
- if (request .headers ().asHttpHeaders ().containsKey ("mcp-last-id" )) {
212
- String lastId = request .headers ().asHttpHeaders ().getFirst ("mcp-last-id" );
255
+ if (request .headers ().asHttpHeaders ().containsKey (LAST_EVENT_ID )) {
256
+ String lastId = request .headers ().asHttpHeaders ().getFirst (LAST_EVENT_ID );
213
257
return ServerResponse .ok ()
214
258
.contentType (MediaType .TEXT_EVENT_STREAM )
215
259
.body (session .replay (lastId ), ServerSentEvent .class );
@@ -252,9 +296,31 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
252
296
253
297
return request .bodyToMono (String .class ).<ServerResponse >flatMap (body -> {
254
298
try {
299
+ List <String > badRequestErrors = new ArrayList <>();
300
+
301
+ String accept = request .headers ().asHttpHeaders ().getFirst (ACCEPT );
302
+ if (accept == null || !accept .contains (MediaType .TEXT_EVENT_STREAM_VALUE )) {
303
+ badRequestErrors .add ("text/event-stream required in Accept header" );
304
+ }
305
+ if (accept == null || !accept .contains (MediaType .APPLICATION_JSON_VALUE )) {
306
+ badRequestErrors .add ("application/json required in Accept header" );
307
+ }
308
+
255
309
McpSchema .JSONRPCMessage message = McpSchema .deserializeJsonRpcMessage (objectMapper , body );
256
310
if (message instanceof McpSchema .JSONRPCRequest jsonrpcRequest
257
311
&& jsonrpcRequest .method ().equals (McpSchema .METHOD_INITIALIZE )) {
312
+ if (!badRequestErrors .isEmpty ()) {
313
+ String combinedMessage = String .join ("; " , badRequestErrors );
314
+ try {
315
+ String errorJson = objectMapper .writeValueAsString (new McpError (combinedMessage ));
316
+ return ServerResponse .badRequest ().bodyValue (errorJson );
317
+ }
318
+ catch (JsonProcessingException e ) {
319
+ logger .debug ("Failed to serialize McpError: {}" , e );
320
+ return ServerResponse .status (HttpStatus .INTERNAL_SERVER_ERROR )
321
+ .bodyValue ("Failed to serialize error message." );
322
+ }
323
+ }
258
324
McpSchema .InitializeRequest initializeRequest = objectMapper .convertValue (jsonrpcRequest .params (),
259
325
new TypeReference <McpSchema .InitializeRequest >() {
260
326
});
@@ -274,15 +340,29 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
274
340
})
275
341
.flatMap (initResult -> ServerResponse .ok ()
276
342
.contentType (MediaType .APPLICATION_JSON )
277
- .header ("mcp-session-id" , init .session ().getId ())
343
+ .header (MCP_SESSION_ID , init .session ().getId ())
278
344
.bodyValue (initResult ));
279
345
}
280
346
281
- if (!request .headers ().asHttpHeaders ().containsKey ("mcp-session-id" )) {
282
- return ServerResponse .badRequest ().bodyValue (new McpError ("Session ID missing" ));
347
+ String sessionId = request .headers ().asHttpHeaders ().getFirst (MCP_SESSION_ID );
348
+
349
+ if (sessionId == null || sessionId .isBlank ()) {
350
+ badRequestErrors .add ("Session ID required in mcp-session-id header" );
351
+ }
352
+
353
+ if (!badRequestErrors .isEmpty ()) {
354
+ String combinedMessage = String .join ("; " , badRequestErrors );
355
+ try {
356
+ String errorJson = objectMapper .writeValueAsString (new McpError (combinedMessage ));
357
+ return ServerResponse .badRequest ().bodyValue (errorJson );
358
+ }
359
+ catch (JsonProcessingException e ) {
360
+ logger .debug ("Failed to serialize McpError: {}" , e );
361
+ return ServerResponse .status (HttpStatus .INTERNAL_SERVER_ERROR )
362
+ .bodyValue ("Failed to serialize error message." );
363
+ }
283
364
}
284
365
285
- String sessionId = request .headers ().asHttpHeaders ().getFirst ("mcp-session-id" );
286
366
McpStreamableServerSession session = sessions .get (sessionId );
287
367
288
368
if (session == null ) {
@@ -330,7 +410,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
330
410
McpTransportContext transportContext = this .contextExtractor .apply (request );
331
411
332
412
return Mono .defer (() -> {
333
- if (!request .headers ().asHttpHeaders ().containsKey ("mcp-session-id" )) {
413
+ if (!request .headers ().asHttpHeaders ().containsKey (MCP_SESSION_ID )) {
334
414
return ServerResponse .badRequest ().build (); // TODO: say we need a session
335
415
// id
336
416
}
@@ -340,7 +420,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
340
420
return ServerResponse .status (HttpStatus .METHOD_NOT_ALLOWED ).build ();
341
421
}
342
422
343
- String sessionId = request .headers ().asHttpHeaders ().getFirst ("mcp-session-id" );
423
+ String sessionId = request .headers ().asHttpHeaders ().getFirst (MCP_SESSION_ID );
344
424
345
425
McpStreamableServerSession session = this .sessions .get (sessionId );
346
426
0 commit comments