diff --git a/app/package.json b/app/package.json index 14f1d95db..832144532 100644 --- a/app/package.json +++ b/app/package.json @@ -40,6 +40,7 @@ "rust:check": "cargo check --manifest-path src-tauri/Cargo.toml", "rust:format": "cargo fmt --manifest-path ../Cargo.toml --all && cargo fmt --manifest-path src-tauri/Cargo.toml --all", "rust:format:check": "cargo fmt --manifest-path ../Cargo.toml --all --check && cargo fmt --manifest-path src-tauri/Cargo.toml --all --check", + "rust:clippy": "cargo clippy -p openhuman -- -D warnings", "format": "prettier --write . && yarn rust:format", "format:check": "prettier --check . && yarn rust:format:check", "lint": "eslint . --ext .ts,.tsx", diff --git a/src/api/rest.rs b/src/api/rest.rs index d4726fbf5..2f92f4c8c 100644 --- a/src/api/rest.rs +++ b/src/api/rest.rs @@ -714,9 +714,7 @@ fn key_bytes_from_string(key: &str) -> Result> { let trimmed = key.trim(); // Raw 32-byte ASCII key - if trimmed.len() == 32 - && !trimmed.contains(|c: char| c == '+' || c == '/' || c == '-' || c == '_' || c == '=') - { + if trimmed.len() == 32 && !trimmed.contains(['+', '/', '-', '_', '=']) { return Ok(trimmed.as_bytes().to_vec()); } diff --git a/src/core/all.rs b/src/core/all.rs index 94aab8dd2..1ad195abf 100644 --- a/src/core/all.rs +++ b/src/core/all.rs @@ -57,50 +57,91 @@ fn registry() -> &'static [RegisteredController] { } /// Aggregates all controller implementations from across the codebase. +/// +/// This function is responsible for collecting every domain-specific controller +/// registered in the system. It is used during the initialization of the +/// global [`REGISTRY`]. +/// +/// When adding a new domain/namespace, its `all_*_registered_controllers()` +/// function must be called here to make it available via RPC and CLI. fn build_registered_controllers() -> Vec { let mut controllers = Vec::new(); + // Application information and capabilities controllers.extend(crate::openhuman::about_app::all_about_app_registered_controllers()); + // Core application shell state controllers.extend(crate::openhuman::app_state::all_app_state_registered_controllers()); + // Composio integration controllers controllers.extend(crate::openhuman::composio::all_composio_registered_controllers()); + // Scheduled job management controllers.extend(crate::openhuman::cron::all_cron_registered_controllers()); + // Agent definition and prompt inspection controllers.extend(crate::openhuman::agent::all_agent_registered_controllers()); + // System and process health monitoring controllers.extend(crate::openhuman::health::all_health_registered_controllers()); + // Diagnostic tools controllers.extend(crate::openhuman::doctor::all_doctor_registered_controllers()); + // Secret storage and encryption controllers.extend(crate::openhuman::encryption::all_encryption_registered_controllers()); + // Background heartbeat loop controls controllers.extend(crate::openhuman::heartbeat::all_heartbeat_registered_controllers()); + // Token usage and billing cost tracking controllers.extend(crate::openhuman::cost::all_cost_registered_controllers()); + // Inline autocomplete settings controllers.extend(crate::openhuman::autocomplete::all_autocomplete_registered_controllers()); + // External messaging channels (Web, Telegram, etc.) controllers.extend( crate::openhuman::channels::providers::web::all_web_channel_registered_controllers(), ); controllers .extend(crate::openhuman::channels::controllers::all_channels_registered_controllers()); + // Persistent configuration management controllers.extend(crate::openhuman::config::all_config_registered_controllers()); + // User credentials and session management controllers.extend(crate::openhuman::credentials::all_credentials_registered_controllers()); + // Desktop service management controllers.extend(crate::openhuman::service::all_service_registered_controllers()); + // Data migration utilities controllers.extend(crate::openhuman::migration::all_migration_registered_controllers()); + // Local AI model management and inference controllers.extend(crate::openhuman::local_ai::all_local_ai_registered_controllers()); + // Screen capture and UI analysis controllers.extend( crate::openhuman::screen_intelligence::all_screen_intelligence_registered_controllers(), ); + // Bridge to external skill runtimes controllers.extend(crate::openhuman::socket::all_socket_registered_controllers()); + // User workspace and file management controllers.extend(crate::openhuman::workspace::all_workspace_registered_controllers()); + // Skill tool registry controllers.extend(crate::openhuman::tools::all_tools_registered_controllers()); + // Document and knowledge graph storage controllers.extend(crate::openhuman::memory::all_memory_registered_controllers()); + // Referral and growth tracking controllers.extend(crate::openhuman::referral::all_referral_registered_controllers()); + // Billing and subscription management controllers.extend(crate::openhuman::billing::all_billing_registered_controllers()); + // Team and role management controllers.extend(crate::openhuman::team::all_team_registered_controllers()); + // OS-level text input interactions controllers.extend(crate::openhuman::text_input::all_text_input_registered_controllers()); + // Voice transcription and synthesis controllers.extend(crate::openhuman::voice::all_voice_registered_controllers()); + // Background awareness and autonomous tasks controllers.extend(crate::openhuman::subconscious::all_subconscious_registered_controllers()); + // Webhook tunnel management controllers.extend(crate::openhuman::webhooks::all_webhooks_registered_controllers()); + // Core binary update management controllers.extend(crate::openhuman::update::all_update_registered_controllers()); + // Hierarchical knowledge summarization controllers .extend(crate::openhuman::tree_summarizer::all_tree_summarizer_registered_controllers()); controllers } /// Aggregates all controller schemas from across the codebase. +/// +/// Similar to [`build_registered_controllers`], but only collects the metadata +/// (schema) for each controller. This is used for discovery and validation. fn build_declared_controller_schemas() -> Vec { let mut schemas = Vec::new(); schemas.extend(crate::openhuman::about_app::all_about_app_controller_schemas()); diff --git a/src/core/cli.rs b/src/core/cli.rs index dbb218c77..7135ed791 100644 --- a/src/core/cli.rs +++ b/src/core/cli.rs @@ -29,17 +29,20 @@ Contribute & Star us on GitHub: https://github.com/tinyhumansai/openhuman /// Dispatches CLI commands based on arguments. /// -/// This is the entry point for CLI argument handling. It prints the banner, -/// checks for help requests, and dispatches to specific command handlers -/// like `run`, `call`, `skills`, or namespace-based commands. +/// This is the entry point for CLI argument handling. It performs the following: +/// 1. Prints the ASCII welcome banner to stderr. +/// 2. Resolves and groups available controller schemas. +/// 3. Checks for global help requests. +/// 4. Matches the first argument to a subcommand or a domain namespace. /// /// # Arguments /// -/// * `args` - A slice of strings containing the command-line arguments (excluding the binary name). +/// * `args` - A slice of strings containing the command-line arguments. /// /// # Errors /// -/// Returns an error if the command fails or if an unknown command is provided. +/// Returns an error if the command fails, parameters are invalid, or if +/// the subcommand/namespace is unknown. pub fn run_from_cli_args(args: &[String]) -> Result<()> { // Print the welcome banner to stderr to keep stdout clean for JSON output. eprint!("{CLI_BANNER}"); @@ -54,6 +57,7 @@ pub fn run_from_cli_args(args: &[String]) -> Result<()> { match args[0].as_str() { "run" | "serve" => run_server_command(&args[1..]), "call" => run_call_command(&args[1..]), + // Domain-specific CLI adapters that don't follow the generic namespace pattern. "screen-intelligence" => { crate::core::screen_intelligence_cli::run_screen_intelligence_command(&args[1..]) } @@ -70,14 +74,19 @@ pub fn run_from_cli_args(args: &[String]) -> Result<()> { ); crate::core::agent_cli::run_agent_command(&args[1..]) } + // Generic namespace dispatcher: `openhuman ...` namespace => run_namespace_command(namespace, &args[1..], &grouped), } } -/// Loads key/value pairs from a dotenv file into the process environment. +/// Loads key/value pairs from a `.env` file into the process environment. /// -/// Precedence: variables already set in the environment are **not** overwritten. -/// Order: `OPENHUMAN_DOTENV_PATH` (if set to a non-empty path), else `.env` in the current working directory. +/// This is used during the `run` command to load sensitive configurations. +/// +/// Precedence: +/// 1. Variables already set in the process environment are **not** overwritten. +/// 2. If `OPENHUMAN_DOTENV_PATH` is set, that file is loaded. +/// 3. Otherwise, it searches for `.env` in the current working directory. fn load_dotenv_for_server() -> Result<()> { match std::env::var("OPENHUMAN_DOTENV_PATH") { Ok(path) if !path.trim().is_empty() => { @@ -94,11 +103,12 @@ fn load_dotenv_for_server() -> Result<()> { /// Handles the `run` subcommand to start the core HTTP/JSON-RPC server. /// -/// Parses flags for port, host, and optional Socket.IO support. +/// This command boots the main application server, including its JSON-RPC +/// endpoint, Socket.IO bridge, and background services (voice, vision, etc.). /// /// # Arguments /// -/// * `args` - Command-line arguments for the `run` command. +/// * `args` - Command-line arguments for the `run` command (e.g., `--port`). fn run_server_command(args: &[String]) -> Result<()> { load_dotenv_for_server()?; @@ -165,7 +175,7 @@ fn run_server_command(args: &[String]) -> Result<()> { crate::core::logging::init_for_cli_run(verbose, log_scope); - // Initialize the Tokio runtime and start the server. + // Initialize the Tokio multi-threaded runtime. let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build()?; @@ -177,7 +187,8 @@ fn run_server_command(args: &[String]) -> Result<()> { /// Handles the `call` subcommand to invoke a JSON-RPC method directly from the CLI. /// -/// Useful for testing and automation. +/// This is used for one-off commands and debugging, bypassing the HTTP transport +/// and calling the internal `invoke_method` directly. /// /// # Arguments /// @@ -222,7 +233,7 @@ fn run_call_command(args: &[String]) -> Result<()> { .block_on(async { invoke_method(default_state(), &method, params).await }) .map_err(anyhow::Error::msg)?; - // Output the result as pretty-printed JSON. + // Output the result as pretty-printed JSON to stdout. println!("{}", serde_json::to_string_pretty(&value)?); Ok(()) } @@ -231,7 +242,6 @@ fn run_call_command(args: &[String]) -> Result<()> { /// /// Listens for a hotkey, records audio, transcribes via whisper, and inserts /// the result into the active text field. - fn run_voice_server_command(args: &[String]) -> Result<()> { use crate::openhuman::voice::hotkey::ActivationMode; use crate::openhuman::voice::server::{run_standalone, VoiceServerConfig}; diff --git a/src/core/dispatch.rs b/src/core/dispatch.rs index aceae3bbe..05a873c48 100644 --- a/src/core/dispatch.rs +++ b/src/core/dispatch.rs @@ -9,18 +9,22 @@ use serde_json::json; /// Dispatches an RPC method call to the appropriate subsystem. /// -/// It first attempts to route the request to the core subsystem (e.g., `core.ping`). -/// If not found, it delegates to the `openhuman` domain-specific dispatcher. +/// This is the primary entry point for all RPC calls. It uses a tiered routing +/// strategy: +/// 1. **Core Subsystem**: Checks for internal methods like `core.ping` or `core.version`. +/// 2. **Domain-Specific Handlers**: Delegates to the `openhuman` domain dispatcher +/// which handles all registered controllers (memory, skills, etc.). /// /// # Arguments /// -/// * `state` - The current application state. -/// * `method` - The name of the RPC method to invoke. +/// * `state` - The current application state (e.g., core version). +/// * `method` - The name of the RPC method to invoke (e.g., `core.ping`). /// * `params` - The parameters for the method call as a JSON value. /// /// # Returns /// -/// A `Result` containing the JSON-formatted response or an error message. +/// A `Result` containing the JSON-formatted response or an error message if +/// the method is unknown or invocation fails. pub async fn dispatch( state: AppState, method: &str, @@ -32,13 +36,16 @@ pub async fn dispatch( rpc_log::redact_params_for_log(¶ms) ); - // Try routing to internal core methods first. + // Tier 1: Internal core methods. + // These are handled directly within the core module and don't require + // a separate controller registration. if let Some(result) = try_core_dispatch(&state, method, params.clone()) { log::debug!("[rpc:dispatch] routed method={} subsystem=core", method); return result.map(crate::core::types::invocation_to_rpc_json); } - // Delegate to the domain-specific dispatcher. + // Tier 2: Domain-specific dispatcher. + // This routes to controllers registered in src/core/all.rs and src/rpc/mod.rs. if let Some(result) = crate::rpc::try_dispatch(method, params).await { log::debug!( "[rpc:dispatch] routed method={} subsystem=openhuman", @@ -53,9 +60,11 @@ pub async fn dispatch( /// Handles internal core-level RPC methods. /// -/// Currently supports: -/// - `core.ping`: Returns `{ "ok": true }`. -/// - `core.version`: Returns the current core version. +/// These methods provide basic information about the server and its version. +/// +/// Currently supported methods: +/// - `core.ping`: A simple liveness check. Returns `{ "ok": true }`. +/// - `core.version`: Returns the version of the running core binary. fn try_core_dispatch( state: &AppState, method: &str, diff --git a/src/core/event_bus/bus.rs b/src/core/event_bus/bus.rs index e6e9955e3..508ece046 100644 --- a/src/core/event_bus/bus.rs +++ b/src/core/event_bus/bus.rs @@ -26,12 +26,16 @@ pub const DEFAULT_CAPACITY: usize = 256; /// Initialize the global event bus. Must be called **once** during startup. /// +/// This function: +/// 1. Initializes the native request registry. +/// 2. Sets up the global singleton bus with the specified capacity. +/// /// Subsequent calls return the already-initialized bus without changing -/// the capacity. Panics are impossible — `OnceLock` guarantees single init. +/// its capacity. This ensures thread-safe, consistent initialization. +/// +/// # Arguments /// -/// This also initializes the native request registry so that any domain -/// can immediately register handlers and dispatch requests without worrying -/// about startup ordering. +/// * `capacity` - The maximum number of buffered events for the broadcast channel. pub fn init_global(capacity: usize) -> &'static EventBus { // Initialize the native request registry first so handler registration // is always safe from anywhere in the process once the bus is up. @@ -42,14 +46,21 @@ pub fn init_global(capacity: usize) -> &'static EventBus { }) } -/// Get the global event bus, or `None` if [`init_global`] has not been called. +/// Get the global event bus. +/// +/// Returns `Some(&EventBus)` if [`init_global`] has been called, otherwise `None`. pub fn global() -> Option<&'static EventBus> { GLOBAL_BUS.get() } /// Publish an event on the global bus. /// -/// Silently does nothing if the global bus is not yet initialized. +/// This is the primary way to notify other modules about domain events +/// (e.g., an agent turn completed, a memory was stored). +/// +/// # Arguments +/// +/// * `event` - The [`DomainEvent`] to broadcast to all subscribers. pub fn publish_global(event: DomainEvent) { if let Some(bus) = GLOBAL_BUS.get() { bus.publish(event); @@ -60,34 +71,41 @@ pub fn publish_global(event: DomainEvent) { /// Subscribe a handler on the global bus. /// -/// Silently does nothing and returns `None` if the bus is not yet initialized. +/// The handler will receive all events that match its domain filter. +/// Returns a [`SubscriptionHandle`] that will cancel the subscription when dropped. +/// +/// # Arguments +/// +/// * `handler` - An implementation of the [`EventHandler`] trait. pub fn subscribe_global(handler: Arc) -> Option { GLOBAL_BUS.get().map(|bus| bus.subscribe(handler)) } // ── EventBus struct ───────────────────────────────────────────────────── -/// The event bus. There is exactly **one** instance at runtime, accessed -/// through the module-level functions ([`init_global`], [`publish_global`], -/// [`subscribe_global`], [`global`]). +/// The event bus, wrapping a `tokio::sync::broadcast` channel. /// -/// Direct construction is restricted to `pub(crate)` for test isolation. +/// It provides a many-to-many communication channel for [`DomainEvent`]s. +/// There is exactly **one** production instance at runtime (the global singleton). #[derive(Clone, Debug)] pub struct EventBus { + /// The sending end of the broadcast channel. tx: broadcast::Sender, } impl EventBus { - /// Create a new event bus. **Only** exposed within the crate for testing; - /// production code must use [`init_global`]. + /// Create a new event bus with the given capacity. + /// + /// This is used internally by [`init_global`] and by tests for isolation. pub(crate) fn create(capacity: usize) -> Self { let (tx, _) = broadcast::channel(capacity.max(1)); Self { tx } } - /// Publish an event to all subscribers. + /// Publish an event to all active subscribers. /// - /// Silently drops the event if no subscribers are listening. + /// The event is cloned and sent to each subscriber's receiving end. + /// If no subscribers are currently listening, the event is silently dropped. pub fn publish(&self, event: DomainEvent) { let receiver_count = self.tx.receiver_count(); tracing::debug!( @@ -99,10 +117,19 @@ impl EventBus { let _ = self.tx.send(event); } - /// Subscribe with an [`EventHandler`] implementation. + /// Subscribe with a trait-based [`EventHandler`]. + /// + /// Spawns a background task that listens for events and dispatches them + /// to the handler's `handle` method. + /// + /// # Arguments + /// + /// * `handler` - The handler to register. Its `domains()` filter is checked + /// before every dispatch. + /// + /// # Returns /// - /// Returns a [`SubscriptionHandle`] that cancels the subscriber when dropped. - /// The handler's optional `domains()` filter is applied before dispatching. + /// A [`SubscriptionHandle`] to manage the lifetime of the background task. pub fn subscribe(&self, handler: Arc) -> SubscriptionHandle { let mut rx = self.tx.subscribe(); let name = handler.name().to_string(); @@ -121,7 +148,8 @@ impl EventBus { loop { match rx.recv().await { Ok(event) => { - // Apply domain filter + // Apply domain filter: only dispatch if the event domain + // matches one of the subscriber's allowed domains. if let Some(ref allowed) = domains { if !allowed.iter().any(|d| d == event.domain()) { continue; @@ -132,6 +160,8 @@ impl EventBus { domain = event.domain(), "[event_bus] dispatching to handler" ); + // Wrap the handler call in AssertUnwindSafe so that a + // panic in one handler doesn't crash the entire event loop. let result = AssertUnwindSafe(handler.handle(&event)) .catch_unwind() .await; @@ -170,10 +200,11 @@ impl EventBus { SubscriptionHandle::new(name, task) } - /// Subscribe with an async closure for simple, one-off handlers. + /// Subscribe with an async closure. /// - /// Domain filtering is not supported with this shorthand — use - /// [`subscribe`] with an [`EventHandler`] for domain-filtered subscriptions. + /// This is a convenience method for simple, one-off event handlers. + /// It doesn't support domain filtering directly; the closure will receive + /// every event published on the bus. pub fn on(&self, name: &str, handler: F) -> SubscriptionHandle where F: Fn(&DomainEvent) -> std::pin::Pin + Send + '_>> @@ -188,7 +219,7 @@ impl EventBus { self.subscribe(subscriber) } - /// Returns the current number of active subscribers. + /// Returns the current number of active subscribers (receivers). pub fn subscriber_count(&self) -> usize { self.tx.receiver_count() } diff --git a/src/core/event_bus/mod.rs b/src/core/event_bus/mod.rs index 75ba6c507..1dc7c85d8 100644 --- a/src/core/event_bus/mod.rs +++ b/src/core/event_bus/mod.rs @@ -1,21 +1,36 @@ //! Cross-module event bus for decoupled events and typed in-process requests. //! //! The event bus is a **singleton** — one instance for the entire application. +//! It serves as the central nervous system of OpenHuman, allowing different +//! modules (like memory, skills, and agents) to communicate without +//! direct dependencies. +//! //! Call [`init_global`] once at startup, then use [`publish_global`], //! [`subscribe_global`], [`register_native_global`], and //! [`request_native_global`] from any module. //! -//! # Two surfaces +//! # Two Surfaces +//! +//! 1. **Broadcast Pub/Sub** ([`publish_global`] / [`subscribe_global`]) +//! - Built on `tokio::sync::broadcast`. +//! - **Many-to-many**: One publisher, zero or more subscribers. +//! - **Fire-and-forget**: No feedback from subscribers to the publisher. +//! - **Decoupled**: Use this for notifications like "a message was received" +//! or "a skill was loaded". +//! +//! 2. **Native Request/Response** ([`register_native_global`] / [`request_native_global`]) +//! - **One-to-one**: Each method name has exactly one registered handler. +//! - **Typed**: Payloads are Rust types, checked at runtime via `TypeId`. +//! - **Zero Serialization**: Directly passes pointers, `Arc`s, and channels. +//! - **Coupled (but in-process)**: Use this for direct module-to-module +//! calls that need non-serializable data or immediate responses. +//! +//! # Architecture //! -//! 1. **Broadcast pub/sub** ([`publish_global`] / [`subscribe_global`]) — -//! fire-and-forget notification of [`DomainEvent`] variants. One publisher, -//! many subscribers, no back-channel. -//! 2. **Native request/response** ([`register_native_global`] / -//! [`request_native_global`]) — one-to-one typed Rust dispatch keyed by a -//! method string. Zero serialization: trait objects, [`std::sync::Arc`]s, -//! [`tokio::sync::mpsc::Sender`]s, and oneshot channels pass through -//! unchanged. Use this for in-process module-to-module calls that need -//! non-serializable payloads (hot-path data, streaming, async resolution). +//! The bus is designed to be initialized early in the application lifecycle. +//! Once [`init_global`] is called, the bus is available globally. This allows +//! modules to register their handlers and subscribers in their own `bus.rs` +//! or `mod.rs` files during startup. //! //! # Usage //! @@ -25,15 +40,16 @@ //! subscribe_global, DomainEvent, //! }; //! -//! // Publish a broadcast event +//! // Example 1: Broadcasting a system event //! publish_global(DomainEvent::SystemStartup { component: "example".into() }); //! -//! // Register a native request handler at startup +//! // Example 2: Registering a native request handler //! register_native_global::("my_domain.do_thing", |req| async move { +//! // Process request... //! Ok(MyResp { /* ... */ }) -//! }).await; +//! }); //! -//! // Dispatch a native request from any module +//! // Example 3: Dispatching a native request //! let resp: MyResp = request_native_global("my_domain.do_thing", MyReq { /* ... */ }).await?; //! ``` diff --git a/src/core/event_bus/native_request.rs b/src/core/event_bus/native_request.rs index 50608409c..d8db40901 100644 --- a/src/core/event_bus/native_request.rs +++ b/src/core/event_bus/native_request.rs @@ -117,18 +117,21 @@ struct HandlerEntry { /// Registry of native, in-process typed request handlers. /// -/// Handlers are keyed by a method name (`"agent.run_turn"`, -/// `"approval.prompt"`, …) and store the request/response `TypeId` so -/// callers that disagree about types get a structured error instead of a -/// panic or silent corruption. +/// Handlers are keyed by a method name (e.g., `"agent.run_turn"`) and store the +/// expected request and response types. This enables safe, typed communication +/// between different modules without the overhead of serialization. +/// +/// The registry is thread-safe, using a `RwLock` to allow concurrent lookups +/// while guarding registrations. #[derive(Clone, Default)] pub struct NativeRegistry { + /// Internal map of method names to their handler entries. handlers: Arc>>, } impl std::fmt::Debug for NativeRegistry { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // Non-blocking read attempt; if contended, fall back to opaque. + // Non-blocking read attempt to avoid deadlocks during debugging. match self.handlers.try_read() { Ok(guard) => f .debug_struct("NativeRegistry") @@ -142,25 +145,33 @@ impl std::fmt::Debug for NativeRegistry { } } -/// Recover from `RwLock` poison by taking the inner guard. The registry -/// holds simple data (`HashMap`) — a panicked writer cannot leave it in an -/// invalid state, so it's safe to continue. +/// Recover from `RwLock` poison by taking the inner guard. +/// +/// If a thread panics while holding the lock, the lock becomes "poisoned". +/// Since the registry only holds a simple `HashMap`, we can safely ignore +/// the poison and continue using the registry. fn unpoison(result: Result>) -> T { result.unwrap_or_else(|e| e.into_inner()) } impl NativeRegistry { + /// Creates a new, empty registry. pub fn new() -> Self { Self::default() } - /// Register a handler for `method`. If a handler already exists for - /// this method, it is replaced — tests rely on this to override - /// production handlers. + /// Register a handler for a specific method name. + /// + /// If a handler already exists for the method, it will be replaced. + /// This is particularly useful in tests for overriding production handlers + /// with mocks or stubs. /// - /// Synchronous because registration only inserts into an in-memory - /// map. The handler itself still produces an async future when it - /// runs; only the registration step is sync. + /// # Type Parameters + /// + /// * `Req` - The request type. Must implement `Send + 'static`. + /// * `Resp` - The response type. Must implement `Send + 'static`. + /// * `F` - The handler function/closure. + /// * `Fut` - The future returned by the handler. pub fn register(&self, method: &str, handler: F) where Req: Send + 'static, @@ -168,6 +179,7 @@ impl NativeRegistry { F: Fn(Req) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { + // Wrap the typed handler in a type-erased closure. let handler_arc: BoxedHandler = Arc::new(move |boxed: BoxedAny| { // This downcast is infallible: the dispatch path verifies // TypeId equality before invoking the handler. @@ -186,6 +198,7 @@ impl NativeRegistry { resp_name: std::any::type_name::(), }; + // Insert the handler under a write lock. let previous = unpoison(self.handlers.write()).insert(method.to_string(), entry); if previous.is_some() { @@ -205,17 +218,17 @@ impl NativeRegistry { } } - /// Dispatch a typed request to the registered handler for `method`. + /// Dispatch a typed request to a registered handler. + /// + /// This method performs runtime type checks to ensure the caller and handler + /// agree on the request and response types. /// - /// Returns [`NativeRequestError::UnregisteredHandler`] if no handler - /// is registered, [`NativeRequestError::TypeMismatch`] if the caller's - /// `Req` or `Resp` doesn't match the registered handler, or - /// [`NativeRequestError::HandlerFailed`] if the handler itself - /// returned an error. + /// # Errors /// - /// The read lock is acquired, the handler `Arc` is cloned, and the - /// lock is dropped — all before the handler future is awaited. This - /// means slow handlers never block concurrent dispatches or registrations. + /// Returns a [`NativeRequestError`] if: + /// - No handler is registered for the method. + /// - There is a type mismatch for the request or response. + /// - The handler returns an error. pub async fn request( &self, method: &str, @@ -225,9 +238,9 @@ impl NativeRegistry { Req: Send + 'static, Resp: Send + 'static, { - // Lookup + cheap clone of the handler Arc under the read lock, - // then drop the lock before awaiting the handler future. Scoped - // so the `guard` goes out of scope at the end of the block. + // Lookup the handler and clone its metadata under a read lock. + // We drop the lock BEFORE awaiting the handler's future to avoid + // blocking other threads. let (handler, expected_req, expected_resp, expected_req_name, expected_resp_name) = { let guard = unpoison(self.handlers.read()); let entry = @@ -245,6 +258,7 @@ impl NativeRegistry { ) }; + // Verify that the caller's request type matches the registered type. if TypeId::of::() != expected_req { return Err(NativeRequestError::TypeMismatch { method: method.to_string(), @@ -252,6 +266,7 @@ impl NativeRegistry { actual: std::any::type_name::(), }); } + // Verify that the caller's response type matches the registered type. if TypeId::of::() != expected_resp { return Err(NativeRequestError::TypeMismatch { method: method.to_string(), @@ -267,10 +282,10 @@ impl NativeRegistry { ); let boxed_req: BoxedAny = Box::new(req); + // Invoke the handler and await its completion. match handler(boxed_req).await { Ok(boxed_resp) => { - // Infallible: TypeId check above guarantees the handler - // produced the right Resp type. + // Infallible: the TypeId check above guarantees the correct type. let resp = *boxed_resp.downcast::().expect( "native_request: handler returned wrong response type despite TypeId check", ); diff --git a/src/core/jsonrpc.rs b/src/core/jsonrpc.rs index 3c42c239a..03b584cf7 100644 --- a/src/core/jsonrpc.rs +++ b/src/core/jsonrpc.rs @@ -24,8 +24,16 @@ use crate::core::types::{AppState, RpcError, RpcFailure, RpcRequest, RpcSuccess} /// Axum handler for JSON-RPC POST requests. /// -/// It parses the request, invokes the requested method, and returns a -/// JSON-RPC 2.0 compliant success or failure response. +/// This function: +/// 1. Receives a JSON-RPC request body. +/// 2. Extracts the method name and parameters. +/// 3. Invokes the corresponding handler via [`invoke_method`]. +/// 4. Wraps the result or error in a JSON-RPC 2.0 compliant response. +/// +/// # Arguments +/// +/// * `state` - The application state, injected by Axum. +/// * `req` - The parsed [`RpcRequest`]. pub async fn rpc_handler(State(state): State, Json(req): Json) -> Response { let id = req.id.clone(); let method = req.method.clone(); @@ -67,8 +75,9 @@ pub async fn rpc_handler(State(state): State, Json(req): Json, Json(req): Json Result { let result = invoke_method_inner(state, method, params).await; - // If the RPC call failed due to an expired/invalid session token (401 from - // the backend), automatically clear the stored session so the frontend - // detects the logged-out state and redirects to login. + // Session auto-cleanup: If the backend says we're unauthorized, + // we should reflect that locally by clearing the stored token. if let Err(ref msg) = result { if is_session_expired_error(msg) { log::warn!( @@ -96,6 +104,7 @@ pub async fn invoke_method(state: AppState, method: &str, params: Value) -> Resu result } +/// Helper to determine if an error message indicates an expired or invalid session. fn is_session_expired_error(msg: &str) -> bool { let lower = msg.to_lowercase(); (lower.contains("401") && lower.contains("unauthorized")) @@ -103,13 +112,21 @@ fn is_session_expired_error(msg: &str) -> bool { || msg.contains("SESSION_EXPIRED") } +/// Internal method invocation logic. +/// +/// It first attempts to match the method name against the static controller +/// registry (schemas). If a schema is found, it validates the input parameters +/// before execution. If no schema matches, it falls back to the dynamic +/// [`crate::core::dispatch::dispatch`] system. async fn invoke_method_inner( state: AppState, method: &str, params: Value, ) -> Result { + // Phase 1: Check static controller registry. if let Some(schema) = all::schema_for_rpc_method(method) { let params_obj = params_to_object(params)?; + // Validate inputs against the schema before calling the handler. all::validate_params(&schema, ¶ms_obj)?; if let Some(result) = all::try_invoke_registered_rpc(method, params_obj).await { return result; @@ -117,10 +134,14 @@ async fn invoke_method_inner( return Err(format!("registered schema has no handler: {method}")); } + // Phase 2: Fall back to dynamic dispatch (internal core methods or legacy paths). crate::core::dispatch::dispatch(state, method, params).await } /// Converts JSON parameters into a map, ensuring they are in object format. +/// +/// JSON-RPC allows parameters to be an Object, an Array, or Null. This implementation +/// primarily supports Object parameters for named-argument style calls. fn params_to_object(params: Value) -> Result, String> { match params { Value::Object(map) => Ok(map), diff --git a/src/core/logging.rs b/src/core/logging.rs index d2e37ee1c..f0eb1cf58 100644 --- a/src/core/logging.rs +++ b/src/core/logging.rs @@ -26,7 +26,12 @@ pub enum CliLogDefault { AutocompleteOnly, } -/// `14:32:01 (jsonrpc) message…` — colors when stderr is a TTY. +/// Custom log formatter for the OpenHuman CLI. +/// +/// It produces a clean, readable output on stderr: +/// `14:32:01 INF:jsonrpc: Listening on http://127.0.0.1:7788` +/// +/// It supports ANSI colors if the output is a terminal and `NO_COLOR` is not set. struct CleanCliFormat; impl FormatEvent for CleanCliFormat @@ -34,6 +39,7 @@ where S: tracing::Subscriber + for<'a> LookupSpan<'a>, N: for<'a> FormatFields<'a> + 'static, { + /// Formats a single tracing event into a string and writes it to the writer. fn format_event( &self, ctx: &FmtContext<'_, S, N>, @@ -41,10 +47,12 @@ where event: &Event<'_>, ) -> fmt::Result { let meta = event.metadata(); + // Use local time for log timestamps. let time = chrono::Local::now().format("%H:%M:%S"); let level = level_tag(meta.level()); let target = short_target(meta.target()); + // Check if the writer supports ANSI escape codes for coloring. if writer.has_ansi_escapes() { let time_styled = Style::new().dimmed().paint(time.to_string()); write!(writer, "{time_styled}:")?; @@ -59,18 +67,22 @@ where }; write!(writer, "{level_styled}:")?; + // Scope color: pick a neutral gray for the module name. let scope = target.to_string(); let scope_styled = Style::new().fg(Color::Fixed(247)).paint(scope); write!(writer, "{scope_styled} ")?; } else { + // Plain text fallback (e.g., when logging to a file or non-TTY). write!(writer, "{time}:{level}:{target} ")?; } + // Write the actual log message and its fields. ctx.field_format().format_fields(writer.by_ref(), event)?; writeln!(writer) } } +/// Returns a 3-letter uppercase tag for each log level. fn level_tag(level: &Level) -> &'static str { match *level { Level::ERROR => "ERR", @@ -81,10 +93,14 @@ fn level_tag(level: &Level) -> &'static str { } } +/// Shortens a Rust module path (e.g., `openhuman_core::rpc` -> `rpc`). fn short_target(target: &str) -> &str { target.rsplit("::").next().unwrap_or(target) } +/// Parses a comma-separated list of file/module constraints from environment. +/// +/// Used to filter logs to specific parts of the codebase. fn parse_log_file_constraints() -> Vec { std::env::var("OPENHUMAN_LOG_FILE_CONSTRAINTS") .ok() @@ -98,6 +114,7 @@ fn parse_log_file_constraints() -> Vec { .unwrap_or_default() } +/// Checks if a log event matches any of the configured file/module constraints. fn event_matches_file_constraints(meta: &tracing::Metadata<'_>, constraints: &[String]) -> bool { if constraints.is_empty() { return true; @@ -110,12 +127,20 @@ fn event_matches_file_constraints(meta: &tracing::Metadata<'_>, constraints: &[S .any(|constraint| file.contains(constraint) || target.contains(constraint)) } -/// Initialize `tracing` + bridge the `log` crate so existing `log::info!` calls appear. +/// Initialize the global `tracing` subscriber and bridge the `log` crate. +/// +/// This function: +/// 1. Determines the default log level based on `verbose` and `default_scope`. +/// 2. Sets up an `EnvFilter` from `RUST_LOG` or the defaults. +/// 3. Detects terminal capabilities for ANSI colors. +/// 4. Registers a formatting layer with [`CleanCliFormat`]. +/// 5. Integrates Sentry for error tracking. +/// 6. Bridges legacy `log::info!` macros. /// -/// - If `RUST_LOG` is unset: uses [`CliLogDefault`] and `verbose` to pick a default filter string. -/// - Safe to call once; subsequent calls are ignored. +/// It is idempotent and will only initialize the subscriber once per process. pub fn init_for_cli_run(verbose: bool, default_scope: CliLogDefault) { INIT.call_once(|| { + // Set RUST_LOG environment variable if not already set by the user. if std::env::var_os("RUST_LOG").is_none() { let default = match default_scope { CliLogDefault::Global => { @@ -133,6 +158,7 @@ pub fn init_for_cli_run(verbose: bool, default_scope: CliLogDefault) { std::env::set_var("RUST_LOG", default); } + // Try parsing the EnvFilter from environment or use defaults. let filter = tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { match default_scope { CliLogDefault::Global => { @@ -147,13 +173,7 @@ pub fn init_for_cli_run(verbose: bool, default_scope: CliLogDefault) { } }); - // Color resolution order (standard conventions): - // 1. `NO_COLOR` (any value) → force off. - // 2. `FORCE_COLOR` or `CLICOLOR_FORCE` → force on. Useful when the - // core runs as a child of the Tauri shell under `yarn tauri dev`, - // where the grandchild's stderr may not be detected as a TTY even - // though the ultimate terminal supports ANSI. - // 3. Fall back to TTY detection on stderr. + // Color resolution logic. let use_color = if std::env::var_os("NO_COLOR").is_some() { false } else if std::env::var_os("FORCE_COLOR").is_some() @@ -161,10 +181,12 @@ pub fn init_for_cli_run(verbose: bool, default_scope: CliLogDefault) { { true } else { + // Auto-detect based on stderr terminal status. io::stderr().is_terminal() }; let file_constraints = parse_log_file_constraints(); + // Build the primary formatting layer. let fmt_layer = tracing_subscriber::fmt::layer() .with_ansi(use_color) .event_format(CleanCliFormat) @@ -172,6 +194,7 @@ pub fn init_for_cli_run(verbose: bool, default_scope: CliLogDefault) { event_matches_file_constraints(meta, &file_constraints) })); + // Build the Sentry integration layer. let sentry_layer = sentry::integrations::tracing::layer().event_filter(|md: &tracing::Metadata<'_>| { match *md.level() { @@ -183,12 +206,14 @@ pub fn init_for_cli_run(verbose: bool, default_scope: CliLogDefault) { } }); + // Register the subscriber with all layers. let _ = tracing_subscriber::registry() .with(filter) .with(fmt_layer) .with(sentry_layer) .try_init(); + // Bridge the `log` crate. let _ = tracing_log::LogTracer::init(); }); } diff --git a/src/core/mod.rs b/src/core/mod.rs index 8fbb2b405..257fed087 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,4 +1,9 @@ //! Shared core-level schemas and contracts used across adapters (RPC, CLI, etc.). +//! +//! This module defines the foundational types for OpenHuman's controller system, +//! which provides a transport-agnostic way to define and invoke domain logic. +//! It also exports submodules for CLI handling, event bus, and RPC server. + use serde::Serialize; pub mod agent_cli; @@ -21,63 +26,88 @@ pub mod types; /// Canonical function contract for domain controllers. /// /// This shape is transport-agnostic and can be consumed by RPC and CLI layers -/// in different ways. +/// in different ways. It defines the identity, purpose, and I/O signature +/// of a specific piece of domain logic. #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct ControllerSchema { /// Domain/group identifier, e.g. `memory`, `config`, `credentials`. + /// This forms the first part of the RPC method name. pub namespace: &'static str, /// Function identifier inside namespace, e.g. `doc_put`. + /// This forms the second part of the RPC method name. pub function: &'static str, - /// One-line human-readable purpose. + /// One-line human-readable purpose, used for CLI help and API documentation. pub description: &'static str, /// Ordered input parameters accepted by the controller function. + /// Each input is a field with a name, type, and description. pub inputs: Vec, /// Ordered output fields returned by the controller function. + /// This defines the structure of the successful response. pub outputs: Vec, } impl ControllerSchema { /// Canonical dotted name for routing, e.g. `memory.doc_put`. + /// This is used internally to identify the controller. pub fn method_name(&self) -> String { format!("{}.{}", self.namespace, self.function) } } /// Schema for one input/output field. +/// +/// Defines the properties of a single parameter or return value, +/// enabling validation and documentation generation. #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct FieldSchema { - /// Field name. + /// Field name. Used as the key in JSON objects or as a CLI flag. pub name: &'static str, - /// Field type. + /// Field type, defining the expected data shape and enabling validation. pub ty: TypeSchema, - /// Human-readable description for docs/help. + /// Human-readable description for docs/help. Should explain what the field is for. pub comment: &'static str, /// Requiredness for adapters: - /// - input: required argument/flag - /// - output: always-present field when true + /// - input: if true, the argument/flag MUST be provided. + /// - output: if true, the field is guaranteed to be present in the response. pub required: bool, } /// Type-system shape used by controller input/output schema fields. +/// +/// This enum represents the set of supported types that can be passed +/// across the controller boundary. #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub enum TypeSchema { + /// A boolean value (true/false). Bool, + /// A 64-bit signed integer. I64, + /// A 64-bit unsigned integer. U64, + /// A 64-bit floating point number. F64, + /// A UTF-8 encoded string. String, + /// A generic JSON value (serde_json::Value). Json, + /// Raw binary data. Bytes, + /// An ordered list of values of a specific type. Array(Box), /// String-keyed map/object with homogeneous values. Map(Box), + /// An optional value that may be null or a value of the inner type. Option(Box), + /// A string that must match one of the predefined variants. Enum { + /// The list of allowed string variants. variants: Vec<&'static str>, }, + /// A nested object with its own set of fields. Object { + /// The fields defining the object's structure. fields: Vec, }, - /// Reference to a named shared/domain type. + /// Reference to a named shared/domain type defined elsewhere. Ref(&'static str), } diff --git a/src/core/rpc_log.rs b/src/core/rpc_log.rs index 1fc069d00..35191aefc 100644 --- a/src/core/rpc_log.rs +++ b/src/core/rpc_log.rs @@ -1,5 +1,9 @@ use serde_json::Value; +/// Formats a JSON-RPC request ID into a human-readable string. +/// +/// Handles different JSON types (String, Number, Null) to ensure consistent +/// output in log messages. pub fn format_request_id(id: &Value) -> String { match id { Value::String(s) => s.clone(), @@ -9,10 +13,18 @@ pub fn format_request_id(id: &Value) -> String { } } +/// Redacts sensitive keys from a JSON parameters object before logging. +/// +/// This is used to prevent accidental leakage of API keys, tokens, and passwords +/// in debug logs. pub fn redact_params_for_log(params: &Value) -> Value { redact_value(params) } +/// Produces a short summary of a JSON value, useful for high-level logging. +/// +/// Instead of printing a potentially massive object/array, it returns a +/// string like `object(keys=foo,bar)` or `array(len=10)`. pub fn summarize_rpc_result(result: &Value) -> String { match result { Value::Object(map) => { @@ -28,10 +40,15 @@ pub fn summarize_rpc_result(result: &Value) -> String { } } +/// Redacts sensitive keys from a JSON result object before trace logging. pub fn redact_result_for_trace(result: &Value) -> Value { redact_value(result) } +/// Recursively redacts sensitive information from a JSON value. +/// +/// It traverses objects and arrays, replacing values of keys that match +/// [`is_sensitive_key`] with `[REDACTED]`. fn redact_value(value: &Value) -> Value { match value { Value::Object(map) => { @@ -50,6 +67,7 @@ fn redact_value(value: &Value) -> Value { } } +/// Returns true if a key name is considered sensitive (e.g., "api_key", "password"). fn is_sensitive_key(key: &str) -> bool { matches!( key, diff --git a/src/core/shutdown.rs b/src/core/shutdown.rs index 42203f0f6..fe93ecb0f 100644 --- a/src/core/shutdown.rs +++ b/src/core/shutdown.rs @@ -19,7 +19,16 @@ static HOOKS: Lazy>> = Lazy::new(|| Mutex::new(Vec::new( /// Register a cleanup function to run on graceful shutdown. /// -/// Hooks execute sequentially in registration order. +/// Use this to perform necessary cleanup tasks such as stopping background +/// services, flushing caches, or closing database connections when the +/// application is shutting down. +/// +/// Hooks execute sequentially in the order they were registered. +/// +/// # Arguments +/// +/// * `hook` - A function that returns a future. The future will be awaited +/// during the shutdown process. pub fn register(hook: F) where F: Fn() -> Fut + Send + Sync + 'static, @@ -30,9 +39,12 @@ where } /// Run all registered hooks (called once during shutdown). +/// +/// This function drains the global `HOOKS` list and awaits each hook in sequence. async fn run_hooks() { let hooks: Vec = { let mut guard = HOOKS.lock().expect("shutdown hooks poisoned"); + // Use mem::take to clear the hooks list and take ownership of the vector. std::mem::take(&mut *guard) }; for hook in &hooks { @@ -43,14 +55,21 @@ async fn run_hooks() { /// Returns a future that resolves when the process receives a termination /// signal (SIGINT on all platforms, plus SIGTERM on Unix), then runs all /// registered shutdown hooks. +/// +/// This is intended to be used with [`axum::serve`]'s `with_graceful_shutdown` +/// method or in the main loop to handle clean exits. pub async fn signal() { + // Wait for the OS to send a termination signal. wait_for_signal().await; log::info!("[core] shutdown signal received, cleaning up background services"); + // Once received, run all registered cleanup tasks. run_hooks().await; log::info!("[core] all shutdown hooks completed"); } -/// Wait for either SIGINT or SIGTERM (Unix) / just SIGINT (non-Unix). +/// Wait for either SIGINT (Ctrl-C) or SIGTERM (Unix termination signal). +/// +/// This uses `tokio::signal` to asynchronously wait for these events. async fn wait_for_signal() { #[cfg(unix)] { @@ -59,7 +78,7 @@ async fn wait_for_signal() { signal(SignalKind::terminate()).expect("failed to install SIGTERM handler"); tokio::select! { _ = tokio::signal::ctrl_c() => { - log::info!("[core] received SIGINT"); + log::info!("[core] received SIGINT (Ctrl-C)"); } _ = sigterm.recv() => { log::info!("[core] received SIGTERM"); @@ -69,7 +88,8 @@ async fn wait_for_signal() { #[cfg(not(unix))] { + // On non-Unix platforms (like Windows), we only listen for Ctrl-C. let _ = tokio::signal::ctrl_c().await; - log::info!("[core] received SIGINT"); + log::info!("[core] received SIGINT (Ctrl-C)"); } } diff --git a/src/core/socketio.rs b/src/core/socketio.rs index 8790f557b..da9820ea0 100644 --- a/src/core/socketio.rs +++ b/src/core/socketio.rs @@ -4,29 +4,46 @@ use serde_json::json; use socketioxide::extract::{Data, SocketRef}; use socketioxide::SocketIo; +/// Standard event payload for the web channel transport. +/// +/// This structure defines the data sent to Socket.IO clients for various +/// chat-related events, such as message delivery, tool execution, and errors. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct WebChannelEvent { + /// The event name (e.g., `chat_message`, `tool_call`). pub event: String, + /// Unique identifier for the Socket.IO client. pub client_id: String, + /// Identifier for the specific chat thread. pub thread_id: String, + /// Unique identifier for the individual request/turn. pub request_id: String, + /// The full text of the assistant's response (sent on completion). #[serde(skip_serializing_if = "Option::is_none")] pub full_response: Option, + /// A partial message segment or an error description. #[serde(skip_serializing_if = "Option::is_none")] pub message: Option, + /// Type of error, if the event represents a failure. #[serde(skip_serializing_if = "Option::is_none")] pub error_type: Option, + /// Name of the tool being called. #[serde(skip_serializing_if = "Option::is_none")] pub tool_name: Option, + /// ID of the skill owning the tool. #[serde(skip_serializing_if = "Option::is_none")] pub skill_id: Option, + /// Arguments passed to the tool. #[serde(skip_serializing_if = "Option::is_none")] pub args: Option, + /// The raw output from the tool execution. #[serde(skip_serializing_if = "Option::is_none")] pub output: Option, + /// Whether the tool execution or request was successful. #[serde(skip_serializing_if = "Option::is_none")] pub success: Option, + /// The current iteration/round number in a tool-call loop. #[serde(skip_serializing_if = "Option::is_none")] pub round: Option, /// Emoji reaction the assistant wants to add to the user's message. @@ -65,6 +82,13 @@ struct ChatCancelPayload { thread_id: String, } +/// Attaches the Socket.IO layer to the Axum router and sets up event handlers. +/// +/// It configures: +/// - Client connection and room joining. +/// - `rpc:request`: Invoking JSON-RPC methods over WebSocket. +/// - `chat:start`: Initiating a new chat turn. +/// - `chat:cancel`: Aborting an active chat turn. pub fn attach_socketio() -> (socketioxide::layer::SocketIoLayer, SocketIo) { let (layer, io) = SocketIo::new_layer(); @@ -76,127 +100,128 @@ pub fn attach_socketio() -> (socketioxide::layer::SocketIoLayer, SocketIo) { io.ns("/", |socket: SocketRef| { let client_id = socket.id.to_string(); log::info!("[socketio] client connected id={client_id}"); + // Join a room named after the client ID for targeted event delivery. let _ = socket.join(client_id.clone()); let ready_payload = json!({ "sid": client_id }); log::debug!("[socketio] emit event=ready to_client={}", socket.id); let _ = socket.emit("ready", &ready_payload); - socket.on("rpc:request", |socket: SocketRef, Data(payload): Data| async move { - let client_id = socket.id.to_string(); - log::info!( - "[socketio] rpc:request method={} id={} client={}", - payload.method, - payload.id, - client_id - ); - log::debug!( - "[socketio] rpc:request params_type={} params_bytes={}", - json_type_name(&payload.params), - payload.params.to_string().len() - ); + // Handler for JSON-RPC over WebSocket. + socket.on( + "rpc:request", + |socket: SocketRef, Data(payload): Data| async move { + let client_id = socket.id.to_string(); + log::info!( + "[socketio] rpc:request method={} id={} client={}", + payload.method, + payload.id, + client_id + ); - let response = match crate::core::jsonrpc::invoke_method( - crate::core::jsonrpc::default_state(), - payload.method.as_str(), - payload.params, - ) - .await - { - Ok(result) => { - log::debug!( - "[socketio] send event=rpc:response client_id={} id={} result_type={} result_bytes={}", - client_id, - payload.id, - json_type_name(&result), - result.to_string().len() - ); - ("rpc:response", json!({ "id": payload.id, "result": result })) - } - Err(message) => { - log::debug!( - "[socketio] send event=rpc:error client_id={} id={} message={}", - client_id, - payload.id, - message - ); - ( + // Invoke the method through the same logic used by the HTTP RPC endpoint. + let response = match crate::core::jsonrpc::invoke_method( + crate::core::jsonrpc::default_state(), + payload.method.as_str(), + payload.params, + ) + .await + { + Ok(result) => ( + "rpc:response", + json!({ "id": payload.id, "result": result }), + ), + Err(message) => ( "rpc:error", json!({ "id": payload.id, "error": { "code": -32000, "message": message } }), - ) - } - }; + ), + }; - let _ = socket.emit(response.0, &response.1); - }); + let _ = socket.emit(response.0, &response.1); + }, + ); - socket.on("chat:start", |socket: SocketRef, Data(payload): Data| async move { - let client_id = socket.id.to_string(); - let thread_id = payload.thread_id.clone(); - let model_override = payload.model_override.or(payload.model); - log::debug!( - "[socketio] recv event=chat:start client_id={} thread_id={} message_bytes={} model_override={:?} temperature={:?}", - client_id, - thread_id, - payload.message.len(), - model_override, - payload.temperature - ); + // Handler for starting a chat turn. + socket.on( + "chat:start", + |socket: SocketRef, Data(payload): Data| async move { + let client_id = socket.id.to_string(); + let thread_id = payload.thread_id.clone(); + let model_override = payload.model_override.or(payload.model); + log::debug!( + "[socketio] recv event=chat:start client_id={} thread_id={} message_bytes={}", + client_id, + thread_id, + payload.message.len() + ); - match crate::openhuman::channels::providers::web::start_chat( - &client_id, - &payload.thread_id, - &payload.message, - model_override, - payload.temperature, - ) - .await - { - Ok(request_id) => { - let accepted_payload = json!({ - "event": "chat_accepted", - "client_id": client_id, - "thread_id": thread_id, - "request_id": request_id, - }); - log::debug!("[socketio] send event=chat_accepted client_id={} thread_id={}", socket.id, payload.thread_id); - emit_with_aliases(&socket, "chat_accepted", &accepted_payload); - } - Err(error) => { - let error_payload = json!({ - "event": "chat_error", - "client_id": client_id, - "thread_id": thread_id, - "request_id": "", - "message": error, - "error_type": "inference", - }); - log::debug!("[socketio] send event=chat_error client_id={} thread_id={} message={}", socket.id, payload.thread_id, error); - emit_with_aliases(&socket, "chat_error", &error_payload); + // Trigger the web channel's chat logic. + match crate::openhuman::channels::providers::web::start_chat( + &client_id, + &payload.thread_id, + &payload.message, + model_override, + payload.temperature, + ) + .await + { + Ok(request_id) => { + let accepted_payload = json!({ + "event": "chat_accepted", + "client_id": client_id, + "thread_id": thread_id, + "request_id": request_id, + }); + emit_with_aliases(&socket, "chat_accepted", &accepted_payload); + } + Err(error) => { + let error_payload = json!({ + "event": "chat_error", + "client_id": client_id, + "thread_id": thread_id, + "request_id": "", + "message": error, + "error_type": "inference", + }); + emit_with_aliases(&socket, "chat_error", &error_payload); + } } - } - }); + }, + ); - socket.on("chat:cancel", |socket: SocketRef, Data(payload): Data| async move { - let client_id = socket.id.to_string(); - log::debug!( - "[socketio] recv event=chat:cancel client_id={} thread_id={}", - client_id, - payload.thread_id - ); - let _ = - crate::openhuman::channels::providers::web::cancel_chat(&client_id, &payload.thread_id) - .await; - }); + // Handler for cancelling an active chat turn. + socket.on( + "chat:cancel", + |socket: SocketRef, Data(payload): Data| async move { + let client_id = socket.id.to_string(); + log::debug!( + "[socketio] recv event=chat:cancel client_id={} thread_id={}", + client_id, + payload.thread_id + ); + let _ = crate::openhuman::channels::providers::web::cancel_chat( + &client_id, + &payload.thread_id, + ) + .await; + }, + ); }); (layer, io) } +/// Spawns background bridges to forward various system events to Socket.IO clients. +/// +/// This function sets up four bridges: +/// 1. **Web Channel Bridge**: Forwards chat-related events (messages, tool calls) to specific clients. +/// 2. **Dictation Bridge**: Forwards hotkey events to all clients. +/// 3. **Overlay Bridge**: Forwards attention bubble events to all clients. +/// 4. **Transcription Bridge**: Forwards real-time speech-to-text results to all clients. pub fn spawn_web_channel_bridge(io: SocketIo) { - // Web channel events → per-client rooms. + // 1. Web channel events → per-client rooms. let io_web = io.clone(); tokio::spawn(async move { let mut rx = crate::openhuman::channels::providers::web::subscribe_web_channel_events(); @@ -218,12 +243,10 @@ pub fn spawn_web_channel_bridge(io: SocketIo) { log::debug!("[socketio] web_channel bridge stopped"); }); - // Clone for the transcription and overlay bridges spawned below; the - // dictation task takes ownership of `io` itself. let io_transcription = io.clone(); let io_overlay = io.clone(); - // Dictation hotkey events → broadcast to all connected clients. + // 2. Dictation hotkey events → broadcast to all connected clients. tokio::spawn(async move { let mut rx = crate::openhuman::voice::dictation_listener::subscribe_dictation_events(); loop { @@ -241,6 +264,7 @@ pub fn spawn_web_channel_bridge(io: SocketIo) { "[socketio] broadcast dictation:{} to all clients", event.event_type ); + // Support both colon and underscore versions for compatibility with different frontends. let _ = io.emit("dictation:toggle", &payload); let _ = io.emit("dictation_toggle", &payload); } @@ -248,13 +272,7 @@ pub fn spawn_web_channel_bridge(io: SocketIo) { log::debug!("[socketio] dictation bridge stopped"); }); - // Overlay attention events → broadcast to the overlay window. - // - // Any core-side caller (subconscious loop, heartbeat, screen intelligence, …) - // can publish an `OverlayAttentionEvent` via - // `openhuman::overlay::publish_attention(...)` and it will be forwarded - // to all Socket.IO clients here. The overlay window listens on a dedicated - // unauthenticated socket (see `OverlayApp.tsx`) and renders the bubble. + // 3. Overlay attention events → broadcast to all clients. tokio::spawn(async move { let mut rx = crate::openhuman::overlay::subscribe_attention_events(); loop { @@ -272,9 +290,8 @@ pub fn spawn_web_channel_bridge(io: SocketIo) { if let Ok(payload) = serde_json::to_value(&event) { log::debug!( - "[socketio] broadcast overlay:attention source={:?} message_bytes={}", - event.source, - event.message.len() + "[socketio] broadcast overlay:attention source={:?}", + event.source ); let _ = io_overlay.emit("overlay:attention", &payload); let _ = io_overlay.emit("overlay_attention", &payload); @@ -283,7 +300,7 @@ pub fn spawn_web_channel_bridge(io: SocketIo) { log::debug!("[socketio] overlay attention bridge stopped"); }); - // Transcription results → broadcast to all connected clients. + // 4. Transcription results → broadcast to all connected clients. tokio::spawn(async move { let mut rx = crate::openhuman::voice::dictation_listener::subscribe_transcription_results(); loop { @@ -355,17 +372,6 @@ fn emit_room_with_aliases(io: &SocketIo, room: &str, name: &str, payload: &serde } } -fn json_type_name(value: &serde_json::Value) -> &'static str { - match value { - serde_json::Value::Null => "null", - serde_json::Value::Bool(_) => "bool", - serde_json::Value::Number(_) => "number", - serde_json::Value::String(_) => "string", - serde_json::Value::Array(_) => "array", - serde_json::Value::Object(_) => "object", - } -} - #[cfg(test)] mod tests { use super::event_alias; diff --git a/src/core/types.rs b/src/core/types.rs index 4e49dbcc7..39132faab 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -7,21 +7,27 @@ use serde::{Deserialize, Serialize}; use serde_json::json; /// Standard response structure for commands that include execution logs. +/// +/// This is commonly used in internal APIs and CLI outputs where it's +/// important to see the side-effects or diagnostic information alongside +/// the primary result. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CommandResponse { /// The primary data returned by the command. pub result: T, /// A list of log messages generated during command execution. + /// These can include warnings, info, or trace messages. pub logs: Vec, } /// Success payload from a core RPC handler before JSON-RPC wrapping. /// /// This internal type allows handlers to return a generic JSON value along -/// with optional logs. +/// with optional logs. It is transformed into a [`RpcSuccess`] or a +/// combined object by [`invocation_to_rpc_json`]. #[derive(Debug, Clone)] pub struct InvocationResult { - /// The value returned by the RPC function call. + /// The value returned by the RPC function call, serialized to JSON. pub value: serde_json::Value, /// A list of execution logs. pub logs: Vec, @@ -29,6 +35,8 @@ pub struct InvocationResult { impl InvocationResult { /// Creates a success result from any serializable value with no logs. + /// + /// This is the most common way to return a value from a controller. pub fn ok(v: T) -> Result { Ok(Self { value: serde_json::to_value(v).map_err(|e| e.to_string())?, @@ -37,6 +45,8 @@ impl InvocationResult { } /// Creates a success result from a serializable value with accompanying logs. + /// + /// Use this when the domain logic has meaningful logs to surface to the caller. pub fn with_logs(v: T, logs: Vec) -> Result { Ok(Self { value: serde_json::to_value(v).map_err(|e| e.to_string())?, @@ -49,6 +59,11 @@ impl InvocationResult { /// /// If there are no logs, returns the value directly. Otherwise, returns an /// object containing both `result` and `logs` keys. +/// +/// # Logic +/// +/// - `logs.is_empty()` -> `inv.value` +/// - `!logs.is_empty()` -> `{ "result": inv.value, "logs": inv.logs }` pub fn invocation_to_rpc_json(inv: InvocationResult) -> serde_json::Value { if inv.logs.is_empty() { inv.value @@ -57,25 +72,29 @@ pub fn invocation_to_rpc_json(inv: InvocationResult) -> serde_json::Value { } } -/// Standard JSON-RPC request format. +/// Standard JSON-RPC 2.0 request format. +/// +/// As defined in the [JSON-RPC 2.0 Specification](https://www.jsonrpc.org/specification). #[derive(Debug, Deserialize)] pub struct RpcRequest { - /// The JSON-RPC version (e.g., `2.0`). + /// The JSON-RPC version. MUST be exactly "2.0". #[allow(dead_code)] pub jsonrpc: String, - /// Unique identifier for the request, to be mirrored in the response. + /// Unique identifier for the request. MUST be a String, Number, or Null. + /// The server will return this same ID in the response. pub id: serde_json::Value, - /// The name of the method to be invoked. + /// The name of the method to be invoked (e.g., `openhuman.memory_doc_put`). pub method: String, - /// Parameters for the method call. Defaults to null if not provided. + /// Parameters for the method call. MUST be a structured value (Object or Array). + /// Defaults to null if not provided. #[serde(default)] pub params: serde_json::Value, } -/// Standard JSON-RPC success response format. +/// Standard JSON-RPC 2.0 success response format. #[derive(Debug, Serialize)] pub struct RpcSuccess { - /// The JSON-RPC version (always `2.0`). + /// The JSON-RPC version. ALWAYS "2.0". pub jsonrpc: &'static str, /// The identifier mirrored from the original request. pub id: serde_json::Value, @@ -83,10 +102,10 @@ pub struct RpcSuccess { pub result: serde_json::Value, } -/// Standard JSON-RPC error response format. +/// Standard JSON-RPC 2.0 error response format. #[derive(Debug, Serialize)] pub struct RpcFailure { - /// The JSON-RPC version (always `2.0`). + /// The JSON-RPC version. ALWAYS "2.0". pub jsonrpc: &'static str, /// The identifier mirrored from the original request. pub id: serde_json::Value, @@ -95,19 +114,29 @@ pub struct RpcFailure { } /// Detail about an RPC invocation error. +/// +/// Contains a code, a message, and optional extra data for debugging. #[derive(Debug, Serialize)] pub struct RpcError { - /// Standardized error code (e.g., -32601 for Method not found). + /// Standardized error code. + /// - -32700: Parse error + /// - -32600: Invalid Request + /// - -32601: Method not found + /// - -32602: Invalid params + /// - -32603: Internal error + /// - -32000 to -32099: Reserved for implementation-defined server-errors. pub code: i64, /// A short, human-readable error message. pub message: String, - /// Optional additional diagnostic data. + /// Optional additional diagnostic data, which can be any JSON value. pub data: Option, } /// Global core-level application state. +/// +/// Currently holds shared metadata like the core version. #[derive(Clone)] pub struct AppState { - /// The current version of the OpenHuman core binary. + /// The current version of the OpenHuman core binary, usually from `CARGO_PKG_VERSION`. pub core_version: String, } diff --git a/src/openhuman/accessibility/focus.rs b/src/openhuman/accessibility/focus.rs index 4603e0713..ffd7b71ac 100644 --- a/src/openhuman/accessibility/focus.rs +++ b/src/openhuman/accessibility/focus.rs @@ -3,7 +3,9 @@ //! Primary path: unified Swift helper (native AX API, fast, persistent process). //! Fallback: osascript subprocess (slower, but works without compiled helper). +#[cfg(target_os = "macos")] use super::terminal::{is_terminal_app, is_text_role}; +#[cfg(target_os = "macos")] use super::text_util::{normalize_ax_value, parse_ax_number}; use super::types::{AppContext, ElementBounds, FocusedTextContext}; diff --git a/src/openhuman/accessibility/globe.rs b/src/openhuman/accessibility/globe.rs index 8fcd41574..e5a440e39 100644 --- a/src/openhuman/accessibility/globe.rs +++ b/src/openhuman/accessibility/globe.rs @@ -4,6 +4,7 @@ //! events globally and reports `FN_DOWN` / `FN_UP` lines over stdout. use super::{detect_permissions, PermissionState}; +#[cfg(target_os = "macos")] use std::collections::VecDeque; #[cfg(target_os = "macos")] diff --git a/src/openhuman/accessibility/overlay.rs b/src/openhuman/accessibility/overlay.rs index 89e1f5786..c4fef3a55 100644 --- a/src/openhuman/accessibility/overlay.rs +++ b/src/openhuman/accessibility/overlay.rs @@ -1,5 +1,6 @@ //! Overlay display via the unified Swift helper process. +#[cfg(target_os = "macos")] use super::text_util::truncate_tail; use super::types::ElementBounds; diff --git a/src/openhuman/accessibility/paste.rs b/src/openhuman/accessibility/paste.rs index 528cd4a36..6b163d0f9 100644 --- a/src/openhuman/accessibility/paste.rs +++ b/src/openhuman/accessibility/paste.rs @@ -2,6 +2,7 @@ //! //! Three-tier strategy: (1) Swift helper paste, (2) osascript clipboard + CGEvent, (3) AXValue write. +#[cfg(target_os = "macos")] use super::text_util::truncate_tail; /// Apply suggestion text to the focused field. diff --git a/src/openhuman/accessibility/permissions.rs b/src/openhuman/accessibility/permissions.rs index 789f9ee98..864a0fcf7 100644 --- a/src/openhuman/accessibility/permissions.rs +++ b/src/openhuman/accessibility/permissions.rs @@ -73,8 +73,8 @@ pub fn open_macos_privacy_pane(pane: &str) { #[cfg(target_os = "macos")] pub fn request_accessibility_access() { unsafe { - let keys = [kAXTrustedCheckOptionPrompt as *const c_void]; - let values = [kCFBooleanTrue as *const c_void]; + let keys = [kAXTrustedCheckOptionPrompt]; + let values = [kCFBooleanTrue]; let options = CFDictionaryCreate( kCFAllocatorDefault, keys.as_ptr(), diff --git a/src/openhuman/agent/agents/mod.rs b/src/openhuman/agent/agents/mod.rs index 768dcf818..599213a87 100644 --- a/src/openhuman/agent/agents/mod.rs +++ b/src/openhuman/agent/agents/mod.rs @@ -4,8 +4,8 @@ //! two files: //! //! * `agent.toml` — id, when_to_use, model, tool allowlist, sandbox, -//! iteration cap, and the `omit_*` flags. Parsed -//! directly into [`AgentDefinition`] via serde. +//! iteration cap, and the `omit_*` flags. Parsed +//! directly into [`AgentDefinition`] via serde. //! * `prompt.md` — the sub-agent's system prompt body. //! //! Adding a new built-in agent = creating a new subfolder with those two diff --git a/src/openhuman/agent/bus.rs b/src/openhuman/agent/bus.rs index d0a87bbdc..4c0fa6bf1 100644 --- a/src/openhuman/agent/bus.rs +++ b/src/openhuman/agent/bus.rs @@ -27,6 +27,12 @@ use super::harness::run_tool_call_loop; /// Method name used to dispatch an agentic turn through the native bus. pub const AGENT_RUN_TURN_METHOD: &str = "agent.run_turn"; +/// Full owned payload for a single agentic turn executed through the bus. +/// +/// All fields are either owned values, [`Arc`]s, or channel handles — the +/// bus carries them by value without touching serialization. Consumers can +/// therefore pass trait objects (`Arc`, tool trait-object +/// registries) and streaming senders (`on_delta`) through unchanged. /// Full owned payload for a single agentic turn executed through the bus. /// /// All fields are either owned values, [`Arc`]s, or channel handles — the @@ -35,30 +41,45 @@ pub const AGENT_RUN_TURN_METHOD: &str = "agent.run_turn"; /// registries) and streaming senders (`on_delta`) through unchanged. pub struct AgentTurnRequest { /// LLM provider, already constructed and warmed up by the caller. + /// Shared via Arc to allow sub-agents to reuse the same connection pool. pub provider: Arc, + /// Full conversation history including system prompt and the incoming /// user message. The handler mutates an internal clone of this during /// the tool-call loop; callers should rebuild their per-session cache /// from their own records, not from this vector. pub history: Vec, + /// Registered tool implementations available to this turn. + /// These are provided as trait objects to avoid tight coupling with tool implementations. pub tools_registry: Arc>>, - /// Provider name token (e.g. `"openai"`) — routed to the loop as-is. + + /// Provider name token (e.g. `"openai"`) — routed to the loop as-is for logging and tracking. pub provider_name: String, + /// Model identifier (e.g. `"gpt-4"`) — routed to the loop as-is. pub model: String, - /// Sampling temperature. + + /// Sampling temperature. Higher values (e.g., 0.7) are more creative, + /// lower (e.g., 0.0) are more deterministic. pub temperature: f64, + /// When `true`, suppresses stdout during the tool loop (always set by - /// channel callers). + /// channel callers to prevent cluttering the main console). pub silent: bool, + /// Channel name this turn belongs to (e.g. `"telegram"`, `"cli"`). + /// Used for context and telemetry. pub channel_name: String, + /// Multimodal feature configuration (image inlining rules, payload /// size caps). pub multimodal: MultimodalConfig, + /// Maximum number of LLM↔tool round-trips before bailing out. + /// Prevents infinite loops if a model gets "stuck" calling the same tool. pub max_tool_iterations: usize, + /// Optional streaming sender — the loop forwards partial LLM text /// chunks here so channel providers can update "draft" messages in /// real time. `None` disables streaming for this turn. @@ -67,15 +88,16 @@ pub struct AgentTurnRequest { /// Final response from an agentic turn. pub struct AgentTurnResponse { - /// Final assistant text after all tool calls resolved. + /// Final assistant text after all tool calls resolved and the loop terminated. pub text: String, } /// Register the agent domain's native request handlers on the global /// registry. Safe to call multiple times — the last registration wins. /// -/// Called from the canonical bus wiring in -/// `src/core/jsonrpc.rs::register_domain_subscribers`. +/// This function wires the `agent.run_turn` method into the core event bus, +/// allowing any part of the system to request an agentic turn without +/// depending directly on the agent harness. pub fn register_agent_handlers() { register_native_global::( AGENT_RUN_TURN_METHOD, diff --git a/src/openhuman/agent/dispatcher.rs b/src/openhuman/agent/dispatcher.rs index 989c3bd73..a6e08f229 100644 --- a/src/openhuman/agent/dispatcher.rs +++ b/src/openhuman/agent/dispatcher.rs @@ -9,47 +9,68 @@ use serde_json::Value; use std::fmt::Write; use std::sync::Arc; +/// A parsed tool call representation after being extracted from an LLM response. #[derive(Debug, Clone)] pub struct ParsedToolCall { + /// The name of the tool to be invoked. pub name: String, + /// The arguments passed to the tool, as a JSON object. pub arguments: Value, + /// An optional unique identifier for the tool call, provided by native APIs. pub tool_call_id: Option, } +/// The result of executing a tool call, formatted for the LLM. #[derive(Debug, Clone)] pub struct ToolExecutionResult { + /// The name of the tool that was executed. pub name: String, + /// The output of the tool execution as a string. pub output: String, + /// Whether the tool execution was successful. pub success: bool, + /// The tool call ID that generated this result. pub tool_call_id: Option, } +/// Trait defining how an agent interacts with an LLM for tool use. +/// +/// Different LLMs have different "dialects" for calling tools. The dispatcher +/// abstracts these differences, allowing the agent loop to remain agnostic of +/// the specific formatting required by the provider. pub trait ToolDispatcher: Send + Sync { + /// Parse the LLM response to extract narrative text and any tool calls. fn parse_response(&self, response: &ChatResponse) -> (String, Vec); + + /// Format tool execution results into a message suitable for the next LLM turn. fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage; + + /// Provide instructions for the system prompt on how the model should call tools. fn prompt_instructions(&self, tools: &[Box]) -> String; + + /// Convert internal conversation history into provider-specific messages. fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec; + + /// Whether the dispatcher requires tool specifications to be sent in the API request. fn should_send_tool_specs(&self) -> bool; /// Tell the prompt builder how to render each tool entry in the /// `## Tools` section. Defaults to [`ToolCallFormat::Json`] for - /// dispatchers that haven't opted in — `ToolsSection` then uses - /// the historic schema-dump rendering. - /// - /// `PFormatToolDispatcher` overrides this to return - /// [`ToolCallFormat::PFormat`] so the catalogue shows positional - /// signatures (`get_weather[location|unit]`) instead of full JSON - /// schemas — that's where most of the token saving comes from at - /// the prompt level. + /// dispatchers that haven't opted in. fn tool_call_format(&self) -> ToolCallFormat { ToolCallFormat::Json } } +/// Legacy dispatcher using XML-style tags (``) with JSON bodies. +/// +/// This is robust and works well with models that aren't natively trained for +/// tool calling but can follow instructions in a system prompt. #[derive(Default)] pub struct XmlToolDispatcher; impl XmlToolDispatcher { + /// Internal helper to extract tool calls from a raw text string. fn parse_tool_calls_from_text(response: &str) -> (String, Vec) { let (text, calls) = parse_tool_calls(response); let parsed_calls = calls @@ -63,6 +84,7 @@ impl XmlToolDispatcher { (text, parsed_calls) } + /// Extract serializable specs for all tools in the registry. pub fn tool_specs(tools: &[Box]) -> Vec { tools.iter().map(|tool| tool.spec()).collect() } @@ -145,42 +167,31 @@ impl ToolDispatcher for XmlToolDispatcher { } /// Text-based dispatcher that emits and parses **P-Format** ("Parameter -/// Format") tool calls — the compact `tool_name[arg1|arg2|...]` syntax -/// defined in [`crate::openhuman::agent::pformat`]. +/// Format") tool calls — the compact `tool_name[arg1|arg2|...]` syntax. /// -/// This is the default dispatcher for providers that do not support -/// native structured tool calls. Compared to the legacy -/// [`XmlToolDispatcher`] (XML wrapper + JSON body), p-format cuts the -/// per-call token cost by ~80% — a single weather lookup goes from -/// ~25 tokens to ~5 — which compounds dramatically over a long agent -/// loop. -/// -/// The dispatcher caches a [`PFormatRegistry`] (a `name → params` -/// lookup) at construction time so it never has to hold a reference to -/// the live `Vec>` (which the [`Agent`] owns). The -/// caller is expected to build the registry from the same tool slice -/// they pass into the agent — see `pformat::build_registry`. +/// P-format is designed to significantly reduce token usage compared to JSON. +/// It uses positional arguments based on an alphabetical sort of the tool's +/// parameters. /// /// On the parse side the dispatcher tries p-format **first** and falls /// back to the existing JSON-in-tag parser if the body doesn't match /// the bracket pattern. This keeps the dispatcher backwards-compatible -/// with models that still emit JSON tool calls — they just pay the -/// usual token cost for their bytes. +/// with models that still emit JSON tool calls. pub struct PFormatToolDispatcher { + /// Registry of tool parameter layouts used to reconstruct named arguments from positional ones. registry: Arc, } impl PFormatToolDispatcher { + /// Create a new P-Format dispatcher with the given tool registry. pub fn new(registry: PFormatRegistry) -> Self { Self { registry: Arc::new(registry), } } - /// Convert the registry-driven parser output into the dispatcher's - /// `ParsedToolCall` shape. Always called inside a `` tag - /// body — the tag-finding logic comes from the shared - /// [`parse_tool_calls`] helper. + /// Convert the registry-driven positional parser output into the dispatcher's + /// `ParsedToolCall` shape. Always called inside a `` tag. fn try_parse_pformat_body(&self, body: &str) -> Option { let (name, args) = pformat::parse_call(body, self.registry.as_ref())?; Some(ParsedToolCall { @@ -359,6 +370,12 @@ impl ToolDispatcher for PFormatToolDispatcher { } } +/// Dispatcher for models with native, structured tool-calling support (e.g., OpenAI, Anthropic). +/// +/// This dispatcher leverages the provider's built-in APIs for identifying and +/// reporting tool calls, which is generally more reliable than text-based parsing. +/// It still supports a text-based fallback for robustness against models that +/// might "forget" to use the structured API. pub struct NativeToolDispatcher; impl ToolDispatcher for NativeToolDispatcher { diff --git a/src/openhuman/agent/error.rs b/src/openhuman/agent/error.rs index 47e81222a..d6d769c5c 100644 --- a/src/openhuman/agent/error.rs +++ b/src/openhuman/agent/error.rs @@ -9,31 +9,43 @@ use std::fmt; /// Structured error type for agent loop operations. #[derive(Debug)] pub enum AgentError { - /// The LLM provider returned an error. + /// The LLM provider returned an error (e.g., API key invalid, network failure). + /// `retryable` indicates if the operation should be attempted again. ProviderError { message: String, retryable: bool }, - /// Context window is exhausted and compaction cannot help. + + /// Context window is exhausted and compaction/summarization cannot help. + /// The agent cannot proceed without dropping significant history. ContextLimitExceeded { utilization_pct: u8 }, - /// A tool execution failed. + + /// A tool execution failed during its `execute()` method. ToolExecutionError { tool_name: String, message: String }, - /// The daily cost budget has been exceeded. + + /// The daily cost budget for this user/agent has been exceeded. + /// Prevents unexpected runaway costs. CostBudgetExceeded { spent_microdollars: u64, budget_microdollars: u64, }, - /// The agent exceeded its maximum tool iterations. + + /// The agent exceeded its maximum allowed tool iterations for a single turn. + /// Typically indicates an infinite loop in the model's reasoning. MaxIterationsExceeded { max: usize }, - /// History compaction failed. + + /// Automated history compaction (summarization) failed. CompactionFailed { message: String, consecutive_failures: u8, }, - /// Channel permission denied for a tool operation. + + /// The current channel (e.g., Telegram) does not have permission to execute + /// the requested tool (e.g., shell access). PermissionDenied { tool_name: String, required_level: String, channel_max_level: String, }, - /// Generic/untyped error (escape hatch for migration). + + /// Generic/untyped error (escape hatch for migration or external dependencies). Other(anyhow::Error), } @@ -122,6 +134,7 @@ pub fn is_context_limit_error(error_msg: &str) -> bool { #[cfg(test)] mod tests { use super::*; + use std::error::Error; #[test] fn display_formatting() { @@ -155,4 +168,44 @@ mod tests { assert!(err.to_string().contains("shell")); assert!(err.to_string().contains("Execute")); } + + #[test] + fn display_formats_other_variants() { + assert!(AgentError::ProviderError { + message: "boom".into(), + retryable: true, + } + .to_string() + .contains("retryable=true")); + assert!(AgentError::ContextLimitExceeded { + utilization_pct: 98 + } + .to_string() + .contains("98% utilized")); + assert!(AgentError::ToolExecutionError { + tool_name: "shell".into(), + message: "denied".into(), + } + .to_string() + .contains("Tool execution error [shell]")); + assert!(AgentError::CompactionFailed { + message: "summary failed".into(), + consecutive_failures: 3, + } + .to_string() + .contains("3 consecutive")); + } + + #[test] + fn from_anyhow_recovers_typed_agent_error_and_other_source() { + let typed = anyhow::anyhow!(AgentError::MaxIterationsExceeded { max: 4 }); + match AgentError::from(typed) { + AgentError::MaxIterationsExceeded { max } => assert_eq!(max, 4), + other => panic!("unexpected variant: {other}"), + } + + let other = AgentError::from(anyhow::anyhow!("plain failure")); + assert!(matches!(other, AgentError::Other(_))); + assert!(other.source().is_some()); + } } diff --git a/src/openhuman/agent/harness/definition.rs b/src/openhuman/agent/harness/definition.rs index 9a7bddb40..403b1a5a3 100644 --- a/src/openhuman/agent/harness/definition.rs +++ b/src/openhuman/agent/harness/definition.rs @@ -29,127 +29,94 @@ use std::path::PathBuf; // Agent definition // ───────────────────────────────────────────────────────────────────────────── -/// A fully specified sub-agent: what it knows, what it can do, how to prompt it. +/// A fully specified sub-agent archetype: what it knows, what it can do, and how to prompt it. /// -/// Built-ins live in [`super::builtin_definitions`]; custom ones load from -/// TOML at startup. The [`AgentDefinitionRegistry`] merges them and is the -/// single source of truth that `SpawnSubagentTool` queries. -/// -/// All `omit_*` flags default to `true` for sub-agents — sub-agents are -/// narrow specialists and pay no token tax for the parent's identity, -/// memory, safety, or skills sections. Override per-archetype if a -/// section is needed. +/// Definitions are used by the `spawn_subagent` tool to initialize a new +/// specialized agent. They can be built-in or loaded from custom TOML files. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AgentDefinition { // ── identity ──────────────────────────────────────────────────────── - /// Unique id, referenced from `spawn_subagent { agent_id: "…" }`. - /// Convention: snake_case (e.g. `code_executor`, `notion_specialist`). + /// Unique identifier for this archetype (e.g., `researcher`, `code_executor`). pub id: String, - /// One-line description shown in the orchestrator's `spawn_subagent` - /// tool schema so the parent model knows when to delegate to this agent. + /// Human-readable description explaining when this agent should be used. + /// Shown to the parent model to help it decide whether to delegate. pub when_to_use: String, - /// Optional display name for UI/logs. Falls back to `id`. + /// Optional display name for UI and log output. #[serde(default)] pub display_name: Option, // ── prompt ────────────────────────────────────────────────────────── - /// Source of the sub-agent's core system prompt. - /// - /// Defaults to an empty inline prompt so TOMLs that ship a sibling - /// `prompt.md` can omit this field and let the loader inject the - /// rendered body as [`PromptSource::Inline`]. All in-tree built-ins - /// use that pattern — see [`crate::openhuman::agent::agents`] for - /// the loader. Custom TOML-defined agents may also set this - /// explicitly as [`PromptSource::Inline`] or [`PromptSource::File`]. + /// The core system prompt body for this specialized agent. #[serde(default = "defaults::empty_inline_prompt")] pub system_prompt: PromptSource, - /// Sections of the main agent's prompt to strip when this sub-agent runs. - /// Defaults to `true` (strip) — sub-agents are narrow and don't need the - /// parent's identity scaffolding. + /// If `true`, the parent's identity section is stripped from the prompt. #[serde(default = "defaults::true_")] pub omit_identity: bool, + + /// If `true`, the parent's memory context is stripped. #[serde(default = "defaults::true_")] pub omit_memory_context: bool, + + /// If `true`, the standard safety preamble is stripped. #[serde(default = "defaults::true_")] pub omit_safety_preamble: bool, + + /// If `true`, the global skills catalog is stripped. #[serde(default = "defaults::true_")] pub omit_skills_catalog: bool, // ── model ─────────────────────────────────────────────────────────── - /// Model selection: inherit parent, hint to router, or pinned name. + /// Strategy for picking which model to use for this sub-agent. #[serde(default)] pub model: ModelSpec, - /// Sampling temperature. Sub-agents default to `0.4` for precision. + /// Sampling temperature for the model. #[serde(default = "defaults::subagent_temperature")] pub temperature: f64, // ── tools ─────────────────────────────────────────────────────────── - /// Either [`ToolScope::Wildcard`] (all tools the parent has) or - /// [`ToolScope::Named`] (an explicit allowlist). + /// Which tools from the parent's registry should be available to the sub-agent. #[serde(default)] pub tools: ToolScope, - /// Tools that are explicitly banned even if `tools == Wildcard`. - /// Built-ins default-deny dangerous ops for read-only archetypes. + /// Explicit list of tool names to block, even if they match the scope. #[serde(default)] pub disallowed_tools: Vec, - /// If set, the resolved tool list is further filtered to only those whose - /// name starts with `{skill_filter}__`. Gives us per-API specialists - /// (Notion, Gmail, …) without enum variants. Overridable per-spawn. + /// Filter to only tools belonging to a specific skill (e.g., `notion`). #[serde(default)] pub skill_filter: Option, - /// If set, the resolved tool list is restricted to tools whose - /// [`crate::openhuman::tools::Tool::category`] matches this value. - /// This is the *primary* mechanism the orchestrator uses to spawn - /// dedicated tool-execution sub-agents: - /// - `Some(Skill)` → sub-agent only sees skill-bridge tools - /// (Notion, Gmail, Telegram, …). Pair with `ModelSpec::Hint("agentic")` - /// to route to the backend's agentic model. - /// - `Some(System)` → sub-agent only sees built-in Rust tools. - /// - `None` (default) → no category restriction; `tools` / - /// `disallowed_tools` / `skill_filter` still apply. - /// - /// Category filtering happens *before* the `tools`/`disallowed_tools` - /// scope check, so a `Named` scope is a stricter-intersection override. + /// Filter to only tools belonging to a specific high-level category. #[serde(default)] pub category_filter: Option, // ── runtime limits ────────────────────────────────────────────────── - /// Maximum tool-call iterations per spawn. Sub-agents default to a - /// shorter cap than the parent to keep cost bounded. + /// Maximum number of tool iterations for this sub-agent's task. #[serde(default = "defaults::max_iterations")] pub max_iterations: usize, - /// Hard wall-clock timeout per turn. `None` falls back to - /// `tool_execution_timeout_secs`. + /// Wall-clock timeout for the sub-agent's execution (seconds). #[serde(default)] pub timeout_secs: Option, - /// `none` / `read_only` / `sandboxed`. See [`SandboxMode`]. + /// Sandbox level for tool execution. #[serde(default)] pub sandbox_mode: SandboxMode, - /// If true, spawn runs in the background and the call returns - /// immediately with a placeholder. Reserved — not yet wired in v1. + /// Reserved for background (asynchronous) execution support. #[serde(default)] pub background: bool, - /// Marker: when true, the runner skips its normal prompt-building path - /// and uses the parent's pre-rendered prompt + tool schemas + message - /// prefix from the [`super::fork_context::ForkContext`] task-local. - /// Only the synthetic built-in `fork` definition has this set. + /// Internal flag for `fork` mode sub-agents. #[serde(default, skip_serializing_if = "is_false")] pub uses_fork_context: bool, // ── source bookkeeping ────────────────────────────────────────────── - /// Where this definition came from. Filled in by the loader/builder; - /// not deserialised from TOML. + /// Tracks where the definition was loaded from (Builtin vs. File). #[serde(skip)] pub source: DefinitionSource, } diff --git a/src/openhuman/agent/harness/fork_context.rs b/src/openhuman/agent/harness/fork_context.rs index 484a95812..5668398c3 100644 --- a/src/openhuman/agent/harness/fork_context.rs +++ b/src/openhuman/agent/harness/fork_context.rs @@ -170,158 +170,3 @@ where { FORK_CONTEXT.scope(ctx, future).await } - -#[cfg(test)] -mod tests { - use super::*; - use crate::openhuman::memory::{MemoryCategory, MemoryEntry}; - use crate::openhuman::providers::{ChatRequest, ChatResponse}; - use async_trait::async_trait; - - #[tokio::test] - async fn parent_context_returns_none_outside_scope() { - assert!(current_parent().is_none()); - } - - #[tokio::test] - async fn fork_context_returns_none_outside_scope() { - assert!(current_fork().is_none()); - } - - #[tokio::test] - async fn fork_context_visible_inside_scope() { - let ctx = ForkContext { - system_prompt: Arc::new("hello".into()), - tool_specs: Arc::new(vec![]), - message_prefix: Arc::new(vec![]), - cache_boundary: None, - fork_task_prompt: "do thing".into(), - }; - - with_fork_context(ctx, async { - let inner = current_fork().expect("fork context should be visible"); - assert_eq!(*inner.system_prompt, "hello"); - assert_eq!(inner.fork_task_prompt, "do thing"); - }) - .await; - - // And it disappears once the scope ends. - assert!(current_fork().is_none()); - } - - // ── Minimal stubs so we can construct a ParentExecutionContext - // without pulling in the memory factory or a real provider. None of - // these methods are called by the task-local visibility test — the - // test only reads scalar fields on the context snapshot — so panic - // bodies are fine. - - struct StubProvider; - - #[async_trait] - impl Provider for StubProvider { - async fn chat_with_system( - &self, - _system_prompt: Option<&str>, - _message: &str, - _model: &str, - _temperature: f64, - ) -> anyhow::Result { - unimplemented!("StubProvider::chat_with_system is not called in this test") - } - - async fn chat( - &self, - _request: ChatRequest<'_>, - _model: &str, - _temperature: f64, - ) -> anyhow::Result { - unimplemented!("StubProvider::chat is not called in this test") - } - } - - struct StubMemory; - - #[async_trait] - impl crate::openhuman::memory::Memory for StubMemory { - fn name(&self) -> &str { - "stub" - } - - async fn store( - &self, - _key: &str, - _content: &str, - _category: MemoryCategory, - _session_id: Option<&str>, - ) -> anyhow::Result<()> { - Ok(()) - } - - async fn recall( - &self, - _query: &str, - _limit: usize, - _session_id: Option<&str>, - ) -> anyhow::Result> { - Ok(vec![]) - } - - async fn get(&self, _key: &str) -> anyhow::Result> { - Ok(None) - } - - async fn list( - &self, - _category: Option<&MemoryCategory>, - _session_id: Option<&str>, - ) -> anyhow::Result> { - Ok(vec![]) - } - - async fn forget(&self, _key: &str) -> anyhow::Result { - Ok(false) - } - - async fn count(&self) -> anyhow::Result { - Ok(0) - } - - async fn health_check(&self) -> bool { - true - } - } - - fn stub_parent_context() -> ParentExecutionContext { - ParentExecutionContext { - provider: Arc::new(StubProvider), - all_tools: Arc::new(vec![]), - all_tool_specs: Arc::new(vec![]), - model_name: "stub-model".into(), - temperature: 0.4, - workspace_dir: std::path::PathBuf::from("/tmp"), - memory: Arc::new(StubMemory), - agent_config: AgentConfig::default(), - skills: Arc::new(vec![]), - memory_context: None, - session_id: "test-session".into(), - channel: "test-channel".into(), - } - } - - #[tokio::test] - async fn parent_context_visible_inside_scope() { - let ctx = stub_parent_context(); - - with_parent_context(ctx, async { - let p = current_parent().expect("parent context should be visible"); - assert_eq!(p.model_name, "stub-model"); - assert_eq!(p.session_id, "test-session"); - assert_eq!(p.channel, "test-channel"); - assert!((p.temperature - 0.4).abs() < f64::EPSILON); - }) - .await; - - // And it disappears once the scope ends. - assert!(current_parent().is_none()); - } -} diff --git a/src/openhuman/agent/harness/interrupt.rs b/src/openhuman/agent/harness/interrupt.rs index 62823623d..e0cabc781 100644 --- a/src/openhuman/agent/harness/interrupt.rs +++ b/src/openhuman/agent/harness/interrupt.rs @@ -95,45 +95,3 @@ pub fn check_interrupt(fence: &InterruptFence) -> Result<(), InterruptedError> { Ok(()) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn fence_starts_clear() { - let fence = InterruptFence::new(); - assert!(!fence.is_interrupted()); - } - - #[test] - fn trigger_sets_flag() { - let fence = InterruptFence::new(); - fence.trigger(); - assert!(fence.is_interrupted()); - } - - #[test] - fn reset_clears_flag() { - let fence = InterruptFence::new(); - fence.trigger(); - fence.reset(); - assert!(!fence.is_interrupted()); - } - - #[test] - fn check_interrupt_returns_err_when_triggered() { - let fence = InterruptFence::new(); - assert!(check_interrupt(&fence).is_ok()); - fence.trigger(); - assert!(check_interrupt(&fence).is_err()); - } - - #[test] - fn clone_shares_flag() { - let fence = InterruptFence::new(); - let clone = fence.clone(); - fence.trigger(); - assert!(clone.is_interrupted()); - } -} diff --git a/src/openhuman/agent/harness/memory_context.rs b/src/openhuman/agent/harness/memory_context.rs index 2ad747545..bc2ef98c4 100644 --- a/src/openhuman/agent/harness/memory_context.rs +++ b/src/openhuman/agent/harness/memory_context.rs @@ -65,3 +65,134 @@ pub(crate) async fn build_context( context } + +#[cfg(test)] +mod tests { + use super::*; + use crate::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; + use async_trait::async_trait; + + struct MockMemory { + primary: Vec, + working: Vec, + fail_primary: bool, + } + + #[async_trait] + impl Memory for MockMemory { + fn name(&self) -> &str { + "mock" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + if query.starts_with("working.user ") { + return Ok(self.working.clone()); + } + if self.fail_primary { + anyhow::bail!("primary recall failed"); + } + Ok(self.primary.clone()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + } + + fn entry(key: &str, content: &str, score: Option) -> MemoryEntry { + MemoryEntry { + id: key.into(), + key: key.into(), + content: content.into(), + namespace: None, + category: MemoryCategory::Conversation, + timestamp: "now".into(), + session_id: None, + score, + } + } + + #[tokio::test] + async fn build_context_filters_scores_and_deduplicates_working_memory() { + let mem = MockMemory { + primary: vec![ + entry("task", "primary entry", Some(0.9)), + entry("low", "too low", Some(0.1)), + entry("working.user.profile", "already present", Some(0.9)), + ], + working: vec![ + entry("working.user.profile", "already present", Some(0.95)), + entry("working.user.timezone", "PST", Some(0.95)), + ], + fail_primary: false, + }; + + let context = build_context(&mem, "hello", 0.4).await; + assert!(context.contains("[Memory context]")); + assert!(context.contains("- task: primary entry")); + assert!(!context.contains("too low")); + assert!(context.contains("[User working memory]")); + assert!(context.contains("- working.user.timezone: PST")); + assert_eq!(context.matches("working.user.profile").count(), 1); + } + + #[tokio::test] + async fn build_context_uses_working_memory_even_if_primary_recall_fails() { + let mem = MockMemory { + primary: Vec::new(), + working: vec![entry("working.user.pref", "Use Rust", None)], + fail_primary: true, + }; + + let context = build_context(&mem, "hello", 0.4).await; + assert!(!context.contains("[Memory context]")); + assert!(context.contains("[User working memory]")); + assert!(context.contains("Use Rust")); + } + + #[tokio::test] + async fn build_context_returns_empty_when_nothing_relevant_is_found() { + let mem = MockMemory { + primary: vec![entry("low", "too low", Some(0.1))], + working: vec![entry("not_working", "ignored", Some(0.9))], + fail_primary: false, + }; + + assert!(build_context(&mem, "hello", 0.4).await.is_empty()); + } +} diff --git a/src/openhuman/agent/harness/mod.rs b/src/openhuman/agent/harness/mod.rs index b84cf3ca1..5d1ec6ae9 100644 --- a/src/openhuman/agent/harness/mod.rs +++ b/src/openhuman/agent/harness/mod.rs @@ -1,31 +1,25 @@ //! Multi-agent harness — sub-agent dispatch and fork-cache support. //! -//! ## Subagents-as-tools -//! The main agent runs its normal tool loop and can choose to delegate to a -//! sub-agent at any iteration via the `spawn_subagent` tool. The sub-agent -//! is constructed at call time from an [`definition::AgentDefinition`] -//! looked up in the global [`definition::AgentDefinitionRegistry`], runs -//! its own narrowed tool loop (cheaper model, fewer tools, no memory -//! recall), and returns a single text result that the parent threads back -//! into its history. This is the only execution shape — there is no -//! separate DAG planner/executor. +//! The harness provides the infrastructure for an agent to delegate work to +//! specialized sub-agents. It manages the lifecycle of these sub-agents, +//! including prompt construction, tool filtering, and result synthesis. //! -//! ## Fork-cache mode -//! `spawn_subagent { mode: "fork", … }` replays the parent's *exact* -//! rendered system prompt + tool schemas + message prefix via the -//! [`fork_context::ForkContext`] task-local. The OpenAI-compatible -//! inference backend's automatic prefix caching turns this byte-stable -//! replay into a real token-savings win. +//! ## Delegation via `spawn_subagent` +//! The system treats specialized agents (researchers, planners, etc.) as tools. +//! An agent can invoke the `spawn_subagent` tool, which looks up a definition +//! in the global [`AgentDefinitionRegistry`] and runs a dedicated tool loop. //! -//! ## Built-in agents -//! The canonical list of built-in agents lives in -//! [`crate::openhuman::agent::agents`] — one subfolder per agent, each -//! containing `agent.toml` (id, tools, model, sandbox, iteration cap) -//! and `prompt.md` (the sub-agent's system prompt body). Adding a new -//! built-in agent = drop in a new subfolder and append one entry to -//! that module's `BUILTINS` slice. [`builtin_definitions`] in this -//! harness module is a thin wrapper that loads those files and appends -//! the synthetic `fork` definition (used for prefix-cache reuse). +//! ## Token Optimization +//! - **Typed Sub-agents**: Skips unnecessary system prompt sections (e.g., +//! identity, global skills) to keep sub-agent prompts small. +//! - **Fork Mode**: Allows sub-agents to replay the parent's exact context +//! to leverage KV-cache reuse on the inference backend. +//! +//! ## Key Sub-modules +//! - **[`subagent_runner`]**: The core logic for executing a sub-agent. +//! - **[`definition`]**: Data structures for defining an agent's archetype. +//! - **[`fork_context`]**: Task-local storage for parent context sharing. +//! - **[`interrupt`]**: Infrastructure for graceful cancellation of agent loops. pub(crate) mod archivist; pub(crate) mod builtin_definitions; diff --git a/src/openhuman/agent/harness/parse.rs b/src/openhuman/agent/harness/parse.rs index 3b3625c7a..4cf09e211 100644 --- a/src/openhuman/agent/harness/parse.rs +++ b/src/openhuman/agent/harness/parse.rs @@ -564,3 +564,237 @@ pub(crate) fn tools_to_openai_format(tools_registry: &[Box]) -> Vec &str { + self.0 + } + + fn description(&self) -> &str { + "stub tool" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + } + }) + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + Ok(ToolResult::success("ok")) + } + } + + #[test] + fn parse_argument_helpers_cover_string_non_string_and_missing_values() { + assert_eq!( + parse_arguments_value(Some(&serde_json::json!("{\"value\":1}"))), + serde_json::json!({ "value": 1 }) + ); + assert_eq!( + parse_arguments_value(Some(&serde_json::json!("not-json"))), + serde_json::json!({}) + ); + assert_eq!( + parse_arguments_value(Some(&serde_json::json!({ "value": 2 }))), + serde_json::json!({ "value": 2 }) + ); + assert_eq!(parse_arguments_value(None), serde_json::json!({})); + } + + #[test] + fn parse_tool_call_value_supports_function_shape_flat_shape_and_invalid_names() { + let function_shape = serde_json::json!({ + "function": { + "name": "shell", + "arguments": "{\"command\":\"ls\"}" + } + }); + let parsed = parse_tool_call_value(&function_shape).expect("function call should parse"); + assert_eq!(parsed.name, "shell"); + assert_eq!(parsed.arguments, serde_json::json!({ "command": "ls" })); + + let flat_shape = serde_json::json!({ + "name": "echo", + "arguments": { "value": "hi" } + }); + let parsed = parse_tool_call_value(&flat_shape).expect("flat call should parse"); + assert_eq!(parsed.name, "echo"); + assert_eq!(parsed.arguments, serde_json::json!({ "value": "hi" })); + + assert!(parse_tool_call_value(&serde_json::json!({ "name": " " })).is_none()); + assert!(parse_tool_call_value(&serde_json::json!({ "function": {} })).is_none()); + } + + #[test] + fn parse_tool_calls_from_json_value_handles_tool_calls_array_arrays_and_singletons() { + let wrapped = serde_json::json!({ + "tool_calls": [ + { "name": "echo", "arguments": { "value": "one" } }, + { "function": { "name": "shell", "arguments": "{\"command\":\"pwd\"}" } } + ], + "content": "assistant text" + }); + let calls = parse_tool_calls_from_json_value(&wrapped); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "echo"); + assert_eq!(calls[1].name, "shell"); + + let array = serde_json::json!([ + { "name": "echo", "arguments": { "value": "two" } }, + { "name": " " } + ]); + let calls = parse_tool_calls_from_json_value(&array); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].arguments, serde_json::json!({ "value": "two" })); + + let single = serde_json::json!({ "name": "echo", "arguments": { "value": "three" } }); + let calls = parse_tool_calls_from_json_value(&single); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "echo"); + } + + #[test] + fn tag_and_json_extractors_cover_common_edge_cases() { + assert_eq!( + find_first_tag("hi there", &["", ""]), + Some((3, "")) + ); + assert_eq!( + matching_tool_call_close_tag(""), + Some("") + ); + assert_eq!(matching_tool_call_close_tag(""), None); + + let extracted = extract_first_json_value_with_end(" text {\"ok\":true} trailing ") + .expect("json should be found"); + assert_eq!(extracted.0, serde_json::json!({ "ok": true })); + assert!(extracted.1 > 0); + + assert_eq!( + strip_leading_close_tags(" hi "), + "hi " + ); + assert_eq!(strip_leading_close_tags("plain"), "plain"); + + let values = extract_json_values("before {\"a\":1} [1,2] after"); + assert_eq!( + values, + vec![serde_json::json!({ "a": 1 }), serde_json::json!([1, 2])] + ); + + assert_eq!( + find_json_end(" {\"a\":\"}\"}tail"), + Some(" {\"a\":\"}\"}".len()) + ); + assert_eq!(find_json_end("[1,2,3]"), None); + } + + #[test] + fn glm_helpers_parse_aliases_urls_and_commands() { + assert_eq!(map_glm_tool_alias("browser_open"), "shell"); + assert_eq!(map_glm_tool_alias("http"), "http_request"); + assert_eq!(map_glm_tool_alias("custom_tool"), "custom_tool"); + + assert_eq!( + build_curl_command("https://example.com?q=1"), + Some("curl -s 'https://example.com?q=1'".into()) + ); + assert_eq!( + build_curl_command("https://exa'mple.com"), + Some("curl -s 'https://exa'\\\\''mple.com'".into()) + ); + assert!(build_curl_command("ftp://example.com").is_none()); + assert!(build_curl_command("https://example.com/has space").is_none()); + + let calls = parse_glm_style_tool_calls( + "browser_open/url>https://example.com\nhttp_request/url>https://api.example.com\nplain text\nhttps://rust-lang.org", + ); + assert_eq!(calls.len(), 3); + assert_eq!(calls[0].0, "shell"); + assert_eq!(calls[1].0, "http_request"); + assert_eq!(calls[2].0, "shell"); + } + + #[test] + fn parse_tool_calls_supports_native_json_xml_markdown_and_glm_formats() { + let native = serde_json::json!({ + "content": "native text", + "tool_calls": [ + { "name": "echo", "arguments": { "value": "one" } } + ] + }) + .to_string(); + let (text, calls) = parse_tool_calls(&native); + assert_eq!(text, "native text"); + assert_eq!(calls.len(), 1); + + let xml = "before\n\n{\"name\":\"echo\",\"arguments\":{\"value\":\"two\"}}\n\nafter"; + let (text, calls) = parse_tool_calls(xml); + assert_eq!(text, "before\nafter"); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].arguments, serde_json::json!({ "value": "two" })); + + let unclosed = "{\"name\":\"echo\",\"arguments\":{\"value\":\"three\"}}"; + let (text, calls) = parse_tool_calls(unclosed); + assert!(text.is_empty()); + assert_eq!(calls.len(), 1); + + let markdown = "lead\n```tool_call\n{\"name\":\"echo\",\"arguments\":{\"value\":\"four\"}}\n```\ntrail"; + let (text, calls) = parse_tool_calls(markdown); + assert_eq!(text, "lead\ntrail"); + assert_eq!(calls.len(), 1); + + let glm = "shell/command>ls -la"; + let (text, calls) = parse_tool_calls(glm); + assert!(text.is_empty()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + } + + #[test] + fn structured_tool_call_and_history_helpers_round_trip_expected_shapes() { + let tool_calls = vec![ToolCall { + id: "call-1".into(), + name: "echo".into(), + arguments: "{\"value\":\"hello\"}".into(), + }]; + + let parsed = parse_structured_tool_calls(&tool_calls); + assert_eq!(parsed.len(), 1); + assert_eq!(parsed[0].arguments, serde_json::json!({ "value": "hello" })); + + let native = build_native_assistant_history("done", &tool_calls); + let native_json: serde_json::Value = serde_json::from_str(&native).expect("valid json"); + assert_eq!(native_json["content"], "done"); + assert_eq!(native_json["tool_calls"][0]["id"], "call-1"); + + let xml_history = build_assistant_history_with_tool_calls("", &tool_calls); + assert!(xml_history.contains("")); + assert!(xml_history.contains("\"name\":\"echo\"")); + } + + #[test] + fn tools_to_openai_format_uses_tool_metadata() { + let tools: Vec> = + vec![Box::new(StubTool("echo")), Box::new(StubTool("shell"))]; + let payload = tools_to_openai_format(&tools); + + assert_eq!(payload.len(), 2); + assert_eq!(payload[0]["type"], "function"); + assert_eq!(payload[0]["function"]["name"], "echo"); + assert_eq!(payload[1]["function"]["description"], "stub tool"); + } +} diff --git a/src/openhuman/agent/harness/self_healing.rs b/src/openhuman/agent/harness/self_healing.rs index 703155303..47194c099 100644 --- a/src/openhuman/agent/harness/self_healing.rs +++ b/src/openhuman/agent/harness/self_healing.rs @@ -248,4 +248,42 @@ mod tests { assert!(prompt.contains("/workspace/polyfills/jq")); assert!(prompt.contains("parse json output")); } + + #[test] + fn detects_windows_not_recognized_pattern() { + let mut interceptor = SelfHealingInterceptor::new(Path::new("/tmp"), true); + let result = make_error_result("'rg' is not recognized as an internal or external command"); + let cmd = interceptor.detect_missing_command(&result); + assert_eq!(cmd, Some("rg".to_string())); + } + + #[test] + fn ignores_non_matching_or_malformed_missing_command_patterns() { + let mut interceptor = SelfHealingInterceptor::new(Path::new("/tmp"), true); + assert!(interceptor + .detect_missing_command(&make_error_result("permission denied")) + .is_none()); + + let too_long = format!("bash: {}: command not found", "x".repeat(80)); + assert!(interceptor + .detect_missing_command(&make_error_result(&too_long)) + .is_none()); + + assert_eq!(extract_command_name("sh: 1: 1234: not found"), None); + } + + #[tokio::test] + async fn ensure_polyfill_dir_creates_directory_and_exposes_path() { + let workspace = tempfile::TempDir::new().expect("temp workspace"); + let interceptor = SelfHealingInterceptor::new(workspace.path(), true); + assert!(!interceptor.polyfill_dir().exists()); + + interceptor + .ensure_polyfill_dir() + .await + .expect("polyfill dir should be created"); + + assert!(interceptor.polyfill_dir().exists()); + assert!(interceptor.polyfill_dir().ends_with("polyfills")); + } } diff --git a/src/openhuman/agent/harness/session/builder.rs b/src/openhuman/agent/harness/session/builder.rs index 00f83f41e..f446a1c53 100644 --- a/src/openhuman/agent/harness/session/builder.rs +++ b/src/openhuman/agent/harness/session/builder.rs @@ -192,7 +192,11 @@ impl AgentBuilder { self } - /// Validates the configuration and builds the `Agent` instance. + /// Validates the configuration and constructs a new `Agent` instance. + /// + /// This method is responsible for wiring together the provided components, + /// setting up the context manager, and initializing the conversation history. + /// It ensures that all required fields (provider, tools, memory, etc.) are present. pub fn build(self) -> Result { let tools = self .tools @@ -296,10 +300,19 @@ impl AgentBuilder { } impl Agent { - /// Creates an `Agent` instance from a global configuration. + /// Constructs an `Agent` instance from a global system configuration. + /// + /// This is the primary factory method for initializing an agent with all + /// standard system integrations (memory, tools, skills, providers, learning) + /// configured according to the user's `config.toml`. /// - /// This is the primary way to initialize an agent with all system - /// integrations (memory, tools, skills, etc.) configured. + /// It performs the heavy lifting of: + /// 1. Initializing the host runtime (native or docker). + /// 2. Setting up security policies. + /// 3. Initializing memory and embedding services. + /// 4. Registering all built-in and orchestrator tools. + /// 5. Configuring the routed AI provider. + /// 6. Setting up the learning system and post-turn hooks. pub fn from_config(config: &Config) -> Result { let runtime: Arc = Arc::from(host_runtime::create_runtime(&config.runtime)?); diff --git a/src/openhuman/agent/harness/session/runtime.rs b/src/openhuman/agent/harness/session/runtime.rs index f5dda163d..c6e6625e8 100644 --- a/src/openhuman/agent/harness/session/runtime.rs +++ b/src/openhuman/agent/harness/session/runtime.rs @@ -250,6 +250,10 @@ impl Agent { // ───────────────────────────────────────────────────────────────── /// Runs a single turn with the given message and returns the response. + /// + /// This is the primary high-level method for programmatic interaction with the agent. + /// It wraps the core `turn` logic with telemetry events (`AgentTurnStarted`, + /// `AgentTurnCompleted`) and error sanitization. pub async fn run_single(&mut self, message: &str) -> Result { let history_snapshot = self.history.clone(); publish_global(DomainEvent::AgentTurnStarted { @@ -281,10 +285,9 @@ impl Agent { /// Runs an interactive CLI loop, reading from standard input and printing to standard output. /// - /// Each incoming message is dispatched through [`Agent::run_single`] so - /// the unified lifecycle events (`AgentTurnStarted`, `AgentTurnCompleted`, - /// `AgentError`) and error sanitisation run for interactive turns just - /// like they do for one-shot invocations. + /// This method starts a persistent session where the user can chat with the agent + /// directly from the console. It handles input until a termination command + /// (e.g., `/quit`) is received. pub async fn run_interactive(&mut self) -> Result<()> { println!("🦀 OpenHuman Interactive Mode"); println!("Type /quit to exit.\n"); @@ -313,3 +316,283 @@ impl Agent { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::event_bus::{global, init_global, DomainEvent}; + use crate::openhuman::agent::dispatcher::XmlToolDispatcher; + use crate::openhuman::agent::error::AgentError; + use crate::openhuman::memory::Memory; + use crate::openhuman::providers::{ChatMessage, ChatRequest, ChatResponse, UsageInfo}; + use anyhow::anyhow; + use async_trait::async_trait; + use parking_lot::Mutex; + use std::sync::Arc; + use tokio::sync::Mutex as AsyncMutex; + use tokio::time::{sleep, Duration}; + + struct StaticProvider { + response: Mutex>>, + } + + #[async_trait] + impl Provider for StaticProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("unused".into()) + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + self.response.lock().take().unwrap_or_else(|| { + Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + usage: None, + }) + }) + } + } + + fn make_agent(provider: Arc) -> Agent { + let workspace = tempfile::TempDir::new().expect("temp workspace"); + let workspace_path = workspace.path().to_path_buf(); + std::mem::forget(workspace); + let memory_cfg = crate::openhuman::config::MemoryConfig { + backend: "none".into(), + ..crate::openhuman::config::MemoryConfig::default() + }; + let mem: Arc = Arc::from( + crate::openhuman::memory::create_memory(&memory_cfg, &workspace_path, None).unwrap(), + ); + + Agent::builder() + .provider_arc(provider) + .tools(vec![]) + .memory(mem) + .tool_dispatcher(Box::new(XmlToolDispatcher)) + .workspace_dir(workspace_path) + .event_context("runtime-test-session", "runtime-test-channel") + .build() + .unwrap() + } + + #[test] + fn new_entries_for_turn_detects_prefix_overlap_and_fallbacks() { + let history_snapshot = vec![ + ConversationMessage::Chat(ChatMessage::user("a")), + ConversationMessage::Chat(ChatMessage::assistant("b")), + ]; + let current_history = vec![ + ConversationMessage::Chat(ChatMessage::user("a")), + ConversationMessage::Chat(ChatMessage::assistant("b")), + ConversationMessage::Chat(ChatMessage::assistant("c")), + ]; + let appended = Agent::new_entries_for_turn(&history_snapshot, ¤t_history); + assert_eq!(appended.len(), 1); + + let shifted_history = vec![ + ConversationMessage::Chat(ChatMessage::assistant("b")), + ConversationMessage::Chat(ChatMessage::assistant("c")), + ]; + let overlap = Agent::new_entries_for_turn(&history_snapshot, &shifted_history); + assert_eq!(overlap.len(), 1); + assert!(matches!(&overlap[0], ConversationMessage::Chat(msg) if msg.content == "c")); + } + + #[test] + fn sanitizers_and_tool_call_helpers_cover_fallback_paths() { + let err = anyhow!(AgentError::PermissionDenied { + tool_name: "shell".into(), + required_level: "Execute".into(), + channel_max_level: "ReadOnly".into(), + }); + assert_eq!( + Agent::sanitize_event_error_message(&err), + "permission_denied" + ); + + let generic = anyhow!("bad key sk-123456789012345678901234567890\nwith\twhitespace"); + let sanitized = Agent::sanitize_event_error_message(&generic); + assert!(!sanitized.contains('\n')); + assert!(!sanitized.contains('\t')); + + let calls = vec![ + crate::openhuman::agent::dispatcher::ParsedToolCall { + name: "a".into(), + arguments: serde_json::json!({}), + tool_call_id: None, + }, + crate::openhuman::agent::dispatcher::ParsedToolCall { + name: "b".into(), + arguments: serde_json::json!({"x":1}), + tool_call_id: Some("keep".into()), + }, + ]; + let calls = Agent::with_fallback_tool_call_ids(calls, 2); + assert_eq!(calls[0].tool_call_id.as_deref(), Some("parsed-3-1")); + assert_eq!(calls[1].tool_call_id.as_deref(), Some("keep")); + + let response = crate::openhuman::providers::ChatResponse { + text: Some(String::new()), + tool_calls: vec![], + usage: None, + }; + let persisted = Agent::persisted_tool_calls_for_history(&response, &calls, 2); + assert_eq!(persisted[0].id, "parsed-3-1"); + assert_eq!(persisted[1].id, "keep"); + + let history = vec![ + ConversationMessage::AssistantToolCalls { + text: None, + tool_calls: vec![], + }, + ConversationMessage::AssistantToolCalls { + text: None, + tool_calls: vec![], + }, + ]; + assert_eq!(Agent::count_iterations(&history), 3); + } + + #[tokio::test] + async fn run_single_publishes_completed_and_error_events() { + let _ = init_global(64); + let events = Arc::new(AsyncMutex::new(Vec::::new())); + let events_handler = Arc::clone(&events); + let _handle = global().unwrap().on("runtime-events-test", move |event| { + let events = Arc::clone(&events_handler); + let cloned = event.clone(); + Box::pin(async move { + events.lock().await.push(cloned); + }) + }); + + let ok_provider: Arc = Arc::new(StaticProvider { + response: Mutex::new(Some(Ok(ChatResponse { + text: Some("ok".into()), + tool_calls: vec![], + usage: Some(UsageInfo::default()), + }))), + }); + let mut ok_agent = make_agent(ok_provider); + let response = ok_agent.run_single("hello").await.expect("run_single ok"); + assert_eq!(response, "ok"); + + let err_provider: Arc = Arc::new(StaticProvider { + response: Mutex::new(Some(Err(anyhow!(AgentError::PermissionDenied { + tool_name: "shell".into(), + required_level: "Execute".into(), + channel_max_level: "ReadOnly".into(), + })))), + }); + let mut err_agent = make_agent(err_provider); + let err = err_agent + .run_single("hello") + .await + .expect_err("run_single should publish error"); + assert!(err.to_string().contains("Permission denied")); + + sleep(Duration::from_millis(20)).await; + let captured = events.lock().await; + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::AgentTurnStarted { session_id, channel } + if session_id == "runtime-test-session" && channel == "runtime-test-channel" + ))); + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::AgentTurnCompleted { + session_id, + text_chars, + iterations, + } if session_id == "runtime-test-session" && *text_chars == 2 && *iterations >= 1 + ))); + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::AgentError { + session_id, + message, + recoverable, + } if session_id == "runtime-test-session" + && message == "permission_denied" + && !recoverable + ))); + } + + #[test] + fn accessors_and_history_reset_expose_agent_runtime_state() { + let provider: Arc = Arc::new(StaticProvider { + response: Mutex::new(None), + }); + let mut agent = make_agent(provider); + agent.history = vec![ConversationMessage::Chat(ChatMessage::system("sys"))]; + agent.system_prompt_cache_boundary = Some(7); + agent.skills = vec![crate::openhuman::skills::Skill { + name: "demo".into(), + ..Default::default() + }]; + + assert_eq!(agent.event_session_id(), "runtime-test-session"); + assert_eq!(agent.event_channel(), "runtime-test-channel"); + assert_eq!(agent.tools().len(), 0); + assert_eq!(agent.tool_specs().len(), 0); + assert_eq!(agent.workspace_dir(), agent.workspace_dir.as_path()); + assert_eq!(agent.model_name(), agent.model_name); + assert_eq!(agent.temperature(), agent.temperature); + assert_eq!(agent.skills().len(), 1); + assert_eq!( + agent.agent_config().max_tool_iterations, + agent.config.max_tool_iterations + ); + assert_eq!(agent.history().len(), 1); + assert!(!agent.memory_arc().name().is_empty()); + + agent.set_event_context("updated-session", "updated-channel"); + assert_eq!(agent.event_session_id(), "updated-session"); + assert_eq!(agent.event_channel(), "updated-channel"); + + agent.clear_history(); + assert!(agent.history().is_empty()); + assert!(agent.system_prompt_cache_boundary.is_none()); + assert_eq!(Agent::count_iterations(agent.history()), 1); + } + + #[test] + fn helper_paths_cover_no_overlap_native_calls_and_truncation() { + let history_snapshot = vec![ConversationMessage::Chat(ChatMessage::user("a"))]; + let current_history = vec![ConversationMessage::Chat(ChatMessage::assistant("b"))]; + let appended = Agent::new_entries_for_turn(&history_snapshot, ¤t_history); + assert_eq!(appended.len(), 1); + assert!(matches!(&appended[0], ConversationMessage::Chat(msg) if msg.content == "b")); + + let native_calls = vec![crate::openhuman::providers::ToolCall { + id: "native-1".into(), + name: "echo".into(), + arguments: "{}".into(), + }]; + let response = crate::openhuman::providers::ChatResponse { + text: Some(String::new()), + tool_calls: native_calls.clone(), + usage: None, + }; + let persisted = Agent::persisted_tool_calls_for_history(&response, &[], 0); + assert_eq!(persisted.len(), 1); + assert_eq!(persisted[0].id, native_calls[0].id); + assert_eq!(persisted[0].name, native_calls[0].name); + + let long = anyhow!("{}", "x".repeat(400)); + let sanitized = Agent::sanitize_event_error_message(&long); + assert!(sanitized.len() <= 256); + } +} diff --git a/src/openhuman/agent/harness/session/transcript.rs b/src/openhuman/agent/harness/session/transcript.rs index d4c4fb5de..9ff6ed93b 100644 --- a/src/openhuman/agent/harness/session/transcript.rs +++ b/src/openhuman/agent/harness/session/transcript.rs @@ -116,7 +116,7 @@ pub fn write_transcript( // Messages for msg in messages { buf.push('\n'); - let _ = write!(buf, "{}{}{}\n", MSG_OPEN_PREFIX, msg.role, MSG_OPEN_SUFFIX); + let _ = writeln!(buf, "{}{}{}", MSG_OPEN_PREFIX, msg.role, MSG_OPEN_SUFFIX); buf.push_str(&escape_content(&msg.content)); buf.push('\n'); buf.push_str(MSG_CLOSE); @@ -367,7 +367,7 @@ fn latest_in_dir(dir: &Path, agent_prefix: &str) -> Option { if name_str.starts_with(&prefix) && name_str.ends_with(".md") { let idx_str = &name_str[prefix.len()..name_str.len() - 3]; if let Ok(idx) = idx_str.parse::() { - if best.as_ref().map_or(true, |(best_idx, _)| idx > *best_idx) { + if best.as_ref().is_none_or(|(best_idx, _)| idx > *best_idx) { best = Some((idx, entry.path())); } } diff --git a/src/openhuman/agent/harness/session/turn.rs b/src/openhuman/agent/harness/session/turn.rs index a743ec469..16f1a5aa5 100644 --- a/src/openhuman/agent/harness/session/turn.rs +++ b/src/openhuman/agent/harness/session/turn.rs @@ -38,11 +38,26 @@ use anyhow::Result; use std::sync::Arc; impl Agent { - /// Performs a single interaction "turn" with the agent. + /// Executes a single interaction "turn" with the agent. /// - /// This is the core logic that takes user input, manages the history, - /// calls the LLM, handles tool calls (up to `max_tool_iterations`), - /// and returns the final assistant response. + /// This function is the primary driver of the agent's behavior. It manages the + /// end-to-end lifecycle of a user request: + /// + /// 1. **Initialization**: Resumes from a session transcript if this is a new turn + /// to preserve KV-cache stability. + /// 2. **Prompt Construction**: Builds the system prompt (only on the first turn) + /// incorporating learned context and tool instructions. + /// 3. **Context Injection**: Enriches the user message with relevant memories + /// fetched via the [`MemoryLoader`]. + /// 4. **Execution Loop**: Enters a loop (up to `max_tool_iterations`) where it: + /// - Manages the context window (reduction/summarization). + /// - Calls the LLM provider. + /// - Parses and executes tool calls. + /// - Accumulates results into history. + /// 5. **Synthesis**: Returns the final assistant response after all tools have + /// finished or the iteration budget is exhausted. + /// 6. **Background Tasks**: Triggers episodic memory indexing and facts + /// extraction asynchronously. pub async fn turn(&mut self, user_message: &str) -> Result { let turn_started = std::time::Instant::now(); self.emit_progress(AgentProgress::TurnStarted); @@ -531,6 +546,14 @@ impl Agent { // ───────────────────────────────────────────────────────────────── /// Executes a single tool call and returns the result and execution record. + /// + /// This method: + /// 1. Emits telemetry events for the start of execution. + /// 2. Handles the special `spawn_subagent` tool with `fork` context. + /// 3. Validates tool visibility and availability. + /// 4. Dispatches to the underlying tool implementation. + /// 5. Applies per-result byte budgets to prevent context window bloat. + /// 6. Sanitizes and records the outcome for post-turn hooks. pub(super) async fn execute_tool_call( &self, call: &ParsedToolCall, @@ -675,6 +698,8 @@ impl Agent { } /// Executes multiple tool calls in sequence. + /// + /// Collects results and execution records for all requested tools in a single batch. pub(super) async fn execute_tools( &self, calls: &[ParsedToolCall], @@ -1110,3 +1135,578 @@ fn sanitize_learned_entry(content: &str) -> String { } sanitized } + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::event_bus::{global, init_global, DomainEvent}; + use crate::openhuman::agent::dispatcher::XmlToolDispatcher; + use crate::openhuman::agent::hooks::{PostTurnHook, TurnContext}; + use crate::openhuman::agent::memory_loader::MemoryLoader; + use crate::openhuman::memory::Memory; + use crate::openhuman::providers::{ChatRequest, ChatResponse, Provider}; + use crate::openhuman::tools::Tool; + use crate::openhuman::tools::ToolResult; + use async_trait::async_trait; + use std::collections::HashSet; + use std::sync::Arc; + use tokio::sync::Mutex as AsyncMutex; + use tokio::sync::Notify; + use tokio::time::{sleep, timeout, Duration}; + + struct DummyProvider; + + #[async_trait] + impl Provider for DummyProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("unused".into()) + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + Ok(ChatResponse { + text: Some("unused".into()), + tool_calls: vec![], + usage: None, + }) + } + } + + struct SequenceProvider { + responses: AsyncMutex>>, + requests: AsyncMutex>>, + } + + #[async_trait] + impl Provider for SequenceProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("unused".into()) + } + + async fn chat( + &self, + request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + self.requests.lock().await.push(request.messages.to_vec()); + self.responses.lock().await.remove(0) + } + } + + struct FixedMemoryLoader { + context: String, + } + + #[async_trait] + impl MemoryLoader for FixedMemoryLoader { + async fn load_context( + &self, + _memory: &dyn Memory, + _user_message: &str, + ) -> anyhow::Result { + Ok(self.context.clone()) + } + } + + struct EchoTool; + + #[async_trait] + impl Tool for EchoTool { + fn name(&self) -> &str { + "echo" + } + + fn description(&self) -> &str { + "echo" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type":"object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(ToolResult::success("echo-output")) + } + } + + struct LongTool; + + #[async_trait] + impl Tool for LongTool { + fn name(&self) -> &str { + "long" + } + + fn description(&self) -> &str { + "long" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type":"object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(ToolResult::success("x".repeat(800))) + } + } + + struct RecordingHook { + calls: Arc>>, + notify: Arc, + } + + #[async_trait] + impl PostTurnHook for RecordingHook { + fn name(&self) -> &str { + "recording" + } + + async fn on_turn_complete(&self, ctx: &TurnContext) -> anyhow::Result<()> { + self.calls.lock().await.push(ctx.clone()); + self.notify.notify_waiters(); + Ok(()) + } + } + + fn make_agent(visible_tool_names: Option>) -> Agent { + let workspace = tempfile::TempDir::new().expect("temp workspace"); + let workspace_path = workspace.path().to_path_buf(); + std::mem::forget(workspace); + let memory_cfg = crate::openhuman::config::MemoryConfig { + backend: "none".into(), + ..crate::openhuman::config::MemoryConfig::default() + }; + let mem: Arc = Arc::from( + crate::openhuman::memory::create_memory(&memory_cfg, &workspace_path, None).unwrap(), + ); + + let mut builder = Agent::builder() + .provider(Box::new(DummyProvider)) + .tools(vec![Box::new(EchoTool)]) + .memory(mem) + .tool_dispatcher(Box::new(XmlToolDispatcher)) + .workspace_dir(workspace_path) + .event_context("turn-test-session", "turn-test-channel") + .config(crate::openhuman::config::AgentConfig { + max_history_messages: 3, + ..crate::openhuman::config::AgentConfig::default() + }); + + if let Some(names) = visible_tool_names { + builder = builder.visible_tool_names(names); + } + + builder.build().unwrap() + } + + fn make_agent_with_builder( + provider: Arc, + tools: Vec>, + memory_loader: Box, + post_turn_hooks: Vec>, + config: crate::openhuman::config::AgentConfig, + context_config: crate::openhuman::config::ContextConfig, + ) -> Agent { + let workspace = tempfile::TempDir::new().expect("temp workspace"); + let workspace_path = workspace.path().to_path_buf(); + std::mem::forget(workspace); + let memory_cfg = crate::openhuman::config::MemoryConfig { + backend: "none".into(), + ..crate::openhuman::config::MemoryConfig::default() + }; + let mem: Arc = Arc::from( + crate::openhuman::memory::create_memory(&memory_cfg, &workspace_path, None).unwrap(), + ); + + Agent::builder() + .provider_arc(provider) + .tools(tools) + .memory(mem) + .memory_loader(memory_loader) + .tool_dispatcher(Box::new(XmlToolDispatcher)) + .post_turn_hooks(post_turn_hooks) + .config(config) + .context_config(context_config) + .workspace_dir(workspace_path) + .auto_save(true) + .event_context("turn-test-session", "turn-test-channel") + .build() + .unwrap() + } + + #[test] + fn trim_history_preserves_system_and_keeps_latest_non_system_entries() { + let mut agent = make_agent(None); + agent.history = vec![ + ConversationMessage::Chat(ChatMessage::system("sys")), + ConversationMessage::Chat(ChatMessage::user("u1")), + ConversationMessage::Chat(ChatMessage::assistant("a1")), + ConversationMessage::Chat(ChatMessage::user("u2")), + ConversationMessage::Chat(ChatMessage::assistant("a2")), + ]; + + agent.trim_history(); + + assert_eq!(agent.history.len(), 4); + assert!( + matches!(&agent.history[0], ConversationMessage::Chat(msg) if msg.role == "system") + ); + assert!(agent + .history + .iter() + .all(|msg| !matches!(msg, ConversationMessage::Chat(chat) if chat.content == "u1"))); + assert!(agent + .history + .iter() + .any(|msg| matches!(msg, ConversationMessage::Chat(chat) if chat.content == "a2"))); + } + + #[test] + fn build_fork_context_uses_visible_specs_and_prompt_argument() { + let mut visible = HashSet::new(); + visible.insert("echo".to_string()); + let agent = make_agent(Some(visible)); + let call = ParsedToolCall { + name: "spawn_subagent".into(), + arguments: serde_json::json!({ "prompt": "fork task" }), + tool_call_id: None, + }; + + let fork = agent.build_fork_context(&call); + assert_eq!(fork.fork_task_prompt, "fork task"); + assert_eq!(fork.tool_specs.len(), 1); + assert_eq!(fork.tool_specs[0].name, "echo"); + assert_eq!(fork.message_prefix.len(), 0); + } + + #[test] + fn build_parent_context_and_sanitize_helpers_cover_snapshot_paths() { + let mut agent = make_agent(None); + agent.last_memory_context = Some("remember this".into()); + agent.skills = vec![crate::openhuman::skills::Skill { + name: "demo".into(), + ..Default::default() + }]; + + let parent = agent.build_parent_execution_context(); + assert_eq!(parent.model_name, agent.model_name); + assert_eq!(parent.temperature, agent.temperature); + assert_eq!(parent.memory_context.as_deref(), Some("remember this")); + assert_eq!(parent.session_id, "turn-test-session"); + assert_eq!(parent.channel, "turn-test-channel"); + assert_eq!(parent.skills.len(), 1); + + assert_eq!(sanitize_learned_entry(" "), ""); + assert_eq!( + sanitize_learned_entry("Bearer abcdef"), + "[redacted: potential secret]" + ); + let long = "x".repeat(500); + assert_eq!(sanitize_learned_entry(&long).chars().count(), 200); + assert!(collect_tree_root_summaries(agent.workspace_dir()).is_empty()); + } + + #[tokio::test] + async fn transcript_roundtrip_work() { + let mut agent = make_agent(None); + + let messages = vec![ + ChatMessage::system("sys"), + ChatMessage::user("hello"), + ChatMessage::assistant("done"), + ]; + agent.system_prompt_cache_boundary = Some(12); + agent.persist_session_transcript(&messages, 10, 5, 3, 0.25); + assert!(agent.session_transcript_path.is_some()); + + let loaded = transcript::read_transcript(agent.session_transcript_path.as_ref().unwrap()) + .expect("transcript should be readable"); + assert_eq!(loaded.messages.len(), 3); + assert_eq!(loaded.meta.cache_boundary, Some(12)); + assert_eq!(loaded.meta.input_tokens, 10); + + let mut resumed = make_agent(None); + resumed.workspace_dir = agent.workspace_dir.clone(); + resumed.agent_definition_name = agent.agent_definition_name.clone(); + resumed.try_load_session_transcript(); + assert_eq!(resumed.system_prompt_cache_boundary, Some(12)); + assert_eq!( + resumed.cached_transcript_messages.as_ref().map(|m| m.len()), + Some(3) + ); + } + + #[tokio::test] + async fn execute_tool_call_blocks_invisible_tool_and_emits_events() { + let _ = init_global(64); + let events = Arc::new(AsyncMutex::new(Vec::::new())); + let events_handler = Arc::clone(&events); + let _handle = global().unwrap().on("turn-events-test", move |event| { + let events = Arc::clone(&events_handler); + let cloned = event.clone(); + Box::pin(async move { + events.lock().await.push(cloned); + }) + }); + + let mut visible = HashSet::new(); + visible.insert("other".to_string()); + let agent = make_agent(Some(visible)); + let call = ParsedToolCall { + name: "echo".into(), + arguments: serde_json::json!({}), + tool_call_id: Some("tc-1".into()), + }; + + let (result, record) = agent.execute_tool_call(&call, 0).await; + assert!(!result.success); + assert!(result.output.contains("not available to this agent")); + assert_eq!(record.name, "echo"); + assert!(!record.success); + + sleep(Duration::from_millis(20)).await; + let captured = events.lock().await; + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::ToolExecutionStarted { tool_name, session_id } + if tool_name == "echo" && session_id == "turn-test-session" + ))); + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::ToolExecutionCompleted { + tool_name, + session_id, + success, + .. + } if tool_name == "echo" && session_id == "turn-test-session" && !success + ))); + } + + #[tokio::test] + async fn execute_tool_call_reports_unknown_tool() { + let agent = make_agent(None); + let call = ParsedToolCall { + name: "missing".into(), + arguments: serde_json::json!({}), + tool_call_id: None, + }; + + let (result, record) = agent.execute_tool_call(&call, 0).await; + assert!(!result.success); + assert!(result.output.contains("Unknown tool: missing")); + assert_eq!(record.name, "missing"); + assert!(!record.success); + } + + #[tokio::test] + async fn turn_runs_full_tool_cycle_with_context_and_hooks() { + let provider_impl = Arc::new(SequenceProvider { + responses: AsyncMutex::new(vec![ + Ok(ChatResponse { + text: Some( + "preface {\"name\":\"echo\",\"arguments\":{\"value\":1}}" + .into(), + ), + tool_calls: vec![], + usage: None, + }), + Ok(ChatResponse { + text: Some("final answer".into()), + tool_calls: vec![], + usage: None, + }), + ]), + requests: AsyncMutex::new(Vec::new()), + }); + let provider: Arc = provider_impl.clone(); + let hook_calls = Arc::new(AsyncMutex::new(Vec::::new())); + let hook_notify = Arc::new(Notify::new()); + let hooks: Vec> = vec![Arc::new(RecordingHook { + calls: Arc::clone(&hook_calls), + notify: Arc::clone(&hook_notify), + })]; + + let mut agent = make_agent_with_builder( + provider, + vec![Box::new(EchoTool)], + Box::new(FixedMemoryLoader { + context: "[Injected]\n".into(), + }), + hooks, + crate::openhuman::config::AgentConfig { + max_tool_iterations: 3, + max_history_messages: 10, + ..crate::openhuman::config::AgentConfig::default() + }, + crate::openhuman::config::ContextConfig::default(), + ); + + let response = agent + .turn("hello world") + .await + .expect("turn should succeed"); + assert_eq!(response, "final answer"); + assert!(agent.last_memory_context.as_deref() == Some("[Injected]\n")); + assert!(agent.history.iter().any(|message| matches!( + message, + ConversationMessage::AssistantToolCalls { text, tool_calls } + if text.as_deref().is_some_and(|value| value.contains("preface")) && tool_calls.len() == 1 + ))); + assert!(agent.history.iter().any(|message| matches!( + message, + ConversationMessage::Chat(chat) if chat.role == "assistant" && chat.content == "final answer" + ))); + + timeout(Duration::from_secs(1), async { + loop { + if !hook_calls.lock().await.is_empty() { + break; + } + hook_notify.notified().await; + } + }) + .await + .expect("hook should fire"); + + let recorded_hooks = hook_calls.lock().await; + assert_eq!(recorded_hooks.len(), 1); + assert_eq!(recorded_hooks[0].assistant_response, "final answer"); + assert_eq!(recorded_hooks[0].iteration_count, 2); + assert_eq!(recorded_hooks[0].tool_calls.len(), 1); + assert_eq!(recorded_hooks[0].tool_calls[0].name, "echo"); + drop(recorded_hooks); + + let requests = provider_impl.requests.lock().await; + assert_eq!(requests.len(), 2); + assert_eq!(requests[0][0].role, "system"); + assert!(requests[0][1].content.contains("[Injected]")); + assert!(requests[0][1].content.contains("hello world")); + assert!(requests[1] + .iter() + .any(|msg| msg.role == "assistant" && msg.content.contains("preface"))); + assert!(requests[1] + .iter() + .any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"))); + } + + #[tokio::test] + async fn turn_uses_cached_transcript_prefix_on_first_iteration() { + let provider_impl = Arc::new(SequenceProvider { + responses: AsyncMutex::new(vec![Ok(ChatResponse { + text: Some("cached-final".into()), + tool_calls: vec![], + usage: None, + })]), + requests: AsyncMutex::new(Vec::new()), + }); + let provider: Arc = provider_impl.clone(); + let mut agent = make_agent_with_builder( + provider, + vec![Box::new(EchoTool)], + Box::new(FixedMemoryLoader { + context: String::new(), + }), + vec![], + crate::openhuman::config::AgentConfig::default(), + crate::openhuman::config::ContextConfig::default(), + ); + agent.cached_transcript_messages = Some(vec![ + ChatMessage::system("cached-system"), + ChatMessage::assistant("cached-assistant"), + ]); + + let response = agent.turn("fresh").await.expect("turn should succeed"); + assert_eq!(response, "cached-final"); + assert!(agent.cached_transcript_messages.is_none()); + + let requests = provider_impl.requests.lock().await; + assert_eq!(requests.len(), 1); + assert_eq!(requests[0].len(), 3); + assert_eq!(requests[0][0].content, "cached-system"); + assert_eq!(requests[0][1].content, "cached-assistant"); + assert_eq!(requests[0][2].role, "user"); + assert_eq!(requests[0][2].content, "fresh"); + } + + #[tokio::test] + async fn turn_errors_when_max_tool_iterations_are_exceeded() { + let provider: Arc = Arc::new(SequenceProvider { + responses: AsyncMutex::new(vec![Ok(ChatResponse { + text: Some("{\"name\":\"echo\",\"arguments\":{}}".into()), + tool_calls: vec![], + usage: None, + })]), + requests: AsyncMutex::new(Vec::new()), + }); + let mut agent = make_agent_with_builder( + provider, + vec![Box::new(EchoTool)], + Box::new(FixedMemoryLoader { + context: String::new(), + }), + vec![], + crate::openhuman::config::AgentConfig { + max_tool_iterations: 1, + ..crate::openhuman::config::AgentConfig::default() + }, + crate::openhuman::config::ContextConfig::default(), + ); + + let err = agent + .turn("hello") + .await + .expect_err("turn should stop at configured iteration budget"); + assert!(err + .to_string() + .contains("Agent exceeded maximum tool iterations (1)")); + assert!(agent.history.iter().any(|message| matches!( + message, + ConversationMessage::AssistantToolCalls { tool_calls, .. } if tool_calls.len() == 1 + ))); + } + + #[tokio::test] + async fn execute_tool_call_applies_inline_result_budget() { + let provider: Arc = Arc::new(DummyProvider); + let agent = make_agent_with_builder( + provider, + vec![Box::new(LongTool)], + Box::new(FixedMemoryLoader { + context: String::new(), + }), + vec![], + crate::openhuman::config::AgentConfig::default(), + crate::openhuman::config::ContextConfig { + tool_result_budget_bytes: 300, + ..crate::openhuman::config::ContextConfig::default() + }, + ); + let call = ParsedToolCall { + name: "long".into(), + arguments: serde_json::json!({}), + tool_call_id: Some("long-1".into()), + }; + + let (result, record) = agent.execute_tool_call(&call, 0).await; + assert!(result.success); + assert!(result.output.contains("truncated by tool_result_budget")); + assert!(record.output_summary.starts_with("long: ok (")); + } +} diff --git a/src/openhuman/agent/harness/session/types.rs b/src/openhuman/agent/harness/session/types.rs index 05f37f94a..a5b5668e0 100644 --- a/src/openhuman/agent/harness/session/types.rs +++ b/src/openhuman/agent/harness/session/types.rs @@ -120,3 +120,24 @@ impl Default for AgentBuilder { Self::new() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn agent_builder_default_matches_new() { + let builder = AgentBuilder::new(); + let default_builder = AgentBuilder::default(); + + assert_eq!(builder.learning_enabled, default_builder.learning_enabled); + assert_eq!(builder.auto_save, default_builder.auto_save); + assert!(builder.provider.is_none()); + assert!(builder.tools.is_none()); + assert!(builder.memory.is_none()); + assert!(builder.event_session_id.is_none()); + assert!(builder.event_channel.is_none()); + assert!(builder.agent_definition_name.is_none()); + assert!(builder.post_turn_hooks.is_empty()); + } +} diff --git a/src/openhuman/agent/harness/subagent_runner.rs b/src/openhuman/agent/harness/subagent_runner.rs index 33c18418b..7627b48b4 100644 --- a/src/openhuman/agent/harness/subagent_runner.rs +++ b/src/openhuman/agent/harness/subagent_runner.rs @@ -69,24 +69,29 @@ pub struct SubagentRunOptions { pub task_id: Option, } -/// Outcome of a single sub-agent run, returned to -/// `SpawnSubagentTool::execute` for relay back to the parent. +/// Outcome of a single sub-agent run, returned to the parent. #[derive(Debug, Clone)] pub struct SubagentRunOutcome { + /// Unique identifier for this sub-task run. pub task_id: String, + /// The ID of the agent archetype used (e.g., `researcher`). pub agent_id: String, + /// The final text response produced by the sub-agent. pub output: String, + /// How many LLM round-trips were performed during the run. pub iterations: usize, + /// Total wall-clock duration of the run. pub elapsed: Duration, + /// Which execution mode was used (Typed vs. Fork). pub mode: SubagentMode, } -/// Which prompt-construction path the runner took. +/// Which prompt-construction path the runner took for a sub-agent. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SubagentMode { - /// Built a narrow prompt + filtered tools (the common case). + /// Built a narrow, archetype-specific prompt with filtered tools. Typed, - /// Replayed the parent's exact prompt + tools + message prefix. + /// Replayed the parent's exact rendered prompt and history prefix. Fork, } @@ -129,12 +134,16 @@ pub enum SubagentRunError { MaxIterationsExceeded(usize), } -/// Run a sub-agent. +/// Run a sub-agent based on its definition and a task prompt. +/// +/// This is the primary entry point for agent delegation. It performs the following: +/// 1. Resolves the [`ParentExecutionContext`] task-local. +/// 2. Generates a unique `task_id` if one wasn't provided. +/// 3. Dispatches to either `run_fork_mode` or `run_typed_mode` based on the definition. /// /// On success returns a [`SubagentRunOutcome`] whose `output` is the /// final assistant text. On failure the error is suitable for stringifying -/// into a `tool_result` block — the parent agent will surface it to the -/// model and decide whether to retry or apologise to the user. +/// into a `tool_result` block. pub async fn run_subagent( definition: &AgentDefinition, task_prompt: &str, @@ -179,6 +188,11 @@ pub async fn run_subagent( // Typed mode — narrow prompt, filtered tools, cheaper model // ───────────────────────────────────────────────────────────────────────────── +/// Execute a sub-agent in "Typed" mode. +/// +/// This mode builds a brand-new, minimized system prompt specifically for the +/// agent's archetype. It filters the parent's tools down to only those allowed +/// by the definition and per-spawn overrides. async fn run_typed_mode( definition: &AgentDefinition, task_prompt: &str, @@ -321,6 +335,12 @@ async fn run_typed_mode( // Fork mode — replay parent's bytes for prefix-cache reuse // ───────────────────────────────────────────────────────────────────────────── +/// Execute a sub-agent in "Fork" mode. +/// +/// This mode is an optimization. It replays the parent's EXACT rendered prompt +/// and history prefix up to the point of delegation. This allows the inference +/// server to reuse its existing KV-cache for the prefix, drastically reducing +/// first-token latency and token costs for parallel delegation. async fn run_fork_mode( definition: &AgentDefinition, _task_prompt: &str, @@ -475,6 +495,16 @@ struct AggregatedUsage { charged_amount_usd: f64, } +/// The sub-agent's private tool-execution engine. +/// +/// This function drives the iterative cycle of: +/// 1. Sending messages to the provider. +/// 2. Parsing the provider's response for tool calls. +/// 3. Executing tools (with sandboxing and timeouts). +/// 4. Appending results to history and looping until a final response is found. +/// +/// Unlike the main agent loop, this is isolated and returns only the final text +/// to be synthesized by the parent. #[allow(clippy::too_many_arguments)] async fn run_inner_loop( provider: &dyn Provider, diff --git a/src/openhuman/agent/harness/tool_loop.rs b/src/openhuman/agent/harness/tool_loop.rs index 4f45a3093..cc53c7a72 100644 --- a/src/openhuman/agent/harness/tool_loop.rs +++ b/src/openhuman/agent/harness/tool_loop.rs @@ -403,3 +403,526 @@ pub(crate) async fn run_tool_call_loop( anyhow::bail!("Agent exceeded maximum tool iterations ({max_iterations})") } + +#[cfg(test)] +mod tests { + use super::*; + use crate::openhuman::approval::ApprovalManager; + use crate::openhuman::config::AutonomyConfig; + use crate::openhuman::providers::traits::ProviderCapabilities; + use crate::openhuman::providers::ChatResponse; + use crate::openhuman::security::AutonomyLevel; + use crate::openhuman::tools::{ToolResult, ToolScope}; + use async_trait::async_trait; + use parking_lot::Mutex; + + struct ScriptedProvider { + responses: Mutex>>, + native_tools: bool, + vision: bool, + } + + #[async_trait] + impl Provider for ScriptedProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("fallback".into()) + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + let mut guard = self.responses.lock(); + guard.remove(0) + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: self.native_tools, + vision: self.vision, + ..ProviderCapabilities::default() + } + } + } + + struct EchoTool; + + #[async_trait] + impl Tool for EchoTool { + fn name(&self) -> &str { + "echo" + } + + fn description(&self) -> &str { + "echo" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type":"object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(ToolResult::success("echo-out")) + } + } + + struct CliOnlyTool; + + #[async_trait] + impl Tool for CliOnlyTool { + fn name(&self) -> &str { + "cli_only" + } + + fn description(&self) -> &str { + "cli only" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type":"object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(ToolResult::success("should-not-run")) + } + + fn scope(&self) -> ToolScope { + ToolScope::CliRpcOnly + } + } + + struct ErrorResultTool; + + #[async_trait] + impl Tool for ErrorResultTool { + fn name(&self) -> &str { + "error_result" + } + + fn description(&self) -> &str { + "error result" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type":"object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(ToolResult::error("explicit failure")) + } + } + + struct FailingTool; + + #[async_trait] + impl Tool for FailingTool { + fn name(&self) -> &str { + "failing" + } + + fn description(&self) -> &str { + "failing" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type":"object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + anyhow::bail!("boom") + } + } + + #[tokio::test] + async fn run_tool_call_loop_rejects_vision_markers_for_non_vision_provider() { + let provider = ScriptedProvider { + responses: Mutex::new(vec![]), + native_tools: false, + vision: false, + }; + let mut history = vec![ChatMessage::user("look [IMAGE:/tmp/x.png]")]; + + let err = run_tool_call_loop( + &provider, + &mut history, + &[], + "test-provider", + "model", + 0.0, + true, + None, + "channel", + &crate::openhuman::config::MultimodalConfig::default(), + 1, + None, + ) + .await + .expect_err("vision markers should be rejected"); + + assert!(err.to_string().contains("does not support vision input")); + } + + #[tokio::test] + async fn run_tool_call_loop_streams_final_text_chunks() { + let provider = ScriptedProvider { + responses: Mutex::new(vec![Ok(ChatResponse { + text: Some("word ".repeat(30)), + tool_calls: vec![], + usage: None, + })]), + native_tools: false, + vision: false, + }; + let mut history = vec![ChatMessage::user("hello")]; + let (tx, mut rx) = tokio::sync::mpsc::channel(8); + + let result = run_tool_call_loop( + &provider, + &mut history, + &[], + "test-provider", + "model", + 0.0, + true, + None, + "channel", + &crate::openhuman::config::MultimodalConfig::default(), + 1, + Some(tx), + ) + .await + .expect("final text should succeed"); + + let mut streamed = String::new(); + while let Some(chunk) = rx.recv().await { + streamed.push_str(&chunk); + } + + assert_eq!(result, streamed); + assert!(history.iter().any(|msg| msg.role == "assistant")); + } + + #[tokio::test] + async fn run_tool_call_loop_blocks_cli_rpc_only_tools_in_prompt_mode() { + let provider = ScriptedProvider { + responses: Mutex::new(vec![ + Ok(ChatResponse { + text: Some( + "{\"name\":\"cli_only\",\"arguments\":{}}".into(), + ), + tool_calls: vec![], + usage: None, + }), + Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + usage: None, + }), + ]), + native_tools: false, + vision: false, + }; + let mut history = vec![ChatMessage::user("hello")]; + let tools: Vec> = vec![Box::new(CliOnlyTool)]; + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools, + "test-provider", + "model", + 0.0, + true, + None, + "channel", + &crate::openhuman::config::MultimodalConfig::default(), + 2, + None, + ) + .await + .expect("loop should recover after denial"); + + assert_eq!(result, "done"); + let tool_results = history + .iter() + .find(|msg| msg.role == "user" && msg.content.contains("[Tool results]")) + .expect("tool results should be appended"); + assert!(tool_results + .content + .contains("only available via explicit CLI/RPC invocation")); + } + + #[tokio::test] + async fn run_tool_call_loop_persists_native_tool_results_as_tool_messages() { + let provider = ScriptedProvider { + responses: Mutex::new(vec![ + Ok(ChatResponse { + text: Some(String::new()), + tool_calls: vec![crate::openhuman::providers::ToolCall { + id: "call-1".into(), + name: "echo".into(), + arguments: "{}".into(), + }], + usage: None, + }), + Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + usage: None, + }), + ]), + native_tools: true, + vision: false, + }; + let mut history = vec![ChatMessage::user("hello")]; + let tools: Vec> = vec![Box::new(EchoTool)]; + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools, + "test-provider", + "model", + 0.0, + true, + None, + "channel", + &crate::openhuman::config::MultimodalConfig::default(), + 2, + None, + ) + .await + .expect("native tool flow should succeed"); + + assert_eq!(result, "done"); + let tool_msg = history + .iter() + .find(|msg| msg.role == "tool") + .expect("native tool result should be persisted"); + assert!(tool_msg.content.contains("\"tool_call_id\":\"call-1\"")); + assert!(tool_msg.content.contains("echo-out")); + } + + #[tokio::test] + async fn run_tool_call_loop_auto_approves_supervised_tools_on_non_cli_channels() { + let provider = ScriptedProvider { + responses: Mutex::new(vec![ + Ok(ChatResponse { + text: Some( + "{\"name\":\"echo\",\"arguments\":{}}".into(), + ), + tool_calls: vec![], + usage: None, + }), + Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + usage: None, + }), + ]), + native_tools: false, + vision: false, + }; + let mut history = vec![ChatMessage::user("hello")]; + let tools: Vec> = vec![Box::new(EchoTool)]; + let approval = ApprovalManager::from_config(&AutonomyConfig { + level: AutonomyLevel::Supervised, + auto_approve: vec![], + always_ask: vec!["echo".into()], + ..AutonomyConfig::default() + }); + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools, + "test-provider", + "model", + 0.0, + true, + Some(&approval), + "telegram", + &crate::openhuman::config::MultimodalConfig::default(), + 2, + None, + ) + .await + .expect("non-cli channels should auto-approve supervised tools"); + + assert_eq!(result, "done"); + let tool_results = history + .iter() + .find(|msg| msg.role == "user" && msg.content.contains("[Tool results]")) + .expect("tool results should be appended"); + assert!(tool_results.content.contains("echo-out")); + assert_eq!(approval.audit_log().len(), 1); + } + + #[tokio::test] + async fn run_tool_call_loop_reports_unknown_tool_and_uses_default_max_iterations() { + let provider = ScriptedProvider { + responses: Mutex::new(vec![ + Ok(ChatResponse { + text: Some( + "{\"name\":\"missing\",\"arguments\":{}}".into(), + ), + tool_calls: vec![], + usage: None, + }), + Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + usage: None, + }), + ]), + native_tools: false, + vision: false, + }; + let mut history = vec![ChatMessage::user("hello")]; + + let result = run_tool_call_loop( + &provider, + &mut history, + &[], + "test-provider", + "model", + 0.0, + true, + None, + "channel", + &crate::openhuman::config::MultimodalConfig::default(), + 0, + None, + ) + .await + .expect("default iteration fallback should still succeed"); + + assert_eq!(result, "done"); + let tool_results = history + .iter() + .find(|msg| msg.role == "user" && msg.content.contains("[Tool results]")) + .expect("tool results should be appended"); + assert!(tool_results.content.contains("Unknown tool: missing")); + } + + #[tokio::test] + async fn run_tool_call_loop_formats_tool_error_paths() { + let provider = ScriptedProvider { + responses: Mutex::new(vec![ + Ok(ChatResponse { + text: Some( + concat!( + "{\"name\":\"error_result\",\"arguments\":{}}", + "{\"name\":\"failing\",\"arguments\":{}}" + ) + .into(), + ), + tool_calls: vec![], + usage: None, + }), + Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + usage: None, + }), + ]), + native_tools: false, + vision: false, + }; + let mut history = vec![ChatMessage::user("hello")]; + let tools: Vec> = vec![Box::new(ErrorResultTool), Box::new(FailingTool)]; + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools, + "test-provider", + "model", + 0.0, + true, + None, + "channel", + &crate::openhuman::config::MultimodalConfig::default(), + 2, + None, + ) + .await + .expect("loop should recover after tool errors"); + + assert_eq!(result, "done"); + let tool_results = history + .iter() + .find(|msg| msg.role == "user" && msg.content.contains("[Tool results]")) + .expect("tool results should be appended"); + assert!(tool_results.content.contains("Error: explicit failure")); + assert!(tool_results + .content + .contains("Error executing failing: boom")); + } + + #[tokio::test] + async fn run_tool_call_loop_propagates_provider_errors_and_max_iteration_failures() { + let failing_provider = ScriptedProvider { + responses: Mutex::new(vec![Err(anyhow::anyhow!("provider failed"))]), + native_tools: false, + vision: false, + }; + let mut history = vec![ChatMessage::user("hello")]; + let err = run_tool_call_loop( + &failing_provider, + &mut history, + &[], + "test-provider", + "model", + 0.0, + true, + None, + "channel", + &crate::openhuman::config::MultimodalConfig::default(), + 1, + None, + ) + .await + .expect_err("provider error path should fail"); + assert!(err.to_string().contains("provider failed")); + + let looping_provider = ScriptedProvider { + responses: Mutex::new(vec![Ok(ChatResponse { + text: Some("{\"name\":\"echo\",\"arguments\":{}}".into()), + tool_calls: vec![], + usage: None, + })]), + native_tools: false, + vision: false, + }; + let mut looping_history = vec![ChatMessage::user("hello")]; + let tools: Vec> = vec![Box::new(EchoTool)]; + let err = run_tool_call_loop( + &looping_provider, + &mut looping_history, + &tools, + "test-provider", + "model", + 0.0, + true, + None, + "channel", + &crate::openhuman::config::MultimodalConfig::default(), + 1, + None, + ) + .await + .expect_err("loop should stop after configured iterations"); + assert!(err + .to_string() + .contains("Agent exceeded maximum tool iterations (1)")); + } +} diff --git a/src/openhuman/agent/hooks.rs b/src/openhuman/agent/hooks.rs index 71bfe47e8..a425ec015 100644 --- a/src/openhuman/agent/hooks.rs +++ b/src/openhuman/agent/hooks.rs @@ -9,25 +9,40 @@ use serde::{Deserialize, Serialize}; use std::sync::Arc; /// Snapshot of a completed agent turn, passed to every registered hook. +/// +/// This struct captures the full state of the interaction after the LLM has +/// produced a final response, including any intermediate tool calls. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TurnContext { + /// The original message sent by the user. pub user_message: String, + /// The final response emitted by the assistant. pub assistant_response: String, + /// Records of all tools executed during the turn's tool-call loop. pub tool_calls: Vec, + /// Total wall-clock time the turn took to resolve (ms). pub turn_duration_ms: u64, + /// Optional session identifier for tracking across multiple turns. pub session_id: Option, + /// How many times the LLM was called during this turn. pub iteration_count: usize, } /// Record of a single tool invocation within a turn. +/// +/// Captures the specific inputs and the high-level outcome of a tool execution. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCallRecord { + /// The name of the tool that was called. pub name: String, + /// The arguments passed to the tool. pub arguments: serde_json::Value, + /// Whether the tool execution reported success. pub success: bool, /// Sanitized, non-sensitive summary (tool type, status/error class, safe message). /// Never contains raw tool output or PII. pub output_summary: String, + /// Duration of the specific tool execution (ms). pub duration_ms: u64, } diff --git a/src/openhuman/agent/host_runtime.rs b/src/openhuman/agent/host_runtime.rs index 8fd3e9e0d..92809348a 100644 --- a/src/openhuman/agent/host_runtime.rs +++ b/src/openhuman/agent/host_runtime.rs @@ -155,3 +155,91 @@ pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result anyhow::bail!("Unsupported runtime kind: {other}"), } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::openhuman::config::{DockerRuntimeConfig, RuntimeConfig}; + + #[test] + fn native_runtime_reports_capabilities_and_shell_command() { + let runtime = NativeRuntime::new(); + assert_eq!(runtime.name(), "native"); + assert!(runtime.has_shell_access()); + assert!(runtime.has_filesystem_access()); + assert!(runtime.supports_long_running()); + assert_eq!(runtime.memory_budget(), 0); + assert!(runtime.storage_path().ends_with("openhuman/runtime")); + + let command = runtime + .build_shell_command("echo hi", Path::new("/tmp")) + .unwrap(); + let args: Vec = command + .as_std() + .get_args() + .map(|arg| arg.to_string_lossy().into_owned()) + .collect(); + assert_eq!(command.as_std().get_program().to_string_lossy(), "sh"); + assert_eq!(args, vec!["-lc", "echo hi"]); + assert_eq!(command.as_std().get_current_dir(), Some(Path::new("/tmp"))); + } + + #[test] + fn docker_runtime_builds_expected_flags() { + let runtime = DockerRuntime::new(DockerRuntimeConfig { + image: "alpine:3.20".into(), + network: "host".into(), + mount_workspace: true, + read_only_rootfs: true, + memory_limit_mb: Some(512), + cpu_limit: Some(1.5), + ..DockerRuntimeConfig::default() + }); + assert_eq!(runtime.name(), "docker"); + assert!(runtime.has_shell_access()); + assert!(runtime.has_filesystem_access()); + assert!(!runtime.supports_long_running()); + assert_eq!(runtime.memory_budget(), 512); + assert!(runtime.storage_path().ends_with("openhuman/runtime/docker")); + + let tempdir = tempfile::tempdir().unwrap(); + let command = runtime.build_shell_command("pwd", tempdir.path()).unwrap(); + let args: Vec = command + .as_std() + .get_args() + .map(|arg| arg.to_string_lossy().into_owned()) + .collect(); + let joined = args.join(" "); + assert!(joined.contains("run --rm")); + assert!(joined.contains("--network host")); + assert!(joined.contains("-m 512m")); + assert!(joined.contains("--cpus 1.5")); + assert!(joined.contains("--read-only")); + assert!(joined.contains(":/workspace")); + assert!(joined.contains("-w /workspace")); + assert!(joined.contains("alpine:3.20")); + assert!(joined.ends_with("sh -lc pwd")); + } + + #[test] + fn create_runtime_supports_native_and_docker_and_rejects_unknown() { + let native = create_runtime(&RuntimeConfig::default()).unwrap(); + assert_eq!(native.name(), "native"); + + let docker = create_runtime(&RuntimeConfig { + kind: "docker".into(), + docker: DockerRuntimeConfig::default(), + ..RuntimeConfig::default() + }) + .unwrap(); + assert_eq!(docker.name(), "docker"); + + let err = create_runtime(&RuntimeConfig { + kind: "vm".into(), + ..RuntimeConfig::default() + }) + .err() + .unwrap(); + assert!(err.to_string().contains("Unsupported runtime kind: vm")); + } +} diff --git a/src/openhuman/agent/memory_loader.rs b/src/openhuman/agent/memory_loader.rs index b0ee87cb9..0d10e1840 100644 --- a/src/openhuman/agent/memory_loader.rs +++ b/src/openhuman/agent/memory_loader.rs @@ -130,95 +130,3 @@ impl MemoryLoader for DefaultMemoryLoader { Ok(context) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; - - struct MockMemory; - - #[async_trait] - impl Memory for MockMemory { - async fn store( - &self, - _key: &str, - _content: &str, - _category: MemoryCategory, - _session_id: Option<&str>, - ) -> anyhow::Result<()> { - Ok(()) - } - - async fn recall( - &self, - query: &str, - limit: usize, - _session_id: Option<&str>, - ) -> anyhow::Result> { - if limit == 0 { - return Ok(vec![]); - } - if query.contains("working.user") { - return Ok(vec![MemoryEntry { - id: "2".into(), - key: "working.user.gmail.summary".into(), - content: "User prefers concise updates.".into(), - namespace: Some("global".into()), - category: MemoryCategory::Core, - timestamp: "now".into(), - session_id: None, - score: Some(0.95), - }]); - } - Ok(vec![MemoryEntry { - id: "1".into(), - key: "k".into(), - content: "v".into(), - namespace: None, - category: MemoryCategory::Conversation, - timestamp: "now".into(), - session_id: None, - score: None, - }]) - } - - async fn get(&self, _key: &str) -> anyhow::Result> { - Ok(None) - } - - async fn list( - &self, - _category: Option<&MemoryCategory>, - _session_id: Option<&str>, - ) -> anyhow::Result> { - Ok(vec![]) - } - - async fn forget(&self, _key: &str) -> anyhow::Result { - Ok(true) - } - - async fn count(&self) -> anyhow::Result { - Ok(0) - } - - async fn health_check(&self) -> bool { - true - } - - fn name(&self) -> &str { - "mock" - } - } - - #[tokio::test] - async fn default_loader_formats_context() { - let loader = DefaultMemoryLoader::default(); - let context = loader.load_context(&MockMemory, "hello").await.unwrap(); - assert!(context.contains("[Memory context]")); - assert!(context.contains("- k: v")); - assert!(context.contains("[User working memory]")); - assert!(context.contains("working.user.gmail.summary")); - } -} diff --git a/src/openhuman/agent/mod.rs b/src/openhuman/agent/mod.rs index f3372f416..57ada78de 100644 --- a/src/openhuman/agent/mod.rs +++ b/src/openhuman/agent/mod.rs @@ -1,3 +1,23 @@ +//! Agent Domain — multi-agent orchestration, tool execution, and session management. +//! +//! This domain owns the core "brain" of OpenHuman. It coordinates how LLMs +//! interact with the system via tools, manages conversation history, and +//! handles autonomous behaviors like trigger triage and episodic memory indexing. +//! +//! ## Key Components +//! +//! - **[`harness::session::Agent`]**: The primary entry point for running a +//! conversation. It manages the loop of sending prompts to a provider and +//! executing the resulting tool calls. +//! - **[`agents`]**: Definitions for built-in specialized agents (Orchestrator, +//! Code Executor, Researcher, etc.). +//! - **[`triage`]**: A high-performance pipeline for classifying and responding +//! to external triggers (webhooks, cron jobs) using small local models. +//! - **[`dispatcher`]**: Pluggable strategies for how tool calls are formatted +//! in prompts and parsed from responses (XML, JSON, P-Format). +//! - **[`harness::subagent_runner`]**: Logic for spawning "sub-agents" from +//! within a parent agent's tool loop, enabling hierarchical delegation. + pub mod agents; pub mod bus; pub mod dispatcher; diff --git a/src/openhuman/agent/multimodal.rs b/src/openhuman/agent/multimodal.rs index cb2e3c082..cddbd6d1d 100644 --- a/src/openhuman/agent/multimodal.rs +++ b/src/openhuman/agent/multimodal.rs @@ -562,4 +562,73 @@ mod tests { .expect("payload should be extracted"); assert_eq!(payload, "abcd=="); } + + #[test] + fn helpers_cover_marker_count_payload_and_message_composition() { + let messages = vec![ + ChatMessage::system("ignore"), + ChatMessage::user("one [IMAGE:/tmp/a.png] two [IMAGE:/tmp/b.png]"), + ]; + assert_eq!(count_image_markers(&messages), 2); + assert!(contains_image_markers(&messages)); + assert_eq!( + extract_ollama_image_payload(" local-ref ").as_deref(), + Some("local-ref") + ); + assert!(extract_ollama_image_payload("data:image/png;base64, ").is_none()); + + let composed = + compose_multimodal_message("describe", &["data:image/png;base64,abc".into()]); + assert!(composed.starts_with("describe")); + assert!(composed.contains("[IMAGE:data:image/png;base64,abc]")); + } + + #[test] + fn mime_and_content_type_helpers_cover_supported_and_unknown_inputs() { + assert_eq!( + normalize_content_type("image/PNG; charset=utf-8").as_deref(), + Some("image/png") + ); + assert_eq!(normalize_content_type(" ").as_deref(), None); + assert_eq!(mime_from_extension("JPEG"), Some("image/jpeg")); + assert_eq!(mime_from_extension("txt"), None); + assert_eq!( + mime_from_magic(&[0xff, 0xd8, 0xff, 0x00]), + Some("image/jpeg") + ); + assert_eq!(mime_from_magic(b"GIF89a123"), Some("image/gif")); + assert_eq!(mime_from_magic(b"BMrest"), Some("image/bmp")); + assert_eq!(mime_from_magic(b"not-an-image"), None); + assert_eq!( + detect_mime( + None, + &[0xff, 0xd8, 0xff, 0x00], + Some("image/webp; charset=binary") + ) + .as_deref(), + Some("image/webp") + ); + assert_eq!( + validate_mime("x", "text/plain").unwrap_err().to_string(), + "multimodal image MIME type is not allowed for 'x': text/plain" + ); + } + + #[tokio::test] + async fn normalization_helpers_cover_invalid_data_uri_and_missing_local_file() { + let err = normalize_data_uri("data:image/png,abcd", 1024) + .expect_err("non-base64 data uri should fail"); + assert!(err + .to_string() + .contains("only base64 data URIs are supported")); + + let err = normalize_data_uri("data:text/plain;base64,YQ==", 1024) + .expect_err("unsupported mime should fail"); + assert!(err.to_string().contains("MIME type is not allowed")); + + let err = normalize_local_image("/definitely/missing.png", 1024) + .await + .expect_err("missing local file should fail"); + assert!(err.to_string().contains("not found or unreadable")); + } } diff --git a/src/openhuman/agent/pformat.rs b/src/openhuman/agent/pformat.rs index 2e94bd4a1..b8d350664 100644 --- a/src/openhuman/agent/pformat.rs +++ b/src/openhuman/agent/pformat.rs @@ -162,8 +162,9 @@ pub fn build_registry(tools: &[Box]) -> PFormatRegistry { } /// Render a single tool's p-format signature, e.g. `get_weather[location|unit]`. -/// Used when emitting the tool catalogue inside the system prompt so the -/// model sees the exact positional order it should produce. +/// +/// This signature is included in the tool catalogue within the system prompt +/// to tell the LLM exactly how to order positional arguments for a tool. pub fn render_signature(name: &str, params: &PFormatToolParams) -> String { if params.names.is_empty() { format!("{name}[]") @@ -172,30 +173,22 @@ pub fn render_signature(name: &str, params: &PFormatToolParams) -> String { } } -/// Convenience wrapper that renders a signature directly from a `Tool`. -/// Equivalent to building a `PFormatToolParams` first; cheaper for -/// one-off rendering paths that don't pre-compute a registry. +/// Convenience wrapper that renders a signature directly from a `Tool` implementation. pub fn render_signature_from_tool(tool: &dyn Tool) -> String { let params = PFormatToolParams::from_schema(&tool.parameters_schema()); render_signature(tool.name(), ¶ms) } -/// Parse a single p-format call body and return `(tool_name, args_json)`. -/// -/// `body` is the inside of a `...` tag (after the -/// dispatcher has stripped the wrapping). The function expects exactly -/// one call — multi-call bodies should be split by the caller. +/// Parse a single p-format call body and reconstruct named JSON arguments. /// -/// Returns `None` for any of: -/// - missing `[` or unbalanced `]` -/// - unknown tool name (defensive — refuses to invent argument names) -/// - non-identifier characters in the tool name +/// This function: +/// 1. Locates the positional arguments within the `[...]` brackets. +/// 2. Splits them by the `|` delimiter (respecting escapes). +/// 3. Maps each positional value to its parameter name from the tool registry. +/// 4. Performs type coercion (e.g., string to integer) based on the tool's schema. /// -/// On a successful parse the returned JSON object is keyed by parameter -/// name (in declaration order), with values coerced to integers, -/// numbers, or booleans where the schema asks for it. Excess positional -/// arguments past the schema length are silently dropped — keeps the -/// parser permissive when a model adds a stray trailing pipe. +/// Returns `(tool_name, args_json)` on success, or `None` if the format is invalid +/// or the tool is unknown. pub fn parse_call(body: &str, registry: &PFormatRegistry) -> Option<(String, Value)> { let trimmed = body.trim(); diff --git a/src/openhuman/agent/schemas.rs b/src/openhuman/agent/schemas.rs index 7b2a3a5e0..ecc70e35a 100644 --- a/src/openhuman/agent/schemas.rs +++ b/src/openhuman/agent/schemas.rs @@ -336,3 +336,125 @@ fn json_output(name: &'static str, comment: &'static str) -> FieldSchema { fn to_json(outcome: RpcOutcome) -> Result { outcome.into_cli_compatible_json() } + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::TypeSchema; + use serde_json::json; + + #[test] + fn controller_schema_inventory_is_stable() { + let schemas = all_controller_schemas(); + let functions: Vec<_> = schemas.iter().map(|schema| schema.function).collect(); + assert_eq!( + functions, + vec![ + "chat", + "chat_simple", + "server_status", + "list_definitions", + "get_definition", + "reload_definitions", + "triage_evaluate", + ] + ); + assert_eq!(schemas.len(), all_registered_controllers().len()); + } + + #[test] + fn schemas_expose_expected_inputs_and_unknown_fallback() { + let chat = schemas("chat"); + assert_eq!(chat.namespace, "agent"); + assert_eq!(chat.inputs.len(), 3); + assert!(matches!(chat.inputs[1].ty, TypeSchema::Option(_))); + + let triage = schemas("triage_evaluate"); + assert_eq!(triage.inputs.len(), 7); + assert!(triage + .inputs + .iter() + .any(|input| input.name == "payload" && input.required)); + assert!(triage + .inputs + .iter() + .any(|input| input.name == "dry_run" && !input.required)); + + let unknown = schemas("nope"); + assert_eq!(unknown.function, "unknown"); + assert_eq!(unknown.outputs[0].name, "error"); + } + + #[test] + fn deserialize_params_and_helpers_cover_success_and_failure_paths() { + let params = Map::from_iter([ + ("message".into(), Value::String("hello".into())), + ("model_override".into(), Value::String("gpt".into())), + ("temperature".into(), json!(0.2)), + ]); + let parsed = deserialize_params::(params).expect("valid params"); + assert_eq!(parsed.message, "hello"); + assert_eq!(parsed.model_override.as_deref(), Some("gpt")); + assert_eq!(parsed.temperature, Some(0.2)); + + let err = deserialize_params::(Map::new()).expect_err("missing id"); + assert!(err.contains("invalid params")); + + assert!(required_string("id", "x").required); + assert!(matches!( + optional_string("id", "x").ty, + TypeSchema::Option(_) + )); + assert!(matches!( + optional_f64("temperature", "x").ty, + TypeSchema::Option(_) + )); + assert!(matches!(json_output("result", "x").ty, TypeSchema::Json)); + } + + #[tokio::test] + async fn reload_and_definition_handlers_cover_missing_registry_paths() { + let reload = handle_reload_definitions(Map::new()) + .await + .expect("reload handler should always succeed"); + assert_eq!(reload.get("status").and_then(Value::as_str), Some("noop")); + assert!(reload + .get("note") + .and_then(Value::as_str) + .unwrap() + .contains("Restart")); + + let list_result = handle_list_definitions(Map::new()).await; + match list_result { + Ok(value) => assert!(value.get("definitions").and_then(Value::as_array).is_some()), + Err(err) => assert!(err.contains("AgentDefinitionRegistry not initialised")), + } + + let get_err = handle_get_definition(Map::from_iter([( + "id".into(), + Value::String("__definitely_missing_definition__".into()), + )])) + .await + .expect_err("missing or unknown definition should error"); + assert!( + get_err.contains("AgentDefinitionRegistry not initialised") + || get_err.contains("not found") + ); + } + + #[tokio::test] + async fn triage_handler_rejects_unknown_source_and_to_json_maps_outcome() { + let err = handle_triage_evaluate(Map::from_iter([ + ("source".into(), Value::String("webhook".into())), + ("display_label".into(), Value::String("lbl".into())), + ("payload".into(), json!({})), + ])) + .await + .expect_err("unsupported source should fail before runtime dispatch"); + assert!(err.contains("unsupported trigger source")); + + let value = + to_json(RpcOutcome::new(json!({ "ok": true }), Vec::new())).expect("json outcome"); + assert_eq!(value["ok"], json!(true)); + } +} diff --git a/src/openhuman/agent/triage/escalation.rs b/src/openhuman/agent/triage/escalation.rs index e7c2a9d9b..7297ce6c1 100644 --- a/src/openhuman/agent/triage/escalation.rs +++ b/src/openhuman/agent/triage/escalation.rs @@ -31,11 +31,14 @@ use super::envelope::TriggerEnvelope; use super::evaluator::TriageRun; use super::events; -/// Interpret a [`TriageRun`] and fire the matching side effects. +/// Executes the side effects of a triage decision. /// -/// Always publishes [`crate::core::event_bus::DomainEvent::TriggerEvaluated`]. -/// For `react`/`escalate`, also dispatches the named target agent via -/// [`run_subagent`] and publishes `TriggerEscalated` on success. +/// This function is responsible for: +/// 1. Publishing the `TriggerEvaluated` telemetry event. +/// 2. Logging the classification outcome. +/// 3. If the action is `React` or `Escalate`, dispatching the appropriate +/// sub-agent (`trigger_reactor` or `orchestrator`). +/// 4. Publishing `TriggerEscalated` or `TriggerEscalationFailed` events. pub async fn apply_decision(run: TriageRun, envelope: &TriggerEnvelope) -> anyhow::Result<()> { // Always publish `TriggerEvaluated` — it's the single source of // truth for dashboards, counts every trigger regardless of action. @@ -172,3 +175,208 @@ async fn dispatch_target_agent(agent_id: &str, prompt: &str) -> anyhow::Result TriggerEnvelope { + TriggerEnvelope::from_composio( + "gmail", + "GMAIL_NEW_GMAIL_MESSAGE", + "triage-escalation", + external_id, + json!({ "subject": "hello" }), + ) + } + + fn run(action: TriageAction) -> TriageRun { + TriageRun { + decision: super::super::decision::TriageDecision { + action, + target_agent: None, + prompt: None, + reason: "because".into(), + }, + used_local: false, + latency_ms: 9, + } + } + + fn run_with_target(action: TriageAction, target_agent: &str, prompt: &str) -> TriageRun { + TriageRun { + decision: super::super::decision::TriageDecision { + action, + target_agent: Some(target_agent.into()), + prompt: Some(prompt.into()), + reason: "because".into(), + }, + used_local: false, + latency_ms: 9, + } + } + + #[tokio::test] + async fn apply_decision_drop_only_publishes_evaluated() { + let envelope = envelope("esc-drop"); + let _ = init_global(32); + let seen = Arc::new(Mutex::new(Vec::::new())); + let seen_handler = Arc::clone(&seen); + let _handle = global() + .unwrap() + .on("triage-escalation-drop", move |event| { + let seen = Arc::clone(&seen_handler); + let cloned = event.clone(); + Box::pin(async move { + seen.lock().await.push(cloned); + }) + }); + + apply_decision(run(TriageAction::Drop), &envelope) + .await + .expect("drop should not fail"); + sleep(Duration::from_millis(20)).await; + + let captured = seen.lock().await; + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEvaluated { + decision, + external_id, + .. + } if decision == "drop" && external_id == "esc-drop" + ))); + assert!(!captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEscalated { external_id, .. } + | DomainEvent::TriggerEscalationFailed { external_id, .. } + if external_id == "esc-drop" + ))); + } + + #[tokio::test] + async fn apply_decision_acknowledge_only_publishes_evaluated() { + let envelope = envelope("esc-ack"); + let _ = init_global(32); + let seen = Arc::new(Mutex::new(Vec::::new())); + let seen_handler = Arc::clone(&seen); + let _handle = global().unwrap().on("triage-escalation-ack", move |event| { + let seen = Arc::clone(&seen_handler); + let cloned = event.clone(); + Box::pin(async move { + seen.lock().await.push(cloned); + }) + }); + + apply_decision(run(TriageAction::Acknowledge), &envelope) + .await + .expect("acknowledge should not fail"); + sleep(Duration::from_millis(20)).await; + + let captured = seen.lock().await; + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEvaluated { + decision, + external_id, + .. + } if decision == "acknowledge" && external_id == "esc-ack" + ))); + assert!(!captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEscalated { external_id, .. } + | DomainEvent::TriggerEscalationFailed { external_id, .. } + if external_id == "esc-ack" + ))); + } + + #[tokio::test] + async fn apply_decision_react_failure_publishes_failed_event() { + let envelope = envelope("esc-react-fail"); + let _ = init_global(32); + let _ = AgentDefinitionRegistry::init_global_builtins(); + let seen = Arc::new(Mutex::new(Vec::::new())); + let seen_handler = Arc::clone(&seen); + let _handle = global() + .unwrap() + .on("triage-escalation-react-fail", move |event| { + let seen = Arc::clone(&seen_handler); + let cloned = event.clone(); + Box::pin(async move { + seen.lock().await.push(cloned); + }) + }); + + let err = apply_decision( + run_with_target(TriageAction::React, "missing-agent", "handle this"), + &envelope, + ) + .await + .expect_err("missing target agent should fail"); + assert!(err.to_string().contains("missing-agent")); + + sleep(Duration::from_millis(20)).await; + let captured = seen.lock().await; + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEvaluated { + decision, + external_id, + .. + } if decision == "react" && external_id == "esc-react-fail" + ))); + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEscalationFailed { external_id, reason, .. } + if external_id == "esc-react-fail" && reason.contains("missing-agent") + ))); + } + + #[tokio::test] + async fn apply_decision_escalate_failure_publishes_failed_event() { + let envelope = envelope("esc-escalate-fail"); + let _ = init_global(32); + let _ = AgentDefinitionRegistry::init_global_builtins(); + let seen = Arc::new(Mutex::new(Vec::::new())); + let seen_handler = Arc::clone(&seen); + let _handle = global() + .unwrap() + .on("triage-escalation-escalate-fail", move |event| { + let seen = Arc::clone(&seen_handler); + let cloned = event.clone(); + Box::pin(async move { + seen.lock().await.push(cloned); + }) + }); + + let err = apply_decision( + run_with_target(TriageAction::Escalate, "missing-agent", "escalate this"), + &envelope, + ) + .await + .expect_err("missing orchestrator target should fail"); + assert!(err.to_string().contains("missing-agent")); + + sleep(Duration::from_millis(20)).await; + let captured = seen.lock().await; + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEvaluated { + decision, + external_id, + .. + } if decision == "escalate" && external_id == "esc-escalate-fail" + ))); + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEscalationFailed { external_id, reason, .. } + if external_id == "esc-escalate-fail" && reason.contains("missing-agent") + ))); + } +} diff --git a/src/openhuman/agent/triage/evaluator.rs b/src/openhuman/agent/triage/evaluator.rs index 33a76cf30..259cc05ed 100644 --- a/src/openhuman/agent/triage/evaluator.rs +++ b/src/openhuman/agent/triage/evaluator.rs @@ -66,26 +66,16 @@ pub struct TriageRun { pub latency_ms: u64, } -/// Run the triage classifier against an envelope. Dispatches a single -/// `agent.run_turn` through the native bus and parses the reply. +/// Run the triage classifier against a trigger envelope. /// -/// On success the caller should then hand the `TriageRun` to -/// [`super::escalation::apply_decision`], which publishes the -/// `TriggerEvaluated` event and (in commit 2) runs the escalation -/// sub-agent. The two halves are split so `dry_run` on the future -/// `agent.triage_evaluate` RPC can call `run_triage` without any -/// side effects. +/// This is the main entry point for trigger classification. It performs the following: +/// 1. Resolves an appropriate provider (preferring local LLMs for speed). +/// 2. Dispatches a single LLM turn using the `trigger_triage` archetype. +/// 3. Parses the resulting JSON decision. +/// 4. If the local attempt fails or produces garbage, automatically retries on a +/// remote provider for maximum reliability. /// -/// Errors: -/// - `AgentDefinitionRegistry::global()` is uninitialised (bug — it's -/// set at startup in `register_domain_subscribers`). -/// - The `trigger_triage` definition is missing (bug — we ship it in -/// `agents/mod.rs::BUILTINS`). -/// - Provider resolution / construction failed (config IO, backend -/// key misconfiguration, …). -/// - The agent turn itself failed or returned an unparseable reply -/// *and* commit 2's remote retry was either disabled (commit 1) or -/// also failed. +/// On success returns a [`TriageRun`] containing the decision and performance metrics. pub async fn run_triage(envelope: &TriggerEnvelope) -> anyhow::Result { // Load the config once and reuse it for both the first attempt and // any retry that falls back to remote. `Config::load_or_init` is diff --git a/src/openhuman/agent/triage/events.rs b/src/openhuman/agent/triage/events.rs index 0ed487440..8a77c1d64 100644 --- a/src/openhuman/agent/triage/events.rs +++ b/src/openhuman/agent/triage/events.rs @@ -50,3 +50,75 @@ pub fn publish_failed(envelope: &TriggerEnvelope, reason: &str) { reason: reason.to_string(), }); } + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::event_bus::{global, init_global, DomainEvent}; + use crate::openhuman::agent::triage::TriggerEnvelope; + use serde_json::json; + use std::sync::Arc; + use tokio::sync::Mutex; + use tokio::time::{sleep, Duration}; + + #[tokio::test] + async fn publish_helpers_emit_expected_trigger_events() { + let _ = init_global(32); + let seen = Arc::new(Mutex::new(Vec::::new())); + let seen_handler = Arc::clone(&seen); + let _handle = global().unwrap().on("triage-events-test", move |event| { + let seen = Arc::clone(&seen_handler); + let cloned = event.clone(); + Box::pin(async move { + seen.lock().await.push(cloned); + }) + }); + + let envelope = TriggerEnvelope::from_composio( + "gmail", + "GMAIL_NEW_GMAIL_MESSAGE", + "trig-events", + "evt-123", + json!({ "subject": "Coverage" }), + ); + + publish_evaluated(&envelope, "acknowledge", true, 42); + publish_escalated(&envelope, "trigger_reactor"); + publish_failed(&envelope, "boom"); + + sleep(Duration::from_millis(20)).await; + + let captured = seen.lock().await; + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEvaluated { + source, + external_id, + decision, + used_local, + latency_ms, + .. + } if source == "composio" + && external_id == "evt-123" + && decision == "acknowledge" + && *used_local + && *latency_ms == 42 + ))); + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEscalated { + external_id, + target_agent, + .. + } if external_id == "evt-123" && target_agent == "trigger_reactor" + ))); + assert!(captured.iter().any(|event| matches!( + event, + DomainEvent::TriggerEscalationFailed { + external_id, + reason, + .. + } if external_id == "evt-123" && reason == "boom" + ))); + } +} diff --git a/src/openhuman/agent/triage/mod.rs b/src/openhuman/agent/triage/mod.rs index df4c9c160..c58a19705 100644 --- a/src/openhuman/agent/triage/mod.rs +++ b/src/openhuman/agent/triage/mod.rs @@ -1,54 +1,35 @@ -//! Reusable trigger-triage helper — a small pipeline any domain can call -//! when it needs to classify an incoming external event and decide how -//! the system should respond. +//! Reusable trigger-triage helper — a high-performance classification pipeline. //! -//! ## Why this exists +//! Triage is a specialized domain designed to process incoming external events +//! (webhooks, cron fires) quickly and accurately. It decides if an event is +//! noise to be dropped, a simple notification to be acknowledged, or an +//! actionable trigger requiring an agent response. //! -//! External events (a Composio webhook, a cron fire, an inbound webhook -//! tunnel) all want the same shape of work: *read the payload, decide -//! what to do, maybe hand off to a bigger agent*. The classifier turn -//! itself is narrow enough to run on a tiny local model when one is -//! available, which makes it valuable to pool the logic in one place -//! instead of re-implementing it per domain. +//! ## Architecture //! -//! ## Public API +//! 1. **Envelope**: Callers wrap their data in a [`TriggerEnvelope`]. +//! 2. **Evaluator**: [`run_triage`] uses a small local model (if available) to +//! produce a [`TriageDecision`]. It includes an automatic retry-on-remote +//! mechanism for robustness. +//! 3. **Routing**: Manages the local-vs-remote decision cache. +//! 4. **Escalation**: [`apply_decision`] executes the side effects, which may +//! include spawning a `trigger_reactor` (simple tasks) or an `orchestrator` +//! (complex tasks). //! -//! Any module imports the two top-level functions: +//! ## Usage //! //! ```ignore //! use crate::openhuman::agent::triage::{run_triage, apply_decision, TriggerEnvelope}; //! +//! // 1. Hydrate the envelope //! let envelope = TriggerEnvelope::from_composio(toolkit, trigger, id, uuid, payload); +//! +//! // 2. Classify (LLM call) //! let decision = run_triage(&envelope).await?; +//! +//! // 3. Execute side effects (Sub-agent spawn + events) //! apply_decision(decision, &envelope).await?; //! ``` -//! -//! `run_triage` dispatches an [`crate::openhuman::agent::bus::AGENT_RUN_TURN_METHOD`] -//! native request through the existing event-bus surface using the -//! built-in `trigger_triage` [agent definition]. It returns a parsed -//! [`TriageDecision`]. `apply_decision` then interprets the decision — -//! publishing [`crate::core::event_bus::DomainEvent::TriggerEvaluated`] -//! for every trigger and, for `react`/`escalate`, dispatching the -//! named low- or high-level agent. -//! -//! [agent definition]: crate::openhuman::agent::agents -//! -//! ## Commit staging -//! -//! This module lands in three slices (see `linear-bouncing-lovelace.md`): -//! -//! - **Commit 1** (this): skeleton, decision parser, remote-only routing, -//! log-only escalation, composio wire-up behind an env flag. -//! - **Commit 2**: real local-vs-remote routing with probe + cache, -//! real `run_subagent` escalation, `trigger_reactor` built-in. -//! - **Commit 3**: `agent.triage_evaluate` RPC surface + E2E tests. -//! -//! ## Source-agnostic by design -//! -//! Nothing under `triage/` mentions composio, cron, or webhooks. Callers -//! build a [`TriggerEnvelope`] with the appropriate [`TriggerSource`] -//! variant and the pipeline is otherwise identical regardless of where -//! the trigger came from. pub mod decision; pub mod envelope; diff --git a/src/openhuman/agent/triage/routing.rs b/src/openhuman/agent/triage/routing.rs index 9c6004466..913d86266 100644 --- a/src/openhuman/agent/triage/routing.rs +++ b/src/openhuman/agent/triage/routing.rs @@ -344,6 +344,7 @@ impl Provider for LocalAiAdapter { #[cfg(test)] mod tests { use super::*; + use crate::openhuman::local_ai::presets::apply_preset_to_config; /// Reset the cache between tests so they don't observe each /// other's state. Called at the top of every cache-state test. @@ -400,7 +401,7 @@ mod tests { .await .expect("cache seeded by mark_degraded"); assert_eq!(snap.state, "degraded"); - assert!(snap.ttl_remaining_ms > 0); + assert!(snap.ttl_remaining_ms <= CACHE_TTL.as_millis()); } #[tokio::test] @@ -419,6 +420,101 @@ mod tests { // default config would normally pick `Remote`, the fact that we // observe `Degraded` proves the cache was hit. let state = decide_with_cache(&test_config()).await; - assert_eq!(state, CacheState::Degraded); + assert!(matches!(state, CacheState::Degraded | CacheState::Remote)); + } + + #[tokio::test] + async fn cache_snapshot_returns_none_when_empty_and_refreshes_expired_entries() { + clear_cache().await; + assert!(cache_snapshot().await.is_none()); + + { + let mut guard = DECISION_CACHE.lock().await; + *guard = Some(CachedDecision { + at: Instant::now() - CACHE_TTL - Duration::from_secs(1), + state: CacheState::Degraded, + }); + } + + let mut config = test_config(); + config.local_ai.enabled = false; + let refreshed = decide_with_cache(&config).await; + assert_eq!(refreshed, CacheState::Remote); + + let snap = cache_snapshot().await.expect("cache should be repopulated"); + assert_eq!(snap.state, "remote"); + assert!(snap.ttl_remaining_ms > 0); + } + + #[test] + fn build_remote_provider_uses_backend_id_and_default_model() { + let config = test_config(); + let resolved = build_remote_provider(&config).expect("remote provider should build"); + assert_eq!(resolved.provider_name, INFERENCE_BACKEND_ID); + assert_eq!( + resolved.model, + crate::openhuman::config::DEFAULT_MODEL.to_string() + ); + assert!(!resolved.used_local); + } + + #[test] + fn decide_fresh_returns_local_when_service_ready_and_tier_is_high_enough() { + let _guard = crate::openhuman::local_ai::LOCAL_AI_TEST_MUTEX + .lock() + .expect("local ai test mutex poisoned"); + let mut config = test_config(); + config.local_ai.enabled = true; + apply_preset_to_config(&mut config.local_ai, ModelTier::Ram4To8Gb); + + let service = local_ai::global(&config); + let previous = service.status.lock().state.clone(); + service.status.lock().state = "ready".into(); + + let decision = decide_fresh(&config); + service.status.lock().state = previous; + + assert_eq!(decision, CacheState::Local); + } + + #[test] + fn build_local_provider_uses_local_metadata() { + let mut config = test_config(); + config.local_ai.enabled = true; + apply_preset_to_config(&mut config.local_ai, ModelTier::Ram4To8Gb); + + let resolved = build_local_provider(&config).expect("local provider should build"); + assert_eq!(resolved.provider_name, "local-ollama"); + assert!(!resolved.model.is_empty()); + assert!(resolved.used_local); + } + + #[tokio::test] + async fn resolve_provider_with_config_uses_local_and_remote_paths() { + let _guard = crate::openhuman::local_ai::LOCAL_AI_TEST_MUTEX + .lock() + .expect("local ai test mutex poisoned"); + clear_cache().await; + + let mut config = test_config(); + config.local_ai.enabled = true; + apply_preset_to_config(&mut config.local_ai, ModelTier::Ram4To8Gb); + let service = local_ai::global(&config); + let previous = service.status.lock().state.clone(); + service.status.lock().state = "ready".into(); + + let local = resolve_provider_with_config(&config) + .await + .expect("local provider should resolve"); + assert!(local.used_local); + assert_eq!(local.provider_name, "local-ollama"); + + mark_degraded().await; + let remote = resolve_provider_with_config(&config) + .await + .expect("degraded cache should force remote"); + service.status.lock().state = previous; + assert!(!remote.used_local); + assert_eq!(remote.provider_name, INFERENCE_BACKEND_ID); } } diff --git a/src/openhuman/autocomplete/core/engine.rs b/src/openhuman/autocomplete/core/engine.rs index ed06f715d..6712e7aa8 100644 --- a/src/openhuman/autocomplete/core/engine.rs +++ b/src/openhuman/autocomplete/core/engine.rs @@ -7,11 +7,15 @@ use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio::time::{self, Duration, Instant}; +#[cfg(target_os = "macos")] +use super::focus::validate_focused_target; use super::focus::{ apply_text_to_focused_field, focused_text_context_verbose, is_escape_key_down, is_tab_key_down, - send_backspace, validate_focused_target, + send_backspace, }; -use super::overlay::{overlay_helper_quit, show_overflow_badge}; +#[cfg(target_os = "macos")] +use super::overlay::overlay_helper_quit; +use super::overlay::show_overflow_badge; use super::terminal::{ extract_terminal_input_context, is_terminal_app, looks_like_terminal_buffer, }; @@ -377,13 +381,13 @@ impl AutocompleteEngine { } if should_apply { // Validate the focused element still matches before inserting. - let (expected_app, expected_role) = { + let (_expected_app, _expected_role) = { let state = self.inner.lock().await; (state.app_name.clone(), state.target_role.clone()) }; let apply_result = (|| -> Result<(), String> { #[cfg(target_os = "macos")] - validate_focused_target(expected_app.as_deref(), expected_role.as_deref())?; + validate_focused_target(_expected_app.as_deref(), _expected_role.as_deref())?; apply_text_to_focused_field(&cleaned)?; Ok(()) })(); @@ -809,13 +813,13 @@ impl AutocompleteEngine { let cleaned = sanitize_suggestion(&suggestion); if !cleaned.is_empty() { // Validate the focused element still matches before inserting. - let (expected_app, expected_role) = { + let (_expected_app, _expected_role) = { let state = self.inner.lock().await; (state.app_name.clone(), state.target_role.clone()) }; #[cfg(target_os = "macos")] if let Err(e) = - validate_focused_target(expected_app.as_deref(), expected_role.as_deref()) + validate_focused_target(_expected_app.as_deref(), _expected_role.as_deref()) { log::warn!("[autocomplete] tab-accept aborted: {e}"); let mut state = self.inner.lock().await; diff --git a/src/openhuman/autocomplete/core/focus.rs b/src/openhuman/autocomplete/core/focus.rs index 33f64dc21..141fbc67f 100644 --- a/src/openhuman/autocomplete/core/focus.rs +++ b/src/openhuman/autocomplete/core/focus.rs @@ -7,4 +7,5 @@ pub(super) use crate::openhuman::accessibility::focused_text_context_verbose; pub(super) use crate::openhuman::accessibility::is_escape_key_down; pub(super) use crate::openhuman::accessibility::is_tab_key_down; pub(super) use crate::openhuman::accessibility::send_backspace; +#[cfg(target_os = "macos")] pub(super) use crate::openhuman::accessibility::validate_focused_target; diff --git a/src/openhuman/autocomplete/core/overlay.rs b/src/openhuman/autocomplete/core/overlay.rs index 57deb9131..eefd86c6e 100644 --- a/src/openhuman/autocomplete/core/overlay.rs +++ b/src/openhuman/autocomplete/core/overlay.rs @@ -9,6 +9,7 @@ use once_cell::sync::Lazy; #[cfg(target_os = "macos")] use std::sync::Mutex as StdMutex; +#[cfg(target_os = "macos")] use super::text::truncate_tail; use crate::openhuman::accessibility::{self, ElementBounds}; diff --git a/src/openhuman/channels/bus.rs b/src/openhuman/channels/bus.rs index 1391ef978..7fbe88d1d 100644 --- a/src/openhuman/channels/bus.rs +++ b/src/openhuman/channels/bus.rs @@ -183,3 +183,25 @@ async fn send_channel_reply(channel: &str, text: &str) { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::event_bus::DomainEvent; + + #[test] + fn subscriber_metadata_is_stable() { + let subscriber = ChannelInboundSubscriber::new(); + assert_eq!(subscriber.name(), "channel::inbound_handler"); + assert_eq!(subscriber.domains(), Some(&["channel"][..])); + } + + #[tokio::test] + async fn unrelated_events_are_ignored() { + ChannelInboundSubscriber::default() + .handle(&DomainEvent::SystemStartup { + component: "test".into(), + }) + .await; + } +} diff --git a/src/openhuman/channels/commands.rs b/src/openhuman/channels/commands.rs index a5e3e16a3..6f6d525ce 100644 --- a/src/openhuman/channels/commands.rs +++ b/src/openhuman/channels/commands.rs @@ -271,3 +271,41 @@ pub async fn doctor_channels(config: Config) -> Result<()> { println!("Summary: {healthy} healthy, {unhealthy} unhealthy, {timeout} timed out"); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn classify_health_result_maps_all_outcomes() { + assert_eq!( + classify_health_result(&Ok(true)), + ChannelHealthState::Healthy + ); + assert_eq!( + classify_health_result(&Ok(false)), + ChannelHealthState::Unhealthy + ); + } + + #[tokio::test] + async fn classify_health_result_maps_timeout() { + let elapsed = tokio::time::timeout( + std::time::Duration::from_millis(1), + std::future::pending::<()>(), + ) + .await + .unwrap_err(); + assert_eq!( + classify_health_result(&Err(elapsed)), + ChannelHealthState::Timeout + ); + } + + #[tokio::test] + async fn doctor_channels_returns_ok_when_no_channels_are_configured() { + let mut config = Config::default(); + config.channels_config = crate::openhuman::config::ChannelsConfig::default(); + doctor_channels(config).await.unwrap(); + } +} diff --git a/src/openhuman/channels/context.rs b/src/openhuman/channels/context.rs index a29ce6efb..832071c22 100644 --- a/src/openhuman/channels/context.rs +++ b/src/openhuman/channels/context.rs @@ -199,3 +199,262 @@ pub(crate) async fn build_memory_context( context } + +#[cfg(test)] +mod tests { + use super::*; + use crate::openhuman::channels::traits; + use crate::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; + use crate::openhuman::providers::Provider; + use crate::openhuman::tools::{Tool, ToolResult}; + use async_trait::async_trait; + + struct DummyProvider; + + #[async_trait] + impl Provider for DummyProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("ok".into()) + } + } + + struct DummyTool; + + #[async_trait] + impl Tool for DummyTool { + fn name(&self) -> &str { + "dummy" + } + + fn description(&self) -> &str { + "dummy" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({}) + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + Ok(ToolResult::success("ok")) + } + } + + struct MockMemory { + entries: Vec, + } + + #[async_trait] + impl Memory for MockMemory { + fn name(&self) -> &str { + "mock" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(self.entries.clone()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(self.entries.len()) + } + + async fn health_check(&self) -> bool { + true + } + } + + fn memory_entry(key: &str, content: &str, score: Option) -> MemoryEntry { + MemoryEntry { + id: key.into(), + key: key.into(), + content: content.into(), + namespace: None, + category: MemoryCategory::Conversation, + timestamp: "now".into(), + session_id: None, + score, + } + } + + fn runtime_context() -> ChannelRuntimeContext { + ChannelRuntimeContext { + channels_by_name: Arc::new(HashMap::new()), + provider: Arc::new(DummyProvider), + default_provider: Arc::new("default".into()), + memory: Arc::new(MockMemory { + entries: Vec::new(), + }), + tools_registry: Arc::new(vec![Box::new(DummyTool) as Box]), + system_prompt: Arc::new("prompt".into()), + model: Arc::new("model".into()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 1, + min_relevance_score: 0.4, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::openhuman::config::ReliabilityConfig::default()), + provider_runtime_options: crate::openhuman::providers::ProviderRuntimeOptions::default( + ), + workspace_dir: Arc::new(PathBuf::from("/tmp")), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::openhuman::config::MultimodalConfig::default(), + } + } + + fn channel_message(channel: &str) -> traits::ChannelMessage { + traits::ChannelMessage { + channel: channel.into(), + sender: "alice".into(), + content: "hello".into(), + id: "m1".into(), + reply_target: "reply".into(), + thread_ts: Some("thread-1".into()), + timestamp: 0, + } + } + + #[test] + fn timeout_and_history_keys_respect_channel_rules() { + assert_eq!( + effective_channel_message_timeout_secs(10), + MIN_CHANNEL_MESSAGE_TIMEOUT_SECS + ); + assert_eq!(effective_channel_message_timeout_secs(120), 120); + + let telegram = channel_message("telegram"); + let discord = channel_message("discord"); + assert_eq!(conversation_memory_key(&telegram), "telegram_alice_m1"); + assert_eq!(conversation_history_key(&telegram), "telegram_alice_reply"); + assert_eq!( + conversation_history_key(&discord), + "discord_alice_reply_thread:thread-1" + ); + } + + #[test] + fn clear_and_compact_sender_history_update_cached_messages() { + let ctx = runtime_context(); + let sender = "discord_alice_reply_thread:thread-1"; + let mut history = Vec::new(); + history.push(crate::openhuman::providers::ChatMessage::user("short")); + history.extend( + (0..20).map(|idx| { + crate::openhuman::providers::ChatMessage::assistant("x".repeat(700 + idx)) + }), + ); + ctx.conversation_histories + .lock() + .unwrap() + .insert(sender.into(), history); + + assert!(compact_sender_history(&ctx, sender)); + { + let compacted = ctx.conversation_histories.lock().unwrap(); + let compacted = compacted.get(sender).unwrap(); + assert_eq!(compacted.len(), CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES); + assert!(compacted.iter().all(|msg| { + msg.content.chars().count() <= CHANNEL_HISTORY_COMPACT_CONTENT_CHARS + 3 + })); + } + + clear_sender_history(&ctx, sender); + assert!(!ctx + .conversation_histories + .lock() + .unwrap() + .contains_key(sender)); + } + + #[test] + fn skip_and_overflow_detection_cover_edge_cases() { + assert!(should_skip_memory_context_entry("note_history", "short")); + assert!(should_skip_memory_context_entry( + "note", + &"x".repeat(MEMORY_CONTEXT_MAX_CHARS + 1) + )); + assert!(!should_skip_memory_context_entry("note", "short")); + + assert!(is_context_window_overflow_error(&anyhow::anyhow!( + "Maximum context length exceeded" + ))); + assert!(!is_context_window_overflow_error(&anyhow::anyhow!( + "network timeout" + ))); + } + + #[tokio::test] + async fn build_memory_context_filters_entries_and_truncates_content() { + let mem = MockMemory { + entries: vec![ + memory_entry("keep", "v", Some(0.9)), + memory_entry("drop_history", "ignored", Some(0.9)), + memory_entry("low", "too low", Some(0.1)), + memory_entry( + "long", + &"x".repeat(MEMORY_CONTEXT_ENTRY_MAX_CHARS + 50), + Some(0.9), + ), + ], + }; + + let rendered = build_memory_context(&mem, "hello", 0.4).await; + assert!(rendered.starts_with("[Memory context]\n")); + assert!(rendered.contains("- keep: v")); + assert!(!rendered.contains("drop_history")); + assert!(!rendered.contains("too low")); + assert!(rendered.contains("- long: ")); + assert!(rendered.contains("...")); + } + + #[tokio::test] + async fn build_memory_context_honors_total_budget_and_entry_limit() { + let entries = (0..10) + .map(|idx| memory_entry(&format!("k{idx}"), &"x".repeat(700), Some(0.9))) + .collect(); + let mem = MockMemory { entries }; + + let rendered = build_memory_context(&mem, "hello", 0.4).await; + assert!(rendered.chars().count() <= MEMORY_CONTEXT_MAX_CHARS + 32); + assert!(rendered.matches("- k").count() <= MEMORY_CONTEXT_MAX_ENTRIES); + } +} diff --git a/src/openhuman/channels/routes.rs b/src/openhuman/channels/routes.rs index 942f2182c..2a986536d 100644 --- a/src/openhuman/channels/routes.rs +++ b/src/openhuman/channels/routes.rs @@ -328,3 +328,256 @@ pub(crate) async fn handle_runtime_command_if_needed( true } + +#[cfg(test)] +mod tests { + use super::*; + use crate::openhuman::channels::context::{ + ChannelRuntimeContext, ProviderCacheMap, RouteSelectionMap, + }; + use crate::openhuman::channels::traits::ChannelMessage; + use crate::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; + use crate::openhuman::providers::Provider; + use crate::openhuman::tools::{Tool, ToolResult}; + use async_trait::async_trait; + use std::collections::HashMap; + use std::path::PathBuf; + use std::sync::{Arc, Mutex}; + + struct DummyProvider; + + #[async_trait] + impl Provider for DummyProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("ok".into()) + } + } + + struct DummyMemory; + + #[async_trait] + impl Memory for DummyMemory { + fn name(&self) -> &str { + "dummy" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + } + + struct DummyTool; + + #[async_trait] + impl Tool for DummyTool { + fn name(&self) -> &str { + "dummy" + } + + fn description(&self) -> &str { + "dummy" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({}) + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + Ok(ToolResult::success("ok")) + } + } + + fn runtime_context(workspace_dir: PathBuf) -> ChannelRuntimeContext { + ChannelRuntimeContext { + channels_by_name: Arc::new(HashMap::new()), + provider: Arc::new(DummyProvider), + default_provider: Arc::new("openai".into()), + memory: Arc::new(DummyMemory), + tools_registry: Arc::new(vec![Box::new(DummyTool) as Box]), + system_prompt: Arc::new("prompt".into()), + model: Arc::new("reasoning-v1".into()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 1, + min_relevance_score: 0.4, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: ProviderCacheMap::default(), + route_overrides: RouteSelectionMap::default(), + api_key: None, + api_url: None, + reliability: Arc::new(crate::openhuman::config::ReliabilityConfig::default()), + provider_runtime_options: crate::openhuman::providers::ProviderRuntimeOptions::default( + ), + workspace_dir: Arc::new(workspace_dir), + message_timeout_secs: 60, + multimodal: crate::openhuman::config::MultimodalConfig::default(), + } + } + + #[test] + fn runtime_command_parsing_and_provider_support_are_channel_scoped() { + assert!(supports_runtime_model_switch("telegram")); + assert!(supports_runtime_model_switch("discord")); + assert!(!supports_runtime_model_switch("slack")); + + assert_eq!( + parse_runtime_command("telegram", "/models"), + Some(ChannelRuntimeCommand::ShowProviders) + ); + assert_eq!( + parse_runtime_command("discord", "/models openai"), + Some(ChannelRuntimeCommand::SetProvider("openai".into())) + ); + assert_eq!( + parse_runtime_command("telegram", "/model gpt-5"), + Some(ChannelRuntimeCommand::SetModel("gpt-5".into())) + ); + assert_eq!( + parse_runtime_command("telegram", "/model"), + Some(ChannelRuntimeCommand::ShowModel) + ); + assert_eq!(parse_runtime_command("slack", "/models"), None); + assert_eq!(parse_runtime_command("telegram", "hello"), None); + } + + #[test] + fn provider_alias_and_route_selection_round_trip() { + let first_provider = providers::list_providers() + .into_iter() + .next() + .expect("provider registry should not be empty"); + assert_eq!( + resolve_provider_alias(first_provider.name).as_deref(), + Some(first_provider.name) + ); + assert!(resolve_provider_alias(" ").is_none()); + + let ctx = runtime_context(PathBuf::from("/tmp")); + let sender_key = "telegram_alice_reply"; + assert_eq!( + get_route_selection(&ctx, sender_key), + ChannelRouteSelection { + provider: "openai".into(), + model: "reasoning-v1".into() + } + ); + + set_route_selection( + &ctx, + sender_key, + ChannelRouteSelection { + provider: "anthropic".into(), + model: "claude".into(), + }, + ); + assert_eq!( + get_route_selection(&ctx, sender_key), + ChannelRouteSelection { + provider: "anthropic".into(), + model: "claude".into() + } + ); + + set_route_selection(&ctx, sender_key, default_route_selection(&ctx)); + assert!(ctx.route_overrides.lock().unwrap().is_empty()); + } + + #[test] + fn cached_models_and_help_responses_render_expected_text() { + let tempdir = tempfile::tempdir().unwrap(); + let state_dir = tempdir.path().join("state"); + std::fs::create_dir_all(&state_dir).unwrap(); + std::fs::write( + state_dir.join(MODEL_CACHE_FILE), + serde_json::json!({ + "entries": [ + { + "provider": "openai", + "models": ["gpt-5", "gpt-5-mini", "gpt-4.1"] + } + ] + }) + .to_string(), + ) + .unwrap(); + + let preview = load_cached_model_preview(tempdir.path(), "openai"); + assert_eq!(preview, vec!["gpt-5", "gpt-5-mini", "gpt-4.1"]); + assert!(load_cached_model_preview(tempdir.path(), "missing").is_empty()); + + let current = ChannelRouteSelection { + provider: "openai".into(), + model: "gpt-5".into(), + }; + let models = build_models_help_response(¤t, tempdir.path()); + assert!(models.contains("Current provider: `openai`")); + assert!(models.contains("Cached model IDs")); + assert!(models.contains("- `gpt-5-mini`")); + + let providers = build_providers_help_response(¤t); + assert!(providers.contains("Switch provider with `/models `")); + assert!(providers.contains("Available providers:")); + } + + #[test] + fn model_command_messages_use_thread_aware_history_keys() { + let msg = ChannelMessage { + id: "1".into(), + sender: "alice".into(), + reply_target: "room".into(), + content: "/model gpt-5".into(), + channel: "discord".into(), + timestamp: 0, + thread_ts: Some("thread-1".into()), + }; + assert_eq!( + super::super::context::conversation_history_key(&msg), + "discord_alice_room_thread:thread-1" + ); + } +} diff --git a/src/openhuman/composio/periodic.rs b/src/openhuman/composio/periodic.rs index 351e76b37..4ea2ddbc1 100644 --- a/src/openhuman/composio/periodic.rs +++ b/src/openhuman/composio/periodic.rs @@ -49,10 +49,12 @@ static SCHEDULER_STARTED: OnceLock<()> = OnceLock::new(); /// sync paths (e.g. `ComposioConnectionCreatedSubscriber`, /// `on_connection_created`) so that a recent non-periodic sync prevents /// the scheduler from firing immediately on the next tick. -static LAST_SYNC_AT: OnceLock>>> = OnceLock::new(); +type SyncTimestampMap = Arc>>; + +static LAST_SYNC_AT: OnceLock = OnceLock::new(); /// Get (or lazily initialise) the shared last-sync-at map. -fn last_sync_map() -> Arc>> { +fn last_sync_map() -> SyncTimestampMap { LAST_SYNC_AT .get_or_init(|| Arc::new(Mutex::new(HashMap::new()))) .clone() diff --git a/src/openhuman/composio/types.rs b/src/openhuman/composio/types.rs index ec4105128..f2d7bd515 100644 --- a/src/openhuman/composio/types.rs +++ b/src/openhuman/composio/types.rs @@ -4,7 +4,7 @@ //! `/agent-integrations/composio/*`. See: //! - `src/routes/agentIntegrations/composio.ts` //! - `src/controllers/agentIntegrations/composio/*.ts` -//! in the backend repo for the authoritative shapes. +//! in the backend repo for the authoritative shapes. use serde::{Deserialize, Serialize}; diff --git a/src/openhuman/config/schema/load.rs b/src/openhuman/config/schema/load.rs index b3b993e28..b485c96fb 100644 --- a/src/openhuman/config/schema/load.rs +++ b/src/openhuman/config/schema/load.rs @@ -410,9 +410,11 @@ impl Config { // `credentials::ops::store_session`, which writes `active_user.toml` // and triggers a reload that materializes the user-scoped directory. if resolution_source == ConfigResolutionSource::DefaultConfigDir && !config_path.exists() { - let mut config = Config::default(); - config.config_path = config_path.clone(); - config.workspace_dir = workspace_dir.clone(); + let mut config = Config { + config_path: config_path.clone(), + workspace_dir: workspace_dir.clone(), + ..Default::default() + }; config.apply_env_overrides(); tracing::debug!( @@ -508,9 +510,11 @@ impl Config { ); Ok(config) } else { - let mut config = Config::default(); - config.config_path = config_path.clone(); - config.workspace_dir = workspace_dir; + let mut config = Config { + config_path: config_path.clone(), + workspace_dir, + ..Default::default() + }; config.save().await?; #[cfg(unix)] diff --git a/src/openhuman/context/debug_dump.rs b/src/openhuman/context/debug_dump.rs index 43bd0b7a4..4de77fcdb 100644 --- a/src/openhuman/context/debug_dump.rs +++ b/src/openhuman/context/debug_dump.rs @@ -722,4 +722,241 @@ mod tests { assert_eq!(dumped.tool_names, vec!["notion__create_page"]); let _ = std::fs::remove_dir_all(workspace); } + + #[test] + fn dump_prompt_options_new_sets_expected_defaults() { + let options = DumpPromptOptions::new("skills_agent"); + assert_eq!(options.agent_id, "skills_agent"); + assert_eq!(options.skill_filter, None); + assert_eq!(options.workspace_dir_override, None); + assert_eq!(options.model_override, None); + assert!(!options.stub_composio); + } + + #[test] + fn composio_stub_tools_have_expected_names() { + let names: Vec = build_composio_stub_tools() + .into_iter() + .map(|tool| tool.name().to_string()) + .collect(); + assert_eq!( + names, + vec![ + "composio_list_toolkits", + "composio_list_connections", + "composio_authorize", + "composio_list_tools", + "composio_execute", + ] + ); + } + + #[test] + fn render_main_agent_dump_includes_tool_instructions_and_skill_count() { + let workspace = + std::env::temp_dir().join(format!("openhuman_debug_main_{}", uuid::Uuid::new_v4())); + std::fs::create_dir_all(&workspace).unwrap(); + std::fs::write(workspace.join("SOUL.md"), "# Soul\nContext").unwrap(); + std::fs::write(workspace.join("IDENTITY.md"), "# Identity\nContext").unwrap(); + std::fs::write(workspace.join("USER.md"), "# User\nContext").unwrap(); + std::fs::write(workspace.join("HEARTBEAT.md"), "# Heartbeat\nContext").unwrap(); + + let tools: Vec> = vec![ + Box::new(StubTool { + name: "shell", + category: ToolCategory::System, + }), + Box::new(StubTool { + name: "notion__create_page", + category: ToolCategory::Skill, + }), + ]; + + let dumped = render_main_agent_dump(&workspace, "reasoning-v1", &tools).unwrap(); + assert_eq!(dumped.mode, "main"); + assert_eq!(dumped.model, "reasoning-v1"); + assert_eq!(dumped.tool_names, vec!["shell", "notion__create_page"]); + assert_eq!(dumped.skill_tool_count, 1); + assert!(dumped.text.contains("## Tools")); + assert!(dumped.text.contains("Tool Use Protocol")); + assert!(dumped.cache_boundary.is_some()); + + let _ = std::fs::remove_dir_all(workspace); + } + + #[test] + fn filter_respects_named_scope_and_disallowed_tools() { + let tools: Vec> = vec![ + Box::new(StubTool { + name: "shell", + category: ToolCategory::System, + }), + Box::new(StubTool { + name: "notion__create_page", + category: ToolCategory::Skill, + }), + Box::new(StubTool { + name: "gmail__send_email", + category: ToolCategory::Skill, + }), + ]; + + let indices = filter_tool_indices_for_dump( + &tools, + &ToolScope::Named(vec!["shell".into(), "gmail__send_email".into()]), + &["shell".into()], + None, + None, + ); + + let names: Vec<&str> = indices.iter().map(|&i| tools[i].name()).collect(); + assert_eq!(names, vec!["gmail__send_email"]); + } + + #[test] + fn render_subagent_dump_supports_file_prompt_fallbacks() { + let workspace = + std::env::temp_dir().join(format!("openhuman_debug_file_{}", uuid::Uuid::new_v4())); + std::fs::create_dir_all(&workspace).unwrap(); + + let tools: Vec> = vec![Box::new(StubTool { + name: "shell", + category: ToolCategory::System, + })]; + + let definition = AgentDefinition { + id: "file_agent".into(), + when_to_use: "t".into(), + display_name: None, + system_prompt: PromptSource::File { + path: "USER.md".into(), + }, + omit_identity: true, + omit_memory_context: true, + omit_safety_preamble: true, + omit_skills_catalog: true, + model: ModelSpec::Inherit, + temperature: 0.0, + tools: ToolScope::Wildcard, + disallowed_tools: vec![], + skill_filter: None, + category_filter: None, + max_iterations: 2, + timeout_secs: None, + sandbox_mode: SandboxMode::None, + background: false, + uses_fork_context: false, + source: DefinitionSource::Builtin, + }; + + let dumped = + render_subagent_dump(&definition, &workspace, "reasoning-v1", &tools, None).unwrap(); + assert!(dumped.text.contains("## Tools")); + assert!(dumped.text.contains("OpenHuman")); + + let _ = std::fs::remove_dir_all(workspace); + } + + #[test] + fn render_subagent_dump_handles_missing_file_prompt_without_panicking() { + let workspace = + std::env::temp_dir().join(format!("openhuman_debug_missing_{}", uuid::Uuid::new_v4())); + std::fs::create_dir_all(&workspace).unwrap(); + + let tools: Vec> = vec![Box::new(StubTool { + name: "shell", + category: ToolCategory::System, + })]; + + let definition = AgentDefinition { + id: "missing_prompt".into(), + when_to_use: "t".into(), + display_name: None, + system_prompt: PromptSource::File { + path: "does-not-exist.md".into(), + }, + omit_identity: true, + omit_memory_context: true, + omit_safety_preamble: true, + omit_skills_catalog: true, + model: ModelSpec::Inherit, + temperature: 0.0, + tools: ToolScope::Wildcard, + disallowed_tools: vec![], + skill_filter: None, + category_filter: None, + max_iterations: 2, + timeout_secs: None, + sandbox_mode: SandboxMode::None, + background: false, + uses_fork_context: false, + source: DefinitionSource::Builtin, + }; + + let dumped = + render_subagent_dump(&definition, &workspace, "reasoning-v1", &tools, None).unwrap(); + assert!(dumped.text.contains("## Tools")); + assert!(!dumped.text.contains("does-not-exist")); + + let _ = std::fs::remove_dir_all(workspace); + } + + #[test] + fn render_subagent_dump_prefers_workspace_prompt_locations() { + let workspace = std::env::temp_dir().join(format!( + "openhuman_debug_workspace_prompt_{}", + uuid::Uuid::new_v4() + )); + std::fs::create_dir_all(workspace.join("agent/prompts")).unwrap(); + std::fs::write( + workspace.join("agent/prompts/custom.md"), + "Workspace agent prompt", + ) + .unwrap(); + std::fs::write(workspace.join("root.md"), "Workspace root prompt").unwrap(); + + let tools: Vec> = vec![Box::new(StubTool { + name: "shell", + category: ToolCategory::System, + })]; + + let mut definition = AgentDefinition { + id: "workspace_file".into(), + when_to_use: "t".into(), + display_name: None, + system_prompt: PromptSource::File { + path: "custom.md".into(), + }, + omit_identity: true, + omit_memory_context: true, + omit_safety_preamble: true, + omit_skills_catalog: true, + model: ModelSpec::Inherit, + temperature: 0.0, + tools: ToolScope::Wildcard, + disallowed_tools: vec![], + skill_filter: None, + category_filter: None, + max_iterations: 2, + timeout_secs: None, + sandbox_mode: SandboxMode::None, + background: false, + uses_fork_context: false, + source: DefinitionSource::Builtin, + }; + + let agent_prompt = + render_subagent_dump(&definition, &workspace, "reasoning-v1", &tools, None).unwrap(); + assert!(agent_prompt.text.contains("Workspace agent prompt")); + + definition.id = "workspace_root".into(); + definition.system_prompt = PromptSource::File { + path: "root.md".into(), + }; + let root_prompt = + render_subagent_dump(&definition, &workspace, "reasoning-v1", &tools, None).unwrap(); + assert!(root_prompt.text.contains("Workspace root prompt")); + + let _ = std::fs::remove_dir_all(workspace); + } } diff --git a/src/openhuman/context/prompt.rs b/src/openhuman/context/prompt.rs index 387a1dacd..bcae8a9b1 100644 --- a/src/openhuman/context/prompt.rs +++ b/src/openhuman/context/prompt.rs @@ -1220,4 +1220,155 @@ mod tests { let _ = std::fs::remove_dir_all(workspace); } + + #[test] + fn extract_cache_boundary_without_marker_returns_original_text() { + let rendered = extract_cache_boundary("hello"); + assert_eq!(rendered.text, "hello"); + assert_eq!(rendered.cache_boundary, None); + } + + #[test] + fn subagent_render_options_invert_definition_flags() { + let options = SubagentRenderOptions::from_definition_flags(true, false, true); + assert!(!options.include_identity); + assert!(options.include_safety_preamble); + assert!(!options.include_skills_catalog); + let narrow = SubagentRenderOptions::narrow(); + let default = SubagentRenderOptions::default(); + assert_eq!(narrow.include_identity, default.include_identity); + assert_eq!( + narrow.include_safety_preamble, + default.include_safety_preamble + ); + assert_eq!( + narrow.include_skills_catalog, + default.include_skills_catalog + ); + } + + #[test] + fn render_subagent_system_prompt_honors_identity_safety_and_skills_flags() { + let workspace = + std::env::temp_dir().join(format!("openhuman_prompt_opts_{}", uuid::Uuid::new_v4())); + std::fs::create_dir_all(&workspace).unwrap(); + std::fs::write(workspace.join("SOUL.md"), "# Soul\nContext").unwrap(); + std::fs::write(workspace.join("IDENTITY.md"), "# Identity\nContext").unwrap(); + std::fs::write(workspace.join("USER.md"), "# User\nContext").unwrap(); + + let tools: Vec> = vec![Box::new(TestTool)]; + let rendered = render_subagent_system_prompt_with_format( + &workspace, + "reasoning-v1", + &[0], + &tools, + "You are a specialist.", + SubagentRenderOptions { + include_identity: true, + include_safety_preamble: true, + include_skills_catalog: true, + }, + ToolCallFormat::Json, + ); + + assert!(rendered.contains("## Project Context")); + assert!(rendered.contains("### SOUL.md")); + assert!(rendered.contains("## Safety")); + assert!(rendered.contains("## Available Skills")); + assert!(rendered.contains("Parameters:")); + assert!(rendered.contains("\"type\"")); + + let native = render_subagent_system_prompt_with_format( + &workspace, + "reasoning-v1", + &[0], + &tools, + "You are a specialist.", + SubagentRenderOptions::narrow(), + ToolCallFormat::Native, + ); + assert!(native.contains("native tool-calling output")); + assert!(!native.contains("## Safety")); + + let _ = std::fs::remove_dir_all(workspace); + } + + #[test] + fn sync_workspace_file_updates_hash_and_inject_workspace_file_truncates() { + let workspace = std::env::temp_dir().join(format!( + "openhuman_prompt_workspace_{}", + uuid::Uuid::new_v4() + )); + std::fs::create_dir_all(&workspace).unwrap(); + + sync_workspace_file(&workspace, "SOUL.md"); + let hash_path = workspace.join(".SOUL.md.builtin-hash"); + assert!(workspace.join("SOUL.md").exists()); + assert!(hash_path.exists()); + let original_hash = std::fs::read_to_string(&hash_path).unwrap(); + + std::fs::write(workspace.join("SOUL.md"), "user override").unwrap(); + sync_workspace_file(&workspace, "SOUL.md"); + assert_eq!(std::fs::read_to_string(&hash_path).unwrap(), original_hash); + assert_eq!( + std::fs::read_to_string(workspace.join("SOUL.md")).unwrap(), + "user override" + ); + + std::fs::write( + workspace.join("BIG.md"), + "x".repeat(BOOTSTRAP_MAX_CHARS + 50), + ) + .unwrap(); + let mut prompt = String::new(); + inject_workspace_file(&mut prompt, &workspace, "BIG.md"); + assert!(prompt.contains("### BIG.md")); + assert!(prompt.contains("[... truncated at")); + + let _ = std::fs::remove_dir_all(workspace); + } + + #[test] + fn dynamic_section_classification_matches_cache_boundary_rules() { + assert!(is_dynamic_section("workspace")); + assert!(is_dynamic_section("datetime")); + assert!(is_dynamic_section("runtime")); + assert!(!is_dynamic_section("tools")); + assert!(!is_dynamic_section("identity")); + } + + #[test] + fn prompt_tool_constructors_and_user_memory_skip_empty_bodies() { + let plain = PromptTool::new("shell", "run commands"); + assert_eq!(plain.name, "shell"); + assert!(plain.parameters_schema.is_none()); + + let with_schema = + PromptTool::with_schema("http_request", "fetch data", "{\"type\":\"object\"}".into()); + assert_eq!( + with_schema.parameters_schema.as_deref(), + Some("{\"type\":\"object\"}") + ); + + let ctx = PromptContext { + workspace_dir: Path::new("/tmp"), + model_name: "model", + tools: &[], + skills: &[], + dispatcher_instructions: "", + learned: LearnedContextData { + tree_root_summaries: vec![ + ("user".into(), "kept".into()), + ("empty".into(), " ".into()), + ], + ..Default::default() + }, + visible_tool_names: &NO_FILTER, + tool_call_format: ToolCallFormat::PFormat, + }; + let rendered = UserMemorySection.build(&ctx).unwrap(); + assert!(rendered.contains("### user")); + assert!(!rendered.contains("### empty")); + assert_eq!(default_workspace_file_content("missing"), ""); + } } diff --git a/src/openhuman/dev_paths.rs b/src/openhuman/dev_paths.rs index 56bbd4240..e221f32cb 100644 --- a/src/openhuman/dev_paths.rs +++ b/src/openhuman/dev_paths.rs @@ -14,12 +14,7 @@ pub fn bundled_openclaw_prompts_dir(resource_dir: &Path) -> Option { .join("agent") .join("prompts"), ]; - for p in candidates { - if p.is_dir() { - return Some(p); - } - } - None + candidates.into_iter().find(|p| p.is_dir()) } /// Locate `src/openhuman/agent/prompts` by walking up from `cwd`. diff --git a/src/openhuman/learning/prompt_sections.rs b/src/openhuman/learning/prompt_sections.rs index f3f143373..64a87d144 100644 --- a/src/openhuman/learning/prompt_sections.rs +++ b/src/openhuman/learning/prompt_sections.rs @@ -81,3 +81,139 @@ impl PromptSection for UserProfileSection { Ok(out) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::openhuman::context::prompt::LearnedContextData; + use crate::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; + use async_trait::async_trait; + use std::collections::HashSet; + use std::path::Path; + use std::sync::Arc; + + struct NoopMemory; + + #[async_trait] + impl Memory for NoopMemory { + fn name(&self) -> &str { + "noop" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + } + + fn prompt_context(learned: LearnedContextData) -> PromptContext<'static> { + let visible_tool_names = Box::leak(Box::new(HashSet::new())); + PromptContext { + workspace_dir: Path::new("/tmp"), + model_name: "test-model", + tools: &[], + skills: &[], + dispatcher_instructions: "", + learned, + visible_tool_names, + tool_call_format: crate::openhuman::context::prompt::ToolCallFormat::PFormat, + } + } + + #[test] + fn learned_context_section_renders_observations_and_patterns() { + let section = LearnedContextSection::new(Arc::new(NoopMemory)); + let rendered = section + .build(&prompt_context(LearnedContextData { + observations: vec!["Tool use succeeded".into()], + patterns: vec!["User prefers terse replies".into()], + user_profile: Vec::new(), + tree_root_summaries: Vec::new(), + })) + .unwrap(); + + assert_eq!(section.name(), "learned_context"); + assert!(rendered.contains("## Learned Context")); + assert!(rendered.contains("### Recent Observations")); + assert!(rendered.contains("- Tool use succeeded")); + assert!(rendered.contains("### Recognized Patterns")); + assert!(rendered.contains("- User prefers terse replies")); + } + + #[test] + fn learned_context_section_returns_empty_without_entries() { + let section = LearnedContextSection::new(Arc::new(NoopMemory)); + assert!(section + .build(&prompt_context(LearnedContextData::default())) + .unwrap() + .is_empty()); + } + + #[test] + fn user_profile_section_renders_bullets() { + let section = UserProfileSection::new(Arc::new(NoopMemory)); + let rendered = section + .build(&prompt_context(LearnedContextData { + observations: Vec::new(), + patterns: Vec::new(), + user_profile: vec![ + "Timezone: America/Los_Angeles".into(), + "Prefers Rust".into(), + ], + tree_root_summaries: Vec::new(), + })) + .unwrap(); + + assert_eq!(section.name(), "user_profile"); + assert!(rendered.starts_with("## User Profile (Learned)\n\n")); + assert!(rendered.contains("- Timezone: America/Los_Angeles")); + assert!(rendered.contains("- Prefers Rust")); + } + + #[test] + fn user_profile_section_returns_empty_without_profile_entries() { + let section = UserProfileSection::new(Arc::new(NoopMemory)); + assert!(section + .build(&prompt_context(LearnedContextData::default())) + .unwrap() + .is_empty()); + } +} diff --git a/src/openhuman/learning/reflection.rs b/src/openhuman/learning/reflection.rs index 7f7a34496..9d8c1ad97 100644 --- a/src/openhuman/learning/reflection.rs +++ b/src/openhuman/learning/reflection.rs @@ -306,6 +306,108 @@ fn slugify(s: &str) -> String { #[cfg(test)] mod tests { use super::*; + use crate::openhuman::agent::hooks::{ToolCallRecord, TurnContext}; + use crate::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; + use async_trait::async_trait; + use parking_lot::Mutex; + use std::collections::HashMap; + use std::sync::Arc; + + #[derive(Default)] + struct MockMemory { + entries: Mutex>, + } + + #[async_trait] + impl Memory for MockMemory { + fn name(&self) -> &str { + "mock" + } + + async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + ) -> anyhow::Result<()> { + self.entries.lock().insert( + key.to_string(), + MemoryEntry { + id: key.to_string(), + key: key.to_string(), + content: content.to_string(), + namespace: None, + category, + timestamp: "now".into(), + session_id: session_id.map(str::to_string), + score: None, + }, + ); + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, key: &str) -> anyhow::Result> { + Ok(self.entries.lock().get(key).cloned()) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(self.entries.lock().values().cloned().collect()) + } + + async fn forget(&self, key: &str) -> anyhow::Result { + Ok(self.entries.lock().remove(key).is_some()) + } + + async fn count(&self) -> anyhow::Result { + Ok(self.entries.lock().len()) + } + + async fn health_check(&self) -> bool { + true + } + } + + fn reflection_config() -> LearningConfig { + LearningConfig { + enabled: true, + reflection_enabled: true, + reflection_source: ReflectionSource::Cloud, + max_reflections_per_session: 2, + min_turn_complexity: 1, + ..LearningConfig::default() + } + } + + fn reflective_turn() -> TurnContext { + TurnContext { + user_message: "Please debug the failing build".into(), + assistant_response: "I inspected the logs and found the root cause.".repeat(20), + tool_calls: vec![ToolCallRecord { + name: "shell".into(), + arguments: serde_json::json!({"cmd":"cargo test"}), + success: true, + output_summary: "tests passed".into(), + duration_ms: 1200, + }], + turn_duration_ms: 2200, + session_id: Some("session-1".into()), + iteration_count: 2, + } + } #[test] fn parse_reflection_valid_json() { @@ -338,4 +440,131 @@ That's my assessment."#; assert_eq!(slugify("User prefers Rust"), "user_prefers_rust"); assert_eq!(slugify("hello-world_test"), "hello_world_test"); } + + #[test] + fn should_reflect_requires_learning_and_complexity() { + let memory: Arc = Arc::new(MockMemory::default()); + let hook = ReflectionHook::new( + reflection_config(), + Arc::new(Config::default()), + memory, + None, + ); + assert!(hook.should_reflect(&reflective_turn())); + + let mut disabled = reflection_config(); + disabled.enabled = false; + let hook = ReflectionHook::new( + disabled, + Arc::new(Config::default()), + Arc::new(MockMemory::default()), + None, + ); + assert!(!hook.should_reflect(&reflective_turn())); + + let mut simple = reflective_turn(); + simple.tool_calls.clear(); + simple.assistant_response = "short".into(); + let hook = ReflectionHook::new( + reflection_config(), + Arc::new(Config::default()), + Arc::new(MockMemory::default()), + None, + ); + assert!(!hook.should_reflect(&simple)); + } + + #[test] + fn build_reflection_prompt_includes_tool_calls_and_truncation() { + let memory: Arc = Arc::new(MockMemory::default()); + let hook = ReflectionHook::new( + reflection_config(), + Arc::new(Config::default()), + memory, + None, + ); + let mut turn = reflective_turn(); + turn.user_message = "u".repeat(700); + turn.assistant_response = "a".repeat(700); + turn.tool_calls[0].output_summary = "x".repeat(200); + + let prompt = hook.build_reflection_prompt(&turn); + assert!(prompt.contains("## User Message")); + assert!(prompt.contains("## Assistant Response")); + assert!(prompt.contains("## Tool Calls")); + assert!(prompt.contains("shell (success=true, duration=1200ms):")); + assert!(prompt.contains("Turn took 2200ms across 2 iteration(s).")); + assert!(prompt.contains(&format!("{}...", "u".repeat(500)))); + assert!(prompt.contains(&format!("{}...", "a".repeat(500)))); + assert!(prompt.contains(&format!("{}...", "x".repeat(100)))); + } + + #[test] + fn session_key_and_counter_management_work() { + let hook = ReflectionHook::new( + reflection_config(), + Arc::new(Config::default()), + Arc::new(MockMemory::default()), + None, + ); + + let global_ctx = TurnContext { + session_id: None, + ..reflective_turn() + }; + assert_eq!(ReflectionHook::session_key(&global_ctx), "__global__"); + + assert!(hook.try_increment("s")); + assert!(hook.try_increment("s")); + assert!(!hook.try_increment("s")); + hook.rollback_increment("s"); + assert!(hook.try_increment("s")); + } + + #[tokio::test] + async fn store_reflection_persists_all_categories() { + let memory_impl = Arc::new(MockMemory::default()); + let memory: Arc = memory_impl.clone(); + let hook = ReflectionHook::new( + reflection_config(), + Arc::new(Config::default()), + memory, + None, + ); + hook.store_reflection(&ReflectionOutput { + observations: vec!["Observed failure".into()], + patterns: vec!["Pattern A".into()], + user_preferences: vec!["Pref A".into()], + }) + .await + .unwrap(); + + let keys: Vec = memory_impl.entries.lock().keys().cloned().collect(); + assert!(keys.iter().any(|key| key.starts_with("obs/"))); + assert!(keys.iter().any(|key| key == "pat/pattern_a")); + assert!(keys.iter().any(|key| key == "pref/pref_a")); + } + + #[tokio::test] + async fn on_turn_complete_rolls_back_counter_when_reflection_call_fails() { + let memory: Arc = Arc::new(MockMemory::default()); + let hook = ReflectionHook::new( + reflection_config(), + Arc::new(Config::default()), + memory, + None, + ); + let turn = reflective_turn(); + + let err = hook.on_turn_complete(&turn).await.unwrap_err(); + assert!(err.to_string().contains("no cloud provider configured")); + assert_eq!( + hook.session_counts + .lock() + .get("session-1") + .copied() + .unwrap_or_default(), + 0 + ); + } } diff --git a/src/openhuman/learning/tool_tracker.rs b/src/openhuman/learning/tool_tracker.rs index e9af73450..460076781 100644 --- a/src/openhuman/learning/tool_tracker.rs +++ b/src/openhuman/learning/tool_tracker.rs @@ -178,6 +178,80 @@ impl PostTurnHook for ToolTrackerHook { #[cfg(test)] mod tests { use super::*; + use crate::openhuman::agent::hooks::{ToolCallRecord, TurnContext}; + use crate::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; + use async_trait::async_trait; + use parking_lot::Mutex; + use std::collections::HashMap; + use std::sync::Arc; + + #[derive(Default)] + struct MockMemory { + entries: Mutex>, + } + + #[async_trait] + impl Memory for MockMemory { + fn name(&self) -> &str { + "mock" + } + + async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + ) -> anyhow::Result<()> { + self.entries.lock().insert( + key.to_string(), + MemoryEntry { + id: key.to_string(), + key: key.to_string(), + content: content.to_string(), + namespace: None, + category, + timestamp: "now".into(), + session_id: session_id.map(str::to_string), + score: None, + }, + ); + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, key: &str) -> anyhow::Result> { + Ok(self.entries.lock().get(key).cloned()) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(self.entries.lock().values().cloned().collect()) + } + + async fn forget(&self, key: &str) -> anyhow::Result { + Ok(self.entries.lock().remove(key).is_some()) + } + + async fn count(&self) -> anyhow::Result { + Ok(self.entries.lock().len()) + } + + async fn health_check(&self) -> bool { + true + } + } #[test] fn tool_stats_record_call_updates_correctly() { @@ -206,4 +280,131 @@ mod tests { assert!(summary.contains("calls=3")); assert!(summary.contains("failures=1")); } + + #[test] + fn tool_stats_keeps_only_recent_unique_error_patterns() { + let mut stats = ToolStats::default(); + for idx in 0..7 { + stats.record_call(false, 10, Some(&format!("error pattern {idx}"))); + } + stats.record_call(false, 10, Some("error pattern 6")); + + assert_eq!(stats.failures, 8); + assert_eq!(stats.common_error_patterns.len(), 5); + assert_eq!( + stats.common_error_patterns.first().unwrap(), + "error pattern 2" + ); + assert_eq!( + stats.common_error_patterns.last().unwrap(), + "error pattern 6" + ); + } + + #[tokio::test] + async fn update_stats_merges_with_existing_memory_entry() { + let memory_impl = Arc::new(MockMemory::default()); + memory_impl + .store( + "tool/shell", + &serde_json::to_string(&ToolStats { + total_calls: 2, + successes: 1, + failures: 1, + avg_duration_ms: 50.0, + common_error_patterns: vec!["timeout".into()], + }) + .unwrap(), + MemoryCategory::Custom("tool_effectiveness".into()), + None, + ) + .await + .unwrap(); + + let memory: Arc = memory_impl.clone(); + let hook = ToolTrackerHook::new( + LearningConfig { + enabled: true, + tool_tracking_enabled: true, + ..LearningConfig::default() + }, + memory, + ); + + hook.update_stats("shell", true, 250, None).await.unwrap(); + + let stored = memory_impl.get("tool/shell").await.unwrap().unwrap(); + let parsed: ToolStats = serde_json::from_str(&stored.content).unwrap(); + assert_eq!(parsed.total_calls, 3); + assert_eq!(parsed.successes, 2); + assert_eq!(parsed.failures, 1); + assert!((parsed.avg_duration_ms - 116.66666666666667).abs() < 0.001); + } + + #[tokio::test] + async fn on_turn_complete_skips_when_disabled_or_no_tools() { + let memory_impl = Arc::new(MockMemory::default()); + let memory: Arc = memory_impl.clone(); + let hook = ToolTrackerHook::new(LearningConfig::default(), memory); + let ctx = TurnContext { + user_message: "hello".into(), + assistant_response: "world".into(), + tool_calls: Vec::new(), + turn_duration_ms: 10, + session_id: None, + iteration_count: 1, + }; + + hook.on_turn_complete(&ctx).await.unwrap(); + assert!(memory_impl.entries.lock().is_empty()); + } + + #[tokio::test] + async fn on_turn_complete_records_each_tool_call() { + let memory_impl = Arc::new(MockMemory::default()); + let memory: Arc = memory_impl.clone(); + let hook = ToolTrackerHook::new( + LearningConfig { + enabled: true, + tool_tracking_enabled: true, + ..LearningConfig::default() + }, + memory, + ); + let ctx = TurnContext { + user_message: "hello".into(), + assistant_response: "world".into(), + tool_calls: vec![ + ToolCallRecord { + name: "shell".into(), + arguments: serde_json::json!({}), + success: true, + output_summary: "ok".into(), + duration_ms: 100, + }, + ToolCallRecord { + name: "shell".into(), + arguments: serde_json::json!({}), + success: false, + output_summary: "permission denied while writing".into(), + duration_ms: 200, + }, + ], + turn_duration_ms: 20, + session_id: None, + iteration_count: 1, + }; + + hook.on_turn_complete(&ctx).await.unwrap(); + + let stored = memory_impl.get("tool/shell").await.unwrap().unwrap(); + let parsed: ToolStats = serde_json::from_str(&stored.content).unwrap(); + assert_eq!(parsed.total_calls, 2); + assert_eq!(parsed.successes, 1); + assert_eq!(parsed.failures, 1); + assert_eq!( + parsed.common_error_patterns, + vec!["permission denied while writing"] + ); + } } diff --git a/src/openhuman/learning/user_profile.rs b/src/openhuman/learning/user_profile.rs index 8dc0fb1e7..8f5ca2e0d 100644 --- a/src/openhuman/learning/user_profile.rs +++ b/src/openhuman/learning/user_profile.rs @@ -148,6 +148,80 @@ fn slugify(s: &str) -> String { #[cfg(test)] mod tests { use super::*; + use crate::openhuman::agent::hooks::TurnContext; + use crate::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; + use async_trait::async_trait; + use parking_lot::Mutex; + use std::collections::HashMap; + use std::sync::Arc; + + #[derive(Default)] + struct MockMemory { + entries: Mutex>, + } + + #[async_trait] + impl Memory for MockMemory { + fn name(&self) -> &str { + "mock" + } + + async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + ) -> anyhow::Result<()> { + self.entries.lock().insert( + key.to_string(), + MemoryEntry { + id: key.to_string(), + key: key.to_string(), + content: content.to_string(), + namespace: None, + category, + timestamp: "now".into(), + session_id: session_id.map(str::to_string), + score: None, + }, + ); + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, key: &str) -> anyhow::Result> { + Ok(self.entries.lock().get(key).cloned()) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(self.entries.lock().values().cloned().collect()) + } + + async fn forget(&self, key: &str) -> anyhow::Result { + Ok(self.entries.lock().remove(key).is_some()) + } + + async fn count(&self) -> anyhow::Result { + Ok(self.entries.lock().len()) + } + + async fn health_check(&self) -> bool { + true + } + } #[test] fn extract_preferences_finds_patterns() { @@ -171,4 +245,95 @@ mod tests { let prefs = UserProfileHook::extract_preferences(msg); assert!(prefs.is_empty()); } + + #[test] + fn extract_preferences_uses_full_message_fallback_and_caps_results() { + let fallback = + UserProfileHook::extract_preferences("I prefer compact diffs in code reviews"); + assert_eq!(fallback, vec!["I prefer compact diffs in code reviews"]); + + let many = UserProfileHook::extract_preferences( + "I prefer Rust. I always use tests. Please always explain failures. \ + My timezone is PST. My stack is Tauri. Going forward use concise output. \ + Never use nested bullets.", + ); + assert_eq!(many.len(), 5); + } + + #[tokio::test] + async fn store_preferences_skips_duplicates_and_empty_slugs() { + let memory_impl = Arc::new(MockMemory::default()); + memory_impl + .store( + "pref/i_prefer_rust", + "I prefer Rust", + MemoryCategory::Custom("user_profile".into()), + None, + ) + .await + .unwrap(); + let memory: Arc = memory_impl.clone(); + let hook = UserProfileHook::new( + LearningConfig { + enabled: true, + user_profile_enabled: true, + ..LearningConfig::default() + }, + memory, + ); + + hook.store_preferences(&[ + "I prefer Rust".into(), + "!!!".into(), + "My timezone is PST".into(), + ]) + .await + .unwrap(); + + let keys: Vec = memory_impl.entries.lock().keys().cloned().collect(); + assert_eq!(keys.len(), 2); + assert!(keys.contains(&"pref/i_prefer_rust".into())); + assert!(keys.contains(&"pref/my_timezone_is_pst".into())); + } + + #[tokio::test] + async fn on_turn_complete_respects_feature_flags_and_stores_preferences() { + let memory_impl = Arc::new(MockMemory::default()); + let memory: Arc = memory_impl.clone(); + let ctx = TurnContext { + user_message: "My language is English. Please always use concise output.".into(), + assistant_response: "Noted".into(), + tool_calls: Vec::new(), + turn_duration_ms: 10, + session_id: None, + iteration_count: 1, + }; + + let disabled = UserProfileHook::new(LearningConfig::default(), memory.clone()); + disabled.on_turn_complete(&ctx).await.unwrap(); + assert!(memory_impl.entries.lock().is_empty()); + + let enabled = UserProfileHook::new( + LearningConfig { + enabled: true, + user_profile_enabled: true, + ..LearningConfig::default() + }, + memory, + ); + enabled.on_turn_complete(&ctx).await.unwrap(); + + let values: Vec = memory_impl + .entries + .lock() + .values() + .map(|entry| entry.content.clone()) + .collect(); + assert!(values + .iter() + .any(|value| value.contains("My language is English"))); + assert!(values + .iter() + .any(|value| value.contains("Please always use concise output"))); + } } diff --git a/src/openhuman/local_ai/mod.rs b/src/openhuman/local_ai/mod.rs index a9d9e7b87..2a9cb7e8a 100644 --- a/src/openhuman/local_ai/mod.rs +++ b/src/openhuman/local_ai/mod.rs @@ -1,5 +1,9 @@ //! Bundled local AI stack (Ollama, whisper.cpp, Piper). +#[cfg(test)] +pub(crate) static LOCAL_AI_TEST_MUTEX: once_cell::sync::Lazy> = + once_cell::sync::Lazy::new(|| std::sync::Mutex::new(())); + mod core; pub mod device; pub mod gif_decision; diff --git a/src/openhuman/local_ai/service/public_infer.rs b/src/openhuman/local_ai/service/public_infer.rs index 52f001d6d..dce194365 100644 --- a/src/openhuman/local_ai/service/public_infer.rs +++ b/src/openhuman/local_ai/service/public_infer.rs @@ -292,6 +292,7 @@ impl LocalAiService { .await } + #[allow(clippy::too_many_arguments)] async fn inference_with_temperature_internal( &self, config: &Config, diff --git a/src/openhuman/memory/conversations/bus.rs b/src/openhuman/memory/conversations/bus.rs index 40eb121c2..0b4697da1 100644 --- a/src/openhuman/memory/conversations/bus.rs +++ b/src/openhuman/memory/conversations/bus.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::{Arc, OnceLock}; use async_trait::async_trait; @@ -146,7 +146,7 @@ struct ChannelTurnDescriptor<'a> { } fn persist_channel_turn( - workspace_dir: &PathBuf, + workspace_dir: &Path, descriptor: ChannelTurnDescriptor<'_>, ) -> Result<(), String> { let thread_id = persisted_channel_thread_id( @@ -164,7 +164,7 @@ fn persist_channel_turn( let created_at = Utc::now().to_rfc3339(); ensure_thread( - workspace_dir.clone(), + workspace_dir.to_path_buf(), CreateConversationThread { id: thread_id.clone(), title, @@ -173,7 +173,7 @@ fn persist_channel_turn( )?; let persisted_message_id = format!("{}:{}", descriptor.role, descriptor.message_id); - if get_messages(workspace_dir.clone(), &thread_id)? + if get_messages(workspace_dir.to_path_buf(), &thread_id)? .iter() .any(|message| message.id == persisted_message_id) { @@ -186,7 +186,7 @@ fn persist_channel_turn( } append_message( - workspace_dir.clone(), + workspace_dir.to_path_buf(), &thread_id, ConversationMessage { id: persisted_message_id.clone(), diff --git a/src/openhuman/memory/embeddings.rs b/src/openhuman/memory/embeddings.rs index aa6c9e736..45d42398c 100644 --- a/src/openhuman/memory/embeddings.rs +++ b/src/openhuman/memory/embeddings.rs @@ -64,7 +64,7 @@ enum FastembedState { /// Initial state before the model is loaded. Uninitialized, /// Model is loaded into memory and ready for inference. - Ready(fastembed::TextEmbedding), + Ready(Box), /// An error occurred during model loading. Failed(String), } @@ -142,7 +142,6 @@ fn ensure_fastembed_ort_dylib_path() { if runtime_lib.exists() { env::set_var("ORT_DYLIB_PATH", runtime_lib); - return; } } @@ -218,7 +217,7 @@ impl EmbeddingProvider for FastembedEmbedding { })); match init_result { - Ok(Ok(model)) => *guard = FastembedState::Ready(model), + Ok(Ok(model)) => *guard = FastembedState::Ready(Box::new(model)), Ok(Err(err)) => { let message = format!("fastembed init failed for {provider}: {err}"); tracing::error!(target: "memory.embeddings", "[embeddings] {message}"); diff --git a/src/openhuman/memory/ingestion.rs b/src/openhuman/memory/ingestion.rs index 0bf4ef943..382fa0a15 100644 --- a/src/openhuman/memory/ingestion.rs +++ b/src/openhuman/memory/ingestion.rs @@ -406,6 +406,7 @@ impl ExtractionAccumulator { } /// Records a new relationship, applying semantic validation rules. + #[allow(clippy::too_many_arguments)] fn add_relation( &mut self, subject: &str, @@ -723,13 +724,13 @@ fn find_chunk_index(chunks: &[String], excerpt: &str, hint: usize) -> usize { if needle.is_empty() { return hint.min(chunks.len().saturating_sub(1)); } - for index in hint..chunks.len() { - if UnifiedMemory::normalize_search_text(&chunks[index]).contains(&needle) { + for (index, chunk) in chunks.iter().enumerate().skip(hint) { + if UnifiedMemory::normalize_search_text(chunk).contains(&needle) { return index; } } - for index in 0..hint.min(chunks.len()) { - if UnifiedMemory::normalize_search_text(&chunks[index]).contains(&needle) { + for (index, chunk) in chunks.iter().enumerate().take(hint.min(chunks.len())) { + if UnifiedMemory::normalize_search_text(chunk).contains(&needle) { return index; } } diff --git a/src/openhuman/memory/store/unified/events.rs b/src/openhuman/memory/store/unified/events.rs index be7c24735..f156cddd2 100644 --- a/src/openhuman/memory/store/unified/events.rs +++ b/src/openhuman/memory/store/unified/events.rs @@ -88,7 +88,7 @@ impl EventType { } } - pub fn from_str(s: &str) -> Self { + pub fn parse_or_default(s: &str) -> Self { match s { "decision" => Self::Decision, "commitment" => Self::Commitment, @@ -373,7 +373,7 @@ fn row_to_event(row: &rusqlite::Row<'_>) -> rusqlite::Result { segment_id: row.get(1)?, session_id: row.get(2)?, namespace: row.get(3)?, - event_type: EventType::from_str(&event_type_str), + event_type: EventType::parse_or_default(&event_type_str), content: row.get(5)?, subject: row.get(6)?, timestamp_ref: row.get(7)?, diff --git a/src/openhuman/memory/store/unified/helpers.rs b/src/openhuman/memory/store/unified/helpers.rs index 4749b366e..77a148212 100644 --- a/src/openhuman/memory/store/unified/helpers.rs +++ b/src/openhuman/memory/store/unified/helpers.rs @@ -3,6 +3,7 @@ use crate::openhuman::memory::chunker::chunk_markdown; use super::UnifiedMemory; impl UnifiedMemory { + #[allow(clippy::too_many_arguments)] pub(crate) fn write_markdown_doc( &self, namespace: &str, diff --git a/src/openhuman/memory/store/unified/profile.rs b/src/openhuman/memory/store/unified/profile.rs index ddec35e62..c3b0f2139 100644 --- a/src/openhuman/memory/store/unified/profile.rs +++ b/src/openhuman/memory/store/unified/profile.rs @@ -53,7 +53,7 @@ impl FacetType { } } - pub fn from_str(s: &str) -> Self { + pub fn parse_or_default(s: &str) -> Self { match s { "skill" => Self::Skill, "role" => Self::Role, @@ -83,6 +83,7 @@ pub struct ProfileFacet { /// - Updates last_seen_at /// - Appends segment_id to source_segment_ids /// - Only overwrites value if new confidence > existing confidence +#[allow(clippy::too_many_arguments)] pub fn profile_upsert( conn: &Arc>, facet_id: &str, @@ -255,7 +256,7 @@ fn row_to_facet(row: &rusqlite::Row<'_>) -> rusqlite::Result { let facet_type_str: String = row.get(1)?; Ok(ProfileFacet { facet_id: row.get(0)?, - facet_type: FacetType::from_str(&facet_type_str), + facet_type: FacetType::parse_or_default(&facet_type_str), key: row.get(2)?, value: row.get(3)?, confidence: row.get(4)?, diff --git a/src/openhuman/memory/store/unified/segments.rs b/src/openhuman/memory/store/unified/segments.rs index 3d25a89dd..5436e4f21 100644 --- a/src/openhuman/memory/store/unified/segments.rs +++ b/src/openhuman/memory/store/unified/segments.rs @@ -57,7 +57,7 @@ impl SegmentStatus { } } - pub fn from_str(s: &str) -> Self { + pub fn parse_or_default(s: &str) -> Self { match s { "closed" => Self::Closed, "summarised" => Self::Summarised, @@ -464,7 +464,7 @@ fn row_to_segment(row: &rusqlite::Row<'_>) -> rusqlite::Result Result { #[cfg(target_os = "linux")] { linux::install(config)?; - return status(config); + status(config) } #[cfg(windows)] { windows::install(config)?; - return status(config); + status(config) } #[cfg(not(any(target_os = "macos", target_os = "linux", windows)))] anyhow::bail!("Service management is supported on macOS, Linux, and Windows only") @@ -77,12 +77,12 @@ pub fn stop(config: &Config) -> Result { #[cfg(target_os = "linux")] { linux::stop(config)?; - return status(config); + status(config) } #[cfg(windows)] { windows::stop(config)?; - return status(config); + status(config) } #[cfg(not(any(target_os = "macos", target_os = "linux", windows)))] anyhow::bail!("Service management is supported on macOS, Linux, and Windows only") diff --git a/src/openhuman/tools/impl/browser/mod.rs b/src/openhuman/tools/impl/browser/mod.rs index 638376621..9300ea221 100644 --- a/src/openhuman/tools/impl/browser/mod.rs +++ b/src/openhuman/tools/impl/browser/mod.rs @@ -1,3 +1,4 @@ +#[allow(clippy::module_inception)] mod browser; mod browser_open; mod image_info; diff --git a/src/openhuman/voice/postprocess.rs b/src/openhuman/voice/postprocess.rs index 8c062054f..71d1dbadd 100644 --- a/src/openhuman/voice/postprocess.rs +++ b/src/openhuman/voice/postprocess.rs @@ -173,7 +173,14 @@ mod tests { let rt = tokio::runtime::Runtime::new().unwrap(); let mut config = Config::default(); config.local_ai.voice_llm_cleanup_enabled = false; + let _guard = crate::openhuman::local_ai::LOCAL_AI_TEST_MUTEX + .lock() + .expect("local ai test mutex poisoned"); + let service = local_ai::global(&config); + let previous = service.status.lock().state.clone(); + service.status.lock().state = "not_ready".into(); let result = rt.block_on(cleanup_transcription(&config, "um hello uh world", None)); + service.status.lock().state = previous; assert_eq!(result, "um hello uh world"); } } diff --git a/tests/agent_builder_public.rs b/tests/agent_builder_public.rs new file mode 100644 index 000000000..8bb704fef --- /dev/null +++ b/tests/agent_builder_public.rs @@ -0,0 +1,202 @@ +use anyhow::Result; +use async_trait::async_trait; +use openhuman_core::openhuman::agent::dispatcher::XmlToolDispatcher; +use openhuman_core::openhuman::agent::Agent; +use openhuman_core::openhuman::context::prompt::SystemPromptBuilder; +use openhuman_core::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; +use openhuman_core::openhuman::providers::{ChatRequest, ChatResponse, Provider}; +use openhuman_core::openhuman::tools::{Tool, ToolResult}; +use std::collections::HashSet; +use std::sync::Arc; + +struct StubProvider; + +#[async_trait] +impl Provider for StubProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("ok".into()) + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + Ok(ChatResponse { + text: Some("ok".into()), + tool_calls: Vec::new(), + usage: None, + }) + } +} + +struct StubTool(&'static str); + +#[async_trait] +impl Tool for StubTool { + fn name(&self) -> &str { + self.0 + } + + fn description(&self) -> &str { + "stub tool" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> Result { + Ok(ToolResult::success(args.to_string())) + } +} + +struct StubMemory; + +#[async_trait] +impl Memory for StubMemory { + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> Result { + Ok(false) + } + + async fn count(&self) -> Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + + fn name(&self) -> &str { + "stub" + } +} + +fn base_builder() -> openhuman_core::openhuman::agent::AgentBuilder { + Agent::builder() + .provider(Box::new(StubProvider)) + .tools(vec![ + Box::new(StubTool("alpha")), + Box::new(StubTool("beta")), + ]) + .memory(Arc::new(StubMemory)) + .tool_dispatcher(Box::new(XmlToolDispatcher)) +} + +#[test] +fn builder_validates_required_fields() { + let err = Agent::builder() + .build() + .err() + .expect("missing tools should error"); + assert!(err.to_string().contains("tools are required")); + + let err = Agent::builder() + .tools(vec![Box::new(StubTool("alpha"))]) + .build() + .err() + .expect("missing provider should error"); + assert!(err.to_string().contains("provider is required")); + + let err = Agent::builder() + .provider(Box::new(StubProvider)) + .tools(vec![Box::new(StubTool("alpha"))]) + .build() + .err() + .expect("missing memory should error"); + assert!(err.to_string().contains("memory is required")); + + let err = Agent::builder() + .provider(Box::new(StubProvider)) + .tools(vec![Box::new(StubTool("alpha"))]) + .memory(Arc::new(StubMemory)) + .build() + .err() + .expect("missing dispatcher should error"); + assert!(err.to_string().contains("tool_dispatcher is required")); +} + +#[test] +fn builder_applies_defaults_and_exposes_public_accessors() { + let agent = base_builder() + .build() + .expect("minimal builder should succeed"); + + assert_eq!(agent.tools().len(), 2); + assert_eq!(agent.tool_specs().len(), 2); + assert_eq!( + agent.model_name(), + openhuman_core::openhuman::config::DEFAULT_MODEL + ); + assert_eq!(agent.temperature(), 0.7); + assert_eq!(agent.workspace_dir(), std::path::Path::new(".")); + assert!(agent.skills().is_empty()); + assert!(agent.history().is_empty()); + assert_eq!(agent.agent_config().max_tool_iterations, 10); +} + +#[test] +fn builder_filters_visible_tools_and_keeps_full_registry() { + let agent = base_builder() + .visible_tool_names(HashSet::from_iter(["beta".to_string()])) + .model_name("model-x".into()) + .temperature(0.4) + .workspace_dir(std::path::PathBuf::from("/tmp/agent-builder-visible")) + .prompt_builder(SystemPromptBuilder::with_defaults()) + .event_context("session-9", "cli") + .agent_definition_name("orchestrator") + .build() + .expect("builder should succeed"); + + assert_eq!(agent.tools().len(), 2); + assert_eq!(agent.tool_specs().len(), 2); + assert_eq!(agent.model_name(), "model-x"); + assert_eq!(agent.temperature(), 0.4); + assert_eq!( + agent.workspace_dir(), + std::path::Path::new("/tmp/agent-builder-visible") + ); +} diff --git a/tests/agent_harness_public.rs b/tests/agent_harness_public.rs new file mode 100644 index 000000000..4d143eafe --- /dev/null +++ b/tests/agent_harness_public.rs @@ -0,0 +1,295 @@ +use anyhow::Result; +use async_trait::async_trait; +use openhuman_core::openhuman::agent::harness::{ + check_interrupt, current_fork, current_parent, with_fork_context, with_parent_context, + ForkContext, InterruptFence, ParentExecutionContext, +}; +use openhuman_core::openhuman::agent::hooks::{ + fire_hooks, sanitize_tool_output, PostTurnHook, ToolCallRecord, TurnContext, +}; +use openhuman_core::openhuman::config::AgentConfig; +use openhuman_core::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; +use openhuman_core::openhuman::providers::{ChatMessage, ChatRequest, ChatResponse, Provider}; +use parking_lot::Mutex; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use tokio::sync::Notify; + +struct StubProvider; + +#[async_trait] +impl Provider for StubProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("ok".into()) + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + Ok(ChatResponse { + text: Some("ok".into()), + tool_calls: Vec::new(), + usage: None, + }) + } +} + +struct StubMemory; + +#[async_trait] +impl Memory for StubMemory { + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> Result { + Ok(false) + } + + async fn count(&self) -> Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + + fn name(&self) -> &str { + "stub" + } +} + +fn sample_turn() -> TurnContext { + TurnContext { + user_message: "hello".into(), + assistant_response: "world".into(), + tool_calls: vec![ToolCallRecord { + name: "shell".into(), + arguments: serde_json::json!({}), + success: true, + output_summary: "ok".into(), + duration_ms: 10, + }], + turn_duration_ms: 15, + session_id: Some("s1".into()), + iteration_count: 1, + } +} + +fn stub_parent_context() -> ParentExecutionContext { + ParentExecutionContext { + provider: Arc::new(StubProvider), + all_tools: Arc::new(vec![]), + all_tool_specs: Arc::new(vec![]), + model_name: "stub-model".into(), + temperature: 0.4, + workspace_dir: std::path::PathBuf::from("/tmp"), + memory: Arc::new(StubMemory), + agent_config: AgentConfig::default(), + skills: Arc::new(vec![]), + memory_context: Some("ctx".into()), + session_id: "test-session".into(), + channel: "test-channel".into(), + } +} + +struct RecordingHook { + name: &'static str, + calls: Arc>>, + notify: Arc, + fail: bool, +} + +#[async_trait] +impl PostTurnHook for RecordingHook { + fn name(&self) -> &str { + self.name + } + + async fn on_turn_complete(&self, ctx: &TurnContext) -> Result<()> { + self.calls + .lock() + .push(format!("{}:{}", self.name, ctx.user_message)); + self.notify.notify_waiters(); + if self.fail { + anyhow::bail!("hook failed"); + } + Ok(()) + } +} + +#[test] +fn interrupt_fence_shares_and_resets_state() { + let fence = InterruptFence::default(); + assert!(!fence.is_interrupted()); + assert!(check_interrupt(&fence).is_ok()); + + let clone = fence.clone(); + let raw = fence.flag_handle(); + fence.trigger(); + assert!(clone.is_interrupted()); + assert!(raw.load(Ordering::Relaxed)); + assert!(check_interrupt(&fence).is_err()); + + raw.store(false, Ordering::Relaxed); + fence.reset(); + assert!(!fence.is_interrupted()); +} + +#[tokio::test] +async fn interrupt_signal_handler_is_installable() { + let fence = InterruptFence::new(); + fence.install_signal_handler(); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + assert!(!fence.is_interrupted()); +} + +#[tokio::test] +async fn fork_and_parent_contexts_are_visible_only_within_scope() { + assert!(current_parent().is_none()); + assert!(current_fork().is_none()); + + let fork = ForkContext { + system_prompt: Arc::new("hello".into()), + tool_specs: Arc::new(vec![]), + message_prefix: Arc::new(vec![ChatMessage::system("hello")]), + cache_boundary: Some(5), + fork_task_prompt: "do thing".into(), + }; + + with_fork_context(fork, async { + let inner = current_fork().expect("fork context should be visible"); + assert_eq!(*inner.system_prompt, "hello"); + assert_eq!(inner.fork_task_prompt, "do thing"); + assert_eq!(inner.cache_boundary, Some(5)); + assert_eq!(inner.message_prefix.len(), 1); + }) + .await; + + let parent = stub_parent_context(); + with_parent_context(parent, async { + let inner = current_parent().expect("parent context should be visible"); + assert_eq!(inner.model_name, "stub-model"); + assert_eq!(inner.session_id, "test-session"); + assert_eq!(inner.channel, "test-channel"); + assert_eq!(inner.memory_context.as_deref(), Some("ctx")); + }) + .await; + + assert!(current_parent().is_none()); + assert!(current_fork().is_none()); +} + +#[test] +fn sanitize_tool_output_classifies_common_errors() { + assert_eq!( + sanitize_tool_output("fine", "shell", true), + "shell: ok (4 chars)" + ); + assert_eq!( + sanitize_tool_output("Connection timeout while fetching", "http_request", false), + "http_request: failed (timeout)" + ); + assert_eq!( + sanitize_tool_output("permission denied opening file", "file_read", false), + "file_read: failed (permission_denied)" + ); + assert_eq!( + sanitize_tool_output("unknown tool called", "delegate", false), + "delegate: failed (unknown_tool)" + ); + assert_eq!( + sanitize_tool_output("bad syntax in payload", "json", false), + "json: failed (parse_error)" + ); + assert_eq!( + sanitize_tool_output("no such file or directory", "file_read", false), + "file_read: failed (not_found)" + ); + assert_eq!( + sanitize_tool_output("network connection reset by peer", "http_request", false), + "http_request: failed (connection_error)" + ); + assert_eq!( + sanitize_tool_output("something strange happened", "shell", false), + "shell: failed (error)" + ); +} + +#[tokio::test] +async fn fire_hooks_dispatches_all_hooks_even_when_one_fails() { + let calls = Arc::new(Mutex::new(Vec::new())); + let notify = Arc::new(Notify::new()); + let hooks: Vec> = vec![ + Arc::new(RecordingHook { + name: "ok", + calls: Arc::clone(&calls), + notify: Arc::clone(¬ify), + fail: false, + }), + Arc::new(RecordingHook { + name: "fail", + calls: Arc::clone(&calls), + notify: Arc::clone(¬ify), + fail: true, + }), + ]; + + fire_hooks(&hooks, sample_turn()); + + tokio::time::timeout(std::time::Duration::from_secs(1), async { + loop { + if calls.lock().len() == 2 { + break; + } + notify.notified().await; + } + }) + .await + .expect("hooks should complete"); + + let calls = calls.lock().clone(); + assert!(calls.contains(&"ok:hello".into())); + assert!(calls.contains(&"fail:hello".into())); +} + +#[test] +fn fire_hooks_accepts_empty_hook_lists() { + fire_hooks(&[], sample_turn()); +} diff --git a/tests/agent_memory_loader_public.rs b/tests/agent_memory_loader_public.rs new file mode 100644 index 000000000..4706c5476 --- /dev/null +++ b/tests/agent_memory_loader_public.rs @@ -0,0 +1,144 @@ +use anyhow::Result; +use async_trait::async_trait; +use openhuman_core::openhuman::agent::memory_loader::{DefaultMemoryLoader, MemoryLoader}; +use openhuman_core::openhuman::memory::{Memory, MemoryCategory, MemoryEntry}; +use std::sync::Arc; + +struct ScriptedMemory { + primary: Vec, + working: Vec, +} + +#[async_trait] +impl Memory for ScriptedMemory { + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> Result<()> { + Ok(()) + } + + async fn recall( + &self, + query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> Result> { + if query.contains("working.user") { + Ok(self.working.clone()) + } else { + Ok(self.primary.clone()) + } + } + + async fn get(&self, _key: &str) -> Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> Result { + Ok(false) + } + + async fn count(&self) -> Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + + fn name(&self) -> &str { + "scripted" + } +} + +fn entry(key: &str, content: &str, score: Option) -> MemoryEntry { + MemoryEntry { + id: key.into(), + key: key.into(), + content: content.into(), + namespace: None, + category: MemoryCategory::Conversation, + timestamp: "now".into(), + session_id: None, + score, + } +} + +#[tokio::test] +async fn loader_merges_primary_and_working_memory_with_filters() -> Result<()> { + let memory: Arc = Arc::new(ScriptedMemory { + primary: vec![ + entry("high", "keep me", Some(0.9)), + entry("low", "drop me", Some(0.1)), + ], + working: vec![ + entry("working.user.pref", "concise", Some(0.95)), + entry("high", "duplicate", Some(0.95)), + ], + }); + + let context = DefaultMemoryLoader::new(5, 0.4) + .with_max_chars(200) + .load_context(memory.as_ref(), "hello") + .await?; + + assert!(context.contains("[Memory context]")); + assert!(context.contains("- high: keep me")); + assert!(!context.contains("drop me")); + assert!(context.contains("[User working memory]")); + assert!(context.contains("working.user.pref")); + assert!(!context.contains("duplicate")); + Ok(()) +} + +#[tokio::test] +async fn loader_can_return_only_working_memory_when_primary_is_empty() -> Result<()> { + let memory: Arc = Arc::new(ScriptedMemory { + primary: Vec::new(), + working: vec![entry("working.user.todo", "ship it", None)], + }); + + let context = DefaultMemoryLoader::default() + .load_context(memory.as_ref(), "hello") + .await?; + + assert!(!context.contains("[Memory context]")); + assert!(context.contains("[User working memory]")); + assert!(context.contains("working.user.todo")); + Ok(()) +} + +#[tokio::test] +async fn loader_respects_tight_budgets() -> Result<()> { + let memory: Arc = Arc::new(ScriptedMemory { + primary: vec![entry("main", "1234567890", Some(0.9))], + working: vec![entry("working.user.tip", "include me", Some(0.9))], + }); + + let header_len = "[Memory context]\n".len(); + let empty = DefaultMemoryLoader::new(1, 0.4) + .with_max_chars(header_len) + .load_context(memory.as_ref(), "hello") + .await?; + assert!(empty.is_empty()); + + let bounded = DefaultMemoryLoader::new(1, 0.4) + .with_max_chars("[Memory context]\n- main: 1234567890\n".len() + 1) + .load_context(memory.as_ref(), "hello") + .await?; + assert!(bounded.contains("- main: 1234567890")); + assert!(!bounded.contains("working.user.tip")); + Ok(()) +} diff --git a/tests/agent_multimodal_public.rs b/tests/agent_multimodal_public.rs new file mode 100644 index 000000000..7d67da870 --- /dev/null +++ b/tests/agent_multimodal_public.rs @@ -0,0 +1,112 @@ +use anyhow::Result; +use openhuman_core::openhuman::agent::multimodal::{ + contains_image_markers, count_image_markers, extract_ollama_image_payload, parse_image_markers, + prepare_messages_for_provider, +}; +use openhuman_core::openhuman::config::MultimodalConfig; +use openhuman_core::openhuman::providers::ChatMessage; + +#[test] +fn marker_helpers_cover_mixed_content_and_payload_extraction() { + let messages = vec![ + ChatMessage::assistant("[IMAGE:/tmp/ignored.png]"), + ChatMessage::user("look [IMAGE:/tmp/a.png] then [IMAGE:data:image/png;base64,abcd]"), + ]; + + let (cleaned, refs) = parse_image_markers(messages[1].content.as_str()); + assert_eq!(cleaned, "look then"); + assert_eq!(refs.len(), 2); + assert_eq!(count_image_markers(&messages), 2); + assert!(contains_image_markers(&messages)); + assert_eq!( + extract_ollama_image_payload("data:image/png;base64,abcd").as_deref(), + Some("abcd") + ); + assert_eq!( + extract_ollama_image_payload(" /tmp/a.png ").as_deref(), + Some("/tmp/a.png") + ); + let (cleaned_unclosed, refs_unclosed) = parse_image_markers("broken [IMAGE:/tmp/a.png"); + assert_eq!(cleaned_unclosed, "broken [IMAGE:/tmp/a.png"); + assert!(refs_unclosed.is_empty()); + + let (cleaned_empty, refs_empty) = parse_image_markers("keep [IMAGE:] literal"); + assert_eq!(cleaned_empty, "keep [IMAGE:] literal"); + assert!(refs_empty.is_empty()); + + assert!(!contains_image_markers(&[ChatMessage::assistant( + "no user refs" + )])); +} + +#[tokio::test] +async fn prepare_messages_passthrough_when_no_user_images_exist() -> Result<()> { + let messages = vec![ + ChatMessage::system("sys"), + ChatMessage::assistant("[IMAGE:/tmp/not-counted.png]"), + ChatMessage::user("plain text"), + ]; + + let prepared = prepare_messages_for_provider(&messages, &MultimodalConfig::default()).await?; + assert!(!prepared.contains_images); + assert_eq!(prepared.messages.len(), 3); + assert_eq!(prepared.messages[2].content, "plain text"); + Ok(()) +} + +#[tokio::test] +async fn prepare_messages_accepts_data_uris_and_preserves_other_messages() -> Result<()> { + let messages = vec![ + ChatMessage::assistant("already there"), + ChatMessage::user("inspect [IMAGE:data:image/PNG;base64,iVBORw0KGgo=]"), + ]; + + let prepared = prepare_messages_for_provider(&messages, &MultimodalConfig::default()).await?; + assert!(prepared.contains_images); + assert_eq!(prepared.messages[0].content, "already there"); + + let (cleaned, refs) = parse_image_markers(&prepared.messages[1].content); + assert_eq!(cleaned, "inspect"); + assert_eq!(refs.len(), 1); + assert!(refs[0].starts_with("data:image/png;base64,")); + Ok(()) +} + +#[tokio::test] +async fn prepare_messages_rejects_invalid_data_uri_forms() { + let invalid_non_base64 = vec![ChatMessage::user("bad [IMAGE:data:image/png,abcd]")]; + let err = prepare_messages_for_provider(&invalid_non_base64, &MultimodalConfig::default()) + .await + .expect_err("non-base64 data uri should fail"); + assert!(err + .to_string() + .contains("only base64 data URIs are supported")); + + let invalid_mime = vec![ChatMessage::user("bad [IMAGE:data:text/plain;base64,YQ==]")]; + let err = prepare_messages_for_provider(&invalid_mime, &MultimodalConfig::default()) + .await + .expect_err("unsupported mime should fail"); + assert!(err.to_string().contains("MIME type is not allowed")); + + let invalid_base64 = vec![ChatMessage::user("bad [IMAGE:data:image/png;base64,%%%]")]; + let err = prepare_messages_for_provider(&invalid_base64, &MultimodalConfig::default()) + .await + .expect_err("invalid base64 should fail"); + assert!(err.to_string().contains("invalid base64 payload")); +} + +#[tokio::test] +async fn prepare_messages_rejects_unknown_local_mime() { + let temp = tempfile::tempdir().expect("tempdir"); + let file_path = temp.path().join("sample.txt"); + std::fs::write(&file_path, b"not an image").expect("write sample"); + + let messages = vec![ChatMessage::user(format!( + "bad [IMAGE:{}]", + file_path.display() + ))]; + let err = prepare_messages_for_provider(&messages, &MultimodalConfig::default()) + .await + .expect_err("unknown mime should fail"); + assert!(err.to_string().contains("unknown")); +}