Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions internal/adapter/claude/handler_helpers_misc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package claude

import (
"fmt"
"strings"
)

func hasSystemMessage(messages []any) bool {
for _, m := range messages {
msg, ok := m.(map[string]any)
if ok && msg["role"] == "system" {
return true
}
}
return false
}

func extractClaudeToolNames(tools []any) []string {
out := make([]string, 0, len(tools))
for _, t := range tools {
m, ok := t.(map[string]any)
if !ok {
continue
}
name, _, _ := extractClaudeToolMeta(m)
if name != "" {
out = append(out, name)
}
}
return out
}

func extractClaudeToolMeta(m map[string]any) (string, string, any) {
name, _ := m["name"].(string)
desc, _ := m["description"].(string)
schemaObj := m["input_schema"]
if schemaObj == nil {
schemaObj = m["parameters"]
}

if fn, ok := m["function"].(map[string]any); ok {
if strings.TrimSpace(name) == "" {
name, _ = fn["name"].(string)
}
if strings.TrimSpace(desc) == "" {
desc, _ = fn["description"].(string)
}
if schemaObj == nil {
if v, ok := fn["input_schema"]; ok {
schemaObj = v
}
}
if schemaObj == nil {
if v, ok := fn["parameters"]; ok {
schemaObj = v
}
}
}
return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj
}

func toMessageMaps(v any) []map[string]any {
arr, ok := v.([]any)
if !ok {
return nil
}
out := make([]map[string]any, 0, len(arr))
for _, item := range arr {
if m, ok := item.(map[string]any); ok {
out = append(out, m)
}
}
return out
}

func extractMessageContent(v any) string {
switch x := v.(type) {
case string:
return x
case []any:
parts := make([]string, 0, len(x))
for _, it := range x {
parts = append(parts, fmt.Sprintf("%v", it))
}
return strings.Join(parts, "\n")
default:
return fmt.Sprintf("%v", x)
}
}

func cloneMap(in map[string]any) map[string]any {
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
41 changes: 41 additions & 0 deletions internal/adapter/claude/handler_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,47 @@ func TestNormalizeClaudeMessagesToolResultNonTextPayloadStringified(t *testing.T
}
}

func TestNormalizeClaudeMessagesBackfillsToolResultCallIDByName(t *testing.T) {
msgs := []any{
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{
"type": "tool_use",
"name": "search_web",
"input": map[string]any{"query": "latest"},
},
},
},
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "tool_result",
"name": "search_web",
"content": "ok",
},
},
},
}

got := normalizeClaudeMessages(msgs)
if len(got) != 2 {
t.Fatalf("expected 2 messages, got %#v", got)
}
assistant, _ := got[0].(map[string]any)
tc, _ := assistant["tool_calls"].([]any)
call, _ := tc[0].(map[string]any)
callID, _ := call["id"].(string)
if !strings.HasPrefix(callID, "call_claude_") {
t.Fatalf("expected generated call id, got %#v", call)
}
toolMsg, _ := got[1].(map[string]any)
if toolMsg["tool_call_id"] != callID {
t.Fatalf("expected tool_result to reuse generated id, got %#v", toolMsg)
}
}

// ─── buildClaudeToolPrompt ───────────────────────────────────────────

