Skip to content

Commit c16a033

Browse files
committed
remove extra parameterize for dtype
1 parent f1bb502 commit c16a033

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

test/test_transforms_v2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5608,10 +5608,9 @@ def test_correctness_image(self, mean, std, dtype, fn):
56085608

56095609
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
56105610
@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
5611-
@pytest.mark.parametrize("dtype", [torch.float32])
56125611
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
5613-
def test_correctness_cvcuda(self, mean, std, dtype, fn):
5614-
image = make_image(batch_dims=(1,), dtype=dtype, device="cuda")
5612+
def test_correctness_cvcuda(self, mean, std, fn):
5613+
image = make_image(batch_dims=(1,), dtype=torch.float32, device="cuda")
56155614
cvc_image = F.to_cvcuda_tensor(image)
56165615
actual = F._misc._normalize_cvcuda(cvc_image, mean=mean, std=std)
56175616
expected = fn(image, mean=mean, std=std)

0 commit comments

Comments
 (0)