From 894b07eab09bb11af1a726236a013182c372a5ab Mon Sep 17 00:00:00 2001 From: Jordan Hafer <42755763+jjhafer@users.noreply.github.com> Date: Tue, 17 Mar 2026 21:25:29 -0400 Subject: [PATCH] feat: add Rust API to create requests Consumers would like to create requests from the Rust side of Tauri apps. --- README.md | 62 ++ examples/tauri-app/src-tauri/Cargo.toml | 1 + examples/tauri-app/src-tauri/src/lib.rs | 59 ++ examples/tauri-app/src-tauri/tauri.conf.json | 2 + examples/tauri-app/src/index.html | 15 + examples/tauri-app/src/main.js | 32 + src/client.rs | 633 ++++++++++++++----- src/config.rs | 13 + src/error.rs | 2 +- src/lib.rs | 4 + src/request.rs | 298 +++++++++ src/response.rs | 207 ++++++ 12 files changed, 1183 insertions(+), 145 deletions(-) create mode 100644 src/request.rs create mode 100644 src/response.rs diff --git a/README.md b/README.md index db4977a..841761d 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ requests from Tauri applications. support * Binary request and response bodies * Runtime allowlist management from Rust + * Rust backend API -- make HTTP requests from Rust + through the same security pipeline as the frontend ## Installation @@ -268,9 +270,69 @@ try { See [`HttpErrorCode`](guest-js/errors.ts) for the full list of error codes and descriptions. +### Rust Backend Requests + +The plugin exposes a Rust API for making HTTP requests +from backend code through the same security pipeline +(domain allowlist, private IP blocking, redirect +validation, body size limits, retry) as the frontend. + +```rust +use tauri::Manager; +use tauri_plugin_http_client::HttpClientExt; + +#[tauri::command] +async fn fetch_data( + app: tauri::AppHandle, +) -> Result { + let resp = app.http_client() + .get("https://api.example.com/data") + .header("Accept", "application/json") + .timeout(std::time::Duration::from_secs(10)) + .send() + .await + .map_err(|e| e.to_string())?; + + resp.text() + .map(|s| s.to_string()) + .map_err(|e| e.to_string()) +} +``` + +> `send()` returns +> `tauri_plugin_http_client::error::Error`, which +> provides `is_retryable()` for retry decisions and can +> be matched on specific variants (e.g., +> `Error::DomainNotAllowed`). Response body methods like +> `text()` return standard library errors. + +Available builder methods: + + * `get(url)` / `post(url)` -- convenience starters + * `request(method, url)` -- arbitrary HTTP method + * `.header(key, val)` -- add a header (repeatable) + * `.body(bytes)` -- set the request body + * `.timeout(duration)` -- per-request timeout + * `.max_retries(n)` -- per-request retry cap + * `.send()` -- execute through the security pipeline + +The response provides native `reqwest` types: + + * `status()` -- `reqwest::StatusCode` + * `headers()` -- `&reqwest::header::HeaderMap` + * `url()` -- `&url::Url` (final URL after redirects) + * `redirected()` -- `bool` + * `body()` / `into_body()` -- `&[u8]` / `Vec` + * `text()` -- `Result<&str, std::str::Utf8Error>` + * `retry_count()` -- number of retries performed + ## Security +Both the TypeScript frontend and Rust backend API share +the same security pipeline. All protections below apply +equally to both paths. + ### Domain Allowlist The allowlist has two tiers: diff --git a/examples/tauri-app/src-tauri/Cargo.toml b/examples/tauri-app/src-tauri/Cargo.toml index af3638c..ce8d064 100644 --- a/examples/tauri-app/src-tauri/Cargo.toml +++ b/examples/tauri-app/src-tauri/Cargo.toml @@ -9,6 +9,7 @@ name = "tauri_app_lib" crate-type = ["staticlib", "cdylib", "rlib"] [dependencies] +serde = { version = "=1.0.228", features = ["derive"] } tauri = { version = "=2.10.3", features = [] } tauri-plugin-http-client = { path = "../../../" } diff --git a/examples/tauri-app/src-tauri/src/lib.rs b/examples/tauri-app/src-tauri/src/lib.rs index df7961a..4829a26 100644 --- a/examples/tauri-app/src-tauri/src/lib.rs +++ b/examples/tauri-app/src-tauri/src/lib.rs @@ -1,5 +1,63 @@ +use std::collections::HashMap; use std::time::Duration; +use serde::Serialize; +use tauri_plugin_http_client::HttpClientExt; + +/// Response shape returned to the frontend from the Rust backend request. +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct BackendResponse { + status: u16, + status_text: String, + headers: HashMap, + url: String, + redirected: bool, + body: String, +} + +/// Demonstrates the Rust backend API: makes an HTTP request through the +/// plugin's security pipeline from Rust, then returns the result to the +/// frontend via IPC. +#[tauri::command] +async fn fetch_from_rust( + app: tauri::AppHandle, + url: String, +) -> Result { + let resp = app + .http_client() + .get(&url) + .header("Accept", "application/json") + .header("X-Requested-From", "rust-backend") + .send() + .await + .map_err(|e| e.to_string())?; + + // Note: multi-value headers (e.g. Set-Cookie) are collapsed to last value. + let mut headers = HashMap::new(); + + for (name, value) in resp.headers() { + if let Ok(v) = value.to_str() { + headers.insert(name.as_str().to_string(), v.to_string()); + } + } + + let status_text = resp + .status() + .canonical_reason() + .unwrap_or("") + .to_string(); + + Ok(BackendResponse { + status: resp.status().as_u16(), + status_text, + headers, + url: resp.url().to_string(), + redirected: resp.redirected(), + body: resp.text().map(|s| s.to_string()).unwrap_or_default(), + }) +} + #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { tauri::Builder::default() @@ -10,6 +68,7 @@ pub fn run() { .max_response_body_size(5 * 1024 * 1024) .build(), ) + .invoke_handler(tauri::generate_handler![fetch_from_rust]) .run(tauri::generate_context!()) .expect("error while running tauri application"); } diff --git a/examples/tauri-app/src-tauri/tauri.conf.json b/examples/tauri-app/src-tauri/tauri.conf.json index 3827709..64635e8 100644 --- a/examples/tauri-app/src-tauri/tauri.conf.json +++ b/examples/tauri-app/src-tauri/tauri.conf.json @@ -3,6 +3,8 @@ "version": "0.1.0", "identifier": "com.silvermine.httpClientExample", "build": { + "beforeDevCommand": "npm run dev", + "beforeBuildCommand": "npm run build", "devUrl": "http://localhost:5173", "frontendDist": "../dist" }, diff --git a/examples/tauri-app/src/index.html b/examples/tauri-app/src/index.html index ff2fd0e..41aa57f 100644 --- a/examples/tauri-app/src/index.html +++ b/examples/tauri-app/src/index.html @@ -88,6 +88,21 @@

