diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 6d8e82f51..853aed2bf 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -31,6 +31,7 @@ import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; import io.modelcontextprotocol.spec.McpTransportSession; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportStream; @@ -70,6 +71,8 @@ */ public class WebClientStreamableHttpTransport implements McpClientTransport { + private static final String MISSING_SESSION_ID = "[missing_session_id]"; + private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; @@ -221,8 +224,13 @@ else if (isNotAllowed(response)) { return Flux.empty(); } else if (isNotFound(response)) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - return mcpSessionNotFoundError(sessionIdRepresentation); + if (transportSession.sessionId().isPresent()) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + return mcpSessionNotFoundError(sessionIdRepresentation); + } + else { + return this.extractError(response, MISSING_SESSION_ID); + } } else { return response.createError().doOnError(e -> { @@ -318,10 +326,10 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { } } else { - if (isNotFound(response)) { + if (isNotFound(response) && !sessionRepresentation.equals(MISSING_SESSION_ID)) { return mcpSessionNotFoundError(sessionRepresentation); } - return extractError(response, sessionRepresentation); + return this.extractError(response, sessionRepresentation); } }) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) @@ -362,10 +370,10 @@ private Flux extractError(ClientResponse response, Str McpSchema.JSONRPCResponse.class); jsonRpcError = jsonRpcResponse.error(); toPropagate = jsonRpcError != null ? new McpError(jsonRpcError) - : new McpError("Can't parse the jsonResponse " + jsonRpcResponse); + : new McpTransportException("Can't parse the jsonResponse " + jsonRpcResponse); } catch (IOException ex) { - toPropagate = new RuntimeException("Sending request failed", e); + toPropagate = new McpTransportException("Sending request failed, " + e.getMessage(), e); logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body); } @@ -374,7 +382,11 @@ private Flux extractError(ClientResponse response, Str // invalidate the session // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { - return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); + if (!sessionRepresentation.equals(MISSING_SESSION_ID)) { + return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); + } + return Mono.error(new McpTransportException("Received 400 BAD REQUEST for session " + + sessionRepresentation + ". " + toPropagate.getMessage(), toPropagate)); } return Mono.error(toPropagate); }).flux(); @@ -403,7 +415,7 @@ private static boolean isEventStream(ClientResponse response) { } private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { - return transportSession.sessionId().orElse("[missing_session_id]"); + return transportSession.sessionId().orElse(MISSING_SESSION_ID); } private Flux directResponseFlux(McpSchema.JSONRPCMessage sentMessage, @@ -421,8 +433,7 @@ private Flux directResponseFlux(McpSchema.JSONRPCMessa } } catch (IOException e) { - // TODO: this should be a McpTransportError - s.error(e); + s.error(new McpTransportException(e)); } }).flatMapIterable(Function.identity()); } @@ -449,7 +460,7 @@ private Tuple2, Iterable> parse(Serve return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); } catch (IOException ioException) { - throw new McpError("Error parsing JSON-RPC message: " + event.data()); + throw new McpTransportException("Error parsing JSON-RPC message: " + event.data(), ioException); } } else { diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 75caebef0..51d21d18b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -14,7 +14,6 @@ import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.spec.ProtocolVersions; @@ -197,8 +196,6 @@ public List protocolVersions() { * @param handler a function that processes incoming JSON-RPC messages and returns * responses * @return a Mono that completes when the connection is fully established - * @throws McpError if there's an error processing SSE events or if an unrecognized - * event type is received */ @Override public Mono connect(Function, Mono> handler) { @@ -215,7 +212,7 @@ public Mono connect(Function, Mono> h else { // TODO: clarify with the spec if multiple events can be // received - s.error(new McpError("Failed to handle SSE endpoint event")); + s.error(new RuntimeException("Failed to handle SSE endpoint event")); } } else if (MESSAGE_EVENT_TYPE.equals(event.event())) { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index a1f1a8947..6140fe489 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -8,6 +8,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; @@ -26,6 +27,7 @@ import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; +@Timeout(15) class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java index 302c58c5f..5516e55b7 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java @@ -8,6 +8,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; @@ -26,6 +27,7 @@ import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; +@Timeout(15) class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java index 616c6dcf8..9eba0e57c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -8,6 +8,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; @@ -26,6 +27,7 @@ import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; +@Timeout(15) class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java new file mode 100644 index 000000000..cdbb97e17 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java @@ -0,0 +1,404 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for error handling in WebClientStreamableHttpTransport. Addresses concurrency + * issues with proper Reactor patterns. + * + * @author Christian Tzolov + */ +@Timeout(15) +public class WebClientStreamableHttpTransportErrorHandlingTest { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String HOST = "http://localhost:" + PORT; + + private HttpServer server; + + private AtomicReference serverResponseStatus = new AtomicReference<>(200); + + private AtomicReference currentServerSessionId = new AtomicReference<>(null); + + private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + + private McpClientTransport transport; + + // Initialize latches for proper request synchronization + CountDownLatch firstRequestLatch; + + CountDownLatch secondRequestLatch; + + CountDownLatch getRequestLatch; + + @BeforeEach + void startServer() throws IOException { + + // Initialize latches for proper synchronization + firstRequestLatch = new CountDownLatch(1); + secondRequestLatch = new CountDownLatch(1); + getRequestLatch = new CountDownLatch(1); + + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Configure the /mcp endpoint with dynamic response + server.createContext("/mcp", exchange -> { + String method = exchange.getRequestMethod(); + + if ("GET".equals(method)) { + // This is the SSE connection attempt after session establishment + getRequestLatch.countDown(); + // Return 405 Method Not Allowed to indicate SSE not supported + exchange.sendResponseHeaders(405, 0); + exchange.close(); + return; + } + + String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + lastReceivedSessionId.set(requestSessionId); + + int status = serverResponseStatus.get(); + + // Track which request this is + if (firstRequestLatch.getCount() > 0) { + // // First request - should have no session ID + firstRequestLatch.countDown(); + } + else if (secondRequestLatch.getCount() > 0) { + // Second request - should have session ID + secondRequestLatch.countDown(); + } + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + + // Don't include session ID in 404 and 400 responses - the implementation + // checks if the transport has a session stored locally + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null && status == 200) { + exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + if (status == 200) { + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + } + else { + exchange.sendResponseHeaders(status, 0); + } + exchange.close(); + }); + + server.setExecutor(null); + server.start(); + + transport = WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(HOST)).build(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 404 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test404WithoutSessionId() { + serverResponseStatus.set(404); + currentServerSessionId.set(null); // No session ID in response + + var testMessage = createTestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Not Found") && throwable.getMessage().contains("404") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(Duration.ofSeconds(5)); + } + + /** + * Test that 404 response WITH session ID throws McpTransportSessionNotFoundException + * Fixed version using proper async coordination + */ + @Test + void test404WithSessionId() throws InterruptedException { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-123"); + + // Set up exception handler to verify session invalidation + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestMessage(); + + // Send first message to establish session + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Wait for first request to complete + assertThat(firstRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Wait for the GET request (SSE connection attempt) to complete + assertThat(getRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Now return 404 for next request + serverResponseStatus.set(404); + + // Use delaySubscription to ensure session is fully processed before next + // request + StepVerifier.create(Mono.delay(Duration.ofMillis(200)).then(transport.sendMessage(testMessage))) + .expectError(McpTransportSessionNotFoundException.class) + .verify(Duration.ofSeconds(5)); + + // Wait for second request to be made + assertThat(secondRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Verify the second request included the session ID + assertThat(lastReceivedSessionId.get()).isEqualTo("test-session-123"); + + // Verify exception handler was called with SessionNotFoundException using + // timeout + verify(exceptionHandler, timeout(5000)).accept(any(McpTransportSessionNotFoundException.class)); + } + + /** + * Test that 400 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test400WithoutSessionId() { + serverResponseStatus.set(400); + currentServerSessionId.set(null); // No session ID + + var testMessage = createTestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Bad Request") && throwable.getMessage().contains("400") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(Duration.ofSeconds(5)); + } + + /** + * Test that 400 response WITH session ID throws McpTransportSessionNotFoundException + * Fixed version using proper async coordination + */ + @Test + void test400WithSessionId() throws InterruptedException { + + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-456"); + + // Set up exception handler + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestMessage(); + + // Send first message to establish session + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Wait for first request to complete + boolean firstCompleted = firstRequestLatch.await(5, TimeUnit.SECONDS); + assertThat(firstCompleted).isTrue(); + + // Wait for the GET request (SSE connection attempt) to complete + boolean getCompleted = getRequestLatch.await(5, TimeUnit.SECONDS); + assertThat(getCompleted).isTrue(); + + // Now return 400 for next request (simulating unknown session ID) + serverResponseStatus.set(400); + + // Use delaySubscription to ensure session is fully processed before next + // request + StepVerifier.create(Mono.delay(Duration.ofMillis(200)).then(transport.sendMessage(testMessage))) + .expectError(McpTransportSessionNotFoundException.class) + .verify(Duration.ofSeconds(5)); + + // Wait for second request to be made + boolean secondCompleted = secondRequestLatch.await(5, TimeUnit.SECONDS); + assertThat(secondCompleted).isTrue(); + + // Verify the second request included the session ID + assertThat(lastReceivedSessionId.get()).isEqualTo("test-session-456"); + + // Verify exception handler was called with timeout + verify(exceptionHandler, timeout(5000)).accept(any(McpTransportSessionNotFoundException.class)); + } + + /** + * Test session recovery after SessionNotFoundException Fixed version using reactive + * patterns and proper synchronization + */ + @Test + void testSessionRecoveryAfter404() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("session-1"); + + // Send initial message to establish session + var testMessage = createTestMessage(); + + // Use Mono.defer to ensure proper sequencing + Mono establishSession = transport.sendMessage(testMessage).then(Mono.defer(() -> { + // Simulate session loss - return 404 + serverResponseStatus.set(404); + return transport.sendMessage(testMessage).onErrorResume(McpTransportSessionNotFoundException.class, e -> { + // Expected error, continue with recovery + return Mono.empty(); + }); + })).then(Mono.defer(() -> { + // Now server is back with new session + serverResponseStatus.set(200); + currentServerSessionId.set("session-2"); + lastReceivedSessionId.set(null); // Reset to verify new session + + // Should be able to establish new session + return transport.sendMessage(testMessage); + })).then(Mono.defer(() -> { + // Verify no session ID was sent (since old session was invalidated) + assertThat(lastReceivedSessionId.get()).isNull(); + + // Next request should use the new session ID + return transport.sendMessage(testMessage); + })).doOnSuccess(v -> { + // Session ID should now be sent with requests + assertThat(lastReceivedSessionId.get()).isEqualTo("session-2"); + }); + + StepVerifier.create(establishSession).verifyComplete(); + } + + /** + * Test that reconnect (GET request) also properly handles 404/400 errors Fixed + * version with proper async handling + */ + @Test + void testReconnectErrorHandling() throws InterruptedException { + // Initialize latch for SSE connection + CountDownLatch sseConnectionLatch = new CountDownLatch(1); + + // Set up SSE endpoint for GET requests + server.createContext("/mcp-sse", exchange -> { + String method = exchange.getRequestMethod(); + String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + if ("GET".equals(method)) { + sseConnectionLatch.countDown(); + int status = serverResponseStatus.get(); + + if (status == 404 && requestSessionId != null) { + // 404 with session ID - should trigger SessionNotFoundException + exchange.sendResponseHeaders(404, 0); + } + else if (status == 404) { + // 404 without session ID - should trigger McpTransportException + exchange.sendResponseHeaders(404, 0); + } + else { + // Normal SSE response + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, 0); + // Send a test SSE event + String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; + exchange.getResponseBody().write(sseData.getBytes()); + } + } + else { + // POST request handling + exchange.getResponseHeaders().set("Content-Type", "application/json"); + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + } + exchange.close(); + }); + + // Test with session ID - should get SessionNotFoundException + serverResponseStatus.set(200); + currentServerSessionId.set("sse-session-1"); + + var transport = WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(HOST)) + .endpoint("/mcp-sse") + .openConnectionOnStartup(true) // This will trigger GET request on connect + .build(); + + // First connect successfully + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Wait for SSE connection to be established + boolean connected = sseConnectionLatch.await(5, TimeUnit.SECONDS); + assertThat(connected).isTrue(); + + // Send message to establish session + var testMessage = createTestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Clean up + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + private McpSchema.JSONRPCRequest createTestMessage() { + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0")); + return new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", + initializeRequest); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 995cbd165..5d048353c 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -11,6 +11,7 @@ import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; @@ -29,6 +30,7 @@ import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import reactor.core.scheduler.Schedulers; +@Timeout(15) class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java index 802363d59..c7c1e710d 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java @@ -11,6 +11,7 @@ import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; @@ -29,6 +30,7 @@ import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; import reactor.core.scheduler.Schedulers; +@Timeout(15) class WebMvcStatelessIntegrationTests extends AbstractStatelessIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java index 800065915..16012e7d9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -11,6 +11,7 @@ import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; @@ -29,6 +30,7 @@ import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; import reactor.core.scheduler.Schedulers; +@Timeout(15) class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 067fbac2c..ea3739da5 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -486,7 +486,8 @@ void testAddRoot() { void testAddRootWithNullValue() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.addRoot(null)) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Root must not be null")) .verify(); }); } @@ -505,7 +506,7 @@ void testRemoveRoot() { void testRemoveNonExistentRoot() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalStateException.class) .hasMessage("Root with uri 'nonexistent-uri' not found")) .verify(); }); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java index 2e0b51748..2cc1c5dba 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java @@ -289,9 +289,7 @@ public Mono withIntitialization(String actionName, Function this.initializationRef.get()) .timeout(this.initializationTimeout) .onErrorResume(ex -> { - logger.warn("Failed to initialize", ex); - return Mono.error( - new McpError("Client failed to initialize " + actionName + " due to: " + ex.getMessage())); + return Mono.error(new RuntimeException("Client failed to initialize " + actionName, ex)); }) .flatMap(operation); }); @@ -316,8 +314,10 @@ private Mono doInitialize(DefaultInitialization init initializeResult.instructions()); if (!this.protocolVersions.contains(initializeResult.protocolVersion())) { - return Mono.error(new McpError( - "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); + return Mono.error(McpError.builder(-32602) + .message("Unsupported protocol version") + .data("Unsupported protocol version from the server: " + initializeResult.protocolVersion()) + .build()); } return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 0f2ee19fa..228313beb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -190,7 +190,8 @@ public class McpAsyncClient { // Sampling Handler if (this.clientCapabilities.sampling() != null) { if (features.samplingHandler() == null) { - throw new McpError("Sampling handler must not be null when client capabilities include sampling"); + throw new IllegalArgumentException( + "Sampling handler must not be null when client capabilities include sampling"); } this.samplingHandler = features.samplingHandler(); requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); @@ -199,7 +200,8 @@ public class McpAsyncClient { // Elicitation Handler if (this.clientCapabilities.elicitation() != null) { if (features.elicitationHandler() == null) { - throw new McpError("Elicitation handler must not be null when client capabilities include elicitation"); + throw new IllegalArgumentException( + "Elicitation handler must not be null when client capabilities include elicitation"); } this.elicitationHandler = features.elicitationHandler(); requestHandlers.put(McpSchema.METHOD_ELICITATION_CREATE, elicitationCreateHandler()); @@ -413,15 +415,15 @@ public Mono ping() { public Mono addRoot(Root root) { if (root == null) { - return Mono.error(new McpError("Root must not be null")); + return Mono.error(new IllegalArgumentException("Root must not be null")); } if (this.clientCapabilities.roots() == null) { - return Mono.error(new McpError("Client must be configured with roots capabilities")); + return Mono.error(new IllegalStateException("Client must be configured with roots capabilities")); } if (this.roots.containsKey(root.uri())) { - return Mono.error(new McpError("Root with uri '" + root.uri() + "' already exists")); + return Mono.error(new IllegalStateException("Root with uri '" + root.uri() + "' already exists")); } this.roots.put(root.uri(), root); @@ -447,11 +449,11 @@ public Mono addRoot(Root root) { public Mono removeRoot(String rootUri) { if (rootUri == null) { - return Mono.error(new McpError("Root uri must not be null")); + return Mono.error(new IllegalArgumentException("Root uri must not be null")); } if (this.clientCapabilities.roots() == null) { - return Mono.error(new McpError("Client must be configured with roots capabilities")); + return Mono.error(new IllegalStateException("Client must be configured with roots capabilities")); } Root removed = this.roots.remove(rootUri); @@ -469,7 +471,7 @@ public Mono removeRoot(String rootUri) { } return Mono.empty(); } - return Mono.error(new McpError("Root with uri '" + rootUri + "' not found")); + return Mono.error(new IllegalStateException("Root with uri '" + rootUri + "' not found")); } /** @@ -540,7 +542,7 @@ private RequestHandler elicitationCreateHandler() { public Mono callTool(McpSchema.CallToolRequest callToolRequest) { return this.initializer.withIntitialization("calling tools", init -> { if (init.initializeResult().capabilities().tools() == null) { - return Mono.error(new McpError("Server does not provide tools capability")); + return Mono.error(new IllegalStateException("Server does not provide tools capability")); } return init.mcpSession() .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); @@ -569,7 +571,7 @@ public Mono listTools() { public Mono listTools(String cursor) { return this.initializer.withIntitialization("listing tools", init -> { if (init.initializeResult().capabilities().tools() == null) { - return Mono.error(new McpError("Server does not provide tools capability")); + return Mono.error(new IllegalStateException("Server does not provide tools capability")); } return init.mcpSession() .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), @@ -633,7 +635,7 @@ public Mono listResources() { public Mono listResources(String cursor) { return this.initializer.withIntitialization("listing resources", init -> { if (init.initializeResult().capabilities().resources() == null) { - return Mono.error(new McpError("Server does not provide the resources capability")); + return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } return init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), @@ -665,7 +667,7 @@ public Mono readResource(McpSchema.Resource resour public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { return this.initializer.withIntitialization("reading resources", init -> { if (init.initializeResult().capabilities().resources() == null) { - return Mono.error(new McpError("Server does not provide the resources capability")); + return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } return init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, READ_RESOURCE_RESULT_TYPE_REF); @@ -703,7 +705,7 @@ public Mono listResourceTemplates() { public Mono listResourceTemplates(String cursor) { return this.initializer.withIntitialization("listing resource templates", init -> { if (init.initializeResult().capabilities().resources() == null) { - return Mono.error(new McpError("Server does not provide the resources capability")); + return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } return init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new McpSchema.PaginatedRequest(cursor), @@ -863,7 +865,7 @@ private NotificationHandler asyncLoggingNotificationHandler( */ public Mono setLoggingLevel(LoggingLevel loggingLevel) { if (loggingLevel == null) { - return Mono.error(new McpError("Logging level must not be null")); + return Mono.error(new IllegalArgumentException("Logging level must not be null")); } return this.initializer.withIntitialization("setting logging level", init -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 473f71fbb..0f3511afb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -24,10 +24,10 @@ import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpTransportException; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import reactor.core.Disposable; @@ -431,7 +431,7 @@ public Mono connect(Function, Mono> h return Flux.empty(); // No further processing needed } else { - sink.error(new McpError("Failed to handle SSE endpoint event")); + sink.error(new RuntimeException("Failed to handle SSE endpoint event")); } } else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { @@ -446,8 +446,7 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { } } catch (IOException e) { - logger.error("Error processing SSE event", e); - sink.error(new McpError("Error processing SSE event")); + sink.error(new McpTransportException("Error processing SSE event", e)); } } return Flux.error( @@ -520,8 +519,7 @@ private Mono serializeMessage(final JSONRPCMessage message) { return Mono.just(objectMapper.writeValueAsString(message)); } catch (IOException e) { - // TODO: why McpError and not RuntimeException? - return Mono.error(new McpError("Failed to serialize message")); + return Mono.error(new McpTransportException("Failed to serialize message", e)); } }); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index a9e5897b9..93c28422a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -30,8 +30,8 @@ import io.modelcontextprotocol.spec.DefaultMcpTransportStream; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; import io.modelcontextprotocol.spec.McpTransportSession; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportStream; @@ -288,9 +288,8 @@ private Mono reconnect(McpTransportStream stream) { } catch (IOException ioException) { - return Flux.error( - new McpError("Error parsing JSON-RPC message: " - + responseEvent.sseEvent().data())); + return Flux.error(new McpTransportException( + "Error parsing JSON-RPC message: " + responseEvent, ioException)); } } else { @@ -304,19 +303,39 @@ else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed return Flux.empty(); } else if (statusCode == NOT_FOUND) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); + + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and the response is 404, we consider it a + // session not found error. + logger.debug("Session not found for session ID: {}", + transportSession.sessionId().get()); + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error( + new McpTransportException("Server Not Found. Status code:" + statusCode + + ", response-event:" + responseEvent)); } else if (statusCode == BAD_REQUEST) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and thre response is 404, we consider it a + // session not found error. + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error( + new McpTransportException("Bad Request. Status code:" + statusCode + + ", response-event:" + responseEvent)); + } - return Flux.error(new McpError( + return Flux.error(new McpTransportException( "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); }).flatMap( @@ -468,8 +487,8 @@ else if (contentType.contains(TEXT_EVENT_STREAM)) { return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); } catch (IOException ioException) { - return Flux.error( - new McpError("Error parsing JSON-RPC message: " + sseEvent.data())); + return Flux.error(new McpTransportException( + "Error parsing JSON-RPC message: " + responseEvent, ioException)); } }); } @@ -485,8 +504,8 @@ else if (contentType.contains(APPLICATION_JSON)) { return Mono.just(McpSchema.deserializeJsonRpcMessage(objectMapper, data)); } catch (IOException e) { - // TODO: this should be a McpTransportError - return Mono.error(e); + return Mono.error(new McpTransportException( + "Error deserializing JSON-RPC message: " + responseEvent, e)); } } logger.warn("Unknown media type {} returned for POST in session {}", contentType, @@ -496,18 +515,32 @@ else if (contentType.contains(APPLICATION_JSON)) { new RuntimeException("Unknown media type returned: " + contentType)); } else if (statusCode == NOT_FOUND) { - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionRepresentation); - return Flux.error(exception); + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id and the + // response is 404, we consider it a session not found error. + logger.debug("Session not found for session ID: {}", transportSession.sessionId().get()); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + return Flux.error(new McpTransportException( + "Server Not Found. Status code:" + statusCode + ", response-event:" + responseEvent)); } - // Some implementations can return 400 when presented with a - // session id that it doesn't know about, so we will - // invalidate the session - // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 else if (statusCode == BAD_REQUEST) { - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionRepresentation); - return Flux.error(exception); + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id and the + // response is 404, we consider it a session not found error. + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + return Flux.error(new McpTransportException( + "Bad Request. Status code:" + statusCode + ", response-event:" + responseEvent)); } return Flux.error( diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java index 4d9bdea5d..296d1a17d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -15,7 +15,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpTransportException; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.FluxSink; @@ -178,8 +178,7 @@ else if (line.startsWith(":")) { } else { // If the response is not successful, emit an error - // TODO: This should be a McpTransportError - this.sink.error(new McpError( + this.sink.error(new McpTransportException( "Invalid SSE response. Status code: " + this.responseInfo.statusCode() + " Line: " + line)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java new file mode 100644 index 000000000..cfd3dae31 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java @@ -0,0 +1,38 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.spec; + +/** + * Exception thrown when there is an issue with the transport layer of the Model Context + * Protocol (MCP). + * + *

+ * This exception is used to indicate errors that occur during communication between the + * MCP client and server, such as connection failures, protocol violations, or unexpected + * responses. + * + * @author Christian Tzolov + */ +public class McpTransportException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + public McpTransportException(String message) { + super(message); + } + + public McpTransportException(String message, Throwable cause) { + super(message, cause); + } + + public McpTransportException(Throwable cause) { + super(cause); + } + + public McpTransportException(String message, Throwable cause, boolean enableSuppression, + boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index e912e1dd6..3626d8ca0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -487,7 +487,8 @@ void testAddRoot() { void testAddRootWithNullValue() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.addRoot(null)) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Root must not be null")) .verify(); }); } @@ -506,7 +507,7 @@ void testRemoveRoot() { void testRemoveNonExistentRoot() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalStateException.class) .hasMessage("Root with uri 'nonexistent-uri' not found")) .verify(); }); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java index c8d691924..02021edbf 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java @@ -16,7 +16,6 @@ import org.mockito.MockitoAnnotations; import io.modelcontextprotocol.spec.McpClientSession; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import reactor.core.publisher.Mono; @@ -154,7 +153,7 @@ void shouldFailForUnsupportedProtocolVersion() { .thenReturn(Mono.just(unsupportedResult)); StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) - .expectError(McpError.class) + .expectError(RuntimeException.class) .verify(); verify(mockClientSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()); @@ -178,7 +177,7 @@ void shouldTimeoutOnSlowInitialization() { init -> Mono.just(init.initializeResult())), () -> virtualTimeScheduler, Long.MAX_VALUE) .expectSubscription() .expectNoEvent(INITIALIZE_TIMEOUT) - .expectError(McpError.class) + .expectError(RuntimeException.class) .verify(); } @@ -234,7 +233,7 @@ void shouldHandleInitializationFailure() { .thenReturn(Mono.error(new RuntimeException("Connection failed"))); StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) - .expectError(McpError.class) + .expectError(RuntimeException.class) .verify(); assertThat(initializer.isInitialized()).isFalse(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index b2fd7fb65..daa6b5e1e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -13,7 +13,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.MockMcpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; @@ -373,7 +372,7 @@ void testSamplingCreateMessageRequestHandlingWithNullHandler() { // Create client with sampling capability but null handler assertThatThrownBy( () -> McpClient.async(transport).capabilities(ClientCapabilities.builder().sampling().build()).build()) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Sampling handler must not be null when client capabilities include sampling"); } @@ -521,7 +520,7 @@ void testElicitationCreateRequestHandlingWithNullHandler() { // Create client with elicitation capability but null handler assertThatThrownBy(() -> McpClient.async(transport) .capabilities(ClientCapabilities.builder().elicitation().build()) - .build()).isInstanceOf(McpError.class) + .build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Elicitation handler must not be null when client capabilities include elicitation"); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java index 36216988f..3feb1d05c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java @@ -113,7 +113,7 @@ void shouldFailForUnsupportedVersion() { new McpSchema.InitializeResult(unsupportedVersion, null, new McpSchema.Implementation("test-server", "1.0.0"), null), null)); - }).expectError(McpError.class).verify(); + }).expectError(RuntimeException.class).verify(); } finally { StepVerifier.create(client.closeGracefully()).verifyComplete(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java new file mode 100644 index 000000000..2b502a83b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java @@ -0,0 +1,345 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.test.StepVerifier; + +/** + * Tests for error handling changes in HttpClientStreamableHttpTransport. Specifically + * tests the distinction between session-related errors and general transport errors for + * 404 and 400 status codes. + * + * @author Christian Tzolov + */ +@Timeout(15) +public class HttpClientStreamableHttpTransportErrorHandlingTest { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String HOST = "http://localhost:" + PORT; + + private HttpServer server; + + private AtomicReference serverResponseStatus = new AtomicReference<>(200); + + private AtomicReference currentServerSessionId = new AtomicReference<>(null); + + private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + + private McpClientTransport transport; + + @BeforeEach + void startServer() throws IOException { + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Configure the /mcp endpoint with dynamic response + server.createContext("/mcp", httpExchange -> { + if ("DELETE".equals(httpExchange.getRequestMethod())) { + httpExchange.sendResponseHeaders(200, 0); + } + else { + // Capture session ID from request if present + String requestSessionId = httpExchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + lastReceivedSessionId.set(requestSessionId); + + int status = serverResponseStatus.get(); + + // Set response headers + httpExchange.getResponseHeaders().set("Content-Type", "application/json"); + + // Add session ID to response if configured + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + httpExchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + + // Send response based on configured status + if (status == 200) { + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + httpExchange.sendResponseHeaders(200, response.length()); + httpExchange.getResponseBody().write(response.getBytes()); + } + else { + httpExchange.sendResponseHeaders(status, 0); + } + } + httpExchange.close(); + }); + + server.setExecutor(null); + server.start(); + + transport = HttpClientStreamableHttpTransport.builder(HOST).build(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + } + + /** + * Test that 404 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test404WithoutSessionId() { + serverResponseStatus.set(404); + currentServerSessionId.set(null); // No session ID in response + + var testMessage = createTestRequestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Not Found") && throwable.getMessage().contains("404") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 404 response WITH session ID throws McpTransportSessionNotFoundException + */ + @Test + void test404WithSessionId() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-123"); + + // Set up exception handler to verify session invalidation + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // The session should now be established, next request will include session ID + // Now return 404 for next request + serverResponseStatus.set(404); + + // Send another message - should get SessionNotFoundException + StepVerifier.create(transport.sendMessage(testMessage)) + .expectError(McpTransportSessionNotFoundException.class) + .verify(); + + // Verify exception handler was called with SessionNotFoundException + verify(exceptionHandler).accept(any(McpTransportSessionNotFoundException.class)); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 400 response WITHOUT session ID throws McpTransportException (not + * SessionNotFoundException) + */ + @Test + void test400WithoutSessionId() { + serverResponseStatus.set(400); + currentServerSessionId.set(null); // No session ID + + var testMessage = createTestRequestMessage(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(throwable -> throwable instanceof McpTransportException + && throwable.getMessage().contains("Bad Request") && throwable.getMessage().contains("400") + && !(throwable instanceof McpTransportSessionNotFoundException)) + .verify(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that 400 response WITH session ID throws McpTransportSessionNotFoundException + * This handles the case mentioned in the code comment about some implementations + * returning 400 for unknown session IDs. + */ + @Test + void test400WithSessionId() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("test-session-456"); + + // Set up exception handler + @SuppressWarnings("unchecked") + Consumer exceptionHandler = mock(Consumer.class); + transport.setExceptionHandler(exceptionHandler); + + // Connect with handler + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send initial message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // The session should now be established, next request will include session ID + // Now return 400 for next request (simulating unknown session ID) + serverResponseStatus.set(400); + + // Send another message - should get SessionNotFoundException + StepVerifier.create(transport.sendMessage(testMessage)) + .expectError(McpTransportSessionNotFoundException.class) + .verify(); + + // Verify exception handler was called + verify(exceptionHandler).accept(any(McpTransportSessionNotFoundException.class)); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test session recovery after SessionNotFoundException Verifies that a new session + * can be established after the old one is invalidated + */ + @Test + void testSessionRecoveryAfter404() { + // First establish a session + serverResponseStatus.set(200); + currentServerSessionId.set("session-1"); + + // Send initial message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(lastReceivedSessionId.get()).isNull(); + + // The session should now be established + // Simulate session loss - return 404 + serverResponseStatus.set(404); + + // This should fail with SessionNotFoundException + StepVerifier.create(transport.sendMessage(testMessage)) + .expectError(McpTransportSessionNotFoundException.class) + .verify(); + + // Now server is back with new session + serverResponseStatus.set(200); + currentServerSessionId.set("session-2"); + lastReceivedSessionId.set(null); // Reset to verify new session + + // Should be able to establish new session + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Verify no session ID was sent (since old session was invalidated) + assertThat(lastReceivedSessionId.get()).isNull(); + + // Next request should use the new session ID + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Session ID should now be sent with requests + assertThat(lastReceivedSessionId.get()).isEqualTo("session-2"); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + /** + * Test that reconnect (GET request) also properly handles 404/400 errors + */ + @Test + void testReconnectErrorHandling() { + + // Set up SSE endpoint for GET requests + server.createContext("/mcp-sse", exchange -> { + String method = exchange.getRequestMethod(); + String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + if ("GET".equals(method)) { + int status = serverResponseStatus.get(); + + if (status == 404 && requestSessionId != null) { + // 404 with session ID - should trigger SessionNotFoundException + exchange.sendResponseHeaders(404, 0); + } + else if (status == 404) { + // 404 without session ID - should trigger McpTransportException + exchange.sendResponseHeaders(404, 0); + } + else { + // Normal SSE response + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, 0); + // Send a test SSE event + String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; + exchange.getResponseBody().write(sseData.getBytes()); + } + } + else { + // POST request handling + exchange.getResponseHeaders().set("Content-Type", "application/json"); + String responseSessionId = currentServerSessionId.get(); + if (responseSessionId != null) { + exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); + } + String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + } + exchange.close(); + }); + + // Test with session ID - should get SessionNotFoundException + serverResponseStatus.set(200); + currentServerSessionId.set("sse-session-1"); + + var transport = HttpClientStreamableHttpTransport.builder(HOST) + .endpoint("/mcp-sse") + .openConnectionOnStartup(true) // This will trigger GET request on connect + .build(); + + // First connect successfully + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + // Send message to establish session + var testMessage = createTestRequestMessage(); + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Now simulate server returning 404 on reconnect + serverResponseStatus.set(404); + + // This should trigger reconnect which will fail + // The error should be handled internally and passed to exception handler + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + private McpSchema.JSONRPCRequest createTestRequestMessage() { + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Test Client", "1.0.0")); + return new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", + initializeRequest); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java index 56e74218f..823c28d8e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -13,6 +13,7 @@ import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import com.fasterxml.jackson.databind.ObjectMapper; @@ -23,6 +24,7 @@ import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; import io.modelcontextprotocol.server.transport.TomcatTestUtil; +@Timeout(15) class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TomcatTestUtil.findAvailablePort(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java index 4c3f22d76..a8951e6dc 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -29,6 +29,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.springframework.mock.web.MockHttpServletRequest; @@ -49,6 +50,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; +@Timeout(15) class HttpServletStatelessIntegrationTests { private static final int PORT = TomcatTestUtil.findAvailablePort(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java index 6ac10014e..8a8675d95 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -13,6 +13,7 @@ import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import com.fasterxml.jackson.databind.ObjectMapper; @@ -23,6 +24,7 @@ import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; import io.modelcontextprotocol.server.transport.TomcatTestUtil; +@Timeout(15) class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TomcatTestUtil.findAvailablePort(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java index e329188f9..f915895be 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java @@ -27,10 +27,12 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.Prompt; import io.modelcontextprotocol.spec.McpSchema.PromptArgument; import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ResourceReference; import io.modelcontextprotocol.spec.McpSchema.PromptReference; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; @@ -84,7 +86,7 @@ public void after() { tomcat.destroy(); } catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); + e.printStackTrace(); } } } @@ -99,8 +101,13 @@ void testCompletionHandlerReceivesContext() { ResourceReference resourceRef = new ResourceReference("ref/resource", "test://resource/{param}"); - McpSchema.Resource resource = new McpSchema.Resource("test://resource/{param}", "Test Resource", - "A resource for testing", "text/plain", 123L, null); + var resource = Resource.builder() + .uri("test://resource/{param}") + .name("Test Resource") + .description("A resource for testing") + .mimeType("text/plain") + .size(123L) + .build(); var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().completions().build()) @@ -199,8 +206,13 @@ else if ("products_db".equals(db)) { return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false)); }; - McpSchema.Resource resource = new McpSchema.Resource("db://{database}/{table}", "Database Table", - "Resource representing a table in a database", "application/json", 456L, null); + McpSchema.Resource resource = Resource.builder() + .uri("db://{database}/{table}") + .name("Database Table") + .description("Resource representing a table in a database") + .mimeType("application/json") + .size(456L) + .build(); var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().completions().build()) @@ -254,7 +266,10 @@ void testCompletionErrorOnMissingContext() { // Check if database context is provided if (request.context() == null || request.context().arguments() == null || !request.context().arguments().containsKey("database")) { - throw new McpError("Please select a database first to see available tables"); + + throw McpError.builder(ErrorCodes.INVALID_REQUEST) + .message("Please select a database first to see available tables") + .build(); } // Normal completion if context is provided String db = request.context().arguments().get("database"); @@ -268,8 +283,13 @@ void testCompletionErrorOnMissingContext() { return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false)); }; - McpSchema.Resource resource = new McpSchema.Resource("db://{database}/{table}", "Database Table", - "Resource representing a table in a database", "application/json", 456L, null); + McpSchema.Resource resource = Resource.builder() + .uri("db://{database}/{table}") + .name("Database Table") + .description("Resource representing a table in a database") + .mimeType("application/json") + .size(456L) + .build(); var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().completions().build())