Skip to content

Commit 0ce5359

Browse files
committed
Simplify Subtensor.infer_shape for reversed slices
1 parent 5021a44 commit 0ce5359

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

pytensor/tensor/subtensor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,9 @@ def perform(self, node, inputs, out_):
948948
out[0] = np.asarray(x.__getitem__(cdata))
949949

950950
def infer_shape(self, fgraph, node, shapes):
951+
def _is_constant(const, x):
952+
return isinstance(const, Constant) and const.data.item() == x
953+
951954
xshp = shapes[0]
952955
assert len(xshp) == node.inputs[0].ndim
953956
outshp = []
@@ -961,10 +964,17 @@ def infer_shape(self, fgraph, node, shapes):
961964
# If it is the default (None, None, None) slice, or a variant,
962965
# the shape will be xl
963966
if (
964-
(idx.start in [None, 0])
965-
and (idx.stop in [None, sys.maxsize])
966-
and (idx.step is None or idx.step == 1)
967+
(idx.start is None or _is_constant(idx.start, 0))
968+
and (idx.stop is None or _is_constant(idx.stop, sys.maxsize))
969+
and (idx.step is None or _is_constant(idx.step, 1))
970+
):
971+
outshp.append(xl)
972+
elif (
973+
(idx.start is None)
974+
and (idx.stop is None)
975+
and _is_constant(idx.step, -1)
967976
):
977+
# Reverse slice
968978
outshp.append(xl)
969979
else:
970980
cnf = get_canonical_form_slice(idx, xl)[0]

0 commit comments

Comments
 (0)