Skip to content

[AQUA] Integrate aqua to use model group #1214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: feature/model_group
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ads/aqua/model/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ModelCustomMetadataFields(ExtendedEnum):
DEPLOYMENT_CONTAINER_URI = "deployment-container-uri"
MULTIMODEL_GROUP_COUNT = "model_group_count"
MULTIMODEL_METADATA = "multi_model_metadata"
MODEL_GROUP_CONFIG = "OCI_MODEL_GROUP_CUSTOM_METADATA"


class ModelTask(ExtendedEnum):
Expand Down
15 changes: 15 additions & 0 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,18 @@ class ModelFileDescription(Serializable):
class Config:
alias_generator = to_camel
extra = "allow"


class MemberModel(Serializable):
"""Describes the member model of a model group.

Attributes:
model_id (str): The id of member model.
inference_key (str): The inference key of member model.
"""

model_id: str = Field(..., description="The id of member model.")
inference_key: str = Field(None, description="The inference key of member model.")

class Config:
extra = "allow"
107 changes: 62 additions & 45 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
AquaModelReadme,
AquaModelSummary,
ImportModelDetails,
ModelFileDescription,
MemberModel,
ModelValidationResult,
)
from ads.aqua.model.enums import MultiModelSupportedTaskType
Expand All @@ -102,6 +102,7 @@
)
from ads.model import DataScienceModel
from ads.model.common.utils import MetadataArtifactPathType
from ads.model.datascience_model_group import DataScienceModelGroup
from ads.model.model_metadata import (
MetadataCustomCategory,
ModelCustomMetadata,
Expand Down Expand Up @@ -235,20 +236,27 @@ def create(
def create_multi(
self,
models: List[AquaMultiModelRef],
create_deployment_details,
model_config_summary,
project_id: Optional[str] = None,
compartment_id: Optional[str] = None,
freeform_tags: Optional[Dict] = None,
defined_tags: Optional[Dict] = None,
source_models: Optional[Dict[str, DataScienceModel]] = None,
**kwargs, # noqa: ARG002
) -> DataScienceModel:
) -> DataScienceModelGroup:
"""
Creates a multi-model grouping using the provided model list.

Parameters
----------
models : List[AquaMultiModelRef]
List of AquaMultiModelRef instances for creating a multi-model group.
create_deployment_details : CreateModelDeploymentDetails
An instance of CreateModelDeploymentDetails containing all required and optional
fields for creating a model deployment via Aqua.
model_config_summary : ModelConfigSummary
Summary Model Deployment configuration for the group of models.
project_id : Optional[str]
The project ID for the multi-model group.
compartment_id : Optional[str]
Expand All @@ -264,8 +272,8 @@ def create_multi(

Returns
-------
DataScienceModel
Instance of DataScienceModel object.
DataScienceModelGroup
Instance of DataScienceModelGroup object.
"""

if not models:
Expand All @@ -274,7 +282,6 @@ def create_multi(
)

display_name_list = []
model_file_description_list: List[ModelFileDescription] = []
model_custom_metadata = ModelCustomMetadata()

service_inference_containers = (
Expand Down Expand Up @@ -337,11 +344,6 @@ def create_multi(
"Please register the model with a file description."
)

# Track model file description in a validated structure
model_file_description_list.append(
ModelFileDescription(**model_file_description)
)

# Ensure base model has a valid artifact
if not source_model.artifact:
logger.error(
Expand Down Expand Up @@ -396,11 +398,6 @@ def create_multi(
"Please register the model with a file description."
)

# Track model file description in a validated structure
model_file_description_list.append(
ModelFileDescription(**ft_model_file_description)
)

# Extract fine-tuned model path
_, fine_tune_path = extract_fine_tune_artifacts_path(
fine_tune_source_model
Expand Down Expand Up @@ -481,6 +478,22 @@ def create_multi(
description="Number of models in the group.",
category="Other",
)
model_custom_metadata.add(
key=ModelCustomMetadataFields.MODEL_GROUP_CONFIG,
value=self._build_model_group_config(
create_deployment_details=create_deployment_details,
model_config_summary=model_config_summary,
deployment_container=deployment_container,
),
description="Configs required to deploy multi models.",
category="Other",
)
model_custom_metadata.add(
key=ModelCustomMetadataFields.MULTIMODEL_METADATA,
value=json.dumps([model.model_dump() for model in models]),
description="Metadata to store user's multi model input.",
category="Other",
)

# Combine tags. The `Tags.AQUA_TAG` has been excluded, because we don't want to show
# the models created for multi-model purpose in the AQUA models list.
Expand All @@ -491,46 +504,24 @@ def create_multi(
}

# Create multi-model group
custom_model = (
DataScienceModel()
custom_model_group = (
DataScienceModelGroup()
.with_compartment_id(compartment_id)
.with_project_id(project_id)
.with_display_name(model_group_display_name)
.with_description(model_group_description)
.with_freeform_tags(**tags)
.with_defined_tags(**(defined_tags or {}))
.with_custom_metadata_list(model_custom_metadata)
.with_member_models(
[MemberModel(model_id=model.model_id).model_dump() for model in models]
)
)

# Update multi model file description to attach artifacts
custom_model.with_model_file_description(
json_dict=ModelFileDescription(
models=[
models
for model_file_description in model_file_description_list
for models in model_file_description.models
]
).model_dump(by_alias=True)
)

# Finalize creation
custom_model.create(model_by_reference=True)
custom_model_group.create()

logger.info(
f"Aqua Model '{custom_model.id}' created with models: {', '.join(display_name_list)}."
)

# Create custom metadata for multi model metadata
custom_model.create_custom_metadata_artifact(
metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA,
artifact_path_or_content=json.dumps(
[model.model_dump() for model in models]
).encode(),
path_type=MetadataArtifactPathType.CONTENT,
)

logger.debug(
f"Multi model metadata uploaded for Aqua model: {custom_model.id}."
f"Aqua Model Group'{custom_model_group.id}' created with models: {', '.join(display_name_list)}."
)

# Track telemetry event
Expand All @@ -540,7 +531,33 @@ def create_multi(
detail=combined_models,
)

return custom_model
return custom_model_group

def _build_model_group_config(
self,
create_deployment_details,
model_config_summary,
deployment_container: str,
) -> str:
"""Builds model group config required to deploy multi models."""
container_type_key = (
create_deployment_details.container_family or deployment_container
)
container_config = self.get_container_config_item(container_type_key)
container_spec = container_config.spec if container_config else UNKNOWN

container_params = container_spec.cli_param if container_spec else UNKNOWN

from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig

multi_model_config = ModelGroupConfig.from_create_model_deployment_details(
create_deployment_details,
model_config_summary,
container_type_key,
container_params,
)

return multi_model_config.model_dump_json()

@telemetry(entry_point="plugin=model&action=get", name="aqua")
def get(self, model_id: str) -> "AquaModel":
Expand Down
Loading