From 72ed7d10c6d1a9cc8d48df3876a68e7419f5658c Mon Sep 17 00:00:00 2001 From: zhx06 Date: Wed, 25 Mar 2026 10:37:20 -0700 Subject: [PATCH 01/12] batch support for solver based on batched bbox --- .../environments/arena_env_builder.py | 45 +++- isaaclab_arena/relations/object_placer.py | 118 +++++++---- .../relations/object_placer_params.py | 2 +- isaaclab_arena/relations/placement_events.py | 92 ++++++++ isaaclab_arena/relations/placement_result.py | 18 ++ .../relations/relation_loss_strategies.py | 196 ++++++++++-------- isaaclab_arena/relations/relation_solver.py | 47 +++-- .../relations/relation_solver_params.py | 2 +- .../relations/relation_solver_state.py | 76 +++++-- isaaclab_arena/relations/relations.py | 13 +- .../tests/test_no_collision_loss.py | 40 ++++ .../test_object_placer_reproducibility.py | 23 ++ .../tests/test_relation_loss_strategies.py | 35 ++++ ...e_multi_object_no_collision_environment.py | 8 +- 14 files changed, 528 insertions(+), 187 deletions(-) create mode 100644 isaaclab_arena/relations/placement_events.py diff --git a/isaaclab_arena/environments/arena_env_builder.py b/isaaclab_arena/environments/arena_env_builder.py index 3d1543142..2a093ff1b 100644 --- a/isaaclab_arena/environments/arena_env_builder.py +++ b/isaaclab_arena/environments/arena_env_builder.py @@ -27,7 +27,9 @@ from isaaclab_arena.metrics.recorder_manager_utils import metrics_to_recorder_manager_cfg from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams -from isaaclab_arena.relations.relations import IsAnchor, NoCollision +from isaaclab_arena.relations.placement_events import make_placement_event_cfg +from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult +from isaaclab_arena.relations.relations import IsAnchor, NoCollision, get_anchor_objects from isaaclab_arena.tasks.no_task import NoTask from isaaclab_arena.utils.configclass import combine_configclass_instances from isaaclab_arena.utils.multiprocess import get_local_rank @@ -101,9 +103,35 @@ def _solve_relations(self) -> None: # Run the ObjectPlacer (default on_relation_z_tolerance_m accommodates solver residual). placement_seed = getattr(self.args, "placement_seed", None) placer = ObjectPlacer(params=ObjectPlacerParams(placement_seed=placement_seed)) - result = placer.place(objects=objects_with_relations) - - if result.success: + num_envs = self.args.num_envs + result = placer.place(objects_with_relations, num_envs=num_envs) + if isinstance(result, MultiEnvPlacementResult) and result.results: + positions_all_envs_by_name = [ + {obj.name: result.results[e].positions[obj] for obj in result.results[0].positions} + for e in range(len(result.results)) + ] + object_names = [obj.name for obj in objects_with_relations] + anchor_names = [a.name for a in get_anchor_objects(objects_with_relations)] + self._placement_event_cfg = make_placement_event_cfg( + positions_all_envs_by_name, + object_names, + anchor_names, + ) + else: + self._placement_event_cfg = None + + # Log outcome + results_per_env = result.results if isinstance(result, MultiEnvPlacementResult) else None + if results_per_env is not None: + n_succeeded = sum(1 for r in results_per_env if r.success) + if n_succeeded == num_envs: + print(f"Relation solving succeeded for all {num_envs} env(s) after {result.attempts} attempt(s)") + else: + print( + f"Relation solving: {n_succeeded}/{num_envs} env(s) passed validation after" + f" {result.attempts} attempt(s)." + ) + elif result.success: print(f"Relation solving succeeded after {result.attempts} attempt(s)") else: print(f"Relation solving not completed after {result.attempts} attempt(s)") @@ -152,12 +180,15 @@ def compose_manager_cfg(self) -> IsaacLabArenaManagerBasedRLEnvCfg: embodiment.get_observation_cfg(), task.get_observation_cfg(), ) - events_cfg = combine_configclass_instances( - "EventsCfg", + events_sources = [ embodiment.get_events_cfg(), self.arena_env.scene.get_events_cfg(), task.get_events_cfg(), - ) + ] + placement_event = getattr(self, "_placement_event_cfg", None) + if placement_event is not None: + events_sources.append(placement_event) + events_cfg = combine_configclass_instances("EventsCfg", *events_sources) termination_cfg = combine_configclass_instances( "TerminationCfg", task.get_termination_cfg(), diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 8f9d62928..3c370c603 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams -from isaaclab_arena.relations.placement_result import PlacementResult +from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult from isaaclab_arena.relations.relation_solver import RelationSolver from isaaclab_arena.relations.relations import On, RandomAroundSolution, RotateAroundSolution, get_anchor_objects from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, get_random_pose_within_bounding_box @@ -29,6 +29,8 @@ class ObjectPlacer: 3. Validating the result 4. Retrying if necessary 5. Applying solved positions to objects + + Supports single-env (num_envs=1) and batched (num_envs>1) placement. """ def __init__(self, params: ObjectPlacerParams | None = None): @@ -43,15 +45,18 @@ def __init__(self, params: ObjectPlacerParams | None = None): def place( self, objects: list[Object | ObjectReference], - ) -> PlacementResult: + num_envs: int = 1, + ) -> PlacementResult | MultiEnvPlacementResult: """Place objects according to their spatial relations. Args: objects: List of objects to place. Must include at least one object marked with IsAnchor() which serves as a fixed reference. + num_envs: Number of environments. 1 for single-env; > 1 for batched + placement (one layout per env). Returns: - PlacementResult with success status, positions, loss, and attempt count. + PlacementResult when num_envs is 1; MultiEnvPlacementResult when num_envs > 1. """ # Validate all objects have at least one relation for obj in objects: @@ -76,61 +81,90 @@ def place( anchor_objects_set = set(anchor_objects) - # Save RNG state and set seed if provided (for reproducibility without affecting Isaac Sim) - rng_state = None - if self.params.placement_seed is not None: - rng_state = torch.get_rng_state() - torch.manual_seed(self.params.placement_seed) - # Determine bounds for random position initialization from the first anchor object # TODO(cvolk): The user should not need to know about the bounds to set. # Implement an initialization strategy that infers from the Relations(s). init_bounds = self._get_init_bounds(anchor_objects[0]) - # Placement loop with retries - best_positions: dict[Object | ObjectReference, tuple[float, float, float]] = {} - best_loss = float("inf") - success = False + # Placement loop with retries (per-env tracking) + best_valid_loss_per_env: list[float] = [float("inf")] * num_envs + best_valid_positions_per_env: list[dict | None] = [None] * num_envs + best_any_loss_per_env: list[float] = [float("inf")] * num_envs + best_any_positions_per_env: list[dict] = [dict() for _ in range(num_envs)] for attempt in range(self.params.max_placement_attempts): - # Generate starting positions (anchors from their poses, others random) - initial_positions = self._generate_initial_positions(objects, anchor_objects_set, init_bounds) + # Generate starting positions per env (anchors from their poses, others random) + initial_positions: list[dict] = [] + for env_i in range(num_envs): + rng_state = None + if self.params.placement_seed is not None: + rng_state = torch.get_rng_state() + torch.manual_seed(self.params.placement_seed + env_i + attempt * (num_envs + 1)) + initial_positions.append(self._generate_initial_positions(objects, anchor_objects_set, init_bounds)) + if rng_state is not None: + torch.set_rng_state(rng_state) + + # solve() returns list[dict] when given list[dict] initial_positions + positions_per_env: list[dict] = self._solver.solve(objects, initial_positions) # type: ignore[assignment] # overload returns list[dict] for list input + per_env_loss = ( + self._solver.last_loss_per_env.cpu().tolist() + if self._solver.last_loss_per_env is not None + else [float("inf")] * num_envs + ) - # Solve - positions = self._solver.solve(objects, initial_positions) - loss = self._solver.last_loss_history[-1] if self._solver.last_loss_history else float("inf") + # Check if placement is valid (per env); update best valid and best-by-loss fallback + for e in range(num_envs): + loss_e = per_env_loss[e] if e < len(per_env_loss) else float("inf") + valid = self._validate_placement(positions_per_env[e]) + if valid and loss_e < best_valid_loss_per_env[e]: + best_valid_loss_per_env[e] = loss_e + best_valid_positions_per_env[e] = positions_per_env[e] + if loss_e < best_any_loss_per_env[e]: + best_any_loss_per_env[e] = loss_e + best_any_positions_per_env[e] = positions_per_env[e] if self.params.verbose: - print(f"Attempt {attempt + 1}/{self.params.max_placement_attempts}: loss = {loss:.6f}") - - # Check if placement is valid - if self._validate_placement(positions): - best_loss = loss - best_positions = positions - success = True + mean_loss = sum(per_env_loss) / num_envs + n_succeeded = sum(1 for p in best_valid_positions_per_env if p is not None) + print( + f"Attempt {attempt + 1}/{self.params.max_placement_attempts}:" + f" loss = {mean_loss:.6f}, envs validated = {n_succeeded}/{num_envs}" + ) + + if all(best_valid_positions_per_env): if self.params.verbose: print(f"Success on attempt {attempt + 1}") break - # Track best invalid result as fallback - if loss < best_loss: - best_loss = loss - best_positions = positions + # Per env: use best valid if any, else best-by-loss fallback + final_per_env: list[dict] = [ + best_valid_positions_per_env[e] if best_valid_positions_per_env[e] is not None + else best_any_positions_per_env[e] + for e in range(num_envs) + ] + + results_per_env = [ + PlacementResult( + success=best_valid_positions_per_env[e] is not None, + positions=final_per_env[e], + final_loss=( + best_valid_loss_per_env[e] if best_valid_positions_per_env[e] is not None + else best_any_loss_per_env[e] + ), + attempts=attempt + 1, + ) + for e in range(num_envs) + ] # Apply solved positions to objects - if self.params.apply_positions_to_objects: - self._apply_positions(best_positions, anchor_objects_set) - - # Restore RNG state if we changed it - if rng_state is not None: - torch.set_rng_state(rng_state) - - return PlacementResult( - success=success, - positions=best_positions, - final_loss=best_loss, - attempts=attempt + 1, - ) + # TODO(@zhx06): Consider applying via event for consistency with multi_env. + if num_envs == 1 and self.params.apply_positions_to_objects: + self._apply_positions(final_per_env[0], anchor_objects_set) + + if num_envs == 1: + return results_per_env[0] + # Multi-env: layouts applied at reset via placement event (builder builds event_cfg from result) + return MultiEnvPlacementResult(results=results_per_env) def _get_init_bounds(self, anchor_object: Object | ObjectReference) -> AxisAlignedBoundingBox: """Get bounds for random position initialization. diff --git a/isaaclab_arena/relations/object_placer_params.py b/isaaclab_arena/relations/object_placer_params.py index d5d1ea305..25dca3f65 100644 --- a/isaaclab_arena/relations/object_placer_params.py +++ b/isaaclab_arena/relations/object_placer_params.py @@ -34,7 +34,7 @@ class ObjectPlacerParams: placement_seed: int | None = None """Random seed for reproducible placement. If None, uses current RNG state.""" - min_separation_m: float = 0.0 + min_separation_m: float = 0.005 """Minimum separation (meters) required between object bounding boxes. Set to 0.0 to only reject actual overlaps. A small positive value (e.g. 0.005) adds a safety margin between objects.""" diff --git a/isaaclab_arena/relations/placement_events.py b/isaaclab_arena/relations/placement_events.py new file mode 100644 index 000000000..53c2a81d4 --- /dev/null +++ b/isaaclab_arena/relations/placement_events.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Placement event: apply placement layouts per env on reset (num_envs>1).""" + +from __future__ import annotations + +import dataclasses +import torch + +from isaaclab.envs import ManagerBasedEnv +from isaaclab.managers import EventTermCfg, SceneEntityCfg +from isaaclab.utils import configclass + +from isaaclab_arena.terms.events import set_object_pose_per_env +from isaaclab_arena.utils.pose import Pose + + +def _resolve_env_ids(env: ManagerBasedEnv, env_ids) -> list[int] | None: + """Normalize env_ids from the event manager to a list of env indices.""" + if env_ids is None: + return None + if isinstance(env_ids, slice): + if env_ids == slice(None): + return list(range(env.num_envs)) + start, stop, step = env_ids.indices(env.num_envs) + return list(range(start, stop, step)) + if hasattr(env_ids, "tolist"): + return env_ids.tolist() + return list(env_ids) + + +@configclass +class PlacementEventsCfg: + """Event config for applying placement layouts per env on reset.""" + + set_object_pose_per_env_from_layouts: EventTermCfg = dataclasses.MISSING # type: ignore[assignment] + + +def make_placement_event_cfg( + positions_all_envs_by_name: list[dict[str, tuple[float, float, float]]], + object_names: list[str], + anchor_names: list[str] | None = None, +) -> PlacementEventsCfg: + """Build event config for applying placement layouts per env on reset.""" + params: dict = { + "positions_all_envs_by_name": positions_all_envs_by_name, + "object_names": object_names, + "anchor_names": anchor_names or [], + } + return PlacementEventsCfg( + set_object_pose_per_env_from_layouts=EventTermCfg( + func=set_object_pose_per_env_from_layouts, + mode="reset", + params=params, + ) + ) + + +def set_object_pose_per_env_from_layouts( + env: ManagerBasedEnv, + env_ids, + positions_all_envs_by_name: list[dict[str, tuple[float, float, float]]], + object_names: list[str], + anchor_names: list[str] | None = None, +) -> None: + """Set each object's root pose per env from layout dicts; anchors first.""" + resolved = _resolve_env_ids(env, env_ids) + if not resolved: + return + env_ids_t = torch.tensor(resolved, device=env.device) + anchor_set = set(anchor_names or []) + ordered_names = [n for n in object_names if n in anchor_set] + ordered_names += [n for n in object_names if n not in anchor_set] + identity_quat_wxyz = (1.0, 0.0, 0.0, 0.0) + for name in ordered_names: + if name not in env.scene.keys(): + continue + asset = env.scene[name] + if not hasattr(asset, "write_root_pose_to_sim"): + continue + pose_list = [] + for e in range(len(positions_all_envs_by_name)): + xyz = positions_all_envs_by_name[e].get(name) + if xyz is not None: + x, y, z = xyz + pose_list.append(Pose(position_xyz=(x, y, z), rotation_wxyz=identity_quat_wxyz)) + else: + pose_list.append(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_wxyz=identity_quat_wxyz)) + set_object_pose_per_env(env, env_ids_t, SceneEntityCfg(name), pose_list) diff --git a/isaaclab_arena/relations/placement_result.py b/isaaclab_arena/relations/placement_result.py index e74e266f7..ecd47e8fa 100644 --- a/isaaclab_arena/relations/placement_result.py +++ b/isaaclab_arena/relations/placement_result.py @@ -27,3 +27,21 @@ class PlacementResult: attempts: int """Number of attempts made.""" + + +@dataclass +class MultiEnvPlacementResult: + """Result of an ObjectPlacer.place() call for multiple environments.""" + + results: list[PlacementResult] + """One PlacementResult per environment (same length as num_envs).""" + + @property + def success(self) -> bool: + """True if every environment's placement succeeded.""" + return all(r.success for r in self.results) + + @property + def attempts(self) -> int: + """Number of attempts (same for all envs in the batched run).""" + return self.results[0].attempts if self.results else 0 diff --git a/isaaclab_arena/relations/relation_loss_strategies.py b/isaaclab_arena/relations/relation_loss_strategies.py index 079162c8f..1be500d39 100644 --- a/isaaclab_arena/relations/relation_loss_strategies.py +++ b/isaaclab_arena/relations/relation_loss_strategies.py @@ -79,11 +79,12 @@ def compute_loss( Args: relation: The relation object containing constraint metadata. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box (extents relative to origin). + child_pos: Child object position tensor. Accepts (3,) for single-env + backward compat or (N, 3) for batched. + child_bbox: Child object local bounding box (N=1). Returns: - Scalar loss tensor representing the constraint violation. + Scalar loss tensor when child_pos is (3,), or (N,) tensor when (N, 3). """ pass @@ -103,12 +104,13 @@ def compute_loss( Args: relation: The relation object containing relationship metadata. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box (extents relative to origin). + child_pos: Child object position tensor. Accepts (3,) for single-env + backward compat or (N, 3) for batched. + child_bbox: Child object local bounding box (N=1). parent_world_bbox: Parent bounding box in world coordinates. Returns: - Scalar loss tensor representing the constraint violation. + Scalar loss tensor when child_pos is (3,), or (N,) tensor when (N, 3). """ pass @@ -145,45 +147,49 @@ def compute_loss( Args: relation: NextTo relation with side and distance attributes. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box. + child_pos: Child object position (N, 3) in world coords. + child_bbox: Child object local bounding box (N=1). parent_world_bbox: Parent bounding box in world coordinates. Returns: - Weighted loss tensor. + Weighted loss tensor of shape (N,). """ + single_input = child_pos.dim() == 1 + if single_input: + child_pos = child_pos.unsqueeze(0) + cfg = SIDE_CONFIGS[relation.side] distance = relation.distance_m assert distance >= 0.0, f"NextTo distance must be non-negative, got {distance}" # Parent world extents from the world bounding box if cfg.direction == Direction.POSITIVE: - parent_edge = parent_world_bbox.max_point[0, cfg.primary_axis] - child_offset = child_bbox.min_point[0, cfg.primary_axis] + parent_edge = parent_world_bbox.max_point[:, cfg.primary_axis] + child_offset = child_bbox.min_point[:, cfg.primary_axis] penalty_side = "less" else: - parent_edge = parent_world_bbox.min_point[0, cfg.primary_axis] - child_offset = child_bbox.max_point[0, cfg.primary_axis] + parent_edge = parent_world_bbox.min_point[:, cfg.primary_axis] + child_offset = child_bbox.max_point[:, cfg.primary_axis] penalty_side = "greater" # 1. Half-plane loss: child must be on correct side of parent edge half_plane_loss = single_boundary_linear_loss( - child_pos[cfg.primary_axis], + child_pos[:, cfg.primary_axis], parent_edge, slope=self.slope, penalty_side=penalty_side, ) # 2. Band position loss: child placed at target position within parent's perpendicular extent - parent_band_min = parent_world_bbox.min_point[0, cfg.band_axis] - parent_band_max = parent_world_bbox.max_point[0, cfg.band_axis] - valid_band_min = parent_band_min - child_bbox.min_point[0, cfg.band_axis] - valid_band_max = parent_band_max - child_bbox.max_point[0, cfg.band_axis] + parent_band_min = parent_world_bbox.min_point[:, cfg.band_axis] + parent_band_max = parent_world_bbox.max_point[:, cfg.band_axis] + valid_band_min = parent_band_min - child_bbox.min_point[:, cfg.band_axis] + valid_band_max = parent_band_max - child_bbox.max_point[:, cfg.band_axis] # Convert cross_position_ratio [-1, 1] to interpolation factor [0, 1]: -1 = min, 0 = center, 1 = max t = (relation.cross_position_ratio + 1.0) / 2.0 target_band_pos = valid_band_min + t * (valid_band_max - valid_band_min) band_loss = single_point_linear_loss( - child_pos[cfg.band_axis], + child_pos[:, cfg.band_axis], target_band_pos, slope=self.slope, ) @@ -192,31 +198,32 @@ def compute_loss( # For direction +1: target = parent_max + distance - child_min # For direction -1: target = parent_min - distance - child_max target_pos = parent_edge + cfg.direction * distance - child_offset - distance_loss = single_point_linear_loss(child_pos[cfg.primary_axis], target_pos, slope=self.slope) + distance_loss = single_point_linear_loss(child_pos[:, cfg.primary_axis], target_pos, slope=self.slope) - if self.debug: + if self.debug and child_pos.shape[0] == 1: axis_name = cfg.primary_axis.name band_axis_name = cfg.band_axis.name print( f" [NextTo] {relation.side.value}: child_{axis_name.lower()}=" - f"{child_pos[cfg.primary_axis].item():.4f}, parent_edge={parent_edge.item():.4f}," - f" loss={half_plane_loss.item():.6f}" + f"{child_pos[0, cfg.primary_axis].item():.4f}, parent_edge={parent_edge[0].item():.4f}," + f" loss={half_plane_loss[0].item():.6f}" ) print( f" [NextTo] {band_axis_name} band: child_{band_axis_name.lower()}=" - f"{child_pos[cfg.band_axis].item():.4f}, target={target_band_pos.item():.4f}" + f"{child_pos[0, cfg.band_axis].item():.4f}, target={target_band_pos[0].item():.4f}" f" (cross_position_ratio={relation.cross_position_ratio:.2f}," - f" range=[{valid_band_min.item():.4f}, {valid_band_max.item():.4f}])," - f" loss={band_loss.item():.6f}" + f" range=[{valid_band_min[0].item():.4f}, {valid_band_max[0].item():.4f}])," + f" loss={band_loss[0].item():.6f}" ) print( f" [NextTo] Distance: child_{axis_name.lower()}=" - f"{child_pos[cfg.primary_axis].item():.4f}, target={target_pos.item():.4f}," - f" loss={distance_loss.item():.6f}" + f"{child_pos[0, cfg.primary_axis].item():.4f}, target={target_pos[0].item():.4f}," + f" loss={distance_loss[0].item():.6f}" ) total_loss = half_plane_loss + band_loss + distance_loss - return relation.relation_loss_weight * total_loss + result = relation.relation_loss_weight * total_loss + return result.squeeze(0) if single_input else result class OnLossStrategy(RelationLossStrategy): @@ -249,31 +256,33 @@ def compute_loss( Args: relation: On relation with clearance_m attribute. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box. + child_pos: Child object position (N, 3) in world coords. + child_bbox: Child object local bounding box (N=1). parent_world_bbox: Parent bounding box in world coordinates. Returns: - Weighted loss tensor. + Weighted loss tensor of shape (N,). """ + single_input = child_pos.dim() == 1 + if single_input: + child_pos = child_pos.unsqueeze(0) + # Parent world-space extents from the world bounding box - parent_x_min = parent_world_bbox.min_point[0, 0] - parent_x_max = parent_world_bbox.max_point[0, 0] - parent_y_min = parent_world_bbox.min_point[0, 1] - parent_y_max = parent_world_bbox.max_point[0, 1] - parent_z_max = parent_world_bbox.max_point[0, 2] # Top surface + parent_x_min = parent_world_bbox.min_point[:, 0] + parent_x_max = parent_world_bbox.max_point[:, 0] + parent_y_min = parent_world_bbox.min_point[:, 1] + parent_y_max = parent_world_bbox.max_point[:, 1] + parent_z_max = parent_world_bbox.max_point[:, 2] # Top surface # Compute valid position ranges such that child's entire footprint is within parent - # Child left edge = child_pos[0] + child_bbox.min_point[0, 0], must be >= parent_x_min - # Child right edge = child_pos[0] + child_bbox.max_point[0, 0], must be <= parent_x_max - valid_x_min = parent_x_min - child_bbox.min_point[0, 0] # child's left at parent's left - valid_x_max = parent_x_max - child_bbox.max_point[0, 0] # child's right at parent's right - valid_y_min = parent_y_min - child_bbox.min_point[0, 1] - valid_y_max = parent_y_max - child_bbox.max_point[0, 1] + valid_x_min = parent_x_min - child_bbox.min_point[:, 0] # child's left at parent's left + valid_x_max = parent_x_max - child_bbox.max_point[:, 0] # child's right at parent's right + valid_y_min = parent_y_min - child_bbox.min_point[:, 1] + valid_y_max = parent_y_max - child_bbox.max_point[:, 1] # 1. X band loss: child's footprint entirely within parent's X extent x_band_loss = linear_band_loss( - child_pos[0], + child_pos[:, 0], lower_bound=valid_x_min, upper_bound=valid_x_max, slope=self.slope, @@ -281,32 +290,33 @@ def compute_loss( # 2. Y band loss: child's footprint entirely within parent's Y extent y_band_loss = linear_band_loss( - child_pos[1], + child_pos[:, 1], lower_bound=valid_y_min, upper_bound=valid_y_max, slope=self.slope, ) # 3. Z point loss: child bottom = parent top + clearance - target_z = parent_z_max + relation.clearance_m - child_bbox.min_point[0, 2] - z_loss = single_point_linear_loss(child_pos[2], target_z, slope=self.slope) + target_z = parent_z_max + relation.clearance_m - child_bbox.min_point[:, 2] + z_loss = single_point_linear_loss(child_pos[:, 2], target_z, slope=self.slope) - if self.debug: + if self.debug and child_pos.shape[0] == 1: print( - f" [On] X: child_pos={child_pos[0].item():.4f}, valid_range=[{valid_x_min.item():.4f}," - f" {valid_x_max.item():.4f}], loss={x_band_loss.item():.6f}" + f" [On] X: child_pos={child_pos[0, 0].item():.4f}, valid_range=[{valid_x_min[0].item():.4f}," + f" {valid_x_max[0].item():.4f}], loss={x_band_loss[0].item():.6f}" ) print( - f" [On] Y: child_pos={child_pos[1].item():.4f}, valid_range=[{valid_y_min.item():.4f}," - f" {valid_y_max.item():.4f}], loss={y_band_loss.item():.6f}" + f" [On] Y: child_pos={child_pos[0, 1].item():.4f}, valid_range=[{valid_y_min[0].item():.4f}," + f" {valid_y_max[0].item():.4f}], loss={y_band_loss[0].item():.6f}" ) print( - f" [On] Z: child_pos={child_pos[2].item():.4f}, target={target_z.item():.4f}," - f" loss={z_loss.item():.6f}" + f" [On] Z: child_pos={child_pos[0, 2].item():.4f}, target={target_z[0].item():.4f}," + f" loss={z_loss[0].item():.6f}" ) total_loss = x_band_loss + y_band_loss + z_loss - return relation.relation_loss_weight * total_loss + result = relation.relation_loss_weight * total_loss + return result.squeeze(0) if single_input else result class NoCollisionLossStrategy(RelationLossStrategy): @@ -340,54 +350,59 @@ def compute_loss( Args: relation: NoCollision relation with relation_loss_weight. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box. + child_pos: Child object position (N, 3) in world coords. + child_bbox: Child object local bounding box (N=1). parent_world_bbox: Parent bounding box in world coordinates. Returns: - Weighted loss tensor. + Weighted loss tensor of shape (N,). """ + single_input = child_pos.dim() == 1 + if single_input: + child_pos = child_pos.unsqueeze(0) + # Parent world extents from the world bounding box, expanded by clearance_m c = relation.clearance_m - parent_x_min = parent_world_bbox.min_point[0, 0] - c - parent_x_max = parent_world_bbox.max_point[0, 0] + c - parent_y_min = parent_world_bbox.min_point[0, 1] - c - parent_y_max = parent_world_bbox.max_point[0, 1] + c - parent_z_min = parent_world_bbox.min_point[0, 2] - c - parent_z_max = parent_world_bbox.max_point[0, 2] + c + parent_x_min = parent_world_bbox.min_point[:, 0] - c + parent_x_max = parent_world_bbox.max_point[:, 0] + c + parent_y_min = parent_world_bbox.min_point[:, 1] - c + parent_y_max = parent_world_bbox.max_point[:, 1] + c + parent_z_min = parent_world_bbox.min_point[:, 2] - c + parent_z_max = parent_world_bbox.max_point[:, 2] + c # Child world extents - child_world_min = child_pos + child_bbox.min_point[0] - child_world_max = child_pos + child_bbox.max_point[0] + child_world_min = child_pos + child_bbox.min_point + child_world_max = child_pos + child_bbox.max_point # 1. Per-axis overlap: zero when separated; else overlap length (default slope 1.0 gives length in m) - overlap_x = interval_overlap_axis_loss(child_world_min[0], child_world_max[0], parent_x_min, parent_x_max) - overlap_y = interval_overlap_axis_loss(child_world_min[1], child_world_max[1], parent_y_min, parent_y_max) - overlap_z = interval_overlap_axis_loss(child_world_min[2], child_world_max[2], parent_z_min, parent_z_max) + overlap_x = interval_overlap_axis_loss(child_world_min[:, 0], child_world_max[:, 0], parent_x_min, parent_x_max) + overlap_y = interval_overlap_axis_loss(child_world_min[:, 1], child_world_max[:, 1], parent_y_min, parent_y_max) + overlap_z = interval_overlap_axis_loss(child_world_min[:, 2], child_world_max[:, 2], parent_z_min, parent_z_max) # 2. Volume loss: slope * product of per-axis overlap lengths (overlap volume when slope 1.0) overlap_volume = overlap_x * overlap_y * overlap_z total_loss = self.slope * overlap_volume - if self.debug: + if self.debug and child_pos.shape[0] == 1: print( - f" [NoCollision] X: overlap={overlap_x.item():.6f} (child_x=[{child_world_min[0].item():.4f}," - f" {child_world_max[0].item():.4f}], parent_x=[{parent_x_min.item():.4f}," - f" {parent_x_max.item():.4f}])" + f" [NoCollision] X: overlap={overlap_x[0].item():.6f} (child_x=[{child_world_min[0, 0].item():.4f}," + f" {child_world_max[0, 0].item():.4f}], parent_x=[{parent_x_min[0].item():.4f}," + f" {parent_x_max[0].item():.4f}])" ) print( - f" [NoCollision] Y: overlap={overlap_y.item():.6f} (child_y=[{child_world_min[1].item():.4f}," - f" {child_world_max[1].item():.4f}], parent_y=[{parent_y_min.item():.4f}," - f" {parent_y_max.item():.4f}])" + f" [NoCollision] Y: overlap={overlap_y[0].item():.6f} (child_y=[{child_world_min[0, 1].item():.4f}," + f" {child_world_max[0, 1].item():.4f}], parent_y=[{parent_y_min[0].item():.4f}," + f" {parent_y_max[0].item():.4f}])" ) print( - f" [NoCollision] Z: overlap={overlap_z.item():.6f} (child_z=[{child_world_min[2].item():.4f}," - f" {child_world_max[2].item():.4f}], parent_z=[{parent_z_min.item():.4f}," - f" {parent_z_max.item():.4f}])" + f" [NoCollision] Z: overlap={overlap_z[0].item():.6f} (child_z=[{child_world_min[0, 2].item():.4f}," + f" {child_world_max[0, 2].item():.4f}], parent_z=[{parent_z_min[0].item():.4f}," + f" {parent_z_max[0].item():.4f}])" ) - print(f" [NoCollision] volume={overlap_volume.item():.6f}, loss={total_loss.item():.6f}") + print(f" [NoCollision] volume={overlap_volume[0].item():.6f}, loss={total_loss[0].item():.6f}") - return relation.relation_loss_weight * total_loss + result = relation.relation_loss_weight * total_loss + return result.squeeze(0) if single_input else result class AtPositionLossStrategy(UnaryRelationLossStrategy): @@ -415,27 +430,32 @@ def compute_loss( Args: relation: AtPosition relation with x, y, z target coordinates. - child_pos: Child object position tensor (x, y, z) in world coords. + child_pos: Child object position (N, 3) in world coords. child_bbox: Child object local bounding box (unused, for signature consistency). Returns: - Weighted loss tensor. + Weighted loss tensor of shape (N,). """ - total_loss = torch.tensor(0.0, dtype=child_pos.dtype, device=child_pos.device) + single_input = child_pos.dim() == 1 + if single_input: + child_pos = child_pos.unsqueeze(0) + + total_loss = torch.zeros(child_pos.shape[0], dtype=child_pos.dtype, device=child_pos.device) # X position constraint if relation.x is not None: - x_loss = single_point_linear_loss(child_pos[0], relation.x, slope=self.slope) + x_loss = single_point_linear_loss(child_pos[:, 0], relation.x, slope=self.slope) total_loss = total_loss + x_loss # Y position constraint if relation.y is not None: - y_loss = single_point_linear_loss(child_pos[1], relation.y, slope=self.slope) + y_loss = single_point_linear_loss(child_pos[:, 1], relation.y, slope=self.slope) total_loss = total_loss + y_loss # Z position constraint if relation.z is not None: - z_loss = single_point_linear_loss(child_pos[2], relation.z, slope=self.slope) + z_loss = single_point_linear_loss(child_pos[:, 2], relation.z, slope=self.slope) total_loss = total_loss + z_loss - return relation.relation_loss_weight * total_loss + result = relation.relation_loss_weight * total_loss + return result.squeeze(0) if single_input else result diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index be96be7ac..ed62a10f4 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -39,6 +39,7 @@ def __init__( self.params = params or RelationSolverParams() self._last_loss_history: list[float] = [] self._last_position_history: list = [] + self._last_loss_per_env: torch.Tensor | None = None def _get_strategy(self, relation: RelationBase) -> RelationLossStrategy | UnaryRelationLossStrategy: """Look up the appropriate strategy for a relation type. @@ -68,9 +69,11 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - debug: If True, print detailed loss breakdown. Returns: - Total loss tensor. + Scalar loss tensor (mean over envs). Per-env loss stored in _last_loss_per_env. """ - total_loss = torch.tensor(0.0) + N = state.num_envs + device = state.optimizable_positions.device if state.optimizable_positions is not None else None + total_loss = torch.zeros(N, device=device, dtype=torch.float32) # Compute loss from all spatial relations using strategies for obj in state.optimizable_objects: @@ -86,7 +89,7 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - child_bbox=obj.get_bounding_box(), ) if debug: - _print_unary_relation_debug(obj, relation, child_pos, loss) + _print_unary_relation_debug(obj, relation, child_pos[0], loss.sum()) # Handle binary relations (with parent) like On, NextTo elif isinstance(relation, Relation): # Build parent world bbox: anchors have a known fixed pose, @@ -96,9 +99,7 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - parent_world_bbox = parent.get_world_bounding_box() else: parent_pos = state.get_position(parent) - parent_world_bbox = parent.get_bounding_box().translated( - (float(parent_pos[0]), float(parent_pos[1]), float(parent_pos[2])) - ) + parent_world_bbox = parent.get_bounding_box().translated(parent_pos) loss = strategy.compute_loss( relation=relation, child_pos=child_pos, @@ -107,29 +108,38 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - ) if debug: parent_pos = state.get_position(parent) - _print_relation_debug(obj, relation, child_pos, parent_pos, loss) + _print_relation_debug(obj, relation, child_pos[0], parent_pos[0], loss.sum()) else: raise ValueError(f"Unknown relation type: {type(relation).__name__}") total_loss = total_loss + loss - return total_loss + self._last_loss_per_env = total_loss.detach().clone() + return total_loss.mean() def solve( self, objects: list[Object | ObjectReference], - initial_positions: dict[Object | ObjectReference, tuple[float, float, float]], - ) -> dict[Object | ObjectReference, tuple[float, float, float]]: + initial_positions: ( + dict[Object | ObjectReference, tuple[float, float, float]] + | list[dict[Object | ObjectReference, tuple[float, float, float]]] + ), + ) -> ( + dict[Object | ObjectReference, tuple[float, float, float]] + | list[dict[Object | ObjectReference, tuple[float, float, float]]] + ): """Solve for optimal positions of all objects. Args: objects: List of Object or ObjectReference instances. Must include at least one object marked with IsAnchor() which serves as a fixed reference. - initial_positions: Starting positions for all objects (including anchors). + initial_positions: A single dict (backward compat, single-env) or a list + of dicts (one per env for batched). Returns: - Dictionary mapping object instances to final (x, y, z) positions. + Single dict when input is a dict, or list of dicts when input is a list. """ + single_input = isinstance(initial_positions, dict) state = RelationSolverState(objects, initial_positions) if self.params.verbose: @@ -145,7 +155,8 @@ def solve( print("No optimizable objects, skipping solver.") self._last_loss_history = [0.0] self._last_position_history = [state.get_all_positions_snapshot()] - return state.get_final_positions_dict() + final = state.get_final_positions() + return final[0] if single_input else final # Setup optimizer (only for optimizable positions) optimizer = torch.optim.Adam([state.optimizable_positions], lr=self.params.lr) @@ -188,13 +199,19 @@ def solve( self._last_loss_history = loss_history self._last_position_history = position_history - return state.get_final_positions_dict() + final = state.get_final_positions() + return final[0] if single_input else final @property def last_loss_history(self) -> list[float]: """Loss values from the most recent solve() call.""" return self._last_loss_history + @property + def last_loss_per_env(self) -> torch.Tensor | None: + """Per-env loss (N,) from the last solve() call.""" + return self._last_loss_per_env + @property def last_position_history(self) -> list: """Position snapshots from the most recent solve() call.""" @@ -220,7 +237,7 @@ def debug_losses(self, objects: list[Object | ObjectReference]) -> None: # Build positions dict from final position history final_positions = {obj: (pos[0], pos[1], pos[2]) for obj, pos in zip(objects, final_positions_list)} - state = RelationSolverState(objects, final_positions) + state = RelationSolverState(objects, [final_positions]) self._compute_total_loss(state, debug=True) print("\n" + "=" * 60) diff --git a/isaaclab_arena/relations/relation_solver_params.py b/isaaclab_arena/relations/relation_solver_params.py index bd78408f1..c6f14b3f8 100644 --- a/isaaclab_arena/relations/relation_solver_params.py +++ b/isaaclab_arena/relations/relation_solver_params.py @@ -21,7 +21,7 @@ def _default_strategies() -> dict[type[RelationBase], RelationLossStrategy | Una return { NextTo: NextToLossStrategy(slope=10.0), On: OnLossStrategy(slope=100.0), - NoCollision: NoCollisionLossStrategy(slope=100.0), + NoCollision: NoCollisionLossStrategy(slope=500.0), AtPosition: AtPositionLossStrategy(slope=100.0), } diff --git a/isaaclab_arena/relations/relation_solver_state.py b/isaaclab_arena/relations/relation_solver_state.py index 75c00169d..b45c0f33d 100644 --- a/isaaclab_arena/relations/relation_solver_state.py +++ b/isaaclab_arena/relations/relation_solver_state.py @@ -21,20 +21,30 @@ class RelationSolverState: This class manages the mapping between objects and their positions, keeping anchor (fixed) and optimizable positions separate internally while providing an interface for position lookups. + + Positions are always stored as (N, num_objects, 3) where N = num_envs + (N=1 for single-env). """ def __init__( self, objects: list[Object | ObjectReference], - initial_positions: dict[Object | ObjectReference, tuple[float, float, float]], + initial_positions: ( + dict[Object | ObjectReference, tuple[float, float, float]] + | list[dict[Object | ObjectReference, tuple[float, float, float]]] + ), ): """Initialize optimization state. Args: objects: List of all Object or ObjectReference instances to track. Must include at least one object marked with IsAnchor() which serves as a fixed reference. - initial_positions: Starting positions for all objects (including anchors). + initial_positions: A single dict (backward compat, treated as single-env) or a list + of dicts (one per env). Length 1 = single-env, length > 1 = batched. """ + if isinstance(initial_positions, dict): + initial_positions = [initial_positions] + assert len(initial_positions) >= 1, "initial_positions must contain at least one dict." anchor_objects = get_anchor_objects(objects) assert len(anchor_objects) > 0, "No anchor object found in objects list." @@ -45,27 +55,44 @@ def __init__( # Build object-to-index mapping self._obj_to_idx: dict[Object | ObjectReference, int] = {obj: i for i, obj in enumerate(objects)} - # Extract positions from the provided dict - positions = [] - for obj in objects: - assert obj in initial_positions, f"Missing initial position for {obj.name}" - positions.append(torch.tensor(initial_positions[obj], dtype=torch.float32)) + # Extract positions from each env's dict + self._num_envs = len(initial_positions) + positions_per_env = [] + for d in initial_positions: + positions = [] + for obj in objects: + assert obj in d, f"Missing initial position for {obj.name}" + positions.append(torch.tensor(d[obj], dtype=torch.float32)) + positions_per_env.append(positions) # Separate anchor positions from optimizable positions self._anchor_indices: set[int] = {self._obj_to_idx[obj] for obj in self._anchor_objects} - self._anchor_positions: dict[int, torch.Tensor] = {idx: positions[idx].clone() for idx in self._anchor_indices} + # Anchors are fixed (same position in all envs), so env 0 is representative. + self._anchor_positions: dict[int, torch.Tensor] = { + idx: positions_per_env[0][idx].clone() for idx in self._anchor_indices + } # Build optimizable positions tensor (excludes all anchors) + # Always stored as (N, num_opt, 3) where N = num_envs self._optimizable_indices = [i for i in range(len(objects)) if i not in self._anchor_indices] if self._optimizable_indices: - self._optimizable_positions = torch.stack([positions[i] for i in self._optimizable_indices]) + opt_tensors = [ + torch.stack([positions_per_env[e][i] for e in range(self._num_envs)]) + for i in self._optimizable_indices + ] + self._optimizable_positions = torch.stack(opt_tensors, dim=1) # (N, num_opt, 3) self._optimizable_positions.requires_grad = True else: self._optimizable_positions = None + @property + def num_envs(self) -> int: + """Number of environments (leading dimension N).""" + return self._num_envs + @property def optimizable_positions(self) -> torch.Tensor | None: - """Tensor of optimizable positions (shape: [N-num_anchors, 3]), or None if all objects are anchors. + """Tensor of optimizable positions (N, num_opt, 3), or None if all objects are anchors. This is the tensor that should be passed to the optimizer. """ @@ -88,7 +115,7 @@ def get_position(self, obj: Object | ObjectReference) -> torch.Tensor: obj: The object to get position for. Returns: - Position tensor (x, y, z). + Position tensor of shape (N, 3). Raises: KeyError: If object is not tracked by this state. @@ -96,28 +123,31 @@ def get_position(self, obj: Object | ObjectReference) -> torch.Tensor: """ idx = self._obj_to_idx[obj] if idx in self._anchor_indices: - return self._anchor_positions[idx] + return self._anchor_positions[idx].unsqueeze(0).expand(self._num_envs, 3) if self._optimizable_positions is None: raise RuntimeError(f"No optimizable positions available for object '{obj.name}'") opt_idx = self._optimizable_indices.index(idx) - return self._optimizable_positions[opt_idx] + return self._optimizable_positions[:, opt_idx, :] def get_all_positions_snapshot(self) -> list[tuple[float, float, float]]: """Get detached copy of all positions for history tracking. Returns: - List of (x, y, z) positions for each object (in original order). + List of (x, y, z) positions for each object (in original order). Uses env 0. """ - return [tuple(self.get_position(obj).detach().tolist()) for obj in self._all_objects] + return [tuple(self.get_position(obj)[0].detach().tolist()) for obj in self._all_objects] - def get_final_positions_dict(self) -> dict[Object | ObjectReference, tuple[float, float, float]]: - """Get final positions as a dictionary mapping objects to positions. + def get_final_positions(self) -> list[dict[Object | ObjectReference, tuple[float, float, float]]]: + """Get final positions as a list of dicts, one per env. Returns: - Dictionary with object instances as keys and (x, y, z) tuples as values. + List of dictionaries with object instances as keys and (x, y, z) tuples as values. """ - result: dict[Object | ObjectReference, tuple[float, float, float]] = {} - for obj in self._all_objects: - pos = self.get_position(obj).detach().tolist() - result[obj] = (pos[0], pos[1], pos[2]) - return result + out = [] + for e in range(self._num_envs): + d: dict[Object | ObjectReference, tuple[float, float, float]] = {} + for obj in self._all_objects: + pos = self.get_position(obj)[e].detach().tolist() + d[obj] = (pos[0], pos[1], pos[2]) + out.append(d) + return out diff --git a/isaaclab_arena/relations/relations.py b/isaaclab_arena/relations/relations.py index f64167881..4275765ec 100644 --- a/isaaclab_arena/relations/relations.py +++ b/isaaclab_arena/relations/relations.py @@ -126,23 +126,24 @@ class NoCollision(Relation): Note: Loss computation is handled by NoCollisionLossStrategy in relation_loss_strategies.py. - NOTE: If both A.add_relation(NoCollision(B)) and B.add_relation(NoCollision(A)) are present, - the solver will compute the loss twice and the relation graph becomes cyclic, which can cause - issues during environment creation. Deduplication or cycle detection should be addressed at a - higher level. + NOTE: RelationSolver._compute_total_loss iterates every relation on every object with no + deduplication. If both A.add_relation(NoCollision(B)) and B.add_relation(NoCollision(A)) + are present, loss is computed twice. Bidirectional NoCollision can also make the relation + graph cyclic and cause issues when creating the environment. Deduplication and/or + higher-level handling of symmetric relations to be addressed in a future commit. """ def __init__( self, parent: Object | ObjectReference, relation_loss_weight: float = 1.0, - clearance_m: float = 0.01, + clearance_m: float = 0.02, ): """ Args: parent: The other object that this object must not collide with. relation_loss_weight: Weight for the relationship loss function. - clearance_m: Minimum clearance between bounding boxes in meters (default: 1cm). + clearance_m: Minimum clearance between bounding boxes in meters (default: 2cm). """ super().__init__(parent, relation_loss_weight) assert clearance_m >= 0.0, f"clearance_m must be non-negative, got {clearance_m}" diff --git a/isaaclab_arena/tests/test_no_collision_loss.py b/isaaclab_arena/tests/test_no_collision_loss.py index 77768d768..944d9759f 100644 --- a/isaaclab_arena/tests/test_no_collision_loss.py +++ b/isaaclab_arena/tests/test_no_collision_loss.py @@ -217,3 +217,43 @@ def test_relation_solver_no_collision_same_inputs_reproducible(): assert result1[box_a1] == result2[box_a2], "box_a positions should match" assert result1[box_b1] == result2[box_b2], "box_b positions should match" + + +def test_no_collision_loss_multi_env_shape_and_values(): + """Test that NoCollision with batched (N,3) input returns (N,) loss with correct per-env values.""" + box_a = _create_box("box_a") + box_b = _create_box("box_b") + relation = NoCollision(box_b, clearance_m=0.0) + strategy = NoCollisionLossStrategy(slope=10.0) + + child_pos = torch.tensor([[0.0, 0.0, 0.0], [0.1, 0.1, 0.0]]) + parent_world_bbox = AxisAlignedBoundingBox( + min_point=torch.tensor([[1.0, 0.0, 0.0], [0.05, 0.05, 0.0]]), + max_point=torch.tensor([[1.2, 0.2, 0.2], [0.25, 0.25, 0.2]]), + ) + + loss = strategy.compute_loss(relation, child_pos, box_a.bounding_box, parent_world_bbox) + assert loss.shape == (2,) + assert torch.isclose(loss[0], torch.tensor(0.0), atol=1e-5) + assert loss[1] > 0.0 + + +def test_relation_solver_multi_env_returns_list_of_dicts(): + """Test that solver returns list[dict] when given list[dict] input.""" + table, box_a, box_b = _create_no_collision_scene() + objects = [table, box_a, box_b] + initial_positions = [ + {table: (0.0, 0.0, 0.0), box_a: (0.2, 0.2, 0.11), box_b: (0.25, 0.25, 0.11)}, + {table: (0.0, 0.0, 0.0), box_a: (0.3, 0.3, 0.11), box_b: (0.6, 0.6, 0.11)}, + ] + + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + solver = RelationSolver(params=solver_params) + result = solver.solve(objects=objects, initial_positions=initial_positions) + + assert isinstance(result, list) + assert len(result) == 2 + for d in result: + assert isinstance(d, dict) + assert box_a in d + assert box_b in d diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index 81daa8bf6..3331d43de 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -128,3 +128,26 @@ def test_object_placer_different_seeds_produce_different_results(): break assert any_different, "Different seeds should produce different results" + + +def test_relation_solver_multi_env_batched_positions(): + """Test that solver with list[dict] input returns list[dict] output.""" + solver_params = RelationSolverParams(max_iters=50) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + + initial_positions = [ + {desk: (0.0, 0.0, 0.0), box1: (0.2, 0.2, 0.11), box2: (0.5, 0.5, 0.11)}, + {desk: (0.0, 0.0, 0.0), box1: (0.3, 0.3, 0.11), box2: (0.6, 0.6, 0.11)}, + ] + + solver = RelationSolver(params=solver_params) + result = solver.solve(objects=objects, initial_positions=initial_positions) + + assert isinstance(result, list) + assert len(result) == 2 + for d in result: + assert isinstance(d, dict) + for obj in objects: + assert obj in d + assert len(d[obj]) == 3 diff --git a/isaaclab_arena/tests/test_relation_loss_strategies.py b/isaaclab_arena/tests/test_relation_loss_strategies.py index d7221d0ea..f2463c167 100644 --- a/isaaclab_arena/tests/test_relation_loss_strategies.py +++ b/isaaclab_arena/tests/test_relation_loss_strategies.py @@ -229,3 +229,38 @@ def test_next_to_zero_distance_raises(): with pytest.raises(AssertionError, match="Distance must be positive"): NextTo(parent_obj, side=Side.POSITIVE_X, distance_m=0.0) + + +def test_on_loss_strategy_multi_env_shape_and_values(): + """Test that On with batched (N,3) input returns (N,) loss with correct per-env values.""" + table = _create_table() + box = _create_box() + relation = On(table, clearance_m=0.01) + strategy = OnLossStrategy(slope=10.0) + + child_pos = torch.tensor([[0.4, 0.4, 0.11], [0.4, 0.4, 0.5]]) + parent_world_bbox = AxisAlignedBoundingBox( + min_point=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), + max_point=torch.tensor([[1.0, 1.0, 0.1], [1.0, 1.0, 0.1]]), + ) + + loss = strategy.compute_loss(relation, child_pos, box.bounding_box, parent_world_bbox) + assert loss.shape == (2,) + assert torch.isclose(loss[0], torch.tensor(0.0), atol=1e-4) + assert loss[1] > 0.0 + + +def test_next_to_loss_strategy_multi_env_shape_and_values(): + """Test that NextTo with batched (N,3) input returns (N,) loss with correct per-env values.""" + parent_obj = _create_table() + child_obj = _create_box() + relation = NextTo(parent_obj, side=Side.POSITIVE_X, distance_m=0.05) + strategy = NextToLossStrategy(slope=10.0) + + # Env 0: perfectly placed. Env 1: wrong side. + child_pos = torch.tensor([[1.05, 0.4, 0.0], [-0.5, 0.5, 0.0]]) + + loss = strategy.compute_loss(relation, child_pos, child_obj.bounding_box, parent_obj.bounding_box) + assert loss.shape == (2,) + assert torch.isclose(loss[0], torch.tensor(0.0), atol=1e-4) + assert loss[1] > 0.0 diff --git a/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py b/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py index 3e0428fc1..02b0939e6 100644 --- a/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py +++ b/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py @@ -10,7 +10,7 @@ Example: python isaaclab_arena/evaluation/policy_runner.py --policy_type zero_action --num_steps 500 \\ - --num_envs 1 --enable_cameras gr1_table_multi_object_no_collision --embodiment gr1_joint + --num_envs 16 --env_spacing 4.0 --enable_cameras gr1_table_multi_object_no_collision --embodiment gr1_joint """ import argparse @@ -19,13 +19,13 @@ DEFAULT_TABLE_OBJECTS = [ "cracker_box", - "mustard_bottle", "sugar_box", "tomato_soup_can", "mug", - "brown_box", "dex_cube", -] # Default objects on table (On + pairwise NoCollision) + "power_drill", + "red_container", +] # 7 objects on table (On + pairwise NoCollision). class GR1TableMultiObjectNoCollisionEnvironment(ExampleEnvironmentBase): From 142af71c763094823cc2c1198724313d9e58bf36 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Mon, 30 Mar 2026 13:56:08 -0700 Subject: [PATCH 02/12] fix overlap validation, revert hyperparameters --- isaaclab_arena/relations/object_placer.py | 16 +++++++++++++- .../relations/object_placer_params.py | 2 +- isaaclab_arena/relations/relation_solver.py | 2 +- .../relations/relation_solver_params.py | 2 +- isaaclab_arena/relations/relations.py | 4 ++-- .../test_object_placer_reproducibility.py | 22 +++++++++++++++++++ ...e_multi_object_no_collision_environment.py | 6 +++-- 7 files changed, 46 insertions(+), 8 deletions(-) diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 3c370c603..64bbe1385 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -253,11 +253,25 @@ def _validate_no_overlap( self, positions: dict[Object | ObjectReference, tuple[float, float, float]], ) -> bool: - """Check that no two objects overlap in 3D (axis-aligned bbox with margin).""" + """Validate that no two objects overlap in 3D (axis-aligned bbox with margin). + + Pairs linked by an On relation are skipped (validated separately by + _validate_on_relations). + """ + # Build set of On-related pairs to skip (child, parent) and (parent, child). + on_pairs: set[tuple] = set() + for obj in positions: + for rel in obj.get_relations(): + if isinstance(rel, On) and rel.parent in positions: + on_pairs.add((id(obj), id(rel.parent))) + on_pairs.add((id(rel.parent), id(obj))) + objects = list(positions.keys()) for i in range(len(objects)): for j in range(i + 1, len(objects)): a, b = objects[i], objects[j] + if (id(a), id(b)) in on_pairs: + continue a_world = a.get_bounding_box().translated(positions[a]) b_world = b.get_bounding_box().translated(positions[b]) diff --git a/isaaclab_arena/relations/object_placer_params.py b/isaaclab_arena/relations/object_placer_params.py index 25dca3f65..d5d1ea305 100644 --- a/isaaclab_arena/relations/object_placer_params.py +++ b/isaaclab_arena/relations/object_placer_params.py @@ -34,7 +34,7 @@ class ObjectPlacerParams: placement_seed: int | None = None """Random seed for reproducible placement. If None, uses current RNG state.""" - min_separation_m: float = 0.005 + min_separation_m: float = 0.0 """Minimum separation (meters) required between object bounding boxes. Set to 0.0 to only reject actual overlaps. A small positive value (e.g. 0.005) adds a safety margin between objects.""" diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index ed62a10f4..2ff0f592d 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -237,7 +237,7 @@ def debug_losses(self, objects: list[Object | ObjectReference]) -> None: # Build positions dict from final position history final_positions = {obj: (pos[0], pos[1], pos[2]) for obj, pos in zip(objects, final_positions_list)} - state = RelationSolverState(objects, [final_positions]) + state = RelationSolverState(objects, final_positions) self._compute_total_loss(state, debug=True) print("\n" + "=" * 60) diff --git a/isaaclab_arena/relations/relation_solver_params.py b/isaaclab_arena/relations/relation_solver_params.py index c6f14b3f8..bd78408f1 100644 --- a/isaaclab_arena/relations/relation_solver_params.py +++ b/isaaclab_arena/relations/relation_solver_params.py @@ -21,7 +21,7 @@ def _default_strategies() -> dict[type[RelationBase], RelationLossStrategy | Una return { NextTo: NextToLossStrategy(slope=10.0), On: OnLossStrategy(slope=100.0), - NoCollision: NoCollisionLossStrategy(slope=500.0), + NoCollision: NoCollisionLossStrategy(slope=100.0), AtPosition: AtPositionLossStrategy(slope=100.0), } diff --git a/isaaclab_arena/relations/relations.py b/isaaclab_arena/relations/relations.py index 4275765ec..15427489f 100644 --- a/isaaclab_arena/relations/relations.py +++ b/isaaclab_arena/relations/relations.py @@ -137,13 +137,13 @@ def __init__( self, parent: Object | ObjectReference, relation_loss_weight: float = 1.0, - clearance_m: float = 0.02, + clearance_m: float = 0.01, ): """ Args: parent: The other object that this object must not collide with. relation_loss_weight: Weight for the relationship loss function. - clearance_m: Minimum clearance between bounding boxes in meters (default: 2cm). + clearance_m: Minimum clearance between bounding boxes in meters (default: 1cm). """ super().__init__(parent, relation_loss_weight) assert clearance_m >= 0.0, f"clearance_m must be non-negative, got {clearance_m}" diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index 3331d43de..0cf81487f 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -8,6 +8,7 @@ from isaaclab_arena.assets.dummy_object import DummyObject from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams +from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult from isaaclab_arena.relations.relation_solver import RelationSolver from isaaclab_arena.relations.relation_solver_params import RelationSolverParams from isaaclab_arena.relations.relations import IsAnchor, NextTo, On, Side @@ -130,6 +131,27 @@ def test_object_placer_different_seeds_produce_different_results(): assert any_different, "Different seeds should produce different results" +def test_object_placer_multi_env_returns_multi_env_result(): + """Test that ObjectPlacer.place with num_envs>1 returns MultiEnvPlacementResult.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer( + params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params) + ) + result = placer.place(objects, num_envs=num_envs) + + assert isinstance(result, MultiEnvPlacementResult) + assert len(result.results) == num_envs + for r in result.results: + assert isinstance(r, PlacementResult) + assert box1 in r.positions + assert box2 in r.positions + assert len(r.positions[box1]) == 3 + assert len(r.positions[box2]) == 3 + + def test_relation_solver_multi_env_batched_positions(): """Test that solver with list[dict] input returns list[dict] output.""" solver_params = RelationSolverParams(max_iters=50) diff --git a/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py b/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py index 02b0939e6..f3750fddc 100644 --- a/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py +++ b/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py @@ -21,11 +21,13 @@ "cracker_box", "sugar_box", "tomato_soup_can", - "mug", "dex_cube", "power_drill", "red_container", -] # 7 objects on table (On + pairwise NoCollision). +] +# NOTE: The gradient-based solver does not guarantee collision-free placement for all +# objects. Better initialization strategies and constraining unchanged pose dimensions +# are needed in the near future. class GR1TableMultiObjectNoCollisionEnvironment(ExampleEnvironmentBase): From 7c0079cc2a9d2bd2a6464038d492f3f0d51549c5 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Mon, 30 Mar 2026 14:16:17 -0700 Subject: [PATCH 03/12] pre-commit check --- isaaclab_arena/relations/object_placer.py | 10 +++++++--- isaaclab_arena/relations/relation_solver_state.py | 3 +-- .../tests/test_object_placer_reproducibility.py | 4 +--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 64bbe1385..4899d7179 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -138,8 +138,11 @@ def place( # Per env: use best valid if any, else best-by-loss fallback final_per_env: list[dict] = [ - best_valid_positions_per_env[e] if best_valid_positions_per_env[e] is not None - else best_any_positions_per_env[e] + ( + best_valid_positions_per_env[e] + if best_valid_positions_per_env[e] is not None + else best_any_positions_per_env[e] + ) for e in range(num_envs) ] @@ -148,7 +151,8 @@ def place( success=best_valid_positions_per_env[e] is not None, positions=final_per_env[e], final_loss=( - best_valid_loss_per_env[e] if best_valid_positions_per_env[e] is not None + best_valid_loss_per_env[e] + if best_valid_positions_per_env[e] is not None else best_any_loss_per_env[e] ), attempts=attempt + 1, diff --git a/isaaclab_arena/relations/relation_solver_state.py b/isaaclab_arena/relations/relation_solver_state.py index b45c0f33d..b8a9de8f4 100644 --- a/isaaclab_arena/relations/relation_solver_state.py +++ b/isaaclab_arena/relations/relation_solver_state.py @@ -77,8 +77,7 @@ def __init__( self._optimizable_indices = [i for i in range(len(objects)) if i not in self._anchor_indices] if self._optimizable_indices: opt_tensors = [ - torch.stack([positions_per_env[e][i] for e in range(self._num_envs)]) - for i in self._optimizable_indices + torch.stack([positions_per_env[e][i] for e in range(self._num_envs)]) for i in self._optimizable_indices ] self._optimizable_positions = torch.stack(opt_tensors, dim=1) # (N, num_opt, 3) self._optimizable_positions.requires_grad = True diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index 0cf81487f..10b9d5829 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -137,9 +137,7 @@ def test_object_placer_multi_env_returns_multi_env_result(): solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) desk, box1, box2 = _create_test_objects() objects = [desk, box1, box2] - placer = ObjectPlacer( - params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params) - ) + placer = ObjectPlacer(params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params)) result = placer.place(objects, num_envs=num_envs) assert isinstance(result, MultiEnvPlacementResult) From debb6bfdbdd78d3caafb49f5b42ac596f4f8f6a1 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Mon, 30 Mar 2026 15:21:46 -0700 Subject: [PATCH 04/12] rename env_id related function --- isaaclab_arena/relations/placement_events.py | 26 +++++++++----------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/isaaclab_arena/relations/placement_events.py b/isaaclab_arena/relations/placement_events.py index 53c2a81d4..010f18302 100644 --- a/isaaclab_arena/relations/placement_events.py +++ b/isaaclab_arena/relations/placement_events.py @@ -18,18 +18,16 @@ from isaaclab_arena.utils.pose import Pose -def _resolve_env_ids(env: ManagerBasedEnv, env_ids) -> list[int] | None: - """Normalize env_ids from the event manager to a list of env indices.""" - if env_ids is None: - return None +def _env_ids_to_list(env: ManagerBasedEnv, env_ids: torch.Tensor | slice) -> list[int]: + """Convert env_ids (``torch.Tensor`` or ``slice(None)``) to a plain list of int indices. + + Args: + env: The environment instance (used to determine num_envs for slice). + env_ids: Environment indices from the Isaac Lab event manager. + """ if isinstance(env_ids, slice): - if env_ids == slice(None): - return list(range(env.num_envs)) - start, stop, step = env_ids.indices(env.num_envs) - return list(range(start, stop, step)) - if hasattr(env_ids, "tolist"): - return env_ids.tolist() - return list(env_ids) + return list(range(env.num_envs)) + return env_ids.tolist() @configclass @@ -67,10 +65,10 @@ def set_object_pose_per_env_from_layouts( anchor_names: list[str] | None = None, ) -> None: """Set each object's root pose per env from layout dicts; anchors first.""" - resolved = _resolve_env_ids(env, env_ids) - if not resolved: + env_id_list = _env_ids_to_list(env, env_ids) + if not env_id_list: return - env_ids_t = torch.tensor(resolved, device=env.device) + env_ids_t = torch.tensor(env_id_list, device=env.device) anchor_set = set(anchor_names or []) ordered_names = [n for n in object_names if n in anchor_set] ordered_names += [n for n in object_names if n not in anchor_set] From c433c1c45f4fa3eb92fd7dd69eae4c7eace26e85 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Mon, 30 Mar 2026 15:51:05 -0700 Subject: [PATCH 05/12] add assert and type hints --- isaaclab_arena/relations/object_placer.py | 4 ++-- isaaclab_arena/relations/placement_events.py | 3 +++ isaaclab_arena/relations/relation_solver_state.py | 8 +++++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 4899d7179..e782dd214 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -192,7 +192,7 @@ def _get_init_bounds(self, anchor_object: Object | ObjectReference) -> AxisAlign def _generate_initial_positions( self, objects: list[Object | ObjectReference], - anchor_objects: Object | ObjectReference, + anchor_objects: set[Object | ObjectReference], init_bounds: AxisAlignedBoundingBox, ) -> dict[Object | ObjectReference, tuple[float, float, float]]: """Generate initial positions for all objects. @@ -303,7 +303,7 @@ def _validate_placement( def _apply_positions( self, positions: dict[Object | ObjectReference, tuple[float, float, float]], - anchor_objects: Object | ObjectReference, + anchor_objects: set[Object | ObjectReference], ) -> None: """Apply solved positions to objects (skipping anchors). diff --git a/isaaclab_arena/relations/placement_events.py b/isaaclab_arena/relations/placement_events.py index 010f18302..0df972a04 100644 --- a/isaaclab_arena/relations/placement_events.py +++ b/isaaclab_arena/relations/placement_events.py @@ -65,6 +65,9 @@ def set_object_pose_per_env_from_layouts( anchor_names: list[str] | None = None, ) -> None: """Set each object's root pose per env from layout dicts; anchors first.""" + assert ( + len(positions_all_envs_by_name) == env.num_envs + ), f"Expected {env.num_envs} layout dicts, got {len(positions_all_envs_by_name)}" env_id_list = _env_ids_to_list(env, env_ids) if not env_id_list: return diff --git a/isaaclab_arena/relations/relation_solver_state.py b/isaaclab_arena/relations/relation_solver_state.py index b8a9de8f4..990bb34b5 100644 --- a/isaaclab_arena/relations/relation_solver_state.py +++ b/isaaclab_arena/relations/relation_solver_state.py @@ -67,7 +67,13 @@ def __init__( # Separate anchor positions from optimizable positions self._anchor_indices: set[int] = {self._obj_to_idx[obj] for obj in self._anchor_objects} - # Anchors are fixed (same position in all envs), so env 0 is representative. + # Anchors must be identical across envs (they are fixed reference points). + for idx in self._anchor_indices: + for e in range(1, self._num_envs): + assert torch.allclose(positions_per_env[0][idx], positions_per_env[e][idx]), ( + f"Anchor '{objects[idx].name}' has different positions across envs " + f"(env 0: {positions_per_env[0][idx].tolist()}, env {e}: {positions_per_env[e][idx].tolist()})" + ) self._anchor_positions: dict[int, torch.Tensor] = { idx: positions_per_env[0][idx].clone() for idx in self._anchor_indices } From ab1bff2b9180e1a92b20c3a54e5d6d6330bcc0d4 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 31 Mar 2026 10:24:31 -0700 Subject: [PATCH 06/12] potential bug fix --- isaaclab_arena/environments/arena_env_builder.py | 3 +-- isaaclab_arena/relations/placement_events.py | 1 + isaaclab_arena/relations/relation_solver.py | 4 ++-- .../tests/test_object_placer_reproducibility.py | 16 ++++++++++++++++ 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/isaaclab_arena/environments/arena_env_builder.py b/isaaclab_arena/environments/arena_env_builder.py index 2a093ff1b..2a35a1d61 100644 --- a/isaaclab_arena/environments/arena_env_builder.py +++ b/isaaclab_arena/environments/arena_env_builder.py @@ -107,8 +107,7 @@ def _solve_relations(self) -> None: result = placer.place(objects_with_relations, num_envs=num_envs) if isinstance(result, MultiEnvPlacementResult) and result.results: positions_all_envs_by_name = [ - {obj.name: result.results[e].positions[obj] for obj in result.results[0].positions} - for e in range(len(result.results)) + {obj.name: pos for obj, pos in r.positions.items()} for r in result.results ] object_names = [obj.name for obj in objects_with_relations] anchor_names = [a.name for a in get_anchor_objects(objects_with_relations)] diff --git a/isaaclab_arena/relations/placement_events.py b/isaaclab_arena/relations/placement_events.py index 0df972a04..080a1a348 100644 --- a/isaaclab_arena/relations/placement_events.py +++ b/isaaclab_arena/relations/placement_events.py @@ -89,5 +89,6 @@ def set_object_pose_per_env_from_layouts( x, y, z = xyz pose_list.append(Pose(position_xyz=(x, y, z), rotation_wxyz=identity_quat_wxyz)) else: + print(f"Warning: object '{name}' missing from layout dict for env {e}; defaulting to origin") pose_list.append(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_wxyz=identity_quat_wxyz)) set_object_pose_per_env(env, env_ids_t, SceneEntityCfg(name), pose_list) diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index 2ff0f592d..e8e5e4e3f 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -89,7 +89,7 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - child_bbox=obj.get_bounding_box(), ) if debug: - _print_unary_relation_debug(obj, relation, child_pos[0], loss.sum()) + _print_unary_relation_debug(obj, relation, child_pos[0], loss.mean()) # Handle binary relations (with parent) like On, NextTo elif isinstance(relation, Relation): # Build parent world bbox: anchors have a known fixed pose, @@ -108,7 +108,7 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - ) if debug: parent_pos = state.get_position(parent) - _print_relation_debug(obj, relation, child_pos[0], parent_pos[0], loss.sum()) + _print_relation_debug(obj, relation, child_pos[0], parent_pos[0], loss.mean()) else: raise ValueError(f"Unknown relation type: {type(relation).__name__}") diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index 10b9d5829..fb7a1021a 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -150,6 +150,22 @@ def test_object_placer_multi_env_returns_multi_env_result(): assert len(r.positions[box2]) == 3 +def test_object_placer_multi_env_produces_different_positions(): + """Test that multi-env placement produces different positions across environments.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer(params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params)) + result = placer.place(objects, num_envs=num_envs) + + assert isinstance(result, MultiEnvPlacementResult) + # At least one pair of envs should have different positions for a non-anchor object. + positions_box1 = [result.results[e].positions[box1] for e in range(num_envs)] + any_different = any(positions_box1[i] != positions_box1[j] for i in range(num_envs) for j in range(i + 1, num_envs)) + assert any_different, "Multi-env placement should produce different positions across environments" + + def test_relation_solver_multi_env_batched_positions(): """Test that solver with list[dict] input returns list[dict] output.""" solver_params = RelationSolverParams(max_iters=50) From 2b1e0a432be6bb223109a0f0c5da64e09296f1ca Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 31 Mar 2026 10:25:26 -0700 Subject: [PATCH 07/12] pre commit fix --- isaaclab_arena/environments/arena_env_builder.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/isaaclab_arena/environments/arena_env_builder.py b/isaaclab_arena/environments/arena_env_builder.py index 2a35a1d61..45a914e69 100644 --- a/isaaclab_arena/environments/arena_env_builder.py +++ b/isaaclab_arena/environments/arena_env_builder.py @@ -106,9 +106,7 @@ def _solve_relations(self) -> None: num_envs = self.args.num_envs result = placer.place(objects_with_relations, num_envs=num_envs) if isinstance(result, MultiEnvPlacementResult) and result.results: - positions_all_envs_by_name = [ - {obj.name: pos for obj, pos in r.positions.items()} for r in result.results - ] + positions_all_envs_by_name = [{obj.name: pos for obj, pos in r.positions.items()} for r in result.results] object_names = [obj.name for obj in objects_with_relations] anchor_names = [a.name for a in get_anchor_objects(objects_with_relations)] self._placement_event_cfg = make_placement_event_cfg( From 459a7451c319419cafc34b6f4846cff31186f042 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 31 Mar 2026 15:16:32 -0700 Subject: [PATCH 08/12] address comments --- isaaclab_arena/assets/object_base.py | 27 ++- .../environments/arena_env_builder.py | 24 +-- isaaclab_arena/relations/object_placer.py | 174 +++++++++--------- isaaclab_arena/relations/placement_events.py | 94 ---------- .../test_object_placer_reproducibility.py | 53 +++++- isaaclab_arena/tests/test_pose.py | 17 +- isaaclab_arena/utils/pose.py | 8 + 7 files changed, 185 insertions(+), 212 deletions(-) delete mode 100644 isaaclab_arena/relations/placement_events.py diff --git a/isaaclab_arena/assets/object_base.py b/isaaclab_arena/assets/object_base.py index 85585c6ca..cab16fe97 100644 --- a/isaaclab_arena/assets/object_base.py +++ b/isaaclab_arena/assets/object_base.py @@ -15,8 +15,8 @@ from isaaclab_arena.assets.asset import Asset from isaaclab_arena.relations.relations import AtPosition, Relation, RelationBase -from isaaclab_arena.terms.events import set_object_pose -from isaaclab_arena.utils.pose import Pose, PoseRange +from isaaclab_arena.terms.events import set_object_pose, set_object_pose_per_env +from isaaclab_arena.utils.pose import Pose, PosePerEnv, PoseRange from isaaclab_arena.utils.velocity import Velocity @@ -42,13 +42,13 @@ def __init__( prim_path = "{ENV_REGEX_NS}/" + self.name self.prim_path = prim_path self.object_type = object_type - self.initial_pose: Pose | PoseRange | None = None + self.initial_pose: Pose | PoseRange | PosePerEnv | None = None self.initial_velocity: Velocity | None = None self.object_cfg = None self.event_cfg = None self.relations: list[RelationBase] = [] - def get_initial_pose(self) -> Pose | PoseRange | None: + def get_initial_pose(self) -> Pose | PoseRange | PosePerEnv | None: """Return the current initial pose of this object. Subclasses may override to derive the pose from other sources @@ -60,20 +60,24 @@ def _get_initial_pose_as_pose(self) -> Pose | None: """Return a single ``Pose`` suitable for *init_state* and bounding-box calculations. If the initial pose is a ``PoseRange``, its midpoint is returned. + If the initial pose is a ``PosePerEnv``, the first environment's pose is returned. If the initial pose is ``None``, ``None`` is returned. """ initial_pose = self.get_initial_pose() if initial_pose is None: return None + if isinstance(initial_pose, PosePerEnv): + return initial_pose.poses[0] if isinstance(initial_pose, PoseRange): return initial_pose.get_midpoint() return initial_pose - def set_initial_pose(self, pose: Pose | PoseRange) -> None: + def set_initial_pose(self, pose: Pose | PoseRange | PosePerEnv) -> None: """Set / override the initial pose and rebuild derived configs. Args: - pose: A fixed ``Pose`` or a ``PoseRange`` (randomised on reset). + pose: A fixed ``Pose``, a ``PoseRange`` (randomised on reset), + or a ``PosePerEnv`` (distinct pose per environment). """ self.initial_pose = pose initial_pose = self._get_initial_pose_as_pose() @@ -116,7 +120,16 @@ def _init_event_cfg(self) -> EventTermCfg | None: return None initial_pose = self.get_initial_pose() - if isinstance(initial_pose, PoseRange): + if isinstance(initial_pose, PosePerEnv): + return EventTermCfg( + func=set_object_pose_per_env, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg(self.name), + "pose_list": initial_pose.poses, + }, + ) + elif isinstance(initial_pose, PoseRange): return EventTermCfg( func=randomize_object_pose, mode="reset", diff --git a/isaaclab_arena/environments/arena_env_builder.py b/isaaclab_arena/environments/arena_env_builder.py index 45a914e69..fe1e8e9c5 100644 --- a/isaaclab_arena/environments/arena_env_builder.py +++ b/isaaclab_arena/environments/arena_env_builder.py @@ -27,9 +27,8 @@ from isaaclab_arena.metrics.recorder_manager_utils import metrics_to_recorder_manager_cfg from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams -from isaaclab_arena.relations.placement_events import make_placement_event_cfg from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult -from isaaclab_arena.relations.relations import IsAnchor, NoCollision, get_anchor_objects +from isaaclab_arena.relations.relations import IsAnchor, NoCollision from isaaclab_arena.tasks.no_task import NoTask from isaaclab_arena.utils.configclass import combine_configclass_instances from isaaclab_arena.utils.multiprocess import get_local_rank @@ -101,26 +100,16 @@ def _solve_relations(self) -> None: self._add_pairwise_no_collision(objects_with_relations) # Run the ObjectPlacer (default on_relation_z_tolerance_m accommodates solver residual). + # Positions are applied to objects via set_initial_pose (single-env: Pose/PoseRange, + # multi-env: PosePerEnv), so each object's event_cfg handles its own reset. placement_seed = getattr(self.args, "placement_seed", None) placer = ObjectPlacer(params=ObjectPlacerParams(placement_seed=placement_seed)) num_envs = self.args.num_envs result = placer.place(objects_with_relations, num_envs=num_envs) - if isinstance(result, MultiEnvPlacementResult) and result.results: - positions_all_envs_by_name = [{obj.name: pos for obj, pos in r.positions.items()} for r in result.results] - object_names = [obj.name for obj in objects_with_relations] - anchor_names = [a.name for a in get_anchor_objects(objects_with_relations)] - self._placement_event_cfg = make_placement_event_cfg( - positions_all_envs_by_name, - object_names, - anchor_names, - ) - else: - self._placement_event_cfg = None # Log outcome - results_per_env = result.results if isinstance(result, MultiEnvPlacementResult) else None - if results_per_env is not None: - n_succeeded = sum(1 for r in results_per_env if r.success) + if isinstance(result, MultiEnvPlacementResult): + n_succeeded = sum(1 for r in result.results if r.success) if n_succeeded == num_envs: print(f"Relation solving succeeded for all {num_envs} env(s) after {result.attempts} attempt(s)") else: @@ -182,9 +171,6 @@ def compose_manager_cfg(self) -> IsaacLabArenaManagerBasedRLEnvCfg: self.arena_env.scene.get_events_cfg(), task.get_events_cfg(), ] - placement_event = getattr(self, "_placement_event_cfg", None) - if placement_event is not None: - events_sources.append(placement_event) events_cfg = combine_configclass_instances("EventsCfg", *events_sources) termination_cfg = combine_configclass_instances( "TerminationCfg", diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index e782dd214..187e01fc2 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -5,6 +5,8 @@ from __future__ import annotations +import math + import torch from typing import TYPE_CHECKING @@ -13,7 +15,7 @@ from isaaclab_arena.relations.relation_solver import RelationSolver from isaaclab_arena.relations.relations import On, RandomAroundSolution, RotateAroundSolution, get_anchor_objects from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, get_random_pose_within_bounding_box -from isaaclab_arena.utils.pose import Pose +from isaaclab_arena.utils.pose import Pose, PosePerEnv if TYPE_CHECKING: from isaaclab_arena.assets.object import Object @@ -23,11 +25,12 @@ class ObjectPlacer: """High-level API for placing objects according to their spatial relations. - Encapsulates the workflow of: - 1. Random initialization of object positions - 2. Running the RelationSolver - 3. Validating the result - 4. Retrying if necessary + Generates a pool of candidate layouts in a single batched solver call, + validates each candidate, and selects the best ones: + 1. Random initialization of object positions for all candidates + 2. Running the RelationSolver on the full candidate pool + 3. Validating and ranking candidates by (validity, loss) + 4. Selecting the best num_envs results 5. Applying solved positions to objects Supports single-env (num_envs=1) and batched (num_envs>1) placement. @@ -46,6 +49,7 @@ def place( self, objects: list[Object | ObjectReference], num_envs: int = 1, + result_per_env: bool = True, ) -> PlacementResult | MultiEnvPlacementResult: """Place objects according to their spatial relations. @@ -54,9 +58,13 @@ def place( marked with IsAnchor() which serves as a fixed reference. num_envs: Number of environments. 1 for single-env; > 1 for batched placement (one layout per env). + result_per_env: When True (default), each environment gets a distinct + layout. When False, a single best layout is solved and applied + identically to all environments (useful for deterministic evaluation). Returns: - PlacementResult when num_envs is 1; MultiEnvPlacementResult when num_envs > 1. + PlacementResult when a single layout is produced (num_envs=1 or + result_per_env=False); MultiEnvPlacementResult otherwise. """ # Validate all objects have at least one relation for obj in objects: @@ -86,88 +94,65 @@ def place( # Implement an initialization strategy that infers from the Relations(s). init_bounds = self._get_init_bounds(anchor_objects[0]) - # Placement loop with retries (per-env tracking) - best_valid_loss_per_env: list[float] = [float("inf")] * num_envs - best_valid_positions_per_env: list[dict | None] = [None] * num_envs - best_any_loss_per_env: list[float] = [float("inf")] * num_envs - best_any_positions_per_env: list[dict] = [dict() for _ in range(num_envs)] - - for attempt in range(self.params.max_placement_attempts): - # Generate starting positions per env (anchors from their poses, others random) - initial_positions: list[dict] = [] - for env_i in range(num_envs): - rng_state = None - if self.params.placement_seed is not None: - rng_state = torch.get_rng_state() - torch.manual_seed(self.params.placement_seed + env_i + attempt * (num_envs + 1)) - initial_positions.append(self._generate_initial_positions(objects, anchor_objects_set, init_bounds)) - if rng_state is not None: - torch.set_rng_state(rng_state) - - # solve() returns list[dict] when given list[dict] initial_positions - positions_per_env: list[dict] = self._solver.solve(objects, initial_positions) # type: ignore[assignment] # overload returns list[dict] for list input - per_env_loss = ( - self._solver.last_loss_per_env.cpu().tolist() - if self._solver.last_loss_per_env is not None - else [float("inf")] * num_envs - ) + # When result_per_env is True, each env needs its own layout, so we + # generate max_placement_attempts * num_envs candidates and pick the + # best num_envs. When False, we only need one layout applied to all envs. + num_results = num_envs if result_per_env else 1 + num_candidates = self.params.max_placement_attempts * num_results + + initial_positions: list[dict] = [] + for candidate_idx in range(num_candidates): + rng_state = None + if self.params.placement_seed is not None: + rng_state = torch.get_rng_state() + torch.manual_seed(self.params.placement_seed + candidate_idx) + initial_positions.append(self._generate_initial_positions(objects, anchor_objects_set, init_bounds)) + if rng_state is not None: + torch.set_rng_state(rng_state) + + all_positions: list[dict] = self._solver.solve(objects, initial_positions) # type: ignore[assignment] # overload returns list[dict] for list input + all_losses = ( + self._solver.last_loss_per_env.cpu().tolist() + if self._solver.last_loss_per_env is not None + else [float("inf")] * num_candidates + ) - # Check if placement is valid (per env); update best valid and best-by-loss fallback - for e in range(num_envs): - loss_e = per_env_loss[e] if e < len(per_env_loss) else float("inf") - valid = self._validate_placement(positions_per_env[e]) - if valid and loss_e < best_valid_loss_per_env[e]: - best_valid_loss_per_env[e] = loss_e - best_valid_positions_per_env[e] = positions_per_env[e] - if loss_e < best_any_loss_per_env[e]: - best_any_loss_per_env[e] = loss_e - best_any_positions_per_env[e] = positions_per_env[e] - - if self.params.verbose: - mean_loss = sum(per_env_loss) / num_envs - n_succeeded = sum(1 for p in best_valid_positions_per_env if p is not None) - print( - f"Attempt {attempt + 1}/{self.params.max_placement_attempts}:" - f" loss = {mean_loss:.6f}, envs validated = {n_succeeded}/{num_envs}" - ) - - if all(best_valid_positions_per_env): - if self.params.verbose: - print(f"Success on attempt {attempt + 1}") - break - - # Per env: use best valid if any, else best-by-loss fallback - final_per_env: list[dict] = [ - ( - best_valid_positions_per_env[e] - if best_valid_positions_per_env[e] is not None - else best_any_positions_per_env[e] + all_candidates: list[tuple[float, dict, bool]] = [] + for idx in range(num_candidates): + loss = all_losses[idx] if idx < len(all_losses) else float("inf") + is_valid = self._validate_placement(all_positions[idx]) + all_candidates.append((loss, all_positions[idx], is_valid)) + + # Sort: valid solutions first (by loss), then invalid (by loss) + all_candidates.sort(key=lambda c: (not c[2], c[0])) + selected = all_candidates[:num_results] + + n_valid = sum(1 for c in selected if c[2]) + if self.params.verbose: + total_valid = sum(1 for c in all_candidates if c[2]) + finite_losses = [c[0] for c in all_candidates if math.isfinite(c[0])] + mean_loss = sum(finite_losses) / len(finite_losses) if finite_losses else float("inf") + print( + f"Solved {num_candidates} candidates in one batch: mean loss = {mean_loss:.6f}," + f" {total_valid} valid, selected best {num_results} ({n_valid} valid)" ) - for e in range(num_envs) - ] + final_per_env: list[dict] = [c[1] for c in selected] results_per_env = [ PlacementResult( - success=best_valid_positions_per_env[e] is not None, - positions=final_per_env[e], - final_loss=( - best_valid_loss_per_env[e] - if best_valid_positions_per_env[e] is not None - else best_any_loss_per_env[e] - ), - attempts=attempt + 1, + success=c[2], + positions=c[1], + final_loss=c[0], + attempts=self.params.max_placement_attempts, ) - for e in range(num_envs) + for c in selected ] - # Apply solved positions to objects - # TODO(@zhx06): Consider applying via event for consistency with multi_env. - if num_envs == 1 and self.params.apply_positions_to_objects: - self._apply_positions(final_per_env[0], anchor_objects_set) + if self.params.apply_positions_to_objects: + self._apply_positions(final_per_env, anchor_objects_set) - if num_envs == 1: + if num_results == 1: return results_per_env[0] - # Multi-env: layouts applied at reset via placement event (builder builds event_cfg from result) return MultiEnvPlacementResult(results=results_per_env) def _get_init_bounds(self, anchor_object: Object | ObjectReference) -> AxisAlignedBoundingBox: @@ -302,29 +287,38 @@ def _validate_placement( def _apply_positions( self, - positions: dict[Object | ObjectReference, tuple[float, float, float]], + positions_per_env: list[dict[Object | ObjectReference, tuple[float, float, float]]], anchor_objects: set[Object | ObjectReference], ) -> None: """Apply solved positions to objects (skipping anchors). - If RandomAroundSolution marker is present, sets a PoseRange (for reset-time randomization). - Rotation is taken from RotateAroundSolution marker if present, otherwise keep the identity rotation. + Handles both single-env and multi-env placement: + - Single-env: sets a fixed Pose or PoseRange (with RandomAroundSolution). + - Multi-env: sets a PosePerEnv with one Pose per environment. + + Rotation is taken from RotateAroundSolution marker if present, otherwise identity. """ - for obj, pos in positions.items(): + num_envs = len(positions_per_env) + for obj in positions_per_env[0]: if obj in anchor_objects: continue - random_marker = self._get_random_around_solution(obj) rotate_marker = self._get_rotate_around_solution(obj) rotation_wxyz = rotate_marker.get_rotation_wxyz() if rotate_marker else (1.0, 0.0, 0.0, 0.0) - if random_marker is not None: - # We need to set a PoseRange for the randomization to be picked up on reset. - # Set a PoseRange with the explicit rotation from RotateAroundSolution if present - obj.set_initial_pose(random_marker.to_pose_range_centered_at(pos, rotation_wxyz=rotation_wxyz)) + if num_envs == 1: + pos = positions_per_env[0][obj] + random_marker = self._get_random_around_solution(obj) + if random_marker is not None: + obj.set_initial_pose(random_marker.to_pose_range_centered_at(pos, rotation_wxyz=rotation_wxyz)) + else: + obj.set_initial_pose(Pose(position_xyz=pos, rotation_wxyz=rotation_wxyz)) else: - # Without randomization, we can set a fixed Pose. - obj.set_initial_pose(Pose(position_xyz=pos, rotation_wxyz=rotation_wxyz)) + poses = [ + Pose(position_xyz=positions_per_env[env_idx][obj], rotation_wxyz=rotation_wxyz) + for env_idx in range(num_envs) + ] + obj.set_initial_pose(PosePerEnv(poses=poses)) def _get_random_around_solution(self, obj: Object | ObjectReference) -> RandomAroundSolution | None: """Get RandomAroundSolution marker from object if present. diff --git a/isaaclab_arena/relations/placement_events.py b/isaaclab_arena/relations/placement_events.py deleted file mode 100644 index 080a1a348..000000000 --- a/isaaclab_arena/relations/placement_events.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Placement event: apply placement layouts per env on reset (num_envs>1).""" - -from __future__ import annotations - -import dataclasses -import torch - -from isaaclab.envs import ManagerBasedEnv -from isaaclab.managers import EventTermCfg, SceneEntityCfg -from isaaclab.utils import configclass - -from isaaclab_arena.terms.events import set_object_pose_per_env -from isaaclab_arena.utils.pose import Pose - - -def _env_ids_to_list(env: ManagerBasedEnv, env_ids: torch.Tensor | slice) -> list[int]: - """Convert env_ids (``torch.Tensor`` or ``slice(None)``) to a plain list of int indices. - - Args: - env: The environment instance (used to determine num_envs for slice). - env_ids: Environment indices from the Isaac Lab event manager. - """ - if isinstance(env_ids, slice): - return list(range(env.num_envs)) - return env_ids.tolist() - - -@configclass -class PlacementEventsCfg: - """Event config for applying placement layouts per env on reset.""" - - set_object_pose_per_env_from_layouts: EventTermCfg = dataclasses.MISSING # type: ignore[assignment] - - -def make_placement_event_cfg( - positions_all_envs_by_name: list[dict[str, tuple[float, float, float]]], - object_names: list[str], - anchor_names: list[str] | None = None, -) -> PlacementEventsCfg: - """Build event config for applying placement layouts per env on reset.""" - params: dict = { - "positions_all_envs_by_name": positions_all_envs_by_name, - "object_names": object_names, - "anchor_names": anchor_names or [], - } - return PlacementEventsCfg( - set_object_pose_per_env_from_layouts=EventTermCfg( - func=set_object_pose_per_env_from_layouts, - mode="reset", - params=params, - ) - ) - - -def set_object_pose_per_env_from_layouts( - env: ManagerBasedEnv, - env_ids, - positions_all_envs_by_name: list[dict[str, tuple[float, float, float]]], - object_names: list[str], - anchor_names: list[str] | None = None, -) -> None: - """Set each object's root pose per env from layout dicts; anchors first.""" - assert ( - len(positions_all_envs_by_name) == env.num_envs - ), f"Expected {env.num_envs} layout dicts, got {len(positions_all_envs_by_name)}" - env_id_list = _env_ids_to_list(env, env_ids) - if not env_id_list: - return - env_ids_t = torch.tensor(env_id_list, device=env.device) - anchor_set = set(anchor_names or []) - ordered_names = [n for n in object_names if n in anchor_set] - ordered_names += [n for n in object_names if n not in anchor_set] - identity_quat_wxyz = (1.0, 0.0, 0.0, 0.0) - for name in ordered_names: - if name not in env.scene.keys(): - continue - asset = env.scene[name] - if not hasattr(asset, "write_root_pose_to_sim"): - continue - pose_list = [] - for e in range(len(positions_all_envs_by_name)): - xyz = positions_all_envs_by_name[e].get(name) - if xyz is not None: - x, y, z = xyz - pose_list.append(Pose(position_xyz=(x, y, z), rotation_wxyz=identity_quat_wxyz)) - else: - print(f"Warning: object '{name}' missing from layout dict for env {e}; defaulting to origin") - pose_list.append(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_wxyz=identity_quat_wxyz)) - set_object_pose_per_env(env, env_ids_t, SceneEntityCfg(name), pose_list) diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index fb7a1021a..9e0d50082 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -13,7 +13,7 @@ from isaaclab_arena.relations.relation_solver_params import RelationSolverParams from isaaclab_arena.relations.relations import IsAnchor, NextTo, On, Side from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, get_random_pose_within_bounding_box -from isaaclab_arena.utils.pose import Pose +from isaaclab_arena.utils.pose import Pose, PosePerEnv def _create_test_objects() -> tuple[DummyObject, DummyObject, DummyObject]: @@ -187,3 +187,54 @@ def test_relation_solver_multi_env_batched_positions(): for obj in objects: assert obj in d assert len(d[obj]) == 3 + + +def test_object_placer_result_per_env_false_returns_single_result(): + """Test that place(num_envs>1, result_per_env=False) returns PlacementResult.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer(params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params)) + result = placer.place(objects, num_envs=num_envs, result_per_env=False) + + assert isinstance(result, PlacementResult), "result_per_env=False should return PlacementResult" + assert not isinstance(result, MultiEnvPlacementResult) + assert box1 in result.positions + assert box2 in result.positions + assert len(result.positions[box1]) == 3 + assert len(result.positions[box2]) == 3 + + +def test_object_placer_result_per_env_false_applies_pose_not_pose_per_env(): + """Test that result_per_env=False sets a single Pose (not PosePerEnv) on each object.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer( + params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=True) + ) + placer.place(objects, num_envs=num_envs, result_per_env=False) + + for obj in [box1, box2]: + pose = obj.get_initial_pose() + assert isinstance(pose, Pose), f"{obj.name} should have a Pose, got {type(pose).__name__}" + assert not isinstance(pose, PosePerEnv) + + +def test_object_placer_result_per_env_true_applies_pose_per_env(): + """Test that result_per_env=True (default) sets PosePerEnv on each object when num_envs>1.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer( + params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=True) + ) + placer.place(objects, num_envs=num_envs, result_per_env=True) + + for obj in [box1, box2]: + pose = obj.get_initial_pose() + assert isinstance(pose, PosePerEnv), f"{obj.name} should have PosePerEnv, got {type(pose).__name__}" + assert len(pose.poses) == num_envs diff --git a/isaaclab_arena/tests/test_pose.py b/isaaclab_arena/tests/test_pose.py index 718fc8c76..8a3a5c3bc 100644 --- a/isaaclab_arena/tests/test_pose.py +++ b/isaaclab_arena/tests/test_pose.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from isaaclab_arena.utils.pose import Pose +from isaaclab_arena.utils.pose import Pose, PosePerEnv def test_pose_composition(): @@ -14,3 +14,18 @@ def test_pose_composition(): assert T_C_A.position_xyz == (3.0, 0.0, 0.0) assert T_C_A.rotation_wxyz == (1.0, 0.0, 0.0, 0.0) + + +def test_pose_per_env_stores_poses(): + """Test that PosePerEnv stores the list of Pose objects correctly.""" + poses = [ + Pose(position_xyz=(1.0, 2.0, 3.0)), + Pose(position_xyz=(4.0, 5.0, 6.0)), + Pose(position_xyz=(7.0, 8.0, 9.0)), + ] + pose_per_env = PosePerEnv(poses=poses) + + assert len(pose_per_env.poses) == 3 + assert pose_per_env.poses[0].position_xyz == (1.0, 2.0, 3.0) + assert pose_per_env.poses[1].position_xyz == (4.0, 5.0, 6.0) + assert pose_per_env.poses[2].position_xyz == (7.0, 8.0, 9.0) diff --git a/isaaclab_arena/utils/pose.py b/isaaclab_arena/utils/pose.py index 57babdc58..088f1e311 100644 --- a/isaaclab_arena/utils/pose.py +++ b/isaaclab_arena/utils/pose.py @@ -74,6 +74,14 @@ def compose_poses(T_C_B: Pose, T_B_A: Pose) -> Pose: return Pose(position_xyz=tuple(t_C_A.tolist()), rotation_wxyz=tuple(q_C_A.tolist())) +@dataclass +class PosePerEnv: + """Per-environment poses (one Pose per env, used for batched placement).""" + + poses: list[Pose] + """One Pose per environment.""" + + @dataclass class PoseRange: """Range of poses. From 1680357d386692807861345c1b2870799f2e73ef Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 31 Mar 2026 15:25:42 -0700 Subject: [PATCH 09/12] pre commit check --- isaaclab_arena/relations/object_placer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 187e01fc2..41c32a690 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -6,7 +6,6 @@ from __future__ import annotations import math - import torch from typing import TYPE_CHECKING @@ -25,12 +24,11 @@ class ObjectPlacer: """High-level API for placing objects according to their spatial relations. - Generates a pool of candidate layouts in a single batched solver call, - validates each candidate, and selects the best ones: - 1. Random initialization of object positions for all candidates - 2. Running the RelationSolver on the full candidate pool - 3. Validating and ranking candidates by (validity, loss) - 4. Selecting the best num_envs results + Encapsulates the workflow of: + 1. Random initialization of object positions + 2. Running the RelationSolver + 3. Validating the result + 4. Retrying if necessary 5. Applying solved positions to objects Supports single-env (num_envs=1) and batched (num_envs>1) placement. From 8c7f5a322854f655f0b0bffafd85b5f904200836 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 31 Mar 2026 17:07:15 -0700 Subject: [PATCH 10/12] add comments for Object Placer --- isaaclab_arena/relations/object_placer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 41c32a690..c4bc40b22 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -92,9 +92,9 @@ def place( # Implement an initialization strategy that infers from the Relations(s). init_bounds = self._get_init_bounds(anchor_objects[0]) - # When result_per_env is True, each env needs its own layout, so we - # generate max_placement_attempts * num_envs candidates and pick the - # best num_envs. When False, we only need one layout applied to all envs. + # Pool-based placement: generate all candidates in one batched call, + # then pick the best num_results (environments are homogeneous so any + # valid solution can serve any environment). num_results = num_envs if result_per_env else 1 num_candidates = self.params.max_placement_attempts * num_results From 2b9f06d92d91ff9b21697e46ef81fbf9b2b0d69f Mon Sep 17 00:00:00 2001 From: zhx06 Date: Wed, 1 Apr 2026 11:08:40 -0700 Subject: [PATCH 11/12] simplify types to ObjectBase, always use list --- isaaclab_arena/assets/object_base.py | 11 ++ .../environments/arena_env_builder.py | 6 +- isaaclab_arena/relations/object_placer.py | 127 ++++++++++-------- isaaclab_arena/relations/placement_result.py | 4 +- isaaclab_arena/relations/relation_solver.py | 38 ++---- .../relations/relation_solver_state.py | 32 ++--- isaaclab_arena/relations/relations.py | 15 +-- .../tests/test_no_collision_loss.py | 6 +- .../test_object_placer_reproducibility.py | 4 +- 9 files changed, 128 insertions(+), 115 deletions(-) diff --git a/isaaclab_arena/assets/object_base.py b/isaaclab_arena/assets/object_base.py index cab16fe97..b8b3c4e4b 100644 --- a/isaaclab_arena/assets/object_base.py +++ b/isaaclab_arena/assets/object_base.py @@ -16,6 +16,7 @@ from isaaclab_arena.assets.asset import Asset from isaaclab_arena.relations.relations import AtPosition, Relation, RelationBase from isaaclab_arena.terms.events import set_object_pose, set_object_pose_per_env +from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox from isaaclab_arena.utils.pose import Pose, PosePerEnv, PoseRange from isaaclab_arena.utils.velocity import Velocity @@ -56,6 +57,16 @@ def get_initial_pose(self) -> Pose | PoseRange | PosePerEnv | None: """ return self.initial_pose + @abstractmethod + def get_bounding_box(self) -> AxisAlignedBoundingBox: + """Get local bounding box (relative to object origin).""" + ... + + @abstractmethod + def get_world_bounding_box(self) -> AxisAlignedBoundingBox: + """Get bounding box in world coordinates (local bbox rotated and translated).""" + ... + def _get_initial_pose_as_pose(self) -> Pose | None: """Return a single ``Pose`` suitable for *init_state* and bounding-box calculations. diff --git a/isaaclab_arena/environments/arena_env_builder.py b/isaaclab_arena/environments/arena_env_builder.py index fe1e8e9c5..e5a01e4e5 100644 --- a/isaaclab_arena/environments/arena_env_builder.py +++ b/isaaclab_arena/environments/arena_env_builder.py @@ -166,12 +166,12 @@ def compose_manager_cfg(self) -> IsaacLabArenaManagerBasedRLEnvCfg: embodiment.get_observation_cfg(), task.get_observation_cfg(), ) - events_sources = [ + events_cfg = combine_configclass_instances( + "EventsCfg", embodiment.get_events_cfg(), self.arena_env.scene.get_events_cfg(), task.get_events_cfg(), - ] - events_cfg = combine_configclass_instances("EventsCfg", *events_sources) + ) termination_cfg = combine_configclass_instances( "TerminationCfg", task.get_termination_cfg(), diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index c4bc40b22..2c552cdd7 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -7,6 +7,7 @@ import math import torch +from dataclasses import dataclass from typing import TYPE_CHECKING from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams @@ -17,8 +18,21 @@ from isaaclab_arena.utils.pose import Pose, PosePerEnv if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object - from isaaclab_arena.assets.object_reference import ObjectReference + from isaaclab_arena.assets.object_base import ObjectBase + + +@dataclass +class _Candidate: + """A single placement candidate produced by the solver.""" + + loss: float + """Loss value returned by the solver.""" + + positions: dict[ObjectBase, tuple[float, float, float]] + """Solved positions for each object.""" + + is_valid: bool + """Whether the placement passed validation checks.""" class ObjectPlacer: @@ -45,7 +59,7 @@ def __init__(self, params: ObjectPlacerParams | None = None): def place( self, - objects: list[Object | ObjectReference], + objects: list[ObjectBase], num_envs: int = 1, result_per_env: bool = True, ) -> PlacementResult | MultiEnvPlacementResult: @@ -98,49 +112,38 @@ def place( num_results = num_envs if result_per_env else 1 num_candidates = self.params.max_placement_attempts * num_results - initial_positions: list[dict] = [] - for candidate_idx in range(num_candidates): - rng_state = None - if self.params.placement_seed is not None: - rng_state = torch.get_rng_state() - torch.manual_seed(self.params.placement_seed + candidate_idx) - initial_positions.append(self._generate_initial_positions(objects, anchor_objects_set, init_bounds)) - if rng_state is not None: - torch.set_rng_state(rng_state) + initial_positions = self._generate_initial_positions(objects, anchor_objects_set, init_bounds, num_candidates) - all_positions: list[dict] = self._solver.solve(objects, initial_positions) # type: ignore[assignment] # overload returns list[dict] for list input - all_losses = ( - self._solver.last_loss_per_env.cpu().tolist() - if self._solver.last_loss_per_env is not None - else [float("inf")] * num_candidates - ) + all_positions = self._solver.solve(objects, initial_positions) + assert self._solver.last_loss_per_env is not None + all_losses: list[float] = self._solver.last_loss_per_env.cpu().tolist() - all_candidates: list[tuple[float, dict, bool]] = [] + all_candidates: list[_Candidate] = [] for idx in range(num_candidates): - loss = all_losses[idx] if idx < len(all_losses) else float("inf") + loss = all_losses[idx] is_valid = self._validate_placement(all_positions[idx]) - all_candidates.append((loss, all_positions[idx], is_valid)) + all_candidates.append(_Candidate(loss, all_positions[idx], is_valid)) # Sort: valid solutions first (by loss), then invalid (by loss) - all_candidates.sort(key=lambda c: (not c[2], c[0])) + all_candidates.sort(key=lambda c: (not c.is_valid, c.loss)) selected = all_candidates[:num_results] - n_valid = sum(1 for c in selected if c[2]) + n_valid = sum(1 for c in selected if c.is_valid) if self.params.verbose: - total_valid = sum(1 for c in all_candidates if c[2]) - finite_losses = [c[0] for c in all_candidates if math.isfinite(c[0])] + total_valid = sum(1 for c in all_candidates if c.is_valid) + finite_losses = [c.loss for c in all_candidates if math.isfinite(c.loss)] mean_loss = sum(finite_losses) / len(finite_losses) if finite_losses else float("inf") print( f"Solved {num_candidates} candidates in one batch: mean loss = {mean_loss:.6f}," f" {total_valid} valid, selected best {num_results} ({n_valid} valid)" ) - final_per_env: list[dict] = [c[1] for c in selected] + final_per_env: list[dict] = [c.positions for c in selected] results_per_env = [ PlacementResult( - success=c[2], - positions=c[1], - final_loss=c[0], + success=c.is_valid, + positions=c.positions, + final_loss=c.loss, attempts=self.params.max_placement_attempts, ) for c in selected @@ -153,7 +156,7 @@ def place( return results_per_env[0] return MultiEnvPlacementResult(results=results_per_env) - def _get_init_bounds(self, anchor_object: Object | ObjectReference) -> AxisAlignedBoundingBox: + def _get_init_bounds(self, anchor_object: ObjectBase) -> AxisAlignedBoundingBox: """Get bounds for random position initialization. If init_bounds is provided in params, use it. @@ -174,29 +177,41 @@ def _get_init_bounds(self, anchor_object: Object | ObjectReference) -> AxisAlign def _generate_initial_positions( self, - objects: list[Object | ObjectReference], - anchor_objects: set[Object | ObjectReference], + objects: list[ObjectBase], + anchor_objects: set[ObjectBase], init_bounds: AxisAlignedBoundingBox, - ) -> dict[Object | ObjectReference, tuple[float, float, float]]: - """Generate initial positions for all objects. + num_candidates: int, + ) -> list[dict[ObjectBase, tuple[float, float, float]]]: + """Generate initial positions for ``num_candidates`` placement candidates. + + Each candidate maps every object to a starting position: anchors keep + their current ``initial_pose``; others receive a random position within + ``init_bounds``. When ``placement_seed`` is set, each candidate gets a + deterministic seed (``placement_seed + candidate_idx``). + """ + results: list[dict[ObjectBase, tuple[float, float, float]]] = [] + for candidate_idx in range(num_candidates): + rng_state = None + if self.params.placement_seed is not None: + rng_state = torch.get_rng_state() + torch.manual_seed(self.params.placement_seed + candidate_idx) - Anchors keep their current initial_pose, others get random positions. + positions: dict[ObjectBase, tuple[float, float, float]] = {} + for obj in objects: + if obj in anchor_objects: + positions[obj] = obj.get_initial_pose().position_xyz + else: + random_pose = get_random_pose_within_bounding_box(init_bounds) + positions[obj] = random_pose.position_xyz + results.append(positions) - Returns: - Dictionary mapping all objects to their starting positions. - """ - positions: dict[Object | ObjectReference, tuple[float, float, float]] = {} - for obj in objects: - if obj in anchor_objects: - positions[obj] = obj.get_initial_pose().position_xyz - else: - random_pose = get_random_pose_within_bounding_box(init_bounds) - positions[obj] = random_pose.position_xyz - return positions + if rng_state is not None: + torch.set_rng_state(rng_state) + return results def _validate_on_relations( self, - positions: dict[Object | ObjectReference, tuple[float, float, float]], + positions: dict[ObjectBase, tuple[float, float, float]], ) -> bool: """Validate each On relation; logic matches OnLossStrategy (relation_loss_strategies.py). @@ -238,7 +253,7 @@ def _validate_on_relations( def _validate_no_overlap( self, - positions: dict[Object | ObjectReference, tuple[float, float, float]], + positions: dict[ObjectBase, tuple[float, float, float]], ) -> bool: """Validate that no two objects overlap in 3D (axis-aligned bbox with margin). @@ -257,6 +272,7 @@ def _validate_no_overlap( for i in range(len(objects)): for j in range(i + 1, len(objects)): a, b = objects[i], objects[j] + # Pairs related by an On relation are excluded from the overlap check. if (id(a), id(b)) in on_pairs: continue @@ -271,7 +287,7 @@ def _validate_no_overlap( def _validate_placement( self, - positions: dict[Object | ObjectReference, tuple[float, float, float]], + positions: dict[ObjectBase, tuple[float, float, float]], ) -> bool: """Validate that no two objects overlap in 3D and On relations are satisfied. @@ -285,8 +301,8 @@ def _validate_placement( def _apply_positions( self, - positions_per_env: list[dict[Object | ObjectReference, tuple[float, float, float]]], - anchor_objects: set[Object | ObjectReference], + positions_per_env: list[dict[ObjectBase, tuple[float, float, float]]], + anchor_objects: set[ObjectBase], ) -> None: """Apply solved positions to objects (skipping anchors). @@ -297,7 +313,10 @@ def _apply_positions( Rotation is taken from RotateAroundSolution marker if present, otherwise identity. """ num_envs = len(positions_per_env) - for obj in positions_per_env[0]: + # Objects are the same for every environment. Extract them. + objects = list(positions_per_env[0]) + # Apply pose for each object. + for obj in objects: if obj in anchor_objects: continue @@ -318,7 +337,7 @@ def _apply_positions( ] obj.set_initial_pose(PosePerEnv(poses=poses)) - def _get_random_around_solution(self, obj: Object | ObjectReference) -> RandomAroundSolution | None: + def _get_random_around_solution(self, obj: ObjectBase) -> RandomAroundSolution | None: """Get RandomAroundSolution marker from object if present. Args: @@ -332,7 +351,7 @@ def _get_random_around_solution(self, obj: Object | ObjectReference) -> RandomAr return rel return None - def _get_rotate_around_solution(self, obj: Object | ObjectReference) -> RotateAroundSolution | None: + def _get_rotate_around_solution(self, obj: ObjectBase) -> RotateAroundSolution | None: """Get RotateAroundSolution marker from object if present. Args: diff --git a/isaaclab_arena/relations/placement_result.py b/isaaclab_arena/relations/placement_result.py index ecd47e8fa..22f76a14d 100644 --- a/isaaclab_arena/relations/placement_result.py +++ b/isaaclab_arena/relations/placement_result.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object + from isaaclab_arena.assets.object_base import ObjectBase @dataclass @@ -19,7 +19,7 @@ class PlacementResult: success: bool """Whether placement passed validation checks.""" - positions: dict[Object, tuple[float, float, float]] + positions: dict[ObjectBase, tuple[float, float, float]] """Final positions for each object.""" final_loss: float diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index e8e5e4e3f..9df5155b5 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -14,8 +14,7 @@ from isaaclab_arena.relations.relations import AtPosition, Relation, RelationBase if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object - from isaaclab_arena.assets.object_reference import ObjectReference + from isaaclab_arena.assets.object_base import ObjectBase class RelationSolver: @@ -119,27 +118,20 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - def solve( self, - objects: list[Object | ObjectReference], - initial_positions: ( - dict[Object | ObjectReference, tuple[float, float, float]] - | list[dict[Object | ObjectReference, tuple[float, float, float]]] - ), - ) -> ( - dict[Object | ObjectReference, tuple[float, float, float]] - | list[dict[Object | ObjectReference, tuple[float, float, float]]] - ): + objects: list[ObjectBase], + initial_positions: list[dict[ObjectBase, tuple[float, float, float]]], + ) -> list[dict[ObjectBase, tuple[float, float, float]]]: """Solve for optimal positions of all objects. Args: - objects: List of Object or ObjectReference instances. Must include at least one object + objects: List of ObjectBase instances. Must include at least one object marked with IsAnchor() which serves as a fixed reference. - initial_positions: A single dict (backward compat, single-env) or a list - of dicts (one per env for batched). + initial_positions: List of dicts (one per env). Use a single-element list + for single-env placement. Returns: - Single dict when input is a dict, or list of dicts when input is a list. + List of dicts (one per env) mapping objects to their solved (x, y, z) positions. """ - single_input = isinstance(initial_positions, dict) state = RelationSolverState(objects, initial_positions) if self.params.verbose: @@ -155,8 +147,7 @@ def solve( print("No optimizable objects, skipping solver.") self._last_loss_history = [0.0] self._last_position_history = [state.get_all_positions_snapshot()] - final = state.get_final_positions() - return final[0] if single_input else final + return state.get_final_positions() # Setup optimizer (only for optimizable positions) optimizer = torch.optim.Adam([state.optimizable_positions], lr=self.params.lr) @@ -199,8 +190,7 @@ def solve( self._last_loss_history = loss_history self._last_position_history = position_history - final = state.get_final_positions() - return final[0] if single_input else final + return state.get_final_positions() @property def last_loss_history(self) -> list[float]: @@ -217,7 +207,7 @@ def last_position_history(self) -> list: """Position snapshots from the most recent solve() call.""" return self._last_position_history - def debug_losses(self, objects: list[Object | ObjectReference]) -> None: + def debug_losses(self, objects: list[ObjectBase]) -> None: """Print detailed loss breakdown for all relations using final positions. Call this after solve() to inspect why objects may not be correctly positioned. @@ -237,13 +227,13 @@ def debug_losses(self, objects: list[Object | ObjectReference]) -> None: # Build positions dict from final position history final_positions = {obj: (pos[0], pos[1], pos[2]) for obj, pos in zip(objects, final_positions_list)} - state = RelationSolverState(objects, final_positions) + state = RelationSolverState(objects, [final_positions]) self._compute_total_loss(state, debug=True) print("\n" + "=" * 60) def _print_relation_debug( - obj: Object | ObjectReference, + obj: ObjectBase, relation: Relation, child_pos: torch.Tensor, parent_pos: torch.Tensor, @@ -289,7 +279,7 @@ def _print_relation_debug( def _print_unary_relation_debug( - obj: Object, + obj: ObjectBase, relation: AtPosition, child_pos: torch.Tensor, loss: torch.Tensor, diff --git a/isaaclab_arena/relations/relation_solver_state.py b/isaaclab_arena/relations/relation_solver_state.py index 990bb34b5..d6d5d1aa1 100644 --- a/isaaclab_arena/relations/relation_solver_state.py +++ b/isaaclab_arena/relations/relation_solver_state.py @@ -11,8 +11,7 @@ from isaaclab_arena.relations.relations import get_anchor_objects if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object - from isaaclab_arena.assets.object_reference import ObjectReference + from isaaclab_arena.assets.object_base import ObjectBase class RelationSolverState: @@ -28,32 +27,27 @@ class RelationSolverState: def __init__( self, - objects: list[Object | ObjectReference], - initial_positions: ( - dict[Object | ObjectReference, tuple[float, float, float]] - | list[dict[Object | ObjectReference, tuple[float, float, float]]] - ), + objects: list[ObjectBase], + initial_positions: list[dict[ObjectBase, tuple[float, float, float]]], ): """Initialize optimization state. Args: - objects: List of all Object or ObjectReference instances to track. Must include at least one + objects: List of all ObjectBase instances to track. Must include at least one object marked with IsAnchor() which serves as a fixed reference. - initial_positions: A single dict (backward compat, treated as single-env) or a list - of dicts (one per env). Length 1 = single-env, length > 1 = batched. + initial_positions: List of dicts (one per env). Length 1 = single-env, + length > 1 = batched. """ - if isinstance(initial_positions, dict): - initial_positions = [initial_positions] assert len(initial_positions) >= 1, "initial_positions must contain at least one dict." anchor_objects = get_anchor_objects(objects) assert len(anchor_objects) > 0, "No anchor object found in objects list." self._all_objects = objects - self._anchor_objects: set[Object] = set(anchor_objects) + self._anchor_objects: set[ObjectBase] = set(anchor_objects) self._optimizable_objects = [obj for obj in objects if obj not in self._anchor_objects] # Build object-to-index mapping - self._obj_to_idx: dict[Object | ObjectReference, int] = {obj: i for i, obj in enumerate(objects)} + self._obj_to_idx: dict[ObjectBase, int] = {obj: i for i, obj in enumerate(objects)} # Extract positions from each env's dict self._num_envs = len(initial_positions) @@ -104,16 +98,16 @@ def optimizable_positions(self) -> torch.Tensor | None: return self._optimizable_positions @property - def optimizable_objects(self) -> list[Object]: + def optimizable_objects(self) -> list[ObjectBase]: """List of optimizable objects (excludes anchors).""" return self._optimizable_objects @property - def anchor_objects(self) -> set[Object]: + def anchor_objects(self) -> set[ObjectBase]: """Set of anchor objects (fixed during optimization).""" return self._anchor_objects - def get_position(self, obj: Object | ObjectReference) -> torch.Tensor: + def get_position(self, obj: ObjectBase) -> torch.Tensor: """Get current position for an object. Args: @@ -142,7 +136,7 @@ def get_all_positions_snapshot(self) -> list[tuple[float, float, float]]: """ return [tuple(self.get_position(obj)[0].detach().tolist()) for obj in self._all_objects] - def get_final_positions(self) -> list[dict[Object | ObjectReference, tuple[float, float, float]]]: + def get_final_positions(self) -> list[dict[ObjectBase, tuple[float, float, float]]]: """Get final positions as a list of dicts, one per env. Returns: @@ -150,7 +144,7 @@ def get_final_positions(self) -> list[dict[Object | ObjectReference, tuple[float """ out = [] for e in range(self._num_envs): - d: dict[Object | ObjectReference, tuple[float, float, float]] = {} + d: dict[ObjectBase, tuple[float, float, float]] = {} for obj in self._all_objects: pos = self.get_position(obj)[e].detach().tolist() d[obj] = (pos[0], pos[1], pos[2]) diff --git a/isaaclab_arena/relations/relations.py b/isaaclab_arena/relations/relations.py index 15427489f..d64559011 100644 --- a/isaaclab_arena/relations/relations.py +++ b/isaaclab_arena/relations/relations.py @@ -14,8 +14,7 @@ from isaaclab_arena.utils.pose import PoseRange if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object - from isaaclab_arena.assets.object_reference import ObjectReference + from isaaclab_arena.assets.object_base import ObjectBase class Side(Enum): @@ -41,10 +40,10 @@ class RelationBase: class Relation(RelationBase): """Base class for spatial relationships between objects.""" - def __init__(self, parent: Object | ObjectReference, relation_loss_weight: float = 1.0): + def __init__(self, parent: ObjectBase, relation_loss_weight: float = 1.0): """ Args: - parent: The parent asset in the relationship (Object or ObjectReference). + parent: The parent asset in the relationship. relation_loss_weight: Weight for the relationship loss function. """ self.parent = parent @@ -62,7 +61,7 @@ class NextTo(Relation): def __init__( self, - parent: Object | ObjectReference, + parent: ObjectBase, relation_loss_weight: float = 1.0, distance_m: float = 0.05, side: Side = Side.POSITIVE_X, @@ -102,7 +101,7 @@ class On(Relation): def __init__( self, - parent: Object | ObjectReference, + parent: ObjectBase, relation_loss_weight: float = 1.0, clearance_m: float = 0.01, ): @@ -135,7 +134,7 @@ class NoCollision(Relation): def __init__( self, - parent: Object | ObjectReference, + parent: ObjectBase, relation_loss_weight: float = 1.0, clearance_m: float = 0.01, ): @@ -346,7 +345,7 @@ def __init__( self.relation_loss_weight = relation_loss_weight -def get_anchor_objects(objects: list[Object | ObjectReference]) -> list[Object | ObjectReference]: +def get_anchor_objects(objects: list[ObjectBase]) -> list[ObjectBase]: """Get all anchor objects from a list of objects. Anchor objects are marked with IsAnchor() relation and serve as diff --git a/isaaclab_arena/tests/test_no_collision_loss.py b/isaaclab_arena/tests/test_no_collision_loss.py index 944d9759f..6c3e485fd 100644 --- a/isaaclab_arena/tests/test_no_collision_loss.py +++ b/isaaclab_arena/tests/test_no_collision_loss.py @@ -190,7 +190,7 @@ def test_relation_solver_no_collision_produces_separated_positions(): solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) solver = RelationSolver(params=solver_params) - result = solver.solve(objects=objects, initial_positions=initial_positions) + result = solver.solve(objects=objects, initial_positions=[initial_positions])[0] pos_a = result[box_a] pos_b = result[box_b] @@ -208,12 +208,12 @@ def test_relation_solver_no_collision_same_inputs_reproducible(): solver_params = RelationSolverParams(max_iters=50) solver1 = RelationSolver(params=solver_params) - result1 = solver1.solve(objects=[table1, box_a1, box_b1], initial_positions=initial_positions1) + result1 = solver1.solve(objects=[table1, box_a1, box_b1], initial_positions=[initial_positions1])[0] table2, box_a2, box_b2 = _create_no_collision_scene() initial_positions2 = {table2: initial[0], box_a2: initial[1], box_b2: initial[2]} solver2 = RelationSolver(params=solver_params) - result2 = solver2.solve(objects=[table2, box_a2, box_b2], initial_positions=initial_positions2) + result2 = solver2.solve(objects=[table2, box_a2, box_b2], initial_positions=[initial_positions2])[0] assert result1[box_a1] == result2[box_a2], "box_a positions should match" assert result1[box_b1] == result2[box_b2], "box_b positions should match" diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index 9e0d50082..892d96cbf 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -63,14 +63,14 @@ def test_relation_solver_same_inputs_produces_identical_result(): initial_positions1 = {desk1: desk_pos, box1_run1: fixed_box1_pos, box2_run1: fixed_box2_pos} solver1 = RelationSolver(params=solver_params) - result1 = solver1.solve(objects=[desk1, box1_run1, box2_run1], initial_positions=initial_positions1) + result1 = solver1.solve(objects=[desk1, box1_run1, box2_run1], initial_positions=[initial_positions1])[0] # Run 2 (fresh objects, same initial positions) desk2, box1_run2, box2_run2 = _create_test_objects() initial_positions2 = {desk2: desk_pos, box1_run2: fixed_box1_pos, box2_run2: fixed_box2_pos} solver2 = RelationSolver(params=solver_params) - result2 = solver2.solve(objects=[desk2, box1_run2, box2_run2], initial_positions=initial_positions2) + result2 = solver2.solve(objects=[desk2, box1_run2, box2_run2], initial_positions=[initial_positions2])[0] # Compare by name (different object instances) for obj1 in result1: From 68784da6cba7a803d862c01d961724bc3bd3de37 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Wed, 1 Apr 2026 12:40:05 -0700 Subject: [PATCH 12/12] update dataclass name --- isaaclab_arena/relations/object_placer.py | 26 +++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 2c552cdd7..bfba41096 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -22,8 +22,8 @@ @dataclass -class _Candidate: - """A single placement candidate produced by the solver.""" +class PlacementCandidate: + """A single solver result, ranked and selected in ObjectPlacer.place().""" loss: float """Loss value returned by the solver.""" @@ -118,35 +118,35 @@ def place( assert self._solver.last_loss_per_env is not None all_losses: list[float] = self._solver.last_loss_per_env.cpu().tolist() - all_candidates: list[_Candidate] = [] + all_candidates: list[PlacementCandidate] = [] for idx in range(num_candidates): loss = all_losses[idx] is_valid = self._validate_placement(all_positions[idx]) - all_candidates.append(_Candidate(loss, all_positions[idx], is_valid)) + all_candidates.append(PlacementCandidate(loss, all_positions[idx], is_valid)) # Sort: valid solutions first (by loss), then invalid (by loss) - all_candidates.sort(key=lambda c: (not c.is_valid, c.loss)) + all_candidates.sort(key=lambda candidate: (not candidate.is_valid, candidate.loss)) selected = all_candidates[:num_results] - n_valid = sum(1 for c in selected if c.is_valid) + n_valid = sum(1 for candidate in selected if candidate.is_valid) if self.params.verbose: - total_valid = sum(1 for c in all_candidates if c.is_valid) - finite_losses = [c.loss for c in all_candidates if math.isfinite(c.loss)] + total_valid = sum(1 for candidate in all_candidates if candidate.is_valid) + finite_losses = [candidate.loss for candidate in all_candidates if math.isfinite(candidate.loss)] mean_loss = sum(finite_losses) / len(finite_losses) if finite_losses else float("inf") print( f"Solved {num_candidates} candidates in one batch: mean loss = {mean_loss:.6f}," f" {total_valid} valid, selected best {num_results} ({n_valid} valid)" ) - final_per_env: list[dict] = [c.positions for c in selected] + final_per_env: list[dict] = [candidate.positions for candidate in selected] results_per_env = [ PlacementResult( - success=c.is_valid, - positions=c.positions, - final_loss=c.loss, + success=candidate.is_valid, + positions=candidate.positions, + final_loss=candidate.loss, attempts=self.params.max_placement_attempts, ) - for c in selected + for candidate in selected ] if self.params.apply_positions_to_objects: