diff --git a/keras/src/backend/common/backend_utils.py b/keras/src/backend/common/backend_utils.py index fb809c2cc7b2..e5f051b186f5 100644 --- a/keras/src/backend/common/backend_utils.py +++ b/keras/src/backend/common/backend_utils.py @@ -262,14 +262,16 @@ def compute_conv_transpose_output_shape( def canonicalize_axis(axis, num_dims): """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" - axis = operator.index(axis) + # Removed operator.index() as it is redundant if axis is already int (which it is, from reduce_shape's usage) + if not isinstance(axis, int): + axis = operator.index(axis) if not -num_dims <= axis < num_dims: raise ValueError( f"axis {axis} is out of bounds for an array with dimension " f"{num_dims}." ) if axis < 0: - axis = axis + num_dims + axis += num_dims return axis diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..8944852beac4 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -369,23 +369,28 @@ def compute_take_along_axis_output_shape(input_shape, indices_shape, axis): def reduce_shape(shape, axis=None, keepdims=False): - shape = list(shape) + len_shape = len(shape) + if axis is None: if keepdims: - return tuple([1 for _ in shape]) + return (1,) * len_shape else: - return tuple([]) + return () elif isinstance(axis, int): axis = (axis,) - axis = tuple(canonicalize_axis(a, len(shape)) for a in axis) + # Tuple here triggers generator expression once, not multiple times as in usage below. + canonical_axes = tuple(canonicalize_axis(a, len_shape) for a in axis) + + shape = list(shape) # Only create a list when axis is not None! if keepdims: - for ax in axis: + for ax in canonical_axes: shape[ax] = 1 return tuple(shape) else: - for ax in sorted(axis, reverse=True): + # Use reverse sorted to avoid index shifts + for ax in sorted(canonical_axes, reverse=True): del shape[ax] return tuple(shape)