diff --git a/tests/layers/__init__.py b/tests/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py new file mode 100644 index 0000000000..8f0af005a3 --- /dev/null +++ b/tests/layers/test_drop.py @@ -0,0 +1,144 @@ +import importlib +import os +import unittest + +import torch + +from timm.layers import drop + +torch_backend = os.environ.get("TORCH_BACKEND") +if torch_backend is not None: + importlib.import_module(torch_backend) +torch_device = os.environ.get("TORCH_DEVICE", "cpu") + + +class Conv2dKernelMidpointMask2d(unittest.TestCase): + def test_conv2d_kernel_midpoint_mask_odd(self): + mask = drop.conv2d_kernel_midpoint_mask( + shape=(5, 7), + kernel=(3, 3), + device=torch_device, + dtype=torch.float32, + ) + print(mask) + assert mask.device == torch.device(torch_device) + assert mask.tolist() == [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + + def test_conv2d_kernel_midpoint_mask_even(self): + mask = drop.conv2d_kernel_midpoint_mask( + shape=(5, 7), + kernel=(2, 2), + device=torch_device, + dtype=torch.float32, + ) + print(mask) + assert mask.device == torch.device(torch_device) + assert mask.tolist() == [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ] + + def test_clip_mask_2d_kernel_too_big(self): + try: + drop.conv2d_kernel_midpoint_mask( + shape=(4, 7), + kernel=(5, 5), + device=torch_device, + dtype=torch.float32, + ) + raise RuntimeError("Expected throw") + + except AssertionError as e: + assert "kernel=(5, 5) ! <= shape=(4, 7)" in e.args[0] + + +class DropBlock2dDropFilterTest(unittest.TestCase): + def test_drop_filter(self): + selection = ( + torch.tensor( + [ + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=torch_device, + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + result = drop.drop_block_2d_drop_filter_( + selection=selection, kernel=(2, 3), partial_edge_blocks=False + ).squeeze() + print(result) + assert result.device == torch.device(torch_device) + assert result.tolist() == [ + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + + def test_drop_filter_partial_edge_blocks(self): + selection = ( + torch.tensor( + [ + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=torch_device, + dtype=torch.float32, + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + result = drop.drop_block_2d_drop_filter_( + selection=selection, kernel=(2, 3), partial_edge_blocks=True + ).squeeze() + print(result) + assert result.device == torch.device(torch_device) + assert result.tolist() == [ + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + ] + + +class DropBlock2dTest(unittest.TestCase): + def test_drop_block_2d(self): + tensor = torch.ones((1, 1, 200, 300), device=torch_device) + + drop_prob = 0.1 + keep_prob = 1.0 - drop_prob + + result = drop.drop_block_2d( + tensor, + drop_prob=drop_prob, + with_noise=True, + ).squeeze() + + numel = float(result.numel()) + unchanged = float(len(result[result == 1.0])) + keep_ratio = unchanged / numel + + assert ( + abs(keep_ratio - keep_prob) < 0.05 + ), f"abs({keep_ratio=} - {keep_prob=}) ! < 0.05" diff --git a/tests/test_models.py b/tests/test_models.py index 028440179d..8ed2ca5231 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -204,7 +204,7 @@ def test_model_forward(model_name, batch_size): @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): - """Run a single forward pass with each model""" + """Run a single forward and backward pass with each model""" input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) if max(input_size) > MAX_BWD_SIZE: pytest.skip("Fixed input size model > limit.") @@ -594,7 +594,7 @@ def _create_fx_model(model, train=False): return fx_model -EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*'] +EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*', '*dropblock*'] # not enough memory to run fx on more models than other tests if 'GITHUB_ACTIONS' in os.environ: EXCLUDE_FX_FILTERS += [ diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 289245f5ad..e784edad99 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -1,4 +1,4 @@ -""" DropBlock, DropPath +"""DropBlock, DropPath PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. @@ -14,119 +14,270 @@ Hacked together by / Copyright 2020 Ross Wightman """ + +from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F -from .grid import ndgrid + +def conv2d_kernel_midpoint_mask( + *, + kernel: Tuple[int, int], + inplace=None, + shape: Optional[Tuple[int, int]] = None, + device=None, + dtype=None, +): + """Build a mask of kernel midpoints. + + This predicts the kernel midpoints that conv2d (and related kernel functions) + would place a kernel. + + The *midpoint* of a kernel is computed as ``size / 2``: + * the midpoint of odd kernels is the middle: `mid(3) == 1` + * the midpoint of even kernels is the first point in the second half: `mid(4) == 2` + + Requires `kernel <= min(h, w)`. + + A new mask of `1`s is allocated, and then the `0` locations are cleared. + + Args: + kernel: the (kh, hw) shape of the kernel. + inplace: use the provided tensor as the mask; set masked-out values to 0. + shape: the (h, w) shape of the tensor. + device: the target device. + dtype: the target dtype. + + Returns: + a (h, w) bool mask tensor. + """ + if inplace is None: + assert shape is not None, f"shape is required when inplace is None." + assert dtype is not None, f"dtype is required when inplace is None." + assert device is not None, f"device is required when inplace is None." + + mask = torch.ones(shape, dtype=dtype, device=device) + else: + assert shape is None, f"shape and inplace are incompatile" + assert dtype is None, f"dtype and inplace are incompatile" + assert device is None, f"device and inplace are incompatile" + + mask = inplace + shape = inplace.shape[-2:] + device = inplace.device + dtype = inplace.dtype + + h, w = shape + kh, kw = kernel + assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}" + + mask[..., 0 : kh // 2, :] = 0 + mask[..., :, 0 : kw // 2 :] = 0 + mask[..., h - ((kh - 1) // 2) :, :] = 0 + mask[..., :, w - ((kw - 1) // 2) :] = 0 + + return mask + + +def drop_block_2d_drop_filter_( + *, + selection, + inplace: bool = False, + kernel: Tuple[int, int], + partial_edge_blocks: bool, +): + """Convert drop block gamma noise to a drop filter. + + This is a deterministic internal component of drop_block_2d. + + Args: + selection: 4D (B, C, H, W) input selection noise; + `1.0` at the midpoints of selected blocks to drop, + `0.0` everywhere else. Expected to be gamma noise. + inplace: permit in-place updates to `selection`. + kernel: the shape of the 2d kernel. + partial_edge_blocks: permit partial blocks at the edges, faster. + + Returns: + A drop filter, `1.0` at points to drop, `0.0` at points to keep. + """ + if not partial_edge_blocks: + if inplace: + selection = conv2d_kernel_midpoint_mask( + kernel=kernel, + inplace=selection, + ) + else: + selection = selection * conv2d_kernel_midpoint_mask( + shape=selection.shape[-2:], + kernel=kernel, + dtype=selection.dtype, + device=selection.device, + ) + + kh, kw = kernel + + drop_filter = F.max_pool2d( + selection, + kernel_size=kernel, + stride=1, + padding=[kh // 2, kw // 2], + ) + if (kh % 2 == 0) or (kw % 2 == 0): + drop_filter = drop_filter[..., (kh % 2 == 0) :, (kw % 2 == 0) :] + + return drop_filter def drop_block_2d( - x, - drop_prob: float = 0.1, - block_size: int = 7, - gamma_scale: float = 1.0, - with_noise: bool = False, - inplace: bool = False, - batchwise: bool = False + x, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, + couple_channels: bool = False, + partial_edge_blocks: bool = False, ): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. This layer has been tested on a few training runs with success, but needs further validation and possibly optimization for lower runtime impact. + + Args: + drop_prob: the probability of dropping any given block. + block_size: the size of the dropped blocks; should be odd. + gamma_scale: adjustment scale for the drop_prob. + with_noise: should normal noise be added to the dropped region? + inplace: if the drop should be applied in-place on the input tensor. + batchwise: when true, the entire batch is shares the same drop mask; much faster. + couple_channels: when true, channels share the same drop mask; + much faster, with significant semantic impact. + partial_edge_blocks: partial-blocks at the edges; minor speedup, minor semantic impact. + + Returns: + If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device. """ B, C, H, W = x.shape - total_size = W * H - clipped_block_size = min(block_size, min(W, H)) - # seed_drop_rate, the gamma parameter - gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (W - block_size + 1) * (H - block_size + 1)) - - # Forces the block to be inside the feature map. - w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device)) - valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ - ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) - valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) - - if batchwise: - # one mask for whole batch, quite a bit faster - uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) - else: - uniform_noise = torch.rand_like(x) - block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) - block_mask = -F.max_pool2d( - -block_mask, - kernel_size=clipped_block_size, # block_size, - stride=1, - padding=clipped_block_size // 2) + + # TODO: This behaves oddly when clipped_block_size < block_size. + # We could expose non-square blocks above this layer. + kernel = [min(block_size, H), min(block_size, W)] + kh, kw = kernel + + # batchwise => one mask for whole batch, quite a bit faster + noise_shape = (1 if batchwise else B, 1 if couple_channels else C, H, W) + + gamma = ( + float(gamma_scale * drop_prob * H * W) + / float(kh * kw) + / float((H - kh + 1) * (W - kw + 1)) + ) + + with torch.no_grad(): + drop_filter = drop_block_2d_drop_filter_( + kernel=kernel, + partial_edge_blocks=partial_edge_blocks, + inplace=True, + selection=torch.empty( + noise_shape, + dtype=x.dtype, + device=x.device, + ).bernoulli_(gamma), + ) + keep_filter = 1.0 - drop_filter if with_noise: - normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + # x += (noise * drop_filter) + with torch.no_grad(): + drop_noise = torch.randn_like(drop_filter) + drop_noise.mul_(drop_filter) + if inplace: - x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + x.mul_(keep_filter) + x.add_(drop_noise) else: - x = x * block_mask + normal_noise * (1 - block_mask) + x = x * keep_filter + drop_noise + else: - normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + # x *= (size(keep_filter) / (sum(keep_filter) + eps)) + with torch.no_grad(): + count = keep_filter.numel() + total = keep_filter.to(dtype=torch.float32).sum() + keep_scale = count / total.add(1e-7).to(x.dtype) + + keep_filter.mul_(keep_scale) + if inplace: - x.mul_(block_mask * normalize_scale) + x.mul_(keep_filter) else: - x = x * block_mask * normalize_scale + x = x * keep_filter + return x def drop_block_fast_2d( - x: torch.Tensor, - drop_prob: float = 0.1, - block_size: int = 7, - gamma_scale: float = 1.0, - with_noise: bool = False, - inplace: bool = False, + x: torch.Tensor, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, ): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid block mask at edges. """ - B, C, H, W = x.shape - total_size = W * H - clipped_block_size = min(block_size, min(W, H)) - gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (W - block_size + 1) * (H - block_size + 1)) - - block_mask = torch.empty_like(x).bernoulli_(gamma) - block_mask = F.max_pool2d( - block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) - - if with_noise: - normal_noise = torch.empty_like(x).normal_() - if inplace: - x.mul_(1. - block_mask).add_(normal_noise * block_mask) - else: - x = x * (1. - block_mask) + normal_noise * block_mask - else: - block_mask = 1 - block_mask - normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) - if inplace: - x.mul_(block_mask * normalize_scale) - else: - x = x * block_mask * normalize_scale - return x + drop_block_2d( + x=x, + drop_prob=drop_prob, + block_size=block_size, + gamma_scale=gamma_scale, + with_noise=with_noise, + inplace=inplace, + batchwise=True, + partial_edge_blocks=True, + ) class DropBlock2d(nn.Module): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + Args: + drop_prob: the probability of dropping any given block. + block_size: the size of the dropped blocks; should be odd. + gamma_scale: adjustment scale for the drop_prob. + with_noise: should normal noise be added to the dropped region? + inplace: if the drop should be applied in-place on the input tensor. + batchwise: when true, the entire batch is shares the same drop mask; much faster. + couple_channels: when true, channels share the same drop mask; + much faster, with significant semantic impact. + partial_edge_blocks: partial-blocks at the edges; minor speedup, minor semantic impact. """ + drop_prob: float + block_size: int + gamma_scale: float + with_noise: bool + inplace: bool + batchwise: bool + couple_channels: bool + partial_edge_blocks: bool + def __init__( - self, - drop_prob: float = 0.1, - block_size: int = 7, - gamma_scale: float = 1.0, - with_noise: bool = False, - inplace: bool = False, - batchwise: bool = False, - fast: bool = True): + self, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = True, + couple_channels: bool = False, + partial_edge_blocks: bool = False, + ): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob self.gamma_scale = gamma_scale @@ -134,20 +285,30 @@ def __init__( self.with_noise = with_noise self.inplace = inplace self.batchwise = batchwise - self.fast = fast # FIXME finish comparisons of fast vs not + self.partial_edge_blocks = partial_edge_blocks def forward(self, x): if not self.training or not self.drop_prob: return x - if self.fast: - return drop_block_fast_2d( - x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) - else: - return drop_block_2d( - x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + return drop_block_2d( + x=x, + drop_prob=self.drop_prob, + block_size=self.block_size, + gamma_scale=self.gamma_scale, + with_noise=self.with_noise, + inplace=self.inplace, + batchwise=self.batchwise, + partial_edge_blocks=self.partial_edge_blocks, + ) -def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): +def drop_path( + x, + drop_prob: float = 0.0, + training: bool = False, + scale_by_keep: bool = True, +): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, @@ -157,26 +318,39 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b 'survival rate' as the argument. """ - if drop_prob == 0. or not training: + if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + + # work with diff dim tensors, not just 2D ConvNets + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - """ - def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__( + self, + drop_prob: float = 0.0, + scale_by_keep: bool = True, + ): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): - return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + return drop_path( + x, + drop_prob=self.drop_prob, + training=self.training, + scale_by_keep=self.scale_by_keep, + ) def extra_repr(self): - return f'drop_prob={round(self.drop_prob,3):0.3f}' + return f"drop_prob={round(self.drop_prob,3):0.3f}" diff --git a/timm/models/resnet.py b/timm/models/resnet.py index ca07682b80..54aad03576 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -288,19 +288,32 @@ def downsample_avg( ]) -def drop_blocks(drop_prob: float = 0.) -> List[Optional[partial]]: +def drop_blocks(drop_prob: float = 0., **kwargs) -> List[Optional[partial]]: """Create DropBlock layer instances for each stage. Args: drop_prob: Drop probability for DropBlock. + fast: use the fast / unsafe variant. Returns: List of DropBlock partial instances or None for each stage. """ return [ - None, None, - partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None, - partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None] + None, + None, + partial( + DropBlock2d, + drop_prob=drop_prob, + block_size=5, + gamma_scale=0.25, + **kwargs) if drop_prob else None, + partial( + DropBlock2d, + drop_prob=drop_prob, + block_size=3, + gamma_scale=1.00, + **kwargs) if drop_prob else None, + ] def make_blocks( @@ -313,6 +326,9 @@ def make_blocks( down_kernel_size: int = 1, avg_down: bool = False, drop_block_rate: float = 0., + drop_block_batchwise: bool = True, + drop_block_couple_channels: bool = False, + drop_block_partial_edge_blocks: bool = True, drop_path_rate: float = 0., **kwargs, ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]: @@ -328,6 +344,10 @@ def make_blocks( down_kernel_size: Kernel size for downsample layers. avg_down: Use average pooling for downsample. drop_block_rate: DropBlock drop rate. + drop_block_batchwise: Batchwise block dropping, much faster. + drop_block_couple_channels: Couple channel drops. + drop_block_partial_edge_blocks: Permit partial drop blocks on the edge, + slightly faster. drop_path_rate: Drop path rate for stochastic depth. **kwargs: Additional arguments passed to block constructors. @@ -340,7 +360,16 @@ def make_blocks( net_block_idx = 0 net_stride = 4 dilation = prev_dilation = 1 - for stage_idx, (block_fn, planes, num_blocks, db) in enumerate(zip(block_fns, channels, block_repeats, drop_blocks(drop_block_rate))): + for stage_idx, (block_fn, planes, num_blocks, db) in enumerate(zip( + block_fns, + channels, + block_repeats, + drop_blocks( + drop_prob=drop_block_rate, + batchwise=drop_block_batchwise, + couple_channels=drop_block_couple_channels, + partial_edge_blocks=drop_block_partial_edge_blocks, + ))): stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it stride = 1 if stage_idx == 0 else 2 if net_stride >= output_stride: @@ -440,8 +469,11 @@ def __init__( norm_layer: LayerType = nn.BatchNorm2d, aa_layer: Optional[Type[nn.Module]] = None, drop_rate: float = 0.0, - drop_path_rate: float = 0., drop_block_rate: float = 0., + drop_block_batchwise: bool = True, + drop_block_couple_channels: bool = False, + drop_block_partial_edge_blocks: bool = True, + drop_path_rate: float = 0., zero_init_last: bool = True, block_args: Optional[Dict[str, Any]] = None, ): @@ -470,8 +502,11 @@ def __init__( norm_layer (str, nn.Module): normalization layer aa_layer (nn.Module): anti-aliasing layer drop_rate (float): Dropout probability before classifier, for training (default 0.) - drop_path_rate (float): Stochastic depth drop-path rate (default 0.) drop_block_rate (float): Drop block rate (default 0.) + drop_block_batchwise (bool): Sample blocks batchwise, significantly faster. + drop_block_couple_channels (bool): couple channels when dropping blocks. + drop_block_partial_edge_blocks (bool): Partial block dropping at the edges, faster. + drop_path_rate (float): Stochastic depth drop-path rate (default 0.) zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) block_args (dict): Extra kwargs to pass through to block module """ @@ -542,6 +577,9 @@ def __init__( norm_layer=norm_layer, aa_layer=aa_layer, drop_block_rate=drop_block_rate, + drop_block_batchwise=drop_block_batchwise, + drop_block_couple_channels=drop_block_couple_channels, + drop_block_partial_edge_blocks=drop_block_partial_edge_blocks, drop_path_rate=drop_path_rate, **block_args, ) @@ -1427,6 +1465,39 @@ def resnet10t(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict(block=BasicBlock, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True) return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) +@register_model +def resnet10t_dropblock_slow(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-10-T model with drop_block_rate=0.05, using the slowest DropBlock2d features. + """ + model_args = dict( + block=BasicBlock, + layers=(1, 1, 1, 1), + stem_width=32, + stem_type='deep_tiered', + avg_down=True, + drop_block_rate=0.05, + drop_block_batchwise=False, + drop_block_couple_channels=False, + drop_block_partial_edge_blocks=True, + ) + return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) + +@register_model +def resnet10t_dropblock_fast(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-10-T model with drop_block_rate=0.05, using the fastest DropBlock2d features. + """ + model_args = dict( + block=BasicBlock, + layers=(1, 1, 1, 1), + stem_width=32, + stem_type='deep_tiered', + avg_down=True, + drop_block_rate=0.05, + drop_block_batchwise=True, + drop_block_couple_channels=True, + drop_block_partial_edge_blocks=False, + ) + return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) @register_model def resnet14t(pretrained: bool = False, **kwargs) -> ResNet: