Skip to content

Commit 1f9fc7d

Browse files
authored
[Fix] Fix the no_grad bug for NavDP training & Support NavDP finetuning (#144)
* [FEAT] Add support for navdp finetuning * [FIX] NavDP Training Gradient * [FIX] Support NavDP finetune * [FIX] update navdp training parameters * [FIX] Support NavDP finetuning * modify the class name for NavDP * update class name
1 parent e0d265c commit 1f9fc7d

File tree

5 files changed

+34
-39
lines changed

5 files changed

+34
-39
lines changed

internnav/dataset/navdp_dataset_lerobot.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
image_size=224,
4242
scene_data_scale=1.0,
4343
trajectory_data_scale=1.0,
44+
pixel_channel=7,
4445
debug=False,
4546
preload=False,
4647
random_digit=False,
@@ -61,6 +62,7 @@ def __init__(
6162
self.trajectory_afford_path = []
6263
self.random_digit = random_digit
6364
self.prior_sample = prior_sample
65+
self.pixel_channel = pixel_channel
6466
self.item_cnt = 0
6567
self.batch_size = batch_size
6668
self.batch_time_sum = 0.0
@@ -509,7 +511,12 @@ def __getitem__(self, index):
509511
camera_intrinsic,
510512
trajectory_base_extrinsic,
511513
)
512-
pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1)
514+
# pixel channel == 7 represents the navdp works pixel navigation under asynchronous pace,
515+
# pixel_mask (1), the history image with the assigned pixel goal (3), current image (3)
516+
# if pixel_channel == 4, pixel goal is assigned at current frame, therefore,
517+
# only pixel_mask (1) and current image (3) are needed
518+
if self.pixel_channel == 7:
519+
pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1)
513520

514521
pred_actions = (pred_actions[1:] - pred_actions[:-1]) * 4.0
515522
augment_actions = (augment_actions[1:] - augment_actions[:-1]) * 4.0

internnav/model/basemodel/navdp/navdp_policy.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77

88
from internnav.configs.model.base_encoders import ModelCfg
99
from internnav.configs.trainer.exp import ExpCfg
10-
from internnav.model.encoder.navdp_backbone import *
10+
from internnav.model.encoder.navdp_backbone import (
11+
ImageGoalBackbone,
12+
LearnablePositionalEncoding,
13+
PixelGoalBackbone,
14+
RGBDBackbone,
15+
SinusoidalPosEmb,
16+
)
1117

1218

1319
class NavDPModelConfig(PretrainedConfig):
@@ -51,9 +57,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
5157
elif pretrained_model_name_or_path is None or len(pretrained_model_name_or_path) == 0:
5258
pass
5359
else:
54-
incompatible_keys, _ = model.load_state_dict(
55-
torch.load(pretrained_model_name_or_path)['state_dict'], strict=False
56-
)
60+
incompatible_keys, _ = model.load_state_dict(torch.load(pretrained_model_name_or_path), strict=False)
5761
if len(incompatible_keys) > 0:
5862
print(f'Incompatible keys: {incompatible_keys}')
5963

@@ -66,31 +70,32 @@ def __init__(self, config: NavDPModelConfig):
6670
self.model_config = ModelCfg(**config.model_cfg['model'])
6771
else:
6872
self.model_config = config
69-
7073
self.config.model_cfg['il']
71-
7274
self._device = torch.device(f"cuda:{config.model_cfg['local_rank']}")
7375
self.image_size = self.config.model_cfg['il']['image_size']
7476
self.memory_size = self.config.model_cfg['il']['memory_size']
7577
self.predict_size = self.config.model_cfg['il']['predict_size']
78+
self.pixel_channel = self.config.model_cfg['il']['pixel_channel']
7679
self.temporal_depth = self.config.model_cfg['il']['temporal_depth']
7780
self.attention_heads = self.config.model_cfg['il']['heads']
7881
self.input_channels = self.config.model_cfg['il']['channels']
7982
self.dropout = self.config.model_cfg['il']['dropout']
8083
self.token_dim = self.config.model_cfg['il']['token_dim']
8184
self.scratch = self.config.model_cfg['il']['scratch']
8285
self.finetune = self.config.model_cfg['il']['finetune']
83-
self.rgbd_encoder = NavDP_RGBD_Backbone(
86+
self.rgbd_encoder = RGBDBackbone(
8487
self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device
8588
)
86-
self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size, self.token_dim, device=self._device)
87-
self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size, self.token_dim, device=self._device)
89+
self.pixel_encoder = PixelGoalBackbone(
90+
self.image_size, self.token_dim, pixel_channel=self.pixel_channel, device=self._device
91+
)
92+
self.image_encoder = ImageGoalBackbone(self.image_size, self.token_dim, device=self._device)
8893
self.point_encoder = nn.Linear(3, self.token_dim)
8994

9095
if not self.finetune:
91-
for p in self.rgbd_encoder.parameters():
96+
for p in self.rgbd_encoder.rgb_model.parameters():
9297
p.requires_grad = False
93-
self.rgbd_encoder.eval()
98+
self.rgbd_encoder.rgb_model.eval()
9499

