Skip to content

Fix dimension mismatch and add auto-loss support to TRAK#246

Open
traderbxy wants to merge 2 commits intoTRAIS-Lab:mainfrom
traderbxy:main
Open

Fix dimension mismatch and add auto-loss support to TRAK#246
traderbxy wants to merge 2 commits intoTRAIS-Lab:mainfrom
traderbxy:main

Conversation

@traderbxy
Copy link

Description

Summary

This PR fixes a dimension mismatch issue and enhances the TRAK method's usability by adding automatic loss/probability function support for common task types.

Changes

  1. Fix dimension conflict in sample_inputs

    • File: dattri/dattri/algorithm/block_projected_if/core/compressor.py
    • Issue: sample_inputs[0] had a dimension mismatch with model expectations, causing runtime errors.
    • Fix: Adjusted data preprocessing logic to ensure proper dimension alignment between sample_inputs and model input requirements.
  2. Enhanced TRAK method with automatic function support

    • Modified classes: TRAKAttributor and AttributionTask
    • New parameter: Added task_type parameter supporting 'image_classification' and 'text_classification' values
    • Improvements:
      • When task_type is specified, users no longer need to manually design loss_func and correct_probability_func
      • Added built-in, optimized functions for both classification tasks
      • Extended support for dictionary-formatted inputs in text classification scenarios
      • Maintains backward compatibility - manual function definitions still work when task_type is not specified

Benefits

  • Reduced complexity: Users can now use TRAK for common tasks without custom function implementation
  • Better UX: More consistent API with other attribution methods
  • Extended compatibility: Better support for text classification pipelines with dictionary inputs
  • Bug fix: Resolves dimension-related crashes in the block_projected_if algorithm

model(**inputs)
else:
inputs = sample_inputs[0].to(device)
inputs = sample_inputs.to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a unit test related to this bug fix so that we could make sure future changes won't introduce such bugs again?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, I would recommend separating this bug fix from the other change to maintain a clean change log.

import torch
from torch.func import vmap
from tqdm import tqdm
import torch.nn.functional as F
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is F used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants