diff --git a/egomimic/algo/hpt.py b/egomimic/algo/hpt.py index 5de7e0b7..6ae47832 100644 --- a/egomimic/algo/hpt.py +++ b/egomimic/algo/hpt.py @@ -19,8 +19,6 @@ STD_SCALE, EinOpsRearrange, download_from_huggingface, - draw_actions, - draw_rotation_text, frechet_gaussian_over_time, get_sinusoid_encoding_table, reverse_kl_from_samples, @@ -800,6 +798,7 @@ def __init__( encoder_specs: dict = None, domains: list = None, auxiliary_ac_keys: dict = {}, + viz_func: dict = None, # --------------------------- # Pretrained # --------------------------- @@ -812,6 +811,7 @@ def __init__( ): self.nets = nn.ModuleDict() self.data_schematic = data_schematic + self.viz_func = viz_func self.camera_transforms = camera_transforms self.train_image_augs = train_image_augs @@ -1250,69 +1250,12 @@ def visualize_preds(self, predictions, batch): Returns: ims (np.ndarray): (B, H, W, 3) - images with actions drawn on top """ - + if self.viz_func is None: + raise ValueError("viz_func is not set") embodiment_id = batch["embodiment"][0].item() embodiment_name = get_embodiment(embodiment_id).lower() - ac_key = self.ac_keys[embodiment_id] - - viz_img_key = self.data_schematic.viz_img_key()[embodiment_id] - ims = (batch[viz_img_key].cpu().numpy().transpose((0, 2, 3, 1)) * 255).astype( - np.uint8 - ) - for key in batch: - if f"{embodiment_name}_{key}" in predictions: - preds = predictions[f"{embodiment_name}_{key}"] - gt = batch[key] - - if self.is_6dof and ac_key == "actions_cartesian": - gt, gt_rot = self._extract_xyz(gt) - preds, preds_rot = self._extract_xyz(preds) - - for b in range(ims.shape[0]): - if preds.shape[-1] == 7 or preds.shape[-1] == 14: - ac_type = "joints" - elif preds.shape[-1] == 3 or preds.shape[-1] == 6: - ac_type = "xyz" - else: - raise ValueError( - f"Unknown action type with shape {preds.shape}" - ) - # Determine arm from embodiment name, not action shape - if "bimanual" in embodiment_name: - arm = "both" - elif "left" in embodiment_name: - arm = "left" - elif "right" in embodiment_name: - arm = "right" - else: - raise ValueError(f"Unknown embodiment name: {embodiment_name}") - ims[b] = draw_actions( - ims[b], - ac_type, - "Purples", - preds[b].cpu().numpy(), - self.camera_transforms[embodiment_name].extrinsics, - self.camera_transforms[embodiment_name].intrinsics, - arm=arm, - kinematics_solver=self.kinematics_solver, - ) - ims[b] = draw_actions( - ims[b], - ac_type, - "Greens", - gt[b].cpu().numpy(), - self.camera_transforms[embodiment_name].extrinsics, - self.camera_transforms[embodiment_name].intrinsics, - arm=arm, - kinematics_solver=self.kinematics_solver, - ) - - if self.is_6dof and ac_key == "actions_cartesian": - ims[b] = draw_rotation_text( - ims[b], gt_rot[b][0], preds_rot[b][0], position=(340, 20) - ) - return ims + return self.viz_func[embodiment_name](predictions, batch) @override def compute_losses(self, predictions, batch): diff --git a/egomimic/hydra_configs/train_zarr.yaml b/egomimic/hydra_configs/train_zarr.yaml index 4a91c21b..5da14448 100644 --- a/egomimic/hydra_configs/train_zarr.yaml +++ b/egomimic/hydra_configs/train_zarr.yaml @@ -1,10 +1,11 @@ defaults: - - model: hpt_bc_flow_eva + - model: hpt_cotrain_flow_shared_head + - visualization: eva_cartesian_aria_cartesian - paths: default - trainer: ddp - debug: null - logger: wandb - - data: eva + - data: eva_human_cotrain - callbacks: checkpoints - override hydra/launcher: submitit - _self_ @@ -15,6 +16,8 @@ ckpt_path: null train: true eval: false +norm_stat_fraction: 1.0 + eval_class: _target_: egomimic.scripts.evaluation.Eve mode: real @@ -32,7 +35,6 @@ launch_params: gpus_per_node: 1 nodes: 1 - data_schematic: # Dynamically fill in these shapes from the dataset _target_: egomimic.rldb.zarr.utils.DataSchematic norm_mode: quantile @@ -101,15 +103,5 @@ data_schematic: # Dynamically fill in these shapes from the dataset embodiment: key_type: metadata_keys zarr_key: metadata.embodiment - viz_img_key: - eva_bimanual: - front_img_1 - aria_bimanual: - front_img_1 - mecka_bimanual: - front_img_1 - scale_bimanual: - front_img_1 seed: 42 -norm_stat_fraction: 1.0 # fraction of dataset to calculate norm stats over (out of 1.0) \ No newline at end of file diff --git a/egomimic/hydra_configs/visualization/eva_cartesian_aria_cartesian.yaml b/egomimic/hydra_configs/visualization/eva_cartesian_aria_cartesian.yaml new file mode 100644 index 00000000..d4ed7a7a --- /dev/null +++ b/egomimic/hydra_configs/visualization/eva_cartesian_aria_cartesian.yaml @@ -0,0 +1,10 @@ +eva_bimanual: + _target_: egomimic.rldb.embodiment.eva.Eva.viz_cartesian_gt_preds + _partial_: true + image_key: front_img_1 + action_key: actions_cartesian +aria_bimanual: + _target_: egomimic.rldb.embodiment.human.Human.viz_cartesian_gt_preds + _partial_: true + image_key: front_img_1 + action_key: actions_cartesian diff --git a/egomimic/hydra_configs/visualization/eva_cartesian_aria_keypoints.yaml b/egomimic/hydra_configs/visualization/eva_cartesian_aria_keypoints.yaml new file mode 100644 index 00000000..a2311911 --- /dev/null +++ b/egomimic/hydra_configs/visualization/eva_cartesian_aria_keypoints.yaml @@ -0,0 +1,10 @@ +eva_bimanual: + _target_: egomimic.rldb.embodiment.eva.Eva.viz_cartesian_gt_preds + _partial_: true + image_key: front_img_1 + action_key: actions_eva_cart_aria_keypoints +aria_bimanual: + _target_: egomimic.rldb.embodiment.human.Aria.viz_keypoints_gt_preds + _partial_: true + image_key: front_img_1 + action_key: actions_eva_cart_aria_keypoints diff --git a/egomimic/hydra_configs/visualization/eva_cartesian_aria_keypoints_wrist.yaml b/egomimic/hydra_configs/visualization/eva_cartesian_aria_keypoints_wrist.yaml new file mode 100644 index 00000000..c2bceea0 --- /dev/null +++ b/egomimic/hydra_configs/visualization/eva_cartesian_aria_keypoints_wrist.yaml @@ -0,0 +1,16 @@ +eva_bimanual: + _target_: egomimic.rldb.embodiment.eva.Eva.viz_gt_preds + _partial_: true + image_key: front_img_1 + action_key: actions_eva_cart_aria_keypoints + transform_list: + _target_: egomimic.rldb.embodiment.eva._build_eva_bimanual_revert_eef_frame_transform_list + is_quat: false +aria_bimanual: + _target_: egomimic.rldb.embodiment.human.Aria.viz_gt_preds + _partial_: true + image_key: front_img_1 + action_key: actions_eva_cart_aria_keypoints + transform_list: + _target_: egomimic.rldb.embodiment.human._build_aria_keypoints_revert_eef_frame_transform_list + is_quat: false diff --git a/egomimic/trainHydra.py b/egomimic/trainHydra.py index f08f50a8..d954fd79 100644 --- a/egomimic/trainHydra.py +++ b/egomimic/trainHydra.py @@ -120,10 +120,16 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: ), ) + viz_func = cfg.visualization + viz_func_dict = {} + for embodiment_name, embodiment_viz_func in viz_func.items(): + viz_func_dict[embodiment_name] = hydra.utils.instantiate(embodiment_viz_func) + # NOTE: We also pass the data_schematic_dict into the robomimic model's instatiation now that we've initialzied the shapes and norm stats. In theory, upon loading the PL checkpoint, it will remember this, but let's see. log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate( - cfg.model, robomimic_model={"data_schematic": data_schematic} + cfg.model, + robomimic_model={"data_schematic": data_schematic, "viz_func": viz_func_dict}, ) _log_dataset_frame_counts(train_datasets, valid_datasets)