From ad5c30932a0f58ead05ec22d1e92eb995837a957 Mon Sep 17 00:00:00 2001 From: YunKui Lu Date: Tue, 15 Jul 2025 00:56:02 +0800 Subject: [PATCH] feat: MCP tools support ToolContext - Added `excludedToolContextKeys` to `McpClientCommonProperties.ToolCallback`. This will let us control which `ToolContext` fields should not be sent to the server. By default, `excludedToolContextKeys` is set to `TOOL_CALL_HISTORY`. - Added `AbstractMcpToolCallback` to hold common code for `McpToolCallback` - Updated `AsyncMcpToolCallback` and `SyncMcpToolCallback` to pass `ToolContext` via `_meta` using key `ai.springframework.org/tool_context` - Updated `McpToolUtils` to help server tools extract `ToolContext` from MCP Client requests - Updated `AsyncMcpToolCallbackProvider` and `SyncMcpToolCallbackProvider` to use the Builder design pattern and add the `excludedToolContextKeys` field - Updated the corresponding test cases Signed-off-by: YunKui Lu --- .../McpToolCallbackAutoConfiguration.java | 16 ++- .../properties/McpClientCommonProperties.java | 19 +++ ...itional-spring-configuration-metadata.json | 6 + .../McpClientCommonPropertiesTests.java | 51 +++++++- .../ai/mcp/AbstractMcpToolCallback.java | 114 ++++++++++++++++++ .../ai/mcp/AsyncMcpToolCallback.java | 112 +++++++++++++---- .../ai/mcp/AsyncMcpToolCallbackProvider.java | 78 +++++++++++- .../springframework/ai/mcp/McpToolUtils.java | 67 +++++++--- .../ai/mcp/SyncMcpToolCallback.java | 84 +++++++++++-- .../ai/mcp/SyncMcpToolCallbackProvider.java | 77 +++++++++++- .../ai/mcp/AsyncMcpToolCallbackTest.java | 26 +++- .../ai/mcp/SyncMcpToolCallbackTests.java | 21 +++- .../ai/mcp/ToolUtilsTests.java | 57 ++++++++- 13 files changed, 654 insertions(+), 74 deletions(-) create mode 100644 auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/additional-spring-configuration-metadata.json create mode 100644 mcp/common/src/main/java/org/springframework/ai/mcp/AbstractMcpToolCallback.java diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java index a477af8a47a..226701eb73f 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java @@ -51,16 +51,24 @@ public class McpToolCallbackAutoConfiguration { @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) - public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider> syncMcpClients) { + public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider> syncMcpClients, + McpClientCommonProperties commonProperties) { List mcpClients = syncMcpClients.stream().flatMap(List::stream).toList(); - return new SyncMcpToolCallbackProvider(mcpClients); + return SyncMcpToolCallbackProvider.builder() + .addMcpClients(mcpClients) + .addExcludedToolContextKeys(commonProperties.getToolcallback().getExcludedToolContextKeys()) + .build(); } @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider> mcpClientsProvider) { + public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider> mcpClientsProvider, + McpClientCommonProperties commonProperties) { List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); - return new AsyncMcpToolCallbackProvider(mcpClients); + return AsyncMcpToolCallbackProvider.builder() + .addMcpClients(mcpClients) + .addExcludedToolContextKeys(commonProperties.getToolcallback().getExcludedToolContextKeys()) + .build(); } public static class McpToolCallbackAutoConfigurationCondition extends AllNestedConditions { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpClientCommonProperties.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpClientCommonProperties.java index fcc534080aa..d63c776f800 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpClientCommonProperties.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpClientCommonProperties.java @@ -17,9 +17,12 @@ package org.springframework.ai.mcp.client.common.autoconfigure.properties; import java.time.Duration; +import java.util.Set; import org.springframework.boot.context.properties.ConfigurationProperties; +import static org.springframework.ai.chat.model.ToolContext.TOOL_CALL_HISTORY; + /** * Common Configuration properties for the Model Context Protocol (MCP) clients shared for * all transport types. @@ -190,6 +193,14 @@ public static class Toolcallback { */ private boolean enabled = true; + /** + * The keys that will not be sent to the MCP Server inside the `_meta` field of + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} + */ + // Remember to update META-INF/additional-spring-configuration-metadata.json if + // you change this default values. + private Set excludedToolContextKeys = Set.of(TOOL_CALL_HISTORY); + public void setEnabled(boolean enabled) { this.enabled = enabled; } @@ -198,6 +209,14 @@ public boolean isEnabled() { return this.enabled; } + public Set getExcludedToolContextKeys() { + return excludedToolContextKeys; + } + + public void setExcludedToolContextKeys(Set excludedToolContextKeys) { + this.excludedToolContextKeys = excludedToolContextKeys; + } + } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/additional-spring-configuration-metadata.json b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/additional-spring-configuration-metadata.json new file mode 100644 index 00000000000..d73598661b8 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/additional-spring-configuration-metadata.json @@ -0,0 +1,6 @@ +{"properties": [ + { + "name": "spring.ai.mcp.client.toolcallback.excluded-tool-context-keys", + "defaultValue": ["TOOL_CALL_HISTORY"] + } +]} \ No newline at end of file diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpClientCommonPropertiesTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpClientCommonPropertiesTests.java index 18eb85e2c3f..a38caf8989b 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpClientCommonPropertiesTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpClientCommonPropertiesTests.java @@ -17,6 +17,7 @@ package org.springframework.ai.mcp.client.common.autoconfigure.properties; import java.time.Duration; +import java.util.Set; import org.junit.jupiter.api.Test; @@ -47,6 +48,9 @@ void defaultValues() { assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(20)); assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.SYNC); assertThat(properties.isRootChangeNotification()).isTrue(); + assertThat(properties.getToolcallback().isEnabled()).isTrue(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyElementsOf(Set.of("TOOL_CALL_HISTORY")); }); } @@ -56,7 +60,9 @@ void customValues() { .withPropertyValues("spring.ai.mcp.client.enabled=false", "spring.ai.mcp.client.name=custom-client", "spring.ai.mcp.client.version=2.0.0", "spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.request-timeout=30s", "spring.ai.mcp.client.type=ASYNC", - "spring.ai.mcp.client.root-change-notification=false") + "spring.ai.mcp.client.root-change-notification=false", + "spring.ai.mcp.client.toolcallback.enabled=false", + "spring.ai.mcp.client.toolcallback.excluded-tool-context-keys=foo,bar") .run(context -> { McpClientCommonProperties properties = context.getBean(McpClientCommonProperties.class); assertThat(properties.isEnabled()).isFalse(); @@ -66,6 +72,9 @@ void customValues() { assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(30)); assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.ASYNC); assertThat(properties.isRootChangeNotification()).isFalse(); + assertThat(properties.getToolcallback().isEnabled()).isFalse(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("foo", "bar")); }); } @@ -101,6 +110,14 @@ void setterGetterMethods() { // Test rootChangeNotification property properties.setRootChangeNotification(false); assertThat(properties.isRootChangeNotification()).isFalse(); + + // Test toolcallback property + properties.getToolcallback().setEnabled(false); + assertThat(properties.getToolcallback().isEnabled()).isFalse(); + + properties.getToolcallback().setExcludedToolContextKeys(Set.of("foo", "bar")); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("foo", "bar")); } @Test @@ -125,7 +142,9 @@ void propertiesFileBinding() { .withPropertyValues("spring.ai.mcp.client.enabled=false", "spring.ai.mcp.client.name=test-mcp-client", "spring.ai.mcp.client.version=0.5.0", "spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.request-timeout=45s", "spring.ai.mcp.client.type=ASYNC", - "spring.ai.mcp.client.root-change-notification=false") + "spring.ai.mcp.client.root-change-notification=false", + "spring.ai.mcp.client.toolcallback.enabled=false", + "spring.ai.mcp.client.toolcallback.excluded-tool-context-keys=foo,bar") .run(context -> { McpClientCommonProperties properties = context.getBean(McpClientCommonProperties.class); assertThat(properties.isEnabled()).isFalse(); @@ -135,6 +154,9 @@ void propertiesFileBinding() { assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(45)); assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.ASYNC); assertThat(properties.isRootChangeNotification()).isFalse(); + assertThat(properties.getToolcallback().isEnabled()).isFalse(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("foo", "bar")); }); } @@ -165,7 +187,9 @@ void yamlConfigurationBinding() { .withPropertyValues("spring.ai.mcp.client.enabled=false", "spring.ai.mcp.client.name=test-mcp-client-yaml", "spring.ai.mcp.client.version=0.6.0", "spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.request-timeout=60s", "spring.ai.mcp.client.type=ASYNC", - "spring.ai.mcp.client.root-change-notification=false") + "spring.ai.mcp.client.root-change-notification=false", + "spring.ai.mcp.client.toolcallback.enabled=false", + "spring.ai.mcp.client.toolcallback.excluded-tool-context-keys=foo,bar") .run(context -> { McpClientCommonProperties properties = context.getBean(McpClientCommonProperties.class); assertThat(properties.isEnabled()).isFalse(); @@ -175,6 +199,9 @@ void yamlConfigurationBinding() { assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(60)); assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.ASYNC); assertThat(properties.isRootChangeNotification()).isFalse(); + assertThat(properties.getToolcallback().isEnabled()).isFalse(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("foo", "bar")); }); } @@ -201,6 +228,9 @@ void disabledProperties() { assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(20)); assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.SYNC); assertThat(properties.isRootChangeNotification()).isTrue(); + assertThat(properties.getToolcallback().isEnabled()).isTrue(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("TOOL_CALL_HISTORY")); }); } @@ -216,6 +246,9 @@ void notInitializedProperties() { assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(20)); assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.SYNC); assertThat(properties.isRootChangeNotification()).isTrue(); + assertThat(properties.getToolcallback().isEnabled()).isTrue(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("TOOL_CALL_HISTORY")); }); } @@ -231,6 +264,9 @@ void rootChangeNotificationDisabled() { assertThat(properties.isInitialized()).isTrue(); assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(20)); assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.SYNC); + assertThat(properties.getToolcallback().isEnabled()).isTrue(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("TOOL_CALL_HISTORY")); }); } @@ -246,6 +282,9 @@ void customRequestTimeout() { assertThat(properties.isInitialized()).isTrue(); assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.SYNC); assertThat(properties.isRootChangeNotification()).isTrue(); + assertThat(properties.getToolcallback().isEnabled()).isTrue(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("TOOL_CALL_HISTORY")); }); } @@ -261,6 +300,9 @@ void asyncClientType() { assertThat(properties.isInitialized()).isTrue(); assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(20)); assertThat(properties.isRootChangeNotification()).isTrue(); + assertThat(properties.getToolcallback().isEnabled()).isTrue(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("TOOL_CALL_HISTORY")); }); } @@ -278,6 +320,9 @@ void customNameAndVersion() { assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(20)); assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.SYNC); assertThat(properties.isRootChangeNotification()).isTrue(); + assertThat(properties.getToolcallback().isEnabled()).isTrue(); + assertThat(properties.getToolcallback().getExcludedToolContextKeys()) + .containsExactlyInAnyOrderElementsOf(Set.of("TOOL_CALL_HISTORY")); }); } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AbstractMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AbstractMcpToolCallback.java new file mode 100644 index 00000000000..1207170daf3 --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AbstractMcpToolCallback.java @@ -0,0 +1,114 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import static io.modelcontextprotocol.spec.McpSchema.Tool; + +/** + * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool + * interface with asynchronous execution support. + *

+ * This class acts as a bridge between the Model Context Protocol (MCP) and Spring AI's + * tool system, allowing MCP tools to be used seamlessly within Spring AI applications. + * It: + *

    + *
  • Converts MCP tool definitions to Spring AI tool definitions
  • + *
  • Handles the asynchronous execution of tool calls through the MCP client
  • + *
  • Manages JSON serialization/deserialization of tool inputs and outputs
  • + *
+ *

+ * + * @author YunKui Lu + * @see ToolCallback + * @see AsyncMcpToolCallback + * @see SyncMcpToolCallback + * @see Tool + */ +public abstract class AbstractMcpToolCallback implements ToolCallback { + + public static final String DEFAULT_MCP_META_TOOL_CONTEXT_KEY = McpToolUtils.DEFAULT_MCP_META_TOOL_CONTEXT_KEY; + + protected final Tool tool; + + /** + * the keys that will not be sent to the MCP Server inside the `_meta` field of + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} + */ + protected final Set excludedToolContextKeys; + + /** + * Creates a new {@code AbstractMcpToolCallback} instance. + * @param tool the MCP tool definition to adapt + * @param excludedToolContextKeys the keys that will not be sent to the MCP Server + * inside the `_meta` field of + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} + */ + protected AbstractMcpToolCallback(Tool tool, Set excludedToolContextKeys) { + Assert.notNull(tool, "tool cannot be null"); + Assert.notNull(excludedToolContextKeys, "excludedToolContextKeys cannot be null"); + + this.tool = tool; + this.excludedToolContextKeys = excludedToolContextKeys; + } + + /** + * Executes the tool with the provided input. + *

+ * This method: + *

    + *
  1. Converts the JSON input string to a map of arguments
  2. + *
  3. Calls the tool through the MCP client
  4. + *
  5. Converts the tool's response content to a JSON string
  6. + *
+ * @param functionInput the tool input as a JSON string + * @return the tool's response as a JSON string + */ + @Override + public String call(String functionInput) { + return this.call(functionInput, null); + } + + /** + * Converts the tool context to a mcp meta map + * @param toolContext the context for tool execution in a function calling scenario + * @return the mcp meta map + */ + protected Map getAdditionalToolContextToMeta(@Nullable ToolContext toolContext) { + if (toolContext == null || toolContext.getContext().isEmpty()) { + return Map.of(); + } + + Map meta = new HashMap<>(toolContext.getContext().size() - excludedToolContextKeys.size()); + for (var toolContextEntry : toolContext.getContext().entrySet()) { + if (excludedToolContextKeys.contains(toolContextEntry.getKey())) { + continue; + } + meta.put(toolContextEntry.getKey(), toolContextEntry.getValue()); + } + return meta; + } + +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 5f8da416109..b8de059fdc8 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -16,10 +16,15 @@ package org.springframework.ai.mcp; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; -import java.util.Map; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; @@ -28,6 +33,8 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool @@ -55,24 +62,42 @@ * } * * @author Christian Tzolov + * @author YunKui Lu * @see ToolCallback * @see McpAsyncClient * @see Tool */ -public class AsyncMcpToolCallback implements ToolCallback { +public class AsyncMcpToolCallback extends AbstractMcpToolCallback { private final McpAsyncClient asyncMcpClient; - private final Tool tool; - /** * Creates a new {@code AsyncMcpToolCallback} instance. * @param mcpClient the MCP client to use for tool execution * @param tool the MCP tool definition to adapt */ public AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool) { + this(mcpClient, tool, Set.of()); + } + + /** + * Creates a new {@code AsyncMcpToolCallback} instance. + * @param mcpClient the MCP client to use for tool execution + * @param tool the MCP tool definition to adapt + * @param excludedToolContextKeys the keys that will not be sent to the MCP Server + * inside the `_meta` field of + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} + */ + private AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool, Set excludedToolContextKeys) { + super(tool, excludedToolContextKeys); + + Assert.notNull(mcpClient, "mcpClient cannot be null"); + this.asyncMcpClient = mcpClient; - this.tool = tool; + } + + public static Builder builder() { + return new Builder(); } /** @@ -105,29 +130,72 @@ public ToolDefinition getToolDefinition() { *
  • Converts the tool's response content to a JSON string
  • * * @param functionInput the tool input as a JSON string + * @param toolContext the context for tool execution in a function calling scenario * @return the tool's response as a JSON string */ @Override - public String call(String functionInput) { + public String call(String functionInput, @Nullable ToolContext toolContext) { Map arguments = ModelOptionsUtils.jsonToMap(functionInput); - // Note that we use the original tool name here, not the adapted one from - // getToolDefinition - return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).onErrorMap(exception -> { - // If the tool throws an error during execution - throw new ToolExecutionException(this.getToolDefinition(), exception); - }).map(response -> { - if (response.isError() != null && response.isError()) { - throw new ToolExecutionException(this.getToolDefinition(), - new IllegalStateException("Error calling tool: " + response.content())); - } - return ModelOptionsUtils.toJsonString(response.content()); - }).contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())).block(); + + Map meta = new HashMap<>(); + if (toolContext != null && !toolContext.getContext().isEmpty()) { + meta.put(DEFAULT_MCP_META_TOOL_CONTEXT_KEY, super.getAdditionalToolContextToMeta(toolContext)); + } + + return Objects + // Note that we use the original tool name here, not the adapted one from + // getToolDefinition + .requireNonNull(this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments, meta)) + .onErrorMap(exception -> { + // If the tool throws an error during execution + throw new ToolExecutionException(this.getToolDefinition(), exception); + }) + .map(response -> { + if (response.isError() != null && response.isError()) { + throw new ToolExecutionException(this.getToolDefinition(), + new IllegalStateException("Error calling tool: " + response.content())); + } + return ModelOptionsUtils.toJsonString(response.content()); + }) + .contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())) + .block()); } - @Override - public String call(String toolArguments, ToolContext toolContext) { - // ToolContext is not supported by the MCP tools - return this.call(toolArguments); + public static class Builder { + + private McpAsyncClient asyncMcpClient; + + private Tool tool; + + private Set excludedToolContextKeys = new HashSet<>(); + + private Builder() { + } + + public Builder asyncMcpClient(McpAsyncClient asyncMcpClient) { + this.asyncMcpClient = asyncMcpClient; + return this; + } + + public Builder tool(Tool tool) { + this.tool = tool; + return this; + } + + public Builder addExcludedToolContextKeys(Set excludedToolContextKeys) { + this.excludedToolContextKeys.addAll(excludedToolContextKeys); + return this; + } + + public Builder addExcludedToolContextKey(String excludedToolContextKey) { + this.excludedToolContextKeys.add(excludedToolContextKey); + return this; + } + + public AsyncMcpToolCallback build() { + return new AsyncMcpToolCallback(this.asyncMcpClient, this.tool, this.excludedToolContextKeys); + } + } } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index 3525b9593e3..d4eab9edcee 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -17,7 +17,9 @@ package org.springframework.ai.mcp; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpAsyncClient; @@ -67,6 +69,7 @@ * } * * @author Christian Tzolov + * @author YunKui Lu * @since 1.0.0 * @see ToolCallbackProvider * @see AsyncMcpToolCallback @@ -74,6 +77,12 @@ */ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { + /** + * The keys that will not be sent to the MCP Server inside the `_meta` field of + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} + */ + private final Set excludedToolContextKeys; + private final List mcpClients; private final BiPredicate toolFilter; @@ -85,10 +94,27 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { * @param toolFilter a filter to apply to each discovered tool */ public AsyncMcpToolCallbackProvider(BiPredicate toolFilter, List mcpClients) { + this(toolFilter, mcpClients, Set.of()); + } + + /** + * Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP + * clients. + * @param mcpClients the list of MCP clients to use for discovering tools + * @param toolFilter a filter to apply to each discovered tool + * @param excludedToolContextKeys the keys that will not be sent to the MCP Server + * inside the `_meta` field of + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} + */ + protected AsyncMcpToolCallbackProvider(BiPredicate toolFilter, + List mcpClients, Set excludedToolContextKeys) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); + Assert.notNull(excludedToolContextKeys, "Excluded tool context keys must not be null"); + this.mcpClients = mcpClients; this.toolFilter = toolFilter; + this.excludedToolContextKeys = excludedToolContextKeys; } /** @@ -148,7 +174,11 @@ public ToolCallback[] getToolCallbacks() { .map(response -> response.tools() .stream() .filter(tool -> this.toolFilter.test(mcpClient, tool)) - .map(tool -> new AsyncMcpToolCallback(mcpClient, tool)) + .map(tool -> AsyncMcpToolCallback.builder() + .asyncMcpClient(mcpClient) + .tool(tool) + .addExcludedToolContextKeys(excludedToolContextKeys) + .build()) .toArray(ToolCallback[]::new)) .block(); @@ -203,4 +233,50 @@ public static Flux asyncToolCallbacks(List mcpClie return Flux.fromArray(new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks()); } + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private List mcpClients = new ArrayList<>(); + + private BiPredicate toolFilter = (mcpClient, tool) -> true; + + private Set excludedToolContextKeys = new HashSet<>(); + + private Builder() { + } + + public Builder addMcpClient(McpAsyncClient mcpClient) { + this.mcpClients.add(mcpClient); + return this; + } + + public Builder addMcpClients(List mcpClients) { + this.mcpClients.addAll(mcpClients); + return this; + } + + public Builder addExcludedToolContextKey(String excludedToolContextKey) { + this.excludedToolContextKeys.add(excludedToolContextKey); + return this; + } + + public Builder addExcludedToolContextKeys(Set excludedToolContextKeys) { + this.excludedToolContextKeys.addAll(excludedToolContextKeys); + return this; + } + + public Builder toolFilter(BiPredicate toolFilter) { + this.toolFilter = toolFilter; + return this; + } + + public AsyncMcpToolCallbackProvider build() { + return new AsyncMcpToolCallbackProvider(this.toolFilter, this.mcpClients, excludedToolContextKeys); + } + + } + } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java index 2f8f366d076..2afc563dbe3 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java @@ -16,6 +16,7 @@ package org.springframework.ai.mcp; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -67,6 +68,10 @@ public final class McpToolUtils { */ public static final String TOOL_CONTEXT_MCP_EXCHANGE_KEY = "exchange"; + public static final String TOOL_CONTEXT_MCP_META_KEY = "_meta"; + + public static final String DEFAULT_MCP_META_TOOL_CONTEXT_KEY = "ai.springframework.org/tool_context"; + private McpToolUtils() { } @@ -169,22 +174,42 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To var tool = new McpSchema.Tool(toolCallback.getToolDefinition().name(), toolCallback.getToolDefinition().description(), toolCallback.getToolDefinition().inputSchema()); - - return new McpServerFeatures.SyncToolSpecification(tool, (exchange, request) -> { - try { - String callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request), - new ToolContext(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchange))); - if (mimeType != null && mimeType.toString().startsWith("image")) { - return new McpSchema.CallToolResult(List - .of(new McpSchema.ImageContent(List.of(Role.ASSISTANT), null, callResult, mimeType.toString())), - false); + // @formatter:off + return McpServerFeatures.SyncToolSpecification.builder() // @formatter:on + .tool(tool) + .callHandler((exchange, request) -> { + try { + Map context = new HashMap<>(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchange)); + + if (!CollectionUtils.isEmpty(request.meta())) { + // Put the `_meta` field from `CallToolRequest` into + // `ToolContext`, so it can be used to access fields outside the + // `ai.springframework.org/*` keys. + context.put(TOOL_CONTEXT_MCP_META_KEY, request.meta()); + + if (request.meta().containsKey(DEFAULT_MCP_META_TOOL_CONTEXT_KEY)) { + // Get the McpClient tool context from the + // `ai.springframework.org/tool_context` key in the `_meta`. + Map toolContext = ModelOptionsUtils + .objectToMap(request.meta().get(DEFAULT_MCP_META_TOOL_CONTEXT_KEY)); + context.putAll(toolContext); + } + + } + + String callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request.arguments()), + new ToolContext(context)); + if (mimeType != null && mimeType.toString().startsWith("image")) { + return new McpSchema.CallToolResult(List.of(new McpSchema.ImageContent(List.of(Role.ASSISTANT), + null, callResult, mimeType.toString())), false); + } + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false); } - return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false); - } - catch (Exception e) { - return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(e.getMessage())), true); - } - }); + catch (Exception e) { + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(e.getMessage())), true); + } + }) + .build(); } /** @@ -286,11 +311,13 @@ public static McpServerFeatures.AsyncToolSpecification toAsyncToolSpecification( MimeType mimeType) { McpServerFeatures.SyncToolSpecification syncToolSpecification = toSyncToolSpecification(toolCallback, mimeType); - - return new AsyncToolSpecification(syncToolSpecification.tool(), - (exchange, map) -> Mono - .fromCallable(() -> syncToolSpecification.call().apply(new McpSyncServerExchange(exchange), map)) - .subscribeOn(Schedulers.boundedElastic())); + return AsyncToolSpecification.builder() + .tool(syncToolSpecification.tool()) + .callHandler((exchange, request) -> Mono + .fromCallable( + () -> syncToolSpecification.callHandler().apply(new McpSyncServerExchange(exchange), request)) + .subscribeOn(Schedulers.boundedElastic())) + .build(); } /** diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index fc61d801df1..56adbae7706 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -16,11 +16,15 @@ package org.springframework.ai.mcp; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Tool; -import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,6 +34,8 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool @@ -57,27 +63,40 @@ * } * * @author Christian Tzolov + * @author YunKui Lu * @see ToolCallback * @see McpSyncClient * @see Tool */ -public class SyncMcpToolCallback implements ToolCallback { +public class SyncMcpToolCallback extends AbstractMcpToolCallback { private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolCallback.class); private final McpSyncClient mcpClient; - private final Tool tool; - /** * Creates a new {@code SyncMcpToolCallback} instance. * @param mcpClient the MCP client to use for tool execution * @param tool the MCP tool definition to adapt */ public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) { - this.mcpClient = mcpClient; - this.tool = tool; + this(mcpClient, tool, Set.of()); + } + /** + * Creates a new {@code SyncMcpToolCallback} instance. + * @param mcpClient the MCP client to use for tool execution + * @param tool the MCP tool definition to adapt + * @param excludedToolContextKeys the keys that will not be sent to the MCP Server + * inside the `_meta` field of + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} + */ + private SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool, Set excludedToolContextKeys) { + super(tool, excludedToolContextKeys); + + Assert.notNull(mcpClient, "mcpClient cannot be null"); + + this.mcpClient = mcpClient; } /** @@ -110,17 +129,23 @@ public ToolDefinition getToolDefinition() { *
  • Converts the tool's response content to a JSON string
  • * * @param functionInput the tool input as a JSON string + * @param toolContext the context for tool execution in a function calling scenario * @return the tool's response as a JSON string */ @Override - public String call(String functionInput) { + public String call(String functionInput, @Nullable ToolContext toolContext) { Map arguments = ModelOptionsUtils.jsonToMap(functionInput); CallToolResult response; try { + Map meta = new HashMap<>(); + if (toolContext != null && !toolContext.getContext().isEmpty()) { + meta.put(DEFAULT_MCP_META_TOOL_CONTEXT_KEY, super.getAdditionalToolContextToMeta(toolContext)); + } + // Note that we use the original tool name here, not the adapted one from // getToolDefinition - response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)); + response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments, meta)); } catch (Exception ex) { logger.error("Exception while tool calling: ", ex); @@ -135,10 +160,45 @@ public String call(String functionInput) { return ModelOptionsUtils.toJsonString(response.content()); } - @Override - public String call(String toolArguments, ToolContext toolContext) { - // ToolContext is not supported by the MCP tools - return this.call(toolArguments); + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private McpSyncClient syncMcpClient; + + private Tool tool; + + private Set excludedToolContextKeys = new HashSet<>(); + + private Builder() { + } + + public Builder syncMcpClient(McpSyncClient syncMcpClient) { + this.syncMcpClient = syncMcpClient; + return this; + } + + public Builder tool(Tool tool) { + this.tool = tool; + return this; + } + + public Builder addExcludedToolContextKeys(Set excludedToolContextKeys) { + this.excludedToolContextKeys.addAll(excludedToolContextKeys); + return this; + } + + public Builder addExcludedToolContextKeys(String excludedToolContextKey) { + this.excludedToolContextKeys.add(excludedToolContextKey); + return this; + } + + public SyncMcpToolCallback build() { + return new SyncMcpToolCallback(this.syncMcpClient, this.tool, this.excludedToolContextKeys); + } + } } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java index 7d0aa4276a1..70e562fb0a7 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java @@ -16,7 +16,10 @@ package org.springframework.ai.mcp; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.function.BiPredicate; import io.modelcontextprotocol.client.McpSyncClient; @@ -62,6 +65,7 @@ * } * * @author Christian Tzolov + * @author YunKui Lu * @see ToolCallbackProvider * @see SyncMcpToolCallback * @see McpSyncClient @@ -74,6 +78,12 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { private final BiPredicate toolFilter; + /** + * The keys that will not be sent to the MCP Server inside the `_meta` field of + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} + */ + private final Set excludedToolContextKeys; + /** * Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP * clients. @@ -81,10 +91,25 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider { * @param toolFilter a filter to apply to each discovered tool */ public SyncMcpToolCallbackProvider(BiPredicate toolFilter, List mcpClients) { + this(toolFilter, mcpClients, Set.of()); + } + + /** + * Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP + * clients. + * @param mcpClients the list of MCP clients to use for discovering tools + * @param toolFilter a filter to apply to each discovered tool + * @param excludedToolContextKeys the keys that will not be sent to the MCP Server + * inside the `_meta` field of + * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} + */ + public SyncMcpToolCallbackProvider(BiPredicate toolFilter, List mcpClients, + Set excludedToolContextKeys) { Assert.notNull(mcpClients, "MCP clients must not be null"); Assert.notNull(toolFilter, "Tool filter must not be null"); this.mcpClients = mcpClients; this.toolFilter = toolFilter; + this.excludedToolContextKeys = excludedToolContextKeys; } /** @@ -115,6 +140,10 @@ public SyncMcpToolCallbackProvider(McpSyncClient... mcpClients) { this(List.of(mcpClients)); } + public static Builder builder() { + return new Builder(); + } + /** * Discovers and returns all available tools from all connected MCP servers. *

    @@ -134,7 +163,11 @@ public ToolCallback[] getToolCallbacks() { .tools() .stream() .filter(tool -> this.toolFilter.test(mcpClient, tool)) - .map(tool -> new SyncMcpToolCallback(mcpClient, tool))) + .map(tool -> SyncMcpToolCallback.builder() + .syncMcpClient(mcpClient) + .tool(tool) + .addExcludedToolContextKeys(excludedToolContextKeys) + .build())) .toArray(ToolCallback[]::new); validateToolCallbacks(array); return array; @@ -178,4 +211,46 @@ public static List syncToolCallbacks(List mcpClient return List.of((new SyncMcpToolCallbackProvider(mcpClients).getToolCallbacks())); } + public static class Builder { + + private BiPredicate toolFilter = (mcpClient, tool) -> true; + + private List mcpClients = new ArrayList<>(); + + private Set excludedToolContextKeys = new HashSet<>(); + + private Builder() { + } + + public Builder toolFilter(BiPredicate toolFilter) { + this.toolFilter = toolFilter; + return this; + } + + public Builder addMcpClient(McpSyncClient mcpClient) { + this.mcpClients.add(mcpClient); + return this; + } + + public Builder addMcpClients(List mcpClients) { + this.mcpClients.addAll(mcpClients); + return this; + } + + public Builder addExcludedToolContextKeys(String excludedToolContextKey) { + this.excludedToolContextKeys.add(excludedToolContextKey); + return this; + } + + public Builder addExcludedToolContextKeys(Set excludedToolContextKeys) { + this.excludedToolContextKeys.addAll(excludedToolContextKeys); + return this; + } + + public SyncMcpToolCallbackProvider build() { + return new SyncMcpToolCallbackProvider(toolFilter, mcpClients, excludedToolContextKeys); + } + + } + } diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java index abf2c395ed9..e7b1beae13c 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java @@ -1,5 +1,7 @@ package org.springframework.ai.mcp; +import java.util.Map; + import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.spec.McpSchema; import org.junit.jupiter.api.Test; @@ -8,9 +10,15 @@ import org.mockito.junit.jupiter.MockitoExtension; import reactor.core.publisher.Mono; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.execution.ToolExecutionException; + +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -34,7 +42,7 @@ void callShouldThrowOnError() { assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class) .cause() .isInstanceOf(IllegalStateException.class) - .hasMessage("Error calling tool: [TextContent[annotations=null, text=Some error data]]"); + .hasMessage("Error calling tool: [TextContent[annotations=null, text=Some error data, meta=null]]"); } @Test @@ -51,4 +59,20 @@ void callShouldWrapReactiveErrors() { .hasMessage("Testing tool error"); } + @Test + void callWithToolContext() { + when(this.tool.name()).thenReturn("testTool"); + McpSchema.CallToolResult callResult = mock(McpSchema.CallToolResult.class); + when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class))).thenReturn(Mono.just(callResult)); + + var callback = new AsyncMcpToolCallback(this.mcpClient, this.tool); + + String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar"))); + + assertThat(response).isNotNull(); + verify(this.mcpClient).callTool(argThat(callToolRequest -> callToolRequest.name().equals("testTool") + && callToolRequest.arguments().equals(Map.of("param", "value")) + && callToolRequest.meta().get("ai.springframework.org/tool_context").equals(Map.of("foo", "bar")))); + } + } diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java index 99a901553ad..265bbb2247f 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java @@ -16,11 +16,11 @@ package org.springframework.ai.mcp; -import io.modelcontextprotocol.spec.McpSchema; import java.util.List; import java.util.Map; import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; @@ -31,13 +31,14 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.ai.chat.model.ToolContext; -import org.springframework.ai.content.Content; import org.springframework.ai.tool.execution.ToolExecutionException; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -84,7 +85,7 @@ void callShouldHandleJsonInputAndOutput() { } @Test - void callShouldIgnoreToolContext() { + void callWithToolContext() { // when(mcpClient.getClientInfo()).thenReturn(new Implementation("testClient", // "1.0.0")); @@ -92,11 +93,19 @@ void callShouldIgnoreToolContext() { CallToolResult callResult = mock(CallToolResult.class); when(this.mcpClient.callTool(any(CallToolRequest.class))).thenReturn(callResult); - SyncMcpToolCallback callback = new SyncMcpToolCallback(this.mcpClient, this.tool); + SyncMcpToolCallback callback = SyncMcpToolCallback.builder() + .syncMcpClient(this.mcpClient) + .tool(this.tool) + .addExcludedToolContextKeys("foo1") + .build(); - String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar"))); + String response = callback.call("{\"param\":\"value\"}", + new ToolContext(Map.of("foo1", "bar1", "foo2", "bar2"))); assertThat(response).isNotNull(); + verify(this.mcpClient).callTool(argThat(callToolRequest -> callToolRequest.name().equals("testTool") + && callToolRequest.arguments().equals(Map.of("param", "value")) + && callToolRequest.meta().get("ai.springframework.org/tool_context").equals(Map.of("foo2", "bar2")))); } @Test @@ -114,7 +123,7 @@ void callShouldThrowOnError() { assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class) .cause() .isInstanceOf(IllegalStateException.class) - .hasMessage("Error calling tool: [TextContent[annotations=null, text=Some error data]]"); + .hasMessage("Error calling tool: [TextContent[annotations=null, text=Some error data, meta=null]]"); } @Test diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java index 2bcbe305c5d..a0e49cba3c0 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java @@ -26,6 +26,7 @@ import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification; import io.modelcontextprotocol.server.McpServerFeatures.SyncToolSpecification; import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.ListToolsResult; @@ -42,7 +43,9 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.assertArg; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; class ToolUtilsTests { @@ -111,8 +114,20 @@ void toSyncToolSpecificationShouldConvertSingleCallback() { assertThat(toolSpecification).isNotNull(); assertThat(toolSpecification.tool().name()).isEqualTo("test"); - CallToolResult result = toolSpecification.call().apply(mock(McpSyncServerExchange.class), Map.of()); + Map meta = Map.of("ai.springframework.org/tool_context", Map.of("foo", "bar"), "key1", + "value1"); + CallToolRequest callToolRequest = new CallToolRequest("test", Map.of("param", "value"), meta); + + CallToolResult result = toolSpecification.callHandler() + .apply(mock(McpSyncServerExchange.class), callToolRequest); TextContent content = (TextContent) result.content().get(0); + + verify(callback).call(anyString(), assertArg(toolContext -> { + Map context = toolContext.getContext(); + assertThat(context).containsEntry("foo", "bar"); + assertThat(context).containsEntry("_meta", meta); + })); + assertThat(content.text()).isEqualTo("success"); assertThat(result.isError()).isFalse(); } @@ -124,10 +139,23 @@ void toSyncToolSpecificationShouldHandleError() { SyncToolSpecification toolSpecification = McpToolUtils.toSyncToolSpecification(callback); assertThat(toolSpecification).isNotNull(); - CallToolResult result = toolSpecification.call().apply(mock(McpSyncServerExchange.class), Map.of()); + + Map meta = Map.of("ai.springframework.org/tool_context", Map.of("foo", "bar"), "key1", + "value1"); + CallToolRequest callToolRequest = new CallToolRequest("test", Map.of("param", "value"), meta); + CallToolResult result = toolSpecification.callHandler() + .apply(mock(McpSyncServerExchange.class), callToolRequest); TextContent content = (TextContent) result.content().get(0); + + verify(callback).call(anyString(), assertArg(toolContext -> { + Map context = toolContext.getContext(); + assertThat(context).containsEntry("foo", "bar"); + assertThat(context).containsEntry("_meta", meta); + })); + assertThat(content.text()).isEqualTo("error"); assertThat(result.isError()).isTrue(); + } @Test @@ -152,13 +180,23 @@ void toAsyncToolSpecificationShouldConvertSingleCallback() { assertThat(toolSpecification).isNotNull(); assertThat(toolSpecification.tool().name()).isEqualTo("test"); - StepVerifier.create(toolSpecification.call().apply(mock(McpAsyncServerExchange.class), Map.of())) + Map meta = Map.of("ai.springframework.org/tool_context", Map.of("foo", "bar"), "key1", + "value1"); + CallToolRequest callToolRequest = new CallToolRequest("test", Map.of("param", "value"), meta); + + StepVerifier.create(toolSpecification.callHandler().apply(mock(McpAsyncServerExchange.class), callToolRequest)) .assertNext(result -> { TextContent content = (TextContent) result.content().get(0); assertThat(content.text()).isEqualTo("success"); assertThat(result.isError()).isFalse(); }) .verifyComplete(); + + verify(callback).call(anyString(), assertArg(toolContext -> { + Map context = toolContext.getContext(); + assertThat(context).containsEntry("foo", "bar"); + assertThat(context).containsEntry("_meta", meta); + })); } @Test @@ -168,13 +206,24 @@ void toAsyncToolSpecificationShouldHandleError() { AsyncToolSpecification toolSpecification = McpToolUtils.toAsyncToolSpecification(callback); assertThat(toolSpecification).isNotNull(); - StepVerifier.create(toolSpecification.call().apply(mock(McpAsyncServerExchange.class), Map.of())) + + Map meta = Map.of("ai.springframework.org/tool_context", Map.of("foo", "bar"), "key1", + "value1"); + CallToolRequest callToolRequest = new CallToolRequest("test", Map.of("param", "value"), meta); + + StepVerifier.create(toolSpecification.callHandler().apply(mock(McpAsyncServerExchange.class), callToolRequest)) .assertNext(result -> { TextContent content = (TextContent) result.content().get(0); assertThat(content.text()).isEqualTo("error"); assertThat(result.isError()).isTrue(); }) .verifyComplete(); + + verify(callback).call(anyString(), assertArg(toolContext -> { + Map context = toolContext.getContext(); + assertThat(context).containsEntry("foo", "bar"); + assertThat(context).containsEntry("_meta", meta); + })); } @Test