From 399cd7ff8fb1c661ebb5058008f1ece32499f137 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 8 May 2023 06:14:09 +0000 Subject: [PATCH 1/3] format --- cinn/frontend/op_mappers/paddle/matmul.cc | 69 +++++++++++++++++++++++ python/tests/op_mappers/test_matmul_op.py | 51 +++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 python/tests/op_mappers/test_matmul_op.py diff --git a/cinn/frontend/op_mappers/paddle/matmul.cc b/cinn/frontend/op_mappers/paddle/matmul.cc index 7db1c86fea..be5e735a85 100644 --- a/cinn/frontend/op_mappers/paddle/matmul.cc +++ b/cinn/frontend/op_mappers/paddle/matmul.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "cinn/frontend/op_mapper_registry.h" #include "cinn/frontend/op_mappers/common_utils.h" @@ -46,6 +48,72 @@ void MatMulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c ctx.AddVarModelToProgram(out_name, out->id); } +void MatMulGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { + // get dy + CHECK_EQ(op_desc.Input(paddle::GradVarName("Out")).size(), 1UL); + auto dout_name = op_desc.Input(paddle::GradVarName("Out")).front(); + + // get intput X and Y + CHECK_EQ(op_desc.Input("X").size(), 1UL); + auto x_name = op_desc.Input("X").front(); + CHECK_EQ(op_desc.Input("Y").size(), 1UL); + auto y_name = op_desc.Input("Y").front(); + + // get d_x + std::string dx_name, dy_name; + bool has_dx = !op_desc.Output(paddle::GradVarName("X")).empty(); + bool has_dy = !op_desc.Output(paddle::GradVarName("Y")).empty(); + if (has_dx) { + CHECK_EQ(op_desc.Output(paddle::GradVarName("X")).size(), 1UL); + dx_name = op_desc.Output(paddle::GradVarName("X")).front(); + } + if (has_dy) { + CHECK_EQ(op_desc.Output(paddle::GradVarName("Y")).size(), 1UL); + dy_name = op_desc.Output(paddle::GradVarName("Y")).front(); + } + + // get attr + auto trans_x = utils::GetAttrOrDefault(op_desc, "trans_x", false); + trans_x = utils::GetAttrOrDefault(op_desc, "transpose_X", trans_x); + + auto trans_y = utils::GetAttrOrDefault(op_desc, "trans_y", false); + trans_y = utils::GetAttrOrDefault(op_desc, "transpose_Y", trans_y); + + auto alpha = utils::GetAttrOrDefault(op_desc, "alpha", 1.0f); + + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); + auto dout = ctx.GetVar(dout_name); + if (has_dx) { + absl::optional dx; + if (trans_x && trans_y) { + dx = ctx.Builder()->Matmul(y, dout, true, true, alpha); + } else if (trans_x) { + dx = ctx.Builder()->Matmul(y, dout, false, true, alpha); + } else if (trans_y) { + dx = ctx.Builder()->Matmul(dout, y, false, false, alpha); + } else { + dx = ctx.Builder()->Matmul(dout, y, false, true, alpha); + } + ctx.AddVar(dx_name, dx.value()); + ctx.AddVarModelToProgram(dx_name, dx.value()->id); + } + if (has_dy) { + absl::optional dy; + if (trans_x && trans_y) { + dy = ctx.Builder()->Matmul(dout, x, true, true, alpha); + } else if (trans_x) { + dy = ctx.Builder()->Matmul(x, dout, false, false, alpha); + } else if (trans_y) { + dy = ctx.Builder()->Matmul(dout, x, true, false, alpha); + } else { + dy = ctx.Builder()->Matmul(x, dout, true, false, alpha); + } + ctx.AddVar(dy_name, dy.value()); + ctx.AddVarModelToProgram(dy_name, dy.value()->id); + } +} + } // namespace paddle_mappers } // namespace frontend } // namespace cinn @@ -53,5 +121,6 @@ void MatMulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c CINN_REGISTER_HELPER(paddle_matmul) { CINN_REGISTER_OP_MAPPER(matmul, cinn::frontend::paddle_mappers::MatMulOpMapper) CINN_REGISTER_OP_MAPPER(matmul_v2, cinn::frontend::paddle_mappers::MatMulOpMapper) + CINN_REGISTER_OP_MAPPER(matmul_v2_grad, cinn::frontend::paddle_mappers::MatMulGradOpMapper) return true; } diff --git a/python/tests/op_mappers/test_matmul_op.py b/python/tests/op_mappers/test_matmul_op.py new file mode 100644 index 0000000000..777b552c84 --- /dev/null +++ b/python/tests/op_mappers/test_matmul_op.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_mapper_test import OpMapperTest, logger +import paddle + + +class TestMatmulOp(OpMapperTest): + def init_input_data(self): + self.feed_data = { + "x": self.random([16, 32], "float32"), + "y": self.random([32, 16], "float32") + } + + def set_op_type(self): + return "matmul" + + def set_op_inputs(self): + x = paddle.static.data('X', self.feed_data["x"].shape, + self.feed_data["x"].dtype) + x = paddle.static.data('Y', self.feed_data["y"].shape, + self.feed_data["Y"].dtype) + return {'X': [x], 'Y': [y]} + + def set_op_attrs(self): + return {"trans_x": False, "trans_y": False} + + def set_op_outputs(self): + return {'Out': [str(self.feed_data['x'].dtype)]} + + def test_check_results(self): + self.check_outputs_and_grads() + + +if __name__ == "__main__": + unittest.main() From a15fb82a0aef80a1f6b4c5e6f8d6bef62cb5a95b Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 9 May 2023 06:43:57 +0000 Subject: [PATCH 2/3] fix test --- python/tests/op_mappers/test_matmul_op.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tests/op_mappers/test_matmul_op.py b/python/tests/op_mappers/test_matmul_op.py index 777b552c84..78200fee0a 100644 --- a/python/tests/op_mappers/test_matmul_op.py +++ b/python/tests/op_mappers/test_matmul_op.py @@ -23,17 +23,17 @@ class TestMatmulOp(OpMapperTest): def init_input_data(self): self.feed_data = { - "x": self.random([16, 32], "float32"), - "y": self.random([32, 16], "float32") + "X": self.random([16, 32], "float32"), + "Y": self.random([32, 16], "float32") } def set_op_type(self): return "matmul" def set_op_inputs(self): - x = paddle.static.data('X', self.feed_data["x"].shape, - self.feed_data["x"].dtype) - x = paddle.static.data('Y', self.feed_data["y"].shape, + x = paddle.static.data('X', self.feed_data["X"].shape, + self.feed_data["X"].dtype) + x = paddle.static.data('Y', self.feed_data["Y"].shape, self.feed_data["Y"].dtype) return {'X': [x], 'Y': [y]} @@ -41,7 +41,7 @@ def set_op_attrs(self): return {"trans_x": False, "trans_y": False} def set_op_outputs(self): - return {'Out': [str(self.feed_data['x'].dtype)]} + return {'Out': [str(self.feed_data['X'].dtype)]} def test_check_results(self): self.check_outputs_and_grads() From 23211b8cbafd5c3c1e48cc5c6f839adf0f58e700 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 11 May 2023 09:02:38 +0000 Subject: [PATCH 3/3] fix test --- python/tests/op_mappers/test_matmul_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tests/op_mappers/test_matmul_op.py b/python/tests/op_mappers/test_matmul_op.py index 78200fee0a..495baed975 100644 --- a/python/tests/op_mappers/test_matmul_op.py +++ b/python/tests/op_mappers/test_matmul_op.py @@ -31,11 +31,11 @@ def set_op_type(self): return "matmul" def set_op_inputs(self): - x = paddle.static.data('X', self.feed_data["X"].shape, + X = paddle.static.data('X', self.feed_data["X"].shape, self.feed_data["X"].dtype) - x = paddle.static.data('Y', self.feed_data["Y"].shape, + Y = paddle.static.data('Y', self.feed_data["Y"].shape, self.feed_data["Y"].dtype) - return {'X': [x], 'Y': [y]} + return {'X': [X], 'Y': [Y]} def set_op_attrs(self): return {"trans_x": False, "trans_y": False}