Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import util
import vision_cfg
from ray import tune
from ray.tune.progress_reporter import CLIReporter


class CartpoleRGBNoTuneJobCfg(vision_cfg.CameraJobCfg):
Expand Down Expand Up @@ -47,3 +48,18 @@ def __init__(self, cfg: dict = {}):
cfg = util.populate_isaac_ray_cfg_args(cfg)
cfg["runner_args"]["--task"] = tune.choice(["Isaac-Cartpole-RGB-TheiaTiny-v0"])
super().__init__(cfg)


class CustomCartpoleProgressReporter(CLIReporter):
def __init__(self):
super().__init__(
metric_columns={
"training_iteration": "iter",
"time_total_s": "total time (s)",
"Episode/Episode_Reward/alive": "alive",
"Episode/Episode_Reward/cart_vel": "cart velocity",
"rewards/time": "rewards/time",
},
max_report_frequency=5,
sort_by_metric=True,
)
38 changes: 36 additions & 2 deletions scripts/reinforcement_learning/ray/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ray
import util
from ray import air, tune
from ray.tune.progress_reporter import ProgressReporter
from ray.tune.search.optuna import OptunaSearch
from ray.tune.search.repeater import Repeater

Expand Down Expand Up @@ -203,13 +204,18 @@ def stop_all(self):
return False


def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
def invoke_tuning_run(
cfg: dict,
args: argparse.Namespace,
progress_reporter: ProgressReporter | None = None,
) -> None:
"""Invoke an Isaac-Ray tuning run.

Log either to a local directory or to MLFlow.
Args:
cfg: Configuration dictionary extracted from job setup
args: Command-line arguments related to tuning.
progress_reporter: Custom progress reporter. Defaults to CLIReporter or JupyterNotebookReporter if not provided.
"""
# Allow for early exit
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1"
Expand Down Expand Up @@ -237,6 +243,17 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
)
repeat_search = Repeater(searcher, repeat=args.repeat_run_count)

if progress_reporter is not None:
os.environ["RAY_AIR_NEW_OUTPUT"] = "0"
if (
getattr(progress_reporter, "_metric", None) is not None
or getattr(progress_reporter, "_mode", None) is not None
):
raise ValueError(
"Do not set <metric> or <mode> directly in the custom progress reporter class, "
"provide them as arguments to tuner.py instead."
)

if args.run_mode == "local": # Standard config, to file
run_config = air.RunConfig(
storage_path="/tmp/ray",
Expand All @@ -247,6 +264,7 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
checkpoint_at_end=False, # Disable final checkpoint
),
stop=LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS),
progress_reporter=progress_reporter,
)

elif args.run_mode == "remote": # MLFlow, to MLFlow server
Expand All @@ -263,6 +281,7 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
callbacks=[mlflow_callback],
checkpoint_config=ray.train.CheckpointConfig(checkpoint_frequency=0, checkpoint_at_end=False),
stop=LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS),
progress_reporter=progress_reporter,
)
else:
raise ValueError("Unrecognized run mode.")
Expand Down Expand Up @@ -399,6 +418,12 @@ def __init__(self, cfg: dict):
default=MAX_LOG_EXTRACTION_ERRORS,
help="Max number number of LogExtractionError failures before we abort the whole tuning run.",
)
parser.add_argument(
"--progress_reporter",
type=str,
default=None,
help="A progress reporter in cfg_file, must be a ProgressReporter object.",
)

args = parser.parse_args()
PROCESS_RESPONSE_TIMEOUT = args.process_response_timeout
Expand Down Expand Up @@ -457,7 +482,16 @@ def __init__(self, cfg: dict):
print(f"[INFO]: Successfully instantiated class '{class_name}' from {file_path}")
cfg = instance.cfg
print(f"[INFO]: Grabbed the following hyperparameter sweep config: \n {cfg}")
invoke_tuning_run(cfg, args)
# Load optional progress reporter config
progress_reporter = None
if args.progress_reporter and hasattr(module, args.progress_reporter):
progress_reporter = getattr(module, args.progress_reporter)
if isinstance(progress_reporter, type) and issubclass(progress_reporter, tune.ProgressReporter):
progress_reporter = progress_reporter()
else:
raise TypeError(f"[ERROR]: {args.progress_reporter} is not a valid ProgressReporter.")
print(f"[INFO]: Loaded custom progress reporter from '{args.progress_reporter}'")
invoke_tuning_run(cfg, args, progress_reporter=progress_reporter)

else:
raise AttributeError(f"[ERROR]:Class '{class_name}' not found in {file_path}")