diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..231c9a2b4924 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -197,25 +197,18 @@ def compute_conv_output_shape( f"`dilation_rate={dilation_rate}` and " f"input of shape {input_shape}." ) - none_dims = [] - spatial_shape = np.array(spatial_shape) - for i in range(len(spatial_shape)): - if spatial_shape[i] is None: - # Set `None` shape to a manual value so that we can run numpy - # computation on `spatial_shape`. - spatial_shape[i] = -1 - none_dims.append(i) + none_dims = [i for i, s in enumerate(spatial_shape) if s is None] + spatial_calc = tuple(-1 if s is None else s for s in spatial_shape) - kernel_spatial_shape = np.array(kernel_shape[:-2]) - dilation_rate = np.array(dilation_rate) + kernel_spatial_shape = kernel_shape[:-2] if padding == "valid": - output_spatial_shape = ( - np.floor( - (spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1) - / strides - ) - + 1 - ) + output_spatial_shape = [ + int(np.floor( + (spatial_calc[i] - dilation_rate[i] * (kernel_spatial_shape[i] - 1) - 1) + / strides[i] + 1 + )) if spatial_calc[i] != -1 else -1 + for i in range(len(spatial_shape)) + ] for i in range(len(output_spatial_shape)): if i not in none_dims and output_spatial_shape[i] < 0: raise ValueError( @@ -225,13 +218,16 @@ def compute_conv_output_shape( f"`dilation_rate={dilation_rate}`." ) elif padding == "same" or padding == "causal": - output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1 + output_spatial_shape = [ + int(np.floor((spatial_calc[i] - 1) / strides[i]) + 1) + if spatial_calc[i] != -1 else -1 + for i in range(len(spatial_shape)) + ] else: raise ValueError( "`padding` must be either `'valid'` or `'same'`. Received " f"{padding}." ) - output_spatial_shape = [int(i) for i in output_spatial_shape] for i in none_dims: output_spatial_shape[i] = None output_spatial_shape = tuple(output_spatial_shape)