diff --git a/mart/attack/composer.py b/mart/attack/composer.py index ddfdc45b..7661b9f3 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -11,28 +11,14 @@ import torch +__all__ = ["Composer"] -class Composer(abc.ABC): - def __call__( - self, - perturbation: torch.Tensor | tuple, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - **kwargs, - ) -> torch.Tensor | tuple: - if isinstance(perturbation, tuple): - input_adv = tuple( - self.compose(perturbation_i, input=input_i, target=target_i) - for perturbation_i, input_i, target_i in zip(perturbation, input, target) - ) - else: - input_adv = self.compose(perturbation, input=input, target=target) - return input_adv +class Method(abc.ABC): + """Composition method base class.""" @abc.abstractmethod - def compose( + def __call__( self, perturbation: torch.Tensor, *, @@ -42,17 +28,17 @@ def compose( raise NotImplementedError -class Additive(Composer): +class Additive(Method): """We assume an adversary adds perturbation to the input.""" - def compose(self, perturbation, *, input, target): + def __call__(self, perturbation, *, input, target): return input + perturbation -class Overlay(Composer): +class Overlay(Method): """We assume an adversary overlays a patch to the input.""" - def compose(self, perturbation, *, input, target): + def __call__(self, perturbation, *, input, target): # True is mutable, False is immutable. mask = target["perturbable_mask"] @@ -61,3 +47,52 @@ def compose(self, perturbation, *, input, target): mask = mask.to(input) return input * (1 - mask) + perturbation * mask + + +class Composer: + """A modality-aware composer. + + Example usage: `Composer(rgb=Overlay(), depth=Additive())`. Non-modality composer can be + defined as `Composer(method=Additive())`. + """ + + def __init__(self, **modality_methods): + self.modality_methods = modality_methods + + def __call__( + self, + perturbation: torch.Tensor | dict[str, torch.Tensor] | tuple | list, + *, + input: torch.Tensor | dict[str, torch.Tensor] | tuple | list, + target: torch.Tensor | dict[str, Any] | tuple | list, + modality: str = "method", + ) -> torch.Tensor: + """Recursively compose output from perturbation and input.""" + assert type(perturbation) == type(input) + + if isinstance(perturbation, torch.Tensor): + # Finally we can compose output with tensors. + method = self.modality_methods[modality] + output = method(perturbation, input=input, target=target) + return output + elif isinstance(perturbation, dict): + # The dict input has modalities specified in keys, passing them recursively. + # For non-modality input that does not have dict, modality="method" by default. + output = {} + for modality in perturbation.keys(): + output[modality] = self( + perturbation[modality], input=input[modality], target=target, modality=modality + ) + return output + elif isinstance(perturbation, (list, tuple)): + # We assume a modality-dictionary only contains tensors, but not list/tuple. + assert modality == "method" + # The list or tuple input is a collection of sub-input and sub-target. + output = [] + for pert_i, input_i, target_i in zip(perturbation, input, target): + output.append(self(pert_i, input=input_i, target=target_i, modality=modality)) + if isinstance(perturbation, tuple): + output = tuple(output) + return output + else: + raise ValueError(f"Unsupported data type of perturbation: {type(perturbation)}.") diff --git a/mart/configs/attack/composer/additive.yaml b/mart/configs/attack/composer/additive.yaml index c12f939e..ce31089a 100644 --- a/mart/configs/attack/composer/additive.yaml +++ b/mart/configs/attack/composer/additive.yaml @@ -1 +1,3 @@ -_target_: mart.attack.composer.Additive +_target_: mart.attack.Composer +method: + _target_: mart.attack.composer.Additive diff --git a/mart/configs/attack/composer/overlay.yaml b/mart/configs/attack/composer/overlay.yaml index 469f7245..fceaafc2 100644 --- a/mart/configs/attack/composer/overlay.yaml +++ b/mart/configs/attack/composer/overlay.yaml @@ -1 +1,3 @@ -_target_: mart.attack.composer.Overlay +_target_: mart.attack.Composer +method: + _target_: mart.attack.composer.Overlay diff --git a/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml b/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml new file mode 100644 index 00000000..50a182a9 --- /dev/null +++ b/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml @@ -0,0 +1,5 @@ +_target_: mart.attack.Composer +rgb: + _target_: mart.attack.composer.Overlay +depth: + _target_: mart.attack.composer.Additive diff --git a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml index a2bb039e..ba281992 100644 --- a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml +++ b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml @@ -7,7 +7,7 @@ defaults: - callbacks: [progress_bar, image_visualizer] - objective: zero_ap - gain: rcnn_training_loss - - composer: overlay + - composer: rgb_overlay_depth_additive - enforcer: default - enforcer/constraints@enforcer.rgb: [mask, pixel_range] diff --git a/tests/test_composer.py b/tests/test_composer.py index 06f215e6..315726d9 100644 --- a/tests/test_composer.py +++ b/tests/test_composer.py @@ -6,22 +6,87 @@ import torch -from mart.attack.composer import Additive, Overlay +from mart.attack.composer import Additive, Composer, Overlay -def test_additive_composer_forward(input_data, target_data, perturbation): - composer = Additive() +def test_additive_method_forward(input_data, target_data, perturbation): + method = Additive() - output = composer(perturbation, input=input_data, target=target_data) + output = method(perturbation, input=input_data, target=target_data) expected_output = input_data + perturbation torch.testing.assert_close(output, expected_output, equal_nan=True) -def test_overlay_composer_forward(input_data, target_data, perturbation): - composer = Overlay() +def test_overlay_method_forward(input_data, target_data, perturbation): + method = Overlay() - output = composer(perturbation, input=input_data, target=target_data) + output = method(perturbation, input=input_data, target=target_data) mask = target_data["perturbable_mask"] mask = mask.to(input_data) expected_output = input_data * (1 - mask) + perturbation torch.testing.assert_close(output, expected_output, equal_nan=True) + + +def test_non_modality_composer_forward(input_data, target_data, perturbation): + composer = Composer(method=Additive()) + + # tensor input + pert = perturbation + input = input_data + target = target_data + output = composer(pert, input=input, target=target) + expected_output = input_data + perturbation + torch.testing.assert_close(output, expected_output, equal_nan=True) + + # tuple of tensor input + pert = (perturbation,) + input = (input_data,) + target = (target_data,) + output = composer(pert, input=input, target=target) + expected_output = (input_data + perturbation,) + assert type(output) == type(expected_output) + torch.testing.assert_close(output[0], expected_output[0], equal_nan=True) + + +def test_modality_composer_forward(input_data, target_data, perturbation): + input = {"rgb": input_data, "depth": input_data} + target = target_data + pert = {"rgb": perturbation, "depth": perturbation} + + rgb_method = Additive() + depth_method = Overlay() + composer = Composer(rgb=rgb_method, depth=depth_method) + + # dict + output = composer(pert, input=input, target=target) + expected_output = { + "rgb": rgb_method(perturbation, input=input_data, target=target_data), + "depth": depth_method(perturbation, input=input_data, target=target_data), + } + assert type(output) == type(expected_output) + torch.testing.assert_close(output["rgb"], expected_output["rgb"], equal_nan=True) + torch.testing.assert_close(output["depth"], expected_output["depth"], equal_nan=True) + + # list of dict + output = composer([pert], input=[input], target=[target]) + expected_output = [ + { + "rgb": rgb_method(perturbation, input=input_data, target=target_data), + "depth": depth_method(perturbation, input=input_data, target=target_data), + } + ] + assert type(output) == type(expected_output) + torch.testing.assert_close(output[0]["rgb"], expected_output[0]["rgb"], equal_nan=True) + torch.testing.assert_close(output[0]["depth"], expected_output[0]["depth"], equal_nan=True) + + # tuple of dict + output = composer((pert,), input=(input,), target=(target,)) + expected_output = ( + { + "rgb": rgb_method(perturbation, input=input_data, target=target_data), + "depth": depth_method(perturbation, input=input_data, target=target_data), + }, + ) + assert type(output) == type(expected_output) + torch.testing.assert_close(output[0]["rgb"], expected_output[0]["rgb"], equal_nan=True) + torch.testing.assert_close(output[0]["depth"], expected_output[0]["depth"], equal_nan=True)