diff --git a/code/decision_transformer/models/test_perceiver.py b/code/decision_transformer/models/test_perceiver.py index 25878e6..c32d9d9 100644 --- a/code/decision_transformer/models/test_perceiver.py +++ b/code/decision_transformer/models/test_perceiver.py @@ -13,6 +13,8 @@ "ViT-L-14", pretrained="openai") vision_encoder = vision_encoder.visual +vision_encoder.output_tokens = True + vis_dim=open_clip.get_model_config("ViT-L-14")["vision_cfg"]["width"] perceiver = PerceiverResampler(dim=vis_dim) @@ -21,13 +23,26 @@ batch = 5 num_images = 27 channels = 3 -height = 512 -width = 512 +height = 224 +width = 224 input_data = torch.randn((batch, num_images, 1, channels, height, width)) # vision_x (torch.Tensor): Vision input # shape (B, T_img, F, C, H, W) with F=1 +if False: + import torch + from PIL import Image + import open_clip + + image = image_processor(Image.open("dog.jpg")).unsqueeze(0) + print(image.shape) + + + with torch.no_grad(), torch.cuda.amp.autocast(): + image_features = vision_encoder.encode_image(image) + print(image_features.shape) + def encode_vision_x(vision_x: torch.Tensor): """ @@ -46,14 +61,18 @@ def encode_vision_x(vision_x: torch.Tensor): assert F == 1, "Only single frame supported" vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") + print(vision_x.shape) with torch.no_grad(): - vision_x = vision_encoder(vision_x)[1] - + vision_x, tokens = vision_encoder(vision_x) + #We might want the -2 instead by the way. + print(tokens.shape)#batch x frames x 768.. - vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) + vision_x = rearrange(tokens, "(b T F) v d -> b T F v d", b=b, T=T, F=F) + print(vision_x.shape)#Put back in original shape vision_x = perceiver(vision_x) + print(vision_x.shape) # for layer in lang_encoder._get_decoder_layers(): # layer.condition_vis_x(vision_x) -encode_vision_x(input_data) \ No newline at end of file +encode_vision_x(input_data) diff --git a/code/decision_transformer/models/vision_encoders.py b/code/decision_transformer/models/vision_encoders.py index 36c6cf1..aa836a8 100644 --- a/code/decision_transformer/models/vision_encoders.py +++ b/code/decision_transformer/models/vision_encoders.py @@ -37,6 +37,7 @@ def __init__(self): self.vision_encoder, _, self.image_processor = open_clip.create_model_and_transforms( "ViT-L-14", pretrained="openai") self.vision_encoder = self.vision_encoder.visual + self.vision_encoder.output_tokens = True self.vis_dim = open_clip.get_model_config("ViT-L-14")["vision_cfg"]["width"] self.perceiver = PerceiverResampler(dim=self.vis_dim) @@ -51,4 +52,4 @@ def forward(self, vision_x): vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) vision_x = self.perceiver(vision_x) - return vision_x \ No newline at end of file + return vision_x