Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchx/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,9 @@

# pyre-strict

from torchx.runner.api import get_runner, Runner # noqa: F401 F403
from torchx.runner.api import ( # noqa: F401 F403
ComponentRunner,
get_component_runner,
get_runner,
Runner,
)
189 changes: 189 additions & 0 deletions torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def run_component(
parent_run_id: Optional[str] = None,
) -> AppHandle:
"""

.. warning:: This method will be deprecated in the future. It has been
moved to ``run_component`` from ``ComponentRunner`` which provides the same functionality.

Runs a component.

``component`` has the following resolution order(high to low):
Expand Down Expand Up @@ -205,6 +209,11 @@ def run_component(
ComponentValidationException: if component is invalid.
ComponentNotFoundException: if the ``component_path`` is failed to resolve.
"""
warnings.warn(
"Runner's run_component will be deprecated in the future. Use the ComponentRunner's run_component instead. ",
PendingDeprecationWarning,
stacklevel=2,
)

with log_event("run_component") as ctx:
dryrun_info = self.dryrun_component(
Expand Down Expand Up @@ -235,9 +244,18 @@ def dryrun_component(
parent_run_id: Optional[str] = None,
) -> AppDryRunInfo:
"""
.. warning:: This method will be deprecated in the future. It has been
moved to ``dryrun_component`` from ``ComponentRunner`` which provides the same functionality.

Dryrun version of :py:func:`run_component`. Will not actually run the
component, but just returns what "would" have run.
"""

warnings.warn(
"Runner's dryrun_component will be deprecated in the future. Use the ComponentRunner's dryrun_component instead. ",
PendingDeprecationWarning,
stacklevel=2,
)
component_def = get_component(component)
args_from_cli = component_args if isinstance(component_args, list) else []
args_from_json = component_args if isinstance(component_args, dict) else {}
Expand Down Expand Up @@ -831,3 +849,174 @@ def get_runner(
return Runner(
name, scheduler_factories, component_defaults, scheduler_params=scheduler_params
)


class ComponentRunner:
"""
TorchX component runner. This class is a wrapper around the Runner class
that provides a higher level API for running components.
"""

def __init__(
self,
runner: Runner,
component_defaults: Optional[Dict[str, Dict[str, str]]] = None,
) -> None:
"""
Creates a new component runner instance.

Args:
runner: runner instance to use for running components
component_defaults: defaults to use for the component runs
"""
# pyre-fixme[24]: SchedulerOpts is a generic, and we don't have access to the corresponding type
self._scheduler_instances: Dict[str, Scheduler] = {}
self._apps: Dict[AppHandle, AppDef] = {}
self._runner = runner

# component_name -> map of component_fn_param_name -> user-specified default val encoded as str
self._component_defaults: Dict[str, Dict[str, str]] = component_defaults or {}

def __enter__(self) -> "Self":
return self

def __exit__(self, *args: Any) -> bool:
self.runner.close()
return False

@property
def runner(self) -> Runner:
return self._runner

def run_component(
self,
component: str,
component_args: Union[list[str], dict[str, Any]],
scheduler: str,
cfg: Optional[Mapping[str, CfgVal]] = None,
workspace: Optional[Union[Workspace, str]] = None,
parent_run_id: Optional[str] = None,
) -> AppHandle:
"""
Runs a component.

``component`` has the following resolution order(high to low):
* User-registered components. Users can register components via
https://packaging.python.org/specifications/entry-points/. Method looks for
entrypoints in the group ``torchx.components``.
* Builtin components relative to `torchx.components`. The path to the component should
be module name relative to `torchx.components` and function name in a format:
``$module.$function``.
* File-based components in format: ``$FILE_PATH:FUNCTION_NAME``. Both relative and
absolute paths supported.

Usage:

.. code-block:: python

# resolved to torchx.components.distributed.ddp()
component_runner.run_component("distributed.ddp", ...)

# resolved to my_component() function in ~/home/components.py
component_runner.run_component("~/home/components.py:my_component", ...)


Returns:
An application handle that is used to call other action APIs on the app

Raises:
ComponentValidationException: if component is invalid.
ComponentNotFoundException: if the ``component_path`` is failed to resolve.
"""

with log_event("run_component") as ctx:
dryrun_info = self.dryrun_component(
component,
component_args,
scheduler,
cfg=cfg,
workspace=workspace,
parent_run_id=parent_run_id,
)

handle = self._runner.schedule(dryrun_info)
app = none_throws(dryrun_info._app)

ctx._torchx_event.workspace = str(workspace)
ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler)
ctx._torchx_event.app_image = app.roles[0].image
ctx._torchx_event.app_id = parse_app_handle(handle)[2]
ctx._torchx_event.app_metadata = app.metadata
return handle

def dryrun_component(
self,
component: str,
component_args: Union[list[str], dict[str, Any]],
scheduler: str,
cfg: Optional[Mapping[str, CfgVal]] = None,
workspace: Optional[Union[Workspace, str]] = None,
parent_run_id: Optional[str] = None,
) -> AppDryRunInfo:
"""
Dryrun version of :py:func:`run_component`. Will not actually run the
component, but just returns what "would" have run.
"""
component_def = get_component(component)
args_from_cli = component_args if isinstance(component_args, list) else []
args_from_json = component_args if isinstance(component_args, dict) else {}
app = materialize_appdef(
component_def.fn,
args_from_cli,
self._component_defaults.get(component, None),
args_from_json,
)
return self._runner.dryrun(
app,
scheduler,
cfg=cfg,
workspace=workspace,
parent_run_id=parent_run_id,
)


def get_component_runner(
name: Optional[str] = None,
component_defaults: Optional[Dict[str, Dict[str, str]]] = None,
**scheduler_params: Any,
) -> ComponentRunner:
"""
Convenience method to construct and get a ComponentRunner object. Usage:

.. code-block:: python

with get_component_runner() as component_runner:
app_handle = component_runner.run_component(name=name, component_defaults=component_defaults,
scheduler="kubernetes", runcfg)

Alternatively,

.. code-block:: python

component_runner = get_component_runner()
try:
app_handle = component_runner.run_component(name=name, component_defaults=component_defaults,
scheduler="kubernetes", runcfg)
finally:
component_runner.close()

Args:
name: human readable name that will be included as part of all launched
jobs.
component_defaults: a map of component_name to map of component_fn_param_name
to user-specified default val encoded as str.
scheduler_params: extra arguments that will be passed to the constructor
of all available schedulers.


"""
runner = get_runner(name, **scheduler_params)
return ComponentRunner(
runner,
component_defaults,
)
126 changes: 126 additions & 0 deletions torchx/runner/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from unittest.mock import MagicMock, patch

from torchx.runner import get_runner, Runner
from torchx.runner.api import ComponentRunner
from torchx.schedulers import SchedulerFactory
from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Scheduler
from torchx.schedulers.local_scheduler import (
Expand Down Expand Up @@ -789,3 +790,128 @@ def test_cfg_from_str(self, _) -> None:
"enable=True,disable=False,complex_list=v1;v2;v3",
),
)


