77
88from internnav .configs .model .base_encoders import ModelCfg
99from internnav .configs .trainer .exp import ExpCfg
10- from internnav .model .encoder .navdp_backbone import *
10+ from internnav .model .encoder .navdp_backbone import (
11+ ImageGoalBackbone ,
12+ LearnablePositionalEncoding ,
13+ PixelGoalBackbone ,
14+ RGBDBackbone ,
15+ SinusoidalPosEmb ,
16+ )
1117
1218
1319class NavDPModelConfig (PretrainedConfig ):
@@ -51,9 +57,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
5157 elif pretrained_model_name_or_path is None or len (pretrained_model_name_or_path ) == 0 :
5258 pass
5359 else :
54- incompatible_keys , _ = model .load_state_dict (
55- torch .load (pretrained_model_name_or_path )['state_dict' ], strict = False
56- )
60+ incompatible_keys , _ = model .load_state_dict (torch .load (pretrained_model_name_or_path ), strict = False )
5761 if len (incompatible_keys ) > 0 :
5862 print (f'Incompatible keys: { incompatible_keys } ' )
5963
@@ -66,31 +70,32 @@ def __init__(self, config: NavDPModelConfig):
6670 self .model_config = ModelCfg (** config .model_cfg ['model' ])
6771 else :
6872 self .model_config = config
69-
7073 self .config .model_cfg ['il' ]
71-
7274 self ._device = torch .device (f"cuda:{ config .model_cfg ['local_rank' ]} " )
7375 self .image_size = self .config .model_cfg ['il' ]['image_size' ]
7476 self .memory_size = self .config .model_cfg ['il' ]['memory_size' ]
7577 self .predict_size = self .config .model_cfg ['il' ]['predict_size' ]
78+ self .pixel_channel = self .config .model_cfg ['il' ]['pixel_channel' ]
7679 self .temporal_depth = self .config .model_cfg ['il' ]['temporal_depth' ]
7780 self .attention_heads = self .config .model_cfg ['il' ]['heads' ]
7881 self .input_channels = self .config .model_cfg ['il' ]['channels' ]
7982 self .dropout = self .config .model_cfg ['il' ]['dropout' ]
8083 self .token_dim = self .config .model_cfg ['il' ]['token_dim' ]
8184 self .scratch = self .config .model_cfg ['il' ]['scratch' ]
8285 self .finetune = self .config .model_cfg ['il' ]['finetune' ]
83- self .rgbd_encoder = NavDP_RGBD_Backbone (
86+ self .rgbd_encoder = RGBDBackbone (
8487 self .image_size , self .token_dim , memory_size = self .memory_size , finetune = self .finetune , device = self ._device
8588 )
86- self .pixel_encoder = NavDP_PixelGoal_Backbone (self .image_size , self .token_dim , device = self ._device )
87- self .image_encoder = NavDP_ImageGoal_Backbone (self .image_size , self .token_dim , device = self ._device )
89+ self .pixel_encoder = PixelGoalBackbone (
90+ self .image_size , self .token_dim , pixel_channel = self .pixel_channel , device = self ._device
91+ )
92+ self .image_encoder = ImageGoalBackbone (self .image_size , self .token_dim , device = self ._device )
8893 self .point_encoder = nn .Linear (3 , self .token_dim )
8994
9095 if not self .finetune :
91- for p in self .rgbd_encoder .parameters ():
96+ for p in self .rgbd_encoder .rgb_model . parameters ():
9297 p .requires_grad = False
93- self .rgbd_encoder .eval ()
98+ self .rgbd_encoder .rgb_model . eval ()
9499
95100 decoder_layer = nn .TransformerDecoderLayer (
96101 d_model = self .token_dim ,
@@ -180,23 +185,6 @@ def predict_critic(self, predict_trajectory, rgbd_embed):
180185 return critic_output
181186
182187 def forward (self , goal_point , goal_image , goal_pixel , input_images , input_depths , output_actions , augment_actions ):
183- # """get device safely"""
184- # # get device safely
185- # try:
186- # # try to get device through model parameters
187- # device = next(self.parameters()).device
188- # except StopIteration:
189- # # model has no parameters, use the default device
190- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191- # # move all inputs to model device
192- # goal_point = goal_point.to(device)
193- # goal_image = goal_image.to(device)
194- # input_images = input_images.to(device)
195- # input_depths = input_depths.to(device)
196- # output_actions = output_actions.to(device)
197- # augment_actions = augment_actions.to(device)
198- # device = self._device
199- # print(f"self.parameters() is:{self.parameters()}")
200188 device = next (self .parameters ()).device
201189
202190 assert input_images .shape [1 ] == self .memory_size
@@ -325,7 +313,6 @@ def predict_pointgoal_batch_action_vel(self, goal_point, input_images, input_dep
325313 naction = self .noise_scheduler .step (model_output = noise_pred , timestep = k , sample = naction ).prev_sample
326314
327315 critic_values = self .predict_critic (naction , rgbd_embed )
328- all_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )
329316
330317 negative_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )[(critic_values ).argsort ()[0 :8 ]]
331318 positive_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )[(- critic_values ).argsort ()[0 :8 ]]
@@ -344,7 +331,6 @@ def predict_nogoal_batch_action_vel(self, input_images, input_depths, sample_num
344331 naction = self .noise_scheduler .step (model_output = noise_pred , timestep = k , sample = naction ).prev_sample
345332
346333 critic_values = self .predict_critic (naction , rgbd_embed )
347- all_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )
348334
349335 negative_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )[(critic_values ).argsort ()[0 :8 ]]
350336 positive_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )[(- critic_values ).argsort ()[0 :8 ]]
0 commit comments