diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 19b47365..a66b225b 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -22,10 +22,11 @@ from geoapps_utils.utils.logger import get_logger from geoapps_utils.utils.numerical import inverse_weighted_operator from geoapps_utils.utils.plotting import symlog -from geoapps_utils.utils.transformations import cartesian_to_polar +from geoapps_utils.utils.transformations import cartesian_to_polar, rotate_xyz from geoh5py import Workspace from geoh5py.groups import PropertyGroup, SimPEGGroup -from geoh5py.objects import AirborneTEMReceivers, Surface +from geoh5py.objects import AirborneTEMReceivers, MaxwellPlate, Surface +from geoh5py.objects.maxwell_plate import PlateGeometry from geoh5py.ui_json import InputFile from scipy import signal from scipy.sparse import csr_matrix @@ -34,7 +35,7 @@ from simpeg_drivers.driver import BaseDriver from simpeg_drivers.plate_simulation.match.options import PlateMatchOptions -from simpeg_drivers.plate_simulation.options import PlateSimulationOptions +from simpeg_drivers.plate_simulation.options import ModelOptions, PlateSimulationOptions logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False) @@ -123,6 +124,41 @@ def start(cls, filepath: str | Path, mode="r+", **_) -> Self: return driver + def _create_plate_from_parameters( + self, index_center: int, model_options: ModelOptions, strike_angle: float + ) -> MaxwellPlate: + center = self.params.survey.vertices[index_center] + center[2] = ( + self._drape_heights[index_center] - model_options.overburden_model.thickness + ) + indices = self.params.survey.get_segment_indices( + index_center, self.params.max_distance + ) + segment = self.params.survey.vertices[indices] + delta = np.median(np.diff(segment, axis=0), axis=0) + azimuth = 90 - np.rad2deg(np.arctan2(delta[1], delta[0])) + + plate_geometry = PlateGeometry.model_validate( + { + "position": { + "x": center[0], + "y": center[1], + "z": center[2], + }, + "width": model_options.plate_model.dip_length, + "thickness": model_options.plate_model.width, + "length": model_options.plate_model.strike_length, + "dip": model_options.plate_model.dip, + "dip_direction": (azimuth + strike_angle) % 360, + } + ) + plate = MaxwellPlate.create( + self.params.geoh5, geometry=plate_geometry, parent=self.params.out_group + ) + plate.metadata = model_options.model_dump() + + return plate + def _get_drape_heights(self) -> np.ndarray: """Set drape heights based on topography object and optional topography data.""" @@ -159,6 +195,12 @@ def spatial_interpolation( origin=np.r_[self.params.survey.vertices[indices, :2].mean(axis=0), 0], ) local_polar[local_polar[:, 1] >= 180, 0] *= -1 # Wrap azimuths + + # Flip the line segment if the azimuth angle suggests the opposite direction + start_line = len(indices) // 2 + if np.median(local_polar[:start_line, 1]) < 180: + local_polar = local_polar[::-1, :] + local_polar[:, 1] = ( 0.0 if strike_angle is None else strike_angle ) # Align azimuths to zero @@ -186,8 +228,12 @@ def run(self): "Running %s . . .", self.params.title, ) - observed = normalized_data(self.params.data)[self._time_mask, :] + observed = get_data_array(self.params.data)[self._time_mask, :] tree = cKDTree(self.params.survey.vertices[:, :2]) + file_split = np.array_split( + self.params.simulation_files, np.maximum(1, len(self.workers) * 10) + ) + names = [] results = [] for ii, query in enumerate(self.params.queries.vertices): # Find the nearest survey location to the query point @@ -195,23 +241,25 @@ def run(self): indices = self.params.survey.get_segment_indices( nearest, self.params.max_distance ) - spatial_projection = self.spatial_interpolation( - indices, + strike_angle = ( 0 if self.params.strike_angles is None - else self.params.strike_angles.values[ii], - ) - file_split = np.array_split( - self.params.simulation_files, np.maximum(1, len(self.workers) * 10) + else self.params.strike_angles.values[ii] ) + data, flip = prepare_data(observed[:, indices]) + spatial_projection = self.spatial_interpolation( + indices, + np.abs(strike_angle), + ) tasks = [] + for file_batch in file_split: args = ( file_batch, spatial_projection, self._time_projection, - observed[:, indices], + data, ) tasks.append( @@ -225,36 +273,47 @@ def run(self): progress(tasks) tasks = self.client.gather(tasks) - scores = np.hstack(tasks) - ranked = np.argsort(scores)[::-1] - + scores, centers = np.vstack(tasks).T + ranked = np.argsort(scores) + best = ranked[0] # TODO: Return top N matches # for rank in ranked[-1:][::-1]: logger.info( "File: %s \nScore: %.4f", - self.params.simulation_files[ranked[0]].name, - scores[ranked[0]], + self.params.simulation_files[best].name, + scores[best], ) - with Workspace(self.params.simulation_files[ranked[0]], mode="r") as ws: + with Workspace(self.params.simulation_files[best], mode="r") as ws: survey = fetch_survey(ws) ui_json = survey.parent.parent.options ui_json["geoh5"] = ws ifile = InputFile(ui_json=ui_json) options = PlateSimulationOptions.build(ifile) - plate = survey.parent.parent.get_entity("plate")[0].copy( - parent=self.params.out_group - ) - - # Set position of plate to query location - center = self.params.survey.vertices[nearest] - center[2] = self._drape_heights[nearest] - plate.vertices = plate.vertices + center - plate.metadata = options.model.model_dump() + dir_correction = strike_angle + 180 if flip else strike_angle - results.append(self.params.simulation_files[ranked[0]].name) + plate = self._create_plate_from_parameters( + int(indices[int(centers[best])]), options.model, dir_correction + ) + plate.name = f"Query [{ii}]" + + names.append(self.params.simulation_files[best].name) + results.append(scores[best]) + + out = self.params.queries.copy(parent=self.params.out_group) + out.add_data( + { + "file": { + "values": np.array(names, dtype="U"), + "primitive_type": "TEXT", + }, + "score": { + "values": np.array(results), + }, + } + ) - return results + return out @classmethod def start_dask_run( @@ -281,20 +340,46 @@ def start_dask_run( ) -def normalized_data(property_group: PropertyGroup, threshold=5) -> np.ndarray: +def prepare_data(data: np.ndarray) -> tuple[np.ndarray, bool]: + """ + Prepare data for scoring by checking for multiple channels and normalizing. + + param data_array: Array of data channels per location. + + :return: Tuple of prepared data array, whether locations were reversed. + """ + data_array = normalized_data(data) + + # Guess what the down-dip direction is based on migration of peaks + max_ind = np.argmax(data_array, axis=1) + + # Check if peaks migrate in a consistent direction across channels + diffs = np.diff(max_ind) + if np.median(diffs) < 0: + return data_array[:, ::-1], True # Reverse channels if peaks migrate up-dip + + return data_array, False + + +def get_data_array(property_group: PropertyGroup) -> np.ndarray: + """Extract data array from a property group.""" + table = property_group.table() + return np.vstack([table[name] for name in table.dtype.names]) + + +def normalized_data(data: np.ndarray, threshold=5) -> np.ndarray: """ - Return data from a property group with symlog scaling and zero mean. + Return data from a property group with symlog, zero mean and unit max normalization. - :param property_group: Property group containing data channels. + :param data: Array of data channels per location. :param threshold: Percentile threshold for symlog normalization. :return: Normalized data array. """ - table = property_group.table() - data_array = np.vstack([table[name] for name in table.dtype.names]) - thresh = np.percentile(np.abs(data_array), threshold) - log_data = symlog(data_array, thresh) - return log_data - np.mean(log_data, axis=1)[:, None] + thresh = np.percentile(np.abs(data), threshold) + log_data = symlog(data, thresh) + centered_log = log_data - np.mean(log_data) + return centered_log / np.abs(centered_log).max() def fetch_survey(workspace: Workspace) -> AirborneTEMReceivers | None: @@ -310,7 +395,7 @@ def fetch_survey(workspace: Workspace) -> AirborneTEMReceivers | None: def batch_files_score( files: Path | list[Path], spatial_projection, time_projection, observed -) -> list[float]: +) -> list[tuple[float, int]]: """ Process a batch of simulation files and compute scores against observed data. @@ -334,24 +419,29 @@ def batch_files_score( logger.warning("No survey found in %s, skipping.", sim_file) continue - simulated = normalized_data(survey.get_entity("Iteration_0_z")[0]) + simulated = get_data_array(survey.get_entity("Iteration_0_z")[0]) pred = time_projection @ (spatial_projection @ simulated.T).T + pred = normalized_data(pred) score = 0.0 - + indices = [] # Metric: normalized cross-correlation for obs, pre in zip(observed, pred, strict=True): + # Scale pre on obs + vals = pre / np.abs(pre).max() * np.abs(obs).max() + # Full cross-correlation - corr = signal.correlate(obs, pre, mode="full") + corr = signal.correlate(obs, vals, mode="same") # Normalize by energy to get correlation coefficient in [-1, 1] - denom = np.linalg.norm(pre) * np.linalg.norm(obs) + denom = np.linalg.norm(vals) * np.linalg.norm(obs) if denom == 0: corr_norm = np.zeros_like(corr) else: corr_norm = corr / denom - score += np.max(corr_norm) + score += np.linalg.norm(obs - vals) + indices.append(np.argmax(corr_norm)) - scores.append(score) + scores.append((score, np.median(indices))) return scores diff --git a/tests/plate_simulation/runtest/match_test.py b/tests/plate_simulation/runtest/match_test.py index 731d1140..56e84257 100644 --- a/tests/plate_simulation/runtest/match_test.py +++ b/tests/plate_simulation/runtest/match_test.py @@ -13,6 +13,7 @@ import numpy as np import pytest from geoapps_utils.utils.importing import GeoAppsError +from geoapps_utils.utils.transformations import rotate_xyz from geoh5py import Workspace from geoh5py.groups import PropertyGroup, SimPEGGroup from geoh5py.objects import Points @@ -40,8 +41,14 @@ def generate_example(geoh5: Workspace, n_grid_points: int, refinement: tuple[int]): opts = SyntheticsComponentsOptions( method="airborne tdem", - survey=SurveyOptions(n_stations=n_grid_points, n_lines=1, drape=10.0), - mesh=MeshOptions(refinement=refinement, padding_distance=400.0), + survey=SurveyOptions( + n_stations=n_grid_points, + n_lines=1, + width=1000, + drape=40.0, + topography=lambda x, y: np.zeros(x.shape), + ), + mesh=MeshOptions(refinement=refinement), model=ModelOptions(background=0.001), ) components = SyntheticsComponents(geoh5, options=opts) @@ -110,7 +117,7 @@ def test_matching_driver(tmp_path: Path): # Generate simulation files with get_workspace(tmp_path / f"{__name__}.geoh5") as geoh5: - components = generate_example(geoh5, n_grid_points=5, refinement=(2,)) + components = generate_example(geoh5, n_grid_points=32, refinement=(2,)) params = TDEMForwardOptions.build( geoh5=geoh5, @@ -132,6 +139,8 @@ def test_matching_driver(tmp_path: Path): ifile.data["simulation"] = fwr_driver.out_group plate_options = PlateSimulationOptions.build(ifile.data) + plate_options.model.overburden_model.thickness = 40.0 + plate_options.model.plate_model.dip_length = 300.0 driver = PlateSimulationDriver(plate_options) driver.run() @@ -154,18 +163,50 @@ def test_matching_driver(tmp_path: Path): child = survey.get_entity(uid)[0] child.values = child.values * scale - # Random choice of file + # Downsample data + mask = np.ones_like(child.values, dtype=bool) + mask[1::2] = False + survey.remove_vertices(mask) + indices = np.arange(survey.n_vertices) + survey.cells = np.c_[indices[:-1], indices[1:]] + + # Run the matching driver with geoh5.open(): survey = fetch_survey(geoh5) + + # Rotate the survey to test matching + survey.vertices = rotate_xyz(survey.vertices, [0, 0, 0], 215.0) + + # Flip the data to simulate up-dip measurements + prop_group = survey.get_entity("Iteration_0_z")[0] + for uid in prop_group.properties: + child = survey.get_entity(uid)[0] + child.values = child.values[::-1] + + # Change the strike angle to simulate a different orientation + strikes = components.queries.add_data( + { + "strike": { + "values": np.full(components.queries.n_vertices, -10.0), + } + } + ) + options = PlateMatchOptions( geoh5=geoh5, survey=survey, - data=survey.get_entity("Iteration_0_z")[0], + data=prop_group, queries=components.queries, + strike_angles=strikes, topography_object=components.topography, simulations=new_dir, ) match_driver = PlateMatchDriver(options) results = match_driver.run() - assert results[0] == file.stem + f"_[{4}].geoh5" + assert isinstance(results, Points) + + names = results.get_data("file")[0] + assert names.values[0] == file.stem + f"_[{4}].geoh5" + + assert geoh5.get_entity("Query [0]")[0].geometry.dip_direction == 45.0