diff --git a/host/src/conversion/mod.rs b/host/src/conversion/mod.rs index 936819f..e1bbfae 100644 --- a/host/src/conversion/mod.rs +++ b/host/src/conversion/mod.rs @@ -1,4 +1,4 @@ -//! Conversion routes from/to WIT types. +//! Conversion routes from/to [WIT types](crate::bindings). use std::collections::HashMap; use arrow::{ @@ -15,7 +15,7 @@ use crate::{ error::DataFusionResultExt, }; -pub mod limits; +pub(crate) mod limits; impl CheckedFrom for DataFusionError { fn checked_from( diff --git a/host/src/http.rs b/host/src/http.rs index 747a31d..4f6c360 100644 --- a/host/src/http.rs +++ b/host/src/http.rs @@ -2,12 +2,13 @@ use std::{borrow::Cow, collections::HashSet, fmt}; -pub use http::Method; +pub use http::Method as HttpMethod; use wasmtime_wasi_http::body::HyperOutgoingBody; /// Validates if an outgoing HTTP interaction is allowed. /// -/// You can implement your own business logic here or use one of the pre-built implementations in [this module](self). +/// You can implement your own business logic here or use one of the pre-built implementations, e.g. +/// [`RejectAllHttpRequests`] or [`AllowCertainHttpRequests`]. pub trait HttpRequestValidator: fmt::Debug + Send + Sync + 'static { /// Validate incoming request. /// @@ -16,7 +17,7 @@ pub trait HttpRequestValidator: fmt::Debug + Send + Sync + 'static { &self, request: &hyper::Request, use_tls: bool, - ) -> Result<(), Rejected>; + ) -> Result<(), HttpRequestRejected>; } /// Reject ALL requests. @@ -28,16 +29,16 @@ impl HttpRequestValidator for RejectAllHttpRequests { &self, _request: &hyper::Request, _use_tls: bool, - ) -> Result<(), Rejected> { - Err(Rejected) + ) -> Result<(), HttpRequestRejected> { + Err(HttpRequestRejected) } } /// A request matcher. #[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct Matcher { +pub struct HttpRequestMatcher { /// Method. - pub method: Method, + pub method: HttpMethod, /// Host. /// @@ -56,7 +57,7 @@ pub struct AllowCertainHttpRequests { /// Set of all matchers. /// /// If ANY of them matches, the request will be allowed. - matchers: HashSet, + matchers: HashSet, } impl AllowCertainHttpRequests { @@ -66,7 +67,7 @@ impl AllowCertainHttpRequests { } /// Allow given request. - pub fn allow(&mut self, matcher: Matcher) { + pub fn allow(&mut self, matcher: HttpRequestMatcher) { self.matchers.insert(matcher); } } @@ -76,10 +77,15 @@ impl HttpRequestValidator for AllowCertainHttpRequests { &self, request: &hyper::Request, use_tls: bool, - ) -> Result<(), Rejected> { - let matcher = Matcher { + ) -> Result<(), HttpRequestRejected> { + let matcher = HttpRequestMatcher { method: request.method().clone(), - host: request.uri().host().ok_or(Rejected)?.to_owned().into(), + host: request + .uri() + .host() + .ok_or(HttpRequestRejected)? + .to_owned() + .into(), port: request .uri() .port_u16() @@ -89,22 +95,22 @@ impl HttpRequestValidator for AllowCertainHttpRequests { if self.matchers.contains(&matcher) { Ok(()) } else { - Err(Rejected) + Err(HttpRequestRejected) } } } /// Reject HTTP request. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub struct Rejected; +pub struct HttpRequestRejected; -impl fmt::Display for Rejected { +impl fmt::Display for HttpRequestRejected { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("rejected") } } -impl std::error::Error for Rejected {} +impl std::error::Error for HttpRequestRejected {} #[cfg(test)] mod test { @@ -121,108 +127,108 @@ mod test { #[test] fn allow_certain() { let request_no_port = hyper::Request::builder() - .method(Method::GET) + .method(HttpMethod::GET) .uri("http://foo.bar") .body(Default::default()) .unwrap(); let request_with_port = hyper::Request::builder() - .method(Method::GET) + .method(HttpMethod::GET) .uri("http://my.universe:1337") .body(Default::default()) .unwrap(); struct Case { - matchers: Vec, - result_no_port_no_tls: Result<(), Rejected>, - result_no_port_with_tls: Result<(), Rejected>, - result_with_port_no_tls: Result<(), Rejected>, - result_with_port_with_tls: Result<(), Rejected>, + matchers: Vec, + result_no_port_no_tls: Result<(), HttpRequestRejected>, + result_no_port_with_tls: Result<(), HttpRequestRejected>, + result_with_port_no_tls: Result<(), HttpRequestRejected>, + result_with_port_with_tls: Result<(), HttpRequestRejected>, } let cases = [ Case { matchers: vec![], - result_no_port_no_tls: Err(Rejected), - result_no_port_with_tls: Err(Rejected), - result_with_port_no_tls: Err(Rejected), - result_with_port_with_tls: Err(Rejected), + result_no_port_no_tls: Err(HttpRequestRejected), + result_no_port_with_tls: Err(HttpRequestRejected), + result_with_port_no_tls: Err(HttpRequestRejected), + result_with_port_with_tls: Err(HttpRequestRejected), }, Case { - matchers: vec![Matcher { - method: Method::GET, + matchers: vec![HttpRequestMatcher { + method: HttpMethod::GET, host: "foo.bar".into(), port: 80, }], result_no_port_no_tls: Ok(()), - result_no_port_with_tls: Err(Rejected), - result_with_port_no_tls: Err(Rejected), - result_with_port_with_tls: Err(Rejected), + result_no_port_with_tls: Err(HttpRequestRejected), + result_with_port_no_tls: Err(HttpRequestRejected), + result_with_port_with_tls: Err(HttpRequestRejected), }, Case { - matchers: vec![Matcher { - method: Method::GET, + matchers: vec![HttpRequestMatcher { + method: HttpMethod::GET, host: "foo.bar".into(), port: 443, }], - result_no_port_no_tls: Err(Rejected), + result_no_port_no_tls: Err(HttpRequestRejected), result_no_port_with_tls: Ok(()), - result_with_port_no_tls: Err(Rejected), - result_with_port_with_tls: Err(Rejected), + result_with_port_no_tls: Err(HttpRequestRejected), + result_with_port_with_tls: Err(HttpRequestRejected), }, Case { - matchers: vec![Matcher { - method: Method::POST, + matchers: vec![HttpRequestMatcher { + method: HttpMethod::POST, host: "foo.bar".into(), port: 80, }], - result_no_port_no_tls: Err(Rejected), - result_no_port_with_tls: Err(Rejected), - result_with_port_no_tls: Err(Rejected), - result_with_port_with_tls: Err(Rejected), + result_no_port_no_tls: Err(HttpRequestRejected), + result_no_port_with_tls: Err(HttpRequestRejected), + result_with_port_no_tls: Err(HttpRequestRejected), + result_with_port_with_tls: Err(HttpRequestRejected), }, Case { - matchers: vec![Matcher { - method: Method::GET, + matchers: vec![HttpRequestMatcher { + method: HttpMethod::GET, host: "my.universe".into(), port: 80, }], - result_no_port_no_tls: Err(Rejected), - result_no_port_with_tls: Err(Rejected), - result_with_port_no_tls: Err(Rejected), - result_with_port_with_tls: Err(Rejected), + result_no_port_no_tls: Err(HttpRequestRejected), + result_no_port_with_tls: Err(HttpRequestRejected), + result_with_port_no_tls: Err(HttpRequestRejected), + result_with_port_with_tls: Err(HttpRequestRejected), }, Case { - matchers: vec![Matcher { - method: Method::GET, + matchers: vec![HttpRequestMatcher { + method: HttpMethod::GET, host: "my.universe".into(), port: 1337, }], - result_no_port_no_tls: Err(Rejected), - result_no_port_with_tls: Err(Rejected), + result_no_port_no_tls: Err(HttpRequestRejected), + result_no_port_with_tls: Err(HttpRequestRejected), result_with_port_no_tls: Ok(()), result_with_port_with_tls: Ok(()), }, Case { matchers: vec![ - Matcher { - method: Method::GET, + HttpRequestMatcher { + method: HttpMethod::GET, host: "foo.bar".into(), port: 80, }, - Matcher { - method: Method::POST, + HttpRequestMatcher { + method: HttpMethod::POST, host: "foo.bar".into(), port: 80, }, - Matcher { - method: Method::GET, + HttpRequestMatcher { + method: HttpMethod::GET, host: "my.universe".into(), port: 1337, }, ], result_no_port_no_tls: Ok(()), - result_no_port_with_tls: Err(Rejected), + result_no_port_with_tls: Err(HttpRequestRejected), result_with_port_no_tls: Ok(()), result_with_port_with_tls: Ok(()), }, diff --git a/host/src/lib.rs b/host/src/lib.rs index 92836d6..5a23c5e 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -40,14 +40,23 @@ use wasmtime_wasi_http::{ use crate::{ bindings::exports::datafusion_udf_wasm::udf::types as wit_types, - conversion::limits::{CheckedInto, ComplexityToken, TrustedDataLimits}, + conversion::limits::{CheckedInto, ComplexityToken}, error::{DataFusionResultExt, WasmToDataFusionResultExt, WitDataFusionResultExt}, - http::{HttpRequestValidator, RejectAllHttpRequests}, ignore_debug::IgnoreDebug, - limiter::{Limiter, StaticResourceLimits}, + limiter::Limiter, linker::link, tokio_helpers::async_in_sync_context, - vfs::{VfsCtxView, VfsLimits, VfsState, VfsView}, + vfs::{VfsCtxView, VfsState, VfsView}, +}; + +pub use crate::{ + conversion::limits::TrustedDataLimits, + http::{ + AllowCertainHttpRequests, HttpMethod, HttpRequestMatcher, HttpRequestRejected, + HttpRequestValidator, RejectAllHttpRequests, + }, + limiter::StaticResourceLimits, + vfs::limits::VfsLimits, }; // unused-crate-dependencies false positives @@ -59,14 +68,14 @@ use regex as _; use wiremock as _; mod bindings; -pub mod conversion; -pub mod error; -pub mod http; +mod conversion; +mod error; +mod http; mod ignore_debug; -pub mod limiter; +mod limiter; mod linker; mod tokio_helpers; -pub mod vfs; +mod vfs; /// State of the WASM payload. #[derive(Debug)] diff --git a/host/src/vfs/mod.rs b/host/src/vfs/mod.rs index b620dda..a0f7527 100644 --- a/host/src/vfs/mod.rs +++ b/host/src/vfs/mod.rs @@ -37,14 +37,16 @@ use wasmtime_wasi::{ }, }; -pub use crate::vfs::limits::VfsLimits; use crate::{ error::LimitExceeded, limiter::Limiter, - vfs::path::{PathSegment, PathTraversal}, + vfs::{ + limits::VfsLimits, + path::{PathSegment, PathTraversal}, + }, }; -mod limits; +pub(crate) mod limits; mod path; /// Shared version of [`VfsNode`]. diff --git a/host/tests/integration_tests/evil/complex.rs b/host/tests/integration_tests/evil/complex.rs index aec5d64..98c89ec 100644 --- a/host/tests/integration_tests/evil/complex.rs +++ b/host/tests/integration_tests/evil/complex.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field}; use datafusion_common::{DataFusionError, config::ConfigOptions}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, async_udf::AsyncScalarUDFImpl}; -use datafusion_udf_wasm_host::{WasmPermissions, conversion::limits::TrustedDataLimits}; +use datafusion_udf_wasm_host::{TrustedDataLimits, WasmPermissions}; use crate::integration_tests::evil::test_utils::{try_scalar_udfs, try_scalar_udfs_with_env}; diff --git a/host/tests/integration_tests/evil/root.rs b/host/tests/integration_tests/evil/root.rs index 176d252..e5e345d 100644 --- a/host/tests/integration_tests/evil/root.rs +++ b/host/tests/integration_tests/evil/root.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use datafusion_execution::memory_pool::UnboundedMemoryPool; -use datafusion_udf_wasm_host::{WasmPermissions, WasmScalarUdf, vfs::VfsLimits}; +use datafusion_udf_wasm_host::{VfsLimits, WasmPermissions, WasmScalarUdf}; use crate::integration_tests::evil::test_utils::{ IO_RUNTIME, MEMORY_LIMIT, component, try_scalar_udfs, try_scalar_udfs_with_env, diff --git a/host/tests/integration_tests/python/runtime/fs.rs b/host/tests/integration_tests/python/runtime/fs.rs index 5825ab3..3ed40fd 100644 --- a/host/tests/integration_tests/python/runtime/fs.rs +++ b/host/tests/integration_tests/python/runtime/fs.rs @@ -7,7 +7,7 @@ use arrow::{ use datafusion_common::config::ConfigOptions; use datafusion_execution::memory_pool::{GreedyMemoryPool, UnboundedMemoryPool}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, async_udf::AsyncScalarUDFImpl}; -use datafusion_udf_wasm_host::{WasmPermissions, WasmScalarUdf, vfs::VfsLimits}; +use datafusion_udf_wasm_host::{VfsLimits, WasmPermissions, WasmScalarUdf}; use regex::Regex; use tokio::runtime::Handle; diff --git a/host/tests/integration_tests/python/runtime/http.rs b/host/tests/integration_tests/python/runtime/http.rs index 19e7afd..08b16a8 100644 --- a/host/tests/integration_tests/python/runtime/http.rs +++ b/host/tests/integration_tests/python/runtime/http.rs @@ -11,8 +11,8 @@ use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, async_udf::AsyncScalarUDFImpl, }; use datafusion_udf_wasm_host::{ - WasmPermissions, WasmScalarUdf, - http::{AllowCertainHttpRequests, HttpRequestValidator, Matcher}, + AllowCertainHttpRequests, HttpRequestMatcher, HttpRequestValidator, WasmPermissions, + WasmScalarUdf, }; use tokio::runtime::Handle; use wasmtime_wasi_http::types::DEFAULT_FORBIDDEN_HEADERS; @@ -40,7 +40,7 @@ def perform_request(url: str) -> str: .await; let mut permissions = AllowCertainHttpRequests::new(); - permissions.allow(Matcher { + permissions.allow(HttpRequestMatcher { method: http::Method::GET, host: server.address().ip().to_string().into(), port: server.address().port(), @@ -474,8 +474,8 @@ impl Default for TestCase { } impl TestCase { - fn matcher(&self, server: &MockServer) -> Matcher { - Matcher { + fn matcher(&self, server: &MockServer) -> HttpRequestMatcher { + HttpRequestMatcher { method: self.method.try_into().unwrap(), host: server.address().ip().to_string().into(), port: server.address().port(), @@ -639,7 +639,7 @@ def perform_request(url: str) -> str: // deliberately use a runtime what we are going to throw away later to prevent tricks like `Handle::current` let udf = rt_tmp.block_on(async { let mut permissions = AllowCertainHttpRequests::new(); - permissions.allow(Matcher { + permissions.allow(HttpRequestMatcher { method: http::Method::GET, host: server.address().ip().to_string().into(), port: server.address().port(), diff --git a/host/tests/integration_tests/rust.rs b/host/tests/integration_tests/rust.rs index 714b51f..c66d8cc 100644 --- a/host/tests/integration_tests/rust.rs +++ b/host/tests/integration_tests/rust.rs @@ -12,7 +12,7 @@ use datafusion_expr::{ async_udf::AsyncScalarUDFImpl, }; use datafusion_udf_wasm_host::{ - WasmComponentPrecompiled, WasmPermissions, WasmScalarUdf, limiter::StaticResourceLimits, + StaticResourceLimits, WasmComponentPrecompiled, WasmPermissions, WasmScalarUdf, }; use tokio::{runtime::Handle, sync::OnceCell};