Skip to content

MMCoT-Prune: Efficient Multimodal Reasoning through Cross-Modal Attention-Guided Chain-of-Thought Pruning - NeurIPS 2026 Submission

License

Notifications You must be signed in to change notification settings

Hollis36/mmcot-prune

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

1 Commit
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

MMCoT-Prune: Efficient Multimodal Reasoning through Cross-Modal Attention-Guided Chain-of-Thought Pruning

arXiv Python 3.9+ PyTorch License: MIT

TL;DR: We introduce the first framework for reasoning-level pruning in multimodal chain-of-thought, achieving 22-28% reduction in reasoning steps with minimal accuracy degradation through cross-modal attention analysis.

πŸ”₯ News

  • [2026-02] Paper submitted to NeurIPS 2026
  • [2026-02] Code and models released
  • [2026-02] Initial project release

πŸ“‹ Overview

Multimodal Chain-of-Thought (MCoT) reasoning has shown remarkable capabilities but generates lengthy reasoning chains with substantial computational overhead. MMCoT-Prune addresses this challenge by:

  • 🎯 Analyzing cross-modal attention to identify steps with weak visual grounding
  • πŸ” Detecting semantic redundancy among reasoning steps
  • ⚑ Dynamically adapting pruning aggressiveness to task difficulty
  • πŸ“Š Achieving 22-28% efficiency gains with competitive accuracy

Key Results

Dataset Baseline Acc. MMCoT-Prune Acc. Pruning Rate Token Savings
ScienceQA 86.4% 84.9% (-1.5%) 28.1% 46 tok/q
A-OKVQA 68.3% 64.7% (-3.6%) 23.1% 42 tok/q
VSR 78.1% 77.1% (-1.0%) 25.9% 44 tok/q

πŸ—οΈ Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   Image +   │─────▢│ Multimodal   │─────▢│  Cross-     β”‚
β”‚  Question   β”‚      β”‚ CoT Generate β”‚      β”‚  Modal      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜      β”‚  Attention  β”‚
                                            β”‚  Analysis   β”‚
                                            β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜
                                                   β”‚
                     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”             β”‚
                     β”‚   Dynamic    │◀─────────────
                     β”‚   Pruning    β”‚             β”‚
                     β”‚   Strategy   β”‚             β”‚
                     β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜             β”‚
                            β”‚                     β”‚
                     β”Œβ”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”      β”Œβ”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”
                     β”‚   Pruned     β”‚      β”‚ Redundancy  β”‚
                     β”‚   Reasoning  │◀─────│  Detection  β”‚
                     β”‚    Chain     β”‚      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

πŸš€ Quick Start

Installation

# Clone repository
git clone https://github.com/anonymous/mmcot-prune.git
cd mmcot-prune

# Create environment
conda create -n mmcot-prune python=3.9
conda activate mmcot-prune

# Install dependencies
pip install -r requirements.txt

Basic Usage

from src.mmcot_pruning import MMCoTPruner, ReasoningStep
import numpy as np

# Initialize pruner
pruner = MMCoTPruner(
    entropy_threshold=2.5,
    similarity_threshold=0.85,
    base_prune_threshold=0.4,
    min_steps=3
)

# Create reasoning steps (with attention weights and embeddings)
reasoning_chain = [
    ReasoningStep(
        text="Step 1: Analyze the image...",
        step_id=0,
        attention_weights=np.array([...]),  # Cross-modal attention
        semantic_embedding=np.array([...])   # Text embedding
    ),
    # ... more steps
]

# Prune the chain
pruned_chain, decisions = pruner.prune_reasoning_chain(
    reasoning_chain,
    task_difficulty=0.5
)

# Compute efficiency metrics
metrics = pruner.compute_efficiency_metrics(reasoning_chain, pruned_chain)
print(f"Pruning rate: {metrics['pruning_rate']:.1%}")
print(f"Token reduction: {metrics['token_reduction']}")

Running Experiments

# Run benchmark evaluation
cd experiments
python benchmark_evaluation.py

# Generate figures
python generate_figures.py

# Results saved to experiments/benchmark_results.json

πŸ“Š Reproduction

Prerequisites

  • Python 3.9+
  • PyTorch 2.1.0+
  • CUDA 11.8+ (for GPU acceleration)
  • 4Γ— NVIDIA A100 GPUs (40GB) for full experiments
    • Note: Can run on smaller setups with reduced batch sizes

Step-by-Step Reproduction

  1. Download Datasets
# ScienceQA
wget https://scienceqa.github.io/data/scienceqa.zip
unzip scienceqa.zip -d data/

# A-OKVQA
wget https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz
tar -xzf aokvqa_v1p0.tar.gz -d data/

# VSR
git clone https://github.com/cambridgeltl/visual-spatial-reasoning data/vsr
  1. Preprocess Data
python scripts/preprocess_datasets.py --dataset all
  1. Run Experiments
# Main experiments (Table 1 in paper)
python experiments/run_main_experiments.py

