Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpTransportException;
import io.modelcontextprotocol.spec.McpTransportSession;
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
import io.modelcontextprotocol.spec.McpTransportStream;
Expand Down Expand Up @@ -70,6 +71,8 @@
*/
public class WebClientStreamableHttpTransport implements McpClientTransport {

private static final String MISSING_SESSION_ID = "[missing_session_id]";

private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class);

private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26;
Expand Down Expand Up @@ -221,8 +224,13 @@ else if (isNotAllowed(response)) {
return Flux.empty();
}
else if (isNotFound(response)) {
String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
return mcpSessionNotFoundError(sessionIdRepresentation);
if (transportSession.sessionId().isPresent()) {
String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
return mcpSessionNotFoundError(sessionIdRepresentation);
}
else {
return this.extractError(response, MISSING_SESSION_ID);
}
}
else {
return response.<McpSchema.JSONRPCMessage>createError().doOnError(e -> {
Expand Down Expand Up @@ -318,10 +326,10 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
}
}
else {
if (isNotFound(response)) {
if (isNotFound(response) && !sessionRepresentation.equals(MISSING_SESSION_ID)) {
return mcpSessionNotFoundError(sessionRepresentation);
}
return extractError(response, sessionRepresentation);
return this.extractError(response, sessionRepresentation);
}
})
.flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage)))
Expand Down Expand Up @@ -362,10 +370,10 @@ private Flux<McpSchema.JSONRPCMessage> extractError(ClientResponse response, Str
McpSchema.JSONRPCResponse.class);
jsonRpcError = jsonRpcResponse.error();
toPropagate = jsonRpcError != null ? new McpError(jsonRpcError)
: new McpError("Can't parse the jsonResponse " + jsonRpcResponse);
: new McpTransportException("Can't parse the jsonResponse " + jsonRpcResponse);
}
catch (IOException ex) {
toPropagate = new RuntimeException("Sending request failed", e);
toPropagate = new McpTransportException("Sending request failed, " + e.getMessage(), e);
logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body);
}

Expand All @@ -374,7 +382,11 @@ private Flux<McpSchema.JSONRPCMessage> extractError(ClientResponse response, Str
// invalidate the session
// https://github.com/modelcontextprotocol/typescript-sdk/issues/389
if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) {
return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate));
if (!sessionRepresentation.equals(MISSING_SESSION_ID)) {
return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate));
}
return Mono.error(new McpTransportException("Received 400 BAD REQUEST for session "
+ sessionRepresentation + ". " + toPropagate.getMessage(), toPropagate));
}
return Mono.error(toPropagate);
}).flux();
Expand Down Expand Up @@ -403,7 +415,7 @@ private static boolean isEventStream(ClientResponse response) {
}

private static String sessionIdOrPlaceholder(McpTransportSession<?> transportSession) {
return transportSession.sessionId().orElse("[missing_session_id]");
return transportSession.sessionId().orElse(MISSING_SESSION_ID);
}

private Flux<McpSchema.JSONRPCMessage> directResponseFlux(McpSchema.JSONRPCMessage sentMessage,
Expand All @@ -421,8 +433,7 @@ private Flux<McpSchema.JSONRPCMessage> directResponseFlux(McpSchema.JSONRPCMessa
}
}
catch (IOException e) {
// TODO: this should be a McpTransportError
s.error(e);
s.error(new McpTransportException(e));
}
}).flatMapIterable(Function.identity());
}
Expand All @@ -449,7 +460,7 @@ private Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>> parse(Serve
return Tuples.of(Optional.ofNullable(event.id()), List.of(message));
}
catch (IOException ioException) {
throw new McpError("Error parsing JSON-RPC message: " + event.data());
throw new McpTransportException("Error parsing JSON-RPC message: " + event.data(), ioException);
}
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import io.modelcontextprotocol.spec.HttpHeaders;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
import io.modelcontextprotocol.spec.ProtocolVersions;
Expand Down Expand Up @@ -197,8 +196,6 @@ public List<String> protocolVersions() {
* @param handler a function that processes incoming JSON-RPC messages and returns
* responses
* @return a Mono that completes when the connection is fully established
* @throws McpError if there's an error processing SSE events or if an unrecognized
* event type is received
*/
@Override
public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
Expand All @@ -215,7 +212,7 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
else {
// TODO: clarify with the spec if multiple events can be
// received
s.error(new McpError("Failed to handle SSE endpoint event"));
s.error(new RuntimeException("Failed to handle SSE endpoint event"));
}
}
else if (MESSAGE_EVENT_TYPE.equals(event.event())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.client.WebClient;
Expand All @@ -26,6 +27,7 @@
import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer;

@Timeout(15)
class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests {

private static final int PORT = TestUtil.findAvailablePort();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.client.WebClient;
Expand All @@ -26,6 +27,7 @@
import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer;

@Timeout(15)
class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests {

private static final int PORT = TestUtil.findAvailablePort();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.client.WebClient;
Expand All @@ -26,6 +27,7 @@
import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer;

@Timeout(15)
class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests {

private static final int PORT = TestUtil.findAvailablePort();
Expand Down
Loading