From 2a029f8549be09e94c24e8d3801107cb93210ece Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 6 Nov 2025 06:53:15 +0000 Subject: [PATCH 1/8] fix: no more multiple finsih reason in stream Signed-off-by: ayushag --- .../protocols/openai/chat_completions/jail.rs | 65 ++++++++- lib/llm/tests/test_streaming_tool_parsers.rs | 133 +++++++++++++++++- 2 files changed, 195 insertions(+), 3 deletions(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index de394b1e4e..7efb47a1b4 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -370,12 +370,17 @@ impl ChoiceJailState { struct ChoiceJailStateCollection { /// Vec of states, always kept sorted by choice index for deterministic iteration states: Vec, + /// Track if any choice has emitted a finish_reason (per choice index) + finish_reason_emitted: std::collections::HashMap, } impl ChoiceJailStateCollection { /// Create a new empty collection fn new() -> Self { - Self { states: Vec::new() } + Self { + states: Vec::new(), + finish_reason_emitted: std::collections::HashMap::new(), + } } /// Get or create state for a choice index @@ -394,6 +399,19 @@ impl ChoiceJailStateCollection { } } } + + /// Check if a finish_reason has already been emitted for this choice + fn has_emitted_finish_reason(&self, index: u32) -> bool { + self.finish_reason_emitted + .get(&index) + .copied() + .unwrap_or(false) + } + + /// Mark that a finish_reason has been emitted for this choice + fn mark_finish_reason_emitted(&mut self, index: u32) { + self.finish_reason_emitted.insert(index, true); + } } /// Emission mode for handling multiple choices @@ -456,6 +474,17 @@ impl JailedStream { // Process each choice independently using the new architecture for choice in &chat_response.choices { + // if we've already emitted a finish_reason for this choice, + // skip any subsequent chunks with finish_reason + if choice.finish_reason.is_some() && choice_states.has_emitted_finish_reason(choice.index) { + tracing::debug!( + "Skipping chunk with finish_reason {:?} for choice {} - already emitted finish_reason", + choice.finish_reason, + choice.index + ); + continue; + } + if let Some(ref content) = choice.delta.content { let choice_state = choice_states.get_or_create_state(choice.index); @@ -509,8 +538,16 @@ impl JailedStream { last_annotated_event.clone(), last_annotated_comment.clone(), ); - let responses = self.emit_choice_emissions(tool_content_emissions, chat_response, preserved_metadata); + let responses = self.emit_choice_emissions(tool_content_emissions.clone(), chat_response, preserved_metadata); for emitted_response in responses { + // Mark finish_reason as emitted for choices that have it + if let Some(ref data) = emitted_response.data { + for choice in &data.choices { + if choice.finish_reason.is_some() { + choice_states.mark_finish_reason_emitted(choice.index); + } + } + } yield emitted_response; } } @@ -524,6 +561,14 @@ impl JailedStream { ); let responses = self.emit_choice_emissions(trailing_emissions, chat_response, preserved_metadata); for emitted_response in responses { + // Mark finish_reason as emitted for choices that have it + if let Some(ref data) = emitted_response.data { + for choice in &data.choices { + if choice.finish_reason.is_some() { + choice_states.mark_finish_reason_emitted(choice.index); + } + } + } yield emitted_response; } } @@ -533,6 +578,14 @@ impl JailedStream { let current_metadata = (response.id.clone(), response.event.clone(), response.comment.clone()); let responses = self.emit_choice_emissions(passthrough_emissions, chat_response, current_metadata); for emitted_response in responses { + // Mark finish_reason as emitted for choices that have it + if let Some(ref data) = emitted_response.data { + for choice in &data.choices { + if choice.finish_reason.is_some() { + choice_states.mark_finish_reason_emitted(choice.index); + } + } + } yield emitted_response; } } @@ -568,6 +621,14 @@ impl JailedStream { let final_metadata = (last_annotated_id, last_annotated_event, last_annotated_comment); let responses = self.emit_choice_emissions(final_emissions, &dummy_response, final_metadata); for emitted_response in responses { + // Mark finish_reason as emitted for choices that have it + if let Some(ref data) = emitted_response.data { + for choice in &data.choices { + if choice.finish_reason.is_some() { + choice_states.mark_finish_reason_emitted(choice.index); + } + } + } yield emitted_response; } } diff --git a/lib/llm/tests/test_streaming_tool_parsers.rs b/lib/llm/tests/test_streaming_tool_parsers.rs index 5f7622c355..a4735a8521 100644 --- a/lib/llm/tests/test_streaming_tool_parsers.rs +++ b/lib/llm/tests/test_streaming_tool_parsers.rs @@ -26,7 +26,7 @@ across backends. */ -use dynamo_async_openai::types::ChatChoiceStream; +use dynamo_async_openai::types::{ChatChoiceStream, FinishReason}; use dynamo_llm::preprocessor::OpenAIPreprocessor; use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse; use dynamo_runtime::protocols::annotated::Annotated; @@ -304,6 +304,18 @@ mod tests { aggregated.has_tool_calls, expected_has_tool_calls, "Tool calls presence should match expected value" ); + + // Verify last chunk has Stop finish_reason for no-tool cases + let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + if let Some(data) = &last_chunk.data + && let Some(choice) = data.choices.first() + { + assert_eq!( + choice.finish_reason, + Some(FinishReason::Stop), + "Last chunk should have Stop finish_reason for non-tool call case" + ); + } } #[tokio::test] @@ -360,6 +372,22 @@ mod tests { // Verify tool calls assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); + + // Verify the last chunk has ToolCalls finish_reason (empty Stop chunks should be filtered) + let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + if let Some(data) = &last_chunk.data + && let Some(choice) = data.choices.first() + { + assert_eq!( + choice.finish_reason, + Some(FinishReason::ToolCalls), + "Last chunk should have ToolCalls finish_reason (empty Stop chunks should be filtered)" + ); + assert!( + choice.delta.tool_calls.is_some(), + "Last chunk with ToolCalls finish_reason must have tool_calls data" + ); + } } #[tokio::test] @@ -403,6 +431,18 @@ mod tests { aggregated.has_tool_calls, expected_has_tool_calls, "Tool calls presence should match expected value" ); + + // Verify last chunk has Stop finish_reason for no-tool cases + let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + if let Some(data) = &last_chunk.data + && let Some(choice) = data.choices.first() + { + assert_eq!( + choice.finish_reason, + Some(FinishReason::Stop), + "Last chunk should have Stop finish_reason for non-tool call case" + ); + } } #[tokio::test] @@ -455,6 +495,22 @@ mod tests { // Verify tool calls assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); + + // Verify the last chunk has ToolCalls finish_reason (empty Stop chunks should be filtered) + let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + if let Some(data) = &last_chunk.data + && let Some(choice) = data.choices.first() + { + assert_eq!( + choice.finish_reason, + Some(FinishReason::ToolCalls), + "Last chunk should have ToolCalls finish_reason (empty Stop chunks should be filtered)" + ); + assert!( + choice.delta.tool_calls.is_some(), + "Last chunk with ToolCalls finish_reason must have tool_calls data" + ); + } } #[tokio::test] @@ -511,6 +567,18 @@ mod tests { ); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); + + // Verify last chunk has Stop finish_reason for no-tool cases + let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + if let Some(data) = &last_chunk.data + && let Some(choice) = data.choices.first() + { + assert_eq!( + choice.finish_reason, + Some(FinishReason::Stop), + "Last chunk should have Stop finish_reason for non-tool call case" + ); + } } #[tokio::test] @@ -567,6 +635,23 @@ mod tests { ); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); + + // Verify there is a chunk with ToolCalls finish_reason and tool call data + let has_tool_calls_chunk = output_chunks.iter().any(|chunk| { + chunk + .data + .as_ref() + .and_then(|d| d.choices.first()) + .map(|c| { + c.finish_reason == Some(FinishReason::ToolCalls) + && c.delta.tool_calls.is_some() + }) + .unwrap_or(false) + }); + assert!( + has_tool_calls_chunk, + "Should have a chunk with ToolCalls finish_reason and tool_calls data" + ); } #[tokio::test] @@ -620,6 +705,18 @@ mod tests { ); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); + + // Verify last chunk has Stop finish_reason for no-tool cases + let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + if let Some(data) = &last_chunk.data + && let Some(choice) = data.choices.first() + { + assert_eq!( + choice.finish_reason, + Some(FinishReason::Stop), + "Last chunk should have Stop finish_reason for non-tool call case" + ); + } } #[tokio::test] @@ -674,6 +771,23 @@ mod tests { "Tool calls presence should match expected value" ); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); + + // Verify there is a chunk with ToolCalls finish_reason and tool call data + let has_tool_calls_chunk = output_chunks.iter().any(|chunk| { + chunk + .data + .as_ref() + .and_then(|d| d.choices.first()) + .map(|c| { + c.finish_reason == Some(FinishReason::ToolCalls) + && c.delta.tool_calls.is_some() + }) + .unwrap_or(false) + }); + assert!( + has_tool_calls_chunk, + "Should have a chunk with ToolCalls finish_reason and tool_calls data" + ); } #[tokio::test] @@ -726,5 +840,22 @@ mod tests { // Verify tool calls assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); + + // Verify the last chunk has ToolCalls finish_reason (empty Stop chunks should be filtered) + let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + if let Some(data) = &last_chunk.data + && let Some(choice) = data.choices.first() + { + assert_eq!( + choice.finish_reason, + Some(FinishReason::ToolCalls), + "Last chunk should have ToolCalls finish_reason (empty Stop chunks should be filtered)" + ); + assert!( + choice.delta.tool_calls.is_some(), + "Last chunk with ToolCalls finish_reason must have tool_calls data" + ); + } } + } From 6f2138ec0dca29e9bf93f5c504812aa1ed715bc4 Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 6 Nov 2025 07:02:21 +0000 Subject: [PATCH 2/8] fix: fmt Signed-off-by: ayushag --- lib/llm/tests/test_streaming_tool_parsers.rs | 35 +++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/lib/llm/tests/test_streaming_tool_parsers.rs b/lib/llm/tests/test_streaming_tool_parsers.rs index a4735a8521..fcc56e0517 100644 --- a/lib/llm/tests/test_streaming_tool_parsers.rs +++ b/lib/llm/tests/test_streaming_tool_parsers.rs @@ -306,7 +306,9 @@ mod tests { ); // Verify last chunk has Stop finish_reason for no-tool cases - let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + let last_chunk = output_chunks + .last() + .expect("Should have at least one chunk"); if let Some(data) = &last_chunk.data && let Some(choice) = data.choices.first() { @@ -374,7 +376,9 @@ mod tests { assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); // Verify the last chunk has ToolCalls finish_reason (empty Stop chunks should be filtered) - let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + let last_chunk = output_chunks + .last() + .expect("Should have at least one chunk"); if let Some(data) = &last_chunk.data && let Some(choice) = data.choices.first() { @@ -433,7 +437,9 @@ mod tests { ); // Verify last chunk has Stop finish_reason for no-tool cases - let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + let last_chunk = output_chunks + .last() + .expect("Should have at least one chunk"); if let Some(data) = &last_chunk.data && let Some(choice) = data.choices.first() { @@ -497,7 +503,9 @@ mod tests { assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); // Verify the last chunk has ToolCalls finish_reason (empty Stop chunks should be filtered) - let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + let last_chunk = output_chunks + .last() + .expect("Should have at least one chunk"); if let Some(data) = &last_chunk.data && let Some(choice) = data.choices.first() { @@ -569,7 +577,9 @@ mod tests { assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); // Verify last chunk has Stop finish_reason for no-tool cases - let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + let last_chunk = output_chunks + .last() + .expect("Should have at least one chunk"); if let Some(data) = &last_chunk.data && let Some(choice) = data.choices.first() { @@ -643,8 +653,7 @@ mod tests { .as_ref() .and_then(|d| d.choices.first()) .map(|c| { - c.finish_reason == Some(FinishReason::ToolCalls) - && c.delta.tool_calls.is_some() + c.finish_reason == Some(FinishReason::ToolCalls) && c.delta.tool_calls.is_some() }) .unwrap_or(false) }); @@ -707,7 +716,9 @@ mod tests { assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); // Verify last chunk has Stop finish_reason for no-tool cases - let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + let last_chunk = output_chunks + .last() + .expect("Should have at least one chunk"); if let Some(data) = &last_chunk.data && let Some(choice) = data.choices.first() { @@ -779,8 +790,7 @@ mod tests { .as_ref() .and_then(|d| d.choices.first()) .map(|c| { - c.finish_reason == Some(FinishReason::ToolCalls) - && c.delta.tool_calls.is_some() + c.finish_reason == Some(FinishReason::ToolCalls) && c.delta.tool_calls.is_some() }) .unwrap_or(false) }); @@ -842,7 +852,9 @@ mod tests { assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); // Verify the last chunk has ToolCalls finish_reason (empty Stop chunks should be filtered) - let last_chunk = output_chunks.last().expect("Should have at least one chunk"); + let last_chunk = output_chunks + .last() + .expect("Should have at least one chunk"); if let Some(data) = &last_chunk.data && let Some(choice) = data.choices.first() { @@ -857,5 +869,4 @@ mod tests { ); } } - } From 29351afce9a1b9a344aa921e65b2ef37b86db921 Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 6 Nov 2025 18:58:08 +0000 Subject: [PATCH 3/8] chore: vllm behaviour unleashed Signed-off-by: ayushag --- lib/llm/src/preprocessor.rs | 5 +- .../protocols/openai/chat_completions/jail.rs | 116 ++++----- lib/llm/tests/test_jail.rs | 186 ++++++++++---- lib/llm/tests/test_streaming_tool_parsers.rs | 235 ++++++++---------- 4 files changed, 294 insertions(+), 248 deletions(-) diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 9fda891fec..cfd3ab2e97 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -764,7 +764,10 @@ impl OpenAIPreprocessor { let jail = JailedStream::builder() .tool_call_parser(tool_call_parser) .build(); - jail.apply(stream) + let jailed_stream = jail.apply(stream); + + // Post-process to set finish reason to ToolCalls for the last chunk if any tool calls were emitted + JailedStream::fix_finish_reason(jailed_stream) } // Motivation: Each transformation on the stream should be a separate step to allow for more flexibility diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index 7efb47a1b4..abd8626e61 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -13,6 +13,7 @@ use dynamo_parsers::tool_calling::{ }; use dynamo_runtime::protocols::annotated::Annotated; use futures::{Stream, StreamExt}; +use std::collections::HashMap; use crate::utils::{MarkerMatcher, MatchResult}; @@ -370,17 +371,12 @@ impl ChoiceJailState { struct ChoiceJailStateCollection { /// Vec of states, always kept sorted by choice index for deterministic iteration states: Vec, - /// Track if any choice has emitted a finish_reason (per choice index) - finish_reason_emitted: std::collections::HashMap, } impl ChoiceJailStateCollection { /// Create a new empty collection fn new() -> Self { - Self { - states: Vec::new(), - finish_reason_emitted: std::collections::HashMap::new(), - } + Self { states: Vec::new() } } /// Get or create state for a choice index @@ -399,19 +395,6 @@ impl ChoiceJailStateCollection { } } } - - /// Check if a finish_reason has already been emitted for this choice - fn has_emitted_finish_reason(&self, index: u32) -> bool { - self.finish_reason_emitted - .get(&index) - .copied() - .unwrap_or(false) - } - - /// Mark that a finish_reason has been emitted for this choice - fn mark_finish_reason_emitted(&mut self, index: u32) { - self.finish_reason_emitted.insert(index, true); - } } /// Emission mode for handling multiple choices @@ -474,17 +457,6 @@ impl JailedStream { // Process each choice independently using the new architecture for choice in &chat_response.choices { - // if we've already emitted a finish_reason for this choice, - // skip any subsequent chunks with finish_reason - if choice.finish_reason.is_some() && choice_states.has_emitted_finish_reason(choice.index) { - tracing::debug!( - "Skipping chunk with finish_reason {:?} for choice {} - already emitted finish_reason", - choice.finish_reason, - choice.index - ); - continue; - } - if let Some(ref content) = choice.delta.content { let choice_state = choice_states.get_or_create_state(choice.index); @@ -538,16 +510,8 @@ impl JailedStream { last_annotated_event.clone(), last_annotated_comment.clone(), ); - let responses = self.emit_choice_emissions(tool_content_emissions.clone(), chat_response, preserved_metadata); + let responses = self.emit_choice_emissions(tool_content_emissions.clone(), chat_response, preserved_metadata, &choice_states); for emitted_response in responses { - // Mark finish_reason as emitted for choices that have it - if let Some(ref data) = emitted_response.data { - for choice in &data.choices { - if choice.finish_reason.is_some() { - choice_states.mark_finish_reason_emitted(choice.index); - } - } - } yield emitted_response; } } @@ -559,16 +523,8 @@ impl JailedStream { last_annotated_event.clone(), last_annotated_comment.clone(), ); - let responses = self.emit_choice_emissions(trailing_emissions, chat_response, preserved_metadata); + let responses = self.emit_choice_emissions(trailing_emissions, chat_response, preserved_metadata, &choice_states); for emitted_response in responses { - // Mark finish_reason as emitted for choices that have it - if let Some(ref data) = emitted_response.data { - for choice in &data.choices { - if choice.finish_reason.is_some() { - choice_states.mark_finish_reason_emitted(choice.index); - } - } - } yield emitted_response; } } @@ -576,16 +532,8 @@ impl JailedStream { // Emit pass-through content with current metadata if !passthrough_emissions.is_empty() { let current_metadata = (response.id.clone(), response.event.clone(), response.comment.clone()); - let responses = self.emit_choice_emissions(passthrough_emissions, chat_response, current_metadata); + let responses = self.emit_choice_emissions(passthrough_emissions, chat_response, current_metadata, &choice_states); for emitted_response in responses { - // Mark finish_reason as emitted for choices that have it - if let Some(ref data) = emitted_response.data { - for choice in &data.choices { - if choice.finish_reason.is_some() { - choice_states.mark_finish_reason_emitted(choice.index); - } - } - } yield emitted_response; } } @@ -619,16 +567,8 @@ impl JailedStream { }; let final_metadata = (last_annotated_id, last_annotated_event, last_annotated_comment); - let responses = self.emit_choice_emissions(final_emissions, &dummy_response, final_metadata); + let responses = self.emit_choice_emissions(final_emissions, &dummy_response, final_metadata, &choice_states); for emitted_response in responses { - // Mark finish_reason as emitted for choices that have it - if let Some(ref data) = emitted_response.data { - for choice in &data.choices { - if choice.finish_reason.is_some() { - choice_states.mark_finish_reason_emitted(choice.index); - } - } - } yield emitted_response; } } @@ -641,6 +581,7 @@ impl JailedStream { emissions: Vec, base_response: &NvCreateChatCompletionStreamResponse, annotated_metadata: (Option, Option, Option>), + _choice_states: &ChoiceJailStateCollection, ) -> Vec> { if emissions.is_empty() { return Vec::new(); @@ -770,14 +711,15 @@ impl JailedStream { .collect(); // Create choice with tool calls - return create_choice_stream( + let choice = create_choice_stream( choice_index, Some(Role::Assistant), normal_text.as_deref().unwrap_or(""), Some(tool_call_chunks), - Some(FinishReason::ToolCalls), + None, None, ); + return choice; } // No tool calls found or parsing failed, return content choice @@ -806,6 +748,44 @@ impl JailedStream { } false } + + /// Post-processor that sets finish_reason to ToolCalls when tool calls were emitted + /// This should be called after apply() to fix the finish_reason for tool call chunks + pub fn fix_finish_reason( + input_stream: S, + ) -> impl Stream> + Send + where + S: Stream> + Send + 'static, + { + stream! { + tokio::pin!(input_stream); + let mut has_tool_calls_per_choice: HashMap = HashMap::new(); + + while let Some(mut response) = input_stream.next().await { + // Track if any choice emitted tool calls + if let Some(ref data) = response.data { + for choice in &data.choices { + if choice.delta.tool_calls.is_some() { + has_tool_calls_per_choice.insert(choice.index, true); + } + } + } + + // If this chunk has finish_reason and the choice had tool calls, override to ToolCalls + if let Some(ref mut data) = response.data { + for choice in &mut data.choices { + if choice.finish_reason.is_some() + && has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false) + { + choice.finish_reason = Some(FinishReason::ToolCalls); + } + } + } + + yield response; + } + } + } } /// Builder for configuring a JailedStream diff --git a/lib/llm/tests/test_jail.rs b/lib/llm/tests/test_jail.rs index c45207e569..22887c6e6b 100644 --- a/lib/llm/tests/test_jail.rs +++ b/lib/llm/tests/test_jail.rs @@ -179,6 +179,49 @@ mod tests { } } + /// Helper function to create a multi-choice finish_reason chunk + pub fn create_multi_choice_finish_chunk( + choice_indices: Vec, + ) -> Annotated { + let choices: Vec = choice_indices + .into_iter() + .map(|index| { + #[allow(deprecated)] + ChatChoiceStream { + index, + delta: ChatCompletionStreamResponseDelta { + role: None, + content: None, + tool_calls: None, + function_call: None, + refusal: None, + reasoning_content: None, + }, + finish_reason: Some(FinishReason::Stop), + logprobs: None, + } + }) + .collect(); + + let response = NvCreateChatCompletionStreamResponse { + id: "test-id".to_string(), + choices, + created: 1234567890, + model: "test-model".to_string(), + system_fingerprint: Some("test-fingerprint".to_string()), + object: "chat.completion.chunk".to_string(), + usage: None, + service_tier: None, + }; + + Annotated { + data: Some(response), + id: None, + event: None, + comment: None, + } + } + /// Helper to assert content in a result pub fn assert_content( result: &Annotated, @@ -337,7 +380,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // We should only get 3 chunks now: // 1. "Hello " (before jail) @@ -394,7 +438,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have jailed the content and parsed tool calls at the end assert!(!results.is_empty()); @@ -432,7 +477,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // We should get 2 chunks: // 1. "Normal text " (before jail) @@ -476,7 +522,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have exactly 2 chunks: tool call + trailing content assert_eq!( @@ -519,7 +566,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // === Verify chunk count === assert_eq!( @@ -573,7 +621,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -619,7 +668,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -661,7 +711,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -710,7 +761,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("phi4").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -757,7 +809,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -798,7 +851,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // The "{" pattern triggers jailing, so some chunks get combined assert_eq!(results.len(), 2); @@ -840,7 +894,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Jailing combines the tool call content into fewer chunks assert_eq!( @@ -885,7 +940,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should handle partial tool call gracefully - releases accumulated content on stream end assert_eq!( @@ -925,7 +981,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // === Verify chunk count === assert_eq!( @@ -980,7 +1037,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // === Verify chunk count === assert_eq!( @@ -1088,7 +1146,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should consolidate extreme fragmentation into 3 clean chunks // Input: "I'll process your request. " + 54-char tool call + " Processing complete!" @@ -1142,6 +1201,7 @@ mod tests { create_mock_response_chunk("\"arguments\": {\"query\": \"test\"}}".to_string(), 0), create_mock_response_chunk("".to_string(), 0), create_mock_response_chunk(" Processing complete.".to_string(), 0), + test_utils::create_final_response_chunk(0), // Backend finish_reason chunk ]; let input_stream = stream::iter(chunks); @@ -1150,7 +1210,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should get 3 chunks: before jail, tool call response, after jail assert!( @@ -1159,14 +1220,14 @@ mod tests { results.len() ); - // Find the synthesized tool call response chunk + // Find the tool call chunk (the one with tool_calls, not the finish_reason chunk) let tool_call_chunk = results .iter() .find(|r| { r.data .as_ref() .and_then(|d| d.choices.first()) - .map(|c| c.finish_reason == Some(FinishReason::ToolCalls)) + .map(|c| c.delta.tool_calls.is_some()) .unwrap_or(false) }) .expect("Should have a tool call response chunk"); @@ -1233,7 +1294,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should get 2 chunks: first chunk passes through, stream end releases accumulated assert_eq!(results.len(), 2, "Should have exactly 2 chunks"); @@ -1291,6 +1353,7 @@ mod tests { ), create_mock_response_chunk("{\"name\": \"test\", \"arguments\": {}}".to_string(), 0), create_mock_response_chunk("".to_string(), 0), + test_utils::create_final_response_chunk(0), // Backend finish_reason chunk ]; let input_stream = stream::iter(chunks); @@ -1298,16 +1361,17 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; - // Find the tool call response + // Find the tool call chunk (the one with tool_calls, not the finish_reason chunk) let tool_call_chunk = results .iter() .find(|r| { r.data .as_ref() .and_then(|d| d.choices.first()) - .map(|c| c.finish_reason == Some(FinishReason::ToolCalls)) + .map(|c| c.delta.tool_calls.is_some()) .unwrap_or(false) }) .expect("Should have a tool call response chunk"); @@ -1353,7 +1417,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // === Verify chunk count === assert_eq!( @@ -1396,7 +1461,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have exactly 3 chunks: content + tool call + trailing assert_eq!( @@ -1453,6 +1519,8 @@ mod tests { ("Done with B. ".to_string(), 1), // Choice 1 continues ("".to_string(), 2), // Choice 2 unjails ]), + // Chunk 6: Backend finish_reason chunks for all choices + test_utils::create_multi_choice_finish_chunk(vec![0, 1, 2]), ]; let input_stream = stream::iter(chunks); @@ -1460,7 +1528,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // EXPECTED BEHAVIOR (will fail with current implementation): // - Choice 1 should stream continuously (never jailed) @@ -1529,6 +1598,7 @@ mod tests { 2, ), ]), + test_utils::create_multi_choice_finish_chunk(vec![0, 1, 2]), ]; let input_stream = stream::iter(chunks); @@ -1536,7 +1606,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Find all tool call responses let mut tool_call_responses: Vec<_> = results @@ -1559,25 +1630,32 @@ mod tests { // Run this test multiple times to verify determinism for run in 0..5 { - let chunks = vec![create_multi_choice_chunk(vec![ - ( - "{\"name\": \"tool_0\", \"arguments\": {}}".to_string(), - 0, - ), - ( - "{\"name\": \"tool_1\", \"arguments\": {}}".to_string(), - 1, - ), - ( - "{\"name\": \"tool_2\", \"arguments\": {}}".to_string(), - 2, - ), - ])]; + let chunks = vec![ + create_multi_choice_chunk(vec![ + ( + "{\"name\": \"tool_0\", \"arguments\": {}}" + .to_string(), + 0, + ), + ( + "{\"name\": \"tool_1\", \"arguments\": {}}" + .to_string(), + 1, + ), + ( + "{\"name\": \"tool_2\", \"arguments\": {}}" + .to_string(), + 2, + ), + ]), + test_utils::create_multi_choice_finish_chunk(vec![0, 1, 2]), + ]; let input_stream = stream::iter(chunks); let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jailed_stream = jail.apply(input_stream); - let run_results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let run_results: Vec<_> = fixed_stream.collect().await; let run_responses: Vec<_> = run_results .iter() @@ -1617,7 +1695,8 @@ mod tests { let jail = JailedStream::builder().build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // TODO: Once usage aggregation is implemented, verify: // - Usage chunk has choices: [] (empty array) @@ -1653,7 +1732,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // === Verify chunk count === assert_eq!( @@ -1709,7 +1789,8 @@ mod tests { .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // === Verify chunk count === assert_eq!( @@ -1764,7 +1845,8 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("harmony").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have at least one output containing both analysis text and parsed tool call assert!(!results.is_empty()); @@ -1805,7 +1887,8 @@ mod tests { .tool_call_parser("deepseek_v3_1") .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have at least one output containing both analysis text and parsed tool call assert!(!results.is_empty()); @@ -1879,7 +1962,8 @@ mod tests { .tool_call_parser("deepseek_v3_1") .build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should have at least one output containing both analysis text and parsed tool call assert!(!results.is_empty()); @@ -1921,7 +2005,8 @@ mod tests { let input_stream = stream::iter(chunks); let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; assert!(results.len() >= 2); assert_content(&results[0], "Hey How"); @@ -1957,7 +2042,8 @@ mod tests { let input_stream = stream::iter(chunks); let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jailed_stream = jail.apply(input_stream); - let results: Vec<_> = jailed_stream.collect().await; + let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); + let results: Vec<_> = fixed_stream.collect().await; // Should preserve earlier content and also produce a tool call assert!(results.len() >= 2); diff --git a/lib/llm/tests/test_streaming_tool_parsers.rs b/lib/llm/tests/test_streaming_tool_parsers.rs index fcc56e0517..73b1bec1a9 100644 --- a/lib/llm/tests/test_streaming_tool_parsers.rs +++ b/lib/llm/tests/test_streaming_tool_parsers.rs @@ -251,6 +251,71 @@ fn aggregate_content_from_chunks( } } +/// Helper function to validate finish_reason in the stream +/// Returns true if: +/// 1. There is exactly one finish_reason in the entire stream +/// 2. The finish_reason is in the last chunk +/// 3. The finish_reason matches the expected value +fn validate_finish_reason( + chunks: &[Annotated], + expected_finish_reason: FinishReason, +) -> bool { + let mut finish_reason_count = 0; + let mut last_chunk_index = None; + let mut finish_reason_value = None; + + // Count finish_reason occurrences and track position + for (idx, chunk) in chunks.iter().enumerate() { + if let Some(ref response_data) = chunk.data { + for choice in &response_data.choices { + if let Some(reason) = choice.finish_reason { + finish_reason_count += 1; + last_chunk_index = Some(idx); + finish_reason_value = Some(reason); + } + } + } + } + + // Validate: + // 1. Exactly one finish_reason in the stream + if finish_reason_count != 1 { + eprintln!( + "Expected exactly 1 finish_reason, but found {}", + finish_reason_count + ); + return false; + } + + // 2. finish_reason is in the last chunk + if let Some(idx) = last_chunk_index { + if idx != chunks.len() - 1 { + eprintln!( + "Expected finish_reason in last chunk (index {}), but found at index {}", + chunks.len() - 1, + idx + ); + return false; + } + } else { + eprintln!("No finish_reason found in stream"); + return false; + } + + // 3. finish_reason matches expected value + if let Some(reason) = finish_reason_value + && reason != expected_finish_reason + { + eprintln!( + "Expected finish_reason {:?}, but found {:?}", + expected_finish_reason, reason + ); + return false; + } + + true +} + #[cfg(test)] mod tests { use super::*; @@ -305,19 +370,11 @@ mod tests { "Tool calls presence should match expected value" ); - // Verify last chunk has Stop finish_reason for no-tool cases - let last_chunk = output_chunks - .last() - .expect("Should have at least one chunk"); - if let Some(data) = &last_chunk.data - && let Some(choice) = data.choices.first() - { - assert_eq!( - choice.finish_reason, - Some(FinishReason::Stop), - "Last chunk should have Stop finish_reason for non-tool call case" - ); - } + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop + assert!( + validate_finish_reason(&output_chunks, FinishReason::Stop), + "finish_reason validation failed for non-tool call case" + ); } #[tokio::test] @@ -375,23 +432,11 @@ mod tests { // Verify tool calls assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); - // Verify the last chunk has ToolCalls finish_reason (empty Stop chunks should be filtered) - let last_chunk = output_chunks - .last() - .expect("Should have at least one chunk"); - if let Some(data) = &last_chunk.data - && let Some(choice) = data.choices.first() - { - assert_eq!( - choice.finish_reason, - Some(FinishReason::ToolCalls), - "Last chunk should have ToolCalls finish_reason (empty Stop chunks should be filtered)" - ); - assert!( - choice.delta.tool_calls.is_some(), - "Last chunk with ToolCalls finish_reason must have tool_calls data" - ); - } + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls + assert!( + validate_finish_reason(&output_chunks, FinishReason::ToolCalls), + "finish_reason validation failed for tool call case" + ); } #[tokio::test] @@ -436,19 +481,11 @@ mod tests { "Tool calls presence should match expected value" ); - // Verify last chunk has Stop finish_reason for no-tool cases - let last_chunk = output_chunks - .last() - .expect("Should have at least one chunk"); - if let Some(data) = &last_chunk.data - && let Some(choice) = data.choices.first() - { - assert_eq!( - choice.finish_reason, - Some(FinishReason::Stop), - "Last chunk should have Stop finish_reason for non-tool call case" - ); - } + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop + assert!( + validate_finish_reason(&output_chunks, FinishReason::Stop), + "finish_reason validation failed for non-tool call case" + ); } #[tokio::test] @@ -502,23 +539,11 @@ mod tests { // Verify tool calls assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); - // Verify the last chunk has ToolCalls finish_reason (empty Stop chunks should be filtered) - let last_chunk = output_chunks - .last() - .expect("Should have at least one chunk"); - if let Some(data) = &last_chunk.data - && let Some(choice) = data.choices.first() - { - assert_eq!( - choice.finish_reason, - Some(FinishReason::ToolCalls), - "Last chunk should have ToolCalls finish_reason (empty Stop chunks should be filtered)" - ); - assert!( - choice.delta.tool_calls.is_some(), - "Last chunk with ToolCalls finish_reason must have tool_calls data" - ); - } + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls + assert!( + validate_finish_reason(&output_chunks, FinishReason::ToolCalls), + "finish_reason validation failed for tool call case" + ); } #[tokio::test] @@ -576,19 +601,11 @@ mod tests { assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); - // Verify last chunk has Stop finish_reason for no-tool cases - let last_chunk = output_chunks - .last() - .expect("Should have at least one chunk"); - if let Some(data) = &last_chunk.data - && let Some(choice) = data.choices.first() - { - assert_eq!( - choice.finish_reason, - Some(FinishReason::Stop), - "Last chunk should have Stop finish_reason for non-tool call case" - ); - } + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop + assert!( + validate_finish_reason(&output_chunks, FinishReason::Stop), + "finish_reason validation failed for non-tool call case" + ); } #[tokio::test] @@ -646,20 +663,10 @@ mod tests { assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); - // Verify there is a chunk with ToolCalls finish_reason and tool call data - let has_tool_calls_chunk = output_chunks.iter().any(|chunk| { - chunk - .data - .as_ref() - .and_then(|d| d.choices.first()) - .map(|c| { - c.finish_reason == Some(FinishReason::ToolCalls) && c.delta.tool_calls.is_some() - }) - .unwrap_or(false) - }); + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls assert!( - has_tool_calls_chunk, - "Should have a chunk with ToolCalls finish_reason and tool_calls data" + validate_finish_reason(&output_chunks, FinishReason::ToolCalls), + "finish_reason validation failed for tool call case" ); } @@ -715,19 +722,11 @@ mod tests { assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); - // Verify last chunk has Stop finish_reason for no-tool cases - let last_chunk = output_chunks - .last() - .expect("Should have at least one chunk"); - if let Some(data) = &last_chunk.data - && let Some(choice) = data.choices.first() - { - assert_eq!( - choice.finish_reason, - Some(FinishReason::Stop), - "Last chunk should have Stop finish_reason for non-tool call case" - ); - } + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop + assert!( + validate_finish_reason(&output_chunks, FinishReason::Stop), + "finish_reason validation failed for non-tool call case" + ); } #[tokio::test] @@ -783,20 +782,10 @@ mod tests { ); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); - // Verify there is a chunk with ToolCalls finish_reason and tool call data - let has_tool_calls_chunk = output_chunks.iter().any(|chunk| { - chunk - .data - .as_ref() - .and_then(|d| d.choices.first()) - .map(|c| { - c.finish_reason == Some(FinishReason::ToolCalls) && c.delta.tool_calls.is_some() - }) - .unwrap_or(false) - }); + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls assert!( - has_tool_calls_chunk, - "Should have a chunk with ToolCalls finish_reason and tool_calls data" + validate_finish_reason(&output_chunks, FinishReason::ToolCalls), + "finish_reason validation failed for tool call case" ); } @@ -851,22 +840,10 @@ mod tests { // Verify tool calls assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); - // Verify the last chunk has ToolCalls finish_reason (empty Stop chunks should be filtered) - let last_chunk = output_chunks - .last() - .expect("Should have at least one chunk"); - if let Some(data) = &last_chunk.data - && let Some(choice) = data.choices.first() - { - assert_eq!( - choice.finish_reason, - Some(FinishReason::ToolCalls), - "Last chunk should have ToolCalls finish_reason (empty Stop chunks should be filtered)" - ); - assert!( - choice.delta.tool_calls.is_some(), - "Last chunk with ToolCalls finish_reason must have tool_calls data" - ); - } + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls + assert!( + validate_finish_reason(&output_chunks, FinishReason::ToolCalls), + "finish_reason validation failed for tool call case" + ); } } From 2cad55c97d624517a7f29b42be709d7d31d923c4 Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 6 Nov 2025 19:01:25 +0000 Subject: [PATCH 4/8] fix: cleanup Signed-off-by: ayushag --- lib/llm/src/protocols/openai/chat_completions/jail.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index abd8626e61..06c4987027 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -510,7 +510,7 @@ impl JailedStream { last_annotated_event.clone(), last_annotated_comment.clone(), ); - let responses = self.emit_choice_emissions(tool_content_emissions.clone(), chat_response, preserved_metadata, &choice_states); + let responses = self.emit_choice_emissions(tool_content_emissions.clone(), chat_response, preserved_metadata); for emitted_response in responses { yield emitted_response; } @@ -523,7 +523,7 @@ impl JailedStream { last_annotated_event.clone(), last_annotated_comment.clone(), ); - let responses = self.emit_choice_emissions(trailing_emissions, chat_response, preserved_metadata, &choice_states); + let responses = self.emit_choice_emissions(trailing_emissions, chat_response, preserved_metadata); for emitted_response in responses { yield emitted_response; } @@ -532,7 +532,7 @@ impl JailedStream { // Emit pass-through content with current metadata if !passthrough_emissions.is_empty() { let current_metadata = (response.id.clone(), response.event.clone(), response.comment.clone()); - let responses = self.emit_choice_emissions(passthrough_emissions, chat_response, current_metadata, &choice_states); + let responses = self.emit_choice_emissions(passthrough_emissions, chat_response, current_metadata); for emitted_response in responses { yield emitted_response; } @@ -567,7 +567,7 @@ impl JailedStream { }; let final_metadata = (last_annotated_id, last_annotated_event, last_annotated_comment); - let responses = self.emit_choice_emissions(final_emissions, &dummy_response, final_metadata, &choice_states); + let responses = self.emit_choice_emissions(final_emissions, &dummy_response, final_metadata); for emitted_response in responses { yield emitted_response; } @@ -581,7 +581,6 @@ impl JailedStream { emissions: Vec, base_response: &NvCreateChatCompletionStreamResponse, annotated_metadata: (Option, Option, Option>), - _choice_states: &ChoiceJailStateCollection, ) -> Vec> { if emissions.is_empty() { return Vec::new(); From e8a02ab4aa4424c0c4888bec2c346f6c7c79d2ed Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 6 Nov 2025 19:34:26 +0000 Subject: [PATCH 5/8] chore: remove clone Signed-off-by: ayushag --- lib/llm/src/protocols/openai/chat_completions/jail.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index 06c4987027..6b0696b86f 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -510,7 +510,7 @@ impl JailedStream { last_annotated_event.clone(), last_annotated_comment.clone(), ); - let responses = self.emit_choice_emissions(tool_content_emissions.clone(), chat_response, preserved_metadata); + let responses = self.emit_choice_emissions(tool_content_emissions, chat_response, preserved_metadata); for emitted_response in responses { yield emitted_response; } From c214ba24efe1d83a8a6968b953d2fe73e0c103e9 Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 6 Nov 2025 20:06:35 +0000 Subject: [PATCH 6/8] chore: highlevel wrapper method Signed-off-by: ayushag --- lib/llm/src/preprocessor.rs | 5 +- .../protocols/openai/chat_completions/jail.rs | 14 ++ lib/llm/tests/test_jail.rs | 142 +++++------------- 3 files changed, 55 insertions(+), 106 deletions(-) diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index cfd3ab2e97..e7374d8a3a 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -764,10 +764,7 @@ impl OpenAIPreprocessor { let jail = JailedStream::builder() .tool_call_parser(tool_call_parser) .build(); - let jailed_stream = jail.apply(stream); - - // Post-process to set finish reason to ToolCalls for the last chunk if any tool calls were emitted - JailedStream::fix_finish_reason(jailed_stream) + jail.apply_with_finish_reason(stream) } // Motivation: Each transformation on the stream should be a separate step to allow for more flexibility diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index 6b0696b86f..c03ff93469 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -429,6 +429,20 @@ impl JailedStream { JailedStreamBuilder::new() } + + /// Apply jail stream transformation with finish_reason fix + /// This is a convenience method that applies both apply() and fix_finish_reason() + pub fn apply_with_finish_reason( + self, + stream: S, + ) -> impl Stream> + Send + where + S: Stream> + Send + 'static, + { + let jailed_stream = self.apply(stream); + JailedStream::fix_finish_reason(jailed_stream) + } + /// Apply the jail transformation to a stream of chat completion responses /// Consumes self and returns the transformed stream pub fn apply( diff --git a/lib/llm/tests/test_jail.rs b/lib/llm/tests/test_jail.rs index 22887c6e6b..0f126fa175 100644 --- a/lib/llm/tests/test_jail.rs +++ b/lib/llm/tests/test_jail.rs @@ -379,9 +379,7 @@ mod tests { .jail_end_sequence("") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // We should only get 3 chunks now: // 1. "Hello " (before jail) @@ -437,9 +435,7 @@ mod tests { .tool_call_parser("nemotron_deci") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have jailed the content and parsed tool calls at the end assert!(!results.is_empty()); @@ -476,9 +472,7 @@ mod tests { .tool_call_parser("nemotron_deci") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // We should get 2 chunks: // 1. "Normal text " (before jail) @@ -521,9 +515,7 @@ mod tests { .tool_call_parser("nemotron_deci") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have exactly 2 chunks: tool call + trailing content assert_eq!( @@ -565,9 +557,7 @@ mod tests { .jail_start_sequence("") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // === Verify chunk count === assert_eq!( @@ -620,9 +610,7 @@ mod tests { // Create JailedStream with Hermes parser let jail = JailedStream::builder().tool_call_parser("hermes").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -667,9 +655,7 @@ mod tests { // Create JailedStream with Mistral parser let jail = JailedStream::builder().tool_call_parser("mistral").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -710,9 +696,7 @@ mod tests { // Create JailedStream with Mistral parser let jail = JailedStream::builder().tool_call_parser("mistral").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -760,9 +744,7 @@ mod tests { // Create JailedStream with Phi4 parser let jail = JailedStream::builder().tool_call_parser("phi4").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -808,9 +790,7 @@ mod tests { .tool_call_parser("llama3_json") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have exactly 3 chunks: content + tool call + content assert_eq!( @@ -850,9 +830,7 @@ mod tests { // Create JailedStream with mistral parser (which specifically looks for [{ or [TOOL_CALLS] patterns) let jail = JailedStream::builder().tool_call_parser("mistral").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // The "{" pattern triggers jailing, so some chunks get combined assert_eq!(results.len(), 2); @@ -893,9 +871,7 @@ mod tests { .tool_call_parser("nemotron_deci") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Jailing combines the tool call content into fewer chunks assert_eq!( @@ -939,9 +915,7 @@ mod tests { .tool_call_parser("nemotron_deci") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should handle partial tool call gracefully - releases accumulated content on stream end assert_eq!( @@ -980,9 +954,7 @@ mod tests { .jail_end_sequence("") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // === Verify chunk count === assert_eq!( @@ -1036,9 +1008,7 @@ mod tests { .tool_call_parser("nemotron_deci") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // === Verify chunk count === assert_eq!( @@ -1145,9 +1115,7 @@ mod tests { .tool_call_parser("nemotron_deci") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should consolidate extreme fragmentation into 3 clean chunks // Input: "I'll process your request. " + 54-char tool call + " Processing complete!" @@ -1209,9 +1177,7 @@ mod tests { // Create JailedStream with Hermes parser let jail = JailedStream::builder().tool_call_parser("hermes").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should get 3 chunks: before jail, tool call response, after jail assert!( @@ -1293,9 +1259,7 @@ mod tests { // Create JailedStream with Hermes parser let jail = JailedStream::builder().tool_call_parser("hermes").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should get 2 chunks: first chunk passes through, stream end releases accumulated assert_eq!(results.len(), 2, "Should have exactly 2 chunks"); @@ -1360,9 +1324,7 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Find the tool call chunk (the one with tool_calls, not the finish_reason chunk) let tool_call_chunk = results @@ -1416,9 +1378,7 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // === Verify chunk count === assert_eq!( @@ -1460,9 +1420,7 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have exactly 3 chunks: content + tool call + trailing assert_eq!( @@ -1527,9 +1485,7 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // EXPECTED BEHAVIOR (will fail with current implementation): // - Choice 1 should stream continuously (never jailed) @@ -1605,9 +1561,7 @@ mod tests { let jail = JailedStream::builder().tool_call_parser("hermes").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Find all tool call responses let mut tool_call_responses: Vec<_> = results @@ -1653,9 +1607,7 @@ mod tests { let input_stream = stream::iter(chunks); let jail = JailedStream::builder().tool_call_parser("hermes").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let run_results: Vec<_> = fixed_stream.collect().await; + let run_results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; let run_responses: Vec<_> = run_results .iter() @@ -1694,9 +1646,7 @@ mod tests { let jail = JailedStream::builder().build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // TODO: Once usage aggregation is implemented, verify: // - Usage chunk has choices: [] (empty array) @@ -1731,9 +1681,7 @@ mod tests { .tool_call_parser("nemotron_deci") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // === Verify chunk count === assert_eq!( @@ -1788,9 +1736,7 @@ mod tests { .jail_end_sequence("") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // === Verify chunk count === assert_eq!( @@ -1844,9 +1790,7 @@ mod tests { let input_stream = stream::iter(chunks); let jail = JailedStream::builder().tool_call_parser("harmony").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have at least one output containing both analysis text and parsed tool call assert!(!results.is_empty()); @@ -1886,9 +1830,8 @@ mod tests { let jail = JailedStream::builder() .tool_call_parser("deepseek_v3_1") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let jailed_stream = jail.apply_with_finish_reason(input_stream); + let results: Vec<_> = jailed_stream.collect().await; // Should have at least one output containing both analysis text and parsed tool call assert!(!results.is_empty()); @@ -1961,9 +1904,8 @@ mod tests { let jail = JailedStream::builder() .tool_call_parser("deepseek_v3_1") .build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let jailed_stream = jail.apply_with_finish_reason(input_stream); + let results: Vec<_> = jailed_stream.collect().await; // Should have at least one output containing both analysis text and parsed tool call assert!(!results.is_empty()); @@ -2004,9 +1946,7 @@ mod tests { let input_stream = stream::iter(chunks); let jail = JailedStream::builder().tool_call_parser("mistral").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; assert!(results.len() >= 2); assert_content(&results[0], "Hey How"); @@ -2041,9 +1981,7 @@ mod tests { let input_stream = stream::iter(chunks); let jail = JailedStream::builder().tool_call_parser("mistral").build(); - let jailed_stream = jail.apply(input_stream); - let fixed_stream = JailedStream::fix_finish_reason(jailed_stream); - let results: Vec<_> = fixed_stream.collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should preserve earlier content and also produce a tool call assert!(results.len() >= 2); @@ -2216,7 +2154,7 @@ mod parallel_jail_tests { ]; let input_stream = stream::iter(input_chunks); - let results: Vec<_> = jail.apply(input_stream).collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should have tool call results assert!(!results.is_empty(), "Should have results"); @@ -2289,7 +2227,7 @@ mod parallel_jail_tests { ]; let input_stream = stream::iter(input_chunks); - let results: Vec<_> = jail.apply(input_stream).collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; assert!(!results.is_empty(), "Should have results"); @@ -2326,7 +2264,7 @@ mod parallel_jail_tests { ]; let input_stream = stream::iter(input_chunks); - let results: Vec<_> = jail.apply(input_stream).collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; assert!(!results.is_empty(), "Should have results"); @@ -2396,7 +2334,7 @@ mod parallel_jail_tests { ]; let input_stream = stream::iter(input_chunks); - let results: Vec<_> = jail.apply(input_stream).collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; assert!(!results.is_empty(), "Should have results"); @@ -2634,7 +2572,7 @@ mod parallel_jail_tests { ]; let input_stream = stream::iter(input_chunks); - let results: Vec<_> = jail.apply(input_stream).collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; assert!(!results.is_empty(), "Should have results"); @@ -2679,7 +2617,7 @@ mod parallel_jail_tests { ]; let input_stream = stream::iter(input_chunks); - let results: Vec<_> = jail.apply(input_stream).collect().await; + let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; // Should still handle the incomplete stream gracefully assert!( From 46cd5776f352c5f25ddd19ee9f31c988743d676e Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 6 Nov 2025 20:09:30 +0000 Subject: [PATCH 7/8] fix: fmt Signed-off-by: ayushag --- lib/llm/src/protocols/openai/chat_completions/jail.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index c03ff93469..9d7bbba5ad 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -429,7 +429,6 @@ impl JailedStream { JailedStreamBuilder::new() } - /// Apply jail stream transformation with finish_reason fix /// This is a convenience method that applies both apply() and fix_finish_reason() pub fn apply_with_finish_reason( From 25013b6ef47bdde362acf7ce55203d84261a0b71 Mon Sep 17 00:00:00 2001 From: ayushag Date: Fri, 7 Nov 2025 01:10:11 +0000 Subject: [PATCH 8/8] fix: edge cases Signed-off-by: ayushag --- .../protocols/openai/chat_completions/jail.rs | 26 ++++++++------ .../chat_completion_incomplete_tool.json | 21 +++++++++++ .../chat_completion_stream_finish_length.json | 20 +++++++++++ lib/llm/tests/test_streaming_tool_parsers.rs | 35 +++++++++++++++++++ 4 files changed, 91 insertions(+), 11 deletions(-) create mode 100644 lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json create mode 100644 lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_stream_finish_length.json diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index 9d7bbba5ad..cd10b6cff5 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -73,6 +73,8 @@ struct ChoiceJailState { accumulated_content: String, /// Buffer for partial marker matches across chunks partial_match_buffer: String, + /// Stream finish reason + stream_finish_reason: Option, } fn create_choice_stream( @@ -107,6 +109,7 @@ impl ChoiceJailState { is_jailed: false, accumulated_content: String::new(), partial_match_buffer: String::new(), + stream_finish_reason: None, } } @@ -131,7 +134,6 @@ impl ChoiceJailState { jail_stream: &JailedStream, ) -> Vec { let mut emissions = Vec::new(); - if !self.is_jailed { // Use the marker matcher to detect complete/partial markers let match_result = jail_stream @@ -153,7 +155,7 @@ impl ChoiceJailState { choice.delta.role, &prefix, None, - None, + choice.finish_reason, choice.logprobs.clone(), ); emissions.push(ChoiceEmission::PassThrough(prefix_choice)); @@ -193,7 +195,7 @@ impl ChoiceJailState { choice.delta.role, trailing_part, None, - None, + choice.finish_reason, choice.logprobs.clone(), ); emissions.push(ChoiceEmission::Trailing(trailing_choice)); @@ -225,7 +227,7 @@ impl ChoiceJailState { choice.delta.role, &prefix, None, - None, + choice.finish_reason, choice.logprobs.clone(), ); emissions.push(ChoiceEmission::PassThrough(prefix_choice)); @@ -268,7 +270,7 @@ impl ChoiceJailState { choice.delta.role, &content, None, - None, + choice.finish_reason, choice.logprobs.clone(), ); emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); @@ -313,7 +315,7 @@ impl ChoiceJailState { choice.delta.role, trailing_part, None, - None, + choice.finish_reason, choice.logprobs.clone(), ); emissions.push(ChoiceEmission::Trailing(trailing_choice)); @@ -324,7 +326,6 @@ impl ChoiceJailState { } // If not unjailing, don't emit anything (still accumulating) } - emissions } @@ -343,7 +344,7 @@ impl ChoiceJailState { Some(Role::Assistant), &self.accumulated_content, None, - None, + self.stream_finish_reason, // For the accumulated content, assign the original stream finish reason, otherwise it will get lost None, ); @@ -463,6 +464,7 @@ impl JailedStream { // Pin the stream for iteration (stack pinning is more efficient) tokio::pin!(stream); + // Process each item in the stream while let Some(response) = stream.next().await { if let Some(chat_response) = response.data.as_ref() { @@ -481,6 +483,9 @@ impl JailedStream { last_annotated_comment = response.comment.clone(); } + // Track actual stream finish reason in the choice state + choice_state.stream_finish_reason = choice.finish_reason; + // Process this choice and get emissions let emissions = choice_state.process_content(choice, content, &self).await; all_emissions.extend(emissions); @@ -721,7 +726,6 @@ impl JailedStream { }), }) .collect(); - // Create choice with tool calls let choice = create_choice_stream( choice_index, @@ -740,7 +744,7 @@ impl JailedStream { Some(Role::Assistant), accumulated_content, None, - None, + base_choice.finish_reason, base_choice.logprobs.clone(), ) } @@ -786,7 +790,7 @@ impl JailedStream { // If this chunk has finish_reason and the choice had tool calls, override to ToolCalls if let Some(ref mut data) = response.data { for choice in &mut data.choices { - if choice.finish_reason.is_some() + if choice.finish_reason.is_some() && choice.finish_reason == Some(FinishReason::Stop) && has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false) { choice.finish_reason = Some(FinishReason::ToolCalls); diff --git a/lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json b/lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json new file mode 100644 index 0000000000..8898a0f587 --- /dev/null +++ b/lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json @@ -0,0 +1,21 @@ +{ + "request_id": "8f33c28b-cb52-4272-9ac5-0cb9f80386d3", + "expected_output": { + "normal_content": " the requested format.\n\n\n\n\n{\"name\":\"get" + }, + "input_stream": [ + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" the","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" requested","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" format","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":".\n","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"\n\n","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"\n","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"{\"","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"name","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"\":","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" \"","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"get","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}, "finish_reason":"length"}]}} + ] +} diff --git a/lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_stream_finish_length.json b/lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_stream_finish_length.json new file mode 100644 index 0000000000..e49d1ba328 --- /dev/null +++ b/lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_stream_finish_length.json @@ -0,0 +1,20 @@ +{ + "request_id": "8f33c28b-cb52-4272-9ac5-0cb9f80386d3", + "expected_output": { + "normal_content": "\nOkay, the user is asking for the weather in San Francisco in" + }, + "input_stream": [ + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"\n","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"Okay","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":",","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" the","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" user","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" is","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" asking","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" for","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" the","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" weather","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}}, + {"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" in","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null},"finish_reason":"length"}]}} + ] +} diff --git a/lib/llm/tests/test_streaming_tool_parsers.rs b/lib/llm/tests/test_streaming_tool_parsers.rs index 73b1bec1a9..70b106d676 100644 --- a/lib/llm/tests/test_streaming_tool_parsers.rs +++ b/lib/llm/tests/test_streaming_tool_parsers.rs @@ -846,4 +846,39 @@ mod tests { "finish_reason validation failed for tool call case" ); } + + #[tokio::test] + async fn test_qwen_finish_reason_length_vllm() { + let file_paths = vec![ + format!( + "{}/vllm/qwen3-0.6B/chat_completion_stream_finish_length.json", + DATA_ROOT_PATH + ), + format!( + "{}/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json", + DATA_ROOT_PATH + ), + ]; + + for file_path in file_paths { + let test_data = load_test_data(&file_path); + + // Create a stream from the mock chunks + let input_stream = stream::iter(test_data.stream_chunks); + + // Parse the response stream with tool parsing enabled + let output_chunks = + parse_response_stream(input_stream, true, false, Some("hermes".to_string()), None) + .await; + + // Verify we got output chunks + assert!(!output_chunks.is_empty(), "Should have output chunks"); + + // Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Length + assert!( + validate_finish_reason(&output_chunks, FinishReason::Length), + "finish_reason validation failed for length finish case" + ); + } + } }