diff --git a/cinn/optim/transform_gpu_forloop.cc b/cinn/optim/transform_gpu_forloop.cc index 47ba3b50ee..b79ae313cf 100644 --- a/cinn/optim/transform_gpu_forloop.cc +++ b/cinn/optim/transform_gpu_forloop.cc @@ -170,23 +170,31 @@ void CudaSyncThreadsDropIfThenElse(Expr *expr) { Mutator()(expr); } +class RestructureVarNodes : public ir::IRMutator<> { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::_Var_ *var, Expr *op) override { *op = IRCopy(*op); } +}; + class ReplaceIndexToBindExpr : public ir::IRMutator<> { public: void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } private: void Visit(const ir::ScheduleBlockRealize *op, Expr *expr) override { - auto *schedule_block_realize = expr->As(); + ir::ScheduleBlockRealize *schedule_block_realize = expr->As(); CHECK(schedule_block_realize->schedule_block.As()); - auto iter_values = schedule_block_realize->iter_values; - auto body_copy = schedule_block_realize->schedule_block.As()->body; - auto iter_vars = schedule_block_realize->schedule_block.As()->iter_vars; + std::vector iter_values = schedule_block_realize->iter_values; + ir::Expr body = schedule_block_realize->schedule_block.As()->body; + std::vector iter_vars = schedule_block_realize->schedule_block.As()->iter_vars; CHECK_EQ(iter_values.size(), iter_vars.size()); for (int idx = 0; idx < iter_values.size(); ++idx) { - ReplaceVarWithExpr(&body_copy, iter_vars[idx], iter_values[idx]); + ReplaceVarWithExpr(&body, iter_vars[idx], iter_values[idx]); } - ir::IRMutator<>::Visit(&body_copy, &body_copy); + ir::IRMutator<>::Visit(&body, &body); } }; @@ -594,6 +602,11 @@ class ReplaceVarToZero : public ir::IRMutator<> { void OptimizeExprGPU(Expr *expr) { VLOG(2) << "Before Optimize Expr:\n" << *expr; + + // copy var nodes to prevent one modification leading to multiple changes + RestructureVarNodes restructure_var_nodes; + restructure_var_nodes(expr); + // replace var to bind expr ReplaceIndexToBindExpr replace_index_to_bind_expr; replace_index_to_bind_expr(expr);