From 03bae9aada269392aef262f2f49e581f533c2c4d Mon Sep 17 00:00:00 2001 From: tripathiji1312 Date: Wed, 24 Sep 2025 01:08:22 +0530 Subject: [PATCH] Add special case handling for L4 GPU in get_machine_spec and implement corresponding unit test --- xmanager/cloud/vertex.py | 6 +++++- xmanager/cloud/vertex_test.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/xmanager/cloud/vertex.py b/xmanager/cloud/vertex.py index f93bf73..3bdb2db 100644 --- a/xmanager/cloud/vertex.py +++ b/xmanager/cloud/vertex.py @@ -298,7 +298,11 @@ def get_machine_spec(job: xm.Job) -> Dict[str, Any]: for resource, value in requirements.task_requirements.items(): accelerator_type = None if resource in xm.GpuType: - accelerator_type = 'NVIDIA_TESLA_' + str(resource).upper() + # TODO(b/289373107): Remove this special case. + if resource == xm.GpuType.L4_24TH: + accelerator_type = 'NVIDIA_L4' + else: + accelerator_type = 'NVIDIA_TESLA_' + str(resource).upper() elif resource in xm.TpuType: accelerator_type = _CLOUD_TPU_ACCELERATOR_TYPES[resource] if accelerator_type: diff --git a/xmanager/cloud/vertex_test.py b/xmanager/cloud/vertex_test.py index c2c1573..2461c78 100644 --- a/xmanager/cloud/vertex_test.py +++ b/xmanager/cloud/vertex_test.py @@ -155,6 +155,24 @@ def test_get_machine_spec_a100(self): }, ) + def test_get_machine_spec_l4(self): + job = xm.Job( + executable=local_executables.GoogleContainerRegistryImage('name', ''), + executor=local_executors.Vertex( + requirements=xm.JobRequirements(l4_24th=2) + ), + args={}, + ) + machine_spec = vertex.get_machine_spec(job) + self.assertDictEqual( + machine_spec, + { + 'machine_type': 'g2-standard-4', + 'accelerator_type': vertex.aip_v1.AcceleratorType.NVIDIA_L4, + 'accelerator_count': 2, + }, + ) + def test_get_machine_spec_tpu(self): job = xm.Job( executable=local_executables.GoogleContainerRegistryImage('name', ''),