From 92e63d15a59cb865fbeb59f3e3baa7c7a4945254 Mon Sep 17 00:00:00 2001 From: Alex Grinman Date: Wed, 8 Jan 2025 21:35:41 +0000 Subject: [PATCH 1/9] Add assistants, threads, runs, and messages APIs (assistants beta v2) --- Cargo.toml | 12 +- src/assistants/assistants.rs | 141 +++++++++++++++++ src/assistants/messages.rs | 116 ++++++++++++++ src/assistants/mod.rs | 6 + src/assistants/runs.rs | 288 +++++++++++++++++++++++++++++++++++ src/assistants/threads.rs | 15 ++ src/client.rs | 173 +++++++++++++++++++++ src/lib.rs | 15 +- 8 files changed, 761 insertions(+), 5 deletions(-) create mode 100644 src/assistants/assistants.rs create mode 100644 src/assistants/messages.rs create mode 100644 src/assistants/mod.rs create mode 100644 src/assistants/runs.rs create mode 100644 src/assistants/threads.rs create mode 100644 src/client.rs diff --git a/Cargo.toml b/Cargo.toml index 9523548..1a504b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,13 +11,21 @@ keywords = ["ai", "machine-learning", "openai", "library"] [dependencies] serde_json = "1.0.94" derive_builder = "0.20.0" -reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"], optional = true } +reqwest = { version = "0.12", default-features = false, features = [ + "json", + "stream", + "multipart", +], optional = true } serde = { version = "1.0.157", features = ["derive"] } reqwest-eventsource = "0.6" tokio = { version = "1.26.0", features = ["full"] } -anyhow = "1.0.70" +anyhow = "1.0" futures-util = "0.3.28" bytes = "1.4.0" +schemars = "0.8" +either = { version = "1.8.1", features = ["serde"] } +serde-double-tag = "0.0.4" +log = "0.4" [dev-dependencies] dotenvy = "0.15.7" diff --git a/src/assistants/assistants.rs b/src/assistants/assistants.rs new file mode 100644 index 0000000..98fd2fb --- /dev/null +++ b/src/assistants/assistants.rs @@ -0,0 +1,141 @@ +use std::collections::HashMap; + +use schemars::schema::RootSchema; +use serde::{Deserialize, Serialize}; + +use crate::{client::{Empty, OpenAiClient}, ApiResponseOrError}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Assistant { + pub id: String, + pub object: String, + pub created_at: u32, + /// The name of the assistant. The maximum length is 256 characters. + pub name: Option, + /// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them. + pub model: String, + /// The system instructions that the assistant uses. The maximum length is 256,000 characters. + pub instructions: Option, + pub tools: Vec, + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs. + pub tool_resources: Option, + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, + /// The default model to use for this assistant. + pub response_format: Option, +} + +#[derive(Debug, Clone, serde_double_tag::Deserialize, serde_double_tag::Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum Tool { + CodeInterpreter, + Function(Function), + FileSearch(FileSearch), +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Function { + pub name: String, + pub description: String, + pub parameters: RootSchema, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FunctionParameters { + pub title: String, + pub description: String, + #[serde(rename = "type")] + pub type_: String, + pub required: Vec, + pub properties: HashMap, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FunctionProperty { + pub description: String, + #[serde(rename = "type")] + pub type_: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FileSearch { + pub max_num_results: usize, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ToolResources { + pub code_interpreter: Option, + pub file_search: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct CodeInterpreterResources { + /// A list of file IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool. + pub file_ids: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FileSearchResources { + /// The ID of the vector store attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + pub vector_store_ids: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ResponseFormat { + Auto, +} + +#[derive(Serialize, Default, Debug, Clone)] +pub struct CreateAssistantRequest { + /// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them. + pub model: String, + + /// The name of the assistant. The maximum length is 256 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// The description of the assistant. The maximum length is 256 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The system instructions that the assistant uses. The maximum length is 256,000 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + /// A set of tools that the assistant can use. + pub tools: Vec, + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_resources: Option, + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + /// The default model to use for this assistant. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, +} + +impl OpenAiClient { + pub async fn create_assistant( + &self, + request: CreateAssistantRequest, + ) -> ApiResponseOrError { + self.post("assistants", Some(request)).await + } + + pub async fn get_assistant(&self, assistant_id: &str) -> ApiResponseOrError { + self.get(format!("assistants/{}", assistant_id)).await + } + + pub async fn delete_assistant(&self, assistant_id: &str) -> ApiResponseOrError { + self.delete(format!("assistants/{}", assistant_id)).await + } + + pub async fn update_assistant( + &self, + assistant_id: &str, + request: CreateAssistantRequest, + ) -> ApiResponseOrError { + self.post(format!("assistants/{}", assistant_id), Some(request)) + .await + } +} diff --git a/src/assistants/messages.rs b/src/assistants/messages.rs new file mode 100644 index 0000000..4c001f5 --- /dev/null +++ b/src/assistants/messages.rs @@ -0,0 +1,116 @@ +use crate::{assistants::Tool, client::OpenAiClient, ApiResponseOrError}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Message { + pub id: String, + pub object: String, + pub created_at: u32, + /// The thread ID that this message belongs to. + pub thread_id: String, + /// The status of the message, which can be either in_progress, incomplete, or completed. + pub status: Option, + /// On an incomplete message, details about why the message is incomplete. + pub incomplete_details: Option, + /// The Unix timestamp (in seconds) for when the message was completed. + pub completed_at: Option, + /// The Unix timestamp (in seconds) for when the message was marked as incomplete. + pub incomplete_at: Option, + /// The entity that produced the message. One of user or assistant + pub role: Role, + /// The content of the message. + pub content: Vec, + /// The assistant that produced the message. + pub assistant_id: Option, + /// The ID of the run associated with the creation of this message. Value is null when messages are created manually using the create message or create thread endpoints. + pub run_id: Option, + /// A list of files attached to the message. + pub attachments: Option>, + /// A set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum Status { + InProgress, + Incomplete, + Completed, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct IncompleteDetails { + pub reason: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum Role { + User, + Assistant, +} + +#[derive(Debug, serde_double_tag::Serialize, serde_double_tag::Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum Content { + Text(Text), + ImageFile(ImageFile), + ImageUrl(ImageUrl), + Refusal(Refusal), +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct Text { + pub value: String, + pub annotations: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Annotation { + #[serde(rename = "type")] + pub kind: String, + pub text: String, + pub start_index: u32, + pub end_index: u32, + pub file_citation: FileCitation, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FileCitation { + pub file_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImageFile { + pub file_id: String, + pub detail: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImageUrl { + pub image_url: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Refusal { + pub refusal: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Attachment { + pub file_id: String, + pub tools: Tool, +} + +impl OpenAiClient { + pub async fn list_messages( + &self, + thread_id: &str, + after_id: Option, + ) -> ApiResponseOrError> { + self.list(format!("threads/{thread_id}/messages"), after_id) + .await + } +} diff --git a/src/assistants/mod.rs b/src/assistants/mod.rs new file mode 100644 index 0000000..b72a9cf --- /dev/null +++ b/src/assistants/mod.rs @@ -0,0 +1,6 @@ +pub mod assistants; +pub use assistants::*; + +pub mod messages; +pub mod runs; +pub mod threads; diff --git a/src/assistants/runs.rs b/src/assistants/runs.rs new file mode 100644 index 0000000..de220c9 --- /dev/null +++ b/src/assistants/runs.rs @@ -0,0 +1,288 @@ +use derive_builder::Builder; +use either::Either; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::{assistants::Tool, chat::ToolCall, client::OpenAiClient, ApiResponseOrError}; + +use super::{ + messages::{Attachment, IncompleteDetails, Role}, + ResponseFormat, ToolResources, +}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Run { + pub id: String, + pub object: String, + pub created_at: u32, + /// The ID of the assistant used for this run. + pub assistant_id: String, + /// The ID of the thread associated with this run. + pub thread_id: String, + /// The status of the run. + pub status: Status, + /// Details on the action required to continue the run. Will be null if no action is required. + pub required_action: Option, + + /// The last error that occurred during this run. + pub last_error: Option, + + /// The time at which the run will expire. + pub expires_at: Option, + /// The time at which the run was started. + pub started_at: Option, + /// The time at which the run was completed. + pub completed_at: Option, + /// The time at which the run was cancelled. + pub cancelled_at: Option, + /// The time at which the run was failed. + pub failed_at: Option, + /// The time at which the run was incomplete. + pub incomplete_details: Option, + + /// The model used for this run. + pub model: String, + + /// The instructions given to the assistant. + pub instructions: String, + + /// The tools used for this run. + pub tools: Vec, + + /// The usage of the run. + pub usage: Option, + + /// The truncation strategy used for this run. + pub truncation_strategy: Option, + + /// Whether to run tool calls in parallel. + pub parallel_tool_calls: bool, + + /// The tool choice used for this run. + pub tool_choice: ToolChoice, + + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum Status { + Queued, + InProgress, + RequiresAction, + Cancelling, + Cancelled, + Failed, + Completed, + Incomplete, + Expired, +} + +impl Status { + pub fn is_terminal(&self) -> bool { + !matches!(self, Status::InProgress | Status::Queued) + } +} + +#[derive(Debug, serde_double_tag::Deserialize, serde_double_tag::Serialize, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum RequiredAction { + SubmitToolOutputs { tool_calls: Vec }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct LastError { + pub code: String, + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct TruncationStrategy { + #[serde(rename = "type")] + pub kind: String, + pub last_messages: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(transparent)] +pub struct ToolChoice { + #[serde(with = "either::serde_untagged")] + pub inner: Either, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoiceStrategy { + None, + Auto, + Required, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoiceFunction { + FileSearch, + Function { name: String }, +} + +#[derive(Serialize, Builder, Debug, Clone, Default)] +#[builder(pattern = "owned")] +#[builder(name = "CreateThreadRunBuilder")] +#[builder(setter(strip_option, into))] +pub struct CreateThreadRunRequest { + /// ID of the assistant to use. + pub assistant_id: String, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub model: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tools: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tool_resources: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub metadata: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub parallel_tool_calls: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub response_format: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tool_choice: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub max_completion_tokens: Option, + + /// the thread to create + pub thread: CreateThreadRequest, +} + +#[derive(Serialize, Builder, Debug, Clone, Default)] +#[builder(pattern = "owned")] +#[builder(name = "CreateThreadBuilder")] +#[builder(setter(strip_option, into))] +pub struct CreateThreadRequest { + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tool_resources: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub metadata: Option>, +} + +#[derive(Serialize, Builder, Debug, Clone)] +#[builder(pattern = "owned")] +#[builder(name = "CreateThreadMessageBuilder")] +#[builder(setter(strip_option, into))] +pub struct CreateThreadMessageRequest { + pub role: Role, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub attachments: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SubmitToolOutputsRequest { + pub tool_outputs: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ToolOutput { + pub tool_call_id: String, + pub output: String, +} + +#[derive(Serialize, Builder, Debug, Clone, Default)] +#[builder(pattern = "owned")] +#[builder(name = "CreateRunBuilder")] +#[builder(setter(strip_option, into))] +pub struct CreateRunRequest { + pub assistant_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub additional_messages: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub max_completion_tokens: Option, +} + +impl OpenAiClient { + pub async fn create_thread_run( + &self, + request: CreateThreadRunRequest, + ) -> ApiResponseOrError { + self.post(format!("threads/runs"), Some(request)).await + } + + pub async fn create_run( + &self, + thread_id: &str, + request: CreateRunRequest, + ) -> ApiResponseOrError { + self.post(format!("threads/{thread_id}/runs"), Some(request)) + .await + } + + pub async fn poll_run(&self, mut run: Run) -> ApiResponseOrError { + while !run.status.is_terminal() { + run = self + .get_run(run.thread_id.as_str(), run.id.as_str()) + .await?; + } + Ok(run) + } + + pub async fn get_run(&self, thread_id: &str, run_id: &str) -> ApiResponseOrError { + self.get(format!("threads/{thread_id}/runs/{run_id}")).await + } + + pub async fn submit_tool_outputs_and_poll( + &self, + run: Run, + request: SubmitToolOutputsRequest, + ) -> ApiResponseOrError { + let run: Run = self + .post( + format!( + "threads/{}/runs/{}/submit_tool_outputs", + run.thread_id, run.id + ), + Some(request), + ) + .await?; + + self.poll_run(run).await + } +} diff --git a/src/assistants/threads.rs b/src/assistants/threads.rs new file mode 100644 index 0000000..f4e8d6f --- /dev/null +++ b/src/assistants/threads.rs @@ -0,0 +1,15 @@ +use serde::Deserialize; +use std::collections::HashMap; + +use crate::assistants::ToolResources; + +#[derive(Debug, Deserialize, Clone)] +pub struct Thread { + pub id: String, + pub object: String, + pub created_at: u32, + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs. + pub tool_resources: Option, + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..25ad858 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,173 @@ +use std::str::FromStr; + +use crate::{ApiResponseOrError, Credentials, OpenAiError, DEFAULT_CREDENTIALS}; +use anyhow::Result; +use reqwest::{ + header::{HeaderName, HeaderValue, AUTHORIZATION}, + Client, Method, Response, +}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +#[derive(Clone)] +pub struct OpenAiClient { + credentials: Credentials, + client: Client, +} + +impl std::fmt::Debug for OpenAiClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "OpenAiClient") + } +} + +#[derive(Debug, Clone, Deserialize)] +struct OpenAiErrorWrapper { + error: OpenAiError, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Empty {} + +impl OpenAiClient { + pub fn default() -> Result { + Self::new(DEFAULT_CREDENTIALS.read().unwrap().clone()) + } + + pub fn new(credentials: Credentials) -> Result { + let client = Client::builder() + .default_headers( + [ + ( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {}", credentials.api_key))?, + ), + ( + HeaderName::from_str("OpenAI-Beta")?, + HeaderValue::from_str("assistants=v2")?, + ), + ] + .into_iter() + .collect(), + ) + .build()?; + + Ok(Self { + credentials, + client, + }) + } + + async fn request_inner( + &self, + method: Method, + route: R, + body: Option, + ) -> Result + where + R: Into, + S: Serialize, + { + let url = format!("{}{}", self.credentials.base_url, route.into()); + log::debug!("OpenAI Request[{}] {}", method.to_string(), url); + + let mut request = self.client.request(method.clone(), url.clone()); + + if let Some(body) = body { + request = request.json(&body); + } + + let response = request.send().await?; + + log::debug!( + "OpenAI Response[{}] {} {url}", + method.to_string(), + response.status().as_str() + ); + Ok(response) + } + + pub async fn request( + &self, + method: Method, + route: R, + body: Option, + ) -> ApiResponseOrError + where + R: Into, + S: Serialize, + T: DeserializeOwned, + { + let response = self.request_inner(method, route, body).await?; + let api_response = if response.status().is_success() { + response.json::().await? + } else { + let result = response.text().await?; + if let Ok(api_response) = serde_json::from_str::(&result) { + return Err(api_response.error); + } else { + return Err(OpenAiError::new(result, "unknown".to_string())); + } + }; + + Ok(api_response) + } + pub async fn get(&self, route: R) -> ApiResponseOrError + where + R: Into, + T: DeserializeOwned, + { + self.request::<(), R, T>(Method::GET, route, None).await + } + + pub async fn post(&self, route: R, body: S) -> ApiResponseOrError + where + R: Into, + S: Serialize, + T: DeserializeOwned, + { + self.request(Method::POST, route, Some(body)).await + } + + pub async fn delete(&self, route: R) -> ApiResponseOrError + where + R: Into, + { + self.request::<(), R, Empty>(Method::DELETE, route, None) + .await + } + + pub async fn list(&self, route: R, after: Option) -> ApiResponseOrError> + where + R: Into, + T: DeserializeOwned + std::fmt::Debug, + { + let mut route = if let Some(after) = after { + format!("{}?order=asc&after={after}", route.into()) + } else { + format!("{}?order=asc", route.into()) + }; + + let mut has_more = true; + let mut data = Vec::new(); + + while has_more { + let list: List = self.get(&route).await?; + data.extend(list.data); + has_more = list.has_more; + route = format!( + "{route}?order=asc&after={}", + list.last_id.unwrap_or_default() + ); + } + + Ok(data) + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct List { + pub first_id: Option, + pub last_id: Option, + pub data: Vec, + pub has_more: bool, +} diff --git a/src/lib.rs b/src/lib.rs index fd24555..1a44d03 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,9 @@ use std::env; use std::env::VarError; use std::sync::{LazyLock, RwLock}; +pub mod assistants; pub mod chat; +pub mod client; pub mod completions; pub mod edits; pub mod embeddings; @@ -61,7 +63,7 @@ impl Credentials { #[derive(Deserialize, Debug, Clone, Eq, PartialEq)] pub struct OpenAiError { - pub message: String, + pub message: Option, #[serde(rename = "type")] pub error_type: String, pub param: Option, @@ -71,7 +73,7 @@ pub struct OpenAiError { impl OpenAiError { fn new(message: String, error_type: String) -> OpenAiError { OpenAiError { - message, + message: Some(message), error_type, param: None, code: None, @@ -81,7 +83,12 @@ impl OpenAiError { impl std::fmt::Display for OpenAiError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(&self.message) + f.write_str( + &self + .message + .as_ref() + .unwrap_or(&"empty error message".to_string()), + ) } } @@ -105,6 +112,7 @@ pub type ApiResponseOrError = Result; impl From for OpenAiError { fn from(value: reqwest::Error) -> Self { + println!("{:?}", &value); OpenAiError::new(value.to_string(), "reqwest".to_string()) } } @@ -150,6 +158,7 @@ where let mut request = client.request(method, format!("{}{route}", credentials.base_url)); request = builder(request); let response = request + .header("OpenAI-Beta", "assistants=v2") .header(AUTHORIZATION, format!("Bearer {}", credentials.api_key)) .send() .await?; From 267571ff14fc464a1a58efd4f9564b557b176b62 Mon Sep 17 00:00:00 2001 From: Alex Grinman Date: Sat, 11 Jan 2025 18:46:25 +0000 Subject: [PATCH 2/9] remove print --- src/lib.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1a44d03..c98f645 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -112,8 +112,7 @@ pub type ApiResponseOrError = Result; impl From for OpenAiError { fn from(value: reqwest::Error) -> Self { - println!("{:?}", &value); - OpenAiError::new(value.to_string(), "reqwest".to_string()) + OpenAiError::new(format!("{:?}", value), "reqwest".to_string()) } } From b12998e43399a3fe5b376abb3efd70f5afa0878f Mon Sep 17 00:00:00 2001 From: Alex Grinman Date: Wed, 15 Jan 2025 22:04:16 +0000 Subject: [PATCH 3/9] support for files --- Cargo.toml | 2 ++ src/assistants/assistants.rs | 11 +++--- src/assistants/mod.rs | 2 ++ src/client.rs | 66 ++++++++++++++++++++++++++++-------- src/files.rs | 2 +- 5 files changed, 63 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1a504b5..0c7eb05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,8 @@ schemars = "0.8" either = { version = "1.8.1", features = ["serde"] } serde-double-tag = "0.0.4" log = "0.4" +strum = { version = "0.26", features = ["derive"] } +strum_macros = "0.26" [dev-dependencies] dotenvy = "0.15.7" diff --git a/src/assistants/assistants.rs b/src/assistants/assistants.rs index 98fd2fb..e5a4e6d 100644 --- a/src/assistants/assistants.rs +++ b/src/assistants/assistants.rs @@ -3,7 +3,10 @@ use std::collections::HashMap; use schemars::schema::RootSchema; use serde::{Deserialize, Serialize}; -use crate::{client::{Empty, OpenAiClient}, ApiResponseOrError}; +use crate::{ + client::{Empty, OpenAiClient}, + ApiResponseOrError, +}; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Assistant { @@ -58,12 +61,12 @@ pub struct FunctionProperty { pub type_: String, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, Default)] pub struct FileSearch { - pub max_num_results: usize, + pub max_num_results: Option, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, Default)] pub struct ToolResources { pub code_interpreter: Option, pub file_search: Option, diff --git a/src/assistants/mod.rs b/src/assistants/mod.rs index b72a9cf..651476d 100644 --- a/src/assistants/mod.rs +++ b/src/assistants/mod.rs @@ -1,6 +1,8 @@ pub mod assistants; pub use assistants::*; +pub mod files; pub mod messages; pub mod runs; pub mod threads; +pub mod vector_stores; diff --git a/src/client.rs b/src/client.rs index 25ad858..bcf2ba9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,7 +4,8 @@ use crate::{ApiResponseOrError, Credentials, OpenAiError, DEFAULT_CREDENTIALS}; use anyhow::Result; use reqwest::{ header::{HeaderName, HeaderValue, AUTHORIZATION}, - Client, Method, Response, + multipart::Form, + Client, Method, RequestBuilder, Response, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -28,6 +29,21 @@ struct OpenAiErrorWrapper { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Empty {} +enum RequestBody { + Json(S), + Multipart(Form), + None, +} + +impl From> for RequestBody { + fn from(value: Option) -> Self { + match value { + Some(value) => RequestBody::Json(value), + None => RequestBody::None, + } + } +} + impl OpenAiClient { pub fn default() -> Result { Self::new(DEFAULT_CREDENTIALS.read().unwrap().clone()) @@ -57,47 +73,58 @@ impl OpenAiClient { }) } + fn request_builder(&self, method: Method, route: R) -> RequestBuilder + where + R: Into, + { + let url = format!("{}{}", self.credentials.base_url, route.into()); + log::debug!("OpenAI Request[{}] {}", method.to_string(), url); + + self.client.request(method.clone(), url.clone()) + } + async fn request_inner( &self, method: Method, route: R, - body: Option, + body: RequestBody, ) -> Result where R: Into, S: Serialize, { - let url = format!("{}{}", self.credentials.base_url, route.into()); - log::debug!("OpenAI Request[{}] {}", method.to_string(), url); - - let mut request = self.client.request(method.clone(), url.clone()); + let mut request = self.request_builder(method.clone(), route); - if let Some(body) = body { - request = request.json(&body); + match body { + RequestBody::Json(body) => request = request.json(&body), + RequestBody::Multipart(body) => request = request.multipart(body), + RequestBody::None => (), } let response = request.send().await?; log::debug!( - "OpenAI Response[{}] {} {url}", + "OpenAI Response[{}] {} {}", method.to_string(), - response.status().as_str() + response.status().as_str(), + response.url() ); Ok(response) } - pub async fn request( + async fn request( &self, method: Method, route: R, - body: Option, + body: B, ) -> ApiResponseOrError where R: Into, + B: Into>, S: Serialize, T: DeserializeOwned, { - let response = self.request_inner(method, route, body).await?; + let response = self.request_inner(method, route, body.into()).await?; let api_response = if response.status().is_success() { response.json::().await? } else { @@ -116,7 +143,7 @@ impl OpenAiClient { R: Into, T: DeserializeOwned, { - self.request::<(), R, T>(Method::GET, route, None).await + self.request::<_, (), R, T>(Method::GET, route, None).await } pub async fn post(&self, route: R, body: S) -> ApiResponseOrError @@ -128,11 +155,20 @@ impl OpenAiClient { self.request(Method::POST, route, Some(body)).await } + pub async fn post_multipart(&self, route: R, form: Form) -> ApiResponseOrError + where + R: Into, + T: DeserializeOwned, + { + self.request::<_, (), R, T>(Method::POST, route, RequestBody::Multipart(form)) + .await + } + pub async fn delete(&self, route: R) -> ApiResponseOrError where R: Into, { - self.request::<(), R, Empty>(Method::DELETE, route, None) + self.request::<_, (), R, Empty>(Method::DELETE, route, None) .await } diff --git a/src/files.rs b/src/files.rs index 7c167fa..5b1c6b7 100644 --- a/src/files.rs +++ b/src/files.rs @@ -351,7 +351,7 @@ mod tests { assert_eq!(openapi_err.error_type, "io"); assert_eq!( openapi_err.message, - "No such file or directory (os error 2)" + Some("No such file or directory (os error 2)".to_string()) ) } From 411c58191fc6aee8c9aa94d7ba5bf0517e981f22 Mon Sep 17 00:00:00 2001 From: Alex Grinman Date: Wed, 15 Jan 2025 22:04:32 +0000 Subject: [PATCH 4/9] files and vs --- src/assistants/files.rs | 49 +++++++++++++++++ src/assistants/vector_stores.rs | 95 +++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+) create mode 100644 src/assistants/files.rs create mode 100644 src/assistants/vector_stores.rs diff --git a/src/assistants/files.rs b/src/assistants/files.rs new file mode 100644 index 0000000..6c95111 --- /dev/null +++ b/src/assistants/files.rs @@ -0,0 +1,49 @@ +use crate::{client::OpenAiClient, ApiResponseOrError}; +use reqwest::{ + multipart::{Form, Part}, + Body, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct File { + pub id: String, + pub object: String, + pub created_at: u32, + pub bytes: u32, + pub filename: String, + pub purpose: FilePurpose, +} + +#[derive(Debug, Serialize, Deserialize, Clone, strum_macros::Display)] +#[strum(serialize_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum FilePurpose { + Assistants, + AssistantsOutput, + Batch, + BatchOutput, + FineTune, + FineTuneResults, + Vision, +} + +impl OpenAiClient { + pub async fn upload_file>( + &self, + filename: &str, + mime_type: &str, + bytes: B, + purpose: FilePurpose, + ) -> ApiResponseOrError { + let file_part = Part::stream(bytes) + .file_name(filename.to_string()) + .mime_str(mime_type)?; + + let form = Form::new() + .part("file", file_part) + .text("purpose", purpose.to_string()); + + self.post_multipart("files", form).await + } +} diff --git a/src/assistants/vector_stores.rs b/src/assistants/vector_stores.rs new file mode 100644 index 0000000..c822cd0 --- /dev/null +++ b/src/assistants/vector_stores.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; + +use crate::{client::OpenAiClient, ApiResponseOrError}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct VectorStore { + pub id: String, + pub object: String, + pub created_at: u32, + pub name: String, + pub usage_bytes: u32, + pub file_counts: FileCounts, + pub status: VectorStoreStatus, + pub expires_after: Option, + pub expires_at: Option, + pub last_active_at: Option, + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FileCounts { + pub in_progress: u32, + pub completed: u32, + pub failed: u32, + pub cancelled: u32, + pub total: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone, strum_macros::Display)] +#[strum(serialize_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum VectorStoreStatus { + Expired, + InProgress, + Completed, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ExpiresAfter { + pub anchor: String, + pub days: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct CreateVectorStoreRequest { + pub name: String, + pub file_ids: Option>, + pub metadata: Option>, + pub expires_after: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct VectorStoreFile { + pub id: String, + pub object: String, + pub created_at: u32, + pub file_id: String, + pub vector_store_id: String, + pub usage_bytes: u32, + pub status: VectorStoreFileStatus, + pub last_error: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, strum_macros::Display)] +#[strum(serialize_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum VectorStoreFileStatus { + InProgress, + Completed, + Cancelled, + Failed, +} + +impl OpenAiClient { + pub async fn create_vector_store( + &self, + params: CreateVectorStoreRequest, + ) -> ApiResponseOrError { + self.post("vector_stores", params).await + } + + pub async fn attach_file_to_vector_store( + &self, + vector_store_id: &str, + file_id: &str, + ) -> ApiResponseOrError { + self.post( + &format!("vector_stores/{}/files", vector_store_id), + json!({ file_id: file_id }), + ) + .await + } +} From d4bff90d3c7c533fe69fdfee23139112c937b214 Mon Sep 17 00:00:00 2001 From: Alex Grinman Date: Mon, 20 Jan 2025 20:43:21 +0000 Subject: [PATCH 5/9] make fields public --- src/chat.rs | 47 ++++++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/src/chat.rs b/src/chat.rs index 270bdae..25047aa 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -1,7 +1,7 @@ //! Given a chat conversation, the model will return a chat completion response. use super::{openai_post, ApiResponseOrError, Credentials, Usage}; -use crate::openai_request_stream; +use crate::{client::OpenAiClient, openai_request_stream}; use derive_builder::Builder; use futures_util::StreamExt; use reqwest::Method; @@ -166,7 +166,7 @@ pub enum ChatCompletionMessageRole { Tool, } -#[derive(Serialize, Builder, Debug, Clone)] +#[derive(Serialize, Builder, Debug, Clone, Default)] #[builder(derive(Clone, Debug, PartialEq))] #[builder(pattern = "owned")] #[builder(name = "ChatCompletionBuilder")] @@ -174,62 +174,62 @@ pub enum ChatCompletionMessageRole { pub struct ChatCompletionRequest { /// ID of the model to use. Currently, only `gpt-3.5-turbo`, `gpt-3.5-turbo-0301` and `gpt-4` /// are supported. - model: String, + pub model: String, /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction). - messages: Vec, + pub messages: Vec, /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. /// /// We generally recommend altering this or `top_p` but not both. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, + pub temperature: Option, /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. /// /// We generally recommend altering this or `temperature` but not both. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - top_p: Option, + pub top_p: Option, /// How many chat completion choices to generate for each input message. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - n: Option, + pub n: Option, #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - stream: Option, + pub stream: Option, /// Up to 4 sequences where the API will stop generating further tokens. #[builder(default)] #[serde(skip_serializing_if = "Vec::is_empty")] - stop: Vec, + pub stop: Vec, /// This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - seed: Option, + pub seed: Option, /// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens). #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, + pub max_tokens: Option, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. /// /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - presence_penalty: Option, + pub presence_penalty: Option, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. /// /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - frequency_penalty: Option, + pub frequency_penalty: Option, /// Modify the likelihood of specified tokens appearing in the completion. /// /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - logit_bias: Option>, + pub logit_bias: Option>, /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). #[builder(default)] #[serde(skip_serializing_if = "String::is_empty")] - user: String, + pub user: String, /// Describe functions that ChatGPT can call /// The latest models of ChatGPT support function calling, which allows you to define functions that can be called from the prompt. /// For example, you can define a function called "get_weather" that returns the weather in a given city @@ -238,7 +238,7 @@ pub struct ChatCompletionRequest { /// [See more information about function calling in ChatGPT.](https://platform.openai.com/docs/guides/gpt/function-calling) #[builder(default)] #[serde(skip_serializing_if = "Vec::is_empty")] - functions: Vec, + pub functions: Vec, /// A string or object of the function to call /// /// Controls how the model responds to function calls @@ -250,17 +250,17 @@ pub struct ChatCompletionRequest { /// "none" is the default when no functions are present. "auto" is the default if functions are present. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - function_call: Option, + pub function_call: Option, /// An object specifying the format that the model must output. Compatible with GPT-4 Turbo and all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. /// Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. /// Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - response_format: Option, + pub response_format: Option, /// The credentials to use for this request. #[serde(skip_serializing)] #[builder(default)] - credentials: Option, + pub credentials: Option, } #[derive(Serialize, Debug, Clone, Eq, PartialEq)] @@ -302,6 +302,15 @@ impl ChatCompletion { } } +impl OpenAiClient { + pub async fn create_chat_completion( + &self, + request: ChatCompletionRequest, + ) -> ApiResponseOrError { + self.post("chat/completions", request).await + } +} + impl ChatCompletionDelta { pub async fn create( request: ChatCompletionRequest, From 34978407923837af13b5344077112d52d7cabbb0 Mon Sep 17 00:00:00 2001 From: Alex Grinman Date: Tue, 21 Jan 2025 16:42:00 -0500 Subject: [PATCH 6/9] add sleep on poll --- src/assistants/runs.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/assistants/runs.rs b/src/assistants/runs.rs index de220c9..0386dd4 100644 --- a/src/assistants/runs.rs +++ b/src/assistants/runs.rs @@ -257,6 +257,7 @@ impl OpenAiClient { pub async fn poll_run(&self, mut run: Run) -> ApiResponseOrError { while !run.status.is_terminal() { + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; run = self .get_run(run.thread_id.as_str(), run.id.as_str()) .await?; From 6859a100b5da57bdd59ad3885abdafebfdfaa3e1 Mon Sep 17 00:00:00 2001 From: Alex Grinman Date: Tue, 21 Jan 2025 21:46:49 +0000 Subject: [PATCH 7/9] fix serde version --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0c7eb05..704d98c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,9 +16,9 @@ reqwest = { version = "0.12", default-features = false, features = [ "stream", "multipart", ], optional = true } -serde = { version = "1.0.157", features = ["derive"] } +serde = { version = "1.0", features = ["derive"] } reqwest-eventsource = "0.6" -tokio = { version = "1.26.0", features = ["full"] } +tokio = { version = "1.0", features = ["full"] } anyhow = "1.0" futures-util = "0.3.28" bytes = "1.4.0" From 2b56e2c9eee3e86237f4e54edcb396807e499fcc Mon Sep 17 00:00:00 2001 From: Alex Grinman Date: Tue, 21 Jan 2025 21:49:38 +0000 Subject: [PATCH 8/9] fix serde version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 704d98c..5bc05a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ reqwest = { version = "0.12", default-features = false, features = [ "stream", "multipart", ], optional = true } -serde = { version = "1.0", features = ["derive"] } +serde = { version = "^1.0", features = ["derive"] } reqwest-eventsource = "0.6" tokio = { version = "1.0", features = ["full"] } anyhow = "1.0" From ed64c5272ea8aa96e64db34e1c2f48994673799a Mon Sep 17 00:00:00 2001 From: Alex Grinman Date: Tue, 21 Jan 2025 23:01:12 +0000 Subject: [PATCH 9/9] use once_cell for more compat --- Cargo.toml | 1 + src/lib.rs | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5bc05a9..1f1d3f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ serde-double-tag = "0.0.4" log = "0.4" strum = { version = "0.26", features = ["derive"] } strum_macros = "0.26" +once_cell = "^1" [dev-dependencies] dotenvy = "0.15.7" diff --git a/src/lib.rs b/src/lib.rs index c98f645..b03db5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,11 @@ +use once_cell::sync::Lazy; use reqwest::multipart::Form; use reqwest::{header::AUTHORIZATION, Client, Method, RequestBuilder, Response}; use reqwest_eventsource::{CannotCloneRequestError, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::env; use std::env::VarError; -use std::sync::{LazyLock, RwLock}; +use std::sync::RwLock; pub mod assistants; pub mod chat; @@ -16,10 +17,10 @@ pub mod files; pub mod models; pub mod moderations; -pub static DEFAULT_BASE_URL: LazyLock = - LazyLock::new(|| String::from("https://api.openai.com/v1/")); -static DEFAULT_CREDENTIALS: LazyLock> = - LazyLock::new(|| RwLock::new(Credentials::from_env())); +pub static DEFAULT_BASE_URL: Lazy = + Lazy::new(|| String::from("https://api.openai.com/v1/")); +static DEFAULT_CREDENTIALS: Lazy> = + Lazy::new(|| RwLock::new(Credentials::from_env())); /// Holds the API key and base URL for an OpenAI-compatible API. #[derive(Debug, Clone, Eq, PartialEq)]