Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
efa8696
feat: decouple resampling params (#1)
ramyamounir Aug 15, 2025
fb16234
Merge branch 'main' into dev
ramyamounir Aug 20, 2025
ce1c37d
Merge branch 'main' into dev
ramyamounir Aug 28, 2025
fcb98a2
Merge branch 'main' into dev
ramyamounir Aug 31, 2025
9e72e94
Merge branch 'main' into dev
ramyamounir Sep 3, 2025
0b9ec0e
Merge branch 'main' into dev
ramyamounir Sep 8, 2025
823c741
Merge branch 'main' into dev
ramyamounir Sep 15, 2025
03daf45
Merge branch 'main' into dev
ramyamounir Sep 16, 2025
cea1465
Merge branch 'main' into dev
ramyamounir Sep 27, 2025
867b6c0
Merge branch 'main' into dev
ramyamounir Oct 2, 2025
7f07e97
Merge branch 'main' into dev
ramyamounir Oct 2, 2025
54fe5cd
Merge branch 'main' into dev
ramyamounir Oct 14, 2025
d84b031
feat!: add support for minimum maintained hypotheses (#2)
ramyamounir Oct 16, 2025
93f4e47
Merge branch 'main' into dev
ramyamounir Oct 16, 2025
d3da576
feat: symmetry remapping fix for consistent ids (#3)
ramyamounir Oct 17, 2025
0dff3d7
Merge branch 'main' into dev
ramyamounir Oct 17, 2025
7a65bcb
refactor: adjust default value of evidence_slope_threshold to 0.3
ramyamounir Oct 20, 2025
0227614
Merge branch 'main' into dev
ramyamounir Oct 21, 2025
262216f
refactor: update gsg and learning module to handle empty hypothesis s…
ramyamounir Oct 30, 2025
66d1bad
Merge branch 'main' into dev
ramyamounir Nov 2, 2025
f650dba
tests: provide init_hyp_space to sample count tests
ramyamounir Nov 2, 2025
ab2c18f
docs: update docstring to add evidence_slope_threshold expected range
ramyamounir Nov 10, 2025
a123357
refactor: update last_possible_hypotheses remapping
ramyamounir Nov 10, 2025
65119de
chore: added type hinting to _check_for_symmetry
ramyamounir Nov 10, 2025
da1dff6
refactor: move object_id check in symmetry logic
ramyamounir Nov 10, 2025
732c084
refactor: remove unneccessary variable
ramyamounir Nov 10, 2025
9855df4
style: ruff RET504
ramyamounir Nov 10, 2025
3d1fcee
docs: fix range of evidence_slope_threshold in docstring
ramyamounir Nov 11, 2025
97a92fa
docs: add comment about `new_informed` being divisible by `num_hyps_p…
ramyamounir Nov 11, 2025
085a1e1
Merge branch 'main' into dev
ramyamounir Nov 11, 2025
9166334
Merge branch 'main' into dev
ramyamounir Nov 11, 2025
c26d9dc
Merge branch 'main' into dev
ramyamounir Nov 11, 2025
dbecf0a
refactor: return update telemetry for prediction error as dict
ramyamounir Nov 11, 2025
9412d28
fix: temporary fix for init_hyp_space conditions
ramyamounir Nov 11, 2025
1ef9cf8
Merge branch 'main' into dev
ramyamounir Nov 12, 2025
fcf0546
Merge branch 'main' into dev
ramyamounir Nov 19, 2025
7985888
feat!: burst sampling added to the resampling updater (#6)
ramyamounir Nov 24, 2025
451a418
Merge branch 'main' into dev
ramyamounir Nov 24, 2025
c31557c
Merge branch 'main' into dev
ramyamounir Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
DefaultHypothesesDisplacer,
HypothesisDisplacerTelemetry,
)
from tbp.monty.frameworks.utils.evidence_matching import ChannelMapper
from tbp.monty.frameworks.utils.evidence_matching import (
ChannelMapper,
ConsistentHypothesesIds,
)
from tbp.monty.frameworks.utils.graph_matching_utils import (
get_initial_possible_poses,
possible_sensed_directions,
Expand All @@ -56,6 +59,14 @@ class ChannelHypothesesUpdateTelemetry:


class HypothesesUpdater(Protocol):
def pre_step(self) -> None:
"""Runs once per step before updating the hypotheses."""
...

def post_step(self) -> None:
"""Runs once per step after updating the hypotheses."""
...

def update_hypotheses(
self,
hypotheses: Hypotheses,
Expand All @@ -81,6 +92,19 @@ def update_hypotheses(
"""
...

def remap_hypotheses_ids_to_present(
self, hypotheses_ids: ConsistentHypothesesIds
) -> ConsistentHypothesesIds:
"""Update hypotheses ids based on resizing of hypothesis space.

Args:
hypotheses_ids: Hypotheses ids to be updated

Returns:
The list of the updated hypotheses ids.
"""
...


class DefaultHypothesesUpdater(HypothesesUpdater):
def __init__(
Expand Down Expand Up @@ -183,6 +207,14 @@ def __init__(
use_features_for_matching=self.use_features_for_matching,
)

def pre_step(self) -> None:
"""Runs once per step before updating the hypotheses."""
...

def post_step(self) -> None:
"""Runs once per step after updating the hypotheses."""
...

def update_hypotheses(
self,
hypotheses: Hypotheses,
Expand Down Expand Up @@ -403,6 +435,22 @@ def _get_initial_hypothesis_space(
poses=initial_possible_channel_rotations,
)

def remap_hypotheses_ids_to_present(
self, hypotheses_ids: ConsistentHypothesesIds
) -> ConsistentHypothesesIds:
"""Update hypotheses ids based on resizing of hypothesis space.

We do not resize the hypotheses space when using `DefaultHypothesesUpdater`,
therefore, we return the same ids without update.

Args:
hypotheses_ids: Hypotheses ids to be updated

Returns:
The list of the updated hypotheses ids.
"""
return hypotheses_ids


def all_usable_input_channels(
features: dict, all_input_channels: list[str]
Expand Down
137 changes: 101 additions & 36 deletions src/tbp/monty/frameworks/models/evidence_matching/learning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import time

import numpy as np
import numpy.typing as npt
from scipy.spatial import KDTree
from scipy.spatial.transform import Rotation

Expand All @@ -35,6 +36,7 @@
from tbp.monty.frameworks.models.states import State
from tbp.monty.frameworks.utils.evidence_matching import (
ChannelMapper,
ConsistentHypothesesIds,
evidence_update_threshold,
)
from tbp.monty.frameworks.utils.graph_matching_utils import (
Expand Down Expand Up @@ -521,18 +523,30 @@ def get_unique_pose_if_available(self, object_id):
# Only try to determine object pose if the evidence for it is high enough.
if possible_object_hypotheses_ids is not None:
mlh = self.get_current_mlh()

# Check if all possible poses are similar
pose_is_unique = self._check_for_unique_poses(
object_id, possible_object_hypotheses_ids, mlh["rotation"]
)

# Check for symmetry
last_possible_hypotheses_remapped = (
self.hypotheses_updater.remap_hypotheses_ids_to_present(
self.last_possible_hypotheses
)
)
symmetry_detected = self._check_for_symmetry(
possible_object_hypotheses_ids,
object_id=object_id,
last_possible_object_hypotheses=last_possible_hypotheses_remapped,
possible_object_hypotheses_ids=possible_object_hypotheses_ids,
# Don't increment symmetry counter if LM didn't process observation
increment_evidence=self.buffer.get_last_obs_processed(),
)

self.last_possible_hypotheses = possible_object_hypotheses_ids
self.last_possible_hypotheses = ConsistentHypothesesIds(
hypotheses_ids=possible_object_hypotheses_ids,
channel_sizes=self.channel_hypothesis_mapping[object_id].channel_sizes,
graph_id=object_id,
)

if pose_is_unique or symmetry_detected:
r_inv = mlh["rotation"].inv()
Expand Down Expand Up @@ -561,18 +575,18 @@ def get_unique_pose_if_available(self, object_id):
if symmetry_detected:
symmetry_stats = {
"symmetric_rotations": np.array(self.possible_poses[object_id])[
self.last_possible_hypotheses
self.last_possible_hypotheses.hypotheses_ids
],
"symmetric_locations": self.possible_locations[object_id][
self.last_possible_hypotheses
self.last_possible_hypotheses.hypotheses_ids
],
}
self.buffer.add_overall_stats(symmetry_stats)
return pose_and_scale

logger.debug(f"object {object_id} detected but pose not resolved yet.")
return None

self.last_possible_hypotheses = None
return None

def get_current_mlh(self):
Expand Down Expand Up @@ -603,16 +617,22 @@ def get_top_two_mlh_ids(self):
"""
graph_ids, graph_evidences = self.get_evidence_for_each_graph()

# If all hypothesis spaces are empty return None for both mlh ids. The gsg will
# not generate a goal state.
if len(graph_ids) == 0:
return None, None

# If we have a single hypothesis space, return the second object id as None.
# The gsg will focus on pose to generate a goal state.
if len(graph_ids) == 1:
return graph_ids[0], None

# Note the indices below will be ordered with the 2nd MLH appearing first, and
# the 1st MLH appearing second.
top_indices = np.argsort(graph_evidences)[-2:]
top_id = graph_ids[top_indices[1]]
second_id = graph_ids[top_indices[0]]

if len(top_indices) > 1:
top_id = graph_ids[top_indices[1]]
second_id = graph_ids[top_indices[0]]
else:
top_id = graph_ids[top_indices[0]]
second_id = top_id
# Account for the case where we have multiple top evidences with the same value.
# In this case argsort and argmax (used to get current_mlh) will return
# different results but some downstream logic (in gsg) expects them to be the
Expand All @@ -626,6 +646,7 @@ def get_top_two_mlh_ids(self):
# and keep the second id as is (since this means there is a threeway
# tie in evidence values so its not like top is more likely than second)
top_id = self.current_mlh["graph_id"]

return top_id, second_id

def get_top_two_pose_hypotheses_for_graph_id(self, graph_id):
Expand Down Expand Up @@ -683,15 +704,24 @@ def get_evidence_for_each_graph(self):
graph_ids = self.get_all_known_object_ids()
if graph_ids[0] not in self.evidence.keys():
return ["patch_off_object"], [0]
graph_evidences = []

available_graph_ids = []
available_graph_evidences = []
for graph_id in graph_ids:
graph_evidences.append(np.max(self.evidence[graph_id]))
return graph_ids, np.array(graph_evidences)
if len(self.hyp_evidences_for_object(graph_id)):
available_graph_ids.append(graph_id)
available_graph_evidences.append(np.max(self.evidence[graph_id]))

return available_graph_ids, np.array(available_graph_evidences)

def get_all_evidences(self):
"""Return evidence for each pose on each graph (pointer)."""
return self.evidence

def hyp_evidences_for_object(self, object_id):
"""Return evidences for a specific object_id."""
return self.evidence[object_id]

# ------------------ Logging & Saving ----------------------
def collect_stats_to_save(self):
"""Get all stats that this LM should store in the buffer for logging.
Expand All @@ -710,6 +740,8 @@ def collect_stats_to_save(self):

def _update_possible_matches(self, query):
"""Update evidence for each hypothesis instead of removing them."""
self.hypotheses_updater.pre_step()

thread_list = []
for graph_id in self.get_all_known_object_ids():
if self.use_multithreading:
Expand Down Expand Up @@ -738,6 +770,8 @@ def _update_possible_matches(self, query):
self.previous_mlh = self.current_mlh
self.current_mlh = self._calculate_most_likely_hypothesis()

self.hypotheses_updater.post_step()

def _update_evidence(
self,
features: dict,
Expand Down Expand Up @@ -800,12 +834,20 @@ def _update_evidence(
self._set_hypotheses_in_hpspace(graph_id=graph_id, new_hypotheses=update)

end_time = time.time()
assert not np.isnan(np.max(self.evidence[graph_id])), "evidence contains NaN."
logger.debug(

logger_msg = (
f"evidence update for {graph_id} took "
f"{np.round(end_time - start_time, 2)} seconds."
f" New max evidence: {np.round(np.max(self.evidence[graph_id]), 3)}"
)
graph_evidence = self.hyp_evidences_for_object(graph_id)
if len(graph_evidence):
assert not np.isnan(np.max(self.evidence[graph_id])), (
"evidence contains NaN."
)
logger_msg += (
f" New max evidence: {np.round(np.max(self.evidence[graph_id]), 3)}"
)
logger.debug(logger_msg)

def _set_hypotheses_in_hpspace(
self,
Expand Down Expand Up @@ -992,7 +1034,13 @@ def _check_for_unique_poses(

return location_unique and rotation_unique

def _check_for_symmetry(self, possible_object_hypotheses_ids, increment_evidence):
def _check_for_symmetry(
self,
object_id: str,
last_possible_object_hypotheses: ConsistentHypothesesIds | None,
possible_object_hypotheses_ids: npt.NDArray[np.int64],
increment_evidence: bool,
):
"""Check whether the most likely hypotheses stayed the same over the past steps.

Since the definition of possible_object_hypotheses is a bit murky and depends
Expand All @@ -1001,6 +1049,9 @@ def _check_for_symmetry(self, possible_object_hypotheses_ids, increment_evidence
not sure if this is the best way to check for symmetry...

Args:
object_id: identifier of the object being checked for symmetry
last_possible_object_hypotheses: All the possible hypotheses
from the last step.
possible_object_hypotheses_ids: List of IDs of all possible hypotheses.
increment_evidence: Whether to increment symmetry evidence or not. We
may want this to be False for example if we did not receive a new
Expand All @@ -1009,14 +1060,17 @@ def _check_for_symmetry(self, possible_object_hypotheses_ids, increment_evidence
Returns:
Whether symmetry was detected.
"""
if self.last_possible_hypotheses is None:
if (
last_possible_object_hypotheses is None
or last_possible_object_hypotheses.graph_id != object_id
):
return False # need more steps to meet symmetry condition
logger.debug(
f"\n\nchecking for symmetry for hp ids {possible_object_hypotheses_ids}"
f" with last ids {self.last_possible_hypotheses}"
)
if increment_evidence:
previous_hyps = set(self.last_possible_hypotheses)
previous_hyps = set(last_possible_object_hypotheses.hypotheses_ids)
current_hyps = set(possible_object_hypotheses_ids)
hypothesis_overlap = previous_hyps.intersection(current_hyps)
if len(hypothesis_overlap) / len(current_hyps) > 0.9:
Expand Down Expand Up @@ -1106,7 +1160,13 @@ def _threshold_possible_matches(self, x_percent_scale_factor=1.0):
if len(self.graph_memory) == 0:
logger.info("no objects in memory yet.")
return []

graph_ids, graph_evidences = self.get_evidence_for_each_graph()

if len(graph_ids) == 0:
logger.info("All hypothesis spaces are empty. No possible matches.")
return []

# median_ge = np.median(graph_evidences)
mean_ge = np.mean(graph_evidences)
max_ge = np.max(graph_evidences)
Expand Down Expand Up @@ -1167,23 +1227,28 @@ def _calculate_most_likely_hypothesis(self, graph_id=None):
"""
mlh = {}
if graph_id is not None:
mlh_id = np.argmax(self.evidence[graph_id])
mlh = self._get_mlh_dict_from_id(graph_id, mlh_id)
graph_evidence = self.hyp_evidences_for_object(graph_id)
if len(graph_evidence):
mlh_id = np.argmax(graph_evidence)
mlh = self._get_mlh_dict_from_id(graph_id, mlh_id)
else:
highest_evidence_so_far = -np.inf
for next_graph_id in self.get_all_known_object_ids():
mlh_id = np.argmax(self.evidence[next_graph_id])
evidence = self.evidence[next_graph_id][mlh_id]
if evidence > highest_evidence_so_far:
mlh = self._get_mlh_dict_from_id(next_graph_id, mlh_id)
highest_evidence_so_far = evidence
if not mlh: # No objects in memory
mlh = self.current_mlh
mlh["graph_id"] = "new_object0"
logger.info(
f"current most likely hypothesis: {mlh['graph_id']} "
f"with evidence {np.round(mlh['evidence'], 2)}"
)
for graph_id in self.get_all_known_object_ids():
graph_evidence = self.hyp_evidences_for_object(graph_id)
if len(graph_evidence):
mlh_id = np.argmax(graph_evidence)
evidence = graph_evidence[mlh_id]
if evidence > highest_evidence_so_far:
mlh = self._get_mlh_dict_from_id(graph_id, mlh_id)
highest_evidence_so_far = evidence

if not mlh: # No objects in memory
mlh = self.current_mlh
mlh["graph_id"] = "new_object0"
logger.info(
f"current most likely hypothesis: {mlh['graph_id']} "
f"with evidence {np.round(mlh['evidence'], 2)}"
)
return mlh

def _get_node_distance_weights(self, distances):
Expand Down
Loading