Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ cd ../../datasets
mv * ~/workspace/GraphDiffusionImitate/data/
```

## Testing

Tests live in the `tests/` directory and use [pytest](https://docs.pytest.org/). Make sure the `imitation` conda environment is active before running any tests.

```bash
mamba activate imitation
```

### Run all tests

```bash
pytest tests/ -v
```


## Graph Diffusion Policies

### GraphDDPM
Expand Down
353 changes: 353 additions & 0 deletions tests/test_node_pos_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
"""
Test: _get_node_pos consistency between dataset FK and live robosuite environment.

Goal
----
Verify that `_get_node_pos` (in RobomimicGraphDataset) computes joint-link
positions that are consistent with what the robosuite simulator reports when
the same episode is replayed action-by-action from the exact same starting state.

Strategy
--------
The dataset (data/lift/ph/low_dim_v141.hdf5) was recorded with an OSC_POSE
controller (6-DOF EEF-delta + 1 gripper = 7D action) at control_freq=20.

For each timestep t in demo_0:
1. Restore the simulator to the recorded state at t=0 via
sim.set_state_from_flattened(states[0]) + sim.forward().
2. Apply actions[0..t] to the live environment.
3. Compare the resulting live robot0_joint_pos against dataset obs[t+1].
4. Run the FK pipeline (calculate_panda_joints_positions + base_link_shift)
on both the dataset and live joint positions, and compare Cartesian node
positions.

Timing:
dataset obs[t] ──actions[t]──► dataset obs[t+1]
≈ live obs after env.step(actions[t])

Tolerances
---------------------------
- Joint-space : 0.06 rad (small accumulated integration drift is acceptable)
- Task-space : 6 mm (FK Cartesian node positions, derived from joint tol)

Configuration (from lift_graph.yaml)
-------------------------------------
- base_link_shift = [-0.56, 0.0, 0.912]
- base_link_rotation = [0, 0, 0, 1] (identity)
"""

import importlib.util
import os

import h5py
import numpy as np
import pytest
import torch
from scipy.spatial.transform import Rotation as R

# ── paths ────────────────────────────────────────────────────────────────────
DATASET_PATH = "data/lift/ph/low_dim_v141.hdf5"
EPISODE_KEY = "demo_0"

# ── tolerances ──────────────────────────────────────────────────────────────
JOINT_POS_TOL = 0.06 # rad - max per-joint error over the full episode
CART_POS_TOL = 6e-3 # m - max Cartesian node-position error
EEF_POS_TOL = 1.25e-2 # m - max EEF-position error (live env vs. dataset)
# FK node[8] is the wrist link origin; robosuite's robot0_eef_pos is the
# fingertip TCP. The distance between them is a fixed structural length set
# by the gripper geometry (~96.5 mm for the Panda + Robotiq default).
# We verify this distance is consistent (not drifting) across all timesteps.
FK_EEF_OFFSET_STD_TOL = 2e-3 # m - max std of the FK→EEF distance across the episode

# ── lift-task config (lift_graph.yaml) ───────────────────────────────────────
BASE_LINK_SHIFT = np.array([-0.56, 0.0, 0.912])
BASE_LINK_ROTATION = [0.0, 0.0, 0.0, 1.0] # identity (x,y,z,w)


# ── FK helper ─────────────────────────────────────────────────────────────────

def _load_calculate_panda_joints_positions():
"""
Load calculate_panda_joints_positions directly from generic.py, bypassing
imitation/utils/__init__.py which would import torch_geometric transitively.
"""
spec = importlib.util.spec_from_file_location(
"imitation_generic",
os.path.join(os.path.dirname(__file__), "..", "imitation", "utils", "generic.py")
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod.calculate_panda_joints_positions


_calculate_panda_joints_positions = None # cached after first call


def compute_node_pos_xyz(joint_pos_7: np.ndarray,
gripper_qpos_2: np.ndarray) -> torch.Tensor:
"""
Mirrors RobomimicGraphDataset._get_node_pos for a single robot.
Returns shape (9, 3) - Cartesian [x, y, z] of each Panda link node,
after applying the base_link rotation and shift from lift_graph.yaml.
"""
global _calculate_panda_joints_positions
if _calculate_panda_joints_positions is None:
_calculate_panda_joints_positions = _load_calculate_panda_joints_positions()

joints = [*joint_pos_7.tolist(), *gripper_qpos_2.tolist()]
node_pos = _calculate_panda_joints_positions(joints) # (9, 7): xyz + quat

rotation_matrix = R.from_quat(BASE_LINK_ROTATION)
rot_mat_t = torch.tensor(rotation_matrix.as_matrix()).to(node_pos.dtype)
node_pos[:, :3] = torch.matmul(node_pos[:, :3], rot_mat_t)
node_pos[:, :3] += torch.tensor(BASE_LINK_SHIFT).to(node_pos.dtype)

return node_pos[:, :3] # (9, 3)


# ── robosuite env factory ─────────────────────────────────────────────────────

def make_env(dataset_path: str = DATASET_PATH):
"""
Re-creates the robosuite environment whose parameters are stored verbatim
in the HDF5 ``data/env_args`` attribute written at record time.

Reading from the dataset (instead of hardcoding controller parameters)
guarantees that the replay configuration always matches the recording
even if the dataset or controller settings change in the future.

The only values overridden here are the three renderer/reward flags that
must be ``False`` for a headless test replay:
- has_renderer
- has_offscreen_renderer
- reward_shaping
"""
import json
import robosuite as suite
import h5py

# ── read env_args written by robomimic at record time ─────────────────────
with h5py.File(dataset_path, "r") as f:
env_args = json.loads(f["data"].attrs["env_args"])

env_name = env_args["env_name"] # e.g. "Lift"
env_kwargs = dict(env_args["env_kwargs"]) # shallow copy so we can override

# Force headless / no reward-shaping for test replay
env_kwargs["has_renderer"] = False
env_kwargs["has_offscreen_renderer"] = False
env_kwargs["reward_shaping"] = False

return suite.make(env_name, **env_kwargs)


# ── fixtures ──────────────────────────────────────────────────────────────────

@pytest.fixture(scope="module")
def episode_data():
"""Load demo_0 observations, actions and sim states from the HDF5 file."""
with h5py.File(DATASET_PATH, "r") as f:
ep = f[f"data/{EPISODE_KEY}"]
joint_pos = ep["obs/robot0_joint_pos"][:] # (T, 7)
gripper_qpos = ep["obs/robot0_gripper_qpos"][:] # (T, 2)
eef_pos = ep["obs/robot0_eef_pos"][:] # (T, 3): EEF xyz in world frame
actions = ep["actions"][:] # (T, 7): OSC_POSE + gripper
states = ep["states"][:] # (T, 32): flat MuJoCo sim state
return joint_pos, gripper_qpos, eef_pos, actions, states


# ── tests ─────────────────────────────────────────────────────────────────────

class TestNodePosConsistency:
"""
Replays demo_0 in a live robosuite env (from the exact recorded initial
state) and checks that FK positions from dataset joint_pos match those
from live joint observations.
"""

def test_fk_from_dataset_obs_matches_live_env(self, episode_data):
"""
Main consistency test.

The env is restored to the recorded t=0 sim state via
sim.set_state_from_flattened(states[0]). We then replay each action
from the dataset and compare:
(a) Joint-space: dataset obs[t+1] vs. live joint_pos - must be
within JOINT_POS_TOL (0.06 rad).
(b) Task-space: FK node positions from (a) must agree within
CART_POS_TOL (6 mm).
(c) EEF-position: live env's robot0_eef_pos vs. dataset obs[t+1]
eef_pos - must be within EEF_POS_TOL (6 mm). This verifies
that the simulated state is consistent with robosuite's own
EEF reporting at record time.

A failure means either:
- The FK in _get_node_pos uses stale / off-by-one joint data.
- The base_link_shift/rotation is applied incorrectly in the dataset.
- There is accumulated integration drift (expected to be small for
a deterministic OSC_POSE controller given the same starting state).
"""
dataset_joint_pos, dataset_gripper_qpos, dataset_eef_pos, actions, states = episode_data
T = actions.shape[0]

env = make_env()
env.reset()

# Restore exact initial sim state from the recording
env.sim.set_state_from_flattened(states[0])
env.sim.forward()

max_joint_err = 0.0
max_cart_err = 0.0
max_eef_err = 0.0
worst_joint_t = -1
worst_cart_t = -1
worst_eef_t = -1
per_step_joint_errs = []
per_step_cart_errs = []
per_step_eef_errs = []

# Iterate only while t+1 is a valid dataset index so we can always
# compare the post-step live state against obs[t+1] without clamping.
for t in range(T - 1):
live_obs, _, _, _ = env.step(actions[t])
live_joint_pos = live_obs["robot0_joint_pos"] # (7,)
live_gripper = live_obs["robot0_gripper_qpos"] # (2,)
live_eef_pos = live_obs["robot0_eef_pos"] # (3,)
# Dataset state after action[t] is always obs[t+1]
ds_joint_pos = dataset_joint_pos[t + 1]
ds_gripper = dataset_gripper_qpos[t + 1]
ds_eef_pos = dataset_eef_pos[t + 1] # (3,)

# ── (a) joint-space ────────────────────────────────────────────────
joint_err = float(np.max(np.abs(ds_joint_pos - live_joint_pos)))
per_step_joint_errs.append(joint_err)
if joint_err > max_joint_err:
max_joint_err = joint_err
worst_joint_t = t

# ── (b) FK task-space ──────────────────────────────────────────────
pos_ds = compute_node_pos_xyz(ds_joint_pos, ds_gripper)
pos_live = compute_node_pos_xyz(live_joint_pos, live_gripper)
cart_err = float(torch.max(torch.abs(pos_ds - pos_live)).item())
per_step_cart_errs.append(cart_err)
if cart_err > max_cart_err:
max_cart_err = cart_err
worst_cart_t = t

# ── (c) EEF position (live env vs. dataset) ────────────────────────
# Compares the EEF xyz reported by robosuite live against the value
# stored in the dataset at the same timestep. This is independent
# of the FK pipeline and directly validates that the replayed sim
# state matches the original recording.
eef_err = float(np.max(np.abs(live_eef_pos - ds_eef_pos)))
per_step_eef_errs.append(eef_err)
if eef_err > max_eef_err:
max_eef_err = eef_err
worst_eef_t = t

env.close()

print(f"\n── Episode replay summary ({EPISODE_KEY}) ──────────────────")
print(f" Steps replayed : {len(per_step_joint_errs)}")
print(f" Max joint-pos error (rad) : {max_joint_err:.6f} at step {worst_joint_t}")
print(f" Mean joint-pos error (rad) : {np.mean(per_step_joint_errs):.6f}")
print(f" Max FK Cartesian err (m) : {max_cart_err:.6f} at step {worst_cart_t}")
print(f" Mean FK Cartesian err (m) : {np.mean(per_step_cart_errs):.6f}")
print(f" Max EEF position err (m) : {max_eef_err:.6f} at step {worst_eef_t}")
print(f" Mean EEF position err (m) : {np.mean(per_step_eef_errs):.6f}")

assert max_joint_err <= JOINT_POS_TOL, (
f"Joint-position error {max_joint_err:.5f} rad at step {worst_joint_t} "
f"exceeds tolerance {JOINT_POS_TOL} rad.\n"
f"This may indicate: (1) _get_node_pos uses stale/off-by-one joint data "
f"from the dataset, (2) base_link_shift/rotation is wrong, or "
f"(3) accumulated integration drift exceeds the tolerance."
)

assert max_cart_err <= CART_POS_TOL, (
f"FK Cartesian node-position error {max_cart_err:.5f} m at step {worst_cart_t} "
f"exceeds tolerance {CART_POS_TOL} m.\n"
f"This may indicate: (1) _get_node_pos is producing inconsistent node "
f"positions relative to the live robosuite environment, (2) the "
f"FK pipeline (including base_link_shift/rotation) is misconfigured, or "
f"(3) accumulated joint-position drift is being non-linearly amplified "
f"by the arm kinematics beyond the allowed bound."
)
assert max_eef_err <= EEF_POS_TOL, (
f"EEF-position error {max_eef_err:.5f} m at step {worst_eef_t} "
f"exceeds tolerance {EEF_POS_TOL} m.\n"
f"The live robosuite EEF position diverged from the dataset recording. "
f"This may indicate: (1) the sim state restoration is incomplete, "
f"(2) the controller or environment parameters differ from the recording, or "
f"(3) accumulated integration drift in the end-effector pose."
)

def test_fk_is_deterministic(self, episode_data):
"""Sanity: FK must be bit-for-bit deterministic for the same inputs."""
joint_pos, gripper_qpos, _, _, _ = episode_data
pos_a = compute_node_pos_xyz(joint_pos[0], gripper_qpos[0])
pos_b = compute_node_pos_xyz(joint_pos[0], gripper_qpos[0])
assert torch.allclose(pos_a, pos_b), \
"FK is not deterministic - unexpected randomness in calculate_panda_joints_positions."

def test_node_pos_changes_over_episode(self, episode_data):
"""Sanity: FK positions must vary along the episode (data is not static/zero)."""
joint_pos, gripper_qpos, _, _, _ = episode_data
pos_0 = compute_node_pos_xyz(joint_pos[0], gripper_qpos[0])
pos_10 = compute_node_pos_xyz(joint_pos[10], gripper_qpos[10])
assert not torch.allclose(pos_0, pos_10, atol=1e-4), \
"FK positions unchanged between step 0 and step 10 - check data loading."

def test_initial_joint_positions_are_nonzero(self, episode_data):
"""Sanity: the recorded initial joint positions should not be all zeros."""
joint_pos, _, _, _, _ = episode_data
assert np.any(np.abs(joint_pos[0]) > 1e-4), \
"Initial joint positions are all near zero - dataset may not be loaded correctly."

def test_fk_eef_tracks_dataset_eef_pos(self, episode_data):
"""
Validate that the FK pipeline (calculate_panda_joints_positions +
base_link_shift) consistently tracks the EEF position reported by
robosuite in robot0_eef_pos.

The last FK node (node[8]) corresponds to the wrist link origin, which
sits at a fixed structural distance from the EEF fingertip TCP.
This test verifies:
1. The structural distance (norm of FK_node[8] - eef_pos) is
approximately constant across all dataset timesteps - a variable
distance would indicate the FK is drifting relative to the EEF.
2. The mean structural offset matches the expected gripper geometry
(~96.5 mm for the Panda), confirming base_link_shift is correct.

This is a dataset-only check (no live env needed).
"""
joint_pos, gripper_qpos, eef_pos, _, _ = episode_data
T = len(joint_pos)

offset_norms = []
for t in range(T):
fk_xyz = compute_node_pos_xyz(joint_pos[t], gripper_qpos[t]) # (9, 3)
fk_eef = fk_xyz[-1].detach().numpy() # node[8]
dist = float(np.linalg.norm(fk_eef - eef_pos[t]))
offset_norms.append(dist)

offset_norms = np.array(offset_norms)
mean_offset = float(offset_norms.mean())
std_offset = float(offset_norms.std())

print(f"\n── FK → EEF structural offset ({EPISODE_KEY}) ─────────────")
print(f" FK node[8] to robot0_eef_pos distance:")
print(f" Mean : {mean_offset*1e3:.2f} mm (expected ~96.5 mm)")
print(f" Std : {std_offset*1e3:.3f} mm (should be < {FK_EEF_OFFSET_STD_TOL*1e3:.1f} mm)")
print(f" Max : {offset_norms.max()*1e3:.2f} mm")
print(f" Min : {offset_norms.min()*1e3:.2f} mm")

assert std_offset <= FK_EEF_OFFSET_STD_TOL, (
f"FK→EEF offset norm std ({std_offset*1e3:.2f} mm) exceeds "
f"{FK_EEF_OFFSET_STD_TOL*1e3:.1f} mm tolerance.\n"
f"The FK pipeline is not consistently tracking the EEF position. "
f"This may indicate: (1) base_link_shift/rotation is wrong, "
f"(2) calculate_panda_joints_positions uses an unexpected link frame, or "
f"(3) a bug in compute_node_pos_xyz."
)