|
8 | 8 | import shlex
|
9 | 9 | import threading
|
10 | 10 | from datetime import datetime, timedelta
|
11 |
| -from typing import Dict, List, Optional |
| 11 | +from typing import Dict, List, Optional, Union |
12 | 12 |
|
13 | 13 | from cachetools import TTLCache, cached
|
14 | 14 | from oci.data_science.models import ModelDeploymentShapeSummary
|
15 | 15 | from pydantic import ValidationError
|
| 16 | +from rich.table import Table |
16 | 17 |
|
17 | 18 | from ads.aqua.app import AquaApp, logger
|
18 | 19 | from ads.aqua.common.entities import (
|
|
63 | 64 | ModelDeploymentConfigSummary,
|
64 | 65 | MultiModelDeploymentConfigLoader,
|
65 | 66 | )
|
66 |
| -from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME |
| 67 | +from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME, SHAPE_MAP |
67 | 68 | from ads.aqua.modeldeployment.entities import (
|
68 | 69 | AquaDeployment,
|
69 | 70 | AquaDeploymentDetail,
|
70 | 71 | ConfigValidationError,
|
71 | 72 | CreateModelDeploymentDetails,
|
72 | 73 | )
|
73 | 74 | from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
|
| 75 | +from ads.aqua.shaperecommend.recommend import AquaShapeRecommend |
| 76 | +from ads.aqua.shaperecommend.shape_report import ShapeRecommendationReport |
74 | 77 | from ads.common.object_storage_details import ObjectStorageDetails
|
75 | 78 | from ads.common.utils import UNKNOWN, get_log_links
|
76 | 79 | from ads.common.work_request import DataScienceWorkRequest
|
@@ -1243,6 +1246,107 @@ def validate_deployment_params(
|
1243 | 1246 | )
|
1244 | 1247 | return {"valid": True}
|
1245 | 1248 |
|
| 1249 | + def valid_compute_shapes(self, **kwargs) -> List["ComputeShapeSummary"]: |
| 1250 | + """ |
| 1251 | + Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file. |
| 1252 | +
|
| 1253 | + Parameters |
| 1254 | + ---------- |
| 1255 | + file : str |
| 1256 | + Path to the JSON file containing shape data. |
| 1257 | +
|
| 1258 | + Returns |
| 1259 | + ------- |
| 1260 | + List[ComputeShapeSummary] |
| 1261 | + List of ComputeShapeSummary objects passing the checks. |
| 1262 | +
|
| 1263 | + Raises |
| 1264 | + ------ |
| 1265 | + ValueError |
| 1266 | + If the file cannot be opened, parsed, or the 'shapes' key is missing. |
| 1267 | + """ |
| 1268 | + compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID) |
| 1269 | + oci_shapes: list[ModelDeploymentShapeSummary] = self.list_resource( |
| 1270 | + self.ds_client.list_model_deployment_shapes, |
| 1271 | + compartment_id=compartment_id, |
| 1272 | + **kwargs, |
| 1273 | + ) |
| 1274 | + set_user_shapes = {shape.name: shape for shape in oci_shapes} |
| 1275 | + |
| 1276 | + gpu_shapes_metadata = load_gpu_shapes_index().shapes |
| 1277 | + |
| 1278 | + valid_shapes = [] |
| 1279 | + # only loops through GPU shapes, update later to include CPU shapes |
| 1280 | + for name, spec in gpu_shapes_metadata.items(): |
| 1281 | + if name in set_user_shapes: |
| 1282 | + oci_shape = set_user_shapes.get(name) |
| 1283 | + |
| 1284 | + compute_shape = ComputeShapeSummary( |
| 1285 | + available=True, |
| 1286 | + core_count= oci_shape.core_count, |
| 1287 | + memory_in_gbs= oci_shape.memory_in_gbs, |
| 1288 | + shape_series= SHAPE_MAP.get(oci_shape.shape_series, "GPU"), |
| 1289 | + name= oci_shape.name, |
| 1290 | + gpu_specs= spec |
| 1291 | + ) |
| 1292 | + else: |
| 1293 | + compute_shape = ComputeShapeSummary( |
| 1294 | + available=False, name=name, shape_series="GPU", gpu_specs=spec |
| 1295 | + ) |
| 1296 | + valid_shapes.append(compute_shape) |
| 1297 | + |
| 1298 | + valid_shapes.sort( |
| 1299 | + key=lambda shape: shape.gpu_specs.gpu_memory_in_gbs, reverse=True |
| 1300 | + ) |
| 1301 | + return valid_shapes |
| 1302 | + |
| 1303 | + |
| 1304 | + def recommend_shape( |
| 1305 | + self, **kwargs |
| 1306 | + ) -> Union[Table, ShapeRecommendationReport]: |
| 1307 | + """ |
| 1308 | + For the CLI (set generate_table = True), generates the table (in rich diff) with valid |
| 1309 | + GPU deployment shapes for the provided model and configuration. |
| 1310 | +
|
| 1311 | + For the API (set generate_table = False), generates the JSON with valid |
| 1312 | + GPU deployment shapes for the provided model and configuration. |
| 1313 | +
|
| 1314 | + Validates if recommendations are generated, calls method to construct the rich diff |
| 1315 | + table with the recommendation data. |
| 1316 | +
|
| 1317 | + Parameters |
| 1318 | + ---------- |
| 1319 | + model_ocid : str |
| 1320 | + OCID of the model to recommend feasible compute shapes. |
| 1321 | +
|
| 1322 | + Returns |
| 1323 | + ------- |
| 1324 | + Table (generate_table = True) |
| 1325 | + A table format for the recommendation report with compatible deployment shapes |
| 1326 | + or troubleshooting info citing the largest shapes if no shape is suitable. |
| 1327 | +
|
| 1328 | + ShapeRecommendationReport (generate_table = False) |
| 1329 | + A recommendation report with compatible deployment shapes, or troubleshooting info |
| 1330 | + citing the largest shapes if no shape is suitable. |
| 1331 | +
|
| 1332 | + Raises |
| 1333 | + ------ |
| 1334 | + AquaValueError |
| 1335 | + If model type is unsupported by tool (no recommendation report generated) |
| 1336 | + """ |
| 1337 | + # generate_table = kwargs.pop( |
| 1338 | + # "generate_table", True |
| 1339 | + # ) # Generate rich diff table by default |
| 1340 | + compartment_id = kwargs.get("compartment_id", COMPARTMENT_OCID) |
| 1341 | + |
| 1342 | + kwargs["shapes"] = self.valid_compute_shapes(compartment_id=compartment_id) |
| 1343 | + |
| 1344 | + shape_recommend = AquaShapeRecommend() |
| 1345 | + |
| 1346 | + shape_recommend_report = shape_recommend.which_shapes(**kwargs) |
| 1347 | + |
| 1348 | + return shape_recommend_report |
| 1349 | + |
1246 | 1350 | @telemetry(entry_point="plugin=deployment&action=list_shapes", name="aqua")
|
1247 | 1351 | @cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
|
1248 | 1352 | def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
|
|
0 commit comments