From 0576cd36adc037ba42efac927083fadd380d293a Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Tue, 24 Jun 2025 12:13:18 -0300 Subject: [PATCH 01/15] add session with prompts --- server/session.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/server/session.go b/server/session.go index a79da22ca..d28c821f0 100644 --- a/server/session.go +++ b/server/session.go @@ -39,6 +39,14 @@ type SessionWithTools interface { SetSessionTools(tools map[string]ServerTool) } +type SessionWithPrompts interface { + ClientSession + // GetPrompts returns the prompts specific to this session, if any + GetPrompts() map[string]string + // SetPrompts sets prompts specific to this session + SetPrompts(prompts map[string]string) +} + // SessionWithClientInfo is an extension of ClientSession that can store client info type SessionWithClientInfo interface { ClientSession From 5831267a4afb4614576f81d371774dfab4082e23 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Tue, 24 Jun 2025 13:37:14 -0300 Subject: [PATCH 02/15] on session prompt list msg set the session prompts --- server/server.go | 28 ++++++++++++++++++++++++++++ server/session.go | 4 ++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/server/server.go b/server/server.go index 46e6d9c57..ca1ca60ce 100644 --- a/server/server.go +++ b/server/server.go @@ -832,6 +832,34 @@ func (s *MCPServer) handleListPrompts( } s.promptsMu.RUnlock() + // Check if there are session-specific prompts + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithPrompts, ok := session.(SessionWithPrompts); ok { + if sessionPrompts := sessionWithPrompts.GetSessionPrompts(); sessionPrompts != nil { + // Override or add session-specific prompts + // We need to create a map first to merge the prompts properly + promptMap := make(map[string]mcp.Prompt) + + // Add global prompts first + for _, prompt := range prompts { + promptMap[prompt.Name] = prompt + } + + // Then override with session-specific tools + for name, serverPrompt := range sessionPrompts { + promptMap[name] = serverPrompt.Prompt + } + + // Convert back to slice + prompts = make([]mcp.Prompt, 0, len(promptMap)) + for _, prompt := range promptMap { + prompts = append(prompts, prompt) + } + } + } + } + // sort prompts by name sort.Slice(prompts, func(i, j int) bool { return prompts[i].Name < prompts[j].Name diff --git a/server/session.go b/server/session.go index d28c821f0..da267682b 100644 --- a/server/session.go +++ b/server/session.go @@ -42,9 +42,9 @@ type SessionWithTools interface { type SessionWithPrompts interface { ClientSession // GetPrompts returns the prompts specific to this session, if any - GetPrompts() map[string]string + GetSessionPrompts() map[string]ServerPrompt // SetPrompts sets prompts specific to this session - SetPrompts(prompts map[string]string) + SetSessionPrompts(prompts map[string]ServerPrompt) } // SessionWithClientInfo is an extension of ClientSession that can store client info From fa887f659bbe32d375ef2126e39cc41ab44e8ab6 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Thu, 26 Jun 2025 12:44:26 -0300 Subject: [PATCH 03/15] add new error message --- server/errors.go | 1 + server/session_test.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/server/errors.go b/server/errors.go index ecbe91e5f..5edf318e4 100644 --- a/server/errors.go +++ b/server/errors.go @@ -17,6 +17,7 @@ var ( ErrSessionExists = errors.New("session already exists") ErrSessionNotInitialized = errors.New("session not properly initialized") ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + ErrSessionDoesNotSupportPrompts = errors.New("session does not support per-session prompts") ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") // Notification-related errors diff --git a/server/session_test.go b/server/session_test.go index 3067f4e9c..2c3de7060 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -137,7 +137,7 @@ func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implement f.clientInfo.Store(clientInfo) } -// sessionTestClientWithTools implements the SessionWithLogging interface for testing +// sessionTestClientWithLogging implements the SessionWithLogging interface for testing type sessionTestClientWithLogging struct { sessionID string notificationChannel chan mcp.JSONRPCNotification From 9e18df398f806ee849516f424af95577bbcc7dfd Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Thu, 26 Jun 2025 12:45:04 -0300 Subject: [PATCH 04/15] AddSessionPrompts method on Session --- server/session.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/server/session.go b/server/session.go index da267682b..364475a3b 100644 --- a/server/session.go +++ b/server/session.go @@ -386,3 +386,56 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error return nil } + +func (s *MCPServer) AddSessionPrompts(sessionID string, prompts ...ServerPrompt) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithPrompts) + if !ok { + return ErrSessionDoesNotSupportPrompts + } + + s.implicitlyRegisterPromptCapabilities() + + // Get existing prompts (this should return a thread-safe copy) + sessionPrompts := session.GetSessionPrompts() + + // Create a new map to avoid concurrent modification issues + newSessionPrompts := make(map[string]ServerPrompt, len(sessionPrompts)+len(prompts)) + + // Copy existing prompts + for k, v := range sessionPrompts { + newSessionPrompts[k] = v + } + + // Add new prompts + for _, prompt := range prompts { + newSessionPrompts[prompt.Prompt.Name] = prompt + } + + // Set the prompts (this should be thread-safe) + session.SetSessionPrompts(newSessionPrompts) + + if session.Initialized() && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/prompts/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The prompts were successfully added, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/prompts/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding prompts: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} From da8be7d43e0fc64b6c3b211fbb5e7352bea151e5 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Thu, 26 Jun 2025 12:46:47 -0300 Subject: [PATCH 05/15] start with test, add sessionTestClientWithPrompts struct --- server/session_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/server/session_test.go b/server/session_test.go index 2c3de7060..04a4c2960 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -172,12 +172,58 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +// sessionTestClientWithPrompts implements the SessionWithPrompts interface for testing +type sessionTestClientWithPrompts struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool + sessionPrompts sync.Map +} + +func (f *sessionTestClientWithPrompts) SessionID() string { + return f.sessionID +} + +func (f *sessionTestClientWithPrompts) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f *sessionTestClientWithPrompts) Initialize() { + f.initialized = true +} + +func (f *sessionTestClientWithPrompts) Initialized() bool { + return f.initialized +} + +func (f *sessionTestClientWithPrompts) GetSessionPrompts() map[string]ServerPrompt { + prompts := make(map[string]ServerPrompt) + f.sessionPrompts.Range(func(key, value any) bool { + if prompt, ok := value.(ServerPrompt); ok { + prompts[key.(string)] = prompt + } + return true + }) + return prompts +} + +func (f *sessionTestClientWithPrompts) SetSessionPrompts(prompts map[string]ServerPrompt) { + // Clear existing prompts + f.sessionPrompts.Clear() + + // Set new prompts + for name, prompt := range prompts { + f.sessionPrompts.Store(name, prompt) + } +} + // Verify that all implementations satisfy their respective interfaces var ( _ ClientSession = (*sessionTestClient)(nil) _ SessionWithTools = (*sessionTestClientWithTools)(nil) _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) _ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil) + _ SessionWithPrompts = (*sessionTestClientWithPrompts)(nil) ) func TestSessionWithTools_Integration(t *testing.T) { From 0fbd1d873de0b0764d7cb1a04f575f3271e5f832 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Thu, 26 Jun 2025 12:51:37 -0300 Subject: [PATCH 06/15] add first test AddSessionPrompts --- server/session_test.go | 75 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/server/session_test.go b/server/session_test.go index 04a4c2960..a0bf8696a 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -177,7 +177,8 @@ type sessionTestClientWithPrompts struct { sessionID string notificationChannel chan mcp.JSONRPCNotification initialized bool - sessionPrompts sync.Map + sessionPrompts map[string]ServerPrompt + mu sync.RWMutex // Mutex to protect concurrent access to sessionPrompts } func (f *sessionTestClientWithPrompts) SessionID() string { @@ -197,24 +198,35 @@ func (f *sessionTestClientWithPrompts) Initialized() bool { } func (f *sessionTestClientWithPrompts) GetSessionPrompts() map[string]ServerPrompt { - prompts := make(map[string]ServerPrompt) - f.sessionPrompts.Range(func(key, value any) bool { - if prompt, ok := value.(ServerPrompt); ok { - prompts[key.(string)] = prompt - } - return true - }) - return prompts + f.mu.RLock() + defer f.mu.RUnlock() + + // Return a copy of the map to prevent concurrent modification + if f.sessionPrompts == nil { + return nil + } + + promptsCopy := make(map[string]ServerPrompt, len(f.sessionPrompts)) + for k, v := range f.sessionPrompts { + promptsCopy[k] = v + } + return promptsCopy } func (f *sessionTestClientWithPrompts) SetSessionPrompts(prompts map[string]ServerPrompt) { - // Clear existing prompts - f.sessionPrompts.Clear() + f.mu.Lock() + defer f.mu.Unlock() + + if prompts == nil { + f.sessionPrompts = nil + return + } - // Set new prompts - for name, prompt := range prompts { - f.sessionPrompts.Store(name, prompt) + promptsCopy := make(map[string]ServerPrompt, len(prompts)) + for k, v := range prompts { + promptsCopy[k] = v } + f.sessionPrompts = promptsCopy } // Verify that all implementations satisfy their respective interfaces @@ -381,6 +393,41 @@ func TestMCPServer_AddSessionTools(t *testing.T) { assert.Contains(t, session.GetSessionTools(), "session-tool") } +func TestMCPServer_AddSessionPrompts(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithPrompts{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific prompts + err = server.AddSessionPrompts(session.SessionID(), + ServerPrompt{Prompt: mcp.NewPrompt("session-prompt")}, + ) + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/prompts/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify prompt was added to session + assert.Len(t, session.GetSessionPrompts(), 1) + assert.Contains(t, session.GetSessionPrompts(), "session-prompt") +} + func TestMCPServer_AddSessionTool(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) ctx := context.Background() From 5f64c16b564040a8b77789276206612a62025213 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Thu, 26 Jun 2025 12:54:50 -0300 Subject: [PATCH 07/15] add new test AddSessionPrompt --- server/session.go | 4 ++++ server/session_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/server/session.go b/server/session.go index 364475a3b..9c58c5873 100644 --- a/server/session.go +++ b/server/session.go @@ -439,3 +439,7 @@ func (s *MCPServer) AddSessionPrompts(sessionID string, prompts ...ServerPrompt) return nil } + +func (s *MCPServer) AddSessionPrompt(sessionID string, prompt mcp.Prompt, handler PromptHandlerFunc) error { + return s.AddSessionPrompts(sessionID, ServerPrompt{Prompt: prompt, Handler: handler}) +} diff --git a/server/session_test.go b/server/session_test.go index a0bf8696a..f77ed1110 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -467,6 +467,50 @@ func TestMCPServer_AddSessionTool(t *testing.T) { assert.Contains(t, session.GetSessionTools(), "session-tool-helper") } +func TestMCPServer_AddSessionPrompt(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithPrompts{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific tool using the new helper method + err = server.AddSessionPrompt( + session.SessionID(), + mcp.NewPrompt("session-prompt-helper"), + func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return mcp.NewGetPromptResult("helper result", []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Text: "helper result"}, + }, + }), nil + }, + ) + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/prompts/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify tool was added to session + assert.Len(t, session.GetSessionPrompts(), 1) + assert.Contains(t, session.GetSessionPrompts(), "session-prompt-helper") +} + func TestMCPServer_AddSessionToolsUninitialized(t *testing.T) { // This test verifies that adding tools to an uninitialized session works correctly. // From 9a1f38243869397d0e96f96ac00dee090e980c83 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Thu, 26 Jun 2025 13:06:46 -0300 Subject: [PATCH 08/15] add test GetSessionPrompt --- server/server.go | 25 +++++++++++++-- server/session_test.go | 72 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/server/server.go b/server/server.go index ca1ca60ce..3a78b84b9 100644 --- a/server/server.go +++ b/server/server.go @@ -891,9 +891,28 @@ func (s *MCPServer) handleGetPrompt( id any, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, *requestError) { - s.promptsMu.RLock() - handler, ok := s.promptHandlers[request.Params.Name] - s.promptsMu.RUnlock() + // First check session-specific prompts + var handler PromptHandlerFunc + var ok bool + + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithPrompts, typeAssertOk := session.(SessionWithPrompts); typeAssertOk { + if sessionPrompts := sessionWithPrompts.GetSessionPrompts(); sessionPrompts != nil { + if serverPrompt, sessionOk := sessionPrompts[request.Params.Name]; sessionOk { + handler = serverPrompt.Handler + ok = true + } + } + } + } + + // If not found in session prompts, check global prompts + if !ok { + s.promptsMu.RLock() + handler, ok = s.promptHandlers[request.Params.Name] + s.promptsMu.RUnlock() + } if !ok { return nil, &requestError{ diff --git a/server/session_test.go b/server/session_test.go index f77ed1110..a245de056 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -751,6 +751,78 @@ func TestMCPServer_CallSessionTool(t *testing.T) { } } +func TestMCPServer_GetSessionPrompt(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) + + // Add global prompt + server.AddPrompt(mcp.NewPrompt("test_prompt"), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return mcp.NewGetPromptResult("global result", []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Text: "global result"}, + }, + }), nil + }) + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithPrompts{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Add session-specific prompt with the same name to override the global prompt + err = server.AddSessionPrompt( + session.SessionID(), + mcp.NewPrompt("test_prompt"), + func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return mcp.NewGetPromptResult("session result", []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Text: "session result"}, + }, + }), nil + }, + ) + require.NoError(t, err) + + // Get the prompt using session context + sessionCtx := server.WithContext(context.Background(), session) + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/get", + "params": map[string]any{ + "name": "test_prompt", + }, + } + requestBytes, err := json.Marshal(toolRequest) + if err != nil { + t.Fatalf("Failed to marshal prompt request: %v", err) + } + + response := server.HandleMessage(sessionCtx, requestBytes) + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + + getPromptResult, ok := resp.Result.(mcp.GetPromptResult) + assert.True(t, ok) + + // Since we specify a prompt with the same name for current session, the expected text should be "session result" + if textContent, ok := getPromptResult.Messages[0].Content.(mcp.TextContent); ok { + if textContent.Text != "session result" { + t.Errorf("Expected result 'session result', got %q", textContent.Text) + } + } else { + t.Error("Expected TextContent") + } +} + func TestMCPServer_DeleteSessionTools(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) ctx := context.Background() From a988d3d89a37d562f98d18b405a44021961c936f Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Thu, 26 Jun 2025 20:58:25 -0300 Subject: [PATCH 09/15] add test AddSessionPromptsUninitialized --- server/session_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/server/session_test.go b/server/session_test.go index a245de056..330b19f7a 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -602,6 +602,97 @@ func TestMCPServer_AddSessionToolsUninitialized(t *testing.T) { assert.Contains(t, session.GetSessionTools(), "initialized-tool") } +func TestMCPServer_AddSessionPromptsUninitialized(t *testing.T) { + // This test verifies that adding prompts to an uninitialized session works correctly. + // + // This scenario can occur when prompts are added during the session registration hook, + // before the session is fully initialized. In this case, we should: + // 1. Successfully add the prompts to the session + // 2. Not attempt to send a notification (since the session isn't ready) + // 3. Have the prompts available once the session is initialized + // 4. Not trigger any error hooks when adding prompts to uninitialized sessions + + // Set up error hook to track if it's called + errorChan := make(chan error) + hooks := &Hooks{} + hooks.AddOnError( + func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + errorChan <- err + }, + ) + + server := NewMCPServer("test-server", "1.0.0", + WithPromptCapabilities(true), + WithHooks(hooks), + ) + ctx := context.Background() + + // Create an uninitialized session + sessionChan := make(chan mcp.JSONRPCNotification, 1) + session := &sessionTestClientWithPrompts{ + sessionID: "uninitialized-session", + notificationChannel: sessionChan, + initialized: false, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific tools to the uninitialized session + err = server.AddSessionPrompts(session.SessionID(), + ServerPrompt{Prompt: mcp.NewPrompt("uninitialized-prompt")}, + ) + require.NoError(t, err) + + // Verify no errors + select { + case err := <-errorChan: + t.Error("Expected no errors, but OnError called with: ", err) + case <-time.After(25 * time.Millisecond): // no errors + } + + // Verify no notification was sent (channel should be empty) + select { + case <-sessionChan: + t.Error("Expected no notification to be sent for uninitialized session") + default: // no notifications + } + + // Verify prompt was added to session + assert.Len(t, session.GetSessionPrompts(), 1) + assert.Contains(t, session.GetSessionPrompts(), "uninitialized-prompt") + + // Initialize the session + session.Initialize() + + // Now verify that subsequent tool additions will send notifications + err = server.AddSessionPrompts(session.SessionID(), + ServerPrompt{Prompt: mcp.NewPrompt("initialized-prompt")}, + ) + require.NoError(t, err) + + // Verify no errors + select { + case err := <-errorChan: + t.Error("Expected no errors, but OnError called with:", err) + case <-time.After(200 * time.Millisecond): // No errors + } + + // Verify notification was sent for the initialized session + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/prompts/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Timeout waiting for expected notifications/prompts/list_changed notification") + } + + // Verify both tools are available + assert.Len(t, session.GetSessionPrompts(), 2) + assert.Contains(t, session.GetSessionPrompts(), "uninitialized-prompt") + assert.Contains(t, session.GetSessionPrompts(), "initialized-prompt") +} + func TestMCPServer_DeleteSessionToolsUninitialized(t *testing.T) { // This test verifies that deleting tools from an uninitialized session works correctly. // From 3302f374c3c58d7bf8ca6c426e30049d6d805635 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Thu, 26 Jun 2025 21:00:04 -0300 Subject: [PATCH 10/15] add log --- go.mod | 3 +++ go.sum | 8 ++++++++ server/server.go | 3 +++ 3 files changed, 14 insertions(+) diff --git a/go.mod b/go.mod index 9b9fe2d48..016dfed83 100644 --- a/go.mod +++ b/go.mod @@ -9,8 +9,11 @@ require ( github.com/yosida95/uritemplate/v3 v3.0.2 ) +require golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/sirupsen/logrus v1.9.3 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 31ed86d18..21d9fd363 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -14,13 +15,20 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/server/server.go b/server/server.go index 3a78b84b9..796aa48bc 100644 --- a/server/server.go +++ b/server/server.go @@ -11,6 +11,7 @@ import ( "sync" "github.com/mark3labs/mcp-go/mcp" + "github.com/sirupsen/logrus" ) // resourceEntry holds both a resource and its handler @@ -895,6 +896,8 @@ func (s *MCPServer) handleGetPrompt( var handler PromptHandlerFunc var ok bool + logrus.Infof("handleGetPrompt: %s", request.Params.Name) + session := ClientSessionFromContext(ctx) if session != nil { if sessionWithPrompts, typeAssertOk := session.(SessionWithPrompts); typeAssertOk { From 0e24d31a8c9aa270ba9f8e9e2fd4876110321d9d Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Fri, 27 Jun 2025 10:12:54 -0300 Subject: [PATCH 11/15] add log --- server/server.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index 796aa48bc..21148e5c1 100644 --- a/server/server.go +++ b/server/server.go @@ -896,13 +896,17 @@ func (s *MCPServer) handleGetPrompt( var handler PromptHandlerFunc var ok bool - logrus.Infof("handleGetPrompt: %s", request.Params.Name) + logrus.Infof("[handleGetPrompt]: %s", request.Params.Name) session := ClientSessionFromContext(ctx) if session != nil { + logrus.Infof("[handleGetPrompt] session: %s", session.SessionID()) if sessionWithPrompts, typeAssertOk := session.(SessionWithPrompts); typeAssertOk { + logrus.Info("[handleGetPrompt] SessionWithPrompts: ok") if sessionPrompts := sessionWithPrompts.GetSessionPrompts(); sessionPrompts != nil { + logrus.Info("[handleGetPrompt] GetSessionPrompts: ok") if serverPrompt, sessionOk := sessionPrompts[request.Params.Name]; sessionOk { + logrus.Info("[handleGetPrompt] handler: ok") handler = serverPrompt.Handler ok = true } From 178bd07211819dbc913900754f006842f5766f6f Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Fri, 27 Jun 2025 10:33:25 -0300 Subject: [PATCH 12/15] add SessionWithPromps on sse server --- server/sse.go | 22 ++++++++++ server/sse_test.go | 106 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/server/sse.go b/server/sse.go index 416995730..aacf0cb5a 100644 --- a/server/sse.go +++ b/server/sse.go @@ -29,6 +29,7 @@ type sseSession struct { initialized atomic.Bool loggingLevel atomic.Value tools sync.Map // stores session-specific tools + prompts sync.Map // stores session-specific prompts clientInfo atomic.Value // stores session-specific client info } @@ -74,6 +75,17 @@ func (s *sseSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (s *sseSession) GetSessionPrompts() map[string]ServerPrompt { + prompts := make(map[string]ServerPrompt) + s.prompts.Range(func(key, value any) bool { + if prompt, ok := value.(ServerPrompt); ok { + prompts[key.(string)] = prompt + } + return true + }) + return prompts +} + func (s *sseSession) GetSessionTools() map[string]ServerTool { tools := make(map[string]ServerTool) s.tools.Range(func(key, value any) bool { @@ -85,6 +97,16 @@ func (s *sseSession) GetSessionTools() map[string]ServerTool { return tools } +func (s *sseSession) SetSessionPrompts(prompts map[string]ServerPrompt) { + // Clear existing prompts + s.prompts.Clear() + + // Set new prompts + for name, prompt := range prompts { + s.prompts.Store(name, prompt) + } +} + func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { // Clear existing tools s.tools.Clear() diff --git a/server/sse_test.go b/server/sse_test.go index 96912be49..5370e141d 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -1141,6 +1141,112 @@ func TestSSEServer(t *testing.T) { } }) + t.Run("TestSessionWithPrompts", func(t *testing.T) { + // Create hooks to track sessions + hooks := &Hooks{} + var registeredSession *sseSession + hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) { + if s, ok := session.(*sseSession); ok { + registeredSession = s + } + }) + + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // Connect to SSE endpoint + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event to ensure session is established + _, err = readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + // Verify we got a session + if registeredSession == nil { + t.Fatal("Session was not registered via hook") + } + + // Test setting and getting prompts + prompts := map[string]ServerPrompt{ + "test_prompt": { + Prompt: mcp.Prompt{ + Name: "test_prompt", + Description: "A test prompt", + }, + Handler: func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return mcp.NewGetPromptResult("test", []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Text: "test"}, + }, + }), nil + }, + }, + } + + // Test SetSessionPrompts + registeredSession.SetSessionPrompts(prompts) + + // Test GetSessionPrompts + retrievedPrompts := registeredSession.GetSessionPrompts() + if len(retrievedPrompts) != 1 { + t.Errorf("Expected 1 prompt, got %d", len(retrievedPrompts)) + } + if prompt, exists := retrievedPrompts["test_prompt"]; !exists { + t.Error("Expected test_prompt to exist") + } else if prompt.Prompt.Name != "test_prompt" { + t.Errorf("Expected prompt name test_prompt, got %s", prompt.Prompt.Name) + } + + // Test concurrent access + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(2) + go func(i int) { + defer wg.Done() + prompts := map[string]ServerPrompt{ + fmt.Sprintf("prompt_%d", i): { + Prompt: mcp.Prompt{ + Name: fmt.Sprintf("prompt_%d", i), + Description: fmt.Sprintf("Prompt %d", i), + }, + }, + } + registeredSession.SetSessionPrompts(prompts) + }(i) + go func() { + defer wg.Done() + _ = registeredSession.GetSessionTools() + }() + } + wg.Wait() + + // Verify we can still get and set tools after concurrent access + finalPrompts := map[string]ServerPrompt{ + "final_prompt": { + Prompt: mcp.Prompt{ + Name: "final_prompt", + Description: "Final Prompt", + }, + }, + } + registeredSession.SetSessionPrompts(finalPrompts) + retrievedPrompts = registeredSession.GetSessionPrompts() + if len(retrievedPrompts) != 1 { + t.Errorf("Expected 1 prompt, got %d", len(retrievedPrompts)) + } + if _, exists := retrievedPrompts["final_prompt"]; !exists { + t.Error("Expected final_prompt to exist") + } + }) + t.Run("SessionWithTools implementation", func(t *testing.T) { // Create hooks to track sessions hooks := &Hooks{} From 50f4681a66319ad2c4de98100200d37e4efac0e3 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Sat, 28 Jun 2025 10:27:40 -0300 Subject: [PATCH 13/15] remove logrus and edit readme --- README.md | 1 + go.mod | 1 - go.sum | 2 -- server/server.go | 7 ------- 4 files changed, 1 insertion(+), 10 deletions(-) diff --git a/README.md b/README.md index a35a3ebe0..0c056883d 100644 --- a/README.md +++ b/README.md @@ -547,6 +547,7 @@ MCP-Go provides a robust session management system that allows you to: - Register and track client sessions - Send notifications to specific clients - Provide per-session tool customization +- Provide per-session prompt customization
Show Session Management Examples diff --git a/go.mod b/go.mod index 016dfed83..ea8668167 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,5 @@ require golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/sirupsen/logrus v1.9.3 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 21d9fd363..f3fb05fdf 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/server/server.go b/server/server.go index 21148e5c1..3a78b84b9 100644 --- a/server/server.go +++ b/server/server.go @@ -11,7 +11,6 @@ import ( "sync" "github.com/mark3labs/mcp-go/mcp" - "github.com/sirupsen/logrus" ) // resourceEntry holds both a resource and its handler @@ -896,17 +895,11 @@ func (s *MCPServer) handleGetPrompt( var handler PromptHandlerFunc var ok bool - logrus.Infof("[handleGetPrompt]: %s", request.Params.Name) - session := ClientSessionFromContext(ctx) if session != nil { - logrus.Infof("[handleGetPrompt] session: %s", session.SessionID()) if sessionWithPrompts, typeAssertOk := session.(SessionWithPrompts); typeAssertOk { - logrus.Info("[handleGetPrompt] SessionWithPrompts: ok") if sessionPrompts := sessionWithPrompts.GetSessionPrompts(); sessionPrompts != nil { - logrus.Info("[handleGetPrompt] GetSessionPrompts: ok") if serverPrompt, sessionOk := sessionPrompts[request.Params.Name]; sessionOk { - logrus.Info("[handleGetPrompt] handler: ok") handler = serverPrompt.Handler ok = true } From 5349f38bb4bb9f8e56c965e5080e7d20cca372c9 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Sat, 28 Jun 2025 10:36:08 -0300 Subject: [PATCH 14/15] remove sys --- go.mod | 2 -- go.sum | 6 ------ 2 files changed, 8 deletions(-) diff --git a/go.mod b/go.mod index ea8668167..9b9fe2d48 100644 --- a/go.mod +++ b/go.mod @@ -9,8 +9,6 @@ require ( github.com/yosida95/uritemplate/v3 v3.0.2 ) -require golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect - require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index f3fb05fdf..31ed86d18 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,3 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -17,16 +16,11 @@ github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZV github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 9339179c91d23014fbcc8646ae02ca4ba42f5cd1 Mon Sep 17 00:00:00 2001 From: Sebastian Ripari Date: Sat, 28 Jun 2025 10:41:19 -0300 Subject: [PATCH 15/15] fix test --- server/sse_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/sse_test.go b/server/sse_test.go index 5370e141d..7ccef4164 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -1223,7 +1223,7 @@ func TestSSEServer(t *testing.T) { }(i) go func() { defer wg.Done() - _ = registeredSession.GetSessionTools() + _ = registeredSession.GetSessionPrompts() }() } wg.Wait()