Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dattri/algorithm/block_projected_if/core/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def setup_model_compressors( # noqa: PLR0912, PLR0914, PLR0915 - Complex setup
inputs = {k: v.to(device) for k, v in sample_inputs.items()}
model(**inputs)
else:
inputs = sample_inputs[0].to(device)
inputs = sample_inputs.to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a unit test related to this bug fix so that we could make sure future changes won't introduce such bugs again?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, I would recommend separating this bug fix from the other change to maintain a clean change log.

model(inputs)

# First, capture inputs and outputs for each layer
Expand Down
48 changes: 46 additions & 2 deletions dattri/algorithm/trak.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from torch.func import vmap
from tqdm import tqdm
import torch.nn.functional as F
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is F used?


from dattri.func.projection import random_project
from dattri.func.utils import _unflatten_params
Expand All @@ -36,7 +37,7 @@ class TRAKAttributor(BaseAttributor):
def __init__(
self,
task: AttributionTask,
correct_probability_func: Callable,
correct_probability_func: Optional[Callable] = None,
projector_kwargs: Optional[Dict[str, Any]] = None,
layer_name: Optional[Union[str, List[str]]] = None,
device: str = "cpu",
Expand Down Expand Up @@ -74,6 +75,42 @@ def m(params, image_label_pair):
Added as `regularization * I`, where `I` is the identity matrix.
Default is 0.0.
"""

if correct_probability_func is None:
print(getattr(task, 'task_type', None))
if getattr(task, 'task_type', None) in['image_classification', 'text_classification']:
def default_m(params, data):
if isinstance(data, dict):
if 'label' in data:
label = data['label']
elif 'labels' in data:
label = data['labels']
else:
raise ValueError("Dictionary data must contain 'label' or 'labels'")
inputs = {k: v for k, v in data.items() if k not in ['label', 'labels']}
elif isinstance(data, (tuple, list)):
inputs, label = data[0], data[1]
else:
inputs, label = data

if isinstance(inputs, dict):
inputs_vmap = {'input_ids': data['input_ids'].unsqueeze(0)}
yhat = torch.func.functional_call(self.task.get_model(), params, kwargs=inputs_vmap)
else:
yhat = torch.func.functional_call(self.task.get_model(), params, inputs.unsqueeze(0))

if hasattr(yhat, 'logits'):
yhat = yhat.logits

loss_val = torch.nn.functional.cross_entropy(yhat, label.unsqueeze(0))
return torch.exp(-loss_val)

selected_prob_func = default_m
else:
raise ValueError(f"Unsupported task type: {task_type}, please provide loss_func and correct_probability_func.")
else:
selected_prob_func = correct_probability_func

self.task = task
self.norm_scaler = (
sum(
Expand All @@ -91,7 +128,7 @@ def m(params, image_label_pair):
self.grad_target_func = self.task.get_grad_target_func(in_dims=(None, 0))
self.grad_loss_func = self.task.get_grad_loss_func(in_dims=(None, 0))
self.correct_probability_func = vmap(
correct_probability_func,
selected_prob_func,
in_dims=(None, 0),
randomness="different",
)
Expand Down Expand Up @@ -143,6 +180,11 @@ def cache(
train_batch_data = tuple(
data.to(self.device) for data in train_data
)
elif isinstance(train_data, dict):
train_batch_data = {
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in train_data.items()
}
else:
train_batch_data = train_data

Expand Down Expand Up @@ -307,6 +349,8 @@ def attribute( # noqa: PLR0912, PLR0914, PLR0915
# TODO: reorganize the data pre-grad processing.
if isinstance(test_data, (tuple, list)):
test_batch_data = tuple(data.to(self.device) for data in test_data)
elif isinstance(test_data, dict):
test_batch_data = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in test_data.items()}
else:
test_batch_data = test_data
grad_t = self.grad_target_func(parameters, test_batch_data)
Expand Down
78 changes: 66 additions & 12 deletions dattri/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,32 @@ class AttributionTask:

def __init__(
self,
loss_func: Callable,
model: nn.Module,
checkpoints: Union[
str,
List[str],
List[Dict[str, torch.Tensor]],
Dict[str, torch.Tensor],
],
loss_func: Optional[Callable] = None,
task_type: Optional[str] = None,
target_func: Optional[Callable] = None,
checkpoints_load_func: Optional[Callable] = None,
) -> None:
"""Initialize the AttributionTask.

Args:
model (nn.Module): The model that the target function is based on.
To be more specific, the model is the `model` used in the target
function. Since only the computation graph of the model will be
used, so it is allowed that this model is not loaded with a trained
parameters.
checkpoints:
(Union[str, List[str], List[Dict[str, torch.Tensor]],
Dict[str, torch.Tensor]]): The checkpoints
of the model, both dictionary of the state_dict and the path to
the checkpoint are supported. If ensemble is needed, a list of
checkpoint is also supported.
loss_func (Callable): The loss function of the model training.
The function can be quite flexible in terms of what is calculated,
but it should take the parameters and the data as input. Other than
Expand All @@ -71,17 +83,13 @@ def f(params, data):
return loss(yhat, label)
```.
This examples calculates the CE loss of the model on the data.
model (nn.Module): The model that the target function is based on.
To be more specific, the model is the `model` used in the target
function. Since only the computation graph of the model will be
used, so it is allowed that this model is not loaded with a trained
parameters.
checkpoints:
(Union[str, List[str], List[Dict[str, torch.Tensor]],
Dict[str, torch.Tensor]]): The checkpoints
of the model, both dictionary of the state_dict and the path to
the checkpoint are supported. If ensemble is needed, a list of
checkpoint is also supported.
task_type (str): The type of the task for which attribution is being computed.
Given this parameter, you don't need to manually provide a `loss_func`,
as the appropriate CE loss will be automatically matched.
Supported task types include:
- 'image_classification': For tasks involving image classification.
- 'text_classification': For tasks involving text classification.
- Other task types can be added as needed, and specific handling will be implemented for each task type.
target_func (Callable): The target function to be attributed.
This input is optional, if not provided, the target function will
be the same as the loss function. The function can be quite flexible
Expand Down Expand Up @@ -110,6 +118,51 @@ def checkpoints_load_func(model, checkpoint):
```.
"""
self.model = model

if loss_func is None:
if task_type == 'image_classification':
def trak_log_odds_loss(params, data):
if isinstance(data, (tuple, list)):
inputs, label = data[0], data[1]
else:
inputs, label = data

yhat = torch.func.functional_call(model, params, inputs.unsqueeze(0))
if hasattr(yhat, 'logits'):
yhat = yhat.logits
loss_val = torch.nn.functional.cross_entropy(yhat, label.unsqueeze(0))
logp = -loss_val
eps = 1e-10
return logp - torch.log(1.0 - torch.exp(logp) + eps)

loss_func = trak_log_odds_loss
elif task_type == 'text_classification':
def trak_log_odds_loss(params, data):
if isinstance(data, dict):
if 'label' in data:
label = data['label']
elif 'labels' in data:
label = data['labels']
else:
raise ValueError("Dictionary data must contain 'label' or 'labels'")
inputs = {k: v for k, v in data.items() if k not in ['label', 'labels']}
else:
inputs, label = data

if isinstance(inputs, dict):
inputs_vmap = {'input_ids': data['input_ids'].unsqueeze(0)}
yhat = torch.func.functional_call(model, params, kwargs=inputs_vmap)
if hasattr(yhat, 'logits'):
yhat = yhat.logits
loss_val = torch.nn.functional.cross_entropy(yhat, label.unsqueeze(0))
logp = -loss_val
eps = 1e-10
return logp - torch.log(1.0 - torch.exp(logp) + eps)

loss_func = trak_log_odds_loss
else:
raise ValueError(f"Unsupported task type: {task_type}, please provide loss_func")

if target_func is None:
target_func = loss_func

Expand All @@ -131,6 +184,7 @@ def checkpoints_load_func(model, checkpoint):
else:
self.checkpoints = checkpoints

self.task_type=task_type
# current_checkpoint_idx is used to state
# which checkpoint is currently loaded.
self.current_checkpoint_idx = None
Expand Down
Loading