Skip to content

Commit 5c09fa5

Browse files
author
Zachary German
committed
Added 'Accept' header validation and touched up 'Last-Event-ID' header
1 parent c23c5df commit 5c09fa5

File tree

3 files changed

+234
-43
lines changed

3 files changed

+234
-43
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package io.modelcontextprotocol.server.transport;
22

3+
import java.awt.PageAttributes;
4+
5+
import com.fasterxml.jackson.core.JsonProcessingException;
36
import com.fasterxml.jackson.core.type.TypeReference;
47
import com.fasterxml.jackson.databind.ObjectMapper;
8+
59
import io.modelcontextprotocol.spec.DefaultMcpTransportContext;
610
import io.modelcontextprotocol.spec.McpError;
711
import io.modelcontextprotocol.spec.McpSchema;
@@ -10,6 +14,7 @@
1014
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
1115
import io.modelcontextprotocol.spec.McpTransportContext;
1216
import io.modelcontextprotocol.util.Assert;
17+
1318
import org.slf4j.Logger;
1419
import org.slf4j.LoggerFactory;
1520
import org.springframework.http.HttpStatus;
@@ -19,16 +24,30 @@
1924
import org.springframework.web.reactive.function.server.RouterFunctions;
2025
import org.springframework.web.reactive.function.server.ServerRequest;
2126
import org.springframework.web.reactive.function.server.ServerResponse;
27+
2228
import reactor.core.Disposable;
2329
import reactor.core.Exceptions;
2430
import reactor.core.publisher.Flux;
2531
import reactor.core.publisher.FluxSink;
2632
import reactor.core.publisher.Mono;
2733

2834
import java.io.IOException;
35+
import java.util.ArrayList;
36+
import java.util.List;
2937
import java.util.concurrent.ConcurrentHashMap;
3038
import java.util.function.Function;
3139

