Skip to content

feature: add Result on GenServer start #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 145 additions & 4 deletions concurrency/src/tasks/gen_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,45 @@ impl<G: GenServer> GenServerHandle<G> {
let handle_clone = handle.clone();
// Ignore the JoinHandle for now. Maybe we'll use it in the future
let _join_handle = rt::spawn(async move {
if gen_server.run(&handle, &mut rx).await.is_err() {
if gen_server.run(&handle, &mut rx, None).await.is_err() {
tracing::trace!("GenServer crashed")
};
});
handle_clone
}

pub(crate) fn verified_new(gen_server: G) -> Result<Self, GenServerError> {
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
let cancellation_token = CancellationToken::new();
let handle = GenServerHandle {
tx,
cancellation_token,
};
let handle_clone = handle.clone();

// We create a channel of single use to signal when the GenServer has started.
let (mut start_signal_tx, start_signal_rx) = std::sync::mpsc::channel();
// Ignore the JoinHandle for now. Maybe we'll use it in the future
let join_handle = rt::spawn(async move {
if gen_server
.run(&handle, &mut rx, Some(&mut start_signal_tx))
.await
.is_err()
{
tracing::trace!("GenServer crashed")
};
});

// Wait for the GenServer to signal us that it has started
match start_signal_rx.recv() {
Ok(true) => Ok(handle_clone),
_ => {
join_handle.abort(); // Abort the task even tho we know it won't run anymore
Err(GenServerError::Initialization)
}
}
}

pub(crate) fn new_blocking(gen_server: G) -> Self {
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
let cancellation_token = CancellationToken::new();
Expand All @@ -53,14 +85,49 @@ impl<G: GenServer> GenServerHandle<G> {
// Ignore the JoinHandle for now. Maybe we'll use it in the future
let _join_handle = rt::spawn_blocking(|| {
rt::block_on(async move {
if gen_server.run(&handle, &mut rx).await.is_err() {
if gen_server.run(&handle, &mut rx, None).await.is_err() {
tracing::trace!("GenServer crashed")
};
})
});
handle_clone
}

pub(crate) fn verified_new_blocking(gen_server: G) -> Result<Self, GenServerError> {
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
let cancellation_token = CancellationToken::new();
let handle = GenServerHandle {
tx,
cancellation_token,
};
let handle_clone = handle.clone();

// We create a channel of single use to signal when the GenServer has started.
// This channel is used in the verified method, here it's just to keep the API consistent.
// The handle is thereby returned immediately, without waiting for the GenServer to start.
let (mut start_signal_tx, start_signal_rx) = std::sync::mpsc::channel();
let join_handle = rt::spawn_blocking(|| {
rt::block_on(async move {
if gen_server
.run(&handle, &mut rx, Some(&mut start_signal_tx))
.await
.is_err()
{
tracing::trace!("GenServer crashed")
};
})
});

// Wait for the GenServer to signal us that it has started
match start_signal_rx.recv() {
Ok(true) => Ok(handle_clone),
_ => {
join_handle.abort(); // Abort the task even tho we know it won't run anymore
Err(GenServerError::Initialization)
}
}
}

pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
self.tx.clone()
}
Expand Down Expand Up @@ -126,23 +193,43 @@ pub trait GenServer: Send + Sized + Clone {
type OutMsg: Send + Sized;
type Error: Debug + Send;

/// Starts the GenServer, without waiting for it to finalize its `init` process.
fn start(self) -> GenServerHandle<Self> {
GenServerHandle::new(self)
}

/// Starts the GenServer, waiting for it to finalize its `init` process.
fn verified_start(self) -> Result<GenServerHandle<Self>, GenServerError> {
GenServerHandle::verified_new(self)
}

/// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
/// happen if the task is blocking the thread. As such, for sync compute task
/// or other blocking tasks need to be in their own separate thread, and the OS
/// will manage them through hardware interrupts.
/// Start blocking provides such thread.
///
/// As with `start`, it doesn't wait for the GenServer to finalize its `init` process.
fn start_blocking(self) -> GenServerHandle<Self> {
GenServerHandle::new_blocking(self)
}

/// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
/// happen if the task is blocking the thread. As such, for sync compute task
/// or other blocking tasks need to be in their own separate thread, and the OS
/// will manage them through hardware interrupts.
/// Start blocking provides such thread.
///
/// As with `verified_start`, it waits for the GenServer to finalize its `init` process.
fn verified_start_blocking(self) -> Result<GenServerHandle<Self>, GenServerError> {
GenServerHandle::verified_new_blocking(self)
}

fn run(
self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
start_signal_tx: Option<&mut std::sync::mpsc::Sender<bool>>,
) -> impl Future<Output = Result<(), GenServerError>> + Send {
async {
let init_result = self
Expand All @@ -151,8 +238,26 @@ pub trait GenServer: Send + Sized + Clone {
.inspect_err(|err| tracing::error!("Initialization failed: {err:?}"));

let res = match init_result {
Ok(new_state) => new_state.main_loop(handle, rx).await,
Err(_) => Err(GenServerError::Initialization),
Ok(new_state) => {
// Notify that the GenServer has started successfully
// in case we have a start signal channel
if let Some(start_signal_tx) = start_signal_tx {
start_signal_tx
.send(true)
.map_err(|_| GenServerError::Initialization)?;
}
new_state.main_loop(handle, rx).await
}
Err(_) => {
// Notify that the GenServer failed to start
// in case we have a start signal channel
if let Some(start_signal_tx) = start_signal_tx {
start_signal_tx
.send(false)
.map_err(|_| GenServerError::Initialization)?;
}
Err(GenServerError::Initialization)
}
};

handle.cancellation_token().cancel();
Expand Down Expand Up @@ -469,4 +574,40 @@ mod tests {
assert!(matches!(result, Err(GenServerError::CallTimeout)));
});
}

#[derive(Clone)]
struct FailsOnInitTask;

impl GenServer for FailsOnInitTask {
type CallMsg = ();
type CastMsg = ();
type OutMsg = ();
type Error = ();

async fn init(self, _handle: &GenServerHandle<Self>) -> Result<Self, Self::Error> {
Err(())
}
}

#[test]
pub fn failing_on_init_task() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
// Attempt to start a GenServer that fails on initialization
let result = FailsOnInitTask.verified_start();
assert!(matches!(result, Err(GenServerError::Initialization)));

// Attempt to start a GenServer (in a blocking way) that fails on initialization
let result = FailsOnInitTask.verified_start_blocking();
assert!(matches!(result, Err(GenServerError::Initialization)));

// Other tasks should start correctly
let result = WellBehavedTask { count: 0 }.verified_start();
assert!(result.is_ok());

// They also should start in blocking mode
let result = WellBehavedTask { count: 0 }.verified_start_blocking();
assert!(result.is_ok());
});
}
}