diff --git a/cinn/ir/CMakeLists.txt b/cinn/ir/CMakeLists.txt index b4618cbd0f..340365b7c7 100755 --- a/cinn/ir/CMakeLists.txt +++ b/cinn/ir/CMakeLists.txt @@ -7,6 +7,7 @@ gather_srcs(cinnapi_src SRCS ir_base.cc ir_schedule.cc ir_schedule_util.cc + ir_schedule_error.cc ir_visitor.cc ir_printer.cc ir_mutator.cc diff --git a/cinn/ir/ir_schedule.cc b/cinn/ir/ir_schedule.cc index eb2d934e0f..3664092061 100644 --- a/cinn/ir/ir_schedule.cc +++ b/cinn/ir/ir_schedule.cc @@ -50,8 +50,10 @@ namespace ir { class ScheduleImpl { public: ScheduleImpl() = default; - explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false) - : module_expr_(module_expr), debug_flag_(debug_flag) {} + explicit ScheduleImpl(const ModuleExpr& module_expr, + bool debug_flag = false, + ScheduleErrorMessageLevel err_msg_level = ScheduleErrorMessageLevel::kBlank) + : module_expr_(module_expr), debug_flag_(debug_flag), err_msg_level_(err_msg_level) {} explicit ScheduleImpl(ModuleExpr&& module_expr) : module_expr_(std::move(module_expr)) {} //! Set the debug flag. @@ -114,8 +116,32 @@ class ScheduleImpl { ModuleExpr module_expr_; bool debug_flag_{false}; + ScheduleErrorMessageLevel err_msg_level_; }; +/** \brief A macro that guards the beginning of each implementation of schedule */ +#define CINN_IR_SCHEDULE_BEGIN() try { +/** + * \brief A macro that pairs with `CINN_IR_SCHEDULE_BEGIN`, handling potential errors and error + * message printing + * \param primitive A string representing the kind of schedule primitive + * \param err_msg_level A ScheduleErrorMessageLevel enum, level of error message printing + */ +#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \ + } \ + catch (const IRScheduleErrorHandler& err_hanlder) { \ + switch (err_msg_level) { \ + case ScheduleErrorMessageLevel::kDetailed: \ + throw std::runtime_error(err_hanlder.FormatErrorMessage(primitive)); \ + case ScheduleErrorMessageLevel::kGenearl: \ + throw std::runtime_error(err_hanlder.GeneralErrorMessage()); \ + case ScheduleErrorMessageLevel::kBlank: \ + throw std::runtime_error("IRScheduleError occurred! (No more error message)"); \ + default: \ + throw std::runtime_error("IRScheduleError occurred! (No more error message)"); \ + } \ + } + std::vector ScheduleImpl::Split(const Expr& loop, const std::vector& factors) { CHECK(loop.As()) << "Expr param of Split must be For node! Please check."; auto* for_node = loop.As(); @@ -126,8 +152,10 @@ std::vector ScheduleImpl::Split(const Expr& loop, const std::vector& VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " << tot_extent << ") to (" << cinn::utils::Join(factors, ", ") << ") at loop:\n" << loop; - - auto processed_factors = ValidateFactors(factors, tot_extent); + std::vector processed_factors; + CINN_IR_SCHEDULE_BEGIN(); + processed_factors = ValidateFactors(factors, tot_extent); + CINN_IR_SCHEDULE_END("split", this->err_msg_level_); int prod_size = std::accumulate(processed_factors.begin(), processed_factors.end(), 1, std::multiplies()); std::vector new_loop_vars; Expr substitute_value(0); @@ -1971,8 +1999,11 @@ Expr ScheduleImpl::SampleCategorical(utils::LinearRandomEngine::StateType* rand_ IRSchedule::IRSchedule() {} -IRSchedule::IRSchedule(const ModuleExpr& module_expr, utils::LinearRandomEngine::StateType rand_seed, bool debug_flag) { - impl_ = std::make_unique(module_expr, debug_flag); +IRSchedule::IRSchedule(const ModuleExpr& module_expr, + utils::LinearRandomEngine::StateType rand_seed, + bool debug_flag, + ScheduleErrorMessageLevel err_msg_level) { + impl_ = std::make_unique(module_expr, debug_flag, err_msg_level); this->InitSeed(rand_seed); } diff --git a/cinn/ir/ir_schedule.h b/cinn/ir/ir_schedule.h index 6b7b252a57..2bcc41ebcb 100644 --- a/cinn/ir/ir_schedule.h +++ b/cinn/ir/ir_schedule.h @@ -22,6 +22,7 @@ #include "cinn/ir/ir.h" #include "cinn/ir/ir_base.h" #include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_schedule_error.h" #include "cinn/ir/schedule_desc.h" #include "cinn/ir/tensor.h" #include "cinn/utils/random_engine.h" @@ -67,7 +68,8 @@ class IRSchedule { IRSchedule(); explicit IRSchedule(const ModuleExpr& modexpr, utils::LinearRandomEngine::StateType rand_seed = -1, - bool debug_flag = false); + bool debug_flag = false, + ScheduleErrorMessageLevel err_msg_level = ScheduleErrorMessageLevel::kBlank); IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed = -1); IRSchedule(const IRSchedule& other); IRSchedule& operator=(const IRSchedule& src); diff --git a/cinn/ir/ir_schedule_error.cc b/cinn/ir/ir_schedule_error.cc new file mode 100644 index 0000000000..c54433e17b --- /dev/null +++ b/cinn/ir/ir_schedule_error.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_schedule_error.h" + +namespace cinn { +namespace ir { + +std::string IRScheduleErrorHandler::FormatErrorMessage(const std::string &primitive) const { + std::ostringstream os; + std::string err_msg = DetailedErrorMessage(); + + os << "[IRScheduleError] An error occurred in the scheduel primitive <" << primitive << ">. " << std::endl; + os << "Error info: " << err_msg; + return os.str(); +} + +} // namespace ir +} // namespace cinn diff --git a/cinn/ir/ir_schedule_error.h b/cinn/ir/ir_schedule_error.h new file mode 100644 index 0000000000..1cdc29efc0 --- /dev/null +++ b/cinn/ir/ir_schedule_error.h @@ -0,0 +1,67 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace cinn { +namespace ir { + +/** + * \brief Indicates the level of printing error message in the current Schedule + */ +enum class ScheduleErrorMessageLevel : int32_t { + /** \brief No error message*/ + kBlank = 0, + /** \brief Print an error message in short mode*/ + kGenearl = 1, + /** \brief Print an error message in detailed mode*/ + kDetailed = 2, +}; + +/** + * This handler is dealing with the errors happen in in the current Scheduling. + */ +class IRScheduleErrorHandler : public std::runtime_error { + public: + IRScheduleErrorHandler() : std::runtime_error("") {} + /** + * \brief constructor + * \param s the error message + */ + explicit IRScheduleErrorHandler(const std::string &s) : std::runtime_error(s) {} + + /** + * \brief Returns a detailed error message corresponding to the kDetailed error level. + */ + std::string FormatErrorMessage(const std::string &primitive) const; + + /** + * \brief Returns a short error message corresponding to the kGeneral error level. + */ + virtual std::string GeneralErrorMessage() const = 0; + + /** + * \brief Returns a detailed error message corresponding to the kDetailed error level. + */ + virtual std::string DetailedErrorMessage() const = 0; +}; + +} // namespace ir +} // namespace cinn diff --git a/cinn/ir/ir_schedule_util.cc b/cinn/ir/ir_schedule_util.cc index 054e05dee0..af1b1adf01 100644 --- a/cinn/ir/ir_schedule_util.cc +++ b/cinn/ir/ir_schedule_util.cc @@ -29,6 +29,7 @@ #include "cinn/ir/ir.h" #include "cinn/ir/ir_operators.h" #include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule_error.h" #include "cinn/ir/ir_visitor.h" #include "cinn/lang/compute.h" #include "cinn/optim/ir_copy.h" @@ -196,14 +197,66 @@ void ReplaceExpr(Expr* source, const std::vector& replaced, const std::vect } std::vector ValidateFactors(const std::vector& factors, int total_extent) { + class NegativeFactorErrorHandler : public IRScheduleErrorHandler { + public: + explicit NegativeFactorErrorHandler(int64_t factor, size_t idx) : factor_(factor), idx_(idx) {} + + std::string GeneralErrorMessage() const final { + return "[IRScheduleError]: The params in factors of Split should be positive. However, some " + "factor is zero or negative."; + } + + std::string DetailedErrorMessage() const final { + std::ostringstream os; + os << "The params in factors of Split should be positive. However, the factor at position " << idx_ << " is " + << factor_; + return os.str(); + } + + private: + int64_t factor_; + size_t idx_; + }; + + class InferFactorErrorHandler : public IRScheduleErrorHandler { + public: + std::string GeneralErrorMessage() const final { + return "[IRScheduleError]: The params in factors of Split should not be less than -1 or have more than one -1!"; + } + + std::string DetailedErrorMessage() const final { + std::ostringstream os; + os << "The params in factors of Split should not be less than -1 or have more than one -1!"; + return os.str(); + } + }; + + class FactorProductErrorHandler : public IRScheduleErrorHandler { + public: + std::string GeneralErrorMessage() const final { + return "[IRScheduleError]: In Split, the factors' product should be not larger than or equal to original loop's " + "extent!"; + } + + std::string DetailedErrorMessage() const final { + std::ostringstream os; + os << "In Split, the factors' product should be not larger than or equal to original loop's extent!"; + return os.str(); + } + }; + CHECK(!factors.empty()) << "The factors param of Split should not be empty! Please check."; bool has_minus_one = false; int product = 1; + int idx = -1; for (auto& i : factors) { - CHECK(i != 0) << "The params in factors of Split should not be 0! Please check."; - CHECK(i >= -1) << "The params in factors of Split should not be less than -1! Please check."; - if (i == -1) { - CHECK(!has_minus_one) << "The params in factors of Split should not have more than one -1! Please check."; + idx++; + if (i == 0 || i < -1) { + throw NegativeFactorErrorHandler(i, idx); + } else if (i == -1) { + if (has_minus_one) { + throw InferFactorErrorHandler(); + } has_minus_one = true; } else { product *= i; @@ -211,12 +264,14 @@ std::vector ValidateFactors(const std::vector& factors, int total_exte } std::vector validated_factors = factors; if (!has_minus_one) { - CHECK_GE(product, total_extent) - << "In Split, the factors' product should be equal to original loop's extent! Please check."; + if (product < total_extent) { + throw FactorProductErrorHandler(); + } return validated_factors; } else { - CHECK_LE(product, total_extent) << "In Split, when there is -1 in factors, the other factors' product should be <= " - "original loop's extent! Please check."; + if (product > total_extent) { + throw FactorProductErrorHandler(); + } int minus_one_candidate = (int)ceil((double)total_extent / (double)product); for (int i = 0; i < validated_factors.size(); ++i) { if (validated_factors[i] == -1) {