Skip to content

Commit 01efae7

Browse files
committed
add comment explaining mean/std behavior, one-line intermediate creation
1 parent 1e864d8 commit 01efae7

File tree

1 file changed

+12
-8
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+12
-8
lines changed

torchvision/transforms/v2/functional/_misc.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,18 @@ def normalize_cvcuda(
9898
elif len(std) != channels:
9999
raise ValueError(f"Std should have {channels} elements. Got {len(std)}.")
100100

101-
mean = torch.as_tensor(mean, dtype=torch.float32)
102-
std = torch.as_tensor(std, dtype=torch.float32)
103-
mean_tensor = mean.reshape(1, 1, 1, channels)
104-
std_tensor = std.reshape(1, 1, 1, channels)
105-
mean_tensor = mean_tensor.cuda()
106-
std_tensor = std_tensor.cuda()
107-
mean_cv = cvcuda.as_tensor(mean_tensor, cvcuda.TensorLayout.NHWC)
108-
std_cv = cvcuda.as_tensor(std_tensor, cvcuda.TensorLayout.NHWC)
101+
# CV-CUDA requires float32 tensors for the mean/std parameters
102+
# at small batchs, this is costly relative to normalize operation
103+
# if CV-CUDA is known to be a backend, could optimize this
104+
# For Normalize class:
105+
# by creating tensors at class initialization time
106+
# For functional API:
107+
# by storing cached tensors in helper function with functools.lru_cache (would it even be worth it?)
108+
# Since CV-CUDA is 1) not default backend, 2) only strictly faster at large batch size, ignore
109+
mt = torch.as_tensor(mean, dtype=torch.float32).reshape(1, 1, 1, channels).cuda()
110+
st = torch.as_tensor(std, dtype=torch.float32).reshape(1, 1, 1, channels).cuda()
111+
mean_cv = cvcuda.as_tensor(mt, cvcuda.TensorLayout.NHWC)
112+
std_cv = cvcuda.as_tensor(st, cvcuda.TensorLayout.NHWC)
109113

110114
return cvcuda.normalize(image, base=mean_cv, scale=std_cv, flags=cvcuda.NormalizeFlags.SCALE_IS_STDDEV)
111115

0 commit comments

Comments
 (0)