diff --git a/api_client.go b/api_client.go index 40eec042..ec3c6c62 100644 --- a/api_client.go +++ b/api_client.go @@ -60,15 +60,19 @@ func sendStreamRequest[T responseStream[R], R any](ctx context.Context, ac *apiC var cancel context.CancelFunc if timeout != nil && *timeout > 0*time.Second && isTimeoutBeforeDeadline(ctx, *timeout) { requestContext, cancel = context.WithTimeout(ctx, *timeout) - defer cancel() } req = req.WithContext(requestContext) resp, err := doRequest(ac, req) if err != nil { + if cancel != nil { + cancel() + } return err } + // Transfer cancel ownership to the stream; the iterator will call it when done. + output.cancel = cancel // resp.Body will be closed by the iterator return deserializeStreamResponse(resp, output) } @@ -375,15 +379,19 @@ func deserializeUnaryResponse(resp *http.Response) (map[string]any, error) { } type responseStream[R any] struct { - r *bufio.Scanner - rc io.ReadCloser - h http.Header + r *bufio.Scanner + rc io.ReadCloser + h http.Header + cancel context.CancelFunc // cancels the request timeout context; called when iterator completes } func iterateResponseStream[R any](rs *responseStream[R], responseConverter func(responseMap map[string]any) (*R, error)) iter.Seq2[*R, error] { return func(yield func(*R, error) bool) { defer func() { - // Close the response body range over function is done. + // Cancel the request timeout context first, then close the response body. + if rs.cancel != nil { + rs.cancel() + } if err := rs.rc.Close(); err != nil { log.Printf("Error closing response body: %v", err) } diff --git a/api_client_test.go b/api_client_test.go index ddd7253f..919c6444 100644 --- a/api_client_test.go +++ b/api_client_test.go @@ -647,6 +647,49 @@ func TestSendStreamRequest(t *testing.T) { } } +func TestSendStreamRequestTimeoutNotCancelledEarly(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + flusher := w.(http.Flusher) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + for i := 1; i <= 4; i++ { + fmt.Fprintf(w, "data:{\"chunk\":%d}\n\n", i) + flusher.Flush() + time.Sleep(10 * time.Millisecond) + } + })) + defer ts.Close() + + ac := &apiClient{clientConfig: &ClientConfig{ + Backend: BackendGeminiAPI, + HTTPOptions: HTTPOptions{ + BaseURL: ts.URL, + APIVersion: "v0", + Headers: http.Header{"User-Agent": {"test-user-agent"}, "X-Goog-Api-Key": {"test-api-key"}}, + }, + HTTPClient: ts.Client(), + }} + + requestTimeout := 5 * time.Second + var output responseStream[map[string]any] + if err := sendStreamRequest(context.Background(), ac, "test", "POST", map[string]any{"key": "value"}, &HTTPOptions{Timeout: &requestTimeout, BaseURL: ac.clientConfig.HTTPOptions.BaseURL}, &output); err != nil { + t.Fatalf("sendStreamRequest() error = %v", err) + } + + var got []map[string]any + for resp, err := range iterateResponseStream(&output, func(m map[string]any) (*map[string]any, error) { return &m, nil }) { + if err != nil { + t.Fatalf("iterateResponseStream() error = %v", err) + } + got = append(got, *resp) + } + + want := []map[string]any{{"chunk": float64(1)}, {"chunk": float64(2)}, {"chunk": float64(3)}, {"chunk": float64(4)}} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("stream response mismatch (-want +got):\n%s", diff) + } +} + func TestMapToStruct(t *testing.T) { testCases := []struct { name string