diff --git a/requirements.dev.txt b/requirements.dev.txt index 549c8626..5322aeb8 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -51,6 +51,7 @@ sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 +sseclient-py==1.8.0 toml==0.10.2 tomli==2.0.0 typing-extensions==4.2.0 diff --git a/src/steamship/agents/functional/functions_based.py b/src/steamship/agents/functional/functions_based.py index 13ba79a1..a2253b6c 100644 --- a/src/steamship/agents/functional/functions_based.py +++ b/src/steamship/agents/functional/functions_based.py @@ -1,9 +1,11 @@ +import json from typing import List -from steamship import Block +from steamship import Block, MimeTypes, Tag from steamship.agents.functional.output_parser import FunctionsBasedOutputParser -from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, Tool -from steamship.data.tags.tag_constants import RoleTag +from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, FinishAction, Tool +from steamship.data.tags.tag_constants import RoleTag, TagKind, TagValueKey +from steamship.data.tags.tag_utils import get_tag class FunctionsBasedAgent(ChatAgent): @@ -25,13 +27,17 @@ def __init__(self, tools: List[Tool], llm: ChatLLM, **kwargs): output_parser=FunctionsBasedOutputParser(tools=tools), llm=llm, tools=tools, **kwargs ) + def _get_or_create_system_message(self, context: AgentContext) -> Block: + sys_msg = context.chat_history.last_system_message + if sys_msg: + return sys_msg + + return context.chat_history.append_system_message(text=self.PROMPT, mime_type=MimeTypes.TXT) + def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]: - messages: List[Block] = [] + messages: List[Block] = [self._get_or_create_system_message(context)] # get system message - system_message = Block(text=self.PROMPT) - system_message.set_chat_role(RoleTag.SYSTEM) - messages.append(system_message) messages_from_memory = [] # get prior conversations @@ -41,21 +47,15 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]: .wait() .to_ranked_blocks() ) - # TODO(dougreid): we need a way to threshold message inclusion, especially for small contexts - # remove the actual prompt from the semantic search (it will be an exact match) - messages_from_memory = [ - msg - for msg in messages_from_memory - if msg.id != context.chat_history.last_user_message.id - ] - # get most recent context messages_from_memory.extend(context.chat_history.select_messages(self.message_selector)) # de-dupe the messages from memory - ids = [context.chat_history.last_user_message.id] + ids = [ + context.chat_history.last_user_message.id + ] # filter out last user message, it is appended afterwards for msg in messages_from_memory: if msg.id not in ids: messages.append(msg) @@ -67,10 +67,8 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]: # this should happen BEFORE any agent/assistant messages related to tool selection messages.append(context.chat_history.last_user_message) - # get completed steps - actions = context.completed_steps - for action in actions: - messages.extend(action.to_chat_messages()) + # get working history (completed actions) + messages.extend(self._function_calls_since_last_user_message(context)) return messages @@ -81,4 +79,68 @@ def next_action(self, context: AgentContext) -> Action: # Run the default LLM on those messages output_blocks = self.llm.chat(messages=messages, tools=self.tools) - return self.output_parser.parse(output_blocks[0].text, context) + future_action = self.output_parser.parse(output_blocks[0].text, context) + if not isinstance(future_action, FinishAction): + # record the LLM's function response in history + self._record_action_selection(future_action, context) + return future_action + + def _function_calls_since_last_user_message(self, context: AgentContext) -> List[Block]: + function_calls = [] + for block in context.chat_history.messages[::-1]: # is this too inefficient at scale? + if block.chat_role == RoleTag.USER: + return reversed(function_calls) + if get_tag(block.tags, kind=TagKind.ROLE, name=RoleTag.FUNCTION): + function_calls.append(block) + elif get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION): + function_calls.append(block) + return reversed(function_calls) + + def _to_openai_function_selection(self, action: Action) -> str: + fc = {"name": action.tool} + args = {} + for block in action.input: + for t in block.tags: + if t.kind == TagKind.FUNCTION_ARG: + args[t.name] = block.as_llm_input(exclude_block_wrapper=True) + + fc["arguments"] = json.dumps(args) # the arguments must be a string value NOT a dict + return json.dumps(fc) + + def _record_action_selection(self, action: Action, context: AgentContext): + tags = [ + Tag(kind=TagKind.ROLE, name=RoleTag.ASSISTANT), + Tag(kind=TagKind.FUNCTION_SELECTION, name=action.tool), + Tag( + kind="request-id", + name=context.request_id, + value={TagValueKey.STRING_VALUE: context.request_id}, + ), + ] + context.chat_history.file.append_block( + text=self._to_openai_function_selection(action), tags=tags, mime_type=MimeTypes.TXT + ) + + def record_action_run(self, action: Action, context: AgentContext): + super().record_action_run(action, context) + + tags = [ + Tag( + kind=TagKind.ROLE, + name=RoleTag.FUNCTION, + value={TagValueKey.STRING_VALUE: action.tool}, + ), + Tag( + kind="request-id", + name=context.request_id, + value={TagValueKey.STRING_VALUE: context.request_id}, + ), + ] + # TODO(dougreid): I'm not convinced this is correct for tools that return multiple values. + # It _feels_ like these should be named and inlined as a single message in history, etc. + for block in action.output: + context.chat_history.file.append_block( + text=block.as_llm_input(exclude_block_wrapper=True), + tags=tags, + mime_type=block.mime_type, + ) diff --git a/src/steamship/agents/functional/output_parser.py b/src/steamship/agents/functional/output_parser.py index 9dbb8fa7..45c70375 100644 --- a/src/steamship/agents/functional/output_parser.py +++ b/src/steamship/agents/functional/output_parser.py @@ -4,9 +4,9 @@ from json import JSONDecodeError from typing import Dict, List, Optional -from steamship import Block, MimeTypes, Steamship +from steamship import Block, MimeTypes, Steamship, Tag from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool -from steamship.data.tags.tag_constants import RoleTag +from steamship.data.tags.tag_constants import RoleTag, TagKind from steamship.utils.utils import is_valid_uuid4 @@ -43,16 +43,45 @@ def _extract_action_from_function_call(self, text: str, context: AgentContext) - try: args = json.loads(arguments) if text := args.get("text"): - input_blocks.append(Block(text=text, mime_type=MimeTypes.TXT)) + input_blocks.append( + Block( + text=text, + tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")], + mime_type=MimeTypes.TXT, + ) + ) elif uuid_arg := args.get("uuid"): - input_blocks.append(Block.get(context.client, _id=uuid_arg)) + existing_block = Block.get(context.client, _id=uuid_arg) + tag = Tag.create( + existing_block.client, + file_id=existing_block.file_id, + block_id=existing_block.id, + kind=TagKind.FUNCTION_ARG, + name="uuid", + ) + existing_block.tags.append(tag) + input_blocks.append(existing_block) except json.decoder.JSONDecodeError: if isinstance(arguments, str): if is_valid_uuid4(arguments): - input_blocks.append(Block.get(context.client, _id=uuid_arg)) + existing_block = Block.get(context.client, _id=arguments) + tag = Tag.create( + existing_block.client, + file_id=existing_block.file_id, + block_id=existing_block.id, + kind=TagKind.FUNCTION_ARG, + name="uuid", + ) + existing_block.tags.append(tag) + input_blocks.append(existing_block) else: - input_blocks.append(Block(text=arguments, mime_type=MimeTypes.TXT)) - + input_blocks.append( + Block( + text=arguments, + tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")], + mime_type=MimeTypes.TXT, + ) + ) return Action(tool=tool.name, input=input_blocks, context=context) @staticmethod @@ -112,4 +141,5 @@ def parse(self, text: str, context: AgentContext) -> Action: finish_blocks = FunctionsBasedOutputParser._blocks_from_text(context.client, text) for finish_block in finish_blocks: finish_block.set_chat_role(RoleTag.ASSISTANT) + finish_block.set_request_id(context.request_id) return FinishAction(output=finish_blocks, context=context) diff --git a/src/steamship/agents/llms/openai.py b/src/steamship/agents/llms/openai.py index 5ec142c8..35bdd5e3 100644 --- a/src/steamship/agents/llms/openai.py +++ b/src/steamship/agents/llms/openai.py @@ -58,6 +58,7 @@ def complete(self, prompt: str, stop: Optional[str] = None, **kwargs) -> List[Bl if "max_tokens" in kwargs: options["max_tokens"] = kwargs["max_tokens"] + # TODO(dougreid): do we care about streaming here? should we take a kwarg that is file_id ? action_task = self.generator.generate(text=prompt, options=options) action_task.wait() return action_task.output.blocks @@ -84,12 +85,8 @@ def chat(self, messages: List[Block], tools: Optional[List[Tool]], **kwargs) -> Supported kwargs include: - `max_tokens` (controls the size of LLM responses) """ - - temp_file = File.create( - client=self.client, - blocks=messages, - tags=[Tag(kind=TagKind.GENERATION, name=GenerationTag.PROMPT_COMPLETION)], - ) + if len(messages) <= 0: + return [] options = {} if len(tools) > 0: @@ -119,7 +116,31 @@ def chat(self, messages: List[Block], tools: Optional[List[Tool]], **kwargs) -> logging.info(f"OpenAI ChatComplete ({messages[-1].as_llm_input()})", extra=extra) - tool_selection_task = self.generator.generate(input_file_id=temp_file.id, options=options) - tool_selection_task.wait() - - return tool_selection_task.output.blocks + # for streaming use cases, we want to always use the existing file + # the way to detect this would be if all messages were from the same file + if self._from_same_file(blocks=messages): + file_id = messages[0].file_id + block_indices = [b.index_in_file for b in messages] + generate_task = self.generator.generate( + input_file_id=file_id, + input_file_block_index_list=block_indices, + options=options, + append_output_to_file=True, + ) + else: + tags = [Tag(kind=TagKind.GENERATION, name=GenerationTag.PROMPT_COMPLETION)] + temp_file = File.create(client=self.client, blocks=messages, tags=tags) + generate_task = self.generator.generate(input_file_id=temp_file.id, options=options) + + generate_task.wait() + + return generate_task.output.blocks + + def _from_same_file(self, blocks: List[Block]) -> bool: + if len(blocks) <= 1: + return True + file_id = blocks[0].file_id + for b in blocks[1:]: + if b.file_id != file_id: + return False + return True diff --git a/src/steamship/agents/schema/action.py b/src/steamship/agents/schema/action.py index 929f5afd..e57ec6eb 100644 --- a/src/steamship/agents/schema/action.py +++ b/src/steamship/agents/schema/action.py @@ -1,10 +1,9 @@ from typing import List, Optional from pydantic import BaseModel +from pydantic.fields import Field -from steamship import Block, Tag -from steamship.data import TagKind -from steamship.data.tags.tag_constants import RoleTag +from steamship import Block class Action(BaseModel): @@ -22,32 +21,12 @@ class Action(BaseModel): output: Optional[List[Block]] """Any direct output produced by the Tool.""" - is_final: bool = False + is_final: bool = Field(default=False) """Whether this Action should be the final action performed in a reasoning loop. Setting this to True means that the executing Agent should halt any reasoning. """ - def to_chat_messages(self) -> List[Block]: - tags = [ - Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION), - Tag(kind="name", name=self.tool), - ] - blocks = [] - for block in self.output: - # TODO(dougreid): should we revisit as_llm_input? we might need only the UUID... - blocks.append( - Block( - text=block.as_llm_input(exclude_block_wrapper=True), - tags=tags, - mime_type=block.mime_type, - ) - ) - - # TODO(dougreid): revisit when have multiple output functions. - # Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation. - return blocks - class FinishAction(Action): """Represents a final selected action in an Agent Execution.""" diff --git a/src/steamship/agents/schema/agent.py b/src/steamship/agents/schema/agent.py index 890b29dc..110e3b8a 100644 --- a/src/steamship/agents/schema/agent.py +++ b/src/steamship/agents/schema/agent.py @@ -11,7 +11,8 @@ from steamship.agents.schema.message_selectors import MessageSelector, NoMessages from steamship.agents.schema.output_parser import OutputParser from steamship.agents.schema.tool import Tool -from steamship.data.tags.tag_constants import RoleTag +from steamship.data.tags.tag_constants import RoleTag, TagKind +from steamship.data.tags.tag_utils import get_tag class Agent(BaseModel, ABC): @@ -31,6 +32,11 @@ class Agent(BaseModel, ABC): def next_action(self, context: AgentContext) -> Action: pass + @abstractmethod + def record_action_run(self, action: Action, context: AgentContext): + # TODO(dougreid): should this method (or just bit) actually be on AgentContext? + context.completed_steps.append(action) + class LLMAgent(Agent): """LLMAgents choose next actions for an AgentService based on interactions with an LLM.""" @@ -53,17 +59,16 @@ def messages_to_prompt_history(messages: List[Block]) -> str: # Internal Status Messages are not considered part of **prompt** history. # Their inclusion could lead to problematic LLM behavior, etc. # As such are explicitly skipped here: - # - DON'T RETURN AGENT MESSAGES - # - DON'T RETURN TOOL MESSAGES - # - DON'T RETURN LLM MESSAGES + # - DON'T RETURN STATUS MESSAGES + # - DON'T RETURN FUNCTION or FUNCTION_SELECTION MESSAGES if role == RoleTag.USER: as_strings.append(f"User: {block.text}") - elif role == RoleTag.ASSISTANT: + elif role == RoleTag.ASSISTANT and ( + get_tag(block.tags, TagKind.FUNCTION_SELECTION) is None + ): as_strings.append(f"Assistant: {block.text}") elif role == RoleTag.SYSTEM: as_strings.append(f"System: {block.text}") - elif role == RoleTag.FUNCTION: - as_strings.append(f"Function: {block.text}") return "\n".join(as_strings) diff --git a/src/steamship/agents/schema/chathistory.py b/src/steamship/agents/schema/chathistory.py index 43de4cf6..eb6610f1 100644 --- a/src/steamship/agents/schema/chathistory.py +++ b/src/steamship/agents/schema/chathistory.py @@ -281,6 +281,28 @@ def clear(self): self.refresh() + def append_status_message_with_role( + self, + text: str = None, + role: RoleTag = RoleTag.USER, + tags: List[Tag] = None, + content: Union[str, bytes] = None, + url: Optional[str] = None, + mime_type: Optional[MimeTypes] = None, + ) -> Block: + """Append a new block to this with content provided by the end-user.""" + tags = tags or [] + tags.append( + Tag( + kind=TagKind.STATUS_MESSAGE, + name=ChatTag.ROLE, + value={TagValueKey.STRING_VALUE: role}, + ) + ) + return self.file.append_block( + text=text, tags=tags, content=content, url=url, mime_type=mime_type + ) + def append_agent_message( self, text: str = None, @@ -290,7 +312,9 @@ def append_agent_message( mime_type: Optional[MimeTypes] = None, ) -> Block: """Append a new block to this with status update messages from the Agent.""" - return self.append_message_with_role(text, RoleTag.AGENT, tags, content, url, mime_type) + return self.append_status_message_with_role( + text, RoleTag.AGENT, tags, content, url, mime_type + ) def append_tool_message( self, @@ -301,7 +325,9 @@ def append_tool_message( mime_type: Optional[MimeTypes] = None, ) -> Block: """Append a new block to this with status update messages from the Agent.""" - return self.append_message_with_role(text, RoleTag.TOOL, tags, content, url, mime_type) + return self.append_status_message_with_role( + text, RoleTag.TOOL, tags, content, url, mime_type + ) def append_llm_message( self, @@ -312,7 +338,9 @@ def append_llm_message( mime_type: Optional[MimeTypes] = None, ) -> Block: """Append a new block to this with status update messages from the Agent.""" - return self.append_message_with_role(text, RoleTag.LLM, tags, content, url, mime_type) + return self.append_status_message_with_role( + text, RoleTag.LLM, tags, content, url, mime_type + ) class ChatHistoryLoggingHandler(StreamHandler): @@ -324,10 +352,12 @@ class ChatHistoryLoggingHandler(StreamHandler): chat_history: ChatHistory log_level: any streaming_opts: StreamingOpts + request_id: str def __init__( self, chat_history: ChatHistory, + request_id: str, log_level: any = logging.INFO, streaming_opts: Optional[StreamingOpts] = None, ): @@ -340,6 +370,7 @@ def __init__( self.streaming_opts = streaming_opts else: self.streaming_opts = StreamingOpts() + self.request_id = request_id def emit(self, record): if record.levelno < self.log_level: @@ -364,16 +395,16 @@ def _append_message(self, message_dict: dict, author_kind: str): message = message_dict.get("message", None) message_type = message_dict.get(AgentLogging.MESSAGE_TYPE, AgentLogging.MESSAGE) + req_id_tag = Tag( + kind="request-id", + name=self.request_id, + value={TagValueKey.STRING_VALUE: self.request_id}, + ) + if author_kind == AgentLogging.AGENT: return self.chat_history.append_agent_message( text=message, - tags=[ - Tag( - kind=TagKind.AGENT_STATUS_MESSAGE, - name=message_type, - value={TagValueKey.STRING_VALUE: message}, - ), - ], + tags=[req_id_tag], mime_type=MimeTypes.TXT, ) elif author_kind == AgentLogging.TOOL: @@ -385,7 +416,8 @@ def _append_message(self, message_dict: dict, author_kind: str): kind=TagKind.TOOL_STATUS_MESSAGE, name=message_type, value={TagValueKey.STRING_VALUE: message, "tool": tool_name}, - ) + ), + req_id_tag, ], mime_type=MimeTypes.TXT, ) @@ -398,7 +430,8 @@ def _append_message(self, message_dict: dict, author_kind: str): kind=TagKind.LLM_STATUS_MESSAGE, name=message_type, value={TagValueKey.STRING_VALUE: message, "llm": llm_name}, - ) + ), + req_id_tag, ], mime_type=MimeTypes.TXT, ) diff --git a/src/steamship/agents/schema/context.py b/src/steamship/agents/schema/context.py index dad86531..854325d3 100644 --- a/src/steamship/agents/schema/context.py +++ b/src/steamship/agents/schema/context.py @@ -1,4 +1,5 @@ import logging +import uuid from typing import Any, Callable, Dict, List, Optional from steamship import Block, Steamship, Tag @@ -49,10 +50,14 @@ def id(self) -> str: """Caches all interations with LLMs within a Context. This provides a way to avoid duplicated calls to LLMs when within the same context.""" - def __init__(self, streaming_opts: Optional[StreamingOpts] = None): + request_id: str + """Identifier for the current request being handled by this context.""" + + def __init__(self, request_id: str, streaming_opts: Optional[StreamingOpts] = None): self.metadata = {} self.completed_steps = [] self.emit_funcs = [] + self.request_id = request_id # TODO: protect this? if streaming_opts is not None: self._streaming_opts = streaming_opts else: @@ -67,14 +72,18 @@ def get_or_create( use_llm_cache: Optional[bool] = False, use_action_cache: Optional[bool] = False, streaming_opts: Optional[StreamingOpts] = None, + request_id: Optional[str] = None, ): from steamship.agents.schema.chathistory import ChatHistory if streaming_opts is None: streaming_opts = StreamingOpts() + if request_id is None: + request_id = str(uuid.uuid4()) + history = ChatHistory.get_or_create(client, context_keys, tags, searchable=searchable) - context = AgentContext(streaming_opts=streaming_opts) + context = AgentContext(streaming_opts=streaming_opts, request_id=request_id) context.chat_history = history context.client = client @@ -97,7 +106,9 @@ def __enter__(self): if self._streaming_opts.stream_intermediate_events: self._chat_history_logger = ChatHistoryLoggingHandler( - chat_history=self.chat_history, streaming_opts=self._streaming_opts + chat_history=self.chat_history, + streaming_opts=self._streaming_opts, + request_id=self.request_id, ) logger = logging.getLogger() logger.addHandler(self._chat_history_logger) diff --git a/src/steamship/agents/schema/message_selectors.py b/src/steamship/agents/schema/message_selectors.py index 6c5b0f26..dfd31b4b 100644 --- a/src/steamship/agents/schema/message_selectors.py +++ b/src/steamship/agents/schema/message_selectors.py @@ -5,7 +5,8 @@ from pydantic.main import BaseModel from steamship import Block -from steamship.data.tags.tag_constants import RoleTag +from steamship.data.tags.tag_constants import RoleTag, TagKind +from steamship.data.tags.tag_utils import get_tag class MessageSelector(BaseModel, ABC): @@ -29,20 +30,34 @@ def is_assistant_message(block: Block) -> bool: return role == RoleTag.ASSISTANT +def is_assistant_function_message(block: Block) -> bool: + is_function_selection = get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION) + return is_assistant_message(block) and is_function_selection + + +def is_user_history_message(block: Block) -> bool: + return is_user_message(block) or ( + is_assistant_message(block) and not is_assistant_function_message(block) + ) + + class MessageWindowMessageSelector(MessageSelector): k: int def get_messages(self, messages: List[Block]) -> List[Block]: msgs = messages[:] msgs.pop() # don't add the current prompt to the memory - if len(msgs) <= (self.k * 2): - return msgs + history_msgs = [ + msg for msg in msgs if is_user_history_message(msg) + ] # filter to only user history messages + if len(history_msgs) <= (self.k * 2): + return history_msgs selected_msgs = [] limit = self.k * 2 - scope = msgs[len(messages) - limit :] + scope = history_msgs[len(history_msgs) - limit :] for block in scope: - if is_user_message(block) or is_assistant_message(block): + if is_user_history_message(block): selected_msgs.append(block) return selected_msgs @@ -63,7 +78,10 @@ def get_messages(self, messages: List[Block]) -> List[Block]: msgs = messages[:] msgs.pop() # don't add the current prompt to the memory - for block in reversed(msgs): + history_msgs = [ + msg for msg in msgs if is_user_history_message(msg) + ] # filter to only user history messages + for block in reversed(history_msgs): if block.chat_role != RoleTag.SYSTEM and current_tokens < self.max_tokens: block_tokens = tokens(block) if block_tokens + current_tokens < self.max_tokens: diff --git a/src/steamship/agents/service/agent_service.py b/src/steamship/agents/service/agent_service.py index 26c75410..46367d88 100644 --- a/src/steamship/agents/service/agent_service.py +++ b/src/steamship/agents/service/agent_service.py @@ -3,15 +3,24 @@ from collections import defaultdict from typing import Dict, List, Optional -from steamship import Block, SteamshipError, Task +from pydantic.main import BaseModel + +from steamship import Block, File, SteamshipError, Task from steamship.agents.llms.openai import OpenAI from steamship.agents.logging import AgentLogging, StreamingOpts from steamship.agents.schema import Action, Agent, FinishAction from steamship.agents.schema.context import AgentContext, EmitFunc, Metadata from steamship.agents.utils import with_llm +from steamship.data import TagKind +from steamship.data.tags.tag_constants import ChatTag from steamship.invocable import PackageService, post +class StreamingResponse(BaseModel): + task: Task + file: File + + def build_context_appending_emit_func( context: AgentContext, make_blocks_public: Optional[bool] = False ) -> EmitFunc: @@ -116,6 +125,7 @@ def next_action(self, agent: Agent, input_blocks: List[Block], context: AgentCon }, ) + # save action selection to history... return action def run_action(self, agent: Agent, action: Action, context: AgentContext): @@ -141,7 +151,7 @@ def run_action(self, agent: Agent, action: Action, context: AgentContext): }, ) action.output = output_blocks - context.completed_steps.append(action) + agent.record_action_run(action, context) return tool = next((tool for tool in agent.tools if tool.name == action.tool), None) @@ -182,7 +192,8 @@ def run_action(self, agent: Agent, action: Action, context: AgentContext): action.is_final = ( tool.is_final ) # Permit the tool to decide if this action should halt the reasoning loop. - context.completed_steps.append(action) + + agent.record_action_run(action, context) if context.action_cache and tool.cacheable: context.action_cache.update(key=action, value=action.output) @@ -317,6 +328,7 @@ def build_default_context(self, context_id: Optional[str] = None, **kwargs) -> A context = AgentContext.get_or_create( client=self.client, + request_id=self.client.config.request_id, context_keys={"id": f"{context_id}"}, use_llm_cache=use_llm_cache, use_action_cache=use_action_cache, @@ -338,6 +350,33 @@ def build_default_context(self, context_id: Optional[str] = None, **kwargs) -> A context = with_llm(context=context, llm=llm) return context + @post("async_prompt") + def async_prompt( + self, prompt: Optional[str] = None, context_id: Optional[str] = None, **kwargs + ) -> StreamingResponse: + with self.build_default_context(context_id, **kwargs) as context: + ctx_id = context_id + + # if no context ID is provided, we need to make sure that the streaming context ID + # is the same one as the non-streaming. + if not ctx_id: + ctx_file = context.chat_history.file + for tag in ctx_file.tags: + if tag.kind == TagKind.CHAT and tag.name == ChatTag.CONTEXT_KEYS: + if value := tag.value: + ctx_id = value.get("id", None) + + # if you can't find a consistent context_id, then there is no way to provide an accurate + # streaming endpoint. + if not ctx_id: + # TODO(dougreid): this points to a slight flaw in the context_keys vs. context_id + raise SteamshipError("Error setting up context: no id found for context.") + + task = self.invoke_later( + "/prompt", arguments={"prompt": prompt, "context_id": ctx_id, **kwargs} + ) + return StreamingResponse(task=task, file=context.chat_history.file) + @post("prompt") def prompt( self, prompt: Optional[str] = None, context_id: Optional[str] = None, **kwargs diff --git a/src/steamship/data/block.py b/src/steamship/data/block.py index b227f82f..f04d0eff 100644 --- a/src/steamship/data/block.py +++ b/src/steamship/data/block.py @@ -287,6 +287,13 @@ def set_chat_id(self, chat_id: str): tag_kind=DocTag.CHAT, tag_name=ChatTag.CHAT_ID, string_value=chat_id ) + def set_request_id(self, request_id: Optional[str]): + if not request_id or len(request_id.strip()) == 0: + return + return self._one_time_set_tag( + tag_kind="request-id", tag_name=request_id, string_value=request_id + ) + @property def thread_id(self) -> Optional[str]: return get_tag_value_key( diff --git a/src/steamship/data/tags/tag_constants.py b/src/steamship/data/tags/tag_constants.py index d4847bf9..381db0f1 100644 --- a/src/steamship/data/tags/tag_constants.py +++ b/src/steamship/data/tags/tag_constants.py @@ -30,9 +30,12 @@ class TagKind(str, Enum): CHAT = "chat" CHAT_HISTORY_CONTEXT = "chat-history-context" MESSAGE_ID = "message-id" + STATUS_MESSAGE = "status-message" AGENT_STATUS_MESSAGE = "agent-status-message" TOOL_STATUS_MESSAGE = "tool-status-message" LLM_STATUS_MESSAGE = "llm-status-message" + FUNCTION_ARG = "function-arg" + FUNCTION_SELECTION = "function-selection" class DocTag(str, Enum): diff --git a/src/steamship/utils/repl.py b/src/steamship/utils/repl.py index bcfd163c..1f1c8251 100644 --- a/src/steamship/utils/repl.py +++ b/src/steamship/utils/repl.py @@ -14,6 +14,8 @@ from steamship.agents.logging import AgentLogging from steamship.agents.schema import AgentContext, Tool from steamship.agents.service.agent_service import AgentService +from steamship.data import TagKind, TagValueKey +from steamship.data.tags.tag_utils import get_tag from steamship.data.workspace import Workspace from steamship.invocable.dev_logging_handler import DevelopmentLoggingHandler @@ -202,10 +204,23 @@ def print_history(self, client: Steamship, *args, **kwargs): history = agent_ctx.chat_history history.refresh() for block in history.messages: + chat_role = block.chat_role + status_msg = get_tag(block.tags, kind=TagKind.STATUS_MESSAGE) + if not chat_role and not status_msg: + continue + + if chat_role: + prefix = f"[{chat_role}]" + else: + if value := status_msg.value: + prefix = f"[{value.get(TagValueKey.STRING_VALUE)} status]" + else: + prefix = "[status]" + if block.is_text(): - print(f"[{block.chat_role}] {block.text}") + print(f"{prefix} {block.text}") else: - print(f"[{block.chat_role}] {block.id} ({block.mime_type})") + print(f"{prefix} {block.id} ({block.mime_type})") print("\n------------------------------\n") exit(0) diff --git a/tests/steamship_tests/agents/test_agent_service.py b/tests/steamship_tests/agents/test_agent_service.py index ecb9ffec..927cb0d8 100644 --- a/tests/steamship_tests/agents/test_agent_service.py +++ b/tests/steamship_tests/agents/test_agent_service.py @@ -1,13 +1,17 @@ +import json +import time from typing import Any, List, Union import pytest +import requests from pydantic.fields import PrivateAttr from steamship_tests import SRC_PATH from steamship_tests.utils.deployables import deploy_package -from steamship import Block, Steamship, SteamshipError, Task +from steamship import Block, File, Steamship, SteamshipError, Task, TaskState from steamship.agents.functional import FunctionsBasedAgent from steamship.agents.llms.openai import ChatOpenAI +from steamship.agents.logging import AgentLogging from steamship.agents.schema import Action, AgentContext, Tool from steamship.agents.service.agent_service import AgentService from steamship.data.tags.tag_constants import ChatTag, RoleTag, TagKind, TagValueKey @@ -25,7 +29,6 @@ def _blocks_from_invoke(client: Steamship, potential_blocks) -> List[Block]: @pytest.mark.usefixtures("client") def test_example_with_caching_service(client: Steamship): - # TODO(dougreid): replace the example agent with fake/free/fast tools to minimize test time / costs? example_caching_agent_path = ( @@ -85,7 +88,6 @@ def test_example_with_caching_service(client: Steamship): class FakeUncachableTool(Tool): - name = "FakeUncacheableTool" human_description = "Fake tool" agent_description = "Ignored" @@ -248,3 +250,169 @@ def test_context_logging_to_chat_history_everything(client: Steamship): assert not has_status_message(chat_history.messages, RoleTag.AGENT) assert not has_status_message(chat_history.messages, RoleTag.LLM) assert has_status_message(chat_history.messages, RoleTag.TOOL) + + +@pytest.fixture() +def close_clients(): + to_close = [] + yield to_close + for item in to_close: + item.close() + + +@pytest.mark.usefixtures("client", "close_clients") +def test_async_prompt(client: Steamship, close_clients): # noqa: C901 + example_agent_service_path = ( + SRC_PATH / "steamship" / "agents" / "examples" / "example_assistant.py" + ) + with deploy_package(client, example_agent_service_path, wait_for_init=True) as ( + _, + _, + agent_service, + ): + context_id = "some_async_fun" + try: + streaming_resp = agent_service.invoke( + "async_prompt", + prompt="who is the current president of the United States?", + context_id=context_id, + ) + except SteamshipError as error: + pytest.fail(f"failed request: {error}") + + assert streaming_resp is not None + assert streaming_resp["file"] is not None + assert streaming_resp["task"] is not None + + file = File(client=client, **(streaming_resp["file"])) + streaming_task = Task(client=client, **(streaming_resp["task"])) + + original_len = len(file.blocks) + + while streaming_task.state in [TaskState.waiting]: + # tight loop to check on waiting status of Task + # we only want to try to stream once a Task **starts** + time.sleep(0.1) + streaming_task.refresh() + + assert streaming_task.state in [TaskState.running] + + block_ids_seen = [] + llm_prompt_event_count = 0 + function_selection_event = False + tool_execution_event = False + function_complete_event = False + assistant_chat_response_event = False + + # TODO: it is concerning that putting this block after a Task completes is the only way it seems to work + num_events = 0 + # sse event format: {'blockCreated': {'blockId': '', 'createdAt': '2023-09-25T16:12:54Z'}} + for event in events_while_running(client, streaming_task, file): + # TODO: it seems like event ids aren't consistent + block_creation_event = json.loads(event.data) + block_created = block_creation_event["blockCreated"] + block_id = block_created["blockId"] + if block_id in block_ids_seen: + continue + block_ids_seen.append(block_id) + num_events += 1 + block = Block.get(client=client, _id=block_id) + for t in block.tags: + match t.kind: + case TagKind.LLM_STATUS_MESSAGE: + if t.name == AgentLogging.PROMPT: + llm_prompt_event_count += 1 + case TagKind.FUNCTION_SELECTION: + if t.name == "SearchTool": + function_selection_event = True + case TagKind.TOOL_STATUS_MESSAGE: + tool_execution_event = True + case TagKind.ROLE: + if t.name == RoleTag.FUNCTION: + function_complete_event = True + case TagKind.CHAT: + if ( + t.name == ChatTag.ROLE + and t.value.get(TagValueKey.STRING_VALUE, "") == RoleTag.ASSISTANT + ): + assistant_chat_response_event = True + + file.refresh() + assert ( + len(file.blocks) > original_len + ), "File should have increased in size during AgentService execution" + + print(f"num events: {num_events}") + assert num_events > 0, "Events should have been streamed during execution" + assert llm_prompt_event_count == 2, ( + "At least 2 llm prompts should have happened (first for tool selection, " + "second for generating final answer)" + ) + assert function_selection_event is True, "SearchTool should have been selected" + assert tool_execution_event is True, "SearchTool should log a status message" + assert function_complete_event is True, "SearchTool should return a response" + assert ( + assistant_chat_response_event is True + ), "Agent should have sent the assistant chat response" + # assert False + + +def events_while_running(client: Steamship, task: Task, file: File): + req_id = task.request_id + while task.state in [TaskState.running]: + yields = 0 + # NOTE: I'm convinced that this generator of generators approach is not correct, but... + event_gen = events_for_file(client, file.id, req_id) + try: + for event in event_gen: + yields += 1 + yield event + except StopIteration: + # not ready to stream, or done streaming. + print("stop iteration") + pass + print(f"total yields: {yields}") + # This is not ideal, but it at least should make sure we get **all** events + task.refresh() + print(f"task is complete: {task.state}") + + +def events_for_file(client: Steamship, file_id: str, req_id: str): + print("getting events for file.") + import sseclient + + sse_source = f"{client.config.api_base}file/{file_id}/stream?tagKindFilter=request-id&tagNameFilter={req_id}&timeoutSeconds=30" + headers = { + "Accept": "text/event-stream", + "X-Workspace-Id": client.get_workspace().id, + "Authorization": f"Bearer {client.config.api_key.get_secret_value()}", + } + + sse_response = requests.get(sse_source, stream=True, headers=headers, timeout=45) + sse_client = sseclient.SSEClient(sse_response) + yields = 0 + try: + for event in sse_client.events(): + yields += 1 + print(f"--> yield: {yields}") + yield event + except requests.exceptions.ConnectionError as err: + if "Read timed out." in str(err): + print("-- timeout") + pass + else: + sse_client.close() + sse_response.close() + raise err + except StopIteration: + print("-- stop iteration") + sse_client.close() + sse_response.close() + except Exception as err: + sse_client.close() + sse_response.close() + raise err + else: + print("-- successful close of stream.") + sse_client.close() + sse_response.close()