MMCoT-Prune: Efficient Multimodal Reasoning through Cross-Modal Attention-Guided Chain-of-Thought Pruning
TL;DR: We introduce the first framework for reasoning-level pruning in multimodal chain-of-thought, achieving 22-28% reduction in reasoning steps with minimal accuracy degradation through cross-modal attention analysis.
- [2026-02] Paper submitted to NeurIPS 2026
- [2026-02] Code and models released
- [2026-02] Initial project release
Multimodal Chain-of-Thought (MCoT) reasoning has shown remarkable capabilities but generates lengthy reasoning chains with substantial computational overhead. MMCoT-Prune addresses this challenge by:
- π― Analyzing cross-modal attention to identify steps with weak visual grounding
- π Detecting semantic redundancy among reasoning steps
- β‘ Dynamically adapting pruning aggressiveness to task difficulty
- π Achieving 22-28% efficiency gains with competitive accuracy
| Dataset | Baseline Acc. | MMCoT-Prune Acc. | Pruning Rate | Token Savings |
|---|---|---|---|---|
| ScienceQA | 86.4% | 84.9% (-1.5%) | 28.1% | 46 tok/q |
| A-OKVQA | 68.3% | 64.7% (-3.6%) | 23.1% | 42 tok/q |
| VSR | 78.1% | 77.1% (-1.0%) | 25.9% | 44 tok/q |
βββββββββββββββ ββββββββββββββββ βββββββββββββββ
β Image + βββββββΆβ Multimodal βββββββΆβ Cross- β
β Question β β CoT Generate β β Modal β
βββββββββββββββ ββββββββββββββββ β Attention β
β Analysis β
ββββββββ¬βββββββ
β
ββββββββββββββββ β
β Dynamic βββββββββββββββ€
β Pruning β β
β Strategy β β
ββββββββ¬ββββββββ β
β β
ββββββββΌββββββββ ββββββββΌβββββββ
β Pruned β β Redundancy β
β Reasoning ββββββββ Detection β
β Chain β βββββββββββββββ
ββββββββββββββββ
# Clone repository
git clone https://github.com/anonymous/mmcot-prune.git
cd mmcot-prune
# Create environment
conda create -n mmcot-prune python=3.9
conda activate mmcot-prune
# Install dependencies
pip install -r requirements.txtfrom src.mmcot_pruning import MMCoTPruner, ReasoningStep
import numpy as np
# Initialize pruner
pruner = MMCoTPruner(
entropy_threshold=2.5,
similarity_threshold=0.85,
base_prune_threshold=0.4,
min_steps=3
)
# Create reasoning steps (with attention weights and embeddings)
reasoning_chain = [
ReasoningStep(
text="Step 1: Analyze the image...",
step_id=0,
attention_weights=np.array([...]), # Cross-modal attention
semantic_embedding=np.array([...]) # Text embedding
),
# ... more steps
]
# Prune the chain
pruned_chain, decisions = pruner.prune_reasoning_chain(
reasoning_chain,
task_difficulty=0.5
)
# Compute efficiency metrics
metrics = pruner.compute_efficiency_metrics(reasoning_chain, pruned_chain)
print(f"Pruning rate: {metrics['pruning_rate']:.1%}")
print(f"Token reduction: {metrics['token_reduction']}")# Run benchmark evaluation
cd experiments
python benchmark_evaluation.py
# Generate figures
python generate_figures.py
# Results saved to experiments/benchmark_results.json- Python 3.9+
- PyTorch 2.1.0+
- CUDA 11.8+ (for GPU acceleration)
- 4Γ NVIDIA A100 GPUs (40GB) for full experiments
- Note: Can run on smaller setups with reduced batch sizes
- Download Datasets
# ScienceQA
wget https://scienceqa.github.io/data/scienceqa.zip
unzip scienceqa.zip -d data/
# A-OKVQA
wget https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz
tar -xzf aokvqa_v1p0.tar.gz -d data/
# VSR
git clone https://github.com/cambridgeltl/visual-spatial-reasoning data/vsr- Preprocess Data
python scripts/preprocess_datasets.py --dataset all- Run Experiments
# Main experiments (Table 1 in paper)
python experiments/run_main_experiments.py
# Ablation studies (Table 2-3 in paper)
python experiments/run_ablations.py
# Generate all figures
python experiments/generate_figures.py- Analyze Results
python scripts/analyze_results.py --results_dir experiments/results- Main experiments: ~48 hours on 4Γ A100 GPUs
- Ablation studies: ~24 hours on 4Γ A100 GPUs
- Total: ~72 GPU hours
We compute attention entropy to measure visual grounding:
H(a_i) = -Ξ£ a_ij log(a_ij)Low entropy β focused attention β important step High entropy β diffuse attention β potentially redundant
Combines entropy and peak attention:
g(r_i) = 0.6 / (1 + H(a_i)) + 0.4 * max(a_i)Integrates visual grounding and semantic similarity:
Ο(r_i) = 0.5 * (1 - g(r_i)) + 0.5 * max_j sim(r_i, r_j)Adapts threshold based on task difficulty:
Ο(d) = Ο_base + 0.3 * dPrune if: Ο(r_i) > Ο(d) and |pruned_steps| β₯ k_min
MMCoT-Prune achieves substantial efficiency gains across all benchmarks:
- ScienceQA: 28.1% pruning, 84.9% accuracy (vs 86.4% baseline)
- A-OKVQA: 23.1% pruning, 64.7% accuracy (vs 68.3% baseline)
- VSR: 25.9% pruning, 77.1% accuracy (vs 78.1% baseline)
Key findings:
- Combining attention and similarity outperforms either alone
- Dynamic thresholding improves accuracy by 1.8% vs static
- Conservative settings (Ο=0.6) achieve best accuracy-efficiency balance for critical applications
mmcot-prune/
βββ src/
β βββ mmcot_pruning.py # Core pruning algorithm
β βββ attention_analysis.py # Cross-modal attention extraction
β βββ models.py # VLM wrappers
β βββ utils.py # Helper functions
βββ experiments/
β βββ benchmark_evaluation.py # Main experiments
β βββ ablation_studies.py # Ablation studies
β βββ generate_figures.py # Visualization
β βββ benchmark_results.json # Results
βββ data/ # Datasets (download separately)
βββ figures/ # Generated figures
βββ paper/
β βββ mmcot_prune.tex # Main paper
β βββ supplementary.tex # Supplementary material
β βββ cover_letter.md # Cover letter
βββ scripts/
β βββ preprocess_datasets.py # Data preprocessing
β βββ analyze_results.py # Result analysis
βββ requirements.txt # Python dependencies
βββ README.md # This file
Key hyperparameters can be configured in config.yaml:
pruning:
base_threshold: 0.4 # Base pruning threshold
similarity_threshold: 0.85 # Semantic similarity threshold
entropy_threshold: 2.5 # Attention entropy threshold
min_steps: 3 # Minimum retained steps
max_prune_ratio: 0.6 # Maximum pruning fraction
grounding:
entropy_weight: 0.6 # Weight for entropy component
peak_attention_weight: 0.4 # Weight for peak attention
difficulty:
visual_weight: 0.4 # Weight for visual complexity
text_weight: 0.3 # Weight for text complexity
object_weight: 0.3 # Weight for object countIf you find this work useful, please cite:
@article{mmcotprune2026,
title={MMCoT-Prune: Efficient Multimodal Reasoning through Cross-Modal Attention-Guided Chain-of-Thought Pruning},
author={Anonymous Authors},
journal={arXiv preprint arXiv:2026.XXXXX},
year={2026}
}We welcome contributions! Please see CONTRIBUTING.md for guidelines.
This project is licensed under the MIT License - see LICENSE for details.
This work builds upon:
- Multimodal-CoT for baseline implementation
- LLaVA for vision-language modeling
- ScienceQA, A-OKVQA, VSR for datasets
We thank the authors for making their code and data publicly available.
For questions or collaboration:
- Email: [email protected]
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Text-only CoT Pruning: ThinkPrune, LCPO
- Visual Token Pruning: ConsensusDrop, Attention Debiasing
- Multimodal CoT: Multimodal-CoT, MCOUT
Status: π§ Under review at NeurIPS 2026
Last Updated: February 2026