Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 63 additions & 26 deletions keras/src/ops/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,40 +115,77 @@ def compute_pooling_output_shape(
(32, 2, 2, 3)
"""
strides = pool_size if strides is None else strides
input_shape_origin = list(input_shape)
input_shape = np.array(input_shape)

# Use tuple instead of list for input_shape_origin (to avoid unnecessary copy)
input_shape_origin = tuple(input_shape)
# Only convert to numpy if needed later
if data_format == "channels_last":
spatial_shape = input_shape[1:-1]
else:
spatial_shape = input_shape[2:]
none_dims = []
for i in range(len(spatial_shape)):
if spatial_shape[i] is None:
# Set `None` shape to a manual value so that we can run numpy
# computation on `spatial_shape`.
spatial_shape[i] = -1
have_none = False

# Fast path: all dimensions known and integers
for i, dim in enumerate(spatial_shape):
if dim is None:
have_none = True
none_dims.append(i)
pool_size = np.array(pool_size)
if padding == "valid":
output_spatial_shape = (
np.floor((spatial_shape - pool_size) / strides) + 1

# If there are None dimensions, we must use numpy, otherwise stick with ints
# This saves most allocations for the common case
if have_none:
spatial_shape_arr = np.array(
[(-1 if dim is None else dim) for dim in spatial_shape],
dtype=np.intp
)
for i in range(len(output_spatial_shape)):
if i not in none_dims and output_spatial_shape[i] < 0:
raise ValueError(
"Computed output size would be negative. Received: "
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
)
elif padding == "same":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
pool_size_arr = np.array(pool_size, dtype=np.intp)
strides_arr = np.array(strides, dtype=np.intp)
if padding == "valid":
output_spatial_shape = np.floor_divide(
spatial_shape_arr - pool_size_arr, strides_arr
) + 1
for i in range(len(output_spatial_shape)):
if i not in none_dims and output_spatial_shape[i] < 0:
raise ValueError(
"Computed output size would be negative. Received: "
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
)
elif padding == "same":
output_spatial_shape = np.floor_divide(
spatial_shape_arr - 1, strides_arr
) + 1
else:
raise ValueError(
"Argument `padding` must be either 'valid' or 'same'. Received: "
f"padding={padding}"
)
output_spatial_shape = [int(i) if i != -1 else None for i in output_spatial_shape]
else:
raise ValueError(
"Argument `padding` must be either 'valid' or 'same'. Received: "
f"padding={padding}"
)
output_spatial_shape = [int(i) for i in output_spatial_shape]
for i in none_dims:
output_spatial_shape[i] = None
# No Nones; use only native int math, no numpy
if padding == "valid":
output_spatial_shape = []
for i, (dim, psize, stride) in enumerate(zip(spatial_shape, pool_size, strides)):
val = (dim - psize) // stride + 1
if val < 0:
raise ValueError(
"Computed output size would be negative. Received: "
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
)
output_spatial_shape.append(val)
elif padding == "same":
output_spatial_shape = [
(dim - 1) // stride + 1
for dim, stride in zip(spatial_shape, strides)
]
else:
raise ValueError(
"Argument `padding` must be either 'valid' or 'same'. Received: "
f"padding={padding}"
)
# Assign back None for unknown dims
for idx in none_dims:
output_spatial_shape[idx] = None
output_spatial_shape = tuple(output_spatial_shape)
if data_format == "channels_last":
output_shape = (
Expand Down