diff --git a/acme/utils/lp_utils.py b/acme/utils/lp_utils.py index 354c0b0d62..125b42e13e 100644 --- a/acme/utils/lp_utils.py +++ b/acme/utils/lp_utils.py @@ -18,6 +18,7 @@ import functools import inspect import os +import signal import sys import time from typing import Any, Callable, Optional @@ -227,3 +228,117 @@ def make_xm_docker_resources(program, python_path=python_path) return xm_resources + + +class LaunchpadProgramStopper: + """Context manager for gracefully stopping Launchpad programs on termination. + + This is useful when running Acme distributed experiments with external + schedulers like Ray Tune's ASHA scheduler, which may terminate trials early. + When the parent process receives a termination signal (SIGTERM or SIGINT), + this context manager ensures that all Launchpad processes are properly + stopped via lp.stop(). + + Example usage with Ray Tune: + + def train_function(config): + experiment = build_experiment_config(config) + program = experiments.make_distributed_experiment( + experiment=experiment, + num_actors=1, + ) + with LaunchpadProgramStopper(): + lp.launch(program, lp.LaunchType.LOCAL_MULTI_PROCESSING) + + # Ray Tune will properly terminate all Launchpad processes + tuner = tune.Tuner( + train_function, + tune_config=tune.TuneConfig( + scheduler=ASHAScheduler(...), + ), + ) + """ + + def __init__(self): + self._original_sigterm_handler = None + self._original_sigint_handler = None + + def _signal_handler(self, signum, frame): + """Handle termination signals by stopping the Launchpad program.""" + del frame # Unused. + logging.info( + 'LaunchpadProgramStopper: Received signal %d, stopping program...', + signum) + # Avoid importing Launchpad until it is actually used. + import launchpad as lp # pylint: disable=g-import-not-at-top + try: + lp.stop() + except Exception as e: # pylint: disable=broad-except + logging.warning('LaunchpadProgramStopper: Error stopping program: %s', e) + # Re-raise the signal to allow the process to terminate. + if signum == signal.SIGTERM and self._original_sigterm_handler: + if callable(self._original_sigterm_handler): + self._original_sigterm_handler(signum, None) + elif signum == signal.SIGINT and self._original_sigint_handler: + if callable(self._original_sigint_handler): + self._original_sigint_handler(signum, None) + + def __enter__(self): + # Save original handlers. + self._original_sigterm_handler = signal.signal( + signal.SIGTERM, self._signal_handler) + self._original_sigint_handler = signal.signal( + signal.SIGINT, self._signal_handler) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Restore original handlers. + signal.signal(signal.SIGTERM, self._original_sigterm_handler) + signal.signal(signal.SIGINT, self._original_sigint_handler) + return False + + +def launch_with_termination_handler( + program, + launch_type: str = 'local_multi_processing', + **launch_kwargs, +): + """Launch a Launchpad program with proper termination signal handling. + + This function wraps lp.launch() with a signal handler that ensures all + Launchpad processes are properly terminated when the parent process receives + a SIGTERM or SIGINT signal. This is particularly useful when running Acme + distributed experiments with external schedulers like Ray Tune's ASHA + scheduler. + + Args: + program: The Launchpad program to launch. + launch_type: The type of launch (e.g., 'local_multi_processing', + 'local_multi_threading'). Defaults to 'local_multi_processing'. + **launch_kwargs: Additional keyword arguments to pass to lp.launch(). + + Returns: + The result of lp.launch(). + + Example usage with Ray Tune: + + def train_function(config): + experiment = build_experiment_config(config) + program = experiments.make_distributed_experiment( + experiment=experiment, + num_actors=1, + ) + launch_with_termination_handler(program) + + tuner = tune.Tuner( + train_function, + tune_config=tune.TuneConfig( + scheduler=ASHAScheduler(...), + ), + ) + """ + # Avoid importing Launchpad until it is actually used. + import launchpad as lp # pylint: disable=g-import-not-at-top + + with LaunchpadProgramStopper(): + return lp.launch(program, launch_type=launch_type, **launch_kwargs)