From 4ac57babf54f748553a5a5be421b0ad9d2a058e2 Mon Sep 17 00:00:00 2001 From: George Murray Date: Tue, 24 Oct 2023 00:21:44 -0700 Subject: [PATCH] WIP --- requirements.txt | 4 +- src/api.py | 121 ++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 101 insertions(+), 24 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2cdc98e..94da9d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -steamship==2.17.22 +steamship @ git+https://github.com/steamship-core/python-client@george/agents-refactor openai==0.27.8 -tenacity==8.2.2 \ No newline at end of file +tenacity==8.2.2 diff --git a/src/api.py b/src/api.py index ba49e0a..9964d97 100644 --- a/src/api.py +++ b/src/api.py @@ -1,14 +1,24 @@ import json import logging -from typing import Any, Dict, List, Optional, Type - import openai from pydantic import Field - - -from steamship import Steamship, Block, Tag, SteamshipError +from steamship import Steamship, Block, Tag, SteamshipError, MimeTypes +from steamship.agents.schema import Tool +from steamship.agents.schema.functions import ( + OpenAIFunction, + FunctionProperty, + JSONType, + FunctionParameters, +) from steamship.data.tags.tag_constants import TagKind, RoleTag, TagValueKey, ChatTag from steamship.invocable import Config, InvocableResponse, InvocationContext +from steamship.plugin.capabilities import ( + RequestedCapabilities, + SystemPromptSupport, + SupportLevel, + ConversationSupport, + FunctionCallingSupport, +) from steamship.plugin.generator import Generator from steamship.plugin.inputs.raw_block_and_tag_plugin_input import ( RawBlockAndTagPluginInput, @@ -30,6 +40,7 @@ before_sleep_log, wait_exponential_jitter, ) +from typing import Any, Dict, List, Optional, Type, Mapping VALID_MODELS_FOR_BILLING = [ "gpt-4", @@ -44,6 +55,37 @@ ] +SUPPORT_MAP = { + SystemPromptSupport: SupportLevel.NATIVE, + ConversationSupport: SupportLevel.NATIVE, + FunctionCallingSupport: SupportLevel.NATIVE +} + + +def tool_as_openai_function(tool: Tool) -> OpenAIFunction: + text_property_schema = FunctionProperty( + type=JSONType.string, + description="text prompt for a function.", + ) + + uuid_property_schema = FunctionProperty( + type=JSONType.string, + description="UUID for a Steamship Block. Used to refer to a non-textual input generated by another " + "function. Example: c2f6818c-233d-4426-9dc5-f3c28fa33068", + ) + + params = FunctionParameters( + properties={"text": text_property_schema, "uuid": uuid_property_schema}, + ) + + return OpenAIFunction( + name=tool.name, + # TODO Alter this? How do we not break folks? + description=tool.human_description, + parameters=params, + ) + + class GPT4Plugin(Generator): """ Plugin for generating text using OpenAI's GPT-4 model. @@ -152,6 +194,27 @@ def prepare_message(self, block: Block) -> Optional[Dict[str, str]]: if tag.kind == "name": name = tag.name + if block.mime_type == MimeTypes.STEAMSHIP_PLUGIN_FUNCTION_CALL_INVOCATION: + invocation = FunctionCallingSupport.FunctionCallInvocation.from_block(block) + return { + "role": "assistant", # This does not use our enums, because those are for our purposes, and this is what OpenAI wants. + "content": None, + "function_call": { + "name": invocation.tool_name, + "arguments": { + arg_name: value for arg_name, value in invocation.args + } + } + } + + if block.mime_type == MimeTypes.STEAMSHIP_PLUGIN_FUNCTION_CALL_RESULT: + call_result = FunctionCallingSupport.FunctionCallResult.from_block(block) + return { + "role": "function", + "name": call_result.tool_name, + "content": call_result.result + } + if role is None: role = self.config.default_role @@ -170,8 +233,9 @@ def prepare_message(self, block: Block) -> Optional[Dict[str, str]]: return {"role": role, "content": block.text} - def prepare_messages(self, blocks: List[Block]) -> List[Dict[str, str]]: + def prepare_messages(self, blocks: List[Block], ) -> List[Dict[str, str]]: messages = [] + # TODO (SHIP-854) coalesce this with conversation history if self.config.default_system_prompt != "": messages.append( {"role": RoleTag.SYSTEM, "content": self.config.default_system_prompt} @@ -187,7 +251,7 @@ def prepare_messages(self, blocks: List[Block]) -> List[Dict[str, str]]: return messages def generate_with_retry( - self, user: str, messages: List[Dict[str, str]], options: Dict + self, user: str, messages: List[Dict[str, str]], requested_capabilities: Optional[RequestedCapabilities], options: Dict ) -> (List[Block], List[UsageReport]): """Call the API to generate the next section of text.""" logging.info( @@ -195,7 +259,13 @@ def generate_with_retry( ) options = options or {} stopwords = options.get("stop", None) - functions = options.get("functions", None) + if requested_capabilities: + function_calling_support = requested_capabilities.get(FunctionCallingSupport) + if function_calling_support: + functions = [tool_as_openai_function(tool) for tool in function_calling_support.functions] + else: + # Legacy behavior + functions = options.get("functions", None) @retry( reraise=True, @@ -238,16 +308,25 @@ def _generate_with_retry() -> Any: ) # Fetch text from responses - generations = [] + generation_blocks = [] for choice in openai_result["choices"]: message = choice["message"] role = message["role"] + mime_type = None if function_call := message.get("function_call"): - content = json.dumps({"function_call": function_call}) + if requested_capabilities: + content = FunctionCallingSupport.FunctionCallInvocation( + tool_name=function_call["name"], + arguments=function_call["arguments"] + ).json() + mime_type = FunctionCallingSupport.FunctionCallInvocation.MIME_TYPE + else: + # Legacy behavior + content = json.dumps({"function_call": function_call}) else: content = message.get("content", "") - generations.append((content, role)) + generation_blocks.append(Block(text=content, role=role, mime_type=mime_type)) # for token usage tracking, we need to include not just the token usage, but also completion id # that will allow proper usage aggregration for n > 1 cases @@ -269,15 +348,7 @@ def _generate_with_retry() -> Any: ), ] - return [ - Block( - text=text, - tags=[ - Tag(kind=TagKind.ROLE, name=RoleTag(role)), - ], - ) - for text, role in generations - ], usage_reports + return generation_blocks, usage_reports @staticmethod def _flagged(messages: List[Dict[str, str]]) -> bool: @@ -293,7 +364,10 @@ def run( """Run the text generator against all the text, combined""" self.config.extend_with_dict(request.data.options, overwrite=True) - + requested_capabilities = RequestedCapabilities(SUPPORT_MAP) + capability_response = requested_capabilities.extract_from_blocks(request.data.blocks) + if not capability_response: + requested_capabilities = None messages = self.prepare_messages(request.data.blocks) if self.config.moderate_output and self._flagged(messages): raise SteamshipError( @@ -301,8 +375,11 @@ def run( ) user_id = self.context.user_id if self.context is not None else "testing" generated_blocks, usage_reports = self.generate_with_retry( - messages=messages, user=user_id, options=request.data.options + messages=messages, user=user_id, options=request.data.options, requested_capabilities=requested_capabilities ) + if capability_response: + # TODO append to file? + generated_blocks.append(capability_response.to_block()) return InvocableResponse( data=RawBlockAndTagPluginOutput(