Skip to content

Commit 80dc7dd

Browse files
committed
rgb to gray and gray to rgb done
1 parent e3dd700 commit 80dc7dd

File tree

4 files changed

+181
-11
lines changed

4 files changed

+181
-11
lines changed

test/test_transforms_v2.py

Lines changed: 101 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6357,7 +6357,17 @@ class TestRgbToGrayscale:
63576357
def test_kernel_image(self, dtype, device):
63586358
check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device))
63596359

6360-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
6360+
@pytest.mark.parametrize(
6361+
"make_input",
6362+
[
6363+
make_image_tensor,
6364+
make_image_pil,
6365+
make_image,
6366+
pytest.param(
6367+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6368+
),
6369+
],
6370+
)
63616371
def test_functional(self, make_input):
63626372
check_functional(F.rgb_to_grayscale, make_input())
63636373

@@ -6367,23 +6377,62 @@ def test_functional(self, make_input):
63676377
(F.rgb_to_grayscale_image, torch.Tensor),
63686378
(F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
63696379
(F.rgb_to_grayscale_image, tv_tensors.Image),
6380+
pytest.param(
6381+
F._color._rgb_to_grayscale_cvcuda,
6382+
"cvcuda.Tensor",
6383+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
6384+
),
63706385
],
63716386
)
63726387
def test_functional_signature(self, kernel, input_type):
6388+
if input_type == "cvcuda.Tensor":
6389+
input_type = _import_cvcuda().Tensor
63736390
check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type)
63746391

63756392
@pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)])
6376-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
6393+
@pytest.mark.parametrize(
6394+
"make_input",
6395+
[
6396+
make_image_tensor,
6397+
make_image_pil,
6398+
make_image,
6399+
pytest.param(
6400+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6401+
),
6402+
],
6403+
)
63776404
def test_transform(self, transform, make_input):
6405+
if make_input is make_image_cvcuda and isinstance(transform, transforms.RandomGrayscale):
6406+
pytest.skip("CV-CUDA does not support RandomGrayscale, will have num_output_channels == 3")
63786407
check_transform(transform, make_input())
63796408

63806409
@pytest.mark.parametrize("num_output_channels", [1, 3])
63816410
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
6411+
@pytest.mark.parametrize(
6412+
"make_input",
6413+
[
6414+
make_image,
6415+
pytest.param(
6416+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6417+
),
6418+
],
6419+
)
63826420
@pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
6383-
def test_image_correctness(self, num_output_channels, color_space, fn):
6384-
image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space)
6421+
def test_image_correctness(self, num_output_channels, color_space, make_input, fn):
6422+
if make_input is make_image_cvcuda and num_output_channels == 3:
6423+
pytest.skip("CV-CUDA does not support num_output_channels == 3")
6424+
6425+
image = make_input(dtype=torch.uint8, device="cpu", color_space=color_space)
63856426

63866427
actual = fn(image, num_output_channels=num_output_channels)
6428+
6429+
if make_input is make_image_cvcuda:
6430+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6431+
actual = actual.squeeze(0)
6432+
# drop the batch dimension
6433+
image = F.cvcuda_to_tensor(image).to(device="cpu")
6434+
image = image.squeeze(0)
6435+
63876436
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))
63886437

63896438
assert_equal(actual, expected, rtol=0, atol=1)
@@ -6421,7 +6470,17 @@ class TestGrayscaleToRgb:
64216470
def test_kernel_image(self, dtype, device):
64226471
check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device))
64236472

6424-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
6473+
@pytest.mark.parametrize(
6474+
"make_input",
6475+
[
6476+
make_image_tensor,
6477+
make_image_pil,
6478+
make_image,
6479+
pytest.param(
6480+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6481+
),
6482+
],
6483+
)
64256484
def test_functional(self, make_input):
64266485
check_functional(F.grayscale_to_rgb, make_input())
64276486

@@ -6431,20 +6490,54 @@ def test_functional(self, make_input):
64316490
(F.rgb_to_grayscale_image, torch.Tensor),
64326491
(F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
64336492
(F.rgb_to_grayscale_image, tv_tensors.Image),
6493+
pytest.param(
6494+
F._color._rgb_to_grayscale_cvcuda,
6495+
"cvcuda.Tensor",
6496+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
6497+
),
64346498
],
64356499
)
64366500
def test_functional_signature(self, kernel, input_type):
6501+
if input_type == "cvcuda.Tensor":
6502+
input_type = _import_cvcuda().Tensor
64376503
check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type)
64386504

6439-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
6505+
@pytest.mark.parametrize(
6506+
"make_input",
6507+
[
6508+
make_image_tensor,
6509+
make_image_pil,
6510+
make_image,
6511+
pytest.param(
6512+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6513+
),
6514+
],
6515+
)
64406516
def test_transform(self, make_input):
64416517
check_transform(transforms.RGB(), make_input(color_space="GRAY"))
64426518

