@@ -23,6 +23,7 @@ use opsqueue::{
2323use ux_serde:: u63;
2424
2525use crate :: {
26+ async_util,
2627 common:: { run_unless_interrupted, start_runtime, SubmissionId , SubmissionStatus } ,
2728 errors:: { self , CError , CPyResult , FatalPythonException } ,
2829} ;
@@ -217,17 +218,19 @@ impl ProducerClient {
217218 py. allow_threads ( || {
218219 let prefix = uuid:: Uuid :: now_v7 ( ) . to_string ( ) ;
219220 tracing:: debug!( "Uploading submission chunks to object store subfolder {prefix}..." ) ;
220- let chunk_count = Python :: with_gil ( |py| {
221- self . block_unless_interrupted ( async {
222- let chunk_contents = chunk_contents. bind ( py) ;
223- let stream = futures:: stream:: iter ( chunk_contents)
224- . map ( |item| item. and_then ( |item| item. extract ( ) ) . map_err ( Into :: into) ) ;
221+ let chunk_count = self . block_unless_interrupted ( async {
222+ let chunk_contents = std:: iter:: from_fn ( move || {
223+ Python :: with_gil ( |py|
224+ chunk_contents. bind ( py) . clone ( ) . next ( )
225+ . map ( |item| item. and_then (
226+ |item| item. extract ( ) ) . map_err ( Into :: into) ) )
227+ } ) ;
228+ let stream = futures:: stream:: iter ( chunk_contents) ;
225229 self . object_store_client
226230 . store_chunks ( & prefix, ChunkType :: Input , stream)
227231 . await
228232 . map_err ( |e| CError ( R ( L ( e) ) ) )
229- } )
230- } ) ?;
233+ } ) ?;
231234 let chunk_count = chunk:: ChunkIndex :: from ( chunk_count) ;
232235 tracing:: debug!( "Finished uploading to object store. {prefix} contains {chunk_count} chunks" ) ;
233236
@@ -360,15 +363,18 @@ impl ProducerClient {
360363 ) -> PyResult < Bound < ' p , PyAny > > {
361364 let me = self . clone ( ) ;
362365 let _tokio_active_runtime_guard = me. runtime . enter ( ) ;
363- pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
364- match me. stream_completed_submission_chunks ( submission_id) . await {
365- Ok ( iter) => {
366- let async_iter = PyChunksAsyncIter :: from ( iter) ;
367- Ok ( async_iter)
366+ async_util:: future_into_py (
367+ py,
368+ async_util:: async_allow_threads ( Box :: pin ( async move {
369+ match me. stream_completed_submission_chunks ( submission_id) . await {
370+ Ok ( iter) => {
371+ let async_iter = PyChunksAsyncIter :: from ( iter) ;
372+ Ok ( async_iter)
373+ }
374+ Err ( e) => PyResult :: Err ( e. into ( ) ) ,
368375 }
369- Err ( e) => PyResult :: Err ( e. into ( ) ) ,
370- }
371- } )
376+ } ) ) ,
377+ )
372378 }
373379}
374380
@@ -462,7 +468,7 @@ pub type ChunksStream = BoxStream<'static, CPyResult<Vec<u8>, ChunkRetrievalErro
462468
463469#[ pyclass]
464470pub struct PyChunksIter {
465- stream : tokio:: sync:: Mutex < ChunksStream > ,
471+ stream : Arc < tokio:: sync:: Mutex < ChunksStream > > ,
466472 runtime : Arc < tokio:: runtime:: Runtime > ,
467473}
468474
@@ -475,7 +481,7 @@ impl PyChunksIter {
475481 . map_err ( CError )
476482 . boxed ( ) ;
477483 Self {
478- stream : tokio:: sync:: Mutex :: new ( stream) ,
484+ stream : Arc :: new ( tokio:: sync:: Mutex :: new ( stream) ) ,
479485 runtime : client. runtime . clone ( ) ,
480486 }
481487 }
@@ -487,11 +493,21 @@ impl PyChunksIter {
487493 slf
488494 }
489495
490- fn __next__ ( mut slf : PyRefMut < ' _ , Self > ) -> Option < CPyResult < Vec < u8 > , ChunkRetrievalError > > {
491- let me = & mut * slf;
492- let runtime = & mut me. runtime ;
493- let stream = & mut me. stream ;
494- runtime. block_on ( async { stream. lock ( ) . await . next ( ) . await } )
496+ fn __next__ ( & self , py : Python < ' _ > ) -> Option < CPyResult < Vec < u8 > , ChunkRetrievalError > > {
497+ // The only time we need the GIL is when turning the result back.
498+ // By unlocking here, we reduce the chance of deadlocks.
499+ py. allow_threads ( move || {
500+ let runtime = self . runtime . clone ( ) ;
501+ let stream = self . stream . clone ( ) ;
502+ runtime. block_on ( async {
503+ // We lock the stream in a separate Tokio task
504+ // that explicitly runs on the runtime thread rather than on the main Python thread.
505+ // This reduces the possibility for deadlocks even further.
506+ tokio:: task:: spawn ( async move { stream. lock ( ) . await . next ( ) . await } )
507+ . await
508+ . expect ( "Top-level spawn to succeed" )
509+ } )
510+ } )
495511 }
496512
497513 fn __aiter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
@@ -508,7 +524,7 @@ pub struct PyChunksAsyncIter {
508524impl From < PyChunksIter > for PyChunksAsyncIter {
509525 fn from ( iter : PyChunksIter ) -> Self {
510526 Self {
511- stream : Arc :: new ( iter. stream ) ,
527+ stream : iter. stream ,
512528 runtime : iter. runtime ,
513529 }
514530 }
@@ -520,16 +536,27 @@ impl PyChunksAsyncIter {
520536 slf
521537 }
522538
523- fn __anext__ ( slf : PyRef < ' _ , Self > ) -> PyResult < Bound < ' _ , PyAny > > {
524- let _tokio_active_runtime_guard = slf. runtime . enter ( ) ;
525- let stream = slf. stream . clone ( ) ;
526- pyo3_async_runtimes:: tokio:: future_into_py ( slf. py ( ) , async move {
527- let res = stream. lock ( ) . await . next ( ) . await ;
528- match res {
529- None => Err ( PyStopAsyncIteration :: new_err ( ( ) ) ) ,
530- Some ( Ok ( val) ) => Ok ( Some ( val) ) ,
531- Some ( Err ( e) ) => Err ( e. into ( ) ) ,
532- }
533- } )
539+ fn __anext__ < ' py > ( & self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
540+ let stream = self . stream . clone ( ) ;
541+ let _tokio_active_runtime_guard = self . runtime . enter ( ) ;
542+
543+ async_util:: future_into_py (
544+ py,
545+ // The only time we need the GIL is when turning the result into Python datatypes.
546+ // By unlocking here, we reduce the chance of deadlocks.
547+ async_util:: async_allow_threads ( Box :: pin ( async move {
548+ // We lock the stream in a separate Tokio task
549+ // that explicitly runs on the runtime thread rather than on the main Python thread.
550+ // This reduces the possibility for deadlocks even further.
551+ let res = tokio:: task:: spawn ( async move { stream. lock ( ) . await . next ( ) . await } )
552+ . await
553+ . expect ( "Top-level spawn to succeed" ) ;
554+ match res {
555+ None => Err ( PyStopAsyncIteration :: new_err ( ( ) ) ) ,
556+ Some ( Ok ( val) ) => Ok ( Some ( val) ) ,
557+ Some ( Err ( e) ) => Err ( e. into ( ) ) ,
558+ }
559+ } ) ) ,
560+ )
534561 }
535562}
0 commit comments