From 00091a7d4962b8adc654704cb644ea43298dc5bc Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Wed, 25 Jun 2025 22:02:49 -0400 Subject: [PATCH 1/3] Integrated aqua with model group. --- ads/aqua/model/constants.py | 1 + ads/aqua/model/entities.py | 15 +++ ads/aqua/model/model.py | 107 +++++++++++-------- ads/aqua/modeldeployment/deployment.py | 56 +++++----- ads/model/model_metadata.py | 2 +- tests/unitary/with_extras/aqua/test_model.py | 42 +++++--- 6 files changed, 132 insertions(+), 91 deletions(-) diff --git a/ads/aqua/model/constants.py b/ads/aqua/model/constants.py index 194245fe4..ce3e3f51d 100644 --- a/ads/aqua/model/constants.py +++ b/ads/aqua/model/constants.py @@ -20,6 +20,7 @@ class ModelCustomMetadataFields(ExtendedEnum): DEPLOYMENT_CONTAINER_URI = "deployment-container-uri" MULTIMODEL_GROUP_COUNT = "model_group_count" MULTIMODEL_METADATA = "multi_model_metadata" + MODEL_GROUP_CONFIG = "OCI_MODEL_GROUP_CUSTOM_METADATA" class ModelTask(ExtendedEnum): diff --git a/ads/aqua/model/entities.py b/ads/aqua/model/entities.py index 0bbcdfb0b..72c374f74 100644 --- a/ads/aqua/model/entities.py +++ b/ads/aqua/model/entities.py @@ -383,3 +383,18 @@ class ModelFileDescription(Serializable): class Config: alias_generator = to_camel extra = "allow" + + +class MemberModel(Serializable): + """Describes the member model of a model group. + + Attributes: + model_id (str): The id of member model. + inference_key (str): The inference key of member model. + """ + + model_id: str = Field(..., description="The id of member model.") + inference_key: str = Field(None, description="The inference key of member model.") + + class Config: + extra = "allow" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 2b5d7108f..925535139 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -79,7 +79,7 @@ AquaModelReadme, AquaModelSummary, ImportModelDetails, - ModelFileDescription, + MemberModel, ModelValidationResult, ) from ads.aqua.model.enums import MultiModelSupportedTaskType @@ -102,6 +102,7 @@ ) from ads.model import DataScienceModel from ads.model.common.utils import MetadataArtifactPathType +from ads.model.datascience_model_group import DataScienceModelGroup from ads.model.model_metadata import ( MetadataCustomCategory, ModelCustomMetadata, @@ -235,13 +236,15 @@ def create( def create_multi( self, models: List[AquaMultiModelRef], + create_deployment_details, + model_config_summary, project_id: Optional[str] = None, compartment_id: Optional[str] = None, freeform_tags: Optional[Dict] = None, defined_tags: Optional[Dict] = None, source_models: Optional[Dict[str, DataScienceModel]] = None, **kwargs, # noqa: ARG002 - ) -> DataScienceModel: + ) -> DataScienceModelGroup: """ Creates a multi-model grouping using the provided model list. @@ -249,6 +252,11 @@ def create_multi( ---------- models : List[AquaMultiModelRef] List of AquaMultiModelRef instances for creating a multi-model group. + create_deployment_details : CreateModelDeploymentDetails + An instance of CreateModelDeploymentDetails containing all required and optional + fields for creating a model deployment via Aqua. + model_config_summary : ModelConfigSummary + Summary Model Deployment configuration for the group of models. project_id : Optional[str] The project ID for the multi-model group. compartment_id : Optional[str] @@ -264,8 +272,8 @@ def create_multi( Returns ------- - DataScienceModel - Instance of DataScienceModel object. + DataScienceModelGroup + Instance of DataScienceModelGroup object. """ if not models: @@ -274,7 +282,6 @@ def create_multi( ) display_name_list = [] - model_file_description_list: List[ModelFileDescription] = [] model_custom_metadata = ModelCustomMetadata() service_inference_containers = ( @@ -337,11 +344,6 @@ def create_multi( "Please register the model with a file description." ) - # Track model file description in a validated structure - model_file_description_list.append( - ModelFileDescription(**model_file_description) - ) - # Ensure base model has a valid artifact if not source_model.artifact: logger.error( @@ -396,11 +398,6 @@ def create_multi( "Please register the model with a file description." ) - # Track model file description in a validated structure - model_file_description_list.append( - ModelFileDescription(**ft_model_file_description) - ) - # Extract fine-tuned model path _, fine_tune_path = extract_fine_tune_artifacts_path( fine_tune_source_model @@ -481,6 +478,22 @@ def create_multi( description="Number of models in the group.", category="Other", ) + model_custom_metadata.add( + key=ModelCustomMetadataFields.MODEL_GROUP_CONFIG, + value=self._build_model_group_config( + create_deployment_details=create_deployment_details, + model_config_summary=model_config_summary, + deployment_container=deployment_container, + ), + description="Configs required to deploy multi models.", + category="Other", + ) + model_custom_metadata.add( + key=ModelCustomMetadataFields.MULTIMODEL_METADATA, + value=json.dumps([model.model_dump() for model in models]), + description="Metadata to store user's multi model input.", + category="Other", + ) # Combine tags. The `Tags.AQUA_TAG` has been excluded, because we don't want to show # the models created for multi-model purpose in the AQUA models list. @@ -491,8 +504,8 @@ def create_multi( } # Create multi-model group - custom_model = ( - DataScienceModel() + custom_model_group = ( + DataScienceModelGroup() .with_compartment_id(compartment_id) .with_project_id(project_id) .with_display_name(model_group_display_name) @@ -500,37 +513,15 @@ def create_multi( .with_freeform_tags(**tags) .with_defined_tags(**(defined_tags or {})) .with_custom_metadata_list(model_custom_metadata) + .with_member_models( + [MemberModel(model_id=model.model_id).model_dump() for model in models] + ) ) - # Update multi model file description to attach artifacts - custom_model.with_model_file_description( - json_dict=ModelFileDescription( - models=[ - models - for model_file_description in model_file_description_list - for models in model_file_description.models - ] - ).model_dump(by_alias=True) - ) - - # Finalize creation - custom_model.create(model_by_reference=True) + custom_model_group.create() logger.info( - f"Aqua Model '{custom_model.id}' created with models: {', '.join(display_name_list)}." - ) - - # Create custom metadata for multi model metadata - custom_model.create_custom_metadata_artifact( - metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA, - artifact_path_or_content=json.dumps( - [model.model_dump() for model in models] - ).encode(), - path_type=MetadataArtifactPathType.CONTENT, - ) - - logger.debug( - f"Multi model metadata uploaded for Aqua model: {custom_model.id}." + f"Aqua Model Group'{custom_model_group.id}' created with models: {', '.join(display_name_list)}." ) # Track telemetry event @@ -540,7 +531,33 @@ def create_multi( detail=combined_models, ) - return custom_model + return custom_model_group + + def _build_model_group_config( + self, + create_deployment_details, + model_config_summary, + deployment_container: str, + ) -> str: + """Builds model group config required to deploy multi models.""" + container_type_key = ( + create_deployment_details.container_family or deployment_container + ) + container_config = self.get_container_config_item(container_type_key) + container_spec = container_config.spec if container_config else UNKNOWN + + container_params = container_spec.cli_param if container_spec else UNKNOWN + + from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig + + multi_model_config = ModelGroupConfig.from_create_model_deployment_details( + create_deployment_details, + model_config_summary, + container_type_key, + container_params, + ) + + return multi_model_config.model_dump_json() @telemetry(entry_point="plugin=model&action=get", name="aqua") def get(self, model_id: str) -> "AquaModel": diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 82402af4b..ee4b721a6 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -7,7 +7,7 @@ import shlex import threading from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from cachetools import TTLCache, cached from oci.data_science.models import ModelDeploymentShapeSummary @@ -40,7 +40,6 @@ AQUA_MODEL_TYPE_CUSTOM, AQUA_MODEL_TYPE_MULTI, AQUA_MODEL_TYPE_SERVICE, - AQUA_MULTI_MODEL_CONFIG, MODEL_BY_REFERENCE_OSS_PATH_KEY, MODEL_NAME_DELIMITER, UNKNOWN_DICT, @@ -65,7 +64,6 @@ ConfigValidationError, CreateModelDeploymentDetails, ) -from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig from ads.common.object_storage_details import ObjectStorageDetails from ads.common.utils import UNKNOWN, get_log_links from ads.common.work_request import DataScienceWorkRequest @@ -79,6 +77,7 @@ PROJECT_OCID, ) from ads.model.datascience_model import DataScienceModel +from ads.model.datascience_model_group import DataScienceModelGroup from ads.model.deployment import ( ModelDeployment, ModelDeploymentContainerRuntime, @@ -325,8 +324,10 @@ def create( f"Multi models ({source_model_ids}) provided. Delegating to multi model creation method." ) - aqua_model = model_app.create_multi( + aqua_model_group = model_app.create_multi( models=create_deployment_details.models, + create_deployment_details=create_deployment_details, + model_config_summary=model_config_summary, compartment_id=compartment_id, project_id=project_id, freeform_tags=freeform_tags, @@ -334,8 +335,7 @@ def create( source_models=source_models, ) return self._create_multi( - aqua_model=aqua_model, - model_config_summary=model_config_summary, + aqua_model_group=aqua_model_group, create_deployment_details=create_deployment_details, container_config=container_config, ) @@ -562,8 +562,7 @@ def _create( def _create_multi( self, - aqua_model: DataScienceModel, - model_config_summary: ModelDeploymentConfigSummary, + aqua_model_group: DataScienceModelGroup, create_deployment_details: CreateModelDeploymentDetails, container_config: AquaContainerConfig, ) -> AquaDeployment: @@ -571,15 +570,14 @@ def _create_multi( Parameters ---------- - model_config_summary : model_config_summary - Summary Model Deployment configuration for the group of models. - aqua_model : DataScienceModel - An instance of Aqua data science model. + aqua_model_group : DataScienceModelGroup + An instance of Aqua data science model group. create_deployment_details : CreateModelDeploymentDetails An instance of CreateModelDeploymentDetails containing all required and optional fields for creating a model deployment via Aqua. container_config: Dict Container config dictionary. + Returns ------- AquaDeployment @@ -589,23 +587,12 @@ def _create_multi( env_var = {**(create_deployment_details.env_var or UNKNOWN_DICT)} container_type_key = self._get_container_type_key( - model=aqua_model, + model=aqua_model_group, container_family=create_deployment_details.container_family, ) container_config = self.get_container_config_item(container_type_key) container_spec = container_config.spec if container_config else UNKNOWN - container_params = container_spec.cli_param if container_spec else UNKNOWN - - multi_model_config = ModelGroupConfig.from_create_model_deployment_details( - create_deployment_details, - model_config_summary, - container_type_key, - container_params, - ) - - env_var.update({AQUA_MULTI_MODEL_CONFIG: multi_model_config.model_dump_json()}) - env_vars = container_spec.env_vars if container_spec else [] for env in env_vars: if isinstance(env, dict): @@ -614,7 +601,7 @@ def _create_multi( if key not in env_var: env_var.update(env) - logger.info(f"Env vars used for deploying {aqua_model.id} : {env_var}.") + logger.info(f"Env vars used for deploying {aqua_model_group.id} : {env_var}.") container_image_uri = ( create_deployment_details.container_image_uri @@ -627,7 +614,7 @@ def _create_multi( container_spec.health_check_port if container_spec else None ) tags = { - Tags.AQUA_MODEL_ID_TAG: aqua_model.id, + Tags.AQUA_MODEL_ID_TAG: aqua_model_group.id, Tags.MULTIMODEL_TYPE_TAG: "true", Tags.AQUA_TAG: "active", **(create_deployment_details.freeform_tags or UNKNOWN_DICT), @@ -637,7 +624,7 @@ def _create_multi( aqua_deployment = self._create_deployment( create_deployment_details=create_deployment_details, - aqua_model_id=aqua_model.id, + aqua_model_id=aqua_model_group.id, model_name=model_name, model_type=AQUA_MODEL_TYPE_MULTI, container_image_uri=container_image_uri, @@ -794,7 +781,9 @@ def _create_deployment( ) @staticmethod - def _get_container_type_key(model: DataScienceModel, container_family: str) -> str: + def _get_container_type_key( + model: Union[DataScienceModel, DataScienceModelGroup], container_family: str + ) -> str: container_type_key = UNKNOWN if container_family: container_type_key = container_family @@ -970,7 +959,12 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": f"Invalid multi model deployment {model_deployment_id}." f"Make sure the {Tags.AQUA_MODEL_ID_TAG} tag is added to the deployment." ) - aqua_model = DataScienceModel.from_id(aqua_model_id) + + if "datasciencemodelgroup" in aqua_model_id: + aqua_model = DataScienceModelGroup.from_id(aqua_model_id) + else: + aqua_model = DataScienceModel.from_id(aqua_model_id) + custom_metadata_list = aqua_model.custom_metadata_list multi_model_metadata_value = custom_metadata_list.get( ModelCustomMetadataFields.MULTIMODEL_METADATA, @@ -984,7 +978,9 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": f"Ensure that the required custom metadata `{ModelCustomMetadataFields.MULTIMODEL_METADATA}` is added to the AQUA multi-model `{aqua_model.display_name}` ({aqua_model.id})." ) multi_model_metadata = json.loads( - aqua_model.dsc_model.get_custom_metadata_artifact( + multi_model_metadata_value + if isinstance(aqua_model, DataScienceModelGroup) + else aqua_model.dsc_model.get_custom_metadata_artifact( metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA ).decode("utf-8") ) diff --git a/ads/model/model_metadata.py b/ads/model/model_metadata.py index 6b73b17f5..f0428ec9c 100644 --- a/ads/model/model_metadata.py +++ b/ads/model/model_metadata.py @@ -37,7 +37,7 @@ logger = logging.getLogger("ADS") METADATA_SIZE_LIMIT = 32000 -METADATA_VALUE_LENGTH_LIMIT = 255 +METADATA_VALUE_LENGTH_LIMIT = 16000 METADATA_DESCRIPTION_LENGTH_LIMIT = 255 _METADATA_EMPTY_VALUE = "NA" CURRENT_WORKING_DIR = "." diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index 61d9b849d..5f96cb9a4 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -42,6 +42,7 @@ from ads.aqua.model.enums import MultiModelSupportedTaskType from ads.common.object_storage_details import ObjectStorageDetails from ads.model.datascience_model import DataScienceModel +from ads.model.datascience_model_group import DataScienceModelGroup from ads.model.model_metadata import ( ModelCustomMetadata, ModelProvenanceMetadata, @@ -457,14 +458,12 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create): ) assert model.provenance_metadata.training_id == "test_training_id" - @patch.object(DataScienceModel, "create_custom_metadata_artifact") - @patch.object(DataScienceModel, "create") + @patch.object(DataScienceModelGroup, "create") @patch.object(AquaApp, "get_container_config") def test_create_multimodel( self, mock_get_container_config, - mock_create, - mock_create_custom_metadata_artifact, + mock_create_group, ): mock_get_container_config.return_value = get_container_config() mock_model = MagicMock() @@ -482,6 +481,8 @@ def test_create_multimodel( ) mock_model.custom_metadata_list = custom_metadata_list + mock_create_deployment_details = MagicMock() + mock_model_config_summary = MagicMock() model_info_1 = AquaMultiModelRef( model_id="test_model_id_1", @@ -503,8 +504,10 @@ def test_create_multimodel( } with pytest.raises(AquaValueError): - model = self.app.create_multi( + model_group = self.app.create_multi( models=[model_info_1, model_info_2], + create_deployment_details=mock_create_deployment_details, + model_config_summary=mock_model_config_summary, project_id="test_project_id", compartment_id="test_compartment_id", source_models=model_details, @@ -513,8 +516,10 @@ def test_create_multimodel( mock_model.freeform_tags["aqua_service_model"] = TestDataset.SERVICE_MODEL_ID with pytest.raises(AquaValueError): - model = self.app.create_multi( + model_group = self.app.create_multi( models=[model_info_1, model_info_2], + create_deployment_details=mock_create_deployment_details, + model_config_summary=mock_model_config_summary, project_id="test_project_id", compartment_id="test_compartment_id", source_models=model_details, @@ -523,8 +528,10 @@ def test_create_multimodel( mock_model.freeform_tags["task"] = "text-generation" with pytest.raises(AquaValueError): - model = self.app.create_multi( + model_group = self.app.create_multi( models=[model_info_1, model_info_2], + create_deployment_details=mock_create_deployment_details, + model_config_summary=mock_model_config_summary, project_id="test_project_id", compartment_id="test_compartment_id", source_models=model_details, @@ -541,8 +548,10 @@ def test_create_multimodel( model_info_1.model_task = "invalid_task" with pytest.raises(AquaValueError): - model = self.app.create_multi( + model_group = self.app.create_multi( models=[model_info_1, model_info_2], + create_deployment_details=mock_create_deployment_details, + model_config_summary=mock_model_config_summary, project_id="test_project_id", compartment_id="test_compartment_id", source_models=model_details, @@ -552,8 +561,10 @@ def test_create_multimodel( model_info_1.model_task = None mock_model.freeform_tags["task"] = "unsupported_task" with pytest.raises(AquaValueError): - model = self.app.create_multi( + model_group = self.app.create_multi( models=[model_info_1, model_info_2], + create_deployment_details=mock_create_deployment_details, + model_config_summary=mock_model_config_summary, project_id="test_project_id", compartment_id="test_compartment_id", source_models=model_details, @@ -580,22 +591,23 @@ def test_create_multimodel( model_details[model_info_3.model_id] = mock_model # will create a multi-model group - model = self.app.create_multi( + model_group = self.app.create_multi( models=[model_info_1, model_info_2, model_info_3], + create_deployment_details=mock_create_deployment_details, + model_config_summary=mock_model_config_summary, project_id="test_project_id", compartment_id="test_compartment_id", source_models=model_details, ) - mock_create.assert_called_with(model_by_reference=True) + mock_create_group.assert_called() mock_model.compartment_id = TestDataset.SERVICE_COMPARTMENT_ID - mock_create.return_value = mock_model - assert model.freeform_tags == {"aqua_multimodel": "true"} - assert model.custom_metadata_list.get("model_group_count").value == "3" + assert model_group.freeform_tags == {"aqua_multimodel": "true"} + assert model_group.custom_metadata_list.get("model_group_count").value == "3" assert ( - model.custom_metadata_list.get("deployment-container").value + model_group.custom_metadata_list.get("deployment-container").value == "odsc-vllm-serving" ) From 46d23f0cc02bedcf91674205a58e55e9b96d6437 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Fri, 27 Jun 2025 18:08:07 -0400 Subject: [PATCH 2/3] Updated pr. --- ads/aqua/modeldeployment/deployment.py | 5 +- ads/aqua/modeldeployment/entities.py | 21 +- ads/model/deployment/model_deployment.py | 124 ++++--- .../deployment/model_deployment_runtime.py | 40 ++- .../test_model_deployment_v2.py | 339 +----------------- 5 files changed, 152 insertions(+), 377 deletions(-) diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index ee4b721a6..3e3f06788 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -719,11 +719,14 @@ def _create_deployment( .with_health_check_port(health_check_port) .with_env(env_var) .with_deployment_mode(ModelDeploymentMode.HTTPS) - .with_model_uri(aqua_model_id) .with_region(self.region) .with_overwrite_existing_artifact(True) .with_remove_existing_artifact(True) ) + if "datasciencemodelgroup" in aqua_model_id: + container_runtime.with_model_group_id(aqua_model_id) + else: + container_runtime.with_model_uri(aqua_model_id) if cmd_var: container_runtime.with_cmd(cmd_var) diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py index ebce26dc8..fde05e666 100644 --- a/ads/aqua/modeldeployment/entities.py +++ b/ads/aqua/modeldeployment/entities.py @@ -147,13 +147,25 @@ def from_oci_model_deployment( AquaDeployment: The instance of the Aqua model deployment. """ - instance_configuration = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration + model_deployment_configuration_details = ( + oci_model_deployment.model_deployment_configuration_details + ) + if model_deployment_configuration_details.deployment_type == "SINGLE_MODEL": + instance_configuration = model_deployment_configuration_details.model_configuration_details.instance_configuration + instance_count = model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count + model_id = model_deployment_configuration_details.model_configuration_details.model_id + else: + instance_configuration = model_deployment_configuration_details.infrastructure_configuration_details.instance_configuration + instance_count = model_deployment_configuration_details.infrastructure_configuration_details.scaling_policy.instance_count + model_id = model_deployment_configuration_details.model_group_configuration_details.model_group_id + instance_shape_config_details = ( instance_configuration.model_deployment_instance_shape_config_details ) - instance_count = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count - environment_variables = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.environment_variables - cmd = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.cmd + environment_variables = model_deployment_configuration_details.environment_configuration_details.environment_variables + cmd = ( + model_deployment_configuration_details.environment_configuration_details.cmd + ) shape_info = ShapeInfo( instance_shape=instance_configuration.instance_shape_name, instance_count=instance_count, @@ -168,7 +180,6 @@ def from_oci_model_deployment( else None ), ) - model_id = oci_model_deployment._model_deployment_configuration_details.model_configuration_details.model_id tags = {} tags.update(oci_model_deployment.freeform_tags or UNKNOWN_DICT) tags.update(oci_model_deployment.defined_tags or UNKNOWN_DICT) diff --git a/ads/model/deployment/model_deployment.py b/ads/model/deployment/model_deployment.py index 56a70c112..21e083499 100644 --- a/ads/model/deployment/model_deployment.py +++ b/ads/model/deployment/model_deployment.py @@ -1,22 +1,27 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2021, 2023 Oracle and/or its affiliates. +# Copyright (c) 2021, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import collections import copy import datetime -import oci -import warnings import time -from typing import Dict, List, Union, Any +import warnings +from typing import Any, Dict, List, Union +import oci import oci.loggingsearch -from ads.common import auth as authutil import pandas as pd -from ads.model.serde.model_input import JsonModelInputSERDE +from oci.data_science.models import ( + CreateModelDeploymentDetails, + LogDetails, + UpdateModelDeploymentDetails, +) + +from ads.common import auth as authutil +from ads.common import utils as ads_utils from ads.common.oci_logging import ( LOG_INTERVAL, LOG_RECORDS_LIMIT, @@ -30,10 +35,10 @@ from ads.model.deployment.common.utils import send_request from ads.model.deployment.model_deployment_infrastructure import ( DEFAULT_BANDWIDTH_MBPS, + DEFAULT_MEMORY_IN_GBS, + DEFAULT_OCPUS, DEFAULT_REPLICA, DEFAULT_SHAPE_NAME, - DEFAULT_OCPUS, - DEFAULT_MEMORY_IN_GBS, MODEL_DEPLOYMENT_INFRASTRUCTURE_TYPE, ModelDeploymentInfrastructure, ) @@ -45,18 +50,14 @@ ModelDeploymentRuntimeType, OCIModelDeploymentRuntimeType, ) +from ads.model.serde.model_input import JsonModelInputSERDE from ads.model.service.oci_datascience_model_deployment import ( OCIDataScienceModelDeployment, ) -from ads.common import utils as ads_utils + from .common import utils from .common.utils import State from .model_deployment_properties import ModelDeploymentProperties -from oci.data_science.models import ( - LogDetails, - CreateModelDeploymentDetails, - UpdateModelDeploymentDetails, -) DEFAULT_WAIT_TIME = 1200 DEFAULT_POLL_INTERVAL = 10 @@ -964,7 +965,9 @@ def predict( except oci.exceptions.ServiceError as ex: # When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend. if ex.status == 429: - bandwidth_mbps = self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS + bandwidth_mbps = ( + self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS + ) utils.get_logger().warning( f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps." "To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024." @@ -1518,9 +1521,9 @@ def _build_model_deployment_details(self) -> CreateModelDeploymentDetails: self.infrastructure.CONST_CATEGORY_LOG_DETAILS: self._build_category_log_details(), } - return OCIDataScienceModelDeployment( - **create_model_deployment_details - ).to_oci_model(CreateModelDeploymentDetails) + return CreateModelDeploymentDetails( + **ads_utils.batch_convert_case(create_model_deployment_details, "snake") + ) def _update_model_deployment_details( self, **kwargs @@ -1545,9 +1548,10 @@ def _update_model_deployment_details( self.infrastructure.CONST_MODEL_DEPLOYMENT_CONFIG_DETAILS: self._build_model_deployment_configuration_details(), self.infrastructure.CONST_CATEGORY_LOG_DETAILS: self._build_category_log_details(), } - return OCIDataScienceModelDeployment( - **update_model_deployment_details - ).to_oci_model(UpdateModelDeploymentDetails) + + return UpdateModelDeploymentDetails( + **ads_utils.batch_convert_case(update_model_deployment_details, "snake") + ) def _update_spec(self, **kwargs) -> "ModelDeployment": """Updates model deployment specs from kwargs. @@ -1644,22 +1648,22 @@ def _build_model_deployment_configuration_details(self) -> Dict: } if infrastructure.subnet_id: - instance_configuration[ - infrastructure.CONST_SUBNET_ID - ] = infrastructure.subnet_id + instance_configuration[infrastructure.CONST_SUBNET_ID] = ( + infrastructure.subnet_id + ) if infrastructure.private_endpoint_id: if not hasattr( oci.data_science.models.InstanceConfiguration, "private_endpoint_id" ): # TODO: add oci version with private endpoint support. - raise EnvironmentError( + raise OSError( "Private endpoint is not supported in the current OCI SDK installed." ) - instance_configuration[ - infrastructure.CONST_PRIVATE_ENDPOINT_ID - ] = infrastructure.private_endpoint_id + instance_configuration[infrastructure.CONST_PRIVATE_ENDPOINT_ID] = ( + infrastructure.private_endpoint_id + ) scaling_policy = { infrastructure.CONST_POLICY_TYPE: "FIXED_SIZE", @@ -1667,13 +1671,13 @@ def _build_model_deployment_configuration_details(self) -> Dict: or DEFAULT_REPLICA, } - if not runtime.model_uri: + if not (runtime.model_uri or runtime.model_group_id): raise ValueError( - "Missing parameter model uri. Try reruning it after model uri is configured." + "Missing parameter model uri and model group id. Try reruning it after model or model group is configured." ) model_id = runtime.model_uri - if not model_id.startswith("ocid"): + if model_id and not model_id.startswith("ocid"): from ads.model.datascience_model import DataScienceModel dsc_model = DataScienceModel( @@ -1704,7 +1708,7 @@ def _build_model_deployment_configuration_details(self) -> Dict: oci.data_science.models, "ModelDeploymentEnvironmentConfigurationDetails", ): - raise EnvironmentError( + raise OSError( "Environment variable hasn't been supported in the current OCI SDK installed." ) @@ -1720,9 +1724,9 @@ def _build_model_deployment_configuration_details(self) -> Dict: and runtime.inference_server.upper() == MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON ): - environment_variables[ - "CONTAINER_TYPE" - ] = MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON + environment_variables["CONTAINER_TYPE"] = ( + MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON + ) runtime.set_spec(runtime.CONST_ENV, environment_variables) environment_configuration_details = { runtime.CONST_ENVIRONMENT_CONFIG_TYPE: runtime.environment_config_type, @@ -1734,7 +1738,7 @@ def _build_model_deployment_configuration_details(self) -> Dict: oci.data_science.models, "OcirModelDeploymentEnvironmentConfigurationDetails", ): - raise EnvironmentError( + raise OSError( "Container runtime hasn't been supported in the current OCI SDK installed." ) environment_configuration_details["image"] = runtime.image @@ -1742,9 +1746,9 @@ def _build_model_deployment_configuration_details(self) -> Dict: environment_configuration_details["cmd"] = runtime.cmd environment_configuration_details["entrypoint"] = runtime.entrypoint environment_configuration_details["serverPort"] = runtime.server_port - environment_configuration_details[ - "healthCheckPort" - ] = runtime.health_check_port + environment_configuration_details["healthCheckPort"] = ( + runtime.health_check_port + ) model_deployment_configuration_details = { infrastructure.CONST_DEPLOYMENT_TYPE: "SINGLE_MODEL", @@ -1752,9 +1756,27 @@ def _build_model_deployment_configuration_details(self) -> Dict: runtime.CONST_ENVIRONMENT_CONFIG_DETAILS: environment_configuration_details, } + if runtime.model_group_id: + model_deployment_configuration_details[ + infrastructure.CONST_DEPLOYMENT_TYPE + ] = "MODEL_GROUP" + model_deployment_configuration_details["modelGroupConfigurationDetails"] = { + runtime.CONST_MODEL_GROUP_ID: runtime.model_group_id + } + model_deployment_configuration_details[ + "infrastructureConfigurationDetails" + ] = { + "infrastructureType": "INSTANCE_POOL", + infrastructure.CONST_BANDWIDTH_MBPS: infrastructure.bandwidth_mbps + or DEFAULT_BANDWIDTH_MBPS, + infrastructure.CONST_INSTANCE_CONFIG: instance_configuration, + infrastructure.CONST_SCALING_POLICY: scaling_policy, + } + model_configuration_details.pop(runtime.CONST_MODEL_ID) + if runtime.deployment_mode == ModelDeploymentMode.STREAM: if not hasattr(oci.data_science.models, "StreamConfigurationDetails"): - raise EnvironmentError( + raise OSError( "Model deployment mode hasn't been supported in the current OCI SDK installed." ) model_deployment_configuration_details[ @@ -1786,9 +1808,13 @@ def _build_category_log_details(self) -> Dict: logs = {} if ( - self.infrastructure.access_log and - self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None) - and self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_ID, None) + self.infrastructure.access_log + and self.infrastructure.access_log.get( + self.infrastructure.CONST_LOG_GROUP_ID, None + ) + and self.infrastructure.access_log.get( + self.infrastructure.CONST_LOG_ID, None + ) ): logs[self.infrastructure.CONST_ACCESS] = { self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.access_log.get( @@ -1799,9 +1825,13 @@ def _build_category_log_details(self) -> Dict: ), } if ( - self.infrastructure.predict_log and - self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None) - and self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_ID, None) + self.infrastructure.predict_log + and self.infrastructure.predict_log.get( + self.infrastructure.CONST_LOG_GROUP_ID, None + ) + and self.infrastructure.predict_log.get( + self.infrastructure.CONST_LOG_ID, None + ) ): logs[self.infrastructure.CONST_PREDICT] = { self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.predict_log.get( diff --git a/ads/model/deployment/model_deployment_runtime.py b/ads/model/deployment/model_deployment_runtime.py index 26e31f9cd..adfa48d1d 100644 --- a/ads/model/deployment/model_deployment_runtime.py +++ b/ads/model/deployment/model_deployment_runtime.py @@ -1,11 +1,11 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2023 Oracle and/or its affiliates. +# Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/from typing import Dict from typing import Dict, List + from ads.jobs.builders.base import Builder MODEL_DEPLOYMENT_RUNTIME_KIND = "runtime" @@ -41,6 +41,8 @@ class ModelDeploymentRuntime(Builder): The output stream ids of model deployment. model_uri: str The model uri of model deployment. + model_group_id: str + The model group id of model deployment. bucket_uri: str The OCI Object Storage URI where large size model artifacts will be copied to. auth: Dict @@ -66,6 +68,8 @@ class ModelDeploymentRuntime(Builder): Sets the output stream ids of model deployment with_model_uri(model_uri) Sets the model uri of model deployment + with_model_group_id(model_group_id) + Sets the model group id of model deployment with_bucket_uri(bucket_uri) Sets the bucket uri when uploading large size model. with_auth(auth) @@ -82,6 +86,7 @@ class ModelDeploymentRuntime(Builder): CONST_MODEL_ID = "modelId" CONST_MODEL_URI = "modelUri" + CONST_MODEL_GROUP_ID = "modelGroupId" CONST_ENV = "env" CONST_ENVIRONMENT_VARIABLES = "environmentVariables" CONST_ENVIRONMENT_CONFIG_TYPE = "environmentConfigurationType" @@ -103,6 +108,7 @@ class ModelDeploymentRuntime(Builder): CONST_OUTPUT_STREAM_IDS: "output_stream_ids", CONST_DEPLOYMENT_MODE: "deployment_mode", CONST_MODEL_URI: "model_uri", + CONST_MODEL_GROUP_ID: "model_group_id", CONST_BUCKET_URI: "bucket_uri", CONST_AUTH: "auth", CONST_REGION: "region", @@ -120,6 +126,9 @@ class ModelDeploymentRuntime(Builder): MODEL_CONFIG_DETAILS_PATH = ( "model_deployment_configuration_details.model_configuration_details" ) + MODEL_GROUP_CONFIG_DETAILS_PATH = ( + "model_deployment_configuration_details.model_group_configuration_details" + ) payload_attribute_map = { CONST_ENV: f"{ENVIRONMENT_CONFIG_DETAILS_PATH}.environment_variables", @@ -127,6 +136,7 @@ class ModelDeploymentRuntime(Builder): CONST_OUTPUT_STREAM_IDS: f"{STREAM_CONFIG_DETAILS_PATH}.output_stream_ids", CONST_DEPLOYMENT_MODE: "deployment_mode", CONST_MODEL_URI: f"{MODEL_CONFIG_DETAILS_PATH}.model_id", + CONST_MODEL_GROUP_ID: f"{MODEL_GROUP_CONFIG_DETAILS_PATH}.model_group_id", } def __init__(self, spec: Dict = None, **kwargs) -> None: @@ -278,6 +288,32 @@ def with_model_uri(self, model_uri: str) -> "ModelDeploymentRuntime": """ return self.set_spec(self.CONST_MODEL_URI, model_uri) + @property + def model_group_id(self) -> str: + """The model group id of model deployment. + + Returns + ------- + str + The model group id of model deployment. + """ + return self.get_spec(self.CONST_MODEL_GROUP_ID, None) + + def with_model_group_id(self, model_group_id: str) -> "ModelDeploymentRuntime": + """Sets the model group id of model deployment. + + Parameters + ---------- + model_group_id: str + The model group id of model deployment. + + Returns + ------- + ModelDeploymentRuntime + The ModelDeploymentRuntime instance (self). + """ + return self.set_spec(self.CONST_MODEL_GROUP_ID, model_group_id) + @property def bucket_uri(self) -> str: """The bucket uri of model. diff --git a/tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py b/tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py index 589c58d70..f86f1816a 100644 --- a/tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py +++ b/tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*-- -# Copyright (c) 2023 Oracle and/or its affiliates. +# Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import copy @@ -589,151 +589,6 @@ def test_build_category_log_details(self): }, } - @patch.object(DataScienceModel, "create") - def test_build_model_deployment_details(self, mock_create): - dsc_model = MagicMock() - dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx" - mock_create.return_value = dsc_model - model_deployment = self.initialize_model_deployment() - create_model_deployment_details = ( - model_deployment._build_model_deployment_details() - ) - - mock_create.assert_called() - - assert isinstance( - create_model_deployment_details, - CreateModelDeploymentDetails, - ) - assert ( - create_model_deployment_details.display_name - == model_deployment.display_name - ) - assert ( - create_model_deployment_details.description == model_deployment.description - ) - assert ( - create_model_deployment_details.freeform_tags - == model_deployment.freeform_tags - ) - assert ( - create_model_deployment_details.defined_tags - == model_deployment.defined_tags - ) - - category_log_details = create_model_deployment_details.category_log_details - assert isinstance(category_log_details, CategoryLogDetails) - assert ( - category_log_details.access.log_id - == model_deployment.infrastructure.access_log["logId"] - ) - assert ( - category_log_details.access.log_group_id - == model_deployment.infrastructure.access_log["logGroupId"] - ) - assert ( - category_log_details.predict.log_id - == model_deployment.infrastructure.predict_log["logId"] - ) - assert ( - category_log_details.predict.log_group_id - == model_deployment.infrastructure.predict_log["logGroupId"] - ) - - model_deployment_configuration_details = ( - create_model_deployment_details.model_deployment_configuration_details - ) - assert isinstance( - model_deployment_configuration_details, - SingleModelDeploymentConfigurationDetails, - ) - assert model_deployment_configuration_details.deployment_type == "SINGLE_MODEL" - - environment_configuration_details = ( - model_deployment_configuration_details.environment_configuration_details - ) - assert isinstance( - environment_configuration_details, - OcirModelDeploymentEnvironmentConfigurationDetails, - ) - assert ( - environment_configuration_details.environment_configuration_type - == "OCIR_CONTAINER" - ) - assert ( - environment_configuration_details.environment_variables - == model_deployment.runtime.env - ) - assert environment_configuration_details.cmd == model_deployment.runtime.cmd - assert environment_configuration_details.image == model_deployment.runtime.image - assert ( - environment_configuration_details.image_digest - == model_deployment.runtime.image_digest - ) - assert ( - environment_configuration_details.entrypoint - == model_deployment.runtime.entrypoint - ) - assert ( - environment_configuration_details.server_port - == model_deployment.runtime.server_port - ) - assert ( - environment_configuration_details.health_check_port - == model_deployment.runtime.health_check_port - ) - - model_configuration_details = ( - model_deployment_configuration_details.model_configuration_details - ) - assert isinstance( - model_configuration_details, - ModelConfigurationDetails, - ) - assert ( - model_configuration_details.bandwidth_mbps - == model_deployment.infrastructure.bandwidth_mbps - ) - assert ( - model_configuration_details.model_id == model_deployment.runtime.model_uri - ) - - instance_configuration = model_configuration_details.instance_configuration - assert isinstance(instance_configuration, InstanceConfiguration) - assert ( - instance_configuration.instance_shape_name - == model_deployment.infrastructure.shape_name - ) - assert ( - instance_configuration.model_deployment_instance_shape_config_details.ocpus - == model_deployment.infrastructure.shape_config_details["ocpus"] - ) - assert ( - instance_configuration.model_deployment_instance_shape_config_details.memory_in_gbs - == model_deployment.infrastructure.shape_config_details["memoryInGBs"] - ) - - scaling_policy = model_configuration_details.scaling_policy - assert isinstance(scaling_policy, FixedSizeScalingPolicy) - assert scaling_policy.policy_type == "FIXED_SIZE" - assert scaling_policy.instance_count == model_deployment.infrastructure.replica - - # stream_configuration_details = ( - # model_deployment_configuration_details.stream_configuration_details - # ) - # assert isinstance( - # stream_configuration_details, - # StreamConfigurationDetails, - # ) - # assert ( - # stream_configuration_details.input_stream_ids - # == model_deployment.runtime.input_stream_ids - # ) - # assert ( - # stream_configuration_details.output_stream_ids - # == model_deployment.runtime.output_stream_ids - # ) - def test_update_from_oci_model(self): model_deployment = self.initialize_model_deployment() model_deployment_from_oci = model_deployment._update_from_oci_model( @@ -882,151 +737,6 @@ def test_model_deployment_from_dict(self): assert new_model_deployment.to_dict() == model_deployment.to_dict() - @patch.object(DataScienceModel, "create") - def test_update_model_deployment_details(self, mock_create): - dsc_model = MagicMock() - dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx" - mock_create.return_value = dsc_model - model_deployment = self.initialize_model_deployment() - update_model_deployment_details = ( - model_deployment._update_model_deployment_details() - ) - - mock_create.assert_called() - - assert isinstance( - update_model_deployment_details, - UpdateModelDeploymentDetails, - ) - assert ( - update_model_deployment_details.display_name - == model_deployment.display_name - ) - assert ( - update_model_deployment_details.description == model_deployment.description - ) - assert ( - update_model_deployment_details.freeform_tags - == model_deployment.freeform_tags - ) - assert ( - update_model_deployment_details.defined_tags - == model_deployment.defined_tags - ) - - category_log_details = update_model_deployment_details.category_log_details - assert isinstance(category_log_details, UpdateCategoryLogDetails) - assert ( - category_log_details.access.log_id - == model_deployment.infrastructure.access_log["logId"] - ) - assert ( - category_log_details.access.log_group_id - == model_deployment.infrastructure.access_log["logGroupId"] - ) - assert ( - category_log_details.predict.log_id - == model_deployment.infrastructure.predict_log["logId"] - ) - assert ( - category_log_details.predict.log_group_id - == model_deployment.infrastructure.predict_log["logGroupId"] - ) - - model_deployment_configuration_details = ( - update_model_deployment_details.model_deployment_configuration_details - ) - assert isinstance( - model_deployment_configuration_details, - UpdateSingleModelDeploymentConfigurationDetails, - ) - assert model_deployment_configuration_details.deployment_type == "SINGLE_MODEL" - - environment_configuration_details = ( - model_deployment_configuration_details.environment_configuration_details - ) - assert isinstance( - environment_configuration_details, - UpdateOcirModelDeploymentEnvironmentConfigurationDetails, - ) - assert ( - environment_configuration_details.environment_configuration_type - == "OCIR_CONTAINER" - ) - assert ( - environment_configuration_details.environment_variables - == model_deployment.runtime.env - ) - assert environment_configuration_details.cmd == model_deployment.runtime.cmd - assert environment_configuration_details.image == model_deployment.runtime.image - assert ( - environment_configuration_details.image_digest - == model_deployment.runtime.image_digest - ) - assert ( - environment_configuration_details.entrypoint - == model_deployment.runtime.entrypoint - ) - assert ( - environment_configuration_details.server_port - == model_deployment.runtime.server_port - ) - assert ( - environment_configuration_details.health_check_port - == model_deployment.runtime.health_check_port - ) - - model_configuration_details = ( - model_deployment_configuration_details.model_configuration_details - ) - assert isinstance( - model_configuration_details, - UpdateModelConfigurationDetails, - ) - assert ( - model_configuration_details.bandwidth_mbps - == model_deployment.infrastructure.bandwidth_mbps - ) - assert ( - model_configuration_details.model_id == model_deployment.runtime.model_uri - ) - - instance_configuration = model_configuration_details.instance_configuration - assert isinstance(instance_configuration, InstanceConfiguration) - assert ( - instance_configuration.instance_shape_name - == model_deployment.infrastructure.shape_name - ) - assert ( - instance_configuration.model_deployment_instance_shape_config_details.ocpus - == model_deployment.infrastructure.shape_config_details["ocpus"] - ) - assert ( - instance_configuration.model_deployment_instance_shape_config_details.memory_in_gbs - == model_deployment.infrastructure.shape_config_details["memoryInGBs"] - ) - - scaling_policy = model_configuration_details.scaling_policy - assert isinstance(scaling_policy, FixedSizeScalingPolicy) - assert scaling_policy.policy_type == "FIXED_SIZE" - assert scaling_policy.instance_count == model_deployment.infrastructure.replica - - # stream_configuration_details = ( - # model_deployment_configuration_details.stream_configuration_details - # ) - # assert isinstance( - # stream_configuration_details, - # UpdateStreamConfigurationDetails, - # ) - # assert ( - # stream_configuration_details.input_stream_ids - # == model_deployment.runtime.input_stream_ids - # ) - # assert ( - # stream_configuration_details.output_stream_ids - # == model_deployment.runtime.output_stream_ids - # ) - @patch.object( ModelDeploymentInfrastructure, "_load_default_properties", return_value={} ) @@ -1127,9 +837,7 @@ def test_from_ocid(self, mock_from_ocid): "create_model_deployment", ) @patch.object(DataScienceModel, "create") - def test_deploy( - self, mock_create, mock_create_model_deployment, mock_sync - ): + def test_deploy(self, mock_create, mock_create_model_deployment, mock_sync): dsc_model = MagicMock() dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx" mock_create.return_value = dsc_model @@ -1346,44 +1054,35 @@ def test_update_spec(self): model_deployment = self.initialize_model_deployment() model_deployment._update_spec( display_name="test_updated_name", - freeform_tags={"test_updated_key":"test_updated_value"}, - access_log={ - "log_id": "test_updated_access_log_id" - }, - predict_log={ - "log_group_id": "test_updated_predict_log_group_id" - }, - shape_config_details={ - "ocpus": 100, - "memoryInGBs": 200 - }, + freeform_tags={"test_updated_key": "test_updated_value"}, + access_log={"log_id": "test_updated_access_log_id"}, + predict_log={"log_group_id": "test_updated_predict_log_group_id"}, + shape_config_details={"ocpus": 100, "memoryInGBs": 200}, replica=20, image="test_updated_image", - env={ - "test_updated_env_key":"test_updated_env_value" - } + env={"test_updated_env_key": "test_updated_env_value"}, ) assert model_deployment.display_name == "test_updated_name" assert model_deployment.freeform_tags == { - "test_updated_key":"test_updated_value" + "test_updated_key": "test_updated_value" } assert model_deployment.infrastructure.access_log == { "logId": "test_updated_access_log_id", - "logGroupId": "fakeid.loggroup.oc1.iad.xxx" + "logGroupId": "fakeid.loggroup.oc1.iad.xxx", } assert model_deployment.infrastructure.predict_log == { "logId": "fakeid.log.oc1.iad.xxx", - "logGroupId": "test_updated_predict_log_group_id" + "logGroupId": "test_updated_predict_log_group_id", } assert model_deployment.infrastructure.shape_config_details == { "ocpus": 100, - "memoryInGBs": 200 + "memoryInGBs": 200, } assert model_deployment.infrastructure.replica == 20 assert model_deployment.runtime.image == "test_updated_image" assert model_deployment.runtime.env == { - "test_updated_env_key":"test_updated_env_value" + "test_updated_env_key": "test_updated_env_value" } @patch.object(OCIDataScienceMixin, "sync") @@ -1393,18 +1092,14 @@ def test_update_spec(self): ) @patch.object(DataScienceModel, "create") def test_model_deployment_with_large_size_artifact( - self, - mock_create, - mock_create_model_deployment, - mock_sync + self, mock_create, mock_create_model_deployment, mock_sync ): dsc_model = MagicMock() dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx" mock_create.return_value = dsc_model model_deployment = self.initialize_model_deployment() ( - model_deployment.runtime - .with_auth({"test_key":"test_value"}) + model_deployment.runtime.with_auth({"test_key": "test_value"}) .with_region("test_region") .with_overwrite_existing_artifact(True) .with_remove_existing_artifact(True) @@ -1425,18 +1120,18 @@ def test_model_deployment_with_large_size_artifact( mock_create_model_deployment.return_value = response model_deployment = self.initialize_model_deployment() model_deployment.set_spec(model_deployment.CONST_ID, "test_model_deployment_id") - + create_model_deployment_details = ( model_deployment._build_model_deployment_details() ) model_deployment.deploy(wait_for_completion=False) mock_create.assert_called_with( bucket_uri="test_bucket_uri", - auth={"test_key":"test_value"}, + auth={"test_key": "test_value"}, region="test_region", overwrite_existing_artifact=True, remove_existing_artifact=True, - timeout=100 + timeout=100, ) mock_create_model_deployment.assert_called_with(create_model_deployment_details) mock_sync.assert_called() From 6c4da24fbe7e0689e20cbdcf6c33a957e444b5e3 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Tue, 1 Jul 2025 12:52:34 -0400 Subject: [PATCH 3/3] Updated pr. --- ads/aqua/model/constants.py | 1 - ads/aqua/model/model.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ads/aqua/model/constants.py b/ads/aqua/model/constants.py index ce3e3f51d..194245fe4 100644 --- a/ads/aqua/model/constants.py +++ b/ads/aqua/model/constants.py @@ -20,7 +20,6 @@ class ModelCustomMetadataFields(ExtendedEnum): DEPLOYMENT_CONTAINER_URI = "deployment-container-uri" MULTIMODEL_GROUP_COUNT = "model_group_count" MULTIMODEL_METADATA = "multi_model_metadata" - MODEL_GROUP_CONFIG = "OCI_MODEL_GROUP_CUSTOM_METADATA" class ModelTask(ExtendedEnum): diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 925535139..959fcc8e2 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -52,6 +52,7 @@ AQUA_MODEL_ARTIFACT_FILE, AQUA_MODEL_TOKENIZER_CONFIG, AQUA_MODEL_TYPE_CUSTOM, + AQUA_MULTI_MODEL_CONFIG, HF_METADATA_FOLDER, LICENSE, MODEL_BY_REFERENCE_OSS_PATH_KEY, @@ -479,7 +480,7 @@ def create_multi( category="Other", ) model_custom_metadata.add( - key=ModelCustomMetadataFields.MODEL_GROUP_CONFIG, + key=AQUA_MULTI_MODEL_CONFIG, value=self._build_model_group_config( create_deployment_details=create_deployment_details, model_config_summary=model_config_summary,