Skip to content

Commit 96f5543

Browse files
committed
addressed comments
1 parent 300aa17 commit 96f5543

File tree

9 files changed

+217
-192
lines changed

9 files changed

+217
-192
lines changed

ads/aqua/cli.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from ads.aqua.finetuning import AquaFineTuningApp
1515
from ads.aqua.model import AquaModelApp
1616
from ads.aqua.modeldeployment import AquaDeploymentApp
17-
from ads.aqua.shaperecommend.recommend import AquaRecommendApp
1817
from ads.aqua.verify_policies import AquaVerifyPoliciesApp
1918
from ads.common.utils import LOG_LEVELS
2019

@@ -32,7 +31,6 @@ class AquaCommand:
3231
deployment = AquaDeploymentApp
3332
evaluation = AquaEvaluationApp
3433
verify_policies = AquaVerifyPoliciesApp
35-
recommend = AquaRecommendApp
3634

3735
def __init__(
3836
self,

ads/aqua/extension/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ads.aqua.extension.evaluation_handler import __handlers__ as __eval_handlers__
1414
from ads.aqua.extension.finetune_handler import __handlers__ as __finetune_handlers__
1515
from ads.aqua.extension.model_handler import __handlers__ as __model_handlers__
16-
from ads.aqua.extension.recommend_handler import __handlers__ as __gpu_handlers__
1716
from ads.aqua.extension.ui_handler import __handlers__ as __ui_handlers__
1817
from ads.aqua.extension.ui_websocket_handler import __handlers__ as __ws_handlers__
1918

@@ -25,7 +24,6 @@
2524
+ __ui_handlers__
2625
+ __eval_handlers__
2726
+ __ws_handlers__
28-
+ __gpu_handlers__
2927
)
3028

3129

ads/aqua/extension/deployment_handler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ def get(self, id: Union[str, List[str]] = None):
5757
return self.get_deployment_config(
5858
model_id=id.split(",") if "," in id else id
5959
)
60+
elif paths.startswith("aqua/deployments/recommend_shapes"):
61+
id = id or self.get_argument("model_id", default=None)
62+
if not id or not isinstance(id, str):
63+
raise HTTPError(
64+
400,
65+
f"Invalid request format for {self.request.path}. "
66+
"Expected a single model OCID",
67+
)
68+
id = id.replace(" ", "")
69+
return self.get_recommend_shape(model_id=id)
6070
elif paths.startswith("aqua/deployments/shapes"):
6171
return self.list_shapes()
6272
elif paths.startswith("aqua/deployments"):
@@ -161,6 +171,36 @@ def get_deployment_config(self, model_id: Union[str, List[str]]):
161171

162172
return self.finish(deployment_config)
163173

