Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"jsonargparse>=4.37.0,<5.0.0",
# TODO imagecodecs version > 2023.9.18 produce much larger JPEG XL files: https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/issues/3
# but for numpy > 2, we require a modern imagecodecs
# "imagecodecs[all]>=2024.12.30",
#"imagecodecs[all]>=2024.12.30",
"PyYAML>=6.0.2,<7.0.0",
"opencv-python>=4.11.0.86,<5.0.0.0",
"pillow>=11.1.0,<12.0.0",
Expand Down Expand Up @@ -158,6 +158,7 @@ ignore = [

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["S101"]
"src/ffsplat/models/transformations.py"=["C901"]

[tool.ruff.format]
preview = true
Expand Down
2 changes: 1 addition & 1 deletion src/ffsplat/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def rasterize_splats(
means=gaussians.means.data,
quats=gaussians.quaternions.data,
scales=gaussians.scales.data,
opacities=gaussians.opacities.data,
opacities=gaussians.opacities.data.squeeze(),
colors=colors,
viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
Ks=Ks, # [C, 3, 3]
Expand Down
28 changes: 15 additions & 13 deletions src/ffsplat/cli/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
from ..render.viewer import CameraState, Viewer

available_output_format: list[str] = [
"SOG-PlayCanvas",
"SOG-web",
"3DGS_INRIA_ply",
"3DGS_INRIA_nosh_ply",
"SOG-web",
"SOG-web-png",
"SOG-web-nosh",
"SOG-web-sh-split",
Expand Down Expand Up @@ -397,17 +398,24 @@ def full_evaluation(self, _):
self.enable_convert_ui()
self.enable_load_buttons()

def _build_transform_folder(self, transform_folder, description, transformation, transform_type):
def _build_transform_folder(self, transform_folder, operation, transformation, transform_type):
# clear transform folder for rebuild
for child in tuple(transform_folder._children.values()):
child.remove()
dynamic_params_conf = get_dynamic_params(transformation)

if isinstance(operation["input_fields"], list):
input_field = operation["input_fields"]
description = f"input fields: {input_field}"
else:
input_field = operation["input_fields"]["from_fields_with_prefix"]
description = f"input fields from prefix: {input_field}"
dynamic_params_conf = get_dynamic_params(transformation, input_field)
initial_values = transformation[transform_type]

with transform_folder:
self.viewer.server.gui.add_markdown(description)
rebuild_fn = partial(
self._build_transform_folder, transform_folder, description, transformation, transform_type
self._build_transform_folder, transform_folder, operation, transformation, transform_type
)
self._build_options_for_transformation(dynamic_params_conf, initial_values, rebuild_fn)

Expand Down Expand Up @@ -503,20 +511,14 @@ def _build_convert_options(self):
for transformation in operation["transforms"]:
# get list with customizable options

dynamic_transform_conf = get_dynamic_params(transformation)
dynamic_transform_conf = get_dynamic_params(transformation, operation["input_fields"])
if len(dynamic_transform_conf) == 0:
continue
transform_type = next(iter(transformation.keys()))
transform_folder = self.viewer.server.gui.add_folder(transform_type)
self.viewer.convert_gui_handles.append(transform_folder)
if isinstance(operation["input_fields"], list):
description = f"input fields: {operation["input_fields"]}"
else:
description = (
f"input fields from prefix: {operation["input_fields"]["from_fields_with_prefix"]}"
)

self._build_transform_folder(transform_folder, description, transformation, transform_type)
self._build_transform_folder(transform_folder, operation, transformation, transform_type)

@torch.no_grad()
def render_fn(
Expand All @@ -541,7 +543,7 @@ def render_fn(
means_t, # [N, 3]
quats_t, # [N, 4]
scales_t, # [N, 3]
opacities_t, # [N]
opacities_t.squeeze(), # [N]
colors, # [N, S, 3]
viewmat[None], # [1, 4, 4]
K[None], # [1, 3, 3]
Expand Down
2 changes: 1 addition & 1 deletion src/ffsplat/cli/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def render_fn(
means_t, # [N, 3]
quats_t, # [N, 4]
scales_t, # [N, 3]
opacities_t, # [N]
opacities_t.squeeze(), # [N]
colors, # [N, S, 3]
viewmat[None], # [1, 4, 4]
K[None], # [1, 3, 3]
Expand Down
9 changes: 7 additions & 2 deletions src/ffsplat/coding/scene_decoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
Expand Down Expand Up @@ -47,6 +48,9 @@ def with_input_path(self, input_path: Path) -> "DecodingParams":
self.ops[0]["transforms"][0]["read_file"]["file_path"] = str(input_path)
return self

def get_ops_hashable(self) -> str:
return json.dumps(self.ops, sort_keys=False)


@lru_cache
def process_operation(
Expand All @@ -56,7 +60,7 @@ def process_operation(
"""Process the operation and return the new fields and decoding updates."""
if verbose:
print(f"Decoding {op}...")
return op.apply(verbose=verbose)[0]
return op.apply(verbose=verbose, decoding_params_hashable="")[0]


@dataclass
Expand All @@ -77,11 +81,12 @@ def _process_fields(self, verbose: bool = False) -> None:
def _create_scene(self) -> None:
match self.decoding_params.scene.get("primitives"):
case "3DGS-INRIA":
opacities_field = self.fields["opacities"]
self.scene = Gaussians(
means=self.fields["means"],
quaternions=self.fields["quaternions"],
scales=self.fields["scales"],
opacities=self.fields["opacities"],
opacities=Field(opacities_field.data.unsqueeze(-1), opacities_field.op),
sh=self.fields["sh"],
)
case _:
Expand Down
29 changes: 20 additions & 9 deletions src/ffsplat/coding/scene_encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import json
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field, is_dataclass
Expand Down Expand Up @@ -92,6 +93,18 @@ def reverse_ops(self) -> None:
continue
self.ops[-1]["transforms"].reverse()

def get_ops_hashable(self) -> str:
return json.dumps(self.ops, sort_keys=False)

def to_yaml(self) -> str:
"""Convert the decoding parameters to a YAML string."""
return yaml.dump(
asdict(self),
Dumper=SerializableDumper,
default_flow_style=False,
sort_keys=False,
)


@dataclass
class EncodingParams:
Expand Down Expand Up @@ -128,14 +141,11 @@ def to_yaml_file(self, yaml_path: Path) -> None:


@lru_cache
def process_operation(
op: Operation,
verbose: bool,
) -> tuple[dict[str, Field], list[dict[str, Any]]]:
def process_operation(op: Operation, verbose: bool, decoding_ops: str) -> tuple[dict[str, Field], list[dict[str, Any]]]:
"""Process the operation and return the new fields and decoding updates."""
if verbose:
print(f"Encoding {op}...")
return op.apply(verbose=verbose)
return op.apply(verbose=verbose, decoding_params_hashable=decoding_ops)


@dataclass
Expand All @@ -154,10 +164,9 @@ def _encode_fields(self, verbose: bool) -> None:
input_fields_params = op_params["input_fields"]
for transform_param in op_params["transforms"]:
op = Operation.from_json(input_fields_params, transform_param, self.fields, self.output_path)
if op.transform_type != "write_file":
new_fields, decoding_updates = process_operation(op, verbose=verbose)
else:
new_fields, decoding_updates = op.apply(verbose=verbose)
new_fields, decoding_updates = process_operation(
op, verbose=verbose, decoding_ops=self.decoding_params.to_yaml()
)
# if the coding_updates are not a copy the cache will be wrong
for decoding_update in copy.deepcopy(decoding_updates):
# if the last decoding update has the same input fields we can combine the transforms into one list
Expand Down Expand Up @@ -201,6 +210,8 @@ def encode_gaussians(gaussians: Gaussians, output_path: Path, output_format: str
encoding_params = EncodingParams.from_yaml_file(Path("src/ffsplat/conf/format/SOG-web-nosh.yaml"))
case "SOG-web-sh-split":
encoding_params = EncodingParams.from_yaml_file(Path("src/ffsplat/conf/format/SOG-web-sh-split.yaml"))
case "SOG-canvas":
encoding_params = EncodingParams.from_yaml_file(Path("src/ffsplat/conf/format/SOG-canvas.yaml"))
case _:
raise ValueError(f"Unsupported output format: {output_format}")

Expand Down
Loading
Loading