Skip to content

Commit 98293b5

Browse files
authored
[REFACTOR] Optimize dual contouring vertex generation and add vertex overlap configuration (#36)
# Optimize Dual Contouring Vertex Generation and Add Configuration Flag This PR introduces several optimizations to the dual contouring vertex generation process: - Adds a new `DUAL_CONTOURING_VERTEX_OVERLAP` configuration flag to control vertex overlap behavior - Refactors the vertex generation code with significant performance optimizations: - Eliminates redundant memory copies and intermediate arrays - Improves matrix multiplication efficiency with backend-specific optimizations - Enhances voxel filtering logic to only process valid voxels - Optimizes mass point calculation with more efficient algorithms - Simplifies function interfaces by merging `_generate_vertices` into `generate_dual_contouring_vertices` - Improves vertex overlap handling with better parameter naming and conditional execution - Updates overlap logging to show only relevant keys instead of full data structures These changes maintain compatibility while improving performance and code clarity in the dual contouring implementation.
2 parents 5999fb1 + 158f5d5 commit 98293b5

File tree

6 files changed

+85
-90
lines changed

6 files changed

+85
-90
lines changed

gempy_engine/API/dual_contouring/multi_scalar_dual_contouring.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
import numpy as np
55

66
from gempy_engine.modules.dual_contouring._dual_contouring import compute_dual_contouring
7+
8+
from ...modules.dual_contouring._dual_contouring_v2 import compute_dual_contouring_v2
79
from ._experimental_water_tight_DC_1 import _experimental_water_tight
810
from ._mask_buffer import MaskBuffer
911
from ..interp_single.interp_features import interpolate_all_fields_no_octree
12+
from ...config import DUAL_CONTOURING_VERTEX_OVERLAP
1013
from ...core.backend_tensor import BackendTensor
1114
from ...core.data import InterpolationOptions
1215
from ...core.data.dual_contouring_data import DualContouringData
@@ -100,6 +103,8 @@ def dual_contouring_multi_scalar(
100103

101104
# endregion
102105

106+
compute_overlap = (len(all_left_right_codes) > 1) and DUAL_CONTOURING_VERTEX_OVERLAP
107+
103108
# region Vertex gen and triangulation
104109
left_right_per_mesh = []
105110
# Generate meshes for each scalar field
@@ -132,19 +137,18 @@ def dual_contouring_multi_scalar(
132137
)
133138

134139
dc_data_per_surface_all.append(dc_data_per_surface)
140+
if (compute_overlap):
141+
left_right_per_mesh.append(all_left_right_codes[n_scalar_field][dc_data_per_surface.valid_voxels])
135142

136-
from gempy_engine.modules.dual_contouring._dual_contouring_v2 import compute_dual_contouring_v2
137143
all_meshes = compute_dual_contouring_v2(
138144
dc_data_list=dc_data_per_surface_all,
139145
)
140146
# endregion
141-
if (options.debug or len(all_left_right_codes) > 1) and False:
142-
apply_faults_vertex_overlap(all_meshes, data_descriptor.stack_structure, left_right_per_mesh)
147+
if compute_overlap:
148+
apply_faults_vertex_overlap(all_meshes, data_descriptor.stack_structure, left_right_per_mesh)
143149

144150
return all_meshes
145151

146-
# ... existing code ...
147-
148152

149153
def _compute_meshes_legacy(all_left_right_codes: list[Any], all_mask_arrays: np.ndarray,
150154
all_meshes: list[DualContouringMesh], all_stack_intersection: list[Any],

gempy_engine/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class AvailableBackends(Flag):
3232
LINE_PROFILER_ENABLED = os.getenv('LINE_PROFILER_ENABLED', 'False') == 'True'
3333
SET_RAW_ARRAYS_IN_SOLUTION = os.getenv('SET_RAW_ARRAYS_IN_SOLUTION', 'True') == 'True'
3434
NOT_MAKE_INPUT_DEEP_COPY = os.getenv('NOT_MAKE_INPUT_DEEP_COPY', 'False') == 'True'
35+
DUAL_CONTOURING_VERTEX_OVERLAP = os.getenv('NOT_MAKE_INPUT_DEEP_COPY', 'False') == 'True'
3536

3637
is_numpy_installed = find_spec("numpy") is not None
3738
is_tensorflow_installed = find_spec("tensorflow") is not None

gempy_engine/modules/dual_contouring/_dual_contouring_v2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import List
33

4-
from ._gen_vertices import _generate_vertices
4+
from ._gen_vertices import generate_dual_contouring_vertices
55
from ._parallel_triangulation import _should_use_parallel_processing, _init_worker
66
from ._sequential_triangulation import _compute_triangulation
77
from ... import optional_dependencies
@@ -130,8 +130,7 @@ def _process_surface_batch_v2(surface_indices, dc_data_dicts, left_right_codes):
130130

131131
return results
132132
def _process_one_surface(dc_data: DualContouringData, left_right_codes) -> DualContouringMesh:
133-
vertices = _generate_vertices(dc_data, False, None)
134-
133+
vertices = generate_dual_contouring_vertices(dc_data, slice_surface=None, debug=False)
135134
# * Average gradient for the edges
136135
valid_edges = dc_data.valid_edges
137136
edges_normals = BackendTensor.t.zeros((valid_edges.shape[0], 12, 3), dtype=BackendTensor.dtype_obj)

gempy_engine/modules/dual_contouring/_gen_vertices.py

Lines changed: 64 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def _compute_vertices(dc_data_per_stack: DualContouringData,
1414
valid_edges_per_surface) -> tuple[DualContouringData, Any]:
1515
"""Compute vertices for a specific surface."""
1616
valid_edges: np.ndarray = valid_edges_per_surface[surface_i]
17-
17+
1818
slice_object = _surface_slicer(surface_i, valid_edges_per_surface)
1919

2020
dc_data_per_surface = DualContouringData(
@@ -27,19 +27,10 @@ def _compute_vertices(dc_data_per_stack: DualContouringData,
2727
tree_depth=dc_data_per_stack.tree_depth
2828
)
2929

30-
vertices_numpy = _generate_vertices(dc_data_per_surface, debug, slice_object)
30+
vertices_numpy = generate_dual_contouring_vertices(dc_data_per_surface, slice_object, debug)
3131
return dc_data_per_surface, vertices_numpy
3232

3333

34-
def _generate_vertices(dc_data_per_surface: DualContouringData, debug: bool, slice_object: slice) -> Any:
35-
vertices: np.ndarray = generate_dual_contouring_vertices(
36-
dc_data_per_stack=dc_data_per_surface,
37-
slice_surface=slice_object,
38-
debug=debug
39-
)
40-
return vertices
41-
42-
4334
def generate_dual_contouring_vertices(dc_data_per_stack: DualContouringData, slice_surface: Optional[slice] = None, debug: bool = False):
4435
# @off
4536
n_edges = dc_data_per_stack.n_valid_edges
@@ -48,75 +39,77 @@ def generate_dual_contouring_vertices(dc_data_per_stack: DualContouringData, sli
4839
if slice_surface is not None:
4940
xyz_on_edge = dc_data_per_stack.xyz_on_edge[slice_surface]
5041
gradients = dc_data_per_stack.gradients[slice_surface]
51-
else:
42+
else:
5243
xyz_on_edge = dc_data_per_stack.xyz_on_edge
53-
gradients = dc_data_per_stack.gradients
44+
gradients = dc_data_per_stack.gradients
5445
# @on
5546

56-
# * Coordinates for all posible edges (12) and 3 dummy edges_normals in the center
57-
edges_xyz = BackendTensor.tfnp.zeros((n_edges, 15, 3), dtype=BackendTensor.dtype_obj)
58-
valid_edges = valid_edges > 0
59-
edges_xyz[:, :12][valid_edges] = xyz_on_edge
60-
61-
# Normals
62-
edges_normals = BackendTensor.tfnp.zeros((n_edges, 15, 3), dtype=BackendTensor.dtype_obj)
63-
edges_normals[:, :12][valid_edges] = gradients
64-
65-
if OLD_METHOD := False:
66-
# ! Moureze model does not seems to work with the new method
67-
# ! This branch is all nans at least with ch1_1 model
68-
bias_xyz = BackendTensor.tfnp.copy(edges_xyz[:, :12])
69-
isclose = BackendTensor.tfnp.isclose(bias_xyz, 0)
70-
bias_xyz[isclose] = BackendTensor.tfnp.nan # zero values to nans
71-
mass_points = BackendTensor.tfnp.nanmean(bias_xyz, axis=1) # Mean ignoring nans
72-
else: # ? This is actually doing something
73-
bias_xyz = BackendTensor.tfnp.copy(edges_xyz[:, :12])
74-
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
75-
# PyTorch doesn't have masked arrays, so we'll use a different approach
76-
mask = bias_xyz == 0
77-
# Replace zeros with NaN for mean calculation
78-
bias_xyz_masked = BackendTensor.tfnp.where(mask, float('nan'), bias_xyz)
79-
mass_points = BackendTensor.tfnp.nanmean(bias_xyz_masked, axis=1)
80-
else:
81-
# NumPy approach with masked arrays
82-
bias_xyz = BackendTensor.tfnp.to_numpy(bias_xyz)
83-
import numpy as np
84-
mask = bias_xyz == 0
85-
masked_arr = np.ma.masked_array(bias_xyz, mask)
86-
mass_points = masked_arr.mean(axis=1)
87-
mass_points = BackendTensor.tfnp.array(mass_points)
88-
89-
edges_xyz[:, 12] = mass_points
90-
edges_xyz[:, 13] = mass_points
91-
edges_xyz[:, 14] = mass_points
92-
93-
BIAS_STRENGTH = 1
94-
95-
bias_x = BackendTensor.tfnp.array([BIAS_STRENGTH, 0, 0], dtype=BackendTensor.dtype_obj)
96-
bias_y = BackendTensor.tfnp.array([0, BIAS_STRENGTH, 0], dtype=BackendTensor.dtype_obj)
97-
bias_z = BackendTensor.tfnp.array([0, 0, BIAS_STRENGTH], dtype=BackendTensor.dtype_obj)
47+
n_valid_voxels = BackendTensor.tfnp.sum(valid_voxels)
48+
edges_xyz = BackendTensor.tfnp.zeros((n_valid_voxels, 15, 3), dtype=BackendTensor.dtype_obj)
49+
edges_normals = BackendTensor.tfnp.zeros((n_valid_voxels, 15, 3), dtype=BackendTensor.dtype_obj)
50+
51+
# Filter valid_edges to only valid voxels
52+
valid_edges_bool = valid_edges[valid_voxels] > 0
53+
54+
# Assign edge data (now only to valid voxels)
55+
edges_xyz[:, :12][valid_edges_bool] = xyz_on_edge
56+
edges_normals[:, :12][valid_edges_bool] = gradients
57+
58+
# Use nanmean directly without intermediate copy
59+
bias_xyz_slice = edges_xyz[:, :12]
60+
61+
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
62+
mask = bias_xyz_slice == 0
63+
bias_xyz_masked = BackendTensor.tfnp.where(mask, float('nan'), bias_xyz_slice)
64+
mass_points = BackendTensor.tfnp.nanmean(bias_xyz_masked, axis=1)
65+
else:
66+
# NumPy: more efficient approach using sum and count
67+
mask = bias_xyz_slice != 0
68+
sum_valid = (bias_xyz_slice * mask).sum(axis=1)
69+
count_valid = mask.sum(axis=1)
70+
# Avoid division by zero
71+
count_valid = BackendTensor.tfnp.maximum(count_valid, 1)
72+
mass_points = sum_valid / count_valid
9873

99-
edges_normals[:, 12] = bias_x
100-
edges_normals[:, 13] = bias_y
101-
edges_normals[:, 14] = bias_z
74+
# Assign mass points to bias positions
75+
edges_xyz[:, 12:15] = mass_points[:, None, :]
10276

103-
# Remove unused voxels
104-
edges_xyz = edges_xyz[valid_voxels]
105-
edges_normals = edges_normals[valid_voxels]
77+
BIAS_STRENGTH = 1
78+
bias_normals = BackendTensor.tfnp.array([
79+
[BIAS_STRENGTH, 0, 0],
80+
[0, BIAS_STRENGTH, 0],
81+
[0, 0, BIAS_STRENGTH]
82+
], dtype=BackendTensor.dtype_obj)
83+
84+
edges_normals[:, 12:15] = bias_normals[None, :, :]
10685

107-
# Compute LSTSQS in all voxels at the same time
10886
A = edges_normals
109-
b = (A * edges_xyz).sum(axis=2)
110-
87+
88+
# Compute A^T @ A more efficiently
11189
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
112-
transpose_shape = (2, 1, 0) # For PyTorch: (batch, dim2, dim1)
90+
# For PyTorch: use bmm (batch matrix multiply) which is optimized
91+
A_T = A.transpose(1, 2)
92+
ATA = BackendTensor.tfnp.matmul(A_T, A) # (n_voxels, 3, 3)
93+
94+
# Compute A^T @ (A * edges_xyz).sum(axis=2)
95+
b = (A * edges_xyz).sum(axis=2) # (n_voxels, 15)
96+
ATb = BackendTensor.tfnp.matmul(A_T, b.unsqueeze(-1)).squeeze(-1) # (n_voxels, 3)
97+
98+
# Solve ATA @ x = ATb
99+
ATA_inv = BackendTensor.tfnp.linalg.inv(ATA)
100+
vertices = BackendTensor.tfnp.matmul(ATA_inv, ATb.unsqueeze(-1)).squeeze(-1)
113101
else:
114-
transpose_shape = (0, 2, 1) # For NumPy: (batch, dim2, dim1)
115-
116-
term1 = BackendTensor.tfnp.einsum("ijk, ilj->ikl", A, BackendTensor.tfnp.transpose(A, transpose_shape))
117-
term2 = BackendTensor.tfnp.linalg.inv(term1)
118-
term3 = BackendTensor.tfnp.einsum("ijk,ik->ij", BackendTensor.tfnp.transpose(A, transpose_shape), b)
119-
vertices = BackendTensor.tfnp.einsum("ijk, ij->ik", term2, term3)
102+
# NumPy: use efficient einsum
103+
b = (A * edges_xyz).sum(axis=2)
104+
105+
# A^T @ A
106+
ATA = BackendTensor.tfnp.einsum("ijk,ijl->ikl", A, A)
107+
# A^T @ b
108+
ATb = BackendTensor.tfnp.einsum("ijk,ij->ik", A, b)
109+
110+
# Solve
111+
ATA_inv = BackendTensor.tfnp.linalg.inv(ATA)
112+
vertices = BackendTensor.tfnp.einsum("ijk,ij->ik", ATA_inv, ATb)
120113

121114
if debug:
122115
dc_data_per_stack.bias_center_mass = edges_xyz[:, 12:].reshape(-1, 3)

gempy_engine/modules/dual_contouring/_vertex_overlap.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def _apply_fault_relations_to_overlaps(
3535

3636
if overlap_key in voxel_overlaps:
3737
_apply_vertex_sharing(
38-
all_meshes,
39-
origin_stack,
40-
surface_n,
41-
voxel_overlaps[overlap_key]
38+
all_meshes=all_meshes,
39+
origin_mesh_idx=origin_stack,
40+
destination_mesh_idx=surface_n,
41+
overlap_data=voxel_overlaps[overlap_key]
4242
)
4343

4444

@@ -135,9 +135,10 @@ def _find_overlaps_between_stacks(
135135
for i in range(len(stack_codes)):
136136
for j in range(i + 1, len(stack_codes)):
137137
overlap_data = _process_stack_pair(
138-
stack_codes[i], stack_codes[j],
139-
all_left_right_codes[i], all_left_right_codes[j],
140-
i, j
138+
codes_i=stack_codes[i],
139+
codes_j=stack_codes[j],
140+
left_right_i=all_left_right_codes[i],
141+
left_right_j=all_left_right_codes[j],
141142
)
142143

143144
if overlap_data:
@@ -151,8 +152,6 @@ def _process_stack_pair(
151152
codes_j: np.ndarray,
152153
left_right_i: np.ndarray,
153154
left_right_j: np.ndarray,
154-
stack_i: int,
155-
stack_j: int
156155
) -> Optional[dict]:
157156
"""Process a pair of stacks to find overlapping voxels."""
158157
if codes_i.size == 0 or codes_j.size == 0:

gempy_engine/modules/dual_contouring/dual_contouring_interface.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from ...core.backend_tensor import BackendTensor
99
from ...core.data import InterpolationOptions
1010
from ...core.data.dual_contouring_mesh import DualContouringMesh
11-
from ...core.data.input_data_descriptor import InputDataDescriptor
1211
from ...core.data.interp_output import InterpOutput
1312
from ...core.data.octree_level import OctreeLevel
1413
from ...core.data.options import MeshExtractionMaskingOptions
@@ -205,5 +204,5 @@ def apply_faults_vertex_overlap(all_meshes: list[DualContouringMesh],
205204
voxel_overlaps = find_repeated_voxels_across_stacks(left_right_per_mesh)
206205

207206
if voxel_overlaps:
208-
print(f"Found voxel overlaps between stacks: {voxel_overlaps}")
207+
print(f"Found voxel overlaps between stacks: {voxel_overlaps.keys()}")
209208
_apply_fault_relations_to_overlaps(all_meshes, voxel_overlaps, stack_structure)

0 commit comments

Comments
 (0)