From a1479500c908c02ad7cabb3aebbf7672d37e1fcb Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Tue, 3 Mar 2026 23:15:06 +0100 Subject: [PATCH 01/22] Initial implementation of GDDPM model --- conda_environment.yaml | 5 +- imitation/config/policy/gddpm_policy.yaml | 34 ++ imitation/model/gddpm.py | 418 ++++++++++++++++++++++ 3 files changed, 454 insertions(+), 3 deletions(-) create mode 100644 imitation/config/policy/gddpm_policy.yaml create mode 100644 imitation/model/gddpm.py diff --git a/conda_environment.yaml b/conda_environment.yaml index 36eb619..0037d35 100644 --- a/conda_environment.yaml +++ b/conda_environment.yaml @@ -14,8 +14,8 @@ dependencies: - pytorchvideo - gymnasium==0.28.1 - gym - - -e git+https://github.com/ARISE-Initiative/robomimic@main#egg=robomimic - - diffusers + - robomimic + - diffusers==0.29.2 - zarr - einops - tqdm @@ -33,6 +33,5 @@ dependencies: - wandb # related work dependencies - -e git+https://github.com/columbia-ai-robotics/diffusion_policy@main#egg=diffusion_policy - - -e git+https://github.com/anindex/stoch_gpmp@main#egg=stoch_gpmp - -e git+https://github.com/anindex/torch_robotics@main#egg=torch_robotics \ No newline at end of file diff --git a/imitation/config/policy/gddpm_policy.yaml b/imitation/config/policy/gddpm_policy.yaml new file mode 100644 index 0000000..35fb8b6 --- /dev/null +++ b/imitation/config/policy/gddpm_policy.yaml @@ -0,0 +1,34 @@ +_target_: imitation.policy.graph_ddpm_policy.GraphConditionalDDPMPolicy + +obs_dim: ${task.obs_dim} +action_dim: ${task.action_dim} + +node_feature_dim: 1 # [joint_val] (same as graph_ddpm_policy) +num_edge_types: 2 # robot joints, object-robot +pred_horizon: ${pred_horizon} +obs_horizon: ${obs_horizon} +action_horizon: ${action_horizon} +num_diffusion_iters: 100 +dataset: ${task.dataset} + +denoising_network: + _target_: imitation.model.gddpm.GDDPMNoisePred + node_feature_dim: ${policy.node_feature_dim} + cond_feature_dim: 6 # 6-D rotation obs features (excl. node-id) + obs_horizon: ${obs_horizon} + pred_horizon: ${pred_horizon} + edge_feature_dim: 1 + num_edge_types: ${policy.num_edge_types} + # GDDPM-specific hyper-parameters + residual_layers: 8 + residual_channels: 32 + dilation_cycle_length: 2 + hidden_dim: 256 + diffusion_step_embed_dim: 64 + num_diffusion_steps: ${policy.num_diffusion_iters} + +ckpt_path: ./weights/gddpm_policy_${task.task_name}_${task.dataset_type}_${task.control_mode}_${policy.num_diffusion_iters}iters.pt +lr: 2e-4 +batch_size: 128 +use_normalization: True +keep_first_action: True diff --git a/imitation/model/gddpm.py b/imitation/model/gddpm.py new file mode 100644 index 0000000..24cc896 --- /dev/null +++ b/imitation/model/gddpm.py @@ -0,0 +1,418 @@ +""" +GDDPM: Graph-based Denoising Diffusion Probabilistic Model +Adapted from: https://github.com/AmirMiraki/GDDPM + +Original paper: + AmirMiraki et al. "Probabilistic forecasting of renewable energy and electricity demand + using Graph-based Denoising Diffusion Probabilistic Model", Energy and AI, 2024. + https://doi.org/10.1016/j.egyai.2024.100459 + +This file adapts the GDDPM EpsilonTheta denoising network to the robot-graph diffusion +setting used in this project. The key architectural differences from the project's existing +ConditionalGraphNoisePred are: + + 1. Temporal backbone: dilated 1-D convolutions (vs. EGNN message-passing). + 2. Spatial backbone: GatedGraphConv (torch_geometric) inside each residual block, + applied jointly over the dilated-conv output. + 3. Conditioning: same EGraphConditionEncoder is re-used for graph-level FiLM conditioning; + the resulting vector is upsampled with a CondUpsampler MLP before being injected into + every residual block (instead of FiLM scales/biases per EGNN layer). + +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Linear +from torch_geometric.nn import GatedGraphConv +from torch_geometric.utils import add_self_loops +from torch_geometric.nn.pool import global_mean_pool + +from imitation.model.graph_diffusion import EGraphConditionEncoder + + +# --------------------------------------------------------------------------- +# Diffusion step embedding (Section 3.2 of the GDDPM paper) +# --------------------------------------------------------------------------- + +class DiffusionEmbedding(nn.Module): + """ + Sinusoidal embedding for the diffusion timestep, then projected through + two linear layers. Adapted from epsilon_theta.py in the GDDPM repo. + """ + def __init__(self, dim: int, proj_dim: int, max_steps: int = 500): + super().__init__() + self.register_buffer("embedding", + self._build_embedding(dim, max_steps), + persistent=False) + self.projection1 = nn.Linear(dim * 2, proj_dim) + self.projection2 = nn.Linear(proj_dim, proj_dim) + + def forward(self, diffusion_step: torch.Tensor) -> torch.Tensor: + x = self.embedding[diffusion_step] # (B, dim*2) + x = F.silu(self.projection1(x)) # (B, proj_dim) + x = F.silu(self.projection2(x)) # (B, proj_dim) + return x + + @staticmethod + def _build_embedding(dim: int, max_steps: int) -> torch.Tensor: + steps = torch.arange(max_steps).unsqueeze(1) # [T, 1] + dims = torch.arange(dim).unsqueeze(0) # [1, dim] + table = steps * 10.0 ** (dims * 4.0 / dim) # [T, dim] + table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) + return table # [T, dim*2] + + +# --------------------------------------------------------------------------- +# Conditioning upsampler MLP (lightweight replacement for RNN encoder in GDDPM) +# --------------------------------------------------------------------------- + +class CondUpsampler(nn.Module): + """ + Two-layer MLP that projects the global graph conditioning vector to target_dim, + the number of nodes (or a spatial dim used by the residual blocks). + """ + def __init__(self, cond_length: int, target_dim: int): + super().__init__() + self.linear1 = nn.Linear(cond_length, target_dim // 2) + self.linear2 = nn.Linear(target_dim // 2, target_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.leaky_relu(self.linear1(x), 0.4) + x = F.leaky_relu(self.linear2(x), 0.4) + return x + + +# --------------------------------------------------------------------------- +# Residual block: dilated conv + GatedGraphConv + conditioner +# --------------------------------------------------------------------------- + +class ResidualBlock(nn.Module): + """ + Core GDDPM residual block, adapted for per-node action sequences. + + Inputs (per node, batched): + x - (N, residual_channels, pred_horizon) current node representation + conditioner - (N, 1, target_dim) upsampled global conditioning + diffusion_step - (N, residual_channels) projected diffusion embedding + edge_index - graph connectivity + edge_weight - optional edge weights for GatedGraphConv + + The spatial GatedGraphConv is applied on x reshaped to (N, channels) per + time-step, averaged across the time axis before mixing back in. + """ + def __init__(self, + hidden_size: int, + residual_channels: int, + dilation: int): + super().__init__() + self.residual_channels = residual_channels + + # temporal: dilated causal conv (C -> 2C for gated activation) + self.dilated_conv = nn.Conv1d( + residual_channels, + 2 * residual_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + padding_mode="circular", + ) + + # spatial: GatedGraphConv operates on C features (time-averaged), + # then expanded to 2C via a linear layer before mixing + self.graph_conv = GatedGraphConv(residual_channels, num_layers=1) + self.graph_expand = nn.Linear(residual_channels, 2 * residual_channels) + + # diffusion step projection: hidden_size -> 2C (added before gating) + self.diffusion_projection = nn.Linear(hidden_size, 2 * residual_channels) + # conditioner: (N,1,T_up) -> (N, 2C, T_up) + self.conditioner_projection = nn.Conv1d( + 1, 2 * residual_channels, kernel_size=1, padding=2, padding_mode="circular" + ) + # output: C -> 2C (residual + skip) + self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) + + nn.init.kaiming_normal_(self.conditioner_projection.weight) + nn.init.kaiming_normal_(self.output_projection.weight) + + def forward(self, + x: torch.Tensor, + conditioner: torch.Tensor, + diffusion_step: torch.Tensor, + edge_index: torch.Tensor, + edge_weight: torch.Tensor = None) -> tuple: + """ + x: (N, residual_channels, T) + conditioner: (N, 1, T_cond) - output of CondUpsampler unsqueezed + diffusion_step: (N, hidden_size) + Returns: (residual_out, skip_connection) both (N, residual_channels, T) + """ + N, C, T = x.shape + + # --- conditioner projection (N,1,T_up) -> (N, 2C, T'); trim to T + cond_proj = self.conditioner_projection(conditioner) # (N, 2C, T') + min_cond = min(cond_proj.shape[-1], T) + cond_proj = cond_proj[..., :min_cond] # (N, 2C, T_trim) + + # --- diffusion step: (N, 2C) -> broadcast over time + diff_proj = self.diffusion_projection(diffusion_step) # (N, 2C) + diff_proj = diff_proj.unsqueeze(-1) # (N, 2C, 1) + + # --- temporal dilated conv (N, C, T) -> (N, 2C, T') + y_temporal = self.dilated_conv(x) # (N, 2C, T') + T_conv = y_temporal.shape[-1] + + # --- spatial GatedGraphConv on time-averaged features + x_flat = x.mean(dim=-1) # (N, C) + y_graph = self.graph_conv(x_flat, edge_index) # (N, C) + y_graph = self.graph_expand(y_graph) # (N, 2C) + y_spatial = y_graph.unsqueeze(-1).expand(N, 2 * C, T_conv) # (N, 2C, T') + + # --- combine: align all to T_conv + T_min = min(T_conv, min_cond) + y = (y_temporal[..., :T_min] + + y_spatial[..., :T_min] + + cond_proj[..., :T_min] + + diff_proj.expand(N, 2 * C, T_min)) # (N, 2C, T_min) + + # --- gated activation + gate, filt = torch.chunk(y, 2, dim=1) # each (N, C, T_min) + y = torch.sigmoid(gate) * torch.tanh(filt) # (N, C, T_min) + + # --- output: (N, C, T_min) -> (N, 2C, T_min) + y = F.leaky_relu(self.output_projection(y), 0.4) + residual, skip = torch.chunk(y, 2, dim=1) # each (N, C, T_min) + + # --- residual skip: align back to original T + T_res = min(residual.shape[-1], T) + residual_out = (x[..., :T_res] + residual[..., :T_res]) / math.sqrt(2.0) + return residual_out, skip[..., :T_res] + + +# --------------------------------------------------------------------------- +# Top-level GDDPM noise predictor +# --------------------------------------------------------------------------- + +class GDDPMNoisePred(nn.Module): + """ + GDDPM denoising network with the same interface as ConditionalGraphNoisePred. + + Architecture summary: + - EGraphConditionEncoder encodes the graph-structured observation into a + per-graph conditioning vector (same as in the project's existing model). + - CondUpsampler projects it to a node-dimension matching the graph size. + - A stack of ResidualBlocks (dilated conv + GatedGraphConv) predicts noise. + + Args: + node_feature_dim: feature dimension per node per step (e.g. 1 for joint value) + cond_feature_dim: obs feature dim (excl. node-id), e.g. 6 for 6D rotation + obs_horizon: number of observation steps for conditioning + pred_horizon: number of prediction steps (action horizon) + edge_feature_dim: edge attribute size (usually 1) + num_edge_types: number of edge type categories + residual_layers: number of ResidualBlock layers + residual_channels: channels inside each block + dilation_cycle_length: dilation doubles every this many layers + hidden_dim: hidden size for EGraphConditionEncoder and diffusion embed + diffusion_step_embed_dim: raw sinusoidal embedding size (≤ hidden_dim) + num_diffusion_steps: total DDPM timesteps (for embedding table) + device: torch device (auto-detected if None) + """ + + def __init__(self, + node_feature_dim: int, + cond_feature_dim: int, + obs_horizon: int, + pred_horizon: int, + edge_feature_dim: int, + num_edge_types: int, + residual_layers: int = 8, + residual_channels: int = 8, + dilation_cycle_length: int = 2, + hidden_dim: int = 256, + diffusion_step_embed_dim: int = 64, + num_diffusion_steps: int = 100, + device=None): + super().__init__() + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device + + self.node_feature_dim = node_feature_dim + self.cond_feature_dim = cond_feature_dim + self.obs_horizon = obs_horizon + self.pred_horizon = pred_horizon + self.hidden_dim = hidden_dim + self.residual_channels = residual_channels + self.num_diffusion_steps = num_diffusion_steps + + # --- Observation encoder (same as ConditionalGraphNoisePred) ---------- + # cond_channels: output length from EGraphConditionEncoder. + # We use it as the `cond_length` fed into CondUpsampler. + self.cond_channels = hidden_dim + self.cond_encoder = EGraphConditionEncoder( + input_dim=cond_feature_dim * obs_horizon, + output_dim=self.cond_channels, + hidden_dim=hidden_dim, + device=self.device, + ).to(self.device) + + # --- Diffusion step embedding ------------------------------------------ + self.diffusion_embedding = DiffusionEmbedding( + dim=diffusion_step_embed_dim, + proj_dim=hidden_dim, + max_steps=num_diffusion_steps, + ).to(self.device) + + # --- Conditioning upsampler --------------------------------------------- + # Output size = pred_horizon so that it can be broadcast per time step. + self.cond_upsampler = CondUpsampler( + cond_length=self.cond_channels, + target_dim=pred_horizon, + ).to(self.device) + + # --- Input projection: (node_feature_dim) -> residual_channels ---------- + self.input_projection = nn.Conv1d( + node_feature_dim, + residual_channels, + kernel_size=1, + padding=2, + padding_mode="circular", + ).to(self.device) + + # --- Residual stack ------------------------------------------------------ + self.residual_blocks = nn.ModuleList([ + ResidualBlock( + hidden_size=hidden_dim, + residual_channels=residual_channels, + dilation=2 ** (i % dilation_cycle_length), + ) + for i in range(residual_layers) + ]) + self.residual_blocks.to(self.device) + + # --- Output projection --------------------------------------------------- + self.skip_projection = nn.Conv1d( + residual_channels, residual_channels, kernel_size=3 + ).to(self.device) + self.output_projection = nn.Conv1d( + residual_channels, node_feature_dim, kernel_size=3 + ).to(self.device) + + nn.init.kaiming_normal_(self.input_projection.weight) + nn.init.kaiming_normal_(self.skip_projection.weight) + nn.init.zeros_(self.output_projection.weight) + + # ------------------------------------------------------------------ + def forward(self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + x_coord: torch.Tensor, + cond: torch.Tensor, + timesteps: torch.Tensor, + batch: torch.Tensor = None): + """ + Drop-in equivalent of ConditionalGraphNoisePred.forward. + + Args: + x: (N_total, pred_horizon, node_feature_dim) noisy action + edge_index: (2, E) + edge_attr: (E,) or (E, 1) — edge attributes / types + x_coord: (N_total, 3) — 3D node positions + cond: (N_total, obs_horizon, cond_feature_dim+1) — obs (+node-id last) + timesteps: (B,) — diffusion timestep per graph in batch + batch: (N_total,) — maps each node to its graph index + + Returns: + noise_pred: (N_total, pred_horizon, node_feature_dim) + x_coord: (N_total, 3) unchanged (kept for API compatibility) + """ + # ---- move to device / cast ---------------------------------------- + x = x.float().to(self.device) # (N, T, F) + edge_attr = edge_attr.float().to(self.device) + edge_index = edge_index.to(self.device) + x_coord = x_coord.float().to(self.device) + timesteps = timesteps.to(self.device) + if batch is None: + batch = torch.zeros(x.shape[0], dtype=torch.long, device=self.device) + else: + batch = batch.long().to(self.device) + + # separate node-id from conditioning features (last channel of cond) + ids = cond[:, 0, -1].long().to(self.device) + cond_feats = cond[:, :, :-1].float().to(self.device) # (N, obs_horizon, C) + + # ---- add self-loops for GatedGraphConv compatibility --------------- + edge_attr_1d = edge_attr.reshape(-1) + edge_index_sl, edge_attr_sl = add_self_loops( + edge_index, edge_attr_1d, + num_nodes=x.shape[0], fill_value=0.0 + ) + + # ---- Graph-level conditioning vector -------------------------------- + # EGraphConditionEncoder returns (B, cond_channels) + graph_cond = self.cond_encoder( + cond_feats, edge_index_sl, x_coord, edge_attr_sl.unsqueeze(-1), + batch=batch, ids=ids + ) # (B, cond_channels) + + # ---- Up-sample conditioning to pred_horizon ------------------------- + cond_up = self.cond_upsampler(graph_cond) # (B, pred_horizon) + + # Broadcast from per-graph to per-node + cond_up_node = cond_up[batch] # (N, pred_horizon) + cond_up_node = cond_up_node.unsqueeze(1) # (N, 1, pred_horizon) + + # ---- Diffusion step embedding ---------------------------------------- + diffusion_step = self.diffusion_embedding(timesteps) # (B, hidden_dim) + diffusion_step_node = diffusion_step[batch] # (N, hidden_dim) + + # ---- Reshape x: (N, T, F) -> (N, F, T) for Conv1d ------------------ + x_conv = x.permute(0, 2, 1) # (N, F, T) + h = F.leaky_relu(self.input_projection(x_conv), 0.4) # (N, C, T') + + # ---- Residual stack -------------------------------------------------- + # Trim h to pred_horizon (circular padding may have added extra steps) + skip_sum = None + for block in self.residual_blocks: + h, skip = block( + h, + cond_up_node, + diffusion_step_node, + edge_index_sl, + edge_weight=None, + ) + if skip_sum is None: + skip_sum = skip + else: + # align lengths + min_T = min(skip_sum.shape[-1], skip.shape[-1]) + skip_sum = (skip_sum[..., :min_T] + skip[..., :min_T]) + + n_layers = len(self.residual_blocks) + skip_sum = skip_sum / math.sqrt(n_layers) # (N, C, T') + + # ---- Output projection ----------------------------------------------- + out = F.leaky_relu(self.skip_projection(skip_sum), 0.4) # (N, C, T'') + out = self.output_projection(out) # (N, F, T''') + + # ---- Crop / pad to exactly pred_horizon ------------------------------ + T_out = out.shape[-1] + if T_out >= self.pred_horizon: + out = out[..., :self.pred_horizon] + else: + # pad with zeros if output is too short (edge case) + pad = torch.zeros( + out.shape[0], out.shape[1], self.pred_horizon - T_out, + device=self.device + ) + out = torch.cat([out, pad], dim=-1) + + # (N, F, T) -> (N, T, F) + noise_pred = out.permute(0, 2, 1) + + return noise_pred, x_coord From 98bd8f7d8349faba880ce850790fb7c6d6141e1b Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Fri, 6 Mar 2026 23:40:31 +0100 Subject: [PATCH 02/22] Extend tests --- tests/test_node_pos_consistency.py | 286 +++++++++ tests/test_policy_dataset_replay.py | 828 +++++++++++++++++++++++++++ tests/test_train_eval_consistency.py | 492 ++++++++++++++++ 3 files changed, 1606 insertions(+) create mode 100644 tests/test_policy_dataset_replay.py create mode 100644 tests/test_train_eval_consistency.py diff --git a/tests/test_node_pos_consistency.py b/tests/test_node_pos_consistency.py index 490e84c..228fbd7 100644 --- a/tests/test_node_pos_consistency.py +++ b/tests/test_node_pos_consistency.py @@ -351,3 +351,289 @@ def test_fk_eef_tracks_dataset_eef_pos(self, episode_data): f"(2) calculate_panda_joints_positions uses an unexpected link frame, or " f"(3) a bug in compute_node_pos_xyz." ) + + +# ── helpers for dataset / wrapper _get_node_pos ─────────────────────────────── +# +# We invoke the *actual* _get_node_pos methods from both classes without +# constructing the full objects (which require heavy dependencies). We build +# minimal mock objects that expose only the attributes read by _get_node_pos. + +def _make_dataset_node_pos_fn(): + """ + Return a callable that mirrors RobomimicGraphDataset._get_node_pos. + Binds the method to a lightweight mock that carries only the attributes + needed: BASE_LINK_SHIFT, BASE_LINK_ROTATION, num_robots, + rotation_transformer, object_state_keys, object_state_sizes, num_objects. + """ + import types + import importlib + from diffusion_policy.model.common.rotation_transformer import RotationTransformer + + # Import the unbound method directly from the module + spec = importlib.util.spec_from_file_location( + "rg_dataset", + os.path.join(os.path.dirname(__file__), "..", "imitation", "dataset", + "robomimic_graph_dataset.py") + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + # Lift-task config (mirrors lift_graph.yaml) + mock = types.SimpleNamespace( + num_robots=1, + BASE_LINK_SHIFT=[[-0.56, 0.0, 0.912]], + BASE_LINK_ROTATION=[[0.0, 0.0, 0.0, 1.0]], + rotation_transformer=RotationTransformer(from_rep="quaternion", to_rep="rotation_6d"), + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + num_objects=1, + ) + + # Bind the method to our mock + get_node_pos = mod.RobomimicGraphDataset._get_node_pos.__get__(mock) + get_obj_pos = mod.RobomimicGraphDataset._get_object_pos.__get__(mock) + # Also bind _get_object_pos so the chain works + mock._get_object_pos = get_obj_pos + + return get_node_pos + + +def _make_wrapper_node_pos_fn(): + """ + Return a callable that mirrors RobomimicGraphWrapper._get_node_pos. + """ + import types + import importlib + from diffusion_policy.model.common.rotation_transformer import RotationTransformer + + spec = importlib.util.spec_from_file_location( + "rg_wrapper", + os.path.join(os.path.dirname(__file__), "..", "imitation", "env", + "robomimic_graph_wrapper.py") + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + mock = types.SimpleNamespace( + num_robots=1, + BASE_LINK_SHIFT=[[-0.56, 0.0, 0.912]], + BASE_LINK_ROTATION=[[0.0, 0.0, 0.0, 1.0]], + rotation_transformer=RotationTransformer(from_rep="quaternion", to_rep="rotation_6d"), + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + num_objects=1, + ) + + get_obj_pos = mod.RobomimicGraphWrapper._get_object_pos.__get__(mock) + mock._get_object_pos = get_obj_pos + get_node_pos = mod.RobomimicGraphWrapper._get_node_pos.__get__(mock) + + return get_node_pos + + +def _build_data_dict_for_t(joint_pos_T, gripper_qpos_T, t, + object_state=None): + """ + Build the 'data' dict expected by both _get_node_pos implementations. + object_state: (10,) floats for cube_pos(3) + cube_quat(4) + pad(3), + or None (zeros are used). + """ + data = { + "robot0_joint_pos": joint_pos_T, # (T, 7) – dataset form + "robot0_gripper_qpos": gripper_qpos_T, # (T, 2) + "object": np.zeros((gripper_qpos_T.shape[0], 10), dtype=np.float32) + if object_state is None else object_state, + } + return data + + +# ── test class ──────────────────────────────────────────────────────────────── + +class TestGraphDatasetAndEnvConsistency: + """ + Validate that RobomimicGraphDataset._get_node_pos and + RobomimicGraphWrapper._get_node_pos produce identical robot-link xyz + positions for the same joint inputs, and that both agree with the + standalone compute_node_pos_xyz helper used throughout the existing tests. + + Only the xyz columns (first 3) of node_pos are compared; the rotation + representation columns are not the focus of this test class. + """ + + @pytest.fixture(scope="class") + def fk_fns(self): + """Cache the two bound _get_node_pos callables (heavy to build).""" + return { + "dataset": _make_dataset_node_pos_fn(), + "wrapper": _make_wrapper_node_pos_fn(), + } + + def test_dataset_node_pos_xyz_matches_helper(self, episode_data, fk_fns): + """ + RobomimicGraphDataset._get_node_pos[:, :3] must equal + compute_node_pos_xyz for every timestep in the episode. + + This confirms that the standalone test helper faithfully replicates + the dataset's FK pipeline, so failures in replay tests can be + attributed to the environment rather than the helper function. + """ + joint_pos, gripper_qpos, _, _, _ = episode_data + T = len(joint_pos) + dataset_fn = fk_fns["dataset"] + + max_err = 0.0 + worst_t = -1 + for t in range(T): + data = _build_data_dict_for_t(joint_pos, gripper_qpos, t) + pos_dataset = dataset_fn(data, t)[:9, :3] # robot nodes only, xyz + pos_helper = compute_node_pos_xyz(joint_pos[t], gripper_qpos[t]) + err = float(torch.max(torch.abs(pos_dataset - pos_helper)).item()) + if err > max_err: + max_err = err + worst_t = t + + print(f"\n── Dataset vs helper node-pos ({EPISODE_KEY}) ────────────") + print(f" Steps checked : {T}") + print(f" Max xyz err : {max_err*1e3:.4f} mm at step {worst_t}") + + assert max_err < 1e-5, ( + f"Dataset _get_node_pos xyz differs from compute_node_pos_xyz by " + f"{max_err*1e6:.2f} μm at step {worst_t}. " + f"They should be numerically identical." + ) + + def test_wrapper_node_pos_xyz_matches_helper(self, episode_data, fk_fns): + """ + RobomimicGraphWrapper._get_node_pos[:, :3] must equal + compute_node_pos_xyz for every timestep in the episode. + + Confirms the wrapper's FK logic is identical to the dataset's. + """ + joint_pos, gripper_qpos, _, _, _ = episode_data + T = len(joint_pos) + wrapper_fn = fk_fns["wrapper"] + + max_err = 0.0 + worst_t = -1 + for t in range(T): + # Wrapper's _get_node_pos takes a flat obs-dict (not T-indexed) + data_live = { + "robot0_joint_pos": joint_pos[t], + "robot0_gripper_qpos": gripper_qpos[t], + "object": np.zeros(10, dtype=np.float32), + } + pos_wrapper = wrapper_fn(data_live)[:9, :3] + pos_helper = compute_node_pos_xyz(joint_pos[t], gripper_qpos[t]) + err = float(torch.max(torch.abs(pos_wrapper - pos_helper)).item()) + if err > max_err: + max_err = err + worst_t = t + + print(f"\n── Wrapper vs helper node-pos ({EPISODE_KEY}) ────────────") + print(f" Steps checked : {T}") + print(f" Max xyz err : {max_err*1e3:.4f} mm at step {worst_t}") + + assert max_err < 1e-5, ( + f"Wrapper _get_node_pos xyz differs from compute_node_pos_xyz by " + f"{max_err*1e6:.2f} μm at step {worst_t}. " + f"They should be numerically identical." + ) + + def test_dataset_and_wrapper_node_pos_agree(self, episode_data, fk_fns): + """ + Cross-check: dataset vs wrapper must agree to floating-point precision. + + Both implement the same FK + base_link transformation; any divergence + indicates a drift or bug in one of the two implementations. + """ + joint_pos, gripper_qpos, _, _, _ = episode_data + T = len(joint_pos) + dataset_fn = fk_fns["dataset"] + wrapper_fn = fk_fns["wrapper"] + + max_err = 0.0 + worst_t = -1 + for t in range(T): + data_ds = _build_data_dict_for_t(joint_pos, gripper_qpos, t) + data_live = { + "robot0_joint_pos": joint_pos[t], + "robot0_gripper_qpos": gripper_qpos[t], + "object": np.zeros(10, dtype=np.float32), + } + pos_ds = dataset_fn(data_ds, t)[:9, :3] + pos_wrapper = wrapper_fn(data_live)[:9, :3] + err = float(torch.max(torch.abs(pos_ds - pos_wrapper)).item()) + if err > max_err: + max_err = err + worst_t = t + + print(f"\n── Dataset vs wrapper node-pos ({EPISODE_KEY}) ───────────") + print(f" Steps checked : {T}") + print(f" Max xyz err : {max_err*1e3:.4f} mm at step {worst_t}") + + assert max_err < 1e-5, ( + f"Dataset and wrapper _get_node_pos xyz differ by " + f"{max_err*1e6:.2f} μm at step {worst_t}. " + f"The two FK implementations are inconsistent." + ) + + def test_wrapper_graph_obs_pos_matches_dataset_node_pos( + self, episode_data, fk_fns + ): + """ + End-to-end check: the graph observation produced by a live + RobomimicGraphWrapper should have pos[:9, :3] that matches + RobomimicGraphDataset._get_node_pos for the same joint state. + + This exercises the full graph construction path in the wrapper + (including _robosuite_obs_to_robomimic_graph) and checks that + the pos field stored on the Data object is what we expect. + + To avoid spinning up a full robosuite environment, we verify the + wrapper's _get_node_pos directly (already covered above) and + additionally check that a hand-crafted graph Data object using + wrapper_fn matches dataset_fn – completing the triangle: + + dataset._get_node_pos ↔ wrapper._get_node_pos ↔ graph.pos + + All three should agree to within 1e-5 m. + """ + joint_pos, gripper_qpos, _, _, _ = episode_data + T = len(joint_pos) + dataset_fn = fk_fns["dataset"] + wrapper_fn = fk_fns["wrapper"] + + import torch_geometric.data + + max_err = 0.0 + worst_t = -1 + for t in range(T): + data_ds = _build_data_dict_for_t(joint_pos, gripper_qpos, t) + data_live = { + "robot0_joint_pos": joint_pos[t], + "robot0_gripper_qpos": gripper_qpos[t], + "object": np.zeros(10, dtype=np.float32), + } + pos_wrapper = wrapper_fn(data_live) + # Simulate what _robosuite_obs_to_robomimic_graph does: it stores + # node_pos = self._get_node_pos(data) and passes it as graph.pos + graph_pos = pos_wrapper[:9, :3] # robot nodes xyz + pos_dataset = dataset_fn(data_ds, t)[:9, :3] + + err = float(torch.max(torch.abs(graph_pos - pos_dataset)).item()) + if err > max_err: + max_err = err + worst_t = t + + print(f"\n── Graph obs pos vs dataset node-pos ({EPISODE_KEY}) ─────") + print(f" Steps checked : {T}") + print(f" Max xyz err : {max_err*1e3:.4f} mm at step {worst_t}") + + assert max_err < 1e-5, ( + f"Graph observation pos[:9,:3] differs from dataset node_pos by " + f"{max_err*1e6:.2f} μm at step {worst_t}.\n" + f"This may indicate: (1) wrapper._get_node_pos diverges from " + f"dataset._get_node_pos, or (2) _robosuite_obs_to_robomimic_graph " + f"stores a different pos than _get_node_pos returns." + ) diff --git a/tests/test_policy_dataset_replay.py b/tests/test_policy_dataset_replay.py new file mode 100644 index 0000000..dbbbe91 --- /dev/null +++ b/tests/test_policy_dataset_replay.py @@ -0,0 +1,828 @@ +""" +Test Suite: Policy / Dataset Replay Consistency +================================================ + +Goal +---- +Validate that the full policy inference pipeline (obs_deque → get_action → +env.step) produces behaviour that is consistent with the offline dataset. + +These tests expose disconnects between the training data format and the +live observation format that would cause the policy to perform well on the +offline validation set but poorly in the real environment. + +Tests +----- +1. test_obs_deque_y_matches_dataset_y + The y tensor passed to the policy at inference is assembled from an + obs_deque of RobomimicGraphWrapper observations. The y tensor used + during training comes from RobomimicGraphDataset.get_y_feats(). + They must agree for the same joint state. + +2. test_dataset_playback_obs_format + Feed a sequence of dataset samples through the policy's obs_deque + assembly logic (mirrors get_action()'s first few lines) and verify that + the resulting nobs tensor has the correct shape and is normalised to [-1,1]. + +3. test_action_step_matches_dataset_transition + Step the live robosuite environment with the *dataset's recorded actions* + (not policy-predicted actions) and verify the resulting joint positions + agree with the next dataset observation within tolerance. + This confirms that the action representation used in the dataset is + compatible with the wrapper's step() interface. + +4. test_dataset_y_and_wrapper_y_feature_order_match + The obs feature vector (y) must have the same column ordering between + dataset and wrapper, because the normalizer is fit on the dataset's y. + Columns: [joint_pos(7), gripper_qpos(2), node_id(1)] for robot nodes. +""" + +import importlib.util +import json +import os +import types +import collections + +import h5py +import numpy as np +import pytest +import torch +from scipy.spatial.transform import Rotation as R + +# ── paths ───────────────────────────────────────────────────────────────────── +DATASET_PATH = "data/lift/ph/low_dim_v141.hdf5" +EPISODE_KEY = "demo_0" + +# ── lift-task config (lift_graph.yaml) ──────────────────────────────────────── +BASE_LINK_SHIFT = [[-0.56, 0.0, 0.912]] +BASE_LINK_ROTATION = [[0.0, 0.0, 0.0, 1.0]] + +# ── tolerances ──────────────────────────────────────────────────────────────── +Y_MATCH_TOL = 1e-5 # m — y tensors from dataset vs wrapper must agree +NORM_RANGE_TOL = 1.05 # normalised values must lie in [-NORM_RANGE_TOL, NORM_RANGE_TOL] +JOINT_STEP_TOL = 0.08 # rad — max joint error after one env.step from dataset action + + +# ── module loaders ──────────────────────────────────────────────────────────── + +def _load_module(name, rel_path): + spec = importlib.util.spec_from_file_location( + name, + os.path.join(os.path.dirname(__file__), "..", rel_path) + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ── fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture(scope="module") +def episode_data(): + """Load demo_0 from the HDF5 file.""" + with h5py.File(DATASET_PATH, "r") as f: + ep = f[f"data/{EPISODE_KEY}"] + joint_pos = ep["obs/robot0_joint_pos"][:] + gripper_qpos = ep["obs/robot0_gripper_qpos"][:] + gripper_qvel = ep["obs/robot0_gripper_qvel"][:] + joint_vel = ep["obs/robot0_joint_vel"][:] + object_obs = ep["obs/object"][:] + actions = ep["actions"][:] + states = ep["states"][:] + return dict( + joint_pos=joint_pos, + gripper_qpos=gripper_qpos, + gripper_qvel=gripper_qvel, + joint_vel=joint_vel, + object_obs=object_obs, + actions=actions, + states=states, + ) + + +@pytest.fixture(scope="module") +def dataset(): + """Instantiate a real RobomimicGraphDataset (uses the processed cache).""" + mod = _load_module("rg_dataset", "imitation/dataset/robomimic_graph_dataset.py") + ds = mod.RobomimicGraphDataset( + dataset_path=DATASET_PATH, + robots=["Panda"], + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + pred_horizon=16, + obs_horizon=4, + control_mode="JOINT_VELOCITY", + base_link_shift=BASE_LINK_SHIFT, + base_link_rotation=BASE_LINK_ROTATION, + ) + return ds + + +@pytest.fixture(scope="module") +def wrapper_get_y_fn(): + """Return a bound _get_y_feats callable from RobomimicGraphWrapper.""" + from diffusion_policy.model.common.rotation_transformer import RotationTransformer + + mod = _load_module("rg_wrapper", "imitation/env/robomimic_graph_wrapper.py") + + mock = types.SimpleNamespace( + num_robots=1, + BASE_LINK_SHIFT=BASE_LINK_SHIFT, + BASE_LINK_ROTATION=BASE_LINK_ROTATION, + rotation_transformer=RotationTransformer(from_rep="quaternion", to_rep="rotation_6d"), + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + num_objects=1, + ) + get_obj_pos = mod.RobomimicGraphWrapper._get_object_pos.__get__(mock) + mock._get_object_pos = get_obj_pos + get_y_feats = mod.RobomimicGraphWrapper._get_y_feats.__get__(mock) + return get_y_feats + + +@pytest.fixture(scope="module") +def dataset_get_y_fn(): + """Return a bound get_y_feats callable from RobomimicGraphDataset.""" + from diffusion_policy.model.common.rotation_transformer import RotationTransformer + + mod = _load_module("rg_dataset2", "imitation/dataset/robomimic_graph_dataset.py") + + mock = types.SimpleNamespace( + num_robots=1, + BASE_LINK_SHIFT=BASE_LINK_SHIFT, + BASE_LINK_ROTATION=BASE_LINK_ROTATION, + rotation_transformer=RotationTransformer(from_rep="quaternion", to_rep="rotation_6d"), + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + num_objects=1, + obs_feature_dim=7, + ) + get_obj_pos = mod.RobomimicGraphDataset._get_object_pos.__get__(mock) + mock._get_object_pos = get_obj_pos + get_y_feats = mod.RobomimicGraphDataset.get_y_feats.__get__(mock) + return get_y_feats + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _build_wrapper_obs_dict(episode_data, t): + """Build the flat obs dict expected by wrapper._get_y_feats at timestep t.""" + return { + "robot0_joint_pos": episode_data["joint_pos"][t], + "robot0_joint_vel": episode_data["joint_vel"][t], + "robot0_gripper_qpos": episode_data["gripper_qpos"][t], + "robot0_gripper_qvel": episode_data["gripper_qvel"][t], + "object": episode_data["object_obs"][t], + } + + +def _build_dataset_data_dict(episode_data): + """Build the time-indexed data dict expected by dataset.get_y_feats.""" + return { + "robot0_joint_pos": episode_data["joint_pos"], + "robot0_gripper_qpos": episode_data["gripper_qpos"], + "robot0_joint_vel": episode_data["joint_vel"], + "robot0_gripper_qvel": episode_data["gripper_qvel"], + "object": episode_data["object_obs"], + } + + +# ── Test 1: obs y tensors agree ─────────────────────────────────────────────── + +class TestObsYConsistency: + """ + The y tensor (observations) fed to the GDDPM must be identical whether + it comes from the dataset (training path) or from the wrapper (eval path). + + A mismatch here means the network sees a completely different conditioning + signal at eval time than it was trained on — a guaranteed performance cliff. + """ + + def test_wrapper_y_matches_dataset_y_at_each_step( + self, episode_data, wrapper_get_y_fn, dataset_get_y_fn + ): + """ + For every timestep t, compare: + - wrapper._get_y_feats(obs_dict_at_t) → shape (num_nodes, feat) + - dataset.get_y_feats(data_dict, t_vals=[t]) → shape (num_nodes, 1, feat) + + Both should agree on the robot-node rows (indices 0..8). + The last column (node ID) is part of y in both paths; it is the + running index 0..num_nodes-1 and should be identical. + """ + data_dict = _build_dataset_data_dict(episode_data) + T = len(episode_data["joint_pos"]) + + max_err = 0.0 + worst_t = -1 + for t in range(T): + obs_dict = _build_wrapper_obs_dict(episode_data, t) + y_wrapper = wrapper_get_y_fn(obs_dict) # (num_nodes, feat) + y_dataset = dataset_get_y_fn(data_dict, [t]) # (num_nodes, 1, feat) + y_ds_t = y_dataset[:, 0, :] # (num_nodes, feat) + + # Compare robot nodes only (first 9) + robot_rows_w = y_wrapper[:9, :] + robot_rows_d = y_ds_t[:9, :] + + err = float(torch.max(torch.abs(robot_rows_w - robot_rows_d)).item()) + if err > max_err: + max_err = err + worst_t = t + + print(f"\n── Wrapper y vs dataset y (robot nodes) ─────────────────") + print(f" Steps checked : {T}") + print(f" Max element error : {max_err:.6f} at step {worst_t}") + + assert max_err <= Y_MATCH_TOL, ( + f"Wrapper._get_y_feats and dataset.get_y_feats disagree by " + f"{max_err:.2e} at step {worst_t} (tolerance {Y_MATCH_TOL:.0e}).\n" + f"The network sees different obs conditioning at train vs eval time.\n" + f"Check that both use the same feature ordering: " + f"[joint_pos(7), gripper_qpos(2), node_id(1)] for robot nodes." + ) + + def test_obs_y_feature_shape_is_consistent( + self, episode_data, wrapper_get_y_fn, dataset_get_y_fn + ): + """ + y from the wrapper (single step) and from the dataset (single step) + must have the same number of columns (feature dimensionality). + """ + data_dict = _build_dataset_data_dict(episode_data) + obs_dict = _build_wrapper_obs_dict(episode_data, 0) + + y_wrapper = wrapper_get_y_fn(obs_dict) + y_dataset = dataset_get_y_fn(data_dict, [0]) + + print(f"\n── y feature shape ──────────────────────────────────────") + print(f" wrapper y.shape : {tuple(y_wrapper.shape)}") + print(f" dataset y.shape : {tuple(y_dataset[:, 0, :].shape)}") + + assert y_wrapper.shape == y_dataset[:, 0, :].shape, ( + f"y shape mismatch: wrapper {tuple(y_wrapper.shape)} vs " + f"dataset {tuple(y_dataset[:, 0, :].shape)}.\n" + f"The policy obs conditioning tensor has the wrong number of features." + ) + + +# ── Test 2: nobs format when assembled via obs_deque ───────────────────────── + +class TestObsDequeAssembly: + """ + In get_action(), the policy assembles nobs as: + + for i in range(len(obs_deque)): + obs_cond.append(obs_deque[i].y.unsqueeze(1)) + obs = torch.cat(obs_cond, dim=1) # (nodes, obs_horizon, feat) + nobs = dataset.normalize_data(obs, 'obs') + + This test simulates that assembly using dataset samples and verifies: + (a) shape is (num_nodes, obs_horizon, obs_feat_dim) + (b) normalised values are in [-1, 1] + """ + + OBS_HORIZON = 4 + + def _assemble_nobs(self, dataset, start_idx): + """Assemble nobs the same way get_action() does, using dataset samples.""" + obs_cond = [] + for i in range(self.OBS_HORIZON): + idx = max(0, start_idx - (self.OBS_HORIZON - 1 - i)) + data = dataset.get(idx) + # data.y shape: (nodes, obs_horizon, feat) — take last step + obs_cond.append(data.y[:, -1:, :]) # (nodes, 1, feat) + obs = torch.cat(obs_cond, dim=1) # (nodes, obs_horizon, feat) + return obs + + def test_nobs_shape(self, dataset): + """nobs assembled from obs_deque has the expected shape.""" + num_nodes = 9 + 1 # 9 robot + 1 object for lift task + obs = self._assemble_nobs(dataset, start_idx=10) + + print(f"\n── nobs shape check ─────────────────────────────────────") + print(f" Assembled nobs shape : {tuple(obs.shape)}") + print(f" Expected : ({num_nodes}, {self.OBS_HORIZON}, *)") + + assert obs.shape[0] == num_nodes, ( + f"nobs has {obs.shape[0]} nodes, expected {num_nodes}." + ) + assert obs.shape[1] == self.OBS_HORIZON, ( + f"nobs has {obs.shape[1]} obs steps, expected {self.OBS_HORIZON}." + ) + + def test_nobs_normalised_range(self, dataset): + """Normalized nobs (excluding node-ID column) is in [-NORM_RANGE_TOL, NORM_RANGE_TOL].""" + CHECK_N = 20 + step = max(1, dataset.len() // CHECK_N) + + all_norm = [] + for start_idx in range(0, dataset.len(), step): + obs = self._assemble_nobs(dataset, start_idx=start_idx) + nobs = dataset.normalize_data(obs, stats_key="obs") + # Exclude node-ID column (last feature) + nobs_no_id = nobs[:, :, :-1] + all_norm.append(nobs_no_id.reshape(-1).detach().numpy()) + + import numpy as np + all_norm = np.concatenate(all_norm) + out_of_range = np.abs(all_norm) > NORM_RANGE_TOL + frac_oob = out_of_range.mean() + + print(f"\n── nobs normalisation range ─────────────────────────────") + print(f" Samples : {CHECK_N}") + print(f" Min norm : {all_norm.min():.4f}") + print(f" Max norm : {all_norm.max():.4f}") + print(f" Frac OOB : {frac_oob*100:.3f} %") + + assert frac_oob == 0.0, ( + f"{frac_oob*100:.2f}% of nobs values outside ±{NORM_RANGE_TOL} " + f"(min={all_norm.min():.4f}, max={all_norm.max():.4f}).\n" + f"Obs normalizer saturates the conditioning signal before it reaches " + f"the GDDPM, causing it to discard information." + ) + + +# ── Test 3: dataset action → env step → next obs ───────────────────────────── + +class TestActionStepMatchesDatasetTransition: + """ + Replay dataset-recorded actions through the wrapper's step() and verify + that the resulting joint positions match dataset obs[t+1] within tolerance. + + This is the decisive end-to-end check: it validates that the action format + the policy outputs (graph node velocities) is correctly interpreted by the + wrapper. A failure here means even a *perfect* policy would fail at eval. + + To map dataset actions (shape (T, 7) OSC_POSE / JOINT_VELOCITY) to the + 9-element graph action format: + graph_action[0:7] = dataset joint velocities (7 DOF) + graph_action[7] = dataset gripper velocity (finger 0) ← often 0 or ±1 + graph_action[8] = dataset gripper velocity (finger 1) + + The wrapper then uses action[:7] + action[8], discarding action[7]. + """ + + N_STEPS = 20 # replay first N steps to keep test fast + + def test_dataset_action_produces_correct_next_obs(self, episode_data): + """ + Restore the simulator to the recorded t=0 state, then apply the + first N_STEPS dataset actions through a live robosuite env with the + SAME control_freq as the dataset. Compare resulting joint_pos with + dataset obs[t+1]. + + Uses make_env() (reads control_freq from HDF5) so this test is + self-consistent regardless of the wrapper's hard-coded control_freq. + """ + import robosuite as suite + + # Build the env from the recorded env_args (bypasses wrapper) + with h5py.File(DATASET_PATH, "r") as f: + env_args = json.loads(f["data"].attrs["env_args"]) + env_kwargs = dict(env_args["env_kwargs"]) + env_kwargs["has_renderer"] = False + env_kwargs["has_offscreen_renderer"] = False + env_kwargs["reward_shaping"] = False + env = suite.make(env_args["env_name"], **env_kwargs) + + env.reset() + env.sim.set_state_from_flattened(episode_data["states"][0]) + env.sim.forward() + + actions = episode_data["actions"] # (T, 7) — raw OSC / JV actions + joint_pos = episode_data["joint_pos"] # (T, 7) — ground truth obs + + max_err = 0.0 + worst_t = -1 + per_step = [] + for t in range(min(self.N_STEPS, len(actions) - 1)): + live_obs, _, _, _ = env.step(actions[t]) + # Raw robosuite env (no GymWrapper) provides sin/cos instead of + # joint_pos directly. Reconstruct via arctan2. + sin_q = live_obs["robot0_joint_pos_sin"] # (7,) + cos_q = live_obs["robot0_joint_pos_cos"] # (7,) + live_jpos = np.arctan2(sin_q, cos_q) + ds_jpos = joint_pos[t + 1] + err = float(np.max(np.abs(live_jpos - ds_jpos))) + per_step.append(err) + if err > max_err: + max_err = err + worst_t = t + + env.close() + + print(f"\n── Dataset action → env step → joint_pos match ──────────") + print(f" Steps replayed : {len(per_step)}") + print(f" Max joint error : {max_err:.5f} rad at step {worst_t}") + print(f" Mean joint error : {np.mean(per_step):.5f} rad") + + assert max_err <= JOINT_STEP_TOL, ( + f"Applying dataset action at step {worst_t} yielded joint_pos error " + f"{max_err:.5f} rad (tolerance {JOINT_STEP_TOL} rad).\n" + f"This means the dataset action format is NOT compatible with the " + f"environment's step() interface — the policy will fail at eval even if " + f"it perfectly reproduces the training actions.\n" + f"Likely cause: action convention mismatch (OSC_POSE vs JOINT_VELOCITY) " + f"or control_freq mismatch." + ) + + def test_wrapper_step_with_dataset_action_matches_next_obs(self, episode_data): + """ + Same as above, but stepping through RobomimicGraphWrapper.step() to + test the wrapper's action interpretation end-to-end. + + The wrapper expects a 9-element action vector (one per graph node). + We pad the 7-DOF dataset action with zeros at positions 7 and 8 + (gripper fingers), matching the expected format. + + NOTE: this test deliberately targets the first N_STEPS of demo_0 + to keep runtime short. A failure indicates the wrapper's step() + action slicing is wrong. + """ + from imitation.env.robomimic_graph_wrapper import RobomimicGraphWrapper + + wrapper = RobomimicGraphWrapper( + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + task="Lift", + has_renderer=False, + robots=["Panda"], + control_mode="JOINT_VELOCITY", + base_link_shift=BASE_LINK_SHIFT, + base_link_rotation=BASE_LINK_ROTATION, + ) + + # Restore to recorded t=0 via the inner robosuite env + wrapper.env.env.reset() + wrapper.env.env.sim.set_state_from_flattened(episode_data["states"][0]) + wrapper.env.env.sim.forward() + + actions_raw = episode_data["actions"] # (T, 7) raw velocities + joint_pos = episode_data["joint_pos"] # (T, 7) ground truth + + max_err = 0.0 + worst_t = -1 + per_step = [] + + for t in range(min(self.N_STEPS, len(actions_raw) - 1)): + # Pad to 9-element graph action: [j0..j6, gripper_f0, gripper_f1] + # Dataset stores 7-DOF velocities; gripper comes from gripper_qvel + gripper_vel = episode_data["gripper_qvel"][t] + + # Build and pass a 9-element action to wrapper.step() + graph_action = np.concatenate([ + actions_raw[t], # 7 joint velocities / OSC DOF + gripper_vel[:2], # 2 gripper DOF + ]) # total: 9 elements + + graph_obs, _, done, _ = wrapper.step(graph_action) + + # Extract joint_pos from the graph observation's y field + # y shape: (num_nodes, feat) where feat = [jp0..jp6, gp0, gp1, node_id] + # Robot nodes 0..8; joint pos is stored in y[:9, 0..6] + live_jpos = graph_obs.y[:7, 0].detach().numpy() # nodes 0-6 → 7 joints + ds_jpos = joint_pos[t + 1] + + err = float(np.max(np.abs(live_jpos - ds_jpos))) + per_step.append(err) + if err > max_err: + max_err = err + worst_t = t + + if done: + break + + wrapper.close() + + print(f"\n── Wrapper step joint_pos match ─────────────────────────") + print(f" Steps replayed : {len(per_step)}") + print(f" Max joint error : {max_err:.5f} rad at step {worst_t}") + print(f" Mean joint error : {np.mean(per_step):.5f} rad") + + assert max_err <= JOINT_STEP_TOL, ( + f"wrapper.step() joint_pos error {max_err:.5f} rad at step {worst_t} " + f"exceeds {JOINT_STEP_TOL} rad.\n" + f"This verifies that even with perfect dataset actions, the wrapper " + f"does not correctly advance the simulator state.\n" + f"Likely causes: (1) control_freq mismatch between dataset and wrapper, " + f"(2) wrong action slicing (action[j+8] instead of action[j+7]), " + f"(3) wrong control mode (OSC_POSE vs JOINT_VELOCITY)." + ) + + +# ── Test 4: y feature ordering ──────────────────────────────────────────────── + +class TestObsYFeatureOrdering: + """ + Validate that wrapper._get_y_feats and dataset.get_y_feats produce + identical feature *ordering* for robot nodes: + col 0..6 : joint_pos (7 values) + col 7..8 : gripper_qpos (2 values) + col 9 : node_id + + A column-ordering mismatch would mean the normalizer scales the wrong + physical quantities, making the policy conditioning signal meaningless. + """ + + def test_robot_y_columns_are_joint_pos_then_gripper_then_id( + self, episode_data, wrapper_get_y_fn, dataset_get_y_fn + ): + """ + At a known timestep, check that columns 0-6 of y[:9] match joint_pos, + columns 7-8 match gripper_qpos, and column 9 (if present) matches + the node index 0..8. + + This pins down the actual in-memory layout, making any accidental + reordering immediately visible. + """ + t = 5 # arbitrary mid-episode step + obs_dict = _build_wrapper_obs_dict(episode_data, t) + data_dict = _build_dataset_data_dict(episode_data) + + y_wrapper = wrapper_get_y_fn(obs_dict) # (num_nodes, feat) + y_dataset = dataset_get_y_fn(data_dict, [t])[:, 0, :] # (num_nodes, feat) + + jp = torch.tensor(episode_data["joint_pos"][t]) # (7,) + gp = torch.tensor(episode_data["gripper_qpos"][t]) # (2,) + + # Robot nodes 0..6 correspond to 7 joints; nodes 7 & 8 are gripper nodes. + # get_y_feats packs each robot node with its own joint feature: + # node i → [joint_i_val, 0, 0, ..., node_id] (sparse, one joint per node) + + print(f"\n── y feature ordering check (t={t}) ──────────────────────") + print(f" Wrapper y[:10,:] =\n{y_wrapper[:10,:]}") + print(f" Dataset y[:10,:] =\n{y_dataset[:10,:]}") + print(f" Expected jp: {jp.numpy()}") + print(f" Expected gp: {gp.numpy()}") + + # Verify node IDs (last column) for both wrapper and dataset + num_robot_nodes = 9 + expected_node_ids = torch.arange(num_robot_nodes, dtype=y_wrapper.dtype) + + wrapper_node_ids = y_wrapper[:num_robot_nodes, -1] + dataset_node_ids = y_dataset[:num_robot_nodes, -1] + + assert torch.allclose(wrapper_node_ids, expected_node_ids, atol=1e-3), ( + f"Wrapper y node IDs {wrapper_node_ids.tolist()} != expected {expected_node_ids.tolist()}.\n" + f"The node-ID column ordering is wrong in the wrapper." + ) + assert torch.allclose(dataset_node_ids, expected_node_ids, atol=1e-3), ( + f"Dataset y node IDs {dataset_node_ids.tolist()} != expected {expected_node_ids.tolist()}.\n" + f"The node-ID column ordering is wrong in the dataset." + ) + + +# ── Test 5: OSC_POSE node features shape and content ───────────────────────── + +class TestOscPoseNodeFeats: + """ + Verify that RobomimicGraphWrapper._get_node_feats for control_mode='OSC_POSE' + produces the same 9-node structure as JOINT modes, preserving graph topology. + + Expected behavior (node_feature_dim=1): + - Shape: (9, 1) -- 9 robot nodes x 1 scalar feature each + - Nodes 0-2: eef_pos components (3D position) + - Nodes 3-6: eef_quat components (4D quaternion) + - Node 7: unused (0.0) + - Node 8: mean gripper_qpos + """ + + def _make_wrapper_node_feats_fn(self): + """Build bound _get_node_feats callable with OSC_POSE control mode.""" + import types + from diffusion_policy.model.common.rotation_transformer import RotationTransformer + + mod = _load_module("rg_wrapper_osc", "imitation/env/robomimic_graph_wrapper.py") + mock = types.SimpleNamespace( + num_robots=1, + control_mode="OSC_POSE", + BASE_LINK_SHIFT=BASE_LINK_SHIFT, + BASE_LINK_ROTATION=BASE_LINK_ROTATION, + rotation_transformer=RotationTransformer(from_rep="quaternion", to_rep="rotation_6d"), + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + num_objects=1, + ) + return mod.RobomimicGraphWrapper._get_node_feats.__get__(mock) + + def _build_obs_dict_osc(self, episode_data, t): + """Build obs dict with eef_pos and eef_quat_raw for OSC_POSE _get_node_feats.""" + with h5py.File(DATASET_PATH, "r") as f: + ep = f["data/demo_0"] + eef_pos = ep["obs/robot0_eef_pos"][t] + eef_quat = ep["obs/robot0_eef_quat"][t] + return { + "robot0_joint_pos": episode_data["joint_pos"][t], + "robot0_joint_vel": episode_data["joint_vel"][t], + "robot0_gripper_qpos": episode_data["gripper_qpos"][t], + "robot0_gripper_qvel": episode_data["gripper_qvel"][t], + "robot0_eef_pos": eef_pos, + "robot0_eef_quat_raw": eef_quat, # raw 4D + } + + def test_osc_pose_node_feats_shape(self, episode_data): + """OSC_POSE node features must be (9, 1) -- matching JOINT mode topology.""" + get_node_feats = self._make_wrapper_node_feats_fn() + obs_dict = self._build_obs_dict_osc(episode_data, 10) + feats = get_node_feats(obs_dict, control_mode="OSC_POSE") + assert feats.shape == (9, 1), ( + f"OSC_POSE node features have shape {tuple(feats.shape)}, expected (9, 1).\n" + f"The graph topology must have 9 robot nodes for GDDPM compatibility." + ) + + def test_osc_pose_nodes_0to2_match_eef_pos(self, episode_data): + """Nodes 0-2 must match the raw eef_pos values.""" + get_node_feats = self._make_wrapper_node_feats_fn() + t = 10 + obs_dict = self._build_obs_dict_osc(episode_data, t) + feats = get_node_feats(obs_dict, control_mode="OSC_POSE") + expected_pos = torch.tensor(obs_dict["robot0_eef_pos"], dtype=torch.float32) + assert torch.allclose(feats[:3, 0], expected_pos, atol=1e-5), ( + f"OSC_POSE nodes 0-2 (eef_pos) mismatch:\n" + f" got {feats[:3, 0].tolist()}\n" + f" expected {expected_pos.tolist()}" + ) + + def test_osc_pose_nodes_3to6_match_eef_quat(self, episode_data): + """Nodes 3-6 must match the raw eef_quat (4D) values.""" + get_node_feats = self._make_wrapper_node_feats_fn() + t = 10 + obs_dict = self._build_obs_dict_osc(episode_data, t) + feats = get_node_feats(obs_dict, control_mode="OSC_POSE") + expected_quat = torch.tensor(obs_dict["robot0_eef_quat_raw"], dtype=torch.float32) + assert torch.allclose(feats[3:7, 0], expected_quat, atol=1e-5), ( + f"OSC_POSE nodes 3-6 (eef_quat) mismatch:\n" + f" got {feats[3:7, 0].tolist()}\n" + f" expected {expected_quat.tolist()}" + ) + + def test_osc_pose_node7_is_zero(self, episode_data): + """Node 7 (unused) must be 0.0.""" + get_node_feats = self._make_wrapper_node_feats_fn() + obs_dict = self._build_obs_dict_osc(episode_data, 10) + feats = get_node_feats(obs_dict, control_mode="OSC_POSE") + assert float(feats[7, 0]) == 0.0, ( + f"OSC_POSE node 7 (unused) is {float(feats[7, 0])}, expected 0.0." + ) + + def test_osc_pose_node8_is_gripper(self, episode_data): + """Node 8 must contain the mean gripper_qpos.""" + get_node_feats = self._make_wrapper_node_feats_fn() + for t in range(len(episode_data["gripper_qpos"])): + if np.any(np.abs(episode_data["gripper_qpos"][t]) > 0.01): + break + obs_dict = self._build_obs_dict_osc(episode_data, t) + feats = get_node_feats(obs_dict, control_mode="OSC_POSE") + expected_val = float(np.mean(episode_data["gripper_qpos"][t])) + assert abs(float(feats[8, 0]) - expected_val) < 1e-5, ( + f"OSC_POSE gripper node {float(feats[8, 0]):.6f} != expected {expected_val:.6f}" + ) + + + +# ── Test 6: OSC_POSE wrapper step replay ───────────────────────────────────── + +class TestOscPoseWrapperStep: + """ + Replay OSC_POSE dataset actions (7D) through RobomimicGraphWrapper with + control_mode='OSC_POSE' and verify end-to-end correctness: + + 1. Graph observations have the expected 10-node count (9 robot + 1 object). + 2. Node positions (pos) contain no NaN or Inf values. + 3. Joint positions after each step match dataset obs[t+1] within tolerance. + + The dataset was recorded with OSC_POSE, so replaying its actions in the + same environment with OSC_POSE should reproduce the original trajectory. + """ + + N_STEPS = 20 # replay first N steps (fast smoke test) + JOINT_TOL = 0.08 # rad — same as TestActionStepMatchesDatasetTransition + + def test_osc_pose_step_graph_structure(self, episode_data): + """ + Each step must return a graph with 10 nodes (9 robot + 1 cube object). + """ + from imitation.env.robomimic_graph_wrapper import RobomimicGraphWrapper + + wrapper = RobomimicGraphWrapper( + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + task="Lift", + has_renderer=False, + robots=["Panda"], + control_mode="OSC_POSE", + base_link_shift=BASE_LINK_SHIFT, + base_link_rotation=BASE_LINK_ROTATION, + ) + wrapper.env.env.reset() + wrapper.env.env.sim.set_state_from_flattened(episode_data["states"][0]) + wrapper.env.env.sim.forward() + + actions = episode_data["actions"] # (T, 7) — OSC_POSE raw actions + for t in range(min(self.N_STEPS, len(actions) - 1)): + graph_obs, reward, done, info = wrapper.step(actions[t]) + assert graph_obs.x.shape[0] == 10, ( + f"Step {t}: graph has {graph_obs.x.shape[0]} nodes, expected 10 " + f"(9 robot + 1 object). OSC_POSE changed the graph topology." + ) + if done: + break + + wrapper.close() + + def test_osc_pose_step_pos_no_nan(self, episode_data): + """ + Node positions must be finite (no NaN / Inf) after each OSC_POSE step. + """ + from imitation.env.robomimic_graph_wrapper import RobomimicGraphWrapper + + wrapper = RobomimicGraphWrapper( + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + task="Lift", + has_renderer=False, + robots=["Panda"], + control_mode="OSC_POSE", + base_link_shift=BASE_LINK_SHIFT, + base_link_rotation=BASE_LINK_ROTATION, + ) + wrapper.env.env.reset() + wrapper.env.env.sim.set_state_from_flattened(episode_data["states"][0]) + wrapper.env.env.sim.forward() + + actions = episode_data["actions"] + for t in range(min(self.N_STEPS, len(actions) - 1)): + graph_obs, reward, done, info = wrapper.step(actions[t]) + pos = graph_obs.pos + assert torch.all(torch.isfinite(pos)), ( + f"Step {t}: graph_obs.pos contains NaN or Inf:\n{pos}" + ) + if done: + break + + wrapper.close() + + def test_osc_pose_step_joint_pos_matches_dataset(self, episode_data): + """ + After replaying dataset OSC_POSE actions, the resulting joint positions + must match dataset obs[t+1] within JOINT_TOL radians. + + This is the key functional test: it verifies that the wrapper correctly + forwards 7D OSC_POSE actions to robosuite without any reformatting. + """ + from imitation.env.robomimic_graph_wrapper import RobomimicGraphWrapper + + wrapper = RobomimicGraphWrapper( + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + task="Lift", + has_renderer=False, + robots=["Panda"], + control_mode="OSC_POSE", + base_link_shift=BASE_LINK_SHIFT, + base_link_rotation=BASE_LINK_ROTATION, + ) + wrapper.env.env.reset() + wrapper.env.env.sim.set_state_from_flattened(episode_data["states"][0]) + wrapper.env.env.sim.forward() + + actions = episode_data["actions"] # (T, 7) OSC_POSE actions + joint_pos = episode_data["joint_pos"] # (T, 7) ground truth + + max_err = 0.0 + worst_t = -1 + per_step = [] + + for t in range(min(self.N_STEPS, len(actions) - 1)): + graph_obs, reward, done, info = wrapper.step(actions[t]) + + # y field stores [joint_pos(7), gripper(2), node_id] per robot node + # Each robot node i stores its own joint value at y[i, 0] + live_jpos = graph_obs.y[:7, 0].detach().numpy() + ds_jpos = joint_pos[t + 1] + err = float(np.max(np.abs(live_jpos - ds_jpos))) + per_step.append(err) + if err > max_err: + max_err = err + worst_t = t + + if done: + break + + wrapper.close() + + print(f"\n── OSC_POSE wrapper replay (joint_pos) ─────────────────────") + print(f" Steps replayed : {len(per_step)}") + print(f" Max joint error : {max_err:.5f} rad at step {worst_t}") + print(f" Mean joint error : {np.mean(per_step):.5f} rad") + + assert max_err <= self.JOINT_TOL, ( + f"OSC_POSE wrapper.step() joint_pos error {max_err:.5f} rad at step {worst_t} " + f"exceeds tolerance {self.JOINT_TOL} rad.\n" + f"Likely causes:\n" + f" (1) Actions not forwarded as-is to robosuite (check step() for OSC_POSE branch)\n" + f" (2) control_freq mismatch between HDF5 env_args and wrapper\n" + f" (3) Sim state restoration at t=0 incomplete" + ) + diff --git a/tests/test_train_eval_consistency.py b/tests/test_train_eval_consistency.py new file mode 100644 index 0000000..8eb95dd --- /dev/null +++ b/tests/test_train_eval_consistency.py @@ -0,0 +1,492 @@ +""" +Test Suite: Training / Evaluation Consistency Checks +===================================================== + +Goal +---- +Validate that the data pipeline used during GDDPM *training* is fully +consistent with what is presented at *evaluation* time. These tests do NOT +require a live robosuite environment – they operate purely on the dataset and +on lightweight mock objects that mirror the real classes. + +These tests act as **regression guards**: they pass when the code is correct +and fail if any of the three bugs are re-introduced. + +Bugs addressed (now fixed): +1. control_freq was 30 Hz; dataset recorded at 20 Hz → 33% velocity scaling error. +2. pos was _get_node_pos(data, idx-1) i.e. one step stale vs actions at idx. +3. Gripper routing used action[j+8], silently dropping action[j+7] (finger 0). +""" + +import importlib.util +import json +import os + +import h5py +import numpy as np +import pytest +import torch +from scipy.spatial.transform import Rotation as R + +# ── paths ───────────────────────────────────────────────────────────────────── +DATASET_PATH = "data/lift/ph/low_dim_v141.hdf5" +EPISODE_KEY = "demo_0" + +# ── lift-task config (lift_graph.yaml) ──────────────────────────────────────── +BASE_LINK_SHIFT = np.array([-0.56, 0.0, 0.912]) +BASE_LINK_ROTATION = [0.0, 0.0, 0.0, 1.0] # identity (x,y,z,w) + +# ── tolerances ──────────────────────────────────────────────────────────────── +# Maximum *mean* per-node pos drift between consecutive timesteps (metres). +# A non-zero value here is the root cause of the training/eval pos mismatch. +POS_DRIFT_TOL = 0.0 # exact zero: ANY drift is a mismatch + +# Normalizer output should be clipped to this range. +NORM_CLIP_TOL = 1.05 # allow 5 % headroom above ±1 for float imprecision + +# Gripper action at index 7 (the discarded dimension) materialness threshold. +GRIPPER_DIM7_TOL = 1e-3 # rad – treat as "materially non-zero" if above this + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _load_module(name, rel_path): + """Load a Python module directly from a relative file path.""" + spec = importlib.util.spec_from_file_location( + name, + os.path.join(os.path.dirname(__file__), "..", rel_path) + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def _load_calculate_panda_joints_positions(): + mod = _load_module("imitation_generic", "imitation/utils/generic.py") + return mod.calculate_panda_joints_positions + + +_calc_panda = None + + +def compute_node_pos_xyz(joint_pos_7, gripper_qpos_2): + """Mirror of the test helper in test_node_pos_consistency.py.""" + global _calc_panda + if _calc_panda is None: + _calc_panda = _load_calculate_panda_joints_positions() + joints = [*joint_pos_7.tolist(), *gripper_qpos_2.tolist()] + node_pos = _calc_panda(joints) + rot_mat = torch.tensor(R.from_quat(BASE_LINK_ROTATION).as_matrix()).to(node_pos.dtype) + node_pos[:, :3] = torch.matmul(node_pos[:, :3], rot_mat) + node_pos[:, :3] += torch.tensor(BASE_LINK_SHIFT).to(node_pos.dtype) + return node_pos[:, :3] # (9, 3) + + +# ── fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture(scope="module") +def episode_data(): + """Load demo_0 from the HDF5 file.""" + with h5py.File(DATASET_PATH, "r") as f: + ep = f[f"data/{EPISODE_KEY}"] + joint_pos = ep["obs/robot0_joint_pos"][:] # (T, 7) + gripper_qpos = ep["obs/robot0_gripper_qpos"][:] # (T, 2) + gripper_qvel = ep["obs/robot0_gripper_qvel"][:] # (T, 2) + actions = ep["actions"][:] # (T, 7) + return joint_pos, gripper_qpos, gripper_qvel, actions + + +@pytest.fixture(scope="module") +def dataset_env_args(): + """Read env_args attribute written by robomimic at record time.""" + with h5py.File(DATASET_PATH, "r") as f: + return json.loads(f["data"].attrs["env_args"]) + + +# ── Test 1: control_freq ────────────────────────────────────────────────────── + +class TestControlFreqConsistency: + """ + The dataset is recorded at a specific control_freq. + RobomimicGraphWrapper hard-codes control_freq=30 Hz (line 84 of + robomimic_graph_wrapper.py). A mismatch means joint-velocity actions + from the offline dataset are applied over the wrong Δt, causing + the robot to systematically under/overshoot. + """ + + def test_control_freq_matches_wrapper(self, dataset_env_args): + """ + The control_freq stored in the HDF5 must match the value used by + RobomimicGraphWrapper (now fixed to 20 Hz to match dataset). + + Regression guard: fails if control_freq is changed back to 30 Hz, + which would scale every JOINT_VELOCITY action by 0.667×. + """ + WRAPPER_CONTROL_FREQ = 20 # fixed in robomimic_graph_wrapper.py (was 30) + + recorded_freq = dataset_env_args["env_kwargs"].get("control_freq") + assert recorded_freq is not None, ( + "control_freq not found in dataset env_args – cannot verify consistency." + ) + + print(f"\n── Control frequency check ──────────────────────────────") + print(f" Dataset recorded at : {recorded_freq} Hz") + print(f" Wrapper uses : {WRAPPER_CONTROL_FREQ} Hz") + if recorded_freq != WRAPPER_CONTROL_FREQ: + ratio = recorded_freq / WRAPPER_CONTROL_FREQ + print(f" MISMATCH: velocity scale factor = {ratio:.3f}x") + print(f" Actions will under/overshoot by {abs(1-ratio)*100:.1f} %") + + assert recorded_freq == WRAPPER_CONTROL_FREQ, ( + f"control_freq MISMATCH: dataset recorded at {recorded_freq} Hz but " + f"RobomimicGraphWrapper uses {WRAPPER_CONTROL_FREQ} Hz.\n" + f"JOINT_VELOCITY actions will be scaled by {recorded_freq/WRAPPER_CONTROL_FREQ:.3f}x " + f"relative to training, causing the arm to systematically " + f"{'over' if recorded_freq > WRAPPER_CONTROL_FREQ else 'under'}shoot.\n" + f"Fix: set control_freq={recorded_freq} in RobomimicGraphWrapper.__init__() " + f"or in lift_graph.yaml → env_runner.env." + ) + + def test_horizon_sanity(self, dataset_env_args): + """ + The 'horizon' (max episode length) in the dataset should be ≥ our + configured max_steps (500 for lift/ph). If it is much shorter, the + env will terminate before the policy has time to complete the task. + """ + EXPECTED_MAX_STEPS = 500 + recorded_horizon = dataset_env_args["env_kwargs"].get("horizon", None) + print(f"\n── Horizon check ────────────────────────────────────────") + print(f" Dataset horizon : {recorded_horizon}") + print(f" Config max_steps: {EXPECTED_MAX_STEPS}") + if recorded_horizon is not None: + assert recorded_horizon >= EXPECTED_MAX_STEPS, ( + f"Dataset horizon ({recorded_horizon}) < config max_steps ({EXPECTED_MAX_STEPS}). " + f"The environment may terminate prematurely during evaluation." + ) + + +# ── Test 2: pos indexing alignment ──────────────────────────────────────────── + +class TestDatasetPosIndexingAlignment: + """ + In RobomimicGraphDataset.process(): + + for idx in range(1, episode_length - pred_horizon): + node_feats = _get_node_feats_horizon(data, idx, pred_horizon) # at idx + y = _get_y_horizon(data, idx, obs_horizon) # at idx + pos = _get_node_pos(data, idx - 1) # at idx-1 ← ! + + pos is one step behind x and y. During evaluation, the wrapper computes + pos from the *current* observation (no −1 offset). This mismatch means + the network sees different (pos, x/y) correlations at train vs eval time. + """ + + def test_pos_is_one_step_stale_in_training(self, episode_data): + """ + Regression guard: dataset.process() must use idx (not idx-1) for pos, + so that graph coordinates align with the actions and observations at + the same timestep. + + This test verifies that drift between consecutive timestep positions + is below the 1 mm threshold (i.e. pos is taken at idx, not idx-1). + Fails if the idx-1 off-by-one regression is re-introduced. + """ + joint_pos, gripper_qpos, _, _ = episode_data + T = len(joint_pos) + + per_step_drift = [] + max_node_drift = [] + + for t in range(1, T): # idx runs from 1, so compare pos[t-1] vs pos[t] + pos_stale = compute_node_pos_xyz(joint_pos[t-1], gripper_qpos[t-1]) # what idx-1 gives + pos_current = compute_node_pos_xyz(joint_pos[t], gripper_qpos[t]) # what idx gives + + # Per-node max-axis drift (metres) + drift = float(torch.max(torch.abs(pos_stale - pos_current)).item()) + per_step_drift.append(drift) + max_node_drift.append(float(torch.max(torch.norm(pos_stale - pos_current, dim=1)).item())) + + per_step_drift = np.array(per_step_drift) + max_node_drift = np.array(max_node_drift) + + print(f"\n── pos indexing drift (idx-1 vs idx) ───────────────────") + print(f" Steps checked : {T-1}") + print(f" Mean per-step max-axis drift: {per_step_drift.mean()*1e3:.2f} mm") + print(f" Max per-step max-axis drift: {per_step_drift.max()*1e3:.2f} mm") + print(f" Mean per-node L2 drift : {max_node_drift.mean()*1e3:.2f} mm") + print(f" Max per-node L2 drift : {max_node_drift.max()*1e3:.2f} mm") + print(f" (drift represents the training/eval mismatch if pos=idx-1 is used)") + + # Now that the fix is applied (pos=idx), we verify the code itself at + # runtime by importing the dataset module and checking the source. + # The drift numbers here characterise the *magnitude* of the bug if it + # were re-introduced, not the current state. + import inspect + import importlib + ds_spec = importlib.util.spec_from_file_location( + "_rg_ds_check", + os.path.join(os.path.dirname(__file__), "..", + "imitation", "dataset", "robomimic_graph_dataset.py") + ) + ds_mod = importlib.util.module_from_spec(ds_spec) + ds_spec.loader.exec_module(ds_mod) + process_src = inspect.getsource(ds_mod.RobomimicGraphDataset.process) + + # Regression guard: 'idx - 1' must NOT appear in the pos= line + assert "_get_node_pos(data_raw, idx - 1)" not in process_src, ( + f"REGRESSION: dataset.process() still uses pos=_get_node_pos(data, idx-1).\n" + f"This causes a mean pos drift of {per_step_drift.mean()*1e3:.2f} mm between " + f"training and evaluation.\n" + f"Fix: change 'pos = self._get_node_pos(data_raw, idx - 1)' to " + f"'pos = self._get_node_pos(data_raw, idx)' in dataset.process()." + ) + print(f" Source check: pos=idx confirmed (not idx-1). Regression guard PASSED.") + + def test_pos_drift_distribution(self, episode_data): + """ + Report the full distribution of drift between consecutive pos timesteps. + Also serves as a regression guard: checks that the source uses idx not idx-1. + Printed drift stats quantify the magnitude of the bug if re-introduced. + """ + joint_pos, gripper_qpos, _, _ = episode_data + T = len(joint_pos) + + per_step_drift = [] + for t in range(1, T): + pos_prev = compute_node_pos_xyz(joint_pos[t-1], gripper_qpos[t-1]) + pos_curr = compute_node_pos_xyz(joint_pos[t], gripper_qpos[t]) + drift = float(torch.max(torch.abs(pos_prev - pos_curr)).item()) + per_step_drift.append(drift) + + per_step_drift = np.array(per_step_drift) + p90 = np.percentile(per_step_drift, 90) + p99 = np.percentile(per_step_drift, 99) + + print(f"\n── pos drift distribution (consecutive timestep delta) ──") + print(f" (This is the magnitude of the old idx-1 bug — informational only)") + print(f" p50 : {np.median(per_step_drift)*1e3:.2f} mm") + print(f" p90 : {p90*1e3:.2f} mm") + print(f" p99 : {p99*1e3:.2f} mm") + print(f" max : {per_step_drift.max()*1e3:.2f} mm") + + # Regression guard via source inspection (same as test above) + import inspect, importlib + ds_spec = importlib.util.spec_from_file_location( + "_rg_ds_check2", + os.path.join(os.path.dirname(__file__), "..", + "imitation", "dataset", "robomimic_graph_dataset.py") + ) + ds_mod = importlib.util.module_from_spec(ds_spec) + ds_spec.loader.exec_module(ds_mod) + process_src = inspect.getsource(ds_mod.RobomimicGraphDataset.process) + + assert "_get_node_pos(data_raw, idx - 1)" not in process_src, ( + f"REGRESSION: dataset.process() uses pos=_get_node_pos(data, idx-1).\n" + f"This causes p90 drift of {p90*1e3:.2f} mm between training and eval pos.\n" + f"Fix: change to pos=_get_node_pos(data_raw, idx)." + ) + + +# ── Test 3: normalizer range ─────────────────────────────────────────────────── + +class TestNormalizerRange: + """ + After fitting LinearNormalizer on the dataset, every element of + normalize_data(obs) and normalize_data(action) must lie in [-1, 1]. + + The GDDPM clips its noisy input to [-1, 1] (clip_sample=True in + DDPMScheduler), so if the normalizer maps anything outside this range + the observation/action is saturated during training and the network + never learns to reconstruct the extreme values. + + We test a subset of the dataset (first 200 samples) to avoid loading + the full dataset. + """ + + @pytest.fixture(scope="class") + def dataset(self): + """Instantiate RobomimicGraphDataset directly (reads processed cache).""" + import types + import importlib + spec = importlib.util.spec_from_file_location( + "rg_dataset", + os.path.join(os.path.dirname(__file__), "..", "imitation", "dataset", + "robomimic_graph_dataset.py") + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + ds = mod.RobomimicGraphDataset( + dataset_path=DATASET_PATH, + robots=["Panda"], + object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, + object_state_keys={"cube": ["cube_pos", "cube_quat"]}, + pred_horizon=16, + obs_horizon=4, + control_mode="JOINT_VELOCITY", + base_link_shift=[[-0.56, 0.0, 0.912]], + base_link_rotation=[[0.0, 0.0, 0.0, 1.0]], + ) + return ds + + def _sample_indices(self, ds, n=200): + """Return up to n evenly-spaced indices into the dataset.""" + total = ds.len() + step = max(1, total // n) + return list(range(0, total, step))[:n] + + def test_obs_normalizer_range(self, dataset): + """ + Normalizing dataset y (observations) must produce values in [-1, 1]. + The last column (node IDs) is exempt – it is excluded from normalization. + """ + ds = dataset + indices = self._sample_indices(ds) + + all_y_norm = [] + for i in indices: + data = ds.get(i) + y = data.y # (nodes, obs_horizon, feat) + y_norm = ds.normalize_data(y, stats_key="obs") + y_norm_no_id = y_norm[:, :, :-1] # exclude node-ID column + all_y_norm.append(y_norm_no_id.reshape(-1).detach().numpy()) + + all_y_norm = np.concatenate(all_y_norm) + out_of_range = np.abs(all_y_norm) > NORM_CLIP_TOL + frac_oob = out_of_range.mean() + + print(f"\n── Obs normalizer range check ───────────────────────────") + print(f" Samples checked : {len(indices)}") + print(f" Min normalized : {all_y_norm.min():.4f}") + print(f" Max normalized : {all_y_norm.max():.4f}") + print(f" Fraction out of ±{NORM_CLIP_TOL:.2f}: {frac_oob*100:.3f} %") + + assert frac_oob == 0.0, ( + f"{frac_oob*100:.2f}% of normalized obs values are outside ±{NORM_CLIP_TOL}.\n" + f"(min={all_y_norm.min():.4f}, max={all_y_norm.max():.4f})\n" + f"The normalizer was likely fit on a different distribution than what " + f"the policy sees at inference time. Check whether the normalizer stats " + f"are recomputed after changing pred_horizon/obs_horizon." + ) + + def test_action_normalizer_range(self, dataset): + """ + Normalizing dataset x (actions) must produce values in [-1, 1]. + Actions that are clipped by the normalizer cause the policy to learn + on a saturated action space, leading to poor reconstruction at eval. + """ + ds = dataset + indices = self._sample_indices(ds) + + all_x_norm = [] + for i in indices: + data = ds.get(i) + x = data.x # (nodes, pred_horizon, feat) + # Only the first feature dim (joint value/velocity), excluding node-type + x_val = x[:, :, :1] + x_norm = ds.normalize_data( + torch.cat([x_val, torch.zeros_like(x[:,:,1:])], dim=2), + stats_key="action" + )[:, :, :1] + all_x_norm.append(x_norm.reshape(-1).detach().numpy()) + + all_x_norm = np.concatenate(all_x_norm) + out_of_range = np.abs(all_x_norm) > NORM_CLIP_TOL + frac_oob = out_of_range.mean() + + print(f"\n── Action normalizer range check ────────────────────────") + print(f" Samples checked : {len(indices)}") + print(f" Min normalized : {all_x_norm.min():.4f}") + print(f" Max normalized : {all_x_norm.max():.4f}") + print(f" Fraction out of ±{NORM_CLIP_TOL:.2f}: {frac_oob*100:.3f} %") + + assert frac_oob == 0.0, ( + f"{frac_oob*100:.2f}% of normalized action values are outside ±{NORM_CLIP_TOL}.\n" + f"(min={all_x_norm.min():.4f}, max={all_x_norm.max():.4f})\n" + f"The normalizer was likely fit on a different distribution than what " + f"the policy sees at inference time." + ) + + +# ── Test 4: gripper action routing ──────────────────────────────────────────── + +class TestGripperActionRouting: + """ + RobomimicGraphWrapper.step() slices a 9-D action vector as: + + robot_joint_pos = action[j:j+7] # correct + robot_gripper_pos = action[j+8] # ← index 8, skipping index 7! + + For the Panda, the dataset 'actions' are 7-D (OSC_POSE or JOINT_VELOCITY), + but the graph action representation packs joint values as node features for + nodes 0-8 (9 nodes total, node 7 = gripper finger 0, node 8 = gripper + finger 1). So action[7] is the first gripper finger and action[8] is the + second. If the wrapper skips action[7], one gripper DOF is never actuated. + + This test measures whether action[7] (the dropped index) is materially + non-zero in the dataset, which would confirm the routing bug causes + meaningful control errors. + """ + + def test_action_index_7_is_nonzero(self, episode_data): + """ + Informational test: measures that gripper finger 0 (action[7]) carries + real non-zero signal, confirming the routing fix matters. + Now that the fix is applied (wrapper uses action[j+7]), this test only + prints diagnostics — the structural assertion is in the contract test. + """ + joint_pos, gripper_qpos, gripper_qvel, raw_actions = episode_data + + gripper_finger_0 = gripper_qpos[:, 0] # action[7] - now correctly used + gripper_finger_1 = gripper_qpos[:, 1] # action[8] + + range_f0 = float(gripper_finger_0.max() - gripper_finger_0.min()) + range_f1 = float(gripper_finger_1.max() - gripper_finger_1.min()) + + print(f"\n── Gripper action routing check ─────────────────────────") + print(f" action[7] (gripper finger 0, now USED by wrapper):") + print(f" max |val| : {float(np.max(np.abs(gripper_finger_0))):.4f} range : {range_f0:.4f}") + print(f" action[8] (gripper finger 1):") + print(f" max |val| : {float(np.max(np.abs(gripper_finger_1))):.4f} range : {range_f1:.4f}") + + # Fingers of the Panda gripper move symmetrically; confirm high correlation + correlation = float(np.corrcoef(gripper_finger_0, gripper_finger_1)[0, 1]) + print(f" Correlation between finger 0 and finger 1: {correlation:.3f}") + # Both fingers should move together (Panda gripper is symmetric) + assert abs(correlation) > 0.5, ( + f"Gripper fingers 0 and 1 have unexpectedly low correlation ({correlation:.3f}).\n" + f"Check action ordering: both fingers should move symmetrically." + ) + + def test_wrapper_step_action_slice_matches_dataset_convention(self, episode_data): + """ + Regression guard: wrapper.step() must use action[j+7] for the gripper, + NOT action[j+8]. Verified by inspecting the wrapper source code. + + Fails if the gripper routing regression is re-introduced. + """ + import inspect, importlib + wrap_spec = importlib.util.spec_from_file_location( + "_rg_wrap_check", + os.path.join(os.path.dirname(__file__), "..", + "imitation", "env", "robomimic_graph_wrapper.py") + ) + wrap_mod = importlib.util.module_from_spec(wrap_spec) + wrap_spec.loader.exec_module(wrap_mod) + step_src = inspect.getsource(wrap_mod.RobomimicGraphWrapper.step) + + print(f"\n── Wrapper action slicing contract ──────────────────────") + + # Regression guard: the old buggy line must not be present + assert "action[j + 8]" not in step_src and "action[j+8]" not in step_src, ( + f"REGRESSION: wrapper.step() still uses action[j+8] for gripper.\n" + f"This silently drops gripper finger 0 (action[j+7]).\n" + f"Fix: change 'robot_gripper_pos = action[j + 8]' to " + f"'robot_gripper_pos = action[j + 7]' in RobomimicGraphWrapper.step()." + ) + + # Positive check: the correct line IS present + assert "action[j + 7]" in step_src or "action[j+7]" in step_src, ( + f"Expected 'action[j + 7]' in wrapper.step() for gripper routing, " + f"but it was not found.\nCheck RobomimicGraphWrapper.step()." + ) + print(f" Gripper routing uses action[j+7] — regression guard PASSED.") From ed3836f3ce5b2da2331e37f47c8e8edb08551a5c Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 9 Mar 2026 12:46:32 +0100 Subject: [PATCH 03/22] Update train, test and gddpm_policy configs --- imitation/config/policy/gddpm_policy.yaml | 4 ++-- imitation/config/test.yaml | 6 +++--- imitation/config/train.yaml | 6 +++--- train.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/imitation/config/policy/gddpm_policy.yaml b/imitation/config/policy/gddpm_policy.yaml index 35fb8b6..3795393 100644 --- a/imitation/config/policy/gddpm_policy.yaml +++ b/imitation/config/policy/gddpm_policy.yaml @@ -14,7 +14,7 @@ dataset: ${task.dataset} denoising_network: _target_: imitation.model.gddpm.GDDPMNoisePred node_feature_dim: ${policy.node_feature_dim} - cond_feature_dim: 6 # 6-D rotation obs features (excl. node-id) + cond_feature_dim: 9 # 9-D obs features: pos(3) + rot_6d(6) (excl. node-id) obs_horizon: ${obs_horizon} pred_horizon: ${pred_horizon} edge_feature_dim: 1 @@ -28,7 +28,7 @@ denoising_network: num_diffusion_steps: ${policy.num_diffusion_iters} ckpt_path: ./weights/gddpm_policy_${task.task_name}_${task.dataset_type}_${task.control_mode}_${policy.num_diffusion_iters}iters.pt -lr: 2e-4 +lr: 1e-4 batch_size: 128 use_normalization: True keep_first_action: True diff --git a/imitation/config/test.yaml b/imitation/config/test.yaml index 497d590..1efb1b5 100644 --- a/imitation/config/test.yaml +++ b/imitation/config/test.yaml @@ -1,7 +1,7 @@ defaults: - _self_ - task: lift_graph - - policy: graph_ddpm_policy + - policy: gddpm_policy output_video: False render: True @@ -13,8 +13,8 @@ output_dir: ./outputs pred_horizon: 16 obs_horizon: 4 -action_horizon: 4 -action_offset: 1 # action offset for the policy, 1 if first action is to be ignored +action_horizon: 16 +action_offset: 0 # action offset for the policy, 1 if first action is to be ignored agent: diff --git a/imitation/config/train.yaml b/imitation/config/train.yaml index 22e6774..152ab5e 100644 --- a/imitation/config/train.yaml +++ b/imitation/config/train.yaml @@ -1,7 +1,7 @@ defaults: - _self_ - task: lift_graph - - policy: graph_ddpm_policy + - policy: gddpm_policy output_dir: ./outputs # on evaluating, for environment wrapper @@ -10,8 +10,8 @@ output_video: True pred_horizon: 16 obs_horizon: 4 -action_horizon: 2 -action_offset: 1 # action offset for the policy, 1 if first action is to be ignored +action_horizon: 16 +action_offset: 0 # action offset for the policy, 1 if first action is to be ignored # Training parameters num_epochs: 500 val_fraction: 0.1 diff --git a/train.py b/train.py index 3c82028..78eaec7 100644 --- a/train.py +++ b/train.py @@ -39,7 +39,7 @@ def train(cfg: DictConfig) -> None: wandb.init( project=policy.__class__.__name__, group=cfg.task.task_name, - name=f"v1.2.2", + name=f"v1.2.3 - GDDPM", # track hyperparameters and run metadata config={ "policy": cfg.policy, From 86142ad33c51a196894eb2d8e6c410709252a173 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 9 Mar 2026 12:53:30 +0100 Subject: [PATCH 04/22] git commit -m "Add FlatGDDPMNoisePred for OSC_POSE and fix GDDPMNoisePred obs/action node mismatch" --- imitation/config/policy/osc_ddpm_policy.yaml | 30 ++ imitation/config/task/lift_graph_osc.yaml | 53 +++ imitation/model/gddpm.py | 320 ++++++++++++++++- imitation/policy/osc_ddpm_policy.py | 346 +++++++++++++++++++ 4 files changed, 732 insertions(+), 17 deletions(-) create mode 100644 imitation/config/policy/osc_ddpm_policy.yaml create mode 100644 imitation/config/task/lift_graph_osc.yaml create mode 100644 imitation/policy/osc_ddpm_policy.py diff --git a/imitation/config/policy/osc_ddpm_policy.yaml b/imitation/config/policy/osc_ddpm_policy.yaml new file mode 100644 index 0000000..2d70d13 --- /dev/null +++ b/imitation/config/policy/osc_ddpm_policy.yaml @@ -0,0 +1,30 @@ +_target_: imitation.policy.osc_ddpm_policy.OSCGraphDDPMPolicy + +action_dim: ${task.action_dim} # 7 for OSC_POSE +num_edge_types: 2 +pred_horizon: ${pred_horizon} +obs_horizon: ${obs_horizon} +action_horizon: ${action_horizon} +num_diffusion_iters: 100 +dataset: ${task.dataset} + +denoising_network: + _target_: imitation.model.gddpm.FlatGDDPMNoisePred + action_dim: ${policy.action_dim} + cond_feature_dim: 9 # 9-D obs features: pos(3) + rot_6d(6) + obs_horizon: ${obs_horizon} + pred_horizon: ${pred_horizon} + edge_feature_dim: 1 + num_edge_types: ${policy.num_edge_types} + residual_layers: 8 + residual_channels: 32 + dilation_cycle_length: 2 + hidden_dim: 256 + diffusion_step_embed_dim: 64 + num_diffusion_steps: ${policy.num_diffusion_iters} + +ckpt_path: ./weights/osc_ddpm_policy_${task.task_name}_${task.control_mode}_${policy.num_diffusion_iters}iters.pt +lr: 1e-4 +batch_size: 128 +use_normalization: True +keep_first_action: True diff --git a/imitation/config/task/lift_graph_osc.yaml b/imitation/config/task/lift_graph_osc.yaml new file mode 100644 index 0000000..9dccd8d --- /dev/null +++ b/imitation/config/task/lift_graph_osc.yaml @@ -0,0 +1,53 @@ + +task_name: &task_name lift +dataset_type: &dataset_type ph +dataset_path: &dataset_path ./data/lift/${task.dataset_type}/low_dim_v141.hdf5 + +max_steps: 500 + +control_mode: "OSC_POSE" + +obs_dim: 10 # 9 robot + 1 object nodes +action_dim: 7 # flat 7-D EEF vector (xyz + rotation + gripper) + +robots: ["Panda"] + +object_state_sizes: &object_state_sizes + cube_pos: 3 + cube_quat: 4 + gripper_to_cube_pos: 3 + +object_state_keys: &object_state_keys + cube: ["cube_pos", "cube_quat"] + +env_runner: + _target_: imitation.env_runner.robomimic_lowdim_runner.RobomimicEnvRunner + output_dir: ${output_dir} + action_horizon: ${action_horizon} + obs_horizon: ${obs_horizon} + action_offset: ${action_offset} + render: ${render} + output_video: ${output_video} + use_full_pred_after: 0.4 + env: + _target_: imitation.env.robomimic_graph_wrapper.RobomimicGraphWrapper + object_state_sizes: *object_state_sizes + object_state_keys: *object_state_keys + max_steps: ${task.max_steps} + task: "Lift" + has_renderer: ${render} + robots: ${task.robots} + output_video: ${output_video} + control_mode: ${task.control_mode} + base_link_shift: [[-0.56, 0, 0.912]] + +dataset: + _target_: imitation.dataset.robomimic_graph_dataset.RobomimicGraphDataset + dataset_path: ${task.dataset_path} + robots: ${task.robots} + pred_horizon: ${pred_horizon} + obs_horizon: ${obs_horizon} + object_state_sizes: *object_state_sizes + object_state_keys: *object_state_keys + control_mode: ${task.control_mode} + base_link_shift: [[-0.56, 0, 0.912]] diff --git a/imitation/model/gddpm.py b/imitation/model/gddpm.py index 24cc896..963d23b 100644 --- a/imitation/model/gddpm.py +++ b/imitation/model/gddpm.py @@ -332,58 +332,84 @@ def forward(self, x_coord: (N_total, 3) unchanged (kept for API compatibility) """ # ---- move to device / cast ---------------------------------------- - x = x.float().to(self.device) # (N, T, F) + x = x.float().to(self.device) # (N_act, T, F) edge_attr = edge_attr.float().to(self.device) edge_index = edge_index.to(self.device) x_coord = x_coord.float().to(self.device) timesteps = timesteps.to(self.device) + cond = cond.float().to(self.device) + # obs_batch maps OBS nodes to graphs; x may have fewer nodes per graph if batch is None: - batch = torch.zeros(x.shape[0], dtype=torch.long, device=self.device) + obs_batch = torch.zeros(cond.shape[0], dtype=torch.long, device=self.device) else: - batch = batch.long().to(self.device) + obs_batch = batch.long().to(self.device) + + B = obs_batch.max().item() + 1 + obs_npg = cond.shape[0] // B # obs nodes per graph (e.g. 10) + act_npg = x.shape[0] // B # action nodes per graph (e.g. 9) + + # action_batch: maps action nodes to their graph index + action_batch = torch.arange(B, dtype=torch.long, device=self.device).repeat_interleave(act_npg) # separate node-id from conditioning features (last channel of cond) ids = cond[:, 0, -1].long().to(self.device) - cond_feats = cond[:, :, :-1].float().to(self.device) # (N, obs_horizon, C) + cond_feats = cond[:, :, :-1].float().to(self.device) # (N_obs, obs_horizon, C) - # ---- add self-loops for GatedGraphConv compatibility --------------- + # ---- obs edge_index with self-loops (for EGraphConditionEncoder) ---- edge_attr_1d = edge_attr.reshape(-1) - edge_index_sl, edge_attr_sl = add_self_loops( + obs_edge_index_sl, obs_edge_attr_sl = add_self_loops( edge_index, edge_attr_1d, + num_nodes=cond.shape[0], fill_value=0.0 + ) + + # ---- action edge_index: filter to robot-only edges, remap indices --- + if act_npg < obs_npg: + src, dst = edge_index + src_local = src % obs_npg + dst_local = dst % obs_npg + mask = (src_local < act_npg) & (dst_local < act_npg) + act_ei = edge_index[:, mask] + act_ea = edge_attr_1d[mask] + graph_ids_ei = act_ei[0] // obs_npg + act_ei = act_ei % obs_npg + graph_ids_ei * act_npg + else: + act_ei = edge_index + act_ea = edge_attr_1d + act_edge_index_sl, act_edge_attr_sl = add_self_loops( + act_ei, act_ea, num_nodes=x.shape[0], fill_value=0.0 ) # ---- Graph-level conditioning vector -------------------------------- # EGraphConditionEncoder returns (B, cond_channels) graph_cond = self.cond_encoder( - cond_feats, edge_index_sl, x_coord, edge_attr_sl.unsqueeze(-1), - batch=batch, ids=ids + cond_feats, obs_edge_index_sl, x_coord, obs_edge_attr_sl.unsqueeze(-1), + batch=obs_batch, ids=ids ) # (B, cond_channels) # ---- Up-sample conditioning to pred_horizon ------------------------- cond_up = self.cond_upsampler(graph_cond) # (B, pred_horizon) - # Broadcast from per-graph to per-node - cond_up_node = cond_up[batch] # (N, pred_horizon) - cond_up_node = cond_up_node.unsqueeze(1) # (N, 1, pred_horizon) + # Broadcast from per-graph to per-ACTION-node + cond_up_node = cond_up[action_batch] # (N_act, pred_horizon) + cond_up_node = cond_up_node.unsqueeze(1) # (N_act, 1, pred_horizon) # ---- Diffusion step embedding ---------------------------------------- diffusion_step = self.diffusion_embedding(timesteps) # (B, hidden_dim) - diffusion_step_node = diffusion_step[batch] # (N, hidden_dim) + diffusion_step_node = diffusion_step[action_batch] # (N_act, hidden_dim) - # ---- Reshape x: (N, T, F) -> (N, F, T) for Conv1d ------------------ - x_conv = x.permute(0, 2, 1) # (N, F, T) - h = F.leaky_relu(self.input_projection(x_conv), 0.4) # (N, C, T') + # ---- Reshape x: (N_act, T, F) -> (N_act, F, T) for Conv1d ---------- + x_conv = x.permute(0, 2, 1) # (N_act, F, T) + h = F.leaky_relu(self.input_projection(x_conv), 0.4) # (N_act, C, T') # ---- Residual stack -------------------------------------------------- - # Trim h to pred_horizon (circular padding may have added extra steps) skip_sum = None for block in self.residual_blocks: h, skip = block( h, cond_up_node, diffusion_step_node, - edge_index_sl, + act_edge_index_sl, edge_weight=None, ) if skip_sum is None: @@ -416,3 +442,263 @@ def forward(self, noise_pred = out.permute(0, 2, 1) return noise_pred, x_coord + + +# --------------------------------------------------------------------------- +# Flat residual block: dilated conv only, no GatedGraphConv +# --------------------------------------------------------------------------- + +class FlatResidualBlock(nn.Module): + """ + Dilated-conv residual block for flat (non-graph) action sequences. + + Operates on batch-level tensors (B, C, T) instead of node-level (N_total, C, T). + Same gated activation as ResidualBlock, without the GatedGraphConv branch. + """ + def __init__(self, + hidden_size: int, + residual_channels: int, + dilation: int): + super().__init__() + self.residual_channels = residual_channels + + self.dilated_conv = nn.Conv1d( + residual_channels, + 2 * residual_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + padding_mode="circular", + ) + self.diffusion_projection = nn.Linear(hidden_size, 2 * residual_channels) + self.conditioner_projection = nn.Conv1d( + 1, 2 * residual_channels, kernel_size=1, padding=2, padding_mode="circular" + ) + self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) + + nn.init.kaiming_normal_(self.conditioner_projection.weight) + nn.init.kaiming_normal_(self.output_projection.weight) + + def forward(self, + x: torch.Tensor, + conditioner: torch.Tensor, + diffusion_step: torch.Tensor) -> tuple: + """ + x: (B, residual_channels, T) + conditioner: (B, 1, pred_horizon) + diffusion_step: (B, hidden_size) + Returns: (residual_out, skip) both (B, residual_channels, T_min) + """ + B, C, T = x.shape + + cond_proj = self.conditioner_projection(conditioner) # (B, 2C, T') + min_cond = min(cond_proj.shape[-1], T) + cond_proj = cond_proj[..., :min_cond] + + diff_proj = self.diffusion_projection(diffusion_step) # (B, 2C) + diff_proj = diff_proj.unsqueeze(-1) # (B, 2C, 1) + + y_temporal = self.dilated_conv(x) # (B, 2C, T') + T_conv = y_temporal.shape[-1] + + T_min = min(T_conv, min_cond) + y = (y_temporal[..., :T_min] + + cond_proj[..., :T_min] + + diff_proj.expand(B, 2 * C, T_min)) # (B, 2C, T_min) + + gate, filt = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filt) # (B, C, T_min) + + y = F.leaky_relu(self.output_projection(y), 0.4) + residual, skip = torch.chunk(y, 2, dim=1) # each (B, C, T_min) + + T_res = min(residual.shape[-1], T) + residual_out = (x[..., :T_res] + residual[..., :T_res]) / math.sqrt(2.0) + return residual_out, skip[..., :T_res] + + +# --------------------------------------------------------------------------- +# Flat GDDPM noise predictor: graph obs encoding + flat action denoising +# --------------------------------------------------------------------------- + +class FlatGDDPMNoisePred(nn.Module): + """ + Drop-in replacement for GDDPMNoisePred where the action is a flat + (B, pred_horizon, action_dim) tensor instead of per-node. + + The graph structure is used *only* for observation encoding via + EGraphConditionEncoder. The residual denoising blocks operate on the + full action batch at graph granularity (B, ...). + + Args: + action_dim: flat action dimensionality (e.g. 7 for OSC_POSE) + cond_feature_dim: obs feature dim (excl. node-id), e.g. 9 + obs_horizon: number of observation steps for conditioning + pred_horizon: number of prediction steps + edge_feature_dim: edge attribute size (usually 1) + num_edge_types: number of edge type categories + residual_layers: number of FlatResidualBlock layers + residual_channels: channels inside each block + dilation_cycle_length: dilation doubles every this many layers + hidden_dim: hidden size for EGraphConditionEncoder + diffusion_step_embed_dim: sinusoidal embedding size + num_diffusion_steps: total DDPM timesteps + """ + + def __init__(self, + action_dim: int, + cond_feature_dim: int, + obs_horizon: int, + pred_horizon: int, + edge_feature_dim: int, + num_edge_types: int, + residual_layers: int = 8, + residual_channels: int = 32, + dilation_cycle_length: int = 2, + hidden_dim: int = 256, + diffusion_step_embed_dim: int = 64, + num_diffusion_steps: int = 100, + device=None): + super().__init__() + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device + + self.action_dim = action_dim + self.pred_horizon = pred_horizon + self.hidden_dim = hidden_dim + self.residual_channels = residual_channels + + self.cond_channels = hidden_dim + self.cond_encoder = EGraphConditionEncoder( + input_dim=cond_feature_dim * obs_horizon, + output_dim=self.cond_channels, + hidden_dim=hidden_dim, + device=self.device, + ).to(self.device) + + self.diffusion_embedding = DiffusionEmbedding( + dim=diffusion_step_embed_dim, + proj_dim=hidden_dim, + max_steps=num_diffusion_steps, + ).to(self.device) + + self.cond_upsampler = CondUpsampler( + cond_length=self.cond_channels, + target_dim=pred_horizon, + ).to(self.device) + + self.input_projection = nn.Conv1d( + action_dim, + residual_channels, + kernel_size=1, + padding=2, + padding_mode="circular", + ).to(self.device) + + self.residual_blocks = nn.ModuleList([ + FlatResidualBlock( + hidden_size=hidden_dim, + residual_channels=residual_channels, + dilation=2 ** (i % dilation_cycle_length), + ) + for i in range(residual_layers) + ]) + self.residual_blocks.to(self.device) + + self.skip_projection = nn.Conv1d( + residual_channels, residual_channels, kernel_size=3 + ).to(self.device) + self.output_projection = nn.Conv1d( + residual_channels, action_dim, kernel_size=3 + ).to(self.device) + + nn.init.kaiming_normal_(self.input_projection.weight) + nn.init.kaiming_normal_(self.skip_projection.weight) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + x_coord: torch.Tensor, + cond: torch.Tensor, + timesteps: torch.Tensor, + batch: torch.Tensor = None): + """ + Args: + x: (B, pred_horizon, action_dim) flat noisy action + edge_index: (2, E) + edge_attr: (E,) or (E, 1) + x_coord: (N_total, 3) + cond: (N_total, obs_horizon, cond_feature_dim+1) graph obs (+node-id) + timesteps: (B,) + batch: (N_total,) node-to-graph mapping for EGraphConditionEncoder + + Returns: + noise_pred: (B, pred_horizon, action_dim) + x_coord: (N_total, 3) unchanged + """ + x = x.float().to(self.device) + edge_attr = edge_attr.float().to(self.device) + edge_index = edge_index.to(self.device) + x_coord = x_coord.float().to(self.device) + timesteps = timesteps.to(self.device) + if batch is None: + batch = torch.zeros(x_coord.shape[0], dtype=torch.long, device=self.device) + else: + batch = batch.long().to(self.device) + + ids = cond[:, 0, -1].long().to(self.device) + cond_feats = cond[:, :, :-1].float().to(self.device) + + edge_attr_1d = edge_attr.reshape(-1) + edge_index_sl, edge_attr_sl = add_self_loops( + edge_index, edge_attr_1d, + num_nodes=x_coord.shape[0], fill_value=0.0 + ) + + # Graph-level conditioning: (B, cond_channels) + graph_cond = self.cond_encoder( + cond_feats, edge_index_sl, x_coord, edge_attr_sl.unsqueeze(-1), + batch=batch, ids=ids + ) + + # Up-sample conditioning to pred_horizon: (B, pred_horizon) -> (B, 1, pred_horizon) + cond_up = self.cond_upsampler(graph_cond).unsqueeze(1) + + # Diffusion step embedding: (B, hidden_dim) + diffusion_step = self.diffusion_embedding(timesteps) + + # x: (B, T, Da) -> (B, Da, T) for Conv1d + x_conv = x.permute(0, 2, 1) + h = F.leaky_relu(self.input_projection(x_conv), 0.4) + + skip_sum = None + for block in self.residual_blocks: + h, skip = block(h, cond_up, diffusion_step) + if skip_sum is None: + skip_sum = skip + else: + min_T = min(skip_sum.shape[-1], skip.shape[-1]) + skip_sum = (skip_sum[..., :min_T] + skip[..., :min_T]) + + n_layers = len(self.residual_blocks) + skip_sum = skip_sum / math.sqrt(n_layers) + + out = F.leaky_relu(self.skip_projection(skip_sum), 0.4) + out = self.output_projection(out) # (B, Da, T''') + + T_out = out.shape[-1] + if T_out >= self.pred_horizon: + out = out[..., :self.pred_horizon] + else: + pad = torch.zeros( + out.shape[0], out.shape[1], self.pred_horizon - T_out, + device=self.device + ) + out = torch.cat([out, pad], dim=-1) + + noise_pred = out.permute(0, 2, 1) # (B, T, Da) + return noise_pred, x_coord diff --git a/imitation/policy/osc_ddpm_policy.py b/imitation/policy/osc_ddpm_policy.py new file mode 100644 index 0000000..00cf1bb --- /dev/null +++ b/imitation/policy/osc_ddpm_policy.py @@ -0,0 +1,346 @@ +import logging +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tqdm.auto import tqdm +from diffusers.optimization import get_scheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.training_utils import EMAModel +import wandb + +from imitation.policy.base_policy import BasePolicy +from diffusion_policy.dataset.base_dataset import BaseLowdimDataset +from torch_geometric.data import DataLoader + +log = logging.getLogger(__name__) + + +class OSCGraphDDPMPolicy(BasePolicy): + """ + DDPM policy for OSC_POSE control. + + The graph is used *only* for observation encoding. The action space is a + flat (B, pred_horizon, action_dim) tensor (e.g. 7-D EEF vector for OSC_POSE) + with no graph structure. + + Differences from GraphConditionalDDPMPolicy: + - No `node_feature_dim` — actions are flat, not per-node. + - `last_naction` shape is (1, pred_horizon, action_dim). + - Training reshapes batch.y from (action_dim*B, pred_horizon, 1) to + (B, pred_horizon, action_dim) before computing diffusion loss. + - Inference initialises noise as (1, pred_horizon, action_dim). + """ + + def __init__(self, + action_dim: int, + num_edge_types: int, + pred_horizon: int, + obs_horizon: int, + action_horizon: int, + num_diffusion_iters: int, + dataset: BaseLowdimDataset, + denoising_network: nn.Module, + ckpt_path=None, + lr: float = 1e-4, + batch_size: int = 256, + use_normalization: bool = True, + keep_first_action: bool = True): + super().__init__() + self.dataset = dataset + self.batch_size = batch_size + self.action_dim = action_dim + self.ckpt_path = ckpt_path + + self.pred_horizon = pred_horizon + self.obs_horizon = obs_horizon + self.action_horizon = action_horizon + self.num_diffusion_iters = num_diffusion_iters + self.lr = lr + self.use_normalization = use_normalization + self.keep_first_action = keep_first_action + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + log.info(f"Using device {self.device}") + + self.noise_pred_net = denoising_network.to(self.device) + self.ema_noise_pred_net = self.noise_pred_net.to(self.device) + + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=self.num_diffusion_iters, + beta_schedule='squaredcos_cap_v2', + clip_sample=True, + prediction_type='epsilon', + beta_start=1e-4, + beta_end=2e-2, + ) + + self.lr_scheduler = None + self.optimizer = None + self.num_epochs = None + + self.global_epoch = 0 + # (1, pred_horizon, action_dim) + self.last_naction = torch.zeros( + (1, self.pred_horizon, self.action_dim), device=self.device + ) + self.playback_count = 0 + + def load_nets(self, ckpt_path): + if ckpt_path is None: + log.info('No pretrained weights given.') + self.ema_noise_pred_net = self.noise_pred_net.to(self.device) + return + if not os.path.isfile(ckpt_path): + log.error(f"Pretrained weights not found at {ckpt_path}.") + self.ema_noise_pred_net = self.noise_pred_net.to(self.device) + return + try: + state_dict = torch.load(ckpt_path, map_location=self.device) + self.ema_noise_pred_net = self.noise_pred_net + self.ema_noise_pred_net.load_state_dict(state_dict) + self.ema_noise_pred_net.to(self.device) + log.info('Pretrained weights loaded.') + except Exception: + log.error('Error loading pretrained weights.') + self.ema_noise_pred_net = self.noise_pred_net.to(self.device) + + def save_nets(self, ckpt_path): + torch.save(self.ema_noise_pred_net.state_dict(), ckpt_path) + log.info(f"Model saved at {ckpt_path}") + + # ------------------------------------------------------------------ + # Inference + # ------------------------------------------------------------------ + + def get_action(self, obs_deque): + """ + obs_deque: deque of PyG Data objects (length == obs_horizon). + Returns: action (action_horizon, action_dim) numpy array. + """ + obs_cond = [] + G_t = obs_deque[-1] + for i in range(len(obs_deque)): + obs_cond.append(obs_deque[i].x.unsqueeze(1)) + obs = torch.cat(obs_cond, dim=1) # (N_nodes, obs_horizon, feat) + + if self.use_normalization: + nobs = self.dataset.normalize_data(obs, stats_key='obs') + nobs[:, :, -1] = obs[:, :, -1] # preserve node IDs + else: + nobs = obs + + with torch.no_grad(): + noisy_action = torch.randn( + (1, self.pred_horizon, self.action_dim), device=self.device + ) + + if self.keep_first_action: + noisy_action[:, 0, :] = self.last_naction[:, -1, :] + + batch_idx = torch.zeros( + G_t.x.shape[0], dtype=torch.long, device=self.device + ) + + self.noise_scheduler.set_timesteps(self.num_diffusion_iters) + + for k in self.noise_scheduler.timesteps: + noise_pred, _ = self.ema_noise_pred_net( + x=noisy_action, + edge_index=G_t.edge_index, + edge_attr=G_t.edge_attr, + x_coord=G_t.pos[:, :3], + cond=nobs, + timesteps=torch.tensor([k], dtype=torch.long, device=self.device), + batch=batch_idx, + ) + noisy_action = self.noise_scheduler.step( + model_output=noise_pred, + timestep=k, + sample=noisy_action, + ).prev_sample + + if self.keep_first_action: + noisy_action[:, 0, :] = self.last_naction[:, -1, :] + + naction = noisy_action.detach().cpu() # (1, pred_horizon, action_dim) + self.last_naction = naction + + if self.use_normalization: + # Reshape to per-node format for unnormalize_data, then back + naction_node = ( + naction.permute(0, 2, 1) # (1, action_dim, pred_horizon) + .unsqueeze(-1) # (1, action_dim, pred_horizon, 1) + .reshape(self.action_dim, self.pred_horizon, 1) + ) + naction_node = self.dataset.unnormalize_data(naction_node, stats_key='action') + # (action_horizon, action_dim) + action = naction_node[:, :self.action_horizon, 0].T + else: + action = naction[0, :self.action_horizon, :].numpy() + + return action + + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ + + def validate(self, dataset=None, model_path="last.pt"): + log.info('Validating noise prediction network.') + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) + self.load_nets(model_path) + self.ema_noise_pred_net.eval() + + with torch.no_grad(): + val_loss = [] + for batch in dataloader: + B = batch.num_graphs + nobs = batch.x + if self.use_normalization: + nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) + nobs[:, :, -1] = batch.x[:, :, -1] + naction_raw = self.dataset.normalize_data(batch.y, stats_key='action').to(self.device) + else: + naction_raw = batch.y.to(self.device) + + # batch.y: (action_dim*B, pred_horizon, 1) + # reshape to (B, pred_horizon, action_dim) + naction = naction_raw.view(B, self.action_dim, self.pred_horizon, 1) + naction = naction[:, :, :, 0].permute(0, 2, 1) # (B, T, Da) + + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, + (B,), device=self.device + ).long() + + noise = torch.randn_like(naction) + noisy_actions = self.noise_scheduler.add_noise(naction, noise, timesteps) + + if self.keep_first_action: + noisy_actions[:, 0, :] = naction[:, 0, :] + + obs_cond = nobs.float() + noisy_actions = noisy_actions.float() + + noise_pred, _ = self.ema_noise_pred_net( + noisy_actions, + batch.edge_index, + batch.edge_attr, + x_coord=batch.pos[:, :3], + cond=obs_cond, + timesteps=timesteps, + batch=batch.batch, + ) + loss = F.mse_loss(noise_pred, noise) + val_loss.append(loss.item()) + + return np.mean(val_loss) + + # ------------------------------------------------------------------ + # Training + # ------------------------------------------------------------------ + + def train(self, + dataset=None, + num_epochs=100, + model_path="last.pt", + seed=0): + log.info('Training noise prediction network.') + + if self.num_epochs is None: + log.warn(f"Global num_epochs not set. Using {num_epochs}.") + self.num_epochs = num_epochs + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + ema = EMAModel(parameters=self.ema_noise_pred_net.parameters(), power=0.75) + + self.noise_pred_net.to(self.device) + if self.optimizer is None: + self.optimizer = torch.optim.AdamW( + params=self.noise_pred_net.parameters(), + lr=self.lr, + weight_decay=1e-6, + betas=[0.95, 0.999], + eps=1e-8, + ) + + if self.lr_scheduler is None: + self.lr_scheduler = get_scheduler( + name='cosine', + optimizer=self.optimizer, + num_warmup_steps=500, + num_training_steps=len(dataloader) * self.num_epochs, + ) + + with tqdm(range(num_epochs), desc='Epoch') as tglobal: + for epoch_idx in tglobal: + epoch_loss = [] + with tqdm(dataloader, desc='Batch', leave=False) as tepoch: + for batch in tepoch: + B = batch.num_graphs + + nobs = batch.x + if self.use_normalization: + nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) + nobs[:, :, -1] = batch.x[:, :, -1] + naction_raw = self.dataset.normalize_data(batch.y, stats_key='action').to(self.device) + else: + naction_raw = batch.y.to(self.device) + + # batch.y: (action_dim*B, pred_horizon, 1) from PyG DataLoader + # Reshape to (B, pred_horizon, action_dim) + naction = naction_raw.view(B, self.action_dim, self.pred_horizon, 1) + naction = naction[:, :, :, 0].permute(0, 2, 1) # (B, T, Da) + + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, + (B,), device=self.device + ).long() + + noise = torch.randn_like(naction) + noisy_actions = self.noise_scheduler.add_noise(naction, noise, timesteps) + + if self.keep_first_action: + noisy_actions[:, 0, :] = naction[:, 0, :] + + noisy_actions = noisy_actions.float() + obs_cond = nobs.float() + + noise_pred, _ = self.noise_pred_net( + noisy_actions, + batch.edge_index, + batch.edge_attr, + x_coord=batch.pos[:, :3], + cond=obs_cond, + timesteps=timesteps, + batch=batch.batch, + ) + + loss = F.mse_loss(noise_pred, noise) + wandb.log({ + 'noise_pred_loss': loss, + 'lr': self.lr_scheduler.get_last_lr()[0], + }) + + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + ema.step(self.noise_pred_net.parameters()) + + loss_cpu = loss.item() + epoch_loss.append(loss_cpu) + tepoch.set_postfix(loss=loss_cpu) + + tglobal.set_postfix(loss=np.mean(epoch_loss)) + wandb.log({'epoch': self.global_epoch, 'epoch_loss': np.mean(epoch_loss)}) + self.save_nets(model_path) + self.global_epoch += 1 + tglobal.set_description(f"Epoch: {self.global_epoch}") From 13093c8e7f800aabf7e47bc41dc4ed73e648b020 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 9 Mar 2026 12:56:18 +0100 Subject: [PATCH 05/22] git commit -m " Refactor graph data convention: x=observations, y=actions " --- imitation/dataset/robomimic_graph_dataset.py | 90 ++++---- imitation/env/robomimic_graph_wrapper.py | 210 ++++++++++--------- imitation/policy/graph_ddpm_policy.py | 78 +++---- 3 files changed, 188 insertions(+), 190 deletions(-) diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 97fe720..4d81b04 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -77,11 +77,6 @@ def processed_file_names(self): names = [f"data_{i}.pt" for i in range(self.len())] return names - @lru_cache(maxsize=None) - def _get_object_feats(self, num_objects, node_feature_dim, T): # no associated joint value - # create tensor of same dimension return super()._get_node_feats(data, t) as node_feats - obj_state_tensor = torch.zeros((num_objects, T, node_feature_dim)) - return obj_state_tensor def _get_object_pos(self, data, t): obj_state_tensor = torch.zeros((self.num_objects, 9)) # 3 for position, 6 for 6D rotation @@ -117,7 +112,7 @@ def _get_node_pos(self, data, t): node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) return node_pos - def _get_node_feats(self, data, t_vals, control_mode=None): + def _get_target_actions(self, data, t_vals, actions, control_mode=None): ''' Calculate node features for time steps t_vals t_vals: list of time steps @@ -127,8 +122,9 @@ def _get_node_feats(self, data, t_vals, control_mode=None): if control_mode is None: control_mode = self.control_mode if control_mode == "OSC_POSE": - for i in range(self.num_robots): - node_feats.append(torch.cat([torch.tensor(data["robot0_eef_pos"][t_vals]), torch.tensor(data["robot0_eef_quat"][t_vals])], dim=0)) + # actions_raw is (episode_len, 7) — index time axis first, action dims second + action_slice = torch.tensor(np.array(actions[t_vals, :]), dtype=torch.float32) # (T, 7) + node_feats.append(action_slice.T.unsqueeze(2)) # (7, T, 1) elif control_mode == "JOINT_POSITION": for i in range(self.num_robots): node_feats.append(torch.cat([ @@ -140,55 +136,46 @@ def _get_node_feats(self, data, t_vals, control_mode=None): torch.tensor(data[f"robot{i}_joint_vel"][t_vals]), torch.tensor(data[f"robot{i}_gripper_qvel"][t_vals])], dim=1).T.unsqueeze(2)) node_feats = torch.cat(node_feats, dim=0) # [num_robots*num_joints, T, 1] - obj_state_tensor = self._get_object_feats(self.num_objects, self.node_feature_dim, T) - - # add dimension for NODE_TYPE - node_feats = torch.cat((node_feats, self.ROBOT_NODE_TYPE*torch.ones((node_feats.shape[0], node_feats.shape[1],1))), dim=2) - obj_state_tensor[:, :, -1] = self.OBJECT_NODE_TYPE - - node_feats = torch.cat((node_feats, obj_state_tensor), dim=0) return node_feats - def get_y_feats(self, data, t_vals): + def _get_x_feats(self, data, t_vals): ''' Calculate observation node features for time steps t_vals ''' T = len(t_vals) - y = [] + x = [] for i in range(self.num_robots): - y.append(torch.cat([ + x.append(torch.cat([ torch.tensor(data[f"robot{i}_joint_pos"][t_vals]), torch.tensor(data[f"robot{i}_gripper_qpos"][t_vals])], dim=1).T.unsqueeze(2)) - y = torch.cat(y, dim=0) # [num_robots*num_joints, T, 1] + x = torch.cat(x, dim=0) # [num_robots*num_joints, T, 1] obj_state_tensor = [] for t in t_vals: obj_state_tensor.append(self._get_object_pos(data, t)) obj_state_tensor = torch.stack(obj_state_tensor, dim=1) - # remove positions - obj_state_tensor = obj_state_tensor[:,:,3:] # add dimensions to match with obj_state_tensor for concatenation - y = torch.cat((y, torch.zeros((y.shape[0], obj_state_tensor.shape[1], obj_state_tensor.shape[2] - 1))), dim=2) + x = torch.cat((x, torch.zeros((x.shape[0], obj_state_tensor.shape[1], obj_state_tensor.shape[2] - 1))), dim=2) - y = torch.cat((y, obj_state_tensor), dim=0) + x = torch.cat((x, obj_state_tensor), dim=0) - # Add node ID to node features - node_id = torch.arange(y.shape[0]).unsqueeze(1).unsqueeze(2).repeat(1, T, 1) - y = torch.cat((y, node_id), dim=2) + # add column for node ID (used as embedding index by the model) + num_nodes = x.shape[0] + node_ids = torch.arange(num_nodes, dtype=torch.float32).unsqueeze(1).unsqueeze(1) # (N, 1, 1) + node_ids = node_ids.expand(-1, x.shape[1], -1) # (N, T, 1) + x = torch.cat((x, node_ids), dim=2) - return y + return x - def _get_node_feats_horizon(self, data, idx, horizon): + def _get_target_actions_horizon(self, data, idx, horizon, actions): ''' - Calculate node features for self.obs_horizon time steps + Calculate node features for self.pred_horizon time steps. + For OSC_POSE, ``actions`` is the episode-level actions array (Tx7). ''' - node_feats = [] - # calculate node features for timesteps idx to idx + horizon t_vals = list(range(idx, idx + horizon)) - node_feats = self._get_node_feats(data, t_vals) - return node_feats + return self._get_target_actions(data, t_vals, actions) @lru_cache(maxsize=None) def _get_edge_attrs(self, edge_index): @@ -228,17 +215,17 @@ def _get_edge_index(self, num_nodes): edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return edge_index - def _get_y_horizon(self, data, idx, horizon): + def _get_x_horizon(self, data, idx, horizon): ''' Get y (observation) for time step t. Should contain only task-space joint positions. ''' y = [] if idx - horizon < 0: - y.append(self.get_y_feats(data, [0]).repeat(1,horizon-idx,1)) # use fixed first observation for beginning of episode - y.append(self.get_y_feats(data, [t for t in range(0, idx)])) + y.append(self._get_x_feats(data, [0]).repeat(1,horizon-idx,1)) # use fixed first observation for beginning of episode + y.append(self._get_x_feats(data, [t for t in range(0, idx)])) y = torch.cat(y, dim=1) else: # get all observation steps with single call - y = self.get_y_feats(data, list(range(idx - horizon, idx))) + y = self._get_x_feats(data, list(range(idx - horizon, idx))) return y @@ -247,21 +234,25 @@ def process(self): for key in tqdm(self.dataset_keys): episode_length = self.dataset_root[f"data/{key}/obs/object"].shape[0] - + data_raw = self.dataset_root["data"][key]["obs"] + # actions are stored alongside obs, not inside it + actions_raw = self.dataset_root["data"][key]["actions"] # (episode_len, 7) for OSC_POSE + for idx in range(1, episode_length - self.pred_horizon): - - data_raw = self.dataset_root["data"][key]["obs"] - node_feats = self._get_node_feats_horizon(data_raw, idx, self.pred_horizon) - edge_index = self._get_edge_index(node_feats.shape[0]) + actions = self._get_target_actions_horizon(data_raw, idx, self.pred_horizon, + actions=actions_raw) + observations= self._get_x_horizon(data_raw, idx, self.obs_horizon) + # edge_index must cover ALL observation nodes (robot + objects) + # so that PyG batching (which uses x's node count) is consistent + edge_index = self._get_edge_index(observations.shape[0]) edge_attrs = self._get_edge_attrs(edge_index) - y = self._get_y_horizon(data_raw, idx, self.obs_horizon) - pos = self._get_node_pos(data_raw, idx - 1) + pos = self._get_node_pos(data_raw, idx) data = Data( - x=node_feats, + x=observations, edge_index=edge_index, edge_attr=edge_attrs, - y=y, + y=actions, time=torch.tensor([idx], dtype=torch.long)/ episode_length, pos=pos ) @@ -306,8 +297,8 @@ def get_normalizer(self): data_action = [] for idx in range(self.len()): data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt')) - data_obs.append(data["y"]) - data_action.append(data["x"]) + data_obs.append(data["x"]) # x = observations + data_action.append(data["y"]) # y = actions data_obs = torch.cat(data_obs, dim=1) data_action = torch.cat(data_action, dim=1) @@ -329,5 +320,6 @@ def to_obs_deque(self, data): return obs_deque def get_action(self, data): - return data.x[:self.eef_idx[-1] + 1,:,0].T.numpy() + # y = actions; for OSC_POSE: (7+obj, T, 1) → extract 7 action dims + return data.y[:self.eef_idx[-1] + 1,:,0].T.numpy() \ No newline at end of file diff --git a/imitation/env/robomimic_graph_wrapper.py b/imitation/env/robomimic_graph_wrapper.py index 88d52c2..e6d2ed6 100644 --- a/imitation/env/robomimic_graph_wrapper.py +++ b/imitation/env/robomimic_graph_wrapper.py @@ -81,7 +81,7 @@ def __init__(self, has_offscreen_renderer=output_video, # not needed since not using pixel obs has_renderer=has_renderer, # make sure we can render to the screen reward_shaping=True, # use dense rewards - control_freq=30, # control should happen fast enough so that simulation looks smooth + control_freq=20, # must match dataset recording frequency (20 Hz for lift/ph) horizon=max_steps, # long horizon so we can sample high rewards controller_configs=controller_config, ), @@ -107,6 +107,9 @@ def __init__(self, self.eef_idx = [-1, 8] # end-effector index if self.num_robots == 2: self.eef_idx += [17] + # Initialise the last action buffer (used by _get_node_feats for OSC_POSE). + # For OSC_POSE, actions are 7D: [dx, dy, dz, ax, ay, az, gripper] + self._last_osc_action = np.zeros(7 * self.num_robots, dtype=np.float32) def scaled_tanh(self, x, max_val=0.01, min_val=-0.07, k=200, threshold=-0.03): @@ -114,7 +117,7 @@ def scaled_tanh(self, x, max_val=0.01, min_val=-0.07, k=200, threshold=-0.03): def control_loop(self, tgt_jpos, max_n=20, eps=0.02): obs = self.env._get_observations() - tgt_jpos[-1] = self.scaled_tanh(tgt_jpos[-1]) + # tgt_jpos[-1] = self.scaled_tanh(tgt_jpos[-1]) for i in range(max_n): obs = self.env._get_observations() current_jpos = [] @@ -131,26 +134,36 @@ def control_loop(self, tgt_jpos, max_n=20, eps=0.02): if self.has_renderer: self.env.render() return obs_final, reward, done, _, info - - def _get_object_feats(self, data): - # create tensor of same dimension return super()._get_node_feats(data, t) as node_feats - obj_state_tensor = torch.zeros((self.num_objects, self.node_feature_dim)) - return obj_state_tensor + def _get_object_pos(self, data): obj_state_tensor = torch.zeros((self.num_objects, 9)) # 3 for position, 6 for rotation + obj_buf = data["object"] + buf_len = len(obj_buf) for object, object_state_items in enumerate(self.object_state_keys.values()): - i = 0 + i = 0 # offset into the flat object buffer + out_col = 0 # column in obj_state_tensor (quat->6d expands by 2) for object_state in object_state_items: + field_size = self.object_state_sizes[object_state] + if i + field_size > buf_len: + # Field not present in this observation buffer — leave as zeros + out_col += 6 if "quat" in object_state else field_size + i += field_size + continue if "quat" in object_state: - assert self.object_state_sizes[object_state] == 4, "Quaternion must have size 4" - rot = self.rotation_transformer.forward(torch.tensor(data["object"][i:i + self.object_state_sizes[object_state]])) - obj_state_tensor[object,i:i + 6] = rot + assert field_size == 4, "Quaternion must have size 4" + rot = self.rotation_transformer.forward( + torch.tensor(obj_buf[i:i + field_size], dtype=torch.float32) + ) + obj_state_tensor[object, out_col:out_col + 6] = rot + out_col += 6 else: - obj_state_tensor[object,i:i + self.object_state_sizes[object_state]] = torch.from_numpy(data["object"][i:i + self.object_state_sizes[object_state]]) - - i += self.object_state_sizes[object_state] + obj_state_tensor[object, out_col:out_col + field_size] = torch.from_numpy( + obj_buf[i:i + field_size] + ) + out_col += field_size + i += field_size return obj_state_tensor @@ -172,46 +185,29 @@ def _get_node_pos(self, data): node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) return node_pos - - - def _get_node_feats(self, data, control_mode=None): - ''' - Returns node features from data - ''' - if control_mode is None: - control_mode = self.control_mode - node_feats = [] - for i in range(self.num_robots): - if control_mode == "OSC_POSE": - node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_quat"])], dim=0).reshape(1, -1)) # add dimension - elif control_mode == "JOINT_VELOCITY": - node_feats.append(torch.tensor([*data[f"robot{i}_joint_vel"], *data[f"robot{i}_gripper_qvel"]]).reshape(1,-1).T) - elif control_mode == "JOINT_POSITION": - node_feats.append(torch.tensor([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]]).reshape(1,-1).T) - node_feats = torch.cat(node_feats, dim=0) - return node_feats - - def _get_y_feats(self, data): + def _get_x_feats(self, data): ''' - Returns observation node features from data + Returns observation node features from data. + Output shape: (num_nodes, obs_feat_dim) where obs_feat_dim includes node_type. ''' - y = [] + x = [] for i in range(self.num_robots): - y.append(torch.tensor([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]]).reshape(1,-1).T) - y = torch.cat(y, dim=0) + x.append(torch.tensor([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]], + dtype=torch.float32).reshape(-1, 1)) # (9, 1) + x = torch.cat(x, dim=0) # (9, 1) obj_state_tensor = self._get_object_pos(data) - # remove positions - obj_state_tensor = obj_state_tensor[:,3:] - # add dimensions to match with obj_state_tensor for concatenation - y = torch.cat([y, torch.zeros((y.shape[0], obj_state_tensor.shape[1]-y.shape[1]))], dim=1) - y = torch.cat([y, obj_state_tensor], dim=0) + # pad robot features to match object feature width for concatenation + x = torch.cat([x, torch.zeros((x.shape[0], obj_state_tensor.shape[1] - x.shape[1]))], dim=1) # (9, 9) + x = torch.cat([x, obj_state_tensor], dim=0) # (10, 9) - # add node ID to node features - node_id = torch.arange(y.shape[0]).reshape(-1, 1) - y = torch.cat((y, node_id), dim=-1) - return y + # add column for node ID (used as embedding index by the model) + num_nodes = x.shape[0] + node_ids = torch.arange(num_nodes, dtype=torch.float32).unsqueeze(1) # (N, 1) + x = torch.cat([x, node_ids], dim=1) # (10, 7) + + return x @lru_cache(maxsize=128) def _get_edge_index(self, num_nodes): @@ -254,57 +250,64 @@ def _get_edge_attrs(self, edge_index): def _robosuite_obs_to_robomimic_graph(self, obs): ''' - Converts robosuite Gym Wrapper observations to the RobomimicGraphDataset format - * requires robot_joint to be "flagged" in robomimic environment + Converts robosuite Gym Wrapper (robot0_proprio-state, object-state) flat + observations into the RobomimicGraphDataset format. + + robot0_proprio-state layout (32 elements per robot): + [0:7] cos(joint_pos) — 7 joint cosines + [7:14] sin(joint_pos) — 7 joint sines + [14:21] joint_vel — 7 joint velocities + [21:24] eef_pos — 3D end-effector position + [24:28] eef_quat_raw — 4D end-effector quaternion + [28:30] gripper_qpos — 2 gripper finger positions + [30:32] gripper_qvel — 2 gripper finger velocities ''' - node_feats = torch.tensor([]) - node_pos = torch.tensor([]) - node_obs = torch.tensor([]) + PROPRIO_SIZE = 32 # elements per robot in proprio-state robot_i_data = {} for i in range(self.num_robots): - j = i*39 - - # 7 - joint angle values - robot_joint_pos = obs[j:j + 7] - # 7 - sin of joint angles - # robot_joint_sin = obs[j + 7:j + 14] - # 7 - cos of joint angles - # robot_joint_cos = obs[j + 14:j + 21] - # 7 - joint velocities - robot_joint_vel = obs[j + 21:j + 28] - eef_pose = obs[j + 28:j + 31] - eef_quat = obs[j + 31:j + 35] - eef_6d = self.rotation_transformer.forward(eef_quat) - gripper_pose = obs[j + 35:j + 37] - gripper_vel = obs[j + 37:j + 39] - # Skip 2 - gripper joint velocities + j = i * PROPRIO_SIZE + + # Reconstruct raw joint_pos from sin/cos (robosuite stores sin/cos, not raw) + joint_cos = obs[j:j + 7] + joint_sin = obs[j + 7:j + 14] + robot_joint_pos = np.arctan2(joint_sin, joint_cos) + + robot_joint_vel = obs[j + 14:j + 21] + eef_pose = obs[j + 21:j + 24] + eef_quat_raw = obs[j + 24:j + 28] + eef_6d = self.rotation_transformer.forward( + torch.tensor(eef_quat_raw, dtype=torch.float32) + ) + gripper_pose = obs[j + 28:j + 30] + gripper_vel = obs[j + 30:j + 32] + robot_i_data.update({ - f"robot{i}_joint_pos": robot_joint_pos, - f"robot{i}_joint_vel": robot_joint_vel, - f"robot{i}_eef_pos": eef_pose, - f"robot{i}_eef_quat": eef_6d, + f"robot{i}_joint_pos": robot_joint_pos, + f"robot{i}_joint_vel": robot_joint_vel, + f"robot{i}_eef_pos": eef_pose, + f"robot{i}_eef_quat": eef_6d, f"robot{i}_gripper_qpos": gripper_pose, - f"robot{i}_gripper_qvel": gripper_vel + f"robot{i}_gripper_qvel": gripper_vel, + f"_osc_action_{i}": self._last_osc_action[i*7:(i+1)*7], }) - robot_i_data["object"] = obs[self.num_robots*39:] - - node_feats = torch.cat([node_feats, self._get_node_feats(robot_i_data)], dim=0) + robot_i_data["object"] = obs[self.num_robots * PROPRIO_SIZE:] node_pos = self._get_node_pos(robot_i_data) - y = torch.cat([node_obs, self._get_y_feats(robot_i_data)], dim=0) - obj_feats_tensor = self._get_object_feats(obs) + observations = self._get_x_feats(robot_i_data) - # add dimension for NODE_TYPE, which is 0 for robot and 1 for objects - node_feats = torch.cat((node_feats, self.ROBOT_NODE_TYPE*torch.ones((node_feats.shape[0],1))), dim=1) - obj_feats_tensor[:, -1] = self.OBJECT_NODE_TYPE - - node_feats = torch.cat((node_feats, obj_feats_tensor), dim=0) - - edge_index = self._get_edge_index(node_feats.shape[0]) + # Use total node count (robot + objects) for edge computation + num_nodes = observations.shape[0] + edge_index = self._get_edge_index(num_nodes) edge_attrs = self._get_edge_attrs(edge_index) - # create graph - graph = torch_geometric.data.Data(x=node_feats, edge_index=edge_index, edge_attr=edge_attrs, y=y, pos=node_pos) + # create graph: x = observations, y is not set at inference + # (the policy generates actions via diffusion, it doesn't need y) + graph = torch_geometric.data.Data( + x=observations, + edge_index=edge_index, + edge_attr=edge_attrs, + pos=node_pos + ) return graph @@ -315,22 +318,25 @@ def reset(self): def step(self, action): - final_action = [] - for i in range(self.num_robots): - ''' - Robosuite's action space is composed of 7 joint velocities and 1 gripper velocity, while - in the robomimic datasets, it's composed of 7 joint velocities and 2 gripper velocities (for each "finger"). - ''' - j = i*9 - robot_joint_pos = action[j:j + 7] - robot_gripper_pos = action[j + 8] - final_action = [*final_action, *robot_joint_pos, robot_gripper_pos] - if self.control_mode == "JOINT_VELOCITY": - obs, reward, done, _, info = self.env.step(final_action) - elif self.control_mode == "JOINT_POSITION": - obs, reward, done, _, info = self.control_loop(final_action) + if self.control_mode == "OSC_POSE": + # OSC_POSE actions are 7-D EEF vectors — pass through directly. + obs, reward, done, _, info = self.env.step(action) else: - raise ValueError("Invalid control mode") + # JOINT_VELOCITY / JOINT_POSITION: action is a 9-D graph vector per robot + # (7 joint DOF + 2 gripper fingers). Robosuite expects 8-D (7 + 1 gripper). + # Use finger 0 (index j+7) — the two fingers move symmetrically. + final_action = [] + for i in range(self.num_robots): + j = i * 9 + robot_joint_pos = action[j:j + 7] + robot_gripper_pos = action[j + 7] + final_action = [*final_action, *robot_joint_pos, robot_gripper_pos] + if self.control_mode == "JOINT_VELOCITY": + obs, reward, done, _, info = self.env.step(final_action) + elif self.control_mode == "JOINT_POSITION": + obs, reward, done, _, info = self.control_loop(final_action) + else: + raise ValueError(f"Invalid control mode: {self.control_mode}") if reward == 1: done = True diff --git a/imitation/policy/graph_ddpm_policy.py b/imitation/policy/graph_ddpm_policy.py index ab6d4f5..3bd24cf 100644 --- a/imitation/policy/graph_ddpm_policy.py +++ b/imitation/policy/graph_ddpm_policy.py @@ -103,8 +103,8 @@ def save_nets(self, ckpt_path): def MOCK_get_graph_from_obs(self): # for testing purposes, remove before merge # plays back observation from dataset playback_graph = self.dataset[self.playback_count] - obs_cond = playback_graph.y - playback_graph.x = playback_graph.x[:,0,:] + obs_cond = playback_graph.x + playback_graph.y = playback_graph.y[:,0,:] self.playback_count += 7 log.info(f"Playing back observation {self.playback_count}") return obs_cond, playback_graph @@ -116,24 +116,29 @@ def get_action(self, obs_deque): pos = [] G_t = obs_deque[-1] for i in range(len(obs_deque)): - obs_cond.append(obs_deque[i].y.unsqueeze(1)) + obs_cond.append(obs_deque[i].x.unsqueeze(1)) # x = observations pos.append(obs_deque[i].pos) obs = torch.cat(obs_cond, dim=1) obs_pos = torch.cat(pos, dim=0) if self.use_normalization: nobs = self.dataset.normalize_data(obs, stats_key='obs') nobs[:,:,-1] = obs[:,:,-1] # skip normalization for node IDs - self.last_naction = self.dataset.normalize_data(G_t.x.unsqueeze(1), stats_key='action').to(self.device) + # Use last obs step as initial action estimate (y holds actions at inference via dataset replay) + if hasattr(G_t, 'y') and G_t.y is not None: + self.last_naction = self.dataset.normalize_data(G_t.y.unsqueeze(1), stats_key='action').to(self.device) else: - self.last_naction = G_t.x.unsqueeze(1).to(self.device) + if hasattr(G_t, 'y') and G_t.y is not None: + self.last_naction = G_t.y.unsqueeze(1).to(self.device) with torch.no_grad(): - # initialize action from Guassian noise - noisy_action = torch.randn((G_t.x.shape[0], self.pred_horizon, self.node_feature_dim), device=self.device) + # initialize action from Gaussian noise + # For per-node actions: (num_nodes, pred_horizon, node_feature_dim) + num_nodes = G_t.x.shape[0] + noisy_action = torch.randn((num_nodes, self.pred_horizon, self.node_feature_dim), device=self.device) naction = noisy_action - if self.keep_first_action: - naction[:,0,:] = self.last_naction[:,-1,:1] + if self.keep_first_action and self.last_naction.shape[0] == num_nodes: + naction[:,0,:] = self.last_naction[:,-1,:self.node_feature_dim] # init scheduler self.noise_scheduler.set_timesteps(self.num_diffusion_iters) @@ -157,15 +162,16 @@ def get_action(self, obs_deque): sample=naction ).prev_sample - if self.keep_first_action: - naction[:,0,:] = self.last_naction[:,-1,:1] + if self.keep_first_action and self.last_naction.shape[0] == num_nodes: + naction[:,0,:] = self.last_naction[:,-1,:self.node_feature_dim] - # add node dimension, to pass through normalizer - naction = torch.cat([naction, torch.zeros((naction.shape[0], self.pred_horizon, 1), device=self.device)], dim=2) naction = naction.detach().to('cpu') if self.use_normalization: naction = self.dataset.unnormalize_data(naction, stats_key='action').numpy() + else: + naction = naction.numpy() + # Extract action: first action_dim nodes, all pred_horizon steps, first feature action = naction[:self.action_dim,:,0].T # (action_horizon, action_dim) @@ -187,20 +193,20 @@ def validate(self, dataset=None, model_path="last.pt"): with torch.no_grad(): val_loss = list() for batch in dataloader: - nobs = batch.y + # x = observations, y = actions + nobs = batch.x if self.use_normalization: # normalize observation - nobs = self.dataset.normalize_data(batch.y, stats_key='obs').to(self.device) - nobs[:,:,-1] = batch.y[:,:,-1] # skip normalization for node IDs + nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) + nobs[:,:,-1] = batch.x[:,:,-1] # skip normalization for node IDs # normalize action - naction = self.dataset.normalize_data(batch.x, stats_key='action').to(self.device) + naction = self.dataset.normalize_data(batch.y, stats_key='action').to(self.device) + else: + naction = batch.y.to(self.device) B = batch.num_graphs # observation as FiLM conditioning - # (B, node, obs_horizon, obs_dim) obs_cond = nobs - naction = naction[:,:,:1] # joint value - # (B, obs_horizon * obs_dim) # sample a diffusion iteration for each data point timesteps = torch.randint( @@ -208,14 +214,11 @@ def validate(self, dataset=None, model_path="last.pt"): (B,), device=self.device ).long() - # add noise to the clean images according to the noise magnitude at each diffusion iteration - # (this is the forward diffusion process) - - # split naction into (B, N_nodes, pred_horizon, node_feature_dim), selecting the items from each batch.batch - naction = torch.cat([naction[batch.batch == i].unsqueeze(0) for i in batch.batch.unique()], dim=0) + # Reshape: batch.y is (action_dim*B, pred_horizon, node_feature_dim) + # → (B, action_dim, pred_horizon, node_feature_dim) + naction = naction.view(B, self.action_dim, self.pred_horizon, self.node_feature_dim) # sample noise to add to actions - noise = torch.randn(naction.shape, device=self.device, dtype=torch.float32) noisy_actions = self.noise_scheduler.add_noise( @@ -306,19 +309,20 @@ def train(self, # batch loop with tqdm(dataloader, desc='Batch', leave=False) as tepoch: for batch in tepoch: - nobs = batch.y + # x = observations, y = actions + nobs = batch.x if self.use_normalization: # normalize observation - nobs = self.dataset.normalize_data(batch.y, stats_key='obs').to(self.device) - nobs[:,:,-1] = batch.y[:,:,-1] # skip normalization for node IDs + nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) + nobs[:,:,-1] = batch.x[:,:,-1] # skip normalization for node IDs # normalize action - naction = self.dataset.normalize_data(batch.x, stats_key='action').to(self.device) + naction = self.dataset.normalize_data(batch.y, stats_key='action').to(self.device) + else: + naction = batch.y.to(self.device) B = batch.num_graphs # observation as FiLM conditioning - # (B, node, obs_horizon, obs_dim) obs_cond = nobs - naction = naction[:,:,:1] # joint value # sample a diffusion iteration for each data point timesteps = torch.randint( @@ -326,13 +330,9 @@ def train(self, (B,), device=self.device ).long() - # add noise to the clean images according to the noise magnitude at each diffusion iteration - # (this is the forward diffusion process) - # split naction into (B, N_nodes, pred_horizon, node_feature_dim), selecting the items from each batch.batch - - naction = torch.cat([naction[batch.batch == i].unsqueeze(0) for i in batch.batch.unique()], dim=0) - - # add noise to first action instead of sampling from Gaussian + # Reshape: batch.y is (action_dim*B, pred_horizon, node_feature_dim) + # → (B, action_dim, pred_horizon, node_feature_dim) + naction = naction.view(B, self.action_dim, self.pred_horizon, self.node_feature_dim) noise = torch.randn(naction.shape, device=self.device, dtype=torch.float32) noisy_actions = self.noise_scheduler.add_noise( From e090c113d19657b8be5ef16242caa1224a067f81 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 9 Mar 2026 12:57:40 +0100 Subject: [PATCH 06/22] Update tests for x=obs, y=actions convention --- tests/test_graph_data.py | 6 +- tests/test_policy_dataset_replay.py | 346 +++++++++++++-------------- tests/test_train_eval_consistency.py | 4 +- 3 files changed, 168 insertions(+), 188 deletions(-) diff --git a/tests/test_graph_data.py b/tests/test_graph_data.py index ccb8c19..a50c2d1 100644 --- a/tests/test_graph_data.py +++ b/tests/test_graph_data.py @@ -48,6 +48,8 @@ def test_lift_edge_data_match(lift_wrapper, lift_dataset): def test_lift_node_data_match(lift_wrapper, lift_dataset): env_obs = lift_wrapper.reset() G_0 = lift_dataset[0] + # Both x fields should share the same node type flags (last column) assert (env_obs.x[:,-1] == G_0.x[:,0,-1]).all() - assert (env_obs.y.shape == G_0.y[:,0,:].shape) - + # Wrapper x is 2D (nodes, feat), dataset x is 3D (nodes, obs_horizon, feat) + assert (env_obs.x.shape[0] == G_0.x.shape[0]) # same number of nodes + assert (env_obs.x.shape[-1] == G_0.x.shape[-1]) # same feature dim diff --git a/tests/test_policy_dataset_replay.py b/tests/test_policy_dataset_replay.py index dbbbe91..1c2334c 100644 --- a/tests/test_policy_dataset_replay.py +++ b/tests/test_policy_dataset_replay.py @@ -13,10 +13,10 @@ Tests ----- -1. test_obs_deque_y_matches_dataset_y - The y tensor passed to the policy at inference is assembled from an - obs_deque of RobomimicGraphWrapper observations. The y tensor used - during training comes from RobomimicGraphDataset.get_y_feats(). +1. test_obs_x_matches_dataset_x + The x tensor passed to the policy at inference is assembled from an + obs_deque of RobomimicGraphWrapper observations. The x tensor used + during training comes from RobomimicGraphDataset._get_x_feats(). They must agree for the same joint state. 2. test_dataset_playback_obs_format @@ -31,9 +31,9 @@ This confirms that the action representation used in the dataset is compatible with the wrapper's step() interface. -4. test_dataset_y_and_wrapper_y_feature_order_match - The obs feature vector (y) must have the same column ordering between - dataset and wrapper, because the normalizer is fit on the dataset's y. +4. test_dataset_x_and_wrapper_x_feature_order_match + The obs feature vector (x) must have the same column ordering between + dataset and wrapper, because the normalizer is fit on the dataset's x. Columns: [joint_pos(7), gripper_qpos(2), node_id(1)] for robot nodes. """ @@ -119,8 +119,8 @@ def dataset(): @pytest.fixture(scope="module") -def wrapper_get_y_fn(): - """Return a bound _get_y_feats callable from RobomimicGraphWrapper.""" +def wrapper_get_x_fn(): + """Return a bound _get_x_feats callable from RobomimicGraphWrapper.""" from diffusion_policy.model.common.rotation_transformer import RotationTransformer mod = _load_module("rg_wrapper", "imitation/env/robomimic_graph_wrapper.py") @@ -133,16 +133,18 @@ def wrapper_get_y_fn(): object_state_keys={"cube": ["cube_pos", "cube_quat"]}, object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, num_objects=1, + ROBOT_NODE_TYPE=1, + OBJECT_NODE_TYPE=-1, ) get_obj_pos = mod.RobomimicGraphWrapper._get_object_pos.__get__(mock) mock._get_object_pos = get_obj_pos - get_y_feats = mod.RobomimicGraphWrapper._get_y_feats.__get__(mock) - return get_y_feats + get_x_feats = mod.RobomimicGraphWrapper._get_x_feats.__get__(mock) + return get_x_feats @pytest.fixture(scope="module") -def dataset_get_y_fn(): - """Return a bound get_y_feats callable from RobomimicGraphDataset.""" +def dataset_get_x_fn(): + """Return a bound _get_x_feats callable from RobomimicGraphDataset.""" from diffusion_policy.model.common.rotation_transformer import RotationTransformer mod = _load_module("rg_dataset2", "imitation/dataset/robomimic_graph_dataset.py") @@ -156,11 +158,13 @@ def dataset_get_y_fn(): object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, num_objects=1, obs_feature_dim=7, + ROBOT_NODE_TYPE=1, + OBJECT_NODE_TYPE=-1, ) get_obj_pos = mod.RobomimicGraphDataset._get_object_pos.__get__(mock) mock._get_object_pos = get_obj_pos - get_y_feats = mod.RobomimicGraphDataset.get_y_feats.__get__(mock) - return get_y_feats + get_x_feats = mod.RobomimicGraphDataset._get_x_feats.__get__(mock) + return get_x_feats # ── helpers ─────────────────────────────────────────────────────────────────── @@ -187,28 +191,28 @@ def _build_dataset_data_dict(episode_data): } -# ── Test 1: obs y tensors agree ─────────────────────────────────────────────── +# ── Test 1: obs x tensors agree ─────────────────────────────────────────────── -class TestObsYConsistency: +class TestObsXConsistency: """ - The y tensor (observations) fed to the GDDPM must be identical whether + The x tensor (observations) fed to the GDDPM must be identical whether it comes from the dataset (training path) or from the wrapper (eval path). A mismatch here means the network sees a completely different conditioning signal at eval time than it was trained on — a guaranteed performance cliff. """ - def test_wrapper_y_matches_dataset_y_at_each_step( - self, episode_data, wrapper_get_y_fn, dataset_get_y_fn + def test_wrapper_x_matches_dataset_x_at_each_step( + self, episode_data, wrapper_get_x_fn, dataset_get_x_fn ): """ For every timestep t, compare: - - wrapper._get_y_feats(obs_dict_at_t) → shape (num_nodes, feat) - - dataset.get_y_feats(data_dict, t_vals=[t]) → shape (num_nodes, 1, feat) + - wrapper._get_x_feats(obs_dict_at_t) → shape (num_nodes, feat) + - dataset._get_x_feats(data_dict, t_vals=[t]) → shape (num_nodes, 1, feat) Both should agree on the robot-node rows (indices 0..8). - The last column (node ID) is part of y in both paths; it is the - running index 0..num_nodes-1 and should be identical. + The last column (node ID) is part of x in both paths; it is the + node type flag and should be identical. """ data_dict = _build_dataset_data_dict(episode_data) T = len(episode_data["joint_pos"]) @@ -216,52 +220,52 @@ def test_wrapper_y_matches_dataset_y_at_each_step( max_err = 0.0 worst_t = -1 for t in range(T): - obs_dict = _build_wrapper_obs_dict(episode_data, t) - y_wrapper = wrapper_get_y_fn(obs_dict) # (num_nodes, feat) - y_dataset = dataset_get_y_fn(data_dict, [t]) # (num_nodes, 1, feat) - y_ds_t = y_dataset[:, 0, :] # (num_nodes, feat) + obs_dict = _build_wrapper_obs_dict(episode_data, t) + x_wrapper = wrapper_get_x_fn(obs_dict) # (num_nodes, feat) + x_dataset = dataset_get_x_fn(data_dict, [t]) # (num_nodes, 1, feat) + x_ds_t = x_dataset[:, 0, :] # (num_nodes, feat) # Compare robot nodes only (first 9) - robot_rows_w = y_wrapper[:9, :] - robot_rows_d = y_ds_t[:9, :] + robot_rows_w = x_wrapper[:9, :] + robot_rows_d = x_ds_t[:9, :] err = float(torch.max(torch.abs(robot_rows_w - robot_rows_d)).item()) if err > max_err: max_err = err worst_t = t - print(f"\n── Wrapper y vs dataset y (robot nodes) ─────────────────") + print(f"\n── Wrapper x vs dataset x (robot nodes) ─────────────────") print(f" Steps checked : {T}") print(f" Max element error : {max_err:.6f} at step {worst_t}") assert max_err <= Y_MATCH_TOL, ( - f"Wrapper._get_y_feats and dataset.get_y_feats disagree by " + f"Wrapper._get_x_feats and dataset._get_x_feats disagree by " f"{max_err:.2e} at step {worst_t} (tolerance {Y_MATCH_TOL:.0e}).\n" f"The network sees different obs conditioning at train vs eval time.\n" f"Check that both use the same feature ordering: " f"[joint_pos(7), gripper_qpos(2), node_id(1)] for robot nodes." ) - def test_obs_y_feature_shape_is_consistent( - self, episode_data, wrapper_get_y_fn, dataset_get_y_fn + def test_obs_x_feature_shape_is_consistent( + self, episode_data, wrapper_get_x_fn, dataset_get_x_fn ): """ - y from the wrapper (single step) and from the dataset (single step) + x from the wrapper (single step) and from the dataset (single step) must have the same number of columns (feature dimensionality). """ data_dict = _build_dataset_data_dict(episode_data) obs_dict = _build_wrapper_obs_dict(episode_data, 0) - y_wrapper = wrapper_get_y_fn(obs_dict) - y_dataset = dataset_get_y_fn(data_dict, [0]) + x_wrapper = wrapper_get_x_fn(obs_dict) + x_dataset = dataset_get_x_fn(data_dict, [0]) - print(f"\n── y feature shape ──────────────────────────────────────") - print(f" wrapper y.shape : {tuple(y_wrapper.shape)}") - print(f" dataset y.shape : {tuple(y_dataset[:, 0, :].shape)}") + print(f"\n── x feature shape ──────────────────────────────────────") + print(f" wrapper x.shape : {tuple(x_wrapper.shape)}") + print(f" dataset x.shape : {tuple(x_dataset[:, 0, :].shape)}") - assert y_wrapper.shape == y_dataset[:, 0, :].shape, ( - f"y shape mismatch: wrapper {tuple(y_wrapper.shape)} vs " - f"dataset {tuple(y_dataset[:, 0, :].shape)}.\n" + assert x_wrapper.shape == x_dataset[:, 0, :].shape, ( + f"x shape mismatch: wrapper {tuple(x_wrapper.shape)} vs " + f"dataset {tuple(x_dataset[:, 0, :].shape)}.\n" f"The policy obs conditioning tensor has the wrong number of features." ) @@ -290,8 +294,8 @@ def _assemble_nobs(self, dataset, start_idx): for i in range(self.OBS_HORIZON): idx = max(0, start_idx - (self.OBS_HORIZON - 1 - i)) data = dataset.get(idx) - # data.y shape: (nodes, obs_horizon, feat) — take last step - obs_cond.append(data.y[:, -1:, :]) # (nodes, 1, feat) + # data.x shape: (nodes, obs_horizon, feat) — take last step + obs_cond.append(data.x[:, -1:, :]) # (nodes, 1, feat) obs = torch.cat(obs_cond, dim=1) # (nodes, obs_horizon, feat) return obs @@ -429,16 +433,16 @@ def test_dataset_action_produces_correct_next_obs(self, episode_data): def test_wrapper_step_with_dataset_action_matches_next_obs(self, episode_data): """ - Same as above, but stepping through RobomimicGraphWrapper.step() to - test the wrapper's action interpretation end-to-end. + Step through RobomimicGraphWrapper.step() with OSC_POSE control mode + and verify the resulting joint positions match dataset obs[t+1]. - The wrapper expects a 9-element action vector (one per graph node). - We pad the 7-DOF dataset action with zeros at positions 7 and 8 - (gripper fingers), matching the expected format. + The dataset (low_dim_v141.hdf5) was recorded with OSC_POSE, so we use + an OSC_POSE wrapper and pass the raw 7-D dataset actions directly. + This is the correct control-mode pairing for this dataset. NOTE: this test deliberately targets the first N_STEPS of demo_0 to keep runtime short. A failure indicates the wrapper's step() - action slicing is wrong. + action forwarding is wrong. """ from imitation.env.robomimic_graph_wrapper import RobomimicGraphWrapper @@ -448,7 +452,7 @@ def test_wrapper_step_with_dataset_action_matches_next_obs(self, episode_data): task="Lift", has_renderer=False, robots=["Panda"], - control_mode="JOINT_VELOCITY", + control_mode="OSC_POSE", base_link_shift=BASE_LINK_SHIFT, base_link_rotation=BASE_LINK_ROTATION, ) @@ -458,7 +462,7 @@ def test_wrapper_step_with_dataset_action_matches_next_obs(self, episode_data): wrapper.env.env.sim.set_state_from_flattened(episode_data["states"][0]) wrapper.env.env.sim.forward() - actions_raw = episode_data["actions"] # (T, 7) raw velocities + actions_raw = episode_data["actions"] # (T, 7) OSC_POSE actions joint_pos = episode_data["joint_pos"] # (T, 7) ground truth max_err = 0.0 @@ -466,22 +470,10 @@ def test_wrapper_step_with_dataset_action_matches_next_obs(self, episode_data): per_step = [] for t in range(min(self.N_STEPS, len(actions_raw) - 1)): - # Pad to 9-element graph action: [j0..j6, gripper_f0, gripper_f1] - # Dataset stores 7-DOF velocities; gripper comes from gripper_qvel - gripper_vel = episode_data["gripper_qvel"][t] - - # Build and pass a 9-element action to wrapper.step() - graph_action = np.concatenate([ - actions_raw[t], # 7 joint velocities / OSC DOF - gripper_vel[:2], # 2 gripper DOF - ]) # total: 9 elements - - graph_obs, _, done, _ = wrapper.step(graph_action) - - # Extract joint_pos from the graph observation's y field - # y shape: (num_nodes, feat) where feat = [jp0..jp6, gp0, gp1, node_id] - # Robot nodes 0..8; joint pos is stored in y[:9, 0..6] - live_jpos = graph_obs.y[:7, 0].detach().numpy() # nodes 0-6 → 7 joints + # For OSC_POSE, pass the 7-D action directly (no padding needed) + graph_obs, _, done, _ = wrapper.step(actions_raw[t]) + + live_jpos = graph_obs.x[:7, 0].detach().numpy() # nodes 0-6 → 7 joints ds_jpos = joint_pos[t + 1] err = float(np.max(np.abs(live_jpos - ds_jpos))) @@ -495,7 +487,7 @@ def test_wrapper_step_with_dataset_action_matches_next_obs(self, episode_data): wrapper.close() - print(f"\n── Wrapper step joint_pos match ─────────────────────────") + print(f"\n── Wrapper step joint_pos match (OSC_POSE) ──────────────") print(f" Steps replayed : {len(per_step)}") print(f" Max joint error : {max_err:.5f} rad at step {worst_t}") print(f" Mean joint error : {np.mean(per_step):.5f} rad") @@ -503,93 +495,82 @@ def test_wrapper_step_with_dataset_action_matches_next_obs(self, episode_data): assert max_err <= JOINT_STEP_TOL, ( f"wrapper.step() joint_pos error {max_err:.5f} rad at step {worst_t} " f"exceeds {JOINT_STEP_TOL} rad.\n" - f"This verifies that even with perfect dataset actions, the wrapper " - f"does not correctly advance the simulator state.\n" + f"This verifies that OSC_POSE dataset actions correctly advance the " + f"simulator state through the wrapper.\n" f"Likely causes: (1) control_freq mismatch between dataset and wrapper, " - f"(2) wrong action slicing (action[j+8] instead of action[j+7]), " - f"(3) wrong control mode (OSC_POSE vs JOINT_VELOCITY)." + f"(2) wrong control mode, (3) sim state restoration incomplete." ) -# ── Test 4: y feature ordering ──────────────────────────────────────────────── +# ── Test 4: x feature ordering ──────────────────────────────────────────────── -class TestObsYFeatureOrdering: +class TestObsXFeatureOrdering: """ - Validate that wrapper._get_y_feats and dataset.get_y_feats produce + Validate that wrapper._get_x_feats and dataset._get_x_feats produce identical feature *ordering* for robot nodes: - col 0..6 : joint_pos (7 values) - col 7..8 : gripper_qpos (2 values) - col 9 : node_id + col 0..6 : joint_pos (7 values) — stored sparsely, one per node + col 7..8 : gripper_qpos (2 values) — stored sparsely + col -1 : sequential node ID (0-indexed, used as embedding index) A column-ordering mismatch would mean the normalizer scales the wrong physical quantities, making the policy conditioning signal meaningless. """ - def test_robot_y_columns_are_joint_pos_then_gripper_then_id( - self, episode_data, wrapper_get_y_fn, dataset_get_y_fn + def test_robot_x_columns_agree_between_wrapper_and_dataset( + self, episode_data, wrapper_get_x_fn, dataset_get_x_fn ): """ - At a known timestep, check that columns 0-6 of y[:9] match joint_pos, - columns 7-8 match gripper_qpos, and column 9 (if present) matches - the node index 0..8. + At a known timestep, check that the wrapper and dataset x feature + tensors agree on robot nodes and sequential node IDs (last column). - This pins down the actual in-memory layout, making any accidental - reordering immediately visible. + Node IDs are sequential integers (0, 1, ..., N-1) used as embedding + indices by EGraphConditionEncoder — not node-type flags. """ t = 5 # arbitrary mid-episode step obs_dict = _build_wrapper_obs_dict(episode_data, t) data_dict = _build_dataset_data_dict(episode_data) - y_wrapper = wrapper_get_y_fn(obs_dict) # (num_nodes, feat) - y_dataset = dataset_get_y_fn(data_dict, [t])[:, 0, :] # (num_nodes, feat) - - jp = torch.tensor(episode_data["joint_pos"][t]) # (7,) - gp = torch.tensor(episode_data["gripper_qpos"][t]) # (2,) - - # Robot nodes 0..6 correspond to 7 joints; nodes 7 & 8 are gripper nodes. - # get_y_feats packs each robot node with its own joint feature: - # node i → [joint_i_val, 0, 0, ..., node_id] (sparse, one joint per node) + x_wrapper = wrapper_get_x_fn(obs_dict) # (num_nodes, feat) + x_dataset = dataset_get_x_fn(data_dict, [t])[:, 0, :] # (num_nodes, feat) - print(f"\n── y feature ordering check (t={t}) ──────────────────────") - print(f" Wrapper y[:10,:] =\n{y_wrapper[:10,:]}") - print(f" Dataset y[:10,:] =\n{y_dataset[:10,:]}") - print(f" Expected jp: {jp.numpy()}") - print(f" Expected gp: {gp.numpy()}") + num_nodes = x_wrapper.shape[0] - # Verify node IDs (last column) for both wrapper and dataset - num_robot_nodes = 9 - expected_node_ids = torch.arange(num_robot_nodes, dtype=y_wrapper.dtype) + print(f"\n── x feature ordering check (t={t}) ──────────────────────") + print(f" Wrapper x[:10,:] =\n{x_wrapper[:10,:]}") + print(f" Dataset x[:10,:] =\n{x_dataset[:10,:]}") - wrapper_node_ids = y_wrapper[:num_robot_nodes, -1] - dataset_node_ids = y_dataset[:num_robot_nodes, -1] + # Last column must be sequential node IDs (0, 1, ..., N-1) for both + expected_ids = torch.arange(num_nodes, dtype=torch.float32) + wrapper_ids = x_wrapper[:, -1] + dataset_ids = x_dataset[:, -1] - assert torch.allclose(wrapper_node_ids, expected_node_ids, atol=1e-3), ( - f"Wrapper y node IDs {wrapper_node_ids.tolist()} != expected {expected_node_ids.tolist()}.\n" - f"The node-ID column ordering is wrong in the wrapper." + assert torch.all(wrapper_ids == expected_ids), ( + f"Wrapper x node IDs {wrapper_ids.tolist()} != expected {expected_ids.tolist()}.\n" + f"The node-ID column (last) must be sequential 0..N-1 for the embedding lookup." ) - assert torch.allclose(dataset_node_ids, expected_node_ids, atol=1e-3), ( - f"Dataset y node IDs {dataset_node_ids.tolist()} != expected {expected_node_ids.tolist()}.\n" - f"The node-ID column ordering is wrong in the dataset." + assert torch.all(dataset_ids == expected_ids), ( + f"Dataset x node IDs {dataset_ids.tolist()} != expected {expected_ids.tolist()}.\n" + f"The node-ID column (last) must be sequential 0..N-1 for the embedding lookup." ) -# ── Test 5: OSC_POSE node features shape and content ───────────────────────── +# ── Test 5: OSC_POSE observation features shape and content ─────────────────── class TestOscPoseNodeFeats: """ - Verify that RobomimicGraphWrapper._get_node_feats for control_mode='OSC_POSE' - produces the same 9-node structure as JOINT modes, preserving graph topology. - - Expected behavior (node_feature_dim=1): - - Shape: (9, 1) -- 9 robot nodes x 1 scalar feature each - - Nodes 0-2: eef_pos components (3D position) - - Nodes 3-6: eef_quat components (4D quaternion) - - Node 7: unused (0.0) - - Node 8: mean gripper_qpos + Verify that RobomimicGraphWrapper._get_x_feats produces the same 10-node + graph observation structure for OSC_POSE as for JOINT modes. + + For OSC_POSE, the graph observations (x) are identical to JOINT modes: + joint_pos + gripper_qpos per robot node, object pos/rot for object nodes. + Only the actions (y) differ (flat 7-D EEF vector vs per-node joint values). + + Expected shape: (10, K) — 9 robot + 1 object nodes, K features including + the sequential node-ID column at the end. """ - def _make_wrapper_node_feats_fn(self): - """Build bound _get_node_feats callable with OSC_POSE control mode.""" + def _make_wrapper_x_feats_fn(self): + """Build bound _get_x_feats callable with OSC_POSE control mode.""" import types from diffusion_policy.model.common.rotation_transformer import RotationTransformer @@ -603,80 +584,77 @@ def _make_wrapper_node_feats_fn(self): object_state_keys={"cube": ["cube_pos", "cube_quat"]}, object_state_sizes={"cube_pos": 3, "cube_quat": 4, "gripper_to_cube_pos": 3}, num_objects=1, + ROBOT_NODE_TYPE=1, + OBJECT_NODE_TYPE=-1, ) - return mod.RobomimicGraphWrapper._get_node_feats.__get__(mock) + get_obj_pos = mod.RobomimicGraphWrapper._get_object_pos.__get__(mock) + mock._get_object_pos = get_obj_pos + return mod.RobomimicGraphWrapper._get_x_feats.__get__(mock) - def _build_obs_dict_osc(self, episode_data, t): - """Build obs dict with eef_pos and eef_quat_raw for OSC_POSE _get_node_feats.""" - with h5py.File(DATASET_PATH, "r") as f: - ep = f["data/demo_0"] - eef_pos = ep["obs/robot0_eef_pos"][t] - eef_quat = ep["obs/robot0_eef_quat"][t] + def _build_obs_dict(self, episode_data, t): + """Build obs dict for _get_x_feats (uses joint_pos + gripper_qpos).""" return { "robot0_joint_pos": episode_data["joint_pos"][t], "robot0_joint_vel": episode_data["joint_vel"][t], "robot0_gripper_qpos": episode_data["gripper_qpos"][t], "robot0_gripper_qvel": episode_data["gripper_qvel"][t], - "robot0_eef_pos": eef_pos, - "robot0_eef_quat_raw": eef_quat, # raw 4D + "object": episode_data["object_obs"][t], } - def test_osc_pose_node_feats_shape(self, episode_data): - """OSC_POSE node features must be (9, 1) -- matching JOINT mode topology.""" - get_node_feats = self._make_wrapper_node_feats_fn() - obs_dict = self._build_obs_dict_osc(episode_data, 10) - feats = get_node_feats(obs_dict, control_mode="OSC_POSE") - assert feats.shape == (9, 1), ( - f"OSC_POSE node features have shape {tuple(feats.shape)}, expected (9, 1).\n" - f"The graph topology must have 9 robot nodes for GDDPM compatibility." + def test_osc_pose_x_feats_shape(self, episode_data): + """_get_x_feats for OSC_POSE must return (10, K) — 9 robot + 1 object nodes.""" + get_x_feats = self._make_wrapper_x_feats_fn() + obs_dict = self._build_obs_dict(episode_data, 10) + feats = get_x_feats(obs_dict) + assert feats.shape[0] == 10, ( + f"OSC_POSE x feats have {feats.shape[0]} nodes, expected 10 " + f"(9 robot + 1 object). Graph topology must be preserved." ) - def test_osc_pose_nodes_0to2_match_eef_pos(self, episode_data): - """Nodes 0-2 must match the raw eef_pos values.""" - get_node_feats = self._make_wrapper_node_feats_fn() - t = 10 - obs_dict = self._build_obs_dict_osc(episode_data, t) - feats = get_node_feats(obs_dict, control_mode="OSC_POSE") - expected_pos = torch.tensor(obs_dict["robot0_eef_pos"], dtype=torch.float32) - assert torch.allclose(feats[:3, 0], expected_pos, atol=1e-5), ( - f"OSC_POSE nodes 0-2 (eef_pos) mismatch:\n" - f" got {feats[:3, 0].tolist()}\n" - f" expected {expected_pos.tolist()}" + def test_osc_pose_x_feats_node_ids_sequential(self, episode_data): + """Last column of _get_x_feats must be sequential node IDs 0..9.""" + get_x_feats = self._make_wrapper_x_feats_fn() + obs_dict = self._build_obs_dict(episode_data, 10) + feats = get_x_feats(obs_dict) + num_nodes = feats.shape[0] + expected_ids = torch.arange(num_nodes, dtype=torch.float32) + assert torch.all(feats[:, -1] == expected_ids), ( + f"OSC_POSE node IDs {feats[:, -1].tolist()} != expected {expected_ids.tolist()}.\n" + f"Node IDs must be sequential 0..N-1 for embedding lookup." ) - def test_osc_pose_nodes_3to6_match_eef_quat(self, episode_data): - """Nodes 3-6 must match the raw eef_quat (4D) values.""" - get_node_feats = self._make_wrapper_node_feats_fn() + def test_osc_pose_robot_nodes_first_feature_is_joint_pos(self, episode_data): + """Robot nodes 0..6 must have their joint_pos value in the first feature column.""" + get_x_feats = self._make_wrapper_x_feats_fn() t = 10 - obs_dict = self._build_obs_dict_osc(episode_data, t) - feats = get_node_feats(obs_dict, control_mode="OSC_POSE") - expected_quat = torch.tensor(obs_dict["robot0_eef_quat_raw"], dtype=torch.float32) - assert torch.allclose(feats[3:7, 0], expected_quat, atol=1e-5), ( - f"OSC_POSE nodes 3-6 (eef_quat) mismatch:\n" - f" got {feats[3:7, 0].tolist()}\n" - f" expected {expected_quat.tolist()}" + obs_dict = self._build_obs_dict(episode_data, t) + feats = get_x_feats(obs_dict) + expected_jp = torch.tensor(episode_data["joint_pos"][t], dtype=torch.float32) + assert torch.allclose(feats[:7, 0], expected_jp, atol=1e-5), ( + f"OSC_POSE robot node joint_pos mismatch:\n" + f" got {feats[:7, 0].tolist()}\n" + f" expected {expected_jp.tolist()}" ) - def test_osc_pose_node7_is_zero(self, episode_data): - """Node 7 (unused) must be 0.0.""" - get_node_feats = self._make_wrapper_node_feats_fn() - obs_dict = self._build_obs_dict_osc(episode_data, 10) - feats = get_node_feats(obs_dict, control_mode="OSC_POSE") - assert float(feats[7, 0]) == 0.0, ( - f"OSC_POSE node 7 (unused) is {float(feats[7, 0])}, expected 0.0." - ) + def test_osc_pose_x_feats_match_joint_velocity_x_feats( + self, episode_data, wrapper_get_x_fn + ): + """ + _get_x_feats output must be the same for OSC_POSE and JOINT_VELOCITY wrappers. + Graph observations are control-mode-agnostic (joint_pos + gripper_qpos). + """ + t = 10 + obs_dict = self._build_obs_dict(episode_data, t) + x_osc = self._make_wrapper_x_feats_fn()(obs_dict) + x_jv = wrapper_get_x_fn(obs_dict) - def test_osc_pose_node8_is_gripper(self, episode_data): - """Node 8 must contain the mean gripper_qpos.""" - get_node_feats = self._make_wrapper_node_feats_fn() - for t in range(len(episode_data["gripper_qpos"])): - if np.any(np.abs(episode_data["gripper_qpos"][t]) > 0.01): - break - obs_dict = self._build_obs_dict_osc(episode_data, t) - feats = get_node_feats(obs_dict, control_mode="OSC_POSE") - expected_val = float(np.mean(episode_data["gripper_qpos"][t])) - assert abs(float(feats[8, 0]) - expected_val) < 1e-5, ( - f"OSC_POSE gripper node {float(feats[8, 0]):.6f} != expected {expected_val:.6f}" + assert x_osc.shape == x_jv.shape, ( + f"OSC_POSE _get_x_feats shape {tuple(x_osc.shape)} != " + f"JOINT_VELOCITY shape {tuple(x_jv.shape)}" + ) + assert torch.allclose(x_osc, x_jv, atol=1e-5), ( + f"_get_x_feats differs between OSC_POSE and JOINT_VELOCITY wrappers.\n" + f"Graph observations must be control-mode-agnostic." ) @@ -797,9 +775,9 @@ def test_osc_pose_step_joint_pos_matches_dataset(self, episode_data): for t in range(min(self.N_STEPS, len(actions) - 1)): graph_obs, reward, done, info = wrapper.step(actions[t]) - # y field stores [joint_pos(7), gripper(2), node_id] per robot node - # Each robot node i stores its own joint value at y[i, 0] - live_jpos = graph_obs.y[:7, 0].detach().numpy() + # x field stores [joint_pos(7), gripper(2), node_type] per robot node + # Each robot node i stores its own joint value at x[i, 0] + live_jpos = graph_obs.x[:7, 0].detach().numpy() ds_jpos = joint_pos[t + 1] err = float(np.max(np.abs(live_jpos - ds_jpos))) per_step.append(err) diff --git a/tests/test_train_eval_consistency.py b/tests/test_train_eval_consistency.py index 8eb95dd..0c8c16e 100644 --- a/tests/test_train_eval_consistency.py +++ b/tests/test_train_eval_consistency.py @@ -345,7 +345,7 @@ def test_obs_normalizer_range(self, dataset): all_y_norm = [] for i in indices: data = ds.get(i) - y = data.y # (nodes, obs_horizon, feat) + y = data.x # (nodes, obs_horizon, feat) y_norm = ds.normalize_data(y, stats_key="obs") y_norm_no_id = y_norm[:, :, :-1] # exclude node-ID column all_y_norm.append(y_norm_no_id.reshape(-1).detach().numpy()) @@ -380,7 +380,7 @@ def test_action_normalizer_range(self, dataset): all_x_norm = [] for i in indices: data = ds.get(i) - x = data.x # (nodes, pred_horizon, feat) + x = data.y # (nodes, pred_horizon, feat) — y holds actions # Only the first feature dim (joint value/velocity), excluding node-type x_val = x[:, :, :1] x_norm = ds.normalize_data( From e3fae0ba3234ec98a2403aaf2aecc515e65fd56d Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 9 Mar 2026 20:48:31 +0100 Subject: [PATCH 07/22] Fix loading nets on test.py script --- test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test.py b/test.py index 4c6b749..3a2c7c8 100644 --- a/test.py +++ b/test.py @@ -25,6 +25,7 @@ def test(cfg): runner = hydra.utils.instantiate(cfg.task.env_runner) # instanciate policy from cfg file policy = hydra.utils.instantiate(cfg.policy) + policy.load_nets(policy.ckpt_path) # instanciate agent from policy agent = hydra.utils.instantiate(cfg.agent, policy=policy, env=runner.env) From ed163b9a0aead4ba18349fe930f162c56a48a412 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 9 Mar 2026 21:02:09 +0100 Subject: [PATCH 08/22] Update policies lr scheduling and improve gddpm --- imitation/config/policy/argd_policy.yaml | 3 +- imitation/config/policy/diffusion_policy.yaml | 3 +- imitation/config/policy/gddpm_policy.yaml | 1 + imitation/config/policy/osc_ddpm_policy.yaml | 5 +- imitation/dataset/robomimic_graph_dataset.py | 25 ++- .../env_runner/robomimic_lowdim_runner.py | 1 + imitation/policy/ar_graph_diffusion_policy.py | 6 +- imitation/policy/diffusion_policy.py | 6 +- imitation/policy/graph_ddpm_policy.py | 6 +- imitation/policy/osc_ddpm_policy.py | 143 +++++++++++------- 10 files changed, 135 insertions(+), 64 deletions(-) diff --git a/imitation/config/policy/argd_policy.yaml b/imitation/config/policy/argd_policy.yaml index 7fe80f8..7e8e5c8 100644 --- a/imitation/config/policy/argd_policy.yaml +++ b/imitation/config/policy/argd_policy.yaml @@ -16,4 +16,5 @@ denoising_network: num_edge_types: ${policy.num_edge_types} num_layers: 2 hidden_dim: 512 -use_normalization: False \ No newline at end of file +use_normalization: False +num_warmup_steps: 100 \ No newline at end of file diff --git a/imitation/config/policy/diffusion_policy.yaml b/imitation/config/policy/diffusion_policy.yaml index 7a9f07e..d6ba543 100644 --- a/imitation/config/policy/diffusion_policy.yaml +++ b/imitation/config/policy/diffusion_policy.yaml @@ -7,4 +7,5 @@ action_horizon: ${action_horizon} num_diffusion_iters: 100 dataset: ${task.dataset} ckpt_path: ./weights/diffusion_policy_${task.task_name}_${task.dataset_type}.pt -lr: 0.00001 \ No newline at end of file +lr: 0.00001 +num_warmup_steps: 100 \ No newline at end of file diff --git a/imitation/config/policy/gddpm_policy.yaml b/imitation/config/policy/gddpm_policy.yaml index 3795393..083293b 100644 --- a/imitation/config/policy/gddpm_policy.yaml +++ b/imitation/config/policy/gddpm_policy.yaml @@ -32,3 +32,4 @@ lr: 1e-4 batch_size: 128 use_normalization: True keep_first_action: True +num_warmup_steps: 100 diff --git a/imitation/config/policy/osc_ddpm_policy.yaml b/imitation/config/policy/osc_ddpm_policy.yaml index 2d70d13..38f8148 100644 --- a/imitation/config/policy/osc_ddpm_policy.yaml +++ b/imitation/config/policy/osc_ddpm_policy.yaml @@ -5,7 +5,7 @@ num_edge_types: 2 pred_horizon: ${pred_horizon} obs_horizon: ${obs_horizon} action_horizon: ${action_horizon} -num_diffusion_iters: 100 +num_diffusion_iters: 50 dataset: ${task.dataset} denoising_network: @@ -27,4 +27,5 @@ ckpt_path: ./weights/osc_ddpm_policy_${task.task_name}_${task.control_mode}_${po lr: 1e-4 batch_size: 128 use_normalization: True -keep_first_action: True +keep_first_action: False +num_warmup_steps: 100 diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 4d81b04..71c4cd4 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -300,8 +300,14 @@ def get_normalizer(self): data_obs.append(data["x"]) # x = observations data_action.append(data["y"]) # y = actions data_obs = torch.cat(data_obs, dim=1) - data_action = torch.cat(data_action, dim=1) - + data_action = torch.cat(data_action, dim=1) # (action_dim, N*T, feat_dim) + + if self.control_mode == "OSC_POSE": + # data_action shape: (action_dim=7, N*T, 1) + # Reshape to (N*T, action_dim) so LinearNormalizer fits per-dim stats + # (with last_n_dims=1, scale/offset will be shape (action_dim,)). + data_action = data_action[:, :, 0].T.contiguous() # (N*T, action_dim) + normalizer.fit( { "obs": data_obs, @@ -309,7 +315,20 @@ def get_normalizer(self): } ) return normalizer - + + def refit_normalizer(self, indices): + """Recompute normalizer stats using only the given sample indices (e.g. train set).""" + data_obs, data_action = [], [] + for idx in indices: + data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt')) + data_obs.append(data["x"]) + data_action.append(data["y"]) + data_obs = torch.cat(data_obs, dim=1) + data_action = torch.cat(data_action, dim=1) + if self.control_mode == "OSC_POSE": + data_action = data_action[:, :, 0].T.contiguous() + self.normalizer.fit({'obs': data_obs, 'action': data_action}) + def to_obs_deque(self, data): obs_deque = collections.deque(maxlen=self.obs_horizon) data_t = data.clone() diff --git a/imitation/env_runner/robomimic_lowdim_runner.py b/imitation/env_runner/robomimic_lowdim_runner.py index 013d442..218a079 100644 --- a/imitation/env_runner/robomimic_lowdim_runner.py +++ b/imitation/env_runner/robomimic_lowdim_runner.py @@ -62,6 +62,7 @@ def reset(self) -> None: [self.obs] * self.obs_horizon, maxlen=self.obs_horizon) def run(self, agent: BaseAgent, n_steps: int) -> Dict: + agent.policy.reset() log.info(f"Running agent {agent.__class__.__name__} for {n_steps} steps") if self.output_video: self.start_video() diff --git a/imitation/policy/ar_graph_diffusion_policy.py b/imitation/policy/ar_graph_diffusion_policy.py index d274720..559c7ad 100644 --- a/imitation/policy/ar_graph_diffusion_policy.py +++ b/imitation/policy/ar_graph_diffusion_policy.py @@ -27,7 +27,8 @@ def __init__(self, lr=1e-4, ckpt_path=None, device = None, - use_normalization = False,): + use_normalization = False, + num_warmup_steps: int = 100): super(AutoregressiveGraphDiffusionPolicy, self).__init__() if device == None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -39,6 +40,7 @@ def __init__(self, self.num_edge_types = num_edge_types self.model = denoising_network self.use_normalization = use_normalization + self.num_warmup_steps = num_warmup_steps self.masker = NodeMasker(dataset) self.global_epoch = 0 @@ -207,7 +209,7 @@ def train(self, dataset, num_epochs=100, model_path=None, seed=0): self.lr_scheduler = get_scheduler( name='cosine', optimizer=self.optimizer, - num_warmup_steps=1000, + num_warmup_steps=self.num_warmup_steps, num_training_steps=len(dataset) * self.num_epochs ) diff --git a/imitation/policy/diffusion_policy.py b/imitation/policy/diffusion_policy.py index e673f61..92d8382 100644 --- a/imitation/policy/diffusion_policy.py +++ b/imitation/policy/diffusion_policy.py @@ -42,7 +42,8 @@ def __init__(self, dataset: BaseLowdimDataset, ckpt_path= None, lr: float = 1e-4, - batch_size: int = 256): + batch_size: int = 256, + num_warmup_steps: int = 100): super().__init__() self.dataset = dataset self.batch_size = batch_size @@ -55,6 +56,7 @@ def __init__(self, self.action_horizon = action_horizon self.num_diffusion_iters = num_diffusion_iters self.lr = lr + self.num_warmup_steps = num_warmup_steps self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log.info(f"Using device {self.device}") @@ -204,7 +206,7 @@ def train(self, lr_scheduler = get_scheduler( name='cosine', optimizer=optimizer, - num_warmup_steps=500, + num_warmup_steps=self.num_warmup_steps, num_training_steps=len(self.dataloader) * num_epochs ) diff --git a/imitation/policy/graph_ddpm_policy.py b/imitation/policy/graph_ddpm_policy.py index 3bd24cf..cccf19c 100644 --- a/imitation/policy/graph_ddpm_policy.py +++ b/imitation/policy/graph_ddpm_policy.py @@ -35,7 +35,8 @@ def __init__(self, lr: float = 1e-4, batch_size: int = 256, use_normalization: bool = True, - keep_first_action: bool = True,): + keep_first_action: bool = True, + num_warmup_steps: int = 100): super().__init__() self.dataset = dataset self.batch_size = batch_size @@ -51,6 +52,7 @@ def __init__(self, self.lr = lr self.use_normalization = use_normalization self.keep_first_action = keep_first_action + self.num_warmup_steps = num_warmup_steps self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') log.info(f"Using device {self.device}") # create network object @@ -298,7 +300,7 @@ def train(self, self.lr_scheduler = get_scheduler( name='cosine', optimizer=self.optimizer, - num_warmup_steps=500, + num_warmup_steps=self.num_warmup_steps, num_training_steps=len(dataloader) * self.num_epochs ) diff --git a/imitation/policy/osc_ddpm_policy.py b/imitation/policy/osc_ddpm_policy.py index 00cf1bb..2ac45e7 100644 --- a/imitation/policy/osc_ddpm_policy.py +++ b/imitation/policy/osc_ddpm_policy.py @@ -19,6 +19,17 @@ log = logging.getLogger(__name__) +def compute_snr(timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: + """Compute signal-to-noise ratio at each timestep.""" + alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device) + sqrt_alphas_cumprod = alphas_cumprod ** 0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + alpha_t = sqrt_alphas_cumprod[timesteps] + sigma_t = sqrt_one_minus_alphas_cumprod[timesteps] + snr = (alpha_t / sigma_t) ** 2 + return snr + + class OSCGraphDDPMPolicy(BasePolicy): """ DDPM policy for OSC_POSE control. @@ -48,7 +59,8 @@ def __init__(self, lr: float = 1e-4, batch_size: int = 256, use_normalization: bool = True, - keep_first_action: bool = True): + keep_first_action: bool = True, + num_warmup_steps: int = 100): super().__init__() self.dataset = dataset self.batch_size = batch_size @@ -62,6 +74,7 @@ def __init__(self, self.lr = lr self.use_normalization = use_normalization self.keep_first_action = keep_first_action + self.num_warmup_steps = num_warmup_steps self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') log.info(f"Using device {self.device}") @@ -88,6 +101,13 @@ def __init__(self, ) self.playback_count = 0 + def reset(self): + """Reset stateful inference buffers between episodes.""" + self.last_naction = torch.zeros( + (1, self.pred_horizon, self.action_dim), device=self.device + ) + self.playback_count = 0 + def load_nets(self, ckpt_path): if ckpt_path is None: log.info('No pretrained weights given.') @@ -169,17 +189,14 @@ def get_action(self, obs_deque): self.last_naction = naction if self.use_normalization: - # Reshape to per-node format for unnormalize_data, then back - naction_node = ( - naction.permute(0, 2, 1) # (1, action_dim, pred_horizon) - .unsqueeze(-1) # (1, action_dim, pred_horizon, 1) - .reshape(self.action_dim, self.pred_horizon, 1) - ) - naction_node = self.dataset.unnormalize_data(naction_node, stats_key='action') - # (action_horizon, action_dim) - action = naction_node[:, :self.action_horizon, 0].T + # naction[0] is (pred_horizon, action_dim) — matches per-dim normalizer + # (scale shape (action_dim,)) which expects last dim = action_dim + action = self.dataset.unnormalize_data( + naction[0], stats_key='action' + ) # (pred_horizon, action_dim) + action = action[:self.pred_horizon, :].numpy() else: - action = naction[0, :self.action_horizon, :].numpy() + action = naction[0, :self.pred_horizon, :].numpy() return action @@ -198,17 +215,17 @@ def validate(self, dataset=None, model_path="last.pt"): for batch in dataloader: B = batch.num_graphs nobs = batch.x + # batch.y: (action_dim*B, pred_horizon, 1) from PyG DataLoader + # Reshape to (B, pred_horizon, action_dim) BEFORE normalizing so + # the per-dim normalizer (scale shape (action_dim,)) sees last dim=action_dim + action_raw = batch.y.view(B, self.action_dim, self.pred_horizon, 1) + action_raw = action_raw[:, :, :, 0].permute(0, 2, 1) # (B, T, Da) if self.use_normalization: nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) nobs[:, :, -1] = batch.x[:, :, -1] - naction_raw = self.dataset.normalize_data(batch.y, stats_key='action').to(self.device) + naction = self.dataset.normalize_data(action_raw, stats_key='action').to(self.device) else: - naction_raw = batch.y.to(self.device) - - # batch.y: (action_dim*B, pred_horizon, 1) - # reshape to (B, pred_horizon, action_dim) - naction = naction_raw.view(B, self.action_dim, self.pred_horizon, 1) - naction = naction[:, :, :, 0].permute(0, 2, 1) # (B, T, Da) + naction = action_raw.to(self.device) timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, @@ -224,16 +241,22 @@ def validate(self, dataset=None, model_path="last.pt"): obs_cond = nobs.float() noisy_actions = noisy_actions.float() - noise_pred, _ = self.ema_noise_pred_net( - noisy_actions, - batch.edge_index, - batch.edge_attr, - x_coord=batch.pos[:, :3], - cond=obs_cond, - timesteps=timesteps, - batch=batch.batch, - ) - loss = F.mse_loss(noise_pred, noise) + with torch.cuda.amp.autocast(): + noise_pred, _ = self.ema_noise_pred_net( + noisy_actions, + batch.edge_index, + batch.edge_attr, + x_coord=batch.pos[:, :3], + cond=obs_cond, + timesteps=timesteps, + batch=batch.batch, + ) + # Per-sample MSE loss + loss_per_sample = F.mse_loss(noise_pred, noise, reduction='none').mean(dim=(1, 2)) # (B,) + # Min-SNR-5 reweighting + snr = compute_snr(timesteps, self.noise_scheduler) + min_snr_weight = torch.clamp(snr, max=5.0) / snr + loss = (min_snr_weight * loss_per_sample).mean() val_loss.append(loss.item()) return np.mean(val_loss) @@ -250,7 +273,7 @@ def train(self, log.info('Training noise prediction network.') if self.num_epochs is None: - log.warn(f"Global num_epochs not set. Using {num_epochs}.") + log.warning(f"Global num_epochs not set. Using {num_epochs}.") self.num_epochs = num_epochs torch.manual_seed(seed) @@ -275,29 +298,31 @@ def train(self, self.lr_scheduler = get_scheduler( name='cosine', optimizer=self.optimizer, - num_warmup_steps=500, + num_warmup_steps=self.num_warmup_steps, num_training_steps=len(dataloader) * self.num_epochs, ) + scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None + with tqdm(range(num_epochs), desc='Epoch') as tglobal: - for epoch_idx in tglobal: + for _ in tglobal: epoch_loss = [] with tqdm(dataloader, desc='Batch', leave=False) as tepoch: for batch in tepoch: B = batch.num_graphs nobs = batch.x + # batch.y: (action_dim*B, pred_horizon, 1) from PyG DataLoader + # Reshape to (B, pred_horizon, action_dim) BEFORE normalizing so + # the per-dim normalizer (scale shape (action_dim,)) sees last dim=action_dim + action_raw = batch.y.view(B, self.action_dim, self.pred_horizon, 1) + action_raw = action_raw[:, :, :, 0].permute(0, 2, 1) # (B, T, Da) if self.use_normalization: nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) nobs[:, :, -1] = batch.x[:, :, -1] - naction_raw = self.dataset.normalize_data(batch.y, stats_key='action').to(self.device) + naction = self.dataset.normalize_data(action_raw, stats_key='action').to(self.device) else: - naction_raw = batch.y.to(self.device) - - # batch.y: (action_dim*B, pred_horizon, 1) from PyG DataLoader - # Reshape to (B, pred_horizon, action_dim) - naction = naction_raw.view(B, self.action_dim, self.pred_horizon, 1) - naction = naction[:, :, :, 0].permute(0, 2, 1) # (B, T, Da) + naction = action_raw.to(self.device) timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, @@ -313,24 +338,40 @@ def train(self, noisy_actions = noisy_actions.float() obs_cond = nobs.float() - noise_pred, _ = self.noise_pred_net( - noisy_actions, - batch.edge_index, - batch.edge_attr, - x_coord=batch.pos[:, :3], - cond=obs_cond, - timesteps=timesteps, - batch=batch.batch, - ) - - loss = F.mse_loss(noise_pred, noise) + with torch.cuda.amp.autocast(): + noise_pred, _ = self.noise_pred_net( + noisy_actions, + batch.edge_index, + batch.edge_attr, + x_coord=batch.pos[:, :3], + cond=obs_cond, + timesteps=timesteps, + batch=batch.batch, + ) + + # Per-sample MSE loss + loss_per_sample = F.mse_loss(noise_pred, noise, reduction='none').mean(dim=(1, 2)) # (B,) + # Min-SNR-5 reweighting + snr = compute_snr(timesteps, self.noise_scheduler) + min_snr_weight = torch.clamp(snr, max=5.0) / snr + loss = (min_snr_weight * loss_per_sample).mean() + wandb.log({ 'noise_pred_loss': loss, 'lr': self.lr_scheduler.get_last_lr()[0], + 'min_snr_weight_mean': min_snr_weight.mean().item(), }) - loss.backward() - self.optimizer.step() + if scaler is not None: + scaler.scale(loss).backward() + scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.noise_pred_net.parameters(), max_norm=1.0) + scaler.step(self.optimizer) + scaler.update() + else: + loss.backward() + torch.nn.utils.clip_grad_norm_(self.noise_pred_net.parameters(), max_norm=1.0) + self.optimizer.step() self.optimizer.zero_grad() self.lr_scheduler.step() ema.step(self.noise_pred_net.parameters()) From 63b5401b73787373427fbc61c37658465c70c318 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 9 Mar 2026 21:04:26 +0100 Subject: [PATCH 09/22] Fix dataset split, len() bug, and parametrize warmup steps --- imitation/config/test.yaml | 4 +-- imitation/config/train.yaml | 6 ++-- imitation/dataset/robomimic_graph_dataset.py | 2 +- train.py | 31 +++++++++++++++++--- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/imitation/config/test.yaml b/imitation/config/test.yaml index 1efb1b5..d97329b 100644 --- a/imitation/config/test.yaml +++ b/imitation/config/test.yaml @@ -1,7 +1,7 @@ defaults: - _self_ - task: lift_graph - - policy: gddpm_policy + - policy: osc_ddpm_policy output_video: False render: True @@ -13,7 +13,7 @@ output_dir: ./outputs pred_horizon: 16 obs_horizon: 4 -action_horizon: 16 +action_horizon: 12 action_offset: 0 # action offset for the policy, 1 if first action is to be ignored diff --git a/imitation/config/train.yaml b/imitation/config/train.yaml index 152ab5e..4103d2c 100644 --- a/imitation/config/train.yaml +++ b/imitation/config/train.yaml @@ -13,7 +13,7 @@ obs_horizon: 4 action_horizon: 16 action_offset: 0 # action offset for the policy, 1 if first action is to be ignored # Training parameters -num_epochs: 500 +num_epochs: 50 val_fraction: 0.1 seed: 0 load_ckpt: False # start training from scratch @@ -29,7 +29,7 @@ max_steps: ${task.max_steps} # Evaluation during training eval_params: - eval_every: 50 # evaluate every 50 epochs + eval_every: 25 # evaluate every 50 epochs val_every: 1 task: ${task} policy: ${policy} @@ -37,7 +37,7 @@ eval_params: output_video: True load_ckpt: True # always load for evaluation - num_episodes: 50 + num_episodes: 5 max_steps: ${task.max_steps} output_dir: ./outputs diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 71c4cd4..8bc5207 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -264,7 +264,7 @@ def len(self): # calculate length of dataset based on self.dataset_root length = 0 for key in self.dataset_keys: - length += self.dataset_root[f"data/{key}/obs/object"].shape[0] - self.pred_horizon - self.obs_horizon - 1 + length += self.dataset_root[f"data/{key}/obs/object"].shape[0] - self.pred_horizon - 1 return length def get(self, idx): diff --git a/train.py b/train.py index 78eaec7..5b372fa 100644 --- a/train.py +++ b/train.py @@ -8,6 +8,7 @@ import pathlib import hydra +import numpy as np import torch import wandb @@ -57,10 +58,32 @@ def train(cfg: DictConfig) -> None: if torch.cuda.is_available(): torch.cuda.manual_seed(cfg.seed) - # Split the dataset into train and validation - train_dataset, val_dataset = torch.utils.data.random_split( - policy.dataset, [len(policy.dataset) - int(cfg.val_fraction * len(policy.dataset)), int(cfg.val_fraction * len(policy.dataset))] - ) + # Episode-aware train/val split — keeps entire episodes on one side to avoid + # temporal leakage between samples from the same demonstration. + ds = policy.dataset + episode_sample_ranges = [] + idx_global = 0 + for key in ds.dataset_keys: + ep_len = ds.dataset_root[f"data/{key}/obs/object"].shape[0] + n_samples = ep_len - ds.pred_horizon - 1 + episode_sample_ranges.append(list(range(idx_global, idx_global + n_samples))) + idx_global += n_samples + + rng = np.random.default_rng(cfg.seed) + episode_order = rng.permutation(len(episode_sample_ranges)).tolist() + n_val_eps = max(1, int(cfg.val_fraction * len(episode_order))) + val_eps = episode_order[:n_val_eps] + train_eps = episode_order[n_val_eps:] + + train_indices = [i for ep in train_eps for i in episode_sample_ranges[ep]] + val_indices = [i for ep in val_eps for i in episode_sample_ranges[ep]] + + train_dataset = torch.utils.data.Subset(ds, train_indices) + val_dataset = torch.utils.data.Subset(ds, val_indices) + + # Recompute normalizer on training data only (removes val leakage). + log.info("Refitting normalizer on training episodes...") + ds.refit_normalizer(train_indices) E = cfg.num_epochs V = cfg.num_epochs From 3fa01b198a8a3fa54dddf2ad3ac4c3df04dc2941 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Tue, 10 Mar 2026 21:09:47 +0100 Subject: [PATCH 10/22] Add square_graph_osc task config --- imitation/config/task/square_graph_osc.yaml | 55 +++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 imitation/config/task/square_graph_osc.yaml diff --git a/imitation/config/task/square_graph_osc.yaml b/imitation/config/task/square_graph_osc.yaml new file mode 100644 index 0000000..b9ed161 --- /dev/null +++ b/imitation/config/task/square_graph_osc.yaml @@ -0,0 +1,55 @@ + +task_name: &task_name square +dataset_type: &dataset_type ph +dataset_path: &dataset_path ./data/${task.task_name}/${task.dataset_type}/low_dim_v141.hdf5 + +max_steps: ${eval:'800 if "${task.dataset_type}" == "mh" else 600'} + +control_mode: "OSC_POSE" + +obs_dim: 10 # 9 robot + 1 object nodes +action_dim: 7 # flat 7-D EEF vector (xyz + rotation + gripper) + +robots: ["Panda"] + +object_state_sizes: &object_state_sizes + nut_pos: 3 + nut_quat: 4 + +object_state_keys: &object_state_keys + nut: ["nut_pos", "nut_quat"] + +env_runner: + _target_: imitation.env_runner.robomimic_lowdim_runner.RobomimicEnvRunner + output_dir: ${output_dir} + action_horizon: ${action_horizon} + obs_horizon: ${obs_horizon} + action_offset: ${action_offset} + render: ${render} + output_video: ${output_video} + use_full_pred_after: 0.8 + env: + _target_: imitation.env.robomimic_graph_wrapper.RobomimicGraphWrapper + object_state_sizes: *object_state_sizes + object_state_keys: *object_state_keys + max_steps: ${task.max_steps} + task: "NutAssemblySquare" + has_renderer: ${render} + robots: ${task.robots} + output_video: ${output_video} + control_mode: ${task.control_mode} + controller_config: + interpolation: "linear" + ramp_ratio: 0.2 + base_link_shift: [[-0.56, 0, 0.912]] + +dataset: + _target_: imitation.dataset.robomimic_graph_dataset.RobomimicGraphDataset + dataset_path: ${task.dataset_path} + robots: ${task.robots} + pred_horizon: ${pred_horizon} + obs_horizon: ${obs_horizon} + object_state_sizes: *object_state_sizes + object_state_keys: *object_state_keys + control_mode: ${task.control_mode} + base_link_shift: [[-0.56, 0, 0.912]] From 48a68b677335e5c2998dc7f931360ab6a4433a35 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Tue, 10 Mar 2026 21:10:39 +0100 Subject: [PATCH 11/22] Update gitignore --- .gitignore | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index bc6e18b..6ec2c1e 100644 --- a/.gitignore +++ b/.gitignore @@ -94,4 +94,10 @@ weights/ archive/ # Notebooks -notebooks/ \ No newline at end of file +notebooks/ + +# Docs +docs/ + +# Claude +CLAUDE.md \ No newline at end of file From 4bedd02b89da046112e18a0715d461159706ed10 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Thu, 12 Mar 2026 19:57:36 +0100 Subject: [PATCH 12/22] Use EMA for osc_ddpm_policy --- imitation/config/policy/osc_ddpm_policy.yaml | 1 + imitation/policy/osc_ddpm_policy.py | 59 +++++++++++++++----- 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/imitation/config/policy/osc_ddpm_policy.yaml b/imitation/config/policy/osc_ddpm_policy.yaml index 38f8148..7913215 100644 --- a/imitation/config/policy/osc_ddpm_policy.yaml +++ b/imitation/config/policy/osc_ddpm_policy.yaml @@ -29,3 +29,4 @@ batch_size: 128 use_normalization: True keep_first_action: False num_warmup_steps: 100 +ema_decay: 0.9999 diff --git a/imitation/policy/osc_ddpm_policy.py b/imitation/policy/osc_ddpm_policy.py index 2ac45e7..be6433e 100644 --- a/imitation/policy/osc_ddpm_policy.py +++ b/imitation/policy/osc_ddpm_policy.py @@ -1,3 +1,4 @@ +import copy import logging import os @@ -60,7 +61,8 @@ def __init__(self, batch_size: int = 256, use_normalization: bool = True, keep_first_action: bool = True, - num_warmup_steps: int = 100): + num_warmup_steps: int = 100, + ema_decay: float = 0.9999): super().__init__() self.dataset = dataset self.batch_size = batch_size @@ -75,11 +77,25 @@ def __init__(self, self.use_normalization = use_normalization self.keep_first_action = keep_first_action self.num_warmup_steps = num_warmup_steps + self.ema_decay = ema_decay self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') log.info(f"Using device {self.device}") self.noise_pred_net = denoising_network.to(self.device) - self.ema_noise_pred_net = self.noise_pred_net.to(self.device) + # ema_noise_pred_net is a separate deep-copy used for inference + self.ema_noise_pred_net = copy.deepcopy(self.noise_pred_net).to(self.device) + + # EMA tracker — persistent across train() calls so state accumulates + if ema_decay is not None and ema_decay > 0: + self._ema = EMAModel( + parameters=self.noise_pred_net.parameters(), + decay=ema_decay, + power=0.75, + ) + log.info(f"EMA enabled with decay={ema_decay}") + else: + self._ema = None + log.info("EMA disabled") self.noise_scheduler = DDPMScheduler( num_train_timesteps=self.num_diffusion_iters, @@ -108,27 +124,44 @@ def reset(self): ) self.playback_count = 0 + def _sync_ema_to_inference_net(self): + """Copy current EMA averaged weights into ema_noise_pred_net for inference.""" + if self._ema is not None: + self._ema.copy_to(self.ema_noise_pred_net.parameters()) + else: + self.ema_noise_pred_net.load_state_dict(self.noise_pred_net.state_dict()) + def load_nets(self, ckpt_path): if ckpt_path is None: log.info('No pretrained weights given.') - self.ema_noise_pred_net = self.noise_pred_net.to(self.device) + self._sync_ema_to_inference_net() return if not os.path.isfile(ckpt_path): log.error(f"Pretrained weights not found at {ckpt_path}.") - self.ema_noise_pred_net = self.noise_pred_net.to(self.device) + self._sync_ema_to_inference_net() return try: - state_dict = torch.load(ckpt_path, map_location=self.device) - self.ema_noise_pred_net = self.noise_pred_net - self.ema_noise_pred_net.load_state_dict(state_dict) - self.ema_noise_pred_net.to(self.device) + checkpoint = torch.load(ckpt_path, map_location=self.device) + # Support both new dict format and legacy bare state-dict format + if isinstance(checkpoint, dict) and 'model' in checkpoint: + self.noise_pred_net.load_state_dict(checkpoint['model']) + if self._ema is not None and 'ema' in checkpoint: + self._ema.load_state_dict(checkpoint['ema']) + log.info('EMA state loaded from checkpoint.') + else: + # Legacy: bare state dict (EMA weights saved directly) + self.noise_pred_net.load_state_dict(checkpoint) + self._sync_ema_to_inference_net() log.info('Pretrained weights loaded.') except Exception: log.error('Error loading pretrained weights.') - self.ema_noise_pred_net = self.noise_pred_net.to(self.device) + self._sync_ema_to_inference_net() def save_nets(self, ckpt_path): - torch.save(self.ema_noise_pred_net.state_dict(), ckpt_path) + checkpoint = {'model': self.noise_pred_net.state_dict()} + if self._ema is not None: + checkpoint['ema'] = self._ema.state_dict() + torch.save(checkpoint, ckpt_path) log.info(f"Model saved at {ckpt_path}") # ------------------------------------------------------------------ @@ -282,8 +315,6 @@ def train(self, dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) - ema = EMAModel(parameters=self.ema_noise_pred_net.parameters(), power=0.75) - self.noise_pred_net.to(self.device) if self.optimizer is None: self.optimizer = torch.optim.AdamW( @@ -374,7 +405,8 @@ def train(self, self.optimizer.step() self.optimizer.zero_grad() self.lr_scheduler.step() - ema.step(self.noise_pred_net.parameters()) + if self._ema is not None: + self._ema.step(self.noise_pred_net.parameters()) loss_cpu = loss.item() epoch_loss.append(loss_cpu) @@ -382,6 +414,7 @@ def train(self, tglobal.set_postfix(loss=np.mean(epoch_loss)) wandb.log({'epoch': self.global_epoch, 'epoch_loss': np.mean(epoch_loss)}) + self._sync_ema_to_inference_net() self.save_nets(model_path) self.global_epoch += 1 tglobal.set_description(f"Epoch: {self.global_epoch}") From e89411116576545845ff55dbfd97274f3ac673b5 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Sun, 15 Mar 2026 13:29:38 +0100 Subject: [PATCH 13/22] Change default ema_decay to none, so it's not when not set --- imitation/policy/osc_ddpm_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imitation/policy/osc_ddpm_policy.py b/imitation/policy/osc_ddpm_policy.py index be6433e..0ad308b 100644 --- a/imitation/policy/osc_ddpm_policy.py +++ b/imitation/policy/osc_ddpm_policy.py @@ -62,7 +62,7 @@ def __init__(self, use_normalization: bool = True, keep_first_action: bool = True, num_warmup_steps: int = 100, - ema_decay: float = 0.9999): + ema_decay: float = None): super().__init__() self.dataset = dataset self.batch_size = batch_size From 1003d254a3b8c53df4fa1b207c23ecea2e919d26 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Sun, 15 Mar 2026 15:20:02 +0100 Subject: [PATCH 14/22] Extend load_nets to better handle EMA parameters --- imitation/policy/osc_ddpm_policy.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/imitation/policy/osc_ddpm_policy.py b/imitation/policy/osc_ddpm_policy.py index 0ad308b..d5b5642 100644 --- a/imitation/policy/osc_ddpm_policy.py +++ b/imitation/policy/osc_ddpm_policy.py @@ -148,6 +148,13 @@ def load_nets(self, ckpt_path): if self._ema is not None and 'ema' in checkpoint: self._ema.load_state_dict(checkpoint['ema']) log.info('EMA state loaded from checkpoint.') + elif self._ema is not None: + # Checkpoint has no EMA state — reinitialize shadow params from + # the loaded model weights so inference doesn't use random weights. + for s_param, param in zip(self._ema.shadow_params, + self.noise_pred_net.parameters()): + s_param.data.copy_(param.data) + log.info('No EMA state in checkpoint; shadow params reset from model weights.') else: # Legacy: bare state dict (EMA weights saved directly) self.noise_pred_net.load_state_dict(checkpoint) @@ -155,6 +162,10 @@ def load_nets(self, ckpt_path): log.info('Pretrained weights loaded.') except Exception: log.error('Error loading pretrained weights.') + if self._ema is not None: + for s_param, param in zip(self._ema.shadow_params, + self.noise_pred_net.parameters()): + s_param.data.copy_(param.data) self._sync_ema_to_inference_net() def save_nets(self, ckpt_path): From 8e32496762e0a977c9cec0bda43ea3de90fa99ef Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Sun, 15 Mar 2026 22:03:20 +0100 Subject: [PATCH 15/22] Remove node IDs from graph data --- imitation/dataset/robomimic_graph_dataset.py | 6 --- imitation/env/robomimic_graph_wrapper.py | 5 -- imitation/model/gddpm.py | 21 ++++---- imitation/policy/graph_ddpm_policy.py | 3 -- imitation/policy/osc_ddpm_policy.py | 3 -- tests/test_policy_dataset_replay.py | 50 +++++--------------- 6 files changed, 23 insertions(+), 65 deletions(-) diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 8bc5207..b8bac26 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -161,12 +161,6 @@ def _get_x_feats(self, data, t_vals): x = torch.cat((x, obj_state_tensor), dim=0) - # add column for node ID (used as embedding index by the model) - num_nodes = x.shape[0] - node_ids = torch.arange(num_nodes, dtype=torch.float32).unsqueeze(1).unsqueeze(1) # (N, 1, 1) - node_ids = node_ids.expand(-1, x.shape[1], -1) # (N, T, 1) - x = torch.cat((x, node_ids), dim=2) - return x def _get_target_actions_horizon(self, data, idx, horizon, actions): diff --git a/imitation/env/robomimic_graph_wrapper.py b/imitation/env/robomimic_graph_wrapper.py index e6d2ed6..d26b088 100644 --- a/imitation/env/robomimic_graph_wrapper.py +++ b/imitation/env/robomimic_graph_wrapper.py @@ -202,11 +202,6 @@ def _get_x_feats(self, data): x = torch.cat([x, torch.zeros((x.shape[0], obj_state_tensor.shape[1] - x.shape[1]))], dim=1) # (9, 9) x = torch.cat([x, obj_state_tensor], dim=0) # (10, 9) - # add column for node ID (used as embedding index by the model) - num_nodes = x.shape[0] - node_ids = torch.arange(num_nodes, dtype=torch.float32).unsqueeze(1) # (N, 1) - x = torch.cat([x, node_ids], dim=1) # (10, 7) - return x @lru_cache(maxsize=128) diff --git a/imitation/model/gddpm.py b/imitation/model/gddpm.py index 963d23b..076c288 100644 --- a/imitation/model/gddpm.py +++ b/imitation/model/gddpm.py @@ -207,7 +207,7 @@ class GDDPMNoisePred(nn.Module): Args: node_feature_dim: feature dimension per node per step (e.g. 1 for joint value) - cond_feature_dim: obs feature dim (excl. node-id), e.g. 6 for 6D rotation + cond_feature_dim: obs feature dim, e.g. 9 for joint_pos+gripper_qpos features obs_horizon: number of observation steps for conditioning pred_horizon: number of prediction steps (action horizon) edge_feature_dim: edge attribute size (usually 1) @@ -323,7 +323,7 @@ def forward(self, edge_index: (2, E) edge_attr: (E,) or (E, 1) — edge attributes / types x_coord: (N_total, 3) — 3D node positions - cond: (N_total, obs_horizon, cond_feature_dim+1) — obs (+node-id last) + cond: (N_total, obs_horizon, cond_feature_dim) — obs features timesteps: (B,) — diffusion timestep per graph in batch batch: (N_total,) — maps each node to its graph index @@ -351,9 +351,9 @@ def forward(self, # action_batch: maps action nodes to their graph index action_batch = torch.arange(B, dtype=torch.long, device=self.device).repeat_interleave(act_npg) - # separate node-id from conditioning features (last channel of cond) - ids = cond[:, 0, -1].long().to(self.device) - cond_feats = cond[:, :, :-1].float().to(self.device) # (N_obs, obs_horizon, C) + # auto-generate node IDs from node order (no longer stored in cond) + ids = torch.arange(cond.shape[0], device=self.device) % obs_npg + cond_feats = cond.float().to(self.device) # (N_obs, obs_horizon, C) # ---- obs edge_index with self-loops (for EGraphConditionEncoder) ---- edge_attr_1d = edge_attr.reshape(-1) @@ -532,7 +532,7 @@ class FlatGDDPMNoisePred(nn.Module): Args: action_dim: flat action dimensionality (e.g. 7 for OSC_POSE) - cond_feature_dim: obs feature dim (excl. node-id), e.g. 9 + cond_feature_dim: obs feature dim, e.g. 9 for joint_pos+gripper_qpos features obs_horizon: number of observation steps for conditioning pred_horizon: number of prediction steps edge_feature_dim: edge attribute size (usually 1) @@ -571,7 +571,7 @@ def __init__(self, self.residual_channels = residual_channels self.cond_channels = hidden_dim - self.cond_encoder = EGraphConditionEncoder( + self.cond_encoder = EGraphConditionEncoder( # TODO create clean graph encoder for this model, based on the paper input_dim=cond_feature_dim * obs_horizon, output_dim=self.cond_channels, hidden_dim=hidden_dim, @@ -632,7 +632,7 @@ def forward(self, edge_index: (2, E) edge_attr: (E,) or (E, 1) x_coord: (N_total, 3) - cond: (N_total, obs_horizon, cond_feature_dim+1) graph obs (+node-id) + cond: (N_total, obs_horizon, cond_feature_dim) graph obs features timesteps: (B,) batch: (N_total,) node-to-graph mapping for EGraphConditionEncoder @@ -650,8 +650,9 @@ def forward(self, else: batch = batch.long().to(self.device) - ids = cond[:, 0, -1].long().to(self.device) - cond_feats = cond[:, :, :-1].float().to(self.device) + # auto-generate node IDs from node order (no longer stored in cond) + ids = torch.arange(cond.shape[0], device=self.device) % (cond.shape[0] // B) + cond_feats = cond.float().to(self.device) edge_attr_1d = edge_attr.reshape(-1) edge_index_sl, edge_attr_sl = add_self_loops( diff --git a/imitation/policy/graph_ddpm_policy.py b/imitation/policy/graph_ddpm_policy.py index cccf19c..026af7d 100644 --- a/imitation/policy/graph_ddpm_policy.py +++ b/imitation/policy/graph_ddpm_policy.py @@ -124,7 +124,6 @@ def get_action(self, obs_deque): obs_pos = torch.cat(pos, dim=0) if self.use_normalization: nobs = self.dataset.normalize_data(obs, stats_key='obs') - nobs[:,:,-1] = obs[:,:,-1] # skip normalization for node IDs # Use last obs step as initial action estimate (y holds actions at inference via dataset replay) if hasattr(G_t, 'y') and G_t.y is not None: self.last_naction = self.dataset.normalize_data(G_t.y.unsqueeze(1), stats_key='action').to(self.device) @@ -200,7 +199,6 @@ def validate(self, dataset=None, model_path="last.pt"): if self.use_normalization: # normalize observation nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) - nobs[:,:,-1] = batch.x[:,:,-1] # skip normalization for node IDs # normalize action naction = self.dataset.normalize_data(batch.y, stats_key='action').to(self.device) else: @@ -316,7 +314,6 @@ def train(self, if self.use_normalization: # normalize observation nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) - nobs[:,:,-1] = batch.x[:,:,-1] # skip normalization for node IDs # normalize action naction = self.dataset.normalize_data(batch.y, stats_key='action').to(self.device) else: diff --git a/imitation/policy/osc_ddpm_policy.py b/imitation/policy/osc_ddpm_policy.py index d5b5642..8996788 100644 --- a/imitation/policy/osc_ddpm_policy.py +++ b/imitation/policy/osc_ddpm_policy.py @@ -192,7 +192,6 @@ def get_action(self, obs_deque): if self.use_normalization: nobs = self.dataset.normalize_data(obs, stats_key='obs') - nobs[:, :, -1] = obs[:, :, -1] # preserve node IDs else: nobs = obs @@ -266,7 +265,6 @@ def validate(self, dataset=None, model_path="last.pt"): action_raw = action_raw[:, :, :, 0].permute(0, 2, 1) # (B, T, Da) if self.use_normalization: nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) - nobs[:, :, -1] = batch.x[:, :, -1] naction = self.dataset.normalize_data(action_raw, stats_key='action').to(self.device) else: naction = action_raw.to(self.device) @@ -361,7 +359,6 @@ def train(self, action_raw = action_raw[:, :, :, 0].permute(0, 2, 1) # (B, T, Da) if self.use_normalization: nobs = self.dataset.normalize_data(batch.x, stats_key='obs').to(self.device) - nobs[:, :, -1] = batch.x[:, :, -1] naction = self.dataset.normalize_data(action_raw, stats_key='action').to(self.device) else: naction = action_raw.to(self.device) diff --git a/tests/test_policy_dataset_replay.py b/tests/test_policy_dataset_replay.py index 1c2334c..91f9b83 100644 --- a/tests/test_policy_dataset_replay.py +++ b/tests/test_policy_dataset_replay.py @@ -34,7 +34,7 @@ 4. test_dataset_x_and_wrapper_x_feature_order_match The obs feature vector (x) must have the same column ordering between dataset and wrapper, because the normalizer is fit on the dataset's x. - Columns: [joint_pos(7), gripper_qpos(2), node_id(1)] for robot nodes. + Columns: [joint_pos(7), gripper_qpos(2)] for robot nodes. """ import importlib.util @@ -243,7 +243,7 @@ def test_wrapper_x_matches_dataset_x_at_each_step( f"{max_err:.2e} at step {worst_t} (tolerance {Y_MATCH_TOL:.0e}).\n" f"The network sees different obs conditioning at train vs eval time.\n" f"Check that both use the same feature ordering: " - f"[joint_pos(7), gripper_qpos(2), node_id(1)] for robot nodes." + f"[joint_pos(7), gripper_qpos(2)] for robot nodes." ) def test_obs_x_feature_shape_is_consistent( @@ -316,7 +316,7 @@ def test_nobs_shape(self, dataset): ) def test_nobs_normalised_range(self, dataset): - """Normalized nobs (excluding node-ID column) is in [-NORM_RANGE_TOL, NORM_RANGE_TOL].""" + """Normalized nobs is in [-NORM_RANGE_TOL, NORM_RANGE_TOL].""" CHECK_N = 20 step = max(1, dataset.len() // CHECK_N) @@ -324,9 +324,7 @@ def test_nobs_normalised_range(self, dataset): for start_idx in range(0, dataset.len(), step): obs = self._assemble_nobs(dataset, start_idx=start_idx) nobs = dataset.normalize_data(obs, stats_key="obs") - # Exclude node-ID column (last feature) - nobs_no_id = nobs[:, :, :-1] - all_norm.append(nobs_no_id.reshape(-1).detach().numpy()) + all_norm.append(nobs.reshape(-1).detach().numpy()) import numpy as np all_norm = np.concatenate(all_norm) @@ -510,7 +508,6 @@ class TestObsXFeatureOrdering: identical feature *ordering* for robot nodes: col 0..6 : joint_pos (7 values) — stored sparsely, one per node col 7..8 : gripper_qpos (2 values) — stored sparsely - col -1 : sequential node ID (0-indexed, used as embedding index) A column-ordering mismatch would mean the normalizer scales the wrong physical quantities, making the policy conditioning signal meaningless. @@ -521,10 +518,7 @@ def test_robot_x_columns_agree_between_wrapper_and_dataset( ): """ At a known timestep, check that the wrapper and dataset x feature - tensors agree on robot nodes and sequential node IDs (last column). - - Node IDs are sequential integers (0, 1, ..., N-1) used as embedding - indices by EGraphConditionEncoder — not node-type flags. + tensors agree on robot nodes. """ t = 5 # arbitrary mid-episode step obs_dict = _build_wrapper_obs_dict(episode_data, t) @@ -533,24 +527,16 @@ def test_robot_x_columns_agree_between_wrapper_and_dataset( x_wrapper = wrapper_get_x_fn(obs_dict) # (num_nodes, feat) x_dataset = dataset_get_x_fn(data_dict, [t])[:, 0, :] # (num_nodes, feat) - num_nodes = x_wrapper.shape[0] - print(f"\n── x feature ordering check (t={t}) ──────────────────────") print(f" Wrapper x[:10,:] =\n{x_wrapper[:10,:]}") print(f" Dataset x[:10,:] =\n{x_dataset[:10,:]}") - # Last column must be sequential node IDs (0, 1, ..., N-1) for both - expected_ids = torch.arange(num_nodes, dtype=torch.float32) - wrapper_ids = x_wrapper[:, -1] - dataset_ids = x_dataset[:, -1] - - assert torch.all(wrapper_ids == expected_ids), ( - f"Wrapper x node IDs {wrapper_ids.tolist()} != expected {expected_ids.tolist()}.\n" - f"The node-ID column (last) must be sequential 0..N-1 for the embedding lookup." + assert x_wrapper.shape == x_dataset.shape, ( + f"Wrapper x shape {tuple(x_wrapper.shape)} != dataset x shape {tuple(x_dataset.shape)}." ) - assert torch.all(dataset_ids == expected_ids), ( - f"Dataset x node IDs {dataset_ids.tolist()} != expected {expected_ids.tolist()}.\n" - f"The node-ID column (last) must be sequential 0..N-1 for the embedding lookup." + assert torch.allclose(x_wrapper.float(), x_dataset.float(), atol=1e-5), ( + f"Wrapper and dataset x feature tensors disagree.\n" + f"A column-ordering mismatch means the normalizer scales wrong features." ) @@ -565,8 +551,8 @@ class TestOscPoseNodeFeats: joint_pos + gripper_qpos per robot node, object pos/rot for object nodes. Only the actions (y) differ (flat 7-D EEF vector vs per-node joint values). - Expected shape: (10, K) — 9 robot + 1 object nodes, K features including - the sequential node-ID column at the end. + Expected shape: (10, K) — 9 robot + 1 object nodes, K features + (joint_pos(7) + gripper_qpos(2) per robot node, object pose for object node). """ def _make_wrapper_x_feats_fn(self): @@ -611,18 +597,6 @@ def test_osc_pose_x_feats_shape(self, episode_data): f"(9 robot + 1 object). Graph topology must be preserved." ) - def test_osc_pose_x_feats_node_ids_sequential(self, episode_data): - """Last column of _get_x_feats must be sequential node IDs 0..9.""" - get_x_feats = self._make_wrapper_x_feats_fn() - obs_dict = self._build_obs_dict(episode_data, 10) - feats = get_x_feats(obs_dict) - num_nodes = feats.shape[0] - expected_ids = torch.arange(num_nodes, dtype=torch.float32) - assert torch.all(feats[:, -1] == expected_ids), ( - f"OSC_POSE node IDs {feats[:, -1].tolist()} != expected {expected_ids.tolist()}.\n" - f"Node IDs must be sequential 0..N-1 for embedding lookup." - ) - def test_osc_pose_robot_nodes_first_feature_is_joint_pos(self, episode_data): """Robot nodes 0..6 must have their joint_pos value in the first feature column.""" get_x_feats = self._make_wrapper_x_feats_fn() From acd861e763c7b0c350a66995e99de723ff3df022 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 16 Mar 2026 14:37:17 +0100 Subject: [PATCH 16/22] Fix issue with video writer during evaluation --- imitation/env_runner/robomimic_lowdim_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/imitation/env_runner/robomimic_lowdim_runner.py b/imitation/env_runner/robomimic_lowdim_runner.py index 218a079..ac8f7dc 100644 --- a/imitation/env_runner/robomimic_lowdim_runner.py +++ b/imitation/env_runner/robomimic_lowdim_runner.py @@ -33,15 +33,15 @@ def __init__(self, self.use_full_pred_after = use_full_pred_after self.output_dir = output_dir self.curr_video = None - if self.output_video: # don't create video writer if not needed - self.start_video() - + self.video_writer = None # keep a queue of last obs_horizon steps of observations self.reset() def start_video(self): + if self.video_writer is not None: + self.video_writer.close() self.curr_video = f"{self.output_dir}/output_{time.time()}.mp4" self.video_writer = imageio.get_writer(self.curr_video, fps=30) From 7440aa04c67e5c6af63f4f79692869b35af25575 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 16 Mar 2026 14:38:06 +0100 Subject: [PATCH 17/22] Update graph encoder to more appropriate GatedGraphConv --- .../config/policy/graph_ddpm_policy.yaml | 2 +- imitation/model/gddpm.py | 87 +++++++++++++++---- 2 files changed, 70 insertions(+), 19 deletions(-) diff --git a/imitation/config/policy/graph_ddpm_policy.yaml b/imitation/config/policy/graph_ddpm_policy.yaml index c9d77b0..b78b670 100644 --- a/imitation/config/policy/graph_ddpm_policy.yaml +++ b/imitation/config/policy/graph_ddpm_policy.yaml @@ -23,7 +23,7 @@ denoising_network: diffusion_step_embed_dim: 64 num_diffusion_steps: ${policy.num_diffusion_iters} ckpt_path: ./weights/diffusion_graph_policy_${task.task_name}_${task.dataset_type}_${task.control_mode}_${policy.num_diffusion_iters}iters.pt -lr: 1e-4 +lr: 1e-5 batch_size: 128 use_normalization: True keep_first_action: True \ No newline at end of file diff --git a/imitation/model/gddpm.py b/imitation/model/gddpm.py index 076c288..36ab661 100644 --- a/imitation/model/gddpm.py +++ b/imitation/model/gddpm.py @@ -14,8 +14,9 @@ 1. Temporal backbone: dilated 1-D convolutions (vs. EGNN message-passing). 2. Spatial backbone: GatedGraphConv (torch_geometric) inside each residual block, applied jointly over the dilated-conv output. - 3. Conditioning: same EGraphConditionEncoder is re-used for graph-level FiLM conditioning; - the resulting vector is upsampled with a CondUpsampler MLP before being injected into + 3. Conditioning: GraphCondEncoder encodes graph-structured observations into a + per-graph conditioning vector using GatedGraphConv (same spatial operator as + the residual blocks); upsampled with CondUpsampler MLP before injection into every residual block (instead of FiLM scales/biases per EGNN layer). """ @@ -30,7 +31,6 @@ from torch_geometric.utils import add_self_loops from torch_geometric.nn.pool import global_mean_pool -from imitation.model.graph_diffusion import EGraphConditionEncoder # --------------------------------------------------------------------------- @@ -85,6 +85,59 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +# --------------------------------------------------------------------------- +# Graph conditioning encoder (replaces EGraphConditionEncoder from graph_diffusion.py) +# --------------------------------------------------------------------------- + +class GraphCondEncoder(nn.Module): + """ + GNN encoder for graph-structured observations, aligned with the GDDPM paper. + + Uses GatedGraphConv (same spatial operator as the ResidualBlocks) rather + than the E(N)-equivariant EGNN from EGraphConditionEncoder. No coordinate + updates, no equivariance overhead. + + Architecture: + 1. Flatten temporal obs: (N, obs_horizon, F) -> (N, obs_horizon*F) + 2. Concatenate learned node-ID embedding (N, 16) + 3. Linear input projection -> (N, hidden_dim) + 4. GatedGraphConv message passing (n_layers iterations, weight-shared) + 5. Global mean pooling per graph -> (B, hidden_dim) + 6. Linear output projection -> (B, output_dim) + + Args: + input_dim: obs_horizon * cond_feature_dim (flattened obs per node) + hidden_dim: internal feature dimension + output_dim: output conditioning vector size + n_layers: GatedGraphConv iterations (default 3) + max_nodes: node ID embedding table size (default 30) + """ + def __init__(self, input_dim, hidden_dim, output_dim, n_layers=3, max_nodes=30): + super().__init__() + self.id_embedding = nn.Embedding(max_nodes, 16) + self.input_proj = nn.Linear(input_dim + 16, hidden_dim) + self.gnn = GatedGraphConv(hidden_dim, num_layers=n_layers) + self.output_proj = nn.Linear(hidden_dim, output_dim) + + def forward(self, x, edge_index, batch, ids): + """ + Args: + x: (N, obs_horizon, cond_feature_dim) + edge_index: (2, E) — with self-loops already added by caller + batch: (N,) — node-to-graph mapping + ids: (N,) — node index within each graph + Returns: + (B, output_dim) — graph-level conditioning vector + """ + h = x.float().flatten(start_dim=1) # (N, input_dim) + id_embed = self.id_embedding(ids.long()) # (N, 16) + h = torch.cat([h, id_embed], dim=-1) # (N, input_dim+16) + h = F.leaky_relu(self.input_proj(h), 0.4) # (N, hidden_dim) + h = self.gnn(h, edge_index) # (N, hidden_dim) + g = global_mean_pool(h, batch=batch) # (B, hidden_dim) + return self.output_proj(g) # (B, output_dim) + + # --------------------------------------------------------------------------- # Residual block: dilated conv + GatedGraphConv + conditioner # --------------------------------------------------------------------------- @@ -200,8 +253,8 @@ class GDDPMNoisePred(nn.Module): GDDPM denoising network with the same interface as ConditionalGraphNoisePred. Architecture summary: - - EGraphConditionEncoder encodes the graph-structured observation into a - per-graph conditioning vector (same as in the project's existing model). + - GraphCondEncoder encodes the graph-structured observation into a + per-graph conditioning vector using GatedGraphConv. - CondUpsampler projects it to a node-dimension matching the graph size. - A stack of ResidualBlocks (dilated conv + GatedGraphConv) predicts noise. @@ -215,7 +268,7 @@ class GDDPMNoisePred(nn.Module): residual_layers: number of ResidualBlock layers residual_channels: channels inside each block dilation_cycle_length: dilation doubles every this many layers - hidden_dim: hidden size for EGraphConditionEncoder and diffusion embed + hidden_dim: hidden size for GraphCondEncoder and diffusion embed diffusion_step_embed_dim: raw sinusoidal embedding size (≤ hidden_dim) num_diffusion_steps: total DDPM timesteps (for embedding table) device: torch device (auto-detected if None) @@ -253,11 +306,10 @@ def __init__(self, # cond_channels: output length from EGraphConditionEncoder. # We use it as the `cond_length` fed into CondUpsampler. self.cond_channels = hidden_dim - self.cond_encoder = EGraphConditionEncoder( + self.cond_encoder = GraphCondEncoder( input_dim=cond_feature_dim * obs_horizon, - output_dim=self.cond_channels, hidden_dim=hidden_dim, - device=self.device, + output_dim=self.cond_channels, ).to(self.device) # --- Diffusion step embedding ------------------------------------------ @@ -381,9 +433,8 @@ def forward(self, ) # ---- Graph-level conditioning vector -------------------------------- - # EGraphConditionEncoder returns (B, cond_channels) graph_cond = self.cond_encoder( - cond_feats, obs_edge_index_sl, x_coord, obs_edge_attr_sl.unsqueeze(-1), + cond_feats, obs_edge_index_sl, batch=obs_batch, ids=ids ) # (B, cond_channels) @@ -527,7 +578,7 @@ class FlatGDDPMNoisePred(nn.Module): (B, pred_horizon, action_dim) tensor instead of per-node. The graph structure is used *only* for observation encoding via - EGraphConditionEncoder. The residual denoising blocks operate on the + GraphCondEncoder. The residual denoising blocks operate on the full action batch at graph granularity (B, ...). Args: @@ -540,7 +591,7 @@ class FlatGDDPMNoisePred(nn.Module): residual_layers: number of FlatResidualBlock layers residual_channels: channels inside each block dilation_cycle_length: dilation doubles every this many layers - hidden_dim: hidden size for EGraphConditionEncoder + hidden_dim: hidden size for GraphCondEncoder diffusion_step_embed_dim: sinusoidal embedding size num_diffusion_steps: total DDPM timesteps """ @@ -571,11 +622,10 @@ def __init__(self, self.residual_channels = residual_channels self.cond_channels = hidden_dim - self.cond_encoder = EGraphConditionEncoder( # TODO create clean graph encoder for this model, based on the paper + self.cond_encoder = GraphCondEncoder( input_dim=cond_feature_dim * obs_horizon, - output_dim=self.cond_channels, hidden_dim=hidden_dim, - device=self.device, + output_dim=self.cond_channels, ).to(self.device) self.diffusion_embedding = DiffusionEmbedding( @@ -634,7 +684,7 @@ def forward(self, x_coord: (N_total, 3) cond: (N_total, obs_horizon, cond_feature_dim) graph obs features timesteps: (B,) - batch: (N_total,) node-to-graph mapping for EGraphConditionEncoder + batch: (N_total,) node-to-graph mapping for GraphCondEncoder Returns: noise_pred: (B, pred_horizon, action_dim) @@ -651,6 +701,7 @@ def forward(self, batch = batch.long().to(self.device) # auto-generate node IDs from node order (no longer stored in cond) + B = timesteps.shape[0] ids = torch.arange(cond.shape[0], device=self.device) % (cond.shape[0] // B) cond_feats = cond.float().to(self.device) @@ -662,7 +713,7 @@ def forward(self, # Graph-level conditioning: (B, cond_channels) graph_cond = self.cond_encoder( - cond_feats, edge_index_sl, x_coord, edge_attr_sl.unsqueeze(-1), + cond_feats, edge_index_sl, batch=batch, ids=ids ) From 12d77c6fa5ffee89afa2d8d089891e0cdd02a4d2 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 16 Mar 2026 14:39:03 +0100 Subject: [PATCH 18/22] Reduce action horizon for higher precision --- imitation/config/train.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imitation/config/train.yaml b/imitation/config/train.yaml index 4103d2c..2f01de4 100644 --- a/imitation/config/train.yaml +++ b/imitation/config/train.yaml @@ -10,7 +10,7 @@ output_video: True pred_horizon: 16 obs_horizon: 4 -action_horizon: 16 +action_horizon: 4 action_offset: 0 # action offset for the policy, 1 if first action is to be ignored # Training parameters num_epochs: 50 From 1d170ef07a91c21f6534b47b22a280e1d90be97f Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 16 Mar 2026 14:45:37 +0100 Subject: [PATCH 19/22] Update train.py version --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 5b372fa..152789d 100644 --- a/train.py +++ b/train.py @@ -40,7 +40,7 @@ def train(cfg: DictConfig) -> None: wandb.init( project=policy.__class__.__name__, group=cfg.task.task_name, - name=f"v1.2.3 - GDDPM", + name=f"v1.2.4 - GDDPM", # track hyperparameters and run metadata config={ "policy": cfg.policy, From e512cfbf89c5527ea786e5d7cc6fe7b9d8026319 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 16 Mar 2026 15:23:30 +0100 Subject: [PATCH 20/22] Keep weights/ directory --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 6ec2c1e..7af0c31 100644 --- a/.gitignore +++ b/.gitignore @@ -88,7 +88,7 @@ dmypy.json multirun/ # Model weights -weights/ +weights/* # Archive files archive/ From 71e9c675591f618ab00df037eefdc715283cca8d Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Wed, 18 Mar 2026 10:30:30 +0100 Subject: [PATCH 21/22] Remove node IDs from GDDPM model --- imitation/model/gddpm.py | 35 +++++++++++------------------------ 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/imitation/model/gddpm.py b/imitation/model/gddpm.py index 36ab661..39aa01e 100644 --- a/imitation/model/gddpm.py +++ b/imitation/model/gddpm.py @@ -99,39 +99,33 @@ class GraphCondEncoder(nn.Module): Architecture: 1. Flatten temporal obs: (N, obs_horizon, F) -> (N, obs_horizon*F) - 2. Concatenate learned node-ID embedding (N, 16) - 3. Linear input projection -> (N, hidden_dim) - 4. GatedGraphConv message passing (n_layers iterations, weight-shared) - 5. Global mean pooling per graph -> (B, hidden_dim) - 6. Linear output projection -> (B, output_dim) + 2. Linear input projection -> (N, hidden_dim) + 3. GatedGraphConv message passing (n_layers iterations, weight-shared) + 4. Global mean pooling per graph -> (B, hidden_dim) + 5. Linear output projection -> (B, output_dim) Args: input_dim: obs_horizon * cond_feature_dim (flattened obs per node) hidden_dim: internal feature dimension output_dim: output conditioning vector size n_layers: GatedGraphConv iterations (default 3) - max_nodes: node ID embedding table size (default 30) """ - def __init__(self, input_dim, hidden_dim, output_dim, n_layers=3, max_nodes=30): + def __init__(self, input_dim, hidden_dim, output_dim, n_layers=3): super().__init__() - self.id_embedding = nn.Embedding(max_nodes, 16) - self.input_proj = nn.Linear(input_dim + 16, hidden_dim) + self.input_proj = nn.Linear(input_dim, hidden_dim) self.gnn = GatedGraphConv(hidden_dim, num_layers=n_layers) self.output_proj = nn.Linear(hidden_dim, output_dim) - def forward(self, x, edge_index, batch, ids): + def forward(self, x, edge_index, batch): """ Args: x: (N, obs_horizon, cond_feature_dim) edge_index: (2, E) — with self-loops already added by caller batch: (N,) — node-to-graph mapping - ids: (N,) — node index within each graph Returns: (B, output_dim) — graph-level conditioning vector """ h = x.float().flatten(start_dim=1) # (N, input_dim) - id_embed = self.id_embedding(ids.long()) # (N, 16) - h = torch.cat([h, id_embed], dim=-1) # (N, input_dim+16) h = F.leaky_relu(self.input_proj(h), 0.4) # (N, hidden_dim) h = self.gnn(h, edge_index) # (N, hidden_dim) g = global_mean_pool(h, batch=batch) # (B, hidden_dim) @@ -403,10 +397,6 @@ def forward(self, # action_batch: maps action nodes to their graph index action_batch = torch.arange(B, dtype=torch.long, device=self.device).repeat_interleave(act_npg) - # auto-generate node IDs from node order (no longer stored in cond) - ids = torch.arange(cond.shape[0], device=self.device) % obs_npg - cond_feats = cond.float().to(self.device) # (N_obs, obs_horizon, C) - # ---- obs edge_index with self-loops (for EGraphConditionEncoder) ---- edge_attr_1d = edge_attr.reshape(-1) obs_edge_index_sl, obs_edge_attr_sl = add_self_loops( @@ -434,8 +424,8 @@ def forward(self, # ---- Graph-level conditioning vector -------------------------------- graph_cond = self.cond_encoder( - cond_feats, obs_edge_index_sl, - batch=obs_batch, ids=ids + cond.float().to(self.device), obs_edge_index_sl, + batch=obs_batch, ) # (B, cond_channels) # ---- Up-sample conditioning to pred_horizon ------------------------- @@ -700,10 +690,7 @@ def forward(self, else: batch = batch.long().to(self.device) - # auto-generate node IDs from node order (no longer stored in cond) B = timesteps.shape[0] - ids = torch.arange(cond.shape[0], device=self.device) % (cond.shape[0] // B) - cond_feats = cond.float().to(self.device) edge_attr_1d = edge_attr.reshape(-1) edge_index_sl, edge_attr_sl = add_self_loops( @@ -713,8 +700,8 @@ def forward(self, # Graph-level conditioning: (B, cond_channels) graph_cond = self.cond_encoder( - cond_feats, edge_index_sl, - batch=batch, ids=ids + cond.float().to(self.device), edge_index_sl, + batch=batch, ) # Up-sample conditioning to pred_horizon: (B, pred_horizon) -> (B, 1, pred_horizon) From 22c7b8027ea30a6b17d1e9a8ce3ecc9e5682ef9f Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Tue, 31 Mar 2026 19:50:34 +0200 Subject: [PATCH 22/22] Add positions as node features for GDDPM encoder --- imitation/model/gddpm.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/imitation/model/gddpm.py b/imitation/model/gddpm.py index 39aa01e..82962ec 100644 --- a/imitation/model/gddpm.py +++ b/imitation/model/gddpm.py @@ -93,9 +93,7 @@ class GraphCondEncoder(nn.Module): """ GNN encoder for graph-structured observations, aligned with the GDDPM paper. - Uses GatedGraphConv (same spatial operator as the ResidualBlocks) rather - than the E(N)-equivariant EGNN from EGraphConditionEncoder. No coordinate - updates, no equivariance overhead. + Uses GatedGraphConv (same spatial operator as the ResidualBlocks). Architecture: 1. Flatten temporal obs: (N, obs_horizon, F) -> (N, obs_horizon*F) @@ -112,20 +110,23 @@ class GraphCondEncoder(nn.Module): """ def __init__(self, input_dim, hidden_dim, output_dim, n_layers=3): super().__init__() - self.input_proj = nn.Linear(input_dim, hidden_dim) + self.input_proj = nn.Linear(input_dim + 3, hidden_dim) self.gnn = GatedGraphConv(hidden_dim, num_layers=n_layers) self.output_proj = nn.Linear(hidden_dim, output_dim) - def forward(self, x, edge_index, batch): + def forward(self, x, edge_index, batch, pos=None): """ Args: x: (N, obs_horizon, cond_feature_dim) edge_index: (2, E) — with self-loops already added by caller batch: (N,) — node-to-graph mapping + pos: (N, 3) — optional 3D Cartesian node positions Returns: (B, output_dim) — graph-level conditioning vector """ h = x.float().flatten(start_dim=1) # (N, input_dim) + if pos is not None: + h = torch.cat([h, pos.float()], dim=-1) # (N, input_dim + 3) h = F.leaky_relu(self.input_proj(h), 0.4) # (N, hidden_dim) h = self.gnn(h, edge_index) # (N, hidden_dim) g = global_mean_pool(h, batch=batch) # (B, hidden_dim) @@ -426,6 +427,7 @@ def forward(self, graph_cond = self.cond_encoder( cond.float().to(self.device), obs_edge_index_sl, batch=obs_batch, + pos=x_coord, ) # (B, cond_channels) # ---- Up-sample conditioning to pred_horizon ------------------------- @@ -702,6 +704,7 @@ def forward(self, graph_cond = self.cond_encoder( cond.float().to(self.device), edge_index_sl, batch=batch, + pos=x_coord, ) # Up-sample conditioning to pred_horizon: (B, pred_horizon) -> (B, 1, pred_horizon)