6519+
@pytest.mark.parametrize(
6520+
"make_input",
6521+
[
6522+
make_image,
6523+
pytest.param(
6524+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6525+
),
6526+
],
6527+
)
64436528
@pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)])
6444-
def test_image_correctness(self, fn):
6445-
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
6529+
def test_image_correctness(self, make_input, fn):
6530+
image = make_input(dtype=torch.uint8, device="cpu", color_space="GRAY")
64466531

64476532
actual = fn(image)
6533+
6534+
if make_input is make_image_cvcuda:
6535+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6536+
actual = actual.squeeze(0)
6537+
# drop the batch dimension
6538+
image = F.cvcuda_to_tensor(image).to(device="cpu")
6539+
image = image.squeeze(0)
6540+
64486541
expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))
64496542

64506543
assert_equal(actual, expected, rtol=0, atol=1)

torchvision/transforms/v2/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
182182
chws = {
183183
tuple(get_dimensions(inpt))
184184
for inpt in flat_inputs
185-
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
185+
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
186186
}
187187
if not chws:
188188
raise TypeError("No image or video was found in the sample")

torchvision/transforms/v2/functional/_color.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,38 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
7373
return _FP.to_grayscale(image, num_output_channels=num_output_channels)
7474

7575

76+
def _rgb_to_grayscale_cvcuda(
77+
image: "cvcuda.Tensor",
78+
num_output_channels: int = 1,
79+
) -> "cvcuda.Tensor":
80+
cvcuda = _import_cvcuda()
81+
82+
if num_output_channels not in (1, 3):
83+
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
84+
85+
if num_output_channels == 3:
86+
raise ValueError("num_output_channels must be 1 for CV-CUDA, got 3.")
87+
88+
if image.shape[3] == 1:
89+
# if we already have a single channel, just clone the tensor
90+
# we will use copymakeborder since CV-CUDA has no native clone
91+
return cvcuda.copymakeborder(
92+
image,
93+
border_mode=cvcuda.Border.CONSTANT,
94+
border_value=[0],
95+
top=0,
96+
left=0,
97+
bottom=0,
98+
right=0,
99+
)
100+
101+
return cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2GRAY)
102+
103+
104+
if CVCUDA_AVAILABLE:
105+
_register_kernel_internal(rgb_to_grayscale, _import_cvcuda().Tensor)(_rgb_to_grayscale_cvcuda)
106+
107+
76108
def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor:
77109
"""See :class:`~torchvision.transforms.v2.RGB` for details."""
78110
if torch.jit.is_scripting():
@@ -99,6 +131,31 @@ def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
99131
return image.convert(mode="RGB")
100132

101133

134+
def _grayscale_to_rgb_cvcuda(
135+
image: "cvcuda.Tensor",
136+
) -> "cvcuda.Tensor":
137+
cvcuda = _import_cvcuda()
138+
139+
if image.shape[3] == 3:
140+
# if we already have RGB channels, just clone the tensor
141+
# we will use copymakeborder since CV-CUDA has no native clone
142+
return cvcuda.copymakeborder(
143+
image,
144+
border_mode=cvcuda.Border.CONSTANT,
145+
border_value=[0],
146+
top=0,
147+
left=0,
148+
bottom=0,
149+
right=0,
150+
)
151+
152+
return cvcuda.cvtcolor(image, cvcuda.ColorConversion.GRAY2RGB)
153+
154+
155+
if CVCUDA_AVAILABLE:
156+
_register_kernel_internal(grayscale_to_rgb, _import_cvcuda().Tensor)(_grayscale_to_rgb_cvcuda)
157+
158+
102159
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
103160
ratio = float(ratio)
104161
fp = image1.is_floating_point()

torchvision/transforms/v2/functional/_meta.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]:
5151
return get_dimensions_image(video)
5252

5353

54+
def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]:
55+
# CV-CUDA tensor is always in NHWC layout
56+
# get_dimensions is CHW
57+
return [image.shape[3], image.shape[1], image.shape[2]]
58+
59+
60+
if CVCUDA_AVAILABLE:
61+
_register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda)
62+
63+
5464
def get_num_channels(inpt: torch.Tensor) -> int:
5565
if torch.jit.is_scripting():
5666
return get_num_channels_image(inpt)
@@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int:
8797
get_image_num_channels = get_num_channels
8898

8999

100+
def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int:
101+
# CV-CUDA tensor is always in NHWC layout
102+
# get_num_channels is C
103+
return image.shape[3]
104+
105+
106+
if CVCUDA_AVAILABLE:
107+
_register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda)
108+
109+
90110
def get_size(inpt: torch.Tensor) -> list[int]:
91111
if torch.jit.is_scripting():
92112
return get_size_image(inpt)
@@ -114,7 +134,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]:
114134
return [height, width]
115135

116136

117-
def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
137+
def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]:
118138
"""Get size of `cvcuda.Tensor` with NHWC layout."""
119139
hw = list(image.shape[-3:-1])
120140
ndims = len(hw)
@@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
125145

126146

127147
if CVCUDA_AVAILABLE:
128-
_get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda)
148+
_register_kernel_internal(get_size, _import_cvcuda().Tensor)(_get_size_cvcuda)
129149

130150

131151
@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)

0 commit comments

Comments
 (0)