Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions keras/src/backend/common/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def _convert_conv_transpose_padding_args_from_keras_to_torch(
)

if torch_output_padding >= stride:
raise ValueError(
f"The padding arguments (padding={padding}) and "
f"output_padding={output_padding}) lead to a Torch "
f"output_padding ({torch_output_padding}) that is greater than "
f"strides ({stride}). This is not supported. You can change the "
f"padding arguments, kernel or stride, or run on another backend. "
warnings.warn(
f"Torch backend requires output_padding < stride. "
f"Clamping output_padding {torch_output_padding} -> {stride - 1} "
f"for stride {stride}.",
UserWarning,
)
torch_output_padding = stride - 1
Comment on lines +99 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This clamping logic is a good improvement. However, the redundant clamping logic later in compute_conv_transpose_padding_args_for_torch (lines 187-201) uses max(0, s - 1), which is safer as it prevents negative padding if stride is less than 1. It would be best to incorporate that safer logic here and remove the redundant code block. Using a temporary variable for the new padding value would also improve the warning message's clarity.

Suggested change
warnings.warn(
f"Torch backend requires output_padding < stride. "
f"Clamping output_padding {torch_output_padding} -> {stride - 1} "
f"for stride {stride}.",
UserWarning,
)
torch_output_padding = stride - 1
new_output_padding = max(0, stride - 1)
warnings.warn(
f"Torch backend requires output_padding < stride. "
f"Clamping output_padding {torch_output_padding} -> {new_output_padding} "
f"for stride {stride}.",
UserWarning,
)
torch_output_padding = new_output_padding


return torch_padding, torch_output_padding

Expand Down Expand Up @@ -184,6 +184,22 @@ def compute_conv_transpose_padding_args_for_torch(
torch_paddings.append(torch_padding)
torch_output_paddings.append(torch_output_padding)

# --- FIX FOR TORCH CONSTRAINT: output_padding < stride ---
corrected_output_paddings = []
for s, op in zip(
strides
if isinstance(strides, (list, tuple))
else [strides] * num_spatial_dims,
torch_output_paddings,
):
max_allowed = max(0, s - 1)
if op > max_allowed:
corrected_output_paddings.append(max_allowed)
else:
corrected_output_paddings.append(op)

torch_output_paddings = corrected_output_paddings
Comment on lines +187 to +201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With the clamping logic now handled within _convert_conv_transpose_padding_args_from_keras_to_torch (and improved with the suggestion in the other comment), this entire block of code becomes redundant. Removing it will simplify the function and avoid logic duplication.


return torch_paddings, torch_output_paddings


Expand Down
19 changes: 19 additions & 0 deletions keras/src/backend/common/backend_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ def test_valid_padding_with_output_padding(self):
self.assertEqual(torch_padding, 0)
self.assertEqual(torch_output_padding, 1)

def test_output_padding_clamped_for_torch_constraint(self):
"""Test that output_padding is clamped
when >= stride (Torch constraint).
"""
(
torch_paddings,
torch_output_paddings,
) = compute_conv_transpose_padding_args_for_torch(
input_shape=(1, 8, 8, 8, 16), # any shape
kernel_shape=(2, 2, 2, 16, 32), # Keras kernel shape
strides=1,
padding="same",
output_padding=1, # Keras wants this
dilation_rate=1,
)
# Torch expects output_padding < stride (1)
# so output_padding should be clamped to 0
self.assertEqual(torch_output_paddings, [0, 0, 0])


class GetOutputShapeGivenTFPaddingTest(test_case.TestCase):
def test_valid_padding_without_output_padding(self):
Expand Down
25 changes: 17 additions & 8 deletions keras/src/layers/convolutional/conv_transpose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,21 +825,30 @@ def test_conv1d_transpose_consistency(

# Special cases for Torch
if backend.backend() == "torch":
# The following set of arguments lead to Torch output padding to be
# greater than strides, which is not supported by Torch.
# An error is raised.
# Args that cause output_padding >= strides
# are clamped with a warning.
if (kernel_size, strides, padding, output_padding) in [
(2, 1, "same", None),
(4, 1, "same", None),
]:
with pytest.raises(ValueError):
clamped_output_padding = strides - 1 # usually 0 when stride=1
expected_res = np_conv1d_transpose(
x=input,
kernel_weights=kernel_weights,
bias_weights=np.zeros(shape=(1,)),
strides=strides,
padding=padding,
output_padding=clamped_output_padding,
data_format=backend.config.image_data_format(),
dilation_rate=1,
)
with pytest.warns(UserWarning):
kc_res = kc_layer(input)
self.assertAllClose(expected_res, kc_res, atol=1e-5)
return

# When both torch_padding and torch_output_padding are greater
# than 0, Torch outputs are inconsistent with the ones from
# Tensorflow. A warning is raised, and we expect the results to be
# different.
# torch_padding > 0 and torch_output_padding > 0 case
# Torch output differs from TF.
(
torch_padding,
torch_output_padding,
Expand Down