Skip to content

Commit 8dcad29

Browse files
authored
(torchx/schedulers) Restore Scheduler iface generic types
Differential Revision: D87901987 Pull Request resolved: #1167
1 parent b0d8dbb commit 8dcad29

File tree

10 files changed

+21
-39
lines changed

10 files changed

+21
-39
lines changed

torchx/schedulers/api.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from dataclasses import dataclass, field
1212
from datetime import datetime
1313
from enum import Enum
14-
from typing import Generic, Iterable, List, Optional, TypeVar, Union
14+
from typing import Generic, Iterable, List, Optional, TypeVar
1515

1616
from torchx.specs import (
1717
AppDef,
18+
AppDryRunInfo,
1819
AppState,
1920
NONE,
2021
NULL_RESOURCE,
@@ -95,11 +96,9 @@ def __hash__(self) -> int:
9596

9697

9798
T = TypeVar("T")
98-
A = TypeVar("A")
99-
D = TypeVar("D")
10099

101100

102-
class Scheduler(abc.ABC, Generic[T, A, D]):
101+
class Scheduler(abc.ABC, Generic[T]):
103102
"""
104103
An interface abstracting functionalities of a scheduler.
105104
Implementers need only implement those methods annotated with
@@ -129,7 +128,7 @@ def close(self) -> None:
129128

130129
def submit(
131130
self,
132-
app: A,
131+
app: AppDef,
133132
cfg: T,
134133
workspace: str | Workspace | None = None,
135134
) -> str:
@@ -157,7 +156,7 @@ def submit(
157156
return self.schedule(dryrun_info)
158157

159158
@abc.abstractmethod
160-
def schedule(self, dryrun_info: D) -> str:
159+
def schedule(self, dryrun_info: AppDryRunInfo) -> str:
161160
"""
162161
Same as ``submit`` except that it takes an ``AppDryRunInfo``.
163162
Implementers are encouraged to implement this method rather than
@@ -173,7 +172,7 @@ def schedule(self, dryrun_info: D) -> str:
173172

174173
raise NotImplementedError()
175174

176-
def submit_dryrun(self, app: A, cfg: T) -> D:
175+
def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
177176
"""
178177
Rather than submitting the request to run the app, returns the
179178
request object that would have been submitted to the underlying
@@ -187,15 +186,15 @@ def submit_dryrun(self, app: A, cfg: T) -> D:
187186
# pyre-fixme: _submit_dryrun takes Generic type for resolved_cfg
188187
dryrun_info = self._submit_dryrun(app, resolved_cfg)
189188

190-
if isinstance(app, AppDef):
191-
for role in app.roles:
192-
dryrun_info = role.pre_proc(self.backend, dryrun_info)
189+
for role in app.roles:
190+
dryrun_info = role.pre_proc(self.backend, dryrun_info)
191+
193192
dryrun_info._app = app
194193
dryrun_info._cfg = resolved_cfg
195194
return dryrun_info
196195

197196
@abc.abstractmethod
198-
def _submit_dryrun(self, app: A, cfg: T) -> D:
197+
def _submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
199198
raise NotImplementedError()
200199

201200
def run_opts(self) -> runopts:
@@ -394,15 +393,12 @@ def _pre_build_validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
394393
"""
395394
pass
396395

397-
def _validate(self, app: A, scheduler: str, cfg: T) -> None:
396+
def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
398397
"""
399398
Validates after workspace build whether application is consistent with the scheduler.
400399
401400
Raises error if application is not compatible with scheduler
402401
"""
403-
if not isinstance(app, AppDef):
404-
return
405-
406402
for role in app.roles:
407403
if role.resource == NULL_RESOURCE:
408404
raise ValueError(

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def wrapper() -> T:
381381

382382

383383
@_thread_local_cache
384-
def _local_session() -> "boto3.session.Session":
384+
def _local_session() -> "boto3.session.Session": # noqa: F821
385385
import boto3.session
386386

387387
return boto3.session.Session()
@@ -399,9 +399,7 @@ class AWSBatchOpts(TypedDict, total=False):
399399
ulimits: Optional[list[str]]
400400

401401

402-
class AWSBatchScheduler(
403-
DockerWorkspaceMixin, Scheduler[AWSBatchOpts, AppDef, AppDryRunInfo[BatchJob]]
404-
):
402+
class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
405403
"""
406404
AWSBatchScheduler is a TorchX scheduling interface to AWS Batch.
407405

torchx/schedulers/aws_sagemaker_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _merge_ordered(
157157

158158
class AWSSageMakerScheduler(
159159
DockerWorkspaceMixin,
160-
Scheduler[AWSSageMakerOpts, AppDef, AppDryRunInfo[AWSSageMakerJob]],
160+
Scheduler[AWSSageMakerOpts],
161161
):
162162
"""
163163
AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker.

torchx/schedulers/docker_scheduler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,7 @@ class DockerOpts(TypedDict, total=False):
129129
privileged: bool
130130

131131

132-
class DockerScheduler(
133-
DockerWorkspaceMixin, Scheduler[DockerOpts, AppDef, AppDryRunInfo[DockerJob]]
134-
):
132+
class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
135133
"""
136134
DockerScheduler is a TorchX scheduling interface to Docker.
137135

torchx/schedulers/kubernetes_mcad_scheduler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -796,10 +796,7 @@ class KubernetesMCADOpts(TypedDict, total=False):
796796
network: Optional[str]
797797

798798

799-
class KubernetesMCADScheduler(
800-
DockerWorkspaceMixin,
801-
Scheduler[KubernetesMCADOpts, AppDef, AppDryRunInfo[KubernetesMCADJob]],
802-
):
799+
class KubernetesMCADScheduler(DockerWorkspaceMixin, Scheduler[KubernetesMCADOpts]):
803800
"""
804801
KubernetesMCADScheduler is a TorchX scheduling interface to Kubernetes.
805802

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,10 +591,7 @@ class KubernetesOpts(TypedDict, total=False):
591591
validate_spec: Optional[bool]
592592

593593

594-
class KubernetesScheduler(
595-
DockerWorkspaceMixin,
596-
Scheduler[KubernetesOpts, AppDef, AppDryRunInfo[KubernetesJob]],
597-
):
594+
class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]):
598595
"""
599596
KubernetesScheduler is a TorchX scheduling interface to Kubernetes.
600597

torchx/schedulers/local_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def _register_termination_signals() -> None:
529529
signal.signal(signal.SIGINT, _terminate_process_handler)
530530

531531

532-
class LocalScheduler(Scheduler[LocalOpts, AppDef, AppDryRunInfo[PopenRequest]]):
532+
class LocalScheduler(Scheduler[LocalOpts]):
533533
"""
534534
Schedules on localhost. Containers are modeled as processes and
535535
certain properties of the container that are either not relevant

