diff --git a/examples/gsplat_viewer.py b/examples/gsplat_viewer.py index effaa3aef..e35749832 100644 --- a/examples/gsplat_viewer.py +++ b/examples/gsplat_viewer.py @@ -18,10 +18,10 @@ class GsplatRenderTabState(RenderTabState): eps2d: float = 0.3 backgrounds: Tuple[float, float, float] = (0.0, 0.0, 0.0) render_mode: Literal[ - "rgb", "depth(accumulated)", "depth(expected)", "alpha" + "rgb", "depth(accumulated)", "depth(expected)", "alpha", "diffuse", "specular" ] = "rgb" normalize_nearfar: bool = False - inverse: bool = True + inverse: bool = False colormap: Literal[ "turbo", "viridis", "magma", "inferno", "cividis", "gray" ] = "turbo" @@ -141,7 +141,14 @@ def _(_) -> None: render_mode_dropdown = server.gui.add_dropdown( "Render Mode", - ("rgb", "depth(accumulated)", "depth(expected)", "alpha"), + ( + "rgb", + "depth(accumulated)", + "depth(expected)", + "alpha", + "diffuse", + "specular", + ), initial_value=self.render_tab_state.render_mode, hint="Render mode to use.", ) @@ -150,12 +157,10 @@ def _(_) -> None: def _(_) -> None: if "depth" in render_mode_dropdown.value: normalize_nearfar_checkbox.disabled = False + inverse_checkbox.disabled = False else: normalize_nearfar_checkbox.disabled = True - if render_mode_dropdown.value == "rgb": inverse_checkbox.disabled = True - else: - inverse_checkbox.disabled = False self.render_tab_state.render_mode = render_mode_dropdown.value self.rerender(_) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 53ffe7080..4def28718 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -1095,6 +1095,8 @@ def _viewer_render_fn( "depth(accumulated)": "D", "depth(expected)": "ED", "alpha": "RGB", + "diffuse": "Diffuse", + "specular": "Specular", } render_colors, render_alphas, info = self.rasterize_splats( @@ -1116,11 +1118,7 @@ def _viewer_render_fn( render_tab_state.total_gs_count = len(self.splats["means"]) render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() - if render_tab_state.render_mode == "rgb": - # colors represented with sh are not guranteed to be in [0, 1] - render_colors = render_colors[0, ..., 0:3].clamp(0, 1) - renders = render_colors.cpu().numpy() - elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: + if render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: # normalize depth to [0, 1] depth = render_colors[0, ..., 0:1] if render_tab_state.normalize_nearfar: @@ -1145,6 +1143,10 @@ def _viewer_render_fn( renders = ( apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() ) + else: + # colors represented with sh are not guranteed to be in [0, 1] + render_colors = render_colors[0, ..., 0:3].clamp(0, 1) + renders = render_colors.cpu().numpy() return renders diff --git a/examples/simple_viewer.py b/examples/simple_viewer.py index 63597033a..84367fab3 100644 --- a/examples/simple_viewer.py +++ b/examples/simple_viewer.py @@ -140,6 +140,8 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState "depth(accumulated)": "D", "depth(expected)": "ED", "alpha": "RGB", + "diffuse": "Diffuse", + "specular": "Specular", } render_colors, render_alphas, info = rasterization( @@ -173,11 +175,7 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState render_tab_state.total_gs_count = len(means) render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() - if render_tab_state.render_mode == "rgb": - # colors represented with sh are not guranteed to be in [0, 1] - render_colors = render_colors[0, ..., 0:3].clamp(0, 1) - renders = render_colors.cpu().numpy() - elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: + if render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: # normalize depth to [0, 1] depth = render_colors[0, ..., 0:1] if render_tab_state.normalize_nearfar: @@ -197,11 +195,13 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState ) elif render_tab_state.render_mode == "alpha": alpha = render_alphas[0, ..., 0:1] - if render_tab_state.inverse: - alpha = 1 - alpha renders = ( apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() ) + else: + # colors represented with sh are not guranteed to be in [0, 1] + render_colors = render_colors[0, ..., 0:3].clamp(0, 1) + renders = render_colors.cpu().numpy() return renders server = viser.ViserServer(port=args.port, verbose=False) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index d9b4fd20f..a99d6b7e4 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -46,7 +46,9 @@ def rasterization( packed: bool = True, tile_size: int = 16, backgrounds: Optional[Tensor] = None, - render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", + render_mode: Literal[ + "RGB", "D", "ED", "RGB+D", "RGB+ED", "Diffuse", "Specular" + ] = "RGB", sparse_grad: bool = False, absgrad: bool = False, rasterize_mode: Literal["classic", "antialiased"] = "classic", @@ -98,8 +100,10 @@ def rasterization( .. note:: **Depth Rendering**: This function supports colors or/and depths via `render_mode`. - The supported modes are "RGB", "D", "ED", "RGB+D", and "RGB+ED". "RGB" renders the - colored image that respects the `colors` argument. "D" renders the accumulated z-depth + The supported modes are "RGB", "D", "ED", "RGB+D", "RGB+ED", "Diffuse", and "Specular". + "RGB" renders the colored image that respects the `colors` argument. + "Diffuse" renders view-independent RGB components while "Specular" renders view-dependent RGB components. + "D" renders the accumulated z-depth :math:`\\sum_i w_i z_i`. "ED" renders the expected z-depth :math:`\\frac{\\sum_i w_i z_i}{\\sum_i w_i}`. "RGB+D" and "RGB+ED" render both the colored image and the depth, in which the depth is the last channel of the output. @@ -181,8 +185,10 @@ def rasterization( tile_size: The size of the tiles for rasterization. Default is 16. (Note: other values are not tested) backgrounds: The background colors. [..., C, D]. Default is None. - render_mode: The rendering mode. Supported modes are "RGB", "D", "ED", "RGB+D", - and "RGB+ED". "RGB" renders the colored image, "D" renders the accumulated depth, and + render_mode: The rendering mode. + Supported modes are "RGB", "D", "ED", "RGB+D", "RGB+ED", "Diffuse", and "Specular". + "RGB" renders the colored image, "Diffuse" renders view-independent RGB, + "Specular" renders view-dependent RGB, "D" renders the accumulated depth, and "ED" renders the expected depth. Default is "RGB". sparse_grad: If true, the gradients for {means, quats, scales} will be stored in a COO sparse layout. This can be helpful for saving memory. Default is False. @@ -280,7 +286,15 @@ def rasterization( assert opacities.shape == batch_dims + (N,), opacities.shape assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape - assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + assert render_mode in [ + "RGB", + "D", + "ED", + "RGB+D", + "RGB+ED", + "Diffuse", + "Specular", + ], render_mode def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tensor: view_list = list( @@ -481,6 +495,11 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso pass else: # Colors are SH coefficients, with shape [..., N, K, 3] or [..., C, N, K, 3] + # SH = 0 is the view-independent diffuse component, SH >= 1 are view-dependent specular components. + if render_mode in ("Diffuse", "Specular"): + colors = colors.clone() + sel = slice(1, None) if render_mode == "Diffuse" else slice(0, 1) + colors[..., sel, :] = 0.0 # Zero out the unwanted SH components. campos = torch.inverse(viewmats)[..., :3, 3] # [..., C, 3] if viewmats_rs is not None: campos_rs = torch.inverse(viewmats_rs)[..., :3, 3] @@ -515,7 +534,9 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso sh_degree, dirs, shs, masks=masks ) # [..., C, N, 3] # make it apple-to-apple with Inria's CUDA Backend. - colors = torch.clamp_min(colors + 0.5, 0.0) + # view-dependent components are in [0, 1]. No need to add and clamp. + if render_mode != "Specular": + colors = torch.clamp_min(colors + 0.5, 0.0) # If in distributed mode, we need to scatter the GSs to the destination ranks, based # on which cameras they are visible to, which we already figured out in the projection