From deac775566d62246248a48fd3df58d6b5c03b729 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Wed, 1 Apr 2026 16:05:04 -0700 Subject: [PATCH] Add Opts::after_connect callback --- src/conn/mod.rs | 41 ++++++++++++++++++++++++++++++++++++++++- src/opts/mod.rs | 47 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 81a2483..b732370 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -997,6 +997,10 @@ impl Conn { } async fn run_init_commands(&mut self) -> Result<()> { + if let Some(callback) = self.inner.opts.after_connect() { + callback.clone()(self).await?; + } + let mut init = self.inner.opts.init().to_vec(); while let Some(query) = init.pop() { @@ -1382,8 +1386,13 @@ impl Conn { #[cfg(test)] mod test { + use std::sync::Arc; + use bytes::Bytes; - use futures_util::stream::{self, StreamExt}; + use futures_util::{ + stream::{self, StreamExt}, + FutureExt, + }; use mysql_common::constants::{MariadbCapabilities, MAX_PAYLOAD_LEN}; use rand::RngExt as _; use tokio::{io::AsyncWriteExt, net::TcpListener}; @@ -1587,6 +1596,36 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_execute_after_connect_callback_on_new_connection() -> super::Result<()> { + let opts = OptsBuilder::from_opts(get_opts()).after_connect(Arc::new(|conn| { + async move { + conn.query_drop("SET @a = 42").await?; + conn.query_drop("SET @b = 'foo'").await?; + Ok(()) + } + .boxed() + })); + let mut conn = Conn::new(opts).await?; + let result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?; + conn.disconnect().await?; + assert_eq!(result, vec![(42, "foo".into())]); + Ok(()) + } + + #[tokio::test] + async fn should_propagate_after_connect_callback_error() -> super::Result<()> { + let opts = OptsBuilder::from_opts(get_opts()).after_connect(Arc::new(|_conn| { + async move { Err(Error::Other("rejected".into())) }.boxed() + })); + let e = Conn::new(opts).await.unwrap_err(); + match e { + Error::Other(e) => assert_eq!(e.to_string(), "rejected"), + e => panic!("expected error from after_connect(), got {e:?}"), + } + Ok(()) + } + #[tokio::test] async fn should_execute_setup_queries_on_reset() -> super::Result<()> { let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]); diff --git a/src/opts/mod.rs b/src/opts/mod.rs index c8082f6..e2e4237 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -562,6 +562,25 @@ pub(crate) struct InnerOpts { address: HostPortOrUrl, } +#[derive(Clone)] +pub(crate) struct AfterConnectCallback( + Arc Fn(&'a mut crate::Conn) -> crate::BoxFuture<'a, ()> + Send + Sync>, +); + +impl Eq for AfterConnectCallback {} + +impl PartialEq for AfterConnectCallback { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl fmt::Debug for AfterConnectCallback { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("AfterConnectCallback").finish() + } +} + /// Mysql connection options. /// /// Build one with [`OptsBuilder`]. @@ -595,6 +614,9 @@ pub(crate) struct MysqlOpts { /// (defaults to `wait_timeout`). conn_ttl: Option, + /// Callback to execute once a new connection is established. + after_connect: Option, + /// Commands to execute once new connection is established. init: Vec, @@ -784,7 +806,18 @@ impl Opts { self.inner.mysql_opts.db_name.as_ref().map(AsRef::as_ref) } - /// Commands to execute once new connection is established. + /// Callback to execute after opening a new connection to the database. Runs + /// before the [`init`][Self::init] queries. + /// + /// If this returns an error, the connection attempt will also fail. + pub fn after_connect( + &self, + ) -> Option<&Arc Fn(&'a mut crate::Conn) -> crate::BoxFuture<'a, ()> + Send + Sync>> + { + self.inner.mysql_opts.after_connect.as_ref().map(|cb| &cb.0) + } + + /// Commands to execute once new a connection is established. pub fn init(&self) -> &[String] { self.inner.mysql_opts.init.as_ref() } @@ -1143,6 +1176,7 @@ impl Default for MysqlOpts { user: None, pass: None, db_name: None, + after_connect: None, init: vec![], setup: vec![], tcp_keepalive: None, @@ -1358,6 +1392,17 @@ impl OptsBuilder { self } + /// Defines a callback that runs after connection. See [`Opts::after_connect`]. + pub fn after_connect( + mut self, + callback: Arc< + dyn for<'a> Fn(&'a mut crate::Conn) -> crate::BoxFuture<'a, ()> + Send + Sync, + >, + ) -> Self { + self.opts.after_connect = Some(AfterConnectCallback(callback)); + self + } + /// Defines initial queries. See [`Opts::init`]. pub fn init>(mut self, init: Vec) -> Self { self.opts.init = init.into_iter().map(Into::into).collect();