From 1f7da140224a314cdb1c01895681e93ab155eed1 Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Sun, 22 Feb 2026 20:26:27 -0800 Subject: [PATCH 1/6] go ci go --- .golangci.yml | 66 +++++++++++++++++++++++++++++++++++++ generator/generator.go | 50 ++++++++++++++++++++++++++-- generator/generator_test.go | 18 ++++++++++ generator/helpers.go | 55 ++++++++++++++++++++++++++++++- generator/types.go | 4 +-- main.go | 1 + 6 files changed, 188 insertions(+), 6 deletions(-) create mode 100644 .golangci.yml diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..bdbb9d6 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,66 @@ +version: "2" +formatters: + enable: + - gci + - gofumpt + settings: + gci: + sections: + - standard + - default + - alias + - prefix(github.com/kalbasit/fastcdc) + - blank + - dot + custom-order: true + exclusions: + generated: lax +linters: + enable: + - err113 + - errname + - exhaustive + - gochecknoglobals + - gochecknoinits + - goconst + - godot + - goheader + - gosec + - importas + - lll + - makezero + - misspell + - nakedret + - nestif + - nilerr + - nilnil + - nlreturn + - noctx + - nolintlint + - paralleltest + - prealloc + - predeclared + - revive + - rowserrcheck + - staticcheck + - tagliatelle + - testifylint + - testpackage + - unconvert + - unparam + - wastedassign + - whitespace + - wsl_v5 + - zerologlint + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + settings: + wsl_v5: + allow-first-in-block: true + allow-whole-block: false + branch-max-lines: 2 diff --git a/generator/generator.go b/generator/generator.go index 1595eca..d1094ba 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -29,6 +29,7 @@ func Run(querierPath string) { if err != nil { log.Fatalf("resolving querier path: %v", err) } + sourceDir := filepath.Dir(absQuerierPath) targetDir := filepath.Dir(sourceDir) // Parent of postgresdb is pkg/database @@ -37,6 +38,7 @@ func Run(querierPath string) { // 2. Identify used structs from source methods usedStructNames := make(map[string]bool) + for _, m := range sourceData.Methods { for _, p := range m.Params { cleanType := strings.TrimPrefix(p.Type, "[]") @@ -44,6 +46,7 @@ func Run(querierPath string) { usedStructNames[cleanType] = true } } + for _, r := range m.Returns { cleanType := strings.TrimPrefix(r.Type, "[]") if _, exists := sourceData.Structs[cleanType]; exists { @@ -56,6 +59,7 @@ func Run(querierPath string) { for name := range usedStructNames { sortedStructs = append(sortedStructs, sourceData.Structs[name]) } + sort.Slice(sortedStructs, func(i, j int) bool { return sortedStructs[i].Name < sortedStructs[j].Name }) @@ -65,25 +69,32 @@ func Run(querierPath string) { if !isDomainStruct(name) || strings.HasSuffix(name, "Params") || strings.HasSuffix(name, "Row") { continue } + hasID := false + for _, f := range sourceData.Structs[name].Fields { if f.Name == "ID" { hasID = true + break } } + if !hasID { continue } methodName := "Get" + name + "ByID" found := false + for _, m := range sourceData.Methods { if m.Name == methodName { found = true + break } } + if !found { log.Printf("Synthesizing %s\n", methodName) sourceData.Methods = append(sourceData.Methods, MethodInfo{ @@ -120,6 +131,7 @@ func Run(querierPath string) { // 6. Parse all target packages engineData := make(map[string]PackageData) + for _, engine := range engines { engineDir := filepath.Join(targetDir, engine.Package) engineData[engine.Name] = parsePackage(engineDir) @@ -133,12 +145,14 @@ func Run(querierPath string) { func parsePackage(dir string) PackageData { fset := token.NewFileSet() + pkgs, err := parser.ParseDir(fset, dir, nil, parser.ParseComments) if err != nil { log.Fatal(err) } var methods []MethodInfo + structs := make(map[string]StructInfo) for _, pkg := range pkgs { @@ -154,6 +168,7 @@ func parsePackage(dir string) PackageData { if !ok { return true } + for _, field := range interfaceType.Methods.List { m := MethodInfo{Name: field.Names[0].Name} if field.Doc != nil { @@ -178,18 +193,21 @@ func parsePackage(dir string) PackageData { if funcType.Results != nil { for _, res := range funcType.Results.List { typeStr := exprToString(res.Type) + m.Returns = append(m.Returns, Return{Type: typeStr}) - if typeStr == "error" { + switch typeStr { + case "error": m.ReturnsError = true - } else if typeStr == "Querier" { + case "Querier": m.ReturnsSelf = true m.HasValue = true - } else { + default: m.HasValue = true m.ReturnElem = strings.TrimPrefix(typeStr, "[]") } } } + m.IsCreate = strings.HasPrefix(m.Name, "Create") && isDomainStruct(m.ReturnElem) m.IsUpdate = strings.HasPrefix(m.Name, "Update") && isDomainStruct(m.ReturnElem) methods = append(methods, m) @@ -198,17 +216,21 @@ func parsePackage(dir string) PackageData { if structType, ok := typeSpec.Type.(*ast.StructType); ok { s := StructInfo{Name: typeSpec.Name.Name} + if structType.Fields != nil { for _, field := range structType.Fields.List { typeStr := exprToString(field.Type) tag := "" + if field.Tag != nil { unquoted, err := strconv.Unquote(field.Tag.Value) if err != nil { log.Fatalf("failed to unquote struct tag %s: %v", field.Tag.Value, err) } + tag = unquoted } + if len(field.Names) > 0 { for _, name := range field.Names { s.Fields = append(s.Fields, FieldInfo{Name: name.Name, Type: typeStr, Tag: tag}) @@ -218,8 +240,10 @@ func parsePackage(dir string) PackageData { } } } + structs[s.Name] = s } + return true }) } @@ -234,7 +258,9 @@ func parsePackage(dir string) PackageData { func generateModels(dir, packageName string, structs []StructInfo) { t := template.Must(template.New("models").Parse(modelsTemplate)) + var buf bytes.Buffer + data := map[string]interface{}{ "PackageName": packageName, "Structs": structs, @@ -242,6 +268,7 @@ func generateModels(dir, packageName string, structs []StructInfo) { if err := t.Execute(&buf, data); err != nil { log.Fatalf("executing models template: %v", err) } + writeFile(dir, generatedFilePrefix+"models.go", buf.Bytes()) } @@ -252,6 +279,7 @@ func generateQuerier(dir, packageName string, methods []MethodInfo) { }).Parse(querierTemplate)) var buf bytes.Buffer + data := map[string]interface{}{ "PackageName": packageName, "Methods": methods, @@ -259,18 +287,22 @@ func generateQuerier(dir, packageName string, methods []MethodInfo) { if err := t.Execute(&buf, data); err != nil { log.Fatalf("executing querier template: %v", err) } + writeFile(dir, generatedFilePrefix+"querier.go", buf.Bytes()) } func generateErrors(dir, packageName string) { t := template.Must(template.New("errors").Parse(errorsTemplate)) + var buf bytes.Buffer + data := map[string]interface{}{ "PackageName": packageName, } if err := t.Execute(&buf, data); err != nil { log.Fatalf("executing errors template: %v", err) } + writeFile(dir, generatedFilePrefix+"errors.go", buf.Bytes()) } @@ -293,38 +325,46 @@ func generateWrapper(dir, packageName, importBase string, engine Engine, methods return m } } + return MethodInfo{} }, "getTargetStruct": func(name string) StructInfo { if engData.Structs == nil { return StructInfo{} } + return engData.Structs[name] }, "joinParamsCall": func(params []Param, engPkg string, targetMethodName string) (string, error) { targetMethod := MethodInfo{} + if engData.Methods != nil { for _, m := range engData.Methods { if m.Name == targetMethodName { targetMethod = m + break } } } + return joinParamsCall(params, engPkg, targetMethod, engData.Structs, structs) }, "dict": func(values ...interface{}) (map[string]interface{}, error) { if len(values)%2 != 0 { return nil, fmt.Errorf("invalid dict call") } + dict := make(map[string]interface{}, len(values)/2) for i := 0; i < len(values); i += 2 { key, ok := values[i].(string) if !ok { return nil, fmt.Errorf("dict keys must be strings") } + dict[key] = values[i+1] } + return dict, nil }, "hasSuffix": strings.HasSuffix, @@ -355,11 +395,13 @@ func generateWrapper(dir, packageName, importBase string, engine Engine, methods if clause.keyword == "INSERT INTO " { tableName = strings.Trim(tableName, "()") } + return tableName, true } } } } + return "", false } @@ -381,6 +423,7 @@ func generateWrapper(dir, packageName, importBase string, engine Engine, methods } } } + return strings.ToLower(inflection.Plural(structName)) }, }).Parse(wrapperTemplate)) @@ -398,6 +441,7 @@ func generateWrapper(dir, packageName, importBase string, engine Engine, methods if err := t.Execute(&buf, data); err != nil { log.Fatalf("executing wrapper template: %v", err) } + writeFile(dir, fmt.Sprintf("%swrapper_%s.go", generatedFilePrefix, engine.Name), buf.Bytes()) } diff --git a/generator/generator_test.go b/generator/generator_test.go index c2c7ad7..427f7fd 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -186,8 +186,10 @@ func TestJoinParamsCall(t *testing.T) { got, err := generator.JoinParamsCall(tt.params, tt.engPkg, generator.MethodInfo{}, nil, nil) if (err != nil) != tt.wantErr { t.Errorf("JoinParamsCall() error = %v, wantErr %v", err, tt.wantErr) + return } + if got != tt.want { t.Errorf("JoinParamsCall() = %v, want %v", got, tt.want) } @@ -246,14 +248,17 @@ func TestWrapperTemplate(t *testing.T) { if len(values)%2 != 0 { return nil, fmt.Errorf("invalid dict call") } + dict := make(map[string]interface{}, len(values)/2) for i := 0; i < len(values); i += 2 { key, ok := values[i].(string) if !ok { return nil, fmt.Errorf("dict keys must be strings") } + dict[key] = values[i+1] } + return dict, nil }, "getTargetMethod": func(name string) generator.MethodInfo { @@ -267,6 +272,7 @@ func TestWrapperTemplate(t *testing.T) { Returns: []generator.Return{{Type: "error"}}, } } + return generator.MethodInfo{} }, "getTargetStruct": func(name string) generator.StructInfo { return structs[name] }, @@ -282,6 +288,7 @@ func TestWrapperTemplate(t *testing.T) { Returns: []generator.Return{{Type: "error"}}, } } + return generator.JoinParamsCall(params, engPkg, targetMethod, structs, structs) }, "hasSuffix": strings.HasSuffix, @@ -292,6 +299,7 @@ func TestWrapperTemplate(t *testing.T) { if m.ReturnsSelf { return "nil" } + return "0" }, "getTableName": func(structName string) string { return "users" }, @@ -341,7 +349,9 @@ func TestWrapperTemplate(t *testing.T) { } data["Methods"] = methods + buf.Reset() + if err := tmpl.Execute(&buf, data); err != nil { t.Fatalf("failed to execute template: %v", err) } @@ -365,7 +375,9 @@ func TestWrapperTemplate(t *testing.T) { } data["Methods"] = methods + buf.Reset() + if err := tmpl.Execute(&buf, data); err != nil { t.Fatalf("failed to execute template: %v", err) } @@ -374,6 +386,7 @@ func TestWrapperTemplate(t *testing.T) { if !strings.Contains(output, "nil, ErrNotFound") { t.Errorf("expected output to contain 'nil, ErrNotFound' for WithTx, but it didn't\n%s", output) } + if !strings.Contains(output, "nil, err") { t.Errorf("expected output to contain 'nil, err' for WithTx, but it didn't\n%s", output) } @@ -414,6 +427,7 @@ func TestWrapperTemplate(t *testing.T) { Returns: []generator.Return{{Type: "error"}}, } } + return generator.MethodInfo{} } @@ -426,6 +440,7 @@ func TestWrapperTemplate(t *testing.T) { }, Returns: []generator.Return{{Type: "error"}}, } + return generator.JoinParamsCall(params, engPkg, targetMethod, structs, structs) } funcMap["getTableName"] = func(structName string) string { return "users" } @@ -436,12 +451,15 @@ func TestWrapperTemplate(t *testing.T) { } data["Methods"] = methods + buf.Reset() + if err := tmpl.Execute(&buf, data); err != nil { t.Fatalf("failed to execute template: %v", err) } output = buf.String() + expectedConversion := "Bio: sql.NullString{String: user.Bio, Valid: true}" if !strings.Contains(output, expectedConversion) { t.Errorf("expected output to contain '%s', but it didn't\n%s", expectedConversion, output) diff --git a/generator/helpers.go b/generator/helpers.go index 355c57b..582c9d4 100644 --- a/generator/helpers.go +++ b/generator/helpers.go @@ -14,16 +14,19 @@ import ( func toSnakeCase(s string) string { var res []rune + for i, r := range s { if i > 0 && r >= 'A' && r <= 'Z' { // Check if previous was also uppercase (e.g. ID) prev := rune(s[i-1]) - if !(prev >= 'A' && prev <= 'Z') { + if prev < 'A' || prev > 'Z' { res = append(res, '_') } } + res = append(res, []rune(strings.ToLower(string(r)))[0]) } + return string(res) } @@ -31,6 +34,7 @@ func quote(e Engine, s string) string { if e.IsMySQL() { return "`" + s + "`" } + return `\"` + s + `\"` } @@ -41,6 +45,7 @@ func extractBulkFor(comment string) string { return parts[i+1] } } + return "" } @@ -67,6 +72,7 @@ func writeFile(dir, filename string, content []byte) { if err := os.WriteFile(filepath.Join(dir, filename), formatted, 0o644); err != nil { log.Fatal(err) } + fmt.Printf("Generated %s\n", filename) } @@ -77,6 +83,7 @@ func hasParam(name string, params []Param) bool { return true } } + return false } @@ -85,6 +92,7 @@ func paramHasField(paramName string, fieldName string, params []Param, structs m if param.Name == paramName { typeName := strings.TrimPrefix(param.Type, "[]") typeName = strings.TrimPrefix(typeName, "*") + typeParts := strings.Split(typeName, ".") if len(typeParts) > 1 { typeName = typeParts[len(typeParts)-1] @@ -97,9 +105,11 @@ func paramHasField(paramName string, fieldName string, params []Param, structs m } } } + return false } } + return false } @@ -108,6 +118,7 @@ func joinParamsSignature(params []Param) string { for _, param := range params { p = append(p, fmt.Sprintf("%s %s", param.Name, param.Type)) } + return strings.Join(p, ", ") } @@ -118,6 +129,7 @@ func JoinParamsCall(params []Param, engPkg string, targetMethod MethodInfo, targ func joinParamsCall(params []Param, engPkg string, targetMethod MethodInfo, targetStructs map[string]StructInfo, sourceStructs map[string]StructInfo) (string, error) { var p []string + for i, param := range params { if isDomainStructFunc(param.Type) { if strings.HasPrefix(param.Type, "[]") { @@ -133,13 +145,17 @@ func joinParamsCall(params []Param, engPkg string, targetMethod MethodInfo, targ targetStruct := targetStructs[targetParamType] var fields []string + for _, targetField := range targetStruct.Fields { var sourceField FieldInfo + found := false + for _, sf := range sourceStruct.Fields { if sf.Name == targetField.Name { sourceField = sf found = true + break } } @@ -154,6 +170,7 @@ func joinParamsCall(params []Param, engPkg string, targetMethod MethodInfo, targ fields = append(fields, conversion) } } + p = append(p, fmt.Sprintf("%s.%s{\n%s,\n}", engPkg, targetParamType, strings.Join(fields, ",\n"))) } else { p = append(p, fmt.Sprintf("%s.%s(%s)", engPkg, param.Type, param.Name)) @@ -172,6 +189,7 @@ func joinParamsCall(params []Param, engPkg string, targetMethod MethodInfo, targ } } } + return strings.Join(p, ", "), nil } @@ -180,6 +198,7 @@ func joinReturns(returns []Return) string { for _, ret := range returns { r = append(r, ret.Type) } + return strings.Join(r, ", ") } @@ -191,12 +210,14 @@ func firstReturnType(returns []Return) string { if len(returns) > 0 { return returns[0].Type } + return "" } // isDomainStructFunc checks if type is a "Domain Struct" based on naming convention. func isDomainStructFunc(t string) bool { t = strings.TrimPrefix(t, "[]") + return len(t) > 0 && t[0] >= 'A' && t[0] <= 'Z' && !strings.Contains(t, ".") && t != "Querier" } @@ -209,6 +230,7 @@ func zeroValue(t string) string { if isNumeric(t) { return "0" } + switch t { case "bool": return "false" @@ -217,12 +239,15 @@ func zeroValue(t string) string { case "error": return "nil" } + if strings.HasPrefix(t, "*") || strings.HasPrefix(t, "[]") || strings.HasPrefix(t, "map[") || t == "interface{}" { return "nil" } + if t == "sql.Result" || t == "Querier" { return "nil" } + return fmt.Sprintf("%s{}", t) } @@ -237,6 +262,7 @@ func isNumeric(t string) bool { case "byte", "rune": return true } + return false } @@ -244,6 +270,7 @@ func isStructType(t string) bool { if strings.HasPrefix(t, "sql.Null") { return true } + return false } @@ -330,9 +357,11 @@ func generateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, // Case 4: Both are sql.Null* types but different if isSqlNullType(sourceFieldType) && isSqlNullType(targetFieldType) { sourcePrimitive := getPrimitiveFromNullType(sourceFieldType) + targetPrimitive := getPrimitiveFromNullType(targetFieldType) if sourcePrimitive != "" && targetPrimitive != "" { sourceFieldName := getFieldNameForNullType(sourceFieldType) + targetValueFieldName := getFieldNameForNullType(targetFieldType) if sourcePrimitive == targetPrimitive { return fmt.Sprintf("%s: %s{%s: %s.%s, Valid: %s.Valid}", targetFieldName, targetFieldType, targetValueFieldName, sourceExpr, sourceFieldName, sourceExpr) @@ -347,9 +376,11 @@ func generateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, expectedPrimitive := getPrimitiveFromNullType(targetFieldType) if expectedPrimitive == sourceFieldType { fieldName := getFieldNameForNullType(targetFieldType) + return fmt.Sprintf("%s: %s{%s: %s, Valid: true}", targetFieldName, targetFieldType, fieldName, sourceExpr) } else if expectedPrimitive != "" { fieldName := getFieldNameForNullType(targetFieldType) + return fmt.Sprintf("%s: %s{%s: %s(%s), Valid: true}", targetFieldName, targetFieldType, fieldName, expectedPrimitive, sourceExpr) } } @@ -359,9 +390,11 @@ func generateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, primitive := getPrimitiveFromNullType(sourceFieldType) if primitive == targetFieldType { fieldName := getFieldNameForNullType(sourceFieldType) + return fmt.Sprintf("%s: %s.%s", targetFieldName, sourceExpr, fieldName) } else if primitive != "" { fieldName := getFieldNameForNullType(sourceFieldType) + return fmt.Sprintf("%s: %s(%s.%s)", targetFieldName, targetFieldType, sourceExpr, fieldName) } } @@ -374,6 +407,7 @@ func generateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, // Case 5b: interface{} source → sql.Null* target (SQLite nullable columns come as interface{}) if sourceFieldType == "interface{}" && isSqlNullType(targetFieldType) { primitive := getPrimitiveFromNullType(targetFieldType) + fieldName := getFieldNameForNullType(targetFieldType) if primitive != "" && fieldName != "" { return fmt.Sprintf( @@ -396,6 +430,7 @@ func hasSliceField(s StructInfo) bool { return true } } + return false } @@ -405,6 +440,7 @@ func getSliceField(s StructInfo) FieldInfo { return f } } + return FieldInfo{} } @@ -420,30 +456,39 @@ func findImportBase(targetDir string) string { if err != nil { log.Fatalf("reading go.mod at %s: %v", goModPath, err) } + moduleName := "" + for _, line := range strings.Split(string(data), "\n") { line = strings.TrimSpace(line) if strings.HasPrefix(line, "module ") { moduleName = strings.TrimSpace(strings.TrimPrefix(line, "module ")) + break } } + if moduleName == "" { log.Fatalf("could not find module directive in %s", goModPath) } + relPath, err := filepath.Rel(dir, targetDir) if err != nil { log.Fatalf("computing relative path: %v", err) } + if relPath == "." { return moduleName } + return moduleName + "/" + relPath } + parent := filepath.Dir(dir) if parent == dir { log.Fatalf("no go.mod found walking up from %s", targetDir) } + dir = parent } } @@ -454,24 +499,30 @@ func detectPackageName(dir string) string { if err != nil { return filepath.Base(dir) } + for _, e := range entries { if e.IsDir() { continue } + name := e.Name() if !strings.HasSuffix(name, ".go") { continue } + if strings.HasPrefix(name, "generated_") { continue } + if strings.HasSuffix(name, "_test.go") { continue } + data, err := os.ReadFile(filepath.Join(dir, name)) if err != nil { continue } + for _, line := range strings.Split(string(data), "\n") { line = strings.TrimSpace(line) if strings.HasPrefix(line, "package ") { @@ -480,10 +531,12 @@ func detectPackageName(dir string) string { if idx := strings.Index(pkg, " "); idx != -1 { pkg = pkg[:idx] } + return pkg } } } + return filepath.Base(dir) } diff --git a/generator/types.go b/generator/types.go index 7c027e6..cca0e08 100644 --- a/generator/types.go +++ b/generator/types.go @@ -1,6 +1,6 @@ package generator -// Engine configuration +// Engine configuration. type Engine struct { Name string // e.g. "sqlite" Package string // e.g. "sqlitedb" @@ -9,7 +9,7 @@ type Engine struct { func (e Engine) IsMySQL() bool { return e.Name == "mysql" } func (e Engine) IsPostgres() bool { return e.Name == "postgres" } -// MethodInfo holds extracted data from the AST +// MethodInfo holds extracted data from the AST. type MethodInfo struct { Name string Params []Param diff --git a/main.go b/main.go index 0c76633..89180c5 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ func main() { for _, arg := range os.Args[1:] { if arg != "--" && !strings.HasPrefix(arg, "-") { querierPath = arg + break } } From 8c20e9df26e58f5081291f4474949a3f852bf5c5 Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Sun, 22 Feb 2026 20:28:58 -0800 Subject: [PATCH 2/6] update import --- .golangci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index bdbb9d6..4e25594 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,7 +9,7 @@ formatters: - standard - default - alias - - prefix(github.com/kalbasit/fastcdc) + - prefix(github.com/kalbasit/sqlc-multi-db) - blank - dot custom-order: true From 670b3b0260068edc154a4ba344807026a5a6f7f0 Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Sun, 22 Feb 2026 20:38:21 -0800 Subject: [PATCH 3/6] add nix stuff --- .envrc | 3 + .gitignore | 48 ++++++++++++ flake.lock | 129 ++++++++++++++++++++++++++++++++ flake.nix | 40 ++++++++++ go.mod | 2 +- nix/devshells/flake-module.nix | 32 ++++++++ nix/formatter/flake-module.nix | 27 +++++++ nix/pre-commit/flake-module.nix | 22 ++++++ 8 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 .envrc create mode 100644 .gitignore create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 nix/devshells/flake-module.nix create mode 100644 nix/formatter/flake-module.nix create mode 100644 nix/pre-commit/flake-module.nix diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..dfc636e --- /dev/null +++ b/.envrc @@ -0,0 +1,3 @@ +watch_dir nix + +use_flake diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..98865fb --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Editor files +/.vscode/** +!/.vscode/settings.json +!/.vscode/tasks.json +!/.vscode/launch.json +!/.vscode/extensions.json + +# Go workspace file +go.work +go.work.sum + +# env file used by dbmate +.env + +# Go and Nix-related build result +/ncps +/result* +/.direnv + +# Pre-commit configuration (auto-generated) +/.pre-commit-config.yaml + +# Python Cache +__pycache__/ +*.pyc +*.pyo + +# Agents +/.claude/settings.local.json diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..ec518df --- /dev/null +++ b/flake.lock @@ -0,0 +1,129 @@ +{ + "nodes": { + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1767039857, + "narHash": "sha256-vNpUSpF5Nuw8xvDLj2KCwwksIbjua2LZCqhV1LNRDns=", + "owner": "NixOS", + "repo": "flake-compat", + "rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1769996383, + "narHash": "sha256-AnYjnFWgS49RlqX7LrC4uA+sCCDBj0Ry/WOJ5XWAsa0=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "57928607ea566b5db3ad13af0e57e921e6b12381", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "git-hooks-nix": { + "inputs": { + "flake-compat": "flake-compat", + "gitignore": "gitignore", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1770726378, + "narHash": "sha256-kck+vIbGOaM/dHea7aTBxdFYpeUl/jHOy5W3eyRvVx8=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "5eaaedde414f6eb1aea8b8525c466dc37bba95ae", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "git-hooks-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1771574726, + "narHash": "sha256-D1PA3xQv/s4W3lnR9yJFSld8UOLr0a/cBWMQMXS+1Qg=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c217913993d6c6f6805c3b1a3bda5e639adfde6d", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-25.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-parts": "flake-parts", + "git-hooks-nix": "git-hooks-nix", + "nixpkgs": "nixpkgs", + "treefmt-nix": "treefmt-nix" + } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1770228511, + "narHash": "sha256-wQ6NJSuFqAEmIg2VMnLdCnUc0b7vslUohqqGGD+Fyxk=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "337a4fe074be1042a35086f15481d763b8ddc0e7", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..8f51ef8 --- /dev/null +++ b/flake.nix @@ -0,0 +1,40 @@ +{ + description = "TODO"; + + inputs = { + + flake-parts = { + inputs.nixpkgs-lib.follows = "nixpkgs"; + url = "github:hercules-ci/flake-parts"; + }; + + git-hooks-nix = { + inputs.nixpkgs.follows = "nixpkgs"; + url = "github:cachix/git-hooks.nix"; + }; + + nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.11"; + + treefmt-nix = { + inputs.nixpkgs.follows = "nixpkgs"; + url = "github:numtide/treefmt-nix"; + }; + + }; + + outputs = + inputs@{ flake-parts, ... }: + flake-parts.lib.mkFlake { inherit inputs; } { + imports = [ + ./nix/devshells/flake-module.nix + ./nix/formatter/flake-module.nix + ./nix/pre-commit/flake-module.nix + ]; + systems = [ + "x86_64-linux" + "aarch64-linux" + "aarch64-darwin" + "x86_64-darwin" + ]; + }; +} diff --git a/go.mod b/go.mod index fdfc1d8..7c5cb8c 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/kalbasit/sqlc-multi-db -go 1.25.5 +go 1.25.6 require ( github.com/jinzhu/inflection v1.0.0 diff --git a/nix/devshells/flake-module.nix b/nix/devshells/flake-module.nix new file mode 100644 index 0000000..02d0f69 --- /dev/null +++ b/nix/devshells/flake-module.nix @@ -0,0 +1,32 @@ +{ + perSystem = + { + config, + pkgs, + ... + }: + { + devShells.default = pkgs.mkShell { + buildInputs = [ + pkgs.delve + pkgs.go + pkgs.golangci-lint + pkgs.pre-commit + ]; + + _GO_VERSION = "${pkgs.go.version}"; + _DBMATE_VERSION = "${pkgs.dbmate.version}"; + + # Disable hardening for fortify otherwize it's not possible to use Delve. + hardeningDisable = [ "fortify" ]; + + shellHook = '' + ${config.pre-commit.installationScript} + + if [[ "$(${pkgs.gnugrep}/bin/grep '^\(go \)[0-9.]*$' go.mod)" != "go ''${_GO_VERSION}" ]]; then + ${pkgs.gnused}/bin/sed -e "s:^\(go \)[0-9.]*$:\1''${_GO_VERSION}:" -i go.mod + fi + ''; + }; + }; +} diff --git a/nix/formatter/flake-module.nix b/nix/formatter/flake-module.nix new file mode 100644 index 0000000..c7b07ef --- /dev/null +++ b/nix/formatter/flake-module.nix @@ -0,0 +1,27 @@ +{ inputs, ... }: +{ + imports = [ inputs.treefmt-nix.flakeModule ]; + + perSystem = { + treefmt = { + settings.global.excludes = [ + ".agent/skills/**/*.md" + ".agent/workflows/*.md" + ".env" + ".envrc" + "LICENSE" + "renovate.json" + ]; + + programs = { + actionlint.enable = true; + deadnix.enable = true; + gofumpt.enable = true; + mdformat.enable = true; + nixfmt.enable = true; + statix.enable = true; + yamlfmt.enable = true; + }; + }; + }; +} diff --git a/nix/pre-commit/flake-module.nix b/nix/pre-commit/flake-module.nix new file mode 100644 index 0000000..4237ce4 --- /dev/null +++ b/nix/pre-commit/flake-module.nix @@ -0,0 +1,22 @@ +{ inputs, ... }: + +{ + imports = [ + inputs.git-hooks-nix.flakeModule + ]; + + perSystem = { + pre-commit.check.enable = false; + pre-commit.settings.hooks = { + check-merge-conflicts.enable = true; + deadnix.enable = true; + golangci-lint.enable = true; + no-commit-to-branch.enable = true; + no-commit-to-branch.settings.branch = [ "main" ]; + nixfmt-rfc-style.enable = true; + statix.enable = true; + trim-trailing-whitespace.enable = true; + yamlfmt.enable = true; + }; + }; +} From 06a0ccbf7e76402d27b2b6839026153bd979235a Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Sun, 22 Feb 2026 20:51:33 -0800 Subject: [PATCH 4/6] fix lint issues --- .golangci.yml | 4 + generator/constants.go | 23 +++ generator/errors.go | 21 +++ generator/exports.go | 39 +++-- generator/generator.go | 204 +++++++++++---------- generator/generator_test.go | 39 ++++- generator/helpers.go | 341 ++++++++++++++++++++---------------- 7 files changed, 412 insertions(+), 259 deletions(-) create mode 100644 generator/constants.go create mode 100644 generator/errors.go diff --git a/.golangci.yml b/.golangci.yml index 4e25594..09511fa 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -59,6 +59,10 @@ linters: - common-false-positives - legacy - std-error-handling + rules: + - path: generator/templates\.go + linters: + - lll settings: wsl_v5: allow-first-in-block: true diff --git a/generator/constants.go b/generator/constants.go new file mode 100644 index 0000000..00599ba --- /dev/null +++ b/generator/constants.go @@ -0,0 +1,23 @@ +package generator + +const ( + typeQuerier = "Querier" + typeAny = "interface{}" + typeBool = "bool" + typeString = "string" + zeroNil = "nil" + typeInt16 = "int16" + typeInt32 = "int32" + typeInt64 = "int64" + typeFloat64 = "float64" + typeByte = "byte" + + sqlNullString = "sql.NullString" + sqlNullInt64 = "sql.NullInt64" + sqlNullInt32 = "sql.NullInt32" + sqlNullInt16 = "sql.NullInt16" + sqlNullBool = "sql.NullBool" + sqlNullFloat64 = "sql.NullFloat64" + sqlNullTime = "sql.NullTime" + sqlNullByte = "sql.NullByte" +) diff --git a/generator/errors.go b/generator/errors.go new file mode 100644 index 0000000..96eb519 --- /dev/null +++ b/generator/errors.go @@ -0,0 +1,21 @@ +package generator + +import ( + "errors" + "fmt" +) + +var ( + errInvalidDictCall = errors.New("invalid dict call") + errDictKeysMustBeStrings = errors.New("dict keys must be strings") + + errSliceDomainStructNotSupported = errors.New( + "slices of domain structs are not supported as direct parameters, as they require a conversion loop" + + " to be generated. The auto-looping for bulk inserts handles this by operating on a struct" + + " parameter containing a slice", + ) +) + +func errUnsupportedSliceDomainStruct(t string) error { + return fmt.Errorf("unsupported parameter type: slice of domain struct %s: %w", t, errSliceDomainStructNotSupported) +} diff --git a/generator/exports.go b/generator/exports.go index dadc347..39eb54b 100644 --- a/generator/exports.go +++ b/generator/exports.go @@ -1,55 +1,60 @@ package generator +import "go/ast" + // This file exports internal functions for use in tests and by external callers. // ExprToString converts an AST expression to its string representation. -// This is exported for testing purposes. -var ExprToString = exprToString +func ExprToString(expr ast.Expr) string { return exprToString(expr) } // IsDomainStructFunc checks if a type string represents a domain struct. -var IsDomainStructFunc = isDomainStructFunc +func IsDomainStructFunc(t string) bool { return isDomainStructFunc(t) } // ZeroValue returns the zero value expression for a given type string. -var ZeroValue = zeroValue +func ZeroValue(t string) string { return zeroValue(t) } // ExtractBulkFor extracts the @bulk-for annotation value from a comment. -var ExtractBulkFor = extractBulkFor +func ExtractBulkFor(comment string) string { return extractBulkFor(comment) } // ToSingular converts a plural word to singular form. -var ToSingular = toSingular +func ToSingular(s string) string { return toSingular(s) } // JoinParamsSignature joins parameters into a function signature string. -var JoinParamsSignature = joinParamsSignature +func JoinParamsSignature(params []Param) string { return joinParamsSignature(params) } // JoinReturns joins return types into a comma-separated string. -var JoinReturns = joinReturns +func JoinReturns(returns []Return) string { return joinReturns(returns) } // IsSlice checks if a type string represents a slice. -var IsSlice = isSlice +func IsSlice(retType string) bool { return isSlice(retType) } // FirstReturnType returns the first return type from a Returns slice. -var FirstReturnType = firstReturnType +func FirstReturnType(returns []Return) string { return firstReturnType(returns) } // HasParam checks if a parameter with the given name exists. -var HasParam = hasParam +func HasParam(name string, params []Param) bool { return hasParam(name, params) } // ParamHasField checks if a parameter's struct type has a given field. -var ParamHasField = paramHasField +func ParamHasField(paramName, fieldName string, params []Param, structs map[string]StructInfo) bool { + return paramHasField(paramName, fieldName, params, structs) +} // HasSliceField checks if a struct has a slice field. -var HasSliceField = hasSliceField +func HasSliceField(s StructInfo) bool { return hasSliceField(s) } // GetSliceField returns the first slice field of a struct. -var GetSliceField = getSliceField +func GetSliceField(s StructInfo) FieldInfo { return getSliceField(s) } // ToSnakeCase converts a CamelCase string to snake_case. -var ToSnakeCase = toSnakeCase +func ToSnakeCase(s string) string { return toSnakeCase(s) } // Quote wraps a string in engine-appropriate quotes. -var Quote = quote +func Quote(e Engine, s string) string { return quote(e, s) } // GenerateFieldConversion generates field conversion code. -var GenerateFieldConversion = generateFieldConversion +func GenerateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, sourceExpr string) string { + return generateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, sourceExpr) +} // WrapperTemplate is the template for generating wrapper files. const WrapperTemplate = wrapperTemplate diff --git a/generator/generator.go b/generator/generator.go index d1094ba..feceb74 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -16,15 +16,15 @@ import ( "github.com/jinzhu/inflection" ) -var engines = []Engine{ - {Name: "sqlite", Package: "sqlitedb"}, - {Name: "postgres", Package: "postgresdb"}, - {Name: "mysql", Package: "mysqldb"}, -} - // Run is the main entry point for the generator. // querierPath is the path to the source querier.go file (e.g., postgresdb/querier.go). func Run(querierPath string) { + engines := []Engine{ + {Name: "sqlite", Package: "sqlitedb"}, + {Name: "postgres", Package: "postgresdb"}, + {Name: "mysql", Package: "mysqldb"}, + } + absQuerierPath, err := filepath.Abs(querierPath) if err != nil { log.Fatalf("resolving querier path: %v", err) @@ -55,7 +55,7 @@ func Run(querierPath string) { } } - var sortedStructs []StructInfo + sortedStructs := make([]StructInfo, 0, len(usedStructNames)) for name := range usedStructNames { sortedStructs = append(sortedStructs, sourceData.Structs[name]) } @@ -139,8 +139,102 @@ func Run(querierPath string) { // 7. Generate wrappers for _, engine := range engines { - generateWrapper(targetDir, packageName, importBase, engine, sourceData.Methods, sourceData.Structs, engineData[engine.Name]) + generateWrapper( + targetDir, packageName, importBase, engine, + sourceData.Methods, sourceData.Structs, engineData[engine.Name], + ) + } +} + +func parseQuerierInterface(typeSpec *ast.TypeSpec) ([]MethodInfo, bool) { + if typeSpec.Name.Name != typeQuerier { + return nil, false + } + + interfaceType, ok := typeSpec.Type.(*ast.InterfaceType) + if !ok { + return nil, false + } + + methods := make([]MethodInfo, 0, len(interfaceType.Methods.List)) + + for _, field := range interfaceType.Methods.List { + m := MethodInfo{Name: field.Names[0].Name} + if field.Doc != nil { + for _, comment := range field.Doc.List { + m.Docs = append(m.Docs, comment.Text) + if strings.Contains(comment.Text, "@bulk-for") { + if bulkFor := extractBulkFor(comment.Text); bulkFor != "" { + m.BulkFor = bulkFor + } + } + } + } + + funcType := field.Type.(*ast.FuncType) + for _, param := range funcType.Params.List { + typeStr := exprToString(param.Type) + for _, name := range param.Names { + m.Params = append(m.Params, Param{Name: name.Name, Type: typeStr}) + } + } + + if funcType.Results != nil { + for _, res := range funcType.Results.List { + typeStr := exprToString(res.Type) + + m.Returns = append(m.Returns, Return{Type: typeStr}) + switch typeStr { + case "error": + m.ReturnsError = true + case typeQuerier: + m.ReturnsSelf = true + m.HasValue = true + default: + m.HasValue = true + m.ReturnElem = strings.TrimPrefix(typeStr, "[]") + } + } + } + + m.IsCreate = strings.HasPrefix(m.Name, "Create") && isDomainStruct(m.ReturnElem) + m.IsUpdate = strings.HasPrefix(m.Name, "Update") && isDomainStruct(m.ReturnElem) + methods = append(methods, m) + } + + return methods, true +} + +func parseStructType(typeSpec *ast.TypeSpec, structType *ast.StructType) StructInfo { + s := StructInfo{Name: typeSpec.Name.Name} + + if structType.Fields == nil { + return s + } + + for _, field := range structType.Fields.List { + typeStr := exprToString(field.Type) + tag := "" + + if field.Tag != nil { + unquoted, err := strconv.Unquote(field.Tag.Value) + if err != nil { + log.Fatalf("failed to unquote struct tag %s: %v", field.Tag.Value, err) + } + + tag = unquoted + } + + if len(field.Names) > 0 { + for _, name := range field.Names { + s.Fields = append(s.Fields, FieldInfo{Name: name.Name, Type: typeStr, Tag: tag}) + } + } else { + s.Fields = append(s.Fields, FieldInfo{Name: "", Type: typeStr, Tag: tag}) + } } + + return s } func parsePackage(dir string) PackageData { @@ -151,7 +245,7 @@ func parsePackage(dir string) PackageData { log.Fatal(err) } - var methods []MethodInfo + methods := make([]MethodInfo, 0, 32) structs := make(map[string]StructInfo) @@ -163,84 +257,12 @@ func parsePackage(dir string) PackageData { return true } - if typeSpec.Name.Name == "Querier" { - interfaceType, ok := typeSpec.Type.(*ast.InterfaceType) - if !ok { - return true - } - - for _, field := range interfaceType.Methods.List { - m := MethodInfo{Name: field.Names[0].Name} - if field.Doc != nil { - for _, comment := range field.Doc.List { - m.Docs = append(m.Docs, comment.Text) - if strings.Contains(comment.Text, "@bulk-for") { - if bulkFor := extractBulkFor(comment.Text); bulkFor != "" { - m.BulkFor = bulkFor - } - } - } - } - - funcType := field.Type.(*ast.FuncType) - for _, param := range funcType.Params.List { - typeStr := exprToString(param.Type) - for _, name := range param.Names { - m.Params = append(m.Params, Param{Name: name.Name, Type: typeStr}) - } - } - - if funcType.Results != nil { - for _, res := range funcType.Results.List { - typeStr := exprToString(res.Type) - - m.Returns = append(m.Returns, Return{Type: typeStr}) - switch typeStr { - case "error": - m.ReturnsError = true - case "Querier": - m.ReturnsSelf = true - m.HasValue = true - default: - m.HasValue = true - m.ReturnElem = strings.TrimPrefix(typeStr, "[]") - } - } - } - - m.IsCreate = strings.HasPrefix(m.Name, "Create") && isDomainStruct(m.ReturnElem) - m.IsUpdate = strings.HasPrefix(m.Name, "Update") && isDomainStruct(m.ReturnElem) - methods = append(methods, m) - } + if querierMethods, matched := parseQuerierInterface(typeSpec); matched { + methods = append(methods, querierMethods...) } if structType, ok := typeSpec.Type.(*ast.StructType); ok { - s := StructInfo{Name: typeSpec.Name.Name} - - if structType.Fields != nil { - for _, field := range structType.Fields.List { - typeStr := exprToString(field.Type) - tag := "" - - if field.Tag != nil { - unquoted, err := strconv.Unquote(field.Tag.Value) - if err != nil { - log.Fatalf("failed to unquote struct tag %s: %v", field.Tag.Value, err) - } - - tag = unquoted - } - - if len(field.Names) > 0 { - for _, name := range field.Names { - s.Fields = append(s.Fields, FieldInfo{Name: name.Name, Type: typeStr, Tag: tag}) - } - } else { - s.Fields = append(s.Fields, FieldInfo{Name: "", Type: typeStr, Tag: tag}) - } - } - } - + s := parseStructType(typeSpec, structType) structs[s.Name] = s } @@ -306,7 +328,13 @@ func generateErrors(dir, packageName string) { writeFile(dir, generatedFilePrefix+"errors.go", buf.Bytes()) } -func generateWrapper(dir, packageName, importBase string, engine Engine, methods []MethodInfo, structs map[string]StructInfo, engData PackageData) { +func generateWrapper( + dir, packageName, importBase string, + engine Engine, + methods []MethodInfo, + structs map[string]StructInfo, + engData PackageData, +) { t := template.Must(template.New("wrapper").Funcs(template.FuncMap{ "joinParamsSignature": joinParamsSignature, "joinReturns": joinReturns, @@ -352,14 +380,14 @@ func generateWrapper(dir, packageName, importBase string, engine Engine, methods }, "dict": func(values ...interface{}) (map[string]interface{}, error) { if len(values)%2 != 0 { - return nil, fmt.Errorf("invalid dict call") + return nil, errInvalidDictCall } dict := make(map[string]interface{}, len(values)/2) for i := 0; i < len(values); i += 2 { key, ok := values[i].(string) if !ok { - return nil, fmt.Errorf("dict keys must be strings") + return nil, errDictKeysMustBeStrings } dict[key] = values[i+1] @@ -456,7 +484,7 @@ func exprToString(expr ast.Expr) string { case *ast.ArrayType: return "[]" + exprToString(t.Elt) case *ast.InterfaceType: - return "interface{}" + return typeAny default: panic(fmt.Sprintf("unhandled expression type: %T", t)) } diff --git a/generator/generator_test.go b/generator/generator_test.go index 427f7fd..dd8cb45 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -2,7 +2,7 @@ package generator_test import ( "bytes" - "fmt" + "errors" "go/ast" "strings" "testing" @@ -11,7 +11,14 @@ import ( "github.com/kalbasit/sqlc-multi-db/generator" ) +var ( + errTestInvalidDictCall = errors.New("invalid dict call") + errTestDictKeysMustBeStrings = errors.New("dict keys must be strings") +) + func TestExprToString(t *testing.T) { + t.Parallel() + tests := []struct { name string expr ast.Expr @@ -47,6 +54,8 @@ func TestExprToString(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + defer func() { if r := recover(); r != nil { if !tt.panics { @@ -66,6 +75,8 @@ func TestExprToString(t *testing.T) { } func TestIsDomainStructFunc(t *testing.T) { + t.Parallel() + tests := []struct { inputType string want bool @@ -88,6 +99,8 @@ func TestIsDomainStructFunc(t *testing.T) { } func TestZeroValue(t *testing.T) { + t.Parallel() + tests := []struct { typeName string want string @@ -109,6 +122,8 @@ func TestZeroValue(t *testing.T) { } func TestExtractBulkFor(t *testing.T) { + t.Parallel() + tests := []struct { comment string want string @@ -128,6 +143,8 @@ func TestExtractBulkFor(t *testing.T) { } func TestToSingular(t *testing.T) { + t.Parallel() + tests := []struct { input string want string @@ -147,6 +164,8 @@ func TestToSingular(t *testing.T) { } func TestJoinParamsCall(t *testing.T) { + t.Parallel() + tests := []struct { name string params []generator.Param @@ -183,6 +202,8 @@ func TestJoinParamsCall(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := generator.JoinParamsCall(tt.params, tt.engPkg, generator.MethodInfo{}, nil, nil) if (err != nil) != tt.wantErr { t.Errorf("JoinParamsCall() error = %v, wantErr %v", err, tt.wantErr) @@ -198,6 +219,8 @@ func TestJoinParamsCall(t *testing.T) { } func TestWrapperTemplate(t *testing.T) { + t.Parallel() + // Mock engines sqlite := generator.Engine{Name: "sqlite", Package: "sqlitedb"} @@ -246,14 +269,14 @@ func TestWrapperTemplate(t *testing.T) { "trimPrefix": strings.TrimPrefix, "dict": func(values ...interface{}) (map[string]interface{}, error) { if len(values)%2 != 0 { - return nil, fmt.Errorf("invalid dict call") + return nil, errTestInvalidDictCall } dict := make(map[string]interface{}, len(values)/2) for i := 0; i < len(values); i += 2 { key, ok := values[i].(string) if !ok { - return nil, fmt.Errorf("dict keys must be strings") + return nil, errTestDictKeysMustBeStrings } dict[key] = values[i+1] @@ -302,7 +325,7 @@ func TestWrapperTemplate(t *testing.T) { return "0" }, - "getTableName": func(structName string) string { return "users" }, + "getTableName": func(_ string) string { return "users" }, } tmpl, err := template.New("wrapper").Funcs(funcMap).Parse(generator.WrapperTemplate) @@ -431,7 +454,7 @@ func TestWrapperTemplate(t *testing.T) { return generator.MethodInfo{} } - funcMap["joinParamsCall"] = func(params []generator.Param, engPkg string, targetMethodName string) (string, error) { + funcMap["joinParamsCall"] = func(params []generator.Param, engPkg string, _ string) (string, error) { targetMethod := generator.MethodInfo{ Name: "CreateUser", Params: []generator.Param{ @@ -443,7 +466,7 @@ func TestWrapperTemplate(t *testing.T) { return generator.JoinParamsCall(params, engPkg, targetMethod, structs, structs) } - funcMap["getTableName"] = func(structName string) string { return "users" } + funcMap["getTableName"] = func(_ string) string { return "users" } tmpl, err = template.New("wrapper").Funcs(funcMap).Parse(generator.WrapperTemplate) if err != nil { @@ -467,6 +490,8 @@ func TestWrapperTemplate(t *testing.T) { } func TestGenerateFieldConversion(t *testing.T) { + t.Parallel() + tests := []struct { name string targetFieldName string @@ -519,6 +544,8 @@ func TestGenerateFieldConversion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := generator.GenerateFieldConversion(tt.targetFieldName, tt.targetFieldType, tt.sourceFieldType, tt.sourceExpr) if got != tt.want { t.Errorf("GenerateFieldConversion() = %v, want %v", got, tt.want) diff --git a/generator/helpers.go b/generator/helpers.go index 582c9d4..0be4f0f 100644 --- a/generator/helpers.go +++ b/generator/helpers.go @@ -13,7 +13,7 @@ import ( ) func toSnakeCase(s string) string { - var res []rune + res := make([]rune, 0, len(s)) for i, r := range s { if i > 0 && r >= 'A' && r <= 'Z' { @@ -69,7 +69,7 @@ func writeFile(dir, filename string, content []byte) { log.Fatalf("formatting %s: %v", filename, err) } - if err := os.WriteFile(filepath.Join(dir, filename), formatted, 0o644); err != nil { + if err := os.WriteFile(filepath.Join(dir, filename), formatted, 0o644); err != nil { //nolint:gosec log.Fatal(err) } @@ -114,7 +114,7 @@ func paramHasField(paramName string, fieldName string, params []Param, structs m } func joinParamsSignature(params []Param) string { - var p []string + p := make([]string, 0, len(params)) for _, param := range params { p = append(p, fmt.Sprintf("%s %s", param.Name, param.Type)) } @@ -123,70 +123,102 @@ func joinParamsSignature(params []Param) string { } // JoinParamsCall is exported for use in tests. -func JoinParamsCall(params []Param, engPkg string, targetMethod MethodInfo, targetStructs map[string]StructInfo, sourceStructs map[string]StructInfo) (string, error) { +func JoinParamsCall( + params []Param, + engPkg string, + targetMethod MethodInfo, + targetStructs map[string]StructInfo, + sourceStructs map[string]StructInfo, +) (string, error) { return joinParamsCall(params, engPkg, targetMethod, targetStructs, sourceStructs) } -func joinParamsCall(params []Param, engPkg string, targetMethod MethodInfo, targetStructs map[string]StructInfo, sourceStructs map[string]StructInfo) (string, error) { - var p []string - - for i, param := range params { - if isDomainStructFunc(param.Type) { - if strings.HasPrefix(param.Type, "[]") { - return "", fmt.Errorf("unsupported parameter type: slice of domain struct %s. Slices of domain structs are not supported as direct parameters, as they require a conversion loop to be generated. The auto-looping for bulk inserts handles this by operating on a struct parameter containing a slice.", param.Type) - } else { - targetParamType := "" - if i < len(targetMethod.Params) { - targetParamType = targetMethod.Params[i].Type - } - - if targetParamType != "" { - sourceStruct := sourceStructs[param.Type] - targetStruct := targetStructs[targetParamType] +func joinDomainStructParam( + param Param, + i int, + engPkg string, + targetMethod MethodInfo, + targetStructs map[string]StructInfo, + sourceStructs map[string]StructInfo, +) (string, error) { + if strings.HasPrefix(param.Type, "[]") { + return "", errUnsupportedSliceDomainStruct(param.Type) + } - var fields []string + targetParamType := "" + if i < len(targetMethod.Params) { + targetParamType = targetMethod.Params[i].Type + } - for _, targetField := range targetStruct.Fields { - var sourceField FieldInfo + if targetParamType != "" { + sourceStruct := sourceStructs[param.Type] + targetStruct := targetStructs[targetParamType] - found := false + var fields []string - for _, sf := range sourceStruct.Fields { - if sf.Name == targetField.Name { - sourceField = sf - found = true + for _, targetField := range targetStruct.Fields { + var sourceField FieldInfo - break - } - } + found := false - if found { - conversion := generateFieldConversion( - targetField.Name, - targetField.Type, - sourceField.Type, - fmt.Sprintf("%s.%s", param.Name, sourceField.Name), - ) - fields = append(fields, conversion) - } - } + for _, sf := range sourceStruct.Fields { + if sf.Name == targetField.Name { + sourceField = sf + found = true - p = append(p, fmt.Sprintf("%s.%s{\n%s,\n}", engPkg, targetParamType, strings.Join(fields, ",\n"))) - } else { - p = append(p, fmt.Sprintf("%s.%s(%s)", engPkg, param.Type, param.Name)) + break } } - } else { - targetParamType := "" - if i < len(targetMethod.Params) { - targetParamType = targetMethod.Params[i].Type + + if found { + conversion := generateFieldConversion( + targetField.Name, + targetField.Type, + sourceField.Type, + fmt.Sprintf("%s.%s", param.Name, sourceField.Name), + ) + fields = append(fields, conversion) } + } + + return fmt.Sprintf("%s.%s{\n%s,\n}", engPkg, targetParamType, strings.Join(fields, ",\n")), nil + } + + return fmt.Sprintf("%s.%s(%s)", engPkg, param.Type, param.Name), nil +} + +func joinNonDomainParam(param Param, i int, targetMethod MethodInfo) string { + targetParamType := "" + if i < len(targetMethod.Params) { + targetParamType = targetMethod.Params[i].Type + } - if targetParamType != "" && targetParamType != param.Type { - p = append(p, fmt.Sprintf("%s(%s)", targetParamType, param.Name)) - } else { - p = append(p, param.Name) + if targetParamType != "" && targetParamType != param.Type { + return fmt.Sprintf("%s(%s)", targetParamType, param.Name) + } + + return param.Name +} + +func joinParamsCall( + params []Param, + engPkg string, + targetMethod MethodInfo, + targetStructs map[string]StructInfo, + sourceStructs map[string]StructInfo, +) (string, error) { + p := make([]string, 0, len(params)) + + for i, param := range params { + if isDomainStructFunc(param.Type) { + result, err := joinDomainStructParam(param, i, engPkg, targetMethod, targetStructs, sourceStructs) + if err != nil { + return "", err } + + p = append(p, result) + } else { + p = append(p, joinNonDomainParam(param, i, targetMethod)) } } @@ -194,7 +226,7 @@ func joinParamsCall(params []Param, engPkg string, targetMethod MethodInfo, targ } func joinReturns(returns []Return) string { - var r []string + r := make([]string, 0, len(returns)) for _, ret := range returns { r = append(r, ret.Type) } @@ -218,7 +250,7 @@ func firstReturnType(returns []Return) string { func isDomainStructFunc(t string) bool { t = strings.TrimPrefix(t, "[]") - return len(t) > 0 && t[0] >= 'A' && t[0] <= 'Z' && !strings.Contains(t, ".") && t != "Querier" + return len(t) > 0 && t[0] >= 'A' && t[0] <= 'Z' && !strings.Contains(t, ".") && t != typeQuerier } // isDomainStruct is used during parsing, same logic. @@ -232,20 +264,20 @@ func zeroValue(t string) string { } switch t { - case "bool": + case typeBool: return "false" - case "string": + case typeString: return `""` case "error": - return "nil" + return zeroNil } - if strings.HasPrefix(t, "*") || strings.HasPrefix(t, "[]") || strings.HasPrefix(t, "map[") || t == "interface{}" { - return "nil" + if strings.HasPrefix(t, "*") || strings.HasPrefix(t, "[]") || strings.HasPrefix(t, "map[") || t == typeAny { + return zeroNil } - if t == "sql.Result" || t == "Querier" { - return "nil" + if t == "sql.Result" || t == typeQuerier { + return zeroNil } return fmt.Sprintf("%s{}", t) @@ -253,13 +285,13 @@ func zeroValue(t string) string { func isNumeric(t string) bool { switch t { - case "int", "int8", "int16", "int32", "int64": + case "int", "int8", typeInt16, typeInt32, typeInt64: return true case "uint", "uint8", "uint16", "uint32", "uint64": return true - case "float32", "float64", "complex64", "complex128": + case "float32", typeFloat64, "complex64", "complex128": return true - case "byte", "rune": + case typeByte, "rune": return true } @@ -267,35 +299,31 @@ func isNumeric(t string) bool { } func isStructType(t string) bool { - if strings.HasPrefix(t, "sql.Null") { - return true - } - - return false + return strings.HasPrefix(t, "sql.Null") } -func isSqlNullType(t string) bool { +func isSQLNullType(t string) bool { return strings.HasPrefix(t, "sql.Null") } func getPrimitiveFromNullType(t string) string { switch t { - case "sql.NullString": - return "string" - case "sql.NullInt64": - return "int64" - case "sql.NullInt32": - return "int32" - case "sql.NullInt16": - return "int16" - case "sql.NullBool": - return "bool" - case "sql.NullFloat64": - return "float64" - case "sql.NullTime": + case sqlNullString: + return typeString + case sqlNullInt64: + return typeInt64 + case sqlNullInt32: + return typeInt32 + case sqlNullInt16: + return typeInt16 + case sqlNullBool: + return typeBool + case sqlNullFloat64: + return typeFloat64 + case sqlNullTime: return "time.Time" - case "sql.NullByte": - return "byte" + case sqlNullByte: + return typeByte default: return "" } @@ -303,22 +331,22 @@ func getPrimitiveFromNullType(t string) string { func getNullTypeFromPrimitive(t string) string { switch t { - case "string": - return "sql.NullString" - case "int64": - return "sql.NullInt64" - case "int32": - return "sql.NullInt32" - case "int16": - return "sql.NullInt16" - case "bool": - return "sql.NullBool" - case "float64": - return "sql.NullFloat64" + case typeString: + return sqlNullString + case typeInt64: + return sqlNullInt64 + case typeInt32: + return sqlNullInt32 + case typeInt16: + return sqlNullInt16 + case typeBool: + return sqlNullBool + case typeFloat64: + return sqlNullFloat64 case "time.Time": - return "sql.NullTime" - case "byte": - return "sql.NullByte" + return sqlNullTime + case typeByte: + return sqlNullByte default: return "" } @@ -326,21 +354,21 @@ func getNullTypeFromPrimitive(t string) string { func getFieldNameForNullType(t string) string { switch t { - case "sql.NullString": + case sqlNullString: return "String" - case "sql.NullInt64": + case sqlNullInt64: return "Int64" - case "sql.NullInt32": + case sqlNullInt32: return "Int32" - case "sql.NullInt16": + case sqlNullInt16: return "Int16" - case "sql.NullBool": + case sqlNullBool: return "Bool" - case "sql.NullFloat64": + case sqlNullFloat64: return "Float64" - case "sql.NullTime": + case sqlNullTime: return "Time" - case "sql.NullByte": + case sqlNullByte: return "Byte" default: return "" @@ -355,7 +383,7 @@ func generateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, } // Case 4: Both are sql.Null* types but different - if isSqlNullType(sourceFieldType) && isSqlNullType(targetFieldType) { + if isSQLNullType(sourceFieldType) && isSQLNullType(targetFieldType) { sourcePrimitive := getPrimitiveFromNullType(sourceFieldType) targetPrimitive := getPrimitiveFromNullType(targetFieldType) @@ -364,15 +392,23 @@ func generateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, targetValueFieldName := getFieldNameForNullType(targetFieldType) if sourcePrimitive == targetPrimitive { - return fmt.Sprintf("%s: %s{%s: %s.%s, Valid: %s.Valid}", targetFieldName, targetFieldType, targetValueFieldName, sourceExpr, sourceFieldName, sourceExpr) - } else { - return fmt.Sprintf("%s: %s{%s: %s(%s.%s), Valid: %s.Valid}", targetFieldName, targetFieldType, targetValueFieldName, targetPrimitive, sourceExpr, sourceFieldName, sourceExpr) + return fmt.Sprintf( + "%s: %s{%s: %s.%s, Valid: %s.Valid}", + targetFieldName, targetFieldType, targetValueFieldName, + sourceExpr, sourceFieldName, sourceExpr, + ) } + + return fmt.Sprintf( + "%s: %s{%s: %s(%s.%s), Valid: %s.Valid}", + targetFieldName, targetFieldType, targetValueFieldName, + targetPrimitive, sourceExpr, sourceFieldName, sourceExpr, + ) } } // Case 2: Converting from primitive to sql.Null* (skip interface{} — handled by Case 5b) - if isSqlNullType(targetFieldType) && sourceFieldType != "interface{}" { + if isSQLNullType(targetFieldType) && sourceFieldType != typeAny { expectedPrimitive := getPrimitiveFromNullType(targetFieldType) if expectedPrimitive == sourceFieldType { fieldName := getFieldNameForNullType(targetFieldType) @@ -381,12 +417,15 @@ func generateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, } else if expectedPrimitive != "" { fieldName := getFieldNameForNullType(targetFieldType) - return fmt.Sprintf("%s: %s{%s: %s(%s), Valid: true}", targetFieldName, targetFieldType, fieldName, expectedPrimitive, sourceExpr) + return fmt.Sprintf( + "%s: %s{%s: %s(%s), Valid: true}", + targetFieldName, targetFieldType, fieldName, expectedPrimitive, sourceExpr, + ) } } // Case 3: Converting from sql.Null* to primitive - if isSqlNullType(sourceFieldType) { + if isSQLNullType(sourceFieldType) { primitive := getPrimitiveFromNullType(sourceFieldType) if primitive == targetFieldType { fieldName := getFieldNameForNullType(sourceFieldType) @@ -405,13 +444,14 @@ func generateFieldConversion(targetFieldName, targetFieldType, sourceFieldType, } // Case 5b: interface{} source → sql.Null* target (SQLite nullable columns come as interface{}) - if sourceFieldType == "interface{}" && isSqlNullType(targetFieldType) { + if sourceFieldType == typeAny && isSQLNullType(targetFieldType) { primitive := getPrimitiveFromNullType(targetFieldType) fieldName := getFieldNameForNullType(targetFieldType) if primitive != "" && fieldName != "" { return fmt.Sprintf( - "%s: func() %s { if %s == nil { return %s{} }; v, ok := %s.(%s); if !ok { return %s{} }; return %s{%s: v, Valid: true} }()", + "%s: func() %s { if %s == nil { return %s{} }; v, ok := %s.(%s); if !ok { return %s{} };"+ + " return %s{%s: v, Valid: true} }()", targetFieldName, targetFieldType, sourceExpr, targetFieldType, sourceExpr, primitive, @@ -444,44 +484,49 @@ func getSliceField(s StructInfo) FieldInfo { return FieldInfo{} } -// findImportBase walks up from targetDir to find the nearest go.mod and computes -// the full import path for targetDir. -func findImportBase(targetDir string) string { - dir := targetDir - for { - goModPath := filepath.Join(dir, "go.mod") - if _, err := os.Stat(goModPath); err == nil { - // Found go.mod — read module name - data, err := os.ReadFile(goModPath) - if err != nil { - log.Fatalf("reading go.mod at %s: %v", goModPath, err) - } +func parseGoMod(goModPath, targetDir string) string { + data, err := os.ReadFile(goModPath) + if err != nil { + log.Fatalf("reading go.mod at %s: %v", goModPath, err) + } - moduleName := "" + moduleName := "" - for _, line := range strings.Split(string(data), "\n") { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "module ") { - moduleName = strings.TrimSpace(strings.TrimPrefix(line, "module ")) + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "module ") { + moduleName = strings.TrimSpace(strings.TrimPrefix(line, "module ")) - break - } - } + break + } + } - if moduleName == "" { - log.Fatalf("could not find module directive in %s", goModPath) - } + if moduleName == "" { + log.Fatalf("could not find module directive in %s", goModPath) + } - relPath, err := filepath.Rel(dir, targetDir) - if err != nil { - log.Fatalf("computing relative path: %v", err) - } + dir := filepath.Dir(goModPath) - if relPath == "." { - return moduleName - } + relPath, err := filepath.Rel(dir, targetDir) + if err != nil { + log.Fatalf("computing relative path: %v", err) + } + + if relPath == "." { + return moduleName + } - return moduleName + "/" + relPath + return moduleName + "/" + relPath +} + +// findImportBase walks up from targetDir to find the nearest go.mod and computes +// the full import path for targetDir. +func findImportBase(targetDir string) string { + dir := targetDir + for { + goModPath := filepath.Join(dir, "go.mod") + if _, err := os.Stat(goModPath); err == nil { + return parseGoMod(goModPath, targetDir) } parent := filepath.Dir(dir) From 636a1b14849c8dff9e1a3c89a4e395b1779e4621 Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Sun, 22 Feb 2026 20:57:36 -0800 Subject: [PATCH 5/6] add workflow --- .github/workflows/flake-update.yml | 49 ++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 .github/workflows/flake-update.yml diff --git a/.github/workflows/flake-update.yml b/.github/workflows/flake-update.yml new file mode 100644 index 0000000..481d04f --- /dev/null +++ b/.github/workflows/flake-update.yml @@ -0,0 +1,49 @@ +name: "Flake.lock: update Nix dependencies" +on: + workflow_dispatch: # allows manual triggering + schedule: + - cron: "0 0 * * 0" # runs weekly on Sunday at 00:00 +jobs: + nix-flake-update: + permissions: + contents: write + id-token: write + issues: write + pull-requests: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + token: ${{ secrets.GHA_PAT_TOKEN }} + - uses: cachix/install-nix-action@v31 + - uses: cachix/cachix-action@v16 + with: + name: kalbasit + authToken: ${{ secrets.CACHIX_AUTH_TOKEN }} + - name: update flake.lock and run go mod tidy + id: update-flake-lock-and-go-mod-tidy + run: | + nix flake update + nix develop --command go mod tidy + - uses: EndBug/add-and-commit@v9 + if: ${{ steps.update-flake-lock-and-go-mod-tidy.outcome == 'success' }} + id: commit + with: + default_author: github_actions + message: "chore: update flake.lock and run go mod tidy" + fetch: false + new_branch: "update-flake-lock" + push: --set-upstream origin "update-flake-lock" --force + - uses: thomaseizinger/create-pull-request@1.4.0 + if: ${{ steps.commit.outputs.pushed == 'true' }} + id: create_pr + with: + github_token: ${{ secrets.GHA_PAT_TOKEN }} + head: "update-flake-lock" + base: main + title: "chore: update flake.lock and run go mod tidy" + - name: enable automerge + if: ${{ steps.create_pr.outputs.created }} + run: gh pr merge --squash --auto "${{ steps.create_pr.outputs.number }}" + env: + GH_TOKEN: "${{ secrets.GHA_PAT_TOKEN }}" From 6e513ab9490eb296ed08ab3a035c5ab44a357069 Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Sun, 22 Feb 2026 21:03:24 -0800 Subject: [PATCH 6/6] fix(generator): skip embedded interfaces in parseQuerierInterface --- generator/generator.go | 4 ++++ generator/helpers.go | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/generator/generator.go b/generator/generator.go index feceb74..5ff4fed 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -159,6 +159,10 @@ func parseQuerierInterface(typeSpec *ast.TypeSpec) ([]MethodInfo, bool) { methods := make([]MethodInfo, 0, len(interfaceType.Methods.List)) for _, field := range interfaceType.Methods.List { + if len(field.Names) == 0 { + continue + } + m := MethodInfo{Name: field.Names[0].Name} if field.Doc != nil { for _, comment := range field.Doc.List { diff --git a/generator/helpers.go b/generator/helpers.go index 0be4f0f..84582a6 100644 --- a/generator/helpers.go +++ b/generator/helpers.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "strings" + "unicode" "github.com/jinzhu/inflection" "golang.org/x/tools/imports" @@ -24,7 +25,7 @@ func toSnakeCase(s string) string { } } - res = append(res, []rune(strings.ToLower(string(r)))[0]) + res = append(res, unicode.ToLower(r)) } return string(res)