Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions matey/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BaseModel(nn.Module):
embed_dim (int): Dimension of the embedding
n_states (int): Number of input state variables.
"""
def __init__(self, tokenizer_heads, n_states=6, n_states_out=None, n_states_cond=None, embed_dim=768, leadtime=False, cond_input=False, n_steps=1, bias_type="none", SR_ratio=[1,1,1], model_SR=False, hierarchical=None, notransposed=False, nlevels=1, smooth=False):
def __init__(self, tokenizer_heads, n_states=6, n_states_out=None, n_states_cond=None, embed_dim=768, leadtime=False, cond_input=False, n_steps=1, bias_type="none", SR_ratio=[1,1,1], model_SR=False, hierarchical=None, notransposed=False, nlevels=1, smooth=False, use_linear=False):
super().__init__()
self.space_bag = nn.ModuleList([SubsampledLinear(n_states, embed_dim//4) for _ in range(nlevels)])
self.conditioning = (n_states_cond is not None and n_states_cond > 0)
Expand Down Expand Up @@ -77,10 +77,10 @@ def __init__(self, tokenizer_heads, n_states=6, n_states_out=None, n_states_cond
if self.conditioning:
embed_ensemble_cond.append(GraphhMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim))
else:
embed_ensemble.append(hMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim))
debed_ensemble.append(hMLP_output(patch_size=ps_scale_out, embed_dim=embed_dim, out_chans=n_states_out, notransposed=notransposed, smooth=smooth))
embed_ensemble.append(hMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim, use_linear=use_linear))
debed_ensemble.append(hMLP_output(patch_size=ps_scale_out, embed_dim=embed_dim, out_chans=n_states_out, notransposed=notransposed, smooth=smooth, use_linear=use_linear))
if self.conditioning:
embed_ensemble_cond.append(hMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim))
embed_ensemble_cond.append(hMLP_stem(patch_size=ps_scale, in_chans=embed_dim//4, embed_dim=embed_dim, use_linear=use_linear))
tokenizer_ensemble_heads_level[head_name]["embed"] = embed_ensemble
tokenizer_ensemble_heads_level[head_name]["debed"] = debed_ensemble
tokenizer_ensemble_heads_level[head_name]["embed_cond"] = embed_ensemble_cond
Expand All @@ -105,8 +105,9 @@ def expand_conv_projections(self, refine_resol):
debed_ensemble_new = nn.ModuleList()
#for ps_scale in self.patch_size:
for ilevel in range(self.token_level):
embed_ensemble_new.append(hMLP_stem(patch_size=self.patch_size[ilevel], in_chans=embed_dim//4, embed_dim=embed_dim))
debed_ensemble_new.append(hMLP_output(patch_size=self.patch_size[ilevel], embed_dim=embed_dim, out_chans=n_states))
_use_linear = getattr(self.embed_ensemble[0], 'use_linear', False) if len(self.embed_ensemble) > 0 else False
embed_ensemble_new.append(hMLP_stem(patch_size=self.patch_size[ilevel], in_chans=embed_dim//4, embed_dim=embed_dim, use_linear=_use_linear))
debed_ensemble_new.append(hMLP_output(patch_size=self.patch_size[ilevel], embed_dim=embed_dim, out_chans=n_states, use_linear=_use_linear))

if self.token_level>1:
embed_ensemble_new[-1]=self.embed_ensemble[0]
Expand Down Expand Up @@ -154,6 +155,7 @@ def expand_projections(self, expansion_amount):
out_chans=new_out,
notransposed=old_debed.notransposed,
smooth=old_debed.smooth,
use_linear=getattr(old_debed, 'use_linear', False),
)

old_head = old_debed.out_head
Expand Down
156 changes: 104 additions & 52 deletions matey/models/spatial_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,35 +209,60 @@ def calc_ks4conv(patch_size=(1,16,16), nconv=3):
class hMLP_stem(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, patch_size=(1,16,16), in_chans=3, embed_dim=768, nconv=3):
def __init__(self, patch_size=(1,16,16), in_chans=3, embed_dim=768, nconv=3, use_linear=False):
#patch_size: (ps_z, ps_x, ps_y)
super().__init__()
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.nconv = nconv
self.use_linear = use_linear
self.ks = calc_ks4conv(patch_size=self.patch_size, nconv=self.nconv)

modulelist = []
for ilayer in range(self.nconv):
in_chans_ilayer = in_chans if ilayer==0 else embed_dim//4
embed_ilayer = embed_dim if ilayer==self.nconv-1 else embed_dim//4
ks_ilayer = self.ks[ilayer]
#modulelist.append(nn.Conv2d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False))
#modulelist.append(RMSInstanceNorm2d(embed_ilayer, affine=True)) #changed to RMSInstanceNormSpace
modulelist.append(nn.Conv3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False))
modulelist.append(nn.InstanceNorm3d(embed_ilayer, affine=True))
modulelist.append(nn.GELU())
self.in_proj = torch.nn.Sequential(*modulelist)
if self.use_linear:
self.linears = nn.ModuleList()
self.norms = nn.ModuleList()
self.acts = nn.ModuleList()
for ilayer in range(self.nconv):
in_chans_ilayer = in_chans if ilayer==0 else embed_dim//4
embed_ilayer = embed_dim if ilayer==self.nconv-1 else embed_dim//4
kD, kH, kW = self.ks[ilayer]
self.linears.append(nn.Linear(in_chans_ilayer * kD * kH * kW, embed_ilayer, bias=False))
self.norms.append(nn.InstanceNorm3d(embed_ilayer, affine=True))
self.acts.append(nn.GELU())
else:
modulelist = []
for ilayer in range(self.nconv):
in_chans_ilayer = in_chans if ilayer==0 else embed_dim//4
embed_ilayer = embed_dim if ilayer==self.nconv-1 else embed_dim//4
ks_ilayer = self.ks[ilayer]
modulelist.append(nn.Conv3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False))
modulelist.append(nn.InstanceNorm3d(embed_ilayer, affine=True))
modulelist.append(nn.GELU())
self.in_proj = torch.nn.Sequential(*modulelist)

def forward(self, x):
x = self.in_proj(x)
return x
if self.use_linear:
for ilayer in range(self.nconv):
TB = x.shape[0]
kD, kH, kW = self.ks[ilayer]
D, H, W = x.shape[2], x.shape[3], x.shape[4]
x = rearrange(x, 'tb cin (nd kd) (nh kh) (nw kw) -> (tb nd nh nw) (cin kd kh kw)',
kd=kD, kh=kH, kw=kW)
x = self.linears[ilayer](x)
x = rearrange(x, '(tb nd nh nw) cout -> tb cout nd nh nw',
tb=TB, nd=D//kD, nh=H//kH, nw=W//kW)
x = self.norms[ilayer](x)
x = self.acts[ilayer](x)
return x
else:
x = self.in_proj(x)
return x

class hMLP_output(nn.Module):
""" Patch to Image De-bedding
"""
def __init__(self, patch_size=(1,16,16), out_chans=3, embed_dim=768, nconv=3, notransposed=False, smooth=False):
def __init__(self, patch_size=(1,16,16), out_chans=3, embed_dim=768, nconv=3, notransposed=False, smooth=False, use_linear=False):
#patch_size: (ps_z, ps_x, ps_y)
super().__init__()
self.patch_size = patch_size
Expand All @@ -247,51 +272,78 @@ def __init__(self, patch_size=(1,16,16), out_chans=3, embed_dim=768, nconv=3, no
self.ks = calc_ks4conv(patch_size=self.patch_size, nconv=self.nconv)
self.notransposed = notransposed
self.smooth = smooth

modulelist = []
for ilayer in range(self.nconv-1):
in_chans_ilayer = embed_dim if ilayer==0 else embed_dim//4
embed_ilayer = embed_dim//4
ks_ilayer = self.ks[-(ilayer+1)]
if self.notransposed:
modulelist.append(UpsampleConv3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, bias=False))
else:
modulelist.append(nn.ConvTranspose3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False))
modulelist.append(nn.InstanceNorm3d(embed_ilayer, affine=True))
modulelist.append(nn.GELU())
self.out_proj = torch.nn.Sequential(*modulelist)
if self.notransposed:
out_head = UpsampleConv3d(embed_dim//4, out_chans, kernel_size=self.ks[0])
self.out_head = out_head
else:
self.out_head = nn.ConvTranspose3d(embed_dim//4, out_chans, kernel_size=self.ks[0], stride=self.ks[0])
self.use_linear = use_linear and not notransposed # linear only applies to ConvTranspose3d path

if self.use_linear:
self.linears = nn.ModuleList()
self.norms = nn.ModuleList()
self.acts = nn.ModuleList()
for ilayer in range(self.nconv-1):
in_chans_ilayer = embed_dim if ilayer==0 else embed_dim//4
embed_ilayer = embed_dim//4
kD, kH, kW = self.ks[-(ilayer+1)]
self.linears.append(nn.Linear(in_chans_ilayer, embed_ilayer * kD * kH * kW, bias=False))
self.norms.append(nn.InstanceNorm3d(embed_ilayer, affine=True))
self.acts.append(nn.GELU())
# Final head
kD, kH, kW = self.ks[0]
self.out_head = nn.Linear(embed_dim//4, out_chans * kD * kH * kW)
self.out_head_ks = self.ks[0]
if self.smooth:
self.smooth = nn.Conv3d(out_chans, out_chans, kernel_size=self.ks[0], stride=1, groups=out_chans, padding="same", padding_mode="reflect")
"""
#previous implementation
out_head = nn.ConvTranspose3d(embed_dim//4, out_chans, kernel_size=self.ks[0], stride=self.ks[0])
self.out_stride = self.ks[0]
self.out_kernel = nn.Parameter(out_head.weight)
self.out_bias = nn.Parameter(out_head.bias)
"""
else:
modulelist = []
for ilayer in range(self.nconv-1):
in_chans_ilayer = embed_dim if ilayer==0 else embed_dim//4
embed_ilayer = embed_dim//4
ks_ilayer = self.ks[-(ilayer+1)]
if self.notransposed:
modulelist.append(UpsampleConv3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, bias=False))
else:
modulelist.append(nn.ConvTranspose3d(in_chans_ilayer, embed_ilayer, kernel_size=ks_ilayer, stride=ks_ilayer, bias=False))
modulelist.append(nn.InstanceNorm3d(embed_ilayer, affine=True))
modulelist.append(nn.GELU())
self.out_proj = torch.nn.Sequential(*modulelist)
if self.notransposed:
out_head = UpsampleConv3d(embed_dim//4, out_chans, kernel_size=self.ks[0])
self.out_head = out_head
else:
self.out_head = nn.ConvTranspose3d(embed_dim//4, out_chans, kernel_size=self.ks[0], stride=self.ks[0])
if self.smooth:
self.smooth = nn.Conv3d(out_chans, out_chans, kernel_size=self.ks[0], stride=1, groups=out_chans, padding="same", padding_mode="reflect")

def forward(self, x):
#B,C,D,H,W
x = self.out_proj(x)#.flatten(2).transpose(1, 2)
if self.notransposed:
#x = self.out_upsample(x)
x = self.out_head(x)
#x = F.conv3d(x, self.out_kernel[state_labels, :], self.out_bias[state_labels], stride=self.out_stride)
#x = x[:,state_labels,...]
else:
"""
x = F.conv_transpose3d(x, self.out_kernel[:, state_labels], self.out_bias[state_labels], stride=self.out_stride)
"""
if self.use_linear:
for ilayer in range(self.nconv-1):
TB, _, D, H, W = x.shape
kD, kH, kW = self.ks[-(ilayer+1)]
cout = self.linears[ilayer].out_features // (kD * kH * kW)
x = rearrange(x, 'tb cin d h w -> (tb d h w) cin')
x = self.linears[ilayer](x)
x = rearrange(x, '(tb d h w) (cout kd kh kw) -> tb cout (d kd) (h kh) (w kw)',
tb=TB, d=D, h=H, w=W, cout=cout, kd=kD, kh=kH, kw=kW)
x = self.norms[ilayer](x)
x = self.acts[ilayer](x)
# Final head
TB, _, D, H, W = x.shape
kD, kH, kW = self.out_head_ks
x = rearrange(x, 'tb cin d h w -> (tb d h w) cin')
x = self.out_head(x)
x = rearrange(x, '(tb d h w) (cout kd kh kw) -> tb cout (d kd) (h kh) (w kw)',
tb=TB, d=D, h=H, w=W, cout=self.out_chans, kd=kD, kh=kH, kw=kW)
if self.smooth:
x = self.smooth(x)
#x = x[:,state_labels,...]
return x
return x
else:
x = self.out_proj(x)
if self.notransposed:
x = self.out_head(x)
else:
x = self.out_head(x)
if self.smooth:
x = self.smooth(x)
return x

class GraphhMLP_stem(nn.Module):
"""graph to patch embedding"""
Expand Down
9 changes: 5 additions & 4 deletions matey/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def build_vit(params):
n_steps=params.n_steps,
bias_type=params.bias_type,
replace_patch=getattr(params, 'replace_patch', True),
hierarchical=getattr(params, 'hierarchical', None)
hierarchical=getattr(params, 'hierarchical', None),
use_linear=getattr(params, 'use_linear', False),
)
return model

Expand All @@ -52,9 +53,9 @@ class ViT_all2all(BaseModel):
sts_f
"""
def __init__(self, tokenizer_heads=None, embed_dim=768, num_heads=12, processor_blocks=8, n_states=6, n_states_cond=None,
drop_path=.2, sts_train=False, sts_model=False, leadtime=False, cond_input=False, n_steps=1, bias_type="none", replace_patch=True, SR_ratio=[1,1,1], hierarchical=None):
super().__init__(tokenizer_heads=tokenizer_heads, n_states=n_states, n_states_cond=n_states_cond, embed_dim=embed_dim, leadtime=leadtime,
cond_input=cond_input, n_steps=n_steps, bias_type=bias_type,SR_ratio=SR_ratio, hierarchical=hierarchical)
drop_path=.2, sts_train=False, sts_model=False, leadtime=False, cond_input=False, n_steps=1, bias_type="none", replace_patch=True, SR_ratio=[1,1,1], hierarchical=None, use_linear=False):
super().__init__(tokenizer_heads=tokenizer_heads, n_states=n_states, n_states_cond=n_states_cond, embed_dim=embed_dim, leadtime=leadtime,
cond_input=cond_input, n_steps=n_steps, bias_type=bias_type,SR_ratio=SR_ratio, hierarchical=hierarchical, use_linear=use_linear)
self.drop_path = drop_path
self.dp = np.linspace(0, drop_path, processor_blocks)
self.blocks = nn.ModuleList([SpaceTimeBlock_all2all(embed_dim, num_heads,drop_path=self.dp[i])
Expand Down