From 6ee51eaae74438596410c95057a77509dcfe509a Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Sun, 11 Jan 2026 16:27:17 -0500 Subject: [PATCH 1/8] Let Flow3D dataset to also provide geometry information --- matey/data_utils/blastnet_3Ddatasets.py | 8 ++++---- matey/data_utils/datasets.py | 3 +++ matey/data_utils/flow3d_datasets.py | 23 +++++++++++++++++++---- matey/train.py | 11 +++++++++++ 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/matey/data_utils/blastnet_3Ddatasets.py b/matey/data_utils/blastnet_3Ddatasets.py index e70a29b..029b164 100644 --- a/matey/data_utils/blastnet_3Ddatasets.py +++ b/matey/data_utils/blastnet_3Ddatasets.py @@ -194,15 +194,15 @@ def __getitem__(self, index): tar = np.squeeze(tar[:, :, isz0:isz0+cbszz, isx0:isx0+cbszx, isy0:isy0+cbszy], axis=0) # C,D,H,W else: - assert len(variables) in (2, 3) + assert len(variables) in (2, 3, 4) trajectory, leadtime = variables[:2] - trajectory = trajectory[:, :, isz0:isz0+cbszz, isx0:isx0+cbszx, isy0:isy0+cbszy] inp = trajectory[:-1] if self.leadtime_max > 0 else trajectory tar = trajectory[-1] - - if len(variables) == 3: + if len(variables) >= 3: ret_dict["cond_fields"] = variables[2] + if len(variables) >= 4: + ret_dict["geometry"] = variables[3] ret_dict["x"] = inp ret_dict["y"] = tar diff --git a/matey/data_utils/datasets.py b/matey/data_utils/datasets.py index 64520ee..7f8aa35 100644 --- a/matey/data_utils/datasets.py +++ b/matey/data_utils/datasets.py @@ -301,6 +301,9 @@ def __getitem__(self, index): if "cond_input" in variables: datasamples["cond_input"] = variables["cond_input"] + + if "geometry" in variables: + datasamples["geometry"] = variables["geometry"] return datasamples diff --git a/matey/data_utils/flow3d_datasets.py b/matey/data_utils/flow3d_datasets.py index e625755..ce2437f 100644 --- a/matey/data_utils/flow3d_datasets.py +++ b/matey/data_utils/flow3d_datasets.py @@ -18,6 +18,7 @@ class Flow3D_Object(BaseBLASTNET3DDataset): # cond_field_names = ["cell_types"] # cond_field_names = ["sdf_obstacle"] cond_field_names = ["sdf_obstacle", "sdf_channel"] + provides_geometry = True @staticmethod def _specifics(): @@ -202,6 +203,13 @@ def _get_filesinfo(self, file_paths): else: dictcase[datacasedir]["stats"] = self.compute_and_save_stats(f, json_path) + nx = [50, 192, 50] + res = nx + tx = np.linspace(0, nx[0], res[0], dtype=np.float32) + ty = np.linspace(0, nx[1], res[1], dtype=np.float32) + tz = np.linspace(0, nx[2], res[2], dtype=np.float32) + self.geometry = np.stack(np.meshgrid(tx, ty, tz, indexing="ij"), axis=-1) + return dictcase def _reconstruct_sample(self, dictcase, time_idx, leadtime): @@ -290,14 +298,21 @@ def get_data(start, end): # cond_data[:,:,:,-3,:] = cond_data[:,:,:,-1,:] # only for cell types # data[:,ft_mapping['p'],:,-3,:] = data[:,ft_mapping['p'],:,-1,:] - return data, cond_data + indices_x = slice(6, 54) + indices_y = slice(1, 49) + indices_z = slice(1, 49) + + data = data [:,:,indices_z, indices_x, indices_y] + cond_data = cond_data[:,:,indices_z, indices_x, indices_y] + + return data, cond_data, indices_x, indices_y, indices_z - comb_x, cond_data = get_data(time_idx, time_idx + self.nsteps_input) - comb_y, _ = get_data(time_idx + self.nsteps_input + leadtime - 1, time_idx + self.nsteps_input + leadtime) + comb_x, cond_data, indices_x, indices_y, indices_z = get_data(time_idx, time_idx + self.nsteps_input) + comb_y, _, _, _, _ = get_data(time_idx + self.nsteps_input + leadtime - 1, time_idx + self.nsteps_input + leadtime) comb = np.concatenate((comb_x, comb_y), axis=0) - return torch.from_numpy(comb), leadtime.to(torch.float32), torch.from_numpy(cond_data) + return torch.from_numpy(comb), leadtime.to(torch.float32), torch.from_numpy(cond_data), torch.from_numpy(self.geometry[indices_z, indices_x, indices_y]) def _get_specific_bcs(self): # FIXME: not used for now diff --git a/matey/train.py b/matey/train.py index e7717c6..7f62625 100644 --- a/matey/train.py +++ b/matey/train.py @@ -450,6 +450,12 @@ def train_one_epoch(self): cond_input = data["cond_input"].to(self.device) else: cond_input = None + + try: + geometry = data["geometry"].to(self.device) + except: + pass + cond_dict = {} try: cond_dict["labels"] = data["cond_field_labels"].to(self.device) @@ -616,6 +622,11 @@ def validate_one_epoch(self, full=False, cutoff_skip=False): else: cond_input = None + try: + geometry = data["geometry"].to(self.device) + except: + pass + cond_dict = {} try: cond_dict["labels"] = data["cond_field_labels"].to(self.device) From a119bf1cc9c544525fcb8b6a0c5b08519af76ee3 Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Tue, 13 Jan 2026 13:36:34 -0500 Subject: [PATCH 2/8] Add GNO model --- matey/models/gno.py | 245 ++++++++++++++++++++++++++++++++++++++++++++ matey/train.py | 12 ++- 2 files changed, 253 insertions(+), 4 deletions(-) create mode 100644 matey/models/gno.py diff --git a/matey/models/gno.py b/matey/models/gno.py new file mode 100644 index 0000000..94e64f2 --- /dev/null +++ b/matey/models/gno.py @@ -0,0 +1,245 @@ +import torch +import torch.nn as nn +import numpy as np +from neuralop.layers.channel_mlp import LinearChannelMLP +from neuralop.layers.integral_transform import IntegralTransform +from neuralop.layers.embeddings import SinusoidalEmbedding +from neuralop.layers.gno_block import GNOBlock +import sklearn +import torch.nn.functional as F +from ..utils.forward_options import ForwardOptionsBase, TrainOptionsBase +from typing import List, Literal, Optional, Callable +from einops import rearrange +import psutil + +class CustomNeighborSearch(nn.Module): + def __init__(self, return_norm=False): + super().__init__() + self.search_fn = custom_neighbor_search + self.return_norm = return_norm + + def forward(self, data, queries, radius): + return_dict = self.search_fn(data, queries, radius, self.return_norm) + return return_dict + +def custom_neighbor_search(data: torch.Tensor, queries: torch.Tensor, radius: float, return_norm: bool=False): + if not hasattr(custom_neighbor_search, "nbr_dict"): + custom_neighbor_search.nbr_dict = {} + + key = (tuple(data.shape), tuple(queries.shape), radius) + + if key not in custom_neighbor_search.nbr_dict: + kdtree = sklearn.neighbors.KDTree(data.cpu(), leaf_size=2) + + if return_norm: + indices, dists = kdtree.query_radius(queries.cpu(), r=radius, return_distance=True) + weights = torch.from_numpy(np.concatenate(dists)).to(queries.device) + else: + indices = kdtree.query_radius(queries.cpu(), r=radius) + + sizes = np.array([arr.size for arr in indices]) + nbr_indices = torch.from_numpy(np.concatenate(indices)).to(queries.device) + nbrhd_sizes = torch.cumsum(torch.from_numpy(sizes).to(queries.device), dim=0) + if return_norm: + custom_neighbor_search.nbr_dict[key] = (nbr_indices, nbrhd_sizes, weights) + else: + custom_neighbor_search.nbr_dict[key] = (nbr_indices, nbrhd_sizes) + + if return_norm: + nbr_indices, nbrhd_sizes, weights = custom_neighbor_search.nbr_dict[key] + else: + nbr_indices, nbrhd_sizes = custom_neighbor_search.nbr_dict[key] + + splits = torch.cat((torch.tensor([0.]).to(queries.device), nbrhd_sizes)) + + nbr_dict = {} + nbr_dict['neighbors_index'] = nbr_indices.long().to(queries.device) + nbr_dict['neighbors_row_splits'] = splits.long() + if return_norm: + nbr_dict['weights'] = weights**2 + + return nbr_dict + +class ModifiedGNOBlock(nn.Module): + """ + The code is equivalent to the original GNOBlock in neuraloperator, except + for the use of custom neighbor search + """ + def __init__(self, + in_channels: int, + out_channels: int, + coord_dim: int, + radius: float, + transform_type="linear", + weighting_fn: Optional[Callable]=None, + reduction: Literal['sum', 'mean']='sum', + pos_embedding_type: str='transformer', + pos_embedding_channels: int=32, + pos_embedding_max_positions: int=10000, + channel_mlp_layers: List[int]=[128,256,128], + channel_mlp_non_linearity=F.gelu, + channel_mlp: nn.Module=None, + use_torch_scatter_reduce: bool=True): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.coord_dim = coord_dim + + self.radius = radius + + # Apply sinusoidal positional embedding + self.pos_embedding_type = pos_embedding_type + if self.pos_embedding_type in ['nerf', 'transformer']: + self.pos_embedding = SinusoidalEmbedding( + in_channels=coord_dim, + num_frequencies=pos_embedding_channels, + embedding_type=pos_embedding_type, + max_positions=pos_embedding_max_positions + ) + else: + self.pos_embedding = None + + # Create in-to-out nb search module + self.neighbor_search = CustomNeighborSearch(return_norm=weighting_fn is not None) + + # create proper kernel input channel dim + if self.pos_embedding is None: + # x and y dim will be coordinate dim if no pos embedding is applied + kernel_in_dim = self.coord_dim * 2 + kernel_in_dim_str = "dim(y) + dim(x)" + else: + # x and y dim will be embedding dim if pos embedding is applied + kernel_in_dim = self.pos_embedding.out_channels * 2 + kernel_in_dim_str = "dim(y_embed) + dim(x_embed)" + + if transform_type == "nonlinear" or transform_type == "nonlinear_kernelonly": + kernel_in_dim += self.in_channels + kernel_in_dim_str += " + dim(f_y)" + + if channel_mlp is not None: + assert channel_mlp.in_channels == kernel_in_dim, f"Error: expected ChannelMLP to take\ + input with {kernel_in_dim} channels (feature channels={kernel_in_dim_str}),\ + got {channel_mlp.in_channels}." + assert channel_mlp.out_channels == out_channels, f"Error: expected ChannelMLP to have\ + {out_channels=} but got {channel_mlp.in_channels=}." + channel_mlp = channel_mlp + + elif channel_mlp_layers is not None: + if channel_mlp_layers[0] != kernel_in_dim: + channel_mlp_layers = [kernel_in_dim] + channel_mlp_layers + if channel_mlp_layers[-1] != self.out_channels: + channel_mlp_layers.append(self.out_channels) + channel_mlp = LinearChannelMLP(layers=channel_mlp_layers, non_linearity=channel_mlp_non_linearity) + + # Create integral transform module + self.integral_transform = IntegralTransform( + channel_mlp=channel_mlp, + transform_type=transform_type, + use_torch_scatter=use_torch_scatter_reduce, + weighting_fn=weighting_fn, + reduction=reduction + ) + + def forward(self, y, x, f_y=None): + if f_y is not None: + if f_y.ndim == 3 and f_y.shape[0] == -1: + f_y = f_y.squeeze(0) + + neighbors_dict = self.neighbor_search(data=y, queries=x, radius=self.radius) + + if self.pos_embedding is not None: + y_embed = self.pos_embedding(y) + x_embed = self.pos_embedding(x) + else: + y_embed = y + x_embed = x + + out_features = self.integral_transform(y=y_embed, + x=x_embed, + neighbors=neighbors_dict, + f_y=f_y) + + return out_features + + +def build_gno(num_channels, inner_model, params): + model = GNOModel(num_channels, inner_model, params) + + return model + + +class GNOModel(nn.Module): + def __init__(self, num_channels, inner_model, params=None): + super().__init__() + + print(params, flush=True) + self.gno_in = ModifiedGNOBlock( + in_channels=num_channels, + out_channels=num_channels, + coord_dim=3, + radius=params.gno["radius_in"] + # weighting_fn=params.weighting_fn, + # reduction=params.reduction + ) + self.model = inner_model + self.gno_out = ModifiedGNOBlock( + in_channels=num_channels, + out_channels=num_channels, + coord_dim=3, + radius=params.gno["radius_out"], + # weighting_fn=params.gno.weighting_fn, + # reduction=params.gno.reduction + ) + + self.model = inner_model + + self.res = params.gno["resolution"] + + bmin = [0, 0, 0] + bmax = [1, 1, 1] + self.latent_geom = self.generate_geometry(bmin, bmax, self.res) + + def generate_geometry(self, bmin, bmax, res): + tx = np.linspace(bmin[0], bmax[0], res[0], dtype=np.float32) + ty = np.linspace(bmin[1], bmax[1], res[1], dtype=np.float32) + tz = np.linspace(bmin[2], bmax[2], res[2], dtype=np.float32) + + geometry = torch.from_numpy(np.stack(np.meshgrid(tx, ty, tz, indexing="ij"), axis=-1)) + return torch.flatten(geometry, end_dim=-2) + + def forward(self, x, state_labels, bcs, geometry, opts: ForwardOptionsBase, train_opts: Optional[TrainOptionsBase]=None): + assert geometry != None, "GNOModel requires geometry input" + + # We assume that all geometries in a batch are identical for now + input_geom = torch.flatten(geometry[0], end_dim=-2) + + # Rescale auxiliary grid + latent_geom = self.latent_geom.to(device=x.device) + bmin = [0, 0, 0] + bmax = [1, 1, 1] + for d in range(3): + bmin[d] = input_geom[:,d].min() + bmax[d] = input_geom[:,d].max() + for d in range(3): + latent_geom[:,d] = bmin[d] + (bmax[d] - bmin[d]) * latent_geom[:,d] + + T, B, C, D, H, W = x.shape + Dlat, Hlat, Wlat = self.res[0], self.res[1], self.res[2] + + # Pre-process using GNO + out = torch.zeros(T, B, C, Dlat, Hlat, Wlat, device=x.device) + for t in range(T): + y = rearrange(x[t,:], 'b c d h w -> b (h w d) c') + out_y = self.gno_in(y=input_geom, x=latent_geom, f_y=y) + out[t,:] = rearrange(out_y, 'b (h w d) c -> b c d h w', d=Dlat, h=Hlat, w=Wlat) + + # Run regular model + out = self.model(out, state_labels, bcs, opts, train_opts) + + # Post-process using GNO + out = rearrange(out, 'b c d h w -> b (h w d) c') + out = self.gno_out(y=latent_geom, x=input_geom, f_y=out) + out = rearrange(out, 'b (h w d) c -> b c d h w', d=D, h=H, w=W) + + return out diff --git a/matey/train.py b/matey/train.py index 7f62625..cd8599c 100644 --- a/matey/train.py +++ b/matey/train.py @@ -20,6 +20,7 @@ from .models.avit import build_avit from .models.svit import build_svit from .models.vit import build_vit +from .models.gno import build_gno from .models.turbt import build_turbt from .utils.logging_utils import Timer, record_function_opt from .utils.distributed_utils import get_sequence_parallel_group, add_weight_decay, CosineNoIncrease, determine_turt_levels @@ -177,6 +178,9 @@ def initialize_model(self): elif self.params.model_type == "turbt": self.model = build_turbt(self.params).to(self.device) + num_channels = 4 + self.model = build_gno(num_channels, self.model, self.params).to(self.device) + if self.params.compile: print('WARNING: BFLOAT NOT SUPPORTED IN SOME COMPILE OPS SO SWITCHING TO FLOAT16') self.mp_type = torch.half @@ -395,11 +399,11 @@ def freeze_model_pretraining(self): self.model = self.model.to(self.device) - def model_forward(self, inp, field_labels, bcs, opts: ForwardOptionsBase, pushforward=True): + def model_forward(self, inp, field_labels, bcs, geometry, opts: ForwardOptionsBase, pushforward=True): # Handles a forward pass through the model, either normal or autoregressive rollout. autoregressive = getattr(self.params, "autoregressive", False) if not autoregressive: - output = self.model(inp, field_labels, bcs, opts) + output = self.model(inp, field_labels, bcs, geometry, opts) return output, None else: # autoregressive rollout @@ -501,7 +505,7 @@ def train_one_epoch(self): field_labels_out= field_labels_out ) with record_function_opt("model forward", enabled=self.profiling): - output, rollout_steps = self.model_forward(inp, field_labels, bcs, opts) + output, rollout_steps = self.model_forward(inp, field_labels, bcs, geometry, opts) if tar.ndim == 6:# B,T,C,D,H,W; For autoregressive, update the target with the returned actual rollout_steps tar = tar[:, rollout_steps-1, :] # B,C,D,H,W #compute loss and update (in-place) logging dicts. @@ -670,7 +674,7 @@ def validate_one_epoch(self, full=False, cutoff_skip=False): isgraph=isgraph, field_labels_out= field_labels_out ) - output, rollout_steps = self.model_forward(inp, field_labels, bcs, opts) + output, rollout_steps = self.model_forward(inp, field_labels, bcs, geometry, opts) if tar.ndim == 6:# B,T,C,D,H,W; For autoregressive, update the target with the returned actual rollout_steps tar = tar[:, rollout_steps-1, :] # B,C,D,H,W update_loss_logs_inplace_eval(output, tar, graphdata if isgraph else None, logs, loss_dset_logs, loss_l1_dset_logs, loss_rmse_dset_logs, dset_type) From af2e2fdfa69683f8272bf0c2f906a392dd315322 Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Fri, 23 Jan 2026 14:02:41 -0800 Subject: [PATCH 3/8] Make geometry optional --- matey/models/gno.py | 8 +++++--- matey/train.py | 21 ++++++++++++--------- matey/utils/forward_options.py | 3 ++- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/matey/models/gno.py b/matey/models/gno.py index 94e64f2..047500d 100644 --- a/matey/models/gno.py +++ b/matey/models/gno.py @@ -208,11 +208,13 @@ def generate_geometry(self, bmin, bmax, res): geometry = torch.from_numpy(np.stack(np.meshgrid(tx, ty, tz, indexing="ij"), axis=-1)) return torch.flatten(geometry, end_dim=-2) - def forward(self, x, state_labels, bcs, geometry, opts: ForwardOptionsBase, train_opts: Optional[TrainOptionsBase]=None): - assert geometry != None, "GNOModel requires geometry input" + def forward(self, x, state_labels, bcs, opts: ForwardOptionsBase, train_opts: Optional[TrainOptionsBase]=None): + if opts.geometry == None: + # Pass-through option without using geometry + return self.model(x, state_labels, bcs, opts, train_opts) # We assume that all geometries in a batch are identical for now - input_geom = torch.flatten(geometry[0], end_dim=-2) + input_geom = torch.flatten(opts.geometry[0], end_dim=-2) # Rescale auxiliary grid latent_geom = self.latent_geom.to(device=x.device) diff --git a/matey/train.py b/matey/train.py index cd8599c..12a6a75 100644 --- a/matey/train.py +++ b/matey/train.py @@ -179,7 +179,8 @@ def initialize_model(self): self.model = build_turbt(self.params).to(self.device) num_channels = 4 - self.model = build_gno(num_channels, self.model, self.params).to(self.device) + if hasattr(self.params, "gno"): + self.model = build_gno(num_channels, self.model, self.params).to(self.device) if self.params.compile: print('WARNING: BFLOAT NOT SUPPORTED IN SOME COMPILE OPS SO SWITCHING TO FLOAT16') @@ -399,11 +400,11 @@ def freeze_model_pretraining(self): self.model = self.model.to(self.device) - def model_forward(self, inp, field_labels, bcs, geometry, opts: ForwardOptionsBase, pushforward=True): + def model_forward(self, inp, field_labels, bcs, opts: ForwardOptionsBase, pushforward=True): # Handles a forward pass through the model, either normal or autoregressive rollout. autoregressive = getattr(self.params, "autoregressive", False) if not autoregressive: - output = self.model(inp, field_labels, bcs, geometry, opts) + output = self.model(inp, field_labels, bcs, opts) return output, None else: # autoregressive rollout @@ -458,7 +459,7 @@ def train_one_epoch(self): try: geometry = data["geometry"].to(self.device) except: - pass + geometry = None cond_dict = {} try: @@ -502,10 +503,11 @@ def train_one_epoch(self): cond_dict=copy.deepcopy(cond_dict), cond_input=cond_input, isgraph=isgraph, - field_labels_out= field_labels_out + field_labels_out= field_labels_out, + geometry=geometry ) with record_function_opt("model forward", enabled=self.profiling): - output, rollout_steps = self.model_forward(inp, field_labels, bcs, geometry, opts) + output, rollout_steps = self.model_forward(inp, field_labels, bcs, opts) if tar.ndim == 6:# B,T,C,D,H,W; For autoregressive, update the target with the returned actual rollout_steps tar = tar[:, rollout_steps-1, :] # B,C,D,H,W #compute loss and update (in-place) logging dicts. @@ -629,7 +631,7 @@ def validate_one_epoch(self, full=False, cutoff_skip=False): try: geometry = data["geometry"].to(self.device) except: - pass + geometry = None cond_dict = {} try: @@ -672,9 +674,10 @@ def validate_one_epoch(self, full=False, cutoff_skip=False): cond_dict=copy.deepcopy(cond_dict), cond_input=cond_input, isgraph=isgraph, - field_labels_out= field_labels_out + field_labels_out= field_labels_out, + geometry=geometry ) - output, rollout_steps = self.model_forward(inp, field_labels, bcs, geometry, opts) + output, rollout_steps = self.model_forward(inp, field_labels, bcs, opts) if tar.ndim == 6:# B,T,C,D,H,W; For autoregressive, update the target with the returned actual rollout_steps tar = tar[:, rollout_steps-1, :] # B,C,D,H,W update_loss_logs_inplace_eval(output, tar, graphdata if isgraph else None, logs, loss_dset_logs, loss_l1_dset_logs, loss_rmse_dset_logs, dset_type) diff --git a/matey/utils/forward_options.py b/matey/utils/forward_options.py index 986eafc..116adbf 100644 --- a/matey/utils/forward_options.py +++ b/matey/utils/forward_options.py @@ -25,4 +25,5 @@ class ForwardOptionsBase: field_labels_out: Optional[Tensor] = None #adaptive tokenization (1 of 2 settings) refine_ratio: Optional[float] = None - imod_bottom: int = 0 #needed only by turbt \ No newline at end of file + imod_bottom: int = 0 #needed only by turbt + geometry: Optional[Tensor] = None From 5aff0acea24c4f1a66ba637f2bd00096698844a5 Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Tue, 27 Jan 2026 10:31:59 -0800 Subject: [PATCH 4/8] Add neighbor timers --- matey/models/gno.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/matey/models/gno.py b/matey/models/gno.py index 047500d..9146ca7 100644 --- a/matey/models/gno.py +++ b/matey/models/gno.py @@ -12,6 +12,8 @@ from einops import rearrange import psutil +import time + class CustomNeighborSearch(nn.Module): def __init__(self, return_norm=False): super().__init__() @@ -29,13 +31,20 @@ def custom_neighbor_search(data: torch.Tensor, queries: torch.Tensor, radius: fl key = (tuple(data.shape), tuple(queries.shape), radius) if key not in custom_neighbor_search.nbr_dict: + start = time.time() kdtree = sklearn.neighbors.KDTree(data.cpu(), leaf_size=2) + construction_time = time.time() - start + start = time.time() if return_norm: indices, dists = kdtree.query_radius(queries.cpu(), r=radius, return_distance=True) weights = torch.from_numpy(np.concatenate(dists)).to(queries.device) else: indices = kdtree.query_radius(queries.cpu(), r=radius) + query_time = time.time() - start + + print(f'neighbors: indices = {indices.size}, avg_indices = {indices.size//int(queries.shape[0])}') + print(f'neighbors: construction = {construction_time}, query = {query_time}', flush=True) sizes = np.array([arr.size for arr in indices]) nbr_indices = torch.from_numpy(np.concatenate(indices)).to(queries.device) From 665cbd178dd53a76423196dc2a4318ed7e14b06b Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Tue, 27 Jan 2026 13:43:03 -0800 Subject: [PATCH 5/8] Return geometry id --- matey/data_utils/flow3d_datasets.py | 6 ++++-- matey/train.py | 14 ++++++++------ matey/utils/forward_options.py | 2 +- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/matey/data_utils/flow3d_datasets.py b/matey/data_utils/flow3d_datasets.py index ce2437f..accf2bc 100644 --- a/matey/data_utils/flow3d_datasets.py +++ b/matey/data_utils/flow3d_datasets.py @@ -168,7 +168,7 @@ def compute_and_save_sdf(self, f, sdf_path, mode = "negative_one"): def _get_filesinfo(self, file_paths): dictcase = {} - for datacasedir in file_paths: + for case_id, datacasedir in enumerate(file_paths): file = os.path.join(datacasedir, "data.h5") f = h5py.File(file) nsteps = 5000 @@ -186,6 +186,7 @@ def _get_filesinfo(self, file_paths): dictcase[datacasedir]["ntimes"] = nsteps dictcase[datacasedir]["features"] = features dictcase[datacasedir]["features_mapping"] = features_mapping + dictcase[datacasedir]["geometry_id"] = case_id sdf_path = os.path.join(datacasedir, "sdf_neg_one.npz") if not os.path.exists(sdf_path): @@ -312,7 +313,8 @@ def get_data(start, end): comb = np.concatenate((comb_x, comb_y), axis=0) - return torch.from_numpy(comb), leadtime.to(torch.float32), torch.from_numpy(cond_data), torch.from_numpy(self.geometry[indices_z, indices_x, indices_y]) + # print(f'Returning geometry id {dictcase["geometry_id"]}', flush=True) + return torch.from_numpy(comb), leadtime.to(torch.float32), torch.from_numpy(cond_data), {"geometry_id": dictcase["geometry_id"], "geometry": torch.from_numpy(self.geometry[indices_z, indices_x, indices_y])} def _get_specific_bcs(self): # FIXME: not used for now diff --git a/matey/train.py b/matey/train.py index 12a6a75..22f65d3 100644 --- a/matey/train.py +++ b/matey/train.py @@ -456,9 +456,10 @@ def train_one_epoch(self): else: cond_input = None - try: - geometry = data["geometry"].to(self.device) - except: + if "geometry" in data: + geometry = data["geometry"] + geometry["geometry"] = geometry["geometry"].to(self.device) + else: geometry = None cond_dict = {} @@ -628,9 +629,10 @@ def validate_one_epoch(self, full=False, cutoff_skip=False): else: cond_input = None - try: - geometry = data["geometry"].to(self.device) - except: + if "geometry" in data: + geometry = data["geometry"] + geometry["geometry"] = geometry["geometry"].to(self.device) + else: geometry = None cond_dict = {} diff --git a/matey/utils/forward_options.py b/matey/utils/forward_options.py index 116adbf..122f554 100644 --- a/matey/utils/forward_options.py +++ b/matey/utils/forward_options.py @@ -26,4 +26,4 @@ class ForwardOptionsBase: #adaptive tokenization (1 of 2 settings) refine_ratio: Optional[float] = None imod_bottom: int = 0 #needed only by turbt - geometry: Optional[Tensor] = None + geometry: Optional[tuple[int, Tensor]] = None From 9d85475f34b729ce3df94f1806c0dff7f0b1b44e Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Tue, 27 Jan 2026 13:43:46 -0800 Subject: [PATCH 6/8] Update GNO to deal with different geometries in batch and geometry ids --- matey/models/gno.py | 113 +++++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 55 deletions(-) diff --git a/matey/models/gno.py b/matey/models/gno.py index 9146ca7..e94ef14 100644 --- a/matey/models/gno.py +++ b/matey/models/gno.py @@ -25,40 +25,23 @@ def forward(self, data, queries, radius): return return_dict def custom_neighbor_search(data: torch.Tensor, queries: torch.Tensor, radius: float, return_norm: bool=False): - if not hasattr(custom_neighbor_search, "nbr_dict"): - custom_neighbor_search.nbr_dict = {} - - key = (tuple(data.shape), tuple(queries.shape), radius) - - if key not in custom_neighbor_search.nbr_dict: - start = time.time() - kdtree = sklearn.neighbors.KDTree(data.cpu(), leaf_size=2) - construction_time = time.time() - start - - start = time.time() - if return_norm: - indices, dists = kdtree.query_radius(queries.cpu(), r=radius, return_distance=True) - weights = torch.from_numpy(np.concatenate(dists)).to(queries.device) - else: - indices = kdtree.query_radius(queries.cpu(), r=radius) - query_time = time.time() - start - - print(f'neighbors: indices = {indices.size}, avg_indices = {indices.size//int(queries.shape[0])}') - print(f'neighbors: construction = {construction_time}, query = {query_time}', flush=True) - - sizes = np.array([arr.size for arr in indices]) - nbr_indices = torch.from_numpy(np.concatenate(indices)).to(queries.device) - nbrhd_sizes = torch.cumsum(torch.from_numpy(sizes).to(queries.device), dim=0) - if return_norm: - custom_neighbor_search.nbr_dict[key] = (nbr_indices, nbrhd_sizes, weights) - else: - custom_neighbor_search.nbr_dict[key] = (nbr_indices, nbrhd_sizes) + start = time.time() + kdtree = sklearn.neighbors.KDTree(data.cpu(), leaf_size=2) + construction_time = time.time() - start + start = time.time() if return_norm: - nbr_indices, nbrhd_sizes, weights = custom_neighbor_search.nbr_dict[key] + indices, dists = kdtree.query_radius(queries.cpu(), r=radius, return_distance=True) + weights = torch.from_numpy(np.concatenate(dists)).to(queries.device) else: - nbr_indices, nbrhd_sizes = custom_neighbor_search.nbr_dict[key] + indices = kdtree.query_radius(queries.cpu(), r=radius) + query_time = time.time() - start + print(f'neighbors: construction = {construction_time}, query = {query_time}', flush=True) + + sizes = np.array([arr.size for arr in indices]) + nbr_indices = torch.from_numpy(np.concatenate(indices)).to(queries.device) + nbrhd_sizes = torch.cumsum(torch.from_numpy(sizes).to(queries.device), dim=0) splits = torch.cat((torch.tensor([0.]).to(queries.device), nbrhd_sizes)) nbr_dict = {} @@ -150,12 +133,21 @@ def __init__(self, reduction=reduction ) - def forward(self, y, x, f_y=None): + self.neighbors_dict = {} + + def forward(self, y, x, f_y, key): if f_y is not None: if f_y.ndim == 3 and f_y.shape[0] == -1: f_y = f_y.squeeze(0) - neighbors_dict = self.neighbor_search(data=y, queries=x, radius=self.radius) + key = f'{key}:{self.radius}:{y.shape}:{x.shape}' + if not key in self.neighbors_dict: + print(f'{key}: building new neighbors') + neigh = self.neighbor_search(data=y, queries=x, radius=self.radius) + self.neighbors_dict[key] = neigh + else: + # print(f'{key}: using cached neighbors') + pass if self.pos_embedding is not None: y_embed = self.pos_embedding(y) @@ -166,7 +158,7 @@ def forward(self, y, x, f_y=None): out_features = self.integral_transform(y=y_embed, x=x_embed, - neighbors=neighbors_dict, + neighbors=self.neighbors_dict[key], f_y=f_y) return out_features @@ -182,12 +174,15 @@ class GNOModel(nn.Module): def __init__(self, num_channels, inner_model, params=None): super().__init__() + self.radius_in = params.gno["radius_in"] + self.radius_out = params.gno["radius_out"] + print(params, flush=True) self.gno_in = ModifiedGNOBlock( in_channels=num_channels, out_channels=num_channels, coord_dim=3, - radius=params.gno["radius_in"] + radius=self.radius_in # weighting_fn=params.weighting_fn, # reduction=params.reduction ) @@ -196,7 +191,7 @@ def __init__(self, num_channels, inner_model, params=None): in_channels=num_channels, out_channels=num_channels, coord_dim=3, - radius=params.gno["radius_out"], + radius=self.radius_out # weighting_fn=params.gno.weighting_fn, # reduction=params.gno.reduction ) @@ -222,35 +217,43 @@ def forward(self, x, state_labels, bcs, opts: ForwardOptionsBase, train_opts: Op # Pass-through option without using geometry return self.model(x, state_labels, bcs, opts, train_opts) - # We assume that all geometries in a batch are identical for now - input_geom = torch.flatten(opts.geometry[0], end_dim=-2) - - # Rescale auxiliary grid - latent_geom = self.latent_geom.to(device=x.device) - bmin = [0, 0, 0] - bmax = [1, 1, 1] - for d in range(3): - bmin[d] = input_geom[:,d].min() - bmax[d] = input_geom[:,d].max() - for d in range(3): - latent_geom[:,d] = bmin[d] + (bmax[d] - bmin[d]) * latent_geom[:,d] - T, B, C, D, H, W = x.shape Dlat, Hlat, Wlat = self.res[0], self.res[1], self.res[2] - # Pre-process using GNO out = torch.zeros(T, B, C, Dlat, Hlat, Wlat, device=x.device) - for t in range(T): - y = rearrange(x[t,:], 'b c d h w -> b (h w d) c') - out_y = self.gno_in(y=input_geom, x=latent_geom, f_y=y) - out[t,:] = rearrange(out_y, 'b (h w d) c -> b c d h w', d=Dlat, h=Hlat, w=Wlat) + + # Pre-process using GNO + # The challenge is that different samples in the same batch may correspond to different geometries + input_geom = [None] * B + latent_geom = [None] * B + for b in range(B): + geometry_id = opts.geometry["geometry_id"][b] + input_geom[b] = torch.flatten(opts.geometry["geometry"][b], end_dim=-2) + + # Rescale auxiliary grid + bmin = [None] * 3 + bmax = [None] * 3 + for d in range(3): + bmin[d] = input_geom[b][:,d].min() + bmax[d] = input_geom[b][:,d].max() + latent_geom[b] = self.latent_geom.to(device=x.device) + for d in range(3): + latent_geom[b][:,d] = bmin[d] + (bmax[d] - bmin[d]) * latent_geom[b][:,d] + + # Use T as batch + y = rearrange(x[:,b,:,:,:,:], 't c d h w -> t (h w d) c') + out_y = self.gno_in(y=input_geom[b], x=latent_geom[b], f_y=y, key=str(geometry_id) + ":in") + out[:,b,:,:,:,:] = rearrange(out_y, 't (h w d) c -> t c d h w', d=Dlat, h=Hlat, w=Wlat) # Run regular model - out = self.model(out, state_labels, bcs, opts, train_opts) + out_model = self.model(out, state_labels, bcs, opts, train_opts) # Post-process using GNO + out = torch.zeros(B, C, D, H, W, device=x.device) + out_model = rearrange(out, 'b c d h w -> b (h w d) c') out = rearrange(out, 'b c d h w -> b (h w d) c') - out = self.gno_out(y=latent_geom, x=input_geom, f_y=out) + for b in range(B): + out[b] = self.gno_out(y=latent_geom[b], x=input_geom[b], f_y=out_model[b], key=str(geometry_id) + ":out") out = rearrange(out, 'b (h w d) c -> b c d h w', d=D, h=H, w=W) return out From df9384d630d58ff46315aa70b3d5aa0206ad6e8e Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Tue, 17 Feb 2026 14:08:36 -0500 Subject: [PATCH 7/8] Make neuralop and sklearn module loading optional --- matey/models/gno.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/matey/models/gno.py b/matey/models/gno.py index e94ef14..19456d9 100644 --- a/matey/models/gno.py +++ b/matey/models/gno.py @@ -1,11 +1,19 @@ import torch import torch.nn as nn import numpy as np -from neuralop.layers.channel_mlp import LinearChannelMLP -from neuralop.layers.integral_transform import IntegralTransform -from neuralop.layers.embeddings import SinusoidalEmbedding -from neuralop.layers.gno_block import GNOBlock -import sklearn +try: + from neuralop.layers.channel_mlp import LinearChannelMLP + from neuralop.layers.integral_transform import IntegralTransform + from neuralop.layers.embeddings import SinusoidalEmbedding + from neuralop.layers.gno_block import GNOBlock + neuralop_exist = True +except ImportError: + neuralop_exist = False +try: + import sklearn + sklearn_exist = True +except ImportError: + sklearn_exist = False import torch.nn.functional as F from ..utils.forward_options import ForwardOptionsBase, TrainOptionsBase from typing import List, Literal, Optional, Callable @@ -25,6 +33,9 @@ def forward(self, data, queries, radius): return return_dict def custom_neighbor_search(data: torch.Tensor, queries: torch.Tensor, radius: float, return_norm: bool=False): + if not sklearn_exist: + raise RuntimeError("sklearn is required for constructing neighbors.") + start = time.time() kdtree = sklearn.neighbors.KDTree(data.cpu(), leaf_size=2) construction_time = time.time() - start @@ -80,6 +91,9 @@ def __init__(self, self.radius = radius + if not neuralop_exist: + raise RuntimeError("NeuralOp is required for running GNO module.") + # Apply sinusoidal positional embedding self.pos_embedding_type = pos_embedding_type if self.pos_embedding_type in ['nerf', 'transformer']: From 16399a53924a21e1ead668d9b3d6b1abb3be126d Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Tue, 17 Feb 2026 14:15:11 -0500 Subject: [PATCH 8/8] Pass num_channels to GNO through params --- matey/models/gno.py | 7 ++++--- matey/train.py | 3 +-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/matey/models/gno.py b/matey/models/gno.py index 19456d9..deae2bd 100644 --- a/matey/models/gno.py +++ b/matey/models/gno.py @@ -178,16 +178,17 @@ def forward(self, y, x, f_y, key): return out_features -def build_gno(num_channels, inner_model, params): - model = GNOModel(num_channels, inner_model, params) +def build_gno(inner_model, params): + model = GNOModel(inner_model, params) return model class GNOModel(nn.Module): - def __init__(self, num_channels, inner_model, params=None): + def __init__(self, inner_model, params=None): super().__init__() + num_channels = params.gno["n_channels"] self.radius_in = params.gno["radius_in"] self.radius_out = params.gno["radius_out"] diff --git a/matey/train.py b/matey/train.py index 22f65d3..4060107 100644 --- a/matey/train.py +++ b/matey/train.py @@ -178,9 +178,8 @@ def initialize_model(self): elif self.params.model_type == "turbt": self.model = build_turbt(self.params).to(self.device) - num_channels = 4 if hasattr(self.params, "gno"): - self.model = build_gno(num_channels, self.model, self.params).to(self.device) + self.model = build_gno(self.model, self.params).to(self.device) if self.params.compile: print('WARNING: BFLOAT NOT SUPPORTED IN SOME COMPILE OPS SO SWITCHING TO FLOAT16')