Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 5 additions & 62 deletions egomimic/algo/hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
STD_SCALE,
EinOpsRearrange,
download_from_huggingface,
draw_actions,
draw_rotation_text,
frechet_gaussian_over_time,
get_sinusoid_encoding_table,
reverse_kl_from_samples,
Expand Down Expand Up @@ -800,6 +798,7 @@ def __init__(
encoder_specs: dict = None,
domains: list = None,
auxiliary_ac_keys: dict = {},
viz_func: dict = None,
# ---------------------------
# Pretrained
# ---------------------------
Expand All @@ -812,6 +811,7 @@ def __init__(
):
self.nets = nn.ModuleDict()
self.data_schematic = data_schematic
self.viz_func = viz_func

self.camera_transforms = camera_transforms
self.train_image_augs = train_image_augs
Expand Down Expand Up @@ -1250,69 +1250,12 @@ def visualize_preds(self, predictions, batch):
Returns:
ims (np.ndarray): (B, H, W, 3) - images with actions drawn on top
"""

if self.viz_func is None:
raise ValueError("viz_func is not set")
embodiment_id = batch["embodiment"][0].item()
embodiment_name = get_embodiment(embodiment_id).lower()
ac_key = self.ac_keys[embodiment_id]

viz_img_key = self.data_schematic.viz_img_key()[embodiment_id]
ims = (batch[viz_img_key].cpu().numpy().transpose((0, 2, 3, 1)) * 255).astype(
np.uint8
)
for key in batch:
if f"{embodiment_name}_{key}" in predictions:
preds = predictions[f"{embodiment_name}_{key}"]
gt = batch[key]

if self.is_6dof and ac_key == "actions_cartesian":
gt, gt_rot = self._extract_xyz(gt)
preds, preds_rot = self._extract_xyz(preds)

for b in range(ims.shape[0]):
if preds.shape[-1] == 7 or preds.shape[-1] == 14:
ac_type = "joints"
elif preds.shape[-1] == 3 or preds.shape[-1] == 6:
ac_type = "xyz"
else:
raise ValueError(
f"Unknown action type with shape {preds.shape}"
)

# Determine arm from embodiment name, not action shape
if "bimanual" in embodiment_name:
arm = "both"
elif "left" in embodiment_name:
arm = "left"
elif "right" in embodiment_name:
arm = "right"
else:
raise ValueError(f"Unknown embodiment name: {embodiment_name}")
ims[b] = draw_actions(
ims[b],
ac_type,
"Purples",
preds[b].cpu().numpy(),
self.camera_transforms[embodiment_name].extrinsics,
self.camera_transforms[embodiment_name].intrinsics,
arm=arm,
kinematics_solver=self.kinematics_solver,
)
ims[b] = draw_actions(
ims[b],
ac_type,
"Greens",
gt[b].cpu().numpy(),
self.camera_transforms[embodiment_name].extrinsics,
self.camera_transforms[embodiment_name].intrinsics,
arm=arm,
kinematics_solver=self.kinematics_solver,
)

if self.is_6dof and ac_key == "actions_cartesian":
ims[b] = draw_rotation_text(
ims[b], gt_rot[b][0], preds_rot[b][0], position=(340, 20)
)
return ims
return self.viz_func[embodiment_name](predictions, batch)

@override
def compute_losses(self, predictions, batch):
Expand Down
18 changes: 5 additions & 13 deletions egomimic/hydra_configs/train_zarr.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
defaults:
- model: hpt_bc_flow_eva
- model: hpt_cotrain_flow_shared_head
- visualization: eva_cartesian_aria_cartesian
- paths: default
- trainer: ddp
- debug: null
- logger: wandb
- data: eva
- data: eva_human_cotrain
- callbacks: checkpoints
- override hydra/launcher: submitit
- _self_
Expand All @@ -15,6 +16,8 @@ ckpt_path: null
train: true
eval: false

norm_stat_fraction: 1.0

eval_class:
_target_: egomimic.scripts.evaluation.Eve
mode: real
Expand All @@ -32,7 +35,6 @@ launch_params:
gpus_per_node: 1
nodes: 1


data_schematic: # Dynamically fill in these shapes from the dataset
_target_: egomimic.rldb.zarr.utils.DataSchematic
norm_mode: quantile
Expand Down Expand Up @@ -101,15 +103,5 @@ data_schematic: # Dynamically fill in these shapes from the dataset
embodiment:
key_type: metadata_keys
zarr_key: metadata.embodiment
viz_img_key:
eva_bimanual:
front_img_1
aria_bimanual:
front_img_1
mecka_bimanual:
front_img_1
scale_bimanual:
front_img_1

seed: 42
norm_stat_fraction: 1.0 # fraction of dataset to calculate norm stats over (out of 1.0)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
eva_bimanual:
_target_: egomimic.rldb.embodiment.eva.Eva.viz_cartesian_gt_preds
_partial_: true
image_key: front_img_1
action_key: actions_cartesian
aria_bimanual:
_target_: egomimic.rldb.embodiment.human.Human.viz_cartesian_gt_preds
_partial_: true
image_key: front_img_1
action_key: actions_cartesian
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
eva_bimanual:
_target_: egomimic.rldb.embodiment.eva.Eva.viz_cartesian_gt_preds
_partial_: true
image_key: front_img_1
action_key: actions_eva_cart_aria_keypoints
aria_bimanual:
_target_: egomimic.rldb.embodiment.human.Aria.viz_keypoints_gt_preds
_partial_: true
image_key: front_img_1
action_key: actions_eva_cart_aria_keypoints
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
eva_bimanual:
_target_: egomimic.rldb.embodiment.eva.Eva.viz_gt_preds
_partial_: true
image_key: front_img_1
action_key: actions_eva_cart_aria_keypoints
transform_list:
_target_: egomimic.rldb.embodiment.eva._build_eva_bimanual_revert_eef_frame_transform_list
is_quat: false
aria_bimanual:
_target_: egomimic.rldb.embodiment.human.Aria.viz_gt_preds
_partial_: true
image_key: front_img_1
action_key: actions_eva_cart_aria_keypoints
transform_list:
_target_: egomimic.rldb.embodiment.human._build_aria_keypoints_revert_eef_frame_transform_list
is_quat: false
8 changes: 7 additions & 1 deletion egomimic/trainHydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,16 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
),
)

viz_func = cfg.visualization
viz_func_dict = {}
for embodiment_name, embodiment_viz_func in viz_func.items():
viz_func_dict[embodiment_name] = hydra.utils.instantiate(embodiment_viz_func)

# NOTE: We also pass the data_schematic_dict into the robomimic model's instatiation now that we've initialzied the shapes and norm stats. In theory, upon loading the PL checkpoint, it will remember this, but let's see.
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
cfg.model, robomimic_model={"data_schematic": data_schematic}
cfg.model,
robomimic_model={"data_schematic": data_schematic, "viz_func": viz_func_dict},
)

_log_dataset_frame_counts(train_datasets, valid_datasets)
Expand Down