Skip to content

Commit fad8da6

Browse files
committed
calling sync from async
1 parent 5e7c548 commit fad8da6

File tree

1 file changed

+120
-26
lines changed

1 file changed

+120
-26
lines changed

src/lib.rs

Lines changed: 120 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,11 @@ mod rustpython_runner;
1313
use once_cell::sync::Lazy;
1414
use serde_json::Value;
1515
use std::path::{Path, PathBuf};
16+
use std::sync::mpsc as std_mpsc;
1617
use std::thread;
1718
use thiserror::Error;
18-
use tokio::runtime::Runtime;
1919
use tokio::sync::{mpsc, oneshot};
20-
21-
/// A lazily-initialized global Tokio runtime for synchronous functions.
22-
static SYNC_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
23-
Runtime::new().expect("Failed to create a new Tokio runtime for sync functions")
24-
});
20+
use tokio::runtime::{Builder, Runtime};
2521

2622
#[derive(Debug)]
2723
pub(crate) enum CmdType {
@@ -41,8 +37,26 @@ pub(crate) struct PyCommand {
4137
responder: oneshot::Sender<Result<Value, String>>,
4238
}
4339

40+
/// A boxed, send-able future that resolves to a PyRunnerResult.
41+
type Task = Box<dyn FnOnce(&Runtime) -> Result<Value, PyRunnerError> + Send>;
42+
43+
/// A lazily-initialized worker thread for handling synchronous function calls.
44+
/// This thread has its own private Tokio runtime to safely block on async operations
45+
/// without interfering with any existing runtime the user might be in.
46+
static SYNC_WORKER: Lazy<std_mpsc::Sender<Task>> = Lazy::new(|| {
47+
let (tx, rx) = std_mpsc::channel::<Task>();
48+
49+
thread::spawn(move || {
50+
let rt = Runtime::new().expect("Failed to create Tokio runtime for sync worker");
51+
// When the sender (tx) is dropped, rx.recv() will return an Err, ending the loop.
52+
while let Ok(task) = rx.recv() {
53+
let _ = task(&rt); // The result is sent back via a channel inside the task.
54+
}
55+
});
56+
tx
57+
});
4458
/// Custom error types for the `PyRunner`.
45-
#[derive(Error, Debug)]
59+
#[derive(Error, Debug, Clone)]
4660
pub enum PyRunnerError {
4761
#[error("Failed to send command to Python thread. The thread may have panicked.")]
4862
SendCommandFailed,
@@ -96,13 +110,15 @@ impl PyRunner {
96110
thread::spawn(move || {
97111
#[cfg(all(feature = "pyo3", not(feature = "rustpython")))]
98112
{
99-
use tokio::runtime::Builder;
100113
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
101114
rt.block_on(pyo3_runner::python_thread_main(receiver));
102115
}
103116

104117
#[cfg(feature = "rustpython")]
105-
rustpython_runner::python_thread_main(receiver);
118+
{
119+
let rt = Builder::new_current_thread().enable_all().build().unwrap();
120+
rt.block_on(rustpython_runner::python_thread_main(receiver));
121+
}
106122
});
107123

108124
Self { sender }
@@ -131,6 +147,32 @@ impl PyRunner {
131147
.map_err(PyRunnerError::PyError)
132148
}
133149

150+
/// A private helper function to encapsulate the logic of sending a command
151+
/// and receiving a response synchronously.
152+
fn send_command_sync(&self, cmd_type: CmdType) -> Result<Value, PyRunnerError> {
153+
let (tx, rx) = std_mpsc::channel();
154+
let sender = self.sender.clone();
155+
156+
let cmd_type_clone = cmd_type; // Clone is implicit as CmdType is Copy
157+
let task = Box::new(move |rt: &Runtime| {
158+
let result = rt.block_on(async {
159+
// This is the async `send_command` logic, but we can't call it
160+
// directly because of `&self` lifetime issues inside the closure.
161+
let (responder, receiver) = oneshot::channel();
162+
let cmd = PyCommand { cmd_type: cmd_type_clone, responder };
163+
sender.send(cmd).await.map_err(|_| PyRunnerError::SendCommandFailed)?;
164+
receiver.await.map_err(|_| PyRunnerError::ReceiveResultFailed.clone())?
165+
.map_err(PyRunnerError::PyError)
166+
});
167+
if tx.send(result.clone()).is_err() {
168+
return Err(PyRunnerError::SendCommandFailed);
169+
}
170+
result
171+
});
172+
173+
SYNC_WORKER.send(task).map_err(|_| PyRunnerError::SendCommandFailed)?;
174+
rx.recv().map_err(|_| PyRunnerError::ReceiveResultFailed)?
175+
}
134176
/// Asynchronously executes a block of Python code.
135177
///
136178
/// * `code`: A string slice containing the Python code to execute.
@@ -148,9 +190,10 @@ impl PyRunner {
148190
///
149191
/// * `code`: A string slice containing the Python code to execute.
150192
///
151-
/// **Note:** Calling this from an existing async runtime can lead to panics.
193+
/// **Note:** This function is safe to call from any context (sync or async).
152194
pub fn run_sync(&self, code: &str) -> Result<(), PyRunnerError> {
153-
SYNC_RUNTIME.block_on(self.run(code))
195+
self.send_command_sync(CmdType::RunCode(code.into()))
196+
.map(|_| ())
154197
}
155198

156199
/// Asynchronously runs a python file.
@@ -169,9 +212,10 @@ impl PyRunner {
169212
///
170213
/// * `file`: Absolute path to a python file to execute.
171214
///
172-
/// **Note:** Calling this from an existing async runtime can lead to panics.
215+
/// **Note:** This function is safe to call from any context (sync or async).
173216
pub fn run_file_sync(&self, file: &Path) -> Result<(), PyRunnerError> {
174-
SYNC_RUNTIME.block_on(self.run_file(file))
217+
self.send_command_sync(CmdType::RunFile(file.to_path_buf()))
218+
.map(|_| ())
175219
}
176220

177221
/// Asynchronously evaluates a single Python expression.
@@ -191,9 +235,9 @@ impl PyRunner {
191235
///
192236
/// * `code`: A string slice containing the Python expression to evaluate.
193237
///
194-
/// **Note:** Calling this from an existing async runtime can lead to panics.
238+
/// **Note:** This function is safe to call from any context (sync or async).
195239
pub fn eval_sync(&self, code: &str) -> Result<Value, PyRunnerError> {
196-
SYNC_RUNTIME.block_on(self.eval(code))
240+
self.send_command_sync(CmdType::EvalCode(code.into()))
197241
}
198242

199243
/// Asynchronously reads a variable from the Python interpreter's global scope.
@@ -213,9 +257,9 @@ impl PyRunner {
213257
///
214258
/// * `var_name`: The name of the variable to read.
215259
///
216-
/// **Note:** Calling this from an existing async runtime can lead to panics.
260+
/// **Note:** This function is safe to call from any context (sync or async).
217261
pub fn read_variable_sync(&self, var_name: &str) -> Result<Value, PyRunnerError> {
218-
SYNC_RUNTIME.block_on(self.read_variable(var_name))
262+
self.send_command_sync(CmdType::ReadVariable(var_name.into()))
219263
}
220264

221265
/// Asynchronously calls a Python function in the interpreter's global scope.
@@ -246,15 +290,17 @@ impl PyRunner {
246290
/// * `name`: The name of the function to call.
247291
/// * `args`: A vector of `serde_json::Value` to pass as arguments to the function.
248292
///
249-
/// **Note:** Calling this from an existing async runtime can lead to panics.
250-
/// This is for calling from a non-async context.
293+
/// **Note:** This function is safe to call from any context (sync or async).
251294
#[cfg(feature = "pyo3")]
252295
pub fn call_function_sync(
253296
&self,
254297
name: &str,
255298
args: Vec<Value>,
256299
) -> Result<Value, PyRunnerError> {
257-
SYNC_RUNTIME.block_on(self.call_function(name, args))
300+
self.send_command_sync(CmdType::CallFunction {
301+
name: name.into(),
302+
args,
303+
})
258304
}
259305

260306
/// Asynchronously calls an async Python function in the interpreter's global scope.
@@ -284,14 +330,17 @@ impl PyRunner {
284330
/// * `name`: The name of the function to call.
285331
/// * `args`: A vector of `serde_json::Value` to pass as arguments to the function.
286332
///
287-
/// **Note:** Calling this from an existing async runtime can lead to panics.
333+
/// **Note:** This function is safe to call from any context (sync or async).
288334
#[cfg(feature = "pyo3")]
289335
pub fn call_async_function_sync(
290336
&self,
291337
name: &str,
292338
args: Vec<Value>,
293339
) -> Result<Value, PyRunnerError> {
294-
SYNC_RUNTIME.block_on(self.call_async_function(name, args))
340+
self.send_command_sync(CmdType::CallAsyncFunction {
341+
name: name.into(),
342+
args,
343+
})
295344
}
296345

297346
/// Stops the Python execution thread gracefully.
@@ -306,9 +355,9 @@ impl PyRunner {
306355
/// This is a blocking wrapper around `stop`. It is intended for use in
307356
/// synchronous applications.
308357
///
309-
/// **Note:** Calling this from an existing async runtime can lead to panics.
358+
/// **Note:** This function is safe to call from any context (sync or async).
310359
pub fn stop_sync(&self) -> Result<(), PyRunnerError> {
311-
SYNC_RUNTIME.block_on(self.stop())
360+
self.send_command_sync(CmdType::Stop).map(|_| ())
312361
}
313362

314363
/// Set python venv environment folder (does not change interpreter)
@@ -352,9 +401,37 @@ impl PyRunner {
352401
///
353402
/// * `venv_path`: Path to the venv directory.
354403
///
355-
/// **Note:** Calling this from an existing async runtime can lead to panics.
404+
/// **Note:** This function is safe to call from any context (sync or async).
356405
pub fn set_venv_sync(&self, venv_path: &Path) -> Result<(), PyRunnerError> {
357-
SYNC_RUNTIME.block_on(self.set_venv(venv_path))
406+
if !venv_path.is_dir() {
407+
return Err(PyRunnerError::PyError(format!(
408+
"Could not find venv directory {}",
409+
venv_path.display()
410+
)));
411+
}
412+
let set_venv_code = include_str!("set_venv.py");
413+
self.run_sync(&set_venv_code)?;
414+
415+
let site_packages = if cfg!(target_os = "windows") {
416+
venv_path.join("Lib").join("site-packages")
417+
} else {
418+
let version_code = "f\"python{sys.version_info.major}.{sys.version_info.minor}\"";
419+
let py_version = self.eval_sync(version_code)?;
420+
venv_path
421+
.join("lib")
422+
.join(py_version.as_str().unwrap())
423+
.join("site-packages")
424+
};
425+
#[cfg(all(feature = "pyo3", not(feature = "rustpython")))]
426+
let with_pth = "True";
427+
#[cfg(feature = "rustpython")]
428+
let with_pth = "False";
429+
430+
self.run_sync(&format!(
431+
"add_venv_libs_to_syspath({}, {})",
432+
print_path_for_python(&site_packages),
433+
with_pth
434+
))
358435
}
359436
}
360437

@@ -393,6 +470,23 @@ z = x + y"#;
393470
assert_eq!(z_val, Value::Number(30.into()));
394471
}
395472

473+
474+
#[tokio::test]
475+
async fn test_run_sync_from_async() {
476+
let executor = PyRunner::new();
477+
let code = r#"
478+
x = 10
479+
y = 20
480+
z = x + y"#;
481+
482+
let result_module = executor.run(code).await;
483+
484+
assert!(result_module.is_ok());
485+
486+
let z_val = executor.read_variable_sync("z").unwrap();
487+
488+
assert_eq!(z_val, Value::Number(30.into()));
489+
}
396490

397491
#[tokio::test]
398492
async fn test_run_with_function() {

0 commit comments

Comments
 (0)