Skip to content
This repository was archived by the owner on Jan 14, 2025. It is now read-only.

Commit f8e328c

Browse files
authored
fix: Support postgres, mysql diarect (#42)
2 parents a64c8e7 + 35b5276 commit f8e328c

File tree

11 files changed

+439
-155
lines changed

11 files changed

+439
-155
lines changed

internal/arcgen/lang/go/generate.go

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,7 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error {
5757
if config.GenerateGoCRUDPackage() {
5858
crudFileExt := ".crud" + genFileExt
5959

60-
if err := func() error {
61-
filename := filepath.Join(config.GoCRUDPackagePath(), "common"+crudFileExt)
62-
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, rw_r__r__)
63-
if err != nil {
64-
return errorz.Errorf("os.OpenFile: %w", err)
65-
}
66-
defer f.Close()
67-
68-
if err := fprintCRUDCommon(f, bytes.NewBuffer(nil), arcSrcSetSlice); err != nil {
69-
return errorz.Errorf("sprint: %w", err)
70-
}
71-
72-
return nil
73-
}(); err != nil {
74-
return errorz.Errorf("f: %w", err)
75-
}
76-
60+
crudFiles := make([]string, 0)
7761
for _, arcSrcSet := range arcSrcSetSlice {
7862
// closure for defer
7963
if err := func() error {
@@ -84,7 +68,7 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error {
8468
return errorz.Errorf("os.OpenFile: %w", err)
8569
}
8670
defer f.Close()
87-
f.Name()
71+
crudFiles = append(crudFiles, filename)
8872

8973
if err := fprintCRUD(
9074
f,
@@ -98,6 +82,23 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error {
9882
return errorz.Errorf("f: %w", err)
9983
}
10084
}
85+
86+
if err := func() error {
87+
filename := filepath.Join(config.GoCRUDPackagePath(), "common"+crudFileExt)
88+
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, rw_r__r__)
89+
if err != nil {
90+
return errorz.Errorf("os.OpenFile: %w", err)
91+
}
92+
defer f.Close()
93+
94+
if err := fprintCRUDCommon(f, bytes.NewBuffer(nil), arcSrcSetSlice, crudFiles); err != nil {
95+
return errorz.Errorf("sprint: %w", err)
96+
}
97+
98+
return nil
99+
}(); err != nil {
100+
return errorz.Errorf("f: %w", err)
101+
}
101102
}
102103

103104
return nil

internal/arcgen/lang/go/generate_crud_common.go

Lines changed: 149 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,22 @@ package arcgengo
22

33
import (
44
"go/ast"
5+
"go/parser"
56
"go/printer"
67
"go/token"
78
"io"
9+
"path/filepath"
810
"strconv"
911
"strings"
1012

1113
errorz "github.com/kunitsucom/util.go/errors"
1214

15+
"github.com/kunitsucom/arcgen/internal/arcgen/lang/util"
1316
"github.com/kunitsucom/arcgen/internal/config"
1417
)
1518

16-
func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice) error {
17-
content, err := generateCRUDCommonFileContent(buf, arcSrcSetSlice)
19+
func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice, crudFiles []string) error {
20+
content, err := generateCRUDCommonFileContent(buf, arcSrcSetSlice, crudFiles)
1821
if err != nil {
1922
return errorz.Errorf("generateCRUDCommonFileContent: %w", err)
2023
}
@@ -27,8 +30,13 @@ func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlic
2730
return nil
2831
}
2932

30-
//nolint:funlen
31-
func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, error) {
33+
const (
34+
sqlQueryerContextVarName = "sqlContext"
35+
sqlQueryerContextTypeName = "sqlQueryerContext"
36+
)
37+
38+
//nolint:cyclop,funlen,gocognit,maintidx
39+
func generateCRUDCommonFileContent(buf buffer, arcSrcSetSlice ARCSourceSetSlice, crudFiles []string) (string, error) {
3240
astFile := &ast.File{
3341
// package
3442
Name: &ast.Ident{
@@ -38,18 +46,19 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
3846
Decls: []ast.Decl{},
3947
}
4048

41-
// // Since all directories are the same from arcSrcSetSlice[0].Filename to arcSrcSetSlice[len(-1)].Filename,
42-
// // get the package path from arcSrcSetSlice[0].Filename.
43-
// dir := filepath.Dir(arcSrcSetSlice[0].Filename)
44-
// structPackagePath, err := util.GetPackagePath(dir)
45-
// if err != nil {
46-
// return "", errorz.Errorf("GetPackagePath: %w", err)
47-
// }
49+
// Since all directories are the same from arcSrcSetSlice[0].Filename to arcSrcSetSlice[len(-1)].Filename,
50+
// get the package path from arcSrcSetSlice[0].Filename.
51+
dir := filepath.Dir(arcSrcSetSlice[0].Filename)
52+
structPackagePath, err := util.GetPackagePath(dir)
53+
if err != nil {
54+
return "", errorz.Errorf("GetPackagePath: %w", err)
55+
}
4856

4957
astFile.Decls = append(astFile.Decls,
5058
// import (
5159
// "context"
5260
// "database/sql"
61+
// "log/slog"
5362
//
5463
// dao "path/to/your/dao"
5564
// )
@@ -62,15 +71,18 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
6271
&ast.ImportSpec{
6372
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("database/sql")},
6473
},
65-
// &ast.ImportSpec{
66-
// Name: &ast.Ident{Name: "dao"},
67-
// Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(structPackagePath)},
68-
// },
74+
&ast.ImportSpec{
75+
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("log/slog")},
76+
},
77+
&ast.ImportSpec{
78+
Name: &ast.Ident{Name: "dao"},
79+
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(structPackagePath)},
80+
},
6981
},
7082
},
7183
)
7284

