Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 133 additions & 43 deletions simpeg_drivers/plate_simulation/match/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from geoapps_utils.utils.logger import get_logger
from geoapps_utils.utils.numerical import inverse_weighted_operator
from geoapps_utils.utils.plotting import symlog
from geoapps_utils.utils.transformations import cartesian_to_polar
from geoapps_utils.utils.transformations import cartesian_to_polar, rotate_xyz
from geoh5py import Workspace
from geoh5py.groups import PropertyGroup, SimPEGGroup
from geoh5py.objects import AirborneTEMReceivers, Surface
from geoh5py.objects import AirborneTEMReceivers, MaxwellPlate, Surface
from geoh5py.objects.maxwell_plate import PlateGeometry
from geoh5py.ui_json import InputFile
from scipy import signal
from scipy.sparse import csr_matrix
Expand All @@ -34,7 +35,7 @@

from simpeg_drivers.driver import BaseDriver
from simpeg_drivers.plate_simulation.match.options import PlateMatchOptions
from simpeg_drivers.plate_simulation.options import PlateSimulationOptions
from simpeg_drivers.plate_simulation.options import ModelOptions, PlateSimulationOptions


logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False)
Expand Down Expand Up @@ -123,6 +124,41 @@ def start(cls, filepath: str | Path, mode="r+", **_) -> Self:

return driver

def _create_plate_from_parameters(
self, index_center: int, model_options: ModelOptions, strike_angle: float
) -> MaxwellPlate:
center = self.params.survey.vertices[index_center]
center[2] = (
self._drape_heights[index_center] - model_options.overburden_model.thickness
)
indices = self.params.survey.get_segment_indices(
index_center, self.params.max_distance
)
segment = self.params.survey.vertices[indices]
delta = np.median(np.diff(segment, axis=0), axis=0)
azimuth = 90 - np.rad2deg(np.arctan2(delta[1], delta[0]))

plate_geometry = PlateGeometry.model_validate(
{
"position": {
"x": center[0],
"y": center[1],
"z": center[2],
},
"width": model_options.plate_model.dip_length,
"thickness": model_options.plate_model.width,
"length": model_options.plate_model.strike_length,
"dip": model_options.plate_model.dip,
"dip_direction": (azimuth + strike_angle) % 360,
}
)
plate = MaxwellPlate.create(
self.params.geoh5, geometry=plate_geometry, parent=self.params.out_group
)
plate.metadata = model_options.model_dump()

return plate

def _get_drape_heights(self) -> np.ndarray:
"""Set drape heights based on topography object and optional topography data."""

Expand Down Expand Up @@ -159,6 +195,12 @@ def spatial_interpolation(
origin=np.r_[self.params.survey.vertices[indices, :2].mean(axis=0), 0],
)
local_polar[local_polar[:, 1] >= 180, 0] *= -1 # Wrap azimuths

# Flip the line segment if the azimuth angle suggests the opposite direction
start_line = len(indices) // 2
if np.median(local_polar[:start_line, 1]) < 180:
local_polar = local_polar[::-1, :]

