@@ -569,14 +569,11 @@ struct ReduceOp {
569569 using load_t = at::native::memory::aligned_vector<scalar_t , output_vec_size>;
570570
571571 // Multiple accumulators to remove dependency between unrolled loops.
572- arg_vec_t value_list[vt0] ;
572+ arg_vec_t value_list;
573573
574574 #pragma unroll
575- for (int i = 0 ; i < vt0; i++) {
576- #pragma unroll
577- for (int j = 0 ; j < output_vec_size; j++) {
578- value_list[i][j] = ident;
579- }
575+ for (int j = 0 ; j < output_vec_size; j++) {
576+ value_list[j] = ident;
580577 }
581578
582579 load_t values[vt0];
@@ -591,7 +588,7 @@ struct ReduceOp {
591588 for (index_t i = 0 ; i < vt0; i++) {
592589 #pragma unroll
593590 for (index_t j = 0 ; j < output_vec_size; j++) {
594- value_list[i][ j] = ops.reduce (value_list[i] [j], values[i].val [j], idx + i * stride);
591+ value_list[j] = ops.reduce (value_list[j], values[i].val [j], idx + i * stride);
595592 }
596593 }
597594 idx += stride * vt0;
@@ -616,20 +613,12 @@ struct ReduceOp {
616613 }
617614 #pragma unroll
618615 for (index_t j = 0 ; j < output_vec_size; j++) {
619- value_list[i][ j] = ops.reduce (value_list[i] [j], values[i].val [j], idx);
616+ value_list[j] = ops.reduce (value_list[j], values[i].val [j], idx);
620617 }
621618 idx += stride;
622619 }
623620
624- // combine accumulators
625- #pragma unroll
626- for (int i = 1 ; i < vt0; i++) {
627- #pragma unroll
628- for (index_t j = 0 ; j < output_vec_size; j++) {
629- value_list[0 ][j] = ops.combine (value_list[0 ][j], value_list[i][j]);
630- }
631- }
632- return value_list[0 ];
621+ return value_list;
633622 }
634623
635624 template <int output_vec_size>
0 commit comments