From e8d9b8308505d7ecd434d52d52a7e8d28a1c603b Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Sat, 16 Nov 2024 21:46:06 +0200 Subject: [PATCH 01/63] ZMQ skeleton --- lib/protoflow-zeromq/Cargo.toml | 3 +- lib/protoflow-zeromq/src/lib.rs | 69 +++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/lib/protoflow-zeromq/Cargo.toml b/lib/protoflow-zeromq/Cargo.toml index 00fde5ab..1cc3329d 100644 --- a/lib/protoflow-zeromq/Cargo.toml +++ b/lib/protoflow-zeromq/Cargo.toml @@ -27,6 +27,7 @@ cfg_aliases.workspace = true [dependencies] protoflow-core.workspace = true tracing = { version = "0.1", default-features = false, optional = true } -#zeromq = { version = "0.4", default-features = false } +zeromq = { version = "0.4.1", default-features = false, features = ["tokio-runtime", "all-transport"] } +tokio = { version = "1.40.0", default-features = false } [dev-dependencies] diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 0f186b2c..cea14760 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -5,3 +5,72 @@ #[doc(hidden)] pub use protoflow_core::prelude; + +extern crate std; + +use protoflow_core::{prelude::Bytes, InputPortID, OutputPortID, PortResult, PortState, Transport}; + +use zeromq::{Socket, SocketRecv, SocketSend}; + +pub struct ZMQTransport { + psock: zeromq::PubSocket, + ssock: zeromq::SubSocket, + tokio: tokio::runtime::Handle, +} + +impl ZMQTransport { + pub fn new(url: &str) -> Self { + let tokio = tokio::runtime::Handle::current(); + let mut psock = zeromq::PubSocket::new(); + tokio.block_on(psock.connect(url)).expect("psock conn"); + let mut ssock = zeromq::SubSocket::new(); + tokio.block_on(ssock.connect(url)).expect("ssock conn"); + Self { + psock, + ssock, + tokio, + } + } +} + +impl Transport for ZMQTransport { + fn input_state(&self, input: InputPortID) -> PortResult { + todo!(); + } + + fn output_state(&self, output: OutputPortID) -> PortResult { + todo!(); + } + + fn open_input(&self) -> PortResult { + todo!(); + } + + fn open_output(&self) -> PortResult { + todo!(); + } + + fn close_input(&self, input: InputPortID) -> PortResult { + todo!(); + } + + fn close_output(&self, output: OutputPortID) -> PortResult { + todo!(); + } + + fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { + todo!(); + } + + fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { + todo!(); + } + + fn recv(&self, input: InputPortID) -> PortResult> { + todo!(); + } + + fn try_recv(&self, _input: InputPortID) -> PortResult> { + todo!(); + } +} From 838bdebc82d0a1e2d01cf4729f4aaeb8f39b9052 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 18 Nov 2024 19:26:06 +0200 Subject: [PATCH 02/63] Add rudimentary `send` and `recv` --- lib/protoflow-zeromq/Cargo.toml | 8 ++- lib/protoflow-zeromq/src/lib.rs | 93 ++++++++++++++++++++++++++++----- 2 files changed, 87 insertions(+), 14 deletions(-) diff --git a/lib/protoflow-zeromq/Cargo.toml b/lib/protoflow-zeromq/Cargo.toml index 1cc3329d..c911369c 100644 --- a/lib/protoflow-zeromq/Cargo.toml +++ b/lib/protoflow-zeromq/Cargo.toml @@ -17,7 +17,7 @@ publish.workspace = true [features] default = ["all", "std"] all = ["tracing"] -std = ["protoflow-core/std", "tracing?/std"] #, "zeromq/default"] +std = ["protoflow-core/std", "tracing?/std"] tracing = ["protoflow-core/tracing", "dep:tracing"] unstable = ["protoflow-core/unstable"] @@ -27,7 +27,11 @@ cfg_aliases.workspace = true [dependencies] protoflow-core.workspace = true tracing = { version = "0.1", default-features = false, optional = true } -zeromq = { version = "0.4.1", default-features = false, features = ["tokio-runtime", "all-transport"] } +zeromq = { version = "0.4.1", default-features = false, features = [ + "tokio-runtime", + "all-transport", +] } tokio = { version = "1.40.0", default-features = false } +parking_lot = "0.12" [dev-dependencies] diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index cea14760..6ab44c85 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -8,27 +8,67 @@ pub use protoflow_core::prelude; extern crate std; -use protoflow_core::{prelude::Bytes, InputPortID, OutputPortID, PortResult, PortState, Transport}; +use protoflow_core::{ + prelude::{BTreeMap, Bytes}, + InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, +}; -use zeromq::{Socket, SocketRecv, SocketSend}; +use parking_lot::{Mutex, RwLock}; +use std::sync::mpsc::{Receiver, SyncSender}; +use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend}; pub struct ZMQTransport { - psock: zeromq::PubSocket, - ssock: zeromq::SubSocket, + sock: Mutex, tokio: tokio::runtime::Handle, + outputs: BTreeMap>, + inputs: BTreeMap>, +} + +#[derive(Debug, Clone, Default)] +enum ZmqOutputPortState { + #[default] + Open, + Connected(SyncSender), + Closed, +} + +#[derive(Debug, Default)] +enum ZmqInputPortState { + #[default] + Open, + Connected(Mutex>), + Closed, +} + +#[derive(Clone, Debug)] +enum ZmqTransportEvent { + Connect, + Message(Bytes), + Disconnect, } impl ZMQTransport { pub fn new(url: &str) -> Self { let tokio = tokio::runtime::Handle::current(); - let mut psock = zeromq::PubSocket::new(); - tokio.block_on(psock.connect(url)).expect("psock conn"); - let mut ssock = zeromq::SubSocket::new(); - tokio.block_on(ssock.connect(url)).expect("ssock conn"); + + let peer_id = PeerIdentity::new(); + let mut sock_opts = SocketOptions::default(); + sock_opts.peer_identity(peer_id); + + let mut sock = zeromq::RouterSocket::with_options(sock_opts); + tokio + .block_on(sock.connect(url)) + .expect("failed to connect"); + let sock = Mutex::new(sock); + + let outputs = BTreeMap::default(); + let inputs = BTreeMap::default(); + Self { - psock, - ssock, + sock, tokio, + outputs, + inputs, } } } @@ -63,11 +103,40 @@ impl Transport for ZMQTransport { } fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { - todo!(); + let Some(output_state) = self.outputs.get(&output) else { + todo!(); + }; + + use ZmqOutputPortState::*; + match *output_state.read() { + Open => todo!(), + Closed => todo!(), + Connected(ref sender) => { + let msg = ZmqTransportEvent::Message(message); + Ok(sender.send(msg).unwrap()) + } + } } fn recv(&self, input: InputPortID) -> PortResult> { - todo!(); + let Some(input_state) = self.inputs.get(&input) else { + todo!(); + }; + + use ZmqInputPortState::*; + match *input_state.read() { + Open => todo!(), + Closed => todo!(), + Connected(ref receiver) => { + use ZmqTransportEvent::*; + let receiver = receiver.lock(); + match receiver.recv().map_err(|_| PortError::Disconnected)? { + Connect => todo!(), + Disconnect => todo!(), + Message(bytes) => Ok(Some(bytes)), + } + } + } } fn try_recv(&self, _input: InputPortID) -> PortResult> { From 8cd9b4f02ccff88a5fd16db40fc54470743a579f Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Tue, 19 Nov 2024 07:54:05 +0200 Subject: [PATCH 03/63] wip --- lib/protoflow-zeromq/src/lib.rs | 54 +++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 6ab44c85..15d521bb 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -9,7 +9,7 @@ pub use protoflow_core::prelude; extern crate std; use protoflow_core::{ - prelude::{BTreeMap, Bytes}, + prelude::{BTreeMap, Bytes, ToString}, InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, }; @@ -18,7 +18,8 @@ use std::sync::mpsc::{Receiver, SyncSender}; use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend}; pub struct ZMQTransport { - sock: Mutex, + psock: Mutex, + ssock: Mutex, tokio: tokio::runtime::Handle, outputs: BTreeMap>, inputs: BTreeMap>, @@ -52,20 +53,42 @@ impl ZMQTransport { let tokio = tokio::runtime::Handle::current(); let peer_id = PeerIdentity::new(); - let mut sock_opts = SocketOptions::default(); - sock_opts.peer_identity(peer_id); - let mut sock = zeromq::RouterSocket::with_options(sock_opts); - tokio - .block_on(sock.connect(url)) - .expect("failed to connect"); - let sock = Mutex::new(sock); + let psock = { + let peer_id = peer_id.clone(); + let mut sock_opts = SocketOptions::default(); + sock_opts.peer_identity(peer_id); + + let mut psock = zeromq::PubSocket::with_options(sock_opts); + tokio + .block_on(psock.connect(url)) + .expect("failed to connect PUB"); + Mutex::new(psock) + }; + + let ssock = { + let mut sock_opts = SocketOptions::default(); + sock_opts.peer_identity(peer_id); + + let mut ssock = zeromq::SubSocket::with_options(sock_opts); + tokio + .block_on(ssock.connect(url)) + .expect("failed to connect SUB"); + Mutex::new(ssock) + }; + + // let mut sock = zeromq::RouterSocket::with_options(sock_opts); + // tokio + // .block_on(sock.connect(url)) + // .expect("failed to connect"); + // let sock = Mutex::new(sock); let outputs = BTreeMap::default(); let inputs = BTreeMap::default(); Self { - sock, + psock, + ssock, tokio, outputs, inputs, @@ -83,11 +106,18 @@ impl Transport for ZMQTransport { } fn open_input(&self) -> PortResult { - todo!(); + let id = self.inputs.len() + 1; + InputPortID::try_from(id as isize).map_err(|e| PortError::Other(e.to_string())) } fn open_output(&self) -> PortResult { - todo!(); + let id = self.inputs.len() + 1; + let id = + OutputPortID::try_from(id as isize).map_err(|e| PortError::Other(e.to_string()))?; + self.outputs + .insert(id, RwLock::new(ZmqOutputPortState::Open)) + .unwrap(); + Ok(id) } fn close_input(&self, input: InputPortID) -> PortResult { From d3d5c6df5a1b14690a933863f92b23c827c19a3c Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 25 Nov 2024 09:49:39 +0200 Subject: [PATCH 04/63] Reset impl again --- lib/protoflow-zeromq/src/lib.rs | 50 +++------------------------------ 1 file changed, 4 insertions(+), 46 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 15d521bb..ba95ec69 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -77,12 +77,6 @@ impl ZMQTransport { Mutex::new(ssock) }; - // let mut sock = zeromq::RouterSocket::with_options(sock_opts); - // tokio - // .block_on(sock.connect(url)) - // .expect("failed to connect"); - // let sock = Mutex::new(sock); - let outputs = BTreeMap::default(); let inputs = BTreeMap::default(); @@ -106,18 +100,11 @@ impl Transport for ZMQTransport { } fn open_input(&self) -> PortResult { - let id = self.inputs.len() + 1; - InputPortID::try_from(id as isize).map_err(|e| PortError::Other(e.to_string())) + todo!(); } fn open_output(&self) -> PortResult { - let id = self.inputs.len() + 1; - let id = - OutputPortID::try_from(id as isize).map_err(|e| PortError::Other(e.to_string()))?; - self.outputs - .insert(id, RwLock::new(ZmqOutputPortState::Open)) - .unwrap(); - Ok(id) + todo!(); } fn close_input(&self, input: InputPortID) -> PortResult { @@ -133,40 +120,11 @@ impl Transport for ZMQTransport { } fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { - let Some(output_state) = self.outputs.get(&output) else { - todo!(); - }; - - use ZmqOutputPortState::*; - match *output_state.read() { - Open => todo!(), - Closed => todo!(), - Connected(ref sender) => { - let msg = ZmqTransportEvent::Message(message); - Ok(sender.send(msg).unwrap()) - } - } + todo!(); } fn recv(&self, input: InputPortID) -> PortResult> { - let Some(input_state) = self.inputs.get(&input) else { - todo!(); - }; - - use ZmqInputPortState::*; - match *input_state.read() { - Open => todo!(), - Closed => todo!(), - Connected(ref receiver) => { - use ZmqTransportEvent::*; - let receiver = receiver.lock(); - match receiver.recv().map_err(|_| PortError::Disconnected)? { - Connect => todo!(), - Disconnect => todo!(), - Message(bytes) => Ok(Some(bytes)), - } - } - } + todo!(); } fn try_recv(&self, _input: InputPortID) -> PortResult> { From e7496c26abcb57a5ff5d879d476b318e455f8745 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 25 Nov 2024 09:53:59 +0200 Subject: [PATCH 05/63] Add topics --- lib/protoflow-zeromq/src/lib.rs | 37 +++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index ba95ec69..0d56dbaa 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -9,18 +9,24 @@ pub use protoflow_core::prelude; extern crate std; use protoflow_core::{ - prelude::{BTreeMap, Bytes, ToString}, + prelude::{BTreeMap, Bytes, String, ToString}, InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, }; use parking_lot::{Mutex, RwLock}; -use std::sync::mpsc::{Receiver, SyncSender}; +use std::{ + fmt::{self, Write}, + sync::mpsc::{Receiver, SyncSender}, + write, +}; use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend}; pub struct ZMQTransport { + tokio: tokio::runtime::Handle, + psock: Mutex, ssock: Mutex, - tokio: tokio::runtime::Handle, + outputs: BTreeMap>, inputs: BTreeMap>, } @@ -41,11 +47,30 @@ enum ZmqInputPortState { Closed, } +type SequenceID = u64; + #[derive(Clone, Debug)] enum ZmqTransportEvent { - Connect, - Message(Bytes), - Disconnect, + Connect(OutputPortID, InputPortID), + AckConnection(OutputPortID, InputPortID), + Message(OutputPortID, InputPortID, SequenceID, Bytes), + AckMessage(OutputPortID, InputPortID, SequenceID), + CloseOutput(OutputPortID, InputPortID), + CloseInput(InputPortID), +} + +impl ZmqTransportEvent { + fn write_topic(&self, f: &mut W) -> Result<(), fmt::Error> { + use ZmqTransportEvent::*; + match self { + Connect(o, i) => write!(f, "{}:conn:{}", i, o), + AckConnection(o, i) => write!(f, "{}:ackConn:{}", i, o), + Message(o, i, seq, _payload) => write!(f, "{}:msg:{}:{}", i, o, seq), + AckMessage(o, i, seq) => write!(f, "{}:ackMsg:{}:{}", i, o, seq), + CloseOutput(o, i) => write!(f, "{}:closeOut:{}", i, o), + CloseInput(i) => write!(f, "{}:closeIn", i), + } + } } impl ZMQTransport { From d812570442507d404d9082a1084c322f6228c28a Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 25 Nov 2024 11:15:47 +0200 Subject: [PATCH 06/63] Add port operation skeletons --- lib/protoflow-zeromq/Cargo.toml | 1 + lib/protoflow-zeromq/src/lib.rs | 181 ++++++++++++++++++++++++++++---- 2 files changed, 160 insertions(+), 22 deletions(-) diff --git a/lib/protoflow-zeromq/Cargo.toml b/lib/protoflow-zeromq/Cargo.toml index c911369c..e28a73e7 100644 --- a/lib/protoflow-zeromq/Cargo.toml +++ b/lib/protoflow-zeromq/Cargo.toml @@ -35,3 +35,4 @@ tokio = { version = "1.40.0", default-features = false } parking_lot = "0.12" [dev-dependencies] +futures-util = "0.3.31" diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 0d56dbaa..ec6f3fcd 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -9,7 +9,7 @@ pub use protoflow_core::prelude; extern crate std; use protoflow_core::{ - prelude::{BTreeMap, Bytes, String, ToString}, + prelude::{BTreeMap, Bytes, ToString}, InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, }; @@ -19,23 +19,29 @@ use std::{ sync::mpsc::{Receiver, SyncSender}, write, }; -use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend}; +use zeromq::{util::PeerIdentity, Socket, SocketOptions}; -pub struct ZMQTransport { +const DEFAULT_PUB_SOCKET: &str = "tcp://127.0.0.1:10000"; +const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; + +pub struct ZmqTransport { tokio: tokio::runtime::Handle, psock: Mutex, ssock: Mutex, - outputs: BTreeMap>, - inputs: BTreeMap>, + outputs: RwLock>>, + inputs: RwLock>>, } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Default)] enum ZmqOutputPortState { #[default] Open, - Connected(SyncSender), + Connected( + Mutex>, + SyncSender, + ), Closed, } @@ -43,7 +49,10 @@ enum ZmqOutputPortState { enum ZmqInputPortState { #[default] Open, - Connected(Mutex>), + Connected( + Mutex>, + SyncSender, + ), Closed, } @@ -73,8 +82,14 @@ impl ZmqTransportEvent { } } -impl ZMQTransport { - pub fn new(url: &str) -> Self { +impl Default for ZmqTransport { + fn default() -> Self { + Self::new(DEFAULT_PUB_SOCKET, DEFAULT_SUB_SOCKET) + } +} + +impl ZmqTransport { + pub fn new(pub_url: &str, sub_url: &str) -> Self { let tokio = tokio::runtime::Handle::current(); let peer_id = PeerIdentity::new(); @@ -86,7 +101,7 @@ impl ZMQTransport { let mut psock = zeromq::PubSocket::with_options(sock_opts); tokio - .block_on(psock.connect(url)) + .block_on(psock.connect(pub_url)) .expect("failed to connect PUB"); Mutex::new(psock) }; @@ -97,13 +112,13 @@ impl ZMQTransport { let mut ssock = zeromq::SubSocket::with_options(sock_opts); tokio - .block_on(ssock.connect(url)) + .block_on(ssock.connect(sub_url)) .expect("failed to connect SUB"); Mutex::new(ssock) }; - let outputs = BTreeMap::default(); - let inputs = BTreeMap::default(); + let outputs = RwLock::new(BTreeMap::default()); + let inputs = RwLock::new(BTreeMap::default()); Self { psock, @@ -115,33 +130,109 @@ impl ZMQTransport { } } -impl Transport for ZMQTransport { +impl Transport for ZmqTransport { fn input_state(&self, input: InputPortID) -> PortResult { - todo!(); + use ZmqInputPortState::*; + match self.inputs.read().get(&input) { + Some(input) => match *input.read() { + Open => Ok(PortState::Open), + Connected(_, _) => Ok(PortState::Connected), + Closed => Ok(PortState::Closed), + }, + None => Err(PortError::Invalid(input.into())), + } } fn output_state(&self, output: OutputPortID) -> PortResult { - todo!(); + use ZmqOutputPortState::*; + match self.outputs.read().get(&output) { + Some(output) => match *output.read() { + Open => Ok(PortState::Open), + Connected(_, _) => Ok(PortState::Connected), + Closed => Ok(PortState::Closed), + }, + None => Err(PortError::Invalid(output.into())), + } } fn open_input(&self) -> PortResult { - todo!(); + let mut inputs = self.inputs.write(); + + let new_id = InputPortID::try_from(-(inputs.len() as isize + 1)) + .map_err(|e| PortError::Other(e.to_string()))?; + + let state = RwLock::new(ZmqInputPortState::Open); + inputs.insert(new_id, state); + + // TODO: start worker + + Ok(new_id) } fn open_output(&self) -> PortResult { - todo!(); + let mut outputs = self.outputs.write(); + + let new_id = OutputPortID::try_from(outputs.len() as isize + 1) + .map_err(|e| PortError::Other(e.to_string()))?; + + let state = RwLock::new(ZmqOutputPortState::Open); + outputs.insert(new_id, state); + + // TODO: start worker + + Ok(new_id) } fn close_input(&self, input: InputPortID) -> PortResult { - todo!(); + let inputs = self.inputs.read(); + + let Some(state) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; + + let mut state = state.write(); + + let ZmqInputPortState::Connected(_, _) = *state else { + return Err(PortError::Disconnected); + }; + + // TODO: send message to worker + + *state = ZmqInputPortState::Closed; + + Ok(true) } fn close_output(&self, output: OutputPortID) -> PortResult { - todo!(); + let outputs = self.outputs.read(); + + let Some(state) = outputs.get(&output) else { + return Err(PortError::Invalid(output.into())); + }; + + let mut state = state.write(); + + let ZmqOutputPortState::Connected(_, _) = *state else { + return Err(PortError::Disconnected); + }; + + // TODO: send message to worker + + *state = ZmqOutputPortState::Closed; + + Ok(true) } fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { - todo!(); + let Some(output) = self.outputs.read().get(&source) else { + return Err(PortError::Invalid(source.into())); + }; + + let Some(input) = self.inputs.read().get(&target) else { + return Err(PortError::Invalid(target.into())); + }; + + Ok(true) } fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { @@ -156,3 +247,49 @@ impl Transport for ZMQTransport { todo!(); } } + +#[cfg(test)] +mod tests { + use super::*; + + use protoflow_core::System; + + use futures_util::future::TryFutureExt; + use zeromq::{PubSocket, SocketRecv, SocketSend, SubSocket}; + + fn start_zmqtransport_server(rt: &tokio::runtime::Runtime) { + // bind a *SUB* socket to the *PUB* address so that the transport can *PUB* to it + let mut pub_srv = SubSocket::new(); + rt.block_on(pub_srv.bind(DEFAULT_PUB_SOCKET)).unwrap(); + + // bind a *PUB* socket to the *SUB* address so that the transport can *SUB* to it + let mut sub_srv = PubSocket::new(); + rt.block_on(sub_srv.bind(DEFAULT_SUB_SOCKET)).unwrap(); + + // subscribe to all messages + rt.block_on(pub_srv.subscribe("")).unwrap(); + + // resend anything received on the *SUB* socket to the *PUB* socket + tokio::task::spawn(async move { + let mut pub_srv = pub_srv; + loop { + pub_srv + .recv() + .and_then(|msg| sub_srv.send(msg)) + .await + .unwrap(); + } + }); + } + + #[test] + fn implementation_matches() { + let rt = tokio::runtime::Runtime::new().unwrap(); + let _guard = rt.enter(); + + //zeromq::proxy(frontend, backend, capture) + start_zmqtransport_server(&rt); + + let _ = System::::build(|_s| { /* do nothing */ }); + } +} From 56e063f8443448383a96f7524eef221a00724797 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Tue, 26 Nov 2024 12:12:45 +0200 Subject: [PATCH 07/63] Add port worker skeletons --- lib/protoflow-zeromq/src/lib.rs | 228 +++++++++++++++++++++++++------- 1 file changed, 182 insertions(+), 46 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index ec6f3fcd..a0e72947 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -9,17 +9,18 @@ pub use protoflow_core::prelude; extern crate std; use protoflow_core::{ - prelude::{BTreeMap, Bytes, ToString}, + prelude::{BTreeMap, Bytes, ToString, Vec}, InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, }; use parking_lot::{Mutex, RwLock}; use std::{ fmt::{self, Write}, - sync::mpsc::{Receiver, SyncSender}, + format, + sync::mpsc::{sync_channel, Receiver, SyncSender}, write, }; -use zeromq::{util::PeerIdentity, Socket, SocketOptions}; +use zeromq::{util::PeerIdentity, Socket, SocketOptions, ZmqError}; const DEFAULT_PUB_SOCKET: &str = "tcp://127.0.0.1:10000"; const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; @@ -39,25 +40,53 @@ enum ZmqOutputPortState { #[default] Open, Connected( - Mutex>, SyncSender, + Mutex>, + InputPortID, ), Closed, } -#[derive(Debug, Default)] +impl ZmqOutputPortState { + fn state(&self) -> PortState { + use ZmqOutputPortState::*; + match self { + Open => PortState::Open, + Connected(_, _, _) => PortState::Connected, + Closed => PortState::Closed, + } + } +} + +#[derive(Debug)] enum ZmqInputPortState { - #[default] - Open, + Open( + SyncSender, + Mutex>, + ), Connected( - Mutex>, SyncSender, + Mutex>, + Vec, ), Closed, } +impl ZmqInputPortState { + fn state(&self) -> PortState { + use ZmqInputPortState::*; + match self { + Open(_, _) => PortState::Open, + Connected(_, _, _) => PortState::Connected, + Closed => PortState::Closed, + } + } +} + type SequenceID = u64; +/// ZmqTransportEvent represents the data that goes over the wire, sent from an output port over +/// the network to an input port. #[derive(Clone, Debug)] enum ZmqTransportEvent { Connect(OutputPortID, InputPortID), @@ -82,6 +111,24 @@ impl ZmqTransportEvent { } } +/// ZmqOutputPortEvent represents events that we receive from the background worker of the port. +#[derive(Clone, Debug)] +enum ZmqOutputPortEvent { + Opened, + Connected(InputPortID), + Message(Bytes), + Closed, +} + +/// ZmqInputPortEvent represents events that we receive from the background worker of the port. +#[derive(Clone, Debug)] +enum ZmqInputPortEvent { + Opened, + Connected(OutputPortID), + Message(Bytes), + Closed, +} + impl Default for ZmqTransport { fn default() -> Self { Self::new(DEFAULT_PUB_SOCKET, DEFAULT_SUB_SOCKET) @@ -128,31 +175,47 @@ impl ZmqTransport { inputs, } } + + fn subscribe_for_input_port( + &self, + input: InputPortID, + ) -> Result<(SyncSender, Receiver), ZmqError> { + // TODO: only sub to relevant events + let topic = format!("{}:", input); + self.tokio.block_on(self.ssock.lock().subscribe(&topic))?; + let (from_worker_send, from_worker_recv) = sync_channel(1); + let (to_worker_send, to_worker_recv) = sync_channel(1); + + // Input worker loop: + // 1. Receive connection attempts and respond + // 2. Receive messages and forward to channel + // 3. Receive and handle disconnects + tokio::task::spawn(async { + let (output, input) = (from_worker_send, to_worker_recv); + loop { + todo!(); + } + }); + + Ok((to_worker_send, from_worker_recv)) + } } impl Transport for ZmqTransport { fn input_state(&self, input: InputPortID) -> PortResult { - use ZmqInputPortState::*; - match self.inputs.read().get(&input) { - Some(input) => match *input.read() { - Open => Ok(PortState::Open), - Connected(_, _) => Ok(PortState::Connected), - Closed => Ok(PortState::Closed), - }, - None => Err(PortError::Invalid(input.into())), - } + self.inputs + .read() + .get(&input) + .map(|port| port.read().state()) + .ok_or(PortError::Invalid(input.into())) } fn output_state(&self, output: OutputPortID) -> PortResult { - use ZmqOutputPortState::*; - match self.outputs.read().get(&output) { - Some(output) => match *output.read() { - Open => Ok(PortState::Open), - Connected(_, _) => Ok(PortState::Connected), - Closed => Ok(PortState::Closed), - }, - None => Err(PortError::Invalid(output.into())), - } + self.outputs + .read() + .get(&output) + .map(|port| port.read().state()) + .ok_or(PortError::Invalid(output.into())) } fn open_input(&self) -> PortResult { @@ -161,10 +224,12 @@ impl Transport for ZmqTransport { let new_id = InputPortID::try_from(-(inputs.len() as isize + 1)) .map_err(|e| PortError::Other(e.to_string()))?; - let state = RwLock::new(ZmqInputPortState::Open); - inputs.insert(new_id, state); + let (sender, receiver) = self + .subscribe_for_input_port(new_id) + .map_err(|e| PortError::Other(e.to_string()))?; - // TODO: start worker + let state = RwLock::new(ZmqInputPortState::Open(sender, Mutex::new(receiver))); + inputs.insert(new_id, state); Ok(new_id) } @@ -178,8 +243,6 @@ impl Transport for ZmqTransport { let state = RwLock::new(ZmqOutputPortState::Open); outputs.insert(new_id, state); - // TODO: start worker - Ok(new_id) } @@ -190,17 +253,27 @@ impl Transport for ZmqTransport { return Err(PortError::Invalid(input.into())); }; - let mut state = state.write(); + let state = state.read(); - let ZmqInputPortState::Connected(_, _) = *state else { + let ZmqInputPortState::Connected(sender, receiver, _) = &*state else { return Err(PortError::Disconnected); }; - // TODO: send message to worker - - *state = ZmqInputPortState::Closed; + sender + .send(ZmqTransportEvent::CloseInput(input)) + .map_err(|e| PortError::Other(e.to_string()))?; - Ok(true) + loop { + let msg = receiver + .lock() + .recv() + .map_err(|e| PortError::Other(e.to_string()))?; + use ZmqInputPortEvent::*; + match msg { + Closed => break Ok(true), + _ => continue, // TODO + }; + } } fn close_output(&self, output: OutputPortID) -> PortResult { @@ -210,29 +283,92 @@ impl Transport for ZmqTransport { return Err(PortError::Invalid(output.into())); }; - let mut state = state.write(); + let state = state.write(); - let ZmqOutputPortState::Connected(_, _) = *state else { + let ZmqOutputPortState::Connected(sender, receiver, input) = &*state else { return Err(PortError::Disconnected); }; - // TODO: send message to worker - - *state = ZmqOutputPortState::Closed; + sender + .send(ZmqTransportEvent::CloseOutput(output, *input)) + .map_err(|e| PortError::Other(e.to_string()))?; - Ok(true) + loop { + let msg = receiver + .lock() + .recv() + .map_err(|e| PortError::Other(e.to_string()))?; + use ZmqOutputPortEvent::*; + match msg { + Closed => break Ok(true), + _ => continue, // TODO + } + } } fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { - let Some(output) = self.outputs.read().get(&source) else { + let outputs = self.outputs.read(); + let Some(output) = outputs.get(&source) else { return Err(PortError::Invalid(source.into())); }; - let Some(input) = self.inputs.read().get(&target) else { + let inputs = self.inputs.read(); + let Some(input) = inputs.get(&target) else { return Err(PortError::Invalid(target.into())); }; - Ok(true) + //let mut output = output.write(); + //if !output.state().is_open() { + // return Err(PortError::Invalid(source.into())); + //} + // + //let mut input = input.write(); + //if !input.state().is_open() { + // return Err(PortError::Invalid(source.into())); + //} + + // TODO: send from output, receive and respond from input + + //let (out_recv, out_send) = { + // let (from_worker_send, from_worker_recv) = sync_channel::(1); + // let (to_worker_send, to_worker_recv) = sync_channel::(1); + // + // tokio::task::spawn(async { + // let (output, input) = (from_worker_send, to_worker_recv); + // loop { + // tokio::time::sleep(Duration::from_secs(1)).await; + // } + // }); + // + // (from_worker_recv, to_worker_send) + //}; + + let (from_worker_send, from_worker_recv) = sync_channel::(1); + let (to_worker_send, to_worker_recv) = sync_channel::(1); + + // Output worker loop: + // 1. Send connection attempts + // 2. Send messages + // 2.1 Wait for ACK + // 2.2. Resend on timeout + // 3. Send disconnect events + tokio::task::spawn(async { + let (output, input) = (from_worker_send, to_worker_recv); + loop { + todo!(); + } + }); + + loop { + let msg = from_worker_recv + .recv() + .map_err(|e| PortError::Other(e.to_string()))?; + use ZmqOutputPortEvent::*; + match msg { + Connected(_) => break Ok(true), + _ => continue, // TODO + } + } } fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { From ea1d54dddb7d67247d0592bb4aa955bca2cbb56b Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Tue, 26 Nov 2024 12:23:02 +0200 Subject: [PATCH 08/63] Add ZMQ socket workers for listening and sending --- lib/protoflow-zeromq/src/lib.rs | 455 ++++++++++++++++++++++++++------ 1 file changed, 372 insertions(+), 83 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index a0e72947..cb7ca70b 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -9,18 +9,20 @@ pub use protoflow_core::prelude; extern crate std; use protoflow_core::{ - prelude::{BTreeMap, Bytes, ToString, Vec}, + prelude::{Arc, BTreeMap, Bytes, String, ToString, Vec}, InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, }; +use core::fmt::Error; use parking_lot::{Mutex, RwLock}; use std::{ - fmt::{self, Write}, format, sync::mpsc::{sync_channel, Receiver, SyncSender}, write, }; -use zeromq::{util::PeerIdentity, Socket, SocketOptions, ZmqError}; +use zeromq::{ + util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqError, ZmqMessage, +}; const DEFAULT_PUB_SOCKET: &str = "tcp://127.0.0.1:10000"; const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; @@ -28,20 +30,19 @@ const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; pub struct ZmqTransport { tokio: tokio::runtime::Handle, - psock: Mutex, - ssock: Mutex, + out_queue: Arc>, + sub_queue: Arc>, - outputs: RwLock>>, - inputs: RwLock>>, + outputs: Arc>>>, + inputs: Arc>>>, } -#[derive(Debug, Default)] +#[derive(Debug, Clone)] enum ZmqOutputPortState { - #[default] - Open, + Open(Arc>), Connected( - SyncSender, - Mutex>, + Arc>, + Arc>>, InputPortID, ), Closed, @@ -51,22 +52,22 @@ impl ZmqOutputPortState { fn state(&self) -> PortState { use ZmqOutputPortState::*; match self { - Open => PortState::Open, + Open(_) => PortState::Open, Connected(_, _, _) => PortState::Connected, Closed => PortState::Closed, } } } -#[derive(Debug)] +#[derive(Debug, Clone)] enum ZmqInputPortState { Open( - SyncSender, - Mutex>, + Arc>, + Arc>>, ), Connected( - SyncSender, - Mutex>, + Arc>, + Arc>>, Vec, ), Closed, @@ -98,7 +99,7 @@ enum ZmqTransportEvent { } impl ZmqTransportEvent { - fn write_topic(&self, f: &mut W) -> Result<(), fmt::Error> { + fn write_topic(&self, f: &mut W) -> Result<(), std::io::Error> { use ZmqTransportEvent::*; match self { Connect(o, i) => write!(f, "{}:conn:{}", i, o), @@ -111,12 +112,42 @@ impl ZmqTransportEvent { } } +impl From for ZmqMessage { + fn from(value: ZmqTransportEvent) -> Self { + let mut topic = Vec::new(); + value.write_topic(&mut topic).unwrap(); + + // first frame of the message is the topic + let mut msg = ZmqMessage::from(topic.clone()); + + // second frame of the message is the payload + match value { + ZmqTransportEvent::Connect(output_port_id, input_port_id) => todo!(), + ZmqTransportEvent::AckConnection(output_port_id, input_port_id) => todo!(), + ZmqTransportEvent::Message(_, _, _, bytes) => msg.push_back(bytes), + ZmqTransportEvent::AckMessage(output_port_id, input_port_id, _) => todo!(), + ZmqTransportEvent::CloseOutput(output_port_id, input_port_id) => todo!(), + ZmqTransportEvent::CloseInput(input_port_id) => todo!(), + }; + + msg + } +} + +impl TryFrom for ZmqTransportEvent { + type Error = protoflow_core::DecodeError; + + fn try_from(value: ZmqMessage) -> Result { + todo!() + } +} + /// ZmqOutputPortEvent represents events that we receive from the background worker of the port. #[derive(Clone, Debug)] enum ZmqOutputPortEvent { Opened, Connected(InputPortID), - Message(Bytes), + Ack(SequenceID), Closed, } @@ -135,6 +166,12 @@ impl Default for ZmqTransport { } } +#[derive(Clone)] +enum ZmqSubscriptionRequest { + Subscribe(String), + Unsubscribe(String), +} + impl ZmqTransport { pub fn new(pub_url: &str, sub_url: &str) -> Self { let tokio = tokio::runtime::Handle::current(); @@ -150,7 +187,7 @@ impl ZmqTransport { tokio .block_on(psock.connect(pub_url)) .expect("failed to connect PUB"); - Mutex::new(psock) + psock }; let ssock = { @@ -161,41 +198,145 @@ impl ZmqTransport { tokio .block_on(ssock.connect(sub_url)) .expect("failed to connect SUB"); - Mutex::new(ssock) + ssock }; - let outputs = RwLock::new(BTreeMap::default()); - let inputs = RwLock::new(BTreeMap::default()); + let outputs = Arc::new(RwLock::new(BTreeMap::default())); + let inputs = Arc::new(RwLock::new(BTreeMap::default())); + + let (out_queue, out_queue_recv) = sync_channel(1); + + let out_queue = Arc::new(out_queue); - Self { - psock, - ssock, + let (sub_queue, sub_queue_recv) = tokio::sync::mpsc::channel(1); + let sub_queue = Arc::new(sub_queue); + + let transport = Self { + out_queue, + sub_queue, tokio, outputs, inputs, - } + }; + + transport.start_send_worker(psock, out_queue_recv); + transport.start_recv_worker(ssock, sub_queue_recv); + + transport + } + + fn start_send_worker(&self, psock: zeromq::PubSocket, queue: Receiver) { + let tokio = self.tokio.clone(); + let mut psock = psock; + + tokio::task::spawn(async move { + loop { + let Ok(event) = queue.recv() else { + continue; + }; + + let msg = ZmqMessage::from(event); + + tokio.block_on(psock.send(msg)).expect("zmq send worker") + } + }); + } + + fn start_recv_worker( + &self, + ssock: zeromq::SubSocket, + queue: tokio::sync::mpsc::Receiver, + ) { + let mut ssock = ssock; + let mut queue = queue; + + let outputs = self.outputs.clone(); + let inputs = self.inputs.clone(); + + tokio::task::spawn(async move { + loop { + tokio::select! { + Ok(msg) = ssock.recv() => handle_zmq_msg(msg, &outputs, &inputs).unwrap(), + Some(req) = queue.recv() => { + use ZmqSubscriptionRequest::*; + match req { + Subscribe(topic) => ssock.subscribe(&topic).await.expect("zmq recv worker subscribe"), + Unsubscribe(topic) => ssock.unsubscribe(&topic).await.expect("zmq recv worker subscribe"), + }; + } + }; + } + }); } fn subscribe_for_input_port( &self, input: InputPortID, - ) -> Result<(SyncSender, Receiver), ZmqError> { + ) -> Result< + ( + Arc>, + Arc>>, + ), + ZmqError, + > { // TODO: only sub to relevant events let topic = format!("{}:", input); - self.tokio.block_on(self.ssock.lock().subscribe(&topic))?; + self.tokio + .block_on( + self.sub_queue + .send(ZmqSubscriptionRequest::Subscribe(topic)), + ) + .unwrap(); + let (from_worker_send, from_worker_recv) = sync_channel(1); let (to_worker_send, to_worker_recv) = sync_channel(1); + let to_worker_send = Arc::new(to_worker_send); + let from_worker_recv = Arc::new(Mutex::new(from_worker_recv)); + // Input worker loop: // 1. Receive connection attempts and respond // 2. Receive messages and forward to channel // 3. Receive and handle disconnects - tokio::task::spawn(async { - let (output, input) = (from_worker_send, to_worker_recv); - loop { - todo!(); - } - }); + { + let inputs = self.inputs.clone(); + + let to_worker_send = to_worker_send.clone(); + let from_worker_recv = from_worker_recv.clone(); + + tokio::task::spawn(async move { + let (output, input) = (from_worker_send, to_worker_recv); + + loop { + use ZmqTransportEvent::*; + let event = input.recv().expect("input worker recv"); + match event { + // Connection attempt + Connect(output_port_id, input_port_id) => { + let inputs = inputs.read(); + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + + let mut input = input.write(); + *input = ZmqInputPortState::Connected( + to_worker_send.clone(), + from_worker_recv.clone(), + Vec::new(), + ); + } + + // Message from output port + Message(output_port_id, input_port_id, _, bytes) => todo!(), + // Output port reports being closed + CloseInput(input_port_id) => todo!(), + + // ignore output port type events: + AckConnection(_, _) | AckMessage(_, _, _) | CloseOutput(_, _) => continue, + }; + } + }); + } Ok((to_worker_send, from_worker_recv)) } @@ -219,19 +360,25 @@ impl Transport for ZmqTransport { } fn open_input(&self) -> PortResult { - let mut inputs = self.inputs.write(); + let inputs = self.inputs.write(); let new_id = InputPortID::try_from(-(inputs.len() as isize + 1)) .map_err(|e| PortError::Other(e.to_string()))?; - let (sender, receiver) = self + let (_, receiver) = self .subscribe_for_input_port(new_id) .map_err(|e| PortError::Other(e.to_string()))?; - let state = RwLock::new(ZmqInputPortState::Open(sender, Mutex::new(receiver))); - inputs.insert(new_id, state); - - Ok(new_id) + loop { + let msg = receiver + .lock() + .recv() + .map_err(|e| PortError::Other(e.to_string()))?; + match msg { + ZmqInputPortEvent::Opened => break Ok(new_id), + _ => continue, // TODO + } + } } fn open_output(&self) -> PortResult { @@ -240,7 +387,10 @@ impl Transport for ZmqTransport { let new_id = OutputPortID::try_from(outputs.len() as isize + 1) .map_err(|e| PortError::Other(e.to_string()))?; - let state = RwLock::new(ZmqOutputPortState::Open); + let (sender, _receiver) = sync_channel(1); + let sender = Arc::new(sender); + + let state = RwLock::new(ZmqOutputPortState::Open(sender)); outputs.insert(new_id, state); Ok(new_id) @@ -308,59 +458,90 @@ impl Transport for ZmqTransport { fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { let outputs = self.outputs.read(); - let Some(output) = outputs.get(&source) else { + if outputs + .get(&source) + .is_some_and(|state| !state.read().state().is_open()) + { return Err(PortError::Invalid(source.into())); - }; - - let inputs = self.inputs.read(); - let Some(input) = inputs.get(&target) else { - return Err(PortError::Invalid(target.into())); - }; - - //let mut output = output.write(); - //if !output.state().is_open() { - // return Err(PortError::Invalid(source.into())); - //} - // - //let mut input = input.write(); - //if !input.state().is_open() { - // return Err(PortError::Invalid(source.into())); - //} - - // TODO: send from output, receive and respond from input - - //let (out_recv, out_send) = { - // let (from_worker_send, from_worker_recv) = sync_channel::(1); - // let (to_worker_send, to_worker_recv) = sync_channel::(1); - // - // tokio::task::spawn(async { - // let (output, input) = (from_worker_send, to_worker_recv); - // loop { - // tokio::time::sleep(Duration::from_secs(1)).await; - // } - // }); - // - // (from_worker_recv, to_worker_send) - //}; + } let (from_worker_send, from_worker_recv) = sync_channel::(1); let (to_worker_send, to_worker_recv) = sync_channel::(1); + let to_worker_send = Arc::new(to_worker_send); + let from_worker_recv = Arc::new(Mutex::new(from_worker_recv)); + // Output worker loop: - // 1. Send connection attempts + // 1. Send connection attempt // 2. Send messages // 2.1 Wait for ACK // 2.2. Resend on timeout // 3. Send disconnect events - tokio::task::spawn(async { - let (output, input) = (from_worker_send, to_worker_recv); - loop { - todo!(); - } - }); + { + //let output_state = output_state.clone(); + let to_worker_send = to_worker_send.clone(); + let from_worker_recv = from_worker_recv.clone(); + + let outputs = self.outputs.clone(); + + tokio::task::spawn(async move { + let (output, input) = (from_worker_send, to_worker_recv); + + // connect loop + loop { + let request = input.recv().expect("output worker recv"); + match request { + ZmqTransportEvent::AckConnection(_, input_port_id) => { + let outputs = outputs.read(); + let Some(output_state) = outputs.get(&source) else { + todo!(); + }; + let mut output_state = output_state.write(); + debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(_))); + *output_state = ZmqOutputPortState::Connected( + to_worker_send, + from_worker_recv, + input_port_id, + ); + break; + } + _ => continue, // TODO: when and why would we receive other events? + } + } + + // work loop + loop { + use ZmqTransportEvent::*; + let event = input.recv().expect("output worker recv"); + match event { + AckMessage(output_port_id, input_port_id, seq_id) => { + output + .send(ZmqOutputPortEvent::Ack(seq_id)) + .expect("worker loop ack send"); + } + + CloseInput(input_port_id) => todo!(), + + AckConnection(_, _) => { + unreachable!("already connected") + } + + // ignore input port type events + Connect(_, _) | CloseOutput(_, _) | Message(_, _, _, _) => continue, // TODO + } + } + }); + } + + // send request to connect + self.out_queue + .send(ZmqTransportEvent::Connect(source, target)) + .unwrap(); + // wait for the `Connected` event loop { let msg = from_worker_recv + .lock() .recv() .map_err(|e| PortError::Other(e.to_string()))?; use ZmqOutputPortEvent::*; @@ -384,6 +565,114 @@ impl Transport for ZmqTransport { } } +fn handle_zmq_msg( + msg: ZmqMessage, + outputs: &RwLock>>, + inputs: &RwLock>>, +) -> Result<(), Error> { + let Ok(event) = ZmqTransportEvent::try_from(msg) else { + todo!(); + }; + + use ZmqTransportEvent::*; + match event { + // input ports + Connect(_, input_port_id) => { + let inputs = inputs.read(); + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + let input = input.read(); + + use ZmqInputPortState::*; + match &*input { + Closed => todo!(), + Open(sender, _) | Connected(sender, _, _) => { + sender.send(event).unwrap(); + } + }; + } + Message(output_port_id, input_port_id, _, _) => { + let inputs = inputs.read(); + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + + let input = input.read(); + let ZmqInputPortState::Connected(sender, _, ids) = &*input else { + todo!(); + }; + + // TODO: probably move to ports worker? no sense having here + if !ids.iter().any(|&id| id == output_port_id) { + todo!(); + } + + sender.send(event).unwrap(); + } + CloseOutput(_, input_port_id) => { + let inputs = inputs.read(); + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + let input = input.read(); + + use ZmqInputPortState::*; + match &*input { + Closed => todo!(), + Open(sender, _) | Connected(sender, _, _) => { + sender.send(event).unwrap(); + } + }; + } + + // output ports + AckConnection(output_port_id, _) => { + let outputs = outputs.read(); + let Some(output) = outputs.get(&output_port_id) else { + todo!(); + }; + let output = output.read(); + + let ZmqOutputPortState::Open(sender) = &*output else { + todo!(); + }; + sender.send(event).unwrap(); + } + AckMessage(output_port_id, _, _) => { + let outputs = outputs.read(); + let Some(output) = outputs.get(&output_port_id) else { + todo!(); + }; + let output = output.read(); + + let ZmqOutputPortState::Connected(sender, _, _) = &*output else { + todo!(); + }; + sender.send(event).unwrap(); + } + CloseInput(input_port_id) => { + let outputs = outputs.read(); + + for (_, state) in outputs.iter() { + let state = state.read(); + + let ZmqOutputPortState::Connected(sender, _, id) = &*state else { + todo!(); + }; + + if *id != input_port_id { + todo!(); + } + + sender.send(event.clone()).unwrap(); + } + } + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; From 00eef9bafe04ee93bbb3226007d0f2a3893f79b8 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 08:36:26 +0200 Subject: [PATCH 09/63] Add draft implementations for the public `send` and `recv` --- lib/protoflow-zeromq/src/lib.rs | 61 ++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index cb7ca70b..ff33ff7b 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -478,10 +478,10 @@ impl Transport for ZmqTransport { // 2.2. Resend on timeout // 3. Send disconnect events { - //let output_state = output_state.clone(); let to_worker_send = to_worker_send.clone(); let from_worker_recv = from_worker_recv.clone(); + let out_queue = self.out_queue.clone(); let outputs = self.outputs.clone(); tokio::task::spawn(async move { @@ -489,6 +489,11 @@ impl Transport for ZmqTransport { // connect loop loop { + // send request to connect + out_queue + .send(ZmqTransportEvent::Connect(source, target)) + .unwrap(); + let request = input.recv().expect("output worker recv"); match request { ZmqTransportEvent::AckConnection(_, input_port_id) => { @@ -509,10 +514,13 @@ impl Transport for ZmqTransport { } } - // work loop + // work loop for handling events loop { use ZmqTransportEvent::*; let event = input.recv().expect("output worker recv"); + if !matches!(event, Message(_, _, _, _)) { + unreachable!("why are we getting non-Message?"); + } match event { AckMessage(output_port_id, input_port_id, seq_id) => { output @@ -533,11 +541,6 @@ impl Transport for ZmqTransport { }); } - // send request to connect - self.out_queue - .send(ZmqTransportEvent::Connect(source, target)) - .unwrap(); - // wait for the `Connected` event loop { let msg = from_worker_recv @@ -553,11 +556,51 @@ impl Transport for ZmqTransport { } fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { - todo!(); + let outputs = self.outputs.read(); + let Some(output) = outputs.get(&output) else { + return Err(PortError::Invalid(output.into())); + }; + let output = output.read(); + + let ZmqOutputPortState::Connected(sender, receiver, input_port_id) = &*output else { + return Err(PortError::Disconnected); + }; + + sender.send(message).unwrap(); + + loop { + let msg = receiver.lock().recv().unwrap(); + + use ZmqOutputPortEvent::*; + match msg { + Ack(_seq_id) => break Ok(()), + _ => continue, // TODO + } + } } fn recv(&self, input: InputPortID) -> PortResult> { - todo!(); + let inputs = self.inputs.read(); + let Some(input) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; + let input = input.read(); + + let ZmqInputPortState::Connected(_, receiver, _) = &*input else { + return Err(PortError::Disconnected); + }; + + loop { + use ZmqInputPortEvent::*; + match receiver.lock().recv() { + // ignore + Ok(Opened) | Ok(Connected(_)) => continue, + + Ok(Closed) => break Ok(None), // EOS + Ok(Message(bytes)) => break Ok(Some(bytes)), + Err(e) => break Err(PortError::Other(e.to_string())), + } + } } fn try_recv(&self, _input: InputPortID) -> PortResult> { From a65edd1c83069d81e74589e5cf6531971c7a0be9 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 12:27:55 +0200 Subject: [PATCH 10/63] Add separate ports for requests from public API --- lib/protoflow-zeromq/src/lib.rs | 143 ++++++++++++++++++++++---------- 1 file changed, 101 insertions(+), 42 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index ff33ff7b..3603230d 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -39,10 +39,14 @@ pub struct ZmqTransport { #[derive(Debug, Clone)] enum ZmqOutputPortState { - Open(Arc>), + Open(Arc>)>>), Connected( + // channels for public send, contained channel is for the ack back + Arc)>>, + // internal channels for events Arc>, Arc>>, + // id of the connected input port InputPortID, ), Closed, @@ -53,7 +57,7 @@ impl ZmqOutputPortState { use ZmqOutputPortState::*; match self { Open(_) => PortState::Open, - Connected(_, _, _) => PortState::Connected, + Connected(_, _, _, _) => PortState::Connected, Closed => PortState::Closed, } } @@ -66,8 +70,13 @@ enum ZmqInputPortState { Arc>>, ), Connected( + // channels for the public recv + Arc>, + Arc>>, + // internal channels for events Arc>, Arc>>, + // vec of the connected port ids Vec, ), Closed, @@ -78,7 +87,7 @@ impl ZmqInputPortState { use ZmqInputPortState::*; match self { Open(_, _) => PortState::Open, - Connected(_, _, _) => PortState::Connected, + Connected(_, _, _, _, _) => PortState::Connected, Closed => PortState::Closed, } } @@ -318,8 +327,15 @@ impl ZmqTransport { todo!(); }; + let (msgs_send, msgs_recv) = sync_channel(1); + + let msgs_send = Arc::new(msgs_send); + let msgs_recv = Arc::new(Mutex::new(msgs_recv)); + let mut input = input.write(); *input = ZmqInputPortState::Connected( + msgs_send, + msgs_recv, to_worker_send.clone(), from_worker_recv.clone(), Vec::new(), @@ -327,7 +343,23 @@ impl ZmqTransport { } // Message from output port - Message(output_port_id, input_port_id, _, bytes) => todo!(), + Message(output_port_id, input_port_id, _, bytes) => { + let inputs = inputs.read(); + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + + let input = input.read(); + use ZmqInputPortState::*; + match &*input { + Open(arc, arc1) => todo!(), + Closed => todo!(), + Connected(sender, _, _, _, _) => { + sender.send(ZmqInputPortEvent::Message(bytes)).unwrap() + } + }; + } + // Output port reports being closed CloseInput(input_port_id) => todo!(), @@ -405,7 +437,7 @@ impl Transport for ZmqTransport { let state = state.read(); - let ZmqInputPortState::Connected(sender, receiver, _) = &*state else { + let ZmqInputPortState::Connected(_, _, sender, receiver, _) = &*state else { return Err(PortError::Disconnected); }; @@ -435,7 +467,7 @@ impl Transport for ZmqTransport { let state = state.write(); - let ZmqOutputPortState::Connected(sender, receiver, input) = &*state else { + let ZmqOutputPortState::Connected(_, sender, receiver, input) = &*state else { return Err(PortError::Disconnected); }; @@ -471,6 +503,10 @@ impl Transport for ZmqTransport { let to_worker_send = Arc::new(to_worker_send); let from_worker_recv = Arc::new(Mutex::new(from_worker_recv)); + let (msg_req_send, msg_req_recv) = sync_channel(1); + let msg_req_send = Arc::new(msg_req_send); + let msg_req_recv = Arc::new(Mutex::new(msg_req_recv)); + // Output worker loop: // 1. Send connection attempt // 2. Send messages @@ -504,6 +540,7 @@ impl Transport for ZmqTransport { let mut output_state = output_state.write(); debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(_))); *output_state = ZmqOutputPortState::Connected( + msg_req_send, to_worker_send, from_worker_recv, input_port_id, @@ -514,30 +551,57 @@ impl Transport for ZmqTransport { } } - // work loop for handling events - loop { - use ZmqTransportEvent::*; - let event = input.recv().expect("output worker recv"); - if !matches!(event, Message(_, _, _, _)) { - unreachable!("why are we getting non-Message?"); + // work loop for sending events + tokio::task::spawn(async move { + let mut seq_id = 1; + loop { + let req = msg_req_recv.lock().recv().expect("output worker req recv"); + + let outputs = outputs.read(); + let Some(output_state) = outputs.get(&source) else { + todo!(); + }; + + let ZmqOutputPortState::Connected(ack_send, sender, _, _) = + &*output_state.read() + else { + todo!(); + }; + + sender + .send(ZmqTransportEvent::Message(source, target, seq_id, req.0)) + .unwrap(); + + seq_id += 1; } - match event { - AckMessage(output_port_id, input_port_id, seq_id) => { - output - .send(ZmqOutputPortEvent::Ack(seq_id)) - .expect("worker loop ack send"); + }); + + // work loop for handling events + tokio::task::spawn(async move { + loop { + use ZmqTransportEvent::*; + let event = input.recv().expect("output worker event recv"); + if !matches!(event, Message(_, _, _, _)) { + unreachable!("why are we getting non-Message?"); } + match event { + AckMessage(output_port_id, input_port_id, seq_id) => { + output + .send(ZmqOutputPortEvent::Ack(seq_id)) + .expect("worker loop ack send"); + } - CloseInput(input_port_id) => todo!(), + CloseInput(input_port_id) => todo!(), - AckConnection(_, _) => { - unreachable!("already connected") - } + AckConnection(_, _) => { + unreachable!("already connected") + } - // ignore input port type events - Connect(_, _) | CloseOutput(_, _) | Message(_, _, _, _) => continue, // TODO + // ignore input port type events + Connect(_, _) | CloseOutput(_, _) | Message(_, _, _, _) => continue, // TODO + } } - } + }); }); } @@ -562,21 +626,16 @@ impl Transport for ZmqTransport { }; let output = output.read(); - let ZmqOutputPortState::Connected(sender, receiver, input_port_id) = &*output else { + let ZmqOutputPortState::Connected(sender, _, _, _) = &*output else { return Err(PortError::Disconnected); }; - sender.send(message).unwrap(); + let (ack_send, ack_recv) = sync_channel(1); - loop { - let msg = receiver.lock().recv().unwrap(); + sender.send((message, ack_send)).unwrap(); - use ZmqOutputPortEvent::*; - match msg { - Ack(_seq_id) => break Ok(()), - _ => continue, // TODO - } - } + ack_recv.recv().unwrap(); + Ok(()) } fn recv(&self, input: InputPortID) -> PortResult> { @@ -586,7 +645,7 @@ impl Transport for ZmqTransport { }; let input = input.read(); - let ZmqInputPortState::Connected(_, receiver, _) = &*input else { + let ZmqInputPortState::Connected(_, receiver, _, _, _) = &*input else { return Err(PortError::Disconnected); }; @@ -630,19 +689,19 @@ fn handle_zmq_msg( use ZmqInputPortState::*; match &*input { Closed => todo!(), - Open(sender, _) | Connected(sender, _, _) => { + Open(sender, _) | Connected(_, _, sender, _, _) => { sender.send(event).unwrap(); } }; } - Message(output_port_id, input_port_id, _, _) => { + Message(output_port_id, input_port_id, _, bytes) => { let inputs = inputs.read(); let Some(input) = inputs.get(&input_port_id) else { todo!(); }; let input = input.read(); - let ZmqInputPortState::Connected(sender, _, ids) = &*input else { + let ZmqInputPortState::Connected(sender, _, _, _, ids) = &*input else { todo!(); }; @@ -651,7 +710,7 @@ fn handle_zmq_msg( todo!(); } - sender.send(event).unwrap(); + sender.send(ZmqInputPortEvent::Message(bytes)).unwrap(); } CloseOutput(_, input_port_id) => { let inputs = inputs.read(); @@ -663,7 +722,7 @@ fn handle_zmq_msg( use ZmqInputPortState::*; match &*input { Closed => todo!(), - Open(sender, _) | Connected(sender, _, _) => { + Open(sender, _) | Connected(_, _, sender, _, _) => { sender.send(event).unwrap(); } }; @@ -689,7 +748,7 @@ fn handle_zmq_msg( }; let output = output.read(); - let ZmqOutputPortState::Connected(sender, _, _) = &*output else { + let ZmqOutputPortState::Connected(_, sender, _, _) = &*output else { todo!(); }; sender.send(event).unwrap(); @@ -700,7 +759,7 @@ fn handle_zmq_msg( for (_, state) in outputs.iter() { let state = state.read(); - let ZmqOutputPortState::Connected(sender, _, id) = &*state else { + let ZmqOutputPortState::Connected(_, sender, _, id) = &*state else { todo!(); }; From 264414cb9e5ace6ddd8fbbfd134390b521ed2b9b Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 13:26:35 +0200 Subject: [PATCH 11/63] Separate port channels further --- lib/protoflow-zeromq/src/lib.rs | 147 ++++++++++++++++++++------------ 1 file changed, 93 insertions(+), 54 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 3603230d..10e80f46 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -30,7 +30,7 @@ const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; pub struct ZmqTransport { tokio: tokio::runtime::Handle, - out_queue: Arc>, + pub_queue: Arc>, sub_queue: Arc>, outputs: Arc>>>, @@ -42,7 +42,7 @@ enum ZmqOutputPortState { Open(Arc>)>>), Connected( // channels for public send, contained channel is for the ack back - Arc)>>, + Arc>)>>, // internal channels for events Arc>, Arc>>, @@ -52,6 +52,12 @@ enum ZmqOutputPortState { Closed, } +#[derive(Debug, Clone)] +enum ZmqOutputPortRequest { + Close, + Send(Bytes), +} + impl ZmqOutputPortState { fn state(&self) -> PortState { use ZmqOutputPortState::*; @@ -66,10 +72,14 @@ impl ZmqOutputPortState { #[derive(Debug, Clone)] enum ZmqInputPortState { Open( + // TODO: hide these Arc>, Arc>>, ), Connected( + // channels for requests from public close + Arc>)>>, + Arc>)>>>, // channels for the public recv Arc>, Arc>>, @@ -82,12 +92,17 @@ enum ZmqInputPortState { Closed, } +#[derive(Debug, Clone)] +enum ZmqInputPortRequest { + Close, +} + impl ZmqInputPortState { fn state(&self) -> PortState { use ZmqInputPortState::*; match self { Open(_, _) => PortState::Open, - Connected(_, _, _, _, _) => PortState::Connected, + Connected(_, _, _, _, _, _, _) => PortState::Connected, Closed => PortState::Closed, } } @@ -130,13 +145,14 @@ impl From for ZmqMessage { let mut msg = ZmqMessage::from(topic.clone()); // second frame of the message is the payload + use ZmqTransportEvent::*; match value { - ZmqTransportEvent::Connect(output_port_id, input_port_id) => todo!(), - ZmqTransportEvent::AckConnection(output_port_id, input_port_id) => todo!(), - ZmqTransportEvent::Message(_, _, _, bytes) => msg.push_back(bytes), - ZmqTransportEvent::AckMessage(output_port_id, input_port_id, _) => todo!(), - ZmqTransportEvent::CloseOutput(output_port_id, input_port_id) => todo!(), - ZmqTransportEvent::CloseInput(input_port_id) => todo!(), + Connect(output_port_id, input_port_id) => todo!(), + AckConnection(output_port_id, input_port_id) => todo!(), + Message(_, _, _, bytes) => msg.push_back(bytes), + AckMessage(output_port_id, input_port_id, _) => todo!(), + CloseOutput(output_port_id, input_port_id) => todo!(), + CloseInput(input_port_id) => todo!(), }; msg @@ -221,7 +237,7 @@ impl ZmqTransport { let sub_queue = Arc::new(sub_queue); let transport = Self { - out_queue, + pub_queue: out_queue, sub_queue, tokio, outputs, @@ -239,25 +255,21 @@ impl ZmqTransport { let mut psock = psock; tokio::task::spawn(async move { - loop { - let Ok(event) = queue.recv() else { - continue; - }; - - let msg = ZmqMessage::from(event); - - tokio.block_on(psock.send(msg)).expect("zmq send worker") - } + queue + .into_iter() + .map(ZmqMessage::from) + .try_for_each(|msg| tokio.block_on(psock.send(msg))) + .expect("zmq send worker") }); } fn start_recv_worker( &self, ssock: zeromq::SubSocket, - queue: tokio::sync::mpsc::Receiver, + sub_queue: tokio::sync::mpsc::Receiver, ) { let mut ssock = ssock; - let mut queue = queue; + let mut sub_queue = sub_queue; let outputs = self.outputs.clone(); let inputs = self.inputs.clone(); @@ -266,7 +278,7 @@ impl ZmqTransport { loop { tokio::select! { Ok(msg) = ssock.recv() => handle_zmq_msg(msg, &outputs, &inputs).unwrap(), - Some(req) = queue.recv() => { + Some(req) = sub_queue.recv() => { use ZmqSubscriptionRequest::*; match req { Subscribe(topic) => ssock.subscribe(&topic).await.expect("zmq recv worker subscribe"), @@ -321,12 +333,16 @@ impl ZmqTransport { let event = input.recv().expect("input worker recv"); match event { // Connection attempt - Connect(output_port_id, input_port_id) => { + Connect(_, input_port_id) => { let inputs = inputs.read(); let Some(input) = inputs.get(&input_port_id) else { todo!(); }; + let (req_send, req_recv) = sync_channel(1); + let req_send = Arc::new(req_send); + let req_recv = Arc::new(Mutex::new(req_recv)); + let (msgs_send, msgs_recv) = sync_channel(1); let msgs_send = Arc::new(msgs_send); @@ -334,6 +350,8 @@ impl ZmqTransport { let mut input = input.write(); *input = ZmqInputPortState::Connected( + req_send, + req_recv, msgs_send, msgs_recv, to_worker_send.clone(), @@ -343,7 +361,7 @@ impl ZmqTransport { } // Message from output port - Message(output_port_id, input_port_id, _, bytes) => { + Message(_, input_port_id, _, bytes) => { let inputs = inputs.read(); let Some(input) = inputs.get(&input_port_id) else { todo!(); @@ -352,16 +370,31 @@ impl ZmqTransport { let input = input.read(); use ZmqInputPortState::*; match &*input { - Open(arc, arc1) => todo!(), + Open(_, _) => todo!(), Closed => todo!(), - Connected(sender, _, _, _, _) => { + Connected(_, _, sender, _, _, _, _) => { sender.send(ZmqInputPortEvent::Message(bytes)).unwrap() } }; } // Output port reports being closed - CloseInput(input_port_id) => todo!(), + CloseInput(input_port_id) => { + let inputs = inputs.read(); + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + + let input = input.read(); + use ZmqInputPortState::*; + match &*input { + Open(_, _) => todo!(), + Closed => todo!(), + Connected(_, _, sender, _, _, _, _) => { + sender.send(ZmqInputPortEvent::Closed).unwrap() + } + }; + } // ignore output port type events: AckConnection(_, _) | AckMessage(_, _, _) | CloseOutput(_, _) => continue, @@ -437,25 +470,20 @@ impl Transport for ZmqTransport { let state = state.read(); - let ZmqInputPortState::Connected(_, _, sender, receiver, _) = &*state else { + let ZmqInputPortState::Connected(sender, _, _, _, _, _, _) = &*state else { return Err(PortError::Disconnected); }; + let (close_send, close_recv) = sync_channel(1); + sender - .send(ZmqTransportEvent::CloseInput(input)) + .send((ZmqInputPortRequest::Close, close_send)) .map_err(|e| PortError::Other(e.to_string()))?; - loop { - let msg = receiver - .lock() - .recv() - .map_err(|e| PortError::Other(e.to_string()))?; - use ZmqInputPortEvent::*; - match msg { - Closed => break Ok(true), - _ => continue, // TODO - }; - } + close_recv + .recv() + .map_err(|_| PortError::Disconnected)? + .map(|_| true) } fn close_output(&self, output: OutputPortID) -> PortResult { @@ -517,7 +545,7 @@ impl Transport for ZmqTransport { let to_worker_send = to_worker_send.clone(); let from_worker_recv = from_worker_recv.clone(); - let out_queue = self.out_queue.clone(); + let pub_queue = self.pub_queue.clone(); let outputs = self.outputs.clone(); tokio::task::spawn(async move { @@ -526,7 +554,7 @@ impl Transport for ZmqTransport { // connect loop loop { // send request to connect - out_queue + pub_queue .send(ZmqTransportEvent::Connect(source, target)) .unwrap(); @@ -551,6 +579,8 @@ impl Transport for ZmqTransport { } } + // TODO: combine these two spawns by using tokio's channels and `select!` + // work loop for sending events tokio::task::spawn(async move { let mut seq_id = 1; @@ -562,17 +592,25 @@ impl Transport for ZmqTransport { todo!(); }; - let ZmqOutputPortState::Connected(ack_send, sender, _, _) = + let ZmqOutputPortState::Connected(ack_send, sender, _, output_id) = &*output_state.read() else { todo!(); }; - sender - .send(ZmqTransportEvent::Message(source, target, seq_id, req.0)) - .unwrap(); + let resp = req.1; // TODO: respond - seq_id += 1; + match req.0 { + ZmqOutputPortRequest::Send(bytes) => { + sender + .send(ZmqTransportEvent::Message(source, target, seq_id, bytes)) + .unwrap(); + seq_id += 1; + } + ZmqOutputPortRequest::Close => sender + .send(ZmqTransportEvent::CloseOutput(source, *output_id)) + .unwrap(), + }; } }); @@ -632,10 +670,11 @@ impl Transport for ZmqTransport { let (ack_send, ack_recv) = sync_channel(1); - sender.send((message, ack_send)).unwrap(); + sender + .send((ZmqOutputPortRequest::Send(message), ack_send)) + .unwrap(); - ack_recv.recv().unwrap(); - Ok(()) + ack_recv.recv().map_err(|_| PortError::Disconnected)? } fn recv(&self, input: InputPortID) -> PortResult> { @@ -645,7 +684,7 @@ impl Transport for ZmqTransport { }; let input = input.read(); - let ZmqInputPortState::Connected(_, receiver, _, _, _) = &*input else { + let ZmqInputPortState::Connected(_, _, _, receiver, _, _, _) = &*input else { return Err(PortError::Disconnected); }; @@ -689,7 +728,7 @@ fn handle_zmq_msg( use ZmqInputPortState::*; match &*input { Closed => todo!(), - Open(sender, _) | Connected(_, _, sender, _, _) => { + Open(sender, _) | Connected(_, _, _, _, sender, _, _) => { sender.send(event).unwrap(); } }; @@ -701,7 +740,7 @@ fn handle_zmq_msg( }; let input = input.read(); - let ZmqInputPortState::Connected(sender, _, _, _, ids) = &*input else { + let ZmqInputPortState::Connected(_, _, sender, _, _, _, ids) = &*input else { todo!(); }; @@ -722,7 +761,7 @@ fn handle_zmq_msg( use ZmqInputPortState::*; match &*input { Closed => todo!(), - Open(sender, _) | Connected(_, _, sender, _, _) => { + Open(sender, _) | Connected(_, _, _, _, sender, _, _) => { sender.send(event).unwrap(); } }; From 03145abf788d408acbf2139b629f4b7559ddef42 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 13:33:48 +0200 Subject: [PATCH 12/63] Send event when output connects --- lib/protoflow-zeromq/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 10e80f46..220015e9 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -573,6 +573,9 @@ impl Transport for ZmqTransport { from_worker_recv, input_port_id, ); + output + .send(ZmqOutputPortEvent::Opened) + .expect("output worker connected send"); break; } _ => continue, // TODO: when and why would we receive other events? From 33431fd64e5aa55018a786b657dcff528328f33d Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 13:34:41 +0200 Subject: [PATCH 13/63] Silence unused-var warnings --- lib/protoflow-zeromq/src/lib.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 220015e9..52834dbe 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -147,12 +147,12 @@ impl From for ZmqMessage { // second frame of the message is the payload use ZmqTransportEvent::*; match value { - Connect(output_port_id, input_port_id) => todo!(), - AckConnection(output_port_id, input_port_id) => todo!(), + Connect(_, _) => todo!(), + AckConnection(_, _) => todo!(), Message(_, _, _, bytes) => msg.push_back(bytes), - AckMessage(output_port_id, input_port_id, _) => todo!(), - CloseOutput(output_port_id, input_port_id) => todo!(), - CloseInput(input_port_id) => todo!(), + AckMessage(_, _, _) => todo!(), + CloseOutput(_, _) => todo!(), + CloseInput(_) => todo!(), }; msg @@ -162,7 +162,7 @@ impl From for ZmqMessage { impl TryFrom for ZmqTransportEvent { type Error = protoflow_core::DecodeError; - fn try_from(value: ZmqMessage) -> Result { + fn try_from(_value: ZmqMessage) -> Result { todo!() } } @@ -326,7 +326,7 @@ impl ZmqTransport { let from_worker_recv = from_worker_recv.clone(); tokio::task::spawn(async move { - let (output, input) = (from_worker_send, to_worker_recv); + let (_output, input) = (from_worker_send, to_worker_recv); loop { use ZmqTransportEvent::*; @@ -595,13 +595,13 @@ impl Transport for ZmqTransport { todo!(); }; - let ZmqOutputPortState::Connected(ack_send, sender, _, output_id) = + let ZmqOutputPortState::Connected(_, sender, _, output_id) = &*output_state.read() else { todo!(); }; - let resp = req.1; // TODO: respond + let _resp = req.1; // TODO: respond match req.0 { ZmqOutputPortRequest::Send(bytes) => { @@ -626,13 +626,13 @@ impl Transport for ZmqTransport { unreachable!("why are we getting non-Message?"); } match event { - AckMessage(output_port_id, input_port_id, seq_id) => { + AckMessage(_, _, seq_id) => { output .send(ZmqOutputPortEvent::Ack(seq_id)) .expect("worker loop ack send"); } - CloseInput(input_port_id) => todo!(), + CloseInput(_) => todo!(), AckConnection(_, _) => { unreachable!("already connected") From 0f5c529cee6285ebf853386d7179c978bc562f31 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 14:01:18 +0200 Subject: [PATCH 14/63] Partially refactor input port opening --- lib/protoflow-zeromq/src/lib.rs | 98 +++++++++++++++++++++++++++------ 1 file changed, 81 insertions(+), 17 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 52834dbe..6797b8cb 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -9,7 +9,7 @@ pub use protoflow_core::prelude; extern crate std; use protoflow_core::{ - prelude::{Arc, BTreeMap, Bytes, String, ToString, Vec}, + prelude::{vec, Arc, BTreeMap, Bytes, String, ToString, Vec}, InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, }; @@ -74,7 +74,7 @@ enum ZmqInputPortState { Open( // TODO: hide these Arc>, - Arc>>, + Arc>>, ), Connected( // channels for requests from public close @@ -85,7 +85,7 @@ enum ZmqInputPortState { Arc>>, // internal channels for events Arc>, - Arc>>, + Arc>>, // vec of the connected port ids Vec, ), @@ -405,6 +405,83 @@ impl ZmqTransport { Ok((to_worker_send, from_worker_recv)) } + + fn start_input_worker(&self, input_port_id: InputPortID) -> Result<(), PortError> { + let topic = format!("{}:", input_port_id); + + let (to_worker_send, to_worker_recv) = sync_channel(1); + let to_worker_send = Arc::new(to_worker_send); + let to_worker_recv = Arc::new(Mutex::new(to_worker_recv)); + + { + let mut inputs = self.inputs.write(); + let state = ZmqInputPortState::Open(to_worker_send.clone(), to_worker_recv.clone()); + let state = RwLock::new(state); + inputs.insert(input_port_id, state); + } + + let inputs = self.inputs.clone(); + tokio::task::spawn(async move { + let input = &to_worker_recv; + + let inputs = inputs; + + loop { + let event: ZmqTransportEvent = input.lock().recv().expect("input worker recv"); + use ZmqTransportEvent::*; + match event { + Connect(output_port_id, input_port_id) => { + let inputs = inputs.read(); + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let input_state = input_state.write(); + + use ZmqInputPortState::*; + match &*input_state { + Open(_, _) => { + let (req_send, req_recv) = sync_channel(1); + let req_send = Arc::new(req_send); + let req_recv = Arc::new(Mutex::new(req_recv)); + + let (msgs_send, msgs_recv) = sync_channel(1); + + let msgs_send = Arc::new(msgs_send); + let msgs_recv = Arc::new(Mutex::new(msgs_recv)); + + let mut input_state = input_state; + + *input_state = ZmqInputPortState::Connected( + req_send, + req_recv, + msgs_send, + msgs_recv, + to_worker_send.clone(), + input.clone(), + vec![output_port_id], + ); + } + Connected(_, _, _, _, _, _, _) => todo!(), + Closed => todo!(), + } + } + AckConnection(output_port_id, input_port_id) => todo!(), + Message(output_port_id, input_port_id, _, bytes) => todo!(), + AckMessage(output_port_id, input_port_id, _) => todo!(), + CloseOutput(output_port_id, input_port_id) => todo!(), + CloseInput(input_port_id) => todo!(), + }; + } + }); + + // send sub request + self.tokio + .block_on( + self.sub_queue + .send(ZmqSubscriptionRequest::Subscribe(topic)), + ) + .map_err(|e| PortError::Other(e.to_string())) + } } impl Transport for ZmqTransport { @@ -430,20 +507,7 @@ impl Transport for ZmqTransport { let new_id = InputPortID::try_from(-(inputs.len() as isize + 1)) .map_err(|e| PortError::Other(e.to_string()))?; - let (_, receiver) = self - .subscribe_for_input_port(new_id) - .map_err(|e| PortError::Other(e.to_string()))?; - - loop { - let msg = receiver - .lock() - .recv() - .map_err(|e| PortError::Other(e.to_string()))?; - match msg { - ZmqInputPortEvent::Opened => break Ok(new_id), - _ => continue, // TODO - } - } + self.start_input_worker(new_id).map(|_| new_id) } fn open_output(&self) -> PortResult { From b19ce07b30bf8f85e408285ca72abf08354735e9 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 15:14:04 +0200 Subject: [PATCH 15/63] Finish input worker refactor --- lib/protoflow-zeromq/src/lib.rs | 240 +++++++++++++------------------- 1 file changed, 96 insertions(+), 144 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 6797b8cb..fb835e63 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -290,55 +290,43 @@ impl ZmqTransport { }); } - fn subscribe_for_input_port( - &self, - input: InputPortID, - ) -> Result< - ( - Arc>, - Arc>>, - ), - ZmqError, - > { - // TODO: only sub to relevant events - let topic = format!("{}:", input); - self.tokio - .block_on( - self.sub_queue - .send(ZmqSubscriptionRequest::Subscribe(topic)), - ) - .unwrap(); + fn start_input_worker(&self, input_port_id: InputPortID) -> Result<(), PortError> { + let topic = format!("{}:", input_port_id); - let (from_worker_send, from_worker_recv) = sync_channel(1); let (to_worker_send, to_worker_recv) = sync_channel(1); - let to_worker_send = Arc::new(to_worker_send); - let from_worker_recv = Arc::new(Mutex::new(from_worker_recv)); + let to_worker_recv = Arc::new(Mutex::new(to_worker_recv)); - // Input worker loop: - // 1. Receive connection attempts and respond - // 2. Receive messages and forward to channel - // 3. Receive and handle disconnects { - let inputs = self.inputs.clone(); - - let to_worker_send = to_worker_send.clone(); - let from_worker_recv = from_worker_recv.clone(); + let mut inputs = self.inputs.write(); + let state = ZmqInputPortState::Open(to_worker_send.clone(), to_worker_recv.clone()); + let state = RwLock::new(state); + inputs.insert(input_port_id, state); + } - tokio::task::spawn(async move { - let (_output, input) = (from_worker_send, to_worker_recv); + //let sub_queue = self.sub_queue.clone(); + let pub_queue = self.pub_queue.clone(); + let inputs = self.inputs.clone(); + tokio::task::spawn(async move { + let input = &to_worker_recv; + let inputs = inputs; - loop { - use ZmqTransportEvent::*; - let event = input.recv().expect("input worker recv"); - match event { - // Connection attempt - Connect(_, input_port_id) => { - let inputs = inputs.read(); - let Some(input) = inputs.get(&input_port_id) else { - todo!(); - }; + // Input worker loop: + // 1. Receive connection attempts and respond + // 2. Receive messages and forward to channel + // 3. Receive and handle disconnects + loop { + let event: ZmqTransportEvent = input.lock().recv().expect("input worker recv"); + use ZmqTransportEvent::*; + match event { + Connect(output_port_id, input_port_id) => { + let inputs = inputs.read(); + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let mut input_state = input_state.upgradable_read(); + let open = |input_state: &mut ZmqInputPortState| { let (req_send, req_recv) = sync_channel(1); let req_send = Arc::new(req_send); let req_recv = Arc::new(Mutex::new(req_recv)); @@ -348,132 +336,97 @@ impl ZmqTransport { let msgs_send = Arc::new(msgs_send); let msgs_recv = Arc::new(Mutex::new(msgs_recv)); - let mut input = input.write(); - *input = ZmqInputPortState::Connected( + *input_state = ZmqInputPortState::Connected( req_send, req_recv, msgs_send, msgs_recv, to_worker_send.clone(), - from_worker_recv.clone(), - Vec::new(), + input.clone(), + vec![output_port_id], ); - } - - // Message from output port - Message(_, input_port_id, _, bytes) => { - let inputs = inputs.read(); - let Some(input) = inputs.get(&input_port_id) else { - todo!(); - }; + }; - let input = input.read(); - use ZmqInputPortState::*; - match &*input { - Open(_, _) => todo!(), - Closed => todo!(), - Connected(_, _, sender, _, _, _, _) => { - sender.send(ZmqInputPortEvent::Message(bytes)).unwrap() + use ZmqInputPortState::*; + match &*input_state { + Open(_, _) => input_state.with_upgraded(open), + Connected(_, _, _, _, _, _, connected_ids) => { + if !connected_ids.iter().any(|&id| id == output_port_id) { + input_state.with_upgraded(open) } - }; + } + Connected(_, _, _, _, _, _, _) => todo!(), + Closed => todo!(), } + } + Message(output_port_id, _, seq_id, bytes) => { + let inputs = inputs.read(); + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let input_state = input_state.read(); - // Output port reports being closed - CloseInput(input_port_id) => { - let inputs = inputs.read(); - let Some(input) = inputs.get(&input_port_id) else { - todo!(); - }; - - let input = input.read(); - use ZmqInputPortState::*; - match &*input { - Open(_, _) => todo!(), - Closed => todo!(), - Connected(_, _, sender, _, _, _, _) => { - sender.send(ZmqInputPortEvent::Closed).unwrap() + use ZmqInputPortState::*; + match &*input_state { + Connected(_, _, sender, _, _, _, connected_ids) => { + if !connected_ids.iter().any(|id| *id == output_port_id) { + continue; } - }; - } - - // ignore output port type events: - AckConnection(_, _) | AckMessage(_, _, _) | CloseOutput(_, _) => continue, - }; - } - }); - } - - Ok((to_worker_send, from_worker_recv)) - } - - fn start_input_worker(&self, input_port_id: InputPortID) -> Result<(), PortError> { - let topic = format!("{}:", input_port_id); - - let (to_worker_send, to_worker_recv) = sync_channel(1); - let to_worker_send = Arc::new(to_worker_send); - let to_worker_recv = Arc::new(Mutex::new(to_worker_recv)); - - { - let mut inputs = self.inputs.write(); - let state = ZmqInputPortState::Open(to_worker_send.clone(), to_worker_recv.clone()); - let state = RwLock::new(state); - inputs.insert(input_port_id, state); - } - - let inputs = self.inputs.clone(); - tokio::task::spawn(async move { - let input = &to_worker_recv; - let inputs = inputs; + sender + .send(ZmqInputPortEvent::Message(bytes)) + .expect("input worker message send"); + + pub_queue + .send(ZmqTransportEvent::AckMessage( + output_port_id, + input_port_id, + seq_id, + )) + .expect("input worker message ack"); + } - loop { - let event: ZmqTransportEvent = input.lock().recv().expect("input worker recv"); - use ZmqTransportEvent::*; - match event { - Connect(output_port_id, input_port_id) => { + Open(_, _) | Closed => todo!(), + } + } + CloseOutput(output_port_id, input_port_id) => { let inputs = inputs.read(); let Some(input_state) = inputs.get(&input_port_id) else { todo!(); }; - let input_state = input_state.write(); + let mut input_state = input_state.upgradable_read(); use ZmqInputPortState::*; - match &*input_state { - Open(_, _) => { - let (req_send, req_recv) = sync_channel(1); - let req_send = Arc::new(req_send); - let req_recv = Arc::new(Mutex::new(req_recv)); - - let (msgs_send, msgs_recv) = sync_channel(1); - - let msgs_send = Arc::new(msgs_send); - let msgs_recv = Arc::new(Mutex::new(msgs_recv)); - - let mut input_state = input_state; - - *input_state = ZmqInputPortState::Connected( - req_send, - req_recv, - msgs_send, - msgs_recv, - to_worker_send.clone(), - input.clone(), - vec![output_port_id], - ); - } - Connected(_, _, _, _, _, _, _) => todo!(), - Closed => todo!(), + let Connected(_, _, _, _, _, _, ref connected_ids) = *input_state else { + continue; + }; + + if !connected_ids.iter().any(|id| *id == output_port_id) { + continue; } + + // TODO: send unsubscription for relevant topics + //sub_queue + // .send(ZmqSubscriptionRequest::Unsubscribe("".to_string())) + // .await + // .expect("input worker closeoutput unsub"); + + input_state.with_upgraded(|state| match state { + Open(_, _) | Closed => (), + Connected(_, _, _, _, _, _, connected_ids) => { + connected_ids.retain(|&id| id != output_port_id) + } + }) } - AckConnection(output_port_id, input_port_id) => todo!(), - Message(output_port_id, input_port_id, _, bytes) => todo!(), - AckMessage(output_port_id, input_port_id, _) => todo!(), - CloseOutput(output_port_id, input_port_id) => todo!(), - CloseInput(input_port_id) => todo!(), + + // ignore, ideally we never receive these here: + AckConnection(_, _) | AckMessage(_, _, _) | CloseInput(_) => continue, }; } }); + let topic = format!("{}:", input_port_id); + // send sub request self.tokio .block_on( @@ -502,8 +455,7 @@ impl Transport for ZmqTransport { } fn open_input(&self) -> PortResult { - let inputs = self.inputs.write(); - + let inputs = self.inputs.read(); let new_id = InputPortID::try_from(-(inputs.len() as isize + 1)) .map_err(|e| PortError::Other(e.to_string()))?; From ffb39a3cd99a4d967b7de43c2718855980731f07 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 15:14:57 +0200 Subject: [PATCH 16/63] Send connection ack from input port --- lib/protoflow-zeromq/src/lib.rs | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index fb835e63..10dedbde 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -291,8 +291,6 @@ impl ZmqTransport { } fn start_input_worker(&self, input_port_id: InputPortID) -> Result<(), PortError> { - let topic = format!("{}:", input_port_id); - let (to_worker_send, to_worker_recv) = sync_channel(1); let to_worker_send = Arc::new(to_worker_send); let to_worker_recv = Arc::new(Mutex::new(to_worker_recv)); @@ -326,13 +324,23 @@ impl ZmqTransport { }; let mut input_state = input_state.upgradable_read(); + use ZmqInputPortState::*; + match &*input_state { + Open(_, _) => (), + Connected(_, _, _, _, _, _, connected_ids) => { + if !connected_ids.iter().any(|&id| id == output_port_id) { + continue; + } + } + Closed => continue, + } + let open = |input_state: &mut ZmqInputPortState| { let (req_send, req_recv) = sync_channel(1); let req_send = Arc::new(req_send); let req_recv = Arc::new(Mutex::new(req_recv)); let (msgs_send, msgs_recv) = sync_channel(1); - let msgs_send = Arc::new(msgs_send); let msgs_recv = Arc::new(Mutex::new(msgs_recv)); @@ -347,17 +355,13 @@ impl ZmqTransport { ); }; - use ZmqInputPortState::*; - match &*input_state { - Open(_, _) => input_state.with_upgraded(open), - Connected(_, _, _, _, _, _, connected_ids) => { - if !connected_ids.iter().any(|&id| id == output_port_id) { - input_state.with_upgraded(open) - } - } - Connected(_, _, _, _, _, _, _) => todo!(), - Closed => todo!(), - } + pub_queue + .send(ZmqTransportEvent::AckConnection( + output_port_id, + input_port_id, + )) + .expect("input worker conn ack"); + input_state.with_upgraded(open); } Message(output_port_id, _, seq_id, bytes) => { let inputs = inputs.read(); From 42816de73d1914df7a82fadd2c6ffd82989c0a54 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 15:20:45 +0200 Subject: [PATCH 17/63] Partially refactor output port worker --- lib/protoflow-zeromq/src/lib.rs | 475 ++++++++++++++++++++------------ 1 file changed, 300 insertions(+), 175 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 10dedbde..48bc8be3 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -39,13 +39,15 @@ pub struct ZmqTransport { #[derive(Debug, Clone)] enum ZmqOutputPortState { - Open(Arc>)>>), + Open( + Arc>)>>, + Arc>, + ), Connected( - // channels for public send, contained channel is for the ack back + // channel for public send, contained channel is for the ack back Arc>)>>, - // internal channels for events + // internal channel for events Arc>, - Arc>>, // id of the connected input port InputPortID, ), @@ -62,8 +64,8 @@ impl ZmqOutputPortState { fn state(&self) -> PortState { use ZmqOutputPortState::*; match self { - Open(_) => PortState::Open, - Connected(_, _, _, _) => PortState::Connected, + Open(_, _) => PortState::Open, + Connected(_, _, _) => PortState::Connected, Closed => PortState::Closed, } } @@ -439,6 +441,142 @@ impl ZmqTransport { ) .map_err(|e| PortError::Other(e.to_string())) } + + fn start_output_worker(&self, output_port_id: OutputPortID) -> Result<(), PortError> { + let (conn_send, conn_recv) = sync_channel(1); + let conn_send = Arc::new(conn_send); + + let (to_worker_send, to_worker_recv) = sync_channel(1); + let to_worker_send = Arc::new(to_worker_send); + + { + let mut outputs = self.outputs.write(); + let state = ZmqOutputPortState::Open(conn_send, to_worker_send.clone()); + let state = RwLock::new(state); + outputs.insert(output_port_id, state); + } + + let outputs = self.outputs.clone(); + let pub_queue = self.pub_queue.clone(); + tokio::task::spawn(async move { + let Ok((input_port_id, conn_confirm)) = conn_recv.recv() else { + todo!(); + }; + + let (msg_req_send, msg_req_recv) = sync_channel(1); + let msg_req_send = Arc::new(msg_req_send); + + // Output worker loop: + // 1. Send connection attempt + // 2. Send messages + // 2.1 Wait for ACK + // 2.2. Resend on timeout + // 3. Send disconnect events + + loop { + pub_queue + .send(ZmqTransportEvent::Connect(output_port_id, input_port_id)) + .expect("output worker send connect"); + + let response = to_worker_recv.recv().expect("output worker recv conn ack"); + + use ZmqTransportEvent::*; + match response { + AckConnection(_, input_port_id) => { + let outputs = outputs.read(); + let Some(output_state) = outputs.get(&output_port_id) else { + todo!(); + }; + let mut output_state = output_state.write(); + debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); + *output_state = ZmqOutputPortState::Connected( + msg_req_send, + to_worker_send, + input_port_id, + ); + + conn_confirm + .send(Ok(())) + .expect("output worker send confirm conn"); + + break; + } + _ => continue, + } + } + + // TODO: loop for to_worker_recv.recv(), i.e. events from socket + + let mut seq_id = 1; + 'outer: loop { + let (request, response_chan) = + msg_req_recv.recv().expect("output worker recv msg req"); + + match request { + ZmqOutputPortRequest::Close => { + let response = pub_queue + .send(ZmqTransportEvent::CloseOutput( + output_port_id, + input_port_id, + )) + .map_err(|e| PortError::Other(e.to_string())); + + response_chan + .send(response) + .expect("output worker respond close"); + } + ZmqOutputPortRequest::Send(bytes) => { + pub_queue + .send(ZmqTransportEvent::Message( + output_port_id, + input_port_id, + seq_id, + bytes, + )) + .expect("output worker send message"); + + loop { + let event = to_worker_recv.recv().expect("output worker event recv"); + + use ZmqTransportEvent::*; + match event { + AckMessage(_, _, ack_id) => { + if ack_id == seq_id { + break; + } + } + CloseInput(_) => { + let outputs = outputs.read(); + let Some(output_state) = outputs.get(&output_port_id) else { + todo!(); + }; + let mut output_state = output_state.write(); + debug_assert!(matches!( + *output_state, + ZmqOutputPortState::Connected(..) + )); + *output_state = ZmqOutputPortState::Closed; + + break 'outer; + } + + // ignore others, we shouldn't receive any new conn-acks + // nor should we be receiving input port events + AckConnection(_, _) + | Connect(_, _) + | Message(_, _, _, _) + | CloseOutput(_, _) => continue, + } + } + } + } + + seq_id += 1; + } + }); + + Ok(()) + } } impl Transport for ZmqTransport { @@ -462,23 +600,14 @@ impl Transport for ZmqTransport { let inputs = self.inputs.read(); let new_id = InputPortID::try_from(-(inputs.len() as isize + 1)) .map_err(|e| PortError::Other(e.to_string()))?; - self.start_input_worker(new_id).map(|_| new_id) } fn open_output(&self) -> PortResult { - let mut outputs = self.outputs.write(); - + let outputs = self.outputs.read(); let new_id = OutputPortID::try_from(outputs.len() as isize + 1) .map_err(|e| PortError::Other(e.to_string()))?; - - let (sender, _receiver) = sync_channel(1); - let sender = Arc::new(sender); - - let state = RwLock::new(ZmqOutputPortState::Open(sender)); - outputs.insert(new_id, state); - - Ok(new_id) + self.start_output_worker(new_id).map(|_| new_id) } fn close_input(&self, input: InputPortID) -> PortResult { @@ -515,169 +644,165 @@ impl Transport for ZmqTransport { let state = state.write(); - let ZmqOutputPortState::Connected(_, sender, receiver, input) = &*state else { + let ZmqOutputPortState::Connected(sender, _, _) = &*state else { return Err(PortError::Disconnected); }; + let (close_send, close_recv) = sync_channel(1); + sender - .send(ZmqTransportEvent::CloseOutput(output, *input)) + .send((ZmqOutputPortRequest::Close, close_send)) .map_err(|e| PortError::Other(e.to_string()))?; - loop { - let msg = receiver - .lock() - .recv() - .map_err(|e| PortError::Other(e.to_string()))?; - use ZmqOutputPortEvent::*; - match msg { - Closed => break Ok(true), - _ => continue, // TODO - } - } + close_recv + .recv() + .map_err(|_| PortError::Disconnected)? + .map(|_| true) } fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { - let outputs = self.outputs.read(); - if outputs - .get(&source) - .is_some_and(|state| !state.read().state().is_open()) - { - return Err(PortError::Invalid(source.into())); - } - - let (from_worker_send, from_worker_recv) = sync_channel::(1); - let (to_worker_send, to_worker_recv) = sync_channel::(1); - - let to_worker_send = Arc::new(to_worker_send); - let from_worker_recv = Arc::new(Mutex::new(from_worker_recv)); - - let (msg_req_send, msg_req_recv) = sync_channel(1); - let msg_req_send = Arc::new(msg_req_send); - let msg_req_recv = Arc::new(Mutex::new(msg_req_recv)); - - // Output worker loop: - // 1. Send connection attempt - // 2. Send messages - // 2.1 Wait for ACK - // 2.2. Resend on timeout - // 3. Send disconnect events - { - let to_worker_send = to_worker_send.clone(); - let from_worker_recv = from_worker_recv.clone(); - - let pub_queue = self.pub_queue.clone(); - let outputs = self.outputs.clone(); - - tokio::task::spawn(async move { - let (output, input) = (from_worker_send, to_worker_recv); - - // connect loop - loop { - // send request to connect - pub_queue - .send(ZmqTransportEvent::Connect(source, target)) - .unwrap(); - - let request = input.recv().expect("output worker recv"); - match request { - ZmqTransportEvent::AckConnection(_, input_port_id) => { - let outputs = outputs.read(); - let Some(output_state) = outputs.get(&source) else { - todo!(); - }; - let mut output_state = output_state.write(); - debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(_))); - *output_state = ZmqOutputPortState::Connected( - msg_req_send, - to_worker_send, - from_worker_recv, - input_port_id, - ); - output - .send(ZmqOutputPortEvent::Opened) - .expect("output worker connected send"); - break; - } - _ => continue, // TODO: when and why would we receive other events? - } - } - - // TODO: combine these two spawns by using tokio's channels and `select!` - - // work loop for sending events - tokio::task::spawn(async move { - let mut seq_id = 1; - loop { - let req = msg_req_recv.lock().recv().expect("output worker req recv"); - - let outputs = outputs.read(); - let Some(output_state) = outputs.get(&source) else { - todo!(); - }; - - let ZmqOutputPortState::Connected(_, sender, _, output_id) = - &*output_state.read() - else { - todo!(); - }; - - let _resp = req.1; // TODO: respond - - match req.0 { - ZmqOutputPortRequest::Send(bytes) => { - sender - .send(ZmqTransportEvent::Message(source, target, seq_id, bytes)) - .unwrap(); - seq_id += 1; - } - ZmqOutputPortRequest::Close => sender - .send(ZmqTransportEvent::CloseOutput(source, *output_id)) - .unwrap(), - }; - } - }); - - // work loop for handling events - tokio::task::spawn(async move { - loop { - use ZmqTransportEvent::*; - let event = input.recv().expect("output worker event recv"); - if !matches!(event, Message(_, _, _, _)) { - unreachable!("why are we getting non-Message?"); - } - match event { - AckMessage(_, _, seq_id) => { - output - .send(ZmqOutputPortEvent::Ack(seq_id)) - .expect("worker loop ack send"); - } - - CloseInput(_) => todo!(), - - AckConnection(_, _) => { - unreachable!("already connected") - } - - // ignore input port type events - Connect(_, _) | CloseOutput(_, _) | Message(_, _, _, _) => continue, // TODO - } - } - }); - }); - } - - // wait for the `Connected` event - loop { - let msg = from_worker_recv - .lock() - .recv() - .map_err(|e| PortError::Other(e.to_string()))?; - use ZmqOutputPortEvent::*; - match msg { - Connected(_) => break Ok(true), - _ => continue, // TODO - } - } + todo!(); + //let outputs = self.outputs.read(); + //if outputs + // .get(&source) + // .is_some_and(|state| !state.read().state().is_open()) + //{ + // return Err(PortError::Invalid(source.into())); + //} + // + //let (from_worker_send, from_worker_recv) = sync_channel::(1); + //let (to_worker_send, to_worker_recv) = sync_channel::(1); + // + //let to_worker_send = Arc::new(to_worker_send); + //let from_worker_recv = Arc::new(Mutex::new(from_worker_recv)); + // + //let (msg_req_send, msg_req_recv) = sync_channel(1); + //let msg_req_send = Arc::new(msg_req_send); + //let msg_req_recv = Arc::new(Mutex::new(msg_req_recv)); + // + //// Output worker loop: + //// 1. Send connection attempt + //// 2. Send messages + //// 2.1 Wait for ACK + //// 2.2. Resend on timeout + //// 3. Send disconnect events + //{ + // let to_worker_send = to_worker_send.clone(); + // let from_worker_recv = from_worker_recv.clone(); + // + // let pub_queue = self.pub_queue.clone(); + // let outputs = self.outputs.clone(); + // + // tokio::task::spawn(async move { + // let (output, input) = (from_worker_send, to_worker_recv); + // + // // connect loop + // loop { + // // send request to connect + // pub_queue + // .send(ZmqTransportEvent::Connect(source, target)) + // .unwrap(); + // + // let request = input.recv().expect("output worker recv"); + // match request { + // ZmqTransportEvent::AckConnection(_, input_port_id) => { + // let outputs = outputs.read(); + // let Some(output_state) = outputs.get(&source) else { + // todo!(); + // }; + // let mut output_state = output_state.write(); + // debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(_))); + // *output_state = ZmqOutputPortState::Connected( + // msg_req_send, + // to_worker_send, + // from_worker_recv, + // input_port_id, + // ); + // output + // .send(ZmqOutputPortEvent::Opened) + // .expect("output worker connected send"); + // break; + // } + // _ => continue, // TODO: when and why would we receive other events? + // } + // } + // + // // TODO: combine these two spawns by using tokio's channels and `select!` + // + // // work loop for sending events + // tokio::task::spawn(async move { + // let mut seq_id = 1; + // loop { + // let req = msg_req_recv.lock().recv().expect("output worker req recv"); + // + // let outputs = outputs.read(); + // let Some(output_state) = outputs.get(&source) else { + // todo!(); + // }; + // + // let ZmqOutputPortState::Connected(_, sender, _, output_id) = + // &*output_state.read() + // else { + // todo!(); + // }; + // + // let _resp = req.1; // TODO: respond + // + // match req.0 { + // ZmqOutputPortRequest::Send(bytes) => { + // sender + // .send(ZmqTransportEvent::Message(source, target, seq_id, bytes)) + // .unwrap(); + // seq_id += 1; + // } + // ZmqOutputPortRequest::Close => sender + // .send(ZmqTransportEvent::CloseOutput(source, *output_id)) + // .unwrap(), + // }; + // } + // }); + // + // // work loop for handling events + // tokio::task::spawn(async move { + // loop { + // use ZmqTransportEvent::*; + // let event = input.recv().expect("output worker event recv"); + // if !matches!(event, Message(_, _, _, _)) { + // unreachable!("why are we getting non-Message?"); + // } + // match event { + // AckMessage(_, _, seq_id) => { + // output + // .send(ZmqOutputPortEvent::Ack(seq_id)) + // .expect("worker loop ack send"); + // } + // + // CloseInput(_) => todo!(), + // + // AckConnection(_, _) => { + // unreachable!("already connected") + // } + // + // // ignore input port type events + // Connect(_, _) | CloseOutput(_, _) | Message(_, _, _, _) => continue, // TODO + // } + // } + // }); + // }); + //} + // + //// wait for the `Connected` event + //loop { + // let msg = from_worker_recv + // .lock() + // .recv() + // .map_err(|e| PortError::Other(e.to_string()))?; + // use ZmqOutputPortEvent::*; + // match msg { + // Connected(_) => break Ok(true), + // _ => continue, // TODO + // } + //} } fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { @@ -687,7 +812,7 @@ impl Transport for ZmqTransport { }; let output = output.read(); - let ZmqOutputPortState::Connected(sender, _, _, _) = &*output else { + let ZmqOutputPortState::Connected(sender, _, _) = &*output else { return Err(PortError::Disconnected); }; @@ -798,7 +923,7 @@ fn handle_zmq_msg( }; let output = output.read(); - let ZmqOutputPortState::Open(sender) = &*output else { + let ZmqOutputPortState::Open(_, sender) = &*output else { todo!(); }; sender.send(event).unwrap(); @@ -810,7 +935,7 @@ fn handle_zmq_msg( }; let output = output.read(); - let ZmqOutputPortState::Connected(_, sender, _, _) = &*output else { + let ZmqOutputPortState::Connected(_, sender, _) = &*output else { todo!(); }; sender.send(event).unwrap(); @@ -821,7 +946,7 @@ fn handle_zmq_msg( for (_, state) in outputs.iter() { let state = state.read(); - let ZmqOutputPortState::Connected(_, sender, _, id) = &*state else { + let ZmqOutputPortState::Connected(_, sender, id) = &*state else { todo!(); }; From a532995b86f814e343af7f0cea54339b8b9f6ebc Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 16:22:28 +0200 Subject: [PATCH 18/63] Reimplement public `connect` method --- lib/protoflow-zeromq/src/lib.rs | 179 ++++++-------------------------- 1 file changed, 30 insertions(+), 149 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 48bc8be3..b7936a5b 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -656,153 +656,31 @@ impl Transport for ZmqTransport { close_recv .recv() - .map_err(|_| PortError::Disconnected)? + .map_err(|e| PortError::Other(e.to_string()))? .map(|_| true) } fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { - todo!(); - //let outputs = self.outputs.read(); - //if outputs - // .get(&source) - // .is_some_and(|state| !state.read().state().is_open()) - //{ - // return Err(PortError::Invalid(source.into())); - //} - // - //let (from_worker_send, from_worker_recv) = sync_channel::(1); - //let (to_worker_send, to_worker_recv) = sync_channel::(1); - // - //let to_worker_send = Arc::new(to_worker_send); - //let from_worker_recv = Arc::new(Mutex::new(from_worker_recv)); - // - //let (msg_req_send, msg_req_recv) = sync_channel(1); - //let msg_req_send = Arc::new(msg_req_send); - //let msg_req_recv = Arc::new(Mutex::new(msg_req_recv)); - // - //// Output worker loop: - //// 1. Send connection attempt - //// 2. Send messages - //// 2.1 Wait for ACK - //// 2.2. Resend on timeout - //// 3. Send disconnect events - //{ - // let to_worker_send = to_worker_send.clone(); - // let from_worker_recv = from_worker_recv.clone(); - // - // let pub_queue = self.pub_queue.clone(); - // let outputs = self.outputs.clone(); - // - // tokio::task::spawn(async move { - // let (output, input) = (from_worker_send, to_worker_recv); - // - // // connect loop - // loop { - // // send request to connect - // pub_queue - // .send(ZmqTransportEvent::Connect(source, target)) - // .unwrap(); - // - // let request = input.recv().expect("output worker recv"); - // match request { - // ZmqTransportEvent::AckConnection(_, input_port_id) => { - // let outputs = outputs.read(); - // let Some(output_state) = outputs.get(&source) else { - // todo!(); - // }; - // let mut output_state = output_state.write(); - // debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(_))); - // *output_state = ZmqOutputPortState::Connected( - // msg_req_send, - // to_worker_send, - // from_worker_recv, - // input_port_id, - // ); - // output - // .send(ZmqOutputPortEvent::Opened) - // .expect("output worker connected send"); - // break; - // } - // _ => continue, // TODO: when and why would we receive other events? - // } - // } - // - // // TODO: combine these two spawns by using tokio's channels and `select!` - // - // // work loop for sending events - // tokio::task::spawn(async move { - // let mut seq_id = 1; - // loop { - // let req = msg_req_recv.lock().recv().expect("output worker req recv"); - // - // let outputs = outputs.read(); - // let Some(output_state) = outputs.get(&source) else { - // todo!(); - // }; - // - // let ZmqOutputPortState::Connected(_, sender, _, output_id) = - // &*output_state.read() - // else { - // todo!(); - // }; - // - // let _resp = req.1; // TODO: respond - // - // match req.0 { - // ZmqOutputPortRequest::Send(bytes) => { - // sender - // .send(ZmqTransportEvent::Message(source, target, seq_id, bytes)) - // .unwrap(); - // seq_id += 1; - // } - // ZmqOutputPortRequest::Close => sender - // .send(ZmqTransportEvent::CloseOutput(source, *output_id)) - // .unwrap(), - // }; - // } - // }); - // - // // work loop for handling events - // tokio::task::spawn(async move { - // loop { - // use ZmqTransportEvent::*; - // let event = input.recv().expect("output worker event recv"); - // if !matches!(event, Message(_, _, _, _)) { - // unreachable!("why are we getting non-Message?"); - // } - // match event { - // AckMessage(_, _, seq_id) => { - // output - // .send(ZmqOutputPortEvent::Ack(seq_id)) - // .expect("worker loop ack send"); - // } - // - // CloseInput(_) => todo!(), - // - // AckConnection(_, _) => { - // unreachable!("already connected") - // } - // - // // ignore input port type events - // Connect(_, _) | CloseOutput(_, _) | Message(_, _, _, _) => continue, // TODO - // } - // } - // }); - // }); - //} - // - //// wait for the `Connected` event - //loop { - // let msg = from_worker_recv - // .lock() - // .recv() - // .map_err(|e| PortError::Other(e.to_string()))?; - // use ZmqOutputPortEvent::*; - // match msg { - // Connected(_) => break Ok(true), - // _ => continue, // TODO - // } - //} + let outputs = self.outputs.read(); + let Some(output_state) = outputs.get(&source) else { + return Err(PortError::Invalid(source.into())); + }; + + let output_state = output_state.read(); + let ZmqOutputPortState::Open(ref sender, _) = *output_state else { + return Err(PortError::Invalid(source.into())); + }; + + let (confirm_send, confirm_recv) = sync_channel(1); + + sender + .send((target, confirm_send)) + .map_err(|e| PortError::Other(e.to_string()))?; + + confirm_recv + .recv() + .map_err(|e| PortError::Other(e.to_string()))? + .map(|_| true) } fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { @@ -820,9 +698,11 @@ impl Transport for ZmqTransport { sender .send((ZmqOutputPortRequest::Send(message), ack_send)) - .unwrap(); + .map_err(|e| PortError::Other(e.to_string()))?; - ack_recv.recv().map_err(|_| PortError::Disconnected)? + ack_recv + .recv() + .map_err(|e| PortError::Other(e.to_string()))? } fn recv(&self, input: InputPortID) -> PortResult> { @@ -835,16 +715,17 @@ impl Transport for ZmqTransport { let ZmqInputPortState::Connected(_, _, _, receiver, _, _, _) = &*input else { return Err(PortError::Disconnected); }; + let receiver = receiver.lock(); loop { use ZmqInputPortEvent::*; - match receiver.lock().recv() { - // ignore - Ok(Opened) | Ok(Connected(_)) => continue, - + match receiver.recv() { Ok(Closed) => break Ok(None), // EOS Ok(Message(bytes)) => break Ok(Some(bytes)), Err(e) => break Err(PortError::Other(e.to_string())), + + // ignore + Ok(Opened) | Ok(Connected(_)) => continue, } } } From 8c9fe141e786bc2be8f2c2c787019d22fe5d08a5 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 17:16:05 +0200 Subject: [PATCH 19/63] Remove unused channels from InputPortState --- lib/protoflow-zeromq/src/lib.rs | 45 ++++++++++++++------------------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index b7936a5b..eb94f503 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -20,9 +20,7 @@ use std::{ sync::mpsc::{sync_channel, Receiver, SyncSender}, write, }; -use zeromq::{ - util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqError, ZmqMessage, -}; +use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqMessage}; const DEFAULT_PUB_SOCKET: &str = "tcp://127.0.0.1:10000"; const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; @@ -76,18 +74,15 @@ enum ZmqInputPortState { Open( // TODO: hide these Arc>, - Arc>>, ), Connected( - // channels for requests from public close + // channel for requests from public close Arc>)>>, - Arc>)>>>, - // channels for the public recv + // channel for the public recv Arc>, Arc>>, - // internal channels for events + // internal channel for events Arc>, - Arc>>, // vec of the connected port ids Vec, ), @@ -103,8 +98,8 @@ impl ZmqInputPortState { fn state(&self) -> PortState { use ZmqInputPortState::*; match self { - Open(_, _) => PortState::Open, - Connected(_, _, _, _, _, _, _) => PortState::Connected, + Open(_) => PortState::Open, + Connected(_, _, _, _, _) => PortState::Connected, Closed => PortState::Closed, } } @@ -299,7 +294,7 @@ impl ZmqTransport { { let mut inputs = self.inputs.write(); - let state = ZmqInputPortState::Open(to_worker_send.clone(), to_worker_recv.clone()); + let state = ZmqInputPortState::Open(to_worker_send.clone()); let state = RwLock::new(state); inputs.insert(input_port_id, state); } @@ -328,8 +323,8 @@ impl ZmqTransport { use ZmqInputPortState::*; match &*input_state { - Open(_, _) => (), - Connected(_, _, _, _, _, _, connected_ids) => { + Open(_) => (), + Connected(_, _, _, _, connected_ids) => { if !connected_ids.iter().any(|&id| id == output_port_id) { continue; } @@ -348,11 +343,9 @@ impl ZmqTransport { *input_state = ZmqInputPortState::Connected( req_send, - req_recv, msgs_send, msgs_recv, to_worker_send.clone(), - input.clone(), vec![output_port_id], ); }; @@ -374,7 +367,7 @@ impl ZmqTransport { use ZmqInputPortState::*; match &*input_state { - Connected(_, _, sender, _, _, _, connected_ids) => { + Connected(_, sender, _, _, connected_ids) => { if !connected_ids.iter().any(|id| *id == output_port_id) { continue; } @@ -392,7 +385,7 @@ impl ZmqTransport { .expect("input worker message ack"); } - Open(_, _) | Closed => todo!(), + Open(_) | Closed => todo!(), } } CloseOutput(output_port_id, input_port_id) => { @@ -403,7 +396,7 @@ impl ZmqTransport { let mut input_state = input_state.upgradable_read(); use ZmqInputPortState::*; - let Connected(_, _, _, _, _, _, ref connected_ids) = *input_state else { + let Connected(_, _, _, _, ref connected_ids) = *input_state else { continue; }; @@ -418,8 +411,8 @@ impl ZmqTransport { // .expect("input worker closeoutput unsub"); input_state.with_upgraded(|state| match state { - Open(_, _) | Closed => (), - Connected(_, _, _, _, _, _, connected_ids) => { + Open(_) | Closed => (), + Connected(_, _, _, _, connected_ids) => { connected_ids.retain(|&id| id != output_port_id) } }) @@ -619,7 +612,7 @@ impl Transport for ZmqTransport { let state = state.read(); - let ZmqInputPortState::Connected(sender, _, _, _, _, _, _) = &*state else { + let ZmqInputPortState::Connected(sender, _, _, _, _) = &*state else { return Err(PortError::Disconnected); }; @@ -712,7 +705,7 @@ impl Transport for ZmqTransport { }; let input = input.read(); - let ZmqInputPortState::Connected(_, _, _, receiver, _, _, _) = &*input else { + let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input else { return Err(PortError::Disconnected); }; let receiver = receiver.lock(); @@ -757,7 +750,7 @@ fn handle_zmq_msg( use ZmqInputPortState::*; match &*input { Closed => todo!(), - Open(sender, _) | Connected(_, _, _, _, sender, _, _) => { + Open(sender) | Connected(_, _, _, sender, _) => { sender.send(event).unwrap(); } }; @@ -769,7 +762,7 @@ fn handle_zmq_msg( }; let input = input.read(); - let ZmqInputPortState::Connected(_, _, sender, _, _, _, ids) = &*input else { + let ZmqInputPortState::Connected(_, sender, _, _, ids) = &*input else { todo!(); }; @@ -790,7 +783,7 @@ fn handle_zmq_msg( use ZmqInputPortState::*; match &*input { Closed => todo!(), - Open(sender, _) | Connected(_, _, _, _, sender, _, _) => { + Open(sender) | Connected(_, _, _, sender, _) => { sender.send(event).unwrap(); } }; From b5f13c342be93ac87fea596eba767ca02ec0812a Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 17:37:50 +0200 Subject: [PATCH 20/63] Remove needless `Arc>` --- lib/protoflow-zeromq/src/lib.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index eb94f503..c364756d 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -290,7 +290,6 @@ impl ZmqTransport { fn start_input_worker(&self, input_port_id: InputPortID) -> Result<(), PortError> { let (to_worker_send, to_worker_recv) = sync_channel(1); let to_worker_send = Arc::new(to_worker_send); - let to_worker_recv = Arc::new(Mutex::new(to_worker_recv)); { let mut inputs = self.inputs.write(); @@ -303,7 +302,7 @@ impl ZmqTransport { let pub_queue = self.pub_queue.clone(); let inputs = self.inputs.clone(); tokio::task::spawn(async move { - let input = &to_worker_recv; + let input = to_worker_recv; let inputs = inputs; // Input worker loop: @@ -311,7 +310,7 @@ impl ZmqTransport { // 2. Receive messages and forward to channel // 3. Receive and handle disconnects loop { - let event: ZmqTransportEvent = input.lock().recv().expect("input worker recv"); + let event: ZmqTransportEvent = input.recv().expect("input worker recv"); use ZmqTransportEvent::*; match event { Connect(output_port_id, input_port_id) => { From 90713cb1e6695524ac8216ff6ad5ed0cd8ed3abf Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 17:39:23 +0200 Subject: [PATCH 21/63] Simplify sub socket worker --- lib/protoflow-zeromq/src/lib.rs | 41 ++++++++++++++------------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index c364756d..f4e5204d 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -291,6 +291,10 @@ impl ZmqTransport { let (to_worker_send, to_worker_recv) = sync_channel(1); let to_worker_send = Arc::new(to_worker_send); + let (req_send, req_recv) = sync_channel(1); + let req_send = Arc::new(req_send); + let _req_recv = Arc::new(Mutex::new(req_recv)); + { let mut inputs = self.inputs.write(); let state = ZmqInputPortState::Open(to_worker_send.clone()); @@ -305,6 +309,8 @@ impl ZmqTransport { let input = to_worker_recv; let inputs = inputs; + // TODO: loop for req_recv.recv(), i.e. requests from public methods + // Input worker loop: // 1. Receive connection attempts and respond // 2. Receive messages and forward to channel @@ -332,16 +338,12 @@ impl ZmqTransport { } let open = |input_state: &mut ZmqInputPortState| { - let (req_send, req_recv) = sync_channel(1); - let req_send = Arc::new(req_send); - let req_recv = Arc::new(Mutex::new(req_recv)); - let (msgs_send, msgs_recv) = sync_channel(1); let msgs_send = Arc::new(msgs_send); let msgs_recv = Arc::new(Mutex::new(msgs_recv)); *input_state = ZmqInputPortState::Connected( - req_send, + req_send.clone(), msgs_send, msgs_recv, to_worker_send.clone(), @@ -754,23 +756,18 @@ fn handle_zmq_msg( } }; } - Message(output_port_id, input_port_id, _, bytes) => { + Message(_, input_port_id, _, _) => { let inputs = inputs.read(); let Some(input) = inputs.get(&input_port_id) else { todo!(); }; let input = input.read(); - let ZmqInputPortState::Connected(_, sender, _, _, ids) = &*input else { + let ZmqInputPortState::Connected(_, _, _, sender, _) = &*input else { todo!(); }; - // TODO: probably move to ports worker? no sense having here - if !ids.iter().any(|&id| id == output_port_id) { - todo!(); - } - - sender.send(ZmqInputPortEvent::Message(bytes)).unwrap(); + sender.send(event).unwrap(); } CloseOutput(_, input_port_id) => { let inputs = inputs.read(); @@ -807,27 +804,23 @@ fn handle_zmq_msg( todo!(); }; let output = output.read(); - let ZmqOutputPortState::Connected(_, sender, _) = &*output else { todo!(); }; sender.send(event).unwrap(); } CloseInput(input_port_id) => { - let outputs = outputs.read(); - - for (_, state) in outputs.iter() { + for (_, state) in outputs.read().iter() { let state = state.read(); - - let ZmqOutputPortState::Connected(_, sender, id) = &*state else { - todo!(); + let ZmqOutputPortState::Connected(_, ref sender, ref id) = *state else { + continue; }; - if *id != input_port_id { - todo!(); + continue; + } + if let Err(_e) = sender.send(event.clone()) { + continue; // TODO } - - sender.send(event.clone()).unwrap(); } } } From 2d146b641c12d3a9a69b1ab8b8c51294cd5126c6 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 18:03:35 +0200 Subject: [PATCH 22/63] Handle events from socket in input port worker --- lib/protoflow-zeromq/src/lib.rs | 283 +++++++++++++++++++------------- 1 file changed, 172 insertions(+), 111 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index f4e5204d..4374c8da 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -241,13 +241,17 @@ impl ZmqTransport { inputs, }; - transport.start_send_worker(psock, out_queue_recv); - transport.start_recv_worker(ssock, sub_queue_recv); + transport.start_pub_socket_worker(psock, out_queue_recv); + transport.start_sub_socket_worker(ssock, sub_queue_recv); transport } - fn start_send_worker(&self, psock: zeromq::PubSocket, queue: Receiver) { + fn start_pub_socket_worker( + &self, + psock: zeromq::PubSocket, + queue: Receiver, + ) { let tokio = self.tokio.clone(); let mut psock = psock; @@ -260,7 +264,7 @@ impl ZmqTransport { }); } - fn start_recv_worker( + fn start_sub_socket_worker( &self, ssock: zeromq::SubSocket, sub_queue: tokio::sync::mpsc::Receiver, @@ -279,7 +283,7 @@ impl ZmqTransport { use ZmqSubscriptionRequest::*; match req { Subscribe(topic) => ssock.subscribe(&topic).await.expect("zmq recv worker subscribe"), - Unsubscribe(topic) => ssock.unsubscribe(&topic).await.expect("zmq recv worker subscribe"), + Unsubscribe(topic) => ssock.unsubscribe(&topic).await.expect("zmq recv worker unsubscribe"), }; } }; @@ -293,7 +297,6 @@ impl ZmqTransport { let (req_send, req_recv) = sync_channel(1); let req_send = Arc::new(req_send); - let _req_recv = Arc::new(Mutex::new(req_recv)); { let mut inputs = self.inputs.write(); @@ -305,123 +308,179 @@ impl ZmqTransport { //let sub_queue = self.sub_queue.clone(); let pub_queue = self.pub_queue.clone(); let inputs = self.inputs.clone(); - tokio::task::spawn(async move { - let input = to_worker_recv; - let inputs = inputs; - - // TODO: loop for req_recv.recv(), i.e. requests from public methods - // Input worker loop: - // 1. Receive connection attempts and respond - // 2. Receive messages and forward to channel - // 3. Receive and handle disconnects - loop { - let event: ZmqTransportEvent = input.recv().expect("input worker recv"); - use ZmqTransportEvent::*; - match event { - Connect(output_port_id, input_port_id) => { - let inputs = inputs.read(); - let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); - }; - let mut input_state = input_state.upgradable_read(); - - use ZmqInputPortState::*; - match &*input_state { - Open(_) => (), - Connected(_, _, _, _, connected_ids) => { - if !connected_ids.iter().any(|&id| id == output_port_id) { - continue; - } + fn handle_socket_event( + event: ZmqTransportEvent, + inputs: &RwLock>>, + req_send: &Arc>)>>, + to_worker_send: &Arc>, + pub_queue: &SyncSender, + input_port_id: InputPortID, + ) { + use ZmqTransportEvent::*; + match event { + Connect(output_port_id, input_port_id) => { + let inputs = inputs.read(); + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let mut input_state = input_state.upgradable_read(); + + use ZmqInputPortState::*; + match &*input_state { + Open(_) => (), + Connected(_, _, _, _, connected_ids) => { + if !connected_ids.iter().any(|&id| id == output_port_id) { + return; } - Closed => continue, } - - let open = |input_state: &mut ZmqInputPortState| { - let (msgs_send, msgs_recv) = sync_channel(1); - let msgs_send = Arc::new(msgs_send); - let msgs_recv = Arc::new(Mutex::new(msgs_recv)); - - *input_state = ZmqInputPortState::Connected( - req_send.clone(), - msgs_send, - msgs_recv, - to_worker_send.clone(), - vec![output_port_id], - ); - }; - - pub_queue - .send(ZmqTransportEvent::AckConnection( - output_port_id, - input_port_id, - )) - .expect("input worker conn ack"); - input_state.with_upgraded(open); + Closed => return, } - Message(output_port_id, _, seq_id, bytes) => { - let inputs = inputs.read(); - let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); - }; - let input_state = input_state.read(); - use ZmqInputPortState::*; - match &*input_state { - Connected(_, sender, _, _, connected_ids) => { - if !connected_ids.iter().any(|id| *id == output_port_id) { - continue; - } + let open = |input_state: &mut ZmqInputPortState| { + let (msgs_send, msgs_recv) = sync_channel(1); + let msgs_send = Arc::new(msgs_send); + let msgs_recv = Arc::new(Mutex::new(msgs_recv)); + + *input_state = ZmqInputPortState::Connected( + req_send.clone(), + msgs_send, + msgs_recv, + to_worker_send.clone(), + vec![output_port_id], + ); + }; - sender - .send(ZmqInputPortEvent::Message(bytes)) - .expect("input worker message send"); - - pub_queue - .send(ZmqTransportEvent::AckMessage( - output_port_id, - input_port_id, - seq_id, - )) - .expect("input worker message ack"); + pub_queue + .send(ZmqTransportEvent::AckConnection( + output_port_id, + input_port_id, + )) + .expect("input worker send ack-conn event"); + input_state.with_upgraded(open); + } + Message(output_port_id, _, seq_id, bytes) => { + let inputs = inputs.read(); + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let input_state = input_state.read(); + + use ZmqInputPortState::*; + match &*input_state { + Connected(_, sender, _, _, connected_ids) => { + if !connected_ids.iter().any(|id| *id == output_port_id) { + return; } - Open(_) | Closed => todo!(), + sender + .send(ZmqInputPortEvent::Message(bytes)) + .expect("input worker send message"); + + pub_queue + .send(ZmqTransportEvent::AckMessage( + output_port_id, + input_port_id, + seq_id, + )) + .expect("input worker send message ack"); } + + Open(_) | Closed => todo!(), + } + } + CloseOutput(output_port_id, input_port_id) => { + let inputs = inputs.read(); + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let mut input_state = input_state.upgradable_read(); + + use ZmqInputPortState::*; + let Connected(_, _, _, _, ref connected_ids) = *input_state else { + return; + }; + + if !connected_ids.iter().any(|id| *id == output_port_id) { + return; } - CloseOutput(output_port_id, input_port_id) => { - let inputs = inputs.read(); - let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); - }; - let mut input_state = input_state.upgradable_read(); - use ZmqInputPortState::*; - let Connected(_, _, _, _, ref connected_ids) = *input_state else { - continue; - }; + // TODO: send unsubscription for relevant topics + //sub_queue + // .send(ZmqSubscriptionRequest::Unsubscribe("".to_string())) + // .await + // .expect("input worker closeoutput unsub"); - if !connected_ids.iter().any(|id| *id == output_port_id) { - continue; + input_state.with_upgraded(|state| match state { + Open(_) | Closed => (), + Connected(_, _, _, _, connected_ids) => { + connected_ids.retain(|&id| id != output_port_id) } + }) + } - // TODO: send unsubscription for relevant topics - //sub_queue - // .send(ZmqSubscriptionRequest::Unsubscribe("".to_string())) - // .await - // .expect("input worker closeoutput unsub"); + // ignore, ideally we never receive these here: + AckConnection(_, _) | AckMessage(_, _, _) | CloseInput(_) => (), + } + } - input_state.with_upgraded(|state| match state { - Open(_) | Closed => (), - Connected(_, _, _, _, connected_ids) => { - connected_ids.retain(|&id| id != output_port_id) - } - }) - } + fn handle_input_request( + request: ZmqInputPortRequest, + response_chan: SyncSender>, + inputs: &RwLock>>, + pub_queue: &SyncSender, + input_port_id: InputPortID, + ) { + use ZmqInputPortRequest::*; + match request { + Close => { + let inputs = inputs.read(); + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let mut input_state = input_state.upgradable_read(); + + use ZmqInputPortState::*; + let Connected(_, _, _, _, _) = *input_state else { + return; + }; + + pub_queue + .send(ZmqTransportEvent::CloseInput(input_port_id)) + .expect("input worker send close event"); + + input_state.with_upgraded(|state| *state = ZmqInputPortState::Closed); + + drop(input_state); + + response_chan + .send(Ok(())) + .expect("input worker respond close") + } + } + } - // ignore, ideally we never receive these here: - AckConnection(_, _) | AckMessage(_, _, _) | CloseInput(_) => continue, - }; + tokio::task::spawn(async move { + // Input worker loop: + // 1. Receive connection attempts and respond + // 2. Receive messages and forward to channel + // 3. Receive and handle disconnects + loop { + let event = to_worker_recv + .recv() + .expect("input worker recv socket event"); + handle_socket_event( + event, + &inputs, + &req_send, + &to_worker_send, + &pub_queue, + input_port_id, + ); + + let (request, response_chan) = + req_recv.recv().expect("input worker recv port request"); + handle_input_request(request, response_chan, &inputs, &pub_queue, input_port_id); } }); @@ -470,9 +529,11 @@ impl ZmqTransport { loop { pub_queue .send(ZmqTransportEvent::Connect(output_port_id, input_port_id)) - .expect("output worker send connect"); + .expect("output worker send connect event"); - let response = to_worker_recv.recv().expect("output worker recv conn ack"); + let response = to_worker_recv + .recv() + .expect("output worker recv ack-conn event"); use ZmqTransportEvent::*; match response { @@ -491,7 +552,7 @@ impl ZmqTransport { conn_confirm .send(Ok(())) - .expect("output worker send confirm conn"); + .expect("output worker respond conn"); break; } @@ -527,7 +588,7 @@ impl ZmqTransport { seq_id, bytes, )) - .expect("output worker send message"); + .expect("output worker send message event"); loop { let event = to_worker_recv.recv().expect("output worker event recv"); From a0617da6d81edc7e3b18474e4f2b36cf17e566bd Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 18:47:40 +0200 Subject: [PATCH 23/63] Begin transition from std::sync:mpsc to tokio::sync:mpsc --- lib/protoflow-zeromq/src/lib.rs | 220 +++++++++++++++++--------------- 1 file changed, 120 insertions(+), 100 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 4374c8da..39636fcc 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -15,11 +15,8 @@ use protoflow_core::{ use core::fmt::Error; use parking_lot::{Mutex, RwLock}; -use std::{ - format, - sync::mpsc::{sync_channel, Receiver, SyncSender}, - write, -}; +use std::{format, write}; +use tokio::sync::mpsc::{channel as sync_channel, Receiver, Sender}; use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqMessage}; const DEFAULT_PUB_SOCKET: &str = "tcp://127.0.0.1:10000"; @@ -28,8 +25,8 @@ const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; pub struct ZmqTransport { tokio: tokio::runtime::Handle, - pub_queue: Arc>, - sub_queue: Arc>, + pub_queue: Arc>, + sub_queue: Arc>, outputs: Arc>>>, inputs: Arc>>>, @@ -38,14 +35,14 @@ pub struct ZmqTransport { #[derive(Debug, Clone)] enum ZmqOutputPortState { Open( - Arc>)>>, - Arc>, + Arc>)>>, + Arc>, ), Connected( // channel for public send, contained channel is for the ack back - Arc>)>>, + Arc>)>>, // internal channel for events - Arc>, + Arc>, // id of the connected input port InputPortID, ), @@ -73,16 +70,16 @@ impl ZmqOutputPortState { enum ZmqInputPortState { Open( // TODO: hide these - Arc>, + Arc>, ), Connected( // channel for requests from public close - Arc>)>>, + Arc>)>>, // channel for the public recv - Arc>, + Arc>, Arc>>, // internal channel for events - Arc>, + Arc>, // vec of the connected port ids Vec, ), @@ -164,14 +161,14 @@ impl TryFrom for ZmqTransportEvent { } } -/// ZmqOutputPortEvent represents events that we receive from the background worker of the port. -#[derive(Clone, Debug)] -enum ZmqOutputPortEvent { - Opened, - Connected(InputPortID), - Ack(SequenceID), - Closed, -} +///// ZmqOutputPortEvent represents events that we receive from the background worker of the port. +//#[derive(Clone, Debug)] +//enum ZmqOutputPortEvent { +// Opened, +// Connected(InputPortID), +// Ack(SequenceID), +// Closed, +//} /// ZmqInputPortEvent represents events that we receive from the background worker of the port. #[derive(Clone, Debug)] @@ -226,7 +223,7 @@ impl ZmqTransport { let outputs = Arc::new(RwLock::new(BTreeMap::default())); let inputs = Arc::new(RwLock::new(BTreeMap::default())); - let (out_queue, out_queue_recv) = sync_channel(1); + let (out_queue, pub_queue) = sync_channel(1); let out_queue = Arc::new(out_queue); @@ -241,7 +238,7 @@ impl ZmqTransport { inputs, }; - transport.start_pub_socket_worker(psock, out_queue_recv); + transport.start_pub_socket_worker(psock, pub_queue); transport.start_sub_socket_worker(ssock, sub_queue_recv); transport @@ -250,35 +247,33 @@ impl ZmqTransport { fn start_pub_socket_worker( &self, psock: zeromq::PubSocket, - queue: Receiver, + pub_queue: Receiver, ) { let tokio = self.tokio.clone(); let mut psock = psock; - + let mut pub_queue = pub_queue; tokio::task::spawn(async move { - queue - .into_iter() - .map(ZmqMessage::from) - .try_for_each(|msg| tokio.block_on(psock.send(msg))) - .expect("zmq send worker") + for event in pub_queue.recv().await { + tokio + .block_on(psock.send(event.into())) + .expect("zmq pub-socket worker") + } }); } fn start_sub_socket_worker( &self, ssock: zeromq::SubSocket, - sub_queue: tokio::sync::mpsc::Receiver, + sub_queue: Receiver, ) { - let mut ssock = ssock; - let mut sub_queue = sub_queue; - let outputs = self.outputs.clone(); let inputs = self.inputs.clone(); - - tokio::task::spawn(async move { + let mut ssock = ssock; + let mut sub_queue = sub_queue; + tokio::task::spawn_local(async move { loop { tokio::select! { - Ok(msg) = ssock.recv() => handle_zmq_msg(msg, &outputs, &inputs).unwrap(), + Ok(msg) = ssock.recv() => handle_zmq_msg(msg, &outputs, &inputs).await.unwrap(), Some(req) = sub_queue.recv() => { use ZmqSubscriptionRequest::*; match req { @@ -292,14 +287,17 @@ impl ZmqTransport { } fn start_input_worker(&self, input_port_id: InputPortID) -> Result<(), PortError> { - let (to_worker_send, to_worker_recv) = sync_channel(1); + let (to_worker_send, mut to_worker_recv) = sync_channel(1); let to_worker_send = Arc::new(to_worker_send); - let (req_send, req_recv) = sync_channel(1); + let (req_send, mut req_recv) = sync_channel(1); let req_send = Arc::new(req_send); { let mut inputs = self.inputs.write(); + if inputs.contains_key(&input_port_id) { + return Ok(()); // TODO + } let state = ZmqInputPortState::Open(to_worker_send.clone()); let state = RwLock::new(state); inputs.insert(input_port_id, state); @@ -309,12 +307,12 @@ impl ZmqTransport { let pub_queue = self.pub_queue.clone(); let inputs = self.inputs.clone(); - fn handle_socket_event( + async fn handle_socket_event( event: ZmqTransportEvent, inputs: &RwLock>>, - req_send: &Arc>)>>, - to_worker_send: &Arc>, - pub_queue: &SyncSender, + req_send: &Arc>)>>, + to_worker_send: &Arc>, + pub_queue: &Sender, input_port_id: InputPortID, ) { use ZmqTransportEvent::*; @@ -356,6 +354,7 @@ impl ZmqTransport { output_port_id, input_port_id, )) + .await .expect("input worker send ack-conn event"); input_state.with_upgraded(open); } @@ -375,6 +374,7 @@ impl ZmqTransport { sender .send(ZmqInputPortEvent::Message(bytes)) + .await .expect("input worker send message"); pub_queue @@ -383,6 +383,7 @@ impl ZmqTransport { input_port_id, seq_id, )) + .await .expect("input worker send message ack"); } @@ -424,11 +425,11 @@ impl ZmqTransport { } } - fn handle_input_request( + async fn handle_input_request( request: ZmqInputPortRequest, - response_chan: SyncSender>, + response_chan: Sender>, inputs: &RwLock>>, - pub_queue: &SyncSender, + pub_queue: &Sender, input_port_id: InputPortID, ) { use ZmqInputPortRequest::*; @@ -447,6 +448,7 @@ impl ZmqTransport { pub_queue .send(ZmqTransportEvent::CloseInput(input_port_id)) + .await .expect("input worker send close event"); input_state.with_upgraded(|state| *state = ZmqInputPortState::Closed); @@ -455,12 +457,13 @@ impl ZmqTransport { response_chan .send(Ok(())) + .await .expect("input worker respond close") } } } - tokio::task::spawn(async move { + tokio::task::spawn_local(async move { // Input worker loop: // 1. Receive connection attempts and respond // 2. Receive messages and forward to channel @@ -468,6 +471,7 @@ impl ZmqTransport { loop { let event = to_worker_recv .recv() + .await .expect("input worker recv socket event"); handle_socket_event( event, @@ -476,11 +480,15 @@ impl ZmqTransport { &to_worker_send, &pub_queue, input_port_id, - ); + ) + .await; - let (request, response_chan) = - req_recv.recv().expect("input worker recv port request"); - handle_input_request(request, response_chan, &inputs, &pub_queue, input_port_id); + let (request, response_chan) = req_recv + .recv() + .await + .expect("input worker recv port request"); + handle_input_request(request, response_chan, &inputs, &pub_queue, input_port_id) + .await; } }); @@ -496,14 +504,17 @@ impl ZmqTransport { } fn start_output_worker(&self, output_port_id: OutputPortID) -> Result<(), PortError> { - let (conn_send, conn_recv) = sync_channel(1); + let (conn_send, mut conn_recv) = sync_channel(1); let conn_send = Arc::new(conn_send); - let (to_worker_send, to_worker_recv) = sync_channel(1); + let (to_worker_send, mut to_worker_recv) = sync_channel(1); let to_worker_send = Arc::new(to_worker_send); { let mut outputs = self.outputs.write(); + if outputs.contains_key(&output_port_id) { + return Ok(()); // TODO + } let state = ZmqOutputPortState::Open(conn_send, to_worker_send.clone()); let state = RwLock::new(state); outputs.insert(output_port_id, state); @@ -511,12 +522,12 @@ impl ZmqTransport { let outputs = self.outputs.clone(); let pub_queue = self.pub_queue.clone(); - tokio::task::spawn(async move { - let Ok((input_port_id, conn_confirm)) = conn_recv.recv() else { + tokio::task::spawn_local(async move { + let Some((input_port_id, conn_confirm)) = conn_recv.recv().await else { todo!(); }; - let (msg_req_send, msg_req_recv) = sync_channel(1); + let (msg_req_send, mut msg_req_recv) = sync_channel(1); let msg_req_send = Arc::new(msg_req_send); // Output worker loop: @@ -529,10 +540,12 @@ impl ZmqTransport { loop { pub_queue .send(ZmqTransportEvent::Connect(output_port_id, input_port_id)) + .await .expect("output worker send connect event"); let response = to_worker_recv .recv() + .await .expect("output worker recv ack-conn event"); use ZmqTransportEvent::*; @@ -552,6 +565,7 @@ impl ZmqTransport { conn_confirm .send(Ok(())) + .await .expect("output worker respond conn"); break; @@ -564,8 +578,10 @@ impl ZmqTransport { let mut seq_id = 1; 'outer: loop { - let (request, response_chan) = - msg_req_recv.recv().expect("output worker recv msg req"); + let (request, response_chan) = msg_req_recv + .recv() + .await + .expect("output worker recv msg req"); match request { ZmqOutputPortRequest::Close => { @@ -574,10 +590,12 @@ impl ZmqTransport { output_port_id, input_port_id, )) + .await .map_err(|e| PortError::Other(e.to_string())); response_chan .send(response) + .await .expect("output worker respond close"); } ZmqOutputPortRequest::Send(bytes) => { @@ -588,10 +606,14 @@ impl ZmqTransport { seq_id, bytes, )) + .await .expect("output worker send message event"); loop { - let event = to_worker_recv.recv().expect("output worker event recv"); + let event = to_worker_recv + .recv() + .await + .expect("output worker event recv"); use ZmqTransportEvent::*; match event { @@ -678,15 +700,15 @@ impl Transport for ZmqTransport { return Err(PortError::Disconnected); }; - let (close_send, close_recv) = sync_channel(1); + let (close_send, mut close_recv) = sync_channel(1); - sender - .send((ZmqInputPortRequest::Close, close_send)) + self.tokio + .block_on(sender.send((ZmqInputPortRequest::Close, close_send))) .map_err(|e| PortError::Other(e.to_string()))?; - close_recv - .recv() - .map_err(|_| PortError::Disconnected)? + self.tokio + .block_on(close_recv.recv()) + .ok_or(PortError::Disconnected)? .map(|_| true) } @@ -703,15 +725,15 @@ impl Transport for ZmqTransport { return Err(PortError::Disconnected); }; - let (close_send, close_recv) = sync_channel(1); + let (close_send, mut close_recv) = sync_channel(1); - sender - .send((ZmqOutputPortRequest::Close, close_send)) + self.tokio + .block_on(sender.send((ZmqOutputPortRequest::Close, close_send))) .map_err(|e| PortError::Other(e.to_string()))?; - close_recv - .recv() - .map_err(|e| PortError::Other(e.to_string()))? + self.tokio + .block_on(close_recv.recv()) + .ok_or(PortError::Disconnected)? .map(|_| true) } @@ -726,15 +748,15 @@ impl Transport for ZmqTransport { return Err(PortError::Invalid(source.into())); }; - let (confirm_send, confirm_recv) = sync_channel(1); + let (confirm_send, mut confirm_recv) = sync_channel(1); - sender - .send((target, confirm_send)) + self.tokio + .block_on(sender.send((target, confirm_send))) .map_err(|e| PortError::Other(e.to_string()))?; - confirm_recv - .recv() - .map_err(|e| PortError::Other(e.to_string()))? + self.tokio + .block_on(confirm_recv.recv()) + .ok_or(PortError::Disconnected)? .map(|_| true) } @@ -749,15 +771,15 @@ impl Transport for ZmqTransport { return Err(PortError::Disconnected); }; - let (ack_send, ack_recv) = sync_channel(1); + let (ack_send, mut ack_recv) = sync_channel(1); - sender - .send((ZmqOutputPortRequest::Send(message), ack_send)) + self.tokio + .block_on(sender.send((ZmqOutputPortRequest::Send(message), ack_send))) .map_err(|e| PortError::Other(e.to_string()))?; - ack_recv - .recv() - .map_err(|e| PortError::Other(e.to_string()))? + self.tokio + .block_on(ack_recv.recv()) + .ok_or(PortError::Disconnected)? } fn recv(&self, input: InputPortID) -> PortResult> { @@ -770,17 +792,17 @@ impl Transport for ZmqTransport { let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input else { return Err(PortError::Disconnected); }; - let receiver = receiver.lock(); + let mut receiver = receiver.lock(); loop { use ZmqInputPortEvent::*; - match receiver.recv() { - Ok(Closed) => break Ok(None), // EOS - Ok(Message(bytes)) => break Ok(Some(bytes)), - Err(e) => break Err(PortError::Other(e.to_string())), + match self.tokio.block_on(receiver.recv()) { + Some(Closed) => break Ok(None), // EOS + Some(Message(bytes)) => break Ok(Some(bytes)), + None => break Err(PortError::Disconnected), // ignore - Ok(Opened) | Ok(Connected(_)) => continue, + Some(Opened) | Some(Connected(_)) => continue, } } } @@ -790,7 +812,7 @@ impl Transport for ZmqTransport { } } -fn handle_zmq_msg( +async fn handle_zmq_msg( msg: ZmqMessage, outputs: &RwLock>>, inputs: &RwLock>>, @@ -812,9 +834,7 @@ fn handle_zmq_msg( use ZmqInputPortState::*; match &*input { Closed => todo!(), - Open(sender) | Connected(_, _, _, sender, _) => { - sender.send(event).unwrap(); - } + Open(sender) | Connected(_, _, _, sender, _) => sender.send(event).await.unwrap(), }; } Message(_, input_port_id, _, _) => { @@ -828,7 +848,7 @@ fn handle_zmq_msg( todo!(); }; - sender.send(event).unwrap(); + sender.send(event).await.unwrap(); } CloseOutput(_, input_port_id) => { let inputs = inputs.read(); @@ -841,7 +861,7 @@ fn handle_zmq_msg( match &*input { Closed => todo!(), Open(sender) | Connected(_, _, _, sender, _) => { - sender.send(event).unwrap(); + sender.send(event).await.unwrap(); } }; } @@ -857,7 +877,7 @@ fn handle_zmq_msg( let ZmqOutputPortState::Open(_, sender) = &*output else { todo!(); }; - sender.send(event).unwrap(); + sender.send(event).await.unwrap(); } AckMessage(output_port_id, _, _) => { let outputs = outputs.read(); @@ -868,7 +888,7 @@ fn handle_zmq_msg( let ZmqOutputPortState::Connected(_, sender, _) = &*output else { todo!(); }; - sender.send(event).unwrap(); + sender.send(event).await.unwrap(); } CloseInput(input_port_id) => { for (_, state) in outputs.read().iter() { @@ -879,7 +899,7 @@ fn handle_zmq_msg( if *id != input_port_id { continue; } - if let Err(_e) = sender.send(event.clone()) { + if let Err(_e) = sender.send(event.clone()).await { continue; // TODO } } From 9045d93c135650cfdac01d2d249f5067604eb9ee Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 19:07:31 +0200 Subject: [PATCH 24/63] Send InputPortEvent::Closed --- lib/protoflow-zeromq/src/lib.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 39636fcc..b8275593 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -253,7 +253,7 @@ impl ZmqTransport { let mut psock = psock; let mut pub_queue = pub_queue; tokio::task::spawn(async move { - for event in pub_queue.recv().await { + while let Some(event) = pub_queue.recv().await { tokio .block_on(psock.send(event.into())) .expect("zmq pub-socket worker") @@ -325,7 +325,7 @@ impl ZmqTransport { let mut input_state = input_state.upgradable_read(); use ZmqInputPortState::*; - match &*input_state { + let port_events = match &*input_state { Open(_) => (), Connected(_, _, _, _, connected_ids) => { if !connected_ids.iter().any(|&id| id == output_port_id) { @@ -333,14 +333,14 @@ impl ZmqTransport { } } Closed => return, - } + }; - let open = |input_state: &mut ZmqInputPortState| { + let make_connected = |input_state: &mut ZmqInputPortState| { let (msgs_send, msgs_recv) = sync_channel(1); let msgs_send = Arc::new(msgs_send); let msgs_recv = Arc::new(Mutex::new(msgs_recv)); - *input_state = ZmqInputPortState::Connected( + *input_state = Connected( req_send.clone(), msgs_send, msgs_recv, @@ -356,7 +356,8 @@ impl ZmqTransport { )) .await .expect("input worker send ack-conn event"); - input_state.with_upgraded(open); + + input_state.with_upgraded(make_connected); } Message(output_port_id, _, seq_id, bytes) => { let inputs = inputs.read(); @@ -442,7 +443,7 @@ impl ZmqTransport { let mut input_state = input_state.upgradable_read(); use ZmqInputPortState::*; - let Connected(_, _, _, _, _) = *input_state else { + let Connected(_, ref port_events, _, _, _) = *input_state else { return; }; @@ -451,9 +452,9 @@ impl ZmqTransport { .await .expect("input worker send close event"); - input_state.with_upgraded(|state| *state = ZmqInputPortState::Closed); + port_events.send(ZmqInputPortEvent::Closed).await.unwrap(); - drop(input_state); + input_state.with_upgraded(|state| *state = ZmqInputPortState::Closed); response_chan .send(Ok(())) From c927e03d364791ffc3281fceb2a8dcc6e3164dd7 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 19:11:57 +0200 Subject: [PATCH 25/63] Remove `Arc`s --- lib/protoflow-zeromq/src/lib.rs | 39 +++++++++++++-------------------- 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index b8275593..c9f3a27e 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -25,8 +25,8 @@ const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; pub struct ZmqTransport { tokio: tokio::runtime::Handle, - pub_queue: Arc>, - sub_queue: Arc>, + pub_queue: Sender, + sub_queue: Sender, outputs: Arc>>>, inputs: Arc>>>, @@ -35,14 +35,14 @@ pub struct ZmqTransport { #[derive(Debug, Clone)] enum ZmqOutputPortState { Open( - Arc>)>>, - Arc>, + Sender<(InputPortID, Sender>)>, + Sender, ), Connected( // channel for public send, contained channel is for the ack back - Arc>)>>, + Sender<(ZmqOutputPortRequest, Sender>)>, // internal channel for events - Arc>, + Sender, // id of the connected input port InputPortID, ), @@ -70,16 +70,16 @@ impl ZmqOutputPortState { enum ZmqInputPortState { Open( // TODO: hide these - Arc>, + Sender, ), Connected( // channel for requests from public close - Arc>)>>, + Sender<(ZmqInputPortRequest, Sender>)>, // channel for the public recv - Arc>, + Sender, Arc>>, // internal channel for events - Arc>, + Sender, // vec of the connected port ids Vec, ), @@ -223,22 +223,19 @@ impl ZmqTransport { let outputs = Arc::new(RwLock::new(BTreeMap::default())); let inputs = Arc::new(RwLock::new(BTreeMap::default())); - let (out_queue, pub_queue) = sync_channel(1); - - let out_queue = Arc::new(out_queue); + let (pub_queue, pub_queue_recv) = sync_channel(1); let (sub_queue, sub_queue_recv) = tokio::sync::mpsc::channel(1); - let sub_queue = Arc::new(sub_queue); let transport = Self { - pub_queue: out_queue, + pub_queue, sub_queue, tokio, outputs, inputs, }; - transport.start_pub_socket_worker(psock, pub_queue); + transport.start_pub_socket_worker(psock, pub_queue_recv); transport.start_sub_socket_worker(ssock, sub_queue_recv); transport @@ -288,10 +285,8 @@ impl ZmqTransport { fn start_input_worker(&self, input_port_id: InputPortID) -> Result<(), PortError> { let (to_worker_send, mut to_worker_recv) = sync_channel(1); - let to_worker_send = Arc::new(to_worker_send); let (req_send, mut req_recv) = sync_channel(1); - let req_send = Arc::new(req_send); { let mut inputs = self.inputs.write(); @@ -310,8 +305,8 @@ impl ZmqTransport { async fn handle_socket_event( event: ZmqTransportEvent, inputs: &RwLock>>, - req_send: &Arc>)>>, - to_worker_send: &Arc>, + req_send: &Sender<(ZmqInputPortRequest, Sender>)>, + to_worker_send: &Sender, pub_queue: &Sender, input_port_id: InputPortID, ) { @@ -337,7 +332,6 @@ impl ZmqTransport { let make_connected = |input_state: &mut ZmqInputPortState| { let (msgs_send, msgs_recv) = sync_channel(1); - let msgs_send = Arc::new(msgs_send); let msgs_recv = Arc::new(Mutex::new(msgs_recv)); *input_state = Connected( @@ -506,10 +500,8 @@ impl ZmqTransport { fn start_output_worker(&self, output_port_id: OutputPortID) -> Result<(), PortError> { let (conn_send, mut conn_recv) = sync_channel(1); - let conn_send = Arc::new(conn_send); let (to_worker_send, mut to_worker_recv) = sync_channel(1); - let to_worker_send = Arc::new(to_worker_send); { let mut outputs = self.outputs.write(); @@ -529,7 +521,6 @@ impl ZmqTransport { }; let (msg_req_send, mut msg_req_recv) = sync_channel(1); - let msg_req_send = Arc::new(msg_req_send); // Output worker loop: // 1. Send connection attempt From 9bd59b21a8d54bd6977b6d5174c178b999200316 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 19:20:14 +0200 Subject: [PATCH 26/63] Report port closure from output worker --- lib/protoflow-zeromq/src/lib.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index c9f3a27e..c4139896 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -566,8 +566,6 @@ impl ZmqTransport { } } - // TODO: loop for to_worker_recv.recv(), i.e. events from socket - let mut seq_id = 1; 'outer: loop { let (request, response_chan) = msg_req_recv @@ -626,6 +624,11 @@ impl ZmqTransport { )); *output_state = ZmqOutputPortState::Closed; + response_chan + .send(Err(PortError::Disconnected)) + .await + .expect("output worker respond msg"); + break 'outer; } From ceddf937d2e1dcb82ac6934c862328dced34081f Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 19:25:36 +0200 Subject: [PATCH 27/63] Remove unused port event enum fields --- lib/protoflow-zeromq/src/lib.rs | 37 +++++++++++---------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index c4139896..e46a61f1 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -122,7 +122,7 @@ impl ZmqTransportEvent { match self { Connect(o, i) => write!(f, "{}:conn:{}", i, o), AckConnection(o, i) => write!(f, "{}:ackConn:{}", i, o), - Message(o, i, seq, _payload) => write!(f, "{}:msg:{}:{}", i, o, seq), + Message(o, i, seq, _) => write!(f, "{}:msg:{}:{}", i, o, seq), AckMessage(o, i, seq) => write!(f, "{}:ackMsg:{}:{}", i, o, seq), CloseOutput(o, i) => write!(f, "{}:closeOut:{}", i, o), CloseInput(i) => write!(f, "{}:closeIn", i), @@ -161,20 +161,9 @@ impl TryFrom for ZmqTransportEvent { } } -///// ZmqOutputPortEvent represents events that we receive from the background worker of the port. -//#[derive(Clone, Debug)] -//enum ZmqOutputPortEvent { -// Opened, -// Connected(InputPortID), -// Ack(SequenceID), -// Closed, -//} - /// ZmqInputPortEvent represents events that we receive from the background worker of the port. #[derive(Clone, Debug)] enum ZmqInputPortEvent { - Opened, - Connected(OutputPortID), Message(Bytes), Closed, } @@ -320,7 +309,7 @@ impl ZmqTransport { let mut input_state = input_state.upgradable_read(); use ZmqInputPortState::*; - let port_events = match &*input_state { + match &*input_state { Open(_) => (), Connected(_, _, _, _, connected_ids) => { if !connected_ids.iter().any(|&id| id == output_port_id) { @@ -412,7 +401,7 @@ impl ZmqTransport { Connected(_, _, _, _, connected_ids) => { connected_ids.retain(|&id| id != output_port_id) } - }) + }); } // ignore, ideally we never receive these here: @@ -446,7 +435,10 @@ impl ZmqTransport { .await .expect("input worker send close event"); - port_events.send(ZmqInputPortEvent::Closed).await.unwrap(); + port_events + .send(ZmqInputPortEvent::Closed) + .await + .expect("input worker send port closed"); input_state.with_upgraded(|state| *state = ZmqInputPortState::Closed); @@ -789,16 +781,11 @@ impl Transport for ZmqTransport { }; let mut receiver = receiver.lock(); - loop { - use ZmqInputPortEvent::*; - match self.tokio.block_on(receiver.recv()) { - Some(Closed) => break Ok(None), // EOS - Some(Message(bytes)) => break Ok(Some(bytes)), - None => break Err(PortError::Disconnected), - - // ignore - Some(Opened) | Some(Connected(_)) => continue, - } + use ZmqInputPortEvent::*; + match self.tokio.block_on(receiver.recv()) { + Some(Closed) => Ok(None), // EOS + Some(Message(bytes)) => Ok(Some(bytes)), + None => Err(PortError::Disconnected), } } From 0e6a50b42ba6c121df083e52facbba775695ba1b Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 19:45:34 +0200 Subject: [PATCH 28/63] Refactor input port worker --- lib/protoflow-zeromq/src/lib.rs | 57 +++++++++++++++------------------ 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index e46a61f1..a1523e8c 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -319,17 +319,22 @@ impl ZmqTransport { Closed => return, }; - let make_connected = |input_state: &mut ZmqInputPortState| { - let (msgs_send, msgs_recv) = sync_channel(1); - let msgs_recv = Arc::new(Mutex::new(msgs_recv)); - - *input_state = Connected( - req_send.clone(), - msgs_send, - msgs_recv, - to_worker_send.clone(), - vec![output_port_id], - ); + let add_connection = |input_state: &mut ZmqInputPortState| match input_state { + Open(_) => { + let (msgs_send, msgs_recv) = sync_channel(1); + let msgs_recv = Arc::new(Mutex::new(msgs_recv)); + *input_state = Connected( + req_send.clone(), + msgs_send, + msgs_recv, + to_worker_send.clone(), + vec![output_port_id], + ); + } + Connected(_, _, _, _, ids) => { + ids.push(output_port_id); + } + Closed => unreachable!(), }; pub_queue @@ -340,7 +345,7 @@ impl ZmqTransport { .await .expect("input worker send ack-conn event"); - input_state.with_upgraded(make_connected); + input_state.with_upgraded(add_connection); } Message(output_port_id, _, seq_id, bytes) => { let inputs = inputs.read(); @@ -456,26 +461,14 @@ impl ZmqTransport { // 2. Receive messages and forward to channel // 3. Receive and handle disconnects loop { - let event = to_worker_recv - .recv() - .await - .expect("input worker recv socket event"); - handle_socket_event( - event, - &inputs, - &req_send, - &to_worker_send, - &pub_queue, - input_port_id, - ) - .await; - - let (request, response_chan) = req_recv - .recv() - .await - .expect("input worker recv port request"); - handle_input_request(request, response_chan, &inputs, &pub_queue, input_port_id) - .await; + tokio::select! { + Some(event) = to_worker_recv.recv() => { + handle_socket_event(event, &inputs, &req_send, &to_worker_send, &pub_queue, input_port_id).await; + } + Some((request, response_chan)) = req_recv.recv() => { + handle_input_request(request, response_chan, &inputs, &pub_queue, input_port_id).await; + } + }; } }); From 561a8e163be85581cb37ab8aba59fa878484c51b Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 19:54:17 +0200 Subject: [PATCH 29/63] Send sub and unsub topics --- lib/protoflow-zeromq/src/lib.rs | 93 ++++++++++++++++++++++++++++----- 1 file changed, 81 insertions(+), 12 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index a1523e8c..fed04ab4 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -68,10 +68,7 @@ impl ZmqOutputPortState { #[derive(Debug, Clone)] enum ZmqInputPortState { - Open( - // TODO: hide these - Sender, - ), + Open(Sender), Connected( // channel for requests from public close Sender<(ZmqInputPortRequest, Sender>)>, @@ -130,6 +127,22 @@ impl ZmqTransportEvent { } } +fn input_topics(id: InputPortID) -> Vec { + vec![ + format!("{}:conn", id), + format!("{}:msg", id), + format!("{}:closeOut", id), + ] +} + +fn output_topics(source: OutputPortID, target: InputPortID) -> Vec { + vec![ + format!("{}:ackConn:{}", target, source), + format!("{}:ackMsg:{}:", target, source), + format!("{}:closeIn", target), + ] +} + impl From for ZmqMessage { fn from(value: ZmqTransportEvent) -> Self { let mut topic = Vec::new(); @@ -287,7 +300,22 @@ impl ZmqTransport { inputs.insert(input_port_id, state); } - //let sub_queue = self.sub_queue.clone(); + { + let mut handles = Vec::new(); + for topic in input_topics(input_port_id).into_iter() { + let handle = self + .sub_queue + .send(ZmqSubscriptionRequest::Subscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + self.tokio + .block_on(handle) + .expect("input worker send sub req"); + } + } + + let sub_queue = self.sub_queue.clone(); let pub_queue = self.pub_queue.clone(); let inputs = self.inputs.clone(); @@ -297,6 +325,7 @@ impl ZmqTransport { req_send: &Sender<(ZmqInputPortRequest, Sender>)>, to_worker_send: &Sender, pub_queue: &Sender, + sub_queue: &Sender, input_port_id: InputPortID, ) { use ZmqTransportEvent::*; @@ -395,11 +424,12 @@ impl ZmqTransport { return; } - // TODO: send unsubscription for relevant topics - //sub_queue - // .send(ZmqSubscriptionRequest::Unsubscribe("".to_string())) - // .await - // .expect("input worker closeoutput unsub"); + for topic in input_topics(input_port_id).into_iter() { + sub_queue + .send(ZmqSubscriptionRequest::Unsubscribe(topic)) + .await + .expect("input worker send unsub req"); + } input_state.with_upgraded(|state| match state { Open(_) | Closed => (), @@ -463,7 +493,7 @@ impl ZmqTransport { loop { tokio::select! { Some(event) = to_worker_recv.recv() => { - handle_socket_event(event, &inputs, &req_send, &to_worker_send, &pub_queue, input_port_id).await; + handle_socket_event(event, &inputs, &req_send, &to_worker_send, &pub_queue, &sub_queue, input_port_id).await; } Some((request, response_chan)) = req_recv.recv() => { handle_input_request(request, response_chan, &inputs, &pub_queue, input_port_id).await; @@ -498,13 +528,26 @@ impl ZmqTransport { outputs.insert(output_port_id, state); } - let outputs = self.outputs.clone(); + let sub_queue = self.sub_queue.clone(); let pub_queue = self.pub_queue.clone(); + let outputs = self.outputs.clone(); + tokio::task::spawn_local(async move { let Some((input_port_id, conn_confirm)) = conn_recv.recv().await else { todo!(); }; + { + let mut handles = Vec::new(); + for topic in output_topics(output_port_id, input_port_id).into_iter() { + let handle = sub_queue.send(ZmqSubscriptionRequest::Subscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + handle.await.expect("output worker send sub req"); + } + } + let (msg_req_send, mut msg_req_recv) = sync_channel(1); // Output worker loop: @@ -568,6 +611,18 @@ impl ZmqTransport { .await .map_err(|e| PortError::Other(e.to_string())); + { + let mut handles = Vec::new(); + for topic in output_topics(output_port_id, input_port_id).into_iter() { + let handle = + sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + handle.await.expect("output worker send unsub req"); + } + } + response_chan .send(response) .await @@ -609,6 +664,20 @@ impl ZmqTransport { )); *output_state = ZmqOutputPortState::Closed; + { + let mut handles = Vec::new(); + for topic in + output_topics(output_port_id, input_port_id).into_iter() + { + let handle = sub_queue + .send(ZmqSubscriptionRequest::Unsubscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + handle.await.expect("output worker send unsub req"); + } + } + response_chan .send(Err(PortError::Disconnected)) .await From d2a930cc8909e4e710de2966186d348952df01a9 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 20:16:49 +0200 Subject: [PATCH 30/63] Simplify input worker's inner fn signature --- lib/protoflow-zeromq/src/lib.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index fed04ab4..d703a0ee 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -149,14 +149,14 @@ impl From for ZmqMessage { value.write_topic(&mut topic).unwrap(); // first frame of the message is the topic - let mut msg = ZmqMessage::from(topic.clone()); + let mut msg = ZmqMessage::from(topic); // second frame of the message is the payload use ZmqTransportEvent::*; match value { Connect(_, _) => todo!(), AckConnection(_, _) => todo!(), - Message(_, _, _, bytes) => msg.push_back(bytes), + Message(_, _, _, bytes) => todo!(), AckMessage(_, _, _) => todo!(), CloseOutput(_, _) => todo!(), CloseInput(_) => todo!(), @@ -323,7 +323,6 @@ impl ZmqTransport { event: ZmqTransportEvent, inputs: &RwLock>>, req_send: &Sender<(ZmqInputPortRequest, Sender>)>, - to_worker_send: &Sender, pub_queue: &Sender, sub_queue: &Sender, input_port_id: InputPortID, @@ -349,7 +348,7 @@ impl ZmqTransport { }; let add_connection = |input_state: &mut ZmqInputPortState| match input_state { - Open(_) => { + Open(to_worker_send) => { let (msgs_send, msgs_recv) = sync_channel(1); let msgs_recv = Arc::new(Mutex::new(msgs_recv)); *input_state = Connected( @@ -493,7 +492,7 @@ impl ZmqTransport { loop { tokio::select! { Some(event) = to_worker_recv.recv() => { - handle_socket_event(event, &inputs, &req_send, &to_worker_send, &pub_queue, &sub_queue, input_port_id).await; + handle_socket_event(event, &inputs, &req_send, &pub_queue, &sub_queue, input_port_id).await; } Some((request, response_chan)) = req_recv.recv() => { handle_input_request(request, response_chan, &inputs, &pub_queue, input_port_id).await; From 3d8d7147f3b953ca159fcbf76a3f5a9e7112ad3d Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Fri, 29 Nov 2024 22:56:32 +0200 Subject: [PATCH 31/63] Add serialization layer --- lib/protoflow-zeromq/Cargo.toml | 3 + lib/protoflow-zeromq/build.rs | 6 + lib/protoflow-zeromq/src/lib.rs | 118 ++++++++++++++++-- lib/protoflow-zeromq/src/protoflow_zmq.rs | 70 +++++++++++ .../src/transport_event.proto | 46 +++++++ 5 files changed, 233 insertions(+), 10 deletions(-) create mode 100644 lib/protoflow-zeromq/build.rs create mode 100644 lib/protoflow-zeromq/src/protoflow_zmq.rs create mode 100644 lib/protoflow-zeromq/src/transport_event.proto diff --git a/lib/protoflow-zeromq/Cargo.toml b/lib/protoflow-zeromq/Cargo.toml index e28a73e7..8348cf6a 100644 --- a/lib/protoflow-zeromq/Cargo.toml +++ b/lib/protoflow-zeromq/Cargo.toml @@ -23,6 +23,7 @@ unstable = ["protoflow-core/unstable"] [build-dependencies] cfg_aliases.workspace = true +prost-build = "0.13.2" [dependencies] protoflow-core.workspace = true @@ -33,6 +34,8 @@ zeromq = { version = "0.4.1", default-features = false, features = [ ] } tokio = { version = "1.40.0", default-features = false } parking_lot = "0.12" +prost = "0.13.2" +prost-types = "0.13.2" [dev-dependencies] futures-util = "0.3.31" diff --git a/lib/protoflow-zeromq/build.rs b/lib/protoflow-zeromq/build.rs new file mode 100644 index 00000000..8449adc0 --- /dev/null +++ b/lib/protoflow-zeromq/build.rs @@ -0,0 +1,6 @@ +use std::io::Result; +fn main() -> Result<()> { + prost_build::Config::default() + .out_dir("src/") + .compile_protos(&["src/transport_event.proto"], &["src/"]) +} diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index d703a0ee..50c54ab7 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -6,16 +6,19 @@ #[doc(hidden)] pub use protoflow_core::prelude; +mod protoflow_zmq; + extern crate std; use protoflow_core::{ prelude::{vec, Arc, BTreeMap, Bytes, String, ToString, Vec}, InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, }; +use protoflow_zmq::AckMessage; use core::fmt::Error; use parking_lot::{Mutex, RwLock}; -use std::{format, write}; +use std::{format, io::Read, write}; use tokio::sync::mpsc::{channel as sync_channel, Receiver, Sender}; use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqMessage}; @@ -151,17 +154,52 @@ impl From for ZmqMessage { // first frame of the message is the topic let mut msg = ZmqMessage::from(topic); + fn map_id(id: T) -> i64 + where + isize: From, + { + isize::from(id) as i64 + } + // second frame of the message is the payload + use prost::Message; + use protoflow_zmq::{event::Payload, Event}; use ZmqTransportEvent::*; - match value { - Connect(_, _) => todo!(), - AckConnection(_, _) => todo!(), - Message(_, _, _, bytes) => todo!(), - AckMessage(_, _, _) => todo!(), - CloseOutput(_, _) => todo!(), - CloseInput(_) => todo!(), + let payload = match value { + Connect(output, input) => Payload::Connect(protoflow_zmq::Connect { + output: map_id(output), + input: map_id(input), + }), + AckConnection(output, input) => Payload::AckConnection(protoflow_zmq::AckConnection { + output: map_id(output), + input: map_id(input), + }), + Message(output, input, sequence, message) => Payload::Message(protoflow_zmq::Message { + output: map_id(output), + input: map_id(input), + sequence, + message: message.to_vec(), + }), + AckMessage(output, input, sequence) => Payload::AckMessage(protoflow_zmq::AckMessage { + output: map_id(output), + input: map_id(input), + sequence, + }), + CloseOutput(output, input) => Payload::CloseOutput(protoflow_zmq::CloseOutput { + output: map_id(output), + input: map_id(input), + }), + CloseInput(input) => Payload::CloseInput(protoflow_zmq::CloseInput { + input: map_id(input), + }), }; + let bytes = Event { + payload: Some(payload), + } + .encode_to_vec(); + msg.push_back(bytes.into()); + msg } } @@ -169,8 +207,68 @@ impl From for ZmqMessage { impl TryFrom for ZmqTransportEvent { type Error = protoflow_core::DecodeError; - fn try_from(_value: ZmqMessage) -> Result { - todo!() + fn try_from(value: ZmqMessage) -> Result { + use prost::Message; + use protoflow_core::DecodeError; + use protoflow_zmq::{event::Payload, Event}; + + fn map_id(id: i64) -> Result + where + T: TryFrom, + std::borrow::Cow<'static, str>: From<>::Error>, + { + (id as isize).try_into().map_err(DecodeError::new) + } + + value + .get(1) + .ok_or_else(|| { + protoflow_core::DecodeError::new( + "message from socket contains less than two frames", + ) + }) + .and_then(|bytes| { + let event = Event::decode(bytes.as_ref())?; + + use ZmqTransportEvent::*; + Ok(match event.payload { + None => todo!(), + Some(Payload::Connect(protoflow_zmq::Connect { output, input })) => { + Connect(map_id(output)?, map_id(input)?) + } + + Some(Payload::AckConnection(protoflow_zmq::AckConnection { + output, + input, + })) => AckConnection(map_id(output)?, map_id(input)?), + + Some(Payload::Message(protoflow_zmq::Message { + output, + input, + sequence, + message, + })) => Message( + map_id(output)?, + map_id(input)?, + sequence, + Bytes::from(message), + ), + + Some(Payload::AckMessage(protoflow_zmq::AckMessage { + output, + input, + sequence, + })) => AckMessage(map_id(output)?, map_id(input)?, sequence), + + Some(Payload::CloseOutput(protoflow_zmq::CloseOutput { output, input })) => { + CloseOutput(map_id(output)?, map_id(input)?) + } + + Some(Payload::CloseInput(protoflow_zmq::CloseInput { input })) => { + CloseInput(map_id(input)?) + } + }) + }) } } diff --git a/lib/protoflow-zeromq/src/protoflow_zmq.rs b/lib/protoflow-zeromq/src/protoflow_zmq.rs new file mode 100644 index 00000000..1afabab7 --- /dev/null +++ b/lib/protoflow-zeromq/src/protoflow_zmq.rs @@ -0,0 +1,70 @@ +// This file is @generated by prost-build. +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct Connect { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct AckConnection { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Message { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, + #[prost(uint64, tag = "3")] + pub sequence: u64, + #[prost(bytes = "vec", tag = "4")] + pub message: ::prost::alloc::vec::Vec, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct AckMessage { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, + #[prost(uint64, tag = "3")] + pub sequence: u64, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct CloseOutput { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct CloseInput { + #[prost(int64, tag = "1")] + pub input: i64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Event { + #[prost(oneof = "event::Payload", tags = "1, 2, 3, 4, 5, 6")] + pub payload: ::core::option::Option, +} +/// Nested message and enum types in `Event`. +pub mod event { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Payload { + #[prost(message, tag = "1")] + Connect(super::Connect), + #[prost(message, tag = "2")] + AckConnection(super::AckConnection), + #[prost(message, tag = "3")] + Message(super::Message), + #[prost(message, tag = "4")] + AckMessage(super::AckMessage), + #[prost(message, tag = "5")] + CloseOutput(super::CloseOutput), + #[prost(message, tag = "6")] + CloseInput(super::CloseInput), + } +} diff --git a/lib/protoflow-zeromq/src/transport_event.proto b/lib/protoflow-zeromq/src/transport_event.proto new file mode 100644 index 00000000..b88dabbe --- /dev/null +++ b/lib/protoflow-zeromq/src/transport_event.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package protoflow_zmq; + +message Connect { + int64 output = 1; + int64 input = 2; +} + +message AckConnection { + int64 output = 1; + int64 input = 2; +} + +message Message { + int64 output = 1; + int64 input = 2; + uint64 sequence = 3; + bytes message = 4; +} + +message AckMessage { + int64 output = 1; + int64 input = 2; + uint64 sequence = 3; +} + +message CloseOutput { + int64 output = 1; + int64 input = 2; +} + +message CloseInput { + int64 input = 1; +} + +message Event { + oneof payload { + Connect connect = 1; + AckConnection ack_connection = 2; + Message message = 3; + AckMessage ack_message = 4; + CloseOutput close_output = 5; + CloseInput close_input = 6; + } +} From 3bfdafdef7b042d7868f984e7053870ff0f0e36a Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 06:34:16 +0200 Subject: [PATCH 32/63] Change to tokio's mutices --- lib/protoflow-zeromq/src/lib.rs | 284 ++++++++++++++++++-------------- 1 file changed, 156 insertions(+), 128 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 50c54ab7..a419f3d4 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -14,12 +14,13 @@ use protoflow_core::{ prelude::{vec, Arc, BTreeMap, Bytes, String, ToString, Vec}, InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, }; -use protoflow_zmq::AckMessage; use core::fmt::Error; -use parking_lot::{Mutex, RwLock}; -use std::{format, io::Read, write}; -use tokio::sync::mpsc::{channel as sync_channel, Receiver, Sender}; +use std::{format, write}; +use tokio::sync::{ + mpsc::{channel as sync_channel, Receiver, Sender}, + Mutex, RwLock, +}; use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqMessage}; const DEFAULT_PUB_SOCKET: &str = "tcp://127.0.0.1:10000"; @@ -367,7 +368,7 @@ impl ZmqTransport { let inputs = self.inputs.clone(); let mut ssock = ssock; let mut sub_queue = sub_queue; - tokio::task::spawn_local(async move { + tokio::task::spawn(async move { loop { tokio::select! { Ok(msg) = ssock.recv() => handle_zmq_msg(msg, &outputs, &inputs).await.unwrap(), @@ -389,7 +390,7 @@ impl ZmqTransport { let (req_send, mut req_recv) = sync_channel(1); { - let mut inputs = self.inputs.write(); + let mut inputs = self.tokio.block_on(self.inputs.write()); if inputs.contains_key(&input_port_id) { return Ok(()); // TODO } @@ -428,11 +429,11 @@ impl ZmqTransport { use ZmqTransportEvent::*; match event { Connect(output_port_id, input_port_id) => { - let inputs = inputs.read(); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { todo!(); }; - let mut input_state = input_state.upgradable_read(); + let mut input_state = input_state.write().await; use ZmqInputPortState::*; match &*input_state { @@ -471,14 +472,14 @@ impl ZmqTransport { .await .expect("input worker send ack-conn event"); - input_state.with_upgraded(add_connection); + add_connection(&mut input_state); } Message(output_port_id, _, seq_id, bytes) => { - let inputs = inputs.read(); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { todo!(); }; - let input_state = input_state.read(); + let input_state = input_state.read().await; use ZmqInputPortState::*; match &*input_state { @@ -506,11 +507,11 @@ impl ZmqTransport { } } CloseOutput(output_port_id, input_port_id) => { - let inputs = inputs.read(); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { todo!(); }; - let mut input_state = input_state.upgradable_read(); + let mut input_state = input_state.write().await; use ZmqInputPortState::*; let Connected(_, _, _, _, ref connected_ids) = *input_state else { @@ -528,12 +529,12 @@ impl ZmqTransport { .expect("input worker send unsub req"); } - input_state.with_upgraded(|state| match state { + match *input_state { Open(_) | Closed => (), - Connected(_, _, _, _, connected_ids) => { + Connected(_, _, _, _, ref mut connected_ids) => { connected_ids.retain(|&id| id != output_port_id) } - }); + }; } // ignore, ideally we never receive these here: @@ -551,11 +552,11 @@ impl ZmqTransport { use ZmqInputPortRequest::*; match request { Close => { - let inputs = inputs.read(); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { todo!(); }; - let mut input_state = input_state.upgradable_read(); + let mut input_state = input_state.write().await; use ZmqInputPortState::*; let Connected(_, ref port_events, _, _, _) = *input_state else { @@ -572,7 +573,7 @@ impl ZmqTransport { .await .expect("input worker send port closed"); - input_state.with_upgraded(|state| *state = ZmqInputPortState::Closed); + *input_state = ZmqInputPortState::Closed; response_chan .send(Ok(())) @@ -582,7 +583,7 @@ impl ZmqTransport { } } - tokio::task::spawn_local(async move { + tokio::task::spawn(async move { // Input worker loop: // 1. Receive connection attempts and respond // 2. Receive messages and forward to channel @@ -616,7 +617,7 @@ impl ZmqTransport { let (to_worker_send, mut to_worker_recv) = sync_channel(1); { - let mut outputs = self.outputs.write(); + let mut outputs = self.tokio.block_on(self.outputs.write()); if outputs.contains_key(&output_port_id) { return Ok(()); // TODO } @@ -629,7 +630,7 @@ impl ZmqTransport { let pub_queue = self.pub_queue.clone(); let outputs = self.outputs.clone(); - tokio::task::spawn_local(async move { + tokio::task::spawn(async move { let Some((input_port_id, conn_confirm)) = conn_recv.recv().await else { todo!(); }; @@ -668,11 +669,11 @@ impl ZmqTransport { use ZmqTransportEvent::*; match response { AckConnection(_, input_port_id) => { - let outputs = outputs.read(); + let outputs = outputs.read().await; let Some(output_state) = outputs.get(&output_port_id) else { todo!(); }; - let mut output_state = output_state.write(); + let mut output_state = output_state.write().await; debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); *output_state = ZmqOutputPortState::Connected( msg_req_send, @@ -750,11 +751,11 @@ impl ZmqTransport { } } CloseInput(_) => { - let outputs = outputs.read(); + let outputs = outputs.read().await; let Some(output_state) = outputs.get(&output_port_id) else { todo!(); }; - let mut output_state = output_state.write(); + let mut output_state = output_state.write().await; debug_assert!(matches!( *output_state, ZmqOutputPortState::Connected(..) @@ -804,148 +805,175 @@ impl ZmqTransport { impl Transport for ZmqTransport { fn input_state(&self, input: InputPortID) -> PortResult { - self.inputs - .read() - .get(&input) - .map(|port| port.read().state()) - .ok_or(PortError::Invalid(input.into())) + self.tokio.block_on(async { + Ok(self + .inputs + .read() + .await + .get(&input) + .ok_or_else(|| PortError::Invalid(input.into()))? + .read() + .await + .state()) + }) } fn output_state(&self, output: OutputPortID) -> PortResult { - self.outputs - .read() - .get(&output) - .map(|port| port.read().state()) - .ok_or(PortError::Invalid(output.into())) + self.tokio.block_on(async { + Ok(self + .outputs + .read() + .await + .get(&output) + .ok_or_else(|| PortError::Invalid(output.into()))? + .read() + .await + .state()) + }) } fn open_input(&self) -> PortResult { - let inputs = self.inputs.read(); + let inputs = self.tokio.block_on(self.inputs.read()); let new_id = InputPortID::try_from(-(inputs.len() as isize + 1)) .map_err(|e| PortError::Other(e.to_string()))?; self.start_input_worker(new_id).map(|_| new_id) } fn open_output(&self) -> PortResult { - let outputs = self.outputs.read(); + let outputs = self.tokio.block_on(self.outputs.read()); let new_id = OutputPortID::try_from(outputs.len() as isize + 1) .map_err(|e| PortError::Other(e.to_string()))?; self.start_output_worker(new_id).map(|_| new_id) } fn close_input(&self, input: InputPortID) -> PortResult { - let inputs = self.inputs.read(); + self.tokio.block_on(async { + let inputs = self.inputs.read().await; - let Some(state) = inputs.get(&input) else { - return Err(PortError::Invalid(input.into())); - }; + let Some(state) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; - let state = state.read(); + let state = state.read().await; - let ZmqInputPortState::Connected(sender, _, _, _, _) = &*state else { - return Err(PortError::Disconnected); - }; + let ZmqInputPortState::Connected(sender, _, _, _, _) = &*state else { + return Err(PortError::Disconnected); + }; - let (close_send, mut close_recv) = sync_channel(1); + let (close_send, mut close_recv) = sync_channel(1); - self.tokio - .block_on(sender.send((ZmqInputPortRequest::Close, close_send))) - .map_err(|e| PortError::Other(e.to_string()))?; + sender + .send((ZmqInputPortRequest::Close, close_send)) + .await + .map_err(|e| PortError::Other(e.to_string()))?; - self.tokio - .block_on(close_recv.recv()) - .ok_or(PortError::Disconnected)? - .map(|_| true) + close_recv + .recv() + .await + .ok_or(PortError::Disconnected)? + .map(|_| true) + }) } fn close_output(&self, output: OutputPortID) -> PortResult { - let outputs = self.outputs.read(); + self.tokio.block_on(async { + let outputs = self.outputs.read().await; - let Some(state) = outputs.get(&output) else { - return Err(PortError::Invalid(output.into())); - }; + let Some(state) = outputs.get(&output) else { + return Err(PortError::Invalid(output.into())); + }; - let state = state.write(); + let state = state.read().await; - let ZmqOutputPortState::Connected(sender, _, _) = &*state else { - return Err(PortError::Disconnected); - }; + let ZmqOutputPortState::Connected(sender, _, _) = &*state else { + return Err(PortError::Disconnected); + }; - let (close_send, mut close_recv) = sync_channel(1); + let (close_send, mut close_recv) = sync_channel(1); - self.tokio - .block_on(sender.send((ZmqOutputPortRequest::Close, close_send))) - .map_err(|e| PortError::Other(e.to_string()))?; + sender + .send((ZmqOutputPortRequest::Close, close_send)) + .await + .map_err(|e| PortError::Other(e.to_string()))?; - self.tokio - .block_on(close_recv.recv()) - .ok_or(PortError::Disconnected)? - .map(|_| true) + close_recv + .recv() + .await + .ok_or(PortError::Disconnected)? + .map(|_| true) + }) } fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { - let outputs = self.outputs.read(); - let Some(output_state) = outputs.get(&source) else { - return Err(PortError::Invalid(source.into())); - }; + self.tokio.block_on(async { + let outputs = self.outputs.read().await; + let Some(output_state) = outputs.get(&source) else { + return Err(PortError::Invalid(source.into())); + }; - let output_state = output_state.read(); - let ZmqOutputPortState::Open(ref sender, _) = *output_state else { - return Err(PortError::Invalid(source.into())); - }; + let output_state = output_state.read().await; + let ZmqOutputPortState::Open(ref sender, _) = *output_state else { + return Err(PortError::Invalid(source.into())); + }; - let (confirm_send, mut confirm_recv) = sync_channel(1); + let (confirm_send, mut confirm_recv) = sync_channel(1); - self.tokio - .block_on(sender.send((target, confirm_send))) - .map_err(|e| PortError::Other(e.to_string()))?; + sender + .send((target, confirm_send)) + .await + .map_err(|e| PortError::Other(e.to_string()))?; - self.tokio - .block_on(confirm_recv.recv()) - .ok_or(PortError::Disconnected)? - .map(|_| true) + confirm_recv + .recv() + .await + .ok_or(PortError::Disconnected)? + .map(|_| true) + }) } fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { - let outputs = self.outputs.read(); - let Some(output) = outputs.get(&output) else { - return Err(PortError::Invalid(output.into())); - }; - let output = output.read(); + self.tokio.block_on(async { + let outputs = self.outputs.read().await; + let Some(output) = outputs.get(&output) else { + return Err(PortError::Invalid(output.into())); + }; + let output = output.read().await; - let ZmqOutputPortState::Connected(sender, _, _) = &*output else { - return Err(PortError::Disconnected); - }; + let ZmqOutputPortState::Connected(sender, _, _) = &*output else { + return Err(PortError::Disconnected); + }; - let (ack_send, mut ack_recv) = sync_channel(1); + let (ack_send, mut ack_recv) = sync_channel(1); - self.tokio - .block_on(sender.send((ZmqOutputPortRequest::Send(message), ack_send))) - .map_err(|e| PortError::Other(e.to_string()))?; + sender + .send((ZmqOutputPortRequest::Send(message), ack_send)) + .await + .map_err(|e| PortError::Other(e.to_string()))?; - self.tokio - .block_on(ack_recv.recv()) - .ok_or(PortError::Disconnected)? + ack_recv.recv().await.ok_or(PortError::Disconnected)? + }) } fn recv(&self, input: InputPortID) -> PortResult> { - let inputs = self.inputs.read(); - let Some(input) = inputs.get(&input) else { - return Err(PortError::Invalid(input.into())); - }; - let input = input.read(); + self.tokio.block_on(async { + let inputs = self.inputs.read().await; + let Some(input) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; + let input = input.read().await; - let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input else { - return Err(PortError::Disconnected); - }; - let mut receiver = receiver.lock(); + let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input else { + return Err(PortError::Disconnected); + }; + let mut receiver = receiver.lock().await; - use ZmqInputPortEvent::*; - match self.tokio.block_on(receiver.recv()) { - Some(Closed) => Ok(None), // EOS - Some(Message(bytes)) => Ok(Some(bytes)), - None => Err(PortError::Disconnected), - } + use ZmqInputPortEvent::*; + match receiver.recv().await { + Some(Closed) => Ok(None), // EOS + Some(Message(bytes)) => Ok(Some(bytes)), + None => Err(PortError::Disconnected), + } + }) } fn try_recv(&self, _input: InputPortID) -> PortResult> { @@ -966,11 +994,11 @@ async fn handle_zmq_msg( match event { // input ports Connect(_, input_port_id) => { - let inputs = inputs.read(); + let inputs = inputs.read().await; let Some(input) = inputs.get(&input_port_id) else { todo!(); }; - let input = input.read(); + let input = input.read().await; use ZmqInputPortState::*; match &*input { @@ -979,12 +1007,12 @@ async fn handle_zmq_msg( }; } Message(_, input_port_id, _, _) => { - let inputs = inputs.read(); + let inputs = inputs.read().await; let Some(input) = inputs.get(&input_port_id) else { todo!(); }; - let input = input.read(); + let input = input.read().await; let ZmqInputPortState::Connected(_, _, _, sender, _) = &*input else { todo!(); }; @@ -992,11 +1020,11 @@ async fn handle_zmq_msg( sender.send(event).await.unwrap(); } CloseOutput(_, input_port_id) => { - let inputs = inputs.read(); + let inputs = inputs.read().await; let Some(input) = inputs.get(&input_port_id) else { todo!(); }; - let input = input.read(); + let input = input.read().await; use ZmqInputPortState::*; match &*input { @@ -1009,11 +1037,11 @@ async fn handle_zmq_msg( // output ports AckConnection(output_port_id, _) => { - let outputs = outputs.read(); + let outputs = outputs.read().await; let Some(output) = outputs.get(&output_port_id) else { todo!(); }; - let output = output.read(); + let output = output.read().await; let ZmqOutputPortState::Open(_, sender) = &*output else { todo!(); @@ -1021,19 +1049,19 @@ async fn handle_zmq_msg( sender.send(event).await.unwrap(); } AckMessage(output_port_id, _, _) => { - let outputs = outputs.read(); + let outputs = outputs.read().await; let Some(output) = outputs.get(&output_port_id) else { todo!(); }; - let output = output.read(); + let output = output.read().await; let ZmqOutputPortState::Connected(_, sender, _) = &*output else { todo!(); }; sender.send(event).await.unwrap(); } CloseInput(input_port_id) => { - for (_, state) in outputs.read().iter() { - let state = state.read(); + for (_, state) in outputs.read().await.iter() { + let state = state.read().await; let ZmqOutputPortState::Connected(_, ref sender, ref id) = *state else { continue; }; From 2fc2154105dfc98b462a0a6460e953f595bd6338 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 06:38:35 +0200 Subject: [PATCH 33/63] Use await inside async block --- lib/protoflow-zeromq/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index a419f3d4..3123ebb8 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -347,13 +347,13 @@ impl ZmqTransport { psock: zeromq::PubSocket, pub_queue: Receiver, ) { - let tokio = self.tokio.clone(); let mut psock = psock; let mut pub_queue = pub_queue; tokio::task::spawn(async move { while let Some(event) = pub_queue.recv().await { - tokio - .block_on(psock.send(event.into())) + psock + .send(event.into()) + .await .expect("zmq pub-socket worker") } }); From 50b06cb1e578fbddfb56df1eb484fc6a93650b78 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 06:46:24 +0200 Subject: [PATCH 34/63] Add test for sending and receiving --- lib/protoflow-zeromq/src/lib.rs | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 3123ebb8..7b476dd7 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -1082,7 +1082,7 @@ async fn handle_zmq_msg( mod tests { use super::*; - use protoflow_core::System; + use protoflow_core::{runtimes::StdRuntime, System}; use futures_util::future::TryFutureExt; use zeromq::{PubSocket, SocketRecv, SocketSend, SubSocket}; @@ -1122,4 +1122,30 @@ mod tests { let _ = System::::build(|_s| { /* do nothing */ }); } + + #[test] + fn run_transport() { + let rt = tokio::runtime::Runtime::new().unwrap(); + let _guard = rt.enter(); + + start_zmqtransport_server(&rt); + + let transport = ZmqTransport::default(); + let runtime = StdRuntime::new(transport).unwrap(); + let system = System::new(&runtime); + + let output = system.output(); + let input = system.input(); + + system.connect(&output, &input); + + let output = rt.spawn(async move { output.send(&"Hello world!".to_string()).unwrap() }); + let input = rt.spawn(async move { input.recv().unwrap() }); + + let (output, input) = rt.block_on(async { tokio::join!(output, input) }); + + output.unwrap(); + + assert_eq!("Hello world!".to_string(), input.unwrap().unwrap()); + } } From dbbf11d8dfd26745ab065cc0e07ec76d8fcc415b Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 07:39:11 +0200 Subject: [PATCH 35/63] =?UTF-8?q?Fix=20message=20sending:=20first=20workin?= =?UTF-8?q?g=20version=20=F0=9F=8E=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Only needed to drop a couple of locks, respond to the send requests channel and toggle one condition check. --- lib/protoflow-zeromq/Cargo.toml | 1 - lib/protoflow-zeromq/src/lib.rs | 48 +++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/lib/protoflow-zeromq/Cargo.toml b/lib/protoflow-zeromq/Cargo.toml index 8348cf6a..e9391fa2 100644 --- a/lib/protoflow-zeromq/Cargo.toml +++ b/lib/protoflow-zeromq/Cargo.toml @@ -33,7 +33,6 @@ zeromq = { version = "0.4.1", default-features = false, features = [ "all-transport", ] } tokio = { version = "1.40.0", default-features = false } -parking_lot = "0.12" prost = "0.13.2" prost-types = "0.13.2" diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 7b476dd7..1ea451fd 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -439,7 +439,7 @@ impl ZmqTransport { match &*input_state { Open(_) => (), Connected(_, _, _, _, connected_ids) => { - if !connected_ids.iter().any(|&id| id == output_port_id) { + if connected_ids.iter().any(|&id| id == output_port_id) { return; } } @@ -600,15 +600,7 @@ impl ZmqTransport { } }); - let topic = format!("{}:", input_port_id); - - // send sub request - self.tokio - .block_on( - self.sub_queue - .send(ZmqSubscriptionRequest::Subscribe(topic)), - ) - .map_err(|e| PortError::Other(e.to_string())) + Ok(()) } fn start_output_worker(&self, output_port_id: OutputPortID) -> Result<(), PortError> { @@ -693,7 +685,7 @@ impl ZmqTransport { } let mut seq_id = 1; - 'outer: loop { + 'send: loop { let (request, response_chan) = msg_req_recv .recv() .await @@ -737,7 +729,7 @@ impl ZmqTransport { .await .expect("output worker send message event"); - loop { + 'recv: loop { let event = to_worker_recv .recv() .await @@ -747,7 +739,11 @@ impl ZmqTransport { match event { AckMessage(_, _, ack_id) => { if ack_id == seq_id { - break; + response_chan + .send(Ok(())) + .await + .expect("output worker respond send"); + break 'recv; } } CloseInput(_) => { @@ -781,7 +777,7 @@ impl ZmqTransport { .await .expect("output worker respond msg"); - break 'outer; + break 'send; } // ignore others, we shouldn't receive any new conn-acks @@ -836,6 +832,8 @@ impl Transport for ZmqTransport { let inputs = self.tokio.block_on(self.inputs.read()); let new_id = InputPortID::try_from(-(inputs.len() as isize + 1)) .map_err(|e| PortError::Other(e.to_string()))?; + + drop(inputs); self.start_input_worker(new_id).map(|_| new_id) } @@ -843,6 +841,8 @@ impl Transport for ZmqTransport { let outputs = self.tokio.block_on(self.outputs.read()); let new_id = OutputPortID::try_from(outputs.len() as isize + 1) .map_err(|e| PortError::Other(e.to_string()))?; + + drop(outputs); self.start_output_worker(new_id).map(|_| new_id) } @@ -916,6 +916,10 @@ impl Transport for ZmqTransport { return Err(PortError::Invalid(source.into())); }; + let sender = sender.clone(); + drop(output_state); + drop(outputs); + let (confirm_send, mut confirm_recv) = sync_channel(1); sender @@ -1046,6 +1050,9 @@ async fn handle_zmq_msg( let ZmqOutputPortState::Open(_, sender) = &*output else { todo!(); }; + let sender = sender.clone(); + drop(output); + drop(outputs); sender.send(event).await.unwrap(); } AckMessage(output_port_id, _, _) => { @@ -1139,13 +1146,14 @@ mod tests { system.connect(&output, &input); - let output = rt.spawn(async move { output.send(&"Hello world!".to_string()).unwrap() }); - let input = rt.spawn(async move { input.recv().unwrap() }); - - let (output, input) = rt.block_on(async { tokio::join!(output, input) }); + let output = std::thread::spawn(move || output.send(&"Hello world!".to_string())); + let input = std::thread::spawn(move || input.recv()); - output.unwrap(); + output.join().expect("thread failed").expect("send failed"); - assert_eq!("Hello world!".to_string(), input.unwrap().unwrap()); + assert_eq!( + Some("Hello world!".to_string()), + input.join().expect("thread failed").expect("recv failed") + ); } } From 386f2332527e22ee124918cc60dd3a6118be585c Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 08:26:46 +0200 Subject: [PATCH 36/63] Unsubscribe from topics when input closes --- lib/protoflow-zeromq/src/lib.rs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 1ea451fd..ec4caa20 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -423,7 +423,6 @@ impl ZmqTransport { inputs: &RwLock>>, req_send: &Sender<(ZmqInputPortRequest, Sender>)>, pub_queue: &Sender, - sub_queue: &Sender, input_port_id: InputPortID, ) { use ZmqTransportEvent::*; @@ -522,13 +521,6 @@ impl ZmqTransport { return; } - for topic in input_topics(input_port_id).into_iter() { - sub_queue - .send(ZmqSubscriptionRequest::Unsubscribe(topic)) - .await - .expect("input worker send unsub req"); - } - match *input_state { Open(_) | Closed => (), Connected(_, _, _, _, ref mut connected_ids) => { @@ -547,6 +539,7 @@ impl ZmqTransport { response_chan: Sender>, inputs: &RwLock>>, pub_queue: &Sender, + sub_queue: &Sender, input_port_id: InputPortID, ) { use ZmqInputPortRequest::*; @@ -575,6 +568,17 @@ impl ZmqTransport { *input_state = ZmqInputPortState::Closed; + { + let mut handles = Vec::new(); + for topic in input_topics(input_port_id).into_iter() { + let handle = sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + handle.await.expect("input worker send unsub req"); + } + } + response_chan .send(Ok(())) .await @@ -591,10 +595,10 @@ impl ZmqTransport { loop { tokio::select! { Some(event) = to_worker_recv.recv() => { - handle_socket_event(event, &inputs, &req_send, &pub_queue, &sub_queue, input_port_id).await; + handle_socket_event(event, &inputs, &req_send, &pub_queue, input_port_id).await; } Some((request, response_chan)) = req_recv.recv() => { - handle_input_request(request, response_chan, &inputs, &pub_queue, input_port_id).await; + handle_input_request(request, response_chan, &inputs, &pub_queue, &sub_queue, input_port_id).await; } }; } From 33fb3de525e1fb80930647f6db645da1a478a210 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 09:21:08 +0200 Subject: [PATCH 37/63] Fix deadlock when closing ports Needed to drop locks and send a `Closed` event in the case that there is a receiver waiting on the port. --- lib/protoflow-zeromq/src/lib.rs | 74 +++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index ec4caa20..749918ba 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -523,8 +523,14 @@ impl ZmqTransport { match *input_state { Open(_) | Closed => (), - Connected(_, _, _, _, ref mut connected_ids) => { - connected_ids.retain(|&id| id != output_port_id) + Connected(_, ref sender, _, _, ref mut connected_ids) => { + connected_ids.retain(|&id| id != output_port_id); + if connected_ids.is_empty() { + sender + .send(ZmqInputPortEvent::Closed) + .await + .expect("input worker publish Closed event"); + } } }; } @@ -853,17 +859,19 @@ impl Transport for ZmqTransport { fn close_input(&self, input: InputPortID) -> PortResult { self.tokio.block_on(async { let inputs = self.inputs.read().await; - - let Some(state) = inputs.get(&input) else { + let Some(input_state) = inputs.get(&input) else { return Err(PortError::Invalid(input.into())); }; - let state = state.read().await; - - let ZmqInputPortState::Connected(sender, _, _, _, _) = &*state else { + let input_state = input_state.read().await; + let ZmqInputPortState::Connected(sender, _, _, _, _) = &*input_state else { return Err(PortError::Disconnected); }; + let sender = sender.clone(); + drop(input_state); + drop(inputs); + let (close_send, mut close_recv) = sync_channel(1); sender @@ -882,17 +890,19 @@ impl Transport for ZmqTransport { fn close_output(&self, output: OutputPortID) -> PortResult { self.tokio.block_on(async { let outputs = self.outputs.read().await; - - let Some(state) = outputs.get(&output) else { + let Some(output_state) = outputs.get(&output) else { return Err(PortError::Invalid(output.into())); }; - let state = state.read().await; - - let ZmqOutputPortState::Connected(sender, _, _) = &*state else { + let output_state = output_state.read().await; + let ZmqOutputPortState::Connected(sender, _, _) = &*output_state else { return Err(PortError::Disconnected); }; + let sender = sender.clone(); + drop(output_state); + drop(outputs); + let (close_send, mut close_recv) = sync_channel(1); sender @@ -965,14 +975,18 @@ impl Transport for ZmqTransport { fn recv(&self, input: InputPortID) -> PortResult> { self.tokio.block_on(async { let inputs = self.inputs.read().await; - let Some(input) = inputs.get(&input) else { + let Some(input_state) = inputs.get(&input) else { return Err(PortError::Invalid(input.into())); }; - let input = input.read().await; - let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input else { + let input_state = input_state.read().await; + let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input_state else { return Err(PortError::Disconnected); }; + + let receiver = receiver.clone(); + drop(input_state); + drop(inputs); let mut receiver = receiver.lock().await; use ZmqInputPortEvent::*; @@ -1038,6 +1052,9 @@ async fn handle_zmq_msg( match &*input { Closed => todo!(), Open(sender) | Connected(_, _, _, sender, _) => { + let sender = sender.clone(); + drop(input); + drop(inputs); sender.send(event).await.unwrap(); } }; @@ -1079,6 +1096,8 @@ async fn handle_zmq_msg( if *id != input_port_id { continue; } + let sender = sender.clone(); + drop(state); if let Err(_e) = sender.send(event.clone()).await { continue; // TODO } @@ -1150,14 +1169,25 @@ mod tests { system.connect(&output, &input); - let output = std::thread::spawn(move || output.send(&"Hello world!".to_string())); - let input = std::thread::spawn(move || input.recv()); + let output = std::thread::spawn(move || { + let mut output = output; + output.send(&"Hello world!".to_string())?; + output.close() + }); - output.join().expect("thread failed").expect("send failed"); + let input = std::thread::spawn(move || { + let mut input = input; + + let msg = input.recv()?; + assert_eq!(Some("Hello world!".to_string()), msg); + + let msg = input.recv()?; + assert_eq!(None, msg); + + input.close() + }); - assert_eq!( - Some("Hello world!".to_string()), - input.join().expect("thread failed").expect("recv failed") - ); + output.join().expect("thread failed").unwrap(); + input.join().expect("thread failed").unwrap(); } } From 1520e2b3369d9abb5074eb72beef6f2d99fa6cf0 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 09:23:49 +0200 Subject: [PATCH 38/63] Implement `try_recv` --- lib/protoflow-zeromq/src/lib.rs | 47 ++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 749918ba..e95c5c75 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -18,7 +18,7 @@ use protoflow_core::{ use core::fmt::Error; use std::{format, write}; use tokio::sync::{ - mpsc::{channel as sync_channel, Receiver, Sender}, + mpsc::{channel as sync_channel, error::TryRecvError, Receiver, Sender}, Mutex, RwLock, }; use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqMessage}; @@ -105,8 +105,7 @@ impl ZmqInputPortState { type SequenceID = u64; -/// ZmqTransportEvent represents the data that goes over the wire, sent from an output port over -/// the network to an input port. +/// ZmqTransportEvent represents the data that goes over the wire from one port to another. #[derive(Clone, Debug)] enum ZmqTransportEvent { Connect(OutputPortID, InputPortID), @@ -286,7 +285,7 @@ impl Default for ZmqTransport { } } -#[derive(Clone)] +#[derive(Clone, Debug)] enum ZmqSubscriptionRequest { Subscribe(String), Unsubscribe(String), @@ -998,8 +997,32 @@ impl Transport for ZmqTransport { }) } - fn try_recv(&self, _input: InputPortID) -> PortResult> { - todo!(); + fn try_recv(&self, input: InputPortID) -> PortResult> { + self.tokio.block_on(async { + let inputs = self.inputs.read().await; + let Some(input_state) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; + + let input_state = input_state.read().await; + let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input_state else { + return Err(PortError::Disconnected); + }; + + let receiver = receiver.clone(); + drop(input_state); + drop(inputs); + let mut receiver = receiver.lock().await; + + use ZmqInputPortEvent::*; + match receiver.try_recv() { + Ok(Closed) => Ok(None), // EOS + Ok(Message(bytes)) => Ok(Some(bytes)), + Err(TryRecvError::Disconnected) => Err(PortError::Disconnected), + // TODO: what should we answer with here?: + Err(TryRecvError::Empty) => Err(PortError::RecvFailed), + } + }) } } @@ -1117,17 +1140,17 @@ mod tests { use futures_util::future::TryFutureExt; use zeromq::{PubSocket, SocketRecv, SocketSend, SubSocket}; - fn start_zmqtransport_server(rt: &tokio::runtime::Runtime) { + async fn start_zmqtransport_server() { // bind a *SUB* socket to the *PUB* address so that the transport can *PUB* to it let mut pub_srv = SubSocket::new(); - rt.block_on(pub_srv.bind(DEFAULT_PUB_SOCKET)).unwrap(); + pub_srv.bind(DEFAULT_PUB_SOCKET).await.unwrap(); // bind a *PUB* socket to the *SUB* address so that the transport can *SUB* to it let mut sub_srv = PubSocket::new(); - rt.block_on(sub_srv.bind(DEFAULT_SUB_SOCKET)).unwrap(); + sub_srv.bind(DEFAULT_SUB_SOCKET).await.unwrap(); // subscribe to all messages - rt.block_on(pub_srv.subscribe("")).unwrap(); + pub_srv.subscribe("").await.unwrap(); // resend anything received on the *SUB* socket to the *PUB* socket tokio::task::spawn(async move { @@ -1148,7 +1171,7 @@ mod tests { let _guard = rt.enter(); //zeromq::proxy(frontend, backend, capture) - start_zmqtransport_server(&rt); + rt.block_on(start_zmqtransport_server()); let _ = System::::build(|_s| { /* do nothing */ }); } @@ -1158,7 +1181,7 @@ mod tests { let rt = tokio::runtime::Runtime::new().unwrap(); let _guard = rt.enter(); - start_zmqtransport_server(&rt); + rt.block_on(start_zmqtransport_server()); let transport = ZmqTransport::default(); let runtime = StdRuntime::new(transport).unwrap(); From 9fc2a2662a0db69d9d2c6435f48e4e1d1714ac4a Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 10:26:24 +0200 Subject: [PATCH 39/63] Remove explicit `drop`s --- lib/protoflow-zeromq/src/lib.rs | 271 +++++++++++++++++--------------- 1 file changed, 148 insertions(+), 123 deletions(-) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index e95c5c75..3de57b1e 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -838,38 +838,40 @@ impl Transport for ZmqTransport { } fn open_input(&self) -> PortResult { - let inputs = self.tokio.block_on(self.inputs.read()); - let new_id = InputPortID::try_from(-(inputs.len() as isize + 1)) - .map_err(|e| PortError::Other(e.to_string()))?; + let new_id = { + let inputs = self.tokio.block_on(self.inputs.read()); + InputPortID::try_from(-(inputs.len() as isize + 1)) + .map_err(|e| PortError::Other(e.to_string()))? + }; - drop(inputs); self.start_input_worker(new_id).map(|_| new_id) } fn open_output(&self) -> PortResult { - let outputs = self.tokio.block_on(self.outputs.read()); - let new_id = OutputPortID::try_from(outputs.len() as isize + 1) - .map_err(|e| PortError::Other(e.to_string()))?; + let new_id = { + let outputs = self.tokio.block_on(self.outputs.read()); + OutputPortID::try_from(outputs.len() as isize + 1) + .map_err(|e| PortError::Other(e.to_string()))? + }; - drop(outputs); self.start_output_worker(new_id).map(|_| new_id) } fn close_input(&self, input: InputPortID) -> PortResult { self.tokio.block_on(async { - let inputs = self.inputs.read().await; - let Some(input_state) = inputs.get(&input) else { - return Err(PortError::Invalid(input.into())); - }; + let sender = { + let inputs = self.inputs.read().await; + let Some(input_state) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; - let input_state = input_state.read().await; - let ZmqInputPortState::Connected(sender, _, _, _, _) = &*input_state else { - return Err(PortError::Disconnected); - }; + let input_state = input_state.read().await; + let ZmqInputPortState::Connected(sender, _, _, _, _) = &*input_state else { + return Err(PortError::Disconnected); + }; - let sender = sender.clone(); - drop(input_state); - drop(inputs); + sender.clone() + }; let (close_send, mut close_recv) = sync_channel(1); @@ -888,19 +890,19 @@ impl Transport for ZmqTransport { fn close_output(&self, output: OutputPortID) -> PortResult { self.tokio.block_on(async { - let outputs = self.outputs.read().await; - let Some(output_state) = outputs.get(&output) else { - return Err(PortError::Invalid(output.into())); - }; + let sender = { + let outputs = self.outputs.read().await; + let Some(output_state) = outputs.get(&output) else { + return Err(PortError::Invalid(output.into())); + }; - let output_state = output_state.read().await; - let ZmqOutputPortState::Connected(sender, _, _) = &*output_state else { - return Err(PortError::Disconnected); - }; + let output_state = output_state.read().await; + let ZmqOutputPortState::Connected(sender, _, _) = &*output_state else { + return Err(PortError::Disconnected); + }; - let sender = sender.clone(); - drop(output_state); - drop(outputs); + sender.clone() + }; let (close_send, mut close_recv) = sync_channel(1); @@ -919,19 +921,19 @@ impl Transport for ZmqTransport { fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { self.tokio.block_on(async { - let outputs = self.outputs.read().await; - let Some(output_state) = outputs.get(&source) else { - return Err(PortError::Invalid(source.into())); - }; + let sender = { + let outputs = self.outputs.read().await; + let Some(output_state) = outputs.get(&source) else { + return Err(PortError::Invalid(source.into())); + }; - let output_state = output_state.read().await; - let ZmqOutputPortState::Open(ref sender, _) = *output_state else { - return Err(PortError::Invalid(source.into())); - }; + let output_state = output_state.read().await; + let ZmqOutputPortState::Open(ref sender, _) = *output_state else { + return Err(PortError::Invalid(source.into())); + }; - let sender = sender.clone(); - drop(output_state); - drop(outputs); + sender.clone() + }; let (confirm_send, mut confirm_recv) = sync_channel(1); @@ -950,14 +952,18 @@ impl Transport for ZmqTransport { fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { self.tokio.block_on(async { - let outputs = self.outputs.read().await; - let Some(output) = outputs.get(&output) else { - return Err(PortError::Invalid(output.into())); - }; - let output = output.read().await; + let sender = { + let outputs = self.outputs.read().await; + let Some(output) = outputs.get(&output) else { + return Err(PortError::Invalid(output.into())); + }; + let output = output.read().await; + + let ZmqOutputPortState::Connected(sender, _, _) = &*output else { + return Err(PortError::Disconnected); + }; - let ZmqOutputPortState::Connected(sender, _, _) = &*output else { - return Err(PortError::Disconnected); + sender.clone() }; let (ack_send, mut ack_recv) = sync_channel(1); @@ -973,19 +979,20 @@ impl Transport for ZmqTransport { fn recv(&self, input: InputPortID) -> PortResult> { self.tokio.block_on(async { - let inputs = self.inputs.read().await; - let Some(input_state) = inputs.get(&input) else { - return Err(PortError::Invalid(input.into())); - }; + let receiver = { + let inputs = self.inputs.read().await; + let Some(input_state) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; + + let input_state = input_state.read().await; + let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input_state else { + return Err(PortError::Disconnected); + }; - let input_state = input_state.read().await; - let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input_state else { - return Err(PortError::Disconnected); + receiver.clone() }; - let receiver = receiver.clone(); - drop(input_state); - drop(inputs); let mut receiver = receiver.lock().await; use ZmqInputPortEvent::*; @@ -999,19 +1006,20 @@ impl Transport for ZmqTransport { fn try_recv(&self, input: InputPortID) -> PortResult> { self.tokio.block_on(async { - let inputs = self.inputs.read().await; - let Some(input_state) = inputs.get(&input) else { - return Err(PortError::Invalid(input.into())); - }; + let receiver = { + let inputs = self.inputs.read().await; + let Some(input_state) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; + + let input_state = input_state.read().await; + let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input_state else { + return Err(PortError::Disconnected); + }; - let input_state = input_state.read().await; - let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input_state else { - return Err(PortError::Disconnected); + receiver.clone() }; - let receiver = receiver.clone(); - drop(input_state); - drop(inputs); let mut receiver = receiver.lock().await; use ZmqInputPortEvent::*; @@ -1039,88 +1047,105 @@ async fn handle_zmq_msg( match event { // input ports Connect(_, input_port_id) => { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - todo!(); - }; - let input = input.read().await; + let sender = { + let inputs = inputs.read().await; + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + let input = input.read().await; - use ZmqInputPortState::*; - match &*input { - Closed => todo!(), - Open(sender) | Connected(_, _, _, sender, _) => sender.send(event).await.unwrap(), + use ZmqInputPortState::*; + match &*input { + Closed => todo!(), + Open(sender) | Connected(_, _, _, sender, _) => sender.clone(), + } }; + + sender.send(event).await.unwrap(); } Message(_, input_port_id, _, _) => { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - todo!(); - }; + let sender = { + let inputs = inputs.read().await; + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; - let input = input.read().await; - let ZmqInputPortState::Connected(_, _, _, sender, _) = &*input else { - todo!(); + let input = input.read().await; + let ZmqInputPortState::Connected(_, _, _, sender, _) = &*input else { + todo!(); + }; + + sender.clone() }; sender.send(event).await.unwrap(); } CloseOutput(_, input_port_id) => { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - todo!(); - }; - let input = input.read().await; - - use ZmqInputPortState::*; - match &*input { - Closed => todo!(), - Open(sender) | Connected(_, _, _, sender, _) => { - let sender = sender.clone(); - drop(input); - drop(inputs); - sender.send(event).await.unwrap(); + let sender = { + let inputs = inputs.read().await; + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + let input = input.read().await; + + use ZmqInputPortState::*; + match &*input { + Closed => todo!(), + Open(sender) | Connected(_, _, _, sender, _) => sender.clone(), } }; + + sender.send(event).await.unwrap(); } // output ports AckConnection(output_port_id, _) => { - let outputs = outputs.read().await; - let Some(output) = outputs.get(&output_port_id) else { - todo!(); - }; - let output = output.read().await; + let sender = { + let outputs = outputs.read().await; + let Some(output) = outputs.get(&output_port_id) else { + todo!(); + }; + let output = output.read().await; - let ZmqOutputPortState::Open(_, sender) = &*output else { - todo!(); + let ZmqOutputPortState::Open(_, sender) = &*output else { + todo!(); + }; + + sender.clone() }; - let sender = sender.clone(); - drop(output); - drop(outputs); + sender.send(event).await.unwrap(); } AckMessage(output_port_id, _, _) => { - let outputs = outputs.read().await; - let Some(output) = outputs.get(&output_port_id) else { - todo!(); - }; - let output = output.read().await; - let ZmqOutputPortState::Connected(_, sender, _) = &*output else { - todo!(); + let sender = { + let outputs = outputs.read().await; + let Some(output) = outputs.get(&output_port_id) else { + todo!(); + }; + let output = output.read().await; + let ZmqOutputPortState::Connected(_, sender, _) = &*output else { + todo!(); + }; + + sender.clone() }; + sender.send(event).await.unwrap(); } CloseInput(input_port_id) => { for (_, state) in outputs.read().await.iter() { - let state = state.read().await; - let ZmqOutputPortState::Connected(_, ref sender, ref id) = *state else { - continue; + let sender = { + let state = state.read().await; + let ZmqOutputPortState::Connected(_, ref sender, ref id) = *state else { + continue; + }; + if *id != input_port_id { + continue; + } + + sender.clone() }; - if *id != input_port_id { - continue; - } - let sender = sender.clone(); - drop(state); + if let Err(_e) = sender.send(event.clone()).await { continue; // TODO } From be6afcae658e56e334b0d03c23edeb9c3cb125b9 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 11:01:35 +0200 Subject: [PATCH 40/63] Rename protobuf file --- lib/protoflow-zeromq/src/lib.rs | 1 + lib/protoflow-zeromq/src/{protoflow_zmq.rs => protoflow.zmq.rs} | 0 lib/protoflow-zeromq/src/transport_event.proto | 2 +- 3 files changed, 2 insertions(+), 1 deletion(-) rename lib/protoflow-zeromq/src/{protoflow_zmq.rs => protoflow.zmq.rs} (100%) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 3de57b1e..c6937d13 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -6,6 +6,7 @@ #[doc(hidden)] pub use protoflow_core::prelude; +#[path = "protoflow.zmq.rs"] mod protoflow_zmq; extern crate std; diff --git a/lib/protoflow-zeromq/src/protoflow_zmq.rs b/lib/protoflow-zeromq/src/protoflow.zmq.rs similarity index 100% rename from lib/protoflow-zeromq/src/protoflow_zmq.rs rename to lib/protoflow-zeromq/src/protoflow.zmq.rs diff --git a/lib/protoflow-zeromq/src/transport_event.proto b/lib/protoflow-zeromq/src/transport_event.proto index b88dabbe..8a31eda4 100644 --- a/lib/protoflow-zeromq/src/transport_event.proto +++ b/lib/protoflow-zeromq/src/transport_event.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package protoflow_zmq; +package protoflow.zmq; message Connect { int64 output = 1; From 3f3d5a071c33816c99c9fb20072c05b3faba54d6 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 11:09:41 +0200 Subject: [PATCH 41/63] Consume response channel in output worker --- lib/protoflow-zeromq/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index c6937d13..af6259b8 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -687,6 +687,7 @@ impl ZmqTransport { .send(Ok(())) .await .expect("output worker respond conn"); + drop(conn_confirm); break; } From 93dc100f8dd195d4645225d8e73b43124d84f317 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 11:41:50 +0200 Subject: [PATCH 42/63] Add tracing --- lib/protoflow-zeromq/Cargo.toml | 1 + lib/protoflow-zeromq/src/lib.rs | 168 +++++++++++++++++++++++++++++++- 2 files changed, 168 insertions(+), 1 deletion(-) diff --git a/lib/protoflow-zeromq/Cargo.toml b/lib/protoflow-zeromq/Cargo.toml index e9391fa2..7c31ba0c 100644 --- a/lib/protoflow-zeromq/Cargo.toml +++ b/lib/protoflow-zeromq/Cargo.toml @@ -38,3 +38,4 @@ prost-types = "0.13.2" [dev-dependencies] futures-util = "0.3.31" +tracing-subscriber = "0.3.19" diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index af6259b8..880b6eeb 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -24,6 +24,9 @@ use tokio::sync::{ }; use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqMessage}; +#[cfg(feature = "tracing")] +use tracing::{trace, trace_span}; + const DEFAULT_PUB_SOCKET: &str = "tcp://127.0.0.1:10000"; const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; @@ -351,6 +354,13 @@ impl ZmqTransport { let mut pub_queue = pub_queue; tokio::task::spawn(async move { while let Some(event) = pub_queue.recv().await { + #[cfg(feature = "tracing")] + trace!( + target: "ZmqTransport::pub_socket", + ?event, + "sending event to socket" + ); + psock .send(event.into()) .await @@ -371,8 +381,24 @@ impl ZmqTransport { tokio::task::spawn(async move { loop { tokio::select! { - Ok(msg) = ssock.recv() => handle_zmq_msg(msg, &outputs, &inputs).await.unwrap(), + Ok(msg) = ssock.recv() => { + #[cfg(feature = "tracing")] + trace!( + target: "ZmqTransport::sub_socket", + ?msg, + "got message from socket" + ); + + handle_zmq_msg(msg, &outputs, &inputs).await.unwrap() + }, Some(req) = sub_queue.recv() => { + #[cfg(feature = "tracing")] + trace!( + target: "ZmqTransport::sub_socket", + ?req, + "got sub update request" + ); + use ZmqSubscriptionRequest::*; match req { Subscribe(topic) => ssock.subscribe(&topic).await.expect("zmq recv worker subscribe"), @@ -385,6 +411,9 @@ impl ZmqTransport { } fn start_input_worker(&self, input_port_id: InputPortID) -> Result<(), PortError> { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::start_input_worker", ?input_port_id); + let (to_worker_send, mut to_worker_recv) = sync_channel(1); let (req_send, mut req_recv) = sync_channel(1); @@ -396,12 +425,19 @@ impl ZmqTransport { } let state = ZmqInputPortState::Open(to_worker_send.clone()); let state = RwLock::new(state); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?state, "saving new opened state")); + inputs.insert(input_port_id, state); } { let mut handles = Vec::new(); for topic in input_topics(input_port_id).into_iter() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?topic, "sending subscription request")); + let handle = self .sub_queue .send(ZmqSubscriptionRequest::Subscribe(topic)); @@ -425,6 +461,15 @@ impl ZmqTransport { pub_queue: &Sender, input_port_id: InputPortID, ) { + #[cfg(feature = "tracing")] + let span = trace_span!( + "ZmqTransport::start_input_worker::handle_socket_event", + ?input_port_id + ); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?event, "got socket event")); + use ZmqTransportEvent::*; match event { Connect(output_port_id, input_port_id) => { @@ -439,6 +484,10 @@ impl ZmqTransport { Open(_) => (), Connected(_, _, _, _, connected_ids) => { if connected_ids.iter().any(|&id| id == output_port_id) { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!(?output_port_id, "output port is already connected") + }); return; } } @@ -471,9 +520,18 @@ impl ZmqTransport { .await .expect("input worker send ack-conn event"); + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?output_port_id, "sent conn-ack")); + add_connection(&mut input_state); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?input_state, "connected new port")); } Message(output_port_id, _, seq_id, bytes) => { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "Message", ?output_port_id, ?seq_id); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { todo!(); @@ -484,6 +542,10 @@ impl ZmqTransport { match &*input_state { Connected(_, sender, _, _, connected_ids) => { if !connected_ids.iter().any(|id| *id == output_port_id) { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!("got message from non-connected output port") + }); return; } @@ -500,6 +562,9 @@ impl ZmqTransport { )) .await .expect("input worker send message ack"); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("sent msg-ack")); } Open(_) | Closed => todo!(), @@ -510,6 +575,7 @@ impl ZmqTransport { let Some(input_state) = inputs.get(&input_port_id) else { todo!(); }; + let mut input_state = input_state.write().await; use ZmqInputPortState::*; @@ -518,6 +584,13 @@ impl ZmqTransport { }; if !connected_ids.iter().any(|id| *id == output_port_id) { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!( + ?output_port_id, + "output port doesn't match any connected port" + ) + }); return; } @@ -526,6 +599,8 @@ impl ZmqTransport { Connected(_, ref sender, _, _, ref mut connected_ids) => { connected_ids.retain(|&id| id != output_port_id); if connected_ids.is_empty() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("last connected port disconnected")); sender .send(ZmqInputPortEvent::Closed) .await @@ -548,6 +623,15 @@ impl ZmqTransport { sub_queue: &Sender, input_port_id: InputPortID, ) { + #[cfg(feature = "tracing")] + let span = trace_span!( + "ZmqTransport::start_input_worker::handle_input_event", + ?input_port_id + ); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?request, "got input request")); + use ZmqInputPortRequest::*; match request { Close => { @@ -577,6 +661,9 @@ impl ZmqTransport { { let mut handles = Vec::new(); for topic in input_topics(input_port_id).into_iter() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?topic, "sending unsubscription request")); + let handle = sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); handles.push(handle); } @@ -593,6 +680,9 @@ impl ZmqTransport { } } + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("spawning")); + tokio::task::spawn(async move { // Input worker loop: // 1. Receive connection attempts and respond @@ -614,17 +704,25 @@ impl ZmqTransport { } fn start_output_worker(&self, output_port_id: OutputPortID) -> Result<(), PortError> { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::start_output_worker", ?output_port_id); + let (conn_send, mut conn_recv) = sync_channel(1); let (to_worker_send, mut to_worker_recv) = sync_channel(1); { let mut outputs = self.tokio.block_on(self.outputs.write()); + if outputs.contains_key(&output_port_id) { return Ok(()); // TODO } let state = ZmqOutputPortState::Open(conn_send, to_worker_send.clone()); let state = RwLock::new(state); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?state, "saving new opened state")); + outputs.insert(output_port_id, state); } @@ -632,14 +730,27 @@ impl ZmqTransport { let pub_queue = self.pub_queue.clone(); let outputs = self.outputs.clone(); + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("spawning")); + tokio::task::spawn(async move { let Some((input_port_id, conn_confirm)) = conn_recv.recv().await else { todo!(); }; + #[cfg(feature = "tracing")] + let span = trace_span!( + "ZmqTransport::start_output_worker::spawn", + ?output_port_id, + ?input_port_id + ); + { let mut handles = Vec::new(); for topic in output_topics(output_port_id, input_port_id).into_iter() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?topic, "sending subscription request")); + let handle = sub_queue.send(ZmqSubscriptionRequest::Subscribe(topic)); handles.push(handle); } @@ -658,6 +769,9 @@ impl ZmqTransport { // 3. Send disconnect events loop { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("sending connection attempt...")); + pub_queue .send(ZmqTransportEvent::Connect(output_port_id, input_port_id)) .await @@ -668,6 +782,9 @@ impl ZmqTransport { .await .expect("output worker recv ack-conn event"); + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?response, "got response")); + use ZmqTransportEvent::*; match response { AckConnection(_, input_port_id) => { @@ -683,6 +800,9 @@ impl ZmqTransport { input_port_id, ); + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?output_state, "Connected!")); + conn_confirm .send(Ok(())) .await @@ -697,11 +817,17 @@ impl ZmqTransport { let mut seq_id = 1; 'send: loop { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "send_loop", ?seq_id); + let (request, response_chan) = msg_req_recv .recv() .await .expect("output worker recv msg req"); + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?request, "sending request")); + match request { ZmqOutputPortRequest::Close => { let response = pub_queue @@ -715,6 +841,8 @@ impl ZmqTransport { { let mut handles = Vec::new(); for topic in output_topics(output_port_id, input_port_id).into_iter() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?topic, "sending unsubscription request")); let handle = sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); handles.push(handle); @@ -746,10 +874,15 @@ impl ZmqTransport { .await .expect("output worker event recv"); + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?event, "received event")); + use ZmqTransportEvent::*; match event { AckMessage(_, _, ack_id) => { if ack_id == seq_id { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?ack_id, "msg-ack matches")); response_chan .send(Ok(())) .await @@ -774,6 +907,10 @@ impl ZmqTransport { for topic in output_topics(output_port_id, input_port_id).into_iter() { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!(?topic, "sending unsubscription request") + }); let handle = sub_queue .send(ZmqSubscriptionRequest::Unsubscribe(topic)); handles.push(handle); @@ -840,22 +977,34 @@ impl Transport for ZmqTransport { } fn open_input(&self) -> PortResult { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::open_input", "creating new input port"); + let new_id = { let inputs = self.tokio.block_on(self.inputs.read()); InputPortID::try_from(-(inputs.len() as isize + 1)) .map_err(|e| PortError::Other(e.to_string()))? }; + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::open_input", ?new_id, "created new input port"); + self.start_input_worker(new_id).map(|_| new_id) } fn open_output(&self) -> PortResult { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::open_output", "creating new output port"); + let new_id = { let outputs = self.tokio.block_on(self.outputs.read()); OutputPortID::try_from(outputs.len() as isize + 1) .map_err(|e| PortError::Other(e.to_string()))? }; + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::open_output", ?new_id, "created new output port"); + self.start_output_worker(new_id).map(|_| new_id) } @@ -922,6 +1071,9 @@ impl Transport for ZmqTransport { } fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::connect", ?source, ?target, "connecting ports"); + self.tokio.block_on(async { let sender = { let outputs = self.outputs.read().await; @@ -953,6 +1105,9 @@ impl Transport for ZmqTransport { } fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::send", ?output, "sending from output port"); + self.tokio.block_on(async { let sender = { let outputs = self.outputs.read().await; @@ -980,6 +1135,9 @@ impl Transport for ZmqTransport { } fn recv(&self, input: InputPortID) -> PortResult> { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::recv", ?input, "receiving from input port"); + self.tokio.block_on(async { let receiver = { let inputs = self.inputs.read().await; @@ -1007,6 +1165,9 @@ impl Transport for ZmqTransport { } fn try_recv(&self, input: InputPortID) -> PortResult> { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::try_recv", ?input, "receiving from input port"); + self.tokio.block_on(async { let receiver = { let inputs = self.inputs.read().await; @@ -1045,6 +1206,9 @@ async fn handle_zmq_msg( todo!(); }; + #[cfg(feature = "tracing")] + trace!(target: "handle_zmq_msg", ?event, "got event"); + use ZmqTransportEvent::*; match event { // input ports @@ -1205,6 +1369,8 @@ mod tests { #[test] fn run_transport() { + tracing_subscriber::fmt::init(); + let rt = tokio::runtime::Runtime::new().unwrap(); let _guard = rt.enter(); From 24cd68e34b889bd0bf8f1017c970777d0846b075 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 13:51:51 +0200 Subject: [PATCH 43/63] Move InputPort to it's own module --- lib/protoflow-zeromq/src/input_port.rs | 357 +++++++++++++++++++++++++ lib/protoflow-zeromq/src/lib.rs | 350 +----------------------- 2 files changed, 364 insertions(+), 343 deletions(-) create mode 100644 lib/protoflow-zeromq/src/input_port.rs diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs new file mode 100644 index 00000000..510b8c35 --- /dev/null +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -0,0 +1,357 @@ +// This is free and unencumbered software released into the public domain. + +use crate::{ZmqSubscriptionRequest, ZmqTransport, ZmqTransportEvent}; +use protoflow_core::{ + prelude::{format, vec, Arc, BTreeMap, Bytes, String, Vec}, + InputPortID, OutputPortID, PortError, PortState, +}; +use tokio::sync::{ + mpsc::{channel, Receiver, Sender}, + Mutex, RwLock, +}; + +#[cfg(feature = "tracing")] +use tracing::{trace, trace_span}; + +#[derive(Debug, Clone)] +pub enum ZmqInputPortRequest { + Close, +} + +/// ZmqInputPortEvent represents events that we receive from the background worker of the port. +#[derive(Clone, Debug)] +pub enum ZmqInputPortEvent { + Message(Bytes), + Closed, +} + +#[derive(Debug, Clone)] +pub enum ZmqInputPortState { + Open(Sender), + Connected( + // channel for requests from public close + Sender<(ZmqInputPortRequest, Sender>)>, + // channel for the public recv + Sender, + Arc>>, + // internal channel for events + Sender, + // vec of the connected port ids + Vec, + ), + Closed, +} + +impl ZmqInputPortState { + pub fn state(&self) -> PortState { + use ZmqInputPortState::*; + match self { + Open(_) => PortState::Open, + Connected(_, _, _, _, _) => PortState::Connected, + Closed => PortState::Closed, + } + } +} + +fn input_topics(id: InputPortID) -> Vec { + vec![ + format!("{}:conn", id), + format!("{}:msg", id), + format!("{}:closeOut", id), + ] +} + +pub fn start_input_worker( + transport: &ZmqTransport, + input_port_id: InputPortID, +) -> Result<(), PortError> { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::start_input_worker", ?input_port_id); + + let (to_worker_send, mut to_worker_recv) = channel(1); + + let (req_send, mut req_recv) = channel(1); + + { + let mut inputs = transport.tokio.block_on(transport.inputs.write()); + if inputs.contains_key(&input_port_id) { + return Ok(()); // TODO + } + let state = ZmqInputPortState::Open(to_worker_send.clone()); + let state = RwLock::new(state); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?state, "saving new opened state")); + + inputs.insert(input_port_id, state); + } + + { + let mut handles = Vec::new(); + for topic in input_topics(input_port_id).into_iter() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?topic, "sending subscription request")); + + let handle = transport + .sub_queue + .send(ZmqSubscriptionRequest::Subscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + transport + .tokio + .block_on(handle) + .expect("input worker send sub req"); + } + } + + let sub_queue = transport.sub_queue.clone(); + let pub_queue = transport.pub_queue.clone(); + let inputs = transport.inputs.clone(); + + async fn handle_socket_event( + event: ZmqTransportEvent, + inputs: &RwLock>>, + req_send: &Sender<(ZmqInputPortRequest, Sender>)>, + pub_queue: &Sender, + input_port_id: InputPortID, + ) { + #[cfg(feature = "tracing")] + let span = trace_span!( + "ZmqTransport::start_input_worker::handle_socket_event", + ?input_port_id + ); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?event, "got socket event")); + + use ZmqTransportEvent::*; + match event { + Connect(output_port_id, input_port_id) => { + let inputs = inputs.read().await; + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let mut input_state = input_state.write().await; + + use ZmqInputPortState::*; + match &*input_state { + Open(_) => (), + Connected(_, _, _, _, connected_ids) => { + if connected_ids.iter().any(|&id| id == output_port_id) { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!(?output_port_id, "output port is already connected") + }); + return; + } + } + Closed => return, + }; + + let add_connection = |input_state: &mut ZmqInputPortState| match input_state { + Open(to_worker_send) => { + let (msgs_send, msgs_recv) = channel(1); + let msgs_recv = Arc::new(Mutex::new(msgs_recv)); + *input_state = Connected( + req_send.clone(), + msgs_send, + msgs_recv, + to_worker_send.clone(), + vec![output_port_id], + ); + } + Connected(_, _, _, _, ids) => { + ids.push(output_port_id); + } + Closed => unreachable!(), + }; + + pub_queue + .send(ZmqTransportEvent::AckConnection( + output_port_id, + input_port_id, + )) + .await + .expect("input worker send ack-conn event"); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?output_port_id, "sent conn-ack")); + + add_connection(&mut input_state); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?input_state, "connected new port")); + } + Message(output_port_id, _, seq_id, bytes) => { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "Message", ?output_port_id, ?seq_id); + + let inputs = inputs.read().await; + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let input_state = input_state.read().await; + + use ZmqInputPortState::*; + match &*input_state { + Connected(_, sender, _, _, connected_ids) => { + if !connected_ids.iter().any(|id| *id == output_port_id) { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("got message from non-connected output port")); + return; + } + + sender + .send(ZmqInputPortEvent::Message(bytes)) + .await + .expect("input worker send message"); + + pub_queue + .send(ZmqTransportEvent::AckMessage( + output_port_id, + input_port_id, + seq_id, + )) + .await + .expect("input worker send message ack"); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("sent msg-ack")); + } + + Open(_) | Closed => todo!(), + } + } + CloseOutput(output_port_id, input_port_id) => { + let inputs = inputs.read().await; + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + + let mut input_state = input_state.write().await; + + use ZmqInputPortState::*; + let Connected(_, _, _, _, ref connected_ids) = *input_state else { + return; + }; + + if !connected_ids.iter().any(|id| *id == output_port_id) { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!( + ?output_port_id, + "output port doesn't match any connected port" + ) + }); + return; + } + + match *input_state { + Open(_) | Closed => (), + Connected(_, ref sender, _, _, ref mut connected_ids) => { + connected_ids.retain(|&id| id != output_port_id); + if connected_ids.is_empty() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("last connected port disconnected")); + sender + .send(ZmqInputPortEvent::Closed) + .await + .expect("input worker publish Closed event"); + } + } + }; + } + + // ignore, ideally we never receive these here: + AckConnection(_, _) | AckMessage(_, _, _) | CloseInput(_) => (), + } + } + + async fn handle_input_request( + request: ZmqInputPortRequest, + response_chan: Sender>, + inputs: &RwLock>>, + pub_queue: &Sender, + sub_queue: &Sender, + input_port_id: InputPortID, + ) { + #[cfg(feature = "tracing")] + let span = trace_span!( + "ZmqTransport::start_input_worker::handle_input_event", + ?input_port_id + ); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?request, "got input request")); + + use ZmqInputPortRequest::*; + match request { + Close => { + let inputs = inputs.read().await; + let Some(input_state) = inputs.get(&input_port_id) else { + todo!(); + }; + let mut input_state = input_state.write().await; + + use ZmqInputPortState::*; + let Connected(_, ref port_events, _, _, _) = *input_state else { + return; + }; + + pub_queue + .send(ZmqTransportEvent::CloseInput(input_port_id)) + .await + .expect("input worker send close event"); + + port_events + .send(ZmqInputPortEvent::Closed) + .await + .expect("input worker send port closed"); + + *input_state = ZmqInputPortState::Closed; + + { + let mut handles = Vec::new(); + for topic in input_topics(input_port_id).into_iter() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?topic, "sending unsubscription request")); + + let handle = sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + handle.await.expect("input worker send unsub req"); + } + } + + response_chan + .send(Ok(())) + .await + .expect("input worker respond close") + } + } + } + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("spawning")); + + tokio::task::spawn(async move { + // Input worker loop: + // 1. Receive connection attempts and respond + // 2. Receive messages and forward to channel + // 3. Receive and handle disconnects + loop { + tokio::select! { + Some(event) = to_worker_recv.recv() => { + handle_socket_event(event, &inputs, &req_send, &pub_queue, input_port_id).await; + } + Some((request, response_chan)) = req_recv.recv() => { + handle_input_request(request, response_chan, &inputs, &pub_queue, &sub_queue, input_port_id).await; + } + }; + } + }); + + Ok(()) +} diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 880b6eeb..6dc17caf 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -9,6 +9,9 @@ pub use protoflow_core::prelude; #[path = "protoflow.zmq.rs"] mod protoflow_zmq; +mod input_port; +use input_port::*; + extern crate std; use protoflow_core::{ @@ -74,39 +77,6 @@ impl ZmqOutputPortState { } } -#[derive(Debug, Clone)] -enum ZmqInputPortState { - Open(Sender), - Connected( - // channel for requests from public close - Sender<(ZmqInputPortRequest, Sender>)>, - // channel for the public recv - Sender, - Arc>>, - // internal channel for events - Sender, - // vec of the connected port ids - Vec, - ), - Closed, -} - -#[derive(Debug, Clone)] -enum ZmqInputPortRequest { - Close, -} - -impl ZmqInputPortState { - fn state(&self) -> PortState { - use ZmqInputPortState::*; - match self { - Open(_) => PortState::Open, - Connected(_, _, _, _, _) => PortState::Connected, - Closed => PortState::Closed, - } - } -} - type SequenceID = u64; /// ZmqTransportEvent represents the data that goes over the wire from one port to another. @@ -134,14 +104,6 @@ impl ZmqTransportEvent { } } -fn input_topics(id: InputPortID) -> Vec { - vec![ - format!("{}:conn", id), - format!("{}:msg", id), - format!("{}:closeOut", id), - ] -} - fn output_topics(source: OutputPortID, target: InputPortID) -> Vec { vec![ format!("{}:ackConn:{}", target, source), @@ -276,13 +238,6 @@ impl TryFrom for ZmqTransportEvent { } } -/// ZmqInputPortEvent represents events that we receive from the background worker of the port. -#[derive(Clone, Debug)] -enum ZmqInputPortEvent { - Message(Bytes), - Closed, -} - impl Default for ZmqTransport { fn default() -> Self { Self::new(DEFAULT_PUB_SOCKET, DEFAULT_SUB_SOCKET) @@ -410,299 +365,6 @@ impl ZmqTransport { }); } - fn start_input_worker(&self, input_port_id: InputPortID) -> Result<(), PortError> { - #[cfg(feature = "tracing")] - let span = trace_span!("ZmqTransport::start_input_worker", ?input_port_id); - - let (to_worker_send, mut to_worker_recv) = sync_channel(1); - - let (req_send, mut req_recv) = sync_channel(1); - - { - let mut inputs = self.tokio.block_on(self.inputs.write()); - if inputs.contains_key(&input_port_id) { - return Ok(()); // TODO - } - let state = ZmqInputPortState::Open(to_worker_send.clone()); - let state = RwLock::new(state); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?state, "saving new opened state")); - - inputs.insert(input_port_id, state); - } - - { - let mut handles = Vec::new(); - for topic in input_topics(input_port_id).into_iter() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?topic, "sending subscription request")); - - let handle = self - .sub_queue - .send(ZmqSubscriptionRequest::Subscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - self.tokio - .block_on(handle) - .expect("input worker send sub req"); - } - } - - let sub_queue = self.sub_queue.clone(); - let pub_queue = self.pub_queue.clone(); - let inputs = self.inputs.clone(); - - async fn handle_socket_event( - event: ZmqTransportEvent, - inputs: &RwLock>>, - req_send: &Sender<(ZmqInputPortRequest, Sender>)>, - pub_queue: &Sender, - input_port_id: InputPortID, - ) { - #[cfg(feature = "tracing")] - let span = trace_span!( - "ZmqTransport::start_input_worker::handle_socket_event", - ?input_port_id - ); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?event, "got socket event")); - - use ZmqTransportEvent::*; - match event { - Connect(output_port_id, input_port_id) => { - let inputs = inputs.read().await; - let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); - }; - let mut input_state = input_state.write().await; - - use ZmqInputPortState::*; - match &*input_state { - Open(_) => (), - Connected(_, _, _, _, connected_ids) => { - if connected_ids.iter().any(|&id| id == output_port_id) { - #[cfg(feature = "tracing")] - span.in_scope(|| { - trace!(?output_port_id, "output port is already connected") - }); - return; - } - } - Closed => return, - }; - - let add_connection = |input_state: &mut ZmqInputPortState| match input_state { - Open(to_worker_send) => { - let (msgs_send, msgs_recv) = sync_channel(1); - let msgs_recv = Arc::new(Mutex::new(msgs_recv)); - *input_state = Connected( - req_send.clone(), - msgs_send, - msgs_recv, - to_worker_send.clone(), - vec![output_port_id], - ); - } - Connected(_, _, _, _, ids) => { - ids.push(output_port_id); - } - Closed => unreachable!(), - }; - - pub_queue - .send(ZmqTransportEvent::AckConnection( - output_port_id, - input_port_id, - )) - .await - .expect("input worker send ack-conn event"); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?output_port_id, "sent conn-ack")); - - add_connection(&mut input_state); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?input_state, "connected new port")); - } - Message(output_port_id, _, seq_id, bytes) => { - #[cfg(feature = "tracing")] - let span = trace_span!(parent: &span, "Message", ?output_port_id, ?seq_id); - - let inputs = inputs.read().await; - let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); - }; - let input_state = input_state.read().await; - - use ZmqInputPortState::*; - match &*input_state { - Connected(_, sender, _, _, connected_ids) => { - if !connected_ids.iter().any(|id| *id == output_port_id) { - #[cfg(feature = "tracing")] - span.in_scope(|| { - trace!("got message from non-connected output port") - }); - return; - } - - sender - .send(ZmqInputPortEvent::Message(bytes)) - .await - .expect("input worker send message"); - - pub_queue - .send(ZmqTransportEvent::AckMessage( - output_port_id, - input_port_id, - seq_id, - )) - .await - .expect("input worker send message ack"); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!("sent msg-ack")); - } - - Open(_) | Closed => todo!(), - } - } - CloseOutput(output_port_id, input_port_id) => { - let inputs = inputs.read().await; - let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); - }; - - let mut input_state = input_state.write().await; - - use ZmqInputPortState::*; - let Connected(_, _, _, _, ref connected_ids) = *input_state else { - return; - }; - - if !connected_ids.iter().any(|id| *id == output_port_id) { - #[cfg(feature = "tracing")] - span.in_scope(|| { - trace!( - ?output_port_id, - "output port doesn't match any connected port" - ) - }); - return; - } - - match *input_state { - Open(_) | Closed => (), - Connected(_, ref sender, _, _, ref mut connected_ids) => { - connected_ids.retain(|&id| id != output_port_id); - if connected_ids.is_empty() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!("last connected port disconnected")); - sender - .send(ZmqInputPortEvent::Closed) - .await - .expect("input worker publish Closed event"); - } - } - }; - } - - // ignore, ideally we never receive these here: - AckConnection(_, _) | AckMessage(_, _, _) | CloseInput(_) => (), - } - } - - async fn handle_input_request( - request: ZmqInputPortRequest, - response_chan: Sender>, - inputs: &RwLock>>, - pub_queue: &Sender, - sub_queue: &Sender, - input_port_id: InputPortID, - ) { - #[cfg(feature = "tracing")] - let span = trace_span!( - "ZmqTransport::start_input_worker::handle_input_event", - ?input_port_id - ); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?request, "got input request")); - - use ZmqInputPortRequest::*; - match request { - Close => { - let inputs = inputs.read().await; - let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); - }; - let mut input_state = input_state.write().await; - - use ZmqInputPortState::*; - let Connected(_, ref port_events, _, _, _) = *input_state else { - return; - }; - - pub_queue - .send(ZmqTransportEvent::CloseInput(input_port_id)) - .await - .expect("input worker send close event"); - - port_events - .send(ZmqInputPortEvent::Closed) - .await - .expect("input worker send port closed"); - - *input_state = ZmqInputPortState::Closed; - - { - let mut handles = Vec::new(); - for topic in input_topics(input_port_id).into_iter() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?topic, "sending unsubscription request")); - - let handle = sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - handle.await.expect("input worker send unsub req"); - } - } - - response_chan - .send(Ok(())) - .await - .expect("input worker respond close") - } - } - } - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!("spawning")); - - tokio::task::spawn(async move { - // Input worker loop: - // 1. Receive connection attempts and respond - // 2. Receive messages and forward to channel - // 3. Receive and handle disconnects - loop { - tokio::select! { - Some(event) = to_worker_recv.recv() => { - handle_socket_event(event, &inputs, &req_send, &pub_queue, input_port_id).await; - } - Some((request, response_chan)) = req_recv.recv() => { - handle_input_request(request, response_chan, &inputs, &pub_queue, &sub_queue, input_port_id).await; - } - }; - } - }); - - Ok(()) - } - fn start_output_worker(&self, output_port_id: OutputPortID) -> Result<(), PortError> { #[cfg(feature = "tracing")] let span = trace_span!("ZmqTransport::start_output_worker", ?output_port_id); @@ -769,6 +431,9 @@ impl ZmqTransport { // 3. Send disconnect events loop { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "connect_loop"); + #[cfg(feature = "tracing")] span.in_scope(|| trace!("sending connection attempt...")); @@ -989,7 +654,7 @@ impl Transport for ZmqTransport { #[cfg(feature = "tracing")] trace!(target: "ZmqTransport::open_input", ?new_id, "created new input port"); - self.start_input_worker(new_id).map(|_| new_id) + start_input_worker(self, new_id).map(|_| new_id) } fn open_output(&self) -> PortResult { @@ -1361,7 +1026,6 @@ mod tests { let rt = tokio::runtime::Runtime::new().unwrap(); let _guard = rt.enter(); - //zeromq::proxy(frontend, backend, capture) rt.block_on(start_zmqtransport_server()); let _ = System::::build(|_s| { /* do nothing */ }); From b2c48446f35acb89bb078033b16a8a857b2b8c5e Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 14:37:24 +0200 Subject: [PATCH 44/63] Move OutputPort to it's own module --- lib/protoflow-zeromq/src/lib.rs | 307 +----------------------- lib/protoflow-zeromq/src/output_port.rs | 301 +++++++++++++++++++++++ 2 files changed, 312 insertions(+), 296 deletions(-) create mode 100644 lib/protoflow-zeromq/src/output_port.rs diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 6dc17caf..cd688864 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -12,6 +12,9 @@ mod protoflow_zmq; mod input_port; use input_port::*; +mod output_port; +use output_port::*; + extern crate std; use protoflow_core::{ @@ -22,8 +25,8 @@ use protoflow_core::{ use core::fmt::Error; use std::{format, write}; use tokio::sync::{ - mpsc::{channel as sync_channel, error::TryRecvError, Receiver, Sender}, - Mutex, RwLock, + mpsc::{channel, error::TryRecvError, Receiver, Sender}, + RwLock, }; use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqMessage}; @@ -43,40 +46,6 @@ pub struct ZmqTransport { inputs: Arc>>>, } -#[derive(Debug, Clone)] -enum ZmqOutputPortState { - Open( - Sender<(InputPortID, Sender>)>, - Sender, - ), - Connected( - // channel for public send, contained channel is for the ack back - Sender<(ZmqOutputPortRequest, Sender>)>, - // internal channel for events - Sender, - // id of the connected input port - InputPortID, - ), - Closed, -} - -#[derive(Debug, Clone)] -enum ZmqOutputPortRequest { - Close, - Send(Bytes), -} - -impl ZmqOutputPortState { - fn state(&self) -> PortState { - use ZmqOutputPortState::*; - match self { - Open(_, _) => PortState::Open, - Connected(_, _, _) => PortState::Connected, - Closed => PortState::Closed, - } - } -} - type SequenceID = u64; /// ZmqTransportEvent represents the data that goes over the wire from one port to another. @@ -104,14 +73,6 @@ impl ZmqTransportEvent { } } -fn output_topics(source: OutputPortID, target: InputPortID) -> Vec { - vec![ - format!("{}:ackConn:{}", target, source), - format!("{}:ackMsg:{}:", target, source), - format!("{}:closeIn", target), - ] -} - impl From for ZmqMessage { fn from(value: ZmqTransportEvent) -> Self { let mut topic = Vec::new(); @@ -282,7 +243,7 @@ impl ZmqTransport { let outputs = Arc::new(RwLock::new(BTreeMap::default())); let inputs = Arc::new(RwLock::new(BTreeMap::default())); - let (pub_queue, pub_queue_recv) = sync_channel(1); + let (pub_queue, pub_queue_recv) = channel(1); let (sub_queue, sub_queue_recv) = tokio::sync::mpsc::channel(1); @@ -364,252 +325,6 @@ impl ZmqTransport { } }); } - - fn start_output_worker(&self, output_port_id: OutputPortID) -> Result<(), PortError> { - #[cfg(feature = "tracing")] - let span = trace_span!("ZmqTransport::start_output_worker", ?output_port_id); - - let (conn_send, mut conn_recv) = sync_channel(1); - - let (to_worker_send, mut to_worker_recv) = sync_channel(1); - - { - let mut outputs = self.tokio.block_on(self.outputs.write()); - - if outputs.contains_key(&output_port_id) { - return Ok(()); // TODO - } - let state = ZmqOutputPortState::Open(conn_send, to_worker_send.clone()); - let state = RwLock::new(state); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?state, "saving new opened state")); - - outputs.insert(output_port_id, state); - } - - let sub_queue = self.sub_queue.clone(); - let pub_queue = self.pub_queue.clone(); - let outputs = self.outputs.clone(); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!("spawning")); - - tokio::task::spawn(async move { - let Some((input_port_id, conn_confirm)) = conn_recv.recv().await else { - todo!(); - }; - - #[cfg(feature = "tracing")] - let span = trace_span!( - "ZmqTransport::start_output_worker::spawn", - ?output_port_id, - ?input_port_id - ); - - { - let mut handles = Vec::new(); - for topic in output_topics(output_port_id, input_port_id).into_iter() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?topic, "sending subscription request")); - - let handle = sub_queue.send(ZmqSubscriptionRequest::Subscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - handle.await.expect("output worker send sub req"); - } - } - - let (msg_req_send, mut msg_req_recv) = sync_channel(1); - - // Output worker loop: - // 1. Send connection attempt - // 2. Send messages - // 2.1 Wait for ACK - // 2.2. Resend on timeout - // 3. Send disconnect events - - loop { - #[cfg(feature = "tracing")] - let span = trace_span!(parent: &span, "connect_loop"); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!("sending connection attempt...")); - - pub_queue - .send(ZmqTransportEvent::Connect(output_port_id, input_port_id)) - .await - .expect("output worker send connect event"); - - let response = to_worker_recv - .recv() - .await - .expect("output worker recv ack-conn event"); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?response, "got response")); - - use ZmqTransportEvent::*; - match response { - AckConnection(_, input_port_id) => { - let outputs = outputs.read().await; - let Some(output_state) = outputs.get(&output_port_id) else { - todo!(); - }; - let mut output_state = output_state.write().await; - debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); - *output_state = ZmqOutputPortState::Connected( - msg_req_send, - to_worker_send, - input_port_id, - ); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?output_state, "Connected!")); - - conn_confirm - .send(Ok(())) - .await - .expect("output worker respond conn"); - drop(conn_confirm); - - break; - } - _ => continue, - } - } - - let mut seq_id = 1; - 'send: loop { - #[cfg(feature = "tracing")] - let span = trace_span!(parent: &span, "send_loop", ?seq_id); - - let (request, response_chan) = msg_req_recv - .recv() - .await - .expect("output worker recv msg req"); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?request, "sending request")); - - match request { - ZmqOutputPortRequest::Close => { - let response = pub_queue - .send(ZmqTransportEvent::CloseOutput( - output_port_id, - input_port_id, - )) - .await - .map_err(|e| PortError::Other(e.to_string())); - - { - let mut handles = Vec::new(); - for topic in output_topics(output_port_id, input_port_id).into_iter() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?topic, "sending unsubscription request")); - let handle = - sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - handle.await.expect("output worker send unsub req"); - } - } - - response_chan - .send(response) - .await - .expect("output worker respond close"); - } - ZmqOutputPortRequest::Send(bytes) => { - pub_queue - .send(ZmqTransportEvent::Message( - output_port_id, - input_port_id, - seq_id, - bytes, - )) - .await - .expect("output worker send message event"); - - 'recv: loop { - let event = to_worker_recv - .recv() - .await - .expect("output worker event recv"); - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?event, "received event")); - - use ZmqTransportEvent::*; - match event { - AckMessage(_, _, ack_id) => { - if ack_id == seq_id { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?ack_id, "msg-ack matches")); - response_chan - .send(Ok(())) - .await - .expect("output worker respond send"); - break 'recv; - } - } - CloseInput(_) => { - let outputs = outputs.read().await; - let Some(output_state) = outputs.get(&output_port_id) else { - todo!(); - }; - let mut output_state = output_state.write().await; - debug_assert!(matches!( - *output_state, - ZmqOutputPortState::Connected(..) - )); - *output_state = ZmqOutputPortState::Closed; - - { - let mut handles = Vec::new(); - for topic in - output_topics(output_port_id, input_port_id).into_iter() - { - #[cfg(feature = "tracing")] - span.in_scope(|| { - trace!(?topic, "sending unsubscription request") - }); - let handle = sub_queue - .send(ZmqSubscriptionRequest::Unsubscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - handle.await.expect("output worker send unsub req"); - } - } - - response_chan - .send(Err(PortError::Disconnected)) - .await - .expect("output worker respond msg"); - - break 'send; - } - - // ignore others, we shouldn't receive any new conn-acks - // nor should we be receiving input port events - AckConnection(_, _) - | Connect(_, _) - | Message(_, _, _, _) - | CloseOutput(_, _) => continue, - } - } - } - } - - seq_id += 1; - } - }); - - Ok(()) - } } impl Transport for ZmqTransport { @@ -670,7 +385,7 @@ impl Transport for ZmqTransport { #[cfg(feature = "tracing")] trace!(target: "ZmqTransport::open_output", ?new_id, "created new output port"); - self.start_output_worker(new_id).map(|_| new_id) + start_output_worker(self, new_id).map(|_| new_id) } fn close_input(&self, input: InputPortID) -> PortResult { @@ -689,7 +404,7 @@ impl Transport for ZmqTransport { sender.clone() }; - let (close_send, mut close_recv) = sync_channel(1); + let (close_send, mut close_recv) = channel(1); sender .send((ZmqInputPortRequest::Close, close_send)) @@ -720,7 +435,7 @@ impl Transport for ZmqTransport { sender.clone() }; - let (close_send, mut close_recv) = sync_channel(1); + let (close_send, mut close_recv) = channel(1); sender .send((ZmqOutputPortRequest::Close, close_send)) @@ -754,7 +469,7 @@ impl Transport for ZmqTransport { sender.clone() }; - let (confirm_send, mut confirm_recv) = sync_channel(1); + let (confirm_send, mut confirm_recv) = channel(1); sender .send((target, confirm_send)) @@ -788,7 +503,7 @@ impl Transport for ZmqTransport { sender.clone() }; - let (ack_send, mut ack_recv) = sync_channel(1); + let (ack_send, mut ack_recv) = channel(1); sender .send((ZmqOutputPortRequest::Send(message), ack_send)) diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs new file mode 100644 index 00000000..3f5d1e42 --- /dev/null +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -0,0 +1,301 @@ +// This is free and unencumbered software released into the public domain. + +use crate::{ZmqSubscriptionRequest, ZmqTransport, ZmqTransportEvent}; +use protoflow_core::{ + prelude::{format, vec, Bytes, String, ToString, Vec}, + InputPortID, OutputPortID, PortError, PortState, +}; +use tokio::sync::{ + mpsc::{channel, Sender}, + RwLock, +}; + +#[cfg(feature = "tracing")] +use tracing::{trace, trace_span}; + +#[derive(Debug, Clone)] +pub enum ZmqOutputPortState { + Open( + Sender<(InputPortID, Sender>)>, + Sender, + ), + Connected( + // channel for public send, contained channel is for the ack back + Sender<(ZmqOutputPortRequest, Sender>)>, + // internal channel for events + Sender, + // id of the connected input port + InputPortID, + ), + Closed, +} + +#[derive(Debug, Clone)] +pub enum ZmqOutputPortRequest { + Close, + Send(Bytes), +} + +impl ZmqOutputPortState { + pub fn state(&self) -> PortState { + use ZmqOutputPortState::*; + match self { + Open(_, _) => PortState::Open, + Connected(_, _, _) => PortState::Connected, + Closed => PortState::Closed, + } + } +} + +fn output_topics(source: OutputPortID, target: InputPortID) -> Vec { + vec![ + format!("{}:ackConn:{}", target, source), + format!("{}:ackMsg:{}:", target, source), + format!("{}:closeIn", target), + ] +} + +pub fn start_output_worker( + transport: &ZmqTransport, + output_port_id: OutputPortID, +) -> Result<(), PortError> { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::start_output_worker", ?output_port_id); + + let (conn_send, mut conn_recv) = channel(1); + + let (to_worker_send, mut to_worker_recv) = channel(1); + + { + let mut outputs = transport.tokio.block_on(transport.outputs.write()); + + if outputs.contains_key(&output_port_id) { + return Ok(()); // TODO + } + let state = ZmqOutputPortState::Open(conn_send, to_worker_send.clone()); + let state = RwLock::new(state); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?state, "saving new opened state")); + + outputs.insert(output_port_id, state); + } + + let sub_queue = transport.sub_queue.clone(); + let pub_queue = transport.pub_queue.clone(); + let outputs = transport.outputs.clone(); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("spawning")); + + tokio::task::spawn(async move { + let Some((input_port_id, conn_confirm)) = conn_recv.recv().await else { + todo!(); + }; + + #[cfg(feature = "tracing")] + let span = trace_span!( + "ZmqTransport::start_output_worker::spawn", + ?output_port_id, + ?input_port_id + ); + + { + let mut handles = Vec::new(); + for topic in output_topics(output_port_id, input_port_id).into_iter() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?topic, "sending subscription request")); + + let handle = sub_queue.send(ZmqSubscriptionRequest::Subscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + handle.await.expect("output worker send sub req"); + } + } + + let (msg_req_send, mut msg_req_recv) = channel(1); + + // Output worker loop: + // 1. Send connection attempt + // 2. Send messages + // 2.1 Wait for ACK + // 2.2. Resend on timeout + // 3. Send disconnect events + + loop { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "connect_loop"); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("sending connection attempt...")); + + pub_queue + .send(ZmqTransportEvent::Connect(output_port_id, input_port_id)) + .await + .expect("output worker send connect event"); + + let response = to_worker_recv + .recv() + .await + .expect("output worker recv ack-conn event"); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?response, "got response")); + + use ZmqTransportEvent::*; + match response { + AckConnection(_, input_port_id) => { + let outputs = outputs.read().await; + let Some(output_state) = outputs.get(&output_port_id) else { + todo!(); + }; + let mut output_state = output_state.write().await; + debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); + *output_state = + ZmqOutputPortState::Connected(msg_req_send, to_worker_send, input_port_id); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?output_state, "Connected!")); + + conn_confirm + .send(Ok(())) + .await + .expect("output worker respond conn"); + drop(conn_confirm); + + break; + } + _ => continue, + } + } + + let mut seq_id = 1; + 'send: loop { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "send_loop", ?seq_id); + + let (request, response_chan) = msg_req_recv + .recv() + .await + .expect("output worker recv msg req"); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?request, "sending request")); + + match request { + ZmqOutputPortRequest::Close => { + let response = pub_queue + .send(ZmqTransportEvent::CloseOutput( + output_port_id, + input_port_id, + )) + .await + .map_err(|e| PortError::Other(e.to_string())); + + { + let mut handles = Vec::new(); + for topic in output_topics(output_port_id, input_port_id).into_iter() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?topic, "sending unsubscription request")); + let handle = sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + handle.await.expect("output worker send unsub req"); + } + } + + response_chan + .send(response) + .await + .expect("output worker respond close"); + } + ZmqOutputPortRequest::Send(bytes) => { + pub_queue + .send(ZmqTransportEvent::Message( + output_port_id, + input_port_id, + seq_id, + bytes, + )) + .await + .expect("output worker send message event"); + + 'recv: loop { + let event = to_worker_recv + .recv() + .await + .expect("output worker event recv"); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?event, "received event")); + + use ZmqTransportEvent::*; + match event { + AckMessage(_, _, ack_id) => { + if ack_id == seq_id { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?ack_id, "msg-ack matches")); + response_chan + .send(Ok(())) + .await + .expect("output worker respond send"); + break 'recv; + } + } + CloseInput(_) => { + let outputs = outputs.read().await; + let Some(output_state) = outputs.get(&output_port_id) else { + todo!(); + }; + let mut output_state = output_state.write().await; + debug_assert!(matches!( + *output_state, + ZmqOutputPortState::Connected(..) + )); + *output_state = ZmqOutputPortState::Closed; + + { + let mut handles = Vec::new(); + for topic in + output_topics(output_port_id, input_port_id).into_iter() + { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!(?topic, "sending unsubscription request") + }); + let handle = sub_queue + .send(ZmqSubscriptionRequest::Unsubscribe(topic)); + handles.push(handle); + } + for handle in handles.into_iter() { + handle.await.expect("output worker send unsub req"); + } + } + + response_chan + .send(Err(PortError::Disconnected)) + .await + .expect("output worker respond msg"); + + break 'send; + } + + // ignore others, we shouldn't receive any new conn-acks + // nor should we be receiving input port events + AckConnection(_, _) + | Connect(_, _) + | Message(_, _, _, _) + | CloseOutput(_, _) => continue, + } + } + } + } + + seq_id += 1; + } + }); + + Ok(()) +} From e141688706c72ed4c0edc8043d985af941777f6c Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 14:46:45 +0200 Subject: [PATCH 45/63] Move the socket helpers and messages to their own modules --- lib/protoflow-zeromq/src/event.rs | 160 +++++++++++++ lib/protoflow-zeromq/src/lib.rs | 369 +---------------------------- lib/protoflow-zeromq/src/socket.rs | 205 ++++++++++++++++ 3 files changed, 377 insertions(+), 357 deletions(-) create mode 100644 lib/protoflow-zeromq/src/event.rs create mode 100644 lib/protoflow-zeromq/src/socket.rs diff --git a/lib/protoflow-zeromq/src/event.rs b/lib/protoflow-zeromq/src/event.rs new file mode 100644 index 00000000..1264d6d4 --- /dev/null +++ b/lib/protoflow-zeromq/src/event.rs @@ -0,0 +1,160 @@ +// This is free and unencumbered software released into the public domain. + +use protoflow_core::{ + prelude::{Bytes, Vec}, + InputPortID, OutputPortID, +}; +use zeromq::ZmqMessage; + +pub type SequenceID = u64; + +/// ZmqTransportEvent represents the data that goes over the wire from one port to another. +#[derive(Clone, Debug)] +pub enum ZmqTransportEvent { + Connect(OutputPortID, InputPortID), + AckConnection(OutputPortID, InputPortID), + Message(OutputPortID, InputPortID, SequenceID, Bytes), + AckMessage(OutputPortID, InputPortID, SequenceID), + CloseOutput(OutputPortID, InputPortID), + CloseInput(InputPortID), +} + +impl ZmqTransportEvent { + fn write_topic(&self, f: &mut W) -> Result<(), std::io::Error> { + use ZmqTransportEvent::*; + match self { + Connect(o, i) => write!(f, "{}:conn:{}", i, o), + AckConnection(o, i) => write!(f, "{}:ackConn:{}", i, o), + Message(o, i, seq, _) => write!(f, "{}:msg:{}:{}", i, o, seq), + AckMessage(o, i, seq) => write!(f, "{}:ackMsg:{}:{}", i, o, seq), + CloseOutput(o, i) => write!(f, "{}:closeOut:{}", i, o), + CloseInput(i) => write!(f, "{}:closeIn", i), + } + } +} + +impl From for ZmqMessage { + fn from(value: ZmqTransportEvent) -> Self { + let mut topic = Vec::new(); + value.write_topic(&mut topic).unwrap(); + + // first frame of the message is the topic + let mut msg = ZmqMessage::from(topic); + + fn map_id(id: T) -> i64 + where + isize: From, + { + isize::from(id) as i64 + } + + // second frame of the message is the payload + use crate::protoflow_zmq::{self, event::Payload, Event}; + use prost::Message; + use ZmqTransportEvent::*; + let payload = match value { + Connect(output, input) => Payload::Connect(protoflow_zmq::Connect { + output: map_id(output), + input: map_id(input), + }), + AckConnection(output, input) => Payload::AckConnection(protoflow_zmq::AckConnection { + output: map_id(output), + input: map_id(input), + }), + Message(output, input, sequence, message) => Payload::Message(protoflow_zmq::Message { + output: map_id(output), + input: map_id(input), + sequence, + message: message.to_vec(), + }), + AckMessage(output, input, sequence) => Payload::AckMessage(protoflow_zmq::AckMessage { + output: map_id(output), + input: map_id(input), + sequence, + }), + CloseOutput(output, input) => Payload::CloseOutput(protoflow_zmq::CloseOutput { + output: map_id(output), + input: map_id(input), + }), + CloseInput(input) => Payload::CloseInput(protoflow_zmq::CloseInput { + input: map_id(input), + }), + }; + + let bytes = Event { + payload: Some(payload), + } + .encode_to_vec(); + msg.push_back(bytes.into()); + + msg + } +} + +impl TryFrom for ZmqTransportEvent { + type Error = protoflow_core::DecodeError; + + fn try_from(value: ZmqMessage) -> Result { + use crate::protoflow_zmq::{self, event::Payload, Event}; + use prost::Message; + use protoflow_core::DecodeError; + + fn map_id(id: i64) -> Result + where + T: TryFrom, + std::borrow::Cow<'static, str>: From<>::Error>, + { + (id as isize).try_into().map_err(DecodeError::new) + } + + value + .get(1) + .ok_or_else(|| { + protoflow_core::DecodeError::new( + "message from socket contains less than two frames", + ) + }) + .and_then(|bytes| { + let event = Event::decode(bytes.as_ref())?; + + use ZmqTransportEvent::*; + Ok(match event.payload { + None => todo!(), + Some(Payload::Connect(protoflow_zmq::Connect { output, input })) => { + Connect(map_id(output)?, map_id(input)?) + } + + Some(Payload::AckConnection(protoflow_zmq::AckConnection { + output, + input, + })) => AckConnection(map_id(output)?, map_id(input)?), + + Some(Payload::Message(protoflow_zmq::Message { + output, + input, + sequence, + message, + })) => Message( + map_id(output)?, + map_id(input)?, + sequence, + Bytes::from(message), + ), + + Some(Payload::AckMessage(protoflow_zmq::AckMessage { + output, + input, + sequence, + })) => AckMessage(map_id(output)?, map_id(input)?, sequence), + + Some(Payload::CloseOutput(protoflow_zmq::CloseOutput { output, input })) => { + CloseOutput(map_id(output)?, map_id(input)?) + } + + Some(Payload::CloseInput(protoflow_zmq::CloseInput { input })) => { + CloseInput(map_id(input)?) + } + }) + }) + } +} diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index cd688864..898f061d 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -15,23 +15,27 @@ use input_port::*; mod output_port; use output_port::*; +mod socket; +use socket::*; + +mod event; +use event::*; + extern crate std; use protoflow_core::{ - prelude::{vec, Arc, BTreeMap, Bytes, String, ToString, Vec}, + prelude::{Arc, BTreeMap, Bytes, ToString}, InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, }; -use core::fmt::Error; -use std::{format, write}; use tokio::sync::{ - mpsc::{channel, error::TryRecvError, Receiver, Sender}, + mpsc::{channel, error::TryRecvError, Sender}, RwLock, }; -use zeromq::{util::PeerIdentity, Socket, SocketOptions, SocketRecv, SocketSend, ZmqMessage}; +use zeromq::{util::PeerIdentity, Socket, SocketOptions}; #[cfg(feature = "tracing")] -use tracing::{trace, trace_span}; +use tracing::trace; const DEFAULT_PUB_SOCKET: &str = "tcp://127.0.0.1:10000"; const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; @@ -46,171 +50,12 @@ pub struct ZmqTransport { inputs: Arc>>>, } -type SequenceID = u64; - -/// ZmqTransportEvent represents the data that goes over the wire from one port to another. -#[derive(Clone, Debug)] -enum ZmqTransportEvent { - Connect(OutputPortID, InputPortID), - AckConnection(OutputPortID, InputPortID), - Message(OutputPortID, InputPortID, SequenceID, Bytes), - AckMessage(OutputPortID, InputPortID, SequenceID), - CloseOutput(OutputPortID, InputPortID), - CloseInput(InputPortID), -} - -impl ZmqTransportEvent { - fn write_topic(&self, f: &mut W) -> Result<(), std::io::Error> { - use ZmqTransportEvent::*; - match self { - Connect(o, i) => write!(f, "{}:conn:{}", i, o), - AckConnection(o, i) => write!(f, "{}:ackConn:{}", i, o), - Message(o, i, seq, _) => write!(f, "{}:msg:{}:{}", i, o, seq), - AckMessage(o, i, seq) => write!(f, "{}:ackMsg:{}:{}", i, o, seq), - CloseOutput(o, i) => write!(f, "{}:closeOut:{}", i, o), - CloseInput(i) => write!(f, "{}:closeIn", i), - } - } -} - -impl From for ZmqMessage { - fn from(value: ZmqTransportEvent) -> Self { - let mut topic = Vec::new(); - value.write_topic(&mut topic).unwrap(); - - // first frame of the message is the topic - let mut msg = ZmqMessage::from(topic); - - fn map_id(id: T) -> i64 - where - isize: From, - { - isize::from(id) as i64 - } - - // second frame of the message is the payload - use prost::Message; - use protoflow_zmq::{event::Payload, Event}; - use ZmqTransportEvent::*; - let payload = match value { - Connect(output, input) => Payload::Connect(protoflow_zmq::Connect { - output: map_id(output), - input: map_id(input), - }), - AckConnection(output, input) => Payload::AckConnection(protoflow_zmq::AckConnection { - output: map_id(output), - input: map_id(input), - }), - Message(output, input, sequence, message) => Payload::Message(protoflow_zmq::Message { - output: map_id(output), - input: map_id(input), - sequence, - message: message.to_vec(), - }), - AckMessage(output, input, sequence) => Payload::AckMessage(protoflow_zmq::AckMessage { - output: map_id(output), - input: map_id(input), - sequence, - }), - CloseOutput(output, input) => Payload::CloseOutput(protoflow_zmq::CloseOutput { - output: map_id(output), - input: map_id(input), - }), - CloseInput(input) => Payload::CloseInput(protoflow_zmq::CloseInput { - input: map_id(input), - }), - }; - - let bytes = Event { - payload: Some(payload), - } - .encode_to_vec(); - msg.push_back(bytes.into()); - - msg - } -} - -impl TryFrom for ZmqTransportEvent { - type Error = protoflow_core::DecodeError; - - fn try_from(value: ZmqMessage) -> Result { - use prost::Message; - use protoflow_core::DecodeError; - use protoflow_zmq::{event::Payload, Event}; - - fn map_id(id: i64) -> Result - where - T: TryFrom, - std::borrow::Cow<'static, str>: From<>::Error>, - { - (id as isize).try_into().map_err(DecodeError::new) - } - - value - .get(1) - .ok_or_else(|| { - protoflow_core::DecodeError::new( - "message from socket contains less than two frames", - ) - }) - .and_then(|bytes| { - let event = Event::decode(bytes.as_ref())?; - - use ZmqTransportEvent::*; - Ok(match event.payload { - None => todo!(), - Some(Payload::Connect(protoflow_zmq::Connect { output, input })) => { - Connect(map_id(output)?, map_id(input)?) - } - - Some(Payload::AckConnection(protoflow_zmq::AckConnection { - output, - input, - })) => AckConnection(map_id(output)?, map_id(input)?), - - Some(Payload::Message(protoflow_zmq::Message { - output, - input, - sequence, - message, - })) => Message( - map_id(output)?, - map_id(input)?, - sequence, - Bytes::from(message), - ), - - Some(Payload::AckMessage(protoflow_zmq::AckMessage { - output, - input, - sequence, - })) => AckMessage(map_id(output)?, map_id(input)?, sequence), - - Some(Payload::CloseOutput(protoflow_zmq::CloseOutput { output, input })) => { - CloseOutput(map_id(output)?, map_id(input)?) - } - - Some(Payload::CloseInput(protoflow_zmq::CloseInput { input })) => { - CloseInput(map_id(input)?) - } - }) - }) - } -} - impl Default for ZmqTransport { fn default() -> Self { Self::new(DEFAULT_PUB_SOCKET, DEFAULT_SUB_SOCKET) } } -#[derive(Clone, Debug)] -enum ZmqSubscriptionRequest { - Subscribe(String), - Unsubscribe(String), -} - impl ZmqTransport { pub fn new(pub_url: &str, sub_url: &str) -> Self { let tokio = tokio::runtime::Handle::current(); @@ -255,76 +100,11 @@ impl ZmqTransport { inputs, }; - transport.start_pub_socket_worker(psock, pub_queue_recv); - transport.start_sub_socket_worker(ssock, sub_queue_recv); + start_pub_socket_worker(psock, pub_queue_recv); + start_sub_socket_worker(&transport, ssock, sub_queue_recv); transport } - - fn start_pub_socket_worker( - &self, - psock: zeromq::PubSocket, - pub_queue: Receiver, - ) { - let mut psock = psock; - let mut pub_queue = pub_queue; - tokio::task::spawn(async move { - while let Some(event) = pub_queue.recv().await { - #[cfg(feature = "tracing")] - trace!( - target: "ZmqTransport::pub_socket", - ?event, - "sending event to socket" - ); - - psock - .send(event.into()) - .await - .expect("zmq pub-socket worker") - } - }); - } - - fn start_sub_socket_worker( - &self, - ssock: zeromq::SubSocket, - sub_queue: Receiver, - ) { - let outputs = self.outputs.clone(); - let inputs = self.inputs.clone(); - let mut ssock = ssock; - let mut sub_queue = sub_queue; - tokio::task::spawn(async move { - loop { - tokio::select! { - Ok(msg) = ssock.recv() => { - #[cfg(feature = "tracing")] - trace!( - target: "ZmqTransport::sub_socket", - ?msg, - "got message from socket" - ); - - handle_zmq_msg(msg, &outputs, &inputs).await.unwrap() - }, - Some(req) = sub_queue.recv() => { - #[cfg(feature = "tracing")] - trace!( - target: "ZmqTransport::sub_socket", - ?req, - "got sub update request" - ); - - use ZmqSubscriptionRequest::*; - match req { - Subscribe(topic) => ssock.subscribe(&topic).await.expect("zmq recv worker subscribe"), - Unsubscribe(topic) => ssock.unsubscribe(&topic).await.expect("zmq recv worker unsubscribe"), - }; - } - }; - } - }); - } } impl Transport for ZmqTransport { @@ -577,131 +357,6 @@ impl Transport for ZmqTransport { } } -async fn handle_zmq_msg( - msg: ZmqMessage, - outputs: &RwLock>>, - inputs: &RwLock>>, -) -> Result<(), Error> { - let Ok(event) = ZmqTransportEvent::try_from(msg) else { - todo!(); - }; - - #[cfg(feature = "tracing")] - trace!(target: "handle_zmq_msg", ?event, "got event"); - - use ZmqTransportEvent::*; - match event { - // input ports - Connect(_, input_port_id) => { - let sender = { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - todo!(); - }; - let input = input.read().await; - - use ZmqInputPortState::*; - match &*input { - Closed => todo!(), - Open(sender) | Connected(_, _, _, sender, _) => sender.clone(), - } - }; - - sender.send(event).await.unwrap(); - } - Message(_, input_port_id, _, _) => { - let sender = { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - todo!(); - }; - - let input = input.read().await; - let ZmqInputPortState::Connected(_, _, _, sender, _) = &*input else { - todo!(); - }; - - sender.clone() - }; - - sender.send(event).await.unwrap(); - } - CloseOutput(_, input_port_id) => { - let sender = { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - todo!(); - }; - let input = input.read().await; - - use ZmqInputPortState::*; - match &*input { - Closed => todo!(), - Open(sender) | Connected(_, _, _, sender, _) => sender.clone(), - } - }; - - sender.send(event).await.unwrap(); - } - - // output ports - AckConnection(output_port_id, _) => { - let sender = { - let outputs = outputs.read().await; - let Some(output) = outputs.get(&output_port_id) else { - todo!(); - }; - let output = output.read().await; - - let ZmqOutputPortState::Open(_, sender) = &*output else { - todo!(); - }; - - sender.clone() - }; - - sender.send(event).await.unwrap(); - } - AckMessage(output_port_id, _, _) => { - let sender = { - let outputs = outputs.read().await; - let Some(output) = outputs.get(&output_port_id) else { - todo!(); - }; - let output = output.read().await; - let ZmqOutputPortState::Connected(_, sender, _) = &*output else { - todo!(); - }; - - sender.clone() - }; - - sender.send(event).await.unwrap(); - } - CloseInput(input_port_id) => { - for (_, state) in outputs.read().await.iter() { - let sender = { - let state = state.read().await; - let ZmqOutputPortState::Connected(_, ref sender, ref id) = *state else { - continue; - }; - if *id != input_port_id { - continue; - } - - sender.clone() - }; - - if let Err(_e) = sender.send(event.clone()).await { - continue; // TODO - } - } - } - } - - Ok(()) -} - #[cfg(test)] mod tests { use super::*; diff --git a/lib/protoflow-zeromq/src/socket.rs b/lib/protoflow-zeromq/src/socket.rs new file mode 100644 index 00000000..9cb0f12d --- /dev/null +++ b/lib/protoflow-zeromq/src/socket.rs @@ -0,0 +1,205 @@ +// This is free and unencumbered software released into the public domain. + +use crate::{ZmqInputPortState, ZmqOutputPortState, ZmqTransport, ZmqTransportEvent}; +use core::fmt::Error; +use protoflow_core::{ + prelude::{BTreeMap, String}, + InputPortID, OutputPortID, +}; +use tokio::sync::{mpsc::Receiver, RwLock}; +use zeromq::{SocketRecv, SocketSend, ZmqMessage}; + +#[derive(Clone, Debug)] +pub enum ZmqSubscriptionRequest { + Subscribe(String), + Unsubscribe(String), +} + +#[cfg(feature = "tracing")] +use tracing::trace; + +pub fn start_pub_socket_worker(psock: zeromq::PubSocket, pub_queue: Receiver) { + let mut psock = psock; + let mut pub_queue = pub_queue; + tokio::task::spawn(async move { + while let Some(event) = pub_queue.recv().await { + #[cfg(feature = "tracing")] + trace!( + target: "ZmqTransport::pub_socket", + ?event, + "sending event to socket" + ); + + psock + .send(event.into()) + .await + .expect("zmq pub-socket worker") + } + }); +} + +pub fn start_sub_socket_worker( + transport: &ZmqTransport, + ssock: zeromq::SubSocket, + sub_queue: Receiver, +) { + let outputs = transport.outputs.clone(); + let inputs = transport.inputs.clone(); + let mut ssock = ssock; + let mut sub_queue = sub_queue; + tokio::task::spawn(async move { + loop { + tokio::select! { + Ok(msg) = ssock.recv() => { + #[cfg(feature = "tracing")] + trace!( + target: "ZmqTransport::sub_socket", + ?msg, + "got message from socket" + ); + + handle_zmq_msg(msg, &outputs, &inputs).await.unwrap() + }, + Some(req) = sub_queue.recv() => { + #[cfg(feature = "tracing")] + trace!( + target: "ZmqTransport::sub_socket", + ?req, + "got sub update request" + ); + + use ZmqSubscriptionRequest::*; + match req { + Subscribe(topic) => ssock.subscribe(&topic).await.expect("zmq recv worker subscribe"), + Unsubscribe(topic) => ssock.unsubscribe(&topic).await.expect("zmq recv worker unsubscribe"), + }; + } + }; + } + }); +} + +async fn handle_zmq_msg( + msg: ZmqMessage, + outputs: &RwLock>>, + inputs: &RwLock>>, +) -> Result<(), Error> { + let Ok(event) = ZmqTransportEvent::try_from(msg) else { + todo!(); + }; + + #[cfg(feature = "tracing")] + trace!(target: "handle_zmq_msg", ?event, "got event"); + + use ZmqTransportEvent::*; + match event { + // input ports + Connect(_, input_port_id) => { + let sender = { + let inputs = inputs.read().await; + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + let input = input.read().await; + + use ZmqInputPortState::*; + match &*input { + Closed => todo!(), + Open(sender) | Connected(_, _, _, sender, _) => sender.clone(), + } + }; + + sender.send(event).await.unwrap(); + } + Message(_, input_port_id, _, _) => { + let sender = { + let inputs = inputs.read().await; + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + + let input = input.read().await; + let ZmqInputPortState::Connected(_, _, _, sender, _) = &*input else { + todo!(); + }; + + sender.clone() + }; + + sender.send(event).await.unwrap(); + } + CloseOutput(_, input_port_id) => { + let sender = { + let inputs = inputs.read().await; + let Some(input) = inputs.get(&input_port_id) else { + todo!(); + }; + let input = input.read().await; + + use ZmqInputPortState::*; + match &*input { + Closed => todo!(), + Open(sender) | Connected(_, _, _, sender, _) => sender.clone(), + } + }; + + sender.send(event).await.unwrap(); + } + + // output ports + AckConnection(output_port_id, _) => { + let sender = { + let outputs = outputs.read().await; + let Some(output) = outputs.get(&output_port_id) else { + todo!(); + }; + let output = output.read().await; + + let ZmqOutputPortState::Open(_, sender) = &*output else { + todo!(); + }; + + sender.clone() + }; + + sender.send(event).await.unwrap(); + } + AckMessage(output_port_id, _, _) => { + let sender = { + let outputs = outputs.read().await; + let Some(output) = outputs.get(&output_port_id) else { + todo!(); + }; + let output = output.read().await; + let ZmqOutputPortState::Connected(_, sender, _) = &*output else { + todo!(); + }; + + sender.clone() + }; + + sender.send(event).await.unwrap(); + } + CloseInput(input_port_id) => { + for (_, state) in outputs.read().await.iter() { + let sender = { + let state = state.read().await; + let ZmqOutputPortState::Connected(_, ref sender, ref id) = *state else { + continue; + }; + if *id != input_port_id { + continue; + } + + sender.clone() + }; + + if let Err(_e) = sender.send(event.clone()).await { + continue; // TODO + } + } + } + } + + Ok(()) +} From 1ae285faebbea29f944c506a2d25cf502e0c904c Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Mon, 2 Dec 2024 14:54:32 +0200 Subject: [PATCH 46/63] Refactor topic subscriptions --- lib/protoflow-zeromq/src/input_port.rs | 44 +++++------------- lib/protoflow-zeromq/src/lib.rs | 2 +- lib/protoflow-zeromq/src/output_port.rs | 62 +++++-------------------- lib/protoflow-zeromq/src/socket.rs | 35 +++++++++++++- 4 files changed, 57 insertions(+), 86 deletions(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index 510b8c35..ec602538 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -1,6 +1,8 @@ // This is free and unencumbered software released into the public domain. -use crate::{ZmqSubscriptionRequest, ZmqTransport, ZmqTransportEvent}; +use crate::{ + subscribe_topics, unsubscribe_topics, ZmqSubscriptionRequest, ZmqTransport, ZmqTransportEvent, +}; use protoflow_core::{ prelude::{format, vec, Arc, BTreeMap, Bytes, String, Vec}, InputPortID, OutputPortID, PortError, PortState, @@ -86,29 +88,16 @@ pub fn start_input_worker( inputs.insert(input_port_id, state); } - { - let mut handles = Vec::new(); - for topic in input_topics(input_port_id).into_iter() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?topic, "sending subscription request")); - - let handle = transport - .sub_queue - .send(ZmqSubscriptionRequest::Subscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - transport - .tokio - .block_on(handle) - .expect("input worker send sub req"); - } - } - let sub_queue = transport.sub_queue.clone(); let pub_queue = transport.pub_queue.clone(); let inputs = transport.inputs.clone(); + let topics = input_topics(input_port_id); + transport + .tokio + .block_on(subscribe_topics(&topics, &sub_queue)) + .unwrap(); + async fn handle_socket_event( event: ZmqTransportEvent, inputs: &RwLock>>, @@ -311,19 +300,8 @@ pub fn start_input_worker( *input_state = ZmqInputPortState::Closed; - { - let mut handles = Vec::new(); - for topic in input_topics(input_port_id).into_iter() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?topic, "sending unsubscription request")); - - let handle = sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - handle.await.expect("input worker send unsub req"); - } - } + let topics = input_topics(input_port_id); + unsubscribe_topics(&topics, sub_queue).await.unwrap(); response_chan .send(Ok(())) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 898f061d..e6635bd1 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -90,7 +90,7 @@ impl ZmqTransport { let (pub_queue, pub_queue_recv) = channel(1); - let (sub_queue, sub_queue_recv) = tokio::sync::mpsc::channel(1); + let (sub_queue, sub_queue_recv) = channel(1); let transport = Self { pub_queue, diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs index 3f5d1e42..93fcfc62 100644 --- a/lib/protoflow-zeromq/src/output_port.rs +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -1,6 +1,6 @@ // This is free and unencumbered software released into the public domain. -use crate::{ZmqSubscriptionRequest, ZmqTransport, ZmqTransportEvent}; +use crate::{subscribe_topics, unsubscribe_topics, ZmqTransport, ZmqTransportEvent}; use protoflow_core::{ prelude::{format, vec, Bytes, String, ToString, Vec}, InputPortID, OutputPortID, PortError, PortState, @@ -13,6 +13,11 @@ use tokio::sync::{ #[cfg(feature = "tracing")] use tracing::{trace, trace_span}; +#[derive(Debug, Clone)] +pub enum ZmqOutputPortRequest { + Close, + Send(Bytes), +} #[derive(Debug, Clone)] pub enum ZmqOutputPortState { Open( @@ -30,12 +35,6 @@ pub enum ZmqOutputPortState { Closed, } -#[derive(Debug, Clone)] -pub enum ZmqOutputPortRequest { - Close, - Send(Bytes), -} - impl ZmqOutputPortState { pub fn state(&self) -> PortState { use ZmqOutputPortState::*; @@ -90,7 +89,8 @@ pub fn start_output_worker( tokio::task::spawn(async move { let Some((input_port_id, conn_confirm)) = conn_recv.recv().await else { - todo!(); + // all senders have dropped, i.e. there's no connection request coming + return; }; #[cfg(feature = "tracing")] @@ -100,19 +100,8 @@ pub fn start_output_worker( ?input_port_id ); - { - let mut handles = Vec::new(); - for topic in output_topics(output_port_id, input_port_id).into_iter() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?topic, "sending subscription request")); - - let handle = sub_queue.send(ZmqSubscriptionRequest::Subscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - handle.await.expect("output worker send sub req"); - } - } + let topics = output_topics(output_port_id, input_port_id); + subscribe_topics(&topics, &sub_queue).await.unwrap(); let (msg_req_send, mut msg_req_recv) = channel(1); @@ -193,18 +182,7 @@ pub fn start_output_worker( .await .map_err(|e| PortError::Other(e.to_string())); - { - let mut handles = Vec::new(); - for topic in output_topics(output_port_id, input_port_id).into_iter() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?topic, "sending unsubscription request")); - let handle = sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - handle.await.expect("output worker send unsub req"); - } - } + unsubscribe_topics(&topics, &sub_queue).await.unwrap(); response_chan .send(response) @@ -256,23 +234,7 @@ pub fn start_output_worker( )); *output_state = ZmqOutputPortState::Closed; - { - let mut handles = Vec::new(); - for topic in - output_topics(output_port_id, input_port_id).into_iter() - { - #[cfg(feature = "tracing")] - span.in_scope(|| { - trace!(?topic, "sending unsubscription request") - }); - let handle = sub_queue - .send(ZmqSubscriptionRequest::Unsubscribe(topic)); - handles.push(handle); - } - for handle in handles.into_iter() { - handle.await.expect("output worker send unsub req"); - } - } + unsubscribe_topics(&topics, &sub_queue).await.unwrap(); response_chan .send(Err(PortError::Disconnected)) diff --git a/lib/protoflow-zeromq/src/socket.rs b/lib/protoflow-zeromq/src/socket.rs index 9cb0f12d..aaee5e82 100644 --- a/lib/protoflow-zeromq/src/socket.rs +++ b/lib/protoflow-zeromq/src/socket.rs @@ -3,10 +3,13 @@ use crate::{ZmqInputPortState, ZmqOutputPortState, ZmqTransport, ZmqTransportEvent}; use core::fmt::Error; use protoflow_core::{ - prelude::{BTreeMap, String}, + prelude::{BTreeMap, String, Vec}, InputPortID, OutputPortID, }; -use tokio::sync::{mpsc::Receiver, RwLock}; +use tokio::sync::{ + mpsc::{error::SendError, Receiver, Sender}, + RwLock, +}; use zeromq::{SocketRecv, SocketSend, ZmqMessage}; #[derive(Clone, Debug)] @@ -38,6 +41,34 @@ pub fn start_pub_socket_worker(psock: zeromq::PubSocket, pub_queue: Receiver, +) -> Result<(), SendError> { + let mut handles = Vec::with_capacity(topics.len()); + for topic in topics { + handles.push(sub_queue.send(ZmqSubscriptionRequest::Subscribe(topic.clone()))); + } + for handle in handles { + handle.await?; + } + Ok(()) +} + +pub async fn unsubscribe_topics( + topics: &[String], + sub_queue: &Sender, +) -> Result<(), SendError> { + let mut handles = Vec::with_capacity(topics.len()); + for topic in topics { + handles.push(sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic.clone()))); + } + for handle in handles { + handle.await?; + } + Ok(()) +} + pub fn start_sub_socket_worker( transport: &ZmqTransport, ssock: zeromq::SubSocket, From 3a4e836f1b297fbe4495ce1f70a007928ec8ff5a Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Tue, 3 Dec 2024 10:58:58 +0200 Subject: [PATCH 47/63] Refactor output port worker Removed panics and improved the tracing messages. --- lib/protoflow-zeromq/src/output_port.rs | 200 ++++++++++++++---------- 1 file changed, 116 insertions(+), 84 deletions(-) diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs index 93fcfc62..b0397085 100644 --- a/lib/protoflow-zeromq/src/output_port.rs +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -2,7 +2,7 @@ use crate::{subscribe_topics, unsubscribe_topics, ZmqTransport, ZmqTransportEvent}; use protoflow_core::{ - prelude::{format, vec, Bytes, String, ToString, Vec}, + prelude::{fmt, format, vec, Bytes, String, ToString, Vec}, InputPortID, OutputPortID, PortError, PortState, }; use tokio::sync::{ @@ -11,13 +11,14 @@ use tokio::sync::{ }; #[cfg(feature = "tracing")] -use tracing::{trace, trace_span}; +use tracing::{debug, error, info, trace, trace_span, warn}; #[derive(Debug, Clone)] pub enum ZmqOutputPortRequest { Close, Send(Bytes), } + #[derive(Debug, Clone)] pub enum ZmqOutputPortState { Open( @@ -35,12 +36,25 @@ pub enum ZmqOutputPortState { Closed, } +impl fmt::Display for ZmqOutputPortState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use ZmqOutputPortState::*; + match *self { + Open(..) => write!(f, "Open"), + Connected(.., ref id) => { + write!(f, "Connected({:?})", isize::from(*id),) + } + Closed => write!(f, "Closed"), + } + } +} + impl ZmqOutputPortState { pub fn state(&self) -> PortState { use ZmqOutputPortState::*; match self { - Open(_, _) => PortState::Open, - Connected(_, _, _) => PortState::Connected, + Open(..) => PortState::Open, + Connected(..) => PortState::Connected, Closed => PortState::Closed, } } @@ -59,25 +73,20 @@ pub fn start_output_worker( output_port_id: OutputPortID, ) -> Result<(), PortError> { #[cfg(feature = "tracing")] - let span = trace_span!("ZmqTransport::start_output_worker", ?output_port_id); + let span = trace_span!("ZmqTransport::output_port_worker", ?output_port_id); let (conn_send, mut conn_recv) = channel(1); - let (to_worker_send, mut to_worker_recv) = channel(1); { let mut outputs = transport.tokio.block_on(transport.outputs.write()); - if outputs.contains_key(&output_port_id) { - return Ok(()); // TODO + return Err(PortError::Invalid(output_port_id.into())); } let state = ZmqOutputPortState::Open(conn_send, to_worker_send.clone()); - let state = RwLock::new(state); - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?state, "saving new opened state")); - - outputs.insert(output_port_id, state); + span.in_scope(|| trace!("saving new state: {}", state)); + outputs.insert(output_port_id, RwLock::new(state)); } let sub_queue = transport.sub_queue.clone(); @@ -90,18 +99,20 @@ pub fn start_output_worker( tokio::task::spawn(async move { let Some((input_port_id, conn_confirm)) = conn_recv.recv().await else { // all senders have dropped, i.e. there's no connection request coming + #[cfg(feature = "tracing")] + debug!(parent: &span, "no connection request"); return; }; #[cfg(feature = "tracing")] - let span = trace_span!( - "ZmqTransport::start_output_worker::spawn", - ?output_port_id, - ?input_port_id - ); + let span = trace_span!(parent: &span, "task", ?input_port_id); let topics = output_topics(output_port_id, input_port_id); - subscribe_topics(&topics, &sub_queue).await.unwrap(); + if subscribe_topics(&topics, &sub_queue).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("topic subscription failed")); + return; + } let (msg_req_send, mut msg_req_recv) = channel(1); @@ -119,15 +130,21 @@ pub fn start_output_worker( #[cfg(feature = "tracing")] span.in_scope(|| trace!("sending connection attempt...")); - pub_queue + if pub_queue .send(ZmqTransportEvent::Connect(output_port_id, input_port_id)) .await - .expect("output worker send connect event"); + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("publish channel is closed")); + return; + } - let response = to_worker_recv - .recv() - .await - .expect("output worker recv ack-conn event"); + let Some(response) = to_worker_recv.recv().await else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("all senders to worker have dropped?")); + return; + }; #[cfg(feature = "tracing")] span.in_scope(|| trace!(?response, "got response")); @@ -135,22 +152,33 @@ pub fn start_output_worker( use ZmqTransportEvent::*; match response { AckConnection(_, input_port_id) => { - let outputs = outputs.read().await; - let Some(output_state) = outputs.get(&output_port_id) else { - todo!(); + let response = match outputs.read().await.get(&output_port_id) { + None => { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + Err(PortError::Invalid(output_port_id.into())) + } + Some(output_state) => { + let mut output_state = output_state.write().await; + debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); + *output_state = ZmqOutputPortState::Connected( + msg_req_send, + to_worker_send, + input_port_id, + ); + + #[cfg(feature = "tracing")] + span.in_scope(|| info!("Connected!")); + + Ok(()) + } }; - let mut output_state = output_state.write().await; - debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); - *output_state = - ZmqOutputPortState::Connected(msg_req_send, to_worker_send, input_port_id); - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?output_state, "Connected!")); - - conn_confirm - .send(Ok(())) - .await - .expect("output worker respond conn"); + if conn_confirm.send(response).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("connection confirmation channel is closed")); + // don't exit, proceed to send loop + } drop(conn_confirm); break; @@ -160,18 +188,20 @@ pub fn start_output_worker( } let mut seq_id = 1; - 'send: loop { + 'send: while let Some((request, response_chan)) = msg_req_recv.recv().await { #[cfg(feature = "tracing")] let span = trace_span!(parent: &span, "send_loop", ?seq_id); - let (request, response_chan) = msg_req_recv - .recv() - .await - .expect("output worker recv msg req"); - #[cfg(feature = "tracing")] span.in_scope(|| trace!(?request, "sending request")); + let respond = |response| async { + if response_chan.send(response).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("response channel is closed")); + } + }; + match request { ZmqOutputPortRequest::Close => { let response = pub_queue @@ -181,16 +211,11 @@ pub fn start_output_worker( )) .await .map_err(|e| PortError::Other(e.to_string())); - - unsubscribe_topics(&topics, &sub_queue).await.unwrap(); - - response_chan - .send(response) - .await - .expect("output worker respond close"); + respond(response).await; + break 'send; } ZmqOutputPortRequest::Send(bytes) => { - pub_queue + if pub_queue .send(ZmqTransportEvent::Message( output_port_id, input_port_id, @@ -198,13 +223,20 @@ pub fn start_output_worker( bytes, )) .await - .expect("output worker send message event"); + .is_err() + { + respond(Err(PortError::SendFailed)).await; + continue 'send; + } 'recv: loop { - let event = to_worker_recv - .recv() - .await - .expect("output worker event recv"); + let Some(event) = to_worker_recv.recv().await else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("all senders to worker have dropped")); + + respond(Err(PortError::Invalid(output_port_id.into()))).await; + break 'send; + }; #[cfg(feature = "tracing")] span.in_scope(|| trace!(?event, "received event")); @@ -215,41 +247,26 @@ pub fn start_output_worker( if ack_id == seq_id { #[cfg(feature = "tracing")] span.in_scope(|| trace!(?ack_id, "msg-ack matches")); - response_chan - .send(Ok(())) - .await - .expect("output worker respond send"); + respond(Ok(())).await; break 'recv; + } else { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!(?ack_id, "got msg-ack for different sequence") + }); } } CloseInput(_) => { - let outputs = outputs.read().await; - let Some(output_state) = outputs.get(&output_port_id) else { - todo!(); - }; - let mut output_state = output_state.write().await; - debug_assert!(matches!( - *output_state, - ZmqOutputPortState::Connected(..) - )); - *output_state = ZmqOutputPortState::Closed; - - unsubscribe_topics(&topics, &sub_queue).await.unwrap(); - - response_chan - .send(Err(PortError::Disconnected)) - .await - .expect("output worker respond msg"); - + // report that the input port was closed + respond(Err(PortError::Disconnected)).await; break 'send; } // ignore others, we shouldn't receive any new conn-acks // nor should we be receiving input port events - AckConnection(_, _) - | Connect(_, _) - | Message(_, _, _, _) - | CloseOutput(_, _) => continue, + AckConnection(..) | Connect(..) | Message(..) | CloseOutput(..) => { + continue 'recv + } } } } @@ -257,6 +274,21 @@ pub fn start_output_worker( seq_id += 1; } + + let outputs = outputs.read().await; + let Some(output_state) = outputs.get(&output_port_id) else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; + }; + let mut output_state = output_state.write().await; + debug_assert!(matches!(*output_state, ZmqOutputPortState::Connected(..))); + *output_state = ZmqOutputPortState::Closed; + + if unsubscribe_topics(&topics, &sub_queue).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("topic unsubscription failed")); + } }); Ok(()) From a9e9ffcaae714d974f608a7ad1f3327e96515517 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Tue, 3 Dec 2024 13:41:25 +0200 Subject: [PATCH 48/63] Refactor input port worker --- lib/protoflow-zeromq/src/input_port.rs | 188 ++++++++++++++++--------- 1 file changed, 122 insertions(+), 66 deletions(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index ec602538..45408145 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -4,7 +4,7 @@ use crate::{ subscribe_topics, unsubscribe_topics, ZmqSubscriptionRequest, ZmqTransport, ZmqTransportEvent, }; use protoflow_core::{ - prelude::{format, vec, Arc, BTreeMap, Bytes, String, Vec}, + prelude::{fmt, format, vec, Arc, BTreeMap, Bytes, String, ToString, Vec}, InputPortID, OutputPortID, PortError, PortState, }; use tokio::sync::{ @@ -13,7 +13,7 @@ use tokio::sync::{ }; #[cfg(feature = "tracing")] -use tracing::{trace, trace_span}; +use tracing::{error, info, trace, trace_span, warn}; #[derive(Debug, Clone)] pub enum ZmqInputPortRequest { @@ -44,12 +44,29 @@ pub enum ZmqInputPortState { Closed, } +impl fmt::Display for ZmqInputPortState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use ZmqInputPortState::*; + match *self { + Open(..) => write!(f, "Open"), + Connected(.., ref ids) => { + write!( + f, + "Connected({:?})", + ids.iter().map(|id| isize::from(*id)).collect::>() + ) + } + Closed => write!(f, "Closed"), + } + } +} + impl ZmqInputPortState { pub fn state(&self) -> PortState { use ZmqInputPortState::*; match self { Open(_) => PortState::Open, - Connected(_, _, _, _, _) => PortState::Connected, + Connected(..) => PortState::Connected, Closed => PortState::Closed, } } @@ -68,24 +85,20 @@ pub fn start_input_worker( input_port_id: InputPortID, ) -> Result<(), PortError> { #[cfg(feature = "tracing")] - let span = trace_span!("ZmqTransport::start_input_worker", ?input_port_id); + let span = trace_span!("ZmqTransport::input_port_worker", ?input_port_id); let (to_worker_send, mut to_worker_recv) = channel(1); - let (req_send, mut req_recv) = channel(1); { let mut inputs = transport.tokio.block_on(transport.inputs.write()); if inputs.contains_key(&input_port_id) { - return Ok(()); // TODO + return Err(PortError::Invalid(input_port_id.into())); } let state = ZmqInputPortState::Open(to_worker_send.clone()); - let state = RwLock::new(state); - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?state, "saving new opened state")); - - inputs.insert(input_port_id, state); + span.in_scope(|| trace!("saving new state: {}", state)); + inputs.insert(input_port_id, RwLock::new(state)); } let sub_queue = transport.sub_queue.clone(); @@ -93,10 +106,15 @@ pub fn start_input_worker( let inputs = transport.inputs.clone(); let topics = input_topics(input_port_id); - transport + if transport .tokio .block_on(subscribe_topics(&topics, &sub_queue)) - .unwrap(); + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("topic subscription failed")); + return Err(PortError::Other("topic subscription failed".to_string())); + } async fn handle_socket_event( event: ZmqTransportEvent, @@ -107,7 +125,7 @@ pub fn start_input_worker( ) { #[cfg(feature = "tracing")] let span = trace_span!( - "ZmqTransport::start_input_worker::handle_socket_event", + "ZmqTransport::input_port_worker::handle_socket_event", ?input_port_id ); @@ -117,9 +135,14 @@ pub fn start_input_worker( use ZmqTransportEvent::*; match event { Connect(output_port_id, input_port_id) => { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "Connect", ?output_port_id); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; }; let mut input_state = input_state.write().await; @@ -129,9 +152,7 @@ pub fn start_input_worker( Connected(_, _, _, _, connected_ids) => { if connected_ids.iter().any(|&id| id == output_port_id) { #[cfg(feature = "tracing")] - span.in_scope(|| { - trace!(?output_port_id, "output port is already connected") - }); + span.in_scope(|| trace!("output port is already connected")); return; } } @@ -150,27 +171,32 @@ pub fn start_input_worker( vec![output_port_id], ); } - Connected(_, _, _, _, ids) => { + Connected(.., ids) => { ids.push(output_port_id); } Closed => unreachable!(), }; - pub_queue + if pub_queue .send(ZmqTransportEvent::AckConnection( output_port_id, input_port_id, )) .await - .expect("input worker send ack-conn event"); + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("publish channel is closed")); + return; + } #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?output_port_id, "sent conn-ack")); + span.in_scope(|| trace!("sent conn-ack")); add_connection(&mut input_state); #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?input_state, "connected new port")); + span.in_scope(|| info!("Connected new port: {}", input_state)); } Message(output_port_id, _, seq_id, bytes) => { #[cfg(feature = "tracing")] @@ -178,7 +204,9 @@ pub fn start_input_worker( let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; }; let input_state = input_state.read().await; @@ -191,69 +219,87 @@ pub fn start_input_worker( return; } - sender + if sender .send(ZmqInputPortEvent::Message(bytes)) .await - .expect("input worker send message"); + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("receiver for input events has closed")); + return; + } - pub_queue + if pub_queue .send(ZmqTransportEvent::AckMessage( output_port_id, input_port_id, seq_id, )) .await - .expect("input worker send message ack"); + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("publish channel is closed")); + return; + } #[cfg(feature = "tracing")] span.in_scope(|| trace!("sent msg-ack")); } - Open(_) | Closed => todo!(), + Open(_) | Closed => { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("port is not connected: {}", input_state)); + } } } CloseOutput(output_port_id, input_port_id) => { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "CloseOutput", ?output_port_id); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; }; let mut input_state = input_state.write().await; use ZmqInputPortState::*; - let Connected(_, _, _, _, ref connected_ids) = *input_state else { + let Connected(_, ref sender, _, _, ref mut connected_ids) = *input_state else { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("input port wasn't connected")); return; }; - if !connected_ids.iter().any(|id| *id == output_port_id) { + let Some(idx) = connected_ids.iter().position(|&id| id == output_port_id) else { #[cfg(feature = "tracing")] - span.in_scope(|| { - trace!( - ?output_port_id, - "output port doesn't match any connected port" - ) - }); + span.in_scope(|| trace!("output port doesn't match any connected port")); + return; + }; + + connected_ids.swap_remove(idx); + + if !connected_ids.is_empty() { return; } - match *input_state { - Open(_) | Closed => (), - Connected(_, ref sender, _, _, ref mut connected_ids) => { - connected_ids.retain(|&id| id != output_port_id); - if connected_ids.is_empty() { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!("last connected port disconnected")); - sender - .send(ZmqInputPortEvent::Closed) - .await - .expect("input worker publish Closed event"); - } - } - }; + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("last connected port disconnected")); + + if let Err(err) = sender.try_send(ZmqInputPortEvent::Closed) { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("did not send InputPortEvent::Closed: {}", err)); + } + + // TODO: Should last connection closing close the input port too? + // It does in the MPSC transport. + //*input_state = ZmqInputPortState::Closed; } // ignore, ideally we never receive these here: - AckConnection(_, _) | AckMessage(_, _, _) | CloseInput(_) => (), + AckConnection(..) | AckMessage(..) | CloseInput(_) => (), } } @@ -267,7 +313,7 @@ pub fn start_input_worker( ) { #[cfg(feature = "tracing")] let span = trace_span!( - "ZmqTransport::start_input_worker::handle_input_event", + "ZmqTransport::input_port_worker::handle_input_event", ?input_port_id ); @@ -279,7 +325,9 @@ pub fn start_input_worker( Close => { let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { - todo!(); + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; }; let mut input_state = input_state.write().await; @@ -288,25 +336,33 @@ pub fn start_input_worker( return; }; - pub_queue + if pub_queue .send(ZmqTransportEvent::CloseInput(input_port_id)) .await - .expect("input worker send close event"); + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("can't publish CloseInput event")); + // don't exit, continue to close the port + } - port_events - .send(ZmqInputPortEvent::Closed) - .await - .expect("input worker send port closed"); + if let Err(err) = port_events.try_send(ZmqInputPortEvent::Closed) { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("did not send InputPortEvent::Closed: {}", err)); + } *input_state = ZmqInputPortState::Closed; let topics = input_topics(input_port_id); - unsubscribe_topics(&topics, sub_queue).await.unwrap(); + if unsubscribe_topics(&topics, sub_queue).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("topic unsubscription failed")); + } - response_chan - .send(Ok(())) - .await - .expect("input worker respond close") + if response_chan.send(Ok(())).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("response channel is closed")); + } } } } From 7961ea0aa31ddcffc21216b23128441036923ddf Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Tue, 3 Dec 2024 14:22:03 +0200 Subject: [PATCH 49/63] Refactor socket worker --- lib/protoflow-zeromq/src/socket.rs | 98 +++++++++++++++--------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/lib/protoflow-zeromq/src/socket.rs b/lib/protoflow-zeromq/src/socket.rs index aaee5e82..726edbd1 100644 --- a/lib/protoflow-zeromq/src/socket.rs +++ b/lib/protoflow-zeromq/src/socket.rs @@ -1,10 +1,9 @@ // This is free and unencumbered software released into the public domain. use crate::{ZmqInputPortState, ZmqOutputPortState, ZmqTransport, ZmqTransportEvent}; -use core::fmt::Error; use protoflow_core::{ prelude::{BTreeMap, String, Vec}, - InputPortID, OutputPortID, + InputPortID, OutputPortID, PortError, }; use tokio::sync::{ mpsc::{error::SendError, Receiver, Sender}, @@ -19,24 +18,22 @@ pub enum ZmqSubscriptionRequest { } #[cfg(feature = "tracing")] -use tracing::trace; +use tracing::{error, trace, trace_span}; pub fn start_pub_socket_worker(psock: zeromq::PubSocket, pub_queue: Receiver) { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::pub_socket"); let mut psock = psock; let mut pub_queue = pub_queue; tokio::task::spawn(async move { while let Some(event) = pub_queue.recv().await { #[cfg(feature = "tracing")] - trace!( - target: "ZmqTransport::pub_socket", - ?event, - "sending event to socket" - ); - - psock - .send(event.into()) - .await - .expect("zmq pub-socket worker") + span.in_scope(|| trace!(?event, "sending event to socket")); + + if let Err(err) = psock.send(event.into()).await { + #[cfg(feature = "tracing")] + span.in_scope(|| error!(?err, "failed to send message")); + } } }); } @@ -74,6 +71,8 @@ pub fn start_sub_socket_worker( ssock: zeromq::SubSocket, sub_queue: Receiver, ) { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::sub_socket"); let outputs = transport.outputs.clone(); let inputs = transport.inputs.clone(); let mut ssock = ssock; @@ -83,26 +82,27 @@ pub fn start_sub_socket_worker( tokio::select! { Ok(msg) = ssock.recv() => { #[cfg(feature = "tracing")] - trace!( - target: "ZmqTransport::sub_socket", - ?msg, - "got message from socket" - ); + span.in_scope(|| trace!(?msg, "got message from socket")); - handle_zmq_msg(msg, &outputs, &inputs).await.unwrap() + if let Err(err) = handle_zmq_msg(msg, &outputs, &inputs).await { + #[cfg(feature = "tracing")] + span.in_scope(|| error!(?err, "failed to process message")); + } }, Some(req) = sub_queue.recv() => { #[cfg(feature = "tracing")] - trace!( - target: "ZmqTransport::sub_socket", - ?req, - "got sub update request" - ); + span.in_scope(|| trace!(?req, "got sub update request")); use ZmqSubscriptionRequest::*; match req { - Subscribe(topic) => ssock.subscribe(&topic).await.expect("zmq recv worker subscribe"), - Unsubscribe(topic) => ssock.unsubscribe(&topic).await.expect("zmq recv worker unsubscribe"), + Subscribe(topic) => if let Err(err) = ssock.subscribe(&topic).await { + #[cfg(feature = "tracing")] + span.in_scope(|| error!(?err, ?topic, "subscribe failed")); + }, + Unsubscribe(topic) => if let Err(err) = ssock.unsubscribe(&topic).await { + #[cfg(feature = "tracing")] + span.in_scope(|| error!(?err, ?topic, "unsubscribe failed")); + } }; } }; @@ -114,13 +114,14 @@ async fn handle_zmq_msg( msg: ZmqMessage, outputs: &RwLock>>, inputs: &RwLock>>, -) -> Result<(), Error> { - let Ok(event) = ZmqTransportEvent::try_from(msg) else { - todo!(); - }; +) -> Result<(), PortError> { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::handle_zmq_msg"); + + let event = ZmqTransportEvent::try_from(msg)?; #[cfg(feature = "tracing")] - trace!(target: "handle_zmq_msg", ?event, "got event"); + span.in_scope(|| trace!(?event, "got event")); use ZmqTransportEvent::*; match event { @@ -129,52 +130,52 @@ async fn handle_zmq_msg( let sender = { let inputs = inputs.read().await; let Some(input) = inputs.get(&input_port_id) else { - todo!(); + return Err(PortError::Invalid(input_port_id.into())); }; let input = input.read().await; use ZmqInputPortState::*; match &*input { - Closed => todo!(), - Open(sender) | Connected(_, _, _, sender, _) => sender.clone(), + Closed => return Err(PortError::Invalid(input_port_id.into())), + Open(sender) | Connected(.., sender, _) => sender.clone(), } }; - sender.send(event).await.unwrap(); + sender.send(event).await.map_err(|_| PortError::Closed) } Message(_, input_port_id, _, _) => { let sender = { let inputs = inputs.read().await; let Some(input) = inputs.get(&input_port_id) else { - todo!(); + return Err(PortError::Invalid(input_port_id.into())); }; let input = input.read().await; let ZmqInputPortState::Connected(_, _, _, sender, _) = &*input else { - todo!(); + return Err(PortError::Invalid(input_port_id.into())); }; sender.clone() }; - sender.send(event).await.unwrap(); + sender.send(event).await.map_err(|_| PortError::Closed) } CloseOutput(_, input_port_id) => { let sender = { let inputs = inputs.read().await; let Some(input) = inputs.get(&input_port_id) else { - todo!(); + return Err(PortError::Invalid(input_port_id.into())); }; let input = input.read().await; use ZmqInputPortState::*; match &*input { - Closed => todo!(), + Closed => return Err(PortError::Invalid(input_port_id.into())), Open(sender) | Connected(_, _, _, sender, _) => sender.clone(), } }; - sender.send(event).await.unwrap(); + sender.send(event).await.map_err(|_| PortError::Closed) } // output ports @@ -182,34 +183,34 @@ async fn handle_zmq_msg( let sender = { let outputs = outputs.read().await; let Some(output) = outputs.get(&output_port_id) else { - todo!(); + return Err(PortError::Invalid(output_port_id.into())); }; let output = output.read().await; let ZmqOutputPortState::Open(_, sender) = &*output else { - todo!(); + return Err(PortError::Invalid(output_port_id.into())); }; sender.clone() }; - sender.send(event).await.unwrap(); + sender.send(event).await.map_err(|_| PortError::Closed) } AckMessage(output_port_id, _, _) => { let sender = { let outputs = outputs.read().await; let Some(output) = outputs.get(&output_port_id) else { - todo!(); + return Err(PortError::Invalid(output_port_id.into())); }; let output = output.read().await; let ZmqOutputPortState::Connected(_, sender, _) = &*output else { - todo!(); + return Err(PortError::Invalid(output_port_id.into())); }; sender.clone() }; - sender.send(event).await.unwrap(); + sender.send(event).await.map_err(|_| PortError::Closed) } CloseInput(input_port_id) => { for (_, state) in outputs.read().await.iter() { @@ -229,8 +230,7 @@ async fn handle_zmq_msg( continue; // TODO } } + Ok(()) } } - - Ok(()) } From 33528bc584635a225055ee0996adacf21a1aaaf1 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Tue, 3 Dec 2024 22:03:13 +0200 Subject: [PATCH 50/63] Remove `todo!` from message parsing --- lib/protoflow-zeromq/src/event.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/protoflow-zeromq/src/event.rs b/lib/protoflow-zeromq/src/event.rs index 1264d6d4..a6506f06 100644 --- a/lib/protoflow-zeromq/src/event.rs +++ b/lib/protoflow-zeromq/src/event.rs @@ -110,16 +110,16 @@ impl TryFrom for ZmqTransportEvent { value .get(1) .ok_or_else(|| { - protoflow_core::DecodeError::new( - "message from socket contains less than two frames", - ) + protoflow_core::DecodeError::new("message contains less than two frames") }) .and_then(|bytes| { let event = Event::decode(bytes.as_ref())?; use ZmqTransportEvent::*; Ok(match event.payload { - None => todo!(), + None => { + return Err(protoflow_core::DecodeError::new("message payload is empty")) + } Some(Payload::Connect(protoflow_zmq::Connect { output, input })) => { Connect(map_id(output)?, map_id(input)?) } From 0c34c6a2a32c7214cf292e2849815a4302789bec Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 11:06:51 +0200 Subject: [PATCH 51/63] Make open input ports closable --- lib/protoflow-zeromq/src/input_port.rs | 38 ++++++++++++++++---------- lib/protoflow-zeromq/src/lib.rs | 10 +++---- lib/protoflow-zeromq/src/socket.rs | 4 +-- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index 45408145..861670fe 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -29,14 +29,19 @@ pub enum ZmqInputPortEvent { #[derive(Debug, Clone)] pub enum ZmqInputPortState { - Open(Sender), + Open( + // channel for requests from public close + Sender<(ZmqInputPortRequest, Sender>)>, + // channel used internally for events from socket + Sender, + ), Connected( // channel for requests from public close Sender<(ZmqInputPortRequest, Sender>)>, - // channel for the public recv + // channels to send-to and receive-from the public `recv` method Sender, Arc>>, - // internal channel for events + // channel used internally for events from socket Sender, // vec of the connected port ids Vec, @@ -65,7 +70,7 @@ impl ZmqInputPortState { pub fn state(&self) -> PortState { use ZmqInputPortState::*; match self { - Open(_) => PortState::Open, + Open(..) => PortState::Open, Connected(..) => PortState::Connected, Closed => PortState::Closed, } @@ -95,7 +100,7 @@ pub fn start_input_worker( if inputs.contains_key(&input_port_id) { return Err(PortError::Invalid(input_port_id.into())); } - let state = ZmqInputPortState::Open(to_worker_send.clone()); + let state = ZmqInputPortState::Open(req_send.clone(), to_worker_send.clone()); #[cfg(feature = "tracing")] span.in_scope(|| trace!("saving new state: {}", state)); inputs.insert(input_port_id, RwLock::new(state)); @@ -148,8 +153,8 @@ pub fn start_input_worker( use ZmqInputPortState::*; match &*input_state { - Open(_) => (), - Connected(_, _, _, _, connected_ids) => { + Open(..) => (), + Connected(.., connected_ids) => { if connected_ids.iter().any(|&id| id == output_port_id) { #[cfg(feature = "tracing")] span.in_scope(|| trace!("output port is already connected")); @@ -160,7 +165,7 @@ pub fn start_input_worker( }; let add_connection = |input_state: &mut ZmqInputPortState| match input_state { - Open(to_worker_send) => { + Open(req_send, to_worker_send) => { let (msgs_send, msgs_recv) = channel(1); let msgs_recv = Arc::new(Mutex::new(msgs_recv)); *input_state = Connected( @@ -247,7 +252,7 @@ pub fn start_input_worker( span.in_scope(|| trace!("sent msg-ack")); } - Open(_) | Closed => { + Open(..) | Closed => { #[cfg(feature = "tracing")] span.in_scope(|| warn!("port is not connected: {}", input_state)); } @@ -332,8 +337,16 @@ pub fn start_input_worker( let mut input_state = input_state.write().await; use ZmqInputPortState::*; - let Connected(_, ref port_events, _, _, _) = *input_state else { + + if let Closed = *input_state { return; + } + + if let Connected(_, ref port_events, ..) = *input_state { + if let Err(err) = port_events.try_send(ZmqInputPortEvent::Closed) { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("did not send InputPortEvent::Closed: {}", err)); + } }; if pub_queue @@ -346,11 +359,6 @@ pub fn start_input_worker( // don't exit, continue to close the port } - if let Err(err) = port_events.try_send(ZmqInputPortEvent::Closed) { - #[cfg(feature = "tracing")] - span.in_scope(|| warn!("did not send InputPortEvent::Closed: {}", err)); - } - *input_state = ZmqInputPortState::Closed; let topics = input_topics(input_port_id); diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index e6635bd1..27854a8d 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -175,13 +175,13 @@ impl Transport for ZmqTransport { let Some(input_state) = inputs.get(&input) else { return Err(PortError::Invalid(input.into())); }; - let input_state = input_state.read().await; - let ZmqInputPortState::Connected(sender, _, _, _, _) = &*input_state else { - return Err(PortError::Disconnected); - }; - sender.clone() + use ZmqInputPortState::*; + match *input_state { + Open(ref sender, _) | Connected(ref sender, ..) => sender.clone(), + Closed => return Err(PortError::Disconnected), + } }; let (close_send, mut close_recv) = channel(1); diff --git a/lib/protoflow-zeromq/src/socket.rs b/lib/protoflow-zeromq/src/socket.rs index 726edbd1..f65dda83 100644 --- a/lib/protoflow-zeromq/src/socket.rs +++ b/lib/protoflow-zeromq/src/socket.rs @@ -137,7 +137,7 @@ async fn handle_zmq_msg( use ZmqInputPortState::*; match &*input { Closed => return Err(PortError::Invalid(input_port_id.into())), - Open(sender) | Connected(.., sender, _) => sender.clone(), + Open(.., sender) | Connected(.., sender, _) => sender.clone(), } }; @@ -171,7 +171,7 @@ async fn handle_zmq_msg( use ZmqInputPortState::*; match &*input { Closed => return Err(PortError::Invalid(input_port_id.into())), - Open(sender) | Connected(_, _, _, sender, _) => sender.clone(), + Open(.., sender) | Connected(.., sender, _) => sender.clone(), } }; From c068adadd9f3ce036cbf520dd0c9b52b5ff34cb1 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 11:14:19 +0200 Subject: [PATCH 52/63] Default back to `Open` for disconnected input ports When a `Connected` input port receives a `CloseOutput` event for for the last output port ID switch back to the `Open` state. --- lib/protoflow-zeromq/src/input_port.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index 861670fe..29aa9f18 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -272,7 +272,9 @@ pub fn start_input_worker( let mut input_state = input_state.write().await; use ZmqInputPortState::*; - let Connected(_, ref sender, _, _, ref mut connected_ids) = *input_state else { + let Connected(ref req_send, ref sender, _, ref event_sender, ref mut connected_ids) = + *input_state + else { #[cfg(feature = "tracing")] span.in_scope(|| trace!("input port wasn't connected")); return; @@ -301,6 +303,8 @@ pub fn start_input_worker( // TODO: Should last connection closing close the input port too? // It does in the MPSC transport. //*input_state = ZmqInputPortState::Closed; + + *input_state = Open(req_send.clone(), event_sender.clone()) } // ignore, ideally we never receive these here: From ce2f6eec528dd6c2b07d84f1fd630126ca7a75c4 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 11:16:03 +0200 Subject: [PATCH 53/63] Fix input port worker exit The conditions for the `loop { tokio::select!( ... ) }` to exit are: 1. All the references to `req_send` channel are dropped. This channel is used by the public transport interface to close the port. It clones the `Sender` and `send().awaits` on it. Once all of those have been processed, and the port is in a `Closed` state, there are no other references to the `req_send` port than the one removed in this commit. 2. All the references to `to_worker_send` channel are dropped. When receiving a `Message`, the `SubSocket` worker clones this `Sender` from an input worker's **Connected** state. Therefore if the input worker is in the **Closed** state, all references to the `Sender` will be dropped once the outstanding messages from the socket have been sent. Thus, by dropping the reference to the `req_send` channel inside the input port worker's thread, and after it processes the outstanding sends, no other thread is available to get a reference to the `Sender` channels and the loop is able to exit cleanly --- lib/protoflow-zeromq/src/input_port.rs | 36 +++++++++++++++++++++----- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index 29aa9f18..32d00b55 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -15,7 +15,7 @@ use tokio::sync::{ #[cfg(feature = "tracing")] use tracing::{error, info, trace, trace_span, warn}; -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub enum ZmqInputPortRequest { Close, } @@ -27,7 +27,7 @@ pub enum ZmqInputPortEvent { Closed, } -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub enum ZmqInputPortState { Open( // channel for requests from public close @@ -124,7 +124,6 @@ pub fn start_input_worker( async fn handle_socket_event( event: ZmqTransportEvent, inputs: &RwLock>>, - req_send: &Sender<(ZmqInputPortRequest, Sender>)>, pub_queue: &Sender, input_port_id: InputPortID, ) { @@ -139,10 +138,12 @@ pub fn start_input_worker( use ZmqTransportEvent::*; match event { - Connect(output_port_id, input_port_id) => { + Connect(output_port_id, target_id) => { #[cfg(feature = "tracing")] let span = trace_span!(parent: &span, "Connect", ?output_port_id); + debug_assert_eq!(input_port_id, target_id); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { #[cfg(feature = "tracing")] @@ -203,10 +204,12 @@ pub fn start_input_worker( #[cfg(feature = "tracing")] span.in_scope(|| info!("Connected new port: {}", input_state)); } - Message(output_port_id, _, seq_id, bytes) => { + Message(output_port_id, target_id, seq_id, bytes) => { #[cfg(feature = "tracing")] let span = trace_span!(parent: &span, "Message", ?output_port_id, ?seq_id); + debug_assert_eq!(input_port_id, target_id); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { #[cfg(feature = "tracing")] @@ -258,10 +261,12 @@ pub fn start_input_worker( } } } - CloseOutput(output_port_id, input_port_id) => { + CloseOutput(output_port_id, target_id) => { #[cfg(feature = "tracing")] let span = trace_span!(parent: &span, "CloseOutput", ?output_port_id); + debug_assert_eq!(input_port_id, target_id); + let inputs = inputs.read().await; let Some(input_state) = inputs.get(&input_port_id) else { #[cfg(feature = "tracing")] @@ -390,13 +395,30 @@ pub fn start_input_worker( loop { tokio::select! { Some(event) = to_worker_recv.recv() => { - handle_socket_event(event, &inputs, &req_send, &pub_queue, input_port_id).await; + handle_socket_event(event, &inputs, &pub_queue, input_port_id).await; } Some((request, response_chan)) = req_recv.recv() => { handle_input_request(request, response_chan, &inputs, &pub_queue, &sub_queue, input_port_id).await; } + else => break, }; } + + #[cfg(feature = "tracing")] + { + let state = match inputs.read().await.get(&input_port_id) { + Some(input) => Some(input.read().await.clone()), + None => None, + }; + span.in_scope(|| { + trace!( + events_closed = to_worker_recv.is_closed(), + requests_closed = req_recv.is_closed(), + ?state, + "exited input worker loop" + ) + }); + } }); Ok(()) From 896b09358cedce688626a0d209e5591e200e1dcf Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 13:05:51 +0200 Subject: [PATCH 54/63] Add trace for output worker exit --- lib/protoflow-zeromq/src/output_port.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs index b0397085..9c4a1a98 100644 --- a/lib/protoflow-zeromq/src/output_port.rs +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -13,13 +13,13 @@ use tokio::sync::{ #[cfg(feature = "tracing")] use tracing::{debug, error, info, trace, trace_span, warn}; -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub enum ZmqOutputPortRequest { Close, Send(Bytes), } -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub enum ZmqOutputPortState { Open( Sender<(InputPortID, Sender>)>, @@ -285,6 +285,16 @@ pub fn start_output_worker( debug_assert!(matches!(*output_state, ZmqOutputPortState::Connected(..))); *output_state = ZmqOutputPortState::Closed; + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!( + events_closed = to_worker_recv.is_closed(), + requests_closed = msg_req_recv.is_closed(), + state = ?*output_state, + "exited output worker loop" + ) + }); + if unsubscribe_topics(&topics, &sub_queue).await.is_err() { #[cfg(feature = "tracing")] span.in_scope(|| error!("topic unsubscription failed")); From 7f268647be99074976fe3f3110cda9e26c46e1e9 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 13:43:24 +0200 Subject: [PATCH 55/63] Make open output ports closable --- lib/protoflow-zeromq/src/input_port.rs | 2 +- lib/protoflow-zeromq/src/lib.rs | 30 +++++++++------- lib/protoflow-zeromq/src/output_port.rs | 48 ++++++++++++++++++++----- lib/protoflow-zeromq/src/socket.rs | 2 +- 4 files changed, 59 insertions(+), 23 deletions(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index 32d00b55..7418242a 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -30,7 +30,7 @@ pub enum ZmqInputPortEvent { #[derive(Clone, Debug)] pub enum ZmqInputPortState { Open( - // channel for requests from public close + // channel for close requests from the public `close` method Sender<(ZmqInputPortRequest, Sender>)>, // channel used internally for events from socket Sender, diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 27854a8d..34f7cac7 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -180,7 +180,7 @@ impl Transport for ZmqTransport { use ZmqInputPortState::*; match *input_state { Open(ref sender, _) | Connected(ref sender, ..) => sender.clone(), - Closed => return Err(PortError::Disconnected), + Closed => return Ok(false), // already closed } }; @@ -201,27 +201,31 @@ impl Transport for ZmqTransport { fn close_output(&self, output: OutputPortID) -> PortResult { self.tokio.block_on(async { - let sender = { + let mut close_recv = { let outputs = self.outputs.read().await; let Some(output_state) = outputs.get(&output) else { return Err(PortError::Invalid(output.into())); }; let output_state = output_state.read().await; - let ZmqOutputPortState::Connected(sender, _, _) = &*output_state else { - return Err(PortError::Disconnected); + let (close_send, close_recv) = channel(1); + + use ZmqOutputPortState::*; + match *output_state { + Open(_, ref sender, _) => sender + .send(close_send) + .await + .map_err(|e| PortError::Other(e.to_string()))?, + Connected(ref sender, ..) => sender + .send((ZmqOutputPortRequest::Close, close_send)) + .await + .map_err(|e| PortError::Other(e.to_string()))?, + Closed => return Ok(false), // already closed }; - sender.clone() + close_recv }; - let (close_send, mut close_recv) = channel(1); - - sender - .send((ZmqOutputPortRequest::Close, close_send)) - .await - .map_err(|e| PortError::Other(e.to_string()))?; - close_recv .recv() .await @@ -242,7 +246,7 @@ impl Transport for ZmqTransport { }; let output_state = output_state.read().await; - let ZmqOutputPortState::Open(ref sender, _) = *output_state else { + let ZmqOutputPortState::Open(ref sender, _, _) = *output_state else { return Err(PortError::Invalid(source.into())); }; diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs index 9c4a1a98..8376e224 100644 --- a/lib/protoflow-zeromq/src/output_port.rs +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -22,13 +22,17 @@ pub enum ZmqOutputPortRequest { #[derive(Clone, Debug)] pub enum ZmqOutputPortState { Open( + // channel for connection requests from public `connect` method Sender<(InputPortID, Sender>)>, + // channel for close requests from the public `close` method + Sender>>, + // channel used internally for events from socket Sender, ), Connected( - // channel for public send, contained channel is for the ack back + // channel for public `send` and `close` methods, contained channel is for the ack back Sender<(ZmqOutputPortRequest, Sender>)>, - // internal channel for events + // channel used internally for events from socket Sender, // id of the connected input port InputPortID, @@ -76,6 +80,7 @@ pub fn start_output_worker( let span = trace_span!("ZmqTransport::output_port_worker", ?output_port_id); let (conn_send, mut conn_recv) = channel(1); + let (close_send, mut close_recv) = channel(1); let (to_worker_send, mut to_worker_recv) = channel(1); { @@ -83,7 +88,7 @@ pub fn start_output_worker( if outputs.contains_key(&output_port_id) { return Err(PortError::Invalid(output_port_id.into())); } - let state = ZmqOutputPortState::Open(conn_send, to_worker_send.clone()); + let state = ZmqOutputPortState::Open(conn_send, close_send, to_worker_send.clone()); #[cfg(feature = "tracing")] span.in_scope(|| trace!("saving new state: {}", state)); outputs.insert(output_port_id, RwLock::new(state)); @@ -97,11 +102,38 @@ pub fn start_output_worker( span.in_scope(|| trace!("spawning")); tokio::task::spawn(async move { - let Some((input_port_id, conn_confirm)) = conn_recv.recv().await else { - // all senders have dropped, i.e. there's no connection request coming - #[cfg(feature = "tracing")] - debug!(parent: &span, "no connection request"); - return; + let (input_port_id, conn_confirm) = tokio::select! { + Some((input_port_id, conn_confirm)) = conn_recv.recv() => (input_port_id, conn_confirm), + Some(close_confirm) = close_recv.recv() => { + let response = { + if let Some(output_state) = outputs.read().await.get(&output_port_id) { + let mut output_state = output_state.write().await; + debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); + *output_state = ZmqOutputPortState::Closed; + Ok(()) + } else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + Err(PortError::Invalid(output_port_id.into())) + } + }; + + let _ = close_confirm.try_send(response); + return; + } + else => { + // all senders have dropped, i.e. there's no connection request coming + + if let Some(output_state) = outputs.read().await.get(&output_port_id) { + let mut output_state = output_state.write().await; + debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); + *output_state = ZmqOutputPortState::Closed; + } + + #[cfg(feature = "tracing")] + debug!(parent: &span, "no connection or close request"); + return; + } }; #[cfg(feature = "tracing")] diff --git a/lib/protoflow-zeromq/src/socket.rs b/lib/protoflow-zeromq/src/socket.rs index f65dda83..6691e652 100644 --- a/lib/protoflow-zeromq/src/socket.rs +++ b/lib/protoflow-zeromq/src/socket.rs @@ -187,7 +187,7 @@ async fn handle_zmq_msg( }; let output = output.read().await; - let ZmqOutputPortState::Open(_, sender) = &*output else { + let ZmqOutputPortState::Open(.., sender) = &*output else { return Err(PortError::Invalid(output_port_id.into())); }; From 2c9289cf5e7db16fc667066e28110540294a3cc6 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 15:15:22 +0200 Subject: [PATCH 56/63] Implement message redelivery in output port --- lib/protoflow-zeromq/src/output_port.rs | 122 +++++++++++++++--------- 1 file changed, 78 insertions(+), 44 deletions(-) diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs index 8376e224..2de704f6 100644 --- a/lib/protoflow-zeromq/src/output_port.rs +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -19,6 +19,9 @@ pub enum ZmqOutputPortRequest { Send(Bytes), } +const DEFAULT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(500); +const DEFAULT_MAX_RETRIES: u64 = 10; + #[derive(Clone, Debug)] pub enum ZmqOutputPortState { Open( @@ -247,57 +250,88 @@ pub fn start_output_worker( break 'send; } ZmqOutputPortRequest::Send(bytes) => { - if pub_queue - .send(ZmqTransportEvent::Message( - output_port_id, - input_port_id, - seq_id, - bytes, - )) - .await - .is_err() - { - respond(Err(PortError::SendFailed)).await; - continue 'send; - } + let msg = ZmqTransportEvent::Message( + output_port_id, + input_port_id, + seq_id, + bytes.clone(), + ); - 'recv: loop { - let Some(event) = to_worker_recv.recv().await else { - #[cfg(feature = "tracing")] - span.in_scope(|| error!("all senders to worker have dropped")); + let mut attempts = 0; + 'retry: loop { + attempts += 1; + + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "retry_loop", ?attempts); - respond(Err(PortError::Invalid(output_port_id.into()))).await; + if attempts >= DEFAULT_MAX_RETRIES { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("reached max send attempts")); + respond(Err(PortError::Disconnected)).await; break 'send; - }; + } #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?event, "received event")); - - use ZmqTransportEvent::*; - match event { - AckMessage(_, _, ack_id) => { - if ack_id == seq_id { - #[cfg(feature = "tracing")] - span.in_scope(|| trace!(?ack_id, "msg-ack matches")); - respond(Ok(())).await; - break 'recv; - } else { - #[cfg(feature = "tracing")] - span.in_scope(|| { - trace!(?ack_id, "got msg-ack for different sequence") - }); + span.in_scope(|| trace!("attempting to send message")); + + if pub_queue.send(msg.clone()).await.is_err() { + // the socket for publishing has closed, we won't be able to send any + // messages + respond(Err(PortError::Disconnected)).await; + break 'send; + } + + 'recv: loop { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "recv_loop"); + + let timeout = tokio::time::sleep(DEFAULT_TIMEOUT); + + let event = tokio::select! { + // after DEFAULT_TIMEOUT duration has passed since the last + // received event from the socket, retry + _ = timeout => continue 'retry, + event_opt = to_worker_recv.recv() => match event_opt { + Some(event) => event, + None => { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("all senders to worker have dropped")); + respond(Err(PortError::Invalid(output_port_id.into()))).await; + break 'send; + } } - } - CloseInput(_) => { - // report that the input port was closed - respond(Err(PortError::Disconnected)).await; - break 'send; - } + }; - // ignore others, we shouldn't receive any new conn-acks - // nor should we be receiving input port events - AckConnection(..) | Connect(..) | Message(..) | CloseOutput(..) => { - continue 'recv + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?event, "received event")); + + use ZmqTransportEvent::*; + match event { + AckMessage(_, _, ack_id) => { + if ack_id == seq_id { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?ack_id, "msg-ack matches")); + respond(Ok(())).await; + break 'retry; + } else { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!(?ack_id, "got msg-ack for different sequence") + }); + continue 'recv; + } + } + CloseInput(_) => { + // report that the input port was closed + respond(Err(PortError::Disconnected)).await; + break 'send; + } + + // ignore others, we shouldn't receive any new conn-acks + // nor should we be receiving input port events + AckConnection(..) | Connect(..) | Message(..) | CloseOutput(..) => { + continue 'recv + } } } } From a994872bea8cc996227374c5e027b30d8d357afe Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 19:03:29 +0200 Subject: [PATCH 57/63] Handle message redelivery and re-acknowledgment in input worker --- lib/protoflow-zeromq/src/input_port.rs | 104 +++++++++++++++---------- 1 file changed, 63 insertions(+), 41 deletions(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index 7418242a..e84f01bb 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -1,7 +1,8 @@ // This is free and unencumbered software released into the public domain. use crate::{ - subscribe_topics, unsubscribe_topics, ZmqSubscriptionRequest, ZmqTransport, ZmqTransportEvent, + subscribe_topics, unsubscribe_topics, SequenceID, ZmqSubscriptionRequest, ZmqTransport, + ZmqTransportEvent, }; use protoflow_core::{ prelude::{fmt, format, vec, Arc, BTreeMap, Bytes, String, ToString, Vec}, @@ -44,7 +45,7 @@ pub enum ZmqInputPortState { // channel used internally for events from socket Sender, // vec of the connected port ids - Vec, + BTreeMap, ), Closed, } @@ -58,7 +59,7 @@ impl fmt::Display for ZmqInputPortState { write!( f, "Connected({:?})", - ids.iter().map(|id| isize::from(*id)).collect::>() + ids.keys().map(|id| isize::from(*id)).collect::>() ) } Closed => write!(f, "Closed"), @@ -156,7 +157,7 @@ pub fn start_input_worker( match &*input_state { Open(..) => (), Connected(.., connected_ids) => { - if connected_ids.iter().any(|&id| id == output_port_id) { + if connected_ids.contains_key(&output_port_id) { #[cfg(feature = "tracing")] span.in_scope(|| trace!("output port is already connected")); return; @@ -169,16 +170,18 @@ pub fn start_input_worker( Open(req_send, to_worker_send) => { let (msgs_send, msgs_recv) = channel(1); let msgs_recv = Arc::new(Mutex::new(msgs_recv)); + let mut connected_ids = BTreeMap::new(); + connected_ids.insert(output_port_id, 0); *input_state = Connected( req_send.clone(), msgs_send, msgs_recv, to_worker_send.clone(), - vec![output_port_id], + connected_ids, ); } Connected(.., ids) => { - ids.push(output_port_id); + ids.insert(output_port_id, 0); } Closed => unreachable!(), }; @@ -204,9 +207,9 @@ pub fn start_input_worker( #[cfg(feature = "tracing")] span.in_scope(|| info!("Connected new port: {}", input_state)); } - Message(output_port_id, target_id, seq_id, bytes) => { + Message(output_port_id, target_id, msg_seq_id, bytes) => { #[cfg(feature = "tracing")] - let span = trace_span!(parent: &span, "Message", ?output_port_id, ?seq_id); + let span = trace_span!(parent: &span, "Message", ?output_port_id, ?msg_seq_id); debug_assert_eq!(input_port_id, target_id); @@ -216,43 +219,64 @@ pub fn start_input_worker( span.in_scope(|| error!("port state not found")); return; }; - let input_state = input_state.read().await; + let mut input_state = input_state.write().await; use ZmqInputPortState::*; - match &*input_state { - Connected(_, sender, _, _, connected_ids) => { - if !connected_ids.iter().any(|id| *id == output_port_id) { + match *input_state { + Connected(_, ref sender, _, _, ref mut connected_ids) => { + let Some(&last_seen_seq_id) = connected_ids.get(&output_port_id) else { #[cfg(feature = "tracing")] span.in_scope(|| trace!("got message from non-connected output port")); return; - } + }; - if sender - .send(ZmqInputPortEvent::Message(bytes)) - .await - .is_err() - { + let send_ack = { #[cfg(feature = "tracing")] - span.in_scope(|| warn!("receiver for input events has closed")); - return; + let span = span.clone(); + + |ack_id| async move { + if pub_queue + .send(ZmqTransportEvent::AckMessage( + output_port_id, + input_port_id, + ack_id, + )) + .await + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("publish channel is closed")); + } + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?ack_id, "sent msg-ack")); + } + }; + + use std::cmp::Ordering::*; + match msg_seq_id.cmp(&last_seen_seq_id) { + // seq_id for msg is greater than last seen seq_id by one + Greater if (msg_seq_id - last_seen_seq_id == 1) => { + if sender + .send(ZmqInputPortEvent::Message(bytes)) + .await + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("receiver for input events has closed")); + return; + } + send_ack(msg_seq_id).await; + let _ = connected_ids.insert(output_port_id, msg_seq_id); + } + Equal => { + send_ack(last_seen_seq_id).await; + } + // either the seq_id is greater than the last seen seq_id by more than + // one, or somehow less than the last seen seq_id: + _ => { + send_ack(last_seen_seq_id).await; + } } - - if pub_queue - .send(ZmqTransportEvent::AckMessage( - output_port_id, - input_port_id, - seq_id, - )) - .await - .is_err() - { - #[cfg(feature = "tracing")] - span.in_scope(|| warn!("publish channel is closed")); - return; - } - - #[cfg(feature = "tracing")] - span.in_scope(|| trace!("sent msg-ack")); } Open(..) | Closed => { @@ -285,13 +309,11 @@ pub fn start_input_worker( return; }; - let Some(idx) = connected_ids.iter().position(|&id| id == output_port_id) else { + if connected_ids.remove(&output_port_id).is_none() { #[cfg(feature = "tracing")] span.in_scope(|| trace!("output port doesn't match any connected port")); return; - }; - - connected_ids.swap_remove(idx); + } if !connected_ids.is_empty() { return; From af6d207c7ea0899f8b0a90e2bbea781ed80de8bd Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 22:06:41 +0200 Subject: [PATCH 58/63] Shortcut sending data to network if input port is reachable locally --- lib/protoflow-zeromq/src/input_port.rs | 20 ++++ lib/protoflow-zeromq/src/lib.rs | 2 +- lib/protoflow-zeromq/src/output_port.rs | 24 ++++- lib/protoflow-zeromq/src/socket.rs | 124 ++++++++---------------- 4 files changed, 86 insertions(+), 84 deletions(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index e84f01bb..18544c1e 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -76,6 +76,14 @@ impl ZmqInputPortState { Closed => PortState::Closed, } } + + pub async fn event_sender(&self) -> Option> { + use ZmqInputPortState::*; + match self { + Open(_, sender) | Connected(.., sender, _) => Some(sender.clone()), + Closed => None, + } + } } fn input_topics(id: InputPortID) -> Vec { @@ -86,6 +94,18 @@ fn input_topics(id: InputPortID) -> Vec { ] } +pub async fn input_port_event_sender( + inputs: &RwLock>>, + id: InputPortID, +) -> Option> { + if let Some(input_state) = inputs.read().await.get(&id) { + let input_state = input_state.read().await; + input_state.event_sender().await + } else { + None + } +} + pub fn start_input_worker( transport: &ZmqTransport, input_port_id: InputPortID, diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 34f7cac7..105a5fa7 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -100,7 +100,7 @@ impl ZmqTransport { inputs, }; - start_pub_socket_worker(psock, pub_queue_recv); + start_pub_socket_worker(&transport, psock, pub_queue_recv); start_sub_socket_worker(&transport, ssock, sub_queue_recv); transport diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs index 2de704f6..843f6ad0 100644 --- a/lib/protoflow-zeromq/src/output_port.rs +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -2,7 +2,7 @@ use crate::{subscribe_topics, unsubscribe_topics, ZmqTransport, ZmqTransportEvent}; use protoflow_core::{ - prelude::{fmt, format, vec, Bytes, String, ToString, Vec}, + prelude::{fmt, format, vec, BTreeMap, Bytes, String, ToString, Vec}, InputPortID, OutputPortID, PortError, PortState, }; use tokio::sync::{ @@ -19,7 +19,7 @@ pub enum ZmqOutputPortRequest { Send(Bytes), } -const DEFAULT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(500); +const DEFAULT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(200); const DEFAULT_MAX_RETRIES: u64 = 10; #[derive(Clone, Debug)] @@ -65,6 +65,14 @@ impl ZmqOutputPortState { Closed => PortState::Closed, } } + + pub async fn event_sender(&self) -> Option> { + use ZmqOutputPortState::*; + match self { + Open(.., sender) | Connected(.., sender, _) => Some(sender.clone()), + Closed => None, + } + } } fn output_topics(source: OutputPortID, target: InputPortID) -> Vec { @@ -75,6 +83,18 @@ fn output_topics(source: OutputPortID, target: InputPortID) -> Vec { ] } +pub async fn output_port_event_sender( + outputs: &RwLock>>, + id: OutputPortID, +) -> Option> { + if let Some(output_state) = outputs.read().await.get(&id) { + let output_state = output_state.read().await; + output_state.event_sender().await + } else { + None + } +} + pub fn start_output_worker( transport: &ZmqTransport, output_port_id: OutputPortID, diff --git a/lib/protoflow-zeromq/src/socket.rs b/lib/protoflow-zeromq/src/socket.rs index 6691e652..cef3d2f1 100644 --- a/lib/protoflow-zeromq/src/socket.rs +++ b/lib/protoflow-zeromq/src/socket.rs @@ -1,6 +1,9 @@ // This is free and unencumbered software released into the public domain. -use crate::{ZmqInputPortState, ZmqOutputPortState, ZmqTransport, ZmqTransportEvent}; +use crate::{ + input_port_event_sender, output_port_event_sender, ZmqInputPortState, ZmqOutputPortState, + ZmqTransport, ZmqTransportEvent, +}; use protoflow_core::{ prelude::{BTreeMap, String, Vec}, InputPortID, OutputPortID, PortError, @@ -18,11 +21,17 @@ pub enum ZmqSubscriptionRequest { } #[cfg(feature = "tracing")] -use tracing::{error, trace, trace_span}; +use tracing::{debug, error, trace, trace_span, warn}; -pub fn start_pub_socket_worker(psock: zeromq::PubSocket, pub_queue: Receiver) { +pub fn start_pub_socket_worker( + transport: &ZmqTransport, + psock: zeromq::PubSocket, + pub_queue: Receiver, +) { #[cfg(feature = "tracing")] let span = trace_span!("ZmqTransport::pub_socket"); + let outputs = transport.outputs.clone(); + let inputs = transport.inputs.clone(); let mut psock = psock; let mut pub_queue = pub_queue; tokio::task::spawn(async move { @@ -30,6 +39,27 @@ pub fn start_pub_socket_worker(psock: zeromq::PubSocket, pub_queue: Receiver { + input_port_event_sender(&inputs, id).await + } + AckConnection(id, _) | AckMessage(id, ..) => { + output_port_event_sender(&outputs, id).await + } + CloseInput(..) => None, + }; + + if let Some(sender) = shortcut_sender { + #[cfg(feature = "tracing")] + span.in_scope(|| debug!("attempting to shortcut send directly to target port")); + if sender.send(event.clone()).await.is_ok() { + continue; + } + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("failed to send message with shortcut, sending to socket")); + } + if let Err(err) = psock.send(event.into()).await { #[cfg(feature = "tracing")] span.in_scope(|| error!(?err, "failed to send message")); @@ -126,89 +156,21 @@ async fn handle_zmq_msg( use ZmqTransportEvent::*; match event { // input ports - Connect(_, input_port_id) => { - let sender = { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - return Err(PortError::Invalid(input_port_id.into())); - }; - let input = input.read().await; - - use ZmqInputPortState::*; - match &*input { - Closed => return Err(PortError::Invalid(input_port_id.into())), - Open(.., sender) | Connected(.., sender, _) => sender.clone(), - } - }; - - sender.send(event).await.map_err(|_| PortError::Closed) - } - Message(_, input_port_id, _, _) => { - let sender = { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - return Err(PortError::Invalid(input_port_id.into())); - }; - - let input = input.read().await; - let ZmqInputPortState::Connected(_, _, _, sender, _) = &*input else { - return Err(PortError::Invalid(input_port_id.into())); - }; - - sender.clone() - }; - - sender.send(event).await.map_err(|_| PortError::Closed) - } - CloseOutput(_, input_port_id) => { - let sender = { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - return Err(PortError::Invalid(input_port_id.into())); - }; - let input = input.read().await; - - use ZmqInputPortState::*; - match &*input { - Closed => return Err(PortError::Invalid(input_port_id.into())), - Open(.., sender) | Connected(.., sender, _) => sender.clone(), - } - }; + Connect(_, input_port_id) + | Message(_, input_port_id, _, _) + | CloseOutput(_, input_port_id) => { + let sender = input_port_event_sender(inputs, input_port_id) + .await + .ok_or_else(|| PortError::Invalid(input_port_id.into()))?; sender.send(event).await.map_err(|_| PortError::Closed) } // output ports - AckConnection(output_port_id, _) => { - let sender = { - let outputs = outputs.read().await; - let Some(output) = outputs.get(&output_port_id) else { - return Err(PortError::Invalid(output_port_id.into())); - }; - let output = output.read().await; - - let ZmqOutputPortState::Open(.., sender) = &*output else { - return Err(PortError::Invalid(output_port_id.into())); - }; - - sender.clone() - }; - - sender.send(event).await.map_err(|_| PortError::Closed) - } - AckMessage(output_port_id, _, _) => { - let sender = { - let outputs = outputs.read().await; - let Some(output) = outputs.get(&output_port_id) else { - return Err(PortError::Invalid(output_port_id.into())); - }; - let output = output.read().await; - let ZmqOutputPortState::Connected(_, sender, _) = &*output else { - return Err(PortError::Invalid(output_port_id.into())); - }; - - sender.clone() - }; + AckConnection(output_port_id, _) | AckMessage(output_port_id, _, _) => { + let sender = output_port_event_sender(outputs, output_port_id) + .await + .ok_or_else(|| PortError::Invalid(output_port_id.into()))?; sender.send(event).await.map_err(|_| PortError::Closed) } From a6c4faa5146a86c21f276dcc15b2678920f6ab96 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 22:14:23 +0200 Subject: [PATCH 59/63] Refactor port event sender access helpers --- lib/protoflow-zeromq/src/input_port.rs | 14 ++++++++------ lib/protoflow-zeromq/src/output_port.rs | 14 ++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index 18544c1e..cbd868f6 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -98,12 +98,14 @@ pub async fn input_port_event_sender( inputs: &RwLock>>, id: InputPortID, ) -> Option> { - if let Some(input_state) = inputs.read().await.get(&id) { - let input_state = input_state.read().await; - input_state.event_sender().await - } else { - None - } + inputs + .read() + .await + .get(&id)? + .read() + .await + .event_sender() + .await } pub fn start_input_worker( diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs index 843f6ad0..8d18f7d0 100644 --- a/lib/protoflow-zeromq/src/output_port.rs +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -87,12 +87,14 @@ pub async fn output_port_event_sender( outputs: &RwLock>>, id: OutputPortID, ) -> Option> { - if let Some(output_state) = outputs.read().await.get(&id) { - let output_state = output_state.read().await; - output_state.event_sender().await - } else { - None - } + outputs + .read() + .await + .get(&id)? + .read() + .await + .event_sender() + .await } pub fn start_output_worker( From c133cfd67d9ec37ff74b11337fc053a481617838 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Thu, 5 Dec 2024 12:48:13 +0200 Subject: [PATCH 60/63] Separate proto and rust files --- lib/protoflow-zeromq/build.rs | 2 +- lib/protoflow-zeromq/{src => proto}/transport_event.proto | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename lib/protoflow-zeromq/{src => proto}/transport_event.proto (100%) diff --git a/lib/protoflow-zeromq/build.rs b/lib/protoflow-zeromq/build.rs index 8449adc0..71f34fd0 100644 --- a/lib/protoflow-zeromq/build.rs +++ b/lib/protoflow-zeromq/build.rs @@ -2,5 +2,5 @@ use std::io::Result; fn main() -> Result<()> { prost_build::Config::default() .out_dir("src/") - .compile_protos(&["src/transport_event.proto"], &["src/"]) + .compile_protos(&["proto/transport_event.proto"], &["proto/"]) } diff --git a/lib/protoflow-zeromq/src/transport_event.proto b/lib/protoflow-zeromq/proto/transport_event.proto similarity index 100% rename from lib/protoflow-zeromq/src/transport_event.proto rename to lib/protoflow-zeromq/proto/transport_event.proto From dcbe449caf2fccf144a8eccd1ac21f86d347bba1 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Thu, 5 Dec 2024 21:40:28 +0200 Subject: [PATCH 61/63] Improve test reliability --- lib/protoflow-zeromq/src/event.rs | 2 +- lib/protoflow-zeromq/src/input_port.rs | 2 +- lib/protoflow-zeromq/src/lib.rs | 58 ++++++++++++++++---------- 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/lib/protoflow-zeromq/src/event.rs b/lib/protoflow-zeromq/src/event.rs index a6506f06..f5f60c19 100644 --- a/lib/protoflow-zeromq/src/event.rs +++ b/lib/protoflow-zeromq/src/event.rs @@ -9,7 +9,7 @@ use zeromq::ZmqMessage; pub type SequenceID = u64; /// ZmqTransportEvent represents the data that goes over the wire from one port to another. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum ZmqTransportEvent { Connect(OutputPortID, InputPortID), AckConnection(OutputPortID, InputPortID), diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index cbd868f6..d3cfeb97 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -22,7 +22,7 @@ pub enum ZmqInputPortRequest { } /// ZmqInputPortEvent represents events that we receive from the background worker of the port. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum ZmqInputPortEvent { Message(Bytes), Closed, diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 105a5fa7..43763dfc 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -366,33 +366,47 @@ mod tests { use super::*; use protoflow_core::{runtimes::StdRuntime, System}; + use std::time::Duration; use futures_util::future::TryFutureExt; use zeromq::{PubSocket, SocketRecv, SocketSend, SubSocket}; async fn start_zmqtransport_server() { - // bind a *SUB* socket to the *PUB* address so that the transport can *PUB* to it - let mut pub_srv = SubSocket::new(); - pub_srv.bind(DEFAULT_PUB_SOCKET).await.unwrap(); - - // bind a *PUB* socket to the *SUB* address so that the transport can *SUB* to it - let mut sub_srv = PubSocket::new(); - sub_srv.bind(DEFAULT_SUB_SOCKET).await.unwrap(); - - // subscribe to all messages - pub_srv.subscribe("").await.unwrap(); - - // resend anything received on the *SUB* socket to the *PUB* socket - tokio::task::spawn(async move { - let mut pub_srv = pub_srv; - loop { - pub_srv - .recv() - .and_then(|msg| sub_srv.send(msg)) - .await - .unwrap(); + // retry for a second + for _ in 0..20 { + // bind a *SUB* socket to the *PUB* address so that the transport can *PUB* to it + let mut pub_srv = SubSocket::new(); + if pub_srv.bind(DEFAULT_PUB_SOCKET).await.is_err() { + tokio::time::sleep(Duration::from_millis(50)).await; + continue; } - }); + + // bind a *PUB* socket to the *SUB* address so that the transport can *SUB* to it + let mut sub_srv = PubSocket::new(); + if sub_srv.bind(DEFAULT_SUB_SOCKET).await.is_err() { + tokio::time::sleep(Duration::from_millis(50)).await; + continue; + } + + // subscribe to all messages + pub_srv.subscribe("").await.unwrap(); + + // resend anything received on the *SUB* socket to the *PUB* socket + tokio::task::spawn(async move { + let mut pub_srv = pub_srv; + loop { + pub_srv + .recv() + .and_then(|msg| sub_srv.send(msg)) + .await + .unwrap(); + } + }); + + return; + } + + panic!("unable to start server for tests, are the ports 10000 and 10001 available?"); } #[test] @@ -407,7 +421,7 @@ mod tests { #[test] fn run_transport() { - tracing_subscriber::fmt::init(); + let _ = tracing_subscriber::fmt::try_init(); let rt = tokio::runtime::Runtime::new().unwrap(); let _guard = rt.enter(); From b09c7f1b5cdd210ccf2f678d111958a29912a9ab Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Sat, 7 Dec 2024 22:03:28 +0200 Subject: [PATCH 62/63] Add test for multiple outputs to single input --- lib/protoflow-zeromq/src/lib.rs | 55 +++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 43763dfc..2b579c4d 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -458,4 +458,59 @@ mod tests { output.join().expect("thread failed").unwrap(); input.join().expect("thread failed").unwrap(); } + + #[test] + fn multiple_outputs_to_one_input() { + let _ = tracing_subscriber::fmt::try_init(); + + let rt = tokio::runtime::Runtime::new().unwrap(); + let _guard = rt.enter(); + + rt.block_on(start_zmqtransport_server()); + + let transport = ZmqTransport::default(); + let runtime = StdRuntime::new(transport).unwrap(); + let system = System::new(&runtime); + + let mut output1 = system.output(); + let mut output2 = system.output(); + + let mut input = system.input(); + + assert!(system.connect(&output1, &input)); + assert!(system.connect(&output2, &input)); + + output1.send(&"Hello from one!".to_string()).unwrap(); + assert_eq!(Some("Hello from one!".to_string()), input.recv().unwrap()); + + output2.send(&"Hello from two!".to_string()).unwrap(); + assert_eq!(Some("Hello from two!".to_string()), input.recv().unwrap()); + + output1.send(&"Hello from one again!".to_string()).unwrap(); + assert_eq!( + Some("Hello from one again!".to_string()), + input.recv().unwrap() + ); + + assert!(input.close().unwrap()); + assert_eq!( + Err(PortError::Disconnected), + output1.send(&"Hello from one!".to_string()) + ); + assert_eq!( + Err(PortError::Disconnected), + output2.send(&"Hello from two!".to_string()) + ); + + assert_eq!(Err(PortError::Disconnected), input.try_recv()); + + assert!( + !output1.close().unwrap(), + "closing output should return Ok(false) because input was already closed" + ); + assert!( + !output2.close().unwrap(), + "closing output should return Ok(false) because input was already closed" + ); + } } From 2d26531825bb97b532e9dba30678ea2d9fd01af0 Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Sat, 7 Dec 2024 22:03:46 +0200 Subject: [PATCH 63/63] Add test for redelivery to input worker --- lib/protoflow-zeromq/src/input_port.rs | 150 +++++++++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index d3cfeb97..15074e7b 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -467,3 +467,153 @@ pub fn start_input_worker( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + use std::time::Duration; + + #[test] + fn redelivery_is_idempotent() { + let _ = tracing_subscriber::fmt::try_init(); + + let rt = tokio::runtime::Runtime::new().unwrap(); + let _guard = rt.enter(); + + let (pub_queue, mut pub_queue_recv) = channel(1); + let (sub_queue, mut sub_queue_recv) = channel(1); + + let inputs = Arc::new(RwLock::new(BTreeMap::new())); + let outputs = Arc::new(RwLock::new(BTreeMap::new())); + + let output_id = OutputPortID::try_from(1).unwrap(); + let input_id = InputPortID::try_from(-1).unwrap(); + + let transport = ZmqTransport { + tokio: rt.handle().clone(), + pub_queue, + sub_queue, + inputs: inputs.clone(), + outputs: outputs.clone(), + }; + + // start a fake socket worker that just drops all messages + let sub_queue = tokio::task::spawn(async move { + while sub_queue_recv.recv().await.is_some() {} + Some(()) + }); + + start_input_worker(&transport, input_id).unwrap(); + + let (recv_send, recv_recv) = channel(1); + let recv_recv = Arc::new(Mutex::new(recv_recv)); + + // manually connect the port + let (req_sender, event_sender) = rt.block_on(async { + let inputs = inputs.read().await; + let mut input_state = inputs.get(&input_id).unwrap().write().await; + + let ZmqInputPortState::Open(ref req_sender, ref event_sender) = *input_state else { + panic!(""); + }; + let req_sender = req_sender.clone(); + let event_sender = event_sender.clone(); + + let mut connected_ids = BTreeMap::new(); + connected_ids.insert(output_id, 0); + + *input_state = ZmqInputPortState::Connected( + req_sender.clone(), + recv_send, + recv_recv.clone(), + event_sender.clone(), + connected_ids, + ); + + (req_sender.clone(), event_sender.clone()) + }); + + let timeout = Duration::from_secs(1); + + // send a message from the `output_id` to the worker + rt.block_on(tokio::time::timeout( + timeout, + event_sender.send(ZmqTransportEvent::Message( + output_id, + input_id, + 1, + Bytes::new(), + )), + )) + .unwrap() + .unwrap(); + + // verify that the worker tries to publish a msg-ack + assert_eq!( + Some(ZmqTransportEvent::AckMessage(output_id, input_id, 1)), + rt.block_on(pub_queue_recv.recv()) + ); + + // verify that the worker forwards a new message + assert_eq!( + Ok(Some(ZmqInputPortEvent::Message(Bytes::new()))), + rt.block_on(tokio::time::timeout(timeout, async { + recv_recv.lock().await.recv().await + })) + ); + + // send a new message with the same sequence id to the worker + rt.block_on(tokio::time::timeout( + timeout, + event_sender.send(ZmqTransportEvent::Message( + output_id, + input_id, + 1, + Bytes::new(), + )), + )) + .unwrap() + .unwrap(); + + // verify that the worker tries to publish a msg-ack + assert_eq!( + Ok(Some(ZmqTransportEvent::AckMessage(output_id, input_id, 1))), + rt.block_on(tokio::time::timeout(timeout, pub_queue_recv.recv())) + ); + + // verify that the worker *DOESN'T* forward the message + assert!(rt + .block_on(tokio::time::timeout(timeout, async { + recv_recv.lock().await.recv().await + })) + .is_err()); + + let (close_send, mut close_recv) = channel(1); + + // send a close request the worker + rt.block_on(tokio::time::timeout(timeout, async { + req_sender + .send((ZmqInputPortRequest::Close, close_send)) + .await + .unwrap(); + close_recv.recv().await.unwrap() + })) + .unwrap() + .unwrap(); + + // drop remaining references to the channels that the worker is waiting on + drop(event_sender); + drop(req_sender); + drop(transport); + + // verify that the fake socket worker also exits, implies that the worker has exited as the + // channel sender references must be dropped for the fake worker to exit. + assert_eq!( + Some(()), + rt.block_on(tokio::time::timeout(timeout, sub_queue)) + .unwrap() + .unwrap() + ); + } +}