diff --git a/.gitignore b/.gitignore index 0942096..5ecdddb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,17 @@ -projects +submit.py +scripts/run_all_exps.sh +misc viewer run.sh run-long.sh /database/processed /database/configs /database/raw -/logdir +/database/ama +/database/polycam +/logdir* /tmp +projects/csim lab4d.egg-info __pycache__/ @@ -16,3 +21,4 @@ __pycache__/ preprocess/third_party/vcnplus/vcn_rob.pth preprocess/third_party/viewpoint/human.pth preprocess/third_party/viewpoint/quad.pth +preprocess/third_party/omnivision/*.ckpt diff --git a/.gitmodules b/.gitmodules index ec4ec5b..dede49e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,3 +8,9 @@ [submodule "docs/pytorch_sphinx_theme"] path = docs/pytorch_sphinx_theme url = https://github.com/gengshan-y/pytorch_sphinx_theme +[submodule "projects/ppr/ppr-diffphys"] + path = projects/ppr/ppr-diffphys + url = git@github.com:gengshan-y/ppr-diffphys.git +[submodule "projects/ppr/eval/third_party/ChamferDistancePytorch"] + path = projects/ppr/eval/third_party/ChamferDistancePytorch + url = https://github.com/ThibaultGROUEIX/ChamferDistancePytorch diff --git a/docs/source/tutorials/single_video_cat.rst b/docs/source/tutorials/single_video_cat.rst index 5f54eb4..c74ac64 100644 --- a/docs/source/tutorials/single_video_cat.rst +++ b/docs/source/tutorials/single_video_cat.rst @@ -123,7 +123,7 @@ To render novel views, run:: To render a video of the proxy geometry and cameras over training iterations, run:: - python scripts/render_intermediate.py --testdir logdir/$logname/ + python lab4d/render_intermediate.py --testdir logdir/$logname/ .. raw:: html diff --git a/environment.yml b/environment.yml index 9e98aba..a396fab 100644 --- a/environment.yml +++ b/environment.yml @@ -7,8 +7,9 @@ dependencies: - python=3.9 - setuptools=66.0.0 - pip - - pytorch==2.0.0=py3.9_cuda11.7_cudnn8.5.0_0 - - torchvision + - pytorch=2.0.0 + - pytorch-cuda=11.7 + - torchvision=0.15.2 - cudatoolkit-dev=11.7 - gcc_linux-64=10 - gxx_linux-64=10 @@ -38,3 +39,5 @@ dependencies: - groundingdino @ git+https://github.com/IDEA-Research/GroundingDINO.git - openmim - pyrender + - open3d==0.17.0 + - geomloss==0.2.6 diff --git a/lab4d/config.py b/lab4d/config.py index c6fcaa8..7e93955 100644 --- a/lab4d/config.py +++ b/lab4d/config.py @@ -10,27 +10,28 @@ class TrainModelConfig: # weights of reconstruction terms flags.DEFINE_float("mask_wt", 0.1, "weight for silhouette loss") flags.DEFINE_float("rgb_wt", 0.1, "weight for color loss") - flags.DEFINE_float("depth_wt", 1e-4, "weight for depth loss") + flags.DEFINE_float("depth_wt", 0.0, "weight for depth loss") + flags.DEFINE_float("normal_wt", 0.0, "weight for normal loss") flags.DEFINE_float("flow_wt", 0.5, "weight for flow loss") flags.DEFINE_float("vis_wt", 1e-2, "weight for visibility loss") flags.DEFINE_float("feature_wt", 1e-2, "weight for feature reconstruction loss") - flags.DEFINE_float("feat_reproj_wt", 5e-2, "weight for feature reprojection loss") + flags.DEFINE_float("feat_reproj_wt", 0.05, "weight for feature reprojection loss") # weights of regularization terms flags.DEFINE_float( "reg_visibility_wt", 1e-4, "weight for visibility regularization" ) - flags.DEFINE_float("reg_eikonal_wt", 1e-3, "weight for eikonal regularization") + flags.DEFINE_float("reg_eikonal_wt", 0.01, "weight for eikonal regularization") + flags.DEFINE_float("reg_eikonal_scale_max", 1, "max scaling for eikonal reg") flags.DEFINE_float( - "reg_deform_cyc_wt", 0.01, "weight for deform cyc regularization" - ) - flags.DEFINE_float("reg_delta_skin_wt", 5e-3, "weight for delta skinning reg") - flags.DEFINE_float("reg_skin_entropy_wt", 5e-4, "weight for delta skinning reg") - flags.DEFINE_float( - "reg_gauss_skin_wt", 1e-3, "weight for gauss skinning consistency" + "reg_deform_cyc_wt", 0.05, "weight for deform cyc regularization" ) + flags.DEFINE_float("reg_delta_skin_wt", 1e-3, "weight for delta skinning reg") + flags.DEFINE_float("reg_skin_entropy_wt", 0.0, "weight for delta skinning reg") + flags.DEFINE_float("reg_gauss_skin_wt", 0.02, "weight for gauss density loss in 3D") + # flags.DEFINE_float("reg_gauss_skin_wt", 0.0, "weight for gauss density loss in 3D") flags.DEFINE_float("reg_cam_prior_wt", 0.1, "weight for camera regularization") - flags.DEFINE_float("reg_skel_prior_wt", 0.1, "weight for skeleton regularization") + flags.DEFINE_float("reg_skel_prior_wt", 0.01, "weight for skeleton regularization") flags.DEFINE_float( "reg_gauss_mask_wt", 0.01, "weight for gauss mask regularization" ) @@ -41,7 +42,8 @@ class TrainModelConfig: flags.DEFINE_string( "fg_motion", "rigid", "{rigid, dense, bob, skel-human, skel-quad}" ) - flags.DEFINE_bool("single_inst", True, "assume the same morphology over objs") + flags.DEFINE_bool("single_inst", True, "assume the same morphology over videos") + flags.DEFINE_bool("single_scene", True, "assume the same scene over videos") class TrainOptConfig: @@ -57,22 +59,25 @@ class TrainOptConfig: flags.DEFINE_string("feature_type", "dinov2", "{dinov2, cse}") flags.DEFINE_string("load_path", "", "path to load pretrained model") - # accuracy-related + # optimization-related flags.DEFINE_float("learning_rate", 5e-4, "learning rate") flags.DEFINE_integer("num_rounds", 20, "number of rounds to train") + flags.DEFINE_integer("num_rounds_cam_init", 10, "number of rounds for camera init") flags.DEFINE_integer("iters_per_round", 200, "number of iterations per round") flags.DEFINE_integer("imgs_per_gpu", 128, "images samples per iter, per gpu") flags.DEFINE_integer("pixels_per_image", 16, "pixel samples per image") # flags.DEFINE_integer("imgs_per_gpu", 1, "size of minibatches per iter") # flags.DEFINE_integer("pixels_per_image", 4096, "number of pixel samples per image") - flags.DEFINE_boolean( - "freeze_bone_len", False, "do not change bone length of skeleton" - ) + flags.DEFINE_boolean("use_freq_anneal", True, "whether to use frequency annealing") flags.DEFINE_boolean( "reset_steps", True, "reset steps of loss scheduling, set to False if resuming training", ) + flags.DEFINE_boolean("pose_correction", False, "whether to execute pose correction") + flags.DEFINE_boolean("alter_flow", False, "alternatve between flow and all terms") + flags.DEFINE_boolean("freeze_intrinsics", False, "whether to freeze intrinsics") + flags.DEFINE_boolean("absorb_base", True, "whether to absorb se3 into base") # efficiency-related flags.DEFINE_integer("ngpu", 1, "number of gpus to use") diff --git a/lab4d/config_omega.py b/lab4d/config_omega.py index 827c49e..8d9b97c 100644 --- a/lab4d/config_omega.py +++ b/lab4d/config_omega.py @@ -26,6 +26,7 @@ "field_type": "bg", # {bg, fg, comp} "fg_motion": "rigid", # {rigid, dense, bob, skel} "single_inst": True, # assume the same morphology over objs + "single_scene": True, # assume the same scene over videos }, "io": { "seqname": "cat", # name of the sequence diff --git a/lab4d/dataloader/data_utils.py b/lab4d/dataloader/data_utils.py index d56775e..7f61470 100644 --- a/lab4d/dataloader/data_utils.py +++ b/lab4d/dataloader/data_utils.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. import configparser import glob +import os import random import numpy as np @@ -234,6 +235,7 @@ def get_data_info(loader): intrinsics = [] raw_size = [] feature_pxs = [] + motion_scales = [] for dataset in dataset_list: frame_info = FrameInfo(dataset.dict_list["ref"]) @@ -253,11 +255,29 @@ def get_data_info(loader): num_skip = max(1, len(feature_array) // 1000) feature_pxs.append(feature_array[::num_skip]) + # compute motion magnitude + mask = dataset.mmap_list["mask"][:-1, ..., 0].copy() + if dataset.field_type == "bg": + mask = np.logical_not(mask) + elif dataset.field_type == "fg": + pass + elif dataset.field_type == "comp": + mask[:] = True + else: + raise ValueError("Unknown field type: %s" % dataset.field_type) + flow = dataset.mmap_list["flowfw"][1][mask, :2] + motion_scale = np.linalg.norm(flow, 2, -1).mean() + motion_scales.append(motion_scale) + # compute PCA on non-zero features feature_pxs = np.concatenate(feature_pxs, 0) feature_pxs = feature_pxs[np.linalg.norm(feature_pxs, 2, -1) > 0] data_info["apply_pca_fn"] = pca_numpy(feature_pxs, n_components=3) + # store motion magnitude + data_info["motion_scales"] = motion_scales + # print("motion scales: ", motion_scales) + frame_info = {} frame_info["frame_offset"] = np.asarray(frame_offset).cumsum() frame_info["frame_offset_raw"] = np.asarray(frame_offset_raw).cumsum() @@ -310,23 +330,43 @@ def load_small_files(data_path_dict): # [np.load(path).astype(np.float32) for path in data_path_dict["crop2raw"]], 0 # ) # N,4 - rtmat_bg = np.concatenate( - [np.load(path).astype(np.float32) for path in data_path_dict["cambg"]], 0 - ) # N,4,4 - rtmat_fg = np.concatenate( - [np.load(path).astype(np.float32) for path in data_path_dict["camfg"]], 0 - ) # N,4,4 + # bg/fg camera + rtmat_bg = [] + for vid, path in enumerate(data_path_dict["cambg"]): + # get N + num_frames = np.load(data_path_dict["is_detected"][vid]).shape[0] + if os.path.exists(path): + rtmat_bg.append(np.load(path).astype(np.float32)) + else: + rtmat_bg.append(np.eye(4)[None].repeat(num_frames, 0)) + print("Warning: no bg camera found at %s" % path) + rtmat_bg = np.concatenate(rtmat_bg, 0) # N,4,4 + + rtmat_fg = [] + for vid, path in enumerate(data_path_dict["camfg"]): + # get N + num_frames = np.load(data_path_dict["is_detected"][vid]).shape[0] + if os.path.exists(path): + rtmat_fg.append(np.load(path).astype(np.float32)) + else: + rtmat_fg.append(np.eye(4)[None].repeat(num_frames, 0)) + print("Warning: no fg camera found at %s" % path) + + rtmat_fg = np.concatenate(rtmat_fg, 0) + # hard-code for now vis_info = {"bg": 0, "fg": 1} # video instance segmentation info data_info["vis_info"] = vis_info data_info["rtmat"] = np.stack([rtmat_bg, rtmat_fg], 0) # path to centered mesh files - camera_prefix = data_path_dict["cambg"][0].rsplit("/", 1)[0] - data_info["geom_path"] = [ - "%s/mesh-00-centered.obj" % camera_prefix, - "%s/mesh-01-centered.obj" % camera_prefix, - ] + geom_path_bg = [] + geom_path_fg = [] + for path in data_path_dict["cambg"]: + camera_prefix = path.rsplit("/", 1)[0] + geom_path_bg.append("%s/mesh-00-centered.obj" % camera_prefix) + geom_path_fg.append("%s/mesh-01-centered.obj" % camera_prefix) + data_info["geom_path"] = [geom_path_bg, geom_path_fg] return data_info diff --git a/lab4d/dataloader/vidloader.py b/lab4d/dataloader/vidloader.py index bac7263..9b7a512 100644 --- a/lab4d/dataloader/vidloader.py +++ b/lab4d/dataloader/vidloader.py @@ -3,6 +3,7 @@ from pathlib import Path import numpy as np +import cv2 import torch from torch.utils.data import Dataset @@ -57,6 +58,7 @@ class VidDataset(Dataset): def __init__(self, opts, rgblist, dataid, ks, raw_size): self.delta_list = opts["delta_list"] + self.field_type = opts["field_type"] self.dict_list = self.construct_data_list( rgblist, opts["data_prefix"], opts["feature_type"] ) @@ -66,6 +68,7 @@ def __init__(self, opts, rgblist, dataid, ks, raw_size): self.ks = ks self.raw_size = raw_size self.img_size = np.load(self.dict_list["rgb"]).shape[1:3] # (H, W) + self.res = (opts["eval_res"], opts["eval_res"]) self.load_data_list(self.dict_list) self.idx_sampler = RangeSampler(num_elems=self.img_size[0] * self.img_size[1]) @@ -88,13 +91,28 @@ def construct_data_list(self, reflist, prefix, feature_type): flowfw_path = rgb_path.replace("JPEGImages", "FlowFW") flowbw_path = rgb_path.replace("JPEGImages", "FlowBW") depth_path = rgb_path.replace("JPEGImages", "Depth") + normal_path = rgb_path.replace("JPEGImages", "Normal") + if self.field_type == "bg": + group_id = 0 + else: + group_id = 1 feature_path = str( Path(rgb_path.replace("JPEGImages", "Features")).parent - ) + "/%s-%s-01.npy" % (prefix, feature_type) + ) + "/%s-%s-%02d.npy" % (prefix, feature_type, group_id) - camlist_bg = ( - reflist[0].replace("JPEGImages", "Cameras").replace("00000.jpg", "00.npy") - ) # bg + canonical_path_bg = ( + reflist[0] + .replace("JPEGImages", "Cameras") + .replace("00000.jpg", "00-canonical.npy") + ) + if os.path.exists(canonical_path_bg): + camlist_bg = canonical_path_bg + else: + camlist_bg = ( + reflist[0] + .replace("JPEGImages", "Cameras") + .replace("00000.jpg", "00.npy") + ) # bg camlist_fg = ( reflist[0] .replace("JPEGImages", "Cameras") @@ -115,6 +133,7 @@ def construct_data_list(self, reflist, prefix, feature_type): "flowfw": flowfw_path, "flowbw": flowbw_path, "depth": depth_path, + "normal": normal_path, "feature": feature_path, "crop2raw": crop2raw_path, "is_detected": is_detected_path, @@ -155,7 +174,11 @@ def load_data_list(self, dict_list): self.mmap_list[k] = np.load(path, mmap_mode="r") except: print(f"Warning: cannot load {path}") - self.mmap_list[k] = np.random.rand(self.__len__() + 1, 112, 112, 16) + if k=="feature": + self.mmap_list[k] = np.random.rand(self.__len__() + 1, 112, 112, 16) + else: + self.mmap_list[k] = np.random.rand(self.__len__() + 1, self.img_size[0], self.img_size[1], 3) + def __len__(self): return len(self.dict_list["ref"]) - 1 @@ -228,12 +251,13 @@ def read_raw(self, im0idx, delta, rand_xy=None): delta (int): Distance to other frame id in the pair rand_xy (array or None): (N, 2) pixels to load, if given Returns: - data_dict (Dict): Dict with keys "rgb", "mask", "depth", "feature", + data_dict (Dict): Dict with keys "rgb", "mask", "depth", "normal", "feature", "flow", "vis2d", "crop2raw", "dataid", "frameid_sub", "hxy" """ rgb = self.read_rgb(im0idx, rand_xy=rand_xy) mask, vis2d, crop2raw, is_detected = self.read_mask(im0idx, rand_xy=rand_xy) depth = self.read_depth(im0idx, rand_xy=rand_xy) + normal = self.read_normal(im0idx, rand_xy=rand_xy) flow = self.read_flow(im0idx, delta, rand_xy=rand_xy) feature = self.read_feature(im0idx, rand_xy=rand_xy) @@ -250,6 +274,7 @@ def read_raw(self, im0idx, delta, rand_xy=None): data_dict["rgb"] = rgb data_dict["mask"] = mask data_dict["depth"] = depth + data_dict["normal"] = normal data_dict["feature"] = feature data_dict["flow"] = flow[..., :2] data_dict["flow_uct"] = flow[..., 2:] @@ -272,7 +297,9 @@ def read_rgb(self, im0idx, rand_xy=None): """ rgb = self.mmap_list["rgb"][im0idx] shape = rgb.shape - if rand_xy is not None: + if rand_xy is None: + rgb = cv2.resize(rgb.astype(np.float32), self.res) + else: rgb = rgb[rand_xy[:, 1], rand_xy[:, 0]] # N,3 if len(shape) == 2: # gray image @@ -294,7 +321,10 @@ def read_mask(self, im0idx, rand_xy=None): from cropped (H,W) image to raw image, (fx, fy, cx, cy) """ mask = self.mmap_list["mask"][im0idx] - if rand_xy is not None: + if rand_xy is None: + mask = mask.astype(int) + mask = cv2.resize(mask, self.res, interpolation=cv2.INTER_NEAREST) + else: mask = mask[rand_xy[:, 1], rand_xy[:, 0]] # N,3 vis2d = mask[..., 1:] @@ -314,11 +344,29 @@ def read_depth(self, im0idx, rand_xy=None): depth (np.array): (H,W,1) or (N,1) Depth map, float16 """ depth = self.mmap_list["depth"][im0idx] - if rand_xy is not None: + if rand_xy is None: + depth = cv2.resize(depth.astype(np.float32), self.res) + else: depth = depth[rand_xy[:, 1], rand_xy[:, 0]] return depth[..., None] + def read_normal(self, im0idx, rand_xy=None): + """Read surface normal map for a single frame + + Args: + im0idx (int): Frame id to load + rand_xy (np.array or None): (N,2) Pixels to load, if given + Returns: + normal (np.array): (H,W,3) or (N,3) Surface normal map, float16 + """ + normal = self.mmap_list["normal"][im0idx] + if rand_xy is None: + normal = cv2.resize(normal.astype(np.float32), self.res) + else: + normal = normal[rand_xy[:, 1], rand_xy[:, 0]] + return normal + def read_feature(self, im0idx, rand_xy=None): """Read feature map for a single frame @@ -329,9 +377,13 @@ def read_feature(self, im0idx, rand_xy=None): feat (np.array): (112,112,16) or (N,16) Feature map, float32 """ feat = self.mmap_list["feature"][im0idx] # (112,112,16) - if rand_xy is not None: + if rand_xy is None: + feat = cv2.resize(feat.astype(np.float32), self.res) + else: rand_xy = rand_xy / self.img_size[0] * 112 feat = bilinear_interp(feat, rand_xy) + # normalize + feat = feat / (np.linalg.norm(feat, axis=-1, keepdims=True) + 1e-6) feat = feat.astype(np.float32) return feat @@ -351,7 +403,11 @@ def read_flow(self, im0idx, delta, rand_xy=None): flow = self.mmap_list["flowfw"][delta][im0idx // delta] else: flow = self.mmap_list["flowbw"][delta][im0idx // delta - 1] - if rand_xy is not None: + if rand_xy is None: + flow = cv2.resize(flow.astype(np.float32), self.res) + flow[..., 0] *= self.res[1] / self.img_size[1] + flow[..., 1] *= self.res[0] / self.img_size[0] + else: flow = flow[rand_xy[:, 1], rand_xy[:, 0]] flow = flow.astype(np.float32) diff --git a/lab4d/engine/model.py b/lab4d/engine/model.py index 1718cb3..d7b2a20 100644 --- a/lab4d/engine/model.py +++ b/lab4d/engine/model.py @@ -7,7 +7,9 @@ from tqdm import tqdm from lab4d.engine.train_utils import get_local_rank -from lab4d.nnutils.intrinsics import IntrinsicsMLP +from lab4d.nnutils.intrinsics import IntrinsicsMLP, IntrinsicsMLP_delta +from lab4d.nnutils.pose import CameraMLP_so3 +from lab4d.nnutils.feature import FeatureNeRF from lab4d.nnutils.multifields import MultiFields from lab4d.utils.geom_utils import K2inv, K2mat from lab4d.utils.numpy_utils import interp_wt @@ -32,9 +34,8 @@ def __init__(self, config, data_info): data_info=data_info, field_type=config["field_type"], fg_motion=config["fg_motion"], - num_inst=1 - if config["single_inst"] - else len(data_info["frame_info"]["frame_offset"]) - 1, + single_inst=config["single_inst"], + single_scene=config["single_scene"], ) self.intrinsics = IntrinsicsMLP( self.data_info["intrinsics"], @@ -92,56 +93,135 @@ def process_frameid(self, batch): # convert frameid_sub to frameid batch["frameid"] = batch["frameid_sub"] + self.offset_cache[batch["dataid"]] - def set_progress(self, current_steps): + def set_progress(self, current_steps, progress): """Adjust loss weights and other constants throughout training Args: current_steps (int): Number of optimization steps so far + progress (float): Fraction of training completed (in the current stage) """ - # positional encoding annealing - anchor_x = (0, 4000) - anchor_y = (0.6, 1) - type = "linear" - alpha = interp_wt(anchor_x, anchor_y, current_steps, type=type) - if alpha >= 1: - alpha = None - self.fields.set_alpha(alpha) + self.current_steps = current_steps + config = self.config + if self.config["use_freq_anneal"]: + # positional encoding annealing + anchor_x = (1000, 2000) + anchor_y = (0.6, 1) + type = "linear" + alpha = interp_wt(anchor_x, anchor_y, current_steps, type=type) + if alpha >= 1: + alpha = -1 + alpha = torch.tensor(alpha, device=self.device, dtype=torch.float32) + self.fields.set_alpha(alpha) + + # # use 2k steps to warmup + # if current_steps < 2000: + # self.fields.set_importance_sampling(False) + # else: + # self.fields.set_importance_sampling(True) + self.fields.set_importance_sampling(False) + + # pose correction: steps(0->2k, 1->0) + if config["pose_correction"]: + anchor_x = (0.8, 1.0) + type = "linear" + wt_modifier_dict = { + "feat_reproj_wt": 10.0, + "mask_wt": 0.0, + "rgb_wt": 0.0, + "flow_wt": 0.0, + "feature_wt": 0.0, + "reg_gauss_mask_wt": 0.0, + } + for loss_name, wt_modifier in wt_modifier_dict.items(): + anchor_y = (wt_modifier, 1.0) + self.set_loss_weight(loss_name, anchor_x, anchor_y, progress, type=type) + + if config["pose_correction"]: + sample_around_surface = True + else: + sample_around_surface = False + if "fg" in self.fields.field_params.keys() and isinstance( + self.fields.field_params["fg"], FeatureNeRF + ): + self.fields.field_params["fg"].set_match_region(sample_around_surface) + + if config["alter_flow"]: + # alternating between flow and all losses for initialzation + switch_list = [ + "mask_wt", + "rgb_wt", + "normal_wt", + "reg_gauss_mask_wt", + ] + if current_steps < 1600 and current_steps % 2 == 0: + # set to 0 + for key in switch_list: + self.set_loss_weight( + key, (0, 1), (0, 0), current_steps, type="linear" + ) + else: + # set to 1x + for key in switch_list: + self.set_loss_weight( + key, (0, 1), (1, 1), current_steps, type="linear" + ) - # beta_prob: steps(0->2k, 1->0.2), range (0.2,1) + # anneal geometry/appearance code for foreground: steps(0->2k, 1->0.2), range (0.2,1) anchor_x = (0, 2000) anchor_y = (1.0, 0.2) type = "linear" beta_prob = interp_wt(anchor_x, anchor_y, current_steps, type=type) self.fields.set_beta_prob(beta_prob) - # camera prior wt: steps(0->800, 1->0), range (0,1) + # camera prior wt: steps(0->1000, 1->0), range (0,1) loss_name = "reg_cam_prior_wt" - anchor_x = (0, 800) + anchor_x = (0, config["num_rounds_cam_init"] * config["iters_per_round"]) anchor_y = (1, 0) type = "linear" self.set_loss_weight(loss_name, anchor_x, anchor_y, current_steps, type=type) # reg_eikonal_wt: steps(0->24000, 1->100), range (1,100) loss_name = "reg_eikonal_wt" - anchor_x = (0, 4000) - anchor_y = (1, 100) - type = "log" + anchor_x = (800, 2000) + anchor_y = (1, config["reg_eikonal_scale_max"]) + type = "linear" self.set_loss_weight(loss_name, anchor_x, anchor_y, current_steps, type=type) - # skel prior wt: steps(0->4000, 1->0), range (0,1) + # skel prior wt: steps(0->4000, 1->0), to discouage large changes when shape is not good loss_name = "reg_skel_prior_wt" - anchor_x = (0, 4000) - anchor_y = (1, 0) - type = "linear" + anchor_x = (200, 400) + anchor_y = (10, 1) + type = "log" self.set_loss_weight(loss_name, anchor_x, anchor_y, current_steps, type=type) - # gauss mask wt: steps(0->4000, 1->0), range (0,1) - loss_name = "reg_gauss_mask_wt" - anchor_x = (0, 4000) - anchor_y = (1, 0) + # # gauss mask wt: steps(0->4000, 1->0), range (0,1) + # loss_name = "reg_gauss_mask_wt" + # anchor_x = (0, 2000) + # anchor_y = (1, 0.1) + # type = "log" + # self.set_loss_weight(loss_name, anchor_x, anchor_y, current_steps, type=type) + + # delta skin wt: steps(0->2000, 1->0.1), to make skinning more flexible + loss_name = "reg_delta_skin_wt" + anchor_x = (0, 2000) + anchor_y = (1, 0.01) + type = "log" + self.set_loss_weight(loss_name, anchor_x, anchor_y, current_steps, type=type) + + # gauss skin wt: steps(0->2000, 1->0), to align skeleton with shape + loss_name = "reg_gauss_skin_wt" + anchor_x = (1000, 2000) + anchor_y = (0.05, 1) type = "linear" self.set_loss_weight(loss_name, anchor_x, anchor_y, current_steps, type=type) + # # learn feature field before reproj error + # loss_name = "feat_reproj_wt" + # anchor_x = (200, 400) + # anchor_y = (0, 1) + # type = "linear" + # self.set_loss_weight(loss_name, anchor_x, anchor_y, current_steps, type=type) + def set_loss_weight( self, loss_name, anchor_x, anchor_y, current_steps, type="linear" ): @@ -190,10 +270,17 @@ def evaluate(self, batch, is_pair=True): batch_sub[k][k2] = v2[i * div_factor : (i + 1) * div_factor] else: batch_sub[k] = v[i * div_factor : (i + 1) * div_factor] - rendered_sub = self.render(batch_sub)["rendered"] + results_sub = self.render(batch_sub) + rendered_sub, aux = results_sub["rendered"], results_sub["aux_dict"] for k, v in rendered_sub.items(): res = int(np.sqrt(v.shape[1])) rendered[k].append(v.view(div_factor, res, res, -1)[0]) + for k, v in aux["fg"].items(): + res = int(np.sqrt(v.shape[1])) + rendered["%s_id-fg" % k].append(v.view(div_factor, res, res, -1)[0]) + for k, v in aux["bg"].items(): + res = int(np.sqrt(v.shape[1])) + rendered["%s_id-bg" % k].append(v.view(div_factor, res, res, -1)[0]) for k, v in rendered.items(): rendered[k] = torch.stack(v, 0) @@ -202,14 +289,35 @@ def evaluate(self, batch, is_pair=True): for k, v in rendered.items(): if "mask" in k: continue + elif "xyz_matches" in k or "xyz_reproj" in k: + rendered[k] = rendered[k] * (rendered["mask_id-fg"] > 0.5).float() + elif "xy_reproj" in k: + mask = batch["feature"][::div_factor].norm(2, -1, keepdim=True) > 0 + res = rendered[k].shape[1] + rendered[k] = rendered[k] * mask.float().view(-1, res, res, 1) else: - rendered[k] = rendered[k] * rendered["mask"] + if "id-fg" in k: + mask = rendered["mask_id-fg"] + elif "id-bg" in k: + mask = rendered["mask_id-bg"] + else: + mask = rendered["mask"] + rendered[k] = rendered[k] * mask return rendered def update_geometry_aux(self): """Extract proxy geometry for all neural fields""" self.fields.update_geometry_aux() + def update_camera_aux(self): + if isinstance(self.intrinsics, IntrinsicsMLP_delta): + self.intrinsics.update_base_focal() + + # update camera mlp base quat + for field in self.fields.field_params.values(): + if isinstance(field.camera_mlp, CameraMLP_so3): + field.camera_mlp.update_base_quat() + def export_geometry_aux(self, path): """Export proxy geometry for all neural fields""" return self.fields.export_geometry_aux(path) @@ -230,7 +338,10 @@ def render(self, batch, flow_thresh=None): results["aux_dict"]["fg"]: "xy_reproj" (M,N,2) and "feature" (M,N,16) """ samples_dict = self.get_samples(batch) - results = self.render_samples_chunk(samples_dict, flow_thresh=flow_thresh) + if self.training: + results = self.render_samples(samples_dict, flow_thresh=flow_thresh) + else: + results = self.render_samples_chunk(samples_dict, flow_thresh=flow_thresh) return results def get_samples(self, batch): @@ -256,7 +367,7 @@ def get_samples(self, batch): samples_dict = self.fields.get_samples(Kinv, batch) return samples_dict - def render_samples_chunk(self, samples_dict, flow_thresh=None, chunk_size=8192): + def render_samples_chunk(self, samples_dict, flow_thresh=None, chunk_size=2048): """Render outputs from all neural fields. Divide in chunks along pixel dimension N to avoid running out of memory. @@ -273,14 +384,12 @@ def render_samples_chunk(self, samples_dict, flow_thresh=None, chunk_size=8192): """ # get chunk size category = list(samples_dict.keys())[0] - total_pixels = ( - samples_dict[category]["hxy"].shape[0] - * samples_dict[category]["hxy"].shape[1] - ) + num_imgs, num_pixels, _ = samples_dict[category]["hxy"].shape + total_pixels = num_imgs * num_pixels num_chunks = int(np.ceil(total_pixels / chunk_size)) - chunk_size_n = int( - np.ceil(chunk_size // samples_dict[category]["hxy"].shape[0]) - ) # at n dimension + + # break into chunks at pixel dimension + chunk_size_px = int(np.ceil(chunk_size // num_imgs)) results = { "rendered": defaultdict(list), @@ -292,14 +401,23 @@ def render_samples_chunk(self, samples_dict, flow_thresh=None, chunk_size=8192): for category, category_v in samples_dict.items(): samples_dict_chunk[category] = defaultdict(list) for k, v in category_v.items(): - if k == "hxy": - samples_dict_chunk[category][k] = v[ - :, i * chunk_size_n : (i + 1) * chunk_size_n - ] + # only break for pixel-ish elements + if ( + isinstance(v, torch.Tensor) + and v.ndim == 3 + and v.shape[1] == num_pixels + ): + chunk_px = v[:, i * chunk_size_px : (i + 1) * chunk_size_px] + samples_dict_chunk[category][k] = chunk_px.clone() else: samples_dict_chunk[category][k] = v # get chunk output + if not self.training: + # clear cache for evaluation + torch.cuda.empty_cache() + # print("allocated: %.2f M" % (torch.cuda.memory_allocated() / (1024**2))) + # print("cached: %.2f M" % (torch.cuda.memory_cached() / (1024**2))) results_chunk = self.render_samples( samples_dict_chunk, flow_thresh=flow_thresh ) @@ -392,10 +510,12 @@ def compute_loss(self, batch, results): """ config = self.config loss_dict = {} - self.compute_recon_loss(loss_dict, results, batch, config) + self.compute_recon_loss(loss_dict, results, batch, config, self.current_steps) self.mask_losses(loss_dict, batch, config) self.compute_reg_loss(loss_dict, results) - self.apply_loss_weights(loss_dict, config) + motion_scale = torch.tensor(self.data_info["motion_scales"], device=self.device) + motion_scale = motion_scale[batch["dataid"]] + self.apply_loss_weights(loss_dict, config, motion_scale) return loss_dict @staticmethod @@ -422,7 +542,7 @@ def get_mask_balance_wt(mask, vis2d, is_detected): return mask_balance_wt @staticmethod - def compute_recon_loss(loss_dict, results, batch, config): + def compute_recon_loss(loss_dict, results, batch, config, current_steps): """Compute reconstruction losses. Args: @@ -446,6 +566,7 @@ def compute_recon_loss(loss_dict, results, batch, config): rendered_fg_mask = rendered["mask"] elif config["field_type"] == "comp": rendered_fg_mask = rendered["mask_fg"] + # rendered_fg_mask = aux_dict["fg"]["mask"] elif config["field_type"] == "bg": rendered_fg_mask = None else: @@ -460,9 +581,10 @@ def compute_recon_loss(loss_dict, results, batch, config): loss_dict["mask"] = (rendered_fg_mask - batch["mask"].float()).pow(2) loss_dict["mask"] *= mask_balance_wt elif config["field_type"] == "comp": - loss_dict["mask"] = (rendered_fg_mask - batch["mask"].float()).pow(2) - loss_dict["mask"] *= mask_balance_wt - loss_dict["mask"] += (rendered["mask"] - 1).pow(2) + # loss_dict["mask"] = (rendered_fg_mask - batch["mask"].float()).pow(2) + # loss_dict["mask"] *= mask_balance_wt + # loss_dict["mask"] += (rendered["mask"] - 1).pow(2) + loss_dict["mask"] = (rendered["mask"] - 1).pow(2) else: raise ("field_type %s not supported" % config["field_type"]) @@ -470,13 +592,15 @@ def compute_recon_loss(loss_dict, results, batch, config): loss_dict["feature"] = (aux_dict["fg"]["feature"] - batch["feature"]).norm( 2, -1, keepdim=True ) - loss_dict["feat_reproj"] = ( - aux_dict["fg"]["xy_reproj"] - batch["hxy"][..., :2] - ).norm(2, -1, keepdim=True) + loss_dict["feat_reproj"] = aux_dict["fg"]["xy_reproj"] loss_dict["rgb"] = (rendered["rgb"] - batch["rgb"]).pow(2) - loss_dict["depth"] = ( - (rendered["depth"] - batch["depth"]).norm(2, -1, keepdim=True).clone() + loss_dict["depth"] = (rendered["depth"] - batch["depth"]).abs() + loss_dict["normal"] = (rendered["normal"] - batch["normal"]).pow(2) + # remove pixels not sampled to render normals + loss_dict["normal"] = ( + loss_dict["normal"] + * (rendered["normal"].norm(2, -1, keepdim=True) > 0).float() ) loss_dict["flow"] = (rendered["flow"] - batch["flow"]).norm(2, -1, keepdim=True) @@ -496,9 +620,20 @@ def compute_recon_loss(loss_dict, results, batch, config): # consistency between rendered mask and gauss mask if "gauss_mask" in rendered.keys(): - loss_dict["reg_gauss_mask"] = ( - aux_dict["fg"]["gauss_mask"] - rendered_fg_mask.detach() - ).pow(2) + if current_steps < 4000: + # supervise with a fixed target + loss_dict["reg_gauss_mask"] = ( + aux_dict["fg"]["gauss_mask"] - batch["mask"].float() + ).pow(2) + else: + loss_dict["reg_gauss_mask"] = ( + aux_dict["fg"]["gauss_mask"] - (rendered_fg_mask > 0.5).float() + ).pow(2) + + # # downweight pixels with low opacity (e.g., mask not aligned with gt) + # density_related_loss = ["rgb", "depth", "normal", "feature", "flow"] + # for k in density_related_loss: + # loss_dict[k] = loss_dict[k] * rendered["mask"].detach() def compute_reg_loss(self, loss_dict, results): """Compute regularization losses. @@ -521,7 +656,8 @@ def compute_reg_loss(self, loss_dict, results): loss_dict["reg_delta_skin"] = aux_dict["fg"]["delta_skin"] loss_dict["reg_skin_entropy"] = aux_dict["fg"]["skin_entropy"] loss_dict["reg_soft_deform"] = self.fields.soft_deform_loss() - loss_dict["reg_gauss_skin"] = self.fields.gauss_skin_consistency_loss() + if self.config["reg_gauss_skin_wt"] > 0: + loss_dict["reg_gauss_skin"] = self.fields.gauss_skin_consistency_loss() loss_dict["reg_cam_prior"] = self.fields.cam_prior_loss() loss_dict["reg_skel_prior"] = self.fields.skel_prior_loss() @@ -548,7 +684,7 @@ def mask_losses(loss_dict, batch, config): # always mask-out non-object pixels keys_fg = ["feature", "feat_reproj"] # field type specific keys - keys_type_specific = ["rgb", "depth", "flow", "vis"] + keys_type_specific = ["rgb", "depth", "flow", "vis", "normal"] # type-specific masking rules vis2d = batch["vis2d"].float() @@ -576,13 +712,18 @@ def mask_losses(loss_dict, batch, config): raise ("loss %s not defined" % k) # mask out the following losses if obj is not detected - keys_mask_not_detected = ["mask", "feature", "feat_reproj"] + keys_mask_not_detected = ["feature", "feat_reproj"] + is_detected = batch["is_detected"].float()[:, None, None] for k, v in loss_dict.items(): if k in keys_mask_not_detected: - loss_dict[k] = v * batch["is_detected"].float()[:, None, None] + loss_dict[k] = v * is_detected + + # remove mask loss for frames without detection + if config["field_type"] == "fg" or config["field_type"] == "comp": + loss_dict["mask"] = loss_dict["mask"] * is_detected @staticmethod - def apply_loss_weights(loss_dict, config): + def apply_loss_weights(loss_dict, config, motion_scale): """Weigh each loss term according to command-line configs Args: @@ -595,11 +736,22 @@ def apply_loss_weights(loss_dict, config): "reg_cam_prior" (0,), and "reg_skel_prior" (0,). Modified in place to multiply each term with a scalar weight. config (Dict): Command-line options + motion_scale (Tensor): Motion magnitude for each data sample (M,) """ - px_unit_keys = ["flow", "feat_reproj"] + # px_unit_keys = ["feat_reproj"] + # motion_unit_keys = ["flow"] + px_unit_keys = ["feat_reproj", "flow"] for k, v in loss_dict.items(): + # # scale with motion magnitude + # if k in motion_unit_keys: + # loss_dict[k] /= motion_scale.clamp(1, 20).view(-1, 1, 1) + # average over non-zero pixels - loss_dict[k] = v[v > 0].mean() + v = v[v > 0] + if v.numel() > 0: + loss_dict[k] = v.mean() + else: + loss_dict[k] = v.sum() # return zero # scale with image resolution if k in px_unit_keys: @@ -609,3 +761,14 @@ def apply_loss_weights(loss_dict, config): wt_name = k + "_wt" if wt_name in config.keys(): loss_dict[k] *= config[wt_name] + + def get_field_betas(self): + """Get beta values for all neural fields + + Returns: + betas (Dict): Beta values for each neural field + """ + beta_dicts = {} + for field in self.fields.field_params.values(): + beta_dicts["beta/%s" % field.category] = field.logibeta.exp() + return beta_dicts diff --git a/lab4d/engine/train_utils.py b/lab4d/engine/train_utils.py index 14fc81b..a9bd44c 100644 --- a/lab4d/engine/train_utils.py +++ b/lab4d/engine/train_utils.py @@ -12,6 +12,41 @@ def get_local_rank(): return 0 +def match_param_name(name, param_lr, type): + """ + Match the param name with the param_lr dict + + Args: + name (str): the name of the param + param_lr (Dict): the param_lr dict + type (str): "with" or "startwith" + + Returns: + bool, lr + """ + matched_param = [] + matched_lr = [] + + for params_name, lr in param_lr.items(): + if type == "with": + if params_name in name: + matched_param.append(params_name) + matched_lr.append(lr) + elif type == "startwith": + if name.startswith(params_name): + matched_param.append(params_name) + matched_lr.append(lr) + else: + raise ValueError("type not found") + + if len(matched_param) == 0: + return False, 0.0 + elif len(matched_param) == 1: + return True, matched_lr[0] + else: + raise ValueError("multiple matches found", matched_param) + + class DataParallelPassthrough(torch.nn.parallel.DistributedDataParallel): """For multi-GPU access, forward attributes to the inner module.""" diff --git a/lab4d/engine/trainer.py b/lab4d/engine/trainer.py index f1cc640..89b7b7c 100644 --- a/lab4d/engine/trainer.py +++ b/lab4d/engine/trainer.py @@ -3,6 +3,7 @@ import time from collections import defaultdict from copy import deepcopy +import gc import numpy as np import torch @@ -15,7 +16,11 @@ from lab4d.dataloader import data_utils from lab4d.dataloader.vidloader import VidDataset from lab4d.engine.model import dvr_model -from lab4d.engine.train_utils import DataParallelPassthrough, get_local_rank +from lab4d.engine.train_utils import ( + DataParallelPassthrough, + get_local_rank, + match_param_name, +) from lab4d.utils.profile_utils import torch_profile from lab4d.utils.torch_utils import remove_ddp_prefix from lab4d.utils.vis_utils import img2color, make_image_grid @@ -29,7 +34,6 @@ def __init__(self, opts): opts (Dict): Command-line args from absl (defined in lab4d/config.py) """ # When profiling, use fewer iterations per round so trace files are smaller - is_resumed = opts["load_path"] != "" if opts["profile"]: opts["iters_per_round"] = 10 @@ -38,11 +42,19 @@ def __init__(self, opts): self.define_dataset() self.trainer_init() self.define_model() - self.optimizer_init(is_resumed=is_resumed) + + # move model to ddp + self.model = DataParallelPassthrough( + self.model, + device_ids=[get_local_rank()], + output_device=get_local_rank(), + find_unused_parameters=False, + ) + + self.optimizer_init(is_resumed=opts["load_path"] != "") # load model - if is_resumed: - self.load_checkpoint_train() + self.load_checkpoint_train() def trainer_init(self): """Initialize logger and other misc things""" @@ -63,9 +75,12 @@ def trainer_init(self): self.current_steps = 0 # 0-total_steps self.current_round = 0 # 0-num_rounds + self.first_round = 0 # 0 + self.first_step = 0 # 0 # 0-last image in eval dataset self.eval_fid = np.linspace(0, len(self.evalloader) - 1, 9).astype(int) + # self.eval_fid = np.linspace(1200, 1200, 9).astype(int) # torch.manual_seed(8) # do it again # torch.cuda.manual_seed(1) @@ -88,18 +103,17 @@ def define_dataset(self): def init_model(self): """Initialize camera transforms, geometry, articulations, and camera intrinsics from external priors, if this is the first run""" - opts = self.opts # init mlp if get_local_rank() == 0: self.model.mlp_init() - def define_model(self): + def define_model(self, model=dvr_model): """Define a Lab4D model and wrap it with DistributedDataParallel""" opts = self.opts data_info = self.data_info self.device = torch.device("cuda:{}".format(get_local_rank())) - self.model = dvr_model(opts, data_info) + self.model = model(opts, data_info) # ddp self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) @@ -107,19 +121,23 @@ def define_model(self): self.init_model() - self.model = DataParallelPassthrough( - self.model, - device_ids=[get_local_rank()], - output_device=get_local_rank(), - find_unused_parameters=False, - ) - # cache queue of length 2 self.model_cache = [None, None] self.optimizer_cache = [None, None] self.scheduler_cache = [None, None] - def get_lr_dict(self): + self.grad_queue = {} + self.param_clip_startwith = { + "module.fields.field_params.fg.camera_mlp": 10.0, + "module.fields.field_params.fg.warp.articulation": 10.0, + "module.fields.field_params.fg.basefield": 10.0, + "module.fields.field_params.fg.sdf": 10.0, + "module.fields.field_params.bg.camera_mlp": 10.0, + "module.fields.field_params.bg.basefield": 10.0, + "module.fields.field_params.bg.sdf": 10.0, + } + + def get_lr_dict(self, pose_correction=False): """Return the learning rate for each category of trainable parameters Returns: @@ -130,21 +148,33 @@ def get_lr_dict(self): opts = self.opts lr_base = opts["learning_rate"] lr_explicit = lr_base * 10 + lr_intrinsics = 0.0 if opts["freeze_intrinsics"] else lr_base param_lr_startwith = { "module.fields.field_params": lr_base, - "module.intrinsics": lr_base, + "module.intrinsics": lr_intrinsics, } param_lr_with = { ".logibeta": lr_explicit, ".logsigma": lr_explicit, ".logscale": lr_explicit, - ".log_gauss": lr_explicit, - ".base_quat": lr_explicit, - ".base_logfocal": lr_explicit, - ".base_ppoint": lr_explicit, + ".log_gauss": 0.0, + ".base_quat": 0.0, ".shift": lr_explicit, + ".orient": lr_explicit, } + + if pose_correction: + del param_lr_with[".logscale"] + del param_lr_with[".log_gauss"] + param_lr_with_pose_correction = { + "module.fields.field_params.fg.basefield.": 0.0, + "module.fields.field_params.fg.sdf.": 0.0, + "module.fields.field_params.fg.feature_field": 0.0, + "module.fields.field_params.fg.warp.skinning_model": 0.0, + } + param_lr_with.update(param_lr_with_pose_correction) + return param_lr_startwith, param_lr_with def optimizer_init(self, is_resumed=False): @@ -155,60 +185,95 @@ def optimizer_init(self, is_resumed=False): is_resumed (bool): True if resuming from checkpoint """ opts = self.opts - - param_lr_startwith, param_lr_with = self.get_lr_dict() - - if opts["freeze_bone_len"]: - param_lr_with[".log_bone_len"] = 0 - - params_list = [] - lr_list = [] - for name, p in self.model.named_parameters(): - name_found = False - for params_name, lr in param_lr_with.items(): - if params_name in name: - params_list.append({"params": p}) - lr_list.append(lr) - name_found = True - if get_local_rank() == 0: - print(name, p.shape, lr) - - if name_found: - continue - for params_name, lr in param_lr_startwith.items(): - if name.startswith(params_name): - params_list.append({"params": p}) - lr_list.append(lr) - if get_local_rank() == 0: - print(name, p.shape, lr) - + self.params_ref_list, params_list, lr_list = self.get_optimizable_param_list() + + # # one cycle lr + # self.optimizer = torch.optim.AdamW( + # params_list, + # lr=opts["learning_rate"], + # betas=(0.9, 0.999), + # weight_decay=1e-4, + # ) + # # initial_lr = lr/div_factor + # # min_lr = initial_lr/final_div_factor + # # if is_resumed: + # if False: + # div_factor = 1.0 + # final_div_factor = 25.0 + # pct_start = 0.0 # cannot be 0 + # else: + # div_factor = 25.0 + # final_div_factor = 1.0 + # pct_start = min( + # 1 - 1e-5, 2.0 / opts["num_rounds"] + # ) # use 2 epochs to warm up + # self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + # self.optimizer, + # lr_list, + # int(self.total_steps), + # pct_start=pct_start, + # cycle_momentum=False, + # anneal_strategy="linear", + # div_factor=div_factor, + # final_div_factor=final_div_factor, + # ) + + # cyclic lr + assert self.total_steps // 2000 * 2000 == self.total_steps # dividible by 2k self.optimizer = torch.optim.AdamW( params_list, lr=opts["learning_rate"], - betas=(0.9, 0.999), + betas=(0.9, 0.99), weight_decay=1e-4, ) - # initial_lr = lr/div_factor - # min_lr = initial_lr/final_div_factor - if is_resumed: - div_factor = 1.0 - final_div_factor = 5.0 - pct_start = 0.0 # cannot be 0 - else: - div_factor = 25.0 - final_div_factor = 1.0 - pct_start = 2.0 / opts["num_rounds"] # use 2 epochs to warm up - self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.scheduler = torch.optim.lr_scheduler.CyclicLR( self.optimizer, + [i * 0.01 for i in lr_list], lr_list, - int(self.total_steps), - pct_start=pct_start, + step_size_up=10, + step_size_down=1990, + mode="triangular", + gamma=1.0, + scale_mode="cycle", cycle_momentum=False, - anneal_strategy="linear", - div_factor=div_factor, - final_div_factor=final_div_factor, ) + def get_optimizable_param_list(self): + """ + Get the optimizable param list + Returns: + params_ref_list (List): List of params + params_list (List): List of params + lr_list (List): List of learning rates + """ + param_lr_startwith, param_lr_with = self.get_lr_dict( + pose_correction=self.opts["pose_correction"] + ) + params_ref_list = [] + params_list = [] + lr_list = [] + + for name, p in self.model.named_parameters(): + matched_loose, lr_loose = match_param_name(name, param_lr_with, type="with") + matched_strict, lr_strict = match_param_name( + name, param_lr_startwith, type="startwith" + ) + if matched_loose > 0: + lr = lr_loose # higher priority + elif matched_strict > 0: + lr = lr_strict + else: + lr = 0.0 # not found + # print(name, "not found") + if lr > 0: + params_ref_list.append({name: p}) + params_list.append({"params": p}) + lr_list.append(lr) + if get_local_rank() == 0: + print(name, p.shape, lr) + + return params_ref_list, params_list, lr_list + def train(self): """Training loop""" opts = self.opts @@ -221,36 +286,38 @@ def train(self): # start training loop self.save_checkpoint(round_count=self.current_round) - for round_count in range( - self.current_round, self.current_round + opts["num_rounds"] - ): + for _ in range(self.current_round, self.current_round + opts["num_rounds"]): start_time = time.time() with torch_profile( - self.save_dir, f"{round_count:03d}", enabled=opts["profile"] + self.save_dir, f"{self.current_round:03d}", enabled=opts["profile"] ): - self.run_one_round(round_count) + self.run_one_round() if get_local_rank() == 0: - print(f"Round {round_count:03d}: time={time.time() - start_time:.3f}s") - - def run_one_round(self, round_count): - """Evaluation and training for a single round + print( + f"Round {self.current_round:03d}: time={time.time() - start_time:.3f}s" + ) + self.save_checkpoint(round_count=self.current_round) - Args: - round_count (int): Current round index - """ - self.model.eval() + def run_one_round(self): + """Evaluation and training for a single round""" if get_local_rank() == 0: - with torch.no_grad(): + if self.current_round == self.first_round: self.model_eval() self.model.update_geometry_aux() - self.model.export_geometry_aux("%s/%03d" % (self.save_dir, round_count)) + self.model.export_geometry_aux("%s/%03d" % (self.save_dir, self.current_round)) + if ( + self.current_round > self.opts["num_rounds_cam_init"] + and self.opts["absorb_base"] + ): + self.model.update_camera_aux() self.model.train() - self.train_one_round(round_count) + self.train_one_round() self.current_round += 1 - self.save_checkpoint(round_count=self.current_round) + if get_local_rank() == 0: + self.model_eval() def save_checkpoint(self, round_count): """Save model checkpoint to disk @@ -300,6 +367,9 @@ def load_checkpoint(load_path, model, optimizer=None): model_states = remove_ddp_prefix(model_states) model.load_state_dict(model_states, strict=False) + # reset near_far + model.fields.reset_geometry_aux() + # if optimizer is not None: # # use the new param_groups that contains the learning rate # checkpoint["optimizer"]["param_groups"] = optimizer.state_dict()[ @@ -312,33 +382,33 @@ def load_checkpoint_train(self): """Load a checkpoint at training time and update the current step count and round count """ - # training time - checkpoint = self.load_checkpoint( - self.opts["load_path"], self.model, optimizer=self.optimizer - ) - if not self.opts["reset_steps"]: - self.current_steps = checkpoint["current_steps"] - self.current_round = checkpoint["current_round"] - - # reset near_far - self.model.fields.reset_geometry_aux() - - def train_one_round(self, round_count): - """Train a single round (going over mini-batches) - - Args: - round_count (int): round index - """ + if self.opts["load_path"] != "": + # training time + checkpoint = self.load_checkpoint( + self.opts["load_path"], self.model, optimizer=self.optimizer + ) + if not self.opts["reset_steps"]: + self.current_steps = checkpoint["current_steps"] + self.current_round = checkpoint["current_round"] + self.first_round = self.current_round + self.first_step = self.current_steps + + def train_one_round(self): + """Train a single round (going over mini-batches)""" opts = self.opts + gc.collect() # need to be used together with empty_cache() torch.cuda.empty_cache() self.model.train() + self.optimizer.zero_grad() - self.trainloader.sampler.set_epoch(round_count) # necessary for shuffling + # necessary for shuffling + self.trainloader.sampler.set_epoch(self.current_round) for i, batch in enumerate(self.trainloader): if i == opts["iters_per_round"]: break - self.model.set_progress(self.current_steps) + progress = (self.current_steps - self.first_step) / self.total_steps + self.model.set_progress(self.current_steps, progress) loss_dict = self.model(batch) total_loss = torch.sum(torch.stack(list(loss_dict.values()))) @@ -346,12 +416,16 @@ def train_one_round(self, round_count): # print(total_loss) # self.print_sum_params() - self.check_grad() + grad_dict = self.check_grad() self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() if get_local_rank() == 0: + # update scalar dict + loss_dict["loss/total"] = total_loss + loss_dict.update(self.model.get_field_betas()) + loss_dict.update(grad_dict) self.add_scalar(self.log, loss_dict, self.current_steps) self.current_steps += 1 @@ -370,6 +444,8 @@ def construct_dataset_opts(opts, is_eval=False, dataset_constructor=VidDataset): opts_dict["load_pair"] = True opts_dict["data_prefix"] = "%s-%d" % (opts["data_prefix"], opts["train_res"]) opts_dict["feature_type"] = opts["feature_type"] + opts_dict["field_type"] = opts["field_type"] + opts_dict["eval_res"] = opts["eval_res"] opts_dict["dataset_constructor"] = dataset_constructor if is_eval: @@ -397,8 +473,11 @@ def print_sum_params(self): sum += p.abs().sum() print(f"{sum:.16f}") + @torch.no_grad() def model_eval(self): """Evaluate the current model""" + self.model.eval() + gc.collect() # need to be used together with empty_cache() torch.cuda.empty_cache() ref_dict, batch = self.load_batch(self.evalloader.dataset, self.eval_fid) self.construct_eval_batch(batch) @@ -447,7 +526,7 @@ def load_batch(dataset, fids): """ ref_dict = defaultdict(list) batch_aggr = defaultdict(list) - ref_keys = ["rgb", "mask", "depth", "feature", "vis2d"] + ref_keys = ["rgb", "mask", "depth", "feature", "vis2d", "normal"] batch_keys = ["dataid", "frameid_sub", "crop2raw"] for fid in fids: batch = dataset[fid] @@ -578,27 +657,68 @@ def construct_test_model(opts): return model, data_info, ref_dict - def check_grad(self, thresh=5.0): + def check_grad(self, thresh=10.0): """Check if gradients are above a threshold Args: thresh (float): Gradient clipping threshold """ - # parameters that are sensitive to large gradients - - param_list = [] - for name, p in self.model.named_parameters(): - if p.requires_grad: - param_list.append(p) - - grad_norm = torch.nn.utils.clip_grad_norm_(param_list, thresh) - if grad_norm > thresh: + # detect large gradients and reload model + params_list = [] + for param_dict in self.params_ref_list: + ((name, p),) = param_dict.items() + if p.requires_grad and p.grad is not None: + params_list.append(p) + # if p.grad.isnan().any(): + # p.grad.zero_() + + # check individual parameters + grad_norm = torch.nn.utils.clip_grad_norm_(params_list, thresh) + if grad_norm > thresh or torch.isnan(grad_norm): # clear gradients self.optimizer.zero_grad() + if get_local_rank() == 0: + print("large grad: %.2f, clear gradients" % grad_norm) # load cached model from two rounds ago if self.model_cache[0] is not None: if get_local_rank() == 0: - print("large grad: %.2f, resume from cached weights" % grad_norm) + print("fallback to cached model") self.model.load_state_dict(self.model_cache[0]) self.optimizer.load_state_dict(self.optimizer_cache[0]) self.scheduler.load_state_dict(self.scheduler_cache[0]) + return {} + + # clip individual parameters + grad_dict = {} + queue_length = 10 + for param_dict in self.params_ref_list: + ((name, p),) = param_dict.items() + if p.requires_grad and p.grad is not None: + grad = p.grad.reshape(-1).norm(2, -1) + grad_dict["grad/" + name] = grad + # maintain a queue of grad norm, and clip outlier grads + matched_strict, clip_strict = match_param_name( + name, self.param_clip_startwith, type="startwith" + ) + if matched_strict: + scale_threshold = clip_strict + else: + continue + + # check the gradient norm + if name not in self.grad_queue: + self.grad_queue[name] = [] + if len(self.grad_queue[name]) > queue_length: + med_grad = torch.stack(self.grad_queue[name][:-1]).median() + grad_dict["grad_med/" + name] = med_grad + if grad > scale_threshold * med_grad: + torch.nn.utils.clip_grad_norm_(p, med_grad) + # if get_local_rank() == 0: + # print("large grad: %.2f, clear %s" % (grad, name)) + else: + self.grad_queue[name].append(grad) + self.grad_queue[name].pop(0) + else: + self.grad_queue[name].append(grad) + + return grad_dict diff --git a/lab4d/export.py b/lab4d/export.py index 10781f8..7467bb2 100644 --- a/lab4d/export.py +++ b/lab4d/export.py @@ -27,6 +27,7 @@ dual_quaternion_to_se3, quaternion_translation_to_se3, ) +from lab4d.utils.vis_utils import append_xz_plane cudnn.benchmark = True @@ -37,6 +38,10 @@ class ExportMeshFlags: flags.DEFINE_float( "level", 0.0, "contour value of marching cubes use to search for isosurfaces" ) + flags.DEFINE_float( + "vis_thresh", 0.0, "visibility threshold to remove invisible pts, -inf to inf" + ) + flags.DEFINE_boolean("extend_aabb", False, "use extended aabb for meshing (for bg)") class MotionParamsExpl(NamedTuple): @@ -53,26 +58,37 @@ class MotionParamsExpl(NamedTuple): bone_t: trimesh.Trimesh # bone center at time t -def extract_deformation(field, mesh_rest, inst_id, frame_ids): +def extract_deformation(field, mesh_rest, inst_id): + # get corresponding frame ids + frame_mapping = field.camera_mlp.time_embedding.frame_mapping + frame_offset = field.frame_offset + frame_ids = frame_mapping[frame_offset[inst_id] : frame_offset[inst_id + 1]] + start_id = frame_ids[0] + print("Extracting motion parameters for inst id:", inst_id) + print("Frame ids with the video:", frame_ids - start_id) + device = next(field.parameters()).device xyz = torch.tensor(mesh_rest.vertices, dtype=torch.float32, device=device) inst_id = torch.tensor([inst_id], dtype=torch.long, device=device) motion_tuples = {} for frame_id in frame_ids: - frame_id_torch = torch.tensor([frame_id], dtype=torch.long, device=device) - field2cam = field.camera_mlp.get_vals(frame_id_torch) + frame_id = frame_id[None] + field2cam = field.camera_mlp.get_vals(frame_id) samples_dict = {} - if isinstance(field.warp, SkinningWarp): + se3_mat = quaternion_translation_to_se3(field2cam[0], field2cam[1])[0] + se3_mat = se3_mat.cpu().numpy() + if hasattr(field, "warp") and isinstance(field.warp, SkinningWarp): ( samples_dict["t_articulation"], samples_dict["rest_articulation"], - ) = field.warp.articulation.get_vals_and_mean(frame_id_torch) + ) = field.warp.articulation.get_vals_and_mean(frame_id) t_articulation = samples_dict["t_articulation"] if isinstance(field.warp.articulation, ArticulationSkelMLP): - so3 = field.warp.articulation.get_vals(frame_id_torch, return_so3=True) + so3 = field.warp.articulation.get_vals(frame_id, return_so3=True)[0] + so3 = so3.cpu().numpy() else: so3 = None @@ -84,34 +100,36 @@ def extract_deformation(field, mesh_rest, inst_id, frame_ids): ), field.warp.articulation.edges, ) - se3_mat = quaternion_translation_to_se3(field2cam[0], field2cam[1])[0] - mesh_bones_t.apply_transform(se3_mat.cpu().numpy()) + # 1,K,4,4 + t_articulation = dual_quaternion_to_se3(t_articulation)[0] + t_articulation = t_articulation.cpu().numpy() else: t_articulation = None so3 = None mesh_bones_t = None - xyz_t = field.forward_warp( - xyz[None, None], - field2cam, - frame_id_torch, - inst_id, - samples_dict=samples_dict, - ) - xyz_t = xyz_t[0, 0] - mesh_t = trimesh.Trimesh(vertices=xyz_t.cpu().numpy(), faces=mesh_rest.faces) + if hasattr(field, "warp"): + # warp mesh + xyz_t = field.warp( + xyz[None, None], frame_id, inst_id, samples_dict=samples_dict + )[0, 0] + mesh_t = trimesh.Trimesh( + vertices=xyz_t.cpu().numpy(), faces=mesh_rest.faces, process=False + ) + else: + mesh_t = mesh_rest.copy() - field2cam[1][:] /= field.logscale.exp() # to world scale motion_expl = MotionParamsExpl( - field2cam=field2cam, + field2cam=se3_mat, t_articulation=t_articulation, so3=so3, mesh_t=mesh_t, bone_t=mesh_bones_t, ) - motion_tuples[frame_id] = motion_expl + frame_id_sub = (frame_id[0] - start_id).cpu() + motion_tuples[frame_id_sub] = motion_expl - if isinstance(field.warp, SkinningWarp): + if hasattr(field, "warp") and isinstance(field.warp, SkinningWarp): # modify rest mesh based on instance morphological changes on bones # idendity transformation of cameras field2cam_rot_idn = torch.zeros_like(field2cam[0]) @@ -134,37 +152,48 @@ def extract_deformation(field, mesh_rest, inst_id, frame_ids): return mesh_rest, motion_tuples +def rescale_motion_tuples(motion_tuples, field_scale): + """ + rescale motion tuples to world scale + """ + for frame_id, motion_tuple in motion_tuples.items(): + motion_tuple.field2cam[:3, 3] /= field_scale + motion_tuple.mesh_t.apply_scale(1.0 / field_scale) + if motion_tuple.bone_t is not None: + motion_tuple.bone_t.apply_scale(1.0 / field_scale) + if motion_tuple.t_articulation is not None: + motion_tuple.t_articulation[1][:] /= field_scale + return + + def save_motion_params(meshes_rest, motion_tuples, save_dir): for cate, mesh_rest in meshes_rest.items(): - mesh_rest.export("%s/%s.obj" % (save_dir, cate)) + mesh_rest.export("%s/%s-mesh.obj" % (save_dir, cate)) motion_params = {"field2cam": [], "t_articulation": [], "joint_so3": []} + os.makedirs("%s/fg/mesh/" % save_dir, exist_ok=True) + os.makedirs("%s/bg/mesh/" % save_dir, exist_ok=True) + os.makedirs("%s/fg/bone/" % save_dir, exist_ok=True) for frame_id, motion_expl in motion_tuples[cate].items(): # save mesh - motion_expl.mesh_t.export("%s/%s-%05d.obj" % (save_dir, cate, frame_id)) + motion_expl.mesh_t.export( + "%s/%s/mesh/%05d.obj" % (save_dir, cate, frame_id) + ) if motion_expl.bone_t is not None: motion_expl.bone_t.export( - "%s/%s-%05d-bone.obj" % (save_dir, cate, frame_id) + "%s/%s/bone/%05d.obj" % (save_dir, cate, frame_id) ) # save motion params - field2cam = quaternion_translation_to_se3( - motion_expl.field2cam[0], motion_expl.field2cam[1] - ) # 1,4,4 - motion_params["field2cam"].append(field2cam.cpu().numpy()[0].tolist()) + motion_params["field2cam"].append(motion_expl.field2cam.tolist()) if motion_expl.t_articulation is not None: - t_articulation = dual_quaternion_to_se3( - motion_expl.t_articulation - ) # 1,K,4,4 motion_params["t_articulation"].append( - t_articulation.cpu().numpy()[0].tolist() + motion_expl.t_articulation.tolist() ) if motion_expl.so3 is not None: - motion_params["joint_so3"].append( - motion_expl.so3.cpu().numpy()[0].tolist() - ) # K,3 + motion_params["joint_so3"].append(motion_expl.so3.tolist()) # K,3 - with open("%s/%s-motion.json" % (save_dir, cate), "w") as fp: + with open("%s/%s/motion.json" % (save_dir, cate), "w") as fp: json.dump(motion_params, fp) @@ -175,26 +204,48 @@ def extract_motion_params(model, opts, data_info): grid_size=opts["grid_size"], level=opts["level"], inst_id=opts["inst_id"], - use_visibility=False, - use_extend_aabb=False, + vis_thresh=opts["vis_thresh"], + use_extend_aabb=opts["extend_aabb"], ) - # get absolute frame ids - inst_id = opts["inst_id"] - frame_mapping = data_info["frame_info"]["frame_mapping"] - frame_offset = data_info["frame_info"]["frame_offset"] - frame_ids = frame_mapping[frame_offset[inst_id] : frame_offset[inst_id + 1]] - print("Extracting motion parameters for frame ids:", frame_ids) - # get deformation motion_tuples = {} for cate, field in model.fields.field_params.items(): meshes_rest[cate], motion_tuples[cate] = extract_deformation( - field, meshes_rest[cate], opts["inst_id"], frame_ids=frame_ids + field, meshes_rest[cate], opts["inst_id"] + ) + + # scale + if "bg" in model.fields.field_params.keys(): + bg_field = model.fields.field_params["bg"] + bg_scale = bg_field.logscale.exp().cpu().numpy() + if "fg" in model.fields.field_params.keys(): + fg_field = model.fields.field_params["fg"] + fg_scale = fg_field.logscale.exp().cpu().numpy() + + if ( + "bg" in model.fields.field_params.keys() + and model.fields.field_params["bg"].valid_field2world() + ): + # visualize ground plane + field2world = ( + model.fields.field_params["bg"].get_field2world(opts["inst_id"]).cpu() + ) + field2world[..., :3, 3] *= bg_scale + meshes_rest["bg"] = append_xz_plane( + meshes_rest["bg"], field2world.inverse(), scale=20 * bg_scale ) + + if "fg" in model.fields.field_params.keys(): + meshes_rest["fg"] = meshes_rest["fg"].apply_scale(1.0 / fg_scale) + rescale_motion_tuples(motion_tuples["fg"], fg_scale) + if "bg" in model.fields.field_params.keys(): + meshes_rest["bg"] = meshes_rest["bg"].apply_scale(1.0 / bg_scale) + rescale_motion_tuples(motion_tuples["bg"], bg_scale) return meshes_rest, motion_tuples +@torch.no_grad() def export(opts): model, data_info, ref_dict = Trainer.construct_test_model(opts) save_dir = make_save_dir(opts, sub_dir="export_%04d" % (opts["inst_id"])) @@ -203,11 +254,33 @@ def export(opts): meshes_rest, motion_tuples = extract_motion_params(model, opts, data_info) save_motion_params(meshes_rest, motion_tuples, save_dir) + # save scene to world transform + if ( + "bg" in model.fields.field_params.keys() + and model.fields.field_params["bg"].valid_field2world() + ): + field2world = model.fields.field_params["bg"].get_field2world(opts["inst_id"]) + field2world = field2world.cpu().numpy().tolist() + json.dump(field2world, open("%s/bg/field2world.json" % (save_dir), "w")) + + # same raw image size and intrinsics + with torch.no_grad(): + intrinsics = model.intrinsics.get_intrinsics(opts["inst_id"]) + camera_info = {} + camera_info["raw_size"] = data_info["raw_size"][opts["inst_id"]].tolist() + camera_info["intrinsics"] = intrinsics.cpu().numpy().tolist() + json.dump(camera_info, open("%s/camera.json" % (save_dir), "w")) + # save reference images raw_size = data_info["raw_size"][opts["inst_id"]] # full range of pixels save_rendered(ref_dict, save_dir, raw_size, data_info["apply_pca_fn"]) print("Saved to %s" % save_dir) + # mesh rendering + cmd = "python lab4d/render_mesh.py --testdir %s" % (save_dir) + print("Running: %s" % cmd) + os.system(cmd) + def main(_): opts = get_config() diff --git a/lab4d/mesh_viewer.py b/lab4d/mesh_viewer.py new file mode 100644 index 0000000..fee6420 --- /dev/null +++ b/lab4d/mesh_viewer.py @@ -0,0 +1,183 @@ +"""modified from https://github.com/nerfstudio-project/viser/blob/main/examples/07_record3d_visualizer.py +python lab4d/mesh_viewer.py --testdir logdir//ama-bouncing-4v-ppr-exp/export_0000/ +""" + +import os, sys +import pdb +import time +from pathlib import Path +from typing import List +import argparse + +import cv2 +import numpy as np +import tyro +from tqdm.auto import tqdm + +import viser +import viser.extras +import viser.transforms as tf + +cwd = os.getcwd() +if cwd not in sys.path: + sys.path.insert(0, cwd) +from lab4d.utils.mesh_loader import MeshLoader + + +parser = argparse.ArgumentParser(description="script to render extraced meshes") +parser.add_argument("--testdir", default="", help="path to the directory with results") +parser.add_argument("--fps", default=30, type=int, help="fps of the video") +parser.add_argument("--mode", default="", type=str, help="{shape, bone}") +parser.add_argument("--compose_mode", default="", type=str, help="{object, scene}") +parser.add_argument("--ghosting", action="store_true", help="ghosting") +parser.add_argument("--view", default="ref", type=str, help="{ref, bev, front}") +args = parser.parse_args() + + +def find_seqname(testdir): + parts = [part for part in testdir.split("/") if part] + logdir = "/".join(parts[:2]) + logdir = os.path.join(logdir, "opts.log") + with open(logdir, "r") as file: + for line in file: + if "--seqname" in line: + seqname = line.split("--")[1].split("=")[1].strip() + break + if "seqname" not in locals(): + raise ValueError("Could not find seqname in opts.log") + inst_id = int(parts[2].split("_")[-1]) + seqname = "%s-%04d" % (seqname, inst_id) + return seqname + + +def main( + share: bool = False, +) -> None: + server = viser.ViserServer(share=share) + + downsample_factor = 4 + print("Loading frames!") + loader = MeshLoader(args.testdir, args.mode, args.compose_mode) + loader.print_info() + loader.load_files(ghosting=args.ghosting) + num_frames = len(loader) + fps = args.fps + + # load images + seqname = find_seqname(args.testdir) + img_dir = "database/processed/JPEGImages/Full-Resolution/%s/" % seqname + rgb_list = [cv2.imread("%s/%05d.jpg" % (img_dir, i)) for i in range(num_frames)] + rgb_list = [rgb[::downsample_factor, ::downsample_factor, ::-1] for rgb in rgb_list] + + # Add playback UI. + with server.add_gui_folder("Playback"): + gui_timestep = server.add_gui_slider( + "Timestep", + min=0, + max=num_frames - 1, + step=1, + initial_value=0, + disabled=True, + ) + gui_next_frame = server.add_gui_button("Next Frame", disabled=True) + gui_prev_frame = server.add_gui_button("Prev Frame", disabled=True) + gui_playing = server.add_gui_checkbox("Playing", True) + gui_framerate = server.add_gui_slider( + "FPS", min=1, max=60, step=0.1, initial_value=fps + ) + gui_framerate_options = server.add_gui_button_group( + "FPS options", ("10", "20", "30", "60") + ) + + # Frame step buttons. + @gui_next_frame.on_click + def _(_) -> None: + gui_timestep.value = (gui_timestep.value + 1) % num_frames + + @gui_prev_frame.on_click + def _(_) -> None: + gui_timestep.value = (gui_timestep.value - 1) % num_frames + + # Disable frame controls when we're playing. + @gui_playing.on_update + def _(_) -> None: + gui_timestep.disabled = gui_playing.value + gui_next_frame.disabled = gui_playing.value + gui_prev_frame.disabled = gui_playing.value + + # Set the framerate when we click one of the options. + @gui_framerate_options.on_click + def _(_) -> None: + gui_framerate.value = int(gui_framerate_options.value) + + prev_timestep = gui_timestep.value + + # Toggle frame visibility when the timestep slider changes. + @gui_timestep.on_update + def _(_) -> None: + nonlocal prev_timestep + current_timestep = gui_timestep.value + with server.atomic(): + frame_nodes[current_timestep].visible = True + frame_nodes[prev_timestep].visible = False + prev_timestep = current_timestep + + # Load in frames. + server.add_frame( + "/frames", + wxyz=tf.SO3.exp(np.array([-np.pi / 2, 0.0, 0.0])).wxyz, + position=(0, 0, 0), + show_axes=True, + ) + frame_nodes: List[viser.FrameHandle] = [] + input_dict = loader.query_frame(0) + if "scene" in input_dict: + server.add_mesh_trimesh(name=f"/frames/scene", mesh=input_dict["scene"]) + for i in tqdm(range(num_frames)): + # Add base frame. + frame_nodes.append(server.add_frame(f"/frames/t{i}", show_axes=False)) + + input_dict = loader.query_frame(i) + server.add_mesh_trimesh(name=f"/frames/t{i}/shape", mesh=input_dict["shape"]) + if "bone" in input_dict: + server.add_mesh_trimesh(name=f"/frames/t{i}/bone", mesh=input_dict["bone"]) + + # Place the frustum. + rgb = rgb_list[i] + extrinsics = np.linalg.inv(loader.extr_dict[i]) + intrinsics = loader.intrinsics[i] / downsample_factor + fov = 2 * np.arctan2(rgb.shape[0] / 2, intrinsics[0]) + aspect = rgb.shape[1] / rgb.shape[0] + server.add_camera_frustum( + f"/frames/t{i}/frustum", + fov=fov, + aspect=aspect, + scale=0.3, + image=rgb, + wxyz=tf.SO3.from_matrix(extrinsics[:3, :3]).wxyz, + position=extrinsics[:3, 3], + ) + + # Add some axes. + server.add_frame( + f"/frames/t{i}/frustum/axes", + axes_length=0.01, + axes_radius=0.005, + ) + + # Hide all but the current frame. + for i, frame_node in enumerate(frame_nodes): + frame_node.visible = i == gui_timestep.value + + # Playback update loop. + prev_timestep = gui_timestep.value + while True: + if gui_playing.value: + gui_timestep.value = (gui_timestep.value + 1) % num_frames + + time.sleep(1.0 / gui_framerate.value) + + +if __name__ == "__main__": + # tyro.cli(main) + main() diff --git a/lab4d/nnutils/base.py b/lab4d/nnutils/base.py index 8115002..28b11b6 100644 --- a/lab4d/nnutils/base.py +++ b/lab4d/nnutils/base.py @@ -1,8 +1,10 @@ # Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. import torch import torch.nn as nn +import torch.nn.functional as F from lab4d.nnutils.embedding import InstEmbedding +from functorch import vmap, combine_state_for_ensemble class ScaleLayer(nn.Module): @@ -155,3 +157,184 @@ def get_dim_inst(num_inst, inst_channels): return inst_channels else: return 0 + + +# class PosEncArch(nn.Module): +# def __init__(self, in_channels, N_freqs) -> None: +# super().__init__() +# self.pos_embedding = PosEmbedding(in_channels, N_freqs) + + +class DictMLP(BaseMLP): + """A MLP that accepts both input `x` and condition `c` + + Args: + num_inst (int): Number of distinct object instances. If --nosingle_inst + is passed, this is equal to the number of videos, as we assume each + video captures a different instance. Otherwise, we assume all videos + capture the same instance and set this to 1. + D (int): Number of linear layers for density (sigma) encoder + W (int): Number of hidden units in each MLP layer + in_channels (int): Number of channels in input `x` + inst_channels (int): Number of channels in condition `c` + out_channels (int): Number of output channels + skips (List(int)): List of layers to add skip connections at + activation (Function): Activation function to use (e.g. nn.ReLU()) + final_act (bool): If True, apply the activation function to the output + """ + + def __init__( + self, + num_inst, + D=8, + W=256, + in_channels=63, + inst_channels=32, + out_channels=3, + skips=[4], + activation=nn.ReLU(True), + final_act=False, + ): + super().__init__( + D=D, + W=W, + in_channels=in_channels + inst_channels, + out_channels=out_channels, + skips=skips, + activation=activation, + final_act=False, + ) + + self.basis = BaseMLP( + D=D, + W=W, + in_channels=in_channels, + out_channels=out_channels, + skips=skips, + activation=activation, + final_act=final_act, + ) + + self.inst_embedding = InstEmbedding(num_inst, inst_channels) + + def forward(self, feat, inst_id): + """ + Args: + feat: (M, ..., self.in_channels) + inst_id: (M,) Instance id, or None to use the average instance + Returns: + out: (M, ..., self.out_channels) + """ + if inst_id is None: + if self.inst_embedding.out_channels > 0: + inst_code = self.inst_embedding.get_mean_embedding() + inst_code = inst_code.expand(feat.shape[:-1] + (-1,)) + # print("inst_embedding exists but inst_id is None, using mean inst_code") + else: + # empty, falls back to single-instance NeRF + inst_code = torch.zeros(feat.shape[:-1] + (0,), device=feat.device) + else: + inst_code = self.inst_embedding(inst_id) + inst_code = inst_code.view( + inst_code.shape[:1] + (1,) * (feat.ndim - 2) + (-1,) + ) + inst_code = inst_code.expand(feat.shape[:-1] + (-1,)) + + out = torch.cat([feat, inst_code], -1) + # if both input feature and inst_code are empty, return zeros + if out.shape[-1] == 0: + return out + coeff = super().forward(out) + coeff = F.normalize(coeff, dim=-1) + basis = self.basis(feat) + out = coeff * basis + return out + + @staticmethod + def get_dim_inst(num_inst, inst_channels): + if num_inst > 1: + return inst_channels + else: + return 0 + + +class MultiMLP(nn.Module): + """Independent MLP for each instance""" + + def __init__(self, num_inst, inst_channels=32, **kwargs): + super(MultiMLP, self).__init__() + self.in_channels = kwargs["in_channels"] + self.out_channels = kwargs["out_channels"] + self.num_inst = num_inst + # ensemble version + self.nets = [] + for i in range(num_inst): + self.nets.append(BaseMLP(**kwargs)) + self.nets = nn.ModuleList(self.nets) + + def forward(self, feat, inst_id): + """ + Args: + feat: (M, ..., self.in_channels) + inst_id: (M,) Instance id, or None to use the average instance + Returns: + out: (M, ..., self.out_channels) + """ + # rearrange the batch dimension + shape = feat.shape[:-1] + device = feat.device + inst_id = inst_id.view((-1,) + (1,) * (len(shape) - 1)) + inst_id = inst_id.expand(shape) + + # sequential version: avoid duplicate computation + out = torch.zeros(shape + (self.out_channels,), device=feat.device) + empty_input = torch.zeros(1, 1, self.in_channels, device=feat.device) + for it, net in enumerate(self.nets): + id_sel = inst_id == it + if id_sel.sum() == 0: + out = out + self.nets[it](empty_input).mean() * 0 + continue + out[id_sel] = net(feat[id_sel]) + return out + + +class MixMLP(nn.Module): + """Mixing CondMLP and MultiMLP""" + + def __init__(self, num_inst, inst_channels=32, **kwargs): + super(MixMLP, self).__init__() + self.multimlp = MultiMLP(num_inst, inst_channels=inst_channels, **kwargs) + kwargs["D"] *= 5 # 5 + kwargs["W"] *= 2 # 128 + self.condmlp = CondMLP(num_inst, inst_channels=inst_channels, **kwargs) + + def forward(self, feat, inst_id): + out1 = self.condmlp(feat, inst_id) + out2 = self.multimlp(feat, inst_id) + out = out1 + out2 + return out + + +# class Triplane(nn.Module): +# """Triplane""" + +# def __init__(self, num_inst, inst_channels=32, **kwargs) -> None: +# super(Triplane, self).__init__() +# init_scale = 0.1 +# resolution = 128 +# num_components = 24 +# self.plane = nn.Parameter( +# init_scale * torch.randn((3 * resolution * resolution, num_components)) +# ) + +# def forward(self, feat, inst_id): +# """ +# Args: +# feat: (M, ..., self.in_channels) +# inst_id: (M,) Instance id, or None to use the average instance +# Returns: +# out: (M, ..., self.out_channels) +# """ +# # rearrange the batch dimension +# shape = feat.shape[:-1] +# return out diff --git a/lab4d/nnutils/bgnerf.py b/lab4d/nnutils/bgnerf.py new file mode 100644 index 0000000..e5344e5 --- /dev/null +++ b/lab4d/nnutils/bgnerf.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +from lab4d.nnutils.nerf import NeRF + +import trimesh +from pysdf import SDF + +from lab4d.utils.quat_transform import quaternion_translation_to_se3 +from lab4d.utils.geom_utils import get_near_far, extend_aabb +from lab4d.nnutils.base import MixMLP, MultiMLP, CondMLP, DictMLP +from lab4d.nnutils.visibility import VisField + + +class BGNeRF(NeRF): + """A static neural radiance field with an MLP backbone. Specialized to background.""" + + # def __init__(self, data_info, field_arch=CondMLP, D=5, W=128, **kwargs): + # def __init__(self, data_info, field_arch=MixMLP, D=1, W=64, **kwargs): + def __init__(self, data_info, field_arch=DictMLP, D=8, W=256, **kwargs): + super(BGNeRF, self).__init__( + data_info, field_arch=field_arch, D=D, W=W, **kwargs + ) + # self.vis_mlp = VisField(self.num_inst, D=D, W=W, field_arch=field_arch) + self.vis_mlp = VisField(self.num_inst, D=1, W=64, field_arch=MixMLP) + # TODO: update per-scene beta + # TODO: update per-scene scale + + def init_proxy(self, geom_paths, init_scale): + """Initialize the geometry from a mesh + + Args: + geom_path (Listy(str)): Initial shape mesh + init_scale (float): Geometry scale factor + """ + meshes = [] + for geom_path in geom_paths: + mesh = trimesh.load(geom_path) + mesh.vertices = mesh.vertices * init_scale + meshes.append(mesh) + self.proxy_geometry = meshes + + def get_proxy_geometry(self): + """Get proxy geometry + + Returns: + proxy_geometry (Trimesh): Proxy geometry + """ + return self.proxy_geometry[0] + + def init_aabb(self): + """Initialize axis-aligned bounding box""" + self.register_buffer("aabb", torch.zeros(len(self.proxy_geometry), 2, 3)) + self.update_aabb(beta=0) + + def get_init_sdf_fn(self): + """Initialize signed distance function from mesh geometry + + Returns: + sdf_fn_torch (Function): Signed distance function + """ + + def sdf_fn_torch_sphere(pts): + radius = 0.1 + # l2 distance to a unit sphere + dis = (pts).pow(2).sum(-1, keepdim=True) + sdf = torch.sqrt(dis) - radius # negative inside, postive outside + return sdf + + return sdf_fn_torch_sphere + + def update_proxy(self): + """Extract proxy geometry using marching cubes""" + for inst_id in range(self.num_inst): + mesh = self.extract_canonical_mesh(level=0.005, inst_id=inst_id) + if len(mesh.vertices) > 0: + self.proxy_geometry[inst_id] = mesh + + def get_aabb(self, inst_id=None): + """Get axis-aligned bounding box + Args: + inst_id: (N,) Instance id + Returns: + aabb: (1,2,3) Axis-aligned bounding box if inst_id is None, (N,2,3) otherwise + """ + if inst_id is None: + return self.aabb.mean(0, keepdim=True) + else: + return self.aabb[inst_id] + + def update_aabb(self, beta=0.9): + """Update axis-aligned bounding box by interpolating with the current + proxy geometry's bounds + + Args: + beta (float): Interpolation factor between previous/current values + """ + device = self.aabb.device + for inst_id in range(self.num_inst): + bounds = self.proxy_geometry[inst_id].bounds + if bounds is not None: + aabb = torch.tensor(bounds, dtype=torch.float32, device=device) + aabb = extend_aabb(aabb, factor=0.2) # 1.4x larger + self.aabb[inst_id] = self.aabb[inst_id] * beta + aabb * (1 - beta) + + def update_near_far(self, beta=0.9): + """Update near-far bounds by interpolating with the current near-far bounds + + Args: + beta (float): Interpolation factor between previous/current values + """ + device = next(self.parameters()).device + with torch.no_grad(): + quat, trans = self.camera_mlp.get_vals() # (B, 4, 4) + rtmat = quaternion_translation_to_se3(quat, trans) + + frame_id_all = list(range(self.num_frames)) + frame_offset = self.frame_offset + near_far_all = [] + for inst_id in range(self.num_inst): + verts = self.proxy_geometry[inst_id].vertices + frame_id = frame_id_all[frame_offset[inst_id] : frame_offset[inst_id + 1]] + proxy_pts = torch.tensor(verts, dtype=torch.float32, device=device) + near_far = get_near_far(proxy_pts, rtmat[frame_id]).to(device) + near_far_all.append( + self.near_far[frame_id].data * beta + near_far * (1 - beta) + ) + self.near_far.data = torch.cat(near_far_all, 0) + + def get_near_far(self, frame_id, field2cam): + device = next(self.parameters()).device + frame_id_all = list(range(self.num_frames)) + frame_offset = self.frame_offset + field2cam_mat = quaternion_translation_to_se3(field2cam[0], field2cam[1]) + + near_far_all = [] + for inst_id in range(self.num_inst): + frame_id_sel = frame_id_all[ + frame_offset[inst_id] : frame_offset[inst_id + 1] + ] + # find the overlap of frame_id and frame_id_sel + id_sel = [i for i, x in enumerate(frame_id) if x in frame_id_sel] + if len(id_sel) == 0: + continue + corners = trimesh.bounds.corners(self.proxy_geometry[inst_id].bounds) + corners = torch.tensor(corners, dtype=torch.float32, device=device) + near_far = get_near_far(corners, field2cam_mat[id_sel], tol_fac=1.5) + near_far_all.append(near_far) + near_far = torch.cat(near_far_all, 0) + return near_far diff --git a/lab4d/nnutils/deformable.py b/lab4d/nnutils/deformable.py index 6410c13..7b0d779 100644 --- a/lab4d/nnutils/deformable.py +++ b/lab4d/nnutils/deformable.py @@ -4,13 +4,16 @@ import trimesh from torch import nn from torch.nn import functional as F +import sys +import os + +os.environ["CUDA_PATH"] = sys.prefix # needed for geomloss +from geomloss import SamplesLoss from lab4d.nnutils.feature import FeatureNeRF from lab4d.nnutils.warping import SkinningWarp, create_warp from lab4d.utils.decorator import train_only_fields -from lab4d.utils.geom_utils import extend_aabb -from lab4d.utils.loss_utils import align_vectors -from lab4d.engine.train_utils import get_local_rank +from lab4d.utils.geom_utils import extend_aabb, check_inside_aabb class Deformable(FeatureNeRF): @@ -83,6 +86,15 @@ def __init__( self.warp = create_warp(fg_motion, data_info) self.fg_motion = fg_motion + # def update_aabb(self, beta=0.5): + # """Update axis-aligned bounding box by interpolating with the current + # proxy geometry's bounds + + # Args: + # beta (float): Interpolation factor between previous/current values + # """ + # super().update_aabb(beta=beta) + def init_proxy(self, geom_path, init_scale): """Initialize proxy geometry as a sphere @@ -111,7 +123,7 @@ def sdf_fn_torch_skel(pts): sdf = self.warp.get_gauss_sdf(pts) return sdf - if "skel-" in self.fg_motion: + if "skel-" in self.fg_motion or "urdf-" in self.fg_motion: return sdf_fn_torch_skel else: return sdf_fn_torch_sphere @@ -141,7 +153,7 @@ def backward_warp( xyz_t, frame_id, inst_id, - backward=True, + type="backward", samples_dict=samples_dict, return_aux=True, ) @@ -170,6 +182,34 @@ def forward_warp(self, xyz, field2cam, frame_id, inst_id, samples_dict={}): xyz_cam = self.field_to_cam(xyz_next, field2cam) return xyz_cam + def flow_warp( + self, + xyz_1, + field2cam_flip, + frame_id, + inst_id, + samples_dict={}, + ): + """Warp points from camera space from time t1 to time t2 + + Args: + xyz_1: (M,N,D,3) Points along rays in canonical space at time t1 + field2cam_flip: (M,SE(3)) Object-to-camera SE(3) transform at time t2 + frame_id: (M,) Frame id. If None, warp for all frames + inst_id: (M,) Instance id. If None, warp for the average instance + samples_dict (Dict): Time-dependent bone articulations. Keys: + "rest_articulation": ((M,B,4), (M,B,4)) and + "t_articulation": ((M,B,4), (M,B,4)) + + Returns: + xyz_2: (M,N,D,3) Points along rays in camera space at time t2 + """ + xyz_2 = self.warp( + xyz_1, frame_id, inst_id, type="flow", samples_dict=samples_dict + ) + xyz_2 = self.field_to_cam(xyz_2, field2cam_flip) + return xyz_2 + @train_only_fields def cycle_loss(self, xyz, xyz_t, frame_id, inst_id, samples_dict={}): """Enforce cycle consistency between points in object canonical space, @@ -197,22 +237,68 @@ def cycle_loss(self, xyz, xyz_t, frame_id, inst_id, samples_dict={}): cyc_dict.update(warp_dict) return cyc_dict - def gauss_skin_consistency_loss(self, nsample=2048): - """Enforce consistency between the NeRF's SDF and the SDF of Gaussian bones + def gauss_skin_consistency_loss(self, type="optimal_transport"): + """Enforce consistency between the NeRF's SDF and the SDF of Gaussian bones, + + Args: + type (str): "optimal_transport" or "density" + Returns: + loss: (0,) Skinning consistency loss + """ + if type == "optimal_transport": + return self.gauss_optimal_transport_loss() + elif type == "density": + return self.gauss_skin_density_loss() + else: + raise NotImplementedError + + def gauss_skin_density_loss(self, nsample=4096): + """Enforce consistency between the NeRF's SDF and the SDF of Gaussian bones, + based on density. Args: nsample (int): Number of samples to take from both distance fields Returns: loss: (0,) Skinning consistency loss """ - pts = self.sample_points_aabb(nsample, extend_factor=0.25) + pts, frame_id, _ = self.sample_points_aabb(nsample, extend_factor=0.5) + inst_id = None + samples_dict = {} + ( + samples_dict["t_articulation"], + samples_dict["rest_articulation"], + ) = self.warp.articulation.get_vals_and_mean(frame_id) # match the gauss density to the reconstructed density - density_gauss = self.warp.get_gauss_density(pts) # (N,1) + bones2obj = samples_dict["t_articulation"] + bones2obj = ( + torch.cat([bones2obj[0], samples_dict["rest_articulation"][0]], 0), + torch.cat([bones2obj[1], samples_dict["rest_articulation"][1]], 0), + ) + pts_gauss = torch.cat([pts, pts], dim=0) + density_gauss = self.warp.get_gauss_density(pts_gauss, bone2obj=bones2obj) + with torch.no_grad(): - density = self.forward(pts, inst_id=None, get_density=True) + density = torch.zeros_like(density_gauss) + pts_warped = self.warp( + pts[:, None, None], + frame_id, + inst_id, + type="backward", + samples_dict=samples_dict, + return_aux=False, + )[:, 0, 0] + pts = torch.cat([pts_warped, pts], dim=0) + + # check whether the point is inside the aabb + aabb = self.get_aabb() + aabb = extend_aabb(aabb) + inside_aabb = check_inside_aabb(pts, aabb) + + _, density[inside_aabb] = self.forward(pts[inside_aabb], inst_id=inst_id) density = density / self.logibeta.exp() # (0,1) + # loss = ((density_gauss - density).pow(2)).mean() # binary cross entropy loss to align gauss density to the reconstructed density # weight the loss such that: # wp lp = wn ln @@ -220,10 +306,10 @@ def gauss_skin_consistency_loss(self, nsample=2048): weight_pos = 0.5 / (1e-6 + density.mean()) weight_neg = 0.5 / (1e-6 + 1 - density).mean() weight = density * weight_pos + (1 - density) * weight_neg - # loss = ((density_gauss - density).pow(2) * weight.detach()).mean() - loss = F.binary_cross_entropy( - density_gauss, density.detach(), weight=weight.detach() - ) + loss = ((density_gauss - density).pow(2) * weight.detach()).mean() + # loss = F.binary_cross_entropy( + # density_gauss, density.detach(), weight=weight.detach() + # ) # if get_local_rank() == 0: # is_inside = density > 0.5 @@ -235,6 +321,34 @@ def gauss_skin_consistency_loss(self, nsample=2048): # mesh.export("tmp/1.obj") return loss + def gauss_optimal_transport_loss(self, nsample=1024): + """Enforce consistency between the NeRF's proxy rest shape + and the gaussian bones, based on optimal transport. + + Args: + nsample (int): Number of samples to take from proxy geometry + Returns: + loss: (0,) Gaussian optimal transport loss + """ + # optimal transport loss + device = self.parameters().__next__().device + pts = self.get_proxy_geometry().vertices + # sample points from the proxy geometry + pts = pts[np.random.choice(len(pts), nsample)] + pts = torch.tensor(pts, device=device, dtype=torch.float32) + pts_gauss = self.warp.get_gauss_pts() + samploss = SamplesLoss( + loss="sinkhorn", p=2, blur=0.002, scaling=0.5, truncate=1 + ) + scale_proxy = self.get_scale() # to normalize pts to 1 + loss = samploss(2 * pts_gauss / scale_proxy, 2 * pts / scale_proxy).mean() + # if get_local_rank() == 0: + # mesh = trimesh.Trimesh(vertices=pts.detach().cpu()) + # mesh.export("tmp/0.obj") + # mesh = trimesh.Trimesh(vertices=pts_gauss.detach().cpu()) + # mesh.export("tmp/1.obj") + return loss + def soft_deform_loss(self, nsample=1024): """Minimize soft deformation so it doesn't overpower the skeleton. Compute L2 distance of points before and after soft deformation @@ -244,10 +358,7 @@ def soft_deform_loss(self, nsample=1024): Returns: loss: (0,) Soft deformation loss """ - device = next(self.parameters()).device - pts = self.sample_points_aabb(nsample, extend_factor=1.0) - frame_id = torch.randint(0, self.num_frames, (nsample,), device=device) - inst_id = torch.randint(0, self.num_inst, (nsample,), device=device) + pts, frame_id, inst_id = self.sample_points_aabb(nsample, extend_factor=1.0) dist2 = self.warp.compute_post_warp_dist2(pts[:, None, None], frame_id, inst_id) return dist2.mean() @@ -293,7 +404,7 @@ def mlp_init(self): from an external skeleton """ super().mlp_init() - if self.fg_motion.startswith("skel"): + if "skel-" in self.fg_motion or "urdf-" in self.fg_motion: if hasattr(self.warp.articulation, "init_vals"): self.warp.articulation.mlp_init() @@ -321,12 +432,13 @@ def query_field(self, samples_dict, flow_thresh=None): # xyz = feat_dict["xyz"].detach() # don't backprop to cam/dfm fields xyz = feat_dict["xyz"] - gauss_field = self.compute_gauss_density(xyz, samples_dict) + xyz_t = feat_dict["xyz_t"] + gauss_field = self.compute_gauss_density(xyz, xyz_t, samples_dict) feat_dict.update(gauss_field) return feat_dict, deltas, aux_dict - def compute_gauss_density(self, xyz, samples_dict): + def compute_gauss_density(self, xyz, xyz_t, samples_dict): """If this is a SkinningWarp, compute density from Gaussian bones Args: @@ -339,18 +451,30 @@ def compute_gauss_density(self, xyz, samples_dict): Returns: gauss_field (Dict): Density. Keys: "gauss_density" (M,N,D,1) """ + M, N, D, _ = xyz.shape gauss_field = {} if isinstance(self.warp, SkinningWarp): - shape = xyz.shape[:-1] - if "rest_articulation" in samples_dict: - rest_articulation = ( - samples_dict["rest_articulation"][0][:1], - samples_dict["rest_articulation"][1][:1], - ) - xyz = xyz.view(-1, 3) - gauss_density = self.warp.get_gauss_density(xyz, bone2obj=rest_articulation) + # supervise t articulation + xyz_t = xyz_t.view(-1, 3).detach() + t_articulation = ( + samples_dict["t_articulation"][0][:, None] + .repeat(1, N * D, 1, 1) + .view(M * N * D, -1, 4), + samples_dict["t_articulation"][1][:, None] + .repeat(1, N * D, 1, 1) + .view(M * N * D, -1, 4), + ) + gauss_density = self.warp.get_gauss_density(xyz_t, bone2obj=t_articulation) + + # supervise rest articulation + # rest_articulation = ( + # samples_dict["rest_articulation"][0][:1], + # samples_dict["rest_articulation"][1][:1], + # ) + # xyz = xyz.view(-1, 3).detach() + # gauss_density = self.warp.get_gauss_density(xyz, bone2obj=rest_articulation) # gauss_density = gauss_density * 100 # [0,100] heuristic value gauss_density = gauss_density * self.warp.logibeta.exp() - gauss_field["gauss_density"] = gauss_density.view(shape + (1,)) + gauss_field["gauss_density"] = gauss_density.view((M, N, D, 1)) return gauss_field diff --git a/lab4d/nnutils/embedding.py b/lab4d/nnutils/embedding.py index 57f3bb7..f11903e 100644 --- a/lab4d/nnutils/embedding.py +++ b/lab4d/nnutils/embedding.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from lab4d.utils.torch_utils import frameid_to_vid +from lab4d.utils.geom_utils import get_pre_rotation def get_fourier_embed_dim(in_channels, N_freqs): @@ -31,13 +32,20 @@ class PosEmbedding(nn.Module): in_channels (int): Number of input channels (3 for both xyz, direction) N_freqs (int): Number of frequency bands logscale (bool): If True, construct frequency bands in log-space + pre_rotate (bool): If True, pre-rotate the input along each plane """ - def __init__(self, in_channels, N_freqs, logscale=True): + def __init__(self, in_channels, N_freqs, logscale=True, pre_rotate=False): super().__init__() self.N_freqs = N_freqs self.in_channels = in_channels + if pre_rotate: + # rotate along each dimension for 45 degrees + rot_mat = get_pre_rotation(in_channels) + rot_mat = torch.tensor(rot_mat, dtype=torch.float32) + self.register_buffer("rot_mat", rot_mat, persistent=False) + # no embedding if N_freqs == -1: self.out_channels = 0 @@ -52,8 +60,7 @@ def __init__(self, in_channels, N_freqs, logscale=True): else: freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs) self.register_buffer("freq_bands", freq_bands, persistent=False) - - self.set_alpha(None) + self.register_buffer("alpha", torch.tensor(-1.0, dtype=torch.float32)) def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -62,9 +69,9 @@ def set_alpha(self, alpha): """Set the alpha parameter for the annealing window Args: - alpha (float or None): 0 to 1 + alpha (float): 0 to 1, -1 represents full frequency band """ - self.alpha = alpha + self.alpha.data = alpha def forward(self, x): """Embeds x to (x, sin(2^k x), cos(2^k x), ...) @@ -92,18 +99,27 @@ def forward(self, x): out = torch.empty(x.shape[0], output_dim, dtype=x.dtype, device=device) out[:, :input_dim] = x + if hasattr(self, "rot_mat"): + x = x @ self.rot_mat.T + x = x.view(x.shape[0], input_dim, -1) + # assign fourier features to the remaining channels out_bands = out[:, input_dim:].view( -1, self.N_freqs, self.nfuncs, input_dim ) for i, func in enumerate(self.funcs): # (B, nfreqs, input_dim) = (1, nfreqs, 1) * (B, 1, input_dim) - out_bands[:, :, i] = func( - self.freq_bands[None, :, None] * x[:, None, :] - ) + if hasattr(self, "rot_mat"): + signal = self.freq_bands[None, :, None, None] * x[:, None] + response = func(signal) + response = response.view(-1, self.N_freqs, input_dim, x.shape[-1]) + response = response.mean(-1) + else: + signal = self.freq_bands[None, :, None] * x[:, None, :] + response = func(signal) + out_bands[:, :, i] = response self.apply_annealing(out_bands) - out = out.view(out_shape) else: out = x @@ -116,7 +132,7 @@ def apply_annealing(self, out_bands): out_bands: (..., N_freqs, nfuncs, in_channels) Frequency bands """ device = out_bands.device - if self.alpha is not None: + if self.alpha >= 0: alpha_freq = self.alpha * self.N_freqs window = alpha_freq - torch.arange(self.N_freqs).to(device) window = torch.clamp(window, 0.0, 1.0) @@ -150,6 +166,7 @@ def __init__(self, num_freq_t, frame_info, out_channels=128, time_scale=1.0): self.out_channels = out_channels self.frame_offset = frame_info["frame_offset"] + self.frame_offset_raw = frame_info["frame_offset_raw"] self.num_frames = self.frame_offset[-1] self.num_vids = len(self.frame_offset) - 1 @@ -170,6 +187,9 @@ def __init__(self, num_freq_t, frame_info, out_channels=128, time_scale=1.0): ) # M, in range [0,N-1], M 0: - inst_id = self.randomize_instance(inst_id) - inst_code = self.mapping(inst_id) + inst_code = self.mapping(torch.zeros_like(inst_id)) + else: + if self.training and self.beta_prob > 0: + inst_id = self.randomize_instance(inst_id) + inst_code = self.mapping(inst_id) return inst_code def randomize_instance(self, inst_id): diff --git a/lab4d/nnutils/feature.py b/lab4d/nnutils/feature.py index 113f34d..38c4db5 100644 --- a/lab4d/nnutils/feature.py +++ b/lab4d/nnutils/feature.py @@ -2,13 +2,19 @@ import numpy as np import torch import trimesh +import cv2 from torch import nn from lab4d.nnutils.base import BaseMLP from lab4d.nnutils.embedding import PosEmbedding from lab4d.nnutils.nerf import NeRF from lab4d.utils.decorator import train_only_fields -from lab4d.utils.geom_utils import Kmatinv, pinhole_projection +from lab4d.utils.geom_utils import ( + Kmatinv, + pinhole_projection, + extend_aabb, + check_inside_aabb, +) class FeatureNeRF(NeRF): @@ -85,6 +91,10 @@ def __init__( sigma = torch.tensor([1.0]) self.logsigma = nn.Parameter(sigma.log()) + self.set_match_region(sample_around_surface=True) + + def set_match_region(self, sample_around_surface): + self.sample_around_surface = sample_around_surface def query_field(self, samples_dict, flow_thresh=None): """Render outputs from a neural radiance field. @@ -119,6 +129,8 @@ def query_field(self, samples_dict, flow_thresh=None): # global matching if "feature" in samples_dict and "feature" in feat_dict: feature = feat_dict["feature"] + sdf = feat_dict["sdf"] + feature, xyz = self.propose_matches(feature, xyz.detach(), sdf) xyz_matches = self.global_match(samples_dict["feature"], feature, xyz) xy_reproj, xyz_reproj = self.forward_project( xyz_matches, @@ -130,10 +142,98 @@ def query_field(self, samples_dict, flow_thresh=None): ) aux_dict["xyz_matches"] = xyz_matches aux_dict["xyz_reproj"] = xyz_reproj - aux_dict["xy_reproj"] = xy_reproj + hxy = samples_dict["hxy"][..., :2] + aux_dict["xy_reproj"] = (xy_reproj - hxy).norm(2, -1, keepdim=True) + # # visualize matches + # if not self.training: + # img = self.plot_xy_matches(xy_reproj, samples_dict) + # cv2.imwrite("tmp/arrow.png", img) + # trimesh.Trimesh(vertices=xyz_matches[0].cpu().numpy()).export( + # "tmp/matches.obj" + # ) + # import pdb + + # pdb.set_trace() return feat_dict, deltas, aux_dict - @train_only_fields + def plot_xy_matches(self, xy_reproj, samples_dict): + # plot arrow from hxy to xy_reproj + res = int(np.sqrt(samples_dict["hxy"].shape[1])) + img = np.zeros((res * 16, res * 16, 3), dtype=np.uint8) + hxy_vis = samples_dict["hxy"].view(-1, res, res, 3)[..., :2].cpu().numpy() + xy_reproj_vis = xy_reproj.view(-1, res, res, 2).cpu().numpy() + feature_vis = samples_dict["feature"].view(-1, res, res, 16) + for i in range(res): + for j in range(res): + if feature_vis[0, i, j].norm(2, -1) == 0: + continue + # draw a line + img = cv2.arrowedLine( + img, + tuple(hxy_vis[0, i, j] * 16), + tuple(xy_reproj_vis[0, i, j] * 16), + (0, 255, 0), + 1, + ) + return img + + def propose_matches(self, feature, xyz, sdf, num_candidates=8192): + """Sample canonical points for global matching + Args: + feature: (M,N,D,feature_channels) Pixel features + xyz: (M,N,D,3) Points in field coordinates + num_candidates: Number of candidates to sample + Returns: + feature: (num_candidates, feature_channels) Canonical features + xyz: (num_candidates, 3) Points in field coordinates + """ + # threshold + if self.sample_around_surface: + thresh = 0.005 + else: + thresh = 1 + # sample canonical points + feature = feature.view(-1, feature.shape[-1]) # (M*N*D, feature_channels) + xyz = xyz.view(-1, 3) # (M*N*D, 3) + + # remove points outsize aabb + aabb = self.get_aabb() + aabb = extend_aabb(aabb, 0.1) + inside_aabb = check_inside_aabb(xyz, aabb) + feature = feature[inside_aabb] + xyz = xyz[inside_aabb] + sdf = sdf.view(-1)[inside_aabb] + + # remove points far from the surface beyond a sdf threshold + is_near_surface = sdf.abs() < thresh + feature = feature[is_near_surface] + xyz = xyz[is_near_surface] + + num_candidates = min(num_candidates, feature.shape[0]) + idx = torch.randperm(feature.shape[0])[: num_candidates // 2] + feature = feature[idx] # (num_candidates, feature_channels) + xyz = xyz[idx] # (num_candidates, 3) + + # sample additional points + if self.sample_around_surface: + # sample from proxy geometry on the surface + proxy_geometry = self.get_proxy_geometry() + rand_xyz, _ = trimesh.sample.sample_surface( + proxy_geometry, num_candidates // 2 + ) + rand_xyz = torch.tensor(rand_xyz, dtype=torch.float32, device=xyz.device) + else: + # sample from aabb + rand_xyz, _, _ = self.sample_points_aabb( + num_candidates // 2, extend_factor=0.1 + ) + rand_feat = self.compute_feat(rand_xyz)["feature"] + + # combine + feature = torch.cat([feature, rand_feat], dim=0) + xyz = torch.cat([xyz, rand_xyz], dim=0) + return feature, xyz + def compute_feat(self, xyz): """Render feature field @@ -154,47 +254,35 @@ def global_match( feat_px, feat_canonical, xyz_canonical, - num_candidates=1024, - num_grad=128, + num_grad=0, ): """Match pixel features to canonical features, which combats local minima in differentiable rendering optimization Args: feat: (M,N,feature_channels) Pixel features - feat_canonical: (M,N,D,feature_channels) Canonical features - xyz_canonical: (M,N,D,3) Canonical points + feat_canonical: (...,feature_channels) Canonical features + xyz_canonical: (...,3) Canonical points Returns: xyz_matched: (M,N,3) Matched xyz """ shape = feat_px.shape feat_px = feat_px.view(-1, shape[-1]) # (M*N, feature_channels) - feat_canonical = feat_canonical.view(-1, shape[-1]) # (M*N*D, feature_channels) - xyz_canonical = xyz_canonical.view(-1, 3) # (M*N*D, 3) - - # sample canonical points - num_candidates = min(num_candidates, feat_canonical.shape[0]) - idx = torch.randperm(feat_canonical.shape[0])[:num_candidates] - feat_canonical = feat_canonical[idx] # (num_candidates, feature_channels) - xyz_canonical = xyz_canonical[idx] # (num_candidates, 3) # compute similarity score = torch.matmul(feat_px, feat_canonical.t()) # (M*N, num_candidates) - # # find top K candidates - # num_grad = min(num_grad, score.shape[1]) - # score, idx = torch.topk(score, num_grad, dim=1, largest=True) - # score = score * self.logsigma.exp() # temperature + # find top K candidates + if num_grad > 0: + num_grad = min(num_grad, score.shape[1]) + score, idx = torch.topk(score, num_grad, dim=1, largest=True) + xyz_canonical = xyz_canonical[idx] - # # soft argmin - # prob = torch.softmax(score, dim=1) - # xyz_matched = torch.sum(prob.unsqueeze(-1) * xyz_canonical[idx], dim=1) - - # use all candidates + # soft argmin + # score = score.detach() # do not backprop to features score = score * self.logsigma.exp() # temperature prob = torch.softmax(score, dim=1) xyz_matched = torch.sum(prob.unsqueeze(-1) * xyz_canonical, dim=1) - xyz_matched = xyz_matched.view(shape[:-1] + (-1,)) return xyz_matched diff --git a/lab4d/nnutils/intrinsics.py b/lab4d/nnutils/intrinsics.py index f19441b..30ed3ba 100644 --- a/lab4d/nnutils/intrinsics.py +++ b/lab4d/nnutils/intrinsics.py @@ -4,6 +4,7 @@ import torch.nn as nn from lab4d.nnutils.time import TimeMLP +from lab4d.utils.torch_utils import reinit_model class IntrinsicsMLP(TimeMLP): @@ -64,12 +65,16 @@ def __init__( "init_vals", torch.tensor(intrinsics, dtype=torch.float32), persistent=False ) - def mlp_init(self): + def base_init(self): """Initialize camera intrinsics from external values""" intrinsics = self.init_vals frame_offset = self.get_frame_offset() self.base_logfocal.data = intrinsics[frame_offset[:-1], :2].log() self.base_ppoint.data = intrinsics[frame_offset[:-1], 2:] + + def mlp_init(self): + """Initialize camera intrinsics from external values""" + self.base_init() super().mlp_init(termination_loss=1.0) def forward(self, t_embed): @@ -87,7 +92,7 @@ def get_vals(self, frame_id=None): """Compute camera intrinsics at the given frames. Args: - frame_id: (...,) Frame id. If None, compute at all frames + frame_id: (M,) Frame id. If None, compute at all frames Returns: intrinsics: (..., 4) Output camera intrinsics """ @@ -105,3 +110,58 @@ def get_vals(self, frame_id=None): ppoint = base_ppoint.expand_as(focal) intrinsics = torch.cat([focal, ppoint], dim=-1) return intrinsics + + def get_intrinsics(self, inst_id=None): + if inst_id is None: + frame_id = None + else: + raw_fid_to_vid = self.time_embedding.raw_fid_to_vid + frame_id = (raw_fid_to_vid == inst_id).nonzero()[:, 0] + intrinsics = self.get_vals(frame_id=frame_id) + return intrinsics + + +class IntrinsicsMLP_delta(IntrinsicsMLP): + """Encode camera intrinsics over time with an MLP + + Args: + intrinsics: (N,4) Camera intrinsics (fx, fy, cx, cy) + frame_info (Dict): Metadata about the frames in a dataset + D (int): Number of linear layers + W (int): Number of hidden units in each MLP layer + num_freq_t (int): Number of frequencies in the time embedding + skips (List(int)): List of layers to add skip connections at + activation (Function): Activation function to use (e.g. nn.ReLU()) + time_scale (float): Control the sensitivity to time by scaling. + Lower values make the module less sensitive to time. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + del self.base_logfocal + del self.base_ppoint + self.register_buffer( + "base_logfocal", torch.zeros(self.time_embedding.num_frames, 2) + ) + self.register_buffer( + "base_ppoint", torch.zeros(self.time_embedding.num_frames, 2) + ) + + def update_base_focal(self): + """Update base camera rotations from current camera trajectory""" + intrinsics = self.get_vals() + focal, ppoint = intrinsics[..., :2], intrinsics[..., 2:] + self.base_logfocal.data = focal.log() + self.base_ppoint.data = ppoint + # reinit the mlp head + reinit_model(self.focal, std=0.01) + + def base_init(self): + """Initialize camera intrinsics from external values""" + intrinsics = self.init_vals + frame_offset = self.get_frame_offset() + for i in range(len(frame_offset) - 1): + focal = intrinsics[frame_offset[i], :2] + ppoint = intrinsics[frame_offset[i], 2:] + self.base_logfocal.data[frame_offset[i] : frame_offset[i + 1]] = focal.log() + self.base_ppoint.data[frame_offset[i] : frame_offset[i + 1]] = ppoint diff --git a/lab4d/nnutils/multifields.py b/lab4d/nnutils/multifields.py index ad9e6eb..e007ad5 100644 --- a/lab4d/nnutils/multifields.py +++ b/lab4d/nnutils/multifields.py @@ -8,10 +8,12 @@ from lab4d.nnutils.deformable import Deformable from lab4d.nnutils.nerf import NeRF -from lab4d.nnutils.pose import ArticulationSkelMLP +from lab4d.nnutils.bgnerf import BGNeRF +from lab4d.nnutils.pose import ArticulationSkelMLP, CameraMLP_so3 from lab4d.nnutils.warping import ComposedWarp, SkinningWarp from lab4d.utils.quat_transform import quaternion_translation_to_se3 from lab4d.utils.vis_utils import draw_cams, mesh_cat +from lab4d.utils.geom_utils import extend_aabb class MultiFields(nn.Module): @@ -22,10 +24,8 @@ class MultiFields(nn.Module): field_type (str): Field type ("comp", "fg", or "bg") fg_motion (str): Foreground motion type ("rigid", "dense", "bob", "skel-{human,quad}", or "comp_skel-{human,quad}_{bob,dense}") - num_inst (int): Number of distinct object instances. If --nosingle_inst - is passed, this is equal to the number of videos, as we assume each - video captures a different instance. Otherwise, we assume all videos - capture the same instance and set this to 1. + single_inst (bool): If True, assume the same morphology over videos + single_scene (bool): If True, assume the same scene over videos """ def __init__( @@ -33,7 +33,8 @@ def __init__( data_info, field_type="bg", fg_motion="rigid", - num_inst=None, + single_inst=True, + single_scene=True, ): vis_info = data_info["vis_info"] @@ -41,7 +42,8 @@ def __init__( field_params = nn.ParameterDict() self.field_type = field_type self.fg_motion = fg_motion - self.num_inst = num_inst + self.single_inst = single_inst + self.single_scene = single_scene # specify field type if field_type == "comp": @@ -72,6 +74,7 @@ def define_field(self, category, data_info, tracklet_id): # which is identical to video frameid if # instance=1 data_info["rtmat"] = data_info["rtmat"][tracklet_id] data_info["geom_path"] = data_info["geom_path"][tracklet_id] + num_inst = len(data_info["frame_info"]["frame_offset"]) - 1 if category == "fg": # TODO add a flag to decide rigid fg vs deformable fg nerf = Deformable( @@ -79,17 +82,23 @@ def define_field(self, category, data_info, tracklet_id): data_info, num_freq_dir=-1, appr_channels=32, - num_inst=self.num_inst, + num_inst=1 if self.single_inst else num_inst, init_scale=0.2, ) # no directional encoding elif category == "bg": - nerf = NeRF( + if self.single_scene: + bg_arch = NeRF + else: + bg_arch = BGNeRF + nerf = bg_arch( data_info, - num_freq_xyz=6, + D=8, num_freq_dir=0, appr_channels=0, - init_scale=0.1, + num_inst=num_inst, + init_scale=0.05, + # init_scale=0.2, ) else: # exit with an error raise ValueError("Invalid category") @@ -115,6 +124,13 @@ def set_alpha(self, alpha): field.pos_embedding.set_alpha(alpha) field.pos_embedding_color.set_alpha(alpha) + def set_importance_sampling(self, use_importance_sampling): + """ + Set inverse sampling for all child fields + """ + for field in self.field_params.values(): + field.use_importance_sampling = use_importance_sampling + def set_beta_prob(self, beta_prob): """Set beta probability for all child fields. This determines the probability of instance code swapping @@ -136,6 +152,7 @@ def update_geometry_aux(self): def reset_geometry_aux(self): """Reset proxy geometry and bounds for all child fields""" for field in self.field_params.values(): + print("resetting geometry aux for %s" % field.category) field.update_proxy() field.update_aabb(beta=0) field.update_near_far(beta=0) @@ -146,7 +163,7 @@ def extract_canonical_meshes( grid_size=64, level=0.0, inst_id=None, - use_visibility=True, + vis_thresh=0.0, use_extend_aabb=True, ): """Extract canonical mesh using marching cubes for all child fields @@ -155,9 +172,8 @@ def extract_canonical_meshes( grid_size (int): Marching cubes resolution level (float): Contour value to search for isosurfaces on the signed distance function - inst_id: (M,) Instance id. If None, extract for the average instance - use_visibility (bool): If True, use visibility mlp to mask out invisible - region. + inst_id: (int) Instance id. If None, extract for the average instance + vis_thresh (float): threshold for visibility value to mask out invisible points. use_extend_aabb (bool): If True, extend aabb by 50% to get a loose proxy. Used at training time. Returns: @@ -169,7 +185,7 @@ def extract_canonical_meshes( grid_size=grid_size, level=level, inst_id=inst_id, - use_visibility=use_visibility, + vis_thresh=vis_thresh, use_extend_aabb=use_extend_aabb, ) meshes[category] = mesh @@ -184,7 +200,7 @@ def export_geometry_aux(self, path): """ for category, field in self.field_params.items(): # print(field.near_far) - mesh_geo = field.proxy_geometry + mesh_geo = field.get_proxy_geometry() quat, trans = field.camera_mlp.get_vals() rtmat = quaternion_translation_to_se3(quat, trans).cpu() # evenly pick max 200 cameras @@ -194,7 +210,8 @@ def export_geometry_aux(self, path): mesh_cam = draw_cams(rtmat) mesh = mesh_cat(mesh_geo, mesh_cam) if category == "fg": - mesh_gauss, mesh_sdf = field.warp.get_template_vis(aabb=field.aabb) + aabb = extend_aabb(field.aabb, factor=0.5) + mesh_gauss, mesh_sdf = field.warp.get_template_vis(aabb=aabb) mesh_gauss.export("%s-%s-gauss.obj" % (path, category)) mesh_sdf.export("%s-%s-sdf.obj" % (path, category)) mesh.export("%s-%s-proxy.obj" % (path, category)) @@ -410,13 +427,19 @@ def get_cameras(self, frame_id=None): field2cam[cate] = quaternion_translation_to_se3(quat, trans) return field2cam - def get_aabb(self): + def get_aabb(self, inst_id=None): """Compute axis aligned bounding box + Args: + inst_id (int or tensor): Instance id. If None, return aabb for all instances Returns: - aabb (Dict): Maps field names ("fg" or "bg") to (2,3) aabb + aabb (Dict): Maps field names ("fg" or "bg") to (1/N,2,3) aabb """ + if inst_id is not None: + if not torch.is_tensor(inst_id): + inst_id = torch.tensor(inst_id, dtype=torch.long) + inst_id = inst_id.view(-1) aabb = {} for cate, field in self.field_params.items(): - aabb[cate] = field.aabb / field.logscale.exp() + aabb[cate] = field.get_aabb(inst_id=inst_id) / field.logscale.exp() return aabb diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index eec4b08..68df8ae 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -5,11 +5,13 @@ import trimesh from pysdf import SDF from torch import nn +from torch.autograd.functional import jacobian + from lab4d.nnutils.appearance import AppearanceEmbedding from lab4d.nnutils.base import CondMLP from lab4d.nnutils.embedding import PosEmbedding -from lab4d.nnutils.pose import CameraMLP +from lab4d.nnutils.pose import CameraMLP, CameraMLP_so3 from lab4d.nnutils.visibility import VisField from lab4d.utils.decorator import train_only_fields from lab4d.utils.geom_utils import ( @@ -20,8 +22,9 @@ marching_cubes, pinhole_projection, check_inside_aabb, + compute_rectification_se3, ) -from lab4d.utils.loss_utils import align_vectors +from lab4d.utils.loss_utils import align_tensors from lab4d.utils.quat_transform import ( quaternion_apply, quaternion_translation_inverse, @@ -30,7 +33,9 @@ dual_quaternion_to_quaternion_translation, ) from lab4d.utils.render_utils import sample_cam_rays, sample_pdf, compute_weights -from lab4d.utils.torch_utils import compute_gradient +from lab4d.utils.torch_utils import compute_gradient, flip_pair, compute_gradients_sdf +from lab4d.utils.vis_utils import append_xz_plane + class NeRF(nn.Module): """A static neural radiance field with an MLP backbone. @@ -63,7 +68,7 @@ def __init__( self, data_info, D=5, - W=128, + W=256, num_freq_xyz=10, num_freq_dir=4, appr_channels=32, @@ -75,6 +80,7 @@ def __init__( init_beta=0.1, init_scale=0.1, color_act=True, + field_arch=CondMLP, ): rtmat = data_info["rtmat"] frame_info = data_info["frame_info"] @@ -96,7 +102,7 @@ def __init__( # xyz encoding layers # TODO: add option to replace with instNGP - self.basefield = CondMLP( + self.basefield = field_arch( num_inst=self.num_inst, D=D, W=W, @@ -109,8 +115,8 @@ def __init__( ) # color - self.pos_embedding_color = PosEmbedding(3, num_freq_xyz + 2) - self.colorfield = CondMLP( + self.pos_embedding_color = PosEmbedding(3, 12) + self.colorfield = field_arch( num_inst=self.num_inst, D=2, W=W, @@ -149,22 +155,28 @@ def __init__( # camera pose: field to camera rtmat[..., :3, 3] *= init_scale - self.camera_mlp = CameraMLP(rtmat, frame_info=frame_info) + self.camera_mlp = CameraMLP_so3(rtmat, frame_info=frame_info) + # self.camera_mlp = CameraMLP(rtmat, frame_info=frame_info) # visibility mlp - self.vis_mlp = VisField(self.num_inst) + self.vis_mlp = VisField(self.num_inst, field_arch=field_arch) - # load initial mesh + # load initial mesh, define aabb self.init_proxy(geom_path, init_scale) - self.register_buffer("aabb", torch.zeros(2, 3)) - self.update_aabb(beta=0) + self.init_aabb() # non-parameters are not synchronized self.register_buffer( "near_far", torch.zeros(frame_offset_raw[-1], 2), persistent=False ) - def forward(self, xyz, dir=None, frame_id=None, inst_id=None, get_density=True): + field2world = torch.zeros(4, 4)[None].expand(self.num_inst, -1, -1).clone() + self.register_buffer("field2world", field2world, persistent=True) + + # inverse sampling + self.use_importance_sampling = True + + def forward(self, xyz, dir=None, frame_id=None, inst_id=None): """ Args: xyz: (M,N,D,3) Points along ray in object canonical space @@ -173,26 +185,22 @@ def forward(self, xyz, dir=None, frame_id=None, inst_id=None, get_density=True): inst_id: (M,) Instance id. If None, render for the average instance Returns: rgb: (M,N,D,3) Rendered RGB - sigma: (M,N,D,1) If get_density=True, return density. Otherwise - return signed distance (negative inside) + sdf: (M,N,D,1) Signed distance (negative inside) + sigma: (M,N,D,1) Denstiy """ if frame_id is not None: assert frame_id.ndim == 1 if inst_id is not None: assert inst_id.ndim == 1 - xyz_embed = self.pos_embedding(xyz) - xyz_feat = self.basefield(xyz_embed, inst_id) - sdf = self.sdf(xyz_feat) # negative inside, positive outside - if get_density: - ibeta = self.logibeta.exp() - # density = torch.sigmoid(-sdf * ibeta) * ibeta # neus - density = ( - 0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() * ibeta) - ) * ibeta # volsdf - out = density - else: - out = sdf + sdf, xyz_feat = self.forward_sdf(xyz, inst_id=inst_id) + + ibeta = self.logibeta.exp() + # density = torch.sigmoid(-sdf * ibeta) * ibeta # neus + density = ( + 0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() * ibeta) + ) * ibeta # volsdf + out = sdf, density if dir is not None: dir_embed = self.dir_embedding(dir) @@ -211,9 +219,25 @@ def forward(self, xyz, dir=None, frame_id=None, inst_id=None, get_density=True): rgb = self.rgb(torch.cat([xyz_feat, appr_embed], -1)) if self.color_act: rgb = rgb.sigmoid() - out = rgb, out + out = (rgb,) + out return out + def forward_sdf(self, xyz, inst_id=None): + """Forward pass for signed distance function + Args: + xyz: (M,N,D,3) Points along ray in object canonical space + inst_id: (M,) Instance id. If None, render for the average instance + + Returns: + sdf: (M,N,D,1) Signed distance (negative inside) + xyz_feat: (M,N,D,W) Features from the xyz encoder + """ + xyz_embed = self.pos_embedding(xyz) + xyz_feat = self.basefield(xyz_embed, inst_id) + + sdf = self.sdf(xyz_feat) # negative inside, positive outside + return sdf, xyz_feat + def get_init_sdf_fn(self): """Initialize signed distance function from mesh geometry @@ -241,14 +265,27 @@ def init_proxy(self, geom_path, init_scale): """Initialize the geometry from a mesh Args: - geom_path (str): Initial shape mesh + geom_path (List(str)): paths to initial shape mesh init_scale (float): Geometry scale factor """ - mesh = trimesh.load(geom_path) + mesh = trimesh.load(geom_path[0]) mesh.vertices = mesh.vertices * init_scale self.proxy_geometry = mesh - def geometry_init(self, sdf_fn, nsample=256): + def get_proxy_geometry(self): + """Get proxy geometry + + Returns: + proxy_geometry (Trimesh): Proxy geometry + """ + return self.proxy_geometry + + def init_aabb(self): + """Initialize axis-aligned bounding box""" + self.register_buffer("aabb", torch.zeros(2, 3)) + self.update_aabb(beta=0) + + def geometry_init(self, sdf_fn, nsample=4096): """Initialize SDF using tsdf-fused geometry if radius is not given. Otherwise, initialize sdf using a unit sphere @@ -261,21 +298,21 @@ def geometry_init(self, sdf_fn, nsample=256): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) # optimize - for i in range(500): + for i in range(1000): optimizer.zero_grad() # sample points and gt sdf inst_id = torch.randint(0, self.num_inst, (nsample,), device=device) # sample points - pts = self.sample_points_aabb(nsample, extend_factor=0.25) + pts, _, _ = self.sample_points_aabb(nsample, extend_factor=0.5) # get sdf from proxy geometry sdf_gt = sdf_fn(pts) # evaluate sdf loss - sdf = self.forward(pts, inst_id=inst_id, get_density=False) - scale = align_vectors(sdf, sdf_gt) + sdf, _ = self.forward(pts, inst_id=inst_id) + scale = align_tensors(sdf, sdf_gt) sdf_loss = (sdf * scale.detach() - sdf_gt).pow(2).mean() # evaluate visibility loss @@ -284,9 +321,11 @@ def geometry_init(self, sdf_fn, nsample=256): vis_loss = vis_loss * 0.01 # evaluate eikonal loss - eikonal_loss = self.compute_eikonal(pts[:, None, None], inst_id=inst_id) + eikonal_loss, _ = self.compute_eikonal( + pts[:, None, None], inst_id=inst_id, sample_ratio=1 + ) eikonal_loss = eikonal_loss[eikonal_loss > 0].mean() - eikonal_loss = eikonal_loss * 1e-4 + eikonal_loss = eikonal_loss * 1e-3 total_loss = sdf_loss + vis_loss + eikonal_loss total_loss.backward() @@ -297,7 +336,7 @@ def geometry_init(self, sdf_fn, nsample=256): def update_proxy(self): """Extract proxy geometry using marching cubes""" mesh = self.extract_canonical_mesh(level=0.005) - if mesh is not None: + if len(mesh.vertices) > 3: self.proxy_geometry = mesh @torch.no_grad() @@ -306,7 +345,7 @@ def extract_canonical_mesh( grid_size=64, level=0.0, inst_id=None, - use_visibility=True, + vis_thresh=0.0, use_extend_aabb=True, ): """Extract canonical mesh using marching cubes @@ -315,9 +354,8 @@ def extract_canonical_mesh( grid_size (int): Marching cubes resolution level (float): Contour value to search for isosurfaces on the signed distance function - inst_id: (M,) Instance id. If None, extract for the average instance - use_visibility (bool): If True, use visibility mlp to mask out invisible - region. + inst_id: (int) Instance id. If None, extract for the average instance + vis_thresh (float): threshold for visibility value to remove invisible pts. use_extend_aabb (bool): If True, extend aabb by 50% to get a loose proxy. Used at training time. Returns: @@ -325,22 +363,41 @@ def extract_canonical_mesh( """ if inst_id is not None: inst_id = torch.tensor([inst_id], device=next(self.parameters()).device) - sdf_func = lambda xyz: self.forward(xyz, inst_id=inst_id, get_density=False) - vis_func = lambda xyz: self.vis_mlp(xyz, inst_id=inst_id) > 0 - if use_extend_aabb: - aabb = extend_aabb(self.aabb, factor=0.5) + aabb = self.get_aabb(inst_id=inst_id) # 2,3 else: - aabb = self.aabb + aabb = self.get_aabb() + sdf_func = lambda xyz: self.forward(xyz, inst_id=inst_id)[0] + vis_func = lambda xyz: self.vis_mlp(xyz, inst_id=inst_id) > vis_thresh + if use_extend_aabb: + aabb = extend_aabb(aabb, factor=0.5) mesh = marching_cubes( sdf_func, - aabb, - visibility_func=vis_func if use_visibility else None, + aabb[0], + visibility_func=vis_func, grid_size=grid_size, level=level, apply_connected_component=True if self.category == "fg" else False, ) return mesh + def get_aabb(self, inst_id=None): + """Get axis-aligned bounding box + Args: + inst_id: (N,) Instance id + Returns: + aabb: (2,3) Axis-aligned bounding box if inst_id is None, (N,2,3) otherwise + """ + if inst_id is None: + return self.aabb[None] + else: + return self.aabb[None].repeat(len(inst_id), 1, 1) + + def get_scale(self): + """Get scale of the proxy geometry""" + assert self.category == "fg" + aabb = self.get_aabb()[0] + return (aabb[1] - aabb[0]).mean() + def update_aabb(self, beta=0.9): """Update axis-aligned bounding box by interpolating with the current proxy geometry's bounds @@ -352,6 +409,7 @@ def update_aabb(self, beta=0.9): bounds = self.proxy_geometry.bounds if bounds is not None: aabb = torch.tensor(bounds, dtype=torch.float32, device=device) + aabb = extend_aabb(aabb, factor=0.2) # 1.4x larger self.aabb = self.aabb * beta + aabb * (1 - beta) def update_near_far(self, beta=0.9): @@ -375,24 +433,29 @@ def update_near_far(self, beta=0.9): frame_mapping ] * beta + near_far * (1 - beta) - def sample_points_aabb(self, nsample, extend_factor=1.0): + def sample_points_aabb(self, nsample, extend_factor=1.0, aabb=None): """Sample points within axis-aligned bounding box Args: nsample (int): Number of samples extend_factor (float): Extend aabb along each side by factor of the previous size + aabb: (2,3) Axis-aligned bounding box to sample from, optional Returns: pts: (nsample, 3) Sampled points """ device = next(self.parameters()).device - aabb = extend_aabb(self.aabb, factor=extend_factor) + frame_id = torch.randint(0, self.num_frames, (nsample,), device=device) + inst_id = torch.randint(0, self.num_inst, (nsample,), device=device) + if aabb is None: + aabb = self.get_aabb(inst_id=inst_id) + aabb = extend_aabb(aabb, factor=extend_factor) pts = ( torch.rand(nsample, 3, dtype=torch.float32, device=device) - * (aabb[1:] - aabb[:1]) - + aabb[:1] + * (aabb[..., 1, :] - aabb[..., 0, :]) + + aabb[..., 0, :] ) - return pts + return pts, frame_id, inst_id def visibility_decay_loss(self, nsample=512): """Encourage visibility to be low at random points within the aabb. The @@ -404,9 +467,7 @@ def visibility_decay_loss(self, nsample=512): loss: (0,) Visibility decay loss """ # sample random points - device = next(self.parameters()).device - pts = self.sample_points_aabb(nsample) - inst_id = torch.randint(0, self.num_inst, (nsample,), device=device) + pts, _, inst_id = self.sample_points_aabb(nsample) # evaluate loss vis = self.vis_mlp(pts, inst_id=inst_id) @@ -414,7 +475,7 @@ def visibility_decay_loss(self, nsample=512): return loss def compute_eikonal(self, xyz, inst_id=None, sample_ratio=16): - """Compute eikonal loss + """Compute eikonal loss and normal in the canonical space Args: xyz: (M,N,D,3) Input coordinates in canonical space @@ -432,6 +493,7 @@ def compute_eikonal(self, xyz, inst_id=None, sample_ratio=16): inst_id = inst_id[:, None].expand(-1, N) inst_id = inst_id.reshape(-1) eikonal_loss = torch.zeros_like(xyz[..., 0]) + normal = torch.zeros_like(xyz) # subsample to make it more efficient if M * N > sample_size: @@ -444,15 +506,26 @@ def compute_eikonal(self, xyz, inst_id=None, sample_ratio=16): rand_inds = Ellipsis xyz = xyz.detach() - inst_id = inst_id.detach() if inst_id is not None else None - fn_sdf = lambda x: self.forward(x, inst_id=inst_id, get_density=False) - g = compute_gradient(fn_sdf, xyz)[..., 0] + fn_sdf = lambda x: self.forward(x, inst_id=inst_id)[0] + g = compute_gradients_sdf(fn_sdf, xyz, training=self.training) + # g = compute_gradient(fn_sdf, xyz)[..., 0] + + # def fn_sdf(x): + # sdf, _ = self.forward(x, inst_id=inst_id) + # sdf_sum = sdf.sum() + # return sdf_sum + + # g = jacobian(fn_sdf, xyz, create_graph=True, strict=True) eikonal_loss[rand_inds] = (g.norm(2, dim=-1) - 1) ** 2 eikonal_loss = eikonal_loss.reshape(M, N, D, 1) - return eikonal_loss + normal[rand_inds] = g # self.grad_to_normal(g) + normal = normal.reshape(M, N, D, 3) + return eikonal_loss, normal - def compute_normal(self, xyz_cam, dir_cam, field2cam, frame_id=None, inst_id=None, samples_dict={}): + def compute_eikonal_view( + self, xyz_cam, dir_cam, field2cam, frame_id=None, inst_id=None, samples_dict={} + ): """Compute eikonal loss and normals in camera space Args: @@ -469,6 +542,17 @@ def compute_normal(self, xyz_cam, dir_cam, field2cam, frame_id=None, inst_id=Non """ M, N, D, _ = xyz_cam.shape + xyz_cam = xyz_cam.detach() + dir_cam = dir_cam.detach() + field2cam = (field2cam[0].detach(), field2cam[1].detach()) + samples_dict_copy = {} + for k, v in samples_dict.items(): + if isinstance(v, tuple): + samples_dict_copy[k] = (v[0].detach(), v[1].detach()) + else: + samples_dict_copy[k] = v.detach() + samples_dict = samples_dict_copy + def fn_sdf(xyz_cam): xyz = self.backward_warp( xyz_cam, @@ -478,19 +562,30 @@ def fn_sdf(xyz_cam): inst_id=inst_id, samples_dict=samples_dict, )["xyz"] - sdf = self.forward(xyz, inst_id=inst_id, get_density=False) + sdf, _ = self.forward(xyz, inst_id=inst_id) return sdf - g = compute_gradient(fn_sdf, xyz_cam)[..., 0] + # g = compute_gradient(fn_sdf, xyz_cam)[..., 0] + g = compute_gradients_sdf(fn_sdf, xyz_cam, training=self.training) eikonal = (g.norm(2, dim=-1, keepdim=True) - 1) ** 2 - normal = torch.nn.functional.normalize(g, dim=-1) + normal = g # self.grad_to_normal(g) + return eikonal, normal + + @staticmethod + def grad_to_normal(g): + """ + Args: + g: (...,3) Gradient of sdf + Returns: + normal: (...,3) Normal vector field + """ + normal = F.normalize(g, dim=-1) # Multiply by [1, -1, -1] to match normal conventions from ECON # https://github.com/YuliangXiu/ECON/blob/d98e9cbc96c31ecaa696267a072cdd5ef78d14b8/apps/infer.py#L257 normal = normal * torch.tensor([1, -1, -1], device="cuda") - - return eikonal, normal + return normal @torch.no_grad() def get_valid_idx(self, xyz, xyz_t=None, vis_score=None, samples_dict={}): @@ -504,7 +599,8 @@ def get_valid_idx(self, xyz, xyz_t=None, vis_score=None, samples_dict={}): valid_idx: (M,N,D) Visibility mask, bool """ # check whether the point is inside the aabb - aabb = extend_aabb(self.aabb) + aabb = self.get_aabb(samples_dict["inst_id"]) + aabb = extend_aabb(aabb) # (M,N,D), whether the point is inside the aabb inside_aabb = check_inside_aabb(xyz, aabb) @@ -518,7 +614,7 @@ def get_valid_idx(self, xyz, xyz_t=None, vis_score=None, samples_dict={}): )[1][0] t_aabb = torch.stack([t_bones.min(0)[0], t_bones.max(0)[0]], 0) t_aabb = extend_aabb(t_aabb, factor=1.0) - inside_aabb = check_inside_aabb(xyz_t, t_aabb) + inside_aabb = check_inside_aabb(xyz_t, t_aabb[None]) valid_idx = valid_idx & inside_aabb # temporally disable visibility mask @@ -559,10 +655,7 @@ def get_samples(self, Kinv, batch): near_far = self.near_far.to(device) near_far = near_far[batch["frameid"]] else: - corners = trimesh.bounds.corners(self.proxy_geometry.bounds) - corners = torch.tensor(corners, dtype=torch.float32, device=device) - field2cam_mat = quaternion_translation_to_se3(field2cam[0], field2cam[1]) - near_far = get_near_far(corners, field2cam_mat, tol_fac=1.5) + near_far = self.get_near_far(frame_id, field2cam) # auxiliary outputs samples_dict = {} @@ -577,6 +670,14 @@ def get_samples(self, Kinv, batch): samples_dict["feature"] = batch["feature"] return samples_dict + def get_near_far(self, frame_id, field2cam): + device = next(self.parameters()).device + corners = trimesh.bounds.corners(self.proxy_geometry.bounds) + corners = torch.tensor(corners, dtype=torch.float32, device=device) + field2cam_mat = quaternion_translation_to_se3(field2cam[0], field2cam[1]) + near_far = get_near_far(corners, field2cam_mat, tol_fac=1.5) + return near_far + def query_field(self, samples_dict, flow_thresh=None): """Render outputs from a neural radiance field. @@ -602,7 +703,7 @@ def query_field(self, samples_dict, flow_thresh=None): hxy = samples_dict["hxy"] # (M,N,2) # sample camera space rays - if not self.training: + if self.use_importance_sampling: # importance sampling xyz_cam, dir_cam, deltas, depth = self.importance_sampling( hxy, @@ -612,10 +713,15 @@ def query_field(self, samples_dict, flow_thresh=None): frame_id, inst_id, samples_dict, + n_depth=64, ) else: xyz_cam, dir_cam, deltas, depth = sample_cam_rays( - hxy, Kinv, near_far, perturb=False + hxy, + Kinv, + near_far, + n_depth=64, + perturb=False, ) # (M, N, D, x) # backward warping @@ -627,7 +733,7 @@ def query_field(self, samples_dict, flow_thresh=None): xyz_t = backwarp_dict["xyz_t"] # visibility - vis_score = self.vis_mlp(xyz, inst_id=inst_id) # (M, N, D, 1) + vis_score = self.vis_mlp(xyz.detach(), inst_id=inst_id) # (M, N, D, 1) # compute valid_indices to speed up querying fields if self.training: @@ -674,6 +780,7 @@ def query_field(self, samples_dict, flow_thresh=None): # canonical point feat_dict["xyz"] = xyz + feat_dict["xyz_t"] = xyz_t feat_dict["xyz_cam"] = xyz_cam # depth @@ -683,7 +790,6 @@ def query_field(self, samples_dict, flow_thresh=None): aux_dict = {} return feat_dict, deltas, aux_dict - @torch.no_grad() def importance_sampling( self, hxy, @@ -693,51 +799,79 @@ def importance_sampling( frame_id, inst_id, samples_dict, - n_depth=64, + n_depth, ): """ importance sampling coarse """ - # sample camera space rays - xyz_cam, dir_cam, deltas, depth = sample_cam_rays( - hxy, Kinv, near_far, perturb=False, n_depth=n_depth // 2 - ) # (M, N, D, x) + with torch.no_grad(): + # sample camera space rays + xyz_cam, dir_cam, deltas, depth = sample_cam_rays( + hxy, Kinv, near_far, n_depth // 2, perturb=False + ) # (M, N, D, x) - # backward warping - xyz = self.backward_warp( - xyz_cam, dir_cam, field2cam, frame_id, inst_id, samples_dict=samples_dict - )["xyz"] + # backward warping + xyz = self.backward_warp( + xyz_cam, + dir_cam, + field2cam, + frame_id, + inst_id, + samples_dict=samples_dict, + )["xyz"] - # get pdf - density = self.forward( - xyz, - dir=None, - frame_id=frame_id, - inst_id=inst_id, - ) # (M, N, D, x) - weights, _ = compute_weights(density, deltas) # (M, N, D, x) + # get pdf + _, density = self.forward( + xyz, + dir=None, + frame_id=frame_id, + inst_id=inst_id, + ) # (M, N, D, x) + weights, _ = compute_weights(density, deltas) # (M, N, D, 1) + weights = weights.view(-1, n_depth // 2)[:, 1:-1] # (M*N, D-2) + # modify the weights such that only do is when there is a clear surface (wt is high) + weights_fill = 1 - weights.sum(-1, keepdim=True) + weights = weights + weights_fill / (n_depth // 2 - 2) + # assert torch.allclose(weights.sum(-1), torch.ones_like(weights[:, 0])) + + depth_mid = 0.5 * (depth[:, :, :-1] + depth[:, :, 1:]) # (M, N, D-1) + depth_mid = depth_mid.view(-1, n_depth // 2 - 1) # (M*N, D-1) + + depth_ = sample_pdf(depth_mid, weights, n_depth // 2, det=True) + depth_ = depth_.reshape(depth.shape) # (M, N, D, 1) - depth_mid = 0.5 * (depth[:, :, :-1] + depth[:, :, 1:]) # (M, N, D-1) - is_det = not self.training - depth_mid = depth_mid.view(-1, n_depth // 2 - 1) - weights = weights.view(-1, n_depth // 2) + depth, _ = torch.sort(torch.cat([depth, depth_], -2), -2) # (M, N, D, 1) - depth_ = sample_pdf( - depth_mid, weights[:, 1:-1], n_depth // 2, det=is_det - ).detach() - depth_ = depth_.reshape(depth.shape) - # detach so that grad doesn't propogate to weights_sampled from here + # # plot depth and depth_ + # import matplotlib.pyplot as plt + # import pdb - depth, _ = torch.sort(torch.cat([depth, depth_], -2), -2) # (M, N, D) + # pdb.set_trace() + + # valid_ind = weights.sum(-1) > 0 + # plt.figure() + # depth_vis = depth[0, :, :, 0][valid_ind].cpu().numpy() + + # plt.plot(depth_vis[::10].T) + # plt.show() + # plt.savefig("tmp/depth.png") + + # plt.figure() + # weights_vis = weights[valid_ind].cpu().numpy() + # plt.plot(weights_vis[::10].T) + # plt.show() + # plt.savefig("tmp/weights.png") # sample camera space rays xyz_cam, dir_cam, deltas, depth = sample_cam_rays( - hxy, Kinv, near_far, depth=depth, perturb=False + hxy, Kinv, near_far, None, depth=depth, perturb=False ) return xyz_cam, dir_cam, deltas, depth - def compute_jacobian(self, xyz, xyz_cam, dir_cam, field2cam, frame_id, inst_id, samples_dict): + def compute_jacobian( + self, xyz, xyz_cam, dir_cam, field2cam, frame_id, inst_id, samples_dict + ): """Compute eikonal and normal fields from Jacobian of SDF Args: @@ -758,12 +892,24 @@ def compute_jacobian(self, xyz, xyz_cam, dir_cam, field2cam, frame_id, inst_id, jacob_dict = {} if self.training: # For efficiency, compute subsampled eikonal loss in canonical space - jacob_dict["eikonal"] = self.compute_eikonal(xyz, inst_id=inst_id) + jacob_dict["eikonal"], jacob_dict["normal"] = self.compute_eikonal( + xyz, inst_id=inst_id + ) + # convert to camera space + jacob_dict["normal"] = quaternion_apply( + field2cam[0][:, None, None] + .expand(jacob_dict["normal"].shape[:-1] + (4,)) + .clone(), + jacob_dict["normal"], + ) else: # For rendering, compute full eikonal loss and normals in camera space - jacob_dict["eikonal"], jacob_dict["normal"] = self.compute_normal( + jacob_dict["eikonal"], jacob_dict["normal"] = self.compute_eikonal_view( xyz_cam, dir_cam, field2cam, frame_id, inst_id, samples_dict ) + # jacob_dict["eikonal"], jacob_dict["normal"] = self.compute_eikonal( + # xyz, inst_id=inst_id, sample_ratio=1.0 + # ) return jacob_dict def query_nerf(self, xyz, dir, frame_id, inst_id, valid_idx=None): @@ -788,6 +934,7 @@ def query_nerf(self, xyz, dir, frame_id, inst_id, valid_idx=None): % self.category: torch.zeros( valid_idx.shape + (1,), device=xyz.device ), + "sdf": torch.zeros(valid_idx.shape + (1,), device=xyz.device), } return field_dict # reshape @@ -797,18 +944,43 @@ def query_nerf(self, xyz, dir, frame_id, inst_id, valid_idx=None): frame_id = frame_id[:, None, None].expand(shape[:3])[valid_idx] inst_id = inst_id[:, None, None].expand(shape[:3])[valid_idx] - rgb, density = self.forward( + # # symmetrically normalize + # symm_ratio = 0.5 + # xyz_x = xyz[..., :1].clone() + # symm_mask = torch.rand_like(xyz_x) < symm_ratio + # xyz_x[symm_mask] = -xyz_x[symm_mask] + # xyz = torch.cat([xyz_x, xyz[..., 1:3]], -1) + + rgb, sdf, density = self.forward( xyz, dir=dir, frame_id=frame_id, inst_id=inst_id, ) # (M, N, D, x) + # # density drop out, to enforce motion to explain the missing density + # # get aabb + # ratio = 4 + # aabb = self.get_aabb() + # # select a random box from aabb with 1/ratio size + # aabb_size = aabb[..., 1, :] - aabb[..., 0, :] + # aabb_size_sub = aabb_size / ratio + # aabb_sub_min = aabb[..., 0, :] + torch.rand_like(aabb_size) * ( + # aabb_size - aabb_size_sub + # ) + # aabb_sub_max = aabb_sub_min + aabb_size_sub + # aabb_sub = torch.stack([aabb_sub_min, aabb_sub_max], -2) + # # check whether the point is inside the aabb + # inside_aabb = check_inside_aabb(xyz, aabb_sub) + # density[inside_aabb] = 0 + # reshape field_dict = { "rgb": rgb, + "sdf": sdf, "density": density, - "density_%s" % self.category: density, + "density_%s" + % self.category: (density / self.logibeta.exp()).detach(), # (0,1) } if valid_idx is not None: @@ -818,6 +990,25 @@ def query_nerf(self, xyz, dir, frame_id, inst_id, valid_idx=None): field_dict[k] = tmpv return field_dict + def wipe_loss(self, nsample=512): + # density drop out, to enforce motion to explain the missing density + # get aabb + ratio = 4 + aabb = self.get_aabb() + # select a random box from aabb with 1/ratio size + aabb_size = aabb[..., 1, :] - aabb[..., 0, :] + aabb_size_sub = aabb_size / ratio + aabb_sub_min = aabb[..., 0, :] + torch.rand_like(aabb_size) * ( + aabb_size - aabb_size_sub + ) + aabb_sub_max = aabb_sub_min + aabb_size_sub + aabb_sub = torch.stack([aabb_sub_min, aabb_sub_max], -2) + pts, frame_id, inst_id = self.sample_points_aabb(nsample, aabb=aabb_sub) + # check whether the point is inside the aabb + sdf, _ = self.forward(pts, frame_id=frame_id, inst_id=inst_id) + wipe_loss = (-sdf).exp().mean() + return wipe_loss + @staticmethod def cam_to_field(xyz_cam, dir_cam, field2cam): """Transform rays from camera SE(3) to object SE(3) @@ -926,25 +1117,6 @@ def cycle_loss(self, xyz, xyz_t, frame_id, inst_id, samples_dict={}): } return cyc_dict - @staticmethod - def flip_pair(tensor): - """Flip the tensor along the pair dimension - - Args: - tensor: (M*2, ...) Inputs [x0, x1, x2, x3, ..., x_{2k}, x_{2k+1}] - - Returns: - tensor: (M*2, ...) Outputs [x1, x0, x3, x2, ..., x_{2k+1}, x_{2k}] - """ - if torch.is_tensor(tensor): - if len(tensor) < 2: - return tensor - return tensor.view(tensor.shape[0] // 2, 2, -1).flip(1).view(tensor.shape) - elif isinstance(tensor, tuple): - return tuple([NeRF.flip_pair(t) for t in tensor]) - elif isinstance(tensor, dict): - return {k: NeRF.flip_pair(v) for k, v in tensor.items()} - @train_only_fields def compute_flow( self, @@ -962,7 +1134,7 @@ def compute_flow( Args: hxy: (M,N,D,3) Homogeneous pixel coordinates on the image plane - xyz: (M,N,D,3) Canonical field coordinates + xyz_t: (M,N,D,3) Canonical field coordinates at time t Kinv: (M,3,3) Inverse of camera intrinsics flow_thresh (float): Threshold for flow magnitude @@ -970,15 +1142,18 @@ def compute_flow( flow: (M,N,D,2) Optical flow proposal """ # flip the frame id - frame_id_next = self.flip_pair(frame_id) - field2cam_next = (self.flip_pair(field2cam[0]), self.flip_pair(field2cam[1])) - Kinv_next = self.flip_pair(Kinv) - samples_dict_next = self.flip_pair(samples_dict) + frame_id_next = flip_pair(frame_id) + field2cam_next = (flip_pair(field2cam[0]), flip_pair(field2cam[1])) + Kinv_next = flip_pair(Kinv) + samples_dict_next = flip_pair(samples_dict) # forward warp points to camera space xyz_cam_next = self.forward_warp( xyz, field2cam_next, frame_id_next, inst_id, samples_dict=samples_dict_next ) + # xyz_cam_next = self.flow_warp( + # xyz_t, field2cam_next, frame_id, inst_id, samples_dict + # ) # project to next camera image plane Kmat_next = Kmatinv(Kinv_next) # (M,1,1,3,3) @ (M,N,D,3) = (M,N,D,3) @@ -1004,3 +1179,63 @@ def cam_prior_loss(self): """ loss = self.camera_mlp.compute_distance_to_prior() return loss + + def get_camera(self, frame_id=None): + """Compute camera matrices in world units + + Returns: + field2cam (Dict): Maps field names ("fg" or "bg") to (M,4,4) cameras + """ + quat, trans = self.camera_mlp.get_vals(frame_id=frame_id) + trans = trans / self.logscale.exp() + field2cam = quaternion_translation_to_se3(quat, trans) + return field2cam + + def compute_field2world(self, up_direction=[0, -1, 0]): + """Compute SE(3) to transform points in the scene space to world space + For background, this is computed by detecting planes with ransac. + + Returns: + rect_se3: (4,4) SE(3) transform + """ + for inst_id in range(self.num_inst): + # TODO: move this to background nerf, and use each proxy geometry + mesh = self.extract_canonical_mesh(level=0.0, inst_id=inst_id) + self.field2world[inst_id] = compute_rectification_se3(mesh, up_direction) + + def get_field2world(self, inst_id=None): + """Compute SE(3) to transform points in the scene space to world space + For background, this is computed by detecting planes with ransac. + + Returns: + rect_se3: (4,4) SE(3) transform + """ + if inst_id is None: + field2world = self.field2world + else: + field2world = self.field2world[inst_id] + field2world = field2world.clone() + field2world[..., :3, 3] /= self.logscale.exp() + return field2world + + @torch.no_grad() + def visualize_floor_mesh(self, inst_id, to_world=False): + """Visualize floor and canonical mesh in the world space + Args: + inst_id: (int) Instance id + """ + field2world = self.get_field2world(inst_id) + world2field = field2world.inverse().cpu() + mesh = self.extract_canonical_mesh(level=0.0, inst_id=inst_id) + scale = self.logscale.exp().cpu().numpy() + mesh.vertices /= scale + mesh = append_xz_plane(mesh, world2field, gl=False, scale=20 * scale) + if to_world: + mesh.apply_transform(field2world.cpu().numpy()) + return mesh + + def valid_field2world(self): + if self.field2world.abs().sum() == 0: + return False + else: + return True diff --git a/lab4d/nnutils/pose.py b/lab4d/nnutils/pose.py index b0dad96..394d859 100644 --- a/lab4d/nnutils/pose.py +++ b/lab4d/nnutils/pose.py @@ -7,26 +7,34 @@ from lab4d.nnutils.base import CondMLP, BaseMLP, ScaleLayer from lab4d.nnutils.time import TimeMLP -from lab4d.utils.geom_utils import so3_to_exp_map +from lab4d.nnutils.embedding import TimeEmbedding +from lab4d.utils.geom_utils import ( + so3_to_exp_map, + rot_angle, + interpolate_slerp, + interpolate_linear, +) from lab4d.utils.quat_transform import ( axis_angle_to_quaternion, matrix_to_quaternion, quaternion_mul, quaternion_translation_to_dual_quaternion, - dual_quaternion_to_quaternion_translation, + dual_quaternion_mul, quaternion_translation_to_se3, + dual_quaternion_to_quaternion_translation, ) from lab4d.utils.skel_utils import ( fk_se3, get_predefined_skeleton, rest_joints_to_local, shift_joints_to_bones_dq, - shift_joints_to_bones, + apply_root_offset, ) from lab4d.utils.vis_utils import draw_cams +from lab4d.utils.torch_utils import reinit_model -class CameraMLP(TimeMLP): +class CameraMLP_old(TimeMLP): """Encode camera pose over time (rotation + translation) with an MLP Args: @@ -80,9 +88,11 @@ def __init__( # camera pose: field to camera self.base_quat = nn.Parameter(torch.zeros(self.time_embedding.num_vids, 4)) + self.base_trans = nn.Parameter(torch.zeros(self.time_embedding.num_vids, 3)) self.register_buffer( "init_vals", torch.tensor(rtmat, dtype=torch.float32), persistent=False ) + self.base_init() # override the loss function def loss_fn(gt): @@ -100,10 +110,10 @@ def base_init(self): base_rmat = rtmat[frame_offset[:-1], :3, :3] base_quat = matrix_to_quaternion(base_rmat) self.base_quat.data = base_quat + self.base_trans.data = rtmat[frame_offset[:-1], :3, 3] def mlp_init(self): """Initialize camera SE(3) transforms from external priors""" - self.base_init() super().mlp_init() # with torch.no_grad(): @@ -141,14 +151,328 @@ def get_vals(self, frame_id=None): if frame_id is None: inst_id = self.time_embedding.frame_to_vid else: - inst_id = self.time_embedding.raw_fid_to_vid[frame_id] + inst_id = self.time_embedding.raw_fid_to_vid[frame_id.long()] # multiply with per-instance base rotation base_quat = self.base_quat[inst_id] base_quat = F.normalize(base_quat, dim=-1) quat = quaternion_mul(quat, base_quat) + + base_trans = self.base_trans[inst_id] + trans = trans + base_trans return quat, trans +class CameraMLP(TimeMLP): + """Encode camera pose over time (rotation + translation) with an MLP + + Args: + rtmat: (N,4,4) Object to camera transform + frame_info (Dict): Metadata about the frames in a dataset + D (int): Number of linear layers + W (int): Number of hidden units in each MLP layer + num_freq_t (int): Number of frequencies in time Fourier embedding + skips (List(int)): List of layers to add skip connections at + activation (Function): Activation function to use (e.g. nn.ReLU()) + """ + + def __init__( + self, + rtmat, + frame_info=None, + D=2, + W=256, + num_freq_t=6, + skips=[], + activation=nn.ReLU(True), + ): + if frame_info is None: + num_frames = len(rtmat) + frame_info = { + "frame_offset": np.asarray([0, num_frames]), + "frame_mapping": list(range(num_frames)), + "frame_offset_raw": np.asarray([0, num_frames]), + } + # xyz encoding layers + super().__init__( + frame_info, + D=D, + W=W, + num_freq_t=num_freq_t, + skips=skips, + activation=activation, + ) + + self.time_embedding_rot = TimeEmbedding( + num_freq_t, + frame_info, + out_channels=W, + time_scale=1, + ) + + self.base_rot = BaseMLP( + D=D, + W=W, + in_channels=W, + out_channels=W, + skips=skips, + activation=activation, + final_act=True, + ) + + # output layers + self.trans = nn.Sequential( + nn.Linear(W, W // 2), + activation, + nn.Linear(W // 2, 3), + ) + self.quat = nn.Sequential( + nn.Linear(W, W // 2), + activation, + nn.Linear(W // 2, 4), + ) + + # camera pose: field to camera + self.base_quat = nn.Parameter(torch.zeros(self.time_embedding.num_vids, 4)) + self.register_buffer( + "init_vals", torch.tensor(rtmat, dtype=torch.float32), persistent=False + ) + self.base_init() + + # override the loss function + def loss_fn(gt): + quat, trans = self.get_vals() + pred = quaternion_translation_to_se3(quat, trans) + loss = F.mse_loss(pred, gt) + return loss + + self.loss_fn = loss_fn + + def base_init(self): + """Initialize base camera rotations from initial camera trajectory""" + rtmat = self.init_vals + frame_offset = self.get_frame_offset() + base_rmat = rtmat[frame_offset[:-1], :3, :3] + base_quat = matrix_to_quaternion(base_rmat) + self.base_quat.data = base_quat + + def mlp_init(self): + """Initialize camera SE(3) transforms from external priors""" + super().mlp_init() + + # with torch.no_grad(): + # os.makedirs("tmp", exist_ok=True) + # draw_cams(rtmat.cpu().numpy()).export("tmp/cameras_gt.obj") + # quat, trans = self.get_vals() + # rtmat_pred = quaternion_translation_to_se3(quat, trans) + # draw_cams(rtmat_pred.cpu()).export("tmp/cameras_pred.obj") + + def forward(self, t_embed, t_embed_rot): + """ + Args: + t_embed: (M, self.W) Input Fourier time embeddings + Returns: + quat: (M, 4) Output camera rotation quaternions + trans: (M, 3) Output camera translations + """ + t_feat = super().forward(t_embed) + trans = self.trans(t_feat) + quat = self.quat(self.base_rot(t_embed_rot)) + quat = F.normalize(quat, dim=-1) + return quat, trans + + def get_vals(self, frame_id=None): + """Compute camera pose at the given frames. + + Args: + frame_id: (M,) Frame id. If None, compute values at all frames + Returns: + quat: (M, 4) Output camera rotations + trans: (M, 3) Output camera translations + """ + t_embed = self.time_embedding(frame_id) + t_embed_rot = self.time_embedding_rot(frame_id) + quat, trans = self.forward(t_embed, t_embed_rot) + if frame_id is None: + inst_id = self.time_embedding.frame_to_vid + else: + inst_id = self.time_embedding.raw_fid_to_vid[frame_id.long()] + + # multiply with per-instance base rotation + base_quat = self.base_quat[inst_id] + base_quat = F.normalize(base_quat, dim=-1) + quat = quaternion_mul(quat, base_quat) + return quat, trans + + +class CameraMLP_so3(TimeMLP): + """Encode camera pose over time (rotation + translation) with an MLP + + Args: + rtmat: (N,4,4) Object to camera transform + frame_info (Dict): Metadata about the frames in a dataset + D (int): Number of linear layers + W (int): Number of hidden units in each MLP layer + num_freq_t (int): Number of frequencies in time Fourier embedding + skips (List(int)): List of layers to add skip connections at + activation (Function): Activation function to use (e.g. nn.ReLU()) + """ + + def __init__( + self, + rtmat, + frame_info=None, + D=5, + W=256, + num_freq_t=6, + skips=[], + activation=nn.ReLU(True), + ): + if frame_info is None: + num_frames = len(rtmat) + frame_info = { + "frame_offset": np.asarray([0, num_frames]), + "frame_mapping": list(range(num_frames)), + "frame_offset_raw": np.asarray([0, num_frames]), + } + # xyz encoding layers + super().__init__( + frame_info, + D=D, + W=W, + num_freq_t=num_freq_t, + skips=skips, + activation=activation, + ) + + self.time_embedding_rot = TimeEmbedding( + num_freq_t, + frame_info, + out_channels=W, + time_scale=1, + ) + + self.base_rot = BaseMLP( + D=D, + W=W, + in_channels=W, + out_channels=W, + skips=skips, + activation=activation, + final_act=True, + ) + + # output layers + self.trans = nn.Sequential( + nn.Linear(W, 3), + ) + self.so3 = nn.Sequential( + nn.Linear(W, 3), + ) + + # camera pose: field to camera + base_quat = torch.zeros(frame_info["frame_offset"][-1], 4) + base_trans = torch.zeros(frame_info["frame_offset"][-1], 3) + self.register_buffer("base_quat", base_quat) + self.register_buffer("base_trans", base_trans) + self.register_buffer( + "init_vals", torch.tensor(rtmat, dtype=torch.float32), persistent=False + ) + self.base_init() + + # override the loss function + def loss_fn(gt): + quat, trans = self.get_vals() + pred = quaternion_translation_to_se3(quat, trans) + loss = F.mse_loss(pred, gt) + return loss + + self.loss_fn = loss_fn + + def base_init(self): + """Initialize base camera rotations from initial camera trajectory""" + rtmat = self.init_vals + + # initialize with corresponding frame rotation + # self.base_quat.data = matrix_to_quaternion(rtmat[:, :3, :3]) + self.base_trans.data = rtmat[:, :3, 3] + + # initialize with per-sequence pose + frame_offset = self.get_frame_offset() + for i in range(len(frame_offset) - 1): + base_rmat = rtmat[frame_offset[i], :3, :3] + base_quat = matrix_to_quaternion(base_rmat) + self.base_quat.data[frame_offset[i] : frame_offset[i + 1]] = base_quat + + def mlp_init(self): + """Initialize camera SE(3) transforms from external priors""" + super().mlp_init() + + # with torch.no_grad(): + # os.makedirs("tmp", exist_ok=True) + # draw_cams(rtmat.cpu().numpy()).export("tmp/cameras_gt.obj") + # quat, trans = self.get_vals() + # rtmat_pred = quaternion_translation_to_se3(quat, trans) + # draw_cams(rtmat_pred.cpu()).export("tmp/cameras_pred.obj") + + def forward(self, t_embed, t_embed_rot): + """ + Args: + t_embed: (M, self.W) Input Fourier time embeddings + Returns: + quat: (M, 4) Output camera rotation quaternions + trans: (M, 3) Output camera translations + """ + t_feat = super().forward(t_embed) + trans = self.trans(t_feat) + so3 = self.so3(self.base_rot(t_embed_rot)) + quat = axis_angle_to_quaternion(so3) + return quat, trans + + def get_vals(self, frame_id=None): + """Compute camera pose at the given frames. + + Args: + frame_id: (M,) Frame id. If None, compute values at all frames + Returns: + quat: (M, 4) Output camera rotations + trans: (M, 3) Output camera translations + """ + t_embed = self.time_embedding(frame_id) + t_embed_rot = self.time_embedding_rot(frame_id) + quat, trans = self.forward(t_embed, t_embed_rot) + + # multiply with per-instance base rotation + if frame_id is None: + base_quat = self.base_quat + base_trans = self.base_trans + else: + base_quat, base_trans = self.interpolate_base(frame_id) + base_quat = F.normalize(base_quat, dim=-1) + quat = quaternion_mul(quat, base_quat) + trans = trans + base_trans + return quat, trans + + def update_base_quat(self): + """Update base camera rotations from current camera trajectory""" + self.base_quat.data, self.base_trans.data = self.get_vals() + # reinit the mlp head + reinit_model(self.so3, std=0.01) + reinit_model(self.trans, std=0.01) + + def interpolate_base(self, frame_id): + idx = self.time_embedding.frame_mapping_inv[frame_id.long()] + idx_ceil = idx + 1 + idx_ceil.clamp_(max=self.time_embedding.num_frames - 1) + t_len = ( + self.time_embedding.frame_mapping[idx_ceil] + - self.time_embedding.frame_mapping[idx] + ) + t_frac = frame_id - self.time_embedding.frame_mapping[idx] + t_frac = t_frac / (1e-6 + t_len) + base_quat = interpolate_slerp(self.base_quat, idx, idx + 1, t_frac) + base_trans = interpolate_linear(self.base_trans, idx, idx + 1, t_frac) + return base_quat, base_trans + class ArticulationBaseMLP(TimeMLP): """Base class for bone articulation model (bag-of-bones or skeleton) @@ -377,6 +701,8 @@ def __init__( self.logscale = nn.Parameter(torch.zeros(1)) self.shift = nn.Parameter(torch.zeros(3)) + self.register_buffer("orient", torch.tensor([1.0, 0.0, 0.0, 0.0])) + # instance bone length num_inst = len(frame_info["frame_offset"]) - 1 self.log_bone_len = CondMLP( @@ -459,10 +785,14 @@ def forward( local_rest_joints = override_local_rest_joints # run forward kinematics - out = fk_se3(local_rest_joints, so3, self.edges) - out = shift_joints_to_bones_dq(out, self.edges, shift=self.shift) + out = self.fk_se3(local_rest_joints, so3, self.edges) + out = self.shift_joints_to_bones(out) + out = apply_root_offset(out, self.shift, self.orient) return out + def shift_joints_to_bones(self, se3): + return shift_joints_to_bones_dq(se3, self.edges) + def compute_rel_rest_joints(self, inst_id=None, override_log_bone_len=None): """Compute relative position difference from parent to child bone coordinate frames, without scale @@ -475,7 +805,7 @@ def compute_rel_rest_joints(self, inst_id=None, override_log_bone_len=None): rel_rest_joints: Translations from parent to child joints """ # get relative joints - rel_rest_joints = rest_joints_to_local(self.rest_joints, self.edges) + rel_rest_joints = self.rest_joints_to_local(self.rest_joints, self.edges) # match the shape rel_rest_joints = rel_rest_joints[None] @@ -493,6 +823,14 @@ def compute_rel_rest_joints(self, inst_id=None, override_log_bone_len=None): rel_rest_joints = rel_rest_joints * bone_length[..., None] return rel_rest_joints + def fk_se3(self, local_rest_joints, so3, edges): + """Forward kinematics for a skeleton""" + return fk_se3(local_rest_joints, so3, edges) + + def rest_joints_to_local(self, rest_joints, edges): + """Convert rest joints to local coordinates""" + return rest_joints_to_local(rest_joints, edges) + def get_vals(self, frame_id=None, return_so3=False, override_so3=None): """Compute articulation parameters at the given frames. @@ -508,7 +846,7 @@ def get_vals(self, frame_id=None, return_so3=False, override_so3=None): if frame_id is None: inst_id = self.time_embedding.frame_to_vid else: - inst_id = self.time_embedding.raw_fid_to_vid[frame_id] + inst_id = self.time_embedding.raw_fid_to_vid[frame_id.long()] t_embed = self.time_embedding(frame_id) pred = self.forward( t_embed, inst_id, return_so3=return_so3, override_so3=override_so3 @@ -580,9 +918,12 @@ def skel_prior_loss(self): loss_so3 = so3.pow(2).mean() # get average log bone length increment + # inst_id = torch.arange(0, self.time_embedding.num_vids).long().to(device) + # empty_feat = torch.zeros_like(inst_id[:, None][:, :0]) # (1, 0) + # log_bone_len_inc = self.log_bone_len(empty_feat, inst_id) empty_feat = torch.zeros_like(so3[..., 0, :0]) # (1, 0) log_bone_len_inc = self.log_bone_len(empty_feat, None) - loss_bone = 0.02 * log_bone_len_inc.pow(2).mean() + loss_bone = 0.2 * log_bone_len_inc.pow(2).mean() loss = loss_so3 + loss_bone @@ -590,11 +931,136 @@ def skel_prior_loss(self): # device = self.parameters().__next__().device # t_embed = self.time_embedding.get_mean_embedding(device) # bones_dq = self.forward(t_embed, None) - # bones_pred = dual_quaternion_to_quaternion_translation(bones_dq)[1][0] # B,3 + # trans_pred, rot_pred = dual_quaternion_to_quaternion_translation(bones_dq)[1] - # joints_gt = self.rest_joints * self.logscale.exp() + self.shift[None] - # bones_gt = shift_joints_to_bones(joints_gt, self.edges) + # bones_dq = self.forward( + # None, + # None, + # override_so3=torch.zeros(1, self.num_se3, 3, device=device), + # override_log_bone_len=torch.zeros(1, self.num_se3, device=device), + # ) + # trans_gt, rot_gt = dual_quaternion_to_quaternion_translation(bones_dq)[1] # B,3 - # loss = (bones_gt - bones_pred).norm(2, -1).mean() - # loss = loss * 0.2 + # loss = (trans_gt - trans_pred).norm(2, -1).mean() + # trimesh.Trimesh(vertices=bones_pred.detach().cpu()).export("tmp/bones_pred.obj") + # trimesh.Trimesh(vertices=bones_gt.detach().cpu()).export("tmp/bones_gt.obj") return loss + + +class ArticulationURDFMLP(ArticulationSkelMLP): + """Encode a skeleton over time using an MLP + + Args: + frame_info (FrameInfo): Metadata about the frames in a dataset + skel_type (str): Skeleton type ("human" or "quad") + joint_angles: (B, 3) If provided, initial joint angles + num_se3 (int): Number of bones + D (int): Number of linear layers + W (int): Number of hidden units in each MLP layer + num_freq_t (int): Number of frequencies in time Fourier embedding + skips (List(int)): List of layers to add skip connections at + activation (Function): Activation function to use (e.g. nn.ReLU()) + """ + + def __init__( + self, + frame_info, + skel_type, + joint_angles, + D=5, + W=256, + num_freq_t=6, + skips=[], + activation=nn.ReLU(True), + ): + super().__init__( + frame_info, + skel_type, + joint_angles, + D=D, + W=W, + num_freq_t=num_freq_t, + skips=skips, + activation=activation, + ) + + ( + local_rest_coord, + scale_factor, + orient, + offset, + bone_centers, + bone_sizes, + ) = self.parse_urdf(skel_type) + self.logscale.data = torch.log(scale_factor) + self.shift.data = offset # same scale as object field + self.orient.data = orient + self.register_buffer("bone_centers", bone_centers, persistent=False) + self.register_buffer("bone_sizes", bone_sizes, persistent=False) + + # get local rest rotation matrices, pick the first coordinate in rpy of ball joints + # by default: transform points from child to parent + local_rest_coord = torch.tensor(local_rest_coord, dtype=torch.float32) + self.register_buffer("local_rest_coord", local_rest_coord, persistent=False) + self.rest_joints = None + + def parse_urdf(self, urdf_name): + """Load the URDF file for the skeleton""" + from urdfpy import URDF + + urdf_path = f"projects/ppr/ppr-diffphys/data/urdf_templates/{urdf_name}.urdf" + urdf = URDF.load(urdf_path) + + local_rest_coord = np.stack([i.origin for i in urdf.joints], 0)[::3] + + if urdf_name == "human": + offset = torch.tensor([0.0, 0.0, 0.0]) + orient = torch.tensor([0.0, -1.0, 0.0, 0.0]) # wxyz + scale_factor = torch.tensor([0.1]) + elif urdf_name == "quad": + offset = torch.tensor([0.0, -0.02, 0.02]) + orient = torch.tensor([1.0, -0.8, 0.0, 0.0]) + scale_factor = torch.tensor([0.1]) + else: + raise NotImplementedError + orient = F.normalize(orient, dim=-1) + + # get center/size of each link + bone_centers = [] + bone_sizes = [] + for link in urdf._reverse_topo: + if len(link.visuals) == 0: + continue + bone_bounds = link.collision_mesh.bounds + center = (bone_bounds[1] + bone_bounds[0]) / 2 + size = (bone_bounds[1] - bone_bounds[0]) / 2 + center = torch.tensor(center, dtype=torch.float) + size = torch.tensor(size, dtype=torch.float) + bone_centers.append(center) + bone_sizes.append(size) + + bone_centers = torch.stack(bone_centers, dim=0)[1:] # skip root + bone_sizes = torch.stack(bone_sizes, dim=0)[1:] # skip root + return local_rest_coord, scale_factor, orient, offset, bone_centers, bone_sizes + + def fk_se3(self, local_rest_joints, so3, edges): + return fk_se3( + local_rest_joints, + so3, + edges, + local_rest_coord=self.local_rest_coord.clone(), + ) + + def rest_joints_to_local(self, rest_joints, edges): + return self.local_rest_coord[:, :3, 3].clone() + + def shift_joints_to_bones(self, bone_to_obj): + idn_quat = torch.zeros_like(bone_to_obj[0]) + idn_quat[..., 0] = 1.0 + bone_centers = self.bone_centers.expand_as(idn_quat[..., :3]) + bone_centers = bone_centers * self.logscale.exp().clone() + link_transform = quaternion_translation_to_dual_quaternion( + idn_quat, bone_centers + ) + bone_to_obj = dual_quaternion_mul(bone_to_obj, link_transform) + return bone_to_obj diff --git a/lab4d/nnutils/skinning.py b/lab4d/nnutils/skinning.py index 6227c3d..5b65e37 100644 --- a/lab4d/nnutils/skinning.py +++ b/lab4d/nnutils/skinning.py @@ -10,6 +10,7 @@ from lab4d.utils.quat_transform import ( dual_quaternion_to_quaternion_translation, quaternion_to_matrix, + dual_quaternion_apply, ) from lab4d.utils.transforms import get_bone_coords from lab4d.utils.vis_utils import get_colormap @@ -58,17 +59,22 @@ def __init__( ): super().__init__() - # 3D gaussians - gaussians = init_scale * torch.ones( - num_coords, 3 - ) # scale of bone skinning field + # 3D gaussians: scale of bone skinning field + if torch.is_tensor(init_scale): + gaussians = init_scale + else: + gaussians = init_scale * torch.ones(num_coords, 3) + # clip minimum radius to 0.01 + gaussians = torch.clamp(gaussians, min=0.01) self.log_gauss = nn.Parameter(torch.log(gaussians)) + # self.register_buffer("log_gauss", torch.log(gaussians), persistent=False) + self.logscale = nn.Parameter(torch.zeros(1)) self.num_coords = num_coords if delta_skin: # position and direction embedding - self.pos_embedding = PosEmbedding(3 * num_coords, num_freq_xyz) - self.time_embedding = TimeEmbedding(num_freq_t, frame_info) + self.pos_embedding = PosEmbedding(4 * num_coords, num_freq_xyz) + # self.time_embedding = TimeEmbedding(num_freq_t, frame_info) # xyz encoding layers self.delta_field = CondMLP( @@ -76,9 +82,10 @@ def __init__( D=D, W=W, in_channels=self.pos_embedding.out_channels - + self.time_embedding.out_channels, + # + self.time_embedding.out_channels, + , inst_channels=inst_channels, - out_channels=num_coords, + out_channels=num_coords * 2, skips=skips, activation=activation, final_act=False, @@ -105,18 +112,24 @@ def forward(self, xyz, bone2obj, frame_id, inst_id): dist2 = xyz_bone.pow(2).sum(dim=-1) if hasattr(self, "delta_field"): + xyz_bone = torch.cat([xyz_bone, xyz_bone.norm(2, -1)[..., None]], dim=-1) # modulate with t/inst xyz_embed = self.pos_embedding(xyz_bone.reshape(xyz.shape[:-1] + (-1,))) - if frame_id is None: - t_embed = self.time_embedding.get_mean_embedding(xyz.device) - else: - t_embed = self.time_embedding(frame_id) - t_embed = t_embed.reshape(-1, 1, 1, t_embed.shape[-1]) - t_embed = t_embed.expand(xyz.shape[:-1] + (-1,)) - xyzt_embed = torch.cat([xyz_embed, t_embed], dim=-1) - delta = self.delta_field(xyzt_embed, inst_id) - delta = F.relu(delta) * 0.1 - skin = -(dist2 + delta) + # if frame_id is None: + # t_embed = self.time_embedding.get_mean_embedding(xyz.device) + # else: + # t_embed = self.time_embedding(frame_id) + # t_embed = t_embed.reshape(-1, 1, 1, t_embed.shape[-1]) + # t_embed = t_embed.expand(xyz.shape[:-1] + (-1,)) + # xyzt_embed = torch.cat([xyz_embed, t_embed], dim=-1) + # delta = self.delta_field(xyzt_embed, inst_id) + delta = self.delta_field(xyz_embed, inst_id) + # delta = F.relu(delta) * 0.1 + # skin = -(dist2 + delta) + logscale, shift = torch.split(delta, delta.shape[-1] // 2, dim=-1) + dist2 = dist2 * (0.1 * logscale).exp() + dist2 = dist2 + 0.1 * shift + skin = -dist2 else: skin = -dist2 delta = None @@ -150,9 +163,35 @@ def get_gauss(self): log_gauss = self.log_gauss if self.symm_idx is not None: log_gauss = (log_gauss[self.symm_idx] + log_gauss) / 2 + log_gauss = log_gauss + self.logscale return log_gauss.exp() - def draw_gaussian(self, articulation, edges): + def get_gauss_pts(self, articulation): + """ + Compute gaussian points (differentiable wrt articulation) + Args: + articulation: ((B,4), (B,4)) Bone-to-object SE(3) transforms, + written as dual quaternions + """ + dev = articulation[0].device + gaussians = self.get_gauss() # B,3 + + # append gaussians + sph = trimesh.creation.uv_sphere(radius=1, count=[4, 4]) + pts = torch.tensor(sph.vertices, device=dev, dtype=torch.float32) + pts = pts[:, None] * gaussians[None] # N,B,3 + + # apply articulation + articulation = ( + articulation[0][None].repeat(pts.shape[0], 1, 1), + articulation[1][None].repeat(pts.shape[0], 1, 1), + ) + pts = dual_quaternion_apply(articulation, pts) # N,B,3 + pts = pts.view(-1, 3) # NB,3 + return pts + + @torch.no_grad() + def draw_gaussian(self, articulation, edges, show_joints=False): """Visualize Gaussian bones as a mesh Args: @@ -161,41 +200,42 @@ def draw_gaussian(self, articulation, edges): edges (Dict(int, int) or None): If given, a mapping from each joint to its parent joint on an articulated skeleton """ - with torch.no_grad(): - meshes = [] - gaussians = self.get_gauss().cpu().numpy() - - qr, trans = dual_quaternion_to_quaternion_translation(articulation) - articulation = np.eye(4, 4)[None].repeat(len(qr), axis=0) - articulation[:, :3, :3] = quaternion_to_matrix(qr).cpu().numpy() - articulation[:, :3, 3] = trans.cpu().numpy() - - # add bone center / joints - sph = trimesh.creation.uv_sphere(radius=1, count=[4, 4]) - colormap = get_colormap(self.num_coords, repeat=sph.vertices.shape[0]) - for k, gauss in enumerate(gaussians): - ellips = sph.copy() - # make it smaller for visualization + meshes = [] + gaussians = self.get_gauss().cpu().numpy() + + qr, trans = dual_quaternion_to_quaternion_translation(articulation) + articulation = np.eye(4, 4)[None].repeat(len(qr), axis=0) + articulation[:, :3, :3] = quaternion_to_matrix(qr).cpu().numpy() + articulation[:, :3, 3] = trans.cpu().numpy() + + # add bone center / joints + sph = trimesh.creation.uv_sphere(radius=1, count=[4, 4]) + colormap = get_colormap(self.num_coords, repeat=sph.vertices.shape[0]) + for k, gauss in enumerate(gaussians): + ellips = sph.copy() + # make it smaller for visualization + if show_joints: ellips.vertices *= 5e-3 - # ellips.vertices *= gauss[None] - ellips.apply_transform(articulation[k]) - meshes.append(ellips) - - # add edges if any - if edges is not None: - # rad = gaussians.mean() * 0.1 - rad = 5e-4 - for idx, parent_idx in edges.items(): - if parent_idx == 0: - continue - parent_center = articulation[parent_idx - 1][:3, 3] - child_center = articulation[idx - 1][:3, 3] - cyl = np.stack([parent_center, child_center], 0) - cyl = trimesh.creation.cylinder(rad, segment=cyl, sections=3) - meshes.append(cyl) - - meshes = trimesh.util.concatenate(meshes) - colormap_pad = np.ones((meshes.vertices.shape[0] - colormap.shape[0], 3)) - colormap = np.concatenate([colormap, 192 * colormap_pad], 0) - meshes.visual.vertex_colors = colormap - return meshes + else: + ellips.vertices *= gauss[None] + ellips.apply_transform(articulation[k]) + meshes.append(ellips) + + # add edges if any + if edges is not None: + # rad = gaussians.mean() * 0.1 + rad = 5e-4 + for idx, parent_idx in edges.items(): + if parent_idx == 0: + continue + parent_center = articulation[parent_idx - 1][:3, 3] + child_center = articulation[idx - 1][:3, 3] + cyl = np.stack([parent_center, child_center], 0) + cyl = trimesh.creation.cylinder(rad, segment=cyl, sections=3) + meshes.append(cyl) + + meshes = trimesh.util.concatenate(meshes) + colormap_pad = np.ones((meshes.vertices.shape[0] - colormap.shape[0], 3)) + colormap = np.concatenate([colormap, 192 * colormap_pad], 0) + meshes.visual.vertex_colors = colormap + return meshes diff --git a/lab4d/nnutils/time.py b/lab4d/nnutils/time.py index 4f4b05a..f66f79f 100644 --- a/lab4d/nnutils/time.py +++ b/lab4d/nnutils/time.py @@ -5,7 +5,12 @@ import torch.nn.functional as F from lab4d.nnutils.base import BaseMLP -from lab4d.nnutils.embedding import PosEmbedding, TimeEmbedding, get_fourier_embed_dim +from lab4d.nnutils.embedding import ( + PosEmbedding, + TimeEmbedding, + TimeEmbeddingRest, + get_fourier_embed_dim, +) class TimeMLP(BaseMLP): @@ -31,7 +36,12 @@ def __init__( skips=[], activation=nn.ReLU(True), time_scale=1.0, + bottleneck_dim=None, + has_rest=False, ): + if bottleneck_dim is None: + bottleneck_dim = W + frame_offset = frame_info["frame_offset"] # frame_offset_raw = frame_info["frame_offset_raw"] if num_freq_t > 0: @@ -46,15 +56,20 @@ def __init__( super().__init__( D=D, W=W, - in_channels=W, + in_channels=bottleneck_dim, out_channels=W, skips=skips, activation=activation, final_act=True, ) - self.time_embedding = TimeEmbedding( - num_freq_t, frame_info, out_channels=W, time_scale=time_scale + if has_rest: + arch = TimeEmbeddingRest + else: + arch = TimeEmbedding + + self.time_embedding = arch( + num_freq_t, frame_info, out_channels=bottleneck_dim, time_scale=time_scale ) def loss_fn(y): diff --git a/lab4d/nnutils/visibility.py b/lab4d/nnutils/visibility.py index 437046c..f70e477 100644 --- a/lab4d/nnutils/visibility.py +++ b/lab4d/nnutils/visibility.py @@ -31,6 +31,7 @@ def __init__( inst_channels=32, skips=[4], activation=nn.ReLU(True), + field_arch=CondMLP, ): super().__init__() @@ -38,7 +39,7 @@ def __init__( self.pos_embedding = PosEmbedding(3, num_freq_xyz) # xyz encoding layers - self.basefield = CondMLP( + self.basefield = field_arch( num_inst=num_inst, D=D, W=W, diff --git a/lab4d/nnutils/warping.py b/lab4d/nnutils/warping.py index 0361c63..340ad11 100644 --- a/lab4d/nnutils/warping.py +++ b/lab4d/nnutils/warping.py @@ -6,13 +6,23 @@ from lab4d.nnutils.base import CondMLP from lab4d.nnutils.embedding import PosEmbedding, TimeEmbedding -from lab4d.nnutils.pose import ArticulationFlatMLP, ArticulationSkelMLP +from lab4d.nnutils.pose import ( + ArticulationFlatMLP, + ArticulationSkelMLP, + ArticulationURDFMLP, +) from lab4d.nnutils.skinning import SkinningField from lab4d.third_party.nvp import NVP -from lab4d.utils.geom_utils import dual_quaternion_skinning, marching_cubes, extend_aabb +from lab4d.utils.geom_utils import ( + dual_quaternion_skinning, + marching_cubes, + extend_aabb, + linear_blend_skinning, +) from lab4d.utils.quat_transform import dual_quaternion_inverse, dual_quaternion_mul from lab4d.utils.transforms import get_xyz_bone_distance, get_bone_coords from lab4d.utils.loss_utils import entropy_loss, cross_entropy_skin_loss +from lab4d.utils.torch_utils import flip_pair def create_warp(fg_motion, data_info): @@ -41,7 +51,13 @@ def create_warp(fg_motion, data_info): elif fg_motion.startswith("skel-"): warp = SkinningWarp( frame_info, - skel_type=fg_motion.split("-")[1], + skel_type=fg_motion, + joint_angles=joint_angles, + ) + elif fg_motion.startswith("urdf-"): + warp = SkinningWarp( + frame_info, + skel_type=fg_motion, joint_angles=joint_angles, ) elif fg_motion.startswith("comp"): @@ -71,14 +87,14 @@ def __init__(self, frame_info, num_freq_xyz=10, num_freq_t=6): self.num_inst = len(frame_info["frame_offset"]) - 1 def forward( - self, xyz, frame_id, inst_id, backward=False, samples_dict={}, return_aux=False + self, xyz, frame_id, inst_id, type="forward", samples_dict={}, return_aux=False ): """ Args: xyz: (M,N,D,3) Points in object canonical space frame_id: (M,) Frame id. If None, warp for all frames inst_id: (M,) Instance id. If None, warp for the average instance - backward (bool): Forward (=> deformed) or backward (=> canonical) + type (str): Forward (=> deformed), backward (=> canonical), or flow (t1=>t2) samples_dict (Dict): Only used for SkeletonWarp Returns: xyz: (M,N,D,3) Warped xyz coordinates @@ -141,14 +157,14 @@ def __init__(self, frame_info, num_freq_xyz=6, num_freq_t=6, D=6, W=256): ) def forward( - self, xyz, frame_id, inst_id, backward=False, samples_dict={}, return_aux=False + self, xyz, frame_id, inst_id, type="forward", samples_dict={}, return_aux=False ): """ Args: xyz: (M,N,D,3) Points in object canonical space frame_id: (M,) Frame id. If None, warp for all frames inst_id: (M,) Instance id. If None, warp for the average instance - backward (bool): Forward (=> deformed) or backward (=> canonical) + type (str): Forward (=> deformed), backward (=> canonical), or flow (t1=>t2) samples_dict (Dict): Only used for SkeletonWarp Returns: xyz: (M,N,D,3) Warped xyz coordinates @@ -158,11 +174,25 @@ def forward( t_embed = t_embed.reshape(-1, 1, 1, t_embed.shape[-1]) t_embed = t_embed.expand(xyz.shape[:-1] + (-1,)) embed = torch.cat([xyz_embed, t_embed], dim=-1) - if backward: + if type == "backward": motion = self.backward_map(embed, inst_id) - else: + out = xyz + motion * 0.1 # control the scale + elif type == "forward": motion = self.forward_map(embed, inst_id) - out = xyz + motion * 0.1 # control the scale + out = xyz + motion * 0.1 # control the scale + elif type == "flow": + # TODO: use dx/dt to compute flow + raise NotImplementedError + # motion = self.backward_map(embed, inst_id) + # xyz_canonical = xyz + motion * 0.1 + # xyz_canonical_embed = self.pos_embedding(xyz_canonical) + # t_embed_flip = flip_pair(t_embed) + # embed_flip = torch.cat([xyz_canonical_embed, t_embed_flip], dim=-1) + # motion = self.forward_map(embed_flip, inst_id) + # out = xyz_canonical + motion * 0.1 # control the scale + else: + raise NotImplementedError + warp_dict = {} if return_aux: return out, warp_dict @@ -199,14 +229,14 @@ def __init__(self, frame_info, num_freq_xyz=6, num_freq_t=6, D=2): ) def forward( - self, xyz, frame_id, inst_id, backward=False, samples_dict={}, return_aux=False + self, xyz, frame_id, inst_id, type="forward", samples_dict={}, return_aux=False ): """ Args: xyz: (M,N,D,3) Points in object canonical space frame_id: (M) Frame id. If None, warp for all frames inst_id: (M) Instance id. If None, warp for the average instance - backward (bool): Forward (=> deformed) or backward (=> canonical) + type (str): Forward (=> deformed), backward (=> canonical), or flow (t1=>t2) samples_dict (Dict): Only used for SkeletonWarp Returns: out: (..., 3) Warped xyz coordinates @@ -215,10 +245,18 @@ def forward( t_embed = t_embed.reshape(-1, 1, 1, t_embed.shape[-1]) t_embed = t_embed.expand(xyz.shape[:-1] + (-1,)) # (M, N, D, x) t_embed = t_embed[:, 0] # (M, D, x) vs (M, N, D, 3) - if backward: + if type == "backward": out = self.map.inverse(t_embed, xyz) - else: + elif type == "forward": out = self.map.forward(t_embed, xyz) + elif type == "flow": + # TODO: use dx/dt to compute flow + raise NotImplementedError + # out = self.map.inverse(t_embed, xyz) + # t_embed_flip = flip_pair(t_embed) + # out = self.map.forward(t_embed_flip, out) + else: + raise NotImplementedError warp_dict = {} if return_aux: return out, warp_dict @@ -257,10 +295,21 @@ def __init__( if skel_type == "flat": self.articulation = ArticulationFlatMLP(frame_info, num_se3) symm_idx = None - else: + elif skel_type.startswith("skel-"): + skel_type = skel_type.split("-")[1] self.articulation = ArticulationSkelMLP(frame_info, skel_type, joint_angles) num_se3 = self.articulation.num_se3 symm_idx = self.articulation.symm_idx + elif skel_type.startswith("urdf-"): + skel_type = skel_type.split("-")[1] + self.articulation = ArticulationURDFMLP(frame_info, skel_type, joint_angles) + num_se3 = self.articulation.num_se3 + symm_idx = self.articulation.symm_idx + init_gauss_scale = ( + self.articulation.bone_sizes * self.articulation.logscale.exp() + ) + else: + raise NotImplementedError self.skinning_model = SkinningField( num_se3, @@ -275,7 +324,7 @@ def __init__( self.logibeta = nn.Parameter(-beta.log()) # beta: transparency def forward( - self, xyz, frame_id, inst_id, backward=False, samples_dict={}, return_aux=False + self, xyz, frame_id, inst_id, type="forward", samples_dict={}, return_aux=False ): """Warp points according to a skinning field and articulated bones @@ -283,7 +332,7 @@ def forward( xyz: (M,N,D,3) Points in object canonical space frame_id: (M,) Frame id. If None, warp for all frames inst_id: (M,) Instance id. If None, warp for the mean instance - backward (bool): Forward (=> deformed) or backward (=> canonical) + type (str): Forward (=> deformed), backward (=> canonical), or flow (t1=>t2) samples_dict: Time-dependent bone articulations. Keys: "rest_articulation": ((M,B,4), (M,B,4)) and "t_articulation": ((M,B,4), (M,B,4)) @@ -301,17 +350,25 @@ def forward( ) = self.articulation.get_vals_and_mean(frame_id) # compute per bone se3 - if backward: + if type == "backward": se3 = dual_quaternion_mul( rest_articulation, dual_quaternion_inverse(t_articulation) ) articulation = t_articulation - else: + elif type == "forward": se3 = dual_quaternion_mul( t_articulation, dual_quaternion_inverse(rest_articulation) ) articulation = rest_articulation frame_id = None + elif type == "flow": + t_articulation_flip = flip_pair(t_articulation) + se3 = dual_quaternion_mul( + t_articulation_flip, dual_quaternion_inverse(t_articulation) + ) + articulation = t_articulation + else: + raise NotImplementedError articulation = ( articulation[0][:, None, None].expand(xyz.shape[:3] + (-1, -1)), @@ -320,13 +377,22 @@ def forward( # skinning weights skin, delta_skin = self.skinning_model(xyz, articulation, frame_id, inst_id) + # # debug: hard selection + # max_bone = 3 + # topk, indices = skin.topk(max_bone, 3, largest=True) + # skin = torch.zeros_like(skin).fill_(-torch.inf) + # skin = skin.scatter(3, indices, topk) skin_prob = skin.softmax(-1) # left-multiply per-point se3 out = dual_quaternion_skinning(se3, xyz, skin_prob) + # # linear blend skinning + # out = linear_blend_skinning(se3, xyz, skin_prob) + warp_dict = {} - warp_dict["skin_entropy"] = cross_entropy_skin_loss(skin)[..., None] + # warp_dict["skin_entropy"] = cross_entropy_skin_loss(skin)[..., None] + warp_dict["skin_entropy"] = torch.zeros_like(skin[..., :1]) # TODO: remove if delta_skin is not None: # (M, N, D, 1) warp_dict["delta_skin"] = delta_skin.pow(2).mean(-1, keepdims=True) @@ -352,6 +418,27 @@ def get_gauss_sdf(self, xyz, bias=0.0): sdf = sdf + bias return sdf + def get_xyz_bone_distance(self, xyz, bone2obj=None): + """ + Args: + xyz: (N, 3) Points in object canonical space + bone2obj: ((1/N,B,4), (M,B,4)) Bone-to-object SE(3) transforms, + Returns: + dist2: (N,B) Squared distance to each bone + """ + if isinstance(self.articulation, ArticulationURDFMLP): + # gauss bones + skinning + xyz = xyz[:, None, None] # (N,1,1,3) + bone2obj = ( + bone2obj[0][:, None, None].expand(xyz.shape[0], -1, -1, -1, -1), + bone2obj[1][:, None, None].expand(xyz.shape[0], -1, -1, -1, -1), + ) # (N,1,1,K,4) + dist2 = -self.skinning_model.forward(xyz, bone2obj, None, None)[0][:, 0, 0] + else: + dist2 = get_xyz_bone_distance(xyz, bone2obj) # N,K + dist2 = dist2 / (0.01) ** 2 # assuming spheres of radius 0.01 + return dist2 + def get_gauss_density(self, xyz, bone2obj=None): """Sample volumetric density at Gaussian bones @@ -364,45 +451,54 @@ def get_gauss_density(self, xyz, bone2obj=None): """ if bone2obj is None: bone2obj = self.articulation.get_mean_vals() # 1,K,4,4 - - dist2 = get_xyz_bone_distance(xyz, bone2obj) # N,K - dist2 = dist2 / (0.01) ** 2 # assuming spheres of radius 0.01 - - # # gauss bones - # xyz = xyz[:, None, None] # (N,1,1,3) - # bone2obj = ( - # bone2obj[0][None, None].repeat(xyz.shape[0], 1, 1, 1, 1), - # bone2obj[1][None, None].repeat(xyz.shape[0], 1, 1, 1, 1), - # ) # (N,1,1,K,4) - # dist2 = -self.skinning_model.forward( - # xyz, bone2obj, None, None, normalize=False - # )[0][:, 0, 0] - + dist2 = self.get_xyz_bone_distance(xyz, bone2obj) # (N,K) score = (-0.5 * dist2).exp() # (N,K) - # hard selection - density = score.max(-1)[0] # (N,) + # # hard selection + # density = score.max(-1)[0] # (N,) + + # soften the density selection + mix_prob = F.softmax(-dist2, -1) + density = (score * mix_prob).sum(-1) density = density[..., None] return density - def get_template_vis(self, aabb): - """Visualize Gaussian density and SDF as meshes. + def get_gauss_pts(self): + """Sample points from Gaussian bones""" + articulation = self.articulation.get_mean_vals() # (1,K,4,4) + articulation = (articulation[0][0], articulation[1][0]) + pts = self.skinning_model.get_gauss_pts(articulation) + return pts + + def get_gauss_vis(self, show_joints=False): + """Visualize Gaussians as meshes. Args: aabb: (2,3) Axis-aligned bounding box Returns: mesh_gauss (Trimesh): Gaussian density mesh - mesh_sdf (Trimesh): SDF mesh """ articulation = self.articulation.get_mean_vals() # (1,K,4,4) articulation = (articulation[0][0], articulation[1][0]) mesh_gauss = self.skinning_model.draw_gaussian( - articulation, self.articulation.edges + articulation, self.articulation.edges, show_joints=show_joints ) + return mesh_gauss + + def get_template_vis(self, aabb): + """Visualize Gaussian density and SDF as meshes. + + Args: + aabb: (2,3) Axis-aligned bounding box + Returns: + mesh_gauss (Trimesh): Gaussian density mesh + mesh_sdf (Trimesh): SDF mesh + """ + mesh_gauss = self.get_gauss_vis() sdf_func = lambda xyz: self.get_gauss_sdf(xyz) - mesh_sdf = marching_cubes(sdf_func, aabb, level=0.005) + mesh_sdf = marching_cubes(sdf_func, aabb, level=0.0) return mesh_gauss, mesh_sdf @@ -427,14 +523,14 @@ def __init__( # e.g., comp_skel-human_dense, limited to skel+another type of field type_list = warp_type.split("_")[1:] assert len(type_list) == 2 - assert type_list[0] in ["skel-human", "skel-quad"] + assert type_list[0] in ["skel-human", "skel-quad", "urdf-human", "urdf-quad"] assert type_list[1] in ["bob", "dense"] if type_list[1] == "bob": raise NotImplementedError super().__init__( frame_info, - skel_type=type_list[0].split("-")[1], + skel_type=type_list[0], joint_angles=joint_angles, ) # self.post_warp = DenseWarp(frame_info, D=2, W=64) @@ -443,7 +539,7 @@ def __init__( # self.post_warp = NVPWarp(frame_info) def forward( - self, xyz, frame_id, inst_id, backward=False, samples_dict={}, return_aux=False + self, xyz, frame_id, inst_id, type="forward", samples_dict={}, return_aux=False ): """Warp points according to a skinning field and articulated bones @@ -451,7 +547,7 @@ def forward( xyz: (M,N,D,3) Points in object canonical space frame_id: (M,) Frame id. If None, warp for all frames inst_id: (M,) Instance id. If None, warp for the mean instance - backward (bool): Forward (=> deformed) or backward (=> canonical) + type (str): Forward (=> deformed), backward (=> canonical), or flow (t1=>t2) samples_dict: Time-dependent bone articulations. Keys: "rest_articulation": ((M,B,4), (M,B,4)) and "t_articulation": ((M,B,4), (M,B,4)) @@ -459,23 +555,25 @@ def forward( out: (M,N,D,3) Warped xyz coordinates """ # if forward, and has frame_id - if not backward and frame_id is not None: + if type == "flow": + raise NotImplementedError + if (type == "forward" or type == "flow") and frame_id is not None: xyz = self.post_warp.forward( - xyz, frame_id, inst_id, backward=False, samples_dict=samples_dict + xyz, frame_id, inst_id, type="forward", samples_dict=samples_dict ) out, warp_dict = super().forward( xyz, frame_id, inst_id, - backward=backward, + type=type, samples_dict=samples_dict, return_aux=True, ) - if backward and frame_id is not None: + if (type == "backward" or type == "flow") and frame_id is not None: out = self.post_warp.forward( - out, frame_id, inst_id, backward=True, samples_dict=samples_dict + out, frame_id, inst_id, type="backward", samples_dict=samples_dict ) if return_aux: return out, warp_dict @@ -493,11 +591,11 @@ def compute_post_warp_dist2(self, xyz, frame_id, inst_id): Returns: dist2: (M, ...) Squared soft deformation distance """ - xyz_t = self.post_warp.forward(xyz, frame_id, inst_id, backward=False) + xyz_t = self.post_warp.forward(xyz, frame_id, inst_id, type="forward") dist2 = (xyz_t - xyz).pow(2).sum(-1) # additional cycle consistency regularization for soft deformation if isinstance(self.post_warp, DenseWarp): - xyz_back = self.post_warp.forward(xyz_t, frame_id, inst_id, backward=True) + xyz_back = self.post_warp.forward(xyz_t, frame_id, inst_id, type="backward") dist2 = (dist2 + (xyz_t - xyz_back).pow(2).sum(-1)) * 0.5 return dist2 diff --git a/lab4d/render.py b/lab4d/render.py index b2aea7a..f41480c 100644 --- a/lab4d/render.py +++ b/lab4d/render.py @@ -74,6 +74,9 @@ def construct_batch_from_opts(opts, model, data_info): frameid_start = data_info["frame_info"]["frame_offset_raw"][video_id] frameid_sub = frameid - frameid_start render_length = len(frameid) + # remove last frame to be consistent with flow + frameid_sub = frameid_sub[:-1] + render_length = render_length - 1 elif opts["freeze_id"] >= 0 and opts["freeze_id"] < vid_length: if opts["num_frames"] <= 0: num_frames = vid_length @@ -87,11 +90,12 @@ def construct_batch_from_opts(opts, model, data_info): # get cameras wrt each field with torch.no_grad(): + frameid = torch.tensor(frameid, device=device) field2cam_fr = model.fields.get_cameras(frame_id=frameid) intrinsics_fr = model.intrinsics.get_vals( frameid_sub + data_info["frame_info"]["frame_offset_raw"][video_id] ) - aabb = model.fields.get_aabb() + aabb = model.fields.get_aabb(inst_id=opts["inst_id"]) # convert to numpy for k, v in field2cam_fr.items(): field2cam_fr[k] = v.cpu().numpy() @@ -114,7 +118,7 @@ def construct_batch_from_opts(opts, model, data_info): elev, max_angle = [int(val) for val in opts["viewpoint"].split("-")[1:]] # bg_to_cam - obj_size = (aabb["fg"][1, :] - aabb["fg"][0, :]).max() + obj_size = (aabb["fg"][0, 1, :] - aabb["fg"][0, 0, :]).max() cam_traj = get_rotating_cam( len(frameid_sub), distance=obj_size * 2.5, max_angle=max_angle ) @@ -191,6 +195,10 @@ def render(opts, construct_batch_func): opts["logroot"] = sys.argv[1].split("=")[1].rsplit("/", 2)[0] model, data_info, ref_dict = Trainer.construct_test_model(opts) batch, raw_size = construct_batch_func(opts, model, data_info) + # # TODO: make eval_res and render_res consistent + if opts["render_res"] == opts["eval_res"] and opts["viewpoint"] == "ref": + feature = ref_dict["ref_feature"].reshape(-1, opts["eval_res"] ** 2, 16) + batch["feature"] = torch.tensor(feature, device="cuda") save_dir = make_save_dir( opts, sub_dir="renderings_%04d/%s" % (opts["inst_id"], opts["viewpoint"]) ) diff --git a/lab4d/render_intermediate.py b/lab4d/render_intermediate.py new file mode 100644 index 0000000..0ed8458 --- /dev/null +++ b/lab4d/render_intermediate.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +# python lab4d/render_intermediate.py --testdir logdir/human-48-category-comp/ +import sys, os +import pdb + +import glob +import numpy as np +import cv2 +import argparse +import trimesh +import tqdm + + +cwd = os.getcwd() +if cwd not in sys.path: + sys.path.insert(0, cwd) +from lab4d.utils.io import save_vid +from lab4d.utils.pyrender_wrapper import PyRenderWrapper + +parser = argparse.ArgumentParser(description="script to render cameras over epochs") +parser.add_argument("--testdir", default="", help="path to test dir") +parser.add_argument("--show_bones", action="store_true", help="if render bones") +parser.add_argument( + "--data_class", default="fg", type=str, help="which data to render, {fg, bg}" +) +args = parser.parse_args() + + +def main(): + renderer = PyRenderWrapper() + # io + path_list = [ + i for i in glob.glob("%s/*-%s-proxy.obj" % (args.testdir, args.data_class)) + ] + if len(path_list) == 0: + print("no mesh found in %s for %s" % (args.testdir, args.data_class)) + return + path_list = sorted(path_list, key=lambda x: int(x.split("/")[-1].split("-")[0])) + outdir = "%s/renderings_proxy" % args.testdir + os.makedirs(outdir, exist_ok=True) + + mesh_dict = {} + bone_dict = {} + aabb_min = np.asarray([np.inf, np.inf]) + aabb_max = np.asarray([-np.inf, -np.inf]) + for mesh_path in path_list: + batch_idx = int(mesh_path.split("/")[-1].split("-")[0]) + mesh_obj = trimesh.load(mesh_path) + if args.show_bones: + bone_dict[batch_idx] = trimesh.load(mesh_path.replace("-proxy", "-gauss")) + mesh_dict[batch_idx] = mesh_obj + + # update aabb + aabb_min = np.minimum(aabb_min, mesh_obj.bounds[0, [0, 2]]) # x,z coords + aabb_max = np.maximum(aabb_max, mesh_obj.bounds[1, [0, 2]]) + + # set camera translation + renderer.set_camera_bev(depth=max(aabb_max - aabb_min)) + + # render + frames = [] + for batch_idx, mesh_obj in tqdm.tqdm(mesh_dict.items()): + input_dict = {"shape": mesh_obj} + if args.show_bones: + input_dict["bone"] = bone_dict[batch_idx] + input_dict["shape"].visual.vertex_colors[3:] = 192 + color = renderer.render(input_dict)[0] + # add text + color = color.astype(np.uint8) + color = cv2.putText( + color, + "batch: %02d" % batch_idx, + (30, 50), + cv2.FONT_HERSHEY_SIMPLEX, + 2, + (256, 0, 0), + 2, + ) + frames.append(color) + + save_path = "%s/%s" % (outdir, args.data_class) + save_vid(save_path, frames, suffix=".mp4", upsample_frame=-1) + print("saved to %s.mp4" % (save_path)) + + +if __name__ == "__main__": + main() diff --git a/lab4d/render_mesh.py b/lab4d/render_mesh.py new file mode 100644 index 0000000..3c09d13 --- /dev/null +++ b/lab4d/render_mesh.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +# python lab4d/render_mesh.py --testdir logdir//ama-bouncing-4v-ppr/export_0000/ --view bev --ghosting +import sys, os +import pdb +import json +import glob +import numpy as np +import cv2 +import argparse +import trimesh +import tqdm + + +cwd = os.getcwd() +if cwd not in sys.path: + sys.path.insert(0, cwd) +from lab4d.utils.io import save_vid +from lab4d.utils.pyrender_wrapper import PyRenderWrapper +from lab4d.utils.mesh_loader import MeshLoader + +parser = argparse.ArgumentParser(description="script to render extraced meshes") +parser.add_argument("--testdir", default="", help="path to the directory with results") +parser.add_argument("--fps", default=30, type=int, help="fps of the video") +parser.add_argument("--mode", default="", type=str, help="{shape, bone}") +parser.add_argument("--compose_mode", default="", type=str, help="{object, scene}") +parser.add_argument("--ghosting", action="store_true", help="ghosting") +parser.add_argument("--view", default="ref", type=str, help="{ref, bev, front}") +args = parser.parse_args() + + +def main(): + loader = MeshLoader(args.testdir, args.mode, args.compose_mode) + loader.print_info() + loader.load_files(ghosting=args.ghosting) + + # render + raw_size = loader.raw_size + renderer = PyRenderWrapper(raw_size) + print("Rendering [%s]:" % args.view) + frames = [] + for frame_idx, mesh_obj in tqdm.tqdm(loader.mesh_dict.items()): + # input dict + input_dict = loader.query_frame(frame_idx) + + if loader.compose_mode == "primary": + # set camera extrinsics + renderer.set_camera(loader.extr_dict[frame_idx]) + # set camera intrinsics + renderer.set_intrinsics(loader.intrinsics[frame_idx]) + else: + if args.view == "ref": + # set camera extrinsics + renderer.set_camera(loader.extr_dict[frame_idx]) + # set camera intrinsics + renderer.set_intrinsics(loader.intrinsics[frame_idx]) + elif args.view == "bev": + # bev + renderer.set_camera_bev(depth=max(loader.aabb_max - loader.aabb_min)) + # set camera intrinsics + fl = max(raw_size) + intr = np.asarray([fl * 2, fl * 2, raw_size[1] / 2, raw_size[0] / 2]) + renderer.set_intrinsics(intr) + elif args.view == "front": + # frontal view + renderer.set_camera_frontal(25, delta=0.0) + # set camera intrinsics + fl = max(raw_size) + intr = np.asarray( + [fl * 4, fl * 4, raw_size[1] / 2, raw_size[0] / 4 * 3] + ) + renderer.set_intrinsics(intr) + renderer.align_light_to_camera() + + color = renderer.render(input_dict)[0] + # add text + color = color.astype(np.uint8) + color = cv2.putText( + color, + "frame: %02d" % frame_idx, + (30, 50), + cv2.FONT_HERSHEY_SIMPLEX, + 2, + (256, 0, 0), + 2, + ) + frames.append(color) + + save_path = "%s/render-%s-%s-%s" % ( + args.testdir, + loader.mode, + loader.compose_mode, + args.view, + ) + save_vid( + save_path, + frames, + suffix=".mp4", + upsample_frame=-1, + fps=args.fps, + ) + print("saved to %s.mp4" % save_path) + + +if __name__ == "__main__": + main() diff --git a/lab4d/utils/geom_utils.py b/lab4d/utils/geom_utils.py index cf015eb..1348ff5 100644 --- a/lab4d/utils/geom_utils.py +++ b/lab4d/utils/geom_utils.py @@ -1,13 +1,17 @@ # Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +import cv2 import numpy as np import torch +import torch.nn.functional as F import trimesh from scipy.spatial.transform import Rotation as R from skimage import measure +import open3d as o3d from lab4d.utils.quat_transform import ( dual_quaternion_apply, quaternion_translation_apply, + dual_quaternion_to_se3, ) @@ -83,6 +87,101 @@ def dual_quaternion_skinning(dual_quat, pts, skin): return pts +def linear_blend_skinning(dual_quat, xyz, skin_prob): + """Attach points to SE(3) bones according to skinning weights + + Args: + dual_quat: ((M,B,4), (M,B,4)) per-bone SE(3) transforms, + written as dual quaternions + xyz: (M, ..., 3) Points in object canonical space + skin_prob: (M, ..., B) Skinning weights from each point to each bone + Returns: + pts: (M, ..., 3) Articulated points + """ + shape = xyz.shape + xyz = xyz.view(shape[0], -1, 3) # M, N*D, 3 + skin_prob = skin_prob.view(shape[0], -1, skin_prob.shape[-1]) # M, N*D, B + se3 = dual_quaternion_to_se3(dual_quat) # M,B,4,4 + # M ND B 4 4 + out = se3[:, None, :, :3, :3] @ xyz[:, :, None, :, None] + out = out + se3[:, None, :, :3, 3:4] # M,ND,B,3,1 + out = (out[..., 0] * skin_prob[..., None]).sum(-2) # M,ND,B,3 + out = out.view(shape) + return out + + +def slerp(val, low, high, eps=1e-6): + """ + Args: + val: (M,) Interpolation value + low: (M,4) Low quaternions + high: (M,4) High quaternions + Returns: + out: (M,4) Interpolated quaternions + """ + # Normalize input quaternions. + low_norm = F.normalize(low, dim=1) + high_norm = F.normalize(high, dim=1) + + # Compute cosine of angle between quaternions. + cos_angle = torch.clamp((low_norm * high_norm).sum(dim=1), -1.0 + eps, 1.0 - eps) + omega = torch.acos(cos_angle) + + so = torch.sin(omega) + t1 = torch.sin((1.0 - val) * omega) / (so + eps) + t2 = torch.sin(val * omega) / (so + eps) + return t1.unsqueeze(-1) * low + t2.unsqueeze(-1) * high + + +def interpolate_slerp(y, idx_floor, idx_ceil, t_frac): + """ + Args: + y: (N,4) Quaternions + idx_floor: (M,) Floor indices + idx_ceil: (M,) Ceil indices + t_frac: (M,) Fractional indices (0-1) + Returns: + y_interpolated: (M,4) Interpolated quaternions + """ + # Use integer parts to index y + idx_ceil.clamp_(max=len(y) - 1) + y_floor = y[idx_floor] + y_ceil = y[idx_ceil] + + # Check dot product to ensure the shortest path + dp = torch.sum(y_floor * y_ceil, dim=-1, keepdim=True) + y_ceil = torch.where(dp < 0.0, -y_ceil, y_ceil) + + # Normalize quaternions to be sure + y_floor_norm = F.normalize(y_floor, dim=1) + y_ceil_norm = F.normalize(y_ceil, dim=1) + + # Compute interpolated quaternion + y_interpolated = slerp(t_frac, y_floor_norm, y_ceil_norm) + y_interpolated_norm = F.normalize(y_interpolated, dim=1) + return y_interpolated_norm + + +def interpolate_linear(y, idx_floor, idx_ceil, t_frac): + """ + Args: + y: (N,4) translation + idx_floor: (M,) Floor indices + idx_ceil: (M,) Ceil indices + t_frac: (M,) Fractional indices (0-1) + Returns: + y_interpolated: (M,4) Interpolated translation + """ + # Use integer parts to index y + idx_ceil.clamp_(max=len(y) - 1) + y_floor = y[idx_floor] + y_ceil = y[idx_ceil] + + # Compute interpolated quaternion + y_interpolated = y_floor + t_frac[..., None] * (y_ceil - y_floor) + return y_interpolated + + def hat_map(v): """Returns the skew-symmetric matrix corresponding to the last dimension of a PyTorch tensor. @@ -411,14 +510,15 @@ def extend_aabb(aabb, factor=0.1): If aabb = [-1,1] and factor = 1, the extended aabb will be [-3,3] Args: - aabb: Axis-aligned bounding box, (2,3) + aabb: Axis-aligned bounding box, ((N,)2,3) factor (float): Amount to extend on each side Returns: - aabb_new: Extended aabb, (2,3) + aabb_new: Extended aabb, ((N,)2,3) """ aabb_new = aabb.clone() - aabb_new[0] = aabb[0] - (aabb[1] - aabb[0]) * factor - aabb_new[1] = aabb[1] + (aabb[1] - aabb[0]) * factor + size = (aabb[..., 1, :] - aabb[..., 0, :]) * factor + aabb_new[..., 0, :] = aabb[..., 0, :] - size + aabb_new[..., 1, :] = aabb[..., 1, :] + size return aabb_new @@ -507,11 +607,276 @@ def check_inside_aabb(xyz, aabb): """Return a mask of whether the input poins are inside the aabb Args: - xyz: (N,3) Points in object canonical space to query - aabb: (2,3) axis-aligned bounding box + xyz: (N,...,3) Points in object canonical space to query + aabb: (N,2,3) axis-aligned bounding box Returns: - inside_aabb: (N) Inside mask, bool + inside_aabb: (N,...,) Inside mask, bool """ # check whether the point is inside the aabb - inside_aabb = ((xyz > aabb[:1]) & (xyz < aabb[1:])).all(-1) + shape = xyz.shape[:-1] + aabb = aabb.view((aabb.shape[0], 2) + (1,) * (len(shape) - 1) + (3,)) + inside_aabb = ((xyz > aabb[:, 0]) & (xyz < aabb[:, 1])).all(-1) return inside_aabb + + +def compute_rectification_se3(mesh, up_direction, threshold=0.01, init_n=3, iter=2000): + # run ransac to get plane + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(mesh.vertices) + hypos = [] + n_hypo = 5 + for _ in range(n_hypo): + if len(pcd.points) < 3: + break + best_eq, index = pcd.segment_plane(threshold, init_n, iter) + # visibile plane given z direction + if best_eq[2] > 0: + best_eq = -1 * best_eq + + segmented_pts = pcd.select_by_index(index) + pts_left = np.asarray(pcd.points)[~np.isin(np.arange(len(pcd.points)), index)] + pcd.points = o3d.utility.Vector3dVector(pts_left) + # print("segmented plane pts: ", len(segmented_pts.points) / len(mesh.vertices)) + score = np.asarray(up_direction).dot(best_eq[:3]) + hypos.append((best_eq, segmented_pts, score)) + # find the one with best score + best_eq, segmented_pts, score = sorted(hypos, key=lambda x: x[-1])[-1] + + # get se3 + plane_n = np.asarray(best_eq[:3]) + center = np.asarray(segmented_pts.points).mean(0) + dist = (center * plane_n).sum() + best_eq[3] + plane_o = center - plane_n * dist + plane = np.concatenate([plane_o, plane_n]) + + # xz plane + bg2world = plane_transform(origin=plane[:3], normal=plane[3:6], axis=[0, -1, 0]) + + # further transform the xz plane center to align with origin + mesh_rectified = mesh.copy() + mesh_rectified.apply_transform(bg2world) + bounds = mesh_rectified.bounds + center = (bounds[0] + bounds[1]) / 2 + bg2world[0, 3] -= center[0] + bg2world[2, 3] -= center[2] + + # # DEBUG only + # mesh.export("tmp/raw.obj") + # mesh.apply_transform(bg2world) + # mesh.export("tmp/rect.obj") + # import pdb + + # pdb.set_trace() + + bg2world = torch.Tensor(bg2world) + return bg2world + + +def plane_transform(origin, normal, axis=[0, 1, 0]): + """ + # modified from https://github.com/mikedh/trimesh/blob/main/trimesh/geometry.py#L14 + Given the origin and normal of a plane find the transform + that will move that plane to be coplanar with the XZ plane. + Parameters + ---------- + origin : (3,) float + Point that lies on the plane + normal : (3,) float + Vector that points along normal of plane + Returns + --------- + transform: (4,4) float + Transformation matrix to move points onto XZ plane + """ + normal = normal / (1e-6 + np.linalg.norm(normal)) + # transform = align_vectors(normal, axis) + transform = np.eye(4) + transform[:3, :3] = align_vector_a_to_b(normal, axis) + if origin is not None: + transform[:3, 3] = -np.dot(transform, np.append(origin, 1))[:3] + return transform + + +def align_vector_a_to_b(a, b): + """Find the rotation matrix that transforms one 3D vector + to another. + Args: + a : (3,) float + Unit vector + b : (3,) float + Unit vector + Returns: + matrix : (3, 3) float + Rotation matrix to rotate from `a` to `b` + """ + # Ensure the vectors are numpy arrays + a = np.array(a) + b = np.array(b) + + # Check if vectors are non-zero + if np.linalg.norm(a) == 0 or np.linalg.norm(b) == 0: + raise ValueError("Vectors must be non-zero") + + # Normalize the vectors + a_hat = a / np.linalg.norm(a) + b_hat = b / np.linalg.norm(b) + + # Compute the rotation axis (normal to the plane formed by a and b) + axis = np.cross(a_hat, b_hat) + + # Compute the cosine of the angle between a_hat and b_hat + cos_angle = np.dot(a_hat, b_hat) + + # Handling numerical imprecision + cos_angle = np.clip(cos_angle, -1.0, 1.0) + + # Compute the angle of rotation + angle = np.arccos(cos_angle) + + # If vectors are parallel or anti-parallel, no axis is determined. Handle separately + if np.isclose(angle, 0.0): + return np.eye(3) # Identity matrix, no rotation needed + elif np.isclose(angle, np.pi): + # Find a perpendicular vector + axis = np.cross(a_hat, np.array([1, 0, 0])) + if np.linalg.norm(axis) < 1e-10: + axis = np.cross(a_hat, np.array([0, 1, 0])) + axis = axis / np.linalg.norm(axis) # Normalize axis + + # Compute the rotation matrix using the axis-angle representation + axis_matrix = np.array( + [[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]] + ) + + rotation_matrix = ( + np.eye(3) + + np.sin(angle) * axis_matrix + + (1 - np.cos(angle)) * np.dot(axis_matrix, axis_matrix) + ) + + return rotation_matrix + + +def align_vectors(a, b, return_angle=False): + """ + # modified from https://github.com/mikedh/trimesh/blob/main/trimesh/geometry.py#L38 + Find the rotation matrix that transforms one 3D vector + to another. + Parameters + ------------ + a : (3,) float + Unit vector + b : (3,) float + Unit vector + return_angle : bool + Return the angle between vectors or not + Returns + ------------- + matrix : (4, 4) float + Homogeneous transform to rotate from `a` to `b` + angle : float + If `return_angle` angle in radians between `a` and `b` + """ + a = np.array(a, dtype=np.float64) + b = np.array(b, dtype=np.float64) + if a.shape != (3,) or b.shape != (3,): + raise ValueError("vectors must be (3,)!") + + # find the SVD of the two vectors + au = np.linalg.svd(a.reshape((-1, 1)))[0] + bu = np.linalg.svd(b.reshape((-1, 1)))[0] + + if np.linalg.det(au) < 0: + au[:, -1] *= -1.0 + if np.linalg.det(bu) < 0: + bu[:, -1] *= -1.0 + + # put rotation into homogeneous transformation + matrix = np.eye(4) + matrix[:3, :3] = bu.dot(au.T) + + if return_angle: + # projection of a onto b + # first row of SVD result is normalized source vector + dot = np.dot(au[0], bu[0]) + # clip to avoid floating point error + angle = np.arccos(np.clip(dot, -1.0, 1.0)) + if dot < -1e-5: + angle += np.pi + return matrix, angle + + return matrix + + +def se3_inv(rtmat): + """Invert an SE(3) matrix + + Args: + rtmat: (..., 4, 4) SE(3) matrix + Returns: + rtmat_inv: (..., 4, 4) Inverse SE(3) matrix + """ + rmat, tmat = se3_mat2rt(rtmat) + rmat = rmat.transpose(-1, -2) + tmat = -rmat @ tmat[..., None] + rtmat[..., :3, :3] = rmat + rtmat[..., :3, 3] = tmat[..., 0] + return rtmat + + +def rotation_over_plane(N, dim1, dim2, angle): + """ + Create an N x N rotation matrix for a rotation over the plane defined by + the two dimensions dim1 and dim2. + + Parameters: + - N (int): The dimensionality. + - dim1/2 (int): The axis around which to rotate (0-based index). + - angle (float): The rotation angle in degrees. + + Returns: + - ndarray: The rotation matrix. + """ + + # Basic error check + if N == 1: + return np.eye(1) + if dim1 >= N or dim1 < 0 or dim2 >= N or dim2 < 0: + raise ValueError("The axis index i is out of bounds for dimensionality N.") + + # Calculate cosine and sine values for the rotation angle + c = np.cos(angle) + s = np.sin(angle) + + # Create the 2D rotation block + R_2D = np.array([[c, -s], [s, c]]) + + # Insert the 2D rotation block into the top-left + R_ND = np.eye(N) + R_ND[:2, :2] = R_2D + + # If dim is not 0, create the permutation matrix and apply the axis swapping + P = np.eye(N) + P[0], P[dim1] = P[dim1].copy(), P[0].copy() # Swap rows + P[1], P[dim2] = P[dim2].copy(), P[1].copy() # Swap columns + R_ND = P @ R_ND @ P.T # Apply permutation to the base rotation matrix + + return R_ND + + +def get_pre_rotation(in_channels): + """Get the pre-rotation matrix for the input coordinates in positional encoding + + Args: + in_channels (int): Number of input channels + + Returns: + rot_mat (ndarray): Rotation matrix + """ + rot_mat = [np.eye(in_channels)] + angle = np.pi / 4 + for dim1 in range(in_channels): + for dim2 in range(dim1): + rot_mat.append(rotation_over_plane(in_channels, dim1, dim2, angle)) + rot_mat = np.concatenate(rot_mat, axis=0) + return rot_mat diff --git a/lab4d/utils/io.py b/lab4d/utils/io.py index ffd44ed..0ea3dac 100644 --- a/lab4d/utils/io.py +++ b/lab4d/utils/io.py @@ -25,6 +25,13 @@ def make_save_dir(opts, sub_dir="renderings"): return save_dir +def resize_to_nearest_multiple(image, multiple=16): + height, width = image.shape[:2] + new_height = int(np.ceil(height / multiple) * multiple) + new_width = int(np.ceil(width / multiple) * multiple) + return cv2.resize(image, (new_width, new_height)) + + def save_vid( outpath, frames, @@ -67,7 +74,10 @@ def save_vid( frame = cv2.resize(frame, (w, h)) frame_150.append(frame) - imageio.mimsave("%s%s" % (outpath, suffix), frame_150, fps=fps) + + # to make divisible by 16 + frame_150_resized = [resize_to_nearest_multiple(frame) for frame in frame_150] + imageio.mimsave("%s%s" % (outpath, suffix), frame_150_resized, fps=fps) def save_rendered(rendered, save_dir, raw_size, pca_fn): diff --git a/lab4d/utils/loss_utils.py b/lab4d/utils/loss_utils.py index c75e555..7e55080 100644 --- a/lab4d/utils/loss_utils.py +++ b/lab4d/utils/loss_utils.py @@ -42,17 +42,23 @@ def cross_entropy_skin_loss(skin): return cross_entropy -def align_vectors(v1, v2): +def align_tensors(v1, v2, dim=None): """Return the scale that best aligns v1 to v2 in the L2 sense: min || kv1-v2 ||^2 Args: v1: (...,) Source vector v2: (...,) Target vector + dim: Dimension to align. If None, return a scalar Returns: scale_fac (1,): Scale factor """ - scale_fac = (v1 * v2).sum() / (v1 * v1).sum() - if scale_fac < 0: - scale_fac = torch.tensor([1.0], device=scale_fac.device) - return scale_fac + if dim is None: + scale = (v1 * v2).sum() / (v1 * v1).sum() + if scale < 0: + scale = torch.tensor([1.0], device=scale.device) + return scale + else: + scale = (v1 * v2).sum(dim, keepdim=True) / (v1 * v1).sum(dim, keepdim=True) + scale[scale < 0] = 1.0 + return scale diff --git a/lab4d/utils/mesh_loader.py b/lab4d/utils/mesh_loader.py new file mode 100644 index 0000000..3aef542 --- /dev/null +++ b/lab4d/utils/mesh_loader.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +import json +import glob +import numpy as np +import cv2 +import argparse +import trimesh +import tqdm + + +class MeshLoader: + def __init__(self, testdir, mode, compose_mode): + # io + camera_info = json.load(open("%s/camera.json" % (testdir), "r")) + intrinsics = np.asarray(camera_info["intrinsics"], dtype=np.float32) + raw_size = camera_info["raw_size"] # h,w + if len(glob.glob("%s/fg/mesh/*.obj" % (testdir))) > 0: + primary_dir = "%s/fg" % testdir + secondary_dir = "%s/bg" % testdir + else: + primary_dir = "%s/bg" % testdir + secondary_dir = "%s/fg" % testdir # never use fg for secondary + path_list = sorted([i for i in glob.glob("%s/mesh/*.obj" % (primary_dir))]) + if len(path_list) == 0: + print("no mesh found that matches %s*" % (primary_dir)) + raise ValueError + + # check render mode + if mode != "": + pass + elif len(glob.glob("%s/bone/*" % primary_dir)) > 0: + mode = "bone" + else: + mode = "shape" + + if compose_mode != "": + pass + elif len(glob.glob("%s/mesh/*" % secondary_dir)) > 0: + compose_mode = "compose" + else: + compose_mode = "primary" + + # get cam dict + field2cam_fg_dict = json.load(open("%s/motion.json" % (primary_dir), "r")) + field2cam_fg_dict = field2cam_fg_dict["field2cam"] + if compose_mode == "compose": + field2cam_bg_dict = json.load(open("%s/motion.json" % (secondary_dir), "r")) + field2cam_bg_dict = np.asarray(field2cam_bg_dict["field2cam"]) + + field2world_path = "%s/bg/field2world.json" % (testdir) + field2world = np.asarray(json.load(open(field2world_path, "r"))) + world2field = np.linalg.inv(field2world) + + self.mode = mode + self.compose_mode = compose_mode + self.testdir = testdir + self.intrinsics = intrinsics + self.raw_size = raw_size + self.path_list = path_list + self.field2cam_fg_dict = field2cam_fg_dict + if compose_mode == "compose": + self.field2cam_bg_dict = field2cam_bg_dict + self.field2world = field2world + self.world2field = world2field + else: + self.field2cam_bg_dict = None + self.field2world = None + self.world2field = None + + def __len__(self): + return len(self.path_list) + + def load_files(self, ghosting=False): + mode = self.mode + compose_mode = self.compose_mode + path_list = self.path_list + field2cam_fg_dict = self.field2cam_fg_dict + field2cam_bg_dict = self.field2cam_bg_dict + field2world = self.field2world + world2field = self.world2field + + mesh_dict = {} + extr_dict = {} + bone_dict = {} + scene_dict = {} + ghost_dict = {} + aabb_min = np.asarray([np.inf, np.inf]) + aabb_max = np.asarray([-np.inf, -np.inf]) + for counter, mesh_path in enumerate(path_list): + frame_idx = int(mesh_path.split("/")[-1].split(".")[0]) + mesh = trimesh.load(mesh_path, process=False) + mesh.visual.vertex_colors = ( + mesh.visual.vertex_colors + ) # visual.kind = 'vertex' + field2cam_fg = np.asarray(field2cam_fg_dict[frame_idx]) + + # post-modify the scale of the fg + # mesh.vertices = mesh.vertices / 2 + # field2cam_fg[:3, 3] = field2cam_fg[:3, 3] / 2 + + mesh_dict[frame_idx] = mesh + extr_dict[frame_idx] = field2cam_fg + + if mode == "bone": + # load bone + bone_path = mesh_path.replace("mesh", "bone") + bone = trimesh.load(bone_path, process=False) + bone.visual.vertex_colors = bone.visual.vertex_colors + bone_dict[frame_idx] = bone + + if compose_mode == "compose": + # load scene + scene_path = mesh_path.replace("fg/mesh", "bg/mesh") + scene = trimesh.load(scene_path, process=False) + scene.visual.vertex_colors = scene.visual.vertex_colors + + # align bg floor with xz plane + scene.vertices = ( + scene.vertices @ field2world[:3, :3].T + field2world[:3, 3] + ) + field2cam_bg = field2cam_bg_dict[frame_idx] @ world2field + field2cam_bg_dict[frame_idx] = field2cam_bg + + scene_dict[frame_idx] = scene + # use scene camera + extr_dict[frame_idx] = field2cam_bg_dict[frame_idx] + # transform to scene + object_to_scene = ( + np.linalg.inv(field2cam_bg_dict[frame_idx]) @ field2cam_fg + ) + mesh_dict[frame_idx].apply_transform(object_to_scene) + if mode == "bone": + bone_dict[frame_idx].apply_transform(object_to_scene) + + if ghosting: + total_ghost = 10 + ghost_skip = len(path_list) // total_ghost + if "ghost_list" in locals(): + if counter % ghost_skip == 0: + mesh_ghost = mesh_dict[frame_idx].copy() + mesh_ghost.visual.vertex_colors[:, 3] = 102 + ghost_list.append(mesh_ghost) + else: + ghost_list = [mesh_dict[frame_idx]] + ghost_dict[frame_idx] = [mesh.copy() for mesh in ghost_list] + + # update aabb # x,z coords + if compose_mode == "compose": + bounds = scene_dict[frame_idx].bounds + else: + bounds = mesh_dict[frame_idx].bounds + aabb_min = np.minimum(aabb_min, bounds[0, [0, 2]]) + aabb_max = np.maximum(aabb_max, bounds[1, [0, 2]]) + + self.mesh_dict = mesh_dict + self.extr_dict = extr_dict + self.bone_dict = bone_dict + self.scene_dict = scene_dict + self.ghost_dict = ghost_dict + self.aabb_min = aabb_min + self.aabb_max = aabb_max + + def query_frame(self, frame_idx): + input_dict = {} + input_dict["shape"] = self.mesh_dict[frame_idx] + if self.mode == "bone": + input_dict["bone"] = self.bone_dict[frame_idx] + # make shape transparent and gray + input_dict["shape"].visual.vertex_colors[:3] = 102 + input_dict["shape"].visual.vertex_colors[3:] = 192 + if self.compose_mode == "compose": + scene_mesh = self.scene_dict[frame_idx] + scene_mesh.visual.vertex_colors[:, :3] = np.asarray([[224, 224, 54]]) + input_dict["scene"] = scene_mesh + if len(self.ghost_dict) > 0: + ghost_mesh = trimesh.util.concatenate(self.ghost_dict[frame_idx]) + input_dict["ghost"] = ghost_mesh + return input_dict + + def print_info(self): + print( + "[mode=%s, compose=%s] rendering %d meshes from %s" + % (self.mode, self.compose_mode, len(self), self.testdir) + ) diff --git a/lab4d/utils/numpy_utils.py b/lab4d/utils/numpy_utils.py index af7bdd1..9e93bc2 100644 --- a/lab4d/utils/numpy_utils.py +++ b/lab4d/utils/numpy_utils.py @@ -35,7 +35,18 @@ def interp_wt(x, y, x2, type="linear"): # Transform back to original space y2 = 10**log_y2 + elif type == "exp": + # clip + assert x0 >= 1 + assert x1 >= 1 + x2 = np.clip(x2, x0, x1) + # Transform to log space + log_x0 = np.log10(x0) + log_x1 = np.log10(x1) + log_x2 = np.log10(x2) + # Perform linear interpolation in log space + y2 = y0 + (log_x2 - log_x0) * (y1 - y0) / (log_x1 - log_x0) else: raise ValueError("interpolation_type must be 'linear' or 'log'") diff --git a/lab4d/utils/pyrender_wrapper.py b/lab4d/utils/pyrender_wrapper.py new file mode 100644 index 0000000..18a4a4a --- /dev/null +++ b/lab4d/utils/pyrender_wrapper.py @@ -0,0 +1,152 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. + +import os +import numpy as np +import cv2 +import pdb +import pyrender +import trimesh +from pyrender import ( + IntrinsicsCamera, + Mesh, + Node, + Scene, + OffscreenRenderer, + MetallicRoughnessMaterial, +) + +os.environ["PYOPENGL_PLATFORM"] = "egl" + + +class PyRenderWrapper: + def __init__(self, image_size=(1024, 1024)) -> None: + # renderer + self.image_size = image_size + render_size = max(image_size) + self.r = OffscreenRenderer(render_size, render_size) + self.intrinsics = IntrinsicsCamera( + render_size, render_size, render_size / 2, render_size / 2 + ) + # light + self.light_pose = np.eye(4) + self.set_light_topdown() + self.direc_l = pyrender.DirectionalLight(color=np.ones(3), intensity=5.0) + self.material = MetallicRoughnessMaterial( + roughnessFactor=0.75, metallicFactor=0.75, alphaMode="BLEND" + ) + self.init_camera() + + def init_camera(self): + # cv to gl coords + self.flip_pose = -np.eye(4) + self.flip_pose[0, 0] = 1 + self.flip_pose[-1, -1] = 1 + self.set_camera(np.eye(4)) + + def set_camera_bev(self, depth, gl=False): + # object to camera transforms + if gl: + rot = cv2.Rodrigues(np.asarray([-np.pi / 2, 0, 0]))[0] + else: + rot = cv2.Rodrigues(np.asarray([np.pi / 2, 0, 0]))[0] + scene_to_cam = np.eye(4) + scene_to_cam[:3, :3] = rot + scene_to_cam[2, 3] = depth + self.scene_to_cam = self.flip_pose @ scene_to_cam + + def set_camera_frontal(self, depth, gl=False, delta=0.0): + # object to camera transforms + if gl: + rot = cv2.Rodrigues(np.asarray([np.pi + np.pi / 180, delta, 0]))[0] + else: + rot = cv2.Rodrigues(np.asarray([np.pi / 180, delta, 0]))[0] + scene_to_cam = np.eye(4) + scene_to_cam[:3, :3] = rot + scene_to_cam[2, 3] = depth + self.scene_to_cam = self.flip_pose @ scene_to_cam + + def set_camera(self, scene_to_cam): + # object to camera transforms + self.scene_to_cam = self.flip_pose @ scene_to_cam + + def set_light_topdown(self, gl=False): + # top down light, slightly closer to the camera + if gl: + rot = cv2.Rodrigues(np.asarray([-np.pi / 2, 0, 0]))[0] + else: + rot = cv2.Rodrigues(np.asarray([np.pi / 2, 0, 0]))[0] + self.light_pose[:3, :3] = rot + + def align_light_to_camera(self): + self.light_pose = np.linalg.inv(self.scene_to_cam) + + def set_intrinsics(self, intrinsics): + """ + Args: + intrinsics: (4,) fx,fy,px,py + """ + self.intrinsics = IntrinsicsCamera( + intrinsics[0], intrinsics[1], intrinsics[2], intrinsics[3] + ) + + def get_cam_to_scene(self): + cam_to_scene = np.eye(4) + cam_to_scene[:3, :3] = self.scene_to_cam[:3, :3].T + cam_to_scene[:3, 3] = -self.scene_to_cam[:3, :3].T @ self.scene_to_cam[:3, 3] + return cam_to_scene + + def render(self, input_dict): + """ + Args: + input_dict: Dict of trimesh objects. Keys: shape, bone + "shape": trimesh object + "bone": trimesh object + Returns: + color: (H,W,3) + depth: (H,W) + """ + scene = Scene(ambient_light=0.1 * np.asarray([1.0, 1.0, 1.0, 1.0])) + + # add shape / camera + if "bone" in input_dict: + # add bone + mesh_pyrender = Mesh.from_trimesh(input_dict["bone"], smooth=False) + mesh_pyrender.primitives[0].material = self.material + scene.add_node(Node(mesh=mesh_pyrender)) + # else: + # # make shape gray + # input_dict["shape"].visual.vertex_colors[:, :3] = 102 + + if "scene" in input_dict: + # add scene + mesh_pyrender = Mesh.from_trimesh(input_dict["scene"], smooth=False) + mesh_pyrender.primitives[0].material = self.material + scene.add_node(Node(mesh=mesh_pyrender)) + + # shape + mesh_pyrender = Mesh.from_trimesh(input_dict["shape"], smooth=False) + mesh_pyrender.primitives[0].material = self.material + scene.add_node(Node(mesh=mesh_pyrender)) + if "ghost" in input_dict: + mesh_pyrender = Mesh.from_trimesh(input_dict["ghost"], smooth=False) + mesh_pyrender.primitives[0].material = self.material + scene.add_node(Node(mesh=mesh_pyrender)) + + # camera + scene.add(self.intrinsics, pose=self.get_cam_to_scene()) + + # light + scene.add(self.direc_l, pose=self.light_pose) + + # render + if "ghost" in input_dict: + flags = 0 + else: + flags = pyrender.RenderFlags.SHADOWS_DIRECTIONAL + color, depth = self.r.render(scene, flags=flags) + color = color[: self.image_size[0], : self.image_size[1]] + depth = depth[: self.image_size[0], : self.image_size[1]] + return color, depth + + def delete(self): + self.r.delete() diff --git a/lab4d/utils/render_utils.py b/lab4d/utils/render_utils.py index 342482d..7ebc2a9 100644 --- a/lab4d/utils/render_utils.py +++ b/lab4d/utils/render_utils.py @@ -5,7 +5,7 @@ import torch.nn.functional as F -def sample_cam_rays(hxy, Kinv, near_far, n_depth=64, depth=None, perturb=False): +def sample_cam_rays(hxy, Kinv, near_far, n_depth, depth=None, perturb=False): """Sample NeRF rays in camera space Args: @@ -14,7 +14,7 @@ def sample_cam_rays(hxy, Kinv, near_far, n_depth=64, depth=None, perturb=False): near_far: (M,2) Location of near/far planes per frame n_depth (int): Number of points to sample along each ray depth: (M,N,D,1) If provided, use these Z-coordinates for each ray sample - perturb (bool): If True, use stratified sampling and perturb depth samples + perturb (bool): If True, perturb depth samples Returns: xyz: (M,N,D,3) Ray points in camera space dir: (M,N,D,3) Ray directions in camera space @@ -73,20 +73,18 @@ def render_pixel(field_dict, deltas): # auxiliary outputs if "eikonal" in field_dict: - rendered["eikonal"] = field_dict["eikonal"].mean(dim=(-1, -2)) # (M, N) + # rendered["eikonal"] = field_dict["eikonal"].mean(dim=(-1, -2)) # (M, N) + rendered["eikonal"] = (field_dict["eikonal"][..., 0] * weights.detach()).sum(-1) if "delta_skin" in field_dict: rendered["delta_skin"] = field_dict["delta_skin"].mean(dim=(-1, -2)) # visibility loss + is_visible = (transmit[..., None] > 0.4).float() # a loose threshold # part of binary cross entropy: -label * log(sigmoid(vis)), where label is transmit - transmit = transmit[..., None].detach() - # sharpness = 20 # 0.6->0.88 - # is_visible = torch.sigmoid(sharpness * (transmit - 0.5)) - is_visible = transmit vis_loss = -(F.logsigmoid(field_dict["vis"]) * is_visible).mean(-2) # normalize by the number of visible points - vis_loss = vis_loss / is_visible.mean().detach() + vis_loss = vis_loss / is_visible.mean() rendered["vis"] = vis_loss # mask for gaussian density @@ -110,19 +108,19 @@ def compute_weights(density, deltas): alpha_p = 1 - torch.exp(-density) # (M, N, D) alpha_p = torch.cat( [alpha_p, torch.ones_like(alpha_p[:, :, :1])], dim=-1 - ) # (M, N, D), [a1,a2,a3,...,an,1] + ) # (M, N, D+1), [a1,a2,a3,...,an,1], adding a inf seg at the end transmit = torch.cumsum(density, dim=-1) transmit = torch.exp(-transmit) # (M, N, D) transmit = torch.cat( [torch.ones_like(transmit[:, :, :1]), transmit], dim=-1 - ) # (M, N, D), [1, (1-a1), (1-a1)(1-a2), ..., (1-a1)(1-a2)...(1-an)] + ) # (M, N, D+1), [1, (1-a1), (1-a1)(1-a2), ..., (1-a1)(1-a2)...(1-an)] # aggregate: sum to 1 # [a1, (1-a1)a2, (1-a1)(1-a2)a3, ..., (1-a1)(1-a2)...(1-an)1] weights = alpha_p * transmit # (M, N, D+1) - weights = weights[..., :-1] # (M, N, D), only take the first D weights - transmit = transmit[..., 1:] # (M, N, D), only take the first D transmits + weights = weights[..., :-1] # (M, N, D), first D weights (might not sum up to 1) + transmit = transmit[..., 1:] # (M, N, D) return weights, transmit @@ -190,7 +188,7 @@ def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): Sample @N_importance samples from @bins with distribution defined by @weights. Inputs: - bins: (N_rays, n_samples1) where n_samples is "the number of coarse samples per ray - 2" + bins: (N_rays, n_samples+1) where n_samples is "the number of coarse samples per ray - 2" weights: (N_rays, n_samples) N_importance: the number of samples to draw from the distribution det: deterministic or not diff --git a/lab4d/utils/skel_utils.py b/lab4d/utils/skel_utils.py index 4936073..779955b 100644 --- a/lab4d/utils/skel_utils.py +++ b/lab4d/utils/skel_utils.py @@ -4,12 +4,13 @@ import cv2 import numpy as np import torch +import torch.nn.functional as F from lab4d.utils.geom_utils import so3_to_exp_map from lab4d.utils.quat_transform import ( axis_angle_to_quaternion, matrix_to_quaternion, - quaternion_translation_mul, + dual_quaternion_mul, quaternion_translation_to_dual_quaternion, dual_quaternion_to_quaternion_translation, ) @@ -47,20 +48,22 @@ def rest_joints_to_local(rest_joints, edges): return local_rest_joints -def fk_se3(local_rest_joints, so3, edges, to_dq=True): - """Compute forward kinematics given joint angles on a skeleton +def fk_se3(local_rest_joints, so3, edges, to_dq=True, local_rest_coord=None): + """Compute forward kinematics given joint angles on a skeleton. + If local_rest_rmat is None, assuming identity rotation in zero configuration. Args: local_rest_joints: (B, 3) Translations from parent to current joints, - assuming identity rotation in zero configuration so3: (..., B, 3) Axis-angles at each joint edges (Dict(int, int)): Maps each joint to its parent joint to_dq (bool): If True, output link rigid transforms as dual quaternions, otherwise output SE(3) + local_rest_rot: (B, 3, 3) Local rotations Returns: out: Location of each joint. This is written as dual quaternions ((..., B, 4), (..., B, 4)) if to_dq=True, otherwise it is written as (..., B, 4, 4) SE(3) matrices. + link to global transforms X_global = T_1...T_k x X_k """ assert local_rest_joints.shape == so3.shape shape = so3.shape @@ -70,15 +73,22 @@ def fk_se3(local_rest_joints, so3, edges, to_dq=True): identity_rt = identity_rt.view((1,) * (len(shape) - 2) + (-1, 4, 4)) identity_rt = identity_rt.expand(*shape[:-1], -1, -1).clone() identity_rt_slice = identity_rt[..., 0, :, :].clone() - local_to_parent = identity_rt.clone() global_rt = identity_rt.clone() + if local_rest_coord is None: + local_rmat = so3_to_exp_map(so3) + else: + local_rmat = local_rest_coord[:, :3, :3] + local_rmat = local_rmat.view((1,) * (len(shape) - 2) + (-1, 3, 3)) + local_rmat = local_rmat @ so3_to_exp_map(so3) + + local_to_parent = torch.cat([local_rmat, local_rest_joints[..., None]], -1) + local_to_parent = torch.cat([local_to_parent, identity_rt[..., -1:, :]], -2) + # get local rt transformation: (..., k, 4, 4) + # parent ... child # first rotate around joint i # then translate wrt the relative position of the parent to i - local_to_parent[..., :3, :3] = so3_to_exp_map(so3) - local_to_parent[..., :3, 3] = local_rest_joints - for idx, parent_idx in edges.items(): if parent_idx > 0: parent_to_global = global_rt[..., parent_idx - 1, :, :].clone() @@ -98,7 +108,7 @@ def fk_se3(local_rest_joints, so3, edges, to_dq=True): return global_rt -def shift_joints_to_bones_dq(dq, edges, shift=None): +def shift_joints_to_bones_dq(dq, edges): """Compute bone centers and orientations from joint locations Args: @@ -110,8 +120,6 @@ def shift_joints_to_bones_dq(dq, edges, shift=None): written as dual quaternions """ quat, joints = dual_quaternion_to_quaternion_translation(dq) - if shift is not None: - joints += shift.reshape((1,) * (joints[0].ndim - 1) + (3,)) joints = shift_joints_to_bones(joints, edges) dq = quaternion_translation_to_dual_quaternion(quat, joints) return dq @@ -137,6 +145,31 @@ def shift_joints_to_bones(joints, edges): return joints +def apply_root_offset(dq, shift, orient): + """Compute bone centers and orientations from joint locations + + Args: + dq: ((..., B, 4), (..., B, 4)) Location of each joint, written as dual + quaternions + edges (Dict(int, int)): Maps each joint to its parent joint + Returns: + dq: ((..., B, 4), (..., B, 4)) Bone-to-object SE(3) transforms, + written as dual quaternions + """ + # normliaze the quaternion + orient = F.normalize(orient, 2, dim=-1) + ndim = dq[0].ndim + shape = dq[0].shape + shift = shift.reshape((1,) * (ndim - 1) + (3,)) + shift = shift.expand(*shape[:-1], -1) + orient = orient.reshape((1,) * (ndim - 1) + (4,)) + orient = orient.expand(*shape[:-1], -1) + offset_dq = quaternion_translation_to_dual_quaternion(orient, shift) + dq = dual_quaternion_mul(offset_dq, dq) + + return dq + + def get_predefined_skeleton(skel_type): """Compute pre-defined skeletons @@ -230,7 +263,7 @@ def get_predefined_skeleton(skel_type): 22: 0, # right hip 2: 1, # spine 2 3: 2, # spine 3 - 4: 3, # spine 4 + 4: 3, # head 5: 3, # left shoulder 9: 3, # right shoulder 6: 5, # left elbow diff --git a/lab4d/utils/torch_utils.py b/lab4d/utils/torch_utils.py index e478752..e150e63 100644 --- a/lab4d/utils/torch_utils.py +++ b/lab4d/utils/torch_utils.py @@ -1,5 +1,66 @@ # Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. import torch +import torch.nn as nn + + +def reinit_model(model, std=1): + for m in model.modules(): + if isinstance(m, nn.Linear): + if hasattr(m.weight, "data"): + nn.init.normal_(m.weight, mean=0.0, std=std) + if hasattr(m.bias, "data"): + m.bias.data.zero_() + + +def flip_pair(tensor): + """Flip the tensor along the pair dimension + + Args: + tensor: (M*2, ...) Inputs [x0, x1, x2, x3, ..., x_{2k}, x_{2k+1}] + + Returns: + tensor: (M*2, ...) Outputs [x1, x0, x3, x2, ..., x_{2k+1}, x_{2k}] + """ + if torch.is_tensor(tensor): + if len(tensor) < 2: + return tensor + return tensor.view(tensor.shape[0] // 2, 2, -1).flip(1).view(tensor.shape) + elif isinstance(tensor, tuple): + return tuple([flip_pair(t) for t in tensor]) + elif isinstance(tensor, dict): + return {k: flip_pair(v) for k, v in tensor.items()} + + +def compute_gradients_sdf(fn, x, training=False, sdf=None, mode="numerical", eps=1e-3): + """ + Taken from https://github.com/nvlabs/neuralangelo + """ + x = x.detach() + if mode == "analytical": + requires_grad = x.requires_grad + with torch.enable_grad(): + # 1st-order gradient + x.requires_grad_(True) + sdf = fn(x) + gradient = torch.autograd.grad(sdf.sum(), x, create_graph=True)[0] + # 2nd-order gradient (hessian) + if training: + pass + else: + gradient = gradient.detach() + x.requires_grad_(requires_grad) + elif mode == "numerical": + k1 = torch.tensor([1, -1, -1], dtype=x.dtype, device=x.device) # [3] + k2 = torch.tensor([-1, -1, 1], dtype=x.dtype, device=x.device) # [3] + k3 = torch.tensor([-1, 1, -1], dtype=x.dtype, device=x.device) # [3] + k4 = torch.tensor([1, 1, 1], dtype=x.dtype, device=x.device) # [3] + sdf1 = fn(x + k1 * eps) # [...,1] + sdf2 = fn(x + k2 * eps) # [...,1] + sdf3 = fn(x + k3 * eps) # [...,1] + sdf4 = fn(x + k4 * eps) # [...,1] + gradient = (k1 * sdf1 + k2 * sdf2 + k3 * sdf3 + k4 * sdf4) / (4.0 * eps) + return gradient + @torch.enable_grad() def compute_gradient(fn, x): @@ -26,6 +87,7 @@ def compute_gradient(fn, x): gradients = torch.cat(gradients, -1) # ...,input-dim, output-dim return gradients + def frameid_to_vid(fid, frame_offset): """Given absolute frame ids [0, ..., N], compute the video id of each frame. diff --git a/lab4d/utils/transforms.py b/lab4d/utils/transforms.py index c25e23d..2a6c7cb 100644 --- a/lab4d/utils/transforms.py +++ b/lab4d/utils/transforms.py @@ -21,6 +21,11 @@ def get_bone_coords(xyz, bone2obj): # reshape xyz = xyz[..., None, :].expand(xyz.shape[:-1] + (bone2obj[0].shape[-2], 3)).clone() + expand_shape = xyz.shape[:-2] + (-1, -1) + obj2bone = ( + obj2bone[0].expand(expand_shape).clone(), + obj2bone[1].expand(expand_shape).clone(), + ) xyz_bone = dual_quaternion_apply(obj2bone, xyz) return xyz_bone @@ -29,11 +34,11 @@ def get_xyz_bone_distance(xyz, bone2obj): """Compute squared distances from points to bone centers Argss: - xyz: (..., 3) Points in object canonical space - bone2obj: ((..., B, 4), (..., B, 4)) Bone-to-object SE(3) transforms, written as dual quaternions + xyz: (M, 3) Points in object canonical space + bone2obj: ((M, B, 4), (M, B, 4)) Bone-to-object SE(3) transforms, written as dual quaternions Returns: - dist2: (..., B) Squared distance to each bone center + dist2: (M, B) Squared distance to each bone center """ _, center = dual_quaternion_to_quaternion_translation(bone2obj) dist2 = (xyz[..., None, :] - center).pow(2).sum(-1) # M, K diff --git a/lab4d/utils/vis_utils.py b/lab4d/utils/vis_utils.py index b482116..3f31966 100644 --- a/lab4d/utils/vis_utils.py +++ b/lab4d/utils/vis_utils.py @@ -46,6 +46,9 @@ def img2color(tag, img, pca_fn=None): if "vis2d" in tag: img = minmax_normalize(img) + + if "xy_reproj" in tag: + img = minmax_normalize(img) return img @@ -377,3 +380,65 @@ def image_to_mesh(image_path, z_displacement=0.04, mesh_scale=0.005, mesh_res=5e mesh = trimesh.Trimesh(vertices=points, faces=faces, face_colors=colors) return mesh + + +def create_plane(size, offset): + """ + Create a plane mesh spaning x,z axis + """ + vertices = np.array( + [ + [-0.5, 0, -0.5], # vertex 0 + [0.5, 0, -0.5], # vertex 1 + [0.5, 0, 0.5], # vertex 2 + [-0.5, 0, 0.5], # vertex 3 + ] + ) + vertices = vertices * size + np.asarray(offset) + + faces = np.array( + [ + [0, 2, 1], # triangle 0 + [2, 0, 3], # triangle 1 + ] + ) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) + return mesh + + +def create_floor_mesh(scale=20, gl=True): + # create scene + floor1 = create_plane(scale, [0, 0, 0]) + floor1.visual.vertex_colors[:, 0] = 10 + floor1.visual.vertex_colors[:, 1] = 255 + floor1.visual.vertex_colors[:, 2] = 102 + floor1.visual.vertex_colors[:, 3] = 102 + + floor2 = create_plane(scale / 4, [0, scale * 0.001, 0]) + floor2.visual.vertex_colors[:, 0] = 10 + floor2.visual.vertex_colors[:, 1] = 102 + floor2.visual.vertex_colors[:, 2] = 255 + floor2.visual.vertex_colors[:, 3] = 102 + + floor = trimesh.util.concatenate([floor1, floor2]) + if not gl: + floor.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0])) + return floor + + +# visualze all meshes +def visualize_trajectory(all_verts, tag): + meshes_concat = [] + for it, verts_pred in enumerate(all_verts[::30]): + verts_pred = verts_pred.cpu().numpy() + verts_pred[:, 0] += 1 * it + meshes_concat.append([trimesh.Trimesh(vertices=verts_pred)]) + trimesh.util.concatenate(meshes_concat).export("tmp/%s.obj" % tag) + + +def append_xz_plane(mesh, world_to_cam, scale=5, gl=True): + mesh.visual.vertex_colors = mesh.visual.vertex_colors # visual.kind = 'vertex' + # scale = np.abs(mesh.vertices).max() * 2 + plane = create_floor_mesh(scale=scale, gl=gl) + plane.apply_transform(world_to_cam) + return trimesh.util.concatenate([mesh, plane]) diff --git a/preprocess/libs/io.py b/preprocess/libs/io.py index 4129e9b..c1c5589 100644 --- a/preprocess/libs/io.py +++ b/preprocess/libs/io.py @@ -109,6 +109,13 @@ def read_depth(depth_path, shape): return depth +def read_normal(normal_path, shape): + normal = np.load(normal_path).astype(np.float32) + if normal.shape[0] != shape[0] or normal.shape[1] != shape[1]: + normal = cv2.resize(normal, shape[:2][::-1], interpolation=cv2.INTER_LINEAR) + return normal + + @record_function("read_raw") def read_raw(img_path, delta, crop_size, use_full, with_flow=True): img = cv2.imread(img_path)[..., ::-1] / 255.0 @@ -120,6 +127,8 @@ def read_raw(img_path, delta, crop_size, use_full, with_flow=True): crop2raw = compute_crop_params(mask, crop_size=crop_size, use_full=use_full) depth_path = img_path.replace("JPEGImages", "Depth").replace(".jpg", ".npy") depth = read_depth(depth_path, shape) + normal_path = img_path.replace("JPEGImages", "Normal").replace(".jpg", ".npy") + normal = read_normal(normal_path, shape) is_fw = delta > 0 delta = abs(delta) @@ -148,6 +157,7 @@ def read_raw(img_path, delta, crop_size, use_full, with_flow=True): flow = cv2.remap(flow, x0, y0, interpolation=cv2.INTER_LINEAR) occ = cv2.remap(occ, x0, y0, interpolation=cv2.INTER_LINEAR) depth = cv2.remap(depth, x0, y0, interpolation=cv2.INTER_LINEAR) + normal = cv2.remap(normal, x0, y0, interpolation=cv2.INTER_LINEAR) # print('crop:%f'%(time.time()-ss)) data_dict = {} @@ -157,6 +167,7 @@ def read_raw(img_path, delta, crop_size, use_full, with_flow=True): data_dict["flow"] = flow data_dict["occ"] = occ data_dict["depth"] = depth.astype(np.float16) + data_dict["normal"] = normal.astype(np.float16) data_dict["crop2raw"] = crop2raw data_dict["hxy"] = hp_crop data_dict["hp_raw"] = hp_raw diff --git a/preprocess/libs/torch_models.py b/preprocess/libs/torch_models.py index 59c6aa8..ed07fbd 100644 --- a/preprocess/libs/torch_models.py +++ b/preprocess/libs/torch_models.py @@ -41,6 +41,14 @@ def forward(self, unary_wt=1.0, pairwise_wt=1.0): @ self.cams_canonical[self.annotated_idx, :3, :3].permute(0, 2, 1) ) + loss_unary_translation = torch.norm( + cams_pred[self.annotated_idx, :3, 3] + - self.cams_canonical[self.annotated_idx, :3, 3], + 2, + dim=-1, + ) + loss_unary = loss_unary + loss_unary_translation + # (2) relative translation should be close to procrustes cams_rel = cams_pred[1:, :3, :3] @ cams_pred[:-1, :3, :3].permute(0, 2, 1) diff --git a/preprocess/scripts/canonical_registration.py b/preprocess/scripts/canonical_registration.py index 3b49aa2..ffa962b 100644 --- a/preprocess/scripts/canonical_registration.py +++ b/preprocess/scripts/canonical_registration.py @@ -76,8 +76,9 @@ def canonical_registration(seqname, crop_size, obj_class, component_id=1): if obj_class == "other": import json, pdb - cam_path = ( - "database/processed/Cameras/Full-Resolution/%s/01-manual.json" % seqname + cam_path = "database/processed/Cameras/Full-Resolution/%s/%02d-manual.json" % ( + seqname, + component_id, ) with open(cam_path) as f: cams_canonical = json.load(f) @@ -126,30 +127,31 @@ def canonical_registration(seqname, crop_size, obj_class, component_id=1): quat, trans = registration.optimize() cams_pred = quaternion_translation_to_se3(quat, trans).cpu().numpy() - # fixed depth - cams_pred[:, :2, 3] = 0 - cams_pred[:, 2, 3] = 3 - - # compute initial camera trans with 2d bbox - # depth = focal * sqrt(surface_area / bbox_area) = focal / bbox_size - # xytrn = depth * (pxy - crop_size/2) / focal - # surface_area = 1 - for it, imgpath in enumerate(imglist): - bbox = get_bbox(imgpath, component_id=component_id) - if bbox is None: - continue - shape = cv2.imread(imgpath).shape[:2] - - focal = max(shape) - depth = focal / np.sqrt(bbox[2] * bbox[3]) - depth = min(depth, 10) # depth might be too large for mis-detected frames - - center_bbox = bbox[:2] + bbox[2:] / 2 - center_img = np.array(shape[::-1]) / 2 - xytrn = depth * (center_bbox - center_img) / focal - - cams_pred[it, 2, 3] = depth - cams_pred[it, :2, 3] = xytrn + if component_id == 1: + # fixed depth + cams_pred[:, :2, 3] = 0 + cams_pred[:, 2, 3] = 3 + + # compute initial camera trans with 2d bbox + # depth = focal * sqrt(surface_area / bbox_area) = focal / bbox_size + # xytrn = depth * (pxy - crop_size/2) / focal + # surface_area = 1 + for it, imgpath in enumerate(imglist): + bbox = get_bbox(imgpath, component_id=component_id) + if bbox is None: + continue + shape = cv2.imread(imgpath).shape[:2] + + focal = max(shape) + depth = focal / np.sqrt(bbox[2] * bbox[3]) + depth = min(depth, 10) # depth might be too large for mis-detected frames + + center_bbox = bbox[:2] + bbox[2:] / 2 + center_img = np.array(shape[::-1]) / 2 + xytrn = depth * (center_bbox - center_img) / focal + + cams_pred[it, 2, 3] = depth + cams_pred[it, :2, 3] = xytrn np.save("%s/%02d-canonical.npy" % (save_path, component_id), cams_pred) draw_cams(cams_pred, rgbpath_list=imglist).export( diff --git a/preprocess/scripts/crop.py b/preprocess/scripts/crop.py index aaed5f0..8ac376e 100644 --- a/preprocess/scripts/crop.py +++ b/preprocess/scripts/crop.py @@ -29,6 +29,7 @@ def extract_crop(seqname, crop_size, use_full): rgb_list = [] mask_list = [] depth_list = [] + normal_list = [] crop2raw_list = [] is_detected_list = [] @@ -51,6 +52,7 @@ def extract_crop(seqname, crop_size, use_full): rgb_list.append(data_dict0["img"]) mask_list.append(data_dict0["mask"]) depth_list.append(data_dict0["depth"]) + normal_list.append(data_dict0["normal"]) crop2raw_list.append(data_dict0["crop2raw"]) is_detected_list.append(data_dict0["is_detected"]) @@ -58,6 +60,7 @@ def extract_crop(seqname, crop_size, use_full): rgb_list.append(data_dict1["img"]) mask_list.append(data_dict1["mask"]) depth_list.append(data_dict1["depth"]) + normal_list.append(data_dict1["normal"]) crop2raw_list.append(data_dict1["crop2raw"]) is_detected_list.append(data_dict1["is_detected"]) @@ -95,6 +98,11 @@ def extract_crop(seqname, crop_size, use_full): np.stack(depth_list, 0), ) + np.save( + "database/processed/Normal/Full-Resolution/%s/%s.npy" % (seqname, save_prefix), + np.stack(normal_list, 0), + ) + np.save( "database/processed/Annotations/Full-Resolution/%s/%s-crop2raw.npy" % (seqname, save_prefix), diff --git a/preprocess/scripts/extract_frames.py b/preprocess/scripts/extract_frames.py index 612cfb1..daddd20 100644 --- a/preprocess/scripts/extract_frames.py +++ b/preprocess/scripts/extract_frames.py @@ -6,10 +6,19 @@ import numpy as np -def extract_frames(in_path, out_path): +def extract_frames(in_path, out_path, desired_fps=30): print("extracting frames: ", in_path) # Open the video file reader = imageio.get_reader(in_path) + original_fps = reader.get_meta_data()["fps"] + # If a desired frame rate is higher than original + if original_fps < desired_fps: + desired_fps = original_fps + + # If a desired frame rate is given, calculate the frame skip rate + skip_rate = 1 + if desired_fps: + skip_rate = int(original_fps / desired_fps) # Find the first non-black frame for i, im in enumerate(reader): @@ -17,10 +26,10 @@ def extract_frames(in_path, out_path): start_frame = i break - # Write the video starting from the first non-black frame + # Write the video starting from the first non-black frame, considering the desired frame rate count = 0 for i, im in enumerate(reader): - if i >= start_frame: + if i >= start_frame and i % skip_rate == 0: imageio.imsave("%s/%05d.jpg" % (out_path, count), im) count += 1 diff --git a/preprocess/scripts/fake_data.py b/preprocess/scripts/fake_data.py new file mode 100644 index 0000000..202a4c5 --- /dev/null +++ b/preprocess/scripts/fake_data.py @@ -0,0 +1,16 @@ +import numpy as np +import cv2 +import os +import glob + + +def create_fake_masks(seqname, outdir): + anno_dir = f"{outdir}/Annotations/Full-Resolution/{seqname}" + os.makedirs(anno_dir, exist_ok=True) + ref_list = sorted(glob.glob(f"{outdir}/JPEGImages/Full-Resolution/{seqname}/*")) + shape = cv2.imread(ref_list[0]).shape[:2] + mask = -1 * np.ones(shape).astype(np.int8) + for ref in ref_list: + img_ext = ref.split("/")[-1].split(".")[0] + save_path = "%s/%s.npy" % (anno_dir, img_ext) + np.save(save_path, mask) diff --git a/preprocess/scripts/tsdf_fusion.py b/preprocess/scripts/tsdf_fusion.py index 3ec26b7..64aa607 100644 --- a/preprocess/scripts/tsdf_fusion.py +++ b/preprocess/scripts/tsdf_fusion.py @@ -39,7 +39,9 @@ # return cam2scene -def tsdf_fusion(seqname, component_id, crop_size=256, use_full=True): +def tsdf_fusion( + seqname, component_id, crop_size=256, use_full=True, voxel_size=0.2, use_gpu=False +): # load rgb/depth imgdir = "database/processed/JPEGImages/Full-Resolution/%s" % seqname imglist = sorted(glob.glob("%s/*.jpg" % imgdir)) @@ -69,7 +71,7 @@ def tsdf_fusion(seqname, component_id, crop_size=256, use_full=True): view_frust_pts = fusion.get_view_frustum(depth, K0, cam2scene) vol_bnds[:, 0] = np.minimum(vol_bnds[:, 0], np.amin(view_frust_pts, axis=1)) vol_bnds[:, 1] = np.maximum(vol_bnds[:, 1], np.amax(view_frust_pts, axis=1)) - tsdf_vol = fusion.TSDFVolume(vol_bnds, voxel_size=0.2, use_gpu=False) + tsdf_vol = fusion.TSDFVolume(vol_bnds, voxel_size=voxel_size, use_gpu=use_gpu) # fusion for it, imgpath in enumerate(imglist[:-1]): diff --git a/preprocess/third_party/omnivision/modules/__init__.py b/preprocess/third_party/omnivision/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocess/third_party/omnivision/modules/channel_attention.py b/preprocess/third_party/omnivision/modules/channel_attention.py new file mode 100644 index 0000000..9588506 --- /dev/null +++ b/preprocess/third_party/omnivision/modules/channel_attention.py @@ -0,0 +1,130 @@ +import torch +from torch import nn + + +class ECALayer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + def __init__(self, channel, k_size=3): + super(ECALayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + y = self.sigmoid(y) + + return x * y.expand_as(x) + + +class ChannelAttention(nn.Module): + def __init__(self, num_features, reduction): + super(ChannelAttention, self).__init__() + self.module = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(num_features, num_features // reduction, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(num_features // reduction, num_features, kernel_size=1), + nn.Sigmoid() + ) + + def forward(self, x): + return x * self.module(x) + + +class RCAB(nn.Module): + def __init__(self, num_features, reduction): + super(RCAB, self).__init__() + self.module = nn.Sequential( + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), + ChannelAttention(num_features, reduction) + ) + + def forward(self, x): + return x + self.module(x) + + +class RG(nn.Module): + def __init__(self, num_features, num_rcab, reduction): + super(RG, self).__init__() + self.module = [RCAB(num_features, reduction) for _ in range(num_rcab)] + self.module.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)) + self.module = nn.Sequential(*self.module) + + def forward(self, x): + return x + self.module(x) + + +class RCAN(nn.Module): + def __init__(self, scale, num_features, num_rg, num_rcab, reduction): + super(RCAN, self).__init__() + self.sf = nn.Conv2d(3, num_features, kernel_size=3, padding=1) + self.rgs = nn.Sequential(*[RG(num_features, num_rcab, reduction) for _ in range(num_rg)]) + self.conv1 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1) + self.upscale = nn.Sequential( + nn.Conv2d(num_features, num_features * (scale ** 2), kernel_size=3, padding=1), + nn.PixelShuffle(scale) + ) + self.conv2 = nn.Conv2d(num_features, 3, kernel_size=3, padding=1) + + def forward(self, x): + x = self.sf(x) + residual = x + x = self.rgs(x) + x = self.conv1(x) + x += residual + x = self.upscale(x) + x = self.conv2(x) + return x + + +class CBAMChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=16): + super(CBAMChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) + self.relu1 = nn.ReLU() + self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) + max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) + out = avg_out + max_out + return self.sigmoid(out) + + +class CBAMSpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(CBAMSpatialAttention, self).__init__() + + assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + padding = 3 if kernel_size == 7 else 1 + + self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) diff --git a/preprocess/third_party/omnivision/modules/midas/__init__.py b/preprocess/third_party/omnivision/modules/midas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocess/third_party/omnivision/modules/midas/base_model.py b/preprocess/third_party/omnivision/modules/midas/base_model.py new file mode 100644 index 0000000..5cf4302 --- /dev/null +++ b/preprocess/third_party/omnivision/modules/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/preprocess/third_party/omnivision/modules/midas/blocks.py b/preprocess/third_party/omnivision/modules/midas/blocks.py new file mode 100644 index 0000000..2145d18 --- /dev/null +++ b/preprocess/third_party/omnivision/modules/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/preprocess/third_party/omnivision/modules/midas/dpt_depth.py b/preprocess/third_party/omnivision/modules/midas/dpt_depth.py new file mode 100644 index 0000000..738cbf5 --- /dev/null +++ b/preprocess/third_party/omnivision/modules/midas/dpt_depth.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + True, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) \ No newline at end of file diff --git a/preprocess/third_party/omnivision/modules/midas/midas_net.py b/preprocess/third_party/omnivision/modules/midas/midas_net.py new file mode 100644 index 0000000..8a95497 --- /dev/null +++ b/preprocess/third_party/omnivision/modules/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/preprocess/third_party/omnivision/modules/midas/midas_net_custom.py b/preprocess/third_party/omnivision/modules/midas/midas_net_custom.py new file mode 100644 index 0000000..50e4acb --- /dev/null +++ b/preprocess/third_party/omnivision/modules/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/preprocess/third_party/omnivision/modules/midas/transforms.py b/preprocess/third_party/omnivision/modules/midas/transforms.py new file mode 100644 index 0000000..350cbc1 --- /dev/null +++ b/preprocess/third_party/omnivision/modules/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/preprocess/third_party/omnivision/modules/midas/vit.py b/preprocess/third_party/omnivision/modules/midas/vit.py new file mode 100644 index 0000000..ea46b1b --- /dev/null +++ b/preprocess/third_party/omnivision/modules/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/preprocess/third_party/omnivision/modules/unet.py b/preprocess/third_party/omnivision/modules/unet.py new file mode 100644 index 0000000..a8b9d1f --- /dev/null +++ b/preprocess/third_party/omnivision/modules/unet.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .channel_attention import * + + +class UNet_up_block(nn.Module): + def __init__(self, prev_channel, input_channel, output_channel, up_sample=True, use_skip=True): + super().__init__() + self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + if use_skip: + self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1) + else: + self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1) + self.bn1 = nn.GroupNorm(8, output_channel) + self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1) + self.bn2 = nn.GroupNorm(8, output_channel) + self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1) + self.bn3 = nn.GroupNorm(8, output_channel) + self.relu = torch.nn.ReLU() + self.up_sample = up_sample + + def forward(self, x, prev_feature_map=None): + if self.up_sample: + x = self.up_sampling(x) + if prev_feature_map is not None: + x = torch.cat((x, prev_feature_map), dim=1) + x = self.relu(self.bn1(self.conv1(x))) + x = self.relu(self.bn2(self.conv2(x))) + x = self.relu(self.bn3(self.conv3(x))) + return x + + +class UNet_down_block(nn.Module): + def __init__(self, input_channel, output_channel, down_size=True): + super().__init__() + self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1) + self.bn1 = nn.GroupNorm(8, output_channel) + self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1) + self.bn2 = nn.GroupNorm(8, output_channel) + self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1) + self.bn3 = nn.GroupNorm(8, output_channel) + self.max_pool = nn.MaxPool2d(2, 2) + self.relu = nn.ReLU() + self.down_size = down_size + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + x = self.relu(self.bn2(self.conv2(x))) + x = self.relu(self.bn3(self.conv3(x))) + if self.down_size: + x = self.max_pool(x) + return x + + +class UNet(nn.Module): + def __init__(self, downsample=6, in_channels=3, out_channels=3, patch_size=1): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.downsample = downsample + self.patch_size = patch_size + + self.down1 = UNet_down_block(in_channels, 16, False) + self.down_blocks = nn.ModuleList( + [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)] + ) + + bottleneck = 2**(4 + downsample) + self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) + self.bn1 = nn.GroupNorm(8, bottleneck) + self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) + self.bn2 = nn.GroupNorm(8, bottleneck) + self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1) + self.bn3 = nn.GroupNorm(8, bottleneck) + + self.up_blocks = nn.ModuleList( + [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)] + ) + + self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) + self.last_bn = nn.GroupNorm(8, 16) + self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.down1(x) + xvals = [x] + for i in range(0, self.downsample): + x = self.down_blocks[i](x) + xvals.append(x) + + x = self.relu(self.bn1(self.mid_conv1(x))) + x = self.relu(self.bn2(self.mid_conv2(x))) + x = self.relu(self.bn3(self.mid_conv3(x))) + + for i in range(0, self.downsample)[::-1]: + x = self.up_blocks[i](x, xvals[i]) + + x = self.relu(self.last_bn(self.last_conv1(x))) + x = self.last_conv2(x) + # x = F.interpolate(x, scale_factor=(1/self.patch_size, 1/self.patch_size), mode='bilinear', align_corners=False) + return x + + + + +class UNetRelu(nn.Module): + def __init__(self, downsample=6, in_channels=3, out_channels=3, patch_size=1): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.downsample = downsample + self.patch_size = patch_size + + self.down1 = UNet_down_block(in_channels, 16, False) + self.down_blocks = nn.ModuleList( + [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)] + ) + + bottleneck = 2**(4 + downsample) + self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) + self.bn1 = nn.GroupNorm(8, bottleneck) + self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) + self.bn2 = nn.GroupNorm(8, bottleneck) + self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1) + self.bn3 = nn.GroupNorm(8, bottleneck) + + self.up_blocks = nn.ModuleList( + [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i)) for i in range(0, downsample)] + ) + + self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) + self.last_bn = nn.GroupNorm(8, 16) + self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.down1(x) + xvals = [x] + for i in range(0, self.downsample): + x = self.down_blocks[i](x) + xvals.append(x) + + x = self.relu(self.bn1(self.mid_conv1(x))) + x = self.relu(self.bn2(self.mid_conv2(x))) + x = self.relu(self.bn3(self.mid_conv3(x))) + + for i in range(0, self.downsample)[::-1]: + x = self.up_blocks[i](x, xvals[i]) + + x = self.relu(self.last_bn(self.last_conv1(x))) + x = self.last_conv2(x) + # x = F.interpolate(x, scale_factor=(1/self.patch_size, 1/self.patch_size), mode='bilinear', align_corners=False) + return self.relu(x) + + +class UNetV2(nn.Module): + def __init__(self, in_channels=3, out_channels=3, patch_size=1): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + + self.down1 = UNet_down_block(in_channels, 16, False) + self.down_blocks = nn.ModuleList([ + UNet_down_block(16, 32, True), + UNet_down_block(32, 64, True), + UNet_down_block(64, 256, True), + UNet_down_block(256, 256, True), + UNet_down_block(256, 512, True), + UNet_down_block(512, 1024, True), + ]) + + bottleneck = 1024 + self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) + self.bn1 = nn.GroupNorm(8, bottleneck) + self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1) + self.bn2 = nn.GroupNorm(8, bottleneck) + self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1) + self.bn3 = nn.GroupNorm(8, bottleneck) + + self.up_blocks = nn.ModuleList([ + UNet_up_block(512, 1024, 512), + UNet_up_block(256, 512, 256), + UNet_up_block(256, 256, 256), + UNet_up_block(64, 256, 64), + UNet_up_block(32, 64, 32), + UNet_up_block(16, 32, 16), + ]) + + self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1) + self.last_bn = nn.InstanceNorm2d(16) + self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0) + self.attention = ECALayer(out_channels, k_size=7) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.down1(x) + xvals = [x] + for down_block in self.down_blocks: + x = down_block(x) + xvals.append(x) + + x = self.relu(self.bn1(self.mid_conv1(x))) + x = self.relu(self.bn2(self.mid_conv2(x))) + x = self.relu(self.bn3(self.mid_conv3(x))) + + for up_block, xval in zip(self.up_blocks, xvals[::-1][1:len(self.up_blocks)+1]): + x = up_block(x, xval) + + x = self.relu(self.last_bn(self.last_conv1(x))) + x = self.last_conv2(x) + x = self.attention(x) + x = F.interpolate(x, scale_factor=(1/self.patch_size, 1/self.patch_size), mode='bilinear', align_corners=False) + return x + diff --git a/preprocess/third_party/omnivision/normal.py b/preprocess/third_party/omnivision/normal.py new file mode 100644 index 0000000..b0915ef --- /dev/null +++ b/preprocess/third_party/omnivision/normal.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +# python preprocess/third_party/omnivision/normal.py cat-pikachu-0-0000 +import pdb +import glob +import os +import sys + +import numpy as np +import torch +import cv2 + +sys.path.insert( + 0, + "%s/../../" % os.path.join(os.path.dirname(__file__)), +) + +from libs.utils import resize_to_target + +sys.path.insert( + 0, + "%s/" % os.path.join(os.path.dirname(__file__)), +) +from modules.midas.dpt_depth import DPTDepthModel + + +def load_model(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + pretrained_weights_path = ( + "./preprocess/third_party/omnivision/omnidata_dpt_normal_v2_cleaned.ckpt" + ) + model = DPTDepthModel(backbone="vitb_rn50_384", num_channels=3) # DPT Hybrid + checkpoint = torch.load(pretrained_weights_path) + if "state_dict" in checkpoint: + state_dict = {} + for k, v in checkpoint["state_dict"].items(): + state_dict[k[6:]] = v + else: + state_dict = checkpoint + + model.load_state_dict(state_dict) + model.to(device) + return model + + +@torch.no_grad() +def predict_normal(model, img): + # resize + testres = np.sqrt(2e5 / (img.shape[0] * img.shape[1])) + maxh = img.shape[0] * testres + maxw = img.shape[1] * testres + max_h = int(maxh // 64 * 64) + max_w = int(maxw // 64 * 64) + if max_h < maxh: + max_h += 64 + if max_w < maxw: + max_w += 64 + + input_size = img.shape + img = cv2.resize(img, (max_w, max_h)) + img = np.transpose(img, [2, 0, 1])[None] + img_tensor = torch.Tensor(img / 255.0).cuda() + if img_tensor.shape[1] == 1: + img_tensor = img_tensor.repeat_interleave(3, 1) + + output = model(img_tensor).clamp(min=0, max=1) + normal = output[0].permute(1, 2, 0).cpu().numpy() + normal = cv2.resize( + normal, (input_size[1], input_size[0]), interpolation=cv2.INTER_LINEAR + ) + + return normal + + +def extract_normal(seqname): + image_dir = "database/processed/JPEGImages/Full-Resolution/%s/" % seqname + output_dir = image_dir.replace("JPEGImages", "Normal") + os.makedirs(output_dir, exist_ok=True) + + model = load_model() + + for img_path in sorted(glob.glob(f"{image_dir}/*.jpg")): + # print(img_path) + img = cv2.imread(img_path)[..., ::-1] + normal = predict_normal(model, img) + normal = resize_to_target(normal, is_flow=False) + normal = 2 * normal - 1 # [-1, 1] + normal = normal / (1e-6 + np.linalg.norm(normal, 2, -1)[..., None]) + + out_path = f"{output_dir}/{os.path.basename(img_path).replace('.jpg', '.npy')}" + np.save(out_path, normal.astype(np.float16)) + vis_path = f"{output_dir}/vis-{os.path.basename(img_path)}" + normal = (normal + 1) / 2 + cv2.imwrite(vis_path, normal[..., ::-1] * 255) + + print("surface normal saved to %s" % output_dir) + + +if __name__ == "__main__": + seqname = sys.argv[1] + + extract_normal(seqname) diff --git a/projects/ppr/config.py b/projects/ppr/config.py new file mode 100644 index 0000000..4f9dd97 --- /dev/null +++ b/projects/ppr/config.py @@ -0,0 +1,40 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +import os + +from absl import flags + +opts = flags.FLAGS + + +class PPRConfig: + # configs related to ppr + flags.DEFINE_string("urdf_template", "", "whether to use predefined skeleton") + flags.DEFINE_float("timestep", 1e-3, "time step of simulation") + flags.DEFINE_float("frame_interval", 0.1, "time between two frames") + flags.DEFINE_float("ratio_phys_cycle", 0.5, "number of iterations per round") + flags.DEFINE_float("secs_per_wdw", 2.4, "length of the physics opt window in secs") + flags.DEFINE_string( + "phys_vid", "0", "whether to optimize selected videos, e.g., 0,1,2" + ) + flags.DEFINE_integer("phys_vis_interval", 20, "visualization interval") + flags.DEFINE_integer("warmup_iters", 0, "warmup iterations, only >0 for DR+DP") + flags.DEFINE_float("phys_learning_rate", 5e-4, "learning rate") + flags.DEFINE_float("noise_std", 2e-3, "noise std added to initial states") + + # weights + flags.DEFINE_float("traj_wt", 0.01, "weight for traj matching loss") + flags.DEFINE_float("pos_state_wt", 2e-4, "weight for position matching reg") + flags.DEFINE_float("vel_state_wt", 0.0, "weight for velocity matching reg") + flags.DEFINE_float("pos_distill_wt", 0.1, "weight for distilling proxy kienmatics") + + # regs + flags.DEFINE_float("reg_torque_wt", 0.0, "weight for torque regularization") + flags.DEFINE_float("reg_res_f_wt", 0.0, "weight for residual force regularization") + flags.DEFINE_float("reg_foot_wt", 0.0, "weight for foot contact regularization") + flags.DEFINE_float("reg_root_wt", 0.0, "weight for root pose regularization") + flags.DEFINE_float("reg_phys_q_wt", 0.1, "weight for soft physics regularization") + flags.DEFINE_float("reg_phys_ja_wt", 0.02, "weight for soft physics regularization") + + # io-related + flags.DEFINE_string("load_path_bg", "", "path to load pretrained model") + flags.DEFINE_string("load_suffix_phys", "", "sufix of params, {latest, 0, 10, ...}") diff --git a/projects/ppr/eval/compute_metrics.py b/projects/ppr/eval/compute_metrics.py new file mode 100644 index 0000000..81738cc --- /dev/null +++ b/projects/ppr/eval/compute_metrics.py @@ -0,0 +1,163 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +# python projects/ppr/eval/compute_metrics.py --pred_prefix "" --fps 10 --skip 3 +# python projects/ppr/eval/compute_metrics.py +import sys, os +import pdb +import json +import glob +import numpy as np +import argparse +import trimesh +import tqdm + +sys.path.insert(0, os.getcwd()) +from eval_utils import load_ama_intrinsics, ama_eval + + +cwd = os.getcwd() +if cwd not in sys.path: + sys.path.insert(0, cwd) +from lab4d.utils.io import save_vid +from lab4d.utils.pyrender_wrapper import PyRenderWrapper +from lab4d.utils.vis_utils import append_xz_plane + +parser = argparse.ArgumentParser(description="script to render extraced meshes") +parser.add_argument( + "--testdir", + default="logdir/ama-samba-4v-fg-bob-2g-r120-cse/export_0000/", + help="path to the directory with results", +) +parser.add_argument( + "--gt_seq", default="T_samba-1", help="name of the ground-truth sequqnce" +) +parser.add_argument( + "--pred_prefix", + default="/data3/gengshay/eval/humor/T_samba1", + help="name of the pred sequqnce", +) +parser.add_argument("--skip", default=1, type=int, help="pred mesh has n times less") +parser.add_argument("--fps", default=30, type=int, help="fps of the video") +args = parser.parse_args() + + +def main(): + ama_path = "database/ama/" + # gt + gt_name, gt_cam_id = args.gt_seq.split("-") + gt_cam_path = "%s/%s/calibration/Camera%s.Pmat.cal" % (ama_path, gt_name, gt_cam_id) + intrinsics_gt, Gmat_gt = load_ama_intrinsics(gt_cam_path) + + # glob meshes + gt_mesh_dir = "%s/%s/meshes/" % (ama_path, gt_name) + gt_mesh_dict = {} + for fidx, path in enumerate(sorted(glob.glob("%s/mesh_*.obj" % (gt_mesh_dir)))): + if fidx % args.skip == 0: + gt_mesh_dict[fidx] = trimesh.load(path, process=False) + gt_mesh_dict[fidx].apply_transform(Gmat_gt) + if len(gt_mesh_dict) == 0: + print("no mesh found that matches %s*" % (args.testdir)) + return + print("found %d groune-truth meshes" % (len(gt_mesh_dict))) + + # pred (from lab4d) + camera_info = json.load(open("%s/camera.json" % (args.testdir), "r")) + raw_size = camera_info["raw_size"] # h,w + + # glob predicted meshes (from either lab4d or other methods) + if args.pred_prefix == "": + pred_prefix = "%s/fg/mesh/" % (args.testdir) # use lab4d + pred_mesh_paths = glob.glob("%s*.obj" % (pred_prefix)) + intrinsics = np.asarray(camera_info["intrinsics"], dtype=np.float32) + # transform to view coord + extrinsics = json.load(open("%s/fg/motion.json" % (args.testdir), "r")) + extrinsics = np.asarray(extrinsics["field2cam"]) + + if os.path.exists("%s/bg/motion.json" % (args.testdir)): + extrinsics_bg = json.load(open("%s/bg/motion.json" % (args.testdir), "r")) + extrinsics_bg = np.asarray(extrinsics_bg["field2cam"]) + + # align bg floor with xz plane + field2world_path = "%s/bg/field2world.json" % (args.testdir) + field2world = np.asarray(json.load(open(field2world_path, "r"))) + world2field = np.linalg.inv(field2world) + else: + pred_mesh_paths = glob.glob("%s*.obj" % (args.pred_prefix)) + pred_camera_paths = sorted(glob.glob("%s*.txt" % (args.pred_prefix))) + cameras = np.stack([np.loadtxt(i) for i in pred_camera_paths], 0) + intrinsics = cameras[:, 3] + extrinsics = np.repeat(np.eye(4)[None], len(pred_mesh_paths), axis=0) + extrinsics[:, :3] = cameras[:, :3] + pred_mesh_dict = {} + for fidx, mesh_path in enumerate(sorted(pred_mesh_paths)): + fidx = int(mesh_path.split("/")[-1].split("-")[-1].split(".")[0]) + pred_mesh_dict[args.skip * fidx] = trimesh.load(mesh_path, process=False) + pred_mesh_dict[args.skip * fidx].apply_transform(extrinsics[fidx]) + # pred_mesh_dict[args.skip * fidx].apply_transform(np.linalg.inv(Gmat_gt)) + assert len(pred_mesh_dict) == len(gt_mesh_dict) + + if os.path.exists("%s/bg/motion.json" % (args.testdir)): + pred_mesh_paths_bg = glob.glob("%s/bg/mesh/*.obj" % (args.testdir)) + pred_mesh_dict_bg = {} + for fidx, mesh_path in enumerate(sorted(pred_mesh_paths_bg)): + fidx = int(mesh_path.split("/")[-1].split("-")[-1].split(".")[0]) + pred_mesh_dict_bg[args.skip * fidx] = trimesh.load(mesh_path, process=False) + pred_mesh_dict_bg[args.skip * fidx].apply_transform(extrinsics_bg[fidx]) + + # evaluate + # ama_eval(all_verts_gt, all_verts_gt, verbose=True) + ( + cd_avg, + f010_avg, + f005_avg, + f002_avg, + pred_mesh_dict, + pred_cd_dict, + gt_cd_dict, + ) = ama_eval(pred_mesh_dict, gt_mesh_dict, verbose=True) + + # render + renderer_gt = PyRenderWrapper(raw_size) + renderer_pred = PyRenderWrapper(raw_size) + frames = [] + for fidx, mesh_obj in tqdm.tqdm(gt_mesh_dict.items(), desc=f"Rendering:"): + # world_to_cam_pred = extrinsics_bg[fidx] @ world2field + mesh_obj = append_xz_plane(mesh_obj, Gmat_gt) + gt_cd_dict[fidx] = append_xz_plane(gt_cd_dict[fidx], Gmat_gt) + pred_mesh_dict[fidx] = append_xz_plane(pred_mesh_dict[fidx], Gmat_gt) + pred_cd_dict[fidx] = append_xz_plane(pred_cd_dict[fidx], Gmat_gt) + # pred_mesh_dict[fidx] = trimesh.util.concatenate( + # [pred_mesh_dict[fidx], pred_mesh_dict_bg[fidx]] + # ) + # mesh_obj.export("tmp/0.obj") + # pred_mesh_dict[fidx].export("tmp/1.obj") + # pdb.set_trace() + + # renderer_gt.set_camera_frontal(4, gl=True) + renderer_gt.set_intrinsics(intrinsics_gt) + renderer_gt.align_light_to_camera() + color_gt = renderer_gt.render({"shape": mesh_obj})[0] + cd_gt = renderer_gt.render({"shape": gt_cd_dict[fidx]})[0] + + # renderer_pred.set_camera_frontal(4, gl=True) + renderer_pred.set_intrinsics(intrinsics[fidx // args.skip]) + renderer_pred.align_light_to_camera() + color_pred = renderer_pred.render({"shape": pred_mesh_dict[fidx]})[0] + cd_pred = renderer_pred.render({"shape": pred_cd_dict[fidx]})[0] + + color = np.concatenate([color_gt, color_pred], axis=1) + cd = np.concatenate([cd_gt, cd_pred], axis=1) + final = np.concatenate([color, cd], axis=0) + frames.append(final.astype(np.uint8)) + + save_vid( + "%s/render" % args.testdir, + frames, + suffix=".mp4", + upsample_frame=-1, + fps=args.fps, + ) + print("saved to %s/render.mp4" % args.testdir) + + +if __name__ == "__main__": + main() diff --git a/projects/ppr/eval/compute_phys_metrics.py b/projects/ppr/eval/compute_phys_metrics.py new file mode 100644 index 0000000..88cd08f --- /dev/null +++ b/projects/ppr/eval/compute_phys_metrics.py @@ -0,0 +1,110 @@ +# """ +# WIP: compute physical metrics for the predicted keypoints +# python scripts/eval/eval_phys.py logdir/bgnerf-new-human_mod-ama-d-e120-b96-pft3/D_handstand5-kps.npy +# """ +# from absl import flags, app +# import cv2 +# import trimesh +# import json +# import glob +# import sys + +# sys.path.insert(0, "") +# sys.path.insert(0, "third_party") +# import configparser +# import numpy as np +# import pdb +# import imageio +# import pyrender +# from scipy.spatial.transform import Rotation as R +# import torch + +# from nnutils.geom_utils import vec_to_sim3, optimize_scale, fit_plane_contact +# from nnutils.geom_utils import ( +# extract_mesh, +# zero_to_rest_bone, +# zero_to_rest_dpose, +# skinning, +# lbs, +# se3_vec2mat, +# ) +# from nnutils.urdf_utils import articulate_robot, angles2cfg +# from utils.io import vis_kps, draw_cams +# from utils.io import save_vid, str_to_frame, save_bones +# from nnutils.train_utils import v2s_trainer +# from sklearn.metrics import f1_score + +# kp_path = sys.argv[1] +# seqname = kp_path.split("/")[-1].rsplit("-", 1)[0] +# if "D_bouncing" in kp_path or "D_handstand" in kp_path or "T_samba" in kp_path: +# seqname = seqname[:-1] +# is_ama = True +# else: +# is_ama = False # assume cat +# frame_duration = 1 / 30.0 + + +# def main(_): +# # get numbers +# kps = np.load(kp_path) +# contact_labels = np.loadtxt("misc/gt_contact/%s-contact.txt" % seqname) +# n_kp = kps.shape[1] +# n_fr = min(kps.shape[0], contact_labels.shape[0]) + +# contact_labels = contact_labels[:n_fr] +# kps = kps[:n_fr] + +# pdb.set_trace() +# if is_ama: +# contact_pred = np.abs(kps[..., 1]) < 0.1 +# else: +# contact_pred = np.abs(kps[..., 1]) < 0.2 +# f1 = f1_score(contact_labels.flatten(), contact_pred.flatten()) + +# # TODO visualize +# try: +# kps_gt = np.load("misc/gt_contact/%s-kps.npy" % seqname) +# kps_gt = kps_gt[:n_fr] +# vis_kps( +# np.transpose(kps_gt, [0, 2, 1]), +# "tmp/kps_gt.obj", +# binary_labels=contact_labels, +# ) +# except: +# pass +# vis_kps( +# np.transpose(kps, [0, 2, 1]), +# "tmp/kps_pred.obj", +# binary_labels=contact_pred == contact_labels, +# ) + +# ## jerk: this is not accurate due to finite difference +# # kps_vel = (kps[2:] - kps[:-2]) / (2*frame_duration) +# # kps_acc = (kps_vel[2:] - kps_vel[:-2]) / (2*frame_duration) +# ##kps_acn = np.linalg.norm(kps_acc, 2,-1).mean() +# # kps_jrk = np.linalg.norm(kps_acc[1:] - kps_acc[:-1], 2,-1).mean() +# # +# # kps_vel_gt = (kps_gt[2:] - kps_gt[:-2]) / (2*frame_duration) +# # kps_acc_gt = (kps_vel_gt[2:] - kps_vel_gt[:-2]) / (2*frame_duration) +# ##kps_acn_gt = np.linalg.norm(kps_acc_gt, 2,-1).mean() +# # kps_jrk_gt = np.linalg.norm(kps_acc_gt[1:] - kps_acc_gt[:-1], 2,-1).mean() + +# # pdb.set_trace() +# # acc_err = np.abs(kps_acn - kps_acn_gt).mean() + +# # skate +# move_dis = np.linalg.norm(kps[1:] - kps[:-1], 2, -1) +# move_dis_all = move_dis[np.logical_and(contact_labels[:-1], contact_labels[1:])] + +# state_5cm = (np.asarray(move_dis_all) > 0.05).mean() +# state_ave = np.asarray(move_dis_all).mean() + +# # print('tv: %.1f'%(kps_jrk)) +# # print('tv-gt: %.1f'%(kps_jrk_gt)) +# print("f1: %.1f" % (f1 * 100)) +# print("sk-5cm: %.1f" % (state_5cm * 100)) +# print("sk-ave: %.1f" % (state_ave * 100)) + + +# if __name__ == "__main__": +# app.run(main) diff --git a/projects/ppr/eval/eval_utils.py b/projects/ppr/eval/eval_utils.py new file mode 100644 index 0000000..6b63364 --- /dev/null +++ b/projects/ppr/eval/eval_utils.py @@ -0,0 +1,489 @@ +"""utility functions for evaluation. some functions are taken from DASR: +https://github.com/jefftan969/dasr/blob/main/eval_utils.py#L86 + +pip install -e third_party/ChamferDistancePytorch/chamfer3D/ +""" + +import sys, os +import pdb +import trimesh +from copy import deepcopy +import cv2 +import numpy as np +import torch +import tqdm +from matplotlib import pyplot as plt + +cmap = plt.get_cmap("plasma") + +sys.path.insert( + 0, + "%s/third_party/ChamferDistancePytorch/" % os.path.join(os.path.dirname(__file__)), +) + + +from chamfer3D.dist_chamfer_3D import chamfer_3DDist + +import pytorch3d +from pytorch3d.ops.knn import _KNN +from pytorch3d.ops.points_alignment import ( + ICPSolution, + SimilarityTransform, + corresponding_points_alignment, + _apply_similarity_transform, +) +from pytorch3d.ops.utils import wmean + +from lab4d.utils.vis_utils import visualize_trajectory + + +def load_ama_intrinsics(path): + pmat = np.loadtxt(path) + K, R, T, _, _, _, _ = cv2.decomposeProjectionMatrix(pmat) + Rmat_gt = R + Tmat_gt = T[:3, 0] / T[-1, 0] + Tmat_gt = Rmat_gt.dot(-Tmat_gt[..., None])[..., 0] + K = K / K[-1, -1] + intrinscs_gt = np.asarray([K[0, 0], K[1, 1], K[0, 2], K[1, 2]]) + Gmat_gt = np.eye(4) + Gmat_gt[:3, :3] = Rmat_gt + Gmat_gt[:3, 3] = Tmat_gt + return intrinscs_gt, Gmat_gt + + +def ama_eval( + pred_mesh_dict, + gt_mesh_dict, + verbose=False, + device="cuda", + shape_scale=1, +): + """Evaluate a sequence of AMA videos + Modified from DASR: https://github.com/jefftan969/dasr/blob/main/eval_utils.py#L86 + + Args + load_dir [str]: Directory to load predicted meshes from + seqname [str]: Name of sequence (e.g. T_samba) + vidid [int]: Video identifier (e.g. 1) + verbose [bool]: Whether to print eval metrics + render_vid [str]: If provided, output an error video to this path + + Returns: + cd_avg [float]: Chamfer distance (cm), averaged across all frames + f010_avg [float]: F-score at 10cm threshold, averaged across all frames + f005_avg [float]: F-score at 5cm threshold, averaged across all frames + f002_avg [float]: F-score at 2cm threshold, averaged across all frames + """ + all_verts_pred = [v.vertices for k, v in pred_mesh_dict.items()] + all_verts_gt = [v.vertices for k, v in gt_mesh_dict.items()] + + # Evaluate metrics: chamfer distance and f-score (@10cm, @5cm, @2cm) + nframes = len(all_verts_gt) + metrics = torch.zeros(nframes, 4, dtype=torch.float32, device=device) # nframes, 4 + all_verts_pred = torch.tensor(all_verts_pred, device=device, dtype=torch.float32) + all_verts_gt = torch.tensor(all_verts_gt, device=device, dtype=torch.float32) + + # visualize_trajectory(all_verts_pred, "pred") + # global sim3 alignment, translation, rotation, scale + all_verts_pred = align_seqs( + all_verts_pred[:, None], + all_verts_gt[:, None], + align_se3=True, + verbose=verbose, + ) + all_verts_pred = [x[0] for x in all_verts_pred] + # visualize_trajectory(all_verts_pred, "aligned") + # visualize_trajectory(all_verts_gt, "gt") + # pdb.set_trace() + + chamLoss = chamfer_3DDist() + pred_cd_list = [] + gt_cd_list = [] + for idx in tqdm.trange(nframes, desc=f"Evaluating:"): + raw_cd_fw, raw_cd_bw, _, _ = chamLoss( + all_verts_gt[idx][None], all_verts_pred[idx][None] + ) # 1, npts_gt | 1, npts_pred + raw_cd_fw = raw_cd_fw.squeeze(0) # npts_gt + raw_cd_bw = raw_cd_bw.squeeze(0) # npts_pred + pred_cd_list.append(raw_cd_bw) + gt_cd_list.append(raw_cd_fw) + + cd = torch.mean(torch.sqrt(raw_cd_fw)) + torch.mean(torch.sqrt(raw_cd_bw)) + f010, _, _ = fscore(raw_cd_fw, raw_cd_bw, threshold=(shape_scale * 0.10) ** 2) + f005, _, _ = fscore(raw_cd_fw, raw_cd_bw, threshold=(shape_scale * 0.05) ** 2) + f002, _, _ = fscore(raw_cd_fw, raw_cd_bw, threshold=(shape_scale * 0.02) ** 2) + + metrics[idx, 0] = cd + metrics[idx, 1] = f010 + metrics[idx, 2] = f005 + metrics[idx, 3] = f002 + + if verbose: + print( + f"Frame {idx}: CD={100 * cd:.2f}cm, f@10cm={100 * f010:.1f}%, " + f"f@5cm={100 * f005:.1f}%, f@2cm={100 * f002:.1f}%" + ) + + metrics = torch.mean(metrics, dim=0) # 4, + cd_avg, f010_avg, f005_avg, f002_avg = tuple(float(x) for x in metrics) + + if verbose: + print(f"Finished evaluation") + print(f" Avg chamfer dist: {100 * cd_avg:.2f}cm") + print(f" Avg f-score at d=10cm: {100 * f010_avg:.1f}%") + print(f" Avg f-score at d=5cm: {100 * f005_avg:.1f}%") + print(f" Avg f-score at d=2cm: {100 * f002_avg:.1f}%") + + # assign aligned vertices + for fidx in pred_mesh_dict.keys(): + pred_mesh_dict[fidx].vertices = all_verts_pred[fidx].cpu().numpy() + + pred_cd_dict = deepcopy(pred_mesh_dict) + gt_cd_dict = deepcopy(gt_mesh_dict) + vis_err_max = 0.02 # 2cm + for idx, fidx in tqdm.tqdm(enumerate(pred_cd_dict.keys()), desc=f"Evaluating:"): + pred_cd_dict[fidx].visual.vertex_colors = 255 * cmap( + pred_cd_list[idx].cpu().numpy() / vis_err_max + ) + gt_cd_dict[fidx].visual.vertex_colors = 255 * cmap( + gt_cd_list[idx].cpu().numpy() / vis_err_max + ) + + return ( + cd_avg, + f010_avg, + f005_avg, + f002_avg, + pred_mesh_dict, + pred_cd_dict, + gt_cd_dict, + ) + + +def fscore(dist1, dist2, threshold=0.001): + """ + Calculates the F-score between two point clouds with the corresponding threshold value. + modified from https://github.com/ThibaultGROUEIX/ChamferDistancePytorch + :param dist1: N-Points + :param dist2: N-Points + :param th: float + :return: fscore, precision, recall + """ + # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly. + precision_1 = torch.mean((dist1 < threshold).float()) + precision_2 = torch.mean((dist2 < threshold).float()) + fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) + fscore[torch.isnan(fscore)] = 0 + return fscore, precision_1, precision_2 + + +def timeseries_pointclouds_to_tensor(X_pts): + """Convert a time series of variable-length point clouds to a padded tensor + + Args: + X_pts [List(bs, npts[t], dim)]: List of length T, containing a + time-series of variable-length point cloud batches + + Returns: + X [bs, T, npts, dim]: Padded pointcloud tensor + num_points_X [bs, T]: Number of points in each point cloud + """ + bs, _, dim = X_pts[0].shape + T = len(X_pts) + device = X_pts[0].device + + num_points_X = torch.tensor( + [X_pts[t].shape[1] for t in range(T)], dtype=torch.int64 + ) # T, + num_points_X = num_points_X.to(device)[None, :].repeat(bs, 1) # bs, T + + npts = torch.max(num_points_X) + X = X_pts[0].new_zeros(bs, T, npts, dim) # bs, T, npts, dim + for t in range(T): + npts_t = X_pts[t].shape[1] + X[:, t, :npts_t] = X_pts[t] # bs, T, npts[t], dim + + return X, num_points_X + + +def timeseries_iterative_closest_point( + X_pts, + Y_pts, + init_transform=None, + max_iterations=100, + relative_rmse_thr=1e-6, + estimate_scale=False, + allow_reflection=False, + verbose=False, +): + """Execute the ICP algorithm to find a similarity transform (R, T, s) + between two time series of differently-sized point clouds + + Args: + X_pts [List(bs, npts[t], dim)]: Time-series of variable-length + point cloud batches + Y_pts [List(bs, npts[t], dim)]: Time-series of variable-length + point cloud batches + init_transform [SimilarityTransform]: If provided, initialization for + the similarity transform, containing orthonormal matrices + R [bs, dim, dim], translations T [bs, dim], and scaling s[bs,] + max_iterations (int): Maximum number of ICP iterations + relative_rmse_thr (float): Threshold on relative root mean square error + used to terminate the algorithm + estimate_scale (bool): If True, estimate a scaling component of the + transformation, otherwise assume identity scale + allow_reflection (bool): If True, allow algorithm to return `R` + which is orthonormal but has determinant -1 + verbose: If True, print status messages during each ICP iteration + + Returns: ICPSolution with the following fields + converged (bool): Boolean flag denoting whether the algorithm converged + rmse (float): Attained root mean squared error after termination + Xt [bs, T, size_X, dim]: Point cloud X transformed with final similarity + transformation (R, T, s) + RTs (SimilarityTransform): Named tuple containing a batch of similarity transforms: + R [bs, dim, dim] Orthonormal matrices + T [bs, dim]: Translations + s [bs,]: Scaling factors + t_history (list(SimilarityTransform)): List of similarity transform + parameters after each ICP iteration + """ + # Convert input Pointclouds structures to padded tensors + X, num_points_X = timeseries_pointclouds_to_tensor( + X_pts + ) # bs, T, size_X, dim | bs, T + Y, num_points_Y = timeseries_pointclouds_to_tensor( + Y_pts + ) # bs, T, size_Y, dim | bs, T + + if ( + (X.shape[3] != Y.shape[3]) + or (X.shape[1] != Y.shape[1]) + or (X.shape[0] != Y.shape[0]) + ): + raise ValueError( + "X and Y should have same number of batch, time, and data dimensions" + ) + bs, T, size_X, dim = X.shape + bs, T, size_Y, dim = Y.shape + + # Handle heterogeneous input + if ((num_points_Y < size_Y).any() or (num_points_X < size_X).any()) and ( + num_points_Y != num_points_X + ).any(): + mask_X = ( + torch.arange(size_X, dtype=torch.int64, device=X.device)[None, None, :] + < num_points_X[:, :, None] + ).type_as( + X + ) # bs, T, size_X + else: + mask_X = X.new_ones(bs, T, size_X) # bs, T, size_X + + X = X.reshape(bs, T * size_X, dim) # bs, T*size_X, dim + Y = Y.reshape(bs, T * size_Y, dim) # bs, T*size_Y, dim + mask_X = mask_X.reshape(bs, T * size_X) # bs, T*size_X + + # Clone the initial point cloud + X_init = X.clone() # bs, T*size_X, dim + + # Initialize transformation with identity + sim_R = torch.eye(dim, device=X.device, dtype=X.dtype)[None].repeat( + bs, 1, 1 + ) # bs, 3, 3 + sim_T = X.new_zeros((bs, dim)) # bs, dim + sim_s = X.new_ones(bs) # bs, + + prev_rmse = None + rmse = None + iteration = -1 + converged = False + t_history = [] + + # Main loop over ICP iterations + for iteration in range(max_iterations): + X_nn_points = timeseries_knn_points( + X.reshape(bs, T, size_X, dim), + Y.reshape(bs, T, size_Y, dim), + lengths_X=num_points_X, + lengths_Y=num_points_Y, + K=1, + return_nn=True, + ).knn[:, :, 0, :] + + # Get alignment of nearest neighbors from Y with X_init + sim_R, sim_T, sim_s = corresponding_points_alignment( + X_init, + X_nn_points, + weights=mask_X, + estimate_scale=estimate_scale, + allow_reflection=allow_reflection, + ) + + # Apply the estimated similarity transform to X_init + X = _apply_similarity_transform(X_init, sim_R, sim_T, sim_s) + + # Add current transformation to history + t_history.append(SimilarityTransform(sim_R, sim_T, sim_s)) + + # Compute root mean squared error + X_sq_diff = torch.sum((X - X_nn_points) ** 2, dim=2) + rmse = wmean(X_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0] + + # Compute relative rmse change + if prev_rmse is None: + relative_rmse = rmse.new_ones(bs) + else: + relative_rmse = (prev_rmse - rmse) / prev_rmse + + if verbose: + print( + f"ICP iteration {iteration}: mean/max rmse={rmse.mean():1.2e}/{rmse.max():1.2e}; " + f"mean relative rmse={relative_rmse.mean():1.2e}" + ) + + # Check for convergence + if (relative_rmse <= relative_rmse_thr).all(): + converged = True + break + + # Update the previous rmse + prev_rmse = rmse + + X = X.reshape(bs, T, size_X, dim) # bs, T, size_X, dim + return ICPSolution( + converged, rmse, X, SimilarityTransform(sim_R, sim_T, sim_s), t_history + ) + + +def align_seqs(all_verts_pred, all_verts_gt, align_se3=True, verbose=False): + """Align predicted mesh sequence to the ground-truths + Taken from DASR: https://github.com/jefftan969/dasr/blob/main/eval_utils.py#L86 + + + Args: + all_verts_pred (List(bs, npts[t], 3)): Time-series of predicted mesh batches + all_verts_gt (List(bs, npts[t], 3)): Time-series of ground-truth mesh batches + verbose (bool): Whether to print ICP results + + Returns: + out_verts_pred (List(bs, npts[t], 3)): Time-series of aligned predicted mesh batches + """ + device = all_verts_pred[0].device + nframes = len(all_verts_pred) + + # Compute coarse scale estimate (in the correct order of magnitude) + fitted_scale = torch.zeros(nframes, dtype=torch.float32, device=device) # nframes, + for i in range(nframes): + verts_pred = all_verts_pred[i] # 1, npts_pred, 3 + verts_gt = all_verts_gt[i] # 1, npts_gt, 3 + fitted_scale[i] = ( + torch.max(verts_gt[..., -1]) + torch.min(verts_gt[..., -1]) + ) / (torch.max(verts_pred[..., -1]) + torch.min(verts_pred[..., -1])) + fitted_scale = torch.mean(fitted_scale) + + out_verts_pred = [verts_pred * fitted_scale for verts_pred in all_verts_pred] + + if align_se3: + # Use ICP to align the first frame and fine-tune the scale estimate + # scale estimation with ICP is not reliable + frts0 = timeseries_iterative_closest_point( + out_verts_pred[:1], + all_verts_gt[:1], + estimate_scale=False, + max_iterations=100, + verbose=verbose, + ) + R_icp0, T_icp0, s_icp0 = frts0.RTs # 1, 3, 3 | 1, 3 | 1, 1 + + for i in range(nframes): + out_verts_pred[i] = _apply_similarity_transform( + out_verts_pred[i], R_icp0, T_icp0, s_icp0 + ) + + # Run global ICP across the point cloud time-series + frts = timeseries_iterative_closest_point( + out_verts_pred, + all_verts_gt, + estimate_scale=True, + max_iterations=100, + verbose=verbose, + ) + R_icp, T_icp, s_icp = frts.RTs # 1, 3, 3 | 1, 3 | 1, 1 + + for i in range(nframes): + out_verts_pred[i] = _apply_similarity_transform( + out_verts_pred[i], R_icp, T_icp, s_icp + ) + + return out_verts_pred + + +def timeseries_knn_points( + X, + Y, + lengths_X=None, + lengths_Y=None, + K=1, + version=-1, + return_nn=False, + return_sorted=True, +): + """K-nearest neighbors on two time series of point clouds. + + Args: + X [bs, T, size_X, dim]: A batch of `bs` time series, each with `T` + point clouds containing `size_X` points of dimension `dim` + Y [bs, T, size_Y, dim]: A batch of `bs` time series, each with `T` + point clouds containing `size_Y` points of dimension `dim` + lengths_X [bs, T]: Length of each point cloud in X, in range [0, size_X] + lengths_Y [bs, T]: Length of each point cloud in Y, in range [0, size_Y] + norm (int): Which norm to use, either 1 for L1-norm or 2 for L2-norm + K (int): Number of nearest neighbors to return + version (int): Which KNN implementation to use in the backend + return_nn (bool): If True, returns K nearest neighbors in p2 for each point + return_sorted (bool0: If True, return nearest neighbors sorted in + ascending order of distance + + Returns: + dists [bs, T*size_X, K]: Squared distances to nearest neighbors + idx [bs, T*size_X, K]: Indices of K nearest neighbors from X to Y. + If `X_idx[n, t, i, k] = j` then `Y[n, j]` is the k-th nearest + neighbor to `X_idx[n, t, i]` in `Y[n]`. + nn [bs, T*size_X, K, dim]: Coords of the K-nearest neighbors from X to Y. + """ + if ( + (X.shape[3] != Y.shape[3]) + or (X.shape[1] != Y.shape[1]) + or (X.shape[0] != Y.shape[0]) + ): + raise ValueError( + "X and Y should have same number of batch, time, and data dimensions" + ) + bs, T, size_X, dim = X.shape + bs, T, size_Y, dim = Y.shape + + # Call knn_points, treating time as a batch dimension + dists, idx, nn = pytorch3d.ops.knn_points( + X.reshape(bs * T, size_X, dim), + Y.reshape(bs * T, size_Y, dim), + lengths1=lengths_X.reshape(bs * T), + lengths2=lengths_Y.reshape(bs * T), + K=K, + version=version, + return_nn=return_nn, + return_sorted=return_sorted, + ) # bs*T, size_X, K | bs*T, size_X, K | bs*T, size_X, K, dim + + # Reshape into batched time-series of points, and offset points along T-dimension + dists = dists.reshape(bs, T * size_X, K) # bs, T*size_X, K + nn = ( + nn.reshape(bs, T * size_X, K, dim) if return_nn else None + ) # bs, T*size_X, K, dim + + idx = idx.reshape(bs, T, size_X, K) # bs, T, size_X, K + offsets = torch.cumsum(lengths_Y, dim=-1) - lengths_Y # bs, T + idx += offsets[:, :, None, None].repeat(1, 1, size_X, K) # bs, T, size_X, K + idx = idx.reshape(bs, T * size_X, K) # bs, T*size_X, K + + return _KNN(dists=dists, idx=idx, knn=nn) diff --git a/projects/ppr/eval/third_party/ChamferDistancePytorch b/projects/ppr/eval/third_party/ChamferDistancePytorch new file mode 160000 index 0000000..364c03c --- /dev/null +++ b/projects/ppr/eval/third_party/ChamferDistancePytorch @@ -0,0 +1 @@ +Subproject commit 364c03c4ec5febc1e21068ffac362eca4a8f61d9 diff --git a/projects/ppr/export.py b/projects/ppr/export.py new file mode 100644 index 0000000..12ddaf1 --- /dev/null +++ b/projects/ppr/export.py @@ -0,0 +1,23 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +""" +python projects/ppr/export.py --flagfile=logdir/cat-85-sub-sub-bob-pika-cate-b02/opts.log --load_suffix latest --inst_id 0 +""" + +import os, sys +from absl import app + +cwd = os.getcwd() +if cwd not in sys.path: + sys.path.insert(0, cwd) + +from lab4d.export import export, get_config +import config + + +def main(_): + opts = get_config() + export(opts) + + +if __name__ == "__main__": + app.run(main) diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys new file mode 160000 index 0000000..42a52e2 --- /dev/null +++ b/projects/ppr/ppr-diffphys @@ -0,0 +1 @@ +Subproject commit 42a52e206d1dce25fee40b946ce0b1c38eb7327b diff --git a/projects/ppr/render.py b/projects/ppr/render.py new file mode 100644 index 0000000..b357691 --- /dev/null +++ b/projects/ppr/render.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +# python scripts/render.py --seqname --flagfile=logdir/cat-0t10-fg-bob-d0-long/opts.log --load_suffix latest + +import os, sys +from absl import app + +cwd = os.getcwd() +if cwd not in sys.path: + sys.path.insert(0, cwd) + +from lab4d.render import render, get_config, construct_batch_from_opts +import config + + +def main(_): + opts = get_config() + render(opts, construct_batch_func=construct_batch_from_opts) + + +if __name__ == "__main__": + app.run(main) diff --git a/projects/ppr/render_intermediate.py b/projects/ppr/render_intermediate.py new file mode 100644 index 0000000..d5a9ce6 --- /dev/null +++ b/projects/ppr/render_intermediate.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +# python lab4d/render_intermediate.py --testdir logdir/human-48-category-comp/ +import sys, os +import pdb + +import glob +import numpy as np +import cv2 +import argparse +import trimesh +import tqdm + + +cwd = os.getcwd() +if cwd not in sys.path: + sys.path.insert(0, cwd) +from lab4d.utils.io import save_vid +from lab4d.utils.pyrender_wrapper import PyRenderWrapper + +parser = argparse.ArgumentParser(description="script to render cameras over epochs") +parser.add_argument("--testdir", default="", help="path to test dir") +parser.add_argument( + "--data_class", default="distilled", type=str, help="which data to render, {fg, bg}" +) +args = parser.parse_args() + + +def main(): + renderer = PyRenderWrapper() + # io + path_list = [i for i in glob.glob("%s/%s_*.obj" % (args.testdir, args.data_class))] + if len(path_list) == 0: + print("no mesh found in %s for %s" % (args.testdir, args.data_class)) + return + path_list = sorted( + path_list, key=lambda x: int(x.split("/")[-1].split("-")[-1][:-4]) + ) + outdir = "%s/renderings_trajs" % args.testdir + os.makedirs(outdir, exist_ok=True) + + mesh_dict = {} + aabb_min = np.asarray([np.inf, np.inf, np.inf]) + aabb_max = np.asarray([-np.inf, -np.inf, -np.inf]) + for mesh_path in path_list: + batch_idx = int(mesh_path.split("/")[-1].split("-")[-1][:-4]) + mesh_obj = trimesh.load(mesh_path) + mesh_dict[batch_idx] = mesh_obj + + # update aabb + aabb_min = np.minimum(aabb_min, mesh_obj.bounds[0]) + aabb_max = np.maximum(aabb_max, mesh_obj.bounds[1]) + + # set camera translation + # renderer.set_camera_bev(depth=max(aabb_max - aabb_min) * 1.2, gl=True) + aabb_range = max(aabb_max - aabb_min) + scene_to_cam = np.eye(4) + rot = cv2.Rodrigues(np.asarray([-np.pi * 8 / 9, 0, 0]))[0] + scene_to_cam[:3, :3] = rot + scene_to_cam[2, 3] = aabb_range + renderer.set_camera(scene_to_cam) + renderer.set_light_topdown(gl=True) + + # render + frames = [] + for batch_idx, mesh_obj in tqdm.tqdm(mesh_dict.items()): + # percentage = batch_idx / list(mesh_dict.keys())[-1] + # scene_to_cam[0, 3] = aabb_range * (percentage - 0.5) * 0.5 + renderer.set_camera(scene_to_cam) + input_dict = {"shape": mesh_obj} + color = renderer.render(input_dict)[0] + + # render another view + scene_to_cam_vp2 = scene_to_cam.copy() + # rotate 90 degrees along y axis + rot = cv2.Rodrigues(np.asarray([0, np.pi / 2, 0]))[0] + scene_to_cam_vp2[:3, :3] = scene_to_cam_vp2[:3, :3] @ rot + renderer.set_camera(scene_to_cam_vp2) + color_vp2 = renderer.render(input_dict)[0] + color = np.concatenate([color, color_vp2], axis=1) + + # add text + color = color.astype(np.uint8) + color = cv2.putText( + color, + "iteration: %04d" % batch_idx, + (30, 50), + cv2.FONT_HERSHEY_SIMPLEX, + 2, + (256, 0, 0), + 2, + ) + + frames.append(color) + + save_path = "%s/%s" % (outdir, args.data_class) + vid_secs = 5 # 5s + fps = len(frames) / vid_secs + save_vid(save_path, frames, suffix=".mp4", upsample_frame=-1, fps=fps) + print("saved to %s.mp4" % (save_path)) + + +if __name__ == "__main__": + main() diff --git a/projects/ppr/run_ama.sh b/projects/ppr/run_ama.sh new file mode 100644 index 0000000..cec0dc4 --- /dev/null +++ b/projects/ppr/run_ama.sh @@ -0,0 +1,60 @@ +# Description: Run experiments for PPR + +# ******ama-samba +seqname=ama-samba-4v + +# download pre-processed data +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/yikd46stoxe8p3m5tvpe1/ama-samba-4v.zip?rlkey=mc78xpctmis3cw6j0gzk84r2f&dl=0" + +# alternatively, run this if you want to process raw video +python scripts/run_preprocess.py $seqname human "0,1"; + +# scene reconstruction +# rm -rf logdir/$seqname-bg +bash scripts/train.sh lab4d/train.py 0 --seqname $seqname --logname bg \ + --field_type bg --data_prefix full --num_rounds 20 --alter_flow --mask_wt 0.01 --normal_wt 1e-2 --reg_eikonal_wt 0.01 --nosingle_scene --freeze_intrinsics + +# foreground reconstruction +# rm -rf logdir/$seqname-fg-urdf +bash scripts/train.sh lab4d/train.py 0 --seqname $seqname --logname fg-urdf --fg_motion urdf-human --num_rounds 20 --feature_type cse --freeze_intrinsics + +# physical reconstruction +# rm -rf logdir/$seqname-ppr +bash scripts/train.sh projects/ppr/train.py 0 --seqname $seqname --logname ppr --field_type comp --fg_motion urdf-human --feature_type cse --nosingle_scene \ + --num_rounds 20 --iters_per_round 100 --ratio_phys_cycle 0.5 --phys_vis_interval 20 --frame_interval 0.0333 --secs_per_wdw 2.0 --warmup_iters 100 \ + --pixels_per_image 12 --noreset_steps --learning_rate 1e-4 --noabsorb_base \ + --load_path logdir/$seqname-fg-urdf/ckpt_latest.pth \ + --load_path_bg logdir/$seqname-bg/ckpt_latest.pth + +# export meshes and visualize results, run +python projects/ppr/render_intermediate.py --testdir logdir/$seqname-ppr/ --data_class sim +python projects/ppr/export.py --flagfile=logdir/$seqname-ppr/opts.log --load_suffix latest --inst_id 0 --vis_thresh 0 --extend_aabb +python lab4d/render_mesh.py --testdir logdir/$seqname-ppr/export_0000/ --view bev --ghosting + +# *****ama-bouncing +seqname=ama-bouncing-4v + +# download pre-processed data +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/hld0yyofjl5gb3hbdnra2/ama-bouncing-4v.zip?rlkey=uzoluxprkm33sryt49726wlee&dl=0" + +# alternatively, run this if you want to process raw video +python scripts/run_preprocess.py $seqname human "0,1"; + +# rm -rf logdir/$seqname-bg +bash scripts/train.sh lab4d/train.py 0 --seqname $seqname --logname bg \ + --field_type bg --data_prefix full --num_rounds 20 --alter_flow --mask_wt 0.01 --normal_wt 1e-2 --reg_eikonal_wt 0.01 --nosingle_scene --freeze_intrinsics + +# rm -rf logdir/$seqname-fg-urdf +bash scripts/train.sh lab4d/train.py 0 --seqname $seqname --logname fg-urdf --fg_motion urdf-human --num_rounds 20 --feature_type cse --freeze_intrinsics + +# rm -rf logdir/$seqname-ppr +bash scripts/train.sh projects/ppr/train.py 0 --seqname $seqname --logname ppr --field_type comp --fg_motion urdf-human --feature_type cse --nosingle_scene \ + --num_rounds 20 --iters_per_round 100 --frame_interval 0.0333 --secs_per_wdw 1.0 --warmup_iters 100 \ + --pixels_per_image 12 --noreset_steps --learning_rate 1e-4 --noabsorb_base \ + --load_path logdir/$seqname-fg-urdf/ckpt_latest.pth \ + --load_path_bg logdir/$seqname-bg/ckpt_latest.pth + +# export meshes and visualize results, run +python projects/ppr/render_intermediate.py --testdir logdir/$seqname-ppr/ --data_class sim +python projects/ppr/export.py --flagfile=logdir/$seqname-ppr/opts.log --load_suffix latest --inst_id 0 --vis_thresh 0 --extend_aabb +python lab4d/render_mesh.py --testdir logdir/$seqname-ppr/export_0000/ --view bev --ghosting \ No newline at end of file diff --git a/projects/ppr/run_cat.sh b/projects/ppr/run_cat.sh new file mode 100644 index 0000000..d2fb4c5 --- /dev/null +++ b/projects/ppr/run_cat.sh @@ -0,0 +1,19 @@ +## download data +#bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/j2bztcv49mqc4a6ngj8jp/cat-pikachu-0.zip?rlkey=g2sw8te19bsirr5srdl17kzt1&dl=0" + +# reconstruct the scene +# Args: gpu-id, sequence name, hyper-parameters defined in lab4d/config.py +bash scripts/train.sh lab4d/train.py 0 --seqname cat-pikachu-0 --logname bg --field_type bg --data_prefix full --num_rounds 60 --alter_flow --mask_wt 0.01 --normal_wt 1e-2 --reg_eikonal_wt 0.01 + +# reconstruct the object: +# Args: gpu-id, sequence name, hyper-parameters defined in lab4d/config.py +bash scripts/train.sh lab4d/train.py 0 --seqname cat-pikachu-0 --logname fg-urdf --fg_motion urdf-quad --num_rounds 20 --feature_type cse + +# physics-informed optimization +# Args: gpu-id sequence name, hyper-parameters in both lab4d/config.py and and projects/ppr/config.py +bash scripts/train.sh projects/ppr/train.py 0 --seqname cat-pikachu-0 --logname ppr --field_type comp --fg_motion urdf-quad --feature_type cse --num_rounds 20 --learning_rate 1e-4 --pixels_per_image 12 --iters_per_round 100 --secs_per_wdw 2.4 --noreset_steps --noabsorb_base --load_path logdir/cat-pikachu-0-fg-urdf/ckpt_latest.pth --load_path_bg logdir/cat-pikachu-0-bg/ckpt_latest.pth + +# visualization +python projects/ppr/render_intermediate.py --testdir logdir/cat-pikachu-0-ppr/ --data_class sim +python projects/ppr/export.py --flagfile=logdir/cat-pikachu-0-ppr/opts.log --load_suffix latest --inst_id 0 --vis_thresh -10 --extend_aabb +python lab4d/render_mesh.py --testdir logdir/cat-pikachu-0-ppr/export_0000/ --view bev --ghosting diff --git a/projects/ppr/simulate.py b/projects/ppr/simulate.py new file mode 100644 index 0000000..eaef38f --- /dev/null +++ b/projects/ppr/simulate.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +""" +python projects/ppr/simulate.py --flagfile=logdir/cat-pikachu-0-ppr/opts.log --load_suffix latest --load_suffix_phys latest --inst_id 0 +""" + +import os, sys +import numpy as np +import cv2 +import pdb +import json +from absl import app, flags + + +cwd = os.getcwd() +if cwd not in sys.path: + sys.path.insert(0, cwd) + +from lab4d.config import get_config +from lab4d.utils.io import make_save_dir, save_rendered + +sys.path.insert(0, "%s/../" % os.path.join(os.path.dirname(__file__))) +from trainer import PPRTrainer +from trainer import PhysVisualizer + + +class ExportMeshFlags: + flags.DEFINE_integer("inst_id", 0, "video/instance id") + + +def export_simulate_mesh(save_dir, data, tag): + traj = data[tag] + camera = data["camera"] + save_dir = os.path.join(save_dir, tag) + save_dir_fg = os.path.join(save_dir, "fg/mesh") + save_dir_bg = os.path.join(save_dir, "bg/mesh") + os.makedirs(save_dir_fg, exist_ok=True) + os.makedirs(save_dir_bg, exist_ok=True) + + # save fg + bg_motion = {"field2cam": []} + fg_motion = {"field2cam": []} + for frame_idx in range(len(traj)): + # mesh + mesh = traj[frame_idx] + mesh.export(os.path.join(save_dir_fg, "%05d.obj" % frame_idx)) + + # camera pose + camera_pose = np.eye(4) + camera_pose[:3] = camera[frame_idx][:3] + bg_motion["field2cam"].append(camera_pose.tolist()) + fg_motion["field2cam"].append(camera_pose.tolist()) + + # save bg mesh + data["floor"].export(os.path.join(save_dir_bg, "%05d.obj" % frame_idx)) + + # save to json + with open(os.path.join(save_dir, "bg/motion.json"), "w") as file: + json.dump(bg_motion, file) + + with open(os.path.join(save_dir, "fg/motion.json"), "w") as file: + json.dump(fg_motion, file) + + # save field2world + with open(os.path.join(save_dir, "bg/field2world.json"), "w") as file: + field2world = np.eye(4) + field2world[:3, :3] = cv2.Rodrigues(np.asarray([np.pi, 0, 0]))[0] + json.dump(field2world.tolist(), file) + + # save intrinsics + with open(os.path.join(save_dir, "camera.json"), "w") as file: + camera_info = {} + camera_info["raw_size"] = [] + camera_info["intrinsics"] = camera[:, 3].tolist() + json.dump(camera_info, open("%s/camera.json" % (save_dir), "w")) + + +def simulate(opts): + opts["urdf_template"] = opts["fg_motion"].split("-")[1].split("_")[0] + ( + model, + data_info, + ref_dict, + phys_model, + ) = PPRTrainer.construct_test_model(opts) + + save_dir = make_save_dir(opts, sub_dir="simulate_%04d" % (opts["inst_id"])) + phys_visualizer = PhysVisualizer(save_dir) + + # reset scale to avoid initial penetration + data = PPRTrainer.simulate(phys_model, data_info, opts["inst_id"]) + fps = 1.0 / phys_model.frame_interval + phys_visualizer.show("simulated_ref", data, fps=fps, view_mode="ref") + phys_visualizer.show("simulated_bev", data, fps=fps, view_mode="bev") + phys_visualizer.show("simulated_front", data, fps=fps, view_mode="front") + + data["floor"] = phys_visualizer.floor + export_simulate_mesh(save_dir, data, tag="sim_traj") + + print("Results saved to %s" % (save_dir)) + return + + +def main(_): + opts = get_config() + simulate(opts) + + +if __name__ == "__main__": + app.run(main) diff --git a/projects/ppr/train.py b/projects/ppr/train.py new file mode 100644 index 0000000..6c48182 --- /dev/null +++ b/projects/ppr/train.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +import os +import sys +import pdb +from absl import app + +sys.path.insert(0, "%s/../../" % os.path.join(os.path.dirname(__file__))) +from lab4d.train import train_ddp + +sys.path.insert(0, "%s/../" % os.path.join(os.path.dirname(__file__))) +from ppr.trainer import PPRTrainer + + +def main(_): + train_ddp(PPRTrainer) + + +if __name__ == "__main__": + app.run(main) diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py new file mode 100644 index 0000000..ba18a0a --- /dev/null +++ b/projects/ppr/trainer.py @@ -0,0 +1,351 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +import os, sys +import pdb +import torch +import numpy as np +import tqdm +import gc + +from lab4d.engine.trainer import Trainer +from lab4d.engine.trainer import get_local_rank +from lab4d.engine.model import dvr_model +from ppr import config + +sys.path.insert(0, "%s/ppr-diffphys" % os.path.join(os.path.dirname(__file__))) +from diffphys.dp_interface import phys_interface, query_q +from diffphys.vis import PhysVisualizer +from diffphys.dp_utils import se3_loss + + +class dvr_phys_reg(dvr_model): + """A model that contains a collection of static/deformable neural fields + + Args: + config (Dict): Command-line args + data_info (Dict): Dataset metadata from get_data_info() + """ + + @torch.no_grad() + def copy_phys_traj(self, phys_model): + phys_traj = {} + phys_traj["steps_fr"] = torch.arange( + phys_model.total_frames, device=self.device + ) + # phys_traj["phys_q"] = phys_model.root_pose_mlp(phys_traj["steps_fr"]) + # phys_traj["phys_ja"] = phys_model.joint_angle_mlp(phys_traj["steps_fr"]) + # N, 7/dof + phys_traj["phys_q"] = phys_model.root_pose_distilled(phys_traj["steps_fr"]) + phys_traj["phys_ja"] = phys_model.joint_angle_distilled(phys_traj["steps_fr"]) + self.phys_traj = phys_traj + + def forward(self, batch): + loss_dict = super().forward(batch) + reg_phys_q, reg_phys_ja = self.compute_kinemaics_phys_diff() + loss_dict["phys_q_reg"] = self.config["reg_phys_q_wt"] * reg_phys_q + loss_dict["phys_ja_reg"] = self.config["reg_phys_ja_wt"] * reg_phys_ja + return loss_dict + + def compute_kinemaics_phys_diff(self): + """ + compute the difference between the target kinematics and kinematics estimated by physics proxy + """ + if not hasattr(self, "phys_traj"): + return ( + torch.zeros(1).to(self.device).mean(), + torch.zeros(1).to(self.device).mean(), + ) + steps_fr = self.phys_traj["steps_fr"] + phys_q = self.phys_traj["phys_q"] + phys_ja = self.phys_traj["phys_ja"] + + object_field = self.fields.field_params["fg"] + scene_field = self.fields.field_params["bg"] + kinematics_q, _ = query_q(steps_fr, object_field, scene_field) + kinematics_ja = object_field.warp.articulation.get_vals( + steps_fr, return_so3=True + ) + + loss_q = se3_loss(phys_q, kinematics_q).mean() + loss_ja = (phys_ja - kinematics_ja).pow(2).mean() + # print("loss_q:", loss_q) + # print("loss_ja:", loss_ja) + return loss_q, loss_ja + + +class PPRTrainer(Trainer): + def __init__(self, opts): + """Train and evaluate a Lab4D model. + + Args: + opts (Dict): Command-line args from absl (defined in lab4d/config.py) + """ + opts["phys_vid"] = [int(i) for i in opts["phys_vid"].split(",")] + opts["urdf_template"] = opts["fg_motion"].split("-")[1].split("_")[0] + + super().__init__(opts) + self.model.fields.field_params["bg"].compute_field2world() + for vidid in opts["phys_vid"]: + mesh = self.model.fields.field_params["bg"].visualize_floor_mesh( + vidid, to_world=True + ) + mesh.export("%s/floor_%02d.obj" % (self.save_dir, vidid)) + + # after loading the ckeckpoints + self.floor_fitting() + self.init_phys_coupling() + + def floor_fitting(self): + """ + fit floor to the background reconstruction + """ + self.model.fields.field_params["bg"].compute_field2world() + if get_local_rank() == 0: + for vidid in self.opts["phys_vid"]: + mesh = self.model.fields.field_params["bg"].visualize_floor_mesh( + vidid, to_world=True + ) + mesh.export("%s/floor_%02d.obj" % (self.save_dir, vidid)) + + def init_phys_coupling(self): + """ + initialize scale lowest point fitting + """ + # initialize control input of phys model to kinematics + self.phys_model.override_control_ref_states() + self.phys_model.override_distilled_states() + + # reset scale to avoid initial penetration + frame_offset_raw = self.phys_model.frame_offset_raw + vid_frames = [] + for vidid in self.opts["phys_vid"]: + vid_frame = range(frame_offset_raw[vidid], frame_offset_raw[vidid + 1]) + vid_frames += vid_frame + self.phys_model.correct_scale(vid_frames) + if get_local_rank() == 0: + self.run_phys_visualization(tag="kinematics") + + def trainer_init(self): + super().trainer_init() + + opts = self.opts + self.current_steps_phys = 0 # 0-total_steps + self.current_round_phys = 0 # 0-total_rounds + self.iters_per_phys_cycle = int( + opts["ratio_phys_cycle"] * opts["iters_per_round"] + ) + print("# iterations per phys cycle:", self.iters_per_phys_cycle) + + def init_model(self): + """Initialize camera transforms, geometry, articulations, and camera + intrinsics from external priors, if this is the first run""" + if self.opts["load_path"] == "": + super().init_model() + return + + def define_model(self, model=dvr_phys_reg): + super().define_model(model=model) + self.phys_model = self.define_phys_standalone( + self.model, self.opts, self.data_info + ) + self.phys_visualizer = PhysVisualizer(self.save_dir) + + # move model to device + self.device = torch.device("cuda:{}".format(get_local_rank())) + self.phys_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.phys_model) + self.phys_model = self.phys_model.to(self.device) + + @staticmethod + def define_phys_standalone(model, opts, data_info): + """Define a standalon phys model""" + model_dict = {} + model_dict["scene_field"] = model.fields.field_params["bg"] + model_dict["object_field"] = model.fields.field_params["fg"] + model_dict["intrinsics"] = model.intrinsics + model_dict["frame_interval"] = opts["frame_interval"] + model_dict["frame_info"] = data_info["frame_info"] + + # define phys model + phys_model = phys_interface(opts, model_dict, dt=opts["timestep"]) + return phys_model + + def load_checkpoint_train(self): + """Load a checkpoint at training time and update the current step count + and round count + """ + if self.opts["load_path_bg"] != "": + # load background and intrinsics model + checkpoint = torch.load(self.opts["load_path_bg"]) + model_states = checkpoint["model"] + self.model.load_state_dict(model_states, strict=False) + super().load_checkpoint_train() + + # # reset beta + # beta = torch.tensor([0.01]).to(self.device) + # self.model.fields.field_params["fg"].logibeta.data = -beta.log() + # self.model.fields.field_params["bg"].logibeta.data = -beta.log() + + self.model.fields.reset_geometry_aux() + + def get_lr_dict(self, pose_correction=False): + """Return the learning rate for each category of trainable parameters + + Returns: + param_lr_startwith (Dict(str, float)): Learning rate for base model + param_lr_with (Dict(str, float)): Learning rate for explicit params + """ + # define a dict for (tensor_name, learning) pair + param_lr_startwith, param_lr_with = super().get_lr_dict( + pose_correction=pose_correction + ) + opts = self.opts + + param_lr_with.update( + { + "module.fields.field_params.fg.basefield.": 0.0, + # "module.fields.field_params.fg.colorfield.": 0.0, + "module.fields.field_params.fg.sdf.": 0.0, + # "module.fields.field_params.fg.rgb.": 0.0, + "module.fields.field_params.fg.vis_mlp.": 0.0, + "module.fields.field_params.bg.basefield.": 0.0, + # "module.fields.field_params.bg.colorfield.": 0.0, + "module.fields.field_params.bg.sdf.": 0.0, + # "module.fields.field_params.bg.rgb.": 0.0, + "module.fields.field_params.bg.vis_mlp.": 0.0, + "module.fields.field_params.fg.warp.articulation.logscale": 0.0, + "module.fields.field_params.fg.warp.articulation.log_bone_len": 0.0, + "module.fields.field_params.bg.camera_mlp.": 0.0, + } + ) + del param_lr_with[".logscale"] # do not update scale of the urdf + return param_lr_startwith, param_lr_with + + def run_one_round(self): + # run dr cycle + super().run_one_round() + if self.opts["ratio_phys_cycle"] > 0: + self.run_one_round_phys() + + def run_one_round_phys(self): + # determine wdw size + secs_per_wdw = self.opts["secs_per_wdw"] + # # schedule: 0-end, 0.2-2s + # progress = self.current_steps_phys / self.phys_model.total_iters + # secs_per_wdw = (1 - progress) * 0.5 + progress * 2 + + # warmup phys + if self.current_round_phys == 0 and self.opts["warmup_iters"] > 0: + self.run_phys_cycle(0.2, num_iters=self.opts["warmup_iters"]) + + # run physics cycle + if self.phys_model.copy_weights: + # transfer dvr kinematics to phys + self.phys_model.override_distilled_states() + self.run_phys_cycle(secs_per_wdw) + # transfer phys-optimized kinematics to dvr + self.phys_model.override_states_inv() + else: + self.run_phys_cycle(secs_per_wdw) + # transfer phys-optimized kinematics to dvr as soft constriaints + self.model.copy_phys_traj(self.phys_model) + self.current_round_phys += 1 + + def init_phys_env_train(self, ses_per_wdw): + opts = self.opts + # to use the same amount memory as DR + total_timesteps = ses_per_wdw / opts["timestep"] + num_envs = int(96000 / total_timesteps) + frames_per_wdw = int(ses_per_wdw / self.phys_model.frame_interval) + 1 + overwrite = self.opts["warmup_iters"] > 0 + print("num_envs:", num_envs) + print("frames_per_wdw:", frames_per_wdw) + self.phys_model.train() + self.phys_model.reinit_envs( + num_envs, + frames_per_wdw=frames_per_wdw, + is_eval=False, + overwrite=overwrite, + ) + + def run_phys_cycle(self, secs_per_wdw, num_iters=0): + opts = self.opts + gc.collect() # need to be used together with empty_cache() + torch.cuda.empty_cache() + + self.init_phys_env_train(secs_per_wdw) + num_iters = self.iters_per_phys_cycle if num_iters == 0 else num_iters + for i in tqdm.tqdm(range(num_iters)): + self.phys_model.set_progress(self.current_steps_phys) + self.run_phys_iter() + self.current_steps_phys += 1 + if self.current_steps_phys % opts["phys_vis_interval"] == 0: + # eval + self.phys_model.save_checkpoint(self.current_steps_phys) + self.run_phys_visualization(tag="phys") + self.init_phys_env_train(secs_per_wdw) + + def run_phys_iter(self): + """Run physics optimization""" + phys_aux = self.phys_model() + self.phys_model.backward(phys_aux["total_loss"]) + grad_dict = self.phys_model.update() + phys_aux.update(grad_dict) + if get_local_rank() == 0: + del phys_aux["total_loss"] + self.add_scalar(self.log, phys_aux, self.current_steps_phys) + + @torch.no_grad() + def run_phys_visualization(self, tag=""): + self.phys_model.eval() + opts = self.opts + frame_offset_raw = self.phys_model.frame_offset_raw + for vidid in opts["phys_vid"]: + data = self.simulate(self.phys_model, self.data_info, vidid) + self.phys_visualizer.show( + "%s-%02d-%05d" % (tag, vidid, self.current_steps_phys), + data, + fps=1.0 / self.phys_model.frame_interval, + view_mode="front", + ) + + @staticmethod + @torch.no_grad() + def simulate(phys_model, data_info, vidid): + """ + run phys simulation for a video in eval mode + """ + device = phys_model.device + frame_offset_raw = phys_model.frame_offset_raw + num_frames = frame_offset_raw[vidid + 1] - frame_offset_raw[vidid] + phys_model.reinit_envs(1, frames_per_wdw=num_frames, is_eval=True) + frame_start = torch.zeros(1) + frame_offset_raw[vidid] + _ = phys_model(frame_start=frame_start.to(device)) + img_size = tuple(data_info["raw_size"][vidid]) + img_size = img_size + (0.5,) # scale + data = phys_model.query(img_size=img_size) + return data + + @staticmethod + def construct_test_model(opts): + """Load a model at test time + + Args: + opts (Dict): Command-line options + """ + # io + logname = "%s-%s" % (opts["seqname"], opts["logname"]) + + # construct dvr model + model, data_info, ref_dict = Trainer.construct_test_model(opts) + + # construct phys model + phys_model = PPRTrainer.define_phys_standalone(model, opts, data_info) + load_path = "%s/%s/ckpt_phys_%s.pth" % ( + opts["logroot"], + logname, + opts["load_suffix_phys"], + ) + phys_model.load_checkpoint(load_path) + phys_model.cuda() + phys_model.eval() + + return model, data_info, ref_dict, phys_model diff --git a/projects/ppr/viewer.py b/projects/ppr/viewer.py new file mode 100644 index 0000000..b597e39 --- /dev/null +++ b/projects/ppr/viewer.py @@ -0,0 +1,23 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +""" +python projects/ppr/export.py --flagfile=logdir/cat-85-sub-sub-bob-pika-cate-b02/opts.log --load_suffix latest --inst_id 0 +""" + +import os, sys +from absl import app + +cwd = os.getcwd() +if cwd not in sys.path: + sys.path.insert(0, cwd) + +from viewer.viewer import run_viewer, get_config +import config + + +def main(_): + opts = get_config() + run_viewer(opts) + + +if __name__ == "__main__": + app.run(main) diff --git a/scripts/download_all_data.sh b/scripts/download_all_data.sh new file mode 100644 index 0000000..65664d4 --- /dev/null +++ b/scripts/download_all_data.sh @@ -0,0 +1,13 @@ +# download all test data +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/5wfbc692qhpejhyo8u9r0/car-turnaround-2.zip?rlkey=riq060i3wm5raynxryf8g2hcw&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/6w0qgeuc6gh02ix1o1tck/finch.zip?rlkey=jkz09o6ipw0yb78s7qnt9l1n6&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/s/mb7zgk73oomix4s/cat-pikachu-0.zip?dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/mitk5e36hz4anmbksmgki/squirrel.zip?rlkey=xwgee3bc5t0e9lyu8r9oz3oag&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/8yc8vuaimpzctiiszdbku/dog-robolounge.zip?rlkey=ky21wq5ah0na4xutqks6lwzvy&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/s/3w0vhh05olzwwn4/cat-pikachu.zip?dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/5ok2s27p1d1q6wg47ljiu/shiba-haru.zip?rlkey=qqmk353oysw1q05l6xepjw01m&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/yse6ohs6cinot228fup9p/human-cap.zip?rlkey=zwf5t8pefcp0ndebphlyngt9t&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/lgljmx9ckmfif7ovajv6q/penguin.zip?rlkey=rmakdtigf06mqdbu0omr0w569&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/c6lrg2aaabat4gu57avbq/human-48.zip?rlkey=ezpc3k13qgm1yqzm4v897whcj&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/xfaot22qbzz0o0ncl5bna/cat-85.zip?rlkey=wcer6lf0u4en7tjzaonj5v96q&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/h2m7f3jqzm4a2u3lpxhki/dog-98.zip?rlkey=x4fy74mbk7qrhc5ovmt4lwpkg&dl=0" \ No newline at end of file diff --git a/scripts/download_all_log.sh b/scripts/download_all_log.sh new file mode 100644 index 0000000..8c31e55 --- /dev/null +++ b/scripts/download_all_log.sh @@ -0,0 +1,13 @@ +# download all checkpoints +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/3g03jso6803ck4irg4ha2/log-car-turnaround-2-fg-rigid-b120.zip?rlkey=9ear4wux3noato7lhkfdclw3a&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/el4mlo3x0o50ktcgmsvhl/log-cat-pikachu-0-fg-skel-b120.zip?rlkey=lwc9gis8whn3gyfo3a0ct86uv&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/iow542jki6krk25oqxrpq/log-cat-pikachu-0-comp-comp-s2.zip?rlkey=iiuh40c19qc4kcdbm9t002ujn&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/hr526prumgkicpcabo7bd/log-squirrel-fg-comp-b120.zip?rlkey=ndkc918ww45e03wgfzb2tqsde&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/xcu57yshzahbrs6u17wht/log-dog-robolounge-fg-comp-b120.zip?rlkey=7cloqjq97rv4e81w2414dlwsn&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/f9i7hdni7tldwx96owstj/log-cat-pikachu-fg-bob-b120.zip?rlkey=00ipeg8w6se7baf1njf00qa8g&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/5t5p070obyszffifb5xsc/log-shiba-haru-fg-comp-b120.zip?rlkey=pt8dqh4oft52gdp7usu0prv4j&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/hcnbvbmp9kegpmb4xv8x4/log-human-cap-fg-comp-b120.zip?rlkey=qd7p0u9mirwb9t6zxgd9tqh22&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/iykw85xdx502l8a53cflb/log-penguin-fg-skel-b120.zip?rlkey=to9zt5x4uocj2xj5yd0gazzx8&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/8px220byvcv8912x2q3mu/log-human-48-category-comp.zip?rlkey=7z4me9mzmwto9nh34ihuojosh&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/rcm2jur101issowcpdihq/log-cat-85-category-comp.zip?rlkey=w1b317frn7ct1oa81bipmmt18&dl=0" +bash scripts/download_unzip.sh "https://www.dropbox.com/scl/fi/5zkottt2xug6e0dhd3t15/log-dog-98-category-comp.zip?rlkey=vg6qarpmb9fdi3i1wwcz7hpdu&dl=0" \ No newline at end of file diff --git a/scripts/install-deps.sh b/scripts/install-deps.sh index 315e3cc..47c8aba 100644 --- a/scripts/install-deps.sh +++ b/scripts/install-deps.sh @@ -9,3 +9,5 @@ wget https://www.dropbox.com/s/bgsodsnnbxdoza3/vcn_rob.pth -O ./preprocess/third wget https://www.dropbox.com/s/51cjzo8zgz966t5/human.pth -O preprocess/third_party/viewpoint/human.pth wget https://www.dropbox.com/s/1464pg6c9ce8rve/quad.pth -O preprocess/third_party/viewpoint/quad.pth + +wget "https://www.dropbox.com/scl/fi/hbemdtw4fuzlrfgz36ni6/omnidata_dpt_normal_v2_cleaned.ckpt?rlkey=duhix8g259wtcdfyxlf0jyvq1" -O preprocess/third_party/omnivision/omnidata_dpt_normal_v2_cleaned.ckpt diff --git a/scripts/render_intermediate.py b/scripts/render_intermediate.py deleted file mode 100644 index 46ef90d..0000000 --- a/scripts/render_intermediate.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. -# python scripts/render_intermediate.py --testdir logdir/human-48-category-comp/ -import sys, os -import pdb - -os.environ["PYOPENGL_PLATFORM"] = "egl" # opengl seems to only work with TPU -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -import glob -import numpy as np -import cv2 -import argparse -import trimesh -import pyrender -from pyrender import IntrinsicsCamera, Mesh, Node, Scene, OffscreenRenderer -import matplotlib -import tqdm - -from lab4d.utils.io import save_vid - -cmap = matplotlib.colormaps.get_cmap("cool") - -parser = argparse.ArgumentParser(description="script to render cameras over epochs") -parser.add_argument("--testdir", default="", help="path to test dir") -parser.add_argument( - "--data_class", default="fg", type=str, help="which data to render, {fg, bg}" -) -args = parser.parse_args() - -img_size = 1024 - -# renderer -r = OffscreenRenderer(img_size, img_size) -cam = IntrinsicsCamera(img_size, img_size, img_size / 2, img_size / 2) -# light -direc_l = pyrender.DirectionalLight(color=np.ones(3), intensity=3.0) -light_pose = np.asarray( - [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=float -) -# cv to gl coords -cam_pose = -np.eye(4) -cam_pose[0, 0] = 1 -cam_pose[-1, -1] = 1 -rtmat = np.eye(4) -# object to camera transforms -rtmat[:3, :3] = cv2.Rodrigues(np.asarray([np.pi / 2, 0, 0]))[0] # bev - - -def main(): - # io - path_list = [ - i for i in glob.glob("%s/*-%s-proxy.obj" % (args.testdir, args.data_class)) - ] - if len(path_list) == 0: - print("no mesh found in %s for %s" % (args.testdir, args.data_class)) - return - path_list = sorted(path_list, key=lambda x: int(x.split("/")[-1].split("-")[0])) - outdir = "%s/renderings_proxy" % args.testdir - os.makedirs(outdir, exist_ok=True) - - mesh_dict = {} - aabb_min = np.asarray([np.inf, np.inf, np.inf]) - aabb_max = np.asarray([-np.inf, -np.inf, -np.inf]) - for mesh_path in path_list: - batch_idx = int(mesh_path.split("/")[-1].split("-")[0]) - mesh_obj = trimesh.load(mesh_path) - mesh_dict[batch_idx] = mesh_obj - - # update aabb - aabb_min = np.minimum(aabb_min, mesh_obj.bounds[0]) - aabb_max = np.maximum(aabb_max, mesh_obj.bounds[1]) - - # set camera translation - rtmat[2, 3] = max(aabb_max - aabb_min) * 1.2 - - # render - frames = [] - for batch_idx, mesh_obj in tqdm.tqdm(mesh_dict.items()): - scene = Scene(ambient_light=0.4 * np.asarray([1.0, 1.0, 1.0, 1.0])) - - # add object / camera - mesh_obj.apply_transform(rtmat) - scene.add_node(Node(mesh=Mesh.from_trimesh(mesh_obj))) - - # camera - scene.add(cam, pose=cam_pose) - - # light - scene.add(direc_l, pose=light_pose) - - # render - color, depth = r.render( - scene, - flags=pyrender.RenderFlags.SHADOWS_DIRECTIONAL - | pyrender.RenderFlags.SKIP_CULL_FACES, - ) - # add text - color = color.astype(np.uint8) - color = cv2.putText( - color, - "batch: %02d" % batch_idx, - (30, 50), - cv2.FONT_HERSHEY_SIMPLEX, - 2, - (256, 0, 0), - 2, - ) - frames.append(color) - - save_vid("%s/fg" % outdir, frames, suffix=".mp4", upsample_frame=-1) - print("saved to %s/fg.mp4" % outdir) - - -if __name__ == "__main__": - main() diff --git a/scripts/run_preprocess.py b/scripts/run_preprocess.py index aa5f339..07ab88d 100644 --- a/scripts/run_preprocess.py +++ b/scripts/run_preprocess.py @@ -21,6 +21,7 @@ from preprocess.scripts.write_config import write_config from preprocess.third_party.vcnplus.compute_flow import compute_flow from preprocess.third_party.vcnplus.frame_filter import frame_filter +from preprocess.third_party.omnivision.normal import extract_normal track_anything_module = importlib.import_module( "preprocess.third_party.Track-Anything.app" @@ -75,6 +76,7 @@ def run_extract_priors(seqname, outdir, obj_class_cam): # depth extract_depth(seqname) + extract_normal(seqname) # crop around object and process flow extract_crop(seqname, 256, 0) @@ -107,7 +109,7 @@ def run_extract_priors(seqname, outdir, obj_class_cam): # True: manually annotate camera for key frames use_manual_cameras = True if obj_class_cam == "other" else False # True: filter frame based on motion magnitude | False: use all frames - use_filter_frames = True + use_filter_frames = False outdir = "database/processed/" viddir = "database/raw/%s" % vidname diff --git a/scripts/run_preprocess_bg.py b/scripts/run_preprocess_bg.py new file mode 100644 index 0000000..81a7833 --- /dev/null +++ b/scripts/run_preprocess_bg.py @@ -0,0 +1,125 @@ +# python scripts/run_preprocess.py shiba-haru "0,1,2,3,4,5,6,7" +import configparser +import glob +import os +import pdb +import numpy as np +import cv2 +import struct +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from lab4d.utils.gpu_utils import gpu_map +from preprocess.libs.io import run_bash_command +from preprocess.scripts.download import download_seq +from preprocess.scripts.camera_registration import camera_registration +from preprocess.scripts.canonical_registration import canonical_registration +from preprocess.scripts.crop import extract_crop +from preprocess.scripts.depth import extract_depth +from preprocess.scripts.extract_dinov2 import extract_dinov2 +from preprocess.scripts.extract_frames import extract_frames +from preprocess.scripts.tsdf_fusion import tsdf_fusion +from preprocess.scripts.write_config import write_config +from preprocess.scripts.fake_data import create_fake_masks +from preprocess.third_party.vcnplus.compute_flow import compute_flow +from preprocess.third_party.vcnplus.frame_filter import frame_filter +from preprocess.third_party.omnivision.normal import extract_normal + + +def remove_exist_dir(seqname, outdir): + run_bash_command(f"rm -rf {outdir}/JPEGImages/Full-Resolution/{seqname}") + run_bash_command(f"rm -rf {outdir}/Cameras/Full-Resolution/{seqname}") + run_bash_command(f"rm -rf {outdir}/Features/Full-Resolution/{seqname}") + run_bash_command(f"rm -rf {outdir}/Depth/Full-Resolution/{seqname}") + run_bash_command(f"rm -rf {outdir}/Flow*/Full-Resolution/{seqname}") + run_bash_command(f"rm -rf {outdir}/Annotations/Full-Resolution/{seqname}") + run_bash_command(f"rm -rf {outdir}/Normal/Full-Resolution/{seqname}") + + +def run_extract_frames(seqname, outdir, infile, use_filter_frames, fps): + # extract frames + imgpath = f"{outdir}/JPEGImagesRaw/Full-Resolution/{seqname}" + run_bash_command(f"rm -rf {imgpath}") + os.makedirs(imgpath, exist_ok=True) + extract_frames(infile, imgpath, desired_fps=fps) + + # remove existing dirs for preprocessing + remove_exist_dir(seqname, outdir) + + # filter frames without motion: frame id is the time stamp + if use_filter_frames: + frame_filter(seqname, outdir) + else: + outpath = f"{outdir}/JPEGImages/Full-Resolution/{seqname}" + run_bash_command(f"rm -rf {outpath}") + os.makedirs(outpath, exist_ok=True) + run_bash_command(f"cp {imgpath}/* {outpath}/") + + +def run_extract_priors(seqname, outdir): + print("extracting priors: ", seqname) + # flow + for dframe in [1, 2, 4, 8]: + compute_flow(seqname, outdir, dframe) + + # depth + extract_depth(seqname) + extract_normal(seqname) + + # TODO create fake masks + create_fake_masks(seqname, outdir) + + # crop around object and process flow + extract_crop(seqname, 256, 1) + + # compute bg/fg cameras + camera_registration(seqname, 0) + tsdf_fusion(seqname, 0) + canonical_registration(seqname, 256, "other") + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print(f"Usage: python {sys.argv[0]} ") + print(f" Example: python {sys.argv[0]} cat-pikachu-0 '0,1,2,3,4,5,6,7'") + exit() + vidname = sys.argv[1] + gpulist = [int(n) for n in sys.argv[2].split(",")] + + use_filter_frames = False + fps = 10 + + outdir = "database/processed/" + viddir = "database/raw/%s" % vidname + print("using gpus: ", gpulist) + os.makedirs("tmp", exist_ok=True) + + # download the videos + download_seq(vidname) + + # set up parallel extraction + frame_args = [] + for counter, infile in enumerate(sorted(glob.glob("%s/*" % viddir))): + seqname = "%s-%04d" % (vidname, counter) + frame_args.append((seqname, outdir, infile, use_filter_frames, fps)) + + # extract frames and filter frames without motion: frame id is the time stamp + gpu_map(run_extract_frames, frame_args, gpus=gpulist) + + # write config + write_config(vidname) + + # read config + config = configparser.RawConfigParser() + config.read("database/configs/%s.config" % vidname) + prior_args = [] + for vidid in range(len(config.sections()) - 1): + seqname = config.get("data_%d" % vidid, "img_path").strip("/").split("/")[-1] + prior_args.append((seqname, outdir)) + + # extract flow/depth/camera/etc + gpu_map(run_extract_priors, prior_args, gpus=gpulist) + + # extract dinov2 features + extract_dinov2(vidname, 256, component_id=0, gpulist=gpulist) diff --git a/scripts/run_rendering_parallel.py b/scripts/run_rendering_parallel.py index 89bc2f1..5422078 100644 --- a/scripts/run_rendering_parallel.py +++ b/scripts/run_rendering_parallel.py @@ -27,7 +27,7 @@ # render proxy over rounds logdir = flagfile.rsplit("/", 1)[0] subprocess.Popen( - f"python scripts/render_intermediate.py --testdir {logdir}/", shell=True + f"python lab4d/render_intermediate.py --testdir {logdir}/", shell=True ) # Loop over each device. diff --git a/scripts/train.sh b/scripts/train.sh index f3dc29a..5533c72 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -9,6 +9,7 @@ echo "using "$ngpu "gpus" # assign random port # https://github.com/pytorch/pytorch/issues/73320 +# torchrun \ CUDA_VISIBLE_DEVICES=$dev torchrun \ --nproc_per_node $ngpu --nnodes 1 --rdzv_backend c10d --rdzv_endpoint localhost:0 \ $main_func \ diff --git a/scripts/zip_all_log.sh b/scripts/zip_all_log.sh new file mode 100644 index 0000000..6665f45 --- /dev/null +++ b/scripts/zip_all_log.sh @@ -0,0 +1,12 @@ +python scripts/zip_logdir.py logdir/car-turnaround-2-fg-rigid-b120/ +python scripts/zip_logdir.py logdir/cat-pikachu-0-fg-skel-b120/ +python scripts/zip_logdir.py logdir/cat-pikachu-0-comp-comp-s2/ +python scripts/zip_logdir.py logdir/squirrel-fg-comp-b120/ +python scripts/zip_logdir.py logdir/dog-robolounge-fg-comp-b120/ +python scripts/zip_logdir.py logdir/cat-pikachu-fg-bob-b120/ +python scripts/zip_logdir.py logdir/shiba-haru-fg-comp-b120/ +python scripts/zip_logdir.py logdir/human-cap-fg-comp-b120/ +python scripts/zip_logdir.py logdir/penguin-fg-skel-b120/ +python scripts/zip_logdir.py logdir/human-48-category-comp/ +python scripts/zip_logdir.py logdir/cat-85-category-comp/ +python scripts/zip_logdir.py logdir/dog-98-category-comp/