Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause
import pathlib
import sys
from typing import Any

# Allow for import of items from the ray workflow.
CUR_DIR = pathlib.Path(__file__).parent
Expand All @@ -12,6 +13,7 @@
import util
import vision_cfg
from ray import tune
from ray.tune.stopper import Stopper


class CartpoleRGBNoTuneJobCfg(vision_cfg.CameraJobCfg):
Expand Down Expand Up @@ -47,3 +49,21 @@ def __init__(self, cfg: dict = {}):
cfg = util.populate_isaac_ray_cfg_args(cfg)
cfg["runner_args"]["--task"] = tune.choice(["Isaac-Cartpole-RGB-TheiaTiny-v0"])
super().__init__(cfg)


class CartpoleEarlyStopper(Stopper):
def __init__(self):
self._bad_trials = set()

def __call__(self, trial_id: str, result: dict[str, Any]) -> bool:
iter = result.get("training_iteration", 0)
out_of_bounds = result.get("Episode/Episode_Termination/cart_out_of_bounds")

# Mark the trial for stopping if conditions are met
if 20 <= iter and out_of_bounds is not None and out_of_bounds > 0.85:
self._bad_trials.add(trial_id)

return trial_id in self._bad_trials

def stop_all(self) -> bool:
return False # only stop individual trials
35 changes: 31 additions & 4 deletions scripts/reinforcement_learning/ray/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ray import air, tune
from ray.tune.search.optuna import OptunaSearch
from ray.tune.search.repeater import Repeater
from ray.tune.stopper import CombinedStopper

"""
This script breaks down an aggregate tuning job, as defined by a hyperparameter sweep configuration,
Expand Down Expand Up @@ -203,13 +204,18 @@ def stop_all(self):
return False


def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
def invoke_tuning_run(
cfg: dict,
args: argparse.Namespace,
stopper: tune.Stopper | None = None,
) -> None:
"""Invoke an Isaac-Ray tuning run.

Log either to a local directory or to MLFlow.
Args:
cfg: Configuration dictionary extracted from job setup
args: Command-line arguments related to tuning.
stopper: Custom stopper, optional.
"""
# Allow for early exit
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1"
Expand Down Expand Up @@ -237,6 +243,12 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
)
repeat_search = Repeater(searcher, repeat=args.repeat_run_count)

# Configure the stoppers
stoppers: CombinedStopper = CombinedStopper(*[
LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS),
*([stopper] if stopper is not None else []),
])

if args.run_mode == "local": # Standard config, to file
run_config = air.RunConfig(
storage_path="/tmp/ray",
Expand All @@ -246,7 +258,7 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
checkpoint_frequency=0, # Disable periodic checkpointing
checkpoint_at_end=False, # Disable final checkpoint
),
stop=LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS),
stop=stoppers,
)

elif args.run_mode == "remote": # MLFlow, to MLFlow server
Expand All @@ -262,7 +274,7 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
storage_path="/tmp/ray",
callbacks=[mlflow_callback],
checkpoint_config=ray.train.CheckpointConfig(checkpoint_frequency=0, checkpoint_at_end=False),
stop=LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS),
stop=stoppers,
)
else:
raise ValueError("Unrecognized run mode.")
Expand Down Expand Up @@ -399,6 +411,12 @@ def __init__(self, cfg: dict):
default=MAX_LOG_EXTRACTION_ERRORS,
help="Max number number of LogExtractionError failures before we abort the whole tuning run.",
)
parser.add_argument(
"--stopper",
type=str,
default=None,
help="A stop criteria in the cfg_file, must be a tune.Stopper instance.",
)

args = parser.parse_args()
PROCESS_RESPONSE_TIMEOUT = args.process_response_timeout
Expand Down Expand Up @@ -457,7 +475,16 @@ def __init__(self, cfg: dict):
print(f"[INFO]: Successfully instantiated class '{class_name}' from {file_path}")
cfg = instance.cfg
print(f"[INFO]: Grabbed the following hyperparameter sweep config: \n {cfg}")
invoke_tuning_run(cfg, args)
# Load optional stopper config
stopper = None
if args.stopper and hasattr(module, args.stopper):
stopper = getattr(module, args.stopper)
if isinstance(stopper, type) and issubclass(stopper, tune.Stopper):
stopper = stopper()
else:
raise TypeError(f"[ERROR]: Unsupported stop criteria type: {type(stopper)}")
print(f"[INFO]: Loaded custom stop criteria from '{args.stopper}'")
invoke_tuning_run(cfg, args, stopper=stopper)

else:
raise AttributeError(f"[ERROR]:Class '{class_name}' not found in {file_path}")