diff --git a/egomimic/rldb/embodiment/embodiment.py b/egomimic/rldb/embodiment/embodiment.py index 24fab039..13dc1b8f 100644 --- a/egomimic/rldb/embodiment/embodiment.py +++ b/egomimic/rldb/embodiment/embodiment.py @@ -1,7 +1,11 @@ from abc import ABC from enum import Enum +import numpy as np +import torch + from egomimic.rldb.zarr.action_chunk_transforms import Transform +from egomimic.utils.type_utils import _to_numpy class EMBODIMENT(Enum): @@ -53,3 +57,67 @@ def viz_transformed_batch(batch): def get_keymap(): """Returns a dictionary mapping from the raw keys in the dataset to the canonical keys used by the model.""" raise NotImplementedError + + @classmethod + def viz_cartesian_gt_preds(cls, predictions, batch, image_key, action_key): + embodiment_id = batch["embodiment"][0].item() + embodiment_name = get_embodiment(embodiment_id).lower() + + images = batch[image_key] + actions = batch[action_key] + pred_actions = predictions[f"{embodiment_name}_{action_key}"] + ims_list = [] + images = _to_numpy(images) + actions = _to_numpy(actions) + pred_actions = _to_numpy(pred_actions) + for i in range(images.shape[0]): + image = images[i] + action = actions[i] + pred_action = pred_actions[i] + ims = cls.viz(image, action, mode="traj", color="Reds") + ims = cls.viz(ims, pred_action, mode="traj", color="Greens") + ims_list.append(ims) + ims = np.stack(ims_list, axis=0) + return ims + + @classmethod + def apply_transform(cls, batch, transform_list: list[Transform]): + if transform_list: + batch_size = None + for v in batch.values(): + if isinstance(v, (np.ndarray, torch.Tensor)): + batch_size = v.shape[0] + break + + if batch_size is not None: + # Apply transforms per-sample (matching how ZarrDataset applies them) + results = [] + for i in range(batch_size): + sample = { + k: (v[i].numpy() if isinstance(v, torch.Tensor) else v[i]) + if isinstance(v, (np.ndarray, torch.Tensor)) + else v + for k, v in batch.items() + } + for transform in transform_list: + sample = transform.transform(sample) + results.append(sample) + + batch = {} + for k in results[0]: + vals = [r[k] for r in results] + if isinstance(vals[0], np.ndarray): + batch[k] = np.stack(vals, axis=0) + elif isinstance(vals[0], torch.Tensor): + batch[k] = torch.stack(vals, dim=0) + else: + batch[k] = vals + else: + for transform in transform_list: + batch = transform.transform(batch) + + for k, v in batch.items(): + if isinstance(v, np.ndarray): + batch[k] = torch.from_numpy(v).to(torch.float32) + + return batch diff --git a/egomimic/rldb/embodiment/human.py b/egomimic/rldb/embodiment/human.py index 1939fd91..7fce24ba 100644 --- a/egomimic/rldb/embodiment/human.py +++ b/egomimic/rldb/embodiment/human.py @@ -1,19 +1,28 @@ from __future__ import annotations -from egomimic.rldb.embodiment.embodiment import Embodiment -from egomimic.rldb.embodiment.eva import ( - _viz_batch_palm_axes, - _viz_batch_palm_traj, -) +from typing import Literal + +import numpy as np + +from egomimic.rldb.embodiment.embodiment import Embodiment, get_embodiment from egomimic.rldb.zarr.action_chunk_transforms import ( ActionChunkCoordinateFrameTransform, ConcatKeys, DeleteKeys, InterpolatePose, PoseCoordinateFrameTransform, + Reshape, + SplitKeys, Transform, XYZWXYZ_to_XYZYPR, ) +from egomimic.utils.type_utils import _to_numpy +from egomimic.utils.viz_utils import ( + ColorPalette, + _viz_axes, + _viz_keypoints, + _viz_traj, +) class Human(Embodiment): @@ -22,75 +31,259 @@ class Human(Embodiment): ACTION_STRIDE = 3 @classmethod - def get_transform_list(cls) -> list[Transform]: - return _build_aria_bimanual_transform_list(stride=cls.ACTION_STRIDE) + def viz_keypoints_gt_preds( + cls, predictions, batch, image_key, action_key, transform_list=None, **kwargs + ): + if transform_list is not None: + batch = cls.apply_transform(batch, transform_list) + embodiment_id = batch["embodiment"][0].item() + embodiment_name = get_embodiment(embodiment_id).lower() + + images = batch[image_key] + actions = batch[action_key] + pred_actions = predictions[f"{embodiment_name}_{action_key}"] + ims_list = [] + images = _to_numpy(images) + actions = _to_numpy(actions) + pred_actions = _to_numpy(pred_actions) + for i in range(images.shape[0]): + image = images[i] + action = actions[i] + pred_action = pred_actions[i] + ims = cls.viz(image, action, mode="keypoints", color="Reds", **kwargs) + ims = cls.viz(ims, pred_action, mode="keypoints", color="Greens", **kwargs) + ims_list.append(ims) + ims = np.stack(ims_list, axis=0) + return ims @classmethod - def viz_transformed_batch(cls, batch, mode=""): - image_key = cls.VIZ_IMAGE_KEY - action_key = "actions_cartesian" + def viz_transformed_batch( + cls, + batch, + mode=Literal["traj", "axes", "keypoints"], + action_key="actions_cartesian", + image_key=None, + transform_list=None, + **kwargs, + ): + if transform_list is not None: + batch = cls.apply_transform(batch, transform_list) + + image_key = image_key or cls.VIZ_IMAGE_KEY + action_key = action_key or "actions_cartesian" intrinsics_key = cls.VIZ_INTRINSICS_KEY - mode = (mode or "palm_traj").lower() + mode = (mode or "traj").lower() + images = _to_numpy(batch[image_key][0]) + actions = _to_numpy(batch[action_key][0]) + + return cls.viz( + images=images, + actions=actions, + mode=mode, + intrinsics_key=intrinsics_key, + **kwargs, + ) - if mode == "palm_traj": - return _viz_batch_palm_traj( - batch=batch, - image_key=image_key, - action_key=action_key, + @classmethod + def viz( + cls, + images, + actions, + mode=Literal["traj", "axes", "keypoints"], + intrinsics_key=None, + **kwargs, + ): + intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY + if mode == "traj": + return _viz_traj( + images=images, + actions=actions, intrinsics_key=intrinsics_key, + **kwargs, ) - if mode == "palm_axes": - return _viz_batch_palm_axes( - batch=batch, - image_key=image_key, - action_key=action_key, + if mode == "axes": + return _viz_axes( + images=images, + actions=actions, intrinsics_key=intrinsics_key, + **kwargs, ) if mode == "keypoints": - raise NotImplementedError( - "mode='keypoints' is reserved and not implemented yet." - ) + color = kwargs.get("color", None) + if color is not None and ColorPalette.is_valid(color): + n = len(cls.FINGER_COLORS) + colors = { + finger: ColorPalette.to_rgb(color, value=(i + 1) / (n + 1)) + for i, finger in enumerate(cls.FINGER_COLORS) + } + dot_color = ColorPalette.to_rgb(color, value=0.7) + else: + colors = cls.FINGER_COLORS + dot_color = cls.DOT_COLOR + return _viz_keypoints( + images=images, + actions=actions, + intrinsics_key=intrinsics_key, + edges=cls.FINGER_EDGES, + edge_ranges=cls.FINGER_EDGE_RANGES, + colors=colors, + dot_color=dot_color, + **kwargs, + ) raise ValueError( f"Unsupported mode '{mode}'. Expected one of: " - f"('palm_traj', 'palm_axes', 'keypoints')." + f"('traj', 'axes', 'keypoints')." ) @classmethod - def get_keymap(cls): - return { - cls.VIZ_IMAGE_KEY: { - "key_type": "camera_keys", - "zarr_key": "images.front_1", - }, - "right.action_ee_pose": { - "key_type": "action_keys", - "zarr_key": "right.obs_ee_pose", - "horizon": 30, - }, - "left.action_ee_pose": { - "key_type": "action_keys", - "zarr_key": "left.obs_ee_pose", - "horizon": 30, - }, - "right.obs_ee_pose": { - "key_type": "proprio_keys", - "zarr_key": "right.obs_ee_pose", - }, - "left.obs_ee_pose": { - "key_type": "proprio_keys", - "zarr_key": "left.obs_ee_pose", - }, - "obs_head_pose": { - "key_type": "proprio_keys", - "zarr_key": "obs_head_pose", - }, - } + def get_keymap(cls, mode: Literal["cartesian", "keypoints"]): + if mode == "cartesian": + key_map = { + cls.VIZ_IMAGE_KEY: { + "key_type": "camera_keys", + "zarr_key": "images.front_1", + }, + "right.action_ee_pose": { + "key_type": "action_keys", + "zarr_key": "right.obs_ee_pose", + "horizon": 30, + }, + "left.action_ee_pose": { + "key_type": "action_keys", + "zarr_key": "left.obs_ee_pose", + "horizon": 30, + }, + "right.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_ee_pose", + }, + "left.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_ee_pose", + }, + "obs_head_pose": { + "key_type": "proprio_keys", + "zarr_key": "obs_head_pose", + }, + } + elif mode == "keypoints": + key_map = { + cls.VIZ_IMAGE_KEY: { + "key_type": "camera_keys", + "zarr_key": "images.front_1", + }, + "left.action_keypoints": { + "key_type": "action_keys", + "zarr_key": "left.obs_keypoints", + "horizon": 30, + }, + "right.action_keypoints": { + "key_type": "action_keys", + "zarr_key": "right.obs_keypoints", + "horizon": 30, + }, + "left.action_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_wrist_pose", + "horizon": 30, + }, + "right.action_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_wrist_pose", + "horizon": 30, + }, + "left.obs_keypoints": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_keypoints", + }, + "right.obs_keypoints": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_keypoints", + }, + "left.obs_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_wrist_pose", + }, + "right.obs_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_wrist_pose", + }, + "obs_head_pose": { + "key_type": "proprio_keys", + "zarr_key": "obs_head_pose", + }, + } + else: + raise ValueError( + f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'keypoints'." + ) + return key_map class Aria(Human): VIZ_INTRINSICS_KEY = "base" ACTION_STRIDE = 3 + FINGER_EDGES = [ + ( + 5, + 6, + ), + (6, 7), + (7, 0), # thumb + (5, 8), + (8, 9), + (9, 10), + (9, 1), # index + (5, 11), + (11, 12), + (12, 13), + (13, 2), # middle + (5, 14), + (14, 15), + (15, 16), + (16, 3), # ring + (5, 17), + (17, 18), + (18, 19), + (19, 4), # pinky + ] + FINGER_COLORS = { + "thumb": (255, 100, 100), # red + "index": (100, 255, 100), # green + "middle": (100, 100, 255), # blue + "ring": (255, 255, 100), # yellow + "pinky": (255, 100, 255), # magenta + } + FINGER_EDGE_RANGES = [ + ("thumb", 0, 3), + ("index", 3, 6), + ("middle", 6, 9), + ("ring", 9, 12), + ("pinky", 12, 15), + ] + DOT_COLOR = (255, 165, 0) + + @classmethod + def get_transform_list( + cls, mode: Literal["cartesian", "keypoints_headframe", "keypoints_wristframe"] + ) -> list[Transform]: + if mode == "cartesian": + return _build_aria_cartesian_bimanual_transform_list( + stride=cls.ACTION_STRIDE + ) + elif mode == "keypoints": + return _build_aria_keypoints_bimanual_transform_list( + stride=cls.ACTION_STRIDE + ) + elif mode == "keypoints_wristframe": + return _build_aria_keypoints_eef_frame_transform_list( + stride=cls.ACTION_STRIDE + ) + else: + raise ValueError( + f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'keypoints', 'keypoints_wristframe'." + ) class Scale(Human): @@ -103,7 +296,428 @@ class Mecka(Human): ACTION_STRIDE = 1 -def _build_aria_bimanual_transform_list( +def _build_aria_keypoints_revert_eef_frame_transform_list( + *, + action_key: str = "actions_keypoints", + left_keypoints_action_wristframe: str = "left.action_keypoints_wristframe", + right_keypoints_action_wristframe: str = "right.action_keypoints_wristframe", + left_wrist_obs_headframe: str = "left.obs_wrist_pose_headframe", + right_wrist_obs_headframe: str = "right.obs_wrist_pose_headframe", + left_wrist_action_headframe: str = "left.action_wrist_pose_headframe", + right_wrist_action_headframe: str = "right.action_wrist_pose_headframe", +) -> list[Transform]: + transform_list = [ + SplitKeys( + input_key=action_key, + output_key_list=[ + (left_keypoints_action_wristframe, 63), + (right_keypoints_action_wristframe, 63), + ], + ), + Reshape( + input_key=left_keypoints_action_wristframe, + output_key=left_keypoints_action_wristframe, + shape=(100, 21, 3), + ), + Reshape( + input_key=right_keypoints_action_wristframe, + output_key=right_keypoints_action_wristframe, + shape=(100, 21, 3), + ), + ActionChunkCoordinateFrameTransform( + target_world=left_wrist_obs_headframe, + chunk_world=left_keypoints_action_wristframe, + transformed_key_name=left_wrist_action_headframe, + mode="xyz", + inverse=False, + ), + ActionChunkCoordinateFrameTransform( + target_world=right_wrist_obs_headframe, + chunk_world=right_keypoints_action_wristframe, + transformed_key_name=right_wrist_action_headframe, + mode="xyz", + inverse=False, + ), + Reshape( + input_key=left_wrist_action_headframe, + output_key=left_wrist_action_headframe, + shape=(100, 63), + ), + Reshape( + input_key=right_wrist_action_headframe, + output_key=right_wrist_action_headframe, + shape=(100, 63), + ), + ConcatKeys( + key_list=[ + left_wrist_action_headframe, + right_wrist_action_headframe, + ], + new_key_name=action_key, + delete_old_keys=True, + ), + ] + return transform_list + + +def _build_aria_keypoints_eef_frame_transform_list( + *, + target_world: str = "obs_head_pose", + target_world_ypr: str = "obs_head_pose_ypr", + target_world_is_quat: bool = True, + left_keypoints_action_world: str = "left.action_keypoints", + right_keypoints_action_world: str = "right.action_keypoints", + left_keypoints_obs_pose: str = "left.obs_keypoints", + right_keypoints_obs_pose: str = "right.obs_keypoints", + left_keypoints_action_headframe: str = "left.action_keypoints_headframe", + right_keypoints_action_headframe: str = "right.action_keypoints_headframe", + left_keypoints_obs_headframe: str = "left.obs_keypoints_headframe", + right_keypoints_obs_headframe: str = "right.obs_keypoints_headframe", + left_wrist_action_world: str = "left.action_wrist_pose", + right_wrist_action_world: str = "right.action_wrist_pose", + left_keypoints_action_wristframe: str = "left.action_keypoints_wristframe", + right_keypoints_action_wristframe: str = "right.action_keypoints_wristframe", + left_wrist_action_wristframe: str = "left.action_wrist_pose_wristframe", + right_wrist_action_wristframe: str = "right.action_wrist_pose_wristframe", + left_wrist_obs_pose: str = "left.obs_wrist_pose", + right_wrist_obs_pose: str = "right.obs_wrist_pose", + left_wrist_action_headframe: str = "left.action_wrist_pose_headframe", + right_wrist_action_headframe: str = "right.action_wrist_pose_headframe", + left_wrist_obs_headframe: str = "left.obs_wrist_pose_headframe", + right_wrist_obs_headframe: str = "right.obs_wrist_pose_headframe", + left_keypoints_obs_wristframe: str = "left.obs_keypoints_wristframe", + right_keypoints_obs_wristframe: str = "right.obs_keypoints_wristframe", + delete_target_world: bool = True, + chunk_length: int = 100, + stride: int = 3, +) -> list[Transform]: + transform_list = _build_aria_keypoints_bimanual_transform_list( + target_world=target_world, + target_world_ypr=target_world_ypr, + target_world_is_quat=target_world_is_quat, + delete_target_world=delete_target_world, + chunk_length=chunk_length, + stride=stride, + concat_keys=False, + ) + delete_keys = [ + left_keypoints_action_world, + right_keypoints_action_world, + left_keypoints_obs_pose, + right_keypoints_obs_pose, + left_wrist_action_world, + right_wrist_action_world, + left_wrist_obs_pose, + right_wrist_obs_pose, + left_keypoints_action_headframe, + right_keypoints_action_headframe, + left_keypoints_obs_headframe, + right_keypoints_obs_headframe, + left_wrist_action_headframe, + right_wrist_action_headframe, + ] + if delete_target_world: + delete_keys.append(target_world) + if target_world_is_quat: + delete_keys.append(target_world_ypr) + transform_list.extend( + [ + Reshape( + input_key=left_keypoints_action_headframe, + output_key=left_keypoints_action_headframe, + shape=(chunk_length, 21, 3), + ), + Reshape( + input_key=right_keypoints_action_headframe, + output_key=right_keypoints_action_headframe, + shape=(chunk_length, 21, 3), + ), + ActionChunkCoordinateFrameTransform( + target_world=left_wrist_obs_headframe, + chunk_world=left_keypoints_action_headframe, + transformed_key_name=left_keypoints_action_wristframe, + mode="xyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=right_wrist_obs_headframe, + chunk_world=right_keypoints_action_headframe, + transformed_key_name=right_keypoints_action_wristframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_action_wristframe, + output_key=left_keypoints_action_wristframe, + shape=(chunk_length, 63), + ), + Reshape( + input_key=right_keypoints_action_wristframe, + output_key=right_keypoints_action_wristframe, + shape=(chunk_length, 63), + ), + Reshape( + input_key=left_keypoints_obs_headframe, + output_key=left_keypoints_obs_headframe, + shape=(21, 3), + ), + Reshape( + input_key=right_keypoints_obs_headframe, + output_key=right_keypoints_obs_headframe, + shape=(21, 3), + ), + PoseCoordinateFrameTransform( + target_world=left_wrist_obs_headframe, + pose_world=left_keypoints_obs_headframe, + transformed_key_name=left_keypoints_obs_wristframe, + mode="xyz", + ), + PoseCoordinateFrameTransform( + target_world=right_wrist_obs_headframe, + pose_world=right_keypoints_obs_headframe, + transformed_key_name=right_keypoints_obs_wristframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_obs_wristframe, + output_key=left_keypoints_obs_wristframe, + shape=(63,), + ), + Reshape( + input_key=right_keypoints_obs_wristframe, + output_key=right_keypoints_obs_wristframe, + shape=(63,), + ), + ] + ) + transform_list.extend( + [ + ConcatKeys( + key_list=[ + left_keypoints_action_wristframe, + right_keypoints_action_wristframe, + ], + new_key_name="actions_keypoints", + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_keypoints_obs_wristframe, + right_keypoints_obs_wristframe, + ], + new_key_name="observations.state.keypoints", + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_wrist_obs_headframe, + right_wrist_obs_headframe, + ], + new_key_name="observations.state.wrist_pose", + delete_old_keys=False, + ), + DeleteKeys(keys_to_delete=delete_keys), + ] + ) + return transform_list + + +def _build_aria_keypoints_bimanual_transform_list( + *, + target_world: str = "obs_head_pose", + target_world_ypr: str = "obs_head_pose_ypr", + target_world_is_quat: bool = True, + left_keypoints_action_world: str = "left.action_keypoints", + right_keypoints_action_world: str = "right.action_keypoints", + left_keypoints_obs_pose: str = "left.obs_keypoints", + right_keypoints_obs_pose: str = "right.obs_keypoints", + left_keypoints_action_headframe: str = "left.action_keypoints_headframe", + right_keypoints_action_headframe: str = "right.action_keypoints_headframe", + left_keypoints_obs_headframe: str = "left.obs_keypoints_headframe", + right_keypoints_obs_headframe: str = "right.obs_keypoints_headframe", + left_wrist_action_world: str = "left.action_wrist_pose", + right_wrist_action_world: str = "right.action_wrist_pose", + left_wrist_obs_pose: str = "left.obs_wrist_pose", + right_wrist_obs_pose: str = "right.obs_wrist_pose", + left_wrist_action_headframe: str = "left.action_wrist_pose_headframe", + right_wrist_action_headframe: str = "right.action_wrist_pose_headframe", + left_wrist_obs_headframe: str = "left.obs_wrist_pose_headframe", + right_wrist_obs_headframe: str = "right.obs_wrist_pose_headframe", + delete_target_world: bool = True, + chunk_length: int = 100, + stride: int = 3, + concat_keys: bool = True, +) -> list[Transform]: + keys_to_delete = list( + { + left_keypoints_action_world, + right_keypoints_action_world, + left_keypoints_obs_pose, + right_keypoints_obs_pose, + left_wrist_action_world, + right_wrist_action_world, + left_wrist_obs_pose, + right_wrist_obs_pose, + left_keypoints_action_headframe, + right_keypoints_action_headframe, + left_keypoints_obs_headframe, + right_keypoints_obs_headframe, + left_wrist_action_headframe, + right_wrist_action_headframe, + left_wrist_obs_headframe, + right_wrist_obs_headframe, + } + ) + if delete_target_world: + keys_to_delete.append(target_world) + if target_world_is_quat: + keys_to_delete.append(target_world_ypr) + transform_list: list[Transform] = [ + Reshape( + input_key=left_keypoints_action_world, + output_key=left_keypoints_action_world, + shape=(30, 21, 3), + ), + Reshape( + input_key=right_keypoints_action_world, + output_key=right_keypoints_action_world, + shape=(30, 21, 3), + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=left_keypoints_action_world, + transformed_key_name=left_keypoints_action_headframe, + mode="xyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=right_keypoints_action_world, + transformed_key_name=right_keypoints_action_headframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_obs_pose, + output_key=left_keypoints_obs_pose, + shape=(21, 3), + ), + Reshape( + input_key=right_keypoints_obs_pose, + output_key=right_keypoints_obs_pose, + shape=(21, 3), + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=left_keypoints_obs_pose, + transformed_key_name=left_keypoints_obs_headframe, + mode="xyz", + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=right_keypoints_obs_pose, + transformed_key_name=right_keypoints_obs_headframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_obs_headframe, + output_key=left_keypoints_obs_headframe, + shape=(63,), + ), + Reshape( + input_key=right_keypoints_obs_headframe, + output_key=right_keypoints_obs_headframe, + shape=(63,), + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_keypoints_action_headframe, + output_action_key=left_keypoints_action_headframe, + stride=stride, + mode="xyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_keypoints_action_headframe, + output_action_key=right_keypoints_action_headframe, + stride=stride, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_action_headframe, + output_key=left_keypoints_action_headframe, + shape=(chunk_length, 63), + ), + Reshape( + input_key=right_keypoints_action_headframe, + output_key=right_keypoints_action_headframe, + shape=(chunk_length, 63), + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=left_wrist_action_world, + transformed_key_name=left_wrist_action_headframe, + mode="xyzwxyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=right_wrist_action_world, + transformed_key_name=right_wrist_action_headframe, + mode="xyzwxyz", + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=left_wrist_obs_pose, + transformed_key_name=left_wrist_obs_headframe, + mode="xyzwxyz", + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=right_wrist_obs_pose, + transformed_key_name=right_wrist_obs_headframe, + mode="xyzwxyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_wrist_action_headframe, + output_action_key=left_wrist_action_headframe, + stride=stride, + mode="xyzwxyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_wrist_action_headframe, + output_action_key=right_wrist_action_headframe, + stride=stride, + mode="xyzwxyz", + ), + ] + if concat_keys: + transform_list.extend( + [ + ConcatKeys( + key_list=[ + left_wrist_action_headframe, + left_keypoints_action_headframe, + right_wrist_action_headframe, + right_keypoints_action_headframe, + ], + new_key_name="actions_keypoints", + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_wrist_obs_headframe, + left_keypoints_obs_headframe, + right_wrist_obs_headframe, + right_keypoints_obs_headframe, + ], + new_key_name="observations.state.keypoints", + delete_old_keys=True, + ), + DeleteKeys(keys_to_delete=keys_to_delete), + ] + ) + return transform_list + + +def _build_aria_cartesian_bimanual_transform_list( *, target_world: str = "obs_head_pose", target_world_ypr: str = "obs_head_pose_ypr", @@ -147,39 +761,39 @@ def _build_aria_bimanual_transform_list( target_world=target_pose_key, chunk_world=left_action_world, transformed_key_name=left_action_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), ActionChunkCoordinateFrameTransform( target_world=target_pose_key, chunk_world=right_action_world, transformed_key_name=right_action_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), PoseCoordinateFrameTransform( target_world=target_pose_key, pose_world=left_obs_pose, transformed_key_name=left_obs_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), PoseCoordinateFrameTransform( target_world=target_pose_key, pose_world=right_obs_pose, transformed_key_name=right_obs_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), InterpolatePose( new_chunk_length=chunk_length, action_key=left_action_headframe, output_action_key=left_action_headframe, stride=stride, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), InterpolatePose( new_chunk_length=chunk_length, action_key=right_action_headframe, output_action_key=right_action_headframe, stride=stride, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), ] diff --git a/egomimic/rldb/zarr/action_chunk_transforms.py b/egomimic/rldb/zarr/action_chunk_transforms.py index ea0ce6ad..6ff5fdc2 100644 --- a/egomimic/rldb/zarr/action_chunk_transforms.py +++ b/egomimic/rldb/zarr/action_chunk_transforms.py @@ -13,19 +13,24 @@ from __future__ import annotations from abc import abstractmethod +from typing import Literal import numpy as np +import torch from projectaria_tools.core.sophus import SE3 from scipy.spatial.transform import Rotation as R -import torch from egomimic.utils.pose_utils import ( _interpolate_euler, _interpolate_linear, _interpolate_quat_wxyz, + _interpolate_xyz, + _matrix_to_xyz, _matrix_to_xyzwxyz, _matrix_to_xyzypr, + _xyz_to_matrix, _xyzwxyz_to_matrix, + _xyzypr_to_matrix, ) # --------------------------------------------------------------------------- @@ -56,7 +61,7 @@ def __init__( action_key: str, output_action_key: str, stride: int = 1, - is_quat: bool = False, + mode: Literal["xyzwxyz", "xyzypr"] = "xyzwxyz", ): if stride <= 0: raise ValueError(f"stride must be positive, got {stride}") @@ -64,12 +69,12 @@ def __init__( self.action_key = action_key self.output_action_key = output_action_key self.stride = int(stride) - self.is_quat = is_quat + self.mode = mode def transform(self, batch: dict) -> dict: actions = np.asarray(batch[self.action_key]) actions = actions[:: self.stride] - if self.is_quat: + if self.mode == "xyzwxyz": if actions.ndim != 2 or actions.shape[-1] != 7: raise ValueError( f"InterpolatePose expects (T, 7) when is_quat=True, got " @@ -78,7 +83,7 @@ def transform(self, batch: dict) -> dict: batch[self.output_action_key] = _interpolate_quat_wxyz( actions, self.new_chunk_length ) - else: + elif self.mode == "xyzypr": if actions.ndim != 2 or actions.shape[-1] != 6: raise ValueError( f"InterpolatePose expects (T, 6), got {actions.shape} for key " @@ -87,6 +92,15 @@ def transform(self, batch: dict) -> dict: batch[self.output_action_key] = _interpolate_euler( actions, self.new_chunk_length ) + else: + if actions.shape[-1] != 3: + raise ValueError( + f"InterpolatePose expects (T, 3) or (T, K, 3), got {actions.shape} for key " + f"'{self.action_key}'" + ) + batch[self.output_action_key] = _interpolate_xyz( + actions, self.new_chunk_length + ) return batch @@ -126,27 +140,6 @@ def transform(self, batch: dict) -> dict: # --------------------------------------------------------------------------- -def _xyzypr_to_matrix(xyzypr: np.ndarray) -> np.ndarray: - """ - args: - xyzypr: (B, 6) np.array of [[x, y, z, yaw, pitch, roll]] - returns: - (B, 4, 4) array of SE3 transformation matrices - """ - if xyzypr.ndim != 2 or xyzypr.shape[-1] != 6: - raise ValueError(f"Expected (B, 6) array, got shape {xyzypr.shape}") - - B = xyzypr.shape[0] - dtype = xyzypr.dtype if np.issubdtype(xyzypr.dtype, np.floating) else np.float64 - - mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() - # Input is [yaw, pitch, roll], so use ZYX order (Rz @ Ry @ Rx). - mats[:, :3, :3] = R.from_euler("ZYX", xyzypr[:, 3:6], degrees=False).as_matrix() - mats[:, :3, 3] = xyzypr[:, :3] - - return mats - - class ActionChunkCoordinateFrameTransform(Transform): def __init__( self, @@ -154,7 +147,8 @@ def __init__( chunk_world: str, transformed_key_name: str, extra_batch_key: dict = None, - is_quat: bool = False, + mode: Literal["xyz", "xyzwxyz", "xyzypr"] = "xyzwxyz", + inverse: bool = True, ): """ args: @@ -167,7 +161,8 @@ def __init__( self.chunk_world = chunk_world self.transformed_key_name = transformed_key_name self.extra_batch_key = extra_batch_key - self.is_quat = is_quat + self.mode = mode + self.inverse = inverse def transform(self, batch): """ @@ -183,25 +178,56 @@ def transform(self, batch): if is_quat=False: (T, 6) xyz + ypr if is_quat=True: (T, 7) xyz + quat(wxyz) """ + # flatten to (T, D) + # target world is head pose, chunk world is keypoints batch.update(self.extra_batch_key or {}) target_world = np.asarray(batch[self.target_world]) chunk_world = np.asarray(batch[self.chunk_world]) - to_matrix_fn = _xyzwxyz_to_matrix if self.is_quat else _xyzypr_to_matrix + chunk_world_shape = None + + if chunk_world.ndim > 2: + chunk_world_shape = chunk_world.shape + chunk_world = chunk_world.reshape(-1, chunk_world_shape[-1]) + + to_matrix_fn = None + if self.mode == "xyzwxyz": + to_matrix_fn = _xyzwxyz_to_matrix + elif self.mode == "xyzypr": + to_matrix_fn = _xyzypr_to_matrix + elif self.mode == "xyz": + to_matrix_fn = _xyz_to_matrix + else: + raise ValueError(f"Invalid mode: {self.mode}") + target_world_to_matrix_fn = ( + _xyzwxyz_to_matrix if target_world.shape[-1] == 7 else _xyzypr_to_matrix + ) # Convert to SE3 for transformation - target_se3 = SE3.from_matrix(to_matrix_fn(target_world[None, :])[0]) # (4, 4) + target_se3 = SE3.from_matrix( + target_world_to_matrix_fn(target_world[None, :])[0] + ) # (4, 4) chunk_se3 = SE3.from_matrix(to_matrix_fn(chunk_world)) # (T, 4, 4) # Compute relative transform and apply to chunk - chunk_in_target_frame = target_se3.inverse() @ chunk_se3 + if self.inverse: + chunk_in_target_frame = target_se3.inverse() @ chunk_se3 + else: + chunk_in_target_frame = target_se3 @ chunk_se3 chunk_mats = chunk_in_target_frame.to_matrix() if chunk_mats.ndim == 2: chunk_mats = chunk_mats[None, ...] - chunk_in_target_frame = ( - _matrix_to_xyzwxyz(chunk_mats) - if self.is_quat - else _matrix_to_xyzypr(chunk_mats) - ) + + if self.mode == "xyzwxyz": + chunk_in_target_frame = _matrix_to_xyzwxyz(chunk_mats) + elif self.mode == "xyzypr": + chunk_in_target_frame = _matrix_to_xyzypr(chunk_mats) + elif self.mode == "xyz": + chunk_in_target_frame = _matrix_to_xyz(chunk_mats) + else: + raise ValueError(f"Invalid mode: {self.mode}") + + if chunk_world_shape is not None: + chunk_in_target_frame = chunk_in_target_frame.reshape(*chunk_world_shape) # Store transformed chunk back in batch batch[self.transformed_key_name] = chunk_in_target_frame @@ -237,27 +263,21 @@ def __init__( target_world: str, pose_world: str, transformed_key_name: str, - is_quat: bool = False, + mode: Literal["xyzwxyz", "xyzypr", "xyz"] = "xyzwxyz", ): self.target_world = target_world self.pose_world = pose_world self.transformed_key_name = transformed_key_name - self.is_quat = is_quat + self.mode = mode self._chunk_transform = ActionChunkCoordinateFrameTransform( target_world=target_world, chunk_world=pose_world, transformed_key_name=transformed_key_name, - is_quat=is_quat, + mode=mode, ) def transform(self, batch: dict) -> dict: pose_world = np.asarray(batch[self.pose_world]) - expected_shape = (7,) if self.is_quat else (6,) - if pose_world.shape != expected_shape: - raise ValueError( - f"Expected pose_world shape {expected_shape}, got {pose_world.shape}" - ) - transformed = self._chunk_transform.transform( { self.target_world: batch[self.target_world], @@ -390,6 +410,17 @@ def transform(self, batch): # --------------------------------------------------------------------------- # Shape Transforms # --------------------------------------------------------------------------- +class SplitKeys(Transform): + def __init__(self, input_key: str, output_key_list: list[(str, int)]): + self.input_key = input_key + self.output_key_list = list(output_key_list) + + def transform(self, batch: dict) -> dict: + prev_end = 0 + for key, size in self.output_key_list: + batch[key] = batch[self.input_key][..., prev_end : prev_end + size] + prev_end += size + return batch class ConcatKeys(Transform): @@ -414,10 +445,23 @@ def transform(self, batch): return batch + +class Reshape(Transform): + def __init__(self, input_key: str, output_key: str, shape: tuple): + self.input_key = input_key + self.output_key = output_key + self.shape = shape + + def transform(self, batch: dict) -> dict: + batch[self.output_key] = batch[self.input_key].reshape(*self.shape) + return batch + + # --------------------------------------------------------------------------- # Type Transforms # --------------------------------------------------------------------------- + class NumpyToTensor(Transform): def __init__(self, keys: list[str]): self.keys = keys @@ -429,5 +473,7 @@ def transform(self, batch: dict) -> dict: elif isinstance(batch[key], torch.Tensor): batch[key] = batch[key].clone() else: - raise ValueError(f"NumpyToTensor expects key '{key}' to be a numpy array or torch tensor, got {type(batch[key])}") + raise ValueError( + f"NumpyToTensor expects key '{key}' to be a numpy array or torch tensor, got {type(batch[key])}" + ) return batch diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index bb7b08b7..41a0e729 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -738,12 +738,11 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: if self.transform: for transform in self.transform or []: try: + # breakpoint() data = transform.transform(data) except Exception as e: logger.error(f"Error transforming data: {e}") - logger.error(f"Data: {data}") logger.error(f"Transform: {transform}") - logger.error(f"Error: {e}") if idx == 0: logger.error("Error in first frame") raise e diff --git a/egomimic/scripts/tutorials/zarr_data_viz.ipynb b/egomimic/scripts/tutorials/zarr_data_viz.ipynb index 45b93fb3..c16fcfb9 100644 --- a/egomimic/scripts/tutorials/zarr_data_viz.ipynb +++ b/egomimic/scripts/tutorials/zarr_data_viz.ipynb @@ -258,7 +258,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "emimic (3.11.14)", "language": "python", "name": "python3" }, diff --git a/egomimic/scripts/zarr_data_viz.ipynb b/egomimic/scripts/zarr_data_viz.ipynb new file mode 100644 index 00000000..1a1e4302 --- /dev/null +++ b/egomimic/scripts/zarr_data_viz.ipynb @@ -0,0 +1,420 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "79d184b3", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "29aeeb40", + "metadata": {}, + "source": [ + "# Eva Data\n", + "\n", + "This notebook builds a `MultiDataset` containing exactly one `ZarrDataset`, loads one batch, visualizes one image with `mediapy`, and prints the rest of the batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32d9110f", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import cv2\n", + "import imageio_ffmpeg\n", + "import mediapy as mpy\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from egomimic.rldb.embodiment.eva import Eva\n", + "from egomimic.rldb.embodiment.human import Aria\n", + "from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset, ZarrDataset\n", + "from egomimic.rldb.zarr.zarr_dataset_multi import S3EpisodeResolver\n", + "from egomimic.utils.egomimicUtils import (\n", + " INTRINSICS,\n", + " cam_frame_to_cam_pixels,\n", + " nds,\n", + ")\n", + "from egomimic.utils.aws.aws_data_utils import load_env\n", + "\n", + "# Ensure mediapy can find an ffmpeg executable in this environment\n", + "mpy.set_ffmpeg(imageio_ffmpeg.get_ffmpeg_exe())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc9edba1", + "metadata": {}, + "outputs": [], + "source": [ + "TEMP_DIR = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n", + "load_env()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4aa1a05", + "metadata": {}, + "outputs": [], + "source": [ + "# Point this at a single episode directory, e.g. /path/to/episode_hash.zarr\n", + "# EPISODE_PATH = Path(\"/coc/flash7/scratch/egoverseDebugDatasets/1767495035712.zarr\")\n", + "\n", + "key_map = Eva.get_keymap()\n", + "transform_list = Eva.get_transform_list()\n", + "\n", + "# Build a MultiDataset with exactly one ZarrDataset inside\n", + "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n", + "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map)\n", + "\n", + "# multi_ds = MultiDataset(datasets={\"single_episode\": single_ds}, mode=\"total\")\n", + "resolver = S3EpisodeResolver(\n", + " TEMP_DIR, key_map=key_map, transform_list=transform_list\n", + ")\n", + "filters = {\n", + " \"episode_hash\": \"2025-12-26-18-07-46-296000\"\n", + "}\n", + "multi_ds = MultiDataset._from_resolver(\n", + " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", + ")\n", + "\n", + "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b72f3bb", + "metadata": {}, + "outputs": [], + "source": [ + "# Separate YPR visualization preview\n", + "for batch in loader:\n", + " vis_ypr = Eva.viz_transformed_batch(batch, mode=\"axes\")\n", + " mpy.show_image(vis_ypr)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d8c3da2", + "metadata": {}, + "outputs": [], + "source": [ + "images = []\n", + "for i, batch in enumerate(loader):\n", + " vis = Eva.viz_transformed_batch(batch, mode=\"traj\")\n", + " images.append(vis)\n", + " if i > 10:\n", + " break\n", + "\n", + "mpy.show_video(images, fps=30)" + ] + }, + { + "cell_type": "markdown", + "id": "1a3382f1", + "metadata": {}, + "source": [ + "## Human Datasets\n", + "Mecka, Scale and Aria should all run exactly the same" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7384468", + "metadata": {}, + "outputs": [], + "source": [ + "temp_dir = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n", + "\n", + "intrinsics_key = \"base\"\n", + "\n", + "key_map = Aria.get_keymap(mode=\"keypoints\")\n", + "transform_list = Aria.get_transform_list(mode=\"keypoints\")\n", + "\n", + "resolver = S3EpisodeResolver(\n", + " temp_dir,\n", + " key_map=key_map,\n", + " transform_list=transform_list,\n", + ")\n", + "\n", + "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"} #aria\n", + "# filters = {\"episode_hash\": \"692ee048ef7557106e6c4b8d\"} # mecka\n", + "\n", + "cloudflare_ds = MultiDataset._from_resolver(\n", + " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", + ")\n", + "\n", + "loader = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af65095a", + "metadata": {}, + "outputs": [], + "source": [ + "ims = []\n", + "for i, batch in enumerate(loader):\n", + " vis = Aria.viz_transformed_batch(batch, mode=\"traj\")\n", + " ims.append(vis)\n", + " if i > 10:\n", + " break\n", + "\n", + "mpy.show_video(ims, fps=30)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6d8d872", + "metadata": {}, + "outputs": [], + "source": [ + "# Aria YPR video (same data loop, YPR overlay)\n", + "ims_ypr = []\n", + "for i, batch in enumerate(loader):\n", + " vis_ypr = Aria.viz_transformed_batch(batch, mode=\"axes\")\n", + " ims_ypr.append(vis_ypr)\n", + " if i > 20:\n", + " break\n", + "\n", + "mpy.show_video(ims_ypr, fps=30)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60723adf", + "metadata": {}, + "outputs": [], + "source": [ + "ims_keypoints = []\n", + "for i, batch in enumerate(loader):\n", + " vis_keypoints = Aria.viz_transformed_batch(batch, mode=\"keypoints\")\n", + " ims_keypoints.append(vis_keypoints)\n", + " if i > 360:\n", + " break\n", + "\n", + "mpy.show_video(ims_keypoints, fps=20)" + ] + }, + { + "cell_type": "markdown", + "id": "efecaba7", + "metadata": {}, + "source": [ + "## Keypoint Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e39bca03", + "metadata": {}, + "outputs": [], + "source": [ + "# Load Scale episode with raw keypoints (no action chunking needed)\n", + "\n", + "from egomimic.rldb.zarr.action_chunk_transforms import _xyzwxyz_to_matrix\n", + "\n", + "key_map_kp = {\n", + " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", + " \"left.obs_keypoints\": {\"zarr_key\": \"left.obs_keypoints\"},\n", + " \"right.obs_keypoints\": {\"zarr_key\": \"right.obs_keypoints\"},\n", + " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n", + "}\n", + "\n", + "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"}\n", + "\n", + "resolver = S3EpisodeResolver(\n", + " temp_dir,\n", + " key_map=key_map\n", + ")\n", + "\n", + "cloudflare_ds = MultiDataset._from_resolver(\n", + " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", + ")\n", + "\n", + "loader_kp = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "848c6d74", + "metadata": {}, + "outputs": [], + "source": [ + "# ARIA Keypoint Viz\n", + "# MANO skeleton edges: (parent, child) for drawing bones\n", + "MANO_EDGES = [\n", + " (0, 1), (1, 2), (2, 3), (3, 4), # thumb\n", + " (0, 5), (5, 6), (6, 7), (7, 8), # index\n", + " (0, 9), (9, 10), (10, 11), (11, 12), # middle\n", + " (0, 13), (13, 14), (14, 15), (15, 16), # ring\n", + " (0, 17), (17, 18), (18, 19), (19, 20), # pinky\n", + "]\n", + "\n", + "# aria configuration\n", + "MANO_EDGES = [\n", + " (5, 6,), (6, 7), (7, 0), # thumb\n", + " (5, 8), (8, 9), (9, 10), (9, 1), # index\n", + " (5, 11), (11, 12), (12, 13), (13, 2), # middle\n", + " (5, 14), (14, 15), (15, 16), (16, 3), # ring\n", + " (5, 17), (17, 18), (18, 19), (19, 4), # pinky\n", + "]\n", + "\n", + "FINGER_COLORS = {\n", + " \"thumb\": (255, 100, 100), # red\n", + " \"index\": (100, 255, 100), # green\n", + " \"middle\": (100, 100, 255), # blue\n", + " \"ring\": (255, 255, 100), # yellow\n", + " \"pinky\": (255, 100, 255), # magenta\n", + "}\n", + "FINGER_EDGE_RANGES = [\n", + " (\"thumb\", 0, 3), (\"index\", 3, 6), (\"middle\", 6, 9),\n", + " (\"ring\", 9, 12), (\"pinky\", 12, 15),\n", + "]\n", + "\n", + "\n", + "def viz_keypoints(batch, image_key=\"observations.images.front_img_1\"):\n", + " \"\"\"Visualize all 21 MANO keypoints per hand, projected onto the image.\"\"\"\n", + " # Prepare image\n", + " img = batch[image_key][0].detach().cpu()\n", + " if img.shape[0] in (1, 3):\n", + " img = img.permute(1, 2, 0)\n", + " img_np = img.numpy()\n", + " if img_np.dtype != np.uint8:\n", + " if img_np.max() <= 1.0:\n", + " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n", + " else:\n", + " img_np = img_np.clip(0, 255).astype(np.uint8)\n", + " if img_np.shape[-1] == 1:\n", + " img_np = np.repeat(img_np, 3, axis=-1)\n", + "\n", + " intrinsics = INTRINSICS[\"base\"]\n", + " head_pose = batch[\"obs_head_pose\"][0].detach().cpu().numpy() # (6,)\n", + "\n", + " # T_head_world: camera pose in world (camera-to-world)\n", + " # We need world-to-camera = inv(T_head_world)\n", + " T_head_world = _xyzwxyz_to_matrix(head_pose[None, :])[0] # (4, 4)\n", + " T_world_to_cam = np.linalg.inv(T_head_world)\n", + "\n", + " vis = img_np.copy()\n", + " h, w = vis.shape[:2]\n", + "\n", + " for hand, dot_color in [(\"left\", (0, 120, 255)), (\"right\", (255, 80, 0))]:\n", + " kps_key = f\"{hand}.obs_keypoints\"\n", + " if kps_key not in batch:\n", + " continue\n", + " kps_flat = batch[kps_key][0].detach().cpu().numpy() # (63,)\n", + " kps_world = kps_flat.reshape(21, 3)\n", + "\n", + " # Skip if keypoints are all zero (invalid, clamped from 1e9)\n", + " if np.allclose(kps_world, 0.0, atol=1e-3):\n", + " continue\n", + "\n", + " # World -> camera frame\n", + " kps_h = np.concatenate([kps_world, np.ones((21, 1))], axis=1) # (21, 4)\n", + " kps_cam = (T_world_to_cam @ kps_h.T).T[:, :3] # (21, 3)\n", + "\n", + " # Camera frame -> pixels\n", + " kps_px = cam_frame_to_cam_pixels(kps_cam, intrinsics) # (21, 3+)\n", + "\n", + " # Identify valid keypoints (z > 0 and in image bounds)\n", + " valid = (kps_cam[:, 2] > 0.01)\n", + " valid &= (kps_px[:, 0] >= 0) & (kps_px[:, 0] < w)\n", + " valid &= (kps_px[:, 1] >= 0) & (kps_px[:, 1] < h)\n", + "\n", + " # Draw skeleton edges (colored by finger)\n", + " for finger, start, end in FINGER_EDGE_RANGES:\n", + " color = FINGER_COLORS[finger]\n", + " for edge_idx in range(start, end):\n", + " i, j = MANO_EDGES[edge_idx]\n", + " if valid[i] and valid[j]:\n", + " p1 = (int(kps_px[i, 0]), int(kps_px[i, 1]))\n", + " p2 = (int(kps_px[j, 0]), int(kps_px[j, 1]))\n", + " cv2.line(vis, p1, p2, color, 2)\n", + "\n", + " # Draw keypoint dots on top\n", + " for k in range(21):\n", + " if valid[k]:\n", + " center = (int(kps_px[k, 0]), int(kps_px[k, 1]))\n", + " cv2.circle(vis, center, 4, dot_color, -1)\n", + " cv2.circle(vis, center, 4, (255, 255, 255), 1) # white border\n", + "\n", + " # Label wrist\n", + " if valid[0]:\n", + " wrist_px = (int(kps_px[0, 0]) + 6, int(kps_px[0, 1]) - 6)\n", + " cv2.putText(vis, f\"{hand[0].upper()}\", wrist_px,\n", + " cv2.FONT_HERSHEY_SIMPLEX, 0.5, dot_color, 2)\n", + "\n", + " return vis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75dbfa95", + "metadata": {}, + "outputs": [], + "source": [ + "# Render keypoint video\n", + "ims_kp = []\n", + "for i, batch_kp in enumerate(loader_kp):\n", + " vis = viz_keypoints(batch_kp)\n", + " ims_kp.append(vis)\n", + " if i > 10:\n", + " break\n", + "\n", + "mpy.show_video(ims_kp, fps=30)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f4fbaec", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/egomimic/utils/pose_utils.py b/egomimic/utils/pose_utils.py index ac284e1a..1c9e142c 100644 --- a/egomimic/utils/pose_utils.py +++ b/egomimic/utils/pose_utils.py @@ -80,6 +80,14 @@ def _interpolate_quat_wxyz(seq: np.ndarray, chunk_length: int) -> np.ndarray: ) +def _interpolate_xyz(seq: np.ndarray, chunk_length: int) -> np.ndarray: + """Linear interpolation for arbitrary (T, 3) arrays or (T, K, 3) arrays.""" + T = seq.shape[0] + old_time = np.linspace(0, 1, T) + new_time = np.linspace(0, 1, chunk_length) + return interp1d(old_time, seq, axis=0, kind="linear")(new_time) + + def _matrix_to_xyzypr(mats: np.ndarray) -> np.ndarray: """ args: @@ -99,6 +107,24 @@ def _matrix_to_xyzypr(mats: np.ndarray) -> np.ndarray: return np.concatenate([xyz, ypr], axis=-1).astype(dtype, copy=False) +def _xyzypr_to_matrix(xyzypr: np.ndarray) -> np.ndarray: + """ + args: + xyzypr: (B, 6) np.array of [[x, y, z, yaw, pitch, roll]] + returns: + (B, 4, 4) array of SE3 transformation matrices + """ + if xyzypr.ndim != 2 or xyzypr.shape[-1] != 6: + raise ValueError(f"Expected (B, 6) array, got shape {xyzypr.shape}") + B = xyzypr.shape[0] + dtype = xyzypr.dtype if np.issubdtype(xyzypr.dtype, np.floating) else np.float64 + + mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() + mats[:, :3, :3] = R.from_euler("ZYX", xyzypr[:, 3:6], degrees=False).as_matrix() + mats[:, :3, 3] = xyzypr[:, :3] + return mats + + def _matrix_to_xyzwxyz(mats: np.ndarray) -> np.ndarray: """ args: @@ -149,3 +175,72 @@ def T_rot_orientation(T: np.ndarray, rot_orientation: np.ndarray) -> np.ndarray: rot = rot @ rot_orientation T[:3, :3] = rot return T + +def _xyz_to_matrix(xyz: np.ndarray) -> np.ndarray: + """ + args: + xyz: (B, 3) np.array of [[x, y, z]] + returns: + (B, 4, 4) array of SE3 transformation matrices + """ + if xyz.ndim != 2 or xyz.shape[-1] != 3: + raise ValueError(f"Expected (B, 3) array, got shape {xyz.shape}") + B = xyz.shape[0] + dtype = xyz.dtype if np.issubdtype(xyz.dtype, np.floating) else np.float64 + mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() + mats[:, :3, 3] = xyz + return mats + + +def _matrix_to_xyz(mats: np.ndarray) -> np.ndarray: + """ + args: + mats: (B, 4, 4) array of SE3 transformation matrices + returns: + (B, 3) np.array of [[x, y, z]] + """ + if mats.ndim != 3 or mats.shape[-2:] != (4, 4): + raise ValueError(f"Expected (B, 4, 4) array, got shape {mats.shape}") + mats = np.asarray(mats) + dtype = mats.dtype if np.issubdtype(mats.dtype, np.floating) else np.float64 + return mats[:, :3, 3].astype(dtype, copy=False) + + +def _split_action_pose(actions): + # 14D layout: [L xyz ypr g, R xyz ypr g] + # 12D layout: [L xyz ypr, R xyz ypr] + if actions.shape[-1] == 14: + left_xyz = actions[..., :3] + left_ypr = actions[..., 3:6] + right_xyz = actions[..., 7:10] + right_ypr = actions[..., 10:13] + elif actions.shape[-1] == 12: + left_xyz = actions[..., :3] + left_ypr = actions[..., 3:6] + right_xyz = actions[..., 6:9] + right_ypr = actions[..., 9:12] + else: + raise ValueError(f"Unsupported action dim {actions.shape[-1]}") + return left_xyz, left_ypr, right_xyz, right_ypr + + +def _split_keypoints(keypoints, wrist_in_data: bool = False): + if wrist_in_data: + left_xyz = keypoints[..., :3] + left_wxyz = keypoints[..., 3:7] + left_keypoints = keypoints[..., 7:70] + right_xyz = keypoints[..., 70:73] + right_wxyz = keypoints[..., 73:77] + right_keypoints = keypoints[..., 77:140] + return ( + left_xyz, + left_wxyz, + left_keypoints, + right_xyz, + right_wxyz, + right_keypoints, + ) + else: + left_keypoints = keypoints[..., :63] + right_keypoints = keypoints[..., 63:] + return left_keypoints, right_keypoints diff --git a/egomimic/utils/type_utils.py b/egomimic/utils/type_utils.py new file mode 100644 index 00000000..e6608ecb --- /dev/null +++ b/egomimic/utils/type_utils.py @@ -0,0 +1,11 @@ +import numpy as np + + +def _to_numpy(arr): + if hasattr(arr, "detach"): + arr = arr.detach() + if hasattr(arr, "cpu"): + arr = arr.cpu() + if hasattr(arr, "numpy"): + return arr.numpy() + return np.asarray(arr) diff --git a/egomimic/utils/viz_utils.py b/egomimic/utils/viz_utils.py new file mode 100644 index 00000000..da6328fc --- /dev/null +++ b/egomimic/utils/viz_utils.py @@ -0,0 +1,232 @@ +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from scipy.spatial.transform import Rotation as R + +from egomimic.utils.egomimicUtils import ( + INTRINSICS, + cam_frame_to_cam_pixels, + draw_actions, +) +from egomimic.utils.pose_utils import _split_action_pose, _split_keypoints + + +class ColorPalette: + Blues = "Blues" + Greens = "Greens" + Reds = "Reds" + Oranges = "Oranges" + Purples = "Purples" + Greys = "Greys" + + @classmethod + def is_valid(cls, name: str) -> bool: + return name in vars(cls).values() + + @classmethod + def to_rgb(cls, cmap_name: str, value: float = 0.7) -> tuple[int, int, int]: + """Convert a ColorPalette cmap name to an RGB tuple (0-255). + value: 0-1, where higher = darker shade.""" + rgba = plt.get_cmap(cmap_name)(value) + return tuple(int(c * 255) for c in rgba[:3]) + + +def _prepare_viz_image(img): + if img.ndim == 3 and img.shape[0] in (1, 3): + img = np.transpose(img, (1, 2, 0)) + + if img.dtype != np.uint8: + if img.max() <= 1.0: + img = (img * 255.0).clip(0, 255).astype(np.uint8) + else: + img = img.clip(0, 255).astype(np.uint8) + + if img.ndim == 2: + img = np.repeat(img[:, :, None], 3, axis=-1) + elif img.shape[-1] == 1: + img = np.repeat(img, 3, axis=-1) + + return img + + +def _viz_traj(images, actions, intrinsics_key, color="Reds"): + images = _prepare_viz_image(images) + intrinsics = INTRINSICS[intrinsics_key] + left_xyz, _, right_xyz, _ = _split_action_pose(actions) + + vis = draw_actions( + images.copy(), + type="xyz", + color=color, + actions=left_xyz, + extrinsics=None, + intrinsics=intrinsics, + arm="left", + ) + vis = draw_actions( + vis, + type="xyz", + color=color, + actions=right_xyz, + extrinsics=None, + intrinsics=intrinsics, + arm="right", + ) + return vis + + +def _viz_axes(images, actions, intrinsics_key, axis_len_m=0.04): + images = _prepare_viz_image(images) + intrinsics = INTRINSICS[intrinsics_key] + left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions) + vis = images.copy() + + def _draw_axis_color_legend(frame): + _, w = frame.shape[:2] + x_right = w - 12 + y_start = 14 + y_step = 12 + line_len = 24 + axis_legend = [ + ("x", (255, 0, 0)), + ("y", (0, 255, 0)), + ("z", (0, 0, 255)), + ] + for i, (name, color) in enumerate(axis_legend): + y = y_start + i * y_step + x0 = x_right - line_len + x1 = x_right + cv2.line(frame, (x0, y), (x1, y), color, 3) + cv2.putText( + frame, + name, + (x0 - 12, y + 4), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + color, + 1, + cv2.LINE_AA, + ) + return frame + + def _draw_rotation_at_anchor(frame, xyz_seq, ypr_seq, label, anchor_color): + if len(xyz_seq) == 0 or len(ypr_seq) == 0: + return frame + + palm_xyz = xyz_seq[0] + palm_ypr = ypr_seq[0] + rot = R.from_euler("ZYX", palm_ypr, degrees=False).as_matrix() + + axis_points_cam = np.vstack( + [ + palm_xyz, + palm_xyz + rot[:, 0] * axis_len_m, + palm_xyz + rot[:, 1] * axis_len_m, + palm_xyz + rot[:, 2] * axis_len_m, + ] + ) + + px = cam_frame_to_cam_pixels(axis_points_cam, intrinsics)[:, :2] + if not np.isfinite(px).all(): + return frame + pts = np.round(px).astype(np.int32) + + h, w = frame.shape[:2] + x0, y0 = pts[0] + if not (0 <= x0 < w and 0 <= y0 < h): + return frame + + cv2.circle(frame, (x0, y0), 4, anchor_color, -1) + axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] + for i, color in enumerate(axis_colors, start=1): + x1, y1 = pts[i] + if 0 <= x1 < w and 0 <= y1 < h: + cv2.line(frame, (x0, y0), (x1, y1), color, 2) + cv2.circle(frame, (x1, y1), 2, color, -1) + + cv2.putText( + frame, + label, + (x0 + 6, max(12, y0 - 8)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + anchor_color, + 1, + cv2.LINE_AA, + ) + return frame + + vis = _draw_rotation_at_anchor(vis, left_xyz, left_ypr, "L rot", (255, 180, 80)) + vis = _draw_rotation_at_anchor(vis, right_xyz, right_ypr, "R rot", (80, 180, 255)) + vis = _draw_axis_color_legend(vis) + return vis + + +def _viz_keypoints( + images, + actions, + intrinsics_key, + edges, + colors, + edge_ranges, + dot_color=None, + **kwargs, +): + """Visualize all 21 MANO keypoints per hand, projected onto the image.""" + # Prepare image + images = _prepare_viz_image(images) + + intrinsics = INTRINSICS[intrinsics_key] + + vis = images.copy() + h, w = vis.shape[:2] + + left_keypoints, right_keypoints = _split_keypoints(actions, wrist_in_data=False) + keypoints = {} + keypoints["left"] = left_keypoints.reshape(-1, 3) + keypoints["right"] = right_keypoints.reshape(-1, 3) + _default_dot_colors = {"left": (0, 120, 255), "right": (255, 80, 0)} + for hand in ("left", "right"): + hand_dot_color = ( + dot_color if dot_color is not None else _default_dot_colors[hand] + ) + kps_cam = keypoints[hand] + # Camera frame -> pixels + kps_px = cam_frame_to_cam_pixels(kps_cam, intrinsics) # (42, 3+) 21 per arm + + # Identify valid keypoints (z > 0 and in image bounds) + valid = kps_cam[:, 2] > 0.01 + valid &= (kps_px[:, 0] >= 0) & (kps_px[:, 0] < w) + valid &= (kps_px[:, 1] >= 0) & (kps_px[:, 1] < h) + + # Draw skeleton edges (colored by finger) + for finger, start, end in edge_ranges: + color = colors[finger] + for edge_idx in range(start, end): + i, j = edges[edge_idx] + if valid[i] and valid[j]: + p1 = (int(kps_px[i, 0]), int(kps_px[i, 1])) + p2 = (int(kps_px[j, 0]), int(kps_px[j, 1])) + cv2.line(vis, p1, p2, color, 2) + + # Draw keypoint dots on top + for k in range(21): + if valid[k]: + center = (int(kps_px[k, 0]), int(kps_px[k, 1])) + cv2.circle(vis, center, 4, hand_dot_color, -1) + cv2.circle(vis, center, 4, (255, 255, 255), 1) # white border + + # Label wrist + if valid[0]: + wrist_px = (int(kps_px[0, 0]) + 6, int(kps_px[0, 1]) - 6) + cv2.putText( + vis, + f"{hand[0].upper()}", + wrist_px, + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + hand_dot_color, + 2, + ) + + return vis