diff --git a/cmd/root.go b/cmd/root.go index 5bd3b86..aa3b793 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -87,6 +87,14 @@ func NewRootCmd() *cobra.Command { return err } + batchSize, err := cmd.Flags().GetInt("batch-size") + if err != nil { + return err + } + if batchSize < 1 { + return fmt.Errorf("batch-size must be positive") + } + var uri string if len(args) > 0 { uri = args[0] @@ -105,7 +113,7 @@ func NewRootCmd() *cobra.Command { // return fmt.Errorf("Too many arguments") // } - return internal.Main(uri, showData, showAll, limit, processes, only, except, minCount, pattern, debug, format, include, exclude, output) + return internal.Main(uri, showData, showAll, limit, processes, only, except, minCount, pattern, debug, format, include, exclude, output, batchSize) }, } cmd.PersistentFlags().Bool("show-data", false, "Show data") @@ -122,6 +130,7 @@ func NewRootCmd() *cobra.Command { cmd.PersistentFlags().String("include", "", "Filter tables to scan (comma-separated, supports wildcards)") cmd.PersistentFlags().String("exclude", "", "Exclude tables from scan (comma-separated, supports wildcards)") cmd.PersistentFlags().StringP("output", "o", "", "Output file path (defaults to stdout)") + cmd.PersistentFlags().Int("batch-size", 1, "Number of concurrent database queries") return cmd } diff --git a/internal/data_store_adapter.go b/internal/data_store_adapter.go index b255b87..d4fb0f0 100644 --- a/internal/data_store_adapter.go +++ b/internal/data_store_adapter.go @@ -4,6 +4,6 @@ type DataStoreAdapter interface { TableName() string RowName() string Init(url string) error - FetchTables() ([]table, error) + FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error) FetchTableData(table table, limit int) (*tableData, error) } diff --git a/internal/elasticsearch_adapter.go b/internal/elasticsearch_adapter.go index f9b9ea8..8a593bf 100644 --- a/internal/elasticsearch_adapter.go +++ b/internal/elasticsearch_adapter.go @@ -63,7 +63,7 @@ func (a *ElasticsearchAdapter) Init(urlStr string) error { return nil } -func (a ElasticsearchAdapter) FetchTables() ([]table, error) { +func (a ElasticsearchAdapter) FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error) { tables := []table{} es := a.DB diff --git a/internal/main.go b/internal/main.go index d5034b1..b99a42c 100644 --- a/internal/main.go +++ b/internal/main.go @@ -30,9 +30,10 @@ type ScanOpts struct { MatchConfig *MatchConfig Include string Exclude string + BatchSize int } -func Main(urlStr string, showData bool, showAll bool, limit int, processes int, only string, except string, minCount int, pattern string, debug bool, format string, include string, exclude string, output string) error { +func Main(urlStr string, showData bool, showAll bool, limit int, processes int, only string, except string, minCount int, pattern string, debug bool, format string, include string, exclude string, output string, batchSize int) error { runtime.GOMAXPROCS(processes) var writer io.Writer = os.Stdout @@ -98,7 +99,7 @@ func Main(urlStr string, showData bool, showAll bool, limit int, processes int, adapter = &SqlAdapter{} } - matchList, err := adapter.Scan(ScanOpts{urlStr, showData, showAll, limit, debug, formatter, writer, &matchConfig, include, exclude}) + matchList, err := adapter.Scan(ScanOpts{urlStr, showData, showAll, limit, debug, formatter, writer, &matchConfig, include, exclude, batchSize}) if err != nil { return err @@ -131,7 +132,9 @@ func scanDataStore(adapter DataStoreAdapter, scanOpts ScanOpts) ([]ruleMatch, er return nil, err } - tables, err := adapter.FetchTables() + includeSchemas := extractSchemas(scanOpts.Include) + excludeSchemas := extractSchemas(scanOpts.Exclude) + tables, err := adapter.FetchTables(includeSchemas, excludeSchemas) if err != nil { return nil, err } @@ -150,7 +153,7 @@ func scanDataStore(adapter DataStoreAdapter, scanOpts ScanOpts) ([]ruleMatch, er var g errgroup.Group var appendMutex sync.Mutex - var queryMutex sync.Mutex + sem := make(chan struct{}, scanOpts.BatchSize) for _, table := range tables { // important - do not remove @@ -160,10 +163,10 @@ func scanDataStore(adapter DataStoreAdapter, scanOpts ScanOpts) ([]ruleMatch, er g.Go(func() error { start := time.Now() - // limit to one query at a time - queryMutex.Lock() + // limit concurrent queries to batch size + sem <- struct{}{} tableData, err := adapter.FetchTableData(table, limit) - queryMutex.Unlock() + <-sem if scanOpts.Debug { duration := time.Since(start) @@ -415,3 +418,21 @@ func makeValidNames(matchConfig *MatchConfig) map[string]bool { } return validNames } + +func extractSchemas(patterns string) []string { + if patterns == "" { + return nil + } + seen := map[string]bool{} + for _, p := range strings.Split(patterns, ",") { + parts := strings.SplitN(strings.TrimSpace(p), ".", 2) + if len(parts) > 0 && parts[0] != "*" && parts[0] != "" { + seen[parts[0]] = true + } + } + result := []string{} + for s := range seen { + result = append(result, s) + } + return result +} diff --git a/internal/mongodb_adapter.go b/internal/mongodb_adapter.go index 05475e6..a83f0ba 100644 --- a/internal/mongodb_adapter.go +++ b/internal/mongodb_adapter.go @@ -52,7 +52,7 @@ func (a *MongodbAdapter) Init(urlStr string) error { return nil } -func (a MongodbAdapter) FetchTables() ([]table, error) { +func (a MongodbAdapter) FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error) { tables := []table{} db := a.DB diff --git a/internal/redis_adapter.go b/internal/redis_adapter.go index 378185d..bbd4308 100644 --- a/internal/redis_adapter.go +++ b/internal/redis_adapter.go @@ -42,7 +42,7 @@ func (a *RedisAdapter) Init(urlStr string) error { return nil } -func (a RedisAdapter) FetchTables() ([]table, error) { +func (a RedisAdapter) FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error) { return []table{{Schema: "", Name: ""}}, nil } diff --git a/internal/sql_adapter.go b/internal/sql_adapter.go index de65fb8..b545fd2 100644 --- a/internal/sql_adapter.go +++ b/internal/sql_adapter.go @@ -2,6 +2,7 @@ package internal import ( "fmt" + "strings" "github.com/jmoiron/sqlx" "github.com/lib/pq" @@ -50,7 +51,7 @@ func (a *SqlAdapter) Init(url string) error { return nil } -func (a SqlAdapter) FetchTables() ([]table, error) { +func (a SqlAdapter) FetchTables(includeSchemas []string, excludeSchemas []string) ([]table, error) { tables := []table{} db := a.DB @@ -61,13 +62,23 @@ func (a SqlAdapter) FetchTables() ([]table, error) { case "sqlite3": query = `SELECT '' AS table_schema, name AS table_name FROM sqlite_master WHERE type = 'table' AND name != 'sqlite_sequence' ORDER BY name` case "mysql": - query = `SELECT table_schema AS table_schema, table_name AS table_name FROM information_schema.tables WHERE table_schema = DATABASE() OR (DATABASE() IS NULL AND table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')) ORDER BY table_schema, table_name` + query = `SELECT table_schema AS table_schema, table_name AS table_name FROM information_schema.tables WHERE table_schema = DATABASE() OR (DATABASE() IS NULL AND table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys'))` case "sqlserver": - query = `SELECT table_schema, table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' ORDER BY table_schema, table_name` + query = `SELECT table_schema, table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE'` case "redshift": - query = `SELECT table_schema, table_name FROM svv_tables WHERE table_catalog = current_database() ORDER BY table_schema, table_name` + query = `SELECT table_schema, table_name FROM svv_tables WHERE table_catalog = current_database()` default: - query = `SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema NOT IN ('information_schema', 'pg_catalog') ORDER BY table_schema, table_name` + query = `SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema NOT IN ('information_schema', 'pg_catalog')` + } + + if a.UnaliasedDriverName != "sqlite3" { + if len(includeSchemas) > 0 { + query += fmt.Sprintf(" AND table_schema IN (%s)", schemaList(includeSchemas)) + } + if len(excludeSchemas) > 0 { + query += fmt.Sprintf(" AND table_schema NOT IN (%s)", schemaList(excludeSchemas)) + } + query += " ORDER BY table_schema, table_name" } err := db.Select(&tables, query) @@ -78,6 +89,14 @@ func (a SqlAdapter) FetchTables() ([]table, error) { return tables, nil } +func schemaList(schemas []string) string { + quoted := make([]string, len(schemas)) + for i, s := range schemas { + quoted[i] = fmt.Sprintf("'%s'", s) + } + return strings.Join(quoted, ", ") +} + func (a SqlAdapter) FetchTableData(table table, limit int) (*tableData, error) { db := a.DB