diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 40767f416..20a953dbb 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -4,16 +4,8 @@ package io.modelcontextprotocol.server.transport; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.PrintWriter; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.modelcontextprotocol.json.McpJsonMapper; - import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.server.McpStatelessServerHandler; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; @@ -25,8 +17,14 @@ import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; + /** * Implementation of an HttpServlet based {@link McpStatelessServerTransport}. * @@ -52,12 +50,15 @@ public class HttpServletStatelessServerTransport extends HttpServlet implements private final String mcpEndpoint; - private McpStatelessServerHandler mcpHandler; + private volatile McpStatelessServerHandler mcpHandler; - private McpTransportContextExtractor contextExtractor; + private final McpTransportContextExtractor contextExtractor; private volatile boolean isClosing = false; + private volatile GetHandler getHandler = (request, response) -> response + .sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + private HttpServletStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor contextExtractor) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); @@ -74,11 +75,28 @@ public void setMcpHandler(McpStatelessServerHandler mcpHandler) { this.mcpHandler = mcpHandler; } + @Override + public McpStatelessServerHandler getMcpHandler() { + return mcpHandler; + } + @Override public Mono closeGracefully() { return Mono.fromRunnable(() -> this.isClosing = true); } + public void setGetHandler(GetHandler getHandler) { + Assert.notNull(getHandler, "getHandler must not be null"); + + this.getHandler = getHandler; + } + + public interface GetHandler { + + void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException; + + } + /** * Handles GET requests - returns 405 METHOD NOT ALLOWED as stateless transport * doesn't support GET requests. @@ -97,7 +115,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } - response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + getHandler.doGet(request, response); } /** diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java index d1c2e5206..b64a05c2a 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -4,15 +4,17 @@ package io.modelcontextprotocol.spec; -import java.util.List; - import io.modelcontextprotocol.server.McpStatelessServerHandler; import reactor.core.publisher.Mono; +import java.util.List; + public interface McpStatelessServerTransport { void setMcpHandler(McpStatelessServerHandler mcpHandler); + McpStatelessServerHandler getMcpHandler(); + /** * Immediately closes all the transports with connected clients and releases any * associated resources. diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java index 400be341e..75fe4950b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java @@ -4,8 +4,8 @@ package io.modelcontextprotocol.server.transport; -import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.server.McpStatelessServerHandler; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; @@ -40,12 +40,14 @@ public class WebFluxStatelessServerTransport implements McpStatelessServerTransp private final RouterFunction routerFunction; - private McpStatelessServerHandler mcpHandler; + private volatile McpStatelessServerHandler mcpHandler; - private McpTransportContextExtractor contextExtractor; + private final McpTransportContextExtractor contextExtractor; private volatile boolean isClosing = false; + private volatile GetHandler getHandler = (request) -> ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + private WebFluxStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor contextExtractor) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); @@ -66,6 +68,11 @@ public void setMcpHandler(McpStatelessServerHandler mcpHandler) { this.mcpHandler = mcpHandler; } + @Override + public McpStatelessServerHandler getMcpHandler() { + return mcpHandler; + } + @Override public Mono closeGracefully() { return Mono.fromRunnable(() -> this.isClosing = true); @@ -87,8 +94,20 @@ public RouterFunction getRouterFunction() { return this.routerFunction; } + public interface GetHandler { + + Mono doGet(ServerRequest request); + + } + + public void setGetHandler(GetHandler getHandler) { + Assert.notNull(getHandler, "getHandler must not be null"); + + this.getHandler = getHandler; + } + private Mono handleGet(ServerRequest request) { - return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + return getHandler.doGet(request); } private Mono handlePost(ServerRequest request) { diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java index 4223084ff..055885bb9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java @@ -44,12 +44,14 @@ public class WebMvcStatelessServerTransport implements McpStatelessServerTranspo private final RouterFunction routerFunction; - private McpStatelessServerHandler mcpHandler; + private volatile McpStatelessServerHandler mcpHandler; - private McpTransportContextExtractor contextExtractor; + private final McpTransportContextExtractor contextExtractor; private volatile boolean isClosing = false; + private volatile GetHandler getHandler = (request) -> ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + private WebMvcStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor contextExtractor) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); @@ -70,6 +72,11 @@ public void setMcpHandler(McpStatelessServerHandler mcpHandler) { this.mcpHandler = mcpHandler; } + @Override + public McpStatelessServerHandler getMcpHandler() { + return mcpHandler; + } + @Override public Mono closeGracefully() { return Mono.fromRunnable(() -> this.isClosing = true); @@ -91,8 +98,20 @@ public RouterFunction getRouterFunction() { return this.routerFunction; } + public interface GetHandler { + + ServerResponse doGet(ServerRequest request); + + } + + public void setGetHandler(GetHandler getHandler) { + Assert.notNull(getHandler, "getHandler must not be null"); + + this.getHandler = getHandler; + } + private ServerResponse handleGet(ServerRequest request) { - return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + return getHandler.doGet(request); } private ServerResponse handlePost(ServerRequest request) {