@@ -948,6 +948,9 @@ def perform(self, node, inputs, out_):
948
948
out [0 ] = np .asarray (x .__getitem__ (cdata ))
949
949
950
950
def infer_shape (self , fgraph , node , shapes ):
951
+ def _is_constant (const , x ):
952
+ return isinstance (const , Constant ) and const .data .item () == x
953
+
951
954
xshp = shapes [0 ]
952
955
assert len (xshp ) == node .inputs [0 ].ndim
953
956
outshp = []
@@ -961,10 +964,17 @@ def infer_shape(self, fgraph, node, shapes):
961
964
# If it is the default (None, None, None) slice, or a variant,
962
965
# the shape will be xl
963
966
if (
964
- (idx .start in [None , 0 ])
965
- and (idx .stop in [None , sys .maxsize ])
966
- and (idx .step is None or idx .step == 1 )
967
+ (idx .start is None or _is_constant (idx .start , 0 ))
968
+ and (idx .stop is None or _is_constant (idx .stop , sys .maxsize ))
969
+ and (idx .step is None or _is_constant (idx .step , 1 ))
970
+ ):
971
+ outshp .append (xl )
972
+ elif (
973
+ (idx .start is None )
974
+ and (idx .stop is None )
975
+ and _is_constant (idx .step , - 1 )
967
976
):
977
+ # Reverse slice
968
978
outshp .append (xl )
969
979
else :
970
980
cnf = get_canonical_form_slice (idx , xl )[0 ]
0 commit comments