Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
44db71c
implement additional cvcuda infra for all branches to avoid duplicate…
justincdavis Nov 25, 2025
e3dd700
update make_image_cvcuda to have default batch dim
justincdavis Nov 25, 2025
c035df1
add stanardized setup to main for easier updating of PRs and branches
justincdavis Dec 2, 2025
98d7dfb
update is_cvcuda_tensor
justincdavis Dec 2, 2025
ddc116d
add cvcuda to pil compatible to transforms by default
justincdavis Dec 2, 2025
e51dc7e
remove cvcuda from transform class
justincdavis Dec 2, 2025
e14e210
merge with main
justincdavis Dec 4, 2025
4939355
resolve more formatting naming
justincdavis Dec 4, 2025
ec76196
initial draft of to_dtype_cvcuda
justincdavis Nov 18, 2025
bd823cf
fix: to_dtype_cvcuda conventions
justincdavis Nov 20, 2025
f7aa94a
remove staticmethod from reference todtype
justincdavis Nov 24, 2025
b21d9f0
add docstring for explain scaling setup, combine correctness checks
justincdavis Nov 24, 2025
973e058
resolve more review comments
justincdavis Nov 24, 2025
d871331
simplify todtype testing
justincdavis Nov 26, 2025
736a2e6
add int -> int scaling setup for cvcuda, use bit diff for scale
justincdavis Nov 26, 2025
7a231b1
further simplify todtype test
justincdavis Nov 26, 2025
d3e4573
update todtype based on PR reviews
justincdavis Dec 2, 2025
ec93ba3
cleanup commnet, variable names
justincdavis Dec 2, 2025
89122db
update to_dtype_cvcuda name
justincdavis Dec 4, 2025
1b0d295
update to standards from flip PR
justincdavis Dec 4, 2025
009f925
remove cvcuda updates to augment
justincdavis Dec 4, 2025
41af724
remove cvcuda refs from color
justincdavis Dec 4, 2025
d12e4df
refactor dtype converters to be in utils
justincdavis Dec 4, 2025
c198cf0
add type checking for cvcuda
justincdavis Dec 4, 2025
18df67f
provide better error for todtype
justincdavis Dec 4, 2025
c5a2a5a
refactor to simplify setup for dtype conversions
justincdavis Dec 5, 2025
915ffb1
Merge branch 'main' into feat/dtype_cvcuda
justincdavis Dec 5, 2025
7f41c95
fix: not testing transform class correctness in ToDtype, resolved
justincdavis Dec 5, 2025
9b41552
preserve previous torchvision test behavior for non cvcuda inputs
justincdavis Dec 8, 2025
b9c378b
further simplify branching flow of testtodtype image correctness
justincdavis Dec 8, 2025
1781244
add functional signature tests, fix bug in type check in todtype tran…
justincdavis Dec 8, 2025
626b47a
add consolidated cvcuda test markers
justincdavis Dec 9, 2025
e8540ba
finalize consolidated cvcuda skip behavior
justincdavis Dec 9, 2025
5aa4b3d
revert var name change back to input
justincdavis Dec 9, 2025
7cbf30e
drop the dimensions and num channels variants for cvcuda
justincdavis Dec 12, 2025
93bf674
drop _is_cvcuda_tensor from _utils query_size query_chw unused in thi…
justincdavis Dec 12, 2025
ecf3c58
simplify ToDtype class
justincdavis Dec 12, 2025
8cd76dc
refactor to move the dtype tables to _misc adjacent with to_dtype_ima…
justincdavis Dec 12, 2025
713810f
drop the evergreen cvcuda import at file level
justincdavis Dec 12, 2025
a68b7b5
make atol thresholds clearer and smaller, drop uint16 to uint8 for cv…
justincdavis Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 107 additions & 61 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand All @@ -60,9 +60,11 @@
)


CVCUDA_AVAILABLE = _is_cvcuda_available()
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda()
CV_CUDA_TEST = [
pytest.mark.skipif(not _is_cvcuda_available(), reason="CVCUDA is not available"),
pytest.mark.needs_cuda,
]


# turns all warnings into errors for this module
pytestmark = [pytest.mark.filterwarnings("error")]
Expand Down Expand Up @@ -1240,10 +1242,7 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a not that you should be able to remove these changes once #9305 lands.