class ComponentRunnerTest(TestWithTmpDir):
def setUp(self) -> None:
super().setUp()
self.mock_scheduler = MagicMock()
self.scheduler_factories = {"local": lambda name, **kwargs: self.mock_scheduler}
self.runner = Runner(
name=SESSION_NAME, scheduler_factories=self.scheduler_factories
)

def tearDown(self) -> None:
self.runner.close()
super().tearDown()

def test_component_runner_init(self) -> None:
component_defaults = {"test.component": {"arg1": "value1"}}
component_runner = ComponentRunner(
runner=self.runner, component_defaults=component_defaults
)

self.assertEqual(component_runner._runner, self.runner)
self.assertEqual(component_runner._component_defaults, component_defaults)

def test_component_runner_context_manager(self) -> None:
with ComponentRunner(runner=self.runner) as component_runner:
self.assertIsInstance(component_runner, ComponentRunner)
self.assertEqual(component_runner._runner, self.runner)

@patch("torchx.runner.api.get_component")
@patch("torchx.runner.api.materialize_appdef")
def test_component_runner_dryrun_component(
self, materialize_mock: MagicMock, get_component_mock: MagicMock
) -> None:
mock_component_def = MagicMock()
mock_component_def.fn = MagicMock()
get_component_mock.return_value = mock_component_def

mock_app = AppDef(
"test_app",
roles=[
Role(
name="test_role",
image="test_image",
resource=resource.SMALL,
entrypoint="echo",
args=["hello"],
)
],
)
materialize_mock.return_value = mock_app

component_defaults = {"test.component": {"arg1": "default_value"}}
component_runner = ComponentRunner(
runner=self.runner, component_defaults=component_defaults
)

mock_dryrun_info = AppDryRunInfo(mock_app, lambda x: x)
with patch.object(self.runner, "dryrun", return_value=mock_dryrun_info):
component_args = ["--arg1", "value1"]
result = component_runner.dryrun_component(
component="test.component",
component_args=component_args,
scheduler="local",
cfg={"key": "value"},
)

get_component_mock.assert_called_once_with("test.component")

materialize_mock.assert_called_once_with(
mock_component_def.fn,
["--arg1", "value1"],
{"arg1": "default_value"},
{},
)

self.assertEqual(result, mock_dryrun_info)

@patch("torchx.runner.api.log_event")
@patch("torchx.runner.api.get_component")
@patch("torchx.runner.api.materialize_appdef")
def test_component_runner_run_component(
self,
materialize_mock: MagicMock,
get_component_mock: MagicMock,
log_event_mock: MagicMock,
) -> None:
mock_component_def = MagicMock()
mock_component_def.fn = MagicMock()
get_component_mock.return_value = mock_component_def

mock_app = AppDef(
"test_app",
roles=[
Role(
name="test_role",
image="test_image",
resource=resource.SMALL,
entrypoint="echo",
args=["hello"],
)
],
)
materialize_mock.return_value = mock_app

component_runner = ComponentRunner(runner=self.runner)

mock_dryrun_info = AppDryRunInfo(mock_app, lambda x: x)
mock_dryrun_info._scheduler = "local"
mock_dryrun_info._app = mock_app
mock_app_handle = "local://test_session/test_app_id"

mock_schedule = MagicMock(return_value=mock_app_handle)
with patch.object(
self.runner, "dryrun", return_value=mock_dryrun_info
), patch.object(self.runner, "schedule", mock_schedule):
result = component_runner.run_component(
component="test.component",
component_args=["--arg1", "value1"],
scheduler="local",
)

self.assertEqual(result, mock_app_handle)

mock_schedule.assert_called_once_with(mock_dryrun_info)