diff --git a/CHANGELOG.md b/CHANGELOG.md index ccba4e6..dacd133 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [0.6.0] +### Added +- Commands for `assistants list` and `assistants clean` +- Automatic cleanup of any extra assistants in user account when initiating chat +- Added docstrings for undocumented functions + ### Changed - Updated CLI chat command to `vs chat` - Refactored CLI into separate modules +- Removed assistant ID persistance from settings file and only attempt to retrieve assistant from API ## [0.5.1] ### Fixed diff --git a/src/vecsync/chat/clients/base.py b/src/vecsync/chat/clients/base.py new file mode 100644 index 0000000..82aa12d --- /dev/null +++ b/src/vecsync/chat/clients/base.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class Assistant(BaseModel): + id: str + name: str diff --git a/src/vecsync/chat/clients/openai.py b/src/vecsync/chat/clients/openai.py index 3173aa1..fc14faf 100644 --- a/src/vecsync/chat/clients/openai.py +++ b/src/vecsync/chat/clients/openai.py @@ -1,7 +1,9 @@ from queue import Empty, Queue from openai import AssistantEventHandler, OpenAI +from termcolor import cprint +from vecsync.chat.clients.base import Assistant from vecsync.chat.formatter import ConsoleFormatter, GradioFormatter from vecsync.settings import SettingExists, SettingMissing, Settings from vecsync.store.openai import OpenAiVectorStore @@ -10,6 +12,21 @@ # TODO: This class will likely be refactored into common class across other client types. However # since we only have OpenAI at the moment, we'll keep it here for now. class OpenAIHandler(AssistantEventHandler): + """Handler for OpenAI API events. + + This class is used to handle streaming events from the OpenAI API. Internally, it puts all streaming + chunks into a Queue which allows for the streaming to be consumed in real time by other functions. + + Parameters + ---------- + files : dict[str, str] + A dictionary of file IDs and their corresponding names. This is used to format the citations + in the response. + formatter : ConsoleFormatter | GradioFormatter + The formatter to use for formatting the output of the response. This can be either a + ConsoleFormatter or GradioFormatter. + """ + def __init__(self, files: dict[str, str], formatter: ConsoleFormatter | GradioFormatter): super().__init__() self.files = files @@ -19,6 +36,7 @@ def __init__(self, files: dict[str, str], formatter: ConsoleFormatter | GradioFo self.formatter = formatter def on_message_delta(self, delta, snapshot): + # Handle the response chunk delta_annotations = {} text_chunks = [] @@ -43,12 +61,27 @@ def on_message_delta(self, delta, snapshot): self.queue.put("".join(text_chunks)) def on_message_done(self, message): + # Append citations at the end of the response text = self.formatter.get_references(self.annotations, self.files) self.queue.put(text) self.active = False def consume_queue(self, timeout: float = 1.0): - """Pulls from handler.queue in real time, calls write_fn(text).""" + """Consume chunks from the queue. + + Parameters + ---------- + timeout : float + The timeout seconds for the queue. This is used to prevent blocking if there are no chunks + available in the queue. + + Yields + ------ + str + The chunks of text from the queue. This will yield until the queue is empty or the + active flag is set to False. + """ + while self.active or not self.queue.empty(): try: chunk = self.queue.get(timeout=timeout) @@ -60,39 +93,126 @@ def consume_queue(self, timeout: float = 1.0): class OpenAIClient: - def __init__(self, store_name: str, new_conversation: bool = False): + """OpenAI client for interacting with the OpenAI API. + + This client is used to send messages to the OpenAI API and receive responses. + + Parameters + ---------- + store_name : str + The name of the vector store to use for this client. The named assistant will + be created in the form of "vecsync-{store_name}". + + """ + + def __init__(self, store_name: str): self.client = OpenAI() - self.vector_store = OpenAiVectorStore(store_name) + self.store_name = store_name + self.assistant_name = f"vecsync-{store_name}" + self.connected = False + + def connect(self): + """Connect to the OpenAI API and load the assistant and thread. + + There are four independent entitites in the OpenAI API: + 1. Files: User files are uploaded to OpenAI and exist as an artifact which can be used in + several places. The file references are loaded here for translating ciation references. + 2. Vector Store: The vector store is a collection of files which are used for RAG search + by the assistant. A vector store can be assigned to multiple assistants. + 3. Assistant: The assistant is the entity which is used to interact with the OpenAI API. + We currently only support one assistant per OpenAI account. + 4. Thread: The thread is the conversation history which is used to store messages between the + user and assistant. Threads are created whenever a new assistant is created, the user + deletes their settings file, or the user runs the application on a different machine. + """ + # Connect to the OpenAI vector store + self.vector_store = OpenAiVectorStore(self.store_name) self.vector_store.get() - self.assistant_name = f"vecsync-{self.vector_store.store.name}" - self.assistant_id = self._get_or_create_assistant() - - self.thread_id = None if new_conversation else self._get_thread_id() + # Load the assistant and thread + self.assistant_id = self._get_assistant_id() + self.thread_id = self._get_thread_id() + # Load the files in the vector store self.files = {f.id: f.name for f in self.vector_store.get_files()} - - def _get_thread_id(self) -> str | None: + self.connected = True + + def disconnect(self): + """Clear all OpenAI client state.""" + self.assistant_id = None + self.thread_id = None + self.files = None + self.vector_store = None + self.connected = False + + def _get_thread_id(self) -> str: + """Locates or creates the thread ID + + Thread IDs are stored locally in the user settings file. The ID is loaded from settings and + is created if it doesn't exist. + + Returns + ------- + str + The thread ID for the current conversation. + """ settings = Settings() + # TODO: Ideally we would grab the thread ID from OpenAI but there doesn't seem to be + # a way to do that. So we are storing it in the settings file for now. match settings["openai_thread_id"]: case SettingMissing(): - return None + return self._create_thread() case SettingExists() as x: print(f"✅ Thread found: {x.value}") return x.value - def _get_or_create_assistant(self): - settings = Settings() + def _get_assistant_id(self) -> str: + """Locates or creates the assistant ID - match settings["openai_assistant_id"]: - case SettingExists() as x: - print(f"✅ Assistant found: {x.value}") - return x.value - case _: - return self._create_assistant() + Assistant IDs are stored in the OpenAI account. The ID is loaded from the account and is created + if it doesn't exist. There should only be one assistant per account at any point in time. This step + performs a cleanup check if multiple assistants are found. + + Returns + ------- + str + The assistant ID for the current conversation. + """ + # Check if the assistant already exists + existing_assistants = self.list_assistants() + count_assistants = len(existing_assistants) + + if count_assistants > 1: + # We only allow for one assistant per account at this time + # This state shouldn't happen, but if it does, we need to remove the extras + # to keep things clean + + count_extra = count_assistants - 1 + cprint(f"⚠️ Multiple vecsync assistants found in account. Cleaning up {count_extra} extras.", "yellow") + + for assistant in existing_assistants[1:]: + self.delete_assistant(assistant.id) + + if count_assistants > 0: + id = existing_assistants[0].id + print(f"✅ Assistant found remotely: {id}") + return id + else: + return self._create_assistant() def _create_assistant(self) -> str: + """Creates a new assistant in the OpenAI account. + + The assistant is created with the name "vecsync-{store_name}" and is attached to the + user's vector store. + + Returns + ------- + str + The assistant ID for the current conversation. + """ + instructions = """You are a helpful research assistant that can search through a large number of journals and papers to help answer the user questions. You have been given a file store which contains the relevant documents the user is referencing. These documents should be your primary source of information. @@ -114,15 +234,45 @@ def _create_assistant(self) -> str: settings = Settings() del settings["openai_thread_id"] - del settings["openai_assistant_id"] print(f"🖥️ Assistant created: {assistant.name}") print(f"🔗 Assistant URL: https://platform.openai.com/assistants/{assistant.id}") - settings["openai_assistant_id"] = assistant.id return assistant.id + def _create_thread(self) -> str: + """Creates a new thread in the OpenAI account. + + The new thread ID is stored in the local settings file since OpenAI doesn't provide a way to + remotely access the thread ID. + + Returns + ------- + str + The thread ID for the current conversation. + """ + + thread = self.client.beta.threads.create() + print(f"💬 Conversation started: {thread.id}") + settings = Settings() + settings["openai_thread_id"] = thread.id + return thread.id + def load_history(self) -> list[dict[str, str]]: - """Fetch all prior messages in this thread""" + """Fetch all prior messages in this thread + + This method loads the conversation history from the OpenAI API. The messages are sorted + by their creation time. + + Returns + ------- + list[dict[str, str]] + A list of dictionaries containing the role and content of each message. + The role is either "user" or "assistant". + """ + + if not self.connected: + self.connect() + history = [] if self.thread_id is not None: resp = self.client.beta.threads.messages.list(thread_id=self.thread_id) @@ -138,31 +288,63 @@ def load_history(self) -> list[dict[str, str]]: return history - def initialize_chat(self): - if self.thread_id is None: - thread = self.client.beta.threads.create() - self.thread_id = thread.id - print(f"💬 Conversation started: {self.thread_id}") - settings = Settings() - settings["openai_thread_id"] = self.thread_id + def send_message(self, prompt: str): + """Send a message to the OpenAI thread. - def _run_stream(self, handler: OpenAIHandler): - with self.client.beta.threads.runs.stream( - thread_id=self.thread_id, - assistant_id=self.assistant_id, - event_handler=handler, - ) as stream: - stream.until_done() + Parameters + ---------- + prompt : str + The message to send to the OpenAI thread. + """ - def send_message(self, prompt: str): - self.initialize_chat() + if not self.connected: + self.connect() return self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=prompt) def stream_response(self, thread_id: str, assistant_id: str, handler): + """Generate a thread run and stream the response. + + Parameters + ---------- + thread_id : str + The ID of the thread to stream the response from. + assistant_id : str + The ID of the assistant to stream the response from. + handler : AssistantEventHandler + The event handler to use for processing the response. + """ + with self.client.beta.threads.runs.stream( thread_id=thread_id, assistant_id=assistant_id, event_handler=handler, ) as stream: stream.until_done() + + def list_assistants(self) -> list[Assistant]: + """List all vecsync assistants in the OpenAI account. + + This only returns vecsync assistants which are prefixed with "vecsync-". There should + only be one assistant per account, but this method is here to help with cleanup if + multiple assistants are created. + + Returns: + list[Assistant]: A list of Assistant objects. + """ + results = [] + + for assistant in self.client.beta.assistants.list(): + if assistant.name.startswith("vecsync-"): + results.append(Assistant(id=assistant.id, name=assistant.name)) + + return results + + def delete_assistant(self, assistant_id: str): + """Delete an assistant from the OpenAI account. + + Args: + assistant_id (str): The ID of the assistant to delete. + """ + self.client.beta.assistants.delete(assistant_id) + self.disconnect() diff --git a/src/vecsync/chat/interface.py b/src/vecsync/chat/interface.py index 5cd44ad..7e96ba3 100644 --- a/src/vecsync/chat/interface.py +++ b/src/vecsync/chat/interface.py @@ -8,6 +8,17 @@ class ConsoleInterface: + """Interact with the assistant via the console. + + This class allows for messages to be sent and received in a console environment. Multithreading + is used to allow for streaming responses. + + Parameters + ---------- + client : OpenAIClient + The OpenAI client used to send and receive messages. + """ + def __init__(self, client: OpenAIClient): self.client = client self.executor = ThreadPoolExecutor(max_workers=1) @@ -25,6 +36,17 @@ def prompt(self, prompt_text: str): class GradioInterface: + """Interact with the assistant via the Gradio UI. + + This class allows for messages to be sent and received in a Gradio environment. Multithreading + is used to allow for streaming responses. The Gradio UI is launched locally. + + Parameters + ---------- + client : OpenAIClient + The OpenAI client used to send and receive messages. + """ + def __init__(self, client: OpenAIClient): self.client = client self.executor = ThreadPoolExecutor(max_workers=1) diff --git a/src/vecsync/cli.py b/src/vecsync/cli.py index a0c39c4..d0b4b47 100644 --- a/src/vecsync/cli.py +++ b/src/vecsync/cli.py @@ -85,15 +85,9 @@ def sync(source: str): @click.command("chat") -@click.option( - "--new-conversation", - "-n", - is_flag=True, - help="Force the assistant to create a new thread.", -) -def chat_assistant(new_conversation: bool): +def chat_assistant(): """Chat with the assistant.""" - client = OpenAIClient("test", new_conversation=new_conversation) + client = OpenAIClient("test") ui = ConsoleInterface(client) print('Type "exit" to quit at any time.') diff --git a/src/vecsync/cli/assistants.py b/src/vecsync/cli/assistants.py new file mode 100644 index 0000000..f37a392 --- /dev/null +++ b/src/vecsync/cli/assistants.py @@ -0,0 +1,59 @@ +import click +from termcolor import cprint + +from vecsync.chat.clients.openai import OpenAIClient +from vecsync.constants import DEFAULT_STORE_NAME + + +@click.command(name="list") +def list_assistants(): + """List all vecync assistants in the OpenAI account.""" + client = OpenAIClient(store_name=DEFAULT_STORE_NAME) + assistants = client.list_assistants() + + if len(assistants) == 0: + cprint("No assistants found.", "green", attrs=["bold"]) + else: + cprint("Assistants in your OpenAI account:", "green", attrs=["bold"]) + for i, assistant in enumerate(assistants): + cprint(f" {i + 1}. Name: {assistant.name} ({assistant.id})", "yellow") + + +@click.command() +def clean(): + """Clean up vecsync assistants in the OpenAI account.""" + client = OpenAIClient(store_name=DEFAULT_STORE_NAME) + assistants = client.list_assistants() + + if len(assistants) == 0: + cprint("No deletable assistants found.", "green", attrs=["bold"]) + return + + cprint("Assistants in your OpenAI account:", "green", attrs=["bold"]) + for i, assistant in enumerate(assistants): + cprint(f" {i + 1}. Name: {assistant.name} ({assistant.id})", "yellow") + cprint("Would you like to delete the following assistants? [y/N] ", "red", end="") + + confirm = input().strip().lower() + + while confirm not in ["y", "n", ""]: + cprint("Please enter 'y' or 'n': ", "red", end="") + confirm = input().strip().lower() + + if confirm in ["", "n"]: + cprint("Aborting...", "green", attrs=["bold"]) + return + + for assistant in assistants: + cprint(f"Deleting assistant {assistant.name} ({assistant.id})...", "red", attrs=["bold"]) + client.delete_assistant(assistant.id) + + +@click.group(name="assistants") +def group(): + """Commands to manage assistants""" + pass + + +group.add_command(list_assistants) +group.add_command(clean) diff --git a/src/vecsync/cli/chat.py b/src/vecsync/cli/chat.py index bf9bd05..8da9c72 100644 --- a/src/vecsync/cli/chat.py +++ b/src/vecsync/cli/chat.py @@ -5,8 +5,10 @@ from vecsync.constants import DEFAULT_STORE_NAME -def start_console_chat(store_name: str, new_conversation: bool): - client = OpenAIClient(store_name=store_name, new_conversation=new_conversation) +def start_console_chat(store_name: str): + client = OpenAIClient(store_name=store_name) + client.connect() + ui = ConsoleInterface(client) print('Type "exit" to quit at any time.') @@ -18,29 +20,25 @@ def start_console_chat(store_name: str, new_conversation: bool): ui.prompt(prompt) -def start_ui_chat(store_name: str, new_conversation: bool): - client = OpenAIClient(store_name=store_name, new_conversation=new_conversation) +def start_ui_chat(store_name: str): + client = OpenAIClient(store_name=store_name) + client.connect() + ui = GradioInterface(client) ui.chat_interface() @click.command("chat") -@click.option( - "--new-conversation", - "-n", - is_flag=True, - help="Force the assistant to create a new thread.", -) @click.option( "--use-ui", "-u", is_flag=True, help="Spawn an interactive UI instead of a console interface.", ) -def chat(new_conversation: bool, use_ui: bool): +def chat(use_ui: bool): """Chat with the assistant.""" if use_ui: - start_ui_chat(DEFAULT_STORE_NAME, new_conversation) + start_ui_chat(DEFAULT_STORE_NAME) else: - start_console_chat(DEFAULT_STORE_NAME, new_conversation) + start_console_chat(DEFAULT_STORE_NAME) diff --git a/src/vecsync/cli/entry.py b/src/vecsync/cli/entry.py index 8d7a84f..172e2c5 100644 --- a/src/vecsync/cli/entry.py +++ b/src/vecsync/cli/entry.py @@ -1,5 +1,6 @@ import click +from vecsync.cli.assistants import group as assistants_group from vecsync.cli.chat import chat from vecsync.cli.settings import group as settings_group from vecsync.cli.store import group as store_group @@ -12,7 +13,7 @@ def cli(): pass -for group in [store_group, settings_group]: +for group in [assistants_group, store_group, settings_group]: cli.add_command(group) cli.add_command(sync) diff --git a/src/vecsync/cli/settings.py b/src/vecsync/cli/settings.py index a0885f2..0f66a98 100644 --- a/src/vecsync/cli/settings.py +++ b/src/vecsync/cli/settings.py @@ -12,7 +12,7 @@ def clear(): @click.command() -def info(): +def show(): """Get the location and data of the settings file.""" settings = Settings() data = settings.info() @@ -27,4 +27,4 @@ def group(): group.add_command(clear) -group.add_command(info) +group.add_command(show) diff --git a/src/vecsync/cli/store.py b/src/vecsync/cli/store.py index 76c7001..0958987 100644 --- a/src/vecsync/cli/store.py +++ b/src/vecsync/cli/store.py @@ -5,8 +5,8 @@ from vecsync.store.openai import OpenAiVectorStore -@click.command() -def list(): +@click.command(name="list") +def list_stores(): """List files in the remote vector store.""" store = OpenAiVectorStore(DEFAULT_STORE_NAME) files = store.get_files() @@ -31,5 +31,5 @@ def group(): pass -group.add_command(list) +group.add_command(list_stores) group.add_command(delete)