diff --git a/README.md b/README.md
index 122510c..1032e8f 100644
--- a/README.md
+++ b/README.md
@@ -54,12 +54,19 @@ import nerfview
def render_fn(
- camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
+ camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
) -> np.ndarray:
# Parse camera state for camera-to-world matrix (c2w) and intrinsic (K) as
# float64 numpy arrays.
+ if render_tab_state.preview_render:
+ width = render_tab_state.render_width
+ height = render_tab_state.render_height
+ else:
+ width = render_tab_state.viewer_width
+ height = render_tab_state.viewer_height
+
c2w = camera_state.c2w
- K = camera_state.get_K(img_wh)
+ K = camera_state.get_K([width, height])
# Do your things and get an image as a uint8 numpy array.
img = your_rendering_logic(...)
return img
@@ -139,7 +146,7 @@ which we include here to be self-contained.
# Only need to run once the first time.
bash examples/assets/download_gsplat_ckpt.sh
CUDA_VISIBLE_DEVICES=0 python examples/03_gsplat_rendering.py \
- --ckpt results/garden/ckpts/ckpt_6999_crop.pt
+ --ckpt examples/assets/ckpt_6999_crop.pt
```
diff --git a/examples/00_dummy_rendering.py b/examples/00_dummy_rendering.py
index 66962f4..6c32858 100644
--- a/examples/00_dummy_rendering.py
+++ b/examples/00_dummy_rendering.py
@@ -24,19 +24,26 @@ def main(port: int = 8080, rendering_latency: float = 0.0):
"""
def render_fn(
- camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
+ camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
) -> UInt8[np.ndarray, "H W 3"]:
# Get camera parameters.
- W, H = img_wh
+ if render_tab_state.preview_render:
+ width = render_tab_state.render_width
+ height = render_tab_state.render_height
+ else:
+ width = render_tab_state.viewer_width
+ height = render_tab_state.viewer_height
c2w = camera_state.c2w
- K = camera_state.get_K(img_wh)
+ K = camera_state.get_K([width, height])
# Render a dummy image as a function of camera direction.
camera_dirs = np.einsum(
"ij,hwj->hwi",
np.linalg.inv(K),
np.pad(
- np.stack(np.meshgrid(np.arange(W), np.arange(H), indexing="xy"), -1)
+ np.stack(
+ np.meshgrid(np.arange(width), np.arange(height), indexing="xy"), -1
+ )
+ 0.5,
((0, 0), (0, 0), (0, 1)),
constant_values=1.0,
diff --git a/examples/01_dummy_training.py b/examples/01_dummy_training.py
index 7696827..9d8d51a 100644
--- a/examples/01_dummy_training.py
+++ b/examples/01_dummy_training.py
@@ -29,19 +29,26 @@ def main(port: int = 8080, max_steps: int = 50, rendering_latency: float = 0.0):
step: int = 0
def render_fn(
- camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
+ camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
) -> UInt8[np.ndarray, "H W 3"]:
# Get camera parameters.
- W, H = img_wh
+ if render_tab_state.preview_render:
+ width = render_tab_state.render_width
+ height = render_tab_state.render_height
+ else:
+ width = render_tab_state.viewer_width
+ height = render_tab_state.viewer_height
c2w = camera_state.c2w
- K = camera_state.get_K(img_wh)
+ K = camera_state.get_K([width, height])
# Render a dummy image as a function of camera direction.
camera_dirs = np.einsum(
"ij,hwj->hwi",
np.linalg.inv(K),
np.pad(
- np.stack(np.meshgrid(np.arange(W), np.arange(H), indexing="xy"), -1)
+ np.stack(
+ np.meshgrid(np.arange(width), np.arange(height), indexing="xy"), -1
+ )
+ 0.5,
((0, 0), (0, 0), (0, 1)),
constant_values=1.0,
@@ -83,11 +90,11 @@ def training_step():
# Optionally make the training utility lower such that we update the scene
# more frequently in this example. You dont need to do this in your own
# code.
- viewer._train_util_slider.value = 0.5
+ viewer._training_tab_handles["train_util_slider"].value = 0.5
for step in tqdm(range(max_steps)):
# Allow user to pause the training process.
- while viewer.state.status == "paused":
+ while viewer.state == "paused":
time.sleep(0.01)
# Do the training step and compute the number of training rays per second.
tic = time.time()
@@ -96,7 +103,7 @@ def training_step():
num_train_steps_per_sec = 1.0 / (time.time() - tic)
num_train_rays_per_sec = num_train_rays_per_step * num_train_steps_per_sec
# Update the viewer state.
- viewer.state.num_train_rays_per_sec = num_train_rays_per_sec
+ viewer.render_tab_state.num_train_rays_per_sec = num_train_rays_per_sec
# Update the scene.
viewer.update(step, num_train_rays_per_step)
viewer.complete()
diff --git a/examples/02_mesh_rendering.py b/examples/02_mesh_rendering.py
index a6d6110..767021e 100644
--- a/examples/02_mesh_rendering.py
+++ b/examples/02_mesh_rendering.py
@@ -172,13 +172,21 @@ def main(port: int = 8080):
)
def render_fn(
- camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
+ camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
) -> UInt8[np.ndarray, "H W 3"]:
- # nvdiffrast requires the image size to be multiples of 8.
- img_wh = (img_wh[0] // 8 * 8, img_wh[1] // 8 * 8)
-
# Get camera parameters.
+ if render_tab_state.preview_render:
+ width = render_tab_state.render_width
+ height = render_tab_state.render_height
+ else:
+ width = render_tab_state.viewer_width
+ height = render_tab_state.viewer_height
+
+ # nvdiffrast requires the image size to be multiples of 8.
+ width = width // 8 * 8
+ height = height // 8 * 8
c2w = camera_state.c2w
+ img_wh = [width, height]
K = camera_state.get_K(img_wh)
# Compute the normal map.
diff --git a/examples/03_gsplat_rendering.py b/examples/03_gsplat_rendering.py
index 473baf0..1ddd2e8 100644
--- a/examples/03_gsplat_rendering.py
+++ b/examples/03_gsplat_rendering.py
@@ -135,10 +135,17 @@
# register and open viewer
@torch.no_grad()
-def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]):
- width, height = img_wh
+def viewer_render_fn(
+ camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
+):
+ if render_tab_state.preview_render:
+ width = render_tab_state.render_width
+ height = render_tab_state.render_height
+ else:
+ width = render_tab_state.viewer_width
+ height = render_tab_state.viewer_height
c2w = camera_state.c2w
- K = camera_state.get_K(img_wh)
+ K = camera_state.get_K([width, height])
c2w = torch.from_numpy(c2w).float().to(device)
K = torch.from_numpy(K).float().to(device)
viewmat = c2w.inverse()
diff --git a/examples/04_gsplat_training.py b/examples/04_gsplat_training.py
index b3124a7..cf8debe 100644
--- a/examples/04_gsplat_training.py
+++ b/examples/04_gsplat_training.py
@@ -916,20 +916,31 @@ def render_traj(self, step: int):
@torch.no_grad()
def _viewer_render_fn(
- self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
+ self,
+ camera_state: nerfview.CameraState,
+ render_tab_state: nerfview.RenderTabState,
):
"""Callable function for the viewer."""
- W, H = img_wh
+ if render_tab_state.preview_render:
+ width, height = (
+ render_tab_state.render_width,
+ render_tab_state.render_height,
+ )
+ else:
+ width, height = (
+ render_tab_state.viewer_width,
+ render_tab_state.viewer_height,
+ )
c2w = camera_state.c2w
- K = camera_state.get_K(img_wh)
+ K = camera_state.get_K([width, height])
c2w = torch.from_numpy(c2w).float().to(self.device)
K = torch.from_numpy(K).float().to(self.device)
render_colors, _, _ = self.rasterize_splats(
camtoworlds=c2w[None],
Ks=K[None],
- width=W,
- height=H,
+ width=width,
+ height=height,
sh_degree=self.cfg.sh_degree, # active all SH degrees
radius_clip=3.0, # skip GSs that have small image radius (in pixels)
backgrounds=torch.ones(1, 3, device=self.device),
diff --git a/nerfview/__init__.py b/nerfview/__init__.py
index 7578c49..ef8c5dd 100644
--- a/nerfview/__init__.py
+++ b/nerfview/__init__.py
@@ -1,2 +1,3 @@
+from .render_panel import RenderTabState
from .version import __version__
from .viewer import VIEWER_LOCK, CameraState, Viewer, with_viewer_lock
diff --git a/nerfview/_renderer.py b/nerfview/_renderer.py
index be01125..0d55982 100644
--- a/nerfview/_renderer.py
+++ b/nerfview/_renderer.py
@@ -1,3 +1,7 @@
+"""
+Modified from nerfview/_renderer.py
+"""
+
import dataclasses
import os
import sys
@@ -9,7 +13,7 @@
import viser
if TYPE_CHECKING:
- from .viewer import CameraState, Viewer
+ from examples.viewer import CameraState, Viewer
RenderState = Literal["low_move", "low_static", "high"]
RenderAction = Literal["rerender", "move", "static", "update"]
@@ -53,7 +57,7 @@ def __init__(
self.lock = lock
self.running = True
- self.is_prepared_fn = lambda: self.viewer.state.status != "preparing"
+ self.is_prepared_fn = lambda: self.viewer.state != "preparing"
self._render_event = threading.Event()
self._state: RenderState = "low_static"
@@ -61,6 +65,7 @@ def __init__(
self._target_fps = 30
self._may_interrupt_render = False
+ self._old_version = False
self._define_transitions()
@@ -84,8 +89,9 @@ def _may_interrupt_trace(self, frame, event, arg):
return self._may_interrupt_trace
def _get_img_wh(self, aspect: float) -> Tuple[int, int]:
- max_img_res = self.viewer._max_img_res_slider.value
- if self._state == "high":
+ # we always trade off speed for quality
+ max_img_res = self.viewer.render_tab_state.viewer_res
+ if self._state in ["high"]:
# if True:
H = max_img_res
W = int(H * aspect)
@@ -93,7 +99,7 @@ def _get_img_wh(self, aspect: float) -> Tuple[int, int]:
W = max_img_res
H = int(W / aspect)
elif self._state in ["low_move", "low_static"]:
- num_view_rays_per_sec = self.viewer.state.num_view_rays_per_sec
+ num_view_rays_per_sec = self.viewer.render_tab_state.num_view_rays_per_sec
target_fps = self._target_fps
num_viewer_rays = num_view_rays_per_sec / target_fps
H = (num_viewer_rays / aspect) ** 0.5
@@ -141,13 +147,31 @@ def run(self):
with self.lock, set_trace_context(self._may_interrupt_trace):
tic = time.time()
W, H = img_wh = self._get_img_wh(task.camera_state.aspect)
- rendered = self.viewer.render_fn(task.camera_state, img_wh)
+ self.viewer.render_tab_state.viewer_width = W
+ self.viewer.render_tab_state.viewer_height = H
+
+ if not self._old_version:
+ try:
+ rendered = self.viewer.render_fn(
+ task.camera_state,
+ self.viewer.render_tab_state,
+ )
+ except TypeError:
+ self._old_version = True
+ print(
+ "[WARNING] Your API will be deprecated in the future, please update your render_fn."
+ )
+ rendered = self.viewer.render_fn(task.camera_state, img_wh)
+ else:
+ rendered = self.viewer.render_fn(task.camera_state, img_wh)
+
+ self.viewer._after_render()
if isinstance(rendered, tuple):
img, depth = rendered
else:
img, depth = rendered, None
- self.viewer.state.num_view_rays_per_sec = (W * H) / (
- max(time.time() - tic, 1e-6)
+ self.viewer.render_tab_state.num_view_rays_per_sec = (W * H) / (
+ time.time() - tic
)
except InterruptRenderException:
continue
@@ -160,4 +184,3 @@ def run(self):
jpeg_quality=70 if task.action in ["static", "update"] else 40,
depth=depth,
)
- self.client.flush()
diff --git a/nerfview/render_panel.py b/nerfview/render_panel.py
new file mode 100644
index 0000000..a81fa3e
--- /dev/null
+++ b/nerfview/render_panel.py
@@ -0,0 +1,1407 @@
+# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import colorsys
+import dataclasses
+import json
+import os
+import threading
+import time
+from pathlib import Path
+from typing import Dict, List, Literal, Optional, Tuple, Union
+
+import imageio
+import matplotlib
+import numpy as np
+import splines
+import splines.quaternion
+import torch
+import viser
+import viser.transforms as tf
+from jaxtyping import Float
+from rich.console import Console
+from scipy import interpolate
+from torch import Tensor
+
+
+@dataclasses.dataclass
+class Keyframe:
+ position: np.ndarray
+ wxyz: np.ndarray
+ override_fov_enabled: bool
+ override_fov_rad: float
+ override_time_enabled: bool
+ override_time_val: float
+ aspect: float
+ override_transition_enabled: bool
+ override_transition_sec: Optional[float]
+
+ @staticmethod
+ def from_camera(camera: viser.CameraHandle, aspect: float) -> Keyframe:
+ return Keyframe(
+ camera.position,
+ camera.wxyz,
+ override_fov_enabled=False,
+ override_fov_rad=camera.fov,
+ override_time_enabled=False,
+ override_time_val=0.0,
+ aspect=aspect,
+ override_transition_enabled=False,
+ override_transition_sec=None,
+ )
+
+
+class CameraPath:
+ def __init__(
+ self,
+ server: viser.ViserServer,
+ duration_element: viser.GuiInputHandle[float],
+ time_enabled: bool = False,
+ ):
+ self._server = server
+ self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {}
+ self._keyframe_counter: int = 0
+ self._spline_nodes: List[viser.SceneNodeHandle] = []
+ self._camera_edit_panel: Optional[viser.Gui3dContainerHandle] = None
+
+ self._orientation_spline: Optional[splines.quaternion.KochanekBartels] = None
+ self._position_spline: Optional[splines.KochanekBartels] = None
+ self._fov_spline: Optional[splines.KochanekBartels] = None
+ self._keyframes_visible: bool = True
+
+ self._duration_element = duration_element
+
+ # These parameters should be overridden externally.
+ self.loop: bool = False
+ self.framerate: float = 30.0
+ self.tension: float = 0.5 # Tension / alpha term.
+ self.default_fov: float = 0.0
+ self.time_enabled = time_enabled
+ self.default_render_time: float = 0.0
+ self.default_transition_sec: float = 0.0
+ self.show_spline: bool = True
+
+ def set_keyframes_visible(self, visible: bool) -> None:
+ self._keyframes_visible = visible
+ for keyframe in self._keyframes.values():
+ keyframe[1].visible = visible
+
+ def add_camera(
+ self, keyframe: Keyframe, keyframe_index: Optional[int] = None
+ ) -> None:
+ """Add a new camera, or replace an old one if `keyframe_index` is passed in."""
+ server = self._server
+
+ # Add a keyframe if we aren't replacing an existing one.
+ if keyframe_index is None:
+ keyframe_index = self._keyframe_counter
+ self._keyframe_counter += 1
+
+ frustum_handle = server.scene.add_camera_frustum(
+ f"/render_cameras/{keyframe_index}",
+ fov=(
+ keyframe.override_fov_rad
+ if keyframe.override_fov_enabled
+ else self.default_fov
+ ),
+ aspect=keyframe.aspect,
+ scale=0.1,
+ color=(200, 10, 30),
+ wxyz=keyframe.wxyz,
+ position=keyframe.position,
+ visible=self._keyframes_visible,
+ )
+ self._server.scene.add_icosphere(
+ f"/render_cameras/{keyframe_index}/sphere",
+ radius=0.03,
+ color=(200, 10, 30),
+ )
+
+ @frustum_handle.on_click
+ def _(_) -> None:
+ if self._camera_edit_panel is not None:
+ self._camera_edit_panel.remove()
+ self._camera_edit_panel = None
+
+ with server.scene.add_3d_gui_container(
+ "/camera_edit_panel",
+ position=keyframe.position,
+ ) as camera_edit_panel:
+ self._camera_edit_panel = camera_edit_panel
+ override_fov = server.gui.add_checkbox(
+ "Override FOV", initial_value=keyframe.override_fov_enabled
+ )
+ override_fov_degrees_slider = server.gui.add_slider(
+ "Override FOV (degrees)",
+ 5.0,
+ 175.0,
+ step=0.1,
+ initial_value=keyframe.override_fov_rad * 180.0 / np.pi,
+ disabled=not keyframe.override_fov_enabled,
+ )
+ if self.time_enabled:
+ override_time = server.gui.add_checkbox(
+ "Override Time", initial_value=keyframe.override_time_enabled
+ )
+ override_time_val = server.gui.add_slider(
+ "Override Time",
+ 0.0,
+ 1.0,
+ step=0.01,
+ initial_value=keyframe.override_time_val,
+ disabled=not keyframe.override_time_enabled,
+ )
+
+ @override_time.on_update
+ def _(_) -> None:
+ keyframe.override_time_enabled = override_time.value
+ override_time_val.disabled = not override_time.value
+ self.add_camera(keyframe, keyframe_index)
+
+ @override_time_val.on_update
+ def _(_) -> None:
+ keyframe.override_time_val = override_time_val.value
+ self.add_camera(keyframe, keyframe_index)
+
+ delete_button = server.gui.add_button(
+ "Delete", color="red", icon=viser.Icon.TRASH
+ )
+ go_to_button = server.gui.add_button("Go to")
+ close_button = server.gui.add_button("Close")
+
+ @override_fov.on_update
+ def _(_) -> None:
+ keyframe.override_fov_enabled = override_fov.value
+ override_fov_degrees_slider.disabled = not override_fov.value
+ self.add_camera(keyframe, keyframe_index)
+
+ @override_fov_degrees_slider.on_update
+ def _(_) -> None:
+ keyframe.override_fov_rad = (
+ override_fov_degrees_slider.value / 180.0 * np.pi
+ )
+ self.add_camera(keyframe, keyframe_index)
+
+ @delete_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ assert event.client is not None
+ with event.client.gui.add_modal("Confirm") as modal:
+ event.client.gui.add_markdown("Delete keyframe?")
+ confirm_button = event.client.gui.add_button(
+ "Yes", color="red", icon=viser.Icon.TRASH
+ )
+ exit_button = event.client.gui.add_button("Cancel")
+
+ @confirm_button.on_click
+ def _(_) -> None:
+ assert camera_edit_panel is not None
+
+ keyframe_id = None
+ for i, keyframe_tuple in self._keyframes.items():
+ if keyframe_tuple[1] is frustum_handle:
+ keyframe_id = i
+ break
+ assert keyframe_id is not None
+
+ self._keyframes.pop(keyframe_id)
+ frustum_handle.remove()
+ camera_edit_panel.remove()
+ self._camera_edit_panel = None
+ modal.close()
+ self.update_spline()
+
+ @exit_button.on_click
+ def _(_) -> None:
+ modal.close()
+
+ @go_to_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ assert event.client is not None
+ client = event.client
+ T_world_current = tf.SE3.from_rotation_and_translation(
+ tf.SO3(client.camera.wxyz), client.camera.position
+ )
+ T_world_target = tf.SE3.from_rotation_and_translation(
+ tf.SO3(keyframe.wxyz), keyframe.position
+ ) @ tf.SE3.from_translation(np.array([0.0, 0.0, -0.5]))
+
+ T_current_target = T_world_current.inverse() @ T_world_target
+
+ for j in range(10):
+ T_world_set = T_world_current @ tf.SE3.exp(
+ T_current_target.log() * j / 9.0
+ )
+
+ # Important bit: we atomically set both the orientation and the position
+ # of the camera.
+ with client.atomic():
+ client.camera.wxyz = T_world_set.rotation().wxyz
+ client.camera.position = T_world_set.translation()
+ time.sleep(1.0 / 30.0)
+
+ @close_button.on_click
+ def _(_) -> None:
+ assert camera_edit_panel is not None
+ camera_edit_panel.remove()
+ self._camera_edit_panel = None
+
+ self._keyframes[keyframe_index] = (keyframe, frustum_handle)
+
+ def update_aspect(self, aspect: float) -> None:
+ for keyframe_index, frame in self._keyframes.items():
+ frame = dataclasses.replace(frame[0], aspect=aspect)
+ self.add_camera(frame, keyframe_index=keyframe_index)
+
+ def get_aspect(self) -> float:
+ """Get W/H aspect ratio, which is shared across all keyframes."""
+ assert len(self._keyframes) > 0
+ return next(iter(self._keyframes.values()))[0].aspect
+
+ def reset(self) -> None:
+ for frame in self._keyframes.values():
+ frame[1].remove()
+ self._keyframes.clear()
+ self.update_spline()
+
+ def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray:
+ """From a time value in seconds, compute a t value for our geometric
+ spline interpolation. An increment of 1 for the latter will move the
+ camera forward by one keyframe.
+
+ We use a PCHIP spline here to guarantee monotonicity.
+ """
+ transition_times_cumsum = self.compute_transition_times_cumsum()
+ spline_indices = np.arange(transition_times_cumsum.shape[0])
+
+ if self.loop:
+ # In the case of a loop, we pad the spline to match the start/end
+ # slopes.
+ interpolator = interpolate.PchipInterpolator(
+ x=np.concatenate(
+ [
+ [-(transition_times_cumsum[-1] - transition_times_cumsum[-2])],
+ transition_times_cumsum,
+ transition_times_cumsum[-1:] + transition_times_cumsum[1:2],
+ ],
+ axis=0,
+ ),
+ y=np.concatenate(
+ [[-1], spline_indices, [spline_indices[-1] + 1]], axis=0
+ ),
+ )
+ else:
+ interpolator = interpolate.PchipInterpolator(
+ x=transition_times_cumsum, y=spline_indices
+ )
+
+ # Clip to account for floating point error.
+ return np.clip(interpolator(time), 0, spline_indices[-1])
+
+ def interpolate_pose_and_fov_rad(
+ self, normalized_t: float
+ ) -> Optional[Union[Tuple[tf.SE3, float], Tuple[tf.SE3, float, float]]]:
+ if len(self._keyframes) < 2:
+ return None
+
+ self._fov_spline = splines.KochanekBartels(
+ [
+ (
+ keyframe[0].override_fov_rad
+ if keyframe[0].override_fov_enabled
+ else self.default_fov
+ )
+ for keyframe in self._keyframes.values()
+ ],
+ tcb=(self.tension, 0.0, 0.0),
+ endconditions="closed" if self.loop else "natural",
+ )
+
+ self._time_spline = splines.KochanekBartels(
+ [
+ (
+ keyframe[0].override_time_val
+ if keyframe[0].override_time_enabled
+ else self.default_render_time
+ )
+ for keyframe in self._keyframes.values()
+ ],
+ tcb=(self.tension, 0.0, 0.0),
+ endconditions="closed" if self.loop else "natural",
+ )
+
+ assert self._orientation_spline is not None
+ assert self._position_spline is not None
+ assert self._fov_spline is not None
+ if self.time_enabled:
+ assert self._time_spline is not None
+ max_t = self.compute_duration()
+ t = max_t * normalized_t
+ spline_t = float(self.spline_t_from_t_sec(np.array(t)))
+
+ quat = self._orientation_spline.evaluate(spline_t)
+ assert isinstance(quat, splines.quaternion.UnitQuaternion)
+ if self.time_enabled:
+ return (
+ tf.SE3.from_rotation_and_translation(
+ tf.SO3(np.array([quat.scalar, *quat.vector])),
+ self._position_spline.evaluate(spline_t),
+ ),
+ float(self._fov_spline.evaluate(spline_t)),
+ float(self._time_spline.evaluate(spline_t)),
+ )
+ else:
+ return (
+ tf.SE3.from_rotation_and_translation(
+ tf.SO3(np.array([quat.scalar, *quat.vector])),
+ self._position_spline.evaluate(spline_t),
+ ),
+ float(self._fov_spline.evaluate(spline_t)),
+ )
+
+ def update_spline(self) -> None:
+ num_frames = int(self.compute_duration() * self.framerate)
+ keyframes = list(self._keyframes.values())
+
+ if num_frames <= 0 or not self.show_spline or len(keyframes) < 2:
+ for node in self._spline_nodes:
+ node.remove()
+ self._spline_nodes.clear()
+ return
+
+ transition_times_cumsum = self.compute_transition_times_cumsum()
+
+ self._orientation_spline = splines.quaternion.KochanekBartels(
+ [
+ splines.quaternion.UnitQuaternion.from_unit_xyzw(
+ np.roll(keyframe[0].wxyz, shift=-1)
+ )
+ for keyframe in keyframes
+ ],
+ tcb=(self.tension, 0.0, 0.0),
+ endconditions="closed" if self.loop else "natural",
+ )
+ self._position_spline = splines.KochanekBartels(
+ [keyframe[0].position for keyframe in keyframes],
+ tcb=(self.tension, 0.0, 0.0),
+ endconditions="closed" if self.loop else "natural",
+ )
+
+ # Update visualized spline.
+ points_array = self._position_spline.evaluate(
+ self.spline_t_from_t_sec(
+ np.linspace(0, transition_times_cumsum[-1], num_frames)
+ )
+ )
+ colors_array = np.array(
+ [
+ colorsys.hls_to_rgb(h, 0.5, 1.0)
+ for h in np.linspace(0.0, 1.0, len(points_array))
+ ]
+ )
+
+ # Clear prior spline nodes.
+ for node in self._spline_nodes:
+ node.remove()
+ self._spline_nodes.clear()
+
+ self._spline_nodes.append(
+ self._server.scene.add_spline_catmull_rom(
+ "/render_camera_spline",
+ positions=points_array,
+ color=(220, 220, 220),
+ closed=self.loop,
+ line_width=1.0,
+ segments=points_array.shape[0] + 1,
+ )
+ )
+ self._spline_nodes.append(
+ self._server.scene.add_point_cloud(
+ "/render_camera_spline/points",
+ points=points_array,
+ colors=colors_array,
+ point_size=0.04,
+ )
+ )
+
+ def make_transition_handle(i: int) -> None:
+ assert self._position_spline is not None
+ transition_pos = self._position_spline.evaluate(
+ float(
+ self.spline_t_from_t_sec(
+ (transition_times_cumsum[i] + transition_times_cumsum[i + 1])
+ / 2.0,
+ )
+ )
+ )
+ transition_sphere = self._server.scene.add_icosphere(
+ f"/render_camera_spline/transition_{i}",
+ radius=0.04,
+ color=(255, 0, 0),
+ position=transition_pos,
+ )
+ self._spline_nodes.append(transition_sphere)
+
+ @transition_sphere.on_click
+ def _(_) -> None:
+ server = self._server
+
+ if self._camera_edit_panel is not None:
+ self._camera_edit_panel.remove()
+ self._camera_edit_panel = None
+
+ keyframe_index = (i + 1) % len(self._keyframes)
+ keyframe = keyframes[keyframe_index][0]
+
+ with server.scene.add_3d_gui_container(
+ "/camera_edit_panel",
+ position=transition_pos,
+ ) as camera_edit_panel:
+ self._camera_edit_panel = camera_edit_panel
+ override_transition_enabled = server.gui.add_checkbox(
+ "Override transition",
+ initial_value=keyframe.override_transition_enabled,
+ )
+ override_transition_sec = server.gui.add_number(
+ "Override transition (sec)",
+ initial_value=(
+ keyframe.override_transition_sec
+ if keyframe.override_transition_sec is not None
+ else self.default_transition_sec
+ ),
+ min=0.001,
+ max=30.0,
+ step=0.001,
+ disabled=not override_transition_enabled.value,
+ )
+ close_button = server.gui.add_button("Close")
+
+ @override_transition_enabled.on_update
+ def _(_) -> None:
+ keyframe.override_transition_enabled = (
+ override_transition_enabled.value
+ )
+ override_transition_sec.disabled = (
+ not override_transition_enabled.value
+ )
+ self._duration_element.value = self.compute_duration()
+
+ @override_transition_sec.on_update
+ def _(_) -> None:
+ keyframe.override_transition_sec = override_transition_sec.value
+ self._duration_element.value = self.compute_duration()
+
+ @close_button.on_click
+ def _(_) -> None:
+ assert camera_edit_panel is not None
+ camera_edit_panel.remove()
+ self._camera_edit_panel = None
+
+ (num_transitions_plus_1,) = transition_times_cumsum.shape
+ for i in range(num_transitions_plus_1 - 1):
+ make_transition_handle(i)
+
+ # for i in range(transition_times.shape[0])
+
+ def compute_duration(self) -> float:
+ """Compute the total duration of the trajectory."""
+ total = 0.0
+ for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
+ if i == 0 and not self.loop:
+ continue
+ del frustum
+ total += (
+ keyframe.override_transition_sec
+ if keyframe.override_transition_enabled
+ and keyframe.override_transition_sec is not None
+ else self.default_transition_sec
+ )
+ return total
+
+ def compute_transition_times_cumsum(self) -> np.ndarray:
+ """Compute the total duration of the trajectory."""
+ total = 0.0
+ out = [0.0]
+ for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
+ if i == 0:
+ continue
+ del frustum
+ total += (
+ keyframe.override_transition_sec
+ if keyframe.override_transition_enabled
+ and keyframe.override_transition_sec is not None
+ else self.default_transition_sec
+ )
+ out.append(total)
+
+ if self.loop:
+ keyframe = next(iter(self._keyframes.values()))[0]
+ total += (
+ keyframe.override_transition_sec
+ if keyframe.override_transition_enabled
+ and keyframe.override_transition_sec is not None
+ else self.default_transition_sec
+ )
+ out.append(total)
+
+ return np.array(out)
+
+
+@dataclasses.dataclass
+class RenderTabState:
+ """Useful GUI handles exposed by the render tab."""
+
+ num_train_rays_per_sec: Optional[float] = None
+ num_view_rays_per_sec: float = 100000.0
+ preview_render: bool = False
+ preview_fov: float = 0.0
+ preview_time: float = 0.0
+ preview_aspect: float = 1.0
+ viewer_res: int = 2048
+ viewer_width: int = 1280
+ viewer_height: int = 960
+ render_width: int = 1280
+ render_height: int = 960
+
+
+Colormaps = Literal["turbo", "viridis", "magma", "inferno", "cividis", "gray"]
+
+
+def apply_float_colormap(
+ image: Float[Tensor, "*bs 1"], colormap: Colormaps = "viridis"
+) -> Float[Tensor, "*bs rgb=3"]:
+ """Copied from nerfstudio/utils/colormaps.py
+ Convert single channel to a color image.
+
+ Args:
+ image: Single channel image.
+ colormap: Colormap for image.
+
+ Returns:
+ Tensor: Colored image with colors in [0, 1]
+ """
+
+ image = torch.nan_to_num(image, 0)
+ if colormap == "gray":
+ return image.repeat(1, 1, 3)
+ image_long = (image * 255).long()
+ image_long_min = torch.min(image_long)
+ image_long_max = torch.max(image_long)
+ assert image_long_min >= 0, f"the min value is {image_long_min}"
+ assert image_long_max <= 255, f"the max value is {image_long_max}"
+ return torch.tensor(matplotlib.colormaps[colormap].colors, device=image.device)[
+ image_long[..., 0]
+ ]
+
+
+def populate_general_render_tab(
+ server: viser.ViserServer,
+ output_dir: Path,
+ folder: viser.GuiFolderHandle,
+ render_tab_state: RenderTabState,
+ extra_handles: Optional[Dict[str, viser.GuiInputHandle]] = None,
+ scale_ratio: float = 10.0, # VISER_NERFSTUDIO_SCALE_RATIO
+ time_enabled: bool = False,
+) -> Dict[str, viser.GuiInputHandle]:
+ """
+ Populate the render tab with general controls.
+ Args:
+ server: The server to populate the render tab on.
+ output_dir: The path to the output folder.
+ folder: The folder to populate the render tab on.
+ render_tab_state: The render tab state exposed to the outer scope.
+ extra_handles: Extra handles needed to be disabled during dump_video.
+ scale_ratio: The scale ratio for the render tab.
+ time_enabled: Whether to enable the time slider.
+ Returns:
+ A dictionary of handles populated in the render tab.
+ """
+ with folder:
+ fov_degrees_slider = server.gui.add_slider(
+ "FOV",
+ initial_value=50.0,
+ min=0.1,
+ max=175.0,
+ step=0.01,
+ hint="Field-of-view for rendering, which can also be overridden on a per-keyframe basis.",
+ )
+
+ render_time = None
+ if time_enabled:
+ render_time = server.gui.add_slider(
+ "Default Time",
+ initial_value=0.0,
+ min=0.0,
+ max=1.0,
+ step=0.01,
+ hint="Rendering time step, which can also be overridden on a per-keyframe basis.",
+ )
+
+ @render_time.on_update
+ def _(_) -> None:
+ camera_path.default_render_time = render_time.value
+
+ @fov_degrees_slider.on_update
+ def _(_) -> None:
+ fov_radians = fov_degrees_slider.value / 180.0 * np.pi
+ for client in server.get_clients().values():
+ client.camera.fov = fov_radians
+ camera_path.default_fov = fov_radians
+
+ # Updating the aspect ratio will also re-render the camera frustums.
+ # Could rethink this.
+ camera_path.update_aspect(
+ render_res_vec2.value[0] / render_res_vec2.value[1]
+ )
+ compute_and_update_preview_camera_state()
+
+ render_res_vec2 = server.gui.add_vector2(
+ "Render Res",
+ initial_value=(1280, 960),
+ min=(50, 50),
+ max=(10_000, 10_000),
+ step=1,
+ hint="Rendering resolution.",
+ )
+
+ @render_res_vec2.on_update
+ def _(_) -> None:
+ camera_path.update_aspect(
+ render_res_vec2.value[0] / render_res_vec2.value[1]
+ )
+ compute_and_update_preview_camera_state()
+ render_tab_state.render_width = int(render_res_vec2.value[0])
+ render_tab_state.render_height = int(render_res_vec2.value[1])
+
+ add_keyframe_button = server.gui.add_button(
+ "Add Keyframe",
+ icon=viser.Icon.PLUS,
+ hint="Add a new keyframe at the current pose.",
+ )
+
+ @add_keyframe_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ assert event.client_id is not None
+ camera = server.get_clients()[event.client_id].camera
+
+ # Add this camera to the path.
+ camera_path.add_camera(
+ Keyframe.from_camera(
+ camera,
+ aspect=render_res_vec2.value[0] / render_res_vec2.value[1],
+ ),
+ )
+ duration_number.value = camera_path.compute_duration()
+ camera_path.update_spline()
+
+ clear_keyframes_button = server.gui.add_button(
+ "Clear Keyframes",
+ icon=viser.Icon.TRASH,
+ hint="Remove all keyframes from the render path.",
+ )
+
+ @clear_keyframes_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ assert event.client_id is not None
+ client = server.get_clients()[event.client_id]
+ with client.atomic(), client.gui.add_modal("Confirm") as modal:
+ client.gui.add_markdown("Clear all keyframes?")
+ confirm_button = client.gui.add_button(
+ "Yes", color="red", icon=viser.Icon.TRASH
+ )
+ exit_button = client.gui.add_button("Cancel")
+
+ @confirm_button.on_click
+ def _(_) -> None:
+ camera_path.reset()
+ modal.close()
+
+ duration_number.value = camera_path.compute_duration()
+
+ # Clear move handles.
+ if len(transform_controls) > 0:
+ for t in transform_controls:
+ t.remove()
+ transform_controls.clear()
+ return
+
+ @exit_button.on_click
+ def _(_) -> None:
+ modal.close()
+
+ reset_up_button = server.gui.add_button(
+ "Reset Up Direction",
+ icon=viser.Icon.ARROW_BIG_UP_LINES,
+ color="gray",
+ hint="Set the up direction of the camera orbit controls to the camera's current up direction.",
+ )
+
+ @reset_up_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ assert event.client is not None
+ event.client.camera.up_direction = tf.SO3(
+ event.client.camera.wxyz
+ ) @ np.array([0.0, -1.0, 0.0])
+
+ loop_checkbox = server.gui.add_checkbox(
+ "Loop",
+ False,
+ hint="Add a segment between the first and last keyframes.",
+ )
+
+ @loop_checkbox.on_update
+ def _(_) -> None:
+ camera_path.loop = loop_checkbox.value
+ duration_number.value = camera_path.compute_duration()
+
+ tension_slider = server.gui.add_slider(
+ "Spline tension",
+ min=0.0,
+ max=1.0,
+ initial_value=0.0,
+ step=0.01,
+ hint="Tension parameter for adjusting smoothness of spline interpolation.",
+ )
+
+ @tension_slider.on_update
+ def _(_) -> None:
+ camera_path.tension = tension_slider.value
+ camera_path.update_spline()
+
+ move_checkbox = server.gui.add_checkbox(
+ "Move keyframes",
+ initial_value=False,
+ hint="Toggle move handles for keyframes in the scene.",
+ )
+
+ transform_controls: List[viser.SceneNodeHandle] = []
+
+ @move_checkbox.on_update
+ def _(event: viser.GuiEvent) -> None:
+ # Clear move handles when toggled off.
+ if move_checkbox.value is False:
+ for t in transform_controls:
+ t.remove()
+ transform_controls.clear()
+ return
+
+ def _make_transform_controls_callback(
+ keyframe: Tuple[Keyframe, viser.SceneNodeHandle],
+ controls: viser.TransformControlsHandle,
+ ) -> None:
+ @controls.on_update
+ def _(_) -> None:
+ keyframe[0].wxyz = controls.wxyz
+ keyframe[0].position = controls.position
+
+ keyframe[1].wxyz = controls.wxyz
+ keyframe[1].position = controls.position
+
+ camera_path.update_spline()
+
+ # Show move handles.
+ assert event.client is not None
+ for keyframe_index, keyframe in camera_path._keyframes.items():
+ controls = event.client.scene.add_transform_controls(
+ f"/keyframe_move/{keyframe_index}",
+ scale=0.4,
+ wxyz=keyframe[0].wxyz,
+ position=keyframe[0].position,
+ )
+ transform_controls.append(controls)
+ _make_transform_controls_callback(keyframe, controls)
+
+ show_keyframe_checkbox = server.gui.add_checkbox(
+ "Show keyframes",
+ initial_value=True,
+ hint="Show keyframes in the scene.",
+ )
+
+ @show_keyframe_checkbox.on_update
+ def _(_: viser.GuiEvent) -> None:
+ camera_path.set_keyframes_visible(show_keyframe_checkbox.value)
+
+ show_spline_checkbox = server.gui.add_checkbox(
+ "Show spline",
+ initial_value=True,
+ hint="Show camera path spline in the scene.",
+ )
+
+ @show_spline_checkbox.on_update
+ def _(_) -> None:
+ camera_path.show_spline = show_spline_checkbox.value
+ camera_path.update_spline()
+
+ transition_sec_number = server.gui.add_number(
+ "Transition (sec)",
+ min=0.001,
+ max=30.0,
+ step=0.001,
+ initial_value=2.0,
+ hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.",
+ )
+ framerate_number = server.gui.add_number(
+ "FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0
+ )
+ duration_number = server.gui.add_number(
+ "Duration (sec)",
+ min=0.0,
+ max=1e8,
+ step=0.001,
+ initial_value=0.0,
+ disabled=True,
+ )
+
+ @transition_sec_number.on_update
+ def _(_) -> None:
+ camera_path.default_transition_sec = transition_sec_number.value
+ duration_number.value = camera_path.compute_duration()
+
+ # set the initial value to the current date-time string
+ trajectory_name_text = server.gui.add_text(
+ "Name",
+ initial_value="default",
+ hint="Name of the trajectory",
+ )
+
+ # add button for loading existing path
+ load_camera_path_button = server.gui.add_button(
+ "Load Trajectory",
+ icon=viser.Icon.FOLDER_OPEN,
+ hint="Load an existing camera path.",
+ )
+
+ save_camera_path_button = server.gui.add_button(
+ "Save Trajectory",
+ icon=viser.Icon.FILE_EXPORT,
+ hint="Save the current trajectory to a json file.",
+ )
+
+ play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY)
+ pause_button = server.gui.add_button(
+ "Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False
+ )
+ preview_save_camera_path_button = server.gui.add_button(
+ "Preview Render",
+ icon=viser.Icon.EYE,
+ hint="Show a preview of the render in the viewport.",
+ )
+ preview_render_stop_button = server.gui.add_button(
+ "Exit Render Preview", color="red", visible=False
+ )
+ dump_video_button = server.gui.add_button(
+ "Dump Video",
+ color="green",
+ icon=viser.Icon.PLAYER_PLAY,
+ hint="Dump the current trajectory as a video.",
+ )
+
+ def get_max_frame_index() -> int:
+ return max(1, int(framerate_number.value * duration_number.value) - 1)
+
+ preview_camera_handle: Optional[viser.SceneNodeHandle] = None
+
+ def remove_preview_camera() -> None:
+ nonlocal preview_camera_handle
+ if preview_camera_handle is not None:
+ preview_camera_handle.remove()
+ preview_camera_handle = None
+
+ def compute_and_update_preview_camera_state() -> (
+ Optional[Union[Tuple[tf.SE3, float], Tuple[tf.SE3, float, float]]]
+ ):
+ """Update the render tab state with the current preview camera pose.
+ Returns current camera pose + FOV if available."""
+
+ if preview_frame_slider is None:
+ return
+ maybe_pose_and_fov_rad = camera_path.interpolate_pose_and_fov_rad(
+ preview_frame_slider.value / get_max_frame_index()
+ )
+ if maybe_pose_and_fov_rad is None:
+ remove_preview_camera()
+ return
+ time = None
+ if len(maybe_pose_and_fov_rad) == 3: # Time is enabled.
+ pose, fov_rad, time = maybe_pose_and_fov_rad
+ render_tab_state.preview_time = time
+ else:
+ pose, fov_rad = maybe_pose_and_fov_rad
+ render_tab_state.preview_fov = fov_rad
+ render_tab_state.preview_aspect = camera_path.get_aspect()
+
+ if time is not None:
+ return pose, fov_rad, time
+ else:
+ return pose, fov_rad
+
+ def add_preview_frame_slider() -> Optional[viser.GuiInputHandle[int]]:
+ """Helper for creating the current frame # slider. This is removed and
+ re-added anytime the `max` value changes."""
+
+ with folder:
+ preview_frame_slider = server.gui.add_slider(
+ "Preview frame",
+ min=0,
+ max=get_max_frame_index(),
+ step=1,
+ initial_value=0,
+ # Place right after the trajectory name text
+ order=trajectory_name_text.order + 0.01,
+ disabled=get_max_frame_index() == 1,
+ )
+ play_button.disabled = preview_frame_slider.disabled
+ preview_save_camera_path_button.disabled = preview_frame_slider.disabled
+ save_camera_path_button.disabled = preview_frame_slider.disabled
+ dump_video_button.disabled = preview_frame_slider.disabled
+
+ @preview_frame_slider.on_update
+ def _(_) -> None:
+ nonlocal preview_camera_handle
+ maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
+ if maybe_pose_and_fov_rad is None:
+ return
+ if len(maybe_pose_and_fov_rad) == 3: # Time is enabled.
+ pose, fov_rad, time = maybe_pose_and_fov_rad
+ else:
+ pose, fov_rad = maybe_pose_and_fov_rad
+
+ preview_camera_handle = server.scene.add_camera_frustum(
+ "/preview_camera",
+ fov=fov_rad,
+ aspect=render_res_vec2.value[0] / render_res_vec2.value[1],
+ scale=0.35,
+ wxyz=pose.rotation().wxyz,
+ position=pose.translation(),
+ color=(10, 200, 30),
+ )
+ if render_tab_state.preview_render:
+ for client in server.get_clients().values():
+ # aspect ratio is not assignable, pass args in get_render instead
+ client.camera.wxyz = pose.rotation().wxyz
+ client.camera.position = pose.translation()
+ client.camera.fov = fov_rad
+
+ return preview_frame_slider
+
+ # We back up the camera poses before and after we start previewing renders.
+ camera_pose_backup_from_id: Dict[int, tuple] = {}
+
+ @preview_save_camera_path_button.on_click
+ def _(_) -> None:
+ render_tab_state.preview_render = True
+ preview_save_camera_path_button.visible = False
+ preview_render_stop_button.visible = True
+ dump_video_button.disabled = True
+
+ maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
+ if maybe_pose_and_fov_rad is None:
+ remove_preview_camera()
+ return
+ if len(maybe_pose_and_fov_rad) == 3: # Time is enabled.
+ pose, fov, time = maybe_pose_and_fov_rad
+ else:
+ pose, fov = maybe_pose_and_fov_rad
+ del fov
+
+ # Hide all scene nodes when we're previewing the render.
+ server.scene.set_global_visibility(False)
+
+ # Back up and then set camera poses.
+ for client in server.get_clients().values():
+ camera_pose_backup_from_id[client.client_id] = (
+ client.camera.position,
+ client.camera.look_at,
+ client.camera.up_direction,
+ )
+ client.camera.wxyz = pose.rotation().wxyz
+ client.camera.position = pose.translation()
+
+ @preview_render_stop_button.on_click
+ def _(_) -> None:
+ render_tab_state.preview_render = False
+ preview_save_camera_path_button.visible = True
+ preview_render_stop_button.visible = False
+ dump_video_button.disabled = False
+
+ # Revert camera poses.
+ for client in server.get_clients().values():
+ if client.client_id not in camera_pose_backup_from_id:
+ continue
+ cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(
+ client.client_id
+ )
+ client.camera.position = cam_position
+ client.camera.look_at = cam_look_at
+ client.camera.up_direction = cam_up
+ client.flush()
+
+ # Un-hide scene nodes.
+ server.scene.set_global_visibility(True)
+
+ preview_frame_slider = add_preview_frame_slider()
+ handles = {
+ "fov_degrees_slider": fov_degrees_slider,
+ "render_res_vec2": render_res_vec2,
+ "add_keyframe_button": add_keyframe_button,
+ "clear_keyframes_button": clear_keyframes_button,
+ "reset_up_button": reset_up_button,
+ "loop_checkbox": loop_checkbox,
+ "tension_slider": tension_slider,
+ "move_checkbox": move_checkbox,
+ "show_keyframe_checkbox": show_keyframe_checkbox,
+ "show_spline_checkbox": show_spline_checkbox,
+ "transition_sec_number": transition_sec_number,
+ "framerate_number": framerate_number,
+ "duration_number": duration_number,
+ "trajectory_name_text": trajectory_name_text,
+ "preview_frame_slider": preview_frame_slider,
+ "load_camera_path_button": load_camera_path_button,
+ "save_camera_path_button": save_camera_path_button,
+ "play_button": play_button,
+ "pause_button": pause_button,
+ "preview_save_camera_path_button": preview_save_camera_path_button,
+ "preview_render_stop_button": preview_render_stop_button,
+ "dump_video_button": dump_video_button,
+ }
+ if time_enabled:
+ handles["render_time"] = render_time
+
+ # Update the # of frames.
+ @duration_number.on_update
+ @framerate_number.on_update
+ def _(_) -> None:
+ remove_preview_camera() # Will be re-added when slider is updated.
+
+ nonlocal preview_frame_slider
+ old = preview_frame_slider
+ assert old is not None
+
+ preview_frame_slider = add_preview_frame_slider()
+ if preview_frame_slider is not None:
+ old.remove()
+ else:
+ preview_frame_slider = old
+
+ handles["preview_frame_slider"] = preview_frame_slider
+ camera_path.framerate = framerate_number.value
+ camera_path.update_spline()
+
+ # Play the camera trajectory when the play button is pressed.
+ @play_button.on_click
+ def _(_) -> None:
+ play_button.visible = False
+ pause_button.visible = True
+ dump_video_button.disabled = True
+
+ def play() -> None:
+ while not play_button.visible:
+ max_frame = int(framerate_number.value * duration_number.value)
+ if max_frame > 0:
+ assert preview_frame_slider is not None
+ preview_frame_slider.value = (
+ preview_frame_slider.value + 1
+ ) % max_frame
+ time.sleep(1.0 / framerate_number.value)
+
+ play_thread = threading.Thread(target=play)
+ play_thread.start()
+ play_thread.join()
+ dump_video_button.disabled = False
+
+ # Play the camera trajectory when the play button is pressed.
+ @pause_button.on_click
+ def _(_) -> None:
+ play_button.visible = True
+ pause_button.visible = False
+
+ @load_camera_path_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ assert event.client is not None
+ camera_path_dir = output_dir / "camera_paths"
+ camera_path_dir.mkdir(parents=True, exist_ok=True)
+ preexisting_camera_paths = list(camera_path_dir.glob("*.json"))
+ preexisting_camera_filenames = [p.name for p in preexisting_camera_paths]
+
+ with event.client.gui.add_modal("Load Path") as modal:
+ if len(preexisting_camera_filenames) == 0:
+ event.client.gui.add_markdown("No existing paths found")
+ else:
+ event.client.gui.add_markdown("Select existing camera path:")
+ camera_path_dropdown = event.client.gui.add_dropdown(
+ label="Camera Path",
+ options=[str(p) for p in preexisting_camera_filenames],
+ initial_value=str(preexisting_camera_filenames[0]),
+ )
+ load_button = event.client.gui.add_button("Load")
+
+ @load_button.on_click
+ def _(_) -> None:
+ # load the json file
+ json_path = output_dir / "camera_paths" / camera_path_dropdown.value
+ with open(json_path, "r") as f:
+ json_data = json.load(f)
+
+ keyframes = json_data["keyframes"]
+ camera_path.reset()
+ for i in range(len(keyframes)):
+ frame = keyframes[i]
+ pose = tf.SE3.from_matrix(
+ np.array(frame["matrix"]).reshape(4, 4)
+ )
+ # apply the x rotation by 180 deg
+ pose = tf.SE3.from_rotation_and_translation(
+ pose.rotation() @ tf.SO3.from_x_radians(np.pi),
+ pose.translation(),
+ )
+ camera_path.add_camera(
+ Keyframe(
+ position=pose.translation() * scale_ratio,
+ wxyz=pose.rotation().wxyz,
+ # There are some floating point conversions between degrees and radians, so the fov and
+ # default_Fov values will not be exactly matched.
+ override_fov_enabled=abs(
+ frame["fov"] - json_data.get("default_fov", 0.0)
+ )
+ > 1e-3,
+ override_fov_rad=frame["fov"] / 180.0 * np.pi,
+ override_time_enabled=frame.get(
+ "override_time_enabled", False
+ ),
+ override_time_val=frame.get("render_time", None),
+ aspect=frame["aspect"],
+ override_transition_enabled=frame.get(
+ "override_transition_enabled", None
+ ),
+ override_transition_sec=frame.get(
+ "override_transition_sec", None
+ ),
+ ),
+ )
+
+ transition_sec_number.value = json_data.get(
+ "default_transition_sec", 0.5
+ )
+
+ # update the render name
+ trajectory_name_text.value = json_path.stem
+ camera_path.update_spline()
+ modal.close()
+
+ # visualize the camera path
+ server.scene.set_global_visibility(True)
+
+ cancel_button = event.client.gui.add_button("Cancel")
+
+ @cancel_button.on_click
+ def _(_) -> None:
+ modal.close()
+
+ @save_camera_path_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ assert event.client is not None
+ num_frames = int(framerate_number.value * duration_number.value)
+ json_data = {}
+ # json data has the properties:
+ # keyframes: list of keyframes with
+ # matrix : flattened 4x4 matrix
+ # fov: float in degrees
+ # aspect: float
+ # render_height: int
+ # render_width: int
+ # fps: int
+ # seconds: float
+ # is_cycle: bool
+ # smoothness_value: float
+ # camera_path: list of frames with properties
+ # camera_to_world: flattened 4x4 matrix
+ # fov: float in degrees
+ # aspect: float
+ # first populate the keyframes:
+ keyframes = []
+ for keyframe, dummy in camera_path._keyframes.values():
+ pose = tf.SE3.from_rotation_and_translation(
+ tf.SO3(keyframe.wxyz) @ tf.SO3.from_x_radians(np.pi),
+ keyframe.position / scale_ratio,
+ )
+ keyframe_dict = {
+ "matrix": pose.as_matrix().flatten().tolist(),
+ "fov": (
+ np.rad2deg(keyframe.override_fov_rad)
+ if keyframe.override_fov_enabled
+ else fov_degrees_slider.value
+ ),
+ "aspect": keyframe.aspect,
+ "override_transition_enabled": keyframe.override_transition_enabled,
+ "override_transition_sec": keyframe.override_transition_sec,
+ }
+ keyframes.append(keyframe_dict)
+ json_data["default_fov"] = fov_degrees_slider.value
+ json_data["default_transition_sec"] = transition_sec_number.value
+ json_data["keyframes"] = keyframes
+ json_data["render_height"] = render_res_vec2.value[1]
+ json_data["render_width"] = render_res_vec2.value[0]
+ json_data["fps"] = framerate_number.value
+ json_data["seconds"] = duration_number.value
+ json_data["is_cycle"] = loop_checkbox.value
+ json_data["smoothness_value"] = tension_slider.value
+ # now populate the camera path:
+ camera_path_list = []
+ for i in range(num_frames):
+ maybe_pose_and_fov = camera_path.interpolate_pose_and_fov_rad(
+ i / num_frames
+ )
+ if maybe_pose_and_fov is None:
+ return
+ time = None
+ if len(maybe_pose_and_fov) == 3: # Time is enabled.
+ pose, fov, time = maybe_pose_and_fov
+ else:
+ pose, fov = maybe_pose_and_fov
+ # rotate the axis of the camera 180 about x axis
+ pose = tf.SE3.from_rotation_and_translation(
+ pose.rotation() @ tf.SO3.from_x_radians(np.pi),
+ pose.translation() / scale_ratio,
+ )
+ camera_path_list_dict = {
+ "camera_to_world": pose.as_matrix().flatten().tolist(),
+ "fov": np.rad2deg(fov),
+ "aspect": render_res_vec2.value[0] / render_res_vec2.value[1],
+ }
+ if time is not None:
+ camera_path_list_dict["render_time"] = time
+ camera_path_list.append(camera_path_list_dict)
+ json_data["camera_path"] = camera_path_list
+ # finally add crop data if crop is enabled
+ # if control_panel is not None:
+ # if control_panel.crop_viewport:
+ # obb = control_panel.crop_obb
+ # rpy = tf.SO3.from_matrix(obb.R.numpy()).as_rpy_radians()
+ # color = control_panel.background_color
+ # json_data["crop"] = {
+ # "crop_center": obb.T.tolist(),
+ # "crop_scale": obb.S.tolist(),
+ # "crop_rot": [rpy.roll, rpy.pitch, rpy.yaw],
+ # "crop_bg_color": {"r": color[0], "g": color[1], "b": color[2]},
+ # }
+
+ # now write the json file
+ try:
+ json_outfile = (
+ output_dir / "camera_paths" / f"{trajectory_name_text.value}.json"
+ )
+ json_outfile.parent.mkdir(parents=True, exist_ok=True)
+ except Exception:
+ Console(width=120).print(
+ "[bold yellow]Warning: Failed to write the camera path to the data directory. Saving to the output directory instead."
+ )
+ json_outfile = (
+ output_dir / "camera_paths" / f"{trajectory_name_text.value}.json"
+ )
+ json_outfile.parent.mkdir(parents=True, exist_ok=True)
+ with open(json_outfile.absolute(), "w") as outfile:
+ json.dump(json_data, outfile)
+ print(f"Camera path saved to {json_outfile.absolute()}")
+
+ @dump_video_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+
+ # enter into preview render mode
+ render_tab_state.preview_render = True
+ maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
+ if maybe_pose_and_fov_rad is None:
+ remove_preview_camera()
+ return
+ if len(maybe_pose_and_fov_rad) == 3: # Time is enabled.
+ pose, fov, time = maybe_pose_and_fov_rad
+ else:
+ pose, fov = maybe_pose_and_fov_rad
+ del fov
+
+ # Hide all scene nodes when we're previewing the render.
+ server.scene.set_global_visibility(False)
+
+ # Back up and then set camera poses.
+ for client in server.get_clients().values():
+ camera_pose_backup_from_id[client.client_id] = (
+ client.camera.position,
+ client.camera.look_at,
+ client.camera.up_direction,
+ )
+ client.camera.wxyz = pose.rotation().wxyz
+ client.camera.position = pose.translation()
+
+ # disable all the trajectory control widgets
+ handles_to_disable = list(handles.values()) + list(extra_handles.values())
+ original_disabled = [handle.disabled for handle in handles_to_disable]
+ for handle in handles_to_disable:
+ handle.disabled = True
+
+ def dump() -> None:
+ os.makedirs(output_dir / "videos", exist_ok=True)
+ writer = imageio.get_writer(
+ f"{output_dir}/videos/traj_{trajectory_name_text.value}.mp4",
+ fps=framerate_number.value,
+ )
+ max_frame = int(framerate_number.value * duration_number.value)
+ assert max_frame > 0 and preview_frame_slider is not None
+ preview_frame_slider.value = 0
+ for _ in range(max_frame):
+ preview_frame_slider.value = (
+ preview_frame_slider.value + 1
+ ) % max_frame
+ # should we use get_render here?
+ image = client.camera.get_render(
+ height=render_res_vec2.value[1],
+ width=render_res_vec2.value[0],
+ )
+ writer.append_data(image)
+ writer.close()
+ print(f"Video saved to videos/traj_{trajectory_name_text.value}.mp4")
+
+ dump_thread = threading.Thread(target=dump)
+ dump_thread.start()
+ dump_thread.join()
+
+ # restore the original disabled state
+ for handle, original_disabled in zip(handles_to_disable, original_disabled):
+ handle.disabled = original_disabled
+
+ # exit preview render mode
+ render_tab_state.preview_render = False
+
+ # Revert camera poses.
+ for client in server.get_clients().values():
+ if client.client_id not in camera_pose_backup_from_id:
+ continue
+ cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(
+ client.client_id
+ )
+ client.camera.position = cam_position
+ client.camera.look_at = cam_look_at
+ client.camera.up_direction = cam_up
+ client.flush()
+
+ # Un-hide scene nodes.
+ server.scene.set_global_visibility(True)
+
+ camera_path = CameraPath(server, duration_number)
+ camera_path.tension = tension_slider.value
+ camera_path.default_fov = fov_degrees_slider.value / 180.0 * np.pi
+ camera_path.default_transition_sec = transition_sec_number.value
+
+ return handles
diff --git a/nerfview/version.py b/nerfview/version.py
index 27fdca4..3dc1f76 100644
--- a/nerfview/version.py
+++ b/nerfview/version.py
@@ -1 +1 @@
-__version__ = "0.0.3"
+__version__ = "0.1.0"
diff --git a/nerfview/viewer.py b/nerfview/viewer.py
index a42518e..8cf3229 100644
--- a/nerfview/viewer.py
+++ b/nerfview/viewer.py
@@ -1,14 +1,16 @@
import dataclasses
import time
+from pathlib import Path
from threading import Lock
-from typing import Callable, Literal, Optional, Tuple, Union
+from typing import Callable, Literal, Optional, Tuple
import numpy as np
import viser
import viser.transforms as vt
-from jaxtyping import Float32, UInt8
+from jaxtyping import Float32
from ._renderer import Renderer, RenderTask
+from .render_panel import RenderTabState, populate_general_render_tab
@dataclasses.dataclass
@@ -30,15 +32,6 @@ def get_K(self, img_wh: Tuple[int, int]) -> Float32[np.ndarray, "3 3"]:
return K
-@dataclasses.dataclass
-class ViewerState(object):
- num_train_rays_per_sec: Optional[float] = None
- num_view_rays_per_sec: float = 100000.0
- status: Literal[
- "rendering", "preparing", "training", "paused", "completed"
- ] = "training"
-
-
VIEWER_LOCK = Lock()
@@ -70,13 +63,8 @@ class Viewer(object):
def __init__(
self,
server: viser.ViserServer,
- render_fn: Callable[
- [CameraState, Tuple[int, int]],
- Union[
- UInt8[np.ndarray, "H W 3"],
- Tuple[UInt8[np.ndarray, "H W 3"], Optional[Float32[np.ndarray, "H W"]]],
- ],
- ],
+ render_fn: Callable,
+ output_dir: Optional[Path] = None,
mode: Literal["rendering", "training"] = "rendering",
):
# Public states.
@@ -84,9 +72,8 @@ def __init__(
self.render_fn = render_fn
self.mode = mode
self.lock = VIEWER_LOCK
- self.state = ViewerState()
- if self.mode == "rendering":
- self.state.status = "rendering"
+ self.state = "preparing"
+ self.output_dir = output_dir if output_dir is not None else Path("./results")
# Private states.
self._renderers: dict[int, Renderer] = {}
@@ -94,53 +81,109 @@ def __init__(
self._last_update_step: int = 0
self._last_move_time: float = 0.0
+ # Initialize and populate GUIs.
+ server.scene.set_global_visibility(True)
server.on_client_disconnect(self._disconnect_client)
server.on_client_connect(self._connect_client)
-
- self._define_guis()
-
- def _define_guis(self):
- with self.server.gui.add_folder(
- "Stats", visible=self.mode == "training"
- ) as self._stats_folder:
- self._stats_text_fn = (
- lambda: f"""
- Step: {self._step}\\
- Last Update: {self._last_update_step}
- """
+ server.gui.set_panel_label("basic viewer")
+ server.gui.configure_theme(
+ control_layout="collapsible",
+ dark_mode=True,
+ brand_color=(255, 211, 105),
+ )
+ if self.mode == "training":
+ self._init_training_tab()
+ self._populate_training_tab()
+ self._init_rendering_tab()
+ self._populate_rendering_tab()
+ self.state = mode
+
+ def _init_training_tab(self):
+ self._training_tab_handles = {}
+ self._training_folder = self.server.gui.add_folder("Training")
+
+ def _populate_training_tab(self):
+ server = self.server
+ with self._training_folder:
+ step_number = server.gui.add_number(
+ "Step",
+ min=0,
+ max=1000000,
+ step=1,
+ disabled=True,
+ initial_value=0,
)
- self._stats_text = self.server.gui.add_markdown(self._stats_text_fn())
-
- with self.server.gui.add_folder(
- "Training", visible=self.mode == "training"
- ) as self._training_folder:
- self._pause_train_button = self.server.gui.add_button("Pause")
- self._pause_train_button.on_click(self._toggle_train_buttons)
- self._pause_train_button.on_click(self._toggle_train_s)
- self._resume_train_button = self.server.gui.add_button("Resume")
- self._resume_train_button.visible = False
- self._resume_train_button.on_click(self._toggle_train_buttons)
- self._resume_train_button.on_click(self._toggle_train_s)
-
- self._train_util_slider = self.server.gui.add_slider(
- "Train Util", min=0.0, max=1.0, step=0.05, initial_value=0.9
+ pause_train_button = server.gui.add_button(
+ "Pause",
+ icon=viser.Icon.PLAYER_PAUSE,
+ hint="Pause the training.",
)
- self._train_util_slider.on_update(self.rerender)
-
- with self.server.gui.add_folder("Rendering") as self._rendering_folder:
- self._max_img_res_slider = self.server.gui.add_slider(
- "Max Img Res", min=64, max=2048, step=1, initial_value=2048
+ resume_train_button = server.gui.add_button(
+ "Resume",
+ icon=viser.Icon.PLAYER_PLAY,
+ visible=False,
+ hint="Resume the training.",
)
- self._max_img_res_slider.on_update(self.rerender)
- def _toggle_train_buttons(self, _):
- self._pause_train_button.visible = not self._pause_train_button.visible
- self._resume_train_button.visible = not self._resume_train_button.visible
+ @pause_train_button.on_click
+ @resume_train_button.on_click
+ def _(_) -> None:
+ pause_train_button.visible = not pause_train_button.visible
+ resume_train_button.visible = not resume_train_button.visible
+ if self.state != "completed":
+ self.state = "paused" if self.state == "training" else "training"
- def _toggle_train_s(self, _):
- if self.state.status == "completed":
- return
- self.state.status = "paused" if self.state.status == "training" else "training"
+ train_util_slider = self.server.gui.add_slider(
+ "Train Util", min=0.0, max=1.0, step=0.05, initial_value=0.9
+ )
+ train_util_slider.on_update(self.rerender)
+
+ self._training_tab_handles = {
+ "step_number": step_number,
+ "pause_train_button": pause_train_button,
+ "resume_train_button": resume_train_button,
+ "train_util_slider": train_util_slider,
+ }
+
+ def _init_rendering_tab(self):
+ # Allow subclasses to override for custom rendering table
+ self.render_tab_state = RenderTabState()
+ self._rendering_tab_handles = {}
+ self._rendering_folder = self.server.gui.add_folder("Rendering")
+
+ def _populate_rendering_tab(self):
+ # Allow subclasses to override for custom rendering table
+ assert self.render_tab_state is not None, "Render tab state is not initialized"
+ assert self._rendering_folder is not None, "Rendering folder is not initialized"
+ with self._rendering_folder:
+ viewer_res_slider = self.server.gui.add_slider(
+ "Viewer Res",
+ min=64,
+ max=2048,
+ step=1,
+ initial_value=2048,
+ hint="Maximum resolution of the viewer rendered image.",
+ )
+
+ @viewer_res_slider.on_update
+ def _(_) -> None:
+ self.render_tab_state.viewer_res = int(viewer_res_slider.value)
+ self.rerender(_)
+
+ self._rendering_tab_handles["viewer_res_slider"] = viewer_res_slider
+
+ # training tab handles should also be disabled during dumping video.
+ extra_handles = self._rendering_tab_handles.copy()
+ if self.mode == "training":
+ extra_handles.update(self._training_tab_handles)
+ handles = populate_general_render_tab(
+ self.server,
+ output_dir=self.output_dir,
+ folder=self._rendering_folder,
+ render_tab_state=self.render_tab_state,
+ extra_handles=extra_handles,
+ )
+ self._rendering_tab_handles.update(handles)
def rerender(self, _):
clients = self.server.get_clients()
@@ -193,21 +236,23 @@ def update(self, step: int, num_train_rays_per_step: int):
if step < 5:
return
self._step = step
- with self.server.atomic(), self._stats_folder:
- self._stats_text.content = self._stats_text_fn()
+ self._training_tab_handles["step_number"].value = step
if len(self._renderers) == 0:
return
# Stop training while user moves camera to make viewing smoother.
while time.time() - self._last_move_time < 0.1:
time.sleep(0.05)
- if self.state.status == "training" and self._train_util_slider.value != 1:
+ if (
+ self.state == "training"
+ and self._training_tab_handles["train_util_slider"].value != 1
+ ):
assert (
- self.state.num_train_rays_per_sec is not None
+ self.render_tab_state.num_train_rays_per_sec is not None
), "User must keep track of `num_train_rays_per_sec` to use `update`."
- train_s = self.state.num_train_rays_per_sec
- view_s = self.state.num_view_rays_per_sec
- train_util = self._train_util_slider.value
- view_n = self._max_img_res_slider.value**2
+ train_s = self.render_tab_state.num_train_rays_per_sec
+ view_s = self.render_tab_state.num_view_rays_per_sec
+ train_util = self._training_tab_handles["train_util_slider"].value
+ view_n = self.render_tab_state.viewer_res**2
train_n = num_train_rays_per_step
train_time = train_n / train_s
view_time = view_n / view_s
@@ -223,16 +268,15 @@ def update(self, step: int, num_train_rays_per_step: int):
self._renderers[client_id].submit(
RenderTask("update", camera_state)
)
- with self.server.atomic(), self._stats_folder:
- self._stats_text.content = self._stats_text_fn()
+
+ def _after_render(self):
+ # This function will be called each time render_fn is called.
+ # It can be used to update the viewer panel.
+ pass
def complete(self):
- self.state.status = "completed"
- self._pause_train_button.disabled = True
- self._resume_train_button.disabled = True
- self._train_util_slider.disabled = True
- with self.server.atomic(), self._stats_folder:
- self._stats_text.content = f"""
- Step: {self._step}\\
- Training Completed!
- """
+ print("Training complete, disable training tab.")
+ self.state = "completed"
+ self._training_tab_handles["pause_train_button"].disabled = True
+ self._training_tab_handles["resume_train_button"].disabled = True
+ self._training_tab_handles["train_util_slider"].disabled = True