Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
base_llm: Qwen/Qwen3-4B-Instruct-2507
base_model: null
attachment_token: <|reserved_special_token_0|>
tokenizer_type: qwen3
token_size: 2560

loaders:
- loader_type: raw-image
modality_type: image

modalities:
- model_type: meditron_biomedclip
clip_name: michel-ducartier/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224
hidden_size: 2560
trust_remote_code: true

training_mode: ALIGNMENT

datasets:
# Non formatted dataset
# ---------------------
- packed_path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/llava_pretrain_cleaned
- packed_path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/pixmo_anything
- packed_path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/pixmo_cap
- packed_path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/medtrinity_conversations_1_formatted_alignment/

training_args:
output_dir: /capstor/store/cscs/swissai/a127/meditron/models/multimeditron/single_clip/two_phase/MultiMeditron-Qwen-4B-Alignment-Generalist-Delimiter
dataloader_num_workers: 16 # > 0 not supported for IterableDataset, cf. https://github.com/huggingface/datasets/issues/5984
dataloader_prefetch_factor: 4
remove_unused_columns: false
ddp_find_unused_parameters: false # Test to reduce memory
learning_rate: 1.0e-4
bf16: true
per_device_train_batch_size: 4 # note that training_args.n_gpu and training_args.train_batch_size show faulty values
# with deepspeed -> use deepspeed_plugin instead (besides training_args.distributed_state.num_processes == WORLD_SIZE)
gradient_accumulation_steps: 8
num_train_epochs: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
save_strategy: epoch
# save_steps: 0.25
max_grad_norm: 1.0
run_name: MultiMeditron-Qwen-4B-Alignment-Generalist-Delimiter # Set the name for the run
deepspeed: ./config/deepspeed.json
accelerator_config:
dispatch_batches: false
lr_scheduler_type: "cosine_with_min_lr"
lr_scheduler_kwargs:
min_lr: 3.0e-5
report_to: wandb
logging_steps: 1
weight_decay: 0.01

4 changes: 4 additions & 0 deletions src/multimeditron/model/modalities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from multimeditron.model.modalities.image_modality_moe import MOEImageConfig, MOEImageModality, MOEImageProcessor
from multimeditron.model.modalities.image_modality_moe_pep import MOEImageConfigPEP, MOEImageModalityPEP, MOEImageProcessorPEP
from multimeditron.model.modalities.image_modality import ImageConfig, ImageModality, ImageProcessor
from multimeditron.model.modalities.image_modality_biomed import BioMedCLIPImageConfig, BioMedCLIPImageModality, BioMedCLIPImageProcessor

__all__ = [
"BaseModality",
Expand All @@ -17,4 +18,7 @@
"ImageConfig",
"ImageModality",
"ImageProcessor",
"BioMedCLIPImageConfig",
"BioMedCLIPImageModality",
"BioMedCLIPImageProcessor",
]
5 changes: 0 additions & 5 deletions src/multimeditron/model/modalities/image_modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ def forward(self, inputs) -> torch.FloatTensor:

return projected

@classmethod
def from_dict(cls, config_args, **kwargs):
return ImageConfig.from_dict(config_args, **kwargs)

def freeze_modality_embedder(self):
for parameters in self.feature_extractor.parameters():
parameters.requires_grad = False
Expand All @@ -152,4 +148,3 @@ def unfreeze_projection(self):
for parameters in self.projector.parameters():
parameters.requires_grad = True


135 changes: 135 additions & 0 deletions src/multimeditron/model/modalities/image_modality_biomed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from transformers import AutoConfig, AutoImageProcessor, AutoModel
from multimeditron.model.modalities.base import BaseModalityConfig

import torch
from typing import Dict, Any
from PIL import Image

from multimeditron.model.constants import (
NUM_EMBEDDINGS_KEY,
MODALITY_VALUE_KEY,
POSITION_IDS_KEY,
)
from multimeditron.model.modalities.base import BaseModalityProcessor, BaseModality, AutoModality
from multimeditron.model.projectors.mlp import MLPProjector


class BioMedCLIPImageConfig(BaseModalityConfig):
"""
Image modality config for OpenCLIP-based models (e.g. BiomedCLIP)
"""

def __init__(
self,
hidden_size: int = 4096,
clip_name: str = "",
# Set this to true (trust)
trust_remote_code: bool = False,
projection_type: str = "mlp",
**kwargs,
):
super().__init__(
modality_type="image",
hidden_size=hidden_size,
kwargs=kwargs,
)

