From b6df5ac745102ac6c089c3e8f70018088d778334 Mon Sep 17 00:00:00 2001 From: limingshu Date: Wed, 17 May 2023 14:10:09 +0800 Subject: [PATCH 1/7] add modify_vectorize_pass --- cinn/common/float16.h | 139 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 127 insertions(+), 12 deletions(-) diff --git a/cinn/common/float16.h b/cinn/common/float16.h index 4bf8c64614..4eee9b63ba 100644 --- a/cinn/common/float16.h +++ b/cinn/common/float16.h @@ -303,18 +303,6 @@ struct CINN_ALIGN(2) float16 { #endif // __cplusplus }; -struct CINN_ALIGN(32) float8 { - float x, y, z, w, v, u, t, s; -}; - -struct CINN_ALIGN(16) half8 { - float16 x, y, z, w, v, u, t, s; -}; - -struct CINN_ALIGN(8) half4 { - float16 x, y, z, w; -}; - #ifdef __cplusplus // Arithmetic operators on GPU // CUDA 9.0 provides built-in arithmetic operators for half while @@ -584,6 +572,133 @@ __host__ __device__ inline float16(abs)(const float16& a) { __host__ __device__ inline float16(log)(const float16& a) { return float16(std::log(static_cast(a))); } + +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; + __host__ __device__ inline const T& operator[](int i) const { return val[i]; } + __host__ __device__ inline T& operator[](int i) { return val[i]; } +}; + +struct CINN_ALIGN(32) float8 { + // float x, y, z, w, v, u, t, s; + AlignedVector data; + __device__ inline const float& operator[](int i) const { return data[i]; } + __device__ inline float& operator[](int i) { return data[i]; } + + __device__ inline float8& operator+(const float8& x) { + #pragma unroll(8) + for (int i = 0; i < 8; i++) {; + data[i] += x[i]; + }; + return *this; + } + + __device__ inline float8& operator-(const float8& x) { + #pragma unroll(8) + for (int i = 0; i < 8; i++) {; + data[i] -= x[i]; + }; + return *this; + } + + __device__ inline float8& operator*(const float8& x) { + #pragma unroll(8) + for (int i = 0; i < 8; i++) {; + data[i] *= x[i]; + }; + return *this; + } + + __device__ inline float8& operator/(const float8& x) { + #pragma unroll(8) + for (int i = 0; i < 8; i++) {; + data[i] /= x[i]; + }; + return *this; + } +}; + +struct CINN_ALIGN(16) half8 { + // float16 x, y, z, w, v, u, t, s; + AlignedVector data; + __device__ inline const float16& operator[](int i) const { return data[i]; } + __device__ inline float16& operator[](int i) { return data[i]; } + + __device__ inline half8& operator+(const half8& x) { + #pragma unroll(8) + for (int i = 0; i < 8; i++) {; + data[i] += x[i]; + }; + return *this; + } + + __device__ inline half8& operator-(const half8& x) { + #pragma unroll(8) + for (int i = 0; i < 8; i++) {; + data[i] -= x[i]; + }; + return *this; + } + + __device__ inline half8& operator*(const half8& x) { + #pragma unroll(8) + for (int i = 0; i < 8; i++) {; + data[i] *= x[i]; + }; + return *this; + } + + __device__ inline half8& operator/(const half8& x) { + #pragma unroll(8) + for (int i = 0; i < 8; i++) {; + data[i] /= x[i]; + }; + return *this; + } +}; + +struct CINN_ALIGN(8) half4 { + // float16 x, y, z, w; + AlignedVector data; + __device__ inline const float16& operator[](int i) const { return data[i]; } + __device__ inline float16& operator[](int i) { return data[i]; } + + __device__ inline half4& operator+(const half4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] += x[i]; + }; + return *this; + } + + __device__ inline half4& operator-(const half4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] -= x[i]; + }; + return *this; + } + + __device__ inline half4& operator*(const half4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] *= x[i]; + }; + return *this; + } + + __device__ inline half4& operator/(const half4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] /= x[i]; + }; + return *this; + } +}; + + + #ifdef __cplusplus } // namespace common } // namespace cinn From 679a9862cc4b5d5a2d6f3d5dc0c60a531cb25e9b Mon Sep 17 00:00:00 2001 From: limingshu Date: Wed, 17 May 2023 23:20:06 +0800 Subject: [PATCH 2/7] add vectorized==4 config --- cinn/common/float16.h | 69 +++++++++++++++++++++++++++++++++++ cinn/optim/vectorize_loops.cc | 2 +- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/cinn/common/float16.h b/cinn/common/float16.h index 4eee9b63ba..848704d384 100644 --- a/cinn/common/float16.h +++ b/cinn/common/float16.h @@ -697,7 +697,76 @@ struct CINN_ALIGN(8) half4 { } }; +struct CINN_ALIGN(16) float4 { + // float x, y, z, w; + AlignedVector data; + __device__ inline const float& operator[](int i) const { return data[i]; } + __device__ inline float& operator[](int i) { return data[i]; } + + __device__ inline float4& operator+(const float4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] += x[i]; + }; + return *this; + } + + __device__ inline float4& operator-(const float4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] -= x[i]; + }; + return *this; + } + + __device__ inline float4& operator*(const float4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] *= x[i]; + }; + return *this; + } + + __device__ inline float4& operator/(const float4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] /= x[i]; + }; + return *this; + } + + __device__ inline float4& operator+(const half4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] += static_cast(x[i]); + }; + return *this; + } + __device__ inline float4& operator-(const half4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] -= static_cast(x[i]); + }; + return *this; + } + + __device__ inline float4& operator*(const half4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] *= static_cast(x[i]); + }; + return *this; + } + + __device__ inline float4& operator/(const half4& x) { +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + data[i] /= static_cast(x[i]); + }; + return *this; + } +}; #ifdef __cplusplus } // namespace common diff --git a/cinn/optim/vectorize_loops.cc b/cinn/optim/vectorize_loops.cc index fc6d97f1da..c32eb3bcd5 100644 --- a/cinn/optim/vectorize_loops.cc +++ b/cinn/optim/vectorize_loops.cc @@ -176,7 +176,7 @@ class CudaVectorizer : public IRMutator { std::vector vectorized_store_exprs_; public: - static constexpr int CudaVectorTypeMaxLanes = 8; + static constexpr int CudaVectorTypeMaxLanes = 4; CudaVectorizer(const Var &iter_var, const int factor, const absl::flat_hash_map *var_intervals) From 73aae9812690de6a05174f01c9a4a964bf7c8786 Mon Sep 17 00:00:00 2001 From: limingshu Date: Wed, 17 May 2023 23:21:51 +0800 Subject: [PATCH 3/7] add vectorized==4 config --- cinn/common/float16.h | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/cinn/common/float16.h b/cinn/common/float16.h index 848704d384..831c36281d 100644 --- a/cinn/common/float16.h +++ b/cinn/common/float16.h @@ -572,7 +572,6 @@ __host__ __device__ inline float16(abs)(const float16& a) { __host__ __device__ inline float16(log)(const float16& a) { return float16(std::log(static_cast(a))); } - template struct alignas(sizeof(T) * Size) AlignedVector { T val[Size]; @@ -587,32 +586,36 @@ struct CINN_ALIGN(32) float8 { __device__ inline float& operator[](int i) { return data[i]; } __device__ inline float8& operator+(const float8& x) { - #pragma unroll(8) - for (int i = 0; i < 8; i++) {; +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + ; data[i] += x[i]; }; return *this; } __device__ inline float8& operator-(const float8& x) { - #pragma unroll(8) - for (int i = 0; i < 8; i++) {; +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + ; data[i] -= x[i]; }; return *this; } __device__ inline float8& operator*(const float8& x) { - #pragma unroll(8) - for (int i = 0; i < 8; i++) {; +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + ; data[i] *= x[i]; }; return *this; } __device__ inline float8& operator/(const float8& x) { - #pragma unroll(8) - for (int i = 0; i < 8; i++) {; +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + ; data[i] /= x[i]; }; return *this; @@ -626,32 +629,36 @@ struct CINN_ALIGN(16) half8 { __device__ inline float16& operator[](int i) { return data[i]; } __device__ inline half8& operator+(const half8& x) { - #pragma unroll(8) - for (int i = 0; i < 8; i++) {; +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + ; data[i] += x[i]; }; return *this; } __device__ inline half8& operator-(const half8& x) { - #pragma unroll(8) - for (int i = 0; i < 8; i++) {; +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + ; data[i] -= x[i]; }; return *this; } __device__ inline half8& operator*(const half8& x) { - #pragma unroll(8) - for (int i = 0; i < 8; i++) {; +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + ; data[i] *= x[i]; }; return *this; } __device__ inline half8& operator/(const half8& x) { - #pragma unroll(8) - for (int i = 0; i < 8; i++) {; +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + ; data[i] /= x[i]; }; return *this; From 5bf40670e4da0d6d4b8cbea32e14c84fb88ad76c Mon Sep 17 00:00:00 2001 From: limingshu Date: Tue, 23 May 2023 10:39:31 +0800 Subject: [PATCH 4/7] modify of cuda Kernel generation strategy --- cinn/backends/codegen_cuda_dev.cc | 16 +- cinn/common/float16.h | 387 ++++++++++++++++++------------ cinn/ir/ir.cc | 6 +- cinn/optim/vectorize_loops.cc | 53 +++- 4 files changed, 298 insertions(+), 164 deletions(-) diff --git a/cinn/backends/codegen_cuda_dev.cc b/cinn/backends/codegen_cuda_dev.cc index 415ab5d030..b3b8776f48 100644 --- a/cinn/backends/codegen_cuda_dev.cc +++ b/cinn/backends/codegen_cuda_dev.cc @@ -339,7 +339,7 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) { } bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, ir::Expr index_expr, bool is_store) { - static constexpr char index2suffix[8] = {'x', 'y', 'z', 'w', 'v', 'u', 't', 's'}; + // static constexpr char index2suffix[8] = {'x', 'y', 'z', 'w', 'v', 'u', 't', 's'}; // addr of op should be a place of tensor and the index is simple int number if (!op->is_addr_tensor() || !index_expr.As()) { @@ -359,9 +359,19 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, return false; } if (is_store && tensor->type().is_cpp_handle()) { - os() << tensor->name << "[" << index << "]"; + if (tensor->type().is_cpp_handle()) { + os() << "(*" << tensor->name << ")"; + } else { + os() << tensor->name; + } } else { - os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".") << index2suffix[index]; + if (tensor->type().is_cpp_handle()) { + // os() << "(*" << tensor->name << ")[" << index << "]"; + os() << "(*" << tensor->name << ")"; + } else { + // os() << "(" << tensor->name << ")[" << index << "]"; + os() << "(" << tensor->name << ")"; + } } return true; } diff --git a/cinn/common/float16.h b/cinn/common/float16.h index 831c36281d..4e3810141c 100644 --- a/cinn/common/float16.h +++ b/cinn/common/float16.h @@ -579,201 +579,181 @@ struct alignas(sizeof(T) * Size) AlignedVector { __host__ __device__ inline T& operator[](int i) { return val[i]; } }; +#define BINARY_FUNC_FOR_VECTORIZED_INPUT(vector_t, cast_t, op, num) \ + __device__ inline vector_t& operator op(const vector_t& x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op static_cast(x[i]); } \ + return *this; \ + } \ + __device__ inline vector_t& operator op(vector_t& x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op static_cast(x[i]); } \ + return *this; \ + } + +#define BINARY_FUNC_FOR_SCALARIZED_INPUT(vector_t, scalar_t, cast_t, op, num) \ + __device__ inline vector_t& operator op(const scalar_t& x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op static_cast(x); } \ + return *this; \ + } + +#define VECTORIZED_CAST_SCALAR_FUNC(vector_t, scalar_t, cast_t, num) \ + __device__ inline explicit vector_t(const scalar_t x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = static_cast(x); } \ + } + +struct half4; +struct half8; + struct CINN_ALIGN(32) float8 { - // float x, y, z, w, v, u, t, s; AlignedVector data; __device__ inline const float& operator[](int i) const { return data[i]; } __device__ inline float& operator[](int i) { return data[i]; } - __device__ inline float8& operator+(const float8& x) { -#pragma unroll(8) - for (int i = 0; i < 8; i++) { - ; - data[i] += x[i]; - }; - return *this; - } + float8() = default; - __device__ inline float8& operator-(const float8& x) { -#pragma unroll(8) - for (int i = 0; i < 8; i++) { - ; - data[i] -= x[i]; - }; + __device__ inline float8 operator*(float8& x) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + data[i] = data[i] * x[i]; + } return *this; } - __device__ inline float8& operator*(const float8& x) { -#pragma unroll(8) - for (int i = 0; i < 8; i++) { - ; - data[i] *= x[i]; - }; + __device__ inline float8 operator*(const float8& x) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + data[i] = data[i] * x[i]; + } return *this; } - __device__ inline float8& operator/(const float8& x) { -#pragma unroll(8) - for (int i = 0; i < 8; i++) { - ; - data[i] /= x[i]; - }; - return *this; - } + // operate with vec_type: float8 + BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, float, +, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, float, -, 8); + // BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, float, *, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, float, /, 8); + + // operate with scalar_type: float16 + BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float16, float, +, 8); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float16, float, -, 8); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float16, float, *, 8); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float16, float, /, 8); + + // operate with scalar_type: float + BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float, float, +, 8); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float, float, -, 8); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float, float, *, 8); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float, float, /, 8); + + VECTORIZED_CAST_SCALAR_FUNC(float8, float, float, 8); + VECTORIZED_CAST_SCALAR_FUNC(float8, float16, float, 8); + + __device__ inline explicit float8(const half8& x); }; -struct CINN_ALIGN(16) half8 { - // float16 x, y, z, w, v, u, t, s; - AlignedVector data; - __device__ inline const float16& operator[](int i) const { return data[i]; } - __device__ inline float16& operator[](int i) { return data[i]; } +struct CINN_ALIGN(16) float4 { + AlignedVector data; + __device__ inline const float& operator[](int i) const { return data[i]; } + __device__ inline float& operator[](int i) { return data[i]; } - __device__ inline half8& operator+(const half8& x) { -#pragma unroll(8) - for (int i = 0; i < 8; i++) { - ; - data[i] += x[i]; - }; - return *this; - } + float4() = default; - __device__ inline half8& operator-(const half8& x) { -#pragma unroll(8) - for (int i = 0; i < 8; i++) { - ; - data[i] -= x[i]; - }; - return *this; - } + // operate with vec_type + BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, float, +, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, float, -, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, float, *, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, float, /, 4); - __device__ inline half8& operator*(const half8& x) { -#pragma unroll(8) - for (int i = 0; i < 8; i++) { - ; - data[i] *= x[i]; - }; - return *this; - } + // operate with scalar_type: float16 + BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float16, float, +, 4); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float16, float, -, 4); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float16, float, *, 4); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float16, float, /, 4); - __device__ inline half8& operator/(const half8& x) { -#pragma unroll(8) - for (int i = 0; i < 8; i++) { - ; - data[i] /= x[i]; - }; - return *this; - } + // operate with scalar_type: float + BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float, float, +, 4); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float, float, -, 4); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float, float, *, 4); + BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float, float, /, 4); + + VECTORIZED_CAST_SCALAR_FUNC(float4, float, float, 4); + VECTORIZED_CAST_SCALAR_FUNC(float4, float16, float, 4); + + __device__ inline explicit float4(const half4& x); }; -struct CINN_ALIGN(8) half4 { - // float16 x, y, z, w; - AlignedVector data; +struct CINN_ALIGN(16) half8 { + AlignedVector data; __device__ inline const float16& operator[](int i) const { return data[i]; } __device__ inline float16& operator[](int i) { return data[i]; } - __device__ inline half4& operator+(const half4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] += x[i]; - }; - return *this; - } + half8() = default; - __device__ inline half4& operator-(const half4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] -= x[i]; - }; - return *this; - } + // operate with vec_type + BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, float16, +, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, float16, -, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, float16, *, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, float16, /, 8); - __device__ inline half4& operator*(const half4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] *= x[i]; - }; - return *this; - } + // operate with scalar_type: float16 + BINARY_FUNC_FOR_SCALARIZED_INPUT(half8, float16, float16, +, 8); + BINARY_FUNC_FOR_SCALARIZED_INPUT(half8, float16, float16, -, 8); + BINARY_FUNC_FOR_SCALARIZED_INPUT(half8, float16, float16, *, 8); + BINARY_FUNC_FOR_SCALARIZED_INPUT(half8, float16, float16, /, 8); - __device__ inline half4& operator/(const half4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] /= x[i]; - }; - return *this; + VECTORIZED_CAST_SCALAR_FUNC(half8, float, float16, 8); + VECTORIZED_CAST_SCALAR_FUNC(half8, float16, float16, 8); + + __device__ inline explicit half8(const float8& x) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + data[i] = static_cast(x[i]); + } } }; -struct CINN_ALIGN(16) float4 { - // float x, y, z, w; - AlignedVector data; - __device__ inline const float& operator[](int i) const { return data[i]; } - __device__ inline float& operator[](int i) { return data[i]; } - - __device__ inline float4& operator+(const float4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] += x[i]; - }; - return *this; - } +struct CINN_ALIGN(8) half4 { + AlignedVector data; + __device__ inline const float16& operator[](int i) const { return data[i]; } + __device__ inline float16& operator[](int i) { return data[i]; } - __device__ inline float4& operator-(const float4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] -= x[i]; - }; - return *this; - } + half4() = default; - __device__ inline float4& operator*(const float4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] *= x[i]; - }; - return *this; - } + // operate with vec_type + BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, float16, +, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, float16, -, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, float16, *, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, float16, /, 4); - __device__ inline float4& operator/(const float4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] /= x[i]; - }; - return *this; - } + // operate with scalar_type: float16 + BINARY_FUNC_FOR_SCALARIZED_INPUT(half4, float16, float16, +, 4); + BINARY_FUNC_FOR_SCALARIZED_INPUT(half4, float16, float16, -, 4); + BINARY_FUNC_FOR_SCALARIZED_INPUT(half4, float16, float16, *, 4); + BINARY_FUNC_FOR_SCALARIZED_INPUT(half4, float16, float16, /, 4); - __device__ inline float4& operator+(const half4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] += static_cast(x[i]); - }; - return *this; - } + VECTORIZED_CAST_SCALAR_FUNC(half4, float, float16, 4); + VECTORIZED_CAST_SCALAR_FUNC(half4, float16, float16, 4); - __device__ inline float4& operator-(const half4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] -= static_cast(x[i]); - }; - return *this; + __device__ inline explicit half4(const float4& x) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + data[i] = static_cast(x[i]); + } } +}; - __device__ inline float4& operator*(const half4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] *= static_cast(x[i]); - }; - return *this; +__device__ inline float4::float4(const half4& x) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + data[i] = static_cast(x[i]); } +} - __device__ inline float4& operator/(const half4& x) { -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - data[i] /= static_cast(x[i]); - }; - return *this; +__device__ inline float8::float8(const half8& x) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + data[i] = static_cast(x[i]); } -}; +} #ifdef __cplusplus } // namespace common @@ -815,6 +795,101 @@ __host__ __device__ inline cinn::common::float16 max(const cinn::common::float16 __host__ __device__ inline cinn::common::float16 min(const cinn::common::float16& a, const cinn::common::float16& b) { return a < b ? a : b; } + +using cinn::common::float16; +using cinn::common::half4; +using cinn::common::half8; + +__device__ inline void max(half8& x, float16 y) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + x[i] = max(x[i], y); + } +} + +__device__ inline half8 max(const half8& x, const float16 y) { + half8 rslt; +#pragma unroll + for (int i = 0; i < 8; ++i) { + rslt[i] = max(rslt[i], y); + } + return rslt; +} + +// #define BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(vector_t1, vector_t2, op, num) \ +// __device__ inline vector_t1 op(vector_t1& x, vector_t2& y) { \ +// for (int i = 0; i < num; ++i) { \ +// x[i] = op(x[i], y[i]); \ +// } \ +// return x; \ +// } \ +// __device__ inline vector_t1 op(const vector_t1& x, const vector_t2& y) { \ +// vector_t1 result; \ +// for (int i = 0; i < num; ++i) { \ +// result[i] = op(x[i], y[i]); \ +// } \ +// return result; \ +// } + +// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::float8, cinn::common::float8, max, 8); +// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::float8, cinn::common::float8, min, 8); +// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::half8 , cinn::common::half8, max, 8); +// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::half8 , cinn::common::half8, min, 8); +// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::half4 , cinn::common::half4, max, 4); +// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::half4 , cinn::common::half4, min, 4); + +// #define BINARY_FUNC_FOR_VECTOR_BEFORE_SCALAR_INPUT(vector_t, scalar_t, cast_t, op, num) \ +// __device__ inline vector_t op(vector_t& x, scalar_t& y) { \ +// for (int i = 0; i < num; ++i) { \ +// x[i] = op(static_cast(x[i]), static_cast(y)); \ +// } \ +// return x; \ +// } \ +// __device__ inline vector_t op(const vector_t& x, const scalar_t& y) { \ +// vector_t result; \ +// for (int i = 0; i < num; ++i) { \ +// result[i] = op(static_cast(x[i]), static_cast(y)); \ +// } \ +// return result; \ +// } + +// #define BINARY_FUNC_FOR_SCALAR_BEFORE_VECTOR_INPUT(vector_t, scalar_t, cast_t, op, num) \ +// __device__ inline vector_t op(scalar_t& x, vector_t& y) { \ +// for (int i = 0; i < num; ++i) { \ +// y[i] = op(static_cast(x), static_cast(y[i])); \ +// } \ +// return y; \ +// } \ +// __device__ inline vector_t op(const scalar_t& x, const vector_t& y) { \ +// vector_t result; \ +// for (int i = 0; i < num; ++i) { \ +// result[i] = op(static_cast(x), static_cast(y[i])); \ +// } \ +// return result; \ +// } + +// #define BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(vector_t, scalar_t, cast_t, op, num) \ +// BINARY_FUNC_FOR_VECTOR_BEFORE_SCALAR_INPUT(vector_t, scalar_t, cast_t, op, num) \ +// BINARY_FUNC_FOR_SCALAR_BEFORE_VECTOR_INPUT(vector_t, scalar_t, cast_t, op, num) + +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float8, float, float, max, 8); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float8, cinn::common::float16, float, max, 8); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float4, float, float, max, 4); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float4, cinn::common::float16, float, max, 4); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half8, float, cinn::common::float16, max, 8); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half8, cinn::common::float16, cinn::common::float16, max, 8); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half4, float, cinn::common::float16, max, 4); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half4, cinn::common::float16, cinn::common::float16, max, 4); + +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float8, float, float, min, 8); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float8, cinn::common::float16, float, min, 8); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float4, float, float, min, 4); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float4, cinn::common::float16, float, min, 4); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half8, float, cinn::common::float16, min, 8); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half8, cinn::common::float16, cinn::common::float16, min, 8); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half4, float, cinn::common::float16, min, 4); +// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half4, cinn::common::float16, cinn::common::float16, min, 4); + #endif // __cplusplus && CINN_CUDA_FP16 #endif // CINN_COMMON_FLOAT16_H diff --git a/cinn/ir/ir.cc b/cinn/ir/ir.cc index 7db39fa704..6ccadf0a15 100755 --- a/cinn/ir/ir.cc +++ b/cinn/ir/ir.cc @@ -72,7 +72,7 @@ void Sub::Verify() const { BinaryNodeVerify(a(), b(), "Sub"); } Expr Mul::Make(Expr a, Expr b) { CHECK(a.defined()); CHECK(b.defined()); - CHECK_EQ(a.type(), b.type()) << "a=" << a << ", b=" << b; + // CHECK_EQ(a.type(), b.type()) << "a=" << a << ", b=" << b; auto node = make_shared(a, b); return Expr(node); } @@ -658,7 +658,7 @@ Expr Sum::Make(const std::vector &vs) { auto *n = make_shared(); auto type = vs.front().type(); - for (auto &v : vs) CHECK_EQ(v.type(), type) << vs.front() << " " << v; + // for (auto &v : vs) CHECK_EQ(v.type(), type) << vs.front() << " " << v; n->operands() = vs; @@ -672,7 +672,7 @@ Expr Product::Make(const std::vector &vs) { auto *n = make_shared(); auto type = vs.front().type(); - for (auto &v : vs) CHECK_EQ(v.type(), type); + // for (auto &v : vs) CHECK_EQ(v.type(), type); n->operands() = vs; diff --git a/cinn/optim/vectorize_loops.cc b/cinn/optim/vectorize_loops.cc index c32eb3bcd5..fddc5cc5ad 100644 --- a/cinn/optim/vectorize_loops.cc +++ b/cinn/optim/vectorize_loops.cc @@ -176,7 +176,7 @@ class CudaVectorizer : public IRMutator { std::vector vectorized_store_exprs_; public: - static constexpr int CudaVectorTypeMaxLanes = 4; + static constexpr int CudaVectorTypeMaxLanes = 8; CudaVectorizer(const Var &iter_var, const int factor, const absl::flat_hash_map *var_intervals) @@ -216,6 +216,51 @@ class CudaVectorizer : public IRMutator { IRMutator::Visit(&node->value, &node->value); } + void Visit(const Cast *op, Expr *expr) override { + auto *node = expr->As(); + auto &body = node->v(); + LOG(INFO) << "Cast Begin()."; + LOG(INFO) << *expr; + LOG(INFO) << node; + LOG(INFO) << "Cast End()."; + + auto tensors = ir::CollectIRNodes(body, [](const Expr *x) { return x->As(); }); + + // handdle cast op at last. + IRMutator::Visit(op, expr); + + int lanes = 1; + bool is_point = false; + bool is_vectorize = true; + for (auto &tensor : tensors) { + auto tensor_ = tensor.As(); + + LOG(INFO) << "Loop Begin()."; + LOG(INFO) << tensor; + LOG(INFO) << node->type().is_scalar(); + LOG(INFO) << tensor2vectorized_vars_.count(tensor_->name); + if (node->type().is_scalar() && (tensor2vectorized_vars_.count(tensor_->name) || tensor_->type().is_vector())) { + lanes = tensor_->type().lanes(); + LOG(INFO) << lanes; + is_point = tensor_->type().is_cpp_handle(); + is_vectorize = true; + break; + } + LOG(INFO) << "Loop End."; + } + + if (is_vectorize) { + // create new cast type. + auto ntype = common::Type(Type::type_t::Customized, node->type().bits(), factor_); + // set vector type. + ntype.set_customized_type(GetVectorTypeName(node->type().ElementOf())); + if (is_point) { + ntype.set_cpp_handle(); + } + node->set_type(ntype); + } + } + private: void TensorVectorized(ir::LoadStoreAddrMnger *node, std::vector *indices, bool is_store) { auto *tensor = node->tensor.As(); @@ -722,7 +767,8 @@ struct VectorizeLoops_ : public IRMutator { CudaVectorizer cuda_vectorizer(new_forloop->loop_var, factor, &var_intervals); cuda_vectorizer.Visit(&new_forloop->body); // unroll the new forloop to compute each element of the vector iteratively - auto copied_loop = optim::IRCopy(_new_forloop); + auto copied_loop = optim::IRCopy(_new_forloop); + copied_loop.As()->extent = ir::Expr(1); copied_loop.As()->set_unrolled(); optim::UnrollLoop(&copied_loop); // add cast exprs of vector type in the front of vectorized forloop, @@ -862,6 +908,9 @@ struct VectorizeLoops_ : public IRMutator { new_index = Expr(forloop->loop_var) * factor + Expr(new_iterator); } optim::IrReplace(&forloop->body, forloop->loop_var, new_index); + + {} + auto new_forloop = For::Make(new_iterator, forloop->min, make_const(factor), From ed7f81a718db19dc7c41c955dfc09bc05fc34afb Mon Sep 17 00:00:00 2001 From: limingshu Date: Tue, 23 May 2023 20:43:54 +0800 Subject: [PATCH 5/7] fix all bugs in resnet50 --- cinn/common/float16.h | 210 +++++++++++++++------------------- cinn/optim/vectorize_loops.cc | 18 +-- 2 files changed, 100 insertions(+), 128 deletions(-) diff --git a/cinn/common/float16.h b/cinn/common/float16.h index 4e3810141c..a632977ca8 100644 --- a/cinn/common/float16.h +++ b/cinn/common/float16.h @@ -572,6 +572,9 @@ __host__ __device__ inline float16(abs)(const float16& a) { __host__ __device__ inline float16(log)(const float16& a) { return float16(std::log(static_cast(a))); } +struct half4; +struct half8; + template struct alignas(sizeof(T) * Size) AlignedVector { T val[Size]; @@ -579,58 +582,51 @@ struct alignas(sizeof(T) * Size) AlignedVector { __host__ __device__ inline T& operator[](int i) { return val[i]; } }; -#define BINARY_FUNC_FOR_VECTORIZED_INPUT(vector_t, cast_t, op, num) \ - __device__ inline vector_t& operator op(const vector_t& x) { \ - _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op static_cast(x[i]); } \ - return *this; \ - } \ - __device__ inline vector_t& operator op(vector_t& x) { \ - _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op static_cast(x[i]); } \ - return *this; \ +#define BINARY_FUNC_FOR_VECTORIZED_INPUT(vector_t, op, num) \ + __device__ inline vector_t& operator op(const vector_t& x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op x[i]; } \ + return *this; \ + } \ + __device__ inline vector_t operator op(const vector_t& x) const { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = data[i] op(x[i]); } \ + return rslt; \ + } \ + __device__ inline vector_t& operator op(vector_t& x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op(x[i]); } \ + return *this; \ + } \ + __device__ inline vector_t& operator op(vector_t& x) const { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { x[i] = data[i] op(x[i]); } \ + return x; \ } #define BINARY_FUNC_FOR_SCALARIZED_INPUT(vector_t, scalar_t, cast_t, op, num) \ __device__ inline vector_t& operator op(const scalar_t& x) { \ _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op static_cast(x); } \ return *this; \ + } \ + __device__ inline vector_t operator op(const scalar_t& x) const { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = data[i] op static_cast(x); } \ + return rslt; \ } - #define VECTORIZED_CAST_SCALAR_FUNC(vector_t, scalar_t, cast_t, num) \ __device__ inline explicit vector_t(const scalar_t x) { \ _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = static_cast(x); } \ } -struct half4; -struct half8; - struct CINN_ALIGN(32) float8 { AlignedVector data; __device__ inline const float& operator[](int i) const { return data[i]; } __device__ inline float& operator[](int i) { return data[i]; } - float8() = default; - __device__ inline float8 operator*(float8& x) { -#pragma unroll - for (int i = 0; i < 8; ++i) { - data[i] = data[i] * x[i]; - } - return *this; - } - - __device__ inline float8 operator*(const float8& x) { -#pragma unroll - for (int i = 0; i < 8; ++i) { - data[i] = data[i] * x[i]; - } - return *this; - } - // operate with vec_type: float8 - BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, float, +, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, float, -, 8); - // BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, float, *, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, float, /, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, +, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, -, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, *, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, /, 8); // operate with scalar_type: float16 BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float16, float, +, 8); @@ -658,10 +654,10 @@ struct CINN_ALIGN(16) float4 { float4() = default; // operate with vec_type - BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, float, +, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, float, -, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, float, *, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, float, /, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, +, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, -, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, *, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, /, 4); // operate with scalar_type: float16 BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float16, float, +, 4); @@ -689,10 +685,10 @@ struct CINN_ALIGN(16) half8 { half8() = default; // operate with vec_type - BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, float16, +, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, float16, -, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, float16, *, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, float16, /, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, +, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, -, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, *, 8); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, /, 8); // operate with scalar_type: float16 BINARY_FUNC_FOR_SCALARIZED_INPUT(half8, float16, float16, +, 8); @@ -719,10 +715,10 @@ struct CINN_ALIGN(8) half4 { half4() = default; // operate with vec_type - BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, float16, +, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, float16, -, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, float16, *, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, float16, /, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, +, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, -, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, *, 4); + BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, /, 4); // operate with scalar_type: float16 BINARY_FUNC_FOR_SCALARIZED_INPUT(half4, float16, float16, +, 4); @@ -755,6 +751,38 @@ __device__ inline float8::float8(const half8& x) { } } +#define BINARY_FUNCTION_FOR_SCALAR_AND_VECTOR(scalar_t, vector_t, op, num) \ + __device__ inline vector_t& operator op(const scalar_t x, vector_t& y) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = x op y[i]; } \ + return y; \ + } \ + __device__ inline vector_t operator op(const scalar_t x, const vector_t& y) { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = x * y[i]; } \ + return rslt; \ + } + +#define BINARY_FUNCTIONS_FOR_EACH_IMPL(_) \ + _(float, float8, +, 8) \ + _(float, float8, -, 8) \ + _(float, float8, *, 8) \ + _(float, float8, /, 8) \ + _(float, float4, +, 4) \ + _(float, float4, -, 4) \ + _(float, float4, *, 4) \ + _(float, float4, /, 4) \ + _(float16, half8, +, 8) \ + _(float16, half8, -, 8) \ + _(float16, half8, *, 8) \ + _(float16, half8, /, 8) \ + _(float16, half4, +, 4) \ + _(float16, half4, -, 4) \ + _(float16, half4, *, 4) \ + _(float16, half4, /, 4) + +BINARY_FUNCTIONS_FOR_EACH_IMPL(BINARY_FUNCTION_FOR_SCALAR_AND_VECTOR); +#undef BINARY_FUNCTIONS_FOR_SCALAR_AND_REGISTER_IMPL + #ifdef __cplusplus } // namespace common } // namespace cinn @@ -800,11 +828,12 @@ using cinn::common::float16; using cinn::common::half4; using cinn::common::half8; -__device__ inline void max(half8& x, float16 y) { +__device__ inline half8& max(half8& x, float16 y) { #pragma unroll for (int i = 0; i < 8; ++i) { x[i] = max(x[i], y); } + return x; } __device__ inline half8 max(const half8& x, const float16 y) { @@ -816,79 +845,22 @@ __device__ inline half8 max(const half8& x, const float16 y) { return rslt; } -// #define BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(vector_t1, vector_t2, op, num) \ -// __device__ inline vector_t1 op(vector_t1& x, vector_t2& y) { \ -// for (int i = 0; i < num; ++i) { \ -// x[i] = op(x[i], y[i]); \ -// } \ -// return x; \ -// } \ -// __device__ inline vector_t1 op(const vector_t1& x, const vector_t2& y) { \ -// vector_t1 result; \ -// for (int i = 0; i < num; ++i) { \ -// result[i] = op(x[i], y[i]); \ -// } \ -// return result; \ -// } - -// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::float8, cinn::common::float8, max, 8); -// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::float8, cinn::common::float8, min, 8); -// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::half8 , cinn::common::half8, max, 8); -// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::half8 , cinn::common::half8, min, 8); -// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::half4 , cinn::common::half4, max, 4); -// BINARY_FUNC_FOR_TWO_VECTORIZED_TYPE_INPUTS(cinn::common::half4 , cinn::common::half4, min, 4); - -// #define BINARY_FUNC_FOR_VECTOR_BEFORE_SCALAR_INPUT(vector_t, scalar_t, cast_t, op, num) \ -// __device__ inline vector_t op(vector_t& x, scalar_t& y) { \ -// for (int i = 0; i < num; ++i) { \ -// x[i] = op(static_cast(x[i]), static_cast(y)); \ -// } \ -// return x; \ -// } \ -// __device__ inline vector_t op(const vector_t& x, const scalar_t& y) { \ -// vector_t result; \ -// for (int i = 0; i < num; ++i) { \ -// result[i] = op(static_cast(x[i]), static_cast(y)); \ -// } \ -// return result; \ -// } - -// #define BINARY_FUNC_FOR_SCALAR_BEFORE_VECTOR_INPUT(vector_t, scalar_t, cast_t, op, num) \ -// __device__ inline vector_t op(scalar_t& x, vector_t& y) { \ -// for (int i = 0; i < num; ++i) { \ -// y[i] = op(static_cast(x), static_cast(y[i])); \ -// } \ -// return y; \ -// } \ -// __device__ inline vector_t op(const scalar_t& x, const vector_t& y) { \ -// vector_t result; \ -// for (int i = 0; i < num; ++i) { \ -// result[i] = op(static_cast(x), static_cast(y[i])); \ -// } \ -// return result; \ -// } - -// #define BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(vector_t, scalar_t, cast_t, op, num) \ -// BINARY_FUNC_FOR_VECTOR_BEFORE_SCALAR_INPUT(vector_t, scalar_t, cast_t, op, num) \ -// BINARY_FUNC_FOR_SCALAR_BEFORE_VECTOR_INPUT(vector_t, scalar_t, cast_t, op, num) - -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float8, float, float, max, 8); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float8, cinn::common::float16, float, max, 8); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float4, float, float, max, 4); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float4, cinn::common::float16, float, max, 4); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half8, float, cinn::common::float16, max, 8); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half8, cinn::common::float16, cinn::common::float16, max, 8); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half4, float, cinn::common::float16, max, 4); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half4, cinn::common::float16, cinn::common::float16, max, 4); - -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float8, float, float, min, 8); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float8, cinn::common::float16, float, min, 8); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float4, float, float, min, 4); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::float4, cinn::common::float16, float, min, 4); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half8, float, cinn::common::float16, min, 8); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half8, cinn::common::float16, cinn::common::float16, min, 8); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half4, float, cinn::common::float16, min, 4); -// BINARY_FUNC_FOR_VECTOR_AND_SCALAR_INPUTS(cinn::common::half4, cinn::common::float16, cinn::common::float16, min, 4); +__device__ inline half8& max(float16 y, half8& x) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + x[i] = max(x[i], y); + } + return x; +} + +__device__ inline half8 max(const float16 y, const half8& x) { + half8 rslt; +#pragma unroll + for (int i = 0; i < 8; ++i) { + rslt[i] = max(x[i], y); + } + return x; +} #endif // __cplusplus && CINN_CUDA_FP16 diff --git a/cinn/optim/vectorize_loops.cc b/cinn/optim/vectorize_loops.cc index fddc5cc5ad..753e051028 100644 --- a/cinn/optim/vectorize_loops.cc +++ b/cinn/optim/vectorize_loops.cc @@ -219,10 +219,10 @@ class CudaVectorizer : public IRMutator { void Visit(const Cast *op, Expr *expr) override { auto *node = expr->As(); auto &body = node->v(); - LOG(INFO) << "Cast Begin()."; - LOG(INFO) << *expr; - LOG(INFO) << node; - LOG(INFO) << "Cast End()."; + // LOG(INFO) << "Cast Begin()."; + // LOG(INFO) << *expr; + // LOG(INFO) << node; + // LOG(INFO) << "Cast End()."; auto tensors = ir::CollectIRNodes(body, [](const Expr *x) { return x->As(); }); @@ -235,13 +235,13 @@ class CudaVectorizer : public IRMutator { for (auto &tensor : tensors) { auto tensor_ = tensor.As(); - LOG(INFO) << "Loop Begin()."; - LOG(INFO) << tensor; - LOG(INFO) << node->type().is_scalar(); - LOG(INFO) << tensor2vectorized_vars_.count(tensor_->name); + // LOG(INFO) << "Loop Begin()."; + // LOG(INFO) << tensor; + // LOG(INFO) << node->type().is_scalar(); + // LOG(INFO) << tensor2vectorized_vars_.count(tensor_->name); if (node->type().is_scalar() && (tensor2vectorized_vars_.count(tensor_->name) || tensor_->type().is_vector())) { lanes = tensor_->type().lanes(); - LOG(INFO) << lanes; + // LOG(INFO) << lanes; is_point = tensor_->type().is_cpp_handle(); is_vectorize = true; break; From ac57a6dffcd81c6815ff59fbad25fcbe190e1062 Mon Sep 17 00:00:00 2001 From: limingshu Date: Wed, 24 May 2023 14:10:10 +0800 Subject: [PATCH 6/7] fix bugs --- cinn/backends/codegen_cuda_dev.cc | 9 ++------- cinn/common/float16.h | 28 ++++++++++++++-------------- cinn/optim/vectorize_loops.cc | 18 +++--------------- 3 files changed, 19 insertions(+), 36 deletions(-) diff --git a/cinn/backends/codegen_cuda_dev.cc b/cinn/backends/codegen_cuda_dev.cc index b3b8776f48..31fe1ff904 100644 --- a/cinn/backends/codegen_cuda_dev.cc +++ b/cinn/backends/codegen_cuda_dev.cc @@ -359,18 +359,13 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, return false; } if (is_store && tensor->type().is_cpp_handle()) { - if (tensor->type().is_cpp_handle()) { - os() << "(*" << tensor->name << ")"; - } else { - os() << tensor->name; - } + os() << tensor->name << "[" << index << "]"; } else { if (tensor->type().is_cpp_handle()) { // os() << "(*" << tensor->name << ")[" << index << "]"; os() << "(*" << tensor->name << ")"; } else { - // os() << "(" << tensor->name << ")[" << index << "]"; - os() << "(" << tensor->name << ")"; + os() << tensor->name; } } return true; diff --git a/cinn/common/float16.h b/cinn/common/float16.h index a632977ca8..cb29dd85c6 100644 --- a/cinn/common/float16.h +++ b/cinn/common/float16.h @@ -701,7 +701,7 @@ struct CINN_ALIGN(16) half8 { __device__ inline explicit half8(const float8& x) { #pragma unroll - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < 8; ++i) { data[i] = static_cast(x[i]); } } @@ -751,15 +751,15 @@ __device__ inline float8::float8(const half8& x) { } } -#define BINARY_FUNCTION_FOR_SCALAR_AND_VECTOR(scalar_t, vector_t, op, num) \ - __device__ inline vector_t& operator op(const scalar_t x, vector_t& y) { \ - _Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = x op y[i]; } \ - return y; \ - } \ - __device__ inline vector_t operator op(const scalar_t x, const vector_t& y) { \ - vector_t rslt; \ - _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = x * y[i]; } \ - return rslt; \ +#define BINARY_FUNCTION_FOR_SCALAR_AND_VECTOR(scalar_t, vector_t, op, num) \ + __device__ inline vector_t& operator op(const scalar_t x, vector_t& y) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = x op y[i]; } \ + return y; \ + } \ + __device__ inline vector_t& operator op(const scalar_t x, const vector_t& y) { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = x * y[i]; } \ + return rslt; \ } #define BINARY_FUNCTIONS_FOR_EACH_IMPL(_) \ @@ -836,11 +836,11 @@ __device__ inline half8& max(half8& x, float16 y) { return x; } -__device__ inline half8 max(const half8& x, const float16 y) { +__device__ inline half8& max(const half8& x, const float16 y) { half8 rslt; #pragma unroll for (int i = 0; i < 8; ++i) { - rslt[i] = max(rslt[i], y); + rslt[i] = max(x[i], y); } return rslt; } @@ -853,13 +853,13 @@ __device__ inline half8& max(float16 y, half8& x) { return x; } -__device__ inline half8 max(const float16 y, const half8& x) { +__device__ inline half8& max(const float16 y, const half8& x) { half8 rslt; #pragma unroll for (int i = 0; i < 8; ++i) { rslt[i] = max(x[i], y); } - return x; + return rslt; } #endif // __cplusplus && CINN_CUDA_FP16 diff --git a/cinn/optim/vectorize_loops.cc b/cinn/optim/vectorize_loops.cc index 753e051028..74e4b7153a 100644 --- a/cinn/optim/vectorize_loops.cc +++ b/cinn/optim/vectorize_loops.cc @@ -217,13 +217,8 @@ class CudaVectorizer : public IRMutator { } void Visit(const Cast *op, Expr *expr) override { - auto *node = expr->As(); - auto &body = node->v(); - // LOG(INFO) << "Cast Begin()."; - // LOG(INFO) << *expr; - // LOG(INFO) << node; - // LOG(INFO) << "Cast End()."; - + auto *node = expr->As(); + auto &body = node->v(); auto tensors = ir::CollectIRNodes(body, [](const Expr *x) { return x->As(); }); // handdle cast op at last. @@ -234,19 +229,12 @@ class CudaVectorizer : public IRMutator { bool is_vectorize = true; for (auto &tensor : tensors) { auto tensor_ = tensor.As(); - - // LOG(INFO) << "Loop Begin()."; - // LOG(INFO) << tensor; - // LOG(INFO) << node->type().is_scalar(); - // LOG(INFO) << tensor2vectorized_vars_.count(tensor_->name); if (node->type().is_scalar() && (tensor2vectorized_vars_.count(tensor_->name) || tensor_->type().is_vector())) { - lanes = tensor_->type().lanes(); - // LOG(INFO) << lanes; + lanes = tensor_->type().lanes(); is_point = tensor_->type().is_cpp_handle(); is_vectorize = true; break; } - LOG(INFO) << "Loop End."; } if (is_vectorize) { From a984b02909eba9fc9d6017db04c08ce8f081bdea Mon Sep 17 00:00:00 2001 From: limingshu Date: Wed, 24 May 2023 17:03:52 +0800 Subject: [PATCH 7/7] concat the code with more macro --- cinn/common/float16.h | 219 ++++++++++++++++-------------------------- 1 file changed, 84 insertions(+), 135 deletions(-) diff --git a/cinn/common/float16.h b/cinn/common/float16.h index cb29dd85c6..f85c25d9b9 100644 --- a/cinn/common/float16.h +++ b/cinn/common/float16.h @@ -572,6 +572,7 @@ __host__ __device__ inline float16(abs)(const float16& a) { __host__ __device__ inline float16(log)(const float16& a) { return float16(std::log(static_cast(a))); } +// Overload of vectorized types. struct half4; struct half8; @@ -582,7 +583,7 @@ struct alignas(sizeof(T) * Size) AlignedVector { __host__ __device__ inline T& operator[](int i) { return val[i]; } }; -#define BINARY_FUNC_FOR_VECTORIZED_INPUT(vector_t, op, num) \ +#define BINARY_FUNC_FOR_VECTORIZED_INPUT(vector_t, num, op) \ __device__ inline vector_t& operator op(const vector_t& x) { \ _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op x[i]; } \ return *this; \ @@ -601,7 +602,7 @@ struct alignas(sizeof(T) * Size) AlignedVector { return x; \ } -#define BINARY_FUNC_FOR_SCALARIZED_INPUT(vector_t, scalar_t, cast_t, op, num) \ +#define BINARY_FUNC_FOR_SCALARIZED_INPUT(vector_t, scalar_t, cast_t, num, op) \ __device__ inline vector_t& operator op(const scalar_t& x) { \ _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op static_cast(x); } \ return *this; \ @@ -616,29 +617,22 @@ struct alignas(sizeof(T) * Size) AlignedVector { _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = static_cast(x); } \ } +#define BASIC_BINARY_ARITHMETICS_FOR_EACH(_, ...) \ + _(__VA_ARGS__, +); \ + _(__VA_ARGS__, -); \ + _(__VA_ARGS__, *); \ + _(__VA_ARGS__, /); + struct CINN_ALIGN(32) float8 { AlignedVector data; __device__ inline const float& operator[](int i) const { return data[i]; } __device__ inline float& operator[](int i) { return data[i]; } float8() = default; - // operate with vec_type: float8 - BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, +, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, -, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, *, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float8, /, 8); - - // operate with scalar_type: float16 - BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float16, float, +, 8); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float16, float, -, 8); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float16, float, *, 8); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float16, float, /, 8); - - // operate with scalar_type: float - BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float, float, +, 8); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float, float, -, 8); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float, float, *, 8); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float8, float, float, /, 8); + // operate with different input types. + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, float8, 8); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float8, float16, float, 8); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float8, float, float, 8) VECTORIZED_CAST_SCALAR_FUNC(float8, float, float, 8); VECTORIZED_CAST_SCALAR_FUNC(float8, float16, float, 8); @@ -653,23 +647,10 @@ struct CINN_ALIGN(16) float4 { float4() = default; - // operate with vec_type - BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, +, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, -, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, *, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(float4, /, 4); - - // operate with scalar_type: float16 - BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float16, float, +, 4); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float16, float, -, 4); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float16, float, *, 4); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float16, float, /, 4); - - // operate with scalar_type: float - BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float, float, +, 4); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float, float, -, 4); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float, float, *, 4); - BINARY_FUNC_FOR_SCALARIZED_INPUT(float4, float, float, /, 4); + // operate with different input types. + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, float4, 4); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float4, float16, float, 4); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float4, float, float, 4); VECTORIZED_CAST_SCALAR_FUNC(float4, float, float, 4); VECTORIZED_CAST_SCALAR_FUNC(float4, float16, float, 4); @@ -682,29 +663,20 @@ struct CINN_ALIGN(16) half8 { __device__ inline const float16& operator[](int i) const { return data[i]; } __device__ inline float16& operator[](int i) { return data[i]; } + // construction. half8() = default; - // operate with vec_type - BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, +, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, -, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, *, 8); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half8, /, 8); + __device__ explicit half8(const float8& x) { + _Pragma("unroll") for (int i = 0; i < 8; ++i) { data[i] = static_cast(x[i]); } + } - // operate with scalar_type: float16 - BINARY_FUNC_FOR_SCALARIZED_INPUT(half8, float16, float16, +, 8); - BINARY_FUNC_FOR_SCALARIZED_INPUT(half8, float16, float16, -, 8); - BINARY_FUNC_FOR_SCALARIZED_INPUT(half8, float16, float16, *, 8); - BINARY_FUNC_FOR_SCALARIZED_INPUT(half8, float16, float16, /, 8); + // Operate with different input types. + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, half8, 8); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, half8, float16, float16, 8); + // Cast operation overload. VECTORIZED_CAST_SCALAR_FUNC(half8, float, float16, 8); VECTORIZED_CAST_SCALAR_FUNC(half8, float16, float16, 8); - - __device__ inline explicit half8(const float8& x) { -#pragma unroll - for (int i = 0; i < 8; ++i) { - data[i] = static_cast(x[i]); - } - } }; struct CINN_ALIGN(8) half4 { @@ -712,31 +684,22 @@ struct CINN_ALIGN(8) half4 { __device__ inline const float16& operator[](int i) const { return data[i]; } __device__ inline float16& operator[](int i) { return data[i]; } + // construction. half4() = default; + __device__ explicit half4(const float4& x) { + _Pragma("unroll") for (int i = 0; i < 4; ++i) { data[i] = static_cast(x[i]); } + } - // operate with vec_type - BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, +, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, -, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, *, 4); - BINARY_FUNC_FOR_VECTORIZED_INPUT(half4, /, 4); - - // operate with scalar_type: float16 - BINARY_FUNC_FOR_SCALARIZED_INPUT(half4, float16, float16, +, 4); - BINARY_FUNC_FOR_SCALARIZED_INPUT(half4, float16, float16, -, 4); - BINARY_FUNC_FOR_SCALARIZED_INPUT(half4, float16, float16, *, 4); - BINARY_FUNC_FOR_SCALARIZED_INPUT(half4, float16, float16, /, 4); + // Operate with different input types. + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, half4, 4); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, half4, float16, float16, 4); + // Cast operation overload. VECTORIZED_CAST_SCALAR_FUNC(half4, float, float16, 4); VECTORIZED_CAST_SCALAR_FUNC(half4, float16, float16, 4); - - __device__ inline explicit half4(const float4& x) { -#pragma unroll - for (int i = 0; i < 4; ++i) { - data[i] = static_cast(x[i]); - } - } }; +// Cast operation overload. __device__ inline float4::float4(const half4& x) { #pragma unroll for (int i = 0; i < 4; ++i) { @@ -744,6 +707,7 @@ __device__ inline float4::float4(const half4& x) { } } +// Cast operation overload. __device__ inline float8::float8(const half8& x) { #pragma unroll for (int i = 0; i < 8; ++i) { @@ -751,37 +715,25 @@ __device__ inline float8::float8(const half8& x) { } } -#define BINARY_FUNCTION_FOR_SCALAR_AND_VECTOR(scalar_t, vector_t, op, num) \ - __device__ inline vector_t& operator op(const scalar_t x, vector_t& y) { \ - _Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = x op y[i]; } \ - return y; \ - } \ - __device__ inline vector_t& operator op(const scalar_t x, const vector_t& y) { \ - vector_t rslt; \ - _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = x * y[i]; } \ - return rslt; \ - } - -#define BINARY_FUNCTIONS_FOR_EACH_IMPL(_) \ - _(float, float8, +, 8) \ - _(float, float8, -, 8) \ - _(float, float8, *, 8) \ - _(float, float8, /, 8) \ - _(float, float4, +, 4) \ - _(float, float4, -, 4) \ - _(float, float4, *, 4) \ - _(float, float4, /, 4) \ - _(float16, half8, +, 8) \ - _(float16, half8, -, 8) \ - _(float16, half8, *, 8) \ - _(float16, half8, /, 8) \ - _(float16, half4, +, 4) \ - _(float16, half4, -, 4) \ - _(float16, half4, *, 4) \ - _(float16, half4, /, 4) - -BINARY_FUNCTIONS_FOR_EACH_IMPL(BINARY_FUNCTION_FOR_SCALAR_AND_VECTOR); -#undef BINARY_FUNCTIONS_FOR_SCALAR_AND_REGISTER_IMPL +#define SCALAR_FIRST_BASIC_BINARY_FUNCTIONS(scalar_t, vector_t, num, op) \ + __device__ inline vector_t& operator op(const scalar_t x, vector_t& y) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = x op y[i]; } \ + return y; \ + } \ + __device__ inline vector_t operator op(const scalar_t x, const vector_t& y) { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = x op y[i]; } \ + return rslt; \ + } + +#define SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(_, IMPL) \ + _(IMPL, float, float8, 8) \ + _(IMPL, float, float4, 4) \ + _(IMPL, float16, half8, 8) \ + _(IMPL, float16, half4, 4) + +SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(BASIC_BINARY_ARITHMETICS_FOR_EACH, SCALAR_FIRST_BASIC_BINARY_FUNCTIONS); +#undef SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL #ifdef __cplusplus } // namespace common @@ -828,39 +780,36 @@ using cinn::common::float16; using cinn::common::half4; using cinn::common::half8; -__device__ inline half8& max(half8& x, float16 y) { -#pragma unroll - for (int i = 0; i < 8; ++i) { - x[i] = max(x[i], y); - } - return x; -} - -__device__ inline half8& max(const half8& x, const float16 y) { - half8 rslt; -#pragma unroll - for (int i = 0; i < 8; ++i) { - rslt[i] = max(x[i], y); - } - return rslt; -} - -__device__ inline half8& max(float16 y, half8& x) { -#pragma unroll - for (int i = 0; i < 8; ++i) { - x[i] = max(x[i], y); - } - return x; -} - -__device__ inline half8& max(const float16 y, const half8& x) { - half8 rslt; -#pragma unroll - for (int i = 0; i < 8; ++i) { - rslt[i] = max(x[i], y); - } - return rslt; -} +#define BASIC_BINARY_FUNC_FOR_VECTOR_AND_SCALAR_TYPES(vector_t, scalar_t, num, op) \ + __device__ inline vector_t& op(vector_t& x, scalar_t y) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { x[i] = op(x[i], y); } \ + return x; \ + } \ + __device__ inline vector_t op(const vector_t& x, const scalar_t y) { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = op(x[i], y); } \ + return rslt; \ + } \ + __device__ inline vector_t& op(scalar_t x, vector_t& y) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = op(x, y[i]); } \ + return y; \ + } \ + __device__ inline vector_t op(const scalar_t x, const vector_t& y) { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = op(x, y[i]); } \ + return rslt; \ + } + +#define BASIC_BINARY_FUNC_FOR_EACH(_, ...) \ + _(__VA_ARGS__, max); \ + _(__VA_ARGS__, min); + +#define SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(_, IMPL) \ + _(IMPL, half8, float16, 8); \ + _(IMPL, half4, float16, 4); + +SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(BASIC_BINARY_FUNC_FOR_EACH, BASIC_BINARY_FUNC_FOR_VECTOR_AND_SCALAR_TYPES); +#undef SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL #endif // __cplusplus && CINN_CUDA_FP16