make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1259,11 +1258,7 @@ def test_functional(self, make_input):
(F.horizontal_flip_image, torch.Tensor),
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._horizontal_flip_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(F._geometry._horizontal_flip_image_cvcuda, None, marks=CV_CUDA_TEST),
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.horizontal_flip_mask, tv_tensors.Mask),
(F.horizontal_flip_video, tv_tensors.Video),
Expand All @@ -1281,10 +1276,7 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1302,10 +1294,7 @@ def test_transform(self, make_input, device):
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
],
)
def test_image_correctness(self, fn, make_input):
Expand Down Expand Up @@ -1370,10 +1359,7 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand Down Expand Up @@ -1882,10 +1868,7 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1901,11 +1884,7 @@ def test_functional(self, make_input):
(F.vertical_flip_image, torch.Tensor),
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._vertical_flip_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(F._geometry._vertical_flip_image_cvcuda, None, marks=CV_CUDA_TEST),
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.vertical_flip_mask, tv_tensors.Mask),
(F.vertical_flip_video, tv_tensors.Video),
Expand All @@ -1923,10 +1902,7 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1942,10 +1918,7 @@ def test_transform(self, make_input, device):
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
],
)
def test_image_correctness(self, fn, make_input):
Expand Down Expand Up @@ -2006,10 +1979,7 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand Down Expand Up @@ -2627,7 +2597,32 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca
scale=scale,
)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.to_dtype_image, torch.Tensor),
(F.to_dtype_video, tv_tensors.Video),
pytest.param(
F._misc._to_dtype_image_cvcuda,
None,
marks=CV_CUDA_TEST,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._misc._to_dtype_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.to_dtype, kernel=kernel, input_type=input_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this test!


@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_video,
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
Expand All @@ -2642,7 +2637,14 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale):

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
[
make_image_tensor,
make_image,
make_bounding_boxes,
make_segmentation_mask,
make_video,
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
Expand Down Expand Up @@ -2688,25 +2690,69 @@ def fn(value):

return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device)

def _get_dtype_conversion_atol_cvcuda(self, input_dtype, output_dtype):
in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None
out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None
narrows_bits = in_bits is not None and out_bits is not None and out_bits < in_bits

# int->int with narrowing bits, allow atol=1 for rounding diffs
if narrows_bits:
atol = 1
# float->int check for same diff, rounding error on float
elif input_dtype.is_floating_point and not output_dtype.is_floating_point:
atol = 1
# if generating a float value from an int, allow small rounding error
elif not input_dtype.is_floating_point and output_dtype.is_floating_point:
atol = 1e-7
# all other cases, should be exact
# uint8 -> uint16 promotion would be here
else:
atol = 0

return atol

@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_image_correctness(self, input_dtype, output_dtype, device, scale):
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST),
],
)
@pytest.mark.parametrize("fn", [F.to_dtype, transform_cls_to_functional(transforms.ToDtype)])
def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_input, fn):
if input_dtype.is_floating_point and output_dtype == torch.int64:
pytest.xfail("float to int64 conversion is not supported")
if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda":
pytest.xfail("uint8 to uint16 conversion is not supported on cuda")
if (
input_dtype == torch.uint16
and output_dtype == torch.uint8
and not scale
and make_input is make_image_cvcuda
):
pytest.xfail("uint16 to uint8 conversion without scale is not supported for CV-CUDA.")

input = make_image(dtype=input_dtype, device=device)
input = make_input(dtype=input_dtype, device=device)
out = fn(input, dtype=output_dtype, scale=scale)

if make_input is make_image_cvcuda:
input = F.cvcuda_to_tensor(input)
out = F.cvcuda_to_tensor(out)

out = F.to_dtype(input, dtype=output_dtype, scale=scale)
expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale)

if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
torch.testing.assert_close(out, expected, atol=1, rtol=0)
else:
torch.testing.assert_close(out, expected)
atol, rtol = None, None
if make_input is make_image_cvcuda:
atol = self._get_dtype_conversion_atol_cvcuda(input_dtype, output_dtype)
rtol = 0
elif input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
atol, rtol = 1, 0

torch.testing.assert_close(out, expected, atol=atol, rtol=rtol)

def was_scaled(self, inpt):
# this assumes the target dtype is float
Expand Down Expand Up @@ -6794,9 +6840,9 @@ def test_functional_error(self):
F.pil_to_tensor(object())


@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
@needs_cuda
class TestToCVCUDATensor:
pytestmark = CV_CUDA_TEST

@pytest.mark.parametrize("image_type", (torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
@pytest.mark.parametrize("device", cpu_and_cuda())
Expand All @@ -6813,7 +6859,7 @@ def test_functional_and_transform(self, image_type, dtype, device, color_space,
assert is_pure_tensor(image)
output = fn(image)

assert isinstance(output, cvcuda.Tensor)
assert isinstance(output, _import_cvcuda().Tensor)
assert F.get_size(output) == F.get_size(image)
assert output is not None

Expand Down Expand Up @@ -6856,9 +6902,9 @@ def test_round_trip(self, dtype, device, color_space, batch_size):
assert result_tensor.shape[0] == batch_size


@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
@needs_cuda
class TestCVDUDAToTensor:
class TestCVCUDAToTensor:
pytestmark = CV_CUDA_TEST

@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
Expand Down
13 changes: 10 additions & 3 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor

from ._utils import (
_parse_labels_getter,
Expand Down Expand Up @@ -267,7 +268,7 @@ class ToDtype(Transform):
Default: ``False``.
"""

_transformed_types = (torch.Tensor,)
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(
self, dtype: Union[torch.dtype, dict[Union[type, str], Optional[torch.dtype]]], scale: bool = False
Expand All @@ -294,7 +295,11 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
if (
not is_pure_tensor(inpt)
and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
and not _is_cvcuda_tensor(inpt)
):
return inpt

dtype: Optional[torch.dtype] = self.dtype
Expand All @@ -311,7 +316,9 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)

supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
supports_scaling = (
is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or _is_cvcuda_tensor(inpt)
)
if dtype is None:
if self.scale and supports_scaling:
warnings.warn(
Expand Down
Loading