Skip to content

Commit c18b322

Browse files
CopilotricardoV94
authored andcommitted
Complete refactoring with improved factory functions and proper slice handling
Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com>
1 parent 3cfbd0d commit c18b322

File tree

1 file changed

+185
-79
lines changed

1 file changed

+185
-79
lines changed

pytensor/tensor/subtensor.py

Lines changed: 185 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,28 +2604,48 @@ def make_node(self, x, *inputs):
26042604
inputs = tuple(as_tensor_variable(a) for a in inputs)
26052605

26062606
idx_list = list(self.idx_list)
2607-
if len(idx_list) > x.type.ndim:
2607+
if (len([entry for entry in idx_list if entry is not np.newaxis]) > x.type.ndim):
26082608
raise IndexError("too many indices for array")
26092609

2610-
# Get input types from idx_list - only process numerical indices
2611-
input_types = []
2612-
input_idx = 0
2610+
# Validate input count matches expected from idx_list
2611+
expected_inputs = get_slice_elements(idx_list, lambda entry: isinstance(entry, Type))
2612+
if len(inputs) != len(expected_inputs):
2613+
raise ValueError(f"Expected {len(expected_inputs)} inputs but got {len(inputs)}")
2614+
2615+
# Build explicit_indices for shape inference
26132616
explicit_indices = []
26142617
new_axes = []
2618+
input_idx = 0
26152619

26162620
for i, entry in enumerate(idx_list):
2617-
if isinstance(entry, slice):
2618-
# Slices are stored in idx_list, not passed as inputs
2619-
explicit_indices.append(entry)
2620-
elif entry is np.newaxis:
2621-
# Newaxis stored in idx_list, not passed as inputs
2621+
if entry is np.newaxis:
26222622
new_axes.append(len(explicit_indices))
2623-
explicit_indices.append(entry)
2623+
explicit_indices.append(np.newaxis)
2624+
elif isinstance(entry, slice):
2625+
# Reconstruct slice with actual values from inputs
2626+
if entry.start is not None and isinstance(entry.start, Type):
2627+
start_val = inputs[input_idx]
2628+
input_idx += 1
2629+
else:
2630+
start_val = entry.start
2631+
2632+
if entry.stop is not None and isinstance(entry.stop, Type):
2633+
stop_val = inputs[input_idx]
2634+
input_idx += 1
2635+
else:
2636+
stop_val = entry.stop
2637+
2638+
if entry.step is not None and isinstance(entry.step, Type):
2639+
step_val = inputs[input_idx]
2640+
input_idx += 1
2641+
else:
2642+
step_val = entry.step
2643+
2644+
explicit_indices.append(slice(start_val, stop_val, step_val))
26242645
elif isinstance(entry, Type):
2625-
# This is a numerical index - should have corresponding input
2626-
if input_idx >= len(inputs):
2627-
raise ValueError(f"Missing input for index {i}")
2646+
# This is a numerical index
26282647
inp = inputs[input_idx]
2648+
input_idx += 1
26292649

26302650
# Handle boolean indices
26312651
if inp.dtype == "bool":
@@ -2649,26 +2669,18 @@ def make_node(self, x, *inputs):
26492669
f"boolean index did not match indexed tensor along axis {axis + j};"
26502670
f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}"
26512671
)
2652-
# Convert boolean indices to integer with nonzero, to reason about static shape next
2672+
# Convert boolean indices to integer with nonzero
26532673
if isinstance(inp, Constant):
26542674
nonzero_indices = [tensor_constant(i) for i in inp.data.nonzero()]
26552675
else:
2656-
# Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
2657-
# and seeing that other integer indices cannot possible match it
26582676
nonzero_indices = inp.nonzero()
26592677
explicit_indices.extend(nonzero_indices)
26602678
else:
26612679
# Regular numerical index
26622680
explicit_indices.append(inp)
2663-
2664-
input_types.append(entry)
2665-
input_idx += 1
26662681
else:
26672682
raise ValueError(f"Invalid entry in idx_list: {entry}")
26682683

2669-
if input_idx != len(inputs):
2670-
raise ValueError(f"Expected {input_idx} inputs but got {len(inputs)}")
2671-
26722684
if (len(explicit_indices) - len(new_axes)) > x.type.ndim:
26732685
raise IndexError(
26742686
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed"
@@ -2740,20 +2752,40 @@ def is_bool_index(idx):
27402752
or getattr(idx, "dtype", None) == "bool"
27412753
)
27422754

2743-
# Reconstruct full index list from idx_list and inputs
2744-
indices = node.inputs[1:]
2755+
# Reconstruct the full indices from idx_list and inputs (like perform method)
2756+
inputs = node.inputs[1:]
2757+
27452758
full_indices = []
27462759
input_idx = 0
27472760

27482761
for entry in self.idx_list:
2749-
if isinstance(entry, slice):
2750-
full_indices.append(entry)
2751-
elif entry is np.newaxis:
2752-
full_indices.append(entry)
2762+
if entry is np.newaxis:
2763+
full_indices.append(np.newaxis)
2764+
elif isinstance(entry, slice):
2765+
# Reconstruct slice from idx_list and inputs
2766+
if entry.start is not None and isinstance(entry.start, Type):
2767+
start_val = inputs[input_idx]
2768+
input_idx += 1
2769+
else:
2770+
start_val = entry.start
2771+
2772+
if entry.stop is not None and isinstance(entry.stop, Type):
2773+
stop_val = inputs[input_idx]
2774+
input_idx += 1
2775+
else:
2776+
stop_val = entry.stop
2777+
2778+
if entry.step is not None and isinstance(entry.step, Type):
2779+
step_val = inputs[input_idx]
2780+
input_idx += 1
2781+
else:
2782+
step_val = entry.step
2783+
2784+
full_indices.append(slice(start_val, stop_val, step_val))
27532785
elif isinstance(entry, Type):
27542786
# This is a numerical index - get from inputs
2755-
if input_idx < len(indices):
2756-
full_indices.append(indices[input_idx])
2787+
if input_idx < len(inputs):
2788+
full_indices.append(inputs[input_idx])
27572789
input_idx += 1
27582790
else:
27592791
raise ValueError("Mismatch between idx_list and inputs")
@@ -2771,7 +2803,7 @@ def is_bool_index(idx):
27712803
index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
27722804
else:
27732805
# Get ishape for this input
2774-
input_shape_idx = indices.index(idx) + 1 # +1 because ishapes[0] is x
2806+
input_shape_idx = inputs.index(idx) + 1 # +1 because ishapes[0] is x
27752807
index_shapes.append(ishapes[input_shape_idx])
27762808
else:
27772809
index_shapes.append(idx)
@@ -2813,10 +2845,29 @@ def perform(self, node, inputs, out_):
28132845
input_idx = 0
28142846

28152847
for entry in self.idx_list:
2816-
if isinstance(entry, slice):
2817-
full_indices.append(entry)
2818-
elif entry is np.newaxis:
2848+
if entry is np.newaxis:
28192849
full_indices.append(np.newaxis)
2850+
elif isinstance(entry, slice):
2851+
# Reconstruct slice from idx_list and inputs
2852+
if entry.start is not None and isinstance(entry.start, Type):
2853+
start_val = tensor_inputs[input_idx]
2854+
input_idx += 1
2855+
else:
2856+
start_val = entry.start
2857+
2858+
if entry.stop is not None and isinstance(entry.stop, Type):
2859+
stop_val = tensor_inputs[input_idx]
2860+
input_idx += 1
2861+
else:
2862+
stop_val = entry.stop
2863+
2864+
if entry.step is not None and isinstance(entry.step, Type):
2865+
step_val = tensor_inputs[input_idx]
2866+
input_idx += 1
2867+
else:
2868+
step_val = entry.step
2869+
2870+
full_indices.append(slice(start_val, stop_val, step_val))
28202871
elif isinstance(entry, Type):
28212872
# This is a numerical index - get from inputs
28222873
if input_idx < len(tensor_inputs):
@@ -2989,10 +3040,29 @@ def perform(self, node, inputs, out_):
29893040
input_idx = 0
29903041

