diff --git a/cinn/backends/compiler.cc b/cinn/backends/compiler.cc index 880441ff63..798b0a96a2 100644 --- a/cinn/backends/compiler.cc +++ b/cinn/backends/compiler.cc @@ -25,6 +25,7 @@ #include "cinn/backends/nvrtc/nvrtc_util.h" #include "cinn/runtime/cuda/cuda_module.h" #include "cinn/runtime/cuda/cuda_util.h" +#include "cinn/runtime/flags.h" #endif DECLARE_string(cinn_source_code_save_path); @@ -123,16 +124,13 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code) SourceCodePrint::GetInstance()->write(source_code); using runtime::cuda::CUDAModule; - backends::nvrtc::Compiler compiler; - + nvrtc::Compiler compiler; auto ptx = compiler(source_code); CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << source_code; - cuda_module_.reset( new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); RuntimeSymbols symbols; - for (auto& fn : device_module.functions()) { std::string kernel_fn_name = fn->name; auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name); diff --git a/cinn/backends/nvrtc/nvrtc_util.cc b/cinn/backends/nvrtc/nvrtc_util.cc index 101406984e..4598054701 100644 --- a/cinn/backends/nvrtc/nvrtc_util.cc +++ b/cinn/backends/nvrtc/nvrtc_util.cc @@ -17,12 +17,20 @@ #include #include #include +#include +#include +#include + +#include +#include #include "cinn/backends/cuda_util.h" #include "cinn/backends/nvrtc/header_generator.h" #include "cinn/common/common.h" +#include "cinn/runtime/flags.h" #include "cinn/utils/string.h" +DECLARE_string(cinn_nvcc_cmd_path); DECLARE_bool(nvrtc_compile_to_cubin); namespace cinn { @@ -30,6 +38,9 @@ namespace backends { namespace nvrtc { std::string Compiler::operator()(const std::string& code, bool include_headers) { + if (runtime::CanUseNvccCompiler()) { + return CompileWithNvcc(code); + } return CompileCudaSource(code, include_headers); } @@ -140,6 +151,89 @@ std::string Compiler::CompileCudaSource(const std::string& code, bool include_he return data; } +std::string Compiler::CompileWithNvcc(const std::string& cuda_c) { + // read dir source + std::string dir = "./source"; + if (access(dir.c_str(), 0) == -1) { + CHECK(mkdir(dir.c_str(), 7) != -1) << "Fail to mkdir " << dir; + } + + // get unqiue prefix name + prefix_name_ = dir + "/" + common::UniqName("rtc_tmp"); + + auto cuda_c_file = prefix_name_ + ".cu"; + std::ofstream ofs(cuda_c_file, std::ios::out); + CHECK(ofs.is_open()) << "Fail to open file " << cuda_c_file; + ofs << cuda_c; + ofs.close(); + + CompileToPtx(); + CompileToCubin(); + + return prefix_name_ + ".cubin"; +} + +// std::string Compiler::GetPtx() { return ReadFile(prefix_name_ + ".ptx", std::ios::in); } + +void Compiler::CompileToPtx() { + auto include_dir = common::Context::Global().runtime_include_dir(); + std::string include_dir_str = ""; + for (auto dir : include_dir) { + if (include_dir_str.empty()) { + include_dir_str = dir; + } else { + include_dir_str += ":" + dir; + } + } + + std::string options = std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + + std::string(":$PATH && nvcc -std=c++14 --ptx -O3 -I ") + include_dir_str; + options += " -arch=" + GetDeviceArch(); + options += " -o " + prefix_name_ + ".ptx"; + options += " " + prefix_name_ + ".cu"; + + VLOG(2) << "Nvcc Compile Options : " << options; + CHECK(system(options.c_str()) == 0) << options; +} + +void Compiler::CompileToCubin() { + std::string options = + std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + std::string(":$PATH && nvcc --cubin -O3"); + options += " -arch=" + GetDeviceArch(); + options += " -o " + prefix_name_ + ".cubin"; + options += " " + prefix_name_ + ".ptx"; + + VLOG(2) << "Nvcc Compile Options : " << options; + CHECK(system(options.c_str()) == 0) << options; +} + +std::string Compiler::GetDeviceArch() { + int major = 0, minor = 0; + if (cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0) == cudaSuccess && + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0) == cudaSuccess) { + return "sm_" + std::to_string(major) + std::to_string(minor); + } else { + LOG(WARNING) << "cannot detect compute capability from your device, " + << "fall back to compute_30."; + return "sm_30"; + } +} + +std::string Compiler::ReadFile(const std::string& file_name, std::ios_base::openmode mode) { + // open cubin file + std::ifstream ifs(file_name, mode); + CHECK(ifs.is_open()) << "Fail to open file " << file_name; + ifs.seekg(std::ios::end); + auto len = ifs.tellg(); + ifs.seekg(0); + + // read cubin file + std::string file_data(len, ' '); + ifs.read(&file_data[0], len); + ifs.close(); + return std::move(file_data); +} + } // namespace nvrtc } // namespace backends } // namespace cinn diff --git a/cinn/backends/nvrtc/nvrtc_util.h b/cinn/backends/nvrtc/nvrtc_util.h index a5f8424a31..b13c24c550 100644 --- a/cinn/backends/nvrtc/nvrtc_util.h +++ b/cinn/backends/nvrtc/nvrtc_util.h @@ -70,6 +70,19 @@ class Compiler { * whether to compile the source code into cubin, only works with cuda version > 11.1 */ bool compile_to_cubin_{false}; + + // compile with nvcc + std::string CompileWithNvcc(const std::string&); + + // compile to ptx + void CompileToPtx(); + // compile to cubin + void CompileToCubin(); + std::string GetDeviceArch(); + + std::string ReadFile(const std::string&, std::ios_base::openmode); + + std::string prefix_name_{""}; }; } // namespace nvrtc diff --git a/cinn/hlir/framework/parallel_compiler.cc b/cinn/hlir/framework/parallel_compiler.cc index aa22dfce65..ede13cab04 100644 --- a/cinn/hlir/framework/parallel_compiler.cc +++ b/cinn/hlir/framework/parallel_compiler.cc @@ -28,6 +28,7 @@ #include "cinn/common/context.h" #include "cinn/hlir/framework/pass.h" #include "cinn/ir/module.h" +#include "cinn/runtime/flags.h" DECLARE_int32(cinn_parallel_compile_size); DECLARE_int32(cinn_parallel_compile_thread); @@ -178,10 +179,9 @@ void ParallelCompiler::Task::CodegenAndJit() { backends::nvrtc::Compiler compiler; auto ptx = compiler(cuda_c); CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c; - graph->SavePTXCode(ptx); - // load cumodule cumodule.reset(new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); + // register kernel backends::RuntimeSymbols symbols; for (auto& fn : dmodule.functions()) { diff --git a/cinn/hlir/pe/pe_transform_test.cc b/cinn/hlir/pe/pe_transform_test.cc index 58f0c109f2..f5a76014e8 100644 --- a/cinn/hlir/pe/pe_transform_test.cc +++ b/cinn/hlir/pe/pe_transform_test.cc @@ -224,719 +224,6 @@ TEST(Concat, ConcatCase0) { #endif } -TEST(Reduce, Reduce_Test_0) { - int m = 128; - int n = 128; - Expr M(m), N(n); - - Placeholder A("A", {M, N}); - Placeholder B("B", {M, N}); - - auto C = hlir::pe::Add(A.tensor(), B.tensor()); - auto D = hlir::pe::ReduceSum(C, {0}); - auto stages = CreateStages({C, D}); - stages[C]->SetBuffer("local"); - stages[C]->Reorder({1, 0}); - stages[D]->Bind(0, "threadIdx.x"); - stages[C]->SimpleComputeAt(stages[D], 1); - - auto func = Lower("fn", stages, {A, B, D}); - LOG(INFO) << "func:\n" << func; - -#ifdef CINN_WITH_CUDA - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -#endif -} - -#ifdef CINN_WITH_CUDA -void CudaReduceReorder(poly::StageMap stages, ir::Tensor input, const std::vector &axes) { - auto &shape = input->shape; - std::vector order; - for (int idx = 0; idx < shape.size(); ++idx) { - if (std::find(axes.begin(), axes.end(), idx) == axes.end()) { - order.push_back(idx); - } - } - for (auto axis : axes) { - order.push_back(axis); - } - stages[input]->Reorder(order); - - int last_dimension_num = shape.size() - axes.back() - 1; - int index = shape.size() - last_dimension_num - axes.size(); - for (auto idx = index; idx < index + last_dimension_num - 1; ++idx) { - stages[input]->Fuse(index, index + 1); - } - - if (stages[input]->GetDimRange(index) > 1024) { - stages[input]->Split(index, 1024); - } - - for (int idx = 0; idx < index - 1; ++idx) { - stages[input]->Fuse(0, 1); - } -} - -TEST(Reduce, Reduce_Test_1) { - int m = 128; - int n = 128; - Expr M(m), N(n); - - Placeholder A("A", {M, M, M, N, N}); - Placeholder B("B", {M, M, M, N, N}); - - auto C = hlir::pe::Add(A.tensor(), B.tensor()); - auto D = hlir::pe::ReduceSum(C, {0, 2}); - auto stages = CreateStages({C, D}); - hlir::pe::CudaReduceSchedule(stages, D, 2, common::DefaultNVGPUTarget()); - CudaReduceReorder(stages, C, {0, 2}); - stages[C]->SetBuffer("local"); - stages[C]->SimpleComputeAt(stages[D], stages[D]->n_out_dims() - 1); - // stages[C]->ComputeInline(); - - auto func = Lower("fn", stages, {A, B, D}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_2) { - int m = 10201; - int n = 50; - Expr M(m), N(n); - - Placeholder A("A", {M, N}); - - auto reduce_out = hlir::pe::BlockShuffleReduceSum(A.tensor(), {0}, false); - CHECK_EQ(reduce_out.size(), 3) << "the output of reduce is not equal to 3"; - auto stages = CreateStages({A, reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaBlockShuffleReduceSchedule(stages, reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_2_1) { - int m = 10240; - int n = 64; - Expr M(m), N(n); - - Placeholder A("A", {M, N}); - - auto reduce_out = hlir::pe::BlockShuffleReduceSum(A.tensor(), {0}, false); - CHECK_EQ(reduce_out.size(), 3) << "the output of reduce is not equal to 3"; - auto stages = CreateStages({A, reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaBlockShuffleReduceSchedule(stages, reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_2_2) { - int m = 10240; - int n = 64; - Expr M(m), N(n); - - Placeholder A("A", {N, M, N}); - - auto reduce_out = hlir::pe::BlockShuffleReduceSum(A.tensor(), {1}, false); - CHECK_EQ(reduce_out.size(), 3) << "the output of reduce is not equal to 3"; - auto stages = CreateStages({A, reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaBlockShuffleReduceSchedule(stages, reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_2_3) { - int m = 10240; - int n = 16; - Expr M(m), N(n); - - Placeholder A("A", {M, N, N}); - - auto reduce_out = hlir::pe::BlockShuffleReduceSum(A.tensor(), {0}, false); - CHECK_EQ(reduce_out.size(), 3) << "the output of reduce is not equal to 3"; - auto stages = CreateStages({A, reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaBlockShuffleReduceSchedule(stages, reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_3) { - int m = 10201; - Expr M(m); - - Placeholder A("A", {M}); - - auto reduce_out = hlir::pe::TwoStepBlockReduceSum(A.tensor(), {0}, false); - CHECK_EQ(reduce_out.size(), 4) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaTwoStepReduceSchedule( - stages, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_3_1) { - int m = 10240; - Expr M(m); - - Placeholder A("A", {M}); - - auto reduce_out = hlir::pe::TwoStepBlockReduceSum(A.tensor(), {0}, false); - CHECK_EQ(reduce_out.size(), 4) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaTwoStepReduceSchedule( - stages, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_3_2) { - int m = 10240; - int n = 64; - Expr M(m), N(n); - - Placeholder A("A", {N, M}); - - auto reduce_out = hlir::pe::TwoStepBlockReduceSum(A.tensor(), {1}, false); - CHECK_EQ(reduce_out.size(), 4) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaTwoStepReduceSchedule( - stages, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_4) { - int m = 10201; - int n = 64; - Expr M(m), N(n); - - Placeholder A("A", {N, M}); - - auto reduce_out = hlir::pe::TwoStepBlockReduceSum(A.tensor(), {1}, false); - CHECK_EQ(reduce_out.size(), 4) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaTwoStepReduceSchedule( - stages, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_5) { - int m = 32; - int n = 64; - Expr M(m), N(n); - - Placeholder A("A", {N, M, M}); - - auto reduce_out = hlir::pe::TwoStepBlockReduceSum(A.tensor(), {1, 2}, false); - CHECK_EQ(reduce_out.size(), 2) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[1], reduce_out[0]}); - - CudaBlockReduceInternalSchedule(stages, reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_6) { - int m = 32; - int n = 64; - Expr M(m), N(n); - - Placeholder A("A", {N, N, M, M}); - - auto reduce_out = hlir::pe::TwoStepBlockReduceSum(A.tensor(), {0, 2, 3}, false); - CHECK_EQ(reduce_out.size(), 3) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[1], reduce_out[0]}); - - CudaBlockReduceSchedule(stages, reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_7) { - int m = 10201; - int n = 64; - Expr M(m), N(n); - - Placeholder A("A", {N, N, M}); - - auto reduce_out = hlir::pe::TwoStepBlockReduceSum(A.tensor(), {1, 2}, false); - CHECK_EQ(reduce_out.size(), 4) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaTwoStepReduceSchedule( - stages, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_8) { - int m = 128; - int n = 112; - Expr M(m), N(n); - - Placeholder A("A", {M, M, N, N}); - - auto reduce_out = hlir::pe::TwoStepBlockReduceSum(A.tensor(), {0, 2, 3}, false); - CHECK_EQ(reduce_out.size(), 4) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaTwoStepReduceSchedule( - stages, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_9) { - int m = 128; - int n = 56; - Expr M(m), N(n); - - Placeholder A("A", {M, M, N, N}); - - auto reduce_out = hlir::pe::TwoStepBlockReduceSum(A.tensor(), {0, 2, 3}, false); - CHECK_EQ(reduce_out.size(), 4) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaTwoStepReduceSchedule( - stages, reduce_out[3], reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_10) { - int m = 128; - int n = 128; - Expr M(m), N(n); - - Placeholder A("A", {M, N}); - Placeholder B("B", {M, N}); - - auto c = hlir::pe::Add(A.tensor(), B.tensor()); - auto reduce_out = hlir::pe::BlockShuffleReduceSum(c, {0}, false); - CHECK_EQ(reduce_out.size(), 3) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, B, c, reduce_out[2], reduce_out[1], reduce_out[0]}); - - stages[c]->Split(0, 8); - stages[c]->Fuse(1, 2); - stages[c]->Reorder({1, 0}); - stages[c]->SetBuffer("local"); - stages[c]->SimpleComputeAt(stages[reduce_out[1]], 1); - stages[reduce_out[2]]->ComputeInline(); - stages[reduce_out[1]]->Bind(0, "threadIdx.x"); - stages[reduce_out[1]]->SetBuffer("shared"); - stages[reduce_out[0]]->Bind(0, "threadIdx.x"); - - auto func = Lower("fn", stages, {A, B, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_11) { - int m = 10; - int n = 10; - Expr M(m), N(n); - - Placeholder A("A", {M, N, N}); - auto reduce_out = hlir::pe::BlockShuffleReduceSum(A.tensor(), {0, 1}, false); - CHECK_EQ(reduce_out.size(), 3) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaBlockShuffleReduceSchedule(stages, reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_12) { - int m = 10; - int n = 10; - Expr M(m), N(n); - - Placeholder A("A", {M, M, N, N}); - auto reduce_out = hlir::pe::BlockShuffleReduceSum(A.tensor(), {0, 1, 2}, false); - CHECK_EQ(reduce_out.size(), 3) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaBlockShuffleReduceSchedule(stages, reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} - -TEST(Reduce, Reduce_Test_13) { - int m = 16; - int n = 16; - Expr M(m), N(n); - - Placeholder A("A", {M, M, N, N}); - auto reduce_out = hlir::pe::BlockShuffleReduceSum(A.tensor(), {0, 1, 2}, false); - CHECK_EQ(reduce_out.size(), 3) << "the output of reduce is not equal to 4!"; - auto stages = CreateStages({A, reduce_out[2], reduce_out[1], reduce_out[0]}); - - CudaBlockShuffleReduceSchedule(stages, reduce_out[2], reduce_out[1], reduce_out[0], common::DefaultNVGPUTarget()); - auto func = Lower("fn", stages, {A, reduce_out[0]}); - LOG(INFO) << "func:\n" << func; - - auto target = common::DefaultNVGPUTarget(); - Module::Builder builder("Concat_Builder", target); - builder.AddFunction(func); - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - auto source_code = codegen.Compile(builder.Build()); - LOG(INFO) << "compiled code:\n\n\n" << source_code; - - // nv jit compile to ptx - backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); - CHECK(!ptx.empty()); -} -#endif - } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/runtime/cuda/cuda_module.cc b/cinn/runtime/cuda/cuda_module.cc index 0ec5aa0bfe..5a90d0cccb 100644 --- a/cinn/runtime/cuda/cuda_module.cc +++ b/cinn/runtime/cuda/cuda_module.cc @@ -25,6 +25,7 @@ #include "cinn/backends/cuda_util.h" #include "cinn/runtime/cuda/cuda_util.h" +#include "cinn/runtime/flags.h" namespace cinn { namespace runtime { @@ -103,16 +104,11 @@ CUfunction CUDAModule::GetFunction(int device_id, const std::string& func_name) jit_options[4] = CU_JIT_GENERATE_LINE_INFO; jit_opt_vals[4] = reinterpret_cast(value); - CUresult status = cuModuleLoadDataEx( - &module_per_card_[device_id], data_.c_str(), jit_num_options, jit_options.data(), jit_opt_vals.data()); - - if (CUDA_SUCCESS != status) { - RAW_LOG(ERROR, "PTX JIT ERROR LOG: %s\n.", log_buffer.data()); - const char* name; - cuGetErrorName(status, &name); - const char* msg; - cuGetErrorString(status, &msg); - RAW_LOG(FATAL, "The error `%s` occurs while compiling the ptx! And its message is `%s`.", name, msg); + if (runtime::CanUseNvccCompiler()) { + CUDA_DRIVER_CALL(cuModuleLoad(&module_per_card_[device_id], data_.c_str())); + } else { + CUDA_DRIVER_CALL(cuModuleLoadDataEx( + &module_per_card_[device_id], data_.c_str(), jit_num_options, jit_options.data(), jit_opt_vals.data())); } } @@ -124,11 +120,15 @@ CUfunction CUDAModule::GetFunction(int device_id, const std::string& func_name) CUdeviceptr CUDAModule::GetGlobal(int device_id, const std::string& name, size_t nbytes) { if (!module_per_card_[device_id]) { std::lock_guard lock(mutex_); - CUDA_DRIVER_CALL(cuModuleLoadData(&module_per_card_[device_id], data_.c_str())); + if (runtime::CanUseNvccCompiler()) { + CUDA_DRIVER_CALL(cuModuleLoad(&module_per_card_[device_id], data_.c_str())); + } else { + CUDA_DRIVER_CALL(cuModuleLoadData(&module_per_card_[device_id], data_.c_str())); + } } - CUdeviceptr global; size_t _nbytes; + CUdeviceptr global; CUDA_DRIVER_CALL(cuModuleGetGlobal(&global, &_nbytes, module_per_card_[device_id], name.c_str())); return global; } diff --git a/cinn/runtime/flags.cc b/cinn/runtime/flags.cc index 5b6b523ecc..9ce3f0d3af 100644 --- a/cinn/runtime/flags.cc +++ b/cinn/runtime/flags.cc @@ -16,6 +16,9 @@ #include #include +#include +#include +#include #include @@ -35,6 +38,9 @@ using ::GFLAGS_NAMESPACE::Int64FromEnv; using ::GFLAGS_NAMESPACE::StringFromEnv; DEFINE_string(cinn_x86_builtin_code_root, StringFromEnv("FLAGS_cinn_x86_builtin_code_root", ""), ""); +DEFINE_string(cinn_nvcc_cmd_path, + StringFromEnv("FLAGS_cinn_nvcc_cmd_path", "/usr/local/cuda/bin"), + "Setting nvcc default path!"); DEFINE_int32(cinn_parallel_compile_size, Int32FromEnv("FLAGS_cinn_parallel_compile_size", 16), @@ -82,9 +88,13 @@ DEFINE_bool(cinn_use_dense_merge_pass, "Whether use dense merge pass."); DEFINE_bool(nvrtc_compile_to_cubin, - BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", false), + BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", true), "Whether nvrtc compile cuda source into cubin instead of ptx (only works after cuda-11.1)."); +DEFINE_bool(cinn_compile_with_nvrtc, + BoolFromEnv("FLAGS_cinn_compile_with_nvrtc", true), + "Whether nvrtc compile cuda source with nvrtc(default nvcc)."); + // FLAGS for performance analysis and accuracy debug DEFINE_bool(cinn_sync_run, BoolFromEnv("FLAGS_cinn_sync_run", false), @@ -175,6 +185,11 @@ unsigned long long RandomSeed::Clear() { return old_seed; } +bool CanUseNvccCompiler() { + std::string nvcc_dir = FLAGS_cinn_nvcc_cmd_path + "/nvcc"; + return (access(nvcc_dir.c_str(), 0) == -1 ? false : true) && (!FLAGS_cinn_compile_with_nvrtc); +} + bool IsCompiledWithCUDA() { #if !defined(CINN_WITH_CUDA) return false; diff --git a/cinn/runtime/flags.h b/cinn/runtime/flags.h index 4b4f19f322..6a663d12af 100644 --- a/cinn/runtime/flags.h +++ b/cinn/runtime/flags.h @@ -27,6 +27,8 @@ bool CheckStringFlagFalse(const std::string &flag); void SetCinnCudnnDeterministic(bool state); bool GetCinnCudnnDeterministic(); +bool CanUseNvccCompiler(); + class RandomSeed { public: static unsigned long long GetOrSet(unsigned long long seed = 0);