From e7730dc417a5e47396d039dc8886dc1b8cf48663 Mon Sep 17 00:00:00 2001 From: Basil Sunny Date: Thu, 15 Feb 2024 12:01:00 +0530 Subject: [PATCH] feat: add universal sentence encoder embedding function --- chromadb/utils/embedding_functions.py | 34 +++++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 3f0a1ce043b..737e25d378e 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -743,9 +743,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): - def __init__( - self, api_key: str = "", api_url = "https://infer.roboflow.com" - ) -> None: + def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> None: """ Create a RoboflowEmbeddingFunction. @@ -757,7 +755,7 @@ def __init__( api_key = os.environ.get("ROBOFLOW_API_KEY") self._api_url = api_url - self._api_key = api_key + self._api_key = api_key try: self._PILImage = importlib.import_module("PIL.Image") @@ -789,10 +787,10 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: json=infer_clip_payload, ) - result = res.json()['embeddings'] + result = res.json()["embeddings"] embeddings.append(result[0]) - + elif is_document(item): infer_clip_payload = { "text": input, @@ -803,13 +801,13 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: json=infer_clip_payload, ) - result = res.json()['embeddings'] + result = res.json()["embeddings"] embeddings.append(result[0]) return embeddings - + class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): def __init__( self, @@ -900,6 +898,22 @@ def __call__(self, input: Documents) -> Embeddings: ) +class UniversalSentenceEncoderEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__( + self, model_name: str = "https://tfhub.dev/google/universal-sentence-encoder/4" + ): + try: + import tensorflow_hub as hub + except ImportError: + raise ValueError( + "The tensorflow_hub python package is not installed. Please install it with `pip install tensorflow_hub`" + ) + self._model = hub.load(model_name) + + def __call__(self, input: Documents) -> Embeddings: + return cast(Embeddings, self._model(input).numpy().tolist()) + + def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore try: from langchain_core.embeddings import Embeddings as LangchainEmbeddings @@ -962,7 +976,7 @@ def __call__(self, input: Documents) -> Embeddings: # type: ignore return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) - + class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): """ This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). @@ -1018,7 +1032,7 @@ def __call__(self, input: Documents) -> Embeddings: ], ) - + # List of all classes in this module _classes = [ name