From ec850f1114e27131596ac6dc862fd78e84dd3bde Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Mon, 22 May 2023 20:53:32 +0800 Subject: [PATCH 1/3] fix(ir): fix copy error of var in ir --- cinn/optim/transform_gpu_forloop.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cinn/optim/transform_gpu_forloop.cc b/cinn/optim/transform_gpu_forloop.cc index 47ba3b50ee..57c19516de 100644 --- a/cinn/optim/transform_gpu_forloop.cc +++ b/cinn/optim/transform_gpu_forloop.cc @@ -179,7 +179,8 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> { auto *schedule_block_realize = expr->As(); CHECK(schedule_block_realize->schedule_block.As()); auto iter_values = schedule_block_realize->iter_values; - auto body_copy = schedule_block_realize->schedule_block.As()->body; + auto body = schedule_block_realize->schedule_block.As()->body; + auto body_copy = IRCopy(body); auto iter_vars = schedule_block_realize->schedule_block.As()->iter_vars; CHECK_EQ(iter_values.size(), iter_vars.size()); @@ -187,6 +188,7 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> { ReplaceVarWithExpr(&body_copy, iter_vars[idx], iter_values[idx]); } ir::IRMutator<>::Visit(&body_copy, &body_copy); + schedule_block_realize->schedule_block.As()->body = body_copy; } }; From 06dc71b3969ed55d4e0df80fc6e7cc720051a618 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 23 May 2023 14:11:46 +0800 Subject: [PATCH 2/3] replace type deduction with explicit type --- cinn/optim/transform_gpu_forloop.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cinn/optim/transform_gpu_forloop.cc b/cinn/optim/transform_gpu_forloop.cc index 57c19516de..3b8e056381 100644 --- a/cinn/optim/transform_gpu_forloop.cc +++ b/cinn/optim/transform_gpu_forloop.cc @@ -176,12 +176,12 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> { private: void Visit(const ir::ScheduleBlockRealize *op, Expr *expr) override { - auto *schedule_block_realize = expr->As(); + ir::ScheduleBlockRealize *schedule_block_realize = expr->As(); CHECK(schedule_block_realize->schedule_block.As()); - auto iter_values = schedule_block_realize->iter_values; - auto body = schedule_block_realize->schedule_block.As()->body; - auto body_copy = IRCopy(body); - auto iter_vars = schedule_block_realize->schedule_block.As()->iter_vars; + std::vector iter_values = schedule_block_realize->iter_values; + ir::Expr body = schedule_block_realize->schedule_block.As()->body; + ir::Expr body_copy = IRCopy(body); + std::vector iter_vars = schedule_block_realize->schedule_block.As()->iter_vars; CHECK_EQ(iter_values.size(), iter_vars.size()); for (int idx = 0; idx < iter_values.size(); ++idx) { From e2385a68e0f70b51f516fda4a0c501254e6ed4ab Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Thu, 25 May 2023 10:59:59 +0800 Subject: [PATCH 3/3] fix ir copy of var --- cinn/optim/transform_gpu_forloop.cc | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/cinn/optim/transform_gpu_forloop.cc b/cinn/optim/transform_gpu_forloop.cc index 3b8e056381..b79ae313cf 100644 --- a/cinn/optim/transform_gpu_forloop.cc +++ b/cinn/optim/transform_gpu_forloop.cc @@ -170,6 +170,14 @@ void CudaSyncThreadsDropIfThenElse(Expr *expr) { Mutator()(expr); } +class RestructureVarNodes : public ir::IRMutator<> { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::_Var_ *var, Expr *op) override { *op = IRCopy(*op); } +}; + class ReplaceIndexToBindExpr : public ir::IRMutator<> { public: void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } @@ -180,15 +188,13 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> { CHECK(schedule_block_realize->schedule_block.As()); std::vector iter_values = schedule_block_realize->iter_values; ir::Expr body = schedule_block_realize->schedule_block.As()->body; - ir::Expr body_copy = IRCopy(body); std::vector iter_vars = schedule_block_realize->schedule_block.As()->iter_vars; CHECK_EQ(iter_values.size(), iter_vars.size()); for (int idx = 0; idx < iter_values.size(); ++idx) { - ReplaceVarWithExpr(&body_copy, iter_vars[idx], iter_values[idx]); + ReplaceVarWithExpr(&body, iter_vars[idx], iter_values[idx]); } - ir::IRMutator<>::Visit(&body_copy, &body_copy); - schedule_block_realize->schedule_block.As()->body = body_copy; + ir::IRMutator<>::Visit(&body, &body); } }; @@ -596,6 +602,11 @@ class ReplaceVarToZero : public ir::IRMutator<> { void OptimizeExprGPU(Expr *expr) { VLOG(2) << "Before Optimize Expr:\n" << *expr; + + // copy var nodes to prevent one modification leading to multiple changes + RestructureVarNodes restructure_var_nodes; + restructure_var_nodes(expr); + // replace var to bind expr ReplaceIndexToBindExpr replace_index_to_bind_expr; replace_index_to_bind_expr(expr);