From 828f0af8c52ede0f9dd87be92a20249d4ea1bfa8 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Fri, 18 Apr 2025 22:10:24 -0700 Subject: [PATCH 1/9] feat: add MCP support (cherry picked from commit 54b4c1fd08a2ba90401af8afe3406484cc20e922) --- go.mod | 4 +- go.sum | 8 +- pkg/cli/gptscript.go | 2 +- pkg/engine/engine.go | 26 ++ pkg/loader/loader.go | 54 ++- pkg/mcp/loader.go | 264 +++++++++++++ pkg/mcp/runner.go | 51 +++ pkg/tests/runner2_test.go | 352 ++++++++++++++++++ .../testdata/TestMCPLoad/call1-resp.golden | 9 + pkg/tests/testdata/TestMCPLoad/call1.golden | 3 + pkg/tests/testdata/TestMCPLoad/step1.golden | 6 + pkg/types/tool.go | 21 +- pkg/types/toolstring.go | 4 + 13 files changed, 784 insertions(+), 20 deletions(-) create mode 100644 pkg/mcp/loader.go create mode 100644 pkg/mcp/runner.go create mode 100644 pkg/tests/testdata/TestMCPLoad/call1-resp.golden create mode 100644 pkg/tests/testdata/TestMCPLoad/call1.golden create mode 100644 pkg/tests/testdata/TestMCPLoad/step1.golden diff --git a/go.mod b/go.mod index f803a3b9..35c9689e 100644 --- a/go.mod +++ b/go.mod @@ -18,10 +18,11 @@ require ( github.com/gptscript-ai/chat-completion-client v0.0.0-20250224164718-139cb4507b1d github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 - github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee + github.com/gptscript-ai/tui v0.0.0-20250419050840-5e79e16786c9 github.com/hexops/autogold/v2 v2.2.1 github.com/hexops/valast v1.4.4 github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 + github.com/mark3labs/mcp-go v0.21.1 github.com/mholt/archives v0.1.0 github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 @@ -122,6 +123,7 @@ require ( github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.5.4 // indirect github.com/yuin/goldmark-emoji v1.0.2 // indirect go4.org v0.0.0-20230225012048-214862532bf5 // indirect diff --git a/go.sum b/go.sum index 74341af5..95e6b1a7 100644 --- a/go.sum +++ b/go.sum @@ -203,8 +203,8 @@ github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7J github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 h1:QxLjsLOYlsVLPwuRkP0Q8EcAoZT1s8vU2ZBSX0+R6CI= github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q= -github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee h1:70PHW6Xw70yNNZ5aX936XqcMLwNmfMZpCV3FCOGKpxE= -github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee/go.mod h1:iwHxuueg2paOak7zIg0ESBWx7A0wIHGopAratbgaPNY= +github.com/gptscript-ai/tui v0.0.0-20250419050840-5e79e16786c9 h1:wQC8sKyeGA50WnCEG+Jo5FNRIkuX3HX8d3ubyWCCoI8= +github.com/gptscript-ai/tui v0.0.0-20250419050840-5e79e16786c9/go.mod h1:iwHxuueg2paOak7zIg0ESBWx7A0wIHGopAratbgaPNY= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -270,6 +270,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.21.1 h1:7Ek6KPIIbMhEYHRiRIg6K6UAgNZCJaHKQp926MNr6V0= +github.com/mark3labs/mcp-go v0.21.1/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -406,6 +408,8 @@ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavM github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.3.7/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.5.4 h1:2uY/xC0roWy8IBEGLgB1ywIoEJFGmRrX21YQcvGZzjU= diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 4b0642d2..b5a823b2 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -215,7 +215,7 @@ func (r *GPTScript) listTools(ctx context.Context, gptScript *gptscript.GPTScrip // Don't print instructions tool.Instructions = "" - lines = append(lines, tool.String()) + lines = append(lines, tool.Print()) } fmt.Println(strings.Join(lines, "\n---\n")) return nil diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index abf45e8c..39357e9a 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -11,6 +11,7 @@ import ( "sync" "github.com/gptscript-ai/gptscript/pkg/counter" + "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" ) @@ -41,6 +42,11 @@ type Engine struct { RuntimeManager RuntimeManager Env []string Progress chan<- types.CompletionStatus + MCPRunner MCPRunner +} + +type MCPRunner interface { + Run(ctx context.Context, progress chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) } type State struct { @@ -307,6 +313,21 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too return nil } +func (e *Engine) runMCPInvoke(ctx Context, tool types.Tool, input string) (*Return, error) { + runner := e.MCPRunner + if runner == nil { + runner = mcp.DefaultRunner + } + output, err := runner.Run(ctx.Ctx, e.Progress, tool, input) + if err != nil { + return nil, fmt.Errorf("failed to run MCP invoke: %w", err) + } + + return &Return{ + Result: &output, + }, nil +} + func (e *Engine) runCommandTools(ctx Context, tool types.Tool, input string) (*Return, error) { if tool.IsHTTP() { return e.runHTTP(ctx, tool, input) @@ -342,6 +363,10 @@ func (e *Engine) Start(ctx Context, input string) (ret *Return, err error) { } }() + if tool.IsMCPInvoke() { + return e.runMCPInvoke(ctx, tool, input) + } + if tool.IsCommand() { return e.runCommandTools(ctx, tool, input) } @@ -378,6 +403,7 @@ func addUpdateSystem(ctx Context, tool types.Tool, msgs []types.CompletionMessag instructions = append(instructions, context.Content) } + tool.Instructions = strings.TrimPrefix(tool.Instructions, types.PromptPrefix) if tool.Instructions != "" { instructions = append(instructions, tool.Instructions) } diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index e70827c6..2a6f2433 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -20,6 +20,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/hash" + "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/openapi" "github.com/gptscript-ai/gptscript/pkg/parser" "github.com/gptscript-ai/gptscript/pkg/system" @@ -155,7 +156,23 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T { return openAPIDocument } -func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) { +func processMCP(ctx context.Context, tool []types.Tool, mcpLoader MCPLoader) (result []types.Tool, _ error) { + for _, t := range tool { + if t.IsMCP() { + mcpTools, err := mcpLoader.Load(ctx, t) + if err != nil { + return nil, fmt.Errorf("error loading MCP tools: %w", err) + } + result = append(result, mcpTools...) + } else { + result = append(result, t) + } + } + + return result, nil +} + +func readTool(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) { data := base.Content var ( @@ -212,6 +229,11 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base return nil, fmt.Errorf("no tools found in %s", base) } + tools, err := processMCP(ctx, tools, mcp) + if err != nil { + return nil, err + } + var ( localTools = types.ToolSet{} targetTools []types.Tool @@ -279,17 +301,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base localTools[strings.ToLower(tool.Name)] = tool } - return linkAll(ctx, cache, prg, base, targetTools, localTools, defaultModel) + return linkAll(ctx, cache, mcp, prg, base, targetTools, localTools, defaultModel) } -func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet, defaultModel string) (result []types.Tool, _ error) { +func linkAll(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet, defaultModel string) (result []types.Tool, _ error) { localToolsMapping := make(map[string]string, len(tools)) for _, localTool := range localTools { localToolsMapping[strings.ToLower(localTool.Name)] = localTool.ID } for _, tool := range tools { - tool, err := link(ctx, cache, prg, base, tool, localTools, localToolsMapping, defaultModel) + tool, err := link(ctx, cache, mcp, prg, base, tool, localTools, localToolsMapping, defaultModel) if err != nil { return nil, err } @@ -298,7 +320,7 @@ func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base return } -func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string, defaultModel string) (types.Tool, error) { +func link(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string, defaultModel string) (types.Tool, error) { if existing, ok := prg.ToolSet[tool.ID]; ok { return existing, nil } @@ -323,7 +345,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so linkedTool = existing } else { var err error - linkedTool, err = link(ctx, cache, prg, base, localTool, localTools, localToolsMapping, defaultModel) + linkedTool, err = link(ctx, cache, mcp, prg, base, localTool, localTools, localToolsMapping, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err) } @@ -333,7 +355,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so toolNames[targetToolName] = struct{}{} } else { toolName, subTool := types.SplitToolRef(targetToolName) - resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool, defaultModel) + resolvedTools, err := resolve(ctx, cache, mcp, prg, base, toolName, subTool, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed resolving %s from %s: %w", targetToolName, base, err) } @@ -373,7 +395,7 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. prg := types.Program{ ToolSet: types.ToolSet{}, } - tools, err := readTool(ctx, opt.Cache, &prg, &source{ + tools, err := readTool(ctx, opt.Cache, opt.MCPLoader, &prg, &source{ Content: []byte(content), Path: locationPath, Name: locationName, @@ -390,6 +412,11 @@ type Options struct { Cache *cache.Client Location string DefaultModel string + MCPLoader MCPLoader +} + +type MCPLoader interface { + Load(ctx context.Context, tool types.Tool) ([]types.Tool, error) } func complete(opts ...Options) (result Options) { @@ -397,6 +424,7 @@ func complete(opts ...Options) (result Options) { result.Cache = types.FirstSet(opt.Cache, result.Cache) result.Location = types.FirstSet(opt.Location, result.Location) result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel) + result.MCPLoader = types.FirstSet(opt.MCPLoader, result.MCPLoader) } if result.Location == "" { @@ -407,6 +435,10 @@ func complete(opts ...Options) (result Options) { result.DefaultModel = builtin.GetDefaultModel() } + if result.MCPLoader == nil { + result.MCPLoader = mcp.DefaultLoader + } + return } @@ -430,7 +462,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty Name: name, ToolSet: types.ToolSet{}, } - tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName, opt.DefaultModel) + tools, err := resolve(ctx, opt.Cache, opt.MCPLoader, &prg, &source{}, name, subToolName, opt.DefaultModel) if err != nil { return types.Program{}, err } @@ -438,7 +470,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty return prg, nil } -func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) { +func resolve(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) { if subTool == "" { t, ok := builtin.DefaultModel(name, defaultModel) if ok { @@ -452,7 +484,7 @@ func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base return nil, err } - result, err := readTool(ctx, cache, prg, s, subTool, defaultModel) + result, err := readTool(ctx, cache, mcp, prg, s, subTool, defaultModel) if err != nil { return nil, err } diff --git a/pkg/mcp/loader.go b/pkg/mcp/loader.go new file mode 100644 index 00000000..b4b33ba6 --- /dev/null +++ b/pkg/mcp/loader.go @@ -0,0 +1,264 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "slices" + "strings" + "sync" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/gptscript-ai/gptscript/pkg/hash" + "github.com/gptscript-ai/gptscript/pkg/types" + "github.com/gptscript-ai/gptscript/pkg/version" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" +) + +var ( + DefaultLoader = &Local{} + DefaultRunner = DefaultLoader +) + +type Local struct { + nextID int64 + lock sync.Mutex + sessions map[string]*Session +} + +type Session struct { + ID string + InitResult *mcp.InitializeResult + Client client.MCPClient + Config ServerConfig +} + +type Config struct { + MCPServers map[string]ServerConfig `json:"mcpServers"` +} + +type ServerConfig struct { + DisableInstruction bool `json:"disableInstruction"` + Command string `json:"command"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + Server string `json:"server"` + URL string `json:"url"` + BaseURL string `json:"baseURL,omitempty"` + Headers map[string]string `json:"headers"` +} + +func (s *ServerConfig) GetBaseURL() string { + if s.BaseURL != "" { + return s.BaseURL + } + if s.Server != "" { + return s.Server + } + return s.URL +} + +func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool, _ error) { + if !tool.IsMCP() { + return []types.Tool{tool}, nil + } + + _, configData, _ := strings.Cut(tool.Instructions, "\n") + var servers Config + + if err := json.Unmarshal([]byte(strings.TrimSpace(configData)), &servers); err != nil { + return nil, fmt.Errorf("failed to parse MCP configuration: %w\n%s", err, configData) + } + + if len(servers.MCPServers) == 0 { + // Try to load just one server + var server ServerConfig + if err := json.Unmarshal([]byte(strings.TrimSpace(configData)), &server); err != nil { + return nil, fmt.Errorf("failed to parse single MCP server configuration: %w\n%s", err, configData) + } + if server.Command == "" && server.URL == "" && server.Server == "" { + return nil, fmt.Errorf("no MCP server configuration found in tool instructions: %s", configData) + } + servers.MCPServers = map[string]ServerConfig{ + "default": server, + } + } + + if len(servers.MCPServers) > 1 { + return nil, fmt.Errorf("only a single MCP server definition is support") + } + + for _, server := range slices.Sorted(maps.Keys(servers.MCPServers)) { + session, err := l.loadSession(ctx, servers.MCPServers[server]) + if err != nil { + return nil, fmt.Errorf("failed to load MCP session for server %s: %w", server, err) + } + + return l.sessionToTools(ctx, session, tool.Name) + } + + // This should never happen, but just in case + return nil, fmt.Errorf("no MCP server configuration found in tool instructions: %s", configData) +} + +func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName string) ([]types.Tool, error) { + tools, err := session.Client.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to list tools: %w", err) + } + + toolDefs := []types.Tool{{ /* this is a placeholder for main tool */ }} + var toolNames []string + + for _, tool := range tools.Tools { + var schema openapi3.Schema + + schemaData, err := json.Marshal(tool.InputSchema) + if err != nil { + panic(err) + } + + if tool.Name == "" { + // I dunno, bad tool? + continue + } + + if err := json.Unmarshal(schemaData, &schema); err != nil { + return nil, fmt.Errorf("failed to unmarshal tool input schema: %w", err) + } + + annotations, err := json.Marshal(tool.Annotations) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool annotations: %w", err) + } + + toolDef := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: tool.Name, + Description: tool.Description, + Arguments: &schema, + }, + Instructions: types.MCPInvokePrefix + "." + tool.Name + " " + session.ID + " " + tool.Name, + }, + } + + if string(annotations) != "{}" { + toolDef.MetaData = map[string]string{ + "mcp-tool-annotations": string(annotations), + } + } + + if tool.Annotations.Title != "" && !slices.Contains(strings.Fields(tool.Annotations.Title), "as") { + toolDef.Name = tool.Annotations.Title + " as " + tool.Name + } + + toolDefs = append(toolDefs, toolDef) + toolNames = append(toolNames, tool.Name) + } + + main := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: toolName, + Description: session.InitResult.ServerInfo.Name, + Export: toolNames, + }, + MetaData: map[string]string{ + "bundle": "true", + }, + }, + } + + if session.InitResult.Instructions != "" { + data, _ := json.Marshal(map[string]any{ + "tools": toolNames, + "instructions": session.InitResult.Instructions, + }) + toolDefs = append(toolDefs, types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: session.ID, + Type: "context", + }, + Instructions: types.EchoPrefix + "\n" + `# START MCP SERVER INFO: ` + session.InitResult.ServerInfo.Name + "\n" + + `You have available the following tools from an MCP Server that has provided the following additional instructions` + "\n" + + string(data) + "\n" + + `# END MCP SERVER INFO` + "\n", + }, + }) + + main.ExportContext = append(main.ExportContext, session.ID) + } + + toolDefs[0] = main + return toolDefs, nil +} + +func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, error) { + id := hash.Digest(server) + l.lock.Lock() + existing, ok := l.sessions[id] + l.lock.Unlock() + if ok { + return existing, nil + } + + var ( + c client.MCPClient + err error + ) + + if server.Command != "" { + env := make([]string, 0, len(server.Env)) + for k, v := range server.Env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + c, err = client.NewStdioMCPClient(server.Command, env, server.Args...) + if err != nil { + return nil, fmt.Errorf("failed to create MCP stdio client: %w", err) + } + } else { + url := server.URL + if url == "" { + url = server.Server + } + c, err = client.NewSSEMCPClient(url, client.WithHeaders(server.Headers)) + if err != nil { + return nil, fmt.Errorf("failed to create MCP HTTP client: %w", err) + } + } + + var initRequest mcp.InitializeRequest + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: version.ProgramName, + Version: version.Get().String(), + } + + initResult, err := c.Initialize(ctx, initRequest) + if err != nil { + return nil, fmt.Errorf("failed to initialize MCP client: %w", err) + } + + result := &Session{ + ID: id, + InitResult: initResult, + Client: c, + Config: server, + } + + l.lock.Lock() + defer l.lock.Unlock() + + if existing, ok := l.sessions[id]; ok { + return existing, c.Close() + } + + if l.sessions == nil { + l.sessions = make(map[string]*Session) + } + l.sessions[id] = result + return result, nil +} diff --git a/pkg/mcp/runner.go b/pkg/mcp/runner.go new file mode 100644 index 00000000..b6d5f584 --- /dev/null +++ b/pkg/mcp/runner.go @@ -0,0 +1,51 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/gptscript-ai/gptscript/pkg/types" + "github.com/mark3labs/mcp-go/mcp" +) + +func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) { + fields := strings.Fields(tool.Instructions) + if len(fields) < 3 { + return "", fmt.Errorf("invalid mcp call, invalid number of fields in %s", tool.Instructions) + } + + id := fields[1] + toolName := fields[2] + arguments := map[string]any{} + + if input != "" { + if err := json.Unmarshal([]byte(input), &arguments); err != nil { + return "", fmt.Errorf("failed to unmarshal input: %w", err) + } + } + + l.lock.Lock() + session, ok := l.sessions[id] + l.lock.Unlock() + if !ok { + return "", fmt.Errorf("session not found for MCP server %s", id) + } + + request := mcp.CallToolRequest{} + request.Params.Name = toolName + request.Params.Arguments = arguments + + result, err := session.Client.CallTool(ctx, request) + if err != nil { + return "", fmt.Errorf("failed to call tool %s: %w", toolName, err) + } + + str, err := json.Marshal(result) + if err != nil { + return "", fmt.Errorf("failed to marshal result: %w", err) + } + + return string(str), nil +} diff --git a/pkg/tests/runner2_test.go b/pkg/tests/runner2_test.go index f5de8e10..3c4264a4 100644 --- a/pkg/tests/runner2_test.go +++ b/pkg/tests/runner2_test.go @@ -8,6 +8,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/tests/tester" + "github.com/gptscript-ai/gptscript/pkg/types" "github.com/hexops/autogold/v2" "github.com/stretchr/testify/require" ) @@ -203,3 +204,354 @@ echo "${GPTSCRIPT_INPUT}" require.NoError(t, err) autogold.Expect(map[string]interface{}{"foo": "baz", "start": true}).Equal(t, data) } + +func TestMCPLoad(t *testing.T) { + r := tester.NewRunner(t) + prg, err := loader.ProgramFromSource(context.Background(), ` +name: mcp + +#!mcp + +{ + "mcpServers": { + "sqlite": { + "command": "docker", + "args": [ + "run", + "--rm", + "-i", + "-v", + "mcp-test:/mcp", + "mcp/sqlite@sha256:007ccae941a6f6db15b26ee41d92edda50ce157176d9273449e8b3f51d979c70", + "--db-path", + "/mcp/test.db" + ] + } + } +} +`, "") + require.NoError(t, err) + + autogold.Expect(types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "mcp", + Description: "sqlite", + ModelName: "gpt-4o", + Export: []string{ + "read_query", + "write_query", + "create_table", + "list_tables", + "describe_table", + "append_insight", + }, + }, + MetaData: map[string]string{"bundle": "true"}, + }, + ID: "inline:mcp", + ToolMapping: map[string][]types.ToolReference{ + "append_insight": {{ + Reference: "append_insight", + ToolID: "inline:append_insight", + }}, + "create_table": {{ + Reference: "create_table", + ToolID: "inline:create_table", + }}, + "describe_table": {{ + Reference: "describe_table", + ToolID: "inline:describe_table", + }}, + "list_tables": {{ + Reference: "list_tables", + ToolID: "inline:list_tables", + }}, + "read_query": {{ + Reference: "read_query", + ToolID: "inline:read_query", + }}, + "write_query": {{ + Reference: "write_query", + ToolID: "inline:write_query", + }}, + }, + LocalTools: map[string]string{ + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query", + }, + Source: types.ToolSource{Location: "inline"}, + WorkingDir: ".", + }).Equal(t, prg.ToolSet[prg.EntryToolID]) + autogold.Expect(7).Equal(t, len(prg.ToolSet[prg.EntryToolID].LocalTools)) + data, _ := json.MarshalIndent(prg.ToolSet, "", " ") + autogold.Expect(`{ + "inline:append_insight": { + "name": "append_insight", + "description": "Add a business insight to the memo", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "insight": { + "description": "Business insight discovered from data analysis", + "type": "string" + } + }, + "required": [ + "insight" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f append_insight", + "id": "inline:append_insight", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:create_table": { + "name": "create_table", + "description": "Create a new table in the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "query": { + "description": "CREATE TABLE SQL statement", + "type": "string" + } + }, + "required": [ + "query" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f create_table", + "id": "inline:create_table", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:describe_table": { + "name": "describe_table", + "description": "Get the schema information for a specific table", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "table_name": { + "description": "Name of the table to describe", + "type": "string" + } + }, + "required": [ + "table_name" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f describe_table", + "id": "inline:describe_table", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:list_tables": { + "name": "list_tables", + "description": "List all tables in the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f list_tables", + "id": "inline:list_tables", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:mcp": { + "name": "mcp", + "description": "sqlite", + "modelName": "gpt-4o", + "internalPrompt": null, + "export": [ + "read_query", + "write_query", + "create_table", + "list_tables", + "describe_table", + "append_insight" + ], + "metaData": { + "bundle": "true" + }, + "id": "inline:mcp", + "toolMapping": { + "append_insight": [ + { + "reference": "append_insight", + "toolID": "inline:append_insight" + } + ], + "create_table": [ + { + "reference": "create_table", + "toolID": "inline:create_table" + } + ], + "describe_table": [ + { + "reference": "describe_table", + "toolID": "inline:describe_table" + } + ], + "list_tables": [ + { + "reference": "list_tables", + "toolID": "inline:list_tables" + } + ], + "read_query": [ + { + "reference": "read_query", + "toolID": "inline:read_query" + } + ], + "write_query": [ + { + "reference": "write_query", + "toolID": "inline:write_query" + } + ] + }, + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:read_query": { + "name": "read_query", + "description": "Execute a SELECT query on the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "query": { + "description": "SELECT SQL query to execute", + "type": "string" + } + }, + "required": [ + "query" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f read_query", + "id": "inline:read_query", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:write_query": { + "name": "write_query", + "description": "Execute an INSERT, UPDATE, or DELETE query on the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "query": { + "description": "SQL query to execute", + "type": "string" + } + }, + "required": [ + "query" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f write_query", + "id": "inline:write_query", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + } +}`).Equal(t, string(data)) + + prg.EntryToolID = prg.ToolSet[prg.EntryToolID].LocalTools["read_query"] + resp, err := r.Chat(context.Background(), nil, prg, nil, `{"query": "SELECT 1"}`, runner.RunOptions{}) + r.AssertStep(t, resp, err) +} diff --git a/pkg/tests/testdata/TestMCPLoad/call1-resp.golden b/pkg/tests/testdata/TestMCPLoad/call1-resp.golden new file mode 100644 index 00000000..2861a036 --- /dev/null +++ b/pkg/tests/testdata/TestMCPLoad/call1-resp.golden @@ -0,0 +1,9 @@ +`{ + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 1" + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestMCPLoad/call1.golden b/pkg/tests/testdata/TestMCPLoad/call1.golden new file mode 100644 index 00000000..31048a88 --- /dev/null +++ b/pkg/tests/testdata/TestMCPLoad/call1.golden @@ -0,0 +1,3 @@ +`{ + "model": "gpt-4o" +}` diff --git a/pkg/tests/testdata/TestMCPLoad/step1.golden b/pkg/tests/testdata/TestMCPLoad/step1.golden new file mode 100644 index 00000000..ae20c8ed --- /dev/null +++ b/pkg/tests/testdata/TestMCPLoad/step1.golden @@ -0,0 +1,6 @@ +`{ + "done": true, + "content": "{\"content\":[{\"type\":\"text\",\"text\":\"[{'1': 1}]\"}]}", + "toolID": "", + "state": null +}` diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 3d48c6e1..6f16a7ed 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -16,11 +16,14 @@ import ( ) const ( - DaemonPrefix = "#!sys.daemon" - OpenAPIPrefix = "#!sys.openapi" - EchoPrefix = "#!sys.echo" - CallPrefix = "#!sys.call" - CommandPrefix = "#!" + DaemonPrefix = "#!sys.daemon" + OpenAPIPrefix = "#!sys.openapi" + EchoPrefix = "#!sys.echo" + CallPrefix = "#!sys.call" + MCPPrefix = "#!mcp" + MCPInvokePrefix = "#!sys.mcp.invoke" + CommandPrefix = "#!" + PromptPrefix = "!!" ) var ( @@ -876,6 +879,14 @@ func (t Tool) IsDaemon() bool { return strings.HasPrefix(t.Instructions, DaemonPrefix) } +func (t Tool) IsMCP() bool { + return strings.HasPrefix(t.Instructions, MCPPrefix) +} + +func (t Tool) IsMCPInvoke() bool { + return strings.HasPrefix(t.Instructions, MCPInvokePrefix) +} + func (t Tool) IsOpenAPI() bool { return strings.HasPrefix(t.Instructions, OpenAPIPrefix) } diff --git a/pkg/types/toolstring.go b/pkg/types/toolstring.go index b5e0d1d5..fe9d7dde 100644 --- a/pkg/types/toolstring.go +++ b/pkg/types/toolstring.go @@ -44,6 +44,10 @@ func ToDisplayText(tool Tool, input string) string { } func ToSysDisplayString(id string, args map[string]string) (string, error) { + if suffix, ok := strings.CutPrefix(id, "sys.mcp.invoke."); ok { + return fmt.Sprintf("Invoking MCP `%s`", suffix), nil + } + switch id { case "sys.append": return fmt.Sprintf("Appending to file `%s`", args["filename"]), nil From f186140b6fc470cc0fe90f60df9b258c60e5f566 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Wed, 30 Apr 2025 20:37:58 -0400 Subject: [PATCH 2/9] fix: stop spinning up multiple servers The gob encoding implementation doesn't handle map key ordering so the hash we used wasn't consistent. This change switches to using slices of strings for headers and env. Also, fix the linting errors. Signed-off-by: Donnie Adams --- pkg/loader/openapi_test.go | 24 ++++++++++++-------- pkg/mcp/loader.go | 45 ++++++++++++++++++++------------------ 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/pkg/loader/openapi_test.go b/pkg/loader/openapi_test.go index 423246d1..26561538 100644 --- a/pkg/loader/openapi_test.go +++ b/pkg/loader/openapi_test.go @@ -26,7 +26,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err, "failed to read openapi v3") require.Equal(t, 3, numOpenAPITools(prgv3.ToolSet), "expected 3 openapi tools") @@ -35,7 +35,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.json") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2json, &source{Content: datav2}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2json, &source{Content: datav2}, "", "") require.NoError(t, err, "failed to read openapi v2") require.Equal(t, 3, numOpenAPITools(prgv2json.ToolSet), "expected 3 openapi tools") @@ -44,7 +44,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav2, err = os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2yaml, &source{Content: datav2}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2yaml, &source{Content: datav2}, "", "") require.NoError(t, err, "failed to read openapi v2 (yaml)") require.Equal(t, 3, numOpenAPITools(prgv2yaml.ToolSet), "expected 3 openapi tools") @@ -57,7 +57,7 @@ func TestOpenAPIv3(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -69,7 +69,7 @@ func TestOpenAPIv3NoOperationIDs(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -81,7 +81,7 @@ func TestOpenAPIv2(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2, &source{Content: datav2}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) @@ -94,7 +94,7 @@ func TestOpenAPIv3Revamp(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -107,7 +107,7 @@ func TestOpenAPIv3NoOperationIDsRevamp(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -120,8 +120,14 @@ func TestOpenAPIv2Revamp(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2, &source{Content: datav2}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) } + +type fakeMCPLoader struct{} + +func (fakeMCPLoader) Load(context.Context, types.Tool) ([]types.Tool, error) { + return nil, nil +} diff --git a/pkg/mcp/loader.go b/pkg/mcp/loader.go index b4b33ba6..897dd3e9 100644 --- a/pkg/mcp/loader.go +++ b/pkg/mcp/loader.go @@ -23,7 +23,6 @@ var ( ) type Local struct { - nextID int64 lock sync.Mutex sessions map[string]*Session } @@ -39,15 +38,17 @@ type Config struct { MCPServers map[string]ServerConfig `json:"mcpServers"` } +// ServerConfig represents an MCP server configuration for tools calls. +// It is important that this type doesn't have any maps. type ServerConfig struct { - DisableInstruction bool `json:"disableInstruction"` - Command string `json:"command"` - Args []string `json:"args"` - Env map[string]string `json:"env"` - Server string `json:"server"` - URL string `json:"url"` - BaseURL string `json:"baseURL,omitempty"` - Headers map[string]string `json:"headers"` + DisableInstruction bool `json:"disableInstruction"` + Command string `json:"command"` + Args []string `json:"args"` + Env []string `json:"env"` + Server string `json:"server"` + URL string `json:"url"` + BaseURL string `json:"baseURL,omitempty"` + Headers []string `json:"headers"` } func (s *ServerConfig) GetBaseURL() string { @@ -62,12 +63,12 @@ func (s *ServerConfig) GetBaseURL() string { func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool, _ error) { if !tool.IsMCP() { - return []types.Tool{tool}, nil + return nil, nil } _, configData, _ := strings.Cut(tool.Instructions, "\n") - var servers Config + var servers Config if err := json.Unmarshal([]byte(strings.TrimSpace(configData)), &servers); err != nil { return nil, fmt.Errorf("failed to parse MCP configuration: %w\n%s", err, configData) } @@ -87,10 +88,10 @@ func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool, } if len(servers.MCPServers) > 1 { - return nil, fmt.Errorf("only a single MCP server definition is support") + return nil, fmt.Errorf("only a single MCP server definition is supported") } - for _, server := range slices.Sorted(maps.Keys(servers.MCPServers)) { + for server := range maps.Keys(servers.MCPServers) { session, err := l.loadSession(ctx, servers.MCPServers[server]) if err != nil { return nil, fmt.Errorf("failed to load MCP session for server %s: %w", server, err) @@ -202,6 +203,7 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, l.lock.Lock() existing, ok := l.sessions[id] l.lock.Unlock() + if ok { return existing, nil } @@ -210,13 +212,8 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, c client.MCPClient err error ) - if server.Command != "" { - env := make([]string, 0, len(server.Env)) - for k, v := range server.Env { - env = append(env, fmt.Sprintf("%s=%s", k, v)) - } - c, err = client.NewStdioMCPClient(server.Command, env, server.Args...) + c, err = client.NewStdioMCPClient(server.Command, server.Env, server.Args...) if err != nil { return nil, fmt.Errorf("failed to create MCP stdio client: %w", err) } @@ -225,7 +222,13 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, if url == "" { url = server.Server } - c, err = client.NewSSEMCPClient(url, client.WithHeaders(server.Headers)) + + headers := make(map[string]string, len(server.Headers)) + for _, h := range server.Headers { + k, v, _ := strings.Cut(h, "=") + headers[k] = v + } + c, err = client.NewSSEMCPClient(url, client.WithHeaders(headers)) if err != nil { return nil, fmt.Errorf("failed to create MCP HTTP client: %w", err) } @@ -252,7 +255,7 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, l.lock.Lock() defer l.lock.Unlock() - if existing, ok := l.sessions[id]; ok { + if existing, ok = l.sessions[id]; ok { return existing, c.Close() } From a85e68db98ea88f033ae713e99193f124a2eb11c Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Thu, 1 May 2025 12:16:00 -0400 Subject: [PATCH 3/9] Fix tests Signed-off-by: Donnie Adams --- pkg/mcp/loader.go | 2 +- pkg/mcp/runner.go | 8 ++++++-- pkg/tests/runner2_test.go | 12 ++++++------ pkg/types/tool.go | 2 +- pkg/types/toolstring.go | 2 +- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/pkg/mcp/loader.go b/pkg/mcp/loader.go index 897dd3e9..bf377bc5 100644 --- a/pkg/mcp/loader.go +++ b/pkg/mcp/loader.go @@ -142,7 +142,7 @@ func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName s Description: tool.Description, Arguments: &schema, }, - Instructions: types.MCPInvokePrefix + "." + tool.Name + " " + session.ID + " " + tool.Name, + Instructions: types.MCPInvokePrefix + tool.Name + " " + session.ID, }, } diff --git a/pkg/mcp/runner.go b/pkg/mcp/runner.go index b6d5f584..448d58a7 100644 --- a/pkg/mcp/runner.go +++ b/pkg/mcp/runner.go @@ -12,12 +12,16 @@ import ( func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) { fields := strings.Fields(tool.Instructions) - if len(fields) < 3 { + if len(fields) < 2 { return "", fmt.Errorf("invalid mcp call, invalid number of fields in %s", tool.Instructions) } id := fields[1] - toolName := fields[2] + toolName, ok := strings.CutPrefix(fields[0], types.MCPInvokePrefix) + if !ok { + return "", fmt.Errorf("invalid mcp call, invalid tool name in %s", tool.Instructions) + } + arguments := map[string]any{} if input != "" { diff --git a/pkg/tests/runner2_test.go b/pkg/tests/runner2_test.go index 3c4264a4..ba704142 100644 --- a/pkg/tests/runner2_test.go +++ b/pkg/tests/runner2_test.go @@ -308,7 +308,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f append_insight", + "instructions": "#!sys.mcp.invoke.append_insight e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", "id": "inline:append_insight", "localTools": { "append_insight": "inline:append_insight", @@ -341,7 +341,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f create_table", + "instructions": "#!sys.mcp.invoke.create_table e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", "id": "inline:create_table", "localTools": { "append_insight": "inline:append_insight", @@ -374,7 +374,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f describe_table", + "instructions": "#!sys.mcp.invoke.describe_table e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", "id": "inline:describe_table", "localTools": { "append_insight": "inline:append_insight", @@ -398,7 +398,7 @@ name: mcp "arguments": { "type": "object" }, - "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f list_tables", + "instructions": "#!sys.mcp.invoke.list_tables e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", "id": "inline:list_tables", "localTools": { "append_insight": "inline:append_insight", @@ -500,7 +500,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f read_query", + "instructions": "#!sys.mcp.invoke.read_query e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", "id": "inline:read_query", "localTools": { "append_insight": "inline:append_insight", @@ -533,7 +533,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f write_query", + "instructions": "#!sys.mcp.invoke.write_query e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", "id": "inline:write_query", "localTools": { "append_insight": "inline:append_insight", diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 6f16a7ed..10b47c77 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -21,7 +21,7 @@ const ( EchoPrefix = "#!sys.echo" CallPrefix = "#!sys.call" MCPPrefix = "#!mcp" - MCPInvokePrefix = "#!sys.mcp.invoke" + MCPInvokePrefix = "#!sys.mcp.invoke." CommandPrefix = "#!" PromptPrefix = "!!" ) diff --git a/pkg/types/toolstring.go b/pkg/types/toolstring.go index fe9d7dde..8d379f14 100644 --- a/pkg/types/toolstring.go +++ b/pkg/types/toolstring.go @@ -44,7 +44,7 @@ func ToDisplayText(tool Tool, input string) string { } func ToSysDisplayString(id string, args map[string]string) (string, error) { - if suffix, ok := strings.CutPrefix(id, "sys.mcp.invoke."); ok { + if suffix, ok := strings.CutPrefix(id, MCPInvokePrefix); ok { return fmt.Sprintf("Invoking MCP `%s`", suffix), nil } From 72a426dc1b1b9481920e106afd76dbe6b814b8e9 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 2 May 2025 09:44:12 -0400 Subject: [PATCH 4/9] Add scope MCP server config and close method to loader Signed-off-by: Donnie Adams --- pkg/loader/loader.go | 1 + pkg/loader/openapi_test.go | 4 ++++ pkg/mcp/loader.go | 20 ++++++++++++++++++++ 3 files changed, 25 insertions(+) diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 2a6f2433..626cc87f 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -417,6 +417,7 @@ type Options struct { type MCPLoader interface { Load(ctx context.Context, tool types.Tool) ([]types.Tool, error) + Close() error } func complete(opts ...Options) (result Options) { diff --git a/pkg/loader/openapi_test.go b/pkg/loader/openapi_test.go index 26561538..594d8cf7 100644 --- a/pkg/loader/openapi_test.go +++ b/pkg/loader/openapi_test.go @@ -131,3 +131,7 @@ type fakeMCPLoader struct{} func (fakeMCPLoader) Load(context.Context, types.Tool) ([]types.Tool, error) { return nil, nil } + +func (fakeMCPLoader) Close() error { + return nil +} diff --git a/pkg/mcp/loader.go b/pkg/mcp/loader.go index bf377bc5..7e44ec2a 100644 --- a/pkg/mcp/loader.go +++ b/pkg/mcp/loader.go @@ -3,6 +3,7 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "maps" "slices" @@ -49,6 +50,7 @@ type ServerConfig struct { URL string `json:"url"` BaseURL string `json:"baseURL,omitempty"` Headers []string `json:"headers"` + Scope string `json:"scope"` } func (s *ServerConfig) GetBaseURL() string { @@ -104,6 +106,24 @@ func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool, return nil, fmt.Errorf("no MCP server configuration found in tool instructions: %s", configData) } +func (l *Local) Close() error { + if l == nil { + return nil + } + + l.lock.Lock() + defer l.lock.Unlock() + + var errs []error + for id, session := range l.sessions { + if err := session.Client.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close MCP client %s: %w", id, err)) + } + } + + return errors.Join(errs...) +} + func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName string) ([]types.Tool, error) { tools, err := session.Client.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { From acdf274d2229e4890aab1be2d25f7ab0945024cd Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 2 May 2025 10:04:28 -0400 Subject: [PATCH 5/9] Add MCP loader and runner options to the SDK server Signed-off-by: Donnie Adams --- pkg/engine/engine.go | 7 +------ pkg/runner/runner.go | 12 ++++++++++++ pkg/sdkserver/routes.go | 16 +++++++++++++--- pkg/sdkserver/run.go | 6 +++++- pkg/sdkserver/server.go | 8 ++++++++ pkg/tests/runner2_test.go | 12 ++++++------ 6 files changed, 45 insertions(+), 16 deletions(-) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 39357e9a..c7867512 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -11,7 +11,6 @@ import ( "sync" "github.com/gptscript-ai/gptscript/pkg/counter" - "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" ) @@ -314,11 +313,7 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too } func (e *Engine) runMCPInvoke(ctx Context, tool types.Tool, input string) (*Return, error) { - runner := e.MCPRunner - if runner == nil { - runner = mcp.DefaultRunner - } - output, err := runner.Run(ctx.Ctx, e.Progress, tool, input) + output, err := e.MCPRunner.Run(ctx.Ctx, e.Progress, tool, input) if err != nil { return nil, fmt.Errorf("failed to run MCP invoke: %w", err) } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 6d4e7598..200c453b 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -14,6 +14,7 @@ import ( context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/types" "golang.org/x/exp/maps" ) @@ -37,6 +38,7 @@ type Options struct { CredentialOverrides []string `usage:"-"` Sequential bool `usage:"-"` Authorizer AuthorizerFunc `usage:"-"` + MCPRunner engine.MCPRunner `usage:"-"` } type RunOptions struct { @@ -69,6 +71,9 @@ func Complete(opts ...Options) (result Options) { if opt.CredentialOverrides != nil { result.CredentialOverrides = append(result.CredentialOverrides, opt.CredentialOverrides...) } + if opt.MCPRunner != nil { + result.MCPRunner = opt.MCPRunner + } } return } @@ -87,6 +92,9 @@ func complete(opts ...Options) Options { if result.Authorizer == nil { result.Authorizer = DefaultAuthorizer } + if result.MCPRunner == nil { + result.MCPRunner = mcp.DefaultRunner + } return result } @@ -99,6 +107,7 @@ type Runner struct { credOverrides []string credStore credentials.CredentialStore sequential bool + mcpRunner engine.MCPRunner } func New(client engine.Model, credStore credentials.CredentialStore, opts ...Options) (*Runner, error) { @@ -113,6 +122,7 @@ func New(client engine.Model, credStore credentials.CredentialStore, opts ...Opt credStore: credStore, sequential: opt.Sequential, auth: opt.Authorizer, + mcpRunner: opt.MCPRunner, } if opt.StartPort != 0 { @@ -326,6 +336,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en e := engine.Engine{ Model: r.c, + MCPRunner: r.mcpRunner, RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, @@ -524,6 +535,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s e := engine.Engine{ Model: r.c, + MCPRunner: r.mcpRunner, RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 1a4e28ea..52a06994 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -29,6 +29,7 @@ type server struct { datasetTool, workspaceTool string serverToolsEnv []string client *gptscript.GPTScript + mcpLoader loader.MCPLoader events *broadcaster.Broadcaster[event] runtimeManager engine.RuntimeManager @@ -283,11 +284,20 @@ func (s *server) load(w http.ResponseWriter, r *http.Request) { } if reqObject.Content != "" { - prg, err = loader.ProgramFromSource(ctx, reqObject.Content, reqObject.SubTool, loader.Options{Cache: s.client.Cache}) + prg, err = loader.ProgramFromSource(ctx, reqObject.Content, reqObject.SubTool, loader.Options{ + Cache: s.client.Cache, + MCPLoader: s.mcpLoader, + }) } else if reqObject.File != "" { - prg, err = loader.Program(ctx, reqObject.File, reqObject.SubTool, loader.Options{Cache: s.client.Cache}) + prg, err = loader.Program(ctx, reqObject.File, reqObject.SubTool, loader.Options{ + Cache: s.client.Cache, + MCPLoader: s.mcpLoader, + }) } else { - prg, err = loader.ProgramFromSource(ctx, reqObject.ToolDefs.String(), reqObject.SubTool, loader.Options{Cache: s.client.Cache}) + prg, err = loader.ProgramFromSource(ctx, reqObject.ToolDefs.String(), reqObject.SubTool, loader.Options{ + Cache: s.client.Cache, + MCPLoader: s.mcpLoader, + }) } if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) diff --git a/pkg/sdkserver/run.go b/pkg/sdkserver/run.go index fda4a215..a2c0d505 100644 --- a/pkg/sdkserver/run.go +++ b/pkg/sdkserver/run.go @@ -36,7 +36,11 @@ func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, lo if defaultModel == "" { defaultModel = s.gptscriptOpts.OpenAI.DefaultModel } - prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{Cache: g.Cache, DefaultModel: defaultModel}) + prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{ + Cache: g.Cache, + DefaultModel: defaultModel, + MCPLoader: s.mcpLoader, + }) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) return diff --git a/pkg/sdkserver/server.go b/pkg/sdkserver/server.go index f15cc68f..52e9ec1c 100644 --- a/pkg/sdkserver/server.go +++ b/pkg/sdkserver/server.go @@ -16,6 +16,8 @@ import ( "github.com/google/uuid" "github.com/gptscript-ai/broadcaster" "github.com/gptscript-ai/gptscript/pkg/gptscript" + "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" "github.com/gptscript-ai/gptscript/pkg/runner" @@ -26,6 +28,7 @@ import ( type Options struct { gptscript.Options + MCPLoader loader.MCPLoader ListenAddress string DatasetTool, WorkspaceTool string ServerToolsEnv []string @@ -114,6 +117,7 @@ func run(ctx context.Context, listener net.Listener, opts Options) error { serverToolsEnv: opts.ServerToolsEnv, client: g, + mcpLoader: opts.MCPLoader, events: events, runtimeManager: runtimes.Default(opts.Cache.CacheDir, opts.SystemToolsDir), waitingToConfirm: make(map[string]chan runner.AuthorizerResponse), @@ -168,6 +172,7 @@ func complete(opts ...Options) Options { result.WorkspaceTool = types.FirstSet(opt.WorkspaceTool, result.WorkspaceTool) result.Debug = types.FirstSet(opt.Debug, result.Debug) result.DisableServerErrorLogging = types.FirstSet(opt.DisableServerErrorLogging, result.DisableServerErrorLogging) + result.MCPLoader = types.FirstSet(opt.MCPLoader, result.MCPLoader) } if result.ListenAddress == "" { @@ -183,6 +188,9 @@ func complete(opts ...Options) Options { if len(result.ServerToolsEnv) == 0 { result.ServerToolsEnv = os.Environ() } + if result.MCPLoader == nil { + result.MCPLoader = mcp.DefaultLoader + } return result } diff --git a/pkg/tests/runner2_test.go b/pkg/tests/runner2_test.go index ba704142..3ac518f5 100644 --- a/pkg/tests/runner2_test.go +++ b/pkg/tests/runner2_test.go @@ -308,7 +308,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke.append_insight e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", + "instructions": "#!sys.mcp.invoke.append_insight 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", "id": "inline:append_insight", "localTools": { "append_insight": "inline:append_insight", @@ -341,7 +341,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke.create_table e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", + "instructions": "#!sys.mcp.invoke.create_table 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", "id": "inline:create_table", "localTools": { "append_insight": "inline:append_insight", @@ -374,7 +374,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke.describe_table e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", + "instructions": "#!sys.mcp.invoke.describe_table 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", "id": "inline:describe_table", "localTools": { "append_insight": "inline:append_insight", @@ -398,7 +398,7 @@ name: mcp "arguments": { "type": "object" }, - "instructions": "#!sys.mcp.invoke.list_tables e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", + "instructions": "#!sys.mcp.invoke.list_tables 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", "id": "inline:list_tables", "localTools": { "append_insight": "inline:append_insight", @@ -500,7 +500,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke.read_query e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", + "instructions": "#!sys.mcp.invoke.read_query 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", "id": "inline:read_query", "localTools": { "append_insight": "inline:append_insight", @@ -533,7 +533,7 @@ name: mcp ], "type": "object" }, - "instructions": "#!sys.mcp.invoke.write_query e057d98f5d43e56fda04eb3e7ea6120c93b5bcaf832090fca76e8d744e2de494", + "instructions": "#!sys.mcp.invoke.write_query 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", "id": "inline:write_query", "localTools": { "append_insight": "inline:append_insight", From 1a17136dda533eb82f7abcede94dadcab3600025 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 2 May 2025 14:26:19 -0400 Subject: [PATCH 6/9] fix SSE client start Signed-off-by: Donnie Adams --- pkg/mcp/loader.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pkg/mcp/loader.go b/pkg/mcp/loader.go index 7e44ec2a..21f5f7e2 100644 --- a/pkg/mcp/loader.go +++ b/pkg/mcp/loader.go @@ -12,6 +12,7 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/gptscript-ai/gptscript/pkg/hash" + "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" "github.com/mark3labs/mcp-go/client" @@ -21,6 +22,8 @@ import ( var ( DefaultLoader = &Local{} DefaultRunner = DefaultLoader + + logger = mvl.Package() ) type Local struct { @@ -116,6 +119,7 @@ func (l *Local) Close() error { var errs []error for id, session := range l.sessions { + logger.Infof("closing MCP session %s", id) if err := session.Client.Close(); err != nil { errs = append(errs, fmt.Errorf("failed to close MCP client %s: %w", id, err)) } @@ -229,7 +233,7 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, } var ( - c client.MCPClient + c *client.Client err error ) if server.Command != "" { @@ -248,10 +252,16 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, k, v, _ := strings.Cut(h, "=") headers[k] = v } + c, err = client.NewSSEMCPClient(url, client.WithHeaders(headers)) if err != nil { return nil, fmt.Errorf("failed to create MCP HTTP client: %w", err) } + + // We expect the client to outlive this one request. + if err = c.Start(context.Background()); err != nil { + return nil, fmt.Errorf("failed to start MCP client: %w", err) + } } var initRequest mcp.InitializeRequest From e687b3c44f01f14491167a49ca4c1d3d818948a6 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 2 May 2025 14:28:55 -0400 Subject: [PATCH 7/9] Skip TestMCPLoad on Windows Signed-off-by: Donnie Adams --- pkg/tests/runner2_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/tests/runner2_test.go b/pkg/tests/runner2_test.go index 3ac518f5..c531c661 100644 --- a/pkg/tests/runner2_test.go +++ b/pkg/tests/runner2_test.go @@ -3,6 +3,7 @@ package tests import ( "context" "encoding/json" + "runtime" "testing" "github.com/gptscript-ai/gptscript/pkg/loader" @@ -206,6 +207,10 @@ echo "${GPTSCRIPT_INPUT}" } func TestMCPLoad(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } + r := tester.NewRunner(t) prg, err := loader.ProgramFromSource(context.Background(), ` name: mcp From 2591aff6ee6c86edd0c125c215a2008d9796656c Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 2 May 2025 14:36:32 -0400 Subject: [PATCH 8/9] chore: bump go-mcp Signed-off-by: Donnie Adams --- go.mod | 3 ++- go.sum | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 35c9689e..fc68968e 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/hexops/autogold/v2 v2.2.1 github.com/hexops/valast v1.4.4 github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 - github.com/mark3labs/mcp-go v0.21.1 + github.com/mark3labs/mcp-go v0.25.0 github.com/mholt/archives v0.1.0 github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 @@ -114,6 +114,7 @@ require ( github.com/skeema/knownhosts v1.2.2 // indirect github.com/sorairolake/lzip-go v0.3.5 // indirect github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect github.com/therootcompany/xz v1.0.1 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/go.sum b/go.sum index 95e6b1a7..7ce2cd38 100644 --- a/go.sum +++ b/go.sum @@ -270,8 +270,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.21.1 h1:7Ek6KPIIbMhEYHRiRIg6K6UAgNZCJaHKQp926MNr6V0= -github.com/mark3labs/mcp-go v0.21.1/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= +github.com/mark3labs/mcp-go v0.25.0 h1:UUpcMT3L5hIhuDy7aifj4Bphw4Pfx1Rf8mzMXDe8RQw= +github.com/mark3labs/mcp-go v0.25.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -363,6 +363,8 @@ github.com/sorairolake/lzip-go v0.3.5 h1:ms5Xri9o1JBIWvOFAorYtUNik6HI3HgBTkISiqu github.com/sorairolake/lzip-go v0.3.5/go.mod h1:N0KYq5iWrMXI0ZEXKXaS9hCyOjZUQdBDEIbXfoUwbdk= github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e h1:H+jDTUeF+SVd4ApwnSFoew8ZwGNRfgb9EsZc7LcocAg= github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e/go.mod h1:VsUklG6OQo7Ctunu0gS3AtEOCEc2kMB6r5rKzxAes58= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= From b1c1bc3808b262be343f1e1582d273b64db8f547 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 2 May 2025 15:28:02 -0400 Subject: [PATCH 9/9] Correct context management Signed-off-by: Donnie Adams --- pkg/mcp/loader.go | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/pkg/mcp/loader.go b/pkg/mcp/loader.go index 21f5f7e2..0eb713e5 100644 --- a/pkg/mcp/loader.go +++ b/pkg/mcp/loader.go @@ -27,8 +27,10 @@ var ( ) type Local struct { - lock sync.Mutex - sessions map[string]*Session + lock sync.Mutex + sessions map[string]*Session + sessionCtx context.Context + cancel context.CancelFunc } type Session struct { @@ -97,7 +99,7 @@ func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool, } for server := range maps.Keys(servers.MCPServers) { - session, err := l.loadSession(ctx, servers.MCPServers[server]) + session, err := l.loadSession(servers.MCPServers[server]) if err != nil { return nil, fmt.Errorf("failed to load MCP session for server %s: %w", server, err) } @@ -117,6 +119,15 @@ func (l *Local) Close() error { l.lock.Lock() defer l.lock.Unlock() + if l.sessionCtx == nil { + return nil + } + + defer func() { + l.cancel() + l.sessionCtx = nil + }() + var errs []error for id, session := range l.sessions { logger.Infof("closing MCP session %s", id) @@ -222,10 +233,14 @@ func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName s return toolDefs, nil } -func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, error) { +func (l *Local) loadSession(server ServerConfig) (*Session, error) { id := hash.Digest(server) l.lock.Lock() existing, ok := l.sessions[id] + if l.sessionCtx == nil { + l.sessionCtx, l.cancel = context.WithCancel(context.Background()) + } + ctx := l.sessionCtx l.lock.Unlock() if ok { @@ -259,7 +274,7 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, } // We expect the client to outlive this one request. - if err = c.Start(context.Background()); err != nil { + if err = c.Start(ctx); err != nil { return nil, fmt.Errorf("failed to start MCP client: %w", err) } }