From 9bd2bc7b49d742947def8ba373892e5bc215083a Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 10 Dec 2025 17:52:15 +0100 Subject: [PATCH 1/3] Add attestation type detection --- src/attestation/mod.rs | 53 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/attestation/mod.rs b/src/attestation/mod.rs index 1142d30..5c0d9e4 100644 --- a/src/attestation/mod.rs +++ b/src/attestation/mod.rs @@ -66,6 +66,23 @@ impl AttestationType { AttestationType::DcapTdx => "dcap-tdx", } } + + /// Detect what platform we are on by attempting an attestation + pub async fn detect() -> Self { + // First attempt azure, if the feature is present + #[cfg(feature = "azure")] + { + if azure::create_azure_attestation([0; 64]).await.is_ok() { + return AttestationType::AzureTdx; + } + } + // Otherwise try DCAP quote - this internally checks that the quote provider is `tdx_guest` + if configfs_tsm::create_tdx_quote([0; 64]).is_ok() { + // TODO Possibly also check if it looks like we are on GCP (eg: hit metadata API) + return AttestationType::DcapTdx; + } + AttestationType::None + } } /// SCALE encode (used over the wire) @@ -99,6 +116,25 @@ pub struct AttestationGenerator { } impl AttestationGenerator { + /// Create an [AttestationGenerator] detecting the attestation type if it is specified as 'auto' + pub async fn new_with_detection( + attestation_type_string: Option, + dummy_dcap_url: Option, + ) -> Result { + let attestaton_type = if attestation_type_string.as_deref() == Some("auto") { + tracing::info!("Doing attestation type detection..."); + AttestationType::detect().await + } else { + serde_json::from_value(serde_json::Value::String( + attestation_type_string.unwrap_or("none".to_string()), + )) + .unwrap() + }; + tracing::info!("Local platform: {attestaton_type}"); + + Self::new(attestaton_type, dummy_dcap_url) + } + pub fn new( attestation_type: AttestationType, dummy_dcap_url: Option, @@ -116,6 +152,8 @@ impl AttestationGenerator { } } + /// Create an [AttestationGenerator] without a given dummy DCAP url - meaning Dummy attestation + /// type will not be possible pub fn new_not_dummy(attestation_type: AttestationType) -> Result { if attestation_type == AttestationType::Dummy { return Err(AttestationError::DummyUrl); @@ -127,6 +165,7 @@ impl AttestationGenerator { }) } + /// Create a dummy [AttestationGenerator] pub fn new_dummy(dummy_dcap_url: Option) -> Result { match dummy_dcap_url { Some(url) => { @@ -181,6 +220,9 @@ impl AttestationGenerator { } } + /// Generate a dummy attestaion by using an external service for the attestation generation + /// + /// This is for testing only async fn generate_dummy_attestation( &self, input_data: [u8; 64], @@ -351,3 +393,14 @@ pub enum AttestationError { #[error("Dummy server: {0}")] DummyServer(String), } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn attestation_detection_does_not_panic() { + // We dont enforce what platform the test is run on, only that the function does not panic + let _ = AttestationGenerator::new_with_detection(Some("auto".to_string()), None).await; + } +} From ad91b6aa778d7cbc95d36b54131fdfccaa4b077f Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 10 Dec 2025 18:01:07 +0100 Subject: [PATCH 2/3] Default to auto, not none --- src/attestation/mod.rs | 12 ++++++------ src/main.rs | 18 ++++++------------ 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/src/attestation/mod.rs b/src/attestation/mod.rs index 5c0d9e4..3877d94 100644 --- a/src/attestation/mod.rs +++ b/src/attestation/mod.rs @@ -121,14 +121,12 @@ impl AttestationGenerator { attestation_type_string: Option, dummy_dcap_url: Option, ) -> Result { - let attestaton_type = if attestation_type_string.as_deref() == Some("auto") { + let attestation_type_string = attestation_type_string.unwrap_or_else(|| "auto".to_string()); + let attestaton_type = if attestation_type_string == "auto" { tracing::info!("Doing attestation type detection..."); AttestationType::detect().await } else { - serde_json::from_value(serde_json::Value::String( - attestation_type_string.unwrap_or("none".to_string()), - )) - .unwrap() + serde_json::from_value(serde_json::Value::String(attestation_type_string))? }; tracing::info!("Local platform: {attestaton_type}"); @@ -392,6 +390,8 @@ pub enum AttestationError { DummyUrl, #[error("Dummy server: {0}")] DummyServer(String), + #[error("JSON: {0}")] + SerdeJson(#[from] serde_json::Error), } #[cfg(test)] @@ -401,6 +401,6 @@ mod tests { #[tokio::test] async fn attestation_detection_does_not_panic() { // We dont enforce what platform the test is run on, only that the function does not panic - let _ = AttestationGenerator::new_with_detection(Some("auto".to_string()), None).await; + let _ = AttestationGenerator::new_with_detection(None, None).await; } } diff --git a/src/main.rs b/src/main.rs index b5bae62..5acac23 100644 --- a/src/main.rs +++ b/src/main.rs @@ -43,7 +43,7 @@ enum CliCommand { listen_addr: SocketAddr, /// The hostname:port or ip:port of the proxy server (port defaults to 443) target_addr: String, - /// Type of attestation to present (dafaults to none) + /// Type of attestation to present (dafaults to 'auto' for automatic detection) /// If other than None, a TLS key and certicate must also be given #[arg(long, env = "CLIENT_ATTESTATION_TYPE")] client_attestation_type: Option, @@ -68,7 +68,7 @@ enum CliCommand { listen_addr: SocketAddr, /// Socket address of the target service to forward traffic to target_addr: SocketAddr, - /// Type of attestation to present (dafaults to none) + /// Type of attestation to present (dafaults to 'auto' for automatic detection) /// If other than None, a TLS key and certicate must also be given #[arg(long, env = "SERVER_ATTESTATION_TYPE")] server_attestation_type: Option, @@ -177,10 +177,6 @@ async fn main() -> anyhow::Result<()> { None }; - let client_attestation_type: AttestationType = serde_json::from_value( - serde_json::Value::String(client_attestation_type.unwrap_or("none".to_string())), - )?; - let remote_tls_cert = match tls_ca_certificate { Some(remote_cert_filename) => Some( load_certs_pem(remote_cert_filename)? @@ -192,7 +188,8 @@ async fn main() -> anyhow::Result<()> { }; let client_attestation_generator = - AttestationGenerator::new(client_attestation_type, dev_dummy_dcap)?; + AttestationGenerator::new_with_detection(client_attestation_type, dev_dummy_dcap) + .await?; let client = ProxyClient::new( tls_cert_and_chain, @@ -222,12 +219,9 @@ async fn main() -> anyhow::Result<()> { let tls_cert_and_chain = load_tls_cert_and_key(tls_certificate_path, tls_private_key_path)?; - let server_attestation_type: AttestationType = serde_json::from_value( - serde_json::Value::String(server_attestation_type.unwrap_or("none".to_string())), - )?; - let local_attestation_generator = - AttestationGenerator::new(server_attestation_type, dev_dummy_dcap)?; + AttestationGenerator::new_with_detection(server_attestation_type, dev_dummy_dcap) + .await?; let server = ProxyServer::new( tls_cert_and_chain, From 36639224b1297563b617ab325560c50a2adefe29 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 10 Dec 2025 19:36:25 +0100 Subject: [PATCH 3/3] Add check for running on GCP --- src/attestation/mod.rs | 53 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/src/attestation/mod.rs b/src/attestation/mod.rs index 3877d94..bc78624 100644 --- a/src/attestation/mod.rs +++ b/src/attestation/mod.rs @@ -8,13 +8,16 @@ use parity_scale_codec::{Decode, Encode}; use serde::{Deserialize, Serialize}; use std::{ fmt::{self, Display, Formatter}, - time::{SystemTime, UNIX_EPOCH}, + time::{Duration, SystemTime, UNIX_EPOCH}, }; use thiserror::Error; use crate::attestation::{dcap::DcapVerificationError, measurements::MeasurementPolicy}; +const GCP_METADATA_API: &str = + "http://metadata.google.internal/computeMetadata/v1/project/project-id"; + /// This is the type sent over the channel to provide an attestation #[derive(Clone, Debug, Serialize, Deserialize, Encode, Decode)] pub struct AttestationExchangeMessage { @@ -68,20 +71,23 @@ impl AttestationType { } /// Detect what platform we are on by attempting an attestation - pub async fn detect() -> Self { + pub async fn detect() -> Result { // First attempt azure, if the feature is present #[cfg(feature = "azure")] { if azure::create_azure_attestation([0; 64]).await.is_ok() { - return AttestationType::AzureTdx; + return Ok(AttestationType::AzureTdx); } } // Otherwise try DCAP quote - this internally checks that the quote provider is `tdx_guest` if configfs_tsm::create_tdx_quote([0; 64]).is_ok() { - // TODO Possibly also check if it looks like we are on GCP (eg: hit metadata API) - return AttestationType::DcapTdx; + if running_on_gcp().await? { + return Ok(AttestationType::GcpTdx); + } else { + return Ok(AttestationType::DcapTdx); + } } - AttestationType::None + Ok(AttestationType::None) } } @@ -124,7 +130,7 @@ impl AttestationGenerator { let attestation_type_string = attestation_type_string.unwrap_or_else(|| "auto".to_string()); let attestaton_type = if attestation_type_string == "auto" { tracing::info!("Doing attestation type detection..."); - AttestationType::detect().await + AttestationType::detect().await? } else { serde_json::from_value(serde_json::Value::String(attestation_type_string))? }; @@ -362,6 +368,32 @@ async fn log_attestation(attestation: &AttestationExchangeMessage) { } } +/// Test whether it looks like we are running on GCP by hitting the metadata API +async fn running_on_gcp() -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Metadata-Flavor", + "Google".parse().expect("Cannot parse header"), + ); + + let client = reqwest::Client::builder() + .timeout(Duration::from_millis(200)) + .default_headers(headers) + .build()?; + + let resp = client.get(GCP_METADATA_API).send().await; + + if let Ok(r) = resp { + return Ok(r.status().is_success() + && r.headers() + .get("Metadata-Flavor") + .map(|v| v == "Google") + .unwrap_or(false)); + } + + Ok(false) +} + /// An error when generating or verifying an attestation #[derive(Error, Debug)] pub enum AttestationError { @@ -392,6 +424,8 @@ pub enum AttestationError { DummyServer(String), #[error("JSON: {0}")] SerdeJson(#[from] serde_json::Error), + #[error("HTTP client: {0}")] + Reqwest(#[from] reqwest::Error), } #[cfg(test)] @@ -403,4 +437,9 @@ mod tests { // We dont enforce what platform the test is run on, only that the function does not panic let _ = AttestationGenerator::new_with_detection(None, None).await; } + + #[tokio::test] + async fn running_on_gcp_check_does_not_panic() { + let _ = running_on_gcp().await; + } }