Skip to content

Commit c362120

Browse files
committed
simplify elastic cvcuda code more
1 parent 156893f commit c362120

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2563,7 +2563,6 @@ def _elastic_cvcuda(
25632563
if not isinstance(displacement, torch.Tensor):
25642564
raise TypeError("Argument displacement should be a Tensor")
25652565

2566-
# Input image is NHWC format: (N, H, W, C)
25672566
batch_size, height, width, num_channels = image.shape
25682567
device = torch.device("cuda")
25692568
dtype = torch.float32
@@ -2579,6 +2578,10 @@ def _elastic_cvcuda(
25792578
elif num_channels == 1 and input_dtype != cvcuda.Type.F32:
25802579
raise ValueError(f"cvcuda.remap requires float32 dtype for 1-channel images, but got {input_dtype}")
25812580

2581+
interp = _cvcuda_interp.get(interpolation, cvcuda.Interp.LINEAR)
2582+
if interp is None:
2583+
raise ValueError(f"Invalid interpolation mode: {interpolation}")
2584+
25822585
# Build normalized grid: identity + displacement
25832586
# _create_identity_grid returns (1, H, W, 2) with values in [-1, 1]
25842587
identity_grid = _create_identity_grid((height, width), device=device, dtype=dtype)
@@ -2599,28 +2602,20 @@ def _elastic_cvcuda(
25992602
# Create cvcuda map tensor (NHWC layout with 2 channels for x,y)
26002603
cv_map = cvcuda.as_tensor(pixel_map.contiguous(), "NHWC")
26012604

2602-
# Resolve interpolation
2603-
src_interp = _cvcuda_interp.get(interpolation, cvcuda.Interp.LINEAR)
2604-
2605-
# Resolve border mode and value
2605+
border_mode = cvcuda.Border.CONSTANT
26062606
if fill is None:
2607-
border_mode = cvcuda.Border.CONSTANT
26082607
border_value = np.array([], dtype=np.float32)
26092608
elif isinstance(fill, (int, float)):
2610-
border_mode = cvcuda.Border.CONSTANT
26112609
border_value = np.array([fill], dtype=np.float32)
26122610
elif isinstance(fill, (list, tuple)):
2613-
border_mode = cvcuda.Border.CONSTANT
26142611
border_value = np.array(fill, dtype=np.float32)
26152612
else:
2616-
border_mode = cvcuda.Border.CONSTANT
26172613
border_value = np.array([], dtype=np.float32)
26182614

2619-
# Call cvcuda.remap
26202615
output = cvcuda.remap(
26212616
image,
26222617
cv_map,
2623-
src_interp=src_interp,
2618+
src_interp=interp,
26242619
map_interp=cvcuda.Interp.LINEAR,
26252620
map_type=cvcuda.Remap.ABSOLUTE,
26262621
align_corners=False,

0 commit comments

Comments
 (0)