From ec1f7faeb6df1326cd8d81b1954fd75a1ba38e65 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Tue, 3 Feb 2026 12:37:24 -0800 Subject: [PATCH 1/6] Correct for survey azimuth --- .../plate_simulation/match/driver.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 19b47365..789a4b40 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -22,7 +22,7 @@ from geoapps_utils.utils.logger import get_logger from geoapps_utils.utils.numerical import inverse_weighted_operator from geoapps_utils.utils.plotting import symlog -from geoapps_utils.utils.transformations import cartesian_to_polar +from geoapps_utils.utils.transformations import cartesian_to_polar, rotate_xyz from geoh5py import Workspace from geoh5py.groups import PropertyGroup, SimPEGGroup from geoh5py.objects import AirborneTEMReceivers, Surface @@ -195,11 +195,14 @@ def run(self): indices = self.params.survey.get_segment_indices( nearest, self.params.max_distance ) - spatial_projection = self.spatial_interpolation( - indices, + strike_angle = ( 0 if self.params.strike_angles is None - else self.params.strike_angles.values[ii], + else np.abs(self.params.strike_angles.values[ii]) + ) + spatial_projection = self.spatial_interpolation( + indices, + strike_angle, ) file_split = np.array_split( self.params.simulation_files, np.maximum(1, len(self.workers) * 10) @@ -249,8 +252,20 @@ def run(self): # Set position of plate to query location center = self.params.survey.vertices[nearest] center[2] = self._drape_heights[nearest] - plate.vertices = plate.vertices + center - plate.metadata = options.model.model_dump() + + # Rotate along line + delta = ( + self.params.survey.vertices[nearest + 1] + - self.params.survey.vertices[nearest] + ) + azm = np.rad2deg(np.arctan2(delta[1], delta[0])) + strike_angle + vertices = plate.vertices + center + vertices = rotate_xyz(vertices, center, azm) + + plate.vertices = vertices + metadata = options.model.model_dump() + metadata.update({"UUID": self.params.simulation_files[ranked[0]].name}) + plate.metadata = metadata results.append(self.params.simulation_files[ranked[0]].name) From 5212b3cf6f1888cf8fc058eb3a7a50cf9d77ad71 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Thu, 5 Feb 2026 10:54:08 -0800 Subject: [PATCH 2/6] Store indices of correlation --- simpeg_drivers/plate_simulation/match/driver.py | 12 +++++++----- tests/plate_simulation/runtest/match_test.py | 9 ++++++++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 789a4b40..4bccc820 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -188,6 +188,7 @@ def run(self): ) observed = normalized_data(self.params.data)[self._time_mask, :] tree = cKDTree(self.params.survey.vertices[:, :2]) + names = [] results = [] for ii, query in enumerate(self.params.queries.vertices): # Find the nearest survey location to the query point @@ -228,7 +229,7 @@ def run(self): progress(tasks) tasks = self.client.gather(tasks) - scores = np.hstack(tasks) + scores, indices = np.vstack(tasks).T ranked = np.argsort(scores)[::-1] # TODO: Return top N matches @@ -325,7 +326,7 @@ def fetch_survey(workspace: Workspace) -> AirborneTEMReceivers | None: def batch_files_score( files: Path | list[Path], spatial_projection, time_projection, observed -) -> list[float]: +) -> list[tuple[float, int]]: """ Process a batch of simulation files and compute scores against observed data. @@ -352,11 +353,11 @@ def batch_files_score( simulated = normalized_data(survey.get_entity("Iteration_0_z")[0]) pred = time_projection @ (spatial_projection @ simulated.T).T score = 0.0 - + indices = [] # Metric: normalized cross-correlation for obs, pre in zip(observed, pred, strict=True): # Full cross-correlation - corr = signal.correlate(obs, pre, mode="full") + corr = signal.correlate(obs, pre, mode="same") # Normalize by energy to get correlation coefficient in [-1, 1] denom = np.linalg.norm(pre) * np.linalg.norm(obs) if denom == 0: @@ -365,8 +366,9 @@ def batch_files_score( corr_norm = corr / denom score += np.max(corr_norm) + indices.append(np.argmax(corr_norm)) - scores.append(score) + scores.append((score, np.median(indices))) return scores diff --git a/tests/plate_simulation/runtest/match_test.py b/tests/plate_simulation/runtest/match_test.py index 731d1140..08f6f5bf 100644 --- a/tests/plate_simulation/runtest/match_test.py +++ b/tests/plate_simulation/runtest/match_test.py @@ -110,7 +110,7 @@ def test_matching_driver(tmp_path: Path): # Generate simulation files with get_workspace(tmp_path / f"{__name__}.geoh5") as geoh5: - components = generate_example(geoh5, n_grid_points=5, refinement=(2,)) + components = generate_example(geoh5, n_grid_points=15, refinement=(2,)) params = TDEMForwardOptions.build( geoh5=geoh5, @@ -154,6 +154,13 @@ def test_matching_driver(tmp_path: Path): child = survey.get_entity(uid)[0] child.values = child.values * scale + # Downsample data + mask = np.ones_like(child.values, dtype=bool) + mask[1::3] = False + survey.remove_vertices(mask) + indices = np.arange(survey.n_vertices) + survey.cells = np.c_[indices[:-1], indices[1:]] + # Random choice of file with geoh5.open(): survey = fetch_survey(geoh5) From 90bd13346ddeea6cc61dac2f164d6ad0c7943048 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Thu, 5 Feb 2026 11:31:20 -0800 Subject: [PATCH 3/6] Export scores and file name for each query point --- .../plate_simulation/match/driver.py | 18 ++++++++++++++++-- tests/plate_simulation/runtest/match_test.py | 5 ++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 4bccc820..0086800c 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -268,9 +268,23 @@ def run(self): metadata.update({"UUID": self.params.simulation_files[ranked[0]].name}) plate.metadata = metadata - results.append(self.params.simulation_files[ranked[0]].name) + names.append(self.params.simulation_files[ranked[0]].name) + results.append(scores[ranked[0]]) + + out = self.params.queries.copy(parent=self.params.out_group) + out.add_data( + { + "file": { + "values": np.array(names, dtype="U"), + "primitive_type": "TEXT", + }, + "score": { + "values": np.array(results), + }, + } + ) - return results + return out @classmethod def start_dask_run( diff --git a/tests/plate_simulation/runtest/match_test.py b/tests/plate_simulation/runtest/match_test.py index 08f6f5bf..3b0cd9bc 100644 --- a/tests/plate_simulation/runtest/match_test.py +++ b/tests/plate_simulation/runtest/match_test.py @@ -175,4 +175,7 @@ def test_matching_driver(tmp_path: Path): match_driver = PlateMatchDriver(options) results = match_driver.run() - assert results[0] == file.stem + f"_[{4}].geoh5" + assert isinstance(results, Points) + + names = results.get_data("file")[0] + assert names.values[0] == file.stem + f"_[{4}].geoh5" From e1b6478c1b9d3cbc1eadc1f64893c1a48cb004cd Mon Sep 17 00:00:00 2001 From: dominiquef Date: Fri, 6 Feb 2026 10:48:02 -0800 Subject: [PATCH 4/6] Move plate creation to seperate method. Update unit test --- .../plate_simulation/match/driver.py | 93 ++++++++++++------- tests/plate_simulation/runtest/match_test.py | 23 ++++- 2 files changed, 75 insertions(+), 41 deletions(-) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 0086800c..b7aeedeb 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -25,7 +25,8 @@ from geoapps_utils.utils.transformations import cartesian_to_polar, rotate_xyz from geoh5py import Workspace from geoh5py.groups import PropertyGroup, SimPEGGroup -from geoh5py.objects import AirborneTEMReceivers, Surface +from geoh5py.objects import AirborneTEMReceivers, MaxwellPlate, Surface +from geoh5py.objects.maxwell_plate import PlateGeometry from geoh5py.ui_json import InputFile from scipy import signal from scipy.sparse import csr_matrix @@ -34,7 +35,7 @@ from simpeg_drivers.driver import BaseDriver from simpeg_drivers.plate_simulation.match.options import PlateMatchOptions -from simpeg_drivers.plate_simulation.options import PlateSimulationOptions +from simpeg_drivers.plate_simulation.options import ModelOptions, PlateSimulationOptions logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False) @@ -123,6 +124,41 @@ def start(cls, filepath: str | Path, mode="r+", **_) -> Self: return driver + def _create_plate_from_parameters( + self, index_center: int, model_options: ModelOptions, strike_angle: float + ) -> MaxwellPlate: + center = self.params.survey.vertices[index_center] + center[2] = ( + self._drape_heights[index_center] - model_options.overburden_model.thickness + ) + indices = self.params.survey.get_segment_indices( + index_center, self.params.max_distance + ) + segment = self.params.survey.vertices[indices] + delta = np.mean(segment - segment[0, :], axis=0) + azimuth = 90 - np.rad2deg(np.arctan2(delta[0], delta[1])) + + plate_geometry = PlateGeometry.model_validate( + { + "position": { + "x": center[0], + "y": center[1], + "z": center[2], + }, + "width": model_options.plate_model.dip_length, + "thickness": model_options.plate_model.width, + "length": model_options.plate_model.strike_length, + "dip": model_options.plate_model.dip, + "dip_direction": azimuth + strike_angle, + } + ) + plate = MaxwellPlate.create( + self.params.geoh5, geometry=plate_geometry, parent=self.params.out_group + ) + plate.metadata = model_options.model_dump() + + return plate + def _get_drape_heights(self) -> np.ndarray: """Set drape heights based on topography object and optional topography data.""" @@ -229,47 +265,28 @@ def run(self): progress(tasks) tasks = self.client.gather(tasks) - scores, indices = np.vstack(tasks).T - ranked = np.argsort(scores)[::-1] - + scores, centers = np.vstack(tasks).T + ranked = np.argsort(scores) + best = ranked[0] # TODO: Return top N matches # for rank in ranked[-1:][::-1]: logger.info( "File: %s \nScore: %.4f", - self.params.simulation_files[ranked[0]].name, - scores[ranked[0]], + self.params.simulation_files[best].name, + scores[best], ) - with Workspace(self.params.simulation_files[ranked[0]], mode="r") as ws: + with Workspace(self.params.simulation_files[best], mode="r") as ws: survey = fetch_survey(ws) ui_json = survey.parent.parent.options ui_json["geoh5"] = ws ifile = InputFile(ui_json=ui_json) options = PlateSimulationOptions.build(ifile) - - plate = survey.parent.parent.get_entity("plate")[0].copy( - parent=self.params.out_group - ) - - # Set position of plate to query location - center = self.params.survey.vertices[nearest] - center[2] = self._drape_heights[nearest] - - # Rotate along line - delta = ( - self.params.survey.vertices[nearest + 1] - - self.params.survey.vertices[nearest] + self._create_plate_from_parameters( + int(indices[int(centers[best])]), options.model, strike_angle ) - azm = np.rad2deg(np.arctan2(delta[1], delta[0])) + strike_angle - vertices = plate.vertices + center - vertices = rotate_xyz(vertices, center, azm) - plate.vertices = vertices - metadata = options.model.model_dump() - metadata.update({"UUID": self.params.simulation_files[ranked[0]].name}) - plate.metadata = metadata - - names.append(self.params.simulation_files[ranked[0]].name) - results.append(scores[ranked[0]]) + names.append(self.params.simulation_files[best].name) + results.append(scores[best]) out = self.params.queries.copy(parent=self.params.out_group) out.add_data( @@ -313,7 +330,7 @@ def start_dask_run( def normalized_data(property_group: PropertyGroup, threshold=5) -> np.ndarray: """ - Return data from a property group with symlog scaling and zero mean. + Return data from a property group with symlog, zero mean and unit max normalization. :param property_group: Property group containing data channels. :param threshold: Percentile threshold for symlog normalization. @@ -324,7 +341,8 @@ def normalized_data(property_group: PropertyGroup, threshold=5) -> np.ndarray: data_array = np.vstack([table[name] for name in table.dtype.names]) thresh = np.percentile(np.abs(data_array), threshold) log_data = symlog(data_array, thresh) - return log_data - np.mean(log_data, axis=1)[:, None] + centered_log = log_data - np.mean(log_data, axis=1)[:, None] + return centered_log / np.abs(centered_log).max() def fetch_survey(workspace: Workspace) -> AirborneTEMReceivers | None: @@ -370,16 +388,19 @@ def batch_files_score( indices = [] # Metric: normalized cross-correlation for obs, pre in zip(observed, pred, strict=True): + # Scale pre on obs + vals = pre / np.abs(pre).max() * np.abs(obs).max() + # Full cross-correlation - corr = signal.correlate(obs, pre, mode="same") + corr = signal.correlate(obs, vals, mode="same") # Normalize by energy to get correlation coefficient in [-1, 1] - denom = np.linalg.norm(pre) * np.linalg.norm(obs) + denom = np.linalg.norm(vals) * np.linalg.norm(obs) if denom == 0: corr_norm = np.zeros_like(corr) else: corr_norm = corr / denom - score += np.max(corr_norm) + score += np.linalg.norm(obs - vals) indices.append(np.argmax(corr_norm)) scores.append((score, np.median(indices))) diff --git a/tests/plate_simulation/runtest/match_test.py b/tests/plate_simulation/runtest/match_test.py index 3b0cd9bc..3ab34a2a 100644 --- a/tests/plate_simulation/runtest/match_test.py +++ b/tests/plate_simulation/runtest/match_test.py @@ -13,6 +13,7 @@ import numpy as np import pytest from geoapps_utils.utils.importing import GeoAppsError +from geoapps_utils.utils.transformations import rotate_xyz from geoh5py import Workspace from geoh5py.groups import PropertyGroup, SimPEGGroup from geoh5py.objects import Points @@ -40,8 +41,14 @@ def generate_example(geoh5: Workspace, n_grid_points: int, refinement: tuple[int]): opts = SyntheticsComponentsOptions( method="airborne tdem", - survey=SurveyOptions(n_stations=n_grid_points, n_lines=1, drape=10.0), - mesh=MeshOptions(refinement=refinement, padding_distance=400.0), + survey=SurveyOptions( + n_stations=n_grid_points, + n_lines=1, + width=1000, + drape=40.0, + topography=lambda x, y: np.zeros(x.shape), + ), + mesh=MeshOptions(refinement=refinement), model=ModelOptions(background=0.001), ) components = SyntheticsComponents(geoh5, options=opts) @@ -110,7 +117,7 @@ def test_matching_driver(tmp_path: Path): # Generate simulation files with get_workspace(tmp_path / f"{__name__}.geoh5") as geoh5: - components = generate_example(geoh5, n_grid_points=15, refinement=(2,)) + components = generate_example(geoh5, n_grid_points=32, refinement=(2,)) params = TDEMForwardOptions.build( geoh5=geoh5, @@ -132,6 +139,8 @@ def test_matching_driver(tmp_path: Path): ifile.data["simulation"] = fwr_driver.out_group plate_options = PlateSimulationOptions.build(ifile.data) + plate_options.model.overburden_model.thickness = 40.0 + plate_options.model.plate_model.dip_length = 300.0 driver = PlateSimulationDriver(plate_options) driver.run() @@ -156,14 +165,18 @@ def test_matching_driver(tmp_path: Path): # Downsample data mask = np.ones_like(child.values, dtype=bool) - mask[1::3] = False + mask[1::2] = False survey.remove_vertices(mask) indices = np.arange(survey.n_vertices) survey.cells = np.c_[indices[:-1], indices[1:]] - # Random choice of file + # Run the matching driver with geoh5.open(): survey = fetch_survey(geoh5) + + # Rotate the survey to test matching + survey.vertices = rotate_xyz(survey.vertices, [0, 0, 0], 225.0) + options = PlateMatchOptions( geoh5=geoh5, survey=survey, From bc4374dbfdc27b7611921fe0170f055359365551 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Fri, 6 Feb 2026 12:45:32 -0800 Subject: [PATCH 5/6] Add flipping of line if detected up-dip. Flip interpolation if surveyed opposite to template direction --- .../plate_simulation/match/driver.py | 70 ++++++++++++++----- tests/plate_simulation/runtest/match_test.py | 10 ++- 2 files changed, 63 insertions(+), 17 deletions(-) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index b7aeedeb..efcc1506 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -149,7 +149,7 @@ def _create_plate_from_parameters( "thickness": model_options.plate_model.width, "length": model_options.plate_model.strike_length, "dip": model_options.plate_model.dip, - "dip_direction": azimuth + strike_angle, + "dip_direction": (azimuth + strike_angle) % 360, } ) plate = MaxwellPlate.create( @@ -195,6 +195,12 @@ def spatial_interpolation( origin=np.r_[self.params.survey.vertices[indices, :2].mean(axis=0), 0], ) local_polar[local_polar[:, 1] >= 180, 0] *= -1 # Wrap azimuths + + # Flip the line segment if the azimuth angle suggests the opposite direction + start_line = len(indices) // 2 + if np.median(local_polar[:start_line, 1]) < 180: + local_polar = local_polar[::-1, :] + local_polar[:, 1] = ( 0.0 if strike_angle is None else strike_angle ) # Align azimuths to zero @@ -222,8 +228,11 @@ def run(self): "Running %s . . .", self.params.title, ) - observed = normalized_data(self.params.data)[self._time_mask, :] + observed = get_data_array(self.params.data)[self._time_mask, :] tree = cKDTree(self.params.survey.vertices[:, :2]) + file_split = np.array_split( + self.params.simulation_files, np.maximum(1, len(self.workers) * 10) + ) names = [] results = [] for ii, query in enumerate(self.params.queries.vertices): @@ -237,21 +246,20 @@ def run(self): if self.params.strike_angles is None else np.abs(self.params.strike_angles.values[ii]) ) + data, flip = prepare_data(observed[:, indices]) + spatial_projection = self.spatial_interpolation( indices, strike_angle, ) - file_split = np.array_split( - self.params.simulation_files, np.maximum(1, len(self.workers) * 10) - ) - tasks = [] + for file_batch in file_split: args = ( file_batch, spatial_projection, self._time_projection, - observed[:, indices], + data, ) tasks.append( @@ -281,8 +289,11 @@ def run(self): ui_json["geoh5"] = ws ifile = InputFile(ui_json=ui_json) options = PlateSimulationOptions.build(ifile) + + dir_correction = strike_angle + 180 if flip else strike_angle + self._create_plate_from_parameters( - int(indices[int(centers[best])]), options.model, strike_angle + int(indices[int(centers[best])]), options.model, dir_correction ) names.append(self.params.simulation_files[best].name) @@ -328,20 +339,45 @@ def start_dask_run( ) -def normalized_data(property_group: PropertyGroup, threshold=5) -> np.ndarray: +def prepare_data(data: np.ndarray) -> tuple[np.ndarray, bool]: + """ + Prepare data for scoring by checking for multiple channels and normalizing. + + param data_array: Array of data channels per location. + + :return: Tuple of prepared data array, whether locations were reversed. + """ + data_array = normalized_data(data) + + # Guess what the down-dip direction is based on migration of peaks + max_ind = np.argmax(data_array, axis=1) + + # Check if peaks migrate in a consistent direction across channels + diffs = np.diff(max_ind) + if np.mean(diffs) < 0: + return data_array[:, ::-1], True # Reverse channels if peaks migrate up-dip + + return data_array, False + + +def get_data_array(property_group: PropertyGroup) -> np.ndarray: + """Extract data array from a property group.""" + table = property_group.table() + return np.vstack([table[name] for name in table.dtype.names]) + + +def normalized_data(data: np.ndarray, threshold=5) -> np.ndarray: """ Return data from a property group with symlog, zero mean and unit max normalization. - :param property_group: Property group containing data channels. + :param data: Array of data channels per location. :param threshold: Percentile threshold for symlog normalization. :return: Normalized data array. """ - table = property_group.table() - data_array = np.vstack([table[name] for name in table.dtype.names]) - thresh = np.percentile(np.abs(data_array), threshold) - log_data = symlog(data_array, thresh) - centered_log = log_data - np.mean(log_data, axis=1)[:, None] + thresh = np.percentile(np.abs(data), threshold) + log_data = symlog(data, thresh) + centered_log = log_data - np.mean(log_data) return centered_log / np.abs(centered_log).max() @@ -382,7 +418,9 @@ def batch_files_score( logger.warning("No survey found in %s, skipping.", sim_file) continue - simulated = normalized_data(survey.get_entity("Iteration_0_z")[0]) + simulated = normalized_data( + get_data_array(survey.get_entity("Iteration_0_z")[0]) + ) pred = time_projection @ (spatial_projection @ simulated.T).T score = 0.0 indices = [] diff --git a/tests/plate_simulation/runtest/match_test.py b/tests/plate_simulation/runtest/match_test.py index 3ab34a2a..a1d9779c 100644 --- a/tests/plate_simulation/runtest/match_test.py +++ b/tests/plate_simulation/runtest/match_test.py @@ -177,10 +177,16 @@ def test_matching_driver(tmp_path: Path): # Rotate the survey to test matching survey.vertices = rotate_xyz(survey.vertices, [0, 0, 0], 225.0) + # Flip the data to simulate up-dip measurements + prop_group = survey.get_entity("Iteration_0_z")[0] + for uid in prop_group.properties: + child = survey.get_entity(uid)[0] + child.values = child.values[::-1] + options = PlateMatchOptions( geoh5=geoh5, survey=survey, - data=survey.get_entity("Iteration_0_z")[0], + data=prop_group, queries=components.queries, topography_object=components.topography, simulations=new_dir, @@ -192,3 +198,5 @@ def test_matching_driver(tmp_path: Path): names = results.get_data("file")[0] assert names.values[0] == file.stem + f"_[{4}].geoh5" + + assert geoh5.get_entity("Maxwell Plate")[0].geometry.dip_direction == 45.0 From 502a58a0c211c9c0b035618e9bcad76fedbb1a88 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Fri, 6 Feb 2026 15:06:51 -0800 Subject: [PATCH 6/6] More robust azimuth calcs. Augment test --- .../plate_simulation/match/driver.py | 18 +++++++++--------- tests/plate_simulation/runtest/match_test.py | 14 ++++++++++++-- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index efcc1506..a66b225b 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -135,8 +135,8 @@ def _create_plate_from_parameters( index_center, self.params.max_distance ) segment = self.params.survey.vertices[indices] - delta = np.mean(segment - segment[0, :], axis=0) - azimuth = 90 - np.rad2deg(np.arctan2(delta[0], delta[1])) + delta = np.median(np.diff(segment, axis=0), axis=0) + azimuth = 90 - np.rad2deg(np.arctan2(delta[1], delta[0])) plate_geometry = PlateGeometry.model_validate( { @@ -244,13 +244,13 @@ def run(self): strike_angle = ( 0 if self.params.strike_angles is None - else np.abs(self.params.strike_angles.values[ii]) + else self.params.strike_angles.values[ii] ) data, flip = prepare_data(observed[:, indices]) spatial_projection = self.spatial_interpolation( indices, - strike_angle, + np.abs(strike_angle), ) tasks = [] @@ -292,9 +292,10 @@ def run(self): dir_correction = strike_angle + 180 if flip else strike_angle - self._create_plate_from_parameters( + plate = self._create_plate_from_parameters( int(indices[int(centers[best])]), options.model, dir_correction ) + plate.name = f"Query [{ii}]" names.append(self.params.simulation_files[best].name) results.append(scores[best]) @@ -354,7 +355,7 @@ def prepare_data(data: np.ndarray) -> tuple[np.ndarray, bool]: # Check if peaks migrate in a consistent direction across channels diffs = np.diff(max_ind) - if np.mean(diffs) < 0: + if np.median(diffs) < 0: return data_array[:, ::-1], True # Reverse channels if peaks migrate up-dip return data_array, False @@ -418,10 +419,9 @@ def batch_files_score( logger.warning("No survey found in %s, skipping.", sim_file) continue - simulated = normalized_data( - get_data_array(survey.get_entity("Iteration_0_z")[0]) - ) + simulated = get_data_array(survey.get_entity("Iteration_0_z")[0]) pred = time_projection @ (spatial_projection @ simulated.T).T + pred = normalized_data(pred) score = 0.0 indices = [] # Metric: normalized cross-correlation diff --git a/tests/plate_simulation/runtest/match_test.py b/tests/plate_simulation/runtest/match_test.py index a1d9779c..56e84257 100644 --- a/tests/plate_simulation/runtest/match_test.py +++ b/tests/plate_simulation/runtest/match_test.py @@ -175,7 +175,7 @@ def test_matching_driver(tmp_path: Path): survey = fetch_survey(geoh5) # Rotate the survey to test matching - survey.vertices = rotate_xyz(survey.vertices, [0, 0, 0], 225.0) + survey.vertices = rotate_xyz(survey.vertices, [0, 0, 0], 215.0) # Flip the data to simulate up-dip measurements prop_group = survey.get_entity("Iteration_0_z")[0] @@ -183,11 +183,21 @@ def test_matching_driver(tmp_path: Path): child = survey.get_entity(uid)[0] child.values = child.values[::-1] + # Change the strike angle to simulate a different orientation + strikes = components.queries.add_data( + { + "strike": { + "values": np.full(components.queries.n_vertices, -10.0), + } + } + ) + options = PlateMatchOptions( geoh5=geoh5, survey=survey, data=prop_group, queries=components.queries, + strike_angles=strikes, topography_object=components.topography, simulations=new_dir, ) @@ -199,4 +209,4 @@ def test_matching_driver(tmp_path: Path): names = results.get_data("file")[0] assert names.values[0] == file.stem + f"_[{4}].geoh5" - assert geoh5.get_entity("Maxwell Plate")[0].geometry.dip_direction == 45.0 + assert geoh5.get_entity("Query [0]")[0].geometry.dip_direction == 45.0