func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
Expand Down
130 changes: 28 additions & 102 deletions internal/adapter/claude/handler_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import (

func normalizeClaudeMessages(messages []any) []any {
out := make([]any, 0, len(messages))
state := &claudeToolCallState{
nameByID: map[string]string{},
lastIDByName: map[string]string{},
callIDSequence: 0,
}
for _, m := range messages {
msg, ok := m.(map[string]any)
if !ok {
Expand Down Expand Up @@ -44,7 +49,7 @@ func normalizeClaudeMessages(messages []any) []any {
case "tool_use":
if role == "assistant" {
flushText()
if toolMsg := normalizeClaudeToolUseToAssistant(b); toolMsg != nil {
if toolMsg := normalizeClaudeToolUseToAssistant(b, state); toolMsg != nil {
out = append(out, toolMsg)
}
continue
Expand All @@ -54,7 +59,7 @@ func normalizeClaudeMessages(messages []any) []any {
}
case "tool_result":
flushText()
if toolMsg := normalizeClaudeToolResultToToolMessage(b); toolMsg != nil {
if toolMsg := normalizeClaudeToolResultToToolMessage(b, state); toolMsg != nil {
out = append(out, toolMsg)
}
default:
Expand Down Expand Up @@ -119,21 +124,23 @@ func formatClaudeToolResultForPrompt(block map[string]any) string {
return string(b)
}

func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any {
func normalizeClaudeToolUseToAssistant(block map[string]any, state *claudeToolCallState) map[string]any {
if block == nil {
return nil
}
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
if name == "" {
return nil
}
callID := strings.TrimSpace(fmt.Sprintf("%v", block["id"]))
callID := safeStringValue(block["id"])
if callID == "" {
callID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
callID = safeStringValue(block["tool_use_id"])
}
if callID == "" {
callID = "call_claude"
callID = state.nextID()
}
state.nameByID[callID] = name
state.lastIDByName[strings.ToLower(name)] = callID
arguments := block["input"]
if arguments == nil {
arguments = map[string]any{}
Expand All @@ -159,24 +166,34 @@ func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any {
}
}

func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any {
func normalizeClaudeToolResultToToolMessage(block map[string]any, state *claudeToolCallState) map[string]any {
if block == nil {
return nil
}
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
name := safeStringValue(block["name"])
toolCallID := safeStringValue(block["tool_use_id"])
if toolCallID == "" {
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
toolCallID = safeStringValue(block["tool_call_id"])
}
if toolCallID == "" {
toolCallID = "call_claude"
if name != "" {
toolCallID = strings.TrimSpace(state.lastIDByName[strings.ToLower(name)])
}
}
if toolCallID == "" {
toolCallID = state.nextID()
}
out := map[string]any{
"role": "tool",
"tool_call_id": toolCallID,
"content": normalizeClaudeToolResultContent(block["content"]),
}
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
if name != "" {
out["name"] = name
state.nameByID[toolCallID] = name
state.lastIDByName[strings.ToLower(name)] = toolCallID
} else if inferred := strings.TrimSpace(state.nameByID[toolCallID]); inferred != "" {
out["name"] = inferred
}
return out
}
Expand Down Expand Up @@ -206,94 +223,3 @@ func formatClaudeBlockRaw(block map[string]any) string {
}
return string(b)
}

func hasSystemMessage(messages []any) bool {
for _, m := range messages {
msg, ok := m.(map[string]any)
if ok && msg["role"] == "system" {
return true
}
}
return false
}

func extractClaudeToolNames(tools []any) []string {
out := make([]string, 0, len(tools))
for _, t := range tools {
m, ok := t.(map[string]any)
if !ok {
continue
}
name, _, _ := extractClaudeToolMeta(m)
if name != "" {
out = append(out, name)
}
}
return out
}

func extractClaudeToolMeta(m map[string]any) (string, string, any) {
name, _ := m["name"].(string)
desc, _ := m["description"].(string)
schemaObj := m["input_schema"]
if schemaObj == nil {
schemaObj = m["parameters"]
}

if fn, ok := m["function"].(map[string]any); ok {
if strings.TrimSpace(name) == "" {
name, _ = fn["name"].(string)
}
if strings.TrimSpace(desc) == "" {
desc, _ = fn["description"].(string)
}
if schemaObj == nil {
if v, ok := fn["input_schema"]; ok {
schemaObj = v
}
}
if schemaObj == nil {
if v, ok := fn["parameters"]; ok {
schemaObj = v
}
}
}
return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj
}

func toMessageMaps(v any) []map[string]any {
arr, ok := v.([]any)
if !ok {
return nil
}
out := make([]map[string]any, 0, len(arr))
for _, item := range arr {
if m, ok := item.(map[string]any); ok {
out = append(out, m)
}
}
return out
}

func extractMessageContent(v any) string {
switch x := v.(type) {
case string:
return x
case []any:
parts := make([]string, 0, len(x))
for _, it := range x {
parts = append(parts, fmt.Sprintf("%v", it))
}
return strings.Join(parts, "\n")
default:
return fmt.Sprintf("%v", x)
}
}

func cloneMap(in map[string]any) map[string]any {
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
25 changes: 25 additions & 0 deletions internal/adapter/claude/tool_call_state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package claude

import (
"fmt"
"strings"
)

type claudeToolCallState struct {
nameByID map[string]string
lastIDByName map[string]string
callIDSequence int
}

func (s *claudeToolCallState) nextID() string {
s.callIDSequence++
return fmt.Sprintf("call_claude_%d", s.callIDSequence)
}

func safeStringValue(v any) string {
s, ok := v.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
Loading
Loading