Binary Response

+ +
+

Rust Backend Request

+

+ Invokes a Tauri command that makes an HTTP request from Rust using + app.http_client().get(url).send() — the same security + pipeline as the frontend, but from backend code. +

+
+ + +
+
Response will appear here...
+
+ diff --git a/examples/tauri-app/src/main.js b/examples/tauri-app/src/main.js index 6726fd5..808eab9 100644 --- a/examples/tauri-app/src/main.js +++ b/examples/tauri-app/src/main.js @@ -1,4 +1,5 @@ import { request, HttpHeaders, HttpClientError } from '@silvermine/tauri-plugin-http-client'; +import { invoke } from '@tauri-apps/api/core'; // --- Helpers --- @@ -292,3 +293,34 @@ $('binary-fetch').addEventListener('click', async function() { btn.disabled = false; } }); + +// --- Rust Backend Request --- + +$('rust-send').addEventListener('click', async function() { + const output = $('rust-output'), + btn = $('rust-send'); + + btn.disabled = true; + setLoading(output, 'Sending request via Rust backend'); + + try { + const resp = await invoke('fetch_from_rust', { url: $('rust-url').value }); + + setResult(output, JSON.stringify( + { + status: resp.status, + statusText: resp.statusText, + headers: resp.headers, + url: resp.url, + redirected: resp.redirected, + body: tryParseJSON(resp.body), + }, + null, + 2, + )); + } catch(err) { + setError(output, err); + } finally { + btn.disabled = false; + } +}); diff --git a/src/client.rs b/src/client.rs index 7cff2a9..70ec27f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,6 +7,8 @@ use futures_util::StreamExt; use parking_lot::RwLock; use reqwest::redirect; +use reqwest::header::HeaderMap; + use crate::allowlist::{DomainAllowlist, is_private_ip}; use crate::config::{HttpClientConfig, RetryConfig}; use crate::error::{Error, Result}; @@ -41,6 +43,62 @@ const FORBIDDEN_HEADERS: &[&str] = &[ /// These affect proxy routing in ways outside the plugin's security model. const FORBIDDEN_HEADER_PREFIXES: &[&str] = &["sec-", "proxy-"]; +/// Internal response type shared between the IPC and backend paths. +/// +/// Carries native `reqwest` types so that each path can convert to its own +/// public type without intermediate re-encoding. +struct RawResponse { + status: reqwest::StatusCode, + headers: HeaderMap, + url: url::Url, + redirected: bool, + body: Vec, +} + +impl RawResponse { + /// Converts to the IPC-oriented [`ExecuteResult`] (string headers, + /// numeric status). + fn into_execute_result(self, retry_count: u32) -> ExecuteResult { + let mut response_headers: HashMap> = HashMap::new(); + + for (name, value) in &self.headers { + let name = name.as_str().to_string(); + + if let Ok(v) = value.to_str() { + response_headers + .entry(name) + .or_default() + .push(v.to_string()); + } + } + + ExecuteResult { + metadata: FetchResponseMetadata { + status: self.status.as_u16(), + status_text: self.status.canonical_reason().unwrap_or("").to_string(), + headers: response_headers, + url: self.url.to_string(), + redirected: self.redirected, + retry_count, + }, + body: self.body, + } + } + + /// Converts to the backend-oriented [`Response`](crate::response::Response) + /// (native `reqwest` types). + fn into_response(self, retry_count: u32) -> crate::response::Response { + crate::response::Response::new( + self.status, + self.headers, + self.url, + self.redirected, + self.body, + retry_count, + ) + } +} + /// Validates that a header name is not in the forbidden list. /// /// Returns `Err(Error::ForbiddenHeader)` if the name matches any entry in @@ -262,8 +320,7 @@ impl HttpClientState { /// 2. Build the reqwest request (method, headers, body, timeout) /// 3. Execute with custom redirect policy /// 4. Read response body with size limit enforcement - /// 5. Detect text vs binary content and encode accordingly - /// 6. Return structured response + /// 5. Return structured response /// /// # Retry Behavior /// @@ -293,17 +350,98 @@ impl HttpClientState { .timeout_ms .map(Duration::from_millis) .or(self.config.default_timeout); - let max_retries = self.resolve_max_retries(&req); - let max_attempts = max_retries + 1; - let retry_config = &self.config.retry; - let method_retryable = retry_config.is_retryable_method(method.as_str()); // Parse and validate URL once before the retry loop. The URL string // doesn't change between retries, so re-parsing is unnecessary. let url = self.allowlist.read().validate_url(&req.url)?; - let mut last_result: Option> = None; + // Build a merged HeaderMap: default headers first, then per-request + // headers (which override defaults). Validates forbidden headers. + let merged_headers = self.build_ipc_header_map(&req.headers)?; + + let (raw, attempt) = self + .execute_with_retry( + &url, + &method, + &merged_headers, + body_bytes.as_deref(), + timeout, + max_retries, + ) + .await?; + + Ok(raw.into_execute_result(attempt)) + } + + /// Executes an HTTP request from Rust backend code, returning a native + /// [`Response`](crate::response::Response) with `reqwest` types. + /// + /// This is the backend counterpart to [`execute`](Self::execute) (which + /// serves the IPC/frontend path). It follows the same security pipeline: + /// domain allowlist, private IP blocking, redirect validation, streaming + /// body limits, and retry. + /// + /// Called by [`RequestBuilder::send`](crate::request::RequestBuilder::send). + pub(crate) async fn execute_backend( + &self, + url: &str, + method: reqwest::Method, + headers: &[(String, String)], + body: Option<&[u8]>, + timeout: Option, + max_retries: Option, + ) -> Result { + let timeout = timeout.or(self.config.default_timeout); + let max_retries = match max_retries { + Some(n) => n.min(self.config.retry.max_retries), + None => self.config.retry.max_retries, + }; + + // Build a merged HeaderMap: default headers first, then per-request + // headers (which override defaults). Validates forbidden headers. + let merged_headers = self.build_backend_header_map(headers)?; + + let url = self.validate_url_for_request(url)?; + + let (raw, attempt) = self + .execute_with_retry(&url, &method, &merged_headers, body, timeout, max_retries) + .await?; + + Ok(raw.into_response(attempt)) + } + + /// Shared retry loop used by both [`execute`](Self::execute) (IPC path) + /// and [`execute_backend`](Self::execute_backend) (Rust backend path). + /// + /// Expects the URL to be pre-validated and headers pre-merged by the + /// caller. Returns the final [`RawResponse`] and the attempt count + /// (0 = succeeded on first try, 1+ = number of retries performed). + /// + /// # Retry Behavior + /// + /// When retry is enabled (`max_retries > 0`), transient errors + /// (connection failures, timeouts) and retryable status codes trigger + /// automatic retries with exponential backoff and jitter. Security + /// errors are never retried. + /// + /// The URL is re-validated against the allowlist on every retry attempt. + /// If the allowlist changes between retries (e.g., a domain is removed), + /// the subsequent attempt fails with `DomainNotAllowed` (fail-secure). + async fn execute_with_retry( + &self, + url: &url::Url, + method: &reqwest::Method, + headers: &HeaderMap, + body: Option<&[u8]>, + timeout: Option, + max_retries: u32, + ) -> Result<(RawResponse, u32)> { + let max_attempts = max_retries + 1; + let retry_config = &self.config.retry; + let method_retryable = retry_config.is_retryable_method(method.as_str()); + + let mut last_result: Option> = None; let mut attempt: u32 = 0; // Bounded: returns when attempt + 1 >= max_attempts (should_retry = false) @@ -316,7 +454,7 @@ impl HttpClientState { tracing::error!( attempt, max_attempts, - url = %req.url, + url = %url, "retry loop exceeded max_attempts; this is a bug" ); return Err(Error::Other(format!( @@ -331,7 +469,7 @@ impl HttpClientState { attempt, max_attempts, backoff_ms = backoff.as_millis() as u64, - url = %req.url, + url = %url, "retrying request" ); @@ -342,18 +480,16 @@ impl HttpClientState { // retries must cause immediate failure (fail-secure). We use // validate_parsed_url (not validate_url) to avoid redundant // string parsing since the URL itself hasn't changed. - self.allowlist.read().validate_parsed_url(&url)?; + self.revalidate_parsed_url(url)?; } - let result = self - .execute_once(&url, &method, &req.headers, body_bytes.as_deref(), timeout) - .await; + let result = self.execute_once(url, method, headers, body, timeout).await; let should_retry = attempt + 1 < max_attempts && method_retryable; match result { Ok(ref resp) - if should_retry && retry_config.is_retryable_status(resp.metadata.status) => + if should_retry && retry_config.is_retryable_status(resp.status.as_u16()) => { last_result = Some(result); attempt += 1; @@ -362,42 +498,98 @@ impl HttpClientState { last_result = Some(Err(e)); attempt += 1; } - Ok(mut resp) => { - resp.metadata.retry_count = attempt; - return Ok(resp); - } + Ok(raw) => return Ok((raw, attempt)), Err(e) => return Err(e), } } } + /// Validates a URL string against the domain allowlist, returning a parsed + /// [`Url`](url::Url). + /// + /// In `#[cfg(test)]` builds with `skip_url_validation` enabled, performs + /// only basic URL parsing without allowlist or scheme checks. + fn validate_url_for_request(&self, url: &str) -> Result { + #[cfg(test)] + if self.config.skip_url_validation { + return url::Url::parse(url).map_err(|e| Error::InvalidUrl(e.to_string())); + } + + self.allowlist.read().validate_url(url) + } + + /// Re-validates an already-parsed URL against the domain allowlist. + /// + /// Used on retry attempts to enforce fail-secure behavior when the + /// allowlist changes between retries. Skips validation in + /// `#[cfg(test)]` builds with `skip_url_validation` enabled. + fn revalidate_parsed_url(&self, url: &url::Url) -> Result<()> { + #[cfg(test)] + if self.config.skip_url_validation { + return Ok(()); + } + + self.allowlist.read().validate_parsed_url(url) + } + + /// Builds a merged [`HeaderMap`] from the plugin's default headers and + /// per-request backend headers (slice of tuples). + /// + /// Default headers are applied first; per-request headers override them. + /// All header names are validated against the forbidden-header list. + fn build_backend_header_map(&self, request_headers: &[(String, String)]) -> Result { + merge_headers( + &self.config.default_headers, + request_headers + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())), + ) + } + + /// Starts building a GET request through the plugin's security pipeline. + /// + /// Returns a [`RequestBuilder`](crate::request::RequestBuilder) that can + /// be customized with headers, timeout, and retry settings before sending. + pub fn get(&self, url: impl Into) -> crate::request::RequestBuilder<'_> { + crate::request::RequestBuilder::new(self, reqwest::Method::GET, url.into()) + } + + /// Starts building a POST request through the plugin's security pipeline. + pub fn post(&self, url: impl Into) -> crate::request::RequestBuilder<'_> { + crate::request::RequestBuilder::new(self, reqwest::Method::POST, url.into()) + } + + /// Starts building a request with an arbitrary HTTP method. + pub fn request( + &self, + method: reqwest::Method, + url: impl Into, + ) -> crate::request::RequestBuilder<'_> { + crate::request::RequestBuilder::new(self, method, url.into()) + } + /// Executes a single HTTP request attempt through the full pipeline. /// /// This is the inner implementation called by [`execute`](Self::execute) - /// on each attempt. It assumes URL validation has already been performed. + /// and [`execute_backend`](Self::execute_backend) on each attempt. It + /// assumes URL validation has already been performed and headers have + /// been pre-validated and merged into a [`HeaderMap`]. /// - /// Returns raw body bytes and metadata. Encoding for IPC transfer - /// (binary framing or JSON with base64) happens at the command layer. + /// Returns a [`RawResponse`] carrying native `reqwest` types. Conversion + /// to the caller's public type (`ExecuteResult` for IPC, `Response` for + /// backend) happens after the retry loop. async fn execute_once( &self, url: &url::Url, method: &reqwest::Method, - headers: &Option>, + headers: &HeaderMap, body: Option<&[u8]>, timeout: Option, - ) -> Result { + ) -> Result { let mut builder = self.client.request(method.clone(), url.clone()); - for (key, value) in &self.config.default_headers { - builder = builder.header(key.as_str(), value.as_str()); - } - - if let Some(headers) = headers { - for (key, value) in headers { - validate_header_name(key)?; - builder = builder.header(key.as_str(), value.as_str()); - } - } + // Clone required: reqwest's builder.headers() takes HeaderMap by value. + builder = builder.headers(headers.clone()); if let Some(body) = body { builder = builder.body(body.to_vec()); @@ -436,33 +628,14 @@ impl HttpClientState { } let status = response.status(); - let status_text = status.canonical_reason().unwrap_or("").to_string(); - - // Collect response headers (multi-value support) - let mut response_headers: HashMap> = HashMap::new(); - - for (name, value) in response.headers() { - let name = name.as_str().to_string(); - - if let Ok(v) = value.to_str() { - response_headers - .entry(name) - .or_default() - .push(v.to_string()); - } - } - + let resp_headers = response.headers().clone(); let body_bytes = self.read_body_with_limit(response).await?; - Ok(ExecuteResult { - metadata: FetchResponseMetadata { - status: status.as_u16(), - status_text, - headers: response_headers, - url: final_url.to_string(), - redirected, - retry_count: 0, // Set by execute() after the loop - }, + Ok(RawResponse { + status, + headers: resp_headers, + url: final_url, + redirected, body: body_bytes, }) } @@ -514,6 +687,24 @@ impl HttpClientState { } } + /// Builds a merged [`HeaderMap`] from the plugin's default headers and + /// optional per-request IPC headers. + /// + /// Default headers are applied first; per-request headers override them. + /// All header names are validated against the forbidden-header list. + fn build_ipc_header_map( + &self, + request_headers: &Option>, + ) -> Result { + let empty = HashMap::new(); + let headers = request_headers.as_ref().unwrap_or(&empty); + + merge_headers( + &self.config.default_headers, + headers.iter().map(|(k, v)| (k.as_str(), v.as_str())), + ) + } + /// Aborts an in-flight request by ID. /// /// Returns `true` if a request with the given ID was found and aborted, @@ -535,6 +726,94 @@ impl HttpClientState { } } +/// Test-only constructor for `HttpClientState`. +/// +/// Bypasses domain allowlist validation and private IP checks, allowing +/// tests to use local mock servers on `127.0.0.1`. +/// +/// **Only available in `#[cfg(test)]` builds.** +/// +/// # Testing downstream crates +/// +/// `for_testing()` is `#[cfg(test)]`-gated and is not available to downstream +/// consumers. Downstream crates that use the backend API (`get`, `post`, +/// `request`) via [`tauri::State`] should test against a real +/// `HttpClientState` registered by the plugin. Set up the plugin with +/// `Builder::new().allowed_domains(["localhost"]).build()` and start your mock +/// server on `localhost` (or a DNS-resolvable hostname that is allowlisted). +/// +/// If your mock server binds to `127.0.0.1` rather than `localhost`, configure +/// it to also listen on `localhost` so the domain allowlist check passes. +#[cfg(test)] +impl HttpClientState { + /// Creates an `HttpClientState` suitable for unit/integration tests with + /// local mock servers. + /// + /// Skips all URL validation (scheme, IP-literal, domain allowlist) and + /// allows requests to private/loopback IP addresses. This enables tests + /// using any local mock server bound to `127.0.0.1`. + /// + /// # Behavioral differences from production + /// + /// - **No redirect validation**: Uses reqwest's default redirect policy + /// (follow up to 10 hops) instead of [`build_redirect_policy`] — redirects + /// are not validated against the domain allowlist. + /// - **No default header validation**: Default headers are not checked + /// against the forbidden-header list at construction time (they are still + /// validated per-request via [`merge_headers`]). + /// - **No custom redirect limit**: Uses reqwest's default (10) instead of + /// `HttpClientConfig::max_redirects`. + pub fn for_testing() -> Self { + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(Vec::::new()).expect("empty allowlist is always valid"), + )); + let client = reqwest::Client::builder() + .build() + .expect("default reqwest client should build"); + let config = HttpClientConfig { + allow_private_ip: true, + #[cfg(test)] + skip_url_validation: true, + ..HttpClientConfig::default() + }; + + Self::new(client, allowlist, config) + } +} + +/// Builds a merged [`HeaderMap`] from default headers and per-request headers. +/// +/// Default headers are applied first; per-request headers override defaults +/// with the same name. All per-request header names are validated against the +/// forbidden-header list. +fn merge_headers<'a>( + defaults: &HashMap, + per_request: impl Iterator, +) -> Result { + let mut map = HeaderMap::new(); + + for (key, value) in defaults { + map.insert( + reqwest::header::HeaderName::from_bytes(key.as_bytes()) + .map_err(|e| Error::Other(format!("invalid default header name '{key}': {e}")))?, + reqwest::header::HeaderValue::from_str(value) + .map_err(|e| Error::Other(format!("invalid default header value for '{key}': {e}")))?, + ); + } + + for (key, value) in per_request { + validate_header_name(key)?; + map.insert( + reqwest::header::HeaderName::from_bytes(key.as_bytes()) + .map_err(|e| Error::Other(format!("invalid header name '{key}': {e}")))?, + reqwest::header::HeaderValue::from_str(value) + .map_err(|e| Error::Other(format!("invalid header value for '{key}': {e}")))?, + ); + } + + Ok(map) +} + fn parse_method(method: &str) -> Result { match method.to_uppercase().as_str() { "GET" => Ok(reqwest::Method::GET), @@ -568,7 +847,7 @@ fn decode_request_body(body: &str, encoding: Option<&BodyEncoding>) -> Result>, + last_result: Option<&Result>, ) -> Duration { // Check for Retry-After header on the last response if let Some(Ok(resp)) = last_result @@ -603,11 +882,16 @@ fn calculate_backoff( /// /// Supports both delta-seconds format (`Retry-After: 120`) and ignores /// HTTP-date format (too complex to parse without a date library). -fn parse_retry_after_from_response(resp: &ExecuteResult) -> Option { - let values = resp.metadata.headers.get("retry-after")?; - let value = values.first()?; - - value.trim().parse::().ok().map(Duration::from_secs) +fn parse_retry_after_from_response(resp: &RawResponse) -> Option { + let value = resp.headers.get("retry-after")?; + + value + .to_str() + .ok()? + .trim() + .parse::() + .ok() + .map(Duration::from_secs) } /// Builds a custom redirect policy that validates each redirect hop against the allowlist. @@ -683,6 +967,18 @@ impl std::error::Error for RedirectBlockedError {} mod tests { use super::*; + /// Helper to build a `RawResponse` for tests that need to exercise + /// `calculate_backoff` and `parse_retry_after_from_response`. + fn raw_response_with_headers(headers: HeaderMap) -> RawResponse { + RawResponse { + status: reqwest::StatusCode::TOO_MANY_REQUESTS, + headers, + url: url::Url::parse("https://example.com").unwrap(), + redirected: false, + body: Vec::new(), + } + } + #[test] fn test_parse_method() { assert_eq!(parse_method("GET").unwrap(), reqwest::Method::GET); @@ -1582,17 +1878,9 @@ mod tests { #[test] fn test_calculate_backoff_with_retry_after_header() { let config = RetryConfig::default(); - let resp = ExecuteResult { - metadata: FetchResponseMetadata { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["5".to_string()])]), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, - }, - body: Vec::new(), - }; + let mut headers = HeaderMap::new(); + headers.insert("retry-after", "5".parse().unwrap()); + let resp = raw_response_with_headers(headers); let backoff = calculate_backoff(&config, 1, Some(&Ok(resp))); @@ -1605,17 +1893,9 @@ mod tests { max_retry_after: Duration::from_secs(10), ..RetryConfig::default() }; - let resp = ExecuteResult { - metadata: FetchResponseMetadata { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["999".to_string()])]), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, - }, - body: Vec::new(), - }; + let mut headers = HeaderMap::new(); + headers.insert("retry-after", "999".parse().unwrap()); + let resp = raw_response_with_headers(headers); let backoff = calculate_backoff(&config, 1, Some(&Ok(resp))); @@ -1624,17 +1904,9 @@ mod tests { #[test] fn test_parse_retry_after_valid_seconds() { - let resp = ExecuteResult { - metadata: FetchResponseMetadata { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["120".to_string()])]), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, - }, - body: Vec::new(), - }; + let mut headers = HeaderMap::new(); + headers.insert("retry-after", "120".parse().unwrap()); + let resp = raw_response_with_headers(headers); assert_eq!( parse_retry_after_from_response(&resp), @@ -1644,54 +1916,28 @@ mod tests { #[test] fn test_parse_retry_after_missing_header() { - let resp = ExecuteResult { - metadata: FetchResponseMetadata { - status: 503, - status_text: "Service Unavailable".to_string(), - headers: HashMap::new(), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, - }, - body: Vec::new(), - }; + let resp = raw_response_with_headers(HeaderMap::new()); assert_eq!(parse_retry_after_from_response(&resp), None); } #[test] fn test_parse_retry_after_non_numeric_ignored() { - let resp = ExecuteResult { - metadata: FetchResponseMetadata { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([( - "retry-after".to_string(), - vec!["Wed, 21 Oct 2025 07:28:00 GMT".to_string()], - )]), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, - }, - body: Vec::new(), - }; + let mut headers = HeaderMap::new(); + headers.insert( + "retry-after", + "Wed, 21 Oct 2025 07:28:00 GMT".parse().unwrap(), + ); + let resp = raw_response_with_headers(headers); assert_eq!(parse_retry_after_from_response(&resp), None); } #[test] fn test_parse_retry_after_zero_seconds() { - let resp = ExecuteResult { - metadata: FetchResponseMetadata { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["0".to_string()])]), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, - }, - body: Vec::new(), - }; + let mut headers = HeaderMap::new(); + headers.insert("retry-after", "0".parse().unwrap()); + let resp = raw_response_with_headers(headers); assert_eq!( parse_retry_after_from_response(&resp), @@ -1702,17 +1948,9 @@ mod tests { #[test] fn test_calculate_backoff_with_retry_after_zero() { let config = RetryConfig::default(); - let resp = ExecuteResult { - metadata: FetchResponseMetadata { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["0".to_string()])]), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, - }, - body: Vec::new(), - }; + let mut headers = HeaderMap::new(); + headers.insert("retry-after", "0".parse().unwrap()); + let resp = raw_response_with_headers(headers); let backoff = calculate_backoff(&config, 1, Some(&Ok(resp))); @@ -2661,4 +2899,111 @@ mod tests { // "sec" alone is not in FORBIDDEN_HEADERS and doesn't start with "sec-" assert!(validate_header_name("sec").is_ok()); } + + // --- merge_headers tests --- + + #[test] + fn test_merge_headers_empty_defaults_empty_request() { + let defaults = HashMap::new(); + let map = merge_headers(&defaults, std::iter::empty()).unwrap(); + + assert!(map.is_empty()); + } + + #[test] + fn test_merge_headers_defaults_only() { + let defaults = HashMap::from([("x-app".to_string(), "myapp".to_string())]); + let map = merge_headers(&defaults, std::iter::empty()).unwrap(); + + assert_eq!(map.get("x-app").unwrap(), "myapp"); + } + + #[test] + fn test_merge_headers_request_only() { + let defaults = HashMap::new(); + let map = merge_headers(&defaults, vec![("accept", "application/json")].into_iter()).unwrap(); + + assert_eq!(map.get("accept").unwrap(), "application/json"); + } + + #[test] + fn test_merge_headers_request_overrides_default() { + let defaults = HashMap::from([("accept".to_string(), "text/html".to_string())]); + let map = merge_headers(&defaults, vec![("accept", "application/json")].into_iter()).unwrap(); + + assert_eq!(map.get("accept").unwrap(), "application/json"); + } + + #[test] + fn test_merge_headers_forbidden_header_rejected() { + let defaults = HashMap::new(); + let result = merge_headers(&defaults, vec![("host", "evil.com")].into_iter()); + + assert!(result.is_err()); + } + + #[test] + fn test_merge_headers_defaults_not_validated_against_forbidden() { + // Default headers are validated at plugin init, not in merge_headers. + // merge_headers only validates per-request headers. + let defaults = HashMap::from([("x-custom".to_string(), "value".to_string())]); + let map = merge_headers(&defaults, std::iter::empty()).unwrap(); + + assert_eq!(map.get("x-custom").unwrap(), "value"); + } + + #[test] + fn test_merge_headers_multiple_request_headers() { + let defaults = HashMap::new(); + let map = merge_headers( + &defaults, + vec![ + ("accept", "application/json"), + ("authorization", "Bearer token"), + ] + .into_iter(), + ) + .unwrap(); + + assert_eq!(map.get("accept").unwrap(), "application/json"); + assert_eq!(map.get("authorization").unwrap(), "Bearer token"); + } + + // --- for_testing tests --- + + #[test] + fn test_for_testing_creates_valid_state() { + let state = HttpClientState::for_testing(); + + assert!(state.config.allow_private_ip); + assert!(state.config.skip_url_validation); + assert!(state.is_allowlist_empty()); + } + + // --- validate_url_for_request / revalidate_parsed_url tests --- + + #[test] + fn test_validate_url_for_request_skips_allowlist_in_test() { + let state = HttpClientState::for_testing(); + let url = state.validate_url_for_request("http://127.0.0.1:8080/test"); + + assert!(url.is_ok()); + assert_eq!(url.unwrap().as_str(), "http://127.0.0.1:8080/test"); + } + + #[test] + fn test_validate_url_for_request_rejects_invalid_url() { + let state = HttpClientState::for_testing(); + let result = state.validate_url_for_request("not a url"); + + assert!(result.is_err()); + } + + #[test] + fn test_revalidate_parsed_url_skips_in_test() { + let state = HttpClientState::for_testing(); + let url = url::Url::parse("http://127.0.0.1:8080/test").unwrap(); + + assert!(state.revalidate_parsed_url(&url).is_ok()); + } } diff --git a/src/config.rs b/src/config.rs index aebf58b..88c6c2b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -22,6 +22,16 @@ pub struct HttpClientConfig { /// **Default: `false`.** Do not enable in production — the DNS rebinding /// check is a defense-in-depth layer against SSRF. pub(crate) allow_private_ip: bool, + + /// Disables URL validation against the domain allowlist. + /// + /// When `true`, requests bypass `validate_url()` / `validate_parsed_url()` + /// entirely. Only intended for integration tests using local mock servers + /// where the URL contains an IP literal that the allowlist would reject. + /// + /// **Default: `false`.** Only available in `#[cfg(test)]` builds. + #[cfg(test)] + pub(crate) skip_url_validation: bool, } impl Default for HttpClientConfig { @@ -35,6 +45,9 @@ impl Default for HttpClientConfig { default_headers: HashMap::new(), retry: RetryConfig::disabled(), allow_private_ip: false, + + #[cfg(test)] + skip_url_validation: false, } } } diff --git a/src/error.rs b/src/error.rs index b939760..eae2aed 100644 --- a/src/error.rs +++ b/src/error.rs @@ -84,7 +84,7 @@ impl Error { /// Security errors (`DomainNotAllowed`, `IpAddressNotAllowed`, etc.) are /// never retryable — they indicate policy violations that will not change /// between attempts. - pub(crate) fn is_retryable(&self) -> bool { + pub fn is_retryable(&self) -> bool { matches!(self, Error::Request(e) if e.is_timeout() || e.is_connect()) } diff --git a/src/lib.rs b/src/lib.rs index 947f96a..ea2d145 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,8 @@ pub mod client; mod commands; pub mod config; pub mod error; +pub mod request; +pub mod response; pub mod types; use allowlist::DomainAllowlist; @@ -229,6 +231,8 @@ impl Builder { default_headers, retry, allow_private_ip: false, + #[cfg(test)] + skip_url_validation: false, }; let state = HttpClientState::new(client, allowlist, config); diff --git a/src/request.rs b/src/request.rs new file mode 100644 index 0000000..c014476 --- /dev/null +++ b/src/request.rs @@ -0,0 +1,298 @@ +//! Builder for backend (non-IPC) HTTP requests. +//! +//! Provides a fluent API for constructing requests that go through the full +//! plugin security pipeline (domain allowlist, private IP blocking, redirect +//! validation, streaming body limits, retry). +//! +//! Created by [`HttpClientState::get`](crate::client::HttpClientState::get), +//! [`HttpClientState::post`](crate::client::HttpClientState::post), or +//! [`HttpClientState::request`](crate::client::HttpClientState::request). +//! +//! # Examples +//! +//! ```no_run +//! use tauri_plugin_http_client::client::HttpClientState; +//! +//! async fn example(http: &HttpClientState) { +//! let response = http.get("https://api.example.com/data") +//! .header("Accept", "application/json") +//! .send() +//! .await +//! .unwrap(); +//! +//! let status = response.status(); +//! let body = response.text().unwrap(); +//! } +//! ``` + +use std::time::Duration; + +use crate::client::HttpClientState; +use crate::error::Result; +use crate::response::Response; + +/// Fluent builder for Rust-side HTTP requests through the plugin security pipeline. +/// +/// Created by [`HttpClientState::get`](crate::client::HttpClientState::get), +/// [`HttpClientState::post`](crate::client::HttpClientState::post), or +/// [`HttpClientState::request`](crate::client::HttpClientState::request). +/// Call [`send`](RequestBuilder::send) to execute the request. +/// +/// All requests go through the same security pipeline as IPC-initiated requests: +/// domain allowlist validation, private IP blocking, redirect policy enforcement, +/// streaming body limits, and configurable retry. +pub struct RequestBuilder<'a> { + state: &'a HttpClientState, + url: String, + method: reqwest::Method, + headers: Vec<(String, String)>, + body: Option>, + timeout: Option, + max_retries: Option, +} + +impl<'a> RequestBuilder<'a> { + pub(crate) fn new(state: &'a HttpClientState, method: reqwest::Method, url: String) -> Self { + Self { + state, + url, + method, + headers: Vec::new(), + body: None, + timeout: None, + max_retries: None, + } + } + + /// Adds a header to the request. + /// + /// Per-request headers override default headers configured at plugin init. + /// Forbidden headers (e.g., `Host`, `Connection`) are rejected at send time. + pub fn header(mut self, key: impl Into, val: impl Into) -> Self { + self.headers.push((key.into(), val.into())); + self + } + + /// Sets the request body. + pub fn body(mut self, body: impl Into>) -> Self { + self.body = Some(body.into()); + self + } + + /// Sets a per-request timeout, overriding the plugin's default timeout. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + /// Sets the maximum number of retries for this request. + /// + /// Capped at the plugin-level `RetryConfig::max_retries` ceiling. + /// `Some(0)` disables retry for this request. + pub fn max_retries(mut self, n: u32) -> Self { + self.max_retries = Some(n); + self + } + + /// Execute the request through the full security pipeline. + /// + /// Validates the URL against the domain allowlist, applies private IP + /// blocking, enforces redirect policy, streams the response body with + /// size limits, and retries on transient failures. + pub async fn send(self) -> Result { + self + .state + .execute_backend( + &self.url, + self.method, + &self.headers, + self.body.as_deref(), + self.timeout, + self.max_retries, + ) + .await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_state() -> HttpClientState { + HttpClientState::for_testing() + } + + #[test] + fn test_builder_defaults() { + let state = test_state(); + let builder = RequestBuilder::new(&state, reqwest::Method::GET, "https://example.com".into()); + + assert_eq!(builder.url, "https://example.com"); + assert_eq!(builder.method, reqwest::Method::GET); + assert!(builder.headers.is_empty()); + assert!(builder.body.is_none()); + assert!(builder.timeout.is_none()); + assert!(builder.max_retries.is_none()); + } + + #[test] + fn test_header_accumulates() { + let state = test_state(); + let builder = state + .get("https://example.com") + .header("Accept", "application/json") + .header("X-Custom", "value"); + + assert_eq!(builder.headers.len(), 2); + assert_eq!( + builder.headers[0], + ("Accept".to_string(), "application/json".to_string()) + ); + assert_eq!( + builder.headers[1], + ("X-Custom".to_string(), "value".to_string()) + ); + } + + #[test] + fn test_body_set() { + let state = test_state(); + let builder = state.post("https://example.com").body(b"payload".to_vec()); + + assert_eq!(builder.body.as_deref(), Some(b"payload".as_slice())); + } + + #[test] + fn test_body_from_string() { + let state = test_state(); + let builder = state + .post("https://example.com") + .body("text payload".as_bytes().to_vec()); + + assert_eq!(builder.body.as_deref(), Some(b"text payload".as_slice())); + } + + #[test] + fn test_timeout_set() { + let state = test_state(); + let builder = state + .get("https://example.com") + .timeout(Duration::from_secs(5)); + + assert_eq!(builder.timeout, Some(Duration::from_secs(5))); + } + + #[test] + fn test_max_retries_set() { + let state = test_state(); + let builder = state.get("https://example.com").max_retries(3); + + assert_eq!(builder.max_retries, Some(3)); + } + + #[test] + fn test_max_retries_zero_disables() { + let state = test_state(); + let builder = state.get("https://example.com").max_retries(0); + + assert_eq!(builder.max_retries, Some(0)); + } + + #[test] + fn test_method_preserved() { + let state = test_state(); + + let get = state.get("https://example.com"); + + assert_eq!(get.method, reqwest::Method::GET); + + let post = state.post("https://example.com"); + + assert_eq!(post.method, reqwest::Method::POST); + + let put = state.request(reqwest::Method::PUT, "https://example.com"); + + assert_eq!(put.method, reqwest::Method::PUT); + } + + #[test] + fn test_chaining() { + let state = test_state(); + let builder = state + .post("https://example.com") + .header("Content-Type", "application/json") + .body(b"{\"key\":\"value\"}".to_vec()) + .timeout(Duration::from_secs(30)) + .max_retries(2); + + assert_eq!(builder.url, "https://example.com"); + assert_eq!(builder.method, reqwest::Method::POST); + assert_eq!(builder.headers.len(), 1); + assert!(builder.body.is_some()); + assert_eq!(builder.timeout, Some(Duration::from_secs(30))); + assert_eq!(builder.max_retries, Some(2)); + } + + #[tokio::test] + async fn test_send_happy_path() { + let server = wiremock::MockServer::start().await; + + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/test")) + .respond_with(wiremock::ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = test_state(); + let resp = state + .get(format!("{}/test", server.uri())) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), reqwest::StatusCode::OK); + assert_eq!(resp.text().unwrap(), "ok"); + } + + #[tokio::test] + async fn test_send_post_with_body_and_headers() { + let server = wiremock::MockServer::start().await; + + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path("/submit")) + .and(wiremock::matchers::header( + "content-type", + "application/json", + )) + .respond_with(wiremock::ResponseTemplate::new(201).set_body_string("created")) + .mount(&server) + .await; + + let state = test_state(); + let resp = state + .post(format!("{}/submit", server.uri())) + .header("content-type", "application/json") + .body(b"{\"name\":\"test\"}".to_vec()) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), reqwest::StatusCode::CREATED); + assert_eq!(resp.text().unwrap(), "created"); + } + + #[tokio::test] + async fn test_send_forbidden_header_rejected() { + let state = test_state(); + let result = state + .get("http://127.0.0.1:1234") + .header("host", "evil.com") + .send() + .await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + + assert!(matches!(err, crate::error::Error::ForbiddenHeader(_))); + } +} diff --git a/src/response.rs b/src/response.rs new file mode 100644 index 0000000..a699495 --- /dev/null +++ b/src/response.rs @@ -0,0 +1,207 @@ +//! Native response type for backend (non-IPC) HTTP requests. + +/// Native Rust response type for backend (non-IPC) callers. +/// +/// Unlike `ExecuteResult` (the IPC-oriented type in `types.rs`), this type exposes +/// native `reqwest` types (`StatusCode`, `HeaderMap`, `Url`) rather than +/// IPC-serializable primitives. Created by +/// [`RequestBuilder::send`](crate::request::RequestBuilder::send). +#[derive(Debug)] +pub struct Response { + status: reqwest::StatusCode, + headers: reqwest::header::HeaderMap, + url: url::Url, + redirected: bool, + body: Vec, + retry_count: u32, +} + +impl Response { + pub(crate) fn new( + status: reqwest::StatusCode, + headers: reqwest::header::HeaderMap, + url: url::Url, + redirected: bool, + body: Vec, + retry_count: u32, + ) -> Self { + Self { + status, + headers, + url, + redirected, + body, + retry_count, + } + } + + /// Returns the HTTP status code. + pub fn status(&self) -> reqwest::StatusCode { + self.status + } + + /// Returns the response headers. + pub fn headers(&self) -> &reqwest::header::HeaderMap { + &self.headers + } + + /// Returns the final URL after any redirects. + pub fn url(&self) -> &url::Url { + &self.url + } + + /// Returns `true` if the response was the result of a redirect. + pub fn redirected(&self) -> bool { + self.redirected + } + + /// Returns the response body as a byte slice. + pub fn body(&self) -> &[u8] { + &self.body + } + + /// Consumes the response and returns the body bytes. + pub fn into_body(self) -> Vec { + self.body + } + + /// Returns the number of retry attempts before this response (0 = no retries). + pub fn retry_count(&self) -> u32 { + self.retry_count + } + + /// Convenience: decode body as UTF-8 string. + pub fn text(&self) -> Result<&str, std::str::Utf8Error> { + std::str::from_utf8(&self.body) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use reqwest::header::{HeaderMap, HeaderValue}; + + fn sample_response() -> Response { + let mut headers = HeaderMap::new(); + + headers.insert("content-type", HeaderValue::from_static("application/json")); + + Response::new( + reqwest::StatusCode::OK, + headers, + url::Url::parse("https://example.com/data").unwrap(), + false, + b"hello world".to_vec(), + 0, + ) + } + + #[test] + fn test_status() { + let resp = sample_response(); + + assert_eq!(resp.status(), reqwest::StatusCode::OK); + } + + #[test] + fn test_headers() { + let resp = sample_response(); + + assert_eq!( + resp.headers().get("content-type").unwrap(), + "application/json" + ); + } + + #[test] + fn test_url() { + let resp = sample_response(); + + assert_eq!(resp.url().as_str(), "https://example.com/data"); + } + + #[test] + fn test_redirected_false() { + let resp = sample_response(); + + assert!(!resp.redirected()); + } + + #[test] + fn test_redirected_true() { + let resp = Response::new( + reqwest::StatusCode::OK, + HeaderMap::new(), + url::Url::parse("https://example.com/final").unwrap(), + true, + Vec::new(), + 0, + ); + + assert!(resp.redirected()); + } + + #[test] + fn test_body() { + let resp = sample_response(); + + assert_eq!(resp.body(), b"hello world"); + } + + #[test] + fn test_into_body() { + let resp = sample_response(); + + assert_eq!(resp.into_body(), b"hello world"); + } + + #[test] + fn test_retry_count() { + let resp = Response::new( + reqwest::StatusCode::OK, + HeaderMap::new(), + url::Url::parse("https://example.com").unwrap(), + false, + Vec::new(), + 3, + ); + + assert_eq!(resp.retry_count(), 3); + } + + #[test] + fn test_text_valid_utf8() { + let resp = sample_response(); + + assert_eq!(resp.text().unwrap(), "hello world"); + } + + #[test] + fn test_text_invalid_utf8() { + let resp = Response::new( + reqwest::StatusCode::OK, + HeaderMap::new(), + url::Url::parse("https://example.com").unwrap(), + false, + vec![0xFF, 0xFE], + 0, + ); + + assert!(resp.text().is_err()); + } + + #[test] + fn test_empty_body() { + let resp = Response::new( + reqwest::StatusCode::NO_CONTENT, + HeaderMap::new(), + url::Url::parse("https://example.com").unwrap(), + false, + Vec::new(), + 0, + ); + + assert!(resp.body().is_empty()); + assert_eq!(resp.text().unwrap(), ""); + } +}