diff --git a/connector.go b/connector.go index 53908b4c..8842a6b0 100644 --- a/connector.go +++ b/connector.go @@ -269,6 +269,13 @@ func WithCloudFetch(useCloudFetch bool) ConnOption { } } +// WithHTTPClient allows a custom http client to be used for cloud fetch. Default is http.DefaultClient. +func WithHTTPClient(httpClient *http.Client) ConnOption { + return func(c *config.Config) { + c.UserConfig.CloudFetchConfig.HTTPClient = httpClient + } +} + // WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10. func WithMaxDownloadThreads(numThreads int) ConnOption { return func(c *config.Config) { diff --git a/connector_test.go b/connector_test.go index 57554b98..a5e8632f 100644 --- a/connector_test.go +++ b/connector_test.go @@ -246,6 +246,42 @@ func TestNewConnector(t *testing.T) { require.True(t, ok) assert.False(t, coni.cfg.EnableMetricViewMetadata) }) + + t.Run("Connector test WithCloudFetchHTTPClient sets custom client", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + customClient := &http.Client{Timeout: 5 * time.Second} + + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + WithHTTPClient(customClient), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.Equal(t, customClient, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient) + }) + + t.Run("Connector test WithCloudFetchHTTPClient with nil client is accepted", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.Nil(t, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient) + }) } type mockRoundTripper struct{} diff --git a/internal/config/config.go b/internal/config/config.go index 67437a9c..e13cb98f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -479,6 +479,7 @@ type CloudFetchConfig struct { MaxFilesInMemory int MinTimeToExpiry time.Duration CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1) + HTTPClient *http.Client } func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig { diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index d26d8a4a..4d718b9e 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -40,6 +40,7 @@ func NewCloudIPCStreamIterator( startRowOffset: startRowOffset, pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), downloadTasks: NewQueue[cloudFetchDownloadTask](), + httpClient: cfg.UserConfig.CloudFetchConfig.HTTPClient, } for _, link := range files { @@ -140,6 +141,7 @@ type cloudIPCStreamIterator struct { startRowOffset int64 pendingLinks Queue[cli_service.TSparkArrowResultLink] downloadTasks Queue[cloudFetchDownloadTask] + httpClient *http.Client } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) @@ -162,6 +164,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { resultChan: make(chan cloudFetchDownloadTaskResult), minTimeToExpiry: bi.cfg.MinTimeToExpiry, speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps, + httpClient: bi.httpClient, } task.Run() bi.downloadTasks.Enqueue(task) @@ -210,6 +213,7 @@ type cloudFetchDownloadTask struct { link *cli_service.TSparkArrowResultLink resultChan chan cloudFetchDownloadTaskResult speedThresholdMbps float64 + httpClient *http.Client } func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { @@ -252,7 +256,7 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.StartRowOffset, cft.link.RowCount, ) - data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps) + data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient) if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} return @@ -300,6 +304,7 @@ func fetchBatchBytes( link *cli_service.TSparkArrowResultLink, minTimeToExpiry time.Duration, speedThresholdMbps float64, + httpClient *http.Client, ) (io.ReadCloser, error) { if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { return nil, errors.New(dbsqlerr.ErrLinkExpired) @@ -317,9 +322,12 @@ func fetchBatchBytes( } } + if httpClient == nil { + httpClient = http.DefaultClient + } + startTime := time.Now() - client := http.DefaultClient - res, err := client.Do(req) + res, err := httpClient.Do(req) if err != nil { return nil, err } diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index b018eb6d..c30e0e0b 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -253,6 +253,103 @@ func TestCloudFetchIterator(t *testing.T) { assert.NotNil(t, err3) assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) }) + + t.Run("should use custom HTTPClient when provided", func(t *testing.T) { + customClient := &http.Client{Timeout: 5 * time.Second} + requestCount := 0 + + handler = func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.UserConfig.CloudFetchConfig.HTTPClient = customClient + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + // Verify custom client is passed through the iterator chain + wrapper, ok := bi.(*batchIterator) + assert.True(t, ok) + cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) + assert.True(t, ok) + assert.Equal(t, customClient, cbi.httpClient) + + // Fetch should work with custom client + sab1, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab1) + assert.Greater(t, requestCount, 0) // Verify request was made + }) + + t.Run("should use http.DefaultClient when HTTPClient is nil", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + // HTTPClient is nil by default + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + // Verify nil client is passed through + wrapper, ok := bi.(*batchIterator) + assert.True(t, ok) + cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) + assert.True(t, ok) + assert.Nil(t, cbi.httpClient) + + // Fetch should work (falls back to http.DefaultClient) + sab1, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab1) + }) } func generateArrowRecord() arrow.Record {