Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ func NewRootCmd() *cobra.Command {
return err
}

batchSize, err := cmd.Flags().GetInt("batch-size")
if err != nil {
return err
}
if batchSize < 1 {
return fmt.Errorf("batch-size must be positive")
}

var uri string
if len(args) > 0 {
uri = args[0]
Expand All @@ -105,7 +113,7 @@ func NewRootCmd() *cobra.Command {
// return fmt.Errorf("Too many arguments")
// }

return internal.Main(uri, showData, showAll, limit, processes, only, except, minCount, pattern, debug, format, include, exclude, output)
return internal.Main(uri, showData, showAll, limit, processes, only, except, minCount, pattern, debug, format, include, exclude, output, batchSize)
},
}
cmd.PersistentFlags().Bool("show-data", false, "Show data")
Expand All @@ -122,6 +130,7 @@ func NewRootCmd() *cobra.Command {
cmd.PersistentFlags().String("include", "", "Filter tables to scan (comma-separated, supports wildcards)")
cmd.PersistentFlags().String("exclude", "", "Exclude tables from scan (comma-separated, supports wildcards)")
cmd.PersistentFlags().StringP("output", "o", "", "Output file path (defaults to stdout)")
cmd.PersistentFlags().Int("batch-size", 1, "Number of concurrent database queries")
return cmd
}

Expand Down
2 changes: 1 addition & 1 deletion internal/data_store_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ type DataStoreAdapter interface {
TableName() string
RowName() string
Init(url string) error
FetchTables() ([]table, error)
FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error)
FetchTableData(table table, limit int) (*tableData, error)
}
2 changes: 1 addition & 1 deletion internal/elasticsearch_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (a *ElasticsearchAdapter) Init(urlStr string) error {
return nil
}

func (a ElasticsearchAdapter) FetchTables() ([]table, error) {
func (a ElasticsearchAdapter) FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error) {
tables := []table{}

es := a.DB
Expand Down
35 changes: 28 additions & 7 deletions internal/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ type ScanOpts struct {
MatchConfig *MatchConfig
Include string
Exclude string
BatchSize int
}

func Main(urlStr string, showData bool, showAll bool, limit int, processes int, only string, except string, minCount int, pattern string, debug bool, format string, include string, exclude string, output string) error {
func Main(urlStr string, showData bool, showAll bool, limit int, processes int, only string, except string, minCount int, pattern string, debug bool, format string, include string, exclude string, output string, batchSize int) error {
runtime.GOMAXPROCS(processes)

var writer io.Writer = os.Stdout
Expand Down Expand Up @@ -98,7 +99,7 @@ func Main(urlStr string, showData bool, showAll bool, limit int, processes int,
adapter = &SqlAdapter{}
}

matchList, err := adapter.Scan(ScanOpts{urlStr, showData, showAll, limit, debug, formatter, writer, &matchConfig, include, exclude})
matchList, err := adapter.Scan(ScanOpts{urlStr, showData, showAll, limit, debug, formatter, writer, &matchConfig, include, exclude, batchSize})

if err != nil {
return err
Expand Down Expand Up @@ -131,7 +132,9 @@ func scanDataStore(adapter DataStoreAdapter, scanOpts ScanOpts) ([]ruleMatch, er
return nil, err
}

tables, err := adapter.FetchTables()
includeSchemas := extractSchemas(scanOpts.Include)
excludeSchemas := extractSchemas(scanOpts.Exclude)
tables, err := adapter.FetchTables(includeSchemas, excludeSchemas)
if err != nil {
return nil, err
}
Expand All @@ -150,7 +153,7 @@ func scanDataStore(adapter DataStoreAdapter, scanOpts ScanOpts) ([]ruleMatch, er

var g errgroup.Group
var appendMutex sync.Mutex
var queryMutex sync.Mutex
sem := make(chan struct{}, scanOpts.BatchSize)

for _, table := range tables {
// important - do not remove
Expand All @@ -160,10 +163,10 @@ func scanDataStore(adapter DataStoreAdapter, scanOpts ScanOpts) ([]ruleMatch, er
g.Go(func() error {
start := time.Now()

// limit to one query at a time
queryMutex.Lock()
// limit concurrent queries to batch size
sem <- struct{}{}
tableData, err := adapter.FetchTableData(table, limit)
queryMutex.Unlock()
<-sem

if scanOpts.Debug {
duration := time.Since(start)
Expand Down Expand Up @@ -415,3 +418,21 @@ func makeValidNames(matchConfig *MatchConfig) map[string]bool {
}
return validNames
}

func extractSchemas(patterns string) []string {
if patterns == "" {
return nil
}
seen := map[string]bool{}
for _, p := range strings.Split(patterns, ",") {
parts := strings.SplitN(strings.TrimSpace(p), ".", 2)
if len(parts) > 0 && parts[0] != "*" && parts[0] != "" {
seen[parts[0]] = true
}
}
result := []string{}
for s := range seen {
result = append(result, s)
}
return result
}
2 changes: 1 addition & 1 deletion internal/mongodb_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (a *MongodbAdapter) Init(urlStr string) error {
return nil
}

func (a MongodbAdapter) FetchTables() ([]table, error) {
func (a MongodbAdapter) FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error) {
tables := []table{}

db := a.DB
Expand Down
2 changes: 1 addition & 1 deletion internal/redis_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (a *RedisAdapter) Init(urlStr string) error {
return nil
}

func (a RedisAdapter) FetchTables() ([]table, error) {
func (a RedisAdapter) FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error) {
return []table{{Schema: "", Name: ""}}, nil
}

Expand Down
29 changes: 24 additions & 5 deletions internal/sql_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package internal

import (
"fmt"
"strings"

"github.com/jmoiron/sqlx"
"github.com/lib/pq"
Expand Down Expand Up @@ -50,7 +51,7 @@ func (a *SqlAdapter) Init(url string) error {
return nil
}

func (a SqlAdapter) FetchTables() ([]table, error) {
func (a SqlAdapter) FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error) {
tables := []table{}

db := a.DB
Expand All @@ -61,13 +62,23 @@ func (a SqlAdapter) FetchTables() ([]table, error) {
case "sqlite3":
query = `SELECT '' AS table_schema, name AS table_name FROM sqlite_master WHERE type = 'table' AND name != 'sqlite_sequence' ORDER BY name`
case "mysql":
query = `SELECT table_schema AS table_schema, table_name AS table_name FROM information_schema.tables WHERE table_schema = DATABASE() OR (DATABASE() IS NULL AND table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')) ORDER BY table_schema, table_name`
query = `SELECT table_schema AS table_schema, table_name AS table_name FROM information_schema.tables WHERE table_schema = DATABASE() OR (DATABASE() IS NULL AND table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys'))`
case "sqlserver":
query = `SELECT table_schema, table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' ORDER BY table_schema, table_name`
query = `SELECT table_schema, table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE'`
case "redshift":
query = `SELECT table_schema, table_name FROM svv_tables WHERE table_catalog = current_database() ORDER BY table_schema, table_name`
query = `SELECT table_schema, table_name FROM svv_tables WHERE table_catalog = current_database()`
default:
query = `SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema NOT IN ('information_schema', 'pg_catalog') ORDER BY table_schema, table_name`
query = `SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema NOT IN ('information_schema', 'pg_catalog')`
}

if a.UnaliasedDriverName != "sqlite3" {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need support for all other engines other than mysql and postgresql

I understand it's AI generated, but we should be in control if it as well

if len(includeSchemas) > 0 {
query += fmt.Sprintf(" AND table_schema IN (%s)", schemaList(includeSchemas))
}
if len(excludeSchemas) > 0 {
query += fmt.Sprintf(" AND table_schema NOT IN (%s)", schemaList(excludeSchemas))
}
query += " ORDER BY table_schema, table_name"
}

err := db.Select(&tables, query)
Expand All @@ -78,6 +89,14 @@ func (a SqlAdapter) FetchTables() ([]table, error) {
return tables, nil
}

func schemaList(schemas []string) string {
quoted := make([]string, len(schemas))
for i, s := range schemas {
quoted[i] = fmt.Sprintf("'%s'", s)
}
return strings.Join(quoted, ", ")
}

func (a SqlAdapter) FetchTableData(table table, limit int) (*tableData, error) {
db := a.DB

Expand Down