|  | 
| 13 | 13 | 
 | 
| 14 | 14 | import hydra | 
| 15 | 15 | 
 | 
| 16 |  | -from torchrl import torchrl_logger | 
|  | 16 | +from torchrl import merge_ray_runtime_env, torchrl_logger | 
| 17 | 17 | from torchrl.data.llm.history import History | 
| 18 | 18 | from torchrl.record.loggers.wandb import WandbLogger | 
| 19 | 19 | from torchrl.weight_update.llm import get_model_metadata | 
| @@ -319,19 +319,9 @@ def main(cfg): | 
| 319 | 319 |             if not k.startswith("_") | 
| 320 | 320 |         } | 
| 321 | 321 | 
 | 
| 322 |  | -        # Add computed GPU configuration | 
|  | 322 | +        # Add computed GPU configuration and merge with default runtime_env | 
| 323 | 323 |         ray_init_config["num_gpus"] = device_config["ray_num_gpus"] | 
| 324 |  | -        # Ensure runtime_env and env_vars exist | 
| 325 |  | -        if "runtime_env" not in ray_init_config: | 
| 326 |  | -            ray_init_config["runtime_env"] = {} | 
| 327 |  | -        if not isinstance(ray_init_config["runtime_env"], dict): | 
| 328 |  | -            ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"]) | 
| 329 |  | -        if "env_vars" not in ray_init_config["runtime_env"]: | 
| 330 |  | -            ray_init_config["runtime_env"]["env_vars"] = {} | 
| 331 |  | -        if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict): | 
| 332 |  | -            ray_init_config["runtime_env"]["env_vars"] = dict( | 
| 333 |  | -                ray_init_config["runtime_env"]["env_vars"] | 
| 334 |  | -            ) | 
|  | 324 | +        ray_init_config = merge_ray_runtime_env(ray_init_config) | 
| 335 | 325 |         torchrl_logger.info(f"Ray init config: {ray_init_config=}") | 
| 336 | 326 |         ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY") | 
| 337 | 327 |         if ray_managed_externally: | 
|  | 
0 commit comments