Skip to content

Commit ee626ae

Browse files
committed
simplify test for center crop cvcuda
1 parent 5d5c436 commit ee626ae

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

test/test_transforms_v2.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4979,25 +4979,30 @@ def test_transform(self, make_input):
49794979
check_transform(transforms.CenterCrop(self.OUTPUT_SIZES[0]), make_input(self.INPUT_SIZE))
49804980

49814981
@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
4982+
@pytest.mark.parametrize(
4983+
"make_input",
4984+
[
4985+
make_image,
4986+
pytest.param(
4987+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
4988+
),
4989+
],
4990+
)
49824991
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
4983-
def test_image_correctness(self, output_size, fn):
4984-
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
4992+
def test_image_correctness(self, output_size, make_input, fn):
4993+
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
49854994

49864995
actual = fn(image, output_size)
4987-
expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))
49884996

4989-
assert_equal(actual, expected)
4990-
4991-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
4992-
@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
4993-
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
4994-
def test_cvcuda_correctness(self, output_size, fn):
4995-
image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8, device="cuda")
4997+
if make_input == make_image_cvcuda:
4998+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
4999+
actual = actual.squeeze(0)
5000+
image = F.cvcuda_to_tensor(image).to(device="cpu")
5001+
image = image.squeeze(0)
49965002

4997-
actual = fn(image, output_size)
4998-
expected = F.center_crop(F.cvcuda_to_tensor(image), output_size)
5003+
expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))
49995004

5000-
assert_equal(F.cvcuda_to_tensor(actual), expected)
5005+
assert_equal(actual, expected)
50015006

50025007
def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
50035008
image_height, image_width = bounding_boxes.canvas_size

0 commit comments

Comments
 (0)