Skip to content

Latest commit

 

History

History
28 lines (24 loc) · 975 Bytes

File metadata and controls

28 lines (24 loc) · 975 Bytes

ReTrace

It includes training classifiers under different unlearning settings, extracting traces, and reconstructing forgotten data with reinforcement learning and GANs.

Files

  • unlearning.py – Train classifiers:
    • Normal training
    • Exact unlearning
    • Approximate unlearning
  • trace_heatmap.py – Generate heatmaps of unlearning traces for forgotten classes.
  • trace_distribution.py – Analyze and output the distribution of unlearning traces.
  • RL_GAN.py – Use reinforcement learning with a pretrained GAN to reconstruct forgotten data.

Usage

Train $f^+$ (original model):

python unlearning.py --task train_fplus --dataset cifar100

Train $f^-$ (unlearned model):

python unlearning.py --task train_fminus_exact --epochs 120 --forget_classes 0 --dataset cifar100

Train RL-GAN:

python RL_GAN.py

Note: Pre-train the GAN and set the checkpoint path in RL_GAN.py before running.