Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion python/gigl/common/services/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name

DEFAULT_PIPELINE_TIMEOUT_S: Final[int] = 60 * 60 * 36 # 36 hours
DEFAULT_CUSTOM_JOB_TIMEOUT_S: Final[int] = 60 * 60 * 24 # 24 hours
BOOT_DISK_PLACEHOLDER: Final[str] = "DISK_TYPE_UNSPECIFIED"


@dataclass
Expand All @@ -98,7 +99,7 @@ class VertexAiJobConfig:
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED"
accelerator_count: int = 0
replica_count: int = 1
boot_disk_type: str = "pd-ssd" # Persistent Disk SSD
boot_disk_type: str = BOOT_DISK_PLACEHOLDER # Persistent Disk SSD
boot_disk_size_gb: int = 100 # Default disk size in GB
labels: Optional[dict[str, str]] = None
timeout_s: Optional[
Expand All @@ -107,6 +108,15 @@ class VertexAiJobConfig:
enable_web_access: bool = True
scheduling_strategy: Optional[aiplatform.gapic.Scheduling.Strategy] = None

def __post_init__(self):
if self.boot_disk_type is BOOT_DISK_PLACEHOLDER:
if self.machine_type.startswith("g4-"):
logger.info(f"No boot disk type set, and g4 machine detected, using hyperdisk-balanced")
self.boot_disk_type = "hyperdisk-balanced" # g4 machines require use of hyperdisk-balanced
else:
logger.info(f"No boot disk type set, using pd-ssd")
self.boot_disk_type = "pd-ssd"


class VertexAIService:
"""
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions python/tests/unit/common/services/vertex_ai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

from parameterized import param, parameterized

from gigl.common.services.vertex_ai import VertexAiJobConfig


class VertexAIServiceTest(unittest.TestCase):
@parameterized.expand(
[
param(
"g4 machine ; should default to hyperdisk-balanced",
machine_type="g4-standard-8",
expected_boot_disk_type="hyperdisk-balanced",
),
param(
"n1 machine ; should default to pd-ssd",
machine_type="n1-standard-4",
expected_boot_disk_type="pd-ssd",
),
]
)
def test_default_boot_disk_for_machine(
self, _, machine_type, expected_boot_disk_type
):
job_config = VertexAiJobConfig(
job_name="job_name",
container_uri="container_uri",
command=["command"],
machine_type=machine_type,
)
self.assertEqual(job_config.boot_disk_type, expected_boot_disk_type)


if __name__ == "__main__":
unittest.main()