self.clip_name = clip_name
self.projection_type = projection_type
self.trust_remote_code = trust_remote_code


class BioMedCLIPImageProcessor(BaseModalityProcessor):
"""
Image processor using OpenCLIP preprocessing (BiomedCLIP-compatible)
"""

def __init__(self, config: BioMedCLIPImageConfig):
super().__init__(config)
assert config.clip_name is not None

self.preprocess = AutoImageProcessor.from_pretrained(
config.clip_name,
trust_remote_code=config.trust_remote_code
)
feature_extractor_config = AutoConfig.from_pretrained(
config.clip_name,
trust_remote_code=config.trust_remote_code
)

vision_cfg = feature_extractor_config.vision_cfg
self._num_patches_per_entry = (vision_cfg["image_size"] // vision_cfg["patch_size"]) ** 2

def process(self, modality: Dict[str, Any]) -> Dict[str, Any]:
processed = modality.copy()
image: Image.Image = modality[MODALITY_VALUE_KEY]

pixel_values = self.preprocess(image)
processed[MODALITY_VALUE_KEY] = pixel_values
processed[NUM_EMBEDDINGS_KEY] = self._num_patches_per_entry

return processed


@AutoModality.register("meditron_biomedclip")
class BioMedCLIPImageModality(BaseModality):
"""
Image modality backed by BiomedCLIP (OpenCLIP).
"""

config_class = BioMedCLIPImageConfig
preprocessor_class = BioMedCLIPImageProcessor

def __init__(self, config: BioMedCLIPImageConfig):
super().__init__(config)

assert config.clip_name is not None

self.feature_extractor = AutoModel.from_pretrained(
config.clip_name,
trust_remote_code=config.trust_remote_code
)

remote_config = AutoConfig.from_pretrained(
config.clip_name,
trust_remote_code=config.trust_remote_code
)
self.embedding_size = remote_config.vision_cfg["width"]

self.projector = MLPProjector(
self.embedding_size,
config.hidden_size,
dtype=self.dtype,
)


def forward(self, inputs) -> torch.FloatTensor:
"""
inputs: list[Tensor] each (3, 224, 224)
"""
x = torch.stack(inputs, dim=0).to(self.device)

# OpenCLIP ViT output: (B, D, P, P)
# D is the dimension of an embedding
# P is the number of patches
res = self.feature_extractor.forward_intermediates(vision_inputs=x, normalize_intermediates=True)
features = res["image_intermediates"][-1]
features = features.flatten(start_dim=-2, end_dim=-1)
features = features.transpose(-1, -2)

projected = self.projector(features)

return projected

def freeze_modality_embedder(self):
for p in self.feature_extractor.parameters():
p.requires_grad = False

def unfreeze_modality_embedder(self):
for p in self.feature_extractor.parameters():
p.requires_grad = True

def unfreeze_projection(self):
for p in self.projector.parameters():
p.requires_grad = True

7 changes: 4 additions & 3 deletions src/multimeditron/model/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def tokenize_samples(

tokenized = tokenized_conversations + tokenized_texts

processed_labels = torch.where(tokenized[0]["labels"] == IGNORE_TOKEN_INDEX, 0, tokenized[0]["labels"])
padded_tokenized = self.pad_tokenized(tokenized)

return self.update_with_token_range(padded_tokenized, samples)
Expand Down Expand Up @@ -168,7 +169,7 @@ def _tokenize_conversation(
add_generation_prompt=add_generation_prompt,
enable_thinking=False,
)

input_ids, attention_mask = self.expand_attachment_input_tokens(
token_ids=outputs["input_ids"].flatten(),
attention_mask=outputs["attention_mask"].flatten(),
Expand Down Expand Up @@ -334,8 +335,8 @@ def expand_attachment_input_tokens(
assert len(attention_mask) == len(token_ids)

# First, take all the text until the first modality (excluded)
expanded_token_ids = [token_ids[: modalities_indices[0] - 1]]
expanded_attention_mask = [attention_mask[: modalities_indices[0] - 1]]
expanded_token_ids = [token_ids[: modalities_indices[0]]]
expanded_attention_mask = [attention_mask[: modalities_indices[0]]]

# Add the first modality
num_embeddings = self.get_num_embeddings(modalities_for_message[0])
Expand Down