Skip to content

Commit e99b854

Browse files
authored
(torchx/schedulers)(slurm) Add account to runopts
Differential Revision: D88080923 Pull Request resolved: #1169
1 parent 8dcad29 commit e99b854

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

torchx/schedulers/slurm_scheduler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def _should_use_gpus_per_node_from_version() -> bool:
135135
"comment",
136136
"mail-user",
137137
"mail-type",
138+
"account",
138139
}
139140
SBATCH_GROUP_OPTIONS = {
140141
"partition",
@@ -159,6 +160,7 @@ def _apply_app_id_env(s: str) -> str:
159160
SlurmOpts = TypedDict(
160161
"SlurmOpts",
161162
{
163+
"account": Optional[str],
162164
"partition": str,
163165
"time": str,
164166
"comment": Optional[str],
@@ -404,6 +406,12 @@ def __init__(self, session_name: str) -> None:
404406

405407
def _run_opts(self) -> runopts:
406408
opts = runopts()
409+
opts.add(
410+
"account",
411+
type_=str,
412+
help="The account to use for the slurm job.",
413+
default=None,
414+
)
407415
opts.add(
408416
"partition",
409417
type_=str,

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def test_submit_dryrun_tags(self, _) -> None:
159159
def test_submit_dryrun_job_role_arn(self) -> None:
160160
cfg = AWSBatchOpts({"queue": "ignored_in_test", "job_role_arn": "fizzbuzz"})
161161
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
162-
# pyre-ignore[16]
163162
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
164163
self.assertEqual(1, len(node_groups))
165164
self.assertEqual(cfg["job_role_arn"], node_groups[0]["container"]["jobRoleArn"])
@@ -169,7 +168,6 @@ def test_submit_dryrun_execution_role_arn(self) -> None:
169168
{"queue": "ignored_in_test", "execution_role_arn": "veryexecutive"}
170169
)
171170
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
172-
# pyre-ignore[16]
173171
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
174172
self.assertEqual(1, len(node_groups))
175173
self.assertEqual(
@@ -179,7 +177,6 @@ def test_submit_dryrun_execution_role_arn(self) -> None:
179177
def test_submit_dryrun_privileged(self) -> None:
180178
cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})
181179
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
182-
# pyre-ignore[16]
183180
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
184181
self.assertEqual(1, len(node_groups))
185182
self.assertTrue(node_groups[0]["container"]["privileged"])
@@ -189,7 +186,6 @@ def test_submit_dryrun_instance_type_multinode(self) -> None:
189186
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
190187
app = _test_app(num_replicas=2, resource=resource)
191188
info = create_scheduler("test").submit_dryrun(app, cfg)
192-
# pyre-ignore[16]
193189
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
194190
self.assertEqual(1, len(node_groups))
195191
self.assertEqual(
@@ -202,7 +198,6 @@ def test_submit_dryrun_instance_type_singlenode(self) -> None:
202198
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
203199
app = _test_app(num_replicas=1, resource=resource)
204200
info = create_scheduler("test").submit_dryrun(app, cfg)
205-
# pyre-ignore[16]
206201
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
207202
self.assertEqual(1, len(node_groups))
208203
self.assertTrue("instanceType" in node_groups[0]["container"])
@@ -212,7 +207,6 @@ def test_submit_dryrun_no_instance_type_non_aws(self) -> None:
212207
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
213208
app = _test_app(num_replicas=2)
214209
info = create_scheduler("test").submit_dryrun(app, cfg)
215-
# pyre-ignore[16]
216210
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
217211
self.assertEqual(1, len(node_groups))
218212
self.assertTrue("instanceType" not in node_groups[0]["container"])

torchx/schedulers/test/slurm_scheduler_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,24 @@ def test_dryrun_comment(self, mock_version: MagicMock) -> None:
696696
info.request.cmd,
697697
)
698698

699+
@patch(
700+
"torchx.schedulers.slurm_scheduler.version",
701+
return_value=SLURM_VERSION_24_5,
702+
)
703+
def test_account(self, mock_version: MagicMock) -> None:
704+
scheduler = create_scheduler("foo")
705+
app = simple_app()
706+
info = scheduler.submit_dryrun(
707+
app,
708+
cfg={
709+
"account": "foobar",
710+
},
711+
)
712+
self.assertIn(
713+
"--account=foobar",
714+
info.request.cmd,
715+
)
716+
699717
@patch(
700718
"torchx.schedulers.slurm_scheduler.version",
701719
return_value=SLURM_VERSION_24_5,

0 commit comments

Comments
 (0)