Skip to content

feat: token merging for image classification#537

Open
rensortino wants to merge 4 commits intoPrunaAI:mainfrom
rensortino:feat/token-merging
Open

feat: token merging for image classification#537
rensortino wants to merge 4 commits intoPrunaAI:mainfrom
rensortino:feat/token-merging

Conversation

@rensortino
Copy link

Description

This PR introduces the Token Merging (ToMe) algorithm for HuggingFace Vision Transformer models. Token Merging progressively merges similar tokens between the attention and MLP stages of each transformer block, significantly reducing the number of tokens and speeding up inference with minimal quality loss.

Using model google/vit-base-patch16-224, speedup is over 2x with r=8.

Key Changes:

Token Merging Algorithm:

  • Implements the ToMe algorithm adapted from facebook/ToMe paper
  • Custom ViT module classes (ToMeViTLayer, ToMeViTSelfAttention) that extend HuggingFace transformers
  • Supports proportional attention weighting based on merged token sizes
  • Bipartite soft matching for intelligent token pair selection
  • Configurable token reduction schedule with per-layer control
  • Model wrapper for state management across forward passes

Testing Infrastructure:

  • Added ViT model fixtures for comprehensive testing
  • Token Merging test class with validation scenarios

Related Issue

Fixes #399

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

  • Token Merging algorithm tested with HuggingFace ViT models
  • Test fixtures added for google/vit-base-patch16-224 model family
  • Integration tests verify proper token reduction and attention output handling
  • Validated compatibility with existing Pruna pipeline

Implementation Details

Token Merging Core Features:

  1. Bipartite Soft Matching: Intelligently selects which token pairs to merge based on key similarity
  2. Proportional Attention: Adjusts attention weights by the log of merged token sizes
  3. Configurable Reduction Schedule:
    • Constant r across all layers
    • Per-layer list specification
    • Inflection-based schedules (increasing/decreasing/constant)
  4. Class Swapping Pattern: Dynamically replaces HF module classes at runtime to inject ToMe behavior
  5. Metric Storage: Uses key layer mean as similarity metric for matching

Hyperparameters:

  • r (int, 0-128): Number of tokens to merge per layer (default: 16)
  • trace_source (bool): Track merge provenance for visualization
  • prop_attn (bool): Enable proportional attention weighting (default: True)

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Design Decisions:

  1. Module-level class definitions: ToMeViTLayer and related classes are defined at module level (not inside methods) to ensure they are picklable for distributed training and model serialization.

  2. Eager attention enforcement: The ToMeViTSelfAttention class uses eager attention computation to inject the proportional attention bias between QK matmul and softmax operations.

  3. Shared mutable state: All ToMe modules share a single tome_info dict for efficient state management across layers.

Future Enhancements:

  • Extension to other transformer architectures (Flux, SAM, etc.)
  • Support for custom attention mechanisms

References:

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 4 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

attn_weights = attn_weights + self._tome_info["size"].log()[:, None, None, :, 0]

if head_mask is not None:
attn_weights = attn_weights + head_mask
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

head_mask applied additively instead of multiplicatively

Medium Severity

The head_mask is applied additively before softmax (attn_weights = attn_weights + head_mask), but the original HuggingFace ViTSelfAttention applies it multiplicatively after softmax (attention_probs = attention_probs * head_mask). Since head_mask is a binary tensor (0 to prune, 1 to keep), the additive pre-softmax application has essentially no masking effect — heads that are supposed to be pruned will still contribute to attention, silently producing incorrect results when head masking is used.

Fix in Cursor Fix in Web

pruna_logger.warning("Transformers library not found. Token merging will not be applied.")
return False

return any(isinstance(m, ViTLayer) for m in model.model.modules())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_check_fn crashes for non-pipeline model inputs

Medium Severity

model_check_fn unconditionally accesses model.model.modules(), which assumes the input is always a pipeline. However, _apply conditionally unwraps pipelines via isinstance(model, ImageClassificationPipeline). If a raw ViTForImageClassification model is passed, model_check_fn will raise an AttributeError because non-pipeline models don't have a .model attribute. The check needs the same conditional unwrapping logic used in _apply.

Additional Locations (1)

Fix in Cursor Fix in Web


min_val = int(r * (1.0 - inflect))
max_val = 2 * r - min_val
step = (max_val - min_val) / (num_layers - 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division by zero in _parse_r for single layer

Low Severity

_parse_r computes step = (max_val - min_val) / (num_layers - 1), which raises a ZeroDivisionError when num_layers is 1. Even when inflect is 0 (making the numerator 0), Python's 0 / 0 still raises this error. While unlikely for standard ViT models, any model with a single transformer layer would crash here.

Fix in Cursor Fix in Web

output = model(self.input_image)
pred_labels = [model.config.id2label[p] for p in output[0].topk(5).indices[0].tolist()]
print("Output: ", pred_labels)
print("Original: ", self.original_pred)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statements left in test code

Low Severity

Several print() statements (lines 39–40, 45–46) appear to be leftover debugging output. No other test tester in the tests/algorithms/testers directory uses print(), and these add noise to test output without serving a testing purpose. They likely need to be removed or replaced with proper logging.

Fix in Cursor Fix in Web

@sdiazlor sdiazlor requested a review from llcnt February 17, 2026 14:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Implement Token Merging

1 participant

Comments