Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions examples/config/Demo_gno_vit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
basic_config: &basic_config
# Run settings
log_to_screen: !!bool True # Log progress to screen.
save_checkpoint: !!bool True # Save checkpoints
checkpoint_save_interval: 10 # Save every # epochs - also saves "best" according to val loss
true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs
num_data_workers: 6 # Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio
enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now
compile: !!bool False # Compile model - Does not currently work
gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory
exp_dir: 'Demo_GNO'
log_interval: 1 # How often to log - Don't think this is actually implemented
pretrained: !!bool False # Whether to load a pretrained model
# Training settings
drop_path: 0.1
batch_size: 32 #1
max_epochs: 10
scheduler_epochs: -1
epoch_size: 50 #2000 # Artificial epoch size
rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1
# optimizer: 'adan' # adam, adan, whatever else i end up adding - adan did better on HP sweep
optimizer: 'AdamW' # adam, adan, whatever else i end up adding - adan did better on HP sweep
scheduler: 'cosine' # Only cosine implemented
warmup_steps: 5 # Warmup when not using DAdapt
learning_rate: 1e-3 # -1 means use DAdapt
weight_decay: 1e-3
# n_states: 12 # Number of state variables across the datasets - Can be larger than real number and things will just go unused
# n_states: 200 # Workaround PyTorch bug
n_states: 220 # Workaround PyTorch bug
n_states_cond: 0
state_names: ['Vx', 'Vy', 'Vz', 'Pressure'] # Should be sorted
dt: 1 # Striding of data - Not currently implemented > 1
leadtime_max: 1 #prediction lead time range [1, leadtime_max]
n_steps: 4 # Length of history to include in input
enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception.
accum_grad: 1 # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum
# Model settings
model_type: 'vit_all2all'
#space_type: '2D_attention' # Conditional on block type
tie_fields: !!bool False # Whether to use 1 embedding per field per data
embed_dim: 384 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L
num_heads: 6 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L
processor_blocks: 12 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L
tokenizer_heads:
- head_name: "tk-3D"
patch_size: [[8, 8, 8]] # z, x, y
- head_name: "gno-flow3D"
patch_size: [[1, 1, 1]]
radius_in: 1.8
radius_out: 1.9
# resolution: [48, 48, 48]
resolution: [16, 16, 16]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolution and patch_size

n_channels: 4
sts_model: !!bool False
sts_train: !!bool False #when True, we use loss function with two parts: l_coarse/base + l_total, so that the coarse ViT approximates true solutions directly
bias_type: 'none' # Options rel, continuous, none
# Data settings
train_val_test: [.8, .1, .1]
augmentation: !!bool False # Augmentation not implemented
use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets
tie_batches: !!bool False # Force everything in batch to come from one dset
extended_names: !!bool False # Whether to use extended names - not currently implemented
embedding_offset: 0 # Use when adding extra finetuning fields
# train_data_paths: [ ['/global/homes/a/aprokop/m4724/aprokop/Flowsaround3Dobjects/data/', 'flow3d', ''] ]
# valid_data_paths: [ ['/global/homes/a/aprokop/m4724/aprokop/Flowsaround3Dobjects/data/', 'flow3d', ''] ]
train_data_paths: [ ['/pscratch/sd/a/aprokop/Flowsaround3Dobjects_recompute/data/', 'flow3d', 'train', 'gno-flow3D'] ]
valid_data_paths: [ ['/pscratch/sd/a/aprokop/Flowsaround3Dobjects_recompute/data/', 'flow3d', 'val', 'gno-flow3D'] ]
append_datasets: [] # List of datasets to append to the input/output projections for finetuning
8 changes: 4 additions & 4 deletions matey/data_utils/blastnet_3Ddatasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,15 @@ def __getitem__(self, index):
tar = np.squeeze(tar[:, :, isz0:isz0+cbszz, isx0:isx0+cbszx, isy0:isy0+cbszy], axis=0) # C,D,H,W

else:
assert len(variables) in (2, 3)
assert len(variables) in (2, 3, 4)
trajectory, leadtime = variables[:2]

trajectory = trajectory[:, :, isz0:isz0+cbszz, isx0:isx0+cbszx, isy0:isy0+cbszy]
inp = trajectory[:-1] if self.leadtime_max > 0 else trajectory
tar = trajectory[-1]

if len(variables) == 3:
if len(variables) >= 3:
ret_dict["cond_fields"] = variables[2]
if len(variables) >= 4:
ret_dict["geometry"] = variables[3]

ret_dict["x"] = inp
ret_dict["y"] = tar
Expand Down
3 changes: 3 additions & 0 deletions matey/data_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ def __getitem__(self, index):

if "cond_input" in variables:
datasamples["cond_input"] = variables["cond_input"]

if "geometry" in variables:
datasamples["geometry"] = variables["geometry"]

return datasamples

Expand Down
45 changes: 29 additions & 16 deletions matey/data_utils/flow3d_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Flow3D_Object(BaseBLASTNET3DDataset):
# cond_field_names = ["cell_types"]
# cond_field_names = ["sdf_obstacle"]
cond_field_names = ["sdf_obstacle", "sdf_channel"]
provides_geometry = True

@staticmethod
def _specifics():
Expand All @@ -32,7 +33,7 @@ def _specifics():
# field_names = ["Vx", "Vy", "Vw", "Pressure", "k", "nut"]
field_names = ["Vx", "Vy", "Vw", "Pressure"]
type = "flow3d"
cubsizes = [192, 48, 48]
cubsizes = [194, 50, 50]
case_str = "*"
split_level = "case"
return time_index, sample_index, field_names, type, cubsizes, case_str, split_level
Expand Down Expand Up @@ -167,7 +168,7 @@ def compute_and_save_sdf(self, f, sdf_path, mode = "negative_one"):

def _get_filesinfo(self, file_paths):
dictcase = {}
for datacasedir in file_paths:
for case_id, datacasedir in enumerate(file_paths):
file = os.path.join(datacasedir, "data.h5")
f = h5py.File(file)
nsteps = 5000
Expand All @@ -185,6 +186,7 @@ def _get_filesinfo(self, file_paths):
dictcase[datacasedir]["ntimes"] = nsteps
dictcase[datacasedir]["features"] = features
dictcase[datacasedir]["features_mapping"] = features_mapping
dictcase[datacasedir]["geometry_id"] = case_id

sdf_path = os.path.join(datacasedir, "sdf_neg_one.npz")
if not os.path.exists(sdf_path):
Expand All @@ -202,6 +204,15 @@ def _get_filesinfo(self, file_paths):
else:
dictcase[datacasedir]["stats"] = self.compute_and_save_stats(f, json_path)

# Store mesh coordinates in a [N, 3] tensor in H,W,D order
nx = [self.cubsizes[0], self.cubsizes[1], self.cubsizes[2]]
res = nx
tx = torch.linspace(0, nx[0], res[0], dtype=torch.float32)
ty = torch.linspace(0, nx[1], res[1], dtype=torch.float32)
tz = torch.linspace(0, nx[2], res[2], dtype=torch.float32)
X, Y, Z = torch.meshgrid(tx, ty, tz, indexing="ij")
self.grid = torch.flatten(torch.stack((X, Y, Z), dim=-1), end_dim=-2)

return dictcase

def _reconstruct_sample(self, dictcase, time_idx, leadtime):
Expand Down Expand Up @@ -278,26 +289,28 @@ def get_data(start, end):
data = data.transpose((0, -1, 3, 1, 2))
cond_data = cond_data.transpose((0, -1, 3, 1, 2))

# We reduce H dimension from 194 x 50 x 50 to 192 x 50 x 50 to
# allow reasonable patch sizes. Otherwise, as 194 = 2 x 97, it
# would only allow us patch size of 2 or 97, neither of which are
# reasonable.
# Geometry mask in [(HWD)] order
geometry_mask = np.full((total_padded_cells), False)
geometry_mask[inside_idx] = True
for ft in features:
bcs = f['boundary-conditions'][ft]
for name, desc in bcs.items():
if desc.attrs['type'] == 'fixed-value':
boundary_idx = np.array(f['grid/boundaries'][name])
geometry_mask[boundary_idx] = True

# Pressure has fixed-value boundary condition on the outflow. If we
# simply reduce the dimension by eliminating the last two layers,
# we lose that information. Instead, set the last H layer to be the
# outflow.
# cond_data[:,:,:,-3,:] = cond_data[:,:,:,-1,:] # only for cell types
# data[:,ft_mapping['p'],:,-3,:] = data[:,ft_mapping['p'],:,-1,:]
return data, cond_data, geometry_mask

return data, cond_data
comb_x, cond_data, geometry_mask = get_data(time_idx, time_idx + self.nsteps_input)
comb_y, _, _ = get_data(time_idx + self.nsteps_input + leadtime - 1, time_idx + self.nsteps_input + leadtime)

comb_x, cond_data = get_data(time_idx, time_idx + self.nsteps_input)
comb_y, _ = get_data(time_idx + self.nsteps_input + leadtime - 1, time_idx + self.nsteps_input + leadtime)
# Make sure that the generated sample matches the cubsizes
D, H, W = comb_x.shape[-3:]
assert [H, W, D] == self.cubsizes

comb = np.concatenate((comb_x, comb_y), axis=0)

return torch.from_numpy(comb), leadtime.to(torch.float32), torch.from_numpy(cond_data)
return torch.from_numpy(comb), leadtime.to(torch.float32), torch.from_numpy(cond_data), {"geometry_id": dictcase["geometry_id"], "grid_coords": self.grid, "geometry_mask": torch.from_numpy(geometry_mask)}

def _get_specific_bcs(self):
# FIXME: not used for now
Expand Down
24 changes: 18 additions & 6 deletions matey/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ def inference(self):
else:
cond_input = None

if "geometry" in data:
geometry = data["geometry"]
geometry["grid_coords"] = geometry["grid_coords"].to(self.device)
else:
geometry = None


cond_dict = {}
try:
cond_dict["labels"] = data["cond_field_labels"].to(self.device)
Expand All @@ -164,15 +171,20 @@ def inference(self):
tar = tar.to(self.device)
imod = self.params.hierarchical["nlevels"]-1 if hasattr(self.params, "hierarchical") else 0
if "graph" in data:
isgraph = True
tkhead_type = 'graph'
inp = graphdata
imod_bottom = imod
elif "geometry" in data:
inp = rearrange(inp.to(self.device), 'b t c d h w -> t b c d h w')
tkhead_type = 'gno'
inp = (inp, geometry)
imod_bottom = -1 # not used
else:
inp = rearrange(inp.to(self.device), 'b t c d h w -> t b c d h w')
isgraph = False
tkhead_type = 'default'
imod_bottom = determine_turt_levels(self.model.module.tokenizer_heads_params[tkhead_name][-1], inp.shape[-3:], imod) if imod>0 else 0
seq_group = self.current_group if dset_type in self.valid_dataset.DP_dsets else None
print(f"Rank {self.global_rank} input shape {inp.shape if not isgraph else inp}, dset_type {dset_type}", flush=True)
print(f"Rank {self.global_rank} input shape {inp.shape if tkhead_type == 'default' else inp[0].shape if tkhead_type == 'gno' else inp}, dset_type {dset_type}", flush=True)
opts = ForwardOptionsBase(
imod=imod,
imod_bottom=imod_bottom,
Expand All @@ -182,16 +194,16 @@ def inference(self):
blockdict=copy.deepcopy(blockdict),
cond_dict=copy.deepcopy(cond_dict),
cond_input=cond_input,
isgraph=isgraph,
tkhead_type=tkhead_type,
field_labels_out= field_labels_out
)
output, rollout_steps = self.model_forward(inp, field_labels, bcs, opts)
if tar.ndim == 6: #B,T,C,D,H,W
if rollout_steps is None:
rollout_steps = leadtime.view(-1).long()
tar = tar[:, rollout_steps-1, :] # B,C,D,H,W
update_loss_logs_inplace_eval(output, tar, graphdata if isgraph else None, logs, loss_dset_logs, loss_l1_dset_logs, loss_rmse_dset_logs, dset_type)
if not isgraph and getattr(self.params, "log_ssim", False):
update_loss_logs_inplace_eval(output, tar, graphdata if tkhead_type == 'graph' else None, logs, loss_dset_logs, loss_l1_dset_logs, loss_rmse_dset_logs, dset_type)
if tkhead_type == 'default' and getattr(self.params, "log_ssim", False):
avg_ssim = get_ssim(output, tar, blockdict, self.global_rank, self.current_group, self.group_rank, self.group_size, self.device, self.valid_dataset, dset_index)
logs['valid_ssim'] += avg_ssim
self.single_print('DONE VALIDATING - NOW SYNCING')
Expand Down
4 changes: 2 additions & 2 deletions matey/models/avit.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,16 @@ def forward(self, x, state_labels, bcs, opts: ForwardOptionsBase, train_opts: Op
#unpack arguments
imod = opts.imod
tkhead_name = opts.tkhead_name
tkhead_type = opts.tkhead_type
sequence_parallel_group = opts.sequence_parallel_group
leadtime = opts.leadtime
blockdict = opts.blockdict
cond_dict = opts.cond_dict
refine_ratio = opts.refine_ratio
cond_input = opts.cond_input
isgraph = opts.isgraph
##################################################################
conditioning = (cond_dict != None and bool(cond_dict) and self.conditioning)
assert not isgraph, "graph is not supported in AViT"
assert tkhead_type == 'default', "graph or gno are not supported in AViT"
#T,B,C,D,H,W
T, _, _, D, _, _ = x.shape
if self.tokenizer_heads_gammaref[tkhead_name] is None and refine_ratio is None:
Expand Down
Loading