Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dbsql

Check failure on line 1 in connection.go

View workflow job for this annotation

GitHub Actions / Lint

: # github.com/databricks/databricks-sql-go [github.com/databricks/databricks-sql-go.test]

import (
"context"
Expand Down Expand Up @@ -55,7 +55,7 @@

if err != nil {
log.Err(err).Msg("databricks: failed to close connection")
return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err)
return dbsqlerrint.NewBadConnectionError(err)
}
return nil
}
Expand Down Expand Up @@ -168,9 +168,7 @@
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
}

corrId := driverctx.CorrelationIdFromContext(ctx)
rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)

rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
return rows, err

}
Expand Down Expand Up @@ -367,7 +365,7 @@
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
var statusResp *cli_service.TGetOperationStatusResp
ctx = driverctx.NewContextWithConnId(ctx, c.id)
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
newCtx := context.WithoutCancel(ctx)

Check failure on line 368 in connection.go

View workflow job for this annotation

GitHub Actions / Test and Build (1.20.x, ubuntu-latest)

undefined: context.WithoutCancel

Check failure on line 368 in connection.go

View workflow job for this annotation

GitHub Actions / Test and Build (1.20.x, ubuntu-latest)

undefined: context.WithoutCancel

Check failure on line 368 in connection.go

View workflow job for this annotation

GitHub Actions / Lint

undefined: context.WithoutCancel (typecheck)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if context is cancelled, polling will not have cancellation information. Won't this affect cancellation of long running queries?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that context.WithoutCancel was added in Go 1.21, but go.mod pins to v1.20. This will fail.

Do we need a version change in go.mod as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would like to be compatible with 1.20 and hence would recommend finding a way to avoid using context.WithoutCancel

pollSentinel := sentinel.Sentinel{
OnDoneFn: func(statusResp any) (any, error) {
return statusResp, nil
Expand Down Expand Up @@ -566,7 +564,6 @@
return nil
}

corrId := driverctx.CorrelationIdFromContext(ctx)
var row driver.Rows
var err error

Expand All @@ -589,7 +586,7 @@
}

if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
}
Expand Down
31 changes: 21 additions & 10 deletions internal/rows/arrowbased/arrowRecordIterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"testing"

"github.com/databricks/databricks-sql-go/driverctx"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/client"
"github.com/databricks/databricks-sql-go/internal/config"
Expand All @@ -32,15 +33,17 @@ func TestArrowRecordIterator(t *testing.T) {

var fetchesInfo []fetchResultsInfo

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2})
rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 7311),
5000,
nil,
false,
simpleClient,
"connectionId",
"correlationId",
logger,
)

Expand Down Expand Up @@ -126,17 +129,19 @@ func TestArrowRecordIterator(t *testing.T) {
fetchResp3 := cli_service.TFetchResultsResp{}
loadTestData2(t, "multipleFetch/FetchResults3.json", &fetchResp3)

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

var fetchesInfo []fetchResultsInfo

simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3})
rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 0),
5000,
nil,
false,
simpleClient,
"connectionId",
"correlationId",
logger,
)

Expand Down Expand Up @@ -199,16 +204,18 @@ func TestArrowRecordIteratorSchema(t *testing.T) {
fetchResp1 := cli_service.TFetchResultsResp{}
loadTestData2(t, "directResultsMultipleFetch/FetchResults1.json", &fetchResp1)

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

var fetchesInfo []fetchResultsInfo
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 0),
5000,
nil,
false,
simpleClient,
"connectionId",
"correlationId",
logger,
)

Expand Down Expand Up @@ -251,16 +258,18 @@ func TestArrowRecordIteratorSchema(t *testing.T) {
fetchResp1 := cli_service.TFetchResultsResp{}
loadTestData2(t, "multipleFetch/FetchResults1.json", &fetchResp1)

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

var fetchesInfo []fetchResultsInfo
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 0),
5000,
nil,
false,
simpleClient,
"connectionId",
"correlationId",
logger,
)

Expand Down Expand Up @@ -293,14 +302,16 @@ func TestArrowRecordIteratorSchema(t *testing.T) {
},
}

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 0),
5000,
nil,
false,
failingClient,
"connectionId",
"correlationId",
logger,
)

Expand Down
7 changes: 5 additions & 2 deletions internal/rows/arrowbased/arrowRows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/databricks/databricks-sql-go/driverctx"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/config"
Expand Down Expand Up @@ -1525,18 +1526,20 @@ func TestArrowRowScanner(t *testing.T) {
fetchResp2 := cli_service.TFetchResultsResp{}
loadTestData2(t, "directResultsMultipleFetch/FetchResults2.json", &fetchResp2)

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

var fetchesInfo []fetchResultsInfo
client := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2})
logger := dbsqllog.WithContext("connectionId", "correlationId", "")

rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 7311),
5000,
nil,
false,
client,
"connectionId",
"correlationId",
logger)

cfg := config.WithDefaults()
Expand Down
16 changes: 7 additions & 9 deletions internal/rows/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,22 @@ var _ driver.RowsColumnTypeLength = (*rows)(nil)
var _ dbsqlrows.Rows = (*rows)(nil)

func NewRows(
connId string,
correlationId string,
ctx context.Context,
opHandle *cli_service.TOperationHandle,
client cli_service.TCLIService,
config *config.Config,
directResults *cli_service.TSparkDirectResults,
) (driver.Rows, dbsqlerr.DBError) {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a check for context being nil?

connId := driverctx.ConnIdFromContext(ctx)
correlationId := driverctx.CorrelationIdFromContext(ctx)

var logger *dbsqllog.DBSQLLogger
var ctx context.Context
if opHandle != nil {
logger = dbsqllog.WithContext(connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
ctx = driverctx.NewContextWithQueryId(ctx, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
} else {
logger = dbsqllog.WithContext(connId, correlationId, "")
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId)
}

if client == nil {
Expand Down Expand Up @@ -140,13 +140,12 @@ func NewRows(
// the operations.
closedOnServer := directResults != nil && directResults.CloseOperation != nil
r.ResultPageIterator = rowscanner.NewResultPageIterator(
ctx,
d,
pageSize,
opHandle,
closedOnServer,
client,
connId,
correlationId,
r.logger(),
)

Expand Down Expand Up @@ -417,9 +416,8 @@ func (r *rows) getResultSetSchema() (*cli_service.TTableSchema, dbsqlerr.DBError
req := cli_service.TGetResultSetMetadataReq{
OperationHandle: r.opHandle,
}
ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId)

resp, err2 := r.client.GetResultSetMetadata(ctx, &req)
resp, err2 := r.client.GetResultSetMetadata(r.ctx, &req)
if err2 != nil {
r.logger().Err(err2).Msg(err2.Error())
return nil, dbsqlerr_int.NewRequestError(r.ctx, errRowsMetadataFetchFailed, err)
Expand Down
Loading
Loading