diff --git a/.changelog/pr-182-admin-access-control.md b/.changelog/pr-182-admin-access-control.md new file mode 100644 index 00000000..50b30236 --- /dev/null +++ b/.changelog/pr-182-admin-access-control.md @@ -0,0 +1,5 @@ +--- +tidx: patch +--- + +Hardened view administration by failing closed for trusted CIDR checks, rejecting malformed CIDR configuration, hot-reloading active trusted CIDRs, and requiring an explicit admin mutation header. diff --git a/src/api/mod.rs b/src/api/mod.rs index 4c1dc552..d1243f0a 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -3,10 +3,11 @@ mod views; use std::collections::HashMap; use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; -use std::sync::Arc; +use std::sync::{Arc, RwLock as StdRwLock}; use tokio::sync::RwLock; +use anyhow::{Result as AnyhowResult, anyhow}; use axum::{ Json, Router, extract::{Query, State}, @@ -41,6 +42,7 @@ pub struct ChainClickHouseConfig { } pub type SharedClickHouseConfigs = Arc>>; +pub type SharedTrustedCidrs = Arc>>; #[derive(Clone)] pub struct AppState { @@ -54,7 +56,7 @@ pub struct AppState { /// ClickHouse engines for OLAP queries (per chain) pub clickhouse_engines: SharedClickHouseEngines, /// Parsed trusted CIDRs for admin operations - pub trusted_cidrs: Arc>, + pub trusted_cidrs: SharedTrustedCidrs, } impl AppState { @@ -70,28 +72,41 @@ impl AppState { /// Check if an IP address is in the trusted CIDRs pub fn is_trusted_ip(&self, addr: &SocketAddr) -> bool { - if self.trusted_cidrs.is_empty() { - return true; - } let ip = addr.ip(); self.trusted_cidrs - .iter() - .any(|(network, prefix)| ip_in_cidr(&ip, network, *prefix)) + .read() + .map(|cidrs| { + cidrs + .iter() + .any(|(network, prefix)| ip_in_cidr(&ip, network, *prefix)) + }) + .unwrap_or(false) } } /// Parse CIDR strings into (network, prefix_len) tuples -pub fn parse_cidrs(cidrs: &[String]) -> Vec<(IpAddr, u8)> { +pub fn parse_cidrs(cidrs: &[String]) -> AnyhowResult> { cidrs .iter() - .filter_map(|cidr| { - let parts: Vec<&str> = cidr.split('/').collect(); - if parts.len() != 2 { - return None; + .map(|cidr| { + let (ip, prefix) = cidr + .split_once('/') + .ok_or_else(|| anyhow!("Invalid CIDR '{cidr}': missing prefix"))?; + let ip: IpAddr = ip + .parse() + .map_err(|e| anyhow!("Invalid CIDR '{cidr}': invalid IP address: {e}"))?; + let prefix: u8 = prefix + .parse() + .map_err(|e| anyhow!("Invalid CIDR '{cidr}': invalid prefix: {e}"))?; + match ip { + IpAddr::V4(_) if prefix > 32 => { + Err(anyhow!("Invalid CIDR '{cidr}': IPv4 prefix exceeds 32")) + } + IpAddr::V6(_) if prefix > 128 => { + Err(anyhow!("Invalid CIDR '{cidr}': IPv6 prefix exceeds 128")) + } + _ => Ok((ip, prefix)), } - let ip: IpAddr = parts[0].parse().ok()?; - let prefix: u8 = parts[1].parse().ok()?; - Some((ip, prefix)) }) .collect() } @@ -131,7 +146,7 @@ pub fn router( pools: HashMap, default_chain_id: u64, broadcaster: Arc, -) -> Router<()> { +) -> AnyhowResult> { router_with_options( pools, default_chain_id, @@ -147,8 +162,8 @@ pub fn router_with_options( broadcaster: Arc, clickhouse_configs: HashMap, http_config: &HttpConfig, -) -> Router<()> { - let trusted_cidrs = Arc::new(parse_cidrs(&http_config.trusted_cidrs)); +) -> AnyhowResult> { + let trusted_cidrs = Arc::new(StdRwLock::new(parse_cidrs(&http_config.trusted_cidrs)?)); let state = AppState { pools: Arc::new(RwLock::new(pools)), @@ -159,7 +174,7 @@ pub fn router_with_options( trusted_cidrs, }; - build_router(state) + Ok(build_router(state)) } pub fn router_shared( @@ -168,10 +183,8 @@ pub fn router_shared( broadcaster: Arc, clickhouse_configs: SharedClickHouseConfigs, clickhouse_engines: SharedClickHouseEngines, - trusted_cidrs: Vec, + trusted_cidrs: SharedTrustedCidrs, ) -> Router<()> { - let trusted_cidrs = Arc::new(parse_cidrs(&trusted_cidrs)); - let state = AppState { pools, default_chain_id, @@ -186,7 +199,7 @@ pub fn router_shared( fn build_router(state: AppState) -> Router<()> { let cors = CorsLayer::new() - .allow_methods([Method::GET, Method::POST, Method::DELETE, Method::OPTIONS]) + .allow_methods([Method::GET, Method::OPTIONS]) .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION]) .allow_origin(tower_http::cors::Any); @@ -707,7 +720,7 @@ mod tests { "10.0.0.0/8".to_string(), "192.168.1.0/24".to_string(), ]; - let parsed = parse_cidrs(&cidrs); + let parsed = parse_cidrs(&cidrs).unwrap(); assert_eq!(parsed.len(), 3); assert_eq!(parsed[0], ("100.64.0.0".parse().unwrap(), 10)); assert_eq!(parsed[1], ("10.0.0.0".parse().unwrap(), 8)); @@ -721,8 +734,48 @@ mod tests { "100.64.0.0".to_string(), // Missing prefix "100.64.0.0/abc".to_string(), // Invalid prefix ]; - let parsed = parse_cidrs(&cidrs); - assert_eq!(parsed.len(), 0); + assert!(parse_cidrs(&cidrs).is_err()); + assert!(parse_cidrs(&["100.64.0.0/33".to_string()]).is_err()); + assert!(parse_cidrs(&["fd7a:115c:a1e0::/129".to_string()]).is_err()); + } + + #[test] + fn test_router_with_options_rejects_invalid_trusted_cidr() { + let http_config = HttpConfig { + trusted_cidrs: vec!["100.64.0.0/33".to_string()], + ..Default::default() + }; + + let result = router_with_options( + HashMap::new(), + 0, + Arc::new(Broadcaster::new()), + HashMap::new(), + &http_config, + ); + + assert!(result.is_err()); + } + + #[test] + fn test_trusted_ip_fails_closed_when_empty() { + let state = AppState { + pools: Arc::new(RwLock::new(HashMap::new())), + default_chain_id: 0, + broadcaster: Arc::new(Broadcaster::new()), + clickhouse_configs: Arc::new(RwLock::new(HashMap::new())), + clickhouse_engines: Arc::new(RwLock::new(HashMap::new())), + trusted_cidrs: Arc::new(std::sync::RwLock::new(Vec::new())), + }; + let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap(); + assert!(!state.is_trusted_ip(&addr)); + } + + #[test] + fn test_http_config_default_trusts_only_loopback() { + let parsed = parse_cidrs(&HttpConfig::default().trusted_cidrs).unwrap(); + assert!(parsed.contains(&("127.0.0.1".parse().unwrap(), 32))); + assert!(parsed.contains(&("::1".parse().unwrap(), 128))); } #[test] diff --git a/src/api/views.rs b/src/api/views.rs index 43559315..6bcf239c 100644 --- a/src/api/views.rs +++ b/src/api/views.rs @@ -3,6 +3,7 @@ use axum::{ Json, extract::{ConnectInfo, Path, Query, State}, + http::HeaderMap, }; use serde::{Deserialize, Serialize}; use std::net::SocketAddr; @@ -10,6 +11,32 @@ use std::net::SocketAddr; use super::{ApiError, AppState}; use crate::query::EventSignature; +const ADMIN_MUTATION_HEADER: &str = "x-tidx-admin"; + +fn require_admin_mutation( + headers: &HeaderMap, + state: &AppState, + addr: &SocketAddr, +) -> Result<(), ApiError> { + if !state.is_trusted_ip(addr) { + return Err(ApiError::Forbidden( + "Mutations only allowed from trusted IPs".to_string(), + )); + } + + if headers + .get(ADMIN_MUTATION_HEADER) + .and_then(|value| value.to_str().ok()) + != Some("1") + { + return Err(ApiError::Forbidden( + "Missing admin mutation header".to_string(), + )); + } + + Ok(()) +} + /// Validate view name (alphanumeric + underscore only) fn is_valid_view_name(name: &str) -> bool { is_valid_identifier(name) @@ -170,14 +197,10 @@ pub struct CreateViewResponse { pub async fn create_view( State(state): State, ConnectInfo(addr): ConnectInfo, + headers: HeaderMap, Json(req): Json, ) -> Result, ApiError> { - // Check trusted IP access - if !state.is_trusted_ip(&addr) { - return Err(ApiError::Forbidden( - "Mutations only allowed from trusted IPs".to_string(), - )); - } + require_admin_mutation(&headers, &state, &addr)?; // Validate view name if !is_valid_view_name(&req.name) { @@ -320,15 +343,11 @@ pub struct DeleteViewResponse { pub async fn delete_view( State(state): State, ConnectInfo(addr): ConnectInfo, + headers: HeaderMap, Path(name): Path, Query(params): Query, ) -> Result, ApiError> { - // Check trusted IP access - if !state.is_trusted_ip(&addr) { - return Err(ApiError::Forbidden( - "Mutations only allowed from trusted IPs".to_string(), - )); - } + require_admin_mutation(&headers, &state, &addr)?; // Validate view name if !is_valid_view_name(&name) { @@ -446,7 +465,12 @@ pub async fn get_view( #[cfg(test)] mod tests { use super::*; + use crate::broadcast::Broadcaster; use insta::assert_snapshot; + use std::collections::HashMap; + use std::net::IpAddr; + use std::sync::{Arc, RwLock as StdRwLock}; + use tokio::sync::RwLock; #[test] fn test_valid_view_name() { @@ -460,6 +484,35 @@ mod tests { assert!(!is_valid_view_name("my view")); // Has space } + fn test_state_with_trusted_localhost() -> AppState { + let trusted_cidrs = vec![("127.0.0.1".parse::().unwrap(), 32)]; + AppState { + pools: Arc::new(RwLock::new(HashMap::new())), + default_chain_id: 0, + broadcaster: Arc::new(Broadcaster::new()), + clickhouse_configs: Arc::new(RwLock::new(HashMap::new())), + clickhouse_engines: Arc::new(RwLock::new(HashMap::new())), + trusted_cidrs: Arc::new(StdRwLock::new(trusted_cidrs)), + } + } + + #[test] + fn test_requires_admin_mutation_header() { + let state = test_state_with_trusted_localhost(); + let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap(); + let headers = HeaderMap::new(); + assert!(require_admin_mutation(&headers, &state, &addr).is_err()); + } + + #[test] + fn test_accepts_admin_mutation_header_from_trusted_ip() { + let state = test_state_with_trusted_localhost(); + let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap(); + let mut headers = HeaderMap::new(); + headers.insert(ADMIN_MUTATION_HEADER, "1".parse().unwrap()); + assert!(require_admin_mutation(&headers, &state, &addr).is_ok()); + } + // ======================================================================== // Helper to generate full SQL from signature + user query // ======================================================================== diff --git a/src/cli/up.rs b/src/cli/up.rs index 9aa38deb..9e0952a8 100644 --- a/src/cli/up.rs +++ b/src/cli/up.rs @@ -124,7 +124,8 @@ pub async fn run(args: Args) -> Result<()> { let (chain_tx, mut chain_rx) = tokio::sync::mpsc::channel::(16); if !args.no_watch { - let watcher = ConfigWatcher::new(args.config.clone(), &config, chain_tx); + let watcher = ConfigWatcher::new(args.config.clone(), &config, chain_tx)?; + let trusted_cidrs = watcher.trusted_cidrs(); watcher.start()?; if config.http.enabled && default_chain_id != 0 { @@ -136,7 +137,7 @@ pub async fn run(args: Args) -> Result<()> { broadcaster.clone(), Arc::clone(&clickhouse_configs), Arc::clone(&clickhouse_engines), - config.http.trusted_cidrs.clone(), + trusted_cidrs, ); info!(addr = %addr, "Starting HTTP API server (hot-reload enabled)"); @@ -201,7 +202,7 @@ pub async fn run(args: Args) -> Result<()> { broadcaster.clone(), clickhouse_configs.read().await.clone(), &config.http, - ); + )?; info!(addr = %addr, "Starting HTTP API server"); diff --git a/src/config.rs b/src/config.rs index cb469d9e..3fcaffe1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -35,7 +35,7 @@ pub struct HttpConfig { pub bind: String, /// Trusted CIDRs for admin operations (e.g., `100.64.0.0/10` for Tailscale) - #[serde(default)] + #[serde(default = "default_trusted_cidrs")] pub trusted_cidrs: Vec, } @@ -45,7 +45,7 @@ impl Default for HttpConfig { enabled: true, port: 8080, bind: "0.0.0.0".to_string(), - trusted_cidrs: Vec::new(), + trusted_cidrs: default_trusted_cidrs(), } } } @@ -82,6 +82,10 @@ fn default_bind() -> String { "0.0.0.0".to_string() } +fn default_trusted_cidrs() -> Vec { + vec!["127.0.0.1/32".to_string(), "::1/128".to_string()] +} + fn default_metrics_port() -> u16 { 9090 } diff --git a/src/config/watcher.rs b/src/config/watcher.rs index cba8b02c..194253b1 100644 --- a/src/config/watcher.rs +++ b/src/config/watcher.rs @@ -8,6 +8,7 @@ use tokio::sync::{RwLock, mpsc}; use tracing::{error, info, warn}; use super::{ChainConfig, Config, HttpConfig}; +use crate::api::{SharedTrustedCidrs, parse_cidrs}; pub type SharedHttpConfig = Arc>; @@ -19,6 +20,7 @@ pub struct NewChainEvent { pub struct ConfigWatcher { config_path: PathBuf, http_config: SharedHttpConfig, + trusted_cidrs: SharedTrustedCidrs, chain_tx: mpsc::Sender, known_chain_ids: Arc>>, } @@ -28,25 +30,32 @@ impl ConfigWatcher { config_path: PathBuf, initial_config: &Config, chain_tx: mpsc::Sender, - ) -> Self { + ) -> Result { let known_chain_ids: HashSet = initial_config.chains.iter().map(|c| c.chain_id).collect(); + let trusted_cidrs = parse_cidrs(&initial_config.http.trusted_cidrs)?; - Self { + Ok(Self { config_path, http_config: Arc::new(RwLock::new(initial_config.http.clone())), + trusted_cidrs: Arc::new(std::sync::RwLock::new(trusted_cidrs)), chain_tx, known_chain_ids: Arc::new(RwLock::new(known_chain_ids)), - } + }) } pub fn http_config(&self) -> SharedHttpConfig { Arc::clone(&self.http_config) } + pub fn trusted_cidrs(&self) -> SharedTrustedCidrs { + Arc::clone(&self.trusted_cidrs) + } + pub fn start(self) -> Result<()> { let config_path = self.config_path.clone(); let http_config = self.http_config.clone(); + let trusted_cidrs = self.trusted_cidrs.clone(); let chain_tx = self.chain_tx.clone(); let known_chain_ids = self.known_chain_ids.clone(); @@ -87,6 +96,7 @@ impl ConfigWatcher { if let Err(e) = reload_config( &config_path, &http_config, + &trusted_cidrs, &chain_tx, &known_chain_ids, ).await { @@ -105,16 +115,25 @@ impl ConfigWatcher { async fn reload_config( config_path: &PathBuf, http_config: &SharedHttpConfig, + trusted_cidrs: &SharedTrustedCidrs, chain_tx: &mpsc::Sender, known_chain_ids: &Arc>>, ) -> Result<()> { let new_config = Config::load(config_path)?; + let new_trusted_cidrs = parse_cidrs(&new_config.http.trusted_cidrs)?; { let mut http = http_config.write().await; *http = new_config.http.clone(); } + { + let mut cidrs = trusted_cidrs + .write() + .map_err(|_| anyhow::anyhow!("trusted CIDR lock poisoned"))?; + *cidrs = new_trusted_cidrs; + } + let mut known = known_chain_ids.write().await; for chain in &new_config.chains { if !known.contains(&chain.chain_id) { @@ -135,3 +154,26 @@ async fn reload_config( info!("Config reloaded"); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::PrometheusConfig; + + #[test] + fn test_new_rejects_invalid_trusted_cidr() { + let config = Config { + http: HttpConfig { + trusted_cidrs: vec!["100.64.0.0/33".to_string()], + ..Default::default() + }, + prometheus: PrometheusConfig::default(), + chains: vec![], + }; + let (chain_tx, _chain_rx) = mpsc::channel(1); + + let result = ConfigWatcher::new(PathBuf::from("config.toml"), &config, chain_tx); + + assert!(result.is_err()); + } +} diff --git a/tests/api_live_test.rs b/tests/api_live_test.rs index 32308589..8be8b21e 100644 --- a/tests/api_live_test.rs +++ b/tests/api_live_test.rs @@ -31,6 +31,7 @@ async fn make_test_service( { let mut svc: IntoMakeServiceWithConnectInfo = api::router(pools, chain_id, broadcaster) + .unwrap() .into_make_service_with_connect_info::(); svc.call(SocketAddr::from(([127, 0, 0, 1], 0))) .await diff --git a/tests/status_test.rs b/tests/status_test.rs index 8f70f47b..b1c6bf63 100644 --- a/tests/status_test.rs +++ b/tests/status_test.rs @@ -33,6 +33,7 @@ async fn make_test_service( { let mut svc: IntoMakeServiceWithConnectInfo = api::router(pools, chain_id, broadcaster) + .unwrap() .into_make_service_with_connect_info::(); svc.call(SocketAddr::from(([127, 0, 0, 1], 0))) .await @@ -193,6 +194,7 @@ async fn test_cli_proxy_via_http_server() { metrics::update_sink_watermark("postgres", "txs", 1_000_000); let router = api::router(pools, chain_id, broadcaster) + .unwrap() .into_make_service_with_connect_info::(); // Bind to a random available port