diff --git a/.editorconfig b/.editorconfig deleted file mode 100644 index 7bd3346..0000000 --- a/.editorconfig +++ /dev/null @@ -1,2 +0,0 @@ -[*.yml] -indent_size = 2 diff --git a/.gitignore b/.gitignore index 12550c5..4b97790 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ Cargo.lock # Jetbrains .idea *.iml +/.vscode \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 25ae499..660b03d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,13 +11,18 @@ keywords = ["ai", "machine-learning", "openai", "library"] [dependencies] serde_json = "1.0.94" derive_builder = "0.20.0" -reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"], optional = true } +reqwest = { version = "0.12", default-features = false, features = [ + "json", + "stream", + "multipart", +], optional = true } serde = { version = "1.0.157", features = ["derive"] } reqwest-eventsource = "0.6" tokio = { version = "1.26.0", features = ["full"] } anyhow = "1.0.70" futures-util = "0.3.28" bytes = "1.4.0" +tracing = "0.1.41" [dev-dependencies] dotenvy = "0.15.7" diff --git a/examples/chat_cli.rs b/examples/chat_cli.rs index acbf393..3abcd2d 100644 --- a/examples/chat_cli.rs +++ b/examples/chat_cli.rs @@ -1,6 +1,6 @@ use dotenvy::dotenv; use openai::{ - chat::{ChatCompletion, ChatCompletionMessage, ChatCompletionMessageRole}, + chat::{ChatCompletion, ChatCompletionMessage, ChatCompletionMessageRole, Content}, Credentials, }; use std::io::{stdin, stdout, Write}; @@ -13,7 +13,7 @@ async fn main() { let mut messages = vec![ChatCompletionMessage { role: ChatCompletionMessageRole::System, - content: Some("You are a large language model built into a command line interface as an example of what the `openai` Rust library made by Valentine Briese can do.".to_string()), + content: Some(Content::new_str("You are a large language model built into a command line interface as an example of what the `openai` Rust library made by Valentine Briese can do.")), ..Default::default() }]; @@ -26,7 +26,7 @@ async fn main() { stdin().read_line(&mut user_message_content).unwrap(); messages.push(ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some(user_message_content), + content: Some(Content::new_str(&user_message_content)), ..Default::default() }); @@ -40,7 +40,7 @@ async fn main() { println!( "{:#?}: {}", &returned_message.role, - &returned_message.content.clone().unwrap().trim() + &returned_message.content.clone().unwrap() ); messages.push(returned_message); diff --git a/examples/chat_simple.rs b/examples/chat_simple.rs index 46a5eba..915826b 100644 --- a/examples/chat_simple.rs +++ b/examples/chat_simple.rs @@ -1,6 +1,6 @@ use dotenvy::dotenv; use openai::{ - chat::{ChatCompletion, ChatCompletionMessage, ChatCompletionMessageRole}, + chat::{ChatCompletion, ChatCompletionMessage, ChatCompletionMessageRole, Content}, Credentials, }; @@ -13,12 +13,12 @@ async fn main() { let messages = vec![ ChatCompletionMessage { role: ChatCompletionMessageRole::System, - content: Some("You are a helpful assistant.".to_string()), + content: Some(Content::new_str("You are a helpful assistant.")), ..Default::default() }, ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some("Tell me a random crab fact".to_string()), + content: Some(Content::new_str("Tell me a random crab fact")), ..Default::default() }, ]; @@ -32,6 +32,6 @@ async fn main() { println!( "{:#?}: {}", returned_message.role, - returned_message.content.unwrap().trim() + returned_message.content.unwrap() ); } diff --git a/examples/chat_stream_cli.rs b/examples/chat_stream_cli.rs index 6544813..8370a59 100644 --- a/examples/chat_stream_cli.rs +++ b/examples/chat_stream_cli.rs @@ -1,5 +1,5 @@ use dotenvy::dotenv; -use openai::chat::{ChatCompletion, ChatCompletionDelta}; +use openai::chat::{ChatCompletion, ChatCompletionDelta, Content}; use openai::{ chat::{ChatCompletionMessage, ChatCompletionMessageRole}, Credentials, @@ -15,7 +15,9 @@ async fn main() { let mut messages = vec![ChatCompletionMessage { role: ChatCompletionMessageRole::System, - content: Some("You're an AI that replies to each message verbosely.".to_string()), + content: Some(Content::new_str( + "You're an AI that replies to each message verbosely.", + )), ..Default::default() }]; @@ -28,7 +30,7 @@ async fn main() { stdin().read_line(&mut user_message_content).unwrap(); messages.push(ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some(user_message_content), + content: Some(Content::new_str(&user_message_content)), ..Default::default() }); diff --git a/src/chat.rs b/src/chat.rs index 5d423d2..cde75b5 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -1,161 +1,14 @@ //! Given a chat conversation, the model will return a chat completion response. +pub mod modules; +pub mod requests; +pub mod types; +pub mod utils; -use super::{ - openai_get, openai_get_with_query, openai_post, ApiResponseOrError, Credentials, - RequestPagination, Usage, -}; -use crate::openai_request_stream; -use derive_builder::Builder; -use futures_util::StreamExt; -use reqwest::Method; -use reqwest_eventsource::{CannotCloneRequestError, Event, EventSource}; +pub use modules::*; +pub use requests::*; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::collections::HashMap; -use tokio::sync::mpsc::{channel, Receiver, Sender}; - -/// A full chat completion. -pub type ChatCompletion = ChatCompletionGeneric; - -/// A delta chat completion, which is streamed token by token. -pub type ChatCompletionDelta = ChatCompletionGeneric; - -#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] -pub struct ChatCompletionGeneric { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] -pub struct ChatCompletionChoice { - pub index: u64, - pub finish_reason: String, - pub message: ChatCompletionMessage, -} - -#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] -pub struct ChatCompletionChoiceDelta { - pub index: u64, - pub finish_reason: Option, - pub delta: ChatCompletionMessageDelta, -} - -fn is_none_or_empty_vec(opt: &Option>) -> bool { - opt.as_ref().map(|v| v.is_empty()).unwrap_or(true) -} - -#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq, Default)] -pub struct ChatCompletionMessage { - /// The role of the author of this message. - pub role: ChatCompletionMessageRole, - /// The contents of the message - /// - /// This is always required for all messages, except for when ChatGPT calls - /// a function. - pub content: Option, - /// The name of the user in a multi-user chat - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - /// The function that ChatGPT called. This should be "None" usually, and is returned by ChatGPT and not provided by the developer - /// - /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, - /// Tool call that this message is responding to. - /// Required if the role is `Tool`. - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, - /// Tool calls that the assistant is requesting to invoke. - /// Can only be populated if the role is `Assistant`, - /// otherwise it should be empty. - #[serde(skip_serializing_if = "is_none_or_empty_vec")] - pub tool_calls: Option>, -} - -/// Same as ChatCompletionMessage, but received during a response stream. -#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] -pub struct ChatCompletionMessageDelta { - /// The role of the author of this message. - pub role: Option, - /// The contents of the message - pub content: Option, - /// The name of the user in a multi-user chat - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - /// The function that ChatGPT called - /// - /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, - /// Tool call that this message is responding to. - /// Required if the role is `Tool`. - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, - /// Tool calls that the assistant is requesting to invoke. - /// Can only be populated if the role is `Assistant`, - /// otherwise it should be empty. - #[serde(skip_serializing_if = "is_none_or_empty_vec")] - pub tool_calls: Option>, -} - -#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] -pub struct ToolCall { - /// The ID of the tool call. - pub id: String, - /// The type of the tool. Currently, only `function` is supported. - pub r#type: String, - /// The function that the model called. - pub function: ToolCallFunction, -} - -#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] -pub struct ToolCallFunction { - /// The name of the function to call. - pub name: String, - /// The arguments to call the function with, as generated by the model in - /// JSON format. - /// Note that the model does not always generate valid JSON, and may - /// hallucinate parameters not defined by your function schema. - /// Validate the arguments in your code before calling your function. - pub arguments: String, -} - -#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] -pub struct ChatCompletionFunctionDefinition { - /// The name of the function - pub name: String, - /// The description of the function - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// The parameters of the function formatted in JSON Schema - /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-parameters) - /// [See more information about JSON Schema.](https://json-schema.org/understanding-json-schema/) - #[serde(skip_serializing_if = "Option::is_none")] - pub parameters: Option, -} - -#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] -pub struct ChatCompletionFunctionCall { - /// The name of the function ChatGPT called - pub name: String, - /// The arguments that ChatGPT called (formatted in JSON) - /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) - pub arguments: String, -} - -/// Same as ChatCompletionFunctionCall, but received during a response stream. -#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] -pub struct ChatCompletionFunctionCallDelta { - /// The name of the function ChatGPT called - pub name: Option, - /// The arguments that ChatGPT called (formatted in JSON) - /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) - pub arguments: Option, -} +pub use types::*; +pub use utils::*; #[derive(Deserialize, Serialize, Debug, Clone, Copy, Eq, PartialEq)] #[serde(rename_all = "lowercase")] @@ -168,364 +21,12 @@ pub enum ChatCompletionMessageRole { Developer, } -#[derive(Serialize, Builder, Debug, Clone)] -#[builder(derive(Clone, Debug, PartialEq))] -#[builder(pattern = "owned")] -#[builder(name = "ChatCompletionBuilder")] -#[builder(setter(strip_option, into))] -pub struct ChatCompletionRequest { - /// ID of the model to use. Currently, only `gpt-3.5-turbo`, `gpt-3.5-turbo-0301` and `gpt-4` - /// are supported. - model: String, - /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction). - messages: Vec, - /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. - /// - /// We generally recommend altering this or `top_p` but not both. - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. - /// - /// We generally recommend altering this or `temperature` but not both. - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - top_p: Option, - /// How many chat completion choices to generate for each input message. - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - n: Option, - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - stream: Option, - /// Up to 4 sequences where the API will stop generating further tokens. - #[builder(default)] - #[serde(skip_serializing_if = "Vec::is_empty")] - stop: Vec, - /// This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend. - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - seed: Option, - /// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens). - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, - /// The maximum number of tokens allowed for the generated answer. - /// For reasoning models such as o1 and o3-mini, this does not include reasoning tokens. - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - max_completion_tokens: Option, - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - /// - /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - presence_penalty: Option, - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - /// - /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - frequency_penalty: Option, - /// Modify the likelihood of specified tokens appearing in the completion. - /// - /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - logit_bias: Option>, - /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). - #[builder(default)] - #[serde(skip_serializing_if = "String::is_empty")] - user: String, - /// Describe functions that ChatGPT can call - /// The latest models of ChatGPT support function calling, which allows you to define functions that can be called from the prompt. - /// For example, you can define a function called "get_weather" that returns the weather in a given city - /// - /// [Function calling API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions) - /// [See more information about function calling in ChatGPT.](https://platform.openai.com/docs/guides/gpt/function-calling) - #[builder(default)] - #[serde(skip_serializing_if = "Vec::is_empty")] - functions: Vec, - /// A string or object of the function to call - /// - /// Controls how the model responds to function calls - /// - /// - "none" means the model does not call a function, and responds to the end-user. - /// - "auto" means the model can pick between an end-user or calling a function. - /// - Specifying a particular function via {"name":\ "my_function"} forces the model to call that function. - /// - /// "none" is the default when no functions are present. "auto" is the default if functions are present. - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - function_call: Option, - /// An object specifying the format that the model must output. Compatible with GPT-4 Turbo and all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. - /// Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. - /// Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length. - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - response_format: Option, - /// The credentials to use for this request. - #[serde(skip_serializing)] - #[builder(default)] - credentials: Option, - /// Parameters unique to the Venice API. - /// https://docs.venice.ai/api-reference/api-spec - #[builder(default)] - #[serde(skip_serializing_if = "Option::is_none")] - venice_parameters: Option, - /// Whether to store the completion for use in distillation or evals. - #[serde(skip_serializing_if = "Option::is_none")] - #[builder(default)] - pub store: Option, -} - -#[derive(Serialize, Debug, Clone, Eq, PartialEq)] -pub struct VeniceParameters { - pub include_venice_system_prompt: bool, -} - -#[derive(Serialize, Debug, Clone, Eq, PartialEq)] -pub struct ChatCompletionResponseFormat { - /// Must be one of text or json_object (defaults to text) - #[serde(rename = "type")] - typ: String, -} - -impl ChatCompletionResponseFormat { - pub fn json_object() -> Self { - ChatCompletionResponseFormat { - typ: "json_object".to_string(), - } - } - - pub fn text() -> Self { - ChatCompletionResponseFormat { - typ: "text".to_string(), - } - } -} - -impl ChatCompletionGeneric { - pub fn builder( - model: &str, - messages: impl Into>, - ) -> ChatCompletionBuilder { - ChatCompletionBuilder::create_empty() - .model(model) - .messages(messages) - } -} - -#[derive(Serialize, Builder, Debug, Clone, Default)] -#[builder(derive(Clone, Debug, PartialEq))] -#[builder(pattern = "owned")] -#[builder(name = "ChatCompletionMessagesRequestBuilder")] -#[builder(setter(strip_option, into))] -pub struct ChatCompletionMessagesRequest { - #[serde(skip_serializing)] - pub completion_id: String, - - #[builder(default)] - #[serde(skip_serializing)] - pub credentials: Option, - - #[builder(default)] - #[serde(flatten)] - pub pagination: RequestPagination, -} - -/// A list of messages for a chat completion. -#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] -pub struct ChatCompletionMessages { - pub data: Vec, - pub object: String, - pub first_id: Option, - pub last_id: Option, - pub has_more: bool, -} - -impl ChatCompletion { - pub async fn create(request: ChatCompletionRequest) -> ApiResponseOrError { - let credentials_opt = request.credentials.clone(); - openai_post("chat/completions", &request, credentials_opt).await - } - - /// Get a stored completion. - pub async fn get(id: &str, credentials: Credentials) -> ApiResponseOrError { - let route = format!("chat/completions/{}", id); - openai_get(route.as_str(), Some(credentials)).await - } -} - -impl ChatCompletionDelta { - pub async fn create( - request: ChatCompletionRequest, - ) -> Result, CannotCloneRequestError> { - let credentials_opt = request.credentials.clone(); - let stream = openai_request_stream( - Method::POST, - "chat/completions", - |r| r.json(&request), - credentials_opt, - ) - .await?; - let (tx, rx) = channel::(32); - tokio::spawn(forward_deserialized_chat_response_stream(stream, tx)); - Ok(rx) - } - - /// Merges the input delta completion into `self`. - pub fn merge( - &mut self, - other: ChatCompletionDelta, - ) -> Result<(), ChatCompletionDeltaMergeError> { - if other.id.ne(&self.id) { - return Err(ChatCompletionDeltaMergeError::DifferentCompletionIds); - } - for other_choice in other.choices.iter() { - for choice in self.choices.iter_mut() { - if choice.index != other_choice.index { - continue; - } - choice.merge(other_choice)?; - } - } - Ok(()) - } -} - -impl ChatCompletionChoiceDelta { - pub fn merge( - &mut self, - other: &ChatCompletionChoiceDelta, - ) -> Result<(), ChatCompletionDeltaMergeError> { - if self.index != other.index { - return Err(ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices); - } - if self.delta.role.is_none() { - if let Some(other_role) = other.delta.role { - // Set role to other_role. - self.delta.role = Some(other_role); - } - } - if self.delta.name.is_none() { - if let Some(other_name) = &other.delta.name { - // Set name to other_name. - self.delta.name = Some(other_name.clone()); - } - } - // Merge contents. - match self.delta.content.as_mut() { - Some(content) => { - match &other.delta.content { - Some(other_content) => { - // Push other content into this one. - content.push_str(other_content) - } - None => {} - } - } - None => { - match &other.delta.content { - Some(other_content) => { - // Set this content to other content. - self.delta.content = Some(other_content.clone()); - } - None => {} - } - } - }; - - // merge function calls - // function call names are concatenated - // arguments are merged by concatenating them - match self.delta.function_call.as_mut() { - Some(function_call) => { - match &other.delta.function_call { - Some(other_function_call) => { - // push the arguments string of the other function call into this one - match (&mut function_call.arguments, &other_function_call.arguments) { - (Some(function_call), Some(other_function_call)) => { - function_call.push_str(&other_function_call); - } - (None, Some(other_function_call)) => { - function_call.arguments = Some(other_function_call.clone()); - } - _ => {} - } - } - None => {} - } - } - None => { - match &other.delta.function_call { - Some(other_function_call) => { - // Set this content to other content. - self.delta.function_call = Some(other_function_call.clone()); - } - None => {} - } - } - }; - Ok(()) - } -} - -impl From for ChatCompletion { - fn from(delta: ChatCompletionDelta) -> Self { - ChatCompletion { - id: delta.id, - object: delta.object, - created: delta.created, - model: delta.model, - usage: delta.usage, - choices: delta - .choices - .iter() - .map(|choice| ChatCompletionChoice { - index: choice.index, - finish_reason: clone_default_unwrapped_option_string(&choice.finish_reason), - message: ChatCompletionMessage { - role: choice - .delta - .role - .unwrap_or_else(|| ChatCompletionMessageRole::System), - content: choice.delta.content.clone(), - name: choice.delta.name.clone(), - function_call: choice.delta.function_call.clone().map(|f| f.into()), - tool_call_id: None, - tool_calls: Some(Vec::new()), - }, - }) - .collect(), - } - } -} - -impl From for ChatCompletionFunctionCall { - fn from(delta: ChatCompletionFunctionCallDelta) -> Self { - ChatCompletionFunctionCall { - name: delta.name.unwrap_or("".to_string()), - arguments: delta.arguments.unwrap_or_default(), - } - } -} - -impl ChatCompletionMessages { - /// Create a builder for fetching messages for a stored completion. - pub fn builder(completion_id: String) -> ChatCompletionMessagesRequestBuilder { - ChatCompletionMessagesRequestBuilder::create_empty() - .completion_id(completion_id.to_string()) - } - - /// Fetch messages for a stored completion. - pub async fn fetch( - request: ChatCompletionMessagesRequest, - ) -> ApiResponseOrError { - let route = format!("chat/completions/{}/messages", request.completion_id); - let credentials = request.credentials.clone(); - openai_get_with_query(route.as_str(), &request, credentials).await - } +#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ToolChoiceMode { + None, + Auto, + Required, } #[derive(Debug)] @@ -553,50 +54,6 @@ impl std::fmt::Display for ChatCompletionDeltaMergeError { impl std::error::Error for ChatCompletionDeltaMergeError {} -async fn forward_deserialized_chat_response_stream( - mut stream: EventSource, - tx: Sender, -) -> anyhow::Result<()> { - while let Some(event) = stream.next().await { - let event = event?; - match event { - Event::Message(event) => { - let completion = serde_json::from_str::(&event.data)?; - tx.send(completion).await?; - } - _ => {} - } - } - Ok(()) -} - -impl ChatCompletionBuilder { - pub async fn create(self) -> ApiResponseOrError { - ChatCompletion::create(self.build().unwrap()).await - } - - pub async fn create_stream( - mut self, - ) -> Result, CannotCloneRequestError> { - self.stream = Some(Some(true)); - ChatCompletionDelta::create(self.build().unwrap()).await - } -} - -impl ChatCompletionMessagesRequestBuilder { - /// Fetch messages for the specified completion. - pub async fn fetch(self) -> ApiResponseOrError { - ChatCompletionMessages::fetch(self.build().unwrap()).await - } -} - -fn clone_default_unwrapped_option_string(string: &Option) -> String { - match string { - Some(value) => value.clone(), - None => "".to_string(), - } -} - impl Default for ChatCompletionMessageRole { fn default() -> Self { Self::User @@ -606,8 +63,11 @@ impl Default for ChatCompletionMessageRole { #[cfg(test)] mod tests { use super::*; + use crate::{Credentials, RequestPagination}; use dotenvy::dotenv; + use serde_json::Value; use std::time::Duration; + use tokio::sync::mpsc::Receiver; use tokio::time::sleep; #[tokio::test] @@ -619,7 +79,7 @@ mod tests { "gpt-3.5-turbo", [ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some("Hello!".to_string()), + content: Some(Content::new_str("Hello!")), name: None, function_call: None, tool_call_id: None, @@ -642,7 +102,7 @@ mod tests { .content .as_ref() .unwrap(), - "Hello! How can I assist you today?" + &Content::new_str("Hello! How can I assist you today?") ); } @@ -657,10 +117,9 @@ mod tests { "gpt-3.5-turbo", [ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some( - "What type of seed does Mr. England sow in the song? Reply with 1 word." - .to_string(), - ), + content: Some(Content::new_str( + "What type of seed does Mr. England sow in the song? Reply with 1 word.", + )), name: None, function_call: None, tool_call_id: None, @@ -684,7 +143,7 @@ mod tests { .content .as_ref() .unwrap(), - "Love" + &Content::new_str("Love") ); } @@ -697,7 +156,7 @@ mod tests { "gpt-3.5-turbo", [ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some("Hello!".to_string()), + content: Some(Content::new_str("Hello!")), name: None, function_call: None, tool_call_id: None, @@ -721,7 +180,7 @@ mod tests { .content .as_ref() .unwrap(), - "Hello! How can I assist you today?" + &Content::new_str("Hello! How can I assist you today?") ); } @@ -735,7 +194,7 @@ mod tests { [ ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some("What is the weather in Boston?".to_string()), + content: Some(Content::new_str("What is the weather in Boston?")), name: None, function_call: None, tool_call_id: None, @@ -796,51 +255,6 @@ mod tests { ); } - #[tokio::test] - async fn chat_response_format_json() { - dotenv().ok(); - let credentials = Credentials::from_env(); - let chat_completion = ChatCompletion::builder( - "gpt-3.5-turbo", - [ChatCompletionMessage { - role: ChatCompletionMessageRole::User, - content: Some("Write an example JSON for a JWT header using RS256".to_string()), - name: None, - function_call: None, - tool_call_id: None, - tool_calls: Some(Vec::new()), - }], - ) - .temperature(0.0) - .seed(1337u64) - .response_format(ChatCompletionResponseFormat::json_object()) - .credentials(credentials) - .create() - .await - .unwrap(); - let response_string = chat_completion - .choices - .first() - .unwrap() - .message - .content - .as_ref() - .unwrap(); - #[derive(Deserialize, Eq, PartialEq, Debug)] - struct Response { - alg: String, - typ: String, - } - let response = serde_json::from_str::(response_string).unwrap(); - assert_eq!( - response, - Response { - alg: "RS256".to_owned(), - typ: "JWT".to_owned() - } - ); - } - #[test] fn builder_clone_and_eq() { let builder_a = ChatCompletion::builder("gpt-4", []) @@ -881,14 +295,13 @@ mod tests { [ ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some( + content: Some(Content::new_str( "What's 0.9102847*28456? \ reply in plain text, \ round the number to to 2 decimals \ and reply with the result number only, \ - with no full stop at the end" - .to_string(), - ), + with no full stop at the end", + )), name: None, function_call: None, tool_call_id: None, @@ -896,7 +309,7 @@ mod tests { }, ChatCompletionMessage { role: ChatCompletionMessageRole::Assistant, - content: Some("Let me calculate that for you.".to_string()), + content: Some(Content::new_str("Let me calculate that for you.")), name: None, function_call: None, tool_call_id: None, @@ -911,7 +324,7 @@ mod tests { }, ChatCompletionMessage { role: ChatCompletionMessageRole::Tool, - content: Some("the result is 25903.061423199997".to_string()), + content: Some(Content::new_str("the result is 25903.061423199997")), name: None, function_call: None, tool_call_id: Some("the_tool_call".to_owned()), @@ -936,7 +349,7 @@ mod tests { .content .as_ref() .unwrap(), - "25903.06" + &Content::new_str("25903.06") ); } @@ -949,7 +362,7 @@ mod tests { "gpt-3.5-turbo", [ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some("Hello!".to_string()), + content: Some(Content::new_str("Hello!")), ..Default::default() }], ) @@ -987,7 +400,7 @@ mod tests { let user_message = ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some("Tell me a short joke".to_string()), + content: Some(Content::new_str("Tell me a short joke")), ..Default::default() }; @@ -1018,7 +431,7 @@ mod tests { let user_message = ChatCompletionMessage { role: ChatCompletionMessageRole::User, - content: Some("Tell me a short joke".to_string()), + content: Some(Content::new_str("Tell me a short joke")), ..Default::default() }; diff --git a/src/chat/modules.rs b/src/chat/modules.rs new file mode 100644 index 0000000..0c4fc1e --- /dev/null +++ b/src/chat/modules.rs @@ -0,0 +1,316 @@ +use super::{ + requests::ChatCompletionRequest, types::*, utils::forward_deserialized_chat_response_stream, + ChatCompletionDeltaMergeError, ChatCompletionMessageRole, +}; +use crate::{ + openai_get, openai_post, openai_request_stream, ApiResponseOrError, Credentials, Usage, +}; +use reqwest::Method; +use reqwest_eventsource::CannotCloneRequestError; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc::{channel, Receiver}; + +pub type ChatCompletion = ChatCompletionGeneric; + +/// A delta chat completion, which is streamed token by token. +pub type ChatCompletionDelta = ChatCompletionGeneric; + +#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] +pub struct ChatCompletionGeneric { + #[serde(default)] + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] +pub struct ChatCompletionChoice { + pub index: u64, + pub finish_reason: String, + pub message: ChatCompletionMessage, +} + +#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] +pub struct ChatCompletionChoiceDelta { + pub index: u64, + pub finish_reason: Option, + pub delta: ChatCompletionMessageDelta, +} + +fn is_none_or_empty_vec(opt: &Option>) -> bool { + opt.as_ref().map(|v| v.is_empty()).unwrap_or(true) +} + +#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq, Default)] +pub struct ChatCompletionMessage { + /// The role of the author of this message. + pub role: ChatCompletionMessageRole, + /// The contents of the message + /// + /// This is always required for all messages, except for when ChatGPT calls + /// a function. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// The name of the user in a multi-user chat + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// The function that ChatGPT called. This should be "None" usually, and is returned by ChatGPT and not provided by the developer + /// + /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + /// Tool call that this message is responding to. + /// Required if the role is `Tool`. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Tool calls that the assistant is requesting to invoke. + /// Can only be populated if the role is `Assistant`, + /// otherwise it should be empty. + #[serde(skip_serializing_if = "is_none_or_empty_vec")] + pub tool_calls: Option>, +} + +/// Same as ChatCompletionMessage, but received during a response stream. +#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] +pub struct ChatCompletionMessageDelta { + /// The role of the author of this message. + pub role: Option, + /// The contents of the message + pub content: Option, + /// The name of the user in a multi-user chat + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// The function that ChatGPT called + /// + /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + /// Tool call that this message is responding to. + /// Required if the role is `Tool`. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Tool calls that the assistant is requesting to invoke. + /// Can only be populated if the role is `Assistant`, + /// otherwise it should be empty. + #[serde(skip_serializing_if = "is_none_or_empty_vec")] + pub tool_calls: Option>, +} + +impl ChatCompletionChoiceDelta { + pub fn merge( + &mut self, + other: &ChatCompletionChoiceDelta, + ) -> Result<(), ChatCompletionDeltaMergeError> { + if self.index != other.index { + return Err(ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices); + } + if self.delta.role.is_none() { + if let Some(other_role) = other.delta.role { + // Set role to other_role. + self.delta.role = Some(other_role); + } + } + if self.finish_reason.is_none() { + if let Some(other_finish_reason) = &other.finish_reason { + // Set finish_reason to other_finish_reason. + self.finish_reason = Some(other_finish_reason.clone()); + } + } + if self.delta.name.is_none() { + if let Some(other_name) = &other.delta.name { + // Set name to other_name. + self.delta.name = Some(other_name.clone()); + } + } + if self.delta.tool_call_id.is_none() { + if let Some(other_tool_call_id) = &other.delta.tool_call_id { + // Set tool_call_id to other_tool_call_id. + self.delta.tool_call_id = Some(other_tool_call_id.clone()); + } + } + + // Merge contents. + match self.delta.content.as_mut() { + Some(content) => { + match &other.delta.content { + Some(other_content) => { + // Push other content into this one. + // TODO 这边要添加完整的合并逻辑 + if let Content::Str(content) = content { + if let Content::Str(other_content) = other_content { + content.push_str(other_content); + } + } + } + None => {} + } + } + None => { + match &other.delta.content { + Some(other_content) => { + // Set this content to other content. + self.delta.content = Some(other_content.clone()); + } + None => {} + } + } + }; + + // merge function calls + // function call names are concatenated + // arguments are merged by concatenating them + match self.delta.function_call.as_mut() { + Some(function_call) => { + match &other.delta.function_call { + Some(other_function_call) => { + // push the arguments string of the other function call into this one + match (&mut function_call.arguments, &other_function_call.arguments) { + (Some(function_call), Some(other_function_call)) => { + function_call.push_str(&other_function_call); + } + (None, Some(other_function_call)) => { + function_call.arguments = Some(other_function_call.clone()); + } + _ => {} + } + } + None => {} + } + } + None => { + match &other.delta.function_call { + Some(other_function_call) => { + // Set this content to other content. + self.delta.function_call = Some(other_function_call.clone()); + } + None => {} + } + } + }; + + // merge tools + match self.delta.tool_calls.as_mut() { + Some(tool_calls) => { + if let Some(other_tool_calls) = &other.delta.tool_calls { + tool_calls.iter_mut().zip(other_tool_calls).for_each( + |(tool_call, other_tool_call)| { + tool_call.merge(other_tool_call); + }, + ); + } + } + None => { + match &other.delta.tool_calls { + Some(other_tool_calls) => { + // Set this content to other content. + self.delta.tool_calls = Some(other_tool_calls.clone()); + } + None => {} + } + } + }; + Ok(()) + } +} + +impl From for ChatCompletionMessage { + fn from(value: ChatCompletionMessageDelta) -> ChatCompletionMessage { + ChatCompletionMessage { + role: value.role.unwrap_or(ChatCompletionMessageRole::Assistant), + content: value.content, + name: value.name, + function_call: value.function_call.map(ChatCompletionFunctionCall::from), + tool_call_id: value.tool_call_id, + tool_calls: value.tool_calls, + } + } +} + +impl From for ChatCompletion { + fn from(delta: ChatCompletionDelta) -> Self { + ChatCompletion { + id: delta.id, + object: delta.object, + created: delta.created, + model: delta.model, + usage: delta.usage, + choices: delta + .choices + .iter() + .map(|choice| ChatCompletionChoice { + index: choice.index, + finish_reason: clone_default_unwrapped_option_string(&choice.finish_reason), + message: choice.delta.clone().into(), + }) + .collect(), + } + } +} + +impl ChatCompletion { + pub async fn create(request: ChatCompletionRequest) -> ApiResponseOrError { + let credentials_opt = request.credentials.clone(); + openai_post("chat/completions", &request, credentials_opt).await + } + + /// Get a stored completion. + pub async fn get(id: &str, credentials: Credentials) -> ApiResponseOrError { + let route = format!("chat/completions/{}", id); + openai_get(route.as_str(), Some(credentials)).await + } +} + +impl ChatCompletionDelta { + pub async fn create( + request: ChatCompletionRequest, + ) -> Result, CannotCloneRequestError> { + let credentials_opt = request.credentials.clone(); + let stream = openai_request_stream( + Method::POST, + "chat/completions", + move |r| r.json(&request), + credentials_opt, + ) + .await?; + let (tx, rx) = channel::(32); + tokio::spawn(forward_deserialized_chat_response_stream(stream, tx)); + Ok(rx) + } + pub fn merge( + &mut self, + other: ChatCompletionDelta, + ) -> Result<(), ChatCompletionDeltaMergeError> { + if other.id.ne(&self.id) { + return Err(ChatCompletionDeltaMergeError::DifferentCompletionIds); + } + for other_choice in other.choices.iter() { + for choice in self.choices.iter_mut() { + if choice.index != other_choice.index { + continue; + } + choice.merge(other_choice)?; + } + } + Ok(()) + } +} + +/// A list of messages for a chat completion. +#[derive(Deserialize, Clone, Debug, Eq, PartialEq)] +pub struct ChatCompletionMessages { + pub data: Vec, + pub object: String, + pub first_id: Option, + pub last_id: Option, + pub has_more: bool, +} + +fn clone_default_unwrapped_option_string(string: &Option) -> String { + match string { + Some(value) => value.clone(), + None => "".to_string(), + } +} diff --git a/src/chat/requests.rs b/src/chat/requests.rs new file mode 100644 index 0000000..56ca089 --- /dev/null +++ b/src/chat/requests.rs @@ -0,0 +1,208 @@ +use super::{modules::*, types::*, ToolChoiceMode}; +use crate::{openai_get_with_query, ApiResponseOrError, Credentials, RequestPagination}; +use derive_builder::Builder; +use reqwest_eventsource::CannotCloneRequestError; +use serde::Serialize; +use serde_json::Value; +use std::collections::HashMap; +use tokio::sync::mpsc::Receiver; + +#[derive(Serialize, Builder, Debug, Clone)] +#[builder(derive(Clone, Debug, PartialEq))] +#[builder(pattern = "owned")] +#[builder(name = "ChatCompletionBuilder")] +#[builder(setter(strip_option, into))] +pub struct ChatCompletionRequest { + /// ID of the model to use. Currently, only `gpt-3.5-turbo`, `gpt-3.5-turbo-0301` and `gpt-4` + /// are supported. + model: String, + /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction). + messages: Vec, + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// We generally recommend altering this or `top_p` but not both. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or `temperature` but not both. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + top_p: Option, + /// How many chat completion choices to generate for each input message. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + n: Option, + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + stream: Option, + /// Up to 4 sequences where the API will stop generating further tokens. + #[builder(default)] + #[serde(skip_serializing_if = "Vec::is_empty")] + stop: Vec, + /// This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + seed: Option, + /// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens). + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + /// The maximum number of tokens allowed for the generated answer. + /// For reasoning models such as o1 and o3-mini, this does not include reasoning tokens. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + max_completion_tokens: Option, + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + /// + /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + presence_penalty: Option, + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + /// + /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + frequency_penalty: Option, + /// Modify the likelihood of specified tokens appearing in the completion. + /// + /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + logit_bias: Option>, + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + #[builder(default)] + #[serde(skip_serializing_if = "String::is_empty")] + user: String, + /// A list of tools the model may call during execution. + /// Currently, only function-based tools (`ChatCompletionToolDefinition::Function`) are supported. + /// + /// When tools are provided, the model can choose to call them using the behavior specified by `tool_choice`. + /// If no tools are provided, `tool_choice` defaults to `None`, and the model will not call any tools. + /// + /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools) + #[builder(default)] + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + /// Controls how the model responds to tool calls. + /// + /// - "none" means the model will not call any tools and will respond directly to the user. + /// - "auto" means the model can choose between responding directly or calling one or more tools. + /// - "required" means the model must call at least one tool. + /// + /// The default is "none" when no tools are provided, and "auto" when tools are available. + /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice) + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + /// Describe functions that ChatGPT can call + /// The latest models of ChatGPT support function calling, which allows you to define functions that can be called from the prompt. + /// For example, you can define a function called "get_weather" that returns the weather in a given city + /// + /// [Function calling API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions) + /// [See more information about function calling in ChatGPT.](https://platform.openai.com/docs/guides/gpt/function-calling) + #[deprecated(note = "Use tools instead")] + #[builder(default)] + #[serde(skip_serializing_if = "Vec::is_empty")] + functions: Vec, + /// A string or object of the function to call + /// + /// Controls how the model responds to function calls + /// + /// - "none" means the model does not call a function, and responds to the end-user. + /// - "auto" means the model can pick between an end-user or calling a function. + /// - Specifying a particular function via {"name":\ "my_function"} forces the model to call that function. + /// + /// "none" is the default when no functions are present. "auto" is the default if functions are present. + #[deprecated(note = "Use tool_choice instead")] + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + function_call: Option, + /// An object specifying the format that the model must output. Compatible with GPT-4 Turbo and all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. + /// Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + /// Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + response_format: Option, + /// The credentials to use for this request. + #[serde(skip_serializing)] + #[builder(default)] + pub credentials: Option, + /// Parameters unique to the Venice API. + /// https://docs.venice.ai/api-reference/api-spec + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + venice_parameters: Option, + /// Whether to store the completion for use in distillation or evals. + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub store: Option, +} + +impl ChatCompletionBuilder { + pub async fn create(self) -> ApiResponseOrError { + ChatCompletion::create(self.build().unwrap()).await + } + + pub async fn create_stream( + mut self, + ) -> Result, CannotCloneRequestError> { + self.stream = Some(Some(true)); + ChatCompletionDelta::create(self.build().unwrap()).await + } +} + +impl ChatCompletionGeneric { + pub fn builder( + model: &str, + messages: impl Into>, + ) -> ChatCompletionBuilder { + ChatCompletionBuilder::create_empty() + .model(model) + .messages(messages) + } +} + +#[derive(Serialize, Builder, Debug, Clone, Default)] +#[builder(derive(Clone, Debug, PartialEq))] +#[builder(pattern = "owned")] +#[builder(name = "ChatCompletionMessagesRequestBuilder")] +#[builder(setter(strip_option, into))] +pub struct ChatCompletionMessagesRequest { + #[serde(skip_serializing)] + pub completion_id: String, + + #[builder(default)] + #[serde(skip_serializing)] + pub credentials: Option, + + #[builder(default)] + #[serde(flatten)] + pub pagination: RequestPagination, +} + +impl ChatCompletionMessages { + /// Create a builder for fetching messages for a stored completion. + pub fn builder(completion_id: String) -> ChatCompletionMessagesRequestBuilder { + ChatCompletionMessagesRequestBuilder::create_empty() + .completion_id(completion_id.to_string()) + } + + /// Fetch messages for a stored completion. + pub async fn fetch( + request: ChatCompletionMessagesRequest, + ) -> ApiResponseOrError { + let route = format!("chat/completions/{}/messages", request.completion_id); + let credentials = request.credentials.clone(); + openai_get_with_query(route.as_str(), &request, credentials).await + } +} + +impl ChatCompletionMessagesRequestBuilder { + /// Fetch messages for the specified completion. + pub async fn fetch(self) -> ApiResponseOrError { + ChatCompletionMessages::fetch(self.build().unwrap()).await + } +} diff --git a/src/chat/types.rs b/src/chat/types.rs new file mode 100644 index 0000000..3d5c675 --- /dev/null +++ b/src/chat/types.rs @@ -0,0 +1,191 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum Content { + Str(String), + Object(Vec), +} + +impl Display for Content { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Content::Str(s) => write!(f, "{}", s), + Content::Object(m) => write!(f, "{:?}", m), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Message { + Text { text: String }, + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct ImageUrl { + url: String, +} + +impl Content { + pub fn new_str(s: &str) -> Content { + Content::Str(s.to_string()) + } + + pub fn new_object(m: Message) -> Content { + Content::Object(vec![m]) + } + + pub fn new_text(text: &str) -> Content { + Content::Object(vec![Message::Text { + text: text.to_string(), + }]) + } + + pub fn new_image_url(url: &str) -> Content { + Content::Object(vec![Message::ImageUrl { + image_url: ImageUrl { + url: url.to_string(), + }, + }]) + } +} + +#[macro_export] +macro_rules! new_content { + ($($json:tt)+) => {{ + use serde_json::Value; + use $crate::chat::types::Content; + match serde_json::json!($($json)+) { + Value::String(s) => Content::new_str(&s), + Value::Array(a) => serde_json::from_value::(Value::Array(a)).expect("Failed to parse array as Content"), + Value::Object(o) => serde_json::from_value(Value::Array(vec![Value::Object(o)])).expect("Failed to parse object as Content"), + _ => panic!("Invalid Content"), + } + }}; +} + +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +pub struct ToolCall { + /// The ID of the tool call. + pub id: String, + /// The type of the tool. Currently, only `function` is supported. + pub r#type: String, + /// The function that the model called. + pub function: ToolCallFunction, +} + +impl ToolCall { + pub fn merge(&mut self, other: &ToolCall) { + if self.id.is_empty() { + self.id = other.id.clone(); + } + if self.r#type.is_empty() { + self.r#type = other.r#type.clone(); + } + self.function.merge(&other.function); + } +} + +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +pub struct ToolCallFunction { + /// The name of the function to call. + #[serde(default)] + pub name: String, + /// The arguments to call the function with, as generated by the model in + /// JSON format. + /// Note that the model does not always generate valid JSON, and may + /// hallucinate parameters not defined by your function schema. + /// Validate the arguments in your code before calling your function. + #[serde(default)] + pub arguments: String, +} + +impl ToolCallFunction { + pub fn merge(&mut self, other: &ToolCallFunction) { + if self.name.is_empty() { + self.name = other.name.clone(); + } + self.arguments.push_str(&other.arguments); + } +} + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +pub struct ChatCompletionFunctionCall { + /// The name of the function ChatGPT called + pub name: String, + /// The arguments that ChatGPT called (formatted in JSON) + /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) + pub arguments: String, +} + +/// Same as ChatCompletionFunctionCall, but received during a response stream. +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +pub struct ChatCompletionFunctionCallDelta { + /// The name of the function ChatGPT called + pub name: Option, + /// The arguments that ChatGPT called (formatted in JSON) + /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) + pub arguments: Option, +} + +impl From for ChatCompletionFunctionCall { + fn from(delta: ChatCompletionFunctionCallDelta) -> Self { + ChatCompletionFunctionCall { + name: delta.name.unwrap_or("".to_string()), + arguments: delta.arguments.unwrap_or_default(), + } + } +} + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChatCompletionToolDefinition { + Function { + function: ChatCompletionFunctionDefinition, + }, +} + +#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] +pub struct ChatCompletionFunctionDefinition { + /// The name of the function + pub name: String, + /// The description of the function + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The parameters of the function formatted in JSON Schema + /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-parameters) + /// [See more information about JSON Schema.](https://json-schema.org/understanding-json-schema/) + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +#[derive(Serialize, Debug, Clone, Eq, PartialEq)] +pub struct ChatCompletionResponseFormat { + /// Must be one of text or json_object (defaults to text) + #[serde(rename = "type")] + typ: String, +} + +impl ChatCompletionResponseFormat { + pub fn json_object() -> Self { + ChatCompletionResponseFormat { + typ: "json_object".to_string(), + } + } + + pub fn text() -> Self { + ChatCompletionResponseFormat { + typ: "text".to_string(), + } + } +} + +#[derive(Serialize, Debug, Clone, Eq, PartialEq)] +pub struct VeniceParameters { + pub include_venice_system_prompt: bool, +} diff --git a/src/chat/utils.rs b/src/chat/utils.rs new file mode 100644 index 0000000..7fc8e54 --- /dev/null +++ b/src/chat/utils.rs @@ -0,0 +1,37 @@ +use futures_util::TryStreamExt; +use reqwest_eventsource::{Event, EventSource}; +use tokio::sync::mpsc::Sender; +use tracing::warn; + +use super::modules::ChatCompletionDelta; + +pub async fn forward_deserialized_chat_response_stream( + stream: EventSource, + tx: Sender, +) -> anyhow::Result<()> { + stream + .try_for_each(async |event| { + match event { + Event::Message(event) => { + match serde_json::from_str::(&event.data) { + Ok(completion) => { + if tx.send(completion).await.is_err() { + warn!("Failed to send completion delta: channel closed"); + } + } + Err(e) => { + warn!( + "Failed to deserialize ChatCompletionDelta from JSON data '{}': {}", + &event.data, e + ); + } + } + } + _ => {} + } + Ok::<_, reqwest_eventsource::Error>(()) + }) + .await?; + drop(tx); + Ok(()) +} diff --git a/test_data/file_upload_test1.jsonl b/test_data/file_upload_test1.jsonl deleted file mode 100644 index d2c6770..0000000 --- a/test_data/file_upload_test1.jsonl +++ /dev/null @@ -1,3 +0,0 @@ -{"prompt": "example data: the most correct data\n###\n", "completion": "yes"} -{"prompt": "example data: totally wrong data\n###\n", "completion": "no"} -{"prompt": "example data: very correct data\n###\n", "completion": "yes"} \ No newline at end of file