From 579aa418b28ab80cce987c71e8f92f6df54ca5b2 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 9 May 2025 22:43:36 +0200 Subject: [PATCH 01/11] implements hogs for default strategy --- examples/simple_trainer.py | 31 +++++++++++++++++++++++++++++++ gsplat/strategy/default.py | 10 ++++++---- gsplat/strategy/ops.py | 21 +++++++++++++++++---- gsplat/utils.py | 6 ++++++ 4 files changed, 60 insertions(+), 8 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3aa8caf87..661194da0 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -37,6 +37,7 @@ from gsplat.optimizers import SelectiveAdam from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy +from gsplat.utils import xyz_to_polar from gsplat_viewer import GsplatViewer, GsplatRenderTabState from nerfview import CameraState, RenderTabState, apply_float_colormap @@ -133,6 +134,9 @@ class Config: # Scale regularization scale_reg: float = 0.0 + # Use homogeneous coordinates, 50k max_steps and 30k steps densifications recommended! + use_hom_coords: bool = True + # Enable camera optimization. pose_opt: bool = False # Learning rate for camera optimization @@ -202,6 +206,7 @@ def create_splats_with_optimizers( visible_adam: bool = False, batch_size: int = 1, feature_dim: Optional[int] = None, + use_hom_coords: bool = False, device: str = "cuda", world_rank: int = 0, world_size: int = 1, @@ -215,6 +220,12 @@ def create_splats_with_optimizers( else: raise ValueError("Please specify a correct init_type: sfm or random") + if use_hom_coords: + w, _ = xyz_to_polar(points) + points *= w.unsqueeze(1) + w = torch.log(w) + w = w[world_rank::world_size] + # Initialize the GS size to be the average dist of the 3 nearest neighbors dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] dist_avg = torch.sqrt(dist2_avg) @@ -237,6 +248,9 @@ def create_splats_with_optimizers( ("opacities", torch.nn.Parameter(opacities), 5e-2), ] + if use_hom_coords: + params.append(("w", torch.nn.Parameter(w), 0.0002 * scene_scale)) + if feature_dim is None: # color is SH coefficients. colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] @@ -337,6 +351,7 @@ def __init__( visible_adam=cfg.visible_adam, batch_size=cfg.batch_size, feature_dim=feature_dim, + use_hom_coords=cfg.use_hom_coords, device=self.device, world_rank=world_rank, world_size=world_size, @@ -466,6 +481,11 @@ def rasterize_splats( scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + if cfg.use_hom_coords: + w_inv = 1.0 / torch.exp(self.splats["w"]).unsqueeze(1) + means *= w_inv + scales *= w_inv + image_ids = kwargs.pop("image_ids", None) if self.cfg.app_opt: colors = self.app_module( @@ -529,6 +549,12 @@ def train(self): self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) ), ] + if cfg.use_hom_coords: + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["w"], gamma=0.1 ** (0.5 / max_steps) + ) + ) if cfg.pose_opt: # pose optimization has a learning rate schedule schedulers.append( @@ -770,6 +796,11 @@ def train(self): means = self.splats["means"] scales = self.splats["scales"] + if cfg.use_hom_coords: + w_inv = 1.0 / torch.exp(self.splats["w"]).unsqueeze(1) + means *= w_inv + scales = torch.log(torch.exp(scales) * w_inv) + quats = self.splats["quats"] opacities = self.splats["opacities"] export_splats( diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 9075278a9..df87407ab 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -272,10 +272,12 @@ def _grow_gs( device = grads.device is_grad_high = grads > self.grow_grad2d - is_small = ( - torch.exp(params["scales"]).max(dim=-1).values - <= self.grow_scale3d * state["scene_scale"] - ) + scales = torch.exp(params["scales"]) + if "w" in params: + w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1) + scales *= w_inv + + is_small = scales.max(dim=-1).values <= self.grow_scale3d * state["scene_scale"] is_dupli = is_grad_high & is_small n_dupli = is_dupli.sum().item() diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 83c90a25e..19deca615 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -7,7 +7,7 @@ from gsplat import quat_scale_to_covar_preci from gsplat.relocation import compute_relocation -from gsplat.utils import normalized_quat_to_rotmat +from gsplat.utils import normalized_quat_to_rotmat, xyz_to_polar @torch.no_grad() @@ -141,7 +141,11 @@ def split( sel = torch.where(mask)[0] rest = torch.where(~mask)[0] - scales = torch.exp(params["scales"][sel]) + scales = torch.exp(params["scales"]) + means = params["means"][sel] + if "w" in params: + w_inv = 1.0 / torch.exp(params["w"][sel]).unsqueeze(1) + scales *= w_inv quats = F.normalize(params["quats"][sel], dim=-1) rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3] samples = torch.einsum( @@ -151,15 +155,24 @@ def split( torch.randn(2, len(scales), 3, device=device), ) # [2, N, 3] + means = means + samples + if "w" in params: + w, _ = xyz_to_polar(means) + means = means * w.unsqueeze(1) + def param_fn(name: str, p: Tensor) -> Tensor: repeats = [2] + [1] * (p.dim() - 1) if name == "means": - p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3] + p_split = means.reshape(-1, 3) # [2N, 3] elif name == "scales": - p_split = torch.log(scales / 1.6).repeat(2, 1) # [2N, 3] + p_split = torch.log(torch.exp(params["scales"]) / 1.6).repeat( + 2, 1 + ) # [2N, 3] elif name == "opacities" and revised_opacity: new_opacities = 1.0 - torch.sqrt(1.0 - torch.sigmoid(p[sel])) p_split = torch.logit(new_opacities).repeat(repeats) # [2N] + elif name == "w": + p_split = torch.log(w) else: p_split = p[sel].repeat(repeats) p_new = torch.cat([p[rest], p_split]) diff --git a/gsplat/utils.py b/gsplat/utils.py index e56692958..50e393246 100644 --- a/gsplat/utils.py +++ b/gsplat/utils.py @@ -133,6 +133,12 @@ def normalized_quat_to_rotmat(quat: Tensor) -> Tensor: return mat.reshape(quat.shape[:-1] + (3, 3)) +def xyz_to_polar(means): + x, y, z = means[:, 0], means[:, 1], means[:, 2] + r = torch.sqrt(x**2 + y**2 + z**2) + return 1 / r, r + + def log_transform(x): return torch.sign(x) * torch.log1p(torch.abs(x)) From d608d0f2591818159f7936f77dbc81a5deb2b854 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 9 May 2025 23:16:04 +0200 Subject: [PATCH 02/11] fixes --- examples/simple_trainer.py | 18 +++++++++++++----- gsplat/strategy/default.py | 2 +- gsplat/strategy/ops.py | 11 ++++++----- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 661194da0..158da0ef6 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -54,13 +54,13 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "data/360_v2/garden" + data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset - data_factor: int = 4 + data_factor: int = 1 # Directory to save results result_dir: str = "results/garden" # Every N images there is a test image - test_every: int = 8 + test_every: int = 300 # Random crop size for training (experimental) patch_size: Optional[int] = None # A global scaler that applies to the scene size related parameters @@ -483,8 +483,8 @@ def rasterize_splats( if cfg.use_hom_coords: w_inv = 1.0 / torch.exp(self.splats["w"]).unsqueeze(1) - means *= w_inv - scales *= w_inv + means = means * w_inv + scales = scales * w_inv image_ids = kwargs.pop("image_ids", None) if self.cfg.app_opt: @@ -1211,7 +1211,15 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): ), ), } + cfg = tyro.extras.overridable_config_cli(configs) + if cfg.use_hom_coords: + cfg.max_steps = 50_000 + if isinstance(cfg.strategy, DefaultStrategy): + cfg.strategy.refine_stop_iter = 30_000 + cfg.strategy.refine_every = 200 + cfg.strategy.reset_every = 6_000 + cfg.adjust_steps(cfg.steps_scaler) # try import extra dependencies diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index df87407ab..2619b4537 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -275,7 +275,7 @@ def _grow_gs( scales = torch.exp(params["scales"]) if "w" in params: w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1) - scales *= w_inv + scales = scales * w_inv is_small = scales.max(dim=-1).values <= self.grow_scale3d * state["scene_scale"] is_dupli = is_grad_high & is_small diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 19deca615..9d89a20d6 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -141,11 +141,11 @@ def split( sel = torch.where(mask)[0] rest = torch.where(~mask)[0] - scales = torch.exp(params["scales"]) + scales = torch.exp(params["scales"][sel]) means = params["means"][sel] if "w" in params: w_inv = 1.0 / torch.exp(params["w"][sel]).unsqueeze(1) - scales *= w_inv + scales = scales * w_inv quats = F.normalize(params["quats"][sel], dim=-1) rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3] samples = torch.einsum( @@ -157,13 +157,14 @@ def split( means = means + samples if "w" in params: - w, _ = xyz_to_polar(means) - means = means * w.unsqueeze(1) + flat_means = means.reshape(-1, 3) # [2N, 3] + w, _ = xyz_to_polar(flat_means) # [2N] + means = flat_means * w.unsqueeze(1) # [2N, 3] def param_fn(name: str, p: Tensor) -> Tensor: repeats = [2] + [1] * (p.dim() - 1) if name == "means": - p_split = means.reshape(-1, 3) # [2N, 3] + p_split = means # [2N, 3] elif name == "scales": p_split = torch.log(torch.exp(params["scales"]) / 1.6).repeat( 2, 1 From 223410ef2bd00b553a325eeb2d4ea9e0c4533ac0 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Sun, 11 May 2025 21:36:21 +0200 Subject: [PATCH 03/11] fixes --- examples/simple_trainer.py | 6 ++++-- gsplat/strategy/default.py | 9 +++++++-- gsplat/strategy/ops.py | 7 ++++--- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 158da0ef6..eb7717009 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -54,7 +54,7 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/paja/data/bike_aliked" + data_dir: str = "/home/paja/data/alex_new" # Downsample factor for the dataset data_factor: int = 1 # Directory to save results @@ -552,7 +552,7 @@ def train(self): if cfg.use_hom_coords: schedulers.append( torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["w"], gamma=0.1 ** (0.5 / max_steps) + self.optimizers["w"], gamma=0.1 ** (0.005 / max_steps) ) ) if cfg.pose_opt: @@ -1219,6 +1219,8 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): cfg.strategy.refine_stop_iter = 30_000 cfg.strategy.refine_every = 200 cfg.strategy.reset_every = 6_000 + cfg.strategy.refine_start_iter = 1_500 + cfg.strategy.prune_too_big = False cfg.adjust_steps(cfg.steps_scaler) diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 2619b4537..071755492 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -83,6 +83,7 @@ class DefaultStrategy(Strategy): prune_scale3d: float = 0.1 prune_scale2d: float = 0.15 refine_scale2d_stop_iter: int = 0 + prune_too_big: bool = True refine_start_iter: int = 500 refine_stop_iter: int = 15_000 reset_every: int = 3000 @@ -319,9 +320,13 @@ def _prune_gs( step: int, ) -> int: is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa - if step > self.reset_every: + if step > self.reset_every and self.prune_too_big: + scales = torch.exp(params["scales"]) + if "w" in params: + w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1) + scales = scales * w_inv is_too_big = ( - torch.exp(params["scales"]).max(dim=-1).values + scales.max(dim=-1).values > self.prune_scale3d * state["scene_scale"] ) # The official code also implements sreen-size pruning but diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 9d89a20d6..a6c0cfa35 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -146,6 +146,7 @@ def split( if "w" in params: w_inv = 1.0 / torch.exp(params["w"][sel]).unsqueeze(1) scales = scales * w_inv + means = means * w_inv quats = F.normalize(params["quats"][sel], dim=-1) rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3] samples = torch.einsum( @@ -165,15 +166,15 @@ def param_fn(name: str, p: Tensor) -> Tensor: repeats = [2] + [1] * (p.dim() - 1) if name == "means": p_split = means # [2N, 3] + elif name == "w": + p_split = torch.log(w) elif name == "scales": - p_split = torch.log(torch.exp(params["scales"]) / 1.6).repeat( + p_split = torch.log(torch.exp(params["scales"][sel]) / 1.6).repeat( 2, 1 ) # [2N, 3] elif name == "opacities" and revised_opacity: new_opacities = 1.0 - torch.sqrt(1.0 - torch.sigmoid(p[sel])) p_split = torch.logit(new_opacities).repeat(repeats) # [2N] - elif name == "w": - p_split = torch.log(w) else: p_split = p[sel].repeat(repeats) p_new = torch.cat([p[rest], p_split]) From 340ba81f25775019ea2c4ea48094e787d8a03ee1 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Sun, 11 May 2025 22:14:55 +0200 Subject: [PATCH 04/11] formating --- gsplat/strategy/default.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 071755492..a1e2b81ab 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -326,8 +326,7 @@ def _prune_gs( w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1) scales = scales * w_inv is_too_big = ( - scales.max(dim=-1).values - > self.prune_scale3d * state["scene_scale"] + scales.max(dim=-1).values > self.prune_scale3d * state["scene_scale"] ) # The official code also implements sreen-size pruning but # it's actually not being used due to a bug: From 67bee0aef45c68d788f4ba1aa625e9be03a7eb1e Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Sun, 11 May 2025 22:27:46 +0200 Subject: [PATCH 05/11] restore default values! --- examples/simple_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index bd002ee6b..a23df8fd8 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -54,13 +54,13 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/paja/data/alex_new" + data_dir: str = "data/360_v2/garden" # Downsample factor for the dataset - data_factor: int = 1 + data_factor: int = 4 # Directory to save results result_dir: str = "results/garden" # Every N images there is a test image - test_every: int = 300 + test_every: int = 8 # Random crop size for training (experimental) patch_size: Optional[int] = None # A global scaler that applies to the scene size related parameters From 17a885ba4132b59f6b22464b1865bde0329014d0 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Mon, 12 May 2025 12:36:35 +0200 Subject: [PATCH 06/11] support mcmc --- examples/simple_trainer.py | 8 +++++++- gsplat/strategy/ops.py | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a23df8fd8..6f731ff88 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -148,7 +148,7 @@ class Config: scale_reg: float = 0.0 # Use homogeneous coordinates, 50k max_steps and 30k steps densifications recommended! - use_hom_coords: bool = True + use_hom_coords: bool = False # Enable camera optimization. pose_opt: bool = False @@ -1252,6 +1252,12 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): cfg.strategy.reset_every = 6_000 cfg.strategy.refine_start_iter = 1_500 cfg.strategy.prune_too_big = False + elif isinstance(cfg.strategy, MCMCStrategy): + cfg.strategy.refine_start_iter: 1_500 + cfg.strategy.refine_stop_iter: 40_000 + cfg.strategy.refine_every: int = 200 + else: + assert_never(cfg.strategy) cfg.adjust_steps(cfg.steps_scaler) diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index a6c0cfa35..eda14ff34 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -284,9 +284,14 @@ def relocate( probs = opacities[alive_indices].flatten() # ensure its shape is [N,] sampled_idxs = _multinomial_sample(probs, n, replacement=True) sampled_idxs = alive_indices[sampled_idxs] + + scales = torch.exp(params["scales"])[sampled_idxs] + if "w" in params: + w = torch.exp(params["w"][sampled_idxs]).unsqueeze(1) + scales = scales / w new_opacities, new_scales = compute_relocation( opacities=opacities[sampled_idxs], - scales=torch.exp(params["scales"])[sampled_idxs], + scales=scales, ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, binoms=binoms, ) @@ -296,7 +301,7 @@ def param_fn(name: str, p: Tensor) -> Tensor: if name == "opacities": p[sampled_idxs] = torch.logit(new_opacities) elif name == "scales": - p[sampled_idxs] = torch.log(new_scales) + p[sampled_idxs] = torch.log(new_scales * w) p[dead_indices] = p[sampled_idxs] return torch.nn.Parameter(p, requires_grad=p.requires_grad) @@ -326,9 +331,14 @@ def sample_add( eps = torch.finfo(torch.float32).eps probs = opacities.flatten() sampled_idxs = _multinomial_sample(probs, n, replacement=True) + + scales = torch.exp(params["scales"])[sampled_idxs] + if "w" in params: + w = torch.exp(params["w"][sampled_idxs]).unsqueeze(1) + scales = scales / w new_opacities, new_scales = compute_relocation( opacities=opacities[sampled_idxs], - scales=torch.exp(params["scales"])[sampled_idxs], + scales=scales, ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, binoms=binoms, ) @@ -338,7 +348,10 @@ def param_fn(name: str, p: Tensor) -> Tensor: if name == "opacities": p[sampled_idxs] = torch.logit(new_opacities) elif name == "scales": - p[sampled_idxs] = torch.log(new_scales) + if "w" in params: + p[sampled_idxs] = torch.log(new_scales * w) + else: + p[sampled_idxs] = torch.log(new_scales) p_new = torch.cat([p, p[sampled_idxs]]) return torch.nn.Parameter(p_new, requires_grad=p.requires_grad) @@ -363,7 +376,13 @@ def inject_noise_to_position( scaler: float, ): opacities = torch.sigmoid(params["opacities"].flatten()) + means = params["means"] scales = torch.exp(params["scales"]) + if "w" in params: + w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1) + means = means * w_inv + scales = scales * w_inv + covars, _ = quat_scale_to_covar_preci( params["quats"], scales, @@ -381,4 +400,10 @@ def op_sigmoid(x, k=100, x0=0.995): * scaler ) noise = torch.einsum("bij,bj->bi", covars, noise) - params["means"].add_(noise) + means = means + noise + + w, _ = xyz_to_polar(means) + means = means * w.unsqueeze(1) + + params["means"].data = means + params["w"].data = torch.log(w.clamp_min(1e-8)) From ac31f886950620951aca6df3cf2b72199cf9f279 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Mon, 12 May 2025 12:45:00 +0200 Subject: [PATCH 07/11] update w only if available --- gsplat/strategy/ops.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index eda14ff34..3e77427a9 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -402,8 +402,9 @@ def op_sigmoid(x, k=100, x0=0.995): noise = torch.einsum("bij,bj->bi", covars, noise) means = means + noise - w, _ = xyz_to_polar(means) - means = means * w.unsqueeze(1) + if "w" in params: + w, _ = xyz_to_polar(means) + means = means * w.unsqueeze(1) + params["w"].data = torch.log(w.clamp_min(1e-8)) params["means"].data = means - params["w"].data = torch.log(w.clamp_min(1e-8)) From 2455a5fd011bfc4cccf3d7453a503dfe71bc6d3b Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Mon, 12 May 2025 13:00:33 +0200 Subject: [PATCH 08/11] only apply w if available --- gsplat/strategy/ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 3e77427a9..b254d9b6a 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -301,7 +301,10 @@ def param_fn(name: str, p: Tensor) -> Tensor: if name == "opacities": p[sampled_idxs] = torch.logit(new_opacities) elif name == "scales": - p[sampled_idxs] = torch.log(new_scales * w) + if "w" in params: + p[sampled_idxs] = torch.log(new_scales * w) + else: + p[sampled_idxs] = torch.log(new_scales) p[dead_indices] = p[sampled_idxs] return torch.nn.Parameter(p, requires_grad=p.requires_grad) From fc5228204d89e94ea5b183e03abec83a72c4df62 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 13 May 2025 08:40:46 +0200 Subject: [PATCH 09/11] fix bug --- examples/simple_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 6f731ff88..4fcf2ac3f 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -829,7 +829,7 @@ def train(self): scales = self.splats["scales"] if cfg.use_hom_coords: w_inv = 1.0 / torch.exp(self.splats["w"]).unsqueeze(1) - means *= w_inv + means = means * w_inv scales = torch.log(torch.exp(scales) * w_inv) quats = self.splats["quats"] From 96771ffb5b2168e3d2520616cdbf461edab9a77a Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 13 May 2025 10:26:54 +0200 Subject: [PATCH 10/11] fix simple_viewer with hom coords --- examples/simple_viewer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/simple_viewer.py b/examples/simple_viewer.py index 63597033a..8b0facc76 100644 --- a/examples/simple_viewer.py +++ b/examples/simple_viewer.py @@ -103,12 +103,18 @@ def main(local_rank: int, world_rank, world_size: int, args): means, quats, scales, opacities, sh0, shN = [], [], [], [], [], [] for ckpt_path in args.ckpt: ckpt = torch.load(ckpt_path, map_location=device)["splats"] - means.append(ckpt["means"]) + if "w" in ckpt: + w_inv = 1.0 / torch.exp(ckpt["w"]).unsqueeze(1) + means.append(ckpt["means"] * w_inv) + scales.append(torch.exp(ckpt["scales"]) * w_inv) + else: + means.append(ckpt["means"]) + scales.append(torch.exp(ckpt["scales"])) quats.append(F.normalize(ckpt["quats"], p=2, dim=-1)) - scales.append(torch.exp(ckpt["scales"])) opacities.append(torch.sigmoid(ckpt["opacities"])) sh0.append(ckpt["sh0"]) shN.append(ckpt["shN"]) + means = torch.cat(means, dim=0) quats = torch.cat(quats, dim=0) scales = torch.cat(scales, dim=0) From e757770152370196379ffcdaa2df720c310bb177 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 14 May 2025 09:26:33 +0200 Subject: [PATCH 11/11] fix dim --- gsplat/strategy/ops.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index b254d9b6a..26544cf7d 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -156,11 +156,10 @@ def split( torch.randn(2, len(scales), 3, device=device), ) # [2, N, 3] - means = means + samples + means = (means + samples).reshape(-1, 3) if "w" in params: - flat_means = means.reshape(-1, 3) # [2N, 3] - w, _ = xyz_to_polar(flat_means) # [2N] - means = flat_means * w.unsqueeze(1) # [2N, 3] + w, _ = xyz_to_polar(means) # [2N] + means = means * w.unsqueeze(1) # [2N, 3] def param_fn(name: str, p: Tensor) -> Tensor: repeats = [2] + [1] * (p.dim() - 1)