diff --git a/README.md b/README.md index f33be597..ec501582 100644 --- a/README.md +++ b/README.md @@ -363,7 +363,7 @@ Every array will produce the combinations of flat configurations when the method ## Description of embedding models config -`embedding_model` is an array containing the configuration for the embedding models to use. Embedding model `type` must be `azure` for Azure OpenAI models and `sentence-transformer` for HuggingFace sentence transformer models. +`embedding_model` is an array containing the configuration for the embedding models to use. Embedding model `type` must be `azure` for Azure OpenAI models, `sentence-transformer` for HuggingFace sentence transformer models and `custom-embedding` for custom embeddings deployed as Azure Online Endpoints. ### Azure OpenAI embedding model config @@ -408,6 +408,18 @@ When using the [newer embeddings models (v3)](https://openai.com/blog/new-embedd } ``` +### Custom embedding model + +```json +{ + "type": "custom-embedding", + "model_name": "the name of the Azure deployment of the custom embedding model", + "dimension": "the dimension of the custom embedding model. This field is not required" +} +``` + +The variables `azure_model_api_key` and `azure_model_api_endpoint` should also be set in the environment variables (.env file). + ## Query Expansion Giving an example of an hypothetical answer for the question in query, an hypothetical passage which holds an answer to the query, or generate few alternative related question might improve retrieval and thus get more accurate chunks of docs to pass into LLM context. diff --git a/rag_experiment_accelerator/config/environment.py b/rag_experiment_accelerator/config/environment.py index d895cdde..1cf00659 100644 --- a/rag_experiment_accelerator/config/environment.py +++ b/rag_experiment_accelerator/config/environment.py @@ -92,6 +92,8 @@ class Environment: azure_document_intelligence_endpoint: Optional[str] azure_document_intelligence_admin_key: Optional[str] azure_key_vault_endpoint: Optional[str] + azure_model_api_key: Optional[str] + azure_model_api_endpoint: Optional[str] @classmethod def _field_names(cls) -> list[str]: diff --git a/rag_experiment_accelerator/config/tests/test_environment.py b/rag_experiment_accelerator/config/tests/test_environment.py index bc75f278..519e81df 100644 --- a/rag_experiment_accelerator/config/tests/test_environment.py +++ b/rag_experiment_accelerator/config/tests/test_environment.py @@ -125,7 +125,9 @@ def test_to_keyvault(mock_init_keyvault): azure_language_service_key=None, azure_key_vault_endpoint="test_endpoint", azure_search_use_semantic_search="True", + azure_model_api_key="mock_key", + azure_model_api_endpoint="mock_endpoint", ) environment.to_keyvault() - assert mock_keyvault.set_secret.call_count == 17 + assert mock_keyvault.set_secret.call_count == 19 diff --git a/rag_experiment_accelerator/embedding/custom_embedding_model.py b/rag_experiment_accelerator/embedding/custom_embedding_model.py new file mode 100644 index 00000000..9e2f057d --- /dev/null +++ b/rag_experiment_accelerator/embedding/custom_embedding_model.py @@ -0,0 +1,124 @@ +import urllib.request +import json +import os +import ssl +from typing import Union + +from rag_experiment_accelerator.config.environment import Environment +from rag_experiment_accelerator.embedding.embedding_model import EmbeddingModel +from rag_experiment_accelerator.utils.logging import get_logger + +logger = get_logger(__name__) + + +class CustomEmbeddingModel(EmbeddingModel): + """ + A class representing a Custom Embedding Model deployed as an AzureML online endpoint. + + Args: + model_name (str): The name of the deployment. + environment (Environment): The initialized environment. + dimension (int, optional): The dimension of the embedding. Defaults to 1536. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, model_name: str, environment: Environment, dimension: int = 1536, **kwargs + ): + super().__init__(name=model_name, dimension=dimension, **kwargs) + self.environment = environment + pass + + def prepare_request(self, body: Union[dict, list]) -> Union[dict, bytes]: + """ + Prepares the request to be sent to the AzureML online endpoint. + + Args: + body (Union[dict, list]): The input data. + + Returns: + Union[dict, bytes]: The prepared request body. + + """ + # replace the format based the model input + data_format = { + "input": body, + } + + body = str.encode(json.dumps(data_format)) + + headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + self.environment.azure_model_api_key), + "azureml-model-deployment": self.name, + } + + return headers, body + + def make_request(self, body: bytes, headers: dict) -> list[float]: + """ + Makes a request to the AzureML online endpoint. + + Args: + body (bytes): The request body. + headers (dict): The request headers. + + Returns: + list[float]: The response from the AzureML online endpoint. + + """ + try: + logger.info("Calling Custom Embedding Model API") + req = urllib.request.Request( + self.environment.azure_model_api_endpoint, body, headers + ) + response = urllib.request.urlopen(req) + logger.info("Custom Embedding Model response received") + data = json.loads(response.read()) + logger.info("Custom Embedding Model response parsed") + + return data + + except urllib.error.HTTPError as error: + logger.exception("The request failed with status code: " + str(error.code)) + raise + + def allowSelfSignedHttps(self, allowed: bool) -> None: + """ + Allows self-signed HTTPS requests. + + Args: + allowed (bool): Whether to allow self-signed HTTPS requests. + + """ + + # bypass the server certificate verification on client side + if ( + allowed + and not os.environ.get("PYTHONHTTPSVERIFY", "") + and getattr(ssl, "_create_unverified_context", None) + ): + ssl._create_default_https_context = ssl._create_unverified_context + else: + ssl._create_default_https_context = ssl.create_default_context + + def generate_embedding(self, chunk: str) -> list[float]: + """ + Generates the embedding for a given chunk of text. + + Args: + chunk (str): The input text. + + Returns: + list[float]: The generated embedding. + + """ + self.allowSelfSignedHttps( + True + ) # this line is needed if you use self-signed certificate in your scoring service. + + headers, body = self.prepare_request(chunk) + + result = self.make_request(body, headers) + + return result diff --git a/rag_experiment_accelerator/embedding/factory.py b/rag_experiment_accelerator/embedding/factory.py index 92b17b18..96a52c8a 100644 --- a/rag_experiment_accelerator/embedding/factory.py +++ b/rag_experiment_accelerator/embedding/factory.py @@ -1,5 +1,8 @@ from rag_experiment_accelerator.embedding.aoai_embedding_model import AOAIEmbeddingModel from rag_experiment_accelerator.embedding.st_embedding_model import STEmbeddingModel +from rag_experiment_accelerator.embedding.custom_embedding_model import ( + CustomEmbeddingModel, +) def create_embedding_model(model_type: str, **kwargs): @@ -8,6 +11,8 @@ def create_embedding_model(model_type: str, **kwargs): return AOAIEmbeddingModel(**kwargs) case "sentence-transformer": return STEmbeddingModel(**kwargs) + case "custom_embedding": + return CustomEmbeddingModel(**kwargs) case _: raise ValueError( f"Invalid embedding type: {model_type}. Must be one of ['azure', 'sentence-transformer']" diff --git a/rag_experiment_accelerator/embedding/tests/test_custom_embedding_model.py b/rag_experiment_accelerator/embedding/tests/test_custom_embedding_model.py new file mode 100644 index 00000000..33d49d54 --- /dev/null +++ b/rag_experiment_accelerator/embedding/tests/test_custom_embedding_model.py @@ -0,0 +1,97 @@ +from unittest.mock import patch, MagicMock +import json +import urllib +from rag_experiment_accelerator.embedding.custom_embedding_model import ( + CustomEmbeddingModel, +) +import ssl + + +def test_can_set_embedding_dimension(): + environment = MagicMock() + model = CustomEmbeddingModel("custom-embedding-model", environment, 123) + assert model.dimension == 123 + + +def test_prepare_request_success(): + environment = MagicMock() + environment.azure_model_api_key = "api_key" + model = CustomEmbeddingModel("custom-embedding-deployment", environment) + + body = {"text": "Hello world"} + headers, prepared_body = model.prepare_request(body) + + expected_headers = { + "Content-Type": "application/json", + "Authorization": "Bearer api_key", + "azureml-model-deployment": "custom-embedding-deployment", + } + expected_body = str.encode(json.dumps({"input": body})) + + assert headers == expected_headers + assert prepared_body == expected_body + + +@patch("urllib.request.urlopen") +def test_make_request_success(mock_urlopen): + environment = MagicMock() + environment.azure_model_api_endpoint = "http://fake-endpoint" + model = CustomEmbeddingModel("custom-embedding-model", environment) + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps([0.1, 0.2, 0.3]).encode("utf-8") + mock_urlopen.return_value = mock_response + + headers = {"Content-Type": "application/json"} + body = b'{"input": {"text": "Hello world"}}' + + result = model.make_request(body, headers) + assert result == [0.1, 0.2, 0.3] + + +@patch("urllib.request.urlopen") +def test_make_request_http_error(mock_urlopen): + environment = MagicMock() + environment.azure_model_api_endpoint = "http://fake-endpoint" + model = CustomEmbeddingModel("custom-embedding-model", environment) + + mock_urlopen.side_effect = urllib.error.HTTPError( + url=None, code=500, msg="Internal Server Error", hdrs=None, fp=None + ) + + headers = {"Content-Type": "application/json"} + body = b'{"input": {"text": "Hello world"}}' + + try: + model.make_request(body, headers) + except urllib.error.HTTPError as e: + assert e.code == 500 + + +@patch( + "rag_experiment_accelerator.embedding.custom_embedding_model.CustomEmbeddingModel.make_request" +) +def test_generate_embedding_success(mock_make_request): + environment = MagicMock() + model = CustomEmbeddingModel("custom-embedding-model", environment) + + mock_make_request.return_value = [0.1, 0.2, 0.3] + + result = model.generate_embedding("Hello world") + assert result == [0.1, 0.2, 0.3] + + +def test_allow_self_signed_http_true(): + environment = MagicMock() + model = CustomEmbeddingModel("custom-embedding-model", environment) + + model.allowSelfSignedHttps(True) + assert ssl._create_default_https_context == ssl._create_unverified_context + + +def test_allow_self_signed_http_false(): + environment = MagicMock() + model = CustomEmbeddingModel("custom-embedding-model", environment) + + model.allowSelfSignedHttps(False) + assert ssl._create_default_https_context == ssl.create_default_context