Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,13 @@ func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
var tables []*plugin.Table
for _, t := range s.Tables {
var columns []*plugin.Column
var excludedCols []*plugin.Column
for _, c := range t.Columns {
l := -1
if c.Length != nil {
l = *c.Length
}
columns = append(columns, &plugin.Column{
col := &plugin.Column{
Name: c.Name,
Type: &plugin.Identifier{
Catalog: c.Type.Catalog,
Expand All @@ -125,7 +126,12 @@ func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
},
SourceLocation: pluginSourceLocation(c.SourceLocation),
TypeMods: pluginTypeMods(c.Type.Typmods),
})
}
if c.IsExcluded {
excludedCols = append(excludedCols, col)
} else {
columns = append(columns, col)
}
}
var indexes []*plugin.Index
for _, idx := range t.Indexes {
Expand Down Expand Up @@ -154,11 +160,12 @@ func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
Schema: t.Rel.Schema,
Name: t.Rel.Name,
},
Columns: columns,
Comment: t.Comment,
Indexes: indexes,
Constraints: pluginTableConstraints(t.Constraints),
SourceLocation: pluginSourceLocation(t.SourceLocation),
Columns: columns,
ExcludedColumns: excludedCols,
Comment: t.Comment,
Indexes: indexes,
Constraints: pluginTableConstraints(t.Constraints),
SourceLocation: pluginSourceLocation(t.SourceLocation),
})
}
schemas = append(schemas, &plugin.Schema{
Expand Down Expand Up @@ -256,9 +263,13 @@ func pluginQueries(r *compiler.Result) []*plugin.Query {
for _, q := range r.Queries {
var params []*plugin.Parameter
var columns []*plugin.Column
var excludedColumns []*plugin.Column
for _, c := range q.Columns {
columns = append(columns, pluginQueryColumn(c))
}
for _, c := range q.ExcludedColumns {
excludedColumns = append(excludedColumns, pluginQueryColumn(c))
}
for _, p := range q.Params {
params = append(params, pluginQueryParam(p))
}
Expand All @@ -276,6 +287,7 @@ func pluginQueries(r *compiler.Result) []*plugin.Query {
Text: q.SQL,
Comments: q.Metadata.Comments,
Columns: columns,
ExcludedColumns: excludedColumns,
Params: params,
Filename: q.Metadata.Filename,
InsertIntoTable: iit,
Expand Down
53 changes: 41 additions & 12 deletions internal/compiler/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import (
)

type analysis struct {
Table *ast.TableName
Columns []*Column
Parameters []Parameter
Named *named.ParamSet
Query string
Table *ast.TableName
Columns []*Column
ExcludedColumns []*Column
Parameters []Parameter
Named *named.ParamSet
Query string
}

func convertTableName(id *analyzer.Identifier) *ast.TableName {
Expand Down Expand Up @@ -185,12 +186,12 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
if err := check(err); err != nil {
return nil, err
}
cols, err := c.outputColumns(qc, raw.Stmt)
cols, excludedFromOutput, err := c.outputColumnsWithExcluded(qc, raw.Stmt)
if err := check(err); err != nil {
return nil, err
}

expandEdits, err := c.expand(qc, raw)
expandEdits, excludedFromExpand, err := c.expand(qc, raw)
if check(err); err != nil {
return nil, err
}
Expand All @@ -200,16 +201,44 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
return nil, err
}

// Combine excluded columns from both outputColumns and expand.
// Note: This does not recursively collect excluded columns from subqueries,
// CTEs, or UNION clauses. Excluded columns are only collected from the
// top-level star expansions.
// Use a map to deduplicate based on table and column name.
excludedMap := make(map[string]*Column)
for _, col := range excludedFromOutput {
key := ""
if col.Table != nil {
key = col.Table.Schema + "." + col.Table.Name + "."
}
key += col.Name
excludedMap[key] = col
}
for _, col := range excludedFromExpand {
key := ""
if col.Table != nil {
key = col.Table.Schema + "." + col.Table.Name + "."
}
key += col.Name
excludedMap[key] = col
}
var excludedCols []*Column
for _, col := range excludedMap {
excludedCols = append(excludedCols, col)
}

var rerr error
if len(errors) > 0 {
rerr = errors[0]
}

return &analysis{
Table: table,
Columns: cols,
Parameters: params,
Query: expanded,
Named: namedParams,
Table: table,
Columns: cols,
ExcludedColumns: excludedCols,
Parameters: params,
Query: expanded,
Named: namedParams,
}, rerr
}
4 changes: 4 additions & 0 deletions internal/compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ func (c *Compiler) parseCatalog(schemas []string) error {
if len(merr.Errs()) > 0 {
return merr
}

// Mark excluded columns in the catalog
c.MarkExcludedColumns()

return nil
}

Expand Down
47 changes: 38 additions & 9 deletions internal/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,30 @@ import (
"github.com/sqlc-dev/sqlc/internal/engine/sqlite"
sqliteanalyze "github.com/sqlc-dev/sqlc/internal/engine/sqlite/analyzer"
"github.com/sqlc-dev/sqlc/internal/opts"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
)

type Compiler struct {
conf config.SQL
combo config.CombinedSettings
catalog *catalog.Catalog
parser Parser
result *Result
analyzer analyzer.Analyzer
client dbmanager.Client
selector selector
conf config.SQL
combo config.CombinedSettings
catalog *catalog.Catalog
parser Parser
result *Result
analyzer analyzer.Analyzer
client dbmanager.Client
selector selector
excludeFilter *ExcludeColumnsFilter

schema []string
}

func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, error) {
c := &Compiler{conf: conf, combo: combo}
c := &Compiler{
conf: conf,
combo: combo,
excludeFilter: NewExcludeColumnsFilter(conf.ExcludeColumns),
}

if conf.Database != nil && conf.Database.Managed {
client := dbmanager.NewClient(combo.Global.Servers)
Expand Down Expand Up @@ -103,3 +109,26 @@ func (c *Compiler) Close(ctx context.Context) {
c.client.Close(ctx)
}
}

// MarkExcludedColumns marks columns as excluded based on exclude_columns config.
// This should be called after ParseCatalog to set IsExcluded flag on catalog columns.
func (c *Compiler) MarkExcludedColumns() {
if c.excludeFilter == nil || len(c.conf.ExcludeColumns) == 0 {
return
}

for _, schema := range c.catalog.Schemas {
for _, table := range schema.Tables {
tableName := &ast.TableName{
Catalog: table.Rel.Catalog,
Schema: table.Rel.Schema,
Name: table.Rel.Name,
}
for _, col := range table.Columns {
if c.excludeFilter.ShouldExclude(tableName, col.Name) {
col.IsExcluded = true
}
}
}
}
}
48 changes: 48 additions & 0 deletions internal/compiler/exclude_columns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package compiler

import (
"strings"

"github.com/sqlc-dev/sqlc/internal/sql/ast"
)

// ExcludeColumnsFilter filters columns that should be excluded from star expansion.
type ExcludeColumnsFilter struct {
// excludeMap holds normalized (lowercase) column identifiers.
// Keys are in "table.column" or "schema.table.column" format.
excludeMap map[string]struct{}
}

// NewExcludeColumnsFilter creates a new ExcludeColumnsFilter from a list of column identifiers.
// Each identifier should be in "table.column" or "schema.table.column" format.
func NewExcludeColumnsFilter(excludeColumns []string) *ExcludeColumnsFilter {
m := make(map[string]struct{})
for _, col := range excludeColumns {
m[strings.ToLower(col)] = struct{}{}
}
return &ExcludeColumnsFilter{excludeMap: m}
}

// ShouldExclude returns true if the specified column should be excluded from star expansion.
// tableName should be the actual table name (not an alias).
func (f *ExcludeColumnsFilter) ShouldExclude(tableName *ast.TableName, columnName string) bool {
if f == nil || len(f.excludeMap) == 0 || tableName == nil {
return false
}

// Check "table.column" format
key := strings.ToLower(tableName.Name + "." + columnName)
if _, ok := f.excludeMap[key]; ok {
return true
}

// Check "schema.table.column" format
if tableName.Schema != "" {
key = strings.ToLower(tableName.Schema + "." + tableName.Name + "." + columnName)
if _, ok := f.excludeMap[key]; ok {
return true
}
}

return false
}
113 changes: 113 additions & 0 deletions internal/compiler/exclude_columns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package compiler

import (
"testing"

"github.com/sqlc-dev/sqlc/internal/sql/ast"
)

func TestExcludeColumnsFilter_ShouldExclude(t *testing.T) {
tests := []struct {
name string
excludeColumns []string
tableName *ast.TableName
columnName string
want bool
}{
{
name: "table.column format matches",
excludeColumns: []string{"users.password"},
tableName: &ast.TableName{Name: "users"},
columnName: "password",
want: true,
},
{
name: "table.column format does not match different column",
excludeColumns: []string{"users.password"},
tableName: &ast.TableName{Name: "users"},
columnName: "name",
want: false,
},
{
name: "table.column format does not match different table",
excludeColumns: []string{"users.password"},
tableName: &ast.TableName{Name: "admins"},
columnName: "password",
want: false,
},
{
name: "schema.table.column format matches",
excludeColumns: []string{"public.users.password"},
tableName: &ast.TableName{Schema: "public", Name: "users"},
columnName: "password",
want: true,
},
{
name: "schema.table.column format does not match different schema",
excludeColumns: []string{"public.users.password"},
tableName: &ast.TableName{Schema: "private", Name: "users"},
columnName: "password",
want: false,
},
{
name: "case insensitive match",
excludeColumns: []string{"Users.Password"},
tableName: &ast.TableName{Name: "users"},
columnName: "password",
want: true,
},
{
name: "case insensitive match with schema",
excludeColumns: []string{"Public.Users.Password"},
tableName: &ast.TableName{Schema: "public", Name: "users"},
columnName: "password",
want: true,
},
{
name: "empty filter excludes nothing",
excludeColumns: []string{},
tableName: &ast.TableName{Name: "users"},
columnName: "password",
want: false,
},
{
name: "nil table name excludes nothing",
excludeColumns: []string{"users.password"},
tableName: nil,
columnName: "password",
want: false,
},
{
name: "multiple exclusions",
excludeColumns: []string{"users.password", "users.ssn", "accounts.secret"},
tableName: &ast.TableName{Name: "users"},
columnName: "ssn",
want: true,
},
{
name: "table.column matches even when schema present in tableName",
excludeColumns: []string{"users.password"},
tableName: &ast.TableName{Schema: "public", Name: "users"},
columnName: "password",
want: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
filter := NewExcludeColumnsFilter(tt.excludeColumns)
got := filter.ShouldExclude(tt.tableName, tt.columnName)
if got != tt.want {
t.Errorf("ShouldExclude() = %v, want %v", got, tt.want)
}
})
}
}

func TestExcludeColumnsFilter_NilFilter(t *testing.T) {
var filter *ExcludeColumnsFilter
got := filter.ShouldExclude(&ast.TableName{Name: "users"}, "password")
if got != false {
t.Errorf("ShouldExclude() on nil filter = %v, want false", got)
}
}
Loading
Loading