174+
def get_recommend_shape(self, model_id: str):
175+
"""
176+
Retrieves the valid shape and deployment parameter configuration for one Aqua Model.
177+
178+
Parameters
179+
----------
180+
model_id : str
181+
A single model ID (str).
182+
183+
Returns
184+
-------
185+
None
186+
The function sends the ShapeRecommendReport (generate_table = False) or Rich Diff Table (generate_table = True)
187+
"""
188+
app = AquaDeploymentApp()
189+
190+
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
191+
192+
generate_table = (
193+
self.get_argument("generate_table", default="True").lower() == "true"
194+
)
195+
196+
recommend_report = app.recommend_shape(
197+
model_id=model_id,
198+
compartment_id=compartment_id,
199+
generate_table=generate_table,
200+
)
201+
202+
return self.finish(recommend_report)
203+
164204
def list_shapes(self):
165205
"""
166206
Lists the valid model deployment shapes.
@@ -408,6 +448,7 @@ def get(self, model_deployment_id):
408448
("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
409449
("deployments/config/?([^/]*)", AquaDeploymentHandler),
410450
("deployments/shapes/?([^/]*)", AquaDeploymentHandler),
451+
("deployments/recommend_shapes/?([^/]*)", AquaDeploymentHandler),
411452
("deployments/?([^/]*)", AquaDeploymentHandler),
412453
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
413454
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),

ads/aqua/extension/recommend_handler.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

ads/aqua/modeldeployment/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@
1111

1212
DEFAULT_WAIT_TIME = 12000
1313
DEFAULT_POLL_INTERVAL = 10
14+
15+
SHAPE_MAP = {"NVIDIA_GPU": "GPU"}

ads/aqua/modeldeployment/deployment.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
import shlex
99
import threading
1010
from datetime import datetime, timedelta
11-
from typing import Dict, List, Optional
11+
from typing import Dict, List, Optional, Union
1212

1313
from cachetools import TTLCache, cached
1414
from oci.data_science.models import ModelDeploymentShapeSummary
1515
from pydantic import ValidationError
16+
from rich.table import Table
1617

1718
from ads.aqua.app import AquaApp, logger
1819
from ads.aqua.common.entities import (
@@ -63,14 +64,16 @@
6364
ModelDeploymentConfigSummary,
6465
MultiModelDeploymentConfigLoader,
6566
)
66-
from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME
67+
from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME, SHAPE_MAP
6768
from ads.aqua.modeldeployment.entities import (
6869
AquaDeployment,
6970
AquaDeploymentDetail,
7071
ConfigValidationError,
7172
CreateModelDeploymentDetails,
7273
)
7374
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
75+
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend
76+
from ads.aqua.shaperecommend.shape_report import ShapeRecommendationReport
7477
from ads.common.object_storage_details import ObjectStorageDetails
7578
from ads.common.utils import UNKNOWN, get_log_links
7679
from ads.common.work_request import DataScienceWorkRequest
@@ -1243,6 +1246,107 @@ def validate_deployment_params(
12431246
)
12441247
return {"valid": True}
12451248

1249+
def valid_compute_shapes(self, **kwargs) -> List["ComputeShapeSummary"]:
1250+
"""
1251+
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
1252+
1253+
Parameters
1254+
----------
1255+
file : str
1256+
Path to the JSON file containing shape data.
1257+
1258+
Returns
1259+
-------
1260+
List[ComputeShapeSummary]
1261+
List of ComputeShapeSummary objects passing the checks.
1262+
1263+
Raises
1264+
------
1265+
ValueError
1266+
If the file cannot be opened, parsed, or the 'shapes' key is missing.
1267+
"""
1268+
compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
1269+
oci_shapes: list[ModelDeploymentShapeSummary] = self.list_resource(
1270+
self.ds_client.list_model_deployment_shapes,
1271+
compartment_id=compartment_id,
1272+
**kwargs,
1273+
)
1274+
set_user_shapes = {shape.name: shape for shape in oci_shapes}
1275+
1276+
gpu_shapes_metadata = load_gpu_shapes_index().shapes
1277+
1278+
valid_shapes = []
1279+
# only loops through GPU shapes, update later to include CPU shapes
1280+
for name, spec in gpu_shapes_metadata.items():
1281+
if name in set_user_shapes:
1282+
oci_shape = set_user_shapes.get(name)
1283+
1284+
compute_shape = ComputeShapeSummary(
1285+
available=True,
1286+
core_count= oci_shape.core_count,
1287+
memory_in_gbs= oci_shape.memory_in_gbs,
1288+
shape_series= SHAPE_MAP.get(oci_shape.shape_series, "GPU"),
1289+
name= oci_shape.name,
1290+
gpu_specs= spec
1291+
)
1292+
else:
1293+
compute_shape = ComputeShapeSummary(
1294+
available=False, name=name, shape_series="GPU", gpu_specs=spec
1295+
)
1296+
valid_shapes.append(compute_shape)
1297+
1298+
valid_shapes.sort(
1299+
key=lambda shape: shape.gpu_specs.gpu_memory_in_gbs, reverse=True
1300+
)
1301+
return valid_shapes
1302+
1303+
1304+
def recommend_shape(
1305+
self, **kwargs
1306+
) -> Union[Table, ShapeRecommendationReport]:
1307+
"""
1308+
For the CLI (set generate_table = True), generates the table (in rich diff) with valid
1309+
GPU deployment shapes for the provided model and configuration.
1310+
1311+
For the API (set generate_table = False), generates the JSON with valid
1312+
GPU deployment shapes for the provided model and configuration.
1313+
1314+
Validates if recommendations are generated, calls method to construct the rich diff
1315+
table with the recommendation data.
1316+
1317+
Parameters
1318+
----------
1319+
model_ocid : str
1320+
OCID of the model to recommend feasible compute shapes.
1321+
1322+
Returns
1323+
-------
1324+
Table (generate_table = True)
1325+
A table format for the recommendation report with compatible deployment shapes
1326+
or troubleshooting info citing the largest shapes if no shape is suitable.
1327+
1328+
ShapeRecommendationReport (generate_table = False)
1329+
A recommendation report with compatible deployment shapes, or troubleshooting info
1330+
citing the largest shapes if no shape is suitable.
1331+
1332+
Raises
1333+
------
1334+
AquaValueError
1335+
If model type is unsupported by tool (no recommendation report generated)
1336+
"""
1337+
# generate_table = kwargs.pop(
1338+
# "generate_table", True
1339+
# ) # Generate rich diff table by default
1340+
compartment_id = kwargs.get("compartment_id", COMPARTMENT_OCID)
1341+
1342+
kwargs["shapes"] = self.valid_compute_shapes(compartment_id=compartment_id)
1343+
1344+
shape_recommend = AquaShapeRecommend()
1345+
1346+
shape_recommend_report = shape_recommend.which_shapes(**kwargs)
1347+
1348+
return shape_recommend_report
1349+
12461350
@telemetry(entry_point="plugin=deployment&action=list_shapes", name="aqua")
12471351
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
12481352
def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:

ads/aqua/shaperecommend/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
# Copyright (c) 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4-
from ads.aqua.shaperecommend.recommend import AquaRecommendApp
4+
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend
55

6-
__all__ = ["AquaRecommendApp"]
6+
__all__ = ["AquaShapeRecommend"]

0 commit comments

Comments
 (0)