diff --git a/connection.go b/connection.go index 6464a562..b3ee9482 100644 --- a/connection.go +++ b/connection.go @@ -117,6 +117,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name stagingErr := c.execStagingOperation(exStmtResp, ctx) if exStmtResp != nil && exStmtResp.OperationHandle != nil { + // we have an operation id so update the logger + log, _ := client.LoggerAndContext(ctx, exStmtResp) + // since we have an operation handle we can close the operation if necessary alreadyClosed := exStmtResp.DirectResults != nil && exStmtResp.DirectResults.CloseOperation != nil newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) @@ -167,7 +170,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } corrId := driverctx.CorrelationIdFromContext(ctx) - rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, logger.FromContext(ctx)) return rows, err @@ -340,7 +343,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) + log := logger.AddContext(logger.FromContext(ctx), c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) var statusResp *cli_service.TGetOperationStatusResp ctx = driverctx.NewContextWithConnId(ctx, c.id) newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) @@ -559,7 +562,7 @@ func (c *conn) execStagingOperation( } if len(driverctx.StagingPathsFromContext(ctx)) != 0 { - row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, logger.FromContext(ctx)) if err != nil { return dbsqlerrint.NewDriverError(ctx, "error reading row.", err) } diff --git a/connector.go b/connector.go index 96a88319..a9470303 100644 --- a/connector.go +++ b/connector.go @@ -61,7 +61,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { client: tclient, session: session, } - log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "") + log := logger.AddContext(logger.FromContext(ctx), conn.id, driverctx.CorrelationIdFromContext(ctx), "") log.Info().Msgf("connect: host=%s port=%d httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath) diff --git a/doc.go b/doc.go index 9463d771..f6ea19ff 100644 --- a/doc.go +++ b/doc.go @@ -154,6 +154,18 @@ The result log may look like this: {"level":"debug","connId":"01ed6545-5669-1ec7-8c7e-6d8a1ea0ab16","corrId":"workflow-example","queryId":"01ed6545-57cc-188a-bfc5-d9c0eaf8e189","time":1668558402,"message":"Run Main elapsed time: 1.298712292s"} +You may customize the log by passing it using Zerolog's context support. This allows customziation of the output, as well as inclusion of additionl metadata. + +For example, + + log := zerolog.New(DefaultLogOutput).With("service_id", "workflow-example")).Logger() + ctx = log.WithContext(context.Background()) + ... + db, err := sql.Open("databricks", "") + ... + rows, err := db.QueryContext(ctx, `select * from sometable`) + ... + # Programmatically Retrieving Connection and Query Id Use the driverctx package under driverctx/ctx.go to add callbacks to the query context to receive the connection id and query id. diff --git a/internal/client/client.go b/internal/client/client.go index fda1053e..36bdd54f 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -403,7 +403,7 @@ func LoggerAndContext(ctx context.Context, c any) (*logger.DBSQLLogger, context. queryId = guidFromHasOpHandle(c) ctx = driverctx.NewContextWithQueryId(ctx, queryId) } - log := logger.WithContext(connId, corrId, queryId) + log := logger.AddContext(logger.FromContext(ctx), connId, corrId, queryId) return log, ctx } diff --git a/internal/rows/rows.go b/internal/rows/rows.go index c9581e21..cc77c206 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -74,15 +74,15 @@ func NewRows( client cli_service.TCLIService, config *config.Config, directResults *cli_service.TSparkDirectResults, + logger *dbsqllog.DBSQLLogger, ) (driver.Rows, dbsqlerr.DBError) { - var logger *dbsqllog.DBSQLLogger var ctx context.Context if opHandle != nil { - logger = dbsqllog.WithContext(connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) + logger = dbsqllog.AddContext(logger, connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID)) } else { - logger = dbsqllog.WithContext(connId, correlationId, "") + logger = dbsqllog.AddContext(logger, connId, correlationId, "") ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId) } diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index b11868f4..e3d724f4 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -413,7 +413,7 @@ func TestColumnsWithDirectResults(t *testing.T) { client := getRowsTestSimpleClient(&getMetadataCount, &fetchResultsCount) - d, err := NewRows("", "", nil, client, nil, nil) + d, err := NewRows("", "", nil, client, nil, nil, nil) assert.Nil(t, err) rowSet := d.(*rows) @@ -708,7 +708,7 @@ func TestRowsCloseOptimization(t *testing.T) { } opHandle := &cli_service.TOperationHandle{OperationId: &cli_service.THandleIdentifier{GUID: []byte{'f', 'o'}}} - rowSet, _ := NewRows("", "", opHandle, client, nil, nil) + rowSet, _ := NewRows("", "", opHandle, client, nil, nil, nil) // rowSet has no direct results calling Close should result in call to client to close operation err := rowSet.Close() @@ -721,7 +721,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } closeCount = 0 - rowSet, _ = NewRows("", "", opHandle, client, nil, directResults) + rowSet, _ = NewRows("", "", opHandle, client, nil, directResults, nil) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 1, closeCount) @@ -734,7 +734,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{Schema: &cli_service.TTableSchema{}}, ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } - rowSet, _ = NewRows("", "", opHandle, client, nil, directResults) + rowSet, _ = NewRows("", "", opHandle, client, nil, directResults, nil) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 0, closeCount) @@ -799,7 +799,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults) + rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -869,7 +869,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, nil) + rows, err := NewRows("connId", "corrId", nil, client, cfg, nil, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -927,7 +927,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, nil) + rows, err := NewRows("connId", "corrId", nil, client, cfg, nil, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -951,7 +951,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{}) cfg := config.WithDefaults() - rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults) + rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) diff --git a/logger/logger.go b/logger/logger.go index 683501a1..f2edfe37 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,6 +1,7 @@ package logger import ( + "context" "io" "os" "runtime" @@ -123,9 +124,27 @@ func Err(err error) *zerolog.Event { return Logger.Err(err) } +// FromContext returns a DBSQLLogger from the provided context. If no logger is +// found, the default logger is returned. +func FromContext(ctx context.Context) *DBSQLLogger { + l := zerolog.Ctx(ctx) + if l == zerolog.DefaultContextLogger { + return Logger + } + return &DBSQLLogger{*l} +} + +// AddContext sets connectionId, correlationId, and queryId as fields on the provided logger. +func AddContext(l *DBSQLLogger, connectionId string, correlationId string, queryId string) *DBSQLLogger { + if l == nil { + l = Logger + } + return &DBSQLLogger{l.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()} +} + // WithContext sets connectionId, correlationId, and queryId to be used as fields. func WithContext(connectionId string, correlationId string, queryId string) *DBSQLLogger { - return &DBSQLLogger{Logger.With().Str("connId", connectionId).Str("corrId", correlationId).Str("queryId", queryId).Logger()} + return AddContext(nil, connectionId, correlationId, queryId) } // Track is a convenience function to track time spent