Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 17 additions & 29 deletions cp-agent/cp_agent/agents/coder/message_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
Attachment,
Message,
MessageContent,
MessagePart,
create_message_content,
create_text_block,
)

IS_BEDROCK = False


class MessageManager:
"""Handles message storage, retrieval, and formatting for the coder agent."""
Expand Down Expand Up @@ -66,38 +70,22 @@ async def add_user_message(
content = create_message_content(text, attachments)

msg = Message(content=content, role="user")
self.memory.rpush("messages", {"role": "user", "content": content})

self.chat_history.append(
{
"role": "user",
"content": text,
"timestamp": msg.timestamp,
"attachments": [
{
"url": att.url,
"type": att.type,
"filename": att.filename,
"mime_type": att.mime_type,
"size": att.size,
}
for att in attachments
],
}
)
self.memory.rpush("messages", dict(msg))

async def add_assistant_message(self, content: str) -> None:
"""Add assistant message to both API memory and chat history."""
msg = Message(content=content, role="assistant")
self.memory.rpush("messages", {"role": "assistant", "content": content})

self.chat_history.append(
{
"role": "assistant",
"content": content,
"timestamp": msg.timestamp,
}
)

if self.enable_prompt_cache:
message_content: list[MessagePart] = [create_text_block(content)]
if not IS_BEDROCK:
message_content = [create_text_block(content, "ephemeral")]
else:
message_content.append({"cachePoint": {"type": "default"}})
Comment on lines +79 to +83
Copy link

Copilot AI Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The caching implementation for assistant messages uses an ephemeral cache_control via create_text_block in one branch and appends a separate cachePoint block in the other branch. Consider standardizing the caching representation to avoid potential inconsistencies in message format.

Suggested change
message_content: list[MessagePart] = [create_text_block(content)]
if not IS_BEDROCK:
message_content = [create_text_block(content, "ephemeral")]
else:
message_content.append({"cachePoint": {"type": "default"}})
cache_type = "ephemeral" if not IS_BEDROCK else "default"
message_content: list[MessagePart] = [
create_text_block(content, cache_type),
{"cachePoint": {"type": cache_type}},
]

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bedrock and other vendors use different formats, sigh

message = Message(content=message_content, role="assistant")
else:
message = Message(content=content, role="assistant")

self.memory.rpush("messages", dict(message))

async def add_memory_item(
self, content: MessageContent, role: str = "user"
Expand Down
2 changes: 1 addition & 1 deletion cp-agent/cp_agent/agents/searcher/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ async def search_from_message(self, message: MessageContent) -> list[str]:
# Extract and concatenate all text content
text_parts: list[str] = []
for part in message_parts:
if part["type"] == "text":
if part.get("type") == "text":
text_block = cast(TextBlock, part)
text_parts.append(text_block["text"])

Expand Down
16 changes: 2 additions & 14 deletions cp-agent/cp_agent/context/context_enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,7 @@ async def _prime_content(
content = await content_generator()
content_data = create_text_block(f"<{tag}>{content}</{tag}>")
msg = Message(content=[content_data], role="user")
self.memory.rpush(
"messages",
{
"role": "user",
"content": msg.content,
},
)
self.memory.rpush("messages", dict(msg))

async def _generate_project_info(self) -> str:
"""Generate project information content."""
Expand Down Expand Up @@ -206,13 +200,7 @@ async def prime_supabase_general_instructions(
content = await general_instructions(self.supabase_util.supabase_project_id)
content_data = create_text_block(f"<{tag}>{content}</{tag}>")
msg = Message(content=[content_data], role="user")
self.memory.rpush(
"messages",
{
"role": "user",
"content": msg.content,
},
)
self.memory.rpush("messages", dict(msg))

async def prime_supabase_report(self, with_cache: bool) -> None:
"""Prime the Supabase report with optional caching.
Expand Down
31 changes: 24 additions & 7 deletions cp-agent/cp_agent/utils/message_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from base64 import b64encode
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime
from typing import Literal, Optional, TypedDict

Expand All @@ -12,27 +12,38 @@ class ImageUrl(TypedDict):
url: str


class CacheControl(TypedDict):
type: str


class CachePoint(TypedDict):
type: str


class ImageBlock(TypedDict):
type: str
image_url: ImageUrl


class TextBlock(TypedDict):
class TextBlock(TypedDict, total=False):
type: str
text: str
cache_control: Optional[CacheControl]


class CachePointBlock(TypedDict):
cachePoint: CachePoint

MessagePart = TextBlock | ImageBlock

MessagePart = TextBlock | ImageBlock | CachePointBlock
MessageContent = str | list[MessagePart]


@dataclass
class Message:
class Message(TypedDict):
"""Message in agent conversation with multimodal support."""

content: MessageContent
role: str # 'user' or 'assistant'
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())


@dataclass
Expand All @@ -47,8 +58,14 @@ class Attachment:
size: Optional[int] = None


def create_text_block(text: str) -> TextBlock:
def create_text_block(text: str, cache_control_type: Optional[str] = None) -> TextBlock:
"""Create a text block for message content."""
if cache_control_type:
return {
"type": "text",
"text": text,
"cache_control": {"type": cache_control_type},
}
return {"type": "text", "text": text}


Expand Down
90 changes: 49 additions & 41 deletions cp-agent/tests/test_context_enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,23 @@ def test_scan_messages_for_tag_string_content(
context_enricher: ContextEnricher,
) -> None:
"""Test scanning messages with string content."""
msg = Message(content="<tag>Some content", role="user")
context_enricher.memory.rpush("messages", msg.__dict__)
msg: Message = {"content": "<tag>Some content", "role": "user"}
context_enricher.memory.rpush("messages", msg)

assert context_enricher.scan_messages_for_tag("<tag>") is True
assert context_enricher.scan_messages_for_tag("<other>") is False


def test_scan_messages_for_tag_list_content(context_enricher: ContextEnricher) -> None:
"""Test scanning messages with list content."""
msg = Message(
content=[
msg: Message = {
"content": [
create_text_block("<tag>First block"),
create_text_block("Second block"),
],
role="user",
)
context_enricher.memory.rpush("messages", msg.__dict__)
"role": "user",
}
context_enricher.memory.rpush("messages", msg)

assert context_enricher.scan_messages_for_tag("<tag>") is True
assert context_enricher.scan_messages_for_tag("<other>") is False
Expand All @@ -49,35 +49,40 @@ def test_scan_messages_for_tag_with_whitespace(
context_enricher: ContextEnricher,
) -> None:
"""Test scanning messages with whitespace before tag."""
msg = Message(content=" <tag>Some content", role="user")
context_enricher.memory.rpush("messages", msg.__dict__)
msg: Message = {"content": " <tag>Some content", "role": "user"}
context_enricher.memory.rpush("messages", msg)

assert context_enricher.scan_messages_for_tag("<tag>") is True


def test_scan_messages_for_tag_mixed_content(context_enricher: ContextEnricher) -> None:
"""Test scanning messages with mixed content types."""
msg1 = Message(content="Regular message", role="user")
msg2 = Message(
content=[create_text_block("<tag>Block 1"), create_text_block("Block 2")],
role="user",
)
msg3 = Message(content="<other>Different tag", role="user")
messages = [
{"content": "Regular message", "role": "user"},
{
"content": [
create_text_block("<tag>Block 1"),
create_text_block("Block 2"),
],
"role": "user",
},
{"content": "<other>Different tag", "role": "user"},
]

for msg in [msg1, msg2, msg3]:
context_enricher.memory.rpush("messages", msg.__dict__)
for msg in messages:
context_enricher.memory.rpush("messages", msg)

assert context_enricher.scan_messages_for_tag("<tag>") is True
assert context_enricher.scan_messages_for_tag("<missing>") is False


def test_delete_memory_by_tag_string_content(context_enricher: ContextEnricher) -> None:
"""Test deleting messages with string content."""
msg1 = Message(content="<tag>Delete this", role="user")
msg2 = Message(content="Keep this", role="user")
msg1: Message = {"content": "<tag>Delete this", "role": "user"}
msg2: Message = {"content": "Keep this", "role": "user"}

for msg in [msg1, msg2]:
context_enricher.memory.rpush("messages", msg.__dict__)
context_enricher.memory.rpush("messages", msg)

assert context_enricher.delete_memory_by_tag("<tag>") is True
messages = context_enricher.memory.lrange("messages")
Expand All @@ -87,20 +92,23 @@ def test_delete_memory_by_tag_string_content(context_enricher: ContextEnricher)

def test_delete_memory_by_tag_list_content(context_enricher: ContextEnricher) -> None:
"""Test deleting messages with list content."""
msg1 = Message(
content=[
msg1: Message = {
"content": [
create_text_block("<tag>Delete block"),
create_text_block("Other block"),
],
role="user",
)
msg2 = Message(
content=[create_text_block("Keep block 1"), create_text_block("Keep block 2")],
role="user",
)
"role": "user",
}
msg2: Message = {
"content": [
create_text_block("Keep block 1"),
create_text_block("Keep block 2"),
],
"role": "user",
}

for msg in [msg1, msg2]:
context_enricher.memory.rpush("messages", msg.__dict__)
context_enricher.memory.rpush("messages", msg)

assert context_enricher.delete_memory_by_tag("<tag>") is True
messages = context_enricher.memory.lrange("messages")
Expand All @@ -113,8 +121,8 @@ def test_delete_memory_by_tag_list_content(context_enricher: ContextEnricher) ->

def test_delete_memory_by_tag_no_matches(context_enricher: ContextEnricher) -> None:
"""Test deleting messages when no matches exist."""
msg = Message(content="No tags here", role="user")
context_enricher.memory.rpush("messages", msg.__dict__)
msg: Message = {"content": "No tags here", "role": "user"}
context_enricher.memory.rpush("messages", msg)

assert context_enricher.delete_memory_by_tag("<tag>") is False
messages = context_enricher.memory.lrange("messages")
Expand All @@ -126,20 +134,20 @@ def test_delete_memory_by_tag_multiple_matches(
) -> None:
"""Test deleting multiple messages with the same tag."""
messages = [
Message(content="<tag>First tagged", role="user"),
Message(
content=[
{"content": "<tag>First tagged", "role": "user"},
{
"content": [
create_text_block("<tag>Second tagged"),
create_text_block("extra"),
],
role="user",
),
Message(content="Keep this", role="user"),
Message(content=[create_text_block("<tag>Third tagged")], role="user"),
"role": "user",
},
{"content": "Keep this", "role": "user"},
{"content": [create_text_block("<tag>Third tagged")], "role": "user"},
]

for msg in messages:
context_enricher.memory.rpush("messages", msg.__dict__)
context_enricher.memory.rpush("messages", msg)

assert context_enricher.delete_memory_by_tag("<tag>") is True
remaining_messages = context_enricher.memory.lrange("messages")
Expand All @@ -151,8 +159,8 @@ def test_delete_memory_by_tag_with_whitespace(
context_enricher: ContextEnricher,
) -> None:
"""Test deleting messages with whitespace before tag."""
msg = Message(content=" <tag>Indented content", role="user")
context_enricher.memory.rpush("messages", msg.__dict__)
msg: Message = {"content": " <tag>Indented content", "role": "user"}
context_enricher.memory.rpush("messages", msg)

assert context_enricher.delete_memory_by_tag("<tag>") is True
assert len(context_enricher.memory.lrange("messages")) == 0
Loading