From bdd476e593039e3f7390996ea9877a132c4eddd5 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Mon, 2 Jun 2025 17:31:20 +0200 Subject: [PATCH 01/24] sog-canvas intermediate --- pyproject.toml | 2 +- src/ffsplat/coding/scene_encoder.py | 2 + src/ffsplat/conf/format/SOG-canvas.yaml | 298 ++++++++++++++++++++++++ src/ffsplat/models/transformations.py | 75 ++++++ 4 files changed, 376 insertions(+), 1 deletion(-) create mode 100644 src/ffsplat/conf/format/SOG-canvas.yaml diff --git a/pyproject.toml b/pyproject.toml index fc2bfd0..6f69210 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "jsonargparse>=4.37.0,<5.0.0", # TODO imagecodecs version > 2023.9.18 produce much larger JPEG XL files: https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/issues/3 # but for numpy > 2, we require a modern imagecodecs - # "imagecodecs[all]>=2024.12.30", + "imagecodecs[all]>=2024.12.30", "PyYAML>=6.0.2,<7.0.0", "opencv-python>=4.11.0.86,<5.0.0.0", "pillow>=11.1.0,<12.0.0", diff --git a/src/ffsplat/coding/scene_encoder.py b/src/ffsplat/coding/scene_encoder.py index 4d3e38d..427603c 100644 --- a/src/ffsplat/coding/scene_encoder.py +++ b/src/ffsplat/coding/scene_encoder.py @@ -201,6 +201,8 @@ def encode_gaussians(gaussians: Gaussians, output_path: Path, output_format: str encoding_params = EncodingParams.from_yaml_file(Path("src/ffsplat/conf/format/SOG-web-nosh.yaml")) case "SOG-web-sh-split": encoding_params = EncodingParams.from_yaml_file(Path("src/ffsplat/conf/format/SOG-web-sh-split.yaml")) + case "SOG-canvas": + encoding_params = EncodingParams.from_yaml_file(Path("src/ffsplat/conf/format/SOG-canvas.yaml")) case _: raise ValueError(f"Unsupported output format: {output_format}") diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml new file mode 100644 index 0000000..8a2c83b --- /dev/null +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -0,0 +1,298 @@ +profile: SOG-canvas +profile_version: 0.1 + +scene: + primitives: 3DGS-INRIA + params: + - means + - scales + - opacities + - quaternions + - sh + +ops: + - input_fields: [sh] + transforms: + - split: + to_field_list: [sh0, shN] #in sogs f_dc=sh0, frest=shN + split_size_or_sections: [1, 15] + dim: 1 + squeeze: false + #TODO: in sogs transpose (1,2) is done after reading from .ply here not? + + - input_fields: [means] + transforms: + - remapping: + method: signed-log + to_field: means + #- input_fields: [quaternions] + #transforms: + #- + - input_fields: [means, sh0, shN, opacities, scales, quaternions] + transforms: + - plas: + prune_by: opacities #isn't done in SOG-canvas + scaling_fn: none #standardize + # activated: true + shuffle: true + improvement_break: 1e-4 + to_field: sorted_indices + weights: + means: 1.0 + sh0: 1.0 + shN: 0.0 #shN not in sortkeys + opacities: 0.0 + scales: 1.0 + quaternions: 1.0 + + #means + - input_fields: [means, sorted_indices] + transforms: + - reindex: + to_field: means_reindexed + indices_field: sorted_indices + + - input_fields: [means_reindexed] + transforms: + - remapping: + method: channelwise-minmax + min: 0 #for dyncamic minmax + max: 1 + dim: 2 + - to_dtype: + dtype: uint16 + round_to_int: true + - split_bytes: + to_fields_with_prefix: means_bytes_ + num_bytes: 2 + + - input_fields: [means_bytes_0] + transforms: + -to_dtype: + dtype: uint8 + round_to_int: false + to_field: means_l + + - input_fields: [means_bytes_1] + transforms: + -to_dtype: + dtype: uint8 + round_to_int: false + to_field: means_u + + #scales + - input_fields: [scales, sorted_indices] + transforms: + - reindex: + to_field: scales + indices_field: sorted_indices + + - input_fields: [scales] + transforms: + - remapping: + method: channelwise-minmax + min: 0 #for dyncamic minmax + max: 1 + dim: 2 + - to_dtype: + dtype: uint8 + round_to_int: true + + - input_fields: [opacities, sorted_indices] + transforms: + - reindex: + indices_field: sorted_indices + + - input_fields: [sh0, sorted_indices] + transforms: + - reindex: + src_field: sh0 + index_field: sorted_indices + + - input_fields: [shN, sorted_indices] + transforms: + - reindex: + src_field: shN + index_field: sorted_indices + + #opacity-sh0 rgba + - input_fields: [opacities, sh0] + transforms: + - combine: + method: concat + dim: -1 + to_field: sh0 + + - input_fields: [sh0] + transforms: + - remapping: + method: channelwise-minmax + min: 0 + max: 255 + dim: 2 + - to_dtype: + dtype: uint8 + round_to_int: true + #shN + - input_fields: [shN] + transforms: + #TODO: reshaping? + - reshape: + to_shape: [15, -1] #params.shape[0],-1 (shN) + - permute: + dims: [2, 1, 0] # probably incorrect + - permute: + dims: [1, 0] # in sog-web 0,2,1 + - cluster: + method: kmeans + num_clusters: 4096 #TODO: is dynamic in sogs + distance: manhattan + to_fields_with_prefix: shN_ + + - input_fields: [shN_centroids] + transfroms: + - permute: + to_shape: [1, 0] + - remapping: + method: minmax + - to_dtype: + dtype: uint8 + round_to_int: true + + - input_fields: [shN_centroids] + transforms: + - permute: + dims: [2, 1, 0] #transpose in sogs + to_field: shN_centroids_transposed + + - input_fields: [shN_centroids_transposed] + transforms: + - sort: + method: lexsort + to_field: shN_centroids_indices + + - input_fields: [shN_centroids_indices] + transforms: + - permute: + dims: [0, 1] #TODO: make this sequence simple + - reshape: + to_shape: [64, -1] + - permute: + dims: [0, 1] + - reshape: + to_shape: [-1] + + - input_fields: [shN_centroids, shN_centroids_indices] + transforms: + - reindex: + src_field: shN_centroids + index_field: shN_centroids_indices + + - input_fields: [shN_centroids] + transforms: + - pack: + multiple: 64 + length: 15 #should be dynamic based number of spherical harmonics ShN parameters? + to_field: shN_centroids + + - input_fields: [shN_centroids_indices] + transforms: + - sort: + method: argsort + to_field: shN_centroids_indices_sorted_inverse + + - input_fields: [shN_labels, shN_centroids_indices_sorted_inverse] + transforms: + - reindex: + src_field: shN_labels + index_field: shN_centroids_indices_sorted_inverse + + - input_fields: [shN_labels] + transfroms: + - remapping: + #reindex (inverse) + to_dtype: + dtype: uint16 + round_to_int: False + reshape: + to_shape: [n_sidelen, n_sidelen] + split_bytes: + to_fields_with_prefix: shN_labels_ + num_bytes: 2 + - input_fields: [shN_labels_0, shN_labels_1] + transforms: + - combine: + method: concat-zeros + dim: 2 + to_field: shN_labels + - to_dtype: + dtype: uint8 + round_to_int: true + # add zeros in 3rd dimension + # map to unit +#img file outputs, sogs uses lossless webp, sog uses avif +outputs: + - input_fields: [means_l] + transforms: + - write_file: + type: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [means_u] + transforms: + - write_file: + type: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [scales] #TODO: CHECK, if it is rgba + transforms: + - write_file: + type: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [sh0] #TODO: CHECK, if it is rgba + transforms: + - write_file: + type: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [shN_centroids] #rgb + transforms: + - write_file: + type: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [shN_labels] #rgb + transforms: + - write_file: + type: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true +#Done: means,scales, opacities, sh0, shN, lexsort, argsort, output webp +#TODO: dynamic number of clusters +#TODO: quaternions +# check reshaping and transposes +# TODO: side_len grid reshaping?-> doen by sorting topk already?, meta-data sogs, meta-data ffsplat diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index a17615c..f6e51f7 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -73,6 +73,15 @@ def write_image(output_file_path: Path, field_data: Tensor, file_type: str, codi chroma=coding_params.get("chroma", 444), matrix_coefficients=coding_params.get("matrix_coefficients", 0), ) + case "webp": + Image.fromarray(field_data.cpu().numpy()).save( + output_file_path, + format="WEBP", + quality=coding_params.get("quality", 100), + lossless=coding_params.get("lossless", False), + method=coding_params.get("method", 6), + exact=coding_params.get("exact", False), + ) case _: raise ValueError(f"Unsupported file type: {file_type}") @@ -461,6 +470,36 @@ def apply( return new_fields, decoding_update +class Pack(Transformation): + @staticmethod + @override + def apply( + params: dict[str, Any], parentOp: "Operation", verbose: bool = False + ) -> tuple[dict[str, Field], list[dict[str, Any]]]: + input_fields = parentOp.input_fields + + field_name = next(iter(input_fields.keys())) + field_data = input_fields[field_name].data + + new_fields: dict[str, Field] = {} + decoding_update: list[dict[str, Any]] = [] + + match params: + case {"multiple": multiple}: + # /num clusters / 3 + multiple = params.get("mutliple", 64) + length = params.get("length", 1) + field_data = field_data.reshape((-1, int(length * multiple), 3)) + case _: + raise ValueError(f"Unknown Pack parameters: {params}") + decoding_update.append({ + "input_fields": [field_name], + "transforms": [{"pack": {"multiple": multiple, "length": length}}], + }) + new_fields[field_name] = Field(field_data, parentOp) + return new_fields, decoding_update + + class Permute(Transformation): @staticmethod @override @@ -609,6 +648,37 @@ def apply( return new_fields, decoding_update +class Sort(Transformation): + @staticmethod + @override + def apply( + params: dict[str, Any], parentOp: "Operation", verbose: bool = False + ) -> tuple[dict[str, Field], list[dict[str, Any]]]: + input_fields = parentOp.input_fields + + field_name = next(iter(input_fields.keys())) + field_data = input_fields[field_name].data + + new_fields: dict[str, Field] = {} + decoding_update: list[dict[str, Any]] = [] + + if field_data is None: + raise ValueError("Field data is None before lexicographic sorting") + match params: + case {"method": "lexicographic"}: + sorted_indices = np.lexsort(field_data.cpu().numpy()) + case {"method": "argsort"}: + sorted_indices = torch.argsort(field_data).cpu().numpy() + case _: + raise ValueError(f"Unknown Sort parameters: {params}") + # sorted_indices = sorted_indeces.reshape(params.get("shape",(64,-1))) + # new_fields[field_name] = Field(field_data_sorted, parentOp) + new_fields["to_field"] = Field(sorted_indices, parentOp) + + # TODO: decoding update??? + return new_fields, decoding_update + + class PLAS(Transformation): @staticmethod def as_grid_img(tensor: Tensor) -> Tensor: @@ -805,6 +875,9 @@ def apply( field_data = torch.stack(tensors, dim=dim) elif method == "concat": field_data = torch.cat(tensors, dim=dim) + elif method == "concat-zeros": + zeros = torch.zeros(tensors[0].shape, dtype=tensors[0].dtype, device=tensors[0].device) + field_data = torch.cat(tensors + zeros, dim=dim) else: raise ValueError(f"Unsupported combine method: {method}") new_fields[to_field_name] = Field(field_data, parentOp) @@ -1128,9 +1201,11 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "to_dtype": ToDType, "split_bytes": SplitBytes, "reindex": Reindex, + "sort": Sort, "plas": PLAS, "lookup": Lookup, "combine": Combine, + "pack": Pack, "write_file": WriteFile, "read_file": ReadFile, "simple_quantize": SimpleQuantize, From 4d4e9d42d5b81096135f48714ea50d04f22f6b65 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Tue, 3 Jun 2025 16:55:27 +0200 Subject: [PATCH 02/24] means working, no sogs-metadata yet --- src/ffsplat/conf/format/SOG-canvas.yaml | 410 ++++++++++++------------ src/ffsplat/models/transformations.py | 7 +- 2 files changed, 212 insertions(+), 205 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 8a2c83b..6413f3f 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -49,18 +49,16 @@ ops: - input_fields: [means, sorted_indices] transforms: - reindex: - to_field: means_reindexed - indices_field: sorted_indices + src_field: means + index_field: sorted_indices - - input_fields: [means_reindexed] + - input_fields: [means] transforms: - - remapping: - method: channelwise-minmax - min: 0 #for dyncamic minmax - max: 1 - dim: 2 - - to_dtype: + - simple_quantize: dtype: uint16 + min: 0 + max: 65535 + dim: 2 round_to_int: true - split_bytes: to_fields_with_prefix: means_bytes_ @@ -68,229 +66,237 @@ ops: - input_fields: [means_bytes_0] transforms: - -to_dtype: - dtype: uint8 - round_to_int: false - to_field: means_l + - to_field: + to_field_name: means_l - input_fields: [means_bytes_1] transforms: - -to_dtype: - dtype: uint8 - round_to_int: false - to_field: means_u + - to_field: + to_field_name: means_u #scales - - input_fields: [scales, sorted_indices] - transforms: - - reindex: - to_field: scales - indices_field: sorted_indices + #- input_fields: [scales, sorted_indices] + #transforms: + #- reindex: + #srd_field: scales + #index_field: sorted_indices - - input_fields: [scales] - transforms: - - remapping: - method: channelwise-minmax - min: 0 #for dyncamic minmax - max: 1 - dim: 2 - - to_dtype: - dtype: uint8 - round_to_int: true + #- input_fields: [scales] + #transforms: + #- remapping: + #method: channelwise-minmax + #min: 0 #for dyncamic minmax + #max: 1 + #dim: 2 + #- to_dtype: + #dtype: uint8 + #round_to_int: true - - input_fields: [opacities, sorted_indices] - transforms: - - reindex: - indices_field: sorted_indices + #- input_fields: [opacities, sorted_indices] + #transforms: + #- reindex: + #src_field: opacities + #index_field: sorted_indices - - input_fields: [sh0, sorted_indices] - transforms: - - reindex: - src_field: sh0 - index_field: sorted_indices + #- input_fields: [sh0, sorted_indices] + #transforms: + #- reindex: + #src_field: sh0 + #index_field: sorted_indices - - input_fields: [shN, sorted_indices] - transforms: - - reindex: - src_field: shN - index_field: sorted_indices + #- input_fields: [shN, sorted_indices] + #transforms: + #- reindex: + #src_field: shN + #index_field: sorted_indices - #opacity-sh0 rgba - - input_fields: [opacities, sh0] - transforms: - - combine: - method: concat - dim: -1 - to_field: sh0 + ##opacity-sh0 rgba + #- input_fields: [opacities, sh0] + #transforms: + #- combine: + #method: concat + #dim: -1 + #to_field: sh0 - - input_fields: [sh0] - transforms: - - remapping: - method: channelwise-minmax - min: 0 - max: 255 - dim: 2 - - to_dtype: - dtype: uint8 - round_to_int: true - #shN - - input_fields: [shN] - transforms: - #TODO: reshaping? - - reshape: - to_shape: [15, -1] #params.shape[0],-1 (shN) - - permute: - dims: [2, 1, 0] # probably incorrect - - permute: - dims: [1, 0] # in sog-web 0,2,1 - - cluster: - method: kmeans - num_clusters: 4096 #TODO: is dynamic in sogs - distance: manhattan - to_fields_with_prefix: shN_ + #- input_fields: [sh0] + #transforms: + #- remapping: + #method: channelwise-minmax + #min: 0 + #max: 255 + #dim: 2 + #- to_dtype: + #dtype: uint8 + #round_to_int: true + ##shN + #- input_fields: [shN] + #transforms: + ##TODO: reshaping? + #- reshape: + #to_shape: [15, -1] #params.shape[0],-1 (shN) + #- permute: + #dims: [2, 1, 0] # probably incorrect + #- permute: + #dims: [1, 0] # in sog-web 0,2,1 + #- cluster: + #method: kmeans + #num_clusters: 4096 #TODO: is dynamic in sogs + #distance: manhattan + #to_fields_with_prefix: shN_ - - input_fields: [shN_centroids] - transfroms: - - permute: - to_shape: [1, 0] - - remapping: - method: minmax - - to_dtype: - dtype: uint8 - round_to_int: true + #- input_fields: [shN_centroids] + #transfroms: + #- permute: + #to_shape: [1, 0] + #- remapping: + #method: minmax + #- to_dtype: + #dtype: uint8 + #round_to_int: true - - input_fields: [shN_centroids] - transforms: - - permute: - dims: [2, 1, 0] #transpose in sogs - to_field: shN_centroids_transposed + #- input_fields: [shN_centroids] + #transforms: + #- permute: + #dims: [2, 1, 0] #transpose in sogs + #to_field: shN_centroids_transposed - - input_fields: [shN_centroids_transposed] - transforms: - - sort: - method: lexsort - to_field: shN_centroids_indices + #- input_fields: [shN_centroids_transposed] + #transforms: + #- sort: + #method: lexsort + #to_field: shN_centroids_indices - - input_fields: [shN_centroids_indices] - transforms: - - permute: - dims: [0, 1] #TODO: make this sequence simple - - reshape: - to_shape: [64, -1] - - permute: - dims: [0, 1] - - reshape: - to_shape: [-1] + #- input_fields: [shN_centroids_indices] + #transforms: + #- permute: + #dims: [0, 1] #TODO: make this sequence simple + #- reshape: + #to_shape: [64, -1] + #- permute: + #dims: [0, 1] + #- reshape: + #to_shape: [-1] - - input_fields: [shN_centroids, shN_centroids_indices] - transforms: - - reindex: - src_field: shN_centroids - index_field: shN_centroids_indices + #- input_fields: [shN_centroids, shN_centroids_indices] + #transforms: + #- reindex: + #src_field: shN_centroids + #index_field: shN_centroids_indices - - input_fields: [shN_centroids] - transforms: - - pack: - multiple: 64 - length: 15 #should be dynamic based number of spherical harmonics ShN parameters? - to_field: shN_centroids + #- input_fields: [shN_centroids] + #transforms: + #- pack: + #multiple: 64 + #length: 15 #should be dynamic based number of spherical harmonics ShN parameters? + #to_field: shN_centroids - - input_fields: [shN_centroids_indices] - transforms: - - sort: - method: argsort - to_field: shN_centroids_indices_sorted_inverse + #- input_fields: [shN_centroids_indices] + #transforms: + #- sort: + #method: argsort + #to_field: shN_centroids_indices_sorted_inverse - - input_fields: [shN_labels, shN_centroids_indices_sorted_inverse] - transforms: - - reindex: - src_field: shN_labels - index_field: shN_centroids_indices_sorted_inverse + #- input_fields: [shN_labels, shN_centroids_indices_sorted_inverse] + #transforms: + #- reindex: + #src_field: shN_labels + #index_field: shN_centroids_indices_sorted_inverse + + #- input_fields: [shN_labels] + #transfroms: + #- remapping: + ##reindex (inverse) + #to_dtype: + #dtype: uint16 + #round_to_int: False + #reshape: + #to_shape: [n_sidelen, n_sidelen] + #split_bytes: + #to_fields_with_prefix: shN_labels_ + #num_bytes: 2 + #- input_fields: [shN_labels_0, shN_labels_1] + #transforms: + #- combine: + #method: concat-zeros + #dim: 2 + #to_field: shN_labels + #- to_dtype: + #dtype: uint8 + #round_to_int: true + ## add zeros in 3rd dimension + ## map to unit + ##img file outputs, sogs uses lossless webp, sog uses avif - - input_fields: [shN_labels] - transfroms: - - remapping: - #reindex (inverse) - to_dtype: - dtype: uint16 - round_to_int: False - reshape: - to_shape: [n_sidelen, n_sidelen] - split_bytes: - to_fields_with_prefix: shN_labels_ - num_bytes: 2 - - input_fields: [shN_labels_0, shN_labels_1] - transforms: - - combine: - method: concat-zeros - dim: 2 - to_field: shN_labels - - to_dtype: - dtype: uint8 - round_to_int: true - # add zeros in 3rd dimension - # map to unit -#img file outputs, sogs uses lossless webp, sog uses avif -outputs: - input_fields: [means_l] transforms: - write_file: - type: webp - coding_params: - lossless: true - quality: 100 - method: 6 - exact: true - + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true - input_fields: [means_u] transforms: - write_file: - type: webp - coding_params: - lossless: true - quality: 100 - method: 6 - exact: true + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true - - input_fields: [scales] #TODO: CHECK, if it is rgba - transforms: - - write_file: - type: webp - coding_params: - lossless: true - quality: 100 - method: 6 - exact: true + #- input_fields: + #from_fields_with_prefix: means_bytes_ + #transforms: + #- write_file: + #type: image + #image_codec: avif + #coding_params: + #quality: -1 + #chroma: 444 + #matrix_coefficients: 0 + #- input_fields: [scales] #TODO: CHECK, if it is rgba + #transforms: + #- write_file: + #type: webp + #coding_params: + #lossless: true + #quality: 100 + #method: 6 + #exact: true - - input_fields: [sh0] #TODO: CHECK, if it is rgba - transforms: - - write_file: - type: webp - coding_params: - lossless: true - quality: 100 - method: 6 - exact: true + #- input_fields: [sh0] #TODO: CHECK, if it is rgba + #transforms: + #- write_file: + #type: webp + #coding_params: + #lossless: true + #quality: 100 + #method: 6 + #exact: true - - input_fields: [shN_centroids] #rgb - transforms: - - write_file: - type: webp - coding_params: - lossless: true - quality: 100 - method: 6 - exact: true + #- input_fields: [shN_centroids] #rgb + #transforms: + #- write_file: + #type: webp + #coding_params: + #lossless: true + #quality: 100 + #method: 6 + #exact: true - - input_fields: [shN_labels] #rgb - transforms: - - write_file: - type: webp - coding_params: - lossless: true - quality: 100 - method: 6 - exact: true + #- input_fields: [shN_labels] #rgb + #transforms: + #- write_file: + #type: webp + #coding_params: + #lossless: true + #quality: 100 + #method: 6 + #exact: true #Done: means,scales, opacities, sh0, shN, lexsort, argsort, output webp #TODO: dynamic number of clusters #TODO: quaternions diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index f6e51f7..2e40c53 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -78,9 +78,9 @@ def write_image(output_file_path: Path, field_data: Tensor, file_type: str, codi output_file_path, format="WEBP", quality=coding_params.get("quality", 100), - lossless=coding_params.get("lossless", False), + lossless=coding_params.get("lossless", True), method=coding_params.get("method", 6), - exact=coding_params.get("exact", False), + exact=coding_params.get("exact", True), ) case _: raise ValueError(f"Unsupported file type: {file_type}") @@ -877,7 +877,8 @@ def apply( field_data = torch.cat(tensors, dim=dim) elif method == "concat-zeros": zeros = torch.zeros(tensors[0].shape, dtype=tensors[0].dtype, device=tensors[0].device) - field_data = torch.cat(tensors + zeros, dim=dim) + tensors.append(zeros) + field_data = torch.cat(tensors, dim=dim) else: raise ValueError(f"Unsupported combine method: {method}") new_fields[to_field_name] = Field(field_data, parentOp) From d43829fbd4378558cadbd230b435b904a4e3eaed Mon Sep 17 00:00:00 2001 From: fleischmann Date: Wed, 4 Jun 2025 14:45:08 +0200 Subject: [PATCH 03/24] scales, spherical harmonics 0, opacities --- src/ffsplat/conf/format/SOG-canvas.yaml | 156 +++++++++++++----------- 1 file changed, 82 insertions(+), 74 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 6413f3f..2553e71 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -17,7 +17,7 @@ ops: to_field_list: [sh0, shN] #in sogs f_dc=sh0, frest=shN split_size_or_sections: [1, 15] dim: 1 - squeeze: false + squeeze: true #false #TODO: in sogs transpose (1,2) is done after reading from .ply here not? - input_fields: [means] @@ -28,10 +28,20 @@ ops: #- input_fields: [quaternions] #transforms: #- + - input_fields: [scales] + transforms: + - remapping: + method: log + + - input_fields: [opacities] + transforms: + - remapping: + method: inverse-sigmoid + - input_fields: [means, sh0, shN, opacities, scales, quaternions] transforms: - plas: - prune_by: opacities #isn't done in SOG-canvas + prune_by: opacities # in sogs purned by 0.5 of number gaussians instead scaling_fn: none #standardize # activated: true shuffle: true @@ -75,59 +85,67 @@ ops: to_field_name: means_u #scales - #- input_fields: [scales, sorted_indices] - #transforms: - #- reindex: - #srd_field: scales - #index_field: sorted_indices + - input_fields: [scales, sorted_indices] + transforms: + - reindex: + src_field: scales + index_field: sorted_indices - #- input_fields: [scales] - #transforms: - #- remapping: - #method: channelwise-minmax - #min: 0 #for dyncamic minmax - #max: 1 - #dim: 2 - #- to_dtype: - #dtype: uint8 - #round_to_int: true + - input_fields: [scales] + transforms: + - simple_quantize: + min: 0 + max: 255 + dim: 2 + dtype: uint8 + round_to_int: true - #- input_fields: [opacities, sorted_indices] - #transforms: - #- reindex: - #src_field: opacities - #index_field: sorted_indices + - input_fields: [opacities, sorted_indices] + transforms: + - reindex: + src_field: opacities + index_field: sorted_indices - #- input_fields: [sh0, sorted_indices] + #- input_fields: [sh0] #transforms: - #- reindex: - #src_field: sh0 - #index_field: sorted_indices + #- reshape: + #start_dim: 2 + #shape: [3] + + - input_fields: [sh0, sorted_indices] + transforms: + - reindex: + src_field: sh0 + index_field: sorted_indices + + - input_fields: [opacities] + transforms: + - reshape: + start_dim: 2 + shape: [1] #- input_fields: [shN, sorted_indices] #transforms: #- reindex: #src_field: shN #index_field: sorted_indices + # opacity-sh0 rgba + - input_fields: [opacities, sh0] + transforms: + - combine: + method: concat + dim: 2 + to_field: sh0 - ##opacity-sh0 rgba - #- input_fields: [opacities, sh0] - #transforms: - #- combine: - #method: concat - #dim: -1 - #to_field: sh0 + - input_fields: [sh0] + transforms: + - simple_quantize: + min: 0 + max: 255 + dim: 2 + dtype: uint8 + round_to_int: true - #- input_fields: [sh0] - #transforms: - #- remapping: - #method: channelwise-minmax - #min: 0 - #max: 255 - #dim: 2 - #- to_dtype: - #dtype: uint8 - #round_to_int: true ##shN #- input_fields: [shN] #transforms: @@ -226,7 +244,6 @@ ops: ## add zeros in 3rd dimension ## map to unit ##img file outputs, sogs uses lossless webp, sog uses avif - - input_fields: [means_l] transforms: - write_file: @@ -237,6 +254,7 @@ ops: quality: 100 method: 6 exact: true + - input_fields: [means_u] transforms: - write_file: @@ -248,36 +266,27 @@ ops: method: 6 exact: true - #- input_fields: - #from_fields_with_prefix: means_bytes_ - #transforms: - #- write_file: - #type: image - #image_codec: avif - #coding_params: - #quality: -1 - #chroma: 444 - #matrix_coefficients: 0 - #- input_fields: [scales] #TODO: CHECK, if it is rgba - #transforms: - #- write_file: - #type: webp - #coding_params: - #lossless: true - #quality: 100 - #method: 6 - #exact: true - - #- input_fields: [sh0] #TODO: CHECK, if it is rgba - #transforms: - #- write_file: - #type: webp - #coding_params: - #lossless: true - #quality: 100 - #method: 6 - #exact: true + - input_fields: [scales] + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + - input_fields: [sh0] #TODO: CHECK, if it is rgba + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true #- input_fields: [shN_centroids] #rgb #transforms: #- write_file: @@ -287,7 +296,6 @@ ops: #quality: 100 #method: 6 #exact: true - #- input_fields: [shN_labels] #rgb #transforms: #- write_file: From b5c47079712a49c820659b06421fff5169d94109 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Wed, 4 Jun 2025 14:55:37 +0200 Subject: [PATCH 04/24] corrected order of opacities+sh0 webp --- src/ffsplat/conf/format/SOG-canvas.yaml | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 2553e71..7b06858 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -124,13 +124,7 @@ ops: start_dim: 2 shape: [1] - #- input_fields: [shN, sorted_indices] - #transforms: - #- reindex: - #src_field: shN - #index_field: sorted_indices - # opacity-sh0 rgba - - input_fields: [opacities, sh0] + - input_fields: [sh0, opacities] #opacity-sh0 rgba transforms: - combine: method: concat From 908003ce95d2f45704fcdcf853cda629193c0bf7 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Wed, 4 Jun 2025 17:10:38 +0200 Subject: [PATCH 05/24] shN intermediate --- src/ffsplat/conf/format/SOG-canvas.yaml | 95 ++++++++++++------------- src/ffsplat/models/transformations.py | 22 +++--- 2 files changed, 58 insertions(+), 59 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 7b06858..bcd014a 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -118,6 +118,12 @@ ops: src_field: sh0 index_field: sorted_indices + - input_fields: [shN, sorted_indices] + transforms: + - reindex: + src_field: shN + index_field: sorted_indices + - input_fields: [opacities] transforms: - reshape: @@ -141,59 +147,50 @@ ops: round_to_int: true ##shN - #- input_fields: [shN] - #transforms: - ##TODO: reshaping? - #- reshape: - #to_shape: [15, -1] #params.shape[0],-1 (shN) - #- permute: - #dims: [2, 1, 0] # probably incorrect - #- permute: - #dims: [1, 0] # in sog-web 0,2,1 - #- cluster: - #method: kmeans - #num_clusters: 4096 #TODO: is dynamic in sogs - #distance: manhattan - #to_fields_with_prefix: shN_ - - #- input_fields: [shN_centroids] - #transfroms: - #- permute: - #to_shape: [1, 0] - #- remapping: - #method: minmax - #- to_dtype: - #dtype: uint8 - #round_to_int: true + - input_fields: [shN] + transforms: + - reshape: + shape: [-1, 45] # 3*15 spherical harmonics + - cluster: + method: kmeans + num_clusters: 65536 #TODO: is dynamic in sogs + distance: manhattan + to_fields_with_prefix: shN_ - #- input_fields: [shN_centroids] - #transforms: - #- permute: - #dims: [2, 1, 0] #transpose in sogs - #to_field: shN_centroids_transposed + - input_fields: [shN_centroids] + transforms: + - simple_quantize: + min: 0 + max: 255 + dim: 2 #dummy dim + dtype: uint8 + round_to_int: true - #- input_fields: [shN_centroids_transposed] - #transforms: - #- sort: - #method: lexsort - #to_field: shN_centroids_indices + - input_fields: [shN_centroids] + transforms: + - sort: + method: lexicographic + to_field: shN_centroids_indices - #- input_fields: [shN_centroids_indices] - #transforms: - #- permute: - #dims: [0, 1] #TODO: make this sequence simple - #- reshape: - #to_shape: [64, -1] - #- permute: - #dims: [0, 1] - #- reshape: - #to_shape: [-1] + - input_fields: [shN_centroids_indices] + transforms: + - reshape: + start_dim: 1 + shape: [1] + #- permute: + #dims: [1, 0] #TODO: make this sequence simple + #- reshape: + #shape: [64, -1] + #- permute: + #dims: [1, 0] + #- reshape: + #shape: [-1] - #- input_fields: [shN_centroids, shN_centroids_indices] - #transforms: - #- reindex: - #src_field: shN_centroids - #index_field: shN_centroids_indices + - input_fields: [shN_centroids, shN_centroids_indices] + transforms: + - reindex: + src_field: shN_centroids + index_field: shN_centroids_indices #- input_fields: [shN_centroids] #transforms: diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 2e40c53..c1be7ac 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -485,17 +485,17 @@ def apply( decoding_update: list[dict[str, Any]] = [] match params: - case {"multiple": multiple}: + case {"multiple": multiple, "length": length}: # /num clusters / 3 - multiple = params.get("mutliple", 64) - length = params.get("length", 1) - field_data = field_data.reshape((-1, int(length * multiple), 3)) + # multiple = params.get("mutliple", 64) + # length = params.get("length", 1) + field_data = field_data.reshape((-1, int(length * multiple / 3), 3)) case _: raise ValueError(f"Unknown Pack parameters: {params}") decoding_update.append({ "input_fields": [field_name], "transforms": [{"pack": {"multiple": multiple, "length": length}}], - }) + }) # TODO: this is not correct, need to flatten new_fields[field_name] = Field(field_data, parentOp) return new_fields, decoding_update @@ -630,7 +630,7 @@ def apply( new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] - + # TODO: make this compatible with single match params: case {"src_field": src_field_name, "index_field": index_field_name}: index_field_obj = input_fields[index_field_name] @@ -662,18 +662,20 @@ def apply( new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] + sorted_indices = None if field_data is None: - raise ValueError("Field data is None before lexicographic sorting") + raise ValueError("Field data is None before sorting") match params: case {"method": "lexicographic"}: - sorted_indices = np.lexsort(field_data.cpu().numpy()) + sorted_indices = np.lexsort(field_data.permute(dims=(1, 0)).cpu().numpy()) + sorted_indices = torch.tensor(sorted_indices, device=field_data.device) case {"method": "argsort"}: - sorted_indices = torch.argsort(field_data).cpu().numpy() + sorted_indices = torch.argsort(field_data) case _: raise ValueError(f"Unknown Sort parameters: {params}") # sorted_indices = sorted_indeces.reshape(params.get("shape",(64,-1))) # new_fields[field_name] = Field(field_data_sorted, parentOp) - new_fields["to_field"] = Field(sorted_indices, parentOp) + new_fields[params["to_field"]] = Field(sorted_indices, parentOp) # TODO: decoding update??? return new_fields, decoding_update From f29b68d4db965c1850a2c5148dfc5adee15fc10d Mon Sep 17 00:00:00 2001 From: fleischmann Date: Wed, 11 Jun 2025 11:03:11 +0200 Subject: [PATCH 06/24] shN encoding --- src/ffsplat/conf/format/SOG-canvas.yaml | 157 +++++++++++++----------- src/ffsplat/models/transformations.py | 37 +----- 2 files changed, 89 insertions(+), 105 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index bcd014a..81e5f38 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -162,7 +162,7 @@ ops: - simple_quantize: min: 0 max: 255 - dim: 2 #dummy dim + dim: 20 #dummy dim dtype: uint8 round_to_int: true @@ -172,19 +172,17 @@ ops: method: lexicographic to_field: shN_centroids_indices - - input_fields: [shN_centroids_indices] + - input_fields: [shN_centroids_indices] #for reindexing + transforms: + - reshape: + start_dim: 1 + shape: [1] + + - input_fields: [shN_labels] #for reindexing transforms: - reshape: start_dim: 1 shape: [1] - #- permute: - #dims: [1, 0] #TODO: make this sequence simple - #- reshape: - #shape: [64, -1] - #- permute: - #dims: [1, 0] - #- reshape: - #shape: [-1] - input_fields: [shN_centroids, shN_centroids_indices] transforms: @@ -192,49 +190,62 @@ ops: src_field: shN_centroids index_field: shN_centroids_indices - #- input_fields: [shN_centroids] - #transforms: - #- pack: - #multiple: 64 - #length: 15 #should be dynamic based number of spherical harmonics ShN parameters? - #to_field: shN_centroids + - input_fields: [shN_centroids_indices] + transforms: + - reshape: + shape: [-1] #for argsort - #- input_fields: [shN_centroids_indices] - #transforms: - #- sort: - #method: argsort - #to_field: shN_centroids_indices_sorted_inverse + - input_fields: [shN_centroids] + transforms: + - reshape: + shape: [-1, 960, 3] # int(num_clusters*num_spherical_harmonics/3 = - #- input_fields: [shN_labels, shN_centroids_indices_sorted_inverse] - #transforms: - #- reindex: - #src_field: shN_labels - #index_field: shN_centroids_indices_sorted_inverse - - #- input_fields: [shN_labels] - #transfroms: - #- remapping: - ##reindex (inverse) - #to_dtype: - #dtype: uint16 - #round_to_int: False - #reshape: - #to_shape: [n_sidelen, n_sidelen] - #split_bytes: - #to_fields_with_prefix: shN_labels_ - #num_bytes: 2 - #- input_fields: [shN_labels_0, shN_labels_1] - #transforms: - #- combine: - #method: concat-zeros - #dim: 2 - #to_field: shN_labels - #- to_dtype: - #dtype: uint8 - #round_to_int: true - ## add zeros in 3rd dimension - ## map to unit - ##img file outputs, sogs uses lossless webp, sog uses avif + - input_fields: [shN_centroids_indices] + transforms: + - sort: + method: argsort + to_field: shN_centroids_indices_sorted_inverse + + - input_fields: [shN_centroids_indices_sorted_inverse] + transforms: + - reshape: + start_dim: 1 + shape: [1] #for reindexing + + - input_fields: [shN_labels, shN_centroids_indices_sorted_inverse] + transforms: + - reindex: + src_field: shN_centroids_indices_sorted_inverse + index_field: shN_labels + + #bug here: + - input_fields: [shN_centroids_indices_sorted_inverse] + transforms: + - to_field: + to_field_name: shN_labels + + - input_fields: [shN_labels] + transforms: + - reshape: + shape: [588, 588] #n_sidelen + - simple_quantize: + min: 0 + max: 65535 + dim: 2 + dtype: uint16 + round_to_int: False + - split_bytes: + to_fields_with_prefix: shN_labels_ + num_bytes: 2 + + - input_fields: [shN_labels_0, shN_labels_1] + transforms: + - combine: + method: stack-zeros + dim: 2 + to_field: shN_labels + + #img file outputs - input_fields: [means_l] transforms: - write_file: @@ -268,7 +279,7 @@ ops: method: 6 exact: true - - input_fields: [sh0] #TODO: CHECK, if it is rgba + - input_fields: [sh0] #rgba transforms: - write_file: type: image @@ -278,25 +289,29 @@ ops: quality: 100 method: 6 exact: true - #- input_fields: [shN_centroids] #rgb - #transforms: - #- write_file: - #type: webp - #coding_params: - #lossless: true - #quality: 100 - #method: 6 - #exact: true - #- input_fields: [shN_labels] #rgb - #transforms: - #- write_file: - #type: webp - #coding_params: - #lossless: true - #quality: 100 - #method: 6 - #exact: true -#Done: means,scales, opacities, sh0, shN, lexsort, argsort, output webp + + - input_fields: [shN_centroids] #rgb + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + + - input_fields: [shN_labels] #rgb , result is suprisingly less noisy than in sogs + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true +#Done: means,scales, opacities, sh0, shN, lexsort, argsort, output webp, #TODO: dynamic number of clusters #TODO: quaternions # check reshaping and transposes diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index c1be7ac..dab2c35 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -470,36 +470,6 @@ def apply( return new_fields, decoding_update -class Pack(Transformation): - @staticmethod - @override - def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False - ) -> tuple[dict[str, Field], list[dict[str, Any]]]: - input_fields = parentOp.input_fields - - field_name = next(iter(input_fields.keys())) - field_data = input_fields[field_name].data - - new_fields: dict[str, Field] = {} - decoding_update: list[dict[str, Any]] = [] - - match params: - case {"multiple": multiple, "length": length}: - # /num clusters / 3 - # multiple = params.get("mutliple", 64) - # length = params.get("length", 1) - field_data = field_data.reshape((-1, int(length * multiple / 3), 3)) - case _: - raise ValueError(f"Unknown Pack parameters: {params}") - decoding_update.append({ - "input_fields": [field_name], - "transforms": [{"pack": {"multiple": multiple, "length": length}}], - }) # TODO: this is not correct, need to flatten - new_fields[field_name] = Field(field_data, parentOp) - return new_fields, decoding_update - - class Permute(Transformation): @staticmethod @override @@ -630,7 +600,7 @@ def apply( new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] - # TODO: make this compatible with single + # TODO: make this compatible with 1D-tensors match params: case {"src_field": src_field_name, "index_field": index_field_name}: index_field_obj = input_fields[index_field_name] @@ -877,10 +847,10 @@ def apply( field_data = torch.stack(tensors, dim=dim) elif method == "concat": field_data = torch.cat(tensors, dim=dim) - elif method == "concat-zeros": + elif method == "stack-zeros": zeros = torch.zeros(tensors[0].shape, dtype=tensors[0].dtype, device=tensors[0].device) tensors.append(zeros) - field_data = torch.cat(tensors, dim=dim) + field_data = torch.stack(tensors, dim=dim) else: raise ValueError(f"Unsupported combine method: {method}") new_fields[to_field_name] = Field(field_data, parentOp) @@ -1208,7 +1178,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "plas": PLAS, "lookup": Lookup, "combine": Combine, - "pack": Pack, "write_file": WriteFile, "read_file": ReadFile, "simple_quantize": SimpleQuantize, From 3152359cbb126cd515e1939af6c9557db6f85ba2 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Wed, 11 Jun 2025 17:30:39 +0200 Subject: [PATCH 07/24] 1D-reindexing, quaternions (intermediate) --- src/ffsplat/conf/format/SOG-canvas.yaml | 67 ++++++++++--------- src/ffsplat/models/transformations.py | 86 +++++++++++++++++++++++-- 2 files changed, 117 insertions(+), 36 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 81e5f38..ce8a55a 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -25,9 +25,7 @@ ops: - remapping: method: signed-log to_field: means - #- input_fields: [quaternions] - #transforms: - #- + - input_fields: [scales] transforms: - remapping: @@ -38,6 +36,12 @@ ops: - remapping: method: inverse-sigmoid + - input_fields: [quaternions] + transforms: + - reparametize: + method: unit_sphere + dim: -1 + - input_fields: [means, sh0, shN, opacities, scales, quaternions] transforms: - plas: @@ -106,11 +110,22 @@ ops: src_field: opacities index_field: sorted_indices - #- input_fields: [sh0] - #transforms: - #- reshape: - #start_dim: 2 - #shape: [3] + - input_fields: [quaternions, sorted_indices] + transforms: + - reindex: + src_field: quaternions + index_field: sorted_indices + + - input_fields: [quaternions] #pack quaternions + transforms: + - reparametize: + method: pack_dynamic + - simple_quantize: + min_values: [0, 0, 0, 252] + max_values: [255, 255, 255, 255] + dim: 2 + dtype: uint8 + round_to_int: true - input_fields: [sh0, sorted_indices] transforms: @@ -162,7 +177,7 @@ ops: - simple_quantize: min: 0 max: 255 - dim: 20 #dummy dim + dim: 2 dtype: uint8 round_to_int: true @@ -172,29 +187,12 @@ ops: method: lexicographic to_field: shN_centroids_indices - - input_fields: [shN_centroids_indices] #for reindexing - transforms: - - reshape: - start_dim: 1 - shape: [1] - - - input_fields: [shN_labels] #for reindexing - transforms: - - reshape: - start_dim: 1 - shape: [1] - - input_fields: [shN_centroids, shN_centroids_indices] transforms: - reindex: src_field: shN_centroids index_field: shN_centroids_indices - - input_fields: [shN_centroids_indices] - transforms: - - reshape: - shape: [-1] #for argsort - - input_fields: [shN_centroids] transforms: - reshape: @@ -206,12 +204,6 @@ ops: method: argsort to_field: shN_centroids_indices_sorted_inverse - - input_fields: [shN_centroids_indices_sorted_inverse] - transforms: - - reshape: - start_dim: 1 - shape: [1] #for reindexing - - input_fields: [shN_labels, shN_centroids_indices_sorted_inverse] transforms: - reindex: @@ -279,6 +271,17 @@ ops: method: 6 exact: true + - input_fields: [quaternions] #rgba + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + - input_fields: [sh0] #rgba transforms: - write_file: diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index dab2c35..5a22192 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -382,6 +382,79 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: return dynamic_params_config +class Reparametrize(Transformation): + @staticmethod + @override + def apply( + params: dict[str, Any], parentOp: "Operation", verbose: bool = False + ) -> tuple[dict[str, Field], list[dict[str, Any]]]: + input_fields = parentOp.input_fields + + field_name = next(iter(input_fields.keys())) + field_data = input_fields[field_name].data + + new_fields: dict[str, Field] = {} + decoding_update: list[dict[str, Any]] = [] + + match params: + case {"method": "unit_sphere", "dim": dim}: + field_data = field_data / torch.linalg.norm(field_data, dim=dim, keepdim=True) + + sign_mask = field_data.select(dim, 3) < 0 + # sign_mask = field_data.select(dim, 0) < 0 + field_data[sign_mask] *= -1 + new_fields[field_name] = Field(field_data, parentOp) + + # field_data = field_data.narrow(dim, 0, field_data.shape[dim] - 1) + # field_data = field_data.narrow(dim, 1,field_data.shape[dim] - 1) + # decoding_update.append({ + # "unit_sphere_recover_last": { + # "dim": dim, + # } + # }) + case {"method": "pack_dynamic"}: + abs_field_data = field_data.abs() + max_idx = abs_field_data.argmax(dim=-1) + + # ensure largest component is positive + max_vals = field_data.gather(-1, max_idx.unsqueeze(-1)).squeeze(-1) + sign = max_vals.sign() + sign[sign == 0] = 1 + q_signed = field_data * sign.unsqueeze(-1) + + # build variants dropping each component + variants = [] + for i in range(4): + dims = list(range(4)) + dims.remove(i) + variants.append(q_signed[..., dims]) # (...,3) + stacked = torch.stack(variants, dim=-2) # (...,4,3) + + # select the appropriate 3-vector based on max_idx + idx_exp = max_idx.unsqueeze(-1).unsqueeze(-1).expand(*max_idx.shape, 1, 3) + small = torch.gather(stacked, dim=-2, index=idx_exp).squeeze(-2) # (...,3) + + max_idx = max_idx.to(small.dtype) + + # scale by sqrt(2) to normalize range to [-1,1] + # Apply brightness increase by multiplying by sqrt(2) + # TODO: this needs to be done s.t. it doesn't get lost in "simple quantizaton" + small = small * torch.sqrt(torch.tensor(2.0, device=small.device, dtype=small.dtype)) + + # Ensure the brightness increase is preserved after SimpleQuantize + # Map from [-1,1] to [0,1] before quantization + # small = small * 0.5 + 0.5 + small = small.clamp(max=1.0) + # max_idx = (252.0 + max_idx.to(torch.float32)) / 255.0 + packed = torch.cat([small, max_idx.unsqueeze(-1)], dim=-1) + new_fields[field_name] = Field(packed, parentOp) + + case _: + raise ValueError(f"Unknown ToField parameters: {params}") + + return new_fields, decoding_update + + class ToField(Transformation): @staticmethod @override @@ -604,8 +677,8 @@ def apply( match params: case {"src_field": src_field_name, "index_field": index_field_name}: index_field_obj = input_fields[index_field_name] - if len(index_field_obj.data.shape) != 2: - raise ValueError("Expecting grid for re-index operation") + # if len(index_field_obj.data.shape) != 2: + # raise ValueError("Expecting grid for re-index operation") decoding_update.append({ "input_fields": [src_field_name], "transforms": [{"flatten": {"start_dim": 0, "end_dim": 1}}], @@ -1107,12 +1180,11 @@ def apply( "max_values": max_values, "dim": dim, "dtype": dtype_str, + "round_to_int": round_to_int, }: if field_data is None: raise ValueError("Field data is None before channelwise remapping") - field_data = convert_to_dtype(field_data, dtype_str) - min_tensor = torch.tensor(min_values, device=field_data.device, dtype=torch.float32) max_tensor = torch.tensor(max_values, device=field_data.device, dtype=torch.float32) @@ -1128,6 +1200,11 @@ def apply( field_data = minmax(field_data) field_data = field_data * field_range + min_tensor + + if round_to_int: + field_data = torch.round(field_data) + field_data = convert_to_dtype(field_data, dtype_str) + case _: raise ValueError(f"Unknown remapping parameters: {params}") new_fields[field_name] = Field(field_data, parentOp) @@ -1169,6 +1246,7 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "flatten": Flatten, "reshape": Reshape, "remapping": Remapping, + "reparametize": Reparametrize, "to_field": ToField, "permute": Permute, "to_dtype": ToDType, From 34aa55b38da9cdcd7f0cbba3edf9948e1dba7cb1 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Wed, 11 Jun 2025 17:43:34 +0200 Subject: [PATCH 08/24] quaternions brightness increase --- src/ffsplat/models/transformations.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 5a22192..ad1f5a5 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -436,16 +436,13 @@ def apply( max_idx = max_idx.to(small.dtype) + # TODO: this should be done as remapping # scale by sqrt(2) to normalize range to [-1,1] # Apply brightness increase by multiplying by sqrt(2) - # TODO: this needs to be done s.t. it doesn't get lost in "simple quantizaton" small = small * torch.sqrt(torch.tensor(2.0, device=small.device, dtype=small.dtype)) - # Ensure the brightness increase is preserved after SimpleQuantize - # Map from [-1,1] to [0,1] before quantization - # small = small * 0.5 + 0.5 - small = small.clamp(max=1.0) - # max_idx = (252.0 + max_idx.to(torch.float32)) / 255.0 + small = small * 0.5 + 0.5 + small = torch.clamp(small * 255.0, max=255.0) packed = torch.cat([small, max_idx.unsqueeze(-1)], dim=-1) new_fields[field_name] = Field(packed, parentOp) From 479c4c3f3953783bdc52b4c08e722d3babd8caf1 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Mon, 16 Jun 2025 11:42:57 +0200 Subject: [PATCH 09/24] seperate quaternion packing and brightness increase --- src/ffsplat/conf/format/SOG-canvas.yaml | 38 +++++++++++++++---- src/ffsplat/models/transformations.py | 49 +++++++++++++++++++------ 2 files changed, 69 insertions(+), 18 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index ce8a55a..1c2bb5c 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -120,6 +120,34 @@ ops: transforms: - reparametize: method: pack_dynamic + to_fields_with_prefix: quaternions_packed_ + + - input_fields: [quaternions_packed_values] + transforms: + - multiply_add: + multiply: sqrt2 + add: 0 + clamp: false + - multiply_add: + multiply: 0.5 + add: 0.5 + clamp: false + - multiply_add: # could be clipped to 0,1 here, quantization minmax to 0-255 range, but sogs clips after * 255 + multiply: 255 + add: 0 + clamp: true + min: 0 + max: 255 + + - input_fields: [quaternions_packed_values, quaternions_packed_indices] + transforms: + - combine: + method: concat + dim: 2 + to_field: quaternions + + - input_fields: [quaternions] + transforms: - simple_quantize: min_values: [0, 0, 0, 252] max_values: [255, 255, 255, 255] @@ -168,7 +196,7 @@ ops: shape: [-1, 45] # 3*15 spherical harmonics - cluster: method: kmeans - num_clusters: 65536 #TODO: is dynamic in sogs + num_clusters: 65536 #is dynamic in sogs distance: manhattan to_fields_with_prefix: shN_ @@ -196,7 +224,7 @@ ops: - input_fields: [shN_centroids] transforms: - reshape: - shape: [-1, 960, 3] # int(num_clusters*num_spherical_harmonics/3 = + shape: [-1, 960, 3] # int(num_clusters*num_spherical_harmonics/3) = - input_fields: [shN_centroids_indices] transforms: @@ -314,8 +342,4 @@ ops: quality: 100 method: 6 exact: true -#Done: means,scales, opacities, sh0, shN, lexsort, argsort, output webp, -#TODO: dynamic number of clusters -#TODO: quaternions -# check reshaping and transposes -# TODO: side_len grid reshaping?-> doen by sorting topk already?, meta-data sogs, meta-data ffsplat +#TODO: decoding, output of metadata diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index ad1f5a5..ff502b5 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -382,6 +382,38 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: return dynamic_params_config +class MultiplyAdd(Transformation): + @staticmethod + @override + def apply( + params: dict[str, Any], parentOp: "Operation", verbose: bool = False + ) -> tuple[dict[str, Field], list[dict[str, Any]]]: + input_fields = parentOp.input_fields + + field_name = next(iter(input_fields.keys())) + field_data = input_fields[field_name].data + + new_fields: dict[str, Field] = {} + decoding_update: list[dict[str, Any]] = [] + + match params: + case {"multiply": mul, "add": add, "clamp": clamp_}: + if mul == "sqrt2": + mul = torch.sqrt(torch.tensor(2.0, device=field_data.device, dtype=field_data.dtype)) + field_data = field_data * mul + add + if clamp_: + match params: + case {"min": min_, "max": max_}: + field_data = field_data.clamp(min=min_, max=max_) + case _: + raise ValueError("Clamp parameters must contain 'min' and 'max' values") + case _: + raise ValueError(f"Unknown MulAdd parameters: {params}") + + new_fields[field_name] = Field(field_data, parentOp) + return new_fields, decoding_update + + class Reparametrize(Transformation): @staticmethod @override @@ -412,7 +444,7 @@ def apply( # "dim": dim, # } # }) - case {"method": "pack_dynamic"}: + case {"method": "pack_dynamic", "to_fields_with_prefix": to_fields_with_prefix}: abs_field_data = field_data.abs() max_idx = abs_field_data.argmax(dim=-1) @@ -432,19 +464,13 @@ def apply( # select the appropriate 3-vector based on max_idx idx_exp = max_idx.unsqueeze(-1).unsqueeze(-1).expand(*max_idx.shape, 1, 3) - small = torch.gather(stacked, dim=-2, index=idx_exp).squeeze(-2) # (...,3) + values = torch.gather(stacked, dim=-2, index=idx_exp).squeeze(-2) # (...,3) - max_idx = max_idx.to(small.dtype) + max_idx = max_idx.to(values.dtype).unsqueeze(-1) - # TODO: this should be done as remapping - # scale by sqrt(2) to normalize range to [-1,1] # Apply brightness increase by multiplying by sqrt(2) - small = small * torch.sqrt(torch.tensor(2.0, device=small.device, dtype=small.dtype)) - # Ensure the brightness increase is preserved after SimpleQuantize - small = small * 0.5 + 0.5 - small = torch.clamp(small * 255.0, max=255.0) - packed = torch.cat([small, max_idx.unsqueeze(-1)], dim=-1) - new_fields[field_name] = Field(packed, parentOp) + new_fields[f"{to_fields_with_prefix}indices"] = Field(max_idx, parentOp) + new_fields[f"{to_fields_with_prefix}values"] = Field(values, parentOp) case _: raise ValueError(f"Unknown ToField parameters: {params}") @@ -1243,6 +1269,7 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "flatten": Flatten, "reshape": Reshape, "remapping": Remapping, + "multiply_add": MultiplyAdd, "reparametize": Reparametrize, "to_field": ToField, "permute": Permute, From 1a18a3b156484f65c16ea4ceb906708cb1416c46 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Mon, 16 Jun 2025 16:55:22 +0200 Subject: [PATCH 10/24] decoding multiply-add --- src/ffsplat/conf/format/SOG-canvas.yaml | 19 ++--- src/ffsplat/models/transformations.py | 104 +++++++++++++++++++++--- 2 files changed, 100 insertions(+), 23 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 1c2bb5c..c78f01d 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -124,18 +124,17 @@ ops: - input_fields: [quaternions_packed_values] transforms: - - multiply_add: - multiply: sqrt2 - add: 0 - clamp: false - - multiply_add: + - linear: + method: multiply + value: sqrt2 + - linear: + method: multiply-add multiply: 0.5 add: 0.5 - clamp: false - - multiply_add: # could be clipped to 0,1 here, quantization minmax to 0-255 range, but sogs clips after * 255 - multiply: 255 - add: 0 - clamp: true + - linear: + method: multiply + value: 255 + - clamp: min: 0 max: 255 diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index ff502b5..45a04c8 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -382,7 +382,7 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: return dynamic_params_config -class MultiplyAdd(Transformation): +class Linear(Transformation): @staticmethod @override def apply( @@ -397,18 +397,95 @@ def apply( decoding_update: list[dict[str, Any]] = [] match params: - case {"multiply": mul, "add": add, "clamp": clamp_}: - if mul == "sqrt2": - mul = torch.sqrt(torch.tensor(2.0, device=field_data.device, dtype=field_data.dtype)) - field_data = field_data * mul + add - if clamp_: - match params: - case {"min": min_, "max": max_}: - field_data = field_data.clamp(min=min_, max=max_) - case _: - raise ValueError("Clamp parameters must contain 'min' and 'max' values") + case {"method": "multiply-add", "multiply": multiply_, "add": add_}: + if multiply_ == "sqrt2": + multiply_ = torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) + # add = torch.tensor([add],dtype=field_data.dtype,device=field_data.device) + field_data = field_data * multiply_ + add_ + multiply_ = multiply_.item() if isinstance(multiply_, Tensor) else multiply_ + # add = add.item() if isinstance(add, Tensor) else add + decoding_update.append({ + "input_fields": [field_name], + "transforms": [ + { + "linear": { + "method": "add-multiply", + "add": -add_, + "multiply": 1.0 / multiply_ if multiply_ != 0 else 0, + } + } + ], + }) + case {"method": "add-multiply", "add": add_, "multiply": multiply_}: + if multiply_ == "sqrt2": + multiply_ = torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) + field_data = (field_data + add_) * multiply_ + multiply_ = multiply_.item() if isinstance(multiply_, Tensor) else multiply_ + decoding_update.append({ + "input_fields": [field_name], + "transforms": [ + { + "linear": { + "method": "multiply-add", + "multiply": 1.0 / multiply_ if multiply_ != 0 else 0, + "add": -add_, + } + } + ], + }) + case {"method": "add", "value": value}: + field_data = field_data + value + decoding_update.append({ + "input_fields": [field_name], + "transforms": [ + { + "linear": { + "method": "add", + "value": -value, + } + } + ], + }) + case {"method": "multiply", "value": value}: + if value == "sqrt2": + value = torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) + field_data = field_data * value + value = value.item() if isinstance(value, Tensor) else value + decoding_update.append({ + "input_fields": [field_name], + "transforms": [ + { + "linear": { + "method": "multiply", + "value": 1.0 / value if value != 0 else 0, + } + } + ], + }) + + new_fields[field_name] = Field(field_data, parentOp) + return new_fields, decoding_update + + +class Clamp(Transformation): # this is lossy, decoding can't recover values outside the range + @staticmethod + @override + def apply( + params: dict[str, Any], parentOp: "Operation", verbose: bool = False + ) -> tuple[dict[str, Field], list[dict[str, Any]]]: + input_fields = parentOp.input_fields + + field_name = next(iter(input_fields.keys())) + field_data = input_fields[field_name].data + + new_fields: dict[str, Field] = {} + decoding_update: list[dict[str, Any]] = [] + + match params: + case {"min": min_val, "max": max_val}: + field_data = field_data.clamp(min=min_val, max=max_val) case _: - raise ValueError(f"Unknown MulAdd parameters: {params}") + raise ValueError(f"Unknown Clamp parameters: {params}") new_fields[field_name] = Field(field_data, parentOp) return new_fields, decoding_update @@ -1269,7 +1346,8 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "flatten": Flatten, "reshape": Reshape, "remapping": Remapping, - "multiply_add": MultiplyAdd, + "linear": Linear, + "clamp": Clamp, "reparametize": Reparametrize, "to_field": ToField, "permute": Permute, From 9cda62a6207e7f1a8fc23b0741f5714173636a12 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Mon, 16 Jun 2025 16:57:26 +0200 Subject: [PATCH 11/24] decoding packed quaternions --- src/ffsplat/models/transformations.py | 74 +++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 9 deletions(-) diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 45a04c8..932c00d 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -514,16 +514,9 @@ def apply( field_data[sign_mask] *= -1 new_fields[field_name] = Field(field_data, parentOp) - # field_data = field_data.narrow(dim, 0, field_data.shape[dim] - 1) - # field_data = field_data.narrow(dim, 1,field_data.shape[dim] - 1) - # decoding_update.append({ - # "unit_sphere_recover_last": { - # "dim": dim, - # } - # }) - case {"method": "pack_dynamic", "to_fields_with_prefix": to_fields_with_prefix}: + case {"method": "pack_dynamic", "to_fields_with_prefix": to_fields_with_prefix, "dim": dim}: abs_field_data = field_data.abs() - max_idx = abs_field_data.argmax(dim=-1) + max_idx = abs_field_data.argmax(dim=dim) # ensure largest component is positive max_vals = field_data.gather(-1, max_idx.unsqueeze(-1)).squeeze(-1) @@ -548,6 +541,69 @@ def apply( # Apply brightness increase by multiplying by sqrt(2) new_fields[f"{to_fields_with_prefix}indices"] = Field(max_idx, parentOp) new_fields[f"{to_fields_with_prefix}values"] = Field(values, parentOp) + decoding_update.append({ + "input_fields": [f"{to_fields_with_prefix}indices", f"{to_fields_with_prefix}values"], + "transforms": [ + { + "reparametize": { + "method": "unpack_dynamic", + "from_fields_with_prefix": to_fields_with_prefix, + "dim": -1, + "to_field_name": field_name, + } + } + ], + }) + + case { + "method": "unpack_dynamic", + "from_fields_with_prefix": from_fields_with_prefix, + "dim": dim, + "to_field_name": to_field_name, + }: + # Retrieve the indices and values fields + indices_field = input_fields[f"{from_fields_with_prefix}indices"].data + values_field = input_fields[f"{from_fields_with_prefix}values"].data + + if values_field is None: + raise ValueError("values field data is None before unit sphere recovery") + if values_field.shape[dim] != 3: + raise ValueError(f"Field data shape mismatch for unit sphere recovery: {field_data.shape}") + + # Ensure indices are integers + indices_field = indices_field.to(torch.int64).squeeze(-1) + + # Compute squared norm of the partial vector + partial_norm_sq = (values_field**2).sum(dim=dim, keepdim=True) + + # Recover the missing component (always non-negative) + w = torch.sqrt(torch.clamp(1.0 - partial_norm_sq, min=0.0)) + + # Create a tensor to hold the reconstructed data + num_components = values_field.shape[-1] + 1 # Original data had one more component + reconstructed = torch.zeros( + *values_field.shape[:-1], num_components, dtype=values_field.dtype, device=values_field.device + ) + + # Scatter the values back into the reconstructed tensor + reconstructed.scatter_(-1, indices_field.unsqueeze(-1), src=w) + + # Dynamically place values_field into the indices not covered by indices_field + full_indices = torch.arange(num_components, device=values_field.device).view(1, -1) + + indices_field_exp = indices_field.unsqueeze(-1) + + # correct positions of values + value_mask = (full_indices != indices_field_exp).to(values_field.dtype) + + # Set values_field into the masked slots + value_positions = value_mask.bool() + reconstructed[value_positions] = values_field.flatten() + + # Optional re-normalization to mitigate numerical drift + field_data = reconstructed / torch.linalg.norm(reconstructed, dim=dim, keepdim=True) + + new_fields[to_field_name] = Field(field_data, parentOp) case _: raise ValueError(f"Unknown ToField parameters: {params}") From 6cf77ec380b7867fd75c15156709c419e3b579df Mon Sep 17 00:00:00 2001 From: fleischmann Date: Thu, 19 Jun 2025 15:33:43 +0200 Subject: [PATCH 12/24] simplify quaternion-remapping --- pyproject.toml | 1 + src/ffsplat/conf/format/SOG-canvas.yaml | 24 ++-- src/ffsplat/models/transformations.py | 150 ++++++------------------ 3 files changed, 51 insertions(+), 124 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f69210..f8ea5a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,6 +158,7 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "tests/*" = ["S101"] +"src/ffsplat/models/transformations.py"=["C901"] [tool.ruff.format] preview = true diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index c78f01d..b2c80cc 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -121,20 +121,22 @@ ops: - reparametize: method: pack_dynamic to_fields_with_prefix: quaternions_packed_ + dim: -1 + + #- input_fields: [quaternions_packed_values,quaternions_packed_indices] + #transforms: + #- reparametize: + #method: unpack_dynamic + #dim: 2 + #from_fields_with_prefix: quaternions_packed_ + #to_field_name: quaternions_new - input_fields: [quaternions_packed_values] transforms: - - linear: - method: multiply - value: sqrt2 - - linear: - method: multiply-add - multiply: 0.5 - add: 0.5 - - linear: - method: multiply - value: 255 - - clamp: + - remapping: + method: scale-sqrt2 + - remapping: + method: "minmax" min: 0 max: 255 diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 932c00d..0f2abe2 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -220,7 +220,9 @@ def apply( "to_field_list": to_field_list, }: chunks = field_data.split(split_size_or_sections, dim) - for target_field_name, chunk in zip(to_field_list, chunks, strict=False): + for target_field_name, chunk in zip(to_field_list, chunks): + if target_field_name == "_": + continue if squeeze: chunk = chunk.squeeze(dim) new_fields[target_field_name] = Field(chunk, parentOp) @@ -280,6 +282,20 @@ def apply( new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] match params: + case {"method": "scale-sqrt2"}: + scale = torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) + field_data = field_data * scale + decoding_update.append({ + "input_fields": [field_name], + "transforms": [{"remapping": {"method": "scale-inverse-sqrt2"}}], + }) + case {"method": "scale-inverse-sqrt2"}: + scale = 1.0 / torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) + field_data = field_data * scale + decoding_update.append({ + "input_fields": [field_name], + "transforms": [{"remapping": {"method": "scale-sqrt2"}}], + }) case {"method": "exp"}: field_data = torch.exp(field_data) case {"method": "sigmoid"}: @@ -364,6 +380,9 @@ def apply( normalized = normalized * (max_val_f - min_val_f) + min_val_f field_data = normalized + + case {"method": "canvas", "min": min_val, "max": max_val}: + field_data = field_data.clamp(min=min_val, max=max_val) case _: raise ValueError(f"Unknown remapping parameters: {params}") new_fields[field_name] = Field(field_data, parentOp) @@ -382,115 +401,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: return dynamic_params_config -class Linear(Transformation): - @staticmethod - @override - def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False - ) -> tuple[dict[str, Field], list[dict[str, Any]]]: - input_fields = parentOp.input_fields - - field_name = next(iter(input_fields.keys())) - field_data = input_fields[field_name].data - - new_fields: dict[str, Field] = {} - decoding_update: list[dict[str, Any]] = [] - - match params: - case {"method": "multiply-add", "multiply": multiply_, "add": add_}: - if multiply_ == "sqrt2": - multiply_ = torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) - # add = torch.tensor([add],dtype=field_data.dtype,device=field_data.device) - field_data = field_data * multiply_ + add_ - multiply_ = multiply_.item() if isinstance(multiply_, Tensor) else multiply_ - # add = add.item() if isinstance(add, Tensor) else add - decoding_update.append({ - "input_fields": [field_name], - "transforms": [ - { - "linear": { - "method": "add-multiply", - "add": -add_, - "multiply": 1.0 / multiply_ if multiply_ != 0 else 0, - } - } - ], - }) - case {"method": "add-multiply", "add": add_, "multiply": multiply_}: - if multiply_ == "sqrt2": - multiply_ = torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) - field_data = (field_data + add_) * multiply_ - multiply_ = multiply_.item() if isinstance(multiply_, Tensor) else multiply_ - decoding_update.append({ - "input_fields": [field_name], - "transforms": [ - { - "linear": { - "method": "multiply-add", - "multiply": 1.0 / multiply_ if multiply_ != 0 else 0, - "add": -add_, - } - } - ], - }) - case {"method": "add", "value": value}: - field_data = field_data + value - decoding_update.append({ - "input_fields": [field_name], - "transforms": [ - { - "linear": { - "method": "add", - "value": -value, - } - } - ], - }) - case {"method": "multiply", "value": value}: - if value == "sqrt2": - value = torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) - field_data = field_data * value - value = value.item() if isinstance(value, Tensor) else value - decoding_update.append({ - "input_fields": [field_name], - "transforms": [ - { - "linear": { - "method": "multiply", - "value": 1.0 / value if value != 0 else 0, - } - } - ], - }) - - new_fields[field_name] = Field(field_data, parentOp) - return new_fields, decoding_update - - -class Clamp(Transformation): # this is lossy, decoding can't recover values outside the range - @staticmethod - @override - def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False - ) -> tuple[dict[str, Field], list[dict[str, Any]]]: - input_fields = parentOp.input_fields - - field_name = next(iter(input_fields.keys())) - field_data = input_fields[field_name].data - - new_fields: dict[str, Field] = {} - decoding_update: list[dict[str, Any]] = [] - - match params: - case {"min": min_val, "max": max_val}: - field_data = field_data.clamp(min=min_val, max=max_val) - case _: - raise ValueError(f"Unknown Clamp parameters: {params}") - - new_fields[field_name] = Field(field_data, parentOp) - return new_fields, decoding_update - - class Reparametrize(Transformation): @staticmethod @override @@ -1045,6 +955,7 @@ def apply( ) -> tuple[dict[str, Field], list[dict[str, Any]]]: new_fields: dict[str, Field] = {} field_data: Tensor = torch.empty(0) + decoding_update: list[dict[str, Any]] = [] match params: case {"method": "bytes", "to_field": to_field_name}: @@ -1080,13 +991,28 @@ def apply( zeros = torch.zeros(tensors[0].shape, dtype=tensors[0].dtype, device=tensors[0].device) tensors.append(zeros) field_data = torch.stack(tensors, dim=dim) + decoding_update.append({ + "input_fields": [to_field_name], + "transforms": [ + { + "split": { + "split_size_or_sections": [ + t.shape[dim] if dim < len(t.shape) else 1 for t in tensors + ] + + [1], + "dim": dim, + "to_field_list": [*list(parentOp.input_fields), "_"], + } + } + ], + }) else: raise ValueError(f"Unsupported combine method: {method}") new_fields[to_field_name] = Field(field_data, parentOp) case _: raise ValueError(f"Unknown Combine parameters: {params}") - return new_fields, [] + return new_fields, decoding_update class Lookup(Transformation): @@ -1402,8 +1328,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "flatten": Flatten, "reshape": Reshape, "remapping": Remapping, - "linear": Linear, - "clamp": Clamp, "reparametize": Reparametrize, "to_field": ToField, "permute": Permute, From a4a9cc5ee81b7ad179c2bac1745f545151e133ff Mon Sep 17 00:00:00 2001 From: fleischmann Date: Tue, 24 Jun 2025 23:12:00 +0200 Subject: [PATCH 13/24] meta output canvas intermediate --- src/ffsplat/coding/scene_decoder.py | 6 +- src/ffsplat/coding/scene_encoder.py | 27 ++++-- src/ffsplat/conf/format/SOG-canvas.yaml | 20 ++--- src/ffsplat/models/operations.py | 4 +- src/ffsplat/models/transformations.py | 104 +++++++++++++++++++----- 5 files changed, 115 insertions(+), 46 deletions(-) diff --git a/src/ffsplat/coding/scene_decoder.py b/src/ffsplat/coding/scene_decoder.py index 99d294a..c975ecc 100644 --- a/src/ffsplat/coding/scene_decoder.py +++ b/src/ffsplat/coding/scene_decoder.py @@ -1,3 +1,4 @@ +import json from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path @@ -47,6 +48,9 @@ def with_input_path(self, input_path: Path) -> "DecodingParams": self.ops[0]["transforms"][0]["read_file"]["file_path"] = str(input_path) return self + def get_ops_hashable(self) -> str: + return json.dumps(self.ops, sort_keys=False) + @lru_cache def process_operation( @@ -56,7 +60,7 @@ def process_operation( """Process the operation and return the new fields and decoding updates.""" if verbose: print(f"Decoding {op}...") - return op.apply(verbose=verbose)[0] + return op.apply(verbose=verbose, decoding_params_hashable="")[0] @dataclass diff --git a/src/ffsplat/coding/scene_encoder.py b/src/ffsplat/coding/scene_encoder.py index 427603c..e18216e 100644 --- a/src/ffsplat/coding/scene_encoder.py +++ b/src/ffsplat/coding/scene_encoder.py @@ -1,4 +1,5 @@ import copy +import json from collections import defaultdict from collections.abc import Iterable from dataclasses import asdict, dataclass, field, is_dataclass @@ -92,6 +93,18 @@ def reverse_ops(self) -> None: continue self.ops[-1]["transforms"].reverse() + def get_ops_hashable(self) -> str: + return json.dumps(self.ops, sort_keys=False) + + def to_yaml(self) -> str: + """Convert the decoding parameters to a YAML string.""" + return yaml.dump( + asdict(self), + Dumper=SerializableDumper, + default_flow_style=False, + sort_keys=False, + ) + @dataclass class EncodingParams: @@ -128,14 +141,11 @@ def to_yaml_file(self, yaml_path: Path) -> None: @lru_cache -def process_operation( - op: Operation, - verbose: bool, -) -> tuple[dict[str, Field], list[dict[str, Any]]]: +def process_operation(op: Operation, verbose: bool, decoding_ops: str) -> tuple[dict[str, Field], list[dict[str, Any]]]: """Process the operation and return the new fields and decoding updates.""" if verbose: print(f"Encoding {op}...") - return op.apply(verbose=verbose) + return op.apply(verbose=verbose, decoding_params_hashable=decoding_ops) @dataclass @@ -154,10 +164,9 @@ def _encode_fields(self, verbose: bool) -> None: input_fields_params = op_params["input_fields"] for transform_param in op_params["transforms"]: op = Operation.from_json(input_fields_params, transform_param, self.fields, self.output_path) - if op.transform_type != "write_file": - new_fields, decoding_updates = process_operation(op, verbose=verbose) - else: - new_fields, decoding_updates = op.apply(verbose=verbose) + new_fields, decoding_updates = process_operation( + op, verbose=verbose, decoding_ops=self.decoding_params.to_yaml() + ) # , decoding_ops=self.decoding_params.ops) # if the coding_updates are not a copy the cache will be wrong for decoding_update in copy.deepcopy(decoding_updates): # if the last decoding update has the same input fields we can combine the transforms into one list diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index b2c80cc..283497b 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -145,9 +145,9 @@ ops: - combine: method: concat dim: 2 - to_field: quaternions + to_field: quats - - input_fields: [quaternions] + - input_fields: [quats] transforms: - simple_quantize: min_values: [0, 0, 0, 252] @@ -300,7 +300,7 @@ ops: method: 6 exact: true - - input_fields: [quaternions] #rgba + - input_fields: [quats] #rgba transforms: - write_file: type: image @@ -311,7 +311,7 @@ ops: method: 6 exact: true - - input_fields: [sh0] #rgba + - input_fields: [shN_centroids] #rgb transforms: - write_file: type: image @@ -322,7 +322,7 @@ ops: method: 6 exact: true - - input_fields: [shN_centroids] #rgb + - input_fields: [shN_labels] #rgb , result is suprisingly less noisy than in sogs transforms: - write_file: type: image @@ -333,14 +333,8 @@ ops: method: 6 exact: true - - input_fields: [shN_labels] #rgb , result is suprisingly less noisy than in sogs + - input_fields: [means, scales, quats, sh0, shN_centroids, shN_labels] transforms: - write_file: - type: image - image_codec: webp - coding_params: - lossless: true - quality: 100 - method: 6 - exact: true + type: canvas-metadata #TODO: decoding, output of metadata diff --git a/src/ffsplat/models/operations.py b/src/ffsplat/models/operations.py index 7f3071e..73c801b 100644 --- a/src/ffsplat/models/operations.py +++ b/src/ffsplat/models/operations.py @@ -69,5 +69,5 @@ def to_json(self) -> dict[str, Any]: "params": self.params, } - def apply(self, verbose: bool) -> tuple[dict[str, "Field"], list[dict[str, Any]]]: - return apply_transform(self, verbose=verbose) + def apply(self, verbose: bool, decoding_params_hashable: str) -> tuple[dict[str, "Field"], list[dict[str, Any]]]: + return apply_transform(self, verbose=verbose, decoding_params_hashable=decoding_params_hashable) diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 0f2abe2..529966d 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -1,3 +1,4 @@ +import json import math from abc import ABC, abstractmethod from collections.abc import Callable @@ -8,6 +9,7 @@ import cv2 import numpy as np import torch +import yaml from PIL import Image from pillow_heif import register_avif_opener # type: ignore[import-untyped] from plas import sort_with_plas # type: ignore[import-untyped] @@ -86,6 +88,13 @@ def write_image(output_file_path: Path, field_data: Tensor, file_type: str, codi raise ValueError(f"Unsupported file type: {file_type}") +def write_json(output_file_path: Path, data: dict[str, Any]) -> None: + """Write a dictionary to a JSON file.""" + + with open(output_file_path, "w") as f: + json.dump(data, f, indent=4) + + @dataclass class PLASConfig: """Configuration for PLAS sorting.""" @@ -121,7 +130,7 @@ class Transformation(ABC): @staticmethod @abstractmethod def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: """Apply the transformation to the input fields. returns new/updated fields and decoding updates. Transformations that are only available for decoding do return empty decoding updates""" @@ -137,7 +146,7 @@ class Cluster(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: # Implement the clustering logic here input_fields = parentOp.input_fields @@ -203,7 +212,7 @@ class Split(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -272,7 +281,7 @@ class Remapping(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -405,7 +414,7 @@ class Reparametrize(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -525,7 +534,7 @@ class ToField(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -551,7 +560,7 @@ class Flatten(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -583,7 +592,7 @@ class Reshape(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -613,7 +622,7 @@ class Permute(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -640,7 +649,7 @@ class ToDType(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -681,7 +690,7 @@ class SplitBytes(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -733,7 +742,7 @@ class Reindex(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -761,7 +770,7 @@ class Sort(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -881,7 +890,7 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: new_fields: dict[str, Field] = {} @@ -951,7 +960,7 @@ class Combine(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: new_fields: dict[str, Field] = {} field_data: Tensor = torch.empty(0) @@ -1019,7 +1028,7 @@ class Lookup(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -1038,8 +1047,12 @@ class WriteFile(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], + parentOp: "Operation", + verbose: bool = False, + **kwargs: Any, ) -> tuple[dict[str, Field], list[dict[str, Any]]]: + decoding_ops: list[dict[str, Any]] = kwargs.get("decoding_ops", []) decoding_update: list[dict[str, Any]] = [] match params: case {"type": "ply", "file_path": file_path, "base_path": base_path, "field_prefix": field_prefix}: @@ -1079,6 +1092,47 @@ def apply( } ], }) + case {"type": "canvas-metadata", "base_path": base_path}: + file_path = "meta.json" + output_file_path = Path(base_path) / Path(file_path) + field_names = list(parentOp.input_fields.keys()) + meta = {} + + for field_name in field_names: + field = parentOp.input_fields[field_name] + shape = field.data.shape + shape_meta = [int(np.array(shape[:-1]).prod()), shape[-1]] + meta[field_name] = {"shape": shape_meta, "files": []} + # decoding_ops = decoding_ops['ops'] + for op in decoding_ops: + transfroms_str = list(op["transforms"][0].keys()) + transform_types = [transformation_map[transform] for transform in transfroms_str] + # transform = list(op['transforms'][0].keys())[0] + if len(op["input_fields"]) == 0: + continue + input_field = op["input_fields"][0] + has_field_name_prefix = input_field.startswith(field_name) + # outputfiles + for t_str, t in zip(transfroms_str, transform_types): + if (t is WriteFile) and has_field_name_prefix: + codec = op["transforms"][0][t_str].get("image_codec") + # meta[field_name]["files"].append([f"{input_field}.{codec}"]) + # mins, maxs, dtype, in index + elif (t is SimpleQuantize) and input_field == field_name: + pass + # mins = [0] + # maxs = [0] + # dtype = "hell0" + # meta[field_name]["mins"] = mins + # meta[field_name]["maxs"] = maxs + # meta[field_name]["dtype"] = dtype + + # get original_input names -> via prefix + # assign output file names via prefix + # get pre quantization mins and maxs via ??? + + write_json(output_file_path, meta) + case _: raise ValueError(f"Unknown WriteFile parameters: {params}") @@ -1176,7 +1230,7 @@ class ReadFile(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] @@ -1206,7 +1260,7 @@ class SimpleQuantize(Transformation): @staticmethod @override def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False + params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any ) -> tuple[dict[str, Field], list[dict[str, Any]]]: input_fields = parentOp.input_fields @@ -1344,11 +1398,19 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: } -def apply_transform(parentOp: "Operation", verbose: bool) -> tuple[dict[str, "Field"], list[dict[str, Any]]]: +def apply_transform( + parentOp: "Operation", verbose: bool, decoding_params_hashable: str +) -> tuple[dict[str, "Field"], list[dict[str, Any]]]: transformation = transformation_map.get(parentOp.transform_type) if transformation is None: raise ValueError(f"Unknown transformation: {parentOp.transform_type}") - return transformation.apply(parentOp.params[parentOp.transform_type], parentOp, verbose) + elif transformation is WriteFile: + decoding_ops: list[dict[str, Any]] = yaml.load(decoding_params_hashable, Loader=yaml.SafeLoader)["ops"] + return transformation.apply( + parentOp.params[parentOp.transform_type], parentOp, verbose=verbose, decoding_ops=decoding_ops + ) + else: + return transformation.apply(parentOp.params[parentOp.transform_type], parentOp, verbose=verbose) def get_dynamic_params(params: dict[str, dict[str, Any]]) -> list[dict[str, Any]]: From abeab708e58c93748ff6f6ddb8cffac9d4f55234 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Thu, 26 Jun 2025 16:31:26 +0200 Subject: [PATCH 14/24] meta output canvas --- pyproject.toml | 2 +- src/ffsplat/conf/format/SOG-canvas.yaml | 31 +++++--- src/ffsplat/models/transformations.py | 100 +++++++++++++++--------- 3 files changed, 82 insertions(+), 51 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f8ea5a6..add9078 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "jsonargparse>=4.37.0,<5.0.0", # TODO imagecodecs version > 2023.9.18 produce much larger JPEG XL files: https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/issues/3 # but for numpy > 2, we require a modern imagecodecs - "imagecodecs[all]>=2024.12.30", + #"imagecodecs[all]>=2024.12.30", "PyYAML>=6.0.2,<7.0.0", "opencv-python>=4.11.0.86,<5.0.0.0", "pillow>=11.1.0,<12.0.0", diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 283497b..a93c955 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -18,7 +18,6 @@ ops: split_size_or_sections: [1, 15] dim: 1 squeeze: true #false - #TODO: in sogs transpose (1,2) is done after reading from .ply here not? - input_fields: [means] transforms: @@ -116,21 +115,18 @@ ops: src_field: quaternions index_field: sorted_indices - - input_fields: [quaternions] #pack quaternions + - input_fields: [quaternions] # for tracking decoding_ops + transforms: + - to_field: + to_field_name: quats + + - input_fields: [quats] transforms: - reparametize: method: pack_dynamic to_fields_with_prefix: quaternions_packed_ dim: -1 - #- input_fields: [quaternions_packed_values,quaternions_packed_indices] - #transforms: - #- reparametize: - #method: unpack_dynamic - #dim: 2 - #from_fields_with_prefix: quaternions_packed_ - #to_field_name: quaternions_new - - input_fields: [quaternions_packed_values] transforms: - remapping: @@ -311,6 +307,17 @@ ops: method: 6 exact: true + - input_fields: [sh0] #rgb + transforms: + - write_file: + type: image + image_codec: webp + coding_params: + lossless: true + quality: 100 + method: 6 + exact: true + - input_fields: [shN_centroids] #rgb transforms: - write_file: @@ -333,8 +340,8 @@ ops: method: 6 exact: true - - input_fields: [means, scales, quats, sh0, shN_centroids, shN_labels] + - input_fields: [means, scales, quats, sh0, shN] transforms: - write_file: type: canvas-metadata -#TODO: decoding, output of metadata +#TODO: decoding, diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 529966d..7d8adf0 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -92,7 +92,7 @@ def write_json(output_file_path: Path, data: dict[str, Any]) -> None: """Write a dictionary to a JSON file.""" with open(output_file_path, "w") as f: - json.dump(data, f, indent=4) + json.dump(data, f, indent=2) @dataclass @@ -1096,40 +1096,51 @@ def apply( file_path = "meta.json" output_file_path = Path(base_path) / Path(file_path) field_names = list(parentOp.input_fields.keys()) - meta = {} - + meta: dict[str, Any] = {} + # Readfile no input_fields for field_name in field_names: field = parentOp.input_fields[field_name] shape = field.data.shape shape_meta = [int(np.array(shape[:-1]).prod()), shape[-1]] - meta[field_name] = {"shape": shape_meta, "files": []} - # decoding_ops = decoding_ops['ops'] + if field_name == "sh0": + shape_meta = [shape_meta[0], 1, shape_meta[1]] + # to keep the order semi-consistent with original sogs + meta[field_name] = {"shape": shape_meta, "dtype": None, "mins": [], "maxs": [], "files": []} + for op in decoding_ops: - transfroms_str = list(op["transforms"][0].keys()) - transform_types = [transformation_map[transform] for transform in transfroms_str] - # transform = list(op['transforms'][0].keys())[0] - if len(op["input_fields"]) == 0: - continue - input_field = op["input_fields"][0] - has_field_name_prefix = input_field.startswith(field_name) - # outputfiles - for t_str, t in zip(transfroms_str, transform_types): - if (t is WriteFile) and has_field_name_prefix: - codec = op["transforms"][0][t_str].get("image_codec") - # meta[field_name]["files"].append([f"{input_field}.{codec}"]) - # mins, maxs, dtype, in index - elif (t is SimpleQuantize) and input_field == field_name: - pass - # mins = [0] - # maxs = [0] - # dtype = "hell0" - # meta[field_name]["mins"] = mins - # meta[field_name]["maxs"] = maxs - # meta[field_name]["dtype"] = dtype - - # get original_input names -> via prefix - # assign output file names via prefix - # get pre quantization mins and maxs via ??? + transforms_str = [next(iter(t_str_braced.keys())) for t_str_braced in op["transforms"]] + transform_types = [transformation_map[transform] for transform in transforms_str] + input_field = op["input_fields"][0] if len(op["input_fields"]) > 0 else "" + for idx, (t_str, t) in enumerate(zip(transforms_str, transform_types)): + if t is ReadFile: + field_name_read = op["transforms"][idx][t_str].get("field_name") + if field_name_read.startswith(field_name): + codec = op["transforms"][idx][t_str].get("image_codec") + meta[field_name]["files"].append(f"{field_name_read}.{codec}") + elif ( + (t is SimpleQuantize) + and input_field.startswith(field_name) + and not input_field.endswith("labels") + ): + meta[field_name]["dtype"] = op["transforms"][idx][t_str]["dtype"] + meta[field_name]["mins"] = op["transforms"][idx][t_str]["min_values"] + meta[field_name]["maxs"] = op["transforms"][idx][t_str]["max_values"] + elif t is Reparametrize: + if op["transforms"][idx][t_str].get("method") == "unpack_dynamic": + to_field_name = op["transforms"][idx][t_str].get("to_field_name") + if to_field_name == field_name: + meta[to_field_name]["encoding"] = "quaternion_packed" + + # inconsitencies in sogs, that are hardcoded: + if field_name == "quats": + if "mins" in meta[field_name]: + del meta[field_name]["mins"] + if "maxs" in meta[field_name]: + del meta[field_name]["maxs"] + meta[field_name]["dtype"] = "uint8" # is also hardcoded in sogs + if field_name == "shN": + # because is 8 by default and it's the only parameter requiring encoding parameters of more than one step back + meta[field_name]["quantization"] = 8 write_json(output_file_path, meta) @@ -1269,6 +1280,13 @@ def apply( new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] + torch_dtype_to_str = { + torch.float32: "float32", + torch.uint8: "uint8", + torch.uint16: "uint16", + torch.uint32: "uint32", + torch.int32: "int32", + } match params: case {"min": min_val, "max": max_val, "dim": dim, "dtype": dtype_str, "round_to_int": round_to_int}: min_val_f = float(min_val) @@ -1285,14 +1303,6 @@ def apply( field_data = normalized - torch_dtype_to_str = { - torch.float32: "float32", - torch.uint8: "uint8", - torch.uint16: "uint16", - torch.uint32: "uint32", - torch.int32: "int32", - } - decoding_update.append({ "input_fields": [field_name], "transforms": [ @@ -1337,6 +1347,20 @@ def apply( field_data = minmax(field_data) field_data = field_data * field_range + min_tensor + decoding_update.append({ + "input_fields": [field_name], + "transforms": [ + { + "simple_quantize": { + "min_values": min_tensor.tolist(), + "max_values": max_tensor.tolist(), + "dim": dim, + "dtype": torch_dtype_to_str[field_data.dtype], + } + } + ], + }) + if round_to_int: field_data = torch.round(field_data) field_data = convert_to_dtype(field_data, dtype_str) From 33883313f43ac8a3f4707233869abd3ea89fa016 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Mon, 30 Jun 2025 14:10:45 +0200 Subject: [PATCH 15/24] fixed decoding updates for live --- src/ffsplat/conf/format/SOG-canvas.yaml | 20 +++++-- src/ffsplat/models/transformations.py | 80 ++++++++++++++++--------- src/ffsplat/render/viewer.py | 9 +++ 3 files changed, 74 insertions(+), 35 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index a93c955..8ff7f77 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -17,7 +17,7 @@ ops: to_field_list: [sh0, shN] #in sogs f_dc=sh0, frest=shN split_size_or_sections: [1, 15] dim: 1 - squeeze: true #false + squeeze: false - input_fields: [means] transforms: @@ -44,7 +44,7 @@ ops: - input_fields: [means, sh0, shN, opacities, scales, quaternions] transforms: - plas: - prune_by: opacities # in sogs purned by 0.5 of number gaussians instead + prune_by: opacities # in sogs pruned by 0.5 of number gaussians instead scaling_fn: none #standardize # activated: true shuffle: true @@ -170,6 +170,13 @@ ops: start_dim: 2 shape: [1] + - input_fields: [sh0] + transforms: + - permute: + dims: [0, 1, 3, 2] + - flatten: + start_dim: 2 + - input_fields: [sh0, opacities] #opacity-sh0 rgba transforms: - combine: @@ -186,11 +193,13 @@ ops: dtype: uint8 round_to_int: true - ##shN - input_fields: [shN] transforms: - - reshape: - shape: [-1, 45] # 3*15 spherical harmonics + - flatten: + start_dim: 2 + - flatten: + start_dim: 0 + end_dim: 1 - cluster: method: kmeans num_clusters: 65536 #is dynamic in sogs @@ -344,4 +353,3 @@ ops: transforms: - write_file: type: canvas-metadata -#TODO: decoding, diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 7d8adf0..0dd8aeb 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -490,7 +490,7 @@ def apply( raise ValueError(f"Field data shape mismatch for unit sphere recovery: {field_data.shape}") # Ensure indices are integers - indices_field = indices_field.to(torch.int64).squeeze(-1) + indices_field = indices_field.round().to(torch.int64).squeeze(-1) # Compute squared norm of the partial vector partial_norm_sq = (values_field**2).sum(dim=dim, keepdim=True) @@ -573,6 +573,11 @@ def apply( decoding_update: list[dict[str, Any]] = [] match params: case {"start_dim": start_dim, "end_dim": end_dim}: + target_shape = field_data.shape + decoding_update.append({ + "input_fields": [field_name], + "transforms": [{"reshape": {"shape": target_shape}}], + }) field_data = field_data.flatten(start_dim=start_dim, end_dim=end_dim) case {"start_dim": start_dim}: target_shape = field_data.shape[start_dim:] @@ -603,6 +608,10 @@ def apply( decoding_update: list[dict[str, Any]] = [] match params: case {"start_dim": start_dim, "shape": shape}: + decoding_update.append({ + "input_fields": [field_name], + "transforms": [{"reshape": {"shape": field_data.shape}}], + }) target_shape = list(field_data.shape[:start_dim]) + list(shape) field_data = field_data.reshape(*target_shape) case {"shape": shape}: @@ -752,14 +761,14 @@ def apply( match params: case {"src_field": src_field_name, "index_field": index_field_name}: index_field_obj = input_fields[index_field_name] - # if len(index_field_obj.data.shape) != 2: - # raise ValueError("Expecting grid for re-index operation") - decoding_update.append({ - "input_fields": [src_field_name], - "transforms": [{"flatten": {"start_dim": 0, "end_dim": 1}}], - }) original_data = input_fields[src_field_name].data new_fields[src_field_name] = Field(original_data[index_field_obj.data], parentOp) + + if index_field_obj.data.ndim > 1: + decoding_update.append({ + "input_fields": [src_field_name], + "transforms": [{"flatten": {"start_dim": 0, "end_dim": 1}}], + }) case _: raise ValueError(f"Unknown Reindex parameters: {params}") @@ -791,8 +800,6 @@ def apply( sorted_indices = torch.argsort(field_data) case _: raise ValueError(f"Unknown Sort parameters: {params}") - # sorted_indices = sorted_indeces.reshape(params.get("shape",(64,-1))) - # new_fields[field_name] = Field(field_data_sorted, parentOp) new_fields[params["to_field"]] = Field(sorted_indices, parentOp) # TODO: decoding update??? @@ -992,31 +999,39 @@ def apply( tensors: list[Tensor] = [ parentOp.input_fields[source_field_name].data for source_field_name in parentOp.input_fields ] + decoding_update_dict = { + "input_fields": [to_field_name], + "transforms": [ + { + "split": { + "dim": dim, + "to_field_list": list(parentOp.input_fields), + } + } + ], + } if method == "stack": field_data = torch.stack(tensors, dim=dim) + decoding_update_dict["transforms"][0]["split"]["squeeze"] = (True,) elif method == "concat": field_data = torch.cat(tensors, dim=dim) + decoding_update_dict["transforms"][0]["split"]["squeeze"] = False elif method == "stack-zeros": zeros = torch.zeros(tensors[0].shape, dtype=tensors[0].dtype, device=tensors[0].device) tensors.append(zeros) field_data = torch.stack(tensors, dim=dim) - decoding_update.append({ - "input_fields": [to_field_name], - "transforms": [ - { - "split": { - "split_size_or_sections": [ - t.shape[dim] if dim < len(t.shape) else 1 for t in tensors - ] - + [1], - "dim": dim, - "to_field_list": [*list(parentOp.input_fields), "_"], - } - } - ], - }) + decoding_update_dict["transforms"][0]["split"]["squeeze"] = True + decoding_update_dict["transforms"][0]["split"]["to_field_list"] = [ + *list(parentOp.input_fields), + "_", + ] + else: raise ValueError(f"Unsupported combine method: {method}") + decoding_update_dict["transforms"][0]["split"]["split_size_or_sections"] = [ + t.shape[dim] if dim < len(t.shape) else 1 for t in tensors + ] + decoding_update.append(decoding_update_dict) new_fields[to_field_name] = Field(field_data, parentOp) case _: raise ValueError(f"Unknown Combine parameters: {params}") @@ -1259,6 +1274,8 @@ def apply( # TODO: only do this once? register_avif_opener() img_field_data = torch.tensor(np.array(Image.open(file_path))) + case "webp": + img_field_data = torch.tensor(np.array(Image.open(file_path))) new_fields[field_name] = Field.from_file(img_field_data, file_path, field_name) case _: @@ -1308,10 +1325,11 @@ def apply( "transforms": [ { "simple_quantize": { - "min_values": min_vals.tolist(), - "max_values": max_vals.tolist(), + "min_values": min_vals.tolist() if min_vals.ndim > 0 else [min_vals.item()], + "max_values": max_vals.tolist() if max_vals.ndim > 0 else [max_vals.item()], "dim": dim, "dtype": torch_dtype_to_str[field_data.dtype], + "round_to_int": False, # in backmapping usually don't want rounding, right? } } ], @@ -1331,6 +1349,9 @@ def apply( if field_data is None: raise ValueError("Field data is None before channelwise remapping") + min_vals = torch.amin(field_data, dim=[d for d in range(field_data.ndim) if d != dim]) + max_vals = torch.amax(field_data, dim=[d for d in range(field_data.ndim) if d != dim]) + min_tensor = torch.tensor(min_values, device=field_data.device, dtype=torch.float32) max_tensor = torch.tensor(max_values, device=field_data.device, dtype=torch.float32) @@ -1352,10 +1373,11 @@ def apply( "transforms": [ { "simple_quantize": { - "min_values": min_tensor.tolist(), - "max_values": max_tensor.tolist(), + "min_values": min_vals.tolist(), + "max_values": max_vals.tolist(), "dim": dim, "dtype": torch_dtype_to_str[field_data.dtype], + "round_to_int": False, # in backmapping usually don't want rounding, right? } } ], @@ -1366,7 +1388,7 @@ def apply( field_data = convert_to_dtype(field_data, dtype_str) case _: - raise ValueError(f"Unknown remapping parameters: {params}") + raise ValueError(f"Unknown simple_quantize parameters: {params}") new_fields[field_name] = Field(field_data, parentOp) return new_fields, decoding_update diff --git a/src/ffsplat/render/viewer.py b/src/ffsplat/render/viewer.py index 067aec4..21261f7 100644 --- a/src/ffsplat/render/viewer.py +++ b/src/ffsplat/render/viewer.py @@ -14,6 +14,15 @@ from ._renderer import Renderer, RenderTask +available_output_format: list[str] = [ + "3DGS_INRIA_ply", + "3DGS_INRIA_nosh_ply", + "SOG-web", + "SOG-web-nosh", + "SOG-web-sh-split", + "SOG-canvas", +] + @dataclasses.dataclass class CameraState: From 9c5ef8efe9c96c3d8c83335eadea29d7d91eeef3 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Mon, 30 Jun 2025 14:51:52 +0200 Subject: [PATCH 16/24] plas as sort method, fixed no_sh yaml --- src/ffsplat/conf/format/SOG-canvas.yaml | 3 +- src/ffsplat/conf/format/SOG-web-nosh.yaml | 7 ++-- src/ffsplat/conf/format/SOG-web-sh-split.yaml | 3 +- src/ffsplat/conf/format/SOG-web.yaml | 3 +- src/ffsplat/models/transformations.py | 32 ++++++------------- 5 files changed, 19 insertions(+), 29 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 8ff7f77..83d18c1 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -43,7 +43,8 @@ ops: - input_fields: [means, sh0, shN, opacities, scales, quaternions] transforms: - - plas: + - sort: + method: plas prune_by: opacities # in sogs pruned by 0.5 of number gaussians instead scaling_fn: none #standardize # activated: true diff --git a/src/ffsplat/conf/format/SOG-web-nosh.yaml b/src/ffsplat/conf/format/SOG-web-nosh.yaml index 8a0e8e4..64e9de7 100644 --- a/src/ffsplat/conf/format/SOG-web-nosh.yaml +++ b/src/ffsplat/conf/format/SOG-web-nosh.yaml @@ -28,15 +28,16 @@ ops: - input_fields: [sh] transforms: - split: - split_size_or_sections: [1] + split_size_or_sections: [1, 15] dim: 1 squeeze: false - to_field_list: [f_dc] + to_field_list: [f_dc, _] - input_fields: [means_bytes_0, means_bytes_1, f_dc, opacities, scales, quaternions] transforms: - - plas: + - sort: + method: plas prune_by: opacities scaling_fn: standardize # activated: true diff --git a/src/ffsplat/conf/format/SOG-web-sh-split.yaml b/src/ffsplat/conf/format/SOG-web-sh-split.yaml index 5c97125..057fe9c 100644 --- a/src/ffsplat/conf/format/SOG-web-sh-split.yaml +++ b/src/ffsplat/conf/format/SOG-web-sh-split.yaml @@ -44,7 +44,8 @@ ops: quaternions, ] transforms: - - plas: + - sort: + method: plas prune_by: opacities scaling_fn: standardize # activated: true diff --git a/src/ffsplat/conf/format/SOG-web.yaml b/src/ffsplat/conf/format/SOG-web.yaml index 8315aa6..20ba699 100644 --- a/src/ffsplat/conf/format/SOG-web.yaml +++ b/src/ffsplat/conf/format/SOG-web.yaml @@ -44,7 +44,8 @@ ops: quaternions, ] transforms: - - plas: + - sort: + method: plas prune_by: opacities scaling_fn: standardize # activated: true diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 0dd8aeb..1c239de 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -235,6 +235,7 @@ def apply( if squeeze: chunk = chunk.squeeze(dim) new_fields[target_field_name] = Field(chunk, parentOp) + to_field_list = [name for name in to_field_list if name != "_"] decoding_update.append({ "input_fields": to_field_list, @@ -798,15 +799,21 @@ def apply( sorted_indices = torch.tensor(sorted_indices, device=field_data.device) case {"method": "argsort"}: sorted_indices = torch.argsort(field_data) + case {"method": "plas"}: + plas_cfg = {k: v for k, v in params.items() if k != "method"} + sorted_indices = PLAS.plas_preprocess( + plas_cfg=PLASConfig(**plas_cfg), + fields=parentOp.input_fields, + verbose=verbose, + ) case _: raise ValueError(f"Unknown Sort parameters: {params}") new_fields[params["to_field"]] = Field(sorted_indices, parentOp) - # TODO: decoding update??? return new_fields, decoding_update -class PLAS(Transformation): +class PLAS: @staticmethod def as_grid_img(tensor: Tensor) -> Tensor: num_primitives = tensor.shape[0] @@ -894,26 +901,6 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo return sorted_indices - @staticmethod - @override - def apply( - params: dict[str, Any], parentOp: "Operation", verbose: bool = False, **kwargs: Any - ) -> tuple[dict[str, Field], list[dict[str, Any]]]: - new_fields: dict[str, Field] = {} - - plas_cfg_dict = params - if isinstance(plas_cfg_dict, dict): - sorted_indices = PLAS.plas_preprocess( - plas_cfg=PLASConfig(**plas_cfg_dict), - fields=parentOp.input_fields, - verbose=verbose, - ) - new_fields[plas_cfg_dict["to_field"]] = Field(sorted_indices, parentOp) - else: - raise TypeError(f"Unknown PLAS parameters: {params}") - - return new_fields, [] - @staticmethod def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: """Get the dynamic parameters for a given transformation type. This might modify the values in params.""" @@ -1435,7 +1422,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "split_bytes": SplitBytes, "reindex": Reindex, "sort": Sort, - "plas": PLAS, "lookup": Lookup, "combine": Combine, "write_file": WriteFile, From 808695c646943f83960ee87b20278bd7ad24b18b Mon Sep 17 00:00:00 2001 From: fleischmann Date: Mon, 30 Jun 2025 15:21:28 +0200 Subject: [PATCH 17/24] fixed quaternion unpacking --- src/ffsplat/models/transformations.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 1c239de..235a7e0 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -482,7 +482,8 @@ def apply( "to_field_name": to_field_name, }: # Retrieve the indices and values fields - indices_field = input_fields[f"{from_fields_with_prefix}indices"].data + # indices_field = input_fields[f"{from_fields_with_prefix}indices"].data + # Debug: check if indices_field contains anything else than the number 3 values_field = input_fields[f"{from_fields_with_prefix}values"].data if values_field is None: @@ -490,12 +491,8 @@ def apply( if values_field.shape[dim] != 3: raise ValueError(f"Field data shape mismatch for unit sphere recovery: {field_data.shape}") - # Ensure indices are integers - indices_field = indices_field.round().to(torch.int64).squeeze(-1) - - # Compute squared norm of the partial vector + # indices_field = indices_field.round().to(torch.int64).squeeze(-1) # indices_field is useless, all are 3 partial_norm_sq = (values_field**2).sum(dim=dim, keepdim=True) - # Recover the missing component (always non-negative) w = torch.sqrt(torch.clamp(1.0 - partial_norm_sq, min=0.0)) @@ -505,24 +502,12 @@ def apply( *values_field.shape[:-1], num_components, dtype=values_field.dtype, device=values_field.device ) - # Scatter the values back into the reconstructed tensor - reconstructed.scatter_(-1, indices_field.unsqueeze(-1), src=w) - - # Dynamically place values_field into the indices not covered by indices_field - full_indices = torch.arange(num_components, device=values_field.device).view(1, -1) - - indices_field_exp = indices_field.unsqueeze(-1) - - # correct positions of values - value_mask = (full_indices != indices_field_exp).to(values_field.dtype) - - # Set values_field into the masked slots - value_positions = value_mask.bool() - reconstructed[value_positions] = values_field.flatten() + # For wxyz convention + reconstructed[..., 0] = w.squeeze(-1) + reconstructed[..., 1:] = values_field # Optional re-normalization to mitigate numerical drift field_data = reconstructed / torch.linalg.norm(reconstructed, dim=dim, keepdim=True) - new_fields[to_field_name] = Field(field_data, parentOp) case _: From 15b5549495fe2e6652a42cb237d73d3d6622cc9a Mon Sep 17 00:00:00 2001 From: fleischmann Date: Thu, 3 Jul 2025 18:28:44 +0200 Subject: [PATCH 18/24] simplify tranforms, dynamic params, cleanup --- src/ffsplat/coding/scene_encoder.py | 2 +- src/ffsplat/conf/format/SOG-canvas.yaml | 69 ++++----- src/ffsplat/conf/format/SOG-web-nosh.yaml | 4 +- src/ffsplat/conf/format/SOG-web-png.yaml | 5 +- src/ffsplat/conf/format/SOG-web-sh-split.yaml | 2 +- src/ffsplat/conf/format/SOG-web.yaml | 2 +- src/ffsplat/models/transformations.py | 134 +++++++++--------- 7 files changed, 104 insertions(+), 114 deletions(-) diff --git a/src/ffsplat/coding/scene_encoder.py b/src/ffsplat/coding/scene_encoder.py index e18216e..449967c 100644 --- a/src/ffsplat/coding/scene_encoder.py +++ b/src/ffsplat/coding/scene_encoder.py @@ -166,7 +166,7 @@ def _encode_fields(self, verbose: bool) -> None: op = Operation.from_json(input_fields_params, transform_param, self.fields, self.output_path) new_fields, decoding_updates = process_operation( op, verbose=verbose, decoding_ops=self.decoding_params.to_yaml() - ) # , decoding_ops=self.decoding_params.ops) + ) # if the coding_updates are not a copy the cache will be wrong for decoding_update in copy.deepcopy(decoding_updates): # if the last decoding update has the same input fields we can combine the transforms into one list diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 83d18c1..cf8cc9c 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -14,7 +14,7 @@ ops: - input_fields: [sh] transforms: - split: - to_field_list: [sh0, shN] #in sogs f_dc=sh0, frest=shN + to_field_list: [sh0, shN] split_size_or_sections: [1, 15] dim: 1 squeeze: false @@ -45,16 +45,15 @@ ops: transforms: - sort: method: plas - prune_by: opacities # in sogs pruned by 0.5 of number gaussians instead - scaling_fn: none #standardize - # activated: true + prune_by: opacities + scaling_fn: none shuffle: true improvement_break: 1e-4 - to_field: sorted_indices + to_fields_with_prefix: sorted_ weights: means: 1.0 sh0: 1.0 - shN: 0.0 #shN not in sortkeys + shN: 0.0 opacities: 0.0 scales: 1.0 quaternions: 1.0 @@ -124,35 +123,35 @@ ops: - input_fields: [quats] transforms: - reparametize: - method: pack_dynamic - to_fields_with_prefix: quaternions_packed_ + method: pack_quaternions + to_fields_with_prefix: quats_packed_ dim: -1 - - input_fields: [quaternions_packed_values] + - input_fields: [quats_packed_indices] transforms: - - remapping: - method: scale-sqrt2 - - remapping: - method: "minmax" - min: 0 + - simple_quantize: + min: 252 max: 255 - - - input_fields: [quaternions_packed_values, quaternions_packed_indices] - transforms: - - combine: - method: concat dim: 2 - to_field: quats + dtype: uint8 + round_to_int: true - - input_fields: [quats] + - input_fields: [quats_packed_values] transforms: - simple_quantize: - min_values: [0, 0, 0, 252] - max_values: [255, 255, 255, 255] + min: 0 + max: 255 dim: 2 dtype: uint8 round_to_int: true + - input_fields: [quats_packed_values, quats_packed_indices] + transforms: + - combine: + method: concat + dim: 2 + to_field: quats + - input_fields: [sh0, sorted_indices] transforms: - reindex: @@ -203,7 +202,7 @@ ops: end_dim: 1 - cluster: method: kmeans - num_clusters: 65536 #is dynamic in sogs + num_clusters: 65536 distance: manhattan to_fields_with_prefix: shN_ @@ -220,7 +219,7 @@ ops: transforms: - sort: method: lexicographic - to_field: shN_centroids_indices + to_fields_with_prefix: shN_centroids_ - input_fields: [shN_centroids, shN_centroids_indices] transforms: @@ -233,20 +232,13 @@ ops: - reshape: shape: [-1, 960, 3] # int(num_clusters*num_spherical_harmonics/3) = - - input_fields: [shN_centroids_indices] - transforms: - - sort: - method: argsort - to_field: shN_centroids_indices_sorted_inverse - - - input_fields: [shN_labels, shN_centroids_indices_sorted_inverse] + - input_fields: [shN_labels, shN_centroids_indices_inverse] #labels are indices, centroids_inverse_indices are original location, match original location with transforms: - reindex: - src_field: shN_centroids_indices_sorted_inverse + src_field: shN_centroids_indices_inverse index_field: shN_labels - #bug here: - - input_fields: [shN_centroids_indices_sorted_inverse] + - input_fields: [shN_centroids_indices_inverse] transforms: - to_field: to_field_name: shN_labels @@ -272,7 +264,6 @@ ops: dim: 2 to_field: shN_labels - #img file outputs - input_fields: [means_l] transforms: - write_file: @@ -306,7 +297,7 @@ ops: method: 6 exact: true - - input_fields: [quats] #rgba + - input_fields: [quats] transforms: - write_file: type: image @@ -328,7 +319,7 @@ ops: method: 6 exact: true - - input_fields: [shN_centroids] #rgb + - input_fields: [shN_centroids] transforms: - write_file: type: image @@ -339,7 +330,7 @@ ops: method: 6 exact: true - - input_fields: [shN_labels] #rgb , result is suprisingly less noisy than in sogs + - input_fields: [shN_labels] transforms: - write_file: type: image diff --git a/src/ffsplat/conf/format/SOG-web-nosh.yaml b/src/ffsplat/conf/format/SOG-web-nosh.yaml index 64e9de7..a1c054d 100644 --- a/src/ffsplat/conf/format/SOG-web-nosh.yaml +++ b/src/ffsplat/conf/format/SOG-web-nosh.yaml @@ -28,7 +28,7 @@ ops: - input_fields: [sh] transforms: - split: - split_size_or_sections: [1, 15] + split_size_or_sections: [1] dim: 1 squeeze: false to_field_list: [f_dc, _] @@ -44,7 +44,7 @@ ops: shuffle: true improvement_break: 1e-4 # improvement_break: 0.1 - to_field: sorted_indices + to_fields_with_prefix: sorted_ weights: means_bytes_0: 0.1 means_bytes_1: 1.0 diff --git a/src/ffsplat/conf/format/SOG-web-png.yaml b/src/ffsplat/conf/format/SOG-web-png.yaml index 3a3d0ec..526bd93 100644 --- a/src/ffsplat/conf/format/SOG-web-png.yaml +++ b/src/ffsplat/conf/format/SOG-web-png.yaml @@ -44,14 +44,15 @@ ops: quaternions, ] transforms: - - plas: + - sort: + method: plas prune_by: opacities scaling_fn: standardize # activated: true shuffle: true improvement_break: 1e-4 # improvement_break: 0.1 - to_field: sorted_indices + to_fields_with_prefix: sorted_ weights: means_bytes_0: 0.2 means_bytes_1: 1.0 diff --git a/src/ffsplat/conf/format/SOG-web-sh-split.yaml b/src/ffsplat/conf/format/SOG-web-sh-split.yaml index 057fe9c..b86f085 100644 --- a/src/ffsplat/conf/format/SOG-web-sh-split.yaml +++ b/src/ffsplat/conf/format/SOG-web-sh-split.yaml @@ -52,7 +52,7 @@ ops: shuffle: true improvement_break: 1e-4 # improvement_break: 0.1 - to_field: sorted_indices + to_fields_with_prefix: sorted_ weights: means_bytes_0: 0.1 means_bytes_1: 1.0 diff --git a/src/ffsplat/conf/format/SOG-web.yaml b/src/ffsplat/conf/format/SOG-web.yaml index 20ba699..eb1613d 100644 --- a/src/ffsplat/conf/format/SOG-web.yaml +++ b/src/ffsplat/conf/format/SOG-web.yaml @@ -52,7 +52,7 @@ ops: shuffle: true improvement_break: 1e-4 # improvement_break: 0.1 - to_field: sorted_indices + to_fields_with_prefix: sorted_ weights: means_bytes_0: 0.2 means_bytes_1: 1.0 diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 235a7e0..e4c5dce 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -202,7 +202,7 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: dynamic_params_config.append({ "label": "distance", "type": "dropdown", - "values": ["euclidean", "cosine", "manhatten"], + "values": ["euclidean", "cosine", "manhattan"], }) return dynamic_params_config @@ -229,7 +229,7 @@ def apply( "to_field_list": to_field_list, }: chunks = field_data.split(split_size_or_sections, dim) - for target_field_name, chunk in zip(to_field_list, chunks): + for target_field_name, chunk in zip(to_field_list, chunks, strict=False): if target_field_name == "_": continue if squeeze: @@ -292,20 +292,6 @@ def apply( new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] match params: - case {"method": "scale-sqrt2"}: - scale = torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) - field_data = field_data * scale - decoding_update.append({ - "input_fields": [field_name], - "transforms": [{"remapping": {"method": "scale-inverse-sqrt2"}}], - }) - case {"method": "scale-inverse-sqrt2"}: - scale = 1.0 / torch.sqrt(torch.tensor(2.0, dtype=field_data.dtype, device=field_data.device)) - field_data = field_data * scale - decoding_update.append({ - "input_fields": [field_name], - "transforms": [{"remapping": {"method": "scale-sqrt2"}}], - }) case {"method": "exp"}: field_data = torch.exp(field_data) case {"method": "sigmoid"}: @@ -390,9 +376,6 @@ def apply( normalized = normalized * (max_val_f - min_val_f) + min_val_f field_data = normalized - - case {"method": "canvas", "min": min_val, "max": max_val}: - field_data = field_data.clamp(min=min_val, max=max_val) case _: raise ValueError(f"Unknown remapping parameters: {params}") new_fields[field_name] = Field(field_data, parentOp) @@ -430,35 +413,19 @@ def apply( field_data = field_data / torch.linalg.norm(field_data, dim=dim, keepdim=True) sign_mask = field_data.select(dim, 3) < 0 - # sign_mask = field_data.select(dim, 0) < 0 field_data[sign_mask] *= -1 new_fields[field_name] = Field(field_data, parentOp) - case {"method": "pack_dynamic", "to_fields_with_prefix": to_fields_with_prefix, "dim": dim}: - abs_field_data = field_data.abs() - max_idx = abs_field_data.argmax(dim=dim) - - # ensure largest component is positive - max_vals = field_data.gather(-1, max_idx.unsqueeze(-1)).squeeze(-1) - sign = max_vals.sign() + case {"method": "pack_quaternions", "to_fields_with_prefix": to_fields_with_prefix, "dim": dim}: + # Ensure the first component is positive + sign = field_data.select(dim, 0).sign() sign[sign == 0] = 1 q_signed = field_data * sign.unsqueeze(-1) - # build variants dropping each component - variants = [] - for i in range(4): - dims = list(range(4)) - dims.remove(i) - variants.append(q_signed[..., dims]) # (...,3) - stacked = torch.stack(variants, dim=-2) # (...,4,3) - - # select the appropriate 3-vector based on max_idx - idx_exp = max_idx.unsqueeze(-1).unsqueeze(-1).expand(*max_idx.shape, 1, 3) - values = torch.gather(stacked, dim=-2, index=idx_exp).squeeze(-2) # (...,3) - - max_idx = max_idx.to(values.dtype).unsqueeze(-1) + # Drop the first component (is always the largest) + values = q_signed.narrow(dim, 1, 3) + max_idx = torch.zeros_like(sign, dtype=values.dtype).unsqueeze(-1) - # Apply brightness increase by multiplying by sqrt(2) new_fields[f"{to_fields_with_prefix}indices"] = Field(max_idx, parentOp) new_fields[f"{to_fields_with_prefix}values"] = Field(values, parentOp) decoding_update.append({ @@ -466,7 +433,7 @@ def apply( "transforms": [ { "reparametize": { - "method": "unpack_dynamic", + "method": "unpack_quaternions", "from_fields_with_prefix": to_fields_with_prefix, "dim": -1, "to_field_name": field_name, @@ -476,14 +443,11 @@ def apply( }) case { - "method": "unpack_dynamic", + "method": "unpack_quaternions", "from_fields_with_prefix": from_fields_with_prefix, "dim": dim, "to_field_name": to_field_name, }: - # Retrieve the indices and values fields - # indices_field = input_fields[f"{from_fields_with_prefix}indices"].data - # Debug: check if indices_field contains anything else than the number 3 values_field = input_fields[f"{from_fields_with_prefix}values"].data if values_field is None: @@ -491,18 +455,16 @@ def apply( if values_field.shape[dim] != 3: raise ValueError(f"Field data shape mismatch for unit sphere recovery: {field_data.shape}") - # indices_field = indices_field.round().to(torch.int64).squeeze(-1) # indices_field is useless, all are 3 partial_norm_sq = (values_field**2).sum(dim=dim, keepdim=True) # Recover the missing component (always non-negative) w = torch.sqrt(torch.clamp(1.0 - partial_norm_sq, min=0.0)) - # Create a tensor to hold the reconstructed data num_components = values_field.shape[-1] + 1 # Original data had one more component reconstructed = torch.zeros( *values_field.shape[:-1], num_components, dtype=values_field.dtype, device=values_field.device ) - # For wxyz convention + # wxyz convention reconstructed[..., 0] = w.squeeze(-1) reconstructed[..., 1:] = values_field @@ -778,22 +740,28 @@ def apply( sorted_indices = None if field_data is None: raise ValueError("Field data is None before sorting") + prefix = params["to_fields_with_prefix"] match params: case {"method": "lexicographic"}: sorted_indices = np.lexsort(field_data.permute(dims=(1, 0)).cpu().numpy()) sorted_indices = torch.tensor(sorted_indices, device=field_data.device) - case {"method": "argsort"}: - sorted_indices = torch.argsort(field_data) + inverse_sorted_indices = torch.argsort(sorted_indices) + # square_keep_indices = PLAS.primitive_filter_pruning_to_square_shape(inverse_sorted_indices, verbose=False) + # inverse_sorted_indices = inverse_sorted_indices[square_keep_indices] case {"method": "plas"}: - plas_cfg = {k: v for k, v in params.items() if k != "method"} + plas_cfg = {k: v for k, v in params.items() if k not in ["method", "to_fields_with_prefix"]} + plas_cfg["to_field"] = f"{prefix}indices" sorted_indices = PLAS.plas_preprocess( plas_cfg=PLASConfig(**plas_cfg), fields=parentOp.input_fields, verbose=verbose, ) + inverse_sorted_indices = torch.argsort(sorted_indices) + case _: raise ValueError(f"Unknown Sort parameters: {params}") - new_fields[params["to_field"]] = Field(sorted_indices, parentOp) + new_fields[f"{prefix}indices"] = Field(sorted_indices, parentOp) + new_fields[f"{prefix}indices_inverse"] = Field(inverse_sorted_indices, parentOp) return new_fields, decoding_update @@ -1098,7 +1066,7 @@ def apply( transforms_str = [next(iter(t_str_braced.keys())) for t_str_braced in op["transforms"]] transform_types = [transformation_map[transform] for transform in transforms_str] input_field = op["input_fields"][0] if len(op["input_fields"]) > 0 else "" - for idx, (t_str, t) in enumerate(zip(transforms_str, transform_types)): + for idx, (t_str, t) in enumerate(zip(transforms_str, transform_types, strict=False)): if t is ReadFile: field_name_read = op["transforms"][idx][t_str].get("field_name") if field_name_read.startswith(field_name): @@ -1113,24 +1081,12 @@ def apply( meta[field_name]["mins"] = op["transforms"][idx][t_str]["min_values"] meta[field_name]["maxs"] = op["transforms"][idx][t_str]["max_values"] elif t is Reparametrize: - if op["transforms"][idx][t_str].get("method") == "unpack_dynamic": + if op["transforms"][idx][t_str].get("method") == "unpack_quaternions": to_field_name = op["transforms"][idx][t_str].get("to_field_name") if to_field_name == field_name: meta[to_field_name]["encoding"] = "quaternion_packed" - # inconsitencies in sogs, that are hardcoded: - if field_name == "quats": - if "mins" in meta[field_name]: - del meta[field_name]["mins"] - if "maxs" in meta[field_name]: - del meta[field_name]["maxs"] - meta[field_name]["dtype"] = "uint8" # is also hardcoded in sogs - if field_name == "shN": - # because is 8 by default and it's the only parameter requiring encoding parameters of more than one step back - meta[field_name]["quantization"] = 8 - write_json(output_file_path, meta) - case _: raise ValueError(f"Unknown WriteFile parameters: {params}") @@ -1147,12 +1103,14 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: dynamic_params_config.append({ "label": "image_codec", "type": "dropdown", - "values": ["avif", "png"], + "values": ["avif", "png", "webp"], "rebuild": True, }) # coding_params for image file dynamic_coding_params: list[dict[str, Any]] = [] + dynamic_coding_params: list[dict[str, Any]] = [] + coding_params: dict[str, Any] = params.get("coding_params", {}) match params["image_codec"]: case "avif": # check whether we need to update the coding params default: @@ -1217,6 +1175,46 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "type": "heading", "params": dynamic_coding_params, }) + case "webp": + if not all(key in coding_params for key in ["quality", "method", "exact", "lossless"]): + coding_params.clear() + coding_params["quality"] = 100 + coding_params["method"] = 6 + coding_params["exact"] = True + coding_params["lossless"] = True + dynamic_coding_params.append({ + "label": "exact", + "type": "bool", + "values": ["True", "False"], + }) + dynamic_coding_params.append({ + "label": "lossless", + "type": "bool", + "values": ["True", "False"], + }) + # filesize-speed tradeoff when lossless + dynamic_coding_params.append({ + "label": "quality", + "type": "number", + "min": 0, + "max": 100, + "step": 1, + "dtype": int, + }) + # also filesize-speed tradeoff, all compatible with lossless + dynamic_coding_params.append({ + "label": "method", + "type": "number", + "dtype": int, + "min": 0, + "max": 6, + "step": 1, + }) + dynamic_params_config.append({ + "label": "coding_params", + "type": "heading", + "params": dynamic_coding_params, + }) case _: raise ValueError(f"unknown image codec: {params["image_codec"]}") From b2cb9f63bb82712992d50493de0f9b12acf4f964 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Mon, 7 Jul 2025 18:04:27 +0200 Subject: [PATCH 19/24] dynamic sort params, moved lexsort label update into sort --- src/ffsplat/cli/live.py | 23 +++---- src/ffsplat/conf/format/SOG-canvas.yaml | 17 +----- src/ffsplat/models/transformations.py | 81 ++++++++++++++++++++++--- 3 files changed, 87 insertions(+), 34 deletions(-) diff --git a/src/ffsplat/cli/live.py b/src/ffsplat/cli/live.py index ea3dbba..9431615 100644 --- a/src/ffsplat/cli/live.py +++ b/src/ffsplat/cli/live.py @@ -397,17 +397,24 @@ def full_evaluation(self, _): self.enable_convert_ui() self.enable_load_buttons() - def _build_transform_folder(self, transform_folder, description, transformation, transform_type): + def _build_transform_folder(self, transform_folder, operation, transformation, transform_type): # clear transform folder for rebuild for child in tuple(transform_folder._children.values()): child.remove() - dynamic_params_conf = get_dynamic_params(transformation) + + if isinstance(operation["input_fields"], list): + input_field = operation["input_fields"] + description = f"input fields: {input_field}" + else: + input_field = operation["input_fields"]["from_fields_with_prefix"] + description = f"input fields from prefix: {input_field}" + dynamic_params_conf = get_dynamic_params(transformation, input_field) initial_values = transformation[transform_type] with transform_folder: self.viewer.server.gui.add_markdown(description) rebuild_fn = partial( - self._build_transform_folder, transform_folder, description, transformation, transform_type + self._build_transform_folder, transform_folder, operation, transformation, transform_type ) self._build_options_for_transformation(dynamic_params_conf, initial_values, rebuild_fn) @@ -503,20 +510,14 @@ def _build_convert_options(self): for transformation in operation["transforms"]: # get list with customizable options - dynamic_transform_conf = get_dynamic_params(transformation) + dynamic_transform_conf = get_dynamic_params(transformation, operation["input_fields"]) if len(dynamic_transform_conf) == 0: continue transform_type = next(iter(transformation.keys())) transform_folder = self.viewer.server.gui.add_folder(transform_type) self.viewer.convert_gui_handles.append(transform_folder) - if isinstance(operation["input_fields"], list): - description = f"input fields: {operation["input_fields"]}" - else: - description = ( - f"input fields from prefix: {operation["input_fields"]["from_fields_with_prefix"]}" - ) - self._build_transform_folder(transform_folder, description, transformation, transform_type) + self._build_transform_folder(transform_folder, operation, transformation, transform_type) @torch.no_grad() def render_fn( diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index cf8cc9c..bc2eab7 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -215,10 +215,12 @@ ops: dtype: uint8 round_to_int: true - - input_fields: [shN_centroids] + - input_fields: [shN_centroids, shN_labels] transforms: - sort: method: lexicographic + labels: shN_labels + target: shN_centroids to_fields_with_prefix: shN_centroids_ - input_fields: [shN_centroids, shN_centroids_indices] @@ -232,21 +234,8 @@ ops: - reshape: shape: [-1, 960, 3] # int(num_clusters*num_spherical_harmonics/3) = - - input_fields: [shN_labels, shN_centroids_indices_inverse] #labels are indices, centroids_inverse_indices are original location, match original location with - transforms: - - reindex: - src_field: shN_centroids_indices_inverse - index_field: shN_labels - - - input_fields: [shN_centroids_indices_inverse] - transforms: - - to_field: - to_field_name: shN_labels - - input_fields: [shN_labels] transforms: - - reshape: - shape: [588, 588] #n_sidelen - simple_quantize: min: 0 max: 65535 diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index e4c5dce..84aebbc 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -389,8 +389,27 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: dynamic_params_config.append({ "label": "method", "type": "dropdown", - "values": ["log", "signed-log", "inverse-sigmoid"], + "values": ["log", "signed-log", "inverse-sigmoid", "minmax"], }) + if params.get("method") == "minmax": + dynamic_params_config.append({ + "label": "min", + "type": "number", + "min": 0, + "max": 4294967295, # unit32 max + "step": 1, + "dtype": int, + "set": "min", + }) + dynamic_params_config.append({ + "label": "max", + "type": "number", + "min": 0, + "max": 4294967295, # unit32 max + "step": 1, + "dtype": int, + "set": "max", + }) return dynamic_params_config @@ -707,7 +726,10 @@ def apply( decoding_update: list[dict[str, Any]] = [] # TODO: make this compatible with 1D-tensors match params: - case {"src_field": src_field_name, "index_field": index_field_name}: + case { + "src_field": src_field_name, + "index_field": index_field_name, + }: index_field_obj = input_fields[index_field_name] original_data = input_fields[src_field_name].data new_fields[src_field_name] = Field(original_data[index_field_obj.data], parentOp) @@ -742,12 +764,22 @@ def apply( raise ValueError("Field data is None before sorting") prefix = params["to_fields_with_prefix"] match params: - case {"method": "lexicographic"}: - sorted_indices = np.lexsort(field_data.permute(dims=(1, 0)).cpu().numpy()) + case {"method": "lexicographic", "labels": labels_name, "target": target_name}: + target = input_fields[target_name].data + sorted_indices = np.lexsort(target.permute(dims=(1, 0)).cpu().numpy()) sorted_indices = torch.tensor(sorted_indices, device=field_data.device) inverse_sorted_indices = torch.argsort(sorted_indices) - # square_keep_indices = PLAS.primitive_filter_pruning_to_square_shape(inverse_sorted_indices, verbose=False) - # inverse_sorted_indices = inverse_sorted_indices[square_keep_indices] + + if labels_name != "_": + original_labels = input_fields[labels_name].data + # orignal_labels_grid = PLAS.primitive_filter_pruning_to_square_shape(original_labels,verbose=verbose) + original_labels_grid = PLAS.as_grid_img(original_labels) + updated_labels = inverse_sorted_indices[original_labels_grid] + new_fields[labels_name] = Field(updated_labels, parentOp) + decoding_update.append({ + "input_fields": [labels_name], + "transforms": [{"flatten": {"start_dim": 0, "end_dim": 1}}], + }) case {"method": "plas"}: plas_cfg = {k: v for k, v in params.items() if k not in ["method", "to_fields_with_prefix"]} plas_cfg["to_field"] = f"{prefix}indices" @@ -765,6 +797,20 @@ def apply( return new_fields, decoding_update + @staticmethod + def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: + """Get the dynamic parameters for a given transformation type. This might modify the values in params.""" + dynamic_params_config: list[dict[str, Any]] = [] + dynamic_params_config.append({ + "label": "method", + "type": "dropdown", + "values": ["lexicographic", "plas"], + }) + if params.get("method") == "plas": + dynamic_params_config.extend(PLAS.get_dynamic_params(params)) + + return dynamic_params_config + class PLAS: @staticmethod @@ -858,12 +904,20 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: """Get the dynamic parameters for a given transformation type. This might modify the values in params.""" - if params.get("weights") is None: - raise ValueError(f"PLAS parameters is missing weights: {params}") + # if params.get("weights") is None: + # raise ValueError(f"PLAS parameters is missing weights: {params}") field_names = list(params["weights"].keys()) scaling_functions = ["standardize", "minmax", "none"] dynamic_params_config: list[dict[str, Any]] = [] + + params.setdefault("scaling_fn", "standardize") + params.setdefault("shuffle", True) + params.setdefault("improvement_break", 1e-4) + # coding_params.setdefault("quality", -1) + # coding_params.setdefault("chroma", 444) + # coding_params.setdefault("matrix_coefficients", 0) + # TODO: initial values for scaling function, improvement break!!!! dynamic_params_config.append({ "label": "scaling_fn", "type": "dropdown", @@ -885,6 +939,8 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: }) weight_config: list[dict[str, Any]] = [] + # if params.get("weights") is None: + # params. for field_name in field_names: weight_config.append({ "label": field_name, @@ -1428,10 +1484,17 @@ def apply_transform( return transformation.apply(parentOp.params[parentOp.transform_type], parentOp, verbose=verbose) -def get_dynamic_params(params: dict[str, dict[str, Any]]) -> list[dict[str, Any]]: +def get_dynamic_params(params: dict[str, dict[str, Any]], input_field: list | str) -> list[dict[str, Any]]: """Get the dynamic parameters for a given transformation type. This might modify the values in params.""" transform_type = next(iter(params.keys())) transformation = transformation_map.get(transform_type) if transformation is None: raise ValueError(f"Unknown transformation: {transform_type}") + elif ( + transformation is Sort + and params[transform_type]["method"] == "plas" + and "weights" not in params[transform_type] + ): + params[transform_type]["weights"] = {k: 1.0 for k in input_field} + return transformation.get_dynamic_params(params[transform_type]) From aea0f16bdb748e092d00b9d0e30d81a0417cb06f Mon Sep 17 00:00:00 2001 From: fleischmann Date: Tue, 8 Jul 2025 16:55:37 +0200 Subject: [PATCH 20/24] plas opacity issue solved #11 --- src/ffsplat/cli/eval.py | 2 +- src/ffsplat/cli/live.py | 2 +- src/ffsplat/cli/view.py | 2 +- src/ffsplat/coding/scene_decoder.py | 3 ++- src/ffsplat/conf/format/SOG-canvas.yaml | 6 ------ src/ffsplat/models/transformations.py | 4 +++- 6 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/ffsplat/cli/eval.py b/src/ffsplat/cli/eval.py index a0e0d19..7f002d6 100644 --- a/src/ffsplat/cli/eval.py +++ b/src/ffsplat/cli/eval.py @@ -37,7 +37,7 @@ def rasterize_splats( means=gaussians.means.data, quats=gaussians.quaternions.data, scales=gaussians.scales.data, - opacities=gaussians.opacities.data, + opacities=gaussians.opacities.data.squeeze(), colors=colors, viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] Ks=Ks, # [C, 3, 3] diff --git a/src/ffsplat/cli/live.py b/src/ffsplat/cli/live.py index 9431615..5da0e1c 100644 --- a/src/ffsplat/cli/live.py +++ b/src/ffsplat/cli/live.py @@ -542,7 +542,7 @@ def render_fn( means_t, # [N, 3] quats_t, # [N, 4] scales_t, # [N, 3] - opacities_t, # [N] + opacities_t.squeeze(), # [N] colors, # [N, S, 3] viewmat[None], # [1, 4, 4] K[None], # [1, 3, 3] diff --git a/src/ffsplat/cli/view.py b/src/ffsplat/cli/view.py index c38ffad..cf09e91 100644 --- a/src/ffsplat/cli/view.py +++ b/src/ffsplat/cli/view.py @@ -35,7 +35,7 @@ def render_fn( means_t, # [N, 3] quats_t, # [N, 4] scales_t, # [N, 3] - opacities_t, # [N] + opacities_t.squeeze(), # [N] colors, # [N, S, 3] viewmat[None], # [1, 4, 4] K[None], # [1, 3, 3] diff --git a/src/ffsplat/coding/scene_decoder.py b/src/ffsplat/coding/scene_decoder.py index c975ecc..98574c9 100644 --- a/src/ffsplat/coding/scene_decoder.py +++ b/src/ffsplat/coding/scene_decoder.py @@ -81,11 +81,12 @@ def _process_fields(self, verbose: bool = False) -> None: def _create_scene(self) -> None: match self.decoding_params.scene.get("primitives"): case "3DGS-INRIA": + opacities_field = self.fields["opacities"] self.scene = Gaussians( means=self.fields["means"], quaternions=self.fields["quaternions"], scales=self.fields["scales"], - opacities=self.fields["opacities"], + opacities=Field(opacities_field.data.unsqueeze(-1), opacities_field.op), sh=self.fields["sh"], ) case _: diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index bc2eab7..3a680b9 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -164,12 +164,6 @@ ops: src_field: shN index_field: sorted_indices - - input_fields: [opacities] - transforms: - - reshape: - start_dim: 2 - shape: [1] - - input_fields: [sh0] transforms: - permute: diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 84aebbc..c1ce83c 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -840,7 +840,7 @@ def primitive_filter_pruning_to_square_shape(data: Tensor, verbose: bool) -> Ten f"Removing {num_to_remove}/{num_primitives} primitives to fit the grid. ({100 * num_to_remove / num_primitives:.4f}%)" ) - _, keep_indices = torch.topk(data, k=grid_sidelen * grid_sidelen) + _, keep_indices = torch.topk(data.squeeze(), k=grid_sidelen * grid_sidelen) sorted_keep_indices = torch.sort(keep_indices)[0] return sorted_keep_indices @@ -1081,6 +1081,8 @@ def apply( case {"type": "image", "image_codec": codec, "coding_params": coding_params, "base_path": base_path}: for field_name, field_obj in parentOp.input_fields.items(): field_data = field_obj.data + if field_data.shape[-1] == 1: + field_data = field_data.squeeze(-1) file_path = f"{field_name}.{codec}" output_file_path = Path(base_path) / Path(file_path) write_image( From 4b11c2d4d83fb5fb9a554bee60efc1d0f3818d1f Mon Sep 17 00:00:00 2001 From: fleischmann Date: Tue, 8 Jul 2025 17:04:43 +0200 Subject: [PATCH 21/24] add metadata --- src/ffsplat/models/transformations.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index c1ce83c..59ff48b 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -792,6 +792,7 @@ def apply( case _: raise ValueError(f"Unknown Sort parameters: {params}") + # TODO: optional label-update new_fields[f"{prefix}indices"] = Field(sorted_indices, parentOp) new_fields[f"{prefix}indices_inverse"] = Field(inverse_sorted_indices, parentOp) @@ -1109,7 +1110,10 @@ def apply( file_path = "meta.json" output_file_path = Path(base_path) / Path(file_path) field_names = list(parentOp.input_fields.keys()) - meta: dict[str, Any] = {} + meta: dict[str, Any] = { + "packer": "ffsplat", + "version": 1, + } # Readfile no input_fields for field_name in field_names: field = parentOp.input_fields[field_name] From d2d4e9084396b3cd97e65ff956eb3bf7da3309bc Mon Sep 17 00:00:00 2001 From: fleischmann Date: Tue, 8 Jul 2025 18:10:03 +0200 Subject: [PATCH 22/24] cleanup --- src/ffsplat/cli/live.py | 1 + src/ffsplat/conf/format/SOG-canvas.yaml | 4 +- src/ffsplat/conf/format/SOG-web-nosh.yaml | 2 +- src/ffsplat/conf/format/SOG-web-png.yaml | 2 +- src/ffsplat/conf/format/SOG-web-sh-split.yaml | 2 +- src/ffsplat/conf/format/SOG-web.yaml | 2 +- src/ffsplat/models/transformations.py | 63 +++++-------------- 7 files changed, 23 insertions(+), 53 deletions(-) diff --git a/src/ffsplat/cli/live.py b/src/ffsplat/cli/live.py index 5da0e1c..f4e8401 100644 --- a/src/ffsplat/cli/live.py +++ b/src/ffsplat/cli/live.py @@ -42,6 +42,7 @@ "SOG-web-png", "SOG-web-nosh", "SOG-web-sh-split", + "SOG-canvas", ] diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index 3a680b9..b258216 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -49,7 +49,7 @@ ops: scaling_fn: none shuffle: true improvement_break: 1e-4 - to_fields_with_prefix: sorted_ + to_field: sorted_indices weights: means: 1.0 sh0: 1.0 @@ -215,7 +215,7 @@ ops: method: lexicographic labels: shN_labels target: shN_centroids - to_fields_with_prefix: shN_centroids_ + to_field: shN_centroids_indices - input_fields: [shN_centroids, shN_centroids_indices] transforms: diff --git a/src/ffsplat/conf/format/SOG-web-nosh.yaml b/src/ffsplat/conf/format/SOG-web-nosh.yaml index a1c054d..008bbb1 100644 --- a/src/ffsplat/conf/format/SOG-web-nosh.yaml +++ b/src/ffsplat/conf/format/SOG-web-nosh.yaml @@ -44,7 +44,7 @@ ops: shuffle: true improvement_break: 1e-4 # improvement_break: 0.1 - to_fields_with_prefix: sorted_ + to_field: sorted_indices weights: means_bytes_0: 0.1 means_bytes_1: 1.0 diff --git a/src/ffsplat/conf/format/SOG-web-png.yaml b/src/ffsplat/conf/format/SOG-web-png.yaml index 526bd93..5af9306 100644 --- a/src/ffsplat/conf/format/SOG-web-png.yaml +++ b/src/ffsplat/conf/format/SOG-web-png.yaml @@ -52,7 +52,7 @@ ops: shuffle: true improvement_break: 1e-4 # improvement_break: 0.1 - to_fields_with_prefix: sorted_ + to_field: sorted_indices weights: means_bytes_0: 0.2 means_bytes_1: 1.0 diff --git a/src/ffsplat/conf/format/SOG-web-sh-split.yaml b/src/ffsplat/conf/format/SOG-web-sh-split.yaml index b86f085..656f93b 100644 --- a/src/ffsplat/conf/format/SOG-web-sh-split.yaml +++ b/src/ffsplat/conf/format/SOG-web-sh-split.yaml @@ -52,7 +52,7 @@ ops: shuffle: true improvement_break: 1e-4 # improvement_break: 0.1 - to_fields_with_prefix: sorted_ + to_fields: sorted_indices weights: means_bytes_0: 0.1 means_bytes_1: 1.0 diff --git a/src/ffsplat/conf/format/SOG-web.yaml b/src/ffsplat/conf/format/SOG-web.yaml index eb1613d..20ba699 100644 --- a/src/ffsplat/conf/format/SOG-web.yaml +++ b/src/ffsplat/conf/format/SOG-web.yaml @@ -52,7 +52,7 @@ ops: shuffle: true improvement_break: 1e-4 # improvement_break: 0.1 - to_fields_with_prefix: sorted_ + to_field: sorted_indices weights: means_bytes_0: 0.2 means_bytes_1: 1.0 diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 59ff48b..74d1ac0 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -389,27 +389,8 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: dynamic_params_config.append({ "label": "method", "type": "dropdown", - "values": ["log", "signed-log", "inverse-sigmoid", "minmax"], + "values": ["log", "signed-log", "inverse-sigmoid"], }) - if params.get("method") == "minmax": - dynamic_params_config.append({ - "label": "min", - "type": "number", - "min": 0, - "max": 4294967295, # unit32 max - "step": 1, - "dtype": int, - "set": "min", - }) - dynamic_params_config.append({ - "label": "max", - "type": "number", - "min": 0, - "max": 4294967295, # unit32 max - "step": 1, - "dtype": int, - "set": "max", - }) return dynamic_params_config @@ -726,10 +707,7 @@ def apply( decoding_update: list[dict[str, Any]] = [] # TODO: make this compatible with 1D-tensors match params: - case { - "src_field": src_field_name, - "index_field": index_field_name, - }: + case {"src_field": src_field_name, "index_field": index_field_name}: index_field_obj = input_fields[index_field_name] original_data = input_fields[src_field_name].data new_fields[src_field_name] = Field(original_data[index_field_obj.data], parentOp) @@ -762,17 +740,20 @@ def apply( sorted_indices = None if field_data is None: raise ValueError("Field data is None before sorting") - prefix = params["to_fields_with_prefix"] + to_field_name = params["to_field"] match params: - case {"method": "lexicographic", "labels": labels_name, "target": target_name}: + case {"method": "lexicographic", "target": target_name}: + # plas preprocess, use weights 1 instead of target_name + plas_cfg = {k: v for k, v in params.items() if k not in ["method", "labels"]} + target = input_fields[target_name].data sorted_indices = np.lexsort(target.permute(dims=(1, 0)).cpu().numpy()) sorted_indices = torch.tensor(sorted_indices, device=field_data.device) - inverse_sorted_indices = torch.argsort(sorted_indices) - if labels_name != "_": + if "labels" in params: + inverse_sorted_indices = torch.argsort(sorted_indices) + labels_name = params["labels"] original_labels = input_fields[labels_name].data - # orignal_labels_grid = PLAS.primitive_filter_pruning_to_square_shape(original_labels,verbose=verbose) original_labels_grid = PLAS.as_grid_img(original_labels) updated_labels = inverse_sorted_indices[original_labels_grid] new_fields[labels_name] = Field(updated_labels, parentOp) @@ -780,21 +761,22 @@ def apply( "input_fields": [labels_name], "transforms": [{"flatten": {"start_dim": 0, "end_dim": 1}}], }) + case {"method": "plas"}: - plas_cfg = {k: v for k, v in params.items() if k not in ["method", "to_fields_with_prefix"]} - plas_cfg["to_field"] = f"{prefix}indices" + plas_cfg = {k: v for k, v in params.items() if k not in ["method"]} + + plas_cfg["to_field"] = params["to_field"] sorted_indices = PLAS.plas_preprocess( plas_cfg=PLASConfig(**plas_cfg), fields=parentOp.input_fields, verbose=verbose, ) - inverse_sorted_indices = torch.argsort(sorted_indices) case _: raise ValueError(f"Unknown Sort parameters: {params}") + # TODO: optional label-update - new_fields[f"{prefix}indices"] = Field(sorted_indices, parentOp) - new_fields[f"{prefix}indices_inverse"] = Field(inverse_sorted_indices, parentOp) + new_fields[to_field_name] = Field(sorted_indices, parentOp) return new_fields, decoding_update @@ -905,8 +887,6 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: """Get the dynamic parameters for a given transformation type. This might modify the values in params.""" - # if params.get("weights") is None: - # raise ValueError(f"PLAS parameters is missing weights: {params}") field_names = list(params["weights"].keys()) scaling_functions = ["standardize", "minmax", "none"] @@ -915,10 +895,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: params.setdefault("scaling_fn", "standardize") params.setdefault("shuffle", True) params.setdefault("improvement_break", 1e-4) - # coding_params.setdefault("quality", -1) - # coding_params.setdefault("chroma", 444) - # coding_params.setdefault("matrix_coefficients", 0) - # TODO: initial values for scaling function, improvement break!!!! dynamic_params_config.append({ "label": "scaling_fn", "type": "dropdown", @@ -940,8 +916,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: }) weight_config: list[dict[str, Any]] = [] - # if params.get("weights") is None: - # params. for field_name in field_names: weight_config.append({ "label": field_name, @@ -1169,15 +1143,11 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "rebuild": True, }) # coding_params for image file - dynamic_coding_params: list[dict[str, Any]] = [] - dynamic_coding_params: list[dict[str, Any]] = [] coding_params: dict[str, Any] = params.get("coding_params", {}) match params["image_codec"]: case "avif": # check whether we need to update the coding params default: - coding_params: dict[str, Any] = params.get("coding_params", {}) - if not all(key in coding_params for key in ["quality", "chroma", "matrix_coefficients"]): coding_params.clear() coding_params["quality"] = -1 @@ -1219,7 +1189,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: }) case "png": - coding_params = params.get("coding_params", {}) if "compression_level" not in coding_params: coding_params.clear() coding_params["compression_level"] = 3 From f3dd67578b48d98dcf8c32f91ebd656c96ea7113 Mon Sep 17 00:00:00 2001 From: fleischmann Date: Wed, 9 Jul 2025 11:03:30 +0200 Subject: [PATCH 23/24] lexsort compatible with grid-pruning and reshaping to grid, cleanup --- src/ffsplat/conf/format/SOG-canvas.yaml | 4 +- src/ffsplat/models/transformations.py | 90 ++++++++++++++++--------- 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-canvas.yaml index b258216..ce6c4ee 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-canvas.yaml @@ -214,7 +214,9 @@ ops: - sort: method: lexicographic labels: shN_labels - target: shN_centroids + weights: + shN_labels: 0.0 + shN_centroids: 1.0 to_field: shN_centroids_indices - input_fields: [shN_centroids, shN_centroids_indices] diff --git a/src/ffsplat/models/transformations.py b/src/ffsplat/models/transformations.py index 74d1ac0..de46609 100644 --- a/src/ffsplat/models/transformations.py +++ b/src/ffsplat/models/transformations.py @@ -705,7 +705,6 @@ def apply( new_fields: dict[str, Field] = {} decoding_update: list[dict[str, Any]] = [] - # TODO: make this compatible with 1D-tensors match params: case {"src_field": src_field_name, "index_field": index_field_name}: index_field_obj = input_fields[index_field_name] @@ -742,12 +741,20 @@ def apply( raise ValueError("Field data is None before sorting") to_field_name = params["to_field"] match params: - case {"method": "lexicographic", "target": target_name}: - # plas preprocess, use weights 1 instead of target_name + case {"method": "lexicographic"}: + # weight is effectively a boolean in this case plas_cfg = {k: v for k, v in params.items() if k not in ["method", "labels"]} - - target = input_fields[target_name].data - sorted_indices = np.lexsort(target.permute(dims=(1, 0)).cpu().numpy()) + plas_cfg["scaling_fn"] = "none" + plas_cfg["shuffle"] = False + plas_cfg.setdefault("improvement_break", 1e-4) + plas_cfg.setdefault("prune_by", None) # only reshapes to grid if not None + preprocess_dict = PLAS.plas_preprocess( + plas_cfg=PLASConfig(**plas_cfg), fields=parentOp.input_fields, verbose=verbose + ) + params_tensor = preprocess_dict["params_tensor"] + sorted_indices = np.lexsort( + params_tensor.permute(dims=(1, 0)).cpu().numpy(), + ) sorted_indices = torch.tensor(sorted_indices, device=field_data.device) if "labels" in params: @@ -755,27 +762,28 @@ def apply( labels_name = params["labels"] original_labels = input_fields[labels_name].data original_labels_grid = PLAS.as_grid_img(original_labels) + updated_labels = inverse_sorted_indices[original_labels_grid] new_fields[labels_name] = Field(updated_labels, parentOp) decoding_update.append({ "input_fields": [labels_name], "transforms": [{"flatten": {"start_dim": 0, "end_dim": 1}}], }) + if "prune_by" in params: + sorted_indices = PLAS.as_grid_img(sorted_indices) case {"method": "plas"}: plas_cfg = {k: v for k, v in params.items() if k not in ["method"]} - - plas_cfg["to_field"] = params["to_field"] - sorted_indices = PLAS.plas_preprocess( + preprocess_dict = PLAS.plas_preprocess( plas_cfg=PLASConfig(**plas_cfg), fields=parentOp.input_fields, verbose=verbose, ) + sorted_indices = PLAS.sort(**preprocess_dict) case _: raise ValueError(f"Unknown Sort parameters: {params}") - # TODO: optional label-update new_fields[to_field_name] = Field(sorted_indices, parentOp) return new_fields, decoding_update @@ -792,6 +800,23 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: if params.get("method") == "plas": dynamic_params_config.extend(PLAS.get_dynamic_params(params)) + weight_config: list[dict[str, Any]] = [] + + for field_name in list(params["weights"].keys()): + weight_config.append({ + "label": field_name, + "type": "number", + "min": 0.0, + "max": 1.0, + "step": 0.05, + "dtype": float, + }) + dynamic_params_config.append({ + "label": "weights", + "type": "heading", + "params": weight_config, + }) + return dynamic_params_config @@ -828,8 +853,11 @@ def primitive_filter_pruning_to_square_shape(data: Tensor, verbose: bool) -> Ten return sorted_keep_indices @staticmethod - def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: bool) -> Tensor: - primitive_filter = PLAS.primitive_filter_pruning_to_square_shape(fields[plas_cfg.prune_by].data, verbose) + def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: bool) -> dict[str, Any]: + if plas_cfg.prune_by is not None: + primitive_filter = PLAS.primitive_filter_pruning_to_square_shape(fields[plas_cfg.prune_by].data, verbose) + else: + primitive_filter = None # TODO untested match plas_cfg.scaling_fn: @@ -861,11 +889,28 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo params_tensor = torch.cat(params_to_sort, dim=1) if plas_cfg.shuffle: - # TODO shuffling should be an option of sort_with_plas torch.manual_seed(42) shuffled_indices = torch.randperm(params_tensor.shape[0], device=params_tensor.device) params_tensor = params_tensor[shuffled_indices] + else: + shuffled_indices = None + + return { + "plas_cfg": plas_cfg, + "params_tensor": params_tensor, + "shuffled_indices": shuffled_indices, + "primitive_filter": primitive_filter, + "verbose": verbose, + } + @staticmethod + def sort( + plas_cfg: PLASConfig, + params_tensor: Tensor, + shuffled_indices: Tensor | None, + primitive_filter: Tensor | None, + verbose: bool, + ) -> Tensor: grid_to_sort = PLAS.as_grid_img(params_tensor).permute(2, 0, 1) _, sorted_indices_ret = sort_with_plas( grid_to_sort, improvement_break=float(plas_cfg.improvement_break), verbose=verbose @@ -873,7 +918,7 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo sorted_indices: Tensor = sorted_indices_ret.squeeze(0).to(params_tensor.device) - if plas_cfg.shuffle: + if plas_cfg.shuffle and shuffled_indices is not None: flat_indices = sorted_indices.flatten() unshuffled_flat_indices = shuffled_indices[flat_indices] sorted_indices = unshuffled_flat_indices.reshape(sorted_indices.shape) @@ -887,7 +932,6 @@ def plas_preprocess(plas_cfg: PLASConfig, fields: dict[str, Field], verbose: boo def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: """Get the dynamic parameters for a given transformation type. This might modify the values in params.""" - field_names = list(params["weights"].keys()) scaling_functions = ["standardize", "minmax", "none"] dynamic_params_config: list[dict[str, Any]] = [] @@ -915,22 +959,6 @@ def get_dynamic_params(params: dict[str, Any]) -> list[dict[str, Any]]: "inverse_mapping": lambda x: np.log10(x), }) - weight_config: list[dict[str, Any]] = [] - for field_name in field_names: - weight_config.append({ - "label": field_name, - "type": "number", - "min": 0.0, - "max": 1.0, - "step": 0.05, - "dtype": float, - }) - dynamic_params_config.append({ - "label": "weights", - "type": "heading", - "params": weight_config, - }) - return dynamic_params_config From c792675fbd8aee6211bf62c6f5ea2e9ab11e1f63 Mon Sep 17 00:00:00 2001 From: Wieland Morgenstern Date: Wed, 9 Jul 2025 17:16:49 +0200 Subject: [PATCH 24/24] rename SOG-canvas as SOG-PlayCanvas, and make it first selection in compression flavor dropdown --- src/ffsplat/cli/live.py | 4 ++-- .../conf/format/{SOG-canvas.yaml => SOG-PlayCanvas.yaml} | 9 +++++++-- src/ffsplat/render/viewer.py | 4 ++-- 3 files changed, 11 insertions(+), 6 deletions(-) rename src/ffsplat/conf/format/{SOG-canvas.yaml => SOG-PlayCanvas.yaml} (97%) diff --git a/src/ffsplat/cli/live.py b/src/ffsplat/cli/live.py index f4e8401..b6fc1d4 100644 --- a/src/ffsplat/cli/live.py +++ b/src/ffsplat/cli/live.py @@ -36,13 +36,13 @@ from ..render.viewer import CameraState, Viewer available_output_format: list[str] = [ + "SOG-PlayCanvas", + "SOG-web", "3DGS_INRIA_ply", "3DGS_INRIA_nosh_ply", - "SOG-web", "SOG-web-png", "SOG-web-nosh", "SOG-web-sh-split", - "SOG-canvas", ] diff --git a/src/ffsplat/conf/format/SOG-canvas.yaml b/src/ffsplat/conf/format/SOG-PlayCanvas.yaml similarity index 97% rename from src/ffsplat/conf/format/SOG-canvas.yaml rename to src/ffsplat/conf/format/SOG-PlayCanvas.yaml index ce6c4ee..77ef7e7 100644 --- a/src/ffsplat/conf/format/SOG-canvas.yaml +++ b/src/ffsplat/conf/format/SOG-PlayCanvas.yaml @@ -1,5 +1,5 @@ -profile: SOG-canvas -profile_version: 0.1 +profile: SOG-PlayCanvas +profile_version: 1.0 scene: primitives: 3DGS-INRIA @@ -164,6 +164,10 @@ ops: src_field: shN index_field: sorted_indices + # shN[sorted_indices] -> shN + + # new_blah[labels] -> labels + - input_fields: [sh0] transforms: - permute: @@ -178,6 +182,7 @@ ops: dim: 2 to_field: sh0 + # TODO: quantize sh0 and opacities with different ranges - input_fields: [sh0] transforms: - simple_quantize: diff --git a/src/ffsplat/render/viewer.py b/src/ffsplat/render/viewer.py index 21261f7..b3f1516 100644 --- a/src/ffsplat/render/viewer.py +++ b/src/ffsplat/render/viewer.py @@ -15,12 +15,12 @@ from ._renderer import Renderer, RenderTask available_output_format: list[str] = [ + "SOG-PlayCanvas", + "SOG-web", "3DGS_INRIA_ply", "3DGS_INRIA_nosh_ply", - "SOG-web", "SOG-web-nosh", "SOG-web-sh-split", - "SOG-canvas", ]