Skip to content

Commit 27663fa

Browse files
committed
handle some comments from other prs review
1 parent bda6570 commit 27663fa

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

test/test_transforms_v2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3473,12 +3473,14 @@ def test_functional(self, make_input):
34733473
(F.crop_keypoints, tv_tensors.KeyPoints),
34743474
pytest.param(
34753475
F._geometry._crop_cvcuda,
3476-
_import_cvcuda().Tensor,
3476+
"cvcuda.Tensor",
34773477
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
34783478
),
34793479
],
34803480
)
34813481
def test_functional_signature(self, kernel, input_type):
3482+
if input_type == "cvcuda.Tensor":
3483+
input_type = _import_cvcuda().Tensor
34823484
check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type)
34833485

34843486
@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
@@ -4932,12 +4934,14 @@ def test_functional(self, make_input):
49324934
(F.center_crop_keypoints, tv_tensors.KeyPoints),
49334935
pytest.param(
49344936
F._geometry._center_crop_cvcuda,
4935-
_import_cvcuda().Tensor,
4937+
"cvcuda.Tensor",
49364938
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
49374939
),
49384940
],
49394941
)
49404942
def test_functional_signature(self, kernel, input_type):
4943+
if input_type == "cvcuda.Tensor":
4944+
input_type = _import_cvcuda().Tensor
49414945
check_functional_kernel_signature_match(F.center_crop, kernel=kernel, input_type=input_type)
49424946

49434947
@pytest.mark.parametrize(

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,7 +1952,7 @@ def _crop_cvcuda(
19521952

19531953

19541954
if CVCUDA_AVAILABLE:
1955-
_crop_cvcuda_registered = _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda)
1955+
_register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda)
19561956

19571957

19581958
def perspective(
@@ -2741,9 +2741,7 @@ def _center_crop_cvcuda(
27412741

27422742

27432743
if CVCUDA_AVAILABLE:
2744-
_center_crop_cvcuda_registered = _register_kernel_internal(center_crop, _import_cvcuda().Tensor)(
2745-
_center_crop_cvcuda
2746-
)
2744+
_register_kernel_internal(center_crop, _import_cvcuda().Tensor)(_center_crop_cvcuda)
27472745

27482746

27492747
def resized_crop(

0 commit comments

Comments
 (0)