Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions host/src/conversion/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! Conversion routes from/to WIT types.
//! Conversion routes from/to [WIT types](crate::bindings).
use std::collections::HashMap;

use arrow::{
Expand All @@ -15,7 +15,7 @@ use crate::{
error::DataFusionResultExt,
};

pub mod limits;
pub(crate) mod limits;

impl CheckedFrom<wit_types::DataFusionError> for DataFusionError {
fn checked_from(
Expand Down
126 changes: 66 additions & 60 deletions host/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

use std::{borrow::Cow, collections::HashSet, fmt};

pub use http::Method;
pub use http::Method as HttpMethod;
use wasmtime_wasi_http::body::HyperOutgoingBody;

/// Validates if an outgoing HTTP interaction is allowed.
///
/// You can implement your own business logic here or use one of the pre-built implementations in [this module](self).
/// You can implement your own business logic here or use one of the pre-built implementations, e.g.
/// [`RejectAllHttpRequests`] or [`AllowCertainHttpRequests`].
pub trait HttpRequestValidator: fmt::Debug + Send + Sync + 'static {
/// Validate incoming request.
///
Expand All @@ -16,7 +17,7 @@ pub trait HttpRequestValidator: fmt::Debug + Send + Sync + 'static {
&self,
request: &hyper::Request<HyperOutgoingBody>,
use_tls: bool,
) -> Result<(), Rejected>;
) -> Result<(), HttpRequestRejected>;
}

/// Reject ALL requests.
Expand All @@ -28,16 +29,16 @@ impl HttpRequestValidator for RejectAllHttpRequests {
&self,
_request: &hyper::Request<HyperOutgoingBody>,
_use_tls: bool,
) -> Result<(), Rejected> {
Err(Rejected)
) -> Result<(), HttpRequestRejected> {
Err(HttpRequestRejected)
}
}

