Skip to content

Commit c48cdc2

Browse files
authored
[Fix] Fix the import bug in the training script (#139)
* fix train import bug * updated to latest doc; fix meaningless sys path append in train.py; fix habitat config dataset path * update import in train.py
1 parent 5adf875 commit c48cdc2

File tree

3 files changed

+56
-67
lines changed

3 files changed

+56
-67
lines changed

scripts/eval/configs/habitat_r2r_pix.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,5 @@ habitat:
7979
dataset:
8080
type: R2RVLN-v1
8181
split: val_seen
82-
scenes_dir: data/scene_datasets/
82+
scenes_dir: data/scene_data/mp3d_ce
8383
data_path: data/datasets/vln/mp3d/r2r/v1/{split}/{split}.json.gz

scripts/eval/configs/vln_r2r.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ habitat:
6969
look_down:
7070
type: LookDownAction
7171
agent_index: 0
72-
72+
7373
dataset:
7474
type: R2RVLN-v1
7575
split: val_seen
76-
scenes_dir: data/scene_data/
77-
data_path: data/vln_ce/raw_data/r2r/{split}/{split}.json.gz
76+
scenes_dir: data/scene_data/mp3d_ce
77+
data_path: data/vln_ce/raw_data/r2r/{split}/{split}.json.gz

scripts/train/train.py

Lines changed: 52 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,34 @@
11
import os
22
import sys
33

4-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
4+
sys.path.append('./src/diffusion-policy')
55
import logging
6+
import sys
7+
from datetime import datetime
68
from pathlib import Path
7-
import torch.distributed as dist
89

910
import torch
11+
import torch.distributed as dist
1012
import tyro
1113
from pydantic import BaseModel
1214
from transformers import TrainerCallback, TrainingArguments
1315

1416
from internnav.dataset.cma_lerobot_dataset import CMALerobotDataset, cma_collate_fn
15-
from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn
1617
from internnav.dataset.navdp_dataset_lerobot import NavDP_Base_Datset, navdp_collate_fn
17-
from internnav.model import (
18-
CMAModelConfig,
19-
CMANet,
20-
RDPModelConfig,
21-
RDPNet,
22-
Seq2SeqModelConfig,
23-
Seq2SeqNet,
24-
NavDPNet,
25-
NavDPModelConfig,
26-
)
18+
from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn
19+
from internnav.model import get_config, get_policy
2720
from internnav.model.utils.logger import MyLogger
2821
from internnav.model.utils.utils import load_dataset
29-
from internnav.trainer import CMATrainer, RDPTrainer, NavDPTrainer
22+
from internnav.trainer import CMATrainer, NavDPTrainer, RDPTrainer
3023
from scripts.train.configs import (
3124
cma_exp_cfg,
3225
cma_plus_exp_cfg,
26+
navdp_exp_cfg,
3327
rdp_exp_cfg,
3428
seq2seq_exp_cfg,
3529
seq2seq_plus_exp_cfg,
36-
navdp_exp_cfg,
3730
)
38-
import sys
39-
from datetime import datetime
31+
4032

4133
class TrainCfg(BaseModel):
4234
"""Training configuration class"""
@@ -68,16 +60,16 @@ def on_save(self, args, state, control, **kwargs):
6860

6961

7062
def _make_dir(config):
71-
config.tensorboard_dir = config.tensorboard_dir % config.name
63+
config.tensorboard_dir = config.tensorboard_dir % config.name
7264
config.checkpoint_folder = config.checkpoint_folder % config.name
7365
config.log_dir = config.log_dir % config.name
7466
config.output_dir = config.output_dir % config.name
7567
if not os.path.exists(config.tensorboard_dir):
76-
os.makedirs(config.tensorboard_dir,exist_ok=True)
68+
os.makedirs(config.tensorboard_dir, exist_ok=True)
7769
if not os.path.exists(config.checkpoint_folder):
78-
os.makedirs(config.checkpoint_folder,exist_ok=True)
70+
os.makedirs(config.checkpoint_folder, exist_ok=True)
7971
if not os.path.exists(config.log_dir):
80-
os.makedirs(config.log_dir,exist_ok=True)
72+
os.makedirs(config.log_dir, exist_ok=True)
8173

8274

8375
def main(config, model_class, model_config_class):
@@ -101,28 +93,23 @@ def main(config, model_class, model_config_class):
10193
local_rank = int(os.getenv('LOCAL_RANK', '0'))
10294
world_size = int(os.getenv('WORLD_SIZE', '1'))
10395
rank = int(os.getenv('RANK', '0'))
104-
96+
10597
# Set CUDA device for each process
10698
device_id = local_rank
10799
torch.cuda.set_device(device_id)
108100
device = torch.device(f'cuda:{device_id}')
109101
print(f"World size: {world_size}, Local rank: {local_rank}, Global rank: {rank}")
110-
102+
111103
# Initialize distributed training environment
112104
if world_size > 1:
113105
try:
114-
dist.init_process_group(
115-
backend='nccl',
116-
init_method='env://',
117-
world_size=world_size,
118-
rank=rank
119-
)
106+
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
120107
print("Distributed initialization SUCCESS")
121108
except Exception as e:
122109
print(f"Distributed initialization FAILED: {str(e)}")
123110
world_size = 1
124111

125-
print("="*50)
112+
print("=" * 50)
126113
print("After distributed init:")
127114
print(f"LOCAL_RANK: {local_rank}")
128115
print(f"WORLD_SIZE: {world_size}")
@@ -150,46 +137,45 @@ def main(config, model_class, model_config_class):
150137
if buffer.device != device:
151138
print(f"Buffer {name} is on wrong device {buffer.device}, should be moved to {device}")
152139
buffer.data = buffer.data.to(device)
153-
140+
154141
# If distributed training, wrap the model with DDP
155142
if world_size > 1:
156143
model = torch.nn.parallel.DistributedDataParallel(
157-
model,
158-
device_ids=[local_rank],
159-
output_device=local_rank,
160-
find_unused_parameters=True
144+
model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True
161145
)
162146
# ------------ load logger ------------
163147
train_logger_filename = os.path.join(config.log_dir, 'train.log')
164148
if dist.is_initialized() and dist.get_rank() == 0:
165149
train_logger = MyLogger(
166-
name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s', filename=train_logger_filename
150+
name='train',
151+
level=logging.INFO,
152+
format_str='%(asctime)-15s %(message)s',
153+
filename=train_logger_filename,
167154
)
168155
else:
169156
# Other processes use console logging
170-
train_logger = MyLogger(
171-
name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s'
172-
)
157+
train_logger = MyLogger(name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s')
173158
transformers_logger = logging.getLogger("transformers")
174159
if transformers_logger.hasHandlers():
175160
transformers_logger.handlers = []
176161
if config.model_name == "navdp" and local_rank in [0, -1]: # Only main process or non-distributed
177162
transformers_logger.addHandler(train_logger.handlers[0])
178163
transformers_logger.setLevel(logging.INFO)
179164

180-
181165
# ------------ load dataset ------------
182166
if config.model_name == "navdp":
183-
train_dataset_data = NavDP_Base_Datset(config.il.root_dir,
184-
config.il.dataset_navdp,
185-
config.il.memory_size,
186-
config.il.predict_size,
187-
config.il.batch_size,
188-
config.il.image_size,
189-
config.il.scene_scale,
190-
preload = config.il.preload,
191-
random_digit = config.il.random_digit,
192-
prior_sample = config.il.prior_sample)
167+
train_dataset_data = NavDP_Base_Datset(
168+
config.il.root_dir,
169+
config.il.dataset_navdp,
170+
config.il.memory_size,
171+
config.il.predict_size,
172+
config.il.batch_size,
173+
config.il.image_size,
174+
config.il.scene_scale,
175+
preload=config.il.preload,
176+
random_digit=config.il.random_digit,
177+
prior_sample=config.il.prior_sample,
178+
)
193179
else:
194180
if '3dgs' in config.il.lmdb_features_dir or '3dgs' in config.il.lmdb_features_dir:
195181
dataset_root_dir = config.il.dataset_six_floor_root_dir
@@ -223,7 +209,7 @@ def main(config, model_class, model_config_class):
223209
config,
224210
config.il.lerobot_features_dir,
225211
dataset_data=train_dataset_data,
226-
batch_size=config.il.batch_size,
212+
batch_size=config.il.batch_size,
227213
)
228214
collate_fn = rdp_collate_fn(global_batch_size=global_batch_size)
229215
elif config.model_name == 'navdp':
@@ -238,7 +224,7 @@ def main(config, model_class, model_config_class):
238224
remove_unused_columns=False,
239225
deepspeed='',
240226
gradient_checkpointing=False,
241-
bf16=False,#fp16=False,
227+
bf16=False, # fp16=False,
242228
tf32=False,
243229
per_device_train_batch_size=config.il.batch_size,
244230
gradient_accumulation_steps=1,
@@ -249,7 +235,7 @@ def main(config, model_class, model_config_class):
249235
lr_scheduler_type='cosine',
250236
logging_steps=10.0,
251237
num_train_epochs=config.il.epochs,
252-
save_strategy='epoch',# no
238+
save_strategy='epoch', # no
253239
save_steps=config.il.save_interval_epochs,
254240
save_total_limit=8,
255241
report_to=config.il.report_to,
@@ -260,7 +246,7 @@ def main(config, model_class, model_config_class):
260246
torch_compile_mode=None,
261247
dataloader_drop_last=True,
262248
disable_tqdm=True,
263-
log_level="info"
249+
log_level="info",
264250
)
265251

