diff --git a/im2sim/layers/layer_util.py b/im2sim/layers/layer_util.py index 4e08c8f..e2d3713 100644 --- a/im2sim/layers/layer_util.py +++ b/im2sim/layers/layer_util.py @@ -1,6 +1,7 @@ from torch import nn import torch.nn.functional as F import torch_geometric.nn as gnn +import math def get_image_layer(name, rank): @@ -155,5 +156,448 @@ def standardize_spatial_factors(factors, rank): raise TypeError( f"Each factor must be an int, tuple, or list, got {type(f).__name__}" ) + + return standardized - return standardized \ No newline at end of file +def _expand_to_4d(param, name="parameter"): + """ + Convert an int or 4-tuple/list into a 4-tuple: (T, D, H, W) + """ + if isinstance(param, int): + return (param, param, param, param) + if isinstance(param, (tuple, list)): + if len(param) != 4: + raise ValueError(f"{name} must have 4 elements (T, D, H, W), got {param}") + return tuple(param) + raise TypeError(f"{name} must be an int or a tuple/list of length 4") + + +def _same_padding_4d(kernel_size): + """ + Compute 'same-like' padding for odd kernel sizes. + """ + k = _expand_to_4d(kernel_size, "kernel_size") + return tuple(kk // 2 for kk in k) + + +def _get_spatial_op(rank, op2d, op3d, name="spatial op"): + if rank == 2: + return op2d + if rank == 3: + return op3d + raise ValueError(f"{name} only supports rank=2 or rank=3, got {rank}") + + +def _default_spatial_mode(rank): + return "bilinear" if rank == 2 else "trilinear" + +def _same_padding_time(kernel_size, rank): + """ + Return same-style padding for (T + spatial) kernel sizes. + + rank=2 -> expects int or (T, H, W) + rank=3 -> expects int or (T, D, H, W) + """ + if isinstance(kernel_size, int): + k = (kernel_size,) * (rank + 1) + elif isinstance(kernel_size, (tuple, list)): + if len(kernel_size) != rank + 1: + raise ValueError( + f"kernel_size must have length {rank + 1} for rank={rank}, got {kernel_size}" + ) + k = tuple(kernel_size) + else: + raise TypeError("kernel_size must be an int or tuple/list") + + return tuple(kk // 2 for kk in k) + + +class TimeDistributed(nn.Module): + """ + Apply a module independently to each time step. + + Expected input: + 2D+1: (N, C, T, H, W) + 3D+1: (N, C, T, D, H, W) + + Wrapped module should accept: + 2D+1: (N, C, H, W) + 3D+1: (N, C, D, H, W) + """ + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, x): + if x.ndim not in (5, 6): + raise ValueError( + f"TimeDistributed expects 5D or 6D input [N, C, T, ...], got shape {tuple(x.shape)}" + ) + + n, c, t = x.shape[:3] + spatial = x.shape[3:] + + # (N, C, T, *S) -> (N, T, C, *S) + permute_order = [0, 2, 1] + list(range(3, x.ndim)) + x = x.permute(*permute_order).contiguous() + + # (N, T, C, *S) -> (N*T, C, *S) + x = x.reshape(n * t, c, *spatial) + + y = self.module(x) # (N*T, C_out, *S_out) + + c_out = y.shape[1] + spatial_out = y.shape[2:] + + # (N*T, C_out, *S_out) -> (N, T, C_out, *S_out) + y = y.reshape(n, t, c_out, *spatial_out) + + # (N, T, C_out, *S_out) -> (N, C_out, T, *S_out) + permute_back = [0, 2, 1] + list(range(3, y.ndim)) + y = y.permute(*permute_back).contiguous() + + return y + + +class SpaceDistributed(nn.Module): + """ + Apply a module independently at each spatial location over time. + + Expected input: + 2D+1: (N, C, T, H, W) + 3D+1: (N, C, T, D, H, W) + + Wrapped module should accept: + (N_flat, C, T) + """ + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, x): + if x.ndim not in (5, 6): + raise ValueError( + f"SpaceDistributed expects 5D or 6D input [N, C, T, ...], got shape {tuple(x.shape)}" + ) + + n, c, t = x.shape[:3] + spatial = x.shape[3:] + spatial_rank = len(spatial) + spatial_prod = math.prod(spatial) + + # (N, C, T, *S) -> (N, *S, C, T) + permute_order = [0] + list(range(3, x.ndim)) + [1, 2] + x = x.permute(*permute_order).contiguous() + + # (N, *S, C, T) -> (N*prod(S), C, T) + x = x.reshape(n * spatial_prod, c, t) + + y = self.module(x) # (N*prod(S), C_out, T_out) + + c_out = y.shape[1] + t_out = y.shape[2] + + # (N*prod(S), C_out, T_out) -> (N, *S, C_out, T_out) + y = y.reshape(n, *spatial, c_out, t_out) + + # (N, *S, C_out, T_out) -> (N, C_out, T_out, *S) + permute_back = [0, spatial_rank + 1, spatial_rank + 2] + list(range(1, spatial_rank + 1)) + y = y.permute(*permute_back).contiguous() + + return y + + +class ConvTime(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + rank=3, + ): + super().__init__() + self.rank = rank + + self.time_kernel, self.space_kernel = self._split_param(kernel_size, "kernel_size") + self.time_stride, self.space_stride = self._split_param(stride, "stride") + self.time_padding, self.space_padding = self._split_param(padding, "padding") + self.time_dilation, self.space_dilation = self._split_param(dilation, "dilation") + + spatial_conv = _get_spatial_op(rank, nn.Conv2d, nn.Conv3d, "ConvTime spatial conv") + + self.conv_spatial = TimeDistributed( + spatial_conv( + in_channels, + out_channels, + kernel_size=self.space_kernel, + stride=self.space_stride, + padding=self.space_padding, + dilation=self.space_dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + ) + + self.conv_time = SpaceDistributed( + nn.Conv1d( + out_channels, + out_channels, + kernel_size=self.time_kernel, + stride=self.time_stride, + padding=self.time_padding, + dilation=self.time_dilation, + groups=1, + bias=bias, + device=device, + dtype=dtype, + ) + ) + + def _split_param(self, param, name="parameter"): + if isinstance(param, (tuple, list)): + if len(param) != self.rank + 1: + raise ValueError(f"{name} must have {self.rank + 1} elements (T + spatial dims)") + return param[0], tuple(param[1:]) + return param, (param,) * self.rank + + def forward(self, x): + x = self.conv_spatial(x) + x = self.conv_time(x) + return x + + +class UpsampleTime(nn.Module): + """ + Separable (rank)+1D upsampling: + - spatial upsampling per time step + - temporal upsampling per spatial location + + rank=2: + input (B, C, T, H, W) + output (B, C, T_out, H_out, W_out) + + rank=3: + input (B, C, T, D, H, W) + output (B, C, T_out, D_out, H_out, W_out) + """ + def __init__(self, scale_factor=None, size=None, mode=None, align_corners=False, rank=3): + super().__init__() + self.rank = rank + + if (scale_factor is None) == (size is None): + raise ValueError("Provide exactly one of scale_factor or size") + + if mode is None: + mode = _default_spatial_mode(rank) + elif rank == 2 and mode == "trilinear": + mode = "bilinear" + + self.use_scale_factor = scale_factor is not None + + if self.use_scale_factor: + self.time_value, self.space_value = self._split_param(scale_factor, "scale_factor") + else: + self.time_value, self.space_value = self._split_param(size, "size") + + if self.use_scale_factor: + self.up_spatial = TimeDistributed( + nn.Upsample( + scale_factor=self.space_value, + mode=mode, + align_corners=align_corners if "linear" in mode else None, + ) + ) + self.up_time = SpaceDistributed( + nn.Upsample( + scale_factor=self.time_value, + mode="linear", + align_corners=align_corners, + ) + ) + else: + self.up_spatial = TimeDistributed( + nn.Upsample( + size=self.space_value, + mode=mode, + align_corners=align_corners if "linear" in mode else None, + ) + ) + self.up_time = SpaceDistributed( + nn.Upsample( + size=self.time_value, + mode="linear", + align_corners=align_corners, + ) + ) + + def _split_param(self, param, name): + if isinstance(param, (tuple, list)): + if len(param) != self.rank + 1: + raise ValueError(f"{name} must have {self.rank + 1} elements (T + spatial dims)") + return param[0], tuple(param[1:]) + return param, (param,) * self.rank + + def forward(self, x): + x = self.up_spatial(x) + x = self.up_time(x) + return x + + +class ConvTransTime(nn.Module): + """ + Separable (rank)+1D transposed convolution: + 1) spatial ConvTransposeNd applied independently per time step + 2) temporal ConvTranspose1d applied independently per spatial location + """ + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + bias=True, + dilation=1, + rank=3, + ): + super().__init__() + self.rank = rank + + self.time_kernel, self.space_kernel = self._split_param(kernel_size, "kernel_size") + self.time_stride, self.space_stride = self._split_param(stride, "stride") + self.time_padding, self.space_padding = self._split_param(padding, "padding") + self.time_outpad, self.space_outpad = self._split_param(output_padding, "output_padding") + self.time_dilation, self.space_dilation = self._split_param(dilation, "dilation") + + spatial_conv_trans = _get_spatial_op( + rank, nn.ConvTranspose2d, nn.ConvTranspose3d, "ConvTransTime spatial conv" + ) + + self.up_spatial = TimeDistributed( + spatial_conv_trans( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.space_kernel, + stride=self.space_stride, + padding=self.space_padding, + output_padding=self.space_outpad, + dilation=self.space_dilation, + bias=bias, + ) + ) + + self.up_time = SpaceDistributed( + nn.ConvTranspose1d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=self.time_kernel, + stride=self.time_stride, + padding=self.time_padding, + output_padding=self.time_outpad, + dilation=self.time_dilation, + bias=bias, + ) + ) + + def _split_param(self, param, name="parameter"): + if isinstance(param, (tuple, list)): + if len(param) != self.rank + 1: + raise ValueError(f"{name} must have {self.rank + 1} elements (T + spatial dims)") + return param[0], tuple(param[1:]) + return param, (param,) * self.rank + + def forward(self, x): + x = self.up_spatial(x) + x = self.up_time(x) + return x + + +class MaxPoolTime(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, rank=3): + super().__init__() + self.rank = rank + + stride = kernel_size if stride is None else stride + + self.time_kernel, self.space_kernel = self._split_param(kernel_size, "kernel_size") + self.time_stride, self.space_stride = self._split_param(stride, "stride") + self.time_padding, self.space_padding = self._split_param(padding, "padding") + self.time_dilation, self.space_dilation = self._split_param(dilation, "dilation") + + spatial_pool = _get_spatial_op(rank, nn.MaxPool2d, nn.MaxPool3d, "MaxPoolTime spatial pool") + + self.pool_spatial = TimeDistributed( + spatial_pool( + kernel_size=self.space_kernel, + stride=self.space_stride, + padding=self.space_padding, + dilation=self.space_dilation, + ceil_mode=ceil_mode, + ) + ) + + self.pool_time = SpaceDistributed( + nn.MaxPool1d( + kernel_size=self.time_kernel, + stride=self.time_stride, + padding=self.time_padding, + dilation=self.time_dilation, + ceil_mode=ceil_mode, + ) + ) + + def _split_param(self, param, name="parameter"): + if isinstance(param, (tuple, list)): + if len(param) != self.rank + 1: + raise ValueError(f"{name} must have {self.rank + 1} elements (T + spatial dims)") + return param[0], tuple(param[1:]) + return param, (param,) * self.rank + + def forward(self, x): + x = self.pool_spatial(x) + x = self.pool_time(x) + return x + + +class BatchNormTime(nn.Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, rank=3): + super().__init__() + + spatial_bn = _get_spatial_op(rank, nn.BatchNorm2d, nn.BatchNorm3d, "BatchNormTime spatial norm") + + self.bn_spatial = TimeDistributed( + spatial_bn( + num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + ) + + self.bn_time = SpaceDistributed( + nn.BatchNorm1d( + num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + ) + + def forward(self, x): + x = self.bn_spatial(x) + x = self.bn_time(x) + return x \ No newline at end of file diff --git a/im2sim/models/UNet.py b/im2sim/models/UNet.py index 4bb7af9..ceef47f 100644 --- a/im2sim/models/UNet.py +++ b/im2sim/models/UNet.py @@ -3,7 +3,6 @@ import torch.nn.functional as F from ..layers.layer_util import get_image_layer, get_activation, standardize_spatial_factors - class ImageConvBlock(nn.Module): """ A convolutional block for image data diff --git a/im2sim/models/UNet_time.py b/im2sim/models/UNet_time.py new file mode 100644 index 0000000..0efd14b --- /dev/null +++ b/im2sim/models/UNet_time.py @@ -0,0 +1,435 @@ +import torch +from torch import nn +import torch.nn.functional as F +from ..layers.layer_util import get_image_layer, get_activation, standardize_spatial_factors, _same_padding_time +from ..layers.layer_util import MaxPoolTime, UpsampleTime, ConvTime, BatchNormTime, ConvTransTime + +class ImageConvBlockTime(nn.Module): + """ + A convolutional block for image data with time. + + Expected input: + [B, C, T, D, H, W] + """ + def __init__(self, + in_channels, + filters=32, + kernel_size=3, + depth=1, + rank=3, + activation='relu', + norm_type=None, + dropout_rate=None): + super().__init__() + + conv = ConvTime + padding = _same_padding_time(kernel_size, rank=rank) + + self.convs = nn.ModuleList([ + conv(in_channels if i==0 else filters, filters, kernel_size, padding=padding, rank=rank) + for i in range(depth) + ]) + + if norm_type is None: + self.norms = nn.ModuleList([nn.Identity() for _ in range(depth)]) + elif norm_type == "BatchNorm": + self.norms = nn.ModuleList([BatchNormTime(filters, rank=rank) for _ in range(depth)]) + else: + raise ValueError( + f"ImageConvBlockTime currently only supports norm_type=None or 'BatchNorm', got {norm_type}" + ) + + self.drop = nn.Dropout(p=dropout_rate) if dropout_rate else nn.Identity() + + self.act = get_activation(activation)(inplace=True) if activation.lower() == 'relu' else get_activation(activation)() + + + def forward(self, x): + """ + Args: + x (torch.Tensor): Input feature maps in image space [in_channels, ...] where the number of dims in ... corresponds to rank + + Returns: + torch.Tensor: Output feature maps [out_channels, ...] + """ + + for conv, norm in zip(self.convs, self.norms): + x = norm(self.act(conv(x))) + return self.drop(x) + + +class ImageEncoderTime(nn.Module): + """ + A CNN encoder for 3+1D image-time data. + + Expected input: + x: [B, C, T, D, H, W] + + pool_sizes entries should be 4D factors: + (T, D, H, W) + """ + + def __init__(self, + in_channels, + filters=[16,32,64,128,256], + pool_sizes = None, + kernel_size=3, + conv_blocks_per_level=1, + rank=3, + norm_type=None, + activation='relu', + dropout_rate=None): + super().__init__() + + if rank not in (2, 3): + raise ValueError(f"UNetTime currently only supports rank=2 or rank=3, got {rank}") + + n_levels = len(filters) + + if pool_sizes is None: + pool_sizes = [tuple([1] + [2] * rank) for _ in range(n_levels - 1)] + + if not isinstance(pool_sizes, (list, tuple)): + raise TypeError("pool_sizes must be a list or tuple") + + if len(pool_sizes) != n_levels - 1: + raise ValueError( + f"pool_sizes must have length {n_levels - 1} for filters of length {n_levels}, " + f"got {len(pool_sizes)}" + ) + + pool_sizes_standard = standardize_spatial_factors(pool_sizes, rank+1) + + self.conv_blocks = nn.ModuleList([ + ImageConvBlockTime(in_channels=in_channels if i==0 else filters[i-1], + filters=filters[i], + kernel_size=kernel_size, + depth=conv_blocks_per_level, + rank=rank, + activation=activation, + norm_type=norm_type, + dropout_rate=dropout_rate) + for i in range(n_levels) + ]) + + self.maxpools = nn.ModuleList([ + MaxPoolTime(pool_sizes_standard[i-1],rank=rank) if i>0 else nn.Identity() + for i in range(n_levels) + ]) + + def forward(self,x): + outputs = [] + for pool, conv in zip(self.maxpools, self.conv_blocks): + x = conv(pool(x)) + outputs.append(x) + return outputs + + +class ImageDecoderTime(nn.Module): + """ + CNN decoder for images. Mirrors ImageEncoder like a UNet decoder. + + Args: + filters (List[int]): Encoder filter sizes in top→bottom order. + kernel_size (int): Convolution kernel size. + conv_blocks_per_level (int): Number of conv blocks per level. + rank (int): Spatial rank (2 or 3). + upsample_type (str): "ConvTranspose" or "Upsample". + activation (str): Activation name. + norm_type (str): Normalization type. + dropout_rate (float): Dropout rate. + skip (bool): Use skip connections. + """ + + def __init__(self, + filters=[16,32,64,128,256], + kernel_size=3, + pool_sizes=None, + upsample_sizes = None, + conv_blocks_per_level=1, + rank=3, + upsample_type="ConvTranspose", + activation="relu", + norm_type=None, + dropout_rate=None, + skip=True): + super().__init__() + + self.skip = skip + self.rank = rank + + if rank not in (2, 3): + raise ValueError(f"UNetTime currently only supports rank=2 or rank=3, got {rank}") + + n_levels = len(filters) + + if pool_sizes is None: + pool_sizes = [tuple([1] + [2] * rank) for _ in range(n_levels - 1)] + + if not isinstance(pool_sizes, (list, tuple)): + raise TypeError("pool_sizes must be a list or tuple") + + if len(pool_sizes) != n_levels - 1: + raise ValueError( + f"pool_sizes must have length {n_levels - 1} for filters of length {n_levels}, " + f"got {len(pool_sizes)}" + ) + + pool_sizes_standard = standardize_spatial_factors(pool_sizes, rank+1) + + rev_filters = filters[::-1] + rev_pool_sizes = pool_sizes_standard[::-1] + + if upsample_sizes is None: + upsample_sizes = rev_pool_sizes + else: + if not isinstance(upsample_sizes, (list, tuple)): + raise TypeError("upsample_sizes must be a list or tuple") + + if len(upsample_sizes) != n_levels - 1: + raise ValueError( + f"upsample_sizes must have length {n_levels - 1} for filters of length {n_levels}, " + f"got {len(upsample_sizes)}" + ) + + upsample_sizes = standardize_spatial_factors(upsample_sizes, rank + 1) + + if upsample_type.lower() == 'upsample': + self.ups = nn.ModuleList([ + UpsampleTime(scale_factor=upsample_sizes[i], mode='trilinear' if rank==3 else 'bilinear', align_corners=True, rank=rank) + for i in range(n_levels - 1) + ]) + elif upsample_type.lower() == 'convtranspose': + self.ups = nn.ModuleList([ + ConvTransTime(rev_filters[i], rev_filters[i+1], kernel_size=upsample_sizes[i], stride=upsample_sizes[i],rank=rank) + for i in range(n_levels - 1) + ]) + else: + raise ValueError(f"upsample_type must be 'Upsample' or 'ConvTranspose', got {upsample_type}") + + if upsample_type.lower() == "upsample": + self.up_channel_adjust = nn.ModuleList([ + ConvTime( + in_channels=rev_filters[i], + out_channels=rev_filters[i + 1], + kernel_size=1, + padding=0, + rank=rank + ) + for i in range(n_levels - 1) + ]) + else: + self.up_channel_adjust = nn.ModuleList([ + nn.Identity() + for _ in range(n_levels - 1) + ]) + + + self.conv_blocks = nn.ModuleList([ + ImageConvBlockTime( + in_channels=rev_filters[i+1] * (2 if skip else 1), + filters=rev_filters[i+1], + kernel_size=kernel_size, + depth=conv_blocks_per_level, + rank=rank, + activation=activation, + norm_type=norm_type, + dropout_rate=dropout_rate + ) + for i in range(n_levels - 1) + ]) + + + def _match_size(self, x, skip): + """ + Match skip size to x size for 3+1D tensors [B, C, T, D, H, W]. + + Strategy: + - center crop dimensions where skip is too large + - pure interpolation upsample where skip is too small + """ + target_size = x.shape[2:] # (T, D, H, W) + skip_size = skip.shape[2:] + + if skip_size == target_size: + return skip + + # First crop dimensions where skip is too large + slices = [slice(None), slice(None)] + needs_crop = False + + for s, t in zip(skip_size, target_size): + if s > t: + start = (s - t) // 2 + end = start + t + slices.append(slice(start, end)) + needs_crop = True + else: + slices.append(slice(None)) + + if needs_crop: + skip = skip[tuple(slices)] + + # Then non-learned upsample if any dimensions are still too small + if skip.shape[2:] != target_size: + up = UpsampleTime(size=target_size, mode="trilinear", align_corners=True,rank=self.rank) + up = up.to(skip.device) + skip = up(skip) + + return skip + + def forward(self, encoder_outputs): + """ + Args: + encoder_outputs: List of tensors from encoder (top→bottom). + + Returns: + Decoded tensor at highest resolution. + """ + + # Reverse the encoder outputs so we traverse from bottleneck to top + rev_enc = encoder_outputs[::-1] + + # Start from bottleneck + x = rev_enc[0] + + # Traverse decoder levels + for i, (up, ch_adjust, conv) in enumerate(zip(self.ups, self.up_channel_adjust, self.conv_blocks)): + x = up(x) + x = ch_adjust(x) + + if self.skip: + skip_feat = rev_enc[i + 1] # next encoder feature + skip_feat = self._match_size(x, skip_feat) + x = torch.cat([x, skip_feat], dim=1) + + x = conv(x) + + return x + +class UNetTime(nn.Module): + """ + Flexible UNet for 2D or 3D images. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + filters (List[int]): Encoder filter sizes. + kernel_size (int): Conv kernel size. + pool_sizes (Tuple or list): pool sizes per level , + upsample_sizes (Tuple or list): upsample sizes per level, + conv_blocks_per_level (int): Depth per level. + rank (int): Spatial rank. + activation (str): Activation function. + norm_type (str): Normalization type. + dropout_rate (float): Dropout. + final_activation (str): Output activation. + """ + + def __init__(self, + in_channels, + out_channels, + filters=[16,32,64,128,256], + pool_sizes = None, + upsample_sizes = None, + kernel_size=3, + conv_blocks_per_level=1, + rank=3, + upsample_type="ConvTranspose", + activation="relu", + norm_type=None, + dropout_rate=None, + final_activation="linear"): + + if rank not in (2, 3): + raise ValueError(f"UNetTime currently only supports rank=2 or rank=3, got {rank}") + + if pool_sizes is None: + pool_sizes = [tuple([1] + [2] * rank) for _ in range(len(filters) - 1)] + + if not isinstance(pool_sizes, (list, tuple)): + raise TypeError("pool_sizes must be a list or tuple") + + for i, p in enumerate(pool_sizes): + if isinstance(p, int) and not isinstance(p,bool): + if p < 1: + raise ValueError(f"pool_sizes must be positive") + elif isinstance(p, (tuple,list)): + if len(p) != rank+1: + raise ValueError(f"pool_sizes tuple must be same length as rank + 1") + for j , ps in enumerate(p): + if not isinstance(ps, int) or isinstance(ps,bool): + raise TypeError(f"pool_sizes must be an int") + if ps < 1: + raise ValueError(f"pool_sizes must be positive") + else: + raise TypeError("each entry in pool_sizes must be either an int or a tuple or list") + + if len(filters) != (len(pool_sizes)+1): + raise ValueError(f"pool_sizes do not match number of filters. For {len(filters)}, please input {len(filters)-1} number of pools.") + + + if upsample_sizes is not None: + if not isinstance(upsample_sizes, (list, tuple)): + raise TypeError("upsample_sizes must be a list or tuple") + + for i, p in enumerate(upsample_sizes): + if isinstance(p, int) and not isinstance(p,bool): + if p < 1: + raise ValueError(f"upsample_sizes must be positive") + elif isinstance(p, (tuple,list)): + if len(p) != rank + 1: + raise ValueError(f"upsample_sizes tuple must be same length as rank") + for j , ps in enumerate(p): + if not isinstance(ps, int) or isinstance(ps,bool): + raise TypeError(f"upsample_sizes must be an int") + if ps < 1: + raise ValueError(f"upsample_sizes must be positive") + else: + raise TypeError("each entry in upsample_sizes must be either an int or a tuple or list") + + if len(filters) != (len(upsample_sizes)+1): + raise ValueError(f"upsample_sizes do not match number of filters. For {len(filters)}, please input {len(filters)-1} number of upsamples.") + + + super().__init__() + + self.encoder = ImageEncoderTime( + in_channels=in_channels, + filters=filters, + pool_sizes=pool_sizes, + kernel_size=kernel_size, + conv_blocks_per_level=conv_blocks_per_level, + rank=rank, + activation=activation, + norm_type=norm_type, + dropout_rate=dropout_rate + ) + + self.decoder = ImageDecoderTime( + filters=filters, + pool_sizes=pool_sizes, + upsample_sizes = upsample_sizes, + kernel_size=kernel_size, + conv_blocks_per_level=conv_blocks_per_level, + rank=rank, + upsample_type = upsample_type, + activation=activation, + norm_type=norm_type, + dropout_rate=dropout_rate + ) + + self.final_conv = ConvTime(filters[0], out_channels, kernel_size=1, padding=0,rank=rank) + + self.final_act = ( + get_activation(final_activation)() + if final_activation.lower() != "linear" + else nn.Identity() + ) + + def forward(self, x): + enc_feats = self.encoder(x) + x = self.decoder(enc_feats) + x = self.final_conv(x) + x = self.final_act(x) + return x