@@ -41,47 +41,46 @@ void softmax_cpu(const SoftmaxInfo &info,
4141 int dimsize = info.dimsize ;
4242 int stride = info.stride ;
4343 int othersize = info.otherdim_size ;
44- if constexpr (std::is_same_v<T, fp16_t >) {
45- auto input = reinterpret_cast <const fp16_t *>(x);
46- auto output = reinterpret_cast <fp16_t *>(y);
47- for (int i = 0 ; i < othersize; i++) {
48- int tid = i % stride + (i - i % stride) * dimsize;
49- float max_data = -INFINITY;
50- for (int j = 0 ; j < dimsize; j++) {
51- int index = tid + j * stride;
52- max_data = fmax (max_data, utils::cast<float >(input[index]));
53- }
54- float sum_data = 0 .0f ;
55- for (int j = 0 ; j < dimsize; j++) {
56- int index = tid + j * stride;
57- sum_data += std::exp (utils::cast<float >(input[index]) - max_data);
58- }
59- for (int j = 0 ; j < dimsize; j++) {
60- int index = tid + j * stride;
61- output[index] = utils::cast<fp16_t >(std::exp (utils::cast<float >(input[index]) - max_data) / sum_data);
62- }
44+ auto to_float = [](const T &val) -> float {
45+ if constexpr (std::is_same_v<T, fp16_t >) {
46+ return utils::cast<float >(val);
47+ } else {
48+ return val;
6349 }
64- } else if constexpr (std::is_same_v<T, float >) {
65- auto input = reinterpret_cast <const float *>(x);
66- auto output = reinterpret_cast <float *>(y);
67- #pragma omp parallel for
68- for (int i = 0 ; i < othersize; i++) {
69- int tid = i % stride + (i - i % stride) * dimsize;
70- float max_data = -INFINITY;
71- for (int j = 0 ; j < dimsize; j++) {
72- int index = tid + j * stride;
73- max_data = fmax (max_data, input[index]);
74- }
75- float sum_data = 0 .0f ;
76- for (int j = 0 ; j < dimsize; j++) {
77- int index = tid + j * stride;
78- sum_data += std::exp (input[index] - max_data);
79- }
80- for (int j = 0 ; j < dimsize; j++) {
81- int index = tid + j * stride;
82- output[index] = std::exp (input[index] - max_data) / sum_data;
83- }
50+ };
51+
52+ auto from_float = [](float val) -> T {
53+ if constexpr (std::is_same_v<T, fp16_t >) {
54+ return utils::cast<fp16_t >(val);
55+ } else {
56+ return val;
57+ }
58+ };
59+
60+ auto input = reinterpret_cast <const T *>(x);
61+ auto output = reinterpret_cast <T *>(y);
62+
63+ auto compute_softmax = [&](int i) {
64+ int tid = i % stride + (i - i % stride) * dimsize;
65+ float max_data = -INFINITY;
66+ for (int j = 0 ; j < dimsize; j++) {
67+ int index = tid + j * stride;
68+ max_data = fmax (max_data, to_float (input[index]));
69+ }
70+ float sum_data = 0 .0f ;
71+ for (int j = 0 ; j < dimsize; j++) {
72+ int index = tid + j * stride;
73+ sum_data += std::exp (to_float (input[index]) - max_data);
8474 }
75+ for (int j = 0 ; j < dimsize; j++) {
76+ int index = tid + j * stride;
77+ float result = std::exp (to_float (input[index]) - max_data) / sum_data;
78+ output[index] = from_float (result);
79+ }
80+ };
81+ #pragma omp parallel for
82+ for (int i = 0 ; i < othersize; i++) {
83+ compute_softmax (i);
8584 }
8685}
8786
0 commit comments