diff --git a/README.md b/README.md index 886efed..0ffb32e 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ our issues in detail. In this fork, we modify some of the column binding operations to work more nicely with Spark. +We also implement the [`driver.QueryerContext`](https://pkg.go.dev/database/sql/driver#QueryerContext) +which honours the context passed in, and returns when the context times out or gets cancelled. + ## Original `README.md` odbc driver written in go. Implements database driver interface as used by standard database/sql package. It calls into odbc dll on Windows, and uses cgo (unixODBC) everywhere else. diff --git a/api/api.go b/api/api.go index 65c214a..0027c92 100644 --- a/api/api.go +++ b/api/api.go @@ -64,6 +64,7 @@ type ( //sys SQLRowCount(statementHandle SQLHSTMT, rowCountPtr *SQLLEN) (ret SQLRETURN) = odbc32.SQLRowCount //sys SQLSetEnvAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) = odbc32.SQLSetEnvAttr //sys SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) = odbc32.SQLSetConnectAttrW +//sys SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLCancel // UTF16ToString returns the UTF-8 encoding of the UTF-16 sequence s, // with a terminating NUL removed. diff --git a/api/zapi_unix.go b/api/zapi_unix.go index e36d53f..2055440 100644 --- a/api/zapi_unix.go +++ b/api/zapi_unix.go @@ -35,6 +35,11 @@ func SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, in return SQLRETURN(r) } +func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) { + r := C.SQLCancel(C.SQLHSTMT(statementHandle)) + return SQLRETURN(r) +} + func SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) { r := C.SQLCloseCursor(C.SQLHSTMT(statementHandle)) return SQLRETURN(r) diff --git a/api/zapi_windows.go b/api/zapi_windows.go index 3657da3..b2104c8 100644 --- a/api/zapi_windows.go +++ b/api/zapi_windows.go @@ -42,6 +42,7 @@ var ( procSQLAllocHandle = mododbc32.NewProc("SQLAllocHandle") procSQLBindCol = mododbc32.NewProc("SQLBindCol") procSQLBindParameter = mododbc32.NewProc("SQLBindParameter") + procSQLCancel = mododbc32.NewProc("SQLCancel") procSQLCloseCursor = mododbc32.NewProc("SQLCloseCursor") procSQLDescribeColW = mododbc32.NewProc("SQLDescribeColW") procSQLDescribeParam = mododbc32.NewProc("SQLDescribeParam") @@ -80,6 +81,12 @@ func SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, in return } +func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLCancel.Addr(), 1, uintptr(statementHandle), 0, 0) + ret = SQLRETURN(r0) + return +} + func SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) { r0, _, _ := syscall.Syscall(procSQLCloseCursor.Addr(), 1, uintptr(statementHandle), 0, 0) ret = SQLRETURN(r0) diff --git a/conn.go b/conn.go index 0238fb9..2f94009 100644 --- a/conn.go +++ b/conn.go @@ -5,7 +5,9 @@ package odbc import ( + "context" "database/sql/driver" + "errors" "strings" "unsafe" @@ -72,3 +74,96 @@ func (c *Conn) newError(apiName string, handle interface{}) error { } return err } + +// QueryContext implements the driver.QueryerContext interface. +// As per the specifications, it honours the context timeout and returns when the context is cancelled. +// When the context is cancelled, it first cancels the statement, closes it, and then returns an error. +func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + // Prepare a query + os, err := c.PrepareODBCStmt(query) + if err != nil { + return nil, err + } + + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + // Execute the statement + rowsChan := make(chan driver.Rows) + defer close(rowsChan) + errorChan := make(chan error) + defer close(errorChan) + + if ctx.Err() != nil { + os.closeByStmt() + return nil, ctx.Err() + } + + go c.wrapQuery(ctx, os, dargs, rowsChan, errorChan) + + var finalErr error + var finalRes driver.Rows + + select { + case <-ctx.Done(): + // Context has been cancelled or has expired, cancel the statement + if err := os.Cancel(); err != nil { + finalErr = err + break + } + + // The statement has been cancelled, the query execution should eventually fail now. + // We wait for it in order to avoid having a dangling goroutine running in the background + <-errorChan + finalErr = ctx.Err() + case err := <-errorChan: + finalErr = err + case rows := <-rowsChan: + finalRes = rows + } + + // Close the statement + os.closeByStmt() + os = nil + + return finalRes, finalErr +} + +// wrapQuery is following the same logic as `stmt.Query()` except that we don't use a lock +// because the ODBC statement doesn't get exposed externally. +func (c *Conn) wrapQuery(ctx context.Context, os *ODBCStmt, dargs []driver.Value, rowsChan chan<- driver.Rows, errorChan chan<- error) { + if err := os.Exec(dargs, c); err != nil { + errorChan <- err + return + } + + if err := os.BindColumns(); err != nil { + errorChan <- err + return + } + + os.usedByRows = true + rowsChan <- &Rows{os: os} + + // At the end of the execution, we check if the context has been cancelled + // to ensure the caller doesn't end up waiting for a message indefinitely (L119) + if ctx.Err() != nil { + errorChan <- ctx.Err() + } +} + +// namedValueToValue is a utility function that converts a driver.NamedValue into a driver.Value. +// Source: +// https://github.com/golang/go/blob/03ac39ce5e6af4c4bca58b54d5b160a154b7aa0e/src/database/sql/ctxutil.go#L137-L146 +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("sql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} diff --git a/odbcstmt.go b/odbcstmt.go index c9bdc20..3da85d0 100644 --- a/odbcstmt.go +++ b/odbcstmt.go @@ -158,3 +158,12 @@ func (s *ODBCStmt) BindColumns() error { } return nil } + +func (s *ODBCStmt) Cancel() error { + ret := api.SQLCancel(s.h) + if IsError(ret) { + return NewError("SQLCancel", s.h) + } + + return nil +}