Skip to content
Open
Empty file added tests/layers/__init__.py
Empty file.
144 changes: 144 additions & 0 deletions tests/layers/test_drop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import importlib
import os
import unittest

import torch

from timm.layers import drop

torch_backend = os.environ.get("TORCH_BACKEND")
if torch_backend is not None:
importlib.import_module(torch_backend)
torch_device = os.environ.get("TORCH_DEVICE", "cpu")


class Conv2dKernelMidpointMask2d(unittest.TestCase):
def test_conv2d_kernel_midpoint_mask_odd(self):
mask = drop.conv2d_kernel_midpoint_mask(
shape=(5, 7),
kernel=(3, 3),
device=torch_device,
dtype=torch.float32,
)
print(mask)
assert mask.device == torch.device(torch_device)
assert mask.tolist() == [
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
]

def test_conv2d_kernel_midpoint_mask_even(self):
mask = drop.conv2d_kernel_midpoint_mask(
shape=(5, 7),
kernel=(2, 2),
device=torch_device,
dtype=torch.float32,
)
print(mask)
assert mask.device == torch.device(torch_device)
assert mask.tolist() == [
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
]

def test_clip_mask_2d_kernel_too_big(self):
try:
drop.conv2d_kernel_midpoint_mask(
shape=(4, 7),
kernel=(5, 5),
device=torch_device,
dtype=torch.float32,
)
raise RuntimeError("Expected throw")

except AssertionError as e:
assert "kernel=(5, 5) ! <= shape=(4, 7)" in e.args[0]


class DropBlock2dDropFilterTest(unittest.TestCase):
def test_drop_filter(self):
selection = (
torch.tensor(
[
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
],
device=torch_device,
)
.unsqueeze(0)
.unsqueeze(0)
)

result = drop.drop_block_2d_drop_filter_(
selection=selection, kernel=(2, 3), partial_edge_blocks=False
).squeeze()
print(result)
assert result.device == torch.device(torch_device)
assert result.tolist() == [
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
]

def test_drop_filter_partial_edge_blocks(self):
selection = (
torch.tensor(
[
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
],
device=torch_device,
dtype=torch.float32,
)
.unsqueeze(0)
.unsqueeze(0)
)

result = drop.drop_block_2d_drop_filter_(
selection=selection, kernel=(2, 3), partial_edge_blocks=True
).squeeze()
print(result)
assert result.device == torch.device(torch_device)
assert result.tolist() == [
[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
]


class DropBlock2dTest(unittest.TestCase):
def test_drop_block_2d(self):
tensor = torch.ones((1, 1, 200, 300), device=torch_device)

drop_prob = 0.1
keep_prob = 1.0 - drop_prob

result = drop.drop_block_2d(
tensor,
drop_prob=drop_prob,
with_noise=True,
).squeeze()

numel = float(result.numel())
unchanged = float(len(result[result == 1.0]))
keep_ratio = unchanged / numel

assert (
abs(keep_ratio - keep_prob) < 0.05
), f"abs({keep_ratio=} - {keep_prob=}) ! < 0.05"
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_model_forward(model_name, batch_size):
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model"""
"""Run a single forward and backward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")
Expand Down Expand Up @@ -594,7 +594,7 @@ def _create_fx_model(model, train=False):
return fx_model


EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*']
EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*', '*dropblock*']
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
Expand Down
Loading