diff --git a/examples/models/mnist_net.py b/examples/models/mnist_net.py index 2bc6e2f..bef0098 100644 --- a/examples/models/mnist_net.py +++ b/examples/models/mnist_net.py @@ -31,3 +31,21 @@ def get_mnist_model(path): net.eval() net.load_state_dict(model_weigths) return net + + +class MNISTNet_with_sphere_space(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(784, 200) + self.fc2 = torch.nn.Linear(200, 200) + self.fc3 = torch.nn.Linear(200, 10) + self.x_sphere_space = None + + def forward(self, x): + x = x.flatten(1) + x = torch.relu(self.fc1(x)) + self.x_sphere_space = torch.nn.functional.normalize(self.fc2(x), dim=1, p=2) # mapping to sphere space + return self.fc3(self.x_sphere_space) + + def return_sphere_space(self): + return self.x_sphere_space \ No newline at end of file diff --git a/examples/robust_learning_against_backdoor.py b/examples/robust_learning_against_backdoor.py new file mode 100644 index 0000000..cdde524 --- /dev/null +++ b/examples/robust_learning_against_backdoor.py @@ -0,0 +1,110 @@ +import torch, torchvision + +from models.mnist_net import MNISTNet_with_sphere_space +from secmlt.adv.poisoning.backdoor import BackdoorDatasetPyTorch, BackdoorDatasetPyTorchWithCoverSample +from secmlt.metrics.classification import Accuracy, AttackSuccessRate +from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier +from torch.utils.data import DataLoader, Subset + +from secmlt.models.pytorch.robust_trainer import RobustPyTorchTrainer + + +def get_mnist_dataloaders(batch_size, target_label, portion, cover_portion, dataset_path, num_workers): + def apply_patch(x: torch.Tensor) -> torch.Tensor: + x[:, 0, 24:28, 24:28] = 1.0 + return x + + training_dataset = torchvision.datasets.MNIST( + transform=torchvision.transforms.ToTensor(), + train=True, + root=dataset_path, + download=True, + ) + tr_ds = BackdoorDatasetPyTorchWithCoverSample( + training_dataset, + data_manipulation_func=apply_patch, + trigger_label=target_label, + portion=portion, + cover_portion=cover_portion, + ) + + ts_ds = torchvision.datasets.MNIST( + transform=torchvision.transforms.ToTensor(), + train=False, + root=dataset_path, + download=True, + ) + # filter out the samples with target label in test dataset + filtered_indices = [i for i, (_, label) in enumerate(ts_ds) if label != target_label] + ts_ds_non_target = Subset(ts_ds, filtered_indices) + p_ts_ds = BackdoorDatasetPyTorch(ts_ds_non_target, data_manipulation_func=apply_patch) + + tr_dl = DataLoader(tr_ds, batch_size, shuffle=True, num_workers=num_workers) + ts_dl = DataLoader(ts_ds, batch_size, shuffle=False, num_workers=num_workers) + p_ts_dl = DataLoader(p_ts_ds, batch_size, shuffle=False, num_workers=num_workers) + + return tr_dl, ts_dl, p_ts_dl + + +def get_validation_data(tr_ds, num_sample, num_workers): + subset_indics = [] + for class_i in range(len(tr_ds.dataset.classes)): + cnt = 0 + for i, (x, label) in enumerate(tr_ds.dataset): + if label == class_i: + if i not in tr_ds.poisoned_indexes and i not in tr_ds.cover_indexes: # labeled dara from benign samples + subset_indics.append(i) + cnt += 1 + else: + pass + + if cnt == num_sample: + break + + subset_indics_left = [] + for i in range(len(tr_ds)): + if i in subset_indics: + continue + else: + subset_indics_left.append(i) + + val_ds = Subset(tr_ds, subset_indics) + val_dl = DataLoader(dataset=val_ds, batch_size=len(val_ds), shuffle=True, num_workers=num_workers) + + return val_dl, subset_indics, subset_indics_left + + +def evaluate_model(model, ts_dl, p_ts_dl, target_label): + # test accuracy without backdoor + acc = Accuracy()(model, ts_dl) + print("acc: {:.3f}".format(acc.item())) + + asr = AttackSuccessRate(y_target=target_label)(model, p_ts_dl) + print("asr: {:.3f}".format(asr.item())) + + return acc, asr + + +def main(): + device = "cuda:0" + batch_size, target_label, portion, cover_portion, dataset_path, num_workers = 1024, 1, 0.01, 0.01, "example_data/datasets/", 0 + # training, test, and p_test dataset + tr_dl, ts_dl, p_ts_dl = get_mnist_dataloaders(batch_size, target_label, portion, cover_portion, dataset_path, num_workers) + # validation dataloader + val_dl, _, _ = get_validation_data(tr_dl.dataset, 50, num_workers) + # define model, optimizer + model = MNISTNet_with_sphere_space() + model.to(device) + # define robust trainer + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + loss = torch.nn.CrossEntropyLoss(reduction="none") + trainer = RobustPyTorchTrainer(validation_dataloader=val_dl, optimizer=optimizer, epochs=10, loss=loss) + # train the model with pgrl robust training and test the performance + model = trainer.train(model, tr_dl) + model = BasePytorchClassifier(model) + + acc, asr = evaluate_model(model, ts_dl, p_ts_dl, target_label) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/secmlt/adv/poisoning/backdoor.py b/src/secmlt/adv/poisoning/backdoor.py index b703452..8c56884 100644 --- a/src/secmlt/adv/poisoning/backdoor.py +++ b/src/secmlt/adv/poisoning/backdoor.py @@ -44,3 +44,81 @@ def __init__( portion=portion, poisoned_indexes=poisoned_indexes, ) + + +import torch, random +# define the backdoor dataset with cover sample +class BackdoorDatasetPyTorchWithCoverSample(PoisoningDatasetPyTorch): + """Dataset class for adding triggers for backdoor attacks.""" + + def __init__( + self, + dataset: Dataset, + data_manipulation_func: callable, + trigger_label: int = 0, + portion: float | None = None, + cover_portion: float = 0.0, + poisoned_indexes: Union[list[int], torch.Tensor] = None, + ) -> None: + """ + Create the backdoored dataset. + + Parameters + ---------- + dataset : torch.utils.data.Dataset + PyTorch dataset. + data_manipulation_func: callable + Function to manipulate the data and add the backdoor. + trigger_label : int, optional + Label to associate with the backdoored data (default 0). + portion : float, optional + Percentage of samples on which the backdoor will be injected (default 0.1). + poisoned_indexes: list[int] | torch.Tensor + Specific indexes of samples to perturb. Alternative to portion. + """ + super().__init__( + dataset=dataset, + data_manipulation_func=data_manipulation_func, + label_manipulation_func=lambda _: trigger_label, + portion=portion, + poisoned_indexes=poisoned_indexes, + ) + if self.poisoned_indexes is not None: + cover_indexes_condidate = set( + i for i in range(len(self.dataset)) if i not in self.poisoned_indexes + ) + self.cover_indexes = set( + random.sample(cover_indexes_condidate, int(len(self.dataset) * cover_portion)) + ) + self.weights = torch.ones(len(self.dataset)) + + def __getitem__(self, idx: int) -> tuple[torch.Tensor, int, float, int, bool]: + """ + Get item from the dataset. + + Parameters + ---------- + idx : int + Index of the item to return + + Returns + ------- + tuple[torch.Tensor, int] + Item at position specified by idx. + """ + x, label = self.dataset[idx] + poison_flag = False + # poison portion of the data + if idx in self.poisoned_indexes: + x = self.data_manipulation_func(x=x.unsqueeze(0)).squeeze(0) + target_label = self.label_manipulation_func(label) + label = ( + target_label + if isinstance(label, int) + else torch.Tensor(target_label).type(label.dtype) + ) + poison_flag = True + if idx in self.cover_indexes: + x = self.data_manipulation_func(x=x.unsqueeze(0)).squeeze(0) + + return x, label, self.weights[idx], idx, poison_flag \ No newline at end of file diff --git a/src/secmlt/models/pytorch/robust_trainer.py b/src/secmlt/models/pytorch/robust_trainer.py new file mode 100644 index 0000000..c42cb62 --- /dev/null +++ b/src/secmlt/models/pytorch/robust_trainer.py @@ -0,0 +1,246 @@ +"""Core code for prototype-based robust training in PyTorch.""" +import numpy as np +import torch.nn +import umap +from secmlt.models.base_trainer import BaseTrainer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from scipy.stats import multivariate_normal +from sklearn.metrics import roc_auc_score + + +def cal_prototype(val_features, val_y): + anchor = [] + for c in range(len(val_y.unique())): + class_features = val_features[val_y == c] + prototype = torch.mean(class_features, dim=0) + anchor.append(prototype) + anchor = torch.stack(anchor) + anchor = torch.nn.functional.normalize(anchor, dim=1, p=2) + + return anchor + + +def get_pseudo_labels(x_features, prototypes): + similarity = torch.matmul(x_features, prototypes.t()) + pseudo_labels = torch.argmax(similarity, dim=1) + + return pseudo_labels + + + + +def calculate_auc(score_benign, gt_poison_labels): + gt_benign_labels = ~gt_poison_labels + auc = roc_auc_score(gt_benign_labels, score_benign) + return auc + + +def visualization_data(sphere_features_epoch_i_reduced, class_i, gt_poison_epoch_i): + # if np.unique(gt_poison_epoch_i).shape[0] <2: + # only draw the plot scale with one color + # if np.unique(gt_poison_epoch_i).shape[0] ==2 draw the plot scale with two colors + import matplotlib.pyplot as plt + plt.figure(figsize=(6, 6)) + if np.unique(gt_poison_epoch_i).shape[0] == 2: + plt.scatter(sphere_features_epoch_i_reduced[gt_poison_epoch_i == 0, 0], + sphere_features_epoch_i_reduced[gt_poison_epoch_i == 0, 1], c='b', label='benign', alpha=0.5) + plt.scatter(sphere_features_epoch_i_reduced[gt_poison_epoch_i == 1, 0], + sphere_features_epoch_i_reduced[gt_poison_epoch_i == 1, 1], c='r', label='poison', alpha=0.5) + plt.legend() + else: + plt.scatter(sphere_features_epoch_i_reduced[:, 0], + sphere_features_epoch_i_reduced[:, 1], c='b', alpha=0.5) + plt.title('Class {}'.format(class_i)) + plt.savefig('class_{}.png'.format(class_i)) + + +def update_weights(dataset, sphere_features_epoch, val_sphere_features, val_features_label, labels_epoch, gt_poison_epoch, indices_epoch, threshold_percent=0.98): + # convert to numpy + sphere_features_epoch = sphere_features_epoch.numpy() + labels_epoch = labels_epoch.numpy() + indices_epoch = indices_epoch.numpy() + gt_poison_epoch = gt_poison_epoch.numpy() + + # umap dimension reduction + umap_model_10d = umap.UMAP(n_components=2, n_jobs=-1, metric='cosine') # , random_state=42) + pdf_all_classes, weight_indices_l, gt_poison_flag_l = [], [], [] + for class_i in np.unique(labels_epoch): + val_sphere_features_i = val_sphere_features[val_features_label == class_i] # Anchor points + sphere_features_epoch_i = sphere_features_epoch[labels_epoch == class_i] # Untrusted samples + indices_epoch_i = indices_epoch[labels_epoch == class_i] + gt_poison_epoch_i = gt_poison_epoch[labels_epoch == class_i] + + concatenated_data = np.concatenate((val_sphere_features_i, sphere_features_epoch_i), axis=0) + reduced_data_10d = umap_model_10d.fit_transform(concatenated_data) + val_sphere_features_i_reduced, sphere_features_epoch_i_reduced = reduced_data_10d[:len(val_sphere_features_i)],\ + reduced_data_10d[len(val_sphere_features_i):] + + # Step 2: Calculate the PDF values + pdf_class_i = [] + for val_sphere_features_i_j in val_sphere_features_i_reduced: + mvn = multivariate_normal(mean=val_sphere_features_i_j, cov=np.eye(len(val_sphere_features_i_j))) + log_prob = mvn.logpdf(sphere_features_epoch_i_reduced) + pdf = np.exp(log_prob) + pdf_class_i.append(pdf) + pdf_class_i = np.stack(pdf_class_i, axis=1) + pdf_class_i = np.max(pdf_class_i, axis=1) + pdf_all_classes.append(pdf_class_i) + weight_indices_l.append(indices_epoch_i) + gt_poison_flag_l.append(gt_poison_epoch_i) + # vislualization the data distribution + visualization_data(sphere_features_epoch_i_reduced, class_i, gt_poison_epoch_i) + + pdf_all_classes = np.concatenate(pdf_all_classes) + weight_indices_l = np.concatenate(weight_indices_l) + gt_poison_epochs = np.concatenate(gt_poison_flag_l) + + # Normalize across all classes + # get threshold based on the threshold_percent + sorted_pdf = np.sort(pdf_all_classes)[::-1] + top_percent = int(threshold_percent * sorted_pdf.size) # - 1 # normalize pdf lower than threshold P(pdf None: + """ + Create PyTorch trainer. + + Parameters + ---------- + validation_dataloader : DataLoader + DataLoader for prototype and feature distance estimation + optimizer : torch.optim.Optimizer + Optimizer to use for training the model. + epochs : int, optional + Number of epochs, by default 5. + loss : torch.nn.Module, optional + Loss to minimize, by default None. + scheduler : _LRScheduler, optional + Scheduler for the optimizer, by default None. + """ + self._validation_dataloader = validation_dataloader + self._epochs = epochs + self._optimizer = optimizer + self._loss = loss if loss is not None else torch.nn.CrossEntropyLoss() + self._scheduler = scheduler + + def train(self, model: torch.nn.Module, dataloader: DataLoader) -> torch.nn.Module: + """ + Train model with given loader. + + Parameters + ---------- + model : torch.nn.Module + Pytorch model to be trained. + dataloader : DataLoader + Train data loader. + + Returns + ------- + torch.nn.Module + Trained model. + """ + device = next(model.parameters()).device + model = model.train() + for e_id in range(self._epochs): + sphere_features_l, labels_l, indices_l = [], [], [] + val_sphere_features, val_features_label = None, None + pred_poison_l, gt_poison_l = [], [] + for _, batch in enumerate(dataloader): + x, y, w, i = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device) + batch_val = self.get_batch_val(self._validation_dataloader) + val_x, val_y = batch_val[0].to(device), batch_val[1].to(device) + combined_x = torch.cat([x, val_x], dim=0) + self._optimizer.zero_grad() + combined_outputs = model(combined_x) + outputs = combined_outputs[:len(x)] + sphere_features = model.return_sphere_space() + sphere_features_l.append(sphere_features[:len(x)].detach().cpu()) + val_sphere_features = sphere_features[len(x):].detach().cpu() + val_features_label = val_y.detach().cpu() + indices_l.append(i.detach().cpu()) + labels_l.append(y.detach().cpu()) + # generate prototypes + prototypes = cal_prototype(sphere_features[len(x):], val_y) + # get psudeo labels based on closest prototype + pseudo_labels = get_pseudo_labels(sphere_features[:len(x)], prototypes) + consistency_flag = y == pseudo_labels + pred_poison_l.append((~consistency_flag).long().detach().cpu()) + gt_poison_l.append(batch[4].long().detach().cpu()) + # compute the Tpr and fpr between consistency and poison_flags batch[4] + loss = torch.mean(self._loss(outputs[consistency_flag], y[consistency_flag]) * w[consistency_flag].to(device)) + loss.backward() + self._optimizer.step() + if self._scheduler is not None: + self._scheduler.step() + pred_poison_l = torch.cat(pred_poison_l, dim=0) + gt_poison_l = torch.cat(gt_poison_l, dim=0) + tpr, fpr = calculate_tpr_fpr(pred_poison_l.cpu().numpy(), gt_poison_l.cpu().numpy()) + # print('tpr, fpr: {}, {}'.format(tpr, fpr)) + if e_id % 2 == 0: + # update the weights based on the features + sphere_features_epoch = torch.cat(sphere_features_l, dim=0) + labels_epoch = torch.cat(labels_l, dim=0) + indices_epoch = torch.cat(indices_l, dim=0) + # update weights in dataloader + update_weights(dataloader.dataset, sphere_features_epoch, val_sphere_features, val_features_label, labels_epoch, gt_poison_l, indices_epoch) + + return model + + def get_batch_val(self, val_dl): + # supervised learning + try: + batch_val = next(iter(val_dl)) + except StopIteration: + val_dl = DataLoader(val_dl.dataset, batch_size=len(val_dl.dataset), shuffle=True, + num_workers=val_dl.num_workers) + batch_val = next(iter(val_dl)) + while len(set(batch_val[1].tolist())) != len( + val_dl.dataset.dataset.dataset.classes): # ensure there is at least one element for each class + batch_val = next(iter(val_dl)) + + return batch_val \ No newline at end of file