diff --git a/examples/chat_reasoning_simple.rs b/examples/chat_reasoning_simple.rs new file mode 100644 index 0000000..ca65d5a --- /dev/null +++ b/examples/chat_reasoning_simple.rs @@ -0,0 +1,87 @@ +use std::io::{stdout, Write}; + +use dotenvy::dotenv; +use openai::{ + chat::{ChatCompletion, ChatCompletionDelta, ChatCompletionMessage, ChatCompletionMessageRole}, + Credentials, +}; +use tokio::sync::mpsc::{error::TryRecvError, Receiver}; + +#[tokio::main(flavor = "current_thread")] +async fn main() { + dotenv().unwrap(); + let credentials = Credentials::from_env(); + let mut messages = vec![ChatCompletionMessage { + role: ChatCompletionMessageRole::System, + content: Some("You're an AI that replies to each message verbosely.".to_string()), + ..Default::default() + }]; + + stdout().flush().unwrap(); + + let user_message_content = "what tools do you have?".to_string(); + + messages.push(ChatCompletionMessage { + role: ChatCompletionMessageRole::User, + content: Some(user_message_content), + ..Default::default() + }); + + let chat_stream = ChatCompletionDelta::builder("qwen3.5-plus", messages.clone()) + .credentials(credentials.clone()) + .create_stream() + .await + .unwrap(); + + let chat_completion: ChatCompletion = listen_for_tokens(chat_stream).await; + let returned_message = chat_completion.choices.first().unwrap().message.clone(); + + messages.push(returned_message); +} + +async fn listen_for_tokens(mut chat_stream: Receiver) -> ChatCompletion { + let mut merged: Option = None; + let mut thingking = false; + loop { + match chat_stream.try_recv() { + Ok(delta) => { + let choice = &delta.choices[0]; + + if let Some(role) = &choice.delta.role { + print!("{:#?}: ", role); + } + if thingking == false && choice.delta.reasoning_content.is_some() { + thingking = true; + print!("šŸ¤” -> \n"); + } + if thingking == true && choice.delta.reasoning_content.is_none() { + thingking = false; + print!("\nšŸ˜„ -> \n"); + } + if let Some(content) = &choice.delta.content { + print!("{}", content); + } + if let Some(reason_content) = &choice.delta.reasoning_content { + print!("{}", reason_content); + } + stdout().flush().unwrap(); + // Merge token into full completion. + match merged.as_mut() { + Some(c) => { + c.merge(delta).unwrap(); + } + None => merged = Some(delta), + }; + } + Err(TryRecvError::Empty) => { + let duration = std::time::Duration::from_millis(50); + tokio::time::sleep(duration).await; + } + Err(TryRecvError::Disconnected) => { + break; + } + }; + } + println!(); + merged.unwrap().into() +} diff --git a/src/chat.rs b/src/chat.rs index aeba4ca..f8c38d7 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -88,6 +88,9 @@ pub struct ChatCompletionMessageDelta { pub role: Option, /// The contents of the message pub content: Option, + /// The contents of the reasoning message + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, /// The name of the user in a multi-user chat #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, @@ -440,6 +443,27 @@ impl ChatCompletionChoiceDelta { } } }; + // Merge reasonging contents. + match self.delta.reasoning_content.as_mut() { + Some(content) => { + match &other.delta.reasoning_content { + Some(other_content) => { + // Push other content into this one. + content.push_str(other_content) + } + None => {} + } + } + None => { + match &other.delta.reasoning_content { + Some(other_content) => { + // Set this content to other content. + self.delta.reasoning_content = Some(other_content.clone()); + } + None => {} + } + } + }; // merge function calls // function call names are concatenated diff --git a/src/lib.rs b/src/lib.rs index 8a0339c..d803645 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +use reqwest::header::USER_AGENT; use reqwest::multipart::Form; use reqwest::{header::AUTHORIZATION, Client, Method, RequestBuilder, Response}; use reqwest_eventsource::{CannotCloneRequestError, EventSource, RequestBuilderExt}; @@ -184,6 +185,7 @@ where request = builder(request); let response = request .header(AUTHORIZATION, format!("Bearer {}", credentials.api_key)) + .header(USER_AGENT, format!("openai")) .send() .await?; Ok(response) @@ -205,6 +207,7 @@ where request = builder(request); let stream = request .header(AUTHORIZATION, format!("Bearer {}", credentials.api_key)) + .header(USER_AGENT, format!("openai")) .eventsource()?; Ok(stream) }