@@ -102,6 +102,25 @@ def codegen_store(
102102 ) -> ast .AST :
103103 indexing = SubscriptIndexing .create (state , fake_tensor , subscript , extra_mask )
104104 name = state .device_function .tensor_arg (fake_tensor ).name
105+
106+ # Check if value is a tensor load (Name node with id matching a tensor arg)
107+ if isinstance (value , ast .Name ) and hasattr (state .device_function , '_tensor_args' ):
108+ # Check if this name corresponds to a tensor argument
109+ for tensor , tensor_arg in state .device_function ._tensor_args .items ():
110+ if tensor_arg .name == value .id :
111+ # This is a tensor value, we need to load from it
112+ # Get the shape of the slice we're storing to
113+ output_shape = SubscriptIndexing .compute_shape (fake_tensor , subscript )
114+ if len (output_shape ) == 1 and tensor .ndim == 1 :
115+ # Load the entire 1D tensor
116+ value_indexing = SubscriptIndexing .create (state , tensor , [slice (None )], None )
117+ value = expr_from_string (
118+ f"tl.load({ value .id } + offset, mask)" ,
119+ offset = value_indexing .index_expr ,
120+ mask = value_indexing .mask_expr ,
121+ )
122+ break
123+
105124 return expr_from_string (
106125 f"tl.store({ name } + offset, value, mask)" ,
107126 value = value ,
@@ -511,7 +530,14 @@ def compute_shape(
511530 output_size .extend (k .size ())
512531 else :
513532 raise exc .InvalidIndexingType (k )
514- assert len (input_size ) == 0 , "invalid subscript"
533+ # For partial indexing, append remaining dimensions to output
534+ while input_size :
535+ size = input_size .popleft ()
536+ if size != 1 :
537+ rdim = env .allocate_reduction_dimension (size )
538+ output_size .append (rdim .var )
539+ else :
540+ output_size .append (1 )
515541 return output_size
516542
517543 @staticmethod
@@ -648,6 +674,22 @@ def create(
648674 )
649675 else :
650676 raise exc .InvalidIndexingType (type (k ))
677+
678+ # Handle remaining dimensions for partial indexing
679+ while len (index_values ) < fake_value .ndim :
680+ expand = tile_strategy .expand_str (output_size , output_idx )
681+ size = fake_value .size (len (index_values ))
682+ if size != 1 :
683+ rdim = env .allocate_reduction_dimension (size )
684+ block_idx = rdim .block_id
685+ index_var = state .codegen .index_var (block_idx )
686+ index_values .append (f"({ index_var } ){ expand } " )
687+ if mask := state .codegen .mask_var (block_idx ):
688+ mask_values .setdefault (f"({ mask } ){ expand } " )
689+ else :
690+ index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
691+ output_idx += 1
692+
651693 assert len (output_size ) == output_idx
652694 assert len (index_values ) == fake_value .ndim
653695 index_expr = []
0 commit comments