From 0a579e3ce8d947bd76f9ef0e86feb1d74e97c8ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96zhan=20=C3=96zen?= Date: Tue, 26 Aug 2025 15:44:53 +0200 Subject: [PATCH 1/2] Adds early stopping support with stoppers --- scripts/reinforcement_learning/ray/tuner.py | 35 ++++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/scripts/reinforcement_learning/ray/tuner.py b/scripts/reinforcement_learning/ray/tuner.py index c9d5d6e20b9..2c05579b4d4 100644 --- a/scripts/reinforcement_learning/ray/tuner.py +++ b/scripts/reinforcement_learning/ray/tuner.py @@ -14,6 +14,7 @@ from ray import air, tune from ray.tune.search.optuna import OptunaSearch from ray.tune.search.repeater import Repeater +from ray.tune.stopper import CombinedStopper """ This script breaks down an aggregate tuning job, as defined by a hyperparameter sweep configuration, @@ -203,13 +204,18 @@ def stop_all(self): return False -def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None: +def invoke_tuning_run( + cfg: dict, + args: argparse.Namespace, + stopper: tune.Stopper | None = None, +) -> None: """Invoke an Isaac-Ray tuning run. Log either to a local directory or to MLFlow. Args: cfg: Configuration dictionary extracted from job setup args: Command-line arguments related to tuning. + stopper: Custom stopper, optional. """ # Allow for early exit os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1" @@ -237,6 +243,12 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None: ) repeat_search = Repeater(searcher, repeat=args.repeat_run_count) + # Configure the stoppers + stoppers: CombinedStopper = CombinedStopper(*[ + LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS), + *([stopper] if stopper is not None else []), + ]) + if args.run_mode == "local": # Standard config, to file run_config = air.RunConfig( storage_path="/tmp/ray", @@ -246,7 +258,7 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None: checkpoint_frequency=0, # Disable periodic checkpointing checkpoint_at_end=False, # Disable final checkpoint ), - stop=LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS), + stop=stoppers, ) elif args.run_mode == "remote": # MLFlow, to MLFlow server @@ -262,7 +274,7 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None: storage_path="/tmp/ray", callbacks=[mlflow_callback], checkpoint_config=ray.train.CheckpointConfig(checkpoint_frequency=0, checkpoint_at_end=False), - stop=LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS), + stop=stoppers, ) else: raise ValueError("Unrecognized run mode.") @@ -399,6 +411,12 @@ def __init__(self, cfg: dict): default=MAX_LOG_EXTRACTION_ERRORS, help="Max number number of LogExtractionError failures before we abort the whole tuning run.", ) + parser.add_argument( + "--stopper", + type=str, + default=None, + help="A stop criteria in the cfg_file, must be a tune.Stopper instance.", + ) args = parser.parse_args() PROCESS_RESPONSE_TIMEOUT = args.process_response_timeout @@ -457,7 +475,16 @@ def __init__(self, cfg: dict): print(f"[INFO]: Successfully instantiated class '{class_name}' from {file_path}") cfg = instance.cfg print(f"[INFO]: Grabbed the following hyperparameter sweep config: \n {cfg}") - invoke_tuning_run(cfg, args) + # Load optional stopper config + stopper = None + if args.stopper and hasattr(module, args.stopper): + stopper = getattr(module, args.stopper) + if isinstance(stopper, type) and issubclass(stopper, tune.Stopper): + stopper = stopper() + else: + raise TypeError(f"[ERROR]: Unsupported stop criteria type: {type(stopper)}") + print(f"[INFO]: Loaded custom stop criteria from '{args.stopper}'") + invoke_tuning_run(cfg, args, stopper=stopper) else: raise AttributeError(f"[ERROR]:Class '{class_name}' not found in {file_path}") From 4305f8506ff86c891f45d75f67cf3797f107d9bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96zhan=20=C3=96zen?= Date: Tue, 26 Aug 2025 15:45:09 +0200 Subject: [PATCH 2/2] Adds early stopper example for cartpole --- .../vision_cartpole_cfg.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/scripts/reinforcement_learning/ray/hyperparameter_tuning/vision_cartpole_cfg.py b/scripts/reinforcement_learning/ray/hyperparameter_tuning/vision_cartpole_cfg.py index 0a6889d075b..b8a9b9cc433 100644 --- a/scripts/reinforcement_learning/ray/hyperparameter_tuning/vision_cartpole_cfg.py +++ b/scripts/reinforcement_learning/ray/hyperparameter_tuning/vision_cartpole_cfg.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pathlib import sys +from typing import Any # Allow for import of items from the ray workflow. CUR_DIR = pathlib.Path(__file__).parent @@ -12,6 +13,7 @@ import util import vision_cfg from ray import tune +from ray.tune.stopper import Stopper class CartpoleRGBNoTuneJobCfg(vision_cfg.CameraJobCfg): @@ -47,3 +49,21 @@ def __init__(self, cfg: dict = {}): cfg = util.populate_isaac_ray_cfg_args(cfg) cfg["runner_args"]["--task"] = tune.choice(["Isaac-Cartpole-RGB-TheiaTiny-v0"]) super().__init__(cfg) + + +class CartpoleEarlyStopper(Stopper): + def __init__(self): + self._bad_trials = set() + + def __call__(self, trial_id: str, result: dict[str, Any]) -> bool: + iter = result.get("training_iteration", 0) + out_of_bounds = result.get("Episode/Episode_Termination/cart_out_of_bounds") + + # Mark the trial for stopping if conditions are met + if 20 <= iter and out_of_bounds is not None and out_of_bounds > 0.85: + self._bad_trials.add(trial_id) + + return trial_id in self._bad_trials + + def stop_all(self) -> bool: + return False # only stop individual trials