Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ func WithCloudFetch(useCloudFetch bool) ConnOption {
}
}

// WithHTTPClient allows a custom http client to be used for cloud fetch. Default is http.DefaultClient.
func WithHTTPClient(httpClient *http.Client) ConnOption {
return func(c *config.Config) {
c.UserConfig.CloudFetchConfig.HTTPClient = httpClient
}
}

// WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10.
func WithMaxDownloadThreads(numThreads int) ConnOption {
return func(c *config.Config) {
Expand Down
36 changes: 36 additions & 0 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,42 @@ func TestNewConnector(t *testing.T) {
require.True(t, ok)
assert.False(t, coni.cfg.EnableMetricViewMetadata)
})

t.Run("Connector test WithCloudFetchHTTPClient sets custom client", func(t *testing.T) {
host := "databricks-host"
accessToken := "token"
httpPath := "http-path"
customClient := &http.Client{Timeout: 5 * time.Second}

con, err := NewConnector(
WithServerHostname(host),
WithAccessToken(accessToken),
WithHTTPPath(httpPath),
WithHTTPClient(customClient),
)
assert.Nil(t, err)

coni, ok := con.(*connector)
require.True(t, ok)
assert.Equal(t, customClient, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient)
})

t.Run("Connector test WithCloudFetchHTTPClient with nil client is accepted", func(t *testing.T) {
host := "databricks-host"
accessToken := "token"
httpPath := "http-path"

con, err := NewConnector(
WithServerHostname(host),
WithAccessToken(accessToken),
WithHTTPPath(httpPath),
)
assert.Nil(t, err)

coni, ok := con.(*connector)
require.True(t, ok)
assert.Nil(t, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient)
})
}

type mockRoundTripper struct{}
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ type CloudFetchConfig struct {
MaxFilesInMemory int
MinTimeToExpiry time.Duration
CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1)
HTTPClient *http.Client
}

func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {
Expand Down
14 changes: 11 additions & 3 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func NewCloudIPCStreamIterator(
startRowOffset: startRowOffset,
pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](),
downloadTasks: NewQueue[cloudFetchDownloadTask](),
httpClient: cfg.UserConfig.CloudFetchConfig.HTTPClient,
}

for _, link := range files {
Expand Down Expand Up @@ -140,6 +141,7 @@ type cloudIPCStreamIterator struct {
startRowOffset int64
pendingLinks Queue[cli_service.TSparkArrowResultLink]
downloadTasks Queue[cloudFetchDownloadTask]
httpClient *http.Client
}

var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)
Expand All @@ -162,6 +164,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
resultChan: make(chan cloudFetchDownloadTaskResult),
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps,
httpClient: bi.httpClient,
}
task.Run()
bi.downloadTasks.Enqueue(task)
Expand Down Expand Up @@ -210,6 +213,7 @@ type cloudFetchDownloadTask struct {
link *cli_service.TSparkArrowResultLink
resultChan chan cloudFetchDownloadTaskResult
speedThresholdMbps float64
httpClient *http.Client
}

func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) {
Expand Down Expand Up @@ -252,7 +256,7 @@ func (cft *cloudFetchDownloadTask) Run() {
cft.link.StartRowOffset,
cft.link.RowCount,
)
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps)
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient)
if err != nil {
cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err}
return
Expand Down Expand Up @@ -300,6 +304,7 @@ func fetchBatchBytes(
link *cli_service.TSparkArrowResultLink,
minTimeToExpiry time.Duration,
speedThresholdMbps float64,
httpClient *http.Client,
) (io.ReadCloser, error) {
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
return nil, errors.New(dbsqlerr.ErrLinkExpired)
Expand All @@ -317,9 +322,12 @@ func fetchBatchBytes(
}
}

if httpClient == nil {
httpClient = http.DefaultClient
}

startTime := time.Now()
client := http.DefaultClient
res, err := client.Do(req)
res, err := httpClient.Do(req)
if err != nil {
return nil, err
}
Expand Down
97 changes: 97 additions & 0 deletions internal/rows/arrowbased/batchloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,103 @@ func TestCloudFetchIterator(t *testing.T) {
assert.NotNil(t, err3)
assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound))
})

t.Run("should use custom HTTPClient when provided", func(t *testing.T) {
customClient := &http.Client{Timeout: 5 * time.Second}
requestCount := 0

handler = func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.WriteHeader(http.StatusOK)
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
if err != nil {
panic(err)
}
}

startRowOffset := int64(100)

links := []*cli_service.TSparkArrowResultLink{
{
FileLink: server.URL,
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: startRowOffset,
RowCount: 1,
},
}

cfg := config.WithDefaults()
cfg.UseLz4Compression = false
cfg.MaxDownloadThreads = 1
cfg.UserConfig.CloudFetchConfig.HTTPClient = customClient

bi, err := NewCloudBatchIterator(
context.Background(),
links,
startRowOffset,
cfg,
)
assert.Nil(t, err)

// Verify custom client is passed through the iterator chain
wrapper, ok := bi.(*batchIterator)
assert.True(t, ok)
cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator)
assert.True(t, ok)
assert.Equal(t, customClient, cbi.httpClient)

// Fetch should work with custom client
sab1, nextErr := bi.Next()
assert.Nil(t, nextErr)
assert.NotNil(t, sab1)
assert.Greater(t, requestCount, 0) // Verify request was made
})

t.Run("should use http.DefaultClient when HTTPClient is nil", func(t *testing.T) {
handler = func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
if err != nil {
panic(err)
}
}

startRowOffset := int64(100)

links := []*cli_service.TSparkArrowResultLink{
{
FileLink: server.URL,
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: startRowOffset,
RowCount: 1,
},
}

cfg := config.WithDefaults()
cfg.UseLz4Compression = false
cfg.MaxDownloadThreads = 1
// HTTPClient is nil by default

bi, err := NewCloudBatchIterator(
context.Background(),
links,
startRowOffset,
cfg,
)
assert.Nil(t, err)

// Verify nil client is passed through
wrapper, ok := bi.(*batchIterator)
assert.True(t, ok)
cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator)
assert.True(t, ok)
assert.Nil(t, cbi.httpClient)

// Fetch should work (falls back to http.DefaultClient)
sab1, nextErr := bi.Next()
assert.Nil(t, nextErr)
assert.NotNil(t, sab1)
})
}

func generateArrowRecord() arrow.Record {
Expand Down