@@ -167,8 +167,7 @@ type Query struct {
167167 Comments []string
168168
169169 // XXX: Hack
170- NeedsEdit bool
171- Filename string
170+ Filename string
172171}
173172
174173type Result struct {
@@ -289,7 +288,7 @@ func lineno(source string, head int) (int, int) {
289288func pluckQuery (source string , n nodes.RawStmt ) (string , error ) {
290289 head := n .StmtLocation
291290 tail := n .StmtLocation + n .StmtLen
292- return strings . TrimSpace ( source [head :tail ]) , nil
291+ return source [head :tail ], nil
293292}
294293
295294func rangeVars (root nodes.Node ) []nodes.RangeVar {
@@ -403,7 +402,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
403402 if err := validateFuncCall (& c , raw ); err != nil {
404403 return nil , err
405404 }
406- name , cmd , err := parseMetadata (rawSQL )
405+ name , cmd , err := parseMetadata (strings . TrimSpace ( rawSQL ) )
407406 if err != nil {
408407 return nil , err
409408 }
@@ -421,20 +420,23 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
421420 if err != nil {
422421 return nil , err
423422 }
423+ expanded , err := expand (c , raw , rawSQL )
424+ if err != nil {
425+ return nil , err
426+ }
424427
425- trimmed , comments , err := stripComments (rawSQL )
428+ trimmed , comments , err := stripComments (strings . TrimSpace ( expanded ) )
426429 if err != nil {
427430 return nil , err
428431 }
429432
430433 return & Query {
431- Cmd : cmd ,
432- Comments : comments ,
433- Name : name ,
434- Params : params ,
435- Columns : cols ,
436- SQL : trimmed ,
437- NeedsEdit : needsEdit (stmt ),
434+ Cmd : cmd ,
435+ Comments : comments ,
436+ Name : name ,
437+ Params : params ,
438+ Columns : cols ,
439+ SQL : trimmed ,
438440 }, nil
439441}
440442
@@ -454,6 +456,134 @@ func stripComments(sql string) (string, []string, error) {
454456 return strings .Join (lines , "\n " ), comments , s .Err ()
455457}
456458
459+ type edit struct {
460+ Location int
461+ Old string
462+ New string
463+ }
464+
465+ func expand (c core.Catalog , raw nodes.RawStmt , sql string ) (string , error ) {
466+ list := search (raw , func (node nodes.Node ) bool {
467+ switch node .(type ) {
468+ case nodes.DeleteStmt :
469+ case nodes.InsertStmt :
470+ case nodes.SelectStmt :
471+ case nodes.UpdateStmt :
472+ default :
473+ return false
474+ }
475+ return true
476+ })
477+ if len (list .Items ) == 0 {
478+ return sql , nil
479+ }
480+ var edits []edit
481+ for _ , item := range list .Items {
482+ edit , err := expandStmt (c , raw , item )
483+ if err != nil {
484+ return "" , err
485+ }
486+ edits = append (edits , edit ... )
487+ }
488+ return editQuery (sql , edits )
489+ }
490+
491+ func expandStmt (c core.Catalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
492+ tables , err := sourceTables (c , node )
493+ if err != nil {
494+ return nil , err
495+ }
496+
497+ var targets nodes.List
498+ switch n := node .(type ) {
499+ case nodes.DeleteStmt :
500+ targets = n .ReturningList
501+ case nodes.InsertStmt :
502+ targets = n .ReturningList
503+ case nodes.SelectStmt :
504+ targets = n .TargetList
505+ case nodes.UpdateStmt :
506+ targets = n .ReturningList
507+ default :
508+ return nil , fmt .Errorf ("outputColumns: unsupported node type: %T" , n )
509+ }
510+
511+ var edits []edit
512+ for _ , target := range targets .Items {
513+ res , ok := target .(nodes.ResTarget )
514+ if ! ok {
515+ continue
516+ }
517+ ref , ok := res .Val .(nodes.ColumnRef )
518+ if ! ok {
519+ continue
520+ }
521+ if ! HasStarRef (ref ) {
522+ continue
523+ }
524+ var parts , cols []string
525+ for _ , f := range ref .Fields .Items {
526+ switch field := f .(type ) {
527+ case nodes.String :
528+ parts = append (parts , field .Str )
529+ case nodes.A_Star :
530+ parts = append (parts , "*" )
531+ default :
532+ return nil , fmt .Errorf ("unknown field in ColumnRef: %T" , f )
533+ }
534+ }
535+ for _ , t := range tables {
536+ scope := join (ref .Fields , "." )
537+ if scope != "" && scope != t .Name {
538+ continue
539+ }
540+ for _ , c := range t .Columns {
541+ cname := c .Name
542+ if res .Name != nil {
543+ cname = * res .Name
544+ }
545+ if scope != "" {
546+ cname = scope + "." + cname
547+ }
548+ cols = append (cols , cname )
549+ }
550+ }
551+ edits = append (edits , edit {
552+ Location : res .Location - raw .StmtLocation ,
553+ Old : strings .Join (parts , "." ),
554+ New : strings .Join (cols , ", " ),
555+ })
556+ }
557+ return edits , nil
558+ }
559+
560+ func editQuery (raw string , a []edit ) (string , error ) {
561+ if len (a ) == 0 {
562+ return raw , nil
563+ }
564+ sort .Slice (a , func (i , j int ) bool { return a [i ].Location > a [j ].Location })
565+ s := raw
566+ for _ , edit := range a {
567+ start := edit .Location
568+ if start > len (s ) {
569+ return "" , fmt .Errorf ("edit start location is out of bounds" )
570+ }
571+ if len (edit .New ) <= 0 {
572+ return "" , fmt .Errorf ("empty edit contents" )
573+ }
574+ if len (edit .Old ) <= 0 {
575+ return "" , fmt .Errorf ("empty edit contents" )
576+ }
577+ stop := edit .Location + len (edit .Old ) - 1 // Assumes edit.New is non-empty
578+ if stop < len (s ) {
579+ s = s [:start ] + edit .New + s [stop + 1 :]
580+ } else {
581+ s = s [:start ] + edit .New
582+ }
583+ }
584+ return s , nil
585+ }
586+
457587type QueryCatalog struct {
458588 catalog core.Catalog
459589 ctes map [string ]core.Table
@@ -653,6 +783,7 @@ func outputColumns(c core.Catalog, node nodes.Node) ([]core.Column, error) {
653783
654784 case nodes.ColumnRef :
655785 if HasStarRef (n ) {
786+ // TODO: This code is copied in func expand()
656787 for _ , t := range tables {
657788 scope := join (n .Fields , "." )
658789 if scope != "" && scope != t .Name {
@@ -916,24 +1047,6 @@ func findParameters(root nodes.Node) []paramRef {
9161047 return refs
9171048}
9181049
919- type starWalker struct {
920- found bool
921- }
922-
923- func (s * starWalker ) Visit (node nodes.Node ) Visitor {
924- if _ , ok := node .(nodes.A_Star ); ok {
925- s .found = true
926- return nil
927- }
928- return s
929- }
930-
931- func needsEdit (root nodes.Node ) bool {
932- v := & starWalker {}
933- Walk (v , root )
934- return v .found
935- }
936-
9371050type nodeSearch struct {
9381051 list nodes.List
9391052 check func (nodes.Node ) bool
0 commit comments