From 04d34d9a8d15a9dccf2b4135d6ed7ec7eb43c664 Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Fri, 29 Apr 2022 15:57:54 +0200 Subject: [PATCH 1/8] Scetch a possiblity for custom logic in the broker Some thoughts about a broker plugins --- mqtt-v5-broker/src/broker.rs | 58 +++++++++++++++++++++++++-------- mqtt-v5-broker/src/client.rs | 6 ++++ mqtt-v5-broker/src/lib.rs | 1 + mqtt-v5-broker/src/main.rs | 37 +++++++++++++++++++-- mqtt-v5-broker/src/plugin.rs | 63 ++++++++++++++++++++++++++++++++++++ 5 files changed, 150 insertions(+), 15 deletions(-) create mode 100644 mqtt-v5-broker/src/plugin.rs diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index 3e9ead9..f3554e1 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -1,14 +1,19 @@ -use crate::{client::ClientMessage, tree::SubscriptionTree}; +use crate::{ + client::ClientMessage, + plugin::{Noop, Plugin}, + tree::SubscriptionTree, +}; use log::{info, warn}; use mqtt_v5::{ topic::TopicFilter, types::{ properties::{AssignedClientIdentifier, SessionExpiryInterval}, - ConnectAckPacket, ConnectPacket, ConnectReason, DisconnectReason, FinalWill, Packet, - ProtocolVersion, PublishAckPacket, PublishAckReason, PublishCompletePacket, - PublishCompleteReason, PublishPacket, PublishReceivedPacket, PublishReceivedReason, - PublishReleasePacket, PublishReleaseReason, QoS, SubscribeAckPacket, SubscribeAckReason, - SubscribePacket, UnsubscribeAckPacket, UnsubscribeAckReason, UnsubscribePacket, + AuthenticatePacket, ConnectAckPacket, ConnectPacket, ConnectReason, DisconnectReason, + FinalWill, Packet, ProtocolVersion, PublishAckPacket, PublishAckReason, + PublishCompletePacket, PublishCompleteReason, PublishPacket, PublishReceivedPacket, + PublishReceivedReason, PublishReleasePacket, PublishReleaseReason, QoS, SubscribeAckPacket, + SubscribeAckReason, SubscribePacket, UnsubscribeAckPacket, UnsubscribeAckReason, + UnsubscribePacket, }, }; use std::{ @@ -179,6 +184,7 @@ pub enum WillDisconnectLogic { #[derive(Debug)] pub enum BrokerMessage { NewClient(Box, Sender), + Authenticate(String, AuthenticatePacket), Publish(String, Box), PublishAck(String, PublishAckPacket), // TODO - This can be handled by the client task PublishRelease(String, PublishReleasePacket), // TODO - This can be handled by the client task @@ -190,24 +196,46 @@ pub enum BrokerMessage { Disconnect(String, WillDisconnectLogic), } -pub struct Broker { +pub struct Broker { sessions: HashMap, sender: Sender, receiver: Receiver, subscriptions: SubscriptionTree, + #[allow(unused)] + plugin: A, } -impl Default for Broker { +impl Default for Broker { fn default() -> Self { - Self::new() + Broker::::new() } } -impl Broker { - pub fn new() -> Self { +impl Broker { + /// Construct a new Broker. + pub fn new() -> Broker { let (sender, receiver) = mpsc::channel(100); - Self { sessions: HashMap::new(), sender, receiver, subscriptions: SubscriptionTree::new() } + Broker { + sessions: HashMap::new(), + sender, + receiver, + subscriptions: SubscriptionTree::new(), + plugin: Noop, + } + } + + /// Construct a new Broker. + pub fn with_plugin(plugin: A) -> Broker { + let (sender, receiver) = mpsc::channel(100); + + Broker { + sessions: HashMap::new(), + sender, + receiver, + subscriptions: SubscriptionTree::new(), + plugin, + } } pub fn sender(&self) -> Sender { @@ -693,6 +721,9 @@ impl Broker { BrokerMessage::PublishFinalWill(client_id, final_will) => { self.publish_final_will(client_id, final_will).await; }, + BrokerMessage::Authenticate(client_id, _) => { + warn!("Ignoring unexpected authentication message from {}", client_id); + }, } } } @@ -703,6 +734,7 @@ mod tests { use crate::{ broker::{Broker, BrokerMessage}, client::ClientMessage, + plugin::Noop, }; use mqtt_v5::types::{properties::*, ProtocolVersion, *}; use tokio::{ @@ -802,7 +834,7 @@ mod tests { #[test] fn simple_client_test() { - let broker = Broker::new(); + let broker = Broker::::new(); let sender = broker.sender(); let runtime = Runtime::new().unwrap(); diff --git a/mqtt-v5-broker/src/client.rs b/mqtt-v5-broker/src/client.rs index 96394ad..9eb11a2 100644 --- a/mqtt-v5-broker/src/client.rs +++ b/mqtt-v5-broker/src/client.rs @@ -403,6 +403,12 @@ impl + Unpin, SI: Sink { + broker_tx + .send(BrokerMessage::Authenticate(client_id.clone(), packet)) + .await + .expect("Couldn't send Authentivate message to self"); + }, _ => {}, }, Err(err) => { diff --git a/mqtt-v5-broker/src/lib.rs b/mqtt-v5-broker/src/lib.rs index 72cad42..3fc70f1 100644 --- a/mqtt-v5-broker/src/lib.rs +++ b/mqtt-v5-broker/src/lib.rs @@ -1,3 +1,4 @@ pub mod broker; pub mod client; +pub mod plugin; mod tree; diff --git a/mqtt-v5-broker/src/main.rs b/mqtt-v5-broker/src/main.rs index 88ea9ba..1a61081 100644 --- a/mqtt-v5-broker/src/main.rs +++ b/mqtt-v5-broker/src/main.rs @@ -1,10 +1,15 @@ use std::{env, io}; use futures::future::try_join_all; -use log::{debug, info}; +use log::{debug, info, trace}; +use mqtt_v5::types::{ + properties::{AuthenticationData, AuthenticationMethod}, + AuthenticatePacket, PublishPacket, SubscribePacket, +}; use mqtt_v5_broker::{ broker::{Broker, BrokerMessage}, client, + plugin::{self}, }; use tokio::{net::TcpListener, sync::mpsc::Sender, task}; @@ -44,11 +49,39 @@ fn init_logging() { } } +#[derive(Default)] +struct TracePlugin; + +impl plugin::Plugin for TracePlugin { + fn on_connect( + &mut self, + _: Option<&AuthenticationMethod>, + _: Option<&AuthenticationData>, + ) -> plugin::AuthentificationResult { + plugin::AuthentificationResult::Success + } + + fn on_authenticate(&mut self, packet: &AuthenticatePacket) -> plugin::AuthentificationResult { + trace!("Authenticate packet received: {:?}", packet); + plugin::AuthentificationResult::Success + } + + fn on_subscribe(&mut self, packet: &SubscribePacket) -> plugin::SubscribeResult { + trace!("Subscribe packet received: {:?}", packet); + plugin::SubscribeResult::Placeholder + } + + fn on_publish_received(&mut self, packet: &PublishPacket) -> plugin::PublishReceivedResult { + trace!("Publish packet received: {:?}", packet); + plugin::PublishReceivedResult::Placeholder + } +} + #[tokio::main] async fn main() -> Result<(), Box> { init_logging(); - let broker = Broker::new(); + let broker = Broker::with_plugin(TracePlugin::default()); let broker_tx = broker.sender(); let broker = task::spawn(async { broker.run().await; diff --git a/mqtt-v5-broker/src/plugin.rs b/mqtt-v5-broker/src/plugin.rs new file mode 100644 index 0000000..ec98611 --- /dev/null +++ b/mqtt-v5-broker/src/plugin.rs @@ -0,0 +1,63 @@ +use mqtt_v5::types::{ + properties::{AuthenticationData, AuthenticationMethod}, + *, +}; + +pub struct Noop; + +/// Result of a authentication attempt +pub enum AuthentificationResult { + /// Authentification is successful. + Success, + /// Authentification failed. Send connect reason code to the client (ConnectAck) + Fail(ConnectReason), + /// Send this auth packet to the client and wait for the response. + Packet(AuthenticatePacket), +} + +pub enum PublishReceivedResult { + Placeholder, +} + +pub enum SubscribeResult { + Placeholder, +} + +/// Broker plugin +pub trait Plugin { + /// Called on connect packet reception + fn on_connect( + &mut self, + method: Option<&AuthenticationMethod>, + data: Option<&AuthenticationData>, + ) -> AuthentificationResult; + + /// Called on authenticate packet reception + fn on_authenticate(&mut self, packet: &AuthenticatePacket) -> AuthentificationResult; + + fn on_subscribe(&mut self, packet: &SubscribePacket) -> SubscribeResult; + fn on_publish_received(&mut self, packet: &PublishPacket) -> PublishReceivedResult; +} + +/// Default noop authenticator +impl Plugin for Noop { + fn on_connect( + &mut self, + _: Option<&AuthenticationMethod>, + _: Option<&AuthenticationData>, + ) -> AuthentificationResult { + AuthentificationResult::Success + } + + fn on_authenticate(&mut self, _: &AuthenticatePacket) -> AuthentificationResult { + AuthentificationResult::Success + } + + fn on_subscribe(&mut self, _: &SubscribePacket) -> SubscribeResult { + SubscribeResult::Placeholder + } + + fn on_publish_received(&mut self, _: &PublishPacket) -> PublishReceivedResult { + PublishReceivedResult::Placeholder + } +} From 9077dc4c4148425addc8a50ee554b1c18e88f323 Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Mon, 2 May 2022 09:26:00 +0200 Subject: [PATCH 2/8] Add subscribe and publish to plugin trait --- mqtt-v5-broker/src/broker.rs | 124 +++++++++++++++++++---------------- mqtt-v5-broker/src/main.rs | 38 +---------- mqtt-v5-broker/src/plugin.rs | 93 ++++++++++++++++++++++---- 3 files changed, 148 insertions(+), 107 deletions(-) diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index f3554e1..cc56375 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -9,11 +9,10 @@ use mqtt_v5::{ types::{ properties::{AssignedClientIdentifier, SessionExpiryInterval}, AuthenticatePacket, ConnectAckPacket, ConnectPacket, ConnectReason, DisconnectReason, - FinalWill, Packet, ProtocolVersion, PublishAckPacket, PublishAckReason, - PublishCompletePacket, PublishCompleteReason, PublishPacket, PublishReceivedPacket, - PublishReceivedReason, PublishReleasePacket, PublishReleaseReason, QoS, SubscribeAckPacket, - SubscribeAckReason, SubscribePacket, UnsubscribeAckPacket, UnsubscribeAckReason, - UnsubscribePacket, + FinalWill, Packet, ProtocolVersion, PublishAckPacket, PublishCompletePacket, + PublishCompleteReason, PublishPacket, PublishReceivedPacket, PublishReleasePacket, + PublishReleaseReason, QoS, SubscribeAckPacket, SubscribeAckReason, SubscribePacket, + UnsubscribeAckPacket, UnsubscribeAckReason, UnsubscribePacket, }, }; use std::{ @@ -384,10 +383,22 @@ impl Broker { let subscriptions = &mut self.subscriptions; if let Some(session) = self.sessions.get_mut(&client_id) { + let plugin_ack = self.plugin.on_subscribe(&packet); + // If a Server receives a SUBSCRIBE packet containing a Topic Filter that // is identical to a Non‑shared Subscription’s Topic Filter for the current // Session, then it MUST replace that existing Subscription with a new Subscription. - for topic in &packet.subscription_topics { + for (topic, plugin_code) in + packet.subscription_topics.iter().zip(&plugin_ack.reason_codes) + { + // Check only subscriptions that the plugin didn't reject. + match plugin_code { + SubscribeAckReason::GrantedQoSZero + | SubscribeAckReason::GrantedQoSOne + | SubscribeAckReason::GrantedQoSTwo => (), + _ => continue, + } + let topic = &topic.topic_filter; // Unsubscribe the old session from all topics it subscribed to. session.subscription_tokens.retain(|(session_topic, token)| { @@ -405,27 +416,28 @@ impl Broker { let granted_qos_values = packet .subscription_topics .into_iter() - .map(|topic| { - let session_subscription = SessionSubscription { - client_id: client_id.clone(), - maximum_qos: topic.maximum_qos, - }; - let token = subscriptions.insert(&topic.topic_filter, session_subscription); - - session.subscription_tokens.push((topic.topic_filter.clone(), token)); - - match topic.maximum_qos { - QoS::AtMostOnce => SubscribeAckReason::GrantedQoSZero, - QoS::AtLeastOnce => SubscribeAckReason::GrantedQoSOne, - QoS::ExactlyOnce => SubscribeAckReason::GrantedQoSTwo, - } + .zip(plugin_ack.reason_codes) + .map(|(topic, plugin_reason)| match plugin_reason { + SubscribeAckReason::GrantedQoSZero + | SubscribeAckReason::GrantedQoSOne + | SubscribeAckReason::GrantedQoSTwo => { + let session_subscription = SessionSubscription { + client_id: client_id.clone(), + maximum_qos: topic.maximum_qos, + }; + let token = subscriptions.insert(&topic.topic_filter, session_subscription); + + session.subscription_tokens.push((topic.topic_filter.clone(), token)); + plugin_reason + }, + reason => reason, }) .collect(); let subscribe_ack = SubscribeAckPacket { packet_id: packet.packet_id, - reason_string: None, - user_properties: vec![], + reason_string: plugin_ack.reason_string, + user_properties: plugin_ack.user_properties, reason_codes: granted_qos_values, }; @@ -565,55 +577,51 @@ impl Broker { } async fn handle_publish(&mut self, client_id: String, packet: PublishPacket) { - let mut is_dup = false; - - // For QoS2, ensure this packet isn't delivered twice. So if we have an outgoing - // publish receive with the same ID, just send the publish receive again but don't forward - // the message. match packet.qos { - QoS::AtMostOnce => {}, + QoS::AtMostOnce => { + if self.plugin.on_publish_received_qos0(&packet) { + self.publish_message(packet).await; + } + }, QoS::AtLeastOnce => { if let Some(session) = self.sessions.get_mut(&client_id) { - let publish_ack = PublishAckPacket { - packet_id: packet - .packet_id - .expect("Packet with QoS 1 should have a packet ID"), - reason_code: PublishAckReason::Success, - reason_string: None, - user_properties: vec![], - }; - - session.send(ClientMessage::Packet(Packet::PublishAck(publish_ack))).await; + let (publish, publish_ack) = self.plugin.on_publish_received_qos1(&packet); + if let Some(publish_ack) = publish_ack { + session.send(ClientMessage::Packet(Packet::PublishAck(publish_ack))).await; + } + if publish { + self.publish_message(packet).await; + } } }, + // For QoS2, ensure this packet isn't delivered twice. So if we have an outgoing + // publish receive with the same ID, just send the publish receive again but don't forward + // the message. QoS::ExactlyOnce => { if let Some(session) = self.sessions.get_mut(&client_id) { - let packet_id = packet.packet_id.unwrap(); - is_dup = session.outgoing_publish_receives.contains(&packet_id); + let (mut publish, publish_rec) = self.plugin.on_publish_received_qos2(&packet); + + if let Some(publish_recv) = publish_rec { + let packet_id = publish_recv.packet_id; + let is_dup = session.outgoing_publish_receives.contains(&packet_id); + + publish = publish && !is_dup; - if !is_dup { - session.outgoing_publish_receives.push(packet_id) + if !is_dup { + session.outgoing_publish_receives.push(packet_id) + } + + session + .send(ClientMessage::Packet(Packet::PublishReceived(publish_recv))) + .await; } - let publish_recv = PublishReceivedPacket { - packet_id: packet - .packet_id - .expect("Packet with QoS 2 should have a packet ID"), - reason_code: PublishReceivedReason::Success, - reason_string: None, - user_properties: vec![], - }; - - session - .send(ClientMessage::Packet(Packet::PublishReceived(publish_recv))) - .await; + if publish { + self.publish_message(packet).await; + } } }, } - - if !is_dup { - self.publish_message(packet).await; - } } fn handle_publish_ack(&mut self, client_id: String, packet: PublishAckPacket) { diff --git a/mqtt-v5-broker/src/main.rs b/mqtt-v5-broker/src/main.rs index 1a61081..e9125c5 100644 --- a/mqtt-v5-broker/src/main.rs +++ b/mqtt-v5-broker/src/main.rs @@ -1,15 +1,11 @@ use std::{env, io}; use futures::future::try_join_all; -use log::{debug, info, trace}; -use mqtt_v5::types::{ - properties::{AuthenticationData, AuthenticationMethod}, - AuthenticatePacket, PublishPacket, SubscribePacket, -}; +use log::{debug, info}; use mqtt_v5_broker::{ broker::{Broker, BrokerMessage}, client, - plugin::{self}, + plugin::Noop, }; use tokio::{net::TcpListener, sync::mpsc::Sender, task}; @@ -49,39 +45,11 @@ fn init_logging() { } } -#[derive(Default)] -struct TracePlugin; - -impl plugin::Plugin for TracePlugin { - fn on_connect( - &mut self, - _: Option<&AuthenticationMethod>, - _: Option<&AuthenticationData>, - ) -> plugin::AuthentificationResult { - plugin::AuthentificationResult::Success - } - - fn on_authenticate(&mut self, packet: &AuthenticatePacket) -> plugin::AuthentificationResult { - trace!("Authenticate packet received: {:?}", packet); - plugin::AuthentificationResult::Success - } - - fn on_subscribe(&mut self, packet: &SubscribePacket) -> plugin::SubscribeResult { - trace!("Subscribe packet received: {:?}", packet); - plugin::SubscribeResult::Placeholder - } - - fn on_publish_received(&mut self, packet: &PublishPacket) -> plugin::PublishReceivedResult { - trace!("Publish packet received: {:?}", packet); - plugin::PublishReceivedResult::Placeholder - } -} - #[tokio::main] async fn main() -> Result<(), Box> { init_logging(); - let broker = Broker::with_plugin(TracePlugin::default()); + let broker = Broker::with_plugin(Noop); let broker_tx = broker.sender(); let broker = task::spawn(async { broker.run().await; diff --git a/mqtt-v5-broker/src/plugin.rs b/mqtt-v5-broker/src/plugin.rs index ec98611..432050d 100644 --- a/mqtt-v5-broker/src/plugin.rs +++ b/mqtt-v5-broker/src/plugin.rs @@ -1,3 +1,4 @@ +use log::{trace, warn}; use mqtt_v5::types::{ properties::{AuthenticationData, AuthenticationMethod}, *, @@ -15,14 +16,6 @@ pub enum AuthentificationResult { Packet(AuthenticatePacket), } -pub enum PublishReceivedResult { - Placeholder, -} - -pub enum SubscribeResult { - Placeholder, -} - /// Broker plugin pub trait Plugin { /// Called on connect packet reception @@ -35,8 +28,25 @@ pub trait Plugin { /// Called on authenticate packet reception fn on_authenticate(&mut self, packet: &AuthenticatePacket) -> AuthentificationResult; - fn on_subscribe(&mut self, packet: &SubscribePacket) -> SubscribeResult; - fn on_publish_received(&mut self, packet: &PublishPacket) -> PublishReceivedResult; + /// Called on subscribe packets reception + fn on_subscribe(&mut self, packet: &SubscribePacket) -> SubscribeAckPacket; + + /// Called on publish packets reception for QoS 0. Return true if the packet should be published to the clients. + fn on_publish_received_qos0(&mut self, packet: &PublishPacket) -> bool; + /// Called on publish packets reception for QoS 1. Return if the packet should be published to the clients and + /// the publish ack packet to be sent to the publisher. + + fn on_publish_received_qos1( + &mut self, + packet: &PublishPacket, + ) -> (bool, Option); + + /// Called on publish packets reception for QoS 2. Return if the packet should be published to the clients and + /// the publish received packet to be sent to the publisher. + fn on_publish_received_qos2( + &mut self, + packet: &PublishPacket, + ) -> (bool, Option); } /// Default noop authenticator @@ -53,11 +63,66 @@ impl Plugin for Noop { AuthentificationResult::Success } - fn on_subscribe(&mut self, _: &SubscribePacket) -> SubscribeResult { - SubscribeResult::Placeholder + fn on_subscribe(&mut self, packet: &SubscribePacket) -> SubscribeAckPacket { + SubscribeAckPacket { + packet_id: packet.packet_id, + reason_string: None, + user_properties: vec![], + reason_codes: packet + .subscription_topics + .iter() + .inspect(|filter| { + trace!("Granting subscribe to {}", filter.topic_filter); + }) + .map(|filter| match filter.maximum_qos { + QoS::AtMostOnce => SubscribeAckReason::GrantedQoSZero, + QoS::AtLeastOnce => SubscribeAckReason::GrantedQoSOne, + QoS::ExactlyOnce => SubscribeAckReason::GrantedQoSTwo, + }) + .collect(), + } + } + + fn on_publish_received_qos0(&mut self, packet: &PublishPacket) -> bool { + trace!("Granting QoS 0 publish on topic \"{}\"", packet.topic); + true } - fn on_publish_received(&mut self, _: &PublishPacket) -> PublishReceivedResult { - PublishReceivedResult::Placeholder + fn on_publish_received_qos1( + &mut self, + packet: &PublishPacket, + ) -> (bool, Option) { + if let Some(packet_id) = packet.packet_id { + let ack = PublishAckPacket { + packet_id, + reason_code: PublishAckReason::Success, + reason_string: None, + user_properties: Vec::with_capacity(0), + }; + trace!("Granting QoS 1 publish on topic \"{}\"", packet.topic); + (true, Some(ack)) + } else { + warn!("Publish packet with QoS 1 without packet id"); + (false, None) + } + } + + fn on_publish_received_qos2( + &mut self, + packet: &PublishPacket, + ) -> (bool, Option) { + if let Some(packet_id) = packet.packet_id { + let ack = PublishReceivedPacket { + packet_id, + reason_code: PublishReceivedReason::Success, + reason_string: None, + user_properties: Vec::with_capacity(0), + }; + trace!("Granting QoS 2 publish on topic \"{}\"", packet.topic); + (true, Some(ack)) + } else { + warn!("Publish packet with QoS 2 without packet id"); + (false, None) + } } } From 520f9e241be18b2f9de1a636c8eeaa716c63599c Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Thu, 5 May 2022 15:56:40 +0200 Subject: [PATCH 3/8] Attempt to implement authentification --- mqtt-v5-broker/src/broker.rs | 321 ++++++++++++++++++++++++++++------- mqtt-v5-broker/src/main.rs | 2 +- mqtt-v5-broker/src/plugin.rs | 32 ++-- 3 files changed, 270 insertions(+), 85 deletions(-) diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index cc56375..cde6ba3 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -1,9 +1,9 @@ use crate::{ client::ClientMessage, - plugin::{Noop, Plugin}, + plugin::{AuthentificationResult, Noop, Plugin}, tree::SubscriptionTree, }; -use log::{info, warn}; +use log::{debug, info, trace, warn}; use mqtt_v5::{ topic::TopicFilter, types::{ @@ -19,8 +19,12 @@ use std::{ collections::{hash_map::Entry, HashMap}, time::Duration, }; -use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::{ + sync::mpsc::{self, Receiver, Sender}, + task, time, +}; +#[derive(Debug)] struct Session { #[allow(unused)] pub protocol_version: ProtocolVersion, @@ -28,6 +32,9 @@ struct Session { // pub shared_subscriptions: HashSet, pub client_sender: Option>, + // Authentication state + pub authenticated: bool, + // Used to unsubscribe from topics subscription_tokens: Vec<(TopicFilter, u64)>, @@ -51,12 +58,14 @@ struct Session { impl Session { pub fn new( + authenticated: bool, protocol_version: ProtocolVersion, will: Option, session_expiry_interval: Option, client_sender: Sender, ) -> Self { Self { + authenticated, protocol_version, // subscriptions: HashSet::new(), // shared_subscriptions: HashSet::new(), @@ -76,12 +85,14 @@ impl Session { /// for a client which just connected. pub fn into_new_session( self, + authenticated: bool, protocol_version: ProtocolVersion, will: Option, session_expiry_interval: Option, client_sender: Sender, ) -> Self { Self { + authenticated, protocol_version, client_sender: Some(client_sender), session_expiry_interval, @@ -182,6 +193,7 @@ pub enum WillDisconnectLogic { #[derive(Debug)] pub enum BrokerMessage { + Stats, NewClient(Box, Sender), Authenticate(String, AuthenticatePacket), Publish(String, Box), @@ -195,7 +207,7 @@ pub enum BrokerMessage { Disconnect(String, WillDisconnectLogic), } -pub struct Broker { +pub struct Broker { sessions: HashMap, sender: Sender, receiver: Receiver, @@ -204,7 +216,7 @@ pub struct Broker { plugin: A, } -impl Default for Broker { +impl Default for Broker { fn default() -> Self { Broker::::new() } @@ -212,7 +224,7 @@ impl Default for Broker { impl Broker { /// Construct a new Broker. - pub fn new() -> Broker { + pub fn new() -> Broker { let (sender, receiver) = mpsc::channel(100); Broker { @@ -228,6 +240,18 @@ impl Broker { pub fn with_plugin(plugin: A) -> Broker { let (sender, receiver) = mpsc::channel(100); + { + let sender = sender.clone(); + task::spawn(async move { + loop { + time::sleep(time::Duration::from_secs(5)).await; + if sender.send(BrokerMessage::Stats).await.is_err() { + break; + } + } + }); + } + Broker { sessions: HashMap::new(), sender, @@ -294,6 +318,69 @@ impl Broker { connect_packet: ConnectPacket, client_msg_sender: Sender, ) { + debug!("Trying to authenticate client {}", connect_packet.client_id); + let authenticated = match self.plugin.on_connect(&connect_packet) { + AuthentificationResult::Reason(ConnectReason::Success) => { + info!("Authentification successful for client {}", connect_packet.client_id); + true + }, + AuthentificationResult::Reason(reason_code) => { + info!( + "Authentification result for client {}: {:?}", + connect_packet.client_id, reason_code + ); + let connect_ack = ConnectAckPacket { + // TODO: is this correct? + session_present: false, + reason_code, + // TODO: is this correct? + session_expiry_interval: None, + receive_maximum: None, + maximum_qos: None, + retain_available: None, + maximum_packet_size: None, + assigned_client_identifier: None, + topic_alias_maximum: None, + reason_string: None, + user_properties: Vec::with_capacity(0), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }; + debug!( + "Sending CONNACK to client {} with reson code {:?}", + connect_packet.client_id, reason_code + ); + client_msg_sender + .send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))) + .await + .expect("Failed to send Connect Acknowledgement"); + + debug!( + "Sending DISCONNECT to client {} with disconnect reason {:?}", + connect_packet.client_id, + DisconnectReason::NotAuthorized + ); + client_msg_sender + .send(ClientMessage::Disconnect(DisconnectReason::NotAuthorized)) + .await + .expect("Failed to send Disconnect"); + return; + }, + AuthentificationResult::Packet(packet) => { + client_msg_sender + .send(ClientMessage::Packet(Packet::Authenticate(packet))) + .await + .expect("Failed to send Connect Acknowledgement"); + false + }, + }; + let mut takeover_session = self .take_over_existing_client(&connect_packet.client_id, connect_packet.clean_start) .await; @@ -305,8 +392,8 @@ impl Broker { } info!( - "Client ID {} connected (Version: {:?})", - connect_packet.client_id, connect_packet.protocol_version + "Client ID {} connected (Version: {:?}, Authenticated: {})", + connect_packet.client_id, connect_packet.protocol_version, authenticated ); let session_expiry_interval = match connect_packet.session_expiry_interval { @@ -319,56 +406,62 @@ impl Broker { }, None => None, }; + let session_expiry_duration = + session_expiry_interval.map(|i| Duration::from_secs(i.0 as u64)); - let connect_ack = ConnectAckPacket { - // Variable header - session_present, - reason_code: ConnectReason::Success, - - // Properties - session_expiry_interval, - receive_maximum: None, - maximum_qos: None, - retain_available: None, - maximum_packet_size: None, - assigned_client_identifier: Some(AssignedClientIdentifier( - connect_packet.client_id.clone(), - )), - topic_alias_maximum: None, - reason_string: None, - user_properties: vec![], - wildcard_subscription_available: None, - subscription_identifiers_available: None, - shared_subscription_available: None, - server_keep_alive: None, - response_information: None, - server_reference: None, - authentication_method: None, - authentication_data: None, - }; + if authenticated { + // Send conack if the auth is already successful and complete + let connect_ack = ConnectAckPacket { + // Variable header + session_present, + reason_code: ConnectReason::Success, - // A newly connected client should have empty channel and the queuing the connect ack *must* fit in the channel. - client_msg_sender - .send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))) - .await - .expect("Failed to send Connect Acknowledgement"); + // Properties + session_expiry_interval, + receive_maximum: None, + maximum_qos: None, + retain_available: None, + maximum_packet_size: None, + assigned_client_identifier: Some(AssignedClientIdentifier( + connect_packet.client_id.clone(), + )), + topic_alias_maximum: None, + reason_string: None, + user_properties: Vec::with_capacity(0), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }; - let session_expiry_duration = - session_expiry_interval.map(|i| Duration::from_secs(i.0 as u64)); + // A newly connected client should have empty channel and the queuing the connect ack *must* fit in the channel. + client_msg_sender + .send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))) + .await + .expect("Failed to send Connect Acknowledgement"); + } let new_session = if let Some(existing_session) = takeover_session { let mut new_session = existing_session.into_new_session( + authenticated, connect_packet.protocol_version, connect_packet.will, session_expiry_duration, client_msg_sender, ); - new_session.resend_packets().await; + if authenticated { + new_session.resend_packets().await; + } new_session } else { Session::new( + authenticated, connect_packet.protocol_version, connect_packet.will, session_expiry_duration, @@ -379,10 +472,73 @@ impl Broker { self.sessions.insert(connect_packet.client_id, new_session); } + async fn handle_authenticate(&mut self, client_id: String, packet: AuthenticatePacket) { + if let Some(session) = self.sessions.get_mut(&client_id) { + debug!("Trying to authenticate client {}", client_id); + let authentification = self.plugin.on_authenticate(&packet); + + let reason_code = match authentification { + AuthentificationResult::Reason(ConnectReason::Success) => { + info!("Authentification successful for client {}", client_id); + session.authenticated = true; + ConnectReason::Success + }, + AuthentificationResult::Reason(reason_code) => { + info!("Authentification result for client {}: {:?}", client_id, reason_code); + reason_code + }, + AuthentificationResult::Packet(packet) => { + session.send(ClientMessage::Packet(Packet::Authenticate(packet))).await; + return; + }, + }; + + // TODO: fill in correct values + let connect_ack = ConnectAckPacket { + session_present: false, + reason_code, + session_expiry_interval: None, + receive_maximum: None, + maximum_qos: None, + retain_available: None, + maximum_packet_size: None, + assigned_client_identifier: None, + topic_alias_maximum: None, + reason_string: None, + user_properties: Vec::with_capacity(0), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }; + debug!("Sending CONNACK to client {} with reson code {:?}", client_id, reason_code); + session.send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))).await; + + if reason_code != ConnectReason::Success { + session.resend_packets().await; + + debug!( + "Sending DISCONNECT to client {} with disconnect reason {:?}", + client_id, + DisconnectReason::NotAuthorized + ); + session.send(ClientMessage::Disconnect(DisconnectReason::NotAuthorized)).await; + } + } + } + async fn handle_subscribe(&mut self, client_id: String, packet: SubscribePacket) { let subscriptions = &mut self.subscriptions; if let Some(session) = self.sessions.get_mut(&client_id) { + if !session.authenticated { + todo!("Session is not authenticated"); + } + let plugin_ack = self.plugin.on_subscribe(&packet); // If a Server receives a SUBSCRIBE packet containing a Topic Filter that @@ -449,6 +605,10 @@ impl Broker { let subscriptions = &mut self.subscriptions; if let Some(session) = self.sessions.get_mut(&client_id) { + if !session.authenticated { + todo!("Session is not authenticated"); + } + for filter in &packet.topic_filters { // Unsubscribe the old session from all topics it subscribed to. session.subscription_tokens.retain(|(session_topic, token)| { @@ -495,7 +655,11 @@ impl Broker { let mut session_expiry_duration = None; if let Entry::Occupied(mut session_entry) = self.sessions.entry(client_id.clone()) { - session_entry.get_mut().client_sender.take(); + { + let session = session_entry.get_mut(); + session.client_sender.take(); + session.authenticated = false; + } if let Some(expiry_interval) = session_entry.get().session_expiry_interval { // The Will Message MUST be published after the Network Connection is subsequently @@ -535,7 +699,7 @@ impl Broker { // Spawn a task that publishes the will after `will_send_delay_duration` tokio::spawn(async move { - tokio::time::sleep(will_send_delay_duration).await; + time::sleep(will_send_delay_duration).await; broker_sender .send(BrokerMessage::PublishFinalWill(client_id, will)) .await @@ -554,6 +718,9 @@ impl Broker { for session_subscription in self.subscriptions.matching_subscribers(topic) { if let Some(session) = sessions.get_mut(&session_subscription.client_id) { + if !session.authenticated { + todo!("Session is not authenticated"); + } let outgoing_packet_id = match session_subscription.maximum_qos { QoS::AtLeastOnce | QoS::ExactlyOnce => { Some(session.store_outgoing_publish( @@ -577,14 +744,17 @@ impl Broker { } async fn handle_publish(&mut self, client_id: String, packet: PublishPacket) { - match packet.qos { - QoS::AtMostOnce => { - if self.plugin.on_publish_received_qos0(&packet) { - self.publish_message(packet).await; - } - }, - QoS::AtLeastOnce => { - if let Some(session) = self.sessions.get_mut(&client_id) { + if let Some(session) = self.sessions.get_mut(&client_id) { + if !session.authenticated { + todo!("Session is not authenticated"); + } + match packet.qos { + QoS::AtMostOnce => { + if self.plugin.on_publish_received_qos0(&packet) { + self.publish_message(packet).await; + } + }, + QoS::AtLeastOnce => { let (publish, publish_ack) = self.plugin.on_publish_received_qos1(&packet); if let Some(publish_ack) = publish_ack { session.send(ClientMessage::Packet(Packet::PublishAck(publish_ack))).await; @@ -592,13 +762,11 @@ impl Broker { if publish { self.publish_message(packet).await; } - } - }, - // For QoS2, ensure this packet isn't delivered twice. So if we have an outgoing - // publish receive with the same ID, just send the publish receive again but don't forward - // the message. - QoS::ExactlyOnce => { - if let Some(session) = self.sessions.get_mut(&client_id) { + }, + // For QoS2, ensure this packet isn't delivered twice. So if we have an outgoing + // publish receive with the same ID, just send the publish receive again but don't forward + // the message. + QoS::ExactlyOnce => { let (mut publish, publish_rec) = self.plugin.on_publish_received_qos2(&packet); if let Some(publish_recv) = publish_rec { @@ -619,19 +787,25 @@ impl Broker { if publish { self.publish_message(packet).await; } - } - }, + }, + } } } fn handle_publish_ack(&mut self, client_id: String, packet: PublishAckPacket) { if let Some(session) = self.sessions.get_mut(&client_id) { + if !session.authenticated { + todo!("Session is not authenticated"); + } session.remove_outgoing_publish(packet.packet_id); } } async fn handle_publish_release(&mut self, client_id: String, packet: PublishReleasePacket) { if let Some(session) = self.sessions.get_mut(&client_id) { + if !session.authenticated { + todo!("Session is not authenticated"); + } if let Some(pos) = session.outgoing_publish_receives.iter().position(|x| *x == packet.packet_id) { @@ -651,6 +825,9 @@ impl Broker { async fn handle_publish_received(&mut self, client_id: String, packet: PublishReceivedPacket) { if let Some(session) = self.sessions.get_mut(&client_id) { + if !session.authenticated { + todo!("Session is not authenticated"); + } if let Some(pos) = session.outgoing_packets.iter().position(|p| { p.qos == QoS::ExactlyOnce && p.packet_id.map(|id| id == packet.packet_id).unwrap_or(false) @@ -674,6 +851,9 @@ impl Broker { fn handle_publish_complete(&mut self, client_id: String, packet: PublishCompletePacket) { if let Some(session) = self.sessions.get_mut(&client_id) { + if !session.authenticated { + todo!("Session is not authenticated"); + } if let Some(pos) = session.outgoing_publish_released.iter().position(|x| *x == packet.packet_id) { @@ -684,6 +864,9 @@ impl Broker { async fn publish_final_will(&mut self, client_id: String, final_will: FinalWill) { if let Some(session) = self.sessions.get_mut(&client_id) { + if !session.authenticated { + todo!("Session is not authenticated"); + } if session.client_sender.is_some() { // They've reconnected, don't send out the will message. self.publish_message(final_will.into()).await; @@ -696,12 +879,21 @@ impl Broker { } } + fn stats(&self) { + trace!("sessions: {:#?}", self.sessions); + trace!("subscriptions: {:#?}", self.subscriptions); + } + pub async fn run(mut self) { while let Some(msg) = self.receiver.recv().await { match msg { + BrokerMessage::Stats => self.stats(), BrokerMessage::NewClient(connect_packet, client_msg_sender) => { self.handle_new_client(*connect_packet, client_msg_sender).await; }, + BrokerMessage::Authenticate(client_id, packet) => { + self.handle_authenticate(client_id, packet).await; + }, BrokerMessage::Subscribe(client_id, packet) => { self.handle_subscribe(client_id, packet).await; }, @@ -729,9 +921,6 @@ impl Broker { BrokerMessage::PublishFinalWill(client_id, final_will) => { self.publish_final_will(client_id, final_will).await; }, - BrokerMessage::Authenticate(client_id, _) => { - warn!("Ignoring unexpected authentication message from {}", client_id); - }, } } } @@ -771,8 +960,8 @@ mod tests { client_id: "TEST".to_string(), will: None, - user_name: None, - password: None, + user_name: Some("test".into()), + password: Some("test".into()), }; let _ = broker_tx diff --git a/mqtt-v5-broker/src/main.rs b/mqtt-v5-broker/src/main.rs index e9125c5..6c73a8f 100644 --- a/mqtt-v5-broker/src/main.rs +++ b/mqtt-v5-broker/src/main.rs @@ -49,7 +49,7 @@ fn init_logging() { async fn main() -> Result<(), Box> { init_logging(); - let broker = Broker::with_plugin(Noop); + let broker = Broker::default(); let broker_tx = broker.sender(); let broker = task::spawn(async { broker.run().await; diff --git a/mqtt-v5-broker/src/plugin.rs b/mqtt-v5-broker/src/plugin.rs index 432050d..1a3c4ce 100644 --- a/mqtt-v5-broker/src/plugin.rs +++ b/mqtt-v5-broker/src/plugin.rs @@ -1,17 +1,16 @@ use log::{trace, warn}; use mqtt_v5::types::{ - properties::{AuthenticationData, AuthenticationMethod}, - *, + AuthenticatePacket, ConnectPacket, ConnectReason, PublishAckPacket, PublishAckReason, + PublishPacket, PublishReceivedPacket, PublishReceivedReason, QoS, SubscribeAckPacket, + SubscribeAckReason, SubscribePacket, }; pub struct Noop; /// Result of a authentication attempt pub enum AuthentificationResult { - /// Authentification is successful. - Success, - /// Authentification failed. Send connect reason code to the client (ConnectAck) - Fail(ConnectReason), + /// Authentification reason + Reason(ConnectReason), /// Send this auth packet to the client and wait for the response. Packet(AuthenticatePacket), } @@ -19,11 +18,7 @@ pub enum AuthentificationResult { /// Broker plugin pub trait Plugin { /// Called on connect packet reception - fn on_connect( - &mut self, - method: Option<&AuthenticationMethod>, - data: Option<&AuthenticationData>, - ) -> AuthentificationResult; + fn on_connect(&mut self, packet: &ConnectPacket) -> AuthentificationResult; /// Called on authenticate packet reception fn on_authenticate(&mut self, packet: &AuthenticatePacket) -> AuthentificationResult; @@ -51,16 +46,17 @@ pub trait Plugin { /// Default noop authenticator impl Plugin for Noop { - fn on_connect( - &mut self, - _: Option<&AuthenticationMethod>, - _: Option<&AuthenticationData>, - ) -> AuthentificationResult { - AuthentificationResult::Success + fn on_connect(&mut self, packet: &ConnectPacket) -> AuthentificationResult { + // Just a hacky test... + if packet.user_name.is_some() && packet.user_name == packet.password { + AuthentificationResult::Reason(ConnectReason::Success) + } else { + AuthentificationResult::Reason(ConnectReason::BadUserNameOrPassword) + } } fn on_authenticate(&mut self, _: &AuthenticatePacket) -> AuthentificationResult { - AuthentificationResult::Success + AuthentificationResult::Reason(ConnectReason::Success) } fn on_subscribe(&mut self, packet: &SubscribePacket) -> SubscribeAckPacket { From e9d975e192962953cc2fa924a7264cf0d61676fe Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Mon, 23 May 2022 12:17:58 +0200 Subject: [PATCH 4/8] Remove debug stats --- mqtt-v5-broker/src/broker.rs | 23 ++--------------------- mqtt-v5-broker/src/main.rs | 1 - 2 files changed, 2 insertions(+), 22 deletions(-) diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index cde6ba3..f3cc320 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -3,7 +3,7 @@ use crate::{ plugin::{AuthentificationResult, Noop, Plugin}, tree::SubscriptionTree, }; -use log::{debug, info, trace, warn}; +use log::{debug, info, warn}; use mqtt_v5::{ topic::TopicFilter, types::{ @@ -21,7 +21,7 @@ use std::{ }; use tokio::{ sync::mpsc::{self, Receiver, Sender}, - task, time, + time, }; #[derive(Debug)] @@ -193,7 +193,6 @@ pub enum WillDisconnectLogic { #[derive(Debug)] pub enum BrokerMessage { - Stats, NewClient(Box, Sender), Authenticate(String, AuthenticatePacket), Publish(String, Box), @@ -240,18 +239,6 @@ impl Broker { pub fn with_plugin(plugin: A) -> Broker { let (sender, receiver) = mpsc::channel(100); - { - let sender = sender.clone(); - task::spawn(async move { - loop { - time::sleep(time::Duration::from_secs(5)).await; - if sender.send(BrokerMessage::Stats).await.is_err() { - break; - } - } - }); - } - Broker { sessions: HashMap::new(), sender, @@ -879,15 +866,9 @@ impl Broker { } } - fn stats(&self) { - trace!("sessions: {:#?}", self.sessions); - trace!("subscriptions: {:#?}", self.subscriptions); - } - pub async fn run(mut self) { while let Some(msg) = self.receiver.recv().await { match msg { - BrokerMessage::Stats => self.stats(), BrokerMessage::NewClient(connect_packet, client_msg_sender) => { self.handle_new_client(*connect_packet, client_msg_sender).await; }, diff --git a/mqtt-v5-broker/src/main.rs b/mqtt-v5-broker/src/main.rs index 6c73a8f..42a9561 100644 --- a/mqtt-v5-broker/src/main.rs +++ b/mqtt-v5-broker/src/main.rs @@ -5,7 +5,6 @@ use log::{debug, info}; use mqtt_v5_broker::{ broker::{Broker, BrokerMessage}, client, - plugin::Noop, }; use tokio::{net::TcpListener, sync::mpsc::Sender, task}; From d921ec3f90ff1fc958e27309dbeb711471f8363e Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Wed, 1 Jun 2022 12:08:57 +0200 Subject: [PATCH 5/8] Add client disconnect callback --- mqtt-v5-broker/src/broker.rs | 2 ++ mqtt-v5-broker/src/plugin.rs | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index f3cc320..3ae5f32 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -638,6 +638,8 @@ impl Broker { fn handle_disconnect(&mut self, client_id: String, will_disconnect_logic: WillDisconnectLogic) { info!("Client ID {} disconnected", client_id); + self.plugin.on_disconnect(&client_id); + let mut disconnect_will = None; let mut session_expiry_duration = None; diff --git a/mqtt-v5-broker/src/plugin.rs b/mqtt-v5-broker/src/plugin.rs index 1a3c4ce..d0a9a09 100644 --- a/mqtt-v5-broker/src/plugin.rs +++ b/mqtt-v5-broker/src/plugin.rs @@ -20,6 +20,9 @@ pub trait Plugin { /// Called on connect packet reception fn on_connect(&mut self, packet: &ConnectPacket) -> AuthentificationResult; + /// Called on client disconnect + fn on_disconnect(&mut self, client_id: &str); + /// Called on authenticate packet reception fn on_authenticate(&mut self, packet: &AuthenticatePacket) -> AuthentificationResult; @@ -55,6 +58,8 @@ impl Plugin for Noop { } } + fn on_disconnect(&mut self, _: &str) {} + fn on_authenticate(&mut self, _: &AuthenticatePacket) -> AuthentificationResult { AuthentificationResult::Reason(ConnectReason::Success) } From 58c9a99adcf72e5b94ba15b7b933c3f3e8633d7c Mon Sep 17 00:00:00 2001 From: Felix Obenhuber Date: Thu, 2 Jun 2022 11:17:12 +0200 Subject: [PATCH 6/8] Refactor authentification handling Collect unauthorized connections seperated from the session until the auth is complete. Introduce a unique identifier for connections. --- mqtt-v5-broker/src/broker.rs | 524 ++++++++++++++++++++++------------- mqtt-v5-broker/src/client.rs | 85 ++++-- 2 files changed, 397 insertions(+), 212 deletions(-) diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index 3ae5f32..3709aa1 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -24,6 +24,26 @@ use tokio::{ time, }; +/// Client connected but not yet authenticated. +struct UnauthenticatedSession { + connect_packet: ConnectPacket, + client_sender: Sender, +} + +impl UnauthenticatedSession { + /// Construct a new UnauthenticatedSession. + fn new(connect_packet: ConnectPacket, client_sender: Sender) -> Self { + Self { connect_packet, client_sender } + } + + /// Send a `ClientMessage` to the client via the channel handle. + /// This is a fire and forget operation because upon any error on the + /// connection the `Client` will send a `BrokerMessage::Disconnect`. + async fn send(&self, message: ClientMessage) { + drop(self.client_sender.send(message).await); + } +} + #[derive(Debug)] struct Session { #[allow(unused)] @@ -32,9 +52,6 @@ struct Session { // pub shared_subscriptions: HashSet, pub client_sender: Option>, - // Authentication state - pub authenticated: bool, - // Used to unsubscribe from topics subscription_tokens: Vec<(TopicFilter, u64)>, @@ -58,14 +75,12 @@ struct Session { impl Session { pub fn new( - authenticated: bool, protocol_version: ProtocolVersion, will: Option, session_expiry_interval: Option, client_sender: Sender, ) -> Self { Self { - authenticated, protocol_version, // subscriptions: HashSet::new(), // shared_subscriptions: HashSet::new(), @@ -85,14 +100,12 @@ impl Session { /// for a client which just connected. pub fn into_new_session( self, - authenticated: bool, protocol_version: ProtocolVersion, will: Option, session_expiry_interval: Option, client_sender: Sender, ) -> Self { Self { - authenticated, protocol_version, client_sender: Some(client_sender), session_expiry_interval, @@ -191,22 +204,31 @@ pub enum WillDisconnectLogic { DoNotSend, } +/// Unique identifier for a connection +pub type ConnectionId = String; + +/// Client ID +pub type ClientId = String; + #[derive(Debug)] pub enum BrokerMessage { - NewClient(Box, Sender), - Authenticate(String, AuthenticatePacket), - Publish(String, Box), - PublishAck(String, PublishAckPacket), // TODO - This can be handled by the client task - PublishRelease(String, PublishReleasePacket), // TODO - This can be handled by the client task - PublishReceived(String, PublishReceivedPacket), - PublishComplete(String, PublishCompletePacket), - PublishFinalWill(String, FinalWill), - Subscribe(String, SubscribePacket), // TODO - replace string client_id with int - Unsubscribe(String, UnsubscribePacket), // TODO - replace string client_id with int - Disconnect(String, WillDisconnectLogic), + Connect(ConnectionId, Box, Sender), + Disconnect(ConnectionId, String, WillDisconnectLogic), + Authenticate(ConnectionId, ClientId, AuthenticatePacket), + Publish(ConnectionId, ClientId, Box), + PublishAck(ConnectionId, ClientId, PublishAckPacket), // TODO - This can be handled by the client task + PublishRelease(ConnectionId, ClientId, PublishReleasePacket), // TODO - This can be handled by the client task + PublishReceived(ConnectionId, ClientId, PublishReceivedPacket), + PublishComplete(ConnectionId, ClientId, PublishCompletePacket), + PublishFinalWill(ConnectionId, ClientId, FinalWill), + Subscribe(ConnectionId, ClientId, SubscribePacket), // TODO - replace string client_id with int + Unsubscribe(ConnectionId, ClientId, UnsubscribePacket), // TODO - replace string client_id with int } pub struct Broker { + /// A map of client IDs to unauthenticated sessions. Once a client passes + /// authentication, the session exended is moved to the `session` map. + unauthenticated_sessions: HashMap, sessions: HashMap, sender: Sender, receiver: Receiver, @@ -227,6 +249,7 @@ impl Broker { let (sender, receiver) = mpsc::channel(100); Broker { + unauthenticated_sessions: HashMap::new(), sessions: HashMap::new(), sender, receiver, @@ -240,6 +263,7 @@ impl Broker { let (sender, receiver) = mpsc::channel(100); Broker { + unauthenticated_sessions: HashMap::new(), sessions: HashMap::new(), sender, receiver, @@ -294,33 +318,35 @@ impl Broker { // If the Server accepts a connection with Clean Start set to 0 and the Server has Session State for the ClientID, // it MUST set Session Present to 1 in the CONNACK packet, otherwise it MUST set Session Present to 0 in the CONNACK packet. if new_client_clean_start { - return None; + None + } else { + existing_session } - - existing_session } async fn handle_new_client( &mut self, + connection_id: ConnectionId, connect_packet: ConnectPacket, client_msg_sender: Sender, ) { - debug!("Trying to authenticate client {}", connect_packet.client_id); - let authenticated = match self.plugin.on_connect(&connect_packet) { + debug!( + "Trying to authenticate client {} (connection {})", + connect_packet.client_id, connection_id + ); + match self.plugin.on_connect(&connect_packet) { AuthentificationResult::Reason(ConnectReason::Success) => { info!("Authentification successful for client {}", connect_packet.client_id); - true + self.handle_authenticated_client(connect_packet, client_msg_sender).await; }, AuthentificationResult::Reason(reason_code) => { info!( - "Authentification result for client {}: {:?}", + "Authentification reason code for client {} is {:?}", connect_packet.client_id, reason_code ); let connect_ack = ConnectAckPacket { - // TODO: is this correct? session_present: false, reason_code, - // TODO: is this correct? session_expiry_interval: None, receive_maximum: None, maximum_qos: None, @@ -339,48 +365,60 @@ impl Broker { authentication_method: None, authentication_data: None, }; + + // Send a disconnect packet to the client. Ignore send errors because + // the client could already be disconnected and the rx handle of this + // channel is dropped. debug!( - "Sending CONNACK to client {} with reson code {:?}", + "Sending CONNACK to client ID {} with reason code {:?}", connect_packet.client_id, reason_code ); client_msg_sender .send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))) .await - .expect("Failed to send Connect Acknowledgement"); + .ok(); debug!( - "Sending DISCONNECT to client {} with disconnect reason {:?}", + "Sending DISCONNECT to client {} with disconnect reason code {:?}", connect_packet.client_id, DisconnectReason::NotAuthorized ); client_msg_sender .send(ClientMessage::Disconnect(DisconnectReason::NotAuthorized)) .await - .expect("Failed to send Disconnect"); - return; + .ok(); }, AuthentificationResult::Packet(packet) => { client_msg_sender .send(ClientMessage::Packet(Packet::Authenticate(packet))) .await - .expect("Failed to send Connect Acknowledgement"); - false + .ok(); + let client_id = connect_packet.client_id.clone(); + info!("Adding unauthenticated session for client ID {}", client_id); + let unauthenticated_session = + UnauthenticatedSession::new(connect_packet, client_msg_sender); + self.unauthenticated_sessions.insert(connection_id, unauthenticated_session); }, - }; + } + } + async fn handle_authenticated_client( + &mut self, + connect_packet: ConnectPacket, + client_msg_sender: Sender, + ) { let mut takeover_session = self .take_over_existing_client(&connect_packet.client_id, connect_packet.clean_start) .await; let session_present = takeover_session.is_some(); - // TODO(flxo) Calling `resend_packets` after `take_over_existing_client` feels strange since a disconnect is sent in there. if let Some(existing_session) = &mut takeover_session { existing_session.resend_packets().await; } info!( - "Client ID {} connected (Version: {:?}, Authenticated: {})", - connect_packet.client_id, connect_packet.protocol_version, authenticated + "Client ID {} connected (Version: {:?})", + connect_packet.client_id, connect_packet.protocol_version ); let session_expiry_interval = match connect_packet.session_expiry_interval { @@ -393,62 +431,68 @@ impl Broker { }, None => None, }; - let session_expiry_duration = - session_expiry_interval.map(|i| Duration::from_secs(i.0 as u64)); - - if authenticated { - // Send conack if the auth is already successful and complete - let connect_ack = ConnectAckPacket { - // Variable header - session_present, - reason_code: ConnectReason::Success, - - // Properties - session_expiry_interval, - receive_maximum: None, - maximum_qos: None, - retain_available: None, - maximum_packet_size: None, - assigned_client_identifier: Some(AssignedClientIdentifier( - connect_packet.client_id.clone(), - )), - topic_alias_maximum: None, - reason_string: None, - user_properties: Vec::with_capacity(0), - wildcard_subscription_available: None, - subscription_identifiers_available: None, - shared_subscription_available: None, - server_keep_alive: None, - response_information: None, - server_reference: None, - authentication_method: None, - authentication_data: None, - }; + let session_expiry_duration = session_expiry_interval.map(|i| { + let duration = Duration::from_secs(i.0 as u64); + debug!( + "Client ID {} Session expiry interval is {:?}", + connect_packet.client_id, duration + ); + duration + }); + + // The user provided plugin decided when processing the Connect packet that this + // client is authenticated. This could be the case for no authentication needed or + // a username password authentication. The broker shall not wait for addition + // `Authentification` packets from the client before the client is in the authenticated + // state. + // If the client is not authenticated the `ConnectAckPacket` is sent once authentification + // succeeds or fails. + // Send conack if the auth is already successful and complete + let connect_ack = ConnectAckPacket { + // Variable header + session_present, + reason_code: ConnectReason::Success, + + // Properties + session_expiry_interval, + receive_maximum: None, + maximum_qos: None, + retain_available: None, + maximum_packet_size: None, + assigned_client_identifier: Some(AssignedClientIdentifier( + connect_packet.client_id.clone(), + )), + topic_alias_maximum: None, + reason_string: None, + user_properties: Vec::with_capacity(0), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }; - // A newly connected client should have empty channel and the queuing the connect ack *must* fit in the channel. - client_msg_sender - .send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))) - .await - .expect("Failed to send Connect Acknowledgement"); - } + // If the client disconnected in the meantime, the rx part of the client handle is dropped + // and a send attempt will fail. Ignore this error, because the disconnection is handled + // by a BrokerMessage::Disconnect. + client_msg_sender.send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))).await.ok(); let new_session = if let Some(existing_session) = takeover_session { let mut new_session = existing_session.into_new_session( - authenticated, connect_packet.protocol_version, connect_packet.will, session_expiry_duration, client_msg_sender, ); - if authenticated { - new_session.resend_packets().await; - } + new_session.resend_packets().await; new_session } else { Session::new( - authenticated, connect_packet.protocol_version, connect_packet.will, session_expiry_duration, @@ -459,73 +503,90 @@ impl Broker { self.sessions.insert(connect_packet.client_id, new_session); } - async fn handle_authenticate(&mut self, client_id: String, packet: AuthenticatePacket) { - if let Some(session) = self.sessions.get_mut(&client_id) { - debug!("Trying to authenticate client {}", client_id); - let authentification = self.plugin.on_authenticate(&packet); - - let reason_code = match authentification { - AuthentificationResult::Reason(ConnectReason::Success) => { - info!("Authentification successful for client {}", client_id); - session.authenticated = true; - ConnectReason::Success - }, - AuthentificationResult::Reason(reason_code) => { - info!("Authentification result for client {}: {:?}", client_id, reason_code); - reason_code - }, - AuthentificationResult::Packet(packet) => { - session.send(ClientMessage::Packet(Packet::Authenticate(packet))).await; - return; - }, - }; + /// Handle authenticate packets. Query the plugin and send a `ConnectAckPacket` or `Authenticate` + /// packet if needed. If the plugin authenticates the client (session) proceed with the session + /// setup etc. and remove the unauthenticated session. + async fn handle_authenticate( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: AuthenticatePacket, + ) { + debug!("Trying to authenticate client {} (connection {})", client_id, connection_id); - // TODO: fill in correct values - let connect_ack = ConnectAckPacket { - session_present: false, - reason_code, - session_expiry_interval: None, - receive_maximum: None, - maximum_qos: None, - retain_available: None, - maximum_packet_size: None, - assigned_client_identifier: None, - topic_alias_maximum: None, - reason_string: None, - user_properties: Vec::with_capacity(0), - wildcard_subscription_available: None, - subscription_identifiers_available: None, - shared_subscription_available: None, - server_keep_alive: None, - response_information: None, - server_reference: None, - authentication_method: None, - authentication_data: None, - }; - debug!("Sending CONNACK to client {} with reson code {:?}", client_id, reason_code); - session.send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))).await; + let entry = match self.unauthenticated_sessions.entry(connection_id) { + Entry::Occupied(entry) => entry, + Entry::Vacant(entry) => { + warn!("Received authenticate packet for unknown client ID {}", entry.key()); + return; + }, + }; + + match self.plugin.on_authenticate(&packet) { + AuthentificationResult::Reason(ConnectReason::Success) => { + let (client_id, UnauthenticatedSession { client_sender, connect_packet }) = + entry.remove_entry(); + info!("Authentification successful for client ID {}", client_id); + self.handle_authenticated_client(connect_packet, client_sender).await; + }, + AuthentificationResult::Reason(reason_code) => { + info!("Authentification result for client ID {} is {:?}", entry.key(), reason_code); + let connect_ack = ConnectAckPacket { + // Variable header + session_present: false, + reason_code, - if reason_code != ConnectReason::Success { - session.resend_packets().await; + // Properties + session_expiry_interval: None, + receive_maximum: None, + maximum_qos: None, + retain_available: None, + maximum_packet_size: None, + assigned_client_identifier: None, + topic_alias_maximum: None, + reason_string: None, + user_properties: Vec::with_capacity(0), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }; - debug!( - "Sending DISCONNECT to client {} with disconnect reason {:?}", - client_id, - DisconnectReason::NotAuthorized - ); + // If the client disconnected in the meantime, the rx part of the client handle is dropped + // and a send attempt will fail. Ignore this error, because the disconnection is handled + // by a BrokerMessage::Disconnect. + let session = entry.get(); + session.send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))).await; session.send(ClientMessage::Disconnect(DisconnectReason::NotAuthorized)).await; - } + }, + AuthentificationResult::Packet(packet) => { + let session = entry.get(); + session.send(ClientMessage::Packet(Packet::Authenticate(packet))).await; + }, } } - async fn handle_subscribe(&mut self, client_id: String, packet: SubscribePacket) { + async fn handle_subscribe( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: SubscribePacket, + ) { + if self.unauthenticated_sessions.contains_key(&connection_id) { + warn!( + "Ignoring subscribe packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } + let subscriptions = &mut self.subscriptions; if let Some(session) = self.sessions.get_mut(&client_id) { - if !session.authenticated { - todo!("Session is not authenticated"); - } - let plugin_ack = self.plugin.on_subscribe(&packet); // If a Server receives a SUBSCRIBE packet containing a Topic Filter that @@ -588,14 +649,23 @@ impl Broker { } } - async fn handle_unsubscribe(&mut self, client_id: String, packet: UnsubscribePacket) { + async fn handle_unsubscribe( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: UnsubscribePacket, + ) { + if self.unauthenticated_sessions.contains_key(&connection_id) { + warn!( + "Ignoring unsubscribe packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } + let subscriptions = &mut self.subscriptions; if let Some(session) = self.sessions.get_mut(&client_id) { - if !session.authenticated { - todo!("Session is not authenticated"); - } - for filter in &packet.topic_filters { // Unsubscribe the old session from all topics it subscribed to. session.subscription_tokens.retain(|(session_topic, token)| { @@ -635,7 +705,17 @@ impl Broker { } } - fn handle_disconnect(&mut self, client_id: String, will_disconnect_logic: WillDisconnectLogic) { + fn handle_disconnect( + &mut self, + connection_id: ConnectionId, + client_id: String, + will_disconnect_logic: WillDisconnectLogic, + ) { + if self.unauthenticated_sessions.remove(&connection_id).is_some() { + info!("Removing unauthenticated session for client ID {}", client_id); + return; + } + info!("Client ID {} disconnected", client_id); self.plugin.on_disconnect(&client_id); @@ -647,7 +727,6 @@ impl Broker { { let session = session_entry.get_mut(); session.client_sender.take(); - session.authenticated = false; } if let Some(expiry_interval) = session_entry.get().session_expiry_interval { @@ -690,7 +769,11 @@ impl Broker { tokio::spawn(async move { time::sleep(will_send_delay_duration).await; broker_sender - .send(BrokerMessage::PublishFinalWill(client_id, will)) + .send(BrokerMessage::PublishFinalWill( + connection_id, + client_id, + will, + )) .await .expect("Failed to send final will message to broker"); }); @@ -707,9 +790,6 @@ impl Broker { for session_subscription in self.subscriptions.matching_subscribers(topic) { if let Some(session) = sessions.get_mut(&session_subscription.client_id) { - if !session.authenticated { - todo!("Session is not authenticated"); - } let outgoing_packet_id = match session_subscription.maximum_qos { QoS::AtLeastOnce | QoS::ExactlyOnce => { Some(session.store_outgoing_publish( @@ -732,11 +812,21 @@ impl Broker { } } - async fn handle_publish(&mut self, client_id: String, packet: PublishPacket) { + async fn handle_publish( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishPacket, + ) { + if self.unauthenticated_sessions.contains_key(&connection_id) { + warn!( + "Discarding publish packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { - if !session.authenticated { - todo!("Session is not authenticated"); - } match packet.qos { QoS::AtMostOnce => { if self.plugin.on_publish_received_qos0(&packet) { @@ -781,20 +871,37 @@ impl Broker { } } - fn handle_publish_ack(&mut self, client_id: String, packet: PublishAckPacket) { + fn handle_publish_ack( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishAckPacket, + ) { + if self.unauthenticated_sessions.contains_key(&connection_id) { + warn!( + "Discarding publish ack packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { - if !session.authenticated { - todo!("Session is not authenticated"); - } session.remove_outgoing_publish(packet.packet_id); } } - async fn handle_publish_release(&mut self, client_id: String, packet: PublishReleasePacket) { + async fn handle_publish_release( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishReleasePacket, + ) { + if self.unauthenticated_sessions.contains_key(&connection_id) { + warn!("Discarding publish release packet from unauthenticated client ID {} on connection {}", client_id, connection_id); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { - if !session.authenticated { - todo!("Session is not authenticated"); - } if let Some(pos) = session.outgoing_publish_receives.iter().position(|x| *x == packet.packet_id) { @@ -812,11 +919,18 @@ impl Broker { } } - async fn handle_publish_received(&mut self, client_id: String, packet: PublishReceivedPacket) { + async fn handle_publish_received( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishReceivedPacket, + ) { + if self.unauthenticated_sessions.contains_key(&connection_id) { + warn!("Discarding publish received packet from unauthenticated client ID {} on connection {}", client_id, connection_id); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { - if !session.authenticated { - todo!("Session is not authenticated"); - } if let Some(pos) = session.outgoing_packets.iter().position(|p| { p.qos == QoS::ExactlyOnce && p.packet_id.map(|id| id == packet.packet_id).unwrap_or(false) @@ -838,11 +952,18 @@ impl Broker { } } - fn handle_publish_complete(&mut self, client_id: String, packet: PublishCompletePacket) { + fn handle_publish_complete( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishCompletePacket, + ) { + if self.unauthenticated_sessions.contains_key(&connection_id) { + warn!("Discarding publish complete packet from unauthenticated client ID {} on connection {}", client_id, connection_id); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { - if !session.authenticated { - todo!("Session is not authenticated"); - } if let Some(pos) = session.outgoing_publish_released.iter().position(|x| *x == packet.packet_id) { @@ -851,11 +972,21 @@ impl Broker { } } - async fn publish_final_will(&mut self, client_id: String, final_will: FinalWill) { + async fn publish_final_will( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + final_will: FinalWill, + ) { + if self.unauthenticated_sessions.contains_key(&connection_id) { + warn!( + "Discarding final will packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { - if !session.authenticated { - todo!("Session is not authenticated"); - } if session.client_sender.is_some() { // They've reconnected, don't send out the will message. self.publish_message(final_will.into()).await; @@ -871,38 +1002,38 @@ impl Broker { pub async fn run(mut self) { while let Some(msg) = self.receiver.recv().await { match msg { - BrokerMessage::NewClient(connect_packet, client_msg_sender) => { - self.handle_new_client(*connect_packet, client_msg_sender).await; + BrokerMessage::Connect(connection_id, connect_packet, client_msg_sender) => { + self.handle_new_client(connection_id, *connect_packet, client_msg_sender).await; }, - BrokerMessage::Authenticate(client_id, packet) => { - self.handle_authenticate(client_id, packet).await; + BrokerMessage::Disconnect(connection_id, client_id, will_disconnect_logic) => { + self.handle_disconnect(connection_id, client_id, will_disconnect_logic); }, - BrokerMessage::Subscribe(client_id, packet) => { - self.handle_subscribe(client_id, packet).await; + BrokerMessage::Authenticate(connection_id, client_id, packet) => { + self.handle_authenticate(connection_id, client_id, packet).await; }, - BrokerMessage::Unsubscribe(client_id, packet) => { - self.handle_unsubscribe(client_id, packet).await; + BrokerMessage::Subscribe(connection_id, client_id, packet) => { + self.handle_subscribe(connection_id, client_id, packet).await; }, - BrokerMessage::Disconnect(client_id, will_disconnect_logic) => { - self.handle_disconnect(client_id, will_disconnect_logic); + BrokerMessage::Unsubscribe(conneciton_id, client_id, packet) => { + self.handle_unsubscribe(conneciton_id, client_id, packet).await; }, - BrokerMessage::Publish(client_id, packet) => { - self.handle_publish(client_id, *packet).await; + BrokerMessage::Publish(connection_id, client_id, packet) => { + self.handle_publish(connection_id, client_id, *packet).await; }, - BrokerMessage::PublishAck(client_id, packet) => { - self.handle_publish_ack(client_id, packet); + BrokerMessage::PublishAck(connection_id, client_id, packet) => { + self.handle_publish_ack(connection_id, client_id, packet); }, - BrokerMessage::PublishRelease(client_id, packet) => { - self.handle_publish_release(client_id, packet).await; + BrokerMessage::PublishRelease(connection_id, client_id, packet) => { + self.handle_publish_release(connection_id, client_id, packet).await; }, - BrokerMessage::PublishReceived(client_id, packet) => { - self.handle_publish_received(client_id, packet).await; + BrokerMessage::PublishReceived(connection_id, client_id, packet) => { + self.handle_publish_received(connection_id, client_id, packet).await; }, - BrokerMessage::PublishComplete(client_id, packet) => { - self.handle_publish_complete(client_id, packet); + BrokerMessage::PublishComplete(connection_id, client_id, packet) => { + self.handle_publish_complete(connection_id, client_id, packet); }, - BrokerMessage::PublishFinalWill(client_id, final_will) => { - self.publish_final_will(client_id, final_will).await; + BrokerMessage::PublishFinalWill(connection_id, client_id, final_will) => { + self.publish_final_will(connection_id, client_id, final_will).await; }, } } @@ -947,8 +1078,10 @@ mod tests { password: Some("test".into()), }; + let connection_id = nanoid::nanoid!(); + let _ = broker_tx - .send(BrokerMessage::NewClient(Box::new(connect_packet), sender)) + .send(BrokerMessage::Connect(connection_id, Box::new(connect_packet), sender)) .await .unwrap(); @@ -982,6 +1115,7 @@ mod tests { let _ = broker_tx .send(BrokerMessage::Subscribe( + "CONNECTION".to_string(), "TEST".to_string(), SubscribePacket { packet_id: 0, diff --git a/mqtt-v5-broker/src/client.rs b/mqtt-v5-broker/src/client.rs index 9eb11a2..033d0a6 100644 --- a/mqtt-v5-broker/src/client.rs +++ b/mqtt-v5-broker/src/client.rs @@ -1,4 +1,4 @@ -use crate::broker::{BrokerMessage, WillDisconnectLogic}; +use crate::broker::{BrokerMessage, ConnectionId, WillDisconnectLogic}; use futures::{ future::{self, Either}, stream, Sink, SinkExt, Stream, StreamExt, @@ -175,6 +175,7 @@ where } struct UnconnectedClient, SI: Sink> { + connection_id: ConnectionId, packet_stream: ST, packet_sink: SI, broker_tx: Sender, @@ -184,7 +185,8 @@ impl + Unpin, SI: Sink { pub fn new(packet_stream: ST, packet_sink: SI, broker_tx: Sender) -> Self { - Self { packet_stream, packet_sink, broker_tx } + let connection_id = nanoid!(); + Self { connection_id, packet_stream, packet_sink, broker_tx } } pub async fn handshake(mut self) -> Result, ProtocolError> { @@ -223,11 +225,18 @@ impl + Unpin, SI: Sink, SI: Sink> { - id: String, + connection_id: ConnectionId, + client_id: String, _protocol_version: ProtocolVersion, keepalive_seconds: Option, packet_stream: ST, @@ -272,7 +282,8 @@ impl + Unpin, SI: Sink, packet_stream: ST, @@ -282,7 +293,8 @@ impl + Unpin, SI: Sink, ) -> Self { Self { - id, + connection_id, + client_id, _protocol_version: protocol_version, keepalive_seconds, packet_stream, @@ -294,6 +306,7 @@ impl + Unpin, SI: Sink, @@ -329,13 +342,21 @@ impl + Unpin, SI: Sink match frame { Packet::Subscribe(packet) => { broker_tx - .send(BrokerMessage::Subscribe(client_id.clone(), packet)) + .send(BrokerMessage::Subscribe( + connection_id.clone(), + client_id.clone(), + packet, + )) .await .expect("Couldn't send Subscribe message to broker"); }, Packet::Unsubscribe(packet) => { broker_tx - .send(BrokerMessage::Unsubscribe(client_id.clone(), packet)) + .send(BrokerMessage::Unsubscribe( + connection_id.clone(), + client_id.clone(), + packet, + )) .await .expect("Couldn't send Unsubscribe message to broker"); }, @@ -351,31 +372,51 @@ impl + Unpin, SI: Sink { broker_tx - .send(BrokerMessage::PublishAck(client_id.clone(), packet)) + .send(BrokerMessage::PublishAck( + connection_id.clone(), + client_id.clone(), + packet, + )) .await .expect("Couldn't send PublishAck message to broker"); }, Packet::PublishRelease(packet) => { broker_tx - .send(BrokerMessage::PublishRelease(client_id.clone(), packet)) + .send(BrokerMessage::PublishRelease( + connection_id.clone(), + client_id.clone(), + packet, + )) .await .expect("Couldn't send PublishRelease message to broker"); }, Packet::PublishReceived(packet) => { broker_tx - .send(BrokerMessage::PublishReceived(client_id.clone(), packet)) + .send(BrokerMessage::PublishReceived( + connection_id.clone(), + client_id.clone(), + packet, + )) .await .expect("Couldn't send PublishReceive message to broker"); }, Packet::PublishComplete(packet) => { broker_tx - .send(BrokerMessage::PublishComplete(client_id.clone(), packet)) + .send(BrokerMessage::PublishComplete( + connection_id.clone(), + client_id.clone(), + packet, + )) .await .expect("Couldn't send PublishCompelte message to broker"); }, @@ -395,6 +436,7 @@ impl + Unpin, SI: Sink + Unpin, SI: Sink { broker_tx - .send(BrokerMessage::Authenticate(client_id.clone(), packet)) + .send(BrokerMessage::Authenticate( + connection_id.clone(), + client_id.clone(), + packet, + )) .await .expect("Couldn't send Authentivate message to self"); }, @@ -422,7 +468,11 @@ impl + Unpin, SI: Sink + Unpin, SI: Sink + Unpin, SI: Sink Date: Thu, 2 Jun 2022 11:30:42 +0200 Subject: [PATCH 7/8] Optimize connection id Replace String connection id with cheaper u64. The connection id is a atomically inremented counter in the client handler. --- mqtt-v5-broker/src/broker.rs | 12 +++++------- mqtt-v5-broker/src/client.rs | 32 ++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index 3709aa1..7b9194b 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -205,7 +205,7 @@ pub enum WillDisconnectLogic { } /// Unique identifier for a connection -pub type ConnectionId = String; +pub type ConnectionId = u64; /// Client ID pub type ClientId = String; @@ -1014,8 +1014,8 @@ impl Broker { BrokerMessage::Subscribe(connection_id, client_id, packet) => { self.handle_subscribe(connection_id, client_id, packet).await; }, - BrokerMessage::Unsubscribe(conneciton_id, client_id, packet) => { - self.handle_unsubscribe(conneciton_id, client_id, packet).await; + BrokerMessage::Unsubscribe(connection_id, client_id, packet) => { + self.handle_unsubscribe(connection_id, client_id, packet).await; }, BrokerMessage::Publish(connection_id, client_id, packet) => { self.handle_publish(connection_id, client_id, *packet).await; @@ -1078,10 +1078,8 @@ mod tests { password: Some("test".into()), }; - let connection_id = nanoid::nanoid!(); - let _ = broker_tx - .send(BrokerMessage::Connect(connection_id, Box::new(connect_packet), sender)) + .send(BrokerMessage::Connect(0, Box::new(connect_packet), sender)) .await .unwrap(); @@ -1115,7 +1113,7 @@ mod tests { let _ = broker_tx .send(BrokerMessage::Subscribe( - "CONNECTION".to_string(), + 0, "TEST".to_string(), SubscribePacket { packet_id: 0, diff --git a/mqtt-v5-broker/src/client.rs b/mqtt-v5-broker/src/client.rs index 033d0a6..4988fc3 100644 --- a/mqtt-v5-broker/src/client.rs +++ b/mqtt-v5-broker/src/client.rs @@ -12,7 +12,7 @@ use mqtt_v5::{ }, }; use nanoid::nanoid; -use std::{marker::Unpin, time::Duration}; +use std::{marker::Unpin, sync::atomic::AtomicU64, time::Duration}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::mpsc::{self, Receiver, Sender}, @@ -28,6 +28,14 @@ use mqtt_v5::{ WsUpgraderCodec, }, }; +use std::sync::atomic::Ordering; + +/// Generate a new unique connection id +fn next_connection_id() -> u64 { + static CONNECTION_ID: AtomicU64 = AtomicU64::new(0); + // This operation wraps around on overflow. + CONNECTION_ID.fetch_add(1, Ordering::Relaxed) +} type PacketResult = Result; @@ -185,7 +193,7 @@ impl + Unpin, SI: Sink { pub fn new(packet_stream: ST, packet_sink: SI, broker_tx: Sender) -> Self { - let connection_id = nanoid!(); + let connection_id = next_connection_id(); Self { connection_id, packet_stream, packet_sink, broker_tx } } @@ -226,7 +234,7 @@ impl + Unpin, SI: Sink + Unpin, SI: Sink { broker_tx .send(BrokerMessage::Subscribe( - connection_id.clone(), + connection_id, client_id.clone(), packet, )) @@ -353,7 +361,7 @@ impl + Unpin, SI: Sink { broker_tx .send(BrokerMessage::Unsubscribe( - connection_id.clone(), + connection_id, client_id.clone(), packet, )) @@ -373,7 +381,7 @@ impl + Unpin, SI: Sink + Unpin, SI: Sink { broker_tx .send(BrokerMessage::PublishAck( - connection_id.clone(), + connection_id, client_id.clone(), packet, )) @@ -393,7 +401,7 @@ impl + Unpin, SI: Sink { broker_tx .send(BrokerMessage::PublishRelease( - connection_id.clone(), + connection_id, client_id.clone(), packet, )) @@ -403,7 +411,7 @@ impl + Unpin, SI: Sink { broker_tx .send(BrokerMessage::PublishReceived( - connection_id.clone(), + connection_id, client_id.clone(), packet, )) @@ -413,7 +421,7 @@ impl + Unpin, SI: Sink { broker_tx .send(BrokerMessage::PublishComplete( - connection_id.clone(), + connection_id, client_id.clone(), packet, )) @@ -448,7 +456,7 @@ impl + Unpin, SI: Sink { broker_tx .send(BrokerMessage::Authenticate( - connection_id.clone(), + connection_id, client_id.clone(), packet, )) @@ -523,7 +531,7 @@ impl + Unpin, SI: Sink Date: Thu, 2 Jun 2022 11:45:04 +0200 Subject: [PATCH 8/8] Rename UnauthenticatedSession to UnauthenticatedConnection --- mqtt-v5-broker/src/broker.rs | 45 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index 7b9194b..68d3420 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -24,13 +24,13 @@ use tokio::{ time, }; -/// Client connected but not yet authenticated. -struct UnauthenticatedSession { +/// A client connected but not yet authenticated. +struct UnauthenticatedConnection { connect_packet: ConnectPacket, client_sender: Sender, } -impl UnauthenticatedSession { +impl UnauthenticatedConnection { /// Construct a new UnauthenticatedSession. fn new(connect_packet: ConnectPacket, client_sender: Sender) -> Self { Self { connect_packet, client_sender } @@ -226,14 +226,13 @@ pub enum BrokerMessage { } pub struct Broker { - /// A map of client IDs to unauthenticated sessions. Once a client passes - /// authentication, the session exended is moved to the `session` map. - unauthenticated_sessions: HashMap, + /// A map of client IDs to unauthenticated connections. Once a client passes + /// authentication, the connection is promoted to a session. + unauthenticated_connections: HashMap, sessions: HashMap, sender: Sender, receiver: Receiver, subscriptions: SubscriptionTree, - #[allow(unused)] plugin: A, } @@ -249,7 +248,7 @@ impl Broker { let (sender, receiver) = mpsc::channel(100); Broker { - unauthenticated_sessions: HashMap::new(), + unauthenticated_connections: HashMap::new(), sessions: HashMap::new(), sender, receiver, @@ -263,7 +262,7 @@ impl Broker { let (sender, receiver) = mpsc::channel(100); Broker { - unauthenticated_sessions: HashMap::new(), + unauthenticated_connections: HashMap::new(), sessions: HashMap::new(), sender, receiver, @@ -394,10 +393,10 @@ impl Broker { .await .ok(); let client_id = connect_packet.client_id.clone(); - info!("Adding unauthenticated session for client ID {}", client_id); + info!("Adding unauthenticated connection for client ID {}", client_id); let unauthenticated_session = - UnauthenticatedSession::new(connect_packet, client_msg_sender); - self.unauthenticated_sessions.insert(connection_id, unauthenticated_session); + UnauthenticatedConnection::new(connect_packet, client_msg_sender); + self.unauthenticated_connections.insert(connection_id, unauthenticated_session); }, } } @@ -514,7 +513,7 @@ impl Broker { ) { debug!("Trying to authenticate client {} (connection {})", client_id, connection_id); - let entry = match self.unauthenticated_sessions.entry(connection_id) { + let entry = match self.unauthenticated_connections.entry(connection_id) { Entry::Occupied(entry) => entry, Entry::Vacant(entry) => { warn!("Received authenticate packet for unknown client ID {}", entry.key()); @@ -524,7 +523,7 @@ impl Broker { match self.plugin.on_authenticate(&packet) { AuthentificationResult::Reason(ConnectReason::Success) => { - let (client_id, UnauthenticatedSession { client_sender, connect_packet }) = + let (client_id, UnauthenticatedConnection { client_sender, connect_packet }) = entry.remove_entry(); info!("Authentification successful for client ID {}", client_id); self.handle_authenticated_client(connect_packet, client_sender).await; @@ -576,7 +575,7 @@ impl Broker { client_id: ClientId, packet: SubscribePacket, ) { - if self.unauthenticated_sessions.contains_key(&connection_id) { + if self.unauthenticated_connections.contains_key(&connection_id) { warn!( "Ignoring subscribe packet from unauthenticated client ID {} on connection {}", client_id, connection_id @@ -655,7 +654,7 @@ impl Broker { client_id: ClientId, packet: UnsubscribePacket, ) { - if self.unauthenticated_sessions.contains_key(&connection_id) { + if self.unauthenticated_connections.contains_key(&connection_id) { warn!( "Ignoring unsubscribe packet from unauthenticated client ID {} on connection {}", client_id, connection_id @@ -711,7 +710,7 @@ impl Broker { client_id: String, will_disconnect_logic: WillDisconnectLogic, ) { - if self.unauthenticated_sessions.remove(&connection_id).is_some() { + if self.unauthenticated_connections.remove(&connection_id).is_some() { info!("Removing unauthenticated session for client ID {}", client_id); return; } @@ -818,7 +817,7 @@ impl Broker { client_id: ClientId, packet: PublishPacket, ) { - if self.unauthenticated_sessions.contains_key(&connection_id) { + if self.unauthenticated_connections.contains_key(&connection_id) { warn!( "Discarding publish packet from unauthenticated client ID {} on connection {}", client_id, connection_id @@ -877,7 +876,7 @@ impl Broker { client_id: ClientId, packet: PublishAckPacket, ) { - if self.unauthenticated_sessions.contains_key(&connection_id) { + if self.unauthenticated_connections.contains_key(&connection_id) { warn!( "Discarding publish ack packet from unauthenticated client ID {} on connection {}", client_id, connection_id @@ -896,7 +895,7 @@ impl Broker { client_id: ClientId, packet: PublishReleasePacket, ) { - if self.unauthenticated_sessions.contains_key(&connection_id) { + if self.unauthenticated_connections.contains_key(&connection_id) { warn!("Discarding publish release packet from unauthenticated client ID {} on connection {}", client_id, connection_id); return; } @@ -925,7 +924,7 @@ impl Broker { client_id: ClientId, packet: PublishReceivedPacket, ) { - if self.unauthenticated_sessions.contains_key(&connection_id) { + if self.unauthenticated_connections.contains_key(&connection_id) { warn!("Discarding publish received packet from unauthenticated client ID {} on connection {}", client_id, connection_id); return; } @@ -958,7 +957,7 @@ impl Broker { client_id: ClientId, packet: PublishCompletePacket, ) { - if self.unauthenticated_sessions.contains_key(&connection_id) { + if self.unauthenticated_connections.contains_key(&connection_id) { warn!("Discarding publish complete packet from unauthenticated client ID {} on connection {}", client_id, connection_id); return; } @@ -978,7 +977,7 @@ impl Broker { client_id: ClientId, final_will: FinalWill, ) { - if self.unauthenticated_sessions.contains_key(&connection_id) { + if self.unauthenticated_connections.contains_key(&connection_id) { warn!( "Discarding final will packet from unauthenticated client ID {} on connection {}", client_id, connection_id