Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cinn/backends/codegen_cuda_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
}

bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, ir::Expr index_expr, bool is_store) {
static constexpr char index2suffix[8] = {'x', 'y', 'z', 'w', 'v', 'u', 't', 's'};
// static constexpr char index2suffix[8] = {'x', 'y', 'z', 'w', 'v', 'u', 't', 's'};

// addr of op should be a place of tensor and the index is simple int number
if (!op->is_addr_tensor() || !index_expr.As<ir::IntImm>()) {
Expand All @@ -361,7 +361,12 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op,
if (is_store && tensor->type().is_cpp_handle()) {
os() << tensor->name << "[" << index << "]";
} else {
os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".") << index2suffix[index];
if (tensor->type().is_cpp_handle()) {
// os() << "(*" << tensor->name << ")[" << index << "]";
os() << "(*" << tensor->name << ")";
} else {
os() << tensor->name;
}
}
return true;
}
Expand Down
211 changes: 199 additions & 12 deletions cinn/common/float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,6 @@ struct CINN_ALIGN(2) float16 {
#endif // __cplusplus
};

struct CINN_ALIGN(32) float8 {
float x, y, z, w, v, u, t, s;
};

struct CINN_ALIGN(16) half8 {
float16 x, y, z, w, v, u, t, s;
};

struct CINN_ALIGN(8) half4 {
float16 x, y, z, w;
};

#ifdef __cplusplus
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
Expand Down Expand Up @@ -584,6 +572,169 @@ __host__ __device__ inline float16(abs)(const float16& a) {

__host__ __device__ inline float16(log)(const float16& a) { return float16(std::log(static_cast<float>(a))); }

// Overload of vectorized types.
struct half4;
struct half8;

template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
__host__ __device__ inline const T& operator[](int i) const { return val[i]; }
__host__ __device__ inline T& operator[](int i) { return val[i]; }
};

#define BINARY_FUNC_FOR_VECTORIZED_INPUT(vector_t, num, op) \
__device__ inline vector_t& operator op(const vector_t& x) { \
_Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op x[i]; } \
return *this; \
} \
__device__ inline vector_t operator op(const vector_t& x) const { \
vector_t rslt; \
_Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = data[i] op(x[i]); } \
return rslt; \
} \
__device__ inline vector_t& operator op(vector_t& x) { \
_Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op(x[i]); } \
return *this; \
} \
__device__ inline vector_t& operator op(vector_t& x) const { \
_Pragma("unroll") for (int i = 0; i < num; ++i) { x[i] = data[i] op(x[i]); } \
return x; \
}

#define BINARY_FUNC_FOR_SCALARIZED_INPUT(vector_t, scalar_t, cast_t, num, op) \
__device__ inline vector_t& operator op(const scalar_t& x) { \
_Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = data[i] op static_cast<const cast_t>(x); } \
return *this; \
} \
__device__ inline vector_t operator op(const scalar_t& x) const { \
vector_t rslt; \
_Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = data[i] op static_cast<const cast_t>(x); } \
return rslt; \
}
#define VECTORIZED_CAST_SCALAR_FUNC(vector_t, scalar_t, cast_t, num) \
__device__ inline explicit vector_t(const scalar_t x) { \
_Pragma("unroll") for (int i = 0; i < num; ++i) { data[i] = static_cast<const cast_t>(x); } \
}

#define BASIC_BINARY_ARITHMETICS_FOR_EACH(_, ...) \
_(__VA_ARGS__, +); \
_(__VA_ARGS__, -); \
_(__VA_ARGS__, *); \
_(__VA_ARGS__, /);

struct CINN_ALIGN(32) float8 {
AlignedVector<float, 8> data;
__device__ inline const float& operator[](int i) const { return data[i]; }
__device__ inline float& operator[](int i) { return data[i]; }
float8() = default;

// operate with different input types.
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, float8, 8);
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float8, float16, float, 8);
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float8, float, float, 8)

VECTORIZED_CAST_SCALAR_FUNC(float8, float, float, 8);
VECTORIZED_CAST_SCALAR_FUNC(float8, float16, float, 8);

__device__ inline explicit float8(const half8& x);
};

struct CINN_ALIGN(16) float4 {
AlignedVector<float, 4> data;
__device__ inline const float& operator[](int i) const { return data[i]; }
__device__ inline float& operator[](int i) { return data[i]; }

float4() = default;

// operate with different input types.
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, float4, 4);
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float4, float16, float, 4);
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, float4, float, float, 4);

