diff --git a/e2e/testdata/cassettes/TestA2AServer_MultiAgent.yaml b/e2e/testdata/cassettes/TestA2AServer_MultiAgent.yaml index 050dd02b2..e9dc79ed9 100644 --- a/e2e/testdata/cassettes/TestA2AServer_MultiAgent.yaml +++ b/e2e/testdata/cassettes/TestA2AServer_MultiAgent.yaml @@ -8,7 +8,7 @@ interactions: proto_minor: 1 content_length: 0 host: api.openai.com - body: '{"input":[{"content":[{"text":"You are a multi-agent system, make sure to answer the user query in the most helpful way possible. You have access to these sub-agents:\nName: web | Description: \n\nIMPORTANT: You can ONLY transfer tasks to the agents listed above using their ID. The valid agent names are: web. You MUST NOT attempt to transfer to any other agent IDs - doing so will cause system errors.\n\nIf you are the best to answer the question according to your description, you can answer it.\n\nIf another agent is better for answering the question according to its description, call `transfer_task` function to transfer the question to that agent using the agent''s ID. When transferring, do not generate any text other than the function call.\n\n","type":"input_text"}],"role":"system"},{"content":[{"text":"You are a knowledgeable assistant that helps users with various tasks.\nBe helpful, accurate, and concise in your responses.\n","type":"input_text"}],"role":"system"},{"content":"Say hello.","role":"user"}],"model":"gpt-5-mini","tools":[{"strict":true,"parameters":{"additionalProperties":false,"properties":{"agent":{"description":"The name of the agent to transfer the task to.","type":"string"},"expected_output":{"description":"The expected output from the member (optional).","type":"string"},"task":{"description":"A clear and concise description of the task the member should achieve.","type":"string"}},"required":["agent","expected_output","task"],"type":"object"},"name":"transfer_task","description":"Use this function to transfer a task to the selected team member.\n You must provide a clear and concise description of the task the member should achieve AND the expected output.","type":"function"}],"stream":true}' + body: '{"input":[{"content":[{"text":"You are a multi-agent system, make sure to answer the user query in the most helpful way possible. You have access to these sub-agents:\nName: web | Description: \n\nIMPORTANT: You can ONLY transfer tasks to the agents listed above using their ID. The valid agent names are: web. You MUST NOT attempt to transfer to any other agent IDs - doing so will cause system errors.\n\nIf you are the best to answer the question according to your description, you can answer it.\n\nIf another agent is better for answering the question according to its description, call `transfer_task` function to transfer the question to that agent using the agent''s ID. When transferring, do not generate any text other than the function call.\n\n","type":"input_text"}],"role":"system"},{"content":[{"text":"You are a knowledgeable assistant that helps users with various tasks.\nBe helpful, accurate, and concise in your responses.\n","type":"input_text"}],"role":"system"},{"content":"Say hello.","role":"user"}],"model":"gpt-5-mini","tools":[{"strict":true,"parameters":{"additionalProperties":false,"properties":{"agent":{"description":"The name of the agent to transfer the task to.","type":"string"},"expected_output":{"description":"The expected output from the member (optional).","type":"string"},"task":{"description":"A clear and concise description of the task the member should achieve.","type":"string"}},"required":["agent","expected_output","task"],"type":"object"},"name":"transfer_task","description":"Use this function to transfer a task to the selected team member.\n You must provide a clear and concise description of the task the member should achieve AND the expected output.","type":"function"},{"strict":true,"parameters":{"additionalProperties":false,"properties":{"agent":{"description":"The name of the sub-agent to run in the background.","type":"string"},"expected_output":{"description":"The expected output from the agent (optional).","type":["string","null"]},"task":{"description":"A clear and concise description of the task the agent should achieve.","type":"string"}},"required":["agent","expected_output","task"],"type":"object"},"name":"run_background_agent","description":"Start a sub-agent task in the background and return immediately with a task ID.\nUse this to dispatch work to multiple sub-agents concurrently. The sub-agent runs with all tools\npre-approved — use only with trusted sub-agents and well-scoped tasks. Check progress with\nview_background_agent and collect results once the task is complete.","type":"function"},{"strict":true,"parameters":{"additionalProperties":false,"properties":{},"required":[],"type":"object"},"name":"list_background_agents","description":"List all background agent tasks with their status and runtime.","type":"function"},{"strict":true,"parameters":{"additionalProperties":false,"properties":{"task_id":{"description":"The ID of the background agent task to view.","type":"string"}},"required":["task_id"],"type":"object"},"name":"view_background_agent","description":"View the output and status of a specific background agent task by task ID. Returns live buffered output if still running, or the final result if complete.","type":"function"},{"strict":true,"parameters":{"additionalProperties":false,"properties":{"task_id":{"description":"The ID of the background agent task to stop.","type":"string"}},"required":["task_id"],"type":"object"},"name":"stop_background_agent","description":"Stop a running background agent task by task ID.","type":"function"}],"stream":true}' url: https://api.openai.com/v1/responses method: POST response: diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 9d6c73d65..75b9011f8 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -263,7 +263,7 @@ func (r *LocalRuntime) AvailableModels(ctx context.Context) []ModelChoice { // Get the current agent's default model reference currentAgentDefault := "" if r.modelSwitcherCfg.AgentDefaultModels != nil { - currentAgentDefault = r.modelSwitcherCfg.AgentDefaultModels[r.currentAgent] + currentAgentDefault = r.modelSwitcherCfg.AgentDefaultModels[r.CurrentAgentName()] } var choices []ModelChoice diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 5a0dd4dfd..1b6595b21 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -37,6 +37,7 @@ import ( "github.com/docker/cagent/pkg/telemetry" "github.com/docker/cagent/pkg/tools" "github.com/docker/cagent/pkg/tools/builtin" + agenttool "github.com/docker/cagent/pkg/tools/builtin/agent" mcptools "github.com/docker/cagent/pkg/tools/mcp" ) @@ -214,8 +215,12 @@ type LocalRuntime struct { fallbackCooldowns map[string]*fallbackCooldownState fallbackCooldownsMux sync.RWMutex + currentAgentMu sync.RWMutex + // onToolsChanged is called when an MCP toolset reports a tool list change. onToolsChanged func(Event) + + bgAgents *agenttool.Handler } type streamResult struct { @@ -302,6 +307,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { sessionStore: session.NewInMemorySessionStore(), fallbackCooldowns: make(map[string]*fallbackCooldownState), } + r.bgAgents = agenttool.NewHandler(r) for _, opt := range opts { opt(r) @@ -375,7 +381,7 @@ func (r *LocalRuntime) forwardRAGEvents(ctx context.Context, ragManagers map[str return } - agentName := r.currentAgent + agentName := r.CurrentAgentName() slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName) switch ragEvent.Type { @@ -436,9 +442,17 @@ func (r *LocalRuntime) InitializeRAG(ctx context.Context, events chan Event) { } func (r *LocalRuntime) CurrentAgentName() string { + r.currentAgentMu.RLock() + defer r.currentAgentMu.RUnlock() return r.currentAgent } +func (r *LocalRuntime) setCurrentAgent(name string) { + r.currentAgentMu.Lock() + defer r.currentAgentMu.Unlock() + r.currentAgent = name +} + func (r *LocalRuntime) CurrentAgentInfo(context.Context) CurrentAgentInfo { currentAgent := r.CurrentAgent() @@ -454,7 +468,7 @@ func (r *LocalRuntime) SetCurrentAgent(agentName string) error { if _, err := r.team.Agent(agentName); err != nil { return err } - r.currentAgent = agentName + r.setCurrentAgent(agentName) slog.Debug("Switched current agent", "agent", agentName) return nil } @@ -538,7 +552,7 @@ func (r *LocalRuntime) discoverMCPPrompts(ctx context.Context, toolset *mcptools // CurrentAgent returns the current agent func (r *LocalRuntime) CurrentAgent() *agent.Agent { // We validated already that the agent exists - current, _ := r.team.Agent(r.currentAgent) + current, _ := r.team.Agent(r.CurrentAgentName()) return current } @@ -624,7 +638,7 @@ func (r *LocalRuntime) getHooksExecutor(a *agent.Agent) *hooks.Executor { // executeOnUserInputHooks executes on-user-input hooks for the current agent func (r *LocalRuntime) executeOnUserInputHooks(ctx context.Context, sessionID, logContext string) { - a, _ := r.team.Agent(r.currentAgent) + a, _ := r.team.Agent(r.CurrentAgentName()) if a == nil { return } @@ -719,6 +733,7 @@ func (r *LocalRuntime) SessionStore() session.Store { // Close releases resources held by the runtime, including the session store. func (r *LocalRuntime) Close() error { + r.bgAgents.StopAll() if r.sessionStore != nil { return r.sessionStore.Close() } @@ -786,7 +801,7 @@ func (r *LocalRuntime) emitToolsChanged() { if err != nil { return } - r.onToolsChanged(ToolsetInfo(len(agentTools), false, r.currentAgent)) + r.onToolsChanged(ToolsetInfo(len(agentTools), false, r.CurrentAgentName())) } // EmitStartupInfo emits initial agent, team, and toolset information for immediate sidebar display. @@ -817,7 +832,7 @@ func (r *LocalRuntime) EmitStartupInfo(ctx context.Context, sess *session.Sessio if !send(AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())) { return } - if !send(TeamInfo(r.agentDetailsFromTeam(), r.currentAgent)) { + if !send(TeamInfo(r.agentDetailsFromTeam(), r.CurrentAgentName())) { return } @@ -836,7 +851,7 @@ func (r *LocalRuntime) EmitStartupInfo(ctx context.Context, sess *session.Sessio } usage := SessionUsage(sess, contextLimit) usage.Cost = sess.TotalCost() - send(NewTokenUsageEvent(sess.ID, r.currentAgent, usage)) + send(NewTokenUsageEvent(sess.ID, r.CurrentAgentName(), usage)) } // Emit agent warnings (if any) - these are quick @@ -856,12 +871,12 @@ func (r *LocalRuntime) emitToolsProgressively(ctx context.Context, a *agent.Agen // If no toolsets, emit final state immediately if totalToolsets == 0 { - send(ToolsetInfo(0, false, r.currentAgent)) + send(ToolsetInfo(0, false, r.CurrentAgentName())) return } // Emit initial loading state - if !send(ToolsetInfo(0, true, r.currentAgent)) { + if !send(ToolsetInfo(0, true, r.CurrentAgentName())) { return } @@ -895,13 +910,13 @@ func (r *LocalRuntime) emitToolsProgressively(ctx context.Context, a *agent.Agen totalTools += len(ts) // Emit progress update - still loading unless this is the last toolset - if !send(ToolsetInfo(totalTools, !isLast, r.currentAgent)) { + if !send(ToolsetInfo(totalTools, !isLast, r.CurrentAgentName())) { return } } // Emit final state (not loading) - send(ToolsetInfo(totalTools, false, r.currentAgent)) + send(ToolsetInfo(totalTools, false, r.CurrentAgentName())) } // registerDefaultTools registers the runtime-managed tool handlers. @@ -912,16 +927,40 @@ func (r *LocalRuntime) registerDefaultTools() { r.toolMap[builtin.ToolNameHandoff] = r.handleHandoff r.toolMap[builtin.ToolNameChangeModel] = r.handleChangeModel r.toolMap[builtin.ToolNameRevertModel] = r.handleRevertModel + r.toolMap[agenttool.ToolNameRunBackgroundAgent] = r.handleRunBackgroundAgent + r.toolMap[agenttool.ToolNameListBackgroundAgents] = r.handleListBackgroundAgents + r.toolMap[agenttool.ToolNameViewBackgroundAgent] = r.handleViewBackgroundAgent + r.toolMap[agenttool.ToolNameStopBackgroundAgent] = r.handleStopBackgroundAgent +} + +func (r *LocalRuntime) handleRunBackgroundAgent(ctx context.Context, sess *session.Session, tc tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) { + return r.bgAgents.HandleRun(ctx, sess, tc) +} + +func (r *LocalRuntime) handleListBackgroundAgents(ctx context.Context, sess *session.Session, tc tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) { + return r.bgAgents.HandleList(ctx, sess, tc) +} + +func (r *LocalRuntime) handleViewBackgroundAgent(ctx context.Context, sess *session.Session, tc tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) { + return r.bgAgents.HandleView(ctx, sess, tc) +} + +func (r *LocalRuntime) handleStopBackgroundAgent(ctx context.Context, sess *session.Session, tc tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) { + return r.bgAgents.HandleStop(ctx, sess, tc) } func (r *LocalRuntime) finalizeEventChannel(ctx context.Context, sess *session.Session, events chan Event) { // Clear the elicitation events channel before closing the events channel // to prevent a send-on-closed-channel panic in elicitationHandler. - r.clearElicitationEventsChannel() + // Skip for background sessions (ToolsApproved=true) — they never set the + // channel, so clearing it would null out the parent session's channel. + if !sess.ToolsApproved { + r.clearElicitationEventsChannel() + } defer close(events) - events <- StreamStopped(sess.ID, r.currentAgent) + events <- StreamStopped(sess.ID, r.CurrentAgentName()) r.executeOnUserInputHooks(ctx, sess.ID, "stream stopped") @@ -930,30 +969,44 @@ func (r *LocalRuntime) finalizeEventChannel(ctx context.Context, sess *session.S // RunStream starts the agent's interaction loop and returns a channel of events func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-chan Event { - slog.Debug("Starting runtime stream", "agent", r.currentAgent, "session_id", sess.ID) + slog.Debug("Starting runtime stream", "agent", r.CurrentAgentName(), "session_id", sess.ID) events := make(chan Event, 128) go func() { - telemetry.RecordSessionStart(ctx, r.currentAgent, sess.ID) + telemetry.RecordSessionStart(ctx, r.CurrentAgentName(), sess.ID) ctx, sessionSpan := r.startSpan(ctx, "runtime.session", trace.WithAttributes( - attribute.String("agent", r.currentAgent), + attribute.String("agent", r.CurrentAgentName()), attribute.String("session.id", sess.ID), )) defer sessionSpan.End() - // Set the events channel for elicitation requests - r.setElicitationEventsChannel(events) + // Set the events channel for elicitation requests. + // Skip for background sessions (ToolsApproved=true): they have all tools + // pre-approved and will never trigger elicitation prompts. Setting the + // channel would overwrite the parent session's channel; clearing it at + // teardown would break any pending MCP auth flow in the parent. + if !sess.ToolsApproved { + r.setElicitationEventsChannel(events) + } - // Set elicitation handler on all MCP toolsets before getting tools - a := r.CurrentAgent() + // Resolve the agent for this session. When AgentName is set on the + // session (e.g., background agent tasks), use it directly to avoid + // racing on the shared currentAgent field. + var a *agent.Agent + if sess.AgentName != "" { + a, _ = r.team.Agent(sess.AgentName) + } + if a == nil { + a = r.CurrentAgent() + } // Emit agent information for sidebar display // Use getEffectiveModelID to account for active fallback cooldowns events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage()) // Emit team information - events <- TeamInfo(r.agentDetailsFromTeam(), r.currentAgent) + events <- TeamInfo(r.agentDetailsFromTeam(), r.CurrentAgentName()) // Initialize RAG and forward events r.InitializeRAG(ctx, events) @@ -967,7 +1020,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c return } - events <- ToolsetInfo(len(agentTools), false, r.currentAgent) + events <- ToolsetInfo(len(agentTools), false, r.CurrentAgentName()) messages := sess.GetMessages(a) if sess.SendUserMessage { @@ -1001,7 +1054,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c // Emit updated tool count. After a ToolListChanged MCP notification // the cache is invalidated, so getTools above re-fetches from the // server and may return a different count. - events <- ToolsetInfo(len(agentTools), false, r.currentAgent) + events <- ToolsetInfo(len(agentTools), false, r.CurrentAgentName()) // Check iteration limit if runtimeMaxIterations > 0 && iteration >= runtimeMaxIterations { @@ -1203,7 +1256,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c usage := SessionUsage(sess, contextLimit) usage.LastMessage = msgUsage - events <- NewTokenUsageEvent(sess.ID, r.currentAgent, usage) + events <- NewTokenUsageEvent(sess.ID, r.CurrentAgentName(), usage) r.processToolCalls(ctx, sess, res.Calls, agentTools, events) @@ -1247,7 +1300,7 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Even for _, toolset := range a.ToolSets() { tools.ConfigureHandlers(toolset, r.elicitationHandler, - func() { events <- Authorization(tools.ElicitationActionAccept, r.currentAgent) }, + func() { events <- Authorization(tools.ElicitationActionAccept, r.CurrentAgentName()) }, r.managedOAuth, ) } @@ -1261,7 +1314,7 @@ func (r *LocalRuntime) emitAgentWarningsWithSend(a *agent.Agent, send func(Event } slog.Warn("Tool setup partially failed; continuing", "agent", a.Name(), "warnings", warnings) - send(Warning(formatToolWarning(a, warnings), r.currentAgent)) + send(Warning(formatToolWarning(a, warnings), r.CurrentAgentName())) } func (r *LocalRuntime) emitAgentWarnings(a *agent.Agent, events chan Event) { @@ -1273,7 +1326,7 @@ func (r *LocalRuntime) emitAgentWarnings(a *agent.Agent, events chan Event) { slog.Warn("Tool setup partially failed; continuing", "agent", a.Name(), "warnings", warnings) if events != nil { - events <- Warning(formatToolWarning(a, warnings), r.currentAgent) + events <- Warning(formatToolWarning(a, warnings), r.CurrentAgentName()) } } @@ -1287,7 +1340,7 @@ func formatToolWarning(a *agent.Agent, warnings []string) string { } func (r *LocalRuntime) Resume(_ context.Context, req ResumeRequest) { - slog.Debug("Resuming runtime", "agent", r.currentAgent, "type", req.Type, "reason", req.Reason) + slog.Debug("Resuming runtime", "agent", r.CurrentAgentName(), "type", req.Type, "reason", req.Reason) // Defensive validation: // @@ -1298,7 +1351,7 @@ func (r *LocalRuntime) Resume(_ context.Context, req ResumeRequest) { if !IsValidResumeType(req.Type) { slog.Warn( "Invalid resume type received; ignoring resume request", - "agent", r.currentAgent, + "agent", r.CurrentAgentName(), "confirmation_type", req.Type, "valid_types", ValidResumeTypes(), ) @@ -1312,11 +1365,11 @@ func (r *LocalRuntime) Resume(_ context.Context, req ResumeRequest) { // canceled, or shutting down). select { case r.resumeChan <- req: - slog.Debug("Resume signal sent", "agent", r.currentAgent) + slog.Debug("Resume signal sent", "agent", r.CurrentAgentName()) default: slog.Debug( "Resume channel not ready; resume signal dropped", - "agent", r.currentAgent, + "agent", r.CurrentAgentName(), "confirmation_type", req.Type, ) } @@ -1324,7 +1377,7 @@ func (r *LocalRuntime) Resume(_ context.Context, req ResumeRequest) { // ResumeElicitation sends an elicitation response back to a waiting elicitation request func (r *LocalRuntime) ResumeElicitation(ctx context.Context, action tools.ElicitationAction, content map[string]any) error { - slog.Debug("Resuming runtime with elicitation response", "agent", r.currentAgent, "action", action) + slog.Debug("Resuming runtime with elicitation response", "agent", r.CurrentAgentName(), "action", action) result := ElicitationResult{ Action: action, @@ -1916,6 +1969,85 @@ func (r *LocalRuntime) startSpan(ctx context.Context, name string, opts ...trace return r.tracer.Start(ctx, name, opts...) } +// CurrentAgentSubAgentNames implements agenttool.Runner. +func (r *LocalRuntime) CurrentAgentSubAgentNames() []string { + a := r.CurrentAgent() + if a == nil { + return nil + } + var names []string + for _, sa := range a.SubAgents() { + names = append(names, sa.Name()) + } + return names +} + +// RunAgent implements agenttool.Runner. It starts a sub-agent synchronously and +// blocks until completion or cancellation. +func (r *LocalRuntime) RunAgent(ctx context.Context, params agenttool.RunParams) *agenttool.RunResult { + child, err := r.team.Agent(params.AgentName) + if err != nil { + return &agenttool.RunResult{ErrMsg: fmt.Sprintf("agent %q not found: %s", params.AgentName, err)} + } + + systemMsg := "You are a member of a team of agents. Your goal is to complete the following task:" + systemMsg += fmt.Sprintf("\n\n\n%s\n", params.Task) + if params.ExpectedOutput != "" { + systemMsg += fmt.Sprintf("\n\n\n%s\n", params.ExpectedOutput) + } + + sess := params.ParentSession + + // Background tasks run with tools pre-approved because there is no user present + // to respond to interactive approval prompts during async execution. This is a + // deliberate design trade-off: the user implicitly authorises all tool calls made + // by the sub-agent when they approve run_background_agent. Callers should be aware + // that prompt injection in the sub-agent's context could exploit this gate-bypass. + // + // TODO: propagate the parent session's per-tool permission rules once the runtime + // supports per-session permission scoping rather than a single shared ToolsApproved flag. + s := session.New( + session.WithSystemMessage(systemMsg), + session.WithImplicitUserMessage("Please proceed."), + session.WithMaxIterations(child.MaxIterations()), + session.WithTitle("Background agent task"), + session.WithToolsApproved(true), + session.WithThinking(sess.Thinking), + session.WithSendUserMessage(false), + session.WithParentID(sess.ID), + session.WithAgentName(params.AgentName), + ) + + var errMsg string + events := r.RunStream(ctx, s) + for event := range events { + if ctx.Err() != nil { + break + } + if choice, ok := event.(*AgentChoiceEvent); ok && choice.Content != "" { + if params.OnContent != nil { + params.OnContent(choice.Content) + } + } + if errEvt, ok := event.(*ErrorEvent); ok { + errMsg = errEvt.Error + break + } + } + // Drain remaining events so the RunStream goroutine can complete + // and close the channel without blocking on a full buffer. + for range events { + } + + if errMsg != "" { + return &agenttool.RunResult{ErrMsg: errMsg} + } + + result := s.GetLastAssistantMessageContent() + sess.AddSubSession(s) + return &agenttool.RunResult{Result: result} +} + func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, evts chan Event) (*tools.ToolCallResult, error) { var params struct { Agent string `json:"agent"` @@ -1955,14 +2087,14 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses slog.Debug("Transferring task to agent", "from_agent", a.Name(), "to_agent", params.Agent, "task", params.Task) - ca := r.currentAgent + ca := r.CurrentAgentName() // Emit agent switching start event evts <- AgentSwitching(true, ca, params.Agent) - r.currentAgent = params.Agent + r.setCurrentAgent(params.Agent) defer func() { - r.currentAgent = ca + r.setCurrentAgent(ca) // Emit agent switching end event evts <- AgentSwitching(false, params.Agent, ca) @@ -2033,7 +2165,7 @@ func (r *LocalRuntime) handleHandoff(_ context.Context, _ *session.Session, tool return nil, fmt.Errorf("invalid arguments: %w", err) } - ca := r.currentAgent + ca := r.CurrentAgentName() currentAgent, err := r.team.Agent(ca) if err != nil { return nil, fmt.Errorf("current agent not found: %w", err) @@ -2060,7 +2192,7 @@ func (r *LocalRuntime) handleHandoff(_ context.Context, _ *session.Session, tool return nil, err } - r.currentAgent = next.Name() + r.setCurrentAgent(next.Name()) handoffMessage := "The agent " + ca + " handed off the conversation to you. " + "Your available handoff agents and tools are specified in the system messages that follow. " + "Only use those capabilities - do not attempt to use tools or hand off to agents that you see " + @@ -2075,7 +2207,8 @@ func (r *LocalRuntime) handleHandoff(_ context.Context, _ *session.Session, tool // findModelPickerTool returns the ModelPickerTool from the current agent's // toolsets, or nil if the agent has no model_picker configured. func (r *LocalRuntime) findModelPickerTool() *builtin.ModelPickerTool { - a, err := r.team.Agent(r.currentAgent) + currentName := r.CurrentAgentName() + a, err := r.team.Agent(currentName) if err != nil { return nil } @@ -2123,21 +2256,22 @@ func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session // AgentInfo event so the UI reflects the change. An empty modelRef reverts to // the agent's default model. func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string, events chan Event) (*tools.ToolCallResult, error) { - if err := r.SetAgentModel(ctx, r.currentAgent, modelRef); err != nil { + currentName := r.CurrentAgentName() + if err := r.SetAgentModel(ctx, currentName, modelRef); err != nil { return tools.ResultError(fmt.Sprintf("failed to set model: %v", err)), nil } - if a, err := r.team.Agent(r.currentAgent); err == nil { + if a, err := r.team.Agent(currentName); err == nil { events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage()) } else { - slog.Warn("Failed to retrieve agent after model change; UI may not reflect the update", "agent", r.currentAgent, "error", err) + slog.Warn("Failed to retrieve agent after model change; UI may not reflect the update", "agent", currentName, "error", err) } if modelRef == "" { - slog.Info("Model reverted via model_picker tool", "agent", r.currentAgent) + slog.Info("Model reverted via model_picker tool", "agent", currentName) return tools.ResultSuccess("Model reverted to the agent's default model"), nil } - slog.Info("Model changed via model_picker tool", "agent", r.currentAgent, "model", modelRef) + slog.Info("Model changed via model_picker tool", "agent", currentName, "model", modelRef) return tools.ResultSuccess(fmt.Sprintf("Model changed to %s", modelRef)), nil } @@ -2145,7 +2279,7 @@ func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string, // The additionalPrompt parameter allows users to provide additional instructions // for the summarization (e.g., "focus on code changes" or "include action items"). func (r *LocalRuntime) Summarize(ctx context.Context, sess *session.Session, additionalPrompt string, events chan Event) { - r.sessionCompactor.Compact(ctx, sess, additionalPrompt, events, r.currentAgent) + r.sessionCompactor.Compact(ctx, sess, additionalPrompt, events, r.CurrentAgentName()) // Emit a TokenUsageEvent so the sidebar immediately reflects the // compaction: tokens drop to the summary size, context % drops, and @@ -2156,7 +2290,7 @@ func (r *LocalRuntime) Summarize(ctx context.Context, sess *session.Session, add if m, err := r.modelsStore.GetModel(ctx, modelID); err == nil && m != nil { contextLimit = int64(m.Limit.Context) } - events <- NewTokenUsageEvent(sess.ID, r.currentAgent, SessionUsage(sess, contextLimit)) + events <- NewTokenUsageEvent(sess.ID, r.CurrentAgentName(), SessionUsage(sess, contextLimit)) } // setElicitationEventsChannel sets the current events channel for elicitation requests @@ -2193,7 +2327,7 @@ func (r *LocalRuntime) elicitationHandler(ctx context.Context, req *mcp.ElicitPa slog.Debug("Elicitation request meta", "meta", req.Meta) // Send elicitation request event to the runtime's client - eventsChannel <- ElicitationRequest(req.Message, req.Mode, req.RequestedSchema, req.URL, req.ElicitationID, req.Meta, r.currentAgent) + eventsChannel <- ElicitationRequest(req.Message, req.Mode, req.RequestedSchema, req.URL, req.ElicitationID, req.Meta, r.CurrentAgentName()) r.elicitationEventsChannelMux.RUnlock() // Wait for response from the client diff --git a/pkg/session/session.go b/pkg/session/session.go index 037455c7d..bdbc56148 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -122,6 +122,12 @@ type Session struct { // BranchCreatedAt is the time when this branch session was created. BranchCreatedAt *time.Time `json:"branch_created_at,omitempty"` + // AgentName, when set, tells RunStream which agent to use for this session + // instead of reading from the shared runtime currentAgent field. This is + // required for background agent tasks where multiple sessions may run + // concurrently on different agents. + AgentName string `json:"-"` + // ParentID indicates this is a sub-session created by task transfer. // Sub-sessions are not persisted as standalone entries; they are embedded // within the parent session's Messages array. @@ -459,6 +465,15 @@ func WithPermissions(perms *PermissionsConfig) Opt { } } +// WithAgentName pins this session to a specific agent. When set, RunStream +// resolves the agent from the session rather than the shared runtime state, +// which is required for concurrent background agent tasks. +func WithAgentName(name string) Opt { + return func(s *Session) { + s.AgentName = name + } +} + // WithParentID marks this session as a sub-session of the given parent. // Sub-sessions are not persisted as standalone entries in the session store. func WithParentID(parentID string) Opt { diff --git a/pkg/session/store.go b/pkg/session/store.go index 33fed9e60..6cfd86d19 100644 --- a/pkg/session/store.go +++ b/pkg/session/store.go @@ -197,6 +197,7 @@ func (s *InMemorySessionStore) UpdateSession(_ context.Context, session *Session // Build a new session with the same metadata but a fresh mutex. // Messages are stored separately via AddMessage. + // MAINTENANCE: when adding new persisted fields to Session, add them here too. newSession := &Session{ ID: session.ID, Title: session.Title, diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 22773423a..ba53b5e5c 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -23,6 +23,7 @@ import ( "github.com/docker/cagent/pkg/team" "github.com/docker/cagent/pkg/tools" "github.com/docker/cagent/pkg/tools/builtin" + agenttool "github.com/docker/cagent/pkg/tools/builtin/agent" "github.com/docker/cagent/pkg/tools/codemode" ) @@ -463,7 +464,7 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri } if len(a.SubAgents) > 0 { - toolSets = append(toolSets, builtin.NewTransferTaskTool()) + toolSets = append(toolSets, builtin.NewTransferTaskTool(), agenttool.NewToolSet()) } if len(a.Handoffs) > 0 { toolSets = append(toolSets, builtin.NewHandoffTool()) diff --git a/pkg/tools/builtin/agent/agent.go b/pkg/tools/builtin/agent/agent.go new file mode 100644 index 000000000..bce074f01 --- /dev/null +++ b/pkg/tools/builtin/agent/agent.go @@ -0,0 +1,454 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + + "github.com/docker/cagent/pkg/concurrent" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/tools" +) + +const ( + ToolNameRunBackgroundAgent = "run_background_agent" + ToolNameListBackgroundAgents = "list_background_agents" + ToolNameViewBackgroundAgent = "view_background_agent" + ToolNameStopBackgroundAgent = "stop_background_agent" +) + +const ( + // maxConcurrentTasks is the maximum number of simultaneously running background agent tasks. + maxConcurrentTasks = 20 + // maxTotalTasks caps total stored tasks (running + completed) to prevent unbounded memory growth. + maxTotalTasks = 100 + // maxOutputBytes caps the live output buffer per task, mirroring the shell tool's limit. + maxOutputBytes = 10 * 1024 * 1024 // 10 MB +) + +// RunBackgroundAgentArgs specifies the parameters for dispatching a sub-agent task asynchronously. +type RunBackgroundAgentArgs struct { + Agent string `json:"agent" jsonschema:"The name of the sub-agent to run in the background."` + Task string `json:"task" jsonschema:"A clear and concise description of the task the agent should achieve."` + ExpectedOutput string `json:"expected_output,omitempty" jsonschema:"The expected output from the agent (optional)."` +} + +// ViewBackgroundAgentArgs specifies the task ID to inspect. +type ViewBackgroundAgentArgs struct { + TaskID string `json:"task_id" jsonschema:"The ID of the background agent task to view."` +} + +// StopBackgroundAgentArgs specifies the task ID to cancel. +type StopBackgroundAgentArgs struct { + TaskID string `json:"task_id" jsonschema:"The ID of the background agent task to stop."` +} + +// RunParams holds the parameters for running a sub-agent. +type RunParams struct { + AgentName string + Task string + ExpectedOutput string + ParentSession *session.Session + OnContent func(content string) +} + +// RunResult holds the outcome of a sub-agent execution. +type RunResult struct { + Result string // final assistant message on completion + ErrMsg string // error detail if failed +} + +// Runner abstracts the runtime dependency for background agent execution. +type Runner interface { + // CurrentAgentSubAgentNames returns the names of the current agent's sub-agents. + CurrentAgentSubAgentNames() []string + // RunAgent starts a sub-agent and blocks until completion or cancellation. + RunAgent(ctx context.Context, params RunParams) *RunResult +} + +// taskStatus represents the lifecycle state of a background agent task. +type taskStatus int32 + +const ( + taskRunning taskStatus = iota + taskCompleted + taskStopped + taskFailed +) + +var taskStatusStrings = map[taskStatus]string{ + taskRunning: "running", + taskCompleted: "completed", + taskStopped: "stopped", + taskFailed: "failed", +} + +func statusToString(s taskStatus) string { + if str, ok := taskStatusStrings[s]; ok { + return str + } + return "unknown" +} + +// task tracks a single background sub-agent execution. +type task struct { + id string + agentName string + taskDesc string + + cancel context.CancelFunc + outputMu sync.RWMutex + output strings.Builder + outputBytes int + startTime time.Time + status atomic.Int32 + result string + errMsg string +} + +func (t *task) loadStatus() taskStatus { + return taskStatus(t.status.Load()) +} + +func (t *task) storeStatus(s taskStatus) { + t.status.Store(int32(s)) +} + +func (t *task) casStatus(old, next taskStatus) bool { + return t.status.CompareAndSwap(int32(old), int32(next)) +} + +// Handler owns all background agent tasks and provides tool handlers. +type Handler struct { + runner Runner + wg sync.WaitGroup + tasks *concurrent.Map[string, *task] +} + +// NewHandler creates a new Handler with the given Runner. +func NewHandler(runner Runner) *Handler { + return &Handler{ + runner: runner, + tasks: concurrent.NewMap[string, *task](), + } +} + +func newTaskID() string { + return fmt.Sprintf("agent_task_%s", uuid.New().String()) +} + +func (h *Handler) runningTaskCount() int { + var count int + h.tasks.Range(func(_ string, t *task) bool { + if t.loadStatus() == taskRunning { + count++ + } + return true + }) + return count +} + +func (h *Handler) totalTaskCount() int { + return h.tasks.Length() +} + +func (h *Handler) pruneCompleted() { + var toDelete []string + h.tasks.Range(func(id string, t *task) bool { + s := t.loadStatus() + if s == taskCompleted || s == taskStopped || s == taskFailed { + toDelete = append(toDelete, id) + } + return true + }) + for _, id := range toDelete { + h.tasks.Delete(id) + } +} + +// HandleRun starts a sub-agent task asynchronously and returns a task ID immediately. +func (h *Handler) HandleRun(ctx context.Context, sess *session.Session, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { + var params RunBackgroundAgentArgs + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms); err != nil { + return nil, fmt.Errorf("invalid arguments: %w", err) + } + + if strings.TrimSpace(params.Agent) == "" { + return tools.ResultError("agent name must not be empty"), nil + } + if strings.TrimSpace(params.Task) == "" { + return tools.ResultError("task must not be empty"), nil + } + + subAgentNames := h.runner.CurrentAgentSubAgentNames() + valid := false + for _, name := range subAgentNames { + if name == params.Agent { + valid = true + break + } + } + if !valid { + if len(subAgentNames) > 0 { + return tools.ResultError(fmt.Sprintf("agent %q is not in the sub-agents list. Available: %s", params.Agent, strings.Join(subAgentNames, ", "))), nil + } + return tools.ResultError(fmt.Sprintf("agent %q is not in the sub-agents list. This agent has no sub-agents configured.", params.Agent)), nil + } + + // Enforce concurrency cap. + if h.runningTaskCount() >= maxConcurrentTasks { + return tools.ResultError(fmt.Sprintf("maximum concurrent background agent tasks (%d) reached; stop or wait for existing tasks to complete", maxConcurrentTasks)), nil + } + + // Enforce total cap, pruning finished tasks first. + if h.totalTaskCount() >= maxTotalTasks { + h.pruneCompleted() + if h.totalTaskCount() >= maxTotalTasks { + return tools.ResultError(fmt.Sprintf("maximum total background agent tasks (%d) reached; view and discard old tasks first", maxTotalTasks)), nil + } + } + + taskID := newTaskID() + + taskCtx, cancel := context.WithCancel(ctx) + + t := &task{ + id: taskID, + agentName: params.Agent, + taskDesc: params.Task, + cancel: cancel, + startTime: time.Now(), + } + t.storeStatus(taskRunning) + h.tasks.Store(taskID, t) + + h.wg.Add(1) + go func() { + defer h.wg.Done() + defer cancel() + + slog.Debug("Starting background agent task", "task_id", taskID, "agent", params.Agent) + + result := h.runner.RunAgent(taskCtx, RunParams{ + AgentName: params.Agent, + Task: params.Task, + ExpectedOutput: params.ExpectedOutput, + ParentSession: sess, + OnContent: func(content string) { + t.outputMu.Lock() + if t.outputBytes < maxOutputBytes { + n, _ := t.output.WriteString(content) + t.outputBytes += n + } + t.outputMu.Unlock() + }, + }) + + if result.ErrMsg != "" { + t.errMsg = result.ErrMsg + t.storeStatus(taskFailed) + slog.Debug("Background agent task failed", "task_id", taskID, "agent", params.Agent, "error", result.ErrMsg) + return + } + + if taskCtx.Err() != nil && t.loadStatus() == taskRunning { + t.storeStatus(taskStopped) + slog.Debug("Background agent task stopped", "task_id", taskID) + return + } + + // Write result before CAS so readers who observe taskCompleted + // always see the populated result field. + t.result = result.Result + if t.casStatus(taskRunning, taskCompleted) { + slog.Debug("Background agent task completed", "task_id", taskID, "agent", params.Agent) + } + }() + + return tools.ResultSuccess(fmt.Sprintf("Background agent task started with ID: %s\nAgent: %s\nTask: %s", + taskID, params.Agent, params.Task)), nil +} + +// HandleList lists all background agent tasks. +func (h *Handler) HandleList(_ context.Context, _ *session.Session, _ tools.ToolCall) (*tools.ToolCallResult, error) { + var out strings.Builder + out.WriteString("Background Agent Tasks:\n\n") + + var count int + h.tasks.Range(func(_ string, t *task) bool { + count++ + status := t.loadStatus() + elapsed := time.Since(t.startTime).Round(time.Second) + fmt.Fprintf(&out, "ID: %s\n", t.id) + fmt.Fprintf(&out, " Agent: %s\n", t.agentName) + fmt.Fprintf(&out, " Status: %s\n", statusToString(status)) + fmt.Fprintf(&out, " Runtime: %s\n", elapsed) + out.WriteString("\n") + return true + }) + + if count == 0 { + out.WriteString("No background agent tasks found.\n") + } + + return tools.ResultSuccess(out.String()), nil +} + +// HandleView returns the output and status of a specific background agent task. +func (h *Handler) HandleView(_ context.Context, _ *session.Session, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { + var params ViewBackgroundAgentArgs + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms); err != nil { + return nil, fmt.Errorf("invalid arguments: %w", err) + } + + t, exists := h.tasks.Load(params.TaskID) + if !exists { + return tools.ResultError(fmt.Sprintf("task not found: %s", params.TaskID)), nil + } + + status := t.loadStatus() + elapsed := time.Since(t.startTime).Round(time.Second) + + var out strings.Builder + fmt.Fprintf(&out, "Task ID: %s\n", t.id) + fmt.Fprintf(&out, "Agent: %s\n", t.agentName) + fmt.Fprintf(&out, "Status: %s\n", statusToString(status)) + fmt.Fprintf(&out, "Runtime: %s\n", elapsed) + out.WriteString("\n--- Output ---\n") + + switch status { + case taskCompleted: + if t.result != "" { + out.WriteString(t.result) + } else { + out.WriteString("") + } + case taskFailed: + out.WriteString("") + if t.errMsg != "" { + fmt.Fprintf(&out, "\nError: %s", t.errMsg) + } + case taskStopped: + out.WriteString("") + default: + t.outputMu.RLock() + progress := t.output.String() + truncated := t.outputBytes >= maxOutputBytes + t.outputMu.RUnlock() + if progress != "" { + out.WriteString(progress) + if truncated { + out.WriteString("\n\n[output truncated at 10MB limit — still running...]") + } else { + out.WriteString("\n\n[still running...]") + } + } else { + out.WriteString("") + } + } + + return tools.ResultSuccess(out.String()), nil +} + +// HandleStop cancels a running background agent task. +func (h *Handler) HandleStop(_ context.Context, _ *session.Session, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { + var params StopBackgroundAgentArgs + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms); err != nil { + return nil, fmt.Errorf("invalid arguments: %w", err) + } + + t, exists := h.tasks.Load(params.TaskID) + if !exists { + return tools.ResultError(fmt.Sprintf("task not found: %s", params.TaskID)), nil + } + + if !t.casStatus(taskRunning, taskStopped) { + current := t.loadStatus() + return tools.ResultError(fmt.Sprintf("task %s is not running (status: %s)", params.TaskID, statusToString(current))), nil + } + + t.cancel() + + return tools.ResultSuccess(fmt.Sprintf("Background agent task %s stopped.", params.TaskID)), nil +} + +// StopAll cancels all running tasks and waits for their goroutines to exit. +// Called during runtime shutdown to ensure clean teardown. +func (h *Handler) StopAll() { + h.tasks.Range(func(_ string, t *task) bool { + if t.casStatus(taskRunning, taskStopped) { + t.cancel() + } + return true + }) + h.wg.Wait() +} + +// toolSet is a lightweight ToolSet that returns just the tool definitions +// without requiring a Runner. Used by teamloader to register tool schemas. +type toolSet struct{} + +// NewToolSet returns a ToolSet for registering background agent tool definitions. +// This does not require a Runner and is suitable for use in teamloader. +func NewToolSet() tools.ToolSet { + return &toolSet{} +} + +func (t *toolSet) Tools(ctx context.Context) ([]tools.Tool, error) { + return backgroundAgentTools() +} + +// Tools returns the four background agent tool definitions. +func (h *Handler) Tools(ctx context.Context) ([]tools.Tool, error) { + return backgroundAgentTools() +} + +func backgroundAgentTools() ([]tools.Tool, error) { + return []tools.Tool{ + { + Name: ToolNameRunBackgroundAgent, + Category: "transfer", + Description: `Start a sub-agent task in the background and return immediately with a task ID. +Use this to dispatch work to multiple sub-agents concurrently. The sub-agent runs with all tools +pre-approved — use only with trusted sub-agents and well-scoped tasks. Check progress with +view_background_agent and collect results once the task is complete.`, + Parameters: tools.MustSchemaFor[RunBackgroundAgentArgs](), + Annotations: tools.ToolAnnotations{Title: "Run Background Agent"}, + }, + { + Name: ToolNameListBackgroundAgents, + Category: "transfer", + Description: `List all background agent tasks with their status and runtime.`, + Annotations: tools.ToolAnnotations{ + Title: "List Background Agents", + ReadOnlyHint: true, + }, + }, + { + Name: ToolNameViewBackgroundAgent, + Category: "transfer", + Description: `View the output and status of a specific background agent task by task ID. Returns live buffered output if still running, or the final result if complete.`, + Parameters: tools.MustSchemaFor[ViewBackgroundAgentArgs](), + Annotations: tools.ToolAnnotations{ + Title: "View Background Agent", + ReadOnlyHint: true, + }, + }, + { + Name: ToolNameStopBackgroundAgent, + Category: "transfer", + Description: `Stop a running background agent task by task ID.`, + Parameters: tools.MustSchemaFor[StopBackgroundAgentArgs](), + Annotations: tools.ToolAnnotations{ + Title: "Stop Background Agent", + }, + }, + }, nil +} diff --git a/pkg/tools/builtin/agent/agent_test.go b/pkg/tools/builtin/agent/agent_test.go new file mode 100644 index 000000000..53b089d6a --- /dev/null +++ b/pkg/tools/builtin/agent/agent_test.go @@ -0,0 +1,559 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/concurrent" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/tools" +) + +// mockRunner implements Runner for testing. +type mockRunner struct { + subAgentNames []string + runResult *RunResult + runDelay time.Duration // optional delay to simulate work +} + +func (m *mockRunner) CurrentAgentSubAgentNames() []string { return m.subAgentNames } +func (m *mockRunner) RunAgent(ctx context.Context, params RunParams) *RunResult { + if m.runDelay > 0 { + select { + case <-time.After(m.runDelay): + case <-ctx.Done(): + return &RunResult{} + } + } + // Call OnContent if result has content, to simulate streaming. + if m.runResult != nil && m.runResult.Result != "" && params.OnContent != nil { + params.OnContent(m.runResult.Result) + } + if m.runResult != nil { + return m.runResult + } + return &RunResult{} +} + +func newTestHandler() *Handler { + return &Handler{ + tasks: concurrent.NewMap[string, *task](), + } +} + +func newTestHandlerWithRunner(r Runner) *Handler { + return NewHandler(r) +} + +func insertTask(h *Handler, id, agentName string, status taskStatus) *task { + t := &task{ + id: id, + agentName: agentName, + taskDesc: "test task", + cancel: func() {}, + startTime: time.Now(), + } + t.status.Store(int32(status)) + h.tasks.Store(id, t) + return t +} + +func makeToolCall(t *testing.T, args any) tools.ToolCall { + t.Helper() + b, err := json.Marshal(args) + require.NoError(t, err) + return tools.ToolCall{Function: tools.FunctionCall{Arguments: string(b)}} +} + +// --- newTaskID --- + +func TestNewTaskID_IsUnique(t *testing.T) { + ids := make(map[string]struct{}) + for range 100 { + id := newTaskID() + assert.NotEmpty(t, id) + _, dup := ids[id] + assert.False(t, dup, "duplicate task ID: %s", id) + ids[id] = struct{}{} + } +} + +func TestNewTaskID_HasPrefix(t *testing.T) { + id := newTaskID() + assert.True(t, strings.HasPrefix(id, "agent_task_"), "ID should start with agent_task_ prefix, got: %s", id) +} + +// --- statusToString --- + +func TestStatusToString(t *testing.T) { + cases := []struct { + status taskStatus + expected string + }{ + {taskRunning, "running"}, + {taskCompleted, "completed"}, + {taskStopped, "stopped"}, + {taskFailed, "failed"}, + {99, "unknown"}, + } + for _, tc := range cases { + assert.Equal(t, tc.expected, statusToString(tc.status)) + } +} + +// --- runningTaskCount / totalTaskCount --- + +func TestTaskCounts(t *testing.T) { + h := newTestHandler() + assert.Equal(t, 0, h.runningTaskCount()) + assert.Equal(t, 0, h.totalTaskCount()) + + insertTask(h, "t1", "a", taskRunning) + insertTask(h, "t2", "b", taskRunning) + insertTask(h, "t3", "c", taskCompleted) + insertTask(h, "t4", "d", taskFailed) + + assert.Equal(t, 2, h.runningTaskCount()) + assert.Equal(t, 4, h.totalTaskCount()) +} + +// --- pruneCompleted --- + +func TestPruneCompleted(t *testing.T) { + h := newTestHandler() + insertTask(h, "run1", "a", taskRunning) + insertTask(h, "done1", "b", taskCompleted) + insertTask(h, "done2", "c", taskStopped) + insertTask(h, "fail1", "d", taskFailed) + + h.pruneCompleted() + + assert.Equal(t, 1, h.totalTaskCount()) + _, exists := h.tasks.Load("run1") + assert.True(t, exists, "running task should be kept") + _, exists = h.tasks.Load("done1") + assert.False(t, exists, "completed task should be pruned") +} + +// --- HandleList --- + +func TestHandleList_Empty(t *testing.T) { + h := newTestHandler() + result, err := h.HandleList(t.Context(), nil, tools.ToolCall{}) + require.NoError(t, err) + assert.Contains(t, result.Output, "No background agent tasks found") +} + +func TestHandleList_ShowsTasks(t *testing.T) { + h := newTestHandler() + insertTask(h, "t1", "researcher", taskRunning) + insertTask(h, "t2", "writer", taskCompleted) + + result, err := h.HandleList(t.Context(), nil, tools.ToolCall{}) + require.NoError(t, err) + assert.Contains(t, result.Output, "researcher") + assert.Contains(t, result.Output, "writer") + assert.Contains(t, result.Output, "running") + assert.Contains(t, result.Output, "completed") +} + +// --- HandleView --- + +func TestHandleView_NotFound(t *testing.T) { + h := newTestHandler() + tc := makeToolCall(t, ViewBackgroundAgentArgs{TaskID: "nonexistent"}) + result, err := h.HandleView(t.Context(), nil, tc) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "task not found") +} + +func TestHandleView_Completed(t *testing.T) { + h := newTestHandler() + tk := insertTask(h, "t1", "researcher", taskCompleted) + tk.result = "Here is my research." + + tc := makeToolCall(t, ViewBackgroundAgentArgs{TaskID: "t1"}) + result, err := h.HandleView(t.Context(), nil, tc) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.Contains(t, result.Output, "Here is my research.") + assert.Contains(t, result.Output, "completed") +} + +func TestHandleView_Failed(t *testing.T) { + h := newTestHandler() + tk := insertTask(h, "t1", "researcher", taskFailed) + tk.errMsg = "model unavailable" + + tc := makeToolCall(t, ViewBackgroundAgentArgs{TaskID: "t1"}) + result, err := h.HandleView(t.Context(), nil, tc) + require.NoError(t, err) + assert.Contains(t, result.Output, "task failed") + assert.Contains(t, result.Output, "model unavailable") +} + +func TestHandleView_Running_NoOutputYet(t *testing.T) { + h := newTestHandler() + insertTask(h, "t1", "researcher", taskRunning) + + tc := makeToolCall(t, ViewBackgroundAgentArgs{TaskID: "t1"}) + result, err := h.HandleView(t.Context(), nil, tc) + require.NoError(t, err) + assert.Contains(t, result.Output, "no output yet") +} + +func TestHandleView_Running_WithProgress(t *testing.T) { + h := newTestHandler() + tk := insertTask(h, "t1", "researcher", taskRunning) + tk.output.WriteString("Partial research so far...") + + tc := makeToolCall(t, ViewBackgroundAgentArgs{TaskID: "t1"}) + result, err := h.HandleView(t.Context(), nil, tc) + require.NoError(t, err) + assert.Contains(t, result.Output, "Partial research so far...") + assert.Contains(t, result.Output, "still running") +} + +func TestHandleView_Stopped(t *testing.T) { + h := newTestHandler() + insertTask(h, "t1", "researcher", taskStopped) + + tc := makeToolCall(t, ViewBackgroundAgentArgs{TaskID: "t1"}) + result, err := h.HandleView(t.Context(), nil, tc) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.Contains(t, result.Output, "stopped") + assert.Contains(t, result.Output, "task was stopped") +} + +func TestHandleView_Completed_EmptyResult(t *testing.T) { + h := newTestHandler() + insertTask(h, "t1", "researcher", taskCompleted) + + tc := makeToolCall(t, ViewBackgroundAgentArgs{TaskID: "t1"}) + result, err := h.HandleView(t.Context(), nil, tc) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.Contains(t, result.Output, "no output") +} + +func TestHandleView_OutputBufferTruncated(t *testing.T) { + h := newTestHandler() + tk := insertTask(h, "t1", "researcher", taskRunning) + tk.output.WriteString(strings.Repeat("x", maxOutputBytes)) + tk.outputBytes = maxOutputBytes + + tc := makeToolCall(t, ViewBackgroundAgentArgs{TaskID: "t1"}) + result, err := h.HandleView(t.Context(), nil, tc) + require.NoError(t, err) + assert.Contains(t, result.Output, "truncated", "should show truncation notice when buffer is full") + assert.Contains(t, result.Output, "still running") +} + +func TestHandleView_InvalidJSON(t *testing.T) { + h := newTestHandler() + bad := tools.ToolCall{Function: tools.FunctionCall{Arguments: "not-json"}} + _, err := h.HandleView(t.Context(), nil, bad) + require.Error(t, err, "invalid JSON should return an error") +} + +// --- HandleStop --- + +func TestHandleStop_NotFound(t *testing.T) { + h := newTestHandler() + tc := makeToolCall(t, StopBackgroundAgentArgs{TaskID: "ghost"}) + result, err := h.HandleStop(t.Context(), nil, tc) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "task not found") +} + +func TestHandleStop_AlreadyCompleted(t *testing.T) { + h := newTestHandler() + insertTask(h, "t1", "researcher", taskCompleted) + + tc := makeToolCall(t, StopBackgroundAgentArgs{TaskID: "t1"}) + result, err := h.HandleStop(t.Context(), nil, tc) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "not running") +} + +func TestHandleStop_Running(t *testing.T) { + h := newTestHandler() + cancelled := false + tk := insertTask(h, "t1", "researcher", taskRunning) + tk.cancel = func() { cancelled = true } + + tc := makeToolCall(t, StopBackgroundAgentArgs{TaskID: "t1"}) + result, err := h.HandleStop(t.Context(), nil, tc) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.True(t, cancelled) + assert.Equal(t, taskStopped, tk.loadStatus()) +} + +func TestHandleStop_InvalidJSON(t *testing.T) { + h := newTestHandler() + bad := tools.ToolCall{Function: tools.FunctionCall{Arguments: "not-json"}} + _, err := h.HandleStop(t.Context(), nil, bad) + require.Error(t, err, "invalid JSON should return an error") +} + +// --- StopAll waits for goroutines --- + +func TestStopAll_WaitsForGoroutines(t *testing.T) { + h := newTestHandler() + + var goroutineExited atomic.Bool + tk := insertTask(h, "t1", "researcher", taskRunning) + ctx, cancel := context.WithCancel(t.Context()) + tk.cancel = cancel + + h.wg.Add(1) + go func() { + defer h.wg.Done() + <-ctx.Done() + time.Sleep(10 * time.Millisecond) // simulate teardown work + goroutineExited.Store(true) + }() + + h.StopAll() + assert.True(t, goroutineExited.Load(), "StopAll should wait for goroutine to exit") +} + +// --- HandleRun: input validation --- + +func TestHandleRun_EmptyAgent(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{subAgentNames: []string{"sub"}}) + tc := makeToolCall(t, RunBackgroundAgentArgs{Agent: "", Task: "do something"}) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "agent name must not be empty") +} + +func TestHandleRun_EmptyTask(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{subAgentNames: []string{"sub"}}) + tc := makeToolCall(t, RunBackgroundAgentArgs{Agent: "sub", Task: ""}) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "task must not be empty") +} + +func TestHandleRun_InvalidSubAgent(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{subAgentNames: []string{"sub"}}) + tc := makeToolCall(t, RunBackgroundAgentArgs{Agent: "nonexistent", Task: "do something"}) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "not in the sub-agents list") +} + +func TestHandleRun_NoSubAgents(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{subAgentNames: nil}) + tc := makeToolCall(t, RunBackgroundAgentArgs{Agent: "some-agent", Task: "do something"}) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "no sub-agents configured") +} + +func TestHandleRun_ConcurrencyCapEnforced(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{subAgentNames: []string{"sub"}}) + + for i := range maxConcurrentTasks { + insertTask(h, "fake"+string(rune('a'+i)), "sub", taskRunning) + } + + tc := makeToolCall(t, RunBackgroundAgentArgs{Agent: "sub", Task: "do something"}) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "maximum concurrent") +} + +func TestHandleRun_InvalidJSON(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{subAgentNames: []string{"sub"}}) + bad := tools.ToolCall{Function: tools.FunctionCall{Arguments: "not-json"}} + _, err := h.HandleRun(t.Context(), session.New(), bad) + require.Error(t, err, "invalid JSON should return an error") +} + +func TestHandleRun_StartsTask(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{ + subAgentNames: []string{"sub"}, + runResult: &RunResult{Result: "done"}, + }) + + tc := makeToolCall(t, RunBackgroundAgentArgs{Agent: "sub", Task: "write a poem"}) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.Contains(t, result.Output, "agent_task_") + assert.Contains(t, result.Output, "sub") + + h.wg.Wait() + + assert.Equal(t, 1, h.totalTaskCount()) + h.tasks.Range(func(_ string, tk *task) bool { + assert.Equal(t, taskCompleted, tk.loadStatus()) + return true + }) +} + +func TestHandleRun_ProviderError_TaskFails(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{ + subAgentNames: []string{"sub"}, + runResult: &RunResult{ErrMsg: "model unavailable"}, + }) + + tc := makeToolCall(t, RunBackgroundAgentArgs{Agent: "sub", Task: "do something"}) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.False(t, result.IsError, "HandleRun should start successfully before provider error") + + h.wg.Wait() + + h.tasks.Range(func(_ string, tk *task) bool { + assert.Equal(t, taskFailed, tk.loadStatus(), "task should be marked failed on provider error") + assert.NotEmpty(t, tk.errMsg) + return true + }) +} + +func TestHandleRun_WithExpectedOutput(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{ + subAgentNames: []string{"sub"}, + runResult: &RunResult{Result: "result"}, + }) + + tc := makeToolCall(t, RunBackgroundAgentArgs{ + Agent: "sub", + Task: "summarize the document", + ExpectedOutput: "A one-paragraph summary", + }) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.False(t, result.IsError) + + h.wg.Wait() + + h.tasks.Range(func(_ string, tk *task) bool { + assert.Equal(t, taskCompleted, tk.loadStatus()) + return true + }) +} + +func TestHandleRun_TotalCapAutoPruneAdmits(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{ + subAgentNames: []string{"sub"}, + runResult: &RunResult{Result: "done"}, + }) + + for i := range maxTotalTasks { + insertTask(h, fmt.Sprintf("done%d", i), "sub", taskCompleted) + } + assert.Equal(t, maxTotalTasks, h.totalTaskCount()) + + tc := makeToolCall(t, RunBackgroundAgentArgs{Agent: "sub", Task: "do something"}) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.False(t, result.IsError, "task should be admitted after auto-prune of completed tasks") + + h.wg.Wait() +} + +func TestHandleRun_TotalCapExhaustion_ConcurrencyCapFiresFirst(t *testing.T) { + h := newTestHandlerWithRunner(&mockRunner{subAgentNames: []string{"sub"}}) + + for i := range maxConcurrentTasks { + insertTask(h, fmt.Sprintf("run%d", i), "sub", taskRunning) + } + + tc := makeToolCall(t, RunBackgroundAgentArgs{Agent: "sub", Task: "do something"}) + result, err := h.HandleRun(t.Context(), session.New(), tc) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "maximum concurrent", + "concurrency cap should fire before total cap can be exhausted non-prunably") +} + +// --- Concurrent handler access (run with -race) --- + +func TestHandler_ConcurrentAccess(t *testing.T) { + h := newTestHandler() + + for i := range 10 { + tk := insertTask(h, fmt.Sprintf("task%d", i), "researcher", taskRunning) + tk.output.WriteString("some progress output") + tk.outputBytes = len("some progress output") + } + + viewTCs := make([]tools.ToolCall, 5) + for i := range 5 { + viewTCs[i] = makeToolCall(t, ViewBackgroundAgentArgs{TaskID: fmt.Sprintf("task%d", i%10)}) + } + stopTCs := make([]tools.ToolCall, 3) + for i := range 3 { + stopTCs[i] = makeToolCall(t, StopBackgroundAgentArgs{TaskID: fmt.Sprintf("task%d", i)}) + } + + var wg sync.WaitGroup + + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = h.HandleList(t.Context(), nil, tools.ToolCall{}) + }() + } + + for i := range 5 { + wg.Add(1) + go func(tc tools.ToolCall) { + defer wg.Done() + _, _ = h.HandleView(t.Context(), nil, tc) + }(viewTCs[i]) + } + + for i := range 3 { + wg.Add(1) + go func(tc tools.ToolCall) { + defer wg.Done() + _, _ = h.HandleStop(t.Context(), nil, tc) + }(stopTCs[i]) + } + + wg.Wait() + assert.LessOrEqual(t, h.runningTaskCount(), 10) +} + +// --- Tools --- + +func TestTools_ReturnsFourTools(t *testing.T) { + h := NewHandler(&mockRunner{}) + toolsList, err := h.Tools(t.Context()) + require.NoError(t, err) + assert.Len(t, toolsList, 4) + + names := make([]string, len(toolsList)) + for i, tl := range toolsList { + names[i] = tl.Name + } + assert.Contains(t, names, ToolNameRunBackgroundAgent) + assert.Contains(t, names, ToolNameListBackgroundAgents) + assert.Contains(t, names, ToolNameViewBackgroundAgent) + assert.Contains(t, names, ToolNameStopBackgroundAgent) +}