diff --git a/src/attestation/mod.rs b/src/attestation/mod.rs index 1142d30..bc78624 100644 --- a/src/attestation/mod.rs +++ b/src/attestation/mod.rs @@ -8,13 +8,16 @@ use parity_scale_codec::{Decode, Encode}; use serde::{Deserialize, Serialize}; use std::{ fmt::{self, Display, Formatter}, - time::{SystemTime, UNIX_EPOCH}, + time::{Duration, SystemTime, UNIX_EPOCH}, }; use thiserror::Error; use crate::attestation::{dcap::DcapVerificationError, measurements::MeasurementPolicy}; +const GCP_METADATA_API: &str = + "http://metadata.google.internal/computeMetadata/v1/project/project-id"; + /// This is the type sent over the channel to provide an attestation #[derive(Clone, Debug, Serialize, Deserialize, Encode, Decode)] pub struct AttestationExchangeMessage { @@ -66,6 +69,26 @@ impl AttestationType { AttestationType::DcapTdx => "dcap-tdx", } } + + /// Detect what platform we are on by attempting an attestation + pub async fn detect() -> Result { + // First attempt azure, if the feature is present + #[cfg(feature = "azure")] + { + if azure::create_azure_attestation([0; 64]).await.is_ok() { + return Ok(AttestationType::AzureTdx); + } + } + // Otherwise try DCAP quote - this internally checks that the quote provider is `tdx_guest` + if configfs_tsm::create_tdx_quote([0; 64]).is_ok() { + if running_on_gcp().await? { + return Ok(AttestationType::GcpTdx); + } else { + return Ok(AttestationType::DcapTdx); + } + } + Ok(AttestationType::None) + } } /// SCALE encode (used over the wire) @@ -99,6 +122,23 @@ pub struct AttestationGenerator { } impl AttestationGenerator { + /// Create an [AttestationGenerator] detecting the attestation type if it is specified as 'auto' + pub async fn new_with_detection( + attestation_type_string: Option, + dummy_dcap_url: Option, + ) -> Result { + let attestation_type_string = attestation_type_string.unwrap_or_else(|| "auto".to_string()); + let attestaton_type = if attestation_type_string == "auto" { + tracing::info!("Doing attestation type detection..."); + AttestationType::detect().await? + } else { + serde_json::from_value(serde_json::Value::String(attestation_type_string))? + }; + tracing::info!("Local platform: {attestaton_type}"); + + Self::new(attestaton_type, dummy_dcap_url) + } + pub fn new( attestation_type: AttestationType, dummy_dcap_url: Option, @@ -116,6 +156,8 @@ impl AttestationGenerator { } } + /// Create an [AttestationGenerator] without a given dummy DCAP url - meaning Dummy attestation + /// type will not be possible pub fn new_not_dummy(attestation_type: AttestationType) -> Result { if attestation_type == AttestationType::Dummy { return Err(AttestationError::DummyUrl); @@ -127,6 +169,7 @@ impl AttestationGenerator { }) } + /// Create a dummy [AttestationGenerator] pub fn new_dummy(dummy_dcap_url: Option) -> Result { match dummy_dcap_url { Some(url) => { @@ -181,6 +224,9 @@ impl AttestationGenerator { } } + /// Generate a dummy attestaion by using an external service for the attestation generation + /// + /// This is for testing only async fn generate_dummy_attestation( &self, input_data: [u8; 64], @@ -322,6 +368,32 @@ async fn log_attestation(attestation: &AttestationExchangeMessage) { } } +/// Test whether it looks like we are running on GCP by hitting the metadata API +async fn running_on_gcp() -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Metadata-Flavor", + "Google".parse().expect("Cannot parse header"), + ); + + let client = reqwest::Client::builder() + .timeout(Duration::from_millis(200)) + .default_headers(headers) + .build()?; + + let resp = client.get(GCP_METADATA_API).send().await; + + if let Ok(r) = resp { + return Ok(r.status().is_success() + && r.headers() + .get("Metadata-Flavor") + .map(|v| v == "Google") + .unwrap_or(false)); + } + + Ok(false) +} + /// An error when generating or verifying an attestation #[derive(Error, Debug)] pub enum AttestationError { @@ -350,4 +422,24 @@ pub enum AttestationError { DummyUrl, #[error("Dummy server: {0}")] DummyServer(String), + #[error("JSON: {0}")] + SerdeJson(#[from] serde_json::Error), + #[error("HTTP client: {0}")] + Reqwest(#[from] reqwest::Error), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn attestation_detection_does_not_panic() { + // We dont enforce what platform the test is run on, only that the function does not panic + let _ = AttestationGenerator::new_with_detection(None, None).await; + } + + #[tokio::test] + async fn running_on_gcp_check_does_not_panic() { + let _ = running_on_gcp().await; + } } diff --git a/src/main.rs b/src/main.rs index d20facc..f5c5935 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,7 +46,7 @@ enum CliCommand { listen_addr: SocketAddr, /// The hostname:port or ip:port of the proxy server (port defaults to 443) target_addr: String, - /// Type of attestation to present (dafaults to none) + /// Type of attestation to present (dafaults to 'auto' for automatic detection) /// If other than None, a TLS key and certicate must also be given #[arg(long, env = "CLIENT_ATTESTATION_TYPE")] client_attestation_type: Option, @@ -71,7 +71,7 @@ enum CliCommand { listen_addr: SocketAddr, /// Socket address of the target service to forward traffic to target_addr: SocketAddr, - /// Type of attestation to present (dafaults to none) + /// Type of attestation to present (dafaults to 'auto' for automatic detection) /// If other than None, a TLS key and certicate must also be given #[arg(long, env = "SERVER_ATTESTATION_TYPE")] server_attestation_type: Option, @@ -213,10 +213,6 @@ async fn main() -> anyhow::Result<()> { None }; - let client_attestation_type: AttestationType = serde_json::from_value( - serde_json::Value::String(client_attestation_type.unwrap_or("none".to_string())), - )?; - let remote_tls_cert = match tls_ca_certificate { Some(remote_cert_filename) => Some( load_certs_pem(remote_cert_filename)? @@ -228,7 +224,8 @@ async fn main() -> anyhow::Result<()> { }; let client_attestation_generator = - AttestationGenerator::new(client_attestation_type, dev_dummy_dcap)?; + AttestationGenerator::new_with_detection(client_attestation_type, dev_dummy_dcap) + .await?; let client = ProxyClient::new( tls_cert_and_chain, @@ -258,12 +255,9 @@ async fn main() -> anyhow::Result<()> { let tls_cert_and_chain = load_tls_cert_and_key(tls_certificate_path, tls_private_key_path)?; - let server_attestation_type: AttestationType = serde_json::from_value( - serde_json::Value::String(server_attestation_type.unwrap_or("none".to_string())), - )?; - let local_attestation_generator = - AttestationGenerator::new(server_attestation_type, dev_dummy_dcap)?; + AttestationGenerator::new_with_detection(server_attestation_type, dev_dummy_dcap) + .await?; let server = ProxyServer::new( tls_cert_and_chain,