diff --git a/keras/src/backend/common/backend_utils.py b/keras/src/backend/common/backend_utils.py index fb809c2cc7b2..8b4999b38592 100644 --- a/keras/src/backend/common/backend_utils.py +++ b/keras/src/backend/common/backend_utils.py @@ -1,5 +1,4 @@ import functools -import operator import re import warnings @@ -262,14 +261,13 @@ def compute_conv_transpose_output_shape( def canonicalize_axis(axis, num_dims): """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" - axis = operator.index(axis) if not -num_dims <= axis < num_dims: raise ValueError( f"axis {axis} is out of bounds for an array with dimension " f"{num_dims}." ) if axis < 0: - axis = axis + num_dims + axis += num_dims return axis diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..5f533cba8ab5 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -68,9 +68,10 @@ def compute_expand_dims_output_shape(input_shape, axis): axis = to_tuple_or_list(axis) out_ndim = len(axis) + len(input_shape) axis = [canonicalize_axis(a, out_ndim) for a in axis] + axis_set = set(axis) shape_iter = iter(input_shape) new_shape = [ - 1 if ax in axis else next(shape_iter) for ax in range(out_ndim) + 1 if ax in axis_set else next(shape_iter) for ax in range(out_ndim) ] return tuple(new_shape)