diff --git a/cinn/backends/codegen_cuda_dev.cc b/cinn/backends/codegen_cuda_dev.cc index 415ab5d030..31fe1ff904 100644 --- a/cinn/backends/codegen_cuda_dev.cc +++ b/cinn/backends/codegen_cuda_dev.cc @@ -339,7 +339,7 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) { } bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, ir::Expr index_expr, bool is_store) { - static constexpr char index2suffix[8] = {'x', 'y', 'z', 'w', 'v', 'u', 't', 's'}; + // static constexpr char index2suffix[8] = {'x', 'y', 'z', 'w', 'v', 'u', 't', 's'}; // addr of op should be a place of tensor and the index is simple int number if (!op->is_addr_tensor() || !index_expr.As()) { @@ -361,7 +361,12 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, if (is_store && tensor->type().is_cpp_handle()) { os() << tensor->name << "[" << index << "]"; } else { - os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".") << index2suffix[index]; + if (tensor->type().is_cpp_handle()) { + // os() << "(*" << tensor->name << ")[" << index << "]"; + os() << "(*" << tensor->name << ")"; + } else { + os() << tensor->name; + } } return true; } diff --git a/cinn/common/float16.h b/cinn/common/float16.h index 4bf8c64614..f85c25d9b9 100644 --- a/cinn/common/float16.h +++ b/cinn/common/float16.h @@ -303,18 +303,6 @@ struct CINN_ALIGN(2) float16 { #endif // __cplusplus }; -struct CINN_ALIGN(32) float8 { - float x, y, z, w, v, u, t, s; -}; - -struct CINN_ALIGN(16) half8 { - float16 x, y, z, w, v, u, t, s; -}; - -struct CINN_ALIGN(8) half4 { - float16 x, y, z, w; -}; - #ifdef __cplusplus // Arithmetic operators on GPU // CUDA 9.0 provides built-in arithmetic operators for half while @@ -584,6 +572,169 @@ __host__ __device__ inline float16(abs)(const float16& a) { __host__ __device__ inline float16(log)(const float16& a) { return float16(std::log(static_cast(a))); } +// Overload of vectorized types. +struct half4; +struct half8; + +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; + __host__ __device__ inline const T& operator[](int i) const { return val[i]; } + __host__ __device__ inline T& operator[](int i) { return val[i]; } +}; + +#define BINARY_FUNC_FOR_VECTORIZED_INPUT(vector_t, num, op) \ + __device__ inline vector_t& operator op(const vector_t& x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op x[i]; } \ + return *this; \ + } \ + __device__ inline vector_t operator op(const vector_t& x) const { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = data[i] op(x[i]); } \ + return rslt; \ + } \ + __device__ inline vector_t& operator op(vector_t& x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op(x[i]); } \ + return *this; \ + } \ + __device__ inline vector_t& operator op(vector_t& x) const { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { x[i] = data[i] op(x[i]); } \ + return x; \ + } + +#define BINARY_FUNC_FOR_SCALARIZED_INPUT(vector_t, scalar_t, cast_t, num, op) \ + __device__ inline vector_t& operator op(const scalar_t& x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op static_cast(x); } \ + return *this; \ + } \ + __device__ inline vector_t operator op(const scalar_t& x) const { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = data[i] op static_cast(x); } \ + return rslt; \ + } +#define VECTORIZED_CAST_SCALAR_FUNC(vector_t, scalar_t, cast_t, num) \ + __device__ inline explicit vector_t(const scalar_t x) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = static_cast(x); } \ + } + +#define BASIC_BINARY_ARITHMETICS_FOR_EACH(_, ...) \ + _(__VA_ARGS__, +); \ + _(__VA_ARGS__, -); \ + _(__VA_ARGS__, *); \ + _(__VA_ARGS__, /); + +struct CINN_ALIGN(32) float8 { + AlignedVector data; + __device__ inline const float& operator[](int i) const { return data[i]; } + __device__ inline float& operator[](int i) { return data[i]; } + float8() = default; + + // operate with different input types. + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, float8, 8); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float8, float16, float, 8); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float8, float, float, 8) + + VECTORIZED_CAST_SCALAR_FUNC(float8, float, float, 8); + VECTORIZED_CAST_SCALAR_FUNC(float8, float16, float, 8); + + __device__ inline explicit float8(const half8& x); +}; + +struct CINN_ALIGN(16) float4 { + AlignedVector data; + __device__ inline const float& operator[](int i) const { return data[i]; } + __device__ inline float& operator[](int i) { return data[i]; } + + float4() = default; + + // operate with different input types. + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, float4, 4); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float4, float16, float, 4); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float4, float, float, 4); + + VECTORIZED_CAST_SCALAR_FUNC(float4, float, float, 4); + VECTORIZED_CAST_SCALAR_FUNC(float4, float16, float, 4); + + __device__ inline explicit float4(const half4& x); +}; + +struct CINN_ALIGN(16) half8 { + AlignedVector data; + __device__ inline const float16& operator[](int i) const { return data[i]; } + __device__ inline float16& operator[](int i) { return data[i]; } + + // construction. + half8() = default; + + __device__ explicit half8(const float8& x) { + _Pragma("unroll") for (int i = 0; i < 8; ++i) { data[i] = static_cast(x[i]); } + } + + // Operate with different input types. + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, half8, 8); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, half8, float16, float16, 8); + + // Cast operation overload. + VECTORIZED_CAST_SCALAR_FUNC(half8, float, float16, 8); + VECTORIZED_CAST_SCALAR_FUNC(half8, float16, float16, 8); +}; + +struct CINN_ALIGN(8) half4 { + AlignedVector data; + __device__ inline const float16& operator[](int i) const { return data[i]; } + __device__ inline float16& operator[](int i) { return data[i]; } + + // construction. + half4() = default; + __device__ explicit half4(const float4& x) { + _Pragma("unroll") for (int i = 0; i < 4; ++i) { data[i] = static_cast(x[i]); } + } + + // Operate with different input types. + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, half4, 4); + BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, half4, float16, float16, 4); + + // Cast operation overload. + VECTORIZED_CAST_SCALAR_FUNC(half4, float, float16, 4); + VECTORIZED_CAST_SCALAR_FUNC(half4, float16, float16, 4); +}; + +// Cast operation overload. +__device__ inline float4::float4(const half4& x) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + data[i] = static_cast(x[i]); + } +} + +// Cast operation overload. +__device__ inline float8::float8(const half8& x) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + data[i] = static_cast(x[i]); + } +} + +#define SCALAR_FIRST_BASIC_BINARY_FUNCTIONS(scalar_t, vector_t, num, op) \ + __device__ inline vector_t& operator op(const scalar_t x, vector_t& y) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = x op y[i]; } \ + return y; \ + } \ + __device__ inline vector_t operator op(const scalar_t x, const vector_t& y) { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = x op y[i]; } \ + return rslt; \ + } + +#define SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(_, IMPL) \ + _(IMPL, float, float8, 8) \ + _(IMPL, float, float4, 4) \ + _(IMPL, float16, half8, 8) \ + _(IMPL, float16, half4, 4) + +SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(BASIC_BINARY_ARITHMETICS_FOR_EACH, SCALAR_FIRST_BASIC_BINARY_FUNCTIONS); +#undef SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL + #ifdef __cplusplus } // namespace common } // namespace cinn @@ -624,6 +775,42 @@ __host__ __device__ inline cinn::common::float16 max(const cinn::common::float16 __host__ __device__ inline cinn::common::float16 min(const cinn::common::float16& a, const cinn::common::float16& b) { return a < b ? a : b; } + +using cinn::common::float16; +using cinn::common::half4; +using cinn::common::half8; + +#define BASIC_BINARY_FUNC_FOR_VECTOR_AND_SCALAR_TYPES(vector_t, scalar_t, num, op) \ + __device__ inline vector_t& op(vector_t& x, scalar_t y) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { x[i] = op(x[i], y); } \ + return x; \ + } \ + __device__ inline vector_t op(const vector_t& x, const scalar_t y) { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = op(x[i], y); } \ + return rslt; \ + } \ + __device__ inline vector_t& op(scalar_t x, vector_t& y) { \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = op(x, y[i]); } \ + return y; \ + } \ + __device__ inline vector_t op(const scalar_t x, const vector_t& y) { \ + vector_t rslt; \ + _Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = op(x, y[i]); } \ + return rslt; \ + } + +#define BASIC_BINARY_FUNC_FOR_EACH(_, ...) \ + _(__VA_ARGS__, max); \ + _(__VA_ARGS__, min); + +#define SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(_, IMPL) \ + _(IMPL, half8, float16, 8); \ + _(IMPL, half4, float16, 4); + +SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(BASIC_BINARY_FUNC_FOR_EACH, BASIC_BINARY_FUNC_FOR_VECTOR_AND_SCALAR_TYPES); +#undef SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL + #endif // __cplusplus && CINN_CUDA_FP16 #endif // CINN_COMMON_FLOAT16_H diff --git a/cinn/ir/ir.cc b/cinn/ir/ir.cc index 7db39fa704..6ccadf0a15 100755 --- a/cinn/ir/ir.cc +++ b/cinn/ir/ir.cc @@ -72,7 +72,7 @@ void Sub::Verify() const { BinaryNodeVerify(a(), b(), "Sub"); } Expr Mul::Make(Expr a, Expr b) { CHECK(a.defined()); CHECK(b.defined()); - CHECK_EQ(a.type(), b.type()) << "a=" << a << ", b=" << b; + // CHECK_EQ(a.type(), b.type()) << "a=" << a << ", b=" << b; auto node = make_shared(a, b); return Expr(node); } @@ -658,7 +658,7 @@ Expr Sum::Make(const std::vector &vs) { auto *n = make_shared(); auto type = vs.front().type(); - for (auto &v : vs) CHECK_EQ(v.type(), type) << vs.front() << " " << v; + // for (auto &v : vs) CHECK_EQ(v.type(), type) << vs.front() << " " << v; n->operands() = vs; @@ -672,7 +672,7 @@ Expr Product::Make(const std::vector &vs) { auto *n = make_shared(); auto type = vs.front().type(); - for (auto &v : vs) CHECK_EQ(v.type(), type); + // for (auto &v : vs) CHECK_EQ(v.type(), type); n->operands() = vs; diff --git a/cinn/optim/vectorize_loops.cc b/cinn/optim/vectorize_loops.cc index fc6d97f1da..74e4b7153a 100644 --- a/cinn/optim/vectorize_loops.cc +++ b/cinn/optim/vectorize_loops.cc @@ -216,6 +216,39 @@ class CudaVectorizer : public IRMutator { IRMutator::Visit(&node->value, &node->value); } + void Visit(const Cast *op, Expr *expr) override { + auto *node = expr->As(); + auto &body = node->v(); + auto tensors = ir::CollectIRNodes(body, [](const Expr *x) { return x->As(); }); + + // handdle cast op at last. + IRMutator::Visit(op, expr); + + int lanes = 1; + bool is_point = false; + bool is_vectorize = true; + for (auto &tensor : tensors) { + auto tensor_ = tensor.As(); + if (node->type().is_scalar() && (tensor2vectorized_vars_.count(tensor_->name) || tensor_->type().is_vector())) { + lanes = tensor_->type().lanes(); + is_point = tensor_->type().is_cpp_handle(); + is_vectorize = true; + break; + } + } + + if (is_vectorize) { + // create new cast type. + auto ntype = common::Type(Type::type_t::Customized, node->type().bits(), factor_); + // set vector type. + ntype.set_customized_type(GetVectorTypeName(node->type().ElementOf())); + if (is_point) { + ntype.set_cpp_handle(); + } + node->set_type(ntype); + } + } + private: void TensorVectorized(ir::LoadStoreAddrMnger *node, std::vector *indices, bool is_store) { auto *tensor = node->tensor.As(); @@ -722,7 +755,8 @@ struct VectorizeLoops_ : public IRMutator { CudaVectorizer cuda_vectorizer(new_forloop->loop_var, factor, &var_intervals); cuda_vectorizer.Visit(&new_forloop->body); // unroll the new forloop to compute each element of the vector iteratively - auto copied_loop = optim::IRCopy(_new_forloop); + auto copied_loop = optim::IRCopy(_new_forloop); + copied_loop.As()->extent = ir::Expr(1); copied_loop.As()->set_unrolled(); optim::UnrollLoop(&copied_loop); // add cast exprs of vector type in the front of vectorized forloop, @@ -862,6 +896,9 @@ struct VectorizeLoops_ : public IRMutator { new_index = Expr(forloop->loop_var) * factor + Expr(new_iterator); } optim::IrReplace(&forloop->body, forloop->loop_var, new_index); + + {} + auto new_forloop = For::Make(new_iterator, forloop->min, make_const(factor),