29913042
for entry in self.idx_list:
2992-
if isinstance(entry, slice):
2993-
full_indices.append(entry)
2994-
elif entry is np.newaxis:
3043+
if entry is np.newaxis:
29953044
full_indices.append(np.newaxis)
3045+
elif isinstance(entry, slice):
3046+
# Reconstruct slice from idx_list and inputs
3047+
if entry.start is not None and isinstance(entry.start, Type):
3048+
start_val = tensor_inputs[input_idx]
3049+
input_idx += 1
3050+
else:
3051+
start_val = entry.start
3052+
3053+
if entry.stop is not None and isinstance(entry.stop, Type):
3054+
stop_val = tensor_inputs[input_idx]
3055+
input_idx += 1
3056+
else:
3057+
stop_val = entry.stop
3058+
3059+
if entry.step is not None and isinstance(entry.step, Type):
3060+
step_val = tensor_inputs[input_idx]
3061+
input_idx += 1
3062+
else:
3063+
step_val = entry.step
3064+
3065+
full_indices.append(slice(start_val, stop_val, step_val))
29963066
elif isinstance(entry, Type):
29973067
# This is a numerical index - get from inputs
29983068
if input_idx < len(tensor_inputs):
@@ -3108,75 +3178,111 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
31083178
def advanced_subtensor(x, *args):
31093179
"""Create an AdvancedSubtensor operation.
31103180
3111-
This function processes the arguments to separate numerical indices from
3112-
slice/newaxis information and creates the appropriate AdvancedSubtensor op.
3181+
This function converts the arguments to work with the new AdvancedSubtensor
3182+
interface that separates slice structure from variable inputs.
31133183
"""
3114-
# Process args to extract idx_list and numerical inputs
3115-
idx_list = []
3116-
numerical_inputs = []
3117-
3184+
# Convert raw args to proper form first
3185+
processed_args = []
31183186
for arg in args:
31193187
if arg is None:
3120-
idx_list.append(np.newaxis)
3188+
processed_args.append(NoneConst.clone())
31213189
elif isinstance(arg, slice):
3122-
idx_list.append(arg)
3123-
elif isinstance(arg, Variable) and isinstance(arg.type, SliceType):
3124-
# Convert SliceType variable back to slice - this should be a constant
3190+
processed_args.append(make_slice(arg))
3191+
else:
3192+
processed_args.append(as_tensor_variable(arg))
3193+
3194+
# Now create idx_list and extract inputs
3195+
idx_list = []
3196+
input_vars = []
3197+
3198+
for arg in processed_args:
3199+
if isinstance(arg.type, NoneTypeT):
3200+
idx_list.append(np.newaxis)
3201+
elif isinstance(arg.type, SliceType):
3202+
# Handle SliceType - extract components and structure
31253203
if isinstance(arg, Constant):
3204+
# Constant slice
31263205
idx_list.append(arg.data)
31273206
elif arg.owner and isinstance(arg.owner.op, MakeSlice):
3128-
# Convert MakeSlice back to slice
3207+
# Variable slice - extract components
31293208
start, stop, step = arg.owner.inputs
3130-
start_val = start.data if isinstance(start, Constant) else start
3131-
stop_val = stop.data if isinstance(stop, Constant) else stop
3132-
step_val = step.data if isinstance(step, Constant) else step
3133-
idx_list.append(slice(start_val, stop_val, step_val))
3209+
3210+
# Convert components to types for idx_list
3211+
start_type = index_vars_to_types(start, False) if not isinstance(start.type, NoneTypeT) else None
3212+
stop_type = index_vars_to_types(stop, False) if not isinstance(stop.type, NoneTypeT) else None
3213+
step_type = index_vars_to_types(step, False) if not isinstance(step.type, NoneTypeT) else None
3214+
3215+
idx_list.append(slice(start_type, stop_type, step_type))
3216+
3217+
# Add variable components to inputs
3218+
if not isinstance(start.type, NoneTypeT):
3219+
input_vars.append(start)
3220+
if not isinstance(stop.type, NoneTypeT):
3221+
input_vars.append(stop)
3222+
if not isinstance(step.type, NoneTypeT):
3223+
input_vars.append(step)
31343224
else:
3135-
# This is a symbolic slice that we need to handle
3136-
# For now, convert to a generic slice - this may need more work
3225+
# Other slice case
31373226
idx_list.append(slice(None))
3138-
elif isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT):
3139-
idx_list.append(np.newaxis)
31403227
else:
3141-
# This is a numerical index (tensor, scalar, etc.)
3142-
idx_list.append(index_vars_to_types(as_tensor_variable(arg)))
3143-
numerical_inputs.append(as_tensor_variable(arg))
3228+
# Tensor index
3229+
idx_list.append(index_vars_to_types(arg))
3230+
input_vars.append(arg)
31443231

