Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Map;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.Tool;
Expand All @@ -33,6 +34,7 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

Expand All @@ -45,15 +47,20 @@
*
* @author Christian Tzolov
* @author YunKui Lu
* @author Sun Yuhan
*/
public class AsyncMcpToolCallback implements ToolCallback {

private static final Logger logger = LoggerFactory.getLogger(AsyncMcpToolCallback.class);

private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

private final McpAsyncClient mcpClient;

private final Tool tool;

private final ToolMetadata toolMetadata;

private final String prefixedToolName;

private final ToolContextToMcpMetaConverter toolContextToMcpMetaConverter;
Expand Down Expand Up @@ -88,6 +95,14 @@ private AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool, String prefixe
this.tool = tool;
this.prefixedToolName = prefixedToolName;
this.toolContextToMcpMetaConverter = toolContextToMcpMetaConverter;
McpSchema.ToolAnnotations annotations = tool.annotations();
Boolean returnDirect = (annotations != null) ? annotations.returnDirect() : null;
if (returnDirect != null) {
this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build();
}
else {
this.toolMetadata = DEFAULT_TOOL_METADATA;
}
}

@Override
Expand Down Expand Up @@ -149,6 +164,11 @@ public String call(String toolCallInput, @Nullable ToolContext toolContext) {
return ModelOptionsUtils.toJsonString(response.content());
}

@Override
public ToolMetadata getToolMetadata() {
return this.toolMetadata;
}

/**
* Creates a builder for constructing AsyncMcpToolCallback instances.
* @return a new builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
Expand All @@ -63,6 +64,7 @@
* </ul>
*
* @author Christian Tzolov
* @author Sun Yuhan
*/
public final class McpToolUtils {

Expand Down Expand Up @@ -228,12 +230,18 @@ public static McpStatelessServerFeatures.SyncToolSpecification toStatelessSyncTo

private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCallback toolCallback,
MimeType mimeType) {
boolean returnDirect = Optional.ofNullable(toolCallback.getToolMetadata())
.map(ToolMetadata::returnDirect)
.orElse(false);
McpSchema.ToolAnnotations toolAnnotations = new McpSchema.ToolAnnotations(null, null, null, null, null,
returnDirect);

var tool = McpSchema.Tool.builder()
.name(toolCallback.getToolDefinition().name())
.description(toolCallback.getToolDefinition().description())
.inputSchema(ModelOptionsUtils.jsonToObject(toolCallback.getToolDefinition().inputSchema(),
McpSchema.JsonSchema.class))
.annotations(toolAnnotations)
.build();

return new SharedSyncToolSpecification(tool, (exchangeOrContext, request) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Map;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.Tool;
Expand All @@ -32,6 +33,7 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

Expand All @@ -41,16 +43,21 @@
*
* @author Christian Tzolov
* @author YunKui Lu
* @author Sun Yuhan
* @since 1.0.0
*/
public class SyncMcpToolCallback implements ToolCallback {

private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolCallback.class);

private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

private final McpSyncClient mcpClient;

private final Tool tool;

private final ToolMetadata toolMetadata;

private final String prefixedToolName;

private final ToolContextToMcpMetaConverter toolContextToMcpMetaConverter;
Expand Down Expand Up @@ -85,6 +92,14 @@ private SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool, String prefixedT
this.tool = tool;
this.prefixedToolName = prefixedToolName;
this.toolContextToMcpMetaConverter = toolContextToMcpMetaConverter;
McpSchema.ToolAnnotations annotations = tool.annotations();
Boolean returnDirect = (annotations != null) ? annotations.returnDirect() : null;
if (returnDirect != null) {
this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build();
}
else {
this.toolMetadata = DEFAULT_TOOL_METADATA;
}
}

@Override
Expand Down Expand Up @@ -149,6 +164,11 @@ public String call(String toolCallInput, @Nullable ToolContext toolContext) {
return ModelOptionsUtils.toJsonString(response.content());
}

@Override
public ToolMetadata getToolMetadata() {
return this.toolMetadata;
}

/**
* Creates a builder for constructing {@code SyncMcpToolCallback} instances.
* @return a new builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ void getToolDefinitionShouldReturnCorrectDefinition() {
when(this.tool.description()).thenReturn("Test tool description");
var jsonSchema = mock(McpSchema.JsonSchema.class);
when(this.tool.inputSchema()).thenReturn(jsonSchema);
var toolAnnotations = new McpSchema.ToolAnnotations(null, false, false, false, false, true);
when(this.tool.annotations()).thenReturn(toolAnnotations);

// Act
var callback = AsyncMcpToolCallback.builder()
Expand All @@ -213,11 +215,13 @@ void getToolDefinitionShouldReturnCorrectDefinition() {
.build();

ToolDefinition definition = callback.getToolDefinition();
var toolMetadata = callback.getToolMetadata();

// Assert
assertThat(definition.name()).isEqualTo("prefix_testTool");
assertThat(definition.description()).isEqualTo("Test tool description");
assertThat(definition.inputSchema()).isNotNull();
assertThat(toolMetadata.returnDirect()).isEqualTo(true);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class SyncMcpToolCallbackTests {
@Test
void getToolDefinitionShouldReturnCorrectDefinition() {
var clientInfo = new Implementation("testClient", "1.0.0");

when(this.tool.name()).thenReturn("testTool");
when(this.tool.description()).thenReturn("Test tool description");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.metadata.ToolMetadata;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -314,8 +315,10 @@ private ToolCallback createMockToolCallback(String name, String result) {
.description("Test tool")
.inputSchema("{}")
.build();
ToolMetadata metadata = ToolMetadata.builder().build();
when(callback.getToolDefinition()).thenReturn(definition);
when(callback.call(anyString(), any())).thenReturn(result);
when(callback.getToolMetadata()).thenReturn(metadata);
return callback;
}

Expand All @@ -326,8 +329,10 @@ private ToolCallback createMockToolCallback(String name, RuntimeException error)
.description("Test tool")
.inputSchema("{}")
.build();
ToolMetadata metadata = ToolMetadata.builder().build();
when(callback.getToolDefinition()).thenReturn(definition);
when(callback.call(anyString(), any())).thenThrow(error);
when(callback.getToolMetadata()).thenReturn(metadata);
return callback;
}

Expand Down