From 2cb06cee0f9e3b0945491680389a990ebe522e45 Mon Sep 17 00:00:00 2001 From: Alex Bucknall Date: Tue, 16 Sep 2025 13:31:11 +0100 Subject: [PATCH] feat: add session tracking to firmware and API tools --- blues-expert/lib/handlers.go | 9 + blues-expert/lib/session.go | 315 +++++++++++++++++++++++++++++++++++ blues-expert/main.go | 6 +- 3 files changed, 329 insertions(+), 1 deletion(-) create mode 100644 blues-expert/lib/session.go diff --git a/blues-expert/lib/handlers.go b/blues-expert/lib/handlers.go index 5bd5d2d..9a97854 100644 --- a/blues-expert/lib/handlers.go +++ b/blues-expert/lib/handlers.go @@ -41,6 +41,7 @@ var docs embed.FS // Firmware Tools func HandleFirmwareEntrypointTool(ctx context.Context, request *mcp.CallToolRequest, args FirmwareEntrypointArgs) (*mcp.CallToolResult, any, error) { + TrackSession(request, "firmware_entrypoint") if args.Sdk == "" { return &mcp.CallToolResult{ @@ -74,6 +75,8 @@ func HandleFirmwareEntrypointTool(ctx context.Context, request *mcp.CallToolRequ // Notecard API Tools func HandleAPIValidateTool(ctx context.Context, request *mcp.CallToolRequest, args RequestValidateArgs) (*mcp.CallToolResult, any, error) { + TrackSession(request, "api_validate") + var reqMap map[string]interface{} if err := json.Unmarshal([]byte(args.Request), &reqMap); err != nil { return &mcp.CallToolResult{ @@ -101,6 +104,8 @@ func HandleAPIValidateTool(ctx context.Context, request *mcp.CallToolRequest, ar } func HandleAPIDocsTool(ctx context.Context, request *mcp.CallToolRequest, args GetAPIsArgs) (*mcp.CallToolResult, any, error) { + TrackSession(request, "api_docs") + // Get API documentation apiCategory, err := GetNotecardAPIs(ctx, request, args.API) if err != nil { @@ -139,6 +144,8 @@ func HandleAPIDocsTool(ctx context.Context, request *mcp.CallToolRequest, args G // Blues Documentation Tools func HandleDocsSearchTool(ctx context.Context, request *mcp.CallToolRequest, args SearchArgs) (*mcp.CallToolResult, any, error) { + TrackSession(request, "docs_search") + // Call the search implementation from query.go result, err := SearchNotecardDocs(ctx, request, args.Query) if err != nil { @@ -154,6 +161,8 @@ func HandleDocsSearchTool(ctx context.Context, request *mcp.CallToolRequest, arg } func HandleDocsSearchExpertTool(ctx context.Context, request *mcp.CallToolRequest, args SearchExpertArgs) (*mcp.CallToolResult, any, error) { + TrackSession(request, "docs_search_expert") + // First, get the raw search results searchResult, err := SearchNotecardDocs(ctx, request, args.Query) if err != nil { diff --git a/blues-expert/lib/session.go b/blues-expert/lib/session.go new file mode 100644 index 0000000..2ca8ceb --- /dev/null +++ b/blues-expert/lib/session.go @@ -0,0 +1,315 @@ +package lib + +import ( + "encoding/json" + "log" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var globalSessionManager *SessionManager + +// RequestLog holds information about a specific request +type RequestLog struct { + Timestamp time.Time `json:"timestamp"` + ToolName string `json:"tool_name"` + Arguments interface{} `json:"arguments"` +} + +// SessionData holds session-specific data and state +type SessionData struct { + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + LastAccessed time.Time `json:"last_accessed"` + RequestCount int64 `json:"request_count"` + Metadata map[string]string `json:"metadata,omitempty"` + RequestLog []RequestLog `json:"request_log,omitempty"` +} + +// SessionManager manages client sessions for the MCP server +type SessionManager struct { + mu sync.RWMutex + sessions map[string]*SessionData +} + +// NewSessionManager creates a new session manager +func NewSessionManager() *SessionManager { + sm := &SessionManager{ + sessions: make(map[string]*SessionData), + } + + // Set global reference + globalSessionManager = sm + + // Start cleanup goroutine for expired sessions + go sm.cleanupExpiredSessions() + + return sm +} + +// GetSessionManager returns the global session manager +func GetSessionManager() *SessionManager { + return globalSessionManager +} + +// GetOrCreateSession retrieves an existing session or creates a new one +func (sm *SessionManager) GetOrCreateSession(sessionID string) *SessionData { + if sessionID == "" { + // Handle stateless sessions by returning a temporary session + return &SessionData{ + ID: "stateless", + CreatedAt: time.Now(), + LastAccessed: time.Now(), + RequestCount: 1, // This is the current request + Metadata: make(map[string]string), + RequestLog: make([]RequestLog, 0), + } + } + + sm.mu.Lock() + defer sm.mu.Unlock() + + session, exists := sm.sessions[sessionID] + if !exists { + session = &SessionData{ + ID: sessionID, + CreatedAt: time.Now(), + LastAccessed: time.Now(), + RequestCount: 0, + Metadata: make(map[string]string), + RequestLog: make([]RequestLog, 0), + } + sm.sessions[sessionID] = session + log.Printf("Session %s created", sessionID) + } else { + session.LastAccessed = time.Now() + } + + session.RequestCount++ + return session +} + +// GetSession retrieves an existing session +func (sm *SessionManager) GetSession(sessionID string) (*SessionData, bool) { + if sessionID == "" { + return nil, false + } + + sm.mu.RLock() + defer sm.mu.RUnlock() + + session, exists := sm.sessions[sessionID] + if exists { + // Update last accessed time (we need to lock for write) + sm.mu.RUnlock() + sm.mu.Lock() + session.LastAccessed = time.Now() + sm.mu.Unlock() + sm.mu.RLock() + } + return session, exists +} + +// RemoveSession removes a session from the manager +func (sm *SessionManager) RemoveSession(sessionID string) { + if sessionID == "" { + return + } + + sm.mu.Lock() + defer sm.mu.Unlock() + + if session, exists := sm.sessions[sessionID]; exists { + // Log session exit with summary statistics + log.Printf("Session %s exited after %d requests (duration: %v)", + sessionID, session.RequestCount, time.Since(session.CreatedAt).Truncate(time.Second)) + delete(sm.sessions, sessionID) + } +} + +// ListSessions returns all active sessions (for debugging/monitoring) +func (sm *SessionManager) ListSessions() map[string]*SessionData { + sm.mu.RLock() + defer sm.mu.RUnlock() + + // Return a copy to avoid race conditions + result := make(map[string]*SessionData, len(sm.sessions)) + for id, session := range sm.sessions { + // Create a copy of the session data + sessionCopy := *session + result[id] = &sessionCopy + } + return result +} + +// GetSessionCount returns the number of active sessions +func (sm *SessionManager) GetSessionCount() int { + sm.mu.RLock() + defer sm.mu.RUnlock() + return len(sm.sessions) +} + +// cleanupExpiredSessions periodically removes sessions that haven't been accessed recently +func (sm *SessionManager) cleanupExpiredSessions() { + ticker := time.NewTicker(10 * time.Minute) // Cleanup every 10 minutes + defer ticker.Stop() + + for range ticker.C { + sm.mu.Lock() + now := time.Now() + expiredSessions := make([]string, 0) + + // Find sessions that haven't been accessed in the last hour + for sessionID, session := range sm.sessions { + if now.Sub(session.LastAccessed) > time.Hour { + expiredSessions = append(expiredSessions, sessionID) + } + } + + // Remove expired sessions + for _, sessionID := range expiredSessions { + if session, exists := sm.sessions[sessionID]; exists { + log.Printf("Session %s expired after %d requests (duration: %v, idle: %v)", + sessionID, session.RequestCount, + time.Since(session.CreatedAt).Truncate(time.Second), + time.Since(session.LastAccessed).Truncate(time.Second)) + delete(sm.sessions, sessionID) + } + } + + sm.mu.Unlock() + + if len(expiredSessions) > 0 { + log.Printf("Cleaned up %d expired sessions", len(expiredSessions)) + } + } +} + +// GetSessionIDFromRequest extracts the session ID from an MCP request +func GetSessionIDFromRequest(request *mcp.CallToolRequest) string { + if request == nil || request.Session == nil { + return "" + } + return request.Session.ID() +} + +// LogSessionActivity logs session activity for monitoring +func LogSessionActivity(sessionID, toolName string, sessionData *SessionData) { + if sessionID == "" || sessionID == "stateless" { + log.Printf("Tool %s called (stateless session)", toolName) + } else { + log.Printf("Tool %s called by session %s (requests: %d)", + toolName, sessionID, sessionData.RequestCount) + } +} + +// LogSessionActivityWithArgs logs session activity including request arguments +func LogSessionActivityWithArgs(sessionID, toolName string, sessionData *SessionData, arguments interface{}) { + var argsStr string + if arguments != nil { + if argsBytes, err := json.Marshal(arguments); err == nil { + argsStr = string(argsBytes) + } else { + argsStr = "" + } + } else { + argsStr = "" + } + + if sessionID == "" || sessionID == "stateless" { + log.Printf("Tool %s called (stateless session) with args: %s", toolName, argsStr) + } else { + historyCount := len(sessionData.RequestLog) + totalRequests := sessionData.RequestCount + + // Show if we've truncated history + if totalRequests > int64(historyCount) && historyCount == 50 { + log.Printf("Tool %s called by session %s (total: %d requests, recent: %d stored) with args: %s", + toolName, sessionID, totalRequests, historyCount, argsStr) + } else { + log.Printf("Tool %s called by session %s (requests: %d) with args: %s", + toolName, sessionID, totalRequests, argsStr) + } + } +} + +// AddRequestToLog adds a request to the session's request log +func (sm *SessionManager) AddRequestToLog(sessionData *SessionData, toolName string, arguments interface{}) { + if sessionData.ID == "stateless" { + // Don't store logs for stateless sessions + return + } + + // Limit the number of logged requests per session (keep last 50) + const maxLogEntries = 50 + + requestLog := RequestLog{ + Timestamp: time.Now(), + ToolName: toolName, + Arguments: arguments, + } + + // We need to lock the session manager since we're modifying session data + sm.mu.Lock() + defer sm.mu.Unlock() + + // Find the session in our map (sessionData might be a copy) + if session, exists := sm.sessions[sessionData.ID]; exists { + session.RequestLog = append(session.RequestLog, requestLog) + + // Keep only the last maxLogEntries + if len(session.RequestLog) > maxLogEntries { + session.RequestLog = session.RequestLog[len(session.RequestLog)-maxLogEntries:] + } + } +} + +// TrackSession handles session tracking for a tool handler and returns the session data +func TrackSession(request *mcp.CallToolRequest, toolName string) *SessionData { + sessionID := GetSessionIDFromRequest(request) + sessionData := GetSessionManager().GetOrCreateSession(sessionID) + + // Capture the request arguments if available + var arguments interface{} + if request != nil && request.Params != nil { + arguments = request.Params.Arguments + GetSessionManager().AddRequestToLog(sessionData, toolName, arguments) + } + + // Log with arguments included + LogSessionActivityWithArgs(sessionID, toolName, sessionData, arguments) + + return sessionData +} + +// GetSessionRequestHistory returns the recent request history for a session +func (sm *SessionManager) GetSessionRequestHistory(sessionID string, limit int) []RequestLog { + if sessionID == "" || sessionID == "stateless" { + return []RequestLog{} + } + + sm.mu.RLock() + defer sm.mu.RUnlock() + + session, exists := sm.sessions[sessionID] + if !exists { + return []RequestLog{} + } + + // Return the last 'limit' entries + logLen := len(session.RequestLog) + if limit <= 0 || limit > logLen { + limit = logLen + } + + if limit == 0 { + return []RequestLog{} + } + + // Return a copy to avoid race conditions + result := make([]RequestLog, limit) + copy(result, session.RequestLog[logLen-limit:]) + return result +} diff --git a/blues-expert/main.go b/blues-expert/main.go index 8b648a4..f4c9c50 100644 --- a/blues-expert/main.go +++ b/blues-expert/main.go @@ -14,7 +14,8 @@ import ( ) var ( - envFilePath string + envFilePath string + sessionManager *lib.SessionManager ) func init() { @@ -33,6 +34,9 @@ func main() { } } + // Initialize session manager + sessionManager = lib.NewSessionManager() + // Create a new MCP server impl := &mcp.Implementation{Name: "Blues Expert MCP", Version: utils.Commit} opts := &mcp.ServerOptions{