From eb2e1d40e75482b9ca275636c52bfe7d9eaa007a Mon Sep 17 00:00:00 2001 From: Andrew Bernat Date: Sun, 9 Mar 2025 21:51:34 -0700 Subject: [PATCH] Add tests for function list_bedrock_models. diff --git c/.github/workflows/aws-genai-cicd-suite.yml i/.github/workflows/aws-genai-cicd-suite.yml index b16c41b..656154f 100644 --- c/.github/workflows/aws-genai-cicd-suite.yml +++ i/.github/workflows/aws-genai-cicd-suite.yml @@ -25,16 +25,17 @@ jobs: - name: Checkout code uses: actions/checkout@v3 - - name: Set up Node.js - uses: actions/setup-node@v3 + - name: Set up Python + uses: actions/setup-python@v2 with: - node-version: '20' + python-version: 3.12 # Adjust the Python version as needed - - name: Install dependencies @actions/core and @actions/github - run: | - npm install @actions/core - npm install @actions/github - shell: bash + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Test + run: python -m unittest + working-directory: ./tests # check if required dependencies @actions/core and @actions/github are installed - name: Check if required dependencies are installed diff --git c/src/api/models/bedrock.py i/src/api/models/bedrock.py index be3fab2..39ed9ae 100644 --- c/src/api/models/bedrock.py +++ i/src/api/models/bedrock.py @@ -3,7 +3,7 @@ import json import logging import re import time -from abc import ABC +from abc import ABC, abstractmethod from typing import AsyncIterable, Iterable, Literal import boto3 @@ -73,8 +73,27 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = { ENCODER = tiktoken.get_encoding("cl100k_base") +class BedrockClientInterface(ABC): + @abstractmethod + def list_inference_profiles(self, **kwargs) -> dict: + pass -def list_bedrock_models() -> dict: + @abstractmethod + def list_foundation_models(self, **kwargs) -> dict: + pass + +class BedrockClient(BedrockClientInterface): + def __init__(self, client): + self.bedrock_client = client + + def list_inference_profiles(self, **kwargs) -> dict: + return self.bedrock_client.list_inference_profiles(**kwargs) + + def list_foundation_models(self, **kwargs) -> dict: + return self.bedrock_client.list_foundation_models(**kwargs) + + +def list_bedrock_models(client : BedrockClientInterface) -> dict: """Automatically getting a list of supported models. Returns a model list combines: @@ -86,11 +105,11 @@ def list_bedrock_models() -> dict: profile_list = [] if ENABLE_CROSS_REGION_INFERENCE: # List system defined inference profile IDs - response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") + response = client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]] # List foundation models, only cares about text outputs here. - response = bedrock_client.list_foundation_models(byOutputModality="TEXT") + response = client.list_foundation_models(byOutputModality="TEXT") for model in response["modelSummaries"]: model_id = model.get("modelId", "N/A") @@ -123,14 +142,14 @@ def list_bedrock_models() -> dict: # Initialize the model list. -bedrock_model_list = list_bedrock_models() +bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client)) class BedrockModel(BaseChatModel): def list_models(self) -> list[str]: """Always refresh the latest model list""" global bedrock_model_list - bedrock_model_list = list_bedrock_models() + bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client)) return list(bedrock_model_list.keys()) def validate(self, chat_request: ChatRequest): diff --git c/tests/__init__.py i/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git c/tests/list_bedrock_models_test.py i/tests/list_bedrock_models_test.py new file mode 100644 index 0000000..262fe20 --- /dev/null +++ i/tests/list_bedrock_models_test.py @@ -0,0 +1,179 @@ +from typing import Literal + +from src.api.models.bedrock import list_bedrock_models, BedrockClientInterface + +def test_default_model(): + client = FakeBedrockClient( + inference_profile("p1-id", "p1", "SYSTEM_DEFINED"), + inference_profile("p2-id", "p2", "APPLICATION"), + inference_profile("p3-id", "p3", "SYSTEM_DEFINED"), + ) + + models = list_bedrock_models(client) + + assert models == { + "anthropic.claude-3-sonnet-20240229-v1:0": { + "modalities": ["TEXT", "IMAGE"] + } + } + +def test_one_model(): + client = FakeBedrockClient( + model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT", "IMAGE"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id": { + "modalities": ["TEXT", "IMAGE"] + } + } + +def test_two_models(): + client = FakeBedrockClient( + model("model-id-1", "model-name-1", stream_supported=True, input_modalities=["TEXT", "IMAGE"]), + model("model-id-2", "model-name-2", stream_supported=True, input_modalities=["IMAGE"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id-1": { + "modalities": ["TEXT", "IMAGE"] + }, + "model-id-2": { + "modalities": ["IMAGE"] + } + } + +def test_filter_models(): + client = FakeBedrockClient( + model("model-id", "model-name-1", stream_supported=True, input_modalities=["TEXT"], status="LEGACY"), + model("model-id-no-stream", "model-name-2", stream_supported=False, input_modalities=["TEXT", "IMAGE"]), + model("model-id-not-active", "model-name-3", stream_supported=True, status="DISABLED"), + model("model-id-not-text-output", "model-name-4", stream_supported=True, output_modalities=["IMAGE"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id": { + "modalities": ["TEXT"] + } + } + +def test_one_inference_profile(): + client = FakeBedrockClient( + inference_profile("us.model-id", "p1", "SYSTEM_DEFINED"), + model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id": { + "modalities": ["TEXT"] + }, + "us.model-id": { + "modalities": ["TEXT"] + } + } + +def test_default_model_on_throw(): + client = ThrowingBedrockClient() + + models = list_bedrock_models(client) + + assert models == { + "anthropic.claude-3-sonnet-20240229-v1:0": { + "modalities": ["TEXT", "IMAGE"] + } + } + +def inference_profile(profile_id: str, name: str, profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"]): + return { + "inferenceProfileName": name, + "inferenceProfileId": profile_id, + "type": profile_type + } + +def model( + model_id: str, + model_name: str, + input_modalities: list[str] = None, + output_modalities: list[str] = None, + stream_supported: bool = False, + inference_types: list[str] = None, + status: str = "ACTIVE") -> dict: + if input_modalities is None: + input_modalities = ["TEXT"] + if output_modalities is None: + output_modalities = ["TEXT"] + if inference_types is None: + inference_types = ["ON_DEMAND"] + return { + "modelArn": "arn:model:" + model_id, + "modelId": model_id, + "modelName": model_name, + "providerName": "anthropic", + "inputModalities":input_modalities, + "outputModalities": output_modalities, + "responseStreamingSupported": stream_supported, + "customizationsSupported": ["FINE_TUNING"], + "inferenceTypesSupported": inference_types, + "modelLifecycle": { + "status": status + } + } + +def _filter_inference_profiles(inference_profiles: list[dict], profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"], max_results: int = 100): + return [p for p in inference_profiles if p.get("type") == profile_type][:max_results] + +def _filter_models( + models: list[dict], + provider_name: str | None, + customization_type: Literal["FINE_TUNING","CONTINUED_PRE_TRAINING","DISTILLATION"] | None, + output_modality: Literal["TEXT","IMAGE","EMBEDDING"] | None, + inference_type: Literal["ON_DEMAND","PROVISIONED"] | None): + return [m for m in models if + (provider_name is None or m.get("providerName") == provider_name) and + (output_modality is None or output_modality in m.get("outputModalities")) and + (customization_type is None or customization_type in m.get("customizationsSupported")) and + (inference_type is None or inference_type in m.get("inferenceTypesSupported")) + ] + +class ThrowingBedrockClient(BedrockClientInterface): + def list_inference_profiles(self, **kwargs) -> dict: + raise Exception("throwing bedrock client always throws exception") + def list_foundation_models(self, **kwargs) -> dict: + raise Exception("throwing bedrock client always throws exception") + +class FakeBedrockClient(BedrockClientInterface): + def __init__(self, *args): + self.inference_profiles = [p for p in args if p.get("inferenceProfileId", "") != ""] + self.models = [m for m in args if m.get("modelId", "") != ""] + + unexpected = [u for u in args if (u.get("modelId", "") == "" and u.get("inferenceProfileId", "") == "")] + if len(unexpected) > 0: + raise Exception("expected a model or a profile") + + def list_inference_profiles(self, **kwargs) -> dict: + return { + "inferenceProfileSummaries": _filter_inference_profiles( + self.inference_profiles, + profile_type=kwargs["typeEquals"], + max_results=kwargs.get("maxResults", 100) + ) + } + + def list_foundation_models(self, **kwargs) -> dict: + return { + "modelSummaries": _filter_models( + self.models, + provider_name=kwargs.get("byProvider", None), + customization_type=kwargs.get("byCustomizationType", None), + output_modality=kwargs.get("byOutputModality", None), + inference_type=kwargs.get("byInferenceType", None) + ) + } \ No newline at end of file --- .github/workflows/aws-genai-cicd-suite.yml | 17 +- src/api/models/bedrock.py | 31 +++- tests/__init__.py | 0 tests/list_bedrock_models_test.py | 179 +++++++++++++++++++++ 4 files changed, 213 insertions(+), 14 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/list_bedrock_models_test.py diff --git a/.github/workflows/aws-genai-cicd-suite.yml b/.github/workflows/aws-genai-cicd-suite.yml index b16c41b8..656154f3 100644 --- a/.github/workflows/aws-genai-cicd-suite.yml +++ b/.github/workflows/aws-genai-cicd-suite.yml @@ -25,16 +25,17 @@ jobs: - name: Checkout code uses: actions/checkout@v3 - - name: Set up Node.js - uses: actions/setup-node@v3 + - name: Set up Python + uses: actions/setup-python@v2 with: - node-version: '20' + python-version: 3.12 # Adjust the Python version as needed - - name: Install dependencies @actions/core and @actions/github - run: | - npm install @actions/core - npm install @actions/github - shell: bash + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Test + run: python -m unittest + working-directory: ./tests # check if required dependencies @actions/core and @actions/github are installed - name: Check if required dependencies are installed diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index be3fab28..39ed9ae6 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -3,7 +3,7 @@ import logging import re import time -from abc import ABC +from abc import ABC, abstractmethod from typing import AsyncIterable, Iterable, Literal import boto3 @@ -73,8 +73,27 @@ def get_inference_region_prefix(): ENCODER = tiktoken.get_encoding("cl100k_base") +class BedrockClientInterface(ABC): + @abstractmethod + def list_inference_profiles(self, **kwargs) -> dict: + pass -def list_bedrock_models() -> dict: + @abstractmethod + def list_foundation_models(self, **kwargs) -> dict: + pass + +class BedrockClient(BedrockClientInterface): + def __init__(self, client): + self.bedrock_client = client + + def list_inference_profiles(self, **kwargs) -> dict: + return self.bedrock_client.list_inference_profiles(**kwargs) + + def list_foundation_models(self, **kwargs) -> dict: + return self.bedrock_client.list_foundation_models(**kwargs) + + +def list_bedrock_models(client : BedrockClientInterface) -> dict: """Automatically getting a list of supported models. Returns a model list combines: @@ -86,11 +105,11 @@ def list_bedrock_models() -> dict: profile_list = [] if ENABLE_CROSS_REGION_INFERENCE: # List system defined inference profile IDs - response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") + response = client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]] # List foundation models, only cares about text outputs here. - response = bedrock_client.list_foundation_models(byOutputModality="TEXT") + response = client.list_foundation_models(byOutputModality="TEXT") for model in response["modelSummaries"]: model_id = model.get("modelId", "N/A") @@ -123,14 +142,14 @@ def list_bedrock_models() -> dict: # Initialize the model list. -bedrock_model_list = list_bedrock_models() +bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client)) class BedrockModel(BaseChatModel): def list_models(self) -> list[str]: """Always refresh the latest model list""" global bedrock_model_list - bedrock_model_list = list_bedrock_models() + bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client)) return list(bedrock_model_list.keys()) def validate(self, chat_request: ChatRequest): diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/list_bedrock_models_test.py b/tests/list_bedrock_models_test.py new file mode 100644 index 00000000..262fe20f --- /dev/null +++ b/tests/list_bedrock_models_test.py @@ -0,0 +1,179 @@ +from typing import Literal + +from src.api.models.bedrock import list_bedrock_models, BedrockClientInterface + +def test_default_model(): + client = FakeBedrockClient( + inference_profile("p1-id", "p1", "SYSTEM_DEFINED"), + inference_profile("p2-id", "p2", "APPLICATION"), + inference_profile("p3-id", "p3", "SYSTEM_DEFINED"), + ) + + models = list_bedrock_models(client) + + assert models == { + "anthropic.claude-3-sonnet-20240229-v1:0": { + "modalities": ["TEXT", "IMAGE"] + } + } + +def test_one_model(): + client = FakeBedrockClient( + model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT", "IMAGE"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id": { + "modalities": ["TEXT", "IMAGE"] + } + } + +def test_two_models(): + client = FakeBedrockClient( + model("model-id-1", "model-name-1", stream_supported=True, input_modalities=["TEXT", "IMAGE"]), + model("model-id-2", "model-name-2", stream_supported=True, input_modalities=["IMAGE"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id-1": { + "modalities": ["TEXT", "IMAGE"] + }, + "model-id-2": { + "modalities": ["IMAGE"] + } + } + +def test_filter_models(): + client = FakeBedrockClient( + model("model-id", "model-name-1", stream_supported=True, input_modalities=["TEXT"], status="LEGACY"), + model("model-id-no-stream", "model-name-2", stream_supported=False, input_modalities=["TEXT", "IMAGE"]), + model("model-id-not-active", "model-name-3", stream_supported=True, status="DISABLED"), + model("model-id-not-text-output", "model-name-4", stream_supported=True, output_modalities=["IMAGE"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id": { + "modalities": ["TEXT"] + } + } + +def test_one_inference_profile(): + client = FakeBedrockClient( + inference_profile("us.model-id", "p1", "SYSTEM_DEFINED"), + model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id": { + "modalities": ["TEXT"] + }, + "us.model-id": { + "modalities": ["TEXT"] + } + } + +def test_default_model_on_throw(): + client = ThrowingBedrockClient() + + models = list_bedrock_models(client) + + assert models == { + "anthropic.claude-3-sonnet-20240229-v1:0": { + "modalities": ["TEXT", "IMAGE"] + } + } + +def inference_profile(profile_id: str, name: str, profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"]): + return { + "inferenceProfileName": name, + "inferenceProfileId": profile_id, + "type": profile_type + } + +def model( + model_id: str, + model_name: str, + input_modalities: list[str] = None, + output_modalities: list[str] = None, + stream_supported: bool = False, + inference_types: list[str] = None, + status: str = "ACTIVE") -> dict: + if input_modalities is None: + input_modalities = ["TEXT"] + if output_modalities is None: + output_modalities = ["TEXT"] + if inference_types is None: + inference_types = ["ON_DEMAND"] + return { + "modelArn": "arn:model:" + model_id, + "modelId": model_id, + "modelName": model_name, + "providerName": "anthropic", + "inputModalities":input_modalities, + "outputModalities": output_modalities, + "responseStreamingSupported": stream_supported, + "customizationsSupported": ["FINE_TUNING"], + "inferenceTypesSupported": inference_types, + "modelLifecycle": { + "status": status + } + } + +def _filter_inference_profiles(inference_profiles: list[dict], profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"], max_results: int = 100): + return [p for p in inference_profiles if p.get("type") == profile_type][:max_results] + +def _filter_models( + models: list[dict], + provider_name: str | None, + customization_type: Literal["FINE_TUNING","CONTINUED_PRE_TRAINING","DISTILLATION"] | None, + output_modality: Literal["TEXT","IMAGE","EMBEDDING"] | None, + inference_type: Literal["ON_DEMAND","PROVISIONED"] | None): + return [m for m in models if + (provider_name is None or m.get("providerName") == provider_name) and + (output_modality is None or output_modality in m.get("outputModalities")) and + (customization_type is None or customization_type in m.get("customizationsSupported")) and + (inference_type is None or inference_type in m.get("inferenceTypesSupported")) + ] + +class ThrowingBedrockClient(BedrockClientInterface): + def list_inference_profiles(self, **kwargs) -> dict: + raise Exception("throwing bedrock client always throws exception") + def list_foundation_models(self, **kwargs) -> dict: + raise Exception("throwing bedrock client always throws exception") + +class FakeBedrockClient(BedrockClientInterface): + def __init__(self, *args): + self.inference_profiles = [p for p in args if p.get("inferenceProfileId", "") != ""] + self.models = [m for m in args if m.get("modelId", "") != ""] + + unexpected = [u for u in args if (u.get("modelId", "") == "" and u.get("inferenceProfileId", "") == "")] + if len(unexpected) > 0: + raise Exception("expected a model or a profile") + + def list_inference_profiles(self, **kwargs) -> dict: + return { + "inferenceProfileSummaries": _filter_inference_profiles( + self.inference_profiles, + profile_type=kwargs["typeEquals"], + max_results=kwargs.get("maxResults", 100) + ) + } + + def list_foundation_models(self, **kwargs) -> dict: + return { + "modelSummaries": _filter_models( + self.models, + provider_name=kwargs.get("byProvider", None), + customization_type=kwargs.get("byCustomizationType", None), + output_modality=kwargs.get("byOutputModality", None), + inference_type=kwargs.get("byInferenceType", None) + ) + } \ No newline at end of file