Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,14 @@

import torch

__all__ = ["Composer"]

class Composer(abc.ABC):
def __call__(
self,
perturbation: torch.Tensor | tuple,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
**kwargs,
) -> torch.Tensor | tuple:
if isinstance(perturbation, tuple):
input_adv = tuple(
self.compose(perturbation_i, input=input_i, target=target_i)
for perturbation_i, input_i, target_i in zip(perturbation, input, target)
)
else:
input_adv = self.compose(perturbation, input=input, target=target)

return input_adv
class Method(abc.ABC):
"""Composition method base class."""

@abc.abstractmethod
def compose(
def __call__(
self,
perturbation: torch.Tensor,
*,
Expand All @@ -42,17 +28,17 @@ def compose(
raise NotImplementedError


class Additive(Composer):
class Additive(Method):
"""We assume an adversary adds perturbation to the input."""

def compose(self, perturbation, *, input, target):
def __call__(self, perturbation, *, input, target):
return input + perturbation


class Overlay(Composer):
class Overlay(Method):
"""We assume an adversary overlays a patch to the input."""

def compose(self, perturbation, *, input, target):
def __call__(self, perturbation, *, input, target):
# True is mutable, False is immutable.
mask = target["perturbable_mask"]

Expand All @@ -61,3 +47,52 @@ def compose(self, perturbation, *, input, target):
mask = mask.to(input)

return input * (1 - mask) + perturbation * mask


class Composer:
"""A modality-aware composer.

Example usage: `Composer(rgb=Overlay(), depth=Additive())`. Non-modality composer can be
defined as `Composer(method=Additive())`.
"""

def __init__(self, **modality_methods):
self.modality_methods = modality_methods

def __call__(
self,
perturbation: torch.Tensor | dict[str, torch.Tensor] | tuple | list,
*,
input: torch.Tensor | dict[str, torch.Tensor] | tuple | list,
target: torch.Tensor | dict[str, Any] | tuple | list,
modality: str = "method",
) -> torch.Tensor:
"""Recursively compose output from perturbation and input."""
assert type(perturbation) == type(input)

if isinstance(perturbation, torch.Tensor):
# Finally we can compose output with tensors.
method = self.modality_methods[modality]
output = method(perturbation, input=input, target=target)
return output
elif isinstance(perturbation, dict):
# The dict input has modalities specified in keys, passing them recursively.
# For non-modality input that does not have dict, modality="method" by default.
output = {}
for modality in perturbation.keys():
output[modality] = self(
perturbation[modality], input=input[modality], target=target, modality=modality
)
return output
elif isinstance(perturbation, (list, tuple)):
# We assume a modality-dictionary only contains tensors, but not list/tuple.
assert modality == "method"
# The list or tuple input is a collection of sub-input and sub-target.
output = []
for pert_i, input_i, target_i in zip(perturbation, input, target):
output.append(self(pert_i, input=input_i, target=target_i, modality=modality))
if isinstance(perturbation, tuple):
output = tuple(output)
return output
else:
raise ValueError(f"Unsupported data type of perturbation: {type(perturbation)}.")
4 changes: 3 additions & 1 deletion mart/configs/attack/composer/additive.yaml
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
_target_: mart.attack.composer.Additive
_target_: mart.attack.Composer
method:
_target_: mart.attack.composer.Additive
4 changes: 3 additions & 1 deletion mart/configs/attack/composer/overlay.yaml
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
_target_: mart.attack.composer.Overlay
_target_: mart.attack.Composer
method:
_target_: mart.attack.composer.Overlay
5 changes: 5 additions & 0 deletions mart/configs/attack/composer/rgb_overlay_depth_additive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: mart.attack.Composer
rgb:
_target_: mart.attack.composer.Overlay
depth:
_target_: mart.attack.composer.Additive
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defaults:
- callbacks: [progress_bar, image_visualizer]
- objective: zero_ap
- gain: rcnn_training_loss
- composer: overlay
- composer: rgb_overlay_depth_additive
- enforcer: default
- enforcer/constraints@enforcer.rgb: [mask, pixel_range]

Expand Down
79 changes: 72 additions & 7 deletions tests/test_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,87 @@

import torch

from mart.attack.composer import Additive, Overlay
from mart.attack.composer import Additive, Composer, Overlay


def test_additive_composer_forward(input_data, target_data, perturbation):
composer = Additive()
def test_additive_method_forward(input_data, target_data, perturbation):
method = Additive()

output = composer(perturbation, input=input_data, target=target_data)
output = method(perturbation, input=input_data, target=target_data)
expected_output = input_data + perturbation
torch.testing.assert_close(output, expected_output, equal_nan=True)


def test_overlay_composer_forward(input_data, target_data, perturbation):
composer = Overlay()
def test_overlay_method_forward(input_data, target_data, perturbation):
method = Overlay()

output = composer(perturbation, input=input_data, target=target_data)
output = method(perturbation, input=input_data, target=target_data)
mask = target_data["perturbable_mask"]
mask = mask.to(input_data)
expected_output = input_data * (1 - mask) + perturbation
torch.testing.assert_close(output, expected_output, equal_nan=True)


def test_non_modality_composer_forward(input_data, target_data, perturbation):
composer = Composer(method=Additive())

# tensor input
pert = perturbation
input = input_data
target = target_data
output = composer(pert, input=input, target=target)
expected_output = input_data + perturbation
torch.testing.assert_close(output, expected_output, equal_nan=True)

# tuple of tensor input
pert = (perturbation,)
input = (input_data,)
target = (target_data,)
output = composer(pert, input=input, target=target)
expected_output = (input_data + perturbation,)
assert type(output) == type(expected_output)
torch.testing.assert_close(output[0], expected_output[0], equal_nan=True)


def test_modality_composer_forward(input_data, target_data, perturbation):
input = {"rgb": input_data, "depth": input_data}
target = target_data
pert = {"rgb": perturbation, "depth": perturbation}

rgb_method = Additive()
depth_method = Overlay()
composer = Composer(rgb=rgb_method, depth=depth_method)

# dict
output = composer(pert, input=input, target=target)
expected_output = {
"rgb": rgb_method(perturbation, input=input_data, target=target_data),
"depth": depth_method(perturbation, input=input_data, target=target_data),
}
assert type(output) == type(expected_output)
torch.testing.assert_close(output["rgb"], expected_output["rgb"], equal_nan=True)
torch.testing.assert_close(output["depth"], expected_output["depth"], equal_nan=True)

# list of dict
output = composer([pert], input=[input], target=[target])
expected_output = [
{
"rgb": rgb_method(perturbation, input=input_data, target=target_data),
"depth": depth_method(perturbation, input=input_data, target=target_data),
}
]
assert type(output) == type(expected_output)
torch.testing.assert_close(output[0]["rgb"], expected_output[0]["rgb"], equal_nan=True)
torch.testing.assert_close(output[0]["depth"], expected_output[0]["depth"], equal_nan=True)

# tuple of dict
output = composer((pert,), input=(input,), target=(target,))
expected_output = (
{
"rgb": rgb_method(perturbation, input=input_data, target=target_data),
"depth": depth_method(perturbation, input=input_data, target=target_data),
},
)
assert type(output) == type(expected_output)
torch.testing.assert_close(output[0]["rgb"], expected_output[0]["rgb"], equal_nan=True)
torch.testing.assert_close(output[0]["depth"], expected_output[0]["depth"], equal_nan=True)