From 08c999225760e2cc38b59ef345707fe4d680e1d2 Mon Sep 17 00:00:00 2001 From: stav-af Date: Tue, 17 Jun 2025 16:15:38 +0200 Subject: [PATCH 01/21] Implement ReX --- captum/attr/_core/rex.py | 272 +++++++++++++++++++++++++++++++++++++++ tests/attr/test_rex.py | 77 +++++++++++ 2 files changed, 349 insertions(+) create mode 100644 captum/attr/_core/rex.py create mode 100644 tests/attr/test_rex.py diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py new file mode 100644 index 000000000..7046331ed --- /dev/null +++ b/captum/attr/_core/rex.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 + +# pyre-strict +import itertools +from typing import List +import torch +import math, heapq +from collections import deque +import random +from dataclasses import dataclass + +from captum.attr._utils.attribution import PerturbationAttribution +from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.common import _format_input_baseline, _validate_input +from captum.log.dummy_log import log_usage + +@dataclass +class Mutant: + partitions: List[List[int]] + data: List[int] + +def _occlude_data(data: torch.Tensor, partitions: List[List[int]], neutral) -> torch.Tensor: + mask = torch.ones_like(data, dtype=torch.bool) + for part in partitions: + mask[part] = False + + return torch.where(mask, data, neutral) + +def _powerset(iterable): + s = list(iterable) + return [list(combo) for r in range(len(s)+1) + for combo in itertools.combinations(s, r)] + + +def _partitions_combinations(partitions): + return list(filter( + lambda x: len(x) <= len(partitions), + _powerset(partitions))) + + +def _apply_responsibility(feature_importance, part, responsibility): + distributed_resp = responsibility / len(part) + for idx in part: + feature_importance[idx] = distributed_resp + + return feature_importance + +def _partitions_equal(partitions_a, partitions_b): + set_a = {tuple(p) if isinstance(p, list) else p for p in partitions_a} + set_b = {tuple(p) if isinstance(p, list) else p for p in partitions_b} + + return set_a == set_b + + +def _responsibility(subject_partition: List, consistent_set: List[List[int]]) -> float: + witnesses = [mut.partitions for mut in consistent_set if subject_partition not in mut.partitions] + + # a witness is valid if perturbing it results in a counterfactual + # dependence on the subject partition + + # hashset! + valid_witnesses = [] + for witness in witnesses: + counterfactual = [subject_partition] + witness + if not any(_partitions_equal(counterfactual, cst.partitions) for cst in consistent_set): + valid_witnesses.append(witness) + + if len(valid_witnesses) == 0: + return 0.0 + + min_mutant = min(valid_witnesses, key=len) + minpart = len(min_mutant) + + return 1.0 / (1.0 + float(minpart)) + + +class ReX(PerturbationAttribution): + """ + A perturbation-based approach to computing attribution, based on the + Halpern-Pearl definition of actual causality[1]. + + The approach works by + partitioning the input space, and masking each partition. Intuitively, if masking a + partition changes the prediction of the model, then that partition has + some responsibility (attribution > 0). Such partially masked partitions are called + mutants. The responsibility of a subject partition is defined as 1/(1+k) where + k is a minimum number of occluded partitions in a mutant which make forward_func's + output dependednt on the subject partition. + + Partitions with nonzero responsibility are recusrively re-partitioned and masked in a search. + The algorithm runs multiple such searches, where each subsequent search uses the previously + computed attribution map as a heuristic for partitioning. + + + [1] - halpern 06 + [2] - rex paper + """ + def __init__(self, forward_func): + r""" + Args: + forward_func (Callable): The function to be explained. Must return + a scalar for which the equality operator is defined. + """ + PerturbationAttribution.__init__(self, forward_func) + + @log_usage(part_of_slo=True) + def attribute(self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = 0, + *, + search_depth: int = 10, + n_partitions: int = 8, + n_searches: int = 5) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + inputs: + An input or tuple of inputs whose corresponding output is to be explained. Each input + must be of the shape expected by the forward_func. Where multiple examples are + provided, they must be listed in a tuple. + + baselines: + A neutral values to be used as occlusion values. Where a scalar is provided, it is used + as the masking value at each index. Where a tensor is provided, values are masked at + corresponding indices. Where a tuple of tensors is provided, it must be of the same length + as inputs; then baseline and input tensors are matched element-wise and treated as before. + + search_depth (optional): + The maximum depth to which ReX will search. Where one is not provided, the default is 4 + + n_partitions (optional): + The number of partitions to be made out of the input at each search step. + This must be at most hte size of each input, and at least 1. + """ + inputs, baselines = _format_input_baseline(inputs, baselines) + _validate_input(inputs, baselines) + + self._n_partitions = n_partitions + self._search_depth = search_depth + self._n_searches = n_searches + + is_tuple_inputs = isinstance(inputs, tuple) + is_tuple_baselines = isinstance(baselines, tuple) + + attributions = [] + + # match inputs and baselines, explain + if is_tuple_inputs and is_tuple_baselines: + for input, baseline in zip(inputs, baselines): + attributions.append(self._explain(input, baseline)) + elif is_tuple_inputs and not is_tuple_baselines: + for input in inputs: + attributions.append(self._explain(input, baselines)) + else: + attributions.append(self._explain(inputs, baselines)) + + return tuple(attributions) + + def _flatten(self, data): + self._original_shape = data.size() + return data.reshape(-1) + + def _unflatten(self, data): + return data.reshape(self._original_shape) + + + def _fast_partition(self, responsibility: List[int], choices: List[int]) -> List[List[int]]: + population = choices.copy() + random.shuffle(population) + + weights = [responsibility[i] for i in population] + if sum(weights) == 0: weights = [1/len(population) for _ in population] + + remaining_weight = sum(weights) + target_weight = remaining_weight / self._n_partitions + + zip_sorted = zip(weights, population) + weight_sorted, pop_sorted = zip(*sorted(zip_sorted, key=lambda x: x[0], reverse=True)) + + partitions, lp, rp, cumsum = [], 0, 0, 0 + while rp < len(weight_sorted): + cumsum += weight_sorted[rp] + remaining_weight -= weight_sorted[rp] + if cumsum >= target_weight or rp == len(weight_sorted) - 1: + partitions.append(list(pop_sorted[lp:rp + 1])) + lp = rp + 1 + cumsum = 0 + + rp += 1 + + return partitions + + + def _partition(self, responsibility: List[float], choices: List[int]) -> List[List[int]]: + population = choices.copy() + random.shuffle(population) + + weights = [responsibility[i] for i in population] + if sum(weights) == 0: weights = [1 for _ in choices] + + target_weight = sum(weights) / self._n_partitions + partitions = [] + + curr_weight = 0.0 + curr_partition = [] + + while population: + choice = random.choices(population, weights, k=1)[0] + idx = population.index(choice) + + population.pop(idx) + + weights = [responsibility[i] for i in population] + if sum(weights) == 0: weights = [1 for _ in population] + + curr_partition.append(choice) + curr_weight += responsibility[choice] + + if curr_weight > target_weight: + partitions.append(curr_partition) + curr_partition, curr_weight = [], 0.0 + + if curr_partition: + partitions.append(curr_partition) + + return partitions + + + + def _explain(self, input, baseline): + initial_prediction = self.forward_func(input) + flattened_input = self._flatten(input) + + n_features = flattened_input.numel() + feature_attribution = [0.0 for _ in range(n_features)] + + for _ in range(self._n_searches): + # by definition, root partition contains all indices + part_q = deque() + part_q.append((list(range(n_features)), 1.0, 0)) + + while part_q: + indices, parent_resp, depth = part_q.popleft() + partitions = self._fast_partition(feature_attribution, indices) + + mutants = [] + for idx_combination in _partitions_combinations(partitions): + occluded_data = _occlude_data(flattened_input, idx_combination, baseline) + mut = Mutant(partitions=idx_combination, data=occluded_data) + mutants.append(mut) + + cst_set = list(filter( + lambda mut: self.forward_func(self._unflatten(mut.data)) == initial_prediction, + mutants + )) + + for part in partitions: + resp = _responsibility(part, cst_set) + # update should be conditional on entropy? + feature_attribution = _apply_responsibility(feature_attribution, part, resp * parent_resp) + if resp > 0 and \ + len(part) > 1 and \ + depth < self._search_depth: + part_q.append((part, resp, depth + 1)) + + return self._unflatten(torch.tensor(feature_attribution)) + + + def multiplies_by_inputs(self) -> bool: + return False + + def has_convergence_delta(self) -> bool: + return True \ No newline at end of file diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py new file mode 100644 index 000000000..ab4d5800e --- /dev/null +++ b/tests/attr/test_rex.py @@ -0,0 +1,77 @@ +from captum.attr._core.rex import * + +from captum.testing.helpers.basic import BaseTest +from parameterized import parameterized + +import torch + +class Test(BaseTest): + # rename for convenience + ts = torch.tensor + + @parameterized.expand([ + # inputs: baselines: + (ts([1,2,3]), ts([[2,3], [3,4]])), + ((ts([1]),ts([2]),ts([3])), (ts([1]),ts([2]))), + ((ts([1])), ()), + ((), ts([1])) + ]) + def test_input_baseline_mismatch_throws(self, input, baseline): + rex = ReX(lambda x: 1/0) # dummy forward, should be unreachable + with self.assertRaises(AssertionError): + rex.attribute(input, baseline) + + + @parameterized.expand([ + (ts([1,2,3]), 0), + (ts([[1,2,3], [4,5,6]]), 0), + (ts([1,2,3,4]), ts([0,0,0,0])), + (ts([[1, 2], [1,2]])), ts([0,0]), + (ts([[[1,2], [3,4]], [[5,6], [7,8]]]), 0), + ((ts([1,2]), ts([3,4]), ts([5,6])), (0, 0, 0)), + ((ts([1,2]), ts([3,4]), ts([5,6])), (ts([0,0]), ts([0,0]), ts([0,0]))), + ((ts([1,2]), ts([3,4])), (ts([0,0]), ts([0, 0]))), + ]) + def test_valid_input_baseline(self, input, baseline): + rex = ReX(lambda x: True) + + attributions = rex.attribute(input, baseline)[0] + if isinstance(input, tuple): input = input[0] + + # Forward_func returns a constant, no responsibility in input + self.assertFalse(torch.sum(attributions, dim=None)) + self.assertEqual(attributions.size(), input.size()) + + + @parameterized.expand([ + # input # selected_idx + (ts([1,2,3]), 0), + (ts([[1,2], [3,4]]), (0, 1)), + (ts([[[1, 2], [3, 4]], [[5,6], [7,8]]]), (0, 1, 0)) + ]) + def test_selector_function(self, input, idx): + rex = ReX(lambda x: x[idx]) + attributions = rex.attribute(input, 0)[0] + + self.assertTrue(attributions[idx] == 1) + + attributions[idx] = 0 + self.assertFalse(torch.sum(attributions, dim=None)) + + + @parameterized.expand([ + # input shape # important idx + ((12, 12, 12), (1,2,1)), + ((12, 12, 12, 6), (1,1,4,1)), + # ((1920, 1080), (2, 4)) # image-like + ]) + def test_selector_function_large_input(self, input_shape, idx): + rex = ReX(lambda x: x[idx]) + + input = torch.ones(*input_shape) + attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=1)[0] + + self.assertTrue(attributions[idx]) + attributions[idx] = 0 + self.assertLess(torch.sum(attributions, dim=None), 1) + From a2aebb2409d944d0aef18dccdc9734a93aa197b1 Mon Sep 17 00:00:00 2001 From: stav-af Date: Tue, 29 Jul 2025 17:14:01 +0100 Subject: [PATCH 02/21] Refactor responsibility check, use generator --- captum/attr/_core/rex.py | 41 +++++++++++++++++++--------------------- tests/attr/test_rex.py | 2 +- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 7046331ed..3ca5257cf 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -4,14 +4,13 @@ import itertools from typing import List import torch -import math, heapq from collections import deque import random from dataclasses import dataclass from captum.attr._utils.attribution import PerturbationAttribution from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric -from captum.attr._utils.common import _format_input_baseline, _validate_input +from captum.attr._utils.common import _format_input_baseline, _validate_input, _format_output from captum.log.dummy_log import log_usage @dataclass @@ -28,8 +27,8 @@ def _occlude_data(data: torch.Tensor, partitions: List[List[int]], neutral) -> t def _powerset(iterable): s = list(iterable) - return [list(combo) for r in range(len(s)+1) - for combo in itertools.combinations(s, r)] + return (list(combo) for r in range(len(s)+1) + for combo in itertools.combinations(s, r)) def _partitions_combinations(partitions): @@ -45,24 +44,21 @@ def _apply_responsibility(feature_importance, part, responsibility): return feature_importance -def _partitions_equal(partitions_a, partitions_b): - set_a = {tuple(p) if isinstance(p, list) else p for p in partitions_a} - set_b = {tuple(p) if isinstance(p, list) else p for p in partitions_b} - return set_a == set_b +def _part_to_set(partition): + return frozenset(frozenset(p) if isinstance(p, list) else p for p in partition) -def _responsibility(subject_partition: List, consistent_set: List[List[int]]) -> float: - witnesses = [mut.partitions for mut in consistent_set if subject_partition not in mut.partitions] - +def _responsibility(subject_partition: List, consistent_partitions: List[List[int]]) -> float: + witnesses = [mut.partitions for mut in consistent_partitions if subject_partition not in mut.partitions] + consistent_set = set(_part_to_set(part.partitions) for part in consistent_partitions) + # a witness is valid if perturbing it results in a counterfactual # dependence on the subject partition - - # hashset! valid_witnesses = [] for witness in witnesses: - counterfactual = [subject_partition] + witness - if not any(_partitions_equal(counterfactual, cst.partitions) for cst in consistent_set): + counterfactual = _part_to_set([subject_partition] + witness) + if not counterfactual in consistent_set: valid_witnesses.append(witness) if len(valid_witnesses) == 0: @@ -110,7 +106,8 @@ def attribute(self, *, search_depth: int = 10, n_partitions: int = 8, - n_searches: int = 5) -> TensorOrTupleOfTensorsGeneric: + n_searches: int = 5, + contiguous_partitions: bool = False) -> TensorOrTupleOfTensorsGeneric: r""" Args: inputs: @@ -138,22 +135,23 @@ def attribute(self, self._search_depth = search_depth self._n_searches = n_searches - is_tuple_inputs = isinstance(inputs, tuple) - is_tuple_baselines = isinstance(baselines, tuple) + is_input_tuple = isinstance(inputs, tuple) + is_baseline_tuple = isinstance(baselines, tuple) attributions = [] # match inputs and baselines, explain - if is_tuple_inputs and is_tuple_baselines: + if is_input_tuple and is_baseline_tuple: for input, baseline in zip(inputs, baselines): attributions.append(self._explain(input, baseline)) - elif is_tuple_inputs and not is_tuple_baselines: + elif is_input_tuple and not is_baseline_tuple: for input in inputs: attributions.append(self._explain(input, baselines)) else: attributions.append(self._explain(inputs, baselines)) - return tuple(attributions) + return _format_output(is_input_tuple, tuple(attributions)) + def _flatten(self, data): self._original_shape = data.size() @@ -255,7 +253,6 @@ def _explain(self, input, baseline): for part in partitions: resp = _responsibility(part, cst_set) - # update should be conditional on entropy? feature_attribution = _apply_responsibility(feature_attribution, part, resp * parent_resp) if resp > 0 and \ len(part) > 1 and \ diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index ab4d5800e..6b93cb415 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -63,7 +63,7 @@ def test_selector_function(self, input, idx): # input shape # important idx ((12, 12, 12), (1,2,1)), ((12, 12, 12, 6), (1,1,4,1)), - # ((1920, 1080), (2, 4)) # image-like + ((1920, 1080), (2, 4)) # image-like ]) def test_selector_function_large_input(self, input_shape, idx): rex = ReX(lambda x: x[idx]) From 8fd0158e169ce6cafc29f47e0822dda8f3ed6d4f Mon Sep 17 00:00:00 2001 From: stav-af Date: Wed, 30 Jul 2025 11:50:17 +0100 Subject: [PATCH 03/21] WIP - contiguous partition --- captum/attr/_core/rex.py | 29 ++++++++++++++++++++++++++++- tests/attr/test_rex.py | 28 ++++++++++++++-------------- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 3ca5257cf..59783f537 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -188,6 +188,33 @@ def _fast_partition(self, responsibility: List[int], choices: List[int]) -> List return partitions + def _contiguous_partition(self, resposibility, choices, depth): + ndim = len(self._original_shape) + split_dim = depth % ndim + + # find max and min values for dimension we are splitting + dmax, dmin = 0, max(self._original_shape) + for idx in choices: + coords = torch.unravel_index(torch.tensor(idx), self._original_shape) + + dmax = max(dmax, coords[split_dim]) + dmin = min(dmin, coords[split_dim]) + + # cant split this axis + n_splits = min((dmax - dmin), self._n_partitions) + split_points = sorted(list(set(random.sample(range(dmin + 1, dmax), n_splits - 1)))) + + bins = [[] for _ in range(n_splits - 1)] + for idx in choices: + for i, s in enumerate(split_points): + if s > torch.unravel_index(torch.tensor(idx), self._original_shape)[split_dim]: + bins[i].append(idx) + break + + print(bins) + return bins + + def _partition(self, responsibility: List[float], choices: List[int]) -> List[List[int]]: population = choices.copy() random.shuffle(population) @@ -238,7 +265,7 @@ def _explain(self, input, baseline): while part_q: indices, parent_resp, depth = part_q.popleft() - partitions = self._fast_partition(feature_attribution, indices) + partitions = self._contiguous_partition(feature_attribution, indices, depth) mutants = [] for idx_combination in _partitions_combinations(partitions): diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index 6b93cb415..fc90afab2 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -52,26 +52,26 @@ def test_valid_input_baseline(self, input, baseline): def test_selector_function(self, input, idx): rex = ReX(lambda x: x[idx]) attributions = rex.attribute(input, 0)[0] - + print(attributions) self.assertTrue(attributions[idx] == 1) attributions[idx] = 0 self.assertFalse(torch.sum(attributions, dim=None)) - @parameterized.expand([ - # input shape # important idx - ((12, 12, 12), (1,2,1)), - ((12, 12, 12, 6), (1,1,4,1)), - ((1920, 1080), (2, 4)) # image-like - ]) - def test_selector_function_large_input(self, input_shape, idx): - rex = ReX(lambda x: x[idx]) + # @parameterized.expand([ + # # input shape # important idx + # # ((12, 12, 12), (1,2,1)), + # # ((12, 12, 12, 6), (1,1,4,1)), + # # ((1920, 1080), (2, 4)) # image-like + # ]) + # def test_selector_function_large_input(self, input_shape, idx): + # rex = ReX(lambda x: x[idx]) - input = torch.ones(*input_shape) - attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=1)[0] + # input = torch.ones(*input_shape) + # attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=1)[0] - self.assertTrue(attributions[idx]) - attributions[idx] = 0 - self.assertLess(torch.sum(attributions, dim=None), 1) + # self.assertTrue(attributions[idx]) + # attributions[idx] = 0 + # self.assertLess(torch.sum(attributions, dim=None), 1) From bb851010b8676687a135977276212eaf856312b2 Mon Sep 17 00:00:00 2001 From: stav-af Date: Wed, 6 Aug 2025 14:32:15 +0100 Subject: [PATCH 04/21] Refactor for pytorch --- captum/attr/_core/rex.py | 290 +++++++++++++++++++++++++-------------- tests/attr/test_rex.py | 109 +++++++-------- 2 files changed, 243 insertions(+), 156 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 59783f537..591412bb5 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -6,6 +6,7 @@ import torch from collections import deque import random +import math from dataclasses import dataclass from captum.attr._utils.attribution import PerturbationAttribution @@ -13,36 +14,63 @@ from captum.attr._utils.common import _format_input_baseline, _validate_input, _format_output from captum.log.dummy_log import log_usage -@dataclass + +class Partition: + def __init__(self, borders: List[slice] = None, elements=None, size=None): + self.borders = borders + self.elements = elements + self.size = size + + self._mask = None + + def generate_mask(self, shape): + # function to generate a mask for a partition, polymorphic over + # splitting strategy + + if self._mask is None and self.elements is not None: + self._mask = torch.ones(shape, dtype=torch.bool) + self._mask[tuple(self.elements.T)] = False + + elif self._mask is None and self.borders is not None: + self._mask = torch.ones(shape, dtype=torch.bool) + + slices = list(slice(lo, hi) for (lo, hi) in self.borders) + self._mask[slices] = False + + return self._mask + + def __len__(self): + return self.size + +@dataclass(eq=False) class Mutant: partitions: List[List[int]] data: List[int] -def _occlude_data(data: torch.Tensor, partitions: List[List[int]], neutral) -> torch.Tensor: - mask = torch.ones_like(data, dtype=torch.bool) - for part in partitions: - mask[part] = False + # initialize a Mutant from some partitions + # eagerly create the underlying mutant data from partition masks + def __init__(self, data: torch.Tensor, partitions: List[Partition], neutral): + self.partitions = partitions - return torch.where(mask, data, neutral) + mask = torch.ones_like(data, dtype=torch.bool) + for part in partitions: mask &= part.generate_mask(mask.shape) -def _powerset(iterable): - s = list(iterable) - return (list(combo) for r in range(len(s)+1) - for combo in itertools.combinations(s, r)) + self.data = torch.where(mask, data, neutral) + def __len__(self): + return len(self.partitions) -def _partitions_combinations(partitions): - return list(filter( - lambda x: len(x) <= len(partitions), - _powerset(partitions))) + +def _powerset(s): + return (list(combo) for r in range(len(s)+1) + for combo in itertools.combinations(s, r)) -def _apply_responsibility(feature_importance, part, responsibility): - distributed_resp = responsibility / len(part) - for idx in part: - feature_importance[idx] = distributed_resp +def _apply_responsibility(fi, part, responsibility): + distributed = responsibility / len(part) + mask = part.generate_mask(fi.shape) - return feature_importance + return torch.where(mask, fi, (fi * distributed)) def _part_to_set(partition): @@ -70,6 +98,9 @@ def _responsibility(subject_partition: List, consistent_partitions: List[List[in return 1.0 / (1.0 + float(minpart)) +def _generate_indices(ts): + return torch.tensor(tuple(itertools.product(*(range(s) for s in ts.shape))), dtype=torch.long) + class ReX(PerturbationAttribution): """ A perturbation-based approach to computing attribution, based on the @@ -153,65 +184,155 @@ def attribute(self, return _format_output(is_input_tuple, tuple(attributions)) - def _flatten(self, data): - self._original_shape = data.size() - return data.reshape(-1) + def _explain(self, input, baseline): + self._original_shape = input.shape + self._size = input.numel() - def _unflatten(self, data): - return data.reshape(self._original_shape) + initial_prediction = self.forward_func(input) + feature_attribution = torch.full_like(input, 1.0/input.numel(), dtype=torch.float32) + initial_partition = Partition( + borders = list((0, top) for top in self._original_shape), + elements = _generate_indices(input), + size = self._size + ) + for _ in range(self._n_searches): + # by definition, root partition contains all indices + part_q = deque() + part_q.append(( + initial_partition, + 0 + )) - def _fast_partition(self, responsibility: List[int], choices: List[int]) -> List[List[int]]: - population = choices.copy() - random.shuffle(population) + while part_q: + prev_part, depth = part_q.popleft() + partitions = self._fast_partition(feature_attribution, prev_part) - weights = [responsibility[i] for i in population] - if sum(weights) == 0: weights = [1/len(population) for _ in population] + consistent_set = set() + for parts_combo in _powerset(partitions): + mut = Mutant(input, parts_combo, baseline) + if self.forward_func(mut.data) == initial_prediction: + consistent_set.add(mut) + + for part in partitions: + resp = _responsibility(part, consistent_set) + feature_attribution = _apply_responsibility(feature_attribution, part, resp) + + if resp > 0 and \ + len(part) > 1 and \ + depth < self._search_depth: + part_q.append((part, depth + 1)) + + + feature_attribution /= feature_attribution.abs().sum() + + return feature_attribution.clone().detach() + + + def _fast_partition(self, responsibility: torch.Tensor, part: Partition) -> List[Partition]: + perm = torch.randperm(len(part.elements)) - remaining_weight = sum(weights) - target_weight = remaining_weight / self._n_partitions + population = part.elements[perm] + weights = responsibility[tuple(population.T)] - zip_sorted = zip(weights, population) - weight_sorted, pop_sorted = zip(*sorted(zip_sorted, key=lambda x: x[0], reverse=True)) + #illustrate + + if torch.sum(weights, dim=None) == 0: weights = torch.ones_like(weights) / len(weights) + print(torch.sum(weights, dim=None)) + + remaining_weight = torch.sum(weights, dim=None) + target_weight = remaining_weight / self._n_partitions - partitions, lp, rp, cumsum = [], 0, 0, 0 - while rp < len(weight_sorted): - cumsum += weight_sorted[rp] - remaining_weight -= weight_sorted[rp] - if cumsum >= target_weight or rp == len(weight_sorted) - 1: - partitions.append(list(pop_sorted[lp:rp + 1])) - lp = rp + 1 - cumsum = 0 - rp += 1 + idx = torch.argsort(weights, descending=True) + weight_sorted, pop_sorted = weights[idx], population[idx] + + eps = torch.finfo(weight_sorted.dtype).eps + c = weight_sorted.cumsum(0) - eps + + bin_id = torch.div(c, target_weight, rounding_mode='floor').clamp_min(0).long() + _, counts = torch.unique_consecutive(bin_id, return_counts=True) + groups = torch.split(pop_sorted, counts.tolist()) + + # print("--------------") + # print(c) + # print(weight_sorted) + # print(bin_id) + # print(counts) + # print(groups) + # print("--------------") + + partitions = [Partition(elements=g, size=len(g)) for g in groups] return partitions - def _contiguous_partition(self, resposibility, choices, depth): + # def _fast_partition(self, responsibility: torch.Tensor, part: Partition) -> List[Partition]: + # perm = torch.randperm(part.elements) + + # population = part.elements[perm] + # weights = responsibility[population] + + # if weights.sum().item() == 0: weights = torch.full_like(weights, 1.0 / len(part)) + + # remaining_weight = weights.sum() + # target_weight = remaining_weight / self._n_partitions + + # idx = torch.argsort(weights, descending=True, stable=True) + # weight_sorted = weights[idx] + # pop_sorted = population[idx] + + # partitions, lp = [], 0 + # cumsum = torch.zeros((), dtype=weight_sorted.dtype, device=weight_sorted.device) + # for rp in range(weight_sorted.numel()): + # w = weight_sorted[rp] + # cumsum = cumsum + w + + # if (cumsum >= target_weight) or (rp == weight_sorted.numel() - 1): + # block = pop_sorted[lp:rp+1] + # coords = torch.stack(torch.unravel_index(block, self._original_shape), dim=-1) + + # partitions.append(Partition( + # elements = tuple(coords), + # size = block.numel() # (fix: inclusive slice) + # )) + + # lp = rp + 1 + # cumsum.zero_() + + # return partitions + + + + def _contiguous_partition(self, resposibility, part, depth): ndim = len(self._original_shape) - split_dim = depth % ndim + split_dim = -1 # find max and min values for dimension we are splitting - dmax, dmin = 0, max(self._original_shape) - for idx in choices: - coords = torch.unravel_index(torch.tensor(idx), self._original_shape) - - dmax = max(dmax, coords[split_dim]) - dmin = min(dmin, coords[split_dim]) - - # cant split this axis - n_splits = min((dmax - dmin), self._n_partitions) - split_points = sorted(list(set(random.sample(range(dmin + 1, dmax), n_splits - 1)))) - - bins = [[] for _ in range(n_splits - 1)] - for idx in choices: - for i, s in enumerate(split_points): - if s > torch.unravel_index(torch.tensor(idx), self._original_shape)[split_dim]: - bins[i].append(idx) - break + dmin, dmax = max(self._original_shape), 0 + for i in range(ndim): + candidate_dim = (i + depth) % ndim + dmin, dmax = tuple(part.borders[candidate_dim]) + + if dmax - dmin > 1: + split_dim = candidate_dim + break - print(bins) + n_splits = min((dmax - dmin), self._n_partitions) + + split_points = random.sample(range(dmin, dmax), n_splits - 1) + split_borders = sorted(set([dmin, *split_points, dmax])) + + bins = [] + for i in range(len(split_borders) - 1): + new_borders = list(part.borders) + new_borders[split_dim] = (split_borders[i], split_borders[i+1]) + + bins.append(Partition( + borders = tuple(new_borders), + size = math.prod(hi - lo for (lo, hi) in new_borders) + )) + return bins @@ -220,7 +341,7 @@ def _partition(self, responsibility: List[float], choices: List[int]) -> List[Li random.shuffle(population) weights = [responsibility[i] for i in population] - if sum(weights) == 0: weights = [1 for _ in choices] + if torch.sum(weights) == 0: weights = [1 for _ in choices] target_weight = sum(weights) / self._n_partitions partitions = [] @@ -245,50 +366,15 @@ def _partition(self, responsibility: List[float], choices: List[int]) -> List[Li curr_partition, curr_weight = [], 0.0 if curr_partition: - partitions.append(curr_partition) + partitions.append(Partition( + elements = set(curr_partition), + size = len(curr_partition) + )) return partitions - def _explain(self, input, baseline): - initial_prediction = self.forward_func(input) - flattened_input = self._flatten(input) - - n_features = flattened_input.numel() - feature_attribution = [0.0 for _ in range(n_features)] - - for _ in range(self._n_searches): - # by definition, root partition contains all indices - part_q = deque() - part_q.append((list(range(n_features)), 1.0, 0)) - - while part_q: - indices, parent_resp, depth = part_q.popleft() - partitions = self._contiguous_partition(feature_attribution, indices, depth) - - mutants = [] - for idx_combination in _partitions_combinations(partitions): - occluded_data = _occlude_data(flattened_input, idx_combination, baseline) - mut = Mutant(partitions=idx_combination, data=occluded_data) - mutants.append(mut) - - cst_set = list(filter( - lambda mut: self.forward_func(self._unflatten(mut.data)) == initial_prediction, - mutants - )) - - for part in partitions: - resp = _responsibility(part, cst_set) - feature_attribution = _apply_responsibility(feature_attribution, part, resp * parent_resp) - if resp > 0 and \ - len(part) > 1 and \ - depth < self._search_depth: - part_q.append((part, resp, depth + 1)) - - return self._unflatten(torch.tensor(feature_attribution)) - - def multiplies_by_inputs(self) -> bool: return False diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index fc90afab2..5f8b2b02b 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -9,69 +9,70 @@ class Test(BaseTest): # rename for convenience ts = torch.tensor - @parameterized.expand([ - # inputs: baselines: - (ts([1,2,3]), ts([[2,3], [3,4]])), - ((ts([1]),ts([2]),ts([3])), (ts([1]),ts([2]))), - ((ts([1])), ()), - ((), ts([1])) - ]) - def test_input_baseline_mismatch_throws(self, input, baseline): - rex = ReX(lambda x: 1/0) # dummy forward, should be unreachable - with self.assertRaises(AssertionError): - rex.attribute(input, baseline) - + # @parameterized.expand([ + # # inputs: baselines: + # (ts([1,2,3]), ts([[2,3], [3,4]])), + # ((ts([1]),ts([2]),ts([3])), (ts([1]),ts([2]))), + # ((ts([1])), ()), + # ((), ts([1])) + # ]) + # def test_input_baseline_mismatch_throws(self, input, baseline): + # rex = ReX(lambda x: 1/0) # dummy forward, should be unreachable + # with self.assertRaises(AssertionError): + # rex.attribute(input, baseline) - @parameterized.expand([ - (ts([1,2,3]), 0), - (ts([[1,2,3], [4,5,6]]), 0), - (ts([1,2,3,4]), ts([0,0,0,0])), - (ts([[1, 2], [1,2]])), ts([0,0]), - (ts([[[1,2], [3,4]], [[5,6], [7,8]]]), 0), - ((ts([1,2]), ts([3,4]), ts([5,6])), (0, 0, 0)), - ((ts([1,2]), ts([3,4]), ts([5,6])), (ts([0,0]), ts([0,0]), ts([0,0]))), - ((ts([1,2]), ts([3,4])), (ts([0,0]), ts([0, 0]))), - ]) - def test_valid_input_baseline(self, input, baseline): - rex = ReX(lambda x: True) - attributions = rex.attribute(input, baseline)[0] - if isinstance(input, tuple): input = input[0] + # @parameterized.expand([ + # (ts([1,2,3]), 0), + # (ts([[1,2,3], [4,5,6]]), 0), + # (ts([1,2,3,4]), ts([0,0,0,0])), + # (ts([[1, 2], [1,2]])), ts([0,0]), + # (ts([[[1,2], [3,4]], [[5,6], [7,8]]]), 0), + # ((ts([1,2]), ts([3,4]), ts([5,6])), (0, 0, 0)), + # ((ts([1,2]), ts([3,4]), ts([5,6])), (ts([0,0]), ts([0,0]), ts([0,0]))), + # ((ts([1,2]), ts([3,4])), (ts([0,0]), ts([0, 0]))), + # ]) + # def test_valid_input_baseline(self, input, baseline): + # rex = ReX(lambda x: True) - # Forward_func returns a constant, no responsibility in input - self.assertFalse(torch.sum(attributions, dim=None)) - self.assertEqual(attributions.size(), input.size()) + # attributions = rex.attribute(input, baseline)[0] + # if isinstance(input, tuple): input = input[0] - - @parameterized.expand([ - # input # selected_idx - (ts([1,2,3]), 0), - (ts([[1,2], [3,4]]), (0, 1)), - (ts([[[1, 2], [3, 4]], [[5,6], [7,8]]]), (0, 1, 0)) - ]) - def test_selector_function(self, input, idx): - rex = ReX(lambda x: x[idx]) - attributions = rex.attribute(input, 0)[0] - print(attributions) - self.assertTrue(attributions[idx] == 1) - - attributions[idx] = 0 - self.assertFalse(torch.sum(attributions, dim=None)) + # # Forward_func returns a constant, no responsibility in input + # self.assertFalse(torch.sum(attributions, dim=None)) + # self.assertEqual(attributions.size(), input.size()) # @parameterized.expand([ - # # input shape # important idx - # # ((12, 12, 12), (1,2,1)), - # # ((12, 12, 12, 6), (1,1,4,1)), - # # ((1920, 1080), (2, 4)) # image-like + # # input # selected_idx + # (ts([1,2,3]), 0), + # (ts([[1,2], [3,4]]), (0, 1)), + # (ts([[[1, 2], [3, 4]], [[5,6], [7,8]]]), (0, 1, 0)) # ]) - # def test_selector_function_large_input(self, input_shape, idx): + # def test_selector_function(self, input, idx): # rex = ReX(lambda x: x[idx]) + # attributions = rex.attribute(input, 0)[0] + # print(attributions) + # self.assertTrue(attributions[idx] == 1) - # input = torch.ones(*input_shape) - # attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=1)[0] - - # self.assertTrue(attributions[idx]) # attributions[idx] = 0 - # self.assertLess(torch.sum(attributions, dim=None), 1) + # self.assertFalse(torch.sum(attributions, dim=None)) + + + @parameterized.expand([ + # input shape # important idx + ((4,4), (0,0)), + # ((12, 12, 12), (1,2,1)), + # ((12, 12, 12, 6), (1,1,4,1)), + ((1920, 1080), (1, 1)) # image-like + ]) + def test_selector_function_large_input(self, input_shape, idx): + rex = ReX(lambda x: x[idx]) + + input = torch.ones(*input_shape) + attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=3)[0] + print(attributions) + self.assertTrue(attributions[idx]) + attributions[idx] = 0 + self.assertLess(torch.sum(attributions, dim=None), 1) From 8e6f390da0a6349dc8946cbc5b4aca9fd96480da Mon Sep 17 00:00:00 2001 From: stav-af Date: Fri, 8 Aug 2025 13:53:04 +0100 Subject: [PATCH 05/21] Debugging --- captum/attr/_core/rex.py | 59 ++++++---------------------- tests/attr/test_rex.py | 84 ++++++++++++++++++++-------------------- 2 files changed, 53 insertions(+), 90 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 591412bb5..0620d3de1 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -223,8 +223,8 @@ def _explain(self, input, baseline): depth < self._search_depth: part_q.append((part, depth + 1)) - - feature_attribution /= feature_attribution.abs().sum() + asum = feature_attribution.abs().sum() + feature_attribution /= asum if asum != 0 else 1 return feature_attribution.clone().detach() @@ -235,8 +235,6 @@ def _fast_partition(self, responsibility: torch.Tensor, part: Partition) -> List population = part.elements[perm] weights = responsibility[tuple(population.T)] - #illustrate - if torch.sum(weights, dim=None) == 0: weights = torch.ones_like(weights) / len(weights) print(torch.sum(weights, dim=None)) @@ -245,6 +243,8 @@ def _fast_partition(self, responsibility: torch.Tensor, part: Partition) -> List idx = torch.argsort(weights, descending=True) + print("inb4", weights, population) + print(part.elements, part.size) weight_sorted, pop_sorted = weights[idx], population[idx] eps = torch.finfo(weight_sorted.dtype).eps @@ -255,55 +255,18 @@ def _fast_partition(self, responsibility: torch.Tensor, part: Partition) -> List _, counts = torch.unique_consecutive(bin_id, return_counts=True) groups = torch.split(pop_sorted, counts.tolist()) - # print("--------------") - # print(c) - # print(weight_sorted) - # print(bin_id) - # print(counts) - # print(groups) - # print("--------------") + print("--------------") + print(c) + print(weight_sorted) + print(bin_id) + print(counts) + print(groups) + print("--------------") partitions = [Partition(elements=g, size=len(g)) for g in groups] return partitions - # def _fast_partition(self, responsibility: torch.Tensor, part: Partition) -> List[Partition]: - # perm = torch.randperm(part.elements) - - # population = part.elements[perm] - # weights = responsibility[population] - - # if weights.sum().item() == 0: weights = torch.full_like(weights, 1.0 / len(part)) - - # remaining_weight = weights.sum() - # target_weight = remaining_weight / self._n_partitions - - # idx = torch.argsort(weights, descending=True, stable=True) - # weight_sorted = weights[idx] - # pop_sorted = population[idx] - - # partitions, lp = [], 0 - # cumsum = torch.zeros((), dtype=weight_sorted.dtype, device=weight_sorted.device) - # for rp in range(weight_sorted.numel()): - # w = weight_sorted[rp] - # cumsum = cumsum + w - - # if (cumsum >= target_weight) or (rp == weight_sorted.numel() - 1): - # block = pop_sorted[lp:rp+1] - # coords = torch.stack(torch.unravel_index(block, self._original_shape), dim=-1) - - # partitions.append(Partition( - # elements = tuple(coords), - # size = block.numel() # (fix: inclusive slice) - # )) - - # lp = rp + 1 - # cumsum.zero_() - - # return partitions - - - def _contiguous_partition(self, resposibility, part, depth): ndim = len(self._original_shape) split_dim = -1 diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index 5f8b2b02b..68916e199 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -9,54 +9,54 @@ class Test(BaseTest): # rename for convenience ts = torch.tensor - # @parameterized.expand([ - # # inputs: baselines: - # (ts([1,2,3]), ts([[2,3], [3,4]])), - # ((ts([1]),ts([2]),ts([3])), (ts([1]),ts([2]))), - # ((ts([1])), ()), - # ((), ts([1])) - # ]) - # def test_input_baseline_mismatch_throws(self, input, baseline): - # rex = ReX(lambda x: 1/0) # dummy forward, should be unreachable - # with self.assertRaises(AssertionError): - # rex.attribute(input, baseline) - + @parameterized.expand([ + # inputs: baselines: + (ts([1,2,3]), ts([[2,3], [3,4]])), + ((ts([1]),ts([2]),ts([3])), (ts([1]),ts([2]))), + ((ts([1])), ()), + ((), ts([1])) + ]) + def test_input_baseline_mismatch_throws(self, input, baseline): + rex = ReX(lambda x: 1/0) # dummy forward, should be unreachable + with self.assertRaises(AssertionError): + rex.attribute(input, baseline) - # @parameterized.expand([ - # (ts([1,2,3]), 0), - # (ts([[1,2,3], [4,5,6]]), 0), - # (ts([1,2,3,4]), ts([0,0,0,0])), - # (ts([[1, 2], [1,2]])), ts([0,0]), - # (ts([[[1,2], [3,4]], [[5,6], [7,8]]]), 0), - # ((ts([1,2]), ts([3,4]), ts([5,6])), (0, 0, 0)), - # ((ts([1,2]), ts([3,4]), ts([5,6])), (ts([0,0]), ts([0,0]), ts([0,0]))), - # ((ts([1,2]), ts([3,4])), (ts([0,0]), ts([0, 0]))), - # ]) - # def test_valid_input_baseline(self, input, baseline): - # rex = ReX(lambda x: True) - # attributions = rex.attribute(input, baseline)[0] - # if isinstance(input, tuple): input = input[0] + @parameterized.expand([ + (ts([1,2,3]), 0), + (ts([[1,2,3], [4,5,6]]), 0), + (ts([1,2,3,4]), ts([0,0,0,0])), + (ts([[1, 2], [1,2]]), ts([[0,0], [0,0]])), + (ts([[[1,2], [3,4]], [[5,6], [7,8]]]), 0), + ((ts([1,2]), ts([3,4]), ts([5,6])), (0, 0, 0)), + ((ts([1,2]), ts([3,4]), ts([5,6])), (ts([0,0]), ts([0,0]), ts([0,0]))), + ((ts([1,2]), ts([3,4])), (ts([0,0]), ts([0, 0]))), + ]) + def test_valid_input_baseline(self, input, baseline): + rex = ReX(lambda x: True) - # # Forward_func returns a constant, no responsibility in input - # self.assertFalse(torch.sum(attributions, dim=None)) - # self.assertEqual(attributions.size(), input.size()) + attributions = rex.attribute(input, baseline, n_partitions=2)[0] + if isinstance(input, tuple): input = input[0] + print(attributions) + # Forward_func returns a constant, no responsibility in input + self.assertFalse(torch.sum(attributions, dim=None)) + self.assertEqual(attributions.size(), input.size()) - # @parameterized.expand([ - # # input # selected_idx - # (ts([1,2,3]), 0), - # (ts([[1,2], [3,4]]), (0, 1)), - # (ts([[[1, 2], [3, 4]], [[5,6], [7,8]]]), (0, 1, 0)) - # ]) - # def test_selector_function(self, input, idx): - # rex = ReX(lambda x: x[idx]) - # attributions = rex.attribute(input, 0)[0] - # print(attributions) - # self.assertTrue(attributions[idx] == 1) + @parameterized.expand([ + # input # selected_idx + (ts([1,2,3]), 0), + (ts([[1,2], [3,4]]), (0, 1)), + (ts([[[1, 2], [3, 4]], [[5,6], [7,8]]]), (0, 1, 0)) + ]) + def test_selector_function(self, input, idx): + rex = ReX(lambda x: x[idx]) + attributions = rex.attribute(input, 0)[0] + print(attributions) + self.assertTrue(attributions[idx] == 1) - # attributions[idx] = 0 - # self.assertFalse(torch.sum(attributions, dim=None)) + attributions[idx] = 0 + self.assertFalse(torch.sum(attributions, dim=None)) @parameterized.expand([ From 82cf7e5e026718a4039bc0cc707157ab65941c38 Mon Sep 17 00:00:00 2001 From: stav-af Date: Sun, 10 Aug 2025 17:30:13 +0100 Subject: [PATCH 06/21] Comment, test, debug --- captum/attr/_core/rex.py | 170 ++++++++++++++------------------------- tests/attr/test_rex.py | 39 +++++---- 2 files changed, 86 insertions(+), 123 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 0620d3de1..1fa24ffcc 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -24,16 +24,16 @@ def __init__(self, borders: List[slice] = None, elements=None, size=None): self._mask = None def generate_mask(self, shape): - # function to generate a mask for a partition, polymorphic over - # splitting strategy + # generates a mask for a partition (False indicates membership) + if self._mask is not None: return self._mask + self._mask = torch.ones(shape, dtype=torch.bool) - if self._mask is None and self.elements is not None: - self._mask = torch.ones(shape, dtype=torch.bool) + # non-contiguous case + if self.elements is not None: self._mask[tuple(self.elements.T)] = False - elif self._mask is None and self.borders is not None: - self._mask = torch.ones(shape, dtype=torch.bool) - + # contiguous case + elif self.borders is not None: slices = list(slice(lo, hi) for (lo, hi) in self.borders) self._mask[slices] = False @@ -41,20 +41,19 @@ def generate_mask(self, shape): def __len__(self): return self.size + @dataclass(eq=False) class Mutant: partitions: List[List[int]] data: List[int] - # initialize a Mutant from some partitions - # eagerly create the underlying mutant data from partition masks - def __init__(self, data: torch.Tensor, partitions: List[Partition], neutral): - self.partitions = partitions - + # eagerly create the underlying mutant data + def __init__(self, partitions: List[Partition], data: torch.Tensor, neutral): mask = torch.ones_like(data, dtype=torch.bool) for part in partitions: mask &= part.generate_mask(mask.shape) + self.partitions = partitions self.data = torch.where(mask, data, neutral) def __len__(self): @@ -77,9 +76,9 @@ def _part_to_set(partition): return frozenset(frozenset(p) if isinstance(p, list) else p for p in partition) -def _responsibility(subject_partition: List, consistent_partitions: List[List[int]]) -> float: - witnesses = [mut.partitions for mut in consistent_partitions if subject_partition not in mut.partitions] - consistent_set = set(_part_to_set(part.partitions) for part in consistent_partitions) +def _calculate_responsibility(subject_partition: List, consistent_mutants: List[Mutant]) -> float: + witnesses = [mut.partitions for mut in consistent_mutants if subject_partition not in mut.partitions] + consistent_set = set(_part_to_set(part.partitions) for part in consistent_mutants) # a witness is valid if perturbing it results in a counterfactual # dependence on the subject partition @@ -101,16 +100,18 @@ def _responsibility(subject_partition: List, consistent_partitions: List[List[in def _generate_indices(ts): return torch.tensor(tuple(itertools.product(*(range(s) for s in ts.shape))), dtype=torch.long) + class ReX(PerturbationAttribution): """ A perturbation-based approach to computing attribution, based on the - Halpern-Pearl definition of actual causality[1]. + Halpern-Pearl definition of Actual Causality[1]. - The approach works by - partitioning the input space, and masking each partition. Intuitively, if masking a - partition changes the prediction of the model, then that partition has + ReX works by partitioning the input space, and masking each partition with the baseline value. It is fully + model agnostic, and relies only on a 'forward_func' returning a scalar. + + Intuitively, if masking a partition changes the prediction of the model, then that partition has some responsibility (attribution > 0). Such partially masked partitions are called - mutants. The responsibility of a subject partition is defined as 1/(1+k) where + mutants. The responsibility of a partition is defined as 1/(1+k) where k is a minimum number of occluded partitions in a mutant which make forward_func's output dependednt on the subject partition. @@ -134,11 +135,10 @@ def __init__(self, forward_func): def attribute(self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = 0, - *, search_depth: int = 10, - n_partitions: int = 8, + n_partitions: int = 4, n_searches: int = 5, - contiguous_partitions: bool = False) -> TensorOrTupleOfTensorsGeneric: + contiguous_partitioning: bool = False) -> TensorOrTupleOfTensorsGeneric: r""" Args: inputs: @@ -163,8 +163,9 @@ def attribute(self, _validate_input(inputs, baselines) self._n_partitions = n_partitions - self._search_depth = search_depth + self._max_depth = search_depth self._n_searches = n_searches + self._is_contiguous = contiguous_partitioning is_input_tuple = isinstance(inputs, tuple) is_baseline_tuple = isinstance(baselines, tuple) @@ -189,89 +190,76 @@ def _explain(self, input, baseline): self._size = input.numel() initial_prediction = self.forward_func(input) - feature_attribution = torch.full_like(input, 1.0/input.numel(), dtype=torch.float32) + attribution = torch.full_like(input, 1.0/input.numel(), dtype=torch.float32) initial_partition = Partition( borders = list((0, top) for top in self._original_shape), elements = _generate_indices(input), size = self._size ) + prev_depth = 0 + for _ in range(self._n_searches): - # by definition, root partition contains all indices - part_q = deque() - part_q.append(( - initial_partition, - 0 - )) + Q = deque() + Q.append((initial_partition, 0)) - while part_q: - prev_part, depth = part_q.popleft() - partitions = self._fast_partition(feature_attribution, prev_part) + while Q: + prev_part, depth = Q.popleft() + partitions = self._contiguous_partition(prev_part, depth) \ + if self._is_contiguous else self._partition(prev_part, attribution) - consistent_set = set() - for parts_combo in _powerset(partitions): - mut = Mutant(input, parts_combo, baseline) - if self.forward_func(mut.data) == initial_prediction: - consistent_set.add(mut) + mutants = [Mutant(ps, input, baseline) for ps in _powerset(partitions)] + consistent_mutants = [mut for mut in mutants if self.forward_func(mut.data) == initial_prediction] for part in partitions: - resp = _responsibility(part, consistent_set) - feature_attribution = _apply_responsibility(feature_attribution, part, resp) + resp = _calculate_responsibility(part, consistent_mutants) + attribution = _apply_responsibility(attribution, part, resp) + + if resp > 0 and len(part) > 1 and self._max_depth > depth: + Q.append((part, depth + 1)) - if resp > 0 and \ - len(part) > 1 and \ - depth < self._search_depth: - part_q.append((part, depth + 1)) + if depth != prev_depth: + asum = attribution.abs().sum() + attribution /= asum if asum != 0 else 1 - asum = feature_attribution.abs().sum() - feature_attribution /= asum if asum != 0 else 1 + prev_depth = depth - return feature_attribution.clone().detach() + asum = attribution.abs().sum() + attribution /= asum if asum != 0 else 1 + return attribution.clone().detach() - def _fast_partition(self, responsibility: torch.Tensor, part: Partition) -> List[Partition]: + def _partition(self, part: Partition, responsibility: torch.Tensor) -> List[Partition]: + # shuffle candidate indices (randomize tiebreakers) perm = torch.randperm(len(part.elements)) - population = part.elements[perm] weights = responsibility[tuple(population.T)] if torch.sum(weights, dim=None) == 0: weights = torch.ones_like(weights) / len(weights) - print(torch.sum(weights, dim=None)) - - remaining_weight = torch.sum(weights, dim=None) - target_weight = remaining_weight / self._n_partitions - + target_weight = torch.sum(weights) / self._n_partitions + # sort for greedy selection idx = torch.argsort(weights, descending=True) - print("inb4", weights, population) - print(part.elements, part.size) weight_sorted, pop_sorted = weights[idx], population[idx] + # cumulative sum of weights / weight per bucket rounded down gives us bucket ids eps = torch.finfo(weight_sorted.dtype).eps c = weight_sorted.cumsum(0) - eps - bin_id = torch.div(c, target_weight, rounding_mode='floor').clamp_min(0).long() + # count elements in each bucket, and split input accordingly _, counts = torch.unique_consecutive(bin_id, return_counts=True) groups = torch.split(pop_sorted, counts.tolist()) - print("--------------") - print(c) - print(weight_sorted) - print(bin_id) - print(counts) - print(groups) - print("--------------") - partitions = [Partition(elements=g, size=len(g)) for g in groups] return partitions - def _contiguous_partition(self, resposibility, part, depth): + def _contiguous_partition(self, part, depth): ndim = len(self._original_shape) split_dim = -1 - # find max and min values for dimension we are splitting + # find a dimension we can split dmin, dmax = max(self._original_shape), 0 for i in range(ndim): candidate_dim = (i + depth) % ndim @@ -280,10 +268,13 @@ def _contiguous_partition(self, resposibility, part, depth): if dmax - dmin > 1: split_dim = candidate_dim break + + if split_dim == -1: return [part] - n_splits = min((dmax - dmin), self._n_partitions) + n_splits = min((dmax - dmin), self._n_partitions) - 1 - split_points = random.sample(range(dmin, dmax), n_splits - 1) + # drop splits randomly + split_points = random.sample(range(dmin + 1, dmax), n_splits) split_borders = sorted(set([dmin, *split_points, dmax])) bins = [] @@ -299,45 +290,6 @@ def _contiguous_partition(self, resposibility, part, depth): return bins - def _partition(self, responsibility: List[float], choices: List[int]) -> List[List[int]]: - population = choices.copy() - random.shuffle(population) - - weights = [responsibility[i] for i in population] - if torch.sum(weights) == 0: weights = [1 for _ in choices] - - target_weight = sum(weights) / self._n_partitions - partitions = [] - - curr_weight = 0.0 - curr_partition = [] - - while population: - choice = random.choices(population, weights, k=1)[0] - idx = population.index(choice) - - population.pop(idx) - - weights = [responsibility[i] for i in population] - if sum(weights) == 0: weights = [1 for _ in population] - - curr_partition.append(choice) - curr_weight += responsibility[choice] - - if curr_weight > target_weight: - partitions.append(curr_partition) - curr_partition, curr_weight = [], 0.0 - - if curr_partition: - partitions.append(Partition( - elements = set(curr_partition), - size = len(curr_partition) - )) - - return partitions - - - def multiplies_by_inputs(self) -> bool: return False diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index 68916e199..dd4853c8d 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -8,6 +8,13 @@ class Test(BaseTest): # rename for convenience ts = torch.tensor + + depth_opts = range(4, 10) + n_partition_opts = range(2, 5) + n_search_opts = range(2, 5) + is_contiguous_opts = [False, True] + + all_options = list(itertools.product(depth_opts, n_partition_opts, n_search_opts, is_contiguous_opts)) @parameterized.expand([ # inputs: baselines: @@ -33,14 +40,17 @@ def test_input_baseline_mismatch_throws(self, input, baseline): ((ts([1,2]), ts([3,4])), (ts([0,0]), ts([0, 0]))), ]) def test_valid_input_baseline(self, input, baseline): - rex = ReX(lambda x: True) + for o in self.all_options: + rex = ReX(lambda x: True) + + attributions = rex.attribute(input, baseline, *o)[0] + + inp_unwrapped = input + if isinstance(input, tuple): inp_unwrapped = input[0] - attributions = rex.attribute(input, baseline, n_partitions=2)[0] - if isinstance(input, tuple): input = input[0] - print(attributions) - # Forward_func returns a constant, no responsibility in input - self.assertFalse(torch.sum(attributions, dim=None)) - self.assertEqual(attributions.size(), input.size()) + # Forward_func returns a constant, no responsibility in input + self.assertFalse(torch.sum(attributions, dim=None)) + self.assertEqual(attributions.size(), inp_unwrapped.size()) @parameterized.expand([ @@ -50,13 +60,15 @@ def test_valid_input_baseline(self, input, baseline): (ts([[[1, 2], [3, 4]], [[5,6], [7,8]]]), (0, 1, 0)) ]) def test_selector_function(self, input, idx): - rex = ReX(lambda x: x[idx]) - attributions = rex.attribute(input, 0)[0] - print(attributions) - self.assertTrue(attributions[idx] == 1) + for o in self.all_options: + rex = ReX(lambda x: x[idx]) - attributions[idx] = 0 - self.assertFalse(torch.sum(attributions, dim=None)) + attributions = rex.attribute(input, 0, *o)[0] + print(attributions, o) + self.assertTrue(attributions[idx] == 1) + + attributions[idx] = 0 + self.assertFalse(torch.sum(attributions, dim=None)) @parameterized.expand([ @@ -71,7 +83,6 @@ def test_selector_function_large_input(self, input_shape, idx): input = torch.ones(*input_shape) attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=3)[0] - print(attributions) self.assertTrue(attributions[idx]) attributions[idx] = 0 self.assertLess(torch.sum(attributions, dim=None), 1) From 036ed4e225bd047ba760a60432cbbd1c096ad0bc Mon Sep 17 00:00:00 2001 From: stav-af Date: Wed, 13 Aug 2025 15:48:09 +0100 Subject: [PATCH 07/21] Fix inversion logic, boolean OR --- captum/attr/_core/rex.py | 89 +++++++++++++++++++++------------------- tests/attr/test_rex.py | 24 +++++++++-- 2 files changed, 68 insertions(+), 45 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 1fa24ffcc..0cf877acf 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -26,32 +26,38 @@ def __init__(self, borders: List[slice] = None, elements=None, size=None): def generate_mask(self, shape): # generates a mask for a partition (False indicates membership) if self._mask is not None: return self._mask - self._mask = torch.ones(shape, dtype=torch.bool) + self._mask = torch.zeros(shape, dtype=torch.bool) # non-contiguous case if self.elements is not None: - self._mask[tuple(self.elements.T)] = False + self._mask[tuple(self.elements.T)] = True # contiguous case elif self.borders is not None: slices = list(slice(lo, hi) for (lo, hi) in self.borders) - self._mask[slices] = False + self._mask[slices] = True return self._mask def __len__(self): return self.size + def __str__(self): + # unsafe + return self.generate_mask(None).to(torch.int).__str__() + @dataclass(eq=False) class Mutant: - partitions: List[List[int]] - data: List[int] + partitions: List[Partition] + data: torch.Tensor # eagerly create the underlying mutant data - def __init__(self, partitions: List[Partition], data: torch.Tensor, neutral): - mask = torch.ones_like(data, dtype=torch.bool) - for part in partitions: mask &= part.generate_mask(mask.shape) + def __init__(self, partitions: List[Partition], data: torch.Tensor, neutral, shape): + + # A bitmap in the shape of the input indicating membership to a partition in this mutant + mask = torch.zeros(shape, dtype=torch.bool) + for part in partitions: mask |= part.generate_mask(mask.shape) self.partitions = partitions self.data = torch.where(mask, data, neutral) @@ -69,32 +75,38 @@ def _apply_responsibility(fi, part, responsibility): distributed = responsibility / len(part) mask = part.generate_mask(fi.shape) - return torch.where(mask, fi, (fi * distributed)) + return torch.where(mask, distributed, fi) def _part_to_set(partition): return frozenset(frozenset(p) if isinstance(p, list) else p for p in partition) -def _calculate_responsibility(subject_partition: List, consistent_mutants: List[Mutant]) -> float: - witnesses = [mut.partitions for mut in consistent_mutants if subject_partition not in mut.partitions] - consistent_set = set(_part_to_set(part.partitions) for part in consistent_mutants) - - # a witness is valid if perturbing it results in a counterfactual - # dependence on the subject partition - valid_witnesses = [] - for witness in witnesses: - counterfactual = _part_to_set([subject_partition] + witness) - if not counterfactual in consistent_set: - valid_witnesses.append(witness) +def _calculate_responsibility(subject_partition: Partition, + mutants: List[Mutant], + recovery_mutants: List[Mutant]) -> float: - if len(valid_witnesses) == 0: - return 0.0 - min_mutant = min(valid_witnesses, key=len) - minpart = len(min_mutant) + recovery_set = {_part_to_set(m.partitions) for m in recovery_mutants} - return 1.0 / (1.0 + float(minpart)) + valid_witnesses = [] + for m in mutants: + if subject_partition in m.partitions: + continue + W = m.partitions + W_set = _part_to_set(W) + W_plus_P_set = _part_to_set([subject_partition] + W) + + # W alone does NOT recover, but W ∪ {P} DOES recover. + if (W_set not in recovery_set) and (W_plus_P_set in recovery_set): + valid_witnesses.append(W) + + if not valid_witnesses: + return 0.0 + + k = min(len(w) for w in valid_witnesses) + # Responsibility per your definition: 1 / (1 + k) + return 1.0 / (1.0 + float(k)) def _generate_indices(ts): @@ -186,19 +198,17 @@ def attribute(self, def _explain(self, input, baseline): - self._original_shape = input.shape + self._shape = input.shape self._size = input.numel() initial_prediction = self.forward_func(input) attribution = torch.full_like(input, 1.0/input.numel(), dtype=torch.float32) initial_partition = Partition( - borders = list((0, top) for top in self._original_shape), + borders = list((0, top) for top in self._shape), elements = _generate_indices(input), size = self._size ) - prev_depth = 0 - for _ in range(self._n_searches): Q = deque() Q.append((initial_partition, 0)) @@ -208,24 +218,19 @@ def _explain(self, input, baseline): partitions = self._contiguous_partition(prev_part, depth) \ if self._is_contiguous else self._partition(prev_part, attribution) - mutants = [Mutant(ps, input, baseline) for ps in _powerset(partitions)] + + mutants = [Mutant(ps, input, baseline, self._shape) for ps in _powerset(partitions)] consistent_mutants = [mut for mut in mutants if self.forward_func(mut.data) == initial_prediction] + prev_part.generate_mask(self._shape) for part in partitions: - resp = _calculate_responsibility(part, consistent_mutants) + resp = _calculate_responsibility(part, mutants, consistent_mutants) attribution = _apply_responsibility(attribution, part, resp) - if resp > 0 and len(part) > 1 and self._max_depth > depth: + if resp == 1 and len(part) > 1 and self._max_depth > depth: Q.append((part, depth + 1)) + - if depth != prev_depth: - asum = attribution.abs().sum() - attribution /= asum if asum != 0 else 1 - - prev_depth = depth - - asum = attribution.abs().sum() - attribution /= asum if asum != 0 else 1 return attribution.clone().detach() @@ -256,11 +261,11 @@ def _partition(self, part: Partition, responsibility: torch.Tensor) -> List[Part def _contiguous_partition(self, part, depth): - ndim = len(self._original_shape) + ndim = len(self._shape) split_dim = -1 # find a dimension we can split - dmin, dmax = max(self._original_shape), 0 + dmin, dmax = max(self._shape), 0 for i in range(ndim): candidate_dim = (i + depth) % ndim dmin, dmax = tuple(part.borders[candidate_dim]) diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index dd4853c8d..847ae4a46 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -10,8 +10,8 @@ class Test(BaseTest): ts = torch.tensor depth_opts = range(4, 10) - n_partition_opts = range(2, 5) - n_search_opts = range(2, 5) + n_partition_opts = range(3, 5) + n_search_opts = range(5, 15) is_contiguous_opts = [False, True] all_options = list(itertools.product(depth_opts, n_partition_opts, n_search_opts, is_contiguous_opts)) @@ -64,7 +64,6 @@ def test_selector_function(self, input, idx): rex = ReX(lambda x: x[idx]) attributions = rex.attribute(input, 0, *o)[0] - print(attributions, o) self.assertTrue(attributions[idx] == 1) attributions[idx] = 0 @@ -87,3 +86,22 @@ def test_selector_function_large_input(self, input_shape, idx): attributions[idx] = 0 self.assertLess(torch.sum(attributions, dim=None), 1) + @parameterized.expand([ + # input shape # lhs_idx # rhs_idx + ((2,4), (0,2), (1,3)) + + + ]) + def test_boolean_or(self, input_shape, lhs_idx, rhs_idx): + for o in self.all_options: + rex = ReX(lambda x: max(x[lhs_idx], x[rhs_idx])) + input = torch.ones(input_shape) + + attributions = rex.attribute(input, 0, *o)[0] + + self.assertTrue(attributions[lhs_idx] > 0.25, f"{attributions}") + self.assertTrue(attributions[rhs_idx] > 0.25, f"{attributions}") + + attributions[lhs_idx] = 0 + attributions[rhs_idx] = 0 + self.assertTrue(torch.sum(attributions) < 1, f"{attributions}") From 902271730355e449cbd26fb3ffda0b2cabd13652 Mon Sep 17 00:00:00 2001 From: stav-af Date: Sat, 16 Aug 2025 15:38:50 +0100 Subject: [PATCH 08/21] Test gaussian recovery --- captum/attr/_core/rex.py | 20 +++++-- tests/attr/test_rex.py | 109 +++++++++++++++++++++++++++++++++++---- 2 files changed, 114 insertions(+), 15 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 0cf877acf..a0a3e5b97 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -150,7 +150,8 @@ def attribute(self, search_depth: int = 10, n_partitions: int = 4, n_searches: int = 5, - contiguous_partitioning: bool = False) -> TensorOrTupleOfTensorsGeneric: + contiguous_partitioning: bool = False, + merge: bool = True) -> TensorOrTupleOfTensorsGeneric: r""" Args: inputs: @@ -178,6 +179,7 @@ def attribute(self, self._max_depth = search_depth self._n_searches = n_searches self._is_contiguous = contiguous_partitioning + self._merge = merge is_input_tuple = isinstance(inputs, tuple) is_baseline_tuple = isinstance(baselines, tuple) @@ -202,6 +204,8 @@ def _explain(self, input, baseline): self._size = input.numel() initial_prediction = self.forward_func(input) + + prev_attribution = torch.full_like(input, 0.0, dtype=torch.float32) attribution = torch.full_like(input, 1.0/input.numel(), dtype=torch.float32) initial_partition = Partition( @@ -209,7 +213,8 @@ def _explain(self, input, baseline): elements = _generate_indices(input), size = self._size ) - for _ in range(self._n_searches): + + for i in range(1, self._n_searches + 1): Q = deque() Q.append((initial_partition, 0)) @@ -229,7 +234,12 @@ def _explain(self, input, baseline): if resp == 1 and len(part) > 1 and self._max_depth > depth: Q.append((part, depth + 1)) + else: + attribution = _apply_responsibility(attribution, part, resp) + if self._merge: + prev_attribution += (1/i) * (attribution - prev_attribution) + attribution = prev_attribution return attribution.clone().detach() @@ -244,14 +254,14 @@ def _partition(self, part: Partition, responsibility: torch.Tensor) -> List[Part target_weight = torch.sum(weights) / self._n_partitions # sort for greedy selection - idx = torch.argsort(weights, descending=True) + idx = torch.argsort(weights, descending=False) weight_sorted, pop_sorted = weights[idx], population[idx] # cumulative sum of weights / weight per bucket rounded down gives us bucket ids eps = torch.finfo(weight_sorted.dtype).eps - c = weight_sorted.cumsum(0) - eps + c = weight_sorted.cumsum(0) + eps bin_id = torch.div(c, target_weight, rounding_mode='floor').clamp_min(0).long() - + # count elements in each bucket, and split input accordingly _, counts = torch.unique_consecutive(bin_id, return_counts=True) groups = torch.split(pop_sorted, counts.tolist()) diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index 847ae4a46..a761b2c31 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -3,19 +3,51 @@ from captum.testing.helpers.basic import BaseTest from parameterized import parameterized +import random +import statistics import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt + + +def visualize_tensor(tensor, cmap='viridis'): + """ + Simple heatmap visualizer for 2D PyTorch tensors. + Automatically moves tensor to CPU and detaches from graph. + """ + arr = tensor.detach().cpu().numpy() + plt.imshow(arr, cmap=cmap) + plt.colorbar() + plt.show() + class Test(BaseTest): # rename for convenience ts = torch.tensor depth_opts = range(4, 10) - n_partition_opts = range(3, 5) - n_search_opts = range(5, 15) + n_partition_opts = range(4, 5) + n_search_opts = range(10, 15) is_contiguous_opts = [False, True] all_options = list(itertools.product(depth_opts, n_partition_opts, n_search_opts, is_contiguous_opts)) + def _generate_gaussian_pdf(self, shape, mean): + k = len(shape) + + cov = 0.1 * torch.eye(k) * statistics.mean(shape) + dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov) + + grids = torch.meshgrid( + *[torch.arange(n, dtype=torch.float64) for n in shape], + indexing='ij' + ) + coords = torch.stack(grids, dim=-1).reshape(-1, k) + + pdf_vals = torch.exp(dist.log_prob(coords)) + return pdf_vals.reshape(*shape) + @parameterized.expand([ # inputs: baselines: (ts([1,2,3]), ts([[2,3], [3,4]])), @@ -43,7 +75,7 @@ def test_valid_input_baseline(self, input, baseline): for o in self.all_options: rex = ReX(lambda x: True) - attributions = rex.attribute(input, baseline, *o)[0] + attributions = rex.attribute(input, baseline, *o, merge=False)[0] inp_unwrapped = input if isinstance(input, tuple): inp_unwrapped = input[0] @@ -63,7 +95,7 @@ def test_selector_function(self, input, idx): for o in self.all_options: rex = ReX(lambda x: x[idx]) - attributions = rex.attribute(input, 0, *o)[0] + attributions = rex.attribute(input, 0, *o, merge=False)[0] self.assertTrue(attributions[idx] == 1) attributions[idx] = 0 @@ -81,27 +113,84 @@ def test_selector_function_large_input(self, input_shape, idx): rex = ReX(lambda x: x[idx]) input = torch.ones(*input_shape) - attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=3)[0] + attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=3, merge=False)[0] self.assertTrue(attributions[idx]) attributions[idx] = 0 self.assertLess(torch.sum(attributions, dim=None), 1) + @parameterized.expand([ # input shape # lhs_idx # rhs_idx ((2,4), (0,2), (1,3)) - - ]) def test_boolean_or(self, input_shape, lhs_idx, rhs_idx): for o in self.all_options: rex = ReX(lambda x: max(x[lhs_idx], x[rhs_idx])) input = torch.ones(input_shape) - attributions = rex.attribute(input, 0, *o)[0] + attributions = rex.attribute(input, 0, *o, merge=False)[0] - self.assertTrue(attributions[lhs_idx] > 0.25, f"{attributions}") - self.assertTrue(attributions[rhs_idx] > 0.25, f"{attributions}") + self.assertTrue(attributions[lhs_idx] == 1.0, f"{attributions}") + self.assertTrue(attributions[rhs_idx] == 1.0, f"{attributions}") attributions[lhs_idx] = 0 attributions[rhs_idx] = 0 self.assertTrue(torch.sum(attributions) < 1, f"{attributions}") + + + @parameterized.expand([ + # input shape # lhs_idx # rhs_idx + ((2,4), (0,2), (0,3)) + ]) + def test_boolean_and(self, input_shape, lhs_idx, rhs_idx): + for i, o in enumerate(self.all_options): + rex = ReX(lambda x: min(x[lhs_idx], x[rhs_idx])) + input = torch.ones(input_shape) + + attributions = rex.attribute(input, 0, *o, merge=False)[0] + + self.assertTrue(attributions[lhs_idx] == 0.5, f"{attributions}, {i}, {o}") + self.assertTrue(attributions[rhs_idx] == 0.5, f"{attributions}, {i}, {o}") + + attributions[lhs_idx] = 0 + attributions[rhs_idx] = 0 + self.assertTrue(torch.sum(attributions) < 1, f"{attributions}") + + + @parameterized.expand([ + # shape # mean + # ((10,10), ts([4, 6])), + # ((50,50), ts([25, 25])), + # ((20,20), ts([10, 10])), + ((50, 50),) + ]) + def test_gaussian_recovery(self, shape): + random.seed() + p = torch.zeros(shape) + for _ in range(3): + center = self.ts([int(random.random() * dim) for dim in shape]) + p += self._generate_gaussian_pdf(shape, center) + + thresh = math.sqrt(torch.mean(p)) + def _forward(inp): + return 1 if torch.sum(inp, dim=None) > thresh else 0 + + rex = ReX(_forward) + for o in self.all_options[:20]: + # o = (o[0], o[1], 20, o[3]) + + attributions = rex.attribute(p, 0, *o, merge=True)[0] + eps = 1e-12 + + attributions += eps + attrib_norm = attributions / torch.sum(attributions) + + p += eps + p = p/torch.sum(p) + + # visualize_tensor(p) + # visualize_tensor(attrib_norm) + # visualize_tensor(p - attrib_norm) + # print(F.kl_div(p.log(), attrib_norm)) + + self.assertLess(F.kl_div(p.log(), attrib_norm), 0.1) From b596bea5dfcff0855ab69616275b7b793ed1d4a4 Mon Sep 17 00:00:00 2001 From: stav-af Date: Sat, 16 Aug 2025 16:27:33 +0100 Subject: [PATCH 09/21] Debug gaussian test --- tests/attr/test_rex.py | 46 +++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index a761b2c31..6b52cc49f 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -12,10 +12,6 @@ def visualize_tensor(tensor, cmap='viridis'): - """ - Simple heatmap visualizer for 2D PyTorch tensors. - Automatically moves tensor to CPU and detaches from graph. - """ arr = tensor.detach().cpu().numpy() plt.imshow(arr, cmap=cmap) plt.colorbar() @@ -27,7 +23,7 @@ class Test(BaseTest): ts = torch.tensor depth_opts = range(4, 10) - n_partition_opts = range(4, 5) + n_partition_opts = range(4, 7) n_search_opts = range(10, 15) is_contiguous_opts = [False, True] @@ -105,8 +101,8 @@ def test_selector_function(self, input, idx): @parameterized.expand([ # input shape # important idx ((4,4), (0,0)), - # ((12, 12, 12), (1,2,1)), - # ((12, 12, 12, 6), (1,1,4,1)), + ((12, 12, 12), (1,2,1)), + ((12, 12, 12, 6), (1,1,4,1)), ((1920, 1080), (1, 1)) # image-like ]) def test_selector_function_large_input(self, input_shape, idx): @@ -159,38 +155,46 @@ def test_boolean_and(self, input_shape, lhs_idx, rhs_idx): @parameterized.expand([ # shape # mean - # ((10,10), ts([4, 6])), - # ((50,50), ts([25, 25])), - # ((20,20), ts([10, 10])), - ((50, 50),) + ((30,30),), + ((50, 50),), + ((100,100),) ]) def test_gaussian_recovery(self, shape): random.seed() + eps = 1e-12 + p = torch.zeros(shape) for _ in range(3): center = self.ts([int(random.random() * dim) for dim in shape]) p += self._generate_gaussian_pdf(shape, center) + + p += eps + p = p/torch.sum(p) thresh = math.sqrt(torch.mean(p)) def _forward(inp): - return 1 if torch.sum(inp, dim=None) > thresh else 0 + return 1 if torch.sum(inp) > thresh else 0 rex = ReX(_forward) - for o in self.all_options[:20]: - # o = (o[0], o[1], 20, o[3]) + for b in self.n_partition_opts: + attributions = rex.attribute(p, + 0, + n_partitions=b, + search_depth=10, + n_searches=25, + contiguous_partitioning=True, + merge=True)[0] - attributions = rex.attribute(p, 0, *o, merge=True)[0] - eps = 1e-12 attributions += eps attrib_norm = attributions / torch.sum(attributions) - p += eps - p = p/torch.sum(p) - # visualize_tensor(p) # visualize_tensor(attrib_norm) # visualize_tensor(p - attrib_norm) - # print(F.kl_div(p.log(), attrib_norm)) + + mid = 0.5 * (p + attrib_norm) + jsd = 0.5 * F.kl_div(p.log(), mid, reduction="sum") \ + + 0.5 * F.kl_div(attrib_norm.log(), mid, reduction="sum") - self.assertLess(F.kl_div(p.log(), attrib_norm), 0.1) + self.assertLess(jsd, 0.1) From 9a1a416df5e5d1b602835813d6ae1580d70c6f33 Mon Sep 17 00:00:00 2001 From: stav-af Date: Mon, 18 Aug 2025 15:23:27 +0100 Subject: [PATCH 10/21] Cleanup, refine docstrings --- captum/attr/_core/rex.py | 95 ++++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index a0a3e5b97..75c3ca5e3 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -41,10 +41,6 @@ def generate_mask(self, shape): def __len__(self): return self.size - - def __str__(self): - # unsafe - return self.generate_mask(None).to(torch.int).__str__() @dataclass(eq=False) @@ -78,24 +74,19 @@ def _apply_responsibility(fi, part, responsibility): return torch.where(mask, distributed, fi) -def _part_to_set(partition): - return frozenset(frozenset(p) if isinstance(p, list) else p for p in partition) - - def _calculate_responsibility(subject_partition: Partition, mutants: List[Mutant], - recovery_mutants: List[Mutant]) -> float: - - - recovery_set = {_part_to_set(m.partitions) for m in recovery_mutants} + consistent_mutants: List[Mutant]) -> float: + recovery_set = {frozenset(m.partitions) for m in consistent_mutants} valid_witnesses = [] for m in mutants: if subject_partition in m.partitions: continue W = m.partitions - W_set = _part_to_set(W) - W_plus_P_set = _part_to_set([subject_partition] + W) + + W_set = frozenset(W) + W_plus_P_set = frozenset([subject_partition] + W) # W alone does NOT recover, but W ∪ {P} DOES recover. if (W_set not in recovery_set) and (W_plus_P_set in recovery_set): @@ -105,35 +96,36 @@ def _calculate_responsibility(subject_partition: Partition, return 0.0 k = min(len(w) for w in valid_witnesses) - # Responsibility per your definition: 1 / (1 + k) return 1.0 / (1.0 + float(k)) def _generate_indices(ts): + # return a tensor containing all indices in the input shape return torch.tensor(tuple(itertools.product(*(range(s) for s in ts.shape))), dtype=torch.long) class ReX(PerturbationAttribution): """ - A perturbation-based approach to computing attribution, based on the + A perturbation-based approach to computing attribution, derived from the Halpern-Pearl definition of Actual Causality[1]. - ReX works by partitioning the input space, and masking each partition with the baseline value. It is fully - model agnostic, and relies only on a 'forward_func' returning a scalar. - - Intuitively, if masking a partition changes the prediction of the model, then that partition has - some responsibility (attribution > 0). Such partially masked partitions are called - mutants. The responsibility of a partition is defined as 1/(1+k) where - k is a minimum number of occluded partitions in a mutant which make forward_func's - output dependednt on the subject partition. + ReX works through a recursive search on the input to find areas that are + most responsible[3] for a models prediction. ReX splits an input into "partitions", + and masks combinations of these partitions with baseline (neutral) values + to form "mutants". - Partitions with nonzero responsibility are recusrively re-partitioned and masked in a search. - The algorithm runs multiple such searches, where each subsequent search uses the previously - computed attribution map as a heuristic for partitioning. + Intuitively, where masking a partition never changes a models + prediction, that partition is not responsible for the output. Conversely, where some + combination of masked partitions changes the prediction, each partition has responsibility 1/(1+k), where + k is the minimal number of *other* masked partitions required to create a dependence on a partition. + + Responsible partitions are recursively searched to refine responsibility estimates, and results + are (optionally) merged to produce the final attribution map. [1] - halpern 06 [2] - rex paper + [3] - responsibility and blame """ def __init__(self, forward_func): r""" @@ -155,38 +147,51 @@ def attribute(self, r""" Args: inputs: - An input or tuple of inputs whose corresponding output is to be explained. Each input + An input or tuple of inputs to be explain. Each input must be of the shape expected by the forward_func. Where multiple examples are provided, they must be listed in a tuple. baselines: - A neutral values to be used as occlusion values. Where a scalar is provided, it is used - as the masking value at each index. Where a tensor is provided, values are masked at - corresponding indices. Where a tuple of tensors is provided, it must be of the same length - as inputs; then baseline and input tensors are matched element-wise and treated as before. + A neutral values to be used as occlusion values. Where a scalar or tensor is provided, + they are broadcast to the input shape. Where tuples are provided, they are paired element-wise, + and must match the structure of the input search_depth (optional): - The maximum depth to which ReX will search. Where one is not provided, the default is 4 + The maximum depth to which ReX will refine responsibility estimates for causes. n_partitions (optional): - The number of partitions to be made out of the input at each search step. - This must be at most hte size of each input, and at least 1. + The maximum number of partitions to be made out of the input at each search step. + At least 1, and no larger than the partition size. Where ``contiguous partitioning`` is + set to False, partitions are created using previous attribution maps as heuristics. + + n_searches (optional): + The number of times the search is to be ran. + + contiguous_partitioning (optional): + If True, assumes locality of attribution and splits partitions contiguously along a dimension + (uselful for images). Otherwise, partitions are selected for element-wise using a greedy heuristic + approach. + + merge (optional): + If True, return the average of all search results across all n_searches. Otherwise return the + final search attribution without merging. """ + inputs, baselines = _format_input_baseline(inputs, baselines) _validate_input(inputs, baselines) - self._n_partitions = n_partitions - self._max_depth = search_depth - self._n_searches = n_searches + self._n_partitions = n_partitions + self._max_depth = search_depth + self._n_searches = n_searches self._is_contiguous = contiguous_partitioning - self._merge = merge + self._merge = merge is_input_tuple = isinstance(inputs, tuple) is_baseline_tuple = isinstance(baselines, tuple) attributions = [] - # match inputs and baselines, explain + # broadcast baselines, explain if is_input_tuple and is_baseline_tuple: for input, baseline in zip(inputs, baselines): attributions.append(self._explain(input, baseline)) @@ -203,7 +208,7 @@ def _explain(self, input, baseline): self._shape = input.shape self._size = input.numel() - initial_prediction = self.forward_func(input) + prediction = self.forward_func(input) prev_attribution = torch.full_like(input, 0.0, dtype=torch.float32) attribution = torch.full_like(input, 1.0/input.numel(), dtype=torch.float32) @@ -223,14 +228,11 @@ def _explain(self, input, baseline): partitions = self._contiguous_partition(prev_part, depth) \ if self._is_contiguous else self._partition(prev_part, attribution) - - mutants = [Mutant(ps, input, baseline, self._shape) for ps in _powerset(partitions)] - consistent_mutants = [mut for mut in mutants if self.forward_func(mut.data) == initial_prediction] + mutants = [Mutant(part, input, baseline, self._shape) for part in _powerset(partitions)] + consistent_mutants = [mut for mut in mutants if self.forward_func(mut.data) == prediction] - prev_part.generate_mask(self._shape) for part in partitions: resp = _calculate_responsibility(part, mutants, consistent_mutants) - attribution = _apply_responsibility(attribution, part, resp) if resp == 1 and len(part) > 1 and self._max_depth > depth: Q.append((part, depth + 1)) @@ -285,7 +287,6 @@ def _contiguous_partition(self, part, depth): break if split_dim == -1: return [part] - n_splits = min((dmax - dmin), self._n_partitions) - 1 # drop splits randomly From d5d75256c1af7980619f190e611a23b2d169dfe6 Mon Sep 17 00:00:00 2001 From: stav-af Date: Fri, 22 Aug 2025 11:26:09 +0100 Subject: [PATCH 11/21] Torchify --- captum/attr/_core/rex.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 75c3ca5e3..56c6f888f 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -23,10 +23,10 @@ def __init__(self, borders: List[slice] = None, elements=None, size=None): self._mask = None - def generate_mask(self, shape): + def generate_mask(self, shape, device): # generates a mask for a partition (False indicates membership) if self._mask is not None: return self._mask - self._mask = torch.zeros(shape, dtype=torch.bool) + self._mask = torch.zeros(shape, dtype=torch.bool, device=device) # non-contiguous case if self.elements is not None: @@ -52,8 +52,8 @@ class Mutant: def __init__(self, partitions: List[Partition], data: torch.Tensor, neutral, shape): # A bitmap in the shape of the input indicating membership to a partition in this mutant - mask = torch.zeros(shape, dtype=torch.bool) - for part in partitions: mask |= part.generate_mask(mask.shape) + mask = torch.zeros(shape, dtype=torch.bool, device=data.device) + for part in partitions: mask |= part.generate_mask(mask.shape, data.device) self.partitions = partitions self.data = torch.where(mask, data, neutral) @@ -69,7 +69,7 @@ def _powerset(s): def _apply_responsibility(fi, part, responsibility): distributed = responsibility / len(part) - mask = part.generate_mask(fi.shape) + mask = part.generate_mask(fi.shape, fi.device) return torch.where(mask, distributed, fi) @@ -203,15 +203,16 @@ def attribute(self, return _format_output(is_input_tuple, tuple(attributions)) - + @torch.no_grad() def _explain(self, input, baseline): + self._device = input.device self._shape = input.shape self._size = input.numel() prediction = self.forward_func(input) - prev_attribution = torch.full_like(input, 0.0, dtype=torch.float32) - attribution = torch.full_like(input, 1.0/input.numel(), dtype=torch.float32) + prev_attribution = torch.full_like(input, 0.0, dtype=torch.float32, device=self._device) + attribution = torch.full_like(input, 1.0/input.numel(), dtype=torch.float32, device=self._device) initial_partition = Partition( borders = list((0, top) for top in self._shape), @@ -232,7 +233,7 @@ def _explain(self, input, baseline): consistent_mutants = [mut for mut in mutants if self.forward_func(mut.data) == prediction] for part in partitions: - resp = _calculate_responsibility(part, mutants, consistent_mutants) + resp = _calculate_responsibility(part, mutants, consistent_mutants) if resp == 1 and len(part) > 1 and self._max_depth > depth: Q.append((part, depth + 1)) @@ -248,11 +249,11 @@ def _explain(self, input, baseline): def _partition(self, part: Partition, responsibility: torch.Tensor) -> List[Partition]: # shuffle candidate indices (randomize tiebreakers) - perm = torch.randperm(len(part.elements)) + perm = torch.randperm(len(part.elements), device=self._device) population = part.elements[perm] weights = responsibility[tuple(population.T)] - if torch.sum(weights, dim=None) == 0: weights = torch.ones_like(weights) / len(weights) + if torch.sum(weights, dim=None) == 0: weights = torch.ones_like(weights, device=self._device) / len(weights) target_weight = torch.sum(weights) / self._n_partitions # sort for greedy selection From be66dea0adb930b14a41d9c76e1cbd6d676cd676 Mon Sep 17 00:00:00 2001 From: stav-af Date: Sat, 23 Aug 2025 12:45:54 +0100 Subject: [PATCH 12/21] test cleanup --- tests/attr/test_rex.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index 6b52cc49f..a769abb11 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -77,7 +77,7 @@ def test_valid_input_baseline(self, input, baseline): if isinstance(input, tuple): inp_unwrapped = input[0] # Forward_func returns a constant, no responsibility in input - self.assertFalse(torch.sum(attributions, dim=None)) + self.assertEqual(torch.sum(attributions), 0) self.assertEqual(attributions.size(), inp_unwrapped.size()) @@ -95,7 +95,7 @@ def test_selector_function(self, input, idx): self.assertTrue(attributions[idx] == 1) attributions[idx] = 0 - self.assertFalse(torch.sum(attributions, dim=None)) + self.assertEqual(torch.sum(attributions), 0) @parameterized.expand([ @@ -110,7 +110,7 @@ def test_selector_function_large_input(self, input_shape, idx): input = torch.ones(*input_shape) attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=3, merge=False)[0] - self.assertTrue(attributions[idx]) + self.assertGreater(attributions[idx], 0) attributions[idx] = 0 self.assertLess(torch.sum(attributions, dim=None), 1) @@ -126,12 +126,12 @@ def test_boolean_or(self, input_shape, lhs_idx, rhs_idx): attributions = rex.attribute(input, 0, *o, merge=False)[0] - self.assertTrue(attributions[lhs_idx] == 1.0, f"{attributions}") - self.assertTrue(attributions[rhs_idx] == 1.0, f"{attributions}") + self.assertEqual(attributions[lhs_idx], 1.0, f"{attributions}") + self.assertEqual(attributions[rhs_idx], 1.0, f"{attributions}") attributions[lhs_idx] = 0 attributions[rhs_idx] = 0 - self.assertTrue(torch.sum(attributions) < 1, f"{attributions}") + self.assertLess(torch.sum(attributions), 1, f"{attributions}") @parameterized.expand([ @@ -145,12 +145,12 @@ def test_boolean_and(self, input_shape, lhs_idx, rhs_idx): attributions = rex.attribute(input, 0, *o, merge=False)[0] - self.assertTrue(attributions[lhs_idx] == 0.5, f"{attributions}, {i}, {o}") - self.assertTrue(attributions[rhs_idx] == 0.5, f"{attributions}, {i}, {o}") + self.assertEqual(attributions[lhs_idx], 0.5, f"{attributions}, {i}, {o}") + self.assertEqual(attributions[rhs_idx], 0.5, f"{attributions}, {i}, {o}") attributions[lhs_idx] = 0 attributions[rhs_idx] = 0 - self.assertTrue(torch.sum(attributions) < 1, f"{attributions}") + self.assertLess(torch.sum(attributions), 1, f"{attributions}") @parameterized.expand([ @@ -189,9 +189,9 @@ def _forward(inp): attributions += eps attrib_norm = attributions / torch.sum(attributions) - # visualize_tensor(p) - # visualize_tensor(attrib_norm) - # visualize_tensor(p - attrib_norm) + visualize_tensor(p) + visualize_tensor(attrib_norm) + visualize_tensor(p - attrib_norm) mid = 0.5 * (p + attrib_norm) jsd = 0.5 * F.kl_div(p.log(), mid, reduction="sum") \ From 681b395b18ac94944c8c520bdba995f60af32525 Mon Sep 17 00:00:00 2001 From: stav-af Date: Tue, 26 Aug 2025 12:35:34 +0100 Subject: [PATCH 13/21] checkpoint --- captum/attr/_core/rex.py | 33 ++++++++++++++------------------- tests/attr/test_rex.py | 25 ++++++++++++------------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 56c6f888f..c2160327d 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -142,8 +142,7 @@ def attribute(self, search_depth: int = 10, n_partitions: int = 4, n_searches: int = 5, - contiguous_partitioning: bool = False, - merge: bool = True) -> TensorOrTupleOfTensorsGeneric: + assume_locality: bool = False) -> TensorOrTupleOfTensorsGeneric: r""" Args: inputs: @@ -167,14 +166,10 @@ def attribute(self, n_searches (optional): The number of times the search is to be ran. - contiguous_partitioning (optional): - If True, assumes locality of attribution and splits partitions contiguously along a dimension - (uselful for images). Otherwise, partitions are selected for element-wise using a greedy heuristic - approach. - - merge (optional): - If True, return the average of all search results across all n_searches. Otherwise return the - final search attribution without merging. + assume_locality (optional): + Where True, partitioning is contiguous, and attribution maps are merged after each serach. + Otherwise, partitioning is initially random, then uses the previous attribution map + as a heuristic for further searches, returning the result of the final search. """ inputs, baselines = _format_input_baseline(inputs, baselines) @@ -183,8 +178,7 @@ def attribute(self, self._n_partitions = n_partitions self._max_depth = search_depth self._n_searches = n_searches - self._is_contiguous = contiguous_partitioning - self._merge = merge + self._assume_locality = assume_locality is_input_tuple = isinstance(inputs, tuple) is_baseline_tuple = isinstance(baselines, tuple) @@ -212,7 +206,7 @@ def _explain(self, input, baseline): prediction = self.forward_func(input) prev_attribution = torch.full_like(input, 0.0, dtype=torch.float32, device=self._device) - attribution = torch.full_like(input, 1.0/input.numel(), dtype=torch.float32, device=self._device) + attribution = torch.full_like(input, 1.0/self._size, dtype=torch.float32, device=self._device) initial_partition = Partition( borders = list((0, top) for top in self._shape), @@ -227,22 +221,23 @@ def _explain(self, input, baseline): while Q: prev_part, depth = Q.popleft() partitions = self._contiguous_partition(prev_part, depth) \ - if self._is_contiguous else self._partition(prev_part, attribution) + if self._assume_locality else self._partition(prev_part, attribution) mutants = [Mutant(part, input, baseline, self._shape) for part in _powerset(partitions)] consistent_mutants = [mut for mut in mutants if self.forward_func(mut.data) == prediction] for part in partitions: resp = _calculate_responsibility(part, mutants, consistent_mutants) - + attribution = _apply_responsibility(attribution, part, resp) + if resp == 1 and len(part) > 1 and self._max_depth > depth: Q.append((part, depth + 1)) - else: - attribution = _apply_responsibility(attribution, part, resp) - - if self._merge: + + # take average of responsibility maps + if self._assume_locality: prev_attribution += (1/i) * (attribution - prev_attribution) attribution = prev_attribution + return attribution.clone().detach() diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index a769abb11..9c3e2996a 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -25,9 +25,9 @@ class Test(BaseTest): depth_opts = range(4, 10) n_partition_opts = range(4, 7) n_search_opts = range(10, 15) - is_contiguous_opts = [False, True] + assume_locality_opts = [True, False] - all_options = list(itertools.product(depth_opts, n_partition_opts, n_search_opts, is_contiguous_opts)) + all_options = list(itertools.product(depth_opts, n_partition_opts, n_search_opts, assume_locality_opts)) def _generate_gaussian_pdf(self, shape, mean): k = len(shape) @@ -71,7 +71,7 @@ def test_valid_input_baseline(self, input, baseline): for o in self.all_options: rex = ReX(lambda x: True) - attributions = rex.attribute(input, baseline, *o, merge=False)[0] + attributions = rex.attribute(input, baseline, *o)[0] inp_unwrapped = input if isinstance(input, tuple): inp_unwrapped = input[0] @@ -91,8 +91,8 @@ def test_selector_function(self, input, idx): for o in self.all_options: rex = ReX(lambda x: x[idx]) - attributions = rex.attribute(input, 0, *o, merge=False)[0] - self.assertTrue(attributions[idx] == 1) + attributions = rex.attribute(input, 0, *o)[0] + self.assertEqual(attributions[idx], 1, f"expected 1 at {idx} but found {attributions}") attributions[idx] = 0 self.assertEqual(torch.sum(attributions), 0) @@ -109,7 +109,7 @@ def test_selector_function_large_input(self, input_shape, idx): rex = ReX(lambda x: x[idx]) input = torch.ones(*input_shape) - attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=3, merge=False)[0] + attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=3)[0] self.assertGreater(attributions[idx], 0) attributions[idx] = 0 self.assertLess(torch.sum(attributions, dim=None), 1) @@ -124,7 +124,7 @@ def test_boolean_or(self, input_shape, lhs_idx, rhs_idx): rex = ReX(lambda x: max(x[lhs_idx], x[rhs_idx])) input = torch.ones(input_shape) - attributions = rex.attribute(input, 0, *o, merge=False)[0] + attributions = rex.attribute(input, 0, *o)[0] self.assertEqual(attributions[lhs_idx], 1.0, f"{attributions}") self.assertEqual(attributions[rhs_idx], 1.0, f"{attributions}") @@ -143,7 +143,7 @@ def test_boolean_and(self, input_shape, lhs_idx, rhs_idx): rex = ReX(lambda x: min(x[lhs_idx], x[rhs_idx])) input = torch.ones(input_shape) - attributions = rex.attribute(input, 0, *o, merge=False)[0] + attributions = rex.attribute(input, 0, *o)[0] self.assertEqual(attributions[lhs_idx], 0.5, f"{attributions}, {i}, {o}") self.assertEqual(attributions[rhs_idx], 0.5, f"{attributions}, {i}, {o}") @@ -182,16 +182,15 @@ def _forward(inp): n_partitions=b, search_depth=10, n_searches=25, - contiguous_partitioning=True, - merge=True)[0] + assume_locality=True)[0] attributions += eps attrib_norm = attributions / torch.sum(attributions) - visualize_tensor(p) - visualize_tensor(attrib_norm) - visualize_tensor(p - attrib_norm) + # visualize_tensor(p) + # visualize_tensor(attrib_norm) + # visualize_tensor(p - attrib_norm) mid = 0.5 * (p + attrib_norm) jsd = 0.5 * F.kl_div(p.log(), mid, reduction="sum") \ From f3d59b26feb94c6689c70a51625d204c94216426 Mon Sep 17 00:00:00 2001 From: stav-af Date: Sat, 30 Aug 2025 13:05:45 +0100 Subject: [PATCH 14/21] Refinements --- captum/attr/_core/rex.py | 9 +++++---- tests/attr/test_rex.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index c2160327d..9523d4acc 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -101,7 +101,7 @@ def _calculate_responsibility(subject_partition: Partition, def _generate_indices(ts): # return a tensor containing all indices in the input shape - return torch.tensor(tuple(itertools.product(*(range(s) for s in ts.shape))), dtype=torch.long) + return torch.tensor(tuple(itertools.product(*(range(s) for s in ts.shape))), dtype=torch.long, device=ts.device) class ReX(PerturbationAttribution): @@ -109,14 +109,15 @@ class ReX(PerturbationAttribution): A perturbation-based approach to computing attribution, derived from the Halpern-Pearl definition of Actual Causality[1]. - ReX works through a recursive search on the input to find areas that are + ReX conducts a recursive search on the input to find areas that are most responsible[3] for a models prediction. ReX splits an input into "partitions", and masks combinations of these partitions with baseline (neutral) values to form "mutants". Intuitively, where masking a partition never changes a models prediction, that partition is not responsible for the output. Conversely, where some - combination of masked partitions changes the prediction, each partition has responsibility 1/(1+k), where + combination of masked partitions changes the prediction, each partition has some responsibility. + Specifically, their responsibility is 1/(1+k) where k is the minimal number of *other* masked partitions required to create a dependence on a partition. Responsible partitions are recursively searched to refine responsibility estimates, and results @@ -125,7 +126,7 @@ class ReX(PerturbationAttribution): [1] - halpern 06 [2] - rex paper - [3] - responsibility and blame + [3] - Responsibility and Blame; https://arxiv.org/pdf/cs/0312038 """ def __init__(self, forward_func): r""" diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index 9c3e2996a..96ac258a2 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -1,13 +1,13 @@ -from captum.attr._core.rex import * +from captum.attr._core.rex import ReX from captum.testing.helpers.basic import BaseTest from parameterized import parameterized +import math import random import statistics import torch import torch.nn.functional as F -import numpy as np import matplotlib.pyplot as plt @@ -112,7 +112,7 @@ def test_selector_function_large_input(self, input_shape, idx): attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=3)[0] self.assertGreater(attributions[idx], 0) attributions[idx] = 0 - self.assertLess(torch.sum(attributions, dim=None), 1) + self.assertLess(torch.sum(attributions), 1) @parameterized.expand([ From 674b14ae252dac27c8ede7fe22dea61128ddb1ab Mon Sep 17 00:00:00 2001 From: stav-af Date: Sat, 30 Aug 2025 13:09:06 +0100 Subject: [PATCH 15/21] Self-hosted --- .github/workflows/lint.yml | 1 + .github/workflows/test-conda-cpu.yml | 1 + .github/workflows/test-pip-cpu-with-type-checks.yml | 1 + .github/workflows/test-pip-cpu.yml | 1 + .github/workflows/test-pip-gpu.yml | 1 + 5 files changed, 5 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ab4d71bc6..9eb4cbe98 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -10,6 +10,7 @@ on: jobs: tests: + runs-on: self-hosted uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: linux.12xlarge diff --git a/.github/workflows/test-conda-cpu.yml b/.github/workflows/test-conda-cpu.yml index e0da5e42e..6c0a73c1e 100644 --- a/.github/workflows/test-conda-cpu.yml +++ b/.github/workflows/test-conda-cpu.yml @@ -13,6 +13,7 @@ env: jobs: tests: + runs-on: ubuntu-latest strategy: matrix: python_version: ["3.9", "3.10", "3.11", "3.12"] diff --git a/.github/workflows/test-pip-cpu-with-type-checks.yml b/.github/workflows/test-pip-cpu-with-type-checks.yml index 3336a76f6..284abb45b 100644 --- a/.github/workflows/test-pip-cpu-with-type-checks.yml +++ b/.github/workflows/test-pip-cpu-with-type-checks.yml @@ -10,6 +10,7 @@ on: jobs: tests: + runs-on: ubuntu-latest strategy: matrix: pytorch_args: ["", "-n"] diff --git a/.github/workflows/test-pip-cpu.yml b/.github/workflows/test-pip-cpu.yml index 83a513ac2..fa3042368 100644 --- a/.github/workflows/test-pip-cpu.yml +++ b/.github/workflows/test-pip-cpu.yml @@ -10,6 +10,7 @@ on: jobs: tests: + runs-on: ubuntu-latest strategy: matrix: pytorch_args: ["-v 1.10", "-v 1.11", "-v 1.12", "-v 1.13", "-v 2.0.0", "-v 2.1.0", "-v 2.2.0", "-v 2.3.0"] diff --git a/.github/workflows/test-pip-gpu.yml b/.github/workflows/test-pip-gpu.yml index 117f515f4..663c6a3bb 100644 --- a/.github/workflows/test-pip-gpu.yml +++ b/.github/workflows/test-pip-gpu.yml @@ -10,6 +10,7 @@ on: jobs: tests: + runs-on: ubuntu-latest strategy: matrix: cuda_arch_version: ["12.1"] From c47336715037c12435d46ab998e388df27b5376e Mon Sep 17 00:00:00 2001 From: stav-af Date: Sat, 30 Aug 2025 13:13:20 +0100 Subject: [PATCH 16/21] Retry --- .github/workflows/lint.yml | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9eb4cbe98..256e0f1cd 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -3,22 +3,24 @@ name: Captum Lint on: pull_request: push: - branches: - - master - + branches: [ master ] workflow_dispatch: jobs: tests: + # Make sure your runner has these labels (or just "self-hosted") runs-on: self-hosted - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main - with: - runner: linux.12xlarge - docker-image: cimg/python:3.11 - repository: pytorch/captum - script: | - sudo chmod -R 777 . - ./scripts/install_via_pip.sh - ufmt check . - flake8 - sphinx-build -WT --keep-going sphinx/source sphinx/build + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Lint & docs + run: | + sudo chmod -R 777 . + ./scripts/install_via_pip.sh + ufmt check . + flake8 + sphinx-build -WT --keep-going sphinx/source sphinx/build From 81e4f571a2c821dff31a67439de735e8934ec28b Mon Sep 17 00:00:00 2001 From: stav-af Date: Sat, 30 Aug 2025 13:17:12 +0100 Subject: [PATCH 17/21] CI --- .github/workflows/lint.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 256e0f1cd..e21e36fe4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,11 +12,6 @@ jobs: runs-on: self-hosted steps: - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - name: Lint & docs run: | sudo chmod -R 777 . From 287283afeca5bdbc9bb610ae13108dd10d2da436 Mon Sep 17 00:00:00 2001 From: stav-af Date: Sun, 31 Aug 2025 11:19:40 +0100 Subject: [PATCH 18/21] CI Compat --- captum/attr/_core/rex.py | 256 +++++++++++++++++++++++---------------- tests/attr/test_rex.py | 185 +++++++++++++++------------- 2 files changed, 257 insertions(+), 184 deletions(-) diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index 9523d4acc..cae57d765 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -2,46 +2,57 @@ # pyre-strict import itertools -from typing import List -import torch -from collections import deque -import random import math +import random +from collections import deque from dataclasses import dataclass +from typing import Tuple, cast, List, Sized -from captum.attr._utils.attribution import PerturbationAttribution +import torch from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric -from captum.attr._utils.common import _format_input_baseline, _validate_input, _format_output + +from captum.attr._utils.attribution import PerturbationAttribution +from captum.attr._utils.common import ( + _format_input_baseline, + _format_output, + _validate_input, +) from captum.log.dummy_log import log_usage -class Partition: - def __init__(self, borders: List[slice] = None, elements=None, size=None): +class Partition(Sized): + def __init__( + self, + borders: None | List[Tuple[int, int]] = None, + elements: None | torch.Tensor = None, + size: int = -1, + ): self.borders = borders self.elements = elements self.size = size - + self._mask = None def generate_mask(self, shape, device): # generates a mask for a partition (False indicates membership) - if self._mask is not None: return self._mask + if self._mask is not None: + return self._mask self._mask = torch.zeros(shape, dtype=torch.bool, device=device) # non-contiguous case if self.elements is not None: self._mask[tuple(self.elements.T)] = True - + # contiguous case elif self.borders is not None: slices = list(slice(lo, hi) for (lo, hi) in self.borders) self._mask[slices] = True - + return self._mask - + def __len__(self): return self.size - + @dataclass(eq=False) class Mutant: @@ -50,10 +61,12 @@ class Mutant: # eagerly create the underlying mutant data def __init__(self, partitions: List[Partition], data: torch.Tensor, neutral, shape): - - # A bitmap in the shape of the input indicating membership to a partition in this mutant + + # A bitmap in the shape of the input indicating membership to + # a partition in this mutant mask = torch.zeros(shape, dtype=torch.bool, device=data.device) - for part in partitions: mask |= part.generate_mask(mask.shape, data.device) + for part in partitions: + mask |= part.generate_mask(mask.shape, data.device) self.partitions = partitions self.data = torch.where(mask, data, neutral) @@ -63,8 +76,9 @@ def __len__(self): def _powerset(s): - return (list(combo) for r in range(len(s)+1) - for combo in itertools.combinations(s, r)) + return ( + list(combo) for r in range(len(s) + 1) for combo in itertools.combinations(s, r) + ) def _apply_responsibility(fi, part, responsibility): @@ -74,9 +88,11 @@ def _apply_responsibility(fi, part, responsibility): return torch.where(mask, distributed, fi) -def _calculate_responsibility(subject_partition: Partition, - mutants: List[Mutant], - consistent_mutants: List[Mutant]) -> float: +def _calculate_responsibility( + subject_partition: Partition, + mutants: List[Mutant], + consistent_mutants: List[Mutant], +) -> float: recovery_set = {frozenset(m.partitions) for m in consistent_mutants} valid_witnesses = [] @@ -84,7 +100,7 @@ def _calculate_responsibility(subject_partition: Partition, if subject_partition in m.partitions: continue W = m.partitions - + W_set = frozenset(W) W_plus_P_set = frozenset([subject_partition] + W) @@ -101,33 +117,39 @@ def _calculate_responsibility(subject_partition: Partition, def _generate_indices(ts): # return a tensor containing all indices in the input shape - return torch.tensor(tuple(itertools.product(*(range(s) for s in ts.shape))), dtype=torch.long, device=ts.device) + return torch.tensor( + tuple(itertools.product(*(range(s) for s in ts.shape))), + dtype=torch.long, + device=ts.device, + ) class ReX(PerturbationAttribution): """ A perturbation-based approach to computing attribution, derived from the - Halpern-Pearl definition of Actual Causality[1]. - - ReX conducts a recursive search on the input to find areas that are - most responsible[3] for a models prediction. ReX splits an input into "partitions", + Halpern-Pearl definition of Actual Causality[1]. + + ReX conducts a recursive search on the input to find areas that are + most responsible[3] for a models prediction. ReX splits an input into "partitions", and masks combinations of these partitions with baseline (neutral) values - to form "mutants". - + to form "mutants". + Intuitively, where masking a partition never changes a models - prediction, that partition is not responsible for the output. Conversely, where some - combination of masked partitions changes the prediction, each partition has some responsibility. - Specifically, their responsibility is 1/(1+k) where - k is the minimal number of *other* masked partitions required to create a dependence on a partition. + prediction, that partition is not responsible for the output. Conversely, + where some combination of masked partitions changes the prediction, each + partition has some responsibility. Specifically, their responsibility is 1/(1+k) + where k is the minimal number of *other* masked partitions required to create + a dependence on a partition. - Responsible partitions are recursively searched to refine responsibility estimates, and results - are (optionally) merged to produce the final attribution map. + Responsible partitions are recursively searched to refine responsibility estimates, + and results are (optionally) merged to produce the final attribution map. [1] - halpern 06 [2] - rex paper [3] - Responsibility and Blame; https://arxiv.org/pdf/cs/0312038 """ + def __init__(self, forward_func): r""" Args: @@ -137,49 +159,55 @@ def __init__(self, forward_func): PerturbationAttribution.__init__(self, forward_func) @log_usage(part_of_slo=True) - def attribute(self, - inputs: TensorOrTupleOfTensorsGeneric, - baselines: BaselineType = 0, - search_depth: int = 10, - n_partitions: int = 4, - n_searches: int = 5, - assume_locality: bool = False) -> TensorOrTupleOfTensorsGeneric: + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = 0, + search_depth: int = 10, + n_partitions: int = 4, + n_searches: int = 5, + assume_locality: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: r""" Args: inputs: An input or tuple of inputs to be explain. Each input - must be of the shape expected by the forward_func. Where multiple examples are - provided, they must be listed in a tuple. - - baselines: - A neutral values to be used as occlusion values. Where a scalar or tensor is provided, - they are broadcast to the input shape. Where tuples are provided, they are paired element-wise, - and must match the structure of the input + must be of the shape expected by the forward_func. Where multiple + examples are provided, they must be listed in a tuple. + + baselines: + A neutral values to be used as occlusion values. Where a scalar or + tensor is provided, they are broadcast to the input shape. Where + tuples are provided, they are paired element-wise, and must match + the structure of the input. search_depth (optional): - The maximum depth to which ReX will refine responsibility estimates for causes. - + The maximum depth to which ReX will refine responsibility estimates + for causes. + n_partitions (optional): - The maximum number of partitions to be made out of the input at each search step. - At least 1, and no larger than the partition size. Where ``contiguous partitioning`` is - set to False, partitions are created using previous attribution maps as heuristics. + The maximum number of partitions to be made out of the input at each + search step. At least 1, and no larger than the partition size. Where + ``contiguous partitioning`` is set to False, partitions are created + using previous attribution maps as heuristics. n_searches (optional): The number of times the search is to be ran. - + assume_locality (optional): - Where True, partitioning is contiguous, and attribution maps are merged after each serach. - Otherwise, partitioning is initially random, then uses the previous attribution map - as a heuristic for further searches, returning the result of the final search. + Where True, partitioning is contiguous and attribution maps are merged + after each search. Otherwise, partitioning is initially random, then + uses the previous attribution map as a heuristic for further searches, + returning the result of the final search. """ inputs, baselines = _format_input_baseline(inputs, baselines) _validate_input(inputs, baselines) - self._n_partitions = n_partitions - self._max_depth = search_depth - self._n_searches = n_searches - self._assume_locality = assume_locality + self._n_partitions: int = n_partitions + self._max_depth: int = search_depth + self._n_searches: int = n_searches + self._assume_locality: bool = assume_locality is_input_tuple = isinstance(inputs, tuple) is_baseline_tuple = isinstance(baselines, tuple) @@ -196,60 +224,81 @@ def attribute(self, else: attributions.append(self._explain(inputs, baselines)) - return _format_output(is_input_tuple, tuple(attributions)) + return cast( + TensorOrTupleOfTensorsGeneric, + _format_output(is_input_tuple, tuple(attributions)), + ) @torch.no_grad() - def _explain(self, input, baseline): + def _explain(self, input, baseline) -> torch.Tensor: self._device = input.device self._shape = input.shape self._size = input.numel() prediction = self.forward_func(input) - prev_attribution = torch.full_like(input, 0.0, dtype=torch.float32, device=self._device) - attribution = torch.full_like(input, 1.0/self._size, dtype=torch.float32, device=self._device) + prev_attribution = torch.full_like( + input, 0.0, dtype=torch.float32, device=self._device + ) + attribution = torch.full_like( + input, 1.0 / self._size, dtype=torch.float32, device=self._device + ) initial_partition = Partition( - borders = list((0, top) for top in self._shape), - elements = _generate_indices(input), - size = self._size + borders=list((0, top) for top in self._shape), + elements=_generate_indices(input), + size=self._size, ) for i in range(1, self._n_searches + 1): - Q = deque() + Q: deque = deque() Q.append((initial_partition, 0)) while Q: prev_part, depth = Q.popleft() - partitions = self._contiguous_partition(prev_part, depth) \ - if self._assume_locality else self._partition(prev_part, attribution) - - mutants = [Mutant(part, input, baseline, self._shape) for part in _powerset(partitions)] - consistent_mutants = [mut for mut in mutants if self.forward_func(mut.data) == prediction] + partitions = ( + self._contiguous_partition(prev_part, depth) + if self._assume_locality + else self._partition(prev_part, attribution) + ) + + mutants = [ + Mutant(part, input, baseline, self._shape) + for part in _powerset(partitions) + ] + consistent_mutants = [ + mut for mut in mutants if self.forward_func(mut.data) == prediction + ] for part in partitions: resp = _calculate_responsibility(part, mutants, consistent_mutants) attribution = _apply_responsibility(attribution, part, resp) - + if resp == 1 and len(part) > 1 and self._max_depth > depth: Q.append((part, depth + 1)) # take average of responsibility maps - if self._assume_locality: - prev_attribution += (1/i) * (attribution - prev_attribution) + if self._assume_locality: + prev_attribution += (1 / i) * (attribution - prev_attribution) attribution = prev_attribution - return attribution.clone().detach() - - def _partition(self, part: Partition, responsibility: torch.Tensor) -> List[Partition]: + def _partition( + self, part: Partition, responsibility: torch.Tensor + ) -> List[Partition]: # shuffle candidate indices (randomize tiebreakers) - perm = torch.randperm(len(part.elements), device=self._device) - population = part.elements[perm] + perm = torch.randperm(len(part), device=self._device) + + assert ( + part is not None + ), "Partitioning strategy changed mid search. Contact developer" + population = part.elements[perm] # type: ignore + weights = responsibility[tuple(population.T)] - - if torch.sum(weights, dim=None) == 0: weights = torch.ones_like(weights, device=self._device) / len(weights) + + if torch.sum(weights, dim=None) == 0: + weights = torch.ones_like(weights, device=self._device) / len(weights) target_weight = torch.sum(weights) / self._n_partitions # sort for greedy selection @@ -257,23 +306,22 @@ def _partition(self, part: Partition, responsibility: torch.Tensor) -> List[Part weight_sorted, pop_sorted = weights[idx], population[idx] # cumulative sum of weights / weight per bucket rounded down gives us bucket ids - eps = torch.finfo(weight_sorted.dtype).eps + eps = torch.finfo(weight_sorted.dtype).eps c = weight_sorted.cumsum(0) + eps - bin_id = torch.div(c, target_weight, rounding_mode='floor').clamp_min(0).long() - + bin_id = torch.div(c, target_weight, rounding_mode="floor").clamp_min(0).long() + # count elements in each bucket, and split input accordingly _, counts = torch.unique_consecutive(bin_id, return_counts=True) groups = torch.split(pop_sorted, counts.tolist()) - + partitions = [Partition(elements=g, size=len(g)) for g in groups] return partitions - def _contiguous_partition(self, part, depth): ndim = len(self._shape) split_dim = -1 - # find a dimension we can split + # find a dimension we can split dmin, dmax = max(self._shape), 0 for i in range(ndim): candidate_dim = (i + depth) % ndim @@ -283,7 +331,8 @@ def _contiguous_partition(self, part, depth): split_dim = candidate_dim break - if split_dim == -1: return [part] + if split_dim == -1: + return [part] n_splits = min((dmax - dmin), self._n_partitions) - 1 # drop splits randomly @@ -293,18 +342,19 @@ def _contiguous_partition(self, part, depth): bins = [] for i in range(len(split_borders) - 1): new_borders = list(part.borders) - new_borders[split_dim] = (split_borders[i], split_borders[i+1]) + new_borders[split_dim] = (split_borders[i], split_borders[i + 1]) - bins.append(Partition( - borders = tuple(new_borders), - size = math.prod(hi - lo for (lo, hi) in new_borders) - )) + bins.append( + Partition( + borders=tuple(new_borders), + size=math.prod(hi - lo for (lo, hi) in new_borders), + ) + ) return bins + def multiplies_by_inputs(self): + return False - def multiplies_by_inputs(self) -> bool: - return False - - def has_convergence_delta(self) -> bool: - return True \ No newline at end of file + def has_convergence_delta(self): + return True diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py index 96ac258a2..b6fc82f9c 100644 --- a/tests/attr/test_rex.py +++ b/tests/attr/test_rex.py @@ -1,17 +1,18 @@ -from captum.attr._core.rex import ReX - -from captum.testing.helpers.basic import BaseTest -from parameterized import parameterized - +import itertools import math import random import statistics + +import matplotlib.pyplot as plt import torch import torch.nn.functional as F -import matplotlib.pyplot as plt +from captum.attr._core.rex import ReX +from captum.testing.helpers.basic import BaseTest +from parameterized import parameterized -def visualize_tensor(tensor, cmap='viridis'): + +def visualize_tensor(tensor, cmap="viridis"): arr = tensor.detach().cpu().numpy() plt.imshow(arr, cmap=cmap) plt.colorbar() @@ -21,13 +22,17 @@ def visualize_tensor(tensor, cmap='viridis'): class Test(BaseTest): # rename for convenience ts = torch.tensor - + depth_opts = range(4, 10) n_partition_opts = range(4, 7) n_search_opts = range(10, 15) assume_locality_opts = [True, False] - all_options = list(itertools.product(depth_opts, n_partition_opts, n_search_opts, assume_locality_opts)) + all_options = list( + itertools.product( + depth_opts, n_partition_opts, n_search_opts, assume_locality_opts + ) + ) def _generate_gaussian_pdf(self, shape, mean): k = len(shape) @@ -36,37 +41,42 @@ def _generate_gaussian_pdf(self, shape, mean): dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov) grids = torch.meshgrid( - *[torch.arange(n, dtype=torch.float64) for n in shape], - indexing='ij' + *[torch.arange(n, dtype=torch.float64) for n in shape], indexing="ij" ) coords = torch.stack(grids, dim=-1).reshape(-1, k) pdf_vals = torch.exp(dist.log_prob(coords)) return pdf_vals.reshape(*shape) - @parameterized.expand([ + @parameterized.expand( + [ # inputs: baselines: - (ts([1,2,3]), ts([[2,3], [3,4]])), - ((ts([1]),ts([2]),ts([3])), (ts([1]),ts([2]))), - ((ts([1])), ()), - ((), ts([1])) - ]) + (ts([1, 2, 3]), ts([[2, 3], [3, 4]])), + ((ts([1]), ts([2]), ts([3])), (ts([1]), ts([2]))), + ((ts([1])), ()), + ((), ts([1])), + ] + ) def test_input_baseline_mismatch_throws(self, input, baseline): - rex = ReX(lambda x: 1/0) # dummy forward, should be unreachable + rex = ReX(lambda x: 1 / 0) # dummy forward, should be unreachable with self.assertRaises(AssertionError): rex.attribute(input, baseline) - - @parameterized.expand([ - (ts([1,2,3]), 0), - (ts([[1,2,3], [4,5,6]]), 0), - (ts([1,2,3,4]), ts([0,0,0,0])), - (ts([[1, 2], [1,2]]), ts([[0,0], [0,0]])), - (ts([[[1,2], [3,4]], [[5,6], [7,8]]]), 0), - ((ts([1,2]), ts([3,4]), ts([5,6])), (0, 0, 0)), - ((ts([1,2]), ts([3,4]), ts([5,6])), (ts([0,0]), ts([0,0]), ts([0,0]))), - ((ts([1,2]), ts([3,4])), (ts([0,0]), ts([0, 0]))), - ]) + @parameterized.expand( + [ + (ts([1, 2, 3]), 0), + (ts([[1, 2, 3], [4, 5, 6]]), 0), + (ts([1, 2, 3, 4]), ts([0, 0, 0, 0])), + (ts([[1, 2], [1, 2]]), ts([[0, 0], [0, 0]])), + (ts([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), 0), + ((ts([1, 2]), ts([3, 4]), ts([5, 6])), (0, 0, 0)), + ( + (ts([1, 2]), ts([3, 4]), ts([5, 6])), + (ts([0, 0]), ts([0, 0]), ts([0, 0])), + ), + ((ts([1, 2]), ts([3, 4])), (ts([0, 0]), ts([0, 0]))), + ] + ) def test_valid_input_baseline(self, input, baseline): for o in self.all_options: rex = ReX(lambda x: True) @@ -74,56 +84,64 @@ def test_valid_input_baseline(self, input, baseline): attributions = rex.attribute(input, baseline, *o)[0] inp_unwrapped = input - if isinstance(input, tuple): inp_unwrapped = input[0] + if isinstance(input, tuple): + inp_unwrapped = input[0] - # Forward_func returns a constant, no responsibility in input + # Forward_func returns a constant, no responsibility in input self.assertEqual(torch.sum(attributions), 0) self.assertEqual(attributions.size(), inp_unwrapped.size()) - - @parameterized.expand([ - # input # selected_idx - (ts([1,2,3]), 0), - (ts([[1,2], [3,4]]), (0, 1)), - (ts([[[1, 2], [3, 4]], [[5,6], [7,8]]]), (0, 1, 0)) - ]) + @parameterized.expand( + [ + # input # selected_idx + (ts([1, 2, 3]), 0), + (ts([[1, 2], [3, 4]]), (0, 1)), + (ts([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), (0, 1, 0)), + ] + ) def test_selector_function(self, input, idx): for o in self.all_options: rex = ReX(lambda x: x[idx]) attributions = rex.attribute(input, 0, *o)[0] - self.assertEqual(attributions[idx], 1, f"expected 1 at {idx} but found {attributions}") + self.assertEqual( + attributions[idx], 1, f"expected 1 at {idx} but found {attributions}" + ) attributions[idx] = 0 self.assertEqual(torch.sum(attributions), 0) - - @parameterized.expand([ - # input shape # important idx - ((4,4), (0,0)), - ((12, 12, 12), (1,2,1)), - ((12, 12, 12, 6), (1,1,4,1)), - ((1920, 1080), (1, 1)) # image-like - ]) + @parameterized.expand( + [ + # input shape # important idx + ((4, 4), (0, 0)), + ((12, 12, 12), (1, 2, 1)), + ((12, 12, 12, 6), (1, 1, 4, 1)), + ((1920, 1080), (1, 1)), # image-like + ] + ) def test_selector_function_large_input(self, input_shape, idx): rex = ReX(lambda x: x[idx]) input = torch.ones(*input_shape) - attributions = rex.attribute(input, 0, n_partitions=2, search_depth=10, n_searches=3)[0] + attributions = rex.attribute( + input, 0, n_partitions=2, search_depth=10, n_searches=3 + )[0] self.assertGreater(attributions[idx], 0) attributions[idx] = 0 self.assertLess(torch.sum(attributions), 1) - - @parameterized.expand([ - # input shape # lhs_idx # rhs_idx - ((2,4), (0,2), (1,3)) - ]) + @parameterized.expand( + [ + # input shape # lhs_idx # rhs_idx + ((2, 4), (0, 2), (1, 3)) + ] + ) def test_boolean_or(self, input_shape, lhs_idx, rhs_idx): for o in self.all_options: rex = ReX(lambda x: max(x[lhs_idx], x[rhs_idx])) input = torch.ones(input_shape) - + attributions = rex.attribute(input, 0, *o)[0] self.assertEqual(attributions[lhs_idx], 1.0, f"{attributions}") @@ -133,16 +151,17 @@ def test_boolean_or(self, input_shape, lhs_idx, rhs_idx): attributions[rhs_idx] = 0 self.assertLess(torch.sum(attributions), 1, f"{attributions}") - - @parameterized.expand([ - # input shape # lhs_idx # rhs_idx - ((2,4), (0,2), (0,3)) - ]) + @parameterized.expand( + [ + # input shape # lhs_idx # rhs_idx + ((2, 4), (0, 2), (0, 3)) + ] + ) def test_boolean_and(self, input_shape, lhs_idx, rhs_idx): for i, o in enumerate(self.all_options): rex = ReX(lambda x: min(x[lhs_idx], x[rhs_idx])) input = torch.ones(input_shape) - + attributions = rex.attribute(input, 0, *o)[0] self.assertEqual(attributions[lhs_idx], 0.5, f"{attributions}, {i}, {o}") @@ -152,38 +171,41 @@ def test_boolean_and(self, input_shape, lhs_idx, rhs_idx): attributions[rhs_idx] = 0 self.assertLess(torch.sum(attributions), 1, f"{attributions}") - - @parameterized.expand([ - # shape # mean - ((30,30),), - ((50, 50),), - ((100,100),) - ]) + @parameterized.expand( + [ + # shape # mean + ((30, 30),), + ((50, 50),), + ((100, 100),), + ] + ) def test_gaussian_recovery(self, shape): random.seed() eps = 1e-12 - + p = torch.zeros(shape) for _ in range(3): center = self.ts([int(random.random() * dim) for dim in shape]) p += self._generate_gaussian_pdf(shape, center) - + p += eps - p = p/torch.sum(p) + p = p / torch.sum(p) thresh = math.sqrt(torch.mean(p)) + def _forward(inp): return 1 if torch.sum(inp) > thresh else 0 - + rex = ReX(_forward) for b in self.n_partition_opts: - attributions = rex.attribute(p, - 0, - n_partitions=b, - search_depth=10, - n_searches=25, - assume_locality=True)[0] - + attributions = rex.attribute( + p, + 0, + n_partitions=b, + search_depth=10, + n_searches=25, + assume_locality=True, + )[0] attributions += eps attrib_norm = attributions / torch.sum(attributions) @@ -193,7 +215,8 @@ def _forward(inp): # visualize_tensor(p - attrib_norm) mid = 0.5 * (p + attrib_norm) - jsd = 0.5 * F.kl_div(p.log(), mid, reduction="sum") \ - + 0.5 * F.kl_div(attrib_norm.log(), mid, reduction="sum") - + jsd = 0.5 * F.kl_div(p.log(), mid, reduction="sum") + 0.5 * F.kl_div( + attrib_norm.log(), mid, reduction="sum" + ) + self.assertLess(jsd, 0.1) From a78c659f3d226f14f9ff2b8a83f594021193d33b Mon Sep 17 00:00:00 2001 From: stav-af Date: Sun, 7 Sep 2025 18:35:00 +0100 Subject: [PATCH 19/21] Documentation --- docs/algorithms_comparison_matrix.md | 11 ++++++++++- docs/attribution_algorithms.md | 9 +++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/docs/algorithms_comparison_matrix.md b/docs/algorithms_comparison_matrix.md index e74128ee9..21d27514c 100644 --- a/docs/algorithms_comparison_matrix.md +++ b/docs/algorithms_comparison_matrix.md @@ -207,7 +207,16 @@ Please, scroll to the right for more details. Depends on the choice of above mentioned attribution algorithm. Depends on the choice of above mentioned attribution algorithm. | Adds gaussian noise to each input example #samples times, calls any above mentioned attribution algorithm for all #samples per example and aggregates / smoothens them based on different techniques for each input example. Supported smoothing techniques include: smoothgrad, vargrad, smoothgrad_sq. - + + ReX + Perturbation + Any function returning a single value + O(#partitions ^ #max_depth) - user defined values + Any function returning a single value + O( #iterations (#partitions ^ #max_depth)) + Yes (strong assumption regarding neutral baseline) + Perturbation based approach based on a recursive search over the input. By recursively occluding partitions of an input, ReX searches for partitions who's values have predictive value wrt. the output. + **^ Including Layer Variant** diff --git a/docs/attribution_algorithms.md b/docs/attribution_algorithms.md index f1d00a8f5..f2b701b35 100644 --- a/docs/attribution_algorithms.md +++ b/docs/attribution_algorithms.md @@ -134,6 +134,15 @@ Kernel SHAP is a method that uses the LIME framework to compute Shapley Values. To learn more about KernelSHAP, visit the following resources: - [Original paper](https://arxiv.org/abs/1705.07874) +### ReX +ReX is a perturbation-based explainability approach, grounded in the theory of Actual Causality[1]. It works by partitioning the input, and occluding all combinations of partitions using a neutral masking value. Where there masking some combination of partitions changes the output of the model, those partitions are recursively re-partitioned to search for ever-smaller parts of the input which are responsible for the final output. + +To learn more about actual causality, responsibility and ReX: +- [Actual Causality](https://www.cs.cornell.edu/home/halpern/papers/causalitybook-ch1-3.html) +- [Responsibility and Blame](https://arxiv.org/pdf/cs/0312038) +- [ReX Original Paper(called DC-Causal here)](https://www.hanachockler.com/iccv2021/) + + ## Layer Attribution ### Layer Conductance Conductance combines the neuron activation with the partial derivatives of both the neuron with respect to the input and the output with respect to the neuron to build a more complete picture of neuron importance. From 38e47ef17f4cae3088f2d50d4ad583c364305a97 Mon Sep 17 00:00:00 2001 From: stav-af Date: Sun, 7 Sep 2025 18:38:55 +0100 Subject: [PATCH 20/21] Restore CI file --- .github/workflows/lint.yml | 26 ++++++++++--------- .github/workflows/test-conda-cpu.yml | 1 - .../test-pip-cpu-with-type-checks.yml | 1 - .github/workflows/test-pip-cpu.yml | 1 - .github/workflows/test-pip-gpu.yml | 1 - 5 files changed, 14 insertions(+), 16 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e21e36fe4..ab4d71bc6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -3,19 +3,21 @@ name: Captum Lint on: pull_request: push: - branches: [ master ] + branches: + - master + workflow_dispatch: jobs: tests: - # Make sure your runner has these labels (or just "self-hosted") - runs-on: self-hosted - steps: - - uses: actions/checkout@v4 - - name: Lint & docs - run: | - sudo chmod -R 777 . - ./scripts/install_via_pip.sh - ufmt check . - flake8 - sphinx-build -WT --keep-going sphinx/source sphinx/build + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + runner: linux.12xlarge + docker-image: cimg/python:3.11 + repository: pytorch/captum + script: | + sudo chmod -R 777 . + ./scripts/install_via_pip.sh + ufmt check . + flake8 + sphinx-build -WT --keep-going sphinx/source sphinx/build diff --git a/.github/workflows/test-conda-cpu.yml b/.github/workflows/test-conda-cpu.yml index 6c0a73c1e..e0da5e42e 100644 --- a/.github/workflows/test-conda-cpu.yml +++ b/.github/workflows/test-conda-cpu.yml @@ -13,7 +13,6 @@ env: jobs: tests: - runs-on: ubuntu-latest strategy: matrix: python_version: ["3.9", "3.10", "3.11", "3.12"] diff --git a/.github/workflows/test-pip-cpu-with-type-checks.yml b/.github/workflows/test-pip-cpu-with-type-checks.yml index 284abb45b..3336a76f6 100644 --- a/.github/workflows/test-pip-cpu-with-type-checks.yml +++ b/.github/workflows/test-pip-cpu-with-type-checks.yml @@ -10,7 +10,6 @@ on: jobs: tests: - runs-on: ubuntu-latest strategy: matrix: pytorch_args: ["", "-n"] diff --git a/.github/workflows/test-pip-cpu.yml b/.github/workflows/test-pip-cpu.yml index fa3042368..83a513ac2 100644 --- a/.github/workflows/test-pip-cpu.yml +++ b/.github/workflows/test-pip-cpu.yml @@ -10,7 +10,6 @@ on: jobs: tests: - runs-on: ubuntu-latest strategy: matrix: pytorch_args: ["-v 1.10", "-v 1.11", "-v 1.12", "-v 1.13", "-v 2.0.0", "-v 2.1.0", "-v 2.2.0", "-v 2.3.0"] diff --git a/.github/workflows/test-pip-gpu.yml b/.github/workflows/test-pip-gpu.yml index 663c6a3bb..117f515f4 100644 --- a/.github/workflows/test-pip-gpu.yml +++ b/.github/workflows/test-pip-gpu.yml @@ -10,7 +10,6 @@ on: jobs: tests: - runs-on: ubuntu-latest strategy: matrix: cuda_arch_version: ["12.1"] From bf1ca9ad6f588a2f60951232a52a328dc344bbf9 Mon Sep 17 00:00:00 2001 From: stav-af Date: Fri, 12 Sep 2025 07:24:38 +0100 Subject: [PATCH 21/21] Documentation and formatting --- captum/attr/_core/rex.py | 65 +++++++++++++++++------------------ sphinx/source/attribution.rst | 1 + sphinx/source/rex.rst | 6 ++++ 3 files changed, 38 insertions(+), 34 deletions(-) create mode 100644 sphinx/source/rex.rst diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py index cae57d765..e67ba9a2b 100644 --- a/captum/attr/_core/rex.py +++ b/captum/attr/_core/rex.py @@ -6,7 +6,7 @@ import random from collections import deque from dataclasses import dataclass -from typing import Tuple, cast, List, Sized +from typing import cast, List, Sized, Tuple import torch from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric @@ -145,16 +145,16 @@ class ReX(PerturbationAttribution): and results are (optionally) merged to produce the final attribution map. - [1] - halpern 06 - [2] - rex paper + [1] - Cause: https://www.cs.cornell.edu/home/halpern/papers/modified-HPdef.pdf + [2] - ReX paper: https://arxiv.org/pdf/2411.08875 [3] - Responsibility and Blame; https://arxiv.org/pdf/cs/0312038 """ def __init__(self, forward_func): r""" Args: - forward_func (Callable): The function to be explained. Must return - a scalar for which the equality operator is defined. + forward_func (Callable): The function to be explained. *Must* return + a scalar for which the equality operator is defined. """ PerturbationAttribution.__init__(self, forward_func) @@ -170,35 +170,32 @@ def attribute( ) -> TensorOrTupleOfTensorsGeneric: r""" Args: - inputs: - An input or tuple of inputs to be explain. Each input - must be of the shape expected by the forward_func. Where multiple - examples are provided, they must be listed in a tuple. - - baselines: - A neutral values to be used as occlusion values. Where a scalar or - tensor is provided, they are broadcast to the input shape. Where - tuples are provided, they are paired element-wise, and must match - the structure of the input. - - search_depth (optional): - The maximum depth to which ReX will refine responsibility estimates - for causes. - - n_partitions (optional): - The maximum number of partitions to be made out of the input at each - search step. At least 1, and no larger than the partition size. Where - ``contiguous partitioning`` is set to False, partitions are created - using previous attribution maps as heuristics. - - n_searches (optional): - The number of times the search is to be ran. - - assume_locality (optional): - Where True, partitioning is contiguous and attribution maps are merged - after each search. Otherwise, partitioning is initially random, then - uses the previous attribution map as a heuristic for further searches, - returning the result of the final search. + inputs (Tensor or tuple[Tensor, ...]): An input or tuple of inputs + to be explained. Each input must be of the shape expected by + the forward_func. Where multiple examples are provided, they + must be listed in a tuple. + + baselines (Tensor or tuple[Tensor, ...]): A neutral values to be used + as occlusion values. Where a scalar or tensor is provided, they + are broadcast to the input shape. Where tuples are provided, + they are paired element-wise, and must match the structure of + the input. + + search_depth (int, optional): The maximum depth to which ReX will refine + responsibility estimates for causes. + + n_partitions (optional): The maximum number of partitions to be made out of + the input at each search step. At least 1, and no larger than the + partition size. Where ``contiguous partitioning`` is set to False, + partitions are created using previous attribution maps as heuristics. + + n_searches (int, optional): The number of times the search is to be ran. + + assume_locality (int, optional): Where True, partitioning is contiguous and + attribution maps are merged after each search. Otherwise, + partitioning is initially random, then uses the previous attribution + map as a heuristic for further searches, returning the result of the + final search. """ inputs, baselines = _format_input_baseline(inputs, baselines) diff --git a/sphinx/source/attribution.rst b/sphinx/source/attribution.rst index ace52dd9a..922dbaa28 100644 --- a/sphinx/source/attribution.rst +++ b/sphinx/source/attribution.rst @@ -18,3 +18,4 @@ Attribution lime kernel_shap lrp + rex \ No newline at end of file diff --git a/sphinx/source/rex.rst b/sphinx/source/rex.rst new file mode 100644 index 000000000..daa69933f --- /dev/null +++ b/sphinx/source/rex.rst @@ -0,0 +1,6 @@ +ReX +=== + +.. autoclass:: captum.attr._core.rex.ReX + :members: + :show-inheritance: