Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@

package io.modelcontextprotocol;

import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertWith;
import static org.awaitility.Awaitility.await;
import static org.mockito.Mockito.mock;

import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
Expand All @@ -29,9 +21,6 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpServer;
Expand All @@ -56,12 +45,23 @@
import io.modelcontextprotocol.spec.McpSchema.Role;
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
import io.modelcontextprotocol.spec.McpSchema.TextContent;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import io.modelcontextprotocol.util.Utils;
import net.javacrumbs.jsonunit.core.Option;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertWith;
import static org.awaitility.Awaitility.await;
import static org.mockito.Mockito.mock;

public abstract class AbstractMcpClientServerIntegrationTests {

protected ConcurrentHashMap<String, McpClient.SyncSpec> clientBuilders = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -108,8 +108,8 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
.tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
.callHandler((exchange, request) -> {
exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block();
return Mono.just(mock(CallToolResult.class));
return exchange.createMessage(mock(McpSchema.CreateMessageRequest.class))
.then(Mono.just(mock(CallToolResult.class)));
})
.build();

Expand Down Expand Up @@ -1434,6 +1434,66 @@ void testStructuredOutputValidationSuccess(String clientType) {

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient", "webflux" })
void testStructuredOutputWithInHandlerError(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

// Create a tool with output schema
Map<String, Object> outputSchema = Map.of(
"type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation",
Map.of("type", "string"), "timestamp", Map.of("type", "string")),
"required", List.of("result", "operation"));

Tool calculatorTool = Tool.builder()
.name("calculator")
.description("Performs mathematical calculations")
.outputSchema(outputSchema)
.build();

// Handler that throws an exception to simulate an error
McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder()
.tool(calculatorTool)
.callHandler((exchange, request) -> {

return CallToolResult.builder()
.isError(true)
.content(List.of(new TextContent("Error calling tool: Simulated in-handler error")))
.build();
})
.build();

var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.tools(tool)
.build();

try (var mcpClient = clientBuilder.build()) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

// Verify tool is listed with output schema
var toolsList = mcpClient.listTools();
assertThat(toolsList.tools()).hasSize(1);
assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator");
// Note: outputSchema might be null in sync server, but validation still works

// Call tool with valid structured output
CallToolResult response = mcpClient
.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")));

assertThat(response).isNotNull();
assertThat(response.isError()).isTrue();
assertThat(response.content()).isNotEmpty();
assertThat(response.content())
.containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error"));
assertThat(response.structuredContent()).isNull();
}
finally {
mcpServer.closeGracefully();
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testStructuredOutputValidationFailure(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@

package io.modelcontextprotocol;

import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.awaitility.Awaitility.await;

import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
Expand All @@ -20,9 +14,6 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification;
import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification;
Expand All @@ -33,10 +24,19 @@
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
import io.modelcontextprotocol.spec.McpSchema.TextContent;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import net.javacrumbs.jsonunit.core.Option;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import reactor.core.publisher.Mono;

import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.awaitility.Awaitility.await;

public abstract class AbstractStatelessIntegrationTests {

protected ConcurrentHashMap<String, McpClient.SyncSpec> clientBuilders = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -350,6 +350,64 @@ void testStructuredOutputValidationSuccess(String clientType) {
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient", "webflux" })
void testStructuredOutputWithInHandlerError(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

// Create a tool with output schema
Map<String, Object> outputSchema = Map.of(
"type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation",
Map.of("type", "string"), "timestamp", Map.of("type", "string")),
"required", List.of("result", "operation"));

Tool calculatorTool = Tool.builder()
.name("calculator")
.description("Performs mathematical calculations")
.outputSchema(outputSchema)
.build();

// Handler that throws an exception to simulate an error
McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification
.builder()
.tool(calculatorTool)
.callHandler((exchange, request) -> CallToolResult.builder()
.isError(true)
.content(List.of(new TextContent("Error calling tool: Simulated in-handler error")))
.build())
.build();

var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.tools(tool)
.build();

try (var mcpClient = clientBuilder.build()) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

// Verify tool is listed with output schema
var toolsList = mcpClient.listTools();
assertThat(toolsList.tools()).hasSize(1);
assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator");
// Note: outputSchema might be null in sync server, but validation still works

// Call tool with valid structured output
CallToolResult response = mcpClient
.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")));

assertThat(response).isNotNull();
assertThat(response.isError()).isTrue();
assertThat(response.content()).isNotEmpty();
assertThat(response.content())
.containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error"));
assertThat(response.structuredContent()).isNull();
}
finally {
mcpServer.closeGracefully();
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient", "webflux" })
void testStructuredOutputValidationFailure(String clientType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse;
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate;
import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest;
import io.modelcontextprotocol.spec.McpSchema.TextContent;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
Expand Down Expand Up @@ -376,6 +378,11 @@ public Mono<CallToolResult> apply(McpAsyncServerExchange exchange, McpSchema.Cal

return this.delegateCallToolResult.apply(exchange, request).map(result -> {

if (result.isError() != null && result.isError()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can simplify to Boolean.TRUE.equals(result.isError())

// If the tool call resulted in an error, skip further validation
return result;
}

if (outputSchema == null) {
if (result.structuredContent() != null) {
logger.warn(
Expand Down Expand Up @@ -507,11 +514,11 @@ private McpRequestHandler<CallToolResult> toolsCallRequestHandler() {
.findAny();

if (toolSpecification.isEmpty()) {
return Mono.error(new McpError("Tool not found: " + callToolRequest.name()));
return Mono.error(new McpError(new JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INVALID_PARAMS,
"Unknown tool: invalid_tool_name", "Tool not found: " + callToolRequest.name())));
}

return toolSpecification.map(tool -> Mono.defer(() -> tool.callHandler().apply(exchange, callToolRequest)))
.orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name())));
return toolSpecification.get().callHandler().apply(exchange, callToolRequest);
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse;
import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate;
import io.modelcontextprotocol.spec.McpSchema.TextContent;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
import io.modelcontextprotocol.util.Assert;
Expand Down Expand Up @@ -249,6 +251,11 @@ public Mono<CallToolResult> apply(McpTransportContext transportContext, McpSchem

return this.delegateHandler.apply(transportContext, request).map(result -> {

if (result.isError() != null && result.isError()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can simplify to Boolean.TRUE.equals(result.isError())

// If the tool call resulted in an error, skip further validation
return result;
}

if (outputSchema == null) {
if (result.structuredContent() != null) {
logger.warn(
Expand Down Expand Up @@ -375,11 +382,11 @@ private McpStatelessRequestHandler<CallToolResult> toolsCallRequestHandler() {
.findAny();

if (toolSpecification.isEmpty()) {
return Mono.error(new McpError("Tool not found: " + callToolRequest.name()));
return Mono.error(new McpError(new JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INVALID_PARAMS,
"Unknown tool: invalid_tool_name", "Tool not found: " + callToolRequest.name())));
}

return toolSpecification.map(tool -> tool.callHandler().apply(ctx, callToolRequest))
.orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name())));
return toolSpecification.get().callHandler().apply(ctx, callToolRequest);
};
}

Expand Down
16 changes: 16 additions & 0 deletions mcp/src/main/java/io/modelcontextprotocol/util/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.net.URI;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;

/**
* Miscellaneous utility methods.
Expand Down Expand Up @@ -107,4 +108,19 @@ private static boolean isUnderBaseUri(URI baseUri, URI endpointUri) {
return endpointPath.startsWith(basePath);
}

/**
* Finds the root cause of the given throwable by traversing the cause chain.
* @param throwable The throwable to analyze
* @return The root cause throwable
* @throws NullPointerException if the provided throwable is null
*/
public static Throwable findRootCause(Throwable throwable) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this any more

Objects.requireNonNull(throwable);
Throwable rootCause = throwable;
while (rootCause.getCause() != null && rootCause.getCause() != rootCause) {
rootCause = rootCause.getCause();
}
return rootCause;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* @author Christian Tzolov
* @author Dariusz Jędrzejczyk
*/
@Timeout(15) // Giving extra time beyond the client timeout
@Timeout(25) // Giving extra time beyond the client timeout
class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests {

@Override
Expand All @@ -40,4 +40,9 @@ protected Duration getInitializationTimeout() {
return Duration.ofSeconds(20);
}

@Override
protected Duration getRequestTimeout() {
return Duration.ofSeconds(25);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* @author Christian Tzolov
* @author Dariusz Jędrzejczyk
*/
@Timeout(15) // Giving extra time beyond the client timeout
@Timeout(25) // Giving extra time beyond the client timeout
class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests {

@Override
Expand Down Expand Up @@ -71,4 +71,9 @@ protected Duration getInitializationTimeout() {
return Duration.ofSeconds(10);
}

@Override
protected Duration getRequestTimeout() {
return Duration.ofSeconds(25);
}

}
Loading