diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 19d0d02..3a161f1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,7 +8,10 @@ on: jobs: test: - runs-on: ubuntu-latest + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 diff --git a/main.go b/main.go index 43d3074..65b40ce 100644 --- a/main.go +++ b/main.go @@ -15,7 +15,7 @@ import ( "github.com/houqp/sqlvet/pkg/vet" ) -const version = "1.1.10" +const version = "1.1.11" var ( gitCommit = "?" diff --git a/pkg/vet/gosource.go b/pkg/vet/gosource.go index e171b88..40c7656 100644 --- a/pkg/vet/gosource.go +++ b/pkg/vet/gosource.go @@ -30,6 +30,20 @@ var ( ErrQueryArgTODO = errors.New("TODO: support this type") ) +const ( + sqlxLib = "github.com/jmoiron/sqlx" + dbSqlLib = "database/sql" + gormLib = "github.com/jinzhu/gorm" + goGorpLib = "go-gorp/gorp" + gorpV1Lib = "gopkg.in/gorp.v1" + + queryArgName = "query" + sqlArgName = "sql" + + rebindMethodName = "Rebind" + rebindxMethodName = "Rebindx" +) + type QuerySite struct { Called string Position token.Position @@ -100,10 +114,17 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc { sqlfuncs := []MatchedSqlFunc{} s.IterPackageExportedFuncs(func(fobj *types.Func) { + ssaFunc := prog.FuncValue(fobj) + + // Skip pass-through functions that shouldn't be validated as SQL functions + if isPassThroughFunc(ssaFunc) { + return + } + for _, rule := range s.Rules { if rule.FuncName != "" && fobj.Name() == rule.FuncName { sqlfuncs = append(sqlfuncs, MatchedSqlFunc{ - SSA: prog.FuncValue(fobj), + SSA: ssaFunc, QueryArgPos: rule.QueryArgPos, }) // callable matched one rule, no need to go through the rest @@ -120,7 +141,7 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc { continue } sqlfuncs = append(sqlfuncs, MatchedSqlFunc{ - SSA: prog.FuncValue(fobj), + SSA: ssaFunc, QueryArgPos: rule.QueryArgPos, }) // callable matched one rule, no need to go through the rest @@ -132,13 +153,29 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc { return sqlfuncs } +// isNamedQueryFunc checks if a function name is a "named query" function +// that expects named parameters (like :param) instead of positional ($1, $2) +func isNamedQueryFunc(funcName string) bool { + // Check for sqlx named query functions + switch funcName { + case "NamedExec", "NamedQuery", "NamedExecContext", "NamedQueryContext", + "NamedQueryRow", "NamedQueryRowContext": + return true + } + // Also check if the function name contains "Named" (catches custom wrappers) + return strings.Contains(funcName, "Named") +} + func handleQuery(ctx VetContext, qs *QuerySite) { - // TODO: apply named query resolution based on v.X type and v.Sel.Name - // e.g. for sqlx, only apply to NamedExec and NamedQuery - qs.Query, _, qs.Err = parseutil.CompileNamedQuery( - []byte(qs.Query), parseutil.BindType("postgres")) - if qs.Err != nil { - return + // Only apply named query resolution for named query functions + // (e.g., NamedExec, NamedQuery, NamedExecContext, NamedQueryContext) + // to avoid breaking PostgreSQL type casts (::) in regular queries + if isNamedQueryFunc(qs.Called) { + qs.Query, _, qs.Err = parseutil.CompileNamedQuery( + []byte(qs.Query), parseutil.BindType("postgres")) + if qs.Err != nil { + return + } } var queryParams []QueryParam @@ -160,31 +197,31 @@ func handleQuery(ctx VetContext, qs *QuerySite) { func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher { matchers := []*SqlFuncMatcher{ { - PkgPath: "github.com/jmoiron/sqlx", + PkgPath: sqlxLib, Rules: []SqlFuncMatchRule{ - {QueryArgName: "query"}, - {QueryArgName: "sql"}, + {QueryArgName: queryArgName}, + {QueryArgName: sqlArgName}, // for methods with Context suffix - {QueryArgName: "query", QueryArgPos: 1}, - {QueryArgName: "sql", QueryArgPos: 1}, - {QueryArgName: "query", QueryArgPos: 2}, - {QueryArgName: "sql", QueryArgPos: 2}, + {QueryArgName: queryArgName, QueryArgPos: 1}, + {QueryArgName: sqlArgName, QueryArgPos: 1}, + {QueryArgName: queryArgName, QueryArgPos: 2}, + {QueryArgName: sqlArgName, QueryArgPos: 2}, }, }, { - PkgPath: "database/sql", + PkgPath: dbSqlLib, Rules: []SqlFuncMatchRule{ - {QueryArgName: "query"}, - {QueryArgName: "sql"}, + {QueryArgName: queryArgName}, + {QueryArgName: sqlArgName}, // for methods with Context suffix - {QueryArgName: "query", QueryArgPos: 1}, - {QueryArgName: "sql", QueryArgPos: 1}, + {QueryArgName: queryArgName, QueryArgPos: 1}, + {QueryArgName: sqlArgName, QueryArgPos: 1}, }, }, { - PkgPath: "github.com/jinzhu/gorm", + PkgPath: gormLib, Rules: []SqlFuncMatchRule{ - {QueryArgName: "sql"}, + {QueryArgName: sqlArgName}, }, }, // TODO: xorm uses vararg, which is not supported yet @@ -201,15 +238,15 @@ func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher { // }, // }, { - PkgPath: "go-gorp/gorp", + PkgPath: goGorpLib, Rules: []SqlFuncMatchRule{ - {QueryArgName: "query"}, + {QueryArgName: queryArgName}, }, }, { - PkgPath: "gopkg.in/gorp.v1", + PkgPath: gorpV1Lib, Rules: []SqlFuncMatchRule{ - {QueryArgName: "query"}, + {QueryArgName: queryArgName}, }, }, } @@ -240,7 +277,7 @@ func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error) } dirAbs, err := filepath.Abs(dir) if err != nil { - return nil, fmt.Errorf("Invalid path: %w", err) + return nil, fmt.Errorf("invalid path: %w", err) } pkgPath := dirAbs + "/..." pkgs, err := packages.Load(cfg, pkgPath) @@ -250,12 +287,54 @@ func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error) // return early if any syntax error for _, pkg := range pkgs { if len(pkg.Errors) > 0 { - return nil, fmt.Errorf("Failed to load package, %w", pkg.Errors[0]) + return nil, fmt.Errorf("failed to load package, %w", pkg.Errors[0]) } } return pkgs, nil } +// isPassThroughMethodName checks if a method name is known to be a pass-through +func isPassThroughMethodName(methodName string) bool { + switch methodName { + case rebindMethodName, rebindxMethodName: + return true + } + return false +} + +// isPassThroughFunc checks if a function is known to be a pass-through +// that transforms query syntax without changing semantic meaning +func isPassThroughFunc(fn *ssa.Function) bool { + if fn == nil { + return false + } + + // Get the package path and function name + if fn.Pkg != nil && fn.Pkg.Pkg != nil { + pkgPath := fn.Pkg.Pkg.Path() + funcName := fn.Name() + + // sqlx package pass-through functions + if pkgPath == sqlxLib && isPassThroughMethodName(funcName) { + return true + } + } + + // Check by receiver type for methods + if fn.Signature.Recv() != nil { + recv := fn.Signature.Recv() + recvType := recv.Type().String() + funcName := fn.Name() + + // sqlx methods that are pass-through + if strings.HasPrefix(recvType, sqlxLib+".") && isPassThroughMethodName(funcName) { + return true + } + } + + return false +} + func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) { queryStr := "" @@ -292,11 +371,95 @@ func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) { return "", ErrQueryArgTODO case *ssa.Extract: // query string is from one of the multi return values - // need to figure out how to trace string from function returns + // Try to trace the source of the multi-value return + if queryArg.Tuple == nil { + return "", ErrQueryArgTODO + } + + // Check if the tuple comes from a function call + if call, ok := queryArg.Tuple.(*ssa.Call); ok { + callee := call.Call.StaticCallee() + if callee == nil { + return "", ErrQueryArgTODO + } + + // Check if the function has a body + if len(callee.Blocks) == 0 { + // External function, can't trace further + return "", ErrQueryArgTODO + } + + // Look for return instructions and extract the specific index + for _, block := range callee.Blocks { + for _, instr := range block.Instrs { + if ret, ok := instr.(*ssa.Return); ok { + if queryArg.Index >= len(ret.Results) { + continue + } + // Extract the query string from the specific return value at this index + return extractQueryStrFromSsaValue(ret.Results[queryArg.Index]) + } + } + } + } + return "", ErrQueryArgTODO case *ssa.Call: // return value from a function call - // TODO: trace caller function + // Try to trace the function to extract the query string + callee := queryArg.Call.StaticCallee() + + // Check if this is a known pass-through function call + // For interface calls, callee will be nil, so we check by method name + if callee == nil { + // Dynamic call (interface method, function value, etc.) + // Check if it's a known pass-through method by name + if queryArg.Call.IsInvoke() { + method := queryArg.Call.Method + if method != nil && isPassThroughMethodName(method.Name()) { + // Extract the query from the first argument + callArgs := queryArg.Call.Args + if len(callArgs) > 0 { + return extractQueryStrFromSsaValue(callArgs[0]) + } + } + } + return "", ErrQueryArgUnsafe + } + + // Handle known pass-through functions that just transform the query + // without changing its semantic meaning (e.g., sqlx.Rebind) + if isPassThroughFunc(callee) { + // Extract the query from the first argument + callArgs := queryArg.Call.Args + if len(callArgs) > 0 { + // For method calls, the receiver is not in Args, so Args[0] is the first parameter + return extractQueryStrFromSsaValue(callArgs[0]) + } + return "", ErrQueryArgUnsafe + } + + // Check if the function has a body (not external or builtin) + if len(callee.Blocks) == 0 { + return "", ErrQueryArgUnsafe + } + + // Look for return instructions in the function + // This handles simple cases where the function returns a constant or computed value + for _, block := range callee.Blocks { + for _, instr := range block.Instrs { + if ret, ok := instr.(*ssa.Return); ok { + if len(ret.Results) == 0 { + continue + } + // Recursively extract the query string from the first return value + // This handles cases like: + // func getQuery() string { return "SELECT * FROM users" } + return extractQueryStrFromSsaValue(ret.Results[0]) + } + } + } + return "", ErrQueryArgUnsafe case *ssa.MakeInterface: // query function takes interface as input @@ -346,7 +509,7 @@ func shouldIgnoreNode(ignoreNodes []ast.Node, callSitePos token.Pos) bool { } func iterCallGraphNodeCallees(ctx VetContext, cgNode *callgraph.Node, prog *ssa.Program, sqlfunc MatchedSqlFunc, ignoreNodes []ast.Node) []*QuerySite { - queries := []*QuerySite{} + var queries []*QuerySite for _, inEdge := range cgNode.In { callerFunc := inEdge.Caller.Func @@ -490,7 +653,7 @@ func getSortedIgnoreNodes(pkgs []*packages.Package) []ast.Node { func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMatcher) ([]*QuerySite, error) { _, err := os.Stat(filepath.Join(dir, "go.mod")) if os.IsNotExist(err) { - return nil, errors.New("sqlvet only supports projects using go modules for now.") + return nil, errors.New("sqlvet only supports projects using go modules for now") } pkgs, err := loadGoPackages(dir, buildFlags) @@ -520,11 +683,11 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat mode := ssa.InstantiateGenerics prog, ssaPkgs := ssautil.Packages(pkgs, mode) - log.Debug("Performaing whole-program analysis...") + log.Debug("Performing whole-program analysis...") prog.Build() // find ssa.Function for matched sqlfuncs from program - sqlfuncs := []MatchedSqlFunc{} + var sqlfuncs []MatchedSqlFunc for _, matcher := range matchers { if !matcher.PackageImported() { // if package is not imported, then no sqlfunc should be matched @@ -538,7 +701,7 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat mains := ssautil.MainPackages(ssaPkgs) log.Debug("Building call graph...") - funcs := []*ssa.Function{} + var funcs []*ssa.Function for _, fn := range mains { if main := fn.Func("main"); main != nil { funcs = append(funcs, main) @@ -553,7 +716,7 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat return nil, nil } - queries := []*QuerySite{} + var queries []*QuerySite cg := rtaRes.CallGraph for _, sqlfunc := range sqlfuncs { cgNode := cg.CreateNode(sqlfunc.SSA) diff --git a/pkg/vet/gosource_test.go b/pkg/vet/gosource_test.go index 7d9bdc1..e6bdcef 100644 --- a/pkg/vet/gosource_test.go +++ b/pkg/vet/gosource_test.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/houqp/sqlvet/pkg/schema" "github.com/houqp/sqlvet/pkg/vet" ) @@ -21,7 +22,7 @@ func (s GoSourceTmpDir) Construct(t *testing.T, fixtures struct{}) (string, stri assert.NoError(t, err) modpath := filepath.Join(dir, "go.mod") - err = ioutil.WriteFile(modpath, []byte(` + err = os.WriteFile(modpath, []byte(` module github.com/houqp/sqlvettest `), 0644) assert.NoError(t, err) @@ -52,7 +53,7 @@ func (s *GoSourceTests) SubTestInvalidSyntax(t *testing.T, fixtures struct { dir := fixtures.TmpDir fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, []byte(` + err := os.WriteFile(fpath, []byte(` package main func main() { @@ -102,7 +103,7 @@ func main() { `) fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, source, 0644) + err := os.WriteFile(fpath, source, 0644) assert.NoError(t, err) queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil) @@ -159,7 +160,7 @@ func main() { `) fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, source, 0644) + err := os.WriteFile(fpath, source, 0644) assert.NoError(t, err) queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil) @@ -191,6 +192,165 @@ func main() { assert.Equal(t, "SELECT id FROM foo", queries[5].Query) } +func (s *GoSourceTests) SubTestPkgDatabaseSqlxWithCustomWrapperAndRebindForDB(t *testing.T, fixtures struct { + TmpDir string `fixture:"GoSourceTmpDir"` +}) { + dir := fixtures.TmpDir + + source := []byte(` +package main + +import ( + "context" + "github.com/jmoiron/sqlx" + "database/sql" +) + +type CustomWrapper interface { + sqlx.QueryerContext + sqlx.ExecerContext + + SelectContext(ctx context.Context, dest any, query string, args ...any) error + GetContext(ctx context.Context, dest any, query string, args ...any) error + NamedExecContext(ctx context.Context, query string, arg any) (sql.Result, error) + BindNamed(query string, arg any) (string, []any, error) + Rebind(query string) string +} + +func main() { + var ctx = context.Background() + const queryTmpl = "SELECT * FROM test_schema.test_table WHERE id IN (?);" + var query, data, err = sqlx.In(queryTmpl, getArgs()) // sqlvet: ignore + if err != nil { + panic(err) + } + + var db CustomWrapper = &sqlx.DB{} + var entities []any + if err := db.GetContext(ctx, &entities, db.Rebind(query), data...); err != nil { + panic(err) + } +} + +func getArgs() []string { + return []string{"one","two"} +} + `) + + fpath := filepath.Join(dir, "main.go") + err := os.WriteFile(fpath, source, 0644) + assert.NoError(t, err) + + gomod := ` +module github.com/houqp/sqlvet + +go 1.24.0 + +require github.com/jmoiron/sqlx v1.4.0` + fpathGoMod := filepath.Join(dir, "go.mod") + err = os.WriteFile(fpathGoMod, []byte(gomod), 0644) + assert.NoError(t, err) + + gosum := ` + github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= + github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= + ` + fpathGoSum := filepath.Join(dir, "go.sum") + err = os.WriteFile(fpathGoSum, []byte(gosum), 0644) + assert.NoError(t, err) + + queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil) + if err != nil { + t.Fatalf("Failed to load package: %s", err.Error()) + return + } + assert.Equal(t, 0, len(queries)) +} + +func (s *GoSourceTests) SubTestPkgDatabaseSqlxWithCustomWrapperAndTypeConversion(t *testing.T, fixtures struct { + TmpDir string `fixture:"GoSourceTmpDir"` +}) { + dir := fixtures.TmpDir + + source := []byte(` +package main + +import ( + "context" + "github.com/jmoiron/sqlx" + "database/sql" +) + +type CustomWrapper interface { + sqlx.QueryerContext + sqlx.ExecerContext + + SelectContext(ctx context.Context, dest any, query string, args ...any) error + GetContext(ctx context.Context, dest any, query string, args ...any) error + NamedExecContext(ctx context.Context, query string, arg any) (sql.Result, error) + BindNamed(query string, arg any) (string, []any, error) + Rebind(query string) string +} + +func main() { + var ctx = context.Background() + const query = "SELECT CASE WHEN id = 'test1' AND coalesce((jsonbField::jsonb)->>'in_field', '') <> 'test2' THEN (jsonbField::jsonb) - 'in_field_2' ELSE jsonbField::jsonb END AS jsonbField FROM test_schema.test_tbl WHERE id = $1;" + + var db CustomWrapper = &sqlx.DB{} + var entities []any + if err := db.SelectContext(ctx, &entities, query, "test1"); err != nil { + panic(err) + } +} + `) + + fpath := filepath.Join(dir, "main.go") + err := os.WriteFile(fpath, source, 0644) + assert.NoError(t, err) + + gomod := ` +module github.com/houqp/sqlvet + +go 1.24.0 + +require github.com/jmoiron/sqlx v1.4.0` + fpathGoMod := filepath.Join(dir, "go.mod") + err = os.WriteFile(fpathGoMod, []byte(gomod), 0644) + assert.NoError(t, err) + + gosum := ` + github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= + github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= + ` + fpathGoSum := filepath.Join(dir, "go.sum") + err = os.WriteFile(fpathGoSum, []byte(gosum), 0644) + assert.NoError(t, err) + + schemaCtx := vet.NewContext(map[string]schema.Table{ + "test_tbl": { + Name: "test_tbl", + Columns: map[string]schema.Column{ + "id": {Name: "id", Type: "text"}, + "jsonbfield": {Name: "jsonbfield", Type: "jsonb"}, + }, + }, + }) + + queries, err := vet.CheckDir(schemaCtx, dir, "", nil) + if err != nil { + t.Fatalf("Failed to load package: %s", err.Error()) + return + } + + // Expect 2 queries (one for DB, one for Tx implementation) + // Both should validate successfully with :: operators preserved + assert.Equal(t, 2, len(queries)) + for _, q := range queries { + assert.NoError(t, q.Err) + assert.Contains(t, q.Query, "::") + } +} + // run sqlvet from parent dir func (s *GoSourceTests) SubTestCheckRelativeDir(t *testing.T, fixtures struct { TmpDir string `fixture:"GoSourceTmpDir"` @@ -205,7 +365,7 @@ func main() { `) fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, source, 0644) + err := os.WriteFile(fpath, source, 0644) assert.NoError(t, err) cwd, err := os.Getwd() @@ -255,7 +415,7 @@ func main() { `) fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, source, 0644) + err := os.WriteFile(fpath, source, 0644) assert.NoError(t, err) queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil) @@ -305,7 +465,7 @@ func main() { `) fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, source, 0644) + err := os.WriteFile(fpath, source, 0644) assert.NoError(t, err) _, err = vet.CheckDir(vet.VetContext{}, dir, "", nil) diff --git a/pkg/vet/vet.go b/pkg/vet/vet.go index a7e190c..89f13a5 100644 --- a/pkg/vet/vet.go +++ b/pkg/vet/vet.go @@ -413,6 +413,36 @@ func parseExpression(ctx VetContext, clause *pg_query.Node, parseRe *ParseResult parseRe.PostponedNodes = &PostponedNodes{} } parseRe.PostponedNodes.RangeSubselectNodes = append(parseRe.PostponedNodes.RangeSubselectNodes, clause.GetRangeSubselect()) + return nil + case clause.GetCaseExpr() != nil: + // CASE WHEN condition THEN result ELSE default END + caseExpr := clause.GetCaseExpr() + + // Parse all WHEN clauses + for _, arg := range caseExpr.Args { + if caseWhen := arg.GetCaseWhen(); caseWhen != nil { + // Parse the WHEN condition + if caseWhen.Expr != nil { + if err := parseExpression(ctx, caseWhen.Expr, parseRe); err != nil { + return err + } + } + // Parse the THEN result + if caseWhen.Result != nil { + if err := parseExpression(ctx, caseWhen.Result, parseRe); err != nil { + return err + } + } + } + } + + // Parse the ELSE default result + if caseExpr.Defresult != nil { + if err := parseExpression(ctx, caseExpr.Defresult, parseRe); err != nil { + return err + } + } + return nil default: return fmt.Errorf(