diff --git a/go/ai/tools.go b/go/ai/tools.go index 2281499315..c280f95e2e 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -166,6 +166,20 @@ func NewToolWithInputSchema[Out any](name, description string, inputSchema map[s return &tool{Action: toolAction} } +// ToolSchema is a struct that contains the input and output schemas for a tool. +type ToolSchema struct { + Input map[string]any + Output map[string]any +} + +// NewToolWithOutputSchema creates a new [Tool] with a custom output schema. It can be passed directly to [Generate]. +func NewToolWithSchema[In, Out any](name, description string, schema ToolSchema, fn ToolFunc[In, Out]) Tool { + metadata, wrappedFn := implementTool(name, description, fn) + metadata["dynamic"] = true + toolAction := core.NewStructuredAction(name, api.ActionTypeTool, metadata, schema.Input, schema.Output, wrappedFn) + return &tool{Action: toolAction} +} + // implementTool creates the metadata and wrapped function common to both DefineTool and NewTool. func implementTool[In, Out any](name, description string, fn ToolFunc[In, Out]) (map[string]any, func(context.Context, In) (Out, error)) { metadata := map[string]any{ diff --git a/go/core/action.go b/go/core/action.go index 757f9aed6a..cfe64f0456 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -62,7 +62,24 @@ func NewAction[In, Out any]( inputSchema map[string]any, fn Func[In, Out], ) *ActionDef[In, Out, struct{}] { - return newAction(name, atype, metadata, inputSchema, + return newAction(name, atype, metadata, inputSchema, nil, + func(ctx context.Context, in In, cb noStream) (Out, error) { + return fn(ctx, in) + }) +} + +// NewStructuredAction creates a new non-streaming [Action] without registering it. +// It can be used to create a tool with a custom input and output schema. +// If either inputSchema or outputSchema are nil, they are inferred from the function's input or output api. +func NewStructuredAction[In, Out any]( + name string, + atype api.ActionType, + metadata map[string]any, + inputSchema map[string]any, + outputSchema map[string]any, + fn Func[In, Out], +) *ActionDef[In, Out, struct{}] { + return newAction(name, atype, metadata, inputSchema, outputSchema, func(ctx context.Context, in In, cb noStream) (Out, error) { return fn(ctx, in) }) @@ -77,7 +94,7 @@ func NewStreamingAction[In, Out, Stream any]( inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], ) *ActionDef[In, Out, Stream] { - return newAction(name, atype, metadata, inputSchema, fn) + return newAction(name, atype, metadata, inputSchema, nil, fn) } // DefineAction creates a new non-streaming Action and registers it. @@ -118,7 +135,7 @@ func defineAction[In, Out, Stream any]( inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], ) *ActionDef[In, Out, Stream] { - a := newAction(name, atype, metadata, inputSchema, fn) + a := newAction(name, atype, metadata, inputSchema, nil, fn) provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) r.RegisterAction(key, a) @@ -133,6 +150,7 @@ func newAction[In, Out, Stream any]( atype api.ActionType, metadata map[string]any, inputSchema map[string]any, + outputSchema map[string]any, fn StreamingFunc[In, Out, Stream], ) *ActionDef[In, Out, Stream] { if inputSchema == nil { @@ -142,10 +160,11 @@ func newAction[In, Out, Stream any]( } } - var o Out - var outputSchema map[string]any - if reflect.ValueOf(o).Kind() != reflect.Invalid { - outputSchema = InferSchemaMap(o) + if outputSchema == nil { + var o Out + if reflect.ValueOf(o).Kind() != reflect.Invalid { + outputSchema = InferSchemaMap(o) + } } var description string diff --git a/go/go.mod b/go/go.mod index cbde46d01c..059e69b3cd 100644 --- a/go/go.mod +++ b/go/go.mod @@ -27,7 +27,7 @@ require ( github.com/jackc/pgx/v5 v5.7.5 github.com/jba/slog v0.2.0 github.com/lib/pq v1.10.9 - github.com/mark3labs/mcp-go v0.29.0 + github.com/mark3labs/mcp-go v0.42.0 github.com/pgvector/pgvector-go v0.3.0 github.com/stretchr/testify v1.10.0 github.com/weaviate/weaviate v1.30.0 diff --git a/go/go.sum b/go/go.sum index 7100070ee1..9f16a28678 100644 --- a/go/go.sum +++ b/go/go.sum @@ -291,6 +291,8 @@ github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4 github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mark3labs/mcp-go v0.29.0 h1:sH1NBcumKskhxqYzhXfGc201D7P76TVXiT0fGVhabeI= github.com/mark3labs/mcp-go v0.29.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.42.0 h1:gk/8nYJh8t3yroCAOBhNbYsM9TCKvkM13I5t5Hfu6Ls= +github.com/mark3labs/mcp-go v0.42.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= diff --git a/go/internal/base/validation.go b/go/internal/base/validation.go index e363b95523..7a251ca1b9 100644 --- a/go/internal/base/validation.go +++ b/go/internal/base/validation.go @@ -21,12 +21,16 @@ import ( "fmt" "strings" + "github.com/mark3labs/mcp-go/mcp" "github.com/xeipuuv/gojsonschema" ) // ValidateValue will validate any value against the expected schema. // It will return an error if it doesn't match the schema, otherwise it will return nil. func ValidateValue(data any, schema map[string]any) error { + if callToolResult, ok := data.(*mcp.CallToolResult); ok { + data = callToolResult.StructuredContent + } if schema == nil { return nil } diff --git a/go/plugins/mcp/tools.go b/go/plugins/mcp/tools.go index a3fb914413..a3f68138bb 100644 --- a/go/plugins/mcp/tools.go +++ b/go/plugins/mcp/tools.go @@ -74,6 +74,23 @@ func (c *GenkitMCPClient) getInputSchema(mcpTool mcp.Tool) (map[string]any, erro return out, nil } +// getOutputSchema returns the MCP output schema as a generic map for Genkit +func (c *GenkitMCPClient) getOutputSchema(mcpTool mcp.Tool) (map[string]any, error) { + var out map[string]any + schemaBytes, err := json.Marshal(mcpTool.OutputSchema) + if err != nil { + return nil, fmt.Errorf("failed to marshal MCP output schema for tool %s: %w", mcpTool.Name, err) + } + if err := json.Unmarshal(schemaBytes, &out); err != nil { + // Fall back to empty map if unmarshalling fails + out = map[string]any{} + } + if out == nil { + out = map[string]any{} + } + return out, nil +} + // createTool converts a single MCP tool to a Genkit tool func (c *GenkitMCPClient) createTool(mcpTool mcp.Tool) (ai.Tool, error) { // Use namespaced tool name @@ -84,8 +101,22 @@ func (c *GenkitMCPClient) createTool(mcpTool mcp.Tool) (ai.Tool, error) { if err != nil { return nil, fmt.Errorf("failed to get input schema for tool %s: %w", mcpTool.Name, err) } + outputSchema, err := c.getOutputSchema(mcpTool) + if err != nil { + return nil, fmt.Errorf("failed to get output schema for tool %s: %w", mcpTool.Name, err) + } var tool ai.Tool - if len(inputSchema) > 0 { + if len(inputSchema) > 0 && len(outputSchema) > 0 { + tool = ai.NewToolWithSchema( + namespacedToolName, + mcpTool.Description, + ai.ToolSchema{ + Input: inputSchema, + Output: outputSchema, + }, + toolFunc, + ) + } else if len(inputSchema) > 0 { tool = ai.NewToolWithInputSchema( namespacedToolName, mcpTool.Description, diff --git a/go/plugins/mcp/tools_test.go b/go/plugins/mcp/tools_test.go index 8d27470e06..a384ffa851 100644 --- a/go/plugins/mcp/tools_test.go +++ b/go/plugins/mcp/tools_test.go @@ -15,10 +15,12 @@ package mcp import ( + "context" "encoding/json" "testing" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" ) func asMap(t *testing.T, v any, label string) map[string]any { @@ -161,3 +163,111 @@ func TestPrepareToolArguments(t *testing.T) { t.Fatalf("expected error for nil args with required field") } } + +// TestToolOutputSchema tests that both input and output schemas are correctly retrieved +// from the MCP server. +func TestToolOutputSchema(t *testing.T) { + // Start a test MCP server with a tool that has an input and output schema. + type InputSchema struct { + City string + } + type OutputSchema struct { + Weather string + Temperature int + } + mcpServer := server.NewMCPServer("test", "1.0.0", + server.WithToolCapabilities(true), + ) + mcpServer.AddTool( + mcp.NewTool("getWeather", + mcp.WithInputSchema[InputSchema](), + mcp.WithOutputSchema[OutputSchema](), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultStructured( + OutputSchema{Weather: "Sunny, 25°C", Temperature: 25}, + "{\"weather\": \"Sunny, 25°C\", \"temperature\": 25}", + ), nil + }, + ) + // Start the stdio server + sseServer := server.NewTestServer(mcpServer) + defer sseServer.Close() + client, err := NewGenkitMCPClient(MCPClientOptions{ + Name: "test", + SSE: &SSEConfig{ + BaseURL: sseServer.URL + "/sse", + }, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Disconnect() + // Retrieve tools from the MCP server + tools, err := client.GetActiveTools(context.Background(), nil) + if err != nil { + t.Fatalf("GetActiveTools error: %v", err) + } + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + for _, tool := range tools { + if tool.Name() != "test_getWeather" { + t.Fatalf("unexpected tool: %s", tool.Name()) + } + inputSchema := tool.Definition().InputSchema + assertSchemaProperty(t, inputSchema, "City", "string") + + outputSchema := tool.Definition().OutputSchema + assertSchemaProperty(t, outputSchema, "Weather", "string") + assertSchemaProperty(t, outputSchema, "Temperature", "integer") + + result, err := tool.RunRaw(t.Context(), InputSchema{ + City: "Paris", + }) + if err != nil { + t.Fatalf("RunRaw error: %v", err) + } + if result == nil { + t.Fatalf("RunRaw result is nil") + } + toolResult := ParseMapToStruct[mcp.CallToolResult](t, result) + toolResultOutput := ParseMapToStruct[OutputSchema](t, toolResult.StructuredContent) + if toolResultOutput.Weather != "Sunny, 25°C" { + t.Fatalf("unexpected weather: %s", toolResultOutput.Weather) + } + if toolResultOutput.Temperature != 25 { + t.Fatalf("unexpected temperature: %d", toolResultOutput.Temperature) + } + } +} + +func ParseMapToStruct[T any](t *testing.T, v any) T { + t.Helper() + var result T + jsonBytes, err := json.Marshal(v) + if err != nil { + t.Fatalf("failed to marshal map to JSON: %v", err) + } + err = json.Unmarshal(jsonBytes, &result) + if err != nil { + t.Fatalf("failed to unmarshal JSON to struct: %v", err) + } + return result +} + +// assertSchemaProperty asserts that a property in a schema is present and of the expected type. +func assertSchemaProperty(t *testing.T, schema map[string]any, propName string, propType string) { + t.Helper() + if schema == nil { + t.Fatalf("schema is nil") + } + if props, ok := schema["properties"].(map[string]any); !ok { + t.Fatalf("schema properties is nil") + } else if propValue, ok := props[propName].(map[string]any); !ok { + t.Fatalf("schema property %s is nil. schema: %v", propName, schema) + } else if propValue["type"] != propType { + t.Fatalf("schema property %s type is %s, expected %s", + propName, propValue["type"], propType) + } +}