Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions keras/src/ops/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,26 @@ def broadcast_shapes(shape1, shape2):
origin_shape1 = shape1
origin_shape2 = shape2

if len(shape1) > len(shape2):
shape2 = [1] * (len(shape1) - len(shape2)) + shape2
if len(shape1) < len(shape2):
shape1 = [1] * (len(shape2) - len(shape1)) + shape1
output_shape = list(shape1)
for i in range(len(shape1)):
if shape1[i] == 1:
output_shape[i] = shape2[i]
elif shape1[i] is None:
output_shape[i] = None if shape2[i] == 1 else shape2[i]
len1, len2 = len(shape1), len(shape2)
if len1 > len2:
shape2 = [1] * (len1 - len2) + shape2
elif len1 < len2:
shape1 = [1] * (len2 - len1) + shape1

# Avoid the unnecessary output_shape = list(shape1): just use a new list since most elements may be replaced
output_shape = []
for dim1, dim2 in zip(shape1, shape2):
if dim1 == 1:
output_shape.append(dim2)
elif dim1 is None:
output_shape.append(None if dim2 == 1 else dim2)
else:
if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]:
output_shape[i] = shape1[i]
if dim2 == 1 or dim2 is None or dim2 == dim1:
output_shape.append(dim1)
else:
raise ValueError(
"Cannot broadcast shape, the failure dim has value "
f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. "
f"{dim1}, which cannot be broadcasted to {dim2}. "
f"Input shapes are: {origin_shape1} and {origin_shape2}."
)

Expand Down Expand Up @@ -353,9 +356,15 @@ def compute_take_along_axis_output_shape(input_shape, indices_shape, axis):
input_shape = list(input_shape)
indices_shape = list(indices_shape)
if axis is None:
input_shape = (
[None] if None in input_shape else [int(np.prod(input_shape))]
)
# Avoid np.prod if there are no None values (np.prod is relatively expensive for short lists)
if None in input_shape:
input_shape = [None]
else:
prod = 1
for d in input_shape:
prod *= d
input_shape = [int(prod)]


if len(input_shape) != len(indices_shape):
raise ValueError(
Expand Down