|
1 | 1 | #!/usr/bin/env python
|
2 |
| -# -*- coding: utf-8; -*- |
3 | 2 |
|
4 |
| -# Copyright (c) 2024 Oracle and/or its affiliates. |
| 3 | +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. |
5 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6 | 5 |
|
7 |
| -from functools import wraps |
8 | 6 | import logging
|
9 |
| -from typing import Callable, List |
10 |
| -from ads.common.oci_datascience import OCIDataScienceMixin |
11 |
| -from ads.common.work_request import DataScienceWorkRequest |
12 |
| -from ads.config import PROJECT_OCID |
13 |
| -from ads.model.deployment.common.utils import OCIClientManager, State |
14 |
| -import oci |
| 7 | +from functools import wraps |
| 8 | +from typing import Callable, List, Optional |
15 | 9 |
|
| 10 | +import oci |
16 | 11 | from oci.data_science.models import (
|
17 | 12 | CreateModelDeploymentDetails,
|
| 13 | + ModelDeploymentShapeSummary, |
18 | 14 | UpdateModelDeploymentDetails,
|
19 | 15 | )
|
20 | 16 |
|
| 17 | +from ads.common.oci_datascience import OCIDataScienceMixin |
| 18 | +from ads.common.work_request import DataScienceWorkRequest |
| 19 | +from ads.config import COMPARTMENT_OCID, PROJECT_OCID |
| 20 | +from ads.model.deployment.common.utils import OCIClientManager, State |
| 21 | + |
21 | 22 | DEFAULT_WAIT_TIME = 1200
|
22 | 23 | DEFAULT_POLL_INTERVAL = 10
|
23 | 24 | ALLOWED_STATUS = [
|
@@ -185,14 +186,13 @@ def activate(
|
185 | 186 | self.id,
|
186 | 187 | )
|
187 | 188 |
|
188 |
| - |
189 | 189 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
190 | 190 | if wait_for_completion:
|
191 | 191 | try:
|
192 | 192 | DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
|
193 | 193 | progress_bar_description="Activating model deployment",
|
194 |
| - max_wait_time=max_wait_time, |
195 |
| - poll_interval=poll_interval |
| 194 | + max_wait_time=max_wait_time, |
| 195 | + poll_interval=poll_interval, |
196 | 196 | )
|
197 | 197 | except Exception as e:
|
198 | 198 | logger.error(
|
@@ -239,8 +239,8 @@ def create(
|
239 | 239 | try:
|
240 | 240 | DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
|
241 | 241 | progress_bar_description="Creating model deployment",
|
242 |
| - max_wait_time=max_wait_time, |
243 |
| - poll_interval=poll_interval |
| 242 | + max_wait_time=max_wait_time, |
| 243 | + poll_interval=poll_interval, |
244 | 244 | )
|
245 | 245 | except Exception as e:
|
246 | 246 | logger.error("Error while trying to create model deployment: " + str(e))
|
@@ -290,8 +290,8 @@ def deactivate(
|
290 | 290 | try:
|
291 | 291 | DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
|
292 | 292 | progress_bar_description="Deactivating model deployment",
|
293 |
| - max_wait_time=max_wait_time, |
294 |
| - poll_interval=poll_interval |
| 293 | + max_wait_time=max_wait_time, |
| 294 | + poll_interval=poll_interval, |
295 | 295 | )
|
296 | 296 | except Exception as e:
|
297 | 297 | logger.error(
|
@@ -351,14 +351,14 @@ def delete(
|
351 | 351 | response = self.client.delete_model_deployment(
|
352 | 352 | self.id,
|
353 | 353 | )
|
354 |
| - |
| 354 | + |
355 | 355 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
356 | 356 | if wait_for_completion:
|
357 | 357 | try:
|
358 | 358 | DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
|
359 | 359 | progress_bar_description="Deleting model deployment",
|
360 |
| - max_wait_time=max_wait_time, |
361 |
| - poll_interval=poll_interval |
| 360 | + max_wait_time=max_wait_time, |
| 361 | + poll_interval=poll_interval, |
362 | 362 | )
|
363 | 363 | except Exception as e:
|
364 | 364 | logger.error("Error while trying to delete model deployment: " + str(e))
|
@@ -493,3 +493,30 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment":
|
493 | 493 | An instance of `OCIDataScienceModelDeployment`.
|
494 | 494 | """
|
495 | 495 | return super().from_ocid(model_deployment_id)
|
| 496 | + |
| 497 | + @classmethod |
| 498 | + def shapes( |
| 499 | + cls, |
| 500 | + compartment_id: Optional[str] = None, |
| 501 | + **kwargs, |
| 502 | + ) -> List[ModelDeploymentShapeSummary]: |
| 503 | + """ |
| 504 | + Retrieves all available model deployment shapes in the given compartment. |
| 505 | +
|
| 506 | + This method uses OCI's pagination utility to fetch all pages of model |
| 507 | + deployment shape summaries available in the specified compartment. |
| 508 | +
|
| 509 | + Args: |
| 510 | + compartment_id (Optional[str]): The OCID of the compartment. If not provided, |
| 511 | + the default COMPARTMENT_ID extracted form env variables is used. |
| 512 | + **kwargs: Additional keyword arguments to pass to the list_model_deployments call. |
| 513 | +
|
| 514 | + Returns: |
| 515 | + List[ModelDeploymentShapeSummary]: A list of all model deployment shape summaries. |
| 516 | + """ |
| 517 | + client = cls().client |
| 518 | + compartment_id = compartment_id or COMPARTMENT_OCID |
| 519 | + |
| 520 | + return oci.pagination.list_call_get_all_results( |
| 521 | + client.list_model_deployment_shapes, compartment_id, **kwargs |
| 522 | + ).data |
0 commit comments