From 4acc14799eb32f359c2146aea1267e531438b62b Mon Sep 17 00:00:00 2001
From: sjjian <921465802@qq.com>
Date: Thu, 20 Jul 2023 06:37:57 +0000
Subject: [PATCH 1/3] =?UTF-8?q?code=20refactor,=20=E4=BD=BF=E7=94=A8?=
=?UTF-8?q?=E5=8A=A8=E6=80=81=E6=B3=A8=E5=85=A5=E7=9A=84=E5=8F=82=E6=95=B0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
ast/context.go | 11 ++++++++++-
ast/mapper.go | 4 ++--
config.go | 9 +++++++++
parser.go | 15 +++++++++++----
parser_test.go | 10 ++++++++--
5 files changed, 40 insertions(+), 9 deletions(-)
create mode 100644 config.go
diff --git a/ast/context.go b/ast/context.go
index 7c98479..86f9339 100644
--- a/ast/context.go
+++ b/ast/context.go
@@ -4,12 +4,14 @@ type Context struct {
QueryType string // select, insert, update, delete
Variable map[string]string
Sqls map[string]*SqlNode
+ Config *Config
}
-func NewContext() *Context {
+func NewContext(config *Config) *Context {
return &Context{
Variable: map[string]string{},
Sqls: map[string]*SqlNode{},
+ Config: config,
}
}
@@ -26,3 +28,10 @@ func (c *Context) GetSql(k string) (*SqlNode, bool) {
sql, ok := c.Sqls[k]
return sql, ok
}
+
+type Config struct {
+ SkipErrorQuery bool
+ WithQueryId bool
+}
+
+type ConfigFn func() func(*Config)
diff --git a/ast/mapper.go b/ast/mapper.go
index 2254625..4dc8151 100644
--- a/ast/mapper.go
+++ b/ast/mapper.go
@@ -68,7 +68,7 @@ func (m *Mapper) GetStmt(ctx *Context) (string, error) {
return strings.TrimSuffix(buff.String(), "\n"), nil
}
-func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) {
+func (m *Mapper) GetStmts(ctx *Context) ([]string, error) {
var stmts []string
ctx.Sqls = m.SqlNodes
for _, a := range m.QueryNodes {
@@ -77,7 +77,7 @@ func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) {
stmts = append(stmts, data)
continue
}
- if skipErrorQuery {
+ if ctx.Config.SkipErrorQuery {
continue
}
return nil, err
diff --git a/config.go b/config.go
new file mode 100644
index 0000000..10c3d30
--- /dev/null
+++ b/config.go
@@ -0,0 +1,9 @@
+package parser
+
+import "github.com/actiontech/mybatis-mapper-2-sql/ast"
+
+func SkipErrorQuery() func(*ast.Config) {
+ return func(c *ast.Config) {
+ c.SkipErrorQuery = true
+ }
+}
diff --git a/parser.go b/parser.go
index f424043..dad656c 100644
--- a/parser.go
+++ b/parser.go
@@ -20,7 +20,7 @@ func ParseXML(data string) (string, error) {
if n == nil {
return "", nil
}
- stmt, err := n.GetStmt(ast.NewContext())
+ stmt, err := n.GetStmt(ast.NewContext(&ast.Config{}))
if err != nil {
return "", err
}
@@ -28,8 +28,9 @@ func ParseXML(data string) (string, error) {
}
// ParseXMLQuery is a parser for parse all query in XML to []string one by one;
-// you can set `skipErrorQuery` true to ignore invalid query.
-func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) {
+// ConfigFn:
+// `SkipErrorQuery` to ignore invalid query.
+func ParseXMLQuery(data string, configFns ...ast.ConfigFn) ([]string, error) {
r := strings.NewReader(data)
d := xml.NewDecoder(r)
n, err := parse(d)
@@ -43,7 +44,13 @@ func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) {
if !ok {
return nil, fmt.Errorf("the mapper is not found")
}
- stmts, err := m.GetStmts(ast.NewContext(), skipErrorQuery)
+
+ config := &ast.Config{}
+ for _, configFn := range configFns {
+ configFn()(config)
+ }
+
+ stmts, err := m.GetStmts(ast.NewContext(config))
if err != nil {
return nil, err
}
diff --git a/parser_test.go b/parser_test.go
index 89ce67b..d3961b4 100644
--- a/parser_test.go
+++ b/parser_test.go
@@ -2,6 +2,8 @@ package parser
import (
"testing"
+
+ "github.com/actiontech/mybatis-mapper-2-sql/ast"
)
func testParser(t *testing.T, xmlData, expect string) {
@@ -574,7 +576,11 @@ func TestParserSQLRefIdNotFound(t *testing.T) {
}
func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []string) {
- actual, err := ParseXMLQuery(xmlData, skipError)
+ configFns := []ast.ConfigFn{}
+ if skipError {
+ configFns = append(configFns, SkipErrorQuery)
+ }
+ actual, err := ParseXMLQuery(xmlData, configFns...)
if err != nil {
t.Errorf("parse error: %v", err)
return
@@ -796,7 +802,7 @@ func TestParserQueryHasInvalidQuery(t *testing.T) {
from t
-`, false)
+`)
if err == nil {
t.Errorf("expect has error, but no error")
}
From 804a2b3527b6924bb583efc069b0d3e15942381b Mon Sep 17 00:00:00 2001
From: sjjian <921465802@qq.com>
Date: Thu, 20 Jul 2023 07:06:51 +0000
Subject: [PATCH 2/3] =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=BE=93=E5=87=BA?=
=?UTF-8?q?=E8=AF=AD=E5=8F=A5=E7=9A=84ID,=20=E9=80=9A=E8=BF=87SQL=E6=B3=A8?=
=?UTF-8?q?=E9=87=8A=E7=9A=84=E5=BD=A2=E5=BC=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
ast/query.go | 12 +++++++++++-
config.go | 6 ++++++
2 files changed, 17 insertions(+), 1 deletion(-)
diff --git a/ast/query.go b/ast/query.go
index 7482d14..294fa1b 100644
--- a/ast/query.go
+++ b/ast/query.go
@@ -3,6 +3,7 @@ package ast
import (
"bytes"
"encoding/xml"
+ "fmt"
"github.com/actiontech/mybatis-mapper-2-sql/sqlfmt"
)
@@ -32,6 +33,7 @@ func (s *QueryNode) Scan(start *xml.StartElement) error {
func (s *QueryNode) GetStmt(ctx *Context) (string, error) {
buff := bytes.Buffer{}
ctx.QueryType = s.Type
+
for _, a := range s.Children {
data, err := a.GetStmt(ctx)
if err != nil {
@@ -39,5 +41,13 @@ func (s *QueryNode) GetStmt(ctx *Context) (string, error) {
}
buff.WriteString(data)
}
- return sqlfmt.FormatSQL(buff.String()), nil
+ fmtSQL := sqlfmt.FormatSQL(buff.String())
+ if ctx.Config.WithQueryId {
+ buff.Reset()
+ buff.WriteString(fmt.Sprintf("/* id: %s */\n", s.Id))
+ buff.WriteString(fmtSQL)
+ return buff.String(), nil
+ } else {
+ return fmtSQL, nil
+ }
}
diff --git a/config.go b/config.go
index 10c3d30..9711f99 100644
--- a/config.go
+++ b/config.go
@@ -7,3 +7,9 @@ func SkipErrorQuery() func(*ast.Config) {
c.SkipErrorQuery = true
}
}
+
+func WithQueryId() func(*ast.Config) {
+ return func(c *ast.Config) {
+ c.WithQueryId = true
+ }
+}
From a16f58f7b4183ff7582fb83a7ef85991241afc05 Mon Sep 17 00:00:00 2001
From: sjjian <921465802@qq.com>
Date: Thu, 20 Jul 2023 07:07:40 +0000
Subject: [PATCH 3/3] =?UTF-8?q?=E4=B8=BA=E6=94=AF=E6=8C=81=E8=BE=93?=
=?UTF-8?q?=E5=87=BA=E8=AF=AD=E5=8F=A5ID=E7=9A=84=E5=8A=9F=E8=83=BD?=
=?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
parser_ibatis_test.go | 11 +++--
parser_test.go | 112 ++++++++++++++++++++++++++++++++++++++----
2 files changed, 108 insertions(+), 15 deletions(-)
diff --git a/parser_ibatis_test.go b/parser_ibatis_test.go
index 30cf553..9b371d0 100644
--- a/parser_ibatis_test.go
+++ b/parser_ibatis_test.go
@@ -52,7 +52,7 @@ id = #id#
}
func TestParseIBatisInclude(t *testing.T) {
- testParserQuery(t, false, `
+ testParserQuery(t, `
@@ -76,7 +76,7 @@ SELECT id, name
"SELECT `id`,`name` FROM `items` WHERE `parentid`=6",
})
- testParserQuery(t, false, `
+ testParserQuery(t, `
@@ -102,7 +102,7 @@ SELECT id, name
}
func TestParseIBatisAll(t *testing.T) {
- testParserQuery(t, true, `
+ testParserQuery(t, `
@@ -195,5 +195,6 @@ func TestParseIBatisAll(t *testing.T) {
"SELECT * FROM `EMPLOYEE` WHERE (`username`=? OR `username`=?) AND `id` IS NULL AND `id`=?",
"SELECT * FROM `EMPLOYEE` WHERE `ACC_FIRST_NAME`=? OR `ACC_LAST_NAME`=? AND `ACC_EMAIL` LIKE ? AND `ACC_ID`=? ORDER BY `ACC_LAST_NAME`",
"SELECT * FROM `EMPLOYEE` ORDER BY ?",
- })
-}
\ No newline at end of file
+ },
+ SkipErrorQuery)
+}
diff --git a/parser_test.go b/parser_test.go
index d3961b4..32eb5c9 100644
--- a/parser_test.go
+++ b/parser_test.go
@@ -575,11 +575,7 @@ func TestParserSQLRefIdNotFound(t *testing.T) {
}
}
-func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []string) {
- configFns := []ast.ConfigFn{}
- if skipError {
- configFns = append(configFns, SkipErrorQuery)
- }
+func testParserQuery(t *testing.T, xmlData string, expect []string, configFns ...ast.ConfigFn) {
actual, err := ParseXMLQuery(xmlData, configFns...)
if err != nil {
t.Errorf("parse error: %v", err)
@@ -599,7 +595,7 @@ func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []stri
}
func TestParserQueryFullFile(t *testing.T) {
- testParserQuery(t, false,
+ testParserQuery(t,
`
@@ -812,7 +808,7 @@ func TestParserQueryHasInvalidQuery(t *testing.T) {
}
func TestParserQueryHasInvalidQueryButSkip(t *testing.T) {
- testParserQuery(t, true,
+ testParserQuery(t,
`
@@ -836,11 +832,12 @@ func TestParserQueryHasInvalidQueryButSkip(t *testing.T) {
`, []string{
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?",
- })
+ },
+ SkipErrorQuery)
}
func TestIssue302(t *testing.T) {
- testParserQuery(t, false,
+ testParserQuery(t,
`
`, []string{
"SELECT * FROM `user` WHERE `name`=? AND `name`=?",
})
- testParserQuery(t, false,
+ testParserQuery(t,
`