VECTORIZED_CAST_SCALAR_FUNC(float4, float, float, 4);
VECTORIZED_CAST_SCALAR_FUNC(float4, float16, float, 4);

__device__ inline explicit float4(const half4& x);
};

struct CINN_ALIGN(16) half8 {
AlignedVector<float16, 8> data;
__device__ inline const float16& operator[](int i) const { return data[i]; }
__device__ inline float16& operator[](int i) { return data[i]; }

// construction.
half8() = default;

__device__ explicit half8(const float8& x) {
_Pragma("unroll") for (int i = 0; i < 8; ++i) { data[i] = static_cast<const float16>(x[i]); }
}

// Operate with different input types.
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, half8, 8);
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, half8, float16, float16, 8);

// Cast operation overload.
VECTORIZED_CAST_SCALAR_FUNC(half8, float, float16, 8);
VECTORIZED_CAST_SCALAR_FUNC(half8, float16, float16, 8);
};

struct CINN_ALIGN(8) half4 {
AlignedVector<float16, 4> data;
__device__ inline const float16& operator[](int i) const { return data[i]; }
__device__ inline float16& operator[](int i) { return data[i]; }

// construction.
half4() = default;
__device__ explicit half4(const float4& x) {
_Pragma("unroll") for (int i = 0; i < 4; ++i) { data[i] = static_cast<const float16>(x[i]); }
}

// Operate with different input types.
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_VECTORIZED_INPUT, half4, 4);
BASIC_BINARY_ARITHMETICS_FOR_EACH(BINARY_FUNC_FOR_SCALARIZED_INPUT, half4, float16, float16, 4);

// Cast operation overload.
VECTORIZED_CAST_SCALAR_FUNC(half4, float, float16, 4);
VECTORIZED_CAST_SCALAR_FUNC(half4, float16, float16, 4);
};

// Cast operation overload.
__device__ inline float4::float4(const half4& x) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
data[i] = static_cast<const float16>(x[i]);
}
}

// Cast operation overload.
__device__ inline float8::float8(const half8& x) {
#pragma unroll
for (int i = 0; i < 8; ++i) {
data[i] = static_cast<const float>(x[i]);
}
}

#define SCALAR_FIRST_BASIC_BINARY_FUNCTIONS(scalar_t, vector_t, num, op) \
__device__ inline vector_t& operator op(const scalar_t x, vector_t& y) { \
_Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = x op y[i]; } \
return y; \
} \
__device__ inline vector_t operator op(const scalar_t x, const vector_t& y) { \
vector_t rslt; \
_Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = x op y[i]; } \
return rslt; \
}

#define SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(_, IMPL) \
_(IMPL, float, float8, 8) \
_(IMPL, float, float4, 4) \
_(IMPL, float16, half8, 8) \
_(IMPL, float16, half4, 4)

SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(BASIC_BINARY_ARITHMETICS_FOR_EACH, SCALAR_FIRST_BASIC_BINARY_FUNCTIONS);
#undef SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL

#ifdef __cplusplus
} // namespace common
} // namespace cinn
Expand Down Expand Up @@ -624,6 +775,42 @@ __host__ __device__ inline cinn::common::float16 max(const cinn::common::float16
__host__ __device__ inline cinn::common::float16 min(const cinn::common::float16& a, const cinn::common::float16& b) {
return a < b ? a : b;
}

using cinn::common::float16;
using cinn::common::half4;
using cinn::common::half8;

#define BASIC_BINARY_FUNC_FOR_VECTOR_AND_SCALAR_TYPES(vector_t, scalar_t, num, op) \
__device__ inline vector_t& op(vector_t& x, scalar_t y) { \
_Pragma("unroll") for (int i = 0; i < num; ++i) { x[i] = op(x[i], y); } \
return x; \
} \
__device__ inline vector_t op(const vector_t& x, const scalar_t y) { \
vector_t rslt; \
_Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = op(x[i], y); } \
return rslt; \
} \
__device__ inline vector_t& op(scalar_t x, vector_t& y) { \
_Pragma("unroll") for (int i = 0; i < num; ++i) { y[i] = op(x, y[i]); } \
return y; \
} \
__device__ inline vector_t op(const scalar_t x, const vector_t& y) { \
vector_t rslt; \
_Pragma("unroll") for (int i = 0; i < num; ++i) { rslt[i] = op(x, y[i]); } \
return rslt; \
}

