Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions examples/gsplat_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class GsplatRenderTabState(RenderTabState):
eps2d: float = 0.3
backgrounds: Tuple[float, float, float] = (0.0, 0.0, 0.0)
render_mode: Literal[
"rgb", "depth(accumulated)", "depth(expected)", "alpha"
"rgb", "depth(accumulated)", "depth(expected)", "alpha", "diffuse", "specular"
] = "rgb"
normalize_nearfar: bool = False
inverse: bool = True
inverse: bool = False
colormap: Literal[
"turbo", "viridis", "magma", "inferno", "cividis", "gray"
] = "turbo"
Expand Down Expand Up @@ -141,7 +141,14 @@ def _(_) -> None:

render_mode_dropdown = server.gui.add_dropdown(
"Render Mode",
("rgb", "depth(accumulated)", "depth(expected)", "alpha"),
(
"rgb",
"depth(accumulated)",
"depth(expected)",
"alpha",
"diffuse",
"specular",
),
initial_value=self.render_tab_state.render_mode,
hint="Render mode to use.",
)
Expand All @@ -150,12 +157,10 @@ def _(_) -> None:
def _(_) -> None:
if "depth" in render_mode_dropdown.value:
normalize_nearfar_checkbox.disabled = False
inverse_checkbox.disabled = False
else:
normalize_nearfar_checkbox.disabled = True
if render_mode_dropdown.value == "rgb":
inverse_checkbox.disabled = True
else:
inverse_checkbox.disabled = False
self.render_tab_state.render_mode = render_mode_dropdown.value
self.rerender(_)

Expand Down
12 changes: 7 additions & 5 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,8 @@ def _viewer_render_fn(
"depth(accumulated)": "D",
"depth(expected)": "ED",
"alpha": "RGB",
"diffuse": "Diffuse",
"specular": "Specular",
Comment on lines +1098 to +1099
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we dont touch rendering.py we we remove these from here..

}

render_colors, render_alphas, info = self.rasterize_splats(
Expand All @@ -1116,11 +1118,7 @@ def _viewer_render_fn(
render_tab_state.total_gs_count = len(self.splats["means"])
render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item()

if render_tab_state.render_mode == "rgb":
# colors represented with sh are not guranteed to be in [0, 1]
render_colors = render_colors[0, ..., 0:3].clamp(0, 1)
renders = render_colors.cpu().numpy()
elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]:
if render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]:
# normalize depth to [0, 1]
depth = render_colors[0, ..., 0:1]
if render_tab_state.normalize_nearfar:
Expand All @@ -1145,6 +1143,10 @@ def _viewer_render_fn(
renders = (
apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy()
)
else:
# colors represented with sh are not guranteed to be in [0, 1]
render_colors = render_colors[0, ..., 0:3].clamp(0, 1)
renders = render_colors.cpu().numpy()
return renders


Expand Down
14 changes: 7 additions & 7 deletions examples/simple_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState
"depth(accumulated)": "D",
"depth(expected)": "ED",
"alpha": "RGB",
"diffuse": "Diffuse",
"specular": "Specular",
}

render_colors, render_alphas, info = rasterization(
Expand Down Expand Up @@ -173,11 +175,7 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState
render_tab_state.total_gs_count = len(means)
render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item()

if render_tab_state.render_mode == "rgb":
# colors represented with sh are not guranteed to be in [0, 1]
render_colors = render_colors[0, ..., 0:3].clamp(0, 1)
renders = render_colors.cpu().numpy()
elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]:
if render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]:
# normalize depth to [0, 1]
depth = render_colors[0, ..., 0:1]
if render_tab_state.normalize_nearfar:
Expand All @@ -197,11 +195,13 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState
)
elif render_tab_state.render_mode == "alpha":
alpha = render_alphas[0, ..., 0:1]
if render_tab_state.inverse:
alpha = 1 - alpha
renders = (
apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy()
)
else:
# colors represented with sh are not guranteed to be in [0, 1]
render_colors = render_colors[0, ..., 0:3].clamp(0, 1)
renders = render_colors.cpu().numpy()
return renders

server = viser.ViserServer(port=args.port, verbose=False)
Expand Down
35 changes: 28 additions & 7 deletions gsplat/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def rasterization(
packed: bool = True,
tile_size: int = 16,
backgrounds: Optional[Tensor] = None,
render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB",
render_mode: Literal[
"RGB", "D", "ED", "RGB+D", "RGB+ED", "Diffuse", "Specular"
] = "RGB",
sparse_grad: bool = False,
absgrad: bool = False,
rasterize_mode: Literal["classic", "antialiased"] = "classic",
Expand Down Expand Up @@ -98,8 +100,10 @@ def rasterization(

.. note::
**Depth Rendering**: This function supports colors or/and depths via `render_mode`.
The supported modes are "RGB", "D", "ED", "RGB+D", and "RGB+ED". "RGB" renders the
colored image that respects the `colors` argument. "D" renders the accumulated z-depth
The supported modes are "RGB", "D", "ED", "RGB+D", "RGB+ED", "Diffuse", and "Specular".
"RGB" renders the colored image that respects the `colors` argument.
"Diffuse" renders view-independent RGB components while "Specular" renders view-dependent RGB components.
"D" renders the accumulated z-depth
:math:`\\sum_i w_i z_i`. "ED" renders the expected z-depth
:math:`\\frac{\\sum_i w_i z_i}{\\sum_i w_i}`. "RGB+D" and "RGB+ED" render both
the colored image and the depth, in which the depth is the last channel of the output.
Expand Down Expand Up @@ -181,8 +185,10 @@ def rasterization(
tile_size: The size of the tiles for rasterization. Default is 16.
(Note: other values are not tested)
backgrounds: The background colors. [..., C, D]. Default is None.
render_mode: The rendering mode. Supported modes are "RGB", "D", "ED", "RGB+D",
and "RGB+ED". "RGB" renders the colored image, "D" renders the accumulated depth, and
render_mode: The rendering mode.
Supported modes are "RGB", "D", "ED", "RGB+D", "RGB+ED", "Diffuse", and "Specular".
"RGB" renders the colored image, "Diffuse" renders view-independent RGB,
"Specular" renders view-dependent RGB, "D" renders the accumulated depth, and
"ED" renders the expected depth. Default is "RGB".
sparse_grad: If true, the gradients for {means, quats, scales} will be stored in
a COO sparse layout. This can be helpful for saving memory. Default is False.
Expand Down Expand Up @@ -280,7 +286,15 @@ def rasterization(
assert opacities.shape == batch_dims + (N,), opacities.shape
assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape
assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape
assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode
assert render_mode in [
"RGB",
"D",
"ED",
"RGB+D",
"RGB+ED",
"Diffuse",
"Specular",
], render_mode

def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tensor:
view_list = list(
Expand Down Expand Up @@ -481,6 +495,11 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
pass
else:
# Colors are SH coefficients, with shape [..., N, K, 3] or [..., C, N, K, 3]
# SH = 0 is the view-independent diffuse component, SH >= 1 are view-dependent specular components.
if render_mode in ("Diffuse", "Specular"):
colors = colors.clone()
sel = slice(1, None) if render_mode == "Diffuse" else slice(0, 1)
colors[..., sel, :] = 0.0 # Zero out the unwanted SH components.
Comment on lines +498 to +502
Copy link
Copy Markdown
Collaborator

@liruilong940607 liruilong940607 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this logic should live outside of the rasterization function (into viewer_render_fn). For two reasons:

  • What if I want to rendering diffuse color + depth map? there is no "Diffuse+D" mode here. User would end up have to write this logic outside of rasterization function and use "RGB+D" mode anyway. (The major reason)
  • I'm still not quite happy with the inplace operation colors[..., sel, :] = 0.0 and colors.clone(). It's not very good to have inplace operations in the torch differentiable graph. It might break the auto diff graph or lead to unused parameters.

campos = torch.inverse(viewmats)[..., :3, 3] # [..., C, 3]
if viewmats_rs is not None:
campos_rs = torch.inverse(viewmats_rs)[..., :3, 3]
Expand Down Expand Up @@ -515,7 +534,9 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
sh_degree, dirs, shs, masks=masks
) # [..., C, N, 3]
# make it apple-to-apple with Inria's CUDA Backend.
colors = torch.clamp_min(colors + 0.5, 0.0)
# view-dependent components are in [0, 1]. No need to add and clamp.
if render_mode != "Specular":
colors = torch.clamp_min(colors + 0.5, 0.0)

# If in distributed mode, we need to scatter the GSs to the destination ranks, based
# on which cameras they are visible to, which we already figured out in the projection
Expand Down