diff --git a/examples/config/Demo_gno_vit.yaml b/examples/config/Demo_gno_vit.yaml new file mode 100644 index 0000000..8f5c318 --- /dev/null +++ b/examples/config/Demo_gno_vit.yaml @@ -0,0 +1,68 @@ +basic_config: &basic_config + # Run settings + log_to_screen: !!bool True # Log progress to screen. + save_checkpoint: !!bool True # Save checkpoints + checkpoint_save_interval: 10 # Save every # epochs - also saves "best" according to val loss + true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs + num_data_workers: 6 # Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio + enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now + compile: !!bool False # Compile model - Does not currently work + gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory + exp_dir: 'Demo_GNO' + log_interval: 1 # How often to log - Don't think this is actually implemented + pretrained: !!bool False # Whether to load a pretrained model + # Training settings + drop_path: 0.1 + batch_size: 32 #1 + max_epochs: 10 + scheduler_epochs: -1 + epoch_size: 50 #2000 # Artificial epoch size + rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1 + # optimizer: 'adan' # adam, adan, whatever else i end up adding - adan did better on HP sweep + optimizer: 'AdamW' # adam, adan, whatever else i end up adding - adan did better on HP sweep + scheduler: 'cosine' # Only cosine implemented + warmup_steps: 5 # Warmup when not using DAdapt + learning_rate: 1e-3 # -1 means use DAdapt + weight_decay: 1e-3 + # n_states: 12 # Number of state variables across the datasets - Can be larger than real number and things will just go unused + # n_states: 200 # Workaround PyTorch bug + n_states: 220 # Workaround PyTorch bug + n_states_cond: 0 + state_names: ['Vx', 'Vy', 'Vz', 'Pressure'] # Should be sorted + dt: 1 # Striding of data - Not currently implemented > 1 + leadtime_max: 1 #prediction lead time range [1, leadtime_max] + n_steps: 4 # Length of history to include in input + enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception. + accum_grad: 1 # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum + # Model settings + model_type: 'vit_all2all' + #space_type: '2D_attention' # Conditional on block type + tie_fields: !!bool False # Whether to use 1 embedding per field per data + embed_dim: 384 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L + num_heads: 6 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L + processor_blocks: 12 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L + tokenizer_heads: + - head_name: "tk-3D" + patch_size: [[8, 8, 8]] # z, x, y + - head_name: "gno-flow3D" + patch_size: [[1, 1, 1]] + radius_in: 1.8 + radius_out: 1.9 + # resolution: [48, 48, 48] + resolution: [16, 16, 16] + n_channels: 4 + sts_model: !!bool False + sts_train: !!bool False #when True, we use loss function with two parts: l_coarse/base + l_total, so that the coarse ViT approximates true solutions directly + bias_type: 'none' # Options rel, continuous, none + # Data settings + train_val_test: [.8, .1, .1] + augmentation: !!bool False # Augmentation not implemented + use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets + tie_batches: !!bool False # Force everything in batch to come from one dset + extended_names: !!bool False # Whether to use extended names - not currently implemented + embedding_offset: 0 # Use when adding extra finetuning fields + # train_data_paths: [ ['/global/homes/a/aprokop/m4724/aprokop/Flowsaround3Dobjects/data/', 'flow3d', ''] ] + # valid_data_paths: [ ['/global/homes/a/aprokop/m4724/aprokop/Flowsaround3Dobjects/data/', 'flow3d', ''] ] + train_data_paths: [ ['/pscratch/sd/a/aprokop/Flowsaround3Dobjects_recompute/data/', 'flow3d', 'train', 'gno-flow3D'] ] + valid_data_paths: [ ['/pscratch/sd/a/aprokop/Flowsaround3Dobjects_recompute/data/', 'flow3d', 'val', 'gno-flow3D'] ] + append_datasets: [] # List of datasets to append to the input/output projections for finetuning diff --git a/matey/data_utils/blastnet_3Ddatasets.py b/matey/data_utils/blastnet_3Ddatasets.py index e70a29b..029b164 100644 --- a/matey/data_utils/blastnet_3Ddatasets.py +++ b/matey/data_utils/blastnet_3Ddatasets.py @@ -194,15 +194,15 @@ def __getitem__(self, index): tar = np.squeeze(tar[:, :, isz0:isz0+cbszz, isx0:isx0+cbszx, isy0:isy0+cbszy], axis=0) # C,D,H,W else: - assert len(variables) in (2, 3) + assert len(variables) in (2, 3, 4) trajectory, leadtime = variables[:2] - trajectory = trajectory[:, :, isz0:isz0+cbszz, isx0:isx0+cbszx, isy0:isy0+cbszy] inp = trajectory[:-1] if self.leadtime_max > 0 else trajectory tar = trajectory[-1] - - if len(variables) == 3: + if len(variables) >= 3: ret_dict["cond_fields"] = variables[2] + if len(variables) >= 4: + ret_dict["geometry"] = variables[3] ret_dict["x"] = inp ret_dict["y"] = tar diff --git a/matey/data_utils/datasets.py b/matey/data_utils/datasets.py index 64520ee..7f8aa35 100644 --- a/matey/data_utils/datasets.py +++ b/matey/data_utils/datasets.py @@ -301,6 +301,9 @@ def __getitem__(self, index): if "cond_input" in variables: datasamples["cond_input"] = variables["cond_input"] + + if "geometry" in variables: + datasamples["geometry"] = variables["geometry"] return datasamples diff --git a/matey/data_utils/flow3d_datasets.py b/matey/data_utils/flow3d_datasets.py index e625755..9a7114e 100644 --- a/matey/data_utils/flow3d_datasets.py +++ b/matey/data_utils/flow3d_datasets.py @@ -18,6 +18,7 @@ class Flow3D_Object(BaseBLASTNET3DDataset): # cond_field_names = ["cell_types"] # cond_field_names = ["sdf_obstacle"] cond_field_names = ["sdf_obstacle", "sdf_channel"] + provides_geometry = True @staticmethod def _specifics(): @@ -32,7 +33,7 @@ def _specifics(): # field_names = ["Vx", "Vy", "Vw", "Pressure", "k", "nut"] field_names = ["Vx", "Vy", "Vw", "Pressure"] type = "flow3d" - cubsizes = [192, 48, 48] + cubsizes = [194, 50, 50] case_str = "*" split_level = "case" return time_index, sample_index, field_names, type, cubsizes, case_str, split_level @@ -167,7 +168,7 @@ def compute_and_save_sdf(self, f, sdf_path, mode = "negative_one"): def _get_filesinfo(self, file_paths): dictcase = {} - for datacasedir in file_paths: + for case_id, datacasedir in enumerate(file_paths): file = os.path.join(datacasedir, "data.h5") f = h5py.File(file) nsteps = 5000 @@ -185,6 +186,7 @@ def _get_filesinfo(self, file_paths): dictcase[datacasedir]["ntimes"] = nsteps dictcase[datacasedir]["features"] = features dictcase[datacasedir]["features_mapping"] = features_mapping + dictcase[datacasedir]["geometry_id"] = case_id sdf_path = os.path.join(datacasedir, "sdf_neg_one.npz") if not os.path.exists(sdf_path): @@ -202,6 +204,15 @@ def _get_filesinfo(self, file_paths): else: dictcase[datacasedir]["stats"] = self.compute_and_save_stats(f, json_path) + # Store mesh coordinates in a [N, 3] tensor in H,W,D order + nx = [self.cubsizes[0], self.cubsizes[1], self.cubsizes[2]] + res = nx + tx = torch.linspace(0, nx[0], res[0], dtype=torch.float32) + ty = torch.linspace(0, nx[1], res[1], dtype=torch.float32) + tz = torch.linspace(0, nx[2], res[2], dtype=torch.float32) + X, Y, Z = torch.meshgrid(tx, ty, tz, indexing="ij") + self.grid = torch.flatten(torch.stack((X, Y, Z), dim=-1), end_dim=-2) + return dictcase def _reconstruct_sample(self, dictcase, time_idx, leadtime): @@ -278,26 +289,28 @@ def get_data(start, end): data = data.transpose((0, -1, 3, 1, 2)) cond_data = cond_data.transpose((0, -1, 3, 1, 2)) - # We reduce H dimension from 194 x 50 x 50 to 192 x 50 x 50 to - # allow reasonable patch sizes. Otherwise, as 194 = 2 x 97, it - # would only allow us patch size of 2 or 97, neither of which are - # reasonable. + # Geometry mask in [(HWD)] order + geometry_mask = np.full((total_padded_cells), False) + geometry_mask[inside_idx] = True + for ft in features: + bcs = f['boundary-conditions'][ft] + for name, desc in bcs.items(): + if desc.attrs['type'] == 'fixed-value': + boundary_idx = np.array(f['grid/boundaries'][name]) + geometry_mask[boundary_idx] = True - # Pressure has fixed-value boundary condition on the outflow. If we - # simply reduce the dimension by eliminating the last two layers, - # we lose that information. Instead, set the last H layer to be the - # outflow. - # cond_data[:,:,:,-3,:] = cond_data[:,:,:,-1,:] # only for cell types - # data[:,ft_mapping['p'],:,-3,:] = data[:,ft_mapping['p'],:,-1,:] + return data, cond_data, geometry_mask - return data, cond_data + comb_x, cond_data, geometry_mask = get_data(time_idx, time_idx + self.nsteps_input) + comb_y, _, _ = get_data(time_idx + self.nsteps_input + leadtime - 1, time_idx + self.nsteps_input + leadtime) - comb_x, cond_data = get_data(time_idx, time_idx + self.nsteps_input) - comb_y, _ = get_data(time_idx + self.nsteps_input + leadtime - 1, time_idx + self.nsteps_input + leadtime) + # Make sure that the generated sample matches the cubsizes + D, H, W = comb_x.shape[-3:] + assert [H, W, D] == self.cubsizes comb = np.concatenate((comb_x, comb_y), axis=0) - return torch.from_numpy(comb), leadtime.to(torch.float32), torch.from_numpy(cond_data) + return torch.from_numpy(comb), leadtime.to(torch.float32), torch.from_numpy(cond_data), {"geometry_id": dictcase["geometry_id"], "grid_coords": self.grid, "geometry_mask": torch.from_numpy(geometry_mask)} def _get_specific_bcs(self): # FIXME: not used for now diff --git a/matey/inference.py b/matey/inference.py index 9a5c2a7..f55aec6 100644 --- a/matey/inference.py +++ b/matey/inference.py @@ -147,6 +147,13 @@ def inference(self): else: cond_input = None + if "geometry" in data: + geometry = data["geometry"] + geometry["grid_coords"] = geometry["grid_coords"].to(self.device) + else: + geometry = None + + cond_dict = {} try: cond_dict["labels"] = data["cond_field_labels"].to(self.device) @@ -164,15 +171,20 @@ def inference(self): tar = tar.to(self.device) imod = self.params.hierarchical["nlevels"]-1 if hasattr(self.params, "hierarchical") else 0 if "graph" in data: - isgraph = True + tkhead_type = 'graph' inp = graphdata imod_bottom = imod + elif "geometry" in data: + inp = rearrange(inp.to(self.device), 'b t c d h w -> t b c d h w') + tkhead_type = 'gno' + inp = (inp, geometry) + imod_bottom = -1 # not used else: inp = rearrange(inp.to(self.device), 'b t c d h w -> t b c d h w') - isgraph = False + tkhead_type = 'default' imod_bottom = determine_turt_levels(self.model.module.tokenizer_heads_params[tkhead_name][-1], inp.shape[-3:], imod) if imod>0 else 0 seq_group = self.current_group if dset_type in self.valid_dataset.DP_dsets else None - print(f"Rank {self.global_rank} input shape {inp.shape if not isgraph else inp}, dset_type {dset_type}", flush=True) + print(f"Rank {self.global_rank} input shape {inp.shape if tkhead_type == 'default' else inp[0].shape if tkhead_type == 'gno' else inp}, dset_type {dset_type}", flush=True) opts = ForwardOptionsBase( imod=imod, imod_bottom=imod_bottom, @@ -182,7 +194,7 @@ def inference(self): blockdict=copy.deepcopy(blockdict), cond_dict=copy.deepcopy(cond_dict), cond_input=cond_input, - isgraph=isgraph, + tkhead_type=tkhead_type, field_labels_out= field_labels_out ) output, rollout_steps = self.model_forward(inp, field_labels, bcs, opts) @@ -190,8 +202,8 @@ def inference(self): if rollout_steps is None: rollout_steps = leadtime.view(-1).long() tar = tar[:, rollout_steps-1, :] # B,C,D,H,W - update_loss_logs_inplace_eval(output, tar, graphdata if isgraph else None, logs, loss_dset_logs, loss_l1_dset_logs, loss_rmse_dset_logs, dset_type) - if not isgraph and getattr(self.params, "log_ssim", False): + update_loss_logs_inplace_eval(output, tar, graphdata if tkhead_type == 'graph' else None, logs, loss_dset_logs, loss_l1_dset_logs, loss_rmse_dset_logs, dset_type) + if tkhead_type == 'default' and getattr(self.params, "log_ssim", False): avg_ssim = get_ssim(output, tar, blockdict, self.global_rank, self.current_group, self.group_rank, self.group_size, self.device, self.valid_dataset, dset_index) logs['valid_ssim'] += avg_ssim self.single_print('DONE VALIDATING - NOW SYNCING') diff --git a/matey/models/avit.py b/matey/models/avit.py index 9ff2eea..c03a83f 100644 --- a/matey/models/avit.py +++ b/matey/models/avit.py @@ -156,16 +156,16 @@ def forward(self, x, state_labels, bcs, opts: ForwardOptionsBase, train_opts: Op #unpack arguments imod = opts.imod tkhead_name = opts.tkhead_name + tkhead_type = opts.tkhead_type sequence_parallel_group = opts.sequence_parallel_group leadtime = opts.leadtime blockdict = opts.blockdict cond_dict = opts.cond_dict refine_ratio = opts.refine_ratio cond_input = opts.cond_input - isgraph = opts.isgraph ################################################################## conditioning = (cond_dict != None and bool(cond_dict) and self.conditioning) - assert not isgraph, "graph is not supported in AViT" + assert tkhead_type == 'default', "graph or gno are not supported in AViT" #T,B,C,D,H,W T, _, _, D, _, _ = x.shape if self.tokenizer_heads_gammaref[tkhead_name] is None and refine_ratio is None: diff --git a/matey/models/basemodel.py b/matey/models/basemodel.py index 1ff3cba..9317545 100644 --- a/matey/models/basemodel.py +++ b/matey/models/basemodel.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import numpy as np from einops import rearrange, repeat -from .spatial_modules import hMLP_stem, hMLP_output, SubsampledLinear, GraphhMLP_stem, GraphhMLP_output +from .spatial_modules import hMLP_stem, hMLP_output, SubsampledLinear, GraphhMLP_stem, GraphhMLP_output, GNOhMLP_stem, GNOhMLP_output from .time_modules import leadtimeMLP from .input_modules import input_control_MLP from .positionbias_modules import positionbias_mod @@ -76,6 +76,11 @@ def __init__(self, tokenizer_heads, n_states=6, n_states_out=None, n_states_cond debed_ensemble.append(GraphhMLP_output(patch_size=ps_scale_out, embed_dim=embed_dim, out_chans=n_states_out, smooth=smooth)) if self.conditioning: embed_ensemble_cond.append(GraphhMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim)) + elif "gno" in head_name: + embed_ensemble.append(GNOhMLP_stem(tk, in_chans=embed_dim//4, out_chans=embed_dim)) + debed_ensemble.append(GNOhMLP_output(tk, in_chans=embed_dim, mid_chans=embed_dim//16, out_chans=n_states_out)) + if self.conditioning: + embed_ensemble_cond.append(GNOhMLP_stem(tk, in_chans=embed_dim//4, embed_dim=embed_dim)) else: embed_ensemble.append(hMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim)) debed_ensemble.append(hMLP_output(patch_size=ps_scale_out, embed_dim=embed_dim, out_chans=n_states_out, notransposed=notransposed, smooth=smooth)) @@ -187,8 +192,8 @@ def debug_nan(self,x, message=""): #print("No NAN in model parameters: ", name, param.data.numel()) sys.exit(-1) - def get_unified_preembedding(self, x, state_labels, op, isgraph=False): - if not isgraph: + def get_unified_preembedding(self, x, state_labels, op, tkhead_type='default'): + if tkhead_type == 'default' or tkhead_type == 'gno': ## input tensor x: [t, b, c, d, h, w]; state_labels[b, c] # state_labels: variable index to consider varying datasets # return [t, b, c_emb//4, d, h, w] @@ -198,7 +203,7 @@ def get_unified_preembedding(self, x, state_labels, op, isgraph=False): x = rearrange(x, 't b d h w c -> t b c d h w') #self.debug_nan(x) return x - else: + elif tkhead_type == 'graph': #input: (node_features, batch, edge_index); output: (emb node_features, batch, edge_index) node_features, batch, edge_index = x #node_features [nnodes, t, c] @@ -206,8 +211,8 @@ def get_unified_preembedding(self, x, state_labels, op, isgraph=False): node_features = op(node_features, state_labels) #[nnodes, t, c_emb//4] return (node_features, batch, edge_index) - def get_structured_sequence(self, x, embed_index, tokenizer, isgraph=False): - if not isgraph: + def get_structured_sequence(self, x, embed_index, tokenizer, tkhead_type='default'): + if tkhead_type == 'default': ## input tensor x: [t, b, c_emb//4, d, h, w] # embed_index: tokenization at different resolutions; ## and return patch sequences in shape [t, b, c_emd, ntoken_z, ntoken_x, ntoken_y] @@ -216,7 +221,11 @@ def get_structured_sequence(self, x, embed_index, tokenizer, isgraph=False): x = tokenizer[embed_index](x) x = rearrange(x, '(t b) c d h w -> t b c d h w', t=T) #self.debug_nan(x, message="embed_ensemble") - else: + elif tkhead_type == 'gno': + ## input: (x, geometry); output: (x) + # x: [t, b, c_emb//4, d, h, w] + x = tokenizer[embed_index](x) + elif tkhead_type == 'graph': #input: (node_features, batch, edge_index); output: (node_features, batch, edge_index) x = tokenizer[embed_index](x) return x @@ -406,7 +415,7 @@ def get_chosenrefinedpatches(self, x_refine, refineind, t_pos_area_refine, embed leadtime = leadtime.repeat_interleave(ncoarse, dim=0)[mask] return x_local, t_pos_area_local, patch_ids, leadtime - def get_patchsequence(self, x, state_labels, tkhead_name, refineind=None, leadtime=None, blockdict=None, ilevel=0, conditioning: bool = False, isgraph = False): + def get_patchsequence(self, x, state_labels, tkhead_name, refineind=None, leadtime=None, blockdict=None, ilevel=0, conditioning: bool = False, tkhead_type='default'): """ ### intput tensors # x: [T, B, C, D, H, W] @@ -431,12 +440,16 @@ def get_patchsequence(self, x, state_labels, tkhead_name, refineind=None, leadt ######################################################## #[T, B, C_emb//4, D, H, W] op = self.space_bag[ilevel] if not conditioning else self.space_bag_cond[ilevel] - x_pre = self.get_unified_preembedding(x, state_labels, op, isgraph=isgraph) + if tkhead_type == 'gno': + x, geometry = x + x_pre = self.get_unified_preembedding(x, state_labels, op, tkhead_type) + if tkhead_type == 'gno': + x_pre = (x_pre, geometry) ##############tokenizie at the coarse scale############## # x in shape [T, B, C_emb, ntoken_z, ntoken_x, ntoken_y] tokenizer = self.tokenizer_ensemble_heads[ilevel][tkhead_name]["embed" if not conditioning else "embed_cond"] - x = self.get_structured_sequence(x_pre, -1, tokenizer, isgraph=isgraph) - if isgraph: + x = self.get_structured_sequence(x_pre, -1, tokenizer, tkhead_type=tkhead_type) + if tkhead_type == 'graph': #x: (node_features, batch, edge_index) node_emb, batch, edge_index = x x, mask = graph_to_densenodes(node_emb, batch) #[B, Max_nodes, T, C_inp] @@ -445,7 +458,11 @@ def get_patchsequence(self, x, state_labels, tkhead_name, refineind=None, leadt #t_pos_area, _ = self.get_t_pos_area(x_pre, -1, tkhead_name, blockdict=blockdict, ilevel=ilevel) t_pos_area = None return x, None, None, mask_padding, None, None, t_pos_area, None - else: + elif tkhead_type == 'gno': + x = rearrange(x, 't b c d h w -> t b c (d h w)') + t_pos_area = None + return x, None, None, None, None, None, t_pos_area, None + elif tkhead_type == 'default': x = rearrange(x, 't b c d h w -> t b c (d h w)') t_pos_area, _ = self.get_t_pos_area(x_pre, -1, tkhead_name, blockdict=blockdict, ilevel=ilevel) t_pos_area = rearrange(t_pos_area, 'b t d h w c-> b t (d h w) c') @@ -469,7 +486,7 @@ def get_patchsequence(self, x, state_labels, tkhead_name, refineind=None, leadt #mask_padding: [B, ntoken_tot] return x_padding, patch_ids, patch_ids_ref, mask_padding, None, None, t_pos_area_padding, None - def get_spatiotemporalfromsequence(self, x_padding, patch_ids, patch_ids_ref, space_dims, tkhead_name, ilevel=0, isgraph=False): + def get_spatiotemporalfromsequence(self, x_padding, patch_ids, patch_ids_ref, space_dims, tkhead_name, ilevel=0, tkhead_type='default'): #taking token sequences, x_padding, in shape [T, B, C_emb, ntoken_tot] as input #patch_ids_ref: [npatches] (ids of effective tokens in x_local) #patch_ids: [npatches] #selected token ids with sample pos inside batch considered @@ -479,8 +496,10 @@ def get_spatiotemporalfromsequence(self, x_padding, patch_ids, patch_ids_ref, sp embed_ensemble = self.tokenizer_ensemble_heads[ilevel][tkhead_name]["embed"] debed_ensemble = self.tokenizer_ensemble_heads[ilevel][tkhead_name]["debed"] ######################################################################## - if isgraph: + if tkhead_type == 'graph': return debed_ensemble[-1](x_padding) #return batched graph + elif tkhead_type == 'gno': + return debed_ensemble[-1](x_padding, space_dims) T, B = x_padding.shape[:2] ntokendim =[] diff --git a/matey/models/spatial_modules.py b/matey/models/spatial_modules.py index f67162b..8dabb81 100644 --- a/matey/models/spatial_modules.py +++ b/matey/models/spatial_modules.py @@ -12,6 +12,19 @@ from einops import rearrange, repeat from ..utils.distributed_utils import closest_factors from torch_geometric.nn import GCNConv, GraphNorm +from typing import List, Literal, Optional, Callable +import time +try: + from neuralop.layers.gno_block import GNOBlock + from neuralop.layers.channel_mlp import ChannelMLP + neuralop_exist = True +except ImportError: + neuralop_exist = False +try: + import sklearn + sklearn_exist = True +except ImportError: + sklearn_exist = False ### Space utils #FIXME: this function causes training instability. Keeping it now for reproducibility; We'll remove it @@ -385,4 +398,244 @@ def forward(self, data): h = self.smooth(h, edge_index) x_list.append(h) x_out = torch.stack(x_list, dim=1) - return (x_out, batch, edge_index) \ No newline at end of file + return (x_out, batch, edge_index) + + +class CustomNeighborSearch(nn.Module): + def __init__(self, return_norm=False): + super().__init__() + self.search_fn = custom_neighbor_search + self.return_norm = return_norm + + def forward(self, data, queries, radius): + return_dict = self.search_fn(data, queries, radius, self.return_norm) + return return_dict + +def custom_neighbor_search(data: torch.Tensor, queries: torch.Tensor, radius: float, return_norm: bool=False): + if not sklearn_exist: + raise RuntimeError("sklearn is required for constructing neighbors.") + + kdtree = sklearn.neighbors.KDTree(data.cpu(), leaf_size=2) + + if return_norm: + indices, dists = kdtree.query_radius(queries.cpu(), r=radius, return_distance=True) + weights = torch.from_numpy(np.concatenate(dists)) + else: + indices = kdtree.query_radius(queries.cpu(), r=radius) + + sizes = np.array([arr.size for arr in indices]) + nbr_indices = torch.from_numpy(np.concatenate(indices)) + nbrhd_sizes = torch.cumsum(torch.from_numpy(sizes), dim=0) + splits = torch.cat((torch.tensor([0.]), nbrhd_sizes)) + + # print(f'nbr_indices: {nbr_indices.shape[0]}', flush=True) + # print(f'max nbrhd size: {np.max(sizes)}, min nbrhd size: {np.min(sizes)}, avg nbhrd size: {float(nbr_indices.shape[0])/len(sizes)}', flush=True) + + nbr_dict = {} + nbr_dict['neighbors_index'] = nbr_indices.long() + nbr_dict['neighbors_row_splits'] = splits.long() + if return_norm: + nbr_dict['weights'] = weights**2 + + return nbr_dict + + +class ModifiedGNOBlock(GNOBlock): + def __init__(self, + in_channels: int, + out_channels: int, + coord_dim: int, + radius: float, + transform_type="linear", + weighting_fn: Optional[Callable]=None, + reduction: Literal['sum', 'mean']='sum', + pos_embedding_type: str='transformer', + pos_embedding_channels: int=32, + pos_embedding_max_positions: int=10000, + channel_mlp_layers: List[int]=[128,256,128], + channel_mlp_non_linearity=F.gelu, + channel_mlp: nn.Module=None, + use_torch_scatter_reduce: bool=True): + super().__init__(in_channels, out_channels, coord_dim, radius, + transform_type, weighting_fn, reduction, + pos_embedding_type, pos_embedding_channels, + pos_embedding_max_positions, channel_mlp_layers, + channel_mlp_non_linearity, channel_mlp, + use_torch_scatter_reduce) + + self.neighbor_search = CustomNeighborSearch(return_norm=weighting_fn is not None) + + self.neighbors_dict = {} + + def forward(self, y, x, f_y, key): + key = f'{key}:{self.radius}:{y.shape}:{x.shape}' + if not key in self.neighbors_dict: + # print(f'{key}: building new neighbors') + neigh = self.neighbor_search(data=y, queries=x, radius=self.radius) + self.neighbors_dict[key] = neigh + else: + # print(f'{key}: using cached neighbors') + pass + + if self.pos_embedding is not None: + y_embed = self.pos_embedding(y) + x_embed = self.pos_embedding(x) + else: + y_embed = y + x_embed = x + + neighbors = self.neighbors_dict[key] + for item in neighbors: + neighbors[item] = neighbors[item].to(f_y.device) + out_features = self.integral_transform(y=y_embed, + x=x_embed, + neighbors=neighbors, + f_y=f_y) + + return out_features + +class GNOhMLP_stem(nn.Module): + """Geometry to patch embedding""" + def __init__(self, params, in_chans, out_chans): + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.radius = params["radius_in"] + + self.gno = ModifiedGNOBlock( + in_channels=in_chans, + out_channels=in_chans, + coord_dim=3, + radius=self.radius, + channel_mlp_layers=[16,16], + ) + + fno_lifting_channel_ratio = 4 + fno_hidden_channels = 16 + self.lifting = ChannelMLP( + in_channels=in_chans, + hidden_channels=fno_lifting_channel_ratio * fno_hidden_channels, + out_channels=out_chans, + n_layers=2, + ) + + self.res = params["resolution"] # z, x, y + + # Latent grid is [(HWD) x 3] + tx = torch.linspace(0, 1, self.res[1], dtype=torch.float32) + ty = torch.linspace(0, 1, self.res[2], dtype=torch.float32) + tz = torch.linspace(0, 1, self.res[0], dtype=torch.float32) + X, Y, Z = torch.meshgrid(tx, ty, tz, indexing="ij") + grid = torch.stack((X, Y, Z), dim=-1) + self.latent_grid = torch.flatten(grid, end_dim=-2) + + + def forward(self, data): + """ + data: (x, geometry) + """ + x, geometry = data + + T, B, _, D, H, W = x.shape + Dlat, Hlat, Wlat = self.res[0], self.res[1], self.res[2] + + out = torch.zeros(T, B, self.out_chans, Dlat, Hlat, Wlat, device=x.device) + + # The challenge is that different samples in the same batch may correspond to different geometries + for b in range(B): + geometry_id = geometry["geometry_id"][b] + geometry_mask = geometry["geometry_mask"][b] + + input_grid = geometry["grid_coords"][b] + input_grid = input_grid[geometry_mask,:] + + xin = x[:,b,:] + xin = rearrange(xin, 't c d h w -> t (h w d) c') + xin = xin[:,geometry_mask,:] + + # Rescale auxiliary grid + bmin = input_grid.min(dim=0).values + bmax = input_grid.max(dim=0).values + latent_grid = bmin + (bmax - bmin) * self.latent_grid.to(device=x.device) + + # Use T as batch + aux = self.gno(y=input_grid, x=latent_grid, f_y=xin, key=str(geometry_id) + ":in") + aux = rearrange(aux, 't (hwd) c -> t c (hwd)') + aux = self.lifting(aux) + aux = rearrange(aux, 't c (h w d) -> t c d h w', d=Dlat, h=Hlat, w=Wlat) + out[:,b,:] = aux + + return out + +class GNOhMLP_output(nn.Module): + """Patch to geometry de-bedding""" + def __init__(self, params, in_chans, mid_chans, out_chans): + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.radius = params["radius_out"] + + self.gno = ModifiedGNOBlock( + in_channels=in_chans, + out_channels=in_chans, + coord_dim=3, + radius=self.radius, + channel_mlp_layers=[16,16], + ) + + projection_channel_ratio = 4 + fno_hidden_channels = 16 + self.projection = ChannelMLP( + in_channels=in_chans, + hidden_channels=projection_channel_ratio * fno_hidden_channels, + out_channels=out_chans, + n_layers=2, + ) + + self.res = params["resolution"] # z, x, y + + # Latent grid is [(HWD) x 3] + tx = torch.linspace(0, 1, self.res[1], dtype=torch.float32) + ty = torch.linspace(0, 1, self.res[2], dtype=torch.float32) + tz = torch.linspace(0, 1, self.res[0], dtype=torch.float32) + X, Y, Z = torch.meshgrid(tx, ty, tz, indexing="ij") + grid = torch.stack((X, Y, Z), dim=-1) + self.latent_grid = torch.flatten(grid, end_dim=-2) + + def forward(self, data, space_dims): + """ + data: (x, geometry) + """ + x, geometry = data + + T, B, C, _ = x.shape + D, H, W = space_dims + Dlat, Hlat, Wlat = self.res[0], self.res[1], self.res[2] + + out = torch.zeros(T, B, self.out_chans, D, H, W, device=x.device) + + x = rearrange(x, 't b c (d h w) -> b t (h w d) c', d=Dlat, h=Hlat, w=Wlat) + for b in range(B): + geometry_id = geometry["geometry_id"][b] + geometry_mask = geometry["geometry_mask"][b] + + output_grid = geometry["grid_coords"][b] + # output_grid = output_grid[geometry_mask,:] + + # Rescale auxiliary grid + bmin = output_grid.min(dim=0).values + bmax = output_grid.max(dim=0).values + latent_grid = bmin + (bmax - bmin) * self.latent_grid.to(device=x.device) + + # Use T as batch + # FIXME: can we use masked output_grid here + aux = self.gno(y=latent_grid, x=output_grid, f_y=x[b], key=str(geometry_id) + ":out") + aux = rearrange(aux, 't (hwd) c -> t c (hwd)') + aux = self.projection(aux) + aux[:,:,~geometry_mask] = 0 + aux = rearrange(aux, 't c (h w d) -> t c d h w', d=D, h=H, w=W) + out[:,b,:] = aux + + return out diff --git a/matey/models/svit.py b/matey/models/svit.py index c5ccbdc..c9f957c 100644 --- a/matey/models/svit.py +++ b/matey/models/svit.py @@ -125,13 +125,13 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase, train_opts: #unpack arguments imod = opts.imod tkhead_name = opts.tkhead_name + tkhead_type = opts.tkhead_type sequence_parallel_group = opts.sequence_parallel_group leadtime = opts.leadtime blockdict = opts.blockdict cond_dict = opts.cond_dict refine_ratio = opts.refine_ratio cond_input = opts.cond_input - isgraph=opts.isgraph field_labels_out=opts.field_labels_out ################################################################## conditioning = (cond_dict != None and bool(cond_dict) and self.conditioning) @@ -139,14 +139,14 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase, train_opts: if field_labels_out is None: field_labels_out = state_labels - if isgraph: + if tkhead_type == 'graph': x = data.x#nnodes, T, C edge_index = data.edge_index # batch = data.batch ##[N_total] x, data_mean, data_std = normalize_spatiotemporal_persample_graph(x, batch) #node features, mean_g:[G,C], std_g:[G,C] refineind=None x = (x, batch, edge_index) - else: + elif tkhead_type == 'default': x = data #T,B,C,D,H,W T, _, _, D, H, W = x.shape @@ -165,19 +165,19 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase, train_opts: leadtime = self.inconMLP[imod](cond_input) if leadtime is None else leadtime+self.inconMLP[imod](cond_input) ########Encode and get patch sequences [T, B, C_emb, ntoken_len_tot]######## if self.sts_model: - assert not isgraph, "Not set sts_model yet" + assert tkhead_type == 'default', "Not set sts_model yet" #x_padding: coarse tokens; x_local: refined local tokens x_padding, patch_ids, _, _, x_local, leadtime_local, tposarea_padding, tposarea_local = self.get_patchsequence(x, state_labels, tkhead_name, refineind=refineind, leadtime=leadtime, blockdict=blockdict) mask_padding = None x_local = rearrange(x_local, 'nrfb t c dhw_sts -> t nrfb c dhw_sts') else: - x_padding, patch_ids, patch_ids_ref, mask_padding, _, _, tposarea_padding, _ = self.get_patchsequence(x, state_labels, tkhead_name, refineind=refineind, blockdict=blockdict, ilevel=imod, isgraph=isgraph) + x_padding, patch_ids, patch_ids_ref, mask_padding, _, _, tposarea_padding, _ = self.get_patchsequence(x, state_labels, tkhead_name, refineind=refineind, blockdict=blockdict, ilevel=imod, tkhead_type=tkhead_type) # Repeat the steps for conditioning if present if conditioning: assert self.sts_model == False assert refineind == None - assert not isgraph, "Not set conditioning yet" + assert tkhead_type == 'default', "Not set conditioning yet" c, _, _, _, _, _, _, _ = self.get_patchsequence(cond_dict["fields"], cond_dict["labels"], tkhead_name, refineind=refineind, blockdict=blockdict, ilevel=imod, conditioning=conditioning) ################################################################################ if self.posbias[imod] is not None and tposarea_padding is not None: @@ -202,15 +202,15 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase, train_opts: xbase = self.get_spatiotemporalfromsequence(x_padding, None, None, [D, H, W], tkhead_name) x = self.add_sts_model(xbase, patch_ids, x_local, bcs, tkhead_name, leadtime=leadtime_local, t_pos_area=tposarea_local) else: - if isgraph: + if tkhead_type == 'graph': x_padding = rearrange(x_padding, 't b c ntoken_tot -> b ntoken_tot t c') #input:[B, Max_nodes, T, C] and mask: [B, Max_nodes] # #output: [N_total, T, C] (only real nodes) x= densenodes_to_graphnodes(x_padding, mask_padding) #[nnodes, T, C] x_padding = (x, batch, edge_index) D, H, W = -1, -1, -1 #place holder - x = self.get_spatiotemporalfromsequence(x_padding, patch_ids, patch_ids_ref, [D, H, W], tkhead_name, ilevel=imod, isgraph=isgraph) - if isgraph: + x = self.get_spatiotemporalfromsequence(x_padding, patch_ids, patch_ids_ref, [D, H, W], tkhead_name, ilevel=imod, tkhead_type=tkhead_type) + if tkhead_type == 'graph': node_ft, batch, edge_index = x #node_ft: [nnodes, T, C] x = node_ft[:,:,field_labels_out[0]] diff --git a/matey/models/turbt.py b/matey/models/turbt.py index 69ab547..a80f2c1 100644 --- a/matey/models/turbt.py +++ b/matey/models/turbt.py @@ -215,12 +215,12 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase): imod = opts.imod imod_bottom = opts.imod_bottom tkhead_name = opts.tkhead_name + tkhead_type = opts.tkhead_type sequence_parallel_group = opts.sequence_parallel_group leadtime = opts.leadtime blockdict = opts.blockdict refine_ratio = opts.refine_ratio cond_input = opts.cond_input - isgraph=opts.isgraph field_labels_out=opts.field_labels_out ################################################################## if refine_ratio is None: @@ -231,7 +231,7 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase): if field_labels_out is None: field_labels_out = state_labels - if isgraph: + if tkhead_type == 'graph': """ For graph objects: support one level for now FIXME: extend to multiple levels @@ -266,7 +266,7 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase): if self.cond_input and cond_input is not None: leadtime = self.inconMLP[imod](cond_input) if leadtime is None else leadtime+self.inconMLP[imod](cond_input) ########Encode and get patch sequences [B, C_emb, T*ntoken_len_tot]######## - x, patch_ids, patch_ids_ref, mask_padding, _, _, tposarea_padding, _ = self.get_patchsequence(x, state_labels, tkhead_name, refineind=refineind, blockdict=blockdict, ilevel=imod, isgraph=isgraph) + x, patch_ids, patch_ids_ref, mask_padding, _, _, tposarea_padding, _ = self.get_patchsequence(x, state_labels, tkhead_name, refineind=refineind, blockdict=blockdict, ilevel=imod, tkhead_type=tkhead_type) x = rearrange(x, 't b c ntoken_tot -> b c (t ntoken_tot)') ################################################################################ if self.posbias[imod] is not None and tposarea_padding is not None: @@ -278,7 +278,7 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase): ######## Process ######## #only send mask if mask_padding indicates padding tokens mask4attblk = None if (mask_padding is not None and mask_padding.all()) else mask_padding - local_att = not isgraph and imod>imod_bottom + local_att = (tkhead_type != 'graph') and imod>imod_bottom if local_att: #each mode similar cost nfact=max(2**(2*(imod-imod_bottom))//blockdict["nproc_blocks"][-1], 1) if blockdict is not None else max(2**(2*(imod-imod_bottom)), 1) @@ -291,7 +291,7 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase): for iblk, blk in enumerate(self.module_blocks[str(imod)]): if iblk==0: b_mod=x.shape[0] - if not isgraph and leadtime is not None: + if tkhead_type != 'graph' and leadtime is not None: leadtime = leadtime.repeat(b_mod // B, 1) x = blk(x, sequence_parallel_group=sequence_parallel_group, bcs=bcs, leadtime=leadtime, mask_padding=mask4attblk, local_att=local_att) else: @@ -303,7 +303,7 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase): ################################################################################ x = rearrange(x, 'b c (t ntoken_tot) -> t b c ntoken_tot', t=T) ################################################################################# - if isgraph: + if tkhead_type == 'graph': x = rearrange(x, 't b c ntoken_tot -> b ntoken_tot t c') #input:[B, Max_nodes, T, C] and mask: [B, Max_nodes] #output: [N_total, T, C] (only real nodes) @@ -311,8 +311,8 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase): x = (x, batch, edge_index) D, H, W = -1, -1, -1 #place holder ######## Decode ######## - x = self.get_spatiotemporalfromsequence(x, patch_ids, patch_ids_ref, [D, H, W], tkhead_name, ilevel=imod, isgraph=isgraph) - if isgraph: + x = self.get_spatiotemporalfromsequence(x, patch_ids, patch_ids_ref, [D, H, W], tkhead_name, ilevel=imod, tkhead_type=tkhead_type) + if tkhead_type == 'graph': node_ft, batch, edge_index = x #node_ft: [nnodes, T, C] x = node_ft[:,:,field_labels_out[0]] diff --git a/matey/models/vit.py b/matey/models/vit.py index 6664c82..5816448 100644 --- a/matey/models/vit.py +++ b/matey/models/vit.py @@ -122,13 +122,13 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase, train_opts: #unpack arguments imod = opts.imod tkhead_name = opts.tkhead_name + tkhead_type = opts.tkhead_type sequence_parallel_group = opts.sequence_parallel_group leadtime = opts.leadtime blockdict = opts.blockdict cond_dict = opts.cond_dict refine_ratio = opts.refine_ratio cond_input = opts.cond_input - isgraph=opts.isgraph field_labels_out=opts.field_labels_out ################################################################## conditioning = (cond_dict != None and bool(cond_dict) and self.conditioning) @@ -136,7 +136,7 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase, train_opts: if field_labels_out is None: field_labels_out = state_labels - if isgraph: + if tkhead_type == 'graph': x = data.x#[nnodes, T, C] edge_index = data.edge_index # batch = data.batch ##[N_total] @@ -144,6 +144,12 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase, train_opts: x, data_mean, data_std = normalize_spatiotemporal_persample_graph(x, batch) #node features, mean_g:[G,C], std_g:[G,C] refineind=None x = (x, batch, edge_index) + elif tkhead_type == 'gno': + x, geometry = data + x, data_mean, data_std = normalize_spatiotemporal_persample(x) + T, _, _ , D, H, W = x.shape + refineind=None + x = (x, geometry) else: x = data #T,B,C,D,H,W @@ -164,20 +170,20 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase, train_opts: leadtime = self.inconMLP[imod](cond_input) if leadtime is None else leadtime+self.inconMLP[imod](cond_input) ########Encode and get patch sequences [B, C_emb, T*ntoken_len_tot]######## if self.sts_model: - assert not isgraph, "Not set sts_model yet" + assert tkhead_type == 'default', "Not set sts_model yet" #x_padding: coarse tokens; x_local: refined local tokens x_padding, patch_ids, _, _, x_local, leadtime_local, tposarea_padding, tposarea_local = self.get_patchsequence(x, state_labels, tkhead_name, refineind=refineind, leadtime=leadtime, blockdict=blockdict) mask_padding = None x_local = rearrange(x_local, 'nrfb t c dhw_sts -> nrfb c (t dhw_sts)') else: - x_padding, patch_ids, patch_ids_ref, mask_padding, _, _, tposarea_padding, _ = self.get_patchsequence(x, state_labels, tkhead_name, refineind=refineind, blockdict=blockdict, ilevel=imod, isgraph=isgraph) + x_padding, patch_ids, patch_ids_ref, mask_padding, _, _, tposarea_padding, _ = self.get_patchsequence(x, state_labels, tkhead_name, refineind=refineind, blockdict=blockdict, ilevel=imod, tkhead_type=tkhead_type) x_padding = rearrange(x_padding, 't b c ntoken_tot -> b c (t ntoken_tot)') # Repeat the steps for conditioning if present if conditioning: assert self.sts_model == False assert refineind == None - assert not isgraph, "Not set conditioning yet" + assert tkhead_type == 'default', "Not set conditioning yet" c, _, _, _, _, _, _, _ = self.get_patchsequence(cond_dict["fields"], cond_dict["labels"], tkhead_name, refineind=refineind, blockdict=blockdict, ilevel=imod, conditioning=conditioning) c = rearrange(c, 't b c ntoken_tot -> b c (t ntoken_tot)') ################################################################################ @@ -204,16 +210,19 @@ def forward(self, data, state_labels, bcs, opts: ForwardOptionsBase, train_opts: xbase = self.get_spatiotemporalfromsequence(x_padding, None, None, [D, H, W], tkhead_name, ilevel=0) x = self.add_sts_model(xbase, patch_ids, x_local, bcs, tkhead_name, leadtime=leadtime_local, t_pos_area=tposarea_local) else: - if isgraph: + if tkhead_type == 'graph': x_padding = rearrange(x_padding, 't b c ntoken_tot -> b ntoken_tot t c') #input:[B, Max_nodes, T, C] and mask: [B, Max_nodes] #output: [N_total, T, C] (only real nodes) x= densenodes_to_graphnodes(x_padding, mask_padding) #[nnodes, T, C] x_padding = (x, batch, edge_index) D, H, W = -1, -1, -1 #place holder + elif tkhead_type == 'gno': + x_padding = (x_padding, geometry) + + x = self.get_spatiotemporalfromsequence(x_padding, patch_ids, patch_ids_ref, [D, H, W], tkhead_name, ilevel=imod, tkhead_type=tkhead_type) - x = self.get_spatiotemporalfromsequence(x_padding, patch_ids, patch_ids_ref, [D, H, W], tkhead_name, ilevel=imod, isgraph=isgraph) - if isgraph: + if tkhead_type == 'graph': node_ft, batch, edge_index = x #node_ft: [nnodes, T, C] x = node_ft[:,:,field_labels_out[0]] diff --git a/matey/train.py b/matey/train.py index e7717c6..1f2e737 100644 --- a/matey/train.py +++ b/matey/train.py @@ -450,6 +450,13 @@ def train_one_epoch(self): cond_input = data["cond_input"].to(self.device) else: cond_input = None + + if "geometry" in data: + geometry = data["geometry"] + geometry["grid_coords"] = geometry["grid_coords"].to(self.device) + else: + geometry = None + cond_dict = {} try: cond_dict["labels"] = data["cond_field_labels"].to(self.device) @@ -472,12 +479,17 @@ def train_one_epoch(self): tar = tar.to(self.device) imod = self.params.hierarchical["nlevels"]-1 if hasattr(self.params, "hierarchical") else 0 if "graph" in data: - isgraph = True + tkhead_type = 'graph' inp = graphdata imod_bottom = imod + elif "geometry" in data: + inp = rearrange(inp.to(self.device), 'b t c d h w -> t b c d h w') + tkhead_type = 'gno' + inp = (inp, geometry) + imod_bottom = imod else: inp = rearrange(inp.to(self.device), 'b t c d h w -> t b c d h w') - isgraph = False + tkhead_type = 'default' imod_bottom = determine_turt_levels(self.model.module.tokenizer_heads_params[tkhead_name][-1], inp.shape[-3:], imod) if imod>0 else 0 #if self.global_rank == 0: # print(f"input shape {inp.shape}, dset_type {dset_type}, nlevels-1 {imod}, imod_bottom {imod_bottom}, {self.global_rank}, {blockdict}", flush=True) @@ -491,7 +503,7 @@ def train_one_epoch(self): blockdict=copy.deepcopy(blockdict), cond_dict=copy.deepcopy(cond_dict), cond_input=cond_input, - isgraph=isgraph, + tkhead_type=tkhead_type, field_labels_out= field_labels_out ) with record_function_opt("model forward", enabled=self.profiling): @@ -499,14 +511,14 @@ def train_one_epoch(self): if tar.ndim == 6:# B,T,C,D,H,W; For autoregressive, update the target with the returned actual rollout_steps tar = tar[:, rollout_steps-1, :] # B,C,D,H,W #compute loss and update (in-place) logging dicts. - loss, log_nrmse = compute_loss_and_logs(output, tar, graphdata if isgraph else None, logs, loss_logs, dset_type, self.params) + loss, log_nrmse = compute_loss_and_logs(output, tar, graphdata if tkhead_type == 'graph' else None, logs, loss_logs, dset_type, self.params) bad = torch.isnan(loss).any() or torch.isinf(loss) torch.distributed.all_reduce(bad, op=torch.distributed.ReduceOp.SUM) if bad.item() > 0: print(f"INF: {torch.isinf(inp).any(), torch.isinf(tar).any(), torch.isinf(output).any(), bad} for {dset_type}") print(f"NAN: {torch.isnan(inp).any(), torch.isnan(tar).any(), torch.isnan(output).any(), bad} for {dset_type}") continue - if not isgraph: + if tkhead_type != 'graph': if self.params.pei_debug: checking_data_pred_tar(tar, output, blockdict, self.global_rank, self.current_group, self.group_rank, self.group_size, self.device, self.params.debug_outdir, istep=steps, imod=-1) @@ -541,7 +553,7 @@ def train_one_epoch(self): print(f"Epoch {self.epoch} Batch {batch_idx} Train Loss {log_nrmse.item()}") if self.log_to_screen: print('Total Times. Batch: {}, Rank: {}, Data Shape: {}, Data time: {}, Forward: {}, Backward: {}, Optimizer: {}, lr:{}, leadtime.max: {}'.format( - batch_idx, self.global_rank, inp.shape if not isgraph else graphdata, dtime, forward_time, backward_time, optimizer_step, self.optimizer.param_groups[0]['lr'], leadtime.max())) + batch_idx, self.global_rank, inp.shape if tkhead_type == 'default' else (graphdata if tkhead_type == 'graph' else inp[0].shape), dtime, forward_time, backward_time, optimizer_step, self.optimizer.param_groups[0]['lr'], leadtime.max())) data_start = self.timer.get_time() self.check_memory("train-end %d"%batch_idx) if self.params.scheduler == 'steplr': @@ -616,6 +628,12 @@ def validate_one_epoch(self, full=False, cutoff_skip=False): else: cond_input = None + if "geometry" in data: + geometry = data["geometry"] + geometry["grid_coords"] = geometry["grid_coords"].to(self.device) + else: + geometry = None + cond_dict = {} try: cond_dict["labels"] = data["cond_field_labels"].to(self.device) @@ -639,12 +657,17 @@ def validate_one_epoch(self, full=False, cutoff_skip=False): tar = tar.to(self.device) imod = self.params.hierarchical["nlevels"]-1 if hasattr(self.params, "hierarchical") else 0 if "graph" in data: - isgraph = True + tkhead_type = 'graph' inp = graphdata imod_bottom = imod + elif "geometry" in data: + inp = rearrange(inp.to(self.device), 'b t c d h w -> t b c d h w') + tkhead_type = 'gno' + inp = (inp, geometry) + imod_bottom = imod else: inp = rearrange(inp.to(self.device), 'b t c d h w -> t b c d h w') - isgraph = False + tkhead_type = 'default' imod_bottom = determine_turt_levels(self.model.module.tokenizer_heads_params[tkhead_name][-1], inp.shape[-3:], imod) if imod>0 else 0 seq_group = self.current_group if dset_type in self.valid_dataset.DP_dsets else None opts = ForwardOptionsBase( @@ -656,14 +679,14 @@ def validate_one_epoch(self, full=False, cutoff_skip=False): blockdict=copy.deepcopy(blockdict), cond_dict=copy.deepcopy(cond_dict), cond_input=cond_input, - isgraph=isgraph, + tkhead_type=tkhead_type, field_labels_out= field_labels_out ) output, rollout_steps = self.model_forward(inp, field_labels, bcs, opts) if tar.ndim == 6:# B,T,C,D,H,W; For autoregressive, update the target with the returned actual rollout_steps tar = tar[:, rollout_steps-1, :] # B,C,D,H,W - update_loss_logs_inplace_eval(output, tar, graphdata if isgraph else None, logs, loss_dset_logs, loss_l1_dset_logs, loss_rmse_dset_logs, dset_type) - if not isgraph and getattr(self.params, "log_ssim", False): + update_loss_logs_inplace_eval(output, tar, graphdata if tkhead_type == 'graph' else None, logs, loss_dset_logs, loss_l1_dset_logs, loss_rmse_dset_logs, dset_type) + if tkhead_type != 'graph' and getattr(self.params, "log_ssim", False): avg_ssim = get_ssim(output, tar, blockdict, self.global_rank, self.current_group, self.group_rank, self.group_size, self.device, self.valid_dataset, dset_index) logs['valid_ssim'] += avg_ssim self.check_memory("validate-end") diff --git a/matey/utils/forward_options.py b/matey/utils/forward_options.py index 986eafc..a27db4c 100644 --- a/matey/utils/forward_options.py +++ b/matey/utils/forward_options.py @@ -21,8 +21,8 @@ class ForwardOptionsBase: blockdict: Optional[Dict[str, Any]] = None cond_dict: Optional[Dict[str, Any]] = None cond_input: Optional[Tensor] = None - isgraph: Optional[bool] = False + tkhead_type: Optional[str] = 'default' # 'default', 'graph', 'gno' field_labels_out: Optional[Tensor] = None #adaptive tokenization (1 of 2 settings) refine_ratio: Optional[float] = None - imod_bottom: int = 0 #needed only by turbt \ No newline at end of file + imod_bottom: int = 0 #needed only by turbt diff --git a/matey/utils/training_utils.py b/matey/utils/training_utils.py index 5118c75..ca0fdc4 100644 --- a/matey/utils/training_utils.py +++ b/matey/utils/training_utils.py @@ -60,7 +60,7 @@ def autoregressive_rollout(model, inp, field_labels, bcs, opts: ForwardOptionsBa """ rollout_steps = preprocess_target(opts.leadtime) x_t = inp - if opts.isgraph: + if opts.tkhead_type == 'graph': n_steps = x_t.x.shape[1] #[nnodes, T, C] #FIXME: I realize it takes more to make this function work for graphs and will open a seperate PR on this raise ValueError("Autoregressive rollout is not supported yet for graphs")