From 0725b4da7396576c83e85937bdffd709b3336db9 Mon Sep 17 00:00:00 2001 From: Richard Keo Date: Tue, 1 Mar 2022 21:04:07 +0000 Subject: [PATCH 1/4] Implement driver.QueryerContext interface As per the [specifications](https://pkg.go.dev/database/sql/driver#QueryerContext), `QueryerContext` is an optional interface. When it's not implemented, the driver falls back to the `Queryer` interface that is also optional, and eventually defaults to: - preparing a query - executing the statement - closing the statement This means that the context passed to `QueryContext` gets completely ignored and therefore timeouts aren't honoured. This commit provides an implementation of the `QueryerContext` interface that honours the context. When the context expires or gets cancelled, the statement gets cancelled and closed, and upon completion, an error is returned. --- README.md | 3 ++ api/api.go | 1 + api/zapi_unix.go | 5 +++ conn.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++++ odbcstmt.go | 9 +++++ 5 files changed, 113 insertions(+) diff --git a/README.md b/README.md index 886efed..0ffb32e 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ our issues in detail. In this fork, we modify some of the column binding operations to work more nicely with Spark. +We also implement the [`driver.QueryerContext`](https://pkg.go.dev/database/sql/driver#QueryerContext) +which honours the context passed in, and returns when the context times out or gets cancelled. + ## Original `README.md` odbc driver written in go. Implements database driver interface as used by standard database/sql package. It calls into odbc dll on Windows, and uses cgo (unixODBC) everywhere else. diff --git a/api/api.go b/api/api.go index 65c214a..0027c92 100644 --- a/api/api.go +++ b/api/api.go @@ -64,6 +64,7 @@ type ( //sys SQLRowCount(statementHandle SQLHSTMT, rowCountPtr *SQLLEN) (ret SQLRETURN) = odbc32.SQLRowCount //sys SQLSetEnvAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) = odbc32.SQLSetEnvAttr //sys SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) = odbc32.SQLSetConnectAttrW +//sys SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLCancel // UTF16ToString returns the UTF-8 encoding of the UTF-16 sequence s, // with a terminating NUL removed. diff --git a/api/zapi_unix.go b/api/zapi_unix.go index e36d53f..3d4b85e 100644 --- a/api/zapi_unix.go +++ b/api/zapi_unix.go @@ -124,3 +124,8 @@ func SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr r := C.SQLSetConnectAttrW(C.SQLHDBC(connectionHandle), C.SQLINTEGER(attribute), C.SQLPOINTER(valuePtr), C.SQLINTEGER(stringLength)) return SQLRETURN(r) } + +func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) { + r := C.SQLCancel(C.SQLHSTMT(statementHandle)) + return SQLRETURN(r) +} diff --git a/conn.go b/conn.go index 0238fb9..20aaec6 100644 --- a/conn.go +++ b/conn.go @@ -5,7 +5,9 @@ package odbc import ( + "context" "database/sql/driver" + "errors" "strings" "unsafe" @@ -72,3 +74,96 @@ func (c *Conn) newError(apiName string, handle interface{}) error { } return err } + +// QueryContext implements the driver.QueryerContext interface. +// As per the specifications, it honours the context timeout and +// returns when the context is cancelled. +// When the context is cancelled, it first cancels the statement, +// then closes it, and returns an error. +func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + // Prepare a query + os, err := c.PrepareODBCStmt(query) + if err != nil { + return nil, err + } + + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + // Execute the statement + rowsChan := make(chan driver.Rows) + defer close(rowsChan) + errorChan := make(chan error) + defer close(errorChan) + + if ctx.Err() != nil { + os.closeByStmt() + return nil, ctx.Err() + } + + runQuery := func() { + err := os.Exec(dargs, c) + if err != nil { + errorChan <- err + return + } + + err = os.BindColumns() + if err != nil { + errorChan <- err + return + } + + os.usedByRows = true + rowsChan <- &Rows{os: os} + + // At the end of the execution, we check if the context has been cancelled + // to ensure there's no race condition below (L144). + if ctx.Err() != nil { + errorChan <- err + } + } + + go runQuery() + + var finalErr error + var finalRes driver.Rows + + select { + case <-ctx.Done(): + err := os.Cancel() + if err != nil { + finalErr = err + break + } + + // The statement has been cancelled, the query execution should eventually fail now. + // We wait for it in order to avoid having a dangling goroutine running in the background + <-errorChan + finalErr = ctx.Err() + case err := <-errorChan: + finalErr = err + case rows := <-rowsChan: + finalRes = rows + } + + // Close the statement + os.closeByStmt() + os = nil + + return finalRes, finalErr +} + +// namedValueToValue is a utility function that converts a driver.NamedValue into a driver.Value. +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("sql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} diff --git a/odbcstmt.go b/odbcstmt.go index c9bdc20..3da85d0 100644 --- a/odbcstmt.go +++ b/odbcstmt.go @@ -158,3 +158,12 @@ func (s *ODBCStmt) BindColumns() error { } return nil } + +func (s *ODBCStmt) Cancel() error { + ret := api.SQLCancel(s.h) + if IsError(ret) { + return NewError("SQLCancel", s.h) + } + + return nil +} From 86533c791e007a62c1d4a96adc65e3a2cb814b40 Mon Sep 17 00:00:00 2001 From: Richard Keo Date: Wed, 2 Mar 2022 10:31:57 +0000 Subject: [PATCH 2/4] Add code attribution --- conn.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/conn.go b/conn.go index 20aaec6..8932850 100644 --- a/conn.go +++ b/conn.go @@ -157,6 +157,8 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam } // namedValueToValue is a utility function that converts a driver.NamedValue into a driver.Value. +// Source: +// https://github.com/golang/go/blob/03ac39ce5e6af4c4bca58b54d5b160a154b7aa0e/src/database/sql/ctxutil.go#L137-L146 func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { dargs := make([]driver.Value, len(named)) for n, param := range named { From c8a991156ad8ad4d8374025a66c5ec9885aced44 Mon Sep 17 00:00:00 2001 From: Richard Keo Date: Wed, 2 Mar 2022 15:03:15 +0000 Subject: [PATCH 3/4] Address comments --- conn.go | 52 ++++++++++++++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/conn.go b/conn.go index 8932850..1419706 100644 --- a/conn.go +++ b/conn.go @@ -103,38 +103,15 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam return nil, ctx.Err() } - runQuery := func() { - err := os.Exec(dargs, c) - if err != nil { - errorChan <- err - return - } - - err = os.BindColumns() - if err != nil { - errorChan <- err - return - } - - os.usedByRows = true - rowsChan <- &Rows{os: os} - - // At the end of the execution, we check if the context has been cancelled - // to ensure there's no race condition below (L144). - if ctx.Err() != nil { - errorChan <- err - } - } - - go runQuery() + go c.wrapQuery(ctx, os, dargs, rowsChan, errorChan) var finalErr error var finalRes driver.Rows select { case <-ctx.Done(): - err := os.Cancel() - if err != nil { + // Context has been cancelled or has expired, cancel the statement + if err := os.Cancel(); err != nil { finalErr = err break } @@ -156,6 +133,29 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam return finalRes, finalErr } +// wrapQuery is following the same logic as `stmt.Query()` except that we don't use a lock +// because the ODBC statement doesn't get exposed externally. +func (c *Conn) wrapQuery(ctx context.Context, os *ODBCStmt, dargs []driver.Value, rowsChan chan<- driver.Rows, errorChan chan<- error) { + if err := os.Exec(dargs, c); err != nil { + errorChan <- err + return + } + + if err := os.BindColumns(); err != nil { + errorChan <- err + return + } + + os.usedByRows = true + rowsChan <- &Rows{os: os} + + // At the end of the execution, we check if the context has been cancelled + // to ensure the caller doesn't end up waiting for a message indefinitely (L121) + if ctx.Err() != nil { + errorChan <- ctx.Err() + } +} + // namedValueToValue is a utility function that converts a driver.NamedValue into a driver.Value. // Source: // https://github.com/golang/go/blob/03ac39ce5e6af4c4bca58b54d5b160a154b7aa0e/src/database/sql/ctxutil.go#L137-L146 From 7962bce848ca5c6755236800f3bde2999ecf468e Mon Sep 17 00:00:00 2001 From: Richard Keo Date: Thu, 3 Mar 2022 10:11:21 +0000 Subject: [PATCH 4/4] Provide implementation for zapi_windows --- api/zapi_unix.go | 10 +++++----- api/zapi_windows.go | 7 +++++++ conn.go | 8 +++----- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/api/zapi_unix.go b/api/zapi_unix.go index 3d4b85e..2055440 100644 --- a/api/zapi_unix.go +++ b/api/zapi_unix.go @@ -35,6 +35,11 @@ func SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, in return SQLRETURN(r) } +func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) { + r := C.SQLCancel(C.SQLHSTMT(statementHandle)) + return SQLRETURN(r) +} + func SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) { r := C.SQLCloseCursor(C.SQLHSTMT(statementHandle)) return SQLRETURN(r) @@ -124,8 +129,3 @@ func SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr r := C.SQLSetConnectAttrW(C.SQLHDBC(connectionHandle), C.SQLINTEGER(attribute), C.SQLPOINTER(valuePtr), C.SQLINTEGER(stringLength)) return SQLRETURN(r) } - -func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) { - r := C.SQLCancel(C.SQLHSTMT(statementHandle)) - return SQLRETURN(r) -} diff --git a/api/zapi_windows.go b/api/zapi_windows.go index 3657da3..b2104c8 100644 --- a/api/zapi_windows.go +++ b/api/zapi_windows.go @@ -42,6 +42,7 @@ var ( procSQLAllocHandle = mododbc32.NewProc("SQLAllocHandle") procSQLBindCol = mododbc32.NewProc("SQLBindCol") procSQLBindParameter = mododbc32.NewProc("SQLBindParameter") + procSQLCancel = mododbc32.NewProc("SQLCancel") procSQLCloseCursor = mododbc32.NewProc("SQLCloseCursor") procSQLDescribeColW = mododbc32.NewProc("SQLDescribeColW") procSQLDescribeParam = mododbc32.NewProc("SQLDescribeParam") @@ -80,6 +81,12 @@ func SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, in return } +func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLCancel.Addr(), 1, uintptr(statementHandle), 0, 0) + ret = SQLRETURN(r0) + return +} + func SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) { r0, _, _ := syscall.Syscall(procSQLCloseCursor.Addr(), 1, uintptr(statementHandle), 0, 0) ret = SQLRETURN(r0) diff --git a/conn.go b/conn.go index 1419706..2f94009 100644 --- a/conn.go +++ b/conn.go @@ -76,10 +76,8 @@ func (c *Conn) newError(apiName string, handle interface{}) error { } // QueryContext implements the driver.QueryerContext interface. -// As per the specifications, it honours the context timeout and -// returns when the context is cancelled. -// When the context is cancelled, it first cancels the statement, -// then closes it, and returns an error. +// As per the specifications, it honours the context timeout and returns when the context is cancelled. +// When the context is cancelled, it first cancels the statement, closes it, and then returns an error. func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { // Prepare a query os, err := c.PrepareODBCStmt(query) @@ -150,7 +148,7 @@ func (c *Conn) wrapQuery(ctx context.Context, os *ODBCStmt, dargs []driver.Value rowsChan <- &Rows{os: os} // At the end of the execution, we check if the context has been cancelled - // to ensure the caller doesn't end up waiting for a message indefinitely (L121) + // to ensure the caller doesn't end up waiting for a message indefinitely (L119) if ctx.Err() != nil { errorChan <- ctx.Err() }