diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index a34537a44..1656cdf2b 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -114,7 +114,7 @@ def attribute( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = False, + enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -744,7 +744,7 @@ def attribute_future( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = False, + enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> Future[TensorOrTupleOfTensorsGeneric]: r""" diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index b9630ab73..238f6958e 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -115,7 +115,7 @@ def attribute( # type: ignore feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = False, + enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -304,7 +304,7 @@ def attribute_future( feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = False, + enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> Future[TensorOrTupleOfTensorsGeneric]: """ diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index ccd752e64..3e5b9c2b6 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -267,7 +267,6 @@ def attribute( # type: ignore shift_counts=tuple(shift_counts), strides=strides, show_progress=show_progress, - enable_cross_tensor_attribution=True, ) def attribute_future(self) -> None: diff --git a/tests/attr/test_feature_permutation.py b/tests/attr/test_feature_permutation.py index 12f8d1739..329fea812 100644 --- a/tests/attr/test_feature_permutation.py +++ b/tests/attr/test_feature_permutation.py @@ -120,14 +120,14 @@ def forward_func(x: Tensor) -> Tensor: inp = torch.tensor([[1.0, 2.0]]) assertTensorAlmostEqual( self, - feature_importance.attribute(inp), + feature_importance.attribute(inp, enable_cross_tensor_attribution=False), torch.tensor([[0.0, 0.0]]), delta=0.0, ) feature_importance._min_examples_per_batch = 1 with self.assertRaises(AssertionError): - feature_importance.attribute(inp) + feature_importance.attribute(inp, enable_cross_tensor_attribution=False) def test_simple_input_with_min_examples_in_group(self) -> None: def forward_func(x: Tensor) -> Tensor: