diff --git a/pyproject.toml b/pyproject.toml index fc2bfd0..add9078 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "jsonargparse>=4.37.0,<5.0.0", # TODO imagecodecs version > 2023.9.18 produce much larger JPEG XL files: https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/issues/3 # but for numpy > 2, we require a modern imagecodecs - # "imagecodecs[all]>=2024.12.30", + #"imagecodecs[all]>=2024.12.30", "PyYAML>=6.0.2,<7.0.0", "opencv-python>=4.11.0.86,<5.0.0.0", "pillow>=11.1.0,<12.0.0", @@ -158,6 +158,7 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "tests/*" = ["S101"] +"src/ffsplat/models/transformations.py"=["C901"] [tool.ruff.format] preview = true diff --git a/src/ffsplat/cli/eval.py b/src/ffsplat/cli/eval.py index a0e0d19..7f002d6 100644 --- a/src/ffsplat/cli/eval.py +++ b/src/ffsplat/cli/eval.py @@ -37,7 +37,7 @@ def rasterize_splats( means=gaussians.means.data, quats=gaussians.quaternions.data, scales=gaussians.scales.data, - opacities=gaussians.opacities.data, + opacities=gaussians.opacities.data.squeeze(), colors=colors, viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] Ks=Ks, # [C, 3, 3] diff --git a/src/ffsplat/cli/live.py b/src/ffsplat/cli/live.py index ea3dbba..b6fc1d4 100644 --- a/src/ffsplat/cli/live.py +++ b/src/ffsplat/cli/live.py @@ -36,9 +36,10 @@ from ..render.viewer import CameraState, Viewer available_output_format: list[str] = [ + "SOG-PlayCanvas", + "SOG-web", "3DGS_INRIA_ply", "3DGS_INRIA_nosh_ply", - "SOG-web", "SOG-web-png", "SOG-web-nosh", "SOG-web-sh-split", @@ -397,17 +398,24 @@ def full_evaluation(self, _): self.enable_convert_ui() self.enable_load_buttons() - def _build_transform_folder(self, transform_folder, description, transformation, transform_type): + def _build_transform_folder(self, transform_folder, operation, transformation, transform_type): # clear transform folder for rebuild for child in tuple(transform_folder._children.values()): child.remove() - dynamic_params_conf = get_dynamic_params(transformation) + + if isinstance(operation["input_fields"], list): + input_field = operation["input_fields"] + description = f"input fields: {input_field}" + else: + input_field = operation["input_fields"]["from_fields_with_prefix"] + description = f"input fields from prefix: {input_field}" + dynamic_params_conf = get_dynamic_params(transformation, input_field) initial_values = transformation[transform_type] with transform_folder: self.viewer.server.gui.add_markdown(description) rebuild_fn = partial( - self._build_transform_folder, transform_folder, description, transformation, transform_type + self._build_transform_folder, transform_folder, operation, transformation, transform_type ) self._build_options_for_transformation(dynamic_params_conf, initial_values, rebuild_fn) @@ -503,20 +511,14 @@ def _build_convert_options(self): for transformation in operation["transforms"]: # get list with customizable options - dynamic_transform_conf = get_dynamic_params(transformation) + dynamic_transform_conf = get_dynamic_params(transformation, operation["input_fields"]) if len(dynamic_transform_conf) == 0: continue transform_type = next(iter(transformation.keys())) transform_folder = self.viewer.server.gui.add_folder(transform_type) self.viewer.convert_gui_handles.append(transform_folder) - if isinstance(operation["input_fields"], list): - description = f"input fields: {operation["input_fields"]}" - else: - description = ( - f"input fields from prefix: {operation["input_fields"]["from_fields_with_prefix"]}" - ) - self._build_transform_folder(transform_folder, description, transformation, transform_type) + self._build_transform_folder(transform_folder, operation, transformation, transform_type) @torch.no_grad() def render_fn( @@ -541,7 +543,7 @@ def render_fn( means_t, # [N, 3] quats_t, # [N, 4] scales_t, # [N, 3] - opacities_t, # [N] + opacities_t.squeeze(), # [N] colors, # [N, S, 3] viewmat[None], # [1, 4, 4] K[None], # [1, 3, 3] diff --git a/src/ffsplat/cli/view.py b/src/ffsplat/cli/view.py index c38ffad..cf09e91 100644 --- a/src/ffsplat/cli/view.py +++ b/src/ffsplat/cli/view.py @@ -35,7 +35,7 @@ def render_fn( means_t, # [N, 3] quats_t, # [N, 4] scales_t, # [N, 3] - opacities_t, # [N] + opacities_t.squeeze(), # [N] colors, # [N, S, 3] viewmat[None], # [1, 4, 4] K[None], # [1, 3, 3] diff --git a/src/ffsplat/coding/scene_decoder.py b/src/ffsplat/coding/scene_decoder.py index 99d294a..98574c9 100644 --- a/src/ffsplat/coding/scene_decoder.py +++ b/src/ffsplat/coding/scene_decoder.py @@ -1,3 +1,4 @@ +import json from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path @@ -47,6 +48,9 @@ def with_input_path(self, input_path: Path) -> "DecodingParams": self.ops[0]["transforms"][0]["read_file"]["file_path"] = str(input_path) return self + def get_ops_hashable(self) -> str: + return json.dumps(self.ops, sort_keys=False) + @lru_cache def process_operation( @@ -56,7 +60,7 @@ def process_operation( """Process the operation and return the new fields and decoding updates.""" if verbose: print(f"Decoding {op}...") - return op.apply(verbose=verbose)[0] + return op.apply(verbose=verbose, decoding_params_hashable="")[0] @dataclass @@ -77,11 +81,12 @@ def _process_fields(self, verbose: bool = False) -> None: def _create_scene(self) -> None: match self.decoding_params.scene.get("primitives"): case "3DGS-INRIA": + opacities_field = self.fields["opacities"] self.scene = Gaussians( means=self.fields["means"], quaternions=self.fields["quaternions"], scales=self.fields["scales"], - opacities=self.fields["opacities"], + opacities=Field(opacities_field.data.unsqueeze(-1), opacities_field.op), sh=self.fields["sh"], ) case _: diff --git a/src/ffsplat/coding/scene_encoder.py b/src/ffsplat/coding/scene_encoder.py index 4d3e38d..449967c 100644 --- a/src/ffsplat/coding/scene_encoder.py +++ b/src/ffsplat/coding/scene_encoder.py @@ -1,4 +1,5 @@ import copy +import json from collections import defaultdict from collections.abc import Iterable from dataclasses import asdict, dataclass, field, is_dataclass @@ -92,6 +93,18 @@ def reverse_ops(self) -> None: continue self.ops[-1]["transforms"].reverse() + def get_ops_hashable(self) -> str: + return json.dumps(self.ops, sort_keys=False) + + def to_yaml(self) -> str: + """Convert the decoding parameters to a YAML string.""" + return yaml.dump( + asdict(self), + Dumper=SerializableDumper, + default_flow_style=False, + sort_keys=False, + ) + @dataclass class EncodingParams: @@ -128,14 +141,11 @@ def to_yaml_file(self, yaml_path: Path) -> None: @lru_cache -def process_operation( - op: Operation, - verbose: bool, -) -> tuple[dict[str, Field], list[dict[str, Any]]]: +def process_operation(op: Operation, verbose: bool, decoding_ops: str) -> tuple[dict[str, Field], list[dict[str, Any]]]: """Process the operation and return the new fields and decoding updates.""" if verbose: print(f"Encoding {op}...") - return op.apply(verbose=verbose) + return op.apply(verbose=verbose, decoding_params_hashable=decoding_ops) @dataclass @@ -154,10 +164,9 @@ def _encode_fields(self, verbose: bool) -> None: input_fields_params = op_params["input_fields"] for transform_param in op_params["transforms"]: op = Operation.from_json(input_fields_params, transform_param, self.fields, self.output_path) - if op.transform_type != "write_file": - new_fields, decoding_updates = process_operation(op, verbose=verbose) - else: - new_fields, decoding_updates = op.apply(verbose=verbose) + new_fields, decoding_updates = process_operation( + op, verbose=verbose, decoding_ops=self.decoding_params.to_yaml() + ) # if the coding_updates are not a copy the cache will be wrong for decoding_update in copy.deepcopy(decoding_updates): # if the last decoding update has the same input fields we can combine the transforms into one list @@ -201,6 +210,8 @@ def encode_gaussians(gaussians: Gaussians, output_path: Path, output_format: str encoding_params = EncodingParams.from_yaml_file(Path("src/ffsplat/conf/format/SOG-web-nosh.yaml")) case "SOG-web-sh-split": encoding_params = EncodingParams.from_yaml_file(Path("src/ffsplat/conf/format/SOG-web-sh-split.yaml")) + case "SOG-canvas": + encoding_params = EncodingParams.from_yaml_file(Path("src/ffsplat/conf/format/SOG-canvas.yaml")) case _: raise ValueError(f"Unsupported output format: {output_format}") diff --git a/src/ffsplat/conf/format/SOG-PlayCanvas.yaml b/src/ffsplat/conf/format/SOG-PlayCanvas.yaml new file mode 100644 index 0000000..77ef7e7 --- /dev/null +++ b/src/ffsplat/conf/format/SOG-PlayCanvas.yaml @@ -0,0 +1,337 @@ +profile: SOG-PlayCanvas +profile_version: 1.0 + +scene: + primitives: 3DGS-INRIA + params: + - means + - scales + - opacities + - quaternions + - sh + +ops: + - input_fields: [sh] + transforms: + - split: + to_field_list: [sh0, shN] + split_size_or_sections: [1, 15] + dim: 1 + squeeze: false + + - input_fields: [means] + transforms: + - remapping: + method: signed-log + to_field: means + + - input_fields: [scales] + transforms: + - remapping: + method: log + + - input_fields: [opacities] + transforms: + - remapping: + method: inverse-sigmoid + + - input_fields: [quaternions] + transforms: + - reparametize: + method: unit_sphere + dim: -1 + + - input_fields: [means, sh0, shN, opacities, scales, quaternions] + transforms: + - sort: + method: plas + prune_by: opacities + scaling_fn: none + shuffle: true + improvement_break: 1e-4 + to_field: sorted_indices + weights: + means: 1.0 + sh0: 1.0 + shN: 0.0 + opacities: 0.0 + scales: 1.0 + quaternions: 1.0 + + #means + - input_fields: [means, sorted_indices] + transforms: + - reindex: + src_field: means + index_field: sorted_indices + + - input_fields: [means] + transforms: + - simple_quantize: + dtype: uint16 + min: 0 + max: 65535 + dim: 2 + round_to_int: true + - split_bytes: + to_fields_with_prefix: means_bytes_ + num_bytes: 2 + + - input_fields: [means_bytes_0] + transforms: + - to_field: + to_field_name: means_l + + - input_fields: [means_bytes_1] + transforms: + - to_field: + to_field_name: means_u + + #scales + - input_fields: [scales, sorted_indices] + transforms: + - reindex: + src_field: scales + index_field: sorted_indices + + - input_fields: [scales] + transforms: + - simple_quantize: + min: 0 + max: 255 + dim: 2 + dtype: uint8 + round_to_int: true + + - input_fields: [opacities, sorted_indices] + transforms: + - reindex: + src_field: opacities + index_field: sorted_indices + + - input_fields: [quaternions, sorted_indices] + transforms: + - reindex: + src_field: quaternions + index_field: sorted_indices + + - input_fields: [quaternions] # for tracking decoding_ops + transforms: + - to_field: + to_field_name: quats + + - input_fields: [quats] + transforms: + - reparametize: + method: pack_quaternions + to_fields_with_prefix: quats_packed_ + dim: -1 + + - input_fields: [quats_packed_indices] + transforms: + - simple_quantize: + min: 252 + max: 255 + dim: 2 + dtype: uint8 + round_to_int: true + + - input_fields: [quats_packed_values] + transforms: + - simple_quantize: + min: 0 + max: 255 + dim: 2 + dtype: uint8 + round_to_int: true + + - input_fields: [quats_packed_values, quats_packed_indices] + transforms: + - combine: + method: concat + dim: 2 + to_field: quats + + - input_fields: [sh0, sorted_indices] + transforms: + - reindex: + src_field: sh0 + index_field: sorted_indices + + - input_fields: [shN, sorted_indices] + transforms: + - reindex: + src_field: shN + index_field: sorted_indices + + # shN[sorted_indices] -> shN + + # new_blah[labels] -> labels + + - input_fields: [sh0] + transforms: + - permute: + dims: [0, 1, 3, 2] + - flatten: + start_dim: 2 + + - input_fields: [sh0, opacities] #opacity-sh0 rgba + transforms: + - combine: + method: concat + dim: 2 + to_field: sh0 + + # TODO: quantize sh0 and opacities with different ranges + - input_fields: [sh0] + transforms: + - simple_quantize: + min: 0 + max: 255 + dim: 2 + dtype: uint8 + round_to_int: true + + - input_fields: [shN] + transforms: + - flatten: + start_dim: 2 + - flatten: + start_dim: 0 + end_dim: 1 + - cluster: + method: kmeans + num_clusters: 65536 + distance: manhattan + to_fields_with_prefix: shN_ + + - input_fields: [shN_centroids] + transforms: + - simple_quantize: + min: 0 + max: 255 + dim: 2 + dtype: uint8 + round_to_int: true + + - input_fields: [shN_centroids, shN_labels] + transforms: + - sort: + method: lexicographic + labels: shN_labels + weights: + shN_labels: 0.0 + shN_centroids: 1.0 + to_field: shN_centroids_indices + + - input_fields: [shN_centroids, shN_centroids_indices] + transforms: + - reindex: + src_field: shN_centroids + index_field: shN_centroids_indices + + - input_fields: [shN_centroids] + transforms: + - reshape: + shape: [-1, 960, 3] # int(num_clusters*num_spherical_harmonics/3) = + + - input_fields: [shN_labels] + transforms: + - simple_quantize: + min: 0 + max: 65535 + dim: 2 + dtype: uint16 + round_to_int: False + - split_bytes: + to_fields_with_prefix: shN_labels_ + num_bytes: 2 + + - input_fields: [shN_labels_0, shN_labels_1] + transforms: + - combine: + method: stack-zeros + dim: 2 + to_field: shN_labels + + - input_fields: [means_l] + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [means_u] + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [scales] + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [quats] + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [sh0] #rgb + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [shN_centroids] + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [shN_labels] + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [means, scales, quats, sh0, shN] + transforms: + - write_file: + type: canvas-metadata diff --git a/src/ffsplat/conf/format/SOG-web-nosh.yaml b/src/ffsplat/conf/format/SOG-web-nosh.yaml index 8a0e8e4..008bbb1 100644 --- a/src/ffsplat/conf/format/SOG-web-nosh.yaml +++ b/src/ffsplat/conf/format/SOG-web-nosh.yaml @@ -31,12 +31,13 @@ ops: split_size_or_sections: [1] dim: 1 squeeze: false - to_field_list: [f_dc] + to_field_list: [f_dc, _] - input_fields: [means_bytes_0, means_bytes_1, f_dc, opacities, scales, quaternions] transforms: - - plas: + - sort: + method: plas prune_by: opacities scaling_fn: standardize # activated: true diff --git a/src/ffsplat/conf/format/SOG-web-png.yaml b/src/ffsplat/conf/format/SOG-web-png.yaml index 3a3d0ec..5af9306 100644 --- a/src/ffsplat/conf/format/SOG-web-png.yaml +++ b/src/ffsplat/conf/format/SOG-web-png.yaml @@ -44,7 +44,8 @@ ops: quaternions, ] transforms: - - plas: + - sort: + method: plas prune_by: opacities scaling_fn: standardize # activated: true diff --git a/src/ffsplat/conf/format/SOG-web-sh-split.yaml b/src/ffsplat/conf/format/SOG-web-sh-split.yaml index 5c97125..656f93b 100644 --- a/src/ffsplat/conf/format/SOG-web-sh-split.yaml +++ b/src/ffsplat/conf/format/SOG-web-sh-split.yaml @@ -44,14 +44,15 @@ ops: quaternions, ] transforms: - - plas: + - sort: + method: plas prune_by: opacities scaling_fn: standardize # activated: true shuffle: true improvement_break: 1e-4 # improvement_break: 0.1 - to_field: sorted_indices + to_fields: sorted_indices weights: means_bytes_0: 0.1 means_bytes_1: 1.0 diff --git a/src/ffsplat/conf/format/SOG-web.yaml b/src/ffsplat/conf/format/SOG-web.yaml index 8315aa6..20ba699 100644 --- a/src/ffsplat/conf/format/SOG-web.yaml +++ b/src/ffsplat/conf/format/SOG-web.yaml @@ -44,7 +44,8 @@ ops: quaternions, ] transforms: - - plas: + - sort: + method: plas prune_by: opacities scaling_fn: standardize # activated: true diff --git a/src/ffsplat/models/operations.py b/src/ffsplat/models/operations.py index 7f3071e..73c801b 100644 --- a/src/ffsplat/models/operations.py +++ b/src/ffsplat/models/operations.py @@ -69,5 +69,5 @@ def to_json(self) -> dict[str, Any]: "params": self.params, } - def apply(self, verbose: bool) -> tuple[dict[str, "Field"], list[dict[str, Any]]]: - return apply_transform(self, verbose=verbose) + def apply(self, verbose: bool, decoding_params_hashable: str) -> tuple[dict[str, "Field"], list[dict[str, Any]]]: + return apply_transform(self, verbose=verbose, decoding_params_hashable=decoding_params_hashable) diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index a17615c..de46609 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -1,3 +1,4 @@ +import json import math from abc import ABC, abstractmethod from collections.abc import Callable @@ -8,6 +9,7 @@ import cv2 import numpy as np import torch +import yaml from PIL import Image from pillow_heif import register_avif_opener # type: ignore[import-untyped] from plas import sort_with_plas # type: ignore[import-untyped] @@ -73,10 +75,26 @@ def write_image(output_file_path: Path, field_data: Tensor, file_type: str, codi chroma=coding_params.get("chroma", 444), matrix_coefficients=coding_params.get("matrix_coefficients", 0), ) + case "webp": + Image.fromarray(field_data.cpu().numpy()).save( + output_file_path, + format="WEBP", + quality=coding_params.get("quality", 100), + lossless=coding_params.get("lossless", True), + method=coding_params.get("method", 6), + exact=coding_params.get("exact", True), + ) case _: raise ValueError(f"Unsupported file type: {file_type}") +def write_json(output_file_path: Path, data: dict[str, Any]) -> None: + """Write a dictionary to a JSON file.""" + + with open(output_file_path, "w") as f: + json.dump(data, f, indent=2) + + @dataclass class PLASConfig: """Configuration for PLAS sorting.""" @@ -112,7 +130,7 @@ class Transformation(ABC): @staticmethod @abstractmethod def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: """Apply the transformation to the input fields. returns new/updated fields and decoding updates. Transformations that are only available for decoding do return empty decoding updates""" @@ -128,7 +146,7 @@ class Cluster(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: # Implement the clustering logic here input_fields = parentOp.input_fields @@ -184,7 +202,7 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: dynamic_params_config.append({ "label": "distance", "type": "dropdown", - "values": ["euclidean", "cosine", "manhatten"], + "values": ["euclidean", "cosine", "manhattan"], }) return dynamic_params_config @@ -194,7 +212,7 @@ class Split(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -212,9 +230,12 @@ def apply( }: chunks = field_data.split(split_size_or_sections, dim) for target_field_name, chunk in zip(to_field_list, chunks, strict=False): + if target_field_name == "_": + continue if squeeze: chunk = chunk.squeeze(dim) new_fields[target_field_name] = Field(chunk, parentOp) + to_field_list = [name for name in to_field_list if name != "_"] decoding_update.append({ "input_fields": to_field_list, @@ -261,7 +282,7 @@ class Remapping(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -373,11 +394,95 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: return dynamic_params_config +class Reparametrize(Transformation): + @staticmethod + @override + def apply( + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any + ) -> tuple[dict[str, Field], list[dict[str, Any]]]: + input_fields = parentOp.input_fields + + field_name = next(iter(input_fields.keys())) + field_data = input_fields[field_name].data + + new_fields: dict[str, Field] = {} + decoding_update: list[dict[str, Any]] = [] + + match params: + case {"method": "unit_sphere", "dim": dim}: + field_data = field_data / torch.linalg.norm(field_data, dim=dim, keepdim=True) + + sign_mask = field_data.select(dim, 3) < 0 + field_data[sign_mask] *= -1 + new_fields[field_name] = Field(field_data, parentOp) + + case {"method": "pack_quaternions", "to_fields_with_prefix": to_fields_with_prefix, "dim": dim}: + # Ensure the first component is positive + sign = field_data.select(dim, 0).sign() + sign[sign == 0] = 1 + q_signed = field_data * sign.unsqueeze(-1) + + # Drop the first component (is always the largest) + values = q_signed.narrow(dim, 1, 3) + max_idx = torch.zeros_like(sign, dtype=values.dtype).unsqueeze(-1) + + new_fields[f"{to_fields_with_prefix}indices"] = Field(max_idx, parentOp) + new_fields[f"{to_fields_with_prefix}values"] = Field(values, parentOp) + decoding_update.append({ + "input_fields": [f"{to_fields_with_prefix}indices", f"{to_fields_with_prefix}values"], + "transforms": [ + { + "reparametize": { + "method": "unpack_quaternions", + "from_fields_with_prefix": to_fields_with_prefix, + "dim": -1, + "to_field_name": field_name, + } + } + ], + }) + + case { + "method": "unpack_quaternions", + "from_fields_with_prefix": from_fields_with_prefix, + "dim": dim, + "to_field_name": to_field_name, + }: + values_field = input_fields[f"{from_fields_with_prefix}values"].data + + if values_field is None: + raise ValueError("values field data is None before unit sphere recovery") + if values_field.shape[dim] != 3: + raise ValueError(f"Field data shape mismatch for unit sphere recovery: {field_data.shape}") + + partial_norm_sq = (values_field**2).sum(dim=dim, keepdim=True) + # Recover the missing component (always non-negative) + w = torch.sqrt(torch.clamp(1.0 - partial_norm_sq, min=0.0)) + + num_components = values_field.shape[-1] + 1 # Original data had one more component + reconstructed = torch.zeros( + *values_field.shape[:-1], num_components, dtype=values_field.dtype, device=values_field.device + ) + + # wxyz convention + reconstructed[..., 0] = w.squeeze(-1) + reconstructed[..., 1:] = values_field + + # Optional re-normalization to mitigate numerical drift + field_data = reconstructed / torch.linalg.norm(reconstructed, dim=dim, keepdim=True) + new_fields[to_field_name] = Field(field_data, parentOp) + + case _: + raise ValueError(f"Unknown ToField parameters: {params}") + + return new_fields, decoding_update + + class ToField(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -403,7 +508,7 @@ class Flatten(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -416,6 +521,11 @@ def apply( decoding_update: list[dict[str, Any]] = [] match params: case {"start_dim": start_dim, "end_dim": end_dim}: + target_shape = field_data.shape + decoding_update.append({ + "input_fields": [field_name], + "transforms": [{"reshape": {"shape": target_shape}}], + }) field_data = field_data.flatten(start_dim=start_dim, end_dim=end_dim) case {"start_dim": start_dim}: target_shape = field_data.shape[start_dim:] @@ -435,7 +545,7 @@ class Reshape(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -446,6 +556,10 @@ def apply( decoding_update: list[dict[str, Any]] = [] match params: case {"start_dim": start_dim, "shape": shape}: + decoding_update.append({ + "input_fields": [field_name], + "transforms": [{"reshape": {"shape": field_data.shape}}], + }) target_shape = list(field_data.shape[:start_dim]) + list(shape) field_data = field_data.reshape(*target_shape) case {"shape": shape}: @@ -465,7 +579,7 @@ class Permute(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -492,7 +606,7 @@ class ToDType(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -533,7 +647,7 @@ class SplitBytes(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -585,31 +699,128 @@ class Reindex(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] - match params: case {"src_field": src_field_name, "index_field": index_field_name}: index_field_obj = input_fields[index_field_name] - if len(index_field_obj.data.shape) != 2: - raise ValueError("Expecting grid for re-index operation") - decoding_update.append({ - "input_fields": [src_field_name], - "transforms": [{"flatten": {"start_dim": 0, "end_dim": 1}}], - }) original_data = input_fields[src_field_name].data new_fields[src_field_name] = Field(original_data[index_field_obj.data], parentOp) + + if index_field_obj.data.ndim > 1: + decoding_update.append({ + "input_fields": [src_field_name], + "transforms": [{"flatten": {"start_dim": 0, "end_dim": 1}}], + }) case _: raise ValueError(f"Unknown Reindex parameters: {params}") return new_fields, decoding_update -class PLAS(Transformation): +class Sort(Transformation): + @staticmethod + @override + def apply( + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any + ) -> tuple[dict[str, Field], list[dict[str, Any]]]: + input_fields = parentOp.input_fields + + field_name = next(iter(input_fields.keys())) + field_data = input_fields[field_name].data + + new_fields: dict[str, Field] = {} + decoding_update: list[dict[str, Any]] = [] + + sorted_indices = None + if field_data is None: + raise ValueError("Field data is None before sorting") + to_field_name = params["to_field"] + match params: + case {"method": "lexicographic"}: + # weight is effectively a boolean in this case + plas_cfg = {k: v for k, v in params.items() if k not in ["method", "labels"]} + plas_cfg["scaling_fn"] = "none" + plas_cfg["shuffle"] = False + plas_cfg.setdefault("improvement_break", 1e-4) + plas_cfg.setdefault("prune_by", None) # only reshapes to grid if not None + preprocess_dict = PLAS.plas_preprocess( + plas_cfg=PLASConfig(**plas_cfg), fields=parentOp.input_fields, verbose=verbose + ) + params_tensor = preprocess_dict["params_tensor"] + sorted_indices = np.lexsort( + params_tensor.permute(dims=(1, 0)).cpu().numpy(), + ) + sorted_indices = torch.tensor(sorted_indices, device=field_data.device) + + if "labels" in params: + inverse_sorted_indices = torch.argsort(sorted_indices) + labels_name = params["labels"] + original_labels = input_fields[labels_name].data + original_labels_grid = PLAS.as_grid_img(original_labels) + + updated_labels = inverse_sorted_indices[original_labels_grid] + new_fields[labels_name] = Field(updated_labels, parentOp) + decoding_update.append({ + "input_fields": [labels_name], + "transforms": [{"flatten": {"start_dim": 0, "end_dim": 1}}], + }) + if "prune_by" in params: + sorted_indices = PLAS.as_grid_img(sorted_indices) + + case {"method": "plas"}: + plas_cfg = {k: v for k, v in params.items() if k not in ["method"]} + preprocess_dict = PLAS.plas_preprocess( + plas_cfg=PLASConfig(**plas_cfg), + fields=parentOp.input_fields, + verbose=verbose, + ) + sorted_indices = PLAS.sort(**preprocess_dict) + + case _: + raise ValueError(f"Unknown Sort parameters: {params}") + + new_fields[to_field_name] = Field(sorted_indices, parentOp) + + return new_fields, decoding_update + + @staticmethod + def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: + """Get the dynamic parameters for a given transformation type. This might modify the values in params.""" + dynamic_params_config: list[dict[str, Any]] = [] + dynamic_params_config.append({ + "label": "method", + "type": "dropdown", + "values": ["lexicographic", "plas"], + }) + if params.get("method") == "plas": + dynamic_params_config.extend(PLAS.get_dynamic_params(params)) + + weight_config: list[dict[str, Any]] = [] + + for field_name in list(params["weights"].keys()): + weight_config.append({ + "label": field_name, + "type": "number", + "min": 0.0, + "max": 1.0, + "step": 0.05, + "dtype": float, + }) + dynamic_params_config.append({ + "label": "weights", + "type": "heading", + "params": weight_config, + }) + + return dynamic_params_config + + +class PLAS: @staticmethod def as_grid_img(tensor: Tensor) -> Tensor: num_primitives = tensor.shape[0] @@ -637,13 +848,16 @@ def primitive_filter_pruning_to_square_shape(data: Tensor, verbose: bool) -> Ten f"Removing {num_to_remove}/{num_primitives} primitives to fit the grid. ({100 * num_to_remove / num_primitives:.4f}%)" ) - _, keep_indices = torch.topk(data, k=grid_sidelen * grid_sidelen) + _, keep_indices = torch.topk(data.squeeze(), k=grid_sidelen * grid_sidelen) sorted_keep_indices = torch.sort(keep_indices)[0] return sorted_keep_indices @staticmethod - def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: bool) -> Tensor: - primitive_filter = PLAS.primitive_filter_pruning_to_square_shape(fields[plas_cfg.prune_by].data, verbose) + def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: bool) -> dict[str, Any]: + if plas_cfg.prune_by is not None: + primitive_filter = PLAS.primitive_filter_pruning_to_square_shape(fields[plas_cfg.prune_by].data, verbose) + else: + primitive_filter = None # TODO untested match plas_cfg.scaling_fn: @@ -675,11 +889,28 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo params_tensor = torch.cat(params_to_sort, dim=1) if plas_cfg.shuffle: - # TODO shuffling should be an option of sort_with_plas torch.manual_seed(42) shuffled_indices = torch.randperm(params_tensor.shape[0], device=params_tensor.device) params_tensor = params_tensor[shuffled_indices] + else: + shuffled_indices = None + + return { + "plas_cfg": plas_cfg, + "params_tensor": params_tensor, + "shuffled_indices": shuffled_indices, + "primitive_filter": primitive_filter, + "verbose": verbose, + } + @staticmethod + def sort( + plas_cfg: PLASConfig, + params_tensor: Tensor, + shuffled_indices: Tensor | None, + primitive_filter: Tensor | None, + verbose: bool, + ) -> Tensor: grid_to_sort = PLAS.as_grid_img(params_tensor).permute(2, 0, 1) _, sorted_indices_ret = sort_with_plas( grid_to_sort, improvement_break=float(plas_cfg.improvement_break), verbose=verbose @@ -687,7 +918,7 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo sorted_indices: Tensor = sorted_indices_ret.squeeze(0).to(params_tensor.device) - if plas_cfg.shuffle: + if plas_cfg.shuffle and shuffled_indices is not None: flat_indices = sorted_indices.flatten() unshuffled_flat_indices = shuffled_indices[flat_indices] sorted_indices = unshuffled_flat_indices.reshape(sorted_indices.shape) @@ -697,36 +928,17 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo return sorted_indices - @staticmethod - @override - def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False - ) -> tuple[dict[str, Field], list[dict[str, Any]]]: - new_fields: dict[str, Field] = {} - - plas_cfg_dict = params - if isinstance(plas_cfg_dict, dict): - sorted_indices = PLAS.plas_preprocess( - plas_cfg=PLASConfig(**plas_cfg_dict), - fields=parentOp.input_fields, - verbose=verbose, - ) - new_fields[plas_cfg_dict["to_field"]] = Field(sorted_indices, parentOp) - else: - raise TypeError(f"Unknown PLAS parameters: {params}") - - return new_fields, [] - @staticmethod def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: """Get the dynamic parameters for a given transformation type. This might modify the values in params.""" - if params.get("weights") is None: - raise ValueError(f"PLAS parameters is missing weights: {params}") - field_names = list(params["weights"].keys()) scaling_functions = ["standardize", "minmax", "none"] dynamic_params_config: list[dict[str, Any]] = [] + + params.setdefault("scaling_fn", "standardize") + params.setdefault("shuffle", True) + params.setdefault("improvement_break", 1e-4) dynamic_params_config.append({ "label": "scaling_fn", "type": "dropdown", @@ -747,22 +959,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "inverse_mapping": lambda x: np.log10(x), }) - weight_config: list[dict[str, Any]] = [] - for field_name in field_names: - weight_config.append({ - "label": field_name, - "type": "number", - "min": 0.0, - "max": 1.0, - "step": 0.05, - "dtype": float, - }) - dynamic_params_config.append({ - "label": "weights", - "type": "heading", - "params": weight_config, - }) - return dynamic_params_config @@ -770,10 +966,11 @@ class Combine(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: new_fields: dict[str, Field] = {} field_data: Tensor = torch.empty(0) + decoding_update: list[dict[str, Any]] = [] match params: case {"method": "bytes", "to_field": to_field_name}: @@ -801,24 +998,51 @@ def apply( tensors: list[Tensor] = [ parentOp.input_fields[source_field_name].data for source_field_name in parentOp.input_fields ] + decoding_update_dict = { + "input_fields": [to_field_name], + "transforms": [ + { + "split": { + "dim": dim, + "to_field_list": list(parentOp.input_fields), + } + } + ], + } if method == "stack": field_data = torch.stack(tensors, dim=dim) + decoding_update_dict["transforms"][0]["split"]["squeeze"] = (True,) elif method == "concat": field_data = torch.cat(tensors, dim=dim) + decoding_update_dict["transforms"][0]["split"]["squeeze"] = False + elif method == "stack-zeros": + zeros = torch.zeros(tensors[0].shape, dtype=tensors[0].dtype, device=tensors[0].device) + tensors.append(zeros) + field_data = torch.stack(tensors, dim=dim) + decoding_update_dict["transforms"][0]["split"]["squeeze"] = True + decoding_update_dict["transforms"][0]["split"]["to_field_list"] = [ + *list(parentOp.input_fields), + "_", + ] + else: raise ValueError(f"Unsupported combine method: {method}") + decoding_update_dict["transforms"][0]["split"]["split_size_or_sections"] = [ + t.shape[dim] if dim < len(t.shape) else 1 for t in tensors + ] + decoding_update.append(decoding_update_dict) new_fields[to_field_name] = Field(field_data, parentOp) case _: raise ValueError(f"Unknown Combine parameters: {params}") - return new_fields, [] + return new_fields, decoding_update class Lookup(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -837,8 +1061,12 @@ class WriteFile(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], + parentOp: "Operation", + verbose: bool = False, + **kwargs: Any, ) -> tuple[dict[str, Field], list[dict[str, Any]]]: + decoding_ops: list[dict[str, Any]] = kwargs.get("decoding_ops", []) decoding_update: list[dict[str, Any]] = [] match params: case {"type": "ply", "file_path": file_path, "base_path": base_path, "field_prefix": field_prefix}: @@ -856,6 +1084,8 @@ def apply( case {"type": "image", "image_codec": codec, "coding_params": coding_params, "base_path": base_path}: for field_name, field_obj in parentOp.input_fields.items(): field_data = field_obj.data + if field_data.shape[-1] == 1: + field_data = field_data.squeeze(-1) file_path = f"{field_name}.{codec}" output_file_path = Path(base_path) / Path(file_path) write_image( @@ -878,6 +1108,49 @@ def apply( } ], }) + case {"type": "canvas-metadata", "base_path": base_path}: + file_path = "meta.json" + output_file_path = Path(base_path) / Path(file_path) + field_names = list(parentOp.input_fields.keys()) + meta: dict[str, Any] = { + "packer": "ffsplat", + "version": 1, + } + # Readfile no input_fields + for field_name in field_names: + field = parentOp.input_fields[field_name] + shape = field.data.shape + shape_meta = [int(np.array(shape[:-1]).prod()), shape[-1]] + if field_name == "sh0": + shape_meta = [shape_meta[0], 1, shape_meta[1]] + # to keep the order semi-consistent with original sogs + meta[field_name] = {"shape": shape_meta, "dtype": None, "mins": [], "maxs": [], "files": []} + + for op in decoding_ops: + transforms_str = [next(iter(t_str_braced.keys())) for t_str_braced in op["transforms"]] + transform_types = [transformation_map[transform] for transform in transforms_str] + input_field = op["input_fields"][0] if len(op["input_fields"]) > 0 else "" + for idx, (t_str, t) in enumerate(zip(transforms_str, transform_types, strict=False)): + if t is ReadFile: + field_name_read = op["transforms"][idx][t_str].get("field_name") + if field_name_read.startswith(field_name): + codec = op["transforms"][idx][t_str].get("image_codec") + meta[field_name]["files"].append(f"{field_name_read}.{codec}") + elif ( + (t is SimpleQuantize) + and input_field.startswith(field_name) + and not input_field.endswith("labels") + ): + meta[field_name]["dtype"] = op["transforms"][idx][t_str]["dtype"] + meta[field_name]["mins"] = op["transforms"][idx][t_str]["min_values"] + meta[field_name]["maxs"] = op["transforms"][idx][t_str]["max_values"] + elif t is Reparametrize: + if op["transforms"][idx][t_str].get("method") == "unpack_quaternions": + to_field_name = op["transforms"][idx][t_str].get("to_field_name") + if to_field_name == field_name: + meta[to_field_name]["encoding"] = "quaternion_packed" + + write_json(output_file_path, meta) case _: raise ValueError(f"Unknown WriteFile parameters: {params}") @@ -894,17 +1167,15 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: dynamic_params_config.append({ "label": "image_codec", "type": "dropdown", - "values": ["avif", "png"], + "values": ["avif", "png", "webp"], "rebuild": True, }) # coding_params for image file dynamic_coding_params: list[dict[str, Any]] = [] - + coding_params: dict[str, Any] = params.get("coding_params", {}) match params["image_codec"]: case "avif": # check whether we need to update the coding params default: - coding_params: dict[str, Any] = params.get("coding_params", {}) - if not all(key in coding_params for key in ["quality", "chroma", "matrix_coefficients"]): coding_params.clear() coding_params["quality"] = -1 @@ -946,7 +1217,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: }) case "png": - coding_params = params.get("coding_params", {}) if "compression_level" not in coding_params: coding_params.clear() coding_params["compression_level"] = 3 @@ -964,6 +1234,46 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "type": "heading", "params": dynamic_coding_params, }) + case "webp": + if not all(key in coding_params for key in ["quality", "method", "exact", "lossless"]): + coding_params.clear() + coding_params["quality"] = 100 + coding_params["method"] = 6 + coding_params["exact"] = True + coding_params["lossless"] = True + dynamic_coding_params.append({ + "label": "exact", + "type": "bool", + "values": ["True", "False"], + }) + dynamic_coding_params.append({ + "label": "lossless", + "type": "bool", + "values": ["True", "False"], + }) + # filesize-speed tradeoff when lossless + dynamic_coding_params.append({ + "label": "quality", + "type": "number", + "min": 0, + "max": 100, + "step": 1, + "dtype": int, + }) + # also filesize-speed tradeoff, all compatible with lossless + dynamic_coding_params.append({ + "label": "method", + "type": "number", + "dtype": int, + "min": 0, + "max": 6, + "step": 1, + }) + dynamic_params_config.append({ + "label": "coding_params", + "type": "heading", + "params": dynamic_coding_params, + }) case _: raise ValueError(f"unknown image codec: {params["image_codec"]}") @@ -975,7 +1285,7 @@ class ReadFile(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] @@ -993,6 +1303,8 @@ def apply( # TODO: only do this once? register_avif_opener() img_field_data = torch.tensor(np.array(Image.open(file_path))) + case "webp": + img_field_data = torch.tensor(np.array(Image.open(file_path))) new_fields[field_name] = Field.from_file(img_field_data, file_path, field_name) case _: @@ -1005,7 +1317,7 @@ class SimpleQuantize(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -1014,6 +1326,13 @@ def apply( new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] + torch_dtype_to_str = { + torch.float32: "float32", + torch.uint8: "uint8", + torch.uint16: "uint16", + torch.uint32: "uint32", + torch.int32: "int32", + } match params: case {"min": min_val, "max": max_val, "dim": dim, "dtype": dtype_str, "round_to_int": round_to_int}: min_val_f = float(min_val) @@ -1030,23 +1349,16 @@ def apply( field_data = normalized - torch_dtype_to_str = { - torch.float32: "float32", - torch.uint8: "uint8", - torch.uint16: "uint16", - torch.uint32: "uint32", - torch.int32: "int32", - } - decoding_update.append({ "input_fields": [field_name], "transforms": [ { "simple_quantize": { - "min_values": min_vals.tolist(), - "max_values": max_vals.tolist(), + "min_values": min_vals.tolist() if min_vals.ndim > 0 else [min_vals.item()], + "max_values": max_vals.tolist() if max_vals.ndim > 0 else [max_vals.item()], "dim": dim, "dtype": torch_dtype_to_str[field_data.dtype], + "round_to_int": False, # in backmapping usually don't want rounding, right? } } ], @@ -1061,11 +1373,13 @@ def apply( "max_values": max_values, "dim": dim, "dtype": dtype_str, + "round_to_int": round_to_int, }: if field_data is None: raise ValueError("Field data is None before channelwise remapping") - field_data = convert_to_dtype(field_data, dtype_str) + min_vals = torch.amin(field_data, dim=[d for d in range(field_data.ndim) if d != dim]) + max_vals = torch.amax(field_data, dim=[d for d in range(field_data.ndim) if d != dim]) min_tensor = torch.tensor(min_values, device=field_data.device, dtype=torch.float32) max_tensor = torch.tensor(max_values, device=field_data.device, dtype=torch.float32) @@ -1082,8 +1396,28 @@ def apply( field_data = minmax(field_data) field_data = field_data * field_range + min_tensor + + decoding_update.append({ + "input_fields": [field_name], + "transforms": [ + { + "simple_quantize": { + "min_values": min_vals.tolist(), + "max_values": max_vals.tolist(), + "dim": dim, + "dtype": torch_dtype_to_str[field_data.dtype], + "round_to_int": False, # in backmapping usually don't want rounding, right? + } + } + ], + }) + + if round_to_int: + field_data = torch.round(field_data) + field_data = convert_to_dtype(field_data, dtype_str) + case _: - raise ValueError(f"Unknown remapping parameters: {params}") + raise ValueError(f"Unknown simple_quantize parameters: {params}") new_fields[field_name] = Field(field_data, parentOp) return new_fields, decoding_update @@ -1123,12 +1457,13 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "flatten": Flatten, "reshape": Reshape, "remapping": Remapping, + "reparametize": Reparametrize, "to_field": ToField, "permute": Permute, "to_dtype": ToDType, "split_bytes": SplitBytes, "reindex": Reindex, - "plas": PLAS, + "sort": Sort, "lookup": Lookup, "combine": Combine, "write_file": WriteFile, @@ -1137,17 +1472,32 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: } -def apply_transform(parentOp: "Operation", verbose: bool) -> tuple[dict[str, "Field"], list[dict[str, Any]]]: +def apply_transform( + parentOp: "Operation", verbose: bool, decoding_params_hashable: str +) -> tuple[dict[str, "Field"], list[dict[str, Any]]]: transformation = transformation_map.get(parentOp.transform_type) if transformation is None: raise ValueError(f"Unknown transformation: {parentOp.transform_type}") - return transformation.apply(parentOp.params[parentOp.transform_type], parentOp, verbose) + elif transformation is WriteFile: + decoding_ops: list[dict[str, Any]] = yaml.load(decoding_params_hashable, Loader=yaml.SafeLoader)["ops"] + return transformation.apply( + parentOp.params[parentOp.transform_type], parentOp, verbose=verbose, decoding_ops=decoding_ops + ) + else: + return transformation.apply(parentOp.params[parentOp.transform_type], parentOp, verbose=verbose) -def get_dynamic_params(params: dict[str, dict[str, Any]]) -> list[dict[str, Any]]: +def get_dynamic_params(params: dict[str, dict[str, Any]], input_field: list | str) -> list[dict[str, Any]]: """Get the dynamic parameters for a given transformation type. This might modify the values in params.""" transform_type = next(iter(params.keys())) transformation = transformation_map.get(transform_type) if transformation is None: raise ValueError(f"Unknown transformation: {transform_type}") + elif ( + transformation is Sort + and params[transform_type]["method"] == "plas" + and "weights" not in params[transform_type] + ): + params[transform_type]["weights"] = {k: 1.0 for k in input_field} + return transformation.get_dynamic_params(params[transform_type]) diff --git a/src/ffsplat/render/viewer.py b/src/ffsplat/render/viewer.py index 067aec4..b3f1516 100644 --- a/src/ffsplat/render/viewer.py +++ b/src/ffsplat/render/viewer.py @@ -14,6 +14,15 @@ from ._renderer import Renderer, RenderTask +available_output_format: list[str] = [ + "SOG-PlayCanvas", + "SOG-web", + "3DGS_INRIA_ply", + "3DGS_INRIA_nosh_ply", + "SOG-web-nosh", + "SOG-web-sh-split", +] + @dataclasses.dataclass class CameraState: