Skip to content

Commit c24e3e4

Browse files
committed
[WIP] Trying multiprocess
1 parent 587096a commit c24e3e4

File tree

3 files changed

+368
-74
lines changed

3 files changed

+368
-74
lines changed
Lines changed: 92 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
12
import numpy
23
import warnings
34
from typing import List
5+
import os
46

57
import numpy as np
68

@@ -11,95 +13,62 @@
1113
from ...core.data.dual_contouring_data import DualContouringData
1214
from ...core.data.dual_contouring_mesh import DualContouringMesh
1315
from ...core.utils import gempy_profiler_decorator
16+
from ...modules.dual_contouring._parallel_triangulation import _should_use_parallel_processing, _process_surface_batch, _init_worker
17+
from ...modules.dual_contouring._sequential_triangulation import _sequential_triangulation
1418
from ...modules.dual_contouring.dual_contouring_interface import triangulate_dual_contouring, generate_dual_contouring_vertices
1519
from ...modules.dual_contouring.fancy_triangulation import triangulate
1620

21+
# Multiprocessing imports
22+
try:
23+
import torch.multiprocessing as mp
24+
MULTIPROCESSING_AVAILABLE = True
25+
except ImportError:
26+
import multiprocessing as mp
27+
MULTIPROCESSING_AVAILABLE = False
28+
29+
30+
31+
1732

1833
@gempy_profiler_decorator
1934
def compute_dual_contouring(dc_data_per_stack: DualContouringData, left_right_codes=None, debug: bool = False) -> List[DualContouringMesh]:
2035
valid_edges_per_surface = dc_data_per_stack.valid_edges.reshape((dc_data_per_stack.n_surfaces_to_export, -1, 12))
2136

22-
# ? Is there a way to cut also the vertices?
23-
37+
# Check if we should use parallel processing
38+
use_parallel = _should_use_parallel_processing(dc_data_per_stack.n_surfaces_to_export, BackendTensor.engine_backend)
39+
40+
if use_parallel:
41+
print(f"Using parallel processing for {dc_data_per_stack.n_surfaces_to_export} surfaces")
42+
parallel_results = _parallel_process_surfaces(dc_data_per_stack, left_right_codes, debug)
43+
44+
if parallel_results is not None:
45+
# Convert parallel results to DualContouringMesh objects
46+
stack_meshes = []
47+
for vertices_numpy, indices_numpy in parallel_results:
48+
if TRIMESH_LAST_PASS := True:
49+
vertices_numpy, indices_numpy = _last_pass(vertices_numpy, indices_numpy)
50+
51+
stack_meshes.append(
52+
DualContouringMesh(
53+
vertices_numpy,
54+
indices_numpy,
55+
dc_data_per_stack
56+
)
57+
)
58+
return stack_meshes
59+
60+
# Fall back to sequential processing
61+
print(f"Using sequential processing for {dc_data_per_stack.n_surfaces_to_export} surfaces")
2462
stack_meshes: List[DualContouringMesh] = []
2563

2664
last_surface_edge_idx = 0
2765
for i in range(dc_data_per_stack.n_surfaces_to_export):
2866
# @off
29-
valid_edges : np.ndarray = valid_edges_per_surface[i]
30-
next_surface_edge_idx: int = valid_edges.sum() + last_surface_edge_idx
31-
slice_object : slice = slice(last_surface_edge_idx, next_surface_edge_idx)
32-
last_surface_edge_idx: int = next_surface_edge_idx
33-
34-
dc_data_per_surface = DualContouringData(
35-
xyz_on_edge = dc_data_per_stack.xyz_on_edge,
36-
valid_edges = valid_edges,
37-
xyz_on_centers = dc_data_per_stack.xyz_on_centers,
38-
dxdydz = dc_data_per_stack.dxdydz,
39-
exported_fields_on_edges = dc_data_per_stack.exported_fields_on_edges,
40-
n_surfaces_to_export = dc_data_per_stack.n_surfaces_to_export,
41-
tree_depth = dc_data_per_stack.tree_depth
67+
indices_numpy, vertices_numpy = _sequential_triangulation(dc_data_per_stack, debug, i, last_surface_edge_idx, left_right_codes, valid_edges_per_surface)
4268

43-
)
44-
vertices: np.ndarray = generate_dual_contouring_vertices(
45-
dc_data_per_stack = dc_data_per_surface,
46-
slice_surface = slice_object,
47-
debug = debug
48-
)
49-
50-
if left_right_codes is None:
51-
# * Legacy triangulation
52-
indices = triangulate_dual_contouring(dc_data_per_surface)
53-
else:
54-
# * Fancy triangulation 👗
55-
56-
# * Average gradient for the edges
57-
edges_normals = BackendTensor.t.zeros((valid_edges.shape[0], 12, 3), dtype=BackendTensor.dtype_obj)
58-
edges_normals[:] = np.nan
59-
edges_normals[valid_edges] = dc_data_per_stack.gradients[slice_object]
60-
61-
# if LEGACY:=True:
62-
if BackendTensor.engine_backend != AvailableBackends.PYTORCH:
63-
with warnings.catch_warnings():
64-
warnings.simplefilter("ignore", category=RuntimeWarning)
65-
voxel_normal = np.nanmean(edges_normals, axis=1)
66-
voxel_normal = voxel_normal[(~np.isnan(voxel_normal).any(axis=1))] # drop nans
67-
pass
68-
else:
69-
# Assuming edges_normals is a PyTorch tensor
70-
nan_mask = BackendTensor.t.isnan(edges_normals)
71-
valid_count = (~nan_mask).sum(dim=1)
72-
73-
# Replace NaNs with 0 for sum calculation
74-
safe_normals = edges_normals.clone()
75-
safe_normals[nan_mask] = 0
76-
77-
# Compute the sum of non-NaN elements
78-
sum_normals = BackendTensor.t.sum(safe_normals, 1)
79-
80-
# Calculate the mean, avoiding division by zero
81-
voxel_normal = sum_normals / valid_count.clamp(min=1)
82-
83-
# Remove rows where all elements were NaN (and hence valid_count is 0)
84-
voxel_normal = voxel_normal[valid_count > 0].reshape(-1, 3)
85-
86-
87-
valid_voxels = dc_data_per_surface.valid_voxels
88-
indices = triangulate(
89-
left_right_array = left_right_codes[valid_voxels],
90-
valid_edges = dc_data_per_surface.valid_edges[valid_voxels],
91-
tree_depth = dc_data_per_surface.tree_depth,
92-
voxel_normals = voxel_normal
93-
)
94-
indices = BackendTensor.t.concatenate(indices, axis=0)
95-
96-
# @on
97-
vertices_numpy = BackendTensor.t.to_numpy(vertices)
98-
indices_numpy = BackendTensor.t.to_numpy(indices)
99-
10069
if TRIMESH_LAST_PASS := True:
10170
vertices_numpy, indices_numpy = _last_pass(vertices_numpy, indices_numpy)
102-
71+
10372
stack_meshes.append(
10473
DualContouringMesh(
10574
vertices_numpy,
@@ -110,6 +79,55 @@ def compute_dual_contouring(dc_data_per_stack: DualContouringData, left_right_co
11079
return stack_meshes
11180

11281

82+
83+
84+
def _parallel_process_surfaces(dc_data_per_stack, left_right_codes, debug, num_workers=None, chunk_size=2):
85+
"""Process surfaces in parallel using multiprocessing."""
86+
if num_workers is None:
87+
num_workers = max(1, min(os.cpu_count() // 2, dc_data_per_stack.n_surfaces_to_export // 2))
88+
89+
# Prepare data for serialization
90+
dc_data_dict = {
91+
'xyz_on_edge' : dc_data_per_stack.xyz_on_edge,
92+
'valid_edges' : dc_data_per_stack.valid_edges,
93+
'xyz_on_centers' : dc_data_per_stack.xyz_on_centers,
94+
'dxdydz' : dc_data_per_stack.dxdydz,
95+
'exported_fields_on_edges': dc_data_per_stack.exported_fields_on_edges,
96+
'n_surfaces_to_export' : dc_data_per_stack.n_surfaces_to_export,
97+
'tree_depth' : dc_data_per_stack.tree_depth,
98+
# 'gradients': getattr(dc_data_per_stack, 'gradients', None)
99+
}
100+
101+
# Create surface index chunks
102+
surface_indices = list(range(dc_data_per_stack.n_surfaces_to_export))
103+
chunks = [surface_indices[i:i + chunk_size] for i in range(0, len(surface_indices), chunk_size)]
104+
105+
try:
106+
# Use spawn context for better PyTorch compatibility
107+
ctx = mp.get_context("spawn") if MULTIPROCESSING_AVAILABLE else mp
108+
109+
with ctx.Pool(processes=num_workers, initializer=_init_worker) as pool:
110+
# Submit all chunks
111+
async_results = []
112+
for chunk in chunks:
113+
result = pool.apply_async(
114+
_process_surface_batch,
115+
(chunk, dc_data_dict, left_right_codes, debug)
116+
)
117+
async_results.append(result)
118+
119+
# Collect results
120+
all_results = []
121+
for async_result in async_results:
122+
batch_results = async_result.get()
123+
all_results.extend(batch_results)
124+
125+
return all_results
126+
127+
except Exception as e:
128+
print(f"Parallel processing failed: {e}. Falling back to sequential processing.")
129+
return None
130+
113131
def _last_pass(vertices, indices):
114132
# Check if trimesh is available
115133
try:
@@ -118,4 +136,4 @@ def _last_pass(vertices, indices):
118136
mesh.fill_holes()
119137
return mesh.vertices, mesh.faces
120138
except ImportError:
121-
return vertices, indices
139+
return vertices, indices
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import numpy as np
2+
import os
3+
import warnings
4+
5+
from gempy_engine.config import AvailableBackends
6+
from ...core.backend_tensor import BackendTensor
7+
from ...core.data.dual_contouring_data import DualContouringData
8+
from ...modules.dual_contouring.dual_contouring_interface import triangulate_dual_contouring
9+
from ...modules.dual_contouring.fancy_triangulation import triangulate
10+
11+
# Multiprocessing imports
12+
try:
13+
import torch.multiprocessing as mp
14+
MULTIPROCESSING_AVAILABLE = True
15+
except ImportError:
16+
import multiprocessing as mp
17+
MULTIPROCESSING_AVAILABLE = False
18+
19+
20+
def _should_use_parallel_processing(n_surfaces: int, backend: AvailableBackends) -> bool:
21+
"""Determine if parallel processing should be used."""
22+
# Only use parallel processing for PyTorch CPU backend with sufficient surfaces
23+
if backend == AvailableBackends.PYTORCH and MULTIPROCESSING_AVAILABLE:
24+
# Check if we're on CPU (not GPU)
25+
try:
26+
import torch
27+
if torch.cuda.is_available():
28+
# If CUDA is available, check if default tensor type is CPU
29+
dummy = BackendTensor.t.zeros(1)
30+
is_cpu = dummy.device.type == 'cpu' if hasattr(dummy, 'device') else True
31+
else:
32+
is_cpu = True
33+
34+
# Use parallel processing if we have CPU tensors and enough surfaces to justify overhead
35+
return is_cpu and n_surfaces >= 4
36+
except ImportError:
37+
return False
38+
return False
39+
40+
41+
def _init_worker():
42+
"""Initialize worker process to avoid thread oversubscription."""
43+
# Set environment variables for NumPy/OpenMP/MKL
44+
os.environ['OMP_NUM_THREADS'] = '1'
45+
os.environ['MKL_NUM_THREADS'] = '1'
46+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
47+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
48+
49+
# For PyTorch, set environment variables before import
50+
os.environ['TORCH_NUM_THREADS'] = '1'
51+
os.environ['TORCH_NUM_INTEROP_THREADS'] = '1'
52+
53+
# Now import torch in the worker process
54+
try:
55+
import torch
56+
# These calls might still work if torch hasn't done any parallel work yet in this process
57+
try:
58+
torch.set_num_threads(1)
59+
torch.set_num_interop_threads(1)
60+
except RuntimeError:
61+
# If the above fails, the environment variables should handle it
62+
pass
63+
except ImportError:
64+
pass
65+
66+
67+
def _process_surface_batch(surface_indices_batch, dc_data_dict, left_right_codes, debug):
68+
"""Process a batch of surfaces in a worker process."""
69+
_init_worker()
70+
71+
# Reconstruct dc_data_per_stack from dictionary
72+
dc_data_per_stack = DualContouringData(**dc_data_dict)
73+
valid_edges_per_surface = dc_data_per_stack.valid_edges.reshape((dc_data_per_stack.n_surfaces_to_export, -1, 12))
74+
75+
batch_results = []
76+
77+
for i in surface_indices_batch:
78+
result = _process_single_surface(
79+
i, dc_data_per_stack, valid_edges_per_surface, left_right_codes, debug
80+
)
81+
batch_results.append(result)
82+
83+
return batch_results
84+
85+
def _process_single_surface(i, dc_data_per_stack, valid_edges_per_surface, left_right_codes, debug):
86+
"""Process a single surface and return vertices and indices."""
87+
try:
88+
valid_edges = valid_edges_per_surface[i]
89+
90+
# Calculate edge indices for this surface
91+
last_surface_edge_idx = sum(valid_edges_per_surface[j].sum() for j in range(i))
92+
next_surface_edge_idx = valid_edges.sum() + last_surface_edge_idx
93+
slice_object = slice(last_surface_edge_idx, next_surface_edge_idx)
94+
95+
dc_data_per_surface = DualContouringData(
96+
xyz_on_edge=dc_data_per_stack.xyz_on_edge,
97+
valid_edges=valid_edges,
98+
xyz_on_centers=dc_data_per_stack.xyz_on_centers,
99+
dxdydz=dc_data_per_stack.dxdydz,
100+
exported_fields_on_edges=dc_data_per_stack.exported_fields_on_edges,
101+
n_surfaces_to_export=dc_data_per_stack.n_surfaces_to_export,
102+
tree_depth=dc_data_per_stack.tree_depth
103+
)
104+
105+
print(f"DEBUG: Processing surface {i}")
106+
107+
if left_right_codes is None:
108+
# Legacy triangulation
109+
indices = triangulate_dual_contouring(dc_data_per_surface)
110+
else:
111+
# Fancy triangulation
112+
print(f"DEBUG: Creating edges_normals tensor")
113+
114+
# Check BackendTensor.dtype_obj
115+
print(f"DEBUG: BackendTensor.dtype_obj = {BackendTensor.dtype_obj}")
116+
117+
edges_normals = BackendTensor.t.zeros((valid_edges.shape[0], 12, 3), dtype=BackendTensor.dtype_obj)
118+
print(f"DEBUG: edges_normals dtype: {edges_normals.dtype if hasattr(edges_normals, 'dtype') else 'No dtype attr'}")
119+
120+
# Set to NaN - this might be where the error occurs
121+
print(f"DEBUG: Setting edges_normals to NaN")
122+
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
123+
edges_normals[:] = float('nan') # Use Python float nan instead of np.nan
124+
else:
125+
edges_normals[:] = np.nan
126+
127+
# Get gradient data
128+
print(f"DEBUG: Getting gradient data")
129+
gradient_data = dc_data_per_stack.gradients[slice_object]
130+
print(f"DEBUG: gradient_data shape: {gradient_data.shape}, dtype: {gradient_data.dtype if hasattr(gradient_data, 'dtype') else 'No dtype attr'}")
131+
132+
# Fix dtype mismatch by ensuring compatible dtypes
133+
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
134+
if hasattr(gradient_data, 'dtype') and hasattr(edges_normals, 'dtype'):
135+
print(f"DEBUG: Comparing dtypes - edges_normals: {edges_normals.dtype}, gradient_data: {gradient_data.dtype}")
136+
if gradient_data.dtype != edges_normals.dtype:
137+
print(f"DEBUG: Converting gradient_data from {gradient_data.dtype} to {edges_normals.dtype}")
138+
gradient_data = gradient_data.to(edges_normals.dtype)
139+
140+
print(f"DEBUG: Assigning gradient data to edges_normals")
141+
print(f"DEBUG: valid_edges shape: {valid_edges.shape}, sum: {valid_edges.sum()}")
142+
edges_normals[valid_edges] = gradient_data
143+
144+
if BackendTensor.engine_backend != AvailableBackends.PYTORCH:
145+
with warnings.catch_warnings():
146+
warnings.simplefilter("ignore", category=RuntimeWarning)
147+
voxel_normal = np.nanmean(edges_normals, axis=1)
148+
voxel_normal = voxel_normal[(~np.isnan(voxel_normal).any(axis=1))]
149+
else:
150+
print(f"DEBUG: Computing voxel normals with PyTorch")
151+
# PyTorch tensor operations
152+
nan_mask = BackendTensor.t.isnan(edges_normals)
153+
valid_count = (~nan_mask).sum(dim=1)
154+
safe_normals = edges_normals.clone()
155+
safe_normals[nan_mask] = 0
156+
sum_normals = BackendTensor.t.sum(safe_normals, 1)
157+
voxel_normal = sum_normals / valid_count.clamp(min=1)
158+
voxel_normal = voxel_normal[valid_count > 0].reshape(-1, 3)
159+
print(f"DEBUG: voxel_normal shape: {voxel_normal.shape}, dtype: {voxel_normal.dtype}")
160+
161+
valid_voxels = dc_data_per_surface.valid_voxels
162+
left_right_per_surface = left_right_codes[valid_voxels]
163+
valid_voxels_per_surface = dc_data_per_surface.valid_edges[valid_voxels]
164+
voxel_normal_per_surface = voxel_normal[valid_voxels]
165+
tree_depth_per_surface = dc_data_per_surface.tree_depth
166+
167+
print(f"DEBUG: Calling triangulate function")
168+
indices = triangulate(
169+
left_right_array=left_right_per_surface,
170+
valid_edges=valid_voxels_per_surface,
171+
tree_depth=tree_depth_per_surface,
172+
voxel_normals=voxel_normal
173+
)
174+
print(f"DEBUG: triangulate returned, concatenating indices")
175+
indices = BackendTensor.t.concatenate(indices, axis=0)
176+
177+
print(f"DEBUG: Converting to numpy")
178+
# vertices_numpy = BackendTensor.t.to_numpy(vertices)
179+
indices_numpy = BackendTensor.t.to_numpy(indices)
180+
181+
print(f"DEBUG: Successfully processed surface {i}")
182+
return indices_numpy
183+
184+
except Exception as e:
185+
print(f"ERROR in _process_single_surface for surface {i}: {e}")
186+
import traceback
187+
traceback.print_exc()
188+
raise
189+

0 commit comments

Comments
 (0)