diff --git a/docs/source/api/bluesearch.k8s.embeddings.rst b/docs/source/api/bluesearch.k8s.embeddings.rst new file mode 100644 index 000000000..4e73edd0d --- /dev/null +++ b/docs/source/api/bluesearch.k8s.embeddings.rst @@ -0,0 +1,7 @@ +bluesearch.k8s.embeddings module +================================ + +.. automodule:: bluesearch.k8s.embeddings + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/bluesearch.k8s.rst b/docs/source/api/bluesearch.k8s.rst index a76de5009..207599b60 100644 --- a/docs/source/api/bluesearch.k8s.rst +++ b/docs/source/api/bluesearch.k8s.rst @@ -9,6 +9,7 @@ Submodules bluesearch.k8s.connect bluesearch.k8s.create_indices + bluesearch.k8s.embeddings Module contents --------------- diff --git a/notebooks/check_paragrapha_size.ipynb b/notebooks/check_paragrapha_size.ipynb new file mode 100644 index 000000000..37a690f96 --- /dev/null +++ b/notebooks/check_paragrapha_size.ipynb @@ -0,0 +1,231 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# connect to ES" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bluesearch.k8s.connect import connect" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "client = connect()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# tokenize all the paragraphs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tqdm\n", + "from elasticsearch.helpers import scan" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(\"sentence-transformers/multi-qa-MiniLM-L6-cos-v1\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lens = []\n", + "progress = tqdm.tqdm(position=0, unit=\" Docs\", desc=\"Scanning paragraphs\")\n", + "body = {\"query\":{\"match_all\":{}}}\n", + "for hit in scan(client, query=body, index=\"paragraphs\"):\n", + " emb = tokenizer.tokenize(hit['_source']['text'])\n", + " lens.append(len(emb))\n", + " progress.update(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# plot results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "sns.set()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.boxplot(lens)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.boxplot(lens)\n", + "plt.ylim([0, 512])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(lens)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(lens, bins=100)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(lens, bins=100)\n", + "plt.xlim([0, 512])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lens=np.array(lens)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(lens[np.array(lens)>512]) / len(lens) * 100" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get biggest paragraphs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "paragraphs = []\n", + "progress = tqdm.tqdm(position=0, unit=\" Docs\", desc=\"Scanning paragraphs\")\n", + "body = {\"query\":{\"match_all\":{}}}\n", + "for hit in scan(client, query=body, index=\"paragraphs\"):\n", + " emb = tokenizer.tokenize(hit['_source']['text'])\n", + " hit['_source']['tokenizer'] = ', '.join(emb)\n", + " progress.update(1)\n", + " if len(emb) > 1000:\n", + " paragraphs.append(hit['_source'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "paragraphs" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.5 ('py10')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "e14b248c68ef27f7e40aef879e7b97aaa0976632ef81142793ba6d8efee923a4" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index aecd7cbe0..8c6f90b88 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ # Required to encrypt mysql password; >= 3.2 to fix RSA decryption vulnerability "cryptography>=3.2", "defusedxml", - "elasticsearch>=8", + "elasticsearch==8.3.3", "google-cloud-storage", "h5py", "ipython", diff --git a/src/bluesearch/embedding_models.py b/src/bluesearch/embedding_models.py index 8ef3555f8..7f5d0f69f 100644 --- a/src/bluesearch/embedding_models.py +++ b/src/bluesearch/embedding_models.py @@ -22,6 +22,7 @@ import pathlib import pickle # nosec from abc import ABC, abstractmethod +from typing import Any import numpy as np import sentence_transformers @@ -38,10 +39,10 @@ class EmbeddingModel(ABC): @property @abstractmethod - def dim(self): + def dim(self) -> int: """Return dimension of the embedding.""" - def preprocess(self, raw_sentence): + def preprocess(self, raw_sentence: str) -> str: """Preprocess the sentence (Tokenization, ...) if needed by the model. This is a default implementation that perform no preprocessing. @@ -49,7 +50,7 @@ def preprocess(self, raw_sentence): Parameters ---------- - raw_sentence : str + raw_sentence Raw sentence to embed. Returns @@ -59,14 +60,14 @@ def preprocess(self, raw_sentence): """ return raw_sentence - def preprocess_many(self, raw_sentences): + def preprocess_many(self, raw_sentences: list[str]) -> list[str]: """Preprocess multiple sentences. This is a default implementation and can be overridden by children classes. Parameters ---------- - raw_sentences : list of str + raw_sentences List of str representing raw sentences that we want to embed. Returns @@ -77,21 +78,21 @@ def preprocess_many(self, raw_sentences): return [self.preprocess(sentence) for sentence in raw_sentences] @abstractmethod - def embed(self, preprocessed_sentence): + def embed(self, preprocessed_sentence: str) -> np.ndarray[Any, Any]: """Compute the sentences embeddings for a given sentence. Parameters ---------- - preprocessed_sentence : str + preprocessed_sentence Preprocessed sentence to embed. Returns ------- - embedding : numpy.array + embedding One dimensional vector representing the embedding of the given sentence. """ - def embed_many(self, preprocessed_sentences): + def embed_many(self, preprocessed_sentences: list[str]) -> np.ndarray[Any, Any]: """Compute sentence embeddings for all provided sentences. This is a default implementation. Children classes can implement more @@ -99,12 +100,12 @@ def embed_many(self, preprocessed_sentences): Parameters ---------- - preprocessed_sentences : list of str + preprocessed_sentences List of preprocessed sentences. Returns ------- - embeddings : np.ndarray + embeddings 2D numpy array with shape `(len(preprocessed_sentences), self.dim)`. Each row is an embedding of a sentence in `preprocessed_sentences`. """ @@ -116,7 +117,7 @@ class SentTransformer(EmbeddingModel): Parameters ---------- - model_name_or_path : pathlib.Path or str + model_name_or_path The name or the path of the Transformer model to load. References @@ -124,43 +125,53 @@ class SentTransformer(EmbeddingModel): https://github.com/UKPLab/sentence-transformers """ - def __init__(self, model_name_or_path, device=None): + def __init__( + self, model_name_or_path: pathlib.Path | str, device: str | None = None + ): self.senttransf_model = sentence_transformers.SentenceTransformer( str(model_name_or_path), device=device ) @property - def dim(self): + def dim(self) -> int: """Return dimension of the embedding.""" - return 768 + return self.senttransf_model.get_sentence_embedding_dimension() - def embed(self, preprocessed_sentence): + @property + def normalized(self) -> bool: + """Return true is the model as a normalization module.""" + for _, module in self.senttransf_model._modules.items(): + if str(module) == "Normalize()": + return True + return False + + def embed(self, preprocessed_sentence: str) -> np.ndarray[Any, Any]: """Compute the sentences embeddings for a given sentence. Parameters ---------- - preprocessed_sentence : str + preprocessed_sentence Preprocessed sentence to embed. Returns ------- - embedding : numpy.array + embedding Embedding of the given sentence of shape (768,). """ return self.embed_many([preprocessed_sentence]).squeeze() - def embed_many(self, preprocessed_sentences): + def embed_many(self, preprocessed_sentences: list[str]) -> np.ndarray[Any, Any]: """Compute sentence embeddings for multiple sentences. Parameters ---------- - preprocessed_sentences : list of str + preprocessed_sentences Preprocessed sentences to embed. Returns ------- - embedding : numpy.array + embedding Embedding of the specified sentences of shape `(len(preprocessed_sentences), 768)`. """ @@ -173,22 +184,22 @@ class SklearnVectorizer(EmbeddingModel): Parameters ---------- - checkpoint_path : pathlib.Path or str + checkpoint_path The path of the scikit-learn model to use for the embeddings in Pickle format. """ - def __init__(self, checkpoint_path): + def __init__(self, checkpoint_path: pathlib.Path | str): self.checkpoint_path = pathlib.Path(checkpoint_path) with self.checkpoint_path.open("rb") as f: self.model = pickle.load(f) # nosec @property - def dim(self): + def dim(self) -> int: """Return dimension of the embedding. Returns ------- - dim : int + dim The dimension of the embedding. """ if hasattr(self.model, "n_features"): # e.g. HashingVectorizer @@ -201,35 +212,35 @@ def dim(self): f"{type(self.model)} could not be computed." ) - def embed(self, preprocessed_sentence): + def embed(self, preprocessed_sentence: str) -> np.ndarray[Any, Any]: """Embed one given sentence. Parameters ---------- - preprocessed_sentence : str + preprocessed_sentence Preprocessed sentence to embed. Can by obtained using the `preprocess` or `preprocess_many` methods. Returns ------- - embedding : numpy.ndarray + embedding Array of shape `(dim,)` with the sentence embedding. """ embedding = self.embed_many([preprocessed_sentence]) return embedding.squeeze() - def embed_many(self, preprocessed_sentences): + def embed_many(self, preprocessed_sentences: list[str]) -> np.ndarray[Any, Any]: """Compute sentence embeddings for multiple sentences. Parameters ---------- - preprocessed_sentences : iterable of str + preprocessed_sentences Preprocessed sentences to embed. Can by obtained using the `preprocess` or `preprocess_many` methods. Returns ------- - embeddings : numpy.ndarray + embeddings Array of shape `(len(preprocessed_sentences), dim)` with the sentence embeddings. """ @@ -237,7 +248,12 @@ def embed_many(self, preprocessed_sentences): return embeddings -def compute_database_embeddings(connection, model, indices, batch_size=10): +def compute_database_embeddings( + connection: sqlalchemy.engine.Engine, + model: EmbeddingModel, + indices: np.ndarray[Any, Any], + batch_size: int = 10, +) -> tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]]: """Compute sentences embeddings. The embeddings are computed for a given model and a given database @@ -245,14 +261,14 @@ def compute_database_embeddings(connection, model, indices, batch_size=10): Parameters ---------- - connection : sqlalchemy.engine.Engine + connection Connection to the database. - model : EmbeddingModel + model Instance of the EmbeddingModel of choice. - indices : np.ndarray + indices 1D array storing the sentence_ids for which we want to perform the embedding. - batch_size : int + batch_size Number of sentences to preprocess and embed at the same time. Should lead to major speedups. Note that the last batch will have a length of `n_sentences % batch_size` (unless it is 0). Note that some models @@ -261,10 +277,10 @@ def compute_database_embeddings(connection, model, indices, batch_size=10): Returns ------- - final_embeddings : np.array + final_embeddings 2D numpy array with all sentences embeddings for the given models. Its shape is `(len(retrieved_indices), dim)`. - retrieved_indices : np.ndarray + retrieved_indices 1D array of sentence_ids that we managed to embed. Note that the order corresponds exactly to the rows in `final_embeddings`. """ @@ -331,23 +347,30 @@ def get_embedding_model( Returns ------- - sentence_embedding_model : EmbeddingModel + sentence_embedding_model The sentence embedding model instance. """ - configs = { - # Transformer models. - "SentTransformer": lambda: SentTransformer(checkpoint_path, device), - "BioBERT NLI+STS": lambda: SentTransformer( - "clagator/biobert_v1.1_pubmed_nli_sts", device - ), - "SBioBERT": lambda: SentTransformer("gsarti/biobert-nli", device), - "SBERT": lambda: SentTransformer("bert-base-nli-mean-tokens", device), - # Scikit-learn models. - "SklearnVectorizer": lambda: SklearnVectorizer(checkpoint_path), - } - if model_name_or_class not in configs: - raise ValueError(f"Unknown model name or class: {model_name_or_class}") - return configs[model_name_or_class]() + if model_name_or_class in ["SentTransformer", "SklearnVectorizer"]: + if checkpoint_path is not None: + if model_name_or_class == "SentTransformer": + return SentTransformer(checkpoint_path, device) + elif model_name_or_class == "SklearnVectorizer": + return SklearnVectorizer(checkpoint_path) + else: + raise ValueError( + f"Something went wrong, model {model_name_or_class} not " + f"implemented." + ) + else: + raise ValueError("Checkpoint path must be provided for this model.") + elif model_name_or_class == "BioBERT NLI+STS": + return SentTransformer("clagator/biobert_v1.1_pubmed_nli_sts", device) + elif model_name_or_class == "SBioBERT": + return SentTransformer("gsarti/biobert-nli", device) + elif model_name_or_class == "SBERT": + return SentTransformer("bert-base-nli-mean-tokens", device) + else: + raise ValueError("Unknown model name or class.") class MPEmbedder: @@ -355,48 +378,48 @@ class MPEmbedder: Parameters ---------- - database_url : str + database_url URL of the database. - model_name_or_class : str + model_name_or_class The name or class of the model for which to compute the embeddings. - indices : np.ndarray + indices 1D array storing the sentence_ids for which we want to compute the embedding. - h5_path_output : pathlib.Path + h5_path_output Path to where the output h5 file will be lying. - batch_size_inference : int + batch_size_inference Number of sentences to preprocess and embed at the same time. Should lead to major speedups. Note that the last batch will have a length of `n_sentences % batch_size` (unless it is 0). Note that some models (SBioBERT) might perform padding to the longest sentence in the batch and bigger batch size might not lead to a speedup. - batch_size_transfer : int + batch_size_transfer Batch size to be used for transfering data from the temporary h5 files to the final h5 file. - n_processes : int + n_processes Number of processes to use. Note that each process gets `len(indices) / n_processes` sentences to embed. - checkpoint_path : pathlib.Path or None + checkpoint_path If 'model_name_or_class' is the class, the path of the model to load. Otherwise, this argument is ignored. - gpus : None or list + gpus If not specified, all processes will be using CPU. If not None, then it needs to be a list of length `n_processes` where each element represents the GPU id (integer) to be used. None elements will be interpreted as CPU. - delete_temp : bool + delete_temp If True, the temporary h5 files are deleted after the final h5 is created. Disabling this flag is useful for testing and debugging purposes. temp_folder : None or pathlib.Path If None, then all temporary h5 files stored into the same folder as the output h5 file. Otherwise they are stored in the specified folder. - h5_dataset_name : str or None + h5_dataset_name The name of the dataset in the H5 file. Otherwise, the value of 'model_name_or_class' is used. start_method : str, {"fork", "forkserver", "spawn"} Start method for multiprocessing. Note that using "fork" might lead to problems when doing GPU inference. - preinitialize : bool + preinitialize If True we instantiate the model before running multiprocessing in order to download any checkpoints. Once instantiated, the model will be deleted. @@ -404,20 +427,20 @@ class MPEmbedder: def __init__( self, - database_url, - model_name_or_class, - indices, - h5_path_output, - batch_size_inference=16, - batch_size_transfer=1000, - n_processes=2, - checkpoint_path=None, - gpus=None, - delete_temp=True, - temp_folder=None, - h5_dataset_name=None, - start_method="forkserver", - preinitialize=True, + database_url: str, + model_name_or_class: str, + indices: np.ndarray[Any, Any], + h5_path_output: pathlib.Path, + checkpoint_path: pathlib.Path | str | None = None, + batch_size_inference: int = 16, + batch_size_transfer: int = 1000, + n_processes: int = 2, + gpus: list[Any] | None = None, + delete_temp: bool = True, + temp_folder: pathlib.Path | None = None, + h5_dataset_name: str | None = None, + start_method: str = "forkserver", + preinitialize: bool = True, ): self.database_url = database_url self.model_name_or_class = model_name_or_class @@ -445,7 +468,7 @@ def __init__( self.gpus = gpus - def do_embedding(self): + def do_embedding(self) -> None: """Do the parallelized embedding.""" if self.preinitialize: self.logger.info("Preinitializing model (download of checkpoints)") @@ -507,37 +530,37 @@ def do_embedding(self): @staticmethod def run_embedding_worker( - database_url, - model_name_or_class, - indices, - temp_h5_path, - batch_size, - checkpoint_path, - gpu, - h5_dataset_name, - ): + database_url: str, + model_name_or_class: str, + indices: np.ndarray[Any, Any], + temp_h5_path: pathlib.Path, + batch_size: int, + checkpoint_path: pathlib.Path | None = None, + gpu: int | None = None, + h5_dataset_name: str | None = None, + ) -> None: """Run per worker function. Parameters ---------- - database_url : str + database_url URL of the database. - model_name_or_class : str + model_name_or_class The name or class of the model for which to compute the embeddings. - indices : np.ndarray + indices 1D array of sentences ids indices representing what the worker needs to embed. - temp_h5_path : pathlib.Path + temp_h5_path Path to where we store the temporary h5 file. - batch_size : int + batch_size Number of sentences in the batch. - checkpoint_path : pathlib.Path or None + checkpoint_path If 'model_name_or_class' is the class, the path of the model to load. Otherwise, this argument is ignored. - gpu : int or None + gpu If None, we are going to use a CPU. Otherwise, we use a GPU with the specified id. - h5_dataset_name : str or None + h5_dataset_name The name of the dataset in the H5 file. """ current_process = mp.current_process() diff --git a/src/bluesearch/k8s/embeddings.py b/src/bluesearch/k8s/embeddings.py new file mode 100644 index 000000000..96abadc70 --- /dev/null +++ b/src/bluesearch/k8s/embeddings.py @@ -0,0 +1,240 @@ +# Blue Brain Search is a text mining toolbox focused on scientific use cases. +# +# Copyright (C) 2020 Blue Brain Project, EPFL. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Embed the paragraphs in the database.""" +from __future__ import annotations + +import functools +import logging +import os +from typing import Any + +import elasticsearch +import numpy as np +import requests +import tqdm +from dotenv import load_dotenv +from elasticsearch.helpers import scan + +from bluesearch.embedding_models import SentTransformer + +load_dotenv() + +logger = logging.getLogger(__name__) + + +def embed( + client: elasticsearch.Elasticsearch, + index: str = "paragraphs", + embedding_method: str = "seldon", + model_name: str = "minilm", + namespace: str = "seldon", + polling: str = "mean", + force: bool = False, +) -> None: + """Embed the paragraphs in the database locally. + + Parameters + ---------- + client + Elasticsearch client. + index + Name of the ES index. + embedding_method + Method to use to embed the paragraphs. + model_name + Name of the model to use for the embedding. + namespace + Namespace of the Seldon deployment. + polling + Polling method to use for the Seldon deployment. + """ + if embedding_method == "seldon": + embed = functools.partial( + embed_seldon, namespace=namespace, model_name=model_name, polling=polling + ) + elif embedding_method == "bentoml": + embed = functools.partial(embed_bentoml, model_name=model_name) + elif embedding_method == "local": + embed = functools.partial( + embed_locally, + model_name=model_name, + ) + else: + raise ValueError(f"Unknown embedding method: {embedding_method}") + + # get paragraphs without embeddings + if force: + query: dict[str, Any] = {"query": {"match_all": {}}} + else: + query = { + "query": {"bool": {"must_not": {"exists": {"field": "embedding"}}}} + } + paragraph_count = client.count(index=index, query=query)["count"] + logger.info("There are {paragraph_count} paragraphs without embeddings") + + # creates embeddings for all the documents withouts embeddings and updates them + progress = tqdm.tqdm( + total=paragraph_count, + position=0, + unit=" Paragraphs", + desc="Updating embeddings", + ) + for hit in scan(client, query={"query": query}, index=index): + emb = embed(hit["_source"]["text"]) + client.update(index=index, doc={"embedding": emb}, id=hit["_id"]) + progress.update(1) + + +def embed_locally( + text: str, model_name: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" +) -> list[float]: + """Embed the paragraphs in the database locally. + + Parameters + ---------- + text + Text to embed. + model_name + Name of the model to use for the embedding. + + Returns + ------- + embedding + Embedding of the text. + """ + model = SentTransformer(model_name) + emb = model.embed(text) + if not model.normalized: + emb /= np.linalg.norm(emb) + return emb.tolist() + + +def embed_seldon( + text: str, + namespace: str = "seldon", + model_name: str = "minilm", + polling: str = "mean", +) -> list[float]: + """Embed the paragraphs in the database using Seldon. + + Parameters + ---------- + text + Text to embed. + namespace + Namespace of the Seldon deployment. + model_name + Name of the Seldon deployment. + polling + Polling method to use for the Seldon deployment. + + Returns + ------- + embedding + Embedding of the text. + """ + url = ( + "http://" + + os.environ["SELDON_URL"] + + "/seldon/" + + namespace + + "/" + + model_name + + "/v2/models/transformer/infer" + ) + + # create payload + response = requests.post( + url, + json={ + "inputs": [ + { + "name": "args", + "shape": [1], + "datatype": "BYTES", + "data": text, + "parameters": {"content_type": "str"}, + } + ] + }, + ) + + if not response.status_code == 200: + raise ValueError("Error in the request") + + # convert the response to a numpy array + tensor = response.json()["outputs"][0]["data"][0] + tensor = tensor[3:-3].split("], [") + tensor = np.vstack([np.array(t.split(", "), dtype=np.float32) for t in tensor]) + + # apply the polling method + if polling: + if polling == "max": + tensor = np.max(tensor, axis=0) + elif polling == "mean": + tensor = np.mean(tensor, axis=0) + + # normalize the embedding + tensor /= np.linalg.norm(tensor) + + return tensor.tolist() + + +def embed_bentoml( + text: str, model_name: str = "minilm", polling: str = "mean" +) -> list[float]: + """Embed the paragraphs in the database using BentoML. + + Parameters + ---------- + text + Text to embed. + model_name + Name of the BentoML deployment. + + Returns + ------- + embedding + Embedding of the text. + """ + url = "http://" + os.environ["BENTOML_EMBEDDING_URL"] + "/" + model_name + + # create payload + response = requests.post( + url, + headers={"accept": "application/json", "Content-Type": "text/plain"}, + data=text, + ) + + if not response.status_code == 200: + raise ValueError("Error in the request") + + # convert the response to a numpy array + tensor = response.json() + tensor = np.vstack(tensor[0]) + + # apply the polling method + if polling: + if polling == "max": + tensor = np.max(tensor, axis=0) + elif polling == "mean": + tensor = np.mean(tensor, axis=0) + + # normalize the embedding + tensor /= np.linalg.norm(tensor) + + return tensor diff --git a/tests/conftest.py b/tests/conftest.py index 43d0f97ba..a3cac711a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -518,5 +518,5 @@ def get_es_client(monkeypatch): if client is not None: for index in client.indices.get_alias().keys(): - if index in ["articles", "paragraphs", "test_index"]: + if index in ["test_articles", "test_paragraphs", "test_index"]: remove_index(client, index) diff --git a/tests/unit/k8s/test_add_embeedings.py b/tests/unit/k8s/test_add_embeedings.py new file mode 100644 index 000000000..879305f18 --- /dev/null +++ b/tests/unit/k8s/test_add_embeedings.py @@ -0,0 +1,60 @@ +import os + +import numpy as np +import pytest + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_add_embeddings_locally(get_es_client): + from bluesearch.k8s.create_indices import ( + MAPPINGS_PARAGRAPHS, + SETTINGS, + add_index, + remove_index, + ) + from bluesearch.k8s.embeddings import embed + + client = get_es_client + if client is None: + pytest.skip("Elastic search is not available") + + add_index(client, "test_paragraphs", SETTINGS, MAPPINGS_PARAGRAPHS) + + docs = { + "1": {"text": "some test text"}, + "2": {"text": "some other test text"}, + "3": {"text": "some final test text"}, + } + + for doc_id, doc in docs.items(): + client.create(index="test_paragraphs", id=doc_id, document=doc) + client.indices.refresh(index="test_paragraphs") + + query = {"bool": {"must_not": {"exists": {"field": "embedding"}}}} + paragraph_count = client.count(index="test_paragraphs", query=query) + assert paragraph_count["count"] == 3 + + embed( + client, + index="test_paragraphs", + embedding_method="local", + model_name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1", + ) + client.indices.refresh(index="test_paragraphs") + + query = {"bool": {"must_not": {"exists": {"field": "embedding"}}}} + paragraph_count = client.count(index="test_paragraphs", query=query)["count"] + assert paragraph_count == 0 + + remove_index(client, "test_paragraphs") + + +def test_embedding_bentoml(): + if os.environ.get("BENTOML_URL") is None: + pytest.skip("BENTOML_URL is not available") + + from bluesearch.k8s.embeddings import embed_bentoml, embed_locally + + emb_local = embed_locally("some test text") + emb_bentoml = embed_bentoml("some test text") + assert np.allclose(emb_local, emb_bentoml, rtol=1e-04, atol=1e-07) diff --git a/tests/unit/test_embedding_models.py b/tests/unit/test_embedding_models.py index 1556dacb4..6367eae46 100644 --- a/tests/unit/test_embedding_models.py +++ b/tests/unit/test_embedding_models.py @@ -84,8 +84,6 @@ def test_senttransf_embedding(self, monkeypatch, n_sentences): embed_method = getattr(sbert, "embed" if n_sentences == 1 else "embed_many") # Assertions - assert sbert.dim == 768 - preprocessed_sentence = preprocess_method(dummy_sentence) assert preprocessed_sentence == dummy_sentence @@ -282,16 +280,16 @@ def test_invalid_key(self): get_embedding_model("wrong_model_name") @pytest.mark.parametrize( - "name, underlying_class", + "name, underlying_class, checkpoint", [ - ("BioBERT NLI+STS", "SentTransformer"), - ("SentTransformer", "SentTransformer"), - ("SklearnVectorizer", "SklearnVectorizer"), - ("SBioBERT", "SentTransformer"), - ("SBERT", "SentTransformer"), + ("BioBERT NLI+STS", "SentTransformer", None), + ("SentTransformer", "SentTransformer", "fake_model_name"), + ("SklearnVectorizer", "SklearnVectorizer", "fake_checkpoint"), + ("SBioBERT", "SentTransformer", None), + ("SBERT", "SentTransformer", None), ], ) - def test_returns_instance(self, monkeypatch, name, underlying_class): + def test_returns_instance(self, monkeypatch, name, underlying_class, checkpoint): fake_instance = Mock() fake_class = Mock(return_value=fake_instance) @@ -299,7 +297,7 @@ def test_returns_instance(self, monkeypatch, name, underlying_class): f"bluesearch.embedding_models.{underlying_class}", fake_class ) - returned_instance = get_embedding_model(name) + returned_instance = get_embedding_model(name, checkpoint) assert returned_instance is fake_instance @@ -409,3 +407,27 @@ def test_do_embedding(self, monkeypatch, tmp_path, n_processes): args, _ = fake_h5.concatenate.call_args assert len(args[2]) == n_processes + + +@pytest.mark.parametrize( + ("model_name", "expected_dim"), + [ + ("sentence-transformers/multi-qa-MiniLM-L6-cos-v1", 384), + ("sentence-transformers/multi-qa-mpnet-base-dot-v1", 768), + ], +) +def test_embedding_size(model_name, expected_dim): + model = SentTransformer(model_name) + assert model.dim == expected_dim + + +@pytest.mark.parametrize( + ("model_name", "is_normalized"), + [ + ("sentence-transformers/multi-qa-MiniLM-L6-cos-v1", True), + ("sentence-transformers/multi-qa-mpnet-base-dot-v1", False), + ], +) +def test_model_is_normalized(model_name, is_normalized): + model = SentTransformer(model_name) + assert model.normalized is is_normalized