73-
// type sqlContext interface {
85+
// type sqlQueryerContext interface {
7486
// QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
7587
// QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
7688
// ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
@@ -80,7 +92,8 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
8092
Tok: token.TYPE,
8193
Specs: []ast.Spec{
8294
&ast.TypeSpec{
83-
Name: &ast.Ident{Name: "sqlContext"},
95+
// Assign: token.Pos(1),
96+
Name: &ast.Ident{Name: sqlQueryerContextTypeName},
8497
Type: &ast.InterfaceType{
8598
Methods: &ast.FieldList{
8699
List: []*ast.Field{
@@ -133,27 +146,138 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
133146
},
134147
)
135148

136-
// type Queryer struct {}
149+
// type _CRUD struct {
150+
// }
137151
astFile.Decls = append(astFile.Decls,
138152
&ast.GenDecl{
139153
Tok: token.TYPE,
140154
Specs: []ast.Spec{
141155
&ast.TypeSpec{
142-
Name: &ast.Ident{Name: "Queryer"},
156+
Name: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()},
143157
Type: &ast.StructType{Fields: &ast.FieldList{}},
144158
},
145159
},
146160
},
147161
)
148162

149-
// func NewQueryer() *Query {
150-
// return &Queryer{}
151-
// }
163+
// func LoggerFromContext(ctx context.Context) *slog.Logger {
164+
// if ctx == nil {
165+
// return slog.Default()
166+
// }
167+
// if logger, ok := ctx.Value((*slog.Logger)(nil)).(*slog.Logger); ok {
168+
// return logger
169+
// }
170+
// return slog.Default()
171+
// }
172+
astFile.Decls = append(astFile.Decls,
173+
&ast.FuncDecl{
174+
Name: &ast.Ident{Name: "LoggerFromContext"},
175+
Type: &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}}}, Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}}}},
176+
Body: &ast.BlockStmt{
177+
List: []ast.Stmt{
178+
&ast.IfStmt{
179+
Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "ctx"}, Op: token.EQL, Y: &ast.Ident{Name: "nil"}},
180+
Body: &ast.BlockStmt{List: []ast.Stmt{
181+
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Default"}}}}},
182+
}},
183+
},
184+
&ast.IfStmt{
185+
// if logger, ok := ctx.Value((*slog.Logger)(nil)).(*slog.Logger); ok {
186+
Init: &ast.AssignStmt{
187+
Lhs: []ast.Expr{&ast.Ident{Name: "logger"}, &ast.Ident{Name: "ok"}},
188+
Tok: token.DEFINE,
189+
Rhs: []ast.Expr{
190+
&ast.TypeAssertExpr{
191+
X: &ast.CallExpr{
192+
Fun: &ast.Ident{Name: "ctx.Value"},
193+
Args: []ast.Expr{&ast.CallExpr{Fun: &ast.ParenExpr{X: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}, Args: []ast.Expr{&ast.Ident{Name: "nil"}}}},
194+
},
195+
Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}},
196+
},
197+
},
198+
},
199+
Cond: &ast.Ident{Name: "ok"},
200+
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.Ident{Name: "logger"}}}}},
201+
},
202+
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Default"}}}}},
203+
},
204+
},
205+
},
206+
)
207+
208+
// func LoggerWithContext(ctx context.Context, logger *slog.Logger) context.Context {
209+
// return context.WithValue(ctx, (*slog.Logger)(nil), logger)
210+
// }
211+
astFile.Decls = append(astFile.Decls,
212+
&ast.FuncDecl{
213+
Name: &ast.Ident{Name: "LoggerWithContext"},
214+
Type: &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}, {Names: []*ast.Ident{{Name: "logger"}}, Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}}}, Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.Ident{Name: "context.Context"}}}}},
215+
Body: &ast.BlockStmt{
216+
List: []ast.Stmt{
217+
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "context"}, Sel: &ast.Ident{Name: "WithValue"}}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}, &ast.CallExpr{Fun: &ast.ParenExpr{X: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}, Args: []ast.Expr{&ast.Ident{Name: "nil"}}}, &ast.Ident{Name: "logger"}}}}},
218+
},
219+
},
220+
},
221+
)
222+
223+
// type CRUD interface {
224+
// Create{StructName}(ctx context.Context, sqlQueryer sqlQueryerContext, s *{Struct}) error
225+
// ...
226+
// }
227+
methods := make([]*ast.Field, 0)
228+
fset := token.NewFileSet()
229+
for _, crudFile := range crudFiles {
230+
rootNode, err := parser.ParseFile(fset, crudFile, nil, parser.ParseComments)
231+
if err != nil {
232+
// MEMO: parser.ParseFile err contains file path, so no need to log it
233+
return "", errorz.Errorf("parser.ParseFile: %w", err)
234+
}
235+
236+
// MEMO: Inspect is used to get the method declaration from the file
237+
ast.Inspect(rootNode, func(node ast.Node) bool {
238+
switch n := node.(type) {
239+
case *ast.FuncDecl:
240+
//nolint:nestif
241+
if n.Recv != nil && len(n.Recv.List) > 0 {
242+
if t, ok := n.Recv.List[0].Type.(*ast.StarExpr); ok {
243+
if ident, ok := t.X.(*ast.Ident); ok {
244+
if ident.Name == config.GoCRUDTypeNameUnexported() {
245+
methods = append(methods, &ast.Field{
246+
Names: []*ast.Ident{{Name: n.Name.Name}},
247+
Type: n.Type,
248+
})
249+
}
250+
}
251+
}
252+
}
253+
default:
254+
// noop
255+
}
256+
return true
257+
})
258+
}
259+
astFile.Decls = append(astFile.Decls,
260+
&ast.GenDecl{
261+
Tok: token.TYPE,
262+
Specs: []ast.Spec{
263+
&ast.TypeSpec{
264+
Name: &ast.Ident{Name: config.GoCRUDTypeName()},
265+
Type: &ast.InterfaceType{
266+
Methods: &ast.FieldList{List: methods},
267+
},
268+
},
269+
},
270+
},
271+
)
272+
273+
// func NewCRUD() CRUD {
274+
// return &_CRUD{}
275+
// }
152276
astFile.Decls = append(astFile.Decls,
153277
&ast.FuncDecl{
154-
Name: &ast.Ident{Name: "NewQueryer"},
155-
Type: &ast.FuncType{Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}}},
156-
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.Ident{Name: "Queryer{}"}}}}}},
278+
Name: &ast.Ident{Name: "New" + config.GoCRUDTypeName()},
279+
Type: &ast.FuncType{Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.Ident{Name: config.GoCRUDTypeName()}}}}},
280+
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported() + "{}"}}}}}},
157281
},
158282
)
159283

internal/arcgen/lang/go/generate_crud_create.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"go/token"
66
"strconv"
77
"strings"
8+
9+
"github.com/kunitsucom/arcgen/internal/config"
810
)
911

1012
//nolint:funlen
@@ -13,13 +15,14 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
1315
structName := arcSrc.extractStructName()
1416
tableName := arcSrc.extractTableNameFromCommentGroup()
1517
tableInfo := arcSrc.extractFieldNamesAndColumnNames()
16-
columnNames := tableInfo.ColumnNames()
18+
columnNames := tableInfo.Columns.ColumnNames()
1719

18-
// const Create{StructName}Query = `INSERT INTO {table_name} ({column_name1}, {column_name2}) VALUES (?, ?)`
20+
// const Create{StructName}Query = `INSERT INTO {table_name} ({column_name1}, {column_name2}) VALUES ($1, $2)`
1921
//
20-
// func (q *query) Create{StructName}(ctx context.Context, queryer sqlContext, s *{Struct}) error {
21-
// if _, err := queryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
22-
// return fmt.Errorf("q.queryer.ExecContext: %w", err)
22+
// func (q *query) Create{StructName}(ctx context.Context, queryer sqlQueryerContext, s *{Struct}) error {
23+
// LoggerFromContext(ctx).Debug(Create{StructName}Query)
24+
// if _, err := sqlContext.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
25+
// return fmt.Errorf("sqlContext.ExecContext: %w", err)
2326
// }
2427
// return nil
2528
// }
@@ -33,18 +36,18 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
3336
Names: []*ast.Ident{{Name: queryName}},
3437
Values: []ast.Expr{&ast.BasicLit{
3538
Kind: token.STRING,
36-
Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (?" + strings.Repeat(", ?", len(columnNames)-1) + ")`",
39+
Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (" + columnValuesPlaceholder(columnNames) + ")`",
3740
}},
3841
},
3942
},
4043
},
4144
&ast.FuncDecl{
42-
Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}},
45+
Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}}}}},
4346
Name: &ast.Ident{Name: funcName},
4447
Type: &ast.FuncType{
4548
Params: &ast.FieldList{List: []*ast.Field{
4649
{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}},
47-
{Names: []*ast.Ident{{Name: "sqlCtx"}}, Type: &ast.Ident{Name: "sqlContext"}},
50+
{Names: []*ast.Ident{{Name: sqlQueryerContextVarName}}, Type: &ast.Ident{Name: sqlQueryerContextTypeName}},
4851
{Names: []*ast.Ident{{Name: "s"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "dao." + structName}}},
4952
}},
5053
Results: &ast.FieldList{List: []*ast.Field{
@@ -53,14 +56,24 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
5356
},
5457
Body: &ast.BlockStmt{
5558
List: []ast.Stmt{
59+
&ast.ExprStmt{
60+
// LoggerFromContext(ctx).Debug(queryName)
61+
X: &ast.CallExpr{
62+
Fun: &ast.SelectorExpr{
63+
X: &ast.CallExpr{Fun: &ast.Ident{Name: "LoggerFromContext"}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}}},
64+
Sel: &ast.Ident{Name: "Debug"},
65+
},
66+
Args: []ast.Expr{&ast.Ident{Name: queryName}},
67+
},
68+
},
5669
&ast.IfStmt{
57-
// if _, err := queryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
70+
// if _, err := sqlQueryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
5871
Init: &ast.AssignStmt{
5972
Lhs: []ast.Expr{&ast.Ident{Name: "_"}, &ast.Ident{Name: "err"}},
6073
Tok: token.DEFINE,
6174
Rhs: []ast.Expr{&ast.CallExpr{
6275
Fun: &ast.SelectorExpr{
63-
X: &ast.Ident{Name: "sqlCtx"},
76+
X: &ast.Ident{Name: sqlQueryerContextVarName},
6477
Sel: &ast.Ident{Name: "ExecContext"},
6578
},
6679
Args: append(
@@ -80,10 +93,10 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
8093
// err != nil {
8194
Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "err"}, Op: token.NEQ, Y: &ast.Ident{Name: "nil"}},
8295
Body: &ast.BlockStmt{List: []ast.Stmt{
83-
// return fmt.Errorf("queryer.ExecContext: %w", err)
96+
// return fmt.Errorf("sqlContext.ExecContext: %w", err)
8497
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{
8598
Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "fmt"}, Sel: &ast.Ident{Name: "Errorf"}},
86-
Args: []ast.Expr{&ast.Ident{Name: strconv.Quote("queryer.ExecContext: %w")}, &ast.Ident{Name: "err"}},
99+
Args: []ast.Expr{&ast.Ident{Name: strconv.Quote(sqlQueryerContextVarName + ".ExecContext: %w")}, &ast.Ident{Name: "err"}},
87100
}}},
88101
}},
89102
},

0 commit comments

Comments
 (0)