diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 96166e05e9a..07e2eeb6f51 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -1070,7 +1070,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: if kx % 2 == 0: kx += 1 dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) - dx = dx * self.alpha[0] / size[0] + dx = dx * self.alpha[0] / size[1] dy = torch.rand([1, 1] + size) * 2 - 1 if self.sigma[1] > 0.0: @@ -1079,7 +1079,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: if ky % 2 == 0: ky += 1 dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma)) - dy = dy * self.alpha[1] / size[1] + dy = dy * self.alpha[1] / size[0] displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement)