Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion src/attestation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ use parity_scale_codec::{Decode, Encode};
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Display, Formatter},
time::{SystemTime, UNIX_EPOCH},
time::{Duration, SystemTime, UNIX_EPOCH},
};

use thiserror::Error;

use crate::attestation::{dcap::DcapVerificationError, measurements::MeasurementPolicy};

const GCP_METADATA_API: &str =
"http://metadata.google.internal/computeMetadata/v1/project/project-id";

/// This is the type sent over the channel to provide an attestation
#[derive(Clone, Debug, Serialize, Deserialize, Encode, Decode)]
pub struct AttestationExchangeMessage {
Expand Down Expand Up @@ -66,6 +69,26 @@ impl AttestationType {
AttestationType::DcapTdx => "dcap-tdx",
}
}

/// Detect what platform we are on by attempting an attestation
pub async fn detect() -> Result<Self, AttestationError> {
// First attempt azure, if the feature is present
#[cfg(feature = "azure")]
{
if azure::create_azure_attestation([0; 64]).await.is_ok() {
return Ok(AttestationType::AzureTdx);
}
}
// Otherwise try DCAP quote - this internally checks that the quote provider is `tdx_guest`
if configfs_tsm::create_tdx_quote([0; 64]).is_ok() {
if running_on_gcp().await? {
return Ok(AttestationType::GcpTdx);
} else {
return Ok(AttestationType::DcapTdx);
}
}
Ok(AttestationType::None)
}
}

/// SCALE encode (used over the wire)
Expand Down Expand Up @@ -99,6 +122,23 @@ pub struct AttestationGenerator {
}

impl AttestationGenerator {
/// Create an [AttestationGenerator] detecting the attestation type if it is specified as 'auto'
pub async fn new_with_detection(
attestation_type_string: Option<String>,
dummy_dcap_url: Option<String>,
) -> Result<Self, AttestationError> {
let attestation_type_string = attestation_type_string.unwrap_or_else(|| "auto".to_string());
let attestaton_type = if attestation_type_string == "auto" {
tracing::info!("Doing attestation type detection...");
AttestationType::detect().await?
} else {
serde_json::from_value(serde_json::Value::String(attestation_type_string))?
};
tracing::info!("Local platform: {attestaton_type}");

Self::new(attestaton_type, dummy_dcap_url)
}

pub fn new(
attestation_type: AttestationType,
dummy_dcap_url: Option<String>,
Expand All @@ -116,6 +156,8 @@ impl AttestationGenerator {
}
}

/// Create an [AttestationGenerator] without a given dummy DCAP url - meaning Dummy attestation
/// type will not be possible
pub fn new_not_dummy(attestation_type: AttestationType) -> Result<Self, AttestationError> {
if attestation_type == AttestationType::Dummy {
return Err(AttestationError::DummyUrl);
Expand All @@ -127,6 +169,7 @@ impl AttestationGenerator {
})
}

/// Create a dummy [AttestationGenerator]
pub fn new_dummy(dummy_dcap_url: Option<String>) -> Result<Self, AttestationError> {
match dummy_dcap_url {
Some(url) => {
Expand Down Expand Up @@ -181,6 +224,9 @@ impl AttestationGenerator {
}
}

/// Generate a dummy attestaion by using an external service for the attestation generation
///
/// This is for testing only
async fn generate_dummy_attestation(
&self,
input_data: [u8; 64],
Expand Down Expand Up @@ -322,6 +368,32 @@ async fn log_attestation(attestation: &AttestationExchangeMessage) {
}
}

/// Test whether it looks like we are running on GCP by hitting the metadata API
async fn running_on_gcp() -> Result<bool, AttestationError> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Metadata-Flavor",
"Google".parse().expect("Cannot parse header"),
);

let client = reqwest::Client::builder()
.timeout(Duration::from_millis(200))
.default_headers(headers)
.build()?;

let resp = client.get(GCP_METADATA_API).send().await;

if let Ok(r) = resp {
return Ok(r.status().is_success()
&& r.headers()
.get("Metadata-Flavor")
.map(|v| v == "Google")
.unwrap_or(false));
}

Ok(false)
}

/// An error when generating or verifying an attestation
#[derive(Error, Debug)]
pub enum AttestationError {
Expand Down Expand Up @@ -350,4 +422,24 @@ pub enum AttestationError {
DummyUrl,
#[error("Dummy server: {0}")]
DummyServer(String),
#[error("JSON: {0}")]
SerdeJson(#[from] serde_json::Error),
#[error("HTTP client: {0}")]
Reqwest(#[from] reqwest::Error),
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn attestation_detection_does_not_panic() {
// We dont enforce what platform the test is run on, only that the function does not panic
let _ = AttestationGenerator::new_with_detection(None, None).await;
}

#[tokio::test]
async fn running_on_gcp_check_does_not_panic() {
let _ = running_on_gcp().await;
}
}
18 changes: 6 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ enum CliCommand {
listen_addr: SocketAddr,
/// The hostname:port or ip:port of the proxy server (port defaults to 443)
target_addr: String,
/// Type of attestation to present (dafaults to none)
/// Type of attestation to present (dafaults to 'auto' for automatic detection)
/// If other than None, a TLS key and certicate must also be given
#[arg(long, env = "CLIENT_ATTESTATION_TYPE")]
client_attestation_type: Option<String>,
Expand All @@ -71,7 +71,7 @@ enum CliCommand {
listen_addr: SocketAddr,
/// Socket address of the target service to forward traffic to
target_addr: SocketAddr,
/// Type of attestation to present (dafaults to none)
/// Type of attestation to present (dafaults to 'auto' for automatic detection)
/// If other than None, a TLS key and certicate must also be given
#[arg(long, env = "SERVER_ATTESTATION_TYPE")]
server_attestation_type: Option<String>,
Expand Down Expand Up @@ -213,10 +213,6 @@ async fn main() -> anyhow::Result<()> {
None
};

let client_attestation_type: AttestationType = serde_json::from_value(
serde_json::Value::String(client_attestation_type.unwrap_or("none".to_string())),
)?;

let remote_tls_cert = match tls_ca_certificate {
Some(remote_cert_filename) => Some(
load_certs_pem(remote_cert_filename)?
Expand All @@ -228,7 +224,8 @@ async fn main() -> anyhow::Result<()> {
};

let client_attestation_generator =
AttestationGenerator::new(client_attestation_type, dev_dummy_dcap)?;
AttestationGenerator::new_with_detection(client_attestation_type, dev_dummy_dcap)
.await?;

let client = ProxyClient::new(
tls_cert_and_chain,
Expand Down Expand Up @@ -258,12 +255,9 @@ async fn main() -> anyhow::Result<()> {
let tls_cert_and_chain =
load_tls_cert_and_key(tls_certificate_path, tls_private_key_path)?;

let server_attestation_type: AttestationType = serde_json::from_value(
serde_json::Value::String(server_attestation_type.unwrap_or("none".to_string())),
)?;

let local_attestation_generator =
AttestationGenerator::new(server_attestation_type, dev_dummy_dcap)?;
AttestationGenerator::new_with_detection(server_attestation_type, dev_dummy_dcap)
.await?;

let server = ProxyServer::new(
tls_cert_and_chain,
Expand Down