From 811704cc5d4b3b16281526204ce36d8d9a586580 Mon Sep 17 00:00:00 2001 From: Rusty Bee <145002912+rustybee42@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:18:45 +0100 Subject: [PATCH 1/2] refactor/fix: overhaul BeeSerde and msg buffers * Serializer now requires a preallocated buffer in form of a &mut [u8]. This moves the responsibility to allocate the required memory to the user. Before it used a dynamically growing buffer (which would reallocate internally if needed). This change is crucial for future RDMA support. * Since we now no longer use `BytesMut`, remove the bytes crate. * The buffers handled by the store have a fixed size of 4 MiB. This is apparently the maximum size of a BeeMsg (see `WORKER_BUF(IN|OUT)_SIZE` in `Worker.h`). C++ server code also uses these fixed size. * Generalize `msg_feature_flags` to a `Header` that can be modified from within the `Serializable` implementation and can be read out from within the `Deserializable` implementation. * Collect all BeeMsg (de)serialization functions in one module and provide functions for header, body and both combined. The split is required because depending on where the data comes from / goes to different actions need to be taken. This also provides an easy interface for potential external users to handle BeeMsges. * Remove the MsgBuf struct, instead just pass a `&mut [u8]` into the dispatcher. * Add documentation * Various small code cleanups in BeeSerde and other locations --- Cargo.lock | 1 - Cargo.toml | 1 - mgmtd/src/context.rs | 2 +- mgmtd/src/db/import_v7.rs | 10 +- mgmtd/src/lib.rs | 3 +- shared/Cargo.toml | 1 - shared/src/bee_msg.rs | 168 ++++++++++++++- shared/src/bee_msg/header.rs | 66 ------ shared/src/bee_serde.rs | 371 +++++++++++++++++--------------- shared/src/conn.rs | 14 +- shared/src/conn/incoming.rs | 43 ++-- shared/src/conn/msg_buf.rs | 180 ---------------- shared/src/conn/msg_dispatch.rs | 31 +-- shared/src/conn/outgoing.rs | 90 +++++--- shared/src/conn/store.rs | 13 +- shared/src/journald_logger.rs | 3 +- 16 files changed, 489 insertions(+), 508 deletions(-) delete mode 100644 shared/src/bee_msg/header.rs delete mode 100644 shared/src/conn/msg_buf.rs diff --git a/Cargo.lock b/Cargo.lock index d2712e9..915ffaf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1019,7 +1019,6 @@ version = "0.0.0" dependencies = [ "anyhow", "bee_serde_derive", - "bytes", "log", "pnet_datalink", "protobuf", diff --git a/Cargo.toml b/Cargo.toml index b405115..5824be9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,6 @@ publish = false [workspace.dependencies] anyhow = "1" -bytes = "1" clap = { version = "4", features = ["derive"] } env_logger = "0" itertools = "0" diff --git a/mgmtd/src/context.rs b/mgmtd/src/context.rs index 0662d38..ed7321a 100644 --- a/mgmtd/src/context.rs +++ b/mgmtd/src/context.rs @@ -4,8 +4,8 @@ use crate::bee_msg::dispatch_request; use crate::license::LicenseVerifier; use crate::{ClientPulledStateNotification, StaticInfo}; use anyhow::Result; -use shared::conn::Pool; use shared::conn::msg_dispatch::*; +use shared::conn::outgoing::Pool; use shared::run_state::WeakRunStateHandle; use shared::types::{NodeId, NodeType}; use std::ops::Deref; diff --git a/mgmtd/src/db/import_v7.rs b/mgmtd/src/db/import_v7.rs index b5cd42d..9c6676d 100644 --- a/mgmtd/src/db/import_v7.rs +++ b/mgmtd/src/db/import_v7.rs @@ -82,7 +82,7 @@ targetStates=1 fn check_target_states(f: &Path) -> Result<()> { let s = std::fs::read(f)?; - let mut des = Deserializer::new(&s, 0); + let mut des = Deserializer::new(&s); let states = des.map( false, |des| TargetId::deserialize(des), @@ -184,7 +184,7 @@ struct ReadNodesResult { fn read_nodes(f: &Path) -> Result { let s = std::fs::read(f)?; - let mut des = Deserializer::new(&s, 0); + let mut des = Deserializer::new(&s); let version = des.u32()?; let root_id = des.u32()?; let root_mirrored = des.u8()?; @@ -286,7 +286,7 @@ fn storage_targets(tx: &Transaction, targets_path: &Path) -> Result<()> { fn storage_pools(tx: &Transaction, f: &Path) -> Result<()> { let s = std::fs::read(f)?; - let mut des = Deserializer::new(&s, 0); + let mut des = Deserializer::new(&s); // Serialized as size_t, which should usually be 64 bit. let count = des.i64()?; let mut used_aliases = vec![]; @@ -415,7 +415,7 @@ fn quota_default_limits(tx: &Transaction, f: &Path, pool_id: PoolId) -> Result<( Err(err) => return Err(err.into()), }; - let mut des = Deserializer::new(&s, 0); + let mut des = Deserializer::new(&s); let user_inode_limit = des.u64()?; let user_space_limit = des.u64()?; let group_inode_limit = des.u64()?; @@ -497,7 +497,7 @@ fn quota_limits( Err(err) => return Err(err.into()), }; - let mut des = Deserializer::new(&s, 0); + let mut des = Deserializer::new(&s); let limits = des.seq(false, |des| QuotaEntry::deserialize(des))?; des.finish()?; diff --git a/mgmtd/src/lib.rs b/mgmtd/src/lib.rs index 418eba0..9b9c353 100644 --- a/mgmtd/src/lib.rs +++ b/mgmtd/src/lib.rs @@ -20,7 +20,8 @@ use db::node_nic::ReplaceNic; use license::LicenseVerifier; use shared::NetworkAddr; use shared::bee_msg::target::RefreshTargetStates; -use shared::conn::{Pool, incoming}; +use shared::conn::incoming; +use shared::conn::outgoing::Pool; use shared::run_state::{self, RunStateControl}; use shared::types::{AuthSecret, MGMTD_UID, NicType, NodeId, NodeType}; use sqlite::TransactionExt; diff --git a/shared/Cargo.toml b/shared/Cargo.toml index f7fee5d..1990f5d 100644 --- a/shared/Cargo.toml +++ b/shared/Cargo.toml @@ -11,7 +11,6 @@ publish.workspace = true bee_serde_derive = { path = "../bee_serde_derive" } anyhow = { workspace = true } -bytes = { workspace = true } log = { workspace = true } pnet_datalink = "0" protobuf = { workspace = true, optional = true } diff --git a/shared/src/bee_msg.rs b/shared/src/bee_msg.rs index e995a94..b986949 100644 --- a/shared/src/bee_msg.rs +++ b/shared/src/bee_msg.rs @@ -2,12 +2,11 @@ use crate::bee_serde::*; use crate::types::*; -use anyhow::Result; +use anyhow::{Context, Result, anyhow}; use bee_serde_derive::BeeSerde; use std::collections::{HashMap, HashSet}; pub mod buddy_group; -pub mod header; pub mod misc; pub mod node; pub mod quota; @@ -41,3 +40,168 @@ impl OpsErr { pub const AGAIN: Self = Self(22); pub const UNKNOWN_POOL: Self = Self(30); } + +/// The BeeMsg header +#[derive(Clone, Debug, PartialEq, Eq, BeeSerde)] +pub struct Header { + /// Total length of the serialized message, including the header itself + msg_len: u32, + /// Sometimes used for additional message specific payload and/or serialization info + pub msg_feature_flags: u16, + /// Sometimes used for additional message specific payload and/or serialization info + pub msg_compat_feature_flags: u8, + /// Sometimes used for additional message specific payload and/or serialization info + pub msg_flags: u8, + /// Fixed value to identify a BeeMsg header (see MSG_PREFIX below) + msg_prefix: u64, + /// Uniquely identifies the message type as defined in the C++ codebase in NetMessageTypes.h + msg_id: MsgId, + /// Sometimes used for additional message specific payload and/or serialization info + pub msg_target_id: TargetId, + /// Sometimes used for additional message specific payload and/or serialization info + pub msg_user_id: u32, + /// Mirroring related information + pub msg_seq: u64, + /// Mirroring related information + pub msg_seq_done: u64, +} + +impl Header { + /// The serialized length of the header + pub const LEN: usize = 40; + /// Fixed value for identifying BeeMsges. In theory, this has some kind of version modifier + /// (thus the + 0), but it is unused + #[allow(clippy::identity_op)] + pub const MSG_PREFIX: u64 = (0x42474653 << 32) + 0; + + /// The total length of the serialized message + pub fn msg_len(&self) -> usize { + self.msg_len as usize + } + + /// The messages id + pub fn msg_id(&self) -> MsgId { + self.msg_id + } +} + +impl Default for Header { + fn default() -> Self { + Self { + msg_len: 0, + msg_feature_flags: 0, + msg_compat_feature_flags: 0, + msg_flags: 0, + msg_prefix: Self::MSG_PREFIX, + msg_id: 0, + msg_target_id: 0, + msg_user_id: 0, + msg_seq: 0, + msg_seq_done: 0, + } + } +} + +/// Serializes a BeeMsg body into the provided buffer. +/// +/// The data is written from the beginning of the slice, it's up to the caller to pass the correct +/// sub slice if space for the header should be reserved. +/// +/// # Return value +/// Returns the number of bytes written and the header modified by serialization function. +pub fn serialize_body(msg: &M, buf: &mut [u8]) -> Result<(usize, Header)> { + let mut ser = Serializer::new(buf); + msg.serialize(&mut ser) + .context("BeeMsg body serialization failed")?; + + Ok((ser.bytes_written(), ser.finish())) +} + +/// Serializes a BeeMsg header into the provided buffer. +/// +/// # Return value +/// Returns the number of bytes written. +pub fn serialize_header(header: &Header, buf: &mut [u8]) -> Result { + let mut ser_header = Serializer::new(buf); + header + .serialize(&mut ser_header) + .context("BeeMsg header serialization failed")?; + + Ok(ser_header.bytes_written()) +} + +/// Serializes a complete BeeMsg (header + body) into the provided buffer. +/// +/// # Return value +/// Returns the number of bytes written. +pub fn serialize(msg: &M, buf: &mut [u8]) -> Result { + let (written, mut header) = serialize_body(msg, &mut buf[Header::LEN..])?; + + header.msg_len = (written + Header::LEN) as u32; + header.msg_id = M::ID; + + let _ = serialize_header(&header, &mut buf[0..Header::LEN])?; + + Ok(header.msg_len()) +} + +/// Deserializes a BeeMsg header from the provided buffer. +/// +/// # Return value +/// Returns the deserialized header. +pub fn deserialize_header(buf: &[u8]) -> Result
{ + const CTX: &str = "BeeMsg header deserialization failed"; + + let header_buf = buf + .get(..Header::LEN) + .ok_or_else(|| { + anyhow!( + "Header buffer must be at least {} bytes big, got {}", + Header::LEN, + buf.len() + ) + }) + .context(CTX)?; + + let mut des = Deserializer::new(header_buf); + let header = Header::deserialize(&mut des).context(CTX)?; + des.finish().context(CTX)?; + + if header.msg_prefix != Header::MSG_PREFIX { + return Err(anyhow!( + "Invalid BeeMsg prefix: Must be {}, got {}", + Header::MSG_PREFIX, + header.msg_prefix + )) + .context(CTX); + } + + Ok(header) +} + +/// Deserializes a BeeMsg body from the provided buffer. +/// +/// The data is read from the beginning of the slice, it's up to the caller to pass the correct +/// sub slice if space for the header should be excluded from the source. +/// +/// # Return value +/// Returns the deserialized message. +pub fn deserialize_body(header: &Header, buf: &[u8]) -> Result { + const CTX: &str = "BeeMsg body deserialization failed"; + + let mut des = Deserializer::with_header(&buf[0..(header.msg_len() - Header::LEN)], header); + let des_msg = M::deserialize(&mut des).context(CTX)?; + des.finish().context(CTX)?; + + Ok(des_msg) +} + +/// Deserializes a complete BeeMsg (header + body) from the provided buffer. +/// +/// # Return value +/// Returns the deserialized message. +pub fn deserialize(buf: &[u8]) -> Result { + let header = deserialize_header(&buf[0..Header::LEN])?; + let msg = deserialize_body(&header, &buf[Header::LEN..])?; + Ok(msg) +} diff --git a/shared/src/bee_msg/header.rs b/shared/src/bee_msg/header.rs deleted file mode 100644 index f704c95..0000000 --- a/shared/src/bee_msg/header.rs +++ /dev/null @@ -1,66 +0,0 @@ -/// Defines the BeeGFS message header -use super::*; -use crate::bee_serde::Deserializer; -use anyhow::bail; - -/// The BeeGFS message header -#[derive(Clone, Debug, Default, PartialEq, Eq, BeeSerde)] -pub struct Header { - /// Total length of the message, including the header. - /// - /// This determines the amount of bytes read and written from and to sockets. - pub msg_len: u32, - /// Sometimes used for additional message specific payload and/or serialization info - pub msg_feature_flags: u16, - pub msg_compat_feature_flags: u8, - pub msg_flags: u8, - /// Fixed value - pub msg_prefix: u64, - /// Uniquely identifies the message type as defined in the C++ codebase in NetMessageTypes.h - pub msg_id: MsgId, - pub msg_target_id: TargetId, - pub msg_user_id: u32, - pub msg_seq: u64, - pub msg_seq_done: u64, -} - -impl Header { - pub const LEN: usize = 40; - pub const DATA_VERSION: u64 = 0; - pub const MSG_PREFIX: u64 = (0x42474653 << 32) + Self::DATA_VERSION; - - /// Creates a new BeeGFS message header - /// - /// `msg_feature_flags` has to be set depending on the message. - pub fn new(body_len: usize, msg_id: MsgId, msg_feature_flags: u16) -> Self { - Self { - msg_len: (body_len + Self::LEN) as u32, - msg_feature_flags, - msg_compat_feature_flags: 0, - msg_flags: 0, - msg_prefix: Self::MSG_PREFIX, - msg_id, - msg_target_id: 0, - msg_user_id: u32::MAX, - msg_seq: 0, - msg_seq_done: 0, - } - } - - /// Deserializes the given buffer into a header - pub fn from_buf(buf: &[u8]) -> Result { - if buf.len() != Self::LEN { - bail!("Header buffer has an unexpected size of {}", buf.len()); - } - - let mut des = Deserializer::new(buf, 0); - let des_header = Header::deserialize(&mut des)?; - des.finish()?; - Ok(des_header) - } - - /// The expected total message length this header belongs to - pub fn msg_len(&self) -> usize { - self.msg_len as usize - } -} diff --git a/shared/src/bee_serde.rs b/shared/src/bee_serde.rs index 2c9f50d..14e24c5 100644 --- a/shared/src/bee_serde.rs +++ b/shared/src/bee_serde.rs @@ -1,105 +1,100 @@ -//! BeeGFS compatible network message (de-)serialization +//! BeeSerde, the BeeGFS network message (and some on disk data) (de-)serialization +use crate::bee_msg::Header; use anyhow::{Result, bail}; -use bytes::{Buf, BufMut, BytesMut}; +use std::borrow::Cow; use std::collections::HashMap; use std::hash::Hash; use std::marker::PhantomData; use std::mem::size_of; +// SERIALIZATION + +/// Makes a type BeeSerde serializable pub trait Serializable { fn serialize(&self, ser: &mut Serializer<'_>) -> Result<()>; } -pub trait Deserializable { - fn deserialize(des: &mut Deserializer<'_>) -> Result - where - Self: Sized; -} - -/// Provides conversion functionality to and from BeeSerde serializable types. -/// -/// Mainly meant for enums that need to be converted in to a raw integer type, which also might -/// differ between messages. The generic parameter allows implementing it for multiple types. -pub trait BeeSerdeConversion: Sized { - fn into_bee_serde(self) -> S; - fn try_from_bee_serde(value: S) -> Result; -} - -/// Interface for serialization helpers to be used with the `bee_serde` derive macro -/// -/// Serialization helpers are meant to control the `bee_serde` macro in case a value in the -/// message struct shall be serialized as a different type or in case it doesn't have its own -/// [BeeSerde] implementation. Also necessary for maps and sequences since the serializer can't -/// know on its own whether to include collection size or not (it's totally message dependent). -/// -/// # Example -/// -/// ```ignore -/// #[derive(Debug, BeeSerde)] -/// pub struct ExampleMsg { -/// // Serializer doesn't know by itself whether or not C/C++ BeeGFS serializer expects sequence -/// // size included or not - in this case it is not -/// #[bee_serde(as = Seq)] -/// int_sequence: Vec, -/// } -/// ``` -pub trait BeeSerdeHelper { - fn serialize_as(data: &In, ser: &mut Serializer<'_>) -> Result<()>; - fn deserialize_as(des: &mut Deserializer<'_>) -> Result; -} - -/// Serializes one BeeGFS message into a provided buffer +/// Serializes one `impl Serializable` into a target buffer +#[derive(Debug)] pub struct Serializer<'a> { /// The target buffer - target_buf: &'a mut BytesMut, - /// BeeGFS message feature flags obtained, used for conditional serialization by certain - /// messages. To be set by the serialization function. - pub msg_feature_flags: u16, - /// The number of bytes written to the buffer - bytes_written: usize, + target_buf: &'a mut [u8], + /// The position of the write cursor in the buffer. This equals to the number of bytes written. + write_pos: usize, + /// BeeMsg header, some fields are used for conditional serialization by certain + /// messages. To be set by the serialization function (except msg_len and msg_id). + // + // Note that in an ideal world, this would be generic and opaque as core bee_serde doesn't need + // to know about the type of this serialization metadata. But since it would require + // carrying the type everywhere (would make the code more complicated overall) we don't do + // it and accept a little coupling of core bee_serde to the BeeMsg header. It's almost only + // used for BeeMsg anyway (the header itself and v7 data import are the exceptions). + pub header: Header, } macro_rules! fn_serialize_primitive { - ($P:ident, $put_f:ident) => { + ($P:ident) => { pub fn $P(&mut self, v: $P) -> Result<()> { - self.target_buf.$put_f(v); - self.bytes_written += size_of::<$P>(); - Ok(()) + self.bytes(&v.to_le_bytes()) } }; } impl<'a> Serializer<'a> { - /// Creates a new Serializer object - /// - /// `msg_feature_flags` can be accessed from the (de-)serialization definition and is used for - /// conditional serialization on some messages. - /// `msg_feature_flags` is supposed to be obtained from the message definition, and is used - /// for conditional serialization by certain messages. - pub fn new(target_buf: &'a mut BytesMut) -> Self { + /// Creates a new Serializer object, writing into the given buffer. The buffer must be big + /// enough to take all the data. + pub fn with_header(buf: &'a mut [u8], header: Header) -> Self { Self { - target_buf, - msg_feature_flags: 0, - bytes_written: 0, + target_buf: buf, + write_pos: 0, + header, } } - fn_serialize_primitive!(u8, put_u8); - fn_serialize_primitive!(i8, put_i8); - fn_serialize_primitive!(u16, put_u16_le); - fn_serialize_primitive!(i16, put_i16_le); - fn_serialize_primitive!(u32, put_u32_le); - fn_serialize_primitive!(i32, put_i32_le); - fn_serialize_primitive!(u64, put_u64_le); - fn_serialize_primitive!(i64, put_i64_le); - fn_serialize_primitive!(u128, put_u128_le); - fn_serialize_primitive!(i128, put_i128_le); - - /// Serialize the given slice as bytes as expected by BeeGFS + /// Creates a new Serializer object with default header used as metadata. Meant for data that + /// does not do conditional serialization based on these fields (e.g. non BeeMsg or the + /// header itself). + pub fn new(buf: &'a mut [u8]) -> Self { + Self::with_header(buf, Header::default()) + } + + /// Finishes the serialization by consuming the serializer and returning the header + /// that might be set by certain BeeMsgs. + pub fn finish(self) -> Header { + self.header + } + + fn_serialize_primitive!(u8); + fn_serialize_primitive!(i8); + fn_serialize_primitive!(u16); + fn_serialize_primitive!(i16); + fn_serialize_primitive!(u32); + fn_serialize_primitive!(i32); + fn_serialize_primitive!(u64); + fn_serialize_primitive!(i64); + fn_serialize_primitive!(u128); + fn_serialize_primitive!(i128); + + /// Serialize the given slice as bytes. This is also the base operation for the other ops. pub fn bytes(&mut self, v: &[u8]) -> Result<()> { - self.target_buf.put(v); - self.bytes_written += v.len(); + match self + .target_buf + .get_mut(self.write_pos..(self.write_pos + v.len())) + { + Some(ref mut sub) => { + sub.clone_from_slice(v); + self.write_pos += v.len(); + } + None => { + bail!( + "Tried to write {} bytes but target buffer only has {} left", + v.len(), + self.target_buf.len() - self.write_pos + ); + } + } + Ok(()) } @@ -142,7 +137,7 @@ impl<'a> Serializer<'a> { include_total_size: bool, f: impl Fn(&mut Self, T) -> Result<()>, ) -> Result<()> { - let before = self.bytes_written; + let before = self.write_pos; // For the total size and length of the sequence we insert placeholders to be replaced // later when the values are known @@ -151,14 +146,14 @@ impl<'a> Serializer<'a> { // `BytesMut` and not the generic `BufMut` - the latter doesn't allow random access to // already written data let size_pos = if include_total_size { - let size_pos = self.bytes_written; + let size_pos = self.write_pos; self.u32(0xFFFFFFFFu32)?; size_pos } else { 0 }; - let count_pos = self.bytes_written; + let count_pos = self.write_pos; self.u32(0xFFFFFFFFu32)?; let mut count = 0u32; @@ -172,15 +167,13 @@ impl<'a> Serializer<'a> { // the placeholders in the beginning of the sequence with the actual values if include_total_size { - let written = (self.bytes_written - before) as u32; - for (p, b) in written.to_le_bytes().iter().enumerate() { - self.target_buf[size_pos + p] = *b; - } + let written = (self.write_pos - before) as u32; + self.target_buf[size_pos..(size_pos + size_of::())] + .clone_from_slice(&written.to_le_bytes()); } - for (p, b) in count.to_le_bytes().iter().enumerate() { - self.target_buf[count_pos + p] = *b; - } + self.target_buf[count_pos..(count_pos + size_of::())] + .clone_from_slice(&count.to_le_bytes()); Ok(()) } @@ -218,39 +211,56 @@ impl<'a> Serializer<'a> { Ok(()) } - /// The amount of bytes written to the buffer (so far) + /// The amount of bytes written to the buffer pub fn bytes_written(&self) -> usize { - self.bytes_written + self.write_pos } } -/// Deserializes one BeeGFS message from the given buffer +// DESERIALIZATION + +/// Makes a type BeeSerde deserializable +pub trait Deserializable { + fn deserialize(des: &mut Deserializer<'_>) -> Result + where + Self: Sized; +} + +/// Deserializes one `impl Deserializable` object from a source buffer pub struct Deserializer<'a> { /// The source buffer source_buf: &'a [u8], - /// BeeGFS message feature flags obtained from the message definition, used for - /// conditional deserialization by certain messages. - pub msg_feature_flags: u16, + /// BeeMsg header, used for conditional deserialization by certain messages. Can be + /// accessed from the deserialization definition. + pub header: Cow<'a, Header>, } macro_rules! fn_deserialize_primitive { - ($P:ident, $get_f:ident) => { + ($P:ident) => { pub fn $P(&mut self) -> Result<$P> { - self.check_remaining(size_of::<$P>())?; - Ok(self.source_buf.$get_f()) + let b = self.take(size_of::<$P>())?; + Ok($P::from_le_bytes(b.try_into()?)) } }; } impl<'a> Deserializer<'a> { - /// Creates a new Deserializer object - /// - /// `msg_feature_flags` is supposed to be obtained from the message definition, and is used - /// for conditional serialization by certain messages. - pub fn new(source_buf: &'a [u8], msg_feature_flags: u16) -> Self { + /// Creates a new Deserializer object with the given header used as metadata. Meant for BeeMsg - + /// they sometimes do conditional deserialization based on these fields. + pub fn with_header(buf: &'a [u8], header: &'a Header) -> Self { + Self { + source_buf: buf, + header: Cow::Borrowed(header), + } + } + + /// Creates a new Deserializer object with default header used as metadata. Meant for data that + /// does not do conditional deserialization based on these fields (e.g. non BeeMsg or the + /// header itself). + pub fn new(buf: &'a [u8]) -> Self { Self { - source_buf, - msg_feature_flags, + source_buf: buf, + header: Cow::Owned(Header::default()), } } @@ -265,25 +275,20 @@ impl<'a> Deserializer<'a> { Ok(()) } - fn_deserialize_primitive!(u8, get_u8); - fn_deserialize_primitive!(i8, get_i8); - fn_deserialize_primitive!(u16, get_u16_le); - fn_deserialize_primitive!(i16, get_i16_le); - fn_deserialize_primitive!(u32, get_u32_le); - fn_deserialize_primitive!(i32, get_i32_le); - fn_deserialize_primitive!(u64, get_u64_le); - fn_deserialize_primitive!(i64, get_i64_le); - fn_deserialize_primitive!(u128, get_u128_le); - fn_deserialize_primitive!(i128, get_i128_le); + fn_deserialize_primitive!(u8); + fn_deserialize_primitive!(i8); + fn_deserialize_primitive!(u16); + fn_deserialize_primitive!(i16); + fn_deserialize_primitive!(u32); + fn_deserialize_primitive!(i32); + fn_deserialize_primitive!(u64); + fn_deserialize_primitive!(i64); + fn_deserialize_primitive!(u128); + fn_deserialize_primitive!(i128); /// Deserialize a block of bytes as expected by BeeGFS pub fn bytes(&mut self, len: usize) -> Result> { - let mut v = vec![0; len]; - - self.check_remaining(len)?; - self.source_buf.copy_to_slice(&mut v); - - Ok(v) + Ok(self.take(len)?.to_owned()) } /// Deserialize a BeeGFS serialized c string @@ -293,11 +298,7 @@ impl<'a> Deserializer<'a> { /// don't. pub fn cstr(&mut self, align_to: usize) -> Result> { let len = self.u32()? as usize; - - let mut v = vec![0; len]; - - self.check_remaining(len)?; - self.source_buf.copy_to_slice(&mut v); + let v = self.take(len)?.to_owned(); let terminator: u8 = self.u8()?; if terminator != 0 { @@ -387,28 +388,61 @@ impl<'a> Deserializer<'a> { /// /// The opposite of fill_zeroes() in serialization. pub fn skip(&mut self, n: usize) -> Result<()> { - self.check_remaining(n)?; - self.source_buf.advance(n); - + self.take(n)?; Ok(()) } - /// Ensures that the source buffer has at least `n` bytes left - /// - /// Meant to check that there are enough bytes left before calling `Bytes` functions that would - /// panic otherwise (which we wan't to avoid) - fn check_remaining(&self, n: usize) -> Result<()> { - if self.source_buf.remaining() < n { - bail!( - "Unexpected end of source buffer. Needed at least {}, got {}", - n, - self.source_buf.remaining() - ); + /// Takes the next n bytes from the source buffer, checking that there are enough left. + fn take(&mut self, n: usize) -> Result<&[u8]> { + match self.source_buf.split_at_checked(n) { + Some((taken, rest)) => { + self.source_buf = rest; + Ok(taken) + } + None => { + bail!( + "Unexpected end of source buffer. Needed at least {n}, got {}", + self.source_buf.len() + ); + } } - Ok(()) } } +// HELPER / CONVENIENCE FUNCTIONS + +/// Provides conversion functionality to and from BeeSerde serializable types. +/// +/// Mainly meant for enums that need to be converted in to a raw integer type, which also might +/// differ between messages. The generic parameter allows implementing it for multiple types. +pub trait BeeSerdeConversion: Sized { + fn into_bee_serde(self) -> S; + fn try_from_bee_serde(value: S) -> Result; +} + +/// Interface for serialization helpers to be used with the `bee_serde` derive macro +/// +/// Serialization helpers are meant to control the `bee_serde` macro in case a value in the +/// message struct shall be serialized as a different type or in case it doesn't have its own +/// [BeeSerde] implementation. Also necessary for maps and sequences since the serializer can't +/// know on its own whether to include collection size or not (it's totally message dependent). +/// +/// # Example +/// +/// ```ignore +/// #[derive(Debug, BeeSerde)] +/// pub struct ExampleMsg { +/// // Serializer doesn't know by itself whether or not C/C++ BeeGFS serializer expects sequence +/// // size included or not - in this case it is not +/// #[bee_serde(as = Seq)] +/// int_sequence: Vec, +/// } +/// ``` +pub trait BeeSerdeHelper { + fn serialize_as(data: &In, ser: &mut Serializer<'_>) -> Result<()>; + fn deserialize_as(des: &mut Deserializer<'_>) -> Result; +} + /// Serialize an arbitrary type as Integer /// /// Note: Can potentially be used for non-integers, but is not practical due to the [Copy] @@ -530,7 +564,7 @@ mod test { #[test] fn primitives() { - let mut buf = BytesMut::new(); + let mut buf = vec![0; 1 + 1 + 2 + 2 + 4 + 4 + 8 + 8]; let mut ser = Serializer::new(&mut buf); ser.u8(123).unwrap(); @@ -542,10 +576,7 @@ mod test { ser.u64(0xAABBCCDDEEFF1122u64).unwrap(); ser.i64(-0x1ABBCCDDEEFF1122i64).unwrap(); - // 1 + 2 + 2 + 4 + 4 + 8 - assert_eq!(1 + 1 + 2 + 2 + 4 + 4 + 8 + 8, ser.bytes_written); - - let mut des = Deserializer::new(&buf, 0); + let mut des = Deserializer::new(&buf); assert_eq!(123, des.u8().unwrap()); assert_eq!(-123, des.i8().unwrap()); assert_eq!(22222, des.u16().unwrap()); @@ -561,17 +592,14 @@ mod test { #[test] fn bytes() { - let bytes: Vec = vec![0, 1, 2, 3, 4, 5]; - - let mut buf = BytesMut::new(); + let bytes = vec![0, 1, 2, 3, 4, 5]; + let mut buf = vec![0; 12]; let mut ser = Serializer::new(&mut buf); ser.bytes(&bytes).unwrap(); ser.bytes(&bytes).unwrap(); - assert_eq!(12, ser.bytes_written); - - let mut des = Deserializer::new(&buf, 0); + let mut des = Deserializer::new(&buf); assert_eq!(bytes, des.bytes(6).unwrap()); assert_eq!(bytes, des.bytes(6).unwrap()); @@ -581,23 +609,17 @@ mod test { #[test] fn cstr() { let str: Vec = "text".into(); - - let mut buf = BytesMut::new(); + // alignment applies to string length + null byte terminator + // Last one with align_to = 5 is intended and correct: Wrote 9 bytes, 9 % align_to = 1, + // align_to - 1 = 4 + let mut buf = vec![0; (4 + 4 + 1) + (4 + 4 + 1) + (4 + 4 + 1 + 4)]; let mut ser = Serializer::new(&mut buf); ser.cstr(&str, 0).unwrap(); ser.cstr(&str, 4).unwrap(); ser.cstr(&str, 5).unwrap(); - assert_eq!( - // alignment applies to string length + null byte terminator - // Last one with align_to = 5 is intended and correct: Wrote 9 bytes, 9 % align_to = 1, - // align_to - 1 = 4 - (4 + 4 + 1) + (4 + 4 + 1) + (4 + 4 + 1 + 4), - ser.bytes_written - ); - - let mut des = Deserializer::new(&buf, 0); + let mut des = Deserializer::new(&buf); assert_eq!(str, des.cstr(0).unwrap()); assert_eq!(str, des.cstr(4).unwrap()); assert_eq!(str, des.cstr(5).unwrap()); @@ -683,22 +705,19 @@ mod test { c2: HashMap::from([(18, vec!["aaa".into(), "bbbbb".into()])]), }; - let mut buf = BytesMut::new(); - - let mut ser = Serializer::new(&mut buf); - - s.serialize(&mut ser).unwrap(); - - assert_eq!( + let mut buf = vec![ + 0; 1 + 8 + (8 + 3 * 8) + (4 + 2 + 8) + (8 + (4 + 2 + 4)) - + (8 + (2 + (4 + (4 + 3 + 1) + (4 + 5 + 1)))), - ser.bytes_written - ); + + (8 + (2 + (4 + (4 + 3 + 1) + (4 + 5 + 1)))) + ]; - let mut des = Deserializer::new(&buf, 0); + let mut ser = Serializer::new(&mut buf); + s.serialize(&mut ser).unwrap(); + + let mut des = Deserializer::new(&buf); let s2 = S::deserialize(&mut des).unwrap(); @@ -708,23 +727,19 @@ mod test { #[test] fn wrong_buffer_len() { - let bytes: Vec = vec![0, 1, 2, 3, 4, 5]; + let mut buf = vec![0, 1, 2, 3, 4, 5]; - let mut buf = BytesMut::new(); let mut ser = Serializer::new(&mut buf); - ser.bytes(&bytes).unwrap(); + // Write too much + ser.u64(123).unwrap_err(); - let mut des = Deserializer::new(&buf, 0); + let mut des = Deserializer::new(&buf); des.bytes(5).unwrap(); - // Some buffer left des.finish().unwrap_err(); - // Consume too much des.bytes(2).unwrap_err(); - des.bytes(1).unwrap(); - // Complete buffer consumed des.finish().unwrap(); } diff --git a/shared/src/conn.rs b/shared/src/conn.rs index 840f7fe..559da65 100644 --- a/shared/src/conn.rs +++ b/shared/src/conn.rs @@ -2,11 +2,17 @@ mod async_queue; pub mod incoming; -mod msg_buf; pub mod msg_dispatch; -mod outgoing; +pub mod outgoing; mod store; mod stream; -pub use self::msg_buf::MsgBuf; -pub use outgoing::*; +/// Fixed length of the stream / TCP message buffers. +/// Must match the `WORKER_BUF(IN|OUT)_SIZE` value in `Worker.h` in the C++ +/// codebase. +const TCP_BUF_LEN: usize = 4 * 1024 * 1024; + +/// Fixed length of the datagram / UDP message buffers. +/// Must match the `DGRAMMR_(RECV|SEND)BUF_SIZE` value in `DatagramListener.*` in the C/C++ +/// codebase. Must be smaller than TCP_BUF_LEN; +const UDP_BUF_LEN: usize = 65536; diff --git a/shared/src/conn/incoming.rs b/shared/src/conn/incoming.rs index 5e1fd55..33cf52e 100644 --- a/shared/src/conn/incoming.rs +++ b/shared/src/conn/incoming.rs @@ -1,12 +1,12 @@ //! Handle incoming TCP and UDP connections and BeeMsgs. -use super::msg_buf::MsgBuf; use super::msg_dispatch::{DispatchRequest, SocketRequest, StreamRequest}; use super::stream::Stream; -use crate::bee_msg::Msg; +use super::*; use crate::bee_msg::misc::AuthenticateChannel; +use crate::bee_msg::{Header, Msg, deserialize_header}; use crate::run_state::RunStateHandle; -use anyhow::{Result, bail}; +use anyhow::{Context, Result, bail}; use std::io::{self, ErrorKind}; use std::net::SocketAddr; use std::sync::Arc; @@ -86,7 +86,7 @@ async fn stream_loop( log::debug!("Accepted incoming stream from {:?}", stream.addr()); // Use one owned buffer for reading into and writing from. - let mut buf = MsgBuf::default(); + let mut buf = vec![0; TCP_BUF_LEN]; loop { // Wait for available data or shutdown signal @@ -134,28 +134,41 @@ async fn stream_loop( /// handler and sending back a response using the [`StreamRequest`] handle. async fn read_stream( stream: &mut Stream, - buf: &mut MsgBuf, + buf: &mut [u8], dispatch: &impl DispatchRequest, stream_authentication_required: bool, ) -> Result<()> { - buf.read_from_stream(stream).await?; + // Read header + stream.read_exact(&mut buf[0..Header::LEN]).await?; + + let header = deserialize_header(&buf[0..Header::LEN])?; // check authentication if stream_authentication_required && !stream.authenticated - && buf.msg_id() != AuthenticateChannel::ID + && header.msg_id() != AuthenticateChannel::ID { bail!( "Stream is not authenticated and received message with id {}", - buf.msg_id() + header.msg_id() ); } + // Read body + stream + .read_exact(&mut buf[Header::LEN..header.msg_len()]) + .await?; + // Forward to the dispatcher. The dispatcher is responsible for deserializing, dispatching to // msg handlers and sending a response using the [`StreamRequest`] handle. dispatch - .dispatch_request(StreamRequest { stream, buf }) - .await?; + .dispatch_request(StreamRequest { + stream, + buf, + header: &header, + }) + .await + .context("Stream msg dispatch failed")?; Ok(()) } @@ -210,8 +223,11 @@ async fn recv_datagram(sock: Arc, msg_handler: impl DispatchRequest) // message spawns a new task (below) and we don't know how long the processing takes, we cannot // reuse Buffers like the TCP reader does. // A separate buffer pool could potentially be used to avoid allocating new buffers every time. - let mut buf = MsgBuf::default(); - let peer_addr = buf.recv_from_socket(&sock).await?; + let mut buf = vec![0; UDP_BUF_LEN]; + + let (_, peer_addr) = sock.recv_from(&mut buf).await?; + + let header = deserialize_header(&buf[0..Header::LEN])?; // Request shall be handled in a separate task, so the next datagram can be processed // immediately @@ -219,7 +235,8 @@ async fn recv_datagram(sock: Arc, msg_handler: impl DispatchRequest) let req = SocketRequest { sock, peer_addr, - msg_buf: &mut buf, + buf: &mut buf, + header: &header, }; // Forward to the dispatcher diff --git a/shared/src/conn/msg_buf.rs b/shared/src/conn/msg_buf.rs deleted file mode 100644 index bab3bff..0000000 --- a/shared/src/conn/msg_buf.rs +++ /dev/null @@ -1,180 +0,0 @@ -//! Reusable buffer for serialized BeeGFS messages -//! -//! This buffer provides the memory and the functionality to (de-)serialize BeeGFS messages from / -//! into and read / write / send / receive from / to streams and UDP sockets. -//! -//! They are meant to be used in two steps: -//! Serialize a message first, then write or send it to the wire. -//! OR -//! Read or receive data from the wire, then deserialize it into a message. -//! -//! # Example: Reading from stream -//! 1. `.read_from_stream()` to read in the data from stream into the buffer -//! 2. `.deserialize_msg()` to deserialize the message from the buffer -//! -//! # Important -//! If receiving data failed part way or didn't happen at all before calling `deserialize_msg`, the -//! buffer is in an invalid state. Deserializing will then most likely fail, or worse, succeed and -//! provide old or garbage data. The same applies for the opposite direction. It's up to the user to -//! make sure the buffer is used the appropriate way. -use super::stream::Stream; -use crate::bee_msg::header::Header; -use crate::bee_msg::{Msg, MsgId}; -use crate::bee_serde::{Deserializable, Deserializer, Serializable, Serializer}; -use anyhow::{Context, Result, bail}; -use bytes::BytesMut; -use std::net::SocketAddr; -use std::sync::Arc; -use tokio::net::UdpSocket; - -/// Fixed length of the datagrams to send and receive via UDP. -/// -/// Must match DGRAMMGR_*BUF_SIZE in `AbstractDatagramListener.h` (common) and `DatagramListener.h` -/// (client_module). -const DATAGRAM_LEN: usize = 65536; - -/// Reusable buffer for serialized BeeGFS messages -/// -/// See module level documentation for more information. -#[derive(Debug, Default)] -pub struct MsgBuf { - buf: BytesMut, - header: Box
, -} - -impl MsgBuf { - /// Serializes a BeeGFS message into the buffer - pub fn serialize_msg(&mut self, msg: &M) -> Result<()> { - self.buf.truncate(0); - - if self.buf.capacity() < Header::LEN { - self.buf.reserve(Header::LEN); - } - - // We need to serialize the body first since we need its total length for the header. - // Therefore, the body part (which comes AFTER the header) is split off to be passed as a - // separate BytesMut to the serializer. - let mut body = self.buf.split_off(Header::LEN); - - // Catching serialization errors to ensure buffer is unsplit afterwards in all cases - let res = (|| { - // Serialize body - let mut ser_body = Serializer::new(&mut body); - msg.serialize(&mut ser_body) - .context("BeeMsg body serialization failed")?; - - // Create and serialize header - let header = Header::new(ser_body.bytes_written(), M::ID, ser_body.msg_feature_flags); - let mut ser_header = Serializer::new(&mut self.buf); - header - .serialize(&mut ser_header) - .context("BeeMsg header serialization failed")?; - - *self.header = header; - - Ok(()) as Result<_> - })(); - - // Put header and body back together - self.buf.unsplit(body); - - res - } - - /// Deserializes the BeeGFS message present in the buffer - /// - /// # Panic - /// The function will panic if the buffer has not been filled with data before (e.g. by - /// reading from stream or receiving from a socket) - pub fn deserialize_msg(&self) -> Result { - const ERR_CTX: &str = "BeeMsg body deserialization failed"; - - let mut des = Deserializer::new(&self.buf[Header::LEN..], self.header.msg_feature_flags); - let des_msg = M::deserialize(&mut des).context(ERR_CTX)?; - des.finish().context(ERR_CTX)?; - Ok(des_msg) - } - - /// Reads a BeeGFS message from a stream into the buffer - pub(super) async fn read_from_stream(&mut self, stream: &mut Stream) -> Result<()> { - if self.buf.len() < Header::LEN { - self.buf.resize(Header::LEN, 0); - } - - stream.read_exact(&mut self.buf[0..Header::LEN]).await?; - let header = Header::from_buf(&self.buf[0..Header::LEN]) - .context("BeeMsg header deserialization failed")?; - let msg_len = header.msg_len(); - - if self.buf.len() != msg_len { - self.buf.resize(msg_len, 0); - } - - stream - .read_exact(&mut self.buf[Header::LEN..msg_len]) - .await?; - - *self.header = header; - - Ok(()) - } - - /// Writes the BeeGFS message from the buffer to a stream - /// - /// # Panic - /// The function will panic if the buffer has not been filled with data before (e.g. by - /// serializing a message) - pub(super) async fn write_to_stream(&self, stream: &mut Stream) -> Result<()> { - stream - .write_all(&self.buf[0..self.header.msg_len()]) - .await?; - Ok(()) - } - - /// Receives a BeeGFS message from a UDP socket into the buffer - pub(super) async fn recv_from_socket(&mut self, sock: &Arc) -> Result { - if self.buf.len() != DATAGRAM_LEN { - self.buf.resize(DATAGRAM_LEN, 0); - } - - match sock.recv_from(&mut self.buf).await { - Ok(n) => { - let header = Header::from_buf(&self.buf[0..Header::LEN])?; - self.buf.truncate(header.msg_len()); - *self.header = header; - Ok(n.1) - } - Err(err) => Err(err.into()), - } - } - - /// Sends the BeeGFS message in the buffer to a UDP socket - /// - /// # Panic - /// The function will panic if the buffer has not been filled with data before (e.g. by - /// serializing a message) - pub(super) async fn send_to_socket( - &self, - sock: &UdpSocket, - peer_addr: &SocketAddr, - ) -> Result<()> { - if self.buf.len() > DATAGRAM_LEN { - bail!( - "Datagram to be sent to {peer_addr:?} exceeds maximum length of {DATAGRAM_LEN} \ - bytes" - ); - } - - sock.send_to(&self.buf, peer_addr).await?; - Ok(()) - } - - /// The [MsgID] of the serialized BeeGFS message in the buffer - /// - /// # Panic - /// The function will panic if the buffer has not been filled with data before (e.g. by - /// reading from stream or receiving from a socket) - pub fn msg_id(&self) -> MsgId { - self.header.msg_id - } -} diff --git a/shared/src/conn/msg_dispatch.rs b/shared/src/conn/msg_dispatch.rs index 66b6e01..154f811 100644 --- a/shared/src/conn/msg_dispatch.rs +++ b/shared/src/conn/msg_dispatch.rs @@ -1,8 +1,7 @@ //! Facilities for dispatching TCP and UDP messages to their message handlers -use super::msg_buf::MsgBuf; use super::stream::Stream; -use crate::bee_msg::{Msg, MsgId}; +use crate::bee_msg::{Header, Msg, MsgId, deserialize_body, serialize}; use crate::bee_serde::{Deserializable, Serializable}; use anyhow::Result; use std::fmt::Debug; @@ -34,13 +33,14 @@ pub trait Request: Send + Sync { #[derive(Debug)] pub struct StreamRequest<'a> { pub(super) stream: &'a mut Stream, - pub(super) buf: &'a mut MsgBuf, + pub(super) buf: &'a mut [u8], + pub header: &'a Header, } impl Request for StreamRequest<'_> { async fn respond(self, msg: &M) -> Result<()> { - self.buf.serialize_msg(msg)?; - self.buf.write_to_stream(self.stream).await + let msg_len = serialize(msg, self.buf)?; + self.stream.write_all(&self.buf[0..msg_len]).await } fn authenticate_connection(&mut self) { @@ -58,11 +58,11 @@ impl Request for StreamRequest<'_> { } fn deserialize_msg(&self) -> Result { - self.buf.deserialize_msg() + deserialize_body(self.header, &self.buf[Header::LEN..]) } fn msg_id(&self) -> MsgId { - self.buf.msg_id() + self.header.msg_id() } } @@ -71,16 +71,17 @@ impl Request for StreamRequest<'_> { pub struct SocketRequest<'a> { pub(crate) sock: Arc, pub(crate) peer_addr: SocketAddr, - pub(crate) msg_buf: &'a mut MsgBuf, + pub(crate) buf: &'a mut [u8], + pub header: &'a Header, } impl Request for SocketRequest<'_> { async fn respond(self, msg: &M) -> Result<()> { - self.msg_buf.serialize_msg(msg)?; - - self.msg_buf - .send_to_socket(&self.sock, &self.peer_addr) - .await + let msg_len = serialize(msg, self.buf)?; + self.sock + .send_to(&self.buf[0..msg_len], &self.peer_addr) + .await?; + Ok(()) } fn authenticate_connection(&mut self) { @@ -92,10 +93,10 @@ impl Request for SocketRequest<'_> { } fn deserialize_msg(&self) -> Result { - self.msg_buf.deserialize_msg() + deserialize_body(self.header, &self.buf[Header::LEN..]) } fn msg_id(&self) -> MsgId { - self.msg_buf.msg_id() + self.header.msg_id() } } diff --git a/shared/src/conn/outgoing.rs b/shared/src/conn/outgoing.rs index 7b6dc5d..0ac949f 100644 --- a/shared/src/conn/outgoing.rs +++ b/shared/src/conn/outgoing.rs @@ -1,9 +1,9 @@ //! Outgoing communication functionality -use super::msg_buf::MsgBuf; use super::store::Store; -use crate::bee_msg::Msg; use crate::bee_msg::misc::AuthenticateChannel; +use crate::bee_msg::{Header, Msg, deserialize_body, deserialize_header, serialize}; use crate::bee_serde::{Deserializable, Serializable}; +use crate::conn::TCP_BUF_LEN; use crate::conn::store::StoredStream; use crate::conn::stream::Stream; use crate::types::{AuthSecret, Uid}; @@ -50,27 +50,27 @@ impl Pool { ) -> Result { log::trace!("REQUEST to {:?}: {:?}", node_uid, msg); - let mut buf = self.store.pop_buf().unwrap_or_default(); + let mut buf = self.store.pop_buf_or_create(); - buf.serialize_msg(msg)?; - self.comm_stream(node_uid, &mut buf, true).await?; - let resp = buf.deserialize_msg()?; + let msg_len = serialize(msg, &mut buf)?; + let resp_header = self.comm_stream(node_uid, &mut buf, msg_len, true).await?; + let resp_msg = deserialize_body(&resp_header, &buf[Header::LEN..])?; self.store.push_buf(buf); - log::trace!("RESPONSE RECEIVED from {:?}: {:?}", node_uid, resp); + log::trace!("RESPONSE RECEIVED from {:?}: {:?}", node_uid, resp_msg); - Ok(resp) + Ok(resp_msg) } /// Sends a [Msg] to a node and does **not** receive a response. pub async fn send(&self, node_uid: Uid, msg: &M) -> Result<()> { log::trace!("SEND to {:?}: {:?}", node_uid, msg); - let mut buf = self.store.pop_buf().unwrap_or_default(); + let mut buf = self.store.pop_buf_or_create(); - buf.serialize_msg(msg)?; - self.comm_stream(node_uid, &mut buf, false).await?; + let msg_len = serialize(msg, &mut buf)?; + self.comm_stream(node_uid, &mut buf, msg_len, false).await?; self.store.push_buf(buf); @@ -92,16 +92,19 @@ impl Pool { async fn comm_stream( &self, node_uid: Uid, - buf: &mut MsgBuf, + buf: &mut [u8], + send_len: usize, expect_response: bool, - ) -> Result<()> { + ) -> Result
{ + debug_assert_eq!(buf.len(), TCP_BUF_LEN); + // 1. Pop open streams until communication succeeds or none are left while let Some(stream) = self.store.try_pop_stream(node_uid) { match self - .write_and_read_stream(buf, stream, expect_response) + .write_and_read_stream(buf, stream, send_len, expect_response) .await { - Ok(_) => return Ok(()), + Ok(header) => return Ok(header), Err(err) => { // If the stream doesn't work anymore, just discard it and try the next one log::debug!("Communication using existing stream to {node_uid:?} failed: {err}") @@ -129,22 +132,27 @@ impl Pool { if let Some(auth_secret) = self.auth_secret { // The provided buffer contains the actual message to be sent later - // obtain an additional one for the auth message - let mut auth_buf = self.store.pop_buf().unwrap_or_default(); - auth_buf.serialize_msg(&AuthenticateChannel { auth_secret })?; - auth_buf - .write_to_stream(stream.as_mut()) + let mut auth_buf = self.store.pop_buf_or_create(); + let msg_len = + serialize(&AuthenticateChannel { auth_secret }, &mut auth_buf)?; + + stream + .as_mut() + .write_all(&auth_buf[0..msg_len]) .await .with_context(err_context)?; + self.store.push_buf(auth_buf); } // Communication using the newly opened stream should usually not fail. If // it does, abort. It might be better to just try the next address though. - self.write_and_read_stream(buf, stream, expect_response) + let resp_header = self + .write_and_read_stream(buf, stream, send_len, expect_response) .await .with_context(err_context)?; - return Ok(()); + return Ok(resp_header); } // If connecting failed, try the next address Err(err) => log::debug!("Connecting to {node_uid:?} via {addr} failed: {err}"), @@ -158,31 +166,44 @@ impl Pool { // 3. Wait for an already open stream becoming available let stream = self.store.pop_stream(node_uid).await?; - self.write_and_read_stream(buf, stream, expect_response) + let resp_header = self + .write_and_read_stream(buf, stream, send_len, expect_response) .await .with_context(|| { format!("Communication using existing stream to {node_uid:?} failed") })?; - Ok(()) + Ok(resp_header) } /// Writes data to the given stream, optionally receives a response and pushes the stream to /// the store async fn write_and_read_stream( &self, - buf: &mut MsgBuf, + buf: &mut [u8], mut stream: StoredStream, + send_len: usize, expect_response: bool, - ) -> Result<()> { - buf.write_to_stream(stream.as_mut()).await?; - - if expect_response { - buf.read_from_stream(stream.as_mut()).await?; - } + ) -> Result
{ + stream.as_mut().write_all(&buf[0..send_len]).await?; + + let header = if expect_response { + // Read header + stream.as_mut().read_exact(&mut buf[0..Header::LEN]).await?; + let header = deserialize_header(&buf[0..Header::LEN])?; + + // Read body + stream + .as_mut() + .read_exact(&mut buf[Header::LEN..header.msg_len()]) + .await?; + header + } else { + Header::default() + }; self.store.push_stream(stream); - Ok(()) + Ok(header) } pub async fn broadcast_datagram( @@ -190,8 +211,9 @@ impl Pool { peers: impl IntoIterator, msg: &M, ) -> Result<()> { - let mut buf = self.store.pop_buf().unwrap_or_default(); - buf.serialize_msg(msg)?; + let mut buf = self.store.pop_buf_or_create(); + + let msg_len = serialize(msg, &mut buf)?; for node_uid in peers { let Some(addrs) = self.store.get_node_addrs(node_uid) else { @@ -199,7 +221,7 @@ impl Pool { }; for addr in addrs.iter() { - buf.send_to_socket(&self.udp_socket, addr).await?; + self.udp_socket.send_to(&buf[0..msg_len], addr).await?; } } diff --git a/shared/src/conn/store.rs b/shared/src/conn/store.rs index 55e1f4c..fe0231d 100644 --- a/shared/src/conn/store.rs +++ b/shared/src/conn/store.rs @@ -3,8 +3,8 @@ //! //! Also provides a permit system to limit outgoing connections to a defined maximum. +use super::TCP_BUF_LEN; use super::async_queue::AsyncQueue; -use crate::conn::MsgBuf; use crate::conn::stream::Stream; use crate::types::Uid; use anyhow::{Result, anyhow}; @@ -23,7 +23,7 @@ const TIMEOUT: Duration = Duration::from_secs(2); pub struct Store { #[allow(clippy::type_complexity)] streams: Mutex>, Arc)>>, - bufs: Mutex>, + bufs: Mutex>>, addrs: RwLock>>, connection_limit: usize, } @@ -112,12 +112,17 @@ impl Store { } /// Pop a message buffer from the store - pub fn pop_buf(&self) -> Option { + pub fn pop_buf(&self) -> Option> { self.bufs.lock().unwrap().pop_front() } + /// Pop a message buffer from the store or create a new one suitable for stream / TCP messages + pub fn pop_buf_or_create(&self) -> Vec { + self.pop_buf().unwrap_or_else(|| vec![0; TCP_BUF_LEN]) + } + /// Push back a message buffer to the store - pub fn push_buf(&self, buf: MsgBuf) { + pub fn push_buf(&self, buf: Vec) { self.bufs.lock().unwrap().push_back(buf); } diff --git a/shared/src/journald_logger.rs b/shared/src/journald_logger.rs index f1b50b5..1da217c 100644 --- a/shared/src/journald_logger.rs +++ b/shared/src/journald_logger.rs @@ -1,6 +1,5 @@ //! Journald logger implementation for the `log` interface -use bytes::BufMut; use log::{Level, LevelFilter, Log, Metadata, Record}; use std::os::unix::net::UnixDatagram; @@ -37,7 +36,7 @@ impl Log for JournaldLogger { .into_bytes(); buf.reserve(msg.len() + 8 + 1); - buf.put_u64_le(msg.len() as u64); + buf.extend((msg.len() as u64).to_le_bytes()); buf.extend(msg); buf.extend(b"\n"); From 9f98fe466de6e6c56c94ceaeb8aaab2ecb7ea3f4 Mon Sep 17 00:00:00 2001 From: Rusty Bee <145002912+rustybee42@users.noreply.github.com> Date: Fri, 18 Jul 2025 09:05:56 +0200 Subject: [PATCH 2/2] fix: don't break udp socket when receiving an invalid message, improve logging Found by test_empty_udp_packets_server integration test. The test needs to be changed as well to adapt to the corrected log output (it was not really correct before) --- shared/src/conn/incoming.rs | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/shared/src/conn/incoming.rs b/shared/src/conn/incoming.rs index 33cf52e..fd970cc 100644 --- a/shared/src/conn/incoming.rs +++ b/shared/src/conn/incoming.rs @@ -199,8 +199,7 @@ pub fn recv_udp( // Do the actual work res = recv_datagram(sock.clone(), dispatch.clone()) => { if let Err(err) = res { - log::error!("Error in UDP socket {sock:?}: {err:#}"); - break; + log::error!("Error on receiving datagram using UDP socket {:?}: {err:#}", sock.local_addr()); } } @@ -227,20 +226,26 @@ async fn recv_datagram(sock: Arc, msg_handler: impl DispatchRequest) let (_, peer_addr) = sock.recv_from(&mut buf).await?; - let header = deserialize_header(&buf[0..Header::LEN])?; - // Request shall be handled in a separate task, so the next datagram can be processed // immediately tokio::spawn(async move { - let req = SocketRequest { - sock, - peer_addr, - buf: &mut buf, - header: &header, - }; + if let Err(err) = async { + let header = deserialize_header(&buf[0..Header::LEN])?; + + let req = SocketRequest { + sock, + peer_addr, + buf: &mut buf, + header: &header, + }; - // Forward to the dispatcher - if let Err(err) = msg_handler.dispatch_request(req).await { + // Forward to the dispatcher + msg_handler.dispatch_request(req).await?; + + Ok::<(), anyhow::Error>(()) + } + .await + { log::error!("Error while handling datagram from {peer_addr:?}: {err:#}"); } });