Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)

## [0.6.0]
### Added
- Commands for `assistants list` and `assistants clean`
- Automatic cleanup of any extra assistants in user account when initiating chat
- Added docstrings for undocumented functions

### Changed
- Updated CLI chat command to `vs chat`
- Refactored CLI into separate modules
- Removed assistant ID persistance from settings file and only attempt to retrieve assistant from API

## [0.5.1]
### Fixed
Expand Down
6 changes: 6 additions & 0 deletions src/vecsync/chat/clients/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel


class Assistant(BaseModel):
id: str
name: str
256 changes: 219 additions & 37 deletions src/vecsync/chat/clients/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from queue import Empty, Queue

from openai import AssistantEventHandler, OpenAI
from termcolor import cprint

from vecsync.chat.clients.base import Assistant
from vecsync.chat.formatter import ConsoleFormatter, GradioFormatter
from vecsync.settings import SettingExists, SettingMissing, Settings
from vecsync.store.openai import OpenAiVectorStore
Expand All @@ -10,6 +12,21 @@
# TODO: This class will likely be refactored into common class across other client types. However
# since we only have OpenAI at the moment, we'll keep it here for now.
class OpenAIHandler(AssistantEventHandler):
"""Handler for OpenAI API events.

This class is used to handle streaming events from the OpenAI API. Internally, it puts all streaming
chunks into a Queue which allows for the streaming to be consumed in real time by other functions.

Parameters
----------
files : dict[str, str]
A dictionary of file IDs and their corresponding names. This is used to format the citations
in the response.
formatter : ConsoleFormatter | GradioFormatter
The formatter to use for formatting the output of the response. This can be either a
ConsoleFormatter or GradioFormatter.
"""

def __init__(self, files: dict[str, str], formatter: ConsoleFormatter | GradioFormatter):
super().__init__()
self.files = files
Expand All @@ -19,6 +36,7 @@ def __init__(self, files: dict[str, str], formatter: ConsoleFormatter | GradioFo
self.formatter = formatter

def on_message_delta(self, delta, snapshot):
# Handle the response chunk
delta_annotations = {}
text_chunks = []

Expand All @@ -43,12 +61,27 @@ def on_message_delta(self, delta, snapshot):
self.queue.put("".join(text_chunks))

def on_message_done(self, message):
# Append citations at the end of the response
text = self.formatter.get_references(self.annotations, self.files)
self.queue.put(text)
self.active = False

def consume_queue(self, timeout: float = 1.0):
"""Pulls from handler.queue in real time, calls write_fn(text)."""
"""Consume chunks from the queue.

Parameters
----------
timeout : float
The timeout seconds for the queue. This is used to prevent blocking if there are no chunks
available in the queue.

Yields
------
str
The chunks of text from the queue. This will yield until the queue is empty or the
active flag is set to False.
"""

while self.active or not self.queue.empty():
try:
chunk = self.queue.get(timeout=timeout)
Expand All @@ -60,39 +93,126 @@ def consume_queue(self, timeout: float = 1.0):


class OpenAIClient:
def __init__(self, store_name: str, new_conversation: bool = False):
"""OpenAI client for interacting with the OpenAI API.

This client is used to send messages to the OpenAI API and receive responses.

Parameters
----------
store_name : str
The name of the vector store to use for this client. The named assistant will
be created in the form of "vecsync-{store_name}".

"""

def __init__(self, store_name: str):
self.client = OpenAI()
self.vector_store = OpenAiVectorStore(store_name)
self.store_name = store_name
self.assistant_name = f"vecsync-{store_name}"
self.connected = False

def connect(self):
"""Connect to the OpenAI API and load the assistant and thread.

There are four independent entitites in the OpenAI API:
1. Files: User files are uploaded to OpenAI and exist as an artifact which can be used in
several places. The file references are loaded here for translating ciation references.
2. Vector Store: The vector store is a collection of files which are used for RAG search
by the assistant. A vector store can be assigned to multiple assistants.
3. Assistant: The assistant is the entity which is used to interact with the OpenAI API.
We currently only support one assistant per OpenAI account.
4. Thread: The thread is the conversation history which is used to store messages between the
user and assistant. Threads are created whenever a new assistant is created, the user
deletes their settings file, or the user runs the application on a different machine.
"""
# Connect to the OpenAI vector store
self.vector_store = OpenAiVectorStore(self.store_name)
self.vector_store.get()

self.assistant_name = f"vecsync-{self.vector_store.store.name}"
self.assistant_id = self._get_or_create_assistant()

self.thread_id = None if new_conversation else self._get_thread_id()
# Load the assistant and thread
self.assistant_id = self._get_assistant_id()
self.thread_id = self._get_thread_id()

# Load the files in the vector store
self.files = {f.id: f.name for f in self.vector_store.get_files()}

def _get_thread_id(self) -> str | None:
self.connected = True

def disconnect(self):
"""Clear all OpenAI client state."""
self.assistant_id = None
self.thread_id = None
self.files = None
self.vector_store = None
self.connected = False

def _get_thread_id(self) -> str:
"""Locates or creates the thread ID

Thread IDs are stored locally in the user settings file. The ID is loaded from settings and
is created if it doesn't exist.

Returns
-------
str
The thread ID for the current conversation.
"""
settings = Settings()

# TODO: Ideally we would grab the thread ID from OpenAI but there doesn't seem to be
# a way to do that. So we are storing it in the settings file for now.
match settings["openai_thread_id"]:
case SettingMissing():
return None
return self._create_thread()
case SettingExists() as x:
print(f"✅ Thread found: {x.value}")
return x.value

def _get_or_create_assistant(self):
settings = Settings()
def _get_assistant_id(self) -> str:
"""Locates or creates the assistant ID

match settings["openai_assistant_id"]:
case SettingExists() as x:
print(f"✅ Assistant found: {x.value}")
return x.value
case _:
return self._create_assistant()
Assistant IDs are stored in the OpenAI account. The ID is loaded from the account and is created
if it doesn't exist. There should only be one assistant per account at any point in time. This step
performs a cleanup check if multiple assistants are found.

Returns
-------
str
The assistant ID for the current conversation.
"""
# Check if the assistant already exists
existing_assistants = self.list_assistants()
count_assistants = len(existing_assistants)

if count_assistants > 1:
# We only allow for one assistant per account at this time
# This state shouldn't happen, but if it does, we need to remove the extras
# to keep things clean

count_extra = count_assistants - 1
cprint(f"⚠️ Multiple vecsync assistants found in account. Cleaning up {count_extra} extras.", "yellow")

for assistant in existing_assistants[1:]:
self.delete_assistant(assistant.id)

if count_assistants > 0:
id = existing_assistants[0].id
print(f"✅ Assistant found remotely: {id}")
return id
else:
return self._create_assistant()

def _create_assistant(self) -> str:
"""Creates a new assistant in the OpenAI account.

The assistant is created with the name "vecsync-{store_name}" and is attached to the
user's vector store.

Returns
-------
str
The assistant ID for the current conversation.
"""

instructions = """You are a helpful research assistant that can search through a large number
of journals and papers to help answer the user questions. You have been given a file store which contains
the relevant documents the user is referencing. These documents should be your primary source of information.
Expand All @@ -114,15 +234,45 @@ def _create_assistant(self) -> str:

settings = Settings()
del settings["openai_thread_id"]
del settings["openai_assistant_id"]

print(f"🖥️ Assistant created: {assistant.name}")
print(f"🔗 Assistant URL: https://platform.openai.com/assistants/{assistant.id}")
settings["openai_assistant_id"] = assistant.id
return assistant.id

def _create_thread(self) -> str:
"""Creates a new thread in the OpenAI account.

The new thread ID is stored in the local settings file since OpenAI doesn't provide a way to
remotely access the thread ID.

Returns
-------
str
The thread ID for the current conversation.
"""

thread = self.client.beta.threads.create()
print(f"💬 Conversation started: {thread.id}")
settings = Settings()
settings["openai_thread_id"] = thread.id
return thread.id

def load_history(self) -> list[dict[str, str]]:
"""Fetch all prior messages in this thread"""
"""Fetch all prior messages in this thread

This method loads the conversation history from the OpenAI API. The messages are sorted
by their creation time.

Returns
-------
list[dict[str, str]]
A list of dictionaries containing the role and content of each message.
The role is either "user" or "assistant".
"""

if not self.connected:
self.connect()

history = []
if self.thread_id is not None:
resp = self.client.beta.threads.messages.list(thread_id=self.thread_id)
Expand All @@ -138,31 +288,63 @@ def load_history(self) -> list[dict[str, str]]:

return history

def initialize_chat(self):
if self.thread_id is None:
thread = self.client.beta.threads.create()
self.thread_id = thread.id
print(f"💬 Conversation started: {self.thread_id}")
settings = Settings()
settings["openai_thread_id"] = self.thread_id
def send_message(self, prompt: str):
"""Send a message to the OpenAI thread.

def _run_stream(self, handler: OpenAIHandler):
with self.client.beta.threads.runs.stream(
thread_id=self.thread_id,
assistant_id=self.assistant_id,
event_handler=handler,
) as stream:
stream.until_done()
Parameters
----------
prompt : str
The message to send to the OpenAI thread.
"""

def send_message(self, prompt: str):
self.initialize_chat()
if not self.connected:
self.connect()

return self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=prompt)

def stream_response(self, thread_id: str, assistant_id: str, handler):
"""Generate a thread run and stream the response.

Parameters
----------
thread_id : str
The ID of the thread to stream the response from.
assistant_id : str
The ID of the assistant to stream the response from.
handler : AssistantEventHandler
The event handler to use for processing the response.
"""

with self.client.beta.threads.runs.stream(
thread_id=thread_id,
assistant_id=assistant_id,
event_handler=handler,
) as stream:
stream.until_done()

def list_assistants(self) -> list[Assistant]:
"""List all vecsync assistants in the OpenAI account.

This only returns vecsync assistants which are prefixed with "vecsync-". There should
only be one assistant per account, but this method is here to help with cleanup if
multiple assistants are created.

Returns:
list[Assistant]: A list of Assistant objects.
"""
results = []

for assistant in self.client.beta.assistants.list():
if assistant.name.startswith("vecsync-"):
results.append(Assistant(id=assistant.id, name=assistant.name))

return results

def delete_assistant(self, assistant_id: str):
"""Delete an assistant from the OpenAI account.

Args:
assistant_id (str): The ID of the assistant to delete.
"""
self.client.beta.assistants.delete(assistant_id)
self.disconnect()
Loading