diff --git a/cookbook/sft/single_clip/two_phase_alignment/config_alignment_generalist_qwen.yaml b/cookbook/sft/single_clip/two_phase_alignment/config_alignment_generalist_qwen.yaml new file mode 100644 index 0000000..30e3e9e --- /dev/null +++ b/cookbook/sft/single_clip/two_phase_alignment/config_alignment_generalist_qwen.yaml @@ -0,0 +1,55 @@ +base_llm: Qwen/Qwen3-4B-Instruct-2507 +base_model: null +attachment_token: <|reserved_special_token_0|> +tokenizer_type: qwen3 +token_size: 2560 + +loaders: + - loader_type: raw-image + modality_type: image + +modalities: + - model_type: meditron_biomedclip + clip_name: michel-ducartier/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224 + hidden_size: 2560 + trust_remote_code: true + +training_mode: ALIGNMENT + +datasets: + # Non formatted dataset + # --------------------- + - packed_path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/llava_pretrain_cleaned + - packed_path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/pixmo_anything + - packed_path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/pixmo_cap + - packed_path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/medtrinity_conversations_1_formatted_alignment/ + +training_args: + output_dir: /capstor/store/cscs/swissai/a127/meditron/models/multimeditron/single_clip/two_phase/MultiMeditron-Qwen-4B-Alignment-Generalist-Delimiter + dataloader_num_workers: 16 # > 0 not supported for IterableDataset, cf. https://github.com/huggingface/datasets/issues/5984 + dataloader_prefetch_factor: 4 + remove_unused_columns: false + ddp_find_unused_parameters: false # Test to reduce memory + learning_rate: 1.0e-4 + bf16: true + per_device_train_batch_size: 4 # note that training_args.n_gpu and training_args.train_batch_size show faulty values + # with deepspeed -> use deepspeed_plugin instead (besides training_args.distributed_state.num_processes == WORLD_SIZE) + gradient_accumulation_steps: 8 + num_train_epochs: 1 + gradient_checkpointing: true + gradient_checkpointing_kwargs: + use_reentrant: true + save_strategy: epoch + # save_steps: 0.25 + max_grad_norm: 1.0 + run_name: MultiMeditron-Qwen-4B-Alignment-Generalist-Delimiter # Set the name for the run + deepspeed: ./config/deepspeed.json + accelerator_config: + dispatch_batches: false + lr_scheduler_type: "cosine_with_min_lr" + lr_scheduler_kwargs: + min_lr: 3.0e-5 + report_to: wandb + logging_steps: 1 + weight_decay: 0.01 + diff --git a/src/multimeditron/model/modalities/__init__.py b/src/multimeditron/model/modalities/__init__.py index a1a087b..06db2c1 100644 --- a/src/multimeditron/model/modalities/__init__.py +++ b/src/multimeditron/model/modalities/__init__.py @@ -2,6 +2,7 @@ from multimeditron.model.modalities.image_modality_moe import MOEImageConfig, MOEImageModality, MOEImageProcessor from multimeditron.model.modalities.image_modality_moe_pep import MOEImageConfigPEP, MOEImageModalityPEP, MOEImageProcessorPEP from multimeditron.model.modalities.image_modality import ImageConfig, ImageModality, ImageProcessor +from multimeditron.model.modalities.image_modality_biomed import BioMedCLIPImageConfig, BioMedCLIPImageModality, BioMedCLIPImageProcessor __all__ = [ "BaseModality", @@ -17,4 +18,7 @@ "ImageConfig", "ImageModality", "ImageProcessor", + "BioMedCLIPImageConfig", + "BioMedCLIPImageModality", + "BioMedCLIPImageProcessor", ] diff --git a/src/multimeditron/model/modalities/image_modality.py b/src/multimeditron/model/modalities/image_modality.py index 371bfc6..6703e03 100644 --- a/src/multimeditron/model/modalities/image_modality.py +++ b/src/multimeditron/model/modalities/image_modality.py @@ -136,10 +136,6 @@ def forward(self, inputs) -> torch.FloatTensor: return projected - @classmethod - def from_dict(cls, config_args, **kwargs): - return ImageConfig.from_dict(config_args, **kwargs) - def freeze_modality_embedder(self): for parameters in self.feature_extractor.parameters(): parameters.requires_grad = False @@ -152,4 +148,3 @@ def unfreeze_projection(self): for parameters in self.projector.parameters(): parameters.requires_grad = True - diff --git a/src/multimeditron/model/modalities/image_modality_biomed.py b/src/multimeditron/model/modalities/image_modality_biomed.py new file mode 100644 index 0000000..222dc51 --- /dev/null +++ b/src/multimeditron/model/modalities/image_modality_biomed.py @@ -0,0 +1,135 @@ +from transformers import AutoConfig, AutoImageProcessor, AutoModel +from multimeditron.model.modalities.base import BaseModalityConfig + +import torch +from typing import Dict, Any +from PIL import Image + +from multimeditron.model.constants import ( + NUM_EMBEDDINGS_KEY, + MODALITY_VALUE_KEY, + POSITION_IDS_KEY, +) +from multimeditron.model.modalities.base import BaseModalityProcessor, BaseModality, AutoModality +from multimeditron.model.projectors.mlp import MLPProjector + + +class BioMedCLIPImageConfig(BaseModalityConfig): + """ + Image modality config for OpenCLIP-based models (e.g. BiomedCLIP) + """ + + def __init__( + self, + hidden_size: int = 4096, + clip_name: str = "", + # Set this to true (trust) + trust_remote_code: bool = False, + projection_type: str = "mlp", + **kwargs, + ): + super().__init__( + modality_type="image", + hidden_size=hidden_size, + kwargs=kwargs, + ) + + self.clip_name = clip_name + self.projection_type = projection_type + self.trust_remote_code = trust_remote_code + + +class BioMedCLIPImageProcessor(BaseModalityProcessor): + """ + Image processor using OpenCLIP preprocessing (BiomedCLIP-compatible) + """ + + def __init__(self, config: BioMedCLIPImageConfig): + super().__init__(config) + assert config.clip_name is not None + + self.preprocess = AutoImageProcessor.from_pretrained( + config.clip_name, + trust_remote_code=config.trust_remote_code + ) + feature_extractor_config = AutoConfig.from_pretrained( + config.clip_name, + trust_remote_code=config.trust_remote_code + ) + + vision_cfg = feature_extractor_config.vision_cfg + self._num_patches_per_entry = (vision_cfg["image_size"] // vision_cfg["patch_size"]) ** 2 + + def process(self, modality: Dict[str, Any]) -> Dict[str, Any]: + processed = modality.copy() + image: Image.Image = modality[MODALITY_VALUE_KEY] + + pixel_values = self.preprocess(image) + processed[MODALITY_VALUE_KEY] = pixel_values + processed[NUM_EMBEDDINGS_KEY] = self._num_patches_per_entry + + return processed + + +@AutoModality.register("meditron_biomedclip") +class BioMedCLIPImageModality(BaseModality): + """ + Image modality backed by BiomedCLIP (OpenCLIP). + """ + + config_class = BioMedCLIPImageConfig + preprocessor_class = BioMedCLIPImageProcessor + + def __init__(self, config: BioMedCLIPImageConfig): + super().__init__(config) + + assert config.clip_name is not None + + self.feature_extractor = AutoModel.from_pretrained( + config.clip_name, + trust_remote_code=config.trust_remote_code + ) + + remote_config = AutoConfig.from_pretrained( + config.clip_name, + trust_remote_code=config.trust_remote_code + ) + self.embedding_size = remote_config.vision_cfg["width"] + + self.projector = MLPProjector( + self.embedding_size, + config.hidden_size, + dtype=self.dtype, + ) + + + def forward(self, inputs) -> torch.FloatTensor: + """ + inputs: list[Tensor] each (3, 224, 224) + """ + x = torch.stack(inputs, dim=0).to(self.device) + + # OpenCLIP ViT output: (B, D, P, P) + # D is the dimension of an embedding + # P is the number of patches + res = self.feature_extractor.forward_intermediates(vision_inputs=x, normalize_intermediates=True) + features = res["image_intermediates"][-1] + features = features.flatten(start_dim=-2, end_dim=-1) + features = features.transpose(-1, -2) + + projected = self.projector(features) + + return projected + + def freeze_modality_embedder(self): + for p in self.feature_extractor.parameters(): + p.requires_grad = False + + def unfreeze_modality_embedder(self): + for p in self.feature_extractor.parameters(): + p.requires_grad = True + + def unfreeze_projection(self): + for p in self.projector.parameters(): + p.requires_grad = True + diff --git a/src/multimeditron/model/prompt_tokenizers.py b/src/multimeditron/model/prompt_tokenizers.py index edc343b..18e7711 100644 --- a/src/multimeditron/model/prompt_tokenizers.py +++ b/src/multimeditron/model/prompt_tokenizers.py @@ -89,6 +89,7 @@ def tokenize_samples( tokenized = tokenized_conversations + tokenized_texts + processed_labels = torch.where(tokenized[0]["labels"] == IGNORE_TOKEN_INDEX, 0, tokenized[0]["labels"]) padded_tokenized = self.pad_tokenized(tokenized) return self.update_with_token_range(padded_tokenized, samples) @@ -168,7 +169,7 @@ def _tokenize_conversation( add_generation_prompt=add_generation_prompt, enable_thinking=False, ) - + input_ids, attention_mask = self.expand_attachment_input_tokens( token_ids=outputs["input_ids"].flatten(), attention_mask=outputs["attention_mask"].flatten(), @@ -334,8 +335,8 @@ def expand_attachment_input_tokens( assert len(attention_mask) == len(token_ids) # First, take all the text until the first modality (excluded) - expanded_token_ids = [token_ids[: modalities_indices[0] - 1]] - expanded_attention_mask = [attention_mask[: modalities_indices[0] - 1]] + expanded_token_ids = [token_ids[: modalities_indices[0]]] + expanded_attention_mask = [attention_mask[: modalities_indices[0]]] # Add the first modality num_embeddings = self.get_num_embeddings(modalities_for_message[0])