Skip to content

Commit c227aa0

Browse files
committed
fix: to_dtype_cvcuda conventions
1 parent 20aa030 commit c227aa0

File tree

3 files changed

+27
-33
lines changed

3 files changed

+27
-33
lines changed

test/test_transforms_v2.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2628,7 +2628,17 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca
26282628
scale=scale,
26292629
)
26302630

2631-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
2631+
@pytest.mark.parametrize(
2632+
"make_input",
2633+
[
2634+
make_image_tensor,
2635+
make_image,
2636+
make_video,
2637+
pytest.param(
2638+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
2639+
),
2640+
],
2641+
)
26322642
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
26332643
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
26342644
@pytest.mark.parametrize("device", cpu_and_cuda())
@@ -2643,7 +2653,16 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale):
26432653

26442654
@pytest.mark.parametrize(
26452655
"make_input",
2646-
[make_image_tensor, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
2656+
[
2657+
make_image_tensor,
2658+
make_image,
2659+
make_bounding_boxes,
2660+
make_segmentation_mask,
2661+
make_video,
2662+
pytest.param(
2663+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
2664+
),
2665+
],
26472666
)
26482667
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
26492668
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@@ -2807,38 +2826,12 @@ def test_uint16(self):
28072826
assert_equal(F.to_dtype(img_float32, torch.uint8, scale=True), img_uint8)
28082827
assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2)
28092828

2810-
2811-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda is not available")
2812-
@needs_cuda
2813-
class TestToDtypeCVCUDA:
2814-
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
2815-
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
2816-
@pytest.mark.parametrize("device", cpu_and_cuda())
2817-
@pytest.mark.parametrize("scale", (True, False))
2818-
def test_functional(self, input_dtype, output_dtype, device, scale):
2819-
check_functional(
2820-
F.to_dtype,
2821-
make_image_cvcuda(batch_dims=(1,), dtype=input_dtype, device=device),
2822-
dtype=output_dtype,
2823-
scale=scale,
2824-
)
2825-
2826-
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
2827-
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
2828-
@pytest.mark.parametrize("device", cpu_and_cuda())
2829-
@pytest.mark.parametrize("scale", (True, False))
2830-
@pytest.mark.parametrize("as_dict", (True, False))
2831-
def test_transform(self, input_dtype, output_dtype, device, scale, as_dict):
2832-
cvc_input = make_image_cvcuda(batch_dims=(1,), dtype=input_dtype, device=device)
2833-
if as_dict:
2834-
output_dtype = {type(cvc_input): output_dtype}
2835-
check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), cvc_input, check_sample_input=not as_dict)
2836-
2829+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
28372830
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
28382831
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
28392832
@pytest.mark.parametrize("device", cpu_and_cuda())
28402833
@pytest.mark.parametrize("scale", (True, False))
2841-
def test_image_correctness(self, input_dtype, output_dtype, device, scale):
2834+
def test_cvcuda_parity(self, input_dtype, output_dtype, device, scale):
28422835
if input_dtype.is_floating_point and output_dtype == torch.int64:
28432836
pytest.xfail("float to int64 conversion is not supported")
28442837
if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda":

torchvision/transforms/v2/functional/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@
158158
sanitize_bounding_boxes,
159159
sanitize_keypoints,
160160
to_dtype,
161-
to_dtype_cvcuda,
162161
to_dtype_image,
163162
to_dtype_video,
164163
)

torchvision/transforms/v2/functional/_misc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,13 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo
371371
_cvcuda_to_torch_dtypes = {v: k for k, v in _torch_to_cvcuda_dtypes.items()}
372372

373373

374-
def to_dtype_cvcuda(
374+
def _to_dtype_cvcuda(
375375
inpt: "cvcuda.Tensor",
376376
dtype: torch.dtype,
377377
scale: bool = False,
378378
) -> "cvcuda.Tensor":
379+
cvcuda = _import_cvcuda()
380+
379381
dtype_in = _cvcuda_to_torch_dtypes[inpt.dtype]
380382
cvc_dtype = _torch_to_cvcuda_dtypes[dtype]
381383

@@ -409,7 +411,7 @@ def to_dtype_cvcuda(
409411

410412

411413
if CVCUDA_AVAILABLE:
412-
_register_kernel_internal(to_dtype, cvcuda.Tensor)(to_dtype_cvcuda)
414+
_to_dtype_cvcuda_registered = _register_kernel_internal(to_dtype, _import_cvcuda().Tensor)(_to_dtype_cvcuda)
413415

414416

415417
def sanitize_bounding_boxes(

0 commit comments

Comments
 (0)