From 4acc14799eb32f359c2146aea1267e531438b62b Mon Sep 17 00:00:00 2001 From: sjjian <921465802@qq.com> Date: Thu, 20 Jul 2023 06:37:57 +0000 Subject: [PATCH 1/3] =?UTF-8?q?code=20refactor,=20=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E5=8A=A8=E6=80=81=E6=B3=A8=E5=85=A5=E7=9A=84=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ast/context.go | 11 ++++++++++- ast/mapper.go | 4 ++-- config.go | 9 +++++++++ parser.go | 15 +++++++++++---- parser_test.go | 10 ++++++++-- 5 files changed, 40 insertions(+), 9 deletions(-) create mode 100644 config.go diff --git a/ast/context.go b/ast/context.go index 7c98479..86f9339 100644 --- a/ast/context.go +++ b/ast/context.go @@ -4,12 +4,14 @@ type Context struct { QueryType string // select, insert, update, delete Variable map[string]string Sqls map[string]*SqlNode + Config *Config } -func NewContext() *Context { +func NewContext(config *Config) *Context { return &Context{ Variable: map[string]string{}, Sqls: map[string]*SqlNode{}, + Config: config, } } @@ -26,3 +28,10 @@ func (c *Context) GetSql(k string) (*SqlNode, bool) { sql, ok := c.Sqls[k] return sql, ok } + +type Config struct { + SkipErrorQuery bool + WithQueryId bool +} + +type ConfigFn func() func(*Config) diff --git a/ast/mapper.go b/ast/mapper.go index 2254625..4dc8151 100644 --- a/ast/mapper.go +++ b/ast/mapper.go @@ -68,7 +68,7 @@ func (m *Mapper) GetStmt(ctx *Context) (string, error) { return strings.TrimSuffix(buff.String(), "\n"), nil } -func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) { +func (m *Mapper) GetStmts(ctx *Context) ([]string, error) { var stmts []string ctx.Sqls = m.SqlNodes for _, a := range m.QueryNodes { @@ -77,7 +77,7 @@ func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) { stmts = append(stmts, data) continue } - if skipErrorQuery { + if ctx.Config.SkipErrorQuery { continue } return nil, err diff --git a/config.go b/config.go new file mode 100644 index 0000000..10c3d30 --- /dev/null +++ b/config.go @@ -0,0 +1,9 @@ +package parser + +import "github.com/actiontech/mybatis-mapper-2-sql/ast" + +func SkipErrorQuery() func(*ast.Config) { + return func(c *ast.Config) { + c.SkipErrorQuery = true + } +} diff --git a/parser.go b/parser.go index f424043..dad656c 100644 --- a/parser.go +++ b/parser.go @@ -20,7 +20,7 @@ func ParseXML(data string) (string, error) { if n == nil { return "", nil } - stmt, err := n.GetStmt(ast.NewContext()) + stmt, err := n.GetStmt(ast.NewContext(&ast.Config{})) if err != nil { return "", err } @@ -28,8 +28,9 @@ func ParseXML(data string) (string, error) { } // ParseXMLQuery is a parser for parse all query in XML to []string one by one; -// you can set `skipErrorQuery` true to ignore invalid query. -func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) { +// ConfigFn: +// `SkipErrorQuery` to ignore invalid query. +func ParseXMLQuery(data string, configFns ...ast.ConfigFn) ([]string, error) { r := strings.NewReader(data) d := xml.NewDecoder(r) n, err := parse(d) @@ -43,7 +44,13 @@ func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) { if !ok { return nil, fmt.Errorf("the mapper is not found") } - stmts, err := m.GetStmts(ast.NewContext(), skipErrorQuery) + + config := &ast.Config{} + for _, configFn := range configFns { + configFn()(config) + } + + stmts, err := m.GetStmts(ast.NewContext(config)) if err != nil { return nil, err } diff --git a/parser_test.go b/parser_test.go index 89ce67b..d3961b4 100644 --- a/parser_test.go +++ b/parser_test.go @@ -2,6 +2,8 @@ package parser import ( "testing" + + "github.com/actiontech/mybatis-mapper-2-sql/ast" ) func testParser(t *testing.T, xmlData, expect string) { @@ -574,7 +576,11 @@ func TestParserSQLRefIdNotFound(t *testing.T) { } func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []string) { - actual, err := ParseXMLQuery(xmlData, skipError) + configFns := []ast.ConfigFn{} + if skipError { + configFns = append(configFns, SkipErrorQuery) + } + actual, err := ParseXMLQuery(xmlData, configFns...) if err != nil { t.Errorf("parse error: %v", err) return @@ -796,7 +802,7 @@ func TestParserQueryHasInvalidQuery(t *testing.T) { from t -`, false) +`) if err == nil { t.Errorf("expect has error, but no error") } From 804a2b3527b6924bb583efc069b0d3e15942381b Mon Sep 17 00:00:00 2001 From: sjjian <921465802@qq.com> Date: Thu, 20 Jul 2023 07:06:51 +0000 Subject: [PATCH 2/3] =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E8=AF=AD=E5=8F=A5=E7=9A=84ID,=20=E9=80=9A=E8=BF=87SQL=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E7=9A=84=E5=BD=A2=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ast/query.go | 12 +++++++++++- config.go | 6 ++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ast/query.go b/ast/query.go index 7482d14..294fa1b 100644 --- a/ast/query.go +++ b/ast/query.go @@ -3,6 +3,7 @@ package ast import ( "bytes" "encoding/xml" + "fmt" "github.com/actiontech/mybatis-mapper-2-sql/sqlfmt" ) @@ -32,6 +33,7 @@ func (s *QueryNode) Scan(start *xml.StartElement) error { func (s *QueryNode) GetStmt(ctx *Context) (string, error) { buff := bytes.Buffer{} ctx.QueryType = s.Type + for _, a := range s.Children { data, err := a.GetStmt(ctx) if err != nil { @@ -39,5 +41,13 @@ func (s *QueryNode) GetStmt(ctx *Context) (string, error) { } buff.WriteString(data) } - return sqlfmt.FormatSQL(buff.String()), nil + fmtSQL := sqlfmt.FormatSQL(buff.String()) + if ctx.Config.WithQueryId { + buff.Reset() + buff.WriteString(fmt.Sprintf("/* id: %s */\n", s.Id)) + buff.WriteString(fmtSQL) + return buff.String(), nil + } else { + return fmtSQL, nil + } } diff --git a/config.go b/config.go index 10c3d30..9711f99 100644 --- a/config.go +++ b/config.go @@ -7,3 +7,9 @@ func SkipErrorQuery() func(*ast.Config) { c.SkipErrorQuery = true } } + +func WithQueryId() func(*ast.Config) { + return func(c *ast.Config) { + c.WithQueryId = true + } +} From a16f58f7b4183ff7582fb83a7ef85991241afc05 Mon Sep 17 00:00:00 2001 From: sjjian <921465802@qq.com> Date: Thu, 20 Jul 2023 07:07:40 +0000 Subject: [PATCH 3/3] =?UTF-8?q?=E4=B8=BA=E6=94=AF=E6=8C=81=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E8=AF=AD=E5=8F=A5ID=E7=9A=84=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- parser_ibatis_test.go | 11 +++-- parser_test.go | 112 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 108 insertions(+), 15 deletions(-) diff --git a/parser_ibatis_test.go b/parser_ibatis_test.go index 30cf553..9b371d0 100644 --- a/parser_ibatis_test.go +++ b/parser_ibatis_test.go @@ -52,7 +52,7 @@ id = #id# } func TestParseIBatisInclude(t *testing.T) { - testParserQuery(t, false, ` + testParserQuery(t, ` @@ -76,7 +76,7 @@ SELECT id, name "SELECT `id`,`name` FROM `items` WHERE `parentid`=6", }) - testParserQuery(t, false, ` + testParserQuery(t, ` @@ -102,7 +102,7 @@ SELECT id, name } func TestParseIBatisAll(t *testing.T) { - testParserQuery(t, true, ` + testParserQuery(t, ` @@ -195,5 +195,6 @@ func TestParseIBatisAll(t *testing.T) { "SELECT * FROM `EMPLOYEE` WHERE (`username`=? OR `username`=?) AND `id` IS NULL AND `id`=?", "SELECT * FROM `EMPLOYEE` WHERE `ACC_FIRST_NAME`=? OR `ACC_LAST_NAME`=? AND `ACC_EMAIL` LIKE ? AND `ACC_ID`=? ORDER BY `ACC_LAST_NAME`", "SELECT * FROM `EMPLOYEE` ORDER BY ?", - }) -} \ No newline at end of file + }, + SkipErrorQuery) +} diff --git a/parser_test.go b/parser_test.go index d3961b4..32eb5c9 100644 --- a/parser_test.go +++ b/parser_test.go @@ -575,11 +575,7 @@ func TestParserSQLRefIdNotFound(t *testing.T) { } } -func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []string) { - configFns := []ast.ConfigFn{} - if skipError { - configFns = append(configFns, SkipErrorQuery) - } +func testParserQuery(t *testing.T, xmlData string, expect []string, configFns ...ast.ConfigFn) { actual, err := ParseXMLQuery(xmlData, configFns...) if err != nil { t.Errorf("parse error: %v", err) @@ -599,7 +595,7 @@ func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []stri } func TestParserQueryFullFile(t *testing.T) { - testParserQuery(t, false, + testParserQuery(t, ` @@ -812,7 +808,7 @@ func TestParserQueryHasInvalidQuery(t *testing.T) { } func TestParserQueryHasInvalidQueryButSkip(t *testing.T) { - testParserQuery(t, true, + testParserQuery(t, ` @@ -836,11 +832,12 @@ func TestParserQueryHasInvalidQueryButSkip(t *testing.T) { `, []string{ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?", - }) + }, + SkipErrorQuery) } func TestIssue302(t *testing.T) { - testParserQuery(t, false, + testParserQuery(t, ` @@ -1052,3 +1049,98 @@ func TestOtherwise_issue1193(t *testing.T) { `, "SELECT * FROM `fruits` WHERE `name`=? AND `price`=? AND `category`=?;", ) } + +func TestWithQueryId_issue1331(t *testing.T) { + testParserQuery(t, ` + + + + `, []string{"SELECT * FROM `fruits` WHERE `name`=? AND `price`=?"}, + ) + testParserQuery(t, ` + + + + `, []string{"/* id: testChoose */\nSELECT * FROM `fruits` WHERE `name`=? AND `price`=?"}, + WithQueryId, + ) + testParserQuery(t, ` + + + + + `, []string{ + "/* id: testChoose */\nSELECT * FROM `fruits` WHERE `name`=? AND `price`=?", + "/* id: testChoose2 */\nSELECT * FROM `fruits` WHERE `name`=? AND `price`=?", + }, + WithQueryId, + ) +}