diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 926f7a6..53f3fc7 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,7 +1,7 @@ # This workflow will build a golang project # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go -name: go-querysql-test +name: sqlcode on: pull_request: @@ -11,19 +11,16 @@ jobs: build: runs-on: ubuntu-latest - env: - SQLSERVER_DSN: "sqlserver://127.0.0.1:1433?database=master&user id=sa&password=VippsPw1" + strategy: + matrix: + driver: ['mssql'] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.23' - - - name: Start db - run: docker compose -f docker-compose.test.yml up -d + go-version: '1.25' - name: Test - # Skip the example folder because it has examples of what-not-to-do - run: go test -v $(go list ./... | grep -v './example') + run: docker compose -f docker-compose.${{ matrix.driver }}.yml run test \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e7e4e6c --- /dev/null +++ b/Makefile @@ -0,0 +1,4 @@ +test: test_mssql + +test_mssql: + docker compose --progress plain -f docker-compose.mssql.yml run test \ No newline at end of file diff --git a/cli/cmd/build.go b/cli/cmd/build.go index 1ffdde2..32e309e 100644 --- a/cli/cmd/build.go +++ b/cli/cmd/build.go @@ -3,8 +3,10 @@ package cmd import ( "errors" "fmt" + + mssql "github.com/microsoft/go-mssqldb" "github.com/spf13/cobra" - "github.com/vippsas/sqlcode" + "github.com/vippsas/sqlcode/v2" ) var ( @@ -23,7 +25,7 @@ var ( return err } - preprocessed, err := sqlcode.Preprocess(d.CodeBase, schemasuffix) + preprocessed, err := sqlcode.Preprocess(d.CodeBase, schemasuffix, &mssql.Driver{}) if err != nil { return err } diff --git a/cli/cmd/config.go b/cli/cmd/config.go index 6bebbf1..6968802 100644 --- a/cli/cmd/config.go +++ b/cli/cmd/config.go @@ -5,16 +5,17 @@ import ( "database/sql" "errors" "fmt" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/denisenkom/go-mssqldb/azuread" - "golang.org/x/net/proxy" "io/ioutil" "os" "path" "strings" - _ "github.com/denisenkom/go-mssqldb/azuread" - "github.com/denisenkom/go-mssqldb/msdsn" + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/azuread" + "golang.org/x/net/proxy" + + _ "github.com/microsoft/go-mssqldb/azuread" + "github.com/microsoft/go-mssqldb/msdsn" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" ) diff --git a/cli/cmd/constants.go b/cli/cmd/constants.go index a71b364..90d071a 100644 --- a/cli/cmd/constants.go +++ b/cli/cmd/constants.go @@ -20,18 +20,18 @@ var ( if err != nil { return err } - if len(d.CodeBase.Creates) == 0 && len(d.CodeBase.Declares) == 0 { + if d.CodeBase.Empty() { fmt.Println("No SQL code found in given paths") } - if len(d.CodeBase.Errors) > 0 { + if d.CodeBase.HasErrors() { fmt.Println("Errors:") - for _, e := range d.CodeBase.Errors { + for _, e := range d.CodeBase.Errors() { fmt.Printf("%s:%d:%d: %s\n", e.Pos.File, e.Pos.Line, e.Pos.Line, e.Message) } return nil } fmt.Println("declare") - for i, c := range d.CodeBase.Declares { + for i, c := range d.CodeBase.Declares() { var prefix string if i == 0 { prefix = " " diff --git a/cli/cmd/dep.go b/cli/cmd/dep.go index c0d110c..0317937 100644 --- a/cli/cmd/dep.go +++ b/cli/cmd/dep.go @@ -6,7 +6,7 @@ import ( "os" "github.com/spf13/cobra" - "github.com/vippsas/sqlcode" + "github.com/vippsas/sqlcode/v2" ) func dep(partialParseResults bool) (d sqlcode.Deployable, err error) { @@ -36,16 +36,16 @@ var ( fmt.Println() err = nil } - if len(d.CodeBase.Creates) == 0 && len(d.CodeBase.Declares) == 0 { + if d.CodeBase.Empty() { fmt.Println("No SQL code found in given paths") } - if len(d.CodeBase.Errors) > 0 { + if d.CodeBase.HasErrors() { fmt.Println("Errors:") - for _, e := range d.CodeBase.Errors { + for _, e := range d.CodeBase.Errors() { fmt.Printf("%s:%d:%d: %s\n", e.Pos.File, e.Pos.Line, e.Pos.Line, e.Message) } } - for _, c := range d.CodeBase.Creates { + for _, c := range d.CodeBase.Creates() { fmt.Println(c.QuotedName.String() + ":") if len(c.DependsOn) > 0 { fmt.Println(" Uses:") diff --git a/cli/main.go b/cli/main.go index df6a89a..042c27a 100644 --- a/cli/main.go +++ b/cli/main.go @@ -1,10 +1,11 @@ package main import ( - "github.com/vippsas/sqlcode/cli/cmd" "math/rand" "os" "time" + + "github.com/vippsas/sqlcode/v2/cli/cmd" ) func main() { diff --git a/dbintf.go b/dbintf.go index 8257e11..1495942 100644 --- a/dbintf.go +++ b/dbintf.go @@ -3,6 +3,7 @@ package sqlcode import ( "context" "database/sql" + "database/sql/driver" ) type DB interface { @@ -11,6 +12,7 @@ type DB interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row Conn(ctx context.Context) (*sql.Conn, error) BeginTx(ctx context.Context, txOptions *sql.TxOptions) (*sql.Tx, error) + Driver() driver.Driver } var _ DB = &sql.DB{} diff --git a/dbops.go b/dbops.go index 05e5a88..d63278f 100644 --- a/dbops.go +++ b/dbops.go @@ -3,11 +3,26 @@ package sqlcode import ( "context" "database/sql" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" ) func Exists(ctx context.Context, dbc DB, schemasuffix string) (bool, error) { var schemaID int - err := dbc.QueryRowContext(ctx, `select isnull(schema_id(@p1), 0)`, SchemaName(schemasuffix)).Scan(&schemaID) + + driver := dbc.Driver() + var qs string + + if _, ok := driver.(*mssql.Driver); ok { + qs = `select isnull(schema_id(@p1), 0)` + } + if _, ok := driver.(*stdlib.Driver); ok { + qs = `select coalesce((select oid from pg_namespace where nspname = $1),0)` + } + + err := dbc.QueryRowContext(ctx, qs, SchemaName(schemasuffix)).Scan(&schemaID) if err != nil { return false, err } @@ -19,8 +34,24 @@ func Drop(ctx context.Context, dbc DB, schemasuffix string) error { if err != nil { return err } - _, err = tx.ExecContext(ctx, `sqlcode.DropCodeSchema`, - sql.Named("schemasuffix", schemasuffix)) + + var qs string + var arg = []interface{}{} + driver := dbc.Driver() + + if _, ok := driver.(*mssql.Driver); ok { + qs = `sqlcode.DropCodeSchema` + arg = []interface{}{sql.Named("schemasuffix", schemasuffix)} + } + + if _, ok := dbc.Driver().(*stdlib.Driver); ok { + qs = `call sqlcode.dropcodeschema(@schemasuffix)` + arg = []interface{}{ + pgx.NamedArgs{"schemasuffix": schemasuffix}, + } + } + + _, err = tx.ExecContext(ctx, qs, arg...) if err != nil { _ = tx.Rollback() return err diff --git a/deployable.go b/deployable.go index 7e2b178..c661785 100644 --- a/deployable.go +++ b/deployable.go @@ -10,14 +10,18 @@ import ( "strings" "time" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/vippsas/sqlcode/sqlparser" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + pgxstdlib "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" + "github.com/vippsas/sqlcode/v2/sqlparser" + "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" ) type Deployable struct { SchemaSuffix string ParsedFiles []string // mainly for use in error messages etc - CodeBase sqlparser.Document + CodeBase sqldocument.Document // cache over whether it has been uploaded to a given DB // (the same physical DB can be in this map multiple times under @@ -77,24 +81,25 @@ func impersonate(ctx context.Context, dbc DB, username string, f func(conn *sql. // Upload will create and upload the schema; resulting in an error // if the schema already exists func (d *Deployable) Upload(ctx context.Context, dbc DB) error { - // First, impersonate a user with minimal privileges to get at least - // some level of sandboxing so that migration scripts can't do anything - // the caller didn't expect them to. - return impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", func(conn *sql.Conn) error { + driver := dbc.Driver() + qs := make(map[string][]interface{}, 1) + + var uploadFunc = func(conn *sql.Conn) error { tx, err := conn.BeginTx(ctx, nil) if err != nil { return err } - _, err = tx.ExecContext(ctx, `sqlcode.CreateCodeSchema`, - sql.Named("schemasuffix", d.SchemaSuffix), - ) - if err != nil { - _ = tx.Rollback() - return err + for q, args := range qs { + _, err = tx.ExecContext(ctx, q, args...) + + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to execute (%s) with arg(%s) in schema %s: %w", q, args, d.SchemaSuffix, err) + } } - preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix) + preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix, dbc.Driver()) if err != nil { _ = tx.Rollback() return err @@ -103,15 +108,16 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { _, err := tx.ExecContext(ctx, b.Lines) if err != nil { _ = tx.Rollback() - sqlerr, ok := err.(mssql.Error) - if !ok { - return err - } else { - return SQLUserError{ + if sqlerr, ok := err.(mssql.Error); ok { + return MSSQLUserError{ Wrapped: sqlerr, Batch: b, } } + + // TODO(ks) PGSQLUserError + return fmt.Errorf("failed to upload deployable:%s in schema:%s:%w", d.CodeBase, d.SchemaSuffix, err) + } } err = tx.Commit() @@ -123,8 +129,36 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { return nil - }) + } + + if _, ok := driver.(*mssql.Driver); ok { + // First, impersonate a user with minimal privileges to get at least + // some level of sandboxing so that migration scripts can't do anything + // the caller didn't expect them to. + qs["sqlcode.CreateCodeSchema"] = []interface { + }{ + sql.Named("schemasuffix", d.SchemaSuffix), + } + + return impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", uploadFunc) + } + + if _, ok := driver.(*stdlib.Driver); ok { + qs[`set role "sqlcode-deploy-sandbox-user"`] = nil + qs[`call sqlcode.createcodeschema(@schemasuffix)`] = []interface{}{ + pgx.NamedArgs{"schemasuffix": d.SchemaSuffix}, + } + conn, err := dbc.Conn(ctx) + if err != nil { + return err + } + defer func() { + _ = conn.Close() + }() + return uploadFunc(conn) + } + return fmt.Errorf("failed to determine sql driver to upload schema: %s", d.SchemaSuffix) } // EnsureUploaded checks that the schema with the suffix already exists, @@ -137,36 +171,51 @@ func (d *Deployable) EnsureUploaded(ctx context.Context, dbc DB) error { return nil } + driver := dbc.Driver() lockResourceName := "sqlcode.EnsureUploaded/" + d.SchemaSuffix + var lockRetCode int + var lockQs string + var unlockQs string + var err error + // When a lock is opened with the Transaction lock owner, // that lock is released when the transaction is committed or rolled back. - var lockRetCode int - err := dbc.QueryRowContext(ctx, ` -declare @retcode int; -exec @retcode = sp_getapplock @Resource = @resource, @LockMode = 'Shared', @LockOwner = 'Session', @LockTimeout = @timeoutMs; -select @retcode; -`, - sql.Named("resource", lockResourceName), - sql.Named("timeoutMs", 20000), - ).Scan(&lockRetCode) + if _, ok := driver.(*pgxstdlib.Driver); ok { + lockQs = `select sqlcode.get_applock(@resource, @timeout)` + unlockQs = `select sqlcode.release_applock(@resource)` + + err = dbc.QueryRowContext(ctx, lockQs, pgx.NamedArgs{ + "resource": lockResourceName, + "timeoutMs": 20000, + }).Scan(&lockRetCode) + + defer func() { + dbc.ExecContext(ctx, unlockQs, pgx.NamedArgs{"resource": lockResourceName}) + }() + } + + if _, ok := driver.(*mssql.Driver); ok { + // TODO + + defer func() { + // TODO: This returns an error if the lock is already released + _, _ = dbc.ExecContext(ctx, unlockQs, + sql.Named("Resource", lockResourceName), + sql.Named("LockOwner", "Session"), + ) + }() + } + if err != nil { return err } if lockRetCode < 0 { return errors.New("was not able to get lock before timeout") } - - defer func() { - _, _ = dbc.ExecContext(ctx, `sp_releaseapplock`, - sql.Named("Resource", lockResourceName), - sql.Named("LockOwner", "Session"), - ) - }() - exists, err := Exists(ctx, dbc, d.SchemaSuffix) if err != nil { - return err + return fmt.Errorf("unable to determine if schema %s exists: %w", d.SchemaSuffix, err) } if exists { @@ -195,11 +244,28 @@ func (d Deployable) DropAndUpload(ctx context.Context, dbc DB) error { } // Patch will preprocess the sql passed in so that it will call SQL code -// deployed by the receiver Deployable +// deployed by the receiver Deployable for SQL Server. +// NOTE: This will be deprecated and eventually replaced with CodePatch. func (d Deployable) Patch(sql string) string { return preprocessString(d.SchemaSuffix, sql) } +// CodePatch will preprocess the sql passed in to call +// the correct SQL code deployed to the provided database. +// Q: Nameing? DBPatch, PatchV2, ??? +func (d Deployable) CodePatch(dbc *sql.DB, sql string) string { + driver := dbc.Driver() + if _, ok := driver.(*mssql.Driver); ok { + return codeSchemaRegexp.ReplaceAllString(sql, fmt.Sprintf(`[code@%s]`, d.SchemaSuffix)) + } + + if _, ok := driver.(*stdlib.Driver); ok { + return codeSchemaRegexp.ReplaceAllString(sql, fmt.Sprintf(`"code@%s"`, d.SchemaSuffix)) + } + + panic("unhandled sql driver") +} + func (d *Deployable) markAsUploaded(dbc DB) { d.uploaded[dbc] = struct{}{} } @@ -211,9 +277,8 @@ func (d *Deployable) IsUploadedFromCache(dbc DB) bool { // TODO: StringConst. This requires parsing a SQL literal, a bit too complex // to code up just-in-case - func (d Deployable) IntConst(s string) (int, error) { - for _, declare := range d.CodeBase.Declares { + for _, declare := range d.CodeBase.Declares() { if declare.VariableName == s { // TODO: more robust integer SQL parsing than this; only works // in most common cases @@ -247,8 +312,8 @@ type Options struct { func Include(opts Options, fsys ...fs.FS) (result Deployable, err error) { parsedFiles, doc, err := sqlparser.ParseFilesystems(fsys, opts.IncludeTags) - if len(doc.Errors) > 0 && !opts.PartialParseResults { - return Deployable{}, SQLCodeParseErrors{Errors: doc.Errors} + if doc.HasErrors() && !opts.PartialParseResults { + return Deployable{}, SQLCodeParseErrors{Errors: doc.Errors()} } result.CodeBase = doc @@ -280,10 +345,28 @@ func (s *SchemaObject) Suffix() string { // Return a list of sqlcode schemas that have been uploaded to the database. // This includes all current and unused schemas. -func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) []*SchemaObject { +func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) ([]*SchemaObject, error) { objects := []*SchemaObject{} - impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", func(conn *sql.Conn) error { - rows, err := conn.QueryContext(ctx, ` + driver := dbc.Driver() + var qs string + + var list = func(conn *sql.Conn) error { + rows, err := conn.QueryContext(ctx, qs) + if err != nil { + return err + } + + for rows.Next() { + zero := &SchemaObject{} + rows.Scan(&zero.Name, &zero.Objects, &zero.SchemaId, &zero.CreateDate, &zero.ModifyDate) + objects = append(objects, zero) + } + + return nil + } + + if _, ok := driver.(*mssql.Driver); ok { + qs = ` select s.name , s.schema_id @@ -298,18 +381,32 @@ func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) []*SchemaObject { from sys.objects o where o.schema_id = s.schema_id ) as o - where s.name like 'code@%'`) + where s.name like 'code@%'` + impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", list) + } + + // TODO(ks) the timestamps for schemas + if _, ok := driver.(*stdlib.Driver); ok { + qs = `select nspname as name + , oid as schema_id + , 0 as objects + , '' as create_date + , '' as modify_date + from pg_namespace + where nspname like 'code@%' + order by nspname` + conn, err := dbc.Conn(ctx) if err != nil { - return err + return nil, err } - - for rows.Next() { - zero := &SchemaObject{} - rows.Scan(&zero.Name, &zero.Objects, &zero.SchemaId, &zero.CreateDate, &zero.ModifyDate) - objects = append(objects, zero) + err = list(conn) + if err != nil { + return nil, err } + defer func() { + _ = conn.Close() + }() + } - return nil - }) - return objects + return objects, nil } diff --git a/deployable_test.go b/deployable_test.go index 1e9dac5..7b87b57 100644 --- a/deployable_test.go +++ b/deployable_test.go @@ -21,5 +21,11 @@ declare @EnumInt int = 1, @EnumString varchar(max) = 'hello'; n, err := d.IntConst("@EnumInt") require.NoError(t, err) assert.Equal(t, 1, n) +} +func TestPatch(t *testing.T) { + t.Run("mssql schemasuffix", func(t *testing.T) { + d := Deployable{} + require.Equal(t, "[code@].Foo", d.Patch("[code].Foo")) + }) } diff --git a/docker-compose.mssql.yml b/docker-compose.mssql.yml new file mode 100644 index 0000000..06105a3 --- /dev/null +++ b/docker-compose.mssql.yml @@ -0,0 +1,27 @@ +services: + mssql: + image: mcr.microsoft.com/mssql/server:2022-CU18-ubuntu-22.04 + networks: + - mssql + environment: + ACCEPT_EULA: "Y" + SA_PASSWORD: VippsPw1 + healthcheck: + test: ["CMD", "/opt/mssql-tools18/bin/sqlcmd", "-C", "-Usa", "-PVippsPw1", "-Q", "'select 1'"] + interval: 1s + retries: 20 + test: + build: + context: . + no_cache: true + dockerfile: dockerfile.test + networks: + - mssql + environment: + SQLSERVER_DSN: sqlserver://mssql:1433?database=master&user id=sa&password=VippsPw1 + GODEBUG: "x509negativeserial=1" + depends_on: + mssql: + condition: service_healthy +networks: + mssql: diff --git a/docker-compose.test.yml b/docker-compose.test.yml deleted file mode 100644 index 94da3ec..0000000 --- a/docker-compose.test.yml +++ /dev/null @@ -1,15 +0,0 @@ -services: - # - # mssql - # - mssql: - image: mcr.microsoft.com/mssql/server:latest - - hostname: mssql - container_name: mssql - network_mode: bridge - ports: - - "1433:1433" - environment: - ACCEPT_EULA: "Y" - SA_PASSWORD: VippsPw1 diff --git a/dockerfile.test b/dockerfile.test new file mode 100644 index 0000000..dbd2061 --- /dev/null +++ b/dockerfile.test @@ -0,0 +1,5 @@ +FROM golang:1.25 AS builder +WORKDIR /sqlcode +COPY . . +RUN go mod tidy +CMD ["go", "test", "-v", "./..."] \ No newline at end of file diff --git a/example/basic/example.go b/example/basic/example.go index abe1194..9406915 100644 --- a/example/basic/example.go +++ b/example/basic/example.go @@ -1,3 +1,6 @@ +//go:build examples +// +build examples + package example import ( diff --git a/example/basic/example_test.go b/example/basic/example_test.go index 079bd91..e03c582 100644 --- a/example/basic/example_test.go +++ b/example/basic/example_test.go @@ -1,13 +1,17 @@ +//go:build examples +// +build examples + package example import ( "context" "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/vippsas/sqlcode/sqltest" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vippsas/sqlcode/v2/sqltest" ) func TestPreprocess(t *testing.T) { diff --git a/go.mod b/go.mod index e497cd0..3ef2f55 100644 --- a/go.mod +++ b/go.mod @@ -1,31 +1,42 @@ -module github.com/vippsas/sqlcode +module github.com/vippsas/sqlcode/v2 go 1.24.3 require ( github.com/alecthomas/repr v0.5.2 - github.com/denisenkom/go-mssqldb v0.12.3 github.com/gofrs/uuid v4.4.0+incompatible + github.com/jackc/pgx/v5 v5.7.6 + github.com/microsoft/go-mssqldb v1.9.5 github.com/sirupsen/logrus v1.9.3 github.com/smasher164/xid v0.1.2 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 + github.com/vippsas/sqlcode v1.1.0 golang.org/x/net v0.48.0 gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/pflag v1.0.9 // indirect golang.org/x/crypto v0.46.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.32.0 // indirect ) diff --git a/go.sum b/go.sum index 480c669..7902f8f 100644 --- a/go.sum +++ b/go.sum @@ -1,34 +1,70 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 h1:lhSJz9RMbJcTgxifR1hUNJnn6CNYtbgEDtQV22/9RBA= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 h1:OYa9vmRX2XC5GXRAzeggG12sF/z5D9Ahtdm9EJ00WN4= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0 h1:v9p9TfTbf7AwNb5NYQt7hI41IfPoLFiFkLtb+bmGjT0= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs= github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+8Y+8shw= -github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= -github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= -github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0= +github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= +github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smasher164/xid v0.1.2 h1:erplXSdBRIIw+MrwjJ/m8sLN2XY16UGzpTA0E2Ru6HA= @@ -38,41 +74,30 @@ github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiT github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/vippsas/sqlcode v1.1.0 h1:ExW73SqJcCC6m98+20XQ50T7Wqar13eNc9g9lfqHFcU= +github.com/vippsas/sqlcode v1.1.0/go.mod h1:DrqItRntmb9OsgqBO63j63geFoEw4LFIm3m2O8dxM8Y= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/error.go b/mssql_error.go similarity index 79% rename from error.go rename to mssql_error.go index 6131fbf..aaa971b 100644 --- a/error.go +++ b/mssql_error.go @@ -3,17 +3,18 @@ package sqlcode import ( "bytes" "fmt" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/vippsas/sqlcode/sqlparser" "strings" + + mssql "github.com/microsoft/go-mssqldb" + "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" ) -type SQLUserError struct { +type MSSQLUserError struct { Wrapped mssql.Error Batch Batch } -func (s SQLUserError) Error() string { +func (s MSSQLUserError) Error() string { var buf bytes.Buffer if _, fmterr := fmt.Fprintf(&buf, "\n"); fmterr != nil { @@ -32,7 +33,7 @@ func (s SQLUserError) Error() string { } type SQLCodeParseErrors struct { - Errors []sqlparser.Error + Errors []sqldocument.Error } func (e SQLCodeParseErrors) Error() string { diff --git a/preprocess.go b/preprocess.go index 5a8adab..f921f91 100644 --- a/preprocess.go +++ b/preprocess.go @@ -2,20 +2,24 @@ package sqlcode import ( "crypto/sha256" + "database/sql/driver" "encoding/hex" "errors" "fmt" - "github.com/vippsas/sqlcode/sqlparser" + "reflect" "regexp" "strings" + + "github.com/vippsas/sqlcode/v2/sqlparser" + "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" ) -func SchemaSuffixFromHash(doc sqlparser.Document) string { +func SchemaSuffixFromHash(doc sqldocument.Document) string { hasher := sha256.New() - for _, dec := range doc.Declares { + for _, dec := range doc.Declares() { hasher.Write([]byte(dec.String() + "\n")) } - for _, c := range doc.Creates { + for _, c := range doc.Creates() { if err := c.SerializeBytes(hasher); err != nil { panic(err) // asserting that sha256 will never return a write error... } @@ -41,7 +45,7 @@ type lineNumberCorrection struct { } type Batch struct { - StartPos sqlparser.Pos + StartPos sqldocument.Pos Lines string // lineNumberCorrections contains data that helps us map from errors in the `Lines` @@ -86,7 +90,7 @@ type PreprocessedFile struct { } type PreprocessorError struct { - Pos sqlparser.Pos + Pos sqldocument.Pos Message string } @@ -96,7 +100,7 @@ func (p PreprocessorError) Error() string { var codeSchemaRegexp = regexp.MustCompile(`(?i)\[code\]`) -func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quotedTargetSchema string) (result Batch, err error) { +func sqlcodeTransformCreate(declares map[string]string, c sqldocument.Create, quotedTargetSchema string) (result Batch, err error) { var w strings.Builder if len(c.Body) > 0 { @@ -110,12 +114,14 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot // A @Enum replacement can lead to line numbers changing due to \n present in the literal. // For this reason we need to make a mapping between source line numbers and result // line numbers + // + // TODO: The sqldocument should be responsible for transforming itself, not this function. for _, u := range c.Body { token := u.RawValue switch { - case u.Type == sqlparser.QuotedIdentifierToken && u.RawValue == "[code]": + case u.Type == sqldocument.QuotedIdentifierToken && u.RawValue == "[code]": token = quotedTargetSchema - case u.Type == sqlparser.VariableIdentifierToken && sqlparser.IsSqlcodeConstVariable(u.RawValue): + case u.Type == sqldocument.VariableIdentifierToken && sqlparser.IsSqlcodeConstVariable(u.RawValue): constLiteral, ok := declares[u.RawValue] if !ok { err = PreprocessorError{u.Start, fmt.Sprintf("sqlcode constant `%s` not declared", u.RawValue)} @@ -128,7 +134,6 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot result.lineNumberCorrections = append(result.lineNumberCorrections, lineNumberCorrection{relativeLine, newlineCount}) } } - if _, err = w.WriteString(token); err != nil { return } @@ -138,7 +143,7 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot return } -func Preprocess(doc sqlparser.Document, schemasuffix string) (PreprocessedFile, error) { +func Preprocess(doc sqldocument.Document, schemasuffix string, driver driver.Driver) (PreprocessedFile, error) { var result PreprocessedFile if strings.Contains(schemasuffix, "]") { @@ -146,17 +151,27 @@ func Preprocess(doc sqlparser.Document, schemasuffix string) (PreprocessedFile, } declares := make(map[string]string) - for _, dec := range doc.Declares { + for _, dec := range doc.Declares() { declares[dec.VariableName] = dec.Literal.RawValue } - for _, create := range doc.Creates { + // The current sql driver that we are preparring for + currentDriver := reflect.TypeOf(driver) + + // the default target for mssql + target := fmt.Sprintf(`[code@%s]`, schemasuffix) + + for _, create := range doc.Creates() { if len(create.Body) == 0 { continue } - batch, err := sqlcodeTransformCreate(declares, create, "[code@"+schemasuffix+"]") + if !currentDriver.AssignableTo(reflect.TypeOf(create.Driver)) { + // this batch is for a different sql driver + continue + } + batch, err := sqlcodeTransformCreate(declares, create, target) if err != nil { - return result, err + return result, fmt.Errorf("failed to transform create: %w", err) } result.Batches = append(result.Batches, batch) } diff --git a/preprocess_test.go b/preprocess_test.go index bf976e8..ed62b5e 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -1,22 +1,21 @@ package sqlcode import ( + "strings" "testing" + mssql "github.com/microsoft/go-mssqldb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/vippsas/sqlcode/sqlparser" + "github.com/vippsas/sqlcode/v2/sqlparser" + mssql17 "github.com/vippsas/sqlcode/v2/sqlparser/mssql" + "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" ) -func TestSchemaSuffixFromHash(t *testing.T) { - t.Run("returns a unique hash", func(t *testing.T) { - doc := sqlparser.Document{ - Declares: []sqlparser.Declare{}, - } - - value := SchemaSuffixFromHash(doc) - require.Equal(t, value, SchemaSuffixFromHash(doc)) - }) +func ParseString(t *testing.T, file, input string) *mssql17.TSqlDocument { + d := &mssql17.TSqlDocument{} + assert.NoError(t, d.Parse([]byte(input), sqldocument.FileRef(file))) + return d } func TestLineNumberInInput(t *testing.T) { @@ -63,3 +62,318 @@ func TestLineNumberInInput(t *testing.T) { } assert.Equal(t, expectedInputLineNumbers, inputlines[1:]) } + +func TestSchemaSuffixFromHash(t *testing.T) { + t.Run("returns a unique hash", func(t *testing.T) { + doc := sqlparser.NewDocumentFromExtension(".sql") + value := SchemaSuffixFromHash(doc) + require.Equal(t, value, SchemaSuffixFromHash(doc)) + }) + + t.Run("returns consistent hash", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +declare @EnumFoo int = 1; +go +create procedure [code].Test as begin end +`) + + suffix1 := SchemaSuffixFromHash(doc) + suffix2 := SchemaSuffixFromHash(doc) + + assert.Equal(t, suffix1, suffix2) + assert.Len(t, suffix1, 12) // 6 bytes = 12 hex chars + }) + + t.Run("different content yields different hash", func(t *testing.T) { + doc1 := ParseString(t, "test.sql", ` +declare @EnumFoo int = 1; +go +create procedure [code].Test1 as begin end +`) + doc2 := ParseString(t, "test.sql", ` +declare @EnumFoo int = 2; +go +create procedure [code].Test2 as begin end +`) + + suffix1 := SchemaSuffixFromHash(doc1) + suffix2 := SchemaSuffixFromHash(doc2) + + assert.NotEqual(t, suffix1, suffix2) + }) +} + +func TestSchemaName(t *testing.T) { + assert.Equal(t, "code@abc123", SchemaName("abc123")) + assert.Equal(t, "code@", SchemaName("")) +} + +func TestBatchLineNumberInInput(t *testing.T) { + t.Run("no corrections", func(t *testing.T) { + b := Batch{ + StartPos: sqldocument.Pos{Line: 10, Col: 1}, + Lines: "line1\nline2\nline3", + } + + assert.Equal(t, 10, b.LineNumberInInput(1)) + assert.Equal(t, 11, b.LineNumberInInput(2)) + assert.Equal(t, 12, b.LineNumberInInput(3)) + }) + + t.Run("with corrections", func(t *testing.T) { + b := Batch{ + StartPos: sqldocument.Pos{Line: 10, Col: 1}, + Lines: "line1\nline2\nextra1\nextra2\nline3", + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 2, extraLinesInOutput: 2}, // line 2 became 3 lines + }, + } + + assert.Equal(t, 10, b.LineNumberInInput(1)) // line 1 -> input line 10 + assert.Equal(t, 11, b.LineNumberInInput(2)) // line 2 -> input line 11 + assert.Equal(t, 11, b.LineNumberInInput(3)) // extra line -> still input line 11 + assert.Equal(t, 11, b.LineNumberInInput(4)) // extra line -> still input line 11 + assert.Equal(t, 12, b.LineNumberInInput(5)) // line 3 -> input line 12 + }) +} + +func TestBatchRelativeLineNumberInInput(t *testing.T) { + t.Run("simple case with no corrections", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{}, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(5)) + }) + + t.Run("single correction", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 3, extraLinesInOutput: 2}, + }, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(2)) + assert.Equal(t, 3, b.RelativeLineNumberInInput(3)) + assert.Equal(t, 3, b.RelativeLineNumberInInput(4)) // extra line + assert.Equal(t, 3, b.RelativeLineNumberInInput(5)) // extra line + assert.Equal(t, 4, b.RelativeLineNumberInInput(6)) + }) + + t.Run("multiple corrections", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 2, extraLinesInOutput: 1}, + {inputLineNumber: 5, extraLinesInOutput: 3}, + }, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(2)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(3)) // extra from line 2 + assert.Equal(t, 3, b.RelativeLineNumberInInput(4)) + assert.Equal(t, 4, b.RelativeLineNumberInInput(5)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(6)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(7)) // extra from line 5 + assert.Equal(t, 5, b.RelativeLineNumberInInput(8)) // extra from line 5 + assert.Equal(t, 5, b.RelativeLineNumberInInput(9)) // extra from line 5 + assert.Equal(t, 6, b.RelativeLineNumberInInput(10)) + }) +} + +func TestPreprocess(t *testing.T) { + t.Run("basic procedure with schema replacement", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +create procedure [code].Test as +begin + select 1 +end +`) + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + assert.Contains(t, result.Batches[0].Lines, "[code@abc123].") + assert.NotContains(t, result.Batches[0].Lines, "[code].") + }) + + t.Run("replaces enum constants", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +declare @EnumStatus int = 42; +go +create procedure [code].Test as +begin + select @EnumStatus +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "42/*=@EnumStatus*/") + assert.NotContains(t, batch, "@EnumStatus\n") // shouldn't have bare reference + }) + + t.Run("handles multiline string constants", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +declare @EnumMulti nvarchar(max) = N'line1 +line2 +line3'; +go +create procedure [code].Test as +begin + select @EnumMulti +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0] + assert.Contains(t, batch.Lines, "N'line1\nline2\nline3'/*=@EnumMulti*/") + // Should have line number corrections for the 2 extra lines + assert.Len(t, batch.lineNumberCorrections, 1) + assert.Equal(t, 2, batch.lineNumberCorrections[0].extraLinesInOutput) + }) + + t.Run("error on undeclared constant", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +create procedure [code].Test as +begin + select @EnumUndeclared +end +`) + _, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.Error(t, err) + + var preprocErr PreprocessorError + require.ErrorAs(t, err, &preprocErr) + assert.Contains(t, preprocErr.Message, "@EnumUndeclared") + assert.Contains(t, preprocErr.Message, "not declared") + }) + + t.Run("error on schema suffix with bracket", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +create procedure [code].Test as begin end +`) + _, err := Preprocess(doc, "abc]123", &mssql.Driver{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "schemasuffix cannot contain") + }) + + t.Run("handles multiple creates", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +create procedure [code].Proc1 as begin select 1 end +go +create procedure [code].Proc2 as begin select 2 end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + assert.Len(t, result.Batches, 2) + + assert.Contains(t, result.Batches[0].Lines, "Proc1") + assert.Contains(t, result.Batches[1].Lines, "Proc2") + }) + + t.Run("handles multiple constants in same procedure", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +declare @EnumA int = 1, @EnumB int = 2; +go +create procedure [code].Test as +begin + select @EnumA, @EnumB +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "1/*=@EnumA*/") + assert.Contains(t, batch, "2/*=@EnumB*/") + }) + + t.Run("preserves comments and formatting", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +-- This is a test procedure +create procedure [code].Test as +begin + /* multi + line + comment */ + select 1 +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "-- This is a test procedure") + assert.Contains(t, batch, "/* multi") + }) + + t.Run("handles const and global prefixes", func(t *testing.T) { + doc := ParseString(t, "test.sql", ` +declare @ConstValue int = 100; +declare @GlobalSetting nvarchar(50) = N'test'; +go +create procedure [code].Test as +begin + select @ConstValue, @GlobalSetting +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "100/*=@ConstValue*/") + assert.NotContains(t, batch, "N'test'/*=@GlobalSetting*/") + }) +} + +func TestPreprocessString(t *testing.T) { + t.Run("replaces code schema", func(t *testing.T) { + result := preprocessString("abc123", "select * from [code].Table") + assert.Equal(t, "select * from [code@abc123].Table", result) + }) + + t.Run("case insensitive replacement", func(t *testing.T) { + result := preprocessString("abc123", "select * from [CODE].Table and [Code].Other") + assert.Contains(t, result, "[code@abc123].Table") + assert.Contains(t, result, "[code@abc123].Other") + }) + + t.Run("multiple occurrences", func(t *testing.T) { + sql := ` + select * from [code].A + join [code].B on A.id = B.id + where exists (select 1 from [code].C) + ` + result := preprocessString("abc123", sql) + assert.Equal(t, 3, strings.Count(result, "[code@abc123]")) + assert.NotContains(t, result, "[code].") + }) + + t.Run("no replacement needed", func(t *testing.T) { + sql := "select * from dbo.Table" + result := preprocessString("abc123", sql) + assert.Equal(t, sql, result) + }) +} + +func TestPreprocessorError(t *testing.T) { + t.Run("formats error message", func(t *testing.T) { + err := PreprocessorError{ + Pos: sqldocument.Pos{File: "test.sql", Line: 10, Col: 5}, + Message: "something went wrong", + } + + assert.Equal(t, "test.sql:10:5: something went wrong", err.Error()) + }) +} diff --git a/sqlcode.yaml b/sqlcode.yaml index 549c23f..d979dd1 100644 --- a/sqlcode.yaml +++ b/sqlcode.yaml @@ -1,7 +1,7 @@ databases: - localtest: - connection: sqlserver://localhost:1433?database=foo&user id=foouser&password=FooPasswd1 - + mssql: + connection: sqlserver://mssql:1433?database=foo&user id=foouser&password=FooPasswd1 + # One option is to list other paths to include ('dependencies') here. # Commands to set up for testing with credentials above: diff --git a/sqlparser/dom.go b/sqlparser/dom.go deleted file mode 100644 index cc661f4..0000000 --- a/sqlparser/dom.go +++ /dev/null @@ -1,239 +0,0 @@ -package sqlparser - -import ( - "fmt" - "gopkg.in/yaml.v3" - "io" - "strings" -) - -type Unparsed struct { - Type TokenType - Start, Stop Pos - RawValue string -} - -func (u Unparsed) WithoutPos() Unparsed { - return Unparsed{ - Type: u.Type, - Start: Pos{}, - Stop: Pos{}, - RawValue: u.RawValue, - } -} - -type Declare struct { - Start Pos - Stop Pos - VariableName string - Datatype Type - Literal Unparsed -} - -func (d Declare) String() string { - // silly thing just meant for use for hashing and debugging, not legal SQL.. - return fmt.Sprintf("declare %s %s(%s) = %s", - d.VariableName, - d.Datatype.BaseType, - strings.Join(d.Datatype.Args, ","), - d.Literal.RawValue) -} - -func (d Declare) WithoutPos() Declare { - return Declare{ - Start: Pos{}, - Stop: Pos{}, - VariableName: d.VariableName, - Datatype: d.Datatype, - Literal: d.Literal.WithoutPos(), - } -} - -// A string that has a Pos-ition in a source document -type PosString struct { - Pos - Value string -} - -func (p PosString) String() string { - return p.Value -} - -type Create struct { - CreateType string // "procedure", "function" or "type" - QuotedName PosString // proc/func/type name, including [] - Body []Unparsed - DependsOn []PosString - Docstring []PosString // comment lines before the create statement. Note: this is also part of Body -} - -func (c Create) DocstringAsString() string { - var result []string - for _, line := range c.Docstring { - result = append(result, line.Value) - } - return strings.Join(result, "\n") -} - -func (c Create) DocstringYamldoc() (string, error) { - var yamldoc []string - parsing := false - for _, line := range c.Docstring { - if strings.HasPrefix(line.Value, "--!") { - parsing = true - if !strings.HasPrefix(line.Value, "--! ") { - return "", Error{line.Pos, "YAML document in docstring; missing space after `--!`"} - } - yamldoc = append(yamldoc, line.Value[4:]) - } else if parsing { - return "", Error{line.Pos, "once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement"} - } - } - return strings.Join(yamldoc, "\n"), nil -} - -func (c Create) ParseYamlInDocstring(out any) error { - yamldoc, err := c.DocstringYamldoc() - if err != nil { - return err - } - return yaml.Unmarshal([]byte(yamldoc), out) -} - -type Type struct { - BaseType string - Args []string -} - -func (t Type) String() (result string) { - result = t.BaseType - if len(t.Args) > 0 { - result = fmt.Sprintf("(%s)", strings.Join(t.Args, ",")) - } - return result -} - -type Error struct { - Pos Pos - Message string -} - -func (e Error) Error() string { - return fmt.Sprintf("%s:%d:%d %s", e.Pos.File, e.Pos.Line, e.Pos.Col, e.Message) -} - -func (e Error) WithoutPos() Error { - return Error{Message: e.Message} -} - -type Document struct { - PragmaIncludeIf []string - Creates []Create - Declares []Declare - Errors []Error -} - -func (c Create) Serialize(w io.StringWriter) error { - for _, l := range c.Body { - if _, err := w.WriteString(l.RawValue); err != nil { - return err - } - } - return nil -} - -func (c Create) SerializeBytes(w io.Writer) error { - for _, l := range c.Body { - if _, err := w.Write([]byte(l.RawValue)); err != nil { - return err - } - } - return nil -} - -func (c Create) String() string { - var buf strings.Builder - err := c.Serialize(&buf) - if err != nil { - panic(err) - } - return buf.String() -} - -func (c Create) WithoutPos() Create { - var body []Unparsed - for _, x := range c.Body { - body = append(body, x.WithoutPos()) - } - return Create{ - CreateType: c.CreateType, - QuotedName: c.QuotedName, - DependsOn: c.DependsOn, - Body: body, - } -} - -func (c Create) DependsOnStrings() (result []string) { - for _, x := range c.DependsOn { - result = append(result, x.Value) - } - return -} - -// Transform a Document to remove all Position information; this is used -// to 'unclutter' a DOM to more easily write assertions on it. -func (d Document) WithoutPos() Document { - var cs []Create - for _, x := range d.Creates { - cs = append(cs, x.WithoutPos()) - } - var ds []Declare - for _, x := range d.Declares { - ds = append(ds, x.WithoutPos()) - } - var es []Error - for _, x := range d.Errors { - es = append(es, x.WithoutPos()) - } - return Document{ - Creates: cs, - Declares: ds, - Errors: es, - } -} - -func (d *Document) Include(other Document) { - // Do not copy PragmaIncludeIf, since that is local to a single file. - // Its contents is also present in each Create. - d.Declares = append(d.Declares, other.Declares...) - d.Creates = append(d.Creates, other.Creates...) - d.Errors = append(d.Errors, other.Errors...) -} - -func (d *Document) parseSinglePragma(s *Scanner) { - pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) - if pragma == "" { - return - } - parts := strings.Split(pragma, " ") - if len(parts) != 2 { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - if parts[0] != "include-if" { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - d.PragmaIncludeIf = append(d.PragmaIncludeIf, strings.Split(parts[1], ",")...) -} - -func (d *Document) parsePragmas(s *Scanner) { - for s.TokenType() == PragmaToken { - d.parseSinglePragma(s) - s.NextNonWhitespaceToken() - } -} - -func (d Document) Empty() bool { - return len(d.Creates) > 0 || len(d.Declares) > 0 -} diff --git a/sqlparser/mssql/README.md b/sqlparser/mssql/README.md new file mode 100644 index 0000000..6f470bb --- /dev/null +++ b/sqlparser/mssql/README.md @@ -0,0 +1,48 @@ + +Package mssql provides a T-SQL (Microsoft SQL Server) parser for the sqlcode library. + +# Overview +This package implements a lexical scanner and document parser specifically designed +for T-SQL syntax. It is part of the sqlcode toolchain that manages SQL database +objects (procedures, functions, types) with dependency tracking and code generation. + +# Architecture +The parser follows a two-layer architecture: + 1. Scanner (scanner.go): A lexical tokenizer that breaks T-SQL source into tokens. + It handles T-SQL-specific constructs like N'unicode strings', [bracketed identifiers], + and the GO batch separator. + 2. Document (document.go): A higher-level parser that processes token streams to + extract CREATE statements, DECLARE constants, and dependency information. + +# Token System +T-SQL tokens are divided into two categories: + - Common tokens (defined in sqldocument): Shared across SQL dialects (e.g., parentheses, + whitespace, identifiers). These use token type values 0-999. + - T-SQL-specific tokens (defined in tokens.go): Dialect-specific tokens like + VarcharLiteralToken ('...') and NVarcharLiteralToken (N'...'). These use values 1000-1999. + +# Batch Separator Handling +T-SQL uses GO as a batch separator with special rules: + - GO must appear at the start of a line (only whitespace/comments before it) + - Nothing except whitespace may follow GO on the same line + - GO is not a reserved word; it's a client tool command +The scanner tracks line position state to correctly identify GO as a BatchSeparatorToken +rather than an identifier. Malformed separators (GO followed by non-whitespace) are +reported as MalformedBatchSeparatorToken. + +# Document Structure +The parser recognizes: + - CREATE PROCEDURE/FUNCTION/TYPE statements in the [code] schema + - DECLARE statements for constants (variables starting with @Enum, @Global, or @Const) + - Dependencies between objects via [code].ObjectName references + - Pragma comments (--sqlcode:...) for build-time directives + +# Dependency Tracking +When parsing CREATE statements, the parser scans for [code].ObjectName patterns +to build a dependency graph. This enables topological sorting of objects so they +are created in the correct order during deployment. + +# Error Recovery +The parser uses a recovery strategy that skips to the next statement-starting +keyword (CREATE, DECLARE, GO) when encountering syntax errors. This allows +partial parsing of files with errors while collecting all error messages. diff --git a/sqlparser/mssql/document.go b/sqlparser/mssql/document.go new file mode 100644 index 0000000..b6e4e8a --- /dev/null +++ b/sqlparser/mssql/document.go @@ -0,0 +1,527 @@ +package mssql + +import ( + "fmt" + "sort" + "strings" + + mssql "github.com/microsoft/go-mssqldb" + "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" +) + +// TSQLStatementTokens defines the keywords that start new statements. +// Used by error recovery to find a safe point to resume parsing. +var TSQLStatementTokens = []string{"create", "declare", "go"} + +// TSqlDocument represents a T-SQL source file. +// +// The document contains: +// - creates: CREATE PROCEDURE/FUNCTION/TYPE statements with dependency info +// - declares: DECLARE statements for sqlcode constants (@Enum*, @Global*, @Const*) +// - errors: Syntax and semantic errors encountered during parsing +// - pragmaIncludeIf: Conditional compilation directives from --sqlcode:include-if +// +// Parsing follows T-SQL batch semantics where batches are separated by GO. +// The first batch may contain DECLARE statements for constants. +// Subsequent batches contain CREATE statements for database objects. +type TSqlDocument struct { + pragmaIncludeIf []string + creates []sqldocument.Create + declares []sqldocument.Declare + errors []sqldocument.Error + + sqldocument.Pragma +} + +// Parse processes a T-SQL source file from the given input. +// +// Parsing proceeds in phases: +// 1. Parse pragma comments at the file start (--sqlcode:...) +// 2. Parse batches sequentially, separated by GO +// +// The first batch has special rules: it may contain DECLARE statements +// for sqlcode constants. CREATE statements may appear in any batch, +// but procedures/functions must be alone in their batch (T-SQL requirement). +// +// Errors are accumulated in the document rather than stopping parsing, +// allowing partial results even with syntax errors. +func (d *TSqlDocument) Parse(input []byte, file sqldocument.FileRef) error { + s := &Scanner{} + s.SetInput(input) + s.SetFile(file) + + // Functions typically consume *after* the keyword that triggered their + // invoication; e.g. parseCreate parses from first non-whitespace-token + // *after* `create`. + // + // On return, `s` is positioned at the token that starts the next statement/ + // sub-expression. In particular trailing ';' and whitespace has been consumed. + // + // `s` will typically never be positioned on whitespace except in + // whitespace-preserving parsing + s.NextNonWhitespaceToken() + err := d.ParsePragmas(s) + if err != nil { + d.addError(s, err.Error()) + } + + hasMore := d.parseBatch(s, true) + for hasMore { + hasMore = d.parseBatch(s, false) + } + + return nil +} + +func (d TSqlDocument) HasErrors() bool { + return len(d.errors) > 0 +} + +func (d TSqlDocument) Creates() []sqldocument.Create { + return d.creates +} + +func (d TSqlDocument) Declares() []sqldocument.Declare { + return d.declares +} + +func (d TSqlDocument) Errors() []sqldocument.Error { + return d.errors +} + +func (d *TSqlDocument) Sort() { + // Do the topological sort; and include any error with it as part + // of `result`, *not* return it as err + sortedCreates, errpos, sortErr := sqldocument.TopologicalSort(d.creates) + + if sortErr != nil { + d.errors = append(d.errors, sqldocument.Error{ + Pos: errpos, + Message: sortErr.Error(), + }) + } else { + d.creates = sortedCreates + } +} + +func (d *TSqlDocument) Include(other sqldocument.Document) { + // Do not copy pragmaIncludeIf, since that is local to a single file. + // Its contents is also present in each Create. + d.declares = append(d.declares, other.Declares()...) + d.creates = append(d.creates, other.Creates()...) + d.errors = append(d.errors, other.Errors()...) +} + +func (d TSqlDocument) Empty() bool { + return len(d.creates) == 0 || len(d.declares) == 0 +} + +func (d *TSqlDocument) addError(s sqldocument.Scanner, msg string) { + d.errors = append(d.errors, sqldocument.Error{ + Pos: s.Start(), + Message: msg, + }) +} + +func (d *TSqlDocument) unexpectedTokenError(s sqldocument.Scanner) { + d.addError(s, "Unexpected: "+s.Token()) +} + +func (doc *TSqlDocument) parseTypeExpression(s sqldocument.Scanner) (t sqldocument.Type) { + parseArgs := func() { + // parses *after* the initial (; consumes trailing ) + for { + switch { + case s.TokenType() == sqldocument.NumberToken: + t.Args = append(t.Args, s.Token()) + case s.TokenType() == sqldocument.UnquotedIdentifierToken && s.TokenLower() == "max": + t.Args = append(t.Args, "max") + default: + doc.unexpectedTokenError(s) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) + return + } + s.NextNonWhitespaceCommentToken() + switch { + case s.TokenType() == sqldocument.CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case s.TokenType() == sqldocument.RightParenToken: + s.NextNonWhitespaceCommentToken() + return + default: + doc.unexpectedTokenError(s) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) + return + } + } + } + + if s.TokenType() != sqldocument.UnquotedIdentifierToken { + panic("assertion failed, bug in caller") + } + t.BaseType = s.Token() + s.NextNonWhitespaceCommentToken() + if s.TokenType() == sqldocument.LeftParenToken { + s.NextNonWhitespaceCommentToken() + parseArgs() + } + return +} + +func (doc *TSqlDocument) parseDeclare(s sqldocument.Scanner) (result []sqldocument.Declare) { + declareStart := s.Start() + // parse what is *after* the `declare` reserved keyword +loop: + for { + if s.TokenType() != sqldocument.VariableIdentifierToken { + doc.unexpectedTokenError(s) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) + return + } + + variableName := s.Token() + if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && + !strings.HasPrefix(strings.ToLower(variableName), "@global") && + !strings.HasPrefix(strings.ToLower(variableName), "@const") { + doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) + } + + s.NextNonWhitespaceCommentToken() + var variableType sqldocument.Type + switch s.TokenType() { + case sqldocument.EqualToken: + doc.addError(s, "sqlcode constants needs a type declared explicitly") + s.NextNonWhitespaceCommentToken() + case sqldocument.UnquotedIdentifierToken: + variableType = doc.parseTypeExpression(s) + } + + if s.TokenType() != sqldocument.EqualToken { + doc.addError(s, "sqlcode constants needs to be assigned at once using =") + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) + } + + switch s.NextNonWhitespaceCommentToken() { + case sqldocument.NumberToken, NVarcharLiteralToken, VarcharLiteralToken: + declare := sqldocument.Declare{ + Start: declareStart, + Stop: s.Stop(), + VariableName: variableName, + Datatype: variableType, + Literal: sqldocument.CreateUnparsed(s), + } + result = append(result, declare) + default: + doc.unexpectedTokenError(s) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) + return + } + + switch s.NextNonWhitespaceCommentToken() { + case sqldocument.CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case sqldocument.SemicolonToken: + s.NextNonWhitespaceCommentToken() + break loop + default: + break loop + } + } + if len(result) == 0 { + doc.addError(s, "incorrect syntax; no variables successfully declared") + } + return +} + +func (doc *TSqlDocument) parseBatchSeparator(s sqldocument.Scanner) { + // just saw a 'go'; just make sure there's nothing bad trailing it + // (if there is, convert to errors and move on until the line is consumed + errorEmitted := false + // continuously process tokens until a non-whitespace, non-malformed token is encountered. + for { + switch s.NextToken() { + case sqldocument.WhitespaceToken: + continue + case sqldocument.MalformedBatchSeparatorToken: + if !errorEmitted { + doc.addError(s, "`go` should be alone on a line without any comments") + errorEmitted = true + } + continue + default: + return + } + } +} + +func (doc *TSqlDocument) parseDeclareBatch(s sqldocument.Scanner) (hasMore bool) { + if s.ReservedWord() != "declare" { + panic("assertion failed, incorrect use in caller") + } + for { + tt := s.TokenType() + switch { + case tt == sqldocument.EOFToken: + return false + case tt == sqldocument.ReservedWordToken && s.ReservedWord() == "declare": + s.NextNonWhitespaceCommentToken() + d := doc.parseDeclare(s) + doc.declares = append(doc.declares, d...) + case tt == sqldocument.ReservedWordToken && s.ReservedWord() != "declare": + doc.addError(s, "Only 'declare' allowed in this batch") + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) + case tt == sqldocument.BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) + } + } +} + +// parseBatch processes a single T-SQL batch (content between GO separators). +// +// Batch processing strategy: +// - Track tokens before the first significant statement for docstrings +// - Dispatch to specialized parsers based on statement type (CREATE, DECLARE) +// - Handle batch separator (GO) to signal batch boundary +// +// The isFirst parameter indicates whether this is the first batch in the file, +// which affects whether DECLARE statements are allowed. +func (doc *TSqlDocument) parseBatch(s sqldocument.Scanner, isFirst bool) (hasMore bool) { + batch := sqldocument.NewBatch() + batch.BatchSeparatorHandler = func(s sqldocument.Scanner, b *sqldocument.Batch) { + errorEmitted := false + for { + switch s.NextToken() { + case sqldocument.WhitespaceToken: + continue + case sqldocument.MalformedBatchSeparatorToken: + if !errorEmitted { + b.Errors = append(b.Errors, sqldocument.Error{ + Pos: s.Start(), + Message: "`go` should be alone on a line without any comments", + }) + errorEmitted = true + } + continue + default: + return + } + } + } + batch.TokenHandlers = map[string]func(sqldocument.Scanner, *sqldocument.Batch) int{ + "declare": func(s sqldocument.Scanner, _ *sqldocument.Batch) int { + // First declare-statement; enter a mode where we assume all contents + // of batch are declare statements + if !isFirst { + doc.addError(s, "'declare' statement only allowed in first batch") + } + + // regardless of errors, go on and parse as far as we get... + hasMore := doc.parseDeclareBatch(s) + if hasMore { + return 1 + } + + return -1 + }, + "create": func(s sqldocument.Scanner, b *sqldocument.Batch) int { + counts, exists := b.TokenCalls["create"] + if !exists { + counts = 0 + } + + // should be start of create procedure or create function... + c := doc.parseCreate(s, counts) + c.Driver = &mssql.Driver{} + + // *prepend* what we saw before getting to the 'create' + c.Body = append(b.Nodes, c.Body...) + c.Docstring = b.DocString + doc.creates = append(doc.creates, c) + + //continue parsing + return 0 + }, + } + + hasMore = batch.Parse(s) + if batch.HasErrors() { + doc.errors = append(doc.errors, batch.Errors...) + } + + return hasMore +} + +// parseCreate parses CREATE PROCEDURE/FUNCTION/TYPE statements. +// +// This is the core of the sqlcode parser. It: +// 1. Validates the CREATE type is one we support (procedure/function/type) +// 2. Extracts the object name from [code].ObjectName syntax +// 3. Copies the entire statement body for later emission +// 4. Tracks dependencies by finding [code].OtherObject references +// +// The parser is intentionally permissive about T-SQL syntax details, +// delegating full validation to SQL Server. It focuses on extracting +// the structural information needed for dependency ordering and code generation. +// +// Parameters: +// - s: Scanner positioned on the CREATE keyword +// - createCountInBatch: Number of CREATE statements already seen in this batch +// (used to enforce "one procedure/function per batch" rule) +func (d *TSqlDocument) parseCreate(s sqldocument.Scanner, createCountInBatch int) (result sqldocument.Create) { + if s.ReservedWord() != "create" { + panic("illegal use by caller") + } + sqldocument.CopyToken(s, &result.Body) + + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) + + createType, exists := sqldocument.CreateTypeMapping[strings.ToLower(s.Token())] + + if !exists { + d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) + sqldocument.RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + return + } + if (createType == sqldocument.SQLProcedure || createType == sqldocument.SQLFunction) && createCountInBatch > 0 { + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + sqldocument.RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + return + } + + result.CreateType = createType + sqldocument.CopyToken(s, &result.Body) + + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) + + // Insist on [code]. + if s.TokenType() != sqldocument.QuotedIdentifierToken || s.Token() != "[code]" { + d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) + sqldocument.RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + return + } + + var err error + result.QuotedName, err = sqldocument.ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + if err != nil { + d.addError(s, err.Error()) + } + if result.QuotedName.String() == "" { + return + } + + // We have matched "create [code]."; at this + // point we copy the rest until the batch ends; *but* track dependencies + // + some other details mentioned below + + //firstAs := true // See comment below on rowcount + +tailloop: + for { + tt := s.TokenType() + switch { + case tt == sqldocument.ReservedWordToken && s.ReservedWord() == "create": + // So, we're currently parsing 'create ...' and we see another 'create'. + // We split in two cases depending on the context we are currently in + // (createType is referring to how we entered this function, *NOT* the + // `create` statement we are looking at now + switch createType { // note: this is the *outer* create type, not the one of current scanner position + case sqldocument.SQLFunction, sqldocument.SQLProcedure: + // Within a function/procedure we can allow 'create index', 'create table' and nothing + // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain + // about that aspect, not relevant for batch / dependency parsing) + // + // What is important is a function/procedure/type isn't started on without a 'go' + // in between; so we block those 3 from appearing in the same batch + sqldocument.CopyToken(s, &result.Body) + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) + tt2 := s.TokenType() + + if (tt2 == sqldocument.ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || + (tt2 == sqldocument.UnquotedIdentifierToken && s.TokenLower() == "type") { + sqldocument.RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + return + } + case sqldocument.SQLType: + // We allow more than one type creation in a batch; and 'create' can never appear + // scoped within 'create type'. So at a new create we are done with the previous + // one, and return it -- the caller can then re-enter this function from the top + break tailloop + default: + panic("assertion failed") + } + + case tt == sqldocument.EOFToken || tt == sqldocument.BatchSeparatorToken: + break tailloop + case tt == sqldocument.QuotedIdentifierToken && s.Token() == "[code]": + // Parse a dependency + dep, err := sqldocument.ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + if err != nil { + d.addError(s, err.Error()) + } + found := false + for _, existing := range result.DependsOn { + if existing.Value == dep.Value { + found = true + break + } + } + if !found { + result.DependsOn = append(result.DependsOn, dep) + } + case tt == sqldocument.ReservedWordToken && s.Token() == "as": + sqldocument.CopyToken(s, &result.Body) + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) + /* + TODO: Fix and re-enable + This code add RoutineName for convenience. So: + + create procedure [code@5420c0269aaf].Test as + begin + select 1 + end + go + + becomes: + + create procedure [code@5420c0269aaf].Test as + declare @RoutineName nvarchar(128) + set @RoutineName = 'Test' + begin + select 1 + end + go + + However, for some very strange reason, @@rowcount is 1 with the first version, + and it is 2 with the second version. + if firstAs { + // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name + // from inside the procedure (for example, when logging) + if result.CreateType == "procedure" { + procNameToken := Unparsed{ + Type: OtherToken, + RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), + } + result.Body = append(result.Body, procNameToken) + } + firstAs = false + } + */ + + default: + sqldocument.CopyToken(s, &result.Body) + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) + } + } + + sort.Slice(result.DependsOn, func(i, j int) bool { + return result.DependsOn[i].Value < result.DependsOn[j].Value + }) + return +} diff --git a/sqlparser/mssql/document_test.go b/sqlparser/mssql/document_test.go new file mode 100644 index 0000000..15a685b --- /dev/null +++ b/sqlparser/mssql/document_test.go @@ -0,0 +1,289 @@ +package mssql + +import ( + "strings" + "testing" + + mssql "github.com/microsoft/go-mssqldb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" +) + +func ParseString(t *testing.T, file, input string) *TSqlDocument { + d := &TSqlDocument{} + assert.NoError(t, d.Parse([]byte(input), sqldocument.FileRef(file))) + return d +} + +func TestParserSmokeTest(t *testing.T) { + doc := ParseString(t, "test.sql", ` +/* test is a test + +declare @EnumFoo int = 2; + +*/ + +declare/*comment*/@EnumBar1 varchar (max) = N'declare @EnumThisIsInString'; +declare + + + @EnumBar2 int = 20, + @EnumBar3 int=21; + +GO + +declare @EnumNextBatch int = 3; + +go +-- preceding comment 1 +/* preceding comment 2 + +asdfasdf */create procedure [code].TestFunc as begin + refers to [code].OtherFunc [code].HelloFunc; + create table x ( int x not null ); -- should be ok +end; + +/* trailing comment */ +`) + assert.Equal(t, 1, len(doc.Creates())) + + c := doc.Creates()[0] + assert.Equal(t, &mssql.Driver{}, c.Driver) + assert.Equal(t, "[TestFunc]", c.QuotedName.Value) + assert.Equal(t, []string{"[HelloFunc]", "[OtherFunc]"}, c.DependsOnStrings()) + assert.Equal(t, `-- preceding comment 1 +/* preceding comment 2 + +asdfasdf */create procedure [code].TestFunc as begin + refers to [code].OtherFunc [code].HelloFunc; + create table x ( int x not null ); -- should be ok +end; + +/* trailing comment */ +`, c.String()) + + assert.Equal(t, "'declare' statement only allowed in first batch", doc.Errors()[0].Message) + + assert.Equal(t, "@EnumBar1", doc.Declares()[0].VariableName) + assert.Equal(t, "N'declare @EnumThisIsInString'", doc.Declares()[0].Literal.RawValue) + assert.Equal(t, NVarcharLiteralToken, doc.Declares()[0].Literal.Type) + assert.Equal(t, "varchar", doc.Declares()[0].Datatype.BaseType) + assert.Equal(t, "max", doc.Declares()[0].Datatype.Args[0]) + + assert.Equal(t, "@EnumNextBatch", doc.Declares()[3].VariableName) + assert.Equal(t, "int", doc.Declares()[3].Datatype.BaseType) + assert.Equal(t, sqldocument.NumberToken, doc.Declares()[3].Literal.Type) + assert.Equal(t, "3", doc.Declares()[3].Literal.RawValue) +} + +func TestParserDisallowMultipleCreates(t *testing.T) { + // Test that we get an error if we create two things in same batch; + // the test above tests that it is still OK to create a table within + // a procedure.. + doc := ParseString(t, "test.sql", ` +create function [code].One(); +-- the following should give an error; not that One() depends on Two()... +-- (we don't parse body start/end yet) +create function [code].Two(); +`) + + assert.Equal(t, "a procedure/function must be alone in a batch; use 'go' to split batches", doc.Errors()[0].Message) +} + +func TestBuggyDeclare(t *testing.T) { + // this caused parses to infinitely loop; regression test... + doc := ParseString(t, "test.sql", `declare @EnumA int = 4 @EnumB tinyint = 5 @ENUM_C bigint = 435;`) + assert.Equal(t, 1, len(doc.Errors())) + assert.Equal(t, "Unexpected: @EnumB", doc.Errors()[0].Message) +} + +func TestCreateType(t *testing.T) { + doc := ParseString(t, "test.sql", `create type [code].MyType as table (x int not null primary key);`) + assert.Equal(t, 1, len(doc.Creates())) + assert.Equal(t, sqldocument.SQLType, doc.Creates()[0].CreateType) + assert.Equal(t, "[MyType]", doc.Creates()[0].QuotedName.Value) +} + +func TestPragma(t *testing.T) { + doc := ParseString(t, "test.sql", `--sqlcode:include-if one,two +--sqlcode:include-if three + +create procedure [code].ProcedureShouldAlsoHavePragmasAnnotated() +`) + assert.Equal(t, []string{"one", "two", "three"}, doc.PragmaIncludeIf()) +} + +func TestInfiniteLoopRegression(t *testing.T) { + // success if we terminate!... + doc := ParseString(t, "test.sql", `@declare`) + assert.Equal(t, 1, len(doc.Errors())) +} + +func TestDeclareSeparation(t *testing.T) { + // Trying out many possible ways to separate declare statements: + // Comman, semicolon, simply starting a new declare with or without + // whitespace in between. + // Yes, ='hello'declare @EnumThird really does parse as T-SQL + doc := ParseString(t, "test.sql", ` +declare @EnumFirst int = 3, @EnumSecond varchar(max) = 'hello'declare @EnumThird int=3 declare @EnumFourth int=4;declare @EnumFifth int =5 +`) + assert.Len(t, doc.Declares(), 5) + assert.Equal(t, "@EnumFirst", doc.Declares()[0].VariableName) + assert.Equal(t, "int", doc.Declares()[0].Datatype.BaseType) + assert.Equal(t, sqldocument.NumberToken, doc.Declares()[0].Literal.Type) + assert.Equal(t, "3", doc.Declares()[0].Literal.RawValue) + +} + +func TestBatchDivisionsAndCreateStatements(t *testing.T) { + // Had a bug where comments where repeated on each create statement in different batches, discovery & regression + // (The bug was that a token too much was consumed in parseCreate, consuming the `go` token..) + doc := ParseString(t, "test.sql", ` +create type [code].Batch1 as table (x int); +go +-- a comment in 2nd batch +create procedure [code].Batch2 as table (x int); +go +create type [code].Batch3 as table (x int); +`) + commentCount := 0 + for _, c := range doc.Creates() { + for _, b := range c.Body { + if strings.Contains(b.RawValue, "2nd") { + commentCount++ + } + assert.NotEqual(t, "go", b.RawValue) + } + } + assert.Equal(t, 1, commentCount) +} + +func TestCreateTypes(t *testing.T) { + // Apparently there can be several 'create type' per batch, but only one function/procedure... + // Check we catch all 3 types + doc := ParseString(t, "test.sql", ` +create type [code].Type1 as table (x int); +create type [code].Type2 as table (x int); +create type [code].Type3 as table (x int); +`) + assert.Len(t, doc.Creates(), 3) + assert.Equal(t, "[Type1]", doc.Creates()[0].QuotedName.Value) + assert.Equal(t, "[Type3]", doc.Creates()[2].QuotedName.Value) + // There was a bug that the last item in the body would be the 'create' + // of the next statement; regression test.. + assert.Equal(t, "\n", doc.Creates()[0].Body[len(doc.Creates()[0].Body)-1].RawValue) + assert.Equal(t, "create", doc.Creates()[1].Body[0].RawValue) +} + +func TestCreateProcs(t *testing.T) { + // Apparently there can be several 'create type' per batch, but only one function/procedure... + // Check that we get an error for all further create statements in the same batch + doc := ParseString(t, "test.sql", ` +create procedure [code].FirstProc as table (x int) +create function [code].MyFunction () +create type [code].MyType () +create procedure [code].MyProcedure () +`) + + // First function and last procedure triggers errors. + assert.Len(t, doc.Errors(), 2) + emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" + assert.Equal(t, emsg, doc.Errors()[0].Message) + assert.Equal(t, emsg, doc.Errors()[1].Message) + +} + +func TestCreateProcs2(t *testing.T) { + // Create type first, then create proc... should give an error still.. + doc := ParseString(t, "test.sql", ` +create type [code].MyType () +create procedure [code].FirstProc as table (x int) +`) + // Code above was mainly to be able to step through parser in a given way. + // First function triggers an error. Then create type is parsed which is + // fine sharing a batch with others. + require.Equal(t, 1, len(doc.Errors())) + emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" + assert.Equal(t, emsg, doc.Errors()[0].Message) +} + +func TestGoWithoutNewline(t *testing.T) { + doc := ParseString(t, "test.sql", ` +create procedure [code].Foo() as begin +end; +go create function [code].Bar() returns int as begin +end +`) + // Code above was mainly to be able to step through parser in a given way. + // First function triggers an error. Then create type is parsed which is + // fine sharing a batch with others. + require.Equal(t, 1, len(doc.Errors())) + assert.Equal(t, "`go` should be alone on a line without any comments", doc.Errors()[0].Message) +} + +func TestCreateAnnotationHappyDay(t *testing.T) { + // Comment / annotations on create statements + doc := ParseString(t, "test.sql", ` +-- Not part of annotation +--! key4: 1 + +-- This is part of annotation +--! key1: a +--! key2: b +--! key3: [1,2,3] +create procedure [code].Foo as begin end + +`) + assert.Equal(t, + "-- This is part of annotation\n--! key1: a\n--! key2: b\n--! key3: [1,2,3]", + doc.Creates()[0].DocstringAsString()) + s, err := doc.Creates()[0].DocstringYamldoc() + assert.NoError(t, err) + assert.Equal(t, + "key1: a\nkey2: b\nkey3: [1,2,3]", + s) + + var x struct { + Key1 string `yaml:"key1"` + } + require.NoError(t, doc.Creates()[0].ParseYamlInDocstring(&x)) + assert.Equal(t, "a", x.Key1) +} + +func TestCreateAnnotationAfterPragma(t *testing.T) { + // Comment / annotations on create statement, with pragma at start of file + doc := ParseString(t, "test.sql", ` +--sqlcode: include-if foo + +-- docstring here +create procedure [code].Foo as begin end + +`) + assert.Equal(t, + "-- docstring here", + doc.Creates()[0].DocstringAsString()) +} + +func TestCreateAnnotationErrors(t *testing.T) { + // Multiple embedded yaml documents .. + doc := ParseString(t, "test.sql", ` +--! key4: 1 +-- This comment after yamldoc is illegal; this also prevents multiple embedded YAML documents +create procedure [code].Foo as begin end +`) + _, err := doc.Creates()[0].DocstringYamldoc() + assert.Equal(t, "test.sql:3:1 once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement", + err.Error()) + + // No whitespace after ! + doc = ParseString(t, "test.sql", ` +-- Docstring here +--!key4: 1 +create procedure [code].Foo as begin end +`) + _, err = doc.Creates()[0].DocstringYamldoc() + assert.Equal(t, "test.sql:3:1 YAML document in docstring; missing space after `--!`", + err.Error()) + +} diff --git a/sqlparser/scanner.go b/sqlparser/mssql/scanner.go similarity index 69% rename from sqlparser/scanner.go rename to sqlparser/mssql/scanner.go index 3103894..bd32833 100644 --- a/sqlparser/scanner.go +++ b/sqlparser/mssql/scanner.go @@ -1,95 +1,133 @@ -package sqlparser +package mssql import ( - "github.com/smasher164/xid" "regexp" "strings" "unicode" "unicode/utf8" -) -// dedicated type for reference to file, in case we need to refactor this later.. -type FileRef string - -type Pos struct { - File FileRef - Line, Col int -} + "github.com/smasher164/xid" + "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" +) -// We don't do the lexer/parser split / token stream, but simply use the -// Scanner directly from the recursive descent parser; it is simply a cursor -// in the buffer with associated utility methods +// Scanner is a lexical scanner for T-SQL source code. +// +// Unlike traditional lexer/parser architectures with a token stream, Scanner +// is used directly by the recursive descent parser as a cursor into the input +// buffer. It provides utility methods for tokenization and position tracking. +// +// The scanner handles T-SQL specific constructs including: +// - String literals ('...' and N'...') +// - Quoted identifiers ([...]) +// - Single-line (--) and multi-line (/* */) comments +// - Batch separators (GO) +// - Reserved words +// - Variables (@identifier) type Scanner struct { - input string - file FileRef + input string // The complete source code being scanned + file sqldocument.FileRef // Reference to the source file for error reporting - startIndex int // start of this item - curIndex int // current position of the Scanner - tokenType TokenType + startIndex int // Byte index where current token starts + curIndex int // Current byte position in Input + tokenType sqldocument.TokenType // Type of the current token - // NextToken() has a small state machine to implement the rules of batch seperators - // using these two states - startOfLine bool // have we seen anything non-whitespace, non-comment since start of line? Only used for BatchSeparatorToken - afterBatchSeparator bool // raise an error if we see anything except whitespace and comments after 'go' + // Batch separator state machine fields. + // The GO batch separator has special rules: it must appear at the start + // of a line and nothing except whitespace can follow it on the same line. + startOfLine bool // True if no non-whitespace/comment seen since start of line + afterBatchSeparator bool // True if we just saw GO; used to detect malformed separators - startLine int - stopLine int - indexAtStartLine int // value of `curIndex` after newline char - indexAtStopLine int // value of `curIndex` after newline char + startLine int // Line number (0-indexed) where current token starts + stopLine int // Line number (0-indexed) where current token ends + indexAtStartLine int // Byte index at the start of startLine (after newline) + indexAtStopLine int // Byte index at the start of stopLine (after newline) - reservedWord string // in the event that the token is a ReservedWordToken, this contains the lower-case version + reservedWord string // Lowercase version of token if it's a reserved word, empty otherwise } -type TokenType int +// NewScanner creates a new Scanner for the given T-SQL source file and input string. +// The scanner is positioned before the first token; call NextToken() to advance. +func NewScanner(path sqldocument.FileRef, input string) *Scanner { + return &Scanner{input: input, file: path} +} -func (s *Scanner) TokenType() TokenType { +// TokenType returns the type of the current token. +func (s *Scanner) TokenType() sqldocument.TokenType { return s.tokenType } -// Returns a clone of the scanner; this is used to do look-ahead parsing +func (s *Scanner) SetInput(input []byte) { + s.input = string(input) +} + +func (s *Scanner) SetFile(file sqldocument.FileRef) { + s.file = file +} + +func (s *Scanner) File() sqldocument.FileRef { + return s.file +} + +// Clone returns a copy of the scanner at its current position. +// This is used for look-ahead parsing where we need to tentatively +// scan tokens without committing to consuming them. func (s Scanner) Clone() *Scanner { result := new(Scanner) *result = s return result } +// Token returns the text of the current token as a substring of Input. func (s *Scanner) Token() string { return s.input[s.startIndex:s.curIndex] } +// TokenLower returns the current token text converted to lowercase. +// Useful for case-insensitive keyword matching. func (s *Scanner) TokenLower() string { return strings.ToLower(s.Token()) } +// ReservedWord returns the lowercase reserved word if the current token +// is a ReservedWordToken, or an empty string otherwise. func (s *Scanner) ReservedWord() string { return s.reservedWord } -func (s *Scanner) Start() Pos { - return Pos{ +// Start returns the position where the current token begins. +// Line and column are 1-indexed. +func (s *Scanner) Start() sqldocument.Pos { + return sqldocument.Pos{ Line: s.startLine + 1, Col: s.startIndex - s.indexAtStartLine + 1, File: s.file, } } -func (s *Scanner) Stop() Pos { - return Pos{ +// Stop returns the position where the current token ends. +// Line and column are 1-indexed. +func (s *Scanner) Stop() sqldocument.Pos { + return sqldocument.Pos{ Line: s.stopLine + 1, Col: s.curIndex - s.indexAtStopLine + 1, File: s.file, } } +// bumpLine increments the line counter and records the byte position +// after the newline character. The offset parameter is the position +// of the newline within the current scan operation. func (s *Scanner) bumpLine(offset int) { s.stopLine++ s.indexAtStopLine = s.curIndex + offset + 1 } +// SkipWhitespaceComments advances past any whitespace and comment tokens. +// Stops when a non-whitespace, non-comment token is encountered. func (s *Scanner) SkipWhitespaceComments() { for { switch s.TokenType() { - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: + case sqldocument.WhitespaceToken, sqldocument.MultilineCommentToken, sqldocument.SinglelineCommentToken: default: return } @@ -97,10 +135,13 @@ func (s *Scanner) SkipWhitespaceComments() { } } +// SkipWhitespace advances past any whitespace tokens. +// Stops when a non-whitespace token is encountered. +// Unlike SkipWhitespaceComments, this preserves comments. func (s *Scanner) SkipWhitespace() { for { switch s.TokenType() { - case WhitespaceToken: + case sqldocument.WhitespaceToken: default: return } @@ -108,21 +149,35 @@ func (s *Scanner) SkipWhitespace() { } } -func (s *Scanner) NextNonWhitespaceToken() TokenType { +// NextNonWhitespaceToken advances to the next token and then skips +// any whitespace, returning the type of the first non-whitespace token. +func (s *Scanner) NextNonWhitespaceToken() sqldocument.TokenType { s.NextToken() s.SkipWhitespace() return s.TokenType() } -func (s *Scanner) NextNonWhitespaceCommentToken() TokenType { +// NextNonWhitespaceCommentToken advances to the next token and then skips +// any whitespace and comments, returning the type of the first significant token. +func (s *Scanner) NextNonWhitespaceCommentToken() sqldocument.TokenType { s.NextToken() s.SkipWhitespaceComments() return s.TokenType() } -// NextToken scans the NextToken token and advances the Scanner's position to -// after the token -func (s *Scanner) NextToken() TokenType { +// NextToken scans the next token and advances the scanner's position. +// +// This method wraps the raw tokenization with batch separator handling. +// The GO batch separator has special rules in T-SQL: +// - It must appear at the start of a line (only whitespace/comments before it) +// - Nothing except whitespace may follow it on the same line +// - It is not processed inside [names], 'strings', or /*comments*/ +// +// If GO is followed by non-whitespace on the same line, subsequent tokens +// are returned as MalformedBatchSeparatorToken until end of line. +// +// Returns the TokenType of the scanned token. +func (s *Scanner) NextToken() sqldocument.TokenType { // handle startOfLine flag here; this is used to parse the 'go' batch separator s.tokenType = s.nextToken() @@ -139,12 +194,12 @@ func (s *Scanner) NextToken() TokenType { // on the same line as 'go'. And doing so will not turn it into a literal, // but instead return MalformedBatchSeparatorToken - if s.startOfLine && s.tokenType == UnquotedIdentifierToken && s.TokenLower() == "go" { - s.tokenType = BatchSeparatorToken + if s.startOfLine && s.tokenType == sqldocument.UnquotedIdentifierToken && s.TokenLower() == "go" { + s.tokenType = sqldocument.BatchSeparatorToken s.afterBatchSeparator = true - } else if s.afterBatchSeparator && s.tokenType != WhitespaceToken && s.tokenType != EOFToken { - s.tokenType = MalformedBatchSeparatorToken - } else if s.tokenType == WhitespaceToken { + } else if s.afterBatchSeparator && s.tokenType != sqldocument.WhitespaceToken && s.tokenType != sqldocument.EOFToken { + s.tokenType = sqldocument.MalformedBatchSeparatorToken + } else if s.tokenType == sqldocument.WhitespaceToken { // If we just saw the whitespace token that bumped the linecount, // we are at the "start of line", even if this contains some space after the \n: if s.stopLine > s.startLine { @@ -155,10 +210,11 @@ func (s *Scanner) NextToken() TokenType { } else { s.startOfLine = false } + return s.tokenType } -func (s *Scanner) nextToken() TokenType { +func (s *Scanner) nextToken() sqldocument.TokenType { s.startIndex = s.curIndex s.reservedWord = "" s.startLine = s.stopLine @@ -168,29 +224,29 @@ func (s *Scanner) nextToken() TokenType { // First, decisions that can be made after one character: switch { case r == utf8.RuneError && w == 0: - return EOFToken + return sqldocument.EOFToken case r == utf8.RuneError && w == -1: // not UTF-8, we can't really proceed so not advancing Scanner, // caller should take care to always exit.. - return NonUTF8ErrorToken + return sqldocument.NonUTF8ErrorToken case r == '(': s.curIndex += w - return LeftParenToken + return sqldocument.LeftParenToken case r == ')': s.curIndex += w - return RightParenToken + return sqldocument.RightParenToken case r == ';': s.curIndex += w - return SemicolonToken + return sqldocument.SemicolonToken case r == '=': s.curIndex += w - return EqualToken + return sqldocument.EqualToken case r == ',': s.curIndex += w - return CommaToken + return sqldocument.CommaToken case r == '.': s.curIndex += w - return DotToken + return sqldocument.DotToken case r == '\'': s.curIndex += w return s.scanStringLiteral(VarcharLiteralToken) @@ -213,14 +269,15 @@ func (s *Scanner) nextToken() TokenType { s.curIndex += w s.scanIdentifier() if r == '@' { - return VariableIdentifierToken + return sqldocument.VariableIdentifierToken } else { rw := strings.ToLower(s.Token()) - if _, ok := reservedWords[rw]; ok { + _, ok := reservedWords[rw] + if ok { s.reservedWord = rw - return ReservedWordToken + return sqldocument.ReservedWordToken } else { - return UnquotedIdentifierToken + return sqldocument.UnquotedIdentifierToken } } } @@ -238,11 +295,12 @@ func (s *Scanner) nextToken() TokenType { // no, it is instead an identifier starting with N... s.scanIdentifier() rw := strings.ToLower(s.Token()) - if _, ok := reservedWords[rw]; ok { + _, ok := reservedWords[rw] + if ok { s.reservedWord = rw - return ReservedWordToken + return sqldocument.ReservedWordToken } else { - return UnquotedIdentifierToken + return sqldocument.UnquotedIdentifierToken } case r == '/' && r2 == '*': s.curIndex += w + w2 @@ -255,28 +313,28 @@ func (s *Scanner) nextToken() TokenType { } s.curIndex += w - return OtherToken + return sqldocument.OtherToken } // scanMultilineComment assumes one has advanced over '/*' -func (s *Scanner) scanMultilineComment() TokenType { +func (s *Scanner) scanMultilineComment() sqldocument.TokenType { prevWasStar := false for i, r := range s.input[s.curIndex:] { if r == '*' { prevWasStar = true } else if prevWasStar && r == '/' { s.curIndex += i + 1 - return MultilineCommentToken + return sqldocument.MultilineCommentToken } else if r == '\n' { s.bumpLine(i) } } s.curIndex = len(s.input) - return MultilineCommentToken + return sqldocument.MultilineCommentToken } // scanSinglelineComment assumes one has advanced over -- -func (s *Scanner) scanSinglelineComment() TokenType { +func (s *Scanner) scanSinglelineComment() sqldocument.TokenType { isPragma := strings.HasPrefix(s.input[s.curIndex:], "sqlcode:") end := strings.Index(s.input[s.curIndex:], "\n") if end == -1 { @@ -288,20 +346,20 @@ func (s *Scanner) scanSinglelineComment() TokenType { s.curIndex += end } if isPragma { - return PragmaToken + return sqldocument.PragmaToken } else { - return SinglelineCommentToken + return sqldocument.SinglelineCommentToken } } // scanStringLiteral assumes one has scanned ' or N' (depending on param); // then scans until the end of the string -func (s *Scanner) scanStringLiteral(tokenType TokenType) TokenType { +func (s *Scanner) scanStringLiteral(tokenType sqldocument.TokenType) sqldocument.TokenType { return s.scanUntilSingleDoubleEscapes('\'', tokenType, UnterminatedVarcharLiteralErrorToken) } -func (s *Scanner) scanQuotedIdentifier() TokenType { - return s.scanUntilSingleDoubleEscapes(']', QuotedIdentifierToken, UnterminatedQuotedIdentifierErrorToken) +func (s *Scanner) scanQuotedIdentifier() sqldocument.TokenType { + return s.scanUntilSingleDoubleEscapes(']', sqldocument.QuotedIdentifierToken, UnterminatedQuotedIdentifierErrorToken) } // scanIdentifier assumes first character of an identifier has been identified, @@ -316,8 +374,12 @@ func (s *Scanner) scanIdentifier() { s.curIndex = len(s.input) } -// DRY helper to handle both '' and ]] escapes -func (s *Scanner) scanUntilSingleDoubleEscapes(endmarker rune, tokenType TokenType, unterminatedTokenType TokenType) TokenType { +// DRY helper to handle both ” and ]] escapes +func (s *Scanner) scanUntilSingleDoubleEscapes( + endmarker rune, + tokenType sqldocument.TokenType, + unterminatedTokenType sqldocument.TokenType, +) sqldocument.TokenType { skipnext := false for i, r := range s.input[s.curIndex:] { if skipnext { @@ -345,7 +407,7 @@ func (s *Scanner) scanUntilSingleDoubleEscapes(endmarker rune, tokenType TokenTy var numberRegexp = regexp.MustCompile(`^[+-]?\d+\.?\d*([eE][+-]?\d*)?`) -func (s *Scanner) scanNumber() TokenType { +func (s *Scanner) scanNumber() sqldocument.TokenType { // T-SQL seems to scan a number until the // end and then allowing a literal to start without whitespace or other things // in between... @@ -357,24 +419,25 @@ func (s *Scanner) scanNumber() TokenType { panic("should always have a match according to regex and conditions in caller") } s.curIndex += loc[1] - return NumberToken + return sqldocument.NumberToken } -func (s *Scanner) scanWhitespace() TokenType { +func (s *Scanner) scanWhitespace() sqldocument.TokenType { for i, r := range s.input[s.curIndex:] { if r == '\n' { s.bumpLine(i) } if !unicode.IsSpace(r) { s.curIndex += i - return WhitespaceToken + return sqldocument.WhitespaceToken } } // eof s.curIndex = len(s.input) - return WhitespaceToken + return sqldocument.WhitespaceToken } +// tsql (mssql) reservered words var reservedWords = map[string]struct{}{ "add": struct{}{}, "external": struct{}{}, diff --git a/sqlparser/mssql/scanner_test.go b/sqlparser/mssql/scanner_test.go new file mode 100644 index 0000000..d9019fb --- /dev/null +++ b/sqlparser/mssql/scanner_test.go @@ -0,0 +1,487 @@ +package mssql + +import ( + "testing" + + "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" +) + +// Helper to collect all tokens from input +func collectTokens(input string) []struct { + Type sqldocument.TokenType + Value string +} { + s := NewScanner("test.sql", input) + var tokens []struct { + Type sqldocument.TokenType + Value string + } + for { + tt := s.NextToken() + tokens = append(tokens, struct { + Type sqldocument.TokenType + Value string + }{tt, s.Token()}) + if tt == sqldocument.EOFToken { + break + } + } + return tokens +} + +func TestScanner_SimpleTokens(t *testing.T) { + tests := []struct { + name string + input string + expected []sqldocument.TokenType + }{ + { + name: "parentheses and punctuation", + input: "( ) ; = , .", + expected: []sqldocument.TokenType{ + sqldocument.LeftParenToken, + sqldocument.WhitespaceToken, + sqldocument.RightParenToken, + sqldocument.WhitespaceToken, + sqldocument.SemicolonToken, + sqldocument.WhitespaceToken, + sqldocument.EqualToken, + sqldocument.WhitespaceToken, + sqldocument.CommaToken, + sqldocument.WhitespaceToken, + sqldocument.DotToken, + sqldocument.EOFToken, + }, + }, + { + name: "empty input", + input: "", + expected: []sqldocument.TokenType{ + sqldocument.EOFToken, + }, + }, + { + name: "whitespace only", + input: " \t\n ", + expected: []sqldocument.TokenType{ + sqldocument.WhitespaceToken, + sqldocument.EOFToken, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens := collectTokens(tt.input) + if len(tokens) != len(tt.expected) { + t.Fatalf("expected %d tokens, got %d", len(tt.expected), len(tokens)) + } + for i, exp := range tt.expected { + if tokens[i].Type != exp { + t.Errorf("token %d: expected type %v, got %v (value: %q)", + i, exp, tokens[i].Type, tokens[i].Value) + } + } + }) + } +} + +func TestScanner_StringLiterals(t *testing.T) { + tests := []struct { + name string + input string + expectedType sqldocument.TokenType + expectedValue string + }{ + { + name: "simple varchar", + input: "'hello world'", + expectedType: VarcharLiteralToken, + expectedValue: "'hello world'", + }, + { + name: "varchar with escaped quote", + input: "'it''s working'", + expectedType: VarcharLiteralToken, + expectedValue: "'it''s working'", + }, + { + name: "empty varchar", + input: "''", + expectedType: VarcharLiteralToken, + expectedValue: "''", + }, + { + name: "simple nvarchar", + input: "N'unicode string'", + expectedType: NVarcharLiteralToken, + expectedValue: "N'unicode string'", + }, + { + name: "nvarchar with escaped quote", + input: "N'say ''hello'''", + expectedType: NVarcharLiteralToken, + expectedValue: "N'say ''hello'''", + }, + { + name: "multiline varchar", + input: "'line1\nline2\nline3'", + expectedType: VarcharLiteralToken, + expectedValue: "'line1\nline2\nline3'", + }, + { + name: "unterminated varchar", + input: "'unterminated", + expectedType: UnterminatedVarcharLiteralErrorToken, + expectedValue: "'unterminated", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != tt.expectedType { + t.Errorf("expected type %v, got %v", tt.expectedType, s.TokenType()) + } + if s.Token() != tt.expectedValue { + t.Errorf("expected value %q, got %q", tt.expectedValue, s.Token()) + } + }) + } +} + +func TestScanner_QuotedIdentifiers(t *testing.T) { + tests := []struct { + name string + input string + expectedType sqldocument.TokenType + expectedValue string + }{ + { + name: "simple bracket identifier", + input: "[MyTable]", + expectedType: sqldocument.QuotedIdentifierToken, + expectedValue: "[MyTable]", + }, + { + name: "bracket identifier with space", + input: "[My Table Name]", + expectedType: sqldocument.QuotedIdentifierToken, + expectedValue: "[My Table Name]", + }, + { + name: "bracket identifier with escaped bracket", + input: "[My]]Table]", + expectedType: sqldocument.QuotedIdentifierToken, + expectedValue: "[My]]Table]", + }, + { + name: "code schema identifier", + input: "[code]", + expectedType: sqldocument.QuotedIdentifierToken, + expectedValue: "[code]", + }, + { + name: "unterminated bracket identifier", + input: "[unterminated", + expectedType: UnterminatedQuotedIdentifierErrorToken, + expectedValue: "[unterminated", + }, + { + name: "double quote error", + input: "\"identifier\"", + expectedType: DoubleQuoteErrorToken, + expectedValue: "\"", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != tt.expectedType { + t.Errorf("expected type %v, got %v", tt.expectedType, s.TokenType()) + } + if s.Token() != tt.expectedValue { + t.Errorf("expected value %q, got %q", tt.expectedValue, s.Token()) + } + }) + } +} + +func TestScanner_Identifiers(t *testing.T) { + tests := []struct { + name string + input string + expectedType sqldocument.TokenType + expectedWord string // for reserved words + }{ + { + name: "simple identifier", + input: "MyProc", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "identifier with underscore", + input: "my_procedure_name", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "identifier with numbers", + input: "Proc123", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "identifier starting with underscore", + input: "_private", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "identifier with hash (temp table)", + input: "#TempTable", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "global temp table", + input: "##GlobalTemp", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "variable identifier", + input: "@myVariable", + expectedType: sqldocument.VariableIdentifierToken, + }, + { + name: "system variable", + input: "@@ROWCOUNT", + expectedType: sqldocument.VariableIdentifierToken, + }, + { + name: "reserved word CREATE", + input: "CREATE", + expectedType: sqldocument.ReservedWordToken, + expectedWord: "create", + }, + { + name: "reserved word lowercase", + input: "select", + expectedType: sqldocument.ReservedWordToken, + expectedWord: "select", + }, + { + name: "reserved word mixed case", + input: "DeClaRe", + expectedType: sqldocument.ReservedWordToken, + expectedWord: "declare", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != tt.expectedType { + t.Errorf("expected type %v, got %v", tt.expectedType, s.TokenType()) + } + if tt.expectedWord != "" && s.ReservedWord() != tt.expectedWord { + t.Errorf("expected reserved word %q, got %q", tt.expectedWord, s.ReservedWord()) + } + }) + } +} + +func TestScanner_Numbers(t *testing.T) { + tests := []struct { + name string + input string + expectedValue string + }{ + {"integer", "42", "42"}, + {"negative integer", "-42", "-42"}, + {"positive integer", "+42", "+42"}, + {"decimal", "3.14159", "3.14159"}, + {"negative decimal", "-3.14", "-3.14"}, + {"scientific notation", "1.5e10", "1.5e10"}, + {"scientific negative exponent", "1.5e-10", "1.5e-10"}, + {"scientific positive exponent", "1.5e+10", "1.5e+10"}, + {"integer scientific", "1e5", "1e5"}, + {"leading decimal", "123.", "123."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != sqldocument.NumberToken { + t.Errorf("expected NumberToken, got %v", s.TokenType()) + } + if s.Token() != tt.expectedValue { + t.Errorf("expected %q, got %q", tt.expectedValue, s.Token()) + } + }) + } +} + +func TestScanner_Comments(t *testing.T) { + tests := []struct { + name string + input string + expectedType sqldocument.TokenType + expectedValue string + }{ + { + name: "single line comment", + input: "-- this is a comment", + expectedType: sqldocument.SinglelineCommentToken, + expectedValue: "-- this is a comment", + }, + { + name: "single line comment before newline", + input: "-- comment\ncode", + expectedType: sqldocument.SinglelineCommentToken, + expectedValue: "-- comment", + }, + { + name: "multiline comment", + input: "/* this is\na multiline\ncomment */", + expectedType: sqldocument.MultilineCommentToken, + expectedValue: "/* this is\na multiline\ncomment */", + }, + { + name: "multiline comment with asterisks", + input: "/* * * * */", + expectedType: sqldocument.MultilineCommentToken, + expectedValue: "/* * * * */", + }, + { + name: "pragma comment", + input: "--sqlcode:include-if foo", + expectedType: sqldocument.PragmaToken, + expectedValue: "--sqlcode:include-if foo", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != tt.expectedType { + t.Errorf("expected type %v, got %v", tt.expectedType, s.TokenType()) + } + if s.Token() != tt.expectedValue { + t.Errorf("expected value %q, got %q", tt.expectedValue, s.Token()) + } + }) + } +} + +func TestScanner_Position(t *testing.T) { + input := "SELECT\n @var\n FROM" + s := NewScanner("test.sql", input) + + // SELECT + s.NextToken() + start := s.Start() + if start.Line != 1 || start.Col != 1 { + t.Errorf("SELECT start: expected (1,1), got (%d,%d)", start.Line, start.Col) + } + + // whitespace (includes newline) + s.NextToken() + + // @var + s.NextToken() + start = s.Start() + if start.Line != 2 || start.Col != 3 { + t.Errorf("@var start: expected (2,3), got (%d,%d)", start.Line, start.Col) + } + + // whitespace + s.NextToken() + + // FROM + s.NextToken() + start = s.Start() + if start.Line != 3 || start.Col != 3 { + t.Errorf("FROM start: expected (3,3), got (%d,%d)", start.Line, start.Col) + } +} + +func TestScanner_ComplexStatement(t *testing.T) { + input := `CREATE PROCEDURE [code].[MyProc] + @Param1 nvarchar(100), + @Param2 int = 42 +AS +BEGIN + SELECT @Param1, @Param2 +END` + + s := NewScanner("test.sql", input) + + // Verify we can tokenize the entire statement without errors + tokenCount := 0 + for { + tt := s.NextToken() + tokenCount++ + if tt == sqldocument.EOFToken { + break + } + if tt == sqldocument.NonUTF8ErrorToken { + t.Fatalf("unexpected non-UTF8 error at token %d", tokenCount) + } + } + + if tokenCount < 30 { + t.Errorf("expected at least 30 tokens, got %d", tokenCount) + } +} + +func TestScanner_Clone(t *testing.T) { + input := "SELECT FROM WHERE" + s := NewScanner("test.sql", input) + + s.NextToken() // SELECT + s.NextToken() // whitespace + + clone := s.Clone() + + // Advance original + s.NextToken() // FROM + + // Clone should still be at whitespace position + if clone.Token() != " " { + t.Errorf("clone should still be at whitespace, got %q", clone.Token()) + } + + // Advance clone independently + clone.NextToken() + if clone.Token() != "FROM" { + t.Errorf("clone should now be at FROM, got %q", clone.Token()) + } +} + +func TestScanner_SkipMethods(t *testing.T) { + input := "SELECT /* comment */ @var" + s := NewScanner("test.sql", input) + + s.NextToken() // SELECT + s.NextToken() // whitespace + + // SkipWhitespace should stop at comment + s.SkipWhitespace() + if s.TokenType() != sqldocument.MultilineCommentToken { + t.Errorf("SkipWhitespace should stop at comment, got %v", s.TokenType()) + } + + // Reset and test SkipWhitespaceComments + s = NewScanner("test.sql", input) + s.NextToken() // SELECT + tt := s.NextNonWhitespaceCommentToken() + if tt != sqldocument.VariableIdentifierToken { + t.Errorf("NextNonWhitespaceCommentToken should return @var token type, got %v", tt) + } + if s.Token() != "@var" { + t.Errorf("should be at @var, got %q", s.Token()) + } +} diff --git a/sqlparser/mssql/tokens.go b/sqlparser/mssql/tokens.go new file mode 100644 index 0000000..c17c0e2 --- /dev/null +++ b/sqlparser/mssql/tokens.go @@ -0,0 +1,57 @@ +package mssql + +import "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" + +// T-SQL specific tokens (range 1000-1999) +// +// Token values are partitioned by dialect to avoid collisions: +// - 0-999: Common tokens shared across dialects (sqldocument package) +// - 1000-1999: T-SQL specific tokens (this package) +// - 2000-2999: Reserved for other dialects (e.g., PostgreSQL) +// +// This design allows dialect-specific code to use concrete token types +// while common code can use ToCommonToken() for abstraction. +const ( + // T-SQL specific string literals + // + // T-SQL distinguishes between varchar ('...') and nvarchar (N'...') + // string literals. Both use single quotes with '' as the escape sequence. + VarcharLiteralToken sqldocument.TokenType = iota + sqldocument.TSQLTokenStart + NVarcharLiteralToken + + // T-SQL specific identifier styles + // + // T-SQL uses square brackets for quoted identifiers: [My Table] + // Brackets are escaped by doubling: [My]]Table] represents "My]Table" + BracketQuotedIdentifierToken // [identifier] + + // T-SQL specific errors + // + // Unlike standard SQL, T-SQL does not support double-quoted strings. + // Double quotes are reserved for QUOTED_IDENTIFIER mode identifiers, + // but sqlcode requires bracket notation for consistency. + DoubleQuoteErrorToken // T-SQL doesn't support double-quoted strings + UnterminatedVarcharLiteralErrorToken + UnterminatedQuotedIdentifierErrorToken +) + +// ToCommonToken maps T-SQL specific tokens to their common equivalents +// for dialect-agnostic processing. +// +// This abstraction layer allows higher-level code to work with logical +// token categories (e.g., "string literal") without knowing the specific +// dialect syntax (varchar vs nvarchar, brackets vs double quotes). +// +// Tokens that are already common tokens pass through unchanged. +func ToCommonToken(tt sqldocument.TokenType) sqldocument.TokenType { + switch tt { + case VarcharLiteralToken, NVarcharLiteralToken: + return sqldocument.StringLiteralToken + case BracketQuotedIdentifierToken: + return sqldocument.QuotedIdentifierToken + case UnterminatedVarcharLiteralErrorToken, UnterminatedQuotedIdentifierErrorToken: + return sqldocument.UnterminatedStringErrorToken + default: + return tt + } +} diff --git a/sqlparser/parser.go b/sqlparser/parser.go index 40eebe9..6dc0e9a 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -6,562 +6,50 @@ package sqlparser import ( "crypto/sha256" - "errors" "fmt" "io/fs" + "path/filepath" "regexp" - "sort" + "slices" "strings" -) - -var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" - -func CopyToken(s *Scanner, target *[]Unparsed) { - *target = append(*target, CreateUnparsed(s)) -} - -// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered -// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace -// token, and target is either unmodified or filled with some whitespace nodes. -func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - return - } - } - -} - -// AdvanceAndCopy is like NextToken; advance to next token that is not whitespace and return -// Note: The 'go' and EOF tokens are *not* copied -func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - // copy, and return - CopyToken(s, target) - return - } - } -} - -func CreateUnparsed(s *Scanner) Unparsed { - return Unparsed{ - Type: s.TokenType(), - Start: s.Start(), - Stop: s.Stop(), - RawValue: s.Token(), - } -} - -func (d *Document) addError(s *Scanner, msg string) { - d.Errors = append(d.Errors, Error{ - Pos: s.Start(), - Message: msg, - }) -} - -func (d *Document) unexpectedTokenError(s *Scanner) { - d.addError(s, "Unexpected: "+s.Token()) -} - -func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { - parseArgs := func() { - // parses *after* the initial (; consumes trailing ) - for { - switch { - case s.TokenType() == NumberToken: - t.Args = append(t.Args, s.Token()) - case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": - t.Args = append(t.Args, "max") - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - s.NextNonWhitespaceCommentToken() - switch { - case s.TokenType() == CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case s.TokenType() == RightParenToken: - s.NextNonWhitespaceCommentToken() - return - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - } - } - - if s.TokenType() != UnquotedIdentifierToken { - panic("assertion failed, bug in caller") - } - t.BaseType = s.Token() - s.NextNonWhitespaceCommentToken() - if s.TokenType() == LeftParenToken { - s.NextNonWhitespaceCommentToken() - parseArgs() - } - return -} - -func (doc *Document) parseDeclare(s *Scanner) (result []Declare) { - declareStart := s.Start() - // parse what is *after* the `declare` reserved keyword -loop: - for { - if s.TokenType() != VariableIdentifierToken { - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - variableName := s.Token() - if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && - !strings.HasPrefix(strings.ToLower(variableName), "@global") && - !strings.HasPrefix(strings.ToLower(variableName), "@const") { - doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) - } - s.NextNonWhitespaceCommentToken() - var variableType Type - switch s.TokenType() { - case EqualToken: - doc.addError(s, "sqlcode constants needs a type declared explicitly") - s.NextNonWhitespaceCommentToken() - case UnquotedIdentifierToken: - variableType = doc.parseTypeExpression(s) - } - - if s.TokenType() != EqualToken { - doc.addError(s, "sqlcode constants needs to be assigned at once using =") - doc.recoverToNextStatement(s) - } - - switch s.NextNonWhitespaceCommentToken() { - case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: - result = append(result, Declare{ - Start: declareStart, - Stop: s.Stop(), - VariableName: variableName, - Datatype: variableType, - Literal: CreateUnparsed(s), - }) - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - switch s.NextNonWhitespaceCommentToken() { - case CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case SemicolonToken: - s.NextNonWhitespaceCommentToken() - break loop - default: - break loop - } - } - if len(result) == 0 { - doc.addError(s, "incorrect syntax; no variables successfully declared") - } - return -} - -func (doc *Document) parseBatchSeparator(s *Scanner) { - // just saw a 'go'; just make sure there's nothing bad trailing it - // (if there is, convert to errors and move on until the line is consumed - errorEmitted := false - for { - switch s.NextToken() { - case WhitespaceToken: - continue - case MalformedBatchSeparatorToken: - if !errorEmitted { - doc.addError(s, "`go` should be alone on a line without any comments") - errorEmitted = true - } - continue - default: - return - } - } -} - -func (doc *Document) parseDeclareBatch(s *Scanner) (hasMore bool) { - if s.ReservedWord() != "declare" { - panic("assertion failed, incorrect use in caller") - } - for { - tt := s.TokenType() - switch { - case tt == EOFToken: - return false - case tt == ReservedWordToken && s.ReservedWord() == "declare": - s.NextNonWhitespaceCommentToken() - d := doc.parseDeclare(s) - doc.Declares = append(doc.Declares, d...) - case tt == ReservedWordToken && s.ReservedWord() != "declare": - doc.addError(s, "Only 'declare' allowed in this batch") - doc.recoverToNextStatement(s) - case tt == BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - } - } -} - -func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - var createCountInBatch int - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset token for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // We build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": - // First declare-statement; enter a mode where we assume all contents - // of batch are declare statements - if !isFirst { - doc.addError(s, "'declare' statement only allowed in first batch") - } - // regardless of errors, go on and parse as far as we get... - return doc.parseDeclareBatch(s) - case "create": - // should be start of create procedure or create function... - c := doc.parseCreate(s, createCountInBatch) - // *prepend* what we saw before getting to the 'create' - createCountInBatch++ - c.Body = append(nodes, c.Body...) - c.Docstring = docstring - doc.Creates = append(doc.Creates, c) - default: - doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) - s.NextToken() - } - case BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - s.NextToken() - docstring = nil - } - } -} -func (d *Document) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - default: - CopyToken(s, target) - } - } -} - -func (d *Document) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - } - } -} + "github.com/vippsas/sqlcode/v2/sqlparser/mssql" + "github.com/vippsas/sqlcode/v2/sqlparser/sqldocument" +) -// parseCodeschemaName parses `[code] . something`, and returns `something` -// in quoted form (`[something]`). Also copy to `target`. Empty string on error. -// Note: To follow conventions, consume one extra token at the end even if we know -// it fill not be consumed by this function... -func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { - CopyToken(s, target) - NextTokenCopyingWhitespace(s, target) - if s.TokenType() != DotToken { - d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } - CopyToken(s, target) +var ( + templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" + supportedSqlExtensions []string = []string{".sql"} + // consider something a "sqlcode source file" if it contains [code] + // or a --sqlcode: header + isSqlCodeRegex = regexp.MustCompile(`^--sqlcode:|\[code\]`) +) - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case UnquotedIdentifierToken: - // To get something uniform for comparison, quote all names - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} - NextTokenCopyingWhitespace(s, target) - return result - case QuotedIdentifierToken: - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: s.Token()} - NextTokenCopyingWhitespace(s, target) - return result +// Based on the input file extension, create the appropriate Document type +func NewDocumentFromExtension(extension string) sqldocument.Document { + switch extension { + case ".sql", "sql": + return &mssql.TSqlDocument{} default: - d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} + panic("unhandled document type: " + extension) } } -// parseCreate parses anything that starts with "create". Position is -// *on* the create token. -// At this stage in sqlcode parser development we're only interested -// in procedures/functions/types as opaque blocks of SQL code where -// we only track dependencies between them and their declared name; -// so we treat them with the same code. We consume until the end of -// the batch; only one declaration allowed per batch. Everything -// parsed here will also be added to `batch`. On any error, copying -// to batch stops / becomes erratic.. -func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Create) { - if s.ReservedWord() != "create" { - panic("illegal use by caller") - } - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - createType := strings.ToLower(s.Token()) - if !(createType == "procedure" || createType == "function" || createType == "type") { - d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - d.recoverToNextStatementCopying(s, &result.Body) - return - } - - result.CreateType = createType - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - // Insist on [code]. - if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { - d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - result.QuotedName = d.parseCodeschemaName(s, &result.Body) - if result.QuotedName.String() == "" { - return - } - - // We have matched "create [code]."; at this - // point we copy the rest until the batch ends; *but* track dependencies - // + some other details mentioned below - - //firstAs := true // See comment below on rowcount - -tailloop: - for { - tt := s.TokenType() - switch { - case tt == ReservedWordToken && s.ReservedWord() == "create": - // So, we're currently parsing 'create ...' and we see another 'create'. - // We split in two cases depending on the context we are currently in - // (createType is referring to how we entered this function, *NOT* the - // `create` statement we are looking at now - switch createType { // note: this is the *outer* create type, not the one of current scanner position - case "function", "procedure": - // Within a function/procedure we can allow 'create index', 'create table' and nothing - // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain - // about that aspect, not relevant for batch / dependency parsing) - // - // What is important is a function/procedure/type isn't started on without a 'go' - // in between; so we block those 3 from appearing in the same batch - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - tt2 := s.TokenType() - - if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || - (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { - d.recoverToNextStatementCopying(s, &result.Body) - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - return - } - case "type": - // We allow more than one type creation in a batch; and 'create' can never appear - // scoped within 'create type'. So at a new create we are done with the previous - // one, and return it -- the caller can then re-enter this function from the top - break tailloop - default: - panic("assertion failed") - } - - case tt == EOFToken || tt == BatchSeparatorToken: - break tailloop - case tt == QuotedIdentifierToken && s.Token() == "[code]": - // Parse a dependency - dep := d.parseCodeschemaName(s, &result.Body) - found := false - for _, existing := range result.DependsOn { - if existing.Value == dep.Value { - found = true - break - } - } - if !found { - result.DependsOn = append(result.DependsOn, dep) - } - case tt == ReservedWordToken && s.Token() == "as": - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - /* - TODO: Fix and re-enable - This code add RoutineName for convenience. So: - - create procedure [code@5420c0269aaf].Test as - begin - select 1 - end - go - - becomes: - - create procedure [code@5420c0269aaf].Test as - declare @RoutineName nvarchar(128) - set @RoutineName = 'Test' - begin - select 1 - end - go - - However, for some very strange reason, @@rowcount is 1 with the first version, - and it is 2 with the second version. - if firstAs { - // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name - // from inside the procedure (for example, when logging) - if result.CreateType == "procedure" { - procNameToken := Unparsed{ - Type: OtherToken, - RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), - } - result.Body = append(result.Body, procNameToken) - } - firstAs = false - } - */ - - default: - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - } - } - - sort.Slice(result.DependsOn, func(i, j int) bool { - return result.DependsOn[i].Value < result.DependsOn[j].Value - }) - return -} - -func Parse(s *Scanner, result *Document) { - // Top-level parse; this focuses on splitting into "batches" separated - // by 'go'. - - // CONVENTION: - // All functions should expect `s` positioned on what they are documented - // to consume/parse. - // - // Functions typically consume *after* the keyword that triggered their - // invoication; e.g. parseCreate parses from first non-whitespace-token - // *after* `create`. - // - // On return, `s` is positioned at the token that starts the next statement/ - // sub-expression. In particular trailing ';' and whitespace has been consumed. - // - // `s` will typically never be positioned on whitespace except in - // whitespace-preserving parsing - - s.NextNonWhitespaceToken() - result.parsePragmas(s) - hasMore := result.parseBatch(s, true) - for hasMore { - hasMore = result.parseBatch(s, false) - } - return -} - -func ParseString(filename FileRef, input string) (result Document) { - Parse(&Scanner{input: input, file: filename}, &result) - return +// Helper function +func ParseString(file, input string) sqldocument.Document { + doc := NewDocumentFromExtension(filepath.Ext(file)) + doc.Parse([]byte(input), sqldocument.FileRef(file)) + return doc } -// ParseFileystems iterates through a list of filesystems and parses all files -// matching `*.sql`, determines which one are sqlcode files from the contents, -// and returns the combination of all of them. +// ParseFileystems iterates through a list of filesystems and parses all supported +// SQL files and returns the combination of all of them. // // err will only return errors related to filesystems/reading. Errors // related to parsing/sorting will be in result.Errors. // // ParseFilesystems will also sort create statements topologically. -func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, result Document, err error) { +func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, result sqldocument.Document, err error) { // We are being passed several *filesystems* here. It may be easy to pass in the same // directory twice but that should not be encouraged, so if we get the same hash from // two files, return an error. Only files containing [code] in some way will be @@ -569,6 +57,10 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, hashes := make(map[[32]byte]string) + if result == nil { + result = &mssql.TSqlDocument{} + } + for fidx, fsys := range fslst { // WalkDir is in lexical order according to docs, so output should be stable err = fs.WalkDir(fsys, ".", @@ -580,7 +72,9 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, if strings.HasPrefix(path, ".") || strings.Contains(path, "/.") { return nil } - if !strings.HasSuffix(path, ".sql") { + + extension := filepath.Ext(path) + if !slices.Contains(supportedSqlExtensions, extension) { return nil } @@ -600,15 +94,19 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, hash := sha256.Sum256(buf) existingPathDesc, hashExists := hashes[hash] if hashExists { - return errors.New(fmt.Sprintf("file %s has exact same contents as %s (possibly in different filesystems)", - pathDesc, existingPathDesc)) + return fmt.Errorf("file %s has exact same contents as %s (possibly in different filesystems)", + pathDesc, existingPathDesc) } hashes[hash] = pathDesc - var fdoc Document - Parse(&Scanner{input: string(buf), file: FileRef(path)}, &fdoc) + fdoc := NewDocumentFromExtension(extension) + err = fdoc.Parse(buf, sqldocument.FileRef(path)) + if err != nil { + return fmt.Errorf("error parsing file %s: %w", pathDesc, err) + } - if matchesIncludeTags(fdoc.PragmaIncludeIf, includeTags) { + // only include if include tags match + if matchesIncludeTags(fdoc.PragmaIncludeIf(), includeTags) { filenames = append(filenames, pathDesc) result.Include(fdoc) } @@ -620,17 +118,7 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, } } - // Do the topological sort; and include any error with it as part - // of `result`, *not* return it as err - sortedCreates, errpos, sortErr := TopologicalSort(result.Creates) - if sortErr != nil { - result.Errors = append(result.Errors, Error{ - Pos: errpos, - Message: sortErr.Error(), - }) - } else { - result.Creates = sortedCreates - } + result.Sort() return } @@ -659,7 +147,3 @@ func IsSqlcodeConstVariable(varname string) bool { strings.HasPrefix(varname, "@CONST_") || strings.HasPrefix(varname, "@const_") } - -// consider something a "sqlcode source file" if it contains [code] -// or a --sqlcode: header -var isSqlCodeRegex = regexp.MustCompile(`^--sqlcode:|\[code\]`) diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index 7bd20b8..45ece29 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -1,394 +1,231 @@ package sqlparser import ( - "fmt" - "strings" + "io/fs" "testing" + "testing/fstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestParserSmokeTest(t *testing.T) { - doc := ParseString("test.sql", ` -/* test is a test - -declare @EnumFoo int = 2; - -*/ - -declare/*comment*/@EnumBar1 varchar (max) = N'declare @EnumThisIsInString'; -declare - - - @EnumBar2 int = 20, - @EnumBar3 int=21; - -GO - -declare @EnumNextBatch int = 3; - +func TestParseFilesystems(t *testing.T) { + t.Run("basic parsing of sql files", func(t *testing.T) { + fsys := fstest.MapFS{ + "test1.sql": &fstest.MapFile{ + Data: []byte(` +declare @EnumFoo int = 1; go --- preceding comment 1 -/* preceding comment 2 - -asdfasdf */create procedure [code].TestFunc as begin - refers to [code].OtherFunc [code].HelloFunc; - create table x ( int x not null ); -- should be ok -end; +create procedure [code].Proc1 as begin end +`), + }, + "test2.sql": &fstest.MapFile{ + Data: []byte(` +create function [code].Func1() returns int as begin return 1 end +`), + }, + } -/* trailing comment */ -`) - docNoPos := doc.WithoutPos() + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Len(t, doc.Creates(), 2) + assert.Len(t, doc.Declares(), 1) + }) + + t.Run("filters by include tags", func(t *testing.T) { + fsys := fstest.MapFS{ + "included.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if foo,bar +create procedure [code].Included as begin end +`), + }, + "excluded.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if baz +create procedure [code].Excluded as begin end +`), + }, + } - require.Equal(t, 1, len(doc.Creates)) - c := doc.Creates[0] + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, []string{"foo", "bar"}) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "included.sql") + assert.Len(t, doc.Creates(), 1) + assert.Equal(t, "[Included]", doc.Creates()[0].QuotedName.Value) + }) - assert.Equal(t, "[TestFunc]", c.QuotedName.Value) - assert.Equal(t, []string{"[HelloFunc]", "[OtherFunc]"}, c.DependsOnStrings()) - assert.Equal(t, `-- preceding comment 1 -/* preceding comment 2 + t.Run("detects duplicate files with same hash", func(t *testing.T) { + contents := []byte(`create procedure [code].Test as begin end`) -asdfasdf */create procedure [code].TestFunc as begin - refers to [code].OtherFunc [code].HelloFunc; - create table x ( int x not null ); -- should be ok -end; + fs1 := fstest.MapFS{ + "test.sql": &fstest.MapFile{Data: contents}, + } + fs2 := fstest.MapFS{ + "test.sql": &fstest.MapFile{Data: contents}, + } -/* trailing comment */ -`, c.String()) + _, _, err := ParseFilesystems([]fs.FS{fs1, fs2}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "exact same contents") + }) - assert.Equal(t, - []Error{ - { - Message: "'declare' statement only allowed in first batch", + t.Run("skips non-sqlcode files", func(t *testing.T) { + fsys := fstest.MapFS{ + "regular.sql": &fstest.MapFile{ + Data: []byte(`select * from table1`), }, - }, docNoPos.Errors) - - assert.Equal(t, - []Declare{ - { - VariableName: "@EnumBar1", - Datatype: Type{ - BaseType: "varchar", - Args: []string{ - "max", - }, - }, - Literal: Unparsed{ - Type: NVarcharLiteralToken, - RawValue: "N'declare @EnumThisIsInString'", - }, + "sqlcode.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Test as begin end`), }, - { - VariableName: "@EnumBar2", - Datatype: Type{ - BaseType: "int", - }, - Literal: Unparsed{ - Type: NumberToken, - RawValue: "20", - }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "sqlcode.sql") + assert.Len(t, doc.Creates(), 1) + }) + + t.Run("skips hidden directories", func(t *testing.T) { + fsys := fstest.MapFS{ + "visible.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Visible as begin end`), }, - { - VariableName: "@EnumBar3", - Datatype: Type{ - BaseType: "int", - }, - Literal: Unparsed{ - Type: NumberToken, - RawValue: "21", - }, + ".hidden/test.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Hidden as begin end`), }, - { - VariableName: "@EnumNextBatch", - Datatype: Type{ - BaseType: "int", - }, - Literal: Unparsed{ - Type: NumberToken, - RawValue: "3", - }, + "dir/.git/test.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Git as begin end`), }, - }, - docNoPos.Declares, - ) - // repr.Println(doc) -} - -func TestParserDisallowMultipleCreates(t *testing.T) { - // Test that we get an error if we create two things in same batch; - // the test above tests that it is still OK to create a table within - // a procedure.. - doc := ParseString("test.sql", ` -create function [code].One(); --- the following should give an error; not that One() depends on Two()... --- (we don't parse body start/end yet) -create function [code].Two(); -`).WithoutPos() - - assert.Equal(t, []Error{ - { - Message: "a procedure/function must be alone in a batch; use 'go' to split batches", - }, - }, doc.Errors) -} - -func TestBuggyDeclare(t *testing.T) { - // this caused parses to infinitely loop; regression test... - doc := ParseString("test.sql", `declare @EnumA int = 4 @EnumB tinyint = 5 @ENUM_C bigint = 435;`) - assert.Equal(t, 1, len(doc.Errors)) - assert.Equal(t, "Unexpected: @EnumB", doc.Errors[0].Message) -} - -func TestCreateType(t *testing.T) { - doc := ParseString("test.sql", `create type [code].MyType as table (x int not null primary key);`) - assert.Equal(t, 1, len(doc.Creates)) - assert.Equal(t, "type", doc.Creates[0].CreateType) - assert.Equal(t, "[MyType]", doc.Creates[0].QuotedName.Value) -} - -func TestPragma(t *testing.T) { - doc := ParseString("test.sql", `--sqlcode:include-if one,two ---sqlcode:include-if three - -create procedure [code].ProcedureShouldAlsoHavePragmasAnnotated() -`) - assert.Equal(t, []string{"one", "two", "three"}, doc.PragmaIncludeIf) -} - -func TestInfiniteLoopRegression(t *testing.T) { - // success if we terminate!... - doc := ParseString("test.sql", `@declare`) - assert.Equal(t, 1, len(doc.Errors)) -} - -func TestDeclareSeparation(t *testing.T) { - // Trying out many possible ways to separate declare statements: - // Comman, semicolon, simply starting a new declare with or without - // whitespace in between. - // Yes, ='hello'declare @EnumThird really does parse as T-SQL - doc := ParseString("test.sql", ` -declare @EnumFirst int = 3, @EnumSecond varchar(max) = 'hello'declare @EnumThird int=3 declare @EnumFourth int=4;declare @EnumFifth int =5 -`) - //repr.Println(doc.Declares) - require.Equal(t, []Declare{ - { - VariableName: "@EnumFirst", - Datatype: Type{BaseType: "int"}, - Literal: Unparsed{Type: NumberToken, RawValue: "3"}, - }, - { - VariableName: "@EnumSecond", - Datatype: Type{BaseType: "varchar", Args: []string{"max"}}, - Literal: Unparsed{Type: VarcharLiteralToken, RawValue: "'hello'"}, - }, - { - VariableName: "@EnumThird", - Datatype: Type{BaseType: "int"}, - Literal: Unparsed{Type: NumberToken, RawValue: "3"}, - }, - { - VariableName: "@EnumFourth", - Datatype: Type{BaseType: "int"}, - Literal: Unparsed{Type: NumberToken, RawValue: "4"}, - }, - { - VariableName: "@EnumFifth", - Datatype: Type{BaseType: "int"}, - Literal: Unparsed{Type: NumberToken, RawValue: "5"}, - }, - }, doc.WithoutPos().Declares) -} - -func TestBatchDivisionsAndCreateStatements(t *testing.T) { - // Had a bug where comments where repeated on each create statement in different batches, discovery & regression - // (The bug was that a token too much was consumed in parseCreate, consuming the `go` token..) - doc := ParseString("test.sql", ` -create type [code].Batch1 as table (x int); -go --- a comment in 2nd batch -create procedure [code].Batch2 as table (x int); -go -create type [code].Batch3 as table (x int); -`) - commentCount := 0 - for _, c := range doc.Creates { - for _, b := range c.Body { - if strings.Contains(b.RawValue, "2nd") { - commentCount++ - } - assert.NotEqual(t, "go", b.RawValue) } - } - assert.Equal(t, 1, commentCount) -} -func TestCreateTypes(t *testing.T) { - // Apparently there can be several 'create type' per batch, but only one function/procedure... - // Check we catch all 3 types - doc := ParseString("test.sql", ` -create type [code].Type1 as table (x int); -create type [code].Type2 as table (x int); -create type [code].Type3 as table (x int); -`) - require.Equal(t, 3, len(doc.Creates)) - assert.Equal(t, "[Type1]", doc.Creates[0].QuotedName.Value) - assert.Equal(t, "[Type3]", doc.Creates[2].QuotedName.Value) - // There was a bug that the last item in the body would be the 'create' - // of the next statement; regression test.. - assert.Equal(t, "\n", doc.Creates[0].Body[len(doc.Creates[0].Body)-1].RawValue) - assert.Equal(t, "create", doc.Creates[1].Body[0].RawValue) -} + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "visible.sql") + assert.Len(t, doc.Creates(), 1) + }) + + t.Run("handles dependencies and topological sort", func(t *testing.T) { + fsys := fstest.MapFS{ + "proc1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc1 as begin exec [code].Proc2 end`), + }, + "proc2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc2 as begin select 1 end`), + }, + } -func TestCreateProcs(t *testing.T) { - // Apparently there can be several 'create type' per batch, but only one function/procedure... - // Check that we get an error for all further create statements in the same batch - doc := ParseString("test.sql", ` -create procedure [code].FirstProc as table (x int) -create function [code].MyFunction () -create type [code].MyType () -create procedure [code].MyProcedure () -`) - // First function and last procedure triggers errors. - require.Equal(t, 2, len(doc.Errors)) - emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" - assert.Equal(t, emsg, doc.Errors[0].Message) - assert.Equal(t, emsg, doc.Errors[1].Message) + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Len(t, doc.Creates(), 2) + // Proc2 should come before Proc1 due to dependency + assert.Equal(t, "[Proc2]", doc.Creates()[0].QuotedName.Value) + assert.Equal(t, "[Proc1]", doc.Creates()[1].QuotedName.Value) + }) + + t.Run("reports topological sort errors", func(t *testing.T) { + fsys := fstest.MapFS{ + "circular1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].A as begin exec [code].B end`), + }, + "circular2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].B as begin exec [code].A end`), + }, + } -} + _, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) // filesystem error should be nil + assert.NotEmpty(t, doc.Errors()) // but parsing errors should exist + assert.Contains(t, doc.Errors()[0].Message, "Detected a dependency cycle") + }) -func TestCreateProcs2(t *testing.T) { - // Create type first, then create proc... should give an error still.. - doc := ParseString("test.sql", ` -create type [code].MyType () -create procedure [code].FirstProc as table (x int) -`) - //repr.Println(doc.Errors) + t.Run("handles multiple filesystems", func(t *testing.T) { + fs1 := fstest.MapFS{ + "test1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc1 as begin end`), + }, + } + fs2 := fstest.MapFS{ + "test2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc2 as begin end`), + }, + } - // Code above was mainly to be able to step through parser in a given way. - // First function triggers an error. Then create type is parsed which is - // fine sharing a batch with others. - require.Equal(t, 1, len(doc.Errors)) - emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" - assert.Equal(t, emsg, doc.Errors[0].Message) -} + filenames, doc, err := ParseFilesystems([]fs.FS{fs1, fs2}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Contains(t, filenames[0], "fs[0]:") + assert.Contains(t, filenames[1], "fs[1]:") + assert.Len(t, doc.Creates(), 2) + }) + + t.Run("detects sqlcode files by pragma header", func(t *testing.T) { + fsys := fstest.MapFS{ + "test.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if foo +create procedure NotInCodeSchema.Test as begin end`), + }, + } -func TestCreateProcsAndCheckForRoutineName(t *testing.T) { - t.Skip() // Routine name is disabled for now - testcases := []struct { - name string - doc Document - expectedProcName string - expectedIndex int + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, []string{"foo"}) + require.NoError(t, err) + assert.Len(t, filenames, 1) + // Should still parse even though it will have errors (not in [code] schema) + assert.NotEmpty(t, doc.Errors()) + }) + + t.Run("empty filesystem returns empty results", func(t *testing.T) { + fsys := fstest.MapFS{} + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Empty(t, filenames) + assert.Empty(t, doc.Creates()) + assert.Empty(t, doc.Declares()) + }) +} + +func TestMatchesIncludeTags(t *testing.T) { + t.Run("empty requirements matches anything", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{}, []string{})) + assert.True(t, matchesIncludeTags([]string{}, []string{"foo"})) + }) + + t.Run("all requirements must be met", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"foo", "bar", "baz"})) + assert.False(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"foo"})) + assert.False(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"bar"})) + }) + + t.Run("exact match", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{"foo"}, []string{"foo"})) + assert.False(t, matchesIncludeTags([]string{"foo"}, []string{"bar"})) + }) +} + +func TestIsSqlcodeConstVariable(t *testing.T) { + testCases := []struct { + name string + varname string + expected bool }{ - { - name: "Test simple proc", - expectedProcName: "FirstProc", - doc: ParseString("test.sql", ` -create procedure [code].FirstProc as -begin -end -`), - expectedIndex: 10, - }, - { - name: "Test proc with args", - expectedProcName: "transform:safeguarding.Calculation/HEAD", - doc: ParseString("test.sql", ` -create procedure [code].[transform:safeguarding.Calculation/HEAD](@now datetime2, -@count bigint output) as -`), - expectedIndex: 22, - }, - } - for _, tc := range testcases { - require.Equal(t, 0, len(tc.doc.Errors)) - assert.Len(t, tc.doc.Creates, 1) - assert.Greater(t, len(tc.doc.Creates[0].Body), tc.expectedIndex) - assert.Equal(t, - fmt.Sprintf(templateRoutineName, tc.expectedProcName), - tc.doc.Creates[0].Body[tc.expectedIndex].RawValue, - ) + {"@Enum prefix", "@EnumFoo", true}, + {"@ENUM_ prefix", "@ENUM_FOO", true}, + {"@enum_ prefix", "@enum_foo", true}, + {"@Const prefix", "@ConstFoo", true}, + {"@CONST_ prefix", "@CONST_FOO", true}, + {"@const_ prefix", "@const_foo", true}, + {"regular variable", "@MyVariable", false}, + {"@Global prefix", "@GlobalVar", false}, + {"no @ prefix", "EnumFoo", false}, } -} - -func TestGoWithoutNewline(t *testing.T) { - doc := ParseString("test.sql", ` -create procedure [code].Foo() as begin -end; -go create function [code].Bar() returns int as begin -end -`) - // Code above was mainly to be able to step through parser in a given way. - // First function triggers an error. Then create type is parsed which is - // fine sharing a batch with others. - require.Equal(t, 2, len(doc.Errors)) - assert.Equal(t, "`go` should be alone on a line without any comments", doc.Errors[0].Message) - assert.Equal(t, "Expected 'declare' or 'create', got: end", doc.Errors[1].Message) -} - -func TestCreateAnnotationHappyDay(t *testing.T) { - // Comment / annotations on create statements - doc := ParseString("test.sql", ` --- Not part of annotation ---! key4: 1 - --- This is part of annotation ---! key1: a ---! key2: b ---! key3: [1,2,3] -create procedure [code].Foo as begin end -`) - assert.Equal(t, - "-- This is part of annotation\n--! key1: a\n--! key2: b\n--! key3: [1,2,3]", - doc.Creates[0].DocstringAsString()) - s, err := doc.Creates[0].DocstringYamldoc() - assert.NoError(t, err) - assert.Equal(t, - "key1: a\nkey2: b\nkey3: [1,2,3]", - s) - - var x struct { - Key1 string `yaml:"key1"` + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, IsSqlcodeConstVariable(tc.varname)) + }) } - require.NoError(t, doc.Creates[0].ParseYamlInDocstring(&x)) - assert.Equal(t, "a", x.Key1) -} - -func TestCreateAnnotationAfterPragma(t *testing.T) { - // Comment / annotations on create statement, with pragma at start of file - doc := ParseString("test.sql", ` ---sqlcode: include-if foo - --- docstring here -create procedure [code].Foo as begin end - -`) - assert.Equal(t, - "-- docstring here", - doc.Creates[0].DocstringAsString()) -} - -func TestCreateAnnotationErrors(t *testing.T) { - // Multiple embedded yaml documents .. - doc := ParseString("test.sql", ` ---! key4: 1 --- This comment after yamldoc is illegal; this also prevents multiple embedded YAML documents -create procedure [code].Foo as begin end -`) - _, err := doc.Creates[0].DocstringYamldoc() - assert.Equal(t, "test.sql:3:1 once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement", - err.Error()) - - // No whitespace after ! - doc = ParseString("test.sql", ` --- Docstring here ---!key4: 1 -create procedure [code].Foo as begin end -`) - _, err = doc.Creates[0].DocstringYamldoc() - assert.Equal(t, "test.sql:3:1 YAML document in docstring; missing space after `--!`", - err.Error()) - } diff --git a/sqlparser/scanner_test.go b/sqlparser/scanner_test.go deleted file mode 100644 index cff40fc..0000000 --- a/sqlparser/scanner_test.go +++ /dev/null @@ -1,292 +0,0 @@ -package sqlparser - -import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "strings" - "testing" -) - -func TestNextToken(t *testing.T) { - // just check that regexp should return nil if we didn't start to match... - assert.Equal(t, []int(nil), numberRegexp.FindStringIndex("a123")) - - testExt := func(startOfLine bool, prefix, input string, expectedTokenType TokenType, expected string, extraAssertion ...func(s *Scanner)) func(*testing.T) { - return func(t *testing.T) { - s := &Scanner{input: prefix + input, curIndex: len(prefix), startOfLine: startOfLine} - tt := s.NextToken() - assert.Equal(t, expectedTokenType, tt) - assert.Equal(t, expected, s.Token()) - for _, a := range extraAssertion { - a(s) - } - } - } - - test := func(input string, expectedTokenType TokenType, expected string, extraAssertion ...func(s *Scanner)) func(*testing.T) { - return testExt(false, "abcd", input, expectedTokenType, expected, extraAssertion...) - } - - t.Run("", test(" ", WhitespaceToken, " ")) - t.Run("", test(" a ", WhitespaceToken, " ")) - t.Run("", test(" \t\t\n\n \t \nasdf", WhitespaceToken, " \t\t\n\n \t \n")) - - t.Run("", test("123", NumberToken, "123")) - t.Run("", test("123;\n", NumberToken, "123")) - t.Run("", test("123\n", NumberToken, "123")) - t.Run("", test("123 ", NumberToken, "123")) - t.Run("", test("+123.e-3_asdf", NumberToken, "+123.e-3")) - t.Run("", test("-123.e+3+a", NumberToken, "-123.e+3")) - t.Run("", test("-123.12e3+a", NumberToken, "-123.12e3")) - t.Run("", test("-123.12e-35+a", NumberToken, "-123.12e-35")) - t.Run("", test("-123.12ea", NumberToken, "-123.12e")) - t.Run("", test("-123.12;\n", NumberToken, "-123.12")) - - t.Run("", test("'hello world'", VarcharLiteralToken, "'hello world'")) - t.Run("", test("'hello world'after", VarcharLiteralToken, "'hello world'")) - t.Run("", test("'hello '' world'after", VarcharLiteralToken, "'hello '' world'")) - t.Run("", test("''''", VarcharLiteralToken, "''''")) - t.Run("", test("''", VarcharLiteralToken, "''")) - - t.Run("", test("N'hello world'after", NVarcharLiteralToken, "N'hello world'")) - t.Run("", test("N''", NVarcharLiteralToken, "N''")) - - t.Run("", test("'''hello", UnterminatedVarcharLiteralErrorToken, "'''hello")) - t.Run("", test("N'''hello", UnterminatedVarcharLiteralErrorToken, "N'''hello")) - - t.Run("", test("[ quote \n quote]] hi]asdf", QuotedIdentifierToken, "[ quote \n quote]] hi]")) - t.Run("", test("[][]", QuotedIdentifierToken, "[]")) - t.Run("", test("[]]]", QuotedIdentifierToken, "[]]]")) - t.Run("", test("[]]test", UnterminatedQuotedIdentifierErrorToken, "[]]test")) - - t.Run("", test("/* comment\n\n */asdf", MultilineCommentToken, "/* comment\n\n */")) - t.Run("", test("/* comment\n\n ****/asdf", MultilineCommentToken, "/* comment\n\n ****/")) - // unterminated multiline comment is treated like a comment - t.Run("", test("/* comment\n\n asdf", MultilineCommentToken, "/* comment\n\n asdf")) - - // single stopLine comment .. trailing \n is not considered part of token - t.Run("", test("-- test\nhello", SinglelineCommentToken, "-- test")) - t.Run("", test("-- test", SinglelineCommentToken, "-- test")) - - t.Run("", test(`"asdf`, DoubleQuoteErrorToken, `"`)) - - t.Run("", test(``, EOFToken, ``)) - - t.Run("", test("abc", UnquotedIdentifierToken, "abc")) - t.Run("", test("@a#$$__bc a", VariableIdentifierToken, "@a#$$__bc")) - // identifier starting with N is special branch - t.Run("", test("N@a#$$__bc a", UnquotedIdentifierToken, "N@a#$$__bc")) - - t.Run("", test("