Skip to content

Commit 8287e48

Browse files
Fix Torch output_padding constraint for ConvTranspose layers (#21852)
* Fix Conv3DTranspose padding conversion for torch backend * Fix Conv3DTranspose padding conversion for torch backend
1 parent 6fd26b4 commit 8287e48

File tree

3 files changed

+58
-14
lines changed

3 files changed

+58
-14
lines changed

keras/src/backend/common/backend_utils.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def _convert_conv_transpose_padding_args_from_keras_to_torch(
9696
)
9797

9898
if torch_output_padding >= stride:
99-
raise ValueError(
100-
f"The padding arguments (padding={padding}) and "
101-
f"output_padding={output_padding}) lead to a Torch "
102-
f"output_padding ({torch_output_padding}) that is greater than "
103-
f"strides ({stride}). This is not supported. You can change the "
104-
f"padding arguments, kernel or stride, or run on another backend. "
99+
warnings.warn(
100+
f"Torch backend requires output_padding < stride. "
101+
f"Clamping output_padding {torch_output_padding} -> {stride - 1} "
102+
f"for stride {stride}.",
103+
UserWarning,
105104
)
105+
torch_output_padding = stride - 1
106106

107107
return torch_padding, torch_output_padding
108108

@@ -184,6 +184,22 @@ def compute_conv_transpose_padding_args_for_torch(
184184
torch_paddings.append(torch_padding)
185185
torch_output_paddings.append(torch_output_padding)
186186

187+
# --- FIX FOR TORCH CONSTRAINT: output_padding < stride ---
188+
corrected_output_paddings = []
189+
for s, op in zip(
190+
strides
191+
if isinstance(strides, (list, tuple))
192+
else [strides] * num_spatial_dims,
193+
torch_output_paddings,
194+
):
195+
max_allowed = max(0, s - 1)
196+
if op > max_allowed:
197+
corrected_output_paddings.append(max_allowed)
198+
else:
199+
corrected_output_paddings.append(op)
200+
201+
torch_output_paddings = corrected_output_paddings
202+
187203
return torch_paddings, torch_output_paddings
188204

189205

keras/src/backend/common/backend_utils_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,25 @@ def test_valid_padding_with_output_padding(self):
170170
self.assertEqual(torch_padding, 0)
171171
self.assertEqual(torch_output_padding, 1)
172172

173+
def test_output_padding_clamped_for_torch_constraint(self):
174+
"""Test that output_padding is clamped
175+
when >= stride (Torch constraint).
176+
"""
177+
(
178+
torch_paddings,
179+
torch_output_paddings,
180+
) = compute_conv_transpose_padding_args_for_torch(
181+
input_shape=(1, 8, 8, 8, 16), # any shape
182+
kernel_shape=(2, 2, 2, 16, 32), # Keras kernel shape
183+
strides=1,
184+
padding="same",
185+
output_padding=1, # Keras wants this
186+
dilation_rate=1,
187+
)
188+
# Torch expects output_padding < stride (1)
189+
# so output_padding should be clamped to 0
190+
self.assertEqual(torch_output_paddings, [0, 0, 0])
191+
173192

174193
class GetOutputShapeGivenTFPaddingTest(test_case.TestCase):
175194
def test_valid_padding_without_output_padding(self):

keras/src/layers/convolutional/conv_transpose_test.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -825,21 +825,30 @@ def test_conv1d_transpose_consistency(
825825

826826
# Special cases for Torch
827827
if backend.backend() == "torch":
828-
# The following set of arguments lead to Torch output padding to be
829-
# greater than strides, which is not supported by Torch.
830-
# An error is raised.
828+
# Args that cause output_padding >= strides
829+
# are clamped with a warning.
831830
if (kernel_size, strides, padding, output_padding) in [
832831
(2, 1, "same", None),
833832
(4, 1, "same", None),
834833
]:
835-
with pytest.raises(ValueError):
834+
clamped_output_padding = strides - 1 # usually 0 when stride=1
835+
expected_res = np_conv1d_transpose(
836+
x=input,
837+
kernel_weights=kernel_weights,
838+
bias_weights=np.zeros(shape=(1,)),
839+
strides=strides,
840+
padding=padding,
841+
output_padding=clamped_output_padding,
842+
data_format=backend.config.image_data_format(),
843+
dilation_rate=1,
844+
)
845+
with pytest.warns(UserWarning):
836846
kc_res = kc_layer(input)
847+
self.assertAllClose(expected_res, kc_res, atol=1e-5)
837848
return
838849

839-
# When both torch_padding and torch_output_padding are greater
840-
# than 0, Torch outputs are inconsistent with the ones from
841-
# Tensorflow. A warning is raised, and we expect the results to be
842-
# different.
850+
# torch_padding > 0 and torch_output_padding > 0 case
851+
# Torch output differs from TF.
843852
(
844853
torch_padding,
845854
torch_output_padding,

0 commit comments

Comments
 (0)