diff --git a/matey/models/basemodel.py b/matey/models/basemodel.py index 1cd8c75..75ffb67 100644 --- a/matey/models/basemodel.py +++ b/matey/models/basemodel.py @@ -28,7 +28,7 @@ class BaseModel(nn.Module): embed_dim (int): Dimension of the embedding n_states (int): Number of input state variables. """ - def __init__(self, tokenizer_heads, n_states=6, n_states_out=None, n_states_cond=None, embed_dim=768, leadtime=False, cond_input=False, n_steps=1, bias_type="none", SR_ratio=[1,1,1], model_SR=False, hierarchical=None, notransposed=False, nlevels=1, smooth=False): + def __init__(self, tokenizer_heads, n_states=6, n_states_out=None, n_states_cond=None, embed_dim=768, leadtime=False, cond_input=False, n_steps=1, bias_type="none", SR_ratio=[1,1,1], model_SR=False, hierarchical=None, notransposed=False, nlevels=1, smooth=False, use_linear=False): super().__init__() self.space_bag = nn.ModuleList([SubsampledLinear(n_states, embed_dim//4) for _ in range(nlevels)]) self.conditioning = (n_states_cond is not None and n_states_cond > 0) @@ -77,10 +77,10 @@ def __init__(self, tokenizer_heads, n_states=6, n_states_out=None, n_states_cond if self.conditioning: embed_ensemble_cond.append(GraphhMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim)) else: - embed_ensemble.append(hMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim)) - debed_ensemble.append(hMLP_output(patch_size=ps_scale_out, embed_dim=embed_dim, out_chans=n_states_out, notransposed=notransposed, smooth=smooth)) + embed_ensemble.append(hMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim, use_linear=use_linear)) + debed_ensemble.append(hMLP_output(patch_size=ps_scale_out, embed_dim=embed_dim, out_chans=n_states_out, notransposed=notransposed, smooth=smooth, use_linear=use_linear)) if self.conditioning: - embed_ensemble_cond.append(hMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim)) + embed_ensemble_cond.append(hMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim, use_linear=use_linear)) tokenizer_ensemble_heads_level[head_name]["embed"] = embed_ensemble tokenizer_ensemble_heads_level[head_name]["debed"] = debed_ensemble tokenizer_ensemble_heads_level[head_name]["embed_cond"] = embed_ensemble_cond @@ -105,8 +105,9 @@ def expand_conv_projections(self, refine_resol): debed_ensemble_new = nn.ModuleList() #for ps_scale in self.patch_size: for ilevel in range(self.token_level): - embed_ensemble_new.append(hMLP_stem(patch_size=self.patch_size[ilevel], in_chans=embed_dim//4, embed_dim=embed_dim)) - debed_ensemble_new.append(hMLP_output(patch_size=self.patch_size[ilevel], embed_dim=embed_dim, out_chans=n_states)) + _use_linear = getattr(self.embed_ensemble[0], 'use_linear', False) if len(self.embed_ensemble) > 0 else False + embed_ensemble_new.append(hMLP_stem(patch_size=self.patch_size[ilevel], in_chans=embed_dim//4, embed_dim=embed_dim, use_linear=_use_linear)) + debed_ensemble_new.append(hMLP_output(patch_size=self.patch_size[ilevel], embed_dim=embed_dim, out_chans=n_states, use_linear=_use_linear)) if self.token_level>1: embed_ensemble_new[-1]=self.embed_ensemble[0] @@ -154,6 +155,7 @@ def expand_projections(self, expansion_amount): out_chans=new_out, notransposed=old_debed.notransposed, smooth=old_debed.smooth, + use_linear=getattr(old_debed, 'use_linear', False), ) old_head = old_debed.out_head diff --git a/matey/models/spatial_modules.py b/matey/models/spatial_modules.py index f67162b..3f3ef25 100644 --- a/matey/models/spatial_modules.py +++ b/matey/models/spatial_modules.py @@ -209,35 +209,60 @@ def calc_ks4conv(patch_size=(1,16,16), nconv=3): class hMLP_stem(nn.Module): """ Image to Patch Embedding """ - def __init__(self, patch_size=(1,16,16), in_chans=3, embed_dim=768, nconv=3): + def __init__(self, patch_size=(1,16,16), in_chans=3, embed_dim=768, nconv=3, use_linear=False): #patch_size: (ps_z, ps_x, ps_y) super().__init__() self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.nconv = nconv + self.use_linear = use_linear self.ks = calc_ks4conv(patch_size=self.patch_size, nconv=self.nconv) - modulelist = [] - for ilayer in range(self.nconv): - in_chans_ilayer = in_chans if ilayer==0 else embed_dim//4 - embed_ilayer = embed_dim if ilayer==self.nconv-1 else embed_dim//4 - ks_ilayer = self.ks[ilayer] - #modulelist.append(nn.Conv2d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False)) - #modulelist.append(RMSInstanceNorm2d(embed_ilayer, affine=True)) #changed to RMSInstanceNormSpace - modulelist.append(nn.Conv3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False)) - modulelist.append(nn.InstanceNorm3d(embed_ilayer, affine=True)) - modulelist.append(nn.GELU()) - self.in_proj = torch.nn.Sequential(*modulelist) + if self.use_linear: + self.linears = nn.ModuleList() + self.norms = nn.ModuleList() + self.acts = nn.ModuleList() + for ilayer in range(self.nconv): + in_chans_ilayer = in_chans if ilayer==0 else embed_dim//4 + embed_ilayer = embed_dim if ilayer==self.nconv-1 else embed_dim//4 + kD, kH, kW = self.ks[ilayer] + self.linears.append(nn.Linear(in_chans_ilayer * kD * kH * kW, embed_ilayer, bias=False)) + self.norms.append(nn.InstanceNorm3d(embed_ilayer, affine=True)) + self.acts.append(nn.GELU()) + else: + modulelist = [] + for ilayer in range(self.nconv): + in_chans_ilayer = in_chans if ilayer==0 else embed_dim//4 + embed_ilayer = embed_dim if ilayer==self.nconv-1 else embed_dim//4 + ks_ilayer = self.ks[ilayer] + modulelist.append(nn.Conv3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False)) + modulelist.append(nn.InstanceNorm3d(embed_ilayer, affine=True)) + modulelist.append(nn.GELU()) + self.in_proj = torch.nn.Sequential(*modulelist) def forward(self, x): - x = self.in_proj(x) - return x + if self.use_linear: + for ilayer in range(self.nconv): + TB = x.shape[0] + kD, kH, kW = self.ks[ilayer] + D, H, W = x.shape[2], x.shape[3], x.shape[4] + x = rearrange(x, 'tb cin (nd kd) (nh kh) (nw kw) -> (tb nd nh nw) (cin kd kh kw)', + kd=kD, kh=kH, kw=kW) + x = self.linears[ilayer](x) + x = rearrange(x, '(tb nd nh nw) cout -> tb cout nd nh nw', + tb=TB, nd=D//kD, nh=H//kH, nw=W//kW) + x = self.norms[ilayer](x) + x = self.acts[ilayer](x) + return x + else: + x = self.in_proj(x) + return x class hMLP_output(nn.Module): """ Patch to Image De-bedding """ - def __init__(self, patch_size=(1,16,16), out_chans=3, embed_dim=768, nconv=3, notransposed=False, smooth=False): + def __init__(self, patch_size=(1,16,16), out_chans=3, embed_dim=768, nconv=3, notransposed=False, smooth=False, use_linear=False): #patch_size: (ps_z, ps_x, ps_y) super().__init__() self.patch_size = patch_size @@ -247,51 +272,78 @@ def __init__(self, patch_size=(1,16,16), out_chans=3, embed_dim=768, nconv=3, no self.ks = calc_ks4conv(patch_size=self.patch_size, nconv=self.nconv) self.notransposed = notransposed self.smooth = smooth - - modulelist = [] - for ilayer in range(self.nconv-1): - in_chans_ilayer = embed_dim if ilayer==0 else embed_dim//4 - embed_ilayer = embed_dim//4 - ks_ilayer = self.ks[-(ilayer+1)] - if self.notransposed: - modulelist.append(UpsampleConv3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, bias=False)) - else: - modulelist.append(nn.ConvTranspose3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False)) - modulelist.append(nn.InstanceNorm3d(embed_ilayer, affine=True)) - modulelist.append(nn.GELU()) - self.out_proj = torch.nn.Sequential(*modulelist) - if self.notransposed: - out_head = UpsampleConv3d(embed_dim//4, out_chans, kernel_size=self.ks[0]) - self.out_head = out_head - else: - self.out_head = nn.ConvTranspose3d(embed_dim//4, out_chans, kernel_size=self.ks[0], stride=self.ks[0]) + self.use_linear = use_linear and not notransposed # linear only applies to ConvTranspose3d path + + if self.use_linear: + self.linears = nn.ModuleList() + self.norms = nn.ModuleList() + self.acts = nn.ModuleList() + for ilayer in range(self.nconv-1): + in_chans_ilayer = embed_dim if ilayer==0 else embed_dim//4 + embed_ilayer = embed_dim//4 + kD, kH, kW = self.ks[-(ilayer+1)] + self.linears.append(nn.Linear(in_chans_ilayer, embed_ilayer * kD * kH * kW, bias=False)) + self.norms.append(nn.InstanceNorm3d(embed_ilayer, affine=True)) + self.acts.append(nn.GELU()) + # Final head + kD, kH, kW = self.ks[0] + self.out_head = nn.Linear(embed_dim//4, out_chans * kD * kH * kW) + self.out_head_ks = self.ks[0] if self.smooth: self.smooth = nn.Conv3d(out_chans, out_chans, kernel_size=self.ks[0], stride=1, groups=out_chans, padding="same", padding_mode="reflect") - """ - #previous implementation - out_head = nn.ConvTranspose3d(embed_dim//4, out_chans, kernel_size=self.ks[0], stride=self.ks[0]) - self.out_stride = self.ks[0] - self.out_kernel = nn.Parameter(out_head.weight) - self.out_bias = nn.Parameter(out_head.bias) - """ + else: + modulelist = [] + for ilayer in range(self.nconv-1): + in_chans_ilayer = embed_dim if ilayer==0 else embed_dim//4 + embed_ilayer = embed_dim//4 + ks_ilayer = self.ks[-(ilayer+1)] + if self.notransposed: + modulelist.append(UpsampleConv3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, bias=False)) + else: + modulelist.append(nn.ConvTranspose3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False)) + modulelist.append(nn.InstanceNorm3d(embed_ilayer, affine=True)) + modulelist.append(nn.GELU()) + self.out_proj = torch.nn.Sequential(*modulelist) + if self.notransposed: + out_head = UpsampleConv3d(embed_dim//4, out_chans, kernel_size=self.ks[0]) + self.out_head = out_head + else: + self.out_head = nn.ConvTranspose3d(embed_dim//4, out_chans, kernel_size=self.ks[0], stride=self.ks[0]) + if self.smooth: + self.smooth = nn.Conv3d(out_chans, out_chans, kernel_size=self.ks[0], stride=1, groups=out_chans, padding="same", padding_mode="reflect") def forward(self, x): #B,C,D,H,W - x = self.out_proj(x)#.flatten(2).transpose(1, 2) - if self.notransposed: - #x = self.out_upsample(x) - x = self.out_head(x) - #x = F.conv3d(x, self.out_kernel[state_labels, :], self.out_bias[state_labels], stride=self.out_stride) - #x = x[:,state_labels,...] - else: - """ - x = F.conv_transpose3d(x, self.out_kernel[:, state_labels], self.out_bias[state_labels], stride=self.out_stride) - """ + if self.use_linear: + for ilayer in range(self.nconv-1): + TB, _, D, H, W = x.shape + kD, kH, kW = self.ks[-(ilayer+1)] + cout = self.linears[ilayer].out_features // (kD * kH * kW) + x = rearrange(x, 'tb cin d h w -> (tb d h w) cin') + x = self.linears[ilayer](x) + x = rearrange(x, '(tb d h w) (cout kd kh kw) -> tb cout (d kd) (h kh) (w kw)', + tb=TB, d=D, h=H, w=W, cout=cout, kd=kD, kh=kH, kw=kW) + x = self.norms[ilayer](x) + x = self.acts[ilayer](x) + # Final head + TB, _, D, H, W = x.shape + kD, kH, kW = self.out_head_ks + x = rearrange(x, 'tb cin d h w -> (tb d h w) cin') x = self.out_head(x) + x = rearrange(x, '(tb d h w) (cout kd kh kw) -> tb cout (d kd) (h kh) (w kw)', + tb=TB, d=D, h=H, w=W, cout=self.out_chans, kd=kD, kh=kH, kw=kW) if self.smooth: x = self.smooth(x) - #x = x[:,state_labels,...] - return x + return x + else: + x = self.out_proj(x) + if self.notransposed: + x = self.out_head(x) + else: + x = self.out_head(x) + if self.smooth: + x = self.smooth(x) + return x class GraphhMLP_stem(nn.Module): """graph to patch embedding""" diff --git a/matey/models/vit.py b/matey/models/vit.py index 6664c82..12b5095 100644 --- a/matey/models/vit.py +++ b/matey/models/vit.py @@ -35,7 +35,8 @@ def build_vit(params): n_steps=params.n_steps, bias_type=params.bias_type, replace_patch=getattr(params, 'replace_patch', True), - hierarchical=getattr(params, 'hierarchical', None) + hierarchical=getattr(params, 'hierarchical', None), + use_linear=getattr(params, 'use_linear', False), ) return model @@ -52,9 +53,9 @@ class ViT_all2all(BaseModel): sts_f """ def __init__(self, tokenizer_heads=None, embed_dim=768, num_heads=12, processor_blocks=8, n_states=6, n_states_cond=None, - drop_path=.2, sts_train=False, sts_model=False, leadtime=False, cond_input=False, n_steps=1, bias_type="none", replace_patch=True, SR_ratio=[1,1,1], hierarchical=None): - super().__init__(tokenizer_heads=tokenizer_heads, n_states=n_states, n_states_cond=n_states_cond, embed_dim=embed_dim, leadtime=leadtime, - cond_input=cond_input, n_steps=n_steps, bias_type=bias_type,SR_ratio=SR_ratio, hierarchical=hierarchical) + drop_path=.2, sts_train=False, sts_model=False, leadtime=False, cond_input=False, n_steps=1, bias_type="none", replace_patch=True, SR_ratio=[1,1,1], hierarchical=None, use_linear=False): + super().__init__(tokenizer_heads=tokenizer_heads, n_states=n_states, n_states_cond=n_states_cond, embed_dim=embed_dim, leadtime=leadtime, + cond_input=cond_input, n_steps=n_steps, bias_type=bias_type,SR_ratio=SR_ratio, hierarchical=hierarchical, use_linear=use_linear) self.drop_path = drop_path self.dp = np.linspace(0, drop_path, processor_blocks) self.blocks = nn.ModuleList([SpaceTimeBlock_all2all(embed_dim, num_heads,drop_path=self.dp[i])