@@ -919,24 +919,23 @@ def store_reduction_epilogue(self, name, index, value):
919919 compute_index_var = "," .join (partial_zero_var_list )
920920
921921 with self .override_buffer_cse (buffer = self .reductions_suffix ):
922- out = ops ._load (partial_vec_size , mlir_dtype , sram_var , compute_index_var , partial_tile_shape )
922+ out = ops ._load (partial_vec_size , mlir_dtype , value , compute_index_var , partial_tile_shape )
923923 ops ._store (init_vec , value , compute_index_var , partial_tile_shape ) # Clear the partial buffer to zero
924924
925925 # 2 step reduction
926926 new_vec_size = 2
927- new_reduced_shape = f"<{ new_vec_size } x{ mlir_dtype } >"
927+ new_reduced_shape = f"vector <{ new_vec_size } x{ mlir_dtype } >"
928928 reduction_type = self .reduction_info [value ][0 ]
929- out = ops .multi_reduction (out , init_vec , partial_vec_size , new_vec_size , reduction_type , partial_vshape , self . reduction_info [ value ][ 0 ] , mlir_dtype )
929+ out = ops .multi_reduction (out , init_vec2 , partial_vec_size , new_vec_size , partial_vshape , reduction_type , mlir_dtype )
930930
931931 out2 = self .cse .generate (self .reductions_suffix , f"vector.shuffle %{ out } , %{ out } [1, 0] : { new_reduced_shape } , { new_reduced_shape } " )
932932 self .register_var_info (out2 , [new_vec_size , mlir_dtype ])
933933
934934 with self .override_buffer_cse (buffer = self .reductions_suffix ):
935935 out = reduction_partial_combine_vec (self .reduction_info [value ][0 ], out , out2 )
936936
937- if self .welford_reduce_out is not None :
938- # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2
939- with self .override_buffer_cse (buffer = self .reductions_suffix ):
937+ if self .welford_reduce_out is not None :
938+ # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2
940939 divider = ops .constant (float (self .reduction_axis_size ), "f32" )
941940 if self .buffer_types [name ][1 ] > 1 :
942941 divider_vec = ops .broadcast (divider , new_vec_size )
@@ -955,9 +954,9 @@ def store_reduction_epilogue(self, name, index, value):
955954 m2 = ops .mul (variance , divider_vec )
956955 out = m2
957956
958- final_zero_var_list [- 1 ] = f"%{ body_index_var } "
959- final_compute_index_var = "," .join (final_zero_var_list )
960- ops ._store (out , sram_var , final_compute_index_var , final_tile_shape , buffer_name = name )
957+ final_zero_var_list [- 1 ] = f"%{ body_index_var } "
958+ final_compute_index_var = "," .join (final_zero_var_list )
959+ ops ._store (out , sram_var , final_compute_index_var , final_tile_shape , buffer_name = name )
961960
962961 # MVOUT Encoding
963962 # Generate DMA instruction
0 commit comments