@@ -2604,28 +2604,48 @@ def make_node(self, x, *inputs):
26042604 inputs = tuple (as_tensor_variable (a ) for a in inputs )
26052605
26062606 idx_list = list (self .idx_list )
2607- if len (idx_list ) > x .type .ndim :
2607+ if ( len ([ entry for entry in idx_list if entry is not np . newaxis ] ) > x .type .ndim ) :
26082608 raise IndexError ("too many indices for array" )
26092609
2610- # Get input types from idx_list - only process numerical indices
2611- input_types = []
2612- input_idx = 0
2610+ # Validate input count matches expected from idx_list
2611+ expected_inputs = get_slice_elements (idx_list , lambda entry : isinstance (entry , Type ))
2612+ if len (inputs ) != len (expected_inputs ):
2613+ raise ValueError (f"Expected { len (expected_inputs )} inputs but got { len (inputs )} " )
2614+
2615+ # Build explicit_indices for shape inference
26132616 explicit_indices = []
26142617 new_axes = []
2618+ input_idx = 0
26152619
26162620 for i , entry in enumerate (idx_list ):
2617- if isinstance (entry , slice ):
2618- # Slices are stored in idx_list, not passed as inputs
2619- explicit_indices .append (entry )
2620- elif entry is np .newaxis :
2621- # Newaxis stored in idx_list, not passed as inputs
2621+ if entry is np .newaxis :
26222622 new_axes .append (len (explicit_indices ))
2623- explicit_indices .append (entry )
2623+ explicit_indices .append (np .newaxis )
2624+ elif isinstance (entry , slice ):
2625+ # Reconstruct slice with actual values from inputs
2626+ if entry .start is not None and isinstance (entry .start , Type ):
2627+ start_val = inputs [input_idx ]
2628+ input_idx += 1
2629+ else :
2630+ start_val = entry .start
2631+
2632+ if entry .stop is not None and isinstance (entry .stop , Type ):
2633+ stop_val = inputs [input_idx ]
2634+ input_idx += 1
2635+ else :
2636+ stop_val = entry .stop
2637+
2638+ if entry .step is not None and isinstance (entry .step , Type ):
2639+ step_val = inputs [input_idx ]
2640+ input_idx += 1
2641+ else :
2642+ step_val = entry .step
2643+
2644+ explicit_indices .append (slice (start_val , stop_val , step_val ))
26242645 elif isinstance (entry , Type ):
2625- # This is a numerical index - should have corresponding input
2626- if input_idx >= len (inputs ):
2627- raise ValueError (f"Missing input for index { i } " )
2646+ # This is a numerical index
26282647 inp = inputs [input_idx ]
2648+ input_idx += 1
26292649
26302650 # Handle boolean indices
26312651 if inp .dtype == "bool" :
@@ -2649,26 +2669,18 @@ def make_node(self, x, *inputs):
26492669 f"boolean index did not match indexed tensor along axis { axis + j } ;"
26502670 f"size of axis is { indexed_length } but size of corresponding boolean axis is { indexer_length } "
26512671 )
2652- # Convert boolean indices to integer with nonzero, to reason about static shape next
2672+ # Convert boolean indices to integer with nonzero
26532673 if isinstance (inp , Constant ):
26542674 nonzero_indices = [tensor_constant (i ) for i in inp .data .nonzero ()]
26552675 else :
2656- # Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
2657- # and seeing that other integer indices cannot possible match it
26582676 nonzero_indices = inp .nonzero ()
26592677 explicit_indices .extend (nonzero_indices )
26602678 else :
26612679 # Regular numerical index
26622680 explicit_indices .append (inp )
2663-
2664- input_types .append (entry )
2665- input_idx += 1
26662681 else :
26672682 raise ValueError (f"Invalid entry in idx_list: { entry } " )
26682683
2669- if input_idx != len (inputs ):
2670- raise ValueError (f"Expected { input_idx } inputs but got { len (inputs )} " )
2671-
26722684 if (len (explicit_indices ) - len (new_axes )) > x .type .ndim :
26732685 raise IndexError (
26742686 f"too many indices for array: tensor is { x .type .ndim } -dimensional, but { len (explicit_indices ) - len (new_axes )} were indexed"
@@ -2740,20 +2752,40 @@ def is_bool_index(idx):
27402752 or getattr (idx , "dtype" , None ) == "bool"
27412753 )
27422754
2743- # Reconstruct full index list from idx_list and inputs
2744- indices = node .inputs [1 :]
2755+ # Reconstruct the full indices from idx_list and inputs (like perform method)
2756+ inputs = node .inputs [1 :]
2757+
27452758 full_indices = []
27462759 input_idx = 0
27472760
27482761 for entry in self .idx_list :
2749- if isinstance (entry , slice ):
2750- full_indices .append (entry )
2751- elif entry is np .newaxis :
2752- full_indices .append (entry )
2762+ if entry is np .newaxis :
2763+ full_indices .append (np .newaxis )
2764+ elif isinstance (entry , slice ):
2765+ # Reconstruct slice from idx_list and inputs
2766+ if entry .start is not None and isinstance (entry .start , Type ):
2767+ start_val = inputs [input_idx ]
2768+ input_idx += 1
2769+ else :
2770+ start_val = entry .start
2771+
2772+ if entry .stop is not None and isinstance (entry .stop , Type ):
2773+ stop_val = inputs [input_idx ]
2774+ input_idx += 1
2775+ else :
2776+ stop_val = entry .stop
2777+
2778+ if entry .step is not None and isinstance (entry .step , Type ):
2779+ step_val = inputs [input_idx ]
2780+ input_idx += 1
2781+ else :
2782+ step_val = entry .step
2783+
2784+ full_indices .append (slice (start_val , stop_val , step_val ))
27532785 elif isinstance (entry , Type ):
27542786 # This is a numerical index - get from inputs
2755- if input_idx < len (indices ):
2756- full_indices .append (indices [input_idx ])
2787+ if input_idx < len (inputs ):
2788+ full_indices .append (inputs [input_idx ])
27572789 input_idx += 1
27582790 else :
27592791 raise ValueError ("Mismatch between idx_list and inputs" )
@@ -2771,7 +2803,7 @@ def is_bool_index(idx):
27712803 index_shapes .extend ((shape0_op (nz_dim ),) for nz_dim in nonzero (idx ))
27722804 else :
27732805 # Get ishape for this input
2774- input_shape_idx = indices .index (idx ) + 1 # +1 because ishapes[0] is x
2806+ input_shape_idx = inputs .index (idx ) + 1 # +1 because ishapes[0] is x
27752807 index_shapes .append (ishapes [input_shape_idx ])
27762808 else :
27772809 index_shapes .append (idx )
@@ -2813,10 +2845,29 @@ def perform(self, node, inputs, out_):
28132845 input_idx = 0
28142846
28152847 for entry in self .idx_list :
2816- if isinstance (entry , slice ):
2817- full_indices .append (entry )
2818- elif entry is np .newaxis :
2848+ if entry is np .newaxis :
28192849 full_indices .append (np .newaxis )
2850+ elif isinstance (entry , slice ):
2851+ # Reconstruct slice from idx_list and inputs
2852+ if entry .start is not None and isinstance (entry .start , Type ):
2853+ start_val = tensor_inputs [input_idx ]
2854+ input_idx += 1
2855+ else :
2856+ start_val = entry .start
2857+
2858+ if entry .stop is not None and isinstance (entry .stop , Type ):
2859+ stop_val = tensor_inputs [input_idx ]
2860+ input_idx += 1
2861+ else :
2862+ stop_val = entry .stop
2863+
2864+ if entry .step is not None and isinstance (entry .step , Type ):
2865+ step_val = tensor_inputs [input_idx ]
2866+ input_idx += 1
2867+ else :
2868+ step_val = entry .step
2869+
2870+ full_indices .append (slice (start_val , stop_val , step_val ))
28202871 elif isinstance (entry , Type ):
28212872 # This is a numerical index - get from inputs
28222873 if input_idx < len (tensor_inputs ):
@@ -2989,10 +3040,29 @@ def perform(self, node, inputs, out_):
29893040 input_idx = 0
29903041
29913042 for entry in self .idx_list :
2992- if isinstance (entry , slice ):
2993- full_indices .append (entry )
2994- elif entry is np .newaxis :
3043+ if entry is np .newaxis :
29953044 full_indices .append (np .newaxis )
3045+ elif isinstance (entry , slice ):
3046+ # Reconstruct slice from idx_list and inputs
3047+ if entry .start is not None and isinstance (entry .start , Type ):
3048+ start_val = tensor_inputs [input_idx ]
3049+ input_idx += 1
3050+ else :
3051+ start_val = entry .start
3052+
3053+ if entry .stop is not None and isinstance (entry .stop , Type ):
3054+ stop_val = tensor_inputs [input_idx ]
3055+ input_idx += 1
3056+ else :
3057+ stop_val = entry .stop
3058+
3059+ if entry .step is not None and isinstance (entry .step , Type ):
3060+ step_val = tensor_inputs [input_idx ]
3061+ input_idx += 1
3062+ else :
3063+ step_val = entry .step
3064+
3065+ full_indices .append (slice (start_val , stop_val , step_val ))
29963066 elif isinstance (entry , Type ):
29973067 # This is a numerical index - get from inputs
29983068 if input_idx < len (tensor_inputs ):
@@ -3108,75 +3178,111 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
31083178def advanced_subtensor (x , * args ):
31093179 """Create an AdvancedSubtensor operation.
31103180
3111- This function processes the arguments to separate numerical indices from
3112- slice/newaxis information and creates the appropriate AdvancedSubtensor op .
3181+ This function converts the arguments to work with the new AdvancedSubtensor
3182+ interface that separates slice structure from variable inputs .
31133183 """
3114- # Process args to extract idx_list and numerical inputs
3115- idx_list = []
3116- numerical_inputs = []
3117-
3184+ # Convert raw args to proper form first
3185+ processed_args = []
31183186 for arg in args :
31193187 if arg is None :
3120- idx_list .append (np . newaxis )
3188+ processed_args .append (NoneConst . clone () )
31213189 elif isinstance (arg , slice ):
3122- idx_list .append (arg )
3123- elif isinstance (arg , Variable ) and isinstance (arg .type , SliceType ):
3124- # Convert SliceType variable back to slice - this should be a constant
3190+ processed_args .append (make_slice (arg ))
3191+ else :
3192+ processed_args .append (as_tensor_variable (arg ))
3193+
3194+ # Now create idx_list and extract inputs
3195+ idx_list = []
3196+ input_vars = []
3197+
3198+ for arg in processed_args :
3199+ if isinstance (arg .type , NoneTypeT ):
3200+ idx_list .append (np .newaxis )
3201+ elif isinstance (arg .type , SliceType ):
3202+ # Handle SliceType - extract components and structure
31253203 if isinstance (arg , Constant ):
3204+ # Constant slice
31263205 idx_list .append (arg .data )
31273206 elif arg .owner and isinstance (arg .owner .op , MakeSlice ):
3128- # Convert MakeSlice back to slice
3207+ # Variable slice - extract components
31293208 start , stop , step = arg .owner .inputs
3130- start_val = start .data if isinstance (start , Constant ) else start
3131- stop_val = stop .data if isinstance (stop , Constant ) else stop
3132- step_val = step .data if isinstance (step , Constant ) else step
3133- idx_list .append (slice (start_val , stop_val , step_val ))
3209+
3210+ # Convert components to types for idx_list
3211+ start_type = index_vars_to_types (start , False ) if not isinstance (start .type , NoneTypeT ) else None
3212+ stop_type = index_vars_to_types (stop , False ) if not isinstance (stop .type , NoneTypeT ) else None
3213+ step_type = index_vars_to_types (step , False ) if not isinstance (step .type , NoneTypeT ) else None
3214+
3215+ idx_list .append (slice (start_type , stop_type , step_type ))
3216+
3217+ # Add variable components to inputs
3218+ if not isinstance (start .type , NoneTypeT ):
3219+ input_vars .append (start )
3220+ if not isinstance (stop .type , NoneTypeT ):
3221+ input_vars .append (stop )
3222+ if not isinstance (step .type , NoneTypeT ):
3223+ input_vars .append (step )
31343224 else :
3135- # This is a symbolic slice that we need to handle
3136- # For now, convert to a generic slice - this may need more work
3225+ # Other slice case
31373226 idx_list .append (slice (None ))
3138- elif isinstance (arg , Variable ) and isinstance (arg .type , NoneTypeT ):
3139- idx_list .append (np .newaxis )
31403227 else :
3141- # This is a numerical index (tensor, scalar, etc.)
3142- idx_list .append (index_vars_to_types (as_tensor_variable ( arg ) ))
3143- numerical_inputs .append (as_tensor_variable ( arg ) )
3228+ # Tensor index
3229+ idx_list .append (index_vars_to_types (arg ))
3230+ input_vars .append (arg )
31443231
3145- return AdvancedSubtensor (idx_list ).make_node (x , * numerical_inputs ).outputs [0 ]
3232+ return AdvancedSubtensor (idx_list ).make_node (x , * input_vars ).outputs [0 ]
31463233
31473234
31483235def advanced_inc_subtensor (x , y , * args , ** kwargs ):
31493236 """Create an AdvancedIncSubtensor operation for incrementing."""
3150- # Process args to extract idx_list and numerical inputs
3151- idx_list = []
3152- numerical_inputs = []
3153-
3237+ # Convert raw args to proper form first
3238+ processed_args = []
31543239 for arg in args :
31553240 if arg is None :
3156- idx_list .append (np . newaxis )
3241+ processed_args .append (NoneConst . clone () )
31573242 elif isinstance (arg , slice ):
3158- idx_list .append (arg )
3159- elif isinstance (arg , Variable ) and isinstance (arg .type , SliceType ):
3160- # Convert SliceType variable back to slice
3243+ processed_args .append (make_slice (arg ))
3244+ else :
3245+ processed_args .append (as_tensor_variable (arg ))
3246+
3247+ # Now create idx_list and extract inputs
3248+ idx_list = []
3249+ input_vars = []
3250+
3251+ for arg in processed_args :
3252+ if isinstance (arg .type , NoneTypeT ):
3253+ idx_list .append (np .newaxis )
3254+ elif isinstance (arg .type , SliceType ):
3255+ # Handle SliceType - extract components and structure
31613256 if isinstance (arg , Constant ):
3257+ # Constant slice
31623258 idx_list .append (arg .data )
31633259 elif arg .owner and isinstance (arg .owner .op , MakeSlice ):
3164- # Convert MakeSlice back to slice
3260+ # Variable slice - extract components
31653261 start , stop , step = arg .owner .inputs
3166- start_val = start .data if isinstance (start , Constant ) else start
3167- stop_val = stop .data if isinstance (stop , Constant ) else stop
3168- step_val = step .data if isinstance (step , Constant ) else step
3169- idx_list .append (slice (start_val , stop_val , step_val ))
3262+
3263+ # Convert components to types for idx_list
3264+ start_type = index_vars_to_types (start , False ) if not isinstance (start .type , NoneTypeT ) else None
3265+ stop_type = index_vars_to_types (stop , False ) if not isinstance (stop .type , NoneTypeT ) else None
3266+ step_type = index_vars_to_types (step , False ) if not isinstance (step .type , NoneTypeT ) else None
3267+
3268+ idx_list .append (slice (start_type , stop_type , step_type ))
3269+
3270+ # Add variable components to inputs
3271+ if not isinstance (start .type , NoneTypeT ):
3272+ input_vars .append (start )
3273+ if not isinstance (stop .type , NoneTypeT ):
3274+ input_vars .append (stop )
3275+ if not isinstance (step .type , NoneTypeT ):
3276+ input_vars .append (step )
31703277 else :
3278+ # Other slice case
31713279 idx_list .append (slice (None ))
3172- elif isinstance (arg , Variable ) and isinstance (arg .type , NoneTypeT ):
3173- idx_list .append (np .newaxis )
31743280 else :
3175- # This is a numerical index
3176- idx_list .append (index_vars_to_types (as_tensor_variable ( arg ) ))
3177- numerical_inputs .append (as_tensor_variable ( arg ) )
3281+ # Tensor index
3282+ idx_list .append (index_vars_to_types (arg ))
3283+ input_vars .append (arg )
31783284
3179- return AdvancedIncSubtensor (idx_list , ** kwargs ).make_node (x , y , * numerical_inputs ).outputs [0 ]
3285+ return AdvancedIncSubtensor (idx_list , ** kwargs ).make_node (x , y , * input_vars ).outputs [0 ]
31803286
31813287
31823288def advanced_set_subtensor (x , y , * args , ** kwargs ):
0 commit comments