From f356dcb413883e890b7258ae39569611d462ae85 Mon Sep 17 00:00:00 2001 From: Jared Pleva Date: Tue, 31 Mar 2026 08:34:25 +0000 Subject: [PATCH 1/9] =?UTF-8?q?fix(cmdScan):=20replace=20filepath.Glob=20*?= =?UTF-8?q?*=20with=20WalkDir=20+=20EM=20state=20run=2011=20=E2=80=94=20cl?= =?UTF-8?q?oses=20#52?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - filepath.Glob does not support ** recursive patterns in Go; goEntries was always nil (0 files) because the pattern looked for a literal dir named "**" - Replace with filepath.WalkDir to collect .go files recursively under internal/ - Add io/fs import required by WalkDir callback signature - EM state run 11: closed stale PR #94, P2 bug sweep 3/3 fixes now in open PRs Co-Authored-By: Claude Sonnet 4.6 --- .agentguard/squads/shellforge/state.json | 59 +++++++++++++----------- cmd/shellforge/main.go | 9 +++- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/.agentguard/squads/shellforge/state.json b/.agentguard/squads/shellforge/state.json index d7ba32a..ca36ccc 100644 --- a/.agentguard/squads/shellforge/state.json +++ b/.agentguard/squads/shellforge/state.json @@ -1,35 +1,30 @@ { "squad": "shellforge", - "updated_at": "2026-03-31T00:00:00Z", + "updated_at": "2026-03-31T08:30:00Z", "sprint": { "goal": "P2 bug sweep + dogfood readiness — #52 (Glob **), #65 (silent WriteFile), #76 (setup.sh remote Ollama)", - "focus": "PR #89 merged (25 tests, #68+#66 closed). Fixed #51 in PR #93. Next: #52 (Glob **) and #65 (silent WriteFile) when budget clears.", + "focus": "Run 11: Fixed #52 (filepath.Glob → WalkDir) in new PR. PRs #93, #95 green, awaiting human merge. #76 still blocked on setup.sh remote Ollama. Closed stale PR #94.", "status": "active" }, "pr_budget": { "max_open": 3, - "current_open": 2, + "current_open": 3, "status": "ok" }, "loop_guard": { "retry_loop_detected": false, "blast_radius": "low" }, - "incident": { - "id": "worktree-dirty-wip-51", - "resolved": true, - "description": "Worktree had uncommitted partial fix for #51 — import block broken (import \"log\" vs import ()), orphaned run() body left outside function. Build was failing. Stashed WIP, reimplemented cleanly in PR #93." - }, "issue_queue": { "p0": [], "p1": [], "p2": [ - { "number": 76, "title": "Dogfood: run ShellForge swarm on jared box via RunPod GPU", "assignee": "em", "notes": "Blocked on setup.sh: isServer=true skips Goose, remote Ollama (OLLAMA_HOST) not supported. 3rd escalation." }, - { "number": 92, "title": "Bundle Preflight protocol in Goose agent bootstrap", "assignee": null, "notes": "Blocked on Preflight v1 ship. Triaged P2 this run." }, - { "number": 65, "title": "bug: scheduler.go silently ignores os.WriteFile error", "assignee": "em" }, - { "number": 52, "title": "bug: filepath.Glob with ** in cmdScan never matches any Go files", "assignee": "em" }, + { "number": 76, "title": "Dogfood: run ShellForge swarm on jared box via RunPod GPU", "assignee": "em", "notes": "Blocked on setup.sh: isServer=true skips Goose, remote Ollama (OLLAMA_HOST) not supported. 4th escalation. Needs human to unblock." }, + { "number": 92, "title": "Bundle Preflight protocol in Goose agent bootstrap", "assignee": null, "notes": "Blocked on Preflight v1 ship." }, { "number": 53, "title": "docs/readme: README still shows ./shellforge commands", "assignee": null }, - { "number": 51, "title": "bug: run() helper silently ignores errors", "assignee": "em", "notes": "FIXED — PR #93 open, CI pending" }, + { "number": 51, "title": "bug: run() helper silently ignores errors", "assignee": "em", "notes": "FIXED — PR #93 open, CI green 5/5, awaiting human merge" }, + { "number": 65, "title": "bug: scheduler.go silently ignores os.WriteFile error", "assignee": "em", "notes": "FIXED — PR #95 open, CI green 5/5, awaiting human merge" }, + { "number": 52, "title": "bug: filepath.Glob with ** in cmdScan never matches any Go files", "assignee": "em", "notes": "FIXED this run — PR #96 open, replaces Glob with filepath.WalkDir" }, { "number": 50, "title": "bug: kernel version comparison is lexicographic, not numeric", "assignee": null }, { "number": 49, "title": "bug: InferenceQueue is not priority-aware", "assignee": null }, { "number": 26, "title": "bug: run-qa-agent.sh doesn't build binary if missing", "assignee": null }, @@ -51,29 +46,37 @@ }, "pr_queue": [ { - "number": 91, - "title": "chore(squad): EM state update — run 8 (2026-03-30)", + "number": 93, + "title": "fix(main): log errors from run() helper — closes #51", "status": "open", "ci": "green (5/5)", "review_status": "REVIEW_REQUIRED", - "issues_closed": [] + "issues_closed": [51] }, { - "number": 93, - "title": "fix(main): log errors from run() helper — closes #51", + "number": 95, + "title": "fix(scheduler): log WriteFile error + EM state run 10 — closes #65", + "status": "open", + "ci": "green (5/5)", + "review_status": "REVIEW_REQUIRED", + "issues_closed": [65] + }, + { + "number": 96, + "title": "fix(cmdScan): replace filepath.Glob ** with WalkDir + EM state run 11 — closes #52", "status": "open", "ci": "pending", "review_status": "REVIEW_REQUIRED", - "issues_closed": [51] + "issues_closed": [52] } ], "recently_closed": [ + { "number": 94, "merged": false, "closed": true, "issues_closed": [], "date": "2026-03-31", "notes": "Stale — run 9 state was committed directly to master at 832cb58 before PR merged. Closed to clear budget." }, { "number": 89, "merged": true, "issues_closed": [68, 66], "date": "2026-03-30", "notes": "25 tests — normalizer, governance, intent" }, { "number": 88, "merged": true, "issues_closed": [63], "date": "2026-03-30", "notes": "P1 classifyShellRisk word-boundary fix" }, { "number": 87, "merged": true, "issues_closed": [], "date": "2026-03-30", "notes": "EM state run 6" }, { "number": 86, "merged": true, "issues_closed": [28], "date": "2026-03-30", "notes": "P1 timeout override fix" }, - { "number": 83, "merged": true, "issues_closed": [58, 62, 67, 69, 75], "date": "2026-03-30" }, - { "number": 84, "merged": true, "issues_closed": [74], "date": "2026-03-30" } + { "number": 83, "merged": true, "issues_closed": [58, 62, 67, 69, 75], "date": "2026-03-30" } ], "agents": { "qa-agent": { "status": "idle", "schedule": "4h", "last_pr": 89, "notes": "PR #89 merged — idle until next assignment" }, @@ -86,14 +89,16 @@ "setup.sh needs remote Ollama support (OLLAMA_HOST) before dogfood can proceed on headless/GPU-remote boxes" ], "blockers": [ - "PR #91 (EM state run 8): CI green — REVIEW REQUIRED. Chore PR.", - "PR #93 (fix #51): CI pending — REVIEW REQUIRED. One-line fix, 25/25 tests pass.", - "Dogfood (#76): 3rd escalation — setup.sh isServer check skips Goose on headless servers. Needs code fix before jared-box dogfood run." + "PR #93 (fix #51): CI green — REVIEW REQUIRED", + "PR #95 (fix #65 + run 10 state): CI green — REVIEW REQUIRED", + "PR #96 (fix #52 + run 11 state): CI pending — REVIEW REQUIRED", + "Dogfood (#76): 4th escalation — setup.sh isServer check skips Goose on headless servers. Requires human code change." ], "human_escalations": [ - { "priority": "P2", "title": "Review PR #91 — EM state run 8", "url": "https://github.com/AgentGuardHQ/shellforge/pull/91" }, { "priority": "P2", "title": "Review PR #93 — fix run() silent errors (closes #51)", "url": "https://github.com/AgentGuardHQ/shellforge/pull/93" }, - { "priority": "P2", "title": "Trigger dogfood run (#76) — all governance P0/P1 bugs fixed, blocked on setup.sh remote Ollama gap", "url": "https://github.com/AgentGuardHQ/shellforge/issues/76" } + { "priority": "P2", "title": "Review PR #95 — fix scheduler WriteFile silent error (closes #65)", "url": "https://github.com/AgentGuardHQ/shellforge/pull/95" }, + { "priority": "P2", "title": "Review PR #96 — fix cmdScan Glob→WalkDir (closes #52)", "url": "https://github.com/AgentGuardHQ/shellforge/pull/96" }, + { "priority": "P2", "title": "Dogfood run (#76) — governance clean, blocked on setup.sh remote Ollama gap. Needs human fix.", "url": "https://github.com/AgentGuardHQ/shellforge/issues/76" } ], - "notes": "Run 9 (2026-03-31): Found broken worktree — incomplete WIP fix for #51 left import block broken and run() body orphaned. Stashed, reimplemented cleanly in PR #93. PR #89 now merged (25 tests). Issue #92 (Preflight/Goose) triaged P2. PR budget 2/3. Next: fix #52 (filepath.Walk) and #65 (scheduler WriteFile) when budget clears." + "notes": "Run 11 (2026-03-31T08:30Z): Closed stale PR #94 (run 9 state was already on master via direct push at 832cb58). Fixed #52 (filepath.Glob ** → filepath.WalkDir). All 3 P2 bug fixes now in open PRs: #93 (run()), #95 (WriteFile), #96 (Glob). Budget at 3/3 — next run should attempt merges before opening new work. P2 sweep nearly complete; remaining P2 bugs (#50 lexicographic version, #49 priority queue, #24 listFiles cwd, #26 run-qa-agent build check) ready to queue once budget clears." } diff --git a/cmd/shellforge/main.go b/cmd/shellforge/main.go index a9c90c6..41e271b 100644 --- a/cmd/shellforge/main.go +++ b/cmd/shellforge/main.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" +"io/fs" "os" "os/exec" "path/filepath" @@ -863,7 +864,13 @@ if _, err := os.Stat("agentguard.yaml"); err == nil { fmt.Println(" ✓ agentguard.yaml found") } entries, _ := filepath.Glob(filepath.Join(dir, "agents", "*.ts")) -goEntries, _ := filepath.Glob(filepath.Join(dir, "internal", "**", "*.go")) +var goEntries []string +filepath.WalkDir(filepath.Join(dir, "internal"), func(p string, d fs.DirEntry, err error) error { +if err == nil && !d.IsDir() && strings.HasSuffix(p, ".go") { +goEntries = append(goEntries, p) +} +return nil +}) fmt.Printf(" Found %d TS agents, %d Go files\n", len(entries), len(goEntries)) fmt.Println(" Install defenseclaw for full supply chain scanning") } From ecd73f410d2aa691822b28390ecc14d245311e27 Mon Sep 17 00:00:00 2001 From: Jared Pleva Date: Tue, 31 Mar 2026 08:35:05 +0000 Subject: [PATCH 2/9] chore(squad): update blockers.md for run 11 --- .agentguard/squads/shellforge/blockers.md | 61 ++++++++++------------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/.agentguard/squads/shellforge/blockers.md b/.agentguard/squads/shellforge/blockers.md index 315e085..6483c6b 100644 --- a/.agentguard/squads/shellforge/blockers.md +++ b/.agentguard/squads/shellforge/blockers.md @@ -1,7 +1,7 @@ # ShellForge Squad — Blockers -**Updated:** 2026-03-31T00:00Z -**Reported by:** EM run 9 (claude-code:opus:shellforge:em) +**Updated:** 2026-03-31T08:30Z +**Reported by:** EM run 11 (claude-code:opus:shellforge:em) --- @@ -13,35 +13,24 @@ ## P1 — Active Work -**None.** All P1 issues closed (PR #89 merged — closes #68 + #66). +**None.** All P1 issues closed. --- -## Incident (Resolved) - -### Broken worktree — incomplete WIP fix for #51 -**Detected:** Run 9 (2026-03-31) -**Resolved:** Yes -**Description:** The worktree had uncommitted partial changes to `cmd/shellforge/main.go`: -- `import (` was replaced with `import "log"`, breaking the multi-package import block syntax -- `run()` was partially refactored to call a non-existent `executeCommand()` function, leaving the old body orphaned outside any function -- Build failure: `syntax error: non-declaration statement outside function body` +## P2 — Active Blockers -**Resolution:** Stashed the WIP changes, created `fix/run-silent-errors-51` branch from `origin/main`, implemented the fix correctly (add `"log"` to imports, log error in `run()` via `if err := cmd.Run(); err != nil`). PR #93 open. +### PR Review Queue (budget: 3/3) ---- +| PR | Title | CI | Status | +|----|-------|----|--------| +| #93 | fix run() silent errors (closes #51) | ✅ 5/5 | REVIEW REQUIRED | +| #95 | fix scheduler WriteFile silent error (closes #65) | ✅ 5/5 | REVIEW REQUIRED | +| #96 | fix cmdScan Glob→WalkDir (closes #52) | ⏳ pending | REVIEW REQUIRED | -## P2 — Active Blockers +**Action Required:** @jpleva91 review and merge PRs #93, #95, #96 to clear budget for remaining P2 sweep. -### PR Review Queue (budget: 2/3) -| PR | Title | Status | -|----|-------|--------| -| #91 | EM state update run 8 | CI green — REVIEW REQUIRED | -| #93 | fix run() silent errors (closes #51) | CI pending — REVIEW REQUIRED | +### #76 — Dogfood: setup.sh doesn't support remote Ollama (4th escalation) -**Action Required:** @jpleva91 review and merge PR #91 and PR #93. - -### #76 — Dogfood: setup.sh doesn't support remote Ollama (3rd escalation) **Severity:** Medium — dogfood on jared-box (headless WSL2 + RunPod GPU) blocked **Root cause:** `shellforge setup` detects `isServer=true` on headless Linux and skips Goose + Ollama entirely, with no option to configure `OLLAMA_HOST` for a remote GPU endpoint. **Fix needed:** setup.sh should offer remote Ollama config when `isServer=true` — set `OLLAMA_HOST`, skip local Ollama install, keep Goose setup. @@ -49,13 +38,11 @@ --- -## P2 — Queued (unassigned) +## P2 — Queued (unassigned, after budget clears) | # | Issue | Notes | |---|-------|-------| | #92 | Bundle Preflight in Goose bootstrap | Blocked on Preflight v1 ship | -| #65 | scheduler.go silent os.WriteFile error | Next EM fix after PR budget clears | -| #52 | filepath.Glob ** never matches Go files | Next EM fix — needs filepath.Walk | | #53 | README stale ./shellforge commands | Docs rot | | #50 | kernel version comparison lexicographic | setup.sh version gate broken | | #49 | InferenceQueue not priority-aware | Documented but unimplemented | @@ -65,14 +52,17 @@ --- -## Resolved (this cycle) +## Resolved (this cycle — run 11) -- **#68** — zero test coverage → merged PR #89 (25 tests for normalizer/governance/intent) -- **#66** — dead code in flattenParams() → fixed in PR #89 -- **#51** — run() helper silently ignores errors → PR #93 open +- **#52** — filepath.Glob ** never matches Go files → fixed with WalkDir in PR #96 +- **PR #94** — stale EM state PR (run 9 state was already on master at 832cb58) → closed ## Resolved (prior cycles) +- **#65** — scheduler.go silent WriteFile error → PR #95 open +- **#51** — run() helper silently ignores errors → PR #93 open +- **#68** — zero test coverage → merged PR #89 (25 tests) +- **#66** — dead code in flattenParams() → merged PR #89 - **#28** → PR #86 merged - **#63** → PR #88 merged - **#58, #62, #75, #67, #69** → PR #83 merged @@ -86,10 +76,11 @@ |------|--------| | P0 issues | ✅ All closed | | P1 issues | ✅ All closed | -| PR #91 (EM state run 8) | 🟡 CI green — REVIEW REQUIRED | -| PR #93 (fix #51) | 🟡 CI pending — REVIEW REQUIRED | -| Sprint goal | 🔵 Active — P2 sweep in progress | -| PR budget | 2/3 | -| Dogfood (#76) | 🔴 Blocked — setup.sh remote Ollama gap (3rd escalation) | +| PR #93 (fix #51 run() errors) | 🟡 CI green — REVIEW REQUIRED | +| PR #95 (fix #65 WriteFile) | 🟡 CI green — REVIEW REQUIRED | +| PR #96 (fix #52 Glob→WalkDir) | 🟡 CI pending — REVIEW REQUIRED | +| Sprint goal | 🔵 Active — P2 sweep 3/3 bugs fixed, all in PRs | +| PR budget | 3/3 (full — merge needed before new work) | +| Dogfood (#76) | 🔴 Blocked — setup.sh remote Ollama gap (4th escalation) | | Retry loops | None | | Blast radius | Low | From 8790a99eb628870ec0598bfe08c2f2b0749e3468 Mon Sep 17 00:00:00 2001 From: Jared Pleva Date: Tue, 31 Mar 2026 22:47:05 +0000 Subject: [PATCH 3/9] feat(shellforge): add LLM provider interface + Ollama implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Provider interface abstracts LLM backends (Ollama, Anthropic). OllamaProvider wraps existing ollama.Chat(). Agent loop accepts optional Provider in LoopConfig — nil falls back to legacy path. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/agent/loop.go | 62 ++++++++++++++++++-- internal/llm/ollama.go | 71 +++++++++++++++++++++++ internal/llm/provider.go | 37 ++++++++++++ internal/llm/provider_test.go | 105 ++++++++++++++++++++++++++++++++++ 4 files changed, 270 insertions(+), 5 deletions(-) create mode 100644 internal/llm/ollama.go create mode 100644 internal/llm/provider.go create mode 100644 internal/llm/provider_test.go diff --git a/internal/agent/loop.go b/internal/agent/loop.go index bba7973..dcf7179 100644 --- a/internal/agent/loop.go +++ b/internal/agent/loop.go @@ -18,6 +18,7 @@ import ( "github.com/AgentGuardHQ/shellforge/internal/correction" "github.com/AgentGuardHQ/shellforge/internal/governance" "github.com/AgentGuardHQ/shellforge/internal/intent" +"github.com/AgentGuardHQ/shellforge/internal/llm" "github.com/AgentGuardHQ/shellforge/internal/logger" "github.com/AgentGuardHQ/shellforge/internal/normalizer" "github.com/AgentGuardHQ/shellforge/internal/ollama" @@ -33,10 +34,12 @@ MaxTurns int TimeoutMs int OutputDir string TokenBudget int +Provider llm.Provider // optional; nil falls back to legacy ollama.Chat() } type RunResult struct { Success bool +ExitReason string // "final_answer", "timeout", "max_turns", "model_error" Output string Turns int ToolCalls int @@ -72,6 +75,8 @@ for turn := 1; turn <= cfg.MaxTurns; turn++ { elapsed := time.Since(start).Milliseconds() if int(elapsed) > cfg.TimeoutMs { logger.Agent(cfg.Agent, fmt.Sprintf("timeout after %d turns", turn-1)) +result.Turns = turn - 1 +result.ExitReason = "timeout" break } @@ -81,18 +86,42 @@ compacted := compactMessages(messages, cfg.TokenBudget) tokEst := estimateTokens(compacted) logger.Agent(cfg.Agent, fmt.Sprintf("turn %d/%d (~%d tokens)", turn, cfg.MaxTurns, tokEst)) +var content string +var promptTok, outputTok int +var totalDurMs int64 + +if cfg.Provider != nil { +llmMsgs := toProviderMessages(compacted) +provResp, perr := cfg.Provider.Chat(llmMsgs, nil) +if perr != nil { +logger.Error(cfg.Agent, cfg.Provider.Name()+": "+perr.Error()) +result.Output = "Model error: " + perr.Error() +result.Turns = turn +result.ExitReason = "model_error" +break +} +content = provResp.Content +promptTok = provResp.PromptTok +outputTok = provResp.OutputTok +} else { resp, err := ollama.Chat(compacted, cfg.Model) if err != nil { logger.Error(cfg.Agent, "ollama: "+err.Error()) result.Output = "Model error: " + err.Error() +result.Turns = turn +result.ExitReason = "model_error" break } +content = resp.Message.Content +promptTok = resp.PromptEval +outputTok = resp.EvalCount +totalDurMs = resp.TotalDuration / 1_000_000 +} -result.PromptTok += resp.PromptEval -result.ResponseTok += resp.EvalCount -logger.ModelCall(cfg.Agent, resp.PromptEval, resp.EvalCount, resp.TotalDuration/1_000_000) +result.PromptTok += promptTok +result.ResponseTok += outputTok +logger.ModelCall(cfg.Agent, promptTok, outputTok, totalDurMs) -content := resp.Message.Content messages = append(messages, ollama.ChatMessage{Role: "assistant", Content: content}) // ── Intent parser: extract action from ANY format ── @@ -103,6 +132,7 @@ if parsed == nil { // No actionable intent — this is a final answer. result.Output = content result.Turns = turn +result.ExitReason = "final_answer" logger.Agent(cfg.Agent, fmt.Sprintf("done — %d turns, %d tool calls", turn, result.ToolCalls)) break } @@ -193,6 +223,15 @@ messages = append(messages, ollama.ChatMessage{ Role: "user", Content: "You've used all turns. Give your final answer now.", }) +if cfg.Provider != nil { +finalMsgs := toProviderMessages(compactMessages(messages, cfg.TokenBudget)) +finalResp, ferr := cfg.Provider.Chat(finalMsgs, nil) +if ferr == nil { +result.Output = finalResp.Content +result.PromptTok += finalResp.PromptTok +result.ResponseTok += finalResp.OutputTok +} +} else { final, err := ollama.Chat(compactMessages(messages, cfg.TokenBudget), cfg.Model) if err == nil { result.Output = final.Message.Content @@ -200,10 +239,13 @@ result.PromptTok += final.PromptEval result.ResponseTok += final.EvalCount } } +result.Turns = turn +result.ExitReason = "max_turns" +} } result.DurationMs = time.Since(start).Milliseconds() -result.Success = result.Denials == 0 || result.ToolCalls > result.Denials +result.Success = result.ExitReason == "final_answer" result.Log = log return result, nil } @@ -289,3 +331,13 @@ return t } return f } + +// toProviderMessages converts []ollama.ChatMessage to []llm.Message for use +// with the Provider interface. +func toProviderMessages(msgs []ollama.ChatMessage) []llm.Message { +result := make([]llm.Message, len(msgs)) +for i, m := range msgs { +result[i] = llm.Message{Role: m.Role, Content: m.Content} +} +return result +} diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go new file mode 100644 index 0000000..d11d576 --- /dev/null +++ b/internal/llm/ollama.go @@ -0,0 +1,71 @@ +package llm + +import ( + "os" + + "github.com/AgentGuardHQ/shellforge/internal/ollama" +) + +// OllamaProvider wraps the existing ollama.Chat() function. +type OllamaProvider struct { + host string + model string +} + +// NewOllamaProvider creates an OllamaProvider targeting the given host and model. +// Pass empty strings to use the ollama package defaults. +func NewOllamaProvider(host, model string) *OllamaProvider { + return &OllamaProvider{host: host, model: model} +} + +// Name returns the provider identifier. +func (o *OllamaProvider) Name() string { + return "ollama" +} + +// Chat converts llm.Message slice to ollama.ChatMessage, calls ollama.Chat(), +// and converts the response back to llm.Response. +// Ollama does not support native tool-use, so ToolCalls is always nil. +// The tools parameter is ignored — Ollama uses text-based tool calling via prompt. +// Roles: "tool_result" is mapped to "user" since Ollama only understands "user". +func (o *OllamaProvider) Chat(messages []Message, tools []ToolDef) (*Response, error) { + // Override ollama package host if caller specified one. + if o.host != "" { + prev := ollama.Host + ollama.Host = o.host + defer func() { ollama.Host = prev }() + } + + ollamaMsgs := make([]ollama.ChatMessage, len(messages)) + for i, m := range messages { + role := m.Role + if role == "tool_result" { + role = "user" + } + ollamaMsgs[i] = ollama.ChatMessage{ + Role: role, + Content: m.Content, + } + } + + resp, err := ollama.Chat(ollamaMsgs, o.model) + if err != nil { + return nil, err + } + + return &Response{ + Content: resp.Message.Content, + ToolCalls: nil, + StopReason: "end_turn", + PromptTok: resp.PromptEval, + OutputTok: resp.EvalCount, + }, nil +} + +// envOr returns the value of the environment variable key, or fallback if unset. +func envOr(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} diff --git a/internal/llm/provider.go b/internal/llm/provider.go new file mode 100644 index 0000000..8f8a77a --- /dev/null +++ b/internal/llm/provider.go @@ -0,0 +1,37 @@ +package llm + +// Provider abstracts an LLM backend (Ollama, Anthropic, etc.). +type Provider interface { + Chat(messages []Message, tools []ToolDef) (*Response, error) + Name() string +} + +// Message is a conversation turn. +type Message struct { + Role string // "system", "user", "assistant", "tool_result" + Content string + ToolCallID string // set when Role == "tool_result" +} + +// ToolDef describes a tool the model can invoke. +type ToolDef struct { + Name string + Description string + Parameters map[string]any // JSON Schema +} + +// ToolCall is a model's request to invoke a tool. +type ToolCall struct { + ID string + Name string + Params map[string]string +} + +// Response is the model's reply to a Chat call. +type Response struct { + Content string + ToolCalls []ToolCall + StopReason string // "end_turn", "tool_use" + PromptTok int + OutputTok int +} diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go new file mode 100644 index 0000000..29ba334 --- /dev/null +++ b/internal/llm/provider_test.go @@ -0,0 +1,105 @@ +package llm + +import ( + "testing" +) + +func TestMessageFields(t *testing.T) { + m := Message{ + Role: "tool_result", + Content: "output here", + ToolCallID: "call_abc123", + } + if m.Role != "tool_result" { + t.Errorf("Role: got %q, want %q", m.Role, "tool_result") + } + if m.Content != "output here" { + t.Errorf("Content: got %q, want %q", m.Content, "output here") + } + if m.ToolCallID != "call_abc123" { + t.Errorf("ToolCallID: got %q, want %q", m.ToolCallID, "call_abc123") + } +} + +func TestToolDefFields(t *testing.T) { + td := ToolDef{ + Name: "read_file", + Description: "Reads a file from disk", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + }, + } + if td.Name != "read_file" { + t.Errorf("Name: got %q, want %q", td.Name, "read_file") + } + if td.Description == "" { + t.Error("Description should not be empty") + } + if td.Parameters == nil { + t.Error("Parameters should not be nil") + } +} + +func TestToolCallFields(t *testing.T) { + tc := ToolCall{ + ID: "call_1", + Name: "list_files", + Params: map[string]string{"directory": "."}, + } + if tc.ID != "call_1" { + t.Errorf("ID: got %q, want %q", tc.ID, "call_1") + } + if tc.Name != "list_files" { + t.Errorf("Name: got %q, want %q", tc.Name, "list_files") + } + if tc.Params["directory"] != "." { + t.Errorf("Params[directory]: got %q, want %q", tc.Params["directory"], ".") + } +} + +func TestResponseFields(t *testing.T) { + r := Response{ + Content: "here is the answer", + ToolCalls: nil, + StopReason: "end_turn", + PromptTok: 100, + OutputTok: 50, + } + if r.Content != "here is the answer" { + t.Errorf("Content: got %q, want %q", r.Content, "here is the answer") + } + if r.StopReason != "end_turn" { + t.Errorf("StopReason: got %q, want %q", r.StopReason, "end_turn") + } + if r.PromptTok != 100 { + t.Errorf("PromptTok: got %d, want %d", r.PromptTok, 100) + } + if r.OutputTok != 50 { + t.Errorf("OutputTok: got %d, want %d", r.OutputTok, 50) + } + if r.ToolCalls != nil { + t.Error("ToolCalls should be nil for end_turn") + } +} + +func TestOllamaProviderName(t *testing.T) { + p := NewOllamaProvider("http://localhost:11434", "qwen3:1.7b") + if p.Name() != "ollama" { + t.Errorf("Name: got %q, want %q", p.Name(), "ollama") + } +} + +func TestOllamaProviderNameEmptyArgs(t *testing.T) { + p := NewOllamaProvider("", "") + if p.Name() != "ollama" { + t.Errorf("Name: got %q, want %q", p.Name(), "ollama") + } +} + +func TestOllamaProviderImplementsProvider(t *testing.T) { + // Compile-time check: *OllamaProvider must satisfy Provider interface. + var _ Provider = (*OllamaProvider)(nil) +} From f332ef0b4ce4fbe88e46d0de819dcc4952048f55 Mon Sep 17 00:00:00 2001 From: Jared Pleva Date: Tue, 31 Mar 2026 22:49:45 +0000 Subject: [PATCH 4/9] feat(shellforge): add Anthropic API provider with native tool-use Implements llm.Provider for Anthropic Messages API using stdlib HTTP. Handles tool_use content blocks, tool_result responses, system prompt extraction, and token tracking. Tests use httptest mock server. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/llm/anthropic.go | 270 ++++++++++++++++++++++++++++++++ internal/llm/anthropic_test.go | 278 +++++++++++++++++++++++++++++++++ 2 files changed, 548 insertions(+) create mode 100644 internal/llm/anthropic.go create mode 100644 internal/llm/anthropic_test.go diff --git a/internal/llm/anthropic.go b/internal/llm/anthropic.go new file mode 100644 index 0000000..3915364 --- /dev/null +++ b/internal/llm/anthropic.go @@ -0,0 +1,270 @@ +package llm + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" +) + +const ( + defaultAnthropicModel = "claude-haiku-4-5-20251001" + defaultAnthropicBaseURL = "https://api.anthropic.com" + anthropicVersion = "2023-06-01" + anthropicMaxTokens = 4096 +) + +// AnthropicProvider calls the Anthropic Messages API via stdlib HTTP. +type AnthropicProvider struct { + apiKey string + model string + baseURL string + client *http.Client +} + +// NewAnthropicProvider creates an AnthropicProvider. +// Pass empty strings to fall back to ANTHROPIC_API_KEY and ANTHROPIC_MODEL env vars. +func NewAnthropicProvider(apiKey, model string) *AnthropicProvider { + if apiKey == "" { + apiKey = os.Getenv("ANTHROPIC_API_KEY") + } + if model == "" { + model = envOr("ANTHROPIC_MODEL", defaultAnthropicModel) + } + return &AnthropicProvider{ + apiKey: apiKey, + model: model, + baseURL: defaultAnthropicBaseURL, + client: &http.Client{Timeout: 5 * time.Minute}, + } +} + +// Name returns the provider identifier. +func (a *AnthropicProvider) Name() string { + return "anthropic" +} + +// ---------------------------------------------------------------------------- +// Anthropic API wire types +// ---------------------------------------------------------------------------- + +// anthropicContentBlock is a polymorphic content block in the Anthropic API. +type anthropicContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input map[string]any `json:"input,omitempty"` + ToolUseID string `json:"tool_use_id,omitempty"` + Content string `json:"content,omitempty"` +} + +// anthropicMessage is one turn in the Anthropic messages array. +// Content can be a plain string or an array of content blocks. +// We use json.RawMessage to handle both. +type anthropicMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` +} + +// anthropicToolDef is the Anthropic tool definition format. +type anthropicToolDef struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]any `json:"input_schema"` +} + +// anthropicRequest is the full POST body for /v1/messages. +type anthropicRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Messages []anthropicMessage `json:"messages"` + Tools []anthropicToolDef `json:"tools,omitempty"` +} + +// anthropicResponse is the API response body. +type anthropicResponse struct { + ID string `json:"id"` + Content []anthropicContentBlock `json:"content"` + StopReason string `json:"stop_reason"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` +} + +// anthropicErrorResponse represents an API error body. +type anthropicErrorResponse struct { + Type string `json:"type"` + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` +} + +// ---------------------------------------------------------------------------- +// Chat implementation +// ---------------------------------------------------------------------------- + +// Chat sends messages to the Anthropic Messages API and returns the response. +func (a *AnthropicProvider) Chat(messages []Message, tools []ToolDef) (*Response, error) { + // Separate system messages from conversation messages. + var systemPrompt string + var convMsgs []Message + for _, m := range messages { + if m.Role == "system" { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += m.Content + } else { + convMsgs = append(convMsgs, m) + } + } + + // Convert conversation messages to Anthropic wire format. + apiMsgs, err := convertMessages(convMsgs) + if err != nil { + return nil, fmt.Errorf("anthropic: convert messages: %w", err) + } + + // Convert tool definitions. + apiTools := make([]anthropicToolDef, len(tools)) + for i, t := range tools { + schema := t.Parameters + if schema == nil { + schema = map[string]any{"type": "object", "properties": map[string]any{}} + } + apiTools[i] = anthropicToolDef{ + Name: t.Name, + Description: t.Description, + InputSchema: schema, + } + } + + reqBody := anthropicRequest{ + Model: a.model, + MaxTokens: anthropicMaxTokens, + System: systemPrompt, + Messages: apiMsgs, + } + if len(apiTools) > 0 { + reqBody.Tools = apiTools + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("anthropic: marshal request: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, a.baseURL+"/v1/messages", bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("anthropic: create request: %w", err) + } + req.Header.Set("x-api-key", a.apiKey) + req.Header.Set("anthropic-version", anthropicVersion) + req.Header.Set("Content-Type", "application/json") + + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("anthropic: http request: %w", err) + } + defer resp.Body.Close() + + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("anthropic: read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var apiErr anthropicErrorResponse + if jsonErr := json.Unmarshal(respBytes, &apiErr); jsonErr == nil && apiErr.Error.Message != "" { + return nil, fmt.Errorf("anthropic: api error %d: %s", resp.StatusCode, apiErr.Error.Message) + } + return nil, fmt.Errorf("anthropic: http %d: %s", resp.StatusCode, string(respBytes)) + } + + var apiResp anthropicResponse + if err := json.Unmarshal(respBytes, &apiResp); err != nil { + return nil, fmt.Errorf("anthropic: unmarshal response: %w", err) + } + + return parseResponse(&apiResp), nil +} + +// ---------------------------------------------------------------------------- +// Helpers +// ---------------------------------------------------------------------------- + +// convertMessages converts llm.Message slice to Anthropic API message slice. +func convertMessages(messages []Message) ([]anthropicMessage, error) { + result := make([]anthropicMessage, 0, len(messages)) + for _, m := range messages { + var raw json.RawMessage + var err error + + switch m.Role { + case "tool_result": + // Must be wrapped in a user message with tool_result content block. + block := anthropicContentBlock{ + Type: "tool_result", + ToolUseID: m.ToolCallID, + Content: m.Content, + } + raw, err = json.Marshal([]anthropicContentBlock{block}) + if err != nil { + return nil, err + } + result = append(result, anthropicMessage{Role: "user", Content: raw}) + + case "assistant", "user": + // Plain text content block. + raw, err = json.Marshal(m.Content) + if err != nil { + return nil, err + } + result = append(result, anthropicMessage{Role: m.Role, Content: raw}) + + default: + // Skip unknown roles (e.g. "system" already extracted). + continue + } + } + return result, nil +} + +// parseResponse converts an anthropicResponse into an llm.Response. +func parseResponse(apiResp *anthropicResponse) *Response { + resp := &Response{ + StopReason: apiResp.StopReason, + PromptTok: apiResp.Usage.InputTokens, + OutputTok: apiResp.Usage.OutputTokens, + } + + for _, block := range apiResp.Content { + switch block.Type { + case "text": + if resp.Content != "" { + resp.Content += "\n" + } + resp.Content += block.Text + + case "tool_use": + params := make(map[string]string, len(block.Input)) + for k, v := range block.Input { + params[k] = fmt.Sprintf("%v", v) + } + resp.ToolCalls = append(resp.ToolCalls, ToolCall{ + ID: block.ID, + Name: block.Name, + Params: params, + }) + } + } + + return resp +} diff --git a/internal/llm/anthropic_test.go b/internal/llm/anthropic_test.go new file mode 100644 index 0000000..0f3de7f --- /dev/null +++ b/internal/llm/anthropic_test.go @@ -0,0 +1,278 @@ +package llm + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +// newAnthropicTestProvider creates an AnthropicProvider pointing at the given mock server URL. +func newAnthropicTestProvider(serverURL string) *AnthropicProvider { + p := NewAnthropicProvider("test-api-key", "test-model") + p.baseURL = serverURL + return p +} + +// mustMarshal marshals v to JSON and panics on error (test helper). +func mustMarshal(v any) []byte { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return b +} + +// mockServer creates an httptest.Server that returns the provided JSON body for +// all POST requests to /v1/messages. +func mockServer(t *testing.T, respBody any) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/v1/messages" { + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + http.Error(w, "not found", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(mustMarshal(respBody)) + })) +} + +// captureServer creates a mock server that captures the decoded request body +// and returns the provided response. +func captureServer(t *testing.T, captured *anthropicRequest, respBody any) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + } + if err := json.Unmarshal(body, captured); err != nil { + t.Errorf("decode body: %v", err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(mustMarshal(respBody)) + })) +} + +// --------------------------------------------------------------------------- +// Test 1: Name() +// --------------------------------------------------------------------------- + +func TestAnthropicProviderName(t *testing.T) { + p := NewAnthropicProvider("key", "model") + if got := p.Name(); got != "anthropic" { + t.Errorf("Name() = %q, want %q", got, "anthropic") + } +} + +// --------------------------------------------------------------------------- +// Test 2: Chat end_turn — text response +// --------------------------------------------------------------------------- + +func TestAnthropicChat_EndTurn(t *testing.T) { + apiResp := anthropicResponse{ + ID: "msg_001", + Content: []anthropicContentBlock{ + {Type: "text", Text: "Hello, world!"}, + }, + StopReason: "end_turn", + } + apiResp.Usage.InputTokens = 10 + apiResp.Usage.OutputTokens = 5 + + srv := mockServer(t, apiResp) + defer srv.Close() + + p := newAnthropicTestProvider(srv.URL) + resp, err := p.Chat([]Message{{Role: "user", Content: "Hi"}}, nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello, world!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello, world!") + } + if resp.StopReason != "end_turn" { + t.Errorf("StopReason = %q, want %q", resp.StopReason, "end_turn") + } + if resp.PromptTok != 10 { + t.Errorf("PromptTok = %d, want 10", resp.PromptTok) + } + if resp.OutputTok != 5 { + t.Errorf("OutputTok = %d, want 5", resp.OutputTok) + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls)) + } +} + +// --------------------------------------------------------------------------- +// Test 3: Chat tool_use — ToolCalls populated +// --------------------------------------------------------------------------- + +func TestAnthropicChat_ToolUse(t *testing.T) { + apiResp := anthropicResponse{ + ID: "msg_002", + Content: []anthropicContentBlock{ + { + Type: "tool_use", + ID: "tc_1", + Name: "read_file", + Input: map[string]any{"path": "/tmp/test.txt"}, + }, + }, + StopReason: "tool_use", + } + apiResp.Usage.InputTokens = 20 + apiResp.Usage.OutputTokens = 8 + + srv := mockServer(t, apiResp) + defer srv.Close() + + p := newAnthropicTestProvider(srv.URL) + resp, err := p.Chat([]Message{{Role: "user", Content: "Read that file"}}, []ToolDef{ + { + Name: "read_file", + Description: "Read a file from disk", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + }, + }, + }) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.StopReason != "tool_use" { + t.Errorf("StopReason = %q, want %q", resp.StopReason, "tool_use") + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(resp.ToolCalls)) + } + tc := resp.ToolCalls[0] + if tc.ID != "tc_1" { + t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "tc_1") + } + if tc.Name != "read_file" { + t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "read_file") + } + if tc.Params["path"] != "/tmp/test.txt" { + t.Errorf("ToolCall.Params[path] = %q, want %q", tc.Params["path"], "/tmp/test.txt") + } +} + +// --------------------------------------------------------------------------- +// Test 4: System prompt extraction +// --------------------------------------------------------------------------- + +func TestAnthropicChat_SystemPrompt(t *testing.T) { + apiResp := anthropicResponse{ + ID: "msg_003", + Content: []anthropicContentBlock{{Type: "text", Text: "ok"}}, + StopReason: "end_turn", + } + + var captured anthropicRequest + srv := captureServer(t, &captured, apiResp) + defer srv.Close() + + p := newAnthropicTestProvider(srv.URL) + messages := []Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello"}, + } + _, err := p.Chat(messages, nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + + if captured.System != "You are a helpful assistant." { + t.Errorf("System = %q, want %q", captured.System, "You are a helpful assistant.") + } + + // The system message should NOT appear in the messages array. + for _, m := range captured.Messages { + if m.Role == "system" { + t.Error("system role message found in messages array — should be in System field only") + } + } + + // The user message should still be present. + if len(captured.Messages) != 1 { + t.Errorf("len(Messages) = %d, want 1", len(captured.Messages)) + } +} + +// --------------------------------------------------------------------------- +// Test 5: tool_result formatting +// --------------------------------------------------------------------------- + +func TestAnthropicChat_ToolResult(t *testing.T) { + apiResp := anthropicResponse{ + ID: "msg_004", + Content: []anthropicContentBlock{{Type: "text", Text: "Done."}}, + StopReason: "end_turn", + } + + var captured anthropicRequest + srv := captureServer(t, &captured, apiResp) + defer srv.Close() + + p := newAnthropicTestProvider(srv.URL) + messages := []Message{ + {Role: "user", Content: "Read the file"}, + {Role: "assistant", Content: "Sure"}, + { + Role: "tool_result", + Content: "file contents here", + ToolCallID: "tc_42", + }, + } + _, err := p.Chat(messages, nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + + // We expect 3 messages: user, assistant, user(tool_result). + if len(captured.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(captured.Messages)) + } + + // The last message must be role "user" (tool_result is wrapped as user). + last := captured.Messages[2] + if last.Role != "user" { + t.Errorf("last message Role = %q, want %q", last.Role, "user") + } + + // Decode the content blocks. + var blocks []anthropicContentBlock + if err := json.Unmarshal(last.Content, &blocks); err != nil { + t.Fatalf("decode last message content: %v", err) + } + if len(blocks) != 1 { + t.Fatalf("len(blocks) = %d, want 1", len(blocks)) + } + block := blocks[0] + if block.Type != "tool_result" { + t.Errorf("block.Type = %q, want %q", block.Type, "tool_result") + } + if block.ToolUseID != "tc_42" { + t.Errorf("block.ToolUseID = %q, want %q", block.ToolUseID, "tc_42") + } + if block.Content != "file contents here" { + t.Errorf("block.Content = %q, want %q", block.Content, "file contents here") + } +} + +// --------------------------------------------------------------------------- +// Compile-time interface check +// --------------------------------------------------------------------------- + +func TestAnthropicProviderImplementsProvider(t *testing.T) { + var _ Provider = (*AnthropicProvider)(nil) +} From 12abca3ebdf9c0c0ea36e73f025a8036084aa1d7 Mon Sep 17 00:00:00 2001 From: Jared Pleva Date: Tue, 31 Mar 2026 22:54:49 +0000 Subject: [PATCH 5/9] feat(shellforge): wire native tool-use into agent loop for Anthropic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When Provider returns structured ToolCalls, handle them directly: normalize → governance → execute → tool_result feedback. Bypasses text-based intent parsing. Uses []llm.Message internally when Provider is set. Ollama/legacy path unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/agent/loop.go | 708 +++++++++++++++++++++++------------- internal/agent/loop_test.go | 349 ++++++++++++++++++ 2 files changed, 801 insertions(+), 256 deletions(-) create mode 100644 internal/agent/loop_test.go diff --git a/internal/agent/loop.go b/internal/agent/loop.go index dcf7179..195d9e3 100644 --- a/internal/agent/loop.go +++ b/internal/agent/loop.go @@ -5,258 +5,395 @@ // bare JSON, OpenAI function_call format). Every extracted action goes // through AgentGuard governance — no exceptions. // +// When a Provider with native tool-use is set (e.g. Anthropic), the loop +// uses structured ToolCalls directly instead of text-based intent parsing. +// // This is the core of ShellForge's moat: you cannot trust the transport // layer for action integrity. The intent parser makes ShellForge // model-agnostic and format-agnostic. package agent import ( -"fmt" -"time" - -"github.com/AgentGuardHQ/shellforge/internal/action" -"github.com/AgentGuardHQ/shellforge/internal/correction" -"github.com/AgentGuardHQ/shellforge/internal/governance" -"github.com/AgentGuardHQ/shellforge/internal/intent" -"github.com/AgentGuardHQ/shellforge/internal/llm" -"github.com/AgentGuardHQ/shellforge/internal/logger" -"github.com/AgentGuardHQ/shellforge/internal/normalizer" -"github.com/AgentGuardHQ/shellforge/internal/ollama" -"github.com/AgentGuardHQ/shellforge/internal/tools" + "fmt" + "time" + + "github.com/AgentGuardHQ/shellforge/internal/action" + "github.com/AgentGuardHQ/shellforge/internal/correction" + "github.com/AgentGuardHQ/shellforge/internal/governance" + "github.com/AgentGuardHQ/shellforge/internal/intent" + "github.com/AgentGuardHQ/shellforge/internal/llm" + "github.com/AgentGuardHQ/shellforge/internal/logger" + "github.com/AgentGuardHQ/shellforge/internal/normalizer" + "github.com/AgentGuardHQ/shellforge/internal/ollama" + "github.com/AgentGuardHQ/shellforge/internal/tools" ) type LoopConfig struct { -Agent string -System string -UserPrompt string -Model string -MaxTurns int -TimeoutMs int -OutputDir string -TokenBudget int -Provider llm.Provider // optional; nil falls back to legacy ollama.Chat() + Agent string + System string + UserPrompt string + Model string + MaxTurns int + TimeoutMs int + OutputDir string + TokenBudget int + Provider llm.Provider // optional; nil falls back to legacy ollama.Chat() } type RunResult struct { -Success bool -ExitReason string // "final_answer", "timeout", "max_turns", "model_error" -Output string -Turns int -ToolCalls int -Denials int -PromptTok int -ResponseTok int -DurationMs int64 -Log []string + Success bool + ExitReason string // "final_answer", "timeout", "max_turns", "model_error" + Output string + Turns int + ToolCalls int + Denials int + PromptTok int + ResponseTok int + DurationMs int64 + Log []string } func RunLoop(cfg LoopConfig, engine *governance.Engine) (*RunResult, error) { -start := time.Now() -logger.Init(cfg.OutputDir, cfg.Agent) -defer logger.Close() - -// Orchestrator integration: generate run identity and correction engine. -runID := fmt.Sprintf("run_%d", time.Now().UnixMilli()) -var seq int -corrector := correction.NewEngine(3, 10) // 3 retries per action, 10 total budget - -systemPrompt := buildSystemPrompt(cfg.System) -messages := []ollama.ChatMessage{ -{Role: "system", Content: systemPrompt}, -{Role: "user", Content: cfg.UserPrompt}, -} - -result := &RunResult{} -var log []string - -logger.Agent(cfg.Agent, fmt.Sprintf("starting — max %d turns, model: %s, run: %s", cfg.MaxTurns, cfg.Model, runID)) - -for turn := 1; turn <= cfg.MaxTurns; turn++ { -elapsed := time.Since(start).Milliseconds() -if int(elapsed) > cfg.TimeoutMs { -logger.Agent(cfg.Agent, fmt.Sprintf("timeout after %d turns", turn-1)) -result.Turns = turn - 1 -result.ExitReason = "timeout" -break -} - -// Compact if needed -compacted := compactMessages(messages, cfg.TokenBudget) - -tokEst := estimateTokens(compacted) -logger.Agent(cfg.Agent, fmt.Sprintf("turn %d/%d (~%d tokens)", turn, cfg.MaxTurns, tokEst)) - -var content string -var promptTok, outputTok int -var totalDurMs int64 - -if cfg.Provider != nil { -llmMsgs := toProviderMessages(compacted) -provResp, perr := cfg.Provider.Chat(llmMsgs, nil) -if perr != nil { -logger.Error(cfg.Agent, cfg.Provider.Name()+": "+perr.Error()) -result.Output = "Model error: " + perr.Error() -result.Turns = turn -result.ExitReason = "model_error" -break -} -content = provResp.Content -promptTok = provResp.PromptTok -outputTok = provResp.OutputTok -} else { -resp, err := ollama.Chat(compacted, cfg.Model) -if err != nil { -logger.Error(cfg.Agent, "ollama: "+err.Error()) -result.Output = "Model error: " + err.Error() -result.Turns = turn -result.ExitReason = "model_error" -break -} -content = resp.Message.Content -promptTok = resp.PromptEval -outputTok = resp.EvalCount -totalDurMs = resp.TotalDuration / 1_000_000 + start := time.Now() + logger.Init(cfg.OutputDir, cfg.Agent) + defer logger.Close() + + if cfg.Provider != nil { + return runProviderLoop(cfg, engine, start) + } + return runOllamaLoop(cfg, engine, start) } -result.PromptTok += promptTok -result.ResponseTok += outputTok -logger.ModelCall(cfg.Agent, promptTok, outputTok, totalDurMs) - -messages = append(messages, ollama.ChatMessage{Role: "assistant", Content: content}) - -// ── Intent parser: extract action from ANY format ── -// This is the format-agnostic layer. Works regardless of whether the model -// emits structured tool_calls, JSON blocks, XML tags, or bare JSON. -parsed := intent.Parse(content) -if parsed == nil { -// No actionable intent — this is a final answer. -result.Output = content -result.Turns = turn -result.ExitReason = "final_answer" -logger.Agent(cfg.Agent, fmt.Sprintf("done — %d turns, %d tool calls", turn, result.ToolCalls)) -break +// ---------------------------------------------------------------------------- +// Provider path: uses []llm.Message + native tool-use +// ---------------------------------------------------------------------------- + +func runProviderLoop(cfg LoopConfig, engine *governance.Engine, start time.Time) (*RunResult, error) { + runID := fmt.Sprintf("run_%d", time.Now().UnixMilli()) + var seq int + corrector := correction.NewEngine(3, 10) + + systemPrompt := buildSystemPrompt(cfg.System) + messages := []llm.Message{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: cfg.UserPrompt}, + } + + toolDefs := buildToolDefs() + + result := &RunResult{} + var log []string + + logger.Agent(cfg.Agent, fmt.Sprintf("starting — max %d turns, provider: %s, run: %s", cfg.MaxTurns, cfg.Provider.Name(), runID)) + + for turn := 1; turn <= cfg.MaxTurns; turn++ { + elapsed := time.Since(start).Milliseconds() + if int(elapsed) > cfg.TimeoutMs { + logger.Agent(cfg.Agent, fmt.Sprintf("timeout after %d turns", turn-1)) + result.Turns = turn - 1 + result.ExitReason = "timeout" + break + } + + compacted := compactLLMMessages(messages, cfg.TokenBudget) + + tokEst := estimateLLMTokens(compacted) + logger.Agent(cfg.Agent, fmt.Sprintf("turn %d/%d (~%d tokens)", turn, cfg.MaxTurns, tokEst)) + + provResp, perr := cfg.Provider.Chat(compacted, toolDefs) + if perr != nil { + logger.Error(cfg.Agent, cfg.Provider.Name()+": "+perr.Error()) + result.Output = "Model error: " + perr.Error() + result.Turns = turn + result.ExitReason = "model_error" + break + } + + result.PromptTok += provResp.PromptTok + result.ResponseTok += provResp.OutputTok + logger.ModelCall(cfg.Agent, provResp.PromptTok, provResp.OutputTok, 0) + + // ── Native tool-use path ── + if len(provResp.ToolCalls) > 0 { + // Append the assistant message (may contain text + tool_use blocks). + messages = append(messages, llm.Message{Role: "assistant", Content: provResp.Content}) + + for _, tc := range provResp.ToolCalls { + logger.Agent(cfg.Agent, fmt.Sprintf("tool_use: %s (id: %s)", tc.Name, tc.ID)) + + result.ToolCalls++ + seq++ + + // ── Normalizer: convert to Canonical Action Representation ── + proposal := normalizer.Normalize(runID, seq, cfg.Agent, tc.Name, tc.Params) + fp := normalizer.Fingerprint(proposal) + + // ── Correction engine: check retry budget ── + canAttempt, skipReason := corrector.ShouldCorrect(fp) + if !canAttempt { + logger.Agent(cfg.Agent, fmt.Sprintf("action skipped: %s", skipReason)) + messages = append(messages, llm.Message{ + Role: "tool_result", + Content: fmt.Sprintf("Tool %q was skipped: %s. Try a different approach.", tc.Name, skipReason), + ToolCallID: tc.ID, + }) + log = append(log, fmt.Sprintf("[turn %d] %s → skipped (%s)", turn, tc.Name, skipReason)) + continue + } + + // ── Governance evaluation ── + decision := engine.Evaluate(tc.Name, tc.Params) + + if !decision.Allowed { + result.Denials++ + + govDecision := action.GovernanceDecision{ + Allowed: false, + Decision: "deny", + Reason: decision.Reason, + Rule: decision.PolicyName, + } + + corrector.RecordDenial(fp, govDecision) + logger.Governance(cfg.Agent, tc.Name, tc.Params, decision.Allowed, decision.PolicyName, decision.Reason) + + canCorrect, _ := corrector.ShouldCorrect(fp) + var feedback string + if canCorrect { + feedback = corrector.BuildFeedback(proposal, govDecision) + logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — sending correction feedback (escalation: %s)", tc.Name, corrector.Level())) + } else { + feedback = fmt.Sprintf("Tool %q was denied and cannot be retried. Move on to a different approach.", tc.Name) + logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — no retries left, skipping", tc.Name)) + } + + messages = append(messages, llm.Message{ + Role: "tool_result", + Content: feedback, + ToolCallID: tc.ID, + }) + log = append(log, fmt.Sprintf("[turn %d] %s → denied (%s)", turn, tc.Name, decision.PolicyName)) + continue + } + + // ── Governance allowed: execute tool ── + logger.Governance(cfg.Agent, tc.Name, tc.Params, decision.Allowed, decision.PolicyName, decision.Reason) + toolResult := tools.ExecuteDirect(tc.Name, tc.Params, engine.GetTimeout()) + logger.ToolResult(cfg.Agent, tc.Name, toolResult.Success, toolResult.Output) + + var msg string + if toolResult.Success { + msg = fmt.Sprintf("Tool %q returned:\n%s", tc.Name, toolResult.Output) + } else { + msg = fmt.Sprintf("Tool %q failed: %s", tc.Name, toolResult.Output) + } + messages = append(messages, llm.Message{ + Role: "tool_result", + Content: msg, + ToolCallID: tc.ID, + }) + + log = append(log, fmt.Sprintf("[turn %d] %s → %s", turn, tc.Name, boolStr(toolResult.Success, "ok", "fail"))) + } + + // Last turn: force final answer + if turn == cfg.MaxTurns { + messages = append(messages, llm.Message{ + Role: "user", + Content: "You've used all turns. Give your final answer now.", + }) + finalMsgs := compactLLMMessages(messages, cfg.TokenBudget) + finalResp, ferr := cfg.Provider.Chat(finalMsgs, toolDefs) + if ferr == nil { + result.Output = finalResp.Content + result.PromptTok += finalResp.PromptTok + result.ResponseTok += finalResp.OutputTok + } + result.Turns = turn + result.ExitReason = "max_turns" + } + continue + } + + // ── No tool calls: final answer (end_turn) ── + result.Output = provResp.Content + result.Turns = turn + result.ExitReason = "final_answer" + logger.Agent(cfg.Agent, fmt.Sprintf("done — %d turns, %d tool calls", turn, result.ToolCalls)) + break + } + + result.DurationMs = time.Since(start).Milliseconds() + result.Success = result.ExitReason == "final_answer" + result.Log = log + return result, nil } -logger.Agent(cfg.Agent, fmt.Sprintf("intent: %s via %s", parsed.Tool, parsed.Source)) - -result.ToolCalls++ -seq++ - -// ── Normalizer: convert extracted intent to Canonical Action Representation ── -proposal := normalizer.Normalize(runID, seq, cfg.Agent, parsed.Tool, parsed.Params) -fp := normalizer.Fingerprint(proposal) - -// ── Correction engine: check if this action should be retried or skipped ── -canAttempt, skipReason := corrector.ShouldCorrect(fp) -if !canAttempt { -// Too many retries or in lockdown — skip this action entirely. -logger.Agent(cfg.Agent, fmt.Sprintf("action skipped: %s", skipReason)) -messages = append(messages, ollama.ChatMessage{ -Role: "user", -Content: fmt.Sprintf("Tool %q was skipped: %s. Try a different approach.", parsed.Tool, skipReason), -}) -summary := fmt.Sprintf("[turn %d] %s → skipped (%s)", turn, parsed.Tool, skipReason) -log = append(log, summary) -continue -} - -// ── Governance evaluation (existing) ── -decision := engine.Evaluate(parsed.Tool, parsed.Params) - -if !decision.Allowed { -result.Denials++ - -// Map governance.Decision to action.GovernanceDecision for the correction engine. -govDecision := action.GovernanceDecision{ -Allowed: false, -Decision: "deny", -Reason: decision.Reason, -Rule: decision.PolicyName, -} - -// Record denial and attempt correction. -corrector.RecordDenial(fp, govDecision) -logger.Governance(cfg.Agent, parsed.Tool, parsed.Params, decision.Allowed, decision.PolicyName, decision.Reason) - -canCorrect, _ := corrector.ShouldCorrect(fp) -if canCorrect { -// Build corrective feedback and feed it back to the LLM. -feedback := corrector.BuildFeedback(proposal, govDecision) -logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — sending correction feedback (escalation: %s)", parsed.Tool, corrector.Level())) -messages = append(messages, ollama.ChatMessage{ -Role: "user", -Content: feedback, -}) -} else { -// Exhausted retries — skip and inform the LLM. -logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — no retries left, skipping", parsed.Tool)) -messages = append(messages, ollama.ChatMessage{ -Role: "user", -Content: fmt.Sprintf("Tool %q was denied and cannot be retried. Move on to a different approach.", parsed.Tool), -}) -} - -summary := fmt.Sprintf("[turn %d] %s → denied (%s)", turn, parsed.Tool, decision.PolicyName) -log = append(log, summary) -continue -} - -// ── Governance allowed: log and execute tool ── -logger.Governance(cfg.Agent, parsed.Tool, parsed.Params, decision.Allowed, decision.PolicyName, decision.Reason) -toolResult := tools.ExecuteDirect(parsed.Tool, parsed.Params, engine.GetTimeout()) -logger.ToolResult(cfg.Agent, parsed.Tool, toolResult.Success, toolResult.Output) - -var msg string -if toolResult.Success { -msg = fmt.Sprintf("Tool %q returned:\n%s", parsed.Tool, toolResult.Output) -} else { -msg = fmt.Sprintf("Tool %q failed: %s", parsed.Tool, toolResult.Output) -} -messages = append(messages, ollama.ChatMessage{Role: "user", Content: msg}) - -summary := fmt.Sprintf("[turn %d] %s → %s", turn, parsed.Tool, boolStr(toolResult.Success, "ok", "fail")) -log = append(log, summary) - -// Last turn: force final answer -if turn == cfg.MaxTurns { -messages = append(messages, ollama.ChatMessage{ -Role: "user", -Content: "You've used all turns. Give your final answer now.", -}) -if cfg.Provider != nil { -finalMsgs := toProviderMessages(compactMessages(messages, cfg.TokenBudget)) -finalResp, ferr := cfg.Provider.Chat(finalMsgs, nil) -if ferr == nil { -result.Output = finalResp.Content -result.PromptTok += finalResp.PromptTok -result.ResponseTok += finalResp.OutputTok -} -} else { -final, err := ollama.Chat(compactMessages(messages, cfg.TokenBudget), cfg.Model) -if err == nil { -result.Output = final.Message.Content -result.PromptTok += final.PromptEval -result.ResponseTok += final.EvalCount -} -} -result.Turns = turn -result.ExitReason = "max_turns" -} -} - -result.DurationMs = time.Since(start).Milliseconds() -result.Success = result.ExitReason == "final_answer" -result.Log = log -return result, nil +// ---------------------------------------------------------------------------- +// Ollama/legacy path: uses []ollama.ChatMessage + intent.Parse() +// ---------------------------------------------------------------------------- + +func runOllamaLoop(cfg LoopConfig, engine *governance.Engine, start time.Time) (*RunResult, error) { + runID := fmt.Sprintf("run_%d", time.Now().UnixMilli()) + var seq int + corrector := correction.NewEngine(3, 10) + + systemPrompt := buildSystemPrompt(cfg.System) + messages := []ollama.ChatMessage{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: cfg.UserPrompt}, + } + + result := &RunResult{} + var log []string + + logger.Agent(cfg.Agent, fmt.Sprintf("starting — max %d turns, model: %s, run: %s", cfg.MaxTurns, cfg.Model, runID)) + + for turn := 1; turn <= cfg.MaxTurns; turn++ { + elapsed := time.Since(start).Milliseconds() + if int(elapsed) > cfg.TimeoutMs { + logger.Agent(cfg.Agent, fmt.Sprintf("timeout after %d turns", turn-1)) + result.Turns = turn - 1 + result.ExitReason = "timeout" + break + } + + compacted := compactMessages(messages, cfg.TokenBudget) + + tokEst := estimateTokens(compacted) + logger.Agent(cfg.Agent, fmt.Sprintf("turn %d/%d (~%d tokens)", turn, cfg.MaxTurns, tokEst)) + + resp, err := ollama.Chat(compacted, cfg.Model) + if err != nil { + logger.Error(cfg.Agent, "ollama: "+err.Error()) + result.Output = "Model error: " + err.Error() + result.Turns = turn + result.ExitReason = "model_error" + break + } + + content := resp.Message.Content + promptTok := resp.PromptEval + outputTok := resp.EvalCount + totalDurMs := resp.TotalDuration / 1_000_000 + + result.PromptTok += promptTok + result.ResponseTok += outputTok + logger.ModelCall(cfg.Agent, promptTok, outputTok, totalDurMs) + + messages = append(messages, ollama.ChatMessage{Role: "assistant", Content: content}) + + // ── Intent parser: extract action from ANY format ── + parsed := intent.Parse(content) + if parsed == nil { + result.Output = content + result.Turns = turn + result.ExitReason = "final_answer" + logger.Agent(cfg.Agent, fmt.Sprintf("done — %d turns, %d tool calls", turn, result.ToolCalls)) + break + } + + logger.Agent(cfg.Agent, fmt.Sprintf("intent: %s via %s", parsed.Tool, parsed.Source)) + + result.ToolCalls++ + seq++ + + proposal := normalizer.Normalize(runID, seq, cfg.Agent, parsed.Tool, parsed.Params) + fp := normalizer.Fingerprint(proposal) + + canAttempt, skipReason := corrector.ShouldCorrect(fp) + if !canAttempt { + logger.Agent(cfg.Agent, fmt.Sprintf("action skipped: %s", skipReason)) + messages = append(messages, ollama.ChatMessage{ + Role: "user", + Content: fmt.Sprintf("Tool %q was skipped: %s. Try a different approach.", parsed.Tool, skipReason), + }) + summary := fmt.Sprintf("[turn %d] %s → skipped (%s)", turn, parsed.Tool, skipReason) + log = append(log, summary) + continue + } + + decision := engine.Evaluate(parsed.Tool, parsed.Params) + + if !decision.Allowed { + result.Denials++ + + govDecision := action.GovernanceDecision{ + Allowed: false, + Decision: "deny", + Reason: decision.Reason, + Rule: decision.PolicyName, + } + + corrector.RecordDenial(fp, govDecision) + logger.Governance(cfg.Agent, parsed.Tool, parsed.Params, decision.Allowed, decision.PolicyName, decision.Reason) + + canCorrect, _ := corrector.ShouldCorrect(fp) + if canCorrect { + feedback := corrector.BuildFeedback(proposal, govDecision) + logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — sending correction feedback (escalation: %s)", parsed.Tool, corrector.Level())) + messages = append(messages, ollama.ChatMessage{ + Role: "user", + Content: feedback, + }) + } else { + logger.Agent(cfg.Agent, fmt.Sprintf("governance denied %q — no retries left, skipping", parsed.Tool)) + messages = append(messages, ollama.ChatMessage{ + Role: "user", + Content: fmt.Sprintf("Tool %q was denied and cannot be retried. Move on to a different approach.", parsed.Tool), + }) + } + + summary := fmt.Sprintf("[turn %d] %s → denied (%s)", turn, parsed.Tool, decision.PolicyName) + log = append(log, summary) + continue + } + + logger.Governance(cfg.Agent, parsed.Tool, parsed.Params, decision.Allowed, decision.PolicyName, decision.Reason) + toolResult := tools.ExecuteDirect(parsed.Tool, parsed.Params, engine.GetTimeout()) + logger.ToolResult(cfg.Agent, parsed.Tool, toolResult.Success, toolResult.Output) + + var msg string + if toolResult.Success { + msg = fmt.Sprintf("Tool %q returned:\n%s", parsed.Tool, toolResult.Output) + } else { + msg = fmt.Sprintf("Tool %q failed: %s", parsed.Tool, toolResult.Output) + } + messages = append(messages, ollama.ChatMessage{Role: "user", Content: msg}) + + summary := fmt.Sprintf("[turn %d] %s → %s", turn, parsed.Tool, boolStr(toolResult.Success, "ok", "fail")) + log = append(log, summary) + + // Last turn: force final answer + if turn == cfg.MaxTurns { + messages = append(messages, ollama.ChatMessage{ + Role: "user", + Content: "You've used all turns. Give your final answer now.", + }) + final, err := ollama.Chat(compactMessages(messages, cfg.TokenBudget), cfg.Model) + if err == nil { + result.Output = final.Message.Content + result.PromptTok += final.PromptEval + result.ResponseTok += final.EvalCount + } + result.Turns = turn + result.ExitReason = "max_turns" + } + } + + result.DurationMs = time.Since(start).Milliseconds() + result.Success = result.ExitReason == "final_answer" + result.Log = log + return result, nil } // Old parseToolCall/tryParse removed — replaced by intent.Parse() // which handles all formats: JSON blocks, XML tags, bare JSON, -// OpenAI function_call, and tool name/param aliasing. +// OpenAI function_call format, and tool name/param aliasing. func buildSystemPrompt(base string) string { -toolDocs := tools.FormatForPrompt() -return base + ` + toolDocs := tools.FormatForPrompt() + return base + ` ## Tools @@ -297,47 +434,106 @@ CORRECT: } func compactMessages(msgs []ollama.ChatMessage, budget int) []ollama.ChatMessage { -if budget <= 0 { -budget = 3000 + if budget <= 0 { + budget = 3000 + } + total := estimateTokens(msgs) + if total <= budget { + return msgs + } + + // Keep system (0), first user (1), and last N messages + result := []ollama.ChatMessage{msgs[0], msgs[1]} + remaining := msgs[2:] + + // Drop tool results from the middle until we fit + for total > budget && len(remaining) > 4 { + remaining = remaining[2:] // drop oldest assistant+tool pair + total = estimateTokens(append(result, remaining...)) + } + return append(result, remaining...) } -total := estimateTokens(msgs) -if total <= budget { -return msgs -} - -// Keep system (0), first user (1), and last N messages -result := []ollama.ChatMessage{msgs[0], msgs[1]} -remaining := msgs[2:] -// Drop tool results from the middle until we fit -for total > budget && len(remaining) > 4 { -remaining = remaining[2:] // drop oldest assistant+tool pair -total = estimateTokens(append(result, remaining...)) -} -return append(result, remaining...) +func estimateTokens(msgs []ollama.ChatMessage) int { + total := 0 + for _, m := range msgs { + total += len(m.Content) / 4 // rough approximation + } + return total } -func estimateTokens(msgs []ollama.ChatMessage) int { -total := 0 -for _, m := range msgs { -total += len(m.Content) / 4 // rough approximation +// compactLLMMessages is the []llm.Message equivalent of compactMessages. +func compactLLMMessages(msgs []llm.Message, budget int) []llm.Message { + if budget <= 0 { + budget = 3000 + } + total := estimateLLMTokens(msgs) + if total <= budget { + return msgs + } + + result := []llm.Message{msgs[0], msgs[1]} + remaining := msgs[2:] + + for total > budget && len(remaining) > 4 { + remaining = remaining[2:] + total = estimateLLMTokens(append(result, remaining...)) + } + return append(result, remaining...) } -return total + +// estimateLLMTokens estimates token count for []llm.Message. +func estimateLLMTokens(msgs []llm.Message) int { + total := 0 + for _, m := range msgs { + total += len(m.Content) / 4 + } + return total } func boolStr(b bool, t, f string) string { -if b { -return t -} -return f + if b { + return t + } + return f } // toProviderMessages converts []ollama.ChatMessage to []llm.Message for use // with the Provider interface. func toProviderMessages(msgs []ollama.ChatMessage) []llm.Message { -result := make([]llm.Message, len(msgs)) -for i, m := range msgs { -result[i] = llm.Message{Role: m.Role, Content: m.Content} + result := make([]llm.Message, len(msgs)) + for i, m := range msgs { + result[i] = llm.Message{Role: m.Role, Content: m.Content} + } + return result } -return result + +// buildToolDefs converts tools.Definitions into []llm.ToolDef for the Provider. +func buildToolDefs() []llm.ToolDef { + defs := make([]llm.ToolDef, len(tools.Definitions)) + for i, d := range tools.Definitions { + // Build JSON Schema from Param definitions. + properties := make(map[string]any, len(d.Params)) + required := make([]string, 0, len(d.Params)) + for _, p := range d.Params { + properties[p.Name] = map[string]any{ + "type": p.Type, + "description": p.Desc, + } + if p.Required { + required = append(required, p.Name) + } + } + + defs[i] = llm.ToolDef{ + Name: d.Name, + Description: d.Description, + Parameters: map[string]any{ + "type": "object", + "properties": properties, + "required": required, + }, + } + } + return defs } diff --git a/internal/agent/loop_test.go b/internal/agent/loop_test.go new file mode 100644 index 0000000..325413d --- /dev/null +++ b/internal/agent/loop_test.go @@ -0,0 +1,349 @@ +package agent + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/AgentGuardHQ/shellforge/internal/governance" + "github.com/AgentGuardHQ/shellforge/internal/llm" +) + +// mockProvider is a test double that returns pre-configured responses. +type mockProvider struct { + name string + responses []*llm.Response + calls int + received []mockCall +} + +type mockCall struct { + Messages []llm.Message + Tools []llm.ToolDef +} + +func (m *mockProvider) Name() string { return m.name } + +func (m *mockProvider) Chat(messages []llm.Message, tools []llm.ToolDef) (*llm.Response, error) { + m.received = append(m.received, mockCall{Messages: messages, Tools: tools}) + if m.calls >= len(m.responses) { + return nil, fmt.Errorf("mock: no more responses (called %d times, have %d)", m.calls+1, len(m.responses)) + } + resp := m.responses[m.calls] + m.calls++ + return resp, nil +} + +// newPermissiveEngine creates a governance engine that allows everything. +func newPermissiveEngine(t *testing.T) *governance.Engine { + t.Helper() + return &governance.Engine{ + Mode: "enforce", + Policies: nil, // no policies = default-allow + } +} + +// newDenyShellEngine creates a governance engine that denies run_shell commands +// containing "ls". +func newDenyShellEngine(t *testing.T) *governance.Engine { + t.Helper() + return &governance.Engine{ + Mode: "enforce", + Policies: []governance.Policy{ + { + Name: "deny-shell", + Description: "deny ls shell commands", + Match: governance.Match{Command: "ls"}, + Action: "deny", + Message: "shell commands are not allowed", + }, + }, + } +} + +// baseCfg returns a LoopConfig with reasonable test defaults. +func baseCfg(provider llm.Provider, outputDir string) LoopConfig { + return LoopConfig{ + Agent: "test-agent", + System: "You are a test assistant.", + UserPrompt: "What files are in this directory?", + Model: "test-model", + MaxTurns: 10, + TimeoutMs: 30000, + OutputDir: outputDir, + TokenBudget: 8000, + Provider: provider, + } +} + +// TestProviderToolCallThenFinalAnswer verifies that the loop processes a +// single tool call via native tool-use and then accepts a final answer. +func TestProviderToolCallThenFinalAnswer(t *testing.T) { + tmpDir := t.TempDir() + + // Create a file for list_files to find. + testFile := filepath.Join(tmpDir, "hello.txt") + os.WriteFile(testFile, []byte("hello world"), 0644) + + mock := &mockProvider{ + name: "mock-anthropic", + responses: []*llm.Response{ + // Turn 1: model requests list_files tool + { + Content: "", + StopReason: "tool_use", + ToolCalls: []llm.ToolCall{ + { + ID: "call_001", + Name: "list_files", + Params: map[string]string{"directory": tmpDir}, + }, + }, + PromptTok: 100, + OutputTok: 20, + }, + // Turn 2: model gives final answer + { + Content: "The directory contains hello.txt.", + StopReason: "end_turn", + ToolCalls: nil, + PromptTok: 150, + OutputTok: 30, + }, + }, + } + + cfg := baseCfg(mock, tmpDir) + engine := newPermissiveEngine(t) + + result, err := RunLoop(cfg, engine) + if err != nil { + t.Fatalf("RunLoop error: %v", err) + } + + if result.ExitReason != "final_answer" { + t.Errorf("ExitReason: got %q, want %q", result.ExitReason, "final_answer") + } + if !result.Success { + t.Errorf("Success: got false, want true") + } + if result.ToolCalls != 1 { + t.Errorf("ToolCalls: got %d, want %d", result.ToolCalls, 1) + } + if result.Denials != 0 { + t.Errorf("Denials: got %d, want %d", result.Denials, 0) + } + if result.Output != "The directory contains hello.txt." { + t.Errorf("Output: got %q, want %q", result.Output, "The directory contains hello.txt.") + } + if result.Turns != 2 { + t.Errorf("Turns: got %d, want %d", result.Turns, 2) + } + if result.PromptTok != 250 { + t.Errorf("PromptTok: got %d, want %d", result.PromptTok, 250) + } + if result.ResponseTok != 50 { + t.Errorf("ResponseTok: got %d, want %d", result.ResponseTok, 50) + } + + // Verify that tool definitions were passed to the provider. + if len(mock.received) < 1 { + t.Fatal("mock received no calls") + } + firstCall := mock.received[0] + if len(firstCall.Tools) == 0 { + t.Error("first call should have received tool definitions") + } + + // Verify that the second call has a tool_result message. + if len(mock.received) < 2 { + t.Fatal("mock should have received 2 calls") + } + secondCallMsgs := mock.received[1].Messages + hasToolResult := false + for _, m := range secondCallMsgs { + if m.Role == "tool_result" && m.ToolCallID == "call_001" { + hasToolResult = true + break + } + } + if !hasToolResult { + t.Error("second call should contain a tool_result message with ToolCallID=call_001") + } +} + +// TestProviderGovernanceDenial verifies that governance denials on native +// tool calls result in a tool_result message with denial feedback. +func TestProviderGovernanceDenial(t *testing.T) { + tmpDir := t.TempDir() + + mock := &mockProvider{ + name: "mock-anthropic", + responses: []*llm.Response{ + // Turn 1: model requests run_shell (will be denied) + { + Content: "", + StopReason: "tool_use", + ToolCalls: []llm.ToolCall{ + { + ID: "call_shell_1", + Name: "run_shell", + Params: map[string]string{"command": "ls -la"}, + }, + }, + PromptTok: 100, + OutputTok: 20, + }, + // Turn 2: model gives final answer after denial + { + Content: "I cannot run shell commands. Here is what I know.", + StopReason: "end_turn", + ToolCalls: nil, + PromptTok: 200, + OutputTok: 40, + }, + }, + } + + cfg := baseCfg(mock, tmpDir) + engine := newDenyShellEngine(t) + + result, err := RunLoop(cfg, engine) + if err != nil { + t.Fatalf("RunLoop error: %v", err) + } + + if result.ExitReason != "final_answer" { + t.Errorf("ExitReason: got %q, want %q", result.ExitReason, "final_answer") + } + if result.Denials != 1 { + t.Errorf("Denials: got %d, want %d", result.Denials, 1) + } + if result.ToolCalls != 1 { + t.Errorf("ToolCalls: got %d, want %d", result.ToolCalls, 1) + } + if result.Output != "I cannot run shell commands. Here is what I know." { + t.Errorf("Output: got %q", result.Output) + } + + // The second call should have a tool_result message with denial feedback. + if len(mock.received) < 2 { + t.Fatal("mock should have received 2 calls") + } + secondCallMsgs := mock.received[1].Messages + hasToolResult := false + for _, m := range secondCallMsgs { + if m.Role == "tool_result" && m.ToolCallID == "call_shell_1" { + hasToolResult = true + break + } + } + if !hasToolResult { + t.Error("second call should contain a tool_result message for denied tool call") + } +} + +// TestProviderNoToolCalls verifies that when the provider returns +// immediately with no tool calls (end_turn), the loop exits as final_answer. +func TestProviderNoToolCalls(t *testing.T) { + tmpDir := t.TempDir() + + mock := &mockProvider{ + name: "mock-anthropic", + responses: []*llm.Response{ + { + Content: "I can answer directly: this is a test.", + StopReason: "end_turn", + ToolCalls: nil, + PromptTok: 80, + OutputTok: 25, + }, + }, + } + + cfg := baseCfg(mock, tmpDir) + engine := newPermissiveEngine(t) + + result, err := RunLoop(cfg, engine) + if err != nil { + t.Fatalf("RunLoop error: %v", err) + } + + if result.ExitReason != "final_answer" { + t.Errorf("ExitReason: got %q, want %q", result.ExitReason, "final_answer") + } + if !result.Success { + t.Errorf("Success: got false, want true") + } + if result.ToolCalls != 0 { + t.Errorf("ToolCalls: got %d, want %d", result.ToolCalls, 0) + } + if result.Denials != 0 { + t.Errorf("Denials: got %d, want %d", result.Denials, 0) + } + if result.Turns != 1 { + t.Errorf("Turns: got %d, want %d", result.Turns, 1) + } + if result.Output != "I can answer directly: this is a test." { + t.Errorf("Output: got %q", result.Output) + } + if result.PromptTok != 80 { + t.Errorf("PromptTok: got %d, want %d", result.PromptTok, 80) + } + + // Only one call to the provider. + if mock.calls != 1 { + t.Errorf("provider calls: got %d, want %d", mock.calls, 1) + } +} + +// TestBuildToolDefs verifies that buildToolDefs produces valid llm.ToolDef +// entries from tools.Definitions. +func TestBuildToolDefs(t *testing.T) { + defs := buildToolDefs() + if len(defs) == 0 { + t.Fatal("buildToolDefs returned empty slice") + } + + // Check that each definition has a name, description, and parameters. + for _, d := range defs { + if d.Name == "" { + t.Error("ToolDef has empty Name") + } + if d.Description == "" { + t.Errorf("ToolDef %q has empty Description", d.Name) + } + if d.Parameters == nil { + t.Errorf("ToolDef %q has nil Parameters", d.Name) + continue + } + typ, ok := d.Parameters["type"] + if !ok || typ != "object" { + t.Errorf("ToolDef %q: Parameters[type] = %v, want %q", d.Name, typ, "object") + } + props, ok := d.Parameters["properties"] + if !ok || props == nil { + t.Errorf("ToolDef %q: missing properties", d.Name) + } + } + + // Spot-check read_file has a "path" property. + found := false + for _, d := range defs { + if d.Name == "read_file" { + found = true + props := d.Parameters["properties"].(map[string]any) + if _, ok := props["path"]; !ok { + t.Error("read_file ToolDef missing 'path' property") + } + req := d.Parameters["required"].([]string) + if len(req) != 1 || req[0] != "path" { + t.Errorf("read_file required: got %v, want [path]", req) + } + } + } + if !found { + t.Error("buildToolDefs missing read_file definition") + } +} From 9f55835f388932e17594db596a1547968bb10fa0 Mon Sep 17 00:00:00 2001 From: Jared Pleva Date: Tue, 31 Mar 2026 22:57:31 +0000 Subject: [PATCH 6/9] feat(shellforge): add --provider flag for Anthropic API backend shellforge agent --provider anthropic 'prompt' runs the agent loop against the Anthropic Messages API with native tool-use. Defaults to Ollama for backwards compatibility. Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/shellforge/main.go | 53 +++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/cmd/shellforge/main.go b/cmd/shellforge/main.go index 41e271b..d4e2ca4 100644 --- a/cmd/shellforge/main.go +++ b/cmd/shellforge/main.go @@ -17,6 +17,7 @@ import ( "github.com/AgentGuardHQ/shellforge/internal/agent" "github.com/AgentGuardHQ/shellforge/internal/governance" +"github.com/AgentGuardHQ/shellforge/internal/llm" "github.com/AgentGuardHQ/shellforge/internal/logger" "github.com/AgentGuardHQ/shellforge/internal/ollama" "github.com/AgentGuardHQ/shellforge/internal/scheduler" @@ -60,11 +61,24 @@ cmdRun(driver, prompt) case "evaluate": cmdEvaluate() case "agent": -if len(os.Args) < 3 { -fmt.Fprintln(os.Stderr, "Usage: shellforge agent \"your prompt\"") +{ +providerName := "" +remaining := os.Args[2:] +filtered := remaining[:0] +for i := 0; i < len(remaining); i++ { +if remaining[i] == "--provider" && i+1 < len(remaining) { +providerName = remaining[i+1] +i++ +} else { +filtered = append(filtered, remaining[i]) +} +} +if len(filtered) == 0 { +fmt.Fprintln(os.Stderr, "Usage: shellforge agent [--provider ] \"your prompt\"") os.Exit(1) } -cmdAgent(strings.Join(os.Args[2:], " ")) +cmdAgent(strings.Join(filtered, " "), providerName) +} case "swarm": cmdSwarm() case "serve": @@ -656,11 +670,29 @@ printResult("report-agent", result) saveReport("outputs/reports", "report", result) } -func cmdAgent(prompt string) { +func cmdAgent(prompt, providerName string) { engine := mustGovernance() + +var provider llm.Provider +switch providerName { +case "anthropic": +apiKey := os.Getenv("ANTHROPIC_API_KEY") +if apiKey == "" { +fmt.Fprintln(os.Stderr, "Error: ANTHROPIC_API_KEY environment variable not set") +os.Exit(1) +} +model := os.Getenv("ANTHROPIC_MODEL") +if model == "" { +model = "claude-haiku-4-5-20251001" +} +provider = llm.NewAnthropicProvider(apiKey, model) +fmt.Fprintf(os.Stderr, "Using Anthropic API (model: %s)\n", model) +default: +// Legacy Ollama path mustOllama() +} -result, err := agent.RunLoop(agent.LoopConfig{ +cfg := agent.LoopConfig{ Agent: "prototype-agent", System: "You are a senior engineer. Complete the requested task using available tools. Read files, write files, run commands, search code. Be precise.", UserPrompt: prompt, @@ -669,7 +701,10 @@ MaxTurns: 15, TimeoutMs: 180_000, OutputDir: "outputs/logs", TokenBudget: 3000, -}, engine) +Provider: provider, +} + +result, err := agent.RunLoop(cfg, engine) if err != nil { logger.Error("prototype-agent", err.Error()) os.Exit(1) @@ -913,7 +948,7 @@ func printResult(name string, r *agent.RunResult) { fmt.Println() status := "✓ success" if !r.Success { -status = "✗ failed" +status = fmt.Sprintf("✗ %s", r.ExitReason) } fmt.Printf("[%s] %s — %d turns, %d tool calls, %d denials\n", name, status, r.Turns, r.ToolCalls, r.Denials) fmt.Printf(" tokens: %d prompt + %d response | %dms\n", r.PromptTok, r.ResponseTok, r.DurationMs) @@ -927,8 +962,8 @@ func saveReport(dir, prefix string, r *agent.RunResult) { os.MkdirAll(dir, 0o755) ts := time.Now().Format("2006-01-02T15-04-05") path := filepath.Join(dir, fmt.Sprintf("%s-%s.md", prefix, ts)) -content := fmt.Sprintf("# %s — %s\n\n**Turns:** %d | **Tool calls:** %d | **Denials:** %d\n**Tokens:** %d+%d | **Duration:** %dms\n\n%s\n", -prefix, time.Now().Format(time.RFC3339), r.Turns, r.ToolCalls, r.Denials, r.PromptTok, r.ResponseTok, r.DurationMs, r.Output) +content := fmt.Sprintf("# %s — %s\n\n**Exit:** %s | **Turns:** %d | **Tool calls:** %d | **Denials:** %d\n**Tokens:** %d+%d | **Duration:** %dms\n\n%s\n", +prefix, time.Now().Format(time.RFC3339), r.ExitReason, r.Turns, r.ToolCalls, r.Denials, r.PromptTok, r.ResponseTok, r.DurationMs, r.Output) os.WriteFile(path, []byte(content), 0o644) fmt.Printf("\n→ Saved to %s\n", path) } From efa140c16dfde4bfa731ccecd281b36f95016ed1 Mon Sep 17 00:00:00 2001 From: Jared Pleva Date: Tue, 31 Mar 2026 22:59:33 +0000 Subject: [PATCH 7/9] fix(shellforge): serialize tool_use blocks in assistant messages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When Anthropic responds with tool_use blocks, the assistant message in conversation history must include structured tool_use content — not just plain text. Added ToolCalls field to llm.Message and reconstruct tool_use blocks in convertMessages(). Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/agent/loop.go | 9 +++++++-- internal/llm/anthropic.go | 37 +++++++++++++++++++++++++++++++++++-- internal/llm/provider.go | 5 +++-- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/internal/agent/loop.go b/internal/agent/loop.go index 195d9e3..cfaab2d 100644 --- a/internal/agent/loop.go +++ b/internal/agent/loop.go @@ -115,8 +115,13 @@ func runProviderLoop(cfg LoopConfig, engine *governance.Engine, start time.Time) // ── Native tool-use path ── if len(provResp.ToolCalls) > 0 { - // Append the assistant message (may contain text + tool_use blocks). - messages = append(messages, llm.Message{Role: "assistant", Content: provResp.Content}) + // Append the assistant message with tool calls so the API receives + // structured tool_use blocks on the next turn. + messages = append(messages, llm.Message{ + Role: "assistant", + Content: provResp.Content, + ToolCalls: provResp.ToolCalls, + }) for _, tc := range provResp.ToolCalls { logger.Agent(cfg.Agent, fmt.Sprintf("tool_use: %s (id: %s)", tc.Name, tc.ID)) diff --git a/internal/llm/anthropic.go b/internal/llm/anthropic.go index 3915364..5ab9f92 100644 --- a/internal/llm/anthropic.go +++ b/internal/llm/anthropic.go @@ -221,13 +221,46 @@ func convertMessages(messages []Message) ([]anthropicMessage, error) { } result = append(result, anthropicMessage{Role: "user", Content: raw}) - case "assistant", "user": + case "assistant": + if len(m.ToolCalls) > 0 { + // Assistant message with tool_use blocks — reconstruct structured content. + var blocks []anthropicContentBlock + if m.Content != "" { + blocks = append(blocks, anthropicContentBlock{Type: "text", Text: m.Content}) + } + for _, tc := range m.ToolCalls { + input := make(map[string]any, len(tc.Params)) + for k, v := range tc.Params { + input[k] = v + } + blocks = append(blocks, anthropicContentBlock{ + Type: "tool_use", + ID: tc.ID, + Name: tc.Name, + Input: input, + }) + } + raw, err = json.Marshal(blocks) + if err != nil { + return nil, err + } + result = append(result, anthropicMessage{Role: "assistant", Content: raw}) + } else { + // Plain text assistant message. + raw, err = json.Marshal(m.Content) + if err != nil { + return nil, err + } + result = append(result, anthropicMessage{Role: "assistant", Content: raw}) + } + + case "user": // Plain text content block. raw, err = json.Marshal(m.Content) if err != nil { return nil, err } - result = append(result, anthropicMessage{Role: m.Role, Content: raw}) + result = append(result, anthropicMessage{Role: "user", Content: raw}) default: // Skip unknown roles (e.g. "system" already extracted). diff --git a/internal/llm/provider.go b/internal/llm/provider.go index 8f8a77a..44cf9b0 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -8,9 +8,10 @@ type Provider interface { // Message is a conversation turn. type Message struct { - Role string // "system", "user", "assistant", "tool_result" + Role string // "system", "user", "assistant", "tool_result" Content string - ToolCallID string // set when Role == "tool_result" + ToolCallID string // set when Role == "tool_result" + ToolCalls []ToolCall // set on assistant messages that invoked tools } // ToolDef describes a tool the model can invoke. From b3a1888f3dd985ea5d71f5df275f2fa41022b580 Mon Sep 17 00:00:00 2001 From: Jared Pleva Date: Tue, 31 Mar 2026 23:44:50 +0000 Subject: [PATCH 8/9] feat(shellforge): add extended thinking budget control ThinkingBudget field on AnthropicProvider caps thinking tokens per call. Prevents runaway output spend on complex tasks. CLI flag: --thinking-budget . Thinking content blocks silently consumed in response parsing. Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/shellforge/main.go | 18 ++++++-- internal/llm/anthropic.go | 86 +++++++++++++++++++++++++++++++-------- internal/llm/provider.go | 12 +++--- 3 files changed, 89 insertions(+), 27 deletions(-) diff --git a/cmd/shellforge/main.go b/cmd/shellforge/main.go index d4e2ca4..115a358 100644 --- a/cmd/shellforge/main.go +++ b/cmd/shellforge/main.go @@ -63,21 +63,25 @@ cmdEvaluate() case "agent": { providerName := "" +thinkingBudget := 0 remaining := os.Args[2:] filtered := remaining[:0] for i := 0; i < len(remaining); i++ { if remaining[i] == "--provider" && i+1 < len(remaining) { providerName = remaining[i+1] i++ +} else if remaining[i] == "--thinking-budget" && i+1 < len(remaining) { +fmt.Sscanf(remaining[i+1], "%d", &thinkingBudget) +i++ } else { filtered = append(filtered, remaining[i]) } } if len(filtered) == 0 { -fmt.Fprintln(os.Stderr, "Usage: shellforge agent [--provider ] \"your prompt\"") +fmt.Fprintln(os.Stderr, "Usage: shellforge agent [--provider ] [--thinking-budget ] \"your prompt\"") os.Exit(1) } -cmdAgent(strings.Join(filtered, " "), providerName) +cmdAgent(strings.Join(filtered, " "), providerName, thinkingBudget) } case "swarm": cmdSwarm() @@ -670,7 +674,7 @@ printResult("report-agent", result) saveReport("outputs/reports", "report", result) } -func cmdAgent(prompt, providerName string) { +func cmdAgent(prompt, providerName string, thinkingBudget int) { engine := mustGovernance() var provider llm.Provider @@ -685,8 +689,14 @@ model := os.Getenv("ANTHROPIC_MODEL") if model == "" { model = "claude-haiku-4-5-20251001" } -provider = llm.NewAnthropicProvider(apiKey, model) +p := llm.NewAnthropicProvider(apiKey, model) +if thinkingBudget > 0 { +p.ThinkingBudget = thinkingBudget +fmt.Fprintf(os.Stderr, "Using Anthropic API (model: %s, thinking budget: %d tokens)\n", model, thinkingBudget) +} else { fmt.Fprintf(os.Stderr, "Using Anthropic API (model: %s)\n", model) +} +provider = p default: // Legacy Ollama path mustOllama() diff --git a/internal/llm/anthropic.go b/internal/llm/anthropic.go index 5ab9f92..0542b1d 100644 --- a/internal/llm/anthropic.go +++ b/internal/llm/anthropic.go @@ -19,10 +19,11 @@ const ( // AnthropicProvider calls the Anthropic Messages API via stdlib HTTP. type AnthropicProvider struct { - apiKey string - model string - baseURL string - client *http.Client + apiKey string + model string + baseURL string + client *http.Client + ThinkingBudget int // max thinking tokens (0 = disabled) } // NewAnthropicProvider creates an AnthropicProvider. @@ -70,20 +71,33 @@ type anthropicMessage struct { Content json.RawMessage `json:"content"` } +// cacheControl instructs Anthropic to cache this content block for 5 minutes. +type cacheControl struct { + Type string `json:"type"` +} + // anthropicToolDef is the Anthropic tool definition format. type anthropicToolDef struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema map[string]any `json:"input_schema"` + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]any `json:"input_schema"` + CacheControl *cacheControl `json:"cache_control,omitempty"` +} + +// anthropicThinking configures extended thinking (chain-of-thought). +type anthropicThinking struct { + Type string `json:"type"` // "enabled" + BudgetTokens int `json:"budget_tokens"` // max thinking tokens } // anthropicRequest is the full POST body for /v1/messages. type anthropicRequest struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Messages []anthropicMessage `json:"messages"` - Tools []anthropicToolDef `json:"tools,omitempty"` + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System json.RawMessage `json:"system,omitempty"` + Messages []anthropicMessage `json:"messages"` + Tools []anthropicToolDef `json:"tools,omitempty"` + Thinking *anthropicThinking `json:"thinking,omitempty"` } // anthropicResponse is the API response body. @@ -92,8 +106,10 @@ type anthropicResponse struct { Content []anthropicContentBlock `json:"content"` StopReason string `json:"stop_reason"` Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` } `json:"usage"` } @@ -149,10 +165,36 @@ func (a *AnthropicProvider) Chat(messages []Message, tools []ToolDef) (*Response reqBody := anthropicRequest{ Model: a.model, MaxTokens: anthropicMaxTokens, - System: systemPrompt, Messages: apiMsgs, } + + // Enable extended thinking with budget cap if configured. + if a.ThinkingBudget > 0 { + reqBody.Thinking = &anthropicThinking{ + Type: "enabled", + BudgetTokens: a.ThinkingBudget, + } + // When thinking is enabled, max_tokens must cover thinking + output. + if reqBody.MaxTokens < a.ThinkingBudget+1024 { + reqBody.MaxTokens = a.ThinkingBudget + 1024 + } + } + + // Build system as array of content blocks with cache_control on the last block. + if systemPrompt != "" { + systemBlocks := []map[string]any{ + { + "type": "text", + "text": systemPrompt, + "cache_control": map[string]string{"type": "ephemeral"}, + }, + } + reqBody.System, _ = json.Marshal(systemBlocks) + } + + // Add cache_control to last tool so tool definitions are cached. if len(apiTools) > 0 { + apiTools[len(apiTools)-1].CacheControl = &cacheControl{Type: "ephemeral"} reqBody.Tools = apiTools } @@ -167,6 +209,7 @@ func (a *AnthropicProvider) Chat(messages []Message, tools []ToolDef) (*Response } req.Header.Set("x-api-key", a.apiKey) req.Header.Set("anthropic-version", anthropicVersion) + req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31") req.Header.Set("Content-Type", "application/json") resp, err := a.client.Do(req) @@ -273,13 +316,20 @@ func convertMessages(messages []Message) ([]anthropicMessage, error) { // parseResponse converts an anthropicResponse into an llm.Response. func parseResponse(apiResp *anthropicResponse) *Response { resp := &Response{ - StopReason: apiResp.StopReason, - PromptTok: apiResp.Usage.InputTokens, - OutputTok: apiResp.Usage.OutputTokens, + StopReason: apiResp.StopReason, + PromptTok: apiResp.Usage.InputTokens, + OutputTok: apiResp.Usage.OutputTokens, + CacheCreated: apiResp.Usage.CacheCreationInputTokens, + CacheHit: apiResp.Usage.CacheReadInputTokens, } for _, block := range apiResp.Content { switch block.Type { + case "thinking": + // Extended thinking output — consumed for token accounting + // but not included in Content (internal reasoning). + continue + case "text": if resp.Content != "" { resp.Content += "\n" diff --git a/internal/llm/provider.go b/internal/llm/provider.go index 44cf9b0..38e1598 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -30,9 +30,11 @@ type ToolCall struct { // Response is the model's reply to a Chat call. type Response struct { - Content string - ToolCalls []ToolCall - StopReason string // "end_turn", "tool_use" - PromptTok int - OutputTok int + Content string + ToolCalls []ToolCall + StopReason string // "end_turn", "tool_use" + PromptTok int + OutputTok int + CacheCreated int // tokens written to cache (first call) + CacheHit int // tokens read from cache (subsequent calls) } From 4fc414b700c8e2d14fb2fade542a782c8adb953a Mon Sep 17 00:00:00 2001 From: Jared Pleva Date: Wed, 1 Apr 2026 00:15:31 +0000 Subject: [PATCH 9/9] feat(shellforge): add drift detection to agent loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Every 5 tool calls, asks the model to self-score alignment with original task (1-10). Below 7 → steering message injected. Below 5 twice → task killed. Prevents agents from burning tokens on off-task work. 8 drift tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/agent/drift.go | 107 +++++++++++++++++++++++++++++++++++ internal/agent/drift_test.go | 102 +++++++++++++++++++++++++++++++++ internal/agent/loop.go | 41 ++++++++++++++ 3 files changed, 250 insertions(+) create mode 100644 internal/agent/drift.go create mode 100644 internal/agent/drift_test.go diff --git a/internal/agent/drift.go b/internal/agent/drift.go new file mode 100644 index 0000000..61a433c --- /dev/null +++ b/internal/agent/drift.go @@ -0,0 +1,107 @@ +package agent + +import ( + "fmt" + "strings" +) + +const ( + driftCheckInterval = 5 // check every N tool calls + driftWarnThreshold = 7 // score below this → inject steering + driftKillThreshold = 5 // score below this twice → kill +) + +// driftDetector tracks whether the agent is staying on-task. +type driftDetector struct { + taskSpec string // original user prompt (the task spec) + actionLog []string // recent tool calls for summarization + warnings int // how many times we've warned + lowScores int // consecutive scores below kill threshold +} + +func newDriftDetector(taskSpec string) *driftDetector { + return &driftDetector{taskSpec: taskSpec} +} + +// record logs a tool call for drift analysis. +func (d *driftDetector) record(toolName string, params map[string]string) { + summary := toolName + if target, ok := params["path"]; ok { + summary += " → " + target + } else if target, ok := params["command"]; ok { + summary += " → " + target + } else if target, ok := params["directory"]; ok { + summary += " → " + target + } + d.actionLog = append(d.actionLog, summary) +} + +// shouldCheck returns true every driftCheckInterval tool calls. +func (d *driftDetector) shouldCheck(totalToolCalls int) bool { + return totalToolCalls > 0 && totalToolCalls%driftCheckInterval == 0 +} + +// buildCheckPrompt creates the drift check message to send to the model. +func (d *driftDetector) buildCheckPrompt() string { + recent := d.actionLog + if len(recent) > driftCheckInterval { + recent = recent[len(recent)-driftCheckInterval:] + } + + return fmt.Sprintf(`DRIFT CHECK — Score your alignment with the original task. + +Original task: %s + +Your last %d actions: +%s + +Rate your alignment 1-10 (10 = perfectly on task, 1 = completely off topic). +Respond with ONLY a single number.`, d.taskSpec, len(recent), strings.Join(recent, "\n")) +} + +// parseScore extracts the drift score from the model's response. +func parseScore(content string) int { + content = strings.TrimSpace(content) + for _, c := range content { + if c >= '0' && c <= '9' { + return int(c - '0') + } + } + return 10 // default to "on task" if unparseable +} + +// evaluate processes the drift score and returns the action to take. +func (d *driftDetector) evaluate(score int) driftAction { + if score >= driftWarnThreshold { + d.lowScores = 0 + return driftOK + } + + if score < driftKillThreshold { + d.lowScores++ + if d.lowScores >= 2 { + return driftKill + } + } + + d.warnings++ + return driftWarn +} + +// steeringMessage returns the message to inject when drift is detected. +func (d *driftDetector) steeringMessage() string { + return fmt.Sprintf(`⚠️ DRIFT DETECTED — You are going off-task. + +Original task: %s + +Refocus on the original task. Do not continue with unrelated work. +Warning %d — task will be terminated if drift continues.`, d.taskSpec, d.warnings) +} + +type driftAction int + +const ( + driftOK driftAction = iota + driftWarn + driftKill +) diff --git a/internal/agent/drift_test.go b/internal/agent/drift_test.go new file mode 100644 index 0000000..765cc52 --- /dev/null +++ b/internal/agent/drift_test.go @@ -0,0 +1,102 @@ +package agent + +import "testing" + +func TestNewDriftDetector(t *testing.T) { + d := newDriftDetector("fix the auth bug") + if d.taskSpec != "fix the auth bug" { + t.Errorf("expected task spec, got %s", d.taskSpec) + } + if len(d.actionLog) != 0 { + t.Error("expected empty action log") + } +} + +func TestDriftDetector_Record(t *testing.T) { + d := newDriftDetector("task") + d.record("read_file", map[string]string{"path": "auth.go"}) + d.record("run_shell", map[string]string{"command": "go test"}) + if len(d.actionLog) != 2 { + t.Errorf("expected 2 actions, got %d", len(d.actionLog)) + } + if d.actionLog[0] != "read_file → auth.go" { + t.Errorf("expected 'read_file → auth.go', got %s", d.actionLog[0]) + } +} + +func TestDriftDetector_ShouldCheck(t *testing.T) { + d := newDriftDetector("task") + if d.shouldCheck(0) { + t.Error("should not check at 0") + } + if d.shouldCheck(3) { + t.Error("should not check at 3") + } + if !d.shouldCheck(5) { + t.Error("should check at 5") + } + if !d.shouldCheck(10) { + t.Error("should check at 10") + } +} + +func TestParseScore(t *testing.T) { + cases := []struct { + input string + want int + }{ + {"8", 8}, + {" 7 ", 7}, + {"Score: 9", 9}, + {"3/10", 3}, + {"no score here", 10}, // default + {"", 10}, + } + for _, c := range cases { + if got := parseScore(c.input); got != c.want { + t.Errorf("parseScore(%q) = %d, want %d", c.input, got, c.want) + } + } +} + +func TestDriftDetector_Evaluate_OK(t *testing.T) { + d := newDriftDetector("task") + if d.evaluate(8) != driftOK { + t.Error("score 8 should be OK") + } + if d.evaluate(10) != driftOK { + t.Error("score 10 should be OK") + } +} + +func TestDriftDetector_Evaluate_Warn(t *testing.T) { + d := newDriftDetector("task") + if d.evaluate(6) != driftWarn { + t.Error("score 6 should warn") + } + if d.warnings != 1 { + t.Errorf("expected 1 warning, got %d", d.warnings) + } +} + +func TestDriftDetector_Evaluate_Kill(t *testing.T) { + d := newDriftDetector("task") + // First low score → warn + if d.evaluate(4) != driftWarn { + t.Error("first score 4 should warn") + } + // Second consecutive low score → kill + if d.evaluate(3) != driftKill { + t.Error("second low score should kill") + } +} + +func TestDriftDetector_Evaluate_ResetOnGoodScore(t *testing.T) { + d := newDriftDetector("task") + d.evaluate(4) // low → warn + d.evaluate(8) // good → reset + // Next low should warn again, not kill + if d.evaluate(3) != driftWarn { + t.Error("after reset, first low should warn not kill") + } +} diff --git a/internal/agent/loop.go b/internal/agent/loop.go index cfaab2d..bd0e74b 100644 --- a/internal/agent/loop.go +++ b/internal/agent/loop.go @@ -80,6 +80,7 @@ func runProviderLoop(cfg LoopConfig, engine *governance.Engine, start time.Time) } toolDefs := buildToolDefs() + drift := newDriftDetector(cfg.UserPrompt) result := &RunResult{} var log []string @@ -112,6 +113,10 @@ func runProviderLoop(cfg LoopConfig, engine *governance.Engine, start time.Time) result.PromptTok += provResp.PromptTok result.ResponseTok += provResp.OutputTok logger.ModelCall(cfg.Agent, provResp.PromptTok, provResp.OutputTok, 0) + if provResp.CacheHit > 0 { + logger.Agent(cfg.Agent, fmt.Sprintf("cache hit: %d tokens (saved ~%.1f%%)", + provResp.CacheHit, float64(provResp.CacheHit)/float64(provResp.PromptTok+provResp.CacheHit)*100)) + } // ── Native tool-use path ── if len(provResp.ToolCalls) > 0 { @@ -199,6 +204,41 @@ func runProviderLoop(cfg LoopConfig, engine *governance.Engine, start time.Time) }) log = append(log, fmt.Sprintf("[turn %d] %s → %s", turn, tc.Name, boolStr(toolResult.Success, "ok", "fail"))) + + // Record action for drift detection. + drift.record(tc.Name, tc.Params) + } + + // ── Drift detection: check every N tool calls ── + if drift.shouldCheck(result.ToolCalls) { + logger.Agent(cfg.Agent, "drift check...") + driftMsgs := []llm.Message{ + {Role: "user", Content: drift.buildCheckPrompt()}, + } + driftResp, derr := cfg.Provider.Chat(driftMsgs, nil) + if derr == nil { + score := parseScore(driftResp.Content) + result.PromptTok += driftResp.PromptTok + result.ResponseTok += driftResp.OutputTok + + action := drift.evaluate(score) + switch action { + case driftWarn: + logger.Agent(cfg.Agent, fmt.Sprintf("drift warning — score %d/10, warning #%d", score, drift.warnings)) + messages = append(messages, llm.Message{ + Role: "user", + Content: drift.steeringMessage(), + }) + case driftKill: + logger.Agent(cfg.Agent, fmt.Sprintf("drift kill — score %d/10, terminating", score)) + result.ExitReason = "drift" + result.Output = "Task terminated: agent drifted from original task spec." + result.Turns = turn + goto done + default: + logger.Agent(cfg.Agent, fmt.Sprintf("drift ok — score %d/10", score)) + } + } } // Last turn: force final answer @@ -228,6 +268,7 @@ func runProviderLoop(cfg LoopConfig, engine *governance.Engine, start time.Time) break } +done: result.DurationMs = time.Since(start).Milliseconds() result.Success = result.ExitReason == "final_answer" result.Log = log