local_polar[:, 1] = (
0.0 if strike_angle is None else strike_angle
) # Align azimuths to zero
Expand Down Expand Up @@ -186,32 +228,38 @@ def run(self):
"Running %s . . .",
self.params.title,
)
observed = normalized_data(self.params.data)[self._time_mask, :]
observed = get_data_array(self.params.data)[self._time_mask, :]
tree = cKDTree(self.params.survey.vertices[:, :2])
file_split = np.array_split(
self.params.simulation_files, np.maximum(1, len(self.workers) * 10)
)
names = []
results = []
for ii, query in enumerate(self.params.queries.vertices):
# Find the nearest survey location to the query point
nearest = tree.query(query[:2], k=1)[1]
indices = self.params.survey.get_segment_indices(
nearest, self.params.max_distance
)
spatial_projection = self.spatial_interpolation(
indices,
strike_angle = (
0
if self.params.strike_angles is None
else self.params.strike_angles.values[ii],
)
file_split = np.array_split(
self.params.simulation_files, np.maximum(1, len(self.workers) * 10)
else self.params.strike_angles.values[ii]
)
data, flip = prepare_data(observed[:, indices])

spatial_projection = self.spatial_interpolation(
indices,
np.abs(strike_angle),
)
tasks = []

for file_batch in file_split:
args = (
file_batch,
spatial_projection,
self._time_projection,
observed[:, indices],
data,
)

tasks.append(
Expand All @@ -225,36 +273,47 @@ def run(self):
progress(tasks)
tasks = self.client.gather(tasks)

scores = np.hstack(tasks)
ranked = np.argsort(scores)[::-1]

scores, centers = np.vstack(tasks).T
ranked = np.argsort(scores)
best = ranked[0]
# TODO: Return top N matches
# for rank in ranked[-1:][::-1]:
logger.info(
"File: %s \nScore: %.4f",
self.params.simulation_files[ranked[0]].name,
scores[ranked[0]],
self.params.simulation_files[best].name,
scores[best],
)
with Workspace(self.params.simulation_files[ranked[0]], mode="r") as ws:
with Workspace(self.params.simulation_files[best], mode="r") as ws:
survey = fetch_survey(ws)
ui_json = survey.parent.parent.options
ui_json["geoh5"] = ws
ifile = InputFile(ui_json=ui_json)
options = PlateSimulationOptions.build(ifile)

plate = survey.parent.parent.get_entity("plate")[0].copy(
parent=self.params.out_group
)

# Set position of plate to query location
center = self.params.survey.vertices[nearest]
center[2] = self._drape_heights[nearest]
plate.vertices = plate.vertices + center
plate.metadata = options.model.model_dump()
dir_correction = strike_angle + 180 if flip else strike_angle

results.append(self.params.simulation_files[ranked[0]].name)
plate = self._create_plate_from_parameters(
int(indices[int(centers[best])]), options.model, dir_correction
)
plate.name = f"Query [{ii}]"

names.append(self.params.simulation_files[best].name)
results.append(scores[best])

out = self.params.queries.copy(parent=self.params.out_group)
out.add_data(
{
"file": {
"values": np.array(names, dtype="U"),
"primitive_type": "TEXT",
},
"score": {
"values": np.array(results),
},
}
)

return results
return out

@classmethod
def start_dask_run(
Expand All @@ -281,20 +340,46 @@ def start_dask_run(
)


def normalized_data(property_group: PropertyGroup, threshold=5) -> np.ndarray:
def prepare_data(data: np.ndarray) -> tuple[np.ndarray, bool]:
"""
Prepare data for scoring by checking for multiple channels and normalizing.

param data_array: Array of data channels per location.

:return: Tuple of prepared data array, whether locations were reversed.
"""
data_array = normalized_data(data)

# Guess what the down-dip direction is based on migration of peaks
max_ind = np.argmax(data_array, axis=1)

# Check if peaks migrate in a consistent direction across channels
diffs = np.diff(max_ind)
if np.median(diffs) < 0:
return data_array[:, ::-1], True # Reverse channels if peaks migrate up-dip

return data_array, False


def get_data_array(property_group: PropertyGroup) -> np.ndarray:
"""Extract data array from a property group."""
table = property_group.table()
return np.vstack([table[name] for name in table.dtype.names])


def normalized_data(data: np.ndarray, threshold=5) -> np.ndarray:
"""
Return data from a property group with symlog scaling and zero mean.
Return data from a property group with symlog, zero mean and unit max normalization.

:param property_group: Property group containing data channels.
:param data: Array of data channels per location.
:param threshold: Percentile threshold for symlog normalization.

:return: Normalized data array.
"""
table = property_group.table()
data_array = np.vstack([table[name] for name in table.dtype.names])
thresh = np.percentile(np.abs(data_array), threshold)
log_data = symlog(data_array, thresh)
return log_data - np.mean(log_data, axis=1)[:, None]
thresh = np.percentile(np.abs(data), threshold)
log_data = symlog(data, thresh)
centered_log = log_data - np.mean(log_data)
return centered_log / np.abs(centered_log).max()


def fetch_survey(workspace: Workspace) -> AirborneTEMReceivers | None:
Expand All @@ -310,7 +395,7 @@ def fetch_survey(workspace: Workspace) -> AirborneTEMReceivers | None:

def batch_files_score(
files: Path | list[Path], spatial_projection, time_projection, observed
) -> list[float]:
) -> list[tuple[float, int]]:
"""
Process a batch of simulation files and compute scores against observed data.

Expand All @@ -334,24 +419,29 @@ def batch_files_score(
logger.warning("No survey found in %s, skipping.", sim_file)
continue

simulated = normalized_data(survey.get_entity("Iteration_0_z")[0])
simulated = get_data_array(survey.get_entity("Iteration_0_z")[0])
pred = time_projection @ (spatial_projection @ simulated.T).T
pred = normalized_data(pred)
score = 0.0

indices = []
# Metric: normalized cross-correlation
for obs, pre in zip(observed, pred, strict=True):
# Scale pre on obs
vals = pre / np.abs(pre).max() * np.abs(obs).max()

# Full cross-correlation
corr = signal.correlate(obs, pre, mode="full")
corr = signal.correlate(obs, vals, mode="same")
# Normalize by energy to get correlation coefficient in [-1, 1]
denom = np.linalg.norm(pre) * np.linalg.norm(obs)
denom = np.linalg.norm(vals) * np.linalg.norm(obs)
if denom == 0:
corr_norm = np.zeros_like(corr)
else:
corr_norm = corr / denom

score += np.max(corr_norm)
score += np.linalg.norm(obs - vals)
indices.append(np.argmax(corr_norm))

scores.append(score)
scores.append((score, np.median(indices)))

return scores

Expand Down
53 changes: 47 additions & 6 deletions tests/plate_simulation/runtest/match_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
import pytest
from geoapps_utils.utils.importing import GeoAppsError
from geoapps_utils.utils.transformations import rotate_xyz
from geoh5py import Workspace
from geoh5py.groups import PropertyGroup, SimPEGGroup
from geoh5py.objects import Points
Expand Down Expand Up @@ -40,8 +41,14 @@
def generate_example(geoh5: Workspace, n_grid_points: int, refinement: tuple[int]):
opts = SyntheticsComponentsOptions(
method="airborne tdem",
survey=SurveyOptions(n_stations=n_grid_points, n_lines=1, drape=10.0),
mesh=MeshOptions(refinement=refinement, padding_distance=400.0),
survey=SurveyOptions(
n_stations=n_grid_points,
n_lines=1,
width=1000,
drape=40.0,
topography=lambda x, y: np.zeros(x.shape),
),
mesh=MeshOptions(refinement=refinement),
model=ModelOptions(background=0.001),
)
components = SyntheticsComponents(geoh5, options=opts)
Expand Down Expand Up @@ -110,7 +117,7 @@ def test_matching_driver(tmp_path: Path):

# Generate simulation files
with get_workspace(tmp_path / f"{__name__}.geoh5") as geoh5:
components = generate_example(geoh5, n_grid_points=5, refinement=(2,))
components = generate_example(geoh5, n_grid_points=32, refinement=(2,))

params = TDEMForwardOptions.build(
geoh5=geoh5,
Expand All @@ -132,6 +139,8 @@ def test_matching_driver(tmp_path: Path):
ifile.data["simulation"] = fwr_driver.out_group

plate_options = PlateSimulationOptions.build(ifile.data)
plate_options.model.overburden_model.thickness = 40.0
plate_options.model.plate_model.dip_length = 300.0
driver = PlateSimulationDriver(plate_options)
driver.run()

Expand All @@ -154,18 +163,50 @@ def test_matching_driver(tmp_path: Path):
child = survey.get_entity(uid)[0]
child.values = child.values * scale

# Random choice of file
# Downsample data
mask = np.ones_like(child.values, dtype=bool)
mask[1::2] = False
survey.remove_vertices(mask)
indices = np.arange(survey.n_vertices)
survey.cells = np.c_[indices[:-1], indices[1:]]

# Run the matching driver
with geoh5.open():
survey = fetch_survey(geoh5)

# Rotate the survey to test matching
survey.vertices = rotate_xyz(survey.vertices, [0, 0, 0], 215.0)

# Flip the data to simulate up-dip measurements
prop_group = survey.get_entity("Iteration_0_z")[0]
for uid in prop_group.properties:
child = survey.get_entity(uid)[0]
child.values = child.values[::-1]

# Change the strike angle to simulate a different orientation
strikes = components.queries.add_data(
{
"strike": {
"values": np.full(components.queries.n_vertices, -10.0),
}
}
)

options = PlateMatchOptions(
geoh5=geoh5,
survey=survey,
data=survey.get_entity("Iteration_0_z")[0],
data=prop_group,
queries=components.queries,
strike_angles=strikes,
topography_object=components.topography,
simulations=new_dir,
)
match_driver = PlateMatchDriver(options)
results = match_driver.run()

assert results[0] == file.stem + f"_[{4}].geoh5"
assert isinstance(results, Points)

names = results.get_data("file")[0]
assert names.values[0] == file.stem + f"_[{4}].geoh5"

assert geoh5.get_entity("Query [0]")[0].geometry.dip_direction == 45.0
Loading