diff --git a/.agentguard/live-viewer.json b/.agentguard/live-viewer.json new file mode 100644 index 0000000..33910c0 --- /dev/null +++ b/.agentguard/live-viewer.json @@ -0,0 +1 @@ +{"port":36841,"pid":853811,"startedAt":1774998051503} \ No newline at end of file diff --git a/.agentguard/squads/shellforge/blockers.md b/.agentguard/squads/shellforge/blockers.md index 315e085..6483c6b 100644 --- a/.agentguard/squads/shellforge/blockers.md +++ b/.agentguard/squads/shellforge/blockers.md @@ -1,7 +1,7 @@ # ShellForge Squad — Blockers -**Updated:** 2026-03-31T00:00Z -**Reported by:** EM run 9 (claude-code:opus:shellforge:em) +**Updated:** 2026-03-31T08:30Z +**Reported by:** EM run 11 (claude-code:opus:shellforge:em) --- @@ -13,35 +13,24 @@ ## P1 — Active Work -**None.** All P1 issues closed (PR #89 merged — closes #68 + #66). +**None.** All P1 issues closed. --- -## Incident (Resolved) - -### Broken worktree — incomplete WIP fix for #51 -**Detected:** Run 9 (2026-03-31) -**Resolved:** Yes -**Description:** The worktree had uncommitted partial changes to `cmd/shellforge/main.go`: -- `import (` was replaced with `import "log"`, breaking the multi-package import block syntax -- `run()` was partially refactored to call a non-existent `executeCommand()` function, leaving the old body orphaned outside any function -- Build failure: `syntax error: non-declaration statement outside function body` +## P2 — Active Blockers -**Resolution:** Stashed the WIP changes, created `fix/run-silent-errors-51` branch from `origin/main`, implemented the fix correctly (add `"log"` to imports, log error in `run()` via `if err := cmd.Run(); err != nil`). PR #93 open. +### PR Review Queue (budget: 3/3) ---- +| PR | Title | CI | Status | +|----|-------|----|--------| +| #93 | fix run() silent errors (closes #51) | ✅ 5/5 | REVIEW REQUIRED | +| #95 | fix scheduler WriteFile silent error (closes #65) | ✅ 5/5 | REVIEW REQUIRED | +| #96 | fix cmdScan Glob→WalkDir (closes #52) | ⏳ pending | REVIEW REQUIRED | -## P2 — Active Blockers +**Action Required:** @jpleva91 review and merge PRs #93, #95, #96 to clear budget for remaining P2 sweep. -### PR Review Queue (budget: 2/3) -| PR | Title | Status | -|----|-------|--------| -| #91 | EM state update run 8 | CI green — REVIEW REQUIRED | -| #93 | fix run() silent errors (closes #51) | CI pending — REVIEW REQUIRED | +### #76 — Dogfood: setup.sh doesn't support remote Ollama (4th escalation) -**Action Required:** @jpleva91 review and merge PR #91 and PR #93. - -### #76 — Dogfood: setup.sh doesn't support remote Ollama (3rd escalation) **Severity:** Medium — dogfood on jared-box (headless WSL2 + RunPod GPU) blocked **Root cause:** `shellforge setup` detects `isServer=true` on headless Linux and skips Goose + Ollama entirely, with no option to configure `OLLAMA_HOST` for a remote GPU endpoint. **Fix needed:** setup.sh should offer remote Ollama config when `isServer=true` — set `OLLAMA_HOST`, skip local Ollama install, keep Goose setup. @@ -49,13 +38,11 @@ --- -## P2 — Queued (unassigned) +## P2 — Queued (unassigned, after budget clears) | # | Issue | Notes | |---|-------|-------| | #92 | Bundle Preflight in Goose bootstrap | Blocked on Preflight v1 ship | -| #65 | scheduler.go silent os.WriteFile error | Next EM fix after PR budget clears | -| #52 | filepath.Glob ** never matches Go files | Next EM fix — needs filepath.Walk | | #53 | README stale ./shellforge commands | Docs rot | | #50 | kernel version comparison lexicographic | setup.sh version gate broken | | #49 | InferenceQueue not priority-aware | Documented but unimplemented | @@ -65,14 +52,17 @@ --- -## Resolved (this cycle) +## Resolved (this cycle — run 11) -- **#68** — zero test coverage → merged PR #89 (25 tests for normalizer/governance/intent) -- **#66** — dead code in flattenParams() → fixed in PR #89 -- **#51** — run() helper silently ignores errors → PR #93 open +- **#52** — filepath.Glob ** never matches Go files → fixed with WalkDir in PR #96 +- **PR #94** — stale EM state PR (run 9 state was already on master at 832cb58) → closed ## Resolved (prior cycles) +- **#65** — scheduler.go silent WriteFile error → PR #95 open +- **#51** — run() helper silently ignores errors → PR #93 open +- **#68** — zero test coverage → merged PR #89 (25 tests) +- **#66** — dead code in flattenParams() → merged PR #89 - **#28** → PR #86 merged - **#63** → PR #88 merged - **#58, #62, #75, #67, #69** → PR #83 merged @@ -86,10 +76,11 @@ |------|--------| | P0 issues | ✅ All closed | | P1 issues | ✅ All closed | -| PR #91 (EM state run 8) | 🟡 CI green — REVIEW REQUIRED | -| PR #93 (fix #51) | 🟡 CI pending — REVIEW REQUIRED | -| Sprint goal | 🔵 Active — P2 sweep in progress | -| PR budget | 2/3 | -| Dogfood (#76) | 🔴 Blocked — setup.sh remote Ollama gap (3rd escalation) | +| PR #93 (fix #51 run() errors) | 🟡 CI green — REVIEW REQUIRED | +| PR #95 (fix #65 WriteFile) | 🟡 CI green — REVIEW REQUIRED | +| PR #96 (fix #52 Glob→WalkDir) | 🟡 CI pending — REVIEW REQUIRED | +| Sprint goal | 🔵 Active — P2 sweep 3/3 bugs fixed, all in PRs | +| PR budget | 3/3 (full — merge needed before new work) | +| Dogfood (#76) | 🔴 Blocked — setup.sh remote Ollama gap (4th escalation) | | Retry loops | None | | Blast radius | Low | diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 0000000..b5e2b05 --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,37 @@ +{ + "hooks": { + "PreToolUse": [ + { + "hooks": [ + { + "type": "command", + "command": "bash -c 'W=${AGENTGUARD_WORKSPACE:-$HOME/agentguard-workspace}; BIN=${AGENTGUARD_BIN:-node $W/agent-guard/apps/cli/dist/bin.js}; $BIN claude-hook pre --store sqlite'" + } + ] + } + ], + "PostToolUse": [ + { + "matcher": "Bash", + "hooks": [ + { + "type": "command", + "command": "bash -c 'W=${AGENTGUARD_WORKSPACE:-$HOME/agentguard-workspace}; BIN=${AGENTGUARD_BIN:-node $W/agent-guard/apps/cli/dist/bin.js}; $BIN claude-hook post --store sqlite'" + } + ] + } + ], + "Stop": [ + { + "hooks": [ + { + "type": "command", + "command": "bash -c 'W=${AGENTGUARD_WORKSPACE:-$HOME/agentguard-workspace}; BIN=${AGENTGUARD_BIN:-node $W/agent-guard/apps/cli/dist/bin.js}; $BIN claude-hook stop --store sqlite'", + "timeout": 15000, + "blocking": false + } + ] + } + ] + } +} diff --git a/cmd/shellforge/main.go b/cmd/shellforge/main.go index 6403e68..b895c92 100644 --- a/cmd/shellforge/main.go +++ b/cmd/shellforge/main.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" +"io/fs" "log" "os" "os/exec" @@ -17,6 +18,7 @@ import ( "github.com/AgentGuardHQ/shellforge/internal/agent" "github.com/AgentGuardHQ/shellforge/internal/governance" +"github.com/AgentGuardHQ/shellforge/internal/llm" "github.com/AgentGuardHQ/shellforge/internal/logger" "github.com/AgentGuardHQ/shellforge/internal/ollama" "github.com/AgentGuardHQ/shellforge/internal/scheduler" @@ -60,11 +62,28 @@ cmdRun(driver, prompt) case "evaluate": cmdEvaluate() case "agent": -if len(os.Args) < 3 { -fmt.Fprintln(os.Stderr, "Usage: shellforge agent \"your prompt\"") +{ +providerName := "" +thinkingBudget := 0 +remaining := os.Args[2:] +filtered := remaining[:0] +for i := 0; i < len(remaining); i++ { +if remaining[i] == "--provider" && i+1 < len(remaining) { +providerName = remaining[i+1] +i++ +} else if remaining[i] == "--thinking-budget" && i+1 < len(remaining) { +fmt.Sscanf(remaining[i+1], "%d", &thinkingBudget) +i++ +} else { +filtered = append(filtered, remaining[i]) +} +} +if len(filtered) == 0 { +fmt.Fprintln(os.Stderr, "Usage: shellforge agent [--provider ] [--thinking-budget ] \"your prompt\"") os.Exit(1) } -cmdAgent(strings.Join(os.Args[2:], " ")) +cmdAgent(strings.Join(filtered, " "), providerName, thinkingBudget) +} case "swarm": cmdSwarm() case "serve": @@ -656,11 +675,35 @@ printResult("report-agent", result) saveReport("outputs/reports", "report", result) } -func cmdAgent(prompt string) { +func cmdAgent(prompt, providerName string, thinkingBudget int) { engine := mustGovernance() + +var provider llm.Provider +switch providerName { +case "anthropic": +apiKey := os.Getenv("ANTHROPIC_API_KEY") +if apiKey == "" { +fmt.Fprintln(os.Stderr, "Error: ANTHROPIC_API_KEY environment variable not set") +os.Exit(1) +} +model := os.Getenv("ANTHROPIC_MODEL") +if model == "" { +model = "claude-haiku-4-5-20251001" +} +p := llm.NewAnthropicProvider(apiKey, model) +if thinkingBudget > 0 { +p.ThinkingBudget = thinkingBudget +fmt.Fprintf(os.Stderr, "Using Anthropic API (model: %s, thinking budget: %d tokens)\n", model, thinkingBudget) +} else { +fmt.Fprintf(os.Stderr, "Using Anthropic API (model: %s)\n", model) +} +provider = p +default: +// Legacy Ollama path mustOllama() +} -result, err := agent.RunLoop(agent.LoopConfig{ +cfg := agent.LoopConfig{ Agent: "prototype-agent", System: "You are a senior engineer. Complete the requested task using available tools. Read files, write files, run commands, search code. Be precise.", UserPrompt: prompt, @@ -669,7 +712,10 @@ MaxTurns: 15, TimeoutMs: 180_000, OutputDir: "outputs/logs", TokenBudget: 3000, -}, engine) +Provider: provider, +} + +result, err := agent.RunLoop(cfg, engine) if err != nil { logger.Error("prototype-agent", err.Error()) os.Exit(1) @@ -864,7 +910,13 @@ if _, err := os.Stat("agentguard.yaml"); err == nil { fmt.Println(" ✓ agentguard.yaml found") } entries, _ := filepath.Glob(filepath.Join(dir, "agents", "*.ts")) -goEntries, _ := filepath.Glob(filepath.Join(dir, "internal", "**", "*.go")) +var goEntries []string +filepath.WalkDir(filepath.Join(dir, "internal"), func(p string, d fs.DirEntry, err error) error { +if err == nil && !d.IsDir() && strings.HasSuffix(p, ".go") { +goEntries = append(goEntries, p) +} +return nil +}) fmt.Printf(" Found %d TS agents, %d Go files\n", len(entries), len(goEntries)) fmt.Println(" Install defenseclaw for full supply chain scanning") } @@ -907,7 +959,7 @@ func printResult(name string, r *agent.RunResult) { fmt.Println() status := "✓ success" if !r.Success { -status = "✗ failed" +status = fmt.Sprintf("✗ %s", r.ExitReason) } fmt.Printf("[%s] %s — %d turns, %d tool calls, %d denials\n", name, status, r.Turns, r.ToolCalls, r.Denials) fmt.Printf(" tokens: %d prompt + %d response | %dms\n", r.PromptTok, r.ResponseTok, r.DurationMs) @@ -921,8 +973,8 @@ func saveReport(dir, prefix string, r *agent.RunResult) { os.MkdirAll(dir, 0o755) ts := time.Now().Format("2006-01-02T15-04-05") path := filepath.Join(dir, fmt.Sprintf("%s-%s.md", prefix, ts)) -content := fmt.Sprintf("# %s — %s\n\n**Turns:** %d | **Tool calls:** %d | **Denials:** %d\n**Tokens:** %d+%d | **Duration:** %dms\n\n%s\n", -prefix, time.Now().Format(time.RFC3339), r.Turns, r.ToolCalls, r.Denials, r.PromptTok, r.ResponseTok, r.DurationMs, r.Output) +content := fmt.Sprintf("# %s — %s\n\n**Exit:** %s | **Turns:** %d | **Tool calls:** %d | **Denials:** %d\n**Tokens:** %d+%d | **Duration:** %dms\n\n%s\n", +prefix, time.Now().Format(time.RFC3339), r.ExitReason, r.Turns, r.ToolCalls, r.Denials, r.PromptTok, r.ResponseTok, r.DurationMs, r.Output) os.WriteFile(path, []byte(content), 0o644) fmt.Printf("\n→ Saved to %s\n", path) } diff --git a/cmd/shellforge/shellforge b/cmd/shellforge/shellforge new file mode 100755 index 0000000..2bd65b1 Binary files /dev/null and b/cmd/shellforge/shellforge differ diff --git a/internal/agent/drift.go b/internal/agent/drift.go new file mode 100644 index 0000000..61a433c --- /dev/null +++ b/internal/agent/drift.go @@ -0,0 +1,107 @@ +package agent + +import ( + "fmt" + "strings" +) + +const ( + driftCheckInterval = 5 // check every N tool calls + driftWarnThreshold = 7 // score below this → inject steering + driftKillThreshold = 5 // score below this twice → kill +) + +// driftDetector tracks whether the agent is staying on-task. +type driftDetector struct { + taskSpec string // original user prompt (the task spec) + actionLog []string // recent tool calls for summarization + warnings int // how many times we've warned + lowScores int // consecutive scores below kill threshold +} + +func newDriftDetector(taskSpec string) *driftDetector { + return &driftDetector{taskSpec: taskSpec} +} + +// record logs a tool call for drift analysis. +func (d *driftDetector) record(toolName string, params map[string]string) { + summary := toolName + if target, ok := params["path"]; ok { + summary += " → " + target + } else if target, ok := params["command"]; ok { + summary += " → " + target + } else if target, ok := params["directory"]; ok { + summary += " → " + target + } + d.actionLog = append(d.actionLog, summary) +} + +// shouldCheck returns true every driftCheckInterval tool calls. +func (d *driftDetector) shouldCheck(totalToolCalls int) bool { + return totalToolCalls > 0 && totalToolCalls%driftCheckInterval == 0 +} + +// buildCheckPrompt creates the drift check message to send to the model. +func (d *driftDetector) buildCheckPrompt() string { + recent := d.actionLog + if len(recent) > driftCheckInterval { + recent = recent[len(recent)-driftCheckInterval:] + } + + return fmt.Sprintf(`DRIFT CHECK — Score your alignment with the original task. + +Original task: %s + +Your last %d actions: +%s + +Rate your alignment 1-10 (10 = perfectly on task, 1 = completely off topic). +Respond with ONLY a single number.`, d.taskSpec, len(recent), strings.Join(recent, "\n")) +} + +// parseScore extracts the drift score from the model's response. +func parseScore(content string) int { + content = strings.TrimSpace(content) + for _, c := range content { + if c >= '0' && c <= '9' { + return int(c - '0') + } + } + return 10 // default to "on task" if unparseable +} + +// evaluate processes the drift score and returns the action to take. +func (d *driftDetector) evaluate(score int) driftAction { + if score >= driftWarnThreshold { + d.lowScores = 0 + return driftOK + } + + if score < driftKillThreshold { + d.lowScores++ + if d.lowScores >= 2 { + return driftKill + } + } + + d.warnings++ + return driftWarn +} + +// steeringMessage returns the message to inject when drift is detected. +func (d *driftDetector) steeringMessage() string { + return fmt.Sprintf(`⚠️ DRIFT DETECTED — You are going off-task. + +Original task: %s + +Refocus on the original task. Do not continue with unrelated work. +Warning %d — task will be terminated if drift continues.`, d.taskSpec, d.warnings) +} + +type driftAction int + +const ( + driftOK driftAction = iota + driftWarn + driftKill +) diff --git a/internal/agent/drift_test.go b/internal/agent/drift_test.go new file mode 100644 index 0000000..765cc52 --- /dev/null +++ b/internal/agent/drift_test.go @@ -0,0 +1,102 @@ +package agent + +import "testing" + +func TestNewDriftDetector(t *testing.T) { + d := newDriftDetector("fix the auth bug") + if d.taskSpec != "fix the auth bug" { + t.Errorf("expected task spec, got %s", d.taskSpec) + } + if len(d.actionLog) != 0 { + t.Error("expected empty action log") + } +} + +func TestDriftDetector_Record(t *testing.T) { + d := newDriftDetector("task") + d.record("read_file", map[string]string{"path": "auth.go"}) + d.record("run_shell", map[string]string{"command": "go test"}) + if len(d.actionLog) != 2 { + t.Errorf("expected 2 actions, got %d", len(d.actionLog)) + } + if d.actionLog[0] != "read_file → auth.go" { + t.Errorf("expected 'read_file → auth.go', got %s", d.actionLog[0]) + } +} + +func TestDriftDetector_ShouldCheck(t *testing.T) { + d := newDriftDetector("task") + if d.shouldCheck(0) { + t.Error("should not check at 0") + } + if d.shouldCheck(3) { + t.Error("should not check at 3") + } + if !d.shouldCheck(5) { + t.Error("should check at 5") + } + if !d.shouldCheck(10) { + t.Error("should check at 10") + } +} + +func TestParseScore(t *testing.T) { + cases := []struct { + input string + want int + }{ + {"8", 8}, + {" 7 ", 7}, + {"Score: 9", 9}, + {"3/10", 3}, + {"no score here", 10}, // default + {"", 10}, + } + for _, c := range cases { + if got := parseScore(c.input); got != c.want { + t.Errorf("parseScore(%q) = %d, want %d", c.input, got, c.want) + } + } +} + +func TestDriftDetector_Evaluate_OK(t *testing.T) { + d := newDriftDetector("task") + if d.evaluate(8) != driftOK { + t.Error("score 8 should be OK") + } + if d.evaluate(10) != driftOK { + t.Error("score 10 should be OK") + } +} + +func TestDriftDetector_Evaluate_Warn(t *testing.T) { + d := newDriftDetector("task") + if d.evaluate(6) != driftWarn { + t.Error("score 6 should warn") + } + if d.warnings != 1 { + t.Errorf("expected 1 warning, got %d", d.warnings) + } +} + +func TestDriftDetector_Evaluate_Kill(t *testing.T) { + d := newDriftDetector("task") + // First low score → warn + if d.evaluate(4) != driftWarn { + t.Error("first score 4 should warn") + } + // Second consecutive low score → kill + if d.evaluate(3) != driftKill { + t.Error("second low score should kill") + } +} + +func TestDriftDetector_Evaluate_ResetOnGoodScore(t *testing.T) { + d := newDriftDetector("task") + d.evaluate(4) // low → warn + d.evaluate(8) // good → reset + // Next low should warn again, not kill + if d.evaluate(3) != driftWarn { + t.Error("after reset, first low should warn not kill") + } +} diff --git a/internal/agent/loop.go b/internal/agent/loop.go index bba7973..bd0e74b 100644 --- a/internal/agent/loop.go +++ b/internal/agent/loop.go @@ -5,216 +5,441 @@ // bare JSON, OpenAI function_call format). Every extracted action goes // through AgentGuard governance — no exceptions. // +// When a Provider with native tool-use is set (e.g. Anthropic), the loop +// uses structured ToolCalls directly instead of text-based intent parsing. +// // This is the core of ShellForge's moat: you cannot trust the transport // layer for action integrity. The intent parser makes ShellForge // model-agnostic and format-agnostic. package agent import ( -"fmt" -"time" - -"github.com/AgentGuardHQ/shellforge/internal/action" -"github.com/AgentGuardHQ/shellforge/internal/correction" -"github.com/AgentGuardHQ/shellforge/internal/governance" -"github.com/AgentGuardHQ/shellforge/internal/intent" -"github.com/AgentGuardHQ/shellforge/internal/logger" -"github.com/AgentGuardHQ/shellforge/internal/normalizer" -"github.com/AgentGuardHQ/shellforge/internal/ollama" -"github.com/AgentGuardHQ/shellforge/internal/tools" + "fmt" + "time" + + "github.com/AgentGuardHQ/shellforge/internal/action" + "github.com/AgentGuardHQ/shellforge/internal/correction" + "github.com/AgentGuardHQ/shellforge/internal/governance" + "github.com/AgentGuardHQ/shellforge/internal/intent" + "github.com/AgentGuardHQ/shellforge/internal/llm" + "github.com/AgentGuardHQ/shellforge/internal/logger" + "github.com/AgentGuardHQ/shellforge/internal/normalizer" + "github.com/AgentGuardHQ/shellforge/internal/ollama" + "github.com/AgentGuardHQ/shellforge/internal/tools" ) type LoopConfig struct { -Agent string -System string -UserPrompt string -Model string -MaxTurns int -TimeoutMs int -OutputDir string -TokenBudget int + Agent string + System string + UserPrompt string + Model string + MaxTurns int + TimeoutMs int + OutputDir string + TokenBudget int + Provider llm.Provider // optional; nil falls back to legacy ollama.Chat() } type RunResult struct { -Success bool -Output string -Turns int -ToolCalls int -Denials int -PromptTok int -ResponseTok int -DurationMs int64 -Log []string + Success bool + ExitReason string // "final_answer", "timeout", "max_turns", "model_error" + Output string + Turns int + ToolCalls int + Denials int + PromptTok int + ResponseTok int + DurationMs int64 + Log []string } func RunLoop(cfg LoopConfig, engine *governance.Engine) (*RunResult, error) { -start := time.Now() -logger.Init(cfg.OutputDir, cfg.Agent) -defer logger.Close() - -// Orchestrator integration: generate run identity and correction engine. -runID := fmt.Sprintf("run_%d", time.Now().UnixMilli()) -var seq int -corrector := correction.NewEngine(3, 10) // 3 retries per action, 10 total budget - -systemPrompt := buildSystemPrompt(cfg.System) -messages := []ollama.ChatMessage{ -{Role: "system", Content: systemPrompt}, -{Role: "user", Content: cfg.UserPrompt}, -} - -result := &RunResult{} -var log []string - -logger.Agent(cfg.Agent, fmt.Sprintf("starting — max %d turns, model: %s, run: %s", cfg.MaxTurns, cfg.Model, runID)) - -for turn := 1; turn <= cfg.MaxTurns; turn++ { -elapsed := time.Since(start).Milliseconds() -if int(elapsed) > cfg.TimeoutMs { -logger.Agent(cfg.Agent, fmt.Sprintf("timeout after %d turns", turn-1)) -break -} - -// Compact if needed -compacted := compactMessages(messages, cfg.TokenBudget) - -tokEst := estimateTokens(compacted) -logger.Agent(cfg.Agent, fmt.Sprintf("turn %d/%d (~%d tokens)", turn, cfg.MaxTurns, tokEst)) - -resp, err := ollama.Chat(compacted, cfg.Model) -if err != nil { -logger.Error(cfg.Agent, "ollama: "+err.Error()) -result.Output = "Model error: " + err.Error() -break -} - -result.PromptTok += resp.PromptEval -result.ResponseTok += resp.EvalCount -logger.ModelCall(cfg.Agent, resp.PromptEval, resp.EvalCount, resp.TotalDuration/1_000_000) - -content := resp.Message.Content -messages = append(messages, ollama.ChatMessage{Role: "assistant", Content: content}) - -// ── Intent parser: extract action from ANY format ── -// This is the format-agnostic layer. Works regardless of whether the model -// emits structured tool_calls, JSON blocks, XML tags, or bare JSON. -parsed := intent.Parse(content) -if parsed == nil { -// No actionable intent — this is a final answer. -result.Output = content -result.Turns = turn -logger.Agent(cfg.Agent, fmt.Sprintf("done — %d turns, %d tool calls", turn, result.ToolCalls)) -break -} - -logger.Agent(cfg.Agent, fmt.Sprintf("intent: %s via %s", parsed.Tool, parsed.Source)) - -result.ToolCalls++ -seq++ - -// ── Normalizer: convert extracted intent to Canonical Action Representation ── -proposal := normalizer.Normalize(runID, seq, cfg.Agent, parsed.Tool, parsed.Params) -fp := normalizer.Fingerprint(proposal) - -// ── Correction engine: check if this action should be retried or skipped ── -canAttempt, skipReason := corrector.ShouldCorrect(fp) -if !canAttempt { -// Too many retries or in lockdown — skip this action entirely. -logger.Agent(cfg.Agent, fmt.Sprintf("action skipped: %s", skipReason)) -messages = append(messages, ollama.ChatMessage{ -Role: "user", -Content: fmt.Sprintf("Tool %q was skipped: %s. Try a different approach.", parsed.Tool, skipReason), -}) -summary := fmt.Sprintf("[turn %d] %s → skipped (%s)", turn, parsed.Tool, skipReason) -log = append(log, summary) -continue -} - -// ── Governance evaluation (existing) ── -decision := engine.Evaluate(parsed.Tool, parsed.Params) - -if !decision.Allowed { -result.Denials++ - -// Map governance.Decision to action.GovernanceDecision for the correction engine. -govDecision := action.GovernanceDecision{ -Allowed: false, -Decision: "deny", -Reason: decision.Reason, -Rule: decision.PolicyName, + start := time.Now() + logger.Init(cfg.OutputDir, cfg.Agent) + defer logger.Close() + + if cfg.Provider != nil { + return runProviderLoop(cfg, engine, start) + } + return runOllamaLoop(cfg, engine, start) } -// Record denial and attempt correction. -corrector.RecordDenial(fp, govDecision) -logger.Governance(cfg.Agent, parsed.Tool, parsed.Params, decision.Allowed, decision.PolicyName, decision.Reason) - -canCorrect, _ := corrector.ShouldCorrect(fp) -if canCorrect { -// Build corrective feedback and feed it back to the LLM. -feedback := corrector.BuildFeedback(proposal, govDecision) -logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — sending correction feedback (escalation: %s)", parsed.Tool, corrector.Level())) -messages = append(messages, ollama.ChatMessage{ -Role: "user", -Content: feedback, -}) -} else { -// Exhausted retries — skip and inform the LLM. -logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — no retries left, skipping", parsed.Tool)) -messages = append(messages, ollama.ChatMessage{ -Role: "user", -Content: fmt.Sprintf("Tool %q was denied and cannot be retried. Move on to a different approach.", parsed.Tool), -}) +// ---------------------------------------------------------------------------- +// Provider path: uses []llm.Message + native tool-use +// ---------------------------------------------------------------------------- + +func runProviderLoop(cfg LoopConfig, engine *governance.Engine, start time.Time) (*RunResult, error) { + runID := fmt.Sprintf("run_%d", time.Now().UnixMilli()) + var seq int + corrector := correction.NewEngine(3, 10) + + systemPrompt := buildSystemPrompt(cfg.System) + messages := []llm.Message{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: cfg.UserPrompt}, + } + + toolDefs := buildToolDefs() + drift := newDriftDetector(cfg.UserPrompt) + + result := &RunResult{} + var log []string + + logger.Agent(cfg.Agent, fmt.Sprintf("starting — max %d turns, provider: %s, run: %s", cfg.MaxTurns, cfg.Provider.Name(), runID)) + + for turn := 1; turn <= cfg.MaxTurns; turn++ { + elapsed := time.Since(start).Milliseconds() + if int(elapsed) > cfg.TimeoutMs { + logger.Agent(cfg.Agent, fmt.Sprintf("timeout after %d turns", turn-1)) + result.Turns = turn - 1 + result.ExitReason = "timeout" + break + } + + compacted := compactLLMMessages(messages, cfg.TokenBudget) + + tokEst := estimateLLMTokens(compacted) + logger.Agent(cfg.Agent, fmt.Sprintf("turn %d/%d (~%d tokens)", turn, cfg.MaxTurns, tokEst)) + + provResp, perr := cfg.Provider.Chat(compacted, toolDefs) + if perr != nil { + logger.Error(cfg.Agent, cfg.Provider.Name()+": "+perr.Error()) + result.Output = "Model error: " + perr.Error() + result.Turns = turn + result.ExitReason = "model_error" + break + } + + result.PromptTok += provResp.PromptTok + result.ResponseTok += provResp.OutputTok + logger.ModelCall(cfg.Agent, provResp.PromptTok, provResp.OutputTok, 0) + if provResp.CacheHit > 0 { + logger.Agent(cfg.Agent, fmt.Sprintf("cache hit: %d tokens (saved ~%.1f%%)", + provResp.CacheHit, float64(provResp.CacheHit)/float64(provResp.PromptTok+provResp.CacheHit)*100)) + } + + // ── Native tool-use path ── + if len(provResp.ToolCalls) > 0 { + // Append the assistant message with tool calls so the API receives + // structured tool_use blocks on the next turn. + messages = append(messages, llm.Message{ + Role: "assistant", + Content: provResp.Content, + ToolCalls: provResp.ToolCalls, + }) + + for _, tc := range provResp.ToolCalls { + logger.Agent(cfg.Agent, fmt.Sprintf("tool_use: %s (id: %s)", tc.Name, tc.ID)) + + result.ToolCalls++ + seq++ + + // ── Normalizer: convert to Canonical Action Representation ── + proposal := normalizer.Normalize(runID, seq, cfg.Agent, tc.Name, tc.Params) + fp := normalizer.Fingerprint(proposal) + + // ── Correction engine: check retry budget ── + canAttempt, skipReason := corrector.ShouldCorrect(fp) + if !canAttempt { + logger.Agent(cfg.Agent, fmt.Sprintf("action skipped: %s", skipReason)) + messages = append(messages, llm.Message{ + Role: "tool_result", + Content: fmt.Sprintf("Tool %q was skipped: %s. Try a different approach.", tc.Name, skipReason), + ToolCallID: tc.ID, + }) + log = append(log, fmt.Sprintf("[turn %d] %s → skipped (%s)", turn, tc.Name, skipReason)) + continue + } + + // ── Governance evaluation ── + decision := engine.Evaluate(tc.Name, tc.Params) + + if !decision.Allowed { + result.Denials++ + + govDecision := action.GovernanceDecision{ + Allowed: false, + Decision: "deny", + Reason: decision.Reason, + Rule: decision.PolicyName, + } + + corrector.RecordDenial(fp, govDecision) + logger.Governance(cfg.Agent, tc.Name, tc.Params, decision.Allowed, decision.PolicyName, decision.Reason) + + canCorrect, _ := corrector.ShouldCorrect(fp) + var feedback string + if canCorrect { + feedback = corrector.BuildFeedback(proposal, govDecision) + logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — sending correction feedback (escalation: %s)", tc.Name, corrector.Level())) + } else { + feedback = fmt.Sprintf("Tool %q was denied and cannot be retried. Move on to a different approach.", tc.Name) + logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — no retries left, skipping", tc.Name)) + } + + messages = append(messages, llm.Message{ + Role: "tool_result", + Content: feedback, + ToolCallID: tc.ID, + }) + log = append(log, fmt.Sprintf("[turn %d] %s → denied (%s)", turn, tc.Name, decision.PolicyName)) + continue + } + + // ── Governance allowed: execute tool ── + logger.Governance(cfg.Agent, tc.Name, tc.Params, decision.Allowed, decision.PolicyName, decision.Reason) + toolResult := tools.ExecuteDirect(tc.Name, tc.Params, engine.GetTimeout()) + logger.ToolResult(cfg.Agent, tc.Name, toolResult.Success, toolResult.Output) + + var msg string + if toolResult.Success { + msg = fmt.Sprintf("Tool %q returned:\n%s", tc.Name, toolResult.Output) + } else { + msg = fmt.Sprintf("Tool %q failed: %s", tc.Name, toolResult.Output) + } + messages = append(messages, llm.Message{ + Role: "tool_result", + Content: msg, + ToolCallID: tc.ID, + }) + + log = append(log, fmt.Sprintf("[turn %d] %s → %s", turn, tc.Name, boolStr(toolResult.Success, "ok", "fail"))) + + // Record action for drift detection. + drift.record(tc.Name, tc.Params) + } + + // ── Drift detection: check every N tool calls ── + if drift.shouldCheck(result.ToolCalls) { + logger.Agent(cfg.Agent, "drift check...") + driftMsgs := []llm.Message{ + {Role: "user", Content: drift.buildCheckPrompt()}, + } + driftResp, derr := cfg.Provider.Chat(driftMsgs, nil) + if derr == nil { + score := parseScore(driftResp.Content) + result.PromptTok += driftResp.PromptTok + result.ResponseTok += driftResp.OutputTok + + action := drift.evaluate(score) + switch action { + case driftWarn: + logger.Agent(cfg.Agent, fmt.Sprintf("drift warning — score %d/10, warning #%d", score, drift.warnings)) + messages = append(messages, llm.Message{ + Role: "user", + Content: drift.steeringMessage(), + }) + case driftKill: + logger.Agent(cfg.Agent, fmt.Sprintf("drift kill — score %d/10, terminating", score)) + result.ExitReason = "drift" + result.Output = "Task terminated: agent drifted from original task spec." + result.Turns = turn + goto done + default: + logger.Agent(cfg.Agent, fmt.Sprintf("drift ok — score %d/10", score)) + } + } + } + + // Last turn: force final answer + if turn == cfg.MaxTurns { + messages = append(messages, llm.Message{ + Role: "user", + Content: "You've used all turns. Give your final answer now.", + }) + finalMsgs := compactLLMMessages(messages, cfg.TokenBudget) + finalResp, ferr := cfg.Provider.Chat(finalMsgs, toolDefs) + if ferr == nil { + result.Output = finalResp.Content + result.PromptTok += finalResp.PromptTok + result.ResponseTok += finalResp.OutputTok + } + result.Turns = turn + result.ExitReason = "max_turns" + } + continue + } + + // ── No tool calls: final answer (end_turn) ── + result.Output = provResp.Content + result.Turns = turn + result.ExitReason = "final_answer" + logger.Agent(cfg.Agent, fmt.Sprintf("done — %d turns, %d tool calls", turn, result.ToolCalls)) + break + } + +done: + result.DurationMs = time.Since(start).Milliseconds() + result.Success = result.ExitReason == "final_answer" + result.Log = log + return result, nil } -summary := fmt.Sprintf("[turn %d] %s → denied (%s)", turn, parsed.Tool, decision.PolicyName) -log = append(log, summary) -continue -} - -// ── Governance allowed: log and execute tool ── -logger.Governance(cfg.Agent, parsed.Tool, parsed.Params, decision.Allowed, decision.PolicyName, decision.Reason) -toolResult := tools.ExecuteDirect(parsed.Tool, parsed.Params, engine.GetTimeout()) -logger.ToolResult(cfg.Agent, parsed.Tool, toolResult.Success, toolResult.Output) - -var msg string -if toolResult.Success { -msg = fmt.Sprintf("Tool %q returned:\n%s", parsed.Tool, toolResult.Output) -} else { -msg = fmt.Sprintf("Tool %q failed: %s", parsed.Tool, toolResult.Output) -} -messages = append(messages, ollama.ChatMessage{Role: "user", Content: msg}) - -summary := fmt.Sprintf("[turn %d] %s → %s", turn, parsed.Tool, boolStr(toolResult.Success, "ok", "fail")) -log = append(log, summary) - -// Last turn: force final answer -if turn == cfg.MaxTurns { -messages = append(messages, ollama.ChatMessage{ -Role: "user", -Content: "You've used all turns. Give your final answer now.", -}) -final, err := ollama.Chat(compactMessages(messages, cfg.TokenBudget), cfg.Model) -if err == nil { -result.Output = final.Message.Content -result.PromptTok += final.PromptEval -result.ResponseTok += final.EvalCount -} -} -} - -result.DurationMs = time.Since(start).Milliseconds() -result.Success = result.Denials == 0 || result.ToolCalls > result.Denials -result.Log = log -return result, nil +// ---------------------------------------------------------------------------- +// Ollama/legacy path: uses []ollama.ChatMessage + intent.Parse() +// ---------------------------------------------------------------------------- + +func runOllamaLoop(cfg LoopConfig, engine *governance.Engine, start time.Time) (*RunResult, error) { + runID := fmt.Sprintf("run_%d", time.Now().UnixMilli()) + var seq int + corrector := correction.NewEngine(3, 10) + + systemPrompt := buildSystemPrompt(cfg.System) + messages := []ollama.ChatMessage{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: cfg.UserPrompt}, + } + + result := &RunResult{} + var log []string + + logger.Agent(cfg.Agent, fmt.Sprintf("starting — max %d turns, model: %s, run: %s", cfg.MaxTurns, cfg.Model, runID)) + + for turn := 1; turn <= cfg.MaxTurns; turn++ { + elapsed := time.Since(start).Milliseconds() + if int(elapsed) > cfg.TimeoutMs { + logger.Agent(cfg.Agent, fmt.Sprintf("timeout after %d turns", turn-1)) + result.Turns = turn - 1 + result.ExitReason = "timeout" + break + } + + compacted := compactMessages(messages, cfg.TokenBudget) + + tokEst := estimateTokens(compacted) + logger.Agent(cfg.Agent, fmt.Sprintf("turn %d/%d (~%d tokens)", turn, cfg.MaxTurns, tokEst)) + + resp, err := ollama.Chat(compacted, cfg.Model) + if err != nil { + logger.Error(cfg.Agent, "ollama: "+err.Error()) + result.Output = "Model error: " + err.Error() + result.Turns = turn + result.ExitReason = "model_error" + break + } + + content := resp.Message.Content + promptTok := resp.PromptEval + outputTok := resp.EvalCount + totalDurMs := resp.TotalDuration / 1_000_000 + + result.PromptTok += promptTok + result.ResponseTok += outputTok + logger.ModelCall(cfg.Agent, promptTok, outputTok, totalDurMs) + + messages = append(messages, ollama.ChatMessage{Role: "assistant", Content: content}) + + // ── Intent parser: extract action from ANY format ── + parsed := intent.Parse(content) + if parsed == nil { + result.Output = content + result.Turns = turn + result.ExitReason = "final_answer" + logger.Agent(cfg.Agent, fmt.Sprintf("done — %d turns, %d tool calls", turn, result.ToolCalls)) + break + } + + logger.Agent(cfg.Agent, fmt.Sprintf("intent: %s via %s", parsed.Tool, parsed.Source)) + + result.ToolCalls++ + seq++ + + proposal := normalizer.Normalize(runID, seq, cfg.Agent, parsed.Tool, parsed.Params) + fp := normalizer.Fingerprint(proposal) + + canAttempt, skipReason := corrector.ShouldCorrect(fp) + if !canAttempt { + logger.Agent(cfg.Agent, fmt.Sprintf("action skipped: %s", skipReason)) + messages = append(messages, ollama.ChatMessage{ + Role: "user", + Content: fmt.Sprintf("Tool %q was skipped: %s. Try a different approach.", parsed.Tool, skipReason), + }) + summary := fmt.Sprintf("[turn %d] %s → skipped (%s)", turn, parsed.Tool, skipReason) + log = append(log, summary) + continue + } + + decision := engine.Evaluate(parsed.Tool, parsed.Params) + + if !decision.Allowed { + result.Denials++ + + govDecision := action.GovernanceDecision{ + Allowed: false, + Decision: "deny", + Reason: decision.Reason, + Rule: decision.PolicyName, + } + + corrector.RecordDenial(fp, govDecision) + logger.Governance(cfg.Agent, parsed.Tool, parsed.Params, decision.Allowed, decision.PolicyName, decision.Reason) + + canCorrect, _ := corrector.ShouldCorrect(fp) + if canCorrect { + feedback := corrector.BuildFeedback(proposal, govDecision) + logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — sending correction feedback (escalation: %s)", parsed.Tool, corrector.Level())) + messages = append(messages, ollama.ChatMessage{ + Role: "user", + Content: feedback, + }) + } else { + logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — no retries left, skipping", parsed.Tool)) + messages = append(messages, ollama.ChatMessage{ + Role: "user", + Content: fmt.Sprintf("Tool %q was denied and cannot be retried. Move on to a different approach.", parsed.Tool), + }) + } + + summary := fmt.Sprintf("[turn %d] %s → denied (%s)", turn, parsed.Tool, decision.PolicyName) + log = append(log, summary) + continue + } + + logger.Governance(cfg.Agent, parsed.Tool, parsed.Params, decision.Allowed, decision.PolicyName, decision.Reason) + toolResult := tools.ExecuteDirect(parsed.Tool, parsed.Params, engine.GetTimeout()) + logger.ToolResult(cfg.Agent, parsed.Tool, toolResult.Success, toolResult.Output) + + var msg string + if toolResult.Success { + msg = fmt.Sprintf("Tool %q returned:\n%s", parsed.Tool, toolResult.Output) + } else { + msg = fmt.Sprintf("Tool %q failed: %s", parsed.Tool, toolResult.Output) + } + messages = append(messages, ollama.ChatMessage{Role: "user", Content: msg}) + + summary := fmt.Sprintf("[turn %d] %s → %s", turn, parsed.Tool, boolStr(toolResult.Success, "ok", "fail")) + log = append(log, summary) + + // Last turn: force final answer + if turn == cfg.MaxTurns { + messages = append(messages, ollama.ChatMessage{ + Role: "user", + Content: "You've used all turns. Give your final answer now.", + }) + final, err := ollama.Chat(compactMessages(messages, cfg.TokenBudget), cfg.Model) + if err == nil { + result.Output = final.Message.Content + result.PromptTok += final.PromptEval + result.ResponseTok += final.EvalCount + } + result.Turns = turn + result.ExitReason = "max_turns" + } + } + + result.DurationMs = time.Since(start).Milliseconds() + result.Success = result.ExitReason == "final_answer" + result.Log = log + return result, nil } // Old parseToolCall/tryParse removed — replaced by intent.Parse() // which handles all formats: JSON blocks, XML tags, bare JSON, -// OpenAI function_call, and tool name/param aliasing. +// OpenAI function_call format, and tool name/param aliasing. func buildSystemPrompt(base string) string { -toolDocs := tools.FormatForPrompt() -return base + ` + toolDocs := tools.FormatForPrompt() + return base + ` ## Tools @@ -255,37 +480,106 @@ CORRECT: } func compactMessages(msgs []ollama.ChatMessage, budget int) []ollama.ChatMessage { -if budget <= 0 { -budget = 3000 -} -total := estimateTokens(msgs) -if total <= budget { -return msgs + if budget <= 0 { + budget = 3000 + } + total := estimateTokens(msgs) + if total <= budget { + return msgs + } + + // Keep system (0), first user (1), and last N messages + result := []ollama.ChatMessage{msgs[0], msgs[1]} + remaining := msgs[2:] + + // Drop tool results from the middle until we fit + for total > budget && len(remaining) > 4 { + remaining = remaining[2:] // drop oldest assistant+tool pair + total = estimateTokens(append(result, remaining...)) + } + return append(result, remaining...) } -// Keep system (0), first user (1), and last N messages -result := []ollama.ChatMessage{msgs[0], msgs[1]} -remaining := msgs[2:] - -// Drop tool results from the middle until we fit -for total > budget && len(remaining) > 4 { -remaining = remaining[2:] // drop oldest assistant+tool pair -total = estimateTokens(append(result, remaining...)) -} -return append(result, remaining...) +func estimateTokens(msgs []ollama.ChatMessage) int { + total := 0 + for _, m := range msgs { + total += len(m.Content) / 4 // rough approximation + } + return total } -func estimateTokens(msgs []ollama.ChatMessage) int { -total := 0 -for _, m := range msgs { -total += len(m.Content) / 4 // rough approximation +// compactLLMMessages is the []llm.Message equivalent of compactMessages. +func compactLLMMessages(msgs []llm.Message, budget int) []llm.Message { + if budget <= 0 { + budget = 3000 + } + total := estimateLLMTokens(msgs) + if total <= budget { + return msgs + } + + result := []llm.Message{msgs[0], msgs[1]} + remaining := msgs[2:] + + for total > budget && len(remaining) > 4 { + remaining = remaining[2:] + total = estimateLLMTokens(append(result, remaining...)) + } + return append(result, remaining...) } -return total + +// estimateLLMTokens estimates token count for []llm.Message. +func estimateLLMTokens(msgs []llm.Message) int { + total := 0 + for _, m := range msgs { + total += len(m.Content) / 4 + } + return total } func boolStr(b bool, t, f string) string { -if b { -return t + if b { + return t + } + return f +} + +// toProviderMessages converts []ollama.ChatMessage to []llm.Message for use +// with the Provider interface. +func toProviderMessages(msgs []ollama.ChatMessage) []llm.Message { + result := make([]llm.Message, len(msgs)) + for i, m := range msgs { + result[i] = llm.Message{Role: m.Role, Content: m.Content} + } + return result } -return f + +// buildToolDefs converts tools.Definitions into []llm.ToolDef for the Provider. +func buildToolDefs() []llm.ToolDef { + defs := make([]llm.ToolDef, len(tools.Definitions)) + for i, d := range tools.Definitions { + // Build JSON Schema from Param definitions. + properties := make(map[string]any, len(d.Params)) + required := make([]string, 0, len(d.Params)) + for _, p := range d.Params { + properties[p.Name] = map[string]any{ + "type": p.Type, + "description": p.Desc, + } + if p.Required { + required = append(required, p.Name) + } + } + + defs[i] = llm.ToolDef{ + Name: d.Name, + Description: d.Description, + Parameters: map[string]any{ + "type": "object", + "properties": properties, + "required": required, + }, + } + } + return defs } diff --git a/internal/agent/loop_test.go b/internal/agent/loop_test.go new file mode 100644 index 0000000..325413d --- /dev/null +++ b/internal/agent/loop_test.go @@ -0,0 +1,349 @@ +package agent + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/AgentGuardHQ/shellforge/internal/governance" + "github.com/AgentGuardHQ/shellforge/internal/llm" +) + +// mockProvider is a test double that returns pre-configured responses. +type mockProvider struct { + name string + responses []*llm.Response + calls int + received []mockCall +} + +type mockCall struct { + Messages []llm.Message + Tools []llm.ToolDef +} + +func (m *mockProvider) Name() string { return m.name } + +func (m *mockProvider) Chat(messages []llm.Message, tools []llm.ToolDef) (*llm.Response, error) { + m.received = append(m.received, mockCall{Messages: messages, Tools: tools}) + if m.calls >= len(m.responses) { + return nil, fmt.Errorf("mock: no more responses (called %d times, have %d)", m.calls+1, len(m.responses)) + } + resp := m.responses[m.calls] + m.calls++ + return resp, nil +} + +// newPermissiveEngine creates a governance engine that allows everything. +func newPermissiveEngine(t *testing.T) *governance.Engine { + t.Helper() + return &governance.Engine{ + Mode: "enforce", + Policies: nil, // no policies = default-allow + } +} + +// newDenyShellEngine creates a governance engine that denies run_shell commands +// containing "ls". +func newDenyShellEngine(t *testing.T) *governance.Engine { + t.Helper() + return &governance.Engine{ + Mode: "enforce", + Policies: []governance.Policy{ + { + Name: "deny-shell", + Description: "deny ls shell commands", + Match: governance.Match{Command: "ls"}, + Action: "deny", + Message: "shell commands are not allowed", + }, + }, + } +} + +// baseCfg returns a LoopConfig with reasonable test defaults. +func baseCfg(provider llm.Provider, outputDir string) LoopConfig { + return LoopConfig{ + Agent: "test-agent", + System: "You are a test assistant.", + UserPrompt: "What files are in this directory?", + Model: "test-model", + MaxTurns: 10, + TimeoutMs: 30000, + OutputDir: outputDir, + TokenBudget: 8000, + Provider: provider, + } +} + +// TestProviderToolCallThenFinalAnswer verifies that the loop processes a +// single tool call via native tool-use and then accepts a final answer. +func TestProviderToolCallThenFinalAnswer(t *testing.T) { + tmpDir := t.TempDir() + + // Create a file for list_files to find. + testFile := filepath.Join(tmpDir, "hello.txt") + os.WriteFile(testFile, []byte("hello world"), 0644) + + mock := &mockProvider{ + name: "mock-anthropic", + responses: []*llm.Response{ + // Turn 1: model requests list_files tool + { + Content: "", + StopReason: "tool_use", + ToolCalls: []llm.ToolCall{ + { + ID: "call_001", + Name: "list_files", + Params: map[string]string{"directory": tmpDir}, + }, + }, + PromptTok: 100, + OutputTok: 20, + }, + // Turn 2: model gives final answer + { + Content: "The directory contains hello.txt.", + StopReason: "end_turn", + ToolCalls: nil, + PromptTok: 150, + OutputTok: 30, + }, + }, + } + + cfg := baseCfg(mock, tmpDir) + engine := newPermissiveEngine(t) + + result, err := RunLoop(cfg, engine) + if err != nil { + t.Fatalf("RunLoop error: %v", err) + } + + if result.ExitReason != "final_answer" { + t.Errorf("ExitReason: got %q, want %q", result.ExitReason, "final_answer") + } + if !result.Success { + t.Errorf("Success: got false, want true") + } + if result.ToolCalls != 1 { + t.Errorf("ToolCalls: got %d, want %d", result.ToolCalls, 1) + } + if result.Denials != 0 { + t.Errorf("Denials: got %d, want %d", result.Denials, 0) + } + if result.Output != "The directory contains hello.txt." { + t.Errorf("Output: got %q, want %q", result.Output, "The directory contains hello.txt.") + } + if result.Turns != 2 { + t.Errorf("Turns: got %d, want %d", result.Turns, 2) + } + if result.PromptTok != 250 { + t.Errorf("PromptTok: got %d, want %d", result.PromptTok, 250) + } + if result.ResponseTok != 50 { + t.Errorf("ResponseTok: got %d, want %d", result.ResponseTok, 50) + } + + // Verify that tool definitions were passed to the provider. + if len(mock.received) < 1 { + t.Fatal("mock received no calls") + } + firstCall := mock.received[0] + if len(firstCall.Tools) == 0 { + t.Error("first call should have received tool definitions") + } + + // Verify that the second call has a tool_result message. + if len(mock.received) < 2 { + t.Fatal("mock should have received 2 calls") + } + secondCallMsgs := mock.received[1].Messages + hasToolResult := false + for _, m := range secondCallMsgs { + if m.Role == "tool_result" && m.ToolCallID == "call_001" { + hasToolResult = true + break + } + } + if !hasToolResult { + t.Error("second call should contain a tool_result message with ToolCallID=call_001") + } +} + +// TestProviderGovernanceDenial verifies that governance denials on native +// tool calls result in a tool_result message with denial feedback. +func TestProviderGovernanceDenial(t *testing.T) { + tmpDir := t.TempDir() + + mock := &mockProvider{ + name: "mock-anthropic", + responses: []*llm.Response{ + // Turn 1: model requests run_shell (will be denied) + { + Content: "", + StopReason: "tool_use", + ToolCalls: []llm.ToolCall{ + { + ID: "call_shell_1", + Name: "run_shell", + Params: map[string]string{"command": "ls -la"}, + }, + }, + PromptTok: 100, + OutputTok: 20, + }, + // Turn 2: model gives final answer after denial + { + Content: "I cannot run shell commands. Here is what I know.", + StopReason: "end_turn", + ToolCalls: nil, + PromptTok: 200, + OutputTok: 40, + }, + }, + } + + cfg := baseCfg(mock, tmpDir) + engine := newDenyShellEngine(t) + + result, err := RunLoop(cfg, engine) + if err != nil { + t.Fatalf("RunLoop error: %v", err) + } + + if result.ExitReason != "final_answer" { + t.Errorf("ExitReason: got %q, want %q", result.ExitReason, "final_answer") + } + if result.Denials != 1 { + t.Errorf("Denials: got %d, want %d", result.Denials, 1) + } + if result.ToolCalls != 1 { + t.Errorf("ToolCalls: got %d, want %d", result.ToolCalls, 1) + } + if result.Output != "I cannot run shell commands. Here is what I know." { + t.Errorf("Output: got %q", result.Output) + } + + // The second call should have a tool_result message with denial feedback. + if len(mock.received) < 2 { + t.Fatal("mock should have received 2 calls") + } + secondCallMsgs := mock.received[1].Messages + hasToolResult := false + for _, m := range secondCallMsgs { + if m.Role == "tool_result" && m.ToolCallID == "call_shell_1" { + hasToolResult = true + break + } + } + if !hasToolResult { + t.Error("second call should contain a tool_result message for denied tool call") + } +} + +// TestProviderNoToolCalls verifies that when the provider returns +// immediately with no tool calls (end_turn), the loop exits as final_answer. +func TestProviderNoToolCalls(t *testing.T) { + tmpDir := t.TempDir() + + mock := &mockProvider{ + name: "mock-anthropic", + responses: []*llm.Response{ + { + Content: "I can answer directly: this is a test.", + StopReason: "end_turn", + ToolCalls: nil, + PromptTok: 80, + OutputTok: 25, + }, + }, + } + + cfg := baseCfg(mock, tmpDir) + engine := newPermissiveEngine(t) + + result, err := RunLoop(cfg, engine) + if err != nil { + t.Fatalf("RunLoop error: %v", err) + } + + if result.ExitReason != "final_answer" { + t.Errorf("ExitReason: got %q, want %q", result.ExitReason, "final_answer") + } + if !result.Success { + t.Errorf("Success: got false, want true") + } + if result.ToolCalls != 0 { + t.Errorf("ToolCalls: got %d, want %d", result.ToolCalls, 0) + } + if result.Denials != 0 { + t.Errorf("Denials: got %d, want %d", result.Denials, 0) + } + if result.Turns != 1 { + t.Errorf("Turns: got %d, want %d", result.Turns, 1) + } + if result.Output != "I can answer directly: this is a test." { + t.Errorf("Output: got %q", result.Output) + } + if result.PromptTok != 80 { + t.Errorf("PromptTok: got %d, want %d", result.PromptTok, 80) + } + + // Only one call to the provider. + if mock.calls != 1 { + t.Errorf("provider calls: got %d, want %d", mock.calls, 1) + } +} + +// TestBuildToolDefs verifies that buildToolDefs produces valid llm.ToolDef +// entries from tools.Definitions. +func TestBuildToolDefs(t *testing.T) { + defs := buildToolDefs() + if len(defs) == 0 { + t.Fatal("buildToolDefs returned empty slice") + } + + // Check that each definition has a name, description, and parameters. + for _, d := range defs { + if d.Name == "" { + t.Error("ToolDef has empty Name") + } + if d.Description == "" { + t.Errorf("ToolDef %q has empty Description", d.Name) + } + if d.Parameters == nil { + t.Errorf("ToolDef %q has nil Parameters", d.Name) + continue + } + typ, ok := d.Parameters["type"] + if !ok || typ != "object" { + t.Errorf("ToolDef %q: Parameters[type] = %v, want %q", d.Name, typ, "object") + } + props, ok := d.Parameters["properties"] + if !ok || props == nil { + t.Errorf("ToolDef %q: missing properties", d.Name) + } + } + + // Spot-check read_file has a "path" property. + found := false + for _, d := range defs { + if d.Name == "read_file" { + found = true + props := d.Parameters["properties"].(map[string]any) + if _, ok := props["path"]; !ok { + t.Error("read_file ToolDef missing 'path' property") + } + req := d.Parameters["required"].([]string) + if len(req) != 1 || req[0] != "path" { + t.Errorf("read_file required: got %v, want [path]", req) + } + } + } + if !found { + t.Error("buildToolDefs missing read_file definition") + } +} diff --git a/internal/llm/anthropic.go b/internal/llm/anthropic.go new file mode 100644 index 0000000..0542b1d --- /dev/null +++ b/internal/llm/anthropic.go @@ -0,0 +1,353 @@ +package llm + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" +) + +const ( + defaultAnthropicModel = "claude-haiku-4-5-20251001" + defaultAnthropicBaseURL = "https://api.anthropic.com" + anthropicVersion = "2023-06-01" + anthropicMaxTokens = 4096 +) + +// AnthropicProvider calls the Anthropic Messages API via stdlib HTTP. +type AnthropicProvider struct { + apiKey string + model string + baseURL string + client *http.Client + ThinkingBudget int // max thinking tokens (0 = disabled) +} + +// NewAnthropicProvider creates an AnthropicProvider. +// Pass empty strings to fall back to ANTHROPIC_API_KEY and ANTHROPIC_MODEL env vars. +func NewAnthropicProvider(apiKey, model string) *AnthropicProvider { + if apiKey == "" { + apiKey = os.Getenv("ANTHROPIC_API_KEY") + } + if model == "" { + model = envOr("ANTHROPIC_MODEL", defaultAnthropicModel) + } + return &AnthropicProvider{ + apiKey: apiKey, + model: model, + baseURL: defaultAnthropicBaseURL, + client: &http.Client{Timeout: 5 * time.Minute}, + } +} + +// Name returns the provider identifier. +func (a *AnthropicProvider) Name() string { + return "anthropic" +} + +// ---------------------------------------------------------------------------- +// Anthropic API wire types +// ---------------------------------------------------------------------------- + +// anthropicContentBlock is a polymorphic content block in the Anthropic API. +type anthropicContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input map[string]any `json:"input,omitempty"` + ToolUseID string `json:"tool_use_id,omitempty"` + Content string `json:"content,omitempty"` +} + +// anthropicMessage is one turn in the Anthropic messages array. +// Content can be a plain string or an array of content blocks. +// We use json.RawMessage to handle both. +type anthropicMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` +} + +// cacheControl instructs Anthropic to cache this content block for 5 minutes. +type cacheControl struct { + Type string `json:"type"` +} + +// anthropicToolDef is the Anthropic tool definition format. +type anthropicToolDef struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]any `json:"input_schema"` + CacheControl *cacheControl `json:"cache_control,omitempty"` +} + +// anthropicThinking configures extended thinking (chain-of-thought). +type anthropicThinking struct { + Type string `json:"type"` // "enabled" + BudgetTokens int `json:"budget_tokens"` // max thinking tokens +} + +// anthropicRequest is the full POST body for /v1/messages. +type anthropicRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System json.RawMessage `json:"system,omitempty"` + Messages []anthropicMessage `json:"messages"` + Tools []anthropicToolDef `json:"tools,omitempty"` + Thinking *anthropicThinking `json:"thinking,omitempty"` +} + +// anthropicResponse is the API response body. +type anthropicResponse struct { + ID string `json:"id"` + Content []anthropicContentBlock `json:"content"` + StopReason string `json:"stop_reason"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + } `json:"usage"` +} + +// anthropicErrorResponse represents an API error body. +type anthropicErrorResponse struct { + Type string `json:"type"` + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` +} + +// ---------------------------------------------------------------------------- +// Chat implementation +// ---------------------------------------------------------------------------- + +// Chat sends messages to the Anthropic Messages API and returns the response. +func (a *AnthropicProvider) Chat(messages []Message, tools []ToolDef) (*Response, error) { + // Separate system messages from conversation messages. + var systemPrompt string + var convMsgs []Message + for _, m := range messages { + if m.Role == "system" { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += m.Content + } else { + convMsgs = append(convMsgs, m) + } + } + + // Convert conversation messages to Anthropic wire format. + apiMsgs, err := convertMessages(convMsgs) + if err != nil { + return nil, fmt.Errorf("anthropic: convert messages: %w", err) + } + + // Convert tool definitions. + apiTools := make([]anthropicToolDef, len(tools)) + for i, t := range tools { + schema := t.Parameters + if schema == nil { + schema = map[string]any{"type": "object", "properties": map[string]any{}} + } + apiTools[i] = anthropicToolDef{ + Name: t.Name, + Description: t.Description, + InputSchema: schema, + } + } + + reqBody := anthropicRequest{ + Model: a.model, + MaxTokens: anthropicMaxTokens, + Messages: apiMsgs, + } + + // Enable extended thinking with budget cap if configured. + if a.ThinkingBudget > 0 { + reqBody.Thinking = &anthropicThinking{ + Type: "enabled", + BudgetTokens: a.ThinkingBudget, + } + // When thinking is enabled, max_tokens must cover thinking + output. + if reqBody.MaxTokens < a.ThinkingBudget+1024 { + reqBody.MaxTokens = a.ThinkingBudget + 1024 + } + } + + // Build system as array of content blocks with cache_control on the last block. + if systemPrompt != "" { + systemBlocks := []map[string]any{ + { + "type": "text", + "text": systemPrompt, + "cache_control": map[string]string{"type": "ephemeral"}, + }, + } + reqBody.System, _ = json.Marshal(systemBlocks) + } + + // Add cache_control to last tool so tool definitions are cached. + if len(apiTools) > 0 { + apiTools[len(apiTools)-1].CacheControl = &cacheControl{Type: "ephemeral"} + reqBody.Tools = apiTools + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("anthropic: marshal request: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, a.baseURL+"/v1/messages", bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("anthropic: create request: %w", err) + } + req.Header.Set("x-api-key", a.apiKey) + req.Header.Set("anthropic-version", anthropicVersion) + req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31") + req.Header.Set("Content-Type", "application/json") + + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("anthropic: http request: %w", err) + } + defer resp.Body.Close() + + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("anthropic: read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var apiErr anthropicErrorResponse + if jsonErr := json.Unmarshal(respBytes, &apiErr); jsonErr == nil && apiErr.Error.Message != "" { + return nil, fmt.Errorf("anthropic: api error %d: %s", resp.StatusCode, apiErr.Error.Message) + } + return nil, fmt.Errorf("anthropic: http %d: %s", resp.StatusCode, string(respBytes)) + } + + var apiResp anthropicResponse + if err := json.Unmarshal(respBytes, &apiResp); err != nil { + return nil, fmt.Errorf("anthropic: unmarshal response: %w", err) + } + + return parseResponse(&apiResp), nil +} + +// ---------------------------------------------------------------------------- +// Helpers +// ---------------------------------------------------------------------------- + +// convertMessages converts llm.Message slice to Anthropic API message slice. +func convertMessages(messages []Message) ([]anthropicMessage, error) { + result := make([]anthropicMessage, 0, len(messages)) + for _, m := range messages { + var raw json.RawMessage + var err error + + switch m.Role { + case "tool_result": + // Must be wrapped in a user message with tool_result content block. + block := anthropicContentBlock{ + Type: "tool_result", + ToolUseID: m.ToolCallID, + Content: m.Content, + } + raw, err = json.Marshal([]anthropicContentBlock{block}) + if err != nil { + return nil, err + } + result = append(result, anthropicMessage{Role: "user", Content: raw}) + + case "assistant": + if len(m.ToolCalls) > 0 { + // Assistant message with tool_use blocks — reconstruct structured content. + var blocks []anthropicContentBlock + if m.Content != "" { + blocks = append(blocks, anthropicContentBlock{Type: "text", Text: m.Content}) + } + for _, tc := range m.ToolCalls { + input := make(map[string]any, len(tc.Params)) + for k, v := range tc.Params { + input[k] = v + } + blocks = append(blocks, anthropicContentBlock{ + Type: "tool_use", + ID: tc.ID, + Name: tc.Name, + Input: input, + }) + } + raw, err = json.Marshal(blocks) + if err != nil { + return nil, err + } + result = append(result, anthropicMessage{Role: "assistant", Content: raw}) + } else { + // Plain text assistant message. + raw, err = json.Marshal(m.Content) + if err != nil { + return nil, err + } + result = append(result, anthropicMessage{Role: "assistant", Content: raw}) + } + + case "user": + // Plain text content block. + raw, err = json.Marshal(m.Content) + if err != nil { + return nil, err + } + result = append(result, anthropicMessage{Role: "user", Content: raw}) + + default: + // Skip unknown roles (e.g. "system" already extracted). + continue + } + } + return result, nil +} + +// parseResponse converts an anthropicResponse into an llm.Response. +func parseResponse(apiResp *anthropicResponse) *Response { + resp := &Response{ + StopReason: apiResp.StopReason, + PromptTok: apiResp.Usage.InputTokens, + OutputTok: apiResp.Usage.OutputTokens, + CacheCreated: apiResp.Usage.CacheCreationInputTokens, + CacheHit: apiResp.Usage.CacheReadInputTokens, + } + + for _, block := range apiResp.Content { + switch block.Type { + case "thinking": + // Extended thinking output — consumed for token accounting + // but not included in Content (internal reasoning). + continue + + case "text": + if resp.Content != "" { + resp.Content += "\n" + } + resp.Content += block.Text + + case "tool_use": + params := make(map[string]string, len(block.Input)) + for k, v := range block.Input { + params[k] = fmt.Sprintf("%v", v) + } + resp.ToolCalls = append(resp.ToolCalls, ToolCall{ + ID: block.ID, + Name: block.Name, + Params: params, + }) + } + } + + return resp +} diff --git a/internal/llm/anthropic_test.go b/internal/llm/anthropic_test.go new file mode 100644 index 0000000..953baf0 --- /dev/null +++ b/internal/llm/anthropic_test.go @@ -0,0 +1,328 @@ +package llm + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +// newAnthropicTestProvider creates an AnthropicProvider pointing at the given mock server URL. +func newAnthropicTestProvider(serverURL string) *AnthropicProvider { + p := NewAnthropicProvider("test-api-key", "test-model") + p.baseURL = serverURL + return p +} + +// mustMarshal marshals v to JSON and panics on error (test helper). +func mustMarshal(v any) []byte { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return b +} + +// mockServer creates an httptest.Server that returns the provided JSON body for +// all POST requests to /v1/messages. +func mockServer(t *testing.T, respBody any) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/v1/messages" { + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + http.Error(w, "not found", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(mustMarshal(respBody)) + })) +} + +// captureServer creates a mock server that captures the decoded request body +// and returns the provided response. +func captureServer(t *testing.T, captured *anthropicRequest, respBody any) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + } + if err := json.Unmarshal(body, captured); err != nil { + t.Errorf("decode body: %v", err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(mustMarshal(respBody)) + })) +} + +// --------------------------------------------------------------------------- +// Test 1: Name() +// --------------------------------------------------------------------------- + +func TestAnthropicProviderName(t *testing.T) { + p := NewAnthropicProvider("key", "model") + if got := p.Name(); got != "anthropic" { + t.Errorf("Name() = %q, want %q", got, "anthropic") + } +} + +// --------------------------------------------------------------------------- +// Test 2: Chat end_turn — text response +// --------------------------------------------------------------------------- + +func TestAnthropicChat_EndTurn(t *testing.T) { + apiResp := anthropicResponse{ + ID: "msg_001", + Content: []anthropicContentBlock{ + {Type: "text", Text: "Hello, world!"}, + }, + StopReason: "end_turn", + } + apiResp.Usage.InputTokens = 10 + apiResp.Usage.OutputTokens = 5 + + srv := mockServer(t, apiResp) + defer srv.Close() + + p := newAnthropicTestProvider(srv.URL) + resp, err := p.Chat([]Message{{Role: "user", Content: "Hi"}}, nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello, world!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello, world!") + } + if resp.StopReason != "end_turn" { + t.Errorf("StopReason = %q, want %q", resp.StopReason, "end_turn") + } + if resp.PromptTok != 10 { + t.Errorf("PromptTok = %d, want 10", resp.PromptTok) + } + if resp.OutputTok != 5 { + t.Errorf("OutputTok = %d, want 5", resp.OutputTok) + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls)) + } +} + +// --------------------------------------------------------------------------- +// Test 3: Chat tool_use — ToolCalls populated +// --------------------------------------------------------------------------- + +func TestAnthropicChat_ToolUse(t *testing.T) { + apiResp := anthropicResponse{ + ID: "msg_002", + Content: []anthropicContentBlock{ + { + Type: "tool_use", + ID: "tc_1", + Name: "read_file", + Input: map[string]any{"path": "/tmp/test.txt"}, + }, + }, + StopReason: "tool_use", + } + apiResp.Usage.InputTokens = 20 + apiResp.Usage.OutputTokens = 8 + + srv := mockServer(t, apiResp) + defer srv.Close() + + p := newAnthropicTestProvider(srv.URL) + resp, err := p.Chat([]Message{{Role: "user", Content: "Read that file"}}, []ToolDef{ + { + Name: "read_file", + Description: "Read a file from disk", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + }, + }, + }) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.StopReason != "tool_use" { + t.Errorf("StopReason = %q, want %q", resp.StopReason, "tool_use") + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(resp.ToolCalls)) + } + tc := resp.ToolCalls[0] + if tc.ID != "tc_1" { + t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "tc_1") + } + if tc.Name != "read_file" { + t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "read_file") + } + if tc.Params["path"] != "/tmp/test.txt" { + t.Errorf("ToolCall.Params[path] = %q, want %q", tc.Params["path"], "/tmp/test.txt") + } +} + +// --------------------------------------------------------------------------- +// Test 4: System prompt extraction +// --------------------------------------------------------------------------- + +func TestAnthropicChat_SystemPrompt(t *testing.T) { + apiResp := anthropicResponse{ + ID: "msg_003", + Content: []anthropicContentBlock{{Type: "text", Text: "ok"}}, + StopReason: "end_turn", + } + + var captured anthropicRequest + srv := captureServer(t, &captured, apiResp) + defer srv.Close() + + p := newAnthropicTestProvider(srv.URL) + messages := []Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello"}, + } + _, err := p.Chat(messages, nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + + // System is now a JSON array of content blocks with cache_control. + var systemBlocks []map[string]any + if err := json.Unmarshal(captured.System, &systemBlocks); err != nil { + t.Fatalf("decode System blocks: %v", err) + } + if len(systemBlocks) != 1 { + t.Fatalf("len(systemBlocks) = %d, want 1", len(systemBlocks)) + } + if got := systemBlocks[0]["text"]; got != "You are a helpful assistant." { + t.Errorf("System text = %q, want %q", got, "You are a helpful assistant.") + } + if cc, ok := systemBlocks[0]["cache_control"].(map[string]any); !ok || cc["type"] != "ephemeral" { + t.Errorf("System block missing cache_control ephemeral, got: %v", systemBlocks[0]["cache_control"]) + } + + // The system message should NOT appear in the messages array. + for _, m := range captured.Messages { + if m.Role == "system" { + t.Error("system role message found in messages array — should be in System field only") + } + } + + // The user message should still be present. + if len(captured.Messages) != 1 { + t.Errorf("len(Messages) = %d, want 1", len(captured.Messages)) + } +} + +// --------------------------------------------------------------------------- +// Test 5: tool_result formatting +// --------------------------------------------------------------------------- + +func TestAnthropicChat_ToolResult(t *testing.T) { + apiResp := anthropicResponse{ + ID: "msg_004", + Content: []anthropicContentBlock{{Type: "text", Text: "Done."}}, + StopReason: "end_turn", + } + + var captured anthropicRequest + srv := captureServer(t, &captured, apiResp) + defer srv.Close() + + p := newAnthropicTestProvider(srv.URL) + messages := []Message{ + {Role: "user", Content: "Read the file"}, + {Role: "assistant", Content: "Sure"}, + { + Role: "tool_result", + Content: "file contents here", + ToolCallID: "tc_42", + }, + } + _, err := p.Chat(messages, nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + + // We expect 3 messages: user, assistant, user(tool_result). + if len(captured.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(captured.Messages)) + } + + // The last message must be role "user" (tool_result is wrapped as user). + last := captured.Messages[2] + if last.Role != "user" { + t.Errorf("last message Role = %q, want %q", last.Role, "user") + } + + // Decode the content blocks. + var blocks []anthropicContentBlock + if err := json.Unmarshal(last.Content, &blocks); err != nil { + t.Fatalf("decode last message content: %v", err) + } + if len(blocks) != 1 { + t.Fatalf("len(blocks) = %d, want 1", len(blocks)) + } + block := blocks[0] + if block.Type != "tool_result" { + t.Errorf("block.Type = %q, want %q", block.Type, "tool_result") + } + if block.ToolUseID != "tc_42" { + t.Errorf("block.ToolUseID = %q, want %q", block.ToolUseID, "tc_42") + } + if block.Content != "file contents here" { + t.Errorf("block.Content = %q, want %q", block.Content, "file contents here") + } +} + +// --------------------------------------------------------------------------- +// Test 6: Cache metrics parsed from usage response +// --------------------------------------------------------------------------- + +func TestAnthropicChat_CacheMetrics(t *testing.T) { + apiResp := anthropicResponse{ + ID: "msg_005", + Content: []anthropicContentBlock{ + {Type: "text", Text: "Cached response."}, + }, + StopReason: "end_turn", + } + apiResp.Usage.InputTokens = 50 + apiResp.Usage.OutputTokens = 10 + apiResp.Usage.CacheCreationInputTokens = 500 + apiResp.Usage.CacheReadInputTokens = 450 + + srv := mockServer(t, apiResp) + defer srv.Close() + + p := newAnthropicTestProvider(srv.URL) + resp, err := p.Chat([]Message{{Role: "user", Content: "Hello"}}, nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.CacheCreated != 500 { + t.Errorf("CacheCreated = %d, want 500", resp.CacheCreated) + } + if resp.CacheHit != 450 { + t.Errorf("CacheHit = %d, want 450", resp.CacheHit) + } + if resp.PromptTok != 50 { + t.Errorf("PromptTok = %d, want 50", resp.PromptTok) + } + if resp.OutputTok != 10 { + t.Errorf("OutputTok = %d, want 10", resp.OutputTok) + } +} + +// --------------------------------------------------------------------------- +// Compile-time interface check +// --------------------------------------------------------------------------- + +func TestAnthropicProviderImplementsProvider(t *testing.T) { + var _ Provider = (*AnthropicProvider)(nil) +} diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go new file mode 100644 index 0000000..d11d576 --- /dev/null +++ b/internal/llm/ollama.go @@ -0,0 +1,71 @@ +package llm + +import ( + "os" + + "github.com/AgentGuardHQ/shellforge/internal/ollama" +) + +// OllamaProvider wraps the existing ollama.Chat() function. +type OllamaProvider struct { + host string + model string +} + +// NewOllamaProvider creates an OllamaProvider targeting the given host and model. +// Pass empty strings to use the ollama package defaults. +func NewOllamaProvider(host, model string) *OllamaProvider { + return &OllamaProvider{host: host, model: model} +} + +// Name returns the provider identifier. +func (o *OllamaProvider) Name() string { + return "ollama" +} + +// Chat converts llm.Message slice to ollama.ChatMessage, calls ollama.Chat(), +// and converts the response back to llm.Response. +// Ollama does not support native tool-use, so ToolCalls is always nil. +// The tools parameter is ignored — Ollama uses text-based tool calling via prompt. +// Roles: "tool_result" is mapped to "user" since Ollama only understands "user". +func (o *OllamaProvider) Chat(messages []Message, tools []ToolDef) (*Response, error) { + // Override ollama package host if caller specified one. + if o.host != "" { + prev := ollama.Host + ollama.Host = o.host + defer func() { ollama.Host = prev }() + } + + ollamaMsgs := make([]ollama.ChatMessage, len(messages)) + for i, m := range messages { + role := m.Role + if role == "tool_result" { + role = "user" + } + ollamaMsgs[i] = ollama.ChatMessage{ + Role: role, + Content: m.Content, + } + } + + resp, err := ollama.Chat(ollamaMsgs, o.model) + if err != nil { + return nil, err + } + + return &Response{ + Content: resp.Message.Content, + ToolCalls: nil, + StopReason: "end_turn", + PromptTok: resp.PromptEval, + OutputTok: resp.EvalCount, + }, nil +} + +// envOr returns the value of the environment variable key, or fallback if unset. +func envOr(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} diff --git a/internal/llm/provider.go b/internal/llm/provider.go new file mode 100644 index 0000000..38e1598 --- /dev/null +++ b/internal/llm/provider.go @@ -0,0 +1,40 @@ +package llm + +// Provider abstracts an LLM backend (Ollama, Anthropic, etc.). +type Provider interface { + Chat(messages []Message, tools []ToolDef) (*Response, error) + Name() string +} + +// Message is a conversation turn. +type Message struct { + Role string // "system", "user", "assistant", "tool_result" + Content string + ToolCallID string // set when Role == "tool_result" + ToolCalls []ToolCall // set on assistant messages that invoked tools +} + +// ToolDef describes a tool the model can invoke. +type ToolDef struct { + Name string + Description string + Parameters map[string]any // JSON Schema +} + +// ToolCall is a model's request to invoke a tool. +type ToolCall struct { + ID string + Name string + Params map[string]string +} + +// Response is the model's reply to a Chat call. +type Response struct { + Content string + ToolCalls []ToolCall + StopReason string // "end_turn", "tool_use" + PromptTok int + OutputTok int + CacheCreated int // tokens written to cache (first call) + CacheHit int // tokens read from cache (subsequent calls) +} diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go new file mode 100644 index 0000000..29ba334 --- /dev/null +++ b/internal/llm/provider_test.go @@ -0,0 +1,105 @@ +package llm + +import ( + "testing" +) + +func TestMessageFields(t *testing.T) { + m := Message{ + Role: "tool_result", + Content: "output here", + ToolCallID: "call_abc123", + } + if m.Role != "tool_result" { + t.Errorf("Role: got %q, want %q", m.Role, "tool_result") + } + if m.Content != "output here" { + t.Errorf("Content: got %q, want %q", m.Content, "output here") + } + if m.ToolCallID != "call_abc123" { + t.Errorf("ToolCallID: got %q, want %q", m.ToolCallID, "call_abc123") + } +} + +func TestToolDefFields(t *testing.T) { + td := ToolDef{ + Name: "read_file", + Description: "Reads a file from disk", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + }, + } + if td.Name != "read_file" { + t.Errorf("Name: got %q, want %q", td.Name, "read_file") + } + if td.Description == "" { + t.Error("Description should not be empty") + } + if td.Parameters == nil { + t.Error("Parameters should not be nil") + } +} + +func TestToolCallFields(t *testing.T) { + tc := ToolCall{ + ID: "call_1", + Name: "list_files", + Params: map[string]string{"directory": "."}, + } + if tc.ID != "call_1" { + t.Errorf("ID: got %q, want %q", tc.ID, "call_1") + } + if tc.Name != "list_files" { + t.Errorf("Name: got %q, want %q", tc.Name, "list_files") + } + if tc.Params["directory"] != "." { + t.Errorf("Params[directory]: got %q, want %q", tc.Params["directory"], ".") + } +} + +func TestResponseFields(t *testing.T) { + r := Response{ + Content: "here is the answer", + ToolCalls: nil, + StopReason: "end_turn", + PromptTok: 100, + OutputTok: 50, + } + if r.Content != "here is the answer" { + t.Errorf("Content: got %q, want %q", r.Content, "here is the answer") + } + if r.StopReason != "end_turn" { + t.Errorf("StopReason: got %q, want %q", r.StopReason, "end_turn") + } + if r.PromptTok != 100 { + t.Errorf("PromptTok: got %d, want %d", r.PromptTok, 100) + } + if r.OutputTok != 50 { + t.Errorf("OutputTok: got %d, want %d", r.OutputTok, 50) + } + if r.ToolCalls != nil { + t.Error("ToolCalls should be nil for end_turn") + } +} + +func TestOllamaProviderName(t *testing.T) { + p := NewOllamaProvider("http://localhost:11434", "qwen3:1.7b") + if p.Name() != "ollama" { + t.Errorf("Name: got %q, want %q", p.Name(), "ollama") + } +} + +func TestOllamaProviderNameEmptyArgs(t *testing.T) { + p := NewOllamaProvider("", "") + if p.Name() != "ollama" { + t.Errorf("Name: got %q, want %q", p.Name(), "ollama") + } +} + +func TestOllamaProviderImplementsProvider(t *testing.T) { + // Compile-time check: *OllamaProvider must satisfy Provider interface. + var _ Provider = (*OllamaProvider)(nil) +} diff --git a/scripts/hook-resolve.sh b/scripts/hook-resolve.sh new file mode 100755 index 0000000..aeb019d --- /dev/null +++ b/scripts/hook-resolve.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# hook-resolve.sh — Universal AgentGuard binary resolver for all drivers. +# Ensures governance hooks + telemetry work in worktrees, local installs, and global installs. +# +# Usage (from any hook config): +# source scripts/hook-resolve.sh +# eval "$AGENTGUARD_BIN claude-hook" # or copilot-hook, codex-hook, gemini-hook +# +# Sets: +# AGENTGUARD_BIN — shell command prefix that works everywhere (may include cd) +# _AG_MAIN_ROOT — path to the main (non-worktree) checkout + +# Resolve project root +_AG_PROJECT_ROOT="$(git rev-parse --show-toplevel 2>/dev/null || pwd)" + +# Source persona env if available +if [ -f "$_AG_PROJECT_ROOT/.agentguard/persona.env" ]; then + set -a; source "$_AG_PROJECT_ROOT/.agentguard/persona.env"; set +a +fi + +# Source workspace .env for telemetry config (API key, cloud endpoint, tenant ID) +_AG_WS_ROOT="$HOME/agentguard-workspace" +if [ -f "$_AG_WS_ROOT/.env" ]; then + set -a; source "$_AG_WS_ROOT/.env"; set +a +fi + +# Find the main worktree root (where node_modules lives) +_AG_MAIN_ROOT="$(git worktree list --porcelain 2>/dev/null | sed -n '1s/^worktree //p')" +_AG_IN_WORKTREE=0 +if [ -n "$_AG_MAIN_ROOT" ] && [ "$_AG_MAIN_ROOT" != "$_AG_PROJECT_ROOT" ]; then + _AG_IN_WORKTREE=1 +fi + +# Resolve binary — priority: local dev > global install > main worktree fallback +AGENTGUARD_BIN="" + +# 1. Global install (npm install -g @red-codes/agentguard) +# Works in any directory — no worktree issues. +if command -v agentguard &>/dev/null; then + AGENTGUARD_BIN="agentguard" +fi + +# 2. Local dev (apps/cli/dist/bin.js in current or main worktree) +# ESM resolution requires CWD to be where node_modules exists. +# In worktrees, we MUST cd to the main root before running the binary. +if [ -f "$_AG_PROJECT_ROOT/apps/cli/dist/bin.js" ]; then + if [ "$_AG_IN_WORKTREE" -eq 1 ] && [ -n "$_AG_MAIN_ROOT" ]; then + # Worktree: run from main root for ESM package resolution + AGENTGUARD_BIN="cd $_AG_MAIN_ROOT && node apps/cli/dist/bin.js" + else + AGENTGUARD_BIN="node $_AG_PROJECT_ROOT/apps/cli/dist/bin.js" + fi +elif [ "$_AG_IN_WORKTREE" -eq 1 ] && [ -n "$_AG_MAIN_ROOT" ] && [ -f "$_AG_MAIN_ROOT/apps/cli/dist/bin.js" ]; then + # Worktree doesn't have the binary but main root does + AGENTGUARD_BIN="cd $_AG_MAIN_ROOT && node apps/cli/dist/bin.js" +fi + +# 3. Probe: verify the resolved binary actually works +if [ -n "$AGENTGUARD_BIN" ]; then + if ! eval "$AGENTGUARD_BIN --version" >/dev/null 2>&1; then + # Binary fails — try main worktree as last resort + if [ -n "$_AG_MAIN_ROOT" ] && [ -f "$_AG_MAIN_ROOT/apps/cli/dist/bin.js" ]; then + AGENTGUARD_BIN="cd $_AG_MAIN_ROOT && node apps/cli/dist/bin.js" + if ! eval "$AGENTGUARD_BIN --version" >/dev/null 2>&1; then + AGENTGUARD_BIN="" # give up, bootstrap exemption will handle it + fi + else + AGENTGUARD_BIN="" + fi + fi +fi + +export AGENTGUARD_BIN diff --git a/scripts/run-dogfood.sh b/scripts/run-dogfood.sh new file mode 100755 index 0000000..1554d41 --- /dev/null +++ b/scripts/run-dogfood.sh @@ -0,0 +1,3 @@ +#!/bin/bash +# Trigger RunPod GPU session for ShellForge dogfood run +runpod run --image shellforge-v1 --gpus a10x --command "cd /home/jared/agentguard-workspace/shellforge && go run ./cmd/shellforge/main.go" \ No newline at end of file