From 70cc8c2bb029fd4455b80b0ad60a353c33ce1663 Mon Sep 17 00:00:00 2001 From: Diego Giagio Date: Mon, 27 Oct 2025 15:21:17 -0400 Subject: [PATCH] Fix context loss in polling and connection close operations (#1) Signed-off-by: Diego Giagio --- connection.go | 11 ++-- .../arrowbased/arrowRecordIterator_test.go | 31 ++++++--- internal/rows/arrowbased/arrowRows_test.go | 7 ++- internal/rows/rows.go | 16 +++-- internal/rows/rows_test.go | 63 ++++++++++++++----- .../rows/rowscanner/resultPageIterator.go | 25 ++++---- 6 files changed, 96 insertions(+), 57 deletions(-) diff --git a/connection.go b/connection.go index 93de20e8..cf7e03a9 100644 --- a/connection.go +++ b/connection.go @@ -55,7 +55,7 @@ func (c *conn) Close() error { if err != nil { log.Err(err).Msg("databricks: failed to close connection") - return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err) + return dbsqlerrint.NewBadConnectionError(err) } return nil } @@ -168,9 +168,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp) } - corrId := driverctx.CorrelationIdFromContext(ctx) - rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) - + rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) return rows, err } @@ -367,7 +365,7 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) var statusResp *cli_service.TGetOperationStatusResp ctx = driverctx.NewContextWithConnId(ctx, c.id) - newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) + newCtx := context.WithoutCancel(ctx) pollSentinel := sentinel.Sentinel{ OnDoneFn: func(statusResp any) (any, error) { return statusResp, nil @@ -566,7 +564,6 @@ func (c *conn) execStagingOperation( return nil } - corrId := driverctx.CorrelationIdFromContext(ctx) var row driver.Rows var err error @@ -589,7 +586,7 @@ func (c *conn) execStagingOperation( } if len(driverctx.StagingPathsFromContext(ctx)) != 0 { - row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) if err != nil { return dbsqlerrint.NewDriverError(ctx, "error reading row.", err) } diff --git a/internal/rows/arrowbased/arrowRecordIterator_test.go b/internal/rows/arrowbased/arrowRecordIterator_test.go index 95fcd1a8..13891fc9 100644 --- a/internal/rows/arrowbased/arrowRecordIterator_test.go +++ b/internal/rows/arrowbased/arrowRecordIterator_test.go @@ -8,6 +8,7 @@ import ( "os" "testing" + "github.com/databricks/databricks-sql-go/driverctx" "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/client" "github.com/databricks/databricks-sql-go/internal/config" @@ -32,15 +33,17 @@ func TestArrowRecordIterator(t *testing.T) { var fetchesInfo []fetchResultsInfo + ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId") + simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) rpi := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(0, 7311), 5000, nil, false, simpleClient, - "connectionId", - "correlationId", logger, ) @@ -126,17 +129,19 @@ func TestArrowRecordIterator(t *testing.T) { fetchResp3 := cli_service.TFetchResultsResp{} loadTestData2(t, "multipleFetch/FetchResults3.json", &fetchResp3) + ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId") + var fetchesInfo []fetchResultsInfo simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) rpi := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(0, 0), 5000, nil, false, simpleClient, - "connectionId", - "correlationId", logger, ) @@ -199,16 +204,18 @@ func TestArrowRecordIteratorSchema(t *testing.T) { fetchResp1 := cli_service.TFetchResultsResp{} loadTestData2(t, "directResultsMultipleFetch/FetchResults1.json", &fetchResp1) + ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId") + var fetchesInfo []fetchResultsInfo simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1}) rpi := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(0, 0), 5000, nil, false, simpleClient, - "connectionId", - "correlationId", logger, ) @@ -251,16 +258,18 @@ func TestArrowRecordIteratorSchema(t *testing.T) { fetchResp1 := cli_service.TFetchResultsResp{} loadTestData2(t, "multipleFetch/FetchResults1.json", &fetchResp1) + ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId") + var fetchesInfo []fetchResultsInfo simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1}) rpi := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(0, 0), 5000, nil, false, simpleClient, - "connectionId", - "correlationId", logger, ) @@ -293,14 +302,16 @@ func TestArrowRecordIteratorSchema(t *testing.T) { }, } + ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId") + rpi := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(0, 0), 5000, nil, false, failingClient, - "connectionId", - "correlationId", logger, ) diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index 336186d4..9d7eaba8 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -13,6 +13,7 @@ import ( "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" + "github.com/databricks/databricks-sql-go/driverctx" dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/config" @@ -1525,18 +1526,20 @@ func TestArrowRowScanner(t *testing.T) { fetchResp2 := cli_service.TFetchResultsResp{} loadTestData2(t, "directResultsMultipleFetch/FetchResults2.json", &fetchResp2) + ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId") + var fetchesInfo []fetchResultsInfo client := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) logger := dbsqllog.WithContext("connectionId", "correlationId", "") rpi := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(0, 7311), 5000, nil, false, client, - "connectionId", - "correlationId", logger) cfg := config.WithDefaults() diff --git a/internal/rows/rows.go b/internal/rows/rows.go index 85603fac..963a3ce1 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -67,22 +67,22 @@ var _ driver.RowsColumnTypeLength = (*rows)(nil) var _ dbsqlrows.Rows = (*rows)(nil) func NewRows( - connId string, - correlationId string, + ctx context.Context, opHandle *cli_service.TOperationHandle, client cli_service.TCLIService, config *config.Config, directResults *cli_service.TSparkDirectResults, ) (driver.Rows, dbsqlerr.DBError) { + connId := driverctx.ConnIdFromContext(ctx) + correlationId := driverctx.CorrelationIdFromContext(ctx) + var logger *dbsqllog.DBSQLLogger - var ctx context.Context if opHandle != nil { logger = dbsqllog.WithContext(connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) - ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) + ctx = driverctx.NewContextWithQueryId(ctx, dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) } else { logger = dbsqllog.WithContext(connId, correlationId, "") - ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId) } if client == nil { @@ -140,13 +140,12 @@ func NewRows( // the operations. closedOnServer := directResults != nil && directResults.CloseOperation != nil r.ResultPageIterator = rowscanner.NewResultPageIterator( + ctx, d, pageSize, opHandle, closedOnServer, client, - connId, - correlationId, r.logger(), ) @@ -417,9 +416,8 @@ func (r *rows) getResultSetSchema() (*cli_service.TTableSchema, dbsqlerr.DBError req := cli_service.TGetResultSetMetadataReq{ OperationHandle: r.opHandle, } - ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId) - resp, err2 := r.client.GetResultSetMetadata(ctx, &req) + resp, err2 := r.client.GetResultSetMetadata(r.ctx, &req) if err2 != nil { r.logger().Err(err2).Msg(err2.Error()) return nil, dbsqlerr_int.NewRequestError(r.ctx, errRowsMetadataFetchFailed, err) diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index 65e391ab..fa6c913a 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/databricks/databricks-sql-go/driverctx" "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/client" "github.com/databricks/databricks-sql-go/internal/config" @@ -217,14 +218,16 @@ func TestRowsFetchResultPageNoDirectResults(t *testing.T) { client := getRowsTestSimpleClient(&getMetadataCount, &fetchResultsCount) rowSet := &rows{client: client} + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + resultPageIterator := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(0, 0), 1000, nil, false, client, - "connId", - "corrId", rowSet.logger(), ) rowSet.ResultPageIterator = resultPageIterator @@ -311,14 +314,16 @@ func TestRowsFetchResultPageWithDirectResults(t *testing.T) { err1 := rowSet.makeRowScanner(firstPage) assert.Nil(t, err1) + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + resultPageIterator := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(rowSet.RowScanner.Start(), rowSet.RowScanner.Count()), 1000, nil, false, client, - "connId", - "corrId", rowSet.logger(), ) rowSet.ResultPageIterator = resultPageIterator @@ -413,7 +418,10 @@ func TestColumnsWithDirectResults(t *testing.T) { client := getRowsTestSimpleClient(&getMetadataCount, &fetchResultsCount) - d, err := NewRows("", "", nil, client, nil, nil) + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + + d, err := NewRows(ctx, nil, client, nil, nil) assert.Nil(t, err) rowSet := d.(*rows) @@ -460,14 +468,16 @@ func TestNextNoDirectResults(t *testing.T) { client := getRowsTestSimpleClient(&getMetadataCount, &fetchResultsCount) rowSet.client = client + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + resultPageIterator := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(0, 0), 1000, nil, false, client, - "connId", - "corrId", rowSet.logger(), ) rowSet.ResultPageIterator = resultPageIterator @@ -707,8 +717,10 @@ func TestRowsCloseOptimization(t *testing.T) { }, } + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") opHandle := &cli_service.TOperationHandle{OperationId: &cli_service.THandleIdentifier{GUID: []byte{'f', 'o'}}} - rowSet, _ := NewRows("", "", opHandle, client, nil, nil) + rowSet, _ := NewRows(ctx, opHandle, client, nil, nil) // rowSet has no direct results calling Close should result in call to client to close operation err := rowSet.Close() @@ -721,7 +733,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } closeCount = 0 - rowSet, _ = NewRows("", "", opHandle, client, nil, directResults) + rowSet, _ = NewRows(ctx, opHandle, client, nil, directResults) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 1, closeCount) @@ -734,7 +746,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{Schema: &cli_service.TTableSchema{}}, ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } - rowSet, _ = NewRows("", "", opHandle, client, nil, directResults) + rowSet, _ = NewRows(ctx, opHandle, client, nil, directResults) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 0, closeCount) @@ -752,17 +764,19 @@ func TestFetchResultsWithRetries(t *testing.T) { // across multiple result pages. fetches := []fetch{} + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + client := getRowsTestSimpleClient2(&fetches) rowSet := &rows{client: client} resultPageIterator := rowscanner.NewResultPageIterator( + ctx, rowscanner.NewDelimiter(0, 0), 1000, nil, false, client, - "connId", - "corrId", rowSet.logger(), ) rowSet.ResultPageIterator = resultPageIterator @@ -797,9 +811,12 @@ func TestGetArrowBatches(t *testing.T) { fetchResp2 := cli_service.TFetchResultsResp{} loadTestData(t, "directResultsMultipleFetch/FetchResults2.json", &fetchResp2) + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults) + rows, err := NewRows(ctx, nil, client, cfg, executeStatementResp.DirectResults) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -867,9 +884,12 @@ func TestGetArrowBatches(t *testing.T) { fetchResp3 := cli_service.TFetchResultsResp{} loadTestData(t, "multipleFetch/FetchResults3.json", &fetchResp3) + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, nil) + rows, err := NewRows(ctx, nil, client, cfg, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -925,9 +945,12 @@ func TestGetArrowBatches(t *testing.T) { fetchResp1 := cli_service.TFetchResultsResp{} loadTestData(t, "zeroRows/zeroRowsFetchResult.json", &fetchResp1) + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, nil) + rows, err := NewRows(ctx, nil, client, cfg, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -949,9 +972,12 @@ func TestGetArrowBatches(t *testing.T) { loadTestData(t, "zeroRows/zeroRowsDirectResults.json", &executeStatementResp) executeStatementResp.DirectResults.ResultSet.Results.ArrowBatches = []*cli_service.TSparkArrowBatch{} + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + client := getSimpleClient([]cli_service.TFetchResultsResp{}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults) + rows, err := NewRows(ctx, nil, client, cfg, executeStatementResp.DirectResults) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -1525,9 +1551,12 @@ func TestFetchResultPage_PropagatesGetNextPageError(t *testing.T) { client := getErroringClient(expectedErr) + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") + executeStatementResp := cli_service.TExecuteStatementResp{} cfg := config.WithDefaults() - rows, _ := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults) + rows, _ := NewRows(ctx, nil, client, cfg, executeStatementResp.DirectResults) // Call Next and ensure it propagates the error from getNextPage actualErr := rows.Next(nil) diff --git a/internal/rows/rowscanner/resultPageIterator.go b/internal/rows/rowscanner/resultPageIterator.go index 43d45bbb..144d7447 100644 --- a/internal/rows/rowscanner/resultPageIterator.go +++ b/internal/rows/rowscanner/resultPageIterator.go @@ -45,32 +45,34 @@ func (d Direction) String() string { // Create a new result page iterator. func NewResultPageIterator( + ctx context.Context, delimiter Delimiter, maxPageSize int64, opHandle *cli_service.TOperationHandle, closedOnServer bool, client cli_service.TCLIService, - connectionId string, - correlationId string, logger *dbsqllog.DBSQLLogger, ) ResultPageIterator { // delimiter and hasMoreRows are used to set up the point in the paginated // result set that this iterator starts from. return &resultPageIterator{ + ctx: ctx, Delimiter: delimiter, isFinished: closedOnServer, maxPageSize: maxPageSize, opHandle: opHandle, closedOnServer: closedOnServer, client: client, - connectionId: connectionId, - correlationId: correlationId, + connectionId: driverctx.ConnIdFromContext(ctx), + correlationId: driverctx.CorrelationIdFromContext(ctx), logger: logger, } } type resultPageIterator struct { + ctx context.Context + // Gives the parameters of the current result page Delimiter @@ -167,7 +169,6 @@ func (rpf *resultPageIterator) getNextPage() (*cli_service.TFetchResultsResp, er nextPageStartRow := rpf.Start() + rpf.Count() rpf.logger.Debug().Msgf("databricks: fetching result page for row %d", nextPageStartRow) - ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), rpf.connectionId), rpf.correlationId) // Keep fetching in the appropriate direction until we have the expected page. var fetchResult *cli_service.TFetchResultsResp @@ -175,7 +176,7 @@ func (rpf *resultPageIterator) getNextPage() (*cli_service.TFetchResultsResp, er for b = rpf.Contains(nextPageStartRow); !b; b = rpf.Contains(nextPageStartRow) { direction := rpf.Direction(nextPageStartRow) - err := rpf.checkDirectionValid(ctx, direction) + err := rpf.checkDirectionValid(direction) if err != nil { return nil, err } @@ -190,10 +191,10 @@ func (rpf *resultPageIterator) getNextPage() (*cli_service.TFetchResultsResp, er IncludeResultSetMetadata: &includeResultSetMetadata, } - fetchResult, err = rpf.client.FetchResults(ctx, &req) + fetchResult, err = rpf.client.FetchResults(rpf.ctx, &req) if err != nil { rpf.logger.Err(err).Msg("databricks: Rows instance failed to retrieve results") - return nil, dbsqlerrint.NewRequestError(ctx, errRowsResultFetchFailed, err) + return nil, dbsqlerrint.NewRequestError(rpf.ctx, errRowsResultFetchFailed, err) } rpf.Delimiter = NewDelimiter(fetchResult.Results.StartRowOffset, CountRows(fetchResult.Results)) @@ -218,7 +219,7 @@ func (rpf *resultPageIterator) Close() (err error) { OperationHandle: rpf.opHandle, } - _, err = rpf.client.CloseOperation(context.Background(), &req) + _, err = rpf.client.CloseOperation(rpf.ctx, &req) return err } } @@ -283,11 +284,11 @@ func CountRows(rowSet *cli_service.TRowSet) int64 { } // Check if trying to fetch in the specified direction creates an error condition. -func (rpf *resultPageIterator) checkDirectionValid(ctx context.Context, direction Direction) error { +func (rpf *resultPageIterator) checkDirectionValid(direction Direction) error { if direction == DirBack { // can't fetch rows previous to the start if rpf.Start() == 0 { - return dbsqlerrint.NewDriverError(ctx, ErrRowsFetchPriorToStart, nil) + return dbsqlerrint.NewDriverError(rpf.ctx, ErrRowsFetchPriorToStart, nil) } } else if direction == DirForward { // can't fetch past the end of the query results @@ -296,7 +297,7 @@ func (rpf *resultPageIterator) checkDirectionValid(ctx context.Context, directio } } else { rpf.logger.Error().Msgf(errRowsUnandledFetchDirection(direction.String())) - return dbsqlerrint.NewDriverError(ctx, errRowsUnandledFetchDirection(direction.String()), nil) + return dbsqlerrint.NewDriverError(rpf.ctx, errRowsUnandledFetchDirection(direction.String()), nil) } return nil }