#define BASIC_BINARY_FUNC_FOR_EACH(_, ...) \
_(__VA_ARGS__, max); \
_(__VA_ARGS__, min);

#define SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(_, IMPL) \
_(IMPL, half8, float16, 8); \
_(IMPL, half4, float16, 4);

SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL(BASIC_BINARY_FUNC_FOR_EACH, BASIC_BINARY_FUNC_FOR_VECTOR_AND_SCALAR_TYPES);
#undef SCALAR_FIRST_BINARY_FUNCTIONS_FOR_EACH_IMPL

#endif // __cplusplus && CINN_CUDA_FP16

#endif // CINN_COMMON_FLOAT16_H
6 changes: 3 additions & 3 deletions cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void Sub::Verify() const { BinaryNodeVerify(a(), b(), "Sub"); }
Expr Mul::Make(Expr a, Expr b) {
CHECK(a.defined());
CHECK(b.defined());
CHECK_EQ(a.type(), b.type()) << "a=" << a << ", b=" << b;
// CHECK_EQ(a.type(), b.type()) << "a=" << a << ", b=" << b;
auto node = make_shared<Mul>(a, b);
return Expr(node);
}
Expand Down Expand Up @@ -658,7 +658,7 @@ Expr Sum::Make(const std::vector<Expr> &vs) {

auto *n = make_shared<Sum>();
auto type = vs.front().type();
for (auto &v : vs) CHECK_EQ(v.type(), type) << vs.front() << " " << v;
// for (auto &v : vs) CHECK_EQ(v.type(), type) << vs.front() << " " << v;

n->operands() = vs;

Expand All @@ -672,7 +672,7 @@ Expr Product::Make(const std::vector<Expr> &vs) {

auto *n = make_shared<Product>();
auto type = vs.front().type();
for (auto &v : vs) CHECK_EQ(v.type(), type);
// for (auto &v : vs) CHECK_EQ(v.type(), type);

n->operands() = vs;

Expand Down
39 changes: 38 additions & 1 deletion cinn/optim/vectorize_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,39 @@ class CudaVectorizer : public IRMutator<Expr *> {
IRMutator::Visit(&node->value, &node->value);
}

void Visit(const Cast *op, Expr *expr) override {
auto *node = expr->As<Cast>();
auto &body = node->v();
auto tensors = ir::CollectIRNodes(body, [](const Expr *x) { return x->As<ir::_Tensor_>(); });

// handdle cast op at last.
IRMutator::Visit(op, expr);

int lanes = 1;
bool is_point = false;
bool is_vectorize = true;
for (auto &tensor : tensors) {
auto tensor_ = tensor.As<ir::_Tensor_>();
if (node->type().is_scalar() && (tensor2vectorized_vars_.count(tensor_->name) || tensor_->type().is_vector())) {
lanes = tensor_->type().lanes();
is_point = tensor_->type().is_cpp_handle();
is_vectorize = true;
break;
}
}

if (is_vectorize) {
// create new cast type.
auto ntype = common::Type(Type::type_t::Customized, node->type().bits(), factor_);
// set vector type.
ntype.set_customized_type(GetVectorTypeName(node->type().ElementOf()));
if (is_point) {
ntype.set_cpp_handle();
}
node->set_type(ntype);
}
}

private:
void TensorVectorized(ir::LoadStoreAddrMnger *node, std::vector<Expr> *indices, bool is_store) {
auto *tensor = node->tensor.As<ir::_Tensor_>();
Expand Down Expand Up @@ -722,7 +755,8 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
CudaVectorizer cuda_vectorizer(new_forloop->loop_var, factor, &var_intervals);
cuda_vectorizer.Visit(&new_forloop->body);
// unroll the new forloop to compute each element of the vector iteratively
auto copied_loop = optim::IRCopy(_new_forloop);
auto copied_loop = optim::IRCopy(_new_forloop);
copied_loop.As<ir::For>()->extent = ir::Expr(1);
copied_loop.As<ir::For>()->set_unrolled();
optim::UnrollLoop(&copied_loop);
// add cast exprs of vector type in the front of vectorized forloop,
Expand Down Expand Up @@ -862,6 +896,9 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
new_index = Expr(forloop->loop_var) * factor + Expr(new_iterator);
}
optim::IrReplace(&forloop->body, forloop->loop_var, new_index);

{}

auto new_forloop = For::Make(new_iterator,
forloop->min,
make_const(factor),
Expand Down