From 4cc98c236bd0397f0f79a69a17bf43d506cf8f4f Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Tue, 25 Feb 2025 21:32:03 +0000 Subject: [PATCH 01/12] Pump deps: torch v2 and pytorchvideo latest --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6f576290..154cdaa1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -torch==1.13.1 +torch>=2.6.0 torchvision # because torch version already specific, the right torchvision will be derived automatically torchaudio # because torch version already specific, the right torchaudio will be derived automatically -pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d +pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@6cdc929315aab1b5674b6dcf73b16ec99147735f timm==0.6.7 ftfy regex From 7d68fc308205d3d013721b91798c951ed59afe06 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Tue, 25 Feb 2025 21:51:37 +0000 Subject: [PATCH 02/12] Update deprecated modules --- imagebind/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imagebind/data.py b/imagebind/data.py index 6b774d60..56dd7253 100644 --- a/imagebind/data.py +++ b/imagebind/data.py @@ -17,7 +17,7 @@ from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler from pytorchvideo.data.encoded_video import EncodedVideo from torchvision import transforms -from torchvision.transforms._transforms_video import NormalizeVideo +from torchvision.transforms import Normalize from imagebind.models.multimodal_preprocessors import SimpleTokenizer From 2997475598df09502a162fa6c41d85c4ec11220c Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Tue, 25 Feb 2025 22:01:53 +0000 Subject: [PATCH 03/12] Lower pytorch base version in requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 154cdaa1..4d49b77c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=2.6.0 +torch>=2.0.0 torchvision # because torch version already specific, the right torchvision will be derived automatically torchaudio # because torch version already specific, the right torchaudio will be derived automatically pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@6cdc929315aab1b5674b6dcf73b16ec99147735f From 80d074583850e3085a75aea85b29bd7d5c6bd33d Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Wed, 26 Feb 2025 00:18:03 +0200 Subject: [PATCH 04/12] Update usage instructions --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cacf3e90..b18ea4d8 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Emergent zero-shot classification performance. ## Usage -Install pytorch 1.13+ and other 3rd party dependencies. +Install pytorch 2.0+ and other 3rd party dependencies. ```shell conda create --name imagebind python=3.10 -y From ef9b462c5b975913c8687713d9225466d2ca8f5a Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 27 Feb 2025 00:05:32 +0000 Subject: [PATCH 05/12] relax timm and replace deprecated modules --- imagebind/models/multimodal_preprocessors.py | 2 +- imagebind/models/transformer.py | 2 +- requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/imagebind/models/multimodal_preprocessors.py b/imagebind/models/multimodal_preprocessors.py index 3384b871..b7b531a6 100644 --- a/imagebind/models/multimodal_preprocessors.py +++ b/imagebind/models/multimodal_preprocessors.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn from iopath.common.file_io import g_pathmgr -from timm.models.layers import trunc_normal_ +from timm.layers import trunc_normal_ from imagebind.models.helpers import VerboseNNModule, cast_if_src_dtype diff --git a/imagebind/models/transformer.py b/imagebind/models/transformer.py index 6224faf8..621d2b6d 100644 --- a/imagebind/models/transformer.py +++ b/imagebind/models/transformer.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, trunc_normal_ +from timm.layers import DropPath, trunc_normal_ class Attention(nn.Module): diff --git a/requirements.txt b/requirements.txt index 4d49b77c..465581c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torch>=2.0.0 torchvision # because torch version already specific, the right torchvision will be derived automatically torchaudio # because torch version already specific, the right torchaudio will be derived automatically pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@6cdc929315aab1b5674b6dcf73b16ec99147735f -timm==0.6.7 +timm ftfy regex einops From 274d2c04b3093efc09324a648a638fed58e10f51 Mon Sep 17 00:00:00 2001 From: Ahmedsaed Date: Tue, 8 Apr 2025 22:27:41 +0200 Subject: [PATCH 06/12] load state dict with weights_only option --- imagebind/models/imagebind_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imagebind/models/imagebind_model.py b/imagebind/models/imagebind_model.py index c560945f..d40c04c2 100644 --- a/imagebind/models/imagebind_model.py +++ b/imagebind/models/imagebind_model.py @@ -501,6 +501,6 @@ def imagebind_huge(pretrained=False): progress=True, ) - model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth")) + model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth", weights_only=True)) return model From db143b21290bdcd10b36005e9445b60dc9fd2aa4 Mon Sep 17 00:00:00 2001 From: Ahmedsaed Date: Tue, 8 Apr 2025 23:26:09 +0200 Subject: [PATCH 07/12] Add model splitting and modular loading functionality for ImageBind --- .gitignore | 3 +- imagebind/model_splitter.py | 74 +++++++++++ imagebind/modular_imagebind.py | 229 +++++++++++++++++++++++++++++++++ 3 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 imagebind/model_splitter.py create mode 100644 imagebind/modular_imagebind.py diff --git a/.gitignore b/.gitignore index 90529b73..2b9ec12f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ build/ imagebind.egg-info .DS_Store -venv/ \ No newline at end of file +venv/ +.checkpoints \ No newline at end of file diff --git a/imagebind/model_splitter.py b/imagebind/model_splitter.py new file mode 100644 index 00000000..ba673929 --- /dev/null +++ b/imagebind/model_splitter.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +import os +import torch +from imagebind.models.imagebind_model import imagebind_huge, ModalityType +from collections import OrderedDict + + +def split_imagebind_model(pretrained=True, save_dir=".checkpoints/modality_specific"): + """ + Load the full ImageBind model, split it by modality, and save modality-specific weights. + + Args: + pretrained: Whether to load pretrained weights + save_dir: Directory to save modality-specific weights + """ + # Create save directory if it doesn't exist + os.makedirs(save_dir, exist_ok=True) + + # Load the full model with pretrained weights + print(f"Loading full ImageBind model with pretrained={pretrained}...") + full_model = imagebind_huge(pretrained=pretrained) + + # Define the modalities we want to split + modalities = [ + ModalityType.VISION, + ModalityType.TEXT, + ModalityType.AUDIO, + ModalityType.DEPTH, + ModalityType.THERMAL, + ModalityType.IMU, + ] + + for modality in modalities: + print(f"Creating weights for {modality} modality...") + + # Create a dictionary for modality-specific state dict + modality_state_dict = OrderedDict() + + # Get the full state dict + full_state_dict = full_model.state_dict() + + # Extract common parameters (not specific to any modality) + common_prefixes = [] + + # Extract modality-specific parameters + modality_prefixes = [ + f"modality_preprocessors.{modality}", + f"modality_trunks.{modality}", + f"modality_heads.{modality}", + f"modality_postprocessors.{modality}", + ] + + # Collect all parameters for this modality + for k, v in full_state_dict.items(): + # Check if this is a modality-specific parameter + is_modality_specific = any( + k.startswith(prefix) for prefix in modality_prefixes + ) + is_common = any(k.startswith(prefix) for prefix in common_prefixes) + + if is_modality_specific or is_common: + modality_state_dict[k] = v + + # Save modality-specific state dict + save_path = os.path.join(save_dir, f"imagebind_{modality}.pth") + torch.save(modality_state_dict, save_path) + print(f"Saved {modality} weights to {save_path}") + print(f"Number of parameters: {len(modality_state_dict)}") + + print("Finished splitting model.") + + +if __name__ == "__main__": + split_imagebind_model(pretrained=True) diff --git a/imagebind/modular_imagebind.py b/imagebind/modular_imagebind.py new file mode 100644 index 00000000..c1c4604a --- /dev/null +++ b/imagebind/modular_imagebind.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 +import os +import torch +from imagebind.models.imagebind_model import ImageBindModel, ModalityType + + +class ModularImageBind(ImageBindModel): + """ + An extension of ImageBindModel that allows loading specific modalities only. + """ + + def __init__( + self, modalities=None, weights_dir=".checkpoints/modality_specific", **kwargs + ): + """ + Initialize a modality-specific ImageBind model. + + Args: + modalities: List of modalities to load (default: all modalities) + weights_dir: Directory containing modality-specific weights + **kwargs: Additional arguments to pass to ImageBindModel + """ + # Initialize with all modalities to create the architecture + super().__init__(**kwargs) + + # If no modalities specified, use all available + if modalities is None: + modalities = [ + ModalityType.VISION, + ModalityType.TEXT, + ModalityType.AUDIO, + ModalityType.DEPTH, + ModalityType.THERMAL, + ModalityType.IMU, + ] + + self.active_modalities = set(modalities) + + # Load weights for each modality + for modality in modalities: + self._load_modality_weights(modality, weights_dir) + + def _load_modality_weights(self, modality, weights_dir): + """ + Load weights for a specific modality. + + Args: + modality: Modality to load + weights_dir: Directory containing modality-specific weights + """ + weight_path = os.path.join(weights_dir, f"imagebind_{modality}.pth") + + if not os.path.exists(weight_path): + raise FileNotFoundError( + f"Weights for {modality} not found at {weight_path}" + ) + + # Load modality-specific weights + modality_state_dict = torch.load(weight_path, weights_only=True) + + # Create a temporary state dict for the current model state + current_state_dict = self.state_dict() + + # Update only the parameters for this modality + for k, v in modality_state_dict.items(): + if k in current_state_dict: + current_state_dict[k] = v + + # Load the updated state dict + self.load_state_dict(current_state_dict, strict=False) + + print(f"Loaded weights for {modality} modality") + + def forward(self, inputs): + """ + Forward pass for the model, using only active modalities. + + Args: + inputs: Dictionary of inputs for different modalities + + Returns: + Dictionary of outputs for the active modalities + """ + # Filter inputs to only use active modalities + filtered_inputs = { + k: v for k, v in inputs.items() if k in self.active_modalities + } + + # Call the parent's forward method with filtered inputs + return super().forward(filtered_inputs) + + +def load_modular_imagebind_huge( + modalities=None, weights_dir=".checkpoints/modality_specific" +): + """ + Helper function to load a modular ImageBind model with specific modalities. + + Args: + modalities: List of modalities to load (default: all modalities) + weights_dir: Directory containing modality-specific weights + + Returns: + ModularImageBind model with requested modalities + """ + model = ModularImageBind( + modalities=modalities, + weights_dir=weights_dir, + vision_embed_dim=1280, + vision_num_blocks=32, + vision_num_heads=16, + text_embed_dim=1024, + text_num_blocks=24, + text_num_heads=16, + out_embed_dim=1024, + audio_drop_path=0.1, + imu_drop_path=0.7, + ) + return model + + +# Example usage: +if __name__ == "__main__": + from imagebind import data + + text_list = ["A dog.", "A car", "A bird"] + image_paths = [ + ".assets/dog_image.jpg", + ".assets/car_image.jpg", + ".assets/bird_image.jpg", + ] + audio_paths = [ + ".assets/dog_audio.wav", + ".assets/car_audio.wav", + ".assets/bird_audio.wav", + ] + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Example 1: Load only vision and text modalities + print("Loading Vision-Text model...") + model_vision_text = load_modular_imagebind_huge( + modalities=[ModalityType.VISION, ModalityType.TEXT] + ) + model_vision_text.to(device) + + inputs = { + ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), + ModalityType.TEXT: data.load_and_transform_text(text_list, device), + } + + # Perform inference + with torch.no_grad(): + embeddings = model_vision_text(inputs) + + print( + "Vision x Text: ", + torch.softmax( + embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 + ), + ) + + del model_vision_text + del inputs + del embeddings + + # Example 2: Load only audio modality + print("Loading Audio model...") + model_audio = load_modular_imagebind_huge(modalities=[ModalityType.AUDIO]) + model_audio.to(device) + + inputs = { + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), + } + + # Perform inference + with torch.no_grad(): + embeddings = model_audio(inputs) + + print( + "Audio: ", + torch.softmax( + embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.AUDIO].T, dim=-1 + ), + ) + + del model_audio + del inputs + del embeddings + + # Example 3: Create a multimodal model with vision, text, and audio + print("Loading Multimodal model...") + model_multimodal = load_modular_imagebind_huge( + modalities=[ModalityType.VISION, ModalityType.TEXT, ModalityType.AUDIO] + ) + model_multimodal.to(device) + + inputs = { + ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), + ModalityType.TEXT: data.load_and_transform_text(text_list, device), + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), + } + + # Perform inference + with torch.no_grad(): + embeddings = model_multimodal(inputs) + + print( + "Vision x Text: ", + torch.softmax( + embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 + ), + ) + print( + "Audio x Text: ", + torch.softmax( + embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1 + ), + ) + print( + "Vision x Audio: ", + torch.softmax( + embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1 + ), + ) + + del model_multimodal + del inputs + del embeddings From a1614d6645425346ceb012396a0a3d76738b301f Mon Sep 17 00:00:00 2001 From: Ahmedsaed Date: Thu, 10 Apr 2025 12:46:41 +0200 Subject: [PATCH 08/12] Refactor ModularImageBind to support specific modality loading and profile memory usage --- imagebind/modular_imagebind.py | 472 +++++++++++++++++++++++++++++++-- 1 file changed, 445 insertions(+), 27 deletions(-) diff --git a/imagebind/modular_imagebind.py b/imagebind/modular_imagebind.py index c1c4604a..cdd0e41c 100644 --- a/imagebind/modular_imagebind.py +++ b/imagebind/modular_imagebind.py @@ -1,7 +1,30 @@ #!/usr/bin/env python3 import os -import torch +from memory_profiler import profile from imagebind.models.imagebind_model import ImageBindModel, ModalityType +from functools import partial + +import torch +import torch.nn as nn + +from imagebind.models.helpers import ( + EinOpsRearrange, + LearnableLogitScaling, + Normalize, + SelectElement, + SelectEOSAndProject, +) +from imagebind.models.multimodal_preprocessors import ( + AudioPreprocessor, + IMUPreprocessor, + PadIm2Video, + PatchEmbedGeneric, + RGBDTPreprocessor, + SpatioTemporalPosEmbeddingHelper, + TextPreprocessor, + ThermalPreprocessor, +) +from imagebind.models.transformer import MultiheadAttention, SimpleTransformer class ModularImageBind(ImageBindModel): @@ -20,9 +43,6 @@ def __init__( weights_dir: Directory containing modality-specific weights **kwargs: Additional arguments to pass to ImageBindModel """ - # Initialize with all modalities to create the architecture - super().__init__(**kwargs) - # If no modalities specified, use all available if modalities is None: modalities = [ @@ -36,10 +56,363 @@ def __init__( self.active_modalities = set(modalities) + # Initialize with all modalities to create the architecture + super().__init__(**kwargs) + # Load weights for each modality for modality in modalities: self._load_modality_weights(modality, weights_dir) + def _create_modality_preprocessors( + self, + video_frames=2, + vision_embed_dim=1024, + kernel_size=(2, 14, 14), + text_embed_dim=768, + audio_embed_dim=768, + audio_kernel_size=16, + audio_stride=10, + audio_num_mel_bins=128, + audio_target_len=204, + depth_embed_dim=768, + depth_kernel_size=16, + thermal_embed_dim=768, + thermal_kernel_size=16, + imu_embed_dim=512, + ): + if ModalityType.VISION in self.active_modalities: + rgbt_stem = PatchEmbedGeneric( + proj_stem=[ + PadIm2Video(pad_type="repeat", ntimes=2), + nn.Conv3d( + in_channels=3, + kernel_size=kernel_size, + out_channels=vision_embed_dim, + stride=kernel_size, + bias=False, + ), + ] + ) + rgbt_preprocessor = RGBDTPreprocessor( + img_size=[3, video_frames, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=rgbt_stem, + depth_stem=None, + ) + + if ModalityType.TEXT in self.active_modalities: + text_preprocessor = TextPreprocessor( + context_length=77, + vocab_size=49408, + embed_dim=text_embed_dim, + causal_masking=True, + ) + + if ModalityType.AUDIO in self.active_modalities: + audio_stem = PatchEmbedGeneric( + proj_stem=[ + nn.Conv2d( + in_channels=1, + kernel_size=audio_kernel_size, + stride=audio_stride, + out_channels=audio_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), + ) + audio_preprocessor = AudioPreprocessor( + img_size=[1, audio_num_mel_bins, audio_target_len], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + audio_stem=audio_stem, + ) + + if ModalityType.DEPTH in self.active_modalities: + depth_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=depth_kernel_size, + in_channels=1, + out_channels=depth_embed_dim, + stride=depth_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), + ) + + depth_preprocessor = RGBDTPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=None, + depth_stem=depth_stem, + ) + + if ModalityType.THERMAL in self.active_modalities: + thermal_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=thermal_kernel_size, + in_channels=1, + out_channels=thermal_embed_dim, + stride=thermal_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), + ) + thermal_preprocessor = ThermalPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + thermal_stem=thermal_stem, + ) + + if ModalityType.IMU in self.active_modalities: + imu_stem = PatchEmbedGeneric( + [ + nn.Linear( + in_features=48, + out_features=imu_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), + ) + + imu_preprocessor = IMUPreprocessor( + img_size=[6, 2000], + num_cls_tokens=1, + kernel_size=8, + embed_dim=imu_embed_dim, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + imu_stem=imu_stem, + ) + + modality_preprocessors = {} + if ModalityType.VISION in self.active_modalities: + modality_preprocessors[ModalityType.VISION] = rgbt_preprocessor + if ModalityType.TEXT in self.active_modalities: + modality_preprocessors[ModalityType.TEXT] = text_preprocessor + if ModalityType.AUDIO in self.active_modalities: + modality_preprocessors[ModalityType.AUDIO] = audio_preprocessor + if ModalityType.DEPTH in self.active_modalities: + modality_preprocessors[ModalityType.DEPTH] = depth_preprocessor + if ModalityType.THERMAL in self.active_modalities: + modality_preprocessors[ModalityType.THERMAL] = thermal_preprocessor + if ModalityType.IMU in self.active_modalities: + modality_preprocessors[ModalityType.IMU] = imu_preprocessor + + return nn.ModuleDict(modality_preprocessors) + + def _create_modality_trunks( + self, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_drop_path=0.0, + depth_embed_dim=768, + depth_num_blocks=12, + depth_num_heads=12, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + def instantiate_trunk( + embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path + ): + return SimpleTransformer( + embed_dim=embed_dim, + num_blocks=num_blocks, + ffn_dropout_rate=0.0, + drop_path_rate=drop_path, + attn_target=partial( + MultiheadAttention, + embed_dim=embed_dim, + num_heads=num_heads, + bias=True, + add_bias_kv=add_bias_kv, + ), + pre_transformer_layer=nn.Sequential( + ( + nn.LayerNorm(embed_dim, eps=1e-6) + if pre_transformer_ln + else nn.Identity() + ), + EinOpsRearrange("b l d -> l b d"), + ), + post_transformer_layer=EinOpsRearrange("l b d -> b l d"), + ) + + modality_trunks = {} + + if ModalityType.VISION in self.active_modalities: + modality_trunks[ModalityType.VISION] = instantiate_trunk( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + pre_transformer_ln=True, + add_bias_kv=False, + drop_path=0.0, + ) + if ModalityType.TEXT in self.active_modalities: + modality_trunks[ModalityType.TEXT] = instantiate_trunk( + text_embed_dim, + text_num_blocks, + text_num_heads, + pre_transformer_ln=False, + add_bias_kv=False, + drop_path=0.0, + ) + if ModalityType.AUDIO in self.active_modalities: + modality_trunks[ModalityType.AUDIO] = instantiate_trunk( + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=audio_drop_path, + ) + if ModalityType.DEPTH in self.active_modalities: + modality_trunks[ModalityType.DEPTH] = instantiate_trunk( + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=depth_drop_path, + ) + if ModalityType.THERMAL in self.active_modalities: + modality_trunks[ModalityType.THERMAL] = instantiate_trunk( + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=thermal_drop_path, + ) + if ModalityType.IMU in self.active_modalities: + modality_trunks[ModalityType.IMU] = instantiate_trunk( + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=imu_drop_path, + ) + + return nn.ModuleDict(modality_trunks) + + def _create_modality_heads( + self, + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ): + modality_heads = {} + + if ModalityType.VISION in self.active_modalities: + modality_heads[ModalityType.VISION] = nn.Sequential( + nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(vision_embed_dim, out_embed_dim, bias=False), + ) + + if ModalityType.TEXT in self.active_modalities: + modality_heads[ModalityType.TEXT] = SelectEOSAndProject( + proj=nn.Sequential( + nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), + nn.Linear(text_embed_dim, out_embed_dim, bias=False), + ) + ) + + if ModalityType.AUDIO in self.active_modalities: + modality_heads[ModalityType.AUDIO] = nn.Sequential( + nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(audio_embed_dim, out_embed_dim, bias=False), + ) + + if ModalityType.DEPTH in self.active_modalities: + modality_heads[ModalityType.DEPTH] = nn.Sequential( + nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(depth_embed_dim, out_embed_dim, bias=False), + ) + + if ModalityType.THERMAL in self.active_modalities: + modality_heads[ModalityType.THERMAL] = nn.Sequential( + nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), + ) + + if ModalityType.IMU in self.active_modalities: + modality_heads[ModalityType.IMU] = nn.Sequential( + nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Dropout(p=0.5), + nn.Linear(imu_embed_dim, out_embed_dim, bias=False), + ) + + return nn.ModuleDict(modality_heads) + + def _create_modality_postprocessors(self, out_embed_dim): + modality_postprocessors = {} + + if ModalityType.VISION in self.active_modalities: + modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) + + if ModalityType.TEXT in self.active_modalities: + modality_postprocessors[ModalityType.TEXT] = nn.Sequential( + Normalize(dim=-1), LearnableLogitScaling(learnable=True) + ) + + if ModalityType.AUDIO in self.active_modalities: + modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=20.0, learnable=False), + ) + + if ModalityType.DEPTH in self.active_modalities: + modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + + if ModalityType.THERMAL in self.active_modalities: + modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=10.0, learnable=False), + ) + + if ModalityType.IMU in self.active_modalities: + modality_postprocessors[ModalityType.IMU] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + + return nn.ModuleDict(modality_postprocessors) + def _load_modality_weights(self, modality, weights_dir): """ Load weights for a specific modality. @@ -119,8 +492,8 @@ def load_modular_imagebind_huge( return model -# Example usage: -if __name__ == "__main__": +@profile +def vision_text_example(): from imagebind import data text_list = ["A dog.", "A car", "A bird"] @@ -129,11 +502,6 @@ def load_modular_imagebind_huge( ".assets/car_image.jpg", ".assets/bird_image.jpg", ] - audio_paths = [ - ".assets/dog_audio.wav", - ".assets/car_audio.wav", - ".assets/bird_audio.wav", - ] device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -160,9 +528,18 @@ def load_modular_imagebind_huge( ), ) - del model_vision_text - del inputs - del embeddings + +@profile +def audio_example(): + from imagebind import data + + audio_paths = [ + ".assets/dog_audio.wav", + ".assets/car_audio.wav", + ".assets/bird_audio.wav", + ] + + device = "cuda:0" if torch.cuda.is_available() else "cpu" # Example 2: Load only audio modality print("Loading Audio model...") @@ -184,9 +561,24 @@ def load_modular_imagebind_huge( ), ) - del model_audio - del inputs - del embeddings + +@profile +def multimodal_example(): + from imagebind import data + + text_list = ["A dog.", "A car", "A bird"] + image_paths = [ + ".assets/dog_image.jpg", + ".assets/car_image.jpg", + ".assets/bird_image.jpg", + ] + audio_paths = [ + ".assets/dog_audio.wav", + ".assets/car_audio.wav", + ".assets/bird_audio.wav", + ] + + device = "cuda:0" if torch.cuda.is_available() else "cpu" # Example 3: Create a multimodal model with vision, text, and audio print("Loading Multimodal model...") @@ -211,19 +603,45 @@ def load_modular_imagebind_huge( embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 ), ) - print( - "Audio x Text: ", - torch.softmax( - embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1 - ), + + +@profile +def audio_thermal_example(): + from imagebind import data + + audio_paths = [ + ".assets/dog_audio.wav", + ".assets/car_audio.wav", + ".assets/bird_audio.wav", + ] + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Example 4: Create a multimodal model with audio and thermal + print("Loading Audio-Thermal model...") + model_audio_thermal = load_modular_imagebind_huge( + modalities=[ModalityType.AUDIO, ModalityType.THERMAL] ) + model_audio_thermal.to(device) + + inputs = { + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), + } + + # Perform inference + with torch.no_grad(): + embeddings = model_audio_thermal(inputs) + print( - "Vision x Audio: ", + "Audio x Thermal: ", torch.softmax( - embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1 + embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.AUDIO].T, dim=-1 ), ) - del model_multimodal - del inputs - del embeddings + +if __name__ == "__main__": + # vision_text_example() + # audio_example() + multimodal_example() + # audio_thermal_example() From 903be6bfba98a6385786624844b0e7f439f3520f Mon Sep 17 00:00:00 2001 From: Ahmedsaed Date: Thu, 10 Apr 2025 12:54:37 +0200 Subject: [PATCH 09/12] Cleanup and removing of memory profiling --- imagebind/modular_imagebind.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/imagebind/modular_imagebind.py b/imagebind/modular_imagebind.py index cdd0e41c..0e01626d 100644 --- a/imagebind/modular_imagebind.py +++ b/imagebind/modular_imagebind.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import os -from memory_profiler import profile from imagebind.models.imagebind_model import ImageBindModel, ModalityType from functools import partial @@ -442,11 +441,9 @@ def _load_modality_weights(self, modality, weights_dir): # Load the updated state dict self.load_state_dict(current_state_dict, strict=False) - print(f"Loaded weights for {modality} modality") - def forward(self, inputs): """ - Forward pass for the model, using only active modalities. + Forward pass through the model. Args: inputs: Dictionary of inputs for different modalities @@ -454,13 +451,15 @@ def forward(self, inputs): Returns: Dictionary of outputs for the active modalities """ - # Filter inputs to only use active modalities - filtered_inputs = { - k: v for k, v in inputs.items() if k in self.active_modalities - } + # Raise an error if input modalities are not in the active modalities + for modality in inputs.keys(): + if modality not in self.active_modalities: + raise ValueError( + f"Input modality {modality} not in active modalities: {self.active_modalities}" + ) # Call the parent's forward method with filtered inputs - return super().forward(filtered_inputs) + return super().forward(inputs) def load_modular_imagebind_huge( @@ -492,7 +491,6 @@ def load_modular_imagebind_huge( return model -@profile def vision_text_example(): from imagebind import data @@ -529,7 +527,6 @@ def vision_text_example(): ) -@profile def audio_example(): from imagebind import data @@ -562,7 +559,6 @@ def audio_example(): ) -@profile def multimodal_example(): from imagebind import data @@ -605,7 +601,6 @@ def multimodal_example(): ) -@profile def audio_thermal_example(): from imagebind import data @@ -641,7 +636,7 @@ def audio_thermal_example(): if __name__ == "__main__": - # vision_text_example() - # audio_example() + vision_text_example() + audio_example() multimodal_example() - # audio_thermal_example() + audio_thermal_example() From 0636ad499f9217981cad1e9a42d9935e593dc25d Mon Sep 17 00:00:00 2001 From: Ahmedsaed Date: Thu, 10 Apr 2025 12:58:02 +0200 Subject: [PATCH 10/12] move example related code to the __main__ part --- imagebind/modular_imagebind.py | 309 ++++++++++++++++++--------------- 1 file changed, 165 insertions(+), 144 deletions(-) diff --git a/imagebind/modular_imagebind.py b/imagebind/modular_imagebind.py index 0e01626d..0d51c186 100644 --- a/imagebind/modular_imagebind.py +++ b/imagebind/modular_imagebind.py @@ -491,151 +491,172 @@ def load_modular_imagebind_huge( return model -def vision_text_example(): - from imagebind import data - - text_list = ["A dog.", "A car", "A bird"] - image_paths = [ - ".assets/dog_image.jpg", - ".assets/car_image.jpg", - ".assets/bird_image.jpg", - ] - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - - # Example 1: Load only vision and text modalities - print("Loading Vision-Text model...") - model_vision_text = load_modular_imagebind_huge( - modalities=[ModalityType.VISION, ModalityType.TEXT] - ) - model_vision_text.to(device) - - inputs = { - ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), - ModalityType.TEXT: data.load_and_transform_text(text_list, device), - } - - # Perform inference - with torch.no_grad(): - embeddings = model_vision_text(inputs) - - print( - "Vision x Text: ", - torch.softmax( - embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 - ), - ) - - -def audio_example(): - from imagebind import data - - audio_paths = [ - ".assets/dog_audio.wav", - ".assets/car_audio.wav", - ".assets/bird_audio.wav", - ] - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - - # Example 2: Load only audio modality - print("Loading Audio model...") - model_audio = load_modular_imagebind_huge(modalities=[ModalityType.AUDIO]) - model_audio.to(device) - - inputs = { - ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), - } - - # Perform inference - with torch.no_grad(): - embeddings = model_audio(inputs) - - print( - "Audio: ", - torch.softmax( - embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.AUDIO].T, dim=-1 - ), - ) - - -def multimodal_example(): - from imagebind import data - - text_list = ["A dog.", "A car", "A bird"] - image_paths = [ - ".assets/dog_image.jpg", - ".assets/car_image.jpg", - ".assets/bird_image.jpg", - ] - audio_paths = [ - ".assets/dog_audio.wav", - ".assets/car_audio.wav", - ".assets/bird_audio.wav", - ] - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - - # Example 3: Create a multimodal model with vision, text, and audio - print("Loading Multimodal model...") - model_multimodal = load_modular_imagebind_huge( - modalities=[ModalityType.VISION, ModalityType.TEXT, ModalityType.AUDIO] - ) - model_multimodal.to(device) - - inputs = { - ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), - ModalityType.TEXT: data.load_and_transform_text(text_list, device), - ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), - } - - # Perform inference - with torch.no_grad(): - embeddings = model_multimodal(inputs) - - print( - "Vision x Text: ", - torch.softmax( - embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 - ), - ) - - -def audio_thermal_example(): - from imagebind import data - - audio_paths = [ - ".assets/dog_audio.wav", - ".assets/car_audio.wav", - ".assets/bird_audio.wav", - ] - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - - # Example 4: Create a multimodal model with audio and thermal - print("Loading Audio-Thermal model...") - model_audio_thermal = load_modular_imagebind_huge( - modalities=[ModalityType.AUDIO, ModalityType.THERMAL] - ) - model_audio_thermal.to(device) - - inputs = { - ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), - } - - # Perform inference - with torch.no_grad(): - embeddings = model_audio_thermal(inputs) - - print( - "Audio x Thermal: ", - torch.softmax( - embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.AUDIO].T, dim=-1 - ), - ) - - if __name__ == "__main__": + """Example usage of the ModularImageBind model with different modalities.""" + + def vision_text_example(): + from imagebind import data + + text_list = ["A dog.", "A car", "A bird"] + image_paths = [ + ".assets/dog_image.jpg", + ".assets/car_image.jpg", + ".assets/bird_image.jpg", + ] + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Example 1: Load only vision and text modalities + print("Loading Vision-Text model...") + model_vision_text = load_modular_imagebind_huge( + modalities=[ModalityType.VISION, ModalityType.TEXT] + ) + model_vision_text.to(device) + + inputs = { + ModalityType.VISION: data.load_and_transform_vision_data( + image_paths, device + ), + ModalityType.TEXT: data.load_and_transform_text(text_list, device), + } + + # Perform inference + with torch.no_grad(): + embeddings = model_vision_text(inputs) + + print( + "Vision x Text: ", + torch.softmax( + embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, + dim=-1, + ), + ) + + def audio_example(): + from imagebind import data + + audio_paths = [ + ".assets/dog_audio.wav", + ".assets/car_audio.wav", + ".assets/bird_audio.wav", + ] + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Example 2: Load only audio modality + print("Loading Audio model...") + model_audio = load_modular_imagebind_huge(modalities=[ModalityType.AUDIO]) + model_audio.to(device) + + inputs = { + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), + } + + # Perform inference + with torch.no_grad(): + embeddings = model_audio(inputs) + + print( + "Audio: ", + torch.softmax( + embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.AUDIO].T, + dim=-1, + ), + ) + + def multimodal_example(): + from imagebind import data + + text_list = ["A dog.", "A car", "A bird"] + image_paths = [ + ".assets/dog_image.jpg", + ".assets/car_image.jpg", + ".assets/bird_image.jpg", + ] + audio_paths = [ + ".assets/dog_audio.wav", + ".assets/car_audio.wav", + ".assets/bird_audio.wav", + ] + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Example 3: Create a multimodal model with vision, text, and audio + print("Loading Multimodal model...") + model_multimodal = load_modular_imagebind_huge( + modalities=[ModalityType.VISION, ModalityType.TEXT, ModalityType.AUDIO] + ) + model_multimodal.to(device) + + inputs = { + ModalityType.VISION: data.load_and_transform_vision_data( + image_paths, device + ), + ModalityType.TEXT: data.load_and_transform_text(text_list, device), + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), + } + + # Perform inference + with torch.no_grad(): + embeddings = model_multimodal(inputs) + + print( + "Vision x Text: ", + torch.softmax( + embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, + dim=-1, + ), + ) + + print( + "Vision x Audio: ", + torch.softmax( + embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, + dim=-1, + ), + ) + + print( + "Text x Audio: ", + torch.softmax( + embeddings[ModalityType.TEXT] @ embeddings[ModalityType.AUDIO].T, dim=-1 + ), + ) + + def audio_thermal_example(): + from imagebind import data + + audio_paths = [ + ".assets/dog_audio.wav", + ".assets/car_audio.wav", + ".assets/bird_audio.wav", + ] + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Example 4: Create a multimodal model with audio and thermal + print("Loading Audio-Thermal model...") + model_audio_thermal = load_modular_imagebind_huge( + modalities=[ModalityType.AUDIO, ModalityType.THERMAL] + ) + model_audio_thermal.to(device) + + inputs = { + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), + } + + # Perform inference + with torch.no_grad(): + embeddings = model_audio_thermal(inputs) + + print( + "Audio x Thermal: ", + torch.softmax( + embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.AUDIO].T, + dim=-1, + ), + ) + vision_text_example() audio_example() multimodal_example() From 07c0dd50473b2c5e7280ffbc98347593cd4522dc Mon Sep 17 00:00:00 2001 From: Ahmedsaed Date: Tue, 15 Apr 2025 01:58:14 +0200 Subject: [PATCH 11/12] Add memory usage reporting for ModularImageBind model in example usage --- imagebind/modular_imagebind.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/imagebind/modular_imagebind.py b/imagebind/modular_imagebind.py index 0d51c186..98874f26 100644 --- a/imagebind/modular_imagebind.py +++ b/imagebind/modular_imagebind.py @@ -494,6 +494,18 @@ def load_modular_imagebind_huge( if __name__ == "__main__": """Example usage of the ModularImageBind model with different modalities.""" + def memory_usage(model): + param_size = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + + buffer_size = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + + total_size = (param_size + buffer_size) / 1024**2 + print(f"Model size: {total_size:.2f} MB") + def vision_text_example(): from imagebind import data @@ -513,6 +525,8 @@ def vision_text_example(): ) model_vision_text.to(device) + memory_usage(model_vision_text) + inputs = { ModalityType.VISION: data.load_and_transform_vision_data( image_paths, device @@ -548,6 +562,8 @@ def audio_example(): model_audio = load_modular_imagebind_huge(modalities=[ModalityType.AUDIO]) model_audio.to(device) + memory_usage(model_audio) + inputs = { ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), } @@ -588,6 +604,8 @@ def multimodal_example(): ) model_multimodal.to(device) + memory_usage(model_multimodal) + inputs = { ModalityType.VISION: data.load_and_transform_vision_data( image_paths, device @@ -641,6 +659,8 @@ def audio_thermal_example(): ) model_audio_thermal.to(device) + memory_usage(model_audio_thermal) + inputs = { ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), } From 92759e92ef59f38d4506899e960fe2cb8ca2f7f1 Mon Sep 17 00:00:00 2001 From: Ahmedsaed Date: Tue, 15 Apr 2025 13:59:46 +0200 Subject: [PATCH 12/12] move modular imagebind implementation to models directory --- .../{modular_imagebind.py => models/modular_imagebind_model.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename imagebind/{modular_imagebind.py => models/modular_imagebind_model.py} (100%) diff --git a/imagebind/modular_imagebind.py b/imagebind/models/modular_imagebind_model.py similarity index 100% rename from imagebind/modular_imagebind.py rename to imagebind/models/modular_imagebind_model.py