3145-
return AdvancedSubtensor(idx_list).make_node(x, *numerical_inputs).outputs[0]
3232+
return AdvancedSubtensor(idx_list).make_node(x, *input_vars).outputs[0]
31463233

31473234

31483235
def advanced_inc_subtensor(x, y, *args, **kwargs):
31493236
"""Create an AdvancedIncSubtensor operation for incrementing."""
3150-
# Process args to extract idx_list and numerical inputs
3151-
idx_list = []
3152-
numerical_inputs = []
3153-
3237+
# Convert raw args to proper form first
3238+
processed_args = []
31543239
for arg in args:
31553240
if arg is None:
3156-
idx_list.append(np.newaxis)
3241+
processed_args.append(NoneConst.clone())
31573242
elif isinstance(arg, slice):
3158-
idx_list.append(arg)
3159-
elif isinstance(arg, Variable) and isinstance(arg.type, SliceType):
3160-
# Convert SliceType variable back to slice
3243+
processed_args.append(make_slice(arg))
3244+
else:
3245+
processed_args.append(as_tensor_variable(arg))
3246+
3247+
# Now create idx_list and extract inputs
3248+
idx_list = []
3249+
input_vars = []
3250+
3251+
for arg in processed_args:
3252+
if isinstance(arg.type, NoneTypeT):
3253+
idx_list.append(np.newaxis)
3254+
elif isinstance(arg.type, SliceType):
3255+
# Handle SliceType - extract components and structure
31613256
if isinstance(arg, Constant):
3257+
# Constant slice
31623258
idx_list.append(arg.data)
31633259
elif arg.owner and isinstance(arg.owner.op, MakeSlice):
3164-
# Convert MakeSlice back to slice
3260+
# Variable slice - extract components
31653261
start, stop, step = arg.owner.inputs
3166-
start_val = start.data if isinstance(start, Constant) else start
3167-
stop_val = stop.data if isinstance(stop, Constant) else stop
3168-
step_val = step.data if isinstance(step, Constant) else step
3169-
idx_list.append(slice(start_val, stop_val, step_val))
3262+
3263+
# Convert components to types for idx_list
3264+
start_type = index_vars_to_types(start, False) if not isinstance(start.type, NoneTypeT) else None
3265+
stop_type = index_vars_to_types(stop, False) if not isinstance(stop.type, NoneTypeT) else None
3266+
step_type = index_vars_to_types(step, False) if not isinstance(step.type, NoneTypeT) else None
3267+
3268+
idx_list.append(slice(start_type, stop_type, step_type))
3269+
3270+
# Add variable components to inputs
3271+
if not isinstance(start.type, NoneTypeT):
3272+
input_vars.append(start)
3273+
if not isinstance(stop.type, NoneTypeT):
3274+
input_vars.append(stop)
3275+
if not isinstance(step.type, NoneTypeT):
3276+
input_vars.append(step)
31703277
else:
3278+
# Other slice case
31713279
idx_list.append(slice(None))
3172-
elif isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT):
3173-
idx_list.append(np.newaxis)
31743280
else:
3175-
# This is a numerical index
3176-
idx_list.append(index_vars_to_types(as_tensor_variable(arg)))
3177-
numerical_inputs.append(as_tensor_variable(arg))
3281+
# Tensor index
3282+
idx_list.append(index_vars_to_types(arg))
3283+
input_vars.append(arg)
31783284

3179-
return AdvancedIncSubtensor(idx_list, **kwargs).make_node(x, y, *numerical_inputs).outputs[0]
3285+
return AdvancedIncSubtensor(idx_list, **kwargs).make_node(x, y, *input_vars).outputs[0]
31803286

31813287

31823288
def advanced_set_subtensor(x, y, *args, **kwargs):

0 commit comments

Comments
 (0)