From 3a8bf4f1424c52878364a667857de0efa5624233 Mon Sep 17 00:00:00 2001 From: Lachlan Donald Date: Tue, 5 Aug 2025 14:11:40 +1000 Subject: [PATCH 1/2] fix: handle SSE error events from providers like Groq - Parse error events that use explicit SSE event types (event: error) - Extract only JSON data from error events, skipping event type lines - Reset error accumulator after each error to prevent data corruption - Add comprehensive test coverage for both OpenAI and Groq error formats - Maintain backward compatibility with OpenAI's simpler format Fixes parsing of error events from Groq API when invalid tool calls occur --- stream_reader.go | 42 +++++++++++--- stream_reader_test.go | 127 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 159 insertions(+), 10 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index 6faefe0a7..745004fc5 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -3,6 +3,7 @@ package openai import ( "bufio" "bytes" + "encoding/json" "fmt" "io" "net/http" @@ -40,6 +41,12 @@ func (stream *streamReader[T]) Recv() (response T, err error) { err = stream.unmarshaler.Unmarshal(rawLine, &response) if err != nil { + // If we get a JSON parsing error, it might be because we got an error event + // Check if we have accumulated error data + if _, ok := err.(*json.SyntaxError); ok && len(stream.errAccumulator.Bytes()) > 0 { + // We have error data, return a more informative error + return response, fmt.Errorf("failed to parse response (error event received): %s", string(stream.errAccumulator.Bytes())) + } return } return response, nil @@ -65,7 +72,18 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { if readErr != nil || hasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { - return nil, fmt.Errorf("error, %w", respErr.Error) + return nil, respErr.Error + } + // If we detected an error event but couldn't parse it, and the stream ended, + // return a more informative error. This handles cases where providers send + // error events that don't match the expected format and immediately close. + if hasErrorPrefix && readErr == io.EOF { + // Check if we have error data that failed to parse + errBytes := stream.errAccumulator.Bytes() + if len(errBytes) > 0 { + return nil, fmt.Errorf("failed to parse error event: %s", string(errBytes)) + } + return nil, fmt.Errorf("stream ended after error event") } return nil, readErr } @@ -73,20 +91,24 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { noSpaceLine := bytes.TrimSpace(rawLine) if errorPrefix.Match(noSpaceLine) { hasErrorPrefix = true - } - if !headerData.Match(noSpaceLine) || hasErrorPrefix { - if hasErrorPrefix { - noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil) - } - writeErr := stream.errAccumulator.Write(noSpaceLine) + // Extract just the JSON part after "data: " prefix + // This handles both OpenAI format (data: {"error": ...}) and + // Groq format (event: error\ndata: {"error": ...}) + jsonData := headerData.ReplaceAll(noSpaceLine, nil) + writeErr := stream.errAccumulator.Write(jsonData) if writeErr != nil { return nil, writeErr } + continue + } + + // Skip non-data lines (e.g., "event: error" from Groq) + // This allows us to handle SSE streams that use explicit event types + if !headerData.Match(noSpaceLine) { emptyMessagesCount++ if emptyMessagesCount > stream.emptyMessagesLimit { return nil, ErrTooManyEmptyStreamMessages } - continue } @@ -111,6 +133,10 @@ func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { errResp = nil } + // Reset the error accumulator for future error events + // A new accumulator is created to avoid potential interface issues + stream.errAccumulator = utils.NewErrorAccumulator() + return } diff --git a/stream_reader_test.go b/stream_reader_test.go index 449a14b43..fd901e0e6 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -54,11 +54,12 @@ func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { stream := &streamReader[ChatCompletionStreamResponse]{ - reader: bufio.NewReader(bytes.NewReader([]byte("\n"))), + reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"error\": {\"message\": \"test error\"}}\n"))), errAccumulator: &utils.DefaultErrorAccumulator{ Buffer: &test.FailingErrorBuffer{}, }, - unmarshaler: &utils.JSONUnmarshaler{}, + unmarshaler: &utils.JSONUnmarshaler{}, + emptyMessagesLimit: 5, } _, err := stream.Recv() checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) @@ -76,3 +77,125 @@ func TestStreamReaderRecvRaw(t *testing.T) { t.Fatalf("Did not return raw line: %v", string(rawLine)) } } + +func TestStreamReaderParsesErrorEvents(t *testing.T) { + // Test case simulating Groq's error event format + errorEvent := `event: error +data: {"error":{"message":"Invalid tool_call: tool \"name_unknown\" does not exist.","type":"invalid_request_error","code":"invalid_tool_call"}} + +` + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte(errorEvent))), + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + emptyMessagesLimit: 5, + } + + // Process the error event + _, err := stream.Recv() + if err == nil { + t.Fatal("Expected error but got nil") + } + + // Verify it's an APIError + apiErr, ok := err.(*APIError) + if !ok { + t.Fatalf("Expected APIError type but got %T: %v", err, err) + } + + // Verify the error fields are correctly parsed + if apiErr.Message != "Invalid tool_call: tool \"name_unknown\" does not exist." { + t.Errorf("Unexpected error message: %s", apiErr.Message) + } + if apiErr.Type != "invalid_request_error" { + t.Errorf("Unexpected error type: %s", apiErr.Type) + } + if apiErr.Code != "invalid_tool_call" { + t.Errorf("Unexpected error code: %v", apiErr.Code) + } +} + +func TestStreamReaderHandlesErrorEventWithExtraData(t *testing.T) { + // Test case with error event followed by more data + errorEvent := `data: {"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]} +event: error +data: {"error":{"message":"Stream interrupted","type":"server_error"}} +data: [DONE] +` + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte(errorEvent))), + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + emptyMessagesLimit: 5, + } + + // First recv should return the chat completion + resp, err := stream.Recv() + if err != nil { + t.Fatalf("First recv failed: %v", err) + } + if resp.ID != "chatcmpl-123" { + t.Errorf("Unexpected response ID: %s", resp.ID) + } + + // Second recv should return the error + _, err = stream.Recv() + if err == nil { + t.Fatal("Expected error but got nil") + } + + // Verify it's an APIError + apiErr, ok := err.(*APIError) + if !ok { + t.Fatalf("Expected APIError type but got %T: %v", err, err) + } + + if apiErr.Message != "Stream interrupted" { + t.Errorf("Unexpected error message: %s", apiErr.Message) + } +} + +func TestStreamReaderResetsErrorAccumulator(t *testing.T) { + // Test that error accumulator is reset after processing an error + multipleErrors := `event: error +data: {"error":{"message":"First error","type":"error_type_1"}} + +event: error +data: {"error":{"message":"Second error","type":"error_type_2"}} +` + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte(multipleErrors))), + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + emptyMessagesLimit: 5, + } + + // First recv should return the first error + _, err1 := stream.Recv() + if err1 == nil { + t.Fatal("Expected first error but got nil") + } + apiErr1, ok := err1.(*APIError) + if !ok { + t.Fatalf("Expected APIError type but got %T: %v", err1, err1) + } + if apiErr1.Message != "First error" { + t.Errorf("Unexpected first error message: %s", apiErr1.Message) + } + + // Second recv should return the second error (not a concatenation) + _, err2 := stream.Recv() + if err2 == nil { + t.Fatal("Expected second error but got nil") + } + apiErr2, ok := err2.(*APIError) + if !ok { + t.Fatalf("Expected APIError type but got %T: %v", err2, err2) + } + if apiErr2.Message != "Second error" { + t.Errorf("Unexpected second error message: %s", apiErr2.Message) + } + if apiErr2.Type != "error_type_2" { + t.Errorf("Unexpected second error type: %s", apiErr2.Type) + } +} From 366dd4642bc275f3a3a1dd8bebc06711c5c39067 Mon Sep 17 00:00:00 2001 From: Lachlan Donald Date: Tue, 5 Aug 2025 14:36:09 +1000 Subject: [PATCH 2/2] fix: address golangci-lint issues - Replace type assertions with errors.As for errorlint compliance - Break long lines to stay under 120 character limit - Maintain all existing functionality and test coverage --- stream_reader.go | 7 +++++-- stream_reader_test.go | 19 ++++++++++--------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index 745004fc5..deb3dad49 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -43,9 +44,11 @@ func (stream *streamReader[T]) Recv() (response T, err error) { if err != nil { // If we get a JSON parsing error, it might be because we got an error event // Check if we have accumulated error data - if _, ok := err.(*json.SyntaxError); ok && len(stream.errAccumulator.Bytes()) > 0 { + var syntaxErr *json.SyntaxError + if errors.As(err, &syntaxErr) && len(stream.errAccumulator.Bytes()) > 0 { // We have error data, return a more informative error - return response, fmt.Errorf("failed to parse response (error event received): %s", string(stream.errAccumulator.Bytes())) + return response, fmt.Errorf("failed to parse response (error event received): %s", + string(stream.errAccumulator.Bytes())) } return } diff --git a/stream_reader_test.go b/stream_reader_test.go index fd901e0e6..49dec7ee5 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -81,7 +81,8 @@ func TestStreamReaderRecvRaw(t *testing.T) { func TestStreamReaderParsesErrorEvents(t *testing.T) { // Test case simulating Groq's error event format errorEvent := `event: error -data: {"error":{"message":"Invalid tool_call: tool \"name_unknown\" does not exist.","type":"invalid_request_error","code":"invalid_tool_call"}} +data: {"error":{"message":"Invalid tool_call: tool \"name_unknown\" does not exist.",` + + `"type":"invalid_request_error","code":"invalid_tool_call"}} ` stream := &streamReader[ChatCompletionStreamResponse]{ @@ -98,8 +99,8 @@ data: {"error":{"message":"Invalid tool_call: tool \"name_unknown\" does not exi } // Verify it's an APIError - apiErr, ok := err.(*APIError) - if !ok { + var apiErr *APIError + if !errors.As(err, &apiErr) { t.Fatalf("Expected APIError type but got %T: %v", err, err) } @@ -145,8 +146,8 @@ data: [DONE] } // Verify it's an APIError - apiErr, ok := err.(*APIError) - if !ok { + var apiErr *APIError + if !errors.As(err, &apiErr) { t.Fatalf("Expected APIError type but got %T: %v", err, err) } @@ -175,8 +176,8 @@ data: {"error":{"message":"Second error","type":"error_type_2"}} if err1 == nil { t.Fatal("Expected first error but got nil") } - apiErr1, ok := err1.(*APIError) - if !ok { + var apiErr1 *APIError + if !errors.As(err1, &apiErr1) { t.Fatalf("Expected APIError type but got %T: %v", err1, err1) } if apiErr1.Message != "First error" { @@ -188,8 +189,8 @@ data: {"error":{"message":"Second error","type":"error_type_2"}} if err2 == nil { t.Fatal("Expected second error but got nil") } - apiErr2, ok := err2.(*APIError) - if !ok { + var apiErr2 *APIError + if !errors.As(err2, &apiErr2) { t.Fatalf("Expected APIError type but got %T: %v", err2, err2) } if apiErr2.Message != "Second error" {