torchx/schedulers/lsf_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def __repr__(self) -> str:
394394
{self.materialize()}"""
395395

396396

397-
class LsfScheduler(Scheduler[LsfOpts, AppDef, AppDryRunInfo]):
397+
class LsfScheduler(Scheduler[LsfOpts]):
398398
"""
399399
**Example: hello_world**
400400

torchx/schedulers/slurm_scheduler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,7 @@ def __repr__(self) -> str:
335335
{self.materialize()}"""
336336

337337

338-
class SlurmScheduler(
339-
DirWorkspaceMixin, Scheduler[SlurmOpts, AppDef, AppDryRunInfo[SlurmBatchRequest]]
340-
):
338+
class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
341339
"""
342340
SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects
343341
that slurm CLI tools are locally installed and job accounting is enabled.

torchx/schedulers/test/api_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,10 @@
3535
from torchx.workspace.api import WorkspaceMixin
3636

3737
T = TypeVar("T")
38-
A = TypeVar("A")
39-
D = TypeVar("D")
4038

4139

4240
class SchedulerTest(unittest.TestCase):
43-
class MockScheduler(Scheduler[T, A, D], WorkspaceMixin[None]):
41+
class MockScheduler(Scheduler[T], WorkspaceMixin[None]):
4442
def __init__(self, session_name: str) -> None:
4543
super().__init__("mock", session_name)
4644

0 commit comments

Comments
 (0)