From 48e4ca19f96cd967701cbd8c8544d27a18911019 Mon Sep 17 00:00:00 2001 From: Jordan Hafer <42755763+jjhafer@users.noreply.github.com> Date: Thu, 12 Mar 2026 14:55:13 -0400 Subject: [PATCH 1/3] feat: add http request support Includes the following features: * A dynamic 'allowlist' which can be configured at initialization and during an app session. * Retry support. See README for more information. --- .gitignore | 3 + README.md | 427 +++ build.rs | 5 + guest-js/errors.test.ts | 123 + guest-js/errors.ts | 121 + guest-js/headers.test.ts | 197 ++ guest-js/headers.ts | 168 + guest-js/http-client.test.ts | 444 +++ guest-js/http-client.ts | 179 ++ guest-js/index.ts | 10 + guest-js/tsconfig.json | 17 + guest-js/types.ts | 56 + .../autogenerated/commands/abort_request.toml | 13 + permissions/autogenerated/commands/fetch.toml | 13 + permissions/autogenerated/reference.md | 70 + permissions/default.toml | 3 + permissions/schemas/schema.json | 330 ++ src/allowlist.rs | 1431 +++++++++ src/client.rs | 2740 +++++++++++++++++ src/commands.rs | 66 + src/config.rs | 305 ++ src/error.rs | 304 ++ src/lib.rs | 461 +++ src/types.rs | 154 + tsconfig.src.json | 5 +- 25 files changed, 7644 insertions(+), 1 deletion(-) create mode 100644 build.rs create mode 100644 guest-js/errors.test.ts create mode 100644 guest-js/errors.ts create mode 100644 guest-js/headers.test.ts create mode 100644 guest-js/headers.ts create mode 100644 guest-js/http-client.test.ts create mode 100644 guest-js/http-client.ts create mode 100644 guest-js/index.ts create mode 100644 guest-js/tsconfig.json create mode 100644 guest-js/types.ts create mode 100644 permissions/autogenerated/commands/abort_request.toml create mode 100644 permissions/autogenerated/commands/fetch.toml create mode 100644 permissions/autogenerated/reference.md create mode 100644 permissions/default.toml create mode 100644 permissions/schemas/schema.json create mode 100644 src/allowlist.rs create mode 100644 src/client.rs create mode 100644 src/commands.rs create mode 100644 src/config.rs create mode 100644 src/error.rs create mode 100644 src/lib.rs create mode 100644 src/types.rs diff --git a/.gitignore b/.gitignore index df62726..d66a361 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ Thumbs.db # Tauri /gen/ + +# Example app artifacts +examples/*/package-lock.json diff --git a/README.md b/README.md index 6dd069b..fd79ef9 100644 --- a/README.md +++ b/README.md @@ -7,5 +7,432 @@ HTTP client plugin for Tauri 2.x apps. This plugin provides a cross-platform interface for creating HTTP requests from Tauri applications. + +## Features + + * Domain allowlist with wildcard support (secure by + default) + * Anti-SSRF protections (rejects IPs, userinfo, + non-HTTP schemes) + * Redirect validation against the allowlist on every + hop + * Automatic retry with exponential backoff and jitter + * Abort in-flight requests via `AbortController` + * Case-insensitive `HttpHeaders` with multi-value + support + * Binary request and response bodies (base64 over IPC) + * Runtime allowlist management from Rust + + +## Installation + +### 1. Install the npm package + +```bash +npm install @silvermine/tauri-plugin-http-client +``` + +Peer dependency: `@tauri-apps/api >= 2.9.1` + +### 2. Add the Cargo dependency + +In `src-tauri/Cargo.toml`: + +```toml +[dependencies] +tauri-plugin-http-client = { + git = "https://github.com/silvermine/tauri-plugin-http-client.git" +} +``` + +### 3. Register the plugin + +In `src-tauri/src/lib.rs`: + +```rust +use std::time::Duration; + +fn main() { + tauri::Builder::default() + .plugin( + tauri_plugin_http_client::Builder::new() + .allowed_domains([ + "api.example.com", + "*.cdn.example.com", + ]) + .default_timeout(Duration::from_secs(30)) + .build(), + ) + .run(tauri::generate_context!()) + .expect("error running tauri application"); +} +``` + +### 4. Add permissions + +In `src-tauri/capabilities/default.json`, add the plugin +permission: + +```json +{ + "permissions": [ + "http-client:default" + ] +} +``` + +This grants access to both the `fetch` and `abort_request` +IPC commands. + + +## Usage + +### Basic Requests + +```typescript +import { request } from '@silvermine/tauri-plugin-http-client'; + +// GET +const resp = await request('https://api.example.com/items'); +const items = resp.json(); + +// POST with JSON body +const resp = await request('https://api.example.com/items', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: { name: 'New item', quantity: 3 }, +}); +``` + +> When passing an object as `body`, it is JSON-stringified +> automatically. You must set `Content-Type` yourself. + +### Reading Responses + +```typescript +const resp = await request('https://api.example.com/data'); + +// Body accessors +const text: string = resp.text(); +const data: MyType = resp.json(); +const bytes: Uint8Array = resp.bytes(); + +// Metadata +resp.status; // 200 +resp.statusText; // "OK" +resp.ok; // true (status 200-299) +resp.url; // final URL after redirects +resp.redirected; // true if redirected +resp.retryCount; // number of retries before success +``` + +### Custom Headers + +```typescript +import { + request, + HttpHeaders, +} from '@silvermine/tauri-plugin-http-client'; + +// Using HttpHeaders class +const headers = new HttpHeaders(); +headers.set('Authorization', 'Bearer tok_123'); +headers.set('Accept', 'application/json'); + +const resp = await request('https://api.example.com/me', { + headers, +}); + +// Or use a plain object +const resp = await request('https://api.example.com/me', { + headers: { + 'Authorization': 'Bearer tok_123', + 'Accept': 'application/json', + }, +}); + +// Reading response headers (case-insensitive) +resp.headers.get('content-type'); // first value +resp.headers.getAll('set-cookie'); // all values +resp.headers.has('x-request-id'); // boolean +``` + +### Aborting Requests + +```typescript +import { + request, + HttpClientError, + HttpErrorCode, +} from '@silvermine/tauri-plugin-http-client'; + +const controller = new AbortController(); + +// Cancel after 5 seconds +setTimeout(() => controller.abort(), 5000); + +try { + const resp = await request( + 'https://api.example.com/large-export', + { signal: controller.signal }, + ); +} catch (err) { + if ( + err instanceof HttpClientError + && err.code === HttpErrorCode.ABORTED + ) { + console.log('Request was aborted'); + } +} +``` + +### Timeouts + +```typescript +// Per-request timeout in milliseconds +const resp = await request('https://api.example.com/slow', { + timeout: 60000, +}); +``` + +This overrides the plugin-level `default_timeout`. + +### Binary Data + +```typescript +// Sending binary +const payload = new Uint8Array([0x00, 0x01, 0x02]); + +await request('https://api.example.com/upload', { + method: 'POST', + headers: { 'Content-Type': 'application/octet-stream' }, + body: payload, +}); + +// Receiving binary +const resp = await request( + 'https://api.example.com/image.png', +); +const bytes = resp.bytes(); +const blob = new Blob([bytes], { type: 'image/png' }); +``` + +### Retries + +Per-request retry override (capped at the plugin-level max): + +```typescript +const resp = await request('https://api.example.com/data', { + maxRetries: 5, +}); + +// Disable retry for a specific request +const resp = await request('https://api.example.com/data', { + maxRetries: 0, +}); +``` + +See [Retry Configuration](#retry-configuration) for +plugin-level setup. + +### Error Handling + +```typescript +import { + request, + HttpClientError, + HttpErrorCode, +} from '@silvermine/tauri-plugin-http-client'; + +try { + const resp = await request('https://blocked.example.com'); +} catch (err) { + if (err instanceof HttpClientError) { + switch (err.code) { + case HttpErrorCode.DOMAIN_NOT_ALLOWED: + // URL not in allowlist + break; + case HttpErrorCode.TIMEOUT: + // Request timed out + break; + case HttpErrorCode.ABORTED: + // Cancelled via AbortController + break; + default: + console.error(err.code, err.message); + } + } +} +``` + +See [`HttpErrorCode`](guest-js/errors.ts) for the full +list of error codes and descriptions. + + +## Security + +### Domain Allowlist + +The allowlist has two tiers: + + * **Init-time patterns** -- set via + `allowed_domains()` in the builder. Supports exact + domains (`api.example.com`) and wildcards + (`*.example.com`). These cannot be removed at + runtime. + * **Runtime patterns** -- added from Rust via + `HttpClientExt`. Exact domains only (wildcards are + rejected). Can be added and removed at any time. + +An empty allowlist blocks all requests (secure by default). + +The total number of patterns (init + runtime) is capped at +`max_allowlist_size` (default: 128). + +### Anti-SSRF Protections + + * Rejects IP addresses (IPv4, IPv6, decimal, octal, + hex encodings) + * Rejects `userinfo@` in URLs + * Only allows `http` and `https` schemes + * Validates every redirect hop against the allowlist + +### Forbidden Headers + +Certain transport-layer and security-prefix headers are +blocked from both per-request and default headers: + + * `Host`, `Connection`, `Keep-Alive`, + `Transfer-Encoding`, `TE`, `Upgrade`, `Trailer` + * Any header starting with `Sec-` or `Proxy-` + +Default headers are validated at plugin init; per-request +headers are validated before each request. Blocked headers +produce a `FORBIDDEN_HEADER` error. + + +## Rust Configuration + +### Builder Options + +| Method | Type | Default | Description | +| --- | --- |---|---| +| `allowed_domains` | `impl IntoIterator>` | `[]` | Domain patterns | +| `default_timeout` | `Duration` | None | Request timeout | +| `max_redirects` | `usize` | `10` | Max redirect hops | +| `max_response_body_size` | `usize` | 10 MB | Body size limit | +| `max_allowlist_size` | `usize` | `128` | Pattern cap | +| `user_agent` | `String` | None | Custom User-Agent | +| `default_headers` | `HashMap` | `{}` | Default headers | +| `retry` | `RetryConfig` | disabled | Retry settings | +| `max_retries` | `u32` | -- | Convenience for retry | + +Full example: + +```rust +use std::collections::HashMap; +use std::time::Duration; +use tauri_plugin_http_client::config::RetryConfig; + +let plugin = tauri_plugin_http_client::Builder::new() + .allowed_domains([ + "api.example.com", + "*.cdn.example.com", + ]) + .default_timeout(Duration::from_secs(30)) + .max_redirects(5) + .max_response_body_size(5 * 1024 * 1024) + .max_allowlist_size(64) + .user_agent("my-app/1.0".into()) + .default_headers(HashMap::from([ + ("X-App-Version".into(), "1.0".into()), + ])) + .retry(RetryConfig::default()) + .build(); +``` + +### Retry Configuration + +Retry is disabled by default. Enable it with +`RetryConfig::default()` or a custom config: + +| Field | Type | Default | Description | +| --- | --- |---|---| +| `max_retries` | `u32` | `3` | Max attempts after initial | +| `initial_backoff` | `Duration` | 200 ms | First retry delay | +| `max_backoff` | `Duration` | 10 s | Backoff cap | +| `retryable_status_codes` | `Vec` | 408, 429, 500, 502, 503, 504 | Status codes to retry | +| `max_retry_after` | `Duration` | 60 s | Cap for Retry-After | +| `retryable_methods` | `Option>` | GET, HEAD, PUT, DELETE, OPTIONS | Methods to retry | + +Key behaviors: + + * Exponential backoff with jitter + (`initial_backoff * 2^(attempt-1)`) + * Honors `Retry-After` headers (capped at + `max_retry_after`) + * POST and PATCH excluded by default (not idempotent) + * Set `retryable_methods` to `None` to retry all + methods + * Timeout is per-attempt, not total + * Security errors are never retried + +```rust +use std::time::Duration; +use tauri_plugin_http_client::config::RetryConfig; + +let retry = RetryConfig { + max_retries: 5, + initial_backoff: Duration::from_millis(500), + max_backoff: Duration::from_secs(30), + retryable_methods: None, // retry all methods + ..RetryConfig::default() +}; + +let plugin = tauri_plugin_http_client::Builder::new() + .allowed_domains(["api.example.com"]) + .retry(retry) + .build(); +``` + +### Runtime Allowlist Management + +Use the `HttpClientExt` trait to manage domains from Rust: + +```rust +use tauri::Manager; +use tauri_plugin_http_client::HttpClientExt; + +#[tauri::command] +fn connect_service( + app: tauri::AppHandle, + domain: String, +) -> Result<(), String> { + app.add_allowed_domain(domain) + .map_err(|e| e.to_string()) +} +``` + +Available methods: + + * `add_allowed_domain(domain)` -- add one domain + * `add_allowed_domains(domains)` -- add multiple + domains + * `remove_allowed_domain(domain)` -- remove one + (returns whether it existed) + * `remove_allowed_domains(domains)` -- remove multiple + (returns count removed) + * `remove_all_runtime_domains()` -- clear all runtime + domains + +Wildcards are rejected at runtime +(`WILDCARD_NOT_ALLOWED_AT_RUNTIME`). Init-time patterns +cannot be removed. + + +## License + +[MIT](./LICENSE) + [ci-badge]: https://img.shields.io/github/actions/workflow/status/silvermine/tauri-plugin-http-client/ci.yml [ci-url]: https://github.com/silvermine/tauri-plugin-http-client/actions/workflows/ci.yml diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..69714ff --- /dev/null +++ b/build.rs @@ -0,0 +1,5 @@ +const COMMANDS: &[&str] = &["fetch", "abort_request"]; + +fn main() { + tauri_plugin::Builder::new(COMMANDS).build(); +} diff --git a/guest-js/errors.test.ts b/guest-js/errors.test.ts new file mode 100644 index 0000000..9f689b1 --- /dev/null +++ b/guest-js/errors.test.ts @@ -0,0 +1,123 @@ +import { describe, it, expect } from 'vitest'; +import { HttpClientError, HttpErrorCode, parseError } from './errors'; + +describe('HttpClientError', () => { + + it('has correct name, code, and message', () => { + const error = new HttpClientError(HttpErrorCode.TIMEOUT, 'request timed out'); + + expect(error.name).toBe('HttpClientError'); + expect(error.code).toBe(HttpErrorCode.TIMEOUT); + expect(error.message).toBe('request timed out'); + expect(error).toBeInstanceOf(Error); + }); + +}); + +describe('parseError', () => { + + it('passes through HttpClientError instances', () => { + const original = new HttpClientError(HttpErrorCode.ABORTED, 'aborted'); + + expect(parseError(original)).toBe(original); + }); + + it('parses structured JSON string from Rust', () => { + const json = JSON.stringify({ code: 'DOMAIN_NOT_ALLOWED', message: 'domain not allowed: evil.com' }), + error = parseError(json); + + expect(error.code).toBe(HttpErrorCode.DOMAIN_NOT_ALLOWED); + expect(error.message).toBe('domain not allowed: evil.com'); + }); + + it('parses structured object from Rust', () => { + const obj = { code: 'TIMEOUT', message: 'request timed out' }, + error = parseError(obj); + + expect(error.code).toBe(HttpErrorCode.TIMEOUT); + expect(error.message).toBe('request timed out'); + }); + + it('handles plain string errors', () => { + const error = parseError('something went wrong'); + + expect(error.code).toBe(HttpErrorCode.ERROR); + expect(error.message).toBe('something went wrong'); + }); + + it('handles unknown error codes gracefully', () => { + const obj = { code: 'UNKNOWN_CODE', message: 'some error' }, + error = parseError(obj); + + expect(error.code).toBe(HttpErrorCode.ERROR); + expect(error.message).toBe('some error'); + }); + + it('handles null/undefined', () => { + const error = parseError(null); + + expect(error.code).toBe(HttpErrorCode.ERROR); + expect(error.message).toBe('unknown error'); + }); + + it('handles objects with only message', () => { + const error = parseError({ message: 'just a message' }); + + expect(error.code).toBe(HttpErrorCode.ERROR); + expect(error.message).toBe('just a message'); + }); + + it('handles undefined', () => { + const error = parseError(undefined); + + expect(error.code).toBe(HttpErrorCode.ERROR); + expect(error.message).toBe('unknown error'); + }); + + it('handles plain Error instances', () => { + const error = parseError(new Error('regular error')); + + expect(error.code).toBe(HttpErrorCode.ERROR); + expect(error.message).toBe('regular error'); + }); + + it('handles number input', () => { + const error = parseError(42); + + expect(error.code).toBe(HttpErrorCode.ERROR); + expect(error.message).toBe('unknown error'); + }); + + it('parses all known error codes from Rust', () => { + const codes = [ + 'DOMAIN_NOT_ALLOWED', 'SCHEME_NOT_ALLOWED', 'IP_ADDRESS_NOT_ALLOWED', + 'INVALID_URL', 'TIMEOUT', 'CONNECTION_ERROR', 'REQUEST_ERROR', + 'ABORTED', 'RESPONSE_TOO_LARGE', 'REDIRECT_BLOCKED', + 'ALLOWLIST_SIZE_EXCEEDED', 'WILDCARD_NOT_ALLOWED_AT_RUNTIME', + 'INVALID_DOMAIN_PATTERN', 'FORBIDDEN_HEADER', + ]; + + for (const code of codes) { + const error = parseError({ code, message: 'test' }); + + expect(error.code).toBe(code); + } + }); + + it('parses FORBIDDEN_HEADER error from Rust', () => { + const json = JSON.stringify({ code: 'FORBIDDEN_HEADER', message: 'forbidden header: host' }), + error = parseError(json); + + expect(error.code).toBe(HttpErrorCode.FORBIDDEN_HEADER); + expect(error.message).toBe('forbidden header: host'); + }); + + it('handles JSON string with only message field', () => { + const json = JSON.stringify({ message: 'partial error' }), + error = parseError(json); + + // No code field, so should get plain string treatment since code check fails + expect(error.code).toBe(HttpErrorCode.ERROR); + }); + +}); diff --git a/guest-js/errors.ts b/guest-js/errors.ts new file mode 100644 index 0000000..d2ef89a --- /dev/null +++ b/guest-js/errors.ts @@ -0,0 +1,121 @@ +/** + * Machine-readable error codes returned by the HTTP client plugin. + * + * Use these with {@link HttpClientError.code} for programmatic error + * handling. + */ +export enum HttpErrorCode { + + /** URL host is not in the domain allowlist. */ + DOMAIN_NOT_ALLOWED = 'DOMAIN_NOT_ALLOWED', + + /** Non-HTTP(S) scheme (e.g. `ftp://`, `file://`). */ + SCHEME_NOT_ALLOWED = 'SCHEME_NOT_ALLOWED', + + /** IP address used instead of a domain name. */ + IP_ADDRESS_NOT_ALLOWED = 'IP_ADDRESS_NOT_ALLOWED', + + /** Malformed URL. */ + INVALID_URL = 'INVALID_URL', + + /** Request timed out. */ + TIMEOUT = 'TIMEOUT', + + /** TCP or TLS connection failure. */ + CONNECTION_ERROR = 'CONNECTION_ERROR', + + /** Other request-level error. */ + REQUEST_ERROR = 'REQUEST_ERROR', + + /** Request cancelled via `AbortController`. */ + ABORTED = 'ABORTED', + + /** Response body exceeds the configured size limit. */ + RESPONSE_TOO_LARGE = 'RESPONSE_TOO_LARGE', + + /** A redirect targeted a domain not in the allowlist. */ + REDIRECT_BLOCKED = 'REDIRECT_BLOCKED', + + /** Adding a domain would exceed the allowlist size cap. */ + ALLOWLIST_SIZE_EXCEEDED = 'ALLOWLIST_SIZE_EXCEEDED', + + /** Wildcard patterns cannot be added at runtime. */ + WILDCARD_NOT_ALLOWED_AT_RUNTIME = 'WILDCARD_NOT_ALLOWED_AT_RUNTIME', + + /** Domain pattern is malformed. */ + INVALID_DOMAIN_PATTERN = 'INVALID_DOMAIN_PATTERN', + + /** A forbidden header was provided (e.g. Host). */ + FORBIDDEN_HEADER = 'FORBIDDEN_HEADER', + + /** Unclassified error. */ + ERROR = 'ERROR', + +} + +const KNOWN_CODES = new Set(Object.values(HttpErrorCode)); + +/** + * Error thrown by the HTTP client plugin. + * + * Contains a machine-readable {@link code} for programmatic error handling. + */ +export class HttpClientError extends Error { + + public readonly code: HttpErrorCode; + + public constructor(code: HttpErrorCode, message: string) { + super(message); + this.name = 'HttpClientError'; + this.code = code; + } + +} + +/** + * Parses the structured `{code, message}` error from the Rust backend + * into an `HttpClientError`. + */ +export function parseError(err: unknown): HttpClientError { + if (err instanceof HttpClientError) { + return err; + } + + // Tauri invoke errors come as strings or objects + let code = HttpErrorCode.ERROR, + message = 'unknown error'; + + if (typeof err === 'string') { + try { + const parsed = JSON.parse(err) as { code?: string; message?: string }; + + if (parsed.code && parsed.message) { + code = toErrorCode(parsed.code); + message = parsed.message; + } else { + message = err; + } + } catch{ + message = err; + } + } else if (err && typeof err === 'object') { + const obj = err as Record; + + if (typeof obj.code === 'string' && typeof obj.message === 'string') { + code = toErrorCode(obj.code); + message = obj.message; + } else if (typeof obj.message === 'string') { + message = obj.message; + } + } + + return new HttpClientError(code, message); +} + +function toErrorCode(raw: string): HttpErrorCode { + if (KNOWN_CODES.has(raw)) { + return raw as HttpErrorCode; + } + + return HttpErrorCode.ERROR; +} diff --git a/guest-js/headers.test.ts b/guest-js/headers.test.ts new file mode 100644 index 0000000..888f29b --- /dev/null +++ b/guest-js/headers.test.ts @@ -0,0 +1,197 @@ +import { describe, it, expect } from 'vitest'; +import { HttpHeaders } from './headers'; + +describe('HttpHeaders', () => { + + it('constructs from Record', () => { + const headers = new HttpHeaders({ 'Content-Type': 'application/json' }); + + expect(headers.get('content-type')).toBe('application/json'); + }); + + it('constructs from Record', () => { + const headers = new HttpHeaders({ 'Set-Cookie': [ 'a=1', 'b=2' ] }); + + expect(headers.getAll('set-cookie')).toEqual([ 'a=1', 'b=2' ]); + }); + + it('get() is case-insensitive', () => { + const headers = new HttpHeaders({ 'Content-Type': 'text/html' }); + + expect(headers.get('CONTENT-TYPE')).toBe('text/html'); + expect(headers.get('content-type')).toBe('text/html'); + expect(headers.get('Content-Type')).toBe('text/html'); + }); + + it('get() returns null for missing headers', () => { + const headers = new HttpHeaders(); + + expect(headers.get('x-missing')).toBeNull(); + }); + + it('set() replaces existing values', () => { + const headers = new HttpHeaders({ 'X-Foo': [ 'a', 'b' ] }); + + headers.set('X-Foo', 'c'); + expect(headers.getAll('x-foo')).toEqual([ 'c' ]); + }); + + it('append() adds to existing values', () => { + const headers = new HttpHeaders({ 'X-Foo': 'a' }); + + headers.append('X-Foo', 'b'); + expect(headers.getAll('x-foo')).toEqual([ 'a', 'b' ]); + }); + + it('append() creates new header if not present', () => { + const headers = new HttpHeaders(); + + headers.append('X-New', 'val'); + expect(headers.get('x-new')).toBe('val'); + }); + + it('has() checks existence case-insensitively', () => { + const headers = new HttpHeaders({ 'Authorization': 'Bearer token' }); + + expect(headers.has('authorization')).toBe(true); + expect(headers.has('AUTHORIZATION')).toBe(true); + expect(headers.has('x-missing')).toBe(false); + }); + + it('delete() removes headers case-insensitively', () => { + const headers = new HttpHeaders({ 'X-Remove': 'val' }); + + headers.delete('X-REMOVE'); + expect(headers.has('x-remove')).toBe(false); + }); + + it('forEach() iterates all name-value pairs', () => { + const headers = new HttpHeaders({ 'a': [ '1', '2' ], 'b': '3' }), + pairs: Array<[string, string]> = []; + + headers.forEach((value, name) => { + pairs.push([ name, value ]); + }); + + expect(pairs).toEqual([ + [ 'a', '1' ], + [ 'a', '2' ], + [ 'b', '3' ], + ]); + }); + + it('entries() returns flattened name-value pairs', () => { + const headers = new HttpHeaders({ 'a': [ '1', '2' ] }), + entries = Array.from(headers.entries()); + + expect(entries).toEqual([ + [ 'a', '1' ], + [ 'a', '2' ], + ]); + }); + + it('is iterable via Symbol.iterator', () => { + const headers = new HttpHeaders({ 'x-test': 'val' }), + entries = Array.from(headers); + + expect(entries).toEqual([ [ 'x-test', 'val' ] ]); + }); + + it('keys() returns header names', () => { + const headers = new HttpHeaders({ 'A': '1', 'B': '2' }), + keys = Array.from(headers.keys()); + + expect(keys).toEqual([ 'a', 'b' ]); + }); + + it('values() returns all values across all headers', () => { + const headers = new HttpHeaders({ 'a': [ 'first', 'second' ], 'b': 'only' }), + values = Array.from(headers.values()); + + expect(values).toEqual([ 'first', 'second', 'only' ]); + }); + + it('toRecord() joins multi-value headers with comma per RFC 9110', () => { + const headers = new HttpHeaders({ 'a': [ '1', '2' ], 'b': '3' }); + + expect(headers.toRecord()).toEqual({ a: '1, 2', b: '3' }); + }); + + it('handles empty initialization', () => { + const headers = new HttpHeaders(); + + expect(headers.get('anything')).toBeNull(); + expect(Array.from(headers.entries())).toEqual([]); + }); + + it('delete() on non-existent header is a no-op', () => { + const headers = new HttpHeaders({ 'X-Keep': 'val' }); + + headers.delete('X-NonExistent'); + + expect(headers.has('x-keep')).toBe(true); + expect(headers.has('x-nonexistent')).toBe(false); + }); + + it('toRecord() joins multi-value headers for IPC (lossy conversion)', () => { + const headers = new HttpHeaders(); + + headers.set('Accept', 'text/html'); + headers.append('Accept', 'application/json'); + + const record = headers.toRecord(); + + // Multi-value headers are joined with ", " for the IPC bridge + expect(record.accept).toBe('text/html, application/json'); + }); + + it('toMultiRecord() preserves all values as arrays', () => { + const headers = new HttpHeaders({ 'Set-Cookie': [ 'a=1; Path=/', 'b=2; Path=/' ], 'content-type': 'text/html' }); + + expect(headers.toMultiRecord()).toEqual({ + 'set-cookie': [ 'a=1; Path=/', 'b=2; Path=/' ], + 'content-type': [ 'text/html' ], + }); + }); + + it('toMultiRecord() wraps single-value headers in an array', () => { + const headers = new HttpHeaders({ 'Content-Type': 'text/html' }); + + expect(headers.toMultiRecord()).toEqual({ 'content-type': [ 'text/html' ] }); + }); + + it('toMultiRecord() does not split Set-Cookie values containing commas', () => { + // Set-Cookie values can contain commas (e.g. in Expires dates). + // toMultiRecord() preserves each cookie as a distinct entry. + const headers = new HttpHeaders({ + 'Set-Cookie': [ 'a=1; Expires=Thu, 01 Jan 2099 00:00:00 GMT', 'b=2' ], + }); + + expect(headers.toMultiRecord()['set-cookie']).toEqual([ + 'a=1; Expires=Thu, 01 Jan 2099 00:00:00 GMT', + 'b=2', + ]); + }); + + it('toMultiRecord() returns copies, not references to internal arrays', () => { + const headers = new HttpHeaders({ 'x-foo': [ 'a', 'b' ] }), + record = headers.toMultiRecord(); + + record['x-foo'].push('c'); + expect(headers.getAll('x-foo')).toEqual([ 'a', 'b' ]); + }); + + it('toMultiRecord() returns empty object for empty headers', () => { + const headers = new HttpHeaders(); + + expect(headers.toMultiRecord()).toEqual({}); + }); + + it('getAll() returns a copy — mutations do not affect internal state', () => { + const headers = new HttpHeaders({ 'X-Foo': [ 'a', 'b' ] }); + + headers.getAll('x-foo').push('c'); + expect(headers.getAll('x-foo')).toEqual([ 'a', 'b' ]); + }); + +}); diff --git a/guest-js/headers.ts b/guest-js/headers.ts new file mode 100644 index 0000000..142ddbd --- /dev/null +++ b/guest-js/headers.ts @@ -0,0 +1,168 @@ +/** + * Case-insensitive HTTP headers collection with multi-value support. + * + * Header names are normalized to lowercase for consistent access. + */ +export class HttpHeaders implements Iterable<[string, string]> { + + private _map: Map = new Map(); + + public constructor(init?: Record) { + if (init) { + for (const [ name, value ] of Object.entries(init)) { + if (Array.isArray(value)) { + this._map.set(name.toLowerCase(), [ ...value ]); + } else { + this._map.set(name.toLowerCase(), [ value ]); + } + } + } + } + + /** + * Returns the first value for the given header name, or `null` if not present. + */ + public get(name: string): string | null { + const values = this._map.get(name.toLowerCase()); + + return values ? values[0] : null; + } + + /** + * Returns all values for the given header name. + */ + public getAll(name: string): string[] { + const values = this._map.get(name.toLowerCase()); + + return values ? [ ...values ] : []; + } + + /** + * Sets the header to a single value, replacing any existing values. + */ + public set(name: string, value: string): void { + this._map.set(name.toLowerCase(), [ value ]); + } + + /** + * Appends a value to the header (creates the header if it doesn't exist). + */ + public append(name: string, value: string): void { + const key = name.toLowerCase(), + existing = this._map.get(key); + + if (existing) { + existing.push(value); + } else { + this._map.set(key, [ value ]); + } + } + + /** + * Returns `true` if the header exists. + */ + public has(name: string): boolean { + return this._map.has(name.toLowerCase()); + } + + /** + * Removes the header. + */ + public delete(name: string): void { + this._map.delete(name.toLowerCase()); + } + + /** + * Iterates over all headers, calling `fn` for each name-value pair. + * Multi-value headers call `fn` once per value. + */ + public forEach(fn: (value: string, name: string, headers: HttpHeaders) => void): void { + for (const [ name, values ] of this._map) { + for (const value of values) { + fn(value, name, this); + } + } + } + + /** + * Returns an iterator of `[name, value]` pairs. Multi-value headers + * produce one entry per value. + */ + public entries(): IterableIterator<[string, string]> { + const pairs: Array<[string, string]> = []; + + for (const [ name, values ] of this._map) { + for (const value of values) { + pairs.push([ name, value ]); + } + } + + return pairs[Symbol.iterator](); + } + + /** + * Returns an iterator of header names. + */ + public keys(): IterableIterator { + return this._map.keys(); + } + + /** + * Returns an iterator of all values across all headers. + * Multi-value headers produce one entry per value. + */ + public values(): IterableIterator { + const vals: string[] = []; + + for (const values of this._map.values()) { + for (const value of values) { + vals.push(value); + } + } + + return vals[Symbol.iterator](); + } + + public [Symbol.iterator](): Iterator<[string, string]> { + return this.entries(); + } + + /** + * Converts headers to a plain `Record`, joining + * multi-value headers with `", "` per RFC 9110 Section 5.3. + * + * Used when serializing headers for the IPC bridge. + * + * **Note:** This conversion is lossy for headers that must not be + * combined, most notably `Set-Cookie` (RFC 9110 Section 5.3 explicitly + * excludes it from the combining rule). Use `toMultiRecord()` when you + * need to preserve each value separately. + */ + public toRecord(): Record { + const record: Record = {}; + + for (const [ name, values ] of this._map) { + record[name] = values.join(', '); + } + + return record; + } + + /** + * Converts headers to a `Record`, preserving all + * values for every header as a separate array entry. + * + * Prefer this over `toRecord()` when the caller needs to handle headers + * that must not be combined, such as `Set-Cookie`. + */ + public toMultiRecord(): Record { + const record: Record = {}; + + for (const [ name, values ] of this._map) { + record[name] = [ ...values ]; + } + + return record; + } + +} diff --git a/guest-js/http-client.test.ts b/guest-js/http-client.test.ts new file mode 100644 index 0000000..debcd29 --- /dev/null +++ b/guest-js/http-client.test.ts @@ -0,0 +1,444 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import type { RawFetchResponse } from './types'; +import type { HttpClientError as HttpClientErrorType } from './errors'; + +// Mock @tauri-apps/api/core before importing the module under test +const mockInvoke = vi.fn(); + +vi.mock('@tauri-apps/api/core', () => { + return { invoke: mockInvoke }; +}); + +// Import after mock setup +const { request } = await import('./http-client'); + +const { HttpClientError, HttpErrorCode } = await import('./errors'); + +const { HttpHeaders } = await import('./headers'); + +function makeRawResponse(overrides?: Partial): RawFetchResponse { + return { + status: 200, + statusText: 'OK', + headers: { 'content-type': [ 'application/json' ] }, + body: '{"key":"value"}', + bodyEncoding: 'utf8', + url: 'https://api.example.com/data', + redirected: false, + retryCount: 0, + ...overrides, + }; +} + +describe('request()', () => { + + beforeEach(() => { + mockInvoke.mockReset(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('sends a basic GET request and returns a response', async () => { + const raw = makeRawResponse(); + + mockInvoke.mockResolvedValueOnce(raw); + + const resp = await request('https://api.example.com/data'); + + expect(mockInvoke).toHaveBeenCalledWith('plugin:http-client|fetch', { + request: { url: 'https://api.example.com/data' }, + }); + expect(resp.status).toBe(200); + expect(resp.statusText).toBe('OK'); + expect(resp.ok).toBe(true); + expect(resp.url).toBe('https://api.example.com/data'); + expect(resp.redirected).toBe(false); + expect(resp.headers.get('content-type')).toBe('application/json'); + }); + + it('returns text body correctly', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'hello world', bodyEncoding: 'utf8' })); + + const resp = await request('https://example.com'); + + expect(resp.text()).toBe('hello world'); + }); + + it('parses JSON body correctly', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + const resp = await request('https://example.com'), + data = resp.json<{ key: string }>(); + + expect(data.key).toBe('value'); + }); + + it('decodes base64 body to bytes', async () => { + // "hello" in base64 + mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'aGVsbG8=', bodyEncoding: 'base64' })); + + const resp = await request('https://example.com'), + bytes = resp.bytes(); + + expect(bytes).toBeInstanceOf(Uint8Array); + expect(new TextDecoder().decode(bytes)).toBe('hello'); + }); + + it('sends POST with string body', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + await request('https://example.com', { method: 'POST', body: 'payload' }); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.method).toBe('POST'); + expect(payload.body).toBe('payload'); + expect(payload.bodyEncoding).toBe('utf8'); + }); + + it('sends POST with object body as JSON', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + await request('https://example.com', { method: 'POST', body: { foo: 'bar' } }); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.body).toBe('{"foo":"bar"}'); + expect(payload.bodyEncoding).toBe('utf8'); + }); + + it('sends POST with Uint8Array body as base64', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + // Construct Uint8Array directly (not via TextEncoder) to avoid + // cross-realm instanceof issues in jsdom test environment. + const encoded = new TextEncoder().encode('binary data'), + bytes = new Uint8Array(encoded); + + await request('https://example.com', { method: 'POST', body: bytes }); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.bodyEncoding).toBe('base64'); + expect(atob(payload.body)).toBe('binary data'); + }); + + it('sends headers from Record', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + await request('https://example.com', { + headers: { 'Authorization': 'Bearer token' }, + }); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.headers).toEqual({ 'Authorization': 'Bearer token' }); + }); + + it('sends headers from HttpHeaders instance', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + const headers = new HttpHeaders(); + + headers.set('Authorization', 'Bearer token'); + + await request('https://example.com', { headers }); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.headers).toEqual({ 'authorization': 'Bearer token' }); + }); + + it('sends timeout', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + await request('https://example.com', { timeout: 5000 }); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.timeoutMs).toBe(5000); + }); + + it('ok is false for non-2xx status', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse({ status: 404, statusText: 'Not Found' })); + + const resp = await request('https://example.com'); + + expect(resp.ok).toBe(false); + expect(resp.status).toBe(404); + }); + + it('throws HttpClientError on invoke failure', async () => { + mockInvoke.mockRejectedValueOnce(JSON.stringify({ + code: 'DOMAIN_NOT_ALLOWED', + message: 'domain not allowed: evil.com', + })); + + try { + await request('https://evil.com'); + expect.fail('should have thrown'); + } catch(err) { + expect(err).toBeInstanceOf(HttpClientError); + expect((err as HttpClientErrorType).code).toBe(HttpErrorCode.DOMAIN_NOT_ALLOWED); + } + }); + + it('throws ABORTED when signal is already aborted', async () => { + const controller = new AbortController(); + + controller.abort(); + + try { + await request('https://example.com', { signal: controller.signal }); + expect.fail('should have thrown'); + } catch(err) { + expect(err).toBeInstanceOf(HttpClientError); + expect((err as HttpClientErrorType).code).toBe(HttpErrorCode.ABORTED); + } + }); + + it('includes requestId when signal is provided', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + const controller = new AbortController(); + + await request('https://example.com', { signal: controller.signal }); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.requestId).toBeDefined(); + expect(typeof payload.requestId).toBe('string'); + }); + + it('calls abort_request when signal fires', async () => { + // Make invoke hang until we abort + let resolveInvoke: ((v: RawFetchResponse) => void) | undefined; + + const invokePromise = new Promise((resolve) => { + resolveInvoke = resolve; + }); + + mockInvoke.mockImplementation((cmd: string) => { + if (cmd === 'plugin:http-client|fetch') { + return invokePromise; + } + // abort_request call + return Promise.resolve(true); + }); + + const controller = new AbortController(), + reqPromise = request('https://example.com', { signal: controller.signal }); + + // Abort the request + controller.abort(); + + // Resolve the fetch so the promise settles + if (resolveInvoke) { + resolveInvoke(makeRawResponse()); + } + + const resp = await reqPromise; + + expect(resp.status).toBe(200); + + // Check that abort_request was called + const abortCall = mockInvoke.mock.calls.find((c: unknown[]) => { + return c[0] === 'plugin:http-client|abort_request'; + }); + + expect(abortCall).toBeDefined(); + }); + + it('decodes base64 body to text via text()', async () => { + // "héllo" in UTF-8 then base64 + const bytes = new TextEncoder().encode('héllo'), + b64 = btoa(String.fromCharCode(...bytes)); + + mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: b64, bodyEncoding: 'base64' })); + + const resp = await request('https://example.com'); + + expect(resp.text()).toBe('héllo'); + }); + + it('converts utf8 body to bytes via bytes()', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'hello', bodyEncoding: 'utf8' })); + + const resp = await request('https://example.com'), + bytes = resp.bytes(); + + expect(ArrayBuffer.isView(bytes)).toBe(true); + expect(new TextDecoder().decode(bytes)).toBe('hello'); + }); + + it('caches text() return value', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'cached', bodyEncoding: 'utf8' })); + + const resp = await request('https://example.com'); + + expect(resp.text()).toBe('cached'); + // Second call should return same value (cached) + expect(resp.text()).toBe('cached'); + }); + + it('caches bytes() return value', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'aGVsbG8=', bodyEncoding: 'base64' })); + + const resp = await request('https://example.com'); + + const first = resp.bytes(), + second = resp.bytes(); + + // Should be the exact same reference + expect(first).toBe(second); + }); + + it('removes abort listener after successful request', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + const controller = new AbortController(), + removeSpy = vi.spyOn(controller.signal, 'removeEventListener'); + + await request('https://example.com', { signal: controller.signal }); + + expect(removeSpy).toHaveBeenCalledWith('abort', expect.any(Function)); + }); + + it('does not include requestId without signal', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + await request('https://example.com'); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.requestId).toBeUndefined(); + }); + + it('omits undefined optional fields from payload', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + await request('https://example.com'); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload).toEqual({ url: 'https://example.com' }); + }); + + it('ok boundary: 199 is not ok, 200 is ok, 299 is ok, 300 is not ok', async () => { + for (const [ status, expectedOk ] of [ [ 199, false ], [ 200, true ], [ 299, true ], [ 300, false ] ] as [number, boolean][]) { + mockInvoke.mockResolvedValueOnce(makeRawResponse({ status })); + + const resp = await request('https://example.com'); + + expect(resp.ok).toBe(expectedOk); + } + }); + + it('handles redirected response', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse({ + redirected: true, + url: 'https://api.example.com/final', + })); + + const resp = await request('https://api.example.com/start'); + + expect(resp.redirected).toBe(true); + expect(resp.url).toBe('https://api.example.com/final'); + }); + + it('passes maxRetries in payload', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + await request('https://example.com', { maxRetries: 5 }); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.maxRetries).toBe(5); + }); + + it('omits maxRetries when undefined', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + await request('https://example.com'); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.maxRetries).toBeUndefined(); + }); + + it('exposes retryCount from response', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse({ retryCount: 2 })); + + const resp = await request('https://example.com'); + + expect(resp.retryCount).toBe(2); + }); + + it('retryCount is 0 when no retries occurred', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + const resp = await request('https://example.com'); + + expect(resp.retryCount).toBe(0); + }); + + it('sends empty string body correctly', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + await request('https://example.com', { method: 'POST', body: '' }); + + const payload = mockInvoke.mock.calls[0][1].request; + + expect(payload.body).toBe(''); + expect(payload.bodyEncoding).toBe('utf8'); + }); + + it('sends multi-value headers via HttpHeaders as comma-joined string', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + const headers = new HttpHeaders(); + + headers.set('Accept', 'text/html'); + headers.append('Accept', 'application/json'); + + await request('https://example.com', { headers }); + + const payload = mockInvoke.mock.calls[0][1].request; + + // HttpHeaders.toRecord() joins multi-value with ", " + expect(payload.headers.accept).toBe('text/html, application/json'); + }); + + it('generateRequestId produces unique IDs under rapid calls', async () => { + const ids: Set = new Set(); + + // Fire 10 rapid requests, each generating a unique requestId + for (let i = 0; i < 10; i++) { + mockInvoke.mockResolvedValueOnce(makeRawResponse()); + + const controller = new AbortController(); + + await request('https://example.com', { signal: controller.signal }); + + const payload = mockInvoke.mock.calls[i][1].request; + + ids.add(payload.requestId); + } + + // All IDs should be unique + expect(ids.size).toBe(10); + }); + + it('json() throws on invalid JSON body', async () => { + mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'not valid json', bodyEncoding: 'utf8' })); + + const resp = await request('https://example.com'); + + const fn = (): unknown => { return resp.json(); }; + + expect(fn).toThrow(); + }); + +}); diff --git a/guest-js/http-client.ts b/guest-js/http-client.ts new file mode 100644 index 0000000..cc1321e --- /dev/null +++ b/guest-js/http-client.ts @@ -0,0 +1,179 @@ +import { invoke } from '@tauri-apps/api/core'; +import { HttpHeaders } from './headers'; +import { parseError } from './errors'; +import type { BodyEncoding, RequestOptions, HttpResponse, RawFetchRequest, RawFetchResponse } from './types'; + +let requestCounter = 0; + +// Generates a unique ID used as both a tracking key in Rust's InFlightRequests +// map and a cancellation token for the abort_request IPC command. The counter +// handles rapid calls within the same millisecond. +function generateRequestId(): string { + requestCounter += 1; + return `req-${Date.now()}-${requestCounter}`; +} + +/** + * Sends an HTTP request through the Tauri HTTP client plugin. + * + * All URL validation and security checks happen in the Rust backend. + * + * @param url - The URL to request + * @param options - Optional request configuration + * @returns A response object with status, headers, and body accessors + * @throws {HttpClientError} On network errors, security violations, or abort + */ +export async function request(url: string, options?: RequestOptions): Promise { + let requestId: string | undefined, + abortHandler: (() => void) | undefined; + + if (options?.signal) { + requestId = generateRequestId(); + + if (options.signal.aborted) { + throw parseError({ code: 'ABORTED', message: 'request aborted' }); + } + + abortHandler = (): void => { + // Fire-and-forget abort to Rust + invoke('plugin:http-client|abort_request', { requestId }).catch(() => { + // Ignore errors from abort (request may have already completed) + }); + }; + + options.signal.addEventListener('abort', abortHandler, { once: true }); + } + + try { + const payload = buildPayload(url, requestId, options), + raw: RawFetchResponse = await invoke('plugin:http-client|fetch', { request: payload }); + + return wrapResponse(raw); + } catch(err: unknown) { + throw parseError(err); + } finally { + if (abortHandler && options?.signal) { + options.signal.removeEventListener('abort', abortHandler); + } + } +} + +function buildPayload(url: string, requestId: string | undefined, options?: RequestOptions): RawFetchRequest { + const payload: RawFetchRequest = { url }; + + if (options?.method) { + payload.method = options.method; + } + + if (options?.headers) { + if (options.headers instanceof HttpHeaders) { + payload.headers = options.headers.toRecord(); + } else { + payload.headers = options.headers; + } + } + + if (options?.body !== undefined) { + const encoded: { body: string; encoding: BodyEncoding } = encodeBody(options.body); + + payload.body = encoded.body; + payload.bodyEncoding = encoded.encoding; + } + + if (options?.timeout !== undefined) { + payload.timeoutMs = options.timeout; + } + + if (requestId) { + payload.requestId = requestId; + } + + if (options?.maxRetries !== undefined) { + payload.maxRetries = options.maxRetries; + } + + return payload; +} + +function encodeBody(body: string | Uint8Array | Record): { body: string; encoding: BodyEncoding } { + if (typeof body === 'string') { + return { body, encoding: 'utf8' }; + } + + if (body instanceof Uint8Array) { + return { body: uint8ArrayToBase64(body), encoding: 'base64' }; + } + + return { body: JSON.stringify(body), encoding: 'utf8' }; +} + +// Manual loop + btoa/atob for broad WebView compatibility (avoids relying +// on Uint8Array.toBase64 which is not available in all runtimes). +function uint8ArrayToBase64(bytes: Uint8Array): string { + let binary = ''; + + for (let i = 0; i < bytes.length; i++) { + binary += String.fromCharCode(bytes[i]); + } + + return btoa(binary); +} + +function base64ToUint8Array(base64: string): Uint8Array { + const binary = atob(base64), + bytes = new Uint8Array(binary.length); + + for (let i = 0; i < binary.length; i++) { + bytes[i] = binary.charCodeAt(i); + } + + return bytes; +} + +function wrapResponse(raw: RawFetchResponse): HttpResponse { + const headers = new HttpHeaders(raw.headers); + + // Cache decoded body values + let textValue: string | undefined, + bytesValue: Uint8Array | undefined; + + return { + status: raw.status, + statusText: raw.statusText, + headers, + url: raw.url, + redirected: raw.redirected, + ok: raw.status >= 200 && raw.status < 300, // mirrors fetch() Response.ok + retryCount: raw.retryCount, + + text(): string { + if (textValue === undefined) { + if (raw.bodyEncoding === 'base64') { + const bytes = base64ToUint8Array(raw.body); + + textValue = new TextDecoder().decode(bytes); + } else { + textValue = raw.body; + } + } + + return textValue; + }, + + json(): T { + return JSON.parse(this.text()) as T; + }, + + bytes(): Uint8Array { + if (bytesValue === undefined) { + if (raw.bodyEncoding === 'base64') { + bytesValue = base64ToUint8Array(raw.body); + } else { + bytesValue = new TextEncoder().encode(raw.body); + } + } + + return bytesValue; + }, + }; +} diff --git a/guest-js/index.ts b/guest-js/index.ts new file mode 100644 index 0000000..68fc0f6 --- /dev/null +++ b/guest-js/index.ts @@ -0,0 +1,10 @@ +export { request } from './http-client'; +export { HttpHeaders } from './headers'; +export { HttpClientError, HttpErrorCode } from './errors'; +export type { + BodyEncoding, + RequestOptions, + HttpResponse, + HttpMethod, + BodyInit, +} from './types'; diff --git a/guest-js/tsconfig.json b/guest-js/tsconfig.json new file mode 100644 index 0000000..29209e8 --- /dev/null +++ b/guest-js/tsconfig.json @@ -0,0 +1,17 @@ +{ + "compilerOptions": { + "target": "es2021", + "module": "esnext", + "moduleResolution": "bundler", + "skipLibCheck": true, + "strict": true, + "noUnusedLocals": true, + "noImplicitAny": true, + "declaration": true, + "outDir": "../dist-js", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true + }, + "include": ["index.ts"], + "exclude": ["node_modules", "../dist-js"] +} diff --git a/guest-js/types.ts b/guest-js/types.ts new file mode 100644 index 0000000..d382c31 --- /dev/null +++ b/guest-js/types.ts @@ -0,0 +1,56 @@ +import type { HttpHeaders } from './headers'; + +export type HttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | 'HEAD' | 'OPTIONS'; +export type BodyEncoding = 'utf8' | 'base64'; +export type BodyInit = string | Uint8Array | Record; + +export interface RequestOptions { + method?: HttpMethod; + headers?: Record | HttpHeaders; + body?: BodyInit; + timeout?: number; + signal?: AbortSignal; + + /** Per-request retry override. `0` disables retry for this request. + * Capped at the plugin-level max configured in Rust. */ + maxRetries?: number; +} + +export interface HttpResponse { + readonly status: number; + readonly statusText: string; + readonly headers: HttpHeaders; + readonly url: string; + readonly redirected: boolean; + readonly ok: boolean; + + /** Number of retry attempts before this response (0 = no retries). */ + readonly retryCount: number; + text(): string; + json(): T; + bytes(): Uint8Array; +} + +/** Raw IPC response from the Rust backend. */ +export interface RawFetchResponse { + status: number; + statusText: string; + headers: Record; + body: string; + bodyEncoding: BodyEncoding; + url: string; + redirected: boolean; + retryCount: number; +} + +/** Raw IPC request payload sent to the Rust backend. */ +export interface RawFetchRequest { + url: string; + method?: string; + headers?: Record; + body?: string; + bodyEncoding?: BodyEncoding; + timeoutMs?: number; + requestId?: string; + maxRetries?: number; +} diff --git a/permissions/autogenerated/commands/abort_request.toml b/permissions/autogenerated/commands/abort_request.toml new file mode 100644 index 0000000..8d5ffdb --- /dev/null +++ b/permissions/autogenerated/commands/abort_request.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-abort-request" +description = "Enables the abort_request command without any pre-configured scope." +commands.allow = ["abort_request"] + +[[permission]] +identifier = "deny-abort-request" +description = "Denies the abort_request command without any pre-configured scope." +commands.deny = ["abort_request"] diff --git a/permissions/autogenerated/commands/fetch.toml b/permissions/autogenerated/commands/fetch.toml new file mode 100644 index 0000000..c4e068a --- /dev/null +++ b/permissions/autogenerated/commands/fetch.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-fetch" +description = "Enables the fetch command without any pre-configured scope." +commands.allow = ["fetch"] + +[[permission]] +identifier = "deny-fetch" +description = "Denies the fetch command without any pre-configured scope." +commands.deny = ["fetch"] diff --git a/permissions/autogenerated/reference.md b/permissions/autogenerated/reference.md new file mode 100644 index 0000000..5da6b64 --- /dev/null +++ b/permissions/autogenerated/reference.md @@ -0,0 +1,70 @@ +## Default Permission + +Default permissions for the HTTP client plugin + +#### This default permission set includes the following: + +- `allow-fetch` +- `allow-abort-request` + +## Permission Table + + + + + + + + + + + + + + + + + + + + + + + + + + + +
IdentifierDescription
+ +`http-client:allow-abort-request` + + + +Enables the abort_request command without any pre-configured scope. + +
+ +`http-client:deny-abort-request` + + + +Denies the abort_request command without any pre-configured scope. + +
+ +`http-client:allow-fetch` + + + +Enables the fetch command without any pre-configured scope. + +
+ +`http-client:deny-fetch` + + + +Denies the fetch command without any pre-configured scope. + +
diff --git a/permissions/default.toml b/permissions/default.toml new file mode 100644 index 0000000..b9e628e --- /dev/null +++ b/permissions/default.toml @@ -0,0 +1,3 @@ +[default] +description = "Default permissions for the HTTP client plugin" +permissions = ["allow-fetch", "allow-abort-request"] diff --git a/permissions/schemas/schema.json b/permissions/schemas/schema.json new file mode 100644 index 0000000..8500524 --- /dev/null +++ b/permissions/schemas/schema.json @@ -0,0 +1,330 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "PermissionFile", + "description": "Permission file that can define a default permission, a set of permissions or a list of inlined permissions.", + "type": "object", + "properties": { + "default": { + "description": "The default permission set for the plugin", + "anyOf": [ + { + "$ref": "#/definitions/DefaultPermission" + }, + { + "type": "null" + } + ] + }, + "set": { + "description": "A list of permissions sets defined", + "type": "array", + "items": { + "$ref": "#/definitions/PermissionSet" + } + }, + "permission": { + "description": "A list of inlined permissions", + "default": [], + "type": "array", + "items": { + "$ref": "#/definitions/Permission" + } + } + }, + "definitions": { + "DefaultPermission": { + "description": "The default permission set of the plugin.\n\nWorks similarly to a permission with the \"default\" identifier.", + "type": "object", + "required": [ + "permissions" + ], + "properties": { + "version": { + "description": "The version of the permission.", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 1.0 + }, + "description": { + "description": "Human-readable description of what the permission does. Tauri convention is to use `

` headings in markdown content for Tauri documentation generation purposes.", + "type": [ + "string", + "null" + ] + }, + "permissions": { + "description": "All permissions this set contains.", + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "PermissionSet": { + "description": "A set of direct permissions grouped together under a new name.", + "type": "object", + "required": [ + "description", + "identifier", + "permissions" + ], + "properties": { + "identifier": { + "description": "A unique identifier for the permission.", + "type": "string" + }, + "description": { + "description": "Human-readable description of what the permission does.", + "type": "string" + }, + "permissions": { + "description": "All permissions this set contains.", + "type": "array", + "items": { + "$ref": "#/definitions/PermissionKind" + } + } + } + }, + "Permission": { + "description": "Descriptions of explicit privileges of commands.\n\nIt can enable commands to be accessible in the frontend of the application.\n\nIf the scope is defined it can be used to fine grain control the access of individual or multiple commands.", + "type": "object", + "required": [ + "identifier" + ], + "properties": { + "version": { + "description": "The version of the permission.", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 1.0 + }, + "identifier": { + "description": "A unique identifier for the permission.", + "type": "string" + }, + "description": { + "description": "Human-readable description of what the permission does. Tauri internal convention is to use `

` headings in markdown content for Tauri documentation generation purposes.", + "type": [ + "string", + "null" + ] + }, + "commands": { + "description": "Allowed or denied commands when using this permission.", + "default": { + "allow": [], + "deny": [] + }, + "allOf": [ + { + "$ref": "#/definitions/Commands" + } + ] + }, + "scope": { + "description": "Allowed or denied scoped when using this permission.", + "allOf": [ + { + "$ref": "#/definitions/Scopes" + } + ] + }, + "platforms": { + "description": "Target platforms this permission applies. By default all platforms are affected by this permission.", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Target" + } + } + } + }, + "Commands": { + "description": "Allowed and denied commands inside a permission.\n\nIf two commands clash inside of `allow` and `deny`, it should be denied by default.", + "type": "object", + "properties": { + "allow": { + "description": "Allowed command.", + "default": [], + "type": "array", + "items": { + "type": "string" + } + }, + "deny": { + "description": "Denied command, which takes priority.", + "default": [], + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "Scopes": { + "description": "An argument for fine grained behavior control of Tauri commands.\n\nIt can be of any serde serializable type and is used to allow or prevent certain actions inside a Tauri command. The configured scope is passed to the command and will be enforced by the command implementation.\n\n## Example\n\n```json { \"allow\": [{ \"path\": \"$HOME/**\" }], \"deny\": [{ \"path\": \"$HOME/secret.txt\" }] } ```", + "type": "object", + "properties": { + "allow": { + "description": "Data that defines what is allowed by the scope.", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Value" + } + }, + "deny": { + "description": "Data that defines what is denied by the scope. This should be prioritized by validation logic.", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Value" + } + } + } + }, + "Value": { + "description": "All supported ACL values.", + "anyOf": [ + { + "description": "Represents a null JSON value.", + "type": "null" + }, + { + "description": "Represents a [`bool`].", + "type": "boolean" + }, + { + "description": "Represents a valid ACL [`Number`].", + "allOf": [ + { + "$ref": "#/definitions/Number" + } + ] + }, + { + "description": "Represents a [`String`].", + "type": "string" + }, + { + "description": "Represents a list of other [`Value`]s.", + "type": "array", + "items": { + "$ref": "#/definitions/Value" + } + }, + { + "description": "Represents a map of [`String`] keys to [`Value`]s.", + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/Value" + } + } + ] + }, + "Number": { + "description": "A valid ACL number.", + "anyOf": [ + { + "description": "Represents an [`i64`].", + "type": "integer", + "format": "int64" + }, + { + "description": "Represents a [`f64`].", + "type": "number", + "format": "double" + } + ] + }, + "Target": { + "description": "Platform target.", + "oneOf": [ + { + "description": "MacOS.", + "type": "string", + "enum": [ + "macOS" + ] + }, + { + "description": "Windows.", + "type": "string", + "enum": [ + "windows" + ] + }, + { + "description": "Linux.", + "type": "string", + "enum": [ + "linux" + ] + }, + { + "description": "Android.", + "type": "string", + "enum": [ + "android" + ] + }, + { + "description": "iOS.", + "type": "string", + "enum": [ + "iOS" + ] + } + ] + }, + "PermissionKind": { + "type": "string", + "oneOf": [ + { + "description": "Enables the abort_request command without any pre-configured scope.", + "type": "string", + "const": "allow-abort-request", + "markdownDescription": "Enables the abort_request command without any pre-configured scope." + }, + { + "description": "Denies the abort_request command without any pre-configured scope.", + "type": "string", + "const": "deny-abort-request", + "markdownDescription": "Denies the abort_request command without any pre-configured scope." + }, + { + "description": "Enables the fetch command without any pre-configured scope.", + "type": "string", + "const": "allow-fetch", + "markdownDescription": "Enables the fetch command without any pre-configured scope." + }, + { + "description": "Denies the fetch command without any pre-configured scope.", + "type": "string", + "const": "deny-fetch", + "markdownDescription": "Denies the fetch command without any pre-configured scope." + }, + { + "description": "Default permissions for the HTTP client plugin\n#### This default permission set includes:\n\n- `allow-fetch`\n- `allow-abort-request`", + "type": "string", + "const": "default", + "markdownDescription": "Default permissions for the HTTP client plugin\n#### This default permission set includes:\n\n- `allow-fetch`\n- `allow-abort-request`" + } + ] + } + } +} \ No newline at end of file diff --git a/src/allowlist.rs b/src/allowlist.rs new file mode 100644 index 0000000..80bea49 --- /dev/null +++ b/src/allowlist.rs @@ -0,0 +1,1431 @@ +use std::collections::HashSet; +use std::net::IpAddr; + +use url::Url; + +use crate::error::{Error, Result}; + +/// A parsed domain pattern for allowlist matching. +#[derive(Debug, Clone, PartialEq, Eq)] +enum DomainPattern { + /// Matches an exact domain (e.g. `api.example.com`). + Exact(String), + /// Matches any subdomain of the base (e.g. `*.example.com` matches + /// `api.example.com` and `deep.sub.example.com`, but not `example.com`). + WildcardSubdomain(String), +} + +impl DomainPattern { + fn parse(pattern: &str) -> Self { + let normalized = pattern.to_lowercase(); + + if let Some(base) = normalized.strip_prefix("*.") { + DomainPattern::WildcardSubdomain(base.to_string()) + } else { + DomainPattern::Exact(normalized) + } + } + + fn matches(&self, host: &str) -> bool { + match self { + DomainPattern::Exact(domain) => host == domain, + DomainPattern::WildcardSubdomain(base) => { + // The byte check ensures "notexample.com" doesn't match "*.example.com": + // the character immediately before the suffix must be a dot separator. + host.ends_with(base) + && host.len() > base.len() + && host.as_bytes()[host.len() - base.len() - 1] == b'.' + } + } + } +} + +/// Domain allowlist that validates URLs against configured domain patterns. +/// +/// An empty allowlist blocks all requests (secure by default). +/// +/// The allowlist uses a two-tier storage model: +/// +/// - **Config-time patterns** (`init_patterns`): Set at construction via +/// [`new`](Self::new). Immutable after creation. Supports both exact and +/// wildcard patterns. +/// - **Runtime patterns** (`runtime_patterns`): Added and removed at runtime +/// via [`add_patterns`](Self::add_patterns) and [`remove_patterns`](Self::remove_patterns). +/// Exact domains only (wildcards rejected). Stored as normalized lowercase +/// strings in a `HashSet` for O(1) operations and natural deduplication. +/// +/// Config-time patterns cannot be removed — they represent the app developer's +/// security policy and are structurally immutable. +#[derive(Debug, Clone)] +pub struct DomainAllowlist { + /// Config-time patterns: immutable after construction, supports wildcards. + init_patterns: Vec, + /// Runtime patterns: mutable (add/remove), exact domains only. + runtime_patterns: HashSet, +} + +impl DomainAllowlist { + /// Creates a new allowlist from raw domain pattern strings. + /// + /// These patterns become config-time patterns and cannot be removed at + /// runtime. Both exact and wildcard patterns are supported. + /// + /// Supported pattern formats: + /// - `"api.example.com"` - exact domain match + /// - `"*.example.com"` - any subdomain of `example.com` + /// + /// # Errors + /// + /// Returns [`Error::InvalidDomainPattern`] if any pattern is empty, + /// contains control characters, URL-reserved characters, or is a bare `*`. + pub fn new(raw_patterns: Vec) -> Result { + for pattern in &raw_patterns { + // For wildcard patterns, validate the base domain after the `*.` prefix + if let Some(base) = pattern.strip_prefix("*.") { + validate_domain_pattern(base)?; + } else { + validate_domain_pattern(pattern)?; + } + } + + let init_patterns = raw_patterns + .iter() + .map(|p| DomainPattern::parse(p)) + .collect(); + + Ok(Self { + init_patterns, + runtime_patterns: HashSet::new(), + }) + } + + /// Validates a URL string through the full security pipeline. + /// + /// # Validation Steps + /// + /// 1. Parse URL + /// 2. Reject non-HTTP(S) schemes + /// 3. Reject URLs with userinfo + /// 4. Reject backslash in authority + /// 5. Reject IP addresses + /// 6. Normalize host and match against allowlist + pub fn validate_url(&self, url_str: &str) -> Result { + // Reject backslash in URL before parsing (parser may normalize it) + if url_str.contains('\\') { + return Err(Error::InvalidUrl( + "backslash not allowed in url".to_string(), + )); + } + + let url = Url::parse(url_str).map_err(|e| Error::InvalidUrl(e.to_string()))?; + + self.validate_parsed_url(&url)?; + + Ok(url) + } + + /// Validates an already-parsed URL (used for redirect hop validation). + pub fn validate_parsed_url(&self, url: &Url) -> Result<()> { + match url.scheme() { + "http" | "https" => {} + scheme => { + return Err(Error::SchemeNotAllowed(scheme.to_string())); + } + } + + if !url.username().is_empty() || url.password().is_some() { + return Err(Error::UserinfoNotAllowed); + } + + let host = url + .host_str() + .ok_or_else(|| Error::InvalidUrl("missing host".to_string()))?; + + // Check the parsed Host enum for definitive IPv4/IPv6 detection + if let Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) = url.host() { + return Err(Error::IpAddressNotAllowed); + } + + // The url crate's Host enum doesn't catch decimal, octal, or hex + // IP representations — those parse as domain strings. Catch them here. + if host.parse::().is_ok() || is_ip_like(host) { + return Err(Error::IpAddressNotAllowed); + } + + let normalized_host = host.to_lowercase(); + let normalized_host = normalized_host.trim_end_matches('.'); + + if !self.is_domain_allowed(normalized_host) { + return Err(Error::DomainNotAllowed(normalized_host.to_string())); + } + + Ok(()) + } + + fn is_domain_allowed(&self, host: &str) -> bool { + self.init_patterns.iter().any(|p| p.matches(host)) || self.runtime_patterns.contains(host) + } + + /// Adds exact domain patterns to the runtime allowlist. + /// + /// Only exact domain patterns are accepted (e.g. `"api.example.com"`). + /// Wildcard patterns (`"*.example.com"`) are rejected to limit the blast + /// radius of runtime mutations. Wildcards should be configured at build + /// time via [`Builder::allowed_domains`](crate::Builder::allowed_domains). + /// + /// Duplicate patterns are silently accepted (the `HashSet` deduplicates). + /// + /// # Errors + /// + /// Returns [`Error::WildcardNotAllowedAtRuntime`] if any pattern starts with `*.`. + /// No patterns are added if any pattern is invalid (atomic operation). + pub(crate) fn add_patterns(&mut self, raw_patterns: Vec) -> Result<()> { + // Validate all patterns before mutating (atomic batch) + for pattern in &raw_patterns { + validate_domain_pattern(pattern)?; + + if pattern.starts_with("*.") { + return Err(Error::WildcardNotAllowedAtRuntime(pattern.clone())); + } + } + + for pattern in &raw_patterns { + self.runtime_patterns.insert(pattern.to_lowercase()); + } + + Ok(()) + } + + /// Removes exact domain patterns from the runtime allowlist. + /// + /// Only runtime-added patterns can be removed. Config-time patterns + /// (set via [`new`](Self::new)) are structurally immutable and cannot + /// be removed — attempts to remove them are silently ignored (idempotent). + /// + /// Note: if a config-time wildcard pattern (e.g. `*.example.com`) covers + /// a runtime domain being removed, the domain will still be allowed via + /// the config-time pattern. Use [`is_runtime_domain`](Self::is_runtime_domain) + /// to inspect runtime membership. + /// + /// # Errors + /// + /// Returns [`Error::WildcardNotAllowedAtRuntime`] if any pattern starts with `*.`. + /// No patterns are removed if any pattern is invalid (atomic operation). + pub(crate) fn remove_patterns(&mut self, raw_patterns: &[String]) -> Result { + for pattern in raw_patterns { + if pattern.starts_with("*.") { + return Err(Error::WildcardNotAllowedAtRuntime(pattern.clone())); + } + } + + let mut removed = 0; + + for pattern in raw_patterns { + if self.runtime_patterns.remove(&pattern.to_lowercase()) { + removed += 1; + } + } + + Ok(removed) + } + + /// Removes all runtime-added patterns, preserving config-time patterns. + /// + /// Returns the number of patterns removed. + pub(crate) fn remove_all_runtime_patterns(&mut self) -> usize { + let count = self.runtime_patterns.len(); + + self.runtime_patterns.clear(); + count + } + + /// Returns `true` if the given domain is in the runtime allowlist. + /// + /// Config-time patterns are not checked. This is useful for inspecting + /// whether a domain was dynamically added at runtime. + pub fn is_runtime_domain(&self, domain: &str) -> bool { + self.runtime_patterns.contains(&domain.to_lowercase()) + } + + /// Returns the number of runtime-added patterns. + pub fn runtime_pattern_count(&self) -> usize { + self.runtime_patterns.len() + } + + /// Returns the number of config-time patterns. + pub fn config_pattern_count(&self) -> usize { + self.init_patterns.len() + } + + /// Returns `true` if the allowlist has no patterns (blocks all requests). + pub fn is_empty(&self) -> bool { + self.init_patterns.is_empty() && self.runtime_patterns.is_empty() + } + + /// Returns the total number of patterns in the allowlist (config + runtime). + pub fn pattern_count(&self) -> usize { + self.init_patterns.len() + self.runtime_patterns.len() + } +} + +/// Detects IP-like hostnames that might bypass simple `IpAddr` parsing. +/// +/// Catches decimal IPs (e.g. `2130706433` for `127.0.0.1`), octal notation +/// (e.g. `0177.0.0.1`), hex notation (e.g. `0x7f.0.0.1`), and bracket-wrapped IPv6. +fn is_ip_like(host: &str) -> bool { + let host = host.trim_start_matches('[').trim_end_matches(']'); + + // Pure numeric (decimal IP encoding) + if host.chars().all(|c| c.is_ascii_digit()) && !host.is_empty() { + return true; + } + + if host.starts_with("0x") || host.starts_with("0X") { + return true; + } + + // Dotted segments that look like octal or hex + let segments: Vec<&str> = host.split('.').collect(); + + if segments.len() >= 2 + && segments.iter().all(|s| { + !s.is_empty() + && (s.chars().all(|c| c.is_ascii_digit()) || s.starts_with("0x") || s.starts_with("0X")) + }) + { + return true; + } + + false +} + +/// Checks if a resolved IP address is in a private/reserved range. +/// +/// Used for anti-DNS-rebinding protection after DNS resolution. +pub fn is_private_ip(ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(v4) => { + v4.is_loopback() // 127.0.0.0/8 + || v4.is_private() // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 + || v4.is_link_local() // 169.254.0.0/16 + || v4.is_unspecified() // 0.0.0.0 + || v4.is_broadcast() // 255.255.255.255 + } + IpAddr::V6(v6) => { + v6.is_loopback() // ::1 + || v6.is_unspecified() // :: + || is_ipv6_unique_local(v6) // fc00::/7 + || is_ipv6_link_local(v6) // fe80::/10 + || is_ipv4_mapped_private(v6) + } + } +} + +/// fc00::/7 — top 7 bits must be `1111110`. Mask with 0xfe00 (7 ones + +/// 9 zeros in a 16-bit segment) and compare against 0xfc00. +fn is_ipv6_unique_local(v6: &std::net::Ipv6Addr) -> bool { + (v6.segments()[0] & 0xfe00) == 0xfc00 +} + +/// fe80::/10 — top 10 bits must be `1111111010`. Mask with 0xffc0 (10 ones + +/// 6 zeros) and compare against 0xfe80. +fn is_ipv6_link_local(v6: &std::net::Ipv6Addr) -> bool { + (v6.segments()[0] & 0xffc0) == 0xfe80 +} + +fn is_ipv4_mapped_private(v6: &std::net::Ipv6Addr) -> bool { + if let Some(v4) = v6.to_ipv4_mapped() { + is_private_ip(&IpAddr::V4(v4)) + } else { + false + } +} + +/// Validates that a domain pattern string is well-formed. +/// +/// Rejects: +/// - Empty or whitespace-only strings +/// - Patterns containing control characters (`\n`, `\r`, `\t`, etc.) +/// - Patterns longer than 253 characters (DNS max) +/// - Patterns containing URL-reserved characters (`:/?#@`) +/// +/// This does NOT check whether the pattern is a wildcard — that is handled +/// separately in [`DomainAllowlist::add_patterns`]. +pub fn validate_domain_pattern(pattern: &str) -> Result<()> { + if pattern != pattern.trim() { + return Err(Error::InvalidDomainPattern( + "pattern must not have leading or trailing whitespace".to_string(), + )); + } + + if pattern.is_empty() { + return Err(Error::InvalidDomainPattern( + "pattern must not be empty or whitespace-only".to_string(), + )); + } + + if pattern == "*" { + return Err(Error::InvalidDomainPattern( + "bare '*' is not supported; use '*.domain.com' for subdomain wildcards".to_string(), + )); + } + + if pattern.chars().any(|c| c.is_control()) { + return Err(Error::InvalidDomainPattern( + "pattern must not contain control characters".to_string(), + )); + } + + if pattern.len() > 253 { + return Err(Error::InvalidDomainPattern(format!( + "pattern length {} exceeds maximum of 253 characters", + pattern.len(), + ))); + } + + const RESERVED: &[char] = &[':', '/', '?', '#', '@']; + + if pattern.contains(RESERVED) { + return Err(Error::InvalidDomainPattern( + "pattern must not contain URL-reserved characters (:/?#@)".to_string(), + )); + } + + Ok(()) +} + +/// Convenience function: creates patterns for both `domain` and `*.domain`. +pub fn allow_domain_with_subdomains(domain: &str) -> Vec { + vec![domain.to_string(), format!("*.{domain}")] +} + +/// Returns a closure that matches a domain if it equals `base_domain` or is a +/// subdomain of it (e.g. `api.example.com` for base `example.com`). +/// +/// The returned closure is `Send + Sync + 'static`, suitable for use in custom +/// validation logic or middleware. +/// +/// # Examples +/// +/// ```no_run +/// use tauri_plugin_http_client::allowlist::subdomain_validator; +/// +/// let matches = subdomain_validator("example.com"); +/// assert!(matches("example.com")); +/// assert!(matches("api.example.com")); +/// assert!(!matches("notexample.com")); +/// ``` +pub fn subdomain_validator(base_domain: &str) -> impl Fn(&str) -> bool + Send + Sync + 'static { + let base = base_domain.to_lowercase(); + + move |domain: &str| { + let d = domain.to_lowercase(); + + if d == base { + return true; + } + + d.ends_with(&base) && d.len() > base.len() && d.as_bytes()[d.len() - base.len() - 1] == b'.' + } +} + +/// Returns a closure that matches a domain if it exactly equals one of the +/// provided domains (case-insensitive). +/// +/// # Examples +/// +/// ```no_run +/// use tauri_plugin_http_client::allowlist::exact_domains_validator; +/// +/// let matches = exact_domains_validator(&["api.example.com", "cdn.example.com"]); +/// assert!(matches("api.example.com")); +/// assert!(matches("CDN.Example.COM")); +/// assert!(!matches("other.example.com")); +/// ``` +pub fn exact_domains_validator(domains: &[&str]) -> impl Fn(&str) -> bool + Send + Sync + 'static { + let set: HashSet = domains.iter().map(|d| d.to_lowercase()).collect(); + + move |domain: &str| set.contains(&domain.to_lowercase()) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn allowlist(patterns: &[&str]) -> DomainAllowlist { + DomainAllowlist::new(patterns.iter().map(|s| s.to_string()).collect()).unwrap() + } + + // --- Pattern matching --- + + #[test] + fn test_exact_match() { + let al = allowlist(&["api.example.com"]); + + assert!(al.validate_url("https://api.example.com/path").is_ok()); + assert!(al.validate_url("https://other.example.com/path").is_err()); + } + + #[test] + fn test_wildcard_subdomain_match() { + let al = allowlist(&["*.example.com"]); + + assert!(al.validate_url("https://api.example.com/path").is_ok()); + assert!(al.validate_url("https://deep.sub.example.com/path").is_ok()); + // Wildcard does NOT match the base domain itself + assert!(al.validate_url("https://example.com/path").is_err()); + } + + #[test] + fn test_case_insensitive_matching() { + let al = allowlist(&["API.Example.COM"]); + + assert!(al.validate_url("https://api.example.com/path").is_ok()); + } + + #[test] + fn test_empty_allowlist_blocks_all() { + let al = allowlist(&[]); + + assert!(al.validate_url("https://example.com").is_err()); + assert!(al.is_empty()); + } + + #[test] + fn test_allow_domain_with_subdomains_helper() { + let patterns = allow_domain_with_subdomains("example.com"); + let al = DomainAllowlist::new(patterns).unwrap(); + + assert!(al.validate_url("https://example.com/path").is_ok()); + assert!(al.validate_url("https://api.example.com/path").is_ok()); + assert!(al.validate_url("https://evil.com").is_err()); + } + + // --- Scheme validation --- + + #[test] + fn test_http_scheme_allowed() { + let al = allowlist(&["example.com"]); + + assert!(al.validate_url("http://example.com").is_ok()); + assert!(al.validate_url("https://example.com").is_ok()); + } + + #[test] + fn test_non_http_scheme_rejected() { + let al = allowlist(&["example.com"]); + + let err = al.validate_url("ftp://example.com").unwrap_err(); + + assert!(matches!(err, Error::SchemeNotAllowed(_))); + } + + #[test] + fn test_javascript_scheme_rejected() { + let al = allowlist(&["example.com"]); + + // `url` crate may fail to parse this, which is also acceptable + let result = al.validate_url("javascript:alert(1)"); + + assert!(result.is_err()); + } + + // --- Userinfo rejection --- + + #[test] + fn test_userinfo_rejected() { + let al = allowlist(&["example.com"]); + + let err = al + .validate_url("https://user:pass@example.com") + .unwrap_err(); + + assert!(matches!(err, Error::UserinfoNotAllowed)); + } + + #[test] + fn test_username_only_rejected() { + let al = allowlist(&["example.com"]); + + let err = al.validate_url("https://user@example.com").unwrap_err(); + + assert!(matches!(err, Error::UserinfoNotAllowed)); + } + + // --- IP address rejection --- + + #[test] + fn test_ipv4_rejected() { + let al = allowlist(&["example.com"]); + + let err = al.validate_url("https://127.0.0.1/path").unwrap_err(); + + assert!(matches!(err, Error::IpAddressNotAllowed)); + } + + #[test] + fn test_ipv6_rejected() { + let al = allowlist(&["example.com"]); + + let err = al.validate_url("https://[::1]/path").unwrap_err(); + + assert!(matches!(err, Error::IpAddressNotAllowed)); + } + + #[test] + fn test_decimal_ip_rejected() { + let al = allowlist(&["example.com"]); + + // 2130706433 = 127.0.0.1 in decimal + let err = al.validate_url("https://2130706433/path").unwrap_err(); + + assert!(matches!(err, Error::IpAddressNotAllowed)); + } + + #[test] + fn test_hex_ip_rejected() { + let al = allowlist(&["example.com"]); + + let err = al.validate_url("https://0x7f000001/path").unwrap_err(); + + assert!(matches!(err, Error::IpAddressNotAllowed)); + } + + // --- Backslash rejection --- + + #[test] + fn test_backslash_rejected() { + let al = allowlist(&["example.com"]); + + let err = al + .validate_url("https://example.com\\@evil.com") + .unwrap_err(); + + assert!(matches!(err, Error::InvalidUrl(_))); + } + + // --- Trailing dot normalization --- + + #[test] + fn test_trailing_dot_normalized() { + let al = allowlist(&["example.com"]); + + assert!(al.validate_url("https://example.com./path").is_ok()); + } + + // --- Private IP detection --- + + #[test] + fn test_private_ipv4_ranges() { + assert!(is_private_ip(&"127.0.0.1".parse().unwrap())); + assert!(is_private_ip(&"10.0.0.1".parse().unwrap())); + assert!(is_private_ip(&"172.16.0.1".parse().unwrap())); + assert!(is_private_ip(&"192.168.1.1".parse().unwrap())); + assert!(is_private_ip(&"169.254.1.1".parse().unwrap())); + assert!(is_private_ip(&"0.0.0.0".parse().unwrap())); + + assert!(!is_private_ip(&"8.8.8.8".parse().unwrap())); + assert!(!is_private_ip(&"1.1.1.1".parse().unwrap())); + } + + #[test] + fn test_private_ipv6_ranges() { + assert!(is_private_ip(&"::1".parse().unwrap())); + assert!(is_private_ip(&"fc00::1".parse().unwrap())); + assert!(is_private_ip(&"fd00::1".parse().unwrap())); + assert!(is_private_ip(&"fe80::1".parse().unwrap())); + + assert!(!is_private_ip(&"2001:db8::1".parse().unwrap())); + } + + #[test] + fn test_ipv4_mapped_ipv6_private() { + // ::ffff:127.0.0.1 + assert!(is_private_ip(&"::ffff:127.0.0.1".parse().unwrap())); + assert!(!is_private_ip(&"::ffff:8.8.8.8".parse().unwrap())); + } + + // --- is_ip_like edge cases --- + + #[test] + fn test_is_ip_like_dotted_octal() { + // Octal representation: 0177.0.0.1 = 127.0.0.1 + assert!(is_ip_like("0177.0.0.1")); + } + + #[test] + fn test_is_ip_like_dotted_hex() { + assert!(is_ip_like("0x7f.0x0.0x0.0x1")); + } + + #[test] + fn test_is_ip_like_pure_decimal() { + assert!(is_ip_like("2130706433")); + } + + #[test] + fn test_is_ip_like_hex_prefix() { + assert!(is_ip_like("0x7f000001")); + } + + #[test] + fn test_is_ip_like_rejects_normal_domains() { + assert!(!is_ip_like("example.com")); + assert!(!is_ip_like("api.example.com")); + assert!(!is_ip_like("my-domain.org")); + } + + #[test] + fn test_is_ip_like_does_not_flag_real_domains() { + // Domains with IP-like segments but containing non-numeric chars + assert!(!is_ip_like("192-168-1-1.example.com")); + assert!(!is_ip_like("ip-10-0-0-1.ec2.internal")); + assert!(!is_ip_like("host123.example.com")); + } + + #[test] + fn test_is_ip_like_empty_string_safe() { + assert!(!is_ip_like("")); + } + + #[test] + fn test_is_ip_like_bracketed_ipv6() { + // After bracket stripping, "::1" is not caught by is_ip_like + // (it contains colons, not digits-only) - but the url crate + // would parse it as IPv6 and catch it in validate_parsed_url + assert!(!is_ip_like("[::1]")); + } + + // --- Wildcard edge cases --- + + #[test] + fn test_wildcard_does_not_match_partial_suffix() { + let al = allowlist(&["*.example.com"]); + + // "notexample.com" ends with "example.com" but should not match + assert!(al.validate_url("https://notexample.com").is_err()); + } + + #[test] + fn test_multiple_patterns() { + let al = allowlist(&["api.example.com", "*.cdn.example.com"]); + + assert!(al.validate_url("https://api.example.com/path").is_ok()); + assert!( + al.validate_url("https://img.cdn.example.com/pic.png") + .is_ok() + ); + assert!(al.validate_url("https://example.com").is_err()); + } + + // --- URL parsing edge cases --- + + #[test] + fn test_data_scheme_rejected() { + let al = allowlist(&["example.com"]); + let result = al.validate_url("data:text/html,

hello

"); + + assert!(result.is_err()); + } + + #[test] + fn test_url_with_port_allowed() { + let al = allowlist(&["example.com"]); + + assert!(al.validate_url("https://example.com:8443/path").is_ok()); + } + + #[test] + fn test_url_with_query_and_fragment() { + let al = allowlist(&["example.com"]); + + assert!( + al.validate_url("https://example.com/path?key=val#frag") + .is_ok() + ); + } + + #[test] + fn test_empty_url_rejected() { + let al = allowlist(&["example.com"]); + + assert!(al.validate_url("").is_err()); + } + + // --- Private IP edge cases --- + + #[test] + fn test_broadcast_is_private() { + assert!(is_private_ip(&"255.255.255.255".parse().unwrap())); + } + + #[test] + fn test_ipv6_unspecified_is_private() { + assert!(is_private_ip(&"::".parse().unwrap())); + } + + #[test] + fn test_172_boundary_values() { + // 172.16.0.0 - 172.31.255.255 is private + assert!(is_private_ip(&"172.16.0.1".parse().unwrap())); + assert!(is_private_ip(&"172.31.255.254".parse().unwrap())); + // 172.32.0.0 is NOT private + assert!(!is_private_ip(&"172.32.0.1".parse().unwrap())); + } + + // --- Dynamic allowlist (add_patterns) --- + + #[test] + fn test_add_patterns_allows_new_domain() { + let mut al = allowlist(&["api.example.com"]); + + assert!(al.validate_url("https://new.example.com").is_err()); + + al.add_patterns(vec!["new.example.com".to_string()]) + .unwrap(); + + assert!(al.validate_url("https://new.example.com").is_ok()); + // Original pattern still works + assert!(al.validate_url("https://api.example.com").is_ok()); + } + + #[test] + fn test_add_patterns_rejects_wildcards() { + let mut al = allowlist(&["api.example.com"]); + + let result = al.add_patterns(vec!["*.cdn.example.com".to_string()]); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::WildcardNotAllowedAtRuntime(_) + )); + // Allowlist unchanged after rejection + assert_eq!(al.pattern_count(), 1); + } + + #[test] + fn test_add_patterns_rejects_wildcard_in_batch() { + let mut al = allowlist(&[]); + + // If any pattern is a wildcard, none should be added + let result = al.add_patterns(vec![ + "good.example.com".to_string(), + "*.bad.example.com".to_string(), + ]); + + assert!(result.is_err()); + assert!(al.is_empty()); + } + + #[test] + fn test_add_patterns_to_empty_allowlist() { + let mut al = allowlist(&[]); + + assert!(al.is_empty()); + assert!(al.validate_url("https://example.com").is_err()); + + al.add_patterns(vec!["example.com".to_string()]).unwrap(); + + assert!(!al.is_empty()); + assert!(al.validate_url("https://example.com").is_ok()); + } + + #[test] + fn test_add_patterns_duplicates_accepted() { + let mut al = allowlist(&["example.com"]); + + al.add_patterns(vec!["example.com".to_string()]).unwrap(); + + // HashSet deduplicates runtime patterns; init pattern still counted + // init: ["example.com"], runtime: {"example.com"} = 2 total + assert_eq!(al.pattern_count(), 2); + assert!(al.validate_url("https://example.com").is_ok()); + } + + #[test] + fn test_add_patterns_empty_vec_is_noop() { + let mut al = allowlist(&["example.com"]); + let count_before = al.pattern_count(); + + al.add_patterns(vec![]).unwrap(); + + assert_eq!(al.pattern_count(), count_before); + } + + #[test] + fn test_add_patterns_case_insensitive() { + let mut al = allowlist(&[]); + + al.add_patterns(vec!["API.Example.COM".to_string()]) + .unwrap(); + + assert!(al.validate_url("https://api.example.com").is_ok()); + } + + #[test] + fn test_pattern_count() { + let al = allowlist(&["a.com", "b.com", "*.c.com"]); + + assert_eq!(al.pattern_count(), 3); + } + + #[test] + fn test_domain_pattern_partial_eq() { + assert_eq!( + DomainPattern::parse("example.com"), + DomainPattern::parse("example.com") + ); + assert_eq!( + DomainPattern::parse("*.example.com"), + DomainPattern::parse("*.example.com") + ); + // Same inner string but different variant + assert_ne!( + DomainPattern::parse("example.com"), + DomainPattern::parse("*.example.com") + ); + } + + // --- Dynamic allowlist (remove_patterns) --- + + #[test] + fn test_remove_patterns_removes_runtime_domain() { + let mut al = allowlist(&[]); + + al.add_patterns(vec!["api.example.com".to_string()]) + .unwrap(); + assert!(al.validate_url("https://api.example.com").is_ok()); + + let removed = al + .remove_patterns(&["api.example.com".to_string()]) + .unwrap(); + + assert_eq!(removed, 1); + assert!(al.validate_url("https://api.example.com").is_err()); + } + + #[test] + fn test_remove_patterns_does_not_affect_init_patterns() { + let mut al = allowlist(&["api.example.com"]); + + // Attempt to remove a config-time domain — should be a no-op + let removed = al + .remove_patterns(&["api.example.com".to_string()]) + .unwrap(); + + assert_eq!(removed, 0); + assert!(al.validate_url("https://api.example.com").is_ok()); + } + + #[test] + fn test_remove_patterns_idempotent() { + let mut al = allowlist(&[]); + + // Removing a domain that was never added + let removed = al + .remove_patterns(&["nonexistent.com".to_string()]) + .unwrap(); + + assert_eq!(removed, 0); + } + + #[test] + fn test_remove_patterns_rejects_wildcards() { + let mut al = allowlist(&[]); + + al.add_patterns(vec!["api.example.com".to_string()]) + .unwrap(); + + let result = al.remove_patterns(&["*.example.com".to_string()]); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::WildcardNotAllowedAtRuntime(_) + )); + // Runtime pattern should still be present (atomic rejection) + assert_eq!(al.runtime_pattern_count(), 1); + } + + #[test] + fn test_remove_patterns_partial_match() { + let mut al = allowlist(&[]); + + al.add_patterns(vec![ + "a.example.com".to_string(), + "b.example.com".to_string(), + ]) + .unwrap(); + + let removed = al.remove_patterns(&["a.example.com".to_string()]).unwrap(); + + assert_eq!(removed, 1); + assert!(al.validate_url("https://a.example.com").is_err()); + assert!(al.validate_url("https://b.example.com").is_ok()); + } + + #[test] + fn test_remove_patterns_case_insensitive() { + let mut al = allowlist(&[]); + + al.add_patterns(vec!["api.example.com".to_string()]) + .unwrap(); + + let removed = al + .remove_patterns(&["API.Example.COM".to_string()]) + .unwrap(); + + assert_eq!(removed, 1); + assert!(al.validate_url("https://api.example.com").is_err()); + } + + #[test] + fn test_remove_all_runtime_patterns() { + let mut al = allowlist(&["init.example.com"]); + + al.add_patterns(vec![ + "a.example.com".to_string(), + "b.example.com".to_string(), + ]) + .unwrap(); + + let removed = al.remove_all_runtime_patterns(); + + assert_eq!(removed, 2); + assert!(al.validate_url("https://a.example.com").is_err()); + assert!(al.validate_url("https://b.example.com").is_err()); + // Config-time pattern preserved + assert!(al.validate_url("https://init.example.com").is_ok()); + } + + #[test] + fn test_remove_all_runtime_patterns_empty() { + let mut al = allowlist(&["init.example.com"]); + + let removed = al.remove_all_runtime_patterns(); + + assert_eq!(removed, 0); + } + + #[test] + fn test_runtime_pattern_count() { + let mut al = allowlist(&["init.example.com"]); + + assert_eq!(al.runtime_pattern_count(), 0); + + al.add_patterns(vec![ + "a.example.com".to_string(), + "b.example.com".to_string(), + ]) + .unwrap(); + + assert_eq!(al.runtime_pattern_count(), 2); + } + + #[test] + fn test_add_then_remove_then_add_again() { + let mut al = allowlist(&[]); + + al.add_patterns(vec!["api.example.com".to_string()]) + .unwrap(); + assert!(al.validate_url("https://api.example.com").is_ok()); + + al.remove_patterns(&["api.example.com".to_string()]) + .unwrap(); + assert!(al.validate_url("https://api.example.com").is_err()); + + al.add_patterns(vec!["api.example.com".to_string()]) + .unwrap(); + assert!(al.validate_url("https://api.example.com").is_ok()); + } + + #[test] + fn test_pattern_count_reflects_both_init_and_runtime() { + let mut al = allowlist(&["a.com", "*.b.com"]); + + assert_eq!(al.pattern_count(), 2); + assert_eq!(al.config_pattern_count(), 2); + assert_eq!(al.runtime_pattern_count(), 0); + + al.add_patterns(vec!["c.com".to_string(), "d.com".to_string()]) + .unwrap(); + + assert_eq!(al.pattern_count(), 4); + assert_eq!(al.config_pattern_count(), 2); + assert_eq!(al.runtime_pattern_count(), 2); + } + + #[test] + fn test_is_empty_after_removing_all_runtime() { + let mut al = allowlist(&[]); + + al.add_patterns(vec!["api.example.com".to_string()]) + .unwrap(); + assert!(!al.is_empty()); + + al.remove_all_runtime_patterns(); + assert!(al.is_empty()); + + // With init patterns, still not empty + let mut al2 = allowlist(&["init.example.com"]); + + al2.add_patterns(vec!["api.example.com".to_string()]) + .unwrap(); + al2.remove_all_runtime_patterns(); + assert!(!al2.is_empty()); + } + + #[test] + fn test_remove_patterns_empty_vec_is_noop() { + let mut al = allowlist(&[]); + + al.add_patterns(vec!["api.example.com".to_string()]) + .unwrap(); + + let removed = al.remove_patterns(&[]).unwrap(); + + assert_eq!(removed, 0); + assert_eq!(al.runtime_pattern_count(), 1); + } + + #[test] + fn test_is_runtime_domain() { + let mut al = allowlist(&["init.example.com"]); + + al.add_patterns(vec!["runtime.example.com".to_string()]) + .unwrap(); + + assert!(al.is_runtime_domain("runtime.example.com")); + assert!(!al.is_runtime_domain("init.example.com")); + assert!(!al.is_runtime_domain("nonexistent.com")); + } + + #[test] + fn test_add_patterns_deduplicates_in_hashset() { + let mut al = allowlist(&[]); + + al.add_patterns(vec![ + "api.example.com".to_string(), + "api.example.com".to_string(), + ]) + .unwrap(); + + assert_eq!(al.runtime_pattern_count(), 1); + } + + // --- validate_domain_pattern --- + + #[test] + fn test_validate_domain_pattern_valid() { + assert!(validate_domain_pattern("example.com").is_ok()); + assert!(validate_domain_pattern("api.example.com").is_ok()); + assert!(validate_domain_pattern("*.example.com").is_ok()); + assert!(validate_domain_pattern("my-domain.org").is_ok()); + } + + #[test] + fn test_validate_domain_pattern_empty() { + let err = validate_domain_pattern("").unwrap_err(); + + assert!(matches!(err, Error::InvalidDomainPattern(_))); + } + + #[test] + fn test_validate_domain_pattern_whitespace_only() { + assert!(validate_domain_pattern(" ").is_err()); + assert!(validate_domain_pattern("\t").is_err()); + } + + #[test] + fn test_validate_domain_pattern_leading_trailing_whitespace_rejected() { + let err = validate_domain_pattern(" example.com ").unwrap_err(); + + assert!(matches!(err, Error::InvalidDomainPattern(_))); + assert!(err.to_string().contains("whitespace")); + } + + #[test] + fn test_validate_domain_pattern_leading_whitespace_rejected() { + assert!(validate_domain_pattern(" example.com").is_err()); + } + + #[test] + fn test_validate_domain_pattern_trailing_whitespace_rejected() { + assert!(validate_domain_pattern("example.com ").is_err()); + } + + #[test] + fn test_validate_domain_pattern_tab_whitespace_rejected() { + assert!(validate_domain_pattern("\texample.com").is_err()); + } + + #[test] + fn test_validate_domain_pattern_control_characters() { + assert!(validate_domain_pattern("example\n.com").is_err()); + assert!(validate_domain_pattern("example\r.com").is_err()); + assert!(validate_domain_pattern("example\t.com").is_err()); + assert!(validate_domain_pattern("example\0.com").is_err()); + } + + #[test] + fn test_validate_domain_pattern_too_long() { + let long_pattern = "a".repeat(254); + let err = validate_domain_pattern(&long_pattern).unwrap_err(); + + assert!(matches!(err, Error::InvalidDomainPattern(_))); + + // Exactly 253 is fine + let max_pattern = "a".repeat(253); + + assert!(validate_domain_pattern(&max_pattern).is_ok()); + } + + #[test] + fn test_validate_domain_pattern_url_reserved_chars() { + assert!(validate_domain_pattern("example.com:8080").is_err()); + assert!(validate_domain_pattern("example.com/path").is_err()); + assert!(validate_domain_pattern("example.com?q=1").is_err()); + assert!(validate_domain_pattern("example.com#frag").is_err()); + assert!(validate_domain_pattern("user@example.com").is_err()); + } + + #[test] + fn test_add_patterns_validates_before_mutating() { + let mut al = allowlist(&[]); + + // Batch with one valid, one invalid: none should be added + let result = al.add_patterns(vec![ + "good.example.com".to_string(), + "bad\n.example.com".to_string(), + ]); + + assert!(result.is_err()); + assert!(al.is_empty()); + } + + #[test] + fn test_validate_domain_pattern_bare_wildcard_rejected() { + let err = validate_domain_pattern("*").unwrap_err(); + + assert!(matches!(err, Error::InvalidDomainPattern(_))); + assert!(err.to_string().contains("bare '*'")); + } + + #[test] + fn test_new_rejects_wildcard_with_empty_base() { + // "*."" strips to "", which fails the empty check + let result = DomainAllowlist::new(vec!["*.".to_string()]); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::InvalidDomainPattern(_) + )); + } + + #[test] + fn test_new_rejects_wildcard_with_invalid_base() { + // The base domain after "*." contains a URL-reserved character + let result = DomainAllowlist::new(vec!["*.com:443".to_string()]); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::InvalidDomainPattern(_) + )); + } + + #[test] + fn test_validate_url_rejects_octal_ip() { + let al = allowlist(&["example.com"]); + + // Octal IP for 127.0.0.1 + let result = al.validate_url("https://0177.0.0.1/path"); + + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), Error::IpAddressNotAllowed)); + } + + #[test] + fn test_init_patterns_validated() { + // Valid patterns work + assert!(DomainAllowlist::new(vec!["example.com".to_string()]).is_ok()); + assert!(DomainAllowlist::new(vec!["*.example.com".to_string()]).is_ok()); + + // Invalid patterns rejected + assert!(DomainAllowlist::new(vec!["".to_string()]).is_err()); + assert!(DomainAllowlist::new(vec!["*".to_string()]).is_err()); + assert!(DomainAllowlist::new(vec!["example\n.com".to_string()]).is_err()); + assert!(DomainAllowlist::new(vec!["example.com:443".to_string()]).is_err()); + } + + #[test] + fn test_init_patterns_atomic_validation() { + // If any pattern is invalid, none should be accepted + let result = DomainAllowlist::new(vec![ + "good.example.com".to_string(), + "bad\n.example.com".to_string(), + ]); + + assert!(result.is_err()); + } + + #[test] + fn test_add_patterns_rejects_empty_pattern() { + let mut al = allowlist(&[]); + + let result = al.add_patterns(vec!["".to_string()]); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::InvalidDomainPattern(_) + )); + } + + #[test] + fn test_add_patterns_rejects_pattern_with_colon() { + let mut al = allowlist(&[]); + + let result = al.add_patterns(vec!["example.com:443".to_string()]); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::InvalidDomainPattern(_) + )); + } + + // --- subdomain_validator --- + + #[test] + fn test_subdomain_validator_matches_base_domain() { + let matches = subdomain_validator("example.com"); + + assert!(matches("example.com")); + } + + #[test] + fn test_subdomain_validator_matches_subdomain() { + let matches = subdomain_validator("example.com"); + + assert!(matches("api.example.com")); + assert!(matches("deep.sub.example.com")); + } + + #[test] + fn test_subdomain_validator_rejects_non_subdomain() { + let matches = subdomain_validator("example.com"); + + assert!(!matches("notexample.com")); + assert!(!matches("evil.com")); + } + + #[test] + fn test_subdomain_validator_case_insensitive() { + let matches = subdomain_validator("Example.COM"); + + assert!(matches("example.com")); + assert!(matches("API.EXAMPLE.COM")); + } + + // --- exact_domains_validator --- + + #[test] + fn test_exact_domains_validator_matches() { + let matches = exact_domains_validator(&["api.example.com", "cdn.example.com"]); + + assert!(matches("api.example.com")); + assert!(matches("cdn.example.com")); + } + + #[test] + fn test_exact_domains_validator_rejects_non_match() { + let matches = exact_domains_validator(&["api.example.com"]); + + assert!(!matches("other.example.com")); + assert!(!matches("example.com")); + } + + #[test] + fn test_exact_domains_validator_case_insensitive() { + let matches = exact_domains_validator(&["api.example.com"]); + + assert!(matches("API.Example.COM")); + } + + #[test] + fn test_exact_domains_validator_empty_list() { + let matches = exact_domains_validator(&[]); + + assert!(!matches("anything.com")); + } + + // --- validate_parsed_url direct tests --- + + #[test] + fn test_validate_parsed_url_rejects_non_http_scheme() { + let al = allowlist(&["example.com"]); + let url = Url::parse("ftp://example.com/file").unwrap(); + + assert!(matches!( + al.validate_parsed_url(&url).unwrap_err(), + Error::SchemeNotAllowed(_) + )); + } + + #[test] + fn test_validate_parsed_url_rejects_userinfo() { + let al = allowlist(&["example.com"]); + let url = Url::parse("https://user:pass@example.com").unwrap(); + + assert!(matches!( + al.validate_parsed_url(&url).unwrap_err(), + Error::UserinfoNotAllowed + )); + } + + #[test] + fn test_validate_parsed_url_rejects_ip_address() { + let al = allowlist(&["example.com"]); + let url = Url::parse("https://127.0.0.1/path").unwrap(); + + assert!(matches!( + al.validate_parsed_url(&url).unwrap_err(), + Error::IpAddressNotAllowed + )); + } + + #[test] + fn test_validate_parsed_url_allows_matching_domain() { + let al = allowlist(&["example.com"]); + let url = Url::parse("https://example.com/path").unwrap(); + + assert!(al.validate_parsed_url(&url).is_ok()); + } + + #[test] + fn test_validate_parsed_url_rejects_disallowed_domain() { + let al = allowlist(&["example.com"]); + let url = Url::parse("https://evil.com/path").unwrap(); + + assert!(matches!( + al.validate_parsed_url(&url).unwrap_err(), + Error::DomainNotAllowed(_) + )); + } + + // --- Port-based access --- + + #[test] + fn test_any_port_on_allowed_domain_is_accessible() { + let al = allowlist(&["example.com"]); + + // All ports should be allowed — port is not part of the domain check + assert!(al.validate_url("https://example.com:443/path").is_ok()); + assert!(al.validate_url("https://example.com:8443/path").is_ok()); + assert!(al.validate_url("http://example.com:8080/path").is_ok()); + assert!(al.validate_url("http://example.com:3000/path").is_ok()); + } +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..972c4f5 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,2740 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use base64::Engine; +use futures_util::StreamExt; +use parking_lot::RwLock; +use reqwest::redirect; + +use crate::allowlist::{DomainAllowlist, is_private_ip}; +use crate::config::{HttpClientConfig, RetryConfig}; +use crate::error::{Error, Result}; +use crate::types::{FetchRequest, FetchResponse}; + +/// Headers that are always forbidden in per-request and default headers. +/// +/// - `host`: Prevents virtual-host routing attacks that bypass the domain allowlist. +/// - `connection`, `keep-alive`, `upgrade`: Hop-by-hop headers managed by reqwest; +/// caller-supplied values interfere with connection pooling and HTTP/2 multiplexing. +/// - `transfer-encoding`, `te`, `trailer`: Body framing headers managed by reqwest; +/// caller-supplied values enable request smuggling attacks. +/// +/// Cookie/Authorization headers are intentionally excluded — this plugin's security +/// model is domain restriction, not credential management. Applications legitimately +/// need to set these for session-based APIs and auth tokens. +const FORBIDDEN_HEADERS: &[&str] = &[ + "host", + "connection", + "keep-alive", + "transfer-encoding", + "te", + "upgrade", + "trailer", +]; + +/// Header name prefixes that are always forbidden. +/// +/// - `sec-`: Browser-set security headers (`Sec-Fetch-*`, `Sec-CH-*`). +/// Setting these from Rust/JS would misrepresent the request context. +/// - `proxy-`: Proxy control headers (`Proxy-Authorization`, `Proxy-Connection`). +/// These affect proxy routing in ways outside the plugin's security model. +const FORBIDDEN_HEADER_PREFIXES: &[&str] = &["sec-", "proxy-"]; + +/// Validates that a header name is not in the forbidden list. +/// +/// Returns `Err(Error::ForbiddenHeader)` if the name matches any entry in +/// [`FORBIDDEN_HEADERS`] (case-insensitive exact match) or any prefix in +/// [`FORBIDDEN_HEADER_PREFIXES`] (case-insensitive starts-with). +pub(crate) fn validate_header_name(name: &str) -> Result<()> { + let lower = name.to_ascii_lowercase(); + + if FORBIDDEN_HEADERS.contains(&lower.as_str()) { + return Err(Error::ForbiddenHeader(lower)); + } + + for prefix in FORBIDDEN_HEADER_PREFIXES { + if lower.starts_with(prefix) { + return Err(Error::ForbiddenHeader(lower)); + } + } + + Ok(()) +} + +/// Core HTTP client state shared across all requests. +/// +/// Uses `Arc` internally so cloning is cheap. `reqwest::Client` also +/// uses internal `Arc`, making this safe to share as Tauri managed state. +/// +/// The allowlist is wrapped in `Arc>` to support +/// runtime mutations via [`add_allowed_domain`](Self::add_allowed_domain), +/// [`add_allowed_domains`](Self::add_allowed_domains), +/// [`remove_allowed_domain`](Self::remove_allowed_domain), +/// [`remove_allowed_domains`](Self::remove_allowed_domains), and +/// [`remove_all_runtime_domains`](Self::remove_all_runtime_domains). +/// The same `Arc` is shared with the redirect policy closure, ensuring +/// both always see the current allowlist state. +#[derive(Clone)] +pub struct HttpClientState { + client: reqwest::Client, + allowlist: Arc>, + config: Arc, +} + +/// Tracks in-flight requests for abort support. +/// +/// Maps request IDs to their `AbortHandle`, allowing cancellation from +/// the TypeScript guest via the `abort_request` command. +#[derive(Clone)] +pub struct InFlightRequests(Arc>>); + +impl Default for InFlightRequests { + fn default() -> Self { + Self(Arc::new(tokio::sync::RwLock::new(HashMap::new()))) + } +} + +impl InFlightRequests { + pub fn new() -> Self { + Self::default() + } + + /// Registers an abort handle for a request ID. + pub async fn register(&self, request_id: String, handle: tokio::task::AbortHandle) { + self.0.write().await.insert(request_id, handle); + } + + /// Removes a request ID from tracking (called on completion). + pub async fn remove(&self, request_id: &str) { + self.0.write().await.remove(request_id); + } +} + +impl HttpClientState { + /// Builds a new `HttpClientState` with a shared allowlist, reqwest client, and config. + /// + /// The `allowlist` Arc should be the same instance passed to + /// [`build_redirect_policy`], ensuring both the request validation path + /// and the redirect policy always read the same allowlist state. + pub fn new( + client: reqwest::Client, + allowlist: Arc>, + config: HttpClientConfig, + ) -> Self { + Self { + client, + allowlist, + config: Arc::new(config), + } + } + + /// Validates a URL against the current allowlist. + pub fn validate_url(&self, url: &str) -> Result { + self.allowlist.read().validate_url(url) + } + + /// Returns `true` if the allowlist has no patterns (blocks all requests). + pub fn is_allowlist_empty(&self) -> bool { + self.allowlist.read().is_empty() + } + + /// Adds a single domain pattern to the allowlist at runtime. + /// + /// This is a convenience wrapper around [`add_allowed_domains`](Self::add_allowed_domains). + pub fn add_allowed_domain(&self, domain: impl Into) -> Result<()> { + self.add_allowed_domains(vec![domain.into()]) + } + + /// Adds domain patterns to the allowlist at runtime. + /// + /// # Errors + /// + /// Returns [`Error::AllowlistSizeExceeded`] if adding the domains would + /// exceed the configured `max_allowlist_size` cap. + /// + /// Returns [`Error::WildcardNotAllowedAtRuntime`] if any pattern starts + /// with `*.`. Wildcard patterns should be configured at build time. + /// No patterns are added if any pattern is invalid (atomic operation). + pub fn add_allowed_domains( + &self, + domains: impl IntoIterator>, + ) -> Result<()> { + let domains: Vec = domains.into_iter().map(Into::into).collect(); + let mut al = self.allowlist.write(); + let current = al.pattern_count(); + let limit = self.config.max_allowlist_size; + + if current + domains.len() > limit { + return Err(Error::AllowlistSizeExceeded { + count: current + domains.len(), + limit, + }); + } + + tracing::info!( + patterns = ?domains, + "adding domains to allowlist" + ); + + al.add_patterns(domains)?; + + Ok(()) + } + + /// Removes a single domain pattern from the runtime allowlist. + /// + /// Returns `true` if the domain was found and removed, `false` if it was + /// not present in the runtime allowlist (config-time domains are not affected). + /// + /// This is a convenience wrapper around [`remove_allowed_domains`](Self::remove_allowed_domains). + pub fn remove_allowed_domain(&self, domain: impl Into) -> Result { + let count = self.remove_allowed_domains(vec![domain.into()])?; + + Ok(count > 0) + } + + /// Removes domain patterns from the runtime allowlist. + /// + /// Only runtime-added patterns are affected. Config-time patterns (set via + /// [`Builder::allowed_domains`](crate::Builder::allowed_domains)) cannot be + /// removed — attempts to remove them are silently ignored. + /// + /// Returns the number of patterns actually removed. + /// + /// # Allowlist Consistency + /// + /// Removal uses eventual consistency: in-flight requests may complete their + /// current hop, but the redirect policy will block subsequent hops to the + /// removed domain. The failure mode is always "fail secure" — a removed + /// domain produces `DomainNotAllowed`, never silent access. + /// + /// # Errors + /// + /// Returns [`Error::WildcardNotAllowedAtRuntime`] if any pattern starts + /// with `*.`. No patterns are removed if any pattern is invalid (atomic operation). + pub fn remove_allowed_domains( + &self, + domains: impl IntoIterator>, + ) -> Result { + let domains: Vec = domains.into_iter().map(Into::into).collect(); + let mut al = self.allowlist.write(); + + let removed = al.remove_patterns(&domains)?; + + if removed > 0 { + tracing::info!( + patterns = ?domains, + removed, + remaining = al.pattern_count(), + "removed domains from allowlist" + ); + } + + Ok(removed) + } + + /// Removes all runtime-added domain patterns, preserving config-time patterns. + /// + /// Returns the number of patterns removed. Useful for revoking all + /// session-scoped domain grants (e.g., on user logout). + pub fn remove_all_runtime_domains(&self) -> usize { + let mut al = self.allowlist.write(); + let removed = al.remove_all_runtime_patterns(); + + if removed > 0 { + tracing::info!( + removed, + remaining = al.pattern_count(), + "removed all runtime domains from allowlist" + ); + } + + removed + } + + /// Executes an HTTP request through the full validation and execution pipeline, + /// with optional automatic retry on transient failures. + /// + /// # Pipeline (per attempt) + /// + /// 1. Validate URL through allowlist security checks + /// 2. Build the reqwest request (method, headers, body, timeout) + /// 3. Execute with custom redirect policy + /// 4. Read response body with size limit enforcement + /// 5. Detect text vs binary content and encode accordingly + /// 6. Return structured response + /// + /// # Retry Behavior + /// + /// When retry is enabled (`RetryConfig::max_retries > 0`), transient errors + /// (connection failures, timeouts) and retryable status codes (default: + /// 429, 500, 502, 503, 504) trigger automatic retries with exponential + /// backoff and jitter. Security errors are never retried. + /// + /// The URL is re-validated against the allowlist on every attempt. If the + /// allowlist changes between retries (e.g., a domain is removed), the + /// subsequent attempt fails with `DomainNotAllowed` (fail-secure). + /// + /// Timeout is per-attempt: a request with `max_retries: 3` and a 10s + /// timeout could take up to ~43s (4 attempts + backoff delays). + /// + /// When retries are exhausted, the last response is returned (including + /// 5xx responses — these are valid HTTP responses, not transport errors). + /// Intermediate retryable responses are fully read and discarded; the + /// caller only sees the final attempt's response body. + pub async fn execute(&self, req: FetchRequest) -> Result { + let method = parse_method(req.method.as_deref().unwrap_or("GET"))?; + let body_bytes = match req.body { + Some(ref b) => Some(decode_request_body(b, req.body_encoding.as_deref())?), + None => None, + }; + let timeout = req + .timeout_ms + .map(Duration::from_millis) + .or(self.config.default_timeout); + + let max_retries = self.resolve_max_retries(&req); + let max_attempts = max_retries + 1; + let retry_config = &self.config.retry; + let method_retryable = retry_config.is_retryable_method(method.as_str()); + + // Parse and validate URL once before the retry loop. The URL string + // doesn't change between retries, so re-parsing is unnecessary. + let url = self.allowlist.read().validate_url(&req.url)?; + + let mut last_result: Option> = None; + let mut attempt: u32 = 0; + + // Bounded: returns when attempt + 1 >= max_attempts (should_retry = false) + loop { + assert!( + attempt < max_attempts, + "retry loop exceeded max_attempts ({max_attempts}); this is a bug" + ); + + if attempt > 0 { + let backoff = calculate_backoff(retry_config, attempt, last_result.as_ref()); + + tracing::debug!( + attempt, + max_attempts, + backoff_ms = backoff.as_millis() as u64, + url = %req.url, + "retrying request" + ); + + tokio::time::sleep(backoff).await; + + // SECURITY: Re-validate the parsed URL on every retry attempt. + // The allowlist is mutable at runtime — a domain removed between + // retries must cause immediate failure (fail-secure). We use + // validate_parsed_url (not validate_url) to avoid redundant + // string parsing since the URL itself hasn't changed. + self.allowlist.read().validate_parsed_url(&url)?; + } + + let result = self + .execute_once(&url, &method, &req.headers, body_bytes.as_deref(), timeout) + .await; + + let should_retry = attempt + 1 < max_attempts && method_retryable; + + match result { + Ok(resp) if should_retry && retry_config.is_retryable_status(resp.status) => { + last_result = Some(Ok(resp)); + attempt += 1; + } + Err(e) if should_retry && e.is_retryable() => { + last_result = Some(Err(e)); + attempt += 1; + } + Ok(mut resp) => { + resp.retry_count = attempt; + return Ok(resp); + } + Err(e) => return Err(e), + } + } + } + + /// Executes a single HTTP request attempt through the full pipeline. + /// + /// This is the inner implementation called by [`execute`](Self::execute) + /// on each attempt. It assumes URL validation has already been performed. + async fn execute_once( + &self, + url: &url::Url, + method: &reqwest::Method, + headers: &Option>, + body: Option<&[u8]>, + timeout: Option, + ) -> Result { + let mut builder = self.client.request(method.clone(), url.clone()); + + for (key, value) in &self.config.default_headers { + builder = builder.header(key.as_str(), value.as_str()); + } + + if let Some(headers) = headers { + for (key, value) in headers { + validate_header_name(key)?; + builder = builder.header(key.as_str(), value.as_str()); + } + } + + if let Some(body) = body { + builder = builder.body(body.to_vec()); + } + + if let Some(t) = timeout { + builder = builder.timeout(t); + } + + let response = builder.send().await?; + + // Track if we were redirected + let final_url = response.url().clone(); + let redirected = final_url.as_str() != url.as_str(); + + // Anti-DNS-rebinding: verify the resolved address is not private. + // NOTE: remote_addr() returns None through proxies, neutralizing this + // check. This is acceptable — proxy environments provide their own + // network-layer protections. The domain allowlist remains the primary + // security boundary. + if !self.config.allow_private_ip { + if let Some(remote_addr) = response.remote_addr() { + let ip = remote_addr.ip(); + + if is_private_ip(&ip) { + return Err(Error::DomainNotAllowed(format!( + "resolved to private ip address: {ip}" + ))); + } + } else { + tracing::warn!( + url = %final_url, + "remote_addr() returned None; DNS rebinding check skipped" + ); + } + } + + let status = response.status(); + let status_text = status.canonical_reason().unwrap_or("").to_string(); + + // Collect response headers (multi-value support) + let mut response_headers: HashMap> = HashMap::new(); + + // Extract content type and Retry-After before consuming the response + // via bytes_stream(), since we need them later. + let content_type = response + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + + for (name, value) in response.headers() { + let name = name.as_str().to_string(); + + if let Ok(v) = value.to_str() { + response_headers + .entry(name) + .or_default() + .push(v.to_string()); + } + } + + let body_bytes = self.read_body_with_limit(response).await?; + + let is_text = is_text_content_type(&content_type); + let (body, body_encoding) = if is_text { + ( + String::from_utf8_lossy(&body_bytes).into_owned(), + "utf8".to_string(), + ) + } else { + ( + base64::engine::general_purpose::STANDARD.encode(&body_bytes), + "base64".to_string(), + ) + }; + + Ok(FetchResponse { + status: status.as_u16(), + status_text, + headers: response_headers, + body, + body_encoding, + url: final_url.to_string(), + redirected, + retry_count: 0, // Set by execute() after the loop + }) + } + + /// Reads the response body as a stream, enforcing the configured size limit. + /// + /// Unlike buffering the entire body before checking, this aborts as soon as + /// accumulated bytes exceed the limit, preventing memory exhaustion. + async fn read_body_with_limit(&self, response: reqwest::Response) -> Result> { + let limit = self.config.max_response_body_size; + + if let Some(len) = response.content_length() + && len > limit as u64 + { + return Err(Error::ResponseTooLarge { + size: len.try_into().unwrap_or(usize::MAX), + limit, + }); + } + + // Pre-allocate using the (already validated) Content-Length when available. + // This is safe because we rejected values exceeding the limit above. + let capacity = response.content_length().unwrap_or(0) as usize; + let mut body = Vec::with_capacity(capacity); + let mut stream = response.bytes_stream(); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(Error::Request)?; + + body.extend_from_slice(&chunk); + + if body.len() > limit { + return Err(Error::ResponseTooLarge { + size: body.len(), + limit, + }); + } + } + + Ok(body) + } + + /// Resolves the effective max retries for a request, capping per-request + /// overrides at the plugin-level configuration ceiling. + fn resolve_max_retries(&self, req: &FetchRequest) -> u32 { + match req.max_retries { + Some(n) => n.min(self.config.retry.max_retries), + None => self.config.retry.max_retries, + } + } + + /// Aborts an in-flight request by ID. + /// + /// Returns `true` if a request with the given ID was found and aborted, + /// `false` if no matching request was in flight. + pub async fn abort(in_flight: &InFlightRequests, request_id: &str) -> bool { + // Drop the write lock before aborting to avoid holding it during task cancellation + let handle = { + let mut map = in_flight.0.write().await; + + map.remove(request_id) + }; + + if let Some(handle) = handle { + handle.abort(); + true + } else { + false + } + } +} + +fn parse_method(method: &str) -> Result { + match method.to_uppercase().as_str() { + "GET" => Ok(reqwest::Method::GET), + "POST" => Ok(reqwest::Method::POST), + "PUT" => Ok(reqwest::Method::PUT), + "DELETE" => Ok(reqwest::Method::DELETE), + "PATCH" => Ok(reqwest::Method::PATCH), + "HEAD" => Ok(reqwest::Method::HEAD), + "OPTIONS" => Ok(reqwest::Method::OPTIONS), + other => Err(Error::Other(format!("unsupported http method: {other}"))), + } +} + +fn decode_request_body(body: &str, encoding: Option<&str>) -> Result> { + match encoding.unwrap_or("utf8") { + "base64" => base64::engine::general_purpose::STANDARD + .decode(body) + .map_err(|e| Error::Other(format!("invalid base64 body: {e}"))), + _ => Ok(body.as_bytes().to_vec()), + } +} + +/// Determines if a Content-Type value represents text content. +/// +/// Uses substring matching on the full content-type so that structured +/// types like `application/vnd.api+json` are correctly detected as text. +fn is_text_content_type(content_type: &str) -> bool { + let ct = content_type.to_lowercase(); + + ct.starts_with("text/") + || ct.contains("json") + || ct.contains("xml") + || ct.contains("javascript") + || ct.contains("html") + || ct.contains("css") + || ct.contains("svg") + || ct.contains("yaml") + || ct.contains("toml") + || ct.contains("csv") + || ct.contains("form-urlencoded") +} + +/// Calculates the backoff duration for a retry attempt using exponential +/// backoff with jitter. +/// +/// For responses with `Retry-After` headers, the header value is used instead +/// of the calculated backoff (capped at `max_retry_after`). +/// +/// Jitter uses "equal jitter": `base/2 + random(0, base/2)`, which provides +/// decorrelation without excessive variance. +fn calculate_backoff( + config: &RetryConfig, + attempt: u32, + last_result: Option<&Result>, +) -> Duration { + // Check for Retry-After header on the last response + if let Some(Ok(resp)) = last_result + && let Some(retry_after) = parse_retry_after_from_response(resp) + { + return retry_after.min(config.max_retry_after); + } + + // Exponential backoff: initial * 2^(attempt-1) + let exponent = (attempt - 1).min(31); // prevent overflow + let base_ms = config.initial_backoff.as_millis() as u64; + let calculated_ms = base_ms.saturating_mul(1u64 << exponent); + let capped_ms = calculated_ms.min(config.max_backoff.as_millis() as u64); + + // Equal jitter: base/2 + random(0, base/2) + let half = capped_ms / 2; + let jitter = if half > 0 { + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .subsec_nanos() as u64; + + nanos % half + } else { + 0 + }; + + Duration::from_millis(half + jitter) +} + +/// Extracts a `Retry-After` duration from a response's headers. +/// +/// Supports both delta-seconds format (`Retry-After: 120`) and ignores +/// HTTP-date format (too complex to parse without a date library). +fn parse_retry_after_from_response(resp: &FetchResponse) -> Option { + let values = resp.headers.get("retry-after")?; + let value = values.first()?; + + value.trim().parse::().ok().map(Duration::from_secs) +} + +/// Builds a custom redirect policy that validates each redirect hop against the allowlist. +/// +/// This is the #1 security requirement: prevents SSRF via open redirects. +/// +/// The `allowlist` Arc should be the same instance used by `HttpClientState`, +/// ensuring the redirect policy always reads the current allowlist (including +/// any domains added at runtime via [`HttpClientState::add_allowed_domains`]). +pub fn build_redirect_policy( + allowlist: Arc>, + max_redirects: usize, +) -> redirect::Policy { + build_redirect_policy_inner(allowlist, max_redirects, false) +} + +/// Inner implementation shared by [`build_redirect_policy`] and test helpers. +/// +/// When `allow_private_ip` is `true`, the DNS rebinding check on redirect +/// targets is skipped. This is only used in integration tests where the test +/// server resolves to localhost. +fn build_redirect_policy_inner( + allowlist: Arc>, + max_redirects: usize, + allow_private_ip: bool, +) -> redirect::Policy { + redirect::Policy::custom(move |attempt| { + if attempt.previous().len() >= max_redirects { + // stop() returns the last redirect response as a non-error result, + // so the caller sees a 3xx status rather than a reqwest error. + return attempt.stop(); + } + + let url = attempt.url().clone(); + + // Validate each redirect hop against the current allowlist + if let Err(_e) = allowlist.read().validate_parsed_url(&url) { + return attempt.error(RedirectBlockedError(url.to_string())); + } + + // Check for private IP in redirect target. + // If DNS resolution fails (Err), we skip the private IP check. This is + // safe because: (1) the allowlist domain check above is the primary guard, + // and (2) the actual connection will fail downstream on DNS failure. + if !allow_private_ip && let Ok(addrs) = url.socket_addrs(|| None) { + for addr in &addrs { + if is_private_ip(&addr.ip()) { + return attempt.error(RedirectBlockedError(format!( + "redirect to private ip: {}", + addr.ip() + ))); + } + } + } + + attempt.follow() + }) +} + +/// Custom error type for redirect policy violations. +#[derive(Debug)] +pub(crate) struct RedirectBlockedError(pub(crate) String); + +impl std::fmt::Display for RedirectBlockedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "redirect to disallowed domain: {}", self.0) + } +} + +impl std::error::Error for RedirectBlockedError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_method() { + assert_eq!(parse_method("GET").unwrap(), reqwest::Method::GET); + assert_eq!(parse_method("post").unwrap(), reqwest::Method::POST); + assert_eq!(parse_method("Put").unwrap(), reqwest::Method::PUT); + assert_eq!(parse_method("DELETE").unwrap(), reqwest::Method::DELETE); + assert_eq!(parse_method("PATCH").unwrap(), reqwest::Method::PATCH); + assert_eq!(parse_method("HEAD").unwrap(), reqwest::Method::HEAD); + assert_eq!(parse_method("OPTIONS").unwrap(), reqwest::Method::OPTIONS); + assert!(parse_method("INVALID").is_err()); + } + + #[test] + fn test_decode_request_body_utf8() { + let body = decode_request_body("hello", Some("utf8")).unwrap(); + + assert_eq!(body, b"hello"); + } + + #[test] + fn test_decode_request_body_base64() { + let body = decode_request_body("aGVsbG8=", Some("base64")).unwrap(); + + assert_eq!(body, b"hello"); + } + + #[test] + fn test_decode_request_body_invalid_base64() { + assert!(decode_request_body("not valid base64!!!", Some("base64")).is_err()); + } + + #[test] + fn test_decode_request_body_default_encoding() { + let body = decode_request_body("hello", None).unwrap(); + + assert_eq!(body, b"hello"); + } + + #[test] + fn test_is_text_content_type() { + assert!(is_text_content_type("text/plain")); + assert!(is_text_content_type("text/html; charset=utf-8")); + assert!(is_text_content_type("application/json")); + assert!(is_text_content_type("application/xml")); + assert!(is_text_content_type("application/javascript")); + assert!(is_text_content_type("text/css")); + assert!(is_text_content_type("image/svg+xml")); + assert!(is_text_content_type("application/x-www-form-urlencoded")); + + assert!(!is_text_content_type("application/octet-stream")); + assert!(!is_text_content_type("image/png")); + assert!(!is_text_content_type("application/pdf")); + } + + #[test] + fn test_is_text_content_type_yaml_toml_csv() { + assert!(is_text_content_type("application/yaml")); + assert!(is_text_content_type("application/toml")); + assert!(is_text_content_type("text/csv")); + } + + #[test] + fn test_is_text_content_type_empty_string() { + assert!(!is_text_content_type("")); + } + + #[test] + fn test_yaml_content_type_is_text() { + assert!(is_text_content_type("application/x-yaml")); + } + + #[test] + fn test_content_type_with_charset_detected() { + assert!(is_text_content_type("text/plain; charset=iso-8859-1")); + } + + #[test] + fn test_parse_method_case_insensitive() { + assert_eq!(parse_method("get").unwrap(), reqwest::Method::GET); + assert_eq!(parse_method("Get").unwrap(), reqwest::Method::GET); + assert_eq!(parse_method("gEt").unwrap(), reqwest::Method::GET); + } + + #[test] + fn test_parse_method_unsupported_returns_error_with_method_name() { + let err = parse_method("TRACE").unwrap_err(); + let msg = err.to_string(); + + assert!( + msg.contains("TRACE"), + "error should contain method name: {msg}" + ); + } + + #[test] + fn test_decode_request_body_unknown_encoding_treated_as_utf8() { + let body = decode_request_body("hello", Some("unknown")).unwrap(); + + assert_eq!(body, b"hello"); + } + + // --- InFlightRequests / abort tests --- + + #[tokio::test] + async fn test_in_flight_register_and_remove() { + let in_flight = InFlightRequests::new(); + let handle = tokio::spawn(async { 42 }); + + in_flight + .register("req-1".to_string(), handle.abort_handle()) + .await; + in_flight.remove("req-1").await; + + // After removal, the map should be empty + let map = in_flight.0.read().await; + + assert!(map.is_empty()); + } + + #[tokio::test] + async fn test_in_flight_remove_nonexistent_is_noop() { + let in_flight = InFlightRequests::new(); + + // Should not panic + in_flight.remove("nonexistent").await; + } + + #[tokio::test] + async fn test_abort_registered_request_returns_true() { + let in_flight = InFlightRequests::new(); + + let handle = tokio::spawn(async { + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + in_flight + .register("req-1".to_string(), handle.abort_handle()) + .await; + + let aborted = HttpClientState::abort(&in_flight, "req-1").await; + + assert!(aborted); + assert!(handle.await.is_err()); + } + + #[tokio::test] + async fn test_abort_unknown_request_returns_false() { + let in_flight = InFlightRequests::new(); + + let aborted = HttpClientState::abort(&in_flight, "nonexistent").await; + + assert!(!aborted); + } + + #[tokio::test] + async fn test_abort_then_abort_again_returns_false() { + let in_flight = InFlightRequests::new(); + + let handle = tokio::spawn(async { + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + in_flight + .register("req-1".to_string(), handle.abort_handle()) + .await; + HttpClientState::abort(&in_flight, "req-1").await; + + let aborted_again = HttpClientState::abort(&in_flight, "req-1").await; + + assert!(!aborted_again); + } + + #[tokio::test] + async fn test_concurrent_register_and_abort_no_deadlock() { + let in_flight = InFlightRequests::new(); + + // Register multiple requests and abort them concurrently + for i in 0..10 { + let handle = tokio::spawn(async { + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + in_flight + .register(format!("req-{i}"), handle.abort_handle()) + .await; + } + + let in_flight_clone = in_flight.clone(); + let abort_handles: Vec<_> = (0..10) + .map(|i| { + let inf = in_flight_clone.clone(); + + tokio::spawn(async move { HttpClientState::abort(&inf, &format!("req-{i}")).await }) + }) + .collect(); + + for handle in abort_handles { + let result = handle.await.unwrap(); + + assert!(result); + } + } + + #[test] + fn test_redirect_blocked_error_display() { + let err = RedirectBlockedError("https://evil.com".to_string()); + + assert_eq!( + err.to_string(), + "redirect to disallowed domain: https://evil.com" + ); + } + + #[test] + fn test_http_client_state_accessors() { + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(vec!["example.com".to_string()]).unwrap(), + )); + let client = reqwest::Client::new(); + let config = HttpClientConfig::default(); + + let state = HttpClientState::new(client, allowlist, config); + + assert!(state.validate_url("https://example.com").is_ok()); + assert!(state.validate_url("https://evil.com").is_err()); + assert!(!state.is_allowlist_empty()); + } + + #[test] + fn test_in_flight_requests_default() { + let in_flight = InFlightRequests::default(); + let in_flight2 = InFlightRequests::new(); + + assert!(!std::ptr::eq(&in_flight, &in_flight2)); + } + + // --- Wiremock-based integration tests --- + + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + /// Helper to build an HttpClientState with a given allowlist and custom config, + /// using the redirect policy. Does NOT allow private IPs (for testing the + /// DNS rebinding check itself). + fn build_test_state( + domains: &[&str], + max_redirects: usize, + max_body_size: usize, + ) -> HttpClientState { + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(domains.iter().map(|s| s.to_string()).collect()).unwrap(), + )); + let policy = build_redirect_policy(Arc::clone(&allowlist), max_redirects); + let client = reqwest::Client::builder().redirect(policy).build().unwrap(); + let config = HttpClientConfig { + max_redirects, + max_response_body_size: max_body_size, + ..Default::default() + }; + + HttpClientState::new(client, allowlist, config) + } + + /// Helper to build an HttpClientState for integration tests using wiremock + /// on localhost. Bypasses the DNS rebinding check since wiremock resolves + /// to 127.0.0.1 (a private IP). The domain allowlist remains enforced. + fn build_localhost_test_state(max_redirects: usize, max_body_size: usize) -> HttpClientState { + build_localhost_test_state_with_config(HttpClientConfig { + max_redirects, + max_response_body_size: max_body_size, + allow_private_ip: true, + ..Default::default() + }) + } + + fn build_localhost_test_state_with_config(config: HttpClientConfig) -> HttpClientState { + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(vec!["localhost".to_string()]).unwrap(), + )); + let policy = build_redirect_policy_inner( + Arc::clone(&allowlist), + config.max_redirects, + config.allow_private_ip, + ); + let client = reqwest::Client::builder().redirect(policy).build().unwrap(); + + HttpClientState::new(client, allowlist, config) + } + + /// Converts a wiremock server URI (http://127.0.0.1:PORT) to use localhost + /// so the URL passes domain allowlist validation. + fn localhost_url(server: &MockServer, path: &str) -> String { + let uri = server.uri(); + let port = uri.rsplit(':').next().unwrap(); + + format!("http://localhost:{port}{path}") + } + + fn make_request(url: &str) -> FetchRequest { + FetchRequest { + url: url.to_string(), + method: None, + headers: None, + body: None, + body_encoding: None, + timeout_ms: None, + request_id: None, + max_retries: None, + } + } + + // --- Redirect policy tests --- + // + // These tests use build_redirect_policy (or build_redirect_policy_inner) + // to test the actual plugin redirect logic, not reqwest's default policy. + // Tests that go through execute() use localhost_url() and + // build_localhost_test_state() to bypass the IP-address URL validation + // while still exercising the full pipeline. + + #[tokio::test] + async fn test_redirect_to_allowed_domain_succeeds() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/a")) + .respond_with( + ResponseTemplate::new(302).insert_header("Location", localhost_url(&server, "/b")), + ) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/b")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/a")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + assert!(resp.redirected); + assert_eq!(resp.body, "ok"); + } + + #[tokio::test] + async fn test_redirect_to_disallowed_domain_blocked() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/redirect")) + .respond_with( + ResponseTemplate::new(302).insert_header("Location", "https://evil.example.com/pwned"), + ) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/redirect")); + let result = state.execute(req).await; + + assert!(result.is_err()); + assert!( + matches!(result.unwrap_err(), Error::RedirectBlocked(_)), + "should block redirect to disallowed domain" + ); + } + + #[tokio::test] + async fn test_redirect_chain_exceeds_max_hops() { + let server = MockServer::start().await; + + // Create a chain of 4 redirects (max is 3) + for i in 0..4 { + Mock::given(method("GET")) + .and(path(format!("/hop{i}"))) + .respond_with(ResponseTemplate::new(302).insert_header( + "Location", + localhost_url(&server, &format!("/hop{}", i + 1)), + )) + .mount(&server) + .await; + } + + Mock::given(method("GET")) + .and(path("/hop4")) + .respond_with(ResponseTemplate::new(200).set_body_string("final")) + .mount(&server) + .await; + + let state = build_localhost_test_state(3, 10_000_000); + let req = make_request(&localhost_url(&server, "/hop0")); + let resp = state.execute(req).await.unwrap(); + + // With max_redirects=3, the 4th redirect is stopped and the 3xx is returned + assert!( + resp.status >= 300 && resp.status < 400, + "should return redirect status when max hops exceeded, got: {}", + resp.status + ); + } + + #[tokio::test] + async fn test_redirect_within_same_domain_succeeds() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/start")) + .respond_with( + ResponseTemplate::new(302).insert_header("Location", localhost_url(&server, "/end")), + ) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/end")) + .respond_with(ResponseTemplate::new(200).set_body_string("final")) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/start")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + assert_eq!(resp.body, "final"); + assert!(resp.redirected); + } + + #[tokio::test] + async fn test_zero_max_redirects_blocks_all() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/start")) + .respond_with( + ResponseTemplate::new(302).insert_header("Location", localhost_url(&server, "/end")), + ) + .mount(&server) + .await; + + let state = build_localhost_test_state(0, 10_000_000); + let req = make_request(&localhost_url(&server, "/start")); + let resp = state.execute(req).await.unwrap(); + + // With max_redirects=0, the redirect is not followed + assert!(resp.status >= 300 && resp.status < 400); + assert!(!resp.redirected); + } + + // --- Body size limit tests --- + + #[tokio::test] + async fn test_body_within_limit_succeeds() { + let server = MockServer::start().await; + let body = "x".repeat(100); + + Mock::given(method("GET")) + .and(path("/small")) + .respond_with(ResponseTemplate::new(200).set_body_string(&body)) + .mount(&server) + .await; + + // Use 1000 byte limit, bypass allowlist by constructing state carefully + let state = build_test_state(&["localhost"], 10, 1000); + let client = reqwest::Client::new(); + let resp = client + .get(format!("{}/small", server.uri())) + .send() + .await + .unwrap(); + let result = state.read_body_with_limit(resp).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 100); + } + + #[tokio::test] + async fn test_content_length_exceeds_limit_early_reject() { + let server = MockServer::start().await; + let body = "x".repeat(2000); + + Mock::given(method("GET")) + .and(path("/big")) + .respond_with(ResponseTemplate::new(200).set_body_string(&body)) + .mount(&server) + .await; + + let state = build_test_state(&["localhost"], 10, 100); + let client = reqwest::Client::new(); + let resp = client + .get(format!("{}/big", server.uri())) + .send() + .await + .unwrap(); + let result = state.read_body_with_limit(resp).await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + + assert!( + matches!(err, Error::ResponseTooLarge { .. }), + "expected ResponseTooLarge, got: {err:?}" + ); + } + + #[tokio::test] + async fn test_chunked_body_exceeds_limit_aborts_midstream() { + let server = MockServer::start().await; + + // wiremock sends the body; with a small limit, streaming read will abort + let body = "x".repeat(500); + + Mock::given(method("GET")) + .and(path("/chunked")) + .respond_with(ResponseTemplate::new(200).set_body_string(&body)) + .mount(&server) + .await; + + let state = build_test_state(&["localhost"], 10, 100); + let client = reqwest::Client::new(); + let resp = client + .get(format!("{}/chunked", server.uri())) + .send() + .await + .unwrap(); + let result = state.read_body_with_limit(resp).await; + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::ResponseTooLarge { .. } + )); + } + + #[tokio::test] + async fn test_body_exactly_at_limit_succeeds() { + let server = MockServer::start().await; + let body = "x".repeat(100); + + Mock::given(method("GET")) + .and(path("/exact")) + .respond_with(ResponseTemplate::new(200).set_body_string(&body)) + .mount(&server) + .await; + + let state = build_test_state(&["localhost"], 10, 100); + let client = reqwest::Client::new(); + let resp = client + .get(format!("{}/exact", server.uri())) + .send() + .await + .unwrap(); + let result = state.read_body_with_limit(resp).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 100); + } + + #[tokio::test] + async fn test_empty_body_succeeds() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/empty")) + .respond_with(ResponseTemplate::new(200)) + .mount(&server) + .await; + + let state = build_test_state(&["localhost"], 10, 100); + let client = reqwest::Client::new(); + let resp = client + .get(format!("{}/empty", server.uri())) + .send() + .await + .unwrap(); + let result = state.read_body_with_limit(resp).await; + + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + } + + // --- DNS rebinding tests --- + + #[tokio::test] + async fn test_execute_rejects_ip_address_in_url() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/data")) + .respond_with(ResponseTemplate::new(200).set_body_string("secret")) + .mount(&server) + .await; + + // wiremock serves on 127.0.0.1, which is an IP literal in the URL. + // validate_url rejects IP addresses before the request is even sent, + // so the DNS rebinding check never runs. This test verifies that the + // IP address guard works at the URL validation layer. + let state = build_test_state(&["localhost"], 10, 10_000_000); + let req = make_request(&format!("{}/data", server.uri())); + let result = state.execute(req).await; + + assert!(result.is_err()); + assert!( + matches!(result.unwrap_err(), Error::IpAddressNotAllowed), + "should reject IP address in URL before sending request" + ); + } + + #[tokio::test] + async fn test_dns_rebinding_rejects_localhost_when_private_ip_check_enabled() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/data")) + .respond_with(ResponseTemplate::new(200).set_body_string("secret")) + .mount(&server) + .await; + + // Use the non-localhost state (private IP check enabled). + // localhost resolves to 127.0.0.1 which should be rejected. + let state = build_test_state(&["localhost"], 10, 10_000_000); + let req = make_request(&localhost_url(&server, "/data")); + let result = state.execute(req).await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + + assert!( + matches!(err, Error::DomainNotAllowed(ref msg) if msg.contains("private ip")), + "should reject private IP from DNS resolution, got: {err:?}" + ); + } + + // --- Dynamic allowlist tests --- + + #[test] + fn test_add_allowed_domain_validates_new_url() { + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(vec!["api.example.com".to_string()]).unwrap(), + )); + let client = reqwest::Client::new(); + let config = HttpClientConfig::default(); + let state = HttpClientState::new(client, allowlist, config); + + assert!(state.validate_url("https://new.example.com").is_err()); + + state.add_allowed_domain("new.example.com").unwrap(); + + assert!(state.validate_url("https://new.example.com").is_ok()); + assert!(state.validate_url("https://api.example.com").is_ok()); + } + + #[test] + fn test_add_allowed_domains_batch() { + let allowlist = Arc::new(RwLock::new(DomainAllowlist::new(vec![]).unwrap())); + let client = reqwest::Client::new(); + let config = HttpClientConfig::default(); + let state = HttpClientState::new(client, allowlist, config); + + assert!(state.is_allowlist_empty()); + + state + .add_allowed_domains(["a.example.com", "b.example.com"]) + .unwrap(); + + assert!(!state.is_allowlist_empty()); + assert!(state.validate_url("https://a.example.com").is_ok()); + assert!(state.validate_url("https://b.example.com").is_ok()); + } + + #[test] + fn test_add_allowed_domains_cap_exceeded() { + let allowlist = Arc::new(RwLock::new(DomainAllowlist::new(vec![]).unwrap())); + let client = reqwest::Client::new(); + let config = HttpClientConfig { + max_allowlist_size: 2, + ..Default::default() + }; + let state = HttpClientState::new(client, allowlist, config); + + state.add_allowed_domains(["a.com", "b.com"]).unwrap(); + + let result = state.add_allowed_domain("c.com"); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::AllowlistSizeExceeded { count: 3, limit: 2 } + )); + } + + #[test] + fn test_redirect_policy_sees_dynamically_added_domain() { + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(vec!["initial.example.com".to_string()]).unwrap(), + )); + let policy_allowlist = Arc::clone(&allowlist); + + // Before adding: validate_parsed_url should fail for the new domain + let new_url = url::Url::parse("https://added.example.com/path").unwrap(); + + assert!(allowlist.read().validate_parsed_url(&new_url).is_err()); + + // Add domain through the shared allowlist + allowlist + .write() + .add_patterns(vec!["added.example.com".to_string()]) + .unwrap(); + + // The same Arc (as redirect policy would use) now sees the new domain + assert!( + policy_allowlist + .read() + .validate_parsed_url(&new_url) + .is_ok() + ); + } + + // --- Dynamic allowlist removal tests --- + + #[test] + fn test_remove_allowed_domain_blocks_url() { + let allowlist = Arc::new(RwLock::new(DomainAllowlist::new(vec![]).unwrap())); + let client = reqwest::Client::new(); + let config = HttpClientConfig::default(); + let state = HttpClientState::new(client, allowlist, config); + + state.add_allowed_domain("api.example.com").unwrap(); + assert!(state.validate_url("https://api.example.com").is_ok()); + + let removed = state.remove_allowed_domain("api.example.com").unwrap(); + + assert!(removed); + assert!(state.validate_url("https://api.example.com").is_err()); + } + + #[test] + fn test_remove_allowed_domains_batch() { + let allowlist = Arc::new(RwLock::new(DomainAllowlist::new(vec![]).unwrap())); + let client = reqwest::Client::new(); + let config = HttpClientConfig::default(); + let state = HttpClientState::new(client, allowlist, config); + + state + .add_allowed_domains(["a.example.com", "b.example.com", "c.example.com"]) + .unwrap(); + + let removed = state + .remove_allowed_domains(["a.example.com", "c.example.com"]) + .unwrap(); + + assert_eq!(removed, 2); + assert!(state.validate_url("https://a.example.com").is_err()); + assert!(state.validate_url("https://b.example.com").is_ok()); + assert!(state.validate_url("https://c.example.com").is_err()); + } + + #[test] + fn test_remove_all_runtime_domains() { + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(vec!["init.example.com".to_string()]).unwrap(), + )); + let client = reqwest::Client::new(); + let config = HttpClientConfig::default(); + let state = HttpClientState::new(client, allowlist, config); + + state + .add_allowed_domains(["a.example.com", "b.example.com"]) + .unwrap(); + + let removed = state.remove_all_runtime_domains(); + + assert_eq!(removed, 2); + assert!(state.validate_url("https://a.example.com").is_err()); + assert!(state.validate_url("https://b.example.com").is_err()); + // Config-time domain preserved + assert!(state.validate_url("https://init.example.com").is_ok()); + } + + #[test] + fn test_redirect_policy_sees_removal() { + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(vec!["init.example.com".to_string()]).unwrap(), + )); + let policy_allowlist = Arc::clone(&allowlist); + + // Add a runtime domain + allowlist + .write() + .add_patterns(vec!["dynamic.example.com".to_string()]) + .unwrap(); + + let dynamic_url = url::Url::parse("https://dynamic.example.com/path").unwrap(); + + assert!( + policy_allowlist + .read() + .validate_parsed_url(&dynamic_url) + .is_ok() + ); + + // Remove the runtime domain + allowlist + .write() + .remove_patterns(&["dynamic.example.com".to_string()]) + .unwrap(); + + // The same Arc (as redirect policy would use) now rejects the domain + assert!( + policy_allowlist + .read() + .validate_parsed_url(&dynamic_url) + .is_err() + ); + } + + // --- Retry logic tests --- + + #[test] + fn test_resolve_max_retries_uses_config_default() { + let state = build_test_state_with_retry( + &["example.com"], + RetryConfig { + max_retries: 3, + ..RetryConfig::default() + }, + ); + let req = make_request("https://example.com"); + + assert_eq!(state.resolve_max_retries(&req), 3); + } + + #[test] + fn test_resolve_max_retries_per_request_override() { + let state = build_test_state_with_retry( + &["example.com"], + RetryConfig { + max_retries: 5, + ..RetryConfig::default() + }, + ); + let mut req = make_request("https://example.com"); + + req.max_retries = Some(2); + + assert_eq!(state.resolve_max_retries(&req), 2); + } + + #[test] + fn test_resolve_max_retries_per_request_capped_at_config() { + let state = build_test_state_with_retry( + &["example.com"], + RetryConfig { + max_retries: 3, + ..RetryConfig::default() + }, + ); + let mut req = make_request("https://example.com"); + + req.max_retries = Some(10); + + assert_eq!(state.resolve_max_retries(&req), 3); + } + + #[test] + fn test_resolve_max_retries_per_request_zero_disables() { + let state = build_test_state_with_retry( + &["example.com"], + RetryConfig { + max_retries: 3, + ..RetryConfig::default() + }, + ); + let mut req = make_request("https://example.com"); + + req.max_retries = Some(0); + + assert_eq!(state.resolve_max_retries(&req), 0); + } + + fn build_test_state_with_retry(domains: &[&str], retry: RetryConfig) -> HttpClientState { + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(domains.iter().map(|s| s.to_string()).collect()).unwrap(), + )); + let client = reqwest::Client::new(); + let config = HttpClientConfig { + retry, + ..Default::default() + }; + + HttpClientState::new(client, allowlist, config) + } + + #[test] + fn test_calculate_backoff_first_retry() { + let config = RetryConfig { + initial_backoff: Duration::from_millis(200), + max_backoff: Duration::from_secs(10), + ..RetryConfig::default() + }; + + let backoff = calculate_backoff(&config, 1, None); + + // Equal jitter: result should be between 0 and 200ms + assert!(backoff <= Duration::from_millis(200)); + } + + #[test] + fn test_calculate_backoff_exponential_growth() { + let config = RetryConfig { + initial_backoff: Duration::from_millis(200), + max_backoff: Duration::from_secs(60), + ..RetryConfig::default() + }; + + // attempt=1: base=200ms, attempt=2: base=400ms, attempt=3: base=800ms + let b1 = calculate_backoff(&config, 1, None); + let b3 = calculate_backoff(&config, 3, None); + + // b3 should have a higher ceiling (800ms) than b1 (200ms) + // Due to jitter, we can only check the ceiling + assert!(b1 <= Duration::from_millis(200)); + assert!(b3 <= Duration::from_millis(800)); + } + + #[test] + fn test_calculate_backoff_capped_at_max() { + let config = RetryConfig { + initial_backoff: Duration::from_millis(1000), + max_backoff: Duration::from_millis(2000), + ..RetryConfig::default() + }; + + // attempt=5: 1000 * 2^4 = 16000ms, capped to 2000ms + let backoff = calculate_backoff(&config, 5, None); + + assert!(backoff <= Duration::from_millis(2000)); + } + + #[test] + fn test_calculate_backoff_with_retry_after_header() { + let config = RetryConfig::default(); + let resp = FetchResponse { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["5".to_string()])]), + body: String::new(), + body_encoding: "utf8".to_string(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }; + + let backoff = calculate_backoff(&config, 1, Some(&Ok(resp))); + + assert_eq!(backoff, Duration::from_secs(5)); + } + + #[test] + fn test_calculate_backoff_retry_after_capped() { + let config = RetryConfig { + max_retry_after: Duration::from_secs(10), + ..RetryConfig::default() + }; + let resp = FetchResponse { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["999".to_string()])]), + body: String::new(), + body_encoding: "utf8".to_string(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }; + + let backoff = calculate_backoff(&config, 1, Some(&Ok(resp))); + + assert_eq!(backoff, Duration::from_secs(10)); + } + + #[test] + fn test_parse_retry_after_valid_seconds() { + let resp = FetchResponse { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["120".to_string()])]), + body: String::new(), + body_encoding: "utf8".to_string(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }; + + assert_eq!( + parse_retry_after_from_response(&resp), + Some(Duration::from_secs(120)) + ); + } + + #[test] + fn test_parse_retry_after_missing_header() { + let resp = FetchResponse { + status: 503, + status_text: "Service Unavailable".to_string(), + headers: HashMap::new(), + body: String::new(), + body_encoding: "utf8".to_string(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }; + + assert_eq!(parse_retry_after_from_response(&resp), None); + } + + #[test] + fn test_parse_retry_after_non_numeric_ignored() { + let resp = FetchResponse { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([( + "retry-after".to_string(), + vec!["Wed, 21 Oct 2025 07:28:00 GMT".to_string()], + )]), + body: String::new(), + body_encoding: "utf8".to_string(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }; + + assert_eq!(parse_retry_after_from_response(&resp), None); + } + + #[test] + fn test_parse_retry_after_zero_seconds() { + let resp = FetchResponse { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["0".to_string()])]), + body: String::new(), + body_encoding: "utf8".to_string(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }; + + assert_eq!( + parse_retry_after_from_response(&resp), + Some(Duration::from_secs(0)) + ); + } + + #[test] + fn test_calculate_backoff_with_retry_after_zero() { + let config = RetryConfig::default(); + let resp = FetchResponse { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["0".to_string()])]), + body: String::new(), + body_encoding: "utf8".to_string(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }; + + let backoff = calculate_backoff(&config, 1, Some(&Ok(resp))); + + assert_eq!(backoff, Duration::from_secs(0)); + } + + #[test] + fn test_calculate_backoff_no_overflow_on_high_attempt() { + let config = RetryConfig { + initial_backoff: Duration::from_millis(200), + max_backoff: Duration::from_secs(10), + ..RetryConfig::default() + }; + + // Very high attempt number should not panic + let backoff = calculate_backoff(&config, 100, None); + + assert!(backoff <= Duration::from_secs(10)); + } + + #[tokio::test] + async fn test_retry_on_500_then_success() { + let server = MockServer::start().await; + + // First call returns 500, second returns 200 + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(500).set_body_string("error")) + .up_to_n_times(1) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = build_localhost_test_state_with_config(HttpClientConfig { + retry: RetryConfig { + max_retries: 2, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + retryable_methods: None, + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }); + let req = make_request(&localhost_url(&server, "/api")); + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + assert_eq!(resp.body, "ok"); + assert_eq!(resp.retry_count, 1); + } + + #[tokio::test] + async fn test_retry_exhausted_returns_last_response() { + let server = MockServer::start().await; + + // All calls return 503 + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(503).set_body_string("unavailable")) + .expect(3) // initial + 2 retries + .mount(&server) + .await; + + let state = build_localhost_test_state_with_config(HttpClientConfig { + retry: RetryConfig { + max_retries: 2, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }); + let req = make_request(&localhost_url(&server, "/api")); + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 503); + assert_eq!(resp.retry_count, 2); + } + + #[tokio::test] + async fn test_post_not_retried_by_default() { + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(500).set_body_string("error")) + .expect(1) // Should only be called once (POST not retryable) + .mount(&server) + .await; + + let state = build_localhost_test_state_with_config(HttpClientConfig { + retry: RetryConfig { + max_retries: 3, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }); + let mut req = make_request(&localhost_url(&server, "/api")); + + req.method = Some("POST".to_string()); + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 500); + assert_eq!(resp.retry_count, 0); + } + + #[tokio::test] + async fn test_retry_disabled_by_default() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(503).set_body_string("unavailable")) + .expect(1) // Should only be called once (no retry, default is disabled) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/api")); + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 503); + assert_eq!(resp.retry_count, 0); + } + + // --- Full execute() pipeline tests --- + // + // These tests exercise the complete request pipeline through execute(), + // verifying behavior that was previously untestable due to the IP + // validation blocking wiremock requests. + + #[tokio::test] + async fn test_execute_happy_path_get() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/data")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string(r#"{"hello":"world"}"#) + .insert_header("Content-Type", "application/json"), + ) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/data")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + assert_eq!(resp.status_text, "OK"); + assert_eq!(resp.body, r#"{"hello":"world"}"#); + assert_eq!(resp.body_encoding, "utf8"); + assert!(!resp.redirected); + assert_eq!(resp.retry_count, 0); + } + + #[tokio::test] + async fn test_execute_with_request_headers() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api")) + .and(wiremock::matchers::header("X-Custom", "test-value")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let mut req = make_request(&localhost_url(&server, "/api")); + + req.headers = Some(HashMap::from([( + "X-Custom".to_string(), + "test-value".to_string(), + )])); + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + } + + #[tokio::test] + async fn test_execute_rejects_host_header() { + let state = build_localhost_test_state(10, 10_000_000); + let server = MockServer::start().await; + let mut req = make_request(&localhost_url(&server, "/api")); + + req.headers = Some(HashMap::from([( + "Host".to_string(), + "evil.com".to_string(), + )])); + + let err = state.execute(req).await.unwrap_err(); + + assert!(matches!(err, Error::ForbiddenHeader(_))); + } + + #[tokio::test] + async fn test_execute_rejects_host_header_case_insensitive() { + let state = build_localhost_test_state(10, 10_000_000); + let server = MockServer::start().await; + let mut req = make_request(&localhost_url(&server, "/api")); + + req.headers = Some(HashMap::from([( + "hOsT".to_string(), + "evil.com".to_string(), + )])); + + let err = state.execute(req).await.unwrap_err(); + + assert!(matches!(err, Error::ForbiddenHeader(_))); + } + + #[tokio::test] + async fn test_execute_default_headers_applied() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api")) + .and(wiremock::matchers::header("X-Default", "default-value")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = build_localhost_test_state_with_config(HttpClientConfig { + default_headers: HashMap::from([("X-Default".to_string(), "default-value".to_string())]), + allow_private_ip: true, + ..Default::default() + }); + let req = make_request(&localhost_url(&server, "/api")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + } + + #[tokio::test] + async fn test_execute_per_request_headers_supplement_defaults() { + let server = MockServer::start().await; + + // reqwest appends per-request headers rather than replacing defaults, + // so both default and per-request headers are sent. Verify both arrive. + Mock::given(method("GET")) + .and(path("/api")) + .and(wiremock::matchers::header("X-Default", "default-value")) + .and(wiremock::matchers::header("X-Request", "request-value")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = build_localhost_test_state_with_config(HttpClientConfig { + default_headers: HashMap::from([("X-Default".to_string(), "default-value".to_string())]), + allow_private_ip: true, + ..Default::default() + }); + let mut req = make_request(&localhost_url(&server, "/api")); + + req.headers = Some(HashMap::from([( + "X-Request".to_string(), + "request-value".to_string(), + )])); + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + } + + #[tokio::test] + async fn test_execute_text_body_encoding() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/text")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string("plain text response") + .insert_header("Content-Type", "text/plain"), + ) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/text")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.body, "plain text response"); + assert_eq!(resp.body_encoding, "utf8"); + } + + #[tokio::test] + async fn test_execute_binary_body_encoding() { + let server = MockServer::start().await; + let binary_data: Vec = vec![0x89, 0x50, 0x4E, 0x47]; // PNG header bytes + + Mock::given(method("GET")) + .and(path("/image")) + .respond_with( + ResponseTemplate::new(200) + .set_body_bytes(binary_data.clone()) + .insert_header("Content-Type", "image/png"), + ) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/image")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.body_encoding, "base64"); + + // Decode and verify + let decoded = base64::engine::general_purpose::STANDARD + .decode(&resp.body) + .unwrap(); + + assert_eq!(decoded, binary_data); + } + + #[tokio::test] + async fn test_execute_post_with_body() { + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/submit")) + .and(wiremock::matchers::body_string(r#"{"key":"value"}"#)) + .respond_with(ResponseTemplate::new(201).set_body_string("created")) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let mut req = make_request(&localhost_url(&server, "/submit")); + + req.method = Some("POST".to_string()); + req.body = Some(r#"{"key":"value"}"#.to_string()); + req.body_encoding = Some("utf8".to_string()); + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 201); + assert_eq!(resp.body, "created"); + } + + #[tokio::test] + async fn test_execute_response_headers_collected() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/headers")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string("ok") + .insert_header("X-Custom-Response", "header-value"), + ) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/headers")); + let resp = state.execute(req).await.unwrap(); + + let custom_header = resp.headers.get("x-custom-response").unwrap(); + + assert_eq!(custom_header, &vec!["header-value".to_string()]); + } + + #[tokio::test] + async fn test_execute_domain_not_allowed_rejected() { + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request("https://evil.com/steal"); + let result = state.execute(req).await; + + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), Error::DomainNotAllowed(_))); + } + + #[tokio::test] + async fn test_execute_body_size_limit_through_pipeline() { + let server = MockServer::start().await; + let body = "x".repeat(200); + + Mock::given(method("GET")) + .and(path("/big")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string(&body) + .insert_header("Content-Type", "text/plain"), + ) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 100); // 100 byte limit + let req = make_request(&localhost_url(&server, "/big")); + let result = state.execute(req).await; + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::ResponseTooLarge { .. } + )); + } + + #[tokio::test] + async fn test_execute_redirect_sets_redirected_flag() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/a")) + .respond_with( + ResponseTemplate::new(301).insert_header("Location", localhost_url(&server, "/b")), + ) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/b")) + .respond_with(ResponseTemplate::new(200).set_body_string("final")) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/a")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + assert!(resp.redirected); + assert!(resp.url.contains("/b")); + } + + #[tokio::test] + async fn test_execute_non_redirected_request_has_false_flag() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/direct")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/direct")); + let resp = state.execute(req).await.unwrap(); + + assert!(!resp.redirected); + } + + #[tokio::test] + async fn test_execute_empty_string_body() { + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/empty")) + .and(wiremock::matchers::body_string("")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let mut req = make_request(&localhost_url(&server, "/empty")); + + req.method = Some("POST".to_string()); + req.body = Some(String::new()); + req.body_encoding = Some("utf8".to_string()); + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + } + + #[tokio::test] + async fn test_execute_from_utf8_lossy_on_binary_with_text_content_type() { + let server = MockServer::start().await; + + // Send binary data (invalid UTF-8) with text/plain content type + let binary_body: Vec = vec![0x48, 0x65, 0x6C, 0xFF, 0x6F]; // "Hel\xFFo" + + Mock::given(method("GET")) + .and(path("/lossy")) + .respond_with( + ResponseTemplate::new(200) + .set_body_bytes(binary_body) + .insert_header("Content-Type", "text/plain"), + ) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/lossy")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.body_encoding, "utf8"); + // The invalid byte 0xFF should be replaced with U+FFFD + assert!( + resp.body.contains('\u{FFFD}'), + "expected replacement character in lossy UTF-8 conversion, got: {:?}", + resp.body + ); + } + + #[tokio::test] + async fn test_execute_timeout_is_retryable() { + let server = MockServer::start().await; + + // First request times out (delay > timeout), second succeeds + Mock::given(method("GET")) + .and(path("/slow")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string("slow") + .set_delay(Duration::from_secs(5)), + ) + .up_to_n_times(1) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/slow")) + .respond_with(ResponseTemplate::new(200).set_body_string("fast")) + .mount(&server) + .await; + + let state = build_localhost_test_state_with_config(HttpClientConfig { + retry: RetryConfig { + max_retries: 1, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }); + + let mut req = make_request(&localhost_url(&server, "/slow")); + + req.timeout_ms = Some(100); // 100ms timeout + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + assert_eq!(resp.body, "fast"); + assert_eq!(resp.retry_count, 1); + } + + #[tokio::test] + async fn test_retry_with_custom_retryable_status_codes() { + let server = MockServer::start().await; + + // First returns 418, second returns 200 + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(418).set_body_string("teapot")) + .up_to_n_times(1) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = build_localhost_test_state_with_config(HttpClientConfig { + retry: RetryConfig { + max_retries: 1, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + retryable_status_codes: vec![418], // Custom: retry on 418 + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }); + let req = make_request(&localhost_url(&server, "/api")); + + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + assert_eq!(resp.retry_count, 1); + } + + #[tokio::test] + async fn test_retry_revalidates_allowlist_between_attempts() { + let server = MockServer::start().await; + + // First call returns 500 (triggers retry), second would return 200 + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(500).set_body_string("error")) + .up_to_n_times(1) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(vec!["localhost".to_string()]).unwrap(), + )); + + // Add a runtime domain that we'll remove to test revalidation + allowlist + .write() + .add_patterns(vec!["localhost".to_string()]) + .unwrap(); + + let policy = build_redirect_policy_inner(Arc::clone(&allowlist), 10, true); + let client = reqwest::Client::builder().redirect(policy).build().unwrap(); + let config = HttpClientConfig { + retry: RetryConfig { + max_retries: 2, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + retryable_methods: None, + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }; + let state = HttpClientState::new(client, allowlist, config); + + // The request should succeed because "localhost" is an init_pattern + // (cannot be removed). This test verifies the revalidation path runs + // without error when the allowlist hasn't changed. + let req = make_request(&localhost_url(&server, "/api")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + assert_eq!(resp.retry_count, 1); + } + + // --- Redirect to IP address (integration) --- + + #[tokio::test] + async fn test_execute_redirect_to_ip_address_blocked() { + let server = MockServer::start().await; + + // Redirect to an IP address URL — should be blocked by validate_parsed_url + Mock::given(method("GET")) + .and(path("/redir")) + .respond_with( + ResponseTemplate::new(301).insert_header("Location", "http://127.0.0.1:9999/evil"), + ) + .mount(&server) + .await; + + let state = build_localhost_test_state(10, 10_000_000); + let req = make_request(&localhost_url(&server, "/redir")); + let result = state.execute(req).await; + + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), Error::RedirectBlocked(_))); + } + + // --- max_redirects stop behavior --- + + #[tokio::test] + async fn test_execute_max_redirects_returns_3xx_not_error() { + let server = MockServer::start().await; + + // Set up a redirect chain longer than max_redirects + Mock::given(method("GET")) + .and(path("/a")) + .respond_with( + ResponseTemplate::new(301).insert_header("Location", localhost_url(&server, "/b")), + ) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/b")) + .respond_with( + ResponseTemplate::new(302).insert_header("Location", localhost_url(&server, "/c")), + ) + .mount(&server) + .await; + + // /c would redirect again, but max_redirects=2 should stop at /b's response + Mock::given(method("GET")) + .and(path("/c")) + .respond_with( + ResponseTemplate::new(301).insert_header("Location", localhost_url(&server, "/d")), + ) + .mount(&server) + .await; + + // max_redirects=2: follows /a -> /b -> /c, stops at /c's 301 + let state = build_localhost_test_state(2, 10_000_000); + let req = make_request(&localhost_url(&server, "/a")); + let resp = state.execute(req).await.unwrap(); + + // Should get the 3xx response (stop behavior), not an error + assert!( + resp.status >= 300 && resp.status < 400, + "expected 3xx status from stop(), got {}", + resp.status + ); + assert!(resp.redirected); + } + + // --- Retry-After end-to-end --- + + #[tokio::test] + async fn test_retry_honors_retry_after_header_end_to_end() { + let server = MockServer::start().await; + + // First returns 429 with Retry-After, second returns 200 + Mock::given(method("GET")) + .and(path("/rate-limited")) + .respond_with( + ResponseTemplate::new(429) + .set_body_string("rate limited") + .insert_header("Retry-After", "1"), + ) + .up_to_n_times(1) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/rate-limited")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + let state = build_localhost_test_state_with_config(HttpClientConfig { + retry: RetryConfig { + max_retries: 1, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_secs(5), + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }); + let req = make_request(&localhost_url(&server, "/rate-limited")); + let resp = state.execute(req).await.unwrap(); + + assert_eq!(resp.status, 200); + assert_eq!(resp.retry_count, 1); + } + + // --- Per-request max_retries override --- + + #[tokio::test] + async fn test_per_request_max_retries_override_through_execute() { + let server = MockServer::start().await; + + // Returns 500 twice, then 200 + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(500).set_body_string("error")) + .up_to_n_times(2) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .mount(&server) + .await; + + // Config allows up to 5 retries, but request asks for only 1 + let state = build_localhost_test_state_with_config(HttpClientConfig { + retry: RetryConfig { + max_retries: 5, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }); + let mut req = make_request(&localhost_url(&server, "/api")); + + req.max_retries = Some(1); + + let resp = state.execute(req).await.unwrap(); + + // With max_retries=1, we get 2 attempts: first returns 500, second returns 500. + // Since retries are exhausted, we get the last 500 response. + assert_eq!(resp.status, 500); + assert_eq!(resp.retry_count, 1); + } + + // --- Security errors skip retry loop --- + + #[tokio::test] + async fn test_security_error_not_retried() { + // DomainNotAllowed should fail immediately, not be retried + let state = build_localhost_test_state_with_config(HttpClientConfig { + retry: RetryConfig { + max_retries: 3, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }); + + let req = make_request("https://evil.com/steal"); + let result = state.execute(req).await; + + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), Error::DomainNotAllowed(_))); + // If it were retried, this test would take noticeable time due to backoff. + // The near-instant completion proves the retry loop was bypassed. + } + + #[tokio::test] + async fn test_forbidden_header_error_not_retried() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api")) + .respond_with(ResponseTemplate::new(200).set_body_string("ok")) + .expect(0) // Should never reach the server + .mount(&server) + .await; + + let state = build_localhost_test_state_with_config(HttpClientConfig { + retry: RetryConfig { + max_retries: 3, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + ..RetryConfig::default() + }, + allow_private_ip: true, + ..Default::default() + }); + let mut req = make_request(&localhost_url(&server, "/api")); + + req.headers = Some(HashMap::from([( + "Host".to_string(), + "evil.com".to_string(), + )])); + + let result = state.execute(req).await; + + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), Error::ForbiddenHeader(_))); + } + + // --- validate_header_name tests --- + + #[test] + fn test_validate_header_name_allows_normal_headers() { + assert!(validate_header_name("authorization").is_ok()); + assert!(validate_header_name("content-type").is_ok()); + assert!(validate_header_name("accept").is_ok()); + assert!(validate_header_name("x-custom-header").is_ok()); + assert!(validate_header_name("user-agent").is_ok()); + assert!(validate_header_name("accept-encoding").is_ok()); + assert!(validate_header_name("cookie").is_ok()); + } + + #[test] + fn test_validate_header_name_blocks_host() { + let result = validate_header_name("host"); + + assert!(matches!(result, Err(Error::ForbiddenHeader(ref h)) if h == "host")); + } + + #[test] + fn test_validate_header_name_blocks_host_case_insensitive() { + assert!(validate_header_name("HOST").is_err()); + assert!(validate_header_name("Host").is_err()); + assert!(validate_header_name("hOsT").is_err()); + } + + #[test] + fn test_validate_header_name_blocks_connection() { + assert!(validate_header_name("connection").is_err()); + assert!(validate_header_name("Connection").is_err()); + } + + #[test] + fn test_validate_header_name_blocks_keep_alive() { + assert!(validate_header_name("keep-alive").is_err()); + assert!(validate_header_name("Keep-Alive").is_err()); + } + + #[test] + fn test_validate_header_name_blocks_transfer_encoding() { + assert!(validate_header_name("transfer-encoding").is_err()); + assert!(validate_header_name("Transfer-Encoding").is_err()); + } + + #[test] + fn test_validate_header_name_blocks_te() { + assert!(validate_header_name("te").is_err()); + assert!(validate_header_name("TE").is_err()); + } + + #[test] + fn test_validate_header_name_blocks_upgrade() { + assert!(validate_header_name("upgrade").is_err()); + assert!(validate_header_name("Upgrade").is_err()); + } + + #[test] + fn test_validate_header_name_blocks_trailer() { + assert!(validate_header_name("trailer").is_err()); + assert!(validate_header_name("Trailer").is_err()); + } + + #[test] + fn test_validate_header_name_allows_x_forwarded_headers() { + // X-Forwarded-* headers are application-layer, not transport-layer. + // Blocking them would break legitimate use cases (e.g., proxy context + // forwarding). The risk requires specific server misconfiguration. + assert!(validate_header_name("x-forwarded-for").is_ok()); + assert!(validate_header_name("X-Forwarded-For").is_ok()); + assert!(validate_header_name("x-forwarded-host").is_ok()); + assert!(validate_header_name("x-real-ip").is_ok()); + } + + #[test] + fn test_validate_header_name_blocks_sec_prefix() { + assert!(validate_header_name("sec-fetch-site").is_err()); + assert!(validate_header_name("sec-fetch-mode").is_err()); + assert!(validate_header_name("sec-ch-ua").is_err()); + assert!(validate_header_name("Sec-Fetch-Dest").is_err()); + } + + #[test] + fn test_validate_header_name_blocks_proxy_prefix() { + assert!(validate_header_name("proxy-authorization").is_err()); + assert!(validate_header_name("proxy-connection").is_err()); + assert!(validate_header_name("Proxy-Authenticate").is_err()); + } + + #[test] + fn test_validate_header_name_error_contains_lowercased_name() { + let result = validate_header_name("Transfer-Encoding"); + + match result { + Err(Error::ForbiddenHeader(name)) => assert_eq!(name, "transfer-encoding"), + _ => panic!("expected ForbiddenHeader error"), + } + } + + #[test] + fn test_validate_header_name_sec_prefix_not_blocked_without_dash() { + // "sec" alone is not in FORBIDDEN_HEADERS and doesn't start with "sec-" + assert!(validate_header_name("sec").is_ok()); + } +} diff --git a/src/commands.rs b/src/commands.rs new file mode 100644 index 0000000..ec0c406 --- /dev/null +++ b/src/commands.rs @@ -0,0 +1,66 @@ +use tauri::{AppHandle, Runtime, State}; + +use crate::client::{HttpClientState, InFlightRequests}; +use crate::error::Result; +use crate::types::{FetchRequest, FetchResponse}; + +/// Executes an HTTP request through the plugin's security and execution pipeline. +/// +/// This is the primary IPC command invoked by the TypeScript guest. +#[tauri::command] +pub(crate) async fn fetch( + _app: AppHandle, + state: State<'_, HttpClientState>, + in_flight: State<'_, InFlightRequests>, + request: FetchRequest, +) -> Result { + let request_id = request.request_id.clone(); + + if let Some(ref id) = request_id { + // Spawn as a trackable task for abort support. + // NOTE: There is a race between spawn and register — the task could + // complete before register() is called. The double-remove below + // mitigates this by ensuring cleanup even if the spawned task's + // internal remove races with register. + let state_ref = state.inner().clone(); + let id_clone = id.clone(); + let in_flight_clone = in_flight.inner().clone(); + + let handle = tokio::spawn(async move { + let result = state_ref.execute(request).await; + + // Clean up tracking regardless of outcome + in_flight_clone.remove(&id_clone).await; + + result + }); + + in_flight.register(id.clone(), handle.abort_handle()).await; + + let result = match handle.await { + Ok(result) => result, + Err(e) if e.is_cancelled() => Err(crate::error::Error::Aborted), + Err(e) => Err(crate::error::Error::Other(format!("task panicked: {e}"))), + }; + + // Ensure cleanup even if spawn's internal remove raced with register + in_flight.remove(id).await; + + result + } else { + state.execute(request).await + } +} + +/// Cancels an in-flight request by its request ID. +/// +/// Returns `true` if a matching request was found and aborted, +/// `false` if no request with the given ID was in flight. +#[tauri::command] +pub(crate) async fn abort_request( + _app: AppHandle, + in_flight: State<'_, InFlightRequests>, + request_id: String, +) -> Result { + Ok(HttpClientState::abort(&in_flight, &request_id).await) +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..aebf58b --- /dev/null +++ b/src/config.rs @@ -0,0 +1,305 @@ +use std::collections::HashMap; +use std::time::Duration; + +/// Configuration for the HTTP client plugin, set during plugin initialization. +/// +/// All fields have sensible defaults. The configuration is immutable after +/// plugin setup. +pub struct HttpClientConfig { + pub default_timeout: Option, + pub max_redirects: usize, + pub max_response_body_size: usize, + pub max_allowlist_size: usize, + pub user_agent: Option, + pub default_headers: HashMap, + pub retry: RetryConfig, + /// Disables the DNS rebinding check (private IP rejection after resolution). + /// + /// When `true`, requests and redirects to private IPs (127.0.0.1, etc.) + /// are allowed. Only intended for integration tests where the test server + /// resolves to localhost. + /// + /// **Default: `false`.** Do not enable in production — the DNS rebinding + /// check is a defense-in-depth layer against SSRF. + pub(crate) allow_private_ip: bool, +} + +impl Default for HttpClientConfig { + fn default() -> Self { + Self { + default_timeout: None, + max_redirects: 10, + max_response_body_size: 10 * 1024 * 1024, // 10MB + max_allowlist_size: 128, + user_agent: None, + default_headers: HashMap::new(), + retry: RetryConfig::disabled(), + allow_private_ip: false, + } + } +} + +/// Configuration for automatic request retries. +/// +/// Disabled by default (`max_retries: 0`). Enable via +/// [`Builder::retry`](crate::Builder::retry) or +/// [`Builder::max_retries`](crate::Builder::max_retries). +/// +/// When enabled, only transient errors (connection failures, timeouts) and +/// configurable status codes trigger retries. Security errors are never +/// retried. +/// +/// Timeout is per-attempt: a request with `max_retries: 3` and a 10s timeout +/// could take up to ~43s (4 attempts + backoff delays). +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// Maximum retry attempts (not counting the initial request). 0 = disabled. + pub max_retries: u32, + /// Base delay before the first retry. Default: 200ms. + /// Subsequent retries use exponential backoff: `initial_backoff * 2^(attempt-1)`. + pub initial_backoff: Duration, + /// Maximum backoff duration (caps exponential growth). Default: 10s. + pub max_backoff: Duration, + /// HTTP status codes that trigger a retry. Default: `[408, 429, 500, 502, 503, 504]`. + /// + /// - **408 Request Timeout**: Server closed an idle connection (RFC 9110 §15.5.9). + /// Explicitly transient — the server is inviting the client to retry. + /// - **429 Too Many Requests**: Rate limited (RFC 6585). Retried with + /// `Retry-After` header support when present. + /// - **500 Internal Server Error**: Often transient in practice (OOM, database + /// pool exhaustion, deployment blips). Safe to retry for idempotent methods; + /// the `retryable_methods` guard prevents duplicate side effects on mutations. + /// - **502 Bad Gateway**: Upstream sent an invalid response. Classic transient + /// infrastructure failure in load-balanced environments. + /// - **503 Service Unavailable**: Server explicitly overloaded or in maintenance + /// (RFC 9110 §15.6.4). The most unambiguously retriable status code. + /// - **504 Gateway Timeout**: Upstream didn't respond in time. Transient by nature. + /// + /// Notably excluded: **501** (server doesn't support the method — permanent), + /// **505** (HTTP version not supported — permanent), **511** (captive portal). + pub retryable_status_codes: Vec, + /// Maximum duration to wait when honoring a `Retry-After` header. + /// Values exceeding this cap are clamped. Default: 60s. + pub max_retry_after: Duration, + /// HTTP methods eligible for retry. Default: `["GET", "HEAD", "PUT", "DELETE", "OPTIONS"]` + /// — the idempotent methods defined by RFC 9110 §9.2.2. + /// + /// PUT and DELETE are idempotent: repeating them produces the same server + /// state as a single execution. POST and PATCH are excluded because they + /// are not idempotent — retrying them risks duplicate side effects (e.g., + /// creating duplicate resources or applying a patch twice). + /// + /// Set to `None` to retry all methods regardless of idempotency. Only do + /// this if you know all endpoints handle duplicate requests safely (e.g., + /// via idempotency keys). + pub retryable_methods: Option>, +} + +impl RetryConfig { + /// Returns a `RetryConfig` with retry disabled (`max_retries: 0`). + pub fn disabled() -> Self { + Self { + max_retries: 0, + ..Self::default() + } + } + + /// Returns `true` if the given status code is in the retryable set. + pub fn is_retryable_status(&self, status: u16) -> bool { + self.retryable_status_codes.contains(&status) + } + + /// Returns `true` if the given HTTP method is eligible for retry. + pub fn is_retryable_method(&self, method: &str) -> bool { + match &self.retryable_methods { + None => true, + Some(methods) => methods.iter().any(|m| m.eq_ignore_ascii_case(method)), + } + } +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + initial_backoff: Duration::from_millis(200), + max_backoff: Duration::from_secs(10), + retryable_status_codes: vec![408, 429, 500, 502, 503, 504], + max_retry_after: Duration::from_secs(60), + retryable_methods: Some(vec![ + "GET".to_string(), + "HEAD".to_string(), + "PUT".to_string(), + "DELETE".to_string(), + "OPTIONS".to_string(), + ]), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_default_values_correct() { + let config = HttpClientConfig::default(); + + assert!(config.default_timeout.is_none()); + assert_eq!(config.max_redirects, 10); + assert_eq!(config.max_allowlist_size, 128); + assert!(config.user_agent.is_none()); + assert!(config.default_headers.is_empty()); + } + + #[test] + fn test_config_default_max_response_body_size_is_10mb() { + let config = HttpClientConfig::default(); + + assert_eq!(config.max_response_body_size, 10 * 1024 * 1024); + } + + #[test] + fn test_config_default_retry_is_disabled() { + let config = HttpClientConfig::default(); + + assert_eq!(config.retry.max_retries, 0); + } + + #[test] + fn test_retry_config_default_values() { + let config = RetryConfig::default(); + + assert_eq!(config.max_retries, 3); + assert_eq!(config.initial_backoff, Duration::from_millis(200)); + assert_eq!(config.max_backoff, Duration::from_secs(10)); + assert_eq!( + config.retryable_status_codes, + vec![408, 429, 500, 502, 503, 504] + ); + assert_eq!(config.max_retry_after, Duration::from_secs(60)); + assert_eq!( + config.retryable_methods, + Some(vec![ + "GET".to_string(), + "HEAD".to_string(), + "PUT".to_string(), + "DELETE".to_string(), + "OPTIONS".to_string() + ]) + ); + } + + #[test] + fn test_retry_config_disabled() { + let config = RetryConfig::disabled(); + + assert_eq!(config.max_retries, 0); + // Other fields inherit defaults + assert_eq!(config.initial_backoff, Duration::from_millis(200)); + } + + #[test] + fn test_is_retryable_status() { + let config = RetryConfig::default(); + + assert!(config.is_retryable_status(408)); + assert!(config.is_retryable_status(429)); + assert!(config.is_retryable_status(500)); + assert!(config.is_retryable_status(502)); + assert!(config.is_retryable_status(503)); + assert!(config.is_retryable_status(504)); + assert!(!config.is_retryable_status(200)); + assert!(!config.is_retryable_status(400)); + assert!(!config.is_retryable_status(401)); + assert!(!config.is_retryable_status(403)); + assert!(!config.is_retryable_status(404)); + assert!(!config.is_retryable_status(501)); + } + + /// 408 Request Timeout is retried because it represents a server-side idle + /// connection timeout (RFC 9110 §15.5.9) — the server is explicitly inviting + /// the client to resend the request. + #[test] + fn test_is_retryable_status_408_request_timeout() { + let config = RetryConfig::default(); + + assert!(config.is_retryable_status(408)); + } + + #[test] + fn test_is_retryable_method_default() { + let config = RetryConfig::default(); + + assert!(config.is_retryable_method("GET")); + assert!(config.is_retryable_method("HEAD")); + assert!(config.is_retryable_method("PUT")); + assert!(config.is_retryable_method("DELETE")); + assert!(config.is_retryable_method("OPTIONS")); + assert!(config.is_retryable_method("get")); // case-insensitive + assert!(config.is_retryable_method("put")); // case-insensitive + assert!(config.is_retryable_method("delete")); // case-insensitive + assert!(!config.is_retryable_method("POST")); + assert!(!config.is_retryable_method("PATCH")); + } + + /// POST and PATCH are not idempotent (RFC 9110 §9.2.2) — retrying them + /// risks duplicate side effects (e.g., creating duplicate resources or + /// applying a patch twice). They are excluded from retries by default. + #[test] + fn test_non_idempotent_methods_not_retried_by_default() { + let config = RetryConfig::default(); + + assert!(!config.is_retryable_method("POST")); + assert!(!config.is_retryable_method("PATCH")); + } + + #[test] + fn test_is_retryable_method_none_allows_all() { + let config = RetryConfig { + retryable_methods: None, + ..RetryConfig::default() + }; + + assert!(config.is_retryable_method("GET")); + assert!(config.is_retryable_method("POST")); + assert!(config.is_retryable_method("PUT")); + assert!(config.is_retryable_method("DELETE")); + assert!(config.is_retryable_method("PATCH")); + } + + #[test] + fn test_is_retryable_status_empty_list() { + let config = RetryConfig { + retryable_status_codes: vec![], + ..RetryConfig::default() + }; + + assert!(!config.is_retryable_status(429)); + assert!(!config.is_retryable_status(503)); + } + + #[test] + fn test_is_retryable_method_empty_list() { + let config = RetryConfig { + retryable_methods: Some(vec![]), + ..RetryConfig::default() + }; + + assert!(!config.is_retryable_method("GET")); + assert!(!config.is_retryable_method("POST")); + } + + #[test] + fn test_custom_retryable_status_codes() { + let config = RetryConfig { + retryable_status_codes: vec![418, 503], + ..RetryConfig::default() + }; + + assert!(config.is_retryable_status(418)); + assert!(config.is_retryable_status(503)); + assert!(!config.is_retryable_status(500)); // Not in custom list + assert!(!config.is_retryable_status(429)); // Not in custom list + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..b939760 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,304 @@ +use std::error::Error as StdError; + +use serde::{Serialize, ser::Serializer}; + +pub type Result = std::result::Result; + +/// Structured error response sent to the TypeScript guest via IPC. +/// +/// Follows the `{code, message}` pattern used by `tauri-plugin-sqlite`. +#[derive(Serialize)] +struct ErrorResponse { + code: String, + message: String, +} + +/// All error types that can occur during HTTP client operations. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("domain not allowed: {0}")] + DomainNotAllowed(String), + + #[error("url scheme not allowed: {0}")] + SchemeNotAllowed(String), + + #[error("ip addresses not allowed in urls")] + IpAddressNotAllowed, + + #[error("invalid url: {0}")] + InvalidUrl(String), + + #[error("request error: {0}")] + Request(reqwest::Error), + + #[error("request aborted")] + Aborted, + + #[error("response too large: {size} bytes exceeds limit of {limit} bytes")] + ResponseTooLarge { size: usize, limit: usize }, + + #[error("redirect to disallowed domain: {0}")] + RedirectBlocked(String), + + #[error("url must not contain userinfo")] + UserinfoNotAllowed, + + #[error("allowlist size exceeded: {count} patterns exceeds limit of {limit}")] + AllowlistSizeExceeded { count: usize, limit: usize }, + + #[error("wildcard patterns not allowed at runtime: {0}")] + WildcardNotAllowedAtRuntime(String), + + #[error("invalid domain pattern: {0}")] + InvalidDomainPattern(String), + + #[error("forbidden header: {0}")] + ForbiddenHeader(String), + + #[error("{0}")] + Other(String), +} + +impl From for Error { + fn from(e: reqwest::Error) -> Self { + if e.is_redirect() { + // Walk the error source chain to find our RedirectBlockedError + let mut source = StdError::source(&e); + + while let Some(err) = source { + if let Some(blocked) = err.downcast_ref::() { + return Error::RedirectBlocked(blocked.0.clone()); + } + source = err.source(); + } + } + + Error::Request(e) + } +} + +impl Error { + /// Returns `true` if this error represents a transient failure that may + /// succeed on retry (connection errors and timeouts). + /// + /// Security errors (`DomainNotAllowed`, `IpAddressNotAllowed`, etc.) are + /// never retryable — they indicate policy violations that will not change + /// between attempts. + pub(crate) fn is_retryable(&self) -> bool { + matches!(self, Error::Request(e) if e.is_timeout() || e.is_connect()) + } + + fn code(&self) -> &str { + match self { + Error::DomainNotAllowed(_) => "DOMAIN_NOT_ALLOWED", + Error::SchemeNotAllowed(_) => "SCHEME_NOT_ALLOWED", + Error::IpAddressNotAllowed => "IP_ADDRESS_NOT_ALLOWED", + Error::InvalidUrl(_) => "INVALID_URL", + Error::Request(e) => { + if e.is_timeout() { + "TIMEOUT" + } else if e.is_connect() { + "CONNECTION_ERROR" + } else { + "REQUEST_ERROR" + } + } + Error::Aborted => "ABORTED", + Error::ResponseTooLarge { .. } => "RESPONSE_TOO_LARGE", + Error::RedirectBlocked(_) => "REDIRECT_BLOCKED", + // Surfaced as INVALID_URL to avoid leaking internal validation details + Error::UserinfoNotAllowed => "INVALID_URL", + Error::AllowlistSizeExceeded { .. } => "ALLOWLIST_SIZE_EXCEEDED", + Error::WildcardNotAllowedAtRuntime(_) => "WILDCARD_NOT_ALLOWED_AT_RUNTIME", + Error::InvalidDomainPattern(_) => "INVALID_DOMAIN_PATTERN", + Error::ForbiddenHeader(_) => "FORBIDDEN_HEADER", + Error::Other(_) => "ERROR", + } + } +} + +impl Serialize for Error { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + let resp = ErrorResponse { + code: self.code().to_string(), + message: self.to_string(), + }; + + resp.serialize(serializer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_serialization_domain_not_allowed() { + let error = Error::DomainNotAllowed("evil.com".to_string()); + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "DOMAIN_NOT_ALLOWED"); + assert_eq!(json["message"], "domain not allowed: evil.com"); + } + + #[test] + fn test_error_serialization_scheme_not_allowed() { + let error = Error::SchemeNotAllowed("ftp".to_string()); + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "SCHEME_NOT_ALLOWED"); + assert_eq!(json["message"], "url scheme not allowed: ftp"); + } + + #[test] + fn test_error_serialization_ip_address_not_allowed() { + let error = Error::IpAddressNotAllowed; + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "IP_ADDRESS_NOT_ALLOWED"); + } + + #[test] + fn test_error_serialization_invalid_url() { + let error = Error::InvalidUrl("not a url".to_string()); + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "INVALID_URL"); + } + + #[test] + fn test_error_serialization_aborted() { + let error = Error::Aborted; + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "ABORTED"); + assert_eq!(json["message"], "request aborted"); + } + + #[test] + fn test_error_serialization_response_too_large() { + let error = Error::ResponseTooLarge { + size: 20_000_000, + limit: 10_000_000, + }; + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "RESPONSE_TOO_LARGE"); + assert!( + json["message"] + .as_str() + .unwrap() + .contains("20000000 bytes exceeds limit of 10000000 bytes") + ); + } + + #[test] + fn test_error_serialization_redirect_blocked() { + let error = Error::RedirectBlocked("evil.com".to_string()); + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "REDIRECT_BLOCKED"); + } + + #[test] + fn test_error_serialization_userinfo_not_allowed() { + let error = Error::UserinfoNotAllowed; + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "INVALID_URL"); + } + + #[test] + fn test_error_serialization_allowlist_size_exceeded() { + let error = Error::AllowlistSizeExceeded { + count: 150, + limit: 128, + }; + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "ALLOWLIST_SIZE_EXCEEDED"); + assert!( + json["message"] + .as_str() + .unwrap() + .contains("150 patterns exceeds limit of 128") + ); + } + + #[test] + fn test_error_serialization_wildcard_not_allowed_at_runtime() { + let error = Error::WildcardNotAllowedAtRuntime("*.evil.com".to_string()); + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "WILDCARD_NOT_ALLOWED_AT_RUNTIME"); + assert!(json["message"].as_str().unwrap().contains("*.evil.com")); + } + + #[test] + fn test_error_serialization_invalid_domain_pattern() { + let error = Error::InvalidDomainPattern("pattern contains invalid characters".to_string()); + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "INVALID_DOMAIN_PATTERN"); + assert!( + json["message"] + .as_str() + .unwrap() + .contains("pattern contains invalid characters") + ); + } + + #[test] + fn test_error_serialization_forbidden_header() { + let error = Error::ForbiddenHeader("host".to_string()); + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "FORBIDDEN_HEADER"); + assert!(json["message"].as_str().unwrap().contains("host")); + } + + #[test] + fn test_error_serialization_other() { + let error = Error::Other("something went wrong".to_string()); + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], "ERROR"); + assert_eq!(json["message"], "something went wrong"); + } + + #[test] + fn test_is_retryable_security_errors_never_retryable() { + assert!(!Error::DomainNotAllowed("evil.com".to_string()).is_retryable()); + assert!(!Error::SchemeNotAllowed("ftp".to_string()).is_retryable()); + assert!(!Error::IpAddressNotAllowed.is_retryable()); + assert!(!Error::InvalidUrl("bad".to_string()).is_retryable()); + assert!(!Error::Aborted.is_retryable()); + assert!(!Error::RedirectBlocked("evil.com".to_string()).is_retryable()); + assert!(!Error::UserinfoNotAllowed.is_retryable()); + assert!( + !Error::AllowlistSizeExceeded { + count: 200, + limit: 128 + } + .is_retryable() + ); + assert!(!Error::WildcardNotAllowedAtRuntime("*.evil.com".to_string()).is_retryable()); + assert!(!Error::InvalidDomainPattern("bad".to_string()).is_retryable()); + assert!(!Error::ForbiddenHeader("host".to_string()).is_retryable()); + assert!(!Error::Other("fail".to_string()).is_retryable()); + } + + #[test] + fn test_is_retryable_response_too_large_not_retryable() { + assert!( + !Error::ResponseTooLarge { + size: 20_000_000, + limit: 10_000_000 + } + .is_retryable() + ); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..947f96a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,461 @@ +//! Tauri plugin providing HTTP request capabilities with a domain allowlist +//! for security. +//! +//! # Overview +//! +//! This plugin exposes a `request(url, options)` API to the Tauri webview, +//! backed by Rust's `reqwest` crate. Every request is validated against a +//! domain allowlist that can be configured at plugin initialization and +//! modified at runtime from Rust via [`HttpClientExt`], preventing +//! unauthorized network access from the frontend. +//! +//! # Usage +//! +//! ```no_run +//! use std::time::Duration; +//! +//! tauri::Builder::default() +//! .plugin( +//! tauri_plugin_http_client::Builder::new() +//! .allowed_domains([ +//! "api.example.com", +//! "*.cdn.example.com", +//! ]) +//! .default_timeout(Duration::from_secs(30)) +//! .build() +//! ); +//! ``` + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use parking_lot::RwLock; +use tauri::{ + Manager, Runtime, + plugin::{self, TauriPlugin}, +}; + +pub mod allowlist; +pub mod client; +mod commands; +pub mod config; +pub mod error; +pub mod types; + +use allowlist::DomainAllowlist; +use client::{HttpClientState, InFlightRequests, build_redirect_policy}; +use config::{HttpClientConfig, RetryConfig}; + +/// Plugin builder for configuring the HTTP client before initialization. +/// +/// # Examples +/// +/// ```no_run +/// use std::time::Duration; +/// +/// let plugin = tauri_plugin_http_client::Builder::new() +/// .allowed_domains(["api.example.com"]) +/// .default_timeout(Duration::from_secs(30)) +/// .max_redirects(5) +/// .build(); +/// ``` +pub struct Builder { + allowed_domains: Vec, + default_timeout: Option, + max_redirects: Option, + max_response_body_size: Option, + max_allowlist_size: Option, + user_agent: Option, + default_headers: Option>, + retry: Option, +} + +impl Builder { + /// Creates a new builder with no allowed domains (blocks all requests). + pub fn new() -> Self { + Self { + allowed_domains: Vec::new(), + default_timeout: None, + max_redirects: None, + max_response_body_size: None, + max_allowlist_size: None, + user_agent: None, + default_headers: None, + retry: None, + } + } + + /// Sets the list of allowed domain patterns. + /// + /// Supported formats: + /// - `"api.example.com"` - exact match + /// - `"*.example.com"` - any subdomain of `example.com` + /// + /// An empty list blocks all requests (secure by default). + pub fn allowed_domains(mut self, domains: impl IntoIterator>) -> Self { + self.allowed_domains = domains.into_iter().map(Into::into).collect(); + self + } + + /// Sets the default request timeout. Can be overridden per-request. + pub fn default_timeout(mut self, timeout: Duration) -> Self { + self.default_timeout = Some(timeout); + self + } + + /// Sets the maximum number of redirects to follow (default: 10). + pub fn max_redirects(mut self, max: usize) -> Self { + self.max_redirects = Some(max); + self + } + + /// Sets the maximum response body size in bytes (default: 10MB). + pub fn max_response_body_size(mut self, max: usize) -> Self { + self.max_response_body_size = Some(max); + self + } + + /// Sets the maximum total number of patterns in the allowlist (default: 128). + /// + /// This cap applies to all patterns, including those set at init time and + /// those added at runtime. It prevents runaway additions from buggy code. + pub fn max_allowlist_size(mut self, max: usize) -> Self { + self.max_allowlist_size = Some(max); + self + } + + /// Sets a custom User-Agent header for all requests. + pub fn user_agent(mut self, ua: impl Into) -> Self { + self.user_agent = Some(ua.into()); + self + } + + /// Sets default headers applied to all requests. + /// + /// Per-request headers override these defaults. + pub fn default_headers( + mut self, + headers: impl IntoIterator, impl Into)>, + ) -> Self { + self.default_headers = Some( + headers + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(), + ); + self + } + + /// Enables automatic retry with the provided configuration. + /// + /// By default, retry is disabled. Pass [`RetryConfig::default()`] for + /// sensible defaults (3 retries, 200ms initial backoff, exponential + /// backoff with jitter, only idempotent methods). + /// + /// # Examples + /// + /// ```no_run + /// use tauri_plugin_http_client::config::RetryConfig; + /// + /// let plugin = tauri_plugin_http_client::Builder::new() + /// .allowed_domains(["api.example.com"]) + /// .retry(RetryConfig::default()) + /// .build(); + /// ``` + pub fn retry(mut self, config: RetryConfig) -> Self { + self.retry = Some(config); + self + } + + /// Convenience method to enable retry with a specific max retry count + /// and default settings for all other retry parameters. + /// + /// Equivalent to `retry(RetryConfig { max_retries: n, ..Default::default() })`. + pub fn max_retries(mut self, n: u32) -> Self { + self.retry = Some(RetryConfig { + max_retries: n, + ..RetryConfig::default() + }); + self + } + + /// Builds the Tauri plugin with the configured settings. + pub fn build(self) -> TauriPlugin { + let allowed_domains = self.allowed_domains; + let default_timeout = self.default_timeout; + let max_redirects = self.max_redirects.unwrap_or(10); + let max_response_body_size = self.max_response_body_size.unwrap_or(10 * 1024 * 1024); + let max_allowlist_size = self.max_allowlist_size.unwrap_or(128); + let user_agent = self.user_agent; + let default_headers = self.default_headers.unwrap_or_default(); + let retry = self.retry.unwrap_or_else(RetryConfig::disabled); + + plugin::Builder::new("http-client") + .invoke_handler(tauri::generate_handler![ + commands::fetch, + commands::abort_request, + ]) + .setup(move |app, _api| { + // Fail fast on forbidden default headers — these are developer + // configuration errors that should surface at plugin init, not + // silently affect every request. + for key in default_headers.keys() { + client::validate_header_name(key).map_err(|e| e.to_string())?; + } + + // Create shared allowlist: same Arc used by both HttpClientState + // and the redirect policy closure. + let allowlist = Arc::new(RwLock::new( + DomainAllowlist::new(allowed_domains).map_err(|e| e.to_string())?, + )); + + let redirect_policy = build_redirect_policy(Arc::clone(&allowlist), max_redirects); + + let mut client_builder = reqwest::Client::builder().redirect(redirect_policy); + + if let Some(ref ua) = user_agent { + client_builder = client_builder.user_agent(ua.clone()); + } + + let client = client_builder.build().map_err(|e| e.to_string())?; + + let config = HttpClientConfig { + default_timeout, + max_redirects, + max_response_body_size, + max_allowlist_size, + user_agent, + default_headers, + retry, + allow_private_ip: false, + }; + + let state = HttpClientState::new(client, allowlist, config); + + app.manage(state); + app.manage(InFlightRequests::new()); + + Ok(()) + }) + .build() + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +/// Convenience function that creates a plugin with default settings. +/// +/// The default configuration has an empty allowlist, which blocks all requests. +/// Use [`Builder`] for custom configuration. +pub fn init() -> TauriPlugin { + Builder::new().build() +} + +/// Extension trait providing access to the HTTP client state from any Tauri manager. +/// +/// # Examples +/// +/// ```no_run +/// use tauri::Manager; +/// use tauri_plugin_http_client::HttpClientExt; +/// +/// // In a Tauri command or setup hook: +/// // let state = app.http_client(); +/// // app.add_allowed_domains(["new-api.example.com"]).unwrap(); +/// ``` +pub trait HttpClientExt { + /// Returns a reference to the HTTP client state. + fn http_client(&self) -> &HttpClientState; + + /// Adds a single domain pattern to the allowlist at runtime. + /// + /// See [`HttpClientState::add_allowed_domain`] for details. + fn add_allowed_domain(&self, domain: impl Into) -> error::Result<()>; + + /// Adds multiple domain patterns to the allowlist at runtime. + /// + /// See [`HttpClientState::add_allowed_domains`] for details. + fn add_allowed_domains( + &self, + domains: impl IntoIterator>, + ) -> error::Result<()>; + + /// Removes a single domain from the runtime allowlist. + /// + /// See [`HttpClientState::remove_allowed_domain`] for details. + fn remove_allowed_domain(&self, domain: impl Into) -> error::Result; + + /// Removes multiple domains from the runtime allowlist. + /// + /// See [`HttpClientState::remove_allowed_domains`] for details. + fn remove_allowed_domains( + &self, + domains: impl IntoIterator>, + ) -> error::Result; + + /// Removes all runtime-added domains, preserving config-time domains. + /// + /// See [`HttpClientState::remove_all_runtime_domains`] for details. + fn remove_all_runtime_domains(&self) -> usize; +} + +impl> HttpClientExt for T { + fn http_client(&self) -> &HttpClientState { + self.state::().inner() + } + + fn add_allowed_domain(&self, domain: impl Into) -> error::Result<()> { + self.http_client().add_allowed_domain(domain) + } + + fn add_allowed_domains( + &self, + domains: impl IntoIterator>, + ) -> error::Result<()> { + self.http_client().add_allowed_domains(domains) + } + + fn remove_allowed_domain(&self, domain: impl Into) -> error::Result { + self.http_client().remove_allowed_domain(domain) + } + + fn remove_allowed_domains( + &self, + domains: impl IntoIterator>, + ) -> error::Result { + self.http_client().remove_allowed_domains(domains) + } + + fn remove_all_runtime_domains(&self) -> usize { + self.http_client().remove_all_runtime_domains() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_builder_default_has_empty_allowlist() { + let builder = Builder::new(); + + assert!(builder.allowed_domains.is_empty()); + assert!(builder.default_timeout.is_none()); + assert!(builder.max_redirects.is_none()); + assert!(builder.max_response_body_size.is_none()); + assert!(builder.max_allowlist_size.is_none()); + assert!(builder.user_agent.is_none()); + assert!(builder.default_headers.is_none()); + assert!(builder.retry.is_none()); + } + + #[test] + fn test_builder_default_impl_matches_new() { + let from_new = Builder::new(); + let from_default = Builder::default(); + + assert_eq!(from_new.allowed_domains, from_default.allowed_domains); + assert_eq!(from_new.default_timeout, from_default.default_timeout); + assert_eq!(from_new.max_redirects, from_default.max_redirects); + assert_eq!( + from_new.max_response_body_size, + from_default.max_response_body_size + ); + assert_eq!(from_new.max_allowlist_size, from_default.max_allowlist_size); + assert_eq!(from_new.user_agent, from_default.user_agent); + assert_eq!(from_new.default_headers, from_default.default_headers); + assert!(from_new.retry.is_none()); + assert!(from_default.retry.is_none()); + } + + #[test] + fn test_builder_setters() { + let builder = Builder::new() + .allowed_domains(["example.com"]) + .default_timeout(Duration::from_secs(30)) + .max_redirects(5) + .max_response_body_size(1024) + .max_allowlist_size(64) + .user_agent("test-agent") + .default_headers([("x-key", "val")]); + + assert_eq!(builder.allowed_domains, vec!["example.com"]); + assert_eq!(builder.default_timeout, Some(Duration::from_secs(30))); + assert_eq!(builder.max_redirects, Some(5)); + assert_eq!(builder.max_response_body_size, Some(1024)); + assert_eq!(builder.max_allowlist_size, Some(64)); + assert_eq!(builder.user_agent, Some("test-agent".to_string())); + assert_eq!( + builder.default_headers, + Some(HashMap::from([("x-key".to_string(), "val".to_string())])) + ); + } + + #[test] + fn test_builder_allowed_domains_accepts_vec_of_strings() { + let domains: Vec = vec!["a.example.com".to_string(), "b.example.com".to_string()]; + let builder = Builder::new().allowed_domains(domains); + + assert_eq!( + builder.allowed_domains, + vec!["a.example.com", "b.example.com"] + ); + } + + #[test] + fn test_builder_allowed_domains_accepts_empty_array() { + let builder = Builder::new().allowed_domains(std::iter::empty::()); + + assert!(builder.allowed_domains.is_empty()); + } + + #[test] + fn test_builder_allowed_domains_accepts_filtered_iterator() { + let all = vec!["keep.example.com", "skip.example.com", "keep2.example.com"]; + let builder = + Builder::new().allowed_domains(all.into_iter().filter(|d| d.starts_with("keep"))); + + assert_eq!( + builder.allowed_domains, + vec!["keep.example.com", "keep2.example.com"] + ); + } + + #[test] + fn test_builder_retry_setter() { + let builder = Builder::new().retry(RetryConfig::default()); + + assert!(builder.retry.is_some()); + assert_eq!(builder.retry.as_ref().unwrap().max_retries, 3); + } + + #[test] + fn test_builder_default_headers_accepts_hashmap() { + let headers: HashMap = + HashMap::from([("x-key".to_string(), "val".to_string())]); + let builder = Builder::new().default_headers(headers); + + assert_eq!( + builder.default_headers, + Some(HashMap::from([("x-key".to_string(), "val".to_string())])) + ); + } + + #[test] + fn test_builder_max_retries_convenience() { + let builder = Builder::new().max_retries(5); + + assert!(builder.retry.is_some()); + assert_eq!(builder.retry.as_ref().unwrap().max_retries, 5); + // Other fields should be defaults + assert_eq!( + builder.retry.as_ref().unwrap().initial_backoff, + Duration::from_millis(200) + ); + } +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..5f7dea0 --- /dev/null +++ b/src/types.rs @@ -0,0 +1,154 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +/// Request payload sent from the TypeScript guest to the Rust backend via IPC. +/// +/// All URL parsing and validation happens exclusively in Rust to avoid +/// JS/Rust URL parsing differentials. +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FetchRequest { + pub url: String, + pub method: Option, + pub headers: Option>, + pub body: Option, + pub body_encoding: Option, + pub timeout_ms: Option, + pub request_id: Option, + /// Per-request retry override. `None` uses plugin config default. + /// `Some(0)` disables retry for this request. Capped at the plugin-level + /// `RetryConfig::max_retries` — the frontend cannot exceed the configured ceiling. + pub max_retries: Option, +} + +/// Response payload sent from the Rust backend to the TypeScript guest via IPC. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct FetchResponse { + pub status: u16, + pub status_text: String, + pub headers: HashMap>, + pub body: String, + pub body_encoding: String, + pub url: String, + pub redirected: bool, + /// Number of retry attempts that occurred before this response (0 = no retries). + pub retry_count: u32, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fetch_request_deserializes_camel_case() { + let json = serde_json::json!({ + "url": "https://example.com", + "method": "POST", + "headers": {"content-type": "application/json"}, + "body": "hello", + "bodyEncoding": "utf8", + "timeoutMs": 5000, + "requestId": "req-1" + }); + + let req: FetchRequest = serde_json::from_value(json).unwrap(); + + assert_eq!(req.url, "https://example.com"); + assert_eq!(req.method.as_deref(), Some("POST")); + assert_eq!( + req.headers.as_ref().unwrap().get("content-type").unwrap(), + "application/json" + ); + assert_eq!(req.body.as_deref(), Some("hello")); + assert_eq!(req.body_encoding.as_deref(), Some("utf8")); + assert_eq!(req.timeout_ms, Some(5000)); + assert_eq!(req.request_id.as_deref(), Some("req-1")); + assert!(req.max_retries.is_none()); + } + + #[test] + fn test_fetch_request_minimal() { + let json = serde_json::json!({"url": "https://example.com"}); + let req: FetchRequest = serde_json::from_value(json).unwrap(); + + assert_eq!(req.url, "https://example.com"); + assert!(req.method.is_none()); + assert!(req.headers.is_none()); + assert!(req.body.is_none()); + assert!(req.body_encoding.is_none()); + assert!(req.timeout_ms.is_none()); + assert!(req.request_id.is_none()); + assert!(req.max_retries.is_none()); + } + + #[test] + fn test_fetch_response_serializes_camel_case() { + let resp = FetchResponse { + status: 200, + status_text: "OK".to_string(), + headers: HashMap::from([("content-type".to_string(), vec!["text/html".to_string()])]), + body: "hello".to_string(), + body_encoding: "utf8".to_string(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }; + + let json = serde_json::to_value(&resp).unwrap(); + + assert_eq!(json["status"], 200); + assert_eq!(json["statusText"], "OK"); + assert_eq!(json["body"], "hello"); + assert_eq!(json["bodyEncoding"], "utf8"); + assert_eq!(json["url"], "https://example.com"); + assert_eq!(json["redirected"], false); + assert_eq!(json["retryCount"], 0); + assert!(json["headers"]["content-type"].is_array()); + } + + #[test] + fn test_fetch_request_with_max_retries() { + let json = serde_json::json!({ + "url": "https://example.com", + "maxRetries": 5 + }); + + let req: FetchRequest = serde_json::from_value(json).unwrap(); + + assert_eq!(req.max_retries, Some(5)); + } + + #[test] + fn test_fetch_request_missing_url_fails_deserialization() { + let json = serde_json::json!({"method": "GET"}); + let result = serde_json::from_value::(json); + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + + assert!( + err_msg.contains("url"), + "error should mention missing 'url' field: {err_msg}" + ); + } + + #[test] + fn test_fetch_response_retry_count_serializes() { + let resp = FetchResponse { + status: 200, + status_text: "OK".to_string(), + headers: HashMap::new(), + body: "ok".to_string(), + body_encoding: "utf8".to_string(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 2, + }; + + let json = serde_json::to_value(&resp).unwrap(); + + assert_eq!(json["retryCount"], 2); + } +} diff --git a/tsconfig.src.json b/tsconfig.src.json index 5ba2890..a0ae008 100644 --- a/tsconfig.src.json +++ b/tsconfig.src.json @@ -1,7 +1,10 @@ { "extends": "@silvermine/typescript-config/tsconfig.esm.json", - "include": [], + "include": [ + "./guest-js/*" + ], "exclude": [ + "./guest-js/*.test.ts", "./tests/*", "**/node_modules/*" ], From 6dbead80f571132665517d8500c588725fed67b5 Mon Sep 17 00:00:00 2001 From: Jordan Hafer <42755763+jjhafer@users.noreply.github.com> Date: Thu, 12 Mar 2026 14:57:31 -0400 Subject: [PATCH 2/3] sub(feat): add sample tauri app --- examples/tauri-app/.gitignore | 6 + examples/tauri-app/README.md | 65 +++ examples/tauri-app/package.json | 18 + examples/tauri-app/src-tauri/Cargo.toml | 16 + examples/tauri-app/src-tauri/build.rs | 3 + .../src-tauri/capabilities/default.json | 9 + examples/tauri-app/src-tauri/src/lib.rs | 15 + examples/tauri-app/src-tauri/src/main.rs | 5 + examples/tauri-app/src-tauri/tauri.conf.json | 22 + examples/tauri-app/src/index.html | 95 +++++ examples/tauri-app/src/main.js | 294 ++++++++++++++ examples/tauri-app/src/styles.css | 378 ++++++++++++++++++ examples/tauri-app/vite.config.js | 9 + 13 files changed, 935 insertions(+) create mode 100644 examples/tauri-app/.gitignore create mode 100644 examples/tauri-app/README.md create mode 100644 examples/tauri-app/package.json create mode 100644 examples/tauri-app/src-tauri/Cargo.toml create mode 100644 examples/tauri-app/src-tauri/build.rs create mode 100644 examples/tauri-app/src-tauri/capabilities/default.json create mode 100644 examples/tauri-app/src-tauri/src/lib.rs create mode 100644 examples/tauri-app/src-tauri/src/main.rs create mode 100644 examples/tauri-app/src-tauri/tauri.conf.json create mode 100644 examples/tauri-app/src/index.html create mode 100644 examples/tauri-app/src/main.js create mode 100644 examples/tauri-app/src/styles.css create mode 100644 examples/tauri-app/vite.config.js diff --git a/examples/tauri-app/.gitignore b/examples/tauri-app/.gitignore new file mode 100644 index 0000000..15cc938 --- /dev/null +++ b/examples/tauri-app/.gitignore @@ -0,0 +1,6 @@ +node_modules/ +dist/ +src-tauri/target/ +src-tauri/gen/ +*.log +.DS_Store diff --git a/examples/tauri-app/README.md b/examples/tauri-app/README.md new file mode 100644 index 0000000..1c407f5 --- /dev/null +++ b/examples/tauri-app/README.md @@ -0,0 +1,65 @@ +# HTTP Client Plugin Example + +A minimal Tauri v2 application demonstrating the `@silvermine/tauri-plugin-http-client` plugin. + +## Prerequisites + +- [Node.js](https://nodejs.org/) (v18+) +- [Rust](https://rustup.rs/) (1.89+, edition 2024) +- [Tauri prerequisites](https://v2.tauri.app/start/prerequisites/) for your platform + +## Setup + +From the plugin root, build the TypeScript package first: + +```sh +npm install +npm run build +``` + +Then set up the example app: + +```sh +cd examples/tauri-app +npm install +``` + +## Running + +### Desktop (macOS, Linux, Windows) + +```sh +npm run tauri dev +``` + +### Android + +```sh +npx tauri android init +npx tauri android dev +``` + +### iOS + +```sh +npx tauri ios init +npx tauri ios dev +``` + +## What It Demonstrates + +The app has six demo panels: + +1. **GET Request** -- Fetch JSON from httpbin.org and display the response +2. **POST Request** -- Send a JSON body with explicit `Content-Type` header +3. **Custom Headers** -- Use the `HttpHeaders` API to set, get, and iterate headers +4. **Abort Request** -- Cancel an in-flight request using `AbortController` +5. **Error Handling** -- Trigger and display `HttpClientError` codes (blocked domain, timeout, invalid URL) +6. **Binary Response** -- Fetch a PNG image using `response.bytes()` and display it + +## Notes + +- The plugin is referenced via relative path (`file:../../` in package.json, `path = "../../../"` in Cargo.toml). This is for development only. +- The `src-tauri/gen/` directory is generated by `tauri android init` / `tauri ios init` and is gitignored. +- The allowed domain is `httpbin.org` only. Requests to other domains will fail with `DOMAIN_NOT_ALLOWED`. +- Abort is best-effort; a request may complete before the cancellation signal is processed. diff --git a/examples/tauri-app/package.json b/examples/tauri-app/package.json new file mode 100644 index 0000000..a6eb11b --- /dev/null +++ b/examples/tauri-app/package.json @@ -0,0 +1,18 @@ +{ + "name": "tauri-plugin-http-client-example", + "private": true, + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "tauri": "tauri" + }, + "dependencies": { + "@silvermine/tauri-plugin-http-client": "file:../../" + }, + "devDependencies": { + "@tauri-apps/api": "2.10.1", + "@tauri-apps/cli": "2.10.1", + "vite": "6.3.5" + } +} diff --git a/examples/tauri-app/src-tauri/Cargo.toml b/examples/tauri-app/src-tauri/Cargo.toml new file mode 100644 index 0000000..af3638c --- /dev/null +++ b/examples/tauri-app/src-tauri/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "tauri-app" +version = "0.1.0" +edition = "2024" +rust-version = "1.89" + +[lib] +name = "tauri_app_lib" +crate-type = ["staticlib", "cdylib", "rlib"] + +[dependencies] +tauri = { version = "=2.10.3", features = [] } +tauri-plugin-http-client = { path = "../../../" } + +[build-dependencies] +tauri-build = { version = "=2.5.6", features = [] } diff --git a/examples/tauri-app/src-tauri/build.rs b/examples/tauri-app/src-tauri/build.rs new file mode 100644 index 0000000..eabc6af --- /dev/null +++ b/examples/tauri-app/src-tauri/build.rs @@ -0,0 +1,3 @@ +fn main() { + tauri_build::build(); +} diff --git a/examples/tauri-app/src-tauri/capabilities/default.json b/examples/tauri-app/src-tauri/capabilities/default.json new file mode 100644 index 0000000..5c27c9e --- /dev/null +++ b/examples/tauri-app/src-tauri/capabilities/default.json @@ -0,0 +1,9 @@ +{ + "identifier": "default", + "description": "Default capability for the example app", + "windows": ["main"], + "permissions": [ + "core:default", + "http-client:default" + ] +} diff --git a/examples/tauri-app/src-tauri/src/lib.rs b/examples/tauri-app/src-tauri/src/lib.rs new file mode 100644 index 0000000..df7961a --- /dev/null +++ b/examples/tauri-app/src-tauri/src/lib.rs @@ -0,0 +1,15 @@ +use std::time::Duration; + +#[cfg_attr(mobile, tauri::mobile_entry_point)] +pub fn run() { + tauri::Builder::default() + .plugin( + tauri_plugin_http_client::Builder::new() + .allowed_domains(["httpbin.org", "*.httpbin.org"]) + .default_timeout(Duration::from_secs(30)) + .max_response_body_size(5 * 1024 * 1024) + .build(), + ) + .run(tauri::generate_context!()) + .expect("error while running tauri application"); +} diff --git a/examples/tauri-app/src-tauri/src/main.rs b/examples/tauri-app/src-tauri/src/main.rs new file mode 100644 index 0000000..87d4e89 --- /dev/null +++ b/examples/tauri-app/src-tauri/src/main.rs @@ -0,0 +1,5 @@ +#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] + +fn main() { + tauri_app_lib::run(); +} diff --git a/examples/tauri-app/src-tauri/tauri.conf.json b/examples/tauri-app/src-tauri/tauri.conf.json new file mode 100644 index 0000000..3827709 --- /dev/null +++ b/examples/tauri-app/src-tauri/tauri.conf.json @@ -0,0 +1,22 @@ +{ + "productName": "HTTP Client Example", + "version": "0.1.0", + "identifier": "com.silvermine.httpClientExample", + "build": { + "devUrl": "http://localhost:5173", + "frontendDist": "../dist" + }, + "app": { + "withGlobalTauri": false, + "windows": [ + { + "title": "HTTP Client Plugin Demo", + "width": 900, + "height": 700 + } + ] + }, + "bundle": { + "active": false + } +} diff --git a/examples/tauri-app/src/index.html b/examples/tauri-app/src/index.html new file mode 100644 index 0000000..ff2fd0e --- /dev/null +++ b/examples/tauri-app/src/index.html @@ -0,0 +1,95 @@ + + + + + + HTTP Client Plugin Demo + + + +

HTTP Client Plugin Demo

+

+ Demonstrates @silvermine/tauri-plugin-http-client capabilities. + Allowed domain: httpbin.org +

+ +
+ + +
+

GET Request

+
+ + +
+
Response will appear here...
+
+ + +
+

POST Request

+
+ + + +
+
Response will appear here...
+
+ + +
+

Custom Headers

+

+ Sends custom headers via the HttpHeaders API and displays + multi-value response headers. Note: request headers are single-value per + name via toRecord(). +

+
+ +
+
Response will appear here...
+
+ + +
+

Abort Request

+

+ Starts a 10-second delayed request, then aborts it. + Abort is best-effort; the request may complete before cancellation takes effect. +

+
+ + +
+
Click "Start Delayed Request" to begin...
+
+ + +
+

Error Handling

+

Triggers intentional errors to demonstrate HttpClientError codes.

+
+ + + +
+
Click a button to trigger an error...
+
+ + +
+

Binary Response

+

Fetches an image using response.bytes() and displays it as a blob URL.

+
+ +
+
+ Click "Fetch Image" to load a PNG from httpbin.org... +
+
+ +
+ + + + diff --git a/examples/tauri-app/src/main.js b/examples/tauri-app/src/main.js new file mode 100644 index 0000000..6726fd5 --- /dev/null +++ b/examples/tauri-app/src/main.js @@ -0,0 +1,294 @@ +import { request, HttpHeaders, HttpClientError } from '@silvermine/tauri-plugin-http-client'; + +// --- Helpers --- + +function $(id) { + return document.getElementById(id); +} + +function clearElement(el) { + while (el.firstChild) { + el.removeChild(el.firstChild); + } +} + +function setLoading(el, message) { + el.textContent = message || 'Loading'; + el.classList.remove('error'); + el.classList.add('loading'); +} + +function setResult(el, text) { + el.classList.remove('error', 'loading'); + el.textContent = text; +} + +function setError(el, err) { + el.classList.remove('loading'); + el.classList.add('error'); + el.textContent = formatError(err); +} + +function formatResponse(resp) { + const headers = {}; + + resp.headers.forEach(function(value, name) { + const existing = headers[name]; + + if (existing) { + headers[name] = existing + ', ' + value; + } else { + headers[name] = value; + } + }); + + return JSON.stringify( + { + status: resp.status, + statusText: resp.statusText, + ok: resp.ok, + redirected: resp.redirected, + url: resp.url, + headers: headers, + body: tryParseJSON(resp.text()), + }, + null, + 2, + ); +} + +function tryParseJSON(text) { + try { + return JSON.parse(text); + } catch(_e) { + return text; + } +} + +function formatError(err) { + if (err instanceof HttpClientError) { + return 'HttpClientError\n code: ' + err.code + '\n message: ' + err.message; + } + + return String(err); +} + +// --- GET Request --- + +$('get-send').addEventListener('click', async function() { + const output = $('get-output'), + btn = $('get-send'); + + btn.disabled = true; + setLoading(output, 'Sending GET request'); + + try { + const resp = await request($('get-url').value); + + setResult(output, formatResponse(resp)); + } catch(err) { + setError(output, err); + } finally { + btn.disabled = false; + } +}); + +// --- POST Request --- + +$('post-send').addEventListener('click', async function() { + const output = $('post-output'), + btn = $('post-send'); + + btn.disabled = true; + setLoading(output, 'Sending POST request'); + + try { + // Explicitly set Content-Type when sending JSON. + // encodeBody() serializes objects to JSON strings but does not + // auto-set the Content-Type header. + const headers = new HttpHeaders(); + + headers.set('Content-Type', 'application/json'); + + const body = JSON.parse($('post-body').value), + resp = await request($('post-url').value, { + method: 'POST', + headers: headers, + body: body, + }); + + setResult(output, formatResponse(resp)); + } catch(err) { + setError(output, err); + } finally { + btn.disabled = false; + } +}); + +// --- Headers --- + +$('headers-send').addEventListener('click', async function() { + const output = $('headers-output'), + btn = $('headers-send'); + + btn.disabled = true; + setLoading(output, 'Sending request with custom headers'); + + try { + const headers = new HttpHeaders(); + + headers.set('X-Custom-Header', 'hello-from-tauri'); + headers.set('Accept', 'application/json'); + + const resp = await request('https://httpbin.org/headers', { headers: headers }); + + // Demonstrate multi-value response header access + let result = '--- Request sent with custom headers ---\n'; + + result += 'X-Custom-Header: ' + headers.get('x-custom-header') + '\n\n'; + result += '--- Response ---\n'; + result += 'Status: ' + resp.status + '\n\n'; + + result += '--- Response Headers (multi-value access) ---\n'; + + resp.headers.forEach(function(value, name) { + result += name + ': ' + value + '\n'; + }); + + result += '\n--- Response Body ---\n'; + result += JSON.stringify(resp.json(), null, 2); + + setResult(output, result); + } catch(err) { + setError(output, err); + } finally { + btn.disabled = false; + } +}); + +// --- Abort --- + +let abortController = null; + +$('abort-start').addEventListener('click', async function() { + const output = $('abort-output'), + cancelBtn = $('abort-cancel'), + startBtn = $('abort-start'); + + abortController = new AbortController(); + setLoading(output, 'Request started (10s delay)... click Abort to cancel'); + cancelBtn.disabled = false; + startBtn.disabled = true; + + try { + const resp = await request('https://httpbin.org/delay/10', { + signal: abortController.signal, + }); + + setResult(output, 'Request completed (was not aborted):\n' + formatResponse(resp)); + } catch(err) { + setError(output, err); + } finally { + cancelBtn.disabled = true; + startBtn.disabled = false; + abortController = null; + } +}); + +$('abort-cancel').addEventListener('click', function() { + if (abortController) { + abortController.abort(); + } +}); + +// --- Error Handling --- + +$('err-domain').addEventListener('click', async function() { + const output = $('err-output'), + btn = $('err-domain'); + + btn.disabled = true; + setLoading(output, 'Requesting blocked domain'); + + try { + await request('https://evil.com/steal-data'); + setResult(output, 'Unexpected success'); + } catch(err) { + setError(output, err); + } finally { + btn.disabled = false; + } +}); + +$('err-timeout').addEventListener('click', async function() { + const output = $('err-output'), + btn = $('err-timeout'); + + btn.disabled = true; + setLoading(output, 'Requesting with 1s timeout against 10s delay'); + + try { + await request('https://httpbin.org/delay/10', { timeout: 1000 }); + setResult(output, 'Unexpected success'); + } catch(err) { + setError(output, err); + } finally { + btn.disabled = false; + } +}); + +$('err-invalid').addEventListener('click', async function() { + const output = $('err-output'), + btn = $('err-invalid'); + + btn.disabled = true; + setLoading(output, 'Sending invalid URL'); + + try { + await request('not-a-valid-url'); + setResult(output, 'Unexpected success'); + } catch(err) { + setError(output, err); + } finally { + btn.disabled = false; + } +}); + +// --- Binary Response --- + +$('binary-fetch').addEventListener('click', async function() { + const output = $('binary-output'), + btn = $('binary-fetch'); + + btn.disabled = true; + clearElement(output); + + const loadingSpan = document.createElement('span'); + + loadingSpan.className = 'loading'; + loadingSpan.textContent = 'Loading image'; + output.appendChild(loadingSpan); + + try { + const resp = await request('https://httpbin.org/image/png'), + bytes = resp.bytes(), + blob = new Blob([bytes], { type: 'image/png' }), + url = URL.createObjectURL(blob), + img = document.createElement('img'); + + img.src = url; + img.alt = 'Image from httpbin.org'; + + clearElement(output); + output.classList.remove('error', 'loading'); + output.appendChild(document.createTextNode( + 'Status: ' + resp.status + ' | Size: ' + bytes.length + ' bytes\n', + )); + output.appendChild(img); + } catch(err) { + clearElement(output); + setError(output, err); + } finally { + btn.disabled = false; + } +}); diff --git a/examples/tauri-app/src/styles.css b/examples/tauri-app/src/styles.css new file mode 100644 index 0000000..ee37b3a --- /dev/null +++ b/examples/tauri-app/src/styles.css @@ -0,0 +1,378 @@ +/* + * HTTP Client Plugin Demo — Stylesheet + * + * Color palette: warm neutrals (stone tones) with teal accent. + * Designed for 960x720 desktop window and mobile webviews. + * 3-space indentation per project convention. + */ + +/* --- Reset & base --- */ + +*, +*::before, +*::after { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +:root { + --color-bg: #f7f6f4; + --color-surface: #ffffff; + --color-text: #2c2926; + --color-text-subdued: #78716c; + --color-border: #d6d3d1; + --color-border-subtle: #e7e5e4; + + --color-accent: #0d7377; + --color-accent-hover: #0a5c5f; + --color-accent-text: #ffffff; + + --color-danger: #b91c1c; + --color-danger-hover: #991b1b; + --color-danger-bg: #fef2f2; + --color-danger-text: #dc2626; + + --color-output-bg: #1c1917; + --color-output-text: #d6d3d1; + --color-output-accent: #5eead4; + + --color-code-bg: #e7e5e4; + + --radius-sm: 4px; + --radius-md: 6px; + --radius-lg: 8px; + + --shadow-panel: 0 1px 2px rgba(0, 0, 0, 0.06), 0 1px 3px rgba(0, 0, 0, 0.08); + + --font-sans: -apple-system, BlinkMacSystemFont, "Segoe UI", system-ui, sans-serif; + --font-mono: "SF Mono", "Cascadia Code", "Fira Code", Menlo, Consolas, monospace; + + --transition-fast: 120ms ease; +} + +html { + font-size: 16px; + -webkit-text-size-adjust: 100%; +} + +body { + font-family: var(--font-sans); + background: var(--color-bg); + color: var(--color-text); + line-height: 1.5; + padding: 1.25rem; + min-height: 100vh; +} + +/* --- App header --- */ + +.app-header { + max-width: 960px; + margin: 0 auto 1.25rem; +} + +.app-header h1 { + font-size: 1.25rem; + font-weight: 700; + letter-spacing: -0.01em; + line-height: 1.25; +} + +.subtitle { + color: var(--color-text-subdued); + font-size: 0.8125rem; + margin-top: 0.25rem; + line-height: 1.4; +} + +.subtitle .separator { + display: inline-block; + width: 3px; + height: 3px; + border-radius: 50%; + background: var(--color-text-subdued); + vertical-align: middle; + margin: 0 0.5rem; + opacity: 0.6; +} + +/* --- Panel grid --- */ + +.panels { + display: grid; + grid-template-columns: 1fr; + gap: 0.875rem; + max-width: 960px; + margin: 0 auto; +} + +@media (min-width: 640px) { + .panels { + grid-template-columns: repeat(2, 1fr); + } +} + +/* --- Panel --- */ + +.panel { + background: var(--color-surface); + border: 1px solid var(--color-border-subtle); + border-radius: var(--radius-lg); + padding: 1rem; + box-shadow: var(--shadow-panel); + display: flex; + flex-direction: column; +} + +.panel h2 { + font-size: 0.875rem; + font-weight: 600; + letter-spacing: 0.01em; + text-transform: uppercase; + color: var(--color-accent); + margin-bottom: 0.5rem; +} + +.hint { + font-size: 0.8125rem; + color: var(--color-text-subdued); + margin-bottom: 0.5rem; + line-height: 1.4; +} + +/* --- Controls --- */ + +.controls { + display: flex; + flex-wrap: wrap; + gap: 0.5rem; + align-items: flex-start; + margin-bottom: 0.75rem; +} + +.controls input[type="text"] { + flex: 1; + min-width: 0; + padding: 0.4rem 0.6rem; + border: 1px solid var(--color-border); + border-radius: var(--radius-sm); + font-size: 0.8125rem; + font-family: var(--font-mono); + color: var(--color-text); + background: var(--color-surface); + transition: border-color var(--transition-fast); +} + +.controls input[type="text"]:focus { + outline: none; + border-color: var(--color-accent); + box-shadow: 0 0 0 2px rgba(13, 115, 119, 0.15); +} + +.controls textarea { + width: 100%; + padding: 0.4rem 0.6rem; + border: 1px solid var(--color-border); + border-radius: var(--radius-sm); + font-size: 0.8125rem; + font-family: var(--font-mono); + color: var(--color-text); + background: var(--color-surface); + resize: vertical; + transition: border-color var(--transition-fast); +} + +.controls textarea:focus { + outline: none; + border-color: var(--color-accent); + box-shadow: 0 0 0 2px rgba(13, 115, 119, 0.15); +} + +/* --- Buttons --- */ + +.btn { + padding: 0.4rem 0.75rem; + border: 1px solid var(--color-border); + border-radius: var(--radius-sm); + background: var(--color-surface); + color: var(--color-text); + cursor: pointer; + font-size: 0.8125rem; + font-family: var(--font-sans); + font-weight: 500; + line-height: 1.4; + white-space: nowrap; + transition: + background var(--transition-fast), + border-color var(--transition-fast), + color var(--transition-fast), + box-shadow var(--transition-fast); + -webkit-tap-highlight-color: transparent; +} + +.btn:hover { + background: var(--color-bg); + border-color: var(--color-border); +} + +.btn:active { + transform: translateY(0.5px); +} + +.btn:focus-visible { + outline: 2px solid var(--color-accent); + outline-offset: 1px; +} + +.btn:disabled { + opacity: 0.45; + cursor: not-allowed; + transform: none; +} + +.btn-primary { + background: var(--color-accent); + border-color: var(--color-accent); + color: var(--color-accent-text); +} + +.btn-primary:hover:not(:disabled) { + background: var(--color-accent-hover); + border-color: var(--color-accent-hover); +} + +.btn-danger { + background: var(--color-surface); + border-color: var(--color-danger); + color: var(--color-danger); +} + +.btn-danger:hover:not(:disabled) { + background: var(--color-danger-bg); +} + +/* --- Output area --- */ + +.output { + background: var(--color-output-bg); + color: var(--color-output-text); + padding: 0.75rem; + border-radius: var(--radius-md); + font-size: 0.75rem; + font-family: var(--font-mono); + overflow-x: auto; + white-space: pre-wrap; + word-break: break-word; + min-height: 3rem; + max-height: 260px; + overflow-y: auto; + flex: 1; + line-height: 1.5; +} + +.output img { + max-width: 200px; + display: block; + margin-top: 0.5rem; + border-radius: var(--radius-sm); +} + +/* Scrollbar styling for output areas */ +.output::-webkit-scrollbar { + width: 6px; + height: 6px; +} + +.output::-webkit-scrollbar-track { + background: transparent; +} + +.output::-webkit-scrollbar-thumb { + background: rgba(255, 255, 255, 0.15); + border-radius: 3px; +} + +.output::-webkit-scrollbar-thumb:hover { + background: rgba(255, 255, 255, 0.25); +} + +/* --- State classes --- */ + +.error { + color: #f87171; +} + +.loading { + color: var(--color-output-accent); +} + +/* Loading dot animation */ +.loading::after { + content: ''; + animation: dots 1.2s steps(4, end) infinite; +} + +@keyframes dots { + 0% { content: ''; } + 25% { content: '.'; } + 50% { content: '..'; } + 75% { content: '...'; } +} + +/* --- Inline code --- */ + +code { + background: var(--color-code-bg); + padding: 0.1rem 0.35rem; + border-radius: var(--radius-sm); + font-size: 0.85em; + font-family: var(--font-mono); +} + +/* --- Responsive adjustments --- */ + +/* Small screens: tighter spacing */ +@media (max-width: 639px) { + body { + padding: 1rem; + } + + .app-header { + margin-bottom: 1rem; + } + + .panel { + padding: 0.875rem; + } + + .output { + max-height: 200px; + } + + .subtitle .separator { + display: none; + } + + .subtitle code:first-child { + display: block; + margin-bottom: 0.15rem; + } +} + +/* Tall enough for comfortable two-column layout */ +@media (min-width: 640px) and (min-height: 600px) { + .panels { + gap: 0.75rem; + } +} + +/* Respect reduced motion */ +@media (prefers-reduced-motion: reduce) { + *, + *::before, + *::after { + animation-duration: 0.01ms !important; + transition-duration: 0.01ms !important; + } +} diff --git a/examples/tauri-app/vite.config.js b/examples/tauri-app/vite.config.js new file mode 100644 index 0000000..56c8424 --- /dev/null +++ b/examples/tauri-app/vite.config.js @@ -0,0 +1,9 @@ +import { defineConfig } from 'vite'; + +export default defineConfig({ + root: 'src', + build: { + outDir: '../dist', + emptyOutDir: true, + }, +}); From 3424843bb1f894d0c07dda7b0e59a6ea8cefb613 Mon Sep 17 00:00:00 2001 From: Jordan Hafer <42755763+jjhafer@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:28:21 -0400 Subject: [PATCH 3/3] sub(feat): use binary framing for response bodies Base64-encoding response bodies added ~33% size overhead and unnecessary CPU work on every response. Binary framing sends raw bytes directly over IPC, removing that cost while keeping a structured metadata header. Switch the fetch command from returning a JSON-serialized FetchResponse to a binary-framed tauri::ipc::Response, eliminating base64 encoding overhead for response bodies. The frame format is: [4-byte BE metadata length][metadata JSON][body bytes] --- README.md | 11 +- guest-js/errors.ts | 5 + guest-js/http-client.test.ts | 381 ++++++++++++++++++++++++++---- guest-js/http-client.ts | 141 +++++++++--- guest-js/types.ts | 10 +- src/client.rs | 434 +++++++++++++++-------------------- src/commands.rs | 273 +++++++++++++++++++++- src/types.rs | 119 ++++++---- 8 files changed, 981 insertions(+), 393 deletions(-) diff --git a/README.md b/README.md index fd79ef9..db4977a 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ requests from Tauri applications. * Abort in-flight requests via `AbortController` * Case-insensitive `HttpHeaders` with multi-value support - * Binary request and response bodies (base64 over IPC) + * Binary request and response bodies * Runtime allowlist management from Rust @@ -330,7 +330,6 @@ produce a `FORBIDDEN_HEADER` error. Full example: ```rust -use std::collections::HashMap; use std::time::Duration; use tauri_plugin_http_client::config::RetryConfig; @@ -343,10 +342,10 @@ let plugin = tauri_plugin_http_client::Builder::new() .max_redirects(5) .max_response_body_size(5 * 1024 * 1024) .max_allowlist_size(64) - .user_agent("my-app/1.0".into()) - .default_headers(HashMap::from([ - ("X-App-Version".into(), "1.0".into()), - ])) + .user_agent("my-app/1.0") + .default_headers([ + ("X-App-Version", "1.0"), + ]) .retry(RetryConfig::default()) .build(); ``` diff --git a/guest-js/errors.ts b/guest-js/errors.ts index d2ef89a..73ff536 100644 --- a/guest-js/errors.ts +++ b/guest-js/errors.ts @@ -48,6 +48,9 @@ export enum HttpErrorCode { /** A forbidden header was provided (e.g. Host). */ FORBIDDEN_HEADER = 'FORBIDDEN_HEADER', + /** IPC response could not be decoded (malformed binary frame). */ + PROTOCOL_ERROR = 'PROTOCOL_ERROR', + /** Unclassified error. */ ERROR = 'ERROR', @@ -66,6 +69,8 @@ export class HttpClientError extends Error { public constructor(code: HttpErrorCode, message: string) { super(message); + // Required for instanceof when transpiled to ES5 + Object.setPrototypeOf(this, new.target.prototype); this.name = 'HttpClientError'; this.code = code; } diff --git a/guest-js/http-client.test.ts b/guest-js/http-client.test.ts index debcd29..060485d 100644 --- a/guest-js/http-client.test.ts +++ b/guest-js/http-client.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import type { RawFetchResponse } from './types'; +import type { FetchResponseMetadata } from './types'; import type { HttpClientError as HttpClientErrorType } from './errors'; // Mock @tauri-apps/api/core before importing the module under test @@ -10,24 +10,42 @@ vi.mock('@tauri-apps/api/core', () => { }); // Import after mock setup -const { request } = await import('./http-client'); +const { request, decodeIpcResult } = await import('./http-client'); const { HttpClientError, HttpErrorCode } = await import('./errors'); const { HttpHeaders } = await import('./headers'); -function makeRawResponse(overrides?: Partial): RawFetchResponse { - return { +// ---- helpers ---------------------------------------------------------------- + +/** + * Builds a binary-framed ArrayBuffer matching the IPC format: + * [4-byte BE metadata length][metadata JSON][body bytes] + */ +function makeBinaryFrame(metadata: Partial, bodyText: string): ArrayBuffer; +function makeBinaryFrame(metadata: Partial, bodyBytes: Uint8Array): ArrayBuffer; +function makeBinaryFrame(metadata: Partial, body: string | Uint8Array): ArrayBuffer { + const meta: FetchResponseMetadata = { status: 200, statusText: 'OK', headers: { 'content-type': [ 'application/json' ] }, - body: '{"key":"value"}', - bodyEncoding: 'utf8', url: 'https://api.example.com/data', redirected: false, retryCount: 0, - ...overrides, + ...metadata, }; + + const metaBytes = new TextEncoder().encode(JSON.stringify(meta)), + bodyBytes = typeof body === 'string' ? new TextEncoder().encode(body) : body, + buf = new ArrayBuffer(4 + metaBytes.length + bodyBytes.length), + view = new DataView(buf), + u8 = new Uint8Array(buf); + + view.setUint32(0, metaBytes.length); + u8.set(metaBytes, 4); + u8.set(bodyBytes, 4 + metaBytes.length); + + return buf; } describe('request()', () => { @@ -41,9 +59,7 @@ describe('request()', () => { }); it('sends a basic GET request and returns a response', async () => { - const raw = makeRawResponse(); - - mockInvoke.mockResolvedValueOnce(raw); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '{"key":"value"}')); const resp = await request('https://api.example.com/data'); @@ -59,7 +75,7 @@ describe('request()', () => { }); it('returns text body correctly', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'hello world', bodyEncoding: 'utf8' })); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, 'hello world')); const resp = await request('https://example.com'); @@ -67,7 +83,7 @@ describe('request()', () => { }); it('parses JSON body correctly', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '{"key":"value"}')); const resp = await request('https://example.com'), data = resp.json<{ key: string }>(); @@ -75,9 +91,8 @@ describe('request()', () => { expect(data.key).toBe('value'); }); - it('decodes base64 body to bytes', async () => { - // "hello" in base64 - mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'aGVsbG8=', bodyEncoding: 'base64' })); + it('decodes binary body to bytes', async () => { + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, new Uint8Array([ 0x68, 0x65, 0x6C, 0x6C, 0x6F ]))); const resp = await request('https://example.com'), bytes = resp.bytes(); @@ -87,7 +102,7 @@ describe('request()', () => { }); it('sends POST with string body', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); await request('https://example.com', { method: 'POST', body: 'payload' }); @@ -99,7 +114,7 @@ describe('request()', () => { }); it('sends POST with object body as JSON', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); await request('https://example.com', { method: 'POST', body: { foo: 'bar' } }); @@ -110,7 +125,7 @@ describe('request()', () => { }); it('sends POST with Uint8Array body as base64', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); // Construct Uint8Array directly (not via TextEncoder) to avoid // cross-realm instanceof issues in jsdom test environment. @@ -126,7 +141,7 @@ describe('request()', () => { }); it('sends headers from Record', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); await request('https://example.com', { headers: { 'Authorization': 'Bearer token' }, @@ -138,7 +153,7 @@ describe('request()', () => { }); it('sends headers from HttpHeaders instance', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); const headers = new HttpHeaders(); @@ -152,7 +167,7 @@ describe('request()', () => { }); it('sends timeout', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); await request('https://example.com', { timeout: 5000 }); @@ -162,7 +177,7 @@ describe('request()', () => { }); it('ok is false for non-2xx status', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse({ status: 404, statusText: 'Not Found' })); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({ status: 404, statusText: 'Not Found' }, '')); const resp = await request('https://example.com'); @@ -200,7 +215,7 @@ describe('request()', () => { }); it('includes requestId when signal is provided', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); const controller = new AbortController(); @@ -214,9 +229,9 @@ describe('request()', () => { it('calls abort_request when signal fires', async () => { // Make invoke hang until we abort - let resolveInvoke: ((v: RawFetchResponse) => void) | undefined; + let resolveInvoke: ((v: ArrayBuffer) => void) | undefined; - const invokePromise = new Promise((resolve) => { + const invokePromise = new Promise((resolve) => { resolveInvoke = resolve; }); @@ -236,7 +251,7 @@ describe('request()', () => { // Resolve the fetch so the promise settles if (resolveInvoke) { - resolveInvoke(makeRawResponse()); + resolveInvoke(makeBinaryFrame({}, '')); } const resp = await reqPromise; @@ -251,20 +266,16 @@ describe('request()', () => { expect(abortCall).toBeDefined(); }); - it('decodes base64 body to text via text()', async () => { - // "héllo" in UTF-8 then base64 - const bytes = new TextEncoder().encode('héllo'), - b64 = btoa(String.fromCharCode(...bytes)); - - mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: b64, bodyEncoding: 'base64' })); + it('decodes non-ASCII body to text via text()', async () => { + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, new TextEncoder().encode('héllo'))); const resp = await request('https://example.com'); expect(resp.text()).toBe('héllo'); }); - it('converts utf8 body to bytes via bytes()', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'hello', bodyEncoding: 'utf8' })); + it('converts body to bytes via bytes()', async () => { + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, 'hello')); const resp = await request('https://example.com'), bytes = resp.bytes(); @@ -274,7 +285,7 @@ describe('request()', () => { }); it('caches text() return value', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'cached', bodyEncoding: 'utf8' })); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, 'cached')); const resp = await request('https://example.com'); @@ -284,7 +295,7 @@ describe('request()', () => { }); it('caches bytes() return value', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'aGVsbG8=', bodyEncoding: 'base64' })); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, 'hello')); const resp = await request('https://example.com'); @@ -296,7 +307,7 @@ describe('request()', () => { }); it('removes abort listener after successful request', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); const controller = new AbortController(), removeSpy = vi.spyOn(controller.signal, 'removeEventListener'); @@ -307,7 +318,7 @@ describe('request()', () => { }); it('does not include requestId without signal', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); await request('https://example.com'); @@ -317,7 +328,7 @@ describe('request()', () => { }); it('omits undefined optional fields from payload', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); await request('https://example.com'); @@ -328,7 +339,7 @@ describe('request()', () => { it('ok boundary: 199 is not ok, 200 is ok, 299 is ok, 300 is not ok', async () => { for (const [ status, expectedOk ] of [ [ 199, false ], [ 200, true ], [ 299, true ], [ 300, false ] ] as [number, boolean][]) { - mockInvoke.mockResolvedValueOnce(makeRawResponse({ status })); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({ status }, '')); const resp = await request('https://example.com'); @@ -337,10 +348,10 @@ describe('request()', () => { }); it('handles redirected response', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse({ + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({ redirected: true, url: 'https://api.example.com/final', - })); + }, '')); const resp = await request('https://api.example.com/start'); @@ -349,7 +360,7 @@ describe('request()', () => { }); it('passes maxRetries in payload', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); await request('https://example.com', { maxRetries: 5 }); @@ -359,7 +370,7 @@ describe('request()', () => { }); it('omits maxRetries when undefined', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); await request('https://example.com'); @@ -369,7 +380,7 @@ describe('request()', () => { }); it('exposes retryCount from response', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse({ retryCount: 2 })); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({ retryCount: 2 }, '')); const resp = await request('https://example.com'); @@ -377,7 +388,7 @@ describe('request()', () => { }); it('retryCount is 0 when no retries occurred', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); const resp = await request('https://example.com'); @@ -385,7 +396,7 @@ describe('request()', () => { }); it('sends empty string body correctly', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); await request('https://example.com', { method: 'POST', body: '' }); @@ -396,7 +407,7 @@ describe('request()', () => { }); it('sends multi-value headers via HttpHeaders as comma-joined string', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); const headers = new HttpHeaders(); @@ -416,7 +427,7 @@ describe('request()', () => { // Fire 10 rapid requests, each generating a unique requestId for (let i = 0; i < 10; i++) { - mockInvoke.mockResolvedValueOnce(makeRawResponse()); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, '')); const controller = new AbortController(); @@ -432,7 +443,7 @@ describe('request()', () => { }); it('json() throws on invalid JSON body', async () => { - mockInvoke.mockResolvedValueOnce(makeRawResponse({ body: 'not valid json', bodyEncoding: 'utf8' })); + mockInvoke.mockResolvedValueOnce(makeBinaryFrame({}, 'not valid json')); const resp = await request('https://example.com'); @@ -441,4 +452,274 @@ describe('request()', () => { expect(fn).toThrow(); }); + // ---- binary frame path -------------------------------------------------- + + it('decodes binary frame: basic status, headers, text body', async () => { + const buf = makeBinaryFrame( + { status: 200, statusText: 'OK', headers: { 'content-type': [ 'text/plain' ] }, url: 'https://example.com', redirected: false, retryCount: 0 }, + 'hello binary' + ); + + mockInvoke.mockResolvedValueOnce(buf); + + const resp = await request('https://example.com'); + + expect(resp.status).toBe(200); + expect(resp.statusText).toBe('OK'); + expect(resp.ok).toBe(true); + expect(resp.url).toBe('https://example.com'); + expect(resp.redirected).toBe(false); + expect(resp.headers.get('content-type')).toBe('text/plain'); + expect(resp.text()).toBe('hello binary'); + }); + + it('decodes binary frame: binary body via bytes()', async () => { + const binaryBody = new Uint8Array([ 0x89, 0x50, 0x4E, 0x47 ]), + buf = makeBinaryFrame({ status: 200, statusText: 'OK', headers: {}, url: 'https://example.com', redirected: false, retryCount: 0 }, binaryBody); + + mockInvoke.mockResolvedValueOnce(buf); + + const resp = await request('https://example.com'), + bytes = resp.bytes(); + + expect(bytes).toBeInstanceOf(Uint8Array); + expect(Array.from(bytes)).toEqual([ 0x89, 0x50, 0x4E, 0x47 ]); + }); + + it('decodes binary frame: empty body', async () => { + const buf = makeBinaryFrame({ status: 204, statusText: 'No Content', headers: {}, url: 'https://example.com', redirected: false, retryCount: 0 }, ''); + + mockInvoke.mockResolvedValueOnce(buf); + + const resp = await request('https://example.com'); + + expect(resp.status).toBe(204); + expect(resp.text()).toBe(''); + expect(resp.bytes().length).toBe(0); + }); + + it('decodes binary frame: retryCount and redirected flags', async () => { + const buf = makeBinaryFrame( + { status: 200, statusText: 'OK', headers: {}, url: 'https://example.com/final', redirected: true, retryCount: 2 }, + 'body' + ); + + mockInvoke.mockResolvedValueOnce(buf); + + const resp = await request('https://example.com'); + + expect(resp.retryCount).toBe(2); + expect(resp.redirected).toBe(true); + expect(resp.url).toBe('https://example.com/final'); + }); + + it('decodes binary frame: multi-value headers', async () => { + const buf = makeBinaryFrame( + { status: 200, statusText: 'OK', headers: { 'set-cookie': [ 'a=1', 'b=2' ] }, url: 'https://example.com', redirected: false, retryCount: 0 }, + '' + ); + + mockInvoke.mockResolvedValueOnce(buf); + + const resp = await request('https://example.com'); + + expect(resp.headers.getAll('set-cookie')).toEqual([ 'a=1', 'b=2' ]); + }); + + it('bytes() caches the result — subsequent calls return the same reference', async () => { + const buf = makeBinaryFrame({ status: 200, statusText: 'OK', headers: {}, url: 'https://example.com', redirected: false, retryCount: 0 }, 'hello'); + + mockInvoke.mockResolvedValueOnce(buf); + + const resp = await request('https://example.com'), + first = resp.bytes(); + + first[0] = 0xFF; + + const second = resp.bytes(); + + // second call returns the cached copy which was already mutated + expect(second[0]).toBe(0xFF); + // but both calls return the same reference (cached) + expect(first).toBe(second); + }); + +}); + +// ---- decodeIpcResult() unit tests ------------------------------------------- + +describe('decodeIpcResult()', () => { + + it('decodes a binary ArrayBuffer frame', () => { + const buf = makeBinaryFrame( + { status: 201, statusText: 'Created', headers: { 'x-id': [ '42' ] }, url: 'https://x.com', redirected: false, retryCount: 1 }, + 'created' + ); + + const decoded = decodeIpcResult(buf); + + expect(decoded.metadata.status).toBe(201); + expect(decoded.metadata.statusText).toBe('Created'); + expect(decoded.metadata.url).toBe('https://x.com'); + expect(decoded.metadata.redirected).toBe(false); + expect(decoded.metadata.retryCount).toBe(1); + expect(decoded.metadata.headers['x-id']).toEqual([ '42' ]); + expect(new TextDecoder().decode(decoded.body)).toBe('created'); + }); + + it('normalizes Android number-array to ArrayBuffer and decodes frame', () => { + const originalBuf = makeBinaryFrame( + { status: 200, statusText: 'OK', headers: {}, url: 'https://android.example.com', redirected: false, retryCount: 0 }, + 'android body' + ); + + // Simulate what Tauri delivers on Android ≥ 1 KB: number array + const numberArray = Array.from(new Uint8Array(originalBuf)); + + const decoded = decodeIpcResult(numberArray); + + expect(decoded.metadata.status).toBe(200); + expect(decoded.metadata.url).toBe('https://android.example.com'); + expect(new TextDecoder().decode(decoded.body)).toBe('android body'); + }); + + it('binary frame: empty body', () => { + const buf = makeBinaryFrame({ status: 204, statusText: 'No Content', headers: {}, url: 'https://x.com', redirected: false, retryCount: 0 }, ''); + + const decoded = decodeIpcResult(buf); + + expect(decoded.metadata.status).toBe(204); + expect(decoded.body.length).toBe(0); + }); + + it('binary frame: body containing arbitrary bytes including null bytes', () => { + const bodyBytes = new Uint8Array([ 0x00, 0x01, 0xFF, 0xFE, 0x00 ]), + buf = makeBinaryFrame({ status: 200, statusText: 'OK', headers: {}, url: 'https://x.com', redirected: false, retryCount: 0 }, bodyBytes); + + const decoded = decodeIpcResult(buf); + + expect(Array.from(decoded.body)).toEqual([ 0x00, 0x01, 0xFF, 0xFE, 0x00 ]); + }); + + // ---- parser error cases (PROTOCOL_ERROR) ------------------------------------ + + it('throws PROTOCOL_ERROR when buffer is less than 4 bytes', () => { + const buf = new ArrayBuffer(3); + + expect(() => { return decodeIpcResult(buf); }).toThrow(expect.objectContaining({ + code: HttpErrorCode.PROTOCOL_ERROR, + })); + }); + + it('throws PROTOCOL_ERROR when metadata length exceeds buffer', () => { + // Frame claims metaLen = 9999 but buffer is only 10 bytes + const buf = new ArrayBuffer(10), + view = new DataView(buf); + + view.setUint32(0, 9999); + + expect(() => { return decodeIpcResult(buf); }).toThrow(expect.objectContaining({ + code: HttpErrorCode.PROTOCOL_ERROR, + })); + }); + + it('throws PROTOCOL_ERROR when metadata length is 0', () => { + const buf = new ArrayBuffer(8), + view = new DataView(buf); + + view.setUint32(0, 0); + + expect(() => { return decodeIpcResult(buf); }).toThrow(expect.objectContaining({ + code: HttpErrorCode.PROTOCOL_ERROR, + })); + }); + + it('throws PROTOCOL_ERROR when metadata bytes are invalid UTF-8', () => { + // 4-byte header + 4 bytes of invalid UTF-8 sequence + const buf = new ArrayBuffer(8), + view = new DataView(buf), + u8 = new Uint8Array(buf); + + view.setUint32(0, 4); + // Lone continuation bytes — invalid UTF-8 + u8[4] = 0x80; + u8[5] = 0x81; + u8[6] = 0x82; + u8[7] = 0x83; + + expect(() => { return decodeIpcResult(buf); }).toThrow(expect.objectContaining({ + code: HttpErrorCode.PROTOCOL_ERROR, + })); + }); + + it('throws PROTOCOL_ERROR when metadata is not valid JSON', () => { + const badJson = 'not { valid json', + badJsonBytes = new TextEncoder().encode(badJson), + buf = new ArrayBuffer(4 + badJsonBytes.length), + view = new DataView(buf), + u8 = new Uint8Array(buf); + + view.setUint32(0, badJsonBytes.length); + u8.set(badJsonBytes, 4); + + expect(() => { return decodeIpcResult(buf); }).toThrow(expect.objectContaining({ + code: HttpErrorCode.PROTOCOL_ERROR, + })); + }); + + it('throws PROTOCOL_ERROR when metadata is missing required fields', () => { + const incompleteJson = JSON.stringify({ status: 200 }), // missing statusText, url, etc. + jsonBytes = new TextEncoder().encode(incompleteJson), + buf = new ArrayBuffer(4 + jsonBytes.length), + view = new DataView(buf), + u8 = new Uint8Array(buf); + + view.setUint32(0, jsonBytes.length); + u8.set(jsonBytes, 4); + + expect(() => { return decodeIpcResult(buf); }).toThrow(expect.objectContaining({ + code: HttpErrorCode.PROTOCOL_ERROR, + })); + }); + + it('throws PROTOCOL_ERROR and request() wraps it as HttpClientError', async () => { + // Buffer too small — simulates a corrupted IPC response + const tinyBuf = new ArrayBuffer(2); + + mockInvoke.mockResolvedValueOnce(tinyBuf); + + try { + await request('https://example.com'); + expect.fail('should have thrown'); + } catch(err) { + expect(err).toBeInstanceOf(HttpClientError); + expect((err as HttpClientErrorType).code).toBe(HttpErrorCode.PROTOCOL_ERROR); + } + }); + + it('valid frame: does not throw (sanity check)', () => { + const buf = makeBinaryFrame( + { status: 200, statusText: 'OK', headers: {}, url: 'https://x.com', redirected: false, retryCount: 0 }, + 'hello' + ); + + // Sanity check: valid frame should NOT throw + expect(() => { return decodeIpcResult(buf); }).not.toThrow(); + }); + + it('throws PROTOCOL_ERROR when metadata headers field is null', () => { + const metaJson = JSON.stringify({ status: 200, statusText: 'OK', headers: null, url: 'https://x.com', redirected: false, retryCount: 0 }), + metaBytes = new TextEncoder().encode(metaJson), + buf = new ArrayBuffer(4 + metaBytes.length), + view = new DataView(buf), + u8 = new Uint8Array(buf); + + view.setUint32(0, metaBytes.length); + u8.set(metaBytes, 4); + + expect(() => { return decodeIpcResult(buf); }).toThrow(expect.objectContaining({ + code: HttpErrorCode.PROTOCOL_ERROR, + })); + }); + }); diff --git a/guest-js/http-client.ts b/guest-js/http-client.ts index cc1321e..60c85d3 100644 --- a/guest-js/http-client.ts +++ b/guest-js/http-client.ts @@ -1,7 +1,7 @@ import { invoke } from '@tauri-apps/api/core'; import { HttpHeaders } from './headers'; -import { parseError } from './errors'; -import type { BodyEncoding, RequestOptions, HttpResponse, RawFetchRequest, RawFetchResponse } from './types'; +import { HttpClientError, HttpErrorCode, parseError } from './errors'; +import type { BodyEncoding, RequestOptions, HttpResponse, RawFetchRequest, FetchResponseMetadata } from './types'; let requestCounter = 0; @@ -46,9 +46,9 @@ export async function request(url: string, options?: RequestOptions): Promise; + + return typeof o.status === 'number' + && typeof o.statusText === 'string' + && typeof o.headers === 'object' && o.headers !== null + && typeof o.url === 'string' + && typeof o.redirected === 'boolean' + && typeof o.retryCount === 'number'; } -function wrapResponse(raw: RawFetchResponse): HttpResponse { - const headers = new HttpHeaders(raw.headers); +/** + * Decodes the binary frame format used on desktop platforms: + * `[4-byte BE metadata length][metadata JSON bytes][body bytes]` + * + * Throws `HttpClientError` with code `PROTOCOL_ERROR` if the frame is + * malformed (truncated, invalid UTF-8 metadata, invalid JSON, or missing + * required fields). + */ +function decodeBinaryFrame(buf: ArrayBuffer): DecodedResponse { + if (buf.byteLength < 4) { + throw new HttpClientError(HttpErrorCode.PROTOCOL_ERROR, `binary frame too small: ${buf.byteLength} bytes`); + } + + const view = new DataView(buf), + metaLen = view.getUint32(0); + + if (metaLen === 0) { + throw new HttpClientError(HttpErrorCode.PROTOCOL_ERROR, 'binary frame metadata length is 0'); + } + + if (4 + metaLen > buf.byteLength) { + throw new HttpClientError( + HttpErrorCode.PROTOCOL_ERROR, + `binary frame metadata length ${metaLen} exceeds buffer size ${buf.byteLength}` + ); + } + + const metaBytes = new Uint8Array(buf, 4, metaLen); + + let metaJson: string; + + try { + metaJson = new TextDecoder('utf-8', { fatal: true }).decode(metaBytes); + } catch{ + throw new HttpClientError(HttpErrorCode.PROTOCOL_ERROR, 'binary frame metadata is not valid UTF-8'); + } + + let parsed: unknown; + + try { + parsed = JSON.parse(metaJson); + } catch{ + throw new HttpClientError(HttpErrorCode.PROTOCOL_ERROR, 'binary frame metadata is not valid JSON'); + } + + if (!isValidFetchMetadata(parsed)) { + throw new HttpClientError(HttpErrorCode.PROTOCOL_ERROR, 'binary frame metadata is missing required fields'); + } + + const metadata = parsed as FetchResponseMetadata, + body = new Uint8Array(buf, 4 + metaLen); + + return { metadata, body }; +} + +function wrapResponse(decoded: DecodedResponse): HttpResponse { + const { metadata, body } = decoded, + headers = new HttpHeaders(metadata.headers); // Cache decoded body values let textValue: string | undefined, bytesValue: Uint8Array | undefined; return { - status: raw.status, - statusText: raw.statusText, + status: metadata.status, + statusText: metadata.statusText, headers, - url: raw.url, - redirected: raw.redirected, - ok: raw.status >= 200 && raw.status < 300, // mirrors fetch() Response.ok - retryCount: raw.retryCount, + url: metadata.url, + redirected: metadata.redirected, + ok: metadata.status >= 200 && metadata.status < 300, // mirrors fetch() Response.ok + retryCount: metadata.retryCount, text(): string { if (textValue === undefined) { - if (raw.bodyEncoding === 'base64') { - const bytes = base64ToUint8Array(raw.body); - - textValue = new TextDecoder().decode(bytes); - } else { - textValue = raw.body; - } + textValue = new TextDecoder().decode(body); } return textValue; @@ -166,11 +250,8 @@ function wrapResponse(raw: RawFetchResponse): HttpResponse { bytes(): Uint8Array { if (bytesValue === undefined) { - if (raw.bodyEncoding === 'base64') { - bytesValue = base64ToUint8Array(raw.body); - } else { - bytesValue = new TextEncoder().encode(raw.body); - } + // Slice to own a non-shared copy (body is a view into the IPC buffer) + bytesValue = body.slice(); } return bytesValue; diff --git a/guest-js/types.ts b/guest-js/types.ts index d382c31..5fff525 100644 --- a/guest-js/types.ts +++ b/guest-js/types.ts @@ -31,13 +31,15 @@ export interface HttpResponse { bytes(): Uint8Array; } -/** Raw IPC response from the Rust backend. */ -export interface RawFetchResponse { +/** + * Response metadata from the Rust backend, carried in the binary frame header. + * Field names match the camelCase serialization of `FetchResponseMetadata` in + * Rust's `types.rs`. + */ +export interface FetchResponseMetadata { status: number; statusText: string; headers: Record; - body: string; - bodyEncoding: BodyEncoding; url: string; redirected: boolean; retryCount: number; diff --git a/src/client.rs b/src/client.rs index 972c4f5..7cff2a9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,7 +10,7 @@ use reqwest::redirect; use crate::allowlist::{DomainAllowlist, is_private_ip}; use crate::config::{HttpClientConfig, RetryConfig}; use crate::error::{Error, Result}; -use crate::types::{FetchRequest, FetchResponse}; +use crate::types::{BodyEncoding, ExecuteResult, FetchRequest, FetchResponseMetadata}; /// Headers that are always forbidden in per-request and default headers. /// @@ -283,10 +283,10 @@ impl HttpClientState { /// 5xx responses — these are valid HTTP responses, not transport errors). /// Intermediate retryable responses are fully read and discarded; the /// caller only sees the final attempt's response body. - pub async fn execute(&self, req: FetchRequest) -> Result { + pub(crate) async fn execute(&self, req: FetchRequest) -> Result { let method = parse_method(req.method.as_deref().unwrap_or("GET"))?; let body_bytes = match req.body { - Some(ref b) => Some(decode_request_body(b, req.body_encoding.as_deref())?), + Some(ref b) => Some(decode_request_body(b, req.body_encoding.as_ref())?), None => None, }; let timeout = req @@ -303,15 +303,26 @@ impl HttpClientState { // doesn't change between retries, so re-parsing is unnecessary. let url = self.allowlist.read().validate_url(&req.url)?; - let mut last_result: Option> = None; + let mut last_result: Option> = None; let mut attempt: u32 = 0; // Bounded: returns when attempt + 1 >= max_attempts (should_retry = false) loop { - assert!( - attempt < max_attempts, - "retry loop exceeded max_attempts ({max_attempts}); this is a bug" - ); + if attempt >= max_attempts { + // Invariant: structurally unreachable — should_retry prevents + // attempt from reaching max_attempts. This is a defensive + // check against an impossible condition, not normal error + // handling. If triggered, the retry loop has a logic bug. + tracing::error!( + attempt, + max_attempts, + url = %req.url, + "retry loop exceeded max_attempts; this is a bug" + ); + return Err(Error::Other(format!( + "retry loop exceeded max_attempts ({max_attempts}); this is a bug" + ))); + } if attempt > 0 { let backoff = calculate_backoff(retry_config, attempt, last_result.as_ref()); @@ -341,8 +352,10 @@ impl HttpClientState { let should_retry = attempt + 1 < max_attempts && method_retryable; match result { - Ok(resp) if should_retry && retry_config.is_retryable_status(resp.status) => { - last_result = Some(Ok(resp)); + Ok(ref resp) + if should_retry && retry_config.is_retryable_status(resp.metadata.status) => + { + last_result = Some(result); attempt += 1; } Err(e) if should_retry && e.is_retryable() => { @@ -350,7 +363,7 @@ impl HttpClientState { attempt += 1; } Ok(mut resp) => { - resp.retry_count = attempt; + resp.metadata.retry_count = attempt; return Ok(resp); } Err(e) => return Err(e), @@ -362,6 +375,9 @@ impl HttpClientState { /// /// This is the inner implementation called by [`execute`](Self::execute) /// on each attempt. It assumes URL validation has already been performed. + /// + /// Returns raw body bytes and metadata. Encoding for IPC transfer + /// (binary framing or JSON with base64) happens at the command layer. async fn execute_once( &self, url: &url::Url, @@ -369,7 +385,7 @@ impl HttpClientState { headers: &Option>, body: Option<&[u8]>, timeout: Option, - ) -> Result { + ) -> Result { let mut builder = self.client.request(method.clone(), url.clone()); for (key, value) in &self.config.default_headers { @@ -425,15 +441,6 @@ impl HttpClientState { // Collect response headers (multi-value support) let mut response_headers: HashMap> = HashMap::new(); - // Extract content type and Retry-After before consuming the response - // via bytes_stream(), since we need them later. - let content_type = response - .headers() - .get(reqwest::header::CONTENT_TYPE) - .and_then(|v| v.to_str().ok()) - .unwrap_or("") - .to_string(); - for (name, value) in response.headers() { let name = name.as_str().to_string(); @@ -447,28 +454,16 @@ impl HttpClientState { let body_bytes = self.read_body_with_limit(response).await?; - let is_text = is_text_content_type(&content_type); - let (body, body_encoding) = if is_text { - ( - String::from_utf8_lossy(&body_bytes).into_owned(), - "utf8".to_string(), - ) - } else { - ( - base64::engine::general_purpose::STANDARD.encode(&body_bytes), - "base64".to_string(), - ) - }; - - Ok(FetchResponse { - status: status.as_u16(), - status_text, - headers: response_headers, - body, - body_encoding, - url: final_url.to_string(), - redirected, - retry_count: 0, // Set by execute() after the loop + Ok(ExecuteResult { + metadata: FetchResponseMetadata { + status: status.as_u16(), + status_text, + headers: response_headers, + url: final_url.to_string(), + redirected, + retry_count: 0, // Set by execute() after the loop + }, + body: body_bytes, }) } @@ -553,35 +548,15 @@ fn parse_method(method: &str) -> Result { } } -fn decode_request_body(body: &str, encoding: Option<&str>) -> Result> { - match encoding.unwrap_or("utf8") { - "base64" => base64::engine::general_purpose::STANDARD +fn decode_request_body(body: &str, encoding: Option<&BodyEncoding>) -> Result> { + match encoding.unwrap_or(&BodyEncoding::Utf8) { + BodyEncoding::Utf8 => Ok(body.as_bytes().to_vec()), + BodyEncoding::Base64 => base64::engine::general_purpose::STANDARD .decode(body) .map_err(|e| Error::Other(format!("invalid base64 body: {e}"))), - _ => Ok(body.as_bytes().to_vec()), } } -/// Determines if a Content-Type value represents text content. -/// -/// Uses substring matching on the full content-type so that structured -/// types like `application/vnd.api+json` are correctly detected as text. -fn is_text_content_type(content_type: &str) -> bool { - let ct = content_type.to_lowercase(); - - ct.starts_with("text/") - || ct.contains("json") - || ct.contains("xml") - || ct.contains("javascript") - || ct.contains("html") - || ct.contains("css") - || ct.contains("svg") - || ct.contains("yaml") - || ct.contains("toml") - || ct.contains("csv") - || ct.contains("form-urlencoded") -} - /// Calculates the backoff duration for a retry attempt using exponential /// backoff with jitter. /// @@ -593,7 +568,7 @@ fn is_text_content_type(content_type: &str) -> bool { fn calculate_backoff( config: &RetryConfig, attempt: u32, - last_result: Option<&Result>, + last_result: Option<&Result>, ) -> Duration { // Check for Retry-After header on the last response if let Some(Ok(resp)) = last_result @@ -628,8 +603,8 @@ fn calculate_backoff( /// /// Supports both delta-seconds format (`Retry-After: 120`) and ignores /// HTTP-date format (too complex to parse without a date library). -fn parse_retry_after_from_response(resp: &FetchResponse) -> Option { - let values = resp.headers.get("retry-after")?; +fn parse_retry_after_from_response(resp: &ExecuteResult) -> Option { + let values = resp.metadata.headers.get("retry-after")?; let value = values.first()?; value.trim().parse::().ok().map(Duration::from_secs) @@ -722,21 +697,21 @@ mod tests { #[test] fn test_decode_request_body_utf8() { - let body = decode_request_body("hello", Some("utf8")).unwrap(); + let body = decode_request_body("hello", Some(&BodyEncoding::Utf8)).unwrap(); assert_eq!(body, b"hello"); } #[test] fn test_decode_request_body_base64() { - let body = decode_request_body("aGVsbG8=", Some("base64")).unwrap(); + let body = decode_request_body("aGVsbG8=", Some(&BodyEncoding::Base64)).unwrap(); assert_eq!(body, b"hello"); } #[test] fn test_decode_request_body_invalid_base64() { - assert!(decode_request_body("not valid base64!!!", Some("base64")).is_err()); + assert!(decode_request_body("not valid base64!!!", Some(&BodyEncoding::Base64)).is_err()); } #[test] @@ -746,44 +721,6 @@ mod tests { assert_eq!(body, b"hello"); } - #[test] - fn test_is_text_content_type() { - assert!(is_text_content_type("text/plain")); - assert!(is_text_content_type("text/html; charset=utf-8")); - assert!(is_text_content_type("application/json")); - assert!(is_text_content_type("application/xml")); - assert!(is_text_content_type("application/javascript")); - assert!(is_text_content_type("text/css")); - assert!(is_text_content_type("image/svg+xml")); - assert!(is_text_content_type("application/x-www-form-urlencoded")); - - assert!(!is_text_content_type("application/octet-stream")); - assert!(!is_text_content_type("image/png")); - assert!(!is_text_content_type("application/pdf")); - } - - #[test] - fn test_is_text_content_type_yaml_toml_csv() { - assert!(is_text_content_type("application/yaml")); - assert!(is_text_content_type("application/toml")); - assert!(is_text_content_type("text/csv")); - } - - #[test] - fn test_is_text_content_type_empty_string() { - assert!(!is_text_content_type("")); - } - - #[test] - fn test_yaml_content_type_is_text() { - assert!(is_text_content_type("application/x-yaml")); - } - - #[test] - fn test_content_type_with_charset_detected() { - assert!(is_text_content_type("text/plain; charset=iso-8859-1")); - } - #[test] fn test_parse_method_case_insensitive() { assert_eq!(parse_method("get").unwrap(), reqwest::Method::GET); @@ -802,13 +739,6 @@ mod tests { ); } - #[test] - fn test_decode_request_body_unknown_encoding_treated_as_utf8() { - let body = decode_request_body("hello", Some("unknown")).unwrap(); - - assert_eq!(body, b"hello"); - } - // --- InFlightRequests / abort tests --- #[tokio::test] @@ -1049,9 +979,9 @@ mod tests { let req = make_request(&localhost_url(&server, "/a")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); - assert!(resp.redirected); - assert_eq!(resp.body, "ok"); + assert_eq!(resp.metadata.status, 200); + assert!(resp.metadata.redirected); + assert_eq!(resp.body, b"ok"); } #[tokio::test] @@ -1105,9 +1035,9 @@ mod tests { // With max_redirects=3, the 4th redirect is stopped and the 3xx is returned assert!( - resp.status >= 300 && resp.status < 400, + resp.metadata.status >= 300 && resp.metadata.status < 400, "should return redirect status when max hops exceeded, got: {}", - resp.status + resp.metadata.status ); } @@ -1133,9 +1063,9 @@ mod tests { let req = make_request(&localhost_url(&server, "/start")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); - assert_eq!(resp.body, "final"); - assert!(resp.redirected); + assert_eq!(resp.metadata.status, 200); + assert_eq!(resp.body, b"final"); + assert!(resp.metadata.redirected); } #[tokio::test] @@ -1155,8 +1085,8 @@ mod tests { let resp = state.execute(req).await.unwrap(); // With max_redirects=0, the redirect is not followed - assert!(resp.status >= 300 && resp.status < 400); - assert!(!resp.redirected); + assert!(resp.metadata.status >= 300 && resp.metadata.status < 400); + assert!(!resp.metadata.redirected); } // --- Body size limit tests --- @@ -1652,15 +1582,16 @@ mod tests { #[test] fn test_calculate_backoff_with_retry_after_header() { let config = RetryConfig::default(); - let resp = FetchResponse { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["5".to_string()])]), - body: String::new(), - body_encoding: "utf8".to_string(), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, + let resp = ExecuteResult { + metadata: FetchResponseMetadata { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["5".to_string()])]), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: Vec::new(), }; let backoff = calculate_backoff(&config, 1, Some(&Ok(resp))); @@ -1674,15 +1605,16 @@ mod tests { max_retry_after: Duration::from_secs(10), ..RetryConfig::default() }; - let resp = FetchResponse { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["999".to_string()])]), - body: String::new(), - body_encoding: "utf8".to_string(), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, + let resp = ExecuteResult { + metadata: FetchResponseMetadata { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["999".to_string()])]), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: Vec::new(), }; let backoff = calculate_backoff(&config, 1, Some(&Ok(resp))); @@ -1692,15 +1624,16 @@ mod tests { #[test] fn test_parse_retry_after_valid_seconds() { - let resp = FetchResponse { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["120".to_string()])]), - body: String::new(), - body_encoding: "utf8".to_string(), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, + let resp = ExecuteResult { + metadata: FetchResponseMetadata { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["120".to_string()])]), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: Vec::new(), }; assert_eq!( @@ -1711,15 +1644,16 @@ mod tests { #[test] fn test_parse_retry_after_missing_header() { - let resp = FetchResponse { - status: 503, - status_text: "Service Unavailable".to_string(), - headers: HashMap::new(), - body: String::new(), - body_encoding: "utf8".to_string(), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, + let resp = ExecuteResult { + metadata: FetchResponseMetadata { + status: 503, + status_text: "Service Unavailable".to_string(), + headers: HashMap::new(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: Vec::new(), }; assert_eq!(parse_retry_after_from_response(&resp), None); @@ -1727,18 +1661,19 @@ mod tests { #[test] fn test_parse_retry_after_non_numeric_ignored() { - let resp = FetchResponse { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([( - "retry-after".to_string(), - vec!["Wed, 21 Oct 2025 07:28:00 GMT".to_string()], - )]), - body: String::new(), - body_encoding: "utf8".to_string(), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, + let resp = ExecuteResult { + metadata: FetchResponseMetadata { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([( + "retry-after".to_string(), + vec!["Wed, 21 Oct 2025 07:28:00 GMT".to_string()], + )]), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: Vec::new(), }; assert_eq!(parse_retry_after_from_response(&resp), None); @@ -1746,15 +1681,16 @@ mod tests { #[test] fn test_parse_retry_after_zero_seconds() { - let resp = FetchResponse { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["0".to_string()])]), - body: String::new(), - body_encoding: "utf8".to_string(), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, + let resp = ExecuteResult { + metadata: FetchResponseMetadata { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["0".to_string()])]), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: Vec::new(), }; assert_eq!( @@ -1766,15 +1702,16 @@ mod tests { #[test] fn test_calculate_backoff_with_retry_after_zero() { let config = RetryConfig::default(); - let resp = FetchResponse { - status: 429, - status_text: "Too Many Requests".to_string(), - headers: HashMap::from([("retry-after".to_string(), vec!["0".to_string()])]), - body: String::new(), - body_encoding: "utf8".to_string(), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, + let resp = ExecuteResult { + metadata: FetchResponseMetadata { + status: 429, + status_text: "Too Many Requests".to_string(), + headers: HashMap::from([("retry-after".to_string(), vec!["0".to_string()])]), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: Vec::new(), }; let backoff = calculate_backoff(&config, 1, Some(&Ok(resp))); @@ -1829,9 +1766,9 @@ mod tests { let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); - assert_eq!(resp.body, "ok"); - assert_eq!(resp.retry_count, 1); + assert_eq!(resp.metadata.status, 200); + assert_eq!(resp.body, b"ok"); + assert_eq!(resp.metadata.retry_count, 1); } #[tokio::test] @@ -1860,8 +1797,8 @@ mod tests { let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 503); - assert_eq!(resp.retry_count, 2); + assert_eq!(resp.metadata.status, 503); + assert_eq!(resp.metadata.retry_count, 2); } #[tokio::test] @@ -1891,8 +1828,8 @@ mod tests { let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 500); - assert_eq!(resp.retry_count, 0); + assert_eq!(resp.metadata.status, 500); + assert_eq!(resp.metadata.retry_count, 0); } #[tokio::test] @@ -1911,8 +1848,8 @@ mod tests { let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 503); - assert_eq!(resp.retry_count, 0); + assert_eq!(resp.metadata.status, 503); + assert_eq!(resp.metadata.retry_count, 0); } // --- Full execute() pipeline tests --- @@ -1939,12 +1876,11 @@ mod tests { let req = make_request(&localhost_url(&server, "/data")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); - assert_eq!(resp.status_text, "OK"); - assert_eq!(resp.body, r#"{"hello":"world"}"#); - assert_eq!(resp.body_encoding, "utf8"); - assert!(!resp.redirected); - assert_eq!(resp.retry_count, 0); + assert_eq!(resp.metadata.status, 200); + assert_eq!(resp.metadata.status_text, "OK"); + assert_eq!(resp.body, br#"{"hello":"world"}"#); + assert!(!resp.metadata.redirected); + assert_eq!(resp.metadata.retry_count, 0); } #[tokio::test] @@ -1968,7 +1904,7 @@ mod tests { let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); + assert_eq!(resp.metadata.status, 200); } #[tokio::test] @@ -2022,7 +1958,7 @@ mod tests { let req = make_request(&localhost_url(&server, "/api")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); + assert_eq!(resp.metadata.status, 200); } #[tokio::test] @@ -2053,7 +1989,7 @@ mod tests { let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); + assert_eq!(resp.metadata.status, 200); } #[tokio::test] @@ -2074,8 +2010,7 @@ mod tests { let req = make_request(&localhost_url(&server, "/text")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.body, "plain text response"); - assert_eq!(resp.body_encoding, "utf8"); + assert_eq!(resp.body, b"plain text response"); } #[tokio::test] @@ -2097,14 +2032,8 @@ mod tests { let req = make_request(&localhost_url(&server, "/image")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.body_encoding, "base64"); - - // Decode and verify - let decoded = base64::engine::general_purpose::STANDARD - .decode(&resp.body) - .unwrap(); - - assert_eq!(decoded, binary_data); + // Body is now raw bytes (no base64 encoding at the execute layer) + assert_eq!(resp.body, binary_data); } #[tokio::test] @@ -2123,12 +2052,12 @@ mod tests { req.method = Some("POST".to_string()); req.body = Some(r#"{"key":"value"}"#.to_string()); - req.body_encoding = Some("utf8".to_string()); + req.body_encoding = Some(BodyEncoding::Utf8); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 201); - assert_eq!(resp.body, "created"); + assert_eq!(resp.metadata.status, 201); + assert_eq!(resp.body, b"created"); } #[tokio::test] @@ -2149,7 +2078,7 @@ mod tests { let req = make_request(&localhost_url(&server, "/headers")); let resp = state.execute(req).await.unwrap(); - let custom_header = resp.headers.get("x-custom-response").unwrap(); + let custom_header = resp.metadata.headers.get("x-custom-response").unwrap(); assert_eq!(custom_header, &vec!["header-value".to_string()]); } @@ -2212,9 +2141,9 @@ mod tests { let req = make_request(&localhost_url(&server, "/a")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); - assert!(resp.redirected); - assert!(resp.url.contains("/b")); + assert_eq!(resp.metadata.status, 200); + assert!(resp.metadata.redirected); + assert!(resp.metadata.url.contains("/b")); } #[tokio::test] @@ -2231,7 +2160,7 @@ mod tests { let req = make_request(&localhost_url(&server, "/direct")); let resp = state.execute(req).await.unwrap(); - assert!(!resp.redirected); + assert!(!resp.metadata.redirected); } #[tokio::test] @@ -2250,15 +2179,15 @@ mod tests { req.method = Some("POST".to_string()); req.body = Some(String::new()); - req.body_encoding = Some("utf8".to_string()); + req.body_encoding = Some(BodyEncoding::Utf8); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); + assert_eq!(resp.metadata.status, 200); } #[tokio::test] - async fn test_execute_from_utf8_lossy_on_binary_with_text_content_type() { + async fn test_execute_returns_raw_bytes_for_invalid_utf8() { let server = MockServer::start().await; // Send binary data (invalid UTF-8) with text/plain content type @@ -2268,7 +2197,7 @@ mod tests { .and(path("/lossy")) .respond_with( ResponseTemplate::new(200) - .set_body_bytes(binary_body) + .set_body_bytes(binary_body.clone()) .insert_header("Content-Type", "text/plain"), ) .mount(&server) @@ -2278,13 +2207,8 @@ mod tests { let req = make_request(&localhost_url(&server, "/lossy")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.body_encoding, "utf8"); - // The invalid byte 0xFF should be replaced with U+FFFD - assert!( - resp.body.contains('\u{FFFD}'), - "expected replacement character in lossy UTF-8 conversion, got: {:?}", - resp.body - ); + // Body is raw bytes — no lossy UTF-8 conversion at the execute layer + assert_eq!(resp.body, binary_body); } #[tokio::test] @@ -2326,9 +2250,9 @@ mod tests { let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); - assert_eq!(resp.body, "fast"); - assert_eq!(resp.retry_count, 1); + assert_eq!(resp.metadata.status, 200); + assert_eq!(resp.body, b"fast"); + assert_eq!(resp.metadata.retry_count, 1); } #[tokio::test] @@ -2364,8 +2288,8 @@ mod tests { let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); - assert_eq!(resp.retry_count, 1); + assert_eq!(resp.metadata.status, 200); + assert_eq!(resp.metadata.retry_count, 1); } #[tokio::test] @@ -2417,8 +2341,8 @@ mod tests { let req = make_request(&localhost_url(&server, "/api")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); - assert_eq!(resp.retry_count, 1); + assert_eq!(resp.metadata.status, 200); + assert_eq!(resp.metadata.retry_count, 1); } // --- Redirect to IP address (integration) --- @@ -2483,11 +2407,11 @@ mod tests { // Should get the 3xx response (stop behavior), not an error assert!( - resp.status >= 300 && resp.status < 400, + resp.metadata.status >= 300 && resp.metadata.status < 400, "expected 3xx status from stop(), got {}", - resp.status + resp.metadata.status ); - assert!(resp.redirected); + assert!(resp.metadata.redirected); } // --- Retry-After end-to-end --- @@ -2527,8 +2451,8 @@ mod tests { let req = make_request(&localhost_url(&server, "/rate-limited")); let resp = state.execute(req).await.unwrap(); - assert_eq!(resp.status, 200); - assert_eq!(resp.retry_count, 1); + assert_eq!(resp.metadata.status, 200); + assert_eq!(resp.metadata.retry_count, 1); } // --- Per-request max_retries override --- @@ -2570,8 +2494,8 @@ mod tests { // With max_retries=1, we get 2 attempts: first returns 500, second returns 500. // Since retries are exhausted, we get the last 500 response. - assert_eq!(resp.status, 500); - assert_eq!(resp.retry_count, 1); + assert_eq!(resp.metadata.status, 500); + assert_eq!(resp.metadata.retry_count, 1); } // --- Security errors skip retry loop --- diff --git a/src/commands.rs b/src/commands.rs index ec0c406..8a4eb0d 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -2,18 +2,23 @@ use tauri::{AppHandle, Runtime, State}; use crate::client::{HttpClientState, InFlightRequests}; use crate::error::Result; -use crate::types::{FetchRequest, FetchResponse}; +use crate::types::{ExecuteResult, FetchRequest}; /// Executes an HTTP request through the plugin's security and execution pipeline. /// /// This is the primary IPC command invoked by the TypeScript guest. +/// +/// Returns `tauri::ipc::Response` using binary framing: +/// `[4-byte BE metadata length][metadata JSON][body bytes]` sent as +/// `InvokeResponseBody::Raw`, delivered to the TypeScript guest as an +/// `ArrayBuffer`. #[tauri::command] pub(crate) async fn fetch( _app: AppHandle, state: State<'_, HttpClientState>, in_flight: State<'_, InFlightRequests>, request: FetchRequest, -) -> Result { +) -> Result { let request_id = request.request_id.clone(); if let Some(ref id) = request_id { @@ -46,9 +51,9 @@ pub(crate) async fn fetch( // Ensure cleanup even if spawn's internal remove raced with register in_flight.remove(id).await; - result + result.and_then(pack_ipc_response) } else { - state.execute(request).await + state.execute(request).await.and_then(pack_ipc_response) } } @@ -64,3 +69,263 @@ pub(crate) async fn abort_request( ) -> Result { Ok(HttpClientState::abort(&in_flight, &request_id).await) } + +/// Packs an `ExecuteResult` into a binary-framed IPC response. +/// +/// Frame format: `[4-byte BE metadata length][metadata JSON][body bytes]` +/// sent as `InvokeResponseBody::Raw`. +fn pack_ipc_response(result: ExecuteResult) -> Result { + let metadata_json = serde_json::to_vec(&result.metadata) + .map_err(|e| crate::error::Error::Other(format!("metadata serialization failed: {e}")))?; + + let meta_len = u32::try_from(metadata_json.len()) + .map_err(|_| crate::error::Error::Other("metadata too large for IPC frame".into()))?; + let mut buf = Vec::with_capacity(4 + metadata_json.len() + result.body.len()); + + buf.extend_from_slice(&meta_len.to_be_bytes()); + buf.extend_from_slice(&metadata_json); + buf.extend_from_slice(&result.body); + + Ok(tauri::ipc::Response::new(buf)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::FetchResponseMetadata; + use std::collections::HashMap; + + #[test] + fn test_pack_ipc_response_binary_framing_structure() { + let result = ExecuteResult { + metadata: FetchResponseMetadata { + status: 200, + status_text: "OK".to_string(), + headers: HashMap::from([( + "content-type".to_string(), + vec!["application/json".to_string()], + )]), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: b"hello world".to_vec(), + }; + + let resp = pack_ipc_response(result).unwrap(); + let buf = extract_raw_bytes(resp); + + // First 4 bytes are big-endian metadata length + let meta_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + + assert!(meta_len > 0, "metadata length should be non-zero"); + assert_eq!(buf.len(), 4 + meta_len + b"hello world".len()); + + // Metadata should be valid JSON + let meta_json: serde_json::Value = serde_json::from_slice(&buf[4..4 + meta_len]).unwrap(); + + assert_eq!(meta_json["status"], 200); + assert_eq!(meta_json["statusText"], "OK"); + assert_eq!(meta_json["url"], "https://example.com"); + assert_eq!(meta_json["redirected"], false); + assert_eq!(meta_json["retryCount"], 0); + + // Body bytes follow metadata + assert_eq!(&buf[4 + meta_len..], b"hello world"); + } + + #[test] + fn test_pack_ipc_response_empty_body() { + let result = ExecuteResult { + metadata: FetchResponseMetadata { + status: 204, + status_text: "No Content".to_string(), + headers: HashMap::new(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: Vec::new(), + }; + + let resp = pack_ipc_response(result).unwrap(); + let buf = extract_raw_bytes(resp); + + let meta_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + + // No body bytes after metadata + assert_eq!(buf.len(), 4 + meta_len); + } + + #[test] + fn test_pack_ipc_response_binary_body_preserved() { + let binary_data: Vec = vec![0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]; + + let result = ExecuteResult { + metadata: FetchResponseMetadata { + status: 200, + status_text: "OK".to_string(), + headers: HashMap::from([("content-type".to_string(), vec!["image/png".to_string()])]), + url: "https://example.com/image.png".to_string(), + redirected: false, + retry_count: 0, + }, + body: binary_data.clone(), + }; + + let resp = pack_ipc_response(result).unwrap(); + let buf = extract_raw_bytes(resp); + + let meta_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + + // Binary body is preserved exactly + assert_eq!(&buf[4 + meta_len..], &binary_data); + } + + #[test] + fn test_pack_ipc_response_retry_count_in_metadata() { + let result = ExecuteResult { + metadata: FetchResponseMetadata { + status: 200, + status_text: "OK".to_string(), + headers: HashMap::new(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 3, + }, + body: b"ok".to_vec(), + }; + + let resp = pack_ipc_response(result).unwrap(); + let buf = extract_raw_bytes(resp); + + let meta_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + let meta_json: serde_json::Value = serde_json::from_slice(&buf[4..4 + meta_len]).unwrap(); + + assert_eq!(meta_json["retryCount"], 3); + } + + #[test] + fn test_pack_ipc_response_redirected_flag_in_metadata() { + let result = ExecuteResult { + metadata: FetchResponseMetadata { + status: 200, + status_text: "OK".to_string(), + headers: HashMap::new(), + url: "https://example.com/final".to_string(), + redirected: true, + retry_count: 0, + }, + body: b"redirected".to_vec(), + }; + + let resp = pack_ipc_response(result).unwrap(); + let buf = extract_raw_bytes(resp); + + let meta_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + let meta_json: serde_json::Value = serde_json::from_slice(&buf[4..4 + meta_len]).unwrap(); + + assert_eq!(meta_json["redirected"], true); + } + + #[test] + fn test_pack_ipc_response_headers_in_metadata() { + let result = ExecuteResult { + metadata: FetchResponseMetadata { + status: 200, + status_text: "OK".to_string(), + headers: HashMap::from([ + ( + "content-type".to_string(), + vec!["application/json".to_string()], + ), + ("x-request-id".to_string(), vec!["abc-123".to_string()]), + ]), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: b"{}".to_vec(), + }; + + let resp = pack_ipc_response(result).unwrap(); + let buf = extract_raw_bytes(resp); + + let meta_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + let meta_json: serde_json::Value = serde_json::from_slice(&buf[4..4 + meta_len]).unwrap(); + + assert!(meta_json["headers"]["content-type"].is_array()); + assert_eq!(meta_json["headers"]["content-type"][0], "application/json"); + assert_eq!(meta_json["headers"]["x-request-id"][0], "abc-123"); + } + + #[test] + fn test_pack_ipc_response_large_body() { + // Verify framing works with bodies larger than u8::MAX + let large_body = vec![0xABu8; 1024]; + + let result = ExecuteResult { + metadata: FetchResponseMetadata { + status: 200, + status_text: "OK".to_string(), + headers: HashMap::new(), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }, + body: large_body.clone(), + }; + + let resp = pack_ipc_response(result).unwrap(); + let buf = extract_raw_bytes(resp); + + let meta_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + + assert_eq!(buf.len(), 4 + meta_len + 1024); + assert_eq!(&buf[4 + meta_len..], &large_body); + } + + #[test] + fn test_pack_ipc_response_metadata_camel_case() { + let result = ExecuteResult { + metadata: FetchResponseMetadata { + status: 404, + status_text: "Not Found".to_string(), + headers: HashMap::new(), + url: "https://example.com/missing".to_string(), + redirected: false, + retry_count: 1, + }, + body: b"not found".to_vec(), + }; + + let resp = pack_ipc_response(result).unwrap(); + let buf = extract_raw_bytes(resp); + + let meta_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + let meta_json: serde_json::Value = serde_json::from_slice(&buf[4..4 + meta_len]).unwrap(); + + // Verify camelCase field names + assert!(meta_json.get("statusText").is_some()); + assert!(meta_json.get("retryCount").is_some()); + // Verify no snake_case field names + assert!(meta_json.get("status_text").is_none()); + assert!(meta_json.get("retry_count").is_none()); + } + + /// Extracts the raw bytes from a `tauri::ipc::Response` for test assertions. + /// + /// Uses the `IpcResponse` trait to get the `InvokeResponseBody`, then + /// matches on the `Raw` variant. Panics if the response is JSON (which + /// would indicate a bug in the desktop `pack_ipc_response`). + fn extract_raw_bytes(resp: tauri::ipc::Response) -> Vec { + use tauri::ipc::IpcResponse; + + match resp.body().expect("response body should be Ok") { + tauri::ipc::InvokeResponseBody::Raw(bytes) => bytes, + tauri::ipc::InvokeResponseBody::Json(json) => { + panic!("expected Raw bytes, got Json: {json}") + } + } + } +} diff --git a/src/types.rs b/src/types.rs index 5f7dea0..86ca774 100644 --- a/src/types.rs +++ b/src/types.rs @@ -2,6 +2,19 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; +/// Encoding used to transport the request body over IPC. +/// +/// The TypeScript guest constrains this to `'utf8' | 'base64'`; the Rust +/// enum mirrors that constraint so serde rejects unknown values at +/// deserialization time. +#[derive(Debug, PartialEq, Deserialize)] +pub enum BodyEncoding { + #[serde(rename = "utf8")] + Utf8, + #[serde(rename = "base64")] + Base64, +} + /// Request payload sent from the TypeScript guest to the Rust backend via IPC. /// /// All URL parsing and validation happens exclusively in Rust to avoid @@ -13,7 +26,7 @@ pub struct FetchRequest { pub method: Option, pub headers: Option>, pub body: Option, - pub body_encoding: Option, + pub body_encoding: Option, pub timeout_ms: Option, pub request_id: Option, /// Per-request retry override. `None` uses plugin config default. @@ -22,19 +35,27 @@ pub struct FetchRequest { pub max_retries: Option, } -/// Response payload sent from the Rust backend to the TypeScript guest via IPC. +/// HTTP response metadata without the body, serialized as JSON in the binary +/// framing protocol (`[4-byte BE length][metadata JSON][body bytes]`). #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] -pub struct FetchResponse { - pub status: u16, - pub status_text: String, - pub headers: HashMap>, - pub body: String, - pub body_encoding: String, - pub url: String, - pub redirected: bool, +pub(crate) struct FetchResponseMetadata { + pub(crate) status: u16, + pub(crate) status_text: String, + pub(crate) headers: HashMap>, + pub(crate) url: String, + pub(crate) redirected: bool, /// Number of retry attempts that occurred before this response (0 = no retries). - pub retry_count: u32, + pub(crate) retry_count: u32, +} + +/// Internal result from the HTTP execution pipeline, carrying raw body bytes +/// and response metadata. Converted to a binary-framed IPC response at the +/// command layer. +#[derive(Debug)] +pub(crate) struct ExecuteResult { + pub metadata: FetchResponseMetadata, + pub body: Vec, } #[cfg(test)] @@ -62,7 +83,7 @@ mod tests { "application/json" ); assert_eq!(req.body.as_deref(), Some("hello")); - assert_eq!(req.body_encoding.as_deref(), Some("utf8")); + assert_eq!(req.body_encoding, Some(BodyEncoding::Utf8)); assert_eq!(req.timeout_ms, Some(5000)); assert_eq!(req.request_id.as_deref(), Some("req-1")); assert!(req.max_retries.is_none()); @@ -83,31 +104,6 @@ mod tests { assert!(req.max_retries.is_none()); } - #[test] - fn test_fetch_response_serializes_camel_case() { - let resp = FetchResponse { - status: 200, - status_text: "OK".to_string(), - headers: HashMap::from([("content-type".to_string(), vec!["text/html".to_string()])]), - body: "hello".to_string(), - body_encoding: "utf8".to_string(), - url: "https://example.com".to_string(), - redirected: false, - retry_count: 0, - }; - - let json = serde_json::to_value(&resp).unwrap(); - - assert_eq!(json["status"], 200); - assert_eq!(json["statusText"], "OK"); - assert_eq!(json["body"], "hello"); - assert_eq!(json["bodyEncoding"], "utf8"); - assert_eq!(json["url"], "https://example.com"); - assert_eq!(json["redirected"], false); - assert_eq!(json["retryCount"], 0); - assert!(json["headers"]["content-type"].is_array()); - } - #[test] fn test_fetch_request_with_max_retries() { let json = serde_json::json!({ @@ -135,20 +131,55 @@ mod tests { } #[test] - fn test_fetch_response_retry_count_serializes() { - let resp = FetchResponse { + fn test_fetch_response_metadata_serializes_camel_case() { + let meta = FetchResponseMetadata { + status: 200, + status_text: "OK".to_string(), + headers: HashMap::from([("content-type".to_string(), vec!["text/html".to_string()])]), + url: "https://example.com".to_string(), + redirected: false, + retry_count: 0, + }; + + let json = serde_json::to_value(&meta).unwrap(); + + assert_eq!(json["status"], 200); + assert_eq!(json["statusText"], "OK"); + assert_eq!(json["url"], "https://example.com"); + assert_eq!(json["redirected"], false); + assert_eq!(json["retryCount"], 0); + assert!(json["headers"]["content-type"].is_array()); + // Metadata has no body or bodyEncoding fields + assert!(json.get("body").is_none()); + assert!(json.get("bodyEncoding").is_none()); + } + + #[test] + fn test_fetch_response_metadata_retry_count_serializes() { + let meta = FetchResponseMetadata { status: 200, status_text: "OK".to_string(), headers: HashMap::new(), - body: "ok".to_string(), - body_encoding: "utf8".to_string(), url: "https://example.com".to_string(), redirected: false, - retry_count: 2, + retry_count: 3, }; - let json = serde_json::to_value(&resp).unwrap(); + let json = serde_json::to_value(&meta).unwrap(); + + assert_eq!(json["retryCount"], 3); + } - assert_eq!(json["retryCount"], 2); + #[test] + fn test_fetch_request_invalid_body_encoding_fails_deserialization() { + let json = serde_json::json!({ + "url": "https://example.com", + "body": "hello", + "bodyEncoding": "gzip" + }); + + let result = serde_json::from_value::(json); + + assert!(result.is_err()); } }