Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ void beforeEach() {
@Autowired
ChatModel chatModel;

@Test
void toolWithGenericArgumentTypes2() {
// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.user("Turn light YELLOW in the living room and the kitchen. You can violate the color enum for this request.")
.tools(new TestToolProvider())
.call()
.content();
// @formatter:on

logger.info("Response: {}", response);

assertThat(arguments).containsEntry("living room", LightColor.RED);
assertThat(arguments).containsEntry("kitchen", LightColor.RED);

assertThat(callCounter.get()).isEqualTo(1);
}

@Test
void toolWithGenericArgumentTypes() {
// @formatter:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;
import java.util.stream.Stream;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -79,11 +80,13 @@ public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata
: DEFAULT_RESULT_CONVERTER;
}

@SuppressWarnings("null")
@Override
public ToolDefinition getToolDefinition() {
return this.toolDefinition;
}

@SuppressWarnings("null")
@Override
public ToolMetadata getToolMetadata() {
return this.toolMetadata;
Expand All @@ -100,13 +103,13 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {

logger.debug("Starting execution of tool: {}", this.toolDefinition.name());

validateToolContextSupport(toolContext);
this.validateToolContextSupport(toolContext);

Map<String, Object> toolArguments = extractToolArguments(toolInput);
Map<String, Object> toolArguments = this.extractToolArguments(toolInput);

Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);
Object[] methodArguments = this.buildMethodArguments(toolArguments, toolContext);

Object result = callMethod(methodArguments);
Object result = this.callMethod(methodArguments);

logger.debug("Successful execution of tool: {}", this.toolDefinition.name());

Expand All @@ -125,11 +128,21 @@ private void validateToolContextSupport(@Nullable ToolContext toolContext) {
}

private Map<String, Object> extractToolArguments(String toolInput) {
return JsonParser.fromJson(toolInput, new TypeReference<>() {
});
try {
return JsonParser.fromJson(toolInput, new TypeReference<>() {
});
}
catch (IllegalStateException ex) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking just the IllegalStateException, we agreed to check all exception which is similar to how MCP toolcallback error handling is done. I will make the change and update the PR.

if (ex.getCause() instanceof JsonProcessingException jsonExp) {
logger.warn("Conversion from JSON failed", ex);
throw new ToolExecutionException(this.getToolDefinition(), jsonExp);
}
throw ex;
}
}

// Based on the implementation in MethodToolCallback.
@SuppressWarnings("null")
private Object[] buildMethodArguments(Map<String, Object> toolInputArguments, @Nullable ToolContext toolContext) {
return Stream.of(this.toolMethod.getParameters()).map(parameter -> {
if (parameter.getType().isAssignableFrom(ToolContext.class)) {
Expand All @@ -145,16 +158,26 @@ private Object buildTypedArgument(@Nullable Object value, Type type) {
if (value == null) {
return null;
}
try {
if (type instanceof Class<?>) {
return JsonParser.toTypedObject(value, (Class<?>) type);
}

if (type instanceof Class<?>) {
return JsonParser.toTypedObject(value, (Class<?>) type);
}
// For generic types, use the fromJson method that accepts Type

// For generic types, use the fromJson method that accepts Type
String json = JsonParser.toJson(value);
return JsonParser.fromJson(json, type);
String json = JsonParser.toJson(value);
return JsonParser.fromJson(json, type);
}
catch (IllegalStateException ex) {
if (ex.getCause() instanceof JsonProcessingException jsonExp) {
logger.warn("Conversion from JSON failed", ex);
throw new ToolExecutionException(this.getToolDefinition(), jsonExp);
}
throw ex;
}
}

@SuppressWarnings("null")
@Nullable
private Object callMethod(Object[] methodArguments) {
if (isObjectNotPublic() || isMethodNotPublic()) {
Expand Down Expand Up @@ -232,6 +255,7 @@ public Builder toolCallResultConverter(ToolCallResultConverter toolCallResultCon
return this;
}

@SuppressWarnings("null")
public MethodToolCallback build() {
return new MethodToolCallback(this.toolDefinition, this.toolMetadata, this.toolMethod, this.toolObject,
this.toolCallResultConverter);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2025 - 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.tool.method;

import java.util.List;

import org.junit.jupiter.api.Test;

import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.execution.ToolExecutionException;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* @author Christian Tzolov
*/
public class MethodToolCallbackExceptionHandlingTest {

@Test
void testGenericListType() throws Exception {
// Create a test object with a method that takes a List<String>
TestTools testObject = new TestTools();

var callback = MethodToolCallbackProvider.builder().toolObjects(testObject).build().getToolCallbacks()[0];

// Create a JSON input with a list of strings
String toolInput = """
{
"strings": ["one", "two", "three"]
}
""";

// Call the tool
String result = callback.call(toolInput);

// Verify the result
assertThat(result).isEqualTo("3 strings processed: [one, two, three]");

// Verify
String ivalidToolInput = """
{
"strings": 678
}
""";

// Call the tool
assertThatThrownBy(() -> callback.call(ivalidToolInput)).isInstanceOf(ToolExecutionException.class)
.hasMessageContaining("Cannot deserialize value");

// Verify extractToolArguments

String ivalidToolInput2 = """
nill
""";

// Call the tool
assertThatThrownBy(() -> callback.call(ivalidToolInput2)).isInstanceOf(ToolExecutionException.class)
.hasMessageContaining("Unrecognized token");
}

public static class TestTools {

@Tool(description = "Process a list of strings")
public String stringList(List<String> strings) {
return strings.size() + " strings processed: " + strings;
}

}

}