From 03c57a497b7b03be01f8ca3d66d817736e7dfc9d Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Fri, 15 Aug 2025 22:05:32 -0700 Subject: [PATCH 01/14] Fix bug in timm.layers.drop.drop_block_2d when `H != W`. There are two bugs in the `valid_block` code for `drop_block_2d`. - a (W, H) grid being reshaped as (H, W) The current code uses (W, H) to generate the meshgrid; but then uses a `.reshape((1, 1, H, W))` to unsqueeze the block map. The simplest fix to the first bug is a one-line change: ```python h_i, w_i = ndgrid(torch.arange(H), torch.arange(W)) ``` This is a longer patch, that attempts to make the code testable. Note: The current code behaves oddly when the block_size or clipped_block_size is even; I've added tests exposing the behavior; but have not changed it. When you trigger the reshape bug, you get wild results: ``` $ python scratch.py {'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': False} grid.shape=torch.Size([1, 1, 4, 5]) tensor([[[[False, False, False, False, False], [ True, True, False, False, True], [ True, False, False, True, True], [False, False, False, False, False]]]]) {'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': True} grid.shape=torch.Size([1, 1, 4, 5]) tensor([[[[False, False, False, False, False], [False, True, True, True, False], [False, True, True, True, False], [False, False, False, False, False]]]]) ``` Here's a tiny exceprt script, showing the problem; it generated the above output. ```python import torch from typing import Tuple def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]: """generate N-D grid in dimension order. The ndgrid function is like meshgrid except that the order of the first two input arguments are switched. That is, the statement [X1,X2,X3] = ndgrid(x1,x2,x3) produces the same result as [X2,X1,X3] = meshgrid(x2,x1,x3) This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy'). """ try: return torch.meshgrid(*tensors, indexing='ij') except TypeError: # old PyTorch < 1.10 will follow this path as it does not have indexing arg, # the old behaviour of meshgrid was 'ij' return torch.meshgrid(*tensors) def valid_block(H, W, block_size, fix_reshape=False): clipped_block_size = min(block_size, H, W) if fix_reshape: # This should match the .reshape() dimension order below. h_i, w_i = ndgrid(torch.arange(H), torch.arange(W)) else: # The original produces crazy stride patterns, due to .reshape() offset winding. # This is only visible when H != W. w_i, h_i = ndgrid(torch.arange(W), torch.arange(H)) valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) valid_block = torch.reshape(valid_block, (1, 1, H, W)) return valid_block def main(): common_args = dict(H=4, W=5, block_size=3) for fix in [False, True]: args = dict(H=4, W=5, block_size=3, fix_reshape=fix) grid = valid_block(**args) print(args) print(f"{grid.shape=}") print(grid) print() if __name__ == "__main__": main() ``` --- tests/layers/__init__.py | 0 tests/layers/test_drop.py | 47 ++++++++++++++++++++++++++++ timm/layers/drop.py | 65 +++++++++++++++++++++++++++++++++++---- 3 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 tests/layers/__init__.py create mode 100644 tests/layers/test_drop.py diff --git a/tests/layers/__init__.py b/tests/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py new file mode 100644 index 0000000000..a46529effb --- /dev/null +++ b/tests/layers/test_drop.py @@ -0,0 +1,47 @@ +import importlib +import os +import unittest + +import torch + +from timm.layers import drop + +torch_backend = os.environ.get('TORCH_BACKEND') +if torch_backend is not None: + importlib.import_module(torch_backend) +torch_device = os.environ.get('TORCH_DEVICE', 'cpu') + +class ClipMaskTests(unittest.TestCase): + def test_clip_mask_2d_odd(self): + mask = drop.clip_mask_2d(h=5, w=7, kernel=3, device=torch_device) + assert mask.device == torch.device(torch_device) + assert mask.tolist() == \ + [ + [False, False, False, False, False, False, False], + [False, True, True, True, True, True, False], + [False, True, True, True, True, True, False], + [False, True, True, True, True, True, False], + [False, False, False, False, False, False, False], + ] + + def test_clip_mask_2d_even(self): + mask = drop.clip_mask_2d(h=5, w=7, kernel=2, device=torch_device) + assert mask.device == torch.device(torch_device) + # TODO: This is a suprising result; should even kernels be forbidden? + assert mask.tolist() == \ + [ + [False, False, False, False, False, False, False], + [False, True, True, True, True, True, True], + [False, True, True, True, True, True, True], + [False, True, True, True, True, True, True], + [False, True, True, True, True, True, True], + ] + + def test_clip_mask_2d_kernel_too_big(self): + try: + drop.clip_mask_2d(h=4, w=7, kernel=5, device=torch_device) + raise RuntimeError("Expected throw") + + except AssertionError as e: + assert "kernel=5 > min(h=4, w=7)" in e.args[0] + diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 289245f5ad..c60768d0d6 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -21,6 +21,43 @@ from .grid import ndgrid +def clip_mask_2d( + h: int, + w: int, + kernel: int, + device, +): + """Build a clip mask. + + Returns a mask of all points which permit a (kernel, kernel) sized + block to sit entirely within the (h, w) index space. + + Requires `kernel <= min(h, w)`. + + TODO: Should even kernels be forbidden? + Even kernels behave oddly, but are not forbidden for historical reasons. + + Args: + h: the height. + w: the width. + kernel_size: the size of the kernel. + device: the target device. + check_kernel: when true, assert that the kernel_size is odd. + + Returns: + a (h, w) bool mask tensor. + """ + assert kernel <= min(h, w), f"{kernel=} > min({h=}, {w=})" + + h_i, w_i = ndgrid(torch.arange(h, device=device), torch.arange(w, device=device)) + return ( + (h_i >= kernel // 2) & + (h_i < h - (kernel - 1) // 2) & + (w_i >= kernel // 2) & + (w_i < w - (kernel - 1) // 2) + ).reshape(h, w) + + def drop_block_2d( x, drop_prob: float = 0.1, @@ -30,23 +67,39 @@ def drop_block_2d( inplace: bool = False, batchwise: bool = False ): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. This layer has been tested on a few training runs with success, but needs further validation and possibly optimization for lower runtime impact. + + Args: + drop_prob: the probability of dropping any given block. + block_size: the size of the dropped blocks; should be odd. + gamma_scale: adjustment scale for the drop_prob. + with_noise: should normal noise be added to the dropped region? + inplace: if the drop should be applied in-place on the input tensor. + batchwise: should the entire batch use the same drop mask? + + Returns: + If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device. """ B, C, H, W = x.shape total_size = W * H - clipped_block_size = min(block_size, min(W, H)) + + # TODO: This behaves oddly when clipped_block_size < block_size, or block_size % 2 == 0. + clipped_block_size = min(block_size, W, H) + # seed_drop_rate, the gamma parameter gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( (W - block_size + 1) * (H - block_size + 1)) # Forces the block to be inside the feature map. - w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device)) - valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ - ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) - valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + valid_block = clip_mask_2d( + h=H, + w=W, + kernel=clipped_block_size, + device=x.device, + ).reshape((1, 1, H, W)).to(dtype=x.dtype) if batchwise: # one mask for whole batch, quite a bit faster From 0c80c3f5219a10693d1f8a0d354df893fad0ff38 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Fri, 15 Aug 2025 23:43:24 -0700 Subject: [PATCH 02/14] switch to slice assign --- tests/layers/test_drop.py | 2 ++ timm/layers/drop.py | 14 +++++--------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py index a46529effb..ae881182af 100644 --- a/tests/layers/test_drop.py +++ b/tests/layers/test_drop.py @@ -14,6 +14,7 @@ class ClipMaskTests(unittest.TestCase): def test_clip_mask_2d_odd(self): mask = drop.clip_mask_2d(h=5, w=7, kernel=3, device=torch_device) + print(mask) assert mask.device == torch.device(torch_device) assert mask.tolist() == \ [ @@ -26,6 +27,7 @@ def test_clip_mask_2d_odd(self): def test_clip_mask_2d_even(self): mask = drop.clip_mask_2d(h=5, w=7, kernel=2, device=torch_device) + print(mask) assert mask.device == torch.device(torch_device) # TODO: This is a suprising result; should even kernels be forbidden? assert mask.tolist() == \ diff --git a/timm/layers/drop.py b/timm/layers/drop.py index c60768d0d6..a18bd8a285 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -18,8 +18,6 @@ import torch.nn as nn import torch.nn.functional as F -from .grid import ndgrid - def clip_mask_2d( h: int, @@ -49,13 +47,11 @@ def clip_mask_2d( """ assert kernel <= min(h, w), f"{kernel=} > min({h=}, {w=})" - h_i, w_i = ndgrid(torch.arange(h, device=device), torch.arange(w, device=device)) - return ( - (h_i >= kernel // 2) & - (h_i < h - (kernel - 1) // 2) & - (w_i >= kernel // 2) & - (w_i < w - (kernel - 1) // 2) - ).reshape(h, w) + mask = torch.zeros((h, w), dtype=torch.bool, device=device) + start = kernel // 2 + end = ((kernel - 1) // 2) + mask[start:h-end, start:w-end] = True + return mask def drop_block_2d( From fa047b88a16bb34fc782cf7b982a4fd93d456ecf Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Mon, 18 Aug 2025 17:56:22 -0700 Subject: [PATCH 03/14] Update based on review discussion --- tests/layers/test_drop.py | 50 ++++++++++++++++++++++++++++++----- timm/layers/drop.py | 55 ++++++++++++++++++++++----------------- 2 files changed, 74 insertions(+), 31 deletions(-) diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py index ae881182af..7f7fd504a9 100644 --- a/tests/layers/test_drop.py +++ b/tests/layers/test_drop.py @@ -11,9 +11,9 @@ importlib.import_module(torch_backend) torch_device = os.environ.get('TORCH_DEVICE', 'cpu') -class ClipMaskTests(unittest.TestCase): - def test_clip_mask_2d_odd(self): - mask = drop.clip_mask_2d(h=5, w=7, kernel=3, device=torch_device) +class Conv2dKernelMidpointMask2d(unittest.TestCase): + def test_conv2d_kernel_midpoint_mask_odd_bool(self): + mask = drop.conv2d_kernel_midpoint_mask(shape=(5, 7), kernel=(3, 3), device=torch_device) print(mask) assert mask.device == torch.device(torch_device) assert mask.tolist() == \ @@ -25,8 +25,44 @@ def test_clip_mask_2d_odd(self): [False, False, False, False, False, False, False], ] - def test_clip_mask_2d_even(self): - mask = drop.clip_mask_2d(h=5, w=7, kernel=2, device=torch_device) + def test_conv2d_kernel_midpoint_mask_odd_float(self): + mask = drop.conv2d_kernel_midpoint_mask( + shape=(5, 7), + kernel=(3, 3), + device=torch_device, + dtype=torch.float32, + ) + print(mask) + assert mask.device == torch.device(torch_device) + assert mask.tolist() == \ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + + def test_conv2d_kernel_midpoint_mask_odd_int(self): + mask = drop.conv2d_kernel_midpoint_mask( + shape=(5, 7), + kernel=(3, 3), + device=torch_device, + dtype=torch.int32, + ) + print(mask) + assert mask.device == torch.device(torch_device) + assert mask.tolist() == \ + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 0], + [0, 1, 1, 1, 1, 1, 0], + [0, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + + def test_conv2d_kernel_midpoint_mask_even(self): + mask = drop.conv2d_kernel_midpoint_mask(shape=(5, 7), kernel=(2, 2), device=torch_device) print(mask) assert mask.device == torch.device(torch_device) # TODO: This is a suprising result; should even kernels be forbidden? @@ -41,9 +77,9 @@ def test_clip_mask_2d_even(self): def test_clip_mask_2d_kernel_too_big(self): try: - drop.clip_mask_2d(h=4, w=7, kernel=5, device=torch_device) + drop.conv2d_kernel_midpoint_mask(shape=(4, 7), kernel=(5, 5), device=torch_device) raise RuntimeError("Expected throw") except AssertionError as e: - assert "kernel=5 > min(h=4, w=7)" in e.args[0] + assert "kernel=(5, 5) ! <= shape=(4, 7)" in e.args[0] diff --git a/timm/layers/drop.py b/timm/layers/drop.py index a18bd8a285..4857a5f2a5 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -19,38 +19,45 @@ import torch.nn.functional as F -def clip_mask_2d( - h: int, - w: int, - kernel: int, +def conv2d_kernel_midpoint_mask( + shape: (int, int), + kernel: (int, int), device, + dtype = torch.bool, ): - """Build a clip mask. + """Build a mask of kernel midpoints. - Returns a mask of all points which permit a (kernel, kernel) sized - block to sit entirely within the (h, w) index space. + This predicts the kernel midpoints that conv2d (and related kernel functions) + would place a kernel. - Requires `kernel <= min(h, w)`. + The *midpoint* of a kernel is computed as ``size / 2``: + * the midpoint of odd kernels is the middle: `mid(3) == 1` + * the midpoint of even kernels is the first point in the second half: `mid(4) == 2` - TODO: Should even kernels be forbidden? - Even kernels behave oddly, but are not forbidden for historical reasons. + Requires `kernel <= min(h, w)`. Args: - h: the height. - w: the width. - kernel_size: the size of the kernel. + shape: the (h, w) shape of the tensor. + kernel: the (kh, hw) shape of the kernel. device: the target device. check_kernel: when true, assert that the kernel_size is odd. Returns: a (h, w) bool mask tensor. """ - assert kernel <= min(h, w), f"{kernel=} > min({h=}, {w=})" + h, w = shape + kh, kw = kernel + assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}" + + mask = torch.zeros((h, w), dtype=dtype, device=device) + + h_start = kh // 2 + h_end = (kh - 1) // 2 + + w_start = kw // 2 + w_end = (kw - 1) // 2 - mask = torch.zeros((h, w), dtype=torch.bool, device=device) - start = kernel // 2 - end = ((kernel - 1) // 2) - mask[start:h-end, start:w-end] = True + mask[h_start:h - h_end, w_start:w - w_end] = 1 return mask @@ -82,7 +89,7 @@ def drop_block_2d( B, C, H, W = x.shape total_size = W * H - # TODO: This behaves oddly when clipped_block_size < block_size, or block_size % 2 == 0. + # TODO: This behaves oddly when clipped_block_size < block_size. clipped_block_size = min(block_size, W, H) # seed_drop_rate, the gamma parameter @@ -90,12 +97,12 @@ def drop_block_2d( (W - block_size + 1) * (H - block_size + 1)) # Forces the block to be inside the feature map. - valid_block = clip_mask_2d( - h=H, - w=W, - kernel=clipped_block_size, + valid_block = conv2d_kernel_midpoint_mask( + shape=(H, W), + kernel=(clipped_block_size, clipped_block_size), device=x.device, - ).reshape((1, 1, H, W)).to(dtype=x.dtype) + dtype=x.dtype, + ).unsqueeze().unsqueeze() if batchwise: # one mask for whole batch, quite a bit faster From 2ffb37c89abe33b144882ead9cba8357308092d1 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Mon, 18 Aug 2025 20:44:23 -0700 Subject: [PATCH 04/14] Unify drop_block_2d / drop_block_fast_2d; add some actual tests --- tests/layers/test_drop.py | 26 +++++ tests/test_models.py | 4 +- timm/layers/drop.py | 202 +++++++++++++++++++++++--------------- timm/models/resnet.py | 72 +++++++++++++- 4 files changed, 220 insertions(+), 84 deletions(-) diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py index 7f7fd504a9..8602e3948f 100644 --- a/tests/layers/test_drop.py +++ b/tests/layers/test_drop.py @@ -25,6 +25,32 @@ def test_conv2d_kernel_midpoint_mask_odd_bool(self): [False, False, False, False, False, False, False], ] + def test_conv2d_kernel_midpoint_mask_odd_float_inplace(self): + mask = torch.tensor( + [ + [2.0, 1.0, 1.0, 1.0, 1.0, 7.0, 1.0], + [1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 8.0], + [9.0, 1.0, 4.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 5.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 6.0, 1.0, 1.0], + ], + device=torch_device, + ) + drop.conv2d_kernel_midpoint_mask( + kernel=(3, 3), + inplace_mask=mask, + ) + print(mask) + assert mask.device == torch.device(torch_device) + assert mask.tolist() == \ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 4.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 5.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + def test_conv2d_kernel_midpoint_mask_odd_float(self): mask = drop.conv2d_kernel_midpoint_mask( shape=(5, 7), diff --git a/tests/test_models.py b/tests/test_models.py index 028440179d..8ed2ca5231 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -204,7 +204,7 @@ def test_model_forward(model_name, batch_size): @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): - """Run a single forward pass with each model""" + """Run a single forward and backward pass with each model""" input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) if max(input_size) > MAX_BWD_SIZE: pytest.skip("Fixed input size model > limit.") @@ -594,7 +594,7 @@ def _create_fx_model(model, train=False): return fx_model -EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*'] +EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*', '*dropblock*'] # not enough memory to run fx on more models than other tests if 'GITHUB_ACTIONS' in os.environ: EXCLUDE_FX_FILTERS += [ diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 4857a5f2a5..413febbab5 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -14,16 +14,19 @@ Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F def conv2d_kernel_midpoint_mask( - shape: (int, int), - kernel: (int, int), - device, - dtype = torch.bool, + kernel: Tuple[int, int], + *, + inplace_mask = None, + shape: Optional[Tuple[int, int]] = None, + device = None, + dtype = None, ): """Build a mask of kernel midpoints. @@ -36,28 +39,53 @@ def conv2d_kernel_midpoint_mask( Requires `kernel <= min(h, w)`. + When an `inplace_mask` is not provided, a new mask of `1`s is allocated, + and then the `0` locations are cleared. + + When an `inplace_mask` is provided, the `0` locations are cleared on the mask, + and no other changes are made. `shape`, `dtype`, and `device` must match, if + they are provided. + Args: - shape: the (h, w) shape of the tensor. kernel: the (kh, hw) shape of the kernel. + inplace_mask: if supplied, updates will apply to the inplace_mask, + and device and dtype will be ignored. Only clears 'false' locations. + shape: the (h, w) shape of the tensor. device: the target device. - check_kernel: when true, assert that the kernel_size is odd. + dtype: the target dtype. Returns: a (h, w) bool mask tensor. """ + if inplace_mask is not None: + mask = inplace_mask + + if shape: + assert shape == mask.shape[-2], f"{shape=} !~= {mask.shape=}" + + shape = mask.shape + + if device: + device = torch.device(device) + assert device == mask.device, f"{device=} != {mask.device=}" + + if dtype: + dtype = torch.dtype(dtype) + assert dtype == inplace_mask.dtype, f"{dtype=} != {mask.dtype=}" + + else: + mask = torch.ones(shape, dtype=dtype, device=device) + h, w = shape kh, kw = kernel assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}" - mask = torch.zeros((h, w), dtype=dtype, device=device) + # Set to 0, rather than set to 1, so we can clear the inplace mask. + mask[:kh // 2, :] = 0 + mask[h - (kh - 1) // 2:, :] = 0 + mask[:, :kw // 2] = 0 + mask[:, w - (kw - 1) // 2:] = 0 - h_start = kh // 2 - h_end = (kh - 1) // 2 - - w_start = kw // 2 - w_end = (kw - 1) // 2 - - mask[h_start:h - h_end, w_start:w - w_end] = 1 return mask @@ -68,7 +96,8 @@ def drop_block_2d( gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, - batchwise: bool = False + batchwise: bool = False, + messy: bool = False, ): """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf @@ -82,6 +111,7 @@ def drop_block_2d( with_noise: should normal noise be added to the dropped region? inplace: if the drop should be applied in-place on the input tensor. batchwise: should the entire batch use the same drop mask? + messy: partial-blocks at the edges, faster. Returns: If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device. @@ -90,44 +120,55 @@ def drop_block_2d( total_size = W * H # TODO: This behaves oddly when clipped_block_size < block_size. - clipped_block_size = min(block_size, W, H) + clipped_block_size = min(block_size, H, W) + + gamma = ( + float(gamma_scale * drop_prob * total_size) + / float(clipped_block_size ** 2) + / float((H - block_size + 1) * (W - block_size + 1)) + ) - # seed_drop_rate, the gamma parameter - gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (W - block_size + 1) * (H - block_size + 1)) + # batchwise => one mask for whole batch, quite a bit faster + mask_shape = (1 if batchwise else B, C, H, W) - # Forces the block to be inside the feature map. - valid_block = conv2d_kernel_midpoint_mask( - shape=(H, W), - kernel=(clipped_block_size, clipped_block_size), - device=x.device, + block_mask = torch.empty( + mask_shape, dtype=x.dtype, - ).unsqueeze().unsqueeze() + device=x.device + ).bernoulli_(gamma) - if batchwise: - # one mask for whole batch, quite a bit faster - uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) - else: - uniform_noise = torch.rand_like(x) - block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) - block_mask = -F.max_pool2d( - -block_mask, - kernel_size=clipped_block_size, # block_size, + if not messy: + conv2d_kernel_midpoint_mask( + kernel=(clipped_block_size, clipped_block_size), + inplace_mask=block_mask, + ) + + block_mask = F.max_pool2d( + block_mask, + kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + if inplace: + x.mul_(block_mask) + else: + x = x * block_mask + + # From this point on, we do inplace ops on X. + if with_noise: - normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) - if inplace: - x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) - else: - x = x * block_mask + normal_noise * (1 - block_mask) + noise = torch.randn(mask_shape, dtype=x.dtype, device=x.device) + # x += (noise * (1 - block_mask)) + block_mask.neg_().add_(1) + noise.mul_(block_mask) + x.add_(noise) + else: - normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) - if inplace: - x.mul_(block_mask * normalize_scale) - else: - x = x * block_mask * normalize_scale + # x *= (size(block_mask) / sum(block_mask)) + total = block_mask.to(dtype=torch.float32).sum() + normalize_scale = block_mask.numel() / total.add(1e-7).to(x.dtype) + x.mul_(normalize_scale) + return x @@ -144,35 +185,37 @@ def drop_block_fast_2d( DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid block mask at edges. """ - B, C, H, W = x.shape - total_size = W * H - clipped_block_size = min(block_size, min(W, H)) - gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (W - block_size + 1) * (H - block_size + 1)) - - block_mask = torch.empty_like(x).bernoulli_(gamma) - block_mask = F.max_pool2d( - block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) - - if with_noise: - normal_noise = torch.empty_like(x).normal_() - if inplace: - x.mul_(1. - block_mask).add_(normal_noise * block_mask) - else: - x = x * (1. - block_mask) + normal_noise * block_mask - else: - block_mask = 1 - block_mask - normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) - if inplace: - x.mul_(block_mask * normalize_scale) - else: - x = x * block_mask * normalize_scale - return x + drop_block_2d( + x=x, + drop_prob=drop_prob, + block_size=block_size, + gamma_scale=gamma_scale, + with_noise=with_noise, + inplace=inplace, + batchwise=True, + messy=True, + ) class DropBlock2d(nn.Module): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + Args: + drop_prob: the probability of dropping any given block. + block_size: the size of the dropped blocks; should be odd. + gamma_scale: adjustment scale for the drop_prob. + with_noise: should normal noise be added to the dropped region? + inplace: if the drop should be applied in-place on the input tensor. + batchwise: should the entire batch use the same drop mask? + messy: partial-blocks at the edges, faster. """ + drop_prob: float + block_size: int + gamma_scale: float + with_noise: bool + inplace: bool + batchwise: bool + messy: bool def __init__( self, @@ -182,7 +225,8 @@ def __init__( with_noise: bool = False, inplace: bool = False, batchwise: bool = False, - fast: bool = True): + messy: bool = True, + ): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob self.gamma_scale = gamma_scale @@ -190,17 +234,21 @@ def __init__( self.with_noise = with_noise self.inplace = inplace self.batchwise = batchwise - self.fast = fast # FIXME finish comparisons of fast vs not + self.messy = messy def forward(self, x): if not self.training or not self.drop_prob: return x - if self.fast: - return drop_block_fast_2d( - x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) - else: - return drop_block_2d( - x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + return drop_block_2d( + x=x, + drop_prob=self.drop_prob, + block_size=self.block_size, + gamma_scale=self.gamma_scale, + with_noise=self.with_noise, + inplace=self.inplace, + batchwise=self.batchwise, + messy=self.messy) def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): diff --git a/timm/models/resnet.py b/timm/models/resnet.py index ca07682b80..8fe6900299 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -288,19 +288,32 @@ def downsample_avg( ]) -def drop_blocks(drop_prob: float = 0.) -> List[Optional[partial]]: +def drop_blocks(drop_prob: float = 0., **kwargs) -> List[Optional[partial]]: """Create DropBlock layer instances for each stage. Args: drop_prob: Drop probability for DropBlock. + fast: use the fast / unsafe variant. Returns: List of DropBlock partial instances or None for each stage. """ return [ - None, None, - partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None, - partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None] + None, + None, + partial( + DropBlock2d, + drop_prob=drop_prob, + block_size=5, + gamma_scale=0.25, + **kwargs) if drop_prob else None, + partial( + DropBlock2d, + drop_prob=drop_prob, + block_size=3, + gamma_scale=1.00, + **kwargs) if drop_prob else None, + ] def make_blocks( @@ -314,6 +327,8 @@ def make_blocks( avg_down: bool = False, drop_block_rate: float = 0., drop_path_rate: float = 0., + drop_block_batchwise: bool = False, + drop_block_messy: bool = True, **kwargs, ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]: """Create ResNet stages with specified block configurations. @@ -328,6 +343,8 @@ def make_blocks( down_kernel_size: Kernel size for downsample layers. avg_down: Use average pooling for downsample. drop_block_rate: DropBlock drop rate. + drop_block_batchwise: Batchwise block dropping, faster. + drop_block_messy: dropping produces partial blocks on the edge, faster. drop_path_rate: Drop path rate for stochastic depth. **kwargs: Additional arguments passed to block constructors. @@ -340,7 +357,15 @@ def make_blocks( net_block_idx = 0 net_stride = 4 dilation = prev_dilation = 1 - for stage_idx, (block_fn, planes, num_blocks, db) in enumerate(zip(block_fns, channels, block_repeats, drop_blocks(drop_block_rate))): + for stage_idx, (block_fn, planes, num_blocks, db) in enumerate(zip( + block_fns, + channels, + block_repeats, + drop_blocks( + drop_prob=drop_block_rate, + batchwise=drop_block_batchwise, + messy=drop_block_messy, + ))): stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it stride = 1 if stage_idx == 0 else 2 if net_stride >= output_stride: @@ -442,6 +467,8 @@ def __init__( drop_rate: float = 0.0, drop_path_rate: float = 0., drop_block_rate: float = 0., + drop_block_batchwise: bool = True, + drop_block_messy: bool = True, zero_init_last: bool = True, block_args: Optional[Dict[str, Any]] = None, ): @@ -472,6 +499,8 @@ def __init__( drop_rate (float): Dropout probability before classifier, for training (default 0.) drop_path_rate (float): Stochastic depth drop-path rate (default 0.) drop_block_rate (float): Drop block rate (default 0.) + drop_block_batchwise (bool): Sample blocks batchwise, faster. + drop_block_messy (bool): Partial block dropping at the edges, faster. zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) block_args (dict): Extra kwargs to pass through to block module """ @@ -542,6 +571,8 @@ def __init__( norm_layer=norm_layer, aa_layer=aa_layer, drop_block_rate=drop_block_rate, + drop_block_batchwise=drop_block_batchwise, + drop_block_messy=drop_block_messy, drop_path_rate=drop_path_rate, **block_args, ) @@ -1427,6 +1458,37 @@ def resnet10t(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict(block=BasicBlock, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True) return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) +@register_model +def resnet10t_dropblock_correct(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-10-T model with drop_block_rate=0.05, using the most accurate DropBlock2d features. + """ + model_args = dict( + block=BasicBlock, + layers=(1, 1, 1, 1), + stem_width=32, + stem_type='deep_tiered', + avg_down=True, + drop_block_rate=0.05, + drop_block_batchwise=True, + drop_block_messy=True, + ) + return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) + +@register_model +def resnet10t_dropblock_fast(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-10-T model with drop_block_rate=0.05, using the fastest DropBlock2d features. + """ + model_args = dict( + block=BasicBlock, + layers=(1, 1, 1, 1), + stem_width=32, + stem_type='deep_tiered', + avg_down=True, + drop_block_rate=0.05, + drop_block_batchwise=False, + drop_block_messy=False, + ) + return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) @register_model def resnet14t(pretrained: bool = False, **kwargs) -> ResNet: From 81714c1799769fe8b9af287200be3877f86a3655 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Tue, 19 Aug 2025 02:03:54 -0700 Subject: [PATCH 05/14] fix a bug in drop_filter vs keep_filter; even more tests --- tests/layers/test_drop.py | 127 ++++++++++++++++++++++++++-------- timm/layers/drop.py | 141 +++++++++++++++++++++----------------- 2 files changed, 176 insertions(+), 92 deletions(-) diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py index 8602e3948f..948cc55197 100644 --- a/tests/layers/test_drop.py +++ b/tests/layers/test_drop.py @@ -13,7 +13,12 @@ class Conv2dKernelMidpointMask2d(unittest.TestCase): def test_conv2d_kernel_midpoint_mask_odd_bool(self): - mask = drop.conv2d_kernel_midpoint_mask(shape=(5, 7), kernel=(3, 3), device=torch_device) + mask = drop.conv2d_kernel_midpoint_mask( + shape=(5, 7), + kernel=(3, 3), + device=torch_device, + dtype=torch.bool, + ) print(mask) assert mask.device == torch.device(torch_device) assert mask.tolist() == \ @@ -25,32 +30,6 @@ def test_conv2d_kernel_midpoint_mask_odd_bool(self): [False, False, False, False, False, False, False], ] - def test_conv2d_kernel_midpoint_mask_odd_float_inplace(self): - mask = torch.tensor( - [ - [2.0, 1.0, 1.0, 1.0, 1.0, 7.0, 1.0], - [1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 8.0], - [9.0, 1.0, 4.0, 1.0, 1.0, 1.0, 1.0], - [1.0, 1.0, 1.0, 5.0, 1.0, 1.0, 1.0], - [1.0, 1.0, 1.0, 1.0, 6.0, 1.0, 1.0], - ], - device=torch_device, - ) - drop.conv2d_kernel_midpoint_mask( - kernel=(3, 3), - inplace_mask=mask, - ) - print(mask) - assert mask.device == torch.device(torch_device) - assert mask.tolist() == \ - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 3.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 4.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 5.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ] - def test_conv2d_kernel_midpoint_mask_odd_float(self): mask = drop.conv2d_kernel_midpoint_mask( shape=(5, 7), @@ -88,10 +67,14 @@ def test_conv2d_kernel_midpoint_mask_odd_int(self): ] def test_conv2d_kernel_midpoint_mask_even(self): - mask = drop.conv2d_kernel_midpoint_mask(shape=(5, 7), kernel=(2, 2), device=torch_device) + mask = drop.conv2d_kernel_midpoint_mask( + shape=(5, 7), + kernel=(2, 2), + device=torch_device, + dtype=torch.bool, + ) print(mask) assert mask.device == torch.device(torch_device) - # TODO: This is a suprising result; should even kernels be forbidden? assert mask.tolist() == \ [ [False, False, False, False, False, False, False], @@ -103,9 +86,93 @@ def test_conv2d_kernel_midpoint_mask_even(self): def test_clip_mask_2d_kernel_too_big(self): try: - drop.conv2d_kernel_midpoint_mask(shape=(4, 7), kernel=(5, 5), device=torch_device) + drop.conv2d_kernel_midpoint_mask( + shape=(4, 7), + kernel=(5, 5), + device=torch_device, + dtype=torch.bool, + ) raise RuntimeError("Expected throw") except AssertionError as e: assert "kernel=(5, 5) ! <= shape=(4, 7)" in e.args[0] + +class DropBlock2dDropFilterTest(unittest.TestCase): + def test_drop_filter(self): + selection = torch.tensor( + [ + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=torch_device, + ).unsqueeze(0).unsqueeze(0) + + result = drop.drop_block_2d_drop_filter_( + selection=selection, + kernel=(2, 3), + messy=False + ).squeeze() + print(result) + assert result.device == torch.device(torch_device) + assert result.tolist() == \ + [ + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + + def test_drop_filter_messy(self): + selection = torch.tensor( + [ + [0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1], + ], + device=torch_device, + dtype=torch.int32, + ).unsqueeze(0).unsqueeze(0) + + result = drop.drop_block_2d_drop_filter_( + selection=selection, + kernel=(2, 3), + messy=True + ).squeeze() + print(result) + assert result.device == torch.device(torch_device) + assert result.tolist() == \ + [ + [1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 0, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 0, 1, 1], + ] + +class DropBlock2dTest(unittest.TestCase): + def test_drop_block_2d(self): + tensor = torch.ones((1, 1, 200, 300), device=torch_device) + + drop_prob=0.1 + keep_prob = 1.0 - drop_prob + + result = drop.drop_block_2d( + tensor, + drop_prob=drop_prob, + with_noise=True, + ).squeeze() + + numel = float(result.numel()) + unchanged = float(len(result[result == 1.0])) + keep_ratio = unchanged / numel + + assert abs(keep_ratio - keep_prob) < 0.05, \ + f"abs({keep_ratio=} - {keep_prob=}) ! < 0.05" + diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 413febbab5..66b85baa67 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -21,12 +21,11 @@ def conv2d_kernel_midpoint_mask( - kernel: Tuple[int, int], *, - inplace_mask = None, - shape: Optional[Tuple[int, int]] = None, - device = None, - dtype = None, + shape: Tuple[int, int], + kernel: Tuple[int, int], + device, + dtype, ): """Build a mask of kernel midpoints. @@ -39,17 +38,10 @@ def conv2d_kernel_midpoint_mask( Requires `kernel <= min(h, w)`. - When an `inplace_mask` is not provided, a new mask of `1`s is allocated, - and then the `0` locations are cleared. - - When an `inplace_mask` is provided, the `0` locations are cleared on the mask, - and no other changes are made. `shape`, `dtype`, and `device` must match, if - they are provided. + A new mask of `1`s is allocated, and then the `0` locations are cleared. Args: kernel: the (kh, hw) shape of the kernel. - inplace_mask: if supplied, updates will apply to the inplace_mask, - and device and dtype will be ignored. Only clears 'false' locations. shape: the (h, w) shape of the tensor. device: the target device. dtype: the target dtype. @@ -57,36 +49,58 @@ def conv2d_kernel_midpoint_mask( Returns: a (h, w) bool mask tensor. """ - if inplace_mask is not None: - mask = inplace_mask + h, w = shape + kh, kw = kernel + assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}" - if shape: - assert shape == mask.shape[-2], f"{shape=} !~= {mask.shape=}" + mask = torch.zeros(shape, dtype=dtype, device=device) - shape = mask.shape + mask[kh//2: h - ((kh - 1) // 2), kw//2: w - ((kw - 1) // 2)] = 1.0 - if device: - device = torch.device(device) - assert device == mask.device, f"{device=} != {mask.device=}" + return mask - if dtype: - dtype = torch.dtype(dtype) - assert dtype == inplace_mask.dtype, f"{dtype=} != {mask.dtype=}" - else: - mask = torch.ones(shape, dtype=dtype, device=device) +def drop_block_2d_drop_filter_( + *, + selection, + kernel: Tuple[int, int], + messy: bool +): + """Convert drop block gamma noise to a drop filter. + + This is a deterministic internal component of drop_block_2d. + + Args: + selection: 4D (B, C, H, W) input selection noise; + `1.0` at the midpoints of selected blocks to drop, + `0.0` everywhere else. Expected to be gamma noise. + kernel: the shape of the 2d kernel. + messy: permit partial blocks at the edges, faster. + + Returns: + A drop filter, `1.0` at points to drop, `0.0` at points to keep. + """ + + if not messy: + selection = selection * conv2d_kernel_midpoint_mask( + shape=selection.shape[-2:], + kernel=kernel, + dtype=selection.dtype, + device=selection.device, + ) - h, w = shape kh, kw = kernel - assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}" - # Set to 0, rather than set to 1, so we can clear the inplace mask. - mask[:kh // 2, :] = 0 - mask[h - (kh - 1) // 2:, :] = 0 - mask[:, :kw // 2] = 0 - mask[:, w - (kw - 1) // 2:] = 0 + drop_filter = F.max_pool2d( + selection, + kernel_size=kernel, + stride=1, + padding=[kh // 2, kw // 2], + ) + if (kh % 2 == 0) or (kw % 2 == 0): + drop_filter = drop_filter[..., (kh%2==0):, (kw%2==0):] - return mask + return drop_filter def drop_block_2d( @@ -117,57 +131,60 @@ def drop_block_2d( If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device. """ B, C, H, W = x.shape - total_size = W * H # TODO: This behaves oddly when clipped_block_size < block_size. - clipped_block_size = min(block_size, H, W) + kh = kw = block_size + + kernel = [min(kh, H), min(kw, W)] + kh, kw = kernel gamma = ( - float(gamma_scale * drop_prob * total_size) - / float(clipped_block_size ** 2) - / float((H - block_size + 1) * (W - block_size + 1)) + float(gamma_scale * drop_prob * H * W) + / float(kh * kw) + / float((H - kh + 1) * (W - kw + 1)) ) # batchwise => one mask for whole batch, quite a bit faster mask_shape = (1 if batchwise else B, C, H, W) - block_mask = torch.empty( + selection = torch.empty( mask_shape, dtype=x.dtype, device=x.device ).bernoulli_(gamma) - if not messy: - conv2d_kernel_midpoint_mask( - kernel=(clipped_block_size, clipped_block_size), - inplace_mask=block_mask, - ) - - block_mask = F.max_pool2d( - block_mask, - kernel_size=clipped_block_size, - stride=1, - padding=clipped_block_size // 2) + drop_filter = drop_block_2d_drop_filter_( + selection=selection, + kernel=kernel, + messy=messy, + ) + keep_filter = 1.0 - drop_filter if inplace: - x.mul_(block_mask) + x.mul_(keep_filter) else: - x = x * block_mask - - # From this point on, we do inplace ops on X. + x = x * keep_filter if with_noise: - noise = torch.randn(mask_shape, dtype=x.dtype, device=x.device) # x += (noise * (1 - block_mask)) - block_mask.neg_().add_(1) - noise.mul_(block_mask) - x.add_(noise) + noise = torch.randn(mask_shape, dtype=x.dtype, device=x.device) + + if inplace: + noise.mul_(drop_filter) + x.add_(noise) + else: + x = x + noise * drop_filter else: # x *= (size(block_mask) / sum(block_mask)) - total = block_mask.to(dtype=torch.float32).sum() - normalize_scale = block_mask.numel() / total.add(1e-7).to(x.dtype) - x.mul_(normalize_scale) + count = keep_filter.numel() + total = keep_filter.to(dtype=torch.float32).sum() + normalize_scale = count / total.add(1e-7).to(x.dtype) + + if inplace: + x.mul_(normalize_scale) + else: + x = x * normalize_scale return x From 6bc3dc005153e9011bffaef229d6e9ea6676626d Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Tue, 19 Aug 2025 10:18:34 -0700 Subject: [PATCH 06/14] kill off int/bool tests; they aren't supported by all backends --- tests/layers/test_drop.py | 67 ++++++++++----------------------------- 1 file changed, 16 insertions(+), 51 deletions(-) diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py index 948cc55197..07183c7b12 100644 --- a/tests/layers/test_drop.py +++ b/tests/layers/test_drop.py @@ -12,25 +12,7 @@ torch_device = os.environ.get('TORCH_DEVICE', 'cpu') class Conv2dKernelMidpointMask2d(unittest.TestCase): - def test_conv2d_kernel_midpoint_mask_odd_bool(self): - mask = drop.conv2d_kernel_midpoint_mask( - shape=(5, 7), - kernel=(3, 3), - device=torch_device, - dtype=torch.bool, - ) - print(mask) - assert mask.device == torch.device(torch_device) - assert mask.tolist() == \ - [ - [False, False, False, False, False, False, False], - [False, True, True, True, True, True, False], - [False, True, True, True, True, True, False], - [False, True, True, True, True, True, False], - [False, False, False, False, False, False, False], - ] - - def test_conv2d_kernel_midpoint_mask_odd_float(self): + def test_conv2d_kernel_midpoint_mask_odd(self): mask = drop.conv2d_kernel_midpoint_mask( shape=(5, 7), kernel=(3, 3), @@ -48,23 +30,6 @@ def test_conv2d_kernel_midpoint_mask_odd_float(self): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ] - def test_conv2d_kernel_midpoint_mask_odd_int(self): - mask = drop.conv2d_kernel_midpoint_mask( - shape=(5, 7), - kernel=(3, 3), - device=torch_device, - dtype=torch.int32, - ) - print(mask) - assert mask.device == torch.device(torch_device) - assert mask.tolist() == \ - [ - [0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 1, 1, 1, 0], - [0, 1, 1, 1, 1, 1, 0], - [0, 1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0], - ] def test_conv2d_kernel_midpoint_mask_even(self): mask = drop.conv2d_kernel_midpoint_mask( @@ -77,11 +42,11 @@ def test_conv2d_kernel_midpoint_mask_even(self): assert mask.device == torch.device(torch_device) assert mask.tolist() == \ [ - [False, False, False, False, False, False, False], - [False, True, True, True, True, True, True], - [False, True, True, True, True, True, True], - [False, True, True, True, True, True, True], - [False, True, True, True, True, True, True], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], ] def test_clip_mask_2d_kernel_too_big(self): @@ -130,11 +95,11 @@ def test_drop_filter(self): def test_drop_filter_messy(self): selection = torch.tensor( [ - [0, 0, 0, 1, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], ], device=torch_device, dtype=torch.int32, @@ -149,11 +114,11 @@ def test_drop_filter_messy(self): assert result.device == torch.device(torch_device) assert result.tolist() == \ [ - [1, 1, 1, 1, 1, 0, 0], - [1, 1, 1, 0, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1], - [0, 0, 0, 0, 0, 1, 1], - [0, 0, 0, 0, 0, 1, 1], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], ] class DropBlock2dTest(unittest.TestCase): From c8b7af0f48e0e1c9263010cc643ddd60b0e72623 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Tue, 19 Aug 2025 13:52:59 -0700 Subject: [PATCH 07/14] fix types --- tests/layers/test_drop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py index 07183c7b12..16974450cf 100644 --- a/tests/layers/test_drop.py +++ b/tests/layers/test_drop.py @@ -36,7 +36,7 @@ def test_conv2d_kernel_midpoint_mask_even(self): shape=(5, 7), kernel=(2, 2), device=torch_device, - dtype=torch.bool, + dtype=torch.float32, ) print(mask) assert mask.device == torch.device(torch_device) @@ -55,7 +55,7 @@ def test_clip_mask_2d_kernel_too_big(self): shape=(4, 7), kernel=(5, 5), device=torch_device, - dtype=torch.bool, + dtype=torch.float32, ) raise RuntimeError("Expected throw") @@ -102,7 +102,7 @@ def test_drop_filter_messy(self): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], ], device=torch_device, - dtype=torch.int32, + dtype=torch.float32, ).unsqueeze(0).unsqueeze(0) result = drop.drop_block_2d_drop_filter_( From ea5d119cd6993e4963a4f926a49a74a718aa7058 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Thu, 21 Aug 2025 10:14:36 -0700 Subject: [PATCH 08/14] s/messy/partial_edge_blocks/g --- tests/layers/test_drop.py | 6 +++--- timm/layers/drop.py | 24 ++++++++++++------------ timm/models/resnet.py | 16 ++++++++-------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py index 16974450cf..b81abafb31 100644 --- a/tests/layers/test_drop.py +++ b/tests/layers/test_drop.py @@ -79,7 +79,7 @@ def test_drop_filter(self): result = drop.drop_block_2d_drop_filter_( selection=selection, kernel=(2, 3), - messy=False + partial_edge_blocks=False ).squeeze() print(result) assert result.device == torch.device(torch_device) @@ -92,7 +92,7 @@ def test_drop_filter(self): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ] - def test_drop_filter_messy(self): + def test_drop_filter_partial_edge_blocks(self): selection = torch.tensor( [ [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], @@ -108,7 +108,7 @@ def test_drop_filter_messy(self): result = drop.drop_block_2d_drop_filter_( selection=selection, kernel=(2, 3), - messy=True + partial_edge_blocks=True ).squeeze() print(result) assert result.device == torch.device(torch_device) diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 66b85baa67..11e7989a52 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -64,7 +64,7 @@ def drop_block_2d_drop_filter_( *, selection, kernel: Tuple[int, int], - messy: bool + partial_edge_blocks: bool ): """Convert drop block gamma noise to a drop filter. @@ -75,13 +75,13 @@ def drop_block_2d_drop_filter_( `1.0` at the midpoints of selected blocks to drop, `0.0` everywhere else. Expected to be gamma noise. kernel: the shape of the 2d kernel. - messy: permit partial blocks at the edges, faster. + partial_edge_blocks: permit partial blocks at the edges, faster. Returns: A drop filter, `1.0` at points to drop, `0.0` at points to keep. """ - if not messy: + if not partial_edge_blocks: selection = selection * conv2d_kernel_midpoint_mask( shape=selection.shape[-2:], kernel=kernel, @@ -111,7 +111,7 @@ def drop_block_2d( with_noise: bool = False, inplace: bool = False, batchwise: bool = False, - messy: bool = False, + partial_edge_blocks: bool = False, ): """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf @@ -125,7 +125,7 @@ def drop_block_2d( with_noise: should normal noise be added to the dropped region? inplace: if the drop should be applied in-place on the input tensor. batchwise: should the entire batch use the same drop mask? - messy: partial-blocks at the edges, faster. + partial_edge_blocks: partial-blocks at the edges, faster. Returns: If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device. @@ -156,7 +156,7 @@ def drop_block_2d( drop_filter = drop_block_2d_drop_filter_( selection=selection, kernel=kernel, - messy=messy, + partial_edge_blocks=partial_edge_blocks, ) keep_filter = 1.0 - drop_filter @@ -210,7 +210,7 @@ def drop_block_fast_2d( with_noise=with_noise, inplace=inplace, batchwise=True, - messy=True, + partial_edge_blocks=True, ) @@ -224,7 +224,7 @@ class DropBlock2d(nn.Module): with_noise: should normal noise be added to the dropped region? inplace: if the drop should be applied in-place on the input tensor. batchwise: should the entire batch use the same drop mask? - messy: partial-blocks at the edges, faster. + partial_edge_blocks: partial-blocks at the edges, faster. """ drop_prob: float block_size: int @@ -232,7 +232,7 @@ class DropBlock2d(nn.Module): with_noise: bool inplace: bool batchwise: bool - messy: bool + partial_edge_blocks: bool def __init__( self, @@ -242,7 +242,7 @@ def __init__( with_noise: bool = False, inplace: bool = False, batchwise: bool = False, - messy: bool = True, + partial_edge_blocks: bool = True, ): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob @@ -251,7 +251,7 @@ def __init__( self.with_noise = with_noise self.inplace = inplace self.batchwise = batchwise - self.messy = messy + self.partial_edge_blocks = partial_edge_blocks def forward(self, x): if not self.training or not self.drop_prob: @@ -265,7 +265,7 @@ def forward(self, x): with_noise=self.with_noise, inplace=self.inplace, batchwise=self.batchwise, - messy=self.messy) + partial_edge_blocks=self.partial_edge_blocks) def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 8fe6900299..776bd216f5 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -328,7 +328,7 @@ def make_blocks( drop_block_rate: float = 0., drop_path_rate: float = 0., drop_block_batchwise: bool = False, - drop_block_messy: bool = True, + drop_block_partial_edge_blocks: bool = True, **kwargs, ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]: """Create ResNet stages with specified block configurations. @@ -344,7 +344,7 @@ def make_blocks( avg_down: Use average pooling for downsample. drop_block_rate: DropBlock drop rate. drop_block_batchwise: Batchwise block dropping, faster. - drop_block_messy: dropping produces partial blocks on the edge, faster. + drop_block_partial_edge_blocks: dropping produces partial blocks on the edge, faster. drop_path_rate: Drop path rate for stochastic depth. **kwargs: Additional arguments passed to block constructors. @@ -364,7 +364,7 @@ def make_blocks( drop_blocks( drop_prob=drop_block_rate, batchwise=drop_block_batchwise, - messy=drop_block_messy, + partial_edge_blocks=drop_block_partial_edge_blocks, ))): stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it stride = 1 if stage_idx == 0 else 2 @@ -468,7 +468,7 @@ def __init__( drop_path_rate: float = 0., drop_block_rate: float = 0., drop_block_batchwise: bool = True, - drop_block_messy: bool = True, + drop_block_partial_edge_blocks: bool = True, zero_init_last: bool = True, block_args: Optional[Dict[str, Any]] = None, ): @@ -500,7 +500,7 @@ def __init__( drop_path_rate (float): Stochastic depth drop-path rate (default 0.) drop_block_rate (float): Drop block rate (default 0.) drop_block_batchwise (bool): Sample blocks batchwise, faster. - drop_block_messy (bool): Partial block dropping at the edges, faster. + drop_block_partial_edge_blocks (bool): Partial block dropping at the edges, faster. zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) block_args (dict): Extra kwargs to pass through to block module """ @@ -572,7 +572,7 @@ def __init__( aa_layer=aa_layer, drop_block_rate=drop_block_rate, drop_block_batchwise=drop_block_batchwise, - drop_block_messy=drop_block_messy, + drop_block_partial_edge_blocks=drop_block_partial_edge_blocks, drop_path_rate=drop_path_rate, **block_args, ) @@ -1470,7 +1470,7 @@ def resnet10t_dropblock_correct(pretrained: bool = False, **kwargs) -> ResNet: avg_down=True, drop_block_rate=0.05, drop_block_batchwise=True, - drop_block_messy=True, + drop_block_partial_edge_blocks=True, ) return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) @@ -1486,7 +1486,7 @@ def resnet10t_dropblock_fast(pretrained: bool = False, **kwargs) -> ResNet: avg_down=True, drop_block_rate=0.05, drop_block_batchwise=False, - drop_block_messy=False, + drop_block_partial_edge_blocks=False, ) return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) From a7a3186f8f861d7ecf095543a7d3bf6fe65d6792 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Thu, 21 Aug 2025 10:16:51 -0700 Subject: [PATCH 09/14] format --- tests/layers/test_drop.py | 133 +++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 66 deletions(-) diff --git a/tests/layers/test_drop.py b/tests/layers/test_drop.py index b81abafb31..8f0af005a3 100644 --- a/tests/layers/test_drop.py +++ b/tests/layers/test_drop.py @@ -6,10 +6,11 @@ from timm.layers import drop -torch_backend = os.environ.get('TORCH_BACKEND') +torch_backend = os.environ.get("TORCH_BACKEND") if torch_backend is not None: importlib.import_module(torch_backend) -torch_device = os.environ.get('TORCH_DEVICE', 'cpu') +torch_device = os.environ.get("TORCH_DEVICE", "cpu") + class Conv2dKernelMidpointMask2d(unittest.TestCase): def test_conv2d_kernel_midpoint_mask_odd(self): @@ -21,15 +22,13 @@ def test_conv2d_kernel_midpoint_mask_odd(self): ) print(mask) assert mask.device == torch.device(torch_device) - assert mask.tolist() == \ - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ] - + assert mask.tolist() == [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] def test_conv2d_kernel_midpoint_mask_even(self): mask = drop.conv2d_kernel_midpoint_mask( @@ -40,14 +39,13 @@ def test_conv2d_kernel_midpoint_mask_even(self): ) print(mask) assert mask.device == torch.device(torch_device) - assert mask.tolist() == \ - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - ] + assert mask.tolist() == [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ] def test_clip_mask_2d_kernel_too_big(self): try: @@ -65,67 +63,70 @@ def test_clip_mask_2d_kernel_too_big(self): class DropBlock2dDropFilterTest(unittest.TestCase): def test_drop_filter(self): - selection = torch.tensor( - [ - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - ], - device=torch_device, - ).unsqueeze(0).unsqueeze(0) + selection = ( + torch.tensor( + [ + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=torch_device, + ) + .unsqueeze(0) + .unsqueeze(0) + ) result = drop.drop_block_2d_drop_filter_( - selection=selection, - kernel=(2, 3), - partial_edge_blocks=False + selection=selection, kernel=(2, 3), partial_edge_blocks=False ).squeeze() print(result) assert result.device == torch.device(torch_device) - assert result.tolist() == \ - [ - [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ] + assert result.tolist() == [ + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] def test_drop_filter_partial_edge_blocks(self): - selection = torch.tensor( - [ - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - ], - device=torch_device, - dtype=torch.float32, - ).unsqueeze(0).unsqueeze(0) + selection = ( + torch.tensor( + [ + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=torch_device, + dtype=torch.float32, + ) + .unsqueeze(0) + .unsqueeze(0) + ) result = drop.drop_block_2d_drop_filter_( - selection=selection, - kernel=(2, 3), - partial_edge_blocks=True + selection=selection, kernel=(2, 3), partial_edge_blocks=True ).squeeze() print(result) assert result.device == torch.device(torch_device) - assert result.tolist() == \ - [ - [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], - ] + assert result.tolist() == [ + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + ] + class DropBlock2dTest(unittest.TestCase): def test_drop_block_2d(self): tensor = torch.ones((1, 1, 200, 300), device=torch_device) - drop_prob=0.1 + drop_prob = 0.1 keep_prob = 1.0 - drop_prob result = drop.drop_block_2d( @@ -138,6 +139,6 @@ def test_drop_block_2d(self): unchanged = float(len(result[result == 1.0])) keep_ratio = unchanged / numel - assert abs(keep_ratio - keep_prob) < 0.05, \ - f"abs({keep_ratio=} - {keep_prob=}) ! < 0.05" - + assert ( + abs(keep_ratio - keep_prob) < 0.05 + ), f"abs({keep_ratio=} - {keep_prob=}) ! < 0.05" From d77d77ad4fa1d07ffe4f963176bfc062dce635b9 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Thu, 21 Aug 2025 10:17:37 -0700 Subject: [PATCH 10/14] format --- timm/layers/drop.py | 98 +++++++++++++++++++++++---------------------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 11e7989a52..1485d9b191 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -1,4 +1,4 @@ -""" DropBlock, DropPath +"""DropBlock, DropPath PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. @@ -14,6 +14,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ + from typing import Optional, Tuple import torch import torch.nn as nn @@ -21,11 +22,11 @@ def conv2d_kernel_midpoint_mask( - *, - shape: Tuple[int, int], - kernel: Tuple[int, int], - device, - dtype, + *, + shape: Tuple[int, int], + kernel: Tuple[int, int], + device, + dtype, ): """Build a mask of kernel midpoints. @@ -55,16 +56,13 @@ def conv2d_kernel_midpoint_mask( mask = torch.zeros(shape, dtype=dtype, device=device) - mask[kh//2: h - ((kh - 1) // 2), kw//2: w - ((kw - 1) // 2)] = 1.0 + mask[kh // 2 : h - ((kh - 1) // 2), kw // 2 : w - ((kw - 1) // 2)] = 1.0 return mask def drop_block_2d_drop_filter_( - *, - selection, - kernel: Tuple[int, int], - partial_edge_blocks: bool + *, selection, kernel: Tuple[int, int], partial_edge_blocks: bool ): """Convert drop block gamma noise to a drop filter. @@ -98,20 +96,20 @@ def drop_block_2d_drop_filter_( padding=[kh // 2, kw // 2], ) if (kh % 2 == 0) or (kw % 2 == 0): - drop_filter = drop_filter[..., (kh%2==0):, (kw%2==0):] + drop_filter = drop_filter[..., (kh % 2 == 0) :, (kw % 2 == 0) :] return drop_filter def drop_block_2d( - x, - drop_prob: float = 0.1, - block_size: int = 7, - gamma_scale: float = 1.0, - with_noise: bool = False, - inplace: bool = False, - batchwise: bool = False, - partial_edge_blocks: bool = False, + x, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, + partial_edge_blocks: bool = False, ): """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf @@ -147,11 +145,9 @@ def drop_block_2d( # batchwise => one mask for whole batch, quite a bit faster mask_shape = (1 if batchwise else B, C, H, W) - selection = torch.empty( - mask_shape, - dtype=x.dtype, - device=x.device - ).bernoulli_(gamma) + selection = torch.empty(mask_shape, dtype=x.dtype, device=x.device).bernoulli_( + gamma + ) drop_filter = drop_block_2d_drop_filter_( selection=selection, @@ -190,14 +186,14 @@ def drop_block_2d( def drop_block_fast_2d( - x: torch.Tensor, - drop_prob: float = 0.1, - block_size: int = 7, - gamma_scale: float = 1.0, - with_noise: bool = False, - inplace: bool = False, + x: torch.Tensor, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, ): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid block mask at edges. @@ -226,6 +222,7 @@ class DropBlock2d(nn.Module): batchwise: should the entire batch use the same drop mask? partial_edge_blocks: partial-blocks at the edges, faster. """ + drop_prob: float block_size: int gamma_scale: float @@ -235,14 +232,14 @@ class DropBlock2d(nn.Module): partial_edge_blocks: bool def __init__( - self, - drop_prob: float = 0.1, - block_size: int = 7, - gamma_scale: float = 1.0, - with_noise: bool = False, - inplace: bool = False, - batchwise: bool = False, - partial_edge_blocks: bool = True, + self, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, + partial_edge_blocks: bool = True, ): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob @@ -265,10 +262,13 @@ def forward(self, x): with_noise=self.with_noise, inplace=self.inplace, batchwise=self.batchwise, - partial_edge_blocks=self.partial_edge_blocks) + partial_edge_blocks=self.partial_edge_blocks, + ) -def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, @@ -278,10 +278,12 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b 'survival rate' as the argument. """ - if drop_prob == 0. or not training: + if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) @@ -289,9 +291,9 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - """ - def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep @@ -300,4 +302,4 @@ def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) def extra_repr(self): - return f'drop_prob={round(self.drop_prob,3):0.3f}' + return f"drop_prob={round(self.drop_prob,3):0.3f}" From dc887adfa6018830820d498af495bac2244af2b8 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Thu, 21 Aug 2025 10:20:47 -0700 Subject: [PATCH 11/14] format --- timm/layers/drop.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 1485d9b191..bf347884f8 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -56,7 +56,10 @@ def conv2d_kernel_midpoint_mask( mask = torch.zeros(shape, dtype=dtype, device=device) - mask[kh // 2 : h - ((kh - 1) // 2), kw // 2 : w - ((kw - 1) // 2)] = 1.0 + mask[ + kh // 2 : h - ((kh - 1) // 2), + kw // 2 : w - ((kw - 1) // 2), + ] = 1.0 return mask @@ -145,9 +148,11 @@ def drop_block_2d( # batchwise => one mask for whole batch, quite a bit faster mask_shape = (1 if batchwise else B, C, H, W) - selection = torch.empty(mask_shape, dtype=x.dtype, device=x.device).bernoulli_( - gamma - ) + selection = torch.empty( + mask_shape, + dtype=x.dtype, + device=x.device, + ).bernoulli_(gamma) drop_filter = drop_block_2d_drop_filter_( selection=selection, @@ -163,7 +168,11 @@ def drop_block_2d( if with_noise: # x += (noise * (1 - block_mask)) - noise = torch.randn(mask_shape, dtype=x.dtype, device=x.device) + noise = torch.randn( + mask_shape, + dtype=x.dtype, + device=x.device, + ) if inplace: noise.mul_(drop_filter) @@ -281,10 +290,12 @@ def drop_path( if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * ( - x.ndim - 1 - ) # work with diff dim tensors, not just 2D ConvNets + + # work with diff dim tensors, not just 2D ConvNets + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor From 8521edc5976b637f653a927aa2cd80152f33512f Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Thu, 21 Aug 2025 10:22:45 -0700 Subject: [PATCH 12/14] format --- timm/layers/drop.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/timm/layers/drop.py b/timm/layers/drop.py index bf347884f8..284d08a6bc 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -65,7 +65,10 @@ def conv2d_kernel_midpoint_mask( def drop_block_2d_drop_filter_( - *, selection, kernel: Tuple[int, int], partial_edge_blocks: bool + *, + selection, + kernel: Tuple[int, int], + partial_edge_blocks: bool, ): """Convert drop block gamma noise to a drop filter. @@ -81,7 +84,6 @@ def drop_block_2d_drop_filter_( Returns: A drop filter, `1.0` at points to drop, `0.0` at points to keep. """ - if not partial_edge_blocks: selection = selection * conv2d_kernel_midpoint_mask( shape=selection.shape[-2:], @@ -276,7 +278,10 @@ def forward(self, x): def drop_path( - x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True + x, + drop_prob: float = 0.0, + training: bool = False, + scale_by_keep: bool = True, ): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). @@ -304,13 +309,22 @@ def drop_path( class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + def __init__( + self, + drop_prob: float = 0.0, + scale_by_keep: bool = True, + ): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): - return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + return drop_path( + x, + drop_prob=self.drop_prob, + training=self.training, + scale_by_keep=self.scale_by_keep, + ) def extra_repr(self): return f"drop_prob={round(self.drop_prob,3):0.3f}" From b2fe962c29ecb874ef569740034c83052db17076 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Thu, 21 Aug 2025 10:50:50 -0700 Subject: [PATCH 13/14] inplace speedups --- timm/layers/drop.py | 109 +++++++++++++++++++++++++------------------- 1 file changed, 63 insertions(+), 46 deletions(-) diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 284d08a6bc..78e28f782a 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -23,10 +23,11 @@ def conv2d_kernel_midpoint_mask( *, - shape: Tuple[int, int], kernel: Tuple[int, int], - device, - dtype, + inplace=None, + shape: Optional[Tuple[int, int]] = None, + device=None, + dtype=None, ): """Build a mask of kernel midpoints. @@ -43,6 +44,7 @@ def conv2d_kernel_midpoint_mask( Args: kernel: the (kh, hw) shape of the kernel. + inplace: use the provided tensor as the mask; set masked-out values to 0. shape: the (h, w) shape of the tensor. device: the target device. dtype: the target dtype. @@ -50,16 +52,30 @@ def conv2d_kernel_midpoint_mask( Returns: a (h, w) bool mask tensor. """ + if inplace is None: + assert shape is not None, f"shape is required when inplace is None." + assert dtype is not None, f"dtype is required when inplace is None." + assert device is not None, f"device is required when inplace is None." + + mask = torch.ones(shape, dtype=dtype, device=device) + else: + assert shape is None, f"shape and inplace are incompatile" + assert dtype is None, f"dtype and inplace are incompatile" + assert device is None, f"device and inplace are incompatile" + + mask = inplace + shape = inplace.shape[-2:] + device = inplace.device + dtype = inplace.dtype + h, w = shape kh, kw = kernel assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}" - mask = torch.zeros(shape, dtype=dtype, device=device) - - mask[ - kh // 2 : h - ((kh - 1) // 2), - kw // 2 : w - ((kw - 1) // 2), - ] = 1.0 + mask[..., 0 : kh // 2, :] = 0 + mask[..., :, 0 : kw // 2 :] = 0 + mask[..., h - ((kh - 1) // 2) :, :] = 0 + mask[..., :, w - ((kw - 1) // 2) :] = 0 return mask @@ -67,6 +83,7 @@ def conv2d_kernel_midpoint_mask( def drop_block_2d_drop_filter_( *, selection, + inplace: bool = False, kernel: Tuple[int, int], partial_edge_blocks: bool, ): @@ -78,6 +95,7 @@ def drop_block_2d_drop_filter_( selection: 4D (B, C, H, W) input selection noise; `1.0` at the midpoints of selected blocks to drop, `0.0` everywhere else. Expected to be gamma noise. + inplace: permit in-place updates to `selection`. kernel: the shape of the 2d kernel. partial_edge_blocks: permit partial blocks at the edges, faster. @@ -85,12 +103,18 @@ def drop_block_2d_drop_filter_( A drop filter, `1.0` at points to drop, `0.0` at points to keep. """ if not partial_edge_blocks: - selection = selection * conv2d_kernel_midpoint_mask( - shape=selection.shape[-2:], - kernel=kernel, - dtype=selection.dtype, - device=selection.device, - ) + if inplace: + selection = conv2d_kernel_midpoint_mask( + kernel=kernel, + inplace=selection, + ) + else: + selection = selection * conv2d_kernel_midpoint_mask( + shape=selection.shape[-2:], + kernel=kernel, + dtype=selection.dtype, + device=selection.device, + ) kh, kw = kernel @@ -136,62 +160,55 @@ def drop_block_2d( B, C, H, W = x.shape # TODO: This behaves oddly when clipped_block_size < block_size. - kh = kw = block_size - - kernel = [min(kh, H), min(kw, W)] + # We could expose non-square blocks above this layer. + kernel = [min(block_size, H), min(block_size, W)] kh, kw = kernel + # batchwise => one mask for whole batch, quite a bit faster + noise_shape = (1 if batchwise else B, C, H, W) + gamma = ( float(gamma_scale * drop_prob * H * W) / float(kh * kw) / float((H - kh + 1) * (W - kw + 1)) ) - # batchwise => one mask for whole batch, quite a bit faster - mask_shape = (1 if batchwise else B, C, H, W) - - selection = torch.empty( - mask_shape, - dtype=x.dtype, - device=x.device, - ).bernoulli_(gamma) - drop_filter = drop_block_2d_drop_filter_( - selection=selection, kernel=kernel, partial_edge_blocks=partial_edge_blocks, + inplace=True, + selection=torch.empty( + noise_shape, + dtype=x.dtype, + device=x.device, + ).bernoulli_(gamma), ) keep_filter = 1.0 - drop_filter - if inplace: - x.mul_(keep_filter) - else: - x = x * keep_filter - if with_noise: - # x += (noise * (1 - block_mask)) - noise = torch.randn( - mask_shape, - dtype=x.dtype, - device=x.device, - ) + # x += (noise * drop_filter) + drop_noise = torch.randn_like(drop_filter) + drop_noise.mul_(drop_filter) if inplace: - noise.mul_(drop_filter) - x.add_(noise) + x.mul_(keep_filter) + x.add_(drop_noise) + else: - x = x + noise * drop_filter + x = x * keep_filter + drop_noise else: - # x *= (size(block_mask) / sum(block_mask)) + # x *= (size(keep_filter) / (sum(keep_filter) + eps)) count = keep_filter.numel() total = keep_filter.to(dtype=torch.float32).sum() - normalize_scale = count / total.add(1e-7).to(x.dtype) + keep_scale = count / total.add(1e-7).to(x.dtype) + + keep_filter.mul_(keep_scale) if inplace: - x.mul_(normalize_scale) + x.mul_(keep_filter) else: - x = x * normalize_scale + x = x * keep_filter return x From f195cc32ac43417b0dea9672457556d0582b3822 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Mon, 25 Aug 2025 15:01:31 -0700 Subject: [PATCH 14/14] couple_channels; no_grad --- timm/layers/drop.py | 59 +++++++++++++++++++++++++------------------ timm/models/resnet.py | 31 +++++++++++++++-------- 2 files changed, 54 insertions(+), 36 deletions(-) diff --git a/timm/layers/drop.py b/timm/layers/drop.py index 78e28f782a..e784edad99 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -138,6 +138,7 @@ def drop_block_2d( with_noise: bool = False, inplace: bool = False, batchwise: bool = False, + couple_channels: bool = False, partial_edge_blocks: bool = False, ): """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf @@ -151,8 +152,10 @@ def drop_block_2d( gamma_scale: adjustment scale for the drop_prob. with_noise: should normal noise be added to the dropped region? inplace: if the drop should be applied in-place on the input tensor. - batchwise: should the entire batch use the same drop mask? - partial_edge_blocks: partial-blocks at the edges, faster. + batchwise: when true, the entire batch is shares the same drop mask; much faster. + couple_channels: when true, channels share the same drop mask; + much faster, with significant semantic impact. + partial_edge_blocks: partial-blocks at the edges; minor speedup, minor semantic impact. Returns: If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device. @@ -165,7 +168,7 @@ def drop_block_2d( kh, kw = kernel # batchwise => one mask for whole batch, quite a bit faster - noise_shape = (1 if batchwise else B, C, H, W) + noise_shape = (1 if batchwise else B, 1 if couple_channels else C, H, W) gamma = ( float(gamma_scale * drop_prob * H * W) @@ -173,37 +176,39 @@ def drop_block_2d( / float((H - kh + 1) * (W - kw + 1)) ) - drop_filter = drop_block_2d_drop_filter_( - kernel=kernel, - partial_edge_blocks=partial_edge_blocks, - inplace=True, - selection=torch.empty( - noise_shape, - dtype=x.dtype, - device=x.device, - ).bernoulli_(gamma), - ) - keep_filter = 1.0 - drop_filter + with torch.no_grad(): + drop_filter = drop_block_2d_drop_filter_( + kernel=kernel, + partial_edge_blocks=partial_edge_blocks, + inplace=True, + selection=torch.empty( + noise_shape, + dtype=x.dtype, + device=x.device, + ).bernoulli_(gamma), + ) + keep_filter = 1.0 - drop_filter if with_noise: # x += (noise * drop_filter) - drop_noise = torch.randn_like(drop_filter) - drop_noise.mul_(drop_filter) + with torch.no_grad(): + drop_noise = torch.randn_like(drop_filter) + drop_noise.mul_(drop_filter) if inplace: x.mul_(keep_filter) x.add_(drop_noise) - else: x = x * keep_filter + drop_noise else: # x *= (size(keep_filter) / (sum(keep_filter) + eps)) - count = keep_filter.numel() - total = keep_filter.to(dtype=torch.float32).sum() - keep_scale = count / total.add(1e-7).to(x.dtype) + with torch.no_grad(): + count = keep_filter.numel() + total = keep_filter.to(dtype=torch.float32).sum() + keep_scale = count / total.add(1e-7).to(x.dtype) - keep_filter.mul_(keep_scale) + keep_filter.mul_(keep_scale) if inplace: x.mul_(keep_filter) @@ -247,8 +252,10 @@ class DropBlock2d(nn.Module): gamma_scale: adjustment scale for the drop_prob. with_noise: should normal noise be added to the dropped region? inplace: if the drop should be applied in-place on the input tensor. - batchwise: should the entire batch use the same drop mask? - partial_edge_blocks: partial-blocks at the edges, faster. + batchwise: when true, the entire batch is shares the same drop mask; much faster. + couple_channels: when true, channels share the same drop mask; + much faster, with significant semantic impact. + partial_edge_blocks: partial-blocks at the edges; minor speedup, minor semantic impact. """ drop_prob: float @@ -257,6 +264,7 @@ class DropBlock2d(nn.Module): with_noise: bool inplace: bool batchwise: bool + couple_channels: bool partial_edge_blocks: bool def __init__( @@ -266,8 +274,9 @@ def __init__( gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, - batchwise: bool = False, - partial_edge_blocks: bool = True, + batchwise: bool = True, + couple_channels: bool = False, + partial_edge_blocks: bool = False, ): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 776bd216f5..54aad03576 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -326,9 +326,10 @@ def make_blocks( down_kernel_size: int = 1, avg_down: bool = False, drop_block_rate: float = 0., - drop_path_rate: float = 0., - drop_block_batchwise: bool = False, + drop_block_batchwise: bool = True, + drop_block_couple_channels: bool = False, drop_block_partial_edge_blocks: bool = True, + drop_path_rate: float = 0., **kwargs, ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]: """Create ResNet stages with specified block configurations. @@ -343,8 +344,10 @@ def make_blocks( down_kernel_size: Kernel size for downsample layers. avg_down: Use average pooling for downsample. drop_block_rate: DropBlock drop rate. - drop_block_batchwise: Batchwise block dropping, faster. - drop_block_partial_edge_blocks: dropping produces partial blocks on the edge, faster. + drop_block_batchwise: Batchwise block dropping, much faster. + drop_block_couple_channels: Couple channel drops. + drop_block_partial_edge_blocks: Permit partial drop blocks on the edge, + slightly faster. drop_path_rate: Drop path rate for stochastic depth. **kwargs: Additional arguments passed to block constructors. @@ -364,6 +367,7 @@ def make_blocks( drop_blocks( drop_prob=drop_block_rate, batchwise=drop_block_batchwise, + couple_channels=drop_block_couple_channels, partial_edge_blocks=drop_block_partial_edge_blocks, ))): stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it @@ -465,10 +469,11 @@ def __init__( norm_layer: LayerType = nn.BatchNorm2d, aa_layer: Optional[Type[nn.Module]] = None, drop_rate: float = 0.0, - drop_path_rate: float = 0., drop_block_rate: float = 0., drop_block_batchwise: bool = True, + drop_block_couple_channels: bool = False, drop_block_partial_edge_blocks: bool = True, + drop_path_rate: float = 0., zero_init_last: bool = True, block_args: Optional[Dict[str, Any]] = None, ): @@ -497,10 +502,11 @@ def __init__( norm_layer (str, nn.Module): normalization layer aa_layer (nn.Module): anti-aliasing layer drop_rate (float): Dropout probability before classifier, for training (default 0.) - drop_path_rate (float): Stochastic depth drop-path rate (default 0.) drop_block_rate (float): Drop block rate (default 0.) - drop_block_batchwise (bool): Sample blocks batchwise, faster. + drop_block_batchwise (bool): Sample blocks batchwise, significantly faster. + drop_block_couple_channels (bool): couple channels when dropping blocks. drop_block_partial_edge_blocks (bool): Partial block dropping at the edges, faster. + drop_path_rate (float): Stochastic depth drop-path rate (default 0.) zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) block_args (dict): Extra kwargs to pass through to block module """ @@ -572,6 +578,7 @@ def __init__( aa_layer=aa_layer, drop_block_rate=drop_block_rate, drop_block_batchwise=drop_block_batchwise, + drop_block_couple_channels=drop_block_couple_channels, drop_block_partial_edge_blocks=drop_block_partial_edge_blocks, drop_path_rate=drop_path_rate, **block_args, @@ -1459,8 +1466,8 @@ def resnet10t(pretrained: bool = False, **kwargs) -> ResNet: return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) @register_model -def resnet10t_dropblock_correct(pretrained: bool = False, **kwargs) -> ResNet: - """Constructs a ResNet-10-T model with drop_block_rate=0.05, using the most accurate DropBlock2d features. +def resnet10t_dropblock_slow(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a ResNet-10-T model with drop_block_rate=0.05, using the slowest DropBlock2d features. """ model_args = dict( block=BasicBlock, @@ -1469,7 +1476,8 @@ def resnet10t_dropblock_correct(pretrained: bool = False, **kwargs) -> ResNet: stem_type='deep_tiered', avg_down=True, drop_block_rate=0.05, - drop_block_batchwise=True, + drop_block_batchwise=False, + drop_block_couple_channels=False, drop_block_partial_edge_blocks=True, ) return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) @@ -1485,7 +1493,8 @@ def resnet10t_dropblock_fast(pretrained: bool = False, **kwargs) -> ResNet: stem_type='deep_tiered', avg_down=True, drop_block_rate=0.05, - drop_block_batchwise=False, + drop_block_batchwise=True, + drop_block_couple_channels=True, drop_block_partial_edge_blocks=False, ) return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))