Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ clipboard = "0.5.0"
reqwest = { version = "0.12.1", features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
once_cell = "1.19.0"
speakstream = { version = "0.1.2", path = "../speakstream" }
speakstream = "0.1.2"
windows = { version = "0.52.0", features = [
"Win32_System_Com",
"Win32_System_Com_StructuredStorage",
Expand Down
30 changes: 29 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use futures::stream::StreamExt; // For `.next()` on FuturesOrdered.
use std::thread;
use tempfile::Builder;
mod record;
mod web_search;
use crate::default_device_sink::{
default_device_name as get_default_output_device,
list_output_devices as list_audio_output_devices, set_output_device, DefaultDeviceSink,
Expand Down Expand Up @@ -58,6 +59,7 @@ use tracing::{debug, error, info, instrument, warn};
use tracing_appender::rolling::{RollingFileAppender, Rotation};

use speakstream::ss::SpeakStream;
use web_search::search_web;

#[derive(Debug, Subcommand)]
pub enum SubCommands {
Expand Down Expand Up @@ -434,6 +436,22 @@ fn call_fn(
Err(err) => Some(format!("Failed to create runtime: {}", err)),
},

"search_web" => match tokio::runtime::Runtime::new() {
Ok(rt) => {
let args: serde_json::Value = serde_json::from_str(fn_args).unwrap();
let query = args["query"].as_str().unwrap_or("");
let api_key = match std::env::var("OPENAI_API_KEY") {
Ok(k) => k,
Err(_) => return Some("OPENAI_API_KEY not set".to_string()),
};
match rt.block_on(search_web(&api_key, query)) {
Ok(ans) => Some(ans),
Err(err) => Some(format!("Web search failed: {}", err)),
}
}
Err(err) => Some(format!("Failed to create runtime: {}", err)),
},

"set_timer_at" => {
let args: serde_json::Value = serde_json::from_str(fn_args).unwrap();
let time_str = args["time"].as_str().unwrap();
Expand Down Expand Up @@ -1578,7 +1596,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
}))
.build().unwrap(),

ChatCompletionFunctionsArgs::default()
ChatCompletionFunctionsArgs::default()
.name("get_location")
.description("Returns an approximate location based on the machine's IP address.")
.parameters(json!({
Expand All @@ -1588,6 +1606,16 @@ async fn main() -> Result<(), Box<dyn Error>> {
}))
.build().unwrap(),

ChatCompletionFunctionsArgs::default()
.name("search_web")
.description("Searches the web using OpenAI's browser tool.")
.parameters(json!({
"type": "object",
"properties": { "query": { "type": "string" } },
"required": ["query"],
}))
.build().unwrap(),

ChatCompletionFunctionsArgs::default()
.name("list_output_devices")
.description("Lists available audio output devices. The default device is marked with '(Default)'.")
Expand Down
109 changes: 109 additions & 0 deletions src/web_search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use reqwest::Client;
use serde_json::json;
use std::{error::Error, time::Duration};
use tokio::time::sleep;

pub async fn search_web(api_key: &str, query: &str) -> Result<String, Box<dyn Error>> {
let client = Client::new();
let base = "https://api.openai.com/v1";

let assistant_res: serde_json::Value = client
.post(&format!("{}/assistants", base))
.header("Authorization", format!("Bearer {}", api_key))
.header("OpenAI-Beta", "assistants=v1")
.json(&json!({
"model": "gpt-4o",
"instructions": "You are a web search assistant.",
"tools": [{"type": "browser"}]
}))
.send()
.await?
.json()
.await?;

let assistant_id = assistant_res["id"]
.as_str()
.ok_or("missing assistant id")?
.to_string();

let thread_res: serde_json::Value = client
.post(&format!("{}/threads", base))
.header("Authorization", format!("Bearer {}", api_key))
.header("OpenAI-Beta", "assistants=v1")
.send()
.await?
.json()
.await?;

let thread_id = thread_res["id"]
.as_str()
.ok_or("missing thread id")?
.to_string();

client
.post(&format!("{}/threads/{}/messages", base, thread_id))
.header("Authorization", format!("Bearer {}", api_key))
.header("OpenAI-Beta", "assistants=v1")
.json(&json!({"role": "user", "content": query}))
.send()
.await?;

let run_res: serde_json::Value = client
.post(&format!("{}/threads/{}/runs", base, thread_id))
.header("Authorization", format!("Bearer {}", api_key))
.header("OpenAI-Beta", "assistants=v1")
.json(&json!({"assistant_id": assistant_id}))
.send()
.await?
.json()
.await?;

let run_id = run_res["id"].as_str().ok_or("missing run id")?.to_string();

loop {
let run_status: serde_json::Value = client
.get(&format!("{}/threads/{}/runs/{}", base, thread_id, run_id))
.header("Authorization", format!("Bearer {}", api_key))
.header("OpenAI-Beta", "assistants=v1")
.send()
.await?
.json()
.await?;

match run_status["status"].as_str() {
Some("completed") => break,
Some("failed") | Some("expired") | Some("cancelled") => return Err("run failed".into()),
_ => sleep(Duration::from_secs(1)).await,
}
}

let messages: serde_json::Value = client
.get(&format!("{}/threads/{}/messages", base, thread_id))
.header("Authorization", format!("Bearer {}", api_key))
.header("OpenAI-Beta", "assistants=v1")
.send()
.await?
.json()
.await?;

let answer = messages["data"][0]["content"][0]["text"]["value"]
.as_str()
.unwrap_or("")
.to_string();

// cleanup
let _ = client
.delete(&format!("{}/assistants/{}", base, assistant_id))
.header("Authorization", format!("Bearer {}", api_key))
.header("OpenAI-Beta", "assistants=v1")
.send()
.await;
let _ = client
.delete(&format!("{}/threads/{}", base, thread_id))
.header("Authorization", format!("Bearer {}", api_key))
.header("OpenAI-Beta", "assistants=v1")
.send()
.await;

Ok(answer)
}