Skip to content

Commit 47ad9d8

Browse files
committed
[BugFix] Fix GRPO tests and runs
ghstack-source-id: a5b7078 Pull-Request: #3213
1 parent e7ec9c3 commit 47ad9d8

File tree

18 files changed

+527
-717
lines changed

18 files changed

+527
-717
lines changed

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,11 @@ log
189189
Roms
190190

191191
scratch/*
192+
193+
# Large directories from git history that should not be committed
194+
dev/
195+
main/
196+
*.html
197+
198+
# Additional cache directories
199+
.ruff_cache/

sota-implementations/expert-iteration/expert-iteration-async.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import hydra
1414

15-
from torchrl import torchrl_logger
15+
from torchrl import merge_ray_runtime_env, torchrl_logger
1616
from torchrl.data.llm.history import History
1717
from torchrl.record.loggers.wandb import WandbLogger
1818
from torchrl.weight_update.llm import get_model_metadata
@@ -397,19 +397,9 @@ def main(cfg):
397397
if not k.startswith("_")
398398
}
399399

400-
# Add computed GPU configuration
400+
# Add computed GPU configuration and merge with default runtime_env
401401
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
402-
# Ensure runtime_env and env_vars exist
403-
if "runtime_env" not in ray_init_config:
404-
ray_init_config["runtime_env"] = {}
405-
if not isinstance(ray_init_config["runtime_env"], dict):
406-
ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"])
407-
if "env_vars" not in ray_init_config["runtime_env"]:
408-
ray_init_config["runtime_env"]["env_vars"] = {}
409-
if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict):
410-
ray_init_config["runtime_env"]["env_vars"] = dict(
411-
ray_init_config["runtime_env"]["env_vars"]
412-
)
402+
ray_init_config = merge_ray_runtime_env(ray_init_config)
413403
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
414404
ray.init(**ray_init_config)
415405

sota-implementations/expert-iteration/expert-iteration-sync.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import hydra
1414

15-
from torchrl import torchrl_logger
15+
from torchrl import merge_ray_runtime_env, torchrl_logger
1616
from torchrl.data.llm.history import History
1717
from torchrl.record.loggers.wandb import WandbLogger
1818
from torchrl.weight_update.llm import get_model_metadata
@@ -398,19 +398,9 @@ def main(cfg):
398398
if not k.startswith("_")
399399
}
400400

401-
# Add computed GPU configuration
401+
# Add computed GPU configuration and merge with default runtime_env
402402
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
403-
# Ensure runtime_env and env_vars exist
404-
if "runtime_env" not in ray_init_config:
405-
ray_init_config["runtime_env"] = {}
406-
if not isinstance(ray_init_config["runtime_env"], dict):
407-
ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"])
408-
if "env_vars" not in ray_init_config["runtime_env"]:
409-
ray_init_config["runtime_env"]["env_vars"] = {}
410-
if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict):
411-
ray_init_config["runtime_env"]["env_vars"] = dict(
412-
ray_init_config["runtime_env"]["env_vars"]
413-
)
403+
ray_init_config = merge_ray_runtime_env(ray_init_config)
414404
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
415405
ray.init(**ray_init_config)
416406

sota-implementations/grpo/grpo-async.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import hydra
1515

16-
from torchrl import torchrl_logger
16+
from torchrl import merge_ray_runtime_env, torchrl_logger
1717
from torchrl.data.llm.history import History
1818
from torchrl.record.loggers.wandb import WandbLogger
1919
from torchrl.weight_update.llm import get_model_metadata
@@ -319,19 +319,9 @@ def main(cfg):
319319
if not k.startswith("_")
320320
}
321321

322-
# Add computed GPU configuration
322+
# Add computed GPU configuration and merge with default runtime_env
323323
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
324-
# Ensure runtime_env and env_vars exist
325-
if "runtime_env" not in ray_init_config:
326-
ray_init_config["runtime_env"] = {}
327-
if not isinstance(ray_init_config["runtime_env"], dict):
328-
ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"])
329-
if "env_vars" not in ray_init_config["runtime_env"]:
330-
ray_init_config["runtime_env"]["env_vars"] = {}
331-
if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict):
332-
ray_init_config["runtime_env"]["env_vars"] = dict(
333-
ray_init_config["runtime_env"]["env_vars"]
334-
)
324+
ray_init_config = merge_ray_runtime_env(ray_init_config)
335325
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
336326
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
337327
if ray_managed_externally:

sota-implementations/grpo/grpo-sync.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import hydra
1414

15-
from torchrl import torchrl_logger
15+
from torchrl import merge_ray_runtime_env, torchrl_logger
1616
from torchrl.data.llm.history import History
1717
from torchrl.record.loggers.wandb import WandbLogger
1818
from torchrl.weight_update.llm import get_model_metadata
@@ -319,19 +319,9 @@ def main(cfg):
319319
if not k.startswith("_")
320320
}
321321

322-
# Add computed GPU configuration
322+
# Add computed GPU configuration and merge with default runtime_env
323323
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
324-
# Ensure runtime_env and env_vars exist
325-
if "runtime_env" not in ray_init_config:
326-
ray_init_config["runtime_env"] = {}
327-
if not isinstance(ray_init_config["runtime_env"], dict):
328-
ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"])
329-
if "env_vars" not in ray_init_config["runtime_env"]:
330-
ray_init_config["runtime_env"]["env_vars"] = {}
331-
if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict):
332-
ray_init_config["runtime_env"]["env_vars"] = dict(
333-
ray_init_config["runtime_env"]["env_vars"]
334-
)
324+
ray_init_config = merge_ray_runtime_env(ray_init_config)
335325
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
336326
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
337327
if ray_managed_externally:
Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
torch==2.7.0
2-
transformers==4.52.4
3-
peft==0.15.2
4-
bitsandbytes==0.46.0
5-
datasets==3.6.0
6-
wandb==0.19.11
7-
hydra-core==1.3.2
8-
ray==2.46.0
9-
tqdm==4.67.1
10-
tensordict==0.9.0
11-
vllm==0.9.0.1
12-
accelerate==1.7.0
13-
xformers==0.0.30
1+
vllm==0.11.0
2+
peft
3+
bitsandbytes
4+
datasets
5+
wandb
6+
hydra-core
7+
ray
8+
tqdm
9+
tensordict
10+
accelerate
11+
xformers
Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
torch==2.7.0
2-
transformers==4.52.4
3-
peft==0.15.2
4-
bitsandbytes==0.46.0
5-
datasets==3.6.0
6-
wandb==0.19.11
7-
hydra-core==1.3.2
8-
ray==2.46.0
9-
tqdm==4.67.1
10-
tensordict==0.9.0
11-
vllm==0.9.0.1
12-
accelerate==1.7.0
13-
xformers==0.0.30
14-
nltk==3.9.1
15-
langdetect==1.0.9
16-
immutabledict==4.2.1
1+
vllm==0.11.0
2+
torch
3+
transformers
4+
peft
5+
bitsandbytes
6+
datasets
7+
wandb
8+
hydra-core
9+
ray
10+
tqdm
11+
tensordict
12+
accelerate
13+
xformers
14+
nltk
15+
langdetect
16+
immutabledict

test/llm/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
from tensordict import lazy_stack, set_list_to_stack, TensorDict
1515

16-
from torchrl import torchrl_logger
16+
from torchrl import logger as torchrl_logger
1717

1818
from torchrl.data import (
1919
History,

0 commit comments

Comments
 (0)