diff --git a/concurrency/src/tasks/gen_server.rs b/concurrency/src/tasks/gen_server.rs index 88eb743..50e6a40 100644 --- a/concurrency/src/tasks/gen_server.rs +++ b/concurrency/src/tasks/gen_server.rs @@ -35,13 +35,45 @@ impl GenServerHandle { let handle_clone = handle.clone(); // Ignore the JoinHandle for now. Maybe we'll use it in the future let _join_handle = rt::spawn(async move { - if gen_server.run(&handle, &mut rx).await.is_err() { + if gen_server.run(&handle, &mut rx, None).await.is_err() { tracing::trace!("GenServer crashed") }; }); handle_clone } + pub(crate) fn verified_new(gen_server: G) -> Result { + let (tx, mut rx) = mpsc::channel::>(); + let cancellation_token = CancellationToken::new(); + let handle = GenServerHandle { + tx, + cancellation_token, + }; + let handle_clone = handle.clone(); + + // We create a channel of single use to signal when the GenServer has started. + let (mut start_signal_tx, start_signal_rx) = std::sync::mpsc::channel(); + // Ignore the JoinHandle for now. Maybe we'll use it in the future + let join_handle = rt::spawn(async move { + if gen_server + .run(&handle, &mut rx, Some(&mut start_signal_tx)) + .await + .is_err() + { + tracing::trace!("GenServer crashed") + }; + }); + + // Wait for the GenServer to signal us that it has started + match start_signal_rx.recv() { + Ok(true) => Ok(handle_clone), + _ => { + join_handle.abort(); // Abort the task even tho we know it won't run anymore + Err(GenServerError::Initialization) + } + } + } + pub(crate) fn new_blocking(gen_server: G) -> Self { let (tx, mut rx) = mpsc::channel::>(); let cancellation_token = CancellationToken::new(); @@ -53,7 +85,7 @@ impl GenServerHandle { // Ignore the JoinHandle for now. Maybe we'll use it in the future let _join_handle = rt::spawn_blocking(|| { rt::block_on(async move { - if gen_server.run(&handle, &mut rx).await.is_err() { + if gen_server.run(&handle, &mut rx, None).await.is_err() { tracing::trace!("GenServer crashed") }; }) @@ -61,6 +93,41 @@ impl GenServerHandle { handle_clone } + pub(crate) fn verified_new_blocking(gen_server: G) -> Result { + let (tx, mut rx) = mpsc::channel::>(); + let cancellation_token = CancellationToken::new(); + let handle = GenServerHandle { + tx, + cancellation_token, + }; + let handle_clone = handle.clone(); + + // We create a channel of single use to signal when the GenServer has started. + // This channel is used in the verified method, here it's just to keep the API consistent. + // The handle is thereby returned immediately, without waiting for the GenServer to start. + let (mut start_signal_tx, start_signal_rx) = std::sync::mpsc::channel(); + let join_handle = rt::spawn_blocking(|| { + rt::block_on(async move { + if gen_server + .run(&handle, &mut rx, Some(&mut start_signal_tx)) + .await + .is_err() + { + tracing::trace!("GenServer crashed") + }; + }) + }); + + // Wait for the GenServer to signal us that it has started + match start_signal_rx.recv() { + Ok(true) => Ok(handle_clone), + _ => { + join_handle.abort(); // Abort the task even tho we know it won't run anymore + Err(GenServerError::Initialization) + } + } + } + pub fn sender(&self) -> mpsc::Sender> { self.tx.clone() } @@ -126,23 +193,43 @@ pub trait GenServer: Send + Sized + Clone { type OutMsg: Send + Sized; type Error: Debug + Send; + /// Starts the GenServer, without waiting for it to finalize its `init` process. fn start(self) -> GenServerHandle { GenServerHandle::new(self) } + /// Starts the GenServer, waiting for it to finalize its `init` process. + fn verified_start(self) -> Result, GenServerError> { + GenServerHandle::verified_new(self) + } + /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't /// happen if the task is blocking the thread. As such, for sync compute task /// or other blocking tasks need to be in their own separate thread, and the OS /// will manage them through hardware interrupts. /// Start blocking provides such thread. + /// + /// As with `start`, it doesn't wait for the GenServer to finalize its `init` process. fn start_blocking(self) -> GenServerHandle { GenServerHandle::new_blocking(self) } + /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't + /// happen if the task is blocking the thread. As such, for sync compute task + /// or other blocking tasks need to be in their own separate thread, and the OS + /// will manage them through hardware interrupts. + /// Start blocking provides such thread. + /// + /// As with `verified_start`, it waits for the GenServer to finalize its `init` process. + fn verified_start_blocking(self) -> Result, GenServerError> { + GenServerHandle::verified_new_blocking(self) + } + fn run( self, handle: &GenServerHandle, rx: &mut mpsc::Receiver>, + start_signal_tx: Option<&mut std::sync::mpsc::Sender>, ) -> impl Future> + Send { async { let init_result = self @@ -151,8 +238,26 @@ pub trait GenServer: Send + Sized + Clone { .inspect_err(|err| tracing::error!("Initialization failed: {err:?}")); let res = match init_result { - Ok(new_state) => new_state.main_loop(handle, rx).await, - Err(_) => Err(GenServerError::Initialization), + Ok(new_state) => { + // Notify that the GenServer has started successfully + // in case we have a start signal channel + if let Some(start_signal_tx) = start_signal_tx { + start_signal_tx + .send(true) + .map_err(|_| GenServerError::Initialization)?; + } + new_state.main_loop(handle, rx).await + } + Err(_) => { + // Notify that the GenServer failed to start + // in case we have a start signal channel + if let Some(start_signal_tx) = start_signal_tx { + start_signal_tx + .send(false) + .map_err(|_| GenServerError::Initialization)?; + } + Err(GenServerError::Initialization) + } }; handle.cancellation_token().cancel(); @@ -469,4 +574,40 @@ mod tests { assert!(matches!(result, Err(GenServerError::CallTimeout))); }); } + + #[derive(Clone)] + struct FailsOnInitTask; + + impl GenServer for FailsOnInitTask { + type CallMsg = (); + type CastMsg = (); + type OutMsg = (); + type Error = (); + + async fn init(self, _handle: &GenServerHandle) -> Result { + Err(()) + } + } + + #[test] + pub fn failing_on_init_task() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + // Attempt to start a GenServer that fails on initialization + let result = FailsOnInitTask.verified_start(); + assert!(matches!(result, Err(GenServerError::Initialization))); + + // Attempt to start a GenServer (in a blocking way) that fails on initialization + let result = FailsOnInitTask.verified_start_blocking(); + assert!(matches!(result, Err(GenServerError::Initialization))); + + // Other tasks should start correctly + let result = WellBehavedTask { count: 0 }.verified_start(); + assert!(result.is_ok()); + + // They also should start in blocking mode + let result = WellBehavedTask { count: 0 }.verified_start_blocking(); + assert!(result.is_ok()); + }); + } }