11import os
22import sys
33
4- sys .path .append (os . path . abspath ( os . path . join ( os . path . dirname ( __file__ ), '../../' )) )
4+ sys .path .append ('./src/diffusion-policy' )
55import logging
6+ import sys
7+ from datetime import datetime
68from pathlib import Path
7- import torch .distributed as dist
89
910import torch
11+ import torch .distributed as dist
1012import tyro
1113from pydantic import BaseModel
1214from transformers import TrainerCallback , TrainingArguments
1315
1416from internnav .dataset .cma_lerobot_dataset import CMALerobotDataset , cma_collate_fn
15- from internnav .dataset .rdp_lerobot_dataset import RDP_LerobotDataset , rdp_collate_fn
1617from internnav .dataset .navdp_dataset_lerobot import NavDP_Base_Datset , navdp_collate_fn
17- from internnav .model import (
18- CMAModelConfig ,
19- CMANet ,
20- RDPModelConfig ,
21- RDPNet ,
22- Seq2SeqModelConfig ,
23- Seq2SeqNet ,
24- NavDPNet ,
25- NavDPModelConfig ,
26- )
18+ from internnav .dataset .rdp_lerobot_dataset import RDP_LerobotDataset , rdp_collate_fn
19+ from internnav .model import get_config , get_policy
2720from internnav .model .utils .logger import MyLogger
2821from internnav .model .utils .utils import load_dataset
29- from internnav .trainer import CMATrainer , RDPTrainer , NavDPTrainer
22+ from internnav .trainer import CMATrainer , NavDPTrainer , RDPTrainer
3023from scripts .train .configs import (
3124 cma_exp_cfg ,
3225 cma_plus_exp_cfg ,
26+ navdp_exp_cfg ,
3327 rdp_exp_cfg ,
3428 seq2seq_exp_cfg ,
3529 seq2seq_plus_exp_cfg ,
36- navdp_exp_cfg ,
3730)
38- import sys
39- from datetime import datetime
31+
4032
4133class TrainCfg (BaseModel ):
4234 """Training configuration class"""
@@ -68,16 +60,16 @@ def on_save(self, args, state, control, **kwargs):
6860
6961
7062def _make_dir (config ):
71- config .tensorboard_dir = config .tensorboard_dir % config .name
63+ config .tensorboard_dir = config .tensorboard_dir % config .name
7264 config .checkpoint_folder = config .checkpoint_folder % config .name
7365 config .log_dir = config .log_dir % config .name
7466 config .output_dir = config .output_dir % config .name
7567 if not os .path .exists (config .tensorboard_dir ):
76- os .makedirs (config .tensorboard_dir ,exist_ok = True )
68+ os .makedirs (config .tensorboard_dir , exist_ok = True )
7769 if not os .path .exists (config .checkpoint_folder ):
78- os .makedirs (config .checkpoint_folder ,exist_ok = True )
70+ os .makedirs (config .checkpoint_folder , exist_ok = True )
7971 if not os .path .exists (config .log_dir ):
80- os .makedirs (config .log_dir ,exist_ok = True )
72+ os .makedirs (config .log_dir , exist_ok = True )
8173
8274
8375def main (config , model_class , model_config_class ):
@@ -101,28 +93,23 @@ def main(config, model_class, model_config_class):
10193 local_rank = int (os .getenv ('LOCAL_RANK' , '0' ))
10294 world_size = int (os .getenv ('WORLD_SIZE' , '1' ))
10395 rank = int (os .getenv ('RANK' , '0' ))
104-
96+
10597 # Set CUDA device for each process
10698 device_id = local_rank
10799 torch .cuda .set_device (device_id )
108100 device = torch .device (f'cuda:{ device_id } ' )
109101 print (f"World size: { world_size } , Local rank: { local_rank } , Global rank: { rank } " )
110-
102+
111103 # Initialize distributed training environment
112104 if world_size > 1 :
113105 try :
114- dist .init_process_group (
115- backend = 'nccl' ,
116- init_method = 'env://' ,
117- world_size = world_size ,
118- rank = rank
119- )
106+ dist .init_process_group (backend = 'nccl' , init_method = 'env://' , world_size = world_size , rank = rank )
120107 print ("Distributed initialization SUCCESS" )
121108 except Exception as e :
122109 print (f"Distributed initialization FAILED: { str (e )} " )
123110 world_size = 1
124111
125- print ("=" * 50 )
112+ print ("=" * 50 )
126113 print ("After distributed init:" )
127114 print (f"LOCAL_RANK: { local_rank } " )
128115 print (f"WORLD_SIZE: { world_size } " )
@@ -150,46 +137,45 @@ def main(config, model_class, model_config_class):
150137 if buffer .device != device :
151138 print (f"Buffer { name } is on wrong device { buffer .device } , should be moved to { device } " )
152139 buffer .data = buffer .data .to (device )
153-
140+
154141 # If distributed training, wrap the model with DDP
155142 if world_size > 1 :
156143 model = torch .nn .parallel .DistributedDataParallel (
157- model ,
158- device_ids = [local_rank ],
159- output_device = local_rank ,
160- find_unused_parameters = True
144+ model , device_ids = [local_rank ], output_device = local_rank , find_unused_parameters = True
161145 )
162146 # ------------ load logger ------------
163147 train_logger_filename = os .path .join (config .log_dir , 'train.log' )
164148 if dist .is_initialized () and dist .get_rank () == 0 :
165149 train_logger = MyLogger (
166- name = 'train' , level = logging .INFO , format_str = '%(asctime)-15s %(message)s' , filename = train_logger_filename
150+ name = 'train' ,
151+ level = logging .INFO ,
152+ format_str = '%(asctime)-15s %(message)s' ,
153+ filename = train_logger_filename ,
167154 )
168155 else :
169156 # Other processes use console logging
170- train_logger = MyLogger (
171- name = 'train' , level = logging .INFO , format_str = '%(asctime)-15s %(message)s'
172- )
157+ train_logger = MyLogger (name = 'train' , level = logging .INFO , format_str = '%(asctime)-15s %(message)s' )
173158 transformers_logger = logging .getLogger ("transformers" )
174159 if transformers_logger .hasHandlers ():
175160 transformers_logger .handlers = []
176161 if config .model_name == "navdp" and local_rank in [0 , - 1 ]: # Only main process or non-distributed
177162 transformers_logger .addHandler (train_logger .handlers [0 ])
178163 transformers_logger .setLevel (logging .INFO )
179164
180-
181165 # ------------ load dataset ------------
182166 if config .model_name == "navdp" :
183- train_dataset_data = NavDP_Base_Datset (config .il .root_dir ,
184- config .il .dataset_navdp ,
185- config .il .memory_size ,
186- config .il .predict_size ,
187- config .il .batch_size ,
188- config .il .image_size ,
189- config .il .scene_scale ,
190- preload = config .il .preload ,
191- random_digit = config .il .random_digit ,
192- prior_sample = config .il .prior_sample )
167+ train_dataset_data = NavDP_Base_Datset (
168+ config .il .root_dir ,
169+ config .il .dataset_navdp ,
170+ config .il .memory_size ,
171+ config .il .predict_size ,
172+ config .il .batch_size ,
173+ config .il .image_size ,
174+ config .il .scene_scale ,
175+ preload = config .il .preload ,
176+ random_digit = config .il .random_digit ,
177+ prior_sample = config .il .prior_sample ,
178+ )
193179 else :
194180 if '3dgs' in config .il .lmdb_features_dir or '3dgs' in config .il .lmdb_features_dir :
195181 dataset_root_dir = config .il .dataset_six_floor_root_dir
@@ -223,7 +209,7 @@ def main(config, model_class, model_config_class):
223209 config ,
224210 config .il .lerobot_features_dir ,
225211 dataset_data = train_dataset_data ,
226- batch_size = config .il .batch_size ,
212+ batch_size = config .il .batch_size ,
227213 )
228214 collate_fn = rdp_collate_fn (global_batch_size = global_batch_size )
229215 elif config .model_name == 'navdp' :
@@ -238,7 +224,7 @@ def main(config, model_class, model_config_class):
238224 remove_unused_columns = False ,
239225 deepspeed = '' ,
240226 gradient_checkpointing = False ,
241- bf16 = False ,# fp16=False,
227+ bf16 = False , # fp16=False,
242228 tf32 = False ,
243229 per_device_train_batch_size = config .il .batch_size ,
244230 gradient_accumulation_steps = 1 ,
@@ -249,7 +235,7 @@ def main(config, model_class, model_config_class):
249235 lr_scheduler_type = 'cosine' ,
250236 logging_steps = 10.0 ,
251237 num_train_epochs = config .il .epochs ,
252- save_strategy = 'epoch' ,# no
238+ save_strategy = 'epoch' , # no
253239 save_steps = config .il .save_interval_epochs ,
254240 save_total_limit = 8 ,
255241 report_to = config .il .report_to ,
@@ -260,7 +246,7 @@ def main(config, model_class, model_config_class):
260246 torch_compile_mode = None ,
261247 dataloader_drop_last = True ,
262248 disable_tqdm = True ,
263- log_level = "info"
249+ log_level = "info" ,
264250 )
265251
266252 # Create the trainer
@@ -279,14 +265,15 @@ def main(config, model_class, model_config_class):
279265 handler .flush ()
280266 except Exception as e :
281267 import traceback
268+
282269 print (f"Unhandled exception: { str (e )} " )
283270 print ("Stack trace:" )
284271 traceback .print_exc ()
285-
272+
286273 # If distributed environment, ensure all processes exit
287274 if dist .is_initialized ():
288275 dist .destroy_process_group ()
289-
276+
290277 raise
291278
292279
@@ -304,18 +291,20 @@ def main(config, model_class, model_config_class):
304291
305292 # Select configuration based on model_name
306293 supported_cfg = {
307- 'seq2seq' : [seq2seq_exp_cfg , Seq2SeqNet , Seq2SeqModelConfig ],
308- 'seq2seq_plus' : [seq2seq_plus_exp_cfg , Seq2SeqNet , Seq2SeqModelConfig ],
309- 'cma' : [cma_exp_cfg , CMANet , CMAModelConfig ],
310- 'cma_plus' : [cma_plus_exp_cfg , CMANet , CMAModelConfig ],
311- 'rdp' : [rdp_exp_cfg , RDPNet , RDPModelConfig ],
312- 'navdp' : [navdp_exp_cfg , NavDPNet , NavDPModelConfig ],
294+ 'seq2seq' : [seq2seq_exp_cfg , "Seq2Seq_Policy" ],
295+ 'seq2seq_plus' : [seq2seq_plus_exp_cfg , 'Seq2Seq_Policy' ],
296+ 'cma' : [cma_exp_cfg , "CMA_Policy" ],
297+ 'cma_plus' : [cma_plus_exp_cfg , "CMA_Policy" ],
298+ 'rdp' : [rdp_exp_cfg , "RDP_Policy" ],
299+ 'navdp' : [navdp_exp_cfg , "NavDP_Policy" ],
313300 }
314301
315302 if config .model_name not in supported_cfg :
316303 raise ValueError (f'Invalid model name: { config .model_name } . Supported models are: { list (supported_cfg .keys ())} ' )
317304
318- exp_cfg , model_class , model_config_class = supported_cfg [config .model_name ]
305+ exp_cfg , policy_name = supported_cfg [config .model_name ]
306+ model_class , model_config_class = get_policy (policy_name ), get_config (policy_name )
307+
319308 exp_cfg .name = config .name
320309 exp_cfg .num_gpus = len (exp_cfg .torch_gpu_ids )
321310 exp_cfg .world_size = exp_cfg .num_gpus
0 commit comments