From f66f64e3036bf5dcc7704ab35eec48abc142991d Mon Sep 17 00:00:00 2001 From: fzowl Date: Wed, 4 Feb 2026 02:22:06 +0100 Subject: [PATCH] voyage-4-nano support --- .github/workflows/ci.yaml | 83 +++++++++- pyproject.toml | 13 +- tests/test_client_local.py | 132 +++++++++++++++ tests/test_client_local_async.py | 101 ++++++++++++ voyageai/__init__.py | 14 ++ voyageai/_base.py | 15 +- voyageai/client.py | 63 ++++++- voyageai/client_async.py | 84 +++++++++- voyageai/local/__init__.py | 30 ++++ voyageai/local/model_registry.py | 154 ++++++++++++++++++ .../local/sentence_transformer_backend.py | 125 ++++++++++++++ 11 files changed, 799 insertions(+), 15 deletions(-) create mode 100644 tests/test_client_local.py create mode 100644 tests/test_client_local_async.py create mode 100644 voyageai/local/__init__.py create mode 100644 voyageai/local/model_registry.py create mode 100644 voyageai/local/sentence_transformer_backend.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 74c5858..4a71ac8 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -7,9 +7,79 @@ on: branches: [ "main" ] workflow_call: +env: + POETRY_HOME: "/opt/poetry" + jobs: - test: - name: Tests + # Unit tests without local dependencies (tests import error handling) + unit-tests-no-local-deps: + name: Unit Tests (no local deps) + runs-on: ubuntu-22.04 + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install poetry + run: | + python3 -m venv $POETRY_HOME + $POETRY_HOME/bin/pip install poetry==1.8.4 + $POETRY_HOME/bin/poetry --version + - name: Install ffmpeg 7 + run: | + sudo apt-get update + sudo apt-get install -y software-properties-common + sudo add-apt-repository -y ppa:ubuntuhandbook1/ffmpeg7 + sudo apt-get update + sudo apt-get install -y ffmpeg + ffmpeg -version + - name: Install package (without local extras) + run: $POETRY_HOME/bin/poetry install + - name: Run unit tests (no API key, no local deps) + run: $POETRY_HOME/bin/poetry run pytest tests/test_client_local.py tests/test_client_local_async.py -v + + # Local model tests with local dependencies but no API key + local-model-tests: + name: Local Model Tests + runs-on: ubuntu-22.04 + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install poetry + run: | + python3 -m venv $POETRY_HOME + $POETRY_HOME/bin/pip install poetry==1.8.4 + $POETRY_HOME/bin/poetry --version + - name: Install ffmpeg 7 + run: | + sudo apt-get update + sudo apt-get install -y software-properties-common + sudo add-apt-repository -y ppa:ubuntuhandbook1/ffmpeg7 + sudo apt-get update + sudo apt-get install -y ffmpeg + ffmpeg -version + - name: Install package with local extras + run: $POETRY_HOME/bin/poetry install --extras local + - name: Run local model tests (no API key) + run: $POETRY_HOME/bin/poetry run pytest tests/test_client_local.py tests/test_client_local_async.py -v + + # Full integration tests with API key and all dependencies + integration-tests: + name: Integration Tests runs-on: ubuntu-22.04 strategy: fail-fast: false @@ -17,7 +87,6 @@ jobs: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] env: VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} - POETRY_HOME: "/opt/poetry" steps: - name: Checkout repository uses: actions/checkout@v4 @@ -38,7 +107,7 @@ jobs: sudo apt-get update sudo apt-get install -y ffmpeg ffmpeg -version - - name: Install package - run: $POETRY_HOME/bin/poetry install - - name: Run tests - run: $POETRY_HOME/bin/poetry run pytest \ No newline at end of file + - name: Install package with local extras + run: $POETRY_HOME/bin/poetry install --extras local + - name: Run all tests + run: $POETRY_HOME/bin/poetry run pytest diff --git a/pyproject.toml b/pyproject.toml index a7c3664..228857e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "voyageai" -version = "0.3.8" +version = "0.3.9" description = "" authors = ["Yujie Qian "] readme = "README.md" @@ -20,6 +20,11 @@ pydantic = ">=1.10.8" tokenizers = ">=0.14.0" langchain-text-splitters = ">=0.3.8" ffmpeg-python = "*" +sentence-transformers = {version = ">=3.0.0", optional = true} +torch = {version = ">=2.0.0", optional = true} + +[tool.poetry.extras] +local = ["sentence-transformers", "torch"] [tool.poetry.group.test.dependencies] pytest = "^7.4.2" @@ -40,6 +45,12 @@ quote-style = "double" indent-style = "space" skip-magic-trailing-comma = false +[tool.pytest.ini_options] +markers = [ + "integration: marks tests as integration tests (require external dependencies or API)", +] +asyncio_mode = "auto" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/tests/test_client_local.py b/tests/test_client_local.py new file mode 100644 index 0000000..abc3f50 --- /dev/null +++ b/tests/test_client_local.py @@ -0,0 +1,132 @@ +"""Tests for local model support in Client.""" + +import pytest + +# ruff: noqa: F401 + +# Check if real dependencies are available +try: + import sentence_transformers + import torch + + REAL_DEPS_AVAILABLE = True +except ImportError: + REAL_DEPS_AVAILABLE = False + + +class TestLocalModelSupport: + """Test local model detection and routing.""" + + @pytest.mark.skipif(REAL_DEPS_AVAILABLE, reason="Only run when deps not installed") + def test_import_error_when_deps_missing(self): + """Test helpful error message when sentence-transformers not installed.""" + from voyageai.local import _ensure_local_deps + + with pytest.raises(ImportError) as exc_info: + _ensure_local_deps() + + assert "pip install voyageai[local]" in str(exc_info.value) + + def test_has_local_constant(self): + """Test HAS_LOCAL constant reflects dependency availability.""" + import voyageai + + assert voyageai.HAS_LOCAL == REAL_DEPS_AVAILABLE + + +@pytest.mark.integration +class TestLocalModelIntegration: + """Integration tests for local models using the standard Client. + + Run with: pytest -m integration + """ + + @pytest.fixture + def check_deps(self): + """Skip if dependencies not installed.""" + if not REAL_DEPS_AVAILABLE: + pytest.skip("sentence-transformers or torch not installed") + + def test_seamless_local_embedding(self, check_deps): + """Test that Client.embed() seamlessly uses local model.""" + from voyageai import Client + + # No API key needed for local models + client = Client() + result = client.embed(["Hello, world!"], model="voyage-4-nano", input_type="document") + + assert len(result.embeddings) == 1 + assert len(result.embeddings[0]) == 2048 + assert result.total_tokens > 0 + + def test_all_dimensions(self, check_deps): + """Test all supported dimensions work.""" + from voyageai import Client + + client = Client() + + for dim in [256, 512, 1024, 2048]: + result = client.embed( + ["Test text"], model="voyage-4-nano", input_type="document", output_dimension=dim + ) + assert len(result.embeddings[0]) == dim, f"Expected {dim}, got {len(result.embeddings[0])}" + + def test_float32_dtype(self, check_deps): + """Test float32 output data type (default).""" + from voyageai import Client + + client = Client() + + result = client.embed(["Test"], model="voyage-4-nano", input_type="document", output_dtype="float32") + assert isinstance(result.embeddings[0][0], float) + + def test_query_vs_document_different(self, check_deps): + """Test query and document embeddings are different.""" + from voyageai import Client + + client = Client() + + query_result = client.embed(["What is machine learning?"], model="voyage-4-nano", input_type="query") + doc_result = client.embed(["What is machine learning?"], model="voyage-4-nano", input_type="document") + + # Embeddings should be different due to different prompts + assert query_result.embeddings[0] != doc_result.embeddings[0] + + def test_batch_embedding(self, check_deps): + """Test batch embedding works.""" + from voyageai import Client + + client = Client() + + texts = [ + "First document", + "Second document", + "Third document", + ] + result = client.embed(texts, model="voyage-4-nano", input_type="document") + + assert len(result.embeddings) == 3 + for emb in result.embeddings: + assert len(emb) == 2048 + + def test_invalid_dimension_raises_error(self, check_deps): + """Test invalid dimension raises ValueError.""" + from voyageai import Client + + client = Client() + + with pytest.raises(ValueError) as exc_info: + client.embed(["test"], model="voyage-4-nano", output_dimension=999) + + assert "Invalid output_dimension" in str(exc_info.value) + + def test_invalid_dtype_raises_error(self, check_deps): + """Test invalid dtype raises ValueError.""" + from voyageai import Client + + client = Client() + + with pytest.raises(ValueError) as exc_info: + client.embed(["test"], model="voyage-4-nano", output_dtype="invalid") + + assert "Invalid output_dtype" in str(exc_info.value) diff --git a/tests/test_client_local_async.py b/tests/test_client_local_async.py new file mode 100644 index 0000000..0f98b8f --- /dev/null +++ b/tests/test_client_local_async.py @@ -0,0 +1,101 @@ +"""Tests for async local model support in AsyncClient.""" + +import asyncio + +import pytest + +# ruff: noqa: F401 + +# Check if real dependencies are available +try: + import sentence_transformers + import torch + + REAL_DEPS_AVAILABLE = True +except ImportError: + REAL_DEPS_AVAILABLE = False + + +@pytest.mark.integration +class TestAsyncLocalModelIntegration: + """Integration tests for async local models using the standard AsyncClient. + + Run with: pytest -m integration + """ + + @pytest.fixture + def check_deps(self): + """Skip if dependencies not installed.""" + if not REAL_DEPS_AVAILABLE: + pytest.skip("sentence-transformers or torch not installed") + + @pytest.mark.asyncio + async def test_seamless_async_local_embedding(self, check_deps): + """Test that AsyncClient.embed() seamlessly uses local model.""" + from voyageai import AsyncClient + + client = AsyncClient() + result = await client.embed(["Hello, world!"], model="voyage-4-nano", input_type="document") + + assert len(result.embeddings) == 1 + assert len(result.embeddings[0]) == 2048 + assert result.total_tokens > 0 + + @pytest.mark.asyncio + async def test_concurrent_local_embeddings(self, check_deps): + """Test concurrent local embedding calls.""" + from voyageai import AsyncClient + + client = AsyncClient() + + texts = [ + ["First document"], + ["Second document"], + ["Third document"], + ] + + tasks = [client.embed(t, model="voyage-4-nano", input_type="document") for t in texts] + results = await asyncio.gather(*tasks) + + assert len(results) == 3 + for result in results: + assert len(result.embeddings) == 1 + assert len(result.embeddings[0]) == 2048 + + @pytest.mark.asyncio + async def test_async_all_dimensions(self, check_deps): + """Test all supported dimensions in async context.""" + from voyageai import AsyncClient + + client = AsyncClient() + + for dim in [256, 512, 1024, 2048]: + result = await client.embed( + ["Test text"], model="voyage-4-nano", input_type="document", output_dimension=dim + ) + assert len(result.embeddings[0]) == dim, f"Expected {dim}, got {len(result.embeddings[0])}" + + @pytest.mark.asyncio + async def test_async_query_vs_document(self, check_deps): + """Test query and document embeddings are different in async context.""" + from voyageai import AsyncClient + + client = AsyncClient() + + query_result = await client.embed(["What is AI?"], model="voyage-4-nano", input_type="query") + doc_result = await client.embed(["What is AI?"], model="voyage-4-nano", input_type="document") + + assert query_result.embeddings[0] != doc_result.embeddings[0] + + @pytest.mark.asyncio + async def test_async_batch_embedding(self, check_deps): + """Test batch embedding in async context.""" + from voyageai import AsyncClient + + client = AsyncClient() + + result = await client.embed( + ["First", "Second", "Third"], model="voyage-4-nano", input_type="document" + ) + assert len(result.embeddings) == 3 + diff --git a/voyageai/__init__.py b/voyageai/__init__.py index 452328e..2d6c525 100644 --- a/voyageai/__init__.py +++ b/voyageai/__init__.py @@ -36,6 +36,20 @@ ) from voyageai.version import VERSION + +def _is_local_available() -> bool: + """Check if sentence-transformers and torch are installed.""" + try: + import sentence_transformers # noqa: F401 + import torch # noqa: F401 + + return True + except ImportError: + return False + + +HAS_LOCAL = _is_local_available() + if TYPE_CHECKING: import requests from aiohttp import ClientSession diff --git a/voyageai/_base.py b/voyageai/_base.py index b3558f4..7edff6e 100644 --- a/voyageai/_base.py +++ b/voyageai/_base.py @@ -46,7 +46,7 @@ class _BaseClient(ABC): """Voyage AI Client Args: - api_key (str): Your API key. + api_key (str): Your API key (optional for local models). max_retries (int): Maximum number of retries if API call fails. timeout (float): Timeout in seconds. base_url (str): Base URL for the API endpoint. @@ -59,9 +59,18 @@ def __init__( timeout: Optional[float] = None, base_url: Optional[str] = None, ) -> None: - self.api_key = api_key or default_api_key() + # API key is optional - allow None for local-only usage + try: + self.api_key = api_key or default_api_key() + except voyageai.error.AuthenticationError: + # No API key available - that's OK for local models + self.api_key = None + self.max_retries = max_retries - base_url = base_url or get_default_base_url(self.api_key) + + # Only set base_url if we have an API key + if self.api_key: + base_url = base_url or get_default_base_url(self.api_key) self._params = { "api_key": self.api_key, diff --git a/voyageai/client.py b/voyageai/client.py index 851cd5d..5e849da 100644 --- a/voyageai/client.py +++ b/voyageai/client.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from PIL.Image import Image from tenacity import ( @@ -13,6 +13,7 @@ import voyageai.error as error from voyageai._base import _BaseClient from voyageai.chunking import apply_chunking +from voyageai.local.model_registry import SUPPORTED_MODELS as LOCAL_MODELS from voyageai.object import ( ContextualizedEmbeddingsObject, EmbeddingsObject, @@ -22,12 +23,15 @@ from voyageai.object.multimodal_embeddings import MultimodalInputRequest from voyageai.video_utils import Video +if TYPE_CHECKING: + from voyageai.local.sentence_transformer_backend import SentenceTransformerBackend + class Client(_BaseClient): """Voyage AI Client Args: - api_key (str): Your API key. + api_key (str): Your API key (not required for local models). max_retries (int): Maximum number of retries if API call fails. timeout (float): Timeout in seconds. base_url (str): Base URL for the API endpoint. @@ -41,6 +45,7 @@ def __init__( base_url: Optional[str] = None, ) -> None: super().__init__(api_key, max_retries, timeout, base_url) + self._local_backends: Dict[str, "SentenceTransformerBackend"] = {} def _make_retry_controller(self) -> Retrying: return Retrying( @@ -54,6 +59,42 @@ def _make_retry_controller(self) -> Retrying: ), ) + def _get_local_backend(self, model: str) -> "SentenceTransformerBackend": + """Get or create a local backend for the given model.""" + if model not in self._local_backends: + from voyageai.local.sentence_transformer_backend import SentenceTransformerBackend + + self._local_backends[model] = SentenceTransformerBackend(model) + return self._local_backends[model] + + def _embed_local( + self, + texts: List[str], + model: str, + input_type: Optional[str] = None, + truncation: bool = True, + output_dtype: Optional[str] = None, + output_dimension: Optional[int] = None, + ) -> EmbeddingsObject: + """Generate embeddings using a local model.""" + backend = self._get_local_backend(model) + + embeddings_array = backend.encode( + texts=texts, + input_type=input_type, + output_dtype=output_dtype, + output_dimension=output_dimension, + truncation=truncation, + ) + + total_tokens = backend.count_tokens(texts) + + result = EmbeddingsObject() + result.embeddings = embeddings_array.tolist() + result.total_tokens = total_tokens + + return result + def embed( self, texts: List[str], @@ -72,6 +113,24 @@ def embed( "provided by Voyage AI." ) + # Check if this is a local model + if model in LOCAL_MODELS: + return self._embed_local( + texts=texts, + model=model, + input_type=input_type, + truncation=truncation, + output_dtype=output_dtype, + output_dimension=output_dimension, + ) + + # API models require an API key + if not self.api_key: + raise error.AuthenticationError( + "An API key is required for API-based models. " + "Set your API key via VOYAGE_API_KEY environment variable or pass it to Client(api_key=...)." + ) + response = None for attempt in self._make_retry_controller(): with attempt: diff --git a/voyageai/client_async.py b/voyageai/client_async.py index cebdaa3..57b8e7f 100644 --- a/voyageai/client_async.py +++ b/voyageai/client_async.py @@ -1,5 +1,6 @@ +import asyncio import warnings -from typing import Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from PIL.Image import Image from tenacity import ( @@ -13,6 +14,7 @@ import voyageai.error as error from voyageai._base import _BaseClient from voyageai.chunking import apply_chunking +from voyageai.local.model_registry import SUPPORTED_MODELS as LOCAL_MODELS from voyageai.object import ( ContextualizedEmbeddingsObject, EmbeddingsObject, @@ -22,12 +24,15 @@ from voyageai.object.multimodal_embeddings import MultimodalInputRequest from voyageai.video_utils import Video +if TYPE_CHECKING: + from voyageai.local.sentence_transformer_backend import SentenceTransformerBackend + class AsyncClient(_BaseClient): """Voyage AI Async Client Args: - api_key (str): Your API key. + api_key (str): Your API key (not required for local models). max_retries (int): Maximum number of retries if API call fails. timeout (float): Timeout in seconds. base_url (str): Base URL for the API endpoint. @@ -41,6 +46,7 @@ def __init__( base_url: Optional[str] = None, ) -> None: super().__init__(api_key, max_retries, timeout, base_url) + self._local_backends: Dict[str, "SentenceTransformerBackend"] = {} def _make_retry_controller(self) -> AsyncRetrying: return AsyncRetrying( @@ -54,6 +60,62 @@ def _make_retry_controller(self) -> AsyncRetrying: ), ) + def _get_local_backend(self, model: str) -> "SentenceTransformerBackend": + """Get or create a local backend for the given model.""" + if model not in self._local_backends: + from voyageai.local.sentence_transformer_backend import SentenceTransformerBackend + + self._local_backends[model] = SentenceTransformerBackend(model) + return self._local_backends[model] + + def _embed_local_sync( + self, + texts: List[str], + model: str, + input_type: Optional[str] = None, + truncation: bool = True, + output_dtype: Optional[str] = None, + output_dimension: Optional[int] = None, + ) -> EmbeddingsObject: + """Generate embeddings using a local model (sync, for use with to_thread).""" + backend = self._get_local_backend(model) + + embeddings_array = backend.encode( + texts=texts, + input_type=input_type, + output_dtype=output_dtype, + output_dimension=output_dimension, + truncation=truncation, + ) + + total_tokens = backend.count_tokens(texts) + + result = EmbeddingsObject() + result.embeddings = embeddings_array.tolist() + result.total_tokens = total_tokens + + return result + + async def _embed_local( + self, + texts: List[str], + model: str, + input_type: Optional[str] = None, + truncation: bool = True, + output_dtype: Optional[str] = None, + output_dimension: Optional[int] = None, + ) -> EmbeddingsObject: + """Generate embeddings using a local model (async).""" + return await asyncio.to_thread( + self._embed_local_sync, + texts=texts, + model=model, + input_type=input_type, + truncation=truncation, + output_dtype=output_dtype, + output_dimension=output_dimension, + ) + async def embed( self, texts: List[str], @@ -72,6 +134,24 @@ async def embed( "provided by Voyage AI." ) + # Check if this is a local model + if model in LOCAL_MODELS: + return await self._embed_local( + texts=texts, + model=model, + input_type=input_type, + truncation=truncation, + output_dtype=output_dtype, + output_dimension=output_dimension, + ) + + # API models require an API key + if not self.api_key: + raise error.AuthenticationError( + "An API key is required for API-based models. " + "Set your API key via VOYAGE_API_KEY environment variable or pass it to AsyncClient(api_key=...)." + ) + response = None async for attempt in self._make_retry_controller(): with attempt: diff --git a/voyageai/local/__init__.py b/voyageai/local/__init__.py new file mode 100644 index 0000000..c8804b7 --- /dev/null +++ b/voyageai/local/__init__.py @@ -0,0 +1,30 @@ +"""Local model support for Voyage AI SDK. + +This module provides lazy imports to avoid loading torch/sentence-transformers +for API-only users. Install with: pip install voyageai[local] +""" + +# Lazy imports - set to None if not available +try: + import sentence_transformers as _sentence_transformers + import torch as _torch +except ImportError: # pragma: no cover - handled lazily in functions + _sentence_transformers = None # type: ignore[assignment] + _torch = None # type: ignore[assignment] + + +def _ensure_local_deps() -> tuple: + """Ensure local model dependencies are available. + + Returns: + Tuple of (sentence_transformers, torch) modules. + + Raises: + ImportError: If sentence-transformers or torch are not installed. + """ + if _sentence_transformers is None or _torch is None: + raise ImportError( + "The 'sentence-transformers' and 'torch' packages are required for local models. " + "Install them with: pip install voyageai[local]" + ) + return _sentence_transformers, _torch diff --git a/voyageai/local/model_registry.py b/voyageai/local/model_registry.py new file mode 100644 index 0000000..806e407 --- /dev/null +++ b/voyageai/local/model_registry.py @@ -0,0 +1,154 @@ +"""Model configuration and thread-safe caching for local models.""" + +import threading +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass(frozen=True) +class LocalModelConfig: + """Configuration for a local embedding model.""" + + huggingface_id: str + max_tokens: int + default_dimension: int + supported_dimensions: tuple + supported_precisions: tuple + trust_remote_code: bool = True + + def validate_dimension(self, dimension: Optional[int]) -> int: + """Validate and return the dimension to use. + + Args: + dimension: Requested dimension, or None for default. + + Returns: + The dimension to use. + + Raises: + ValueError: If dimension is not supported. + """ + if dimension is None: + return self.default_dimension + if dimension not in self.supported_dimensions: + raise ValueError( + f"Invalid output_dimension {dimension}. " + f"Supported dimensions: {self.supported_dimensions}" + ) + return dimension + + def validate_precision(self, precision: Optional[str]) -> Optional[str]: + """Validate and return the precision to use. + + Args: + precision: Requested precision, or None for default (float32). + + Returns: + The precision to use. + + Raises: + ValueError: If precision is not supported. + """ + if precision is None: + return None + if precision not in self.supported_precisions: + raise ValueError( + f"Invalid output_dtype '{precision}'. " + f"Supported dtypes: {self.supported_precisions}" + ) + return precision + + +# Supported local models +SUPPORTED_MODELS: Dict[str, LocalModelConfig] = { + "voyage-4-nano": LocalModelConfig( + huggingface_id="voyageai/voyage-4-nano", + max_tokens=32768, + default_dimension=2048, + supported_dimensions=(2048, 1024, 512, 256), + supported_precisions=("float32", "int8", "uint8", "binary", "ubinary"), + trust_remote_code=True, + ), +} + + +def get_model_config(model: str) -> LocalModelConfig: + """Get configuration for a model. + + Args: + model: Model name. + + Returns: + LocalModelConfig for the model. + + Raises: + ValueError: If model is not supported. + """ + if model not in SUPPORTED_MODELS: + raise ValueError( + f"Unsupported local model '{model}'. " + f"Supported models: {list(SUPPORTED_MODELS.keys())}" + ) + return SUPPORTED_MODELS[model] + + +class ModelCache: + """Thread-safe singleton cache for loaded models. + + Avoids reloading models per call, which can be expensive. + """ + + _instance: Optional["ModelCache"] = None + _lock: threading.Lock = threading.Lock() + + def __new__(cls) -> "ModelCache": + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._cache: Dict[str, Any] = {} + cls._instance._cache_lock = threading.Lock() + return cls._instance + + def get(self, key: str) -> Optional[Any]: + """Get a cached model. + + Args: + key: Cache key (typically model name + device). + + Returns: + Cached model or None if not found. + """ + with self._cache_lock: + return self._cache.get(key) + + def set(self, key: str, model: Any) -> None: + """Cache a model. + + Args: + key: Cache key. + model: Model to cache. + """ + with self._cache_lock: + self._cache[key] = model + + def clear(self) -> None: + """Clear all cached models.""" + with self._cache_lock: + self._cache.clear() + + def get_or_load(self, key: str, loader: callable) -> Any: + """Get a cached model or load it if not cached. + + Args: + key: Cache key. + loader: Callable to load the model if not cached. + + Returns: + The cached or newly loaded model. + """ + model = self.get(key) + if model is None: + model = loader() + self.set(key, model) + return model diff --git a/voyageai/local/sentence_transformer_backend.py b/voyageai/local/sentence_transformer_backend.py new file mode 100644 index 0000000..d5c8d42 --- /dev/null +++ b/voyageai/local/sentence_transformer_backend.py @@ -0,0 +1,125 @@ +"""Sentence-transformers backend for local model inference.""" + +from typing import List, Optional + +import numpy as np + +from voyageai.local import _ensure_local_deps +from voyageai.local.model_registry import ModelCache, get_model_config + +# Mapping from SDK output_dtype to sentence-transformers precision +DTYPE_TO_PRECISION = { + "float32": "float32", + "float": "float32", + "int8": "int8", + "uint8": "uint8", + "binary": "binary", + "ubinary": "ubinary", +} + + +class SentenceTransformerBackend: + """Wrapper for sentence-transformers with SDK-compatible interface.""" + + def __init__( + self, + model_name: str, + device: Optional[str] = None, + ): + """Initialize the backend. + + Args: + model_name: Name of the model (e.g., "voyage-4-nano"). + device: Device to use ("cuda", "cpu", or None for auto-detect). + """ + sentence_transformers, torch = _ensure_local_deps() + + self.config = get_model_config(model_name) + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + # Cache key includes model and device + cache_key = f"{model_name}:{self.device}" + cache = ModelCache() + + def load_model(): + return sentence_transformers.SentenceTransformer( + self.config.huggingface_id, + trust_remote_code=self.config.trust_remote_code, + device=self.device, + ) + + self.model = cache.get_or_load(cache_key, load_model) + self._tokenizer = self.model.tokenizer + + def encode( + self, + texts: List[str], + input_type: Optional[str] = None, + output_dtype: Optional[str] = None, + output_dimension: Optional[int] = None, + truncation: bool = True, + ) -> np.ndarray: + """Encode texts into embeddings. + + Args: + texts: List of texts to encode. + input_type: "query", "document", or None. + output_dtype: Output data type (float32, int8, uint8, binary, ubinary). + output_dimension: Dimension to truncate embeddings to (MRL support). + truncation: Whether to truncate texts exceeding max tokens. + + Returns: + Numpy array of embeddings. + """ + # Validate and get dimension + dimension = self.config.validate_dimension(output_dimension) + + # Validate and map precision + self.config.validate_precision(output_dtype) + precision = DTYPE_TO_PRECISION.get(output_dtype) if output_dtype else None + + # Build encode kwargs + encode_kwargs = { + "truncate_dim": dimension if dimension != self.config.default_dimension else None, + } + if precision: + encode_kwargs["precision"] = precision + + # Route based on input_type + if input_type == "query": + # Use prompt-based encoding for queries + embeddings = self.model.encode( + texts, + prompt_name="query", + **encode_kwargs, + ) + elif input_type == "document": + # Use prompt-based encoding for documents + embeddings = self.model.encode( + texts, + prompt_name="document", + **encode_kwargs, + ) + else: + # Default encoding without prompts + embeddings = self.model.encode( + texts, + **encode_kwargs, + ) + + return embeddings + + def count_tokens(self, texts: List[str]) -> int: + """Count total tokens across all texts. + + Args: + texts: List of texts to count tokens for. + + Returns: + Total token count. + """ + total = 0 + for text in texts: + encoded = self._tokenizer.encode(text, add_special_tokens=True) + total += len(encoded) + return total