diff --git a/.gitignore b/.gitignore index 5cd3e48..85d26c3 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,6 @@ target/ .env *.mp3 +output/ + logs/ \ No newline at end of file diff --git a/assistant_v2/Cargo.toml b/assistant_v2/Cargo.toml index 831c5bb..005277f 100644 --- a/assistant_v2/Cargo.toml +++ b/assistant_v2/Cargo.toml @@ -3,6 +3,10 @@ name = "assistant_v2" version = "0.1.0" edition = "2024" +[[bin]] +name = "assistant_v2" +path = "src/main.rs" + [dependencies] async-openai = "0.28.3" tokio = { version = "1.29.0", features = ["macros", "rt-multi-thread"] } diff --git a/assistant_v2/FEATURE_PROGRESS.md b/assistant_v2/FEATURE_PROGRESS.md index 46a702e..298970c 100644 --- a/assistant_v2/FEATURE_PROGRESS.md +++ b/assistant_v2/FEATURE_PROGRESS.md @@ -18,5 +18,6 @@ This document tracks which features from the original assistant have been implem | Mute/unmute voice output | Pending | | Open OpenAI billing page | Done | | Push-to-talk text-to-speech interface | Done | +| Code interpreter support | Done | Update this table as features are migrated and verified to work in `assistant_v2`. diff --git a/assistant_v2/README.md b/assistant_v2/README.md index 11a9928..9615d37 100644 --- a/assistant_v2/README.md +++ b/assistant_v2/README.md @@ -16,3 +16,13 @@ cargo run --manifest-path assistant_v2/Cargo.toml ``` The program creates a temporary assistant that answers a simple weather query using function calling. + +## Code Interpreter Support + +The main assistant binary now uses the Assistants API code interpreter tool. A sample CSV file is provided in `input/CASTHPI.csv` and is uploaded automatically on startup. + +Run the assistant with: + +```bash +cargo run --manifest-path assistant_v2/Cargo.toml +``` diff --git a/assistant_v2/src/main.rs b/assistant_v2/src/main.rs index 336ba00..0f0299c 100644 --- a/assistant_v2/src/main.rs +++ b/assistant_v2/src/main.rs @@ -1,9 +1,10 @@ use async_openai::{ config::OpenAIConfig, types::{ - AssistantStreamEvent, CreateAssistantRequestArgs, CreateMessageRequest, CreateRunRequest, - CreateThreadRequest, FunctionObject, MessageDeltaContent, MessageRole, RunObject, - SubmitToolOutputsRunRequest, ToolsOutputs, Voice, + AssistantStreamEvent, AssistantToolCodeInterpreterResources, AssistantTools, + CreateAssistantRequestArgs, CreateFileRequest, CreateMessageRequest, CreateRunRequest, + CreateThreadRequest, FilePurpose, FunctionObject, MessageContent, MessageContentTextAnnotations, + MessageDeltaContent, MessageRole, RunObject, SubmitToolOutputsRunRequest, ToolsOutputs, Voice, }, Client, }; @@ -14,6 +15,7 @@ use colored::Colorize; use clap::Parser; use clipboard::{ClipboardContext, ClipboardProvider}; use open; +use std::collections::HashSet; use std::error::Error; use std::path::PathBuf; use std::sync::{Arc, Mutex}; @@ -55,10 +57,19 @@ async fn main() -> Result<(), Box> { let client = Client::new(); + let data_file = client + .files() + .create(CreateFileRequest { + file: "./input/CASTHPI.csv".into(), + purpose: FilePurpose::Assistants, + }) + .await?; + let create_assistant_request = CreateAssistantRequestArgs::default() .instructions("You are a weather bot. Use the provided functions to answer questions.") .model("gpt-4o") .tools(vec![ + AssistantTools::CodeInterpreter, FunctionObject { name: "get_current_temperature".into(), description: Some( @@ -125,6 +136,7 @@ async fn main() -> Result<(), Box> { } .into(), ]) + .tool_resources(AssistantToolCodeInterpreterResources { file_ids: vec![data_file.id.clone()] }) .build()?; let assistant = client.assistants().create(create_assistant_request).await?; @@ -144,6 +156,9 @@ async fn main() -> Result<(), Box> { let (audio_tx, audio_rx) = flume::unbounded(); start_ptt_thread(audio_tx.clone(), speak_stream.clone(), opt.duck_ptt); + std::fs::create_dir_all("output").ok(); + let mut seen_messages: HashSet = HashSet::new(); + loop { let audio_path = audio_rx.recv().unwrap(); let transcription = transcribe::transcribe(&client, &audio_path).await?; @@ -174,6 +189,7 @@ async fn main() -> Result<(), Box> { let client_cloned = client.clone(); let mut task_handle = None; let mut displayed_ai_label = false; + let mut run_completed = false; while let Some(event) = event_stream.next().await { match event { @@ -203,6 +219,12 @@ async fn main() -> Result<(), Box> { } } } + AssistantStreamEvent::ThreadRunCompleted(_) => { + run_completed = true; + } + AssistantStreamEvent::Done(_) => { + break; + } _ => {} }, Err(e) => eprintln!("Error: {e}"), @@ -215,6 +237,12 @@ async fn main() -> Result<(), Box> { speak_stream.lock().unwrap().complete_sentence(); println!(); + + if run_completed { + if let Err(e) = print_new_messages(&client, &thread.id, &mut seen_messages).await { + eprintln!("Error retrieving messages: {e}"); + } + } } } @@ -362,6 +390,56 @@ async fn submit_tool_outputs( Ok(()) } +async fn print_new_messages( + client: &Client, + thread_id: &str, + seen: &mut HashSet, +) -> Result<(), Box> { + let messages = client + .threads() + .messages(thread_id) + .list(&[("limit", "20")]) + .await?; + + for message in messages.data { + if seen.insert(message.id.clone()) { + for content in message.content { + match content { + MessageContent::Text(text) => { + let data = text.text; + println!("{}", data.value); + for ann in data.annotations { + match ann { + MessageContentTextAnnotations::FileCitation(cit) => { + println!("annotation: file citation: {:?}", cit); + } + MessageContentTextAnnotations::FilePath(fp) => { + println!("annotation: file path: {:?}", fp); + } + } + } + } + MessageContent::ImageFile(obj) => { + let file_id = obj.image_file.file_id; + let contents = client.files().content(&file_id).await?; + let path = format!("output/{}.png", file_id); + tokio::fs::write(&path, contents).await?; + println!("Saved image to {path}"); + } + MessageContent::ImageUrl(u) => { + println!("Image URL: {:?}", u); + } + MessageContent::Refusal(r) => { + println!("{r:?}"); + } + } + } + } + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -401,4 +479,16 @@ mod tests { _ => false, })); } + + #[test] + fn includes_code_interpreter_tool() { + let req = CreateAssistantRequestArgs::default() + .model("gpt-4o") + .tools(vec![AssistantTools::CodeInterpreter]) + .build() + .unwrap(); + + let tools = req.tools.unwrap(); + assert!(tools.iter().any(|t| matches!(t, AssistantTools::CodeInterpreter))); + } } diff --git a/input/CASTHPI.csv b/input/CASTHPI.csv new file mode 100644 index 0000000..adc96d8 --- /dev/null +++ b/input/CASTHPI.csv @@ -0,0 +1,5 @@ +date,price_index +2020,100 +2021,110 +2022,120 +2023,125