diff --git a/egomimic/rldb/embodiment/embodiment.py b/egomimic/rldb/embodiment/embodiment.py index 24fab039..798760d8 100644 --- a/egomimic/rldb/embodiment/embodiment.py +++ b/egomimic/rldb/embodiment/embodiment.py @@ -1,7 +1,12 @@ +import copy from abc import ABC from enum import Enum +import numpy as np +import torch + from egomimic.rldb.zarr.action_chunk_transforms import Transform +from egomimic.utils.type_utils import _to_numpy class EMBODIMENT(Enum): @@ -53,3 +58,85 @@ def viz_transformed_batch(batch): def get_keymap(): """Returns a dictionary mapping from the raw keys in the dataset to the canonical keys used by the model.""" raise NotImplementedError + + @classmethod + def viz_gt_preds( + cls, + predictions, + batch, + image_key, + action_key, + transform_list=None, + mode="cartesian", + **kwargs, + ): + embodiment_id = batch["embodiment"][0].item() + embodiment_name = get_embodiment(embodiment_id).lower() + + pred_actions = predictions[ + f"{embodiment_name}_{action_key}" + ] # TODO: make this work with groundtruth, clone batch and replace actions_keypoints with pred_actions + if transform_list is not None: + pred_batch = copy.deepcopy(batch) + pred_batch[action_key] = pred_actions + batch = cls.apply_transform(batch, transform_list) + pred_batch = cls.apply_transform(pred_batch, transform_list) + pred_actions = pred_batch[action_key] + + images = batch[image_key] + actions = batch[action_key] + ims_list = [] + images = _to_numpy(images) + actions = _to_numpy(actions) + pred_actions = _to_numpy(pred_actions) + for i in range(images.shape[0]): + image = images[i] + action = actions[i] + pred_action = pred_actions[i] + ims = cls.viz(image, action, mode=mode, color="Reds", **kwargs) + ims = cls.viz(ims, pred_action, mode=mode, color="Greens", **kwargs) + ims_list.append(ims) + ims = np.stack(ims_list, axis=0) + return ims + + @classmethod + def apply_transform(cls, batch, transform_list: list[Transform]): + if transform_list: + batch_size = None + for v in batch.values(): + if isinstance(v, (np.ndarray, torch.Tensor)): + batch_size = v.shape[0] + break + + if batch_size is not None: + # Apply transforms per-sample (matching how ZarrDataset applies them) + results = [] + for i in range(batch_size): + sample = { + k: (v[i].cpu().numpy() if isinstance(v, torch.Tensor) else v[i]) + if isinstance(v, (np.ndarray, torch.Tensor)) + else v + for k, v in batch.items() + } + for transform in transform_list: + sample = transform.transform(sample) + results.append(sample) + + batch = {} + for k in results[0]: + vals = [r[k] for r in results] + if isinstance(vals[0], np.ndarray): + batch[k] = np.stack(vals, axis=0) + elif isinstance(vals[0], torch.Tensor): + batch[k] = torch.stack(vals, dim=0) + else: + batch[k] = vals + else: + for transform in transform_list: + batch = transform.transform(batch) + + for k, v in batch.items(): + if isinstance(v, np.ndarray): + batch[k] = torch.from_numpy(v).to(torch.float32) + + return batch diff --git a/egomimic/rldb/embodiment/eva.py b/egomimic/rldb/embodiment/eva.py index d77e7407..49676830 100644 --- a/egomimic/rldb/embodiment/eva.py +++ b/egomimic/rldb/embodiment/eva.py @@ -1,71 +1,106 @@ from __future__ import annotations -import cv2 -import numpy as np -from scipy.spatial.transform import Rotation as R +from typing import Literal from egomimic.rldb.embodiment.embodiment import Embodiment from egomimic.rldb.zarr.action_chunk_transforms import ( ActionChunkCoordinateFrameTransform, + BatchQuaternionPoseToYPR, ConcatKeys, DeleteKeys, InterpolateLinear, InterpolatePose, - PoseCoordinateFrameTransform, NumpyToTensor, + PoseCoordinateFrameTransform, + QuaternionPoseToYPR, + SplitKeys, Transform, XYZWXYZ_to_XYZYPR, ) from egomimic.utils.egomimicUtils import ( EXTRINSICS, - INTRINSICS, - cam_frame_to_cam_pixels, - draw_actions, ) from egomimic.utils.pose_utils import ( _matrix_to_xyzwxyz, ) +from egomimic.utils.type_utils import _to_numpy +from egomimic.utils.viz_utils import ( + _viz_axes, + _viz_traj, +) class Eva(Embodiment): + VIZ_INTRINSICS_KEY = "base" VIZ_IMAGE_KEY = "observations.images.front_img_1" @staticmethod - def get_transform_list() -> list[Transform]: - return _build_eva_bimanual_transform_list() + def get_transform_list( + mode: Literal[ + "cartesian", "cartesian_wristframe_ypr", "cartesian_wristframe_quat" + ] = "cartesian", + ) -> list[Transform]: + if mode == "cartesian": + return _build_eva_bimanual_transform_list() + elif mode == "cartesian_wristframe_ypr": + return _build_eva_bimanual_eef_frame_transform_list(is_quat=False) + elif mode == "cartesian_wristframe_quat": + return _build_eva_bimanual_eef_frame_transform_list(is_quat=True) @classmethod - def viz_transformed_batch(cls, batch, mode=""): - """ - Visualize one transformed EVA batch sample. - - Modes: - - palm_traj: draw left/right palm trajectories from actions_cartesian. - - palm_axes: draw local xyz axes at each palm anchor using ypr. - """ - image_key = cls.VIZ_IMAGE_KEY - action_key = "actions_cartesian" - intrinsics_key = "base" - mode = (mode or "palm_traj").lower() + def viz_transformed_batch( + cls, + batch, + mode=Literal["traj", "axes"], + action_key="actions_cartesian", + image_key=None, + transform_list=None, + **kwargs, + ): + if transform_list is not None: + batch = cls.apply_transform(batch, transform_list) + image_key = image_key or cls.VIZ_IMAGE_KEY + action_key = action_key or "actions_cartesian" + intrinsics_key = cls.VIZ_INTRINSICS_KEY + mode = (mode or "traj").lower() + + images = _to_numpy(batch[image_key][0]) + actions = _to_numpy(batch[action_key][0]) + + return cls.viz( + images=images, + actions=actions, + mode=mode, + intrinsics_key=intrinsics_key, + **kwargs, + ) - if mode == "palm_traj": - return _viz_batch_palm_traj( - batch=batch, - image_key=image_key, - action_key=action_key, + @classmethod + def viz( + cls, + images, + actions, + mode=Literal["traj", "axes"], + intrinsics_key=None, + **kwargs, + ): + intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY + if mode == "traj": + return _viz_traj( + images=images, + actions=actions, intrinsics_key=intrinsics_key, + **kwargs, ) - if mode == "palm_axes": - return _viz_batch_palm_axes( - batch=batch, - image_key=image_key, - action_key=action_key, + if mode == "axes": + return _viz_axes( + images=images, + actions=actions, intrinsics_key=intrinsics_key, + **kwargs, ) - raise ValueError( - f"Unsupported mode '{mode}'. Expected one of: " - f"('palm_traj', 'palm_axes', 'keypoints')." + f"Unsupported mode '{mode}'. Expected one of: " f"('traj', 'axes')." ) @classmethod @@ -122,6 +157,246 @@ def get_keymap(cls): } +def _build_eva_bimanual_revert_eef_frame_transform_list( + *, + action_key: str = "actions_cartesian", + obs_key: str = "observations.state.ee_pose", + left_cmd_wristframe: str = "left.cmd_ee_pose_wristframe", + right_cmd_wristframe: str = "right.cmd_ee_pose_wristframe", + left_gripper: str = "left.gripper", + right_gripper: str = "right.gripper", + left_obs_camframe: str = "left.obs_ee_pose_camframe", + right_obs_camframe: str = "right.obs_ee_pose_camframe", + left_obs_gripper: str = "left.obs_gripper", + right_obs_gripper: str = "right.obs_gripper", + left_cmd_camframe: str = "left.cmd_ee_pose_camframe", + right_cmd_camframe: str = "right.cmd_ee_pose_camframe", + is_quat: bool = True, +) -> list[Transform]: + """Revert wrist-frame EVA actions back to camera frame for visualization.""" + if is_quat: + pose_shape = 7 + else: + pose_shape = 6 + transform_list = [ + # Extract obs camframe poses from the concatenated obs key + SplitKeys( + input_key=obs_key, + output_key_list=[ + (left_obs_camframe, pose_shape), + (left_obs_gripper, 1), + (right_obs_camframe, pose_shape), + (right_obs_gripper, 1), + ], + ), + # Split wrist-frame actions into per-arm chunks + SplitKeys( + input_key=action_key, + output_key_list=[ + (left_cmd_wristframe, pose_shape), + (left_gripper, 1), + (right_cmd_wristframe, pose_shape), + (right_gripper, 1), + ], + ), + # Revert wrist frame → camera frame (inverse=False: target_se3 @ chunk_se3) + ActionChunkCoordinateFrameTransform( + target_world=left_obs_camframe, + chunk_world=left_cmd_wristframe, + transformed_key_name=left_cmd_camframe, + mode="xyzypr", + inverse=False, + ), + ActionChunkCoordinateFrameTransform( + target_world=right_obs_camframe, + chunk_world=right_cmd_wristframe, + transformed_key_name=right_cmd_camframe, + mode="xyzypr", + inverse=False, + ), + ConcatKeys( + key_list=[ + left_cmd_camframe, + left_gripper, + right_cmd_camframe, + right_gripper, + ], + new_key_name=action_key, + delete_old_keys=True, + ), + ] + return transform_list + + +def _build_eva_bimanual_eef_frame_transform_list( + *, + left_target_world: str = "left_extrinsics_pose", + right_target_world: str = "right_extrinsics_pose", + left_cmd_world: str = "left.cmd_ee_pose", + right_cmd_world: str = "right.cmd_ee_pose", + left_obs_pose: str = "left.obs_ee_pose", + right_obs_pose: str = "right.obs_ee_pose", + left_obs_gripper: str = "left.obs_gripper", + right_obs_gripper: str = "right.obs_gripper", + left_gripper: str = "left.gripper", + right_gripper: str = "right.gripper", + left_cmd_camframe: str = "left.cmd_ee_pose_camframe", + right_cmd_camframe: str = "right.cmd_ee_pose_camframe", + left_obs_camframe: str = "left.obs_ee_pose_camframe", + right_obs_camframe: str = "right.obs_ee_pose_camframe", + left_cmd_wristframe: str = "left.cmd_ee_pose_wristframe", + right_cmd_wristframe: str = "right.cmd_ee_pose_wristframe", + actions_key: str = "actions_cartesian", + obs_key: str = "observations.state.ee_pose", + chunk_length: int = 100, + stride: int = 1, + extrinsics_key: str = "x5Dec13_2", + is_quat: bool = True, +) -> list[Transform]: + """EVA bimanual transform pipeline with actions expressed relative to the + current EEF pose (wrist frame), analogous to keypoints relative to wrist pose.""" + extrinsics = EXTRINSICS[extrinsics_key] + left_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics["left"][None, :])[0] + right_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics["right"][None, :])[0] + left_extra_batch_key = {"left_extrinsics_pose": left_extrinsics_pose} + right_extra_batch_key = {"right_extrinsics_pose": right_extrinsics_pose} + + # Step 1: transform cmd and obs into camera frame using extrinsics + transform_list = [ + ActionChunkCoordinateFrameTransform( + target_world=left_target_world, + chunk_world=left_cmd_world, + transformed_key_name=left_cmd_camframe, + extra_batch_key=left_extra_batch_key, + mode="xyzwxyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=right_target_world, + chunk_world=right_cmd_world, + transformed_key_name=right_cmd_camframe, + extra_batch_key=right_extra_batch_key, + mode="xyzwxyz", + ), + PoseCoordinateFrameTransform( + target_world=left_target_world, + pose_world=left_obs_pose, + transformed_key_name=left_obs_camframe, + mode="xyzwxyz", + ), + PoseCoordinateFrameTransform( + target_world=right_target_world, + pose_world=right_obs_pose, + transformed_key_name=right_obs_camframe, + mode="xyzwxyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_cmd_camframe, + output_action_key=left_cmd_camframe, + stride=stride, + mode="xyzwxyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_cmd_camframe, + output_action_key=right_cmd_camframe, + stride=stride, + mode="xyzwxyz", + ), + InterpolateLinear( + new_chunk_length=chunk_length, + action_key=left_gripper, + output_action_key=left_gripper, + stride=stride, + ), + InterpolateLinear( + new_chunk_length=chunk_length, + action_key=right_gripper, + output_action_key=right_gripper, + stride=stride, + ), + # Step 2: transform camera-frame actions into EEF-relative (wrist) frame + ActionChunkCoordinateFrameTransform( + target_world=left_obs_camframe, + chunk_world=left_cmd_camframe, + transformed_key_name=left_cmd_wristframe, + mode="xyzwxyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=right_obs_camframe, + chunk_world=right_cmd_camframe, + transformed_key_name=right_cmd_wristframe, + mode="xyzwxyz", + ), + ] + + if not is_quat: + transform_list.extend( + [ + BatchQuaternionPoseToYPR( + pose_key=left_cmd_wristframe, + output_key=left_cmd_wristframe, + ), + BatchQuaternionPoseToYPR( + pose_key=right_cmd_wristframe, + output_key=right_cmd_wristframe, + ), + QuaternionPoseToYPR( + pose_key=left_obs_camframe, + output_key=left_obs_camframe, + ), + QuaternionPoseToYPR( + pose_key=right_obs_camframe, + output_key=right_obs_camframe, + ), + ] + ) + + transform_list.extend( + [ + ConcatKeys( + key_list=[ + left_cmd_wristframe, + left_gripper, + right_cmd_wristframe, + right_gripper, + ], + new_key_name=actions_key, + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_obs_camframe, + left_obs_gripper, + right_obs_camframe, + right_obs_gripper, + ], + new_key_name=obs_key, + delete_old_keys=True, + ), + DeleteKeys( + keys_to_delete=[ + left_cmd_world, + right_cmd_world, + left_obs_pose, + right_obs_pose, + left_cmd_camframe, + right_cmd_camframe, + left_target_world, + right_target_world, + ] + ), + NumpyToTensor( + keys=[ + actions_key, + obs_key, + ] + ), + ] + ) + return transform_list + + def _build_eva_bimanual_transform_list( *, left_target_world: str = "left_extrinsics_pose", @@ -149,46 +424,48 @@ def _build_eva_bimanual_transform_list( right_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics["right"][None, :])[0] left_extra_batch_key = {"left_extrinsics_pose": left_extrinsics_pose} right_extra_batch_key = {"right_extrinsics_pose": right_extrinsics_pose} + + mode = "xyzwxyz" if is_quat else "xyzypr" transform_list = [ ActionChunkCoordinateFrameTransform( target_world=left_target_world, chunk_world=left_cmd_world, transformed_key_name=left_cmd_camframe, extra_batch_key=left_extra_batch_key, - is_quat=is_quat, + mode=mode, ), ActionChunkCoordinateFrameTransform( target_world=right_target_world, chunk_world=right_cmd_world, transformed_key_name=right_cmd_camframe, extra_batch_key=right_extra_batch_key, - is_quat=is_quat, + mode=mode, ), PoseCoordinateFrameTransform( target_world=left_target_world, pose_world=left_obs_pose, transformed_key_name=left_obs_pose, - is_quat=is_quat, + mode=mode, ), PoseCoordinateFrameTransform( target_world=right_target_world, pose_world=right_obs_pose, transformed_key_name=right_obs_pose, - is_quat=is_quat, + mode=mode, ), InterpolatePose( new_chunk_length=chunk_length, action_key=left_cmd_camframe, output_action_key=left_cmd_camframe, stride=stride, - is_quat=is_quat, + mode=mode, ), InterpolatePose( new_chunk_length=chunk_length, action_key=right_cmd_camframe, output_action_key=right_cmd_camframe, stride=stride, - is_quat=is_quat, + mode=mode, ), InterpolateLinear( new_chunk_length=chunk_length, @@ -255,165 +532,3 @@ def _build_eva_bimanual_transform_list( ] ) return transform_list - - -def _to_numpy(arr): - if hasattr(arr, "detach"): - arr = arr.detach() - if hasattr(arr, "cpu"): - arr = arr.cpu() - if hasattr(arr, "numpy"): - return arr.numpy() - return np.asarray(arr) - - -def _split_action_pose(actions): - # 14D layout: [L xyz ypr g, R xyz ypr g] - # 12D layout: [L xyz ypr, R xyz ypr] - if actions.shape[-1] == 14: - left_xyz = actions[:, :3] - left_ypr = actions[:, 3:6] - right_xyz = actions[:, 7:10] - right_ypr = actions[:, 10:13] - elif actions.shape[-1] == 12: - left_xyz = actions[:, :3] - left_ypr = actions[:, 3:6] - right_xyz = actions[:, 6:9] - right_ypr = actions[:, 9:12] - else: - raise ValueError(f"Unsupported action dim {actions.shape[-1]}") - return left_xyz, left_ypr, right_xyz, right_ypr - - -def _prepare_viz_image(batch, image_key): - img = _to_numpy(batch[image_key][0]) - if img.ndim == 3 and img.shape[0] in (1, 3): - img = np.transpose(img, (1, 2, 0)) - - if img.dtype != np.uint8: - if img.max() <= 1.0: - img = (img * 255.0).clip(0, 255).astype(np.uint8) - else: - img = img.clip(0, 255).astype(np.uint8) - - if img.ndim == 2: - img = np.repeat(img[:, :, None], 3, axis=-1) - elif img.shape[-1] == 1: - img = np.repeat(img, 3, axis=-1) - - return img - - -def _viz_batch_palm_traj(batch, image_key, action_key, intrinsics_key): - img_np = _prepare_viz_image(batch, image_key) - intrinsics = INTRINSICS[intrinsics_key] - actions = _to_numpy(batch[action_key][0]) - left_xyz, _, right_xyz, _ = _split_action_pose(actions) - - vis = draw_actions( - img_np.copy(), - type="xyz", - color="Blues", - actions=left_xyz, - extrinsics=None, - intrinsics=intrinsics, - arm="left", - ) - vis = draw_actions( - vis, - type="xyz", - color="Reds", - actions=right_xyz, - extrinsics=None, - intrinsics=intrinsics, - arm="right", - ) - return vis - - -def _viz_batch_palm_axes(batch, image_key, action_key, intrinsics_key, axis_len_m=0.04): - img_np = _prepare_viz_image(batch, image_key) - intrinsics = INTRINSICS[intrinsics_key] - actions = _to_numpy(batch[action_key][0]) - left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions) - vis = img_np.copy() - - def _draw_axis_color_legend(frame): - _, w = frame.shape[:2] - x_right = w - 12 - y_start = 14 - y_step = 12 - line_len = 24 - axis_legend = [ - ("x", (255, 0, 0)), - ("y", (0, 255, 0)), - ("z", (0, 0, 255)), - ] - for i, (name, color) in enumerate(axis_legend): - y = y_start + i * y_step - x0 = x_right - line_len - x1 = x_right - cv2.line(frame, (x0, y), (x1, y), color, 3) - cv2.putText( - frame, - name, - (x0 - 12, y + 4), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - color, - 1, - cv2.LINE_AA, - ) - return frame - - def _draw_rotation_at_palm(frame, xyz_seq, ypr_seq, label, anchor_color): - if len(xyz_seq) == 0 or len(ypr_seq) == 0: - return frame - - palm_xyz = xyz_seq[0] - palm_ypr = ypr_seq[0] - rot = R.from_euler("ZYX", palm_ypr, degrees=False).as_matrix() - - axis_points_cam = np.vstack( - [ - palm_xyz, - palm_xyz + rot[:, 0] * axis_len_m, - palm_xyz + rot[:, 1] * axis_len_m, - palm_xyz + rot[:, 2] * axis_len_m, - ] - ) - - px = cam_frame_to_cam_pixels(axis_points_cam, intrinsics)[:, :2] - if not np.isfinite(px).all(): - return frame - pts = np.round(px).astype(np.int32) - - h, w = frame.shape[:2] - x0, y0 = pts[0] - if not (0 <= x0 < w and 0 <= y0 < h): - return frame - - cv2.circle(frame, (x0, y0), 4, anchor_color, -1) - axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] - for i, color in enumerate(axis_colors, start=1): - x1, y1 = pts[i] - if 0 <= x1 < w and 0 <= y1 < h: - cv2.line(frame, (x0, y0), (x1, y1), color, 2) - cv2.circle(frame, (x1, y1), 2, color, -1) - - cv2.putText( - frame, - label, - (x0 + 6, max(12, y0 - 8)), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - anchor_color, - 1, - cv2.LINE_AA, - ) - return frame - - vis = _draw_rotation_at_palm(vis, left_xyz, left_ypr, "L rot", (255, 180, 80)) - vis = _draw_rotation_at_palm(vis, right_xyz, right_ypr, "R rot", (80, 180, 255)) - vis = _draw_axis_color_legend(vis) - return vis diff --git a/egomimic/rldb/embodiment/human.py b/egomimic/rldb/embodiment/human.py index 1939fd91..981ff98c 100644 --- a/egomimic/rldb/embodiment/human.py +++ b/egomimic/rldb/embodiment/human.py @@ -1,19 +1,28 @@ from __future__ import annotations +from typing import Literal + from egomimic.rldb.embodiment.embodiment import Embodiment -from egomimic.rldb.embodiment.eva import ( - _viz_batch_palm_axes, - _viz_batch_palm_traj, -) from egomimic.rldb.zarr.action_chunk_transforms import ( ActionChunkCoordinateFrameTransform, + BatchQuaternionPoseToYPR, ConcatKeys, DeleteKeys, InterpolatePose, PoseCoordinateFrameTransform, + QuaternionPoseToYPR, + Reshape, + SplitKeys, Transform, XYZWXYZ_to_XYZYPR, ) +from egomimic.utils.type_utils import _to_numpy +from egomimic.utils.viz_utils import ( + ColorPalette, + _viz_axes, + _viz_keypoints, + _viz_traj, +) class Human(Embodiment): @@ -22,75 +31,242 @@ class Human(Embodiment): ACTION_STRIDE = 3 @classmethod - def get_transform_list(cls) -> list[Transform]: - return _build_aria_bimanual_transform_list(stride=cls.ACTION_STRIDE) + def viz_transformed_batch( + cls, + batch, + mode=Literal["traj", "axes", "keypoints"], + action_key="actions_cartesian", + image_key=None, + transform_list=None, + **kwargs, + ): + if transform_list is not None: + batch = cls.apply_transform(batch, transform_list) - @classmethod - def viz_transformed_batch(cls, batch, mode=""): - image_key = cls.VIZ_IMAGE_KEY - action_key = "actions_cartesian" + image_key = image_key or cls.VIZ_IMAGE_KEY + action_key = action_key or "actions_cartesian" intrinsics_key = cls.VIZ_INTRINSICS_KEY - mode = (mode or "palm_traj").lower() + mode = (mode or "traj").lower() + images = _to_numpy(batch[image_key][0]) + actions = _to_numpy(batch[action_key][0]) + + return cls.viz( + images=images, + actions=actions, + mode=mode, + intrinsics_key=intrinsics_key, + **kwargs, + ) - if mode == "palm_traj": - return _viz_batch_palm_traj( - batch=batch, - image_key=image_key, - action_key=action_key, + @classmethod + def viz( + cls, + images, + actions, + mode=Literal["traj", "axes", "keypoints"], + intrinsics_key=None, + **kwargs, + ): + intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY + if mode == "traj": + return _viz_traj( + images=images, + actions=actions, intrinsics_key=intrinsics_key, + **kwargs, ) - if mode == "palm_axes": - return _viz_batch_palm_axes( - batch=batch, - image_key=image_key, - action_key=action_key, + if mode == "axes": + return _viz_axes( + images=images, + actions=actions, intrinsics_key=intrinsics_key, + **kwargs, ) if mode == "keypoints": - raise NotImplementedError( - "mode='keypoints' is reserved and not implemented yet." + color = kwargs.get("color", None) + if color is not None and ColorPalette.is_valid(color): + n = len(cls.FINGER_COLORS) + colors = { + finger: ColorPalette.to_rgb(color, value=(i + 1) / (n + 1)) + for i, finger in enumerate(cls.FINGER_COLORS) + } + dot_color = ColorPalette.to_rgb(color, value=0.7) + else: + colors = cls.FINGER_COLORS + dot_color = cls.DOT_COLOR + return _viz_keypoints( + images=images, + actions=actions, + intrinsics_key=intrinsics_key, + edges=cls.FINGER_EDGES, + edge_ranges=cls.FINGER_EDGE_RANGES, + colors=colors, + dot_color=dot_color, + **kwargs, ) - raise ValueError( f"Unsupported mode '{mode}'. Expected one of: " - f"('palm_traj', 'palm_axes', 'keypoints')." + f"('traj', 'axes', 'keypoints')." ) @classmethod - def get_keymap(cls): - return { - cls.VIZ_IMAGE_KEY: { - "key_type": "camera_keys", - "zarr_key": "images.front_1", - }, - "right.action_ee_pose": { - "key_type": "action_keys", - "zarr_key": "right.obs_ee_pose", - "horizon": 30, - }, - "left.action_ee_pose": { - "key_type": "action_keys", - "zarr_key": "left.obs_ee_pose", - "horizon": 30, - }, - "right.obs_ee_pose": { - "key_type": "proprio_keys", - "zarr_key": "right.obs_ee_pose", - }, - "left.obs_ee_pose": { - "key_type": "proprio_keys", - "zarr_key": "left.obs_ee_pose", - }, - "obs_head_pose": { - "key_type": "proprio_keys", - "zarr_key": "obs_head_pose", - }, - } + def get_keymap(cls, mode: Literal["cartesian", "keypoints"]): + if mode == "cartesian": + key_map = { + cls.VIZ_IMAGE_KEY: { + "key_type": "camera_keys", + "zarr_key": "images.front_1", + }, + "right.action_ee_pose": { + "key_type": "action_keys", + "zarr_key": "right.obs_ee_pose", + "horizon": 30, + }, + "left.action_ee_pose": { + "key_type": "action_keys", + "zarr_key": "left.obs_ee_pose", + "horizon": 30, + }, + "right.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_ee_pose", + }, + "left.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_ee_pose", + }, + "obs_head_pose": { + "key_type": "proprio_keys", + "zarr_key": "obs_head_pose", + }, + } + elif mode == "keypoints": + key_map = { + cls.VIZ_IMAGE_KEY: { + "key_type": "camera_keys", + "zarr_key": "images.front_1", + }, + "left.action_keypoints": { + "key_type": "action_keys", + "zarr_key": "left.obs_keypoints", + "horizon": 30, + }, + "right.action_keypoints": { + "key_type": "action_keys", + "zarr_key": "right.obs_keypoints", + "horizon": 30, + }, + "left.action_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_wrist_pose", + "horizon": 30, + }, + "right.action_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_wrist_pose", + "horizon": 30, + }, + "left.obs_keypoints": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_keypoints", + }, + "right.obs_keypoints": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_keypoints", + }, + "left.obs_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_wrist_pose", + }, + "right.obs_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_wrist_pose", + }, + "obs_head_pose": { + "key_type": "proprio_keys", + "zarr_key": "obs_head_pose", + }, + } + else: + raise ValueError( + f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'keypoints'." + ) + return key_map class Aria(Human): VIZ_INTRINSICS_KEY = "base" ACTION_STRIDE = 3 + FINGER_EDGES = [ + ( + 5, + 6, + ), + (6, 7), + (7, 0), # thumb + (5, 8), + (8, 9), + (9, 10), + (9, 1), # index + (5, 11), + (11, 12), + (12, 13), + (13, 2), # middle + (5, 14), + (14, 15), + (15, 16), + (16, 3), # ring + (5, 17), + (17, 18), + (18, 19), + (19, 4), # pinky + ] + FINGER_COLORS = { + "thumb": (255, 100, 100), # red + "index": (100, 255, 100), # green + "middle": (100, 100, 255), # blue + "ring": (255, 255, 100), # yellow + "pinky": (255, 100, 255), # magenta + } + FINGER_EDGE_RANGES = [ + ("thumb", 0, 3), + ("index", 3, 7), + ("middle", 7, 11), + ("ring", 11, 15), + ("pinky", 15, 19), + ] + DOT_COLOR = (255, 165, 0) + + @classmethod + def get_transform_list( + cls, + mode: Literal[ + "cartesian", + "keypoints_headframe", + "keypoints_wristframe_ypr", + "keypoints_wristframe_quat", + ], + ) -> list[Transform]: + if mode == "cartesian": + return _build_aria_cartesian_bimanual_transform_list( + stride=cls.ACTION_STRIDE + ) + elif mode == "keypoints_headframe": + return _build_aria_keypoints_bimanual_transform_list( + stride=cls.ACTION_STRIDE + ) + elif mode == "keypoints_wristframe_ypr": + return _build_aria_keypoints_eef_frame_transform_list( + stride=cls.ACTION_STRIDE, is_quat=False + ) + elif mode == "keypoints_wristframe_quat": + return _build_aria_keypoints_eef_frame_transform_list( + stride=cls.ACTION_STRIDE, is_quat=True + ) + else: + raise ValueError( + f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'keypoints', 'keypoints_wristframe'." + ) class Scale(Human): @@ -103,7 +279,476 @@ class Mecka(Human): ACTION_STRIDE = 1 -def _build_aria_bimanual_transform_list( +# this works for quat and ypr since actionChunkCoordinateFrameTransform works for both +def _build_aria_keypoints_revert_eef_frame_transform_list( + *, + action_key: str = "actions_keypoints", + left_keypoints_action_wristframe: str = "left.action_keypoints_wristframe", + right_keypoints_action_wristframe: str = "right.action_keypoints_wristframe", + left_wrist_obs_headframe: str = "left.obs_wrist_pose_headframe", + right_wrist_obs_headframe: str = "right.obs_wrist_pose_headframe", + left_wrist_action_headframe: str = "left.action_wrist_pose_headframe", + right_wrist_action_headframe: str = "right.action_wrist_pose_headframe", + left_wrist_action_wristframe: str = "left.action_wrist_pose_wristframe", + right_wrist_action_wristframe: str = "right.action_wrist_pose_wristframe", + left_keypoints_action_headframe: str = "left.action_keypoints_headframe", + right_keypoints_action_headframe: str = "right.action_keypoints_headframe", + is_quat: bool = True, +) -> list[Transform]: + if is_quat: + pose_shape = 7 + else: + pose_shape = 6 + transform_list = [ + SplitKeys( + input_key=action_key, + output_key_list=[ + (left_wrist_action_wristframe, pose_shape), + (left_keypoints_action_wristframe, 63), + (right_wrist_action_wristframe, pose_shape), + (right_keypoints_action_wristframe, 63), + ], + ), + Reshape( + input_key=left_keypoints_action_wristframe, + output_key=left_keypoints_action_wristframe, + shape=(100, 21, 3), + ), + Reshape( + input_key=right_keypoints_action_wristframe, + output_key=right_keypoints_action_wristframe, + shape=(100, 21, 3), + ), + ActionChunkCoordinateFrameTransform( + target_world=left_wrist_obs_headframe, + chunk_world=left_keypoints_action_wristframe, + transformed_key_name=left_keypoints_action_headframe, + mode="xyz", + inverse=False, + ), + ActionChunkCoordinateFrameTransform( + target_world=right_wrist_obs_headframe, + chunk_world=right_keypoints_action_wristframe, + transformed_key_name=right_keypoints_action_headframe, + mode="xyz", + inverse=False, + ), + Reshape( + input_key=left_keypoints_action_headframe, + output_key=left_keypoints_action_headframe, + shape=(100, 63), + ), + Reshape( + input_key=right_keypoints_action_headframe, + output_key=right_keypoints_action_headframe, + shape=(100, 63), + ), + ConcatKeys( + key_list=[ + left_keypoints_action_headframe, + right_keypoints_action_headframe, + ], + new_key_name=action_key, + delete_old_keys=True, + ), + ] + return transform_list + + +def _build_aria_keypoints_eef_frame_transform_list( + *, + target_world: str = "obs_head_pose", + target_world_ypr: str = "obs_head_pose_ypr", + target_world_is_quat: bool = True, + left_keypoints_action_world: str = "left.action_keypoints", + right_keypoints_action_world: str = "right.action_keypoints", + left_keypoints_obs_pose: str = "left.obs_keypoints", + right_keypoints_obs_pose: str = "right.obs_keypoints", + left_keypoints_action_headframe: str = "left.action_keypoints_headframe", + right_keypoints_action_headframe: str = "right.action_keypoints_headframe", + left_keypoints_obs_headframe: str = "left.obs_keypoints_headframe", + right_keypoints_obs_headframe: str = "right.obs_keypoints_headframe", + left_wrist_action_world: str = "left.action_wrist_pose", + right_wrist_action_world: str = "right.action_wrist_pose", + left_keypoints_action_wristframe: str = "left.action_keypoints_wristframe", + right_keypoints_action_wristframe: str = "right.action_keypoints_wristframe", + left_wrist_action_wristframe: str = "left.action_wrist_pose_wristframe", + right_wrist_action_wristframe: str = "right.action_wrist_pose_wristframe", + left_wrist_obs_pose: str = "left.obs_wrist_pose", + right_wrist_obs_pose: str = "right.obs_wrist_pose", + left_wrist_action_headframe: str = "left.action_wrist_pose_headframe", + right_wrist_action_headframe: str = "right.action_wrist_pose_headframe", + left_wrist_obs_headframe: str = "left.obs_wrist_pose_headframe", + right_wrist_obs_headframe: str = "right.obs_wrist_pose_headframe", + left_keypoints_obs_wristframe: str = "left.obs_keypoints_wristframe", + right_keypoints_obs_wristframe: str = "right.obs_keypoints_wristframe", + delete_target_world: bool = True, + chunk_length: int = 100, + stride: int = 3, + is_quat: bool = True, +) -> list[Transform]: + transform_list = _build_aria_keypoints_bimanual_transform_list( + target_world=target_world, + target_world_ypr=target_world_ypr, + target_world_is_quat=target_world_is_quat, + delete_target_world=delete_target_world, + chunk_length=chunk_length, + stride=stride, + concat_keys=False, + ) + delete_keys = [ + left_keypoints_action_world, + right_keypoints_action_world, + left_keypoints_obs_pose, + right_keypoints_obs_pose, + left_wrist_action_world, + right_wrist_action_world, + left_wrist_obs_pose, + right_wrist_obs_pose, + left_keypoints_action_headframe, + right_keypoints_action_headframe, + left_keypoints_obs_headframe, + right_keypoints_obs_headframe, + left_wrist_action_headframe, + right_wrist_action_headframe, + ] + if delete_target_world: + delete_keys.append(target_world) + if target_world_is_quat: + delete_keys.append(target_world_ypr) + transform_list.extend( + [ + Reshape( + input_key=left_keypoints_action_headframe, + output_key=left_keypoints_action_headframe, + shape=(chunk_length, 21, 3), + ), + Reshape( + input_key=right_keypoints_action_headframe, + output_key=right_keypoints_action_headframe, + shape=(chunk_length, 21, 3), + ), + ActionChunkCoordinateFrameTransform( + target_world=left_wrist_obs_headframe, + chunk_world=left_keypoints_action_headframe, + transformed_key_name=left_keypoints_action_wristframe, + mode="xyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=right_wrist_obs_headframe, + chunk_world=right_keypoints_action_headframe, + transformed_key_name=right_keypoints_action_wristframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_action_wristframe, + output_key=left_keypoints_action_wristframe, + shape=(chunk_length, 63), + ), + Reshape( + input_key=right_keypoints_action_wristframe, + output_key=right_keypoints_action_wristframe, + shape=(chunk_length, 63), + ), + Reshape( + input_key=left_keypoints_obs_headframe, + output_key=left_keypoints_obs_headframe, + shape=(21, 3), + ), + Reshape( + input_key=right_keypoints_obs_headframe, + output_key=right_keypoints_obs_headframe, + shape=(21, 3), + ), + PoseCoordinateFrameTransform( + target_world=left_wrist_obs_headframe, + pose_world=left_keypoints_obs_headframe, + transformed_key_name=left_keypoints_obs_wristframe, + mode="xyz", + ), + PoseCoordinateFrameTransform( + target_world=right_wrist_obs_headframe, + pose_world=right_keypoints_obs_headframe, + transformed_key_name=right_keypoints_obs_wristframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_obs_wristframe, + output_key=left_keypoints_obs_wristframe, + shape=(63,), + ), + Reshape( + input_key=right_keypoints_obs_wristframe, + output_key=right_keypoints_obs_wristframe, + shape=(63,), + ), + ActionChunkCoordinateFrameTransform( + target_world=left_wrist_obs_headframe, + chunk_world=left_wrist_action_headframe, + transformed_key_name=left_wrist_action_wristframe, + mode="xyzwxyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=right_wrist_obs_headframe, + chunk_world=right_wrist_action_headframe, + transformed_key_name=right_wrist_action_wristframe, + mode="xyzwxyz", + ), + ] + ) + if not is_quat: + transform_list.extend( + [ + BatchQuaternionPoseToYPR( + pose_key=left_wrist_action_wristframe, + output_key=left_wrist_action_wristframe, + ), + BatchQuaternionPoseToYPR( + pose_key=right_wrist_action_wristframe, + output_key=right_wrist_action_wristframe, + ), + QuaternionPoseToYPR( + pose_key=left_wrist_obs_headframe, + output_key=left_wrist_obs_headframe, + ), + QuaternionPoseToYPR( + pose_key=right_wrist_obs_headframe, + output_key=right_wrist_obs_headframe, + ), + ] + ) + transform_list.extend( + [ + ConcatKeys( + key_list=[ + left_wrist_action_wristframe, + left_keypoints_action_wristframe, + right_wrist_action_wristframe, + right_keypoints_action_wristframe, + ], + new_key_name="actions_keypoints", + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_keypoints_obs_wristframe, + right_keypoints_obs_wristframe, + ], + new_key_name="observations.state.keypoints", + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_wrist_obs_headframe, + right_wrist_obs_headframe, + ], + new_key_name="observations.state.wrist_pose", + delete_old_keys=False, + ), + DeleteKeys(keys_to_delete=delete_keys), + ] + ) + return transform_list + + +def _build_aria_keypoints_bimanual_transform_list( + *, + target_world: str = "obs_head_pose", + target_world_ypr: str = "obs_head_pose_ypr", + target_world_is_quat: bool = True, + left_keypoints_action_world: str = "left.action_keypoints", + right_keypoints_action_world: str = "right.action_keypoints", + left_keypoints_obs_pose: str = "left.obs_keypoints", + right_keypoints_obs_pose: str = "right.obs_keypoints", + left_keypoints_action_headframe: str = "left.action_keypoints_headframe", + right_keypoints_action_headframe: str = "right.action_keypoints_headframe", + left_keypoints_obs_headframe: str = "left.obs_keypoints_headframe", + right_keypoints_obs_headframe: str = "right.obs_keypoints_headframe", + left_wrist_action_world: str = "left.action_wrist_pose", + right_wrist_action_world: str = "right.action_wrist_pose", + left_wrist_obs_pose: str = "left.obs_wrist_pose", + right_wrist_obs_pose: str = "right.obs_wrist_pose", + left_wrist_action_headframe: str = "left.action_wrist_pose_headframe", + right_wrist_action_headframe: str = "right.action_wrist_pose_headframe", + left_wrist_obs_headframe: str = "left.obs_wrist_pose_headframe", + right_wrist_obs_headframe: str = "right.obs_wrist_pose_headframe", + delete_target_world: bool = True, + chunk_length: int = 100, + stride: int = 3, + concat_keys: bool = True, +) -> list[Transform]: + keys_to_delete = list( + { + left_keypoints_action_world, + right_keypoints_action_world, + left_keypoints_obs_pose, + right_keypoints_obs_pose, + left_wrist_action_world, + right_wrist_action_world, + left_wrist_obs_pose, + right_wrist_obs_pose, + left_keypoints_action_headframe, + right_keypoints_action_headframe, + left_keypoints_obs_headframe, + right_keypoints_obs_headframe, + left_wrist_action_headframe, + right_wrist_action_headframe, + left_wrist_obs_headframe, + right_wrist_obs_headframe, + } + ) + if delete_target_world: + keys_to_delete.append(target_world) + if target_world_is_quat: + keys_to_delete.append(target_world_ypr) + transform_list: list[Transform] = [ + Reshape( + input_key=left_keypoints_action_world, + output_key=left_keypoints_action_world, + shape=(30, 21, 3), + ), + Reshape( + input_key=right_keypoints_action_world, + output_key=right_keypoints_action_world, + shape=(30, 21, 3), + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=left_keypoints_action_world, + transformed_key_name=left_keypoints_action_headframe, + mode="xyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=right_keypoints_action_world, + transformed_key_name=right_keypoints_action_headframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_obs_pose, + output_key=left_keypoints_obs_pose, + shape=(21, 3), + ), + Reshape( + input_key=right_keypoints_obs_pose, + output_key=right_keypoints_obs_pose, + shape=(21, 3), + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=left_keypoints_obs_pose, + transformed_key_name=left_keypoints_obs_headframe, + mode="xyz", + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=right_keypoints_obs_pose, + transformed_key_name=right_keypoints_obs_headframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_obs_headframe, + output_key=left_keypoints_obs_headframe, + shape=(63,), + ), + Reshape( + input_key=right_keypoints_obs_headframe, + output_key=right_keypoints_obs_headframe, + shape=(63,), + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_keypoints_action_headframe, + output_action_key=left_keypoints_action_headframe, + stride=stride, + mode="xyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_keypoints_action_headframe, + output_action_key=right_keypoints_action_headframe, + stride=stride, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_action_headframe, + output_key=left_keypoints_action_headframe, + shape=(chunk_length, 63), + ), + Reshape( + input_key=right_keypoints_action_headframe, + output_key=right_keypoints_action_headframe, + shape=(chunk_length, 63), + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=left_wrist_action_world, + transformed_key_name=left_wrist_action_headframe, + mode="xyzwxyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=right_wrist_action_world, + transformed_key_name=right_wrist_action_headframe, + mode="xyzwxyz", + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=left_wrist_obs_pose, + transformed_key_name=left_wrist_obs_headframe, + mode="xyzwxyz", + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=right_wrist_obs_pose, + transformed_key_name=right_wrist_obs_headframe, + mode="xyzwxyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_wrist_action_headframe, + output_action_key=left_wrist_action_headframe, + stride=stride, + mode="xyzwxyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_wrist_action_headframe, + output_action_key=right_wrist_action_headframe, + stride=stride, + mode="xyzwxyz", + ), + ] + if concat_keys: + transform_list.extend( + [ + ConcatKeys( + key_list=[ + left_wrist_action_headframe, + left_keypoints_action_headframe, + right_wrist_action_headframe, + right_keypoints_action_headframe, + ], + new_key_name="actions_keypoints", + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_wrist_obs_headframe, + left_keypoints_obs_headframe, + right_wrist_obs_headframe, + right_keypoints_obs_headframe, + ], + new_key_name="observations.state.keypoints", + delete_old_keys=True, + ), + DeleteKeys(keys_to_delete=keys_to_delete), + ] + ) + return transform_list + + +def _build_aria_cartesian_bimanual_transform_list( *, target_world: str = "obs_head_pose", target_world_ypr: str = "obs_head_pose_ypr", @@ -147,39 +792,39 @@ def _build_aria_bimanual_transform_list( target_world=target_pose_key, chunk_world=left_action_world, transformed_key_name=left_action_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), ActionChunkCoordinateFrameTransform( target_world=target_pose_key, chunk_world=right_action_world, transformed_key_name=right_action_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), PoseCoordinateFrameTransform( target_world=target_pose_key, pose_world=left_obs_pose, transformed_key_name=left_obs_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), PoseCoordinateFrameTransform( target_world=target_pose_key, pose_world=right_obs_pose, transformed_key_name=right_obs_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), InterpolatePose( new_chunk_length=chunk_length, action_key=left_action_headframe, output_action_key=left_action_headframe, stride=stride, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), InterpolatePose( new_chunk_length=chunk_length, action_key=right_action_headframe, output_action_key=right_action_headframe, stride=stride, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), ] diff --git a/egomimic/rldb/zarr/action_chunk_transforms.py b/egomimic/rldb/zarr/action_chunk_transforms.py index ea0ce6ad..379c3a6a 100644 --- a/egomimic/rldb/zarr/action_chunk_transforms.py +++ b/egomimic/rldb/zarr/action_chunk_transforms.py @@ -13,19 +13,26 @@ from __future__ import annotations from abc import abstractmethod +from typing import Literal import numpy as np +import torch from projectaria_tools.core.sophus import SE3 from scipy.spatial.transform import Rotation as R -import torch from egomimic.utils.pose_utils import ( _interpolate_euler, _interpolate_linear, _interpolate_quat_wxyz, + _interpolate_xyz, + _matrix_to_xyz, _matrix_to_xyzwxyz, _matrix_to_xyzypr, + _xyz_to_matrix, _xyzwxyz_to_matrix, + _xyzypr_to_matrix, + wxyz_to_xyzw, + xyzw_to_wxyz, ) # --------------------------------------------------------------------------- @@ -56,7 +63,7 @@ def __init__( action_key: str, output_action_key: str, stride: int = 1, - is_quat: bool = False, + mode: Literal["xyzwxyz", "xyzypr"] = "xyzwxyz", ): if stride <= 0: raise ValueError(f"stride must be positive, got {stride}") @@ -64,12 +71,12 @@ def __init__( self.action_key = action_key self.output_action_key = output_action_key self.stride = int(stride) - self.is_quat = is_quat + self.mode = mode def transform(self, batch: dict) -> dict: actions = np.asarray(batch[self.action_key]) actions = actions[:: self.stride] - if self.is_quat: + if self.mode == "xyzwxyz": if actions.ndim != 2 or actions.shape[-1] != 7: raise ValueError( f"InterpolatePose expects (T, 7) when is_quat=True, got " @@ -78,7 +85,7 @@ def transform(self, batch: dict) -> dict: batch[self.output_action_key] = _interpolate_quat_wxyz( actions, self.new_chunk_length ) - else: + elif self.mode == "xyzypr": if actions.ndim != 2 or actions.shape[-1] != 6: raise ValueError( f"InterpolatePose expects (T, 6), got {actions.shape} for key " @@ -87,6 +94,15 @@ def transform(self, batch: dict) -> dict: batch[self.output_action_key] = _interpolate_euler( actions, self.new_chunk_length ) + else: + if actions.shape[-1] != 3: + raise ValueError( + f"InterpolatePose expects (T, 3) or (T, K, 3), got {actions.shape} for key " + f"'{self.action_key}'" + ) + batch[self.output_action_key] = _interpolate_xyz( + actions, self.new_chunk_length + ) return batch @@ -126,27 +142,6 @@ def transform(self, batch: dict) -> dict: # --------------------------------------------------------------------------- -def _xyzypr_to_matrix(xyzypr: np.ndarray) -> np.ndarray: - """ - args: - xyzypr: (B, 6) np.array of [[x, y, z, yaw, pitch, roll]] - returns: - (B, 4, 4) array of SE3 transformation matrices - """ - if xyzypr.ndim != 2 or xyzypr.shape[-1] != 6: - raise ValueError(f"Expected (B, 6) array, got shape {xyzypr.shape}") - - B = xyzypr.shape[0] - dtype = xyzypr.dtype if np.issubdtype(xyzypr.dtype, np.floating) else np.float64 - - mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() - # Input is [yaw, pitch, roll], so use ZYX order (Rz @ Ry @ Rx). - mats[:, :3, :3] = R.from_euler("ZYX", xyzypr[:, 3:6], degrees=False).as_matrix() - mats[:, :3, 3] = xyzypr[:, :3] - - return mats - - class ActionChunkCoordinateFrameTransform(Transform): def __init__( self, @@ -154,7 +149,8 @@ def __init__( chunk_world: str, transformed_key_name: str, extra_batch_key: dict = None, - is_quat: bool = False, + mode: Literal["xyz", "xyzwxyz", "xyzypr"] = "xyzwxyz", + inverse: bool = True, ): """ args: @@ -167,7 +163,8 @@ def __init__( self.chunk_world = chunk_world self.transformed_key_name = transformed_key_name self.extra_batch_key = extra_batch_key - self.is_quat = is_quat + self.mode = mode + self.inverse = inverse def transform(self, batch): """ @@ -183,25 +180,56 @@ def transform(self, batch): if is_quat=False: (T, 6) xyz + ypr if is_quat=True: (T, 7) xyz + quat(wxyz) """ + # flatten to (T, D) + # target world is head pose, chunk world is keypoints batch.update(self.extra_batch_key or {}) target_world = np.asarray(batch[self.target_world]) chunk_world = np.asarray(batch[self.chunk_world]) - to_matrix_fn = _xyzwxyz_to_matrix if self.is_quat else _xyzypr_to_matrix + chunk_world_shape = None + + if chunk_world.ndim > 2: + chunk_world_shape = chunk_world.shape + chunk_world = chunk_world.reshape(-1, chunk_world_shape[-1]) + + to_matrix_fn = None + if self.mode == "xyzwxyz": + to_matrix_fn = _xyzwxyz_to_matrix + elif self.mode == "xyzypr": + to_matrix_fn = _xyzypr_to_matrix + elif self.mode == "xyz": + to_matrix_fn = _xyz_to_matrix + else: + raise ValueError(f"Invalid mode: {self.mode}") + target_world_to_matrix_fn = ( + _xyzwxyz_to_matrix if target_world.shape[-1] == 7 else _xyzypr_to_matrix + ) # Convert to SE3 for transformation - target_se3 = SE3.from_matrix(to_matrix_fn(target_world[None, :])[0]) # (4, 4) + target_se3 = SE3.from_matrix( + target_world_to_matrix_fn(target_world[None, :])[0] + ) # (4, 4) chunk_se3 = SE3.from_matrix(to_matrix_fn(chunk_world)) # (T, 4, 4) # Compute relative transform and apply to chunk - chunk_in_target_frame = target_se3.inverse() @ chunk_se3 + if self.inverse: + chunk_in_target_frame = target_se3.inverse() @ chunk_se3 + else: + chunk_in_target_frame = target_se3 @ chunk_se3 chunk_mats = chunk_in_target_frame.to_matrix() if chunk_mats.ndim == 2: chunk_mats = chunk_mats[None, ...] - chunk_in_target_frame = ( - _matrix_to_xyzwxyz(chunk_mats) - if self.is_quat - else _matrix_to_xyzypr(chunk_mats) - ) + + if self.mode == "xyzwxyz": + chunk_in_target_frame = _matrix_to_xyzwxyz(chunk_mats) + elif self.mode == "xyzypr": + chunk_in_target_frame = _matrix_to_xyzypr(chunk_mats) + elif self.mode == "xyz": + chunk_in_target_frame = _matrix_to_xyz(chunk_mats) + else: + raise ValueError(f"Invalid mode: {self.mode}") + + if chunk_world_shape is not None: + chunk_in_target_frame = chunk_in_target_frame.reshape(*chunk_world_shape) # Store transformed chunk back in batch batch[self.transformed_key_name] = chunk_in_target_frame @@ -224,11 +252,75 @@ def transform(self, batch: dict) -> dict: f"'{self.pose_key}'" ) xyz = pose[:3] - ypr = R.from_quat(pose[3:7]).as_euler("ZYX", degrees=False) + xyzw = wxyz_to_xyzw(pose[3:7]) + ypr = R.from_quat(xyzw).as_euler("ZYX", degrees=False) batch[self.output_key] = np.concatenate([xyz, ypr], axis=0) return batch +class YPRToQuaternionPose(Transform): + """Convert a single pose from xyz + ypr to xyz + quat(x,y,z,w).""" + + def __init__(self, pose_key: str, output_key: str): + self.pose_key = pose_key + self.output_key = output_key + + def transform(self, batch: dict) -> dict: + pose = np.asarray(batch[self.pose_key]) + if pose.shape != (6,): + raise ValueError( + f"YPRToQuaternionPose expects shape (6,), got {pose.shape} for key " + f"'{self.pose_key}'" + ) + xyz = pose[:3] + quat = R.from_euler("ZYX", pose[3:6], degrees=False).as_quat() # (x,y,z,w) + quat = xyzw_to_wxyz(quat) + batch[self.output_key] = np.concatenate([xyz, quat], axis=0) + return batch + + +class BatchQuaternionPoseToYPR(Transform): + """Convert a batch of poses from xyz + quat(x,y,z,w) to xyz + ypr.""" + + def __init__(self, pose_key: str, output_key: str): + self.pose_key = pose_key + self.output_key = output_key + + def transform(self, batch: dict) -> dict: + pose = np.asarray(batch[self.pose_key]) + if pose.ndim != 2 or pose.shape[-1] != 7: + raise ValueError( + f"BatchQuaternionPoseToYPR expects shape (N, 7), got {pose.shape} for key " + f"'{self.pose_key}'" + ) + xyz = pose[:, :3] + xyzw = wxyz_to_xyzw(pose[:, 3:7]) + ypr = R.from_quat(xyzw).as_euler("ZYX", degrees=False) # (N, 3) + batch[self.output_key] = np.concatenate([xyz, ypr], axis=1) + return batch + + +class BatchYPRToQuaternionPose(Transform): + """Convert a batch of poses from xyz + ypr to xyz + quat(x,y,z,w).""" + + def __init__(self, pose_key: str, output_key: str): + self.pose_key = pose_key + self.output_key = output_key + + def transform(self, batch: dict) -> dict: + pose = np.asarray(batch[self.pose_key]) + if pose.ndim != 2 or pose.shape[-1] != 6: + raise ValueError( + f"BatchYPRToQuaternionPose expects shape (N, 6), got {pose.shape} for key " + f"'{self.pose_key}'" + ) + xyz = pose[:, :3] + quat = R.from_euler("ZYX", pose[:, 3:6], degrees=False).as_quat() # (N, 4) + quat = xyzw_to_wxyz(quat) + batch[self.output_key] = np.concatenate([xyz, quat], axis=1) + return batch + + class PoseCoordinateFrameTransform(Transform): """Transform a single pose into a target frame pose.""" @@ -237,27 +329,21 @@ def __init__( target_world: str, pose_world: str, transformed_key_name: str, - is_quat: bool = False, + mode: Literal["xyzwxyz", "xyzypr", "xyz"] = "xyzwxyz", ): self.target_world = target_world self.pose_world = pose_world self.transformed_key_name = transformed_key_name - self.is_quat = is_quat + self.mode = mode self._chunk_transform = ActionChunkCoordinateFrameTransform( target_world=target_world, chunk_world=pose_world, transformed_key_name=transformed_key_name, - is_quat=is_quat, + mode=mode, ) def transform(self, batch: dict) -> dict: pose_world = np.asarray(batch[self.pose_world]) - expected_shape = (7,) if self.is_quat else (6,) - if pose_world.shape != expected_shape: - raise ValueError( - f"Expected pose_world shape {expected_shape}, got {pose_world.shape}" - ) - transformed = self._chunk_transform.transform( { self.target_world: batch[self.target_world], @@ -390,6 +476,17 @@ def transform(self, batch): # --------------------------------------------------------------------------- # Shape Transforms # --------------------------------------------------------------------------- +class SplitKeys(Transform): + def __init__(self, input_key: str, output_key_list: list[(str, int)]): + self.input_key = input_key + self.output_key_list = list(output_key_list) + + def transform(self, batch: dict) -> dict: + prev_end = 0 + for key, size in self.output_key_list: + batch[key] = batch[self.input_key][..., prev_end : prev_end + size] + prev_end += size + return batch class ConcatKeys(Transform): @@ -414,10 +511,23 @@ def transform(self, batch): return batch + +class Reshape(Transform): + def __init__(self, input_key: str, output_key: str, shape: tuple): + self.input_key = input_key + self.output_key = output_key + self.shape = shape + + def transform(self, batch: dict) -> dict: + batch[self.output_key] = batch[self.input_key].reshape(*self.shape) + return batch + + # --------------------------------------------------------------------------- # Type Transforms # --------------------------------------------------------------------------- + class NumpyToTensor(Transform): def __init__(self, keys: list[str]): self.keys = keys @@ -429,5 +539,7 @@ def transform(self, batch: dict) -> dict: elif isinstance(batch[key], torch.Tensor): batch[key] = batch[key].clone() else: - raise ValueError(f"NumpyToTensor expects key '{key}' to be a numpy array or torch tensor, got {type(batch[key])}") + raise ValueError( + f"NumpyToTensor expects key '{key}' to be a numpy array or torch tensor, got {type(batch[key])}" + ) return batch diff --git a/egomimic/scripts/tutorials/zarr_data_viz.ipynb b/egomimic/scripts/tutorials/zarr_data_viz.ipynb index 45b93fb3..c061fc17 100644 --- a/egomimic/scripts/tutorials/zarr_data_viz.ipynb +++ b/egomimic/scripts/tutorials/zarr_data_viz.ipynb @@ -1,5 +1,16 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "7cd73ea6", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, { "cell_type": "markdown", "id": "29aeeb40", @@ -47,7 +58,7 @@ "metadata": {}, "outputs": [], "source": [ - "TEMP_DIR = \"/tmp/\" # replace with your own temp directory for caching S3 data\n", + "TEMP_DIR = \"/coc/flash7/paphiwetsa3/projects/EgoVerse/tmp\" # replace with your own temp directory for caching S3 data\n", "load_env()" ] }, @@ -83,7 +94,7 @@ "source": [ "# Separate YPR visualization preview\n", "for batch in loader:\n", - " vis_ypr = Eva.viz_transformed_batch(batch, mode=\"palm_axes\")\n", + " vis_ypr = Eva.viz_transformed_batch(batch, mode=\"axes\")\n", " mpy.show_image(vis_ypr)\n", " break" ] @@ -97,7 +108,7 @@ "source": [ "images = []\n", "for i, batch in enumerate(loader):\n", - " vis = Eva.viz_transformed_batch(batch, mode=\"palm_traj\")\n", + " vis = Eva.viz_transformed_batch(batch, mode=\"traj\")\n", " images.append(vis)\n", " if i > 10:\n", " break\n", @@ -123,8 +134,8 @@ "source": [ "intrinsics_key = \"base\"\n", "\n", - "key_map = Aria.get_keymap()\n", - "transform_list = Aria.get_transform_list()\n", + "key_map = Aria.get_keymap(mode=\"cartesian\")\n", + "transform_list = Aria.get_transform_list(mode=\"cartesian\")\n", "\n", "resolver = S3EpisodeResolver(\n", " TEMP_DIR,\n", @@ -151,7 +162,7 @@ "source": [ "ims = []\n", "for i, batch in enumerate(loader):\n", - " vis = Aria.viz_transformed_batch(batch, mode=\"palm_traj\")\n", + " vis = Aria.viz_transformed_batch(batch, mode=\"traj\")\n", " ims.append(vis)\n", " # mpy.show_image(vis)\n", "\n", @@ -174,7 +185,7 @@ "# Aria YPR video (same data loop, YPR overlay)\n", "ims_ypr = []\n", "for i, batch in enumerate(loader):\n", - " vis_ypr = Aria.viz_transformed_batch(batch, mode=\"palm_axes\")\n", + " vis_ypr = Aria.viz_transformed_batch(batch, mode=\"axes\")\n", " ims_ypr.append(vis_ypr)\n", " if i > 20:\n", " break\n", @@ -182,6 +193,46 @@ "mpy.show_video(ims_ypr, fps=30)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "db296b65", + "metadata": {}, + "outputs": [], + "source": [ + "key_map = Aria.get_keymap(mode=\"keypoints\")\n", + "transform_list = Aria.get_transform_list(mode=\"keypoints_headframe\")\n", + "\n", + "resolver = S3EpisodeResolver(\n", + " TEMP_DIR,\n", + " key_map=key_map,\n", + " transform_list=transform_list,\n", + ")\n", + "\n", + "cloudflare_ds = MultiDataset._from_resolver(\n", + " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", + ")\n", + "\n", + "loader = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07a82130", + "metadata": {}, + "outputs": [], + "source": [ + "ims_keypoints = []\n", + "for i, batch in enumerate(loader):\n", + " vis_keypoints = Aria.viz_transformed_batch(batch, mode=\"keypoints\", action_key=\"actions_keypoints\")\n", + " ims_keypoints.append(vis_keypoints)\n", + " if i > 20:\n", + " break\n", + "\n", + "mpy.show_video(ims_keypoints, fps=30)" + ] + }, { "cell_type": "markdown", "id": "3fdb997a", @@ -241,7 +292,7 @@ " frame = batch['images.front_1'][0].permute(1,2,0).numpy() * 255\n", " vis = draw_dot_on_frame(frame, gaze_point_pixel, show=False)\n", " ims_gaze.append(vis)\n", - " if i > 600:\n", + " if i > 30:\n", " break" ] }, @@ -258,7 +309,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/egomimic/utils/pose_utils.py b/egomimic/utils/pose_utils.py index ac284e1a..c7c529bd 100644 --- a/egomimic/utils/pose_utils.py +++ b/egomimic/utils/pose_utils.py @@ -8,6 +8,10 @@ def xyzw_to_wxyz(xyzw): return np.concatenate([xyzw[..., 3:4], xyzw[..., :3]], axis=-1) +def wxyz_to_xyzw(wxyz): + return np.concatenate([wxyz[..., 1:4], wxyz[..., 0:1]], axis=-1) + + def _interpolate_euler(seq: np.ndarray, chunk_length: int) -> np.ndarray: """Euler-aware interpolation for a single (T, 6) or (T, 7) sequence.""" T, D = seq.shape @@ -80,6 +84,14 @@ def _interpolate_quat_wxyz(seq: np.ndarray, chunk_length: int) -> np.ndarray: ) +def _interpolate_xyz(seq: np.ndarray, chunk_length: int) -> np.ndarray: + """Linear interpolation for arbitrary (T, 3) arrays or (T, K, 3) arrays.""" + T = seq.shape[0] + old_time = np.linspace(0, 1, T) + new_time = np.linspace(0, 1, chunk_length) + return interp1d(old_time, seq, axis=0, kind="linear")(new_time) + + def _matrix_to_xyzypr(mats: np.ndarray) -> np.ndarray: """ args: @@ -99,6 +111,24 @@ def _matrix_to_xyzypr(mats: np.ndarray) -> np.ndarray: return np.concatenate([xyz, ypr], axis=-1).astype(dtype, copy=False) +def _xyzypr_to_matrix(xyzypr: np.ndarray) -> np.ndarray: + """ + args: + xyzypr: (B, 6) np.array of [[x, y, z, yaw, pitch, roll]] + returns: + (B, 4, 4) array of SE3 transformation matrices + """ + if xyzypr.ndim != 2 or xyzypr.shape[-1] != 6: + raise ValueError(f"Expected (B, 6) array, got shape {xyzypr.shape}") + B = xyzypr.shape[0] + dtype = xyzypr.dtype if np.issubdtype(xyzypr.dtype, np.floating) else np.float64 + + mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() + mats[:, :3, :3] = R.from_euler("ZYX", xyzypr[:, 3:6], degrees=False).as_matrix() + mats[:, :3, 3] = xyzypr[:, :3] + return mats + + def _matrix_to_xyzwxyz(mats: np.ndarray) -> np.ndarray: """ args: @@ -149,3 +179,71 @@ def T_rot_orientation(T: np.ndarray, rot_orientation: np.ndarray) -> np.ndarray: rot = rot @ rot_orientation T[:3, :3] = rot return T +def _xyz_to_matrix(xyz: np.ndarray) -> np.ndarray: + """ + args: + xyz: (B, 3) np.array of [[x, y, z]] + returns: + (B, 4, 4) array of SE3 transformation matrices + """ + if xyz.ndim != 2 or xyz.shape[-1] != 3: + raise ValueError(f"Expected (B, 3) array, got shape {xyz.shape}") + B = xyz.shape[0] + dtype = xyz.dtype if np.issubdtype(xyz.dtype, np.floating) else np.float64 + mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() + mats[:, :3, 3] = xyz + return mats + + +def _matrix_to_xyz(mats: np.ndarray) -> np.ndarray: + """ + args: + mats: (B, 4, 4) array of SE3 transformation matrices + returns: + (B, 3) np.array of [[x, y, z]] + """ + if mats.ndim != 3 or mats.shape[-2:] != (4, 4): + raise ValueError(f"Expected (B, 4, 4) array, got shape {mats.shape}") + mats = np.asarray(mats) + dtype = mats.dtype if np.issubdtype(mats.dtype, np.floating) else np.float64 + return mats[:, :3, 3].astype(dtype, copy=False) + + +def _split_action_pose(actions): + # 14D layout: [L xyz ypr g, R xyz ypr g] + # 12D layout: [L xyz ypr, R xyz ypr] + if actions.shape[-1] == 14: + left_xyz = actions[..., :3] + left_ypr = actions[..., 3:6] + right_xyz = actions[..., 7:10] + right_ypr = actions[..., 10:13] + elif actions.shape[-1] == 12: + left_xyz = actions[..., :3] + left_ypr = actions[..., 3:6] + right_xyz = actions[..., 6:9] + right_ypr = actions[..., 9:12] + else: + raise ValueError(f"Unsupported action dim {actions.shape[-1]}") + return left_xyz, left_ypr, right_xyz, right_ypr + + +def _split_keypoints(keypoints, wrist_in_data: bool = False): + if wrist_in_data: + left_xyz = keypoints[..., :3] + left_wxyz = keypoints[..., 3:7] + left_keypoints = keypoints[..., 7:70] + right_xyz = keypoints[..., 70:73] + right_wxyz = keypoints[..., 73:77] + right_keypoints = keypoints[..., 77:140] + return ( + left_xyz, + left_wxyz, + left_keypoints, + right_xyz, + right_wxyz, + right_keypoints, + ) + else: + left_keypoints = keypoints[..., :63] + right_keypoints = keypoints[..., 63:] + return left_keypoints, right_keypoints diff --git a/egomimic/utils/type_utils.py b/egomimic/utils/type_utils.py new file mode 100644 index 00000000..e6608ecb --- /dev/null +++ b/egomimic/utils/type_utils.py @@ -0,0 +1,11 @@ +import numpy as np + + +def _to_numpy(arr): + if hasattr(arr, "detach"): + arr = arr.detach() + if hasattr(arr, "cpu"): + arr = arr.cpu() + if hasattr(arr, "numpy"): + return arr.numpy() + return np.asarray(arr) diff --git a/egomimic/utils/viz_utils.py b/egomimic/utils/viz_utils.py new file mode 100644 index 00000000..c929fcd7 --- /dev/null +++ b/egomimic/utils/viz_utils.py @@ -0,0 +1,247 @@ +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from scipy.spatial.transform import Rotation as R + +from egomimic.utils.egomimicUtils import ( + INTRINSICS, + cam_frame_to_cam_pixels, + draw_actions, +) +from egomimic.utils.pose_utils import _split_action_pose, _split_keypoints + + +class ColorPalette: + Blues = "Blues" + Greens = "Greens" + Reds = "Reds" + Oranges = "Oranges" + Purples = "Purples" + Greys = "Greys" + + @classmethod + def is_valid(cls, name: str) -> bool: + return name in vars(cls).values() + + @classmethod + def to_rgb(cls, cmap_name: str, value: float = 0.7) -> tuple[int, int, int]: + """Convert a ColorPalette cmap name to an RGB tuple (0-255). + value: 0-1, where higher = darker shade.""" + rgba = plt.get_cmap(cmap_name)(value) + return tuple(int(c * 255) for c in rgba[:3]) + + +def _prepare_viz_image(img): + if img.ndim == 3 and img.shape[0] in (1, 3): + img = np.transpose(img, (1, 2, 0)) + + if img.dtype != np.uint8: + if img.max() <= 1.0: + img = (img * 255.0).clip(0, 255).astype(np.uint8) + else: + img = img.clip(0, 255).astype(np.uint8) + + if img.ndim == 2: + img = np.repeat(img[:, :, None], 3, axis=-1) + elif img.shape[-1] == 1: + img = np.repeat(img, 3, axis=-1) + + return img + + +def _viz_traj(images, actions, intrinsics_key, **kwargs): + color = kwargs.get("color", "Blues") + if not ColorPalette.is_valid(color): + raise ValueError(f"Invalid color palette: {color}") + + images = _prepare_viz_image(images) + intrinsics = INTRINSICS[intrinsics_key] + left_xyz, _, right_xyz, _ = _split_action_pose(actions) + + vis = draw_actions( + images.copy(), + type="xyz", + color=color, + actions=left_xyz, + extrinsics=None, + intrinsics=intrinsics, + arm="left", + ) + vis = draw_actions( + vis, + type="xyz", + color=color, + actions=right_xyz, + extrinsics=None, + intrinsics=intrinsics, + arm="right", + ) + return vis + + +def _viz_axes(images, actions, intrinsics_key, axis_len_m=0.04, **kwargs): + images = _prepare_viz_image(images) + intrinsics = INTRINSICS[intrinsics_key] + left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions) + vis = images.copy() + + def _draw_axis_color_legend(frame): + _, w = frame.shape[:2] + x_right = w - 12 + y_start = 14 + y_step = 12 + line_len = 24 + axis_legend = [ + ("x", (255, 0, 0)), + ("y", (0, 255, 0)), + ("z", (0, 0, 255)), + ] + for i, (name, color) in enumerate(axis_legend): + y = y_start + i * y_step + x0 = x_right - line_len + x1 = x_right + cv2.line(frame, (x0, y), (x1, y), color, 3) + cv2.putText( + frame, + name, + (x0 - 12, y + 4), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + color, + 1, + cv2.LINE_AA, + ) + return frame + + def _draw_rotation_at_anchor( + frame, xyz_seq, ypr_seq, label, anchor_color, **kwargs + ): + if len(xyz_seq) == 0 or len(ypr_seq) == 0: + return frame + + palm_xyz = xyz_seq[0] + palm_ypr = ypr_seq[0] + rot = R.from_euler("ZYX", palm_ypr, degrees=False).as_matrix() + + axis_points_cam = np.vstack( + [ + palm_xyz, + palm_xyz + rot[:, 0] * axis_len_m, + palm_xyz + rot[:, 1] * axis_len_m, + palm_xyz + rot[:, 2] * axis_len_m, + ] + ) + + px = cam_frame_to_cam_pixels(axis_points_cam, intrinsics)[:, :2] + if not np.isfinite(px).all(): + return frame + pts = np.round(px).astype(np.int32) + + h, w = frame.shape[:2] + x0, y0 = pts[0] + if not (0 <= x0 < w and 0 <= y0 < h): + return frame + + cv2.circle(frame, (x0, y0), 4, anchor_color, -1) + axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] + for i, color in enumerate(axis_colors, start=1): + x1, y1 = pts[i] + if 0 <= x1 < w and 0 <= y1 < h: + cv2.line(frame, (x0, y0), (x1, y1), color, 2) + cv2.circle(frame, (x1, y1), 2, color, -1) + + cv2.putText( + frame, + label, + (x0 + 6, max(12, y0 - 8)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + anchor_color, + 1, + cv2.LINE_AA, + ) + return frame + + vis = _draw_rotation_at_anchor(vis, left_xyz, left_ypr, "L rot", (255, 180, 80)) + vis = _draw_rotation_at_anchor(vis, right_xyz, right_ypr, "R rot", (80, 180, 255)) + vis = _draw_axis_color_legend(vis) + return vis + + +def _viz_keypoints( + images, + actions, + intrinsics_key, + edges, + colors, + edge_ranges, + dot_color=None, + **kwargs, +): + """Visualize all 21 MANO keypoints per hand, projected onto the image.""" + # Prepare image + images = _prepare_viz_image(images) + + intrinsics = INTRINSICS[intrinsics_key] + + vis = images.copy() + h, w = vis.shape[:2] + + if actions.shape[-1] == 140: + left_xyz, left_wxyz, left_keypoints, right_xyz, right_wxyz, right_keypoints = ( + _split_keypoints(actions, wrist_in_data=True) + ) + else: + left_keypoints, right_keypoints = _split_keypoints(actions, wrist_in_data=False) + keypoints = {} + keypoints["left"] = left_keypoints.reshape(-1, 3) + keypoints["right"] = right_keypoints.reshape(-1, 3) + _default_dot_colors = {"left": (0, 120, 255), "right": (255, 80, 0)} + for hand in ("left", "right"): + hand_dot_color = ( + dot_color if dot_color is not None else _default_dot_colors[hand] + ) + kps_cam = keypoints[hand] + # Camera frame -> pixels + kps_px = cam_frame_to_cam_pixels(kps_cam, intrinsics) # (42, 3+) 21 per arm + + # Identify valid keypoints (z > 0 and in image bounds) + valid = kps_cam[:, 2] > 0.01 + valid &= (kps_px[:, 0] >= 0) & (kps_px[:, 0] < w) + valid &= (kps_px[:, 1] >= 0) & (kps_px[:, 1] < h) + + # Draw skeleton edges (colored by finger) + for finger, start, end in edge_ranges: + color = colors[finger] + for edge_idx in range(start, end): + i, j = edges[edge_idx] + if valid[i] and valid[j]: + p1 = (int(kps_px[i, 0]), int(kps_px[i, 1])) + p2 = (int(kps_px[j, 0]), int(kps_px[j, 1])) + cv2.line(vis, p1, p2, color, 2) + + # Draw keypoint dots on top + for k in range(21): + if valid[k]: + center = (int(kps_px[k, 0]), int(kps_px[k, 1])) + cv2.circle(vis, center, 4, hand_dot_color, -1) + cv2.circle(vis, center, 4, (255, 255, 255), 1) # white border + + # Label wrist + if valid[0]: + wrist_px = (int(kps_px[0, 0]) + 6, int(kps_px[0, 1]) - 6) + cv2.putText( + vis, + f"{hand[0].upper()}", + wrist_px, + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + hand_dot_color, + 2, + ) + + return vis + + +def save_image(image: np.ndarray, path: str) -> None: + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))