feat: token merging for image classification#537
feat: token merging for image classification#537rensortino wants to merge 4 commits intoPrunaAI:mainfrom
Conversation
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 4 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
| attn_weights = attn_weights + self._tome_info["size"].log()[:, None, None, :, 0] | ||
|
|
||
| if head_mask is not None: | ||
| attn_weights = attn_weights + head_mask |
There was a problem hiding this comment.
head_mask applied additively instead of multiplicatively
Medium Severity
The head_mask is applied additively before softmax (attn_weights = attn_weights + head_mask), but the original HuggingFace ViTSelfAttention applies it multiplicatively after softmax (attention_probs = attention_probs * head_mask). Since head_mask is a binary tensor (0 to prune, 1 to keep), the additive pre-softmax application has essentially no masking effect — heads that are supposed to be pruned will still contribute to attention, silently producing incorrect results when head masking is used.
| pruna_logger.warning("Transformers library not found. Token merging will not be applied.") | ||
| return False | ||
|
|
||
| return any(isinstance(m, ViTLayer) for m in model.model.modules()) |
There was a problem hiding this comment.
model_check_fn crashes for non-pipeline model inputs
Medium Severity
model_check_fn unconditionally accesses model.model.modules(), which assumes the input is always a pipeline. However, _apply conditionally unwraps pipelines via isinstance(model, ImageClassificationPipeline). If a raw ViTForImageClassification model is passed, model_check_fn will raise an AttributeError because non-pipeline models don't have a .model attribute. The check needs the same conditional unwrapping logic used in _apply.
Additional Locations (1)
|
|
||
| min_val = int(r * (1.0 - inflect)) | ||
| max_val = 2 * r - min_val | ||
| step = (max_val - min_val) / (num_layers - 1) |
There was a problem hiding this comment.
Division by zero in _parse_r for single layer
Low Severity
_parse_r computes step = (max_val - min_val) / (num_layers - 1), which raises a ZeroDivisionError when num_layers is 1. Even when inflect is 0 (making the numerator 0), Python's 0 / 0 still raises this error. While unlikely for standard ViT models, any model with a single transformer layer would crash here.
| output = model(self.input_image) | ||
| pred_labels = [model.config.id2label[p] for p in output[0].topk(5).indices[0].tolist()] | ||
| print("Output: ", pred_labels) | ||
| print("Original: ", self.original_pred) |
There was a problem hiding this comment.
Debug print statements left in test code
Low Severity
Several print() statements (lines 39–40, 45–46) appear to be leftover debugging output. No other test tester in the tests/algorithms/testers directory uses print(), and these add noise to test output without serving a testing purpose. They likely need to be removed or replaced with proper logging.


Description
This PR introduces the Token Merging (ToMe) algorithm for HuggingFace Vision Transformer models. Token Merging progressively merges similar tokens between the attention and MLP stages of each transformer block, significantly reducing the number of tokens and speeding up inference with minimal quality loss.
Using model
google/vit-base-patch16-224, speedup is over 2x withr=8.Key Changes:
Token Merging Algorithm:
ToMeViTLayer,ToMeViTSelfAttention) that extend HuggingFace transformersTesting Infrastructure:
Related Issue
Fixes #399
Type of Change
How Has This Been Tested?
google/vit-base-patch16-224model familyImplementation Details
Token Merging Core Features:
racross all layersHyperparameters:
r(int, 0-128): Number of tokens to merge per layer (default: 16)trace_source(bool): Track merge provenance for visualizationprop_attn(bool): Enable proportional attention weighting (default: True)Checklist
Additional Notes
Design Decisions:
Module-level class definitions:
ToMeViTLayerand related classes are defined at module level (not inside methods) to ensure they are picklable for distributed training and model serialization.Eager attention enforcement: The
ToMeViTSelfAttentionclass uses eager attention computation to inject the proportional attention bias between QK matmul and softmax operations.Shared mutable state: All ToMe modules share a single
tome_infodict for efficient state management across layers.Future Enhancements:
References: