@@ -14,7 +14,7 @@ def _compute_vertices(dc_data_per_stack: DualContouringData,
1414 valid_edges_per_surface ) -> tuple [DualContouringData , Any ]:
1515 """Compute vertices for a specific surface."""
1616 valid_edges : np .ndarray = valid_edges_per_surface [surface_i ]
17-
17+
1818 slice_object = _surface_slicer (surface_i , valid_edges_per_surface )
1919
2020 dc_data_per_surface = DualContouringData (
@@ -27,19 +27,10 @@ def _compute_vertices(dc_data_per_stack: DualContouringData,
2727 tree_depth = dc_data_per_stack .tree_depth
2828 )
2929
30- vertices_numpy = _generate_vertices (dc_data_per_surface , debug , slice_object )
30+ vertices_numpy = generate_dual_contouring_vertices (dc_data_per_surface , slice_object , debug )
3131 return dc_data_per_surface , vertices_numpy
3232
3333
34- def _generate_vertices (dc_data_per_surface : DualContouringData , debug : bool , slice_object : slice ) -> Any :
35- vertices : np .ndarray = generate_dual_contouring_vertices (
36- dc_data_per_stack = dc_data_per_surface ,
37- slice_surface = slice_object ,
38- debug = debug
39- )
40- return vertices
41-
42-
4334def generate_dual_contouring_vertices (dc_data_per_stack : DualContouringData , slice_surface : Optional [slice ] = None , debug : bool = False ):
4435 # @off
4536 n_edges = dc_data_per_stack .n_valid_edges
@@ -48,75 +39,77 @@ def generate_dual_contouring_vertices(dc_data_per_stack: DualContouringData, sli
4839 if slice_surface is not None :
4940 xyz_on_edge = dc_data_per_stack .xyz_on_edge [slice_surface ]
5041 gradients = dc_data_per_stack .gradients [slice_surface ]
51- else :
42+ else :
5243 xyz_on_edge = dc_data_per_stack .xyz_on_edge
53- gradients = dc_data_per_stack .gradients
44+ gradients = dc_data_per_stack .gradients
5445 # @on
5546
56- # * Coordinates for all posible edges (12) and 3 dummy edges_normals in the center
57- edges_xyz = BackendTensor .tfnp .zeros ((n_edges , 15 , 3 ), dtype = BackendTensor .dtype_obj )
58- valid_edges = valid_edges > 0
59- edges_xyz [:, :12 ][valid_edges ] = xyz_on_edge
60-
61- # Normals
62- edges_normals = BackendTensor .tfnp .zeros ((n_edges , 15 , 3 ), dtype = BackendTensor .dtype_obj )
63- edges_normals [:, :12 ][valid_edges ] = gradients
64-
65- if OLD_METHOD := False :
66- # ! Moureze model does not seems to work with the new method
67- # ! This branch is all nans at least with ch1_1 model
68- bias_xyz = BackendTensor .tfnp .copy (edges_xyz [:, :12 ])
69- isclose = BackendTensor .tfnp .isclose (bias_xyz , 0 )
70- bias_xyz [isclose ] = BackendTensor .tfnp .nan # zero values to nans
71- mass_points = BackendTensor .tfnp .nanmean (bias_xyz , axis = 1 ) # Mean ignoring nans
72- else : # ? This is actually doing something
73- bias_xyz = BackendTensor .tfnp .copy (edges_xyz [:, :12 ])
74- if BackendTensor .engine_backend == AvailableBackends .PYTORCH :
75- # PyTorch doesn't have masked arrays, so we'll use a different approach
76- mask = bias_xyz == 0
77- # Replace zeros with NaN for mean calculation
78- bias_xyz_masked = BackendTensor .tfnp .where (mask , float ('nan' ), bias_xyz )
79- mass_points = BackendTensor .tfnp .nanmean (bias_xyz_masked , axis = 1 )
80- else :
81- # NumPy approach with masked arrays
82- bias_xyz = BackendTensor .tfnp .to_numpy (bias_xyz )
83- import numpy as np
84- mask = bias_xyz == 0
85- masked_arr = np .ma .masked_array (bias_xyz , mask )
86- mass_points = masked_arr .mean (axis = 1 )
87- mass_points = BackendTensor .tfnp .array (mass_points )
88-
89- edges_xyz [:, 12 ] = mass_points
90- edges_xyz [:, 13 ] = mass_points
91- edges_xyz [:, 14 ] = mass_points
92-
93- BIAS_STRENGTH = 1
94-
95- bias_x = BackendTensor .tfnp .array ([BIAS_STRENGTH , 0 , 0 ], dtype = BackendTensor .dtype_obj )
96- bias_y = BackendTensor .tfnp .array ([0 , BIAS_STRENGTH , 0 ], dtype = BackendTensor .dtype_obj )
97- bias_z = BackendTensor .tfnp .array ([0 , 0 , BIAS_STRENGTH ], dtype = BackendTensor .dtype_obj )
47+ n_valid_voxels = BackendTensor .tfnp .sum (valid_voxels )
48+ edges_xyz = BackendTensor .tfnp .zeros ((n_valid_voxels , 15 , 3 ), dtype = BackendTensor .dtype_obj )
49+ edges_normals = BackendTensor .tfnp .zeros ((n_valid_voxels , 15 , 3 ), dtype = BackendTensor .dtype_obj )
50+
51+ # Filter valid_edges to only valid voxels
52+ valid_edges_bool = valid_edges [valid_voxels ] > 0
53+
54+ # Assign edge data (now only to valid voxels)
55+ edges_xyz [:, :12 ][valid_edges_bool ] = xyz_on_edge
56+ edges_normals [:, :12 ][valid_edges_bool ] = gradients
57+
58+ # Use nanmean directly without intermediate copy
59+ bias_xyz_slice = edges_xyz [:, :12 ]
60+
61+ if BackendTensor .engine_backend == AvailableBackends .PYTORCH :
62+ mask = bias_xyz_slice == 0
63+ bias_xyz_masked = BackendTensor .tfnp .where (mask , float ('nan' ), bias_xyz_slice )
64+ mass_points = BackendTensor .tfnp .nanmean (bias_xyz_masked , axis = 1 )
65+ else :
66+ # NumPy: more efficient approach using sum and count
67+ mask = bias_xyz_slice != 0
68+ sum_valid = (bias_xyz_slice * mask ).sum (axis = 1 )
69+ count_valid = mask .sum (axis = 1 )
70+ # Avoid division by zero
71+ count_valid = BackendTensor .tfnp .maximum (count_valid , 1 )
72+ mass_points = sum_valid / count_valid
9873
99- edges_normals [:, 12 ] = bias_x
100- edges_normals [:, 13 ] = bias_y
101- edges_normals [:, 14 ] = bias_z
74+ # Assign mass points to bias positions
75+ edges_xyz [:, 12 :15 ] = mass_points [:, None , :]
10276
103- # Remove unused voxels
104- edges_xyz = edges_xyz [valid_voxels ]
105- edges_normals = edges_normals [valid_voxels ]
77+ BIAS_STRENGTH = 1
78+ bias_normals = BackendTensor .tfnp .array ([
79+ [BIAS_STRENGTH , 0 , 0 ],
80+ [0 , BIAS_STRENGTH , 0 ],
81+ [0 , 0 , BIAS_STRENGTH ]
82+ ], dtype = BackendTensor .dtype_obj )
83+
84+ edges_normals [:, 12 :15 ] = bias_normals [None , :, :]
10685
107- # Compute LSTSQS in all voxels at the same time
10886 A = edges_normals
109- b = ( A * edges_xyz ). sum ( axis = 2 )
110-
87+
88+ # Compute A^T @ A more efficiently
11189 if BackendTensor .engine_backend == AvailableBackends .PYTORCH :
112- transpose_shape = (2 , 1 , 0 ) # For PyTorch: (batch, dim2, dim1)
90+ # For PyTorch: use bmm (batch matrix multiply) which is optimized
91+ A_T = A .transpose (1 , 2 )
92+ ATA = BackendTensor .tfnp .matmul (A_T , A ) # (n_voxels, 3, 3)
93+
94+ # Compute A^T @ (A * edges_xyz).sum(axis=2)
95+ b = (A * edges_xyz ).sum (axis = 2 ) # (n_voxels, 15)
96+ ATb = BackendTensor .tfnp .matmul (A_T , b .unsqueeze (- 1 )).squeeze (- 1 ) # (n_voxels, 3)
97+
98+ # Solve ATA @ x = ATb
99+ ATA_inv = BackendTensor .tfnp .linalg .inv (ATA )
100+ vertices = BackendTensor .tfnp .matmul (ATA_inv , ATb .unsqueeze (- 1 )).squeeze (- 1 )
113101 else :
114- transpose_shape = (0 , 2 , 1 ) # For NumPy: (batch, dim2, dim1)
115-
116- term1 = BackendTensor .tfnp .einsum ("ijk, ilj->ikl" , A , BackendTensor .tfnp .transpose (A , transpose_shape ))
117- term2 = BackendTensor .tfnp .linalg .inv (term1 )
118- term3 = BackendTensor .tfnp .einsum ("ijk,ik->ij" , BackendTensor .tfnp .transpose (A , transpose_shape ), b )
119- vertices = BackendTensor .tfnp .einsum ("ijk, ij->ik" , term2 , term3 )
102+ # NumPy: use efficient einsum
103+ b = (A * edges_xyz ).sum (axis = 2 )
104+
105+ # A^T @ A
106+ ATA = BackendTensor .tfnp .einsum ("ijk,ijl->ikl" , A , A )
107+ # A^T @ b
108+ ATb = BackendTensor .tfnp .einsum ("ijk,ij->ik" , A , b )
109+
110+ # Solve
111+ ATA_inv = BackendTensor .tfnp .linalg .inv (ATA )
112+ vertices = BackendTensor .tfnp .einsum ("ijk,ij->ik" , ATA_inv , ATb )
120113
121114 if debug :
122115 dc_data_per_stack .bias_center_mass = edges_xyz [:, 12 :].reshape (- 1 , 3 )
0 commit comments