Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,19 @@ import nerfview


def render_fn(
camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
) -> np.ndarray:
# Parse camera state for camera-to-world matrix (c2w) and intrinsic (K) as
# float64 numpy arrays.
if render_tab_state.preview_render:
width = render_tab_state.render_width
height = render_tab_state.render_height
else:
width = render_tab_state.viewer_width
height = render_tab_state.viewer_height

c2w = camera_state.c2w
K = camera_state.get_K(img_wh)
K = camera_state.get_K([width, height])
# Do your things and get an image as a uint8 numpy array.
img = your_rendering_logic(...)
return img
Expand Down Expand Up @@ -139,7 +146,7 @@ which we include here to be self-contained.
# Only need to run once the first time.
bash examples/assets/download_gsplat_ckpt.sh
CUDA_VISIBLE_DEVICES=0 python examples/03_gsplat_rendering.py \
--ckpt results/garden/ckpts/ckpt_6999_crop.pt
--ckpt examples/assets/ckpt_6999_crop.pt
```

</details>
Expand Down
15 changes: 11 additions & 4 deletions examples/00_dummy_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,26 @@ def main(port: int = 8080, rendering_latency: float = 0.0):
"""

def render_fn(
camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
) -> UInt8[np.ndarray, "H W 3"]:
# Get camera parameters.
W, H = img_wh
if render_tab_state.preview_render:
width = render_tab_state.render_width
height = render_tab_state.render_height
else:
width = render_tab_state.viewer_width
height = render_tab_state.viewer_height
c2w = camera_state.c2w
K = camera_state.get_K(img_wh)
K = camera_state.get_K([width, height])

# Render a dummy image as a function of camera direction.
camera_dirs = np.einsum(
"ij,hwj->hwi",
np.linalg.inv(K),
np.pad(
np.stack(np.meshgrid(np.arange(W), np.arange(H), indexing="xy"), -1)
np.stack(
np.meshgrid(np.arange(width), np.arange(height), indexing="xy"), -1
)
+ 0.5,
((0, 0), (0, 0), (0, 1)),
constant_values=1.0,
Expand Down
21 changes: 14 additions & 7 deletions examples/01_dummy_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,26 @@ def main(port: int = 8080, max_steps: int = 50, rendering_latency: float = 0.0):
step: int = 0

def render_fn(
camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
) -> UInt8[np.ndarray, "H W 3"]:
# Get camera parameters.
W, H = img_wh
if render_tab_state.preview_render:
width = render_tab_state.render_width
height = render_tab_state.render_height
else:
width = render_tab_state.viewer_width
height = render_tab_state.viewer_height
c2w = camera_state.c2w
K = camera_state.get_K(img_wh)
K = camera_state.get_K([width, height])

# Render a dummy image as a function of camera direction.
camera_dirs = np.einsum(
"ij,hwj->hwi",
np.linalg.inv(K),
np.pad(
np.stack(np.meshgrid(np.arange(W), np.arange(H), indexing="xy"), -1)
np.stack(
np.meshgrid(np.arange(width), np.arange(height), indexing="xy"), -1
)
+ 0.5,
((0, 0), (0, 0), (0, 1)),
constant_values=1.0,
Expand Down Expand Up @@ -83,11 +90,11 @@ def training_step():
# Optionally make the training utility lower such that we update the scene
# more frequently in this example. You dont need to do this in your own
# code.
viewer._train_util_slider.value = 0.5
viewer._training_tab_handles["train_util_slider"].value = 0.5

for step in tqdm(range(max_steps)):
# Allow user to pause the training process.
while viewer.state.status == "paused":
while viewer.state == "paused":
time.sleep(0.01)
# Do the training step and compute the number of training rays per second.
tic = time.time()
Expand All @@ -96,7 +103,7 @@ def training_step():
num_train_steps_per_sec = 1.0 / (time.time() - tic)
num_train_rays_per_sec = num_train_rays_per_step * num_train_steps_per_sec
# Update the viewer state.
viewer.state.num_train_rays_per_sec = num_train_rays_per_sec
viewer.render_tab_state.num_train_rays_per_sec = num_train_rays_per_sec
# Update the scene.
viewer.update(step, num_train_rays_per_step)
viewer.complete()
Expand Down
16 changes: 12 additions & 4 deletions examples/02_mesh_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,21 @@ def main(port: int = 8080):
)

def render_fn(
camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
) -> UInt8[np.ndarray, "H W 3"]:
# nvdiffrast requires the image size to be multiples of 8.
img_wh = (img_wh[0] // 8 * 8, img_wh[1] // 8 * 8)

# Get camera parameters.
if render_tab_state.preview_render:
width = render_tab_state.render_width
height = render_tab_state.render_height
else:
width = render_tab_state.viewer_width
height = render_tab_state.viewer_height

# nvdiffrast requires the image size to be multiples of 8.
width = width // 8 * 8
height = height // 8 * 8
c2w = camera_state.c2w
img_wh = [width, height]
K = camera_state.get_K(img_wh)

# Compute the normal map.
Expand Down
13 changes: 10 additions & 3 deletions examples/03_gsplat_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,17 @@

# register and open viewer
@torch.no_grad()
def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]):
width, height = img_wh
def viewer_render_fn(
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
):
if render_tab_state.preview_render:
width = render_tab_state.render_width
height = render_tab_state.render_height
else:
width = render_tab_state.viewer_width
height = render_tab_state.viewer_height
c2w = camera_state.c2w
K = camera_state.get_K(img_wh)
K = camera_state.get_K([width, height])
c2w = torch.from_numpy(c2w).float().to(device)
K = torch.from_numpy(K).float().to(device)
viewmat = c2w.inverse()
Expand Down
21 changes: 16 additions & 5 deletions examples/04_gsplat_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,20 +916,31 @@ def render_traj(self, step: int):

@torch.no_grad()
def _viewer_render_fn(
self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
self,
camera_state: nerfview.CameraState,
render_tab_state: nerfview.RenderTabState,
):
"""Callable function for the viewer."""
W, H = img_wh
if render_tab_state.preview_render:
width, height = (
render_tab_state.render_width,
render_tab_state.render_height,
)
else:
width, height = (
render_tab_state.viewer_width,
render_tab_state.viewer_height,
)
c2w = camera_state.c2w
K = camera_state.get_K(img_wh)
K = camera_state.get_K([width, height])
c2w = torch.from_numpy(c2w).float().to(self.device)
K = torch.from_numpy(K).float().to(self.device)

render_colors, _, _ = self.rasterize_splats(
camtoworlds=c2w[None],
Ks=K[None],
width=W,
height=H,
width=width,
height=height,
sh_degree=self.cfg.sh_degree, # active all SH degrees
radius_clip=3.0, # skip GSs that have small image radius (in pixels)
backgrounds=torch.ones(1, 3, device=self.device),
Expand Down
1 change: 1 addition & 0 deletions nerfview/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .render_panel import RenderTabState
from .version import __version__
from .viewer import VIEWER_LOCK, CameraState, Viewer, with_viewer_lock
41 changes: 32 additions & 9 deletions nerfview/_renderer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Modified from nerfview/_renderer.py
"""

import dataclasses
import os
import sys
Expand All @@ -9,7 +13,7 @@
import viser

if TYPE_CHECKING:
from .viewer import CameraState, Viewer
from examples.viewer import CameraState, Viewer

RenderState = Literal["low_move", "low_static", "high"]
RenderAction = Literal["rerender", "move", "static", "update"]
Expand Down Expand Up @@ -53,14 +57,15 @@ def __init__(
self.lock = lock

self.running = True
self.is_prepared_fn = lambda: self.viewer.state.status != "preparing"
self.is_prepared_fn = lambda: self.viewer.state != "preparing"

self._render_event = threading.Event()
self._state: RenderState = "low_static"
self._task: Optional[RenderTask] = None

self._target_fps = 30
self._may_interrupt_render = False
self._old_version = False

self._define_transitions()

Expand All @@ -84,16 +89,17 @@ def _may_interrupt_trace(self, frame, event, arg):
return self._may_interrupt_trace

def _get_img_wh(self, aspect: float) -> Tuple[int, int]:
max_img_res = self.viewer._max_img_res_slider.value
if self._state == "high":
# we always trade off speed for quality
max_img_res = self.viewer.render_tab_state.viewer_res
if self._state in ["high"]:
# if True:
H = max_img_res
W = int(H * aspect)
if W > max_img_res:
W = max_img_res
H = int(W / aspect)
elif self._state in ["low_move", "low_static"]:
num_view_rays_per_sec = self.viewer.state.num_view_rays_per_sec
num_view_rays_per_sec = self.viewer.render_tab_state.num_view_rays_per_sec
target_fps = self._target_fps
num_viewer_rays = num_view_rays_per_sec / target_fps
H = (num_viewer_rays / aspect) ** 0.5
Expand Down Expand Up @@ -141,13 +147,31 @@ def run(self):
with self.lock, set_trace_context(self._may_interrupt_trace):
tic = time.time()
W, H = img_wh = self._get_img_wh(task.camera_state.aspect)
rendered = self.viewer.render_fn(task.camera_state, img_wh)
self.viewer.render_tab_state.viewer_width = W
self.viewer.render_tab_state.viewer_height = H

if not self._old_version:
try:
rendered = self.viewer.render_fn(
task.camera_state,
self.viewer.render_tab_state,
)
except TypeError:
self._old_version = True
print(
"[WARNING] Your API will be deprecated in the future, please update your render_fn."
)
rendered = self.viewer.render_fn(task.camera_state, img_wh)
else:
rendered = self.viewer.render_fn(task.camera_state, img_wh)

self.viewer._after_render()
if isinstance(rendered, tuple):
img, depth = rendered
else:
img, depth = rendered, None
self.viewer.state.num_view_rays_per_sec = (W * H) / (
max(time.time() - tic, 1e-6)
self.viewer.render_tab_state.num_view_rays_per_sec = (W * H) / (
time.time() - tic
)
except InterruptRenderException:
continue
Expand All @@ -160,4 +184,3 @@ def run(self):
jpeg_quality=70 if task.action in ["static", "update"] else 40,
depth=depth,
)
self.client.flush()
Loading