From fe8ec0b394207a90cad9fef7dc6bea6b1ca04e65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arnaud=20He=CC=81ritier?= Date: Thu, 26 Feb 2026 23:17:48 +0100 Subject: [PATCH] fix(#1863): preserve user messages in trimMessages to prevent session derailment trimMessages() removed the oldest conversation messages first without protecting user messages. In single-turn agentic loops (one user message followed by many tool calls), the user's original request was the first message trimmed, causing the model to lose context and stop working. User messages are now marked as protected and skipped during trimming. Only assistant and tool messages are eligible for removal. Fixes #1863 Assisted-By: cagent --- pkg/session/session.go | 32 +++-- pkg/session/session_history_test.go | 181 ++++++++++++++++++++++------ pkg/session/session_test.go | 4 +- 3 files changed, 172 insertions(+), 45 deletions(-) diff --git a/pkg/session/session.go b/pkg/session/session.go index bc2e52732..85b6304d9 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -691,7 +691,9 @@ func (s *Session) GetMessages(a *agent.Agent) []chat.Message { // trimMessages ensures we don't exceed the maximum number of messages while maintaining // consistency between assistant messages and their tool call results. -// System messages are always preserved and not counted against the limit. +// System messages and user messages are always preserved and not counted against the limit. +// User messages are protected from trimming to prevent the model from losing +// track of what was asked in long agentic loops. func trimMessages(messages []chat.Message, maxItems int) []chat.Message { // Separate system messages from conversation messages var systemMessages []chat.Message @@ -710,15 +712,27 @@ func trimMessages(messages []chat.Message, maxItems int) []chat.Message { return messages } + // Identify user message indices — these are protected from trimming + protected := make(map[int]bool) + for i, msg := range conversationMessages { + if msg.Role == chat.MessageRoleUser { + protected[i] = true + } + } + // Keep track of tool call IDs that need to be removed toolCallsToRemove := make(map[string]bool) // Calculate how many conversation messages we need to remove toRemove := len(conversationMessages) - maxItems - // Start from the beginning (oldest messages) - for i := range toRemove { - // If this is an assistant message with tool calls, mark them for removal + // Mark the oldest non-protected messages for removal + removed := make(map[int]bool) + for i := 0; i < len(conversationMessages) && len(removed) < toRemove; i++ { + if protected[i] { + continue + } + removed[i] = true if conversationMessages[i].Role == chat.MessageRoleAssistant { for _, toolCall := range conversationMessages[i].ToolCalls { toolCallsToRemove[toolCall.ID] = true @@ -732,11 +746,13 @@ func trimMessages(messages []chat.Message, maxItems int) []chat.Message { // Add all system messages first result = append(result, systemMessages...) - // Add the most recent conversation messages - for i := toRemove; i < len(conversationMessages); i++ { - msg := conversationMessages[i] + // Add protected and non-removed conversation messages + for i, msg := range conversationMessages { + if removed[i] { + continue + } - // Skip tool messages that correspond to removed assistant messages + // Skip orphaned tool results whose assistant message was removed if msg.Role == chat.MessageRoleTool && toolCallsToRemove[msg.ToolCallID] { continue } diff --git a/pkg/session/session_history_test.go b/pkg/session/session_history_test.go index cea57b928..a2bec576d 100644 --- a/pkg/session/session_history_test.go +++ b/pkg/session/session_history_test.go @@ -19,16 +19,20 @@ func TestSessionNumHistoryItems(t *testing.T) { expectedConversationMsgs int }{ { - name: "limit to 3 conversation messages", - numHistoryItems: 3, - messageCount: 10, - expectedConversationMsgs: 3, // Limited to 3 despite 20 total messages + name: "limit to 3 conversation messages — user messages protected", + numHistoryItems: 3, + messageCount: 10, + // 10 user (all protected) + 10 assistant. Need to remove 17, but only 10 removable. + // Result: 10 users + 0 assistants = 10 + expectedConversationMsgs: 10, }, { - name: "limit to 5 conversation messages", - numHistoryItems: 5, - messageCount: 8, - expectedConversationMsgs: 5, // Limited to 5 out of 16 total messages + name: "limit to 5 conversation messages — user messages protected", + numHistoryItems: 5, + messageCount: 8, + // 8 user (all protected) + 8 assistant. Need to remove 11, but only 8 removable. + // Result: 8 users + 0 assistants = 8 + expectedConversationMsgs: 8, }, { name: "fewer messages than limit", @@ -71,9 +75,8 @@ func TestSessionNumHistoryItems(t *testing.T) { // System messages should always be present (at least the instruction) assert.Positive(t, systemCount, "Should have system messages") - // Conversation messages should be limited - assert.LessOrEqual(t, conversationCount, tt.expectedConversationMsgs, - "Conversation messages should not exceed the configured limit") + assert.Equal(t, tt.expectedConversationMsgs, conversationCount, + "Conversation messages should match expected count") }) } } @@ -95,22 +98,20 @@ func TestTrimMessagesPreservesSystemMessages(t *testing.T) { // Count message types systemCount := 0 - conversationCount := 0 + userCount := 0 for _, msg := range trimmed { if msg.Role == chat.MessageRoleSystem { systemCount++ - } else { - conversationCount++ + } + if msg.Role == chat.MessageRoleUser { + userCount++ } } // All system messages should be preserved assert.Equal(t, 3, systemCount, "All system messages should be preserved") - assert.Equal(t, 1, conversationCount, "Should have exactly 1 conversation message") - - // The preserved conversation message should be the most recent - assert.Equal(t, "Assistant response 3", trimmed[len(trimmed)-1].Content, - "Should preserve the most recent conversation message") + // All user messages should be preserved even with maxItems=1 + assert.Equal(t, 3, userCount, "All user messages should be preserved") } func TestTrimMessagesConversationLimit(t *testing.T) { @@ -126,16 +127,22 @@ func TestTrimMessagesConversationLimit(t *testing.T) { {Role: chat.MessageRoleAssistant, Content: "Response 4"}, } + // 8 conversation messages: 4 user + 4 assistant + // User messages are always protected, so only assistant messages can be trimmed. testCases := []struct { limit int - expectedTotal int - expectedConversation int expectedSystem int + expectedUser int + expectedConversation int // total non-system }{ - {limit: 2, expectedTotal: 3, expectedConversation: 2, expectedSystem: 1}, - {limit: 4, expectedTotal: 5, expectedConversation: 4, expectedSystem: 1}, - {limit: 8, expectedTotal: 9, expectedConversation: 8, expectedSystem: 1}, - {limit: 100, expectedTotal: 9, expectedConversation: 8, expectedSystem: 1}, + // limit=2: need to remove 6 of 8, but 4 are protected users → only 4 assistants removable → remove 4 + {limit: 2, expectedSystem: 1, expectedUser: 4, expectedConversation: 4}, + // limit=4: need to remove 4 of 8, 4 are protected → remove all 4 assistants + {limit: 4, expectedSystem: 1, expectedUser: 4, expectedConversation: 4}, + // limit=8: no trimming needed (8 <= 8) + {limit: 8, expectedSystem: 1, expectedUser: 4, expectedConversation: 8}, + // limit=100: no trimming needed + {limit: 100, expectedSystem: 1, expectedUser: 4, expectedConversation: 8}, } for _, tc := range testCases { @@ -143,17 +150,22 @@ func TestTrimMessagesConversationLimit(t *testing.T) { trimmed := trimMessages(messages, tc.limit) systemCount := 0 + userCount := 0 conversationCount := 0 for _, msg := range trimmed { - if msg.Role == chat.MessageRoleSystem { + switch msg.Role { + case chat.MessageRoleSystem: systemCount++ - } else { + case chat.MessageRoleUser: + userCount++ + conversationCount++ + default: conversationCount++ } } - assert.Len(t, trimmed, tc.expectedTotal, "Total message count") assert.Equal(t, tc.expectedSystem, systemCount, "System message count") + assert.Equal(t, tc.expectedUser, userCount, "User messages should always be preserved") assert.Equal(t, tc.expectedConversation, conversationCount, "Conversation message count") }) } @@ -190,7 +202,7 @@ func TestTrimMessagesWithToolCallsPreservation(t *testing.T) { }, } - // Limit to 3 conversation messages (should keep the recent tool interaction) + // Limit to 3 conversation messages trimmed := trimMessages(messages, 3) toolCallIDs := make(map[string]bool) @@ -209,12 +221,113 @@ func TestTrimMessagesWithToolCallsPreservation(t *testing.T) { } } - // Should not have the old tool call - hasOldTool := false + // Both user messages should be preserved + userMessages := 0 for _, msg := range trimmed { - if msg.Role == chat.MessageRoleTool && msg.ToolCallID == "old_tool_1" { - hasOldTool = true + if msg.Role == chat.MessageRoleUser { + userMessages++ } } - assert.False(t, hasOldTool, "Should not have old tool results without their calls") + assert.Equal(t, 2, userMessages, "Both user messages should be preserved") +} + +func TestTrimMessagesPreservesUserMessagesInAgenticLoop(t *testing.T) { + // Simulate a single-turn agentic loop: one user message followed by many tool calls + messages := []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "System prompt"}, + {Role: chat.MessageRoleUser, Content: "Analyze MR #123 and build an integration plan"}, + } + + for i := range 30 { + toolID := fmt.Sprintf("tool_%d", i) + messages = append(messages, chat.Message{ + Role: chat.MessageRoleAssistant, + Content: fmt.Sprintf("Calling tool %d", i), + ToolCalls: []tools.ToolCall{ + {ID: toolID, Function: tools.FunctionCall{Name: "shell"}}, + }, + }, chat.Message{ + Role: chat.MessageRoleTool, + Content: fmt.Sprintf("Tool result %d", i), + ToolCallID: toolID, + }) + } + + // 61 conversation messages (1 user + 30 assistant + 30 tool), limit to 30 + trimmed := trimMessages(messages, 30) + + // The user message must survive + var userMessages []string + for _, msg := range trimmed { + if msg.Role == chat.MessageRoleUser { + userMessages = append(userMessages, msg.Content) + } + } + + assert.Len(t, userMessages, 1, "User message must be preserved") + assert.Equal(t, "Analyze MR #123 and build an integration plan", userMessages[0]) + + // Tool call consistency: every tool result must have a matching assistant tool call + toolCallIDs := make(map[string]bool) + for _, msg := range trimmed { + if msg.Role == chat.MessageRoleAssistant { + for _, tc := range msg.ToolCalls { + toolCallIDs[tc.ID] = true + } + } + } + for _, msg := range trimmed { + if msg.Role == chat.MessageRoleTool { + assert.True(t, toolCallIDs[msg.ToolCallID], + "Tool result %s should have a corresponding assistant tool call", msg.ToolCallID) + } + } +} + +func TestTrimMessagesPreservesAllUserMessages(t *testing.T) { + // Multiple user messages interspersed with tool calls + messages := []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "System prompt"}, + {Role: chat.MessageRoleUser, Content: "First request"}, + } + + for i := range 10 { + toolID := fmt.Sprintf("tool_%d", i) + messages = append(messages, chat.Message{ + Role: chat.MessageRoleAssistant, + ToolCalls: []tools.ToolCall{{ID: toolID}}, + }, chat.Message{ + Role: chat.MessageRoleTool, + Content: fmt.Sprintf("result %d", i), + ToolCallID: toolID, + }) + } + + messages = append(messages, chat.Message{Role: chat.MessageRoleUser, Content: "Follow-up request"}) + + for i := 10; i < 20; i++ { + toolID := fmt.Sprintf("tool_%d", i) + messages = append(messages, chat.Message{ + Role: chat.MessageRoleAssistant, + ToolCalls: []tools.ToolCall{{ID: toolID}}, + }, chat.Message{ + Role: chat.MessageRoleTool, + Content: fmt.Sprintf("result %d", i), + ToolCallID: toolID, + }) + } + + // 42 conversation messages (2 user + 20 assistant + 20 tool), limit to 10 + trimmed := trimMessages(messages, 10) + + var userContents []string + for _, msg := range trimmed { + if msg.Role == chat.MessageRoleUser { + userContents = append(userContents, msg.Content) + } + } + + assert.Len(t, userContents, 2, "Both user messages must be preserved") + assert.Equal(t, "First request", userContents[0]) + assert.Equal(t, "Follow-up request", userContents[1]) } diff --git a/pkg/session/session_test.go b/pkg/session/session_test.go index 1ce354488..630f94597 100644 --- a/pkg/session/session_test.go +++ b/pkg/session/session_test.go @@ -58,9 +58,7 @@ func TestTrimMessagesWithToolCalls(t *testing.T) { result := trimMessages(messages, maxItems) - // Should keep last 3 messages, but ensure tool call consistency - assert.Len(t, result, maxItems) - + // Both user messages are protected, so result includes them plus the most recent assistant/tool pair toolCalls := make(map[string]bool) for _, msg := range result { if msg.Role == chat.MessageRoleAssistant {