diff --git a/README.md b/README.md index d243594..401bf87 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,8 @@ a365 auth logout | `A365_CLIENT_ID` | `--client-id` | Entra app client ID (default: `aebc6443-996d-45c2-90f0-388ff96faa56`) | | `A365_TENANT_ID` | `--tenant-id` | Entra tenant ID (optional, defaults to `organizations`) | | `A365_ENDPOINT` | — | Override the agent365 base URL | +| `A365_MCP_RESPONSE_HEADER_TIMEOUT` | — | Override the MCP HTTP response-header timeout (for example `180s`, `5m`) | +| `A365_COPILOT_RESPONSE_HEADER_TIMEOUT` | — | Override the Copilot MCP response-header timeout (default: `5m`) | ## Configuration diff --git a/docs/copilot.md b/docs/copilot.md index fc3df9b..2a314d2 100644 --- a/docs/copilot.md +++ b/docs/copilot.md @@ -35,3 +35,12 @@ a365 copilot chat "Can you give more detail on the second point?" \ # Output as JSON a365 copilot chat "Who shared files with me this week?" --output json ``` + + +## Timeout tuning + +Copilot requests can take longer than typical MCP tool calls. By default, `a365` waits up to `5m` for Copilot response headers before failing. Override this with `A365_COPILOT_RESPONSE_HEADER_TIMEOUT`, or use `A365_MCP_RESPONSE_HEADER_TIMEOUT` to change the timeout for every MCP service. + +```sh +A365_COPILOT_RESPONSE_HEADER_TIMEOUT=10m a365 copilot chat "Summarize recent project updates" +``` diff --git a/internal/commands/copilot/copilot.go b/internal/commands/copilot/copilot.go index 39dd121..c6772b0 100644 --- a/internal/commands/copilot/copilot.go +++ b/internal/commands/copilot/copilot.go @@ -15,7 +15,11 @@ import ( "github.com/sozercan/a365cli/internal/output" ) -const copilotChatTool = "copilot_chat" +const ( + copilotChatTool = "copilot_chat" + copilotServiceErrorMaxRetries = 1 + copilotServiceErrorRetryDelay = time.Second +) // CopilotCmd groups all Copilot subcommands. type CopilotCmd struct { @@ -26,6 +30,15 @@ func copilotEndpoint() string { return config.Endpoint("copilot") } +type copilotServiceError struct { + message string + retryable bool +} + +func (e *copilotServiceError) Error() string { + return e.message +} + // CopilotChatCmd searches internal M365 content using natural language. type CopilotChatCmd struct { Message string `arg:"" help:"Natural language question about your M365 content" optional:""` @@ -119,22 +132,42 @@ func callCopilot(ctx *commands.Context, message, conversationID string) (map[str args["conversationId"] = conversationID } - resp, err := client.CallTool(ctx.Ctx, copilotChatTool, args) - if err != nil { - return nil, "", fmt.Errorf("copilot chat: %w", err) - } + for attempt := 0; ; attempt++ { + resp, err := client.CallTool(ctx.Ctx, copilotChatTool, args) + if err != nil { + return nil, "", fmt.Errorf("copilot chat: %w", err) + } - data, err := output.ExtractContent(resp) - if err != nil { - return nil, "", err - } + data, err := output.ExtractContent(resp) + if err != nil { + return nil, "", err + } - nextConversationID := findConversationID(data) - if ctx.Output.Format != output.FormatJSON { - data = normalizeCopilotResponse(data, nextConversationID) - } + if svcErr := copilotServiceErrorFromData(data); svcErr != nil { + if svcErr.retryable && attempt < copilotServiceErrorMaxRetries { + if ctx.Verbose { + fmt.Fprintf(os.Stderr, "--- Copilot returned a retryable service error; retrying (attempt %d/%d) after %v\n%s\n", attempt+1, copilotServiceErrorMaxRetries, copilotServiceErrorRetryDelay, svcErr.Error()) + } + + select { + case <-ctx.Ctx.Done(): + return nil, "", ctx.Ctx.Err() + case <-time.After(copilotServiceErrorRetryDelay): + } + + continue + } + + return nil, "", fmt.Errorf("copilot chat: %w", svcErr) + } - return data, nextConversationID, nil + nextConversationID := findConversationID(data) + if ctx.Output.Format != output.FormatJSON { + data = normalizeCopilotResponse(data, nextConversationID) + } + + return data, nextConversationID, nil + } } func printCopilotResponse(ctx *commands.Context, data map[string]any) error { @@ -232,6 +265,55 @@ func cloneMap(data map[string]any) map[string]any { return cloned } +func copilotServiceErrorFromData(data map[string]any) *copilotServiceError { + _, message := extractPrimaryText(data) + message = sanitizeCopilotServiceMessage(message) + if message == "" { + return nil + } + + lower := strings.ToLower(message) + if !strings.HasPrefix(lower, "error executing tool:") { + return nil + } + + return &copilotServiceError{ + message: message, + retryable: isRetryableCopilotServiceError(message), + } +} + +func isRetryableCopilotServiceError(message string) bool { + lower := strings.ToLower(message) + return strings.Contains(lower, "timed out") || strings.Contains(lower, "timed-out") || strings.Contains(lower, "timeout") +} + +func sanitizeCopilotServiceMessage(message string) string { + message = strings.ReplaceAll(message, "\r\n", "\n") + message = strings.ReplaceAll(message, "\r", "\n") + + seen := map[string]struct{}{} + lines := strings.Split(message, "\n") + cleaned := make([]string, 0, len(lines)) + + for i, line := range lines { + line = strings.TrimSpace(line) + if i == 0 { + line = strings.TrimSpace(strings.TrimPrefix(line, "Error:")) + } + if line == "" { + continue + } + if _, ok := seen[line]; ok { + continue + } + seen[line] = struct{}{} + cleaned = append(cleaned, line) + } + + return strings.TrimSpace(strings.Join(cleaned, "\n")) +} + func normalizeCopilotResponse(data map[string]any, conversationID string) map[string]any { message := extractConversationMessage(data) if message == "" { diff --git a/internal/commands/copilot/copilot_test.go b/internal/commands/copilot/copilot_test.go index d5d827f..ba4a036 100644 --- a/internal/commands/copilot/copilot_test.go +++ b/internal/commands/copilot/copilot_test.go @@ -37,6 +37,118 @@ func TestCopilotChatCmd_Run(t *testing.T) { } } +func TestCallCopilot_RetriesRetryableServiceError(t *testing.T) { + var toolCalls int + server := newCopilotToolServer(t, [][]map[string]any{ + { + {"type": "text", "text": "Error: Error executing tool: Outgoing HTTP request timed-out.\r\nCorrelationId: retry-1, TimeStamp: 2025-12-17T17:58:01Z"}, + {"type": "text", "text": "CorrelationId: retry-1, TimeStamp: 2025-12-17T17:58:01Z"}, + }, + { + {"type": "text", "text": `{"message":"Recovered answer","conversationId":"conv-123"}`}, + }, + }, &toolCalls) + t.Cleanup(func() { server.Close() }) + t.Setenv("A365_ENDPOINT", server.URL+"/") + + ctx := &commands.Context{ + Ctx: context.Background(), + TokenProvider: func(context.Context) (string, error) { + return "test-token", nil + }, + Output: &output.Formatter{Format: output.FormatHuman, Writer: io.Discard}, + } + + data, conversationID, err := callCopilot(ctx, "Summarize my week", "") + if err != nil { + t.Fatalf("callCopilot() error: %v", err) + } + if toolCalls != 2 { + t.Fatalf("expected 2 Copilot tool calls after retry, got %d", toolCalls) + } + if data["message"] != "Recovered answer" { + t.Fatalf("expected recovered message, got %v", data["message"]) + } + if conversationID != "conv-123" { + t.Fatalf("expected conversation ID to round-trip, got %q", conversationID) + } +} + +func TestCallCopilot_ReturnsRetryableServiceErrorAfterExhaustion(t *testing.T) { + var toolCalls int + server := newCopilotToolServer(t, [][]map[string]any{ + { + {"type": "text", "text": "Error: Error executing tool: Outgoing HTTP request timed-out.\r\nCorrelationId: retry-1, TimeStamp: 2025-12-17T17:58:01Z"}, + {"type": "text", "text": "CorrelationId: retry-1, TimeStamp: 2025-12-17T17:58:01Z"}, + }, + { + {"type": "text", "text": "Error: Error executing tool: Outgoing HTTP request timed-out.\r\nCorrelationId: retry-2, TimeStamp: 2025-12-17T17:58:02Z"}, + {"type": "text", "text": "CorrelationId: retry-2, TimeStamp: 2025-12-17T17:58:02Z"}, + }, + }, &toolCalls) + t.Cleanup(func() { server.Close() }) + t.Setenv("A365_ENDPOINT", server.URL+"/") + + ctx := &commands.Context{ + Ctx: context.Background(), + TokenProvider: func(context.Context) (string, error) { + return "test-token", nil + }, + Output: &output.Formatter{Format: output.FormatHuman, Writer: io.Discard}, + } + + _, _, err := callCopilot(ctx, "Summarize my week", "") + if err == nil { + t.Fatal("expected retried timeout payload to surface as an error") + } + if toolCalls != 2 { + t.Fatalf("expected 2 Copilot tool calls before failing, got %d", toolCalls) + } + if !strings.Contains(err.Error(), "copilot chat: Error executing tool: Outgoing HTTP request timed-out.") { + t.Fatalf("expected timeout error message, got %v", err) + } + if strings.Count(err.Error(), "CorrelationId:") != 1 { + t.Fatalf("expected correlation metadata to be deduplicated, got %q", err.Error()) + } + if !strings.Contains(err.Error(), "retry-2") { + t.Fatalf("expected final retry correlation ID, got %v", err) + } +} + +func TestCallCopilot_ReturnsNonRetryableServiceErrorWithoutRetry(t *testing.T) { + var toolCalls int + server := newCopilotToolServer(t, [][]map[string]any{ + { + {"type": "text", "text": "Error: Error executing tool: upstream unavailable.\r\nCorrelationId: fail-1, TimeStamp: 2025-12-17T17:58:01Z"}, + {"type": "text", "text": "CorrelationId: fail-1, TimeStamp: 2025-12-17T17:58:01Z"}, + }, + { + {"type": "text", "text": `{"message":"unexpected retry"}`}, + }, + }, &toolCalls) + t.Cleanup(func() { server.Close() }) + t.Setenv("A365_ENDPOINT", server.URL+"/") + + ctx := &commands.Context{ + Ctx: context.Background(), + TokenProvider: func(context.Context) (string, error) { + return "test-token", nil + }, + Output: &output.Formatter{Format: output.FormatHuman, Writer: io.Discard}, + } + + _, _, err := callCopilot(ctx, "Summarize my week", "") + if err == nil { + t.Fatal("expected non-timeout tool failure to surface as an error") + } + if toolCalls != 1 { + t.Fatalf("expected non-retryable service error to avoid retry, got %d calls", toolCalls) + } + if !strings.Contains(err.Error(), "copilot chat: Error executing tool: upstream unavailable.") { + t.Fatalf("unexpected error: %v", err) + } +} + func TestPrintCopilotResponse_Human(t *testing.T) { var buf bytes.Buffer ctx := &commands.Context{ @@ -176,6 +288,58 @@ func TestNormalizeCopilotResponse(t *testing.T) { } } +func newCopilotToolServer(t *testing.T, toolResponses [][]map[string]any, toolCalls *int) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + var req struct { + ID int `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` + } + if err := json.Unmarshal(body, &req); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Mcp-Session-Id", "test-session-id") + + switch req.Method { + case "initialize": + io.WriteString(w, "event: message\ndata: "+testutil.MustJSON(map[string]any{ + "jsonrpc": "2.0", + "id": req.ID, + "result": map[string]any{ + "protocolVersion": "2024-11-05", + "serverInfo": map[string]any{"name": "test", "version": "1.0"}, + }, + })+"\n\n") + case "tools/call": + idx := *toolCalls + *toolCalls++ + if idx >= len(toolResponses) { + idx = len(toolResponses) - 1 + } + io.WriteString(w, "event: message\ndata: "+testutil.MustJSON(map[string]any{ + "jsonrpc": "2.0", + "id": req.ID, + "result": map[string]any{ + "content": toolResponses[idx], + }, + })+"\n\n") + default: + http.Error(w, "unknown method", http.StatusBadRequest) + } + })) +} + func TestRunInteractiveLoop_ReusesConversationID(t *testing.T) { var calls []map[string]any diff --git a/internal/config/config.go b/internal/config/config.go index d5d84c1..4a013cc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,7 @@ import ( "net/url" "os" "strings" + "time" ) const ( @@ -29,6 +30,12 @@ const ( // DefaultClientID is the default Entra app client ID (VS Code MCP extension). DefaultClientID = "aebc6443-996d-45c2-90f0-388ff96faa56" + + // DefaultMCPResponseHeaderTimeout is the default HTTP response-header timeout for MCP requests. + DefaultMCPResponseHeaderTimeout = 60 * time.Second + + // DefaultCopilotResponseHeaderTimeout is longer because Copilot requests can take longer to start streaming. + DefaultCopilotResponseHeaderTimeout = 5 * time.Minute ) // Servers maps friendly names to agent365 MCP server names. @@ -73,6 +80,33 @@ func BaseURL() string { return base } +// MCPResponseHeaderTimeout returns the HTTP response-header timeout for MCP requests. +// +// A365_MCP_RESPONSE_HEADER_TIMEOUT overrides the default for all services. +// A365_COPILOT_RESPONSE_HEADER_TIMEOUT overrides the Copilot service specifically. +func MCPResponseHeaderTimeout(service string) time.Duration { + timeout := DefaultMCPResponseHeaderTimeout + if strings.EqualFold(service, "copilot") { + timeout = DefaultCopilotResponseHeaderTimeout + } + + if v := strings.TrimSpace(os.Getenv("A365_MCP_RESPONSE_HEADER_TIMEOUT")); v != "" { + if d, err := time.ParseDuration(v); err == nil && d >= 0 { + timeout = d + } + } + + if strings.EqualFold(service, "copilot") { + if v := strings.TrimSpace(os.Getenv("A365_COPILOT_RESPONSE_HEADER_TIMEOUT")); v != "" { + if d, err := time.ParseDuration(v); err == nil && d >= 0 { + timeout = d + } + } + } + + return timeout +} + // ValidateEndpointURL rejects malformed endpoints and non-loopback plaintext HTTP. func ValidateEndpointURL(raw string) error { if raw == "" { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 1f283e4..ee0808d 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,6 +1,9 @@ package config -import "testing" +import ( + "testing" + "time" +) func TestBaseURL_Default(t *testing.T) { t.Setenv("A365_ENDPOINT", "") @@ -105,3 +108,42 @@ func TestServers_HasExpectedKeys(t *testing.T) { } } } + +func TestMCPResponseHeaderTimeout(t *testing.T) { + t.Setenv("A365_MCP_RESPONSE_HEADER_TIMEOUT", "") + t.Setenv("A365_COPILOT_RESPONSE_HEADER_TIMEOUT", "") + + if got := MCPResponseHeaderTimeout(""); got != DefaultMCPResponseHeaderTimeout { + t.Fatalf(`MCPResponseHeaderTimeout("") = %v, want %v`, got, DefaultMCPResponseHeaderTimeout) + } + if got := MCPResponseHeaderTimeout("copilot"); got != DefaultCopilotResponseHeaderTimeout { + t.Fatalf(`MCPResponseHeaderTimeout("copilot") = %v, want %v`, got, DefaultCopilotResponseHeaderTimeout) + } + + t.Setenv("A365_MCP_RESPONSE_HEADER_TIMEOUT", "90s") + if got := MCPResponseHeaderTimeout(""); got != 90*time.Second { + t.Fatalf("global override = %v, want %v", got, 90*time.Second) + } + if got := MCPResponseHeaderTimeout("copilot"); got != 90*time.Second { + t.Fatalf("global override for copilot = %v, want %v", got, 90*time.Second) + } + + t.Setenv("A365_COPILOT_RESPONSE_HEADER_TIMEOUT", "3m") + if got := MCPResponseHeaderTimeout("copilot"); got != 3*time.Minute { + t.Fatalf("copilot override = %v, want %v", got, 3*time.Minute) + } + + t.Setenv("A365_MCP_RESPONSE_HEADER_TIMEOUT", "invalid") + t.Setenv("A365_COPILOT_RESPONSE_HEADER_TIMEOUT", "invalid") + if got := MCPResponseHeaderTimeout(""); got != DefaultMCPResponseHeaderTimeout { + t.Fatalf("invalid global override should fall back to default, got %v", got) + } + if got := MCPResponseHeaderTimeout("copilot"); got != DefaultCopilotResponseHeaderTimeout { + t.Fatalf("invalid copilot override should fall back to default, got %v", got) + } + + t.Setenv("A365_COPILOT_RESPONSE_HEADER_TIMEOUT", "0") + if got := MCPResponseHeaderTimeout("copilot"); got != 0 { + t.Fatalf("copilot zero override = %v, want 0", got) + } +} diff --git a/internal/mcp/client.go b/internal/mcp/client.go index a10538a..2a1ff64 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -26,26 +26,30 @@ type VerboseLogger func(format string, args ...any) // Client is a lightweight MCP JSON-RPC client that speaks HTTP+SSE to agent365. type Client struct { - endpoint string - tokenProvider TokenProvider - httpClient *http.Client - sessionID string - nextID atomic.Int64 - verbose VerboseLogger - maxRetries int - retryBaseDelay time.Duration + endpoint string + tokenProvider TokenProvider + httpClient *http.Client + responseHeaderTimeout time.Duration + sessionID string + nextID atomic.Int64 + verbose VerboseLogger + maxRetries int + retryBaseDelay time.Duration } // NewClient creates a new MCP client for the given endpoint. func NewClient(endpoint string, tokenProvider TokenProvider) *Client { + responseHeaderTimeout := config.MCPResponseHeaderTimeout(serviceNameFromEndpoint(endpoint)) + return &Client{ - endpoint: endpoint, - tokenProvider: tokenProvider, + endpoint: endpoint, + tokenProvider: tokenProvider, + responseHeaderTimeout: responseHeaderTimeout, httpClient: &http.Client{ Transport: &http.Transport{ DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext, TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 60 * time.Second, + ResponseHeaderTimeout: responseHeaderTimeout, }, }, maxRetries: 2, @@ -53,6 +57,27 @@ func NewClient(endpoint string, tokenProvider TokenProvider) *Client { } } +func serviceNameFromEndpoint(endpoint string) string { + trimmed := strings.TrimSuffix(endpoint, "/") + if trimmed == "" { + return "" + } + + lastSlash := strings.LastIndex(trimmed, "/") + serverName := trimmed + if lastSlash >= 0 { + serverName = trimmed[lastSlash+1:] + } + + for service, knownServerName := range config.Servers { + if knownServerName == serverName { + return service + } + } + + return "" +} + // SetVerbose enables verbose logging of MCP requests and responses. func (c *Client) SetVerbose(logger VerboseLogger) { c.verbose = logger @@ -217,7 +242,7 @@ func (c *Client) doRequest(ctx context.Context, rpcReq JSONRPCRequest) (*JSONRPC return nil, fmt.Errorf("marshal request: %w", err) } - c.logf(">>> MCP %s %s\n%s", rpcReq.Method, c.endpoint, string(body)) + c.logf(">>> MCP %s %s (response-header-timeout=%s)\n%s", rpcReq.Method, c.endpoint, c.responseHeaderTimeout, string(body)) if err := config.ValidateEndpointURL(c.endpoint); err != nil { return nil, fmt.Errorf("invalid endpoint %q: %w", c.endpoint, err) diff --git a/internal/mcp/client_test.go b/internal/mcp/client_test.go index 12dfcf6..f664beb 100644 --- a/internal/mcp/client_test.go +++ b/internal/mcp/client_test.go @@ -11,6 +11,8 @@ import ( "sync/atomic" "testing" "time" + + "github.com/sozercan/a365cli/internal/config" ) func TestParseSSE_ToolCall(t *testing.T) { @@ -689,3 +691,69 @@ func TestClient_ListToolsCached(t *testing.T) { t.Errorf("expected tools/list called only once (cached), got %d", c) } } + +func TestServiceNameFromEndpoint(t *testing.T) { + if got := serviceNameFromEndpoint(config.Endpoint("copilot")); got != "copilot" { + t.Fatalf("serviceNameFromEndpoint(copilot) = %q, want %q", got, "copilot") + } + if got := serviceNameFromEndpoint("https://example.com/custom/"); got != "" { + t.Fatalf("serviceNameFromEndpoint(custom) = %q, want empty", got) + } +} + +func TestNewClient_ResponseHeaderTimeout(t *testing.T) { + t.Setenv("A365_MCP_RESPONSE_HEADER_TIMEOUT", "") + t.Setenv("A365_COPILOT_RESPONSE_HEADER_TIMEOUT", "") + + generic := NewClient(config.Endpoint("teams"), nil) + if generic.responseHeaderTimeout != config.DefaultMCPResponseHeaderTimeout { + t.Fatalf("generic responseHeaderTimeout = %v, want %v", generic.responseHeaderTimeout, config.DefaultMCPResponseHeaderTimeout) + } + transport, ok := generic.httpClient.Transport.(*http.Transport) + if !ok { + t.Fatal("generic transport is not *http.Transport") + } + if transport.ResponseHeaderTimeout != config.DefaultMCPResponseHeaderTimeout { + t.Fatalf("generic transport ResponseHeaderTimeout = %v, want %v", transport.ResponseHeaderTimeout, config.DefaultMCPResponseHeaderTimeout) + } + + copilot := NewClient(config.Endpoint("copilot"), nil) + if copilot.responseHeaderTimeout != config.DefaultCopilotResponseHeaderTimeout { + t.Fatalf("copilot responseHeaderTimeout = %v, want %v", copilot.responseHeaderTimeout, config.DefaultCopilotResponseHeaderTimeout) + } + transport, ok = copilot.httpClient.Transport.(*http.Transport) + if !ok { + t.Fatal("copilot transport is not *http.Transport") + } + if transport.ResponseHeaderTimeout != config.DefaultCopilotResponseHeaderTimeout { + t.Fatalf("copilot transport ResponseHeaderTimeout = %v, want %v", transport.ResponseHeaderTimeout, config.DefaultCopilotResponseHeaderTimeout) + } +} + +func TestClient_CopilotUsesLongerResponseHeaderTimeout(t *testing.T) { + t.Setenv("A365_MCP_RESPONSE_HEADER_TIMEOUT", "20ms") + t.Setenv("A365_COPILOT_RESPONSE_HEADER_TIMEOUT", "500ms") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}`) + })) + defer server.Close() + + tokenProvider := func(ctx context.Context) (string, error) { + return "token", nil + } + + generic := NewClient(server.URL+"/"+config.Servers["teams"]+"/", tokenProvider) + _, err := generic.doRequest(context.Background(), JSONRPCRequest{JSONRPC: "2.0", ID: 1, Method: "tools/list"}) + if err == nil { + t.Fatal("expected generic MCP request to hit response-header timeout") + } + + copilot := NewClient(server.URL+"/"+config.Servers["copilot"]+"/", tokenProvider) + _, err = copilot.doRequest(context.Background(), JSONRPCRequest{JSONRPC: "2.0", ID: 1, Method: "tools/list"}) + if err != nil { + t.Fatalf("expected copilot request to succeed with longer timeout, got %v", err) + } +}