diff --git a/torchx/runner/__init__.py b/torchx/runner/__init__.py index 2c95eca4a..3f8429f45 100644 --- a/torchx/runner/__init__.py +++ b/torchx/runner/__init__.py @@ -7,4 +7,9 @@ # pyre-strict -from torchx.runner.api import get_runner, Runner # noqa: F401 F403 +from torchx.runner.api import ( # noqa: F401 F403 + ComponentRunner, + get_component_runner, + get_runner, + Runner, +) diff --git a/torchx/runner/api.py b/torchx/runner/api.py index fe2b9ca4b..e67f9db59 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -175,6 +175,10 @@ def run_component( parent_run_id: Optional[str] = None, ) -> AppHandle: """ + + .. warning:: This method will be deprecated in the future. It has been + moved to ``run_component`` from ``ComponentRunner`` which provides the same functionality. + Runs a component. ``component`` has the following resolution order(high to low): @@ -205,6 +209,11 @@ def run_component( ComponentValidationException: if component is invalid. ComponentNotFoundException: if the ``component_path`` is failed to resolve. """ + warnings.warn( + "Runner's run_component will be deprecated in the future. Use the ComponentRunner's run_component instead. ", + PendingDeprecationWarning, + stacklevel=2, + ) with log_event("run_component") as ctx: dryrun_info = self.dryrun_component( @@ -235,9 +244,18 @@ def dryrun_component( parent_run_id: Optional[str] = None, ) -> AppDryRunInfo: """ + .. warning:: This method will be deprecated in the future. It has been + moved to ``dryrun_component`` from ``ComponentRunner`` which provides the same functionality. + Dryrun version of :py:func:`run_component`. Will not actually run the component, but just returns what "would" have run. """ + + warnings.warn( + "Runner's dryrun_component will be deprecated in the future. Use the ComponentRunner's dryrun_component instead. ", + PendingDeprecationWarning, + stacklevel=2, + ) component_def = get_component(component) args_from_cli = component_args if isinstance(component_args, list) else [] args_from_json = component_args if isinstance(component_args, dict) else {} @@ -831,3 +849,174 @@ def get_runner( return Runner( name, scheduler_factories, component_defaults, scheduler_params=scheduler_params ) + + +class ComponentRunner: + """ + TorchX component runner. This class is a wrapper around the Runner class + that provides a higher level API for running components. + """ + + def __init__( + self, + runner: Runner, + component_defaults: Optional[Dict[str, Dict[str, str]]] = None, + ) -> None: + """ + Creates a new component runner instance. + + Args: + runner: runner instance to use for running components + component_defaults: defaults to use for the component runs + """ + # pyre-fixme[24]: SchedulerOpts is a generic, and we don't have access to the corresponding type + self._scheduler_instances: Dict[str, Scheduler] = {} + self._apps: Dict[AppHandle, AppDef] = {} + self._runner = runner + + # component_name -> map of component_fn_param_name -> user-specified default val encoded as str + self._component_defaults: Dict[str, Dict[str, str]] = component_defaults or {} + + def __enter__(self) -> "Self": + return self + + def __exit__(self, *args: Any) -> bool: + self.runner.close() + return False + + @property + def runner(self) -> Runner: + return self._runner + + def run_component( + self, + component: str, + component_args: Union[list[str], dict[str, Any]], + scheduler: str, + cfg: Optional[Mapping[str, CfgVal]] = None, + workspace: Optional[Union[Workspace, str]] = None, + parent_run_id: Optional[str] = None, + ) -> AppHandle: + """ + Runs a component. + + ``component`` has the following resolution order(high to low): + * User-registered components. Users can register components via + https://packaging.python.org/specifications/entry-points/. Method looks for + entrypoints in the group ``torchx.components``. + * Builtin components relative to `torchx.components`. The path to the component should + be module name relative to `torchx.components` and function name in a format: + ``$module.$function``. + * File-based components in format: ``$FILE_PATH:FUNCTION_NAME``. Both relative and + absolute paths supported. + + Usage: + + .. code-block:: python + + # resolved to torchx.components.distributed.ddp() + component_runner.run_component("distributed.ddp", ...) + + # resolved to my_component() function in ~/home/components.py + component_runner.run_component("~/home/components.py:my_component", ...) + + + Returns: + An application handle that is used to call other action APIs on the app + + Raises: + ComponentValidationException: if component is invalid. + ComponentNotFoundException: if the ``component_path`` is failed to resolve. + """ + + with log_event("run_component") as ctx: + dryrun_info = self.dryrun_component( + component, + component_args, + scheduler, + cfg=cfg, + workspace=workspace, + parent_run_id=parent_run_id, + ) + + handle = self._runner.schedule(dryrun_info) + app = none_throws(dryrun_info._app) + + ctx._torchx_event.workspace = str(workspace) + ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler) + ctx._torchx_event.app_image = app.roles[0].image + ctx._torchx_event.app_id = parse_app_handle(handle)[2] + ctx._torchx_event.app_metadata = app.metadata + return handle + + def dryrun_component( + self, + component: str, + component_args: Union[list[str], dict[str, Any]], + scheduler: str, + cfg: Optional[Mapping[str, CfgVal]] = None, + workspace: Optional[Union[Workspace, str]] = None, + parent_run_id: Optional[str] = None, + ) -> AppDryRunInfo: + """ + Dryrun version of :py:func:`run_component`. Will not actually run the + component, but just returns what "would" have run. + """ + component_def = get_component(component) + args_from_cli = component_args if isinstance(component_args, list) else [] + args_from_json = component_args if isinstance(component_args, dict) else {} + app = materialize_appdef( + component_def.fn, + args_from_cli, + self._component_defaults.get(component, None), + args_from_json, + ) + return self._runner.dryrun( + app, + scheduler, + cfg=cfg, + workspace=workspace, + parent_run_id=parent_run_id, + ) + + +def get_component_runner( + name: Optional[str] = None, + component_defaults: Optional[Dict[str, Dict[str, str]]] = None, + **scheduler_params: Any, +) -> ComponentRunner: + """ + Convenience method to construct and get a ComponentRunner object. Usage: + + .. code-block:: python + + with get_component_runner() as component_runner: + app_handle = component_runner.run_component(name=name, component_defaults=component_defaults, + scheduler="kubernetes", runcfg) + + Alternatively, + + .. code-block:: python + + component_runner = get_component_runner() + try: + app_handle = component_runner.run_component(name=name, component_defaults=component_defaults, + scheduler="kubernetes", runcfg) + finally: + component_runner.close() + + Args: + name: human readable name that will be included as part of all launched + jobs. + component_defaults: a map of component_name to map of component_fn_param_name + to user-specified default val encoded as str. + scheduler_params: extra arguments that will be passed to the constructor + of all available schedulers. + + + """ + runner = get_runner(name, **scheduler_params) + return ComponentRunner( + runner, + component_defaults, + ) diff --git a/torchx/runner/test/api_test.py b/torchx/runner/test/api_test.py index 118057eb8..e60020c46 100644 --- a/torchx/runner/test/api_test.py +++ b/torchx/runner/test/api_test.py @@ -14,6 +14,7 @@ from unittest.mock import MagicMock, patch from torchx.runner import get_runner, Runner +from torchx.runner.api import ComponentRunner from torchx.schedulers import SchedulerFactory from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Scheduler from torchx.schedulers.local_scheduler import ( @@ -789,3 +790,128 @@ def test_cfg_from_str(self, _) -> None: "enable=True,disable=False,complex_list=v1;v2;v3", ), ) + + +class ComponentRunnerTest(TestWithTmpDir): + def setUp(self) -> None: + super().setUp() + self.mock_scheduler = MagicMock() + self.scheduler_factories = {"local": lambda name, **kwargs: self.mock_scheduler} + self.runner = Runner( + name=SESSION_NAME, scheduler_factories=self.scheduler_factories + ) + + def tearDown(self) -> None: + self.runner.close() + super().tearDown() + + def test_component_runner_init(self) -> None: + component_defaults = {"test.component": {"arg1": "value1"}} + component_runner = ComponentRunner( + runner=self.runner, component_defaults=component_defaults + ) + + self.assertEqual(component_runner._runner, self.runner) + self.assertEqual(component_runner._component_defaults, component_defaults) + + def test_component_runner_context_manager(self) -> None: + with ComponentRunner(runner=self.runner) as component_runner: + self.assertIsInstance(component_runner, ComponentRunner) + self.assertEqual(component_runner._runner, self.runner) + + @patch("torchx.runner.api.get_component") + @patch("torchx.runner.api.materialize_appdef") + def test_component_runner_dryrun_component( + self, materialize_mock: MagicMock, get_component_mock: MagicMock + ) -> None: + mock_component_def = MagicMock() + mock_component_def.fn = MagicMock() + get_component_mock.return_value = mock_component_def + + mock_app = AppDef( + "test_app", + roles=[ + Role( + name="test_role", + image="test_image", + resource=resource.SMALL, + entrypoint="echo", + args=["hello"], + ) + ], + ) + materialize_mock.return_value = mock_app + + component_defaults = {"test.component": {"arg1": "default_value"}} + component_runner = ComponentRunner( + runner=self.runner, component_defaults=component_defaults + ) + + mock_dryrun_info = AppDryRunInfo(mock_app, lambda x: x) + with patch.object(self.runner, "dryrun", return_value=mock_dryrun_info): + component_args = ["--arg1", "value1"] + result = component_runner.dryrun_component( + component="test.component", + component_args=component_args, + scheduler="local", + cfg={"key": "value"}, + ) + + get_component_mock.assert_called_once_with("test.component") + + materialize_mock.assert_called_once_with( + mock_component_def.fn, + ["--arg1", "value1"], + {"arg1": "default_value"}, + {}, + ) + + self.assertEqual(result, mock_dryrun_info) + + @patch("torchx.runner.api.log_event") + @patch("torchx.runner.api.get_component") + @patch("torchx.runner.api.materialize_appdef") + def test_component_runner_run_component( + self, + materialize_mock: MagicMock, + get_component_mock: MagicMock, + log_event_mock: MagicMock, + ) -> None: + mock_component_def = MagicMock() + mock_component_def.fn = MagicMock() + get_component_mock.return_value = mock_component_def + + mock_app = AppDef( + "test_app", + roles=[ + Role( + name="test_role", + image="test_image", + resource=resource.SMALL, + entrypoint="echo", + args=["hello"], + ) + ], + ) + materialize_mock.return_value = mock_app + + component_runner = ComponentRunner(runner=self.runner) + + mock_dryrun_info = AppDryRunInfo(mock_app, lambda x: x) + mock_dryrun_info._scheduler = "local" + mock_dryrun_info._app = mock_app + mock_app_handle = "local://test_session/test_app_id" + + mock_schedule = MagicMock(return_value=mock_app_handle) + with patch.object( + self.runner, "dryrun", return_value=mock_dryrun_info + ), patch.object(self.runner, "schedule", mock_schedule): + result = component_runner.run_component( + component="test.component", + component_args=["--arg1", "value1"], + scheduler="local", + ) + + self.assertEqual(result, mock_app_handle) + + mock_schedule.assert_called_once_with(mock_dryrun_info)