diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index ae093316f..601131a29 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -16,6 +16,8 @@ import java.util.function.Consumer; import java.util.function.Function; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport.Builder; +import io.modelcontextprotocol.client.transport.ResponseSubscribers.SseResponseEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; @@ -116,6 +118,8 @@ public class HttpClientSseClientTransport implements McpClientTransport { */ private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; + private final AtomicReference> connectionClosedHandler = new AtomicReference<>(); + /** * Creates a new transport instance with custom HTTP client builder, object mapper, * and headers. @@ -129,7 +133,8 @@ public class HttpClientSseClientTransport implements McpClientTransport { * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null */ HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, - String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) { + String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, + Consumer connectionClosedHandler) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); @@ -142,6 +147,7 @@ public class HttpClientSseClientTransport implements McpClientTransport { this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.httpRequestCustomizer = httpRequestCustomizer; + this.connectionClosedHandler.set(connectionClosedHandler); } @Override @@ -149,6 +155,20 @@ public List protocolVersions() { return List.of(ProtocolVersions.MCP_2024_11_05); } + @Override + public void setConnectionClosedHandler(Consumer closedHandler) { + logger.debug("Connection closed handler registered"); + connectionClosedHandler.set(closedHandler); + } + + private void handleConnectionClosed() { + logger.debug("Handling connection closed"); + Consumer handler = this.connectionClosedHandler.get(); + if (handler != null) { + handler.accept(null); + } + } + /** * Creates a new builder for {@link HttpClientSseClientTransport}. * @param baseUri the base URI of the MCP server @@ -177,6 +197,8 @@ public static class Builder { private Duration connectTimeout = Duration.ofSeconds(10); + private Consumer connectionClosedHandler = null; + /** * Creates a new builder instance. */ @@ -320,6 +342,17 @@ public Builder connectTimeout(Duration connectTimeout) { return this; } + /** + * Set the connection closed handler. + * @param connectionClosedHandler the connection closed handler + * @return this builder + */ + public Builder connectionClosedHandler(Consumer connectionClosedHandler) { + Assert.notNull(connectionClosedHandler, "connectionClosedHandler must not be null"); + this.connectionClosedHandler = connectionClosedHandler; + return this; + } + /** * Builds a new {@link HttpClientSseClientTransport} instance. * @return a new transport instance @@ -327,7 +360,8 @@ public Builder connectTimeout(Duration connectTimeout) { public HttpClientSseClientTransport build() { HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint, - jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, httpRequestCustomizer); + jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, httpRequestCustomizer, + connectionClosedHandler); } } @@ -352,9 +386,7 @@ public Mono connect(Function, Mono> h .exceptionallyCompose(e -> { sseSink.error(e); return CompletableFuture.failedFuture(e); - })) - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) - .flatMap(responseEvent -> { + })).map(responseEvent -> (SseResponseEvent) responseEvent).flatMap(responseEvent -> { if (isClosing) { return Mono.empty(); } @@ -388,26 +420,21 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { sink.error(new McpTransportException("Error processing SSE event", e)); } } - return Flux.error( - new RuntimeException("Failed to send message: " + responseEvent)); + return Flux.error(new RuntimeException("Failed to send message: " + responseEvent)); - }) - .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) - .onErrorComplete(t -> { + }).flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))).onErrorComplete(t -> { if (!isClosing) { logger.warn("SSE stream observed an error", t); sink.error(t); } return true; - }) - .doFinally(s -> { + }).doFinally(s -> { Disposable ref = this.sseSubscription.getAndSet(null); if (ref != null && !ref.isDisposed()) { ref.dispose(); } - }) - .contextWrite(sink.contextView()) - .subscribe(); + handleConnectionClosed(); + }).contextWrite(sink.contextView()).subscribe(); this.sseSubscription.set(connection); })); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index cd8fa171f..1ba463437 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -4,32 +4,13 @@ package io.modelcontextprotocol.client.transport; -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.net.http.HttpResponse.BodyHandler; -import java.time.Duration; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.CompletionException; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Function; - -import org.reactivestreams.Publisher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.modelcontextprotocol.json.TypeRef; -import io.modelcontextprotocol.json.McpJsonMapper; - +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; -import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.ClosedMcpTransportSession; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.DefaultMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportStream; import io.modelcontextprotocol.spec.HttpHeaders; @@ -42,6 +23,9 @@ import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; @@ -49,6 +33,20 @@ import reactor.util.function.Tuple2; import reactor.util.function.Tuples; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + /** * An implementation of the Streamable HTTP protocol as defined by the * 2025-03-26 version of the MCP specification. @@ -125,9 +123,12 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private final AtomicReference> exceptionHandler = new AtomicReference<>(); + private final AtomicReference> connectionClosedHandler = new AtomicReference<>(); + private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, - boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) { + boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, + Consumer connectionClosedHandler) { this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; @@ -137,6 +138,7 @@ private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient h this.openConnectionOnStartup = openConnectionOnStartup; this.activeSession.set(createTransportSession()); this.httpRequestCustomizer = httpRequestCustomizer; + this.connectionClosedHandler.set(connectionClosedHandler); } @Override @@ -202,6 +204,12 @@ public void setExceptionHandler(Consumer handler) { this.exceptionHandler.set(handler); } + @Override + public void setConnectionClosedHandler(Consumer closedHandler) { + logger.debug("Connection closed handler registered"); + this.connectionClosedHandler.set(closedHandler); + } + private void handleException(Throwable t) { logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); if (t instanceof McpTransportSessionNotFoundException) { @@ -215,6 +223,14 @@ private void handleException(Throwable t) { } } + private void handleConnectionClosed() { + logger.debug("Handling connection closed for session {}", sessionIdOrPlaceholder(this.activeSession.get())); + Consumer handler = this.connectionClosedHandler.get(); + if (handler != null) { + handler.accept(null); + } + } + @Override public Mono closeGracefully() { return Mono.defer(() -> { @@ -365,6 +381,7 @@ else if (statusCode == BAD_REQUEST) { if (ref != null) { transportSession.removeConnection(ref); } + this.handleConnectionClosed(); })) .contextWrite(ctx) .subscribe(); @@ -624,6 +641,8 @@ public static class Builder { private Duration connectTimeout = Duration.ofSeconds(10); + private Consumer connectionClosedHandler = null; + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -772,6 +791,17 @@ public Builder connectTimeout(Duration connectTimeout) { return this; } + /** + * Set the connection closed handler. + * @param connectionClosedHandler the connection closed handler + * @return this builder + */ + public Builder connectionClosedHandler(Consumer connectionClosedHandler) { + Assert.notNull(connectionClosedHandler, "connectionClosedHandler must not be null"); + this.connectionClosedHandler = connectionClosedHandler; + return this; + } + /** * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using * the current builder configuration. @@ -781,7 +811,7 @@ public HttpClientStreamableHttpTransport build() { HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); return new HttpClientStreamableHttpTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, httpClient, requestBuilder, baseUri, endpoint, resumableStreams, openConnectionOnStartup, - httpRequestCustomizer); + httpRequestCustomizer, connectionClosedHandler); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index 22aec831b..495c59310 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -38,4 +38,12 @@ public interface McpClientTransport extends McpTransport { default void setExceptionHandler(Consumer handler) { } + /** + * Sets the handler for the transport closed event. + * @param closedHandler Allows reacting to transport closed event by the higher layers + */ + default void setConnectionClosedHandler(Consumer closedHandler) { + + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index c5c365798..816d4eb9c 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -4,37 +4,36 @@ package io.modelcontextprotocol.client.transport; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; - import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.mockito.ArgumentCaptor; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.util.UriComponentsBuilder; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.util.UriComponentsBuilder; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; @@ -78,7 +77,8 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo public TestHttpClientSseClientTransport(final String baseUri) { super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", JSON_MAPPER, - McpAsyncHttpClientRequestCustomizer.NOOP); + McpAsyncHttpClientRequestCustomizer.NOOP, v -> { + }); } public int getInboundMessageCount() { @@ -130,7 +130,8 @@ void testErrorOnBogusMessage() { StepVerifier.create(transport.sendMessage(bogusMessage)) .verifyErrorMessage( - "Sending message failed with a non-OK HTTP code: 400 - Invalid message: {\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"); + "Sending message failed with a non-OK HTTP code: 400 - Invalid message: {\"id\":\"test-id\"," + + "\"params\":{\"key\":\"value\"}}"); } @Test @@ -477,4 +478,14 @@ void testAsyncRequestCustomizer() { customizedTransport.closeGracefully().block(); } + @Test + void testTransportConnectionClosedHandler() { + AtomicReference closedHandlerCalled = new AtomicReference<>(false); + transport.setConnectionClosedHandler(v -> closedHandlerCalled.set(true)); + // transport close simulate the behavior of disconnection + transport.closeGracefully().block(); + // Verify the closed handler was called + Assertions.assertTrue(closedHandlerCalled.get()); + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportConnectionClosedHandlingTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportConnectionClosedHandlingTest.java new file mode 100644 index 000000000..97071540d --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportConnectionClosedHandlingTest.java @@ -0,0 +1,135 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import com.sun.net.httpserver.HttpServer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @author taobaorun + */ +public class HttpClientStreamableHttpTransportConnectionClosedHandlingTest { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String HOST = "http://localhost:" + PORT; + + private HttpServer server; + + private AtomicReference serverResponseStatus = new AtomicReference<>(200); + + private AtomicReference currentServerSessionId = new AtomicReference<>(null); + + private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + + private McpClientTransport transport; + + private McpTransportContext context = McpTransportContext + .create(Map.of("test-transport-context-key", "some-value")); + + @BeforeEach + void startServer() throws IOException { + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Configure the /mcp endpoint with dynamic response + server.createContext("/mcp", httpExchange -> { + if ("DELETE".equals(httpExchange.getRequestMethod())) { + httpExchange.sendResponseHeaders(200, 0); + } + else if ("POST".equals(httpExchange.getRequestMethod())) { + // Capture session ID from request if present + String requestSessionId = httpExchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + lastReceivedSessionId.set(requestSessionId); + + int status = serverResponseStatus.get(); + + // Set response headers + httpExchange.getResponseHeaders().set("Content-Type", "application/json"); + + // Add session ID to response if configured + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + httpExchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + + // Send response based on configured status + if (status == 200) { + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + httpExchange.sendResponseHeaders(200, response.length()); + httpExchange.getResponseBody().write(response.getBytes()); + } + else { + httpExchange.sendResponseHeaders(status, 0); + } + } + httpExchange.close(); + }); + + server.setExecutor(null); + server.start(); + + transport = HttpClientStreamableHttpTransport.builder(HOST).build(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + } + + @Test + void testTransportConnectionClosedHandler() throws URISyntaxException { + AtomicReference closedHandlerCalled = new AtomicReference<>(false); + var transport = HttpClientStreamableHttpTransport.builder(HOST) + .connectionClosedHandler(v -> closedHandlerCalled.set(true)) + .build(); + + withTransport(transport, (t) -> { + // Send test message + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_06_18, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier + .create(t.sendMessage(testMessage).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); + + // close transport + transport.closeGracefully().subscribe(); + assertTrue(closedHandlerCalled.get()); + + }); + } + + void withTransport(HttpClientStreamableHttpTransport transport, Consumer c) { + try { + c.accept(transport); + } + finally { + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + } + +}