diff --git a/s2energy-connection/examples/communication-server.rs b/s2energy-connection/examples/communication-server.rs index 6a9a228..7bfb2d5 100644 --- a/s2energy-connection/examples/communication-server.rs +++ b/s2energy-connection/examples/communication-server.rs @@ -19,6 +19,8 @@ struct MemoryPairingStoreInner { config: Arc, server: S2NodeId, client: S2NodeId, + // indication of whether the client has unpaired with us. + unpaired: bool, } #[derive(Clone)] @@ -31,6 +33,7 @@ impl MemoryPairingStore { config: Arc::new(NodeConfig::builder(vec![MessageVersion("v1".into())]).build()), server: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8").into(), client: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c7").into(), + unpaired: false, }))) } } @@ -49,7 +52,11 @@ impl ServerPairingStore for MemoryPairingStore { ) -> Result>, Self::Error> { let this = self.0.lock().unwrap(); if this.client == request.client && this.server == request.server { - Ok(PairingLookupResult::Pairing(self.clone())) + if this.unpaired { + Ok(PairingLookupResult::Unpaired) + } else { + Ok(PairingLookupResult::Pairing(self.clone())) + } } else { Ok(PairingLookupResult::NeverPaired) } @@ -71,6 +78,11 @@ impl ServerPairing for MemoryPairingStore { self.0.lock().unwrap().token = token; Ok(()) } + + async fn unpair(self) -> Result<(), Self::Error> { + self.0.lock().unwrap().unpaired = true; + Ok(()) + } } #[tokio::main(flavor = "current_thread")] diff --git a/s2energy-connection/src/communication/client.rs b/s2energy-connection/src/communication/client.rs index 19aadb6..08a8f42 100644 --- a/s2energy-connection/src/communication/client.rs +++ b/s2energy-connection/src/communication/client.rs @@ -11,7 +11,9 @@ use crate::{ common::negotiate_version, communication::{ CommunicationResult, ConnectionInfo, Error, ErrorKind, NodeConfig, WebSocketTransport, - wire::{CommunicationDetails, CommunicationDetailsErrorMessage, InitiateConnectionRequest, InitiateConnectionResponse}, + wire::{ + CommunicationDetails, CommunicationDetailsErrorMessage, InitiateConnectionRequest, InitiateConnectionResponse, UnpairRequest, + }, }, }; @@ -65,6 +67,59 @@ impl Client { } } + /// Unpair the given pairing. The caller is responsible for deleting the pairing + /// upon success. + #[tracing::instrument(skip_all, fields(client = %pairing.client_id(), server = %pairing.server_id()), level = tracing::Level::ERROR)] + pub async fn unpair(&self, pairing: impl ClientPairing) -> CommunicationResult<()> { + let client = reqwest::Client::builder() + .tls_certs_merge( + self.additional_certificates + .iter() + .filter_map(|v| reqwest::Certificate::from_der(v).ok()), + ) + .build() + .map_err(|e| Error::new(ErrorKind::TransportFailed, e))?; + + let communication_url = Url::parse(pairing.communication_url().as_ref()).map_err(|e| Error::new(ErrorKind::InvalidUrl, e))?; + + let version = negotiate_version(&client, communication_url.clone()).await?; + + match version { + crate::common::wire::PairingVersion::V1 => { + let base_url = communication_url.join("v1/").unwrap(); + + let request = UnpairRequest { + client_node_id: pairing.client_id(), + server_node_id: pairing.server_id(), + }; + + for token in pairing.access_tokens().as_ref() { + let response = client + .post(base_url.join("unpair").unwrap()) + .bearer_auth(&token.0) + .json(&request) + .send() + .await + .map_err(|e| Error::new(ErrorKind::TransportFailed, e))?; + + if response.status() == StatusCode::UNAUTHORIZED { + debug!("Token was rejected by remote, assuming it is old."); + continue; + } + + if response.status() != StatusCode::NO_CONTENT { + debug!(status = ?response.status(), "Unexpected status in response to initiateConnection request."); + return Err(ErrorKind::ProtocolError.into()); + } + + return Ok(()); + } + + Err(ErrorKind::NotPaired.into()) + } + } + } + /// Establish a new connection with the server end of the given pairing. #[tracing::instrument(skip_all, fields(client = %pairing.client_id(), server = %pairing.server_id()), level = tracing::Level::ERROR)] pub async fn connect(&self, mut pairing: impl ClientPairing) -> CommunicationResult { @@ -260,6 +315,7 @@ mod tests { token: Arc>, last_request: Arc>>, config: NodeConfig, + deleted: Arc>, } impl TestPairingStore { @@ -268,6 +324,7 @@ mod tests { token: Arc::new(Mutex::new(token)), last_request: Arc::default(), config, + deleted: Arc::default(), } } } @@ -301,6 +358,11 @@ mod tests { *self.token.lock().unwrap() = dbg!(token); Ok(()) } + + async fn unpair(self) -> Result<(), Self::Error> { + *self.deleted.lock().unwrap() = true; + Ok(()) + } } struct TestPairing { @@ -367,6 +429,66 @@ mod tests { (http_server_handle, server) } + #[tokio::test] + async fn unpair() { + let store = TestPairingStore::new( + AccessToken("testtoken".into()), + NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + ); + let (handle, _) = setup_server(store.clone(), None, Router::new()).await; + + let addr = handle.listening().await.unwrap(); + let client = Client::new( + ClientConfig { + additional_certificates: vec![CertificateDer::from_pem_slice(include_bytes!("../../testdata/root.pem")).unwrap()], + endpoint_description: None, + }, + Arc::new(NodeConfig::builder(vec![MessageVersion("v1".into())]).build()), + ); + + let pairing = TestPairing { + client: UUID_A.into(), + server: UUID_B.into(), + tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])), + url: format!("https://localhost:{}/", addr.port()), + }; + + assert!(client.unpair(&pairing).await.is_ok()); + assert!(*store.deleted.lock().unwrap()); + } + + #[tokio::test] + async fn unpair_unauthorized() { + let (handle, _) = setup_server( + TestPairingStore::new( + AccessToken("testtoken".into()), + NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + ), + None, + Router::new(), + ) + .await; + + let addr = handle.listening().await.unwrap(); + let client = Client::new( + ClientConfig { + additional_certificates: vec![CertificateDer::from_pem_slice(include_bytes!("../../testdata/root.pem")).unwrap()], + endpoint_description: None, + }, + Arc::new(NodeConfig::builder(vec![MessageVersion("v1".into())]).build()), + ); + + let pairing = TestPairing { + client: UUID_A.into(), + server: UUID_B.into(), + tokens: Arc::new(Mutex::new(vec![AccessToken("invalidtoken".into())])), + url: format!("https://localhost:{}/", addr.port()), + }; + + let error = client.unpair(&pairing).await.unwrap_err(); + assert_eq!(error.kind(), ErrorKind::NotPaired); + } + #[tokio::test] async fn succesfull_communication() { let store = TestPairingStore::new( diff --git a/s2energy-connection/src/communication/mod.rs b/s2energy-connection/src/communication/mod.rs index 101dced..b481fd4 100644 --- a/s2energy-connection/src/communication/mod.rs +++ b/s2energy-connection/src/communication/mod.rs @@ -125,6 +125,9 @@ //! # async fn set_access_token(&mut self, token: AccessToken) -> Result<(), Self::Error> { //! # unimplemented!() //! # } +//! # async fn unpair(self) -> Result<(), Self::Error> { +//! # unimplemented!() +//! # } //! # } //! # impl ServerPairingStore for SomeStorageProvider { //! # type Error = Infallible; @@ -177,6 +180,9 @@ //! # async fn set_access_token(&mut self, token: AccessToken) -> Result<(), Self::Error> { //! # unimplemented!() //! # } +//! # async fn unpair(self) -> Result<(), Self::Error> { +//! # unimplemented!() +//! # } //! # } //! # impl ServerPairingStore for SomeStorageProvider { //! # type Error = Infallible; diff --git a/s2energy-connection/src/communication/server.rs b/s2energy-connection/src/communication/server.rs index 1d1dc60..2616cd3 100644 --- a/s2energy-connection/src/communication/server.rs +++ b/s2energy-connection/src/communication/server.rs @@ -20,7 +20,7 @@ use crate::{ ConnectionInfo, NodeConfig, WebSocketTransport, wire::{ CommunicationDetails, CommunicationDetailsErrorMessage, CommunicationToken, InitiateConnectionRequest, - InitiateConnectionResponse, WebSocketCommunicationDetails, + InitiateConnectionResponse, UnpairRequest, WebSocketCommunicationDetails, }, }, }; @@ -76,6 +76,9 @@ pub trait ServerPairing: Send { /// Change the stored access token for this pairing. fn set_access_token(&mut self, token: AccessToken) -> impl Future> + Send; + + /// Remove this pairing from the store. + fn unpair(self) -> impl Future> + Send; } /// Configuration for the S2 connection server. @@ -194,11 +197,40 @@ fn select_overlap(primary: &[T], secondary: &[T]) -> Option { fn v1_router() -> Router> { Router::new() + .route("/unpair", post(v1_unpair)) .route("/initiateConnection", post(v1_initiate_connection)) .route("/confirmAccessToken", post(v1_confirm_access_token)) .route("/websocket", get(v1_websocket)) } +#[tracing::instrument(skip_all, level = tracing::Level::INFO)] +async fn v1_unpair( + State(state): State>, + token: AccessToken, + Json(request): Json, +) -> StatusCode { + let lookup = PairingLookup { + client: request.client_node_id, + server: request.server_node_id, + }; + + let pairing = match state.store.lookup(lookup.clone()).await { + Ok(PairingLookupResult::Pairing(pairing)) => pairing, + Ok(PairingLookupResult::NeverPaired | PairingLookupResult::Unpaired) => return StatusCode::UNAUTHORIZED, + Err(_) => return StatusCode::INTERNAL_SERVER_ERROR, + }; + + if pairing.access_token().as_ref() != &token { + return StatusCode::UNAUTHORIZED; + } + + if pairing.unpair().await.is_err() { + return StatusCode::INTERNAL_SERVER_ERROR; + } + + StatusCode::NO_CONTENT +} + #[tracing::instrument(skip_all, level = tracing::Level::INFO)] async fn v1_initiate_connection( State(state): State>, @@ -422,7 +454,7 @@ mod tests { server::{Expiring, PendingWebsocket, Session}, wire::{ CommunicationDetails, CommunicationDetailsErrorMessage, CommunicationToken, InitiateConnectionRequest, - InitiateConnectionResponse, + InitiateConnectionResponse, UnpairRequest, }, }, }; @@ -446,6 +478,10 @@ mod tests { async fn set_access_token(&mut self, _token: AccessToken) -> Result<(), Self::Error> { unimplemented!() } + + async fn unpair(self) -> Result<(), Self::Error> { + unimplemented!() + } } impl ServerPairingStore for NoneStore { @@ -468,6 +504,7 @@ mod tests { #[derive(Debug, Clone)] struct TestStore { token: Arc>, + deleted: Arc>, config: NodeConfig, } @@ -486,6 +523,11 @@ mod tests { *self.token.lock().unwrap() = token; Ok(()) } + + async fn unpair(self) -> Result<(), Self::Error> { + *self.deleted.lock().unwrap() = true; + Ok(()) + } } impl ServerPairingStore for TestStore { @@ -525,6 +567,117 @@ mod tests { assert_eq!(body, b"[\"v1\"]".as_slice()); } + #[tokio::test] + async fn unpair() { + let store = TestStore { + token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), + config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), + }; + let server = Server::new( + ServerConfig { + base_url: "localhost".into(), + endpoint_description: None, + }, + store.clone(), + ); + + let response = server + .get_router() + .oneshot( + http::Request::post("/v1/unpair") + .header(http::header::AUTHORIZATION, "Bearer testtoken") + .header(http::header::CONTENT_TYPE, "application/json") + .body(Body::from( + serde_json::to_vec(&UnpairRequest { + client_node_id: UUID_A.into(), + server_node_id: UUID_B.into(), + }) + .unwrap(), + )) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); + + assert!(*store.deleted.lock().unwrap()); + } + + #[tokio::test] + async fn unpair_nonexisting() { + let store = TestStore { + token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), + config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), + }; + let server = Server::new( + ServerConfig { + base_url: "localhost".into(), + endpoint_description: None, + }, + store.clone(), + ); + + let response = server + .get_router() + .oneshot( + http::Request::post("/v1/unpair") + .header(http::header::AUTHORIZATION, "Bearer testtoken") + .header(http::header::CONTENT_TYPE, "application/json") + .body(Body::from( + serde_json::to_vec(&UnpairRequest { + client_node_id: UUID_B.into(), + server_node_id: UUID_A.into(), + }) + .unwrap(), + )) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + assert!(!*store.deleted.lock().unwrap()); + } + + #[tokio::test] + async fn unpair_invalid_token() { + let store = TestStore { + token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), + config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), + }; + let server = Server::new( + ServerConfig { + base_url: "localhost".into(), + endpoint_description: None, + }, + store.clone(), + ); + + let response = server + .get_router() + .oneshot( + http::Request::post("/v1/unpair") + .header(http::header::AUTHORIZATION, "Bearer invalidtoken") + .header(http::header::CONTENT_TYPE, "application/json") + .body(Body::from( + serde_json::to_vec(&UnpairRequest { + client_node_id: UUID_A.into(), + server_node_id: UUID_B.into(), + }) + .unwrap(), + )) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + assert!(!*store.deleted.lock().unwrap()); + } + #[tokio::test] async fn initiate_communication() { let server = Server::new( @@ -535,6 +688,7 @@ mod tests { TestStore { token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), }, ); @@ -649,6 +803,7 @@ mod tests { TestStore { token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), }, ); @@ -686,6 +841,7 @@ mod tests { TestStore { token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), }, ); @@ -726,6 +882,7 @@ mod tests { TestStore { token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), }, ); @@ -761,6 +918,7 @@ mod tests { let store = TestStore { token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), }; let server = Server::new( ServerConfig { @@ -815,6 +973,7 @@ mod tests { TestStore { token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), }, ); server.app_state.pending_tokens.lock().unwrap().insert( @@ -870,6 +1029,10 @@ mod tests { async fn set_access_token(&mut self, _token: AccessToken) -> Result<(), Self::Error> { Err(std::io::ErrorKind::Other.into()) } + + async fn unpair(self) -> Result<(), Self::Error> { + unimplemented!() + } } impl ServerPairingStore for NoStoreStore { @@ -943,6 +1106,7 @@ mod tests { TestStore { token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), }, ); server.app_state.pending_websockets.lock().unwrap().insert( @@ -1018,6 +1182,7 @@ mod tests { TestStore { token: Arc::new(Mutex::new(AccessToken("testtoken".into()))), config: NodeConfig::builder(vec![MessageVersion("v1".into())]).build(), + deleted: Arc::default(), }, ); server.app_state.pending_websockets.lock().unwrap().insert( diff --git a/s2energy-connection/src/communication/wire.rs b/s2energy-connection/src/communication/wire.rs index d90f14f..25ba3c8 100644 --- a/s2energy-connection/src/communication/wire.rs +++ b/s2energy-connection/src/communication/wire.rs @@ -63,6 +63,14 @@ pub(crate) struct InitiateConnectionResponse { pub(crate) endpoint_description: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct UnpairRequest { + #[serde(rename = "clientS2NodeId")] + pub(crate) client_node_id: S2NodeId, + #[serde(rename = "serverS2NodeId")] + pub(crate) server_node_id: S2NodeId, +} + /// One-time access token for secure access to the S2 message communication channel. It must be renewed every time a client wants to access /// the S2 message communication channel by calling the requestToken endpoint. This token is valid for one time login, with a maximum 5 /// years, and should have a minimum length of 32 bytes.