Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 171 additions & 18 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -2822,7 +2822,18 @@ class TestAdjustBrightness:
def test_kernel(self, kernel, make_input, dtype, device):
check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
check_functional(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)

Expand All @@ -2833,19 +2844,42 @@ def test_functional(self, make_input):
(F._color._adjust_brightness_image_pil, PIL.Image.Image),
(F.adjust_brightness_image, tv_tensors.Image),
(F.adjust_brightness_video, tv_tensors.Video),
pytest.param(
F._color._adjust_brightness_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._adjust_brightness_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS)
def test_image_correctness(self, brightness_factor):
image = make_image(dtype=torch.uint8, device="cpu")
def test_image_correctness(self, make_input, brightness_factor):
image = make_input(dtype=torch.uint8, device="cpu")

actual = F.adjust_brightness(image, brightness_factor=brightness_factor)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor))

torch.testing.assert_close(actual, expected)
if make_input is make_image_cvcuda:
assert_close(actual, expected, rtol=0, atol=1)
else:
assert_close(actual, expected)


class TestCutMixMixUp:
Expand Down Expand Up @@ -6053,7 +6087,18 @@ def test_kernel_image(self, dtype, device):
def test_kernel_video(self):
check_kernel(F.adjust_contrast_video, make_video(), contrast_factor=0.5)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_image_pil,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
check_functional(F.adjust_contrast, make_input(), contrast_factor=0.5)

Expand All @@ -6064,9 +6109,16 @@ def test_functional(self, make_input):
(F._color._adjust_contrast_image_pil, PIL.Image.Image),
(F.adjust_contrast_image, tv_tensors.Image),
(F.adjust_contrast_video, tv_tensors.Video),
pytest.param(
F._color._adjust_contrast_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._adjust_contrast_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.adjust_contrast, kernel=kernel, input_type=input_type)

def test_functional_error(self):
Expand All @@ -6076,11 +6128,24 @@ def test_functional_error(self):
with pytest.raises(ValueError, match="is not non-negative"):
F.adjust_contrast(make_image(), contrast_factor=-1)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("contrast_factor", [0.1, 0.5, 1.0])
def test_correctness_image(self, contrast_factor):
image = make_image(dtype=torch.uint8, device="cpu")
def test_correctness_image(self, make_input, contrast_factor):
image = make_input(dtype=torch.uint8, device="cpu")

actual = F.adjust_contrast(image, contrast_factor=contrast_factor)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.adjust_contrast(F.to_pil_image(image), contrast_factor=contrast_factor))

assert_close(actual, expected, rtol=0, atol=1)
Expand Down Expand Up @@ -6135,7 +6200,18 @@ def test_kernel_image(self, dtype, device):
def test_kernel_video(self):
check_kernel(F.adjust_hue_video, make_video(), hue_factor=0.25)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_image_pil,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
check_functional(F.adjust_hue, make_input(), hue_factor=0.25)

Expand All @@ -6146,9 +6222,16 @@ def test_functional(self, make_input):
(F._color._adjust_hue_image_pil, PIL.Image.Image),
(F.adjust_hue_image, tv_tensors.Image),
(F.adjust_hue_video, tv_tensors.Video),
pytest.param(
F._color._adjust_hue_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._adjust_hue_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.adjust_hue, kernel=kernel, input_type=input_type)

def test_functional_error(self):
Expand All @@ -6159,11 +6242,25 @@ def test_functional_error(self):
with pytest.raises(ValueError, match=re.escape("is not in [-0.5, 0.5]")):
F.adjust_hue(make_image(), hue_factor=hue_factor)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("hue_factor", [-0.5, -0.3, 0.0, 0.2, 0.5])
def test_correctness_image(self, hue_factor):
image = make_image(dtype=torch.uint8, device="cpu")
def test_correctness_image(self, make_input, hue_factor):
image = make_input(dtype=torch.uint8, device="cpu")

actual = F.adjust_hue(image, hue_factor=hue_factor)

if make_input is make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual)[0].cpu()
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.adjust_hue(F.to_pil_image(image), hue_factor=hue_factor))

mae = (actual.float() - expected.float()).abs().mean()
Expand All @@ -6179,7 +6276,18 @@ def test_kernel_image(self, dtype, device):
def test_kernel_video(self):
check_kernel(F.adjust_saturation_video, make_video(), saturation_factor=0.5)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_image_pil,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
check_functional(F.adjust_saturation, make_input(), saturation_factor=0.5)

Expand All @@ -6190,9 +6298,16 @@ def test_functional(self, make_input):
(F._color._adjust_saturation_image_pil, PIL.Image.Image),
(F.adjust_saturation_image, tv_tensors.Image),
(F.adjust_saturation_video, tv_tensors.Video),
pytest.param(
F._color._adjust_saturation_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._adjust_saturation_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type)

def test_functional_error(self):
Expand All @@ -6202,11 +6317,25 @@ def test_functional_error(self):
with pytest.raises(ValueError, match="is not non-negative"):
F.adjust_saturation(make_image(), saturation_factor=-1)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("saturation_factor", [0.1, 0.5, 1.0])
def test_correctness_image(self, saturation_factor):
image = make_image(dtype=torch.uint8, device="cpu")
def test_correctness_image(self, make_input, color_space, saturation_factor):
image = make_input(dtype=torch.uint8, color_space=color_space, device="cpu")

actual = F.adjust_saturation(image, saturation_factor=saturation_factor)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor))

assert_close(actual, expected, rtol=0, atol=1)
Expand Down Expand Up @@ -6339,7 +6468,15 @@ def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip):
class TestColorJitter:
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
Expand Down Expand Up @@ -6383,24 +6520,40 @@ def test_transform_error(self):
with pytest.raises(ValueError, match="values should be between"):
transforms.ColorJitter(hue=1)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("brightness", [None, 0.1, (0.2, 0.3)])
@pytest.mark.parametrize("contrast", [None, 0.4, (0.5, 0.6)])
@pytest.mark.parametrize("saturation", [None, 0.7, (0.8, 0.9)])
@pytest.mark.parametrize("hue", [None, 0.3, (-0.1, 0.2)])
def test_transform_correctness(self, brightness, contrast, saturation, hue):
image = make_image(dtype=torch.uint8, device="cpu")
def test_transform_correctness(self, make_input, brightness, contrast, saturation, hue):
image = make_input(dtype=torch.uint8, device="cpu")

transform = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)

with freeze_rng_state():
torch.manual_seed(0)
actual = transform(image)

if make_input is make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual)[0].cpu()
image = F.cvcuda_to_tensor(image)[0].cpu()

torch.manual_seed(0)
expected = F.to_image(transform(F.to_pil_image(image)))

mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2
mae_threshold = 2
if make_input is make_image_cvcuda:
mae_threshold = 3
assert mae < mae_threshold, f"MAE: {mae}"


class TestRgbToGrayscale:
Expand Down
3 changes: 3 additions & 0 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torchvision import transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor

from ._transform import _RandomApplyTransform
from ._utils import query_chw
Expand Down Expand Up @@ -96,6 +97,8 @@ class ColorJitter(Transform):

_v1_transform_cls = _transforms.ColorJitter

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def _extract_params_for_v1_transform(self) -> dict[str, Any]:
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}

Expand Down
Loading