From 7fd742f2df66e3d48e6e640db3a16ddac1d65283 Mon Sep 17 00:00:00 2001 From: Dongha Koo Date: Sun, 13 Jul 2025 16:10:36 +0900 Subject: [PATCH] GH-2947: Add @ToolParam annotation support for parameter name binding Closes #2947 * Implement @ToolParam to bind method parameters by custom name * Update MethodToolCallback to resolve parameter by annotation value * Add unit tests for generic types and annotation usage Signed-off-by: Dongha Koo --- .../ai/tool/annotation/ToolParam.java | 23 +++- .../ai/tool/method/MethodToolCallback.java | 37 +++++- .../MethodToolCallbackGenericTypesTest.java | 124 ++++++++++++++++++ 3 files changed, 181 insertions(+), 3 deletions(-) diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/annotation/ToolParam.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/annotation/ToolParam.java index 2414e4caecc..a4f3ed6d2cd 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/annotation/ToolParam.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/annotation/ToolParam.java @@ -23,9 +23,25 @@ import java.lang.annotation.Target; /** - * Marks a tool argument. + * Marks a tool argument for method-based tools. + *

+ * This annotation can be used to specify metadata for a tool parameter, including whether + * it is required, a description, or a custom name to bind to. + *

+ * When the parameter name cannot be inferred (e.g. compiled without `-parameters`), the + * {@code value} field can be used to manually specify the name that should match the key + * in the tool input map. + * + *

+ * Example:

+ * public String greet(
+ *     {@code @ToolParam(value = "user_name")} String name) {
+ *   return "Hello, " + name;
+ * }
+ * 
* * @author Thomas Vitale + * @author Dongha Koo * @since 1.0.0 */ @Target({ ElementType.PARAMETER, ElementType.FIELD, ElementType.ANNOTATION_TYPE }) @@ -43,4 +59,9 @@ */ String description() default ""; + /** + * The name of the parameter to bind to. + */ + String value() default ""; + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java index 7c303f3a693..97acc4a632d 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java @@ -29,12 +29,14 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; import org.springframework.ai.tool.execution.ToolCallResultConverter; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.ai.util.json.JsonParser; +import org.springframework.core.annotation.AnnotationUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -44,6 +46,7 @@ * A {@link ToolCallback} implementation to invoke methods as tools. * * @author Thomas Vitale + * @author Dongha Koo * @since 1.0.0 */ public final class MethodToolCallback implements ToolCallback { @@ -129,13 +132,43 @@ private Map extractToolArguments(String toolInput) { }); } - // Based on the implementation in MethodToolCallback. + /** + * Builds the array of arguments to be passed into the target method, based on the + * input tool arguments and method parameter metadata. + * + *

+ * This method handles special cases like: + *

    + *
  • {@link ToolContext} parameters are injected directly.
  • + *
  • When a {@link ToolParam} annotation is present on a parameter, its + * {@code value} is used to bind input keys to parameters — useful when method + * parameter names are not retained (e.g. missing {@code -parameters} during + * compilation).
  • + *
  • Otherwise, falls back to {@link java.lang.reflect.Parameter#getName()}.
  • + *
+ * + *

+ * Examples:

{@code
+	 * public String greet(@ToolParam("user_name") String name) {
+	 *     return "Hi, " + name;
+	 * }
+	 * }
If the tool input contains {"user_name": "Alice"}, the {@code name} + * parameter is populated with "Alice". + * @param toolInputArguments the parsed input map from JSON + * @param toolContext optional tool context, injected if required + * @return an array of method arguments to invoke the tool method with + */ private Object[] buildMethodArguments(Map toolInputArguments, @Nullable ToolContext toolContext) { return Stream.of(this.toolMethod.getParameters()).map(parameter -> { if (parameter.getType().isAssignableFrom(ToolContext.class)) { return toolContext; } - Object rawArgument = toolInputArguments.get(parameter.getName()); + + ToolParam toolParam = AnnotationUtils.getAnnotation(parameter, ToolParam.class); + String paramName = (toolParam != null && !toolParam.value().isEmpty()) ? toolParam.value() + : parameter.getName(); + + Object rawArgument = toolInputArguments.get(paramName); return buildTypedArgument(rawArgument, parameter.getParameterizedType()); }).toArray(); } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java index 6e05fd80c59..8c55ed8d5e6 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; @@ -173,6 +174,117 @@ void testToolContextType() throws Exception { assertThat(result).isEqualTo("1 entries processed {foo=bar}"); } + @Test + void testToolParamAnnotationValueUsedAsBindingKey() throws Exception { + TestGenericClass testObject = new TestGenericClass(); + Method method = TestGenericClass.class.getMethod("greetWithAlias", String.class); + + ToolDefinition toolDefinition = DefaultToolDefinition.builder() + .name("greet") + .description("Greet a user with alias binding") + .inputSchema("{}") + .build(); + + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(toolDefinition) + .toolMethod(method) + .toolObject(testObject) + .build(); + + String toolInput = """ + { + "user_name": "Alice" + } + """; + + String result = callback.call(toolInput); + + assertThat(result).isEqualTo("\"Hello, Alice\""); + } + + @Test + void testToolParamEmptyValueUsesParameterName() throws Exception { + TestGenericClass testObject = new TestGenericClass(); + Method method = TestGenericClass.class.getMethod("greet", String.class); + + ToolDefinition toolDefinition = DefaultToolDefinition.builder() + .name("greet") + .description("Greet a user with implicit binding") + .inputSchema("{}") + .build(); + + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(toolDefinition) + .toolMethod(method) + .toolObject(testObject) + .build(); + + String toolInput = """ + { + "name": "Bob" + } + """; + + String result = callback.call(toolInput); + + assertThat(result).isEqualTo("\"Hello, Bob\""); + } + + @Test + void testToolParamMissingInputHandledAsNull() throws Exception { + TestGenericClass testObject = new TestGenericClass(); + Method method = TestGenericClass.class.getMethod("greetWithAlias", String.class); + + ToolDefinition toolDefinition = DefaultToolDefinition.builder() + .name("greet") + .description("Greet a user with missing input") + .inputSchema("{}") + .build(); + + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(toolDefinition) + .toolMethod(method) + .toolObject(testObject) + .build(); + + String toolInput = """ + {} + """; + + String result = callback.call(toolInput); + + assertThat(result).isEqualTo("\"Hello, null\""); + } + + @Test + void testMultipleToolParamsBinding() throws Exception { + TestGenericClass testObject = new TestGenericClass(); + Method method = TestGenericClass.class.getMethod("greetFullName", String.class, String.class); + + ToolDefinition toolDefinition = DefaultToolDefinition.builder() + .name("greetFullName") + .description("Greet a user by full name") + .inputSchema("{}") + .build(); + + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(toolDefinition) + .toolMethod(method) + .toolObject(testObject) + .build(); + + String toolInput = """ + { + "first": "Jane", + "last": "Doe" + } + """; + + String result = callback.call(toolInput); + + assertThat(result).isEqualTo("\"Hello, Jane Doe\""); + } + /** * Test class with methods that use generic types. */ @@ -195,6 +307,18 @@ public String processStringListInToolContext(ToolContext toolContext) { return context.size() + " entries processed " + context; } + public String greetWithAlias(@ToolParam("user_name") String name) { + return "Hello, " + name; + } + + public String greet(@ToolParam String name) { + return "Hello, " + name; + } + + public String greetFullName(@ToolParam("first") String first, @ToolParam("last") String last) { + return "Hello, " + first + " " + last; + } + } }