From e510e9cfbb6d978ab42518b03ad4b06892ed005c Mon Sep 17 00:00:00 2001 From: Logan King Date: Mon, 30 Jun 2025 01:22:07 -0700 Subject: [PATCH 1/2] =?UTF-8?q?=E2=9C=A8=20(assistant=5Fv2):=20add=20code?= =?UTF-8?q?=20interpreter=20example?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + assistant_v2/Cargo.toml | 8 ++ assistant_v2/FEATURE_PROGRESS.md | 1 + assistant_v2/README.md | 10 ++ assistant_v2/src/code_interpreter.rs | 181 +++++++++++++++++++++++++++ input/CASTHPI.csv | 5 + 6 files changed, 207 insertions(+) create mode 100644 assistant_v2/src/code_interpreter.rs create mode 100644 input/CASTHPI.csv diff --git a/.gitignore b/.gitignore index 5cd3e48..85d26c3 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,6 @@ target/ .env *.mp3 +output/ + logs/ \ No newline at end of file diff --git a/assistant_v2/Cargo.toml b/assistant_v2/Cargo.toml index 831c5bb..0d81cb8 100644 --- a/assistant_v2/Cargo.toml +++ b/assistant_v2/Cargo.toml @@ -3,6 +3,14 @@ name = "assistant_v2" version = "0.1.0" edition = "2024" +[[bin]] +name = "assistant_v2" +path = "src/main.rs" + +[[bin]] +name = "code_interpreter" +path = "src/code_interpreter.rs" + [dependencies] async-openai = "0.28.3" tokio = { version = "1.29.0", features = ["macros", "rt-multi-thread"] } diff --git a/assistant_v2/FEATURE_PROGRESS.md b/assistant_v2/FEATURE_PROGRESS.md index 46a702e..58b0fcb 100644 --- a/assistant_v2/FEATURE_PROGRESS.md +++ b/assistant_v2/FEATURE_PROGRESS.md @@ -18,5 +18,6 @@ This document tracks which features from the original assistant have been implem | Mute/unmute voice output | Pending | | Open OpenAI billing page | Done | | Push-to-talk text-to-speech interface | Done | +| Code interpreter demo | Done | Update this table as features are migrated and verified to work in `assistant_v2`. diff --git a/assistant_v2/README.md b/assistant_v2/README.md index 11a9928..256bb89 100644 --- a/assistant_v2/README.md +++ b/assistant_v2/README.md @@ -16,3 +16,13 @@ cargo run --manifest-path assistant_v2/Cargo.toml ``` The program creates a temporary assistant that answers a simple weather query using function calling. + +## Code Interpreter Example + +`src/code_interpreter.rs` demonstrates how to use the Assistants API code interpreter tool with a CSV file. + +Run it with: + +```bash +cargo run --manifest-path assistant_v2/Cargo.toml --bin code_interpreter +``` diff --git a/assistant_v2/src/code_interpreter.rs b/assistant_v2/src/code_interpreter.rs new file mode 100644 index 0000000..a47c3e6 --- /dev/null +++ b/assistant_v2/src/code_interpreter.rs @@ -0,0 +1,181 @@ +use std::error::Error; + +use async_openai::{ + types::{ + AssistantToolCodeInterpreterResources, AssistantTools, CreateAssistantRequestArgs, + CreateFileRequest, CreateMessageRequestArgs, CreateRunRequest, CreateThreadRequest, + FilePurpose, MessageContent, MessageContentTextAnnotations, MessageRole, RunStatus, + }, + Client, +}; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + unsafe { + std::env::set_var("RUST_LOG", "ERROR"); + } + + tracing_subscriber::registry() + .with(fmt::layer()) + .with(EnvFilter::from_default_env()) + .init(); + + let client = Client::new(); + + // Upload data file with "assistants" purpose + let data_file = client + .files() + .create(CreateFileRequest { + file: "./input/CASTHPI.csv".into(), + purpose: FilePurpose::Assistants, + }) + .await?; + + // Create an assistant with code_interpreter tool with the uploaded file + let create_assistant_request = CreateAssistantRequestArgs::default() + .instructions("You are a data processor. When asked a question about data in a file, write and run code to answer the question.") + .model("gpt-4o") + .tools(vec![AssistantTools::CodeInterpreter]) + .tool_resources(AssistantToolCodeInterpreterResources { file_ids: vec![data_file.id.clone()] }) + .build()?; + + let assistant = client.assistants().create(create_assistant_request).await?; + + // create a thread + let create_message_request = CreateMessageRequestArgs::default() + .role(MessageRole::User) + .content("Generate a graph of price index vs year in png format") + .build()?; + + let create_thread_request = CreateThreadRequest { + messages: Some(vec![create_message_request]), + ..Default::default() + }; + + let thread = client.threads().create(create_thread_request).await?; + + // create run and check the output + let create_run_request = CreateRunRequest { + assistant_id: assistant.id.clone(), + ..Default::default() + }; + + let mut run = client + .threads() + .runs(&thread.id) + .create(create_run_request) + .await?; + + let mut generated_file_ids: Vec = vec![]; + + // poll the status of run until it's in a terminal state + loop { + match run.status { + RunStatus::Completed => { + let messages = client + .threads() + .messages(&thread.id) + .list(&[("limit", "10")]) + .await?; + + for message_obj in messages.data { + for message_content in message_obj.content { + match message_content { + MessageContent::Text(text) => { + let text_data = text.text; + println!("{}", text_data.value); + for annotation in text_data.annotations { + match annotation { + MessageContentTextAnnotations::FileCitation(object) => { + println!("annotation: file citation : {object:?}"); + } + MessageContentTextAnnotations::FilePath(object) => { + println!("annotation: file path: {object:?}"); + generated_file_ids.push(object.file_path.file_id); + } + } + } + } + MessageContent::ImageFile(object) => { + let file_id = object.image_file.file_id; + println!("Retrieving image file_id: {}", file_id); + let contents = client.files().content(&file_id).await?; + let path = "./output/price_index_vs_year_graph.png"; + tokio::fs::write(path, contents).await?; + println!("Graph file: {path}"); + generated_file_ids.push(file_id); + } + MessageContent::ImageUrl(object) => { + eprintln!("Got Image URL instead: {object:?}"); + } + MessageContent::Refusal(refusal) => { + println!("{refusal:?}"); + } + } + } + } + + break; + } + RunStatus::Failed => { + println!("> Run Failed: {:#?}", run); + break; + } + RunStatus::Queued => { + println!("> Run Queued"); + } + RunStatus::Cancelling => { + println!("> Run Cancelling"); + } + RunStatus::Cancelled => { + println!("> Run Cancelled"); + break; + } + RunStatus::Expired => { + println!("> Run Expired"); + break; + } + RunStatus::RequiresAction => { + println!("> Run Requires Action"); + } + RunStatus::InProgress => { + println!("> In Progress ..."); + } + RunStatus::Incomplete => { + println!("> Run Incomplete"); + } + } + + // wait for 1 sec before polling run object again + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + // retrieve the run + run = client.threads().runs(&thread.id).retrieve(&run.id).await?; + } + + // clean up + client.threads().delete(&thread.id).await?; + client.files().delete(&data_file.id).await?; + for file_id in generated_file_ids { + client.files().delete(&file_id).await?; + } + client.assistants().delete(&assistant.id).await?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn build_code_interpreter_request() { + let req = CreateAssistantRequestArgs::default() + .model("gpt-4o") + .tools(vec![AssistantTools::CodeInterpreter]) + .build(); + assert!(req.is_ok()); + } +} + diff --git a/input/CASTHPI.csv b/input/CASTHPI.csv new file mode 100644 index 0000000..adc96d8 --- /dev/null +++ b/input/CASTHPI.csv @@ -0,0 +1,5 @@ +date,price_index +2020,100 +2021,110 +2022,120 +2023,125 From c95207e90525fcb025c6208c2a365da9b7d57662 Mon Sep 17 00:00:00 2001 From: Logan King Date: Mon, 30 Jun 2025 01:37:36 -0700 Subject: [PATCH 2/2] =?UTF-8?q?=E2=9C=A8=20(assistant=5Fv2):=20integrate?= =?UTF-8?q?=20code=20interpreter=20tool?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- assistant_v2/Cargo.toml | 4 - assistant_v2/FEATURE_PROGRESS.md | 2 +- assistant_v2/README.md | 8 +- assistant_v2/src/code_interpreter.rs | 181 --------------------------- assistant_v2/src/main.rs | 96 +++++++++++++- 5 files changed, 98 insertions(+), 193 deletions(-) delete mode 100644 assistant_v2/src/code_interpreter.rs diff --git a/assistant_v2/Cargo.toml b/assistant_v2/Cargo.toml index 0d81cb8..005277f 100644 --- a/assistant_v2/Cargo.toml +++ b/assistant_v2/Cargo.toml @@ -7,10 +7,6 @@ edition = "2024" name = "assistant_v2" path = "src/main.rs" -[[bin]] -name = "code_interpreter" -path = "src/code_interpreter.rs" - [dependencies] async-openai = "0.28.3" tokio = { version = "1.29.0", features = ["macros", "rt-multi-thread"] } diff --git a/assistant_v2/FEATURE_PROGRESS.md b/assistant_v2/FEATURE_PROGRESS.md index 58b0fcb..298970c 100644 --- a/assistant_v2/FEATURE_PROGRESS.md +++ b/assistant_v2/FEATURE_PROGRESS.md @@ -18,6 +18,6 @@ This document tracks which features from the original assistant have been implem | Mute/unmute voice output | Pending | | Open OpenAI billing page | Done | | Push-to-talk text-to-speech interface | Done | -| Code interpreter demo | Done | +| Code interpreter support | Done | Update this table as features are migrated and verified to work in `assistant_v2`. diff --git a/assistant_v2/README.md b/assistant_v2/README.md index 256bb89..9615d37 100644 --- a/assistant_v2/README.md +++ b/assistant_v2/README.md @@ -17,12 +17,12 @@ cargo run --manifest-path assistant_v2/Cargo.toml The program creates a temporary assistant that answers a simple weather query using function calling. -## Code Interpreter Example +## Code Interpreter Support -`src/code_interpreter.rs` demonstrates how to use the Assistants API code interpreter tool with a CSV file. +The main assistant binary now uses the Assistants API code interpreter tool. A sample CSV file is provided in `input/CASTHPI.csv` and is uploaded automatically on startup. -Run it with: +Run the assistant with: ```bash -cargo run --manifest-path assistant_v2/Cargo.toml --bin code_interpreter +cargo run --manifest-path assistant_v2/Cargo.toml ``` diff --git a/assistant_v2/src/code_interpreter.rs b/assistant_v2/src/code_interpreter.rs deleted file mode 100644 index a47c3e6..0000000 --- a/assistant_v2/src/code_interpreter.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::error::Error; - -use async_openai::{ - types::{ - AssistantToolCodeInterpreterResources, AssistantTools, CreateAssistantRequestArgs, - CreateFileRequest, CreateMessageRequestArgs, CreateRunRequest, CreateThreadRequest, - FilePurpose, MessageContent, MessageContentTextAnnotations, MessageRole, RunStatus, - }, - Client, -}; -use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; - -#[tokio::main] -async fn main() -> Result<(), Box> { - unsafe { - std::env::set_var("RUST_LOG", "ERROR"); - } - - tracing_subscriber::registry() - .with(fmt::layer()) - .with(EnvFilter::from_default_env()) - .init(); - - let client = Client::new(); - - // Upload data file with "assistants" purpose - let data_file = client - .files() - .create(CreateFileRequest { - file: "./input/CASTHPI.csv".into(), - purpose: FilePurpose::Assistants, - }) - .await?; - - // Create an assistant with code_interpreter tool with the uploaded file - let create_assistant_request = CreateAssistantRequestArgs::default() - .instructions("You are a data processor. When asked a question about data in a file, write and run code to answer the question.") - .model("gpt-4o") - .tools(vec![AssistantTools::CodeInterpreter]) - .tool_resources(AssistantToolCodeInterpreterResources { file_ids: vec![data_file.id.clone()] }) - .build()?; - - let assistant = client.assistants().create(create_assistant_request).await?; - - // create a thread - let create_message_request = CreateMessageRequestArgs::default() - .role(MessageRole::User) - .content("Generate a graph of price index vs year in png format") - .build()?; - - let create_thread_request = CreateThreadRequest { - messages: Some(vec![create_message_request]), - ..Default::default() - }; - - let thread = client.threads().create(create_thread_request).await?; - - // create run and check the output - let create_run_request = CreateRunRequest { - assistant_id: assistant.id.clone(), - ..Default::default() - }; - - let mut run = client - .threads() - .runs(&thread.id) - .create(create_run_request) - .await?; - - let mut generated_file_ids: Vec = vec![]; - - // poll the status of run until it's in a terminal state - loop { - match run.status { - RunStatus::Completed => { - let messages = client - .threads() - .messages(&thread.id) - .list(&[("limit", "10")]) - .await?; - - for message_obj in messages.data { - for message_content in message_obj.content { - match message_content { - MessageContent::Text(text) => { - let text_data = text.text; - println!("{}", text_data.value); - for annotation in text_data.annotations { - match annotation { - MessageContentTextAnnotations::FileCitation(object) => { - println!("annotation: file citation : {object:?}"); - } - MessageContentTextAnnotations::FilePath(object) => { - println!("annotation: file path: {object:?}"); - generated_file_ids.push(object.file_path.file_id); - } - } - } - } - MessageContent::ImageFile(object) => { - let file_id = object.image_file.file_id; - println!("Retrieving image file_id: {}", file_id); - let contents = client.files().content(&file_id).await?; - let path = "./output/price_index_vs_year_graph.png"; - tokio::fs::write(path, contents).await?; - println!("Graph file: {path}"); - generated_file_ids.push(file_id); - } - MessageContent::ImageUrl(object) => { - eprintln!("Got Image URL instead: {object:?}"); - } - MessageContent::Refusal(refusal) => { - println!("{refusal:?}"); - } - } - } - } - - break; - } - RunStatus::Failed => { - println!("> Run Failed: {:#?}", run); - break; - } - RunStatus::Queued => { - println!("> Run Queued"); - } - RunStatus::Cancelling => { - println!("> Run Cancelling"); - } - RunStatus::Cancelled => { - println!("> Run Cancelled"); - break; - } - RunStatus::Expired => { - println!("> Run Expired"); - break; - } - RunStatus::RequiresAction => { - println!("> Run Requires Action"); - } - RunStatus::InProgress => { - println!("> In Progress ..."); - } - RunStatus::Incomplete => { - println!("> Run Incomplete"); - } - } - - // wait for 1 sec before polling run object again - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - - // retrieve the run - run = client.threads().runs(&thread.id).retrieve(&run.id).await?; - } - - // clean up - client.threads().delete(&thread.id).await?; - client.files().delete(&data_file.id).await?; - for file_id in generated_file_ids { - client.files().delete(&file_id).await?; - } - client.assistants().delete(&assistant.id).await?; - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn build_code_interpreter_request() { - let req = CreateAssistantRequestArgs::default() - .model("gpt-4o") - .tools(vec![AssistantTools::CodeInterpreter]) - .build(); - assert!(req.is_ok()); - } -} - diff --git a/assistant_v2/src/main.rs b/assistant_v2/src/main.rs index 336ba00..0f0299c 100644 --- a/assistant_v2/src/main.rs +++ b/assistant_v2/src/main.rs @@ -1,9 +1,10 @@ use async_openai::{ config::OpenAIConfig, types::{ - AssistantStreamEvent, CreateAssistantRequestArgs, CreateMessageRequest, CreateRunRequest, - CreateThreadRequest, FunctionObject, MessageDeltaContent, MessageRole, RunObject, - SubmitToolOutputsRunRequest, ToolsOutputs, Voice, + AssistantStreamEvent, AssistantToolCodeInterpreterResources, AssistantTools, + CreateAssistantRequestArgs, CreateFileRequest, CreateMessageRequest, CreateRunRequest, + CreateThreadRequest, FilePurpose, FunctionObject, MessageContent, MessageContentTextAnnotations, + MessageDeltaContent, MessageRole, RunObject, SubmitToolOutputsRunRequest, ToolsOutputs, Voice, }, Client, }; @@ -14,6 +15,7 @@ use colored::Colorize; use clap::Parser; use clipboard::{ClipboardContext, ClipboardProvider}; use open; +use std::collections::HashSet; use std::error::Error; use std::path::PathBuf; use std::sync::{Arc, Mutex}; @@ -55,10 +57,19 @@ async fn main() -> Result<(), Box> { let client = Client::new(); + let data_file = client + .files() + .create(CreateFileRequest { + file: "./input/CASTHPI.csv".into(), + purpose: FilePurpose::Assistants, + }) + .await?; + let create_assistant_request = CreateAssistantRequestArgs::default() .instructions("You are a weather bot. Use the provided functions to answer questions.") .model("gpt-4o") .tools(vec![ + AssistantTools::CodeInterpreter, FunctionObject { name: "get_current_temperature".into(), description: Some( @@ -125,6 +136,7 @@ async fn main() -> Result<(), Box> { } .into(), ]) + .tool_resources(AssistantToolCodeInterpreterResources { file_ids: vec![data_file.id.clone()] }) .build()?; let assistant = client.assistants().create(create_assistant_request).await?; @@ -144,6 +156,9 @@ async fn main() -> Result<(), Box> { let (audio_tx, audio_rx) = flume::unbounded(); start_ptt_thread(audio_tx.clone(), speak_stream.clone(), opt.duck_ptt); + std::fs::create_dir_all("output").ok(); + let mut seen_messages: HashSet = HashSet::new(); + loop { let audio_path = audio_rx.recv().unwrap(); let transcription = transcribe::transcribe(&client, &audio_path).await?; @@ -174,6 +189,7 @@ async fn main() -> Result<(), Box> { let client_cloned = client.clone(); let mut task_handle = None; let mut displayed_ai_label = false; + let mut run_completed = false; while let Some(event) = event_stream.next().await { match event { @@ -203,6 +219,12 @@ async fn main() -> Result<(), Box> { } } } + AssistantStreamEvent::ThreadRunCompleted(_) => { + run_completed = true; + } + AssistantStreamEvent::Done(_) => { + break; + } _ => {} }, Err(e) => eprintln!("Error: {e}"), @@ -215,6 +237,12 @@ async fn main() -> Result<(), Box> { speak_stream.lock().unwrap().complete_sentence(); println!(); + + if run_completed { + if let Err(e) = print_new_messages(&client, &thread.id, &mut seen_messages).await { + eprintln!("Error retrieving messages: {e}"); + } + } } } @@ -362,6 +390,56 @@ async fn submit_tool_outputs( Ok(()) } +async fn print_new_messages( + client: &Client, + thread_id: &str, + seen: &mut HashSet, +) -> Result<(), Box> { + let messages = client + .threads() + .messages(thread_id) + .list(&[("limit", "20")]) + .await?; + + for message in messages.data { + if seen.insert(message.id.clone()) { + for content in message.content { + match content { + MessageContent::Text(text) => { + let data = text.text; + println!("{}", data.value); + for ann in data.annotations { + match ann { + MessageContentTextAnnotations::FileCitation(cit) => { + println!("annotation: file citation: {:?}", cit); + } + MessageContentTextAnnotations::FilePath(fp) => { + println!("annotation: file path: {:?}", fp); + } + } + } + } + MessageContent::ImageFile(obj) => { + let file_id = obj.image_file.file_id; + let contents = client.files().content(&file_id).await?; + let path = format!("output/{}.png", file_id); + tokio::fs::write(&path, contents).await?; + println!("Saved image to {path}"); + } + MessageContent::ImageUrl(u) => { + println!("Image URL: {:?}", u); + } + MessageContent::Refusal(r) => { + println!("{r:?}"); + } + } + } + } + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -401,4 +479,16 @@ mod tests { _ => false, })); } + + #[test] + fn includes_code_interpreter_tool() { + let req = CreateAssistantRequestArgs::default() + .model("gpt-4o") + .tools(vec![AssistantTools::CodeInterpreter]) + .build() + .unwrap(); + + let tools = req.tools.unwrap(); + assert!(tools.iter().any(|t| matches!(t, AssistantTools::CodeInterpreter))); + } }