Skip to content

Commit 300aa17

Browse files
committed
Adds shapes method to the OciDataScienceModelDeployment
1 parent ba605ee commit 300aa17

File tree

2 files changed

+95
-26
lines changed

2 files changed

+95
-26
lines changed

ads/aqua/common/utils.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,10 +1229,10 @@ def load_gpu_shapes_index(
12291229
auth: Optional[Dict[str, Any]] = None,
12301230
) -> GPUShapesIndex:
12311231
"""
1232-
Load the GPU shapes index, preferring the OS bucket copy over the local one.
1232+
Load the GPU shapes index, merging based on freshness.
12331233
1234-
Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
1235-
if that succeeds, those entries will override the local defaults.
1234+
Compares last-modified timestamps of local and remote files,
1235+
merging the shapes from the fresher file on top of the older one.
12361236
12371237
Parameters
12381238
----------
@@ -1253,7 +1253,9 @@ def load_gpu_shapes_index(
12531253
file_name = "gpu_shapes_index.json"
12541254

12551255
# Try remote load
1256-
remote_data: Dict[str, Any] = {}
1256+
local_data, remote_data = {}, {}
1257+
local_mtime, remote_mtime = None, None
1258+
12571259
if CONDA_BUCKET_NS:
12581260
try:
12591261
auth = auth or authutil.default_signer()
@@ -1263,8 +1265,24 @@ def load_gpu_shapes_index(
12631265
logger.debug(
12641266
"Loading GPU shapes index from Object Storage: %s", storage_path
12651267
)
1266-
with fsspec.open(storage_path, mode="r", **auth) as f:
1268+
1269+
fs = fsspec.filesystem("oci", **auth)
1270+
with fs.open(storage_path, mode="r") as f:
12671271
remote_data = json.load(f)
1272+
1273+
remote_info = fs.info(storage_path)
1274+
remote_mtime_str = remote_info.get("timeModified", None)
1275+
if remote_mtime_str:
1276+
# Convert OCI timestamp (e.g., 'Mon, 04 Aug 2025 06:37:13 GMT') to epoch time
1277+
remote_mtime = datetime.strptime(
1278+
remote_mtime_str, "%a, %d %b %Y %H:%M:%S %Z"
1279+
).timestamp()
1280+
1281+
logger.debug(
1282+
"Remote GPU shapes last-modified time: %s",
1283+
datetime.fromtimestamp(remote_mtime).strftime("%Y-%m-%d %H:%M:%S"),
1284+
)
1285+
12681286
logger.debug(
12691287
"Loaded %d shapes from Object Storage",
12701288
len(remote_data.get("shapes", {})),
@@ -1273,12 +1291,19 @@ def load_gpu_shapes_index(
12731291
logger.debug("Remote load failed (%s); falling back to local", ex)
12741292

12751293
# Load local copy
1276-
local_data: Dict[str, Any] = {}
12771294
local_path = os.path.join(os.path.dirname(__file__), "../resources", file_name)
12781295
try:
12791296
logger.debug("Loading GPU shapes index from local file: %s", local_path)
12801297
with open(local_path) as f:
12811298
local_data = json.load(f)
1299+
1300+
local_mtime = os.path.getmtime(local_path)
1301+
1302+
logger.debug(
1303+
"Local GPU shapes last-modified time: %s",
1304+
datetime.fromtimestamp(local_mtime).strftime("%Y-%m-%d %H:%M:%S"),
1305+
)
1306+
12821307
logger.debug(
12831308
"Loaded %d shapes from local file", len(local_data.get("shapes", {}))
12841309
)
@@ -1288,7 +1313,24 @@ def load_gpu_shapes_index(
12881313
# Merge: remote shapes override local
12891314
local_shapes = local_data.get("shapes", {})
12901315
remote_shapes = remote_data.get("shapes", {})
1291-
merged_shapes = {**local_shapes, **remote_shapes}
1316+
merged_shapes = {}
1317+
1318+
if local_mtime and remote_mtime:
1319+
if remote_mtime >= local_mtime:
1320+
logger.debug("Remote data is fresher or equal; merging remote over local.")
1321+
merged_shapes = {**local_shapes, **remote_shapes}
1322+
else:
1323+
logger.debug("Local data is fresher; merging local over remote.")
1324+
merged_shapes = {**remote_shapes, **local_shapes}
1325+
elif remote_shapes:
1326+
logger.debug("Only remote shapes available.")
1327+
merged_shapes = remote_shapes
1328+
elif local_shapes:
1329+
logger.debug("Only local shapes available.")
1330+
merged_shapes = local_shapes
1331+
else:
1332+
logger.error("No GPU shapes data found in either source.")
1333+
merged_shapes = {}
12921334

12931335
return GPUShapesIndex(shapes=merged_shapes)
12941336

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

4-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

7-
from functools import wraps
86
import logging
9-
from typing import Callable, List
10-
from ads.common.oci_datascience import OCIDataScienceMixin
11-
from ads.common.work_request import DataScienceWorkRequest
12-
from ads.config import PROJECT_OCID
13-
from ads.model.deployment.common.utils import OCIClientManager, State
14-
import oci
7+
from functools import wraps
8+
from typing import Callable, List, Optional
159

10+
import oci
1611
from oci.data_science.models import (
1712
CreateModelDeploymentDetails,
13+
ModelDeploymentShapeSummary,
1814
UpdateModelDeploymentDetails,
1915
)
2016

17+
from ads.common.oci_datascience import OCIDataScienceMixin
18+
from ads.common.work_request import DataScienceWorkRequest
19+
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
20+
from ads.model.deployment.common.utils import OCIClientManager, State
21+
2122
DEFAULT_WAIT_TIME = 1200
2223
DEFAULT_POLL_INTERVAL = 10
2324
ALLOWED_STATUS = [
@@ -185,14 +186,13 @@ def activate(
185186
self.id,
186187
)
187188

188-
189189
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
190190
if wait_for_completion:
191191
try:
192192
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
193193
progress_bar_description="Activating model deployment",
194-
max_wait_time=max_wait_time,
195-
poll_interval=poll_interval
194+
max_wait_time=max_wait_time,
195+
poll_interval=poll_interval,
196196
)
197197
except Exception as e:
198198
logger.error(
@@ -239,8 +239,8 @@ def create(
239239
try:
240240
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
241241
progress_bar_description="Creating model deployment",
242-
max_wait_time=max_wait_time,
243-
poll_interval=poll_interval
242+
max_wait_time=max_wait_time,
243+
poll_interval=poll_interval,
244244
)
245245
except Exception as e:
246246
logger.error("Error while trying to create model deployment: " + str(e))
@@ -290,8 +290,8 @@ def deactivate(
290290
try:
291291
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
292292
progress_bar_description="Deactivating model deployment",
293-
max_wait_time=max_wait_time,
294-
poll_interval=poll_interval
293+
max_wait_time=max_wait_time,
294+
poll_interval=poll_interval,
295295
)
296296
except Exception as e:
297297
logger.error(
@@ -351,14 +351,14 @@ def delete(
351351
response = self.client.delete_model_deployment(
352352
self.id,
353353
)
354-
354+
355355
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
356356
if wait_for_completion:
357357
try:
358358
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
359359
progress_bar_description="Deleting model deployment",
360-
max_wait_time=max_wait_time,
361-
poll_interval=poll_interval
360+
max_wait_time=max_wait_time,
361+
poll_interval=poll_interval,
362362
)
363363
except Exception as e:
364364
logger.error("Error while trying to delete model deployment: " + str(e))
@@ -493,3 +493,30 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment":
493493
An instance of `OCIDataScienceModelDeployment`.
494494
"""
495495
return super().from_ocid(model_deployment_id)
496+
497+
@classmethod
498+
def shapes(
499+
cls,
500+
compartment_id: Optional[str] = None,
501+
**kwargs,
502+
) -> List[ModelDeploymentShapeSummary]:
503+
"""
504+
Retrieves all available model deployment shapes in the given compartment.
505+
506+
This method uses OCI's pagination utility to fetch all pages of model
507+
deployment shape summaries available in the specified compartment.
508+
509+
Args:
510+
compartment_id (Optional[str]): The OCID of the compartment. If not provided,
511+
the default COMPARTMENT_ID extracted form env variables is used.
512+
**kwargs: Additional keyword arguments to pass to the list_model_deployments call.
513+
514+
Returns:
515+
List[ModelDeploymentShapeSummary]: A list of all model deployment shape summaries.
516+
"""
517+
client = cls().client
518+
compartment_id = compartment_id or COMPARTMENT_OCID
519+
520+
return oci.pagination.list_call_get_all_results(
521+
client.list_model_deployment_shapes, compartment_id, **kwargs
522+
).data

0 commit comments

Comments
 (0)