From 58ad08888d385a1c442e06673591ac7a50b22f2d Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 29 Mar 2023 16:56:49 -0700 Subject: [PATCH 01/11] Add ModalityComposer. --- mart/attack/composer.py | 49 ++++++++++++++++++++++ mart/configs/attack/composer/modality.yaml | 1 + 2 files changed, 50 insertions(+) create mode 100644 mart/configs/attack/composer/modality.yaml diff --git a/mart/attack/composer.py b/mart/attack/composer.py index ddfdc45b..824fbbca 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -11,6 +11,8 @@ import torch +__all__ = ["ModalityComposer"] + class Composer(abc.ABC): def __call__( @@ -61,3 +63,50 @@ def compose(self, perturbation, *, input, target): mask = mask.to(input) return input * (1 - mask) + perturbation * mask + + +class ModalityComposer(Composer): + def __init__(self, **modality_composers): + self.modality_composers = modality_composers + + def __call__( + self, + perturbation: torch.Tensor | tuple, + *, + input: torch.Tensor | tuple, + target: torch.Tensor | dict[str, Any] | tuple, + modality: str = "default", + **kwargs, + ) -> torch.Tensor | tuple: + """Recursively compose output from perturbation and input.""" + if isinstance(perturbation, torch.Tensor): + # Finally we can compose output with tensors. + composer = self.modality_composers[modality] + output = composer(perturbation, input=input, target=target) + return output + elif isinstance(perturbation, dict): + # The dict input has modalities specified in keys, passing them recursively. + output = {} + for modality in perturbation.keys(): + output[modality] = self( + perturbation[modality], input=input[modality], target=target, modality=modality + ) + return output + elif isinstance(perturbation, (list, tuple)): + # The list or tuple input is a collection of sub-input and sub-target. + output = [] + for pert_i, input_i, target_i in zip(perturbation, input, target): + output.append(self(pert_i, input=input_i, target=target_i)) + if isinstance(perturbation, tuple): + output = tuple(output) + return output + + # We have to implement an abstract method... + def compose( + self, + perturbation: torch.Tensor, + *, + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], + ) -> torch.Tensor: + pass diff --git a/mart/configs/attack/composer/modality.yaml b/mart/configs/attack/composer/modality.yaml new file mode 100644 index 00000000..e84f946c --- /dev/null +++ b/mart/configs/attack/composer/modality.yaml @@ -0,0 +1 @@ +_target_: mart.attack.ModalityComposer From 8b54d1955f667e4469213308d7339c8115105267 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 29 Mar 2023 16:57:08 -0700 Subject: [PATCH 02/11] Add test. --- tests/test_composer.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/test_composer.py b/tests/test_composer.py index 06f215e6..bcc75fcf 100644 --- a/tests/test_composer.py +++ b/tests/test_composer.py @@ -6,7 +6,7 @@ import torch -from mart.attack.composer import Additive, Overlay +from mart.attack.composer import Additive, ModalityComposer, Overlay def test_additive_composer_forward(input_data, target_data, perturbation): @@ -25,3 +25,23 @@ def test_overlay_composer_forward(input_data, target_data, perturbation): mask = mask.to(input_data) expected_output = input_data * (1 - mask) + perturbation torch.testing.assert_close(output, expected_output, equal_nan=True) + + +def test_modality_composer_forward(input_data, target_data, perturbation): + input = [{"rgb": input_data, "depth": input_data}] + target = [target_data] + pert = [{"rgb": perturbation, "depth": perturbation}] + + rgb_composer = Additive() + depth_composer = Overlay() + modality_composer = ModalityComposer(rgb=rgb_composer, depth=depth_composer) + + output = modality_composer(pert, input=input, target=target) + expected_output = [ + { + "rgb": rgb_composer(perturbation, input=input_data, target=target_data), + "depth": depth_composer(perturbation, input=input_data, target=target_data), + } + ] + torch.testing.assert_close(output[0]["rgb"], expected_output[0]["rgb"], equal_nan=True) + torch.testing.assert_close(output[0]["depth"], expected_output[0]["depth"], equal_nan=True) From cb29cdedc026f8c637aa7489de32e329d4b4da16 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 11:01:46 -0700 Subject: [PATCH 03/11] Remove batch-aware. --- mart/attack/composer.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 824fbbca..1a865243 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -17,19 +17,13 @@ class Composer(abc.ABC): def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor, *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], **kwargs, - ) -> torch.Tensor | tuple: - if isinstance(perturbation, tuple): - input_adv = tuple( - self.compose(perturbation_i, input=input_i, target=target_i) - for perturbation_i, input_i, target_i in zip(perturbation, input, target) - ) - else: - input_adv = self.compose(perturbation, input=input, target=target) + ) -> torch.Tensor: + input_adv = self.compose(perturbation, input=input, target=target) return input_adv From 3a0fb5693005a83641ead6e3c783b10a92d41667 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 11:02:09 -0700 Subject: [PATCH 04/11] Revert "Remove batch-aware." This reverts commit cb29cdedc026f8c637aa7489de32e329d4b4da16. --- mart/attack/composer.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 1a865243..824fbbca 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -17,13 +17,19 @@ class Composer(abc.ABC): def __call__( self, - perturbation: torch.Tensor, + perturbation: torch.Tensor | tuple, *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], + input: torch.Tensor | tuple, + target: torch.Tensor | dict[str, Any] | tuple, **kwargs, - ) -> torch.Tensor: - input_adv = self.compose(perturbation, input=input, target=target) + ) -> torch.Tensor | tuple: + if isinstance(perturbation, tuple): + input_adv = tuple( + self.compose(perturbation_i, input=input_i, target=target_i) + for perturbation_i, input_i, target_i in zip(perturbation, input, target) + ) + else: + input_adv = self.compose(perturbation, input=input, target=target) return input_adv From b8e423483a9b237ea7c242fe86c3ba78c9d5b5db Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 11:13:45 -0700 Subject: [PATCH 05/11] Add tuple and dict test. --- tests/test_composer.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/test_composer.py b/tests/test_composer.py index bcc75fcf..4576a498 100644 --- a/tests/test_composer.py +++ b/tests/test_composer.py @@ -28,20 +28,44 @@ def test_overlay_composer_forward(input_data, target_data, perturbation): def test_modality_composer_forward(input_data, target_data, perturbation): - input = [{"rgb": input_data, "depth": input_data}] - target = [target_data] - pert = [{"rgb": perturbation, "depth": perturbation}] + input = {"rgb": input_data, "depth": input_data} + target = target_data + pert = {"rgb": perturbation, "depth": perturbation} rgb_composer = Additive() depth_composer = Overlay() modality_composer = ModalityComposer(rgb=rgb_composer, depth=depth_composer) + # dict output = modality_composer(pert, input=input, target=target) + expected_output = { + "rgb": rgb_composer(perturbation, input=input_data, target=target_data), + "depth": depth_composer(perturbation, input=input_data, target=target_data), + } + assert type(output) == type(expected_output) + torch.testing.assert_close(output["rgb"], expected_output["rgb"], equal_nan=True) + torch.testing.assert_close(output["depth"], expected_output["depth"], equal_nan=True) + + # list of dict + output = modality_composer([pert], input=[input], target=[target]) expected_output = [ { "rgb": rgb_composer(perturbation, input=input_data, target=target_data), "depth": depth_composer(perturbation, input=input_data, target=target_data), } ] + assert type(output) == type(expected_output) + torch.testing.assert_close(output[0]["rgb"], expected_output[0]["rgb"], equal_nan=True) + torch.testing.assert_close(output[0]["depth"], expected_output[0]["depth"], equal_nan=True) + + # tuple of dict + output = modality_composer((pert,), input=(input,), target=(target,)) + expected_output = ( + { + "rgb": rgb_composer(perturbation, input=input_data, target=target_data), + "depth": depth_composer(perturbation, input=input_data, target=target_data), + }, + ) + assert type(output) == type(expected_output) torch.testing.assert_close(output[0]["rgb"], expected_output[0]["rgb"], equal_nan=True) torch.testing.assert_close(output[0]["depth"], expected_output[0]["depth"], equal_nan=True) From 0d401536eb149eaca4d6787e8d183d0b2114ef18 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 11:16:41 -0700 Subject: [PATCH 06/11] Improve implementation of recursion. --- mart/attack/composer.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 824fbbca..27c115de 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -75,10 +75,23 @@ def __call__( *, input: torch.Tensor | tuple, target: torch.Tensor | dict[str, Any] | tuple, - modality: str = "default", **kwargs, ) -> torch.Tensor | tuple: + # Get rid of batch-aware in Composer.__call__(), because we have the recursive self.compose(). + input_adv = self.compose(perturbation, input=input, target=target) + return input_adv + + def compose( + self, + perturbation: torch.Tensor | dict[str, torch.Tensor] | tuple | list, + *, + input: torch.Tensor | dict[str, torch.Tensor] | tuple | list, + target: torch.Tensor | dict[str, Any] | tuple | list, + modality: str = "default", + ) -> torch.Tensor: """Recursively compose output from perturbation and input.""" + assert type(perturbation) == type(input) + if isinstance(perturbation, torch.Tensor): # Finally we can compose output with tensors. composer = self.modality_composers[modality] @@ -88,7 +101,7 @@ def __call__( # The dict input has modalities specified in keys, passing them recursively. output = {} for modality in perturbation.keys(): - output[modality] = self( + output[modality] = self.compose( perturbation[modality], input=input[modality], target=target, modality=modality ) return output @@ -96,17 +109,9 @@ def __call__( # The list or tuple input is a collection of sub-input and sub-target. output = [] for pert_i, input_i, target_i in zip(perturbation, input, target): - output.append(self(pert_i, input=input_i, target=target_i)) + output.append(self.compose(pert_i, input=input_i, target=target_i)) if isinstance(perturbation, tuple): output = tuple(output) return output - - # We have to implement an abstract method... - def compose( - self, - perturbation: torch.Tensor, - *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], - ) -> torch.Tensor: - pass + else: + raise ValueError(f"Unsupported data type of perturbation: {type(perturbation)}.") From 119a924ecb038aeb85edd13175734b2e1e4ed253 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 11:30:06 -0700 Subject: [PATCH 07/11] Complete __all__. --- mart/attack/composer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 27c115de..d842e418 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -11,7 +11,7 @@ import torch -__all__ = ["ModalityComposer"] +__all__ = ["Additive", "Overlay", "ModalityComposer"] class Composer(abc.ABC): From 261a18a1d958591a8d0a5a81ff1cef00e89739b9 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 11:45:28 -0700 Subject: [PATCH 08/11] Pass along modaity in tuple/list calls. --- mart/attack/composer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mart/attack/composer.py b/mart/attack/composer.py index d842e418..60053773 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -109,7 +109,9 @@ def compose( # The list or tuple input is a collection of sub-input and sub-target. output = [] for pert_i, input_i, target_i in zip(perturbation, input, target): - output.append(self.compose(pert_i, input=input_i, target=target_i)) + output.append( + self.compose(pert_i, input=input_i, target=target_i, modality=modality) + ) if isinstance(perturbation, tuple): output = tuple(output) return output From 6f3d2b5cc58f982229258a1c9fe1b1fc1d4c1ace Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 11:52:08 -0700 Subject: [PATCH 09/11] Comment. --- mart/attack/composer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 60053773..04f2ee44 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -66,6 +66,12 @@ def compose(self, perturbation, *, input, target): class ModalityComposer(Composer): + """A modality-aware composer. + + Example usage: `ModalityComposer(rgb=Overlay(), depth=Additive())`. Note that + `ModalityComposer(default=Additive())` is equivalent with `Additive()`. + """ + def __init__(self, **modality_composers): self.modality_composers = modality_composers @@ -77,7 +83,7 @@ def __call__( target: torch.Tensor | dict[str, Any] | tuple, **kwargs, ) -> torch.Tensor | tuple: - # Get rid of batch-aware in Composer.__call__(), because we have the recursive self.compose(). + # Bypass batch-aware in Composer.__call__(), because we have the recursive self.compose(). input_adv = self.compose(perturbation, input=input, target=target) return input_adv From ca68bfb910bedbf43a2b6f4e4322009b659c2c2f Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 12:45:10 -0700 Subject: [PATCH 10/11] Add example config. --- mart/configs/attack/composer/rgb_overlay_depth_additive.yaml | 5 +++++ mart/configs/attack/object_detection_rgb_mask_adversary.yaml | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 mart/configs/attack/composer/rgb_overlay_depth_additive.yaml diff --git a/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml b/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml new file mode 100644 index 00000000..84cc6ac4 --- /dev/null +++ b/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml @@ -0,0 +1,5 @@ +defaults: + - .@rgb: overlay + - .@depth: additive + +_target_: mart.attack.ModalityComposer diff --git a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml index a2bb039e..ba281992 100644 --- a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml +++ b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml @@ -7,7 +7,7 @@ defaults: - callbacks: [progress_bar, image_visualizer] - objective: zero_ap - gain: rcnn_training_loss - - composer: overlay + - composer: rgb_overlay_depth_additive - enforcer: default - enforcer/constraints@enforcer.rgb: [mask, pixel_range] From 6e530dd1d6414045b3c332d6f4bdc20ad036bbde Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 30 Mar 2023 15:46:41 -0700 Subject: [PATCH 11/11] Make a modality-aware Composer, just like Enforcer. --- mart/attack/composer.py | 69 ++++++------------- mart/configs/attack/composer/additive.yaml | 4 +- mart/configs/attack/composer/modality.yaml | 1 - mart/configs/attack/composer/overlay.yaml | 4 +- .../composer/rgb_overlay_depth_additive.yaml | 10 +-- tests/test_composer.py | 59 +++++++++++----- 6 files changed, 72 insertions(+), 75 deletions(-) delete mode 100644 mart/configs/attack/composer/modality.yaml diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 04f2ee44..7661b9f3 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -11,30 +11,14 @@ import torch -__all__ = ["Additive", "Overlay", "ModalityComposer"] +__all__ = ["Composer"] -class Composer(abc.ABC): - def __call__( - self, - perturbation: torch.Tensor | tuple, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - **kwargs, - ) -> torch.Tensor | tuple: - if isinstance(perturbation, tuple): - input_adv = tuple( - self.compose(perturbation_i, input=input_i, target=target_i) - for perturbation_i, input_i, target_i in zip(perturbation, input, target) - ) - else: - input_adv = self.compose(perturbation, input=input, target=target) - - return input_adv +class Method(abc.ABC): + """Composition method base class.""" @abc.abstractmethod - def compose( + def __call__( self, perturbation: torch.Tensor, *, @@ -44,17 +28,17 @@ def compose( raise NotImplementedError -class Additive(Composer): +class Additive(Method): """We assume an adversary adds perturbation to the input.""" - def compose(self, perturbation, *, input, target): + def __call__(self, perturbation, *, input, target): return input + perturbation -class Overlay(Composer): +class Overlay(Method): """We assume an adversary overlays a patch to the input.""" - def compose(self, perturbation, *, input, target): + def __call__(self, perturbation, *, input, target): # True is mutable, False is immutable. mask = target["perturbable_mask"] @@ -65,59 +49,48 @@ def compose(self, perturbation, *, input, target): return input * (1 - mask) + perturbation * mask -class ModalityComposer(Composer): +class Composer: """A modality-aware composer. - Example usage: `ModalityComposer(rgb=Overlay(), depth=Additive())`. Note that - `ModalityComposer(default=Additive())` is equivalent with `Additive()`. + Example usage: `Composer(rgb=Overlay(), depth=Additive())`. Non-modality composer can be + defined as `Composer(method=Additive())`. """ - def __init__(self, **modality_composers): - self.modality_composers = modality_composers + def __init__(self, **modality_methods): + self.modality_methods = modality_methods def __call__( - self, - perturbation: torch.Tensor | tuple, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - **kwargs, - ) -> torch.Tensor | tuple: - # Bypass batch-aware in Composer.__call__(), because we have the recursive self.compose(). - input_adv = self.compose(perturbation, input=input, target=target) - return input_adv - - def compose( self, perturbation: torch.Tensor | dict[str, torch.Tensor] | tuple | list, *, input: torch.Tensor | dict[str, torch.Tensor] | tuple | list, target: torch.Tensor | dict[str, Any] | tuple | list, - modality: str = "default", + modality: str = "method", ) -> torch.Tensor: """Recursively compose output from perturbation and input.""" assert type(perturbation) == type(input) if isinstance(perturbation, torch.Tensor): # Finally we can compose output with tensors. - composer = self.modality_composers[modality] - output = composer(perturbation, input=input, target=target) + method = self.modality_methods[modality] + output = method(perturbation, input=input, target=target) return output elif isinstance(perturbation, dict): # The dict input has modalities specified in keys, passing them recursively. + # For non-modality input that does not have dict, modality="method" by default. output = {} for modality in perturbation.keys(): - output[modality] = self.compose( + output[modality] = self( perturbation[modality], input=input[modality], target=target, modality=modality ) return output elif isinstance(perturbation, (list, tuple)): + # We assume a modality-dictionary only contains tensors, but not list/tuple. + assert modality == "method" # The list or tuple input is a collection of sub-input and sub-target. output = [] for pert_i, input_i, target_i in zip(perturbation, input, target): - output.append( - self.compose(pert_i, input=input_i, target=target_i, modality=modality) - ) + output.append(self(pert_i, input=input_i, target=target_i, modality=modality)) if isinstance(perturbation, tuple): output = tuple(output) return output diff --git a/mart/configs/attack/composer/additive.yaml b/mart/configs/attack/composer/additive.yaml index c12f939e..ce31089a 100644 --- a/mart/configs/attack/composer/additive.yaml +++ b/mart/configs/attack/composer/additive.yaml @@ -1 +1,3 @@ -_target_: mart.attack.composer.Additive +_target_: mart.attack.Composer +method: + _target_: mart.attack.composer.Additive diff --git a/mart/configs/attack/composer/modality.yaml b/mart/configs/attack/composer/modality.yaml deleted file mode 100644 index e84f946c..00000000 --- a/mart/configs/attack/composer/modality.yaml +++ /dev/null @@ -1 +0,0 @@ -_target_: mart.attack.ModalityComposer diff --git a/mart/configs/attack/composer/overlay.yaml b/mart/configs/attack/composer/overlay.yaml index 469f7245..fceaafc2 100644 --- a/mart/configs/attack/composer/overlay.yaml +++ b/mart/configs/attack/composer/overlay.yaml @@ -1 +1,3 @@ -_target_: mart.attack.composer.Overlay +_target_: mart.attack.Composer +method: + _target_: mart.attack.composer.Overlay diff --git a/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml b/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml index 84cc6ac4..50a182a9 100644 --- a/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml +++ b/mart/configs/attack/composer/rgb_overlay_depth_additive.yaml @@ -1,5 +1,5 @@ -defaults: - - .@rgb: overlay - - .@depth: additive - -_target_: mart.attack.ModalityComposer +_target_: mart.attack.Composer +rgb: + _target_: mart.attack.composer.Overlay +depth: + _target_: mart.attack.composer.Additive diff --git a/tests/test_composer.py b/tests/test_composer.py index 4576a498..315726d9 100644 --- a/tests/test_composer.py +++ b/tests/test_composer.py @@ -6,52 +6,73 @@ import torch -from mart.attack.composer import Additive, ModalityComposer, Overlay +from mart.attack.composer import Additive, Composer, Overlay -def test_additive_composer_forward(input_data, target_data, perturbation): - composer = Additive() +def test_additive_method_forward(input_data, target_data, perturbation): + method = Additive() - output = composer(perturbation, input=input_data, target=target_data) + output = method(perturbation, input=input_data, target=target_data) expected_output = input_data + perturbation torch.testing.assert_close(output, expected_output, equal_nan=True) -def test_overlay_composer_forward(input_data, target_data, perturbation): - composer = Overlay() +def test_overlay_method_forward(input_data, target_data, perturbation): + method = Overlay() - output = composer(perturbation, input=input_data, target=target_data) + output = method(perturbation, input=input_data, target=target_data) mask = target_data["perturbable_mask"] mask = mask.to(input_data) expected_output = input_data * (1 - mask) + perturbation torch.testing.assert_close(output, expected_output, equal_nan=True) +def test_non_modality_composer_forward(input_data, target_data, perturbation): + composer = Composer(method=Additive()) + + # tensor input + pert = perturbation + input = input_data + target = target_data + output = composer(pert, input=input, target=target) + expected_output = input_data + perturbation + torch.testing.assert_close(output, expected_output, equal_nan=True) + + # tuple of tensor input + pert = (perturbation,) + input = (input_data,) + target = (target_data,) + output = composer(pert, input=input, target=target) + expected_output = (input_data + perturbation,) + assert type(output) == type(expected_output) + torch.testing.assert_close(output[0], expected_output[0], equal_nan=True) + + def test_modality_composer_forward(input_data, target_data, perturbation): input = {"rgb": input_data, "depth": input_data} target = target_data pert = {"rgb": perturbation, "depth": perturbation} - rgb_composer = Additive() - depth_composer = Overlay() - modality_composer = ModalityComposer(rgb=rgb_composer, depth=depth_composer) + rgb_method = Additive() + depth_method = Overlay() + composer = Composer(rgb=rgb_method, depth=depth_method) # dict - output = modality_composer(pert, input=input, target=target) + output = composer(pert, input=input, target=target) expected_output = { - "rgb": rgb_composer(perturbation, input=input_data, target=target_data), - "depth": depth_composer(perturbation, input=input_data, target=target_data), + "rgb": rgb_method(perturbation, input=input_data, target=target_data), + "depth": depth_method(perturbation, input=input_data, target=target_data), } assert type(output) == type(expected_output) torch.testing.assert_close(output["rgb"], expected_output["rgb"], equal_nan=True) torch.testing.assert_close(output["depth"], expected_output["depth"], equal_nan=True) # list of dict - output = modality_composer([pert], input=[input], target=[target]) + output = composer([pert], input=[input], target=[target]) expected_output = [ { - "rgb": rgb_composer(perturbation, input=input_data, target=target_data), - "depth": depth_composer(perturbation, input=input_data, target=target_data), + "rgb": rgb_method(perturbation, input=input_data, target=target_data), + "depth": depth_method(perturbation, input=input_data, target=target_data), } ] assert type(output) == type(expected_output) @@ -59,11 +80,11 @@ def test_modality_composer_forward(input_data, target_data, perturbation): torch.testing.assert_close(output[0]["depth"], expected_output[0]["depth"], equal_nan=True) # tuple of dict - output = modality_composer((pert,), input=(input,), target=(target,)) + output = composer((pert,), input=(input,), target=(target,)) expected_output = ( { - "rgb": rgb_composer(perturbation, input=input_data, target=target_data), - "depth": depth_composer(perturbation, input=input_data, target=target_data), + "rgb": rgb_method(perturbation, input=input_data, target=target_data), + "depth": depth_method(perturbation, input=input_data, target=target_data), }, ) assert type(output) == type(expected_output)