diff --git a/src/tbp/monty/frameworks/models/evidence_matching/hypotheses_updater.py b/src/tbp/monty/frameworks/models/evidence_matching/hypotheses_updater.py index 04f5d166c..f69661d87 100644 --- a/src/tbp/monty/frameworks/models/evidence_matching/hypotheses_updater.py +++ b/src/tbp/monty/frameworks/models/evidence_matching/hypotheses_updater.py @@ -35,7 +35,10 @@ DefaultHypothesesDisplacer, HypothesisDisplacerTelemetry, ) -from tbp.monty.frameworks.utils.evidence_matching import ChannelMapper +from tbp.monty.frameworks.utils.evidence_matching import ( + ChannelMapper, + ConsistentHypothesesIds, +) from tbp.monty.frameworks.utils.graph_matching_utils import ( get_initial_possible_poses, possible_sensed_directions, @@ -56,6 +59,14 @@ class ChannelHypothesesUpdateTelemetry: class HypothesesUpdater(Protocol): + def pre_step(self) -> None: + """Runs once per step before updating the hypotheses.""" + ... + + def post_step(self) -> None: + """Runs once per step after updating the hypotheses.""" + ... + def update_hypotheses( self, hypotheses: Hypotheses, @@ -81,6 +92,19 @@ def update_hypotheses( """ ... + def remap_hypotheses_ids_to_present( + self, hypotheses_ids: ConsistentHypothesesIds + ) -> ConsistentHypothesesIds: + """Update hypotheses ids based on resizing of hypothesis space. + + Args: + hypotheses_ids: Hypotheses ids to be updated + + Returns: + The list of the updated hypotheses ids. + """ + ... + class DefaultHypothesesUpdater(HypothesesUpdater): def __init__( @@ -183,6 +207,14 @@ def __init__( use_features_for_matching=self.use_features_for_matching, ) + def pre_step(self) -> None: + """Runs once per step before updating the hypotheses.""" + ... + + def post_step(self) -> None: + """Runs once per step after updating the hypotheses.""" + ... + def update_hypotheses( self, hypotheses: Hypotheses, @@ -403,6 +435,22 @@ def _get_initial_hypothesis_space( poses=initial_possible_channel_rotations, ) + def remap_hypotheses_ids_to_present( + self, hypotheses_ids: ConsistentHypothesesIds + ) -> ConsistentHypothesesIds: + """Update hypotheses ids based on resizing of hypothesis space. + + We do not resize the hypotheses space when using `DefaultHypothesesUpdater`, + therefore, we return the same ids without update. + + Args: + hypotheses_ids: Hypotheses ids to be updated + + Returns: + The list of the updated hypotheses ids. + """ + return hypotheses_ids + def all_usable_input_channels( features: dict, all_input_channels: list[str] diff --git a/src/tbp/monty/frameworks/models/evidence_matching/learning_module.py b/src/tbp/monty/frameworks/models/evidence_matching/learning_module.py index 5ae515413..c53f3339f 100644 --- a/src/tbp/monty/frameworks/models/evidence_matching/learning_module.py +++ b/src/tbp/monty/frameworks/models/evidence_matching/learning_module.py @@ -15,6 +15,7 @@ import time import numpy as np +import numpy.typing as npt from scipy.spatial import KDTree from scipy.spatial.transform import Rotation @@ -35,6 +36,7 @@ from tbp.monty.frameworks.models.states import State from tbp.monty.frameworks.utils.evidence_matching import ( ChannelMapper, + ConsistentHypothesesIds, evidence_update_threshold, ) from tbp.monty.frameworks.utils.graph_matching_utils import ( @@ -521,18 +523,30 @@ def get_unique_pose_if_available(self, object_id): # Only try to determine object pose if the evidence for it is high enough. if possible_object_hypotheses_ids is not None: mlh = self.get_current_mlh() + # Check if all possible poses are similar pose_is_unique = self._check_for_unique_poses( object_id, possible_object_hypotheses_ids, mlh["rotation"] ) + # Check for symmetry + last_possible_hypotheses_remapped = ( + self.hypotheses_updater.remap_hypotheses_ids_to_present( + self.last_possible_hypotheses + ) + ) symmetry_detected = self._check_for_symmetry( - possible_object_hypotheses_ids, + object_id=object_id, + last_possible_object_hypotheses=last_possible_hypotheses_remapped, + possible_object_hypotheses_ids=possible_object_hypotheses_ids, # Don't increment symmetry counter if LM didn't process observation increment_evidence=self.buffer.get_last_obs_processed(), ) - - self.last_possible_hypotheses = possible_object_hypotheses_ids + self.last_possible_hypotheses = ConsistentHypothesesIds( + hypotheses_ids=possible_object_hypotheses_ids, + channel_sizes=self.channel_hypothesis_mapping[object_id].channel_sizes, + graph_id=object_id, + ) if pose_is_unique or symmetry_detected: r_inv = mlh["rotation"].inv() @@ -561,18 +575,18 @@ def get_unique_pose_if_available(self, object_id): if symmetry_detected: symmetry_stats = { "symmetric_rotations": np.array(self.possible_poses[object_id])[ - self.last_possible_hypotheses + self.last_possible_hypotheses.hypotheses_ids ], "symmetric_locations": self.possible_locations[object_id][ - self.last_possible_hypotheses + self.last_possible_hypotheses.hypotheses_ids ], } self.buffer.add_overall_stats(symmetry_stats) return pose_and_scale - logger.debug(f"object {object_id} detected but pose not resolved yet.") return None + self.last_possible_hypotheses = None return None def get_current_mlh(self): @@ -603,16 +617,22 @@ def get_top_two_mlh_ids(self): """ graph_ids, graph_evidences = self.get_evidence_for_each_graph() + # If all hypothesis spaces are empty return None for both mlh ids. The gsg will + # not generate a goal state. + if len(graph_ids) == 0: + return None, None + + # If we have a single hypothesis space, return the second object id as None. + # The gsg will focus on pose to generate a goal state. + if len(graph_ids) == 1: + return graph_ids[0], None + # Note the indices below will be ordered with the 2nd MLH appearing first, and # the 1st MLH appearing second. top_indices = np.argsort(graph_evidences)[-2:] + top_id = graph_ids[top_indices[1]] + second_id = graph_ids[top_indices[0]] - if len(top_indices) > 1: - top_id = graph_ids[top_indices[1]] - second_id = graph_ids[top_indices[0]] - else: - top_id = graph_ids[top_indices[0]] - second_id = top_id # Account for the case where we have multiple top evidences with the same value. # In this case argsort and argmax (used to get current_mlh) will return # different results but some downstream logic (in gsg) expects them to be the @@ -626,6 +646,7 @@ def get_top_two_mlh_ids(self): # and keep the second id as is (since this means there is a threeway # tie in evidence values so its not like top is more likely than second) top_id = self.current_mlh["graph_id"] + return top_id, second_id def get_top_two_pose_hypotheses_for_graph_id(self, graph_id): @@ -683,15 +704,24 @@ def get_evidence_for_each_graph(self): graph_ids = self.get_all_known_object_ids() if graph_ids[0] not in self.evidence.keys(): return ["patch_off_object"], [0] - graph_evidences = [] + + available_graph_ids = [] + available_graph_evidences = [] for graph_id in graph_ids: - graph_evidences.append(np.max(self.evidence[graph_id])) - return graph_ids, np.array(graph_evidences) + if len(self.hyp_evidences_for_object(graph_id)): + available_graph_ids.append(graph_id) + available_graph_evidences.append(np.max(self.evidence[graph_id])) + + return available_graph_ids, np.array(available_graph_evidences) def get_all_evidences(self): """Return evidence for each pose on each graph (pointer).""" return self.evidence + def hyp_evidences_for_object(self, object_id): + """Return evidences for a specific object_id.""" + return self.evidence[object_id] + # ------------------ Logging & Saving ---------------------- def collect_stats_to_save(self): """Get all stats that this LM should store in the buffer for logging. @@ -710,6 +740,8 @@ def collect_stats_to_save(self): def _update_possible_matches(self, query): """Update evidence for each hypothesis instead of removing them.""" + self.hypotheses_updater.pre_step() + thread_list = [] for graph_id in self.get_all_known_object_ids(): if self.use_multithreading: @@ -738,6 +770,8 @@ def _update_possible_matches(self, query): self.previous_mlh = self.current_mlh self.current_mlh = self._calculate_most_likely_hypothesis() + self.hypotheses_updater.post_step() + def _update_evidence( self, features: dict, @@ -800,12 +834,20 @@ def _update_evidence( self._set_hypotheses_in_hpspace(graph_id=graph_id, new_hypotheses=update) end_time = time.time() - assert not np.isnan(np.max(self.evidence[graph_id])), "evidence contains NaN." - logger.debug( + + logger_msg = ( f"evidence update for {graph_id} took " f"{np.round(end_time - start_time, 2)} seconds." - f" New max evidence: {np.round(np.max(self.evidence[graph_id]), 3)}" ) + graph_evidence = self.hyp_evidences_for_object(graph_id) + if len(graph_evidence): + assert not np.isnan(np.max(self.evidence[graph_id])), ( + "evidence contains NaN." + ) + logger_msg += ( + f" New max evidence: {np.round(np.max(self.evidence[graph_id]), 3)}" + ) + logger.debug(logger_msg) def _set_hypotheses_in_hpspace( self, @@ -992,7 +1034,13 @@ def _check_for_unique_poses( return location_unique and rotation_unique - def _check_for_symmetry(self, possible_object_hypotheses_ids, increment_evidence): + def _check_for_symmetry( + self, + object_id: str, + last_possible_object_hypotheses: ConsistentHypothesesIds | None, + possible_object_hypotheses_ids: npt.NDArray[np.int64], + increment_evidence: bool, + ): """Check whether the most likely hypotheses stayed the same over the past steps. Since the definition of possible_object_hypotheses is a bit murky and depends @@ -1001,6 +1049,9 @@ def _check_for_symmetry(self, possible_object_hypotheses_ids, increment_evidence not sure if this is the best way to check for symmetry... Args: + object_id: identifier of the object being checked for symmetry + last_possible_object_hypotheses: All the possible hypotheses + from the last step. possible_object_hypotheses_ids: List of IDs of all possible hypotheses. increment_evidence: Whether to increment symmetry evidence or not. We may want this to be False for example if we did not receive a new @@ -1009,14 +1060,17 @@ def _check_for_symmetry(self, possible_object_hypotheses_ids, increment_evidence Returns: Whether symmetry was detected. """ - if self.last_possible_hypotheses is None: + if ( + last_possible_object_hypotheses is None + or last_possible_object_hypotheses.graph_id != object_id + ): return False # need more steps to meet symmetry condition logger.debug( f"\n\nchecking for symmetry for hp ids {possible_object_hypotheses_ids}" f" with last ids {self.last_possible_hypotheses}" ) if increment_evidence: - previous_hyps = set(self.last_possible_hypotheses) + previous_hyps = set(last_possible_object_hypotheses.hypotheses_ids) current_hyps = set(possible_object_hypotheses_ids) hypothesis_overlap = previous_hyps.intersection(current_hyps) if len(hypothesis_overlap) / len(current_hyps) > 0.9: @@ -1106,7 +1160,13 @@ def _threshold_possible_matches(self, x_percent_scale_factor=1.0): if len(self.graph_memory) == 0: logger.info("no objects in memory yet.") return [] + graph_ids, graph_evidences = self.get_evidence_for_each_graph() + + if len(graph_ids) == 0: + logger.info("All hypothesis spaces are empty. No possible matches.") + return [] + # median_ge = np.median(graph_evidences) mean_ge = np.mean(graph_evidences) max_ge = np.max(graph_evidences) @@ -1167,23 +1227,28 @@ def _calculate_most_likely_hypothesis(self, graph_id=None): """ mlh = {} if graph_id is not None: - mlh_id = np.argmax(self.evidence[graph_id]) - mlh = self._get_mlh_dict_from_id(graph_id, mlh_id) + graph_evidence = self.hyp_evidences_for_object(graph_id) + if len(graph_evidence): + mlh_id = np.argmax(graph_evidence) + mlh = self._get_mlh_dict_from_id(graph_id, mlh_id) else: highest_evidence_so_far = -np.inf - for next_graph_id in self.get_all_known_object_ids(): - mlh_id = np.argmax(self.evidence[next_graph_id]) - evidence = self.evidence[next_graph_id][mlh_id] - if evidence > highest_evidence_so_far: - mlh = self._get_mlh_dict_from_id(next_graph_id, mlh_id) - highest_evidence_so_far = evidence - if not mlh: # No objects in memory - mlh = self.current_mlh - mlh["graph_id"] = "new_object0" - logger.info( - f"current most likely hypothesis: {mlh['graph_id']} " - f"with evidence {np.round(mlh['evidence'], 2)}" - ) + for graph_id in self.get_all_known_object_ids(): + graph_evidence = self.hyp_evidences_for_object(graph_id) + if len(graph_evidence): + mlh_id = np.argmax(graph_evidence) + evidence = graph_evidence[mlh_id] + if evidence > highest_evidence_so_far: + mlh = self._get_mlh_dict_from_id(graph_id, mlh_id) + highest_evidence_so_far = evidence + + if not mlh: # No objects in memory + mlh = self.current_mlh + mlh["graph_id"] = "new_object0" + logger.info( + f"current most likely hypothesis: {mlh['graph_id']} " + f"with evidence {np.round(mlh['evidence'], 2)}" + ) return mlh def _get_node_distance_weights(self, distances): diff --git a/src/tbp/monty/frameworks/models/evidence_matching/resampling_hypotheses_updater.py b/src/tbp/monty/frameworks/models/evidence_matching/resampling_hypotheses_updater.py index 9abbf848b..8d5858c2c 100644 --- a/src/tbp/monty/frameworks/models/evidence_matching/resampling_hypotheses_updater.py +++ b/src/tbp/monty/frameworks/models/evidence_matching/resampling_hypotheses_updater.py @@ -9,7 +9,7 @@ from __future__ import annotations -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, replace from typing import Any, Literal import numpy as np @@ -42,7 +42,9 @@ ) from tbp.monty.frameworks.utils.evidence_matching import ( ChannelMapper, + ConsistentHypothesesIds, EvidenceSlopeTracker, + HypothesesSelection, InvalidEvidenceThresholdConfig, ) from tbp.monty.frameworks.utils.graph_matching_utils import ( @@ -71,32 +73,44 @@ class ChannelHypothesesResamplingTelemetry(ChannelHypothesesUpdateTelemetry): ages: npt.NDArray[np.int_] evidence_slopes: npt.NDArray[np.float64] removed_ids: npt.NDArray[np.int_] + max_slope: float class ResamplingHypothesesUpdater: - """Hypotheses updater that resamples hypotheses at every step. - - This updater enables updating of the hypothesis space by resampling and rebuilding - the hypothesis space at every step. We resample hypotheses from the existing - hypothesis space, as well as new hypotheses informed by the sensed pose. - - The resampling process is governed by two main parameters: - - `hypotheses_count_multiplier`: scales the total number of hypotheses every step. - - `hypotheses_existing_to_new_ratio`: controls the proportion of existing vs. - informed hypotheses during resampling. + """Hypotheses updater that adds and deletes hypotheses based on evidence slope. + + This updater enables updating of the hypothesis space by intelligently resampling + and rebuilding the hypothesis space when the model's prediction error is high. The + prediction error is determined based on the highest evidence slope over all the + objects hypothesis spaces. If the hypothesis with the highest slope is unable to + accumulate evidence at a high enough slope, i.e., none of the current hypotheses + match the incoming observations well, a sampling burst is triggered. A sampling + burst adds new hypotheses over a specified `sampling_burst_duration` number of + consecutive steps to all hypothesis spaces. This burst duration reduces the effect + of sensor noise. Hypotheses are deleted when their smoothed evidence slope is below + `deletion_trigger_slope`. + + The resampling process is governed by four main parameters: + - `resampling_multiplier`: Determines the number of the hypotheses to resample + as a multiplier of the object graph nodes. + - `deletion_trigger_slope`: Hypotheses below this threshold are deleted. + - `sampling_burst_duration`: The number of consecutive steps in each burst. + - `burst_trigger_slope`: The threshold for triggering a sampling burst. This + threshold is applied to the highest global slope over all the hypotheses (i.e., + over all objects' hypothesis spaces). The range of this slope is [-1, 2]. To reproduce the behavior of `DefaultHypothesesUpdater` sampling a fixed number of - hypotheses only at the beginning of the episode, you can set - `hypotheses_count_multiplier=1.0` and `hypotheses_existing_to_new_ratio=0.0`. - - Note: - It would be better to decouple the amount of hypotheses added from the amount - deleted in each step. At the moment, this is decided by the - `hypotheses_count_multiplier`. For example, when the multiplier is set to 1.0, - the hypotheses sampled is equal to the hypotheses removed. We ideally can - decouple theses then use a slope threshold to decide on which hypotheses to - remove, and use prediction error or other heuristics to decide to how many - hypotheses to resample. + hypotheses only at the beginning of the episode, you can set: + - `resampling_multiplier=2` (or `umbilical_num_poses` if PC undefined) + - `deletion_trigger_slope=-1.0` (no deletion is allowed) + - `sampling_burst_duration=1` (sample the full burst over a single step) + - `burst_trigger_slope=-1.0` (never trigger additional bursts) + + These parameters will trigger a single-step burst at the first step of the episode. + Note that if the PC of the first observation is undetermined, + `resampling_multiplier` should be set to `umbilical_num_poses` to reproduce the + exact results of `DefaultHypothesesUpdater`. In practice, this is difficult to + predict because it relies on the first sampled observation. """ def __init__( @@ -113,8 +127,10 @@ def __init__( features_for_matching_selector: type[FeaturesForMatchingSelector] = ( DefaultFeaturesForMatchingSelector ), - hypotheses_count_multiplier: float = 1.0, - hypotheses_existing_to_new_ratio: float = 0.1, + resampling_multiplier: float = 0.4, + deletion_trigger_slope: float = 0.5, + sampling_burst_duration: int = 5, + burst_trigger_slope: float = 1.0, include_telemetry: bool = False, initial_possible_poses: Literal["uniform", "informed"] | list[Rotation] = "informed", @@ -148,11 +164,17 @@ def __init__( features_for_matching_selector: Class to select if features should be used for matching. Defaults to the default selector. - hypotheses_count_multiplier: Scales the total number of hypotheses - every step. Defaults to 1.0. - hypotheses_existing_to_new_ratio: Controls the proportion of the - existing vs. newly sampled hypotheses during resampling. Defaults to - 0.0. + resampling_multiplier: Determines the number of the hypotheses to resample + as a multiplier of the object graph nodes. Value of 0.0 results in no + resampling. Value can be greater than 1 but not to exceed the + `num_hyps_per_node` of the current step. Defaults to 0.4. + deletion_trigger_slope: Hypotheses below this threshold are deleted. + Expected range matches the range of step evidence change, i.e., + [-1.0, 2.0]. Defaults to 0.5. + sampling_burst_duration: The number of steps in every sampling burst. + Defaults to 5. + burst_trigger_slope: A threshold below which a sampling burst is triggered. + Defaults to 1.0. include_telemetry: Flag to control if we want to calculate and return the resampling telemetry in the `update_hypotheses` method. Defaults to False. @@ -177,6 +199,7 @@ def __init__( umbilical points (i.e., points where PC directions are undefined). Raises: + ValueError: If the resampling_multiplier is less than 0 InvalidEvidenceThresholdConfig: If `evidence_threshold_config` is not set to "all". @@ -191,6 +214,10 @@ def __init__( self.feature_evidence_increment = feature_evidence_increment self.feature_weights = feature_weights self.features_for_matching_selector = features_for_matching_selector + self.resampling_multiplier = resampling_multiplier + self.deletion_trigger_slope = deletion_trigger_slope + self.sampling_burst_duration = sampling_burst_duration + self.burst_trigger_slope = burst_trigger_slope self.graph_memory = graph_memory self.include_telemetry = include_telemetry self.initial_possible_poses = get_initial_possible_poses(initial_possible_poses) @@ -214,19 +241,41 @@ def __init__( use_features_for_matching=self.use_features_for_matching, ) - # Controls the shrinking or growth of hypothesis space size - # Cannot be less than 0 - self.hypotheses_count_multiplier = max(0, hypotheses_count_multiplier) - - # Controls the ratio of existing to newly sampled hypotheses - # Bounded between 0 and 1 - self.hypotheses_existing_to_new_ratio = max( - 0, min(hypotheses_existing_to_new_ratio, 1) - ) + # resampling multiplier should not be less than 0 (no resampling) + if self.resampling_multiplier < 0: + raise ValueError("resampling_multiplier should be >= 0") # Dictionary of slope trackers, one for each graph_id self.evidence_slope_trackers: dict[str, EvidenceSlopeTracker] = {} + # Dictionary of resampling telemetry for each channel in each graph_id + self.resampling_telemetry: dict[str, dict[str, HypothesesUpdateTelemetry]] = {} + + # Trigger a burst at the beginning of the episode + self.sampling_burst_steps = self.sampling_burst_duration + + def pre_step(self) -> None: + """Runs once per step before updating the hypotheses. + + We calculate the max slope and update resampling parameters before running the + hypotheses update loop/threads over all the graph_ids and channels. + """ + self.max_slope = self._max_global_slope() + + if ( + self.max_slope <= self.burst_trigger_slope + and self.sampling_burst_steps == 0 + ): + self.sampling_burst_steps = self.sampling_burst_duration + + def post_step(self) -> None: + """Runs once per step after updating the hypotheses. + + We decrement the burst steps by 1 every step for the duration of the burst. + """ + if self.sampling_burst_steps > 0: + self.sampling_burst_steps -= 1 + def update_hypotheses( self, hypotheses: Hypotheses, @@ -284,7 +333,7 @@ def update_hypotheses( for input_channel in input_channels_to_use: # Calculate sample count for each type - existing_count, informed_count = self._sample_count( + hypotheses_selection, informed_count = self._sample_count( input_channel=input_channel, channel_features=features[input_channel], graph_id=graph_id, @@ -293,8 +342,8 @@ def update_hypotheses( ) # Sample hypotheses based on their type - existing_hypotheses, remove_ids = self._sample_existing( - existing_count=existing_count, + existing_hypotheses = self._sample_existing( + hypotheses_selection=hypotheses_selection, hypotheses=hypotheses, input_channel=input_channel, mapper=mapper, @@ -310,7 +359,7 @@ def update_hypotheses( # We only displace existing hypotheses since the newly resampled hypotheses # should not be affected by the displacement from the last sensory input. - if existing_count > 0: + if len(hypotheses_selection.maintain_ids) > 0: existing_hypotheses, channel_hypothesis_displacer_telemetry = ( self.hypotheses_displacer.displace_hypotheses_and_compute_evidence( channel_displacement=displacements[input_channel], @@ -338,32 +387,44 @@ def update_hypotheses( # Update tracker evidence tracker.update(channel_hypotheses.evidence, input_channel) - if self.include_telemetry: - resampling_telemetry[input_channel] = asdict( - ChannelHypothesesResamplingTelemetry( - channel_hypothesis_displacer_telemetry=channel_hypothesis_displacer_telemetry, - added_ids=( - np.arange(len(channel_hypotheses.evidence))[ - -len(informed_hypotheses.evidence) : - ] - if len(informed_hypotheses.evidence) > 0 - else np.array([], dtype=np.int_) - ), - ages=tracker.hyp_ages(input_channel), - evidence_slopes=tracker.calculate_slopes(input_channel), - removed_ids=remove_ids, - ) + # Telemetry update + resampling_telemetry[input_channel] = asdict( + ChannelHypothesesResamplingTelemetry( + channel_hypothesis_displacer_telemetry=channel_hypothesis_displacer_telemetry, + added_ids=( + np.arange(len(channel_hypotheses.evidence))[ + -len(informed_hypotheses.evidence) : + ] + if len(informed_hypotheses.evidence) > 0 + else np.array([], dtype=np.int_) + ), + ages=tracker.hyp_ages(input_channel), + evidence_slopes=tracker.calculate_slopes(input_channel), + removed_ids=hypotheses_selection.remove_ids, + max_slope=self.max_slope, ) - else: - # Still return prediction error. - # TODO: make this nicer like dependent on log_level. - resampling_telemetry[input_channel] = asdict( + ) + + self.resampling_telemetry[graph_id] = resampling_telemetry + + # Still return prediction error. + # TODO: make this nicer like dependent on log_level. + if not self.include_telemetry: + updater_telemetry = { + k: asdict( ChannelHypothesesUpdateTelemetry( - channel_hypothesis_displacer_telemetry=channel_hypothesis_displacer_telemetry + channel_hypothesis_displacer_telemetry=v[ + "channel_hypothesis_displacer_telemetry" + ] ) ) + for k, v in resampling_telemetry.items() + } - return hypotheses_updates, resampling_telemetry + return ( + hypotheses_updates, + resampling_telemetry if self.include_telemetry else updater_telemetry, + ) def _num_hyps_per_node(self, channel_features: dict) -> int: """Calculate the number of hypotheses per node. @@ -390,7 +451,7 @@ def _sample_count( graph_id: str, mapper: ChannelMapper, tracker: EvidenceSlopeTracker, - ) -> tuple[int, int]: + ) -> tuple[HypothesesSelection, int]: """Calculates the number of existing and informed hypotheses needed. Args: @@ -403,67 +464,53 @@ def _sample_count( graph_id Returns: - A tuple containing the number of existing and new hypotheses needed. - Existing hypotheses are maintained from existing ones while new hypotheses - will be initialized, informed by pose sensory information. + A tuple containing the hypotheses selection and count of new hypotheses + needed. Hypotheses selection are maintained from existing ones while new + hypotheses will be initialized, informed by pose sensory information. Notes: - This function takes into account the following ratios: - - `hypotheses_count_multiplier`: multiplier for total count calculation. - - `hypotheses_existing_to_new_ratio`: ratio between existing and new - hypotheses to be sampled. + This function takes into account the following parameters: + - `resampling_multiplier`: The number of hypotheses to resample. This + is defined as a multiplier of the number of nodes in the object graph. + - `deletion_trigger_slope`: This dictates how many hypotheses to + delete. Hypotheses below this threshold are deleted. + - `sampling_burst_steps`: The remaining number of burst steps. This value + is decremented in the `post_step` function. """ - graph_num_points = self.graph_memory.get_locations_in_graph( - graph_id, input_channel - ).shape[0] - num_hyps_per_node = self._num_hyps_per_node(channel_features) - full_informed_count = graph_num_points * num_hyps_per_node + new_informed = 0 + if self.sampling_burst_steps > 0: + graph_num_points = self.graph_memory.get_locations_in_graph( + graph_id, input_channel + ).shape[0] + num_hyps_per_node = self._num_hyps_per_node(channel_features) - # If hypothesis space does not exist, we initialize with informed hypotheses - if input_channel not in mapper.channels: - return 0, full_informed_count - - # Calculate the total number of hypotheses needed - current = mapper.channel_size(input_channel) - needed = current * self.hypotheses_count_multiplier - - # Calculate how many existing and new hypotheses needed - existing_maintained, new_informed = ( - needed * (1 - self.hypotheses_existing_to_new_ratio), - needed * self.hypotheses_existing_to_new_ratio, - ) + # This makes sure that we do not request more than the available number of + # informed hypotheses + resampling_multiplier = min(self.resampling_multiplier, num_hyps_per_node) - # Needed existing hypotheses should not exceed the existing hypotheses - # if trying to maintain more hypotheses, set the available count as ceiling + # Calculate the total number of informed hypotheses to be resampled + new_informed = round(graph_num_points * resampling_multiplier) - # We make sure that `new_informed` is divisible by the number of hypotheses - # per graph node. This allows for sampling the graph nodes first (according - # to evidence) then multiply by the `num_hyps_per_node`, as shown in - # `_sample_informed`. - if existing_maintained > current: - existing_maintained = current - new_informed = needed - current + # Ensure the `new_informed` is divisible by `num_hyps_per_node` new_informed -= new_informed % num_hyps_per_node - # Needed informed hypotheses should not exceed the available informed hypotheses - # If trying to sample more hypotheses, set the available count as ceiling - if new_informed > full_informed_count: - new_informed = full_informed_count - - # Additional adjustment based on valid mask - must_keep = int(np.sum(~tracker.removable_indices_mask(input_channel))) - if must_keep > existing_maintained: - existing_maintained = must_keep - new_informed = needed - existing_maintained + # Returns a selection of hypotheses to maintain/delete + hypotheses_selection = ( + tracker.select_hypotheses( + slope_threshold=self.deletion_trigger_slope, channel=input_channel + ) + if input_channel in mapper.channels + else HypothesesSelection(maintain_mask=[]) + ) return ( - int(existing_maintained), - int(new_informed), + hypotheses_selection, + new_informed, ) def _sample_existing( self, - existing_count: int, + hypotheses_selection: HypothesesSelection, hypotheses: Hypotheses, input_channel: str, mapper: ChannelMapper, @@ -472,7 +519,7 @@ def _sample_existing( """Samples the specified number of existing hypotheses to retain. Args: - existing_count: Number of existing hypotheses to sample. + hypotheses_selection: The selection of hypotheses to maintain/remove. hypotheses: Hypotheses for all input channels in the graph_id. input_channel: The channel for which to sample existing hypotheses. mapper: Mapper for the graph_id to extract data from @@ -484,36 +531,30 @@ def _sample_existing( A tuple of sampled existing hypotheses and the IDs of the hypotheses to remove. """ + maintain_ids = hypotheses_selection.maintain_ids + # Return empty arrays for no hypotheses to sample - if existing_count == 0: + if len(maintain_ids) == 0: # Clear all channel hypotheses from the tracker - remove_ids = np.arange(tracker.total_size(input_channel)) tracker.clear_hyp(input_channel) - channel_hypotheses = ChannelHypotheses( + return ChannelHypotheses( input_channel=input_channel, locations=np.zeros((0, 3)), poses=np.zeros((0, 3, 3)), evidence=np.zeros(0), ) - return channel_hypotheses, remove_ids - - keep_ids, remove_ids = tracker.calculate_keep_and_remove_ids( - num_keep=existing_count, - channel=input_channel, - ) # Update tracker by removing the remove_ids - tracker.remove_hyp(remove_ids, input_channel) + tracker.remove_hyp(hypotheses_selection.remove_ids, input_channel) channel_hypotheses = mapper.extract_hypotheses(hypotheses, input_channel) - maintained_channel_hypotheses = ChannelHypotheses( + return ChannelHypotheses( input_channel=channel_hypotheses.input_channel, - locations=channel_hypotheses.locations[keep_ids], - poses=channel_hypotheses.poses[keep_ids], - evidence=channel_hypotheses.evidence[keep_ids], + locations=channel_hypotheses.locations[maintain_ids], + poses=channel_hypotheses.poses[maintain_ids], + evidence=channel_hypotheses.evidence[maintain_ids], ) - return maintained_channel_hypotheses, remove_ids def _sample_informed( self, @@ -649,3 +690,116 @@ def _sample_informed( poses=selected_rotations, evidence=selected_feature_evidence, ) + + def remap_hypotheses_ids_to_present( + self, + hypotheses_ids: ConsistentHypothesesIds, + ) -> ConsistentHypothesesIds: + """Update hypotheses ids based on resizing of hypothesis space. + + This function will receive hypotheses ids in a hypothesis space from + the previous timestep and find the ids of those same hypotheses in the + current hypothesis space (i.e. after resizing). + + Within a single channel, we only need the `removed_ids` to shift the + `hypotheses_ids`. This is because `added_ids` are appended to the end + of the channel. However, when dealing with multiple stacked channels, + the resizing of one channel affects the subsequent channels. + + We perform two main operations here: + - Channel rebasing: This takes care of the full channel shift that is needed + due to resizing of preceding channels. This is done by changing the + starting index of the whole channel. + - Channel-specific id shifting: This uses the `removed_ids` to shift the ids + within the channel itself. + + Note that we do not remap the full hypothesis space, we only remap + a selection of hypotheses ids (defined using `ConsistentHypothesesIds`). + To remap the full hypothesis space, `hypotheses_ids` should contain all of + the ids in the hypothesis space. + + Args: + hypotheses_ids: Previous timestep hypotheses ids to be updated + + Returns: + The list of the updated hypotheses ids in the current timestep/hypothesis + space. + """ + # Exit if no hypotheses_ids or no telemetry for this graph + if ( + hypotheses_ids is None + or hypotheses_ids.graph_id not in self.resampling_telemetry + or not len(hypotheses_ids.channel_sizes) + ): + return hypotheses_ids + + telemetry = self.resampling_telemetry[hypotheses_ids.graph_id] + ids = np.asarray(hypotheses_ids.hypotheses_ids, dtype=np.int64) + names, sizes = zip(*hypotheses_ids.channel_sizes.items()) + sizes = np.asarray(sizes, dtype=np.int64) + starts = np.r_[0, np.cumsum(sizes)[:-1]] + + # Collect removed and added counts per channel for "rebasing" ids + # Note that rebasing ids is important to cancel out the id shift + # effect of `added_ids` in preceding channels. + removed_map = {} + rem_counts, add_counts = [], [] + for name in names: + removed_ids = np.asarray(telemetry[name]["removed_ids"], np.int64) + removed_map[name] = np.sort(removed_ids) + rem_counts.append(removed_ids.size) + add_counts.append(len(telemetry[name]["added_ids"])) + rem_counts = np.asarray(rem_counts, dtype=np.int64) + add_counts = np.asarray(add_counts, dtype=np.int64) + + # Calculate the new bases of all channels after deletions and appended additions + new_lens = sizes - rem_counts + add_counts + new_bases = np.r_[0, np.cumsum(new_lens)[:-1]] + + out = [] + for name, start, size, new_base in zip(names, starts, sizes, new_bases): + # Pick only the ids that belong to this channel's original span + channel_mask = (ids >= start) & (ids < start + size) + if not channel_mask.any(): + continue + + # Convert from global to channel-local indices + local = ids[channel_mask] - start + channel_removed_ids = removed_map[name] + + if channel_removed_ids.size: + # Drop removed ids + keep_mask = ~np.isin(local, channel_removed_ids) + if keep_mask.any(): + keep_ids = local[keep_mask] + # Shift non-dropped ids left by how many removed indices are less + shift = np.searchsorted(channel_removed_ids, keep_ids, side="left") + out.append(new_base + (keep_ids - shift)) # Rebase and shift ids + else: + out.append(new_base + local) # Only rebase ids + + new_ids = np.concatenate(out) if out else np.empty(0, dtype=np.int64) + return replace(hypotheses_ids, hypotheses_ids=new_ids) + + def _max_global_slope(self) -> float: + """Compute the maximum slope over all objects and channels. + + Returns: + The maximum global slope if finitie, otherwise float("nan") + """ + max_slope = float("-inf") + + for tracker in self.evidence_slope_trackers.values(): + for channel in tracker.evidence_buffer.keys(): + if tracker.total_size(channel) == 0: + continue + + slopes = tracker.calculate_slopes(channel) + if slopes.size == 0: + continue + + finite_slopes = slopes[np.isfinite(slopes)] + if finite_slopes.size: + max_slope = max(max_slope, np.max(finite_slopes)) + + return float(max_slope if np.isfinite(max_slope) else "nan") diff --git a/src/tbp/monty/frameworks/models/goal_state_generation.py b/src/tbp/monty/frameworks/models/goal_state_generation.py index 95b04d21e..c644d5d4b 100644 --- a/src/tbp/monty/frameworks/models/goal_state_generation.py +++ b/src/tbp/monty/frameworks/models/goal_state_generation.py @@ -882,7 +882,7 @@ def _check_conditions_for_hypothesis_test(self): pose and object ID determination, i.e. determines whether there is a good chance of discriminating between conflicting object IDs or poses. - The schedule is designed to balance descriminating the pose and objects as + The schedule is designed to balance discriminating the pose and objects as efficiently as possible; TODO M future work can use the schedule conditions as primitives and use RL or evolutionary algorithms to optimize the relevant parameters. @@ -909,7 +909,16 @@ def _check_conditions_for_hypothesis_test(self): top_id, second_id = self.parent_lm.get_top_two_mlh_ids() - if top_id == second_id: + # This happens when all hypothesis spaces are empty + if top_id is None and second_id is None: + return False + + if second_id is None: + # If we only have one object with a single hypothesis, we should not + # attempt to generate a goal state. + if len(self.parent_lm.hyp_evidences_for_object(top_id)) == 1: + return False + # If we only know (i.e. have learned) about one object, we can focus on pose # In this case, get_top_two_mlh_ids returns the same IDs for top_id and # second_id @@ -920,7 +929,7 @@ def _check_conditions_for_hypothesis_test(self): top_mlh = self.parent_lm.get_current_mlh() # If the MLH evidence is significantly above the second MLH (where "significant" - # is determined by x_percent_scale_factor below), then focus on descriminating + # is determined by x_percent_scale_factor below), then focus on discriminating # its pose on some (random) occasions; always focus on pose if we've convereged # to one object # TODO M update so that not accessing private methods here; part of 2nd phase @@ -933,6 +942,11 @@ def _check_conditions_for_hypothesis_test(self): if ( len(pm_smaller_thresh) == 1 and (self.parent_lm.rng.uniform() <= 0.5) ) or len(pm_base_thresh) == 1: + # If we only have one object with a single hypothesis, we should not + # attempt to generate a goal state. + if len(self.parent_lm.hyp_evidences_for_object(top_id)) == 1: + return False + # We always focus on pose if there is just 1 possible match - if we are part # of the way towards being certain about the ID # (len(pm_smaller_thresh) == 1), then we sometimes (hence the randomness) diff --git a/src/tbp/monty/frameworks/models/mixins/no_reset_evidence.py b/src/tbp/monty/frameworks/models/mixins/no_reset_evidence.py index af5706a58..4d2ae1ac3 100644 --- a/src/tbp/monty/frameworks/models/mixins/no_reset_evidence.py +++ b/src/tbp/monty/frameworks/models/mixins/no_reset_evidence.py @@ -101,7 +101,7 @@ def _add_detailed_stats(self, stats: dict[str, Any]) -> dict[str, Any]: Returns: Updated statistics dictionary. """ - stats["max_evidence"] = {k: max(v) for k, v in self.evidence.items()} + stats["max_evidence"] = {k: max(v) for k, v in self.evidence.items() if len(v)} stats["target_object_theoretical_limit"] = ( self._theoretical_limit_target_object_pose_error() ) @@ -137,6 +137,20 @@ def _channel_telemetry( HypothesesUpdaterChannelTelemetry for the given graph ID and input channel. """ mapper = self.channel_hypothesis_mapping[graph_id] + + if input_channel not in mapper.channels: + channel_evidence = np.empty(shape=(0,), dtype=np.float64) + channel_rotations_inv = np.empty(shape=(0, 3, 3), dtype=np.float64) + channel_locations = np.empty(shape=(0, 3), dtype=np.float64) + + return HypothesesUpdaterChannelTelemetry( + hypotheses_updater=channel_telemetry.copy(), + evidence=np.empty(shape=(0,), dtype=np.float64), + rotations=np.empty(shape=(0, 3, 3), dtype=np.float64), + locations=np.empty(shape=(0, 3), dtype=np.float64), + pose_errors=np.empty(shape=(0,), dtype=np.float64), + ) + channel_rotations = mapper.extract(self.possible_poses[graph_id], input_channel) channel_rotations_inv = Rotation.from_matrix(channel_rotations).inv() channel_evidence = mapper.extract(self.evidence[graph_id], input_channel) diff --git a/src/tbp/monty/frameworks/utils/evidence_matching.py b/src/tbp/monty/frameworks/utils/evidence_matching.py index 9161464d2..6551218d5 100644 --- a/src/tbp/monty/frameworks/utils/evidence_matching.py +++ b/src/tbp/monty/frameworks/utils/evidence_matching.py @@ -9,6 +9,7 @@ from __future__ import annotations from collections import OrderedDict +from dataclasses import dataclass from typing import OrderedDict as OrderedDictType import numpy as np @@ -96,40 +97,40 @@ def channel_range(self, channel_name: str) -> tuple[int, int]: return (start, start + size) start += size - def resize_channel_by(self, channel_name: str, value: int) -> None: - """Increases or decreases the channel by a specific amount. + def resize_channel_to(self, channel_name: str, new_size: int) -> None: + """Sets the size of the given channel to a specific value. + + This function will also delete the channel if the `new_size` is 0. Args: channel_name: The name of the channel. - value: The value used to modify the channel size. - Use a negative value to decrease the size. + new_size: The new size to set for the channel. Raises: - ValueError: If the channel is not found or the requested size is negative. + ValueError: If the channel is not found or if the new size is not positive. """ if channel_name not in self.channel_sizes: raise ValueError(f"Channel '{channel_name}' not found.") - if self.channel_sizes[channel_name] + value <= 0: - raise ValueError( - f"Channel '{channel_name}' size cannot be negative or zero." - ) - self.channel_sizes[channel_name] += value + if new_size < 0: + raise ValueError(f"Channel '{channel_name}' size must be positive.") + if new_size == 0: + self.delete_channel(channel_name) + return - def resize_channel_to(self, channel_name: str, new_size: int) -> None: - """Sets the size of the given channel to a specific value. + self.channel_sizes[channel_name] = new_size + + def delete_channel(self, channel_name: str) -> None: + """Delete a channel from the mapping. Args: - channel_name: The name of the channel. - new_size: The new size to set for the channel. + channel_name: The name of the channel to delete. Raises: - ValueError: If the channel is not found or if the new size is not positive. + ValueError: If the channel is not found. """ if channel_name not in self.channel_sizes: raise ValueError(f"Channel '{channel_name}' not found.") - if new_size <= 0: - raise ValueError(f"Channel '{channel_name}' size must be positive.") - self.channel_sizes[channel_name] = new_size + del self.channel_sizes[channel_name] def add_channel( self, channel_name: str, size: int, position: int | None = None @@ -277,7 +278,7 @@ class EvidenceSlopeTracker: hyp_age: Maps channel names to hypothesis age counters. """ - def __init__(self, window_size: int = 3, min_age: int = 5) -> None: + def __init__(self, window_size: int = 10, min_age: int = 5) -> None: """Initializes the EvidenceSlopeTracker. Args: @@ -416,50 +417,160 @@ def clear_hyp(self, channel: str) -> None: if channel in self.evidence_buffer: self.remove_hyp(np.arange(self.total_size(channel)), channel) - def calculate_keep_and_remove_ids( - self, num_keep: int, channel: str - ) -> tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]: - """Determines which hypotheses to keep and which to remove in a channel. + def select_hypotheses( + self, slope_threshold: float, channel: str + ) -> HypothesesSelection: + """Returns a hypotheses selection given a slope threshold. - Hypotheses with the lowest average slope are selected for removal. + A hypothesis is maintained if: + - Its slope is >= the threshold, OR + - It is not yet removable due to age. Args: - num_keep: Requested number of hypotheses to retain. + slope_threshold: Minimum slope value to keep a removable (sufficiently old) + hypothesis. channel: Name of the input channel. Returns: - - to_keep: Indices of hypotheses to retain. - - to_remove: Indices of hypotheses to remove. + A selection of hypotheses to maintain. Raises: ValueError: If the channel does not exist. - ValueError: If the requested hypotheses to retain are more than available - hypotheses. """ if channel not in self.evidence_buffer: raise ValueError(f"Channel '{channel}' does not exist.") - total_size = self.total_size(channel) - if num_keep > total_size: - raise ValueError( - f"Cannot keep {num_keep} hypotheses; only {total_size} exist." - ) - total_ids = np.arange(total_size) - num_remove = total_size - num_keep - - # Retrieve valid slopes and sort them - removable_mask = self.removable_indices_mask(channel) slopes = self.calculate_slopes(channel) - removable_slopes = slopes[removable_mask] - removable_ids = total_ids[removable_mask] - sorted_indices = np.argsort(removable_slopes) + removable_mask = self.removable_indices_mask(channel) + + maintain_mask = (slopes >= slope_threshold) | (~removable_mask) + + return HypothesesSelection(maintain_mask) + + +class HypothesesSelection: + """Encapsulates the selection of hypotheses to maintain or remove. + + This class stores a boolean mask indicating which hypotheses should be maintained. + From this mask, it can return the indices and masks for both the maintained and + removed hypotheses. It also provides convenience constructors for creating a + selection from maintain/remove masks or from maintain/remove index lists. + + Attributes: + _maintain_mask: Boolean mask of shape (N,) where True indicates a maintain + hypothesis and False indicates a remove hypothesis. + """ - # Calculate which ids to keep and which to remove - to_remove = removable_ids[sorted_indices[:num_remove]] - to_remove_mask = np.zeros(total_size, dtype=bool) - to_remove_mask[to_remove] = True - to_keep = total_ids[~to_remove_mask] - return to_keep, to_remove + def __init__(self, maintain_mask: npt.NDArray[np.bool_]) -> None: + """Initializes a HypothesesSelection from a maintain mask. + + Args: + maintain_mask: Boolean array-like of shape (N,) where True indicates a + maintained hypothesis and False indicates a removed hypothesis. + """ + self._maintain_mask = np.asarray(maintain_mask, dtype=bool) + + @classmethod + def from_maintain_mask(cls, mask: npt.NDArray[np.bool_]) -> HypothesesSelection: + """Creates a selection from a maintain mask. + + Args: + mask: Boolean array-like where True indicates a maintained hypothesis. + + Returns: + A HypothesesSelection instance. + + Note: + This method is added from completeness, but it is redundant as it calls the + default class `__init__` function. + """ + return cls(mask) + + @classmethod + def from_remove_mask(cls, mask: npt.NDArray[np.bool_]) -> HypothesesSelection: + """Creates a hypotheses selection from a remove mask. + + Args: + mask: Boolean array-like where True indicates a hypothesis to remove. + + Returns: + A HypothesesSelection instance. + """ + return cls(~mask) + + @classmethod + def from_maintain_ids( + cls, total_size: int, ids: npt.NDArray[np.int_] + ) -> HypothesesSelection: + """Creates a hypotheses selection from maintain indices. + + Args: + total_size: Total number of hypotheses. + ids: Indices of hypotheses to maintain. + + Returns: + A HypothesesSelection instance. + + Raises: + IndexError: If any index is out of range [0, total_size). + """ + mask = np.zeros(int(total_size), dtype=bool) + + if ids.size: + if ids.min() < 0 or ids.max() >= total_size: + raise IndexError(f"maintain_ids outside [0, {total_size})") + mask[np.unique(ids)] = True + + return cls(mask) + + @classmethod + def from_remove_ids( + cls, total_size: int, ids: npt.NDArray[np.int_] + ) -> HypothesesSelection: + """Creates a selection from remove indices. + + Args: + total_size: Total number of hypotheses. + ids: Indices of hypotheses to remove. + + Returns: + A HypothesesSelection instance. + + Raises: + IndexError: If any index is out of range [0, total_size). + """ + mask = np.ones(int(total_size), dtype=bool) + + if ids.size: + if ids.min() < 0 or ids.max() >= total_size: + raise IndexError(f"remove_ids outside [0, {total_size})") + mask[np.unique(ids)] = False + + return cls(mask) + + @property + def maintain_mask(self) -> npt.NDArray[np.bool_]: + """Returns the maintain mask.""" + return self._maintain_mask + + @property + def remove_mask(self) -> npt.NDArray[np.bool_]: + """Returns the remove mask.""" + return ~self._maintain_mask + + @property + def maintain_ids(self) -> npt.NDArray[np.int_]: + """Returns the indices of maintained hypotheses.""" + return np.flatnonzero(self._maintain_mask).astype(int) + + @property + def remove_ids(self) -> npt.NDArray[np.int_]: + """Returns the indices of removed hypotheses.""" + return np.flatnonzero(~self._maintain_mask).astype(int) + + def __len__(self) -> int: + """Returns the total number of hypotheses in the selection.""" + return int(self._maintain_mask.size) def evidence_update_threshold( @@ -532,6 +643,20 @@ def evidence_update_threshold( ) +@dataclass +class ConsistentHypothesesIds: + """Contains hypotheses ids for symmetry detection. + + These ids will be updated when using the `ResamplingHypothesesUpdater`. + The update makes sure the ids are consistent across matching steps despite + resizing of hypothesis spaces. + """ + + hypotheses_ids: npt.NDArray[np.int_] + channel_sizes: OrderedDictType[str, int] + graph_id: str + + class InvalidEvidenceThresholdConfig(ValueError): """Raised when the evidence update threshold is invalid.""" diff --git a/src/tbp/monty/frameworks/utils/graph_matching_utils.py b/src/tbp/monty/frameworks/utils/graph_matching_utils.py index b33f4198b..83ad728a7 100644 --- a/src/tbp/monty/frameworks/utils/graph_matching_utils.py +++ b/src/tbp/monty/frameworks/utils/graph_matching_utils.py @@ -239,24 +239,28 @@ def get_scaled_evidences(evidences, per_object=False): scaled_evidences = {} if per_object: for graph_id in evidences.keys(): - scaled_evidences[graph_id] = ( - evidences[graph_id] - np.min(evidences[graph_id]) - ) / (np.max(evidences[graph_id]) - np.min(evidences[graph_id])) - # put in range(-1, 1) - scaled_evidences[graph_id] = (scaled_evidences[graph_id] - 0.5) * 2 + if len(evidences[graph_id]): + graph_evidences = evidences[graph_id] + min_evidence = np.min(graph_evidences) + max_evidence = np.max(graph_evidences) + scaled_evidences[graph_id] = (evidences[graph_id] - min_evidence) / ( + max_evidence - min_evidence + ) + # put in range(-1, 1) + scaled_evidences[graph_id] = (scaled_evidences[graph_id] - 0.5) * 2 else: min_evidence = np.inf max_evidence = -np.inf for graph_id in evidences.keys(): - minev = np.min(evidences[graph_id]) - if minev < min_evidence: - min_evidence = minev - maxev = np.max(evidences[graph_id]) - if maxev > max_evidence: - max_evidence = maxev + graph_evidences = evidences[graph_id] + if len(graph_evidences): + min_evidence = min(min_evidence, np.min(graph_evidences)) + max_evidence = max(max_evidence, np.max(graph_evidences)) + for graph_id in evidences.keys(): + graph_evidences = evidences[graph_id] if max_evidence >= 1: - scaled_evidences[graph_id] = (evidences[graph_id] - min_evidence) / ( + scaled_evidences[graph_id] = (graph_evidences - min_evidence) / ( max_evidence - min_evidence ) # put in range(-1, 1) @@ -264,7 +268,7 @@ def get_scaled_evidences(evidences, per_object=False): else: # If largest value is <1, don't scale them -> don't increase any # evidences. Instead just make sure they are in the right range. - scaled_evidences[graph_id] = np.clip(evidences[graph_id], -1, 1) + scaled_evidences[graph_id] = np.clip(graph_evidences, -1, 1) return scaled_evidences diff --git a/tests/unit/frameworks/models/evidence_matching/resampling_hypotheses_updater_test.py b/tests/unit/frameworks/models/evidence_matching/resampling_hypotheses_updater_test.py index ee644d1e8..cb3ae594e 100644 --- a/tests/unit/frameworks/models/evidence_matching/resampling_hypotheses_updater_test.py +++ b/tests/unit/frameworks/models/evidence_matching/resampling_hypotheses_updater_test.py @@ -6,6 +6,8 @@ # Use of this source code is governed by the MIT # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. +from collections import OrderedDict + import pytest pytest.importorskip( @@ -20,10 +22,19 @@ from tbp.monty.frameworks.utils.evidence_matching import ( ChannelMapper, + ConsistentHypothesesIds, EvidenceSlopeTracker, ) +def make_consistent_ids(graph_id, sizes, ids): + return ConsistentHypothesesIds( + graph_id=graph_id, + channel_sizes=OrderedDict(sizes), + hypotheses_ids=np.asarray(ids, dtype=np.int64), + ) + + class ResamplingHypothesesUpdaterTest(TestCase): def setUp(self) -> None: super().setUp() @@ -39,6 +50,13 @@ def setUp(self) -> None: ], ) + def get_resampling_updater(self): + train_config = copy.deepcopy(self.pretraining_configs) + with MontySupervisedObjectPretrainingExperiment(train_config) as train_exp: + train_exp.setup_experiment(train_exp.config) + + return train_exp.model.learning_modules[0].hypotheses_updater + def get_pretrained_resampling_lm(self): exp = hydra.utils.instantiate(self.cfg.test) with exp: @@ -62,13 +80,13 @@ def _num_hyps_multiplier(self, rlm, pose_defined): def run_sample_count( self, rlm, - count_multiplier, - existing_to_new_ratio, + resampling_multiplier, + deletion_trigger_slope, pose_defined, graph_id, ): - rlm.hypotheses_updater.hypotheses_count_multiplier = count_multiplier - rlm.hypotheses_updater.hypotheses_existing_to_new_ratio = existing_to_new_ratio + rlm.hypotheses_updater.resampling_multiplier = resampling_multiplier + rlm.hypotheses_updater.deletion_trigger_slope = deletion_trigger_slope test_features = {"patch": {"pose_fully_defined": pose_defined}} return rlm.hypotheses_updater._sample_count( input_channel="patch", @@ -78,34 +96,12 @@ def run_sample_count( tracker=rlm.hypotheses_updater.evidence_slope_trackers[graph_id], ) - def _initial_count(self, rlm, pose_defined): - """This tests that the initial requested number of hypotheses is correct. - - In order to initialize a hypothesis space, the `_sample_count` should request - that all resampled hypotheses be of the type informed. This tests the informed - sampling with defined and undefined poses. - """ - graph_id = "capsule3DSolid" - existing_count, informed_count = self.run_sample_count( - rlm=rlm, - count_multiplier=1, - existing_to_new_ratio=0.1, - pose_defined=pose_defined, - graph_id=graph_id, - ) - self.assertEqual(existing_count, 0) - self.assertEqual( - informed_count, - self._graph_node_count(rlm, graph_id) - * self._num_hyps_multiplier(rlm, pose_defined), - ) - - def _count_multiplier(self, rlm): - """This tests that the count multiplier correctly scales the hypothesis space. + def _resampling_multiplier(self, rlm): + """Tests that the resampling multiplier correctly scales the hypothesis space. - The count multiplier parameter is used to scale the hypothesis space between - steps. For example, a multiplier of 2, will request to double the number of - hypotheses. + The resampling multiplier parameter is used to scale the hypothesis space + between steps. For example, a multiplier of 2, will request to increase the + number of hypotheses by 2x the number of graph nodes. """ graph_id = "capsule3DSolid" pose_defined = True @@ -115,147 +111,209 @@ def _count_multiplier(self, rlm): rlm.hypotheses_updater.evidence_slope_trackers[graph_id].add_hyp( before_count, "patch" ) - count_multipliers = [0.5, 1, 2] + resampling_multipliers = [0.5, 1, 2] - for count_multiplier in count_multipliers: - existing_count, informed_count = self.run_sample_count( + for resampling_multiplier in resampling_multipliers: + _, informed_count = self.run_sample_count( rlm=rlm, - count_multiplier=count_multiplier, - existing_to_new_ratio=0.5, + resampling_multiplier=resampling_multiplier, + deletion_trigger_slope=0.0, pose_defined=pose_defined, graph_id=graph_id, ) - self.assertEqual( - before_count * count_multiplier, (existing_count + informed_count) - ) + self.assertEqual(graph_num_nodes * resampling_multiplier, informed_count) # Reset mapper rlm.channel_hypothesis_mapping[graph_id] = ChannelMapper() - def _count_multiplier_maximum(self, rlm, pose_defined): - """This tests that the count multiplier respects the maximum scaling boundary. - - The count multiplier parameter is used to scale the hypothesis space between - steps. For example, a multiplier of 2, will request to double the number of - hypotheses. However, there is a limit to how many hypotheses we can resample. - For existing hypotheses, the limit is to resample all of them. For newly - resampled informed hypotheses, the limit depends on whether the pose is defined - or not. This test ensures that `_sample_count` respects the maximum sampling - limit. + def _resampling_multiplier_maximum(self, rlm, pose_defined): + """Tests that the resampling multiplier respects the maximum scaling boundary. - In the case of `pose_defined = True` - Existing is 72 and informed is 2*36=72 (total is 144) - Maximum multiplier can be 2 if the pose is defined + The resampling multiplier is used to scale the hypothesis space between + steps. For example, a multiplier of 2, will request to add hypotheses of + count that is twice the number of nodes in the object graph. However, there + is a limit to how many hypotheses we can resample. For existing hypotheses, + the limit is to resample all of them. For newly resampled informed hypotheses, + the limit depends on whether the pose is defined or not. This test ensures + that `_sample_count` respects the maximum sampling limit. - In the case of `pose_defined = False` - Existing is 72 and informed is 8*36=288 (total is 360) - Maximum multiplier can be umbilical_num_poses if the pose is undefined + Maximum multiplier cannot exceed the num_hyps_per_node (2 if + `pose_defined=True` or umbilical_num_poses if `pose_defined=False`). """ graph_id = "capsule3DSolid" graph_num_nodes = self._graph_node_count(rlm, graph_id) before_count = graph_num_nodes * self._num_hyps_multiplier(rlm, pose_defined) rlm.channel_hypothesis_mapping[graph_id].add_channel("patch", before_count) - requested_count_multiplier = 100 + resampling_multiplier = 100 expected_count = before_count + ( graph_num_nodes * self._num_hyps_multiplier(rlm, pose_defined) ) - existing_count, informed_count = self.run_sample_count( + _, informed_count = self.run_sample_count( rlm=rlm, - count_multiplier=requested_count_multiplier, - existing_to_new_ratio=0.5, + resampling_multiplier=resampling_multiplier, + deletion_trigger_slope=0.0, pose_defined=pose_defined, graph_id=graph_id, ) - self.assertEqual(expected_count, existing_count + informed_count) + self.assertEqual(expected_count, before_count + informed_count) # Reset mapper rlm.channel_hypothesis_mapping[graph_id] = ChannelMapper() - def _count_ratio(self, rlm, pose_defined): - """This tests that the resampling ratio of new hypotheses is correct. - - The existing_to_new_ratio parameter is used to control the ratio of how many - existing vs. informed hypotheses to resample. This test ensures that the - `_sample_count` function follows the expected behavior of this ratio parameter. - - Note that the `_sample_count` function will prioritize the multiplier count - parameter over this ratio parameter. In other words, if not enough existing - hypotheses are available, the function will attempt to fill the missing - existing hypotheses with informed hypotheses. + def test_sampling_count(self): + """This function tests different aspects of _sample_count. + We define three different tests of `_sample_count`: + - Testing the requested count for initialization of hypotheses space + - Testing the resampling multiplier parameter + - Testing the resampling multiplier parameter maximum limit """ - graph_id = "capsule3DSolid" - graph_num_nodes = self._graph_node_count(rlm, graph_id) - available_existing_count = graph_num_nodes * self._num_hyps_multiplier( - rlm, pose_defined - ) - rlm.channel_hypothesis_mapping[graph_id].add_channel( - "patch", available_existing_count - ) - rlm.hypotheses_updater.evidence_slope_trackers[graph_id].add_hyp( - available_existing_count, "patch" - ) - count_multiplier = 2 + rlm = self.get_pretrained_resampling_lm() - for ratio in [0.0, 0.1, 0.5, 0.9, 1.0]: - requested_existing_count = ( - available_existing_count * count_multiplier * (1.0 - ratio) - ) - requested_informed_count = ( - available_existing_count * count_multiplier * ratio - ) - maximum_available_existing_count = available_existing_count - maximum_available_informed_count = ( - graph_num_nodes * self._num_hyps_multiplier(rlm, pose_defined) - ) + # test count multiplier + self._resampling_multiplier(rlm) + self._resampling_multiplier_maximum(rlm, pose_defined=True) + self._resampling_multiplier_maximum(rlm, pose_defined=False) + + def _single_channel_no_changes(self, updater): + updater.resampling_telemetry = { + "mug": {"patch": {"removed_ids": [], "added_ids": []}} + } + hyp_ids = make_consistent_ids( + graph_id="mug", sizes=[("patch", 5)], ids=[0, 1, 3, 4] + ) - existing_count, informed_count = self.run_sample_count( - rlm=rlm, - count_multiplier=count_multiplier, - existing_to_new_ratio=ratio, - pose_defined=pose_defined, - graph_id=graph_id, - ) - expected_existing_count = min( - maximum_available_existing_count, - requested_existing_count, - ) - self.assertEqual(existing_count, int(expected_existing_count)) + hyp_ids = updater.remap_hypotheses_ids_to_present(hyp_ids) + np.testing.assert_array_equal(hyp_ids.hypotheses_ids, np.array([0, 1, 3, 4])) - # `missing_existing_hypotheses` will be zero, or otherwise the count that - # informed hypotheses need to fill in - missing_existing_hypotheses = ( - requested_existing_count - expected_existing_count - ) - expected_informed_count = min( - maximum_available_informed_count, - (requested_informed_count + missing_existing_hypotheses), - ) - self.assertEqual(informed_count, int(expected_informed_count)) + def _single_channel_with_removals_shifts(self, updater): + updater.resampling_telemetry = { + "mug": {"patch": {"removed_ids": [1, 4, 6], "added_ids": []}} + } + hyp_ids = make_consistent_ids( + graph_id="mug", sizes=[("patch", 8)], ids=[0, 2, 3, 5, 7] + ) + hyp_ids = updater.remap_hypotheses_ids_to_present(hyp_ids) - # Reset mapper - rlm.channel_hypothesis_mapping[graph_id] = ChannelMapper() + # Shift per searchsorted([1,4,6], x, 'left'): 0->0, 2->1, 3->1, 5->2, 7->3 + # new locals after shifting: [0,1,2,3,4] + np.testing.assert_array_equal(hyp_ids.hypotheses_ids, np.array([0, 1, 2, 3, 4])) - def test_sampling_count(self): - """This function tests different aspects of _sample_count. + def _single_channel_full_remap_misses_added(self, updater): + """Tests that added ids do not show up in remapping. - We define three different tests of `_sample_count`: - - Testing the requested count for initialization of hypotheses space - - Testing the count multiplier parameter - - Testing the count ratio of resampled hypotheses + The remapping function finds the mapping between ids from the previous step + to the current time step. The added_ids did not exist in previous steps, + therefore should not appear in the mapping. """ - rlm = self.get_pretrained_resampling_lm() + added_ids = [5, 6] - # test initial count - self._initial_count(rlm, pose_defined=True) - self._initial_count(rlm, pose_defined=False) + updater.resampling_telemetry = { + "mug": {"patch": {"removed_ids": [], "added_ids": added_ids}} + } + hyp_ids = make_consistent_ids( + graph_id="mug", sizes=[("patch", 5)], ids=list(range(5)) + ) + hyp_ids = updater.remap_hypotheses_ids_to_present(hyp_ids) + + # In patch0: locals ids = [0,1,2,3,4]; removed = []; shift = [0,0,0,0,0]. + # So [0,1,2,3,4] becomes [0,1,2,3,4] + np.testing.assert_array_equal(hyp_ids.hypotheses_ids, np.array([0, 1, 2, 3, 4])) + + # Added ids should NOT appear in the remapped ids + self.assertFalse(np.isin(added_ids, hyp_ids.hypotheses_ids).any()) + + def _multi_channel_rebase_due_to_resizing(self, updater): + updater.resampling_telemetry = { + "mug": { + "patch0": {"removed_ids": [1, 3], "added_ids": [5, 6]}, + "patch1": {"removed_ids": [2], "added_ids": [4]}, + } + } + hyp_ids = make_consistent_ids( + graph_id="mug", sizes=[("patch0", 5), ("patch1", 4)], ids=[0, 2, 4, 5, 7] + ) + hyp_ids = updater.remap_hypotheses_ids_to_present(hyp_ids) + + # In patch0: locals ids = [0,2,4]; removed = [1,3]; shift = [0,1,2]. + # So [0, 2, 4] becomes [0, 1, 2] + + # new bases are the same since patch0 removed 2 and added 2. + # So new_bases = [0,5] + + # In patch1: locals ids = [0,2]; removed = [2]; new base = 5 + # So [5, 7] becomes [5] + np.testing.assert_array_equal(hyp_ids.hypotheses_ids, np.array([0, 1, 2, 5])) + + def _rebase_when_first_channel_shrinks(self, updater): + updater.resampling_telemetry = { + "mug": { + "patch0": {"removed_ids": [1, 3], "added_ids": []}, # shrink by 2 + "patch1": {"removed_ids": [], "added_ids": []}, + } + } + hyp_ids = make_consistent_ids( + graph_id="mug", + sizes=[("patch0", 5), ("patch1", 4)], + ids=[0, 2, 4, 5, 7], + ) + hyp_ids = updater.remap_hypotheses_ids_to_present(hyp_ids) + + # In patch0: local = [0,2,4]; removed = [1,3]; shifts = [0,1,2] + # So [0,2,4] becomes [0,1,2] + + # New bases: [0,3] so patch1 base is 3 (not 5) + + # In patch 1: locals = [0,2] + # So [5,7] becomes [3,5] + np.testing.assert_array_equal(hyp_ids.hypotheses_ids, np.array([0, 1, 2, 3, 5])) + + def _rebase_when_first_channel_grows(self, updater): + updater.resampling_telemetry = { + "mug": { + "patch0": {"removed_ids": [], "added_ids": [5, 6]}, # grow by 2 + "patch1": {"removed_ids": [], "added_ids": []}, + } + } + hyp_ids = make_consistent_ids( + graph_id="mug", + sizes=[("patch0", 5), ("patch1", 4)], + ids=[0, 4, 5, 6, 8], + ) + out = updater.remap_hypotheses_ids_to_present(hyp_ids) + + # In patch0: local = [0,4]; added = [5,6]; No shifts + # So [0,4] becomes [0,4] + + # New bases: [0,7] + + # In patch 1: locals = [0,2,4] + # So [0,1,3] becomes [7,8,10] + np.testing.assert_array_equal(out.hypotheses_ids, np.array([0, 4, 7, 8, 10])) + + def _all_ids_removed_in_a_channel(self, updater): + updater.resampling_telemetry = { + "mug": { + "patch0": {"removed_ids": [0, 1, 2], "added_ids": []}, + "patch1": {"removed_ids": [], "added_ids": []}, + } + } + hyp_ids = make_consistent_ids( + graph_id="mug", sizes=[("patch0", 3), ("patch1", 3)], ids=[0, 1, 2, 3, 4, 5] + ) + hyp_ids = updater.remap_hypotheses_ids_to_present(hyp_ids) - # test count multiplier - self._count_multiplier(rlm) - self._count_multiplier_maximum(rlm, pose_defined=True) - self._count_multiplier_maximum(rlm, pose_defined=False) + # Removed [0, 1, 2], so [3, 4, 5] was rebased to [0, 1, 2] + np.testing.assert_array_equal(hyp_ids.hypotheses_ids, np.array([0, 1, 2])) + + def test_remap_hypotheses_ids(self): + updater = self.get_resampling_updater() - # test existing to informed ratio - self._count_ratio(rlm, pose_defined=True) - self._count_ratio(rlm, pose_defined=False) + self._single_channel_no_changes(updater) + self._single_channel_with_removals_shifts(updater) + self._single_channel_full_remap_misses_added(updater) + self._multi_channel_rebase_due_to_resizing(updater) + self._rebase_when_first_channel_shrinks(updater) + self._rebase_when_first_channel_grows(updater) + self._all_ids_removed_in_a_channel(updater) diff --git a/tests/unit/frameworks/utils/evidence_matching_test.py b/tests/unit/frameworks/utils/evidence_matching_test.py index 41a65e25e..a0c06591e 100644 --- a/tests/unit/frameworks/utils/evidence_matching_test.py +++ b/tests/unit/frameworks/utils/evidence_matching_test.py @@ -45,21 +45,6 @@ def test_channel_size(self) -> None: with self.assertRaises(ValueError): self.mapper.channel_size("D") - def test_resize_channel_by_positive(self) -> None: - """Test increasing channel sizes.""" - self.mapper.resize_channel_by("B", 5) - self.assertEqual(self.mapper.channel_range("B"), (5, 20)) - self.assertEqual(self.mapper.total_size, 35) - - def test_resize_channel_by_negative(self) -> None: - """Test decreasing channel sizes.""" - self.mapper.resize_channel_by("B", -5) - self.assertEqual(self.mapper.channel_range("B"), (5, 10)) - self.assertEqual(self.mapper.total_size, 25) - - with self.assertRaises(ValueError): - self.mapper.resize_channel_by("A", -10) - def test_resize_channel_to_valid(self) -> None: """Test setting a new size for an existing channel.""" self.mapper.resize_channel_to("A", 8) @@ -68,18 +53,67 @@ def test_resize_channel_to_valid(self) -> None: self.assertEqual(self.mapper.channel_range("C"), (18, 33)) self.assertEqual(self.mapper.total_size, 33) + def test_resize_channel_to_zero_deletes(self) -> None: + """Resizing a channel to zero removes it.""" + self.mapper.resize_channel_to("B", 0) + self.assertEqual(self.mapper.channels, ["A", "C"]) + self.assertEqual(self.mapper.total_size, 20) + self.assertEqual(self.mapper.channel_range("A"), (0, 5)) + self.assertEqual(self.mapper.channel_range("C"), (5, 20)) + # "B" is gone + with self.assertRaises(ValueError): + self.mapper.channel_range("B") + def test_resize_channel_to_invalid_channel(self) -> None: """Test resizing a non-existent channel.""" with self.assertRaises(ValueError): self.mapper.resize_channel_to("Z", 5) def test_resize_channel_to_invalid_size(self) -> None: - """Test resizing a channel to a non-positive size.""" - with self.assertRaises(ValueError): - self.mapper.resize_channel_to("B", 0) + """Test resizing a channel to a negative size.""" with self.assertRaises(ValueError): self.mapper.resize_channel_to("B", -3) + def test_delete_channel_middle(self) -> None: + """Deleting a middle channel updates order, ranges, and total size.""" + self.mapper.delete_channel("B") + self.assertEqual(self.mapper.channels, ["A", "C"]) + self.assertEqual(self.mapper.total_size, 20) + self.assertEqual(self.mapper.channel_range("A"), (0, 5)) + self.assertEqual(self.mapper.channel_range("C"), (5, 20)) + + def test_delete_channel_first(self) -> None: + """Deleting the first channel shifts subsequent ranges correctly.""" + self.mapper.delete_channel("A") + self.assertEqual(self.mapper.channels, ["B", "C"]) + self.assertEqual(self.mapper.total_size, 25) + self.assertEqual(self.mapper.channel_range("B"), (0, 10)) + self.assertEqual(self.mapper.channel_range("C"), (10, 25)) + + def test_delete_channel_last(self) -> None: + """Deleting the last channel leaves earlier ranges unchanged.""" + self.mapper.delete_channel("C") + self.assertEqual(self.mapper.channels, ["A", "B"]) + self.assertEqual(self.mapper.total_size, 15) + self.assertEqual(self.mapper.channel_range("A"), (0, 5)) + self.assertEqual(self.mapper.channel_range("B"), (5, 15)) + + def test_delete_channel_nonexistent(self) -> None: + """Deleting an unknown channel raises.""" + with self.assertRaises(ValueError): + self.mapper.delete_channel("Z") + + def test_delete_all_channels(self) -> None: + """Deleting all channels yields an empty mapper.""" + self.mapper.delete_channel("A") + self.mapper.delete_channel("B") + self.mapper.delete_channel("C") + self.assertEqual(self.mapper.channels, []) + self.assertEqual(self.mapper.total_size, 0) + # Follow-up operations should still error cleanly + with self.assertRaises(ValueError): + self.mapper.channel_range("A") + def test_add_channel(self) -> None: """Test adding a new channel.""" self.mapper.add_channel("D", 8) @@ -278,27 +312,34 @@ def test_removable_indices_mask_matches_min_age(self) -> None: mask = self.tracker.removable_indices_mask(self.channel) np.testing.assert_array_equal(mask, [False, True, True]) - def test_calculate_keep_and_remove_ids_returns_expected(self) -> None: - """Test that hypotheses with the lowest slopes are selected for removal.""" - self.tracker.add_hyp(3, self.channel) - self.tracker.update(np.array([1.0, 3.0, 1.0]), self.channel) - self.tracker.update(np.array([2.0, 2.0, 1.0]), self.channel) - self.tracker.update(np.array([3.0, 1.0, 1.0]), self.channel) + def test_select_hypotheses_threshold_and_age(self) -> None: + """Test that select_hypotheses respects slope threshold and min_age.""" + self.tracker.add_hyp(4, self.channel) + + # slopes are [1, 0, -1, -1] + self.tracker.update(np.array([1.0, 2.0, 3.0, 3.0]), self.channel) + self.tracker.update(np.array([2.0, 2.0, 2.0, 2.0]), self.channel) + self.tracker.update(np.array([3.0, 2.0, 1.0, 1.0]), self.channel) + + # Force ages so only last hyp is too young to remove. + self.tracker.hyp_age[self.channel] = np.array([3, 3, 3, 1], dtype=int) - # Slopes = [1.0, -1.0, 0.0] - to_keep, to_remove = self.tracker.calculate_keep_and_remove_ids( - num_keep=2, channel=self.channel + selection = self.tracker.select_hypotheses( + slope_threshold=-0.5, channel=self.channel ) - np.testing.assert_array_equal(np.sort(to_keep), [0, 2]) - np.testing.assert_array_equal(to_remove, [1]) + # 0,1 have higher slopes, 3 is too young + expected_keep = np.array([0, 1, 3], dtype=int) + expected_keep_mask = np.array([True, True, False, True], dtype=bool) - def test_keep_more_than_total_raises(self) -> None: - """Test that asking to keep more hypotheses than exist raises an error.""" - self.tracker.add_hyp(2, self.channel) - self.tracker.hyp_age[self.channel][:] = [2, 2] - with self.assertRaises(ValueError): - self.tracker.calculate_keep_and_remove_ids(3, self.channel) + # lower slope than threshold (-1 < -0.5) + expected_remove = np.array([2], dtype=int) + expected_remove_mask = np.array([False, False, True, False], dtype=bool) + + np.testing.assert_array_equal(selection.maintain_ids, expected_keep) + np.testing.assert_array_equal(selection.remove_ids, expected_remove) + np.testing.assert_array_equal(selection.maintain_mask, expected_keep_mask) + np.testing.assert_array_equal(selection.remove_mask, expected_remove_mask) if __name__ == "__main__":