Skip to content

Commit c32af30

Browse files
committed
Fix AdvancedSubtensor static shape with newaxis
1 parent 695574b commit c32af30

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

pytensor/tensor/subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2630,7 +2630,7 @@ def make_node(self, x, *indices):
26302630
adv_group_axis = None
26312631
last_adv_group_axis = None
26322632
expanded_x_shape = tuple(
2633-
np.insert(np.array(x.type.shape, dtype=object), 1, new_axes)
2633+
np.insert(np.array(x.type.shape, dtype=object), new_axes, values=1)
26342634
)
26352635
for i, (idx, dim_length) in enumerate(
26362636
zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)

tests/tensor/test_subtensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,6 +1856,7 @@ def test_static_shape(self):
18561856

18571857
assert x[idx1].type.shape == (10, None)
18581858
assert x[:, idx1].type.shape == (None, 10)
1859+
assert x[None, :, idx1].type.shape == (1, None, 10)
18591860
assert x[idx2, :5].type.shape == (3, None, None)
18601861
assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5)
18611862
assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3)

0 commit comments

Comments
 (0)