From d9f7aa415e82e5ab7bb02fbaa9e29a0db68ed8ab Mon Sep 17 00:00:00 2001 From: Jaebeom Date: Wed, 3 Dec 2025 22:43:00 +0900 Subject: [PATCH] Fix incorrect normalization axis in v2.ElasticTransform --- torchvision/transforms/v2/_geometry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 1418a6b4953..a3d0346134f 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -1062,7 +1062,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: if kx % 2 == 0: kx += 1 dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) - dx = dx * self.alpha[0] / size[0] + dx = dx * self.alpha[0] / size[1] dy = torch.rand([1, 1] + size) * 2 - 1 if self.sigma[1] > 0.0: @@ -1071,7 +1071,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: if ky % 2 == 0: ky += 1 dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma)) - dy = dy * self.alpha[1] / size[1] + dy = dy * self.alpha[1] / size[0] displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement)