diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..a997c9e1d448 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -115,40 +115,77 @@ def compute_pooling_output_shape( (32, 2, 2, 3) """ strides = pool_size if strides is None else strides - input_shape_origin = list(input_shape) - input_shape = np.array(input_shape) + + # Use tuple instead of list for input_shape_origin (to avoid unnecessary copy) + input_shape_origin = tuple(input_shape) + # Only convert to numpy if needed later if data_format == "channels_last": spatial_shape = input_shape[1:-1] else: spatial_shape = input_shape[2:] none_dims = [] - for i in range(len(spatial_shape)): - if spatial_shape[i] is None: - # Set `None` shape to a manual value so that we can run numpy - # computation on `spatial_shape`. - spatial_shape[i] = -1 + have_none = False + + # Fast path: all dimensions known and integers + for i, dim in enumerate(spatial_shape): + if dim is None: + have_none = True none_dims.append(i) - pool_size = np.array(pool_size) - if padding == "valid": - output_spatial_shape = ( - np.floor((spatial_shape - pool_size) / strides) + 1 + + # If there are None dimensions, we must use numpy, otherwise stick with ints + # This saves most allocations for the common case + if have_none: + spatial_shape_arr = np.array( + [(-1 if dim is None else dim) for dim in spatial_shape], + dtype=np.intp ) - for i in range(len(output_spatial_shape)): - if i not in none_dims and output_spatial_shape[i] < 0: - raise ValueError( - "Computed output size would be negative. Received: " - f"`inputs.shape={input_shape}` and `pool_size={pool_size}`." - ) - elif padding == "same": - output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1 + pool_size_arr = np.array(pool_size, dtype=np.intp) + strides_arr = np.array(strides, dtype=np.intp) + if padding == "valid": + output_spatial_shape = np.floor_divide( + spatial_shape_arr - pool_size_arr, strides_arr + ) + 1 + for i in range(len(output_spatial_shape)): + if i not in none_dims and output_spatial_shape[i] < 0: + raise ValueError( + "Computed output size would be negative. Received: " + f"`inputs.shape={input_shape}` and `pool_size={pool_size}`." + ) + elif padding == "same": + output_spatial_shape = np.floor_divide( + spatial_shape_arr - 1, strides_arr + ) + 1 + else: + raise ValueError( + "Argument `padding` must be either 'valid' or 'same'. Received: " + f"padding={padding}" + ) + output_spatial_shape = [int(i) if i != -1 else None for i in output_spatial_shape] else: - raise ValueError( - "Argument `padding` must be either 'valid' or 'same'. Received: " - f"padding={padding}" - ) - output_spatial_shape = [int(i) for i in output_spatial_shape] - for i in none_dims: - output_spatial_shape[i] = None + # No Nones; use only native int math, no numpy + if padding == "valid": + output_spatial_shape = [] + for i, (dim, psize, stride) in enumerate(zip(spatial_shape, pool_size, strides)): + val = (dim - psize) // stride + 1 + if val < 0: + raise ValueError( + "Computed output size would be negative. Received: " + f"`inputs.shape={input_shape}` and `pool_size={pool_size}`." + ) + output_spatial_shape.append(val) + elif padding == "same": + output_spatial_shape = [ + (dim - 1) // stride + 1 + for dim, stride in zip(spatial_shape, strides) + ] + else: + raise ValueError( + "Argument `padding` must be either 'valid' or 'same'. Received: " + f"padding={padding}" + ) + # Assign back None for unknown dims + for idx in none_dims: + output_spatial_shape[idx] = None output_spatial_shape = tuple(output_spatial_shape) if data_format == "channels_last": output_shape = (