2020from torch .testing ._comparison import BooleanPair , NonePair , not_close_error_metas , NumberPair , TensorLikePair
2121from torchvision import io , tv_tensors
2222from torchvision .transforms ._functional_tensor import _max_value as get_max_value
23- from torchvision .transforms .v2 .functional import to_cvcuda_tensor , to_image , to_pil_image
23+ from torchvision .transforms .v2 .functional import cvcuda_to_tensor , to_cvcuda_tensor , to_image , to_pil_image
24+ from torchvision .transforms .v2 .functional ._utils import _import_cvcuda , _is_cvcuda_available
2425from torchvision .utils import _Image_fromarray
2526
2627
2728IN_OSS_CI = any (os .getenv (var ) == "true" for var in ["CIRCLECI" , "GITHUB_ACTIONS" ])
2829IN_RE_WORKER = os .environ .get ("INSIDE_RE_WORKER" ) is not None
2930IN_FBCODE = os .environ .get ("IN_FBCODE_TORCHVISION" ) == "1"
31+ CVCUDA_AVAILABLE = _is_cvcuda_available ()
3032CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3133MPS_NOT_AVAILABLE_MSG = "MPS device not available"
3234OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@@ -275,6 +277,17 @@ def combinations_grid(**kwargs):
275277 return [dict (zip (kwargs .keys (), values )) for values in itertools .product (* kwargs .values ())]
276278
277279
280+ def cvcuda_to_pil_compatible_tensor (tensor : "cvcuda.Tensor" ) -> torch .Tensor :
281+ tensor = cvcuda_to_tensor (tensor )
282+ if tensor .ndim != 4 :
283+ raise ValueError (f"CV-CUDA Tensor should be 4 dimensional. Got { tensor .ndim } dimensions." )
284+ if tensor .shape [0 ] != 1 :
285+ raise ValueError (
286+ f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got { tensor .shape [0 ]} ."
287+ )
288+ return tensor .squeeze (0 ).cpu ()
289+
290+
278291class ImagePair (TensorLikePair ):
279292 def __init__ (
280293 self ,
@@ -287,6 +300,11 @@ def __init__(
287300 if all (isinstance (input , PIL .Image .Image ) for input in [actual , expected ]):
288301 actual , expected = (to_image (input ) for input in [actual , expected ])
289302
303+ # handle check for CV-CUDA Tensors
304+ if CVCUDA_AVAILABLE and isinstance (actual , _import_cvcuda ().Tensor ):
305+ # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image
306+ actual = cvcuda_to_pil_compatible_tensor (actual )
307+
290308 super ().__init__ (actual , expected , ** other_parameters )
291309 self .mae = mae
292310
@@ -401,7 +419,6 @@ def make_image_pil(*args, **kwargs):
401419
402420
403421def make_image_cvcuda (* args , batch_dims = (1 ,), ** kwargs ):
404- # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4)
405422 return to_cvcuda_tensor (make_image (* args , batch_dims = batch_dims , ** kwargs ))
406423
407424
0 commit comments