Skip to content

Commit e843a51

Browse files
committed
Issue/259 softmax_cpu计算抽象减少冗余
1 parent 7583cea commit e843a51

File tree

6 files changed

+47
-48
lines changed

6 files changed

+47
-48
lines changed

include/infiniop/ops/softmax.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef __INFINIOP_MLP_API_H__
2-
#define __INFINIOP_MLP_API_H__
1+
#ifndef __INFINIOP_SOFTMAX_API_H__
2+
#define __INFINIOP_SOFTMAX_API_H__
33

44
#include "../operator_descriptor.h"
55

src/infiniop/ops/softmax/cpu/softmax_cpu.cc

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -41,47 +41,46 @@ void softmax_cpu(const SoftmaxInfo &info,
4141
int dimsize = info.dimsize;
4242
int stride = info.stride;
4343
int othersize = info.otherdim_size;
44-
if constexpr (std::is_same_v<T, fp16_t>) {
45-
auto input = reinterpret_cast<const fp16_t *>(x);
46-
auto output = reinterpret_cast<fp16_t *>(y);
47-
for (int i = 0; i < othersize; i++) {
48-
int tid = i % stride + (i - i % stride) * dimsize;
49-
float max_data = -INFINITY;
50-
for (int j = 0; j < dimsize; j++) {
51-
int index = tid + j * stride;
52-
max_data = fmax(max_data, utils::cast<float>(input[index]));
53-
}
54-
float sum_data = 0.0f;
55-
for (int j = 0; j < dimsize; j++) {
56-
int index = tid + j * stride;
57-
sum_data += std::exp(utils::cast<float>(input[index]) - max_data);
58-
}
59-
for (int j = 0; j < dimsize; j++) {
60-
int index = tid + j * stride;
61-
output[index] = utils::cast<fp16_t>(std::exp(utils::cast<float>(input[index]) - max_data) / sum_data);
62-
}
44+
auto to_float = [](const T &val) -> float {
45+
if constexpr (std::is_same_v<T, fp16_t>) {
46+
return utils::cast<float>(val);
47+
} else {
48+
return val;
6349
}
64-
} else if constexpr (std::is_same_v<T, float>) {
65-
auto input = reinterpret_cast<const float *>(x);
66-
auto output = reinterpret_cast<float *>(y);
67-
#pragma omp parallel for
68-
for (int i = 0; i < othersize; i++) {
69-
int tid = i % stride + (i - i % stride) * dimsize;
70-
float max_data = -INFINITY;
71-
for (int j = 0; j < dimsize; j++) {
72-
int index = tid + j * stride;
73-
max_data = fmax(max_data, input[index]);
74-
}
75-
float sum_data = 0.0f;
76-
for (int j = 0; j < dimsize; j++) {
77-
int index = tid + j * stride;
78-
sum_data += std::exp(input[index] - max_data);
79-
}
80-
for (int j = 0; j < dimsize; j++) {
81-
int index = tid + j * stride;
82-
output[index] = std::exp(input[index] - max_data) / sum_data;
83-
}
50+
};
51+
52+
auto from_float = [](float val) -> T {
53+
if constexpr (std::is_same_v<T, fp16_t>) {
54+
return utils::cast<fp16_t>(val);
55+
} else {
56+
return val;
57+
}
58+
};
59+
60+
auto input = reinterpret_cast<const T *>(x);
61+
auto output = reinterpret_cast<T *>(y);
62+
63+
auto compute_softmax = [&](int i) {
64+
int tid = i % stride + (i - i % stride) * dimsize;
65+
float max_data = -INFINITY;
66+
for (int j = 0; j < dimsize; j++) {
67+
int index = tid + j * stride;
68+
max_data = fmax(max_data, to_float(input[index]));
69+
}
70+
float sum_data = 0.0f;
71+
for (int j = 0; j < dimsize; j++) {
72+
int index = tid + j * stride;
73+
sum_data += std::exp(to_float(input[index]) - max_data);
8474
}
75+
for (int j = 0; j < dimsize; j++) {
76+
int index = tid + j * stride;
77+
float result = std::exp(to_float(input[index]) - max_data) / sum_data;
78+
output[index] = from_float(result);
79+
}
80+
};
81+
#pragma omp parallel for
82+
for (int i = 0; i < othersize; i++) {
83+
compute_softmax(i);
8584
}
8685
}
8786

src/infiniop/ops/softmax/cuda/softmax_kernel.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#ifndef __SOFTMAX_CUDA_KERNEL_H__
22
#define __SOFTMAX_CUDA_KERNEL_H__
3+
34
#include "../../../devices/cuda/cuda_kernel_common.cuh"
45
#include "softmax_cuda.cuh"
56
#include <cub/block/block_reduce.cuh>
@@ -74,7 +75,6 @@ i 也就是 (blockIdx.x * blockDim.y + threadIdx.y) / stride
7475
j 也就是 (blockIdx.x * blockDim.y + threadIdx.y) % stride
7576
然后i转化为线性也就是 i * stride * dimsize
7677
j直接加上就好
77-
7878
*/
7979
template <int elemPerThread, int BLOCK_DIM_Y, int BLOCK_DIM_X, typename T>
8080
__global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int otherdim_size) {
@@ -236,4 +236,4 @@ infiniStatus_t softmax_dispatch(const op::softmax::SoftmaxInfo &info, void *y, c
236236
return INFINI_STATUS_SUCCESS;
237237
}
238238

239-
#endif // __SOFTMAX_CUDA_KERNEL_H__
239+
#endif // __SOFTMAX_CUDA_KERNEL_H__

src/infiniop/ops/softmax/info.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef __CONV_INFO_H__
2-
#define __CONV_INFO_H__
1+
#ifndef __SOFTMAX_INFO_H__
2+
#define __SOFTMAX_INFO_H__
33

44
#include "../../../utils.h"
55
#include "../../operator.h"
@@ -44,4 +44,4 @@ class SoftmaxInfo {
4444
};
4545
} // namespace op::softmax
4646

47-
#endif // __CONV_INFO_H__
47+
#endif // __SOFTMAX_INFO_H__

src/infiniop/ops/softmax/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,4 @@ infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) {
107107
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
108108
}
109109
#undef DELETE
110-
}
110+
}

src/infiniop/ops/softmax/softmax.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@
4646
void *stream) const; \
4747
}; \
4848
}
49-
#endif // __CONV_H__
49+
#endif // __SOFTMAX_H__

0 commit comments

Comments
 (0)