-
Notifications
You must be signed in to change notification settings - Fork 30
Add CLIP-JAX ADE20K implementation, training, and tests #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Thakor-Yashpal
wants to merge
10
commits into
jax-ml:main
Choose a base branch
from
Thakor-Yashpal:clip-jax-ade20k
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
c4eb01a
Add CLIP-JAX ADE20K implementation, training, and tests
Thakor-Yashpal 5e34692
I am working on the new model, similar to the previous one, but remov…
Thakor-Yashpal d8dbeba
Add initial CLIP (ITACLIP) model implementation
Thakor-Yashpal 4fded20
Create README.md
Thakor-Yashpal f437948
Update README.md
Thakor-Yashpal a7a325b
Update README.md
Thakor-Yashpal 98556c2
x
Thakor-Yashpal 2cf8d6e
Delete bonsai/models/clip/tests/__init__.py
Thakor-Yashpal 97cc48d
changes
Thakor-Yashpal 2d280e9
Merge branch 'clip-jax-ade20k' of https://github.com/Thakor-Yashpal/b…
Thakor-Yashpal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # ITA-CLIP — CLIP-style model (JAX / Flax) | ||
|
|
||
| This directory contains a compact CLIP-like implementation (ITA-CLIP) in JAX/Flax, | ||
| intended for zero-shot image classification, prompt-guided heatmaps, and image-text embedding experiments. | ||
|
|
||
| ## Paper (reference) | ||
|
|
||
| - Radford et al., *Learning Transferable Visual Models From Natural Language Supervision* (OpenAI CLIP) | ||
| Local copy used during development: `/mnt/data/2103.00020v1.pdf` | ||
|
|
||
| --- | ||
|
|
||
| ## Tested on | ||
|
|
||
| | Model Name | Config | CPU | GPU (single) | GPU (multi) | TPU | | ||
| | :--- | :---: | :---: | :---: | :---: | :---: | | ||
| | ITA-CLIP (TinyViT + TinyText) | ✅ Compact research config | ✅ Runs (CPU) | ❔ Needs check (CUDA JAX) | ❔ Needs check | ❔ Needs check | | ||
|
|
||
| > Notes: This implementation uses a compact TinyViT and small text-transformer to make local testing and CI-friendly smoke tests possible. For large-scale ViT-B/32 or ViT-L/14 variants, add config presets and provide pretrained weights. | ||
|
|
||
| --- | ||
|
|
||
| ### Running this model (quick smoke test) | ||
|
|
||
| Run a forward pass / smoke test: | ||
|
|
||
| ```bash | ||
| python3 -m bonsai.models.clip.tests.run_model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from .modeling import CLIPModel, clip_contrastive_loss | ||
| from .params import CLIPConfig | ||
| from .tokenizer import load_tokenizer, simple_whitespace_tokenizer | ||
|
|
||
| __all__ = ["CLIPModel", "clip_contrastive_loss", "CLIPConfig", "load_tokenizer", "simple_whitespace_tokenizer"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| from typing import Any | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import flax.linen as nn | ||
| from flax.linen import initializers | ||
| from .params import CLIPConfig | ||
|
|
||
| def _get_dtype(cfg: CLIPConfig): | ||
| return jnp.float32 if cfg.dtype == "float32" else jnp.float16 | ||
|
|
||
| class MLPBlock(nn.Module): | ||
| mlp_dim: int | ||
| out_dim: int | ||
| act = nn.gelu | ||
| dtype = jnp.float32 | ||
|
|
||
| @nn.compact | ||
| def __call__(self, x): | ||
| x = nn.Dense(self.mlp_dim, dtype=self.dtype)(x) | ||
| x = self.act(x) | ||
| x = nn.Dense(self.out_dim, dtype=self.dtype)(x) | ||
| return x | ||
|
|
||
| class AddPositionEmbs(nn.Module): | ||
| max_len: int | ||
| emb_dim: int | ||
| dtype = jnp.float32 | ||
|
|
||
| def setup(self): | ||
| self.pos_emb = self.param("pos_emb", initializers.normal(0.02), (1, self.max_len, self.emb_dim)) | ||
|
|
||
| def __call__(self, x): | ||
| return x + self.pos_emb | ||
|
|
||
| class TransformerEncoderBlock(nn.Module): | ||
| num_heads: int | ||
| mlp_dim: int | ||
| dtype = jnp.float32 | ||
|
|
||
| @nn.compact | ||
| def __call__(self, x, deterministic=True): | ||
| y = nn.LayerNorm(dtype=self.dtype)(x) | ||
| y = nn.SelfAttention(num_heads=self.num_heads, dtype=self.dtype, deterministic=deterministic)(y) | ||
| x = x + y | ||
| y = nn.LayerNorm(dtype=self.dtype)(x) | ||
| y = MLPBlock(self.mlp_dim, x.shape[-1], dtype=self.dtype)(y) | ||
| return x + y | ||
|
|
||
| class SimplePatchEmbed(nn.Module): | ||
| patch_size: int | ||
| emb_dim: int | ||
| dtype = jnp.float32 | ||
|
|
||
| @nn.compact | ||
| def __call__(self, x): | ||
| ps = self.patch_size | ||
| x = nn.Conv(self.emb_dim, (ps,ps), strides=(ps,ps), padding='VALID', dtype=self.dtype)(x) | ||
| b,h,w,c = x.shape | ||
| return jnp.reshape(x, (b, h*w, c)) | ||
|
|
||
| class ImageEncoderViT(nn.Module): | ||
| cfg: CLIPConfig | ||
| dtype = jnp.float32 | ||
|
|
||
| @nn.compact | ||
| def __call__(self, images, deterministic=True): | ||
| cfg = self.cfg | ||
| x = SimplePatchEmbed(cfg.patch_size, cfg.image_embed_dim, dtype=self.dtype)(images) | ||
| cls = self.param('cls', initializers.zeros, (1,1,cfg.image_embed_dim)) | ||
| cls_b = jnp.tile(cls, (x.shape[0],1,1)) | ||
| x = jnp.concatenate([cls_b, x], axis=1) | ||
| x = AddPositionEmbs(x.shape[1], cfg.image_embed_dim, dtype=self.dtype)(x) | ||
| for _ in range(cfg.vit_num_layers): | ||
| x = TransformerEncoderBlock(cfg.vit_num_heads, cfg.vit_mlp_dim, dtype=self.dtype)(x, deterministic=deterministic) | ||
| cls_out = x[:,0] | ||
| cls_out = nn.LayerNorm(dtype=self.dtype)(cls_out) | ||
| img_feat = nn.Dense(cfg.image_embed_dim, dtype=self.dtype)(cls_out) | ||
| return img_feat | ||
|
|
||
| # small ResNet-like encoder (kept light) | ||
| class ResNetStem(nn.Module): | ||
| out_ch: int | ||
| dtype = jnp.float32 | ||
|
|
||
| @nn.compact | ||
| def __call__(self, x): | ||
| x = nn.Conv(self.out_ch, (7,7), strides=(2,2), padding='SAME', use_bias=False, dtype=self.dtype)(x) | ||
| x = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(x) | ||
| x = nn.relu(x) | ||
| x = nn.max_pool(x, (3,3), strides=(2,2), padding='SAME') | ||
| return x | ||
|
|
||
| class ResidualBlock(nn.Module): | ||
| out_ch: int | ||
| strides: tuple = (1,1) | ||
| dtype = jnp.float32 | ||
|
|
||
| @nn.compact | ||
| def __call__(self, x): | ||
| residual = x | ||
| y = nn.Conv(self.out_ch, (3,3), strides=self.strides, padding='SAME', use_bias=False, dtype=self.dtype)(x) | ||
| y = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(y) | ||
| y = nn.relu(y) | ||
| y = nn.Conv(self.out_ch, (3,3), padding='SAME', use_bias=False, dtype=self.dtype)(y) | ||
| y = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(y) | ||
| if residual.shape[-1] != self.out_ch or self.strides != (1,1): | ||
| residual = nn.Conv(self.out_ch, (1,1), strides=self.strides, padding='SAME', use_bias=False, dtype=self.dtype)(residual) | ||
| residual = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(residual) | ||
| return nn.relu(residual + y) | ||
|
|
||
| class ImageEncoderResNet(nn.Module): | ||
| cfg: CLIPConfig | ||
| dtype = jnp.float32 | ||
|
|
||
| @nn.compact | ||
| def __call__(self, images, deterministic=True): | ||
| cfg = self.cfg | ||
| x = ResNetStem(cfg.resnet_stem_channels, dtype=self.dtype)(images) | ||
| for ch, repeats in zip(cfg.resnet_block_channels, cfg.resnet_block_repeats): | ||
| for i in range(repeats): | ||
| strides = (2,2) if i == 0 else (1,1) | ||
| x = ResidualBlock(ch, strides=strides, dtype=self.dtype)(x) | ||
| x = x.mean(axis=(1,2)) | ||
| x = nn.LayerNorm(dtype=self.dtype)(x) | ||
| img_feat = nn.Dense(cfg.image_embed_dim, dtype=self.dtype)(x) | ||
| return img_feat | ||
|
|
||
| class TextEncoder(nn.Module): | ||
| cfg: CLIPConfig | ||
| dtype = jnp.float32 | ||
|
|
||
| @nn.compact | ||
| def __call__(self, token_ids, deterministic=True): | ||
| cfg = self.cfg | ||
| tok_emb = nn.Embed(num_embeddings=cfg.text_vocab_size, features=cfg.text_embed_dim, dtype=self.dtype)(token_ids) | ||
| tok_emb = AddPositionEmbs(tok_emb.shape[1], cfg.text_embed_dim, dtype=self.dtype)(tok_emb) | ||
| x = tok_emb | ||
| for _ in range(cfg.text_num_layers): | ||
| x = TransformerEncoderBlock(cfg.text_num_heads, cfg.text_mlp_dim, dtype=self.dtype)(x, deterministic=deterministic) | ||
| eos_feat = x[:, -1, :] | ||
| eos_feat = nn.LayerNorm(dtype=self.dtype)(eos_feat) | ||
| txt_feat = nn.Dense(cfg.text_embed_dim, dtype=self.dtype)(eos_feat) | ||
| return txt_feat | ||
|
|
||
| class CLIPModel(nn.Module): | ||
| cfg: CLIPConfig | ||
| dtype = jnp.float32 | ||
|
|
||
| def setup(self): | ||
| self.cfg.apply_model_size_presets() | ||
| self._dtype = _get_dtype(self.cfg) | ||
| if self.cfg.encoder_type == 'vit': | ||
| self.image_encoder = ImageEncoderViT(self.cfg, dtype=self._dtype) | ||
| else: | ||
| self.image_encoder = ImageEncoderResNet(self.cfg, dtype=self._dtype) | ||
| self.text_encoder = TextEncoder(self.cfg, dtype=self._dtype) | ||
| self.img_proj = nn.Dense(self.cfg.proj_dim, dtype=self._dtype, use_bias=False) | ||
| self.txt_proj = nn.Dense(self.cfg.proj_dim, dtype=self._dtype, use_bias=False) | ||
| self.logit_scale = self.param('logit_scale', lambda rng, shape: jnp.array(1.0), ()) | ||
|
|
||
| def encode_image(self, images, deterministic=True): | ||
| feats = self.image_encoder(images, deterministic=deterministic) | ||
| proj = self.img_proj(feats) | ||
| proj = proj / (jnp.linalg.norm(proj, axis=-1, keepdims=True) + 1e-10) | ||
| return proj | ||
|
|
||
| def encode_text(self, token_ids, deterministic=True): | ||
| feats = self.text_encoder(token_ids, deterministic=deterministic) | ||
| proj = self.txt_proj(feats) | ||
| proj = proj / (jnp.linalg.norm(proj, axis=-1, keepdims=True) + 1e-10) | ||
| return proj | ||
|
|
||
| def __call__(self, images, token_ids, deterministic=True): | ||
| i_e = self.encode_image(images, deterministic=deterministic) | ||
| t_e = self.encode_text(token_ids, deterministic=deterministic) | ||
| scale = jnp.exp(self.logit_scale) | ||
| logits = jnp.matmul(i_e, t_e.T) * scale | ||
| return logits, i_e, t_e, scale | ||
|
|
||
| def clip_contrastive_loss(logits: jnp.ndarray): | ||
| n = logits.shape[0] | ||
| labels = jnp.arange(n) | ||
| loss_i = jnp.mean(nn.softmax_cross_entropy(logits=logits, labels=jax.nn.one_hot(labels, n), axis=1)) | ||
| loss_t = jnp.mean(nn.softmax_cross_entropy(logits=logits.T, labels=jax.nn.one_hot(labels, n), axis=1)) | ||
| return 0.5 * (loss_i + loss_t) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Literal | ||
|
|
||
| @dataclass | ||
| class CLIPConfig: | ||
| image_size: int = 224 | ||
| encoder_type: Literal["vit", "resnet"] = "vit" | ||
| model_size: Literal["ViT-B/32", "ViT-L/14"] = "ViT-B/32" | ||
| dtype: str = "float32" | ||
|
|
||
| patch_size: int = 32 | ||
| image_embed_dim: int = 768 | ||
| vit_num_layers: int = 12 | ||
| vit_num_heads: int = 12 | ||
| vit_mlp_dim: int = 3072 | ||
|
|
||
| resnet_stem_channels: int = 64 | ||
| resnet_block_channels: tuple = (64, 128, 256, 512) | ||
| resnet_block_repeats: tuple = (3, 4, 6, 3) | ||
|
|
||
| # text encoder | ||
| text_embed_dim: int = 512 | ||
| text_vocab_size: int = 49408 | ||
| text_max_length: int = 77 | ||
| text_num_layers: int = 12 | ||
| text_num_heads: int = 8 | ||
| text_mlp_dim: int = 2048 | ||
|
|
||
| proj_dim: int = 512 | ||
|
|
||
| def apply_model_size_presets(self): | ||
| if self.model_size == "ViT-B/32": | ||
| self.patch_size = 32 | ||
| self.image_embed_dim = 768 | ||
| self.vit_num_layers = 12 | ||
| self.vit_num_heads = 12 | ||
| self.vit_mlp_dim = 3072 | ||
| self.text_embed_dim = 512 | ||
| self.proj_dim = 512 | ||
| elif self.model_size == "ViT-L/14": | ||
| self.patch_size = 14 | ||
| self.image_embed_dim = 1024 | ||
| self.vit_num_layers = 24 | ||
| self.vit_num_heads = 16 | ||
| self.vit_mlp_dim = 4096 | ||
| self.text_embed_dim = 1024 | ||
| self.proj_dim = 1024 | ||
| else: | ||
| raise ValueError("Unknown model_size: " + str(self.model_size)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For users interested in inference, could you add functionality to transfer parameters from a pretrained model? |
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have the config in this file for consistency with the rest of the repo?