40+
/**
41+
* Server-side implementation of the Model Context Protocol (MCP) streamable transport
42+
* layer using HTTP with Server-Sent Events (SSE) through Spring WebFlux.
43+
*
44+
* <p>
45+
*
46+
* @author Dariusz Jędrzejczyk
47+
* @author Zachary German
48+
* @see McpStreamableServerTransportProvider
49+
* @see RouterFunction
50+
*/
3251
public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider {
3352

3453
private static final Logger logger = LoggerFactory.getLogger(WebFluxStreamableServerTransportProvider.class);
@@ -37,6 +56,12 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe
3756

3857
public static final String DEFAULT_BASE_URL = "";
3958

59+
private static final String MCP_SESSION_ID = "mcp-session-id";
60+
61+
private static final String LAST_EVENT_ID = "Last-Event-ID";
62+
63+
private static final String ACCEPT = "Accept";
64+
4065
private final ObjectMapper objectMapper;
4166

4267
private final String baseUrl;
@@ -195,21 +220,40 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
195220
McpTransportContext transportContext = this.contextExtractor.apply(request);
196221

197222
return Mono.defer(() -> {
198-
if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) {
199-
return ServerResponse.badRequest().build(); // TODO: say we need a session
200-
// id
223+
List<String> badRequestErrors = new ArrayList<>();
224+
225+
String accept = request.headers().asHttpHeaders().getFirst(ACCEPT);
226+
if (accept == null || !accept.contains(MediaType.TEXT_EVENT_STREAM_VALUE)) {
227+
badRequestErrors.add("text/event-stream required in Accept header");
201228
}
202229

203-
String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id");
230+
String sessionId = request.headers().asHttpHeaders().getFirst(MCP_SESSION_ID);
231+
232+
if (sessionId == null || sessionId.isBlank()) {
233+
badRequestErrors.add("Session ID required in mcp-session-id header");
234+
}
235+
236+
if (!badRequestErrors.isEmpty()) {
237+
String combinedMessage = String.join("; ", badRequestErrors);
238+
try {
239+
String errorJson = objectMapper.writeValueAsString(new McpError(combinedMessage));
240+
return ServerResponse.badRequest().bodyValue(errorJson);
241+
}
242+
catch (JsonProcessingException e) {
243+
logger.debug("Failed to serialize McpError: {}", e);
244+
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
245+
.bodyValue("Failed to serialize error message.");
246+
}
247+
}
204248

205249
McpStreamableServerSession session = this.sessions.get(sessionId);
206250

207251
if (session == null) {
208252
return ServerResponse.notFound().build();
209253
}
210254

211-
if (request.headers().asHttpHeaders().containsKey("mcp-last-id")) {
212-
String lastId = request.headers().asHttpHeaders().getFirst("mcp-last-id");
255+
if (request.headers().asHttpHeaders().containsKey(LAST_EVENT_ID)) {
256+
String lastId = request.headers().asHttpHeaders().getFirst(LAST_EVENT_ID);
213257
return ServerResponse.ok()
214258
.contentType(MediaType.TEXT_EVENT_STREAM)
215259
.body(session.replay(lastId), ServerSentEvent.class);
@@ -252,9 +296,31 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
252296

253297
return request.bodyToMono(String.class).<ServerResponse>flatMap(body -> {
254298
try {
299+
List<String> badRequestErrors = new ArrayList<>();
300+
301+
String accept = request.headers().asHttpHeaders().getFirst(ACCEPT);
302+
if (accept == null || !accept.contains(MediaType.TEXT_EVENT_STREAM_VALUE)) {
303+
badRequestErrors.add("text/event-stream required in Accept header");
304+
}
305+
if (accept == null || !accept.contains(MediaType.APPLICATION_JSON_VALUE)) {
306+
badRequestErrors.add("application/json required in Accept header");
307+
}
308+
255309
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
256310
if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest
257311
&& jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) {
312+
if (!badRequestErrors.isEmpty()) {
313+
String combinedMessage = String.join("; ", badRequestErrors);
314+
try {
315+
String errorJson = objectMapper.writeValueAsString(new McpError(combinedMessage));
316+
return ServerResponse.badRequest().bodyValue(errorJson);
317+
}
318+
catch (JsonProcessingException e) {
319+
logger.debug("Failed to serialize McpError: {}", e);
320+
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
321+
.bodyValue("Failed to serialize error message.");
322+
}
323+
}
258324
McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(),
259325
new TypeReference<McpSchema.InitializeRequest>() {
260326
});
@@ -274,15 +340,29 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
274340
})
275341
.flatMap(initResult -> ServerResponse.ok()
276342
.contentType(MediaType.APPLICATION_JSON)
277-
.header("mcp-session-id", init.session().getId())
343+
.header(MCP_SESSION_ID, init.session().getId())
278344
.bodyValue(initResult));
279345
}
280346

281-
if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) {
282-
return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing"));
347+
String sessionId = request.headers().asHttpHeaders().getFirst(MCP_SESSION_ID);
348+
349+
if (sessionId == null || sessionId.isBlank()) {
350+
badRequestErrors.add("Session ID required in mcp-session-id header");
351+
}
352+
353+
if (!badRequestErrors.isEmpty()) {
354+
String combinedMessage = String.join("; ", badRequestErrors);
355+
try {
356+
String errorJson = objectMapper.writeValueAsString(new McpError(combinedMessage));
357+
return ServerResponse.badRequest().bodyValue(errorJson);
358+
}
359+
catch (JsonProcessingException e) {
360+
logger.debug("Failed to serialize McpError: {}", e);
361+
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
362+
.bodyValue("Failed to serialize error message.");
363+
}
283364
}
284365

285-
String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id");
286366
McpStreamableServerSession session = sessions.get(sessionId);
287367

288368
if (session == null) {
@@ -330,7 +410,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
330410
McpTransportContext transportContext = this.contextExtractor.apply(request);
331411

332412
return Mono.defer(() -> {
333-
if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) {
413+
if (!request.headers().asHttpHeaders().containsKey(MCP_SESSION_ID)) {
334414
return ServerResponse.badRequest().build(); // TODO: say we need a session
335415
// id
336416
}
@@ -340,7 +420,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
340420
return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
341421
}
342422

343-
String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id");
423+
String sessionId = request.headers().asHttpHeaders().getFirst(MCP_SESSION_ID);
344424

345425
McpStreamableServerSession session = this.sessions.get(sessionId);
346426

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,26 @@
66

77
import java.io.IOException;
88
import java.time.Duration;
9+
import java.util.ArrayList;
10+
import java.util.List;
911
import java.util.concurrent.ConcurrentHashMap;
1012
import java.util.concurrent.locks.ReentrantLock;
1113
import java.util.function.Function;
1214

15+
import org.slf4j.Logger;
16+
import org.slf4j.LoggerFactory;
17+
import org.springframework.http.HttpStatus;
18+
import org.springframework.http.MediaType;
19+
import org.springframework.web.servlet.function.RouterFunction;
20+
import org.springframework.web.servlet.function.RouterFunctions;
21+
import org.springframework.web.servlet.function.ServerRequest;
22+
import org.springframework.web.servlet.function.ServerResponse;
23+
import org.springframework.web.servlet.function.ServerResponse.SseBuilder;
24+
25+
import com.fasterxml.jackson.core.JsonProcessingException;
1326
import com.fasterxml.jackson.core.type.TypeReference;
1427
import com.fasterxml.jackson.databind.ObjectMapper;
28+
1529
import io.modelcontextprotocol.spec.DefaultMcpTransportContext;
1630
import io.modelcontextprotocol.spec.McpError;
1731
import io.modelcontextprotocol.spec.McpSchema;
@@ -20,18 +34,8 @@
2034
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
2135
import io.modelcontextprotocol.spec.McpTransportContext;
2236
import io.modelcontextprotocol.util.Assert;
23-
import org.slf4j.Logger;
24-
import org.slf4j.LoggerFactory;
2537
import reactor.core.publisher.Mono;
2638

27-
import org.springframework.http.HttpStatus;
28-
import org.springframework.http.MediaType;
29-
import org.springframework.web.servlet.function.RouterFunction;
30-
import org.springframework.web.servlet.function.RouterFunctions;
31-
import org.springframework.web.servlet.function.ServerRequest;
32-
import org.springframework.web.servlet.function.ServerResponse;
33-
import org.springframework.web.servlet.function.ServerResponse.SseBuilder;
34-
3539
/**
3640
* Server-side implementation of the Model Context Protocol (MCP) streamable transport
3741
* layer using HTTP with Server-Sent Events (SSE) through Spring WebMVC. This
@@ -44,6 +48,7 @@
4448
*
4549
* @author Christian Tzolov
4650
* @author Dariusz Jędrzejczyk
51+
* @author Zachary German
4752
* @see McpStreamableServerTransportProvider
4853
* @see RouterFunction
4954
*/
@@ -69,7 +74,12 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer
6974
/**
7075
* Header name for the last message ID used in replay requests.
7176
*/
72-
private static final String MCP_LAST_ID = "mcp-last-id";
77+
private static final String LAST_EVENT_ID = "Last-Event-ID";
78+
79+
/**
80+
* Header name for the response media types accepted by the requester.
81+
*/
82+
private static final String ACCEPT = "Accept";
7383

7484
/**
7585
* Default base URL for the message endpoint.
@@ -216,11 +226,32 @@ private ServerResponse handleGet(ServerRequest request) {
216226

217227
McpTransportContext transportContext = this.contextExtractor.apply(request);
218228

219-
if (!request.headers().asHttpHeaders().containsKey(MCP_SESSION_ID)) {
220-
return ServerResponse.badRequest().body("Session ID required in mcp-session-id header");
229+
List<String> badRequestErrors = new ArrayList<>();
230+
231+
String accept = request.headers().asHttpHeaders().getFirst(ACCEPT);
232+
if (accept == null || !accept.contains(MediaType.TEXT_EVENT_STREAM_VALUE)) {
233+
badRequestErrors.add("text/event-stream required in Accept header");
221234
}
222235

223236
String sessionId = request.headers().asHttpHeaders().getFirst(MCP_SESSION_ID);
237+
238+
if (sessionId == null || sessionId.isBlank()) {
239+
badRequestErrors.add("Session ID required in mcp-session-id header");
240+
}
241+
242+
if (!badRequestErrors.isEmpty()) {
243+
String combinedMessage = String.join("; ", badRequestErrors);
244+
try {
245+
String errorJson = objectMapper.writeValueAsString(new McpError(combinedMessage));
246+
return ServerResponse.badRequest().body(errorJson);
247+
}
248+
catch (JsonProcessingException e) {
249+
logger.debug("Failed to serialize McpError: {}", e);
250+
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
251+
.body("Failed to serialize error message.");
252+
}
253+
}
254+
224255
McpStreamableServerSession session = this.sessions.get(sessionId);
225256

226257
if (session == null) {
@@ -239,8 +270,8 @@ private ServerResponse handleGet(ServerRequest request) {
239270
sessionId, sseBuilder);
240271

241272
// Check if this is a replay request
242-
if (request.headers().asHttpHeaders().containsKey(MCP_LAST_ID)) {
243-
String lastId = request.headers().asHttpHeaders().getFirst(MCP_LAST_ID);
273+
if (request.headers().asHttpHeaders().containsKey(LAST_EVENT_ID)) {
274+
String lastId = request.headers().asHttpHeaders().getFirst(LAST_EVENT_ID);
244275

245276
try {
246277
session.replay(lastId)
@@ -294,12 +325,35 @@ private ServerResponse handlePost(ServerRequest request) {
294325
McpTransportContext transportContext = this.contextExtractor.apply(request);
295326

296327
try {
328+
List<String> badRequestErrors = new ArrayList<>();
329+
330+
String accept = request.headers().asHttpHeaders().getFirst(ACCEPT);
331+
if (accept == null || !accept.contains(MediaType.TEXT_EVENT_STREAM_VALUE)) {
332+
badRequestErrors.add("text/event-stream required in Accept header");
333+
}
334+
if (accept == null || !accept.contains(MediaType.APPLICATION_JSON_VALUE)) {
335+
badRequestErrors.add("application/json required in Accept header");
336+
}
337+
297338
String body = request.body(String.class);
298339
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
299340

300341
// Handle initialization request
301342
if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest
302343
&& jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) {
344+
if (!badRequestErrors.isEmpty()) {
345+
String combinedMessage = String.join("; ", badRequestErrors);
346+
try {
347+
String errorJson = objectMapper.writeValueAsString(new McpError(combinedMessage));
348+
return ServerResponse.badRequest().body(errorJson);
349+
}
350+
catch (JsonProcessingException e) {
351+
logger.debug("Failed to serialize McpError: {}", e);
352+
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
353+
.body("Failed to serialize error message.");
354+
}
355+
}
356+
303357
McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(),
304358
new TypeReference<McpSchema.InitializeRequest>() {
305359
});
@@ -323,11 +377,25 @@ private ServerResponse handlePost(ServerRequest request) {
323377
}
324378

325379
// Handle other messages that require a session
326-
if (!request.headers().asHttpHeaders().containsKey(MCP_SESSION_ID)) {
327-
return ServerResponse.badRequest().body(new McpError("Session ID missing"));
380+
String sessionId = request.headers().asHttpHeaders().getFirst(MCP_SESSION_ID);
381+
382+
if (sessionId == null || sessionId.isBlank()) {
383+
badRequestErrors.add("Session ID required in mcp-session-id header");
384+
}
385+
386+
if (!badRequestErrors.isEmpty()) {
387+
String combinedMessage = String.join("; ", badRequestErrors);
388+
try {
389+
String errorJson = objectMapper.writeValueAsString(new McpError(combinedMessage));
390+
return ServerResponse.badRequest().body(errorJson);
391+
}
392+
catch (JsonProcessingException e) {
393+
logger.debug("Failed to serialize McpError: {}", e);
394+
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
395+
.body("Failed to serialize error message.");
396+
}
328397
}
329398

330-
String sessionId = request.headers().asHttpHeaders().getFirst(MCP_SESSION_ID);
331399
McpStreamableServerSession session = this.sessions.get(sessionId);
332400

333401
if (session == null) {

0 commit comments

Comments
 (0)