Skip to content

Fix bug in timm.layers.drop.drop_block_2d; unify fast/slow versions; add model and unit tests. #2569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

crutcher
Copy link

There are two bugs in the valid_block code for drop_block_2d.

  • a (W, H) grid being reshaped as (H, W)

The current code uses (W, H) to generate the meshgrid; but then uses a .reshape((1, 1, H, W)) to unsqueeze the block map.

The simplest fix to the first bug is a one-line change:

h_i, w_i = ndgrid(torch.arange(H), torch.arange(W))

This is a longer patch, that attempts to make the code testable.

Note: The current code behaves oddly when the block_size or clipped_block_size is even; I've added tests exposing the behavior; but have not changed it.

When you trigger the reshape bug, you get wild results:

$ python scratch.py
{'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': False}
grid.shape=torch.Size([1, 1, 4, 5])
tensor([[[[False, False, False, False, False],
          [ True,  True, False, False,  True],
          [ True, False, False,  True,  True],
          [False, False, False, False, False]]]])

{'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': True}
grid.shape=torch.Size([1, 1, 4, 5])
tensor([[[[False, False, False, False, False],
          [False,  True,  True,  True, False],
          [False,  True,  True,  True, False],
          [False, False, False, False, False]]]])

Here's a tiny exceprt script, showing the problem; it generated the above output.

import torch
from typing import Tuple

def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
    """generate N-D grid in dimension order.

    The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.

    That is, the statement
    [X1,X2,X3] = ndgrid(x1,x2,x3)

    produces the same result as

    [X2,X1,X3] = meshgrid(x2,x1,x3)

    This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
    torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').

    """
    try:
        return torch.meshgrid(*tensors, indexing='ij')
    except TypeError:
        # old PyTorch < 1.10 will follow this path as it does not have indexing arg,
        # the old behaviour of meshgrid was 'ij'
        return torch.meshgrid(*tensors)

def valid_block(H, W, block_size, fix_reshape=False):
    clipped_block_size = min(block_size, H, W)

    if fix_reshape:
        # This should match the .reshape() dimension order below.
        h_i, w_i = ndgrid(torch.arange(H), torch.arange(W))
    else:
        # The original produces crazy stride patterns, due to .reshape() offset winding.
        # This is only visible when H != W.
        w_i, h_i = ndgrid(torch.arange(W), torch.arange(H))

    valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
                 ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))

    valid_block = torch.reshape(valid_block, (1, 1, H, W))

    return valid_block

def main():
    common_args = dict(H=4, W=5, block_size=3)

    for fix in [False, True]:
        args = dict(H=4, W=5, block_size=3, fix_reshape=fix)
        grid = valid_block(**args)
        print(args)
        print(f"{grid.shape=}")
        print(grid)
        print()

if __name__ == "__main__":
    main()

There are two bugs in the `valid_block` code for `drop_block_2d`.
- a (W, H) grid being reshaped as (H, W)

The current code uses (W, H) to generate the meshgrid;
but then uses a `.reshape((1, 1, H, W))` to unsqueeze the block map.

The simplest fix to the first bug is a one-line change:
```python
h_i, w_i = ndgrid(torch.arange(H), torch.arange(W))
```

This is a longer patch, that attempts to make the code testable.

Note: The current code behaves oddly when the block_size or
clipped_block_size is even; I've added tests exposing the behavior;
but have not changed it.

When you trigger the reshape bug, you get wild results:
```
$ python scratch.py
{'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': False}
grid.shape=torch.Size([1, 1, 4, 5])
tensor([[[[False, False, False, False, False],
          [ True,  True, False, False,  True],
          [ True, False, False,  True,  True],
          [False, False, False, False, False]]]])

{'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': True}
grid.shape=torch.Size([1, 1, 4, 5])
tensor([[[[False, False, False, False, False],
          [False,  True,  True,  True, False],
          [False,  True,  True,  True, False],
          [False, False, False, False, False]]]])
```

Here's a tiny exceprt script, showing the problem;
it generated the above output.

```python
import torch
from typing import Tuple

def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
    """generate N-D grid in dimension order.

    The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.

    That is, the statement
    [X1,X2,X3] = ndgrid(x1,x2,x3)

    produces the same result as

    [X2,X1,X3] = meshgrid(x2,x1,x3)

    This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
    torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').

    """
    try:
        return torch.meshgrid(*tensors, indexing='ij')
    except TypeError:
        # old PyTorch < 1.10 will follow this path as it does not have indexing arg,
        # the old behaviour of meshgrid was 'ij'
        return torch.meshgrid(*tensors)

def valid_block(H, W, block_size, fix_reshape=False):
    clipped_block_size = min(block_size, H, W)

    if fix_reshape:
        # This should match the .reshape() dimension order below.
        h_i, w_i = ndgrid(torch.arange(H), torch.arange(W))
    else:
        # The original produces crazy stride patterns, due to .reshape() offset winding.
        # This is only visible when H != W.
        w_i, h_i = ndgrid(torch.arange(W), torch.arange(H))

    valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
                 ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))

    valid_block = torch.reshape(valid_block, (1, 1, H, W))

    return valid_block

def main():
    common_args = dict(H=4, W=5, block_size=3)

    for fix in [False, True]:
        args = dict(H=4, W=5, block_size=3, fix_reshape=fix)
        grid = valid_block(**args)
        print(args)
        print(f"{grid.shape=}")
        print(grid)
        print()

if __name__ == "__main__":
    main()
```
@crutcher
Copy link
Author

I realized I'd been so focused on fixing the bug; I hadn't noticed that all of the meshgrid stuff was entirely un-needed. Switched to slice assignment; it does the same thing.

@rwightman
Copy link
Collaborator

@crutcher indeed yeah original should have flipped the assignmetn with ndgrid or used meshgrid + indexing='xy' ... but slice assignment is clearer. Additionally, with slice assignment there's no point in using bool, should create zeroes array with x.dtype and assign 1.0 to avoid another allocation + dtype conversion.

I think there's also technically a problem with even block sizes and the padding + feature map sizing. I believe max pool w/ 'same' padding (asymmetric padding) is needed to for even blocks to work no?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@crutcher
Copy link
Author

@rwightman I deep dove on this yesterday, (I'm implementing it for burn, which is how I found this).

The weird behavior with even-sized kernels is correct, but I don't want to submit yet.

max_pool2d (and conv2d that it sits on) has weird but defined behavior about where it considers the midpoint location of even-sized kernels. I now (think) i understand the implications, and I want to document (and test) this so that the next person doesn't spend another day on it.

@crutcher
Copy link
Author

@rwightman Ok; i think this in a much better state.

@crutcher crutcher changed the title Fix bug in timm.layers.drop.drop_block_2d when H != W. Fix bug in timm.layers.drop.drop_block_2d; unify fast/slow versions; add model and unit tests. Aug 19, 2025
@crutcher
Copy link
Author

I don't know enough about this test base to mask out the failing int/bool dtype tests for jit, so I just killed them.

@rwightman
Copy link
Collaborator

@crutcher I'm out on vacation for a week and a bit so won't be able to take a closer look re merge for a bit...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants