Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ on:

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-go@v2
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/houqp/sqlvet/pkg/vet"
)

const version = "1.1.10"
const version = "1.1.11"

var (
gitCommit = "?"
Expand Down
235 changes: 199 additions & 36 deletions pkg/vet/gosource.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ var (
ErrQueryArgTODO = errors.New("TODO: support this type")
)

const (
sqlxLib = "github.com/jmoiron/sqlx"
dbSqlLib = "database/sql"
gormLib = "github.com/jinzhu/gorm"
goGorpLib = "go-gorp/gorp"
gorpV1Lib = "gopkg.in/gorp.v1"

queryArgName = "query"
sqlArgName = "sql"

rebindMethodName = "Rebind"
rebindxMethodName = "Rebindx"
)

type QuerySite struct {
Called string
Position token.Position
Expand Down Expand Up @@ -100,10 +114,17 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
sqlfuncs := []MatchedSqlFunc{}

s.IterPackageExportedFuncs(func(fobj *types.Func) {
ssaFunc := prog.FuncValue(fobj)

// Skip pass-through functions that shouldn't be validated as SQL functions
if isPassThroughFunc(ssaFunc) {
return
}

for _, rule := range s.Rules {
if rule.FuncName != "" && fobj.Name() == rule.FuncName {
sqlfuncs = append(sqlfuncs, MatchedSqlFunc{
SSA: prog.FuncValue(fobj),
SSA: ssaFunc,
QueryArgPos: rule.QueryArgPos,
})
// callable matched one rule, no need to go through the rest
Expand All @@ -120,7 +141,7 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
continue
}
sqlfuncs = append(sqlfuncs, MatchedSqlFunc{
SSA: prog.FuncValue(fobj),
SSA: ssaFunc,
QueryArgPos: rule.QueryArgPos,
})
// callable matched one rule, no need to go through the rest
Expand All @@ -132,13 +153,29 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
return sqlfuncs
}

// isNamedQueryFunc checks if a function name is a "named query" function
// that expects named parameters (like :param) instead of positional ($1, $2)
func isNamedQueryFunc(funcName string) bool {
// Check for sqlx named query functions
switch funcName {
case "NamedExec", "NamedQuery", "NamedExecContext", "NamedQueryContext",
"NamedQueryRow", "NamedQueryRowContext":
return true
}
// Also check if the function name contains "Named" (catches custom wrappers)
return strings.Contains(funcName, "Named")
}

func handleQuery(ctx VetContext, qs *QuerySite) {
// TODO: apply named query resolution based on v.X type and v.Sel.Name
// e.g. for sqlx, only apply to NamedExec and NamedQuery
qs.Query, _, qs.Err = parseutil.CompileNamedQuery(
[]byte(qs.Query), parseutil.BindType("postgres"))
if qs.Err != nil {
return
// Only apply named query resolution for named query functions
// (e.g., NamedExec, NamedQuery, NamedExecContext, NamedQueryContext)
// to avoid breaking PostgreSQL type casts (::) in regular queries
if isNamedQueryFunc(qs.Called) {
qs.Query, _, qs.Err = parseutil.CompileNamedQuery(
[]byte(qs.Query), parseutil.BindType("postgres"))
if qs.Err != nil {
return
}
}

var queryParams []QueryParam
Expand All @@ -160,31 +197,31 @@ func handleQuery(ctx VetContext, qs *QuerySite) {
func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher {
matchers := []*SqlFuncMatcher{
{
PkgPath: "github.com/jmoiron/sqlx",
PkgPath: sqlxLib,
Rules: []SqlFuncMatchRule{
{QueryArgName: "query"},
{QueryArgName: "sql"},
{QueryArgName: queryArgName},
{QueryArgName: sqlArgName},
// for methods with Context suffix
{QueryArgName: "query", QueryArgPos: 1},
{QueryArgName: "sql", QueryArgPos: 1},
{QueryArgName: "query", QueryArgPos: 2},
{QueryArgName: "sql", QueryArgPos: 2},
{QueryArgName: queryArgName, QueryArgPos: 1},
{QueryArgName: sqlArgName, QueryArgPos: 1},
{QueryArgName: queryArgName, QueryArgPos: 2},
{QueryArgName: sqlArgName, QueryArgPos: 2},
},
},
{
PkgPath: "database/sql",
PkgPath: dbSqlLib,
Rules: []SqlFuncMatchRule{
{QueryArgName: "query"},
{QueryArgName: "sql"},
{QueryArgName: queryArgName},
{QueryArgName: sqlArgName},
// for methods with Context suffix
{QueryArgName: "query", QueryArgPos: 1},
{QueryArgName: "sql", QueryArgPos: 1},
{QueryArgName: queryArgName, QueryArgPos: 1},
{QueryArgName: sqlArgName, QueryArgPos: 1},
},
},
{
PkgPath: "github.com/jinzhu/gorm",
PkgPath: gormLib,
Rules: []SqlFuncMatchRule{
{QueryArgName: "sql"},
{QueryArgName: sqlArgName},
},
},
// TODO: xorm uses vararg, which is not supported yet
Expand All @@ -201,15 +238,15 @@ func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher {
// },
// },
{
PkgPath: "go-gorp/gorp",
PkgPath: goGorpLib,
Rules: []SqlFuncMatchRule{
{QueryArgName: "query"},
{QueryArgName: queryArgName},
},
},
{
PkgPath: "gopkg.in/gorp.v1",
PkgPath: gorpV1Lib,
Rules: []SqlFuncMatchRule{
{QueryArgName: "query"},
{QueryArgName: queryArgName},
},
},
}
Expand Down Expand Up @@ -240,7 +277,7 @@ func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error)
}
dirAbs, err := filepath.Abs(dir)
if err != nil {
return nil, fmt.Errorf("Invalid path: %w", err)
return nil, fmt.Errorf("invalid path: %w", err)
}
pkgPath := dirAbs + "/..."
pkgs, err := packages.Load(cfg, pkgPath)
Expand All @@ -250,12 +287,54 @@ func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error)
// return early if any syntax error
for _, pkg := range pkgs {
if len(pkg.Errors) > 0 {
return nil, fmt.Errorf("Failed to load package, %w", pkg.Errors[0])
return nil, fmt.Errorf("failed to load package, %w", pkg.Errors[0])
}
}
return pkgs, nil
}

// isPassThroughMethodName checks if a method name is known to be a pass-through
func isPassThroughMethodName(methodName string) bool {
switch methodName {
case rebindMethodName, rebindxMethodName:
return true
}
return false
}

// isPassThroughFunc checks if a function is known to be a pass-through
// that transforms query syntax without changing semantic meaning
func isPassThroughFunc(fn *ssa.Function) bool {
if fn == nil {
return false
}

// Get the package path and function name
if fn.Pkg != nil && fn.Pkg.Pkg != nil {
pkgPath := fn.Pkg.Pkg.Path()
funcName := fn.Name()

// sqlx package pass-through functions
if pkgPath == sqlxLib && isPassThroughMethodName(funcName) {
return true
}
}

// Check by receiver type for methods
if fn.Signature.Recv() != nil {
recv := fn.Signature.Recv()
recvType := recv.Type().String()
funcName := fn.Name()

// sqlx methods that are pass-through
if strings.HasPrefix(recvType, sqlxLib+".") && isPassThroughMethodName(funcName) {
return true
}
}

return false
}

func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) {
queryStr := ""

Expand Down Expand Up @@ -292,11 +371,95 @@ func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) {
return "", ErrQueryArgTODO
case *ssa.Extract:
// query string is from one of the multi return values
// need to figure out how to trace string from function returns
// Try to trace the source of the multi-value return
if queryArg.Tuple == nil {
return "", ErrQueryArgTODO
}

// Check if the tuple comes from a function call
if call, ok := queryArg.Tuple.(*ssa.Call); ok {
callee := call.Call.StaticCallee()
if callee == nil {
return "", ErrQueryArgTODO
}

// Check if the function has a body
if len(callee.Blocks) == 0 {
// External function, can't trace further
return "", ErrQueryArgTODO
}

// Look for return instructions and extract the specific index
for _, block := range callee.Blocks {
for _, instr := range block.Instrs {
if ret, ok := instr.(*ssa.Return); ok {
if queryArg.Index >= len(ret.Results) {
continue
}
// Extract the query string from the specific return value at this index
return extractQueryStrFromSsaValue(ret.Results[queryArg.Index])
}
}
}
}

return "", ErrQueryArgTODO
case *ssa.Call:
// return value from a function call
// TODO: trace caller function
// Try to trace the function to extract the query string
callee := queryArg.Call.StaticCallee()

// Check if this is a known pass-through function call
// For interface calls, callee will be nil, so we check by method name
if callee == nil {
// Dynamic call (interface method, function value, etc.)
// Check if it's a known pass-through method by name
if queryArg.Call.IsInvoke() {
method := queryArg.Call.Method
if method != nil && isPassThroughMethodName(method.Name()) {
// Extract the query from the first argument
callArgs := queryArg.Call.Args
if len(callArgs) > 0 {
return extractQueryStrFromSsaValue(callArgs[0])
}
}
}
return "", ErrQueryArgUnsafe
}

// Handle known pass-through functions that just transform the query
// without changing its semantic meaning (e.g., sqlx.Rebind)
if isPassThroughFunc(callee) {
// Extract the query from the first argument
callArgs := queryArg.Call.Args
if len(callArgs) > 0 {
// For method calls, the receiver is not in Args, so Args[0] is the first parameter
return extractQueryStrFromSsaValue(callArgs[0])
}
return "", ErrQueryArgUnsafe
}

// Check if the function has a body (not external or builtin)
if len(callee.Blocks) == 0 {
return "", ErrQueryArgUnsafe
}

// Look for return instructions in the function
// This handles simple cases where the function returns a constant or computed value
for _, block := range callee.Blocks {
for _, instr := range block.Instrs {
if ret, ok := instr.(*ssa.Return); ok {
if len(ret.Results) == 0 {
continue
}
// Recursively extract the query string from the first return value
// This handles cases like:
// func getQuery() string { return "SELECT * FROM users" }
return extractQueryStrFromSsaValue(ret.Results[0])
}
}
}

return "", ErrQueryArgUnsafe
case *ssa.MakeInterface:
// query function takes interface as input
Expand Down Expand Up @@ -346,7 +509,7 @@ func shouldIgnoreNode(ignoreNodes []ast.Node, callSitePos token.Pos) bool {
}

func iterCallGraphNodeCallees(ctx VetContext, cgNode *callgraph.Node, prog *ssa.Program, sqlfunc MatchedSqlFunc, ignoreNodes []ast.Node) []*QuerySite {
queries := []*QuerySite{}
var queries []*QuerySite

for _, inEdge := range cgNode.In {
callerFunc := inEdge.Caller.Func
Expand Down Expand Up @@ -490,7 +653,7 @@ func getSortedIgnoreNodes(pkgs []*packages.Package) []ast.Node {
func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMatcher) ([]*QuerySite, error) {
_, err := os.Stat(filepath.Join(dir, "go.mod"))
if os.IsNotExist(err) {
return nil, errors.New("sqlvet only supports projects using go modules for now.")
return nil, errors.New("sqlvet only supports projects using go modules for now")
}

pkgs, err := loadGoPackages(dir, buildFlags)
Expand Down Expand Up @@ -520,11 +683,11 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat

mode := ssa.InstantiateGenerics
prog, ssaPkgs := ssautil.Packages(pkgs, mode)
log.Debug("Performaing whole-program analysis...")
log.Debug("Performing whole-program analysis...")
prog.Build()

// find ssa.Function for matched sqlfuncs from program
sqlfuncs := []MatchedSqlFunc{}
var sqlfuncs []MatchedSqlFunc
for _, matcher := range matchers {
if !matcher.PackageImported() {
// if package is not imported, then no sqlfunc should be matched
Expand All @@ -538,7 +701,7 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat
mains := ssautil.MainPackages(ssaPkgs)

log.Debug("Building call graph...")
funcs := []*ssa.Function{}
var funcs []*ssa.Function
for _, fn := range mains {
if main := fn.Func("main"); main != nil {
funcs = append(funcs, main)
Expand All @@ -553,7 +716,7 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat
return nil, nil
}

queries := []*QuerySite{}
var queries []*QuerySite
cg := rtaRes.CallGraph
for _, sqlfunc := range sqlfuncs {
cgNode := cg.CreateNode(sqlfunc.SSA)
Expand Down
Loading