From 12526d49c212358559750b1b28cbbeeb72334315 Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Fri, 3 Oct 2025 20:18:09 -0700 Subject: [PATCH] (torchx/specs) Allow roles to specify their own workspaces (#1139) Summary: So far you can only specify the workspace in the runner's API: ``` runner.dryrun(appdef, cfg, workspace=...) ``` in which case the workspace applies to ONLY `role[0]`. This behavior was intentional since multi-role usecases of TorchX typically had a single "main" role that the application owner actually owned and the other roles were prepackaged apps (not part of your project). This is no longer the case with applications such as reenforcement learning where the project encompasses multiple applications (e.g. trainer, generator, etc) therefore we need a more flexible way to specify a workspace per Role. For BC this I'm maintaining the following behavior: 1. If `workspace` is specified as a runner argument then it takes precedence over `role[0].workspace` 2. Non-zero roles (e.g. `role[1], role[2], ...`) are unaffected by the workspace argument. That is their workspace attributes (e.g. `role[1].workspace`) are respected as is. 3. "disabling" workspace (e.g. passing `workspace=None` from the runner argument) can still build a workspace if the role's workspace attribute is not `None`. NOTE: we need to do further optimization for cases where multiple roles have the same "image" and "workspace". In this case we only need to build the image+workspace once. But as it stands we end up building a separate ephemeral per role (even if the ephemeral is the SAME across all the roles). This isn't an issue practically since image builders like Docker are content addressed and caches layers. Reviewed By: AbishekS Differential Revision: D83793199 --- torchx/runner/api.py | 56 +++++++++++------- torchx/runner/test/api_test.py | 98 +++++++++++++++++++++++-------- torchx/specs/__init__.py | 3 + torchx/specs/api.py | 83 +++++++++++++++++++++++++- torchx/specs/test/api_test.py | 79 ++++++++++++++++++++++++- torchx/workspace/api.py | 67 +-------------------- torchx/workspace/test/api_test.py | 71 +++------------------- 7 files changed, 279 insertions(+), 178 deletions(-) diff --git a/torchx/runner/api.py b/torchx/runner/api.py index 08d732238..bf448d774 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -426,26 +426,42 @@ def dryrun( sched._pre_build_validate(app, scheduler, resolved_cfg) - if workspace and isinstance(sched, WorkspaceMixin): - role = app.roles[0] - old_img = role.image - - logger.info(f"Checking for changes in workspace `{workspace}`...") - logger.info( - 'To disable workspaces pass: --workspace="" from CLI or workspace=None programmatically.' - ) - sched.build_workspace_and_update_role2(role, workspace, resolved_cfg) - - if old_img != role.image: - logger.info( - f"Built new image `{role.image}` based on original image `{old_img}`" - f" and changes in workspace `{workspace}` for role[0]={role.name}." - ) - else: - logger.info( - f"Reusing original image `{old_img}` for role[0]={role.name}." - " Either a patch was built or no changes to workspace was detected." - ) + if isinstance(sched, WorkspaceMixin): + for i, role in enumerate(app.roles): + role_workspace = role.workspace + + if i == 0 and workspace: + # NOTE: torchx originally took workspace as a runner arg and only applied the workspace to role[0] + # later, torchx added support for the workspace attr in Role + # for BC, give precedence to the workspace argument over the workspace attr for role[0] + if role_workspace: + logger.info( + f"Using workspace={workspace} over role[{i}].workspace={role_workspace} for role[{i}]={role.name}." + " To use the role's workspace attr pass: --workspace='' from CLI or workspace=None programmatically." # noqa: B950 + ) + role_workspace = workspace + + if role_workspace: + old_img = role.image + logger.info( + f"Checking for changes in workspace `{role_workspace}` for role[{i}]={role.name}..." + ) + # TODO kiuk@ once we deprecate the `workspace` argument in runner APIs we can simplify the signature of + # build_workspace_and_update_role2() to just taking the role and resolved_cfg + sched.build_workspace_and_update_role2( + role, role_workspace, resolved_cfg + ) + + if old_img != role.image: + logger.info( + f"Built new image `{role.image}` based on original image `{old_img}`" + f" and changes in workspace `{role_workspace}` for role[{i}]={role.name}." + ) + else: + logger.info( + f"Reusing original image `{old_img}` for role[{i}]={role.name}." + " Either a patch was built or no changes to workspace was detected." + ) sched._validate(app, scheduler, resolved_cfg) dryrun_info = sched.submit_dryrun(app, resolved_cfg) diff --git a/torchx/runner/test/api_test.py b/torchx/runner/test/api_test.py index 16181e728..ec0db6daa 100644 --- a/torchx/runner/test/api_test.py +++ b/torchx/runner/test/api_test.py @@ -20,16 +20,18 @@ create_scheduler, LocalDirectoryImageProvider, ) -from torchx.specs import AppDryRunInfo, CfgVal -from torchx.specs.api import ( +from torchx.specs import ( AppDef, + AppDryRunInfo, AppHandle, AppState, + CfgVal, parse_app_handle, Resource, Role, runopts, UnknownAppException, + Workspace, ) from torchx.specs.finder import ComponentNotFoundException from torchx.test.fixtures import TestWithTmpDir @@ -400,6 +402,16 @@ def build_workspace_and_update_role( ) -> None: if self.build_new_img: role.image = f"{role.image}_new" + role.env["SRC_WORKSPACE"] = workspace + + def create_role(image: str, workspace: str | None = None) -> Role: + return Role( + name="noop", + image=image, + resource=resource.SMALL, + entrypoint="/bin/true", + workspace=Workspace.from_str(workspace), + ) with Runner( name=SESSION_NAME, @@ -411,33 +423,71 @@ def build_workspace_and_update_role( "builds-img": lambda name, **kwargs: TestScheduler(build_new_img=True), }, ) as runner: + app = AppDef( + "ignored", + roles=[create_role(image="foo"), create_role(image="bar")], + ) + roles = runner.dryrun( + app, "no-build-img", workspace="//workspace" + ).request.roles + self.assertEqual("foo", roles[0].image) + self.assertEqual("bar", roles[1].image) + + roles = runner.dryrun( + app, "builds-img", workspace="//workspace" + ).request.roles + + # workspace is attached to role[0] when role[0].workspace is `None` + self.assertEqual("foo_new", roles[0].image) + self.assertEqual("bar", roles[1].image) + + # now run with role[0] having workspace attribute defined app = AppDef( "ignored", roles=[ - Role( - name="sleep", - image="foo", - resource=resource.SMALL, - entrypoint="sleep", - args=["1"], - ), - Role( - name="sleep", - image="bar", - resource=resource.SMALL, - entrypoint="sleep", - args=["1"], - ), + create_role(image="foo", workspace="//should_be_overriden"), + create_role(image="bar"), + ], + ) + roles = runner.dryrun( + app, "builds-img", workspace="//workspace" + ).request.roles + # workspace argument should override role[0].workspace attribute + self.assertEqual("foo_new", roles[0].image) + self.assertEqual("//workspace", roles[0].env["SRC_WORKSPACE"]) + self.assertEqual("bar", roles[1].image) + + # now run with both role[0] and role[1] having workspace attr + app = AppDef( + "ignored", + roles=[ + create_role(image="foo", workspace="//foo"), + create_role(image="bar", workspace="//bar"), + ], + ) + roles = runner.dryrun( + app, "builds-img", workspace="//workspace" + ).request.roles + + # workspace argument should override role[0].workspace attribute + self.assertEqual("foo_new", roles[0].image) + self.assertEqual("//workspace", roles[0].env["SRC_WORKSPACE"]) + self.assertEqual("bar_new", roles[1].image) + self.assertEqual("//bar", roles[1].env["SRC_WORKSPACE"]) + + # now run with both role[0] and role[1] having workspace attr but no workspace arg + app = AppDef( + "ignored", + roles=[ + create_role(image="foo", workspace="//foo"), + create_role(image="bar", workspace="//bar"), ], ) - dryruninfo = runner.dryrun(app, "no-build-img", workspace="//workspace") - self.assertEqual("foo", dryruninfo.request.roles[0].image) - self.assertEqual("bar", dryruninfo.request.roles[1].image) - - dryruninfo = runner.dryrun(app, "builds-img", workspace="//workspace") - # workspace is attached to role[0] by default - self.assertEqual("foo_new", dryruninfo.request.roles[0].image) - self.assertEqual("bar", dryruninfo.request.roles[1].image) + roles = runner.dryrun(app, "builds-img", workspace=None).request.roles + self.assertEqual("foo_new", roles[0].image) + self.assertEqual("//foo", roles[0].env["SRC_WORKSPACE"]) + self.assertEqual("bar_new", roles[1].image) + self.assertEqual("//bar", roles[1].env["SRC_WORKSPACE"]) def test_describe(self, _) -> None: with self.get_runner() as runner: diff --git a/torchx/specs/__init__.py b/torchx/specs/__init__.py index c31eb9365..beb9f847c 100644 --- a/torchx/specs/__init__.py +++ b/torchx/specs/__init__.py @@ -45,6 +45,7 @@ UnknownAppException, UnknownSchedulerException, VolumeMount, + Workspace, ) from torchx.specs.builders import make_app_handle, materialize_appdef, parse_mounts @@ -236,4 +237,6 @@ def gpu_x_1() -> Dict[str, Resource]: "torchx_run_args_from_json", "TorchXRunArgs", "ALL", + "TORCHX_HOME", + "Workspace", ] diff --git a/torchx/specs/api.py b/torchx/specs/api.py index f50bc4619..3ec7a11da 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -350,6 +350,78 @@ class DeviceMount: permissions: str = "rwm" +@dataclass +class Workspace: + """ + Specifies a local "workspace" (a set of directories). Workspaces are ad-hoc built + into an (usually ephemeral) image. This effectively mirrors the local code changes + at job submission time. + + For example: + + 1. ``projects={"~/github/torch": "torch"}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/torch/**`` + 2. ``projects={"~/github/torch": ""}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/**`` + + The exact location of ``$REMOTE_WORKSPACE_ROOT`` is implementation dependent and varies between + different implementations of :py:class:`~torchx.workspace.api.WorkspaceMixin`. + Check the scheduler documentation for details on which workspace it supports. + + Note: ``projects`` maps the location of the local project to a sub-directory in the remote workspace root directory. + Typically the local project location is a directory path (e.g. ``/home/foo/github/torch``). + + + Attributes: + projects: mapping of local project to the sub-dir in the remote workspace dir. + """ + + projects: dict[str, str] + + def __bool__(self) -> bool: + """False if no projects mapping. Lets us use workspace object in an if-statement""" + return bool(self.projects) + + def is_unmapped_single_project(self) -> bool: + """ + Returns ``True`` if this workspace only has 1 project + and its target mapping is an empty string. + """ + return len(self.projects) == 1 and not next(iter(self.projects.values())) + + @staticmethod + def from_str(workspace: str | None) -> "Workspace": + import yaml + + if not workspace: + return Workspace({}) + + projects = yaml.safe_load(workspace) + if isinstance(projects, str): # single project workspace + projects = {projects: ""} + else: # multi-project workspace + # Replace None mappings with "" (empty string) + projects = {k: ("" if v is None else v) for k, v in projects.items()} + + return Workspace(projects) + + def __str__(self) -> str: + """ + Returns a string representation of the Workspace by concatenating + the project mappings using ';' as a delimiter and ':' between key and value. + If the single-project workspace with no target mapping, then simply + returns the src (local project dir) + + NOTE: meant to be used for logging purposes not serde. + Therefore not symmetric with :py:func:`Workspace.from_str`. + + """ + if self.is_unmapped_single_project(): + return next(iter(self.projects)) + else: + return ";".join( + k if not v else f"{k}:{v}" for k, v in self.projects.items() + ) + + @dataclass class Role: """ @@ -402,6 +474,10 @@ class Role: metadata: Free form information that is associated with the role, for example scheduler specific data. The key should follow the pattern: ``$scheduler.$key`` mounts: a list of mounts on the machine + workspace: local project directories to be mirrored on the remote job. + NOTE: The workspace argument provided to the :py:class:`~torchx.runner.api.Runner` APIs + only takes effect on ``appdef.role[0]`` and overrides this attribute. + """ name: str @@ -417,9 +493,10 @@ class Role: resource: Resource = field(default_factory=_null_resource) port_map: Dict[str, int] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) - mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field( - default_factory=list - ) + mounts: List[BindMount | VolumeMount | DeviceMount] = field(default_factory=list) + workspace: Workspace | None = None + + # DEPRECATED DO NOT SET, WILL BE REMOVED SOON overrides: Dict[str, Any] = field(default_factory=dict) # pyre-ignore diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index c99d6f700..6bbacd5ee 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -19,8 +19,7 @@ from unittest import mock from unittest.mock import MagicMock -import torchx.specs.named_resources_aws as named_resources_aws -from torchx.specs import named_resources, resource +from torchx.specs import named_resources, named_resources_aws, resource from torchx.specs.api import ( _TERMINAL_STATES, AppDef, @@ -44,6 +43,7 @@ runopt, runopts, TORCHX_HOME, + Workspace, ) @@ -74,6 +74,81 @@ def test_TORCHX_HOME_override(self) -> None: self.assertTrue(conda_pack_out.is_dir()) +class WorkspaceTest(unittest.TestCase): + + def test_bool(self) -> None: + self.assertFalse(Workspace(projects={})) + self.assertFalse(Workspace.from_str("")) + + self.assertTrue(Workspace(projects={"/home/foo/bar": ""})) + self.assertTrue(Workspace.from_str("/home/foo/bar")) + + def test_to_string_single_project_workspace(self) -> None: + self.assertEqual( + "/home/foo/bar", + str(Workspace(projects={"/home/foo/bar": ""})), + ) + + def test_to_string_multi_project_workspace(self) -> None: + workspace = Workspace( + projects={ + "/home/foo/workspace/myproj": "", + "/home/foo/github/torch": "torch", + } + ) + + self.assertEqual( + "/home/foo/workspace/myproj;/home/foo/github/torch:torch", + str(workspace), + ) + + def test_is_unmapped_single_project_workspace(self) -> None: + self.assertTrue( + Workspace(projects={"/home/foo/bar": ""}).is_unmapped_single_project() + ) + + self.assertFalse( + Workspace(projects={"/home/foo/bar": "baz"}).is_unmapped_single_project() + ) + + self.assertFalse( + Workspace( + projects={"/home/foo/bar": "", "/home/foo/torch": ""} + ).is_unmapped_single_project() + ) + + self.assertFalse( + Workspace( + projects={"/home/foo/bar": "", "/home/foo/torch": "pytorch"} + ).is_unmapped_single_project() + ) + + def test_from_str_single_project(self) -> None: + self.assertDictEqual( + {"/home/foo/bar": ""}, + Workspace.from_str("/home/foo/bar").projects, + ) + + self.assertDictEqual( + {"/home/foo/bar": "baz"}, + Workspace.from_str("/home/foo/bar: baz").projects, + ) + + def test_from_str_multi_project(self) -> None: + self.assertDictEqual( + { + "/home/foo/bar": "", + "/home/foo/third-party/verl": "verl", + }, + Workspace.from_str( + """# +/home/foo/bar: +/home/foo/third-party/verl: verl +""" + ).projects, + ) + + class AppDryRunInfoTest(unittest.TestCase): def test_repr(self) -> None: request_mock = MagicMock() diff --git a/torchx/workspace/api.py b/torchx/workspace/api.py index 3a5ffbb5e..98b7c949d 100644 --- a/torchx/workspace/api.py +++ b/torchx/workspace/api.py @@ -26,7 +26,7 @@ Union, ) -from torchx.specs import AppDef, CfgVal, Role, runopts +from torchx.specs import AppDef, CfgVal, Role, runopts, Workspace if TYPE_CHECKING: from fsspec import AbstractFileSystem @@ -88,71 +88,6 @@ def build_workspace(self, sync: bool = True) -> PkgInfo[PackageType]: pass -@dataclass -class Workspace: - """ - Specifies a local "workspace" (a set of directories). Workspaces are ad-hoc built - into an (usually ephemeral) image. This effectively mirrors the local code changes - at job submission time. - - For example: - - 1. ``projects={"~/github/torch": "torch"}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/torch/**`` - 2. ``projects={"~/github/torch": ""}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/**`` - - The exact location of ``$REMOTE_WORKSPACE_ROOT`` is implementation dependent and varies between - different implementations of :py:class:`~torchx.workspace.api.WorkspaceMixin`. - Check the scheduler documentation for details on which workspace it supports. - - Note: ``projects`` maps the location of the local project to a sub-directory in the remote workspace root directory. - Typically the local project location is a directory path (e.g. ``/home/foo/github/torch``). - - - Attributes: - projects: mapping of local project to the sub-dir in the remote workspace dir. - """ - - projects: dict[str, str] - - def is_unmapped_single_project(self) -> bool: - """ - Returns ``True`` if this workspace only has 1 project - and its target mapping is an empty string. - """ - return len(self.projects) == 1 and not next(iter(self.projects.values())) - - @staticmethod - def from_str(workspace: str) -> "Workspace": - import yaml - - projects = yaml.safe_load(workspace) - if isinstance(projects, str): # single project workspace - projects = {projects: ""} - else: # multi-project workspace - # Replace None mappings with "" (empty string) - projects = {k: ("" if v is None else v) for k, v in projects.items()} - - return Workspace(projects) - - def __str__(self) -> str: - """ - Returns a string representation of the Workspace by concatenating - the project mappings using ';' as a delimiter and ':' between key and value. - If the single-project workspace with no target mapping, then simply - returns the src (local project dir) - - NOTE: meant to be used for logging purposes not serde. - Therefore not symmetric with :py:func:`Workspace.from_str`. - - """ - if self.is_unmapped_single_project(): - return next(iter(self.projects)) - else: - return ";".join( - k if not v else f"{k}:{v}" for k, v in self.projects.items() - ) - - class WorkspaceMixin(abc.ABC, Generic[T]): """ Note: (Prototype) this interface may change without notice! diff --git a/torchx/workspace/test/api_test.py b/torchx/workspace/test/api_test.py index 352d48530..2daf51f5e 100644 --- a/torchx/workspace/test/api_test.py +++ b/torchx/workspace/test/api_test.py @@ -34,70 +34,15 @@ def build_workspace_and_update_role( class WorkspaceTest(TestWithTmpDir): - def test_to_string_single_project_workspace(self) -> None: - self.assertEqual( - "/home/foo/bar", - str(Workspace(projects={"/home/foo/bar": ""})), - ) - - def test_to_string_multi_project_workspace(self) -> None: - workspace = Workspace( - projects={ - "/home/foo/workspace/myproj": "", - "/home/foo/github/torch": "torch", - } - ) - - self.assertEqual( - "/home/foo/workspace/myproj;/home/foo/github/torch:torch", - str(workspace), - ) - - def test_is_unmapped_single_project_workspace(self) -> None: - self.assertTrue( - Workspace(projects={"/home/foo/bar": ""}).is_unmapped_single_project() - ) - - self.assertFalse( - Workspace(projects={"/home/foo/bar": "baz"}).is_unmapped_single_project() - ) - - self.assertFalse( - Workspace( - projects={"/home/foo/bar": "", "/home/foo/torch": ""} - ).is_unmapped_single_project() - ) - - self.assertFalse( - Workspace( - projects={"/home/foo/bar": "", "/home/foo/torch": "pytorch"} - ).is_unmapped_single_project() - ) - - def test_from_str_single_project(self) -> None: - self.assertDictEqual( - {"/home/foo/bar": ""}, - Workspace.from_str("/home/foo/bar").projects, - ) - - self.assertDictEqual( - {"/home/foo/bar": "baz"}, - Workspace.from_str("/home/foo/bar: baz").projects, - ) + def build_workspace_and_update_role( + self, role: Role, workspace: str, cfg: Mapping[str, CfgVal] + ) -> None: + role.image = "bar" + role.metadata["workspace"] = workspace - def test_from_str_multi_project(self) -> None: - self.assertDictEqual( - { - "/home/foo/bar": "", - "/home/foo/third-party/verl": "verl", - }, - Workspace.from_str( - """# -/home/foo/bar: -/home/foo/third-party/verl: verl -""" - ).projects, - ) + if not workspace.startswith("//"): + # to validate the merged workspace dir copy its content to the tmpdir + shutil.copytree(workspace, self.tmpdir) def test_build_and_update_role2_str_workspace(self) -> None: proj = self.tmpdir / "github" / "torch"