diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..187eff370243 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -29,23 +29,26 @@ def broadcast_shapes(shape1, shape2): origin_shape1 = shape1 origin_shape2 = shape2 - if len(shape1) > len(shape2): - shape2 = [1] * (len(shape1) - len(shape2)) + shape2 - if len(shape1) < len(shape2): - shape1 = [1] * (len(shape2) - len(shape1)) + shape1 - output_shape = list(shape1) - for i in range(len(shape1)): - if shape1[i] == 1: - output_shape[i] = shape2[i] - elif shape1[i] is None: - output_shape[i] = None if shape2[i] == 1 else shape2[i] + len1, len2 = len(shape1), len(shape2) + if len1 > len2: + shape2 = [1] * (len1 - len2) + shape2 + elif len1 < len2: + shape1 = [1] * (len2 - len1) + shape1 + + # Avoid the unnecessary output_shape = list(shape1): just use a new list since most elements may be replaced + output_shape = [] + for dim1, dim2 in zip(shape1, shape2): + if dim1 == 1: + output_shape.append(dim2) + elif dim1 is None: + output_shape.append(None if dim2 == 1 else dim2) else: - if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]: - output_shape[i] = shape1[i] + if dim2 == 1 or dim2 is None or dim2 == dim1: + output_shape.append(dim1) else: raise ValueError( "Cannot broadcast shape, the failure dim has value " - f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. " + f"{dim1}, which cannot be broadcasted to {dim2}. " f"Input shapes are: {origin_shape1} and {origin_shape2}." ) @@ -353,9 +356,15 @@ def compute_take_along_axis_output_shape(input_shape, indices_shape, axis): input_shape = list(input_shape) indices_shape = list(indices_shape) if axis is None: - input_shape = ( - [None] if None in input_shape else [int(np.prod(input_shape))] - ) + # Avoid np.prod if there are no None values (np.prod is relatively expensive for short lists) + if None in input_shape: + input_shape = [None] + else: + prod = 1 + for d in input_shape: + prod *= d + input_shape = [int(prod)] + if len(input_shape) != len(indices_shape): raise ValueError(