Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions examples/models/mnist_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,21 @@ def get_mnist_model(path):
net.eval()
net.load_state_dict(model_weigths)
return net


class MNISTNet_with_sphere_space(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(784, 200)
self.fc2 = torch.nn.Linear(200, 200)
self.fc3 = torch.nn.Linear(200, 10)
self.x_sphere_space = None

def forward(self, x):
x = x.flatten(1)
x = torch.relu(self.fc1(x))
self.x_sphere_space = torch.nn.functional.normalize(self.fc2(x), dim=1, p=2) # mapping to sphere space
return self.fc3(self.x_sphere_space)

def return_sphere_space(self):
return self.x_sphere_space
110 changes: 110 additions & 0 deletions examples/robust_learning_against_backdoor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch, torchvision

from models.mnist_net import MNISTNet_with_sphere_space
from secmlt.adv.poisoning.backdoor import BackdoorDatasetPyTorch, BackdoorDatasetPyTorchWithCoverSample
from secmlt.metrics.classification import Accuracy, AttackSuccessRate
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier
from torch.utils.data import DataLoader, Subset

from secmlt.models.pytorch.robust_trainer import RobustPyTorchTrainer


def get_mnist_dataloaders(batch_size, target_label, portion, cover_portion, dataset_path, num_workers):
def apply_patch(x: torch.Tensor) -> torch.Tensor:
x[:, 0, 24:28, 24:28] = 1.0
return x

training_dataset = torchvision.datasets.MNIST(
transform=torchvision.transforms.ToTensor(),
train=True,
root=dataset_path,
download=True,
)
tr_ds = BackdoorDatasetPyTorchWithCoverSample(
training_dataset,
data_manipulation_func=apply_patch,
trigger_label=target_label,
portion=portion,
cover_portion=cover_portion,
)

ts_ds = torchvision.datasets.MNIST(
transform=torchvision.transforms.ToTensor(),
train=False,
root=dataset_path,
download=True,
)
# filter out the samples with target label in test dataset
filtered_indices = [i for i, (_, label) in enumerate(ts_ds) if label != target_label]
ts_ds_non_target = Subset(ts_ds, filtered_indices)
p_ts_ds = BackdoorDatasetPyTorch(ts_ds_non_target, data_manipulation_func=apply_patch)

tr_dl = DataLoader(tr_ds, batch_size, shuffle=True, num_workers=num_workers)
ts_dl = DataLoader(ts_ds, batch_size, shuffle=False, num_workers=num_workers)
p_ts_dl = DataLoader(p_ts_ds, batch_size, shuffle=False, num_workers=num_workers)

return tr_dl, ts_dl, p_ts_dl


def get_validation_data(tr_ds, num_sample, num_workers):
subset_indics = []
for class_i in range(len(tr_ds.dataset.classes)):
cnt = 0
for i, (x, label) in enumerate(tr_ds.dataset):
if label == class_i:
if i not in tr_ds.poisoned_indexes and i not in tr_ds.cover_indexes: # labeled dara from benign samples
subset_indics.append(i)
cnt += 1
else:
pass

if cnt == num_sample:
break

subset_indics_left = []
for i in range(len(tr_ds)):
if i in subset_indics:
continue
else:
subset_indics_left.append(i)

val_ds = Subset(tr_ds, subset_indics)
val_dl = DataLoader(dataset=val_ds, batch_size=len(val_ds), shuffle=True, num_workers=num_workers)

return val_dl, subset_indics, subset_indics_left


def evaluate_model(model, ts_dl, p_ts_dl, target_label):
# test accuracy without backdoor
acc = Accuracy()(model, ts_dl)
print("acc: {:.3f}".format(acc.item()))

asr = AttackSuccessRate(y_target=target_label)(model, p_ts_dl)
print("asr: {:.3f}".format(asr.item()))

return acc, asr


def main():
device = "cuda:0"
batch_size, target_label, portion, cover_portion, dataset_path, num_workers = 1024, 1, 0.01, 0.01, "example_data/datasets/", 0
# training, test, and p_test dataset
tr_dl, ts_dl, p_ts_dl = get_mnist_dataloaders(batch_size, target_label, portion, cover_portion, dataset_path, num_workers)
# validation dataloader
val_dl, _, _ = get_validation_data(tr_dl.dataset, 50, num_workers)
# define model, optimizer
model = MNISTNet_with_sphere_space()
model.to(device)
# define robust trainer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss = torch.nn.CrossEntropyLoss(reduction="none")
trainer = RobustPyTorchTrainer(validation_dataloader=val_dl, optimizer=optimizer, epochs=10, loss=loss)
# train the model with pgrl robust training and test the performance
model = trainer.train(model, tr_dl)
model = BasePytorchClassifier(model)

acc, asr = evaluate_model(model, ts_dl, p_ts_dl, target_label)


if __name__ == "__main__":
main()
78 changes: 78 additions & 0 deletions src/secmlt/adv/poisoning/backdoor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,81 @@ def __init__(
portion=portion,
poisoned_indexes=poisoned_indexes,
)


import torch, random
# define the backdoor dataset with cover sample
class BackdoorDatasetPyTorchWithCoverSample(PoisoningDatasetPyTorch):
"""Dataset class for adding triggers for backdoor attacks."""

def __init__(
self,
dataset: Dataset,
data_manipulation_func: callable,
trigger_label: int = 0,
portion: float | None = None,
cover_portion: float = 0.0,
poisoned_indexes: Union[list[int], torch.Tensor] = None,
) -> None:
"""
Create the backdoored dataset.

Parameters
----------
dataset : torch.utils.data.Dataset
PyTorch dataset.
data_manipulation_func: callable
Function to manipulate the data and add the backdoor.
trigger_label : int, optional
Label to associate with the backdoored data (default 0).
portion : float, optional
Percentage of samples on which the backdoor will be injected (default 0.1).
poisoned_indexes: list[int] | torch.Tensor
Specific indexes of samples to perturb. Alternative to portion.
"""
super().__init__(
dataset=dataset,
data_manipulation_func=data_manipulation_func,
label_manipulation_func=lambda _: trigger_label,
portion=portion,
poisoned_indexes=poisoned_indexes,
)
if self.poisoned_indexes is not None:
cover_indexes_condidate = set(
i for i in range(len(self.dataset)) if i not in self.poisoned_indexes
)
self.cover_indexes = set(
random.sample(cover_indexes_condidate, int(len(self.dataset) * cover_portion))
)
self.weights = torch.ones(len(self.dataset))

def __getitem__(self, idx: int) -> tuple[torch.Tensor, int, float, int, bool]:
"""
Get item from the dataset.

Parameters
----------
idx : int
Index of the item to return

Returns
-------
tuple[torch.Tensor, int]
Item at position specified by idx.
"""
x, label = self.dataset[idx]
poison_flag = False
# poison portion of the data
if idx in self.poisoned_indexes:
x = self.data_manipulation_func(x=x.unsqueeze(0)).squeeze(0)
target_label = self.label_manipulation_func(label)
label = (
target_label
if isinstance(label, int)
else torch.Tensor(target_label).type(label.dtype)
)
poison_flag = True
if idx in self.cover_indexes:
x = self.data_manipulation_func(x=x.unsqueeze(0)).squeeze(0)

return x, label, self.weights[idx], idx, poison_flag
Loading