|
4 | 4 |
|
5 | 5 | import json |
6 | 6 | import os |
| 7 | +import traceback |
7 | 8 | from dataclasses import fields |
8 | 9 | from typing import Dict, Union |
9 | 10 |
|
|
23 | 24 | from ads.aqua.constants import UNKNOWN |
24 | 25 | from ads.common import oci_client as oc |
25 | 26 | from ads.common.auth import default_signer |
26 | | -from ads.common.utils import extract_region |
| 27 | +from ads.common.utils import extract_region, is_path_exists |
27 | 28 | from ads.config import ( |
28 | 29 | AQUA_TELEMETRY_BUCKET, |
29 | 30 | AQUA_TELEMETRY_BUCKET_NS, |
@@ -296,33 +297,46 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict: |
296 | 297 | raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.") |
297 | 298 |
|
298 | 299 | config = {} |
299 | | - artifact_path = get_artifact_path(oci_model.custom_metadata_list) |
| 300 | + # if the current model has a service model tag, then |
| 301 | + if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags: |
| 302 | + base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG] |
| 303 | + logger.info( |
| 304 | + f"Base model found for the model: {oci_model.id}. " |
| 305 | + f"Loading {config_file_name} for base model {base_model_ocid}." |
| 306 | + ) |
| 307 | + base_model = self.ds_client.get_model(base_model_ocid).data |
| 308 | + artifact_path = get_artifact_path(base_model.custom_metadata_list) |
| 309 | + else: |
| 310 | + logger.info(f"Loading {config_file_name} for model {oci_model.id}...") |
| 311 | + artifact_path = get_artifact_path(oci_model.custom_metadata_list) |
| 312 | + |
300 | 313 | if not artifact_path: |
301 | 314 | logger.debug( |
302 | 315 | f"Failed to get artifact path from custom metadata for the model: {model_id}" |
303 | 316 | ) |
304 | 317 | return config |
305 | 318 |
|
306 | | - try: |
307 | | - config_path = f"{os.path.dirname(artifact_path)}/config/" |
308 | | - config = load_config( |
309 | | - config_path, |
310 | | - config_file_name=config_file_name, |
311 | | - ) |
312 | | - except Exception: |
313 | | - # todo: temp fix for issue related to config load for byom models, update logic to choose the right path |
| 319 | + config_path = f"{os.path.dirname(artifact_path)}/config/" |
| 320 | + if not is_path_exists(config_path): |
| 321 | + config_path = f"{artifact_path.rstrip('/')}/config/" |
| 322 | + |
| 323 | + config_file_path = f"{config_path}{config_file_name}" |
| 324 | + if is_path_exists(config_file_path): |
314 | 325 | try: |
315 | | - config_path = f"{artifact_path.rstrip('/')}/config/" |
316 | 326 | config = load_config( |
317 | 327 | config_path, |
318 | 328 | config_file_name=config_file_name, |
319 | 329 | ) |
320 | 330 | except Exception: |
321 | | - pass |
| 331 | + logger.debug( |
| 332 | + f"Error loading the {config_file_name} at path {config_path}.\n" |
| 333 | + f"{traceback.format_exc()}" |
| 334 | + ) |
322 | 335 |
|
323 | 336 | if not config: |
324 | | - logger.error( |
325 | | - f"{config_file_name} is not available for the model: {model_id}. Check if the custom metadata has the artifact path set." |
| 337 | + logger.debug( |
| 338 | + f"{config_file_name} is not available for the model: {model_id}. " |
| 339 | + f"Check if the custom metadata has the artifact path set." |
326 | 340 | ) |
327 | 341 | return config |
328 | 342 |
|
|
0 commit comments