11#!/usr/bin/env python
2- # Copyright (c) 2024 Oracle and/or its affiliates.
2+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
5+ import json
56import os
7+ import traceback
68from dataclasses import fields
79from typing import Dict , Union
810
2224from ads .aqua .constants import UNKNOWN
2325from ads .common import oci_client as oc
2426from ads .common .auth import default_signer
25- from ads .common .utils import extract_region
27+ from ads .common .utils import extract_region , is_path_exists
2628from ads .config import (
2729 AQUA_TELEMETRY_BUCKET ,
2830 AQUA_TELEMETRY_BUCKET_NS ,
@@ -135,6 +137,8 @@ def create_model_version_set(
135137 description : str = None ,
136138 compartment_id : str = None ,
137139 project_id : str = None ,
140+ freeform_tags : dict = None ,
141+ defined_tags : dict = None ,
138142 ** kwargs ,
139143 ) -> tuple :
140144 """Creates ModelVersionSet from given ID or Name.
@@ -153,7 +157,10 @@ def create_model_version_set(
153157 Project OCID.
154158 tag: (str, optional)
155159 calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION
156-
160+ freeform_tags: (dict, optional)
161+ Freeform tags for the model version set
162+ defined_tags: (dict, optional)
163+ Defined tags for the model version set
157164 Returns
158165 -------
159166 tuple: (model_version_set_id, model_version_set_name)
@@ -182,13 +189,15 @@ def create_model_version_set(
182189 mvs_freeform_tags = {
183190 tag : tag ,
184191 }
192+ mvs_freeform_tags = {** mvs_freeform_tags , ** (freeform_tags or {})}
185193 model_version_set = (
186194 ModelVersionSet ()
187195 .with_compartment_id (compartment_id )
188196 .with_project_id (project_id )
189197 .with_name (model_version_set_name )
190198 .with_description (description )
191199 .with_freeform_tags (** mvs_freeform_tags )
200+ .with_defined_tags (** (defined_tags or {}))
192201 # TODO: decide what parameters will be needed
193202 # when refactor eval to use this method, we need to pass tag here.
194203 .create (** kwargs )
@@ -288,33 +297,46 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
288297 raise AquaRuntimeError (f"Target model { oci_model .id } is not Aqua model." )
289298
290299 config = {}
291- artifact_path = get_artifact_path (oci_model .custom_metadata_list )
300+ # if the current model has a service model tag, then
301+ if Tags .AQUA_SERVICE_MODEL_TAG in oci_model .freeform_tags :
302+ base_model_ocid = oci_model .freeform_tags [Tags .AQUA_SERVICE_MODEL_TAG ]
303+ logger .info (
304+ f"Base model found for the model: { oci_model .id } . "
305+ f"Loading { config_file_name } for base model { base_model_ocid } ."
306+ )
307+ base_model = self .ds_client .get_model (base_model_ocid ).data
308+ artifact_path = get_artifact_path (base_model .custom_metadata_list )
309+ else :
310+ logger .info (f"Loading { config_file_name } for model { oci_model .id } ..." )
311+ artifact_path = get_artifact_path (oci_model .custom_metadata_list )
312+
292313 if not artifact_path :
293- logger .error (
314+ logger .debug (
294315 f"Failed to get artifact path from custom metadata for the model: { model_id } "
295316 )
296317 return config
297318
298- try :
299- config_path = f"{ os .path .dirname (artifact_path )} /config/"
300- config = load_config (
301- config_path ,
302- config_file_name = config_file_name ,
303- )
304- except Exception :
305- # todo: temp fix for issue related to config load for byom models, update logic to choose the right path
319+ config_path = f"{ os .path .dirname (artifact_path )} /config/"
320+ if not is_path_exists (config_path ):
321+ config_path = f"{ artifact_path .rstrip ('/' )} /config/"
322+
323+ config_file_path = f"{ config_path } { config_file_name } "
324+ if is_path_exists (config_file_path ):
306325 try :
307- config_path = f"{ artifact_path .rstrip ('/' )} /config/"
308326 config = load_config (
309327 config_path ,
310328 config_file_name = config_file_name ,
311329 )
312330 except Exception :
313- pass
331+ logger .debug (
332+ f"Error loading the { config_file_name } at path { config_path } .\n "
333+ f"{ traceback .format_exc ()} "
334+ )
314335
315336 if not config :
316- logger .error (
317- f"{ config_file_name } is not available for the model: { model_id } . Check if the custom metadata has the artifact path set."
337+ logger .debug (
338+ f"{ config_file_name } is not available for the model: { model_id } . "
339+ f"Check if the custom metadata has the artifact path set."
318340 )
319341 return config
320342
@@ -340,7 +362,9 @@ def build_cli(self) -> str:
340362 """
341363 cmd = f"ads aqua { self ._command } "
342364 params = [
343- f"--{ field .name } { getattr (self ,field .name )} "
365+ f"--{ field .name } { json .dumps (getattr (self , field .name ))} "
366+ if isinstance (getattr (self , field .name ), dict )
367+ else f"--{ field .name } { getattr (self , field .name )} "
344368 for field in fields (self .__class__ )
345369 if getattr (self , field .name ) is not None
346370 ]
0 commit comments