diff --git a/mqtt-v5-broker/src/broker.rs b/mqtt-v5-broker/src/broker.rs index 3e9ead9..68d3420 100644 --- a/mqtt-v5-broker/src/broker.rs +++ b/mqtt-v5-broker/src/broker.rs @@ -1,22 +1,50 @@ -use crate::{client::ClientMessage, tree::SubscriptionTree}; -use log::{info, warn}; +use crate::{ + client::ClientMessage, + plugin::{AuthentificationResult, Noop, Plugin}, + tree::SubscriptionTree, +}; +use log::{debug, info, warn}; use mqtt_v5::{ topic::TopicFilter, types::{ properties::{AssignedClientIdentifier, SessionExpiryInterval}, - ConnectAckPacket, ConnectPacket, ConnectReason, DisconnectReason, FinalWill, Packet, - ProtocolVersion, PublishAckPacket, PublishAckReason, PublishCompletePacket, - PublishCompleteReason, PublishPacket, PublishReceivedPacket, PublishReceivedReason, - PublishReleasePacket, PublishReleaseReason, QoS, SubscribeAckPacket, SubscribeAckReason, - SubscribePacket, UnsubscribeAckPacket, UnsubscribeAckReason, UnsubscribePacket, + AuthenticatePacket, ConnectAckPacket, ConnectPacket, ConnectReason, DisconnectReason, + FinalWill, Packet, ProtocolVersion, PublishAckPacket, PublishCompletePacket, + PublishCompleteReason, PublishPacket, PublishReceivedPacket, PublishReleasePacket, + PublishReleaseReason, QoS, SubscribeAckPacket, SubscribeAckReason, SubscribePacket, + UnsubscribeAckPacket, UnsubscribeAckReason, UnsubscribePacket, }, }; use std::{ collections::{hash_map::Entry, HashMap}, time::Duration, }; -use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::{ + sync::mpsc::{self, Receiver, Sender}, + time, +}; + +/// A client connected but not yet authenticated. +struct UnauthenticatedConnection { + connect_packet: ConnectPacket, + client_sender: Sender, +} + +impl UnauthenticatedConnection { + /// Construct a new UnauthenticatedSession. + fn new(connect_packet: ConnectPacket, client_sender: Sender) -> Self { + Self { connect_packet, client_sender } + } + + /// Send a `ClientMessage` to the client via the channel handle. + /// This is a fire and forget operation because upon any error on the + /// connection the `Client` will send a `BrokerMessage::Disconnect`. + async fn send(&self, message: ClientMessage) { + drop(self.client_sender.send(message).await); + } +} +#[derive(Debug)] struct Session { #[allow(unused)] pub protocol_version: ProtocolVersion, @@ -176,38 +204,71 @@ pub enum WillDisconnectLogic { DoNotSend, } +/// Unique identifier for a connection +pub type ConnectionId = u64; + +/// Client ID +pub type ClientId = String; + #[derive(Debug)] pub enum BrokerMessage { - NewClient(Box, Sender), - Publish(String, Box), - PublishAck(String, PublishAckPacket), // TODO - This can be handled by the client task - PublishRelease(String, PublishReleasePacket), // TODO - This can be handled by the client task - PublishReceived(String, PublishReceivedPacket), - PublishComplete(String, PublishCompletePacket), - PublishFinalWill(String, FinalWill), - Subscribe(String, SubscribePacket), // TODO - replace string client_id with int - Unsubscribe(String, UnsubscribePacket), // TODO - replace string client_id with int - Disconnect(String, WillDisconnectLogic), + Connect(ConnectionId, Box, Sender), + Disconnect(ConnectionId, String, WillDisconnectLogic), + Authenticate(ConnectionId, ClientId, AuthenticatePacket), + Publish(ConnectionId, ClientId, Box), + PublishAck(ConnectionId, ClientId, PublishAckPacket), // TODO - This can be handled by the client task + PublishRelease(ConnectionId, ClientId, PublishReleasePacket), // TODO - This can be handled by the client task + PublishReceived(ConnectionId, ClientId, PublishReceivedPacket), + PublishComplete(ConnectionId, ClientId, PublishCompletePacket), + PublishFinalWill(ConnectionId, ClientId, FinalWill), + Subscribe(ConnectionId, ClientId, SubscribePacket), // TODO - replace string client_id with int + Unsubscribe(ConnectionId, ClientId, UnsubscribePacket), // TODO - replace string client_id with int } -pub struct Broker { +pub struct Broker { + /// A map of client IDs to unauthenticated connections. Once a client passes + /// authentication, the connection is promoted to a session. + unauthenticated_connections: HashMap, sessions: HashMap, sender: Sender, receiver: Receiver, subscriptions: SubscriptionTree, + plugin: A, } impl Default for Broker { fn default() -> Self { - Self::new() + Broker::::new() } } -impl Broker { - pub fn new() -> Self { +impl Broker { + /// Construct a new Broker. + pub fn new() -> Broker { let (sender, receiver) = mpsc::channel(100); - Self { sessions: HashMap::new(), sender, receiver, subscriptions: SubscriptionTree::new() } + Broker { + unauthenticated_connections: HashMap::new(), + sessions: HashMap::new(), + sender, + receiver, + subscriptions: SubscriptionTree::new(), + plugin: Noop, + } + } + + /// Construct a new Broker. + pub fn with_plugin(plugin: A) -> Broker { + let (sender, receiver) = mpsc::channel(100); + + Broker { + unauthenticated_connections: HashMap::new(), + sessions: HashMap::new(), + sender, + receiver, + subscriptions: SubscriptionTree::new(), + plugin, + } } pub fn sender(&self) -> Sender { @@ -256,13 +317,91 @@ impl Broker { // If the Server accepts a connection with Clean Start set to 0 and the Server has Session State for the ClientID, // it MUST set Session Present to 1 in the CONNACK packet, otherwise it MUST set Session Present to 0 in the CONNACK packet. if new_client_clean_start { - return None; + None + } else { + existing_session } - - existing_session } async fn handle_new_client( + &mut self, + connection_id: ConnectionId, + connect_packet: ConnectPacket, + client_msg_sender: Sender, + ) { + debug!( + "Trying to authenticate client {} (connection {})", + connect_packet.client_id, connection_id + ); + match self.plugin.on_connect(&connect_packet) { + AuthentificationResult::Reason(ConnectReason::Success) => { + info!("Authentification successful for client {}", connect_packet.client_id); + self.handle_authenticated_client(connect_packet, client_msg_sender).await; + }, + AuthentificationResult::Reason(reason_code) => { + info!( + "Authentification reason code for client {} is {:?}", + connect_packet.client_id, reason_code + ); + let connect_ack = ConnectAckPacket { + session_present: false, + reason_code, + session_expiry_interval: None, + receive_maximum: None, + maximum_qos: None, + retain_available: None, + maximum_packet_size: None, + assigned_client_identifier: None, + topic_alias_maximum: None, + reason_string: None, + user_properties: Vec::with_capacity(0), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }; + + // Send a disconnect packet to the client. Ignore send errors because + // the client could already be disconnected and the rx handle of this + // channel is dropped. + debug!( + "Sending CONNACK to client ID {} with reason code {:?}", + connect_packet.client_id, reason_code + ); + client_msg_sender + .send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))) + .await + .ok(); + + debug!( + "Sending DISCONNECT to client {} with disconnect reason code {:?}", + connect_packet.client_id, + DisconnectReason::NotAuthorized + ); + client_msg_sender + .send(ClientMessage::Disconnect(DisconnectReason::NotAuthorized)) + .await + .ok(); + }, + AuthentificationResult::Packet(packet) => { + client_msg_sender + .send(ClientMessage::Packet(Packet::Authenticate(packet))) + .await + .ok(); + let client_id = connect_packet.client_id.clone(); + info!("Adding unauthenticated connection for client ID {}", client_id); + let unauthenticated_session = + UnauthenticatedConnection::new(connect_packet, client_msg_sender); + self.unauthenticated_connections.insert(connection_id, unauthenticated_session); + }, + } + } + + async fn handle_authenticated_client( &mut self, connect_packet: ConnectPacket, client_msg_sender: Sender, @@ -272,7 +411,6 @@ impl Broker { .await; let session_present = takeover_session.is_some(); - // TODO(flxo) Calling `resend_packets` after `take_over_existing_client` feels strange since a disconnect is sent in there. if let Some(existing_session) = &mut takeover_session { existing_session.resend_packets().await; } @@ -292,7 +430,23 @@ impl Broker { }, None => None, }; - + let session_expiry_duration = session_expiry_interval.map(|i| { + let duration = Duration::from_secs(i.0 as u64); + debug!( + "Client ID {} Session expiry interval is {:?}", + connect_packet.client_id, duration + ); + duration + }); + + // The user provided plugin decided when processing the Connect packet that this + // client is authenticated. This could be the case for no authentication needed or + // a username password authentication. The broker shall not wait for addition + // `Authentification` packets from the client before the client is in the authenticated + // state. + // If the client is not authenticated the `ConnectAckPacket` is sent once authentification + // succeeds or fails. + // Send conack if the auth is already successful and complete let connect_ack = ConnectAckPacket { // Variable header session_present, @@ -309,7 +463,7 @@ impl Broker { )), topic_alias_maximum: None, reason_string: None, - user_properties: vec![], + user_properties: Vec::with_capacity(0), wildcard_subscription_available: None, subscription_identifiers_available: None, shared_subscription_available: None, @@ -320,14 +474,10 @@ impl Broker { authentication_data: None, }; - // A newly connected client should have empty channel and the queuing the connect ack *must* fit in the channel. - client_msg_sender - .send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))) - .await - .expect("Failed to send Connect Acknowledgement"); - - let session_expiry_duration = - session_expiry_interval.map(|i| Duration::from_secs(i.0 as u64)); + // If the client disconnected in the meantime, the rx part of the client handle is dropped + // and a send attempt will fail. Ignore this error, because the disconnection is handled + // by a BrokerMessage::Disconnect. + client_msg_sender.send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))).await.ok(); let new_session = if let Some(existing_session) = takeover_session { let mut new_session = existing_session.into_new_session( @@ -352,14 +502,106 @@ impl Broker { self.sessions.insert(connect_packet.client_id, new_session); } - async fn handle_subscribe(&mut self, client_id: String, packet: SubscribePacket) { + /// Handle authenticate packets. Query the plugin and send a `ConnectAckPacket` or `Authenticate` + /// packet if needed. If the plugin authenticates the client (session) proceed with the session + /// setup etc. and remove the unauthenticated session. + async fn handle_authenticate( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: AuthenticatePacket, + ) { + debug!("Trying to authenticate client {} (connection {})", client_id, connection_id); + + let entry = match self.unauthenticated_connections.entry(connection_id) { + Entry::Occupied(entry) => entry, + Entry::Vacant(entry) => { + warn!("Received authenticate packet for unknown client ID {}", entry.key()); + return; + }, + }; + + match self.plugin.on_authenticate(&packet) { + AuthentificationResult::Reason(ConnectReason::Success) => { + let (client_id, UnauthenticatedConnection { client_sender, connect_packet }) = + entry.remove_entry(); + info!("Authentification successful for client ID {}", client_id); + self.handle_authenticated_client(connect_packet, client_sender).await; + }, + AuthentificationResult::Reason(reason_code) => { + info!("Authentification result for client ID {} is {:?}", entry.key(), reason_code); + let connect_ack = ConnectAckPacket { + // Variable header + session_present: false, + reason_code, + + // Properties + session_expiry_interval: None, + receive_maximum: None, + maximum_qos: None, + retain_available: None, + maximum_packet_size: None, + assigned_client_identifier: None, + topic_alias_maximum: None, + reason_string: None, + user_properties: Vec::with_capacity(0), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }; + + // If the client disconnected in the meantime, the rx part of the client handle is dropped + // and a send attempt will fail. Ignore this error, because the disconnection is handled + // by a BrokerMessage::Disconnect. + let session = entry.get(); + session.send(ClientMessage::Packet(Packet::ConnectAck(connect_ack))).await; + session.send(ClientMessage::Disconnect(DisconnectReason::NotAuthorized)).await; + }, + AuthentificationResult::Packet(packet) => { + let session = entry.get(); + session.send(ClientMessage::Packet(Packet::Authenticate(packet))).await; + }, + } + } + + async fn handle_subscribe( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: SubscribePacket, + ) { + if self.unauthenticated_connections.contains_key(&connection_id) { + warn!( + "Ignoring subscribe packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } + let subscriptions = &mut self.subscriptions; if let Some(session) = self.sessions.get_mut(&client_id) { + let plugin_ack = self.plugin.on_subscribe(&packet); + // If a Server receives a SUBSCRIBE packet containing a Topic Filter that // is identical to a Non‑shared Subscription’s Topic Filter for the current // Session, then it MUST replace that existing Subscription with a new Subscription. - for topic in &packet.subscription_topics { + for (topic, plugin_code) in + packet.subscription_topics.iter().zip(&plugin_ack.reason_codes) + { + // Check only subscriptions that the plugin didn't reject. + match plugin_code { + SubscribeAckReason::GrantedQoSZero + | SubscribeAckReason::GrantedQoSOne + | SubscribeAckReason::GrantedQoSTwo => (), + _ => continue, + } + let topic = &topic.topic_filter; // Unsubscribe the old session from all topics it subscribed to. session.subscription_tokens.retain(|(session_topic, token)| { @@ -377,27 +619,28 @@ impl Broker { let granted_qos_values = packet .subscription_topics .into_iter() - .map(|topic| { - let session_subscription = SessionSubscription { - client_id: client_id.clone(), - maximum_qos: topic.maximum_qos, - }; - let token = subscriptions.insert(&topic.topic_filter, session_subscription); - - session.subscription_tokens.push((topic.topic_filter.clone(), token)); - - match topic.maximum_qos { - QoS::AtMostOnce => SubscribeAckReason::GrantedQoSZero, - QoS::AtLeastOnce => SubscribeAckReason::GrantedQoSOne, - QoS::ExactlyOnce => SubscribeAckReason::GrantedQoSTwo, - } + .zip(plugin_ack.reason_codes) + .map(|(topic, plugin_reason)| match plugin_reason { + SubscribeAckReason::GrantedQoSZero + | SubscribeAckReason::GrantedQoSOne + | SubscribeAckReason::GrantedQoSTwo => { + let session_subscription = SessionSubscription { + client_id: client_id.clone(), + maximum_qos: topic.maximum_qos, + }; + let token = subscriptions.insert(&topic.topic_filter, session_subscription); + + session.subscription_tokens.push((topic.topic_filter.clone(), token)); + plugin_reason + }, + reason => reason, }) .collect(); let subscribe_ack = SubscribeAckPacket { packet_id: packet.packet_id, - reason_string: None, - user_properties: vec![], + reason_string: plugin_ack.reason_string, + user_properties: plugin_ack.user_properties, reason_codes: granted_qos_values, }; @@ -405,7 +648,20 @@ impl Broker { } } - async fn handle_unsubscribe(&mut self, client_id: String, packet: UnsubscribePacket) { + async fn handle_unsubscribe( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: UnsubscribePacket, + ) { + if self.unauthenticated_connections.contains_key(&connection_id) { + warn!( + "Ignoring unsubscribe packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } + let subscriptions = &mut self.subscriptions; if let Some(session) = self.sessions.get_mut(&client_id) { @@ -448,14 +704,29 @@ impl Broker { } } - fn handle_disconnect(&mut self, client_id: String, will_disconnect_logic: WillDisconnectLogic) { + fn handle_disconnect( + &mut self, + connection_id: ConnectionId, + client_id: String, + will_disconnect_logic: WillDisconnectLogic, + ) { + if self.unauthenticated_connections.remove(&connection_id).is_some() { + info!("Removing unauthenticated session for client ID {}", client_id); + return; + } + info!("Client ID {} disconnected", client_id); + self.plugin.on_disconnect(&client_id); + let mut disconnect_will = None; let mut session_expiry_duration = None; if let Entry::Occupied(mut session_entry) = self.sessions.entry(client_id.clone()) { - session_entry.get_mut().client_sender.take(); + { + let session = session_entry.get_mut(); + session.client_sender.take(); + } if let Some(expiry_interval) = session_entry.get().session_expiry_interval { // The Will Message MUST be published after the Network Connection is subsequently @@ -495,9 +766,13 @@ impl Broker { // Spawn a task that publishes the will after `will_send_delay_duration` tokio::spawn(async move { - tokio::time::sleep(will_send_delay_duration).await; + time::sleep(will_send_delay_duration).await; broker_sender - .send(BrokerMessage::PublishFinalWill(client_id, will)) + .send(BrokerMessage::PublishFinalWill( + connection_id, + client_id, + will, + )) .await .expect("Failed to send final will message to broker"); }); @@ -536,65 +811,95 @@ impl Broker { } } - async fn handle_publish(&mut self, client_id: String, packet: PublishPacket) { - let mut is_dup = false; - - // For QoS2, ensure this packet isn't delivered twice. So if we have an outgoing - // publish receive with the same ID, just send the publish receive again but don't forward - // the message. - match packet.qos { - QoS::AtMostOnce => {}, - QoS::AtLeastOnce => { - if let Some(session) = self.sessions.get_mut(&client_id) { - let publish_ack = PublishAckPacket { - packet_id: packet - .packet_id - .expect("Packet with QoS 1 should have a packet ID"), - reason_code: PublishAckReason::Success, - reason_string: None, - user_properties: vec![], - }; - - session.send(ClientMessage::Packet(Packet::PublishAck(publish_ack))).await; - } - }, - QoS::ExactlyOnce => { - if let Some(session) = self.sessions.get_mut(&client_id) { - let packet_id = packet.packet_id.unwrap(); - is_dup = session.outgoing_publish_receives.contains(&packet_id); + async fn handle_publish( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishPacket, + ) { + if self.unauthenticated_connections.contains_key(&connection_id) { + warn!( + "Discarding publish packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } - if !is_dup { - session.outgoing_publish_receives.push(packet_id) + if let Some(session) = self.sessions.get_mut(&client_id) { + match packet.qos { + QoS::AtMostOnce => { + if self.plugin.on_publish_received_qos0(&packet) { + self.publish_message(packet).await; + } + }, + QoS::AtLeastOnce => { + let (publish, publish_ack) = self.plugin.on_publish_received_qos1(&packet); + if let Some(publish_ack) = publish_ack { + session.send(ClientMessage::Packet(Packet::PublishAck(publish_ack))).await; + } + if publish { + self.publish_message(packet).await; } + }, + // For QoS2, ensure this packet isn't delivered twice. So if we have an outgoing + // publish receive with the same ID, just send the publish receive again but don't forward + // the message. + QoS::ExactlyOnce => { + let (mut publish, publish_rec) = self.plugin.on_publish_received_qos2(&packet); - let publish_recv = PublishReceivedPacket { - packet_id: packet - .packet_id - .expect("Packet with QoS 2 should have a packet ID"), - reason_code: PublishReceivedReason::Success, - reason_string: None, - user_properties: vec![], - }; - - session - .send(ClientMessage::Packet(Packet::PublishReceived(publish_recv))) - .await; - } - }, - } + if let Some(publish_recv) = publish_rec { + let packet_id = publish_recv.packet_id; + let is_dup = session.outgoing_publish_receives.contains(&packet_id); + + publish = publish && !is_dup; + + if !is_dup { + session.outgoing_publish_receives.push(packet_id) + } + + session + .send(ClientMessage::Packet(Packet::PublishReceived(publish_recv))) + .await; + } - if !is_dup { - self.publish_message(packet).await; + if publish { + self.publish_message(packet).await; + } + }, + } } } - fn handle_publish_ack(&mut self, client_id: String, packet: PublishAckPacket) { + fn handle_publish_ack( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishAckPacket, + ) { + if self.unauthenticated_connections.contains_key(&connection_id) { + warn!( + "Discarding publish ack packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { session.remove_outgoing_publish(packet.packet_id); } } - async fn handle_publish_release(&mut self, client_id: String, packet: PublishReleasePacket) { + async fn handle_publish_release( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishReleasePacket, + ) { + if self.unauthenticated_connections.contains_key(&connection_id) { + warn!("Discarding publish release packet from unauthenticated client ID {} on connection {}", client_id, connection_id); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { if let Some(pos) = session.outgoing_publish_receives.iter().position(|x| *x == packet.packet_id) @@ -613,7 +918,17 @@ impl Broker { } } - async fn handle_publish_received(&mut self, client_id: String, packet: PublishReceivedPacket) { + async fn handle_publish_received( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishReceivedPacket, + ) { + if self.unauthenticated_connections.contains_key(&connection_id) { + warn!("Discarding publish received packet from unauthenticated client ID {} on connection {}", client_id, connection_id); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { if let Some(pos) = session.outgoing_packets.iter().position(|p| { p.qos == QoS::ExactlyOnce @@ -636,7 +951,17 @@ impl Broker { } } - fn handle_publish_complete(&mut self, client_id: String, packet: PublishCompletePacket) { + fn handle_publish_complete( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + packet: PublishCompletePacket, + ) { + if self.unauthenticated_connections.contains_key(&connection_id) { + warn!("Discarding publish complete packet from unauthenticated client ID {} on connection {}", client_id, connection_id); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { if let Some(pos) = session.outgoing_publish_released.iter().position(|x| *x == packet.packet_id) @@ -646,7 +971,20 @@ impl Broker { } } - async fn publish_final_will(&mut self, client_id: String, final_will: FinalWill) { + async fn publish_final_will( + &mut self, + connection_id: ConnectionId, + client_id: ClientId, + final_will: FinalWill, + ) { + if self.unauthenticated_connections.contains_key(&connection_id) { + warn!( + "Discarding final will packet from unauthenticated client ID {} on connection {}", + client_id, connection_id + ); + return; + } + if let Some(session) = self.sessions.get_mut(&client_id) { if session.client_sender.is_some() { // They've reconnected, don't send out the will message. @@ -663,35 +1001,38 @@ impl Broker { pub async fn run(mut self) { while let Some(msg) = self.receiver.recv().await { match msg { - BrokerMessage::NewClient(connect_packet, client_msg_sender) => { - self.handle_new_client(*connect_packet, client_msg_sender).await; + BrokerMessage::Connect(connection_id, connect_packet, client_msg_sender) => { + self.handle_new_client(connection_id, *connect_packet, client_msg_sender).await; + }, + BrokerMessage::Disconnect(connection_id, client_id, will_disconnect_logic) => { + self.handle_disconnect(connection_id, client_id, will_disconnect_logic); }, - BrokerMessage::Subscribe(client_id, packet) => { - self.handle_subscribe(client_id, packet).await; + BrokerMessage::Authenticate(connection_id, client_id, packet) => { + self.handle_authenticate(connection_id, client_id, packet).await; }, - BrokerMessage::Unsubscribe(client_id, packet) => { - self.handle_unsubscribe(client_id, packet).await; + BrokerMessage::Subscribe(connection_id, client_id, packet) => { + self.handle_subscribe(connection_id, client_id, packet).await; }, - BrokerMessage::Disconnect(client_id, will_disconnect_logic) => { - self.handle_disconnect(client_id, will_disconnect_logic); + BrokerMessage::Unsubscribe(connection_id, client_id, packet) => { + self.handle_unsubscribe(connection_id, client_id, packet).await; }, - BrokerMessage::Publish(client_id, packet) => { - self.handle_publish(client_id, *packet).await; + BrokerMessage::Publish(connection_id, client_id, packet) => { + self.handle_publish(connection_id, client_id, *packet).await; }, - BrokerMessage::PublishAck(client_id, packet) => { - self.handle_publish_ack(client_id, packet); + BrokerMessage::PublishAck(connection_id, client_id, packet) => { + self.handle_publish_ack(connection_id, client_id, packet); }, - BrokerMessage::PublishRelease(client_id, packet) => { - self.handle_publish_release(client_id, packet).await; + BrokerMessage::PublishRelease(connection_id, client_id, packet) => { + self.handle_publish_release(connection_id, client_id, packet).await; }, - BrokerMessage::PublishReceived(client_id, packet) => { - self.handle_publish_received(client_id, packet).await; + BrokerMessage::PublishReceived(connection_id, client_id, packet) => { + self.handle_publish_received(connection_id, client_id, packet).await; }, - BrokerMessage::PublishComplete(client_id, packet) => { - self.handle_publish_complete(client_id, packet); + BrokerMessage::PublishComplete(connection_id, client_id, packet) => { + self.handle_publish_complete(connection_id, client_id, packet); }, - BrokerMessage::PublishFinalWill(client_id, final_will) => { - self.publish_final_will(client_id, final_will).await; + BrokerMessage::PublishFinalWill(connection_id, client_id, final_will) => { + self.publish_final_will(connection_id, client_id, final_will).await; }, } } @@ -703,6 +1044,7 @@ mod tests { use crate::{ broker::{Broker, BrokerMessage}, client::ClientMessage, + plugin::Noop, }; use mqtt_v5::types::{properties::*, ProtocolVersion, *}; use tokio::{ @@ -731,12 +1073,12 @@ mod tests { client_id: "TEST".to_string(), will: None, - user_name: None, - password: None, + user_name: Some("test".into()), + password: Some("test".into()), }; let _ = broker_tx - .send(BrokerMessage::NewClient(Box::new(connect_packet), sender)) + .send(BrokerMessage::Connect(0, Box::new(connect_packet), sender)) .await .unwrap(); @@ -770,6 +1112,7 @@ mod tests { let _ = broker_tx .send(BrokerMessage::Subscribe( + 0, "TEST".to_string(), SubscribePacket { packet_id: 0, @@ -802,7 +1145,7 @@ mod tests { #[test] fn simple_client_test() { - let broker = Broker::new(); + let broker = Broker::::new(); let sender = broker.sender(); let runtime = Runtime::new().unwrap(); diff --git a/mqtt-v5-broker/src/client.rs b/mqtt-v5-broker/src/client.rs index 96394ad..4988fc3 100644 --- a/mqtt-v5-broker/src/client.rs +++ b/mqtt-v5-broker/src/client.rs @@ -1,4 +1,4 @@ -use crate::broker::{BrokerMessage, WillDisconnectLogic}; +use crate::broker::{BrokerMessage, ConnectionId, WillDisconnectLogic}; use futures::{ future::{self, Either}, stream, Sink, SinkExt, Stream, StreamExt, @@ -12,7 +12,7 @@ use mqtt_v5::{ }, }; use nanoid::nanoid; -use std::{marker::Unpin, time::Duration}; +use std::{marker::Unpin, sync::atomic::AtomicU64, time::Duration}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::mpsc::{self, Receiver, Sender}, @@ -28,6 +28,14 @@ use mqtt_v5::{ WsUpgraderCodec, }, }; +use std::sync::atomic::Ordering; + +/// Generate a new unique connection id +fn next_connection_id() -> u64 { + static CONNECTION_ID: AtomicU64 = AtomicU64::new(0); + // This operation wraps around on overflow. + CONNECTION_ID.fetch_add(1, Ordering::Relaxed) +} type PacketResult = Result; @@ -175,6 +183,7 @@ where } struct UnconnectedClient, SI: Sink> { + connection_id: ConnectionId, packet_stream: ST, packet_sink: SI, broker_tx: Sender, @@ -184,7 +193,8 @@ impl + Unpin, SI: Sink { pub fn new(packet_stream: ST, packet_sink: SI, broker_tx: Sender) -> Self { - Self { packet_stream, packet_sink, broker_tx } + let connection_id = next_connection_id(); + Self { connection_id, packet_stream, packet_sink, broker_tx } } pub async fn handshake(mut self) -> Result, ProtocolError> { @@ -223,11 +233,18 @@ impl + Unpin, SI: Sink, SI: Sink> { - id: String, + connection_id: ConnectionId, + client_id: String, _protocol_version: ProtocolVersion, keepalive_seconds: Option, packet_stream: ST, @@ -272,7 +290,8 @@ impl + Unpin, SI: Sink, packet_stream: ST, @@ -282,7 +301,8 @@ impl + Unpin, SI: Sink, ) -> Self { Self { - id, + connection_id, + client_id, _protocol_version: protocol_version, keepalive_seconds, packet_stream, @@ -294,6 +314,7 @@ impl + Unpin, SI: Sink, @@ -329,13 +350,21 @@ impl + Unpin, SI: Sink match frame { Packet::Subscribe(packet) => { broker_tx - .send(BrokerMessage::Subscribe(client_id.clone(), packet)) + .send(BrokerMessage::Subscribe( + connection_id, + client_id.clone(), + packet, + )) .await .expect("Couldn't send Subscribe message to broker"); }, Packet::Unsubscribe(packet) => { broker_tx - .send(BrokerMessage::Unsubscribe(client_id.clone(), packet)) + .send(BrokerMessage::Unsubscribe( + connection_id, + client_id.clone(), + packet, + )) .await .expect("Couldn't send Unsubscribe message to broker"); }, @@ -351,31 +380,51 @@ impl + Unpin, SI: Sink { broker_tx - .send(BrokerMessage::PublishAck(client_id.clone(), packet)) + .send(BrokerMessage::PublishAck( + connection_id, + client_id.clone(), + packet, + )) .await .expect("Couldn't send PublishAck message to broker"); }, Packet::PublishRelease(packet) => { broker_tx - .send(BrokerMessage::PublishRelease(client_id.clone(), packet)) + .send(BrokerMessage::PublishRelease( + connection_id, + client_id.clone(), + packet, + )) .await .expect("Couldn't send PublishRelease message to broker"); }, Packet::PublishReceived(packet) => { broker_tx - .send(BrokerMessage::PublishReceived(client_id.clone(), packet)) + .send(BrokerMessage::PublishReceived( + connection_id, + client_id.clone(), + packet, + )) .await .expect("Couldn't send PublishReceive message to broker"); }, Packet::PublishComplete(packet) => { broker_tx - .send(BrokerMessage::PublishComplete(client_id.clone(), packet)) + .send(BrokerMessage::PublishComplete( + connection_id, + client_id.clone(), + packet, + )) .await .expect("Couldn't send PublishCompelte message to broker"); }, @@ -395,6 +444,7 @@ impl + Unpin, SI: Sink + Unpin, SI: Sink { + broker_tx + .send(BrokerMessage::Authenticate( + connection_id, + client_id.clone(), + packet, + )) + .await + .expect("Couldn't send Authentivate message to self"); + }, _ => {}, }, Err(err) => { @@ -416,7 +476,11 @@ impl + Unpin, SI: Sink + Unpin, SI: Sink + Unpin, SI: Sink Result<(), Box> { init_logging(); - let broker = Broker::new(); + let broker = Broker::default(); let broker_tx = broker.sender(); let broker = task::spawn(async { broker.run().await; diff --git a/mqtt-v5-broker/src/plugin.rs b/mqtt-v5-broker/src/plugin.rs new file mode 100644 index 0000000..d0a9a09 --- /dev/null +++ b/mqtt-v5-broker/src/plugin.rs @@ -0,0 +1,129 @@ +use log::{trace, warn}; +use mqtt_v5::types::{ + AuthenticatePacket, ConnectPacket, ConnectReason, PublishAckPacket, PublishAckReason, + PublishPacket, PublishReceivedPacket, PublishReceivedReason, QoS, SubscribeAckPacket, + SubscribeAckReason, SubscribePacket, +}; + +pub struct Noop; + +/// Result of a authentication attempt +pub enum AuthentificationResult { + /// Authentification reason + Reason(ConnectReason), + /// Send this auth packet to the client and wait for the response. + Packet(AuthenticatePacket), +} + +/// Broker plugin +pub trait Plugin { + /// Called on connect packet reception + fn on_connect(&mut self, packet: &ConnectPacket) -> AuthentificationResult; + + /// Called on client disconnect + fn on_disconnect(&mut self, client_id: &str); + + /// Called on authenticate packet reception + fn on_authenticate(&mut self, packet: &AuthenticatePacket) -> AuthentificationResult; + + /// Called on subscribe packets reception + fn on_subscribe(&mut self, packet: &SubscribePacket) -> SubscribeAckPacket; + + /// Called on publish packets reception for QoS 0. Return true if the packet should be published to the clients. + fn on_publish_received_qos0(&mut self, packet: &PublishPacket) -> bool; + /// Called on publish packets reception for QoS 1. Return if the packet should be published to the clients and + /// the publish ack packet to be sent to the publisher. + + fn on_publish_received_qos1( + &mut self, + packet: &PublishPacket, + ) -> (bool, Option); + + /// Called on publish packets reception for QoS 2. Return if the packet should be published to the clients and + /// the publish received packet to be sent to the publisher. + fn on_publish_received_qos2( + &mut self, + packet: &PublishPacket, + ) -> (bool, Option); +} + +/// Default noop authenticator +impl Plugin for Noop { + fn on_connect(&mut self, packet: &ConnectPacket) -> AuthentificationResult { + // Just a hacky test... + if packet.user_name.is_some() && packet.user_name == packet.password { + AuthentificationResult::Reason(ConnectReason::Success) + } else { + AuthentificationResult::Reason(ConnectReason::BadUserNameOrPassword) + } + } + + fn on_disconnect(&mut self, _: &str) {} + + fn on_authenticate(&mut self, _: &AuthenticatePacket) -> AuthentificationResult { + AuthentificationResult::Reason(ConnectReason::Success) + } + + fn on_subscribe(&mut self, packet: &SubscribePacket) -> SubscribeAckPacket { + SubscribeAckPacket { + packet_id: packet.packet_id, + reason_string: None, + user_properties: vec![], + reason_codes: packet + .subscription_topics + .iter() + .inspect(|filter| { + trace!("Granting subscribe to {}", filter.topic_filter); + }) + .map(|filter| match filter.maximum_qos { + QoS::AtMostOnce => SubscribeAckReason::GrantedQoSZero, + QoS::AtLeastOnce => SubscribeAckReason::GrantedQoSOne, + QoS::ExactlyOnce => SubscribeAckReason::GrantedQoSTwo, + }) + .collect(), + } + } + + fn on_publish_received_qos0(&mut self, packet: &PublishPacket) -> bool { + trace!("Granting QoS 0 publish on topic \"{}\"", packet.topic); + true + } + + fn on_publish_received_qos1( + &mut self, + packet: &PublishPacket, + ) -> (bool, Option) { + if let Some(packet_id) = packet.packet_id { + let ack = PublishAckPacket { + packet_id, + reason_code: PublishAckReason::Success, + reason_string: None, + user_properties: Vec::with_capacity(0), + }; + trace!("Granting QoS 1 publish on topic \"{}\"", packet.topic); + (true, Some(ack)) + } else { + warn!("Publish packet with QoS 1 without packet id"); + (false, None) + } + } + + fn on_publish_received_qos2( + &mut self, + packet: &PublishPacket, + ) -> (bool, Option) { + if let Some(packet_id) = packet.packet_id { + let ack = PublishReceivedPacket { + packet_id, + reason_code: PublishReceivedReason::Success, + reason_string: None, + user_properties: Vec::with_capacity(0), + }; + trace!("Granting QoS 2 publish on topic \"{}\"", packet.topic); + (true, Some(ack)) + } else { + warn!("Publish packet with QoS 2 without packet id"); + (false, None) + } + } +}