266252
# Create the trainer
@@ -279,14 +265,15 @@ def main(config, model_class, model_config_class):
279265
handler.flush()
280266
except Exception as e:
281267
import traceback
268+
282269
print(f"Unhandled exception: {str(e)}")
283270
print("Stack trace:")
284271
traceback.print_exc()
285-
272+
286273
# If distributed environment, ensure all processes exit
287274
if dist.is_initialized():
288275
dist.destroy_process_group()
289-
276+
290277
raise
291278

292279

@@ -304,18 +291,20 @@ def main(config, model_class, model_config_class):
304291

305292
# Select configuration based on model_name
306293
supported_cfg = {
307-
'seq2seq': [seq2seq_exp_cfg, Seq2SeqNet, Seq2SeqModelConfig],
308-
'seq2seq_plus': [seq2seq_plus_exp_cfg, Seq2SeqNet, Seq2SeqModelConfig],
309-
'cma': [cma_exp_cfg, CMANet, CMAModelConfig],
310-
'cma_plus': [cma_plus_exp_cfg, CMANet, CMAModelConfig],
311-
'rdp': [rdp_exp_cfg, RDPNet, RDPModelConfig],
312-
'navdp': [navdp_exp_cfg, NavDPNet, NavDPModelConfig],
294+
'seq2seq': [seq2seq_exp_cfg, "Seq2Seq_Policy"],
295+
'seq2seq_plus': [seq2seq_plus_exp_cfg, 'Seq2Seq_Policy'],
296+
'cma': [cma_exp_cfg, "CMA_Policy"],
297+
'cma_plus': [cma_plus_exp_cfg, "CMA_Policy"],
298+
'rdp': [rdp_exp_cfg, "RDP_Policy"],
299+
'navdp': [navdp_exp_cfg, "NavDP_Policy"],
313300
}
314301

315302
if config.model_name not in supported_cfg:
316303
raise ValueError(f'Invalid model name: {config.model_name}. Supported models are: {list(supported_cfg.keys())}')
317304

318-
exp_cfg, model_class, model_config_class = supported_cfg[config.model_name]
305+
exp_cfg, policy_name = supported_cfg[config.model_name]
306+
model_class, model_config_class = get_policy(policy_name), get_config(policy_name)
307+
319308
exp_cfg.name = config.name
320309
exp_cfg.num_gpus = len(exp_cfg.torch_gpu_ids)
321310
exp_cfg.world_size = exp_cfg.num_gpus

0 commit comments

Comments
 (0)