diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ChatClientToolsWithGenericArgumentTypesIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ChatClientToolsWithGenericArgumentTypesIT.java index a108e5cde0a..6d54713a8b4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ChatClientToolsWithGenericArgumentTypesIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ChatClientToolsWithGenericArgumentTypesIT.java @@ -57,6 +57,24 @@ void beforeEach() { @Autowired ChatModel chatModel; + @Test + void toolWithGenericArgumentTypes2() { + // @formatter:off + String response = ChatClient.create(this.chatModel).prompt() + .user("Turn light YELLOW in the living room and the kitchen. You can violate the color enum for this request.") + .tools(new TestToolProvider()) + .call() + .content(); + // @formatter:on + + logger.info("Response: {}", response); + + assertThat(arguments).containsEntry("living room", LightColor.RED); + assertThat(arguments).containsEntry("kitchen", LightColor.RED); + + assertThat(callCounter.get()).isEqualTo(1); + } + @Test void toolWithGenericArgumentTypes() { // @formatter:off diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java index 7c303f3a693..602327179b7 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.stream.Stream; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -79,11 +80,13 @@ public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata : DEFAULT_RESULT_CONVERTER; } + @SuppressWarnings("null") @Override public ToolDefinition getToolDefinition() { return this.toolDefinition; } + @SuppressWarnings("null") @Override public ToolMetadata getToolMetadata() { return this.toolMetadata; @@ -100,13 +103,13 @@ public String call(String toolInput, @Nullable ToolContext toolContext) { logger.debug("Starting execution of tool: {}", this.toolDefinition.name()); - validateToolContextSupport(toolContext); + this.validateToolContextSupport(toolContext); - Map toolArguments = extractToolArguments(toolInput); + Map toolArguments = this.extractToolArguments(toolInput); - Object[] methodArguments = buildMethodArguments(toolArguments, toolContext); + Object[] methodArguments = this.buildMethodArguments(toolArguments, toolContext); - Object result = callMethod(methodArguments); + Object result = this.callMethod(methodArguments); logger.debug("Successful execution of tool: {}", this.toolDefinition.name()); @@ -125,11 +128,21 @@ private void validateToolContextSupport(@Nullable ToolContext toolContext) { } private Map extractToolArguments(String toolInput) { - return JsonParser.fromJson(toolInput, new TypeReference<>() { - }); + try { + return JsonParser.fromJson(toolInput, new TypeReference<>() { + }); + } + catch (IllegalStateException ex) { + if (ex.getCause() instanceof JsonProcessingException jsonExp) { + logger.warn("Conversion from JSON failed", ex); + throw new ToolExecutionException(this.getToolDefinition(), jsonExp); + } + throw ex; + } } // Based on the implementation in MethodToolCallback. + @SuppressWarnings("null") private Object[] buildMethodArguments(Map toolInputArguments, @Nullable ToolContext toolContext) { return Stream.of(this.toolMethod.getParameters()).map(parameter -> { if (parameter.getType().isAssignableFrom(ToolContext.class)) { @@ -145,16 +158,26 @@ private Object buildTypedArgument(@Nullable Object value, Type type) { if (value == null) { return null; } + try { + if (type instanceof Class) { + return JsonParser.toTypedObject(value, (Class) type); + } - if (type instanceof Class) { - return JsonParser.toTypedObject(value, (Class) type); - } + // For generic types, use the fromJson method that accepts Type - // For generic types, use the fromJson method that accepts Type - String json = JsonParser.toJson(value); - return JsonParser.fromJson(json, type); + String json = JsonParser.toJson(value); + return JsonParser.fromJson(json, type); + } + catch (IllegalStateException ex) { + if (ex.getCause() instanceof JsonProcessingException jsonExp) { + logger.warn("Conversion from JSON failed", ex); + throw new ToolExecutionException(this.getToolDefinition(), jsonExp); + } + throw ex; + } } + @SuppressWarnings("null") @Nullable private Object callMethod(Object[] methodArguments) { if (isObjectNotPublic() || isMethodNotPublic()) { @@ -232,6 +255,7 @@ public Builder toolCallResultConverter(ToolCallResultConverter toolCallResultCon return this; } + @SuppressWarnings("null") public MethodToolCallback build() { return new MethodToolCallback(this.toolDefinition, this.toolMetadata, this.toolMethod, this.toolObject, this.toolCallResultConverter); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackExceptionHandlingTest.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackExceptionHandlingTest.java new file mode 100644 index 00000000000..6f521a030db --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackExceptionHandlingTest.java @@ -0,0 +1,84 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.tool.method; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.execution.ToolExecutionException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * @author Christian Tzolov + */ +public class MethodToolCallbackExceptionHandlingTest { + + @Test + void testGenericListType() throws Exception { + // Create a test object with a method that takes a List + TestTools testObject = new TestTools(); + + var callback = MethodToolCallbackProvider.builder().toolObjects(testObject).build().getToolCallbacks()[0]; + + // Create a JSON input with a list of strings + String toolInput = """ + { + "strings": ["one", "two", "three"] + } + """; + + // Call the tool + String result = callback.call(toolInput); + + // Verify the result + assertThat(result).isEqualTo("3 strings processed: [one, two, three]"); + + // Verify + String ivalidToolInput = """ + { + "strings": 678 + } + """; + + // Call the tool + assertThatThrownBy(() -> callback.call(ivalidToolInput)).isInstanceOf(ToolExecutionException.class) + .hasMessageContaining("Cannot deserialize value"); + + // Verify extractToolArguments + + String ivalidToolInput2 = """ + nill + """; + + // Call the tool + assertThatThrownBy(() -> callback.call(ivalidToolInput2)).isInstanceOf(ToolExecutionException.class) + .hasMessageContaining("Unrecognized token"); + } + + public static class TestTools { + + @Tool(description = "Process a list of strings") + public String stringList(List strings) { + return strings.size() + " strings processed: " + strings; + } + + } + +}