diff --git a/proplet/Cargo.lock b/proplet/Cargo.lock index 9e6cfb06..a88cc775 100644 --- a/proplet/Cargo.lock +++ b/proplet/Cargo.lock @@ -204,6 +204,16 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "astral-tokio-tar" version = "0.5.6" @@ -1390,6 +1400,24 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "deadpool" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b" +dependencies = [ + "deadpool-runtime", + "lazy_static", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" + [[package]] name = "debugid" version = "0.8.0" @@ -2179,6 +2207,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" @@ -3219,6 +3253,16 @@ dependencies = [ "libm", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "oauth2" version = "5.0.0" @@ -3930,6 +3974,7 @@ dependencies = [ "wasmtime", "wasmtime-wasi", "wasmtime-wasi-http", + "wiremock", ] [[package]] @@ -6820,6 +6865,29 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "wiremock" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031" +dependencies = [ + "assert-json-diff", + "base64 0.22.1", + "deadpool", + "futures", + "http", + "http-body-util", + "hyper", + "hyper-util", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "tokio", + "url", +] + [[package]] name = "wit-bindgen" version = "0.30.0" diff --git a/proplet/Cargo.toml b/proplet/Cargo.toml index 7c1b37a3..487b5137 100644 --- a/proplet/Cargo.toml +++ b/proplet/Cargo.toml @@ -42,6 +42,10 @@ futures-util = { version = "0.3" } # ELASTIC TEE HAL — hardware abstraction layer for TEE workloads elastic-tee-hal = { git = "https://github.com/elasticproject-eu/wasmhal", default-features = false, features = ["amd-sev"] } +[dev-dependencies] +wiremock = "0.6" +tokio = { version = "1.42", features = ["full"] } + [features] default = [] diff --git a/proplet/src/service.rs b/proplet/src/service.rs index aa9e1775..56bf67f7 100644 --- a/proplet/src/service.rs +++ b/proplet/src/service.rs @@ -1,4 +1,6 @@ use crate::config::PropletConfig; + +const WASM_FETCH_MAX_BYTES: usize = 100 * 1024 * 1024; use crate::metrics::MetricsCollector; use crate::monitoring::{system::SystemMonitor, ProcessMonitor}; use crate::mqtt::{build_topic, MqttMessage, PubSub}; @@ -6,6 +8,7 @@ use crate::runtime::{Runtime, RuntimeContext, StartConfig}; use crate::types::*; use anyhow::{Context, Result}; use reqwest::Client as HttpClient; + use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use std::time::SystemTime; @@ -482,6 +485,18 @@ impl PropletService { if req.encrypted { info!("Encrypted workload with image_url: {}", req.image_url); Vec::new() + } else if req.image_url.starts_with("http://") || req.image_url.starts_with("https://") + { + match self.fetch_wasm_from_http(&req.image_url).await { + Ok(binary) => binary, + Err(e) => { + error!("Failed to fetch wasm for task {}: {}", req.id, e); + self.running_tasks.lock().await.remove(&req.id); + self.publish_result(&req.id, Vec::new(), Some(e.to_string())) + .await?; + return Err(e); + } + } } else { info!("Requesting binary from registry: {}", req.image_url); self.request_binary_from_registry(&req.image_url).await?; @@ -977,6 +992,10 @@ impl PropletService { } } + async fn fetch_wasm_from_http(&self, url: &str) -> Result> { + fetch_wasm_from_http(&self.http_client, url).await + } + async fn try_assemble_chunks(&self, app_name: &str) -> Result>> { let mut assembly = self.chunk_assembly.lock().await; @@ -1212,3 +1231,119 @@ fn build_fl_update_envelope( "metrics": {} }) } + +async fn fetch_wasm_from_http(client: &HttpClient, url: &str) -> Result> { + use futures_util::StreamExt; + + let response = client + .get(url) + .send() + .await + .with_context(|| format!("Failed to connect to {}", url))?; + + let status = response.status(); + if status.is_client_error() || status.is_server_error() { + return Err(anyhow::anyhow!( + "HTTP {} fetching wasm from {}", + status, + url + )); + } + + if let Some(content_length) = response.content_length() { + if content_length as usize > WASM_FETCH_MAX_BYTES { + return Err(anyhow::anyhow!( + "wasm response from {} exceeds size limit ({} > {} bytes)", + url, + content_length, + WASM_FETCH_MAX_BYTES + )); + } + } + + let mut binary = Vec::new(); + let mut stream = response.bytes_stream(); + while let Some(chunk) = stream.next().await { + let chunk = chunk.with_context(|| format!("Failed to read response body from {}", url))?; + binary.extend_from_slice(&chunk); + if binary.len() > WASM_FETCH_MAX_BYTES { + return Err(anyhow::anyhow!( + "wasm response from {} exceeds size limit ({} bytes)", + url, + WASM_FETCH_MAX_BYTES + )); + } + } + + info!("Fetched wasm from {}, size: {} bytes", url, binary.len()); + Ok(binary) +} + +#[cfg(test)] +mod tests { + use super::*; + use wiremock::matchers::method; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + fn make_client() -> HttpClient { + HttpClient::new() + } + + #[tokio::test] + async fn test_fetch_wasm_200_ok() { + let server = MockServer::start().await; + let wasm_bytes = b"\x00asm\x01\x00\x00\x00"; + + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(200).set_body_bytes(wasm_bytes.to_vec())) + .mount(&server) + .await; + + let result = fetch_wasm_from_http(&make_client(), &server.uri()).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), wasm_bytes); + } + + #[tokio::test] + async fn test_fetch_wasm_404() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(404)) + .mount(&server) + .await; + + let result = fetch_wasm_from_http(&make_client(), &server.uri()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("404")); + } + + #[tokio::test] + async fn test_fetch_wasm_500() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(500)) + .mount(&server) + .await; + + let result = fetch_wasm_from_http(&make_client(), &server.uri()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("500")); + } + + #[tokio::test] + async fn test_fetch_wasm_streaming_exceeds_limit() { + let server = MockServer::start().await; + let over_limit_body = vec![0u8; WASM_FETCH_MAX_BYTES + 1]; + + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(200).set_body_bytes(over_limit_body)) + .mount(&server) + .await; + + let result = fetch_wasm_from_http(&make_client(), &server.uri()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("size limit")); + } +} diff --git a/proplet/src/types.rs b/proplet/src/types.rs index 80cc1270..ea9571d3 100644 --- a/proplet/src/types.rs +++ b/proplet/src/types.rs @@ -813,4 +813,75 @@ mod tests { Some(&"/opt/intel/openvino".to_string()) ); } + + #[test] + fn test_start_request_validate_success_with_http_image_url() { + let req = StartRequest { + id: "task-http".to_string(), + cli_args: vec![], + name: "http_func".to_string(), + state: 0, + file: String::new(), + image_url: "http://fileserver/app.wasm".to_string(), + inputs: vec![], + daemon: false, + env: None, + monitoring_profile: None, + encrypted: false, + kbs_resource_path: None, + mode: None, + proplet_id: None, + }; + + assert!(req.validate().is_ok()); + } + + #[test] + fn test_start_request_validate_success_with_https_image_url() { + let req = StartRequest { + id: "task-https".to_string(), + cli_args: vec![], + name: "https_func".to_string(), + state: 0, + file: String::new(), + image_url: "https://releases.example.com/app.wasm".to_string(), + inputs: vec![], + daemon: false, + env: None, + monitoring_profile: None, + encrypted: false, + kbs_resource_path: None, + mode: None, + proplet_id: None, + }; + + assert!(req.validate().is_ok()); + } + + #[test] + fn test_start_request_validate_no_source() { + let req = StartRequest { + id: "task-nosource".to_string(), + cli_args: vec![], + name: "func".to_string(), + state: 0, + file: String::new(), + image_url: String::new(), + inputs: vec![], + daemon: false, + env: None, + monitoring_profile: None, + encrypted: false, + kbs_resource_path: None, + mode: None, + proplet_id: None, + }; + + let result = req.validate(); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "either file or image_url must be provided" + ); + } }