diff --git a/dattri/algorithm/block_projected_if/core/compressor.py b/dattri/algorithm/block_projected_if/core/compressor.py index 64a9393e..eb8b2959 100644 --- a/dattri/algorithm/block_projected_if/core/compressor.py +++ b/dattri/algorithm/block_projected_if/core/compressor.py @@ -273,7 +273,7 @@ def setup_model_compressors( # noqa: PLR0912, PLR0914, PLR0915 - Complex setup inputs = {k: v.to(device) for k, v in sample_inputs.items()} model(**inputs) else: - inputs = sample_inputs[0].to(device) + inputs = sample_inputs.to(device) model(inputs) # First, capture inputs and outputs for each layer diff --git a/dattri/algorithm/trak.py b/dattri/algorithm/trak.py index 91e37ddc..b1195308 100644 --- a/dattri/algorithm/trak.py +++ b/dattri/algorithm/trak.py @@ -15,6 +15,7 @@ import torch from torch.func import vmap from tqdm import tqdm +import torch.nn.functional as F from dattri.func.projection import random_project from dattri.func.utils import _unflatten_params @@ -36,7 +37,7 @@ class TRAKAttributor(BaseAttributor): def __init__( self, task: AttributionTask, - correct_probability_func: Callable, + correct_probability_func: Optional[Callable] = None, projector_kwargs: Optional[Dict[str, Any]] = None, layer_name: Optional[Union[str, List[str]]] = None, device: str = "cpu", @@ -74,6 +75,42 @@ def m(params, image_label_pair): Added as `regularization * I`, where `I` is the identity matrix. Default is 0.0. """ + + if correct_probability_func is None: + print(getattr(task, 'task_type', None)) + if getattr(task, 'task_type', None) in['image_classification', 'text_classification']: + def default_m(params, data): + if isinstance(data, dict): + if 'label' in data: + label = data['label'] + elif 'labels' in data: + label = data['labels'] + else: + raise ValueError("Dictionary data must contain 'label' or 'labels'") + inputs = {k: v for k, v in data.items() if k not in ['label', 'labels']} + elif isinstance(data, (tuple, list)): + inputs, label = data[0], data[1] + else: + inputs, label = data + + if isinstance(inputs, dict): + inputs_vmap = {'input_ids': data['input_ids'].unsqueeze(0)} + yhat = torch.func.functional_call(self.task.get_model(), params, kwargs=inputs_vmap) + else: + yhat = torch.func.functional_call(self.task.get_model(), params, inputs.unsqueeze(0)) + + if hasattr(yhat, 'logits'): + yhat = yhat.logits + + loss_val = torch.nn.functional.cross_entropy(yhat, label.unsqueeze(0)) + return torch.exp(-loss_val) + + selected_prob_func = default_m + else: + raise ValueError(f"Unsupported task type: {task_type}, please provide loss_func and correct_probability_func.") + else: + selected_prob_func = correct_probability_func + self.task = task self.norm_scaler = ( sum( @@ -91,7 +128,7 @@ def m(params, image_label_pair): self.grad_target_func = self.task.get_grad_target_func(in_dims=(None, 0)) self.grad_loss_func = self.task.get_grad_loss_func(in_dims=(None, 0)) self.correct_probability_func = vmap( - correct_probability_func, + selected_prob_func, in_dims=(None, 0), randomness="different", ) @@ -143,6 +180,11 @@ def cache( train_batch_data = tuple( data.to(self.device) for data in train_data ) + elif isinstance(train_data, dict): + train_batch_data = { + k: v.to(self.device) if isinstance(v, torch.Tensor) else v + for k, v in train_data.items() + } else: train_batch_data = train_data @@ -307,6 +349,8 @@ def attribute( # noqa: PLR0912, PLR0914, PLR0915 # TODO: reorganize the data pre-grad processing. if isinstance(test_data, (tuple, list)): test_batch_data = tuple(data.to(self.device) for data in test_data) + elif isinstance(test_data, dict): + test_batch_data = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in test_data.items()} else: test_batch_data = test_data grad_t = self.grad_target_func(parameters, test_batch_data) diff --git a/dattri/task.py b/dattri/task.py index fc57f3cb..a7d7f151 100644 --- a/dattri/task.py +++ b/dattri/task.py @@ -42,7 +42,6 @@ class AttributionTask: def __init__( self, - loss_func: Callable, model: nn.Module, checkpoints: Union[ str, @@ -50,12 +49,25 @@ def __init__( List[Dict[str, torch.Tensor]], Dict[str, torch.Tensor], ], + loss_func: Optional[Callable] = None, + task_type: Optional[str] = None, target_func: Optional[Callable] = None, checkpoints_load_func: Optional[Callable] = None, ) -> None: """Initialize the AttributionTask. Args: + model (nn.Module): The model that the target function is based on. + To be more specific, the model is the `model` used in the target + function. Since only the computation graph of the model will be + used, so it is allowed that this model is not loaded with a trained + parameters. + checkpoints: + (Union[str, List[str], List[Dict[str, torch.Tensor]], + Dict[str, torch.Tensor]]): The checkpoints + of the model, both dictionary of the state_dict and the path to + the checkpoint are supported. If ensemble is needed, a list of + checkpoint is also supported. loss_func (Callable): The loss function of the model training. The function can be quite flexible in terms of what is calculated, but it should take the parameters and the data as input. Other than @@ -71,17 +83,13 @@ def f(params, data): return loss(yhat, label) ```. This examples calculates the CE loss of the model on the data. - model (nn.Module): The model that the target function is based on. - To be more specific, the model is the `model` used in the target - function. Since only the computation graph of the model will be - used, so it is allowed that this model is not loaded with a trained - parameters. - checkpoints: - (Union[str, List[str], List[Dict[str, torch.Tensor]], - Dict[str, torch.Tensor]]): The checkpoints - of the model, both dictionary of the state_dict and the path to - the checkpoint are supported. If ensemble is needed, a list of - checkpoint is also supported. + task_type (str): The type of the task for which attribution is being computed. + Given this parameter, you don't need to manually provide a `loss_func`, + as the appropriate CE loss will be automatically matched. + Supported task types include: + - 'image_classification': For tasks involving image classification. + - 'text_classification': For tasks involving text classification. + - Other task types can be added as needed, and specific handling will be implemented for each task type. target_func (Callable): The target function to be attributed. This input is optional, if not provided, the target function will be the same as the loss function. The function can be quite flexible @@ -110,6 +118,51 @@ def checkpoints_load_func(model, checkpoint): ```. """ self.model = model + + if loss_func is None: + if task_type == 'image_classification': + def trak_log_odds_loss(params, data): + if isinstance(data, (tuple, list)): + inputs, label = data[0], data[1] + else: + inputs, label = data + + yhat = torch.func.functional_call(model, params, inputs.unsqueeze(0)) + if hasattr(yhat, 'logits'): + yhat = yhat.logits + loss_val = torch.nn.functional.cross_entropy(yhat, label.unsqueeze(0)) + logp = -loss_val + eps = 1e-10 + return logp - torch.log(1.0 - torch.exp(logp) + eps) + + loss_func = trak_log_odds_loss + elif task_type == 'text_classification': + def trak_log_odds_loss(params, data): + if isinstance(data, dict): + if 'label' in data: + label = data['label'] + elif 'labels' in data: + label = data['labels'] + else: + raise ValueError("Dictionary data must contain 'label' or 'labels'") + inputs = {k: v for k, v in data.items() if k not in ['label', 'labels']} + else: + inputs, label = data + + if isinstance(inputs, dict): + inputs_vmap = {'input_ids': data['input_ids'].unsqueeze(0)} + yhat = torch.func.functional_call(model, params, kwargs=inputs_vmap) + if hasattr(yhat, 'logits'): + yhat = yhat.logits + loss_val = torch.nn.functional.cross_entropy(yhat, label.unsqueeze(0)) + logp = -loss_val + eps = 1e-10 + return logp - torch.log(1.0 - torch.exp(logp) + eps) + + loss_func = trak_log_odds_loss + else: + raise ValueError(f"Unsupported task type: {task_type}, please provide loss_func") + if target_func is None: target_func = loss_func @@ -131,6 +184,7 @@ def checkpoints_load_func(model, checkpoint): else: self.checkpoints = checkpoints + self.task_type=task_type # current_checkpoint_idx is used to state # which checkpoint is currently loaded. self.current_checkpoint_idx = None