Improved the code in IF and TRAK to support passing in dict#225
Improved the code in IF and TRAK to support passing in dict#225Suliang-Jin wants to merge 6 commits intoTRAIS-Lab:mainfrom
Conversation
|
@Suliang-Jin given this is a relatively large PR, could you follow the PR template to provide more detailed context about this PR? |
README.md
Outdated
|
|
||
| The following is an example to use `IFAttributorCG` and `AttributionTask` to apply data attribution to a PyTorch model. | ||
|
|
||
| Please reference [here](./docs/guide/README.md) for the guide on how to properly define train/test data for Attributor and loss/target function. |
There was a problem hiding this comment.
There is no such file for ./docs/guide/README.md?
There was a problem hiding this comment.
I think I have created this README in this PR
.github/workflows/examples_test.yml
Outdated
| python examples/brittleness/mnist_lr_brittleness.py --method cg --device cpu | ||
| python examples/data_cleaning/influence_function_data_cleaning.py --device cpu --train_size 1000 --val_size 100 --test_size 100 --remove_number 10 | ||
| python examples/relatIF/influence_function_comparison.py --no_output | ||
| sed -i 's/range(1000)/range(100)/g' examples/lds_vs_gt/mnist.py |
There was a problem hiding this comment.
This change should belong to another PR?
There was a problem hiding this comment.
Yes, sorry about it. I will fix it.
SummaryI'm sorry about the confusion in this PR. Please refer to the PR created on Nov 26. What’s Changed
MotivationThe original issue is raised from issue #165. How It WorksOn the support of Huggingface Transformers on Influence Function (see in train_batch_data = tuple(
data.to(self.device).unsqueeze(0) for data in train_batch_data_
)Testing
Related IssuesFixes #165 |
dattri/algorithm/base.py
Outdated
| ) | ||
| elif isinstance(train_batch_data_, dict): | ||
| train_batch_data = { | ||
| k: v.unsqueeze(0) for k, v in train_batch_data_.items() |
There was a problem hiding this comment.
We also assume the value in dictionary to be tensor right?
There was a problem hiding this comment.
I think we should put the data to self.device here.
| k: v.to(self.device) for k, v in full_data_.items() | ||
| } | ||
| else: | ||
| raise Exception("We currently only support the train/test data to be tuple, list or dict.") |
There was a problem hiding this comment.
No need to fix IFAttributor API, I will delete it in another PR.
dattri/model_util/retrain.py
Outdated
| """ | ||
| if seed is None: | ||
| seed = random.getrandbits(64) | ||
|
|
There was a problem hiding this comment.
This is coverred in another PR?
examples/lds_vs_gt/mnist.py
Outdated
| # Calculate and print LDS score | ||
| ############################## | ||
| lds_score = lds(score, ground_truth)[0] | ||
| print("lds:", torch.mean(lds_score[~torch.isnan(lds_score)])) No newline at end of file |
There was a problem hiding this comment.
This is coverred in another PR?
docs/guide/README.md
Outdated
| @@ -0,0 +1,120 @@ | |||
| # User Guide | |||
There was a problem hiding this comment.
The documentation is clear, a high-level summary table at the top would be beneficial. It should list the supported data types and callable types for all methods in the https://github.com/TRAIS-Lab/dattri?tab=readme-ov-file#supported-algorithms.
| ) | ||
| logp = -outputs.loss | ||
| return logp - torch.log(1 - torch.exp(logp)) | ||
| ``` |
There was a problem hiding this comment.
Slightly different requirements should be applied to TRAK (Multi-class Margin) and TracIN (training loss and any target function).
| @@ -0,0 +1,686 @@ | |||
| #!/usr/bin/env python | |||
There was a problem hiding this comment.
We only need one script for IF
| @@ -0,0 +1,742 @@ | |||
| #!/usr/bin/env python | |||
There was a problem hiding this comment.
We only need one script for TRAK.
9471f17 to
3ae6be8
Compare
docs/guide/README.md
Outdated
| @@ -0,0 +1,142 @@ | |||
| # User Guide | |||
There was a problem hiding this comment.
Chage the title to "Data Type Compatibility for Loss and Target Functions". Rename the file to be data_compatibility.md
docs/guide/README.md
Outdated
| | | [EK-FAC](https://arxiv.org/abs/2308.03296) | ✔️ | ✔️ | ❌ | [Code example](../../examples/brittleness/mnist_lr_brittleness.py) | | ||
| | | [RelatIF](https://arxiv.org/pdf/2003.11630) | ✔️ | ✔️ | ❌ | [Code example](../../examples/brittleness/mnist_lr_brittleness.py) | | ||
| | | [LoGra](https://arxiv.org/pdf/2405.13954) | ✔️ | ✔️ | ❌ | [Code example](../../examples/brittleness/mnist_lr_brittleness.py) | | ||
| | | [GraSS](https://arxiv.org/pdf/2505.18976) | ✔️ | ✔️ | ❌ | [Code example](../../examples/brittleness/mnist_lr_brittleness.py) | |
There was a problem hiding this comment.
I think for GraSS, LoGra, RelateIF, EK-FAC, we don't have their examples in ../../examples/brittleness/mnist_lr_brittleness.py
| type=str, | ||
| default="tuple", | ||
| choices=["tuple", "list", "dict"] | ||
| ) |
There was a problem hiding this comment.
We don't need to show-off what we have supported in the examples. Just choose the most convenient way and demonstrate it in the script.
| type=str, | ||
| default="tuple", | ||
| choices=["tuple", "list", "dict"] | ||
| ) |
| score = attributor.attribute(train_dataloader, eval_dataloader) | ||
|
|
||
| torch.save(score, "score_IF.pt") | ||
| logger.info("Attribution scores saved to score_IF.pt") |
There was a problem hiding this comment.
How does IF perform on GPT-2 + wikitext setting?
| type=str, | ||
| default="tuple", | ||
| help="What data structure to pass the training/test data for data attribution." | ||
| ) |
There was a problem hiding this comment.
We could simply remove this argument, and assume the input structure is dict, as it is more natural for huggingface datasets.
| if args.data_structure == "dict": | ||
| train_dataset = [{k: torch.tensor(v, dtype=torch.long) for k, v in d.items()} for d in train_dataset] | ||
| eval_dataset = [{k: torch.tensor(v, dtype=torch.long) for k, v in d.items()} for d in eval_dataset] | ||
|
|
There was a problem hiding this comment.
Remove the conditional check
| train_dataloader = DataLoader( | ||
| train_dataset, | ||
| batch_size=args.per_device_train_batch_size, | ||
| sampler=train_sampler, |
| train_dataset, | ||
| collate_fn=custom_collate_fn, | ||
| batch_size=args.per_device_train_batch_size, | ||
| sampler=train_sampler, |
There was a problem hiding this comment.
Just simply remove the list/tuple case
| input_ids, | ||
| kwargs={"attention_mask": attention_mask, "labels": labels}, | ||
| ) | ||
| return outputs.loss |
|
Thanks @Suliang-Jin, I have walked through the changes and most of the parts LGTM. One additional thing is that we might also need a few new unit test cases with |
|
Thanks @sx-liu! I will make the update in these two days:) |
Description
I have changed some code in IF, TRAK and TracIn so they now support passing in dict. If the user still passes in tuple or list, the behavior doesn't change.
I have also added some experiments on supporting IF and Huggingface transformers.