Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion s2energy-connection/examples/communication-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct MemoryPairingStoreInner {
config: Arc<NodeConfig>,
server: S2NodeId,
client: S2NodeId,
// indication of whether the client has unpaired with us.
unpaired: bool,
}

#[derive(Clone)]
Expand All @@ -31,6 +33,7 @@ impl MemoryPairingStore {
config: Arc::new(NodeConfig::builder(vec![MessageVersion("v1".into())]).build()),
server: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8").into(),
client: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c7").into(),
unpaired: false,
})))
}
}
Expand All @@ -49,7 +52,11 @@ impl ServerPairingStore for MemoryPairingStore {
) -> Result<s2energy_connection::communication::PairingLookupResult<Self::Pairing<'_>>, Self::Error> {
let this = self.0.lock().unwrap();
if this.client == request.client && this.server == request.server {
Ok(PairingLookupResult::Pairing(self.clone()))
if this.unpaired {
Ok(PairingLookupResult::Unpaired)
} else {
Ok(PairingLookupResult::Pairing(self.clone()))
}
} else {
Ok(PairingLookupResult::NeverPaired)
}
Expand All @@ -71,6 +78,11 @@ impl ServerPairing for MemoryPairingStore {
self.0.lock().unwrap().token = token;
Ok(())
}

async fn unpair(self) -> Result<(), Self::Error> {
self.0.lock().unwrap().unpaired = true;
Ok(())
}
}

#[tokio::main(flavor = "current_thread")]
Expand Down
124 changes: 123 additions & 1 deletion s2energy-connection/src/communication/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use crate::{
common::negotiate_version,
communication::{
CommunicationResult, ConnectionInfo, Error, ErrorKind, NodeConfig, WebSocketTransport,
wire::{CommunicationDetails, CommunicationDetailsErrorMessage, InitiateConnectionRequest, InitiateConnectionResponse},
wire::{
CommunicationDetails, CommunicationDetailsErrorMessage, InitiateConnectionRequest, InitiateConnectionResponse, UnpairRequest,
},
},
};

Expand Down Expand Up @@ -65,6 +67,59 @@ impl Client {
}
}

/// Unpair the given pairing. The caller is responsible for deleting the pairing
/// upon success.
#[tracing::instrument(skip_all, fields(client = %pairing.client_id(), server = %pairing.server_id()), level = tracing::Level::ERROR)]
pub async fn unpair(&self, pairing: impl ClientPairing) -> CommunicationResult<()> {
let client = reqwest::Client::builder()
.tls_certs_merge(
self.additional_certificates
.iter()
.filter_map(|v| reqwest::Certificate::from_der(v).ok()),
)
.build()
.map_err(|e| Error::new(ErrorKind::TransportFailed, e))?;

let communication_url = Url::parse(pairing.communication_url().as_ref()).map_err(|e| Error::new(ErrorKind::InvalidUrl, e))?;

let version = negotiate_version(&client, communication_url.clone()).await?;

match version {
crate::common::wire::PairingVersion::V1 => {
let base_url = communication_url.join("v1/").unwrap();

let request = UnpairRequest {
client_node_id: pairing.client_id(),
server_node_id: pairing.server_id(),
};

for token in pairing.access_tokens().as_ref() {
let response = client
.post(base_url.join("unpair").unwrap())
.bearer_auth(&token.0)
.json(&request)
.send()
.await
.map_err(|e| Error::new(ErrorKind::TransportFailed, e))?;

if response.status() == StatusCode::UNAUTHORIZED {
debug!("Token was rejected by remote, assuming it is old.");
continue;
}

if response.status() != StatusCode::NO_CONTENT {
debug!(status = ?response.status(), "Unexpected status in response to initiateConnection request.");
return Err(ErrorKind::ProtocolError.into());
}

return Ok(());
}

Err(ErrorKind::NotPaired.into())
}
}
}

/// Establish a new connection with the server end of the given pairing.
#[tracing::instrument(skip_all, fields(client = %pairing.client_id(), server = %pairing.server_id()), level = tracing::Level::ERROR)]
pub async fn connect(&self, mut pairing: impl ClientPairing) -> CommunicationResult<ConnectionInfo> {
Expand Down Expand Up @@ -260,6 +315,7 @@ mod tests {
token: Arc<Mutex<AccessToken>>,
last_request: Arc<Mutex<Option<PairingLookup>>>,
config: NodeConfig,
deleted: Arc<Mutex<bool>>,
}

impl TestPairingStore {
Expand All @@ -268,6 +324,7 @@ mod tests {
token: Arc::new(Mutex::new(token)),
last_request: Arc::default(),
config,
deleted: Arc::default(),
}
}
}
Expand Down Expand Up @@ -301,6 +358,11 @@ mod tests {
*self.token.lock().unwrap() = dbg!(token);
Ok(())
}

async fn unpair(self) -> Result<(), Self::Error> {
*self.deleted.lock().unwrap() = true;
Ok(())
}
}

struct TestPairing {
Expand Down Expand Up @@ -367,6 +429,66 @@ mod tests {
(http_server_handle, server)
}

#[tokio::test]
async fn unpair() {
let store = TestPairingStore::new(
AccessToken("testtoken".into()),
NodeConfig::builder(vec![MessageVersion("v1".into())]).build(),
);
let (handle, _) = setup_server(store.clone(), None, Router::new()).await;

let addr = handle.listening().await.unwrap();
let client = Client::new(
ClientConfig {
additional_certificates: vec![CertificateDer::from_pem_slice(include_bytes!("../../testdata/root.pem")).unwrap()],
endpoint_description: None,
},
Arc::new(NodeConfig::builder(vec![MessageVersion("v1".into())]).build()),
);

let pairing = TestPairing {
client: UUID_A.into(),
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
};

assert!(client.unpair(&pairing).await.is_ok());
assert!(*store.deleted.lock().unwrap());
}

#[tokio::test]
async fn unpair_unauthorized() {
let (handle, _) = setup_server(
TestPairingStore::new(
AccessToken("testtoken".into()),
NodeConfig::builder(vec![MessageVersion("v1".into())]).build(),
),
None,
Router::new(),
)
.await;

let addr = handle.listening().await.unwrap();
let client = Client::new(
ClientConfig {
additional_certificates: vec![CertificateDer::from_pem_slice(include_bytes!("../../testdata/root.pem")).unwrap()],
endpoint_description: None,
},
Arc::new(NodeConfig::builder(vec![MessageVersion("v1".into())]).build()),
);

let pairing = TestPairing {
client: UUID_A.into(),
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("invalidtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
};

let error = client.unpair(&pairing).await.unwrap_err();
assert_eq!(error.kind(), ErrorKind::NotPaired);
}

#[tokio::test]
async fn succesfull_communication() {
let store = TestPairingStore::new(
Expand Down
6 changes: 6 additions & 0 deletions s2energy-connection/src/communication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@
//! # async fn set_access_token(&mut self, token: AccessToken) -> Result<(), Self::Error> {
//! # unimplemented!()
//! # }
//! # async fn unpair(self) -> Result<(), Self::Error> {
//! # unimplemented!()
//! # }
//! # }
//! # impl ServerPairingStore for SomeStorageProvider {
//! # type Error = Infallible;
Expand Down Expand Up @@ -177,6 +180,9 @@
//! # async fn set_access_token(&mut self, token: AccessToken) -> Result<(), Self::Error> {
//! # unimplemented!()
//! # }
//! # async fn unpair(self) -> Result<(), Self::Error> {
//! # unimplemented!()
//! # }
//! # }
//! # impl ServerPairingStore for SomeStorageProvider {
//! # type Error = Infallible;
Expand Down
Loading
Loading