Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use tracing_subscriber::Registry;
mod default_device_sink;
mod timers;
mod transcribe;
mod web_search;
use chrono::{DateTime, Local};
use futures::stream::StreamExt; // For `.next()` on FuturesOrdered.
use std::thread;
Expand Down Expand Up @@ -434,6 +435,18 @@ fn call_fn(
Err(err) => Some(format!("Failed to create runtime: {}", err)),
},

"web_search" => {
let args: serde_json::Value = serde_json::from_str(fn_args).unwrap();
let query = args["query"].as_str().unwrap_or("");
match tokio::runtime::Runtime::new() {
Ok(rt) => match rt.block_on(web_search::web_search(query)) {
Ok(answer) => Some(answer),
Err(err) => Some(format!("Web search failed: {}", err)),
},
Err(err) => Some(format!("Failed to create runtime: {}", err)),
}
}

"set_timer_at" => {
let args: serde_json::Value = serde_json::from_str(fn_args).unwrap();
let time_str = args["time"].as_str().unwrap();
Expand Down Expand Up @@ -1460,6 +1473,16 @@ async fn main() -> Result<(), Box<dyn Error>> {
}))
.build().unwrap(),

ChatCompletionFunctionsArgs::default()
.name("web_search")
.description("Searches the web using the OpenAI Agents API and returns the answer.")
.parameters(json!({
"type": "object",
"properties": { "query": { "type": "string" } },
"required": ["query"],
}))
.build().unwrap(),

ChatCompletionFunctionsArgs::default()
.name("set_timer_at")
.description("Sets a timer to go off at a specific time. Pass the time as rfc3339 datetime string. Example: \"2024-12-04T00:44:00-08:00\". The description field is optional, add descriptions that will tell you what to remind the user to do, if anything, after the timer goes off.")
Expand Down
122 changes: 122 additions & 0 deletions src/web_search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use anyhow::{Context, Result};
use reqwest::Client;
use serde_json::json;

/// Searches the web using a temporary OpenAI assistant with browsing enabled.
/// The query is passed to the assistant and the assistant's final reply is returned.
/// Requires the `OPENAI_API_KEY` environment variable to be set.
pub async fn web_search(query: &str) -> Result<String> {
let api_key = std::env::var("OPENAI_API_KEY")
.context("OPENAI_API_KEY environment variable not set")?;

let client = Client::new();

// Create a temporary assistant with the browser tool enabled
let assistant_res: serde_json::Value = client
.post("https://api.openai.com/v1/assistants")
.bearer_auth(&api_key)
.json(&json!({
"model": "gpt-4o", // default model with browsing
"instructions": "Answer user questions using web search.",
"tools": [{"type": "browser"}],
}))
.send()
.await
.context("Failed to create assistant")?
.json()
.await
.context("Failed to parse assistant response")?;

let assistant_id = assistant_res["id"].as_str().context("No assistant id")?.to_string();

// Create a thread
let thread_res: serde_json::Value = client
.post("https://api.openai.com/v1/threads")
.bearer_auth(&api_key)
.json(&json!({}))
.send()
.await
.context("Failed to create thread")?
.json()
.await
.context("Failed to parse thread response")?;

let thread_id = thread_res["id"].as_str().context("No thread id")?.to_string();

// Add user message
client
.post(&format!(
"https://api.openai.com/v1/threads/{}/messages",
thread_id
))
.bearer_auth(&api_key)
.json(&json!({"role": "user", "content": query}))
.send()
.await
.context("Failed to add message")?;

// Start the run
let run_res: serde_json::Value = client
.post(&format!(
"https://api.openai.com/v1/threads/{}/runs",
thread_id
))
.bearer_auth(&api_key)
.json(&json!({"assistant_id": assistant_id}))
.send()
.await
.context("Failed to start run")?
.json()
.await
.context("Failed to parse run response")?;

let run_id = run_res["id"].as_str().context("No run id")?.to_string();

// Poll the run status
loop {
let status_res: serde_json::Value = client
.get(&format!(
"https://api.openai.com/v1/threads/{}/runs/{}",
thread_id, run_id
))
.bearer_auth(&api_key)
.send()
.await
.context("Failed to fetch run status")?
.json()
.await
.context("Failed to parse run status")?;

match status_res["status"].as_str() {
Some("completed") => break,
Some("failed") => return Err(anyhow::anyhow!("run failed")),
_ => tokio::time::sleep(std::time::Duration::from_secs(1)).await,
}
}

// Fetch messages and return the last assistant response
let messages_res: serde_json::Value = client
.get(&format!(
"https://api.openai.com/v1/threads/{}/messages",
thread_id
))
.bearer_auth(&api_key)
.send()
.await
.context("Failed to fetch messages")?
.json()
.await
.context("Failed to parse messages")?;

let messages = messages_res["data"].as_array().context("No messages array")?;
let response = messages
.iter()
.filter(|m| m["role"] == "assistant")
.max_by_key(|m| m["created_at"].as_i64().unwrap_or(0))
.and_then(|m| m["content"][0]["text"]["value"].as_str())
.unwrap_or("")
.to_string();

Ok(response)
}