Skip to content

GH-2947: Add @ToolParam annotation support for parameter name binding #3803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,25 @@
import java.lang.annotation.Target;

/**
* Marks a tool argument.
* Marks a tool argument for method-based tools.
* <p>
* This annotation can be used to specify metadata for a tool parameter, including whether
* it is required, a description, or a custom name to bind to.
* <p>
* When the parameter name cannot be inferred (e.g. compiled without `-parameters`), the
* {@code value} field can be used to manually specify the name that should match the key
* in the tool input map.
*
* <p>
* <b>Example:</b> <pre class="code">
* public String greet(
* {@code @ToolParam(value = "user_name")} String name) {
* return "Hello, " + name;
* }
* </pre>
*
* @author Thomas Vitale
* @author Dongha Koo
* @since 1.0.0
*/
@Target({ ElementType.PARAMETER, ElementType.FIELD, ElementType.ANNOTATION_TYPE })
Expand All @@ -43,4 +59,9 @@
*/
String description() default "";

/**
* The name of the parameter to bind to.
*/
String value() default "";

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolCallResultConverter;
import org.springframework.ai.tool.execution.ToolCallResultConverter;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
Expand All @@ -44,6 +46,7 @@
* A {@link ToolCallback} implementation to invoke methods as tools.
*
* @author Thomas Vitale
* @author Dongha Koo
* @since 1.0.0
*/
public final class MethodToolCallback implements ToolCallback {
Expand Down Expand Up @@ -129,13 +132,43 @@ private Map<String, Object> extractToolArguments(String toolInput) {
});
}

// Based on the implementation in MethodToolCallback.
/**
* Builds the array of arguments to be passed into the target method, based on the
* input tool arguments and method parameter metadata.
*
* <p>
* This method handles special cases like:
* <ul>
* <li>{@link ToolContext} parameters are injected directly.</li>
* <li>When a {@link ToolParam} annotation is present on a parameter, its
* {@code value} is used to bind input keys to parameters — useful when method
* parameter names are not retained (e.g. missing {@code -parameters} during
* compilation).</li>
* <li>Otherwise, falls back to {@link java.lang.reflect.Parameter#getName()}.</li>
* </ul>
*
* <p>
* Examples: <pre>{@code
* public String greet(@ToolParam("user_name") String name) {
* return "Hi, " + name;
* }
* }</pre> If the tool input contains {"user_name": "Alice"}, the {@code name}
* parameter is populated with "Alice".
* @param toolInputArguments the parsed input map from JSON
* @param toolContext optional tool context, injected if required
* @return an array of method arguments to invoke the tool method with
*/
private Object[] buildMethodArguments(Map<String, Object> toolInputArguments, @Nullable ToolContext toolContext) {
return Stream.of(this.toolMethod.getParameters()).map(parameter -> {
if (parameter.getType().isAssignableFrom(ToolContext.class)) {
return toolContext;
}
Object rawArgument = toolInputArguments.get(parameter.getName());

ToolParam toolParam = AnnotationUtils.getAnnotation(parameter, ToolParam.class);
String paramName = (toolParam != null && !toolParam.value().isEmpty()) ? toolParam.value()
: parameter.getName();

Object rawArgument = toolInputArguments.get(paramName);
return buildTypedArgument(rawArgument, parameter.getParameterizedType());
}).toArray();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;

Expand Down Expand Up @@ -173,6 +174,117 @@ void testToolContextType() throws Exception {
assertThat(result).isEqualTo("1 entries processed {foo=bar}");
}

@Test
void testToolParamAnnotationValueUsedAsBindingKey() throws Exception {
TestGenericClass testObject = new TestGenericClass();
Method method = TestGenericClass.class.getMethod("greetWithAlias", String.class);

ToolDefinition toolDefinition = DefaultToolDefinition.builder()
.name("greet")
.description("Greet a user with alias binding")
.inputSchema("{}")
.build();

MethodToolCallback callback = MethodToolCallback.builder()
.toolDefinition(toolDefinition)
.toolMethod(method)
.toolObject(testObject)
.build();

String toolInput = """
{
"user_name": "Alice"
}
""";

String result = callback.call(toolInput);

assertThat(result).isEqualTo("\"Hello, Alice\"");
}

@Test
void testToolParamEmptyValueUsesParameterName() throws Exception {
TestGenericClass testObject = new TestGenericClass();
Method method = TestGenericClass.class.getMethod("greet", String.class);

ToolDefinition toolDefinition = DefaultToolDefinition.builder()
.name("greet")
.description("Greet a user with implicit binding")
.inputSchema("{}")
.build();

MethodToolCallback callback = MethodToolCallback.builder()
.toolDefinition(toolDefinition)
.toolMethod(method)
.toolObject(testObject)
.build();

String toolInput = """
{
"name": "Bob"
}
""";

String result = callback.call(toolInput);

assertThat(result).isEqualTo("\"Hello, Bob\"");
}

@Test
void testToolParamMissingInputHandledAsNull() throws Exception {
TestGenericClass testObject = new TestGenericClass();
Method method = TestGenericClass.class.getMethod("greetWithAlias", String.class);

ToolDefinition toolDefinition = DefaultToolDefinition.builder()
.name("greet")
.description("Greet a user with missing input")
.inputSchema("{}")
.build();

MethodToolCallback callback = MethodToolCallback.builder()
.toolDefinition(toolDefinition)
.toolMethod(method)
.toolObject(testObject)
.build();

String toolInput = """
{}
""";

String result = callback.call(toolInput);

assertThat(result).isEqualTo("\"Hello, null\"");
}

@Test
void testMultipleToolParamsBinding() throws Exception {
TestGenericClass testObject = new TestGenericClass();
Method method = TestGenericClass.class.getMethod("greetFullName", String.class, String.class);

ToolDefinition toolDefinition = DefaultToolDefinition.builder()
.name("greetFullName")
.description("Greet a user by full name")
.inputSchema("{}")
.build();

MethodToolCallback callback = MethodToolCallback.builder()
.toolDefinition(toolDefinition)
.toolMethod(method)
.toolObject(testObject)
.build();

String toolInput = """
{
"first": "Jane",
"last": "Doe"
}
""";

String result = callback.call(toolInput);

assertThat(result).isEqualTo("\"Hello, Jane Doe\"");
}

/**
* Test class with methods that use generic types.
*/
Expand All @@ -195,6 +307,18 @@ public String processStringListInToolContext(ToolContext toolContext) {
return context.size() + " entries processed " + context;
}

public String greetWithAlias(@ToolParam("user_name") String name) {
return "Hello, " + name;
}

public String greet(@ToolParam String name) {
return "Hello, " + name;
}

public String greetFullName(@ToolParam("first") String first, @ToolParam("last") String last) {
return "Hello, " + first + " " + last;
}

}

}