95100
decoder_layer = nn.TransformerDecoderLayer(
96101
d_model=self.token_dim,
@@ -180,23 +185,6 @@ def predict_critic(self, predict_trajectory, rgbd_embed):
180185
return critic_output
181186

182187
def forward(self, goal_point, goal_image, goal_pixel, input_images, input_depths, output_actions, augment_actions):
183-
# """get device safely"""
184-
# # get device safely
185-
# try:
186-
# # try to get device through model parameters
187-
# device = next(self.parameters()).device
188-
# except StopIteration:
189-
# # model has no parameters, use the default device
190-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191-
# # move all inputs to model device
192-
# goal_point = goal_point.to(device)
193-
# goal_image = goal_image.to(device)
194-
# input_images = input_images.to(device)
195-
# input_depths = input_depths.to(device)
196-
# output_actions = output_actions.to(device)
197-
# augment_actions = augment_actions.to(device)
198-
# device = self._device
199-
# print(f"self.parameters() is:{self.parameters()}")
200188
device = next(self.parameters()).device
201189

202190
assert input_images.shape[1] == self.memory_size
@@ -325,7 +313,6 @@ def predict_pointgoal_batch_action_vel(self, goal_point, input_images, input_dep
325313
naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample
326314

327315
critic_values = self.predict_critic(naction, rgbd_embed)
328-
all_trajectory = torch.cumsum(naction / 4.0, dim=1)
329316

330317
negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]]
331318
positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]]
@@ -344,7 +331,6 @@ def predict_nogoal_batch_action_vel(self, input_images, input_depths, sample_num
344331
naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample
345332

346333
critic_values = self.predict_critic(naction, rgbd_embed)
347-
all_trajectory = torch.cumsum(naction / 4.0, dim=1)
348334

349335
negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]]
350336
positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]]

internnav/model/encoder/navdp_backbone.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def forward(self, images, depths):
202202
return memory_token
203203

204204

205-
class NavDP_RGBD_Backbone(nn.Module):
205+
class RGBDBackbone(nn.Module):
206206
def __init__(
207207
self,
208208
image_size=224,
@@ -313,7 +313,7 @@ def _get_device(self):
313313
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
314314

315315

316-
class NavDP_ImageGoal_Backbone(nn.Module):
316+
class ImageGoalBackbone(nn.Module):
317317
def __init__(self, image_size=224, embed_size=512, device='cuda:0'):
318318
super().__init__()
319319
if device is None:
@@ -376,8 +376,8 @@ def _get_device(self):
376376
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
377377

378378

379-
class NavDP_PixelGoal_Backbone(nn.Module):
380-
def __init__(self, image_size=224, embed_size=512, device='cuda:0'):
379+
class PixelGoalBackbone(nn.Module):
380+
def __init__(self, image_size=224, embed_size=512, pixel_channel=7, device='cuda:0'):
381381
super().__init__()
382382
if device is None:
383383
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -392,7 +392,7 @@ def __init__(self, image_size=224, embed_size=512, device='cuda:0'):
392392
self.pixelgoal_encoder = DepthAnythingV2(**model_configs['vits'])
393393
self.pixelgoal_encoder = self.pixelgoal_encoder.pretrained.float()
394394
self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d(
395-
in_channels=7,
395+
in_channels=pixel_channel,
396396
out_channels=self.pixelgoal_encoder.patch_embed.proj.out_channels,
397397
kernel_size=self.pixelgoal_encoder.patch_embed.proj.kernel_size,
398398
stride=self.pixelgoal_encoder.patch_embed.proj.stride,

scripts/train/configs/navdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
),
3030
il=IlCfg(
3131
epochs=1000,
32-
batch_size=16,
32+
batch_size=32,
3333
lr=1e-4,
3434
num_workers=8,
3535
weight_decay=1e-4, # TODO
@@ -57,6 +57,7 @@
5757
prior_sample=False,
5858
memory_size=8,
5959
predict_size=24,
60+
pixel_channel=4,
6061
temporal_depth=16,
6162
heads=8,
6263
token_dim=384,

scripts/train/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ def main(config, model_class, model_config_class):
7777
"""Main training function."""
7878
_make_dir(config)
7979

80-
print(f"=== Start training ===")
80+
print("=== Start training ===")
8181
print(f"Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
8282
print(f"PyTorch version: {torch.__version__}")
8383
print(f"CUDA available: {torch.cuda.is_available()}")
8484
print(f"CUDA device count: {torch.cuda.device_count()}")
85-
print(f"Environment variables:")
85+
print("Environment variables:")
8686
print(f" RANK: {os.getenv('RANK', 'Not set')}")
8787
print(f" LOCAL_RANK: {os.getenv('LOCAL_RANK', 'Not set')}")
8888
print(f" WORLD_SIZE: {os.getenv('WORLD_SIZE', 'Not set')}")
@@ -172,6 +172,7 @@ def main(config, model_class, model_config_class):
172172
config.il.batch_size,
173173
config.il.image_size,
174174
config.il.scene_scale,
175+
pixel_channel=config.il.pixel_channel,
175176
preload=config.il.preload,
176177
random_digit=config.il.random_digit,
177178
prior_sample=config.il.prior_sample,

0 commit comments

Comments
 (0)