# Ablation studies (Table 2-3 in paper)
python experiments/run_ablations.py

# Generate all figures
python experiments/generate_figures.py
  1. Analyze Results
python scripts/analyze_results.py --results_dir experiments/results

Expected Runtime

  • Main experiments: ~48 hours on 4Γ— A100 GPUs
  • Ablation studies: ~24 hours on 4Γ— A100 GPUs
  • Total: ~72 GPU hours

πŸ”¬ Method Details

Cross-Modal Attention Analysis

We compute attention entropy to measure visual grounding:

H(a_i) = -Ξ£ a_ij log(a_ij)

Low entropy β†’ focused attention β†’ important step High entropy β†’ diffuse attention β†’ potentially redundant

Visual Grounding Score

Combines entropy and peak attention:

g(r_i) = 0.6 / (1 + H(a_i)) + 0.4 * max(a_i)

Redundancy Score

Integrates visual grounding and semantic similarity:

ρ(r_i) = 0.5 * (1 - g(r_i)) + 0.5 * max_j sim(r_i, r_j)

Dynamic Pruning

Adapts threshold based on task difficulty:

Ο„(d) = Ο„_base + 0.3 * d

Prune if: ρ(r_i) > Ο„(d) and |pruned_steps| β‰₯ k_min

πŸ“ˆ Results Summary

Main Results (Table 1)

MMCoT-Prune achieves substantial efficiency gains across all benchmarks:

  • ScienceQA: 28.1% pruning, 84.9% accuracy (vs 86.4% baseline)
  • A-OKVQA: 23.1% pruning, 64.7% accuracy (vs 68.3% baseline)
  • VSR: 25.9% pruning, 77.1% accuracy (vs 78.1% baseline)

Ablation Studies (Table 2-3)

Key findings:

  • Combining attention and similarity outperforms either alone
  • Dynamic thresholding improves accuracy by 1.8% vs static
  • Conservative settings (Ο„=0.6) achieve best accuracy-efficiency balance for critical applications

πŸ—‚οΈ Project Structure

mmcot-prune/
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ mmcot_pruning.py          # Core pruning algorithm
β”‚   β”œβ”€β”€ attention_analysis.py      # Cross-modal attention extraction
β”‚   β”œβ”€β”€ models.py                  # VLM wrappers
β”‚   └── utils.py                   # Helper functions
β”œβ”€β”€ experiments/
β”‚   β”œβ”€β”€ benchmark_evaluation.py    # Main experiments
β”‚   β”œβ”€β”€ ablation_studies.py        # Ablation studies
β”‚   β”œβ”€β”€ generate_figures.py        # Visualization
β”‚   └── benchmark_results.json     # Results
β”œβ”€β”€ data/                          # Datasets (download separately)
β”œβ”€β”€ figures/                       # Generated figures
β”œβ”€β”€ paper/
β”‚   β”œβ”€β”€ mmcot_prune.tex           # Main paper
β”‚   β”œβ”€β”€ supplementary.tex         # Supplementary material
β”‚   └── cover_letter.md           # Cover letter
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ preprocess_datasets.py    # Data preprocessing
β”‚   └── analyze_results.py        # Result analysis
β”œβ”€β”€ requirements.txt              # Python dependencies
└── README.md                     # This file

πŸ”§ Configuration

Hyperparameters

Key hyperparameters can be configured in config.yaml:

pruning:
  base_threshold: 0.4        # Base pruning threshold
  similarity_threshold: 0.85 # Semantic similarity threshold
  entropy_threshold: 2.5     # Attention entropy threshold
  min_steps: 3               # Minimum retained steps
  max_prune_ratio: 0.6       # Maximum pruning fraction

grounding:
  entropy_weight: 0.6        # Weight for entropy component
  peak_attention_weight: 0.4 # Weight for peak attention

difficulty:
  visual_weight: 0.4         # Weight for visual complexity
  text_weight: 0.3           # Weight for text complexity
  object_weight: 0.3         # Weight for object count

πŸ“ Citation

If you find this work useful, please cite:

@article{mmcotprune2026,
  title={MMCoT-Prune: Efficient Multimodal Reasoning through Cross-Modal Attention-Guided Chain-of-Thought Pruning},
  author={Anonymous Authors},
  journal={arXiv preprint arXiv:2026.XXXXX},
  year={2026}
}

🀝 Contributing

We welcome contributions! Please see CONTRIBUTING.md for guidelines.

πŸ“„ License

This project is licensed under the MIT License - see LICENSE for details.

πŸ™ Acknowledgments

This work builds upon:

We thank the authors for making their code and data publicly available.

πŸ“§ Contact

For questions or collaboration:

πŸ”— Related Work


Status: 🚧 Under review at NeurIPS 2026

Last Updated: February 2026

About

MMCoT-Prune: Efficient Multimodal Reasoning through Cross-Modal Attention-Guided Chain-of-Thought Pruning - NeurIPS 2026 Submission

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •