diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f88f3d9..f68f7ff 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,6 +25,8 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] python-version: ["3.10", "3.11", "3.12", "3.13"] runs-on: ${{ matrix.os }} + env: + OPENAI_API_KEY: dummy_value_for_tests # Dummy value steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/.gitignore b/.gitignore index 583cd7d..7a3c618 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ wheels/ # Development files .coverage +.coverage.* .env *.pdf diff --git a/CHANGELOG.md b/CHANGELOG.md index dacd133..1b42743 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,9 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [0.6.0] ### Added -- Commands for `assistants list` and `assistants clean` +- Added CLI commands for `assistants list` and `assistants clean` - Automatic cleanup of any extra assistants in user account when initiating chat - Added docstrings for undocumented functions +- Test case coverage for most OpenAI chat and vector store operations [#15](https://github.com/jbencina/vecsync/issues/15) ### Changed - Updated CLI chat command to `vs chat` diff --git a/src/vecsync/chat/clients/openai.py b/src/vecsync/chat/clients/openai.py index fc14faf..e736afa 100644 --- a/src/vecsync/chat/clients/openai.py +++ b/src/vecsync/chat/clients/openai.py @@ -63,7 +63,8 @@ def on_message_delta(self, delta, snapshot): def on_message_done(self, message): # Append citations at the end of the response text = self.formatter.get_references(self.annotations, self.files) - self.queue.put(text) + if len(text) > 0: + self.queue.put(text) self.active = False def consume_queue(self, timeout: float = 1.0): @@ -102,14 +103,17 @@ class OpenAIClient: store_name : str The name of the vector store to use for this client. The named assistant will be created in the form of "vecsync-{store_name}". - + settings_path : str | None + The path to the settings file. If None, the default settings file will be used. + This is used to store the thread ID for the current conversation. """ - def __init__(self, store_name: str): + def __init__(self, store_name: str, settings_path: str | None = None): self.client = OpenAI() self.store_name = store_name self.assistant_name = f"vecsync-{store_name}" self.connected = False + self.settings_path = settings_path def connect(self): """Connect to the OpenAI API and load the assistant and thread. @@ -156,7 +160,7 @@ def _get_thread_id(self) -> str: str The thread ID for the current conversation. """ - settings = Settings() + settings = Settings(path=self.settings_path) # TODO: Ideally we would grab the thread ID from OpenAI but there doesn't seem to be # a way to do that. So we are storing it in the settings file for now. @@ -232,7 +236,7 @@ def _create_assistant(self) -> str: model="gpt-4o-mini", ) - settings = Settings() + settings = Settings(path=self.settings_path) del settings["openai_thread_id"] print(f"🖥️ Assistant created: {assistant.name}") @@ -253,7 +257,7 @@ def _create_thread(self) -> str: thread = self.client.beta.threads.create() print(f"💬 Conversation started: {thread.id}") - settings = Settings() + settings = Settings(path=self.settings_path) settings["openai_thread_id"] = thread.id return thread.id diff --git a/src/vecsync/cli.py b/src/vecsync/cli.py index d0b4b47..e4573c2 100644 --- a/src/vecsync/cli.py +++ b/src/vecsync/cli.py @@ -77,7 +77,7 @@ def sync(source: str): f"Saved: {result.files_saved} | Deleted: {result.files_deleted} | Skipped: {result.files_skipped} ", "yellow", ) - cprint(f"Remote count: {result.updated_count}", "yellow") + cprint(f"Remote count: {result.remote_count}", "yellow") cprint(f"Duration: {result.duration:.2f} seconds", "yellow") diff --git a/src/vecsync/cli/sync.py b/src/vecsync/cli/sync.py index 8dfb7a2..dea583d 100644 --- a/src/vecsync/cli/sync.py +++ b/src/vecsync/cli/sync.py @@ -40,5 +40,5 @@ def sync(source: str): f"Saved: {result.files_saved} | Deleted: {result.files_deleted} | Skipped: {result.files_skipped} ", "yellow", ) - cprint(f"Remote count: {result.updated_count}", "yellow") + cprint(f"Remote count: {result.remote_count}", "yellow") cprint(f"Duration: {result.duration:.2f} seconds", "yellow") diff --git a/src/vecsync/settings.py b/src/vecsync/settings.py index d499d0f..d259424 100644 --- a/src/vecsync/settings.py +++ b/src/vecsync/settings.py @@ -21,7 +21,7 @@ class SettingData(BaseModel): class Settings: - def __init__(self, path: str | None = None): + def __init__(self, path: Path | None = None): self.file = path or Path(user_config_dir("vecsync")) / "settings.json" if not self.file.exists(): diff --git a/src/vecsync/store/openai.py b/src/vecsync/store/openai.py index 551d414..d5da208 100644 --- a/src/vecsync/store/openai.py +++ b/src/vecsync/store/openai.py @@ -13,7 +13,7 @@ class SyncOperationResult(BaseModel): files_saved: int files_deleted: int files_skipped: int - updated_count: int + remote_count: int duration: float @@ -88,8 +88,9 @@ def _delete_files(self, files_to_remove: list[str]) -> set[str]: removed_file_ids = [] for file_id in tqdm(files_to_remove): self.client.vector_stores.files.delete(vector_store_id=self.store.id, file_id=file_id) - self.client.files.delete(file_id=file_id) - removed_file_ids.append(file_id) + result = self.client.files.delete(file_id=file_id) + if result.deleted: + removed_file_ids.append(file_id) return set(removed_file_ids) @@ -147,6 +148,6 @@ def sync(self, files: list[Path]): files_saved=len(files_to_upload), files_deleted=len(files_to_remove), files_skipped=len(duplicate_file_names), - updated_count=len(existing_vector_file_ids | files_to_attach), + remote_count=len(existing_vector_file_ids | files_to_attach), duration=duration, ) diff --git a/tests/openai/conftest.py b/tests/openai/conftest.py new file mode 100644 index 0000000..6360fe9 --- /dev/null +++ b/tests/openai/conftest.py @@ -0,0 +1,303 @@ +import os +from datetime import datetime +from types import SimpleNamespace +from typing import Any + +import pytest +from pydantic import BaseModel + +import vecsync.chat.clients.openai as client_mod +from vecsync.chat.clients.openai import OpenAIClient, OpenAIHandler +from vecsync.chat.formatter import ConsoleFormatter +from vecsync.store.openai import OpenAiVectorStore + + +class MockAssistant(BaseModel): + id: str + name: str + + +class MockThread(BaseModel): + id: str + + +class MockMessageContentText(BaseModel): + value: str + + +class MockMessageContent(BaseModel): + type: str + text: MockMessageContentText + + +class MockMessageData(BaseModel): + content: list[MockMessageContent] + created_at: int # TODO: Check if this is really at both levels + role: str + + +class MockMessage(BaseModel): + data: MockMessageData + thread_id: str + created_at: int + + +class MockThreadMessageResponse(BaseModel): + thread_id: str + data: list[MockMessageData] + + +class MockVectorStore(BaseModel): + id: str + name: str + + +class MockFileUpload(BaseModel): + id: str + file: Any + + +class MockFile(BaseModel): + id: str + filename: str + + +class MockFileDeletedResult(BaseModel): + deleted: bool + + +class MockVectorStoreDeletedResult(BaseModel): + deleted: bool + + +class MockStreamResponseAnnotation(BaseModel): + type: str + text: str + + +class MockStreamResponseText(BaseModel): + value: str + annotations: list[MockStreamResponseAnnotation] + + +class MockStreamResponseContent(BaseModel): + type: str + text: MockStreamResponseText + + +class MockStreamResponse(BaseModel): + content: list[MockStreamResponseContent] + + +def mock_vector_store(): + vector_store = [] + file_store = [] + vector_file_store = [] + + def create_vector_store(name): + store = MockVectorStore(id=f"vector_store_{len(vector_store) + 1}", name=name) + vector_store.append(store) + return store + + def delete_vector_store(vector_store_id): + for store in vector_store: + if store.id == vector_store_id: + vector_store.remove(store) + return MockFileDeletedResult(deleted=True) + return MockFileDeletedResult(deleted=False) + + def list_vector_stores(): + return vector_store + + def list_files(): + return file_store + + def list_vector_store_files(vector_store_id): + return vector_file_store + + def delete_vector_store_file(vector_store_id, file_id): + for vector_file in vector_file_store: + if vector_file.id == file_id: + vector_file_store.remove(vector_file) + + def delete_file(file_id): + for file in file_store: + if file.id == file_id: + file_store.remove(file) + return MockFileDeletedResult(deleted=True) + return MockFileDeletedResult(deleted=False) + + def create_file(**kwargs): + base_name = os.path.basename(kwargs["file"].name) + file = MockFileUpload(id=f"file_{len(file_store) + 1}", file=kwargs["file"]) + file_store.append(MockFile(id=file.id, filename=base_name)) + return file + + def create_and_poll(vector_store_id, file_id): + for store in vector_store: + if store.id == vector_store_id: + vector_file = MockFile(id=file_id, filename=f"file_{file_id}") + vector_file_store.append(vector_file) + return vector_file + return None + + # attach methods + vs_files_ns = SimpleNamespace() + vs_files_ns.list = list_vector_store_files + vs_files_ns.delete = delete_vector_store_file + vs_files_ns.create_and_poll = create_and_poll + + stores_ns = SimpleNamespace() + stores_ns.create = create_vector_store + stores_ns.delete = delete_vector_store + stores_ns.list = list_vector_stores + stores_ns.files = vs_files_ns + + files_ns = SimpleNamespace() + files_ns.list = list_files + files_ns.delete = delete_file + files_ns.create = create_file + + # build your “client” + client = SimpleNamespace() + client.vector_stores = stores_ns + client.files = files_ns + + return client + + +def mock_client_backend(): + # our in‐memory store + assistant_store = [] + threads_store = [] + message_store = [] + + def create_assistant(**kwargs): + name = kwargs["name"] + assistant = MockAssistant(id=f"assistant_{name}_{len(assistant_store) + 1}", name=name) + assistant_store.append(assistant) + return assistant + + def list_assistants(): + return assistant_store + + def delete_assistant(assistant_id): + for assistant in assistant_store: + if assistant.id == assistant_id: + assistant_store.remove(assistant) + + def create_thread(**kwargs): + thread = MockThread(id=f"thread_{len(threads_store) + 1}") + threads_store.append(thread) + return thread + + def create_message(**kwargs): + created_at = int(datetime.now().timestamp()) + + message = MockMessage( + created_at=created_at, + data=MockMessageData( + created_at=created_at, + content=[MockMessageContent(type="text", text=MockMessageContentText(value=kwargs["content"]))], + role=kwargs["role"], + ), + thread_id=kwargs["thread_id"], + ) + message_store.append(message) + return message + + def list_messages(**kwargs): + thread_id = kwargs["thread_id"] + messages = [message.data for message in message_store if message.thread_id == thread_id] + return MockThreadMessageResponse(thread_id=thread_id, data=messages) + + def stream_response(**kwargs): + class StreamManager: + def __init__(self, handler): + self.handler = handler + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + return False + + def until_done(self): + text = """This is a test message from the assistant""" + for delta in text.split(): + message = MockStreamResponse( + content=[ + MockStreamResponseContent( + type="text", text=MockStreamResponseText(value=delta, annotations=[]) + ) + ] + ) + self.handler.on_message_delta(delta=message, snapshot=None) + + self.handler.on_message_done(message=None) + + return StreamManager(handler=kwargs["event_handler"]) + + # attach methods + assistants_ns = SimpleNamespace() + assistants_ns.create = create_assistant + assistants_ns.list = list_assistants + assistants_ns.delete = delete_assistant + + messages_ns = SimpleNamespace() + messages_ns.create = create_message + messages_ns.list = list_messages + + runs_ns = SimpleNamespace() + runs_ns.stream = stream_response + + threads_ns = SimpleNamespace() + threads_ns.create = create_thread + threads_ns.messages = messages_ns + threads_ns.runs = runs_ns + + # build your “client” + client = SimpleNamespace() + client.beta = SimpleNamespace() + client.beta.assistants = assistants_ns + client.beta.threads = threads_ns + + return client + + +@pytest.fixture +def mocked_vector_store(): + store = OpenAiVectorStore(name="test_store") + store.client = mock_vector_store() + store.create() + return store + + +@pytest.fixture +def mocked_client(tmp_path, mocked_vector_store, monkeypatch): + monkeypatch.setattr(client_mod, "OpenAiVectorStore", lambda store_name: mocked_vector_store) + + settings_path = tmp_path / "settings.json" + client = OpenAIClient(store_name="test_store", settings_path=settings_path) + client.client = mock_client_backend() + + return client + + +@pytest.fixture +def mocked_client_handler(): + return OpenAIHandler( + files={"file_1", "filename.txt"}, + formatter=ConsoleFormatter(), + ) + + +@pytest.fixture +def create_test_upload(tmp_path): + files = set() + for i in range(3): + file = tmp_path / f"test_file_{i}.txt" + files.add(file) + with open(file, "w") as f: + f.write(f"This is test file {i}") + return files diff --git a/tests/openai/test_openai_chat.py b/tests/openai/test_openai_chat.py new file mode 100644 index 0000000..1527de3 --- /dev/null +++ b/tests/openai/test_openai_chat.py @@ -0,0 +1,100 @@ +from vecsync.settings import Settings + + +def test_list_assistants(mocked_client): + mocked_client.client.beta.assistants.create(name="vecsync-1") + mocked_client.client.beta.assistants.create(name="vecsync-2") + mocked_client.client.beta.assistants.create(name="other-3") + + assistants = mocked_client.list_assistants() + + assert len(assistants) == 2 + + +def test_delete_assistant(mocked_client): + assistant = mocked_client.client.beta.assistants.create(name="vecsync-1") + mocked_client.client.beta.assistants.create(name="vecsync-2") + + mocked_client.delete_assistant(assistant.id) + + assert len(mocked_client.list_assistants()) == 1 + + +def test_create_thread(mocked_client): + thread_id = mocked_client._create_thread() + assert thread_id == "thread_1" + + +def test_get_thread_id_new(mocked_client): + thread_id = mocked_client._get_thread_id() + assert thread_id == "thread_1" + + +def test_get_thread_id_existing(mocked_client): + settings = Settings(path=mocked_client.settings_path) + settings["openai_thread_id"] = "thread_2" + + thread_id = mocked_client._get_thread_id() + assert thread_id == "thread_2" + + +def test_create_assistant(mocked_client, mocked_vector_store): + mocked_client.vector_store = mocked_vector_store + id = mocked_client._create_assistant() + assert id == "assistant_vecsync-test_store_1" + + +def test_get_assistant_id_new(mocked_client, mocked_vector_store): + mocked_client.vector_store = mocked_vector_store + assistant_id = mocked_client._get_assistant_id() + assert assistant_id == "assistant_vecsync-test_store_1" + + +def test_get_assistant_id_existing(mocked_client): + mocked_client.client.beta.assistants.create(name="vecsync-1") + assistant_id = mocked_client._get_assistant_id() + assert assistant_id == "assistant_vecsync-1_1" + + +def test_get_assistant_id_multiple(mocked_client): + mocked_client.client.beta.assistants.create(name="vecsync-1") + mocked_client.client.beta.assistants.create(name="vecsync-2") + assistant_id = mocked_client._get_assistant_id() + assert assistant_id == "assistant_vecsync-1_1" + assert len(mocked_client.client.beta.assistants.list()) == 1 + + +def test_load_history_none(mocked_client): + history = mocked_client.load_history() + assert history == [] + + +def test_load_history_valid(mocked_client): + mocked_client.send_message("Hello") + mocked_client.send_message("World") + mocked_client.client.beta.threads.messages.create( + thread_id=mocked_client.thread_id, role="assistant", content="Response" + ) + + history = mocked_client.load_history() + + assert len(history) == 3 + assert [x["role"] for x in history] == ["user", "user", "assistant"] + + +def test_message(mocked_client, mocked_client_handler): + mocked_client.stream_response(thread_id="", assistant_id="", handler=mocked_client_handler) + + items = [] + while not mocked_client_handler.queue.empty(): + items.append(mocked_client_handler.queue.get_nowait()) + + assert items == ["This", "is", "a", "test", "message", "from", "the", "assistant"] + + +def test_consume_queue(mocked_client, mocked_client_handler): + mocked_client.stream_response(thread_id="", assistant_id="", handler=mocked_client_handler) + + items = list(mocked_client_handler.consume_queue()) + + assert items == ["This", "is", "a", "test", "message", "from", "the", "assistant"] diff --git a/tests/openai/test_openai_store.py b/tests/openai/test_openai_store.py new file mode 100644 index 0000000..b96e936 --- /dev/null +++ b/tests/openai/test_openai_store.py @@ -0,0 +1,96 @@ +import pytest + + +def test_get_files_none(mocked_vector_store): + files = mocked_vector_store.get_files() + assert len(files) == 0 + + +def test_get_valid_store(mocked_vector_store): + store = mocked_vector_store.get() + assert store.name == "test_store" + assert store.id == "vector_store_1" + + +def test_get_invalid_store(mocked_vector_store): + mocked_vector_store.name = "invalid_store" + with pytest.raises(ValueError): + mocked_vector_store.get() + + +def test_get_files_empty(mocked_vector_store): + mocked_vector_store.get() + files = mocked_vector_store.get_files() + assert len(files) == 0 + + +def test_get_files_existing(mocked_vector_store, create_test_upload): + files_uploaded = mocked_vector_store._upload_files(create_test_upload) + + remote_files = mocked_vector_store.get_files() + + assert len(remote_files) == len(files_uploaded) == 3 + + +def test_delete_files(mocked_vector_store, create_test_upload): + files_uploaded = mocked_vector_store._upload_files(create_test_upload) + assert len(files_uploaded) == 3 + + removed_files = mocked_vector_store._delete_files(files_uploaded) + assert len(removed_files) == 3 + + remote_files = mocked_vector_store.get_files() + assert len(remote_files) == 0 + + +def test_delete_files_invalid(mocked_vector_store): + removed_files = mocked_vector_store._delete_files(["test"]) + assert len(removed_files) == 0 + + remote_files = mocked_vector_store.get_files() + assert len(remote_files) == 0 + + +def test_delete_store(mocked_vector_store): + mocked_vector_store.delete() + assert mocked_vector_store.store is None + + +def test_get(mocked_vector_store): + store = mocked_vector_store.get() + assert store.name == "test_store" + assert store.id == "vector_store_1" + + +def test_attach_files(mocked_vector_store, create_test_upload): + files_uploaded = mocked_vector_store._upload_files(create_test_upload) + assert len(files_uploaded) == 3 + + mocked_vector_store._attach_files(files_uploaded) + + remote_files = mocked_vector_store.get_files() + assert len(remote_files) == 3 + + +def test_sync_files(mocked_vector_store, create_test_upload): + result = mocked_vector_store.sync(create_test_upload) + + assert result.files_saved == 3 + assert result.files_deleted == 0 + assert result.files_skipped == 0 + assert result.remote_count == 3 + assert result.duration > 0 + + +def test_sync_files_with_existing_overlap(mocked_vector_store, create_test_upload): + files = list(create_test_upload) + + result1 = mocked_vector_store.sync(files[:2]) + assert result1.files_saved == 2 + + result2 = mocked_vector_store.sync(files) + assert result2.files_saved == 1 + assert result2.files_deleted == 0 + assert result2.files_skipped == 2 + assert result2.remote_count == 3 + assert result2.duration > 0 diff --git a/uv.lock b/uv.lock index 422d9cc..96655d1 100644 --- a/uv.lock +++ b/uv.lock @@ -278,7 +278,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload_time = "2025-05-10T17:42:51.123Z" } wheels = [