From d0f0f92670280358a88b7b81c19fbae3323476ff Mon Sep 17 00:00:00 2001 From: Jijun Leng <962285+jjleng@users.noreply.github.com> Date: Fri, 11 Apr 2025 23:55:40 -0700 Subject: [PATCH] feat: cap max prompt checkpoints to be 4 --- cp-agent/cp_agent/agents/coder/agent.py | 16 ++++++++++++++-- .../cp_agent/agents/coder/message_manager.py | 10 +++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/cp-agent/cp_agent/agents/coder/agent.py b/cp-agent/cp_agent/agents/coder/agent.py index dfa5094..8efe0b5 100644 --- a/cp-agent/cp_agent/agents/coder/agent.py +++ b/cp-agent/cp_agent/agents/coder/agent.py @@ -7,6 +7,7 @@ from uuid import UUID from zoneinfo import ZoneInfo +import litellm from litellm import CustomStreamWrapper, acompletion from loguru import logger @@ -468,9 +469,20 @@ async def _recursively_process_messages( yield TextEvent(text=error_msg) self.state_manager.task.timeout() + except litellm.exceptions.BadRequestError as e: + logger.exception(e) + error_msg = f"Bad request error: {e.message}" + logger.error(error_msg) + await self.message_manager.add_assistant_message(error_msg) + yield TextEvent(text=error_msg) + self.state_manager.task.fail() + except Exception as e: - logger.error(f"Stream processing error: {e}", exc_info=True) - error_msg = f"Error processing assistant response: {str(e)}" + logger.error( + f"Task {self.state_manager.task.id} failed: {e}", exc_info=True + ) + error_msg = f"Task failed: {str(e)}" + await self.message_manager.add_assistant_message(error_msg) yield TextEvent(text=error_msg) self.state_manager.task.fail() diff --git a/cp-agent/cp_agent/agents/coder/message_manager.py b/cp-agent/cp_agent/agents/coder/message_manager.py index aa56382..086eaf9 100644 --- a/cp-agent/cp_agent/agents/coder/message_manager.py +++ b/cp-agent/cp_agent/agents/coder/message_manager.py @@ -48,6 +48,8 @@ def __init__( self.chat_history: list[dict[str, Any]] = [] self.enable_prompt_cache = enable_prompt_cache + self.checkpoint_count = 0 + self.max_checkpoints = 4 async def compact_memory(self) -> None: """Manually trigger memory compaction.""" @@ -75,7 +77,8 @@ async def add_user_message( async def add_assistant_message(self, content: str) -> None: """Add assistant message to both API memory and chat history.""" - if self.enable_prompt_cache: + if self.enable_prompt_cache and self.checkpoint_count < self.max_checkpoints: + self.checkpoint_count += 1 message_content: list[MessagePart] = [create_text_block(content)] if not IS_BEDROCK: message_content = [create_text_block(content, "ephemeral")] @@ -87,6 +90,11 @@ async def add_assistant_message(self, content: str) -> None: self.memory.rpush("messages", dict(message)) + async def reset_checkpoints(self) -> None: + """Reset the checkpoint counter, typically after a new conversation starts.""" + self.checkpoint_count = 0 + logger.debug("Reset checkpoint counter") + async def add_memory_item( self, content: MessageContent, role: str = "user" ) -> None: