Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions acme/utils/lp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import functools
import inspect
import os
import signal
import sys
import time
from typing import Any, Callable, Optional
Expand Down Expand Up @@ -227,3 +228,117 @@ def make_xm_docker_resources(program,
python_path=python_path)

return xm_resources


class LaunchpadProgramStopper:
"""Context manager for gracefully stopping Launchpad programs on termination.

This is useful when running Acme distributed experiments with external
schedulers like Ray Tune's ASHA scheduler, which may terminate trials early.
When the parent process receives a termination signal (SIGTERM or SIGINT),
this context manager ensures that all Launchpad processes are properly
stopped via lp.stop().

Example usage with Ray Tune:

def train_function(config):
experiment = build_experiment_config(config)
program = experiments.make_distributed_experiment(
experiment=experiment,
num_actors=1,
)
with LaunchpadProgramStopper():
lp.launch(program, lp.LaunchType.LOCAL_MULTI_PROCESSING)

# Ray Tune will properly terminate all Launchpad processes
tuner = tune.Tuner(
train_function,
tune_config=tune.TuneConfig(
scheduler=ASHAScheduler(...),
),
)
"""

def __init__(self):
self._original_sigterm_handler = None
self._original_sigint_handler = None

def _signal_handler(self, signum, frame):
"""Handle termination signals by stopping the Launchpad program."""
del frame # Unused.
logging.info(
'LaunchpadProgramStopper: Received signal %d, stopping program...',
signum)
# Avoid importing Launchpad until it is actually used.
import launchpad as lp # pylint: disable=g-import-not-at-top
try:
lp.stop()
except Exception as e: # pylint: disable=broad-except
logging.warning('LaunchpadProgramStopper: Error stopping program: %s', e)
# Re-raise the signal to allow the process to terminate.
if signum == signal.SIGTERM and self._original_sigterm_handler:
if callable(self._original_sigterm_handler):
self._original_sigterm_handler(signum, None)
elif signum == signal.SIGINT and self._original_sigint_handler:
if callable(self._original_sigint_handler):
self._original_sigint_handler(signum, None)

def __enter__(self):
# Save original handlers.
self._original_sigterm_handler = signal.signal(
signal.SIGTERM, self._signal_handler)
self._original_sigint_handler = signal.signal(
signal.SIGINT, self._signal_handler)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
# Restore original handlers.
signal.signal(signal.SIGTERM, self._original_sigterm_handler)
signal.signal(signal.SIGINT, self._original_sigint_handler)
return False


def launch_with_termination_handler(
program,
launch_type: str = 'local_multi_processing',
**launch_kwargs,
):
"""Launch a Launchpad program with proper termination signal handling.

This function wraps lp.launch() with a signal handler that ensures all
Launchpad processes are properly terminated when the parent process receives
a SIGTERM or SIGINT signal. This is particularly useful when running Acme
distributed experiments with external schedulers like Ray Tune's ASHA
scheduler.

Args:
program: The Launchpad program to launch.
launch_type: The type of launch (e.g., 'local_multi_processing',
'local_multi_threading'). Defaults to 'local_multi_processing'.
**launch_kwargs: Additional keyword arguments to pass to lp.launch().

Returns:
The result of lp.launch().

Example usage with Ray Tune:

def train_function(config):
experiment = build_experiment_config(config)
program = experiments.make_distributed_experiment(
experiment=experiment,
num_actors=1,
)
launch_with_termination_handler(program)

tuner = tune.Tuner(
train_function,
tune_config=tune.TuneConfig(
scheduler=ASHAScheduler(...),
),
)
"""
# Avoid importing Launchpad until it is actually used.
import launchpad as lp # pylint: disable=g-import-not-at-top

with LaunchpadProgramStopper():
return lp.launch(program, launch_type=launch_type, **launch_kwargs)