From 560f2037c094c19709b6123730218afefe56d403 Mon Sep 17 00:00:00 2001 From: Logan King Date: Thu, 19 Jun 2025 19:30:18 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20add=20web=20search=20using=20openai?= =?UTF-8?q?=20agents?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 2 +- src/main.rs | 30 ++++++++++++- src/web_search.rs | 109 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 src/web_search.rs diff --git a/Cargo.toml b/Cargo.toml index a49ecfe..53e6668 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ clipboard = "0.5.0" reqwest = { version = "0.12.1", features = ["json"] } serde = { version = "1.0", features = ["derive"] } once_cell = "1.19.0" -speakstream = { version = "0.1.2", path = "../speakstream" } +speakstream = "0.1.2" windows = { version = "0.52.0", features = [ "Win32_System_Com", "Win32_System_Com_StructuredStorage", diff --git a/src/main.rs b/src/main.rs index 2a908e5..e93ac62 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,6 +25,7 @@ use futures::stream::StreamExt; // For `.next()` on FuturesOrdered. use std::thread; use tempfile::Builder; mod record; +mod web_search; use crate::default_device_sink::{ default_device_name as get_default_output_device, list_output_devices as list_audio_output_devices, set_output_device, DefaultDeviceSink, @@ -58,6 +59,7 @@ use tracing::{debug, error, info, instrument, warn}; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use speakstream::ss::SpeakStream; +use web_search::search_web; #[derive(Debug, Subcommand)] pub enum SubCommands { @@ -434,6 +436,22 @@ fn call_fn( Err(err) => Some(format!("Failed to create runtime: {}", err)), }, + "search_web" => match tokio::runtime::Runtime::new() { + Ok(rt) => { + let args: serde_json::Value = serde_json::from_str(fn_args).unwrap(); + let query = args["query"].as_str().unwrap_or(""); + let api_key = match std::env::var("OPENAI_API_KEY") { + Ok(k) => k, + Err(_) => return Some("OPENAI_API_KEY not set".to_string()), + }; + match rt.block_on(search_web(&api_key, query)) { + Ok(ans) => Some(ans), + Err(err) => Some(format!("Web search failed: {}", err)), + } + } + Err(err) => Some(format!("Failed to create runtime: {}", err)), + }, + "set_timer_at" => { let args: serde_json::Value = serde_json::from_str(fn_args).unwrap(); let time_str = args["time"].as_str().unwrap(); @@ -1578,7 +1596,7 @@ async fn main() -> Result<(), Box> { })) .build().unwrap(), - ChatCompletionFunctionsArgs::default() + ChatCompletionFunctionsArgs::default() .name("get_location") .description("Returns an approximate location based on the machine's IP address.") .parameters(json!({ @@ -1588,6 +1606,16 @@ async fn main() -> Result<(), Box> { })) .build().unwrap(), + ChatCompletionFunctionsArgs::default() + .name("search_web") + .description("Searches the web using OpenAI's browser tool.") + .parameters(json!({ + "type": "object", + "properties": { "query": { "type": "string" } }, + "required": ["query"], + })) + .build().unwrap(), + ChatCompletionFunctionsArgs::default() .name("list_output_devices") .description("Lists available audio output devices. The default device is marked with '(Default)'.") diff --git a/src/web_search.rs b/src/web_search.rs new file mode 100644 index 0000000..1652434 --- /dev/null +++ b/src/web_search.rs @@ -0,0 +1,109 @@ +use reqwest::Client; +use serde_json::json; +use std::{error::Error, time::Duration}; +use tokio::time::sleep; + +pub async fn search_web(api_key: &str, query: &str) -> Result> { + let client = Client::new(); + let base = "https://api.openai.com/v1"; + + let assistant_res: serde_json::Value = client + .post(&format!("{}/assistants", base)) + .header("Authorization", format!("Bearer {}", api_key)) + .header("OpenAI-Beta", "assistants=v1") + .json(&json!({ + "model": "gpt-4o", + "instructions": "You are a web search assistant.", + "tools": [{"type": "browser"}] + })) + .send() + .await? + .json() + .await?; + + let assistant_id = assistant_res["id"] + .as_str() + .ok_or("missing assistant id")? + .to_string(); + + let thread_res: serde_json::Value = client + .post(&format!("{}/threads", base)) + .header("Authorization", format!("Bearer {}", api_key)) + .header("OpenAI-Beta", "assistants=v1") + .send() + .await? + .json() + .await?; + + let thread_id = thread_res["id"] + .as_str() + .ok_or("missing thread id")? + .to_string(); + + client + .post(&format!("{}/threads/{}/messages", base, thread_id)) + .header("Authorization", format!("Bearer {}", api_key)) + .header("OpenAI-Beta", "assistants=v1") + .json(&json!({"role": "user", "content": query})) + .send() + .await?; + + let run_res: serde_json::Value = client + .post(&format!("{}/threads/{}/runs", base, thread_id)) + .header("Authorization", format!("Bearer {}", api_key)) + .header("OpenAI-Beta", "assistants=v1") + .json(&json!({"assistant_id": assistant_id})) + .send() + .await? + .json() + .await?; + + let run_id = run_res["id"].as_str().ok_or("missing run id")?.to_string(); + + loop { + let run_status: serde_json::Value = client + .get(&format!("{}/threads/{}/runs/{}", base, thread_id, run_id)) + .header("Authorization", format!("Bearer {}", api_key)) + .header("OpenAI-Beta", "assistants=v1") + .send() + .await? + .json() + .await?; + + match run_status["status"].as_str() { + Some("completed") => break, + Some("failed") | Some("expired") | Some("cancelled") => return Err("run failed".into()), + _ => sleep(Duration::from_secs(1)).await, + } + } + + let messages: serde_json::Value = client + .get(&format!("{}/threads/{}/messages", base, thread_id)) + .header("Authorization", format!("Bearer {}", api_key)) + .header("OpenAI-Beta", "assistants=v1") + .send() + .await? + .json() + .await?; + + let answer = messages["data"][0]["content"][0]["text"]["value"] + .as_str() + .unwrap_or("") + .to_string(); + + // cleanup + let _ = client + .delete(&format!("{}/assistants/{}", base, assistant_id)) + .header("Authorization", format!("Bearer {}", api_key)) + .header("OpenAI-Beta", "assistants=v1") + .send() + .await; + let _ = client + .delete(&format!("{}/threads/{}", base, thread_id)) + .header("Authorization", format!("Bearer {}", api_key)) + .header("OpenAI-Beta", "assistants=v1") + .send() + .await; + + Ok(answer) +}