/// A request matcher.
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct Matcher {
pub struct HttpRequestMatcher {
/// Method.
pub method: Method,
pub method: HttpMethod,

/// Host.
///
Expand All @@ -56,7 +57,7 @@ pub struct AllowCertainHttpRequests {
/// Set of all matchers.
///
/// If ANY of them matches, the request will be allowed.
matchers: HashSet<Matcher>,
matchers: HashSet<HttpRequestMatcher>,
}

impl AllowCertainHttpRequests {
Expand All @@ -66,7 +67,7 @@ impl AllowCertainHttpRequests {
}

/// Allow given request.
pub fn allow(&mut self, matcher: Matcher) {
pub fn allow(&mut self, matcher: HttpRequestMatcher) {
self.matchers.insert(matcher);
}
}
Expand All @@ -76,10 +77,15 @@ impl HttpRequestValidator for AllowCertainHttpRequests {
&self,
request: &hyper::Request<HyperOutgoingBody>,
use_tls: bool,
) -> Result<(), Rejected> {
let matcher = Matcher {
) -> Result<(), HttpRequestRejected> {
let matcher = HttpRequestMatcher {
method: request.method().clone(),
host: request.uri().host().ok_or(Rejected)?.to_owned().into(),
host: request
.uri()
.host()
.ok_or(HttpRequestRejected)?
.to_owned()
.into(),
port: request
.uri()
.port_u16()
Expand All @@ -89,22 +95,22 @@ impl HttpRequestValidator for AllowCertainHttpRequests {
if self.matchers.contains(&matcher) {
Ok(())
} else {
Err(Rejected)
Err(HttpRequestRejected)
}
}
}

/// Reject HTTP request.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct Rejected;
pub struct HttpRequestRejected;

impl fmt::Display for Rejected {
impl fmt::Display for HttpRequestRejected {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("rejected")
}
}

impl std::error::Error for Rejected {}
impl std::error::Error for HttpRequestRejected {}

#[cfg(test)]
mod test {
Expand All @@ -121,108 +127,108 @@ mod test {
#[test]
fn allow_certain() {
let request_no_port = hyper::Request::builder()
.method(Method::GET)
.method(HttpMethod::GET)
.uri("http://foo.bar")
.body(Default::default())
.unwrap();

let request_with_port = hyper::Request::builder()
.method(Method::GET)
.method(HttpMethod::GET)
.uri("http://my.universe:1337")
.body(Default::default())
.unwrap();

struct Case {
matchers: Vec<Matcher>,
result_no_port_no_tls: Result<(), Rejected>,
result_no_port_with_tls: Result<(), Rejected>,
result_with_port_no_tls: Result<(), Rejected>,
result_with_port_with_tls: Result<(), Rejected>,
matchers: Vec<HttpRequestMatcher>,
result_no_port_no_tls: Result<(), HttpRequestRejected>,
result_no_port_with_tls: Result<(), HttpRequestRejected>,
result_with_port_no_tls: Result<(), HttpRequestRejected>,
result_with_port_with_tls: Result<(), HttpRequestRejected>,
}

let cases = [
Case {
matchers: vec![],
result_no_port_no_tls: Err(Rejected),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
result_no_port_no_tls: Err(HttpRequestRejected),
result_no_port_with_tls: Err(HttpRequestRejected),
result_with_port_no_tls: Err(HttpRequestRejected),
result_with_port_with_tls: Err(HttpRequestRejected),
},
Case {
matchers: vec![Matcher {
method: Method::GET,
matchers: vec![HttpRequestMatcher {
method: HttpMethod::GET,
host: "foo.bar".into(),
port: 80,
}],
result_no_port_no_tls: Ok(()),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
result_no_port_with_tls: Err(HttpRequestRejected),
result_with_port_no_tls: Err(HttpRequestRejected),
result_with_port_with_tls: Err(HttpRequestRejected),
},
Case {
matchers: vec![Matcher {
method: Method::GET,
matchers: vec![HttpRequestMatcher {
method: HttpMethod::GET,
host: "foo.bar".into(),
port: 443,
}],
result_no_port_no_tls: Err(Rejected),
result_no_port_no_tls: Err(HttpRequestRejected),
result_no_port_with_tls: Ok(()),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
result_with_port_no_tls: Err(HttpRequestRejected),
result_with_port_with_tls: Err(HttpRequestRejected),
},
Case {
matchers: vec![Matcher {
method: Method::POST,
matchers: vec![HttpRequestMatcher {
method: HttpMethod::POST,
host: "foo.bar".into(),
port: 80,
}],
result_no_port_no_tls: Err(Rejected),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
result_no_port_no_tls: Err(HttpRequestRejected),
result_no_port_with_tls: Err(HttpRequestRejected),
result_with_port_no_tls: Err(HttpRequestRejected),
result_with_port_with_tls: Err(HttpRequestRejected),
},
Case {
matchers: vec![Matcher {
method: Method::GET,
matchers: vec![HttpRequestMatcher {
method: HttpMethod::GET,
host: "my.universe".into(),
port: 80,
}],
result_no_port_no_tls: Err(Rejected),
result_no_port_with_tls: Err(Rejected),
result_with_port_no_tls: Err(Rejected),
result_with_port_with_tls: Err(Rejected),
result_no_port_no_tls: Err(HttpRequestRejected),
result_no_port_with_tls: Err(HttpRequestRejected),
result_with_port_no_tls: Err(HttpRequestRejected),
result_with_port_with_tls: Err(HttpRequestRejected),
},
Case {
matchers: vec![Matcher {
method: Method::GET,
matchers: vec![HttpRequestMatcher {
method: HttpMethod::GET,
host: "my.universe".into(),
port: 1337,
}],
result_no_port_no_tls: Err(Rejected),
result_no_port_with_tls: Err(Rejected),
result_no_port_no_tls: Err(HttpRequestRejected),
result_no_port_with_tls: Err(HttpRequestRejected),
result_with_port_no_tls: Ok(()),
result_with_port_with_tls: Ok(()),
},
Case {
matchers: vec![
Matcher {
method: Method::GET,
HttpRequestMatcher {
method: HttpMethod::GET,
host: "foo.bar".into(),
port: 80,
},
Matcher {
method: Method::POST,
HttpRequestMatcher {
method: HttpMethod::POST,
host: "foo.bar".into(),
port: 80,
},
Matcher {
method: Method::GET,
HttpRequestMatcher {
method: HttpMethod::GET,
host: "my.universe".into(),
port: 1337,
},
],
result_no_port_no_tls: Ok(()),
result_no_port_with_tls: Err(Rejected),
result_no_port_with_tls: Err(HttpRequestRejected),
result_with_port_no_tls: Ok(()),
result_with_port_with_tls: Ok(()),
},
Expand Down
27 changes: 18 additions & 9 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,23 @@ use wasmtime_wasi_http::{

use crate::{
bindings::exports::datafusion_udf_wasm::udf::types as wit_types,
conversion::limits::{CheckedInto, ComplexityToken, TrustedDataLimits},
conversion::limits::{CheckedInto, ComplexityToken},
error::{DataFusionResultExt, WasmToDataFusionResultExt, WitDataFusionResultExt},
http::{HttpRequestValidator, RejectAllHttpRequests},
ignore_debug::IgnoreDebug,
limiter::{Limiter, StaticResourceLimits},
limiter::Limiter,
linker::link,
tokio_helpers::async_in_sync_context,
vfs::{VfsCtxView, VfsLimits, VfsState, VfsView},
vfs::{VfsCtxView, VfsState, VfsView},
};

pub use crate::{
conversion::limits::TrustedDataLimits,
http::{
AllowCertainHttpRequests, HttpMethod, HttpRequestMatcher, HttpRequestRejected,
HttpRequestValidator, RejectAllHttpRequests,
},
limiter::StaticResourceLimits,
vfs::limits::VfsLimits,
};

// unused-crate-dependencies false positives
Expand All @@ -59,14 +68,14 @@ use regex as _;
use wiremock as _;

mod bindings;
pub mod conversion;
pub mod error;
pub mod http;
mod conversion;
mod error;
mod http;
mod ignore_debug;
pub mod limiter;
mod limiter;
mod linker;
mod tokio_helpers;
pub mod vfs;
mod vfs;

/// State of the WASM payload.
#[derive(Debug)]
Expand Down
8 changes: 5 additions & 3 deletions host/src/vfs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ use wasmtime_wasi::{
},
};

pub use crate::vfs::limits::VfsLimits;
use crate::{
error::LimitExceeded,
limiter::Limiter,
vfs::path::{PathSegment, PathTraversal},
vfs::{
limits::VfsLimits,
path::{PathSegment, PathTraversal},
},
};

mod limits;
pub(crate) mod limits;
mod path;

/// Shared version of [`VfsNode`].
Expand Down
2 changes: 1 addition & 1 deletion host/tests/integration_tests/evil/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{DataFusionError, config::ConfigOptions};
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, async_udf::AsyncScalarUDFImpl};
use datafusion_udf_wasm_host::{WasmPermissions, conversion::limits::TrustedDataLimits};
use datafusion_udf_wasm_host::{TrustedDataLimits, WasmPermissions};

use crate::integration_tests::evil::test_utils::{try_scalar_udfs, try_scalar_udfs_with_env};

Expand Down
2 changes: 1 addition & 1 deletion host/tests/integration_tests/evil/root.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use datafusion_execution::memory_pool::UnboundedMemoryPool;
use datafusion_udf_wasm_host::{WasmPermissions, WasmScalarUdf, vfs::VfsLimits};
use datafusion_udf_wasm_host::{VfsLimits, WasmPermissions, WasmScalarUdf};

use crate::integration_tests::evil::test_utils::{
IO_RUNTIME, MEMORY_LIMIT, component, try_scalar_udfs, try_scalar_udfs_with_env,
Expand Down
2 changes: 1 addition & 1 deletion host/tests/integration_tests/python/runtime/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use arrow::{
use datafusion_common::config::ConfigOptions;
use datafusion_execution::memory_pool::{GreedyMemoryPool, UnboundedMemoryPool};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, async_udf::AsyncScalarUDFImpl};
use datafusion_udf_wasm_host::{WasmPermissions, WasmScalarUdf, vfs::VfsLimits};
use datafusion_udf_wasm_host::{VfsLimits, WasmPermissions, WasmScalarUdf};
use regex::Regex;
use tokio::runtime::Handle;

Expand Down
Loading