@@ -20,6 +20,7 @@ use opsqueue::{
2020 tracing:: CarrierMap ,
2121 E ,
2222} ;
23+ use pyo3_async_runtimes:: TaskLocals ;
2324use ux_serde:: u63;
2425
2526use crate :: {
@@ -361,7 +362,7 @@ impl ProducerClient {
361362 submission_id : SubmissionId ,
362363 ) -> PyResult < Bound < ' p , PyAny > > {
363364 let me = self . clone ( ) ;
364- let _tokio_active_runtime_guard = me. runtime . enter ( ) ;
365+ // let _tokio_active_runtime_guard = me.runtime.enter();
365366 pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
366367 match me. stream_completed_submission_chunks ( submission_id) . await {
367368 Ok ( iter) => {
@@ -464,7 +465,7 @@ pub type ChunksStream = BoxStream<'static, CPyResult<Vec<u8>, ChunkRetrievalErro
464465
465466#[ pyclass]
466467pub struct PyChunksIter {
467- stream : tokio:: sync:: Mutex < ChunksStream > ,
468+ stream : Arc < tokio:: sync:: Mutex < ChunksStream > > ,
468469 runtime : Arc < tokio:: runtime:: Runtime > ,
469470}
470471
@@ -477,7 +478,7 @@ impl PyChunksIter {
477478 . map_err ( CError )
478479 . boxed ( ) ;
479480 Self {
480- stream : tokio:: sync:: Mutex :: new ( stream) ,
481+ stream : Arc :: new ( tokio:: sync:: Mutex :: new ( stream) ) ,
481482 runtime : client. runtime . clone ( ) ,
482483 }
483484 }
@@ -490,10 +491,18 @@ impl PyChunksIter {
490491 }
491492
492493 fn __next__ ( mut slf : PyRefMut < ' _ , Self > ) -> Option < CPyResult < Vec < u8 > , ChunkRetrievalError > > {
494+ let py = slf. py ( ) ;
493495 let me = & mut * slf;
494496 let runtime = & mut me. runtime ;
495- let stream = & mut me. stream ;
496- runtime. block_on ( async { stream. lock ( ) . await . next ( ) . await } )
497+ let stream = me. stream . clone ( ) ;
498+ tracing:: warn!( "Grabbing another element from PyChunksIter" ) ;
499+ py. allow_threads ( move || {
500+ runtime. block_on ( async {
501+ tokio:: task:: spawn ( async move { stream. lock ( ) . await . next ( ) . await } )
502+ . await
503+ . unwrap ( )
504+ } )
505+ } )
497506 }
498507
499508 fn __aiter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
@@ -510,7 +519,7 @@ pub struct PyChunksAsyncIter {
510519impl From < PyChunksIter > for PyChunksAsyncIter {
511520 fn from ( iter : PyChunksIter ) -> Self {
512521 Self {
513- stream : Arc :: new ( iter. stream ) ,
522+ stream : iter. stream ,
514523 runtime : iter. runtime ,
515524 }
516525 }
@@ -523,15 +532,135 @@ impl PyChunksAsyncIter {
523532 }
524533
525534 fn __anext__ ( slf : PyRef < ' _ , Self > ) -> PyResult < Bound < ' _ , PyAny > > {
526- let _tokio_active_runtime_guard = slf. runtime . enter ( ) ;
535+ println ! ( "A" ) ;
536+
537+ println ! ( "B" ) ;
527538 let stream = slf. stream . clone ( ) ;
528- pyo3_async_runtimes:: tokio:: future_into_py ( slf. py ( ) , async move {
529- let res = stream. lock ( ) . await . next ( ) . await ;
530- match res {
531- None => Err ( PyStopAsyncIteration :: new_err ( ( ) ) ) ,
532- Some ( Ok ( val) ) => Ok ( Some ( val) ) ,
533- Some ( Err ( e) ) => Err ( e. into ( ) ) ,
534- }
539+ println ! ( "C" ) ;
540+
541+ let _tokio_active_runtime_guard = slf. runtime . enter ( ) ;
542+ pyo3_async_runtimes:: generic:: future_into_py :: < TokioRuntimeThatIsInScope , _ , _ > (
543+ slf. py ( ) ,
544+ async move {
545+ let res = stream. lock ( ) . await . next ( ) . await ;
546+
547+ match res {
548+ None => Err ( PyStopAsyncIteration :: new_err ( ( ) ) ) ,
549+ Some ( Ok ( val) ) => Ok ( Some ( val) ) ,
550+ Some ( Err ( e) ) => Err ( e. into ( ) ) ,
551+ }
552+ } ,
553+ )
554+ // pyo3_async_runtimes::generic::future_into_py::<TokioRuntimeThatIsInScope, _, _>(slf.py(), async move {
555+ // println!("D");
556+ // tokio::task::yield_now().await;
557+
558+ // let res = Python::with_gil(|py| py.allow_threads(async || {
559+ // tokio::task::yield_now().await;
560+ // stream.lock().await.next().await
561+ // })).await;
562+ // match res {
563+ // None => Err(PyStopAsyncIteration::new_err(())),
564+ // Some(Ok(val)) => Ok(Some(val)),
565+ // Some(Err(e)) => Err(e.into()),
566+ // }
567+ // })
568+ // let stream = slf.stream.clone();
569+ // let runtime = slf.runtime.clone();
570+ // let py = slf.py();
571+ // // future_into_py eats the `py` token but does not by itself release the GIL
572+ // // Therefore, we immediately 'reacquire' the GIL (a no-op) to explicitly call `py.allow_threads`.
573+ // pyo3_async_runtimes::tokio::future_into_py(py, async move {
574+ // // let res =
575+ // let res = Python::with_gil(|py| {
576+ // py.allow_threads(|| {
577+ // runtime.spawn(async move {stream.lock().await.next().await})
578+ // // stream.lock().await.next().await
579+ // })
580+ // })
581+ // .await.expect("Top-level spawn should not be canceled");
582+ // ;
583+
584+ // match res {
585+ // None => Err(PyStopAsyncIteration::new_err(())),
586+ // Some(Ok(val)) => Ok(Some(val)),
587+ // Some(Err(e)) => Err(e.into()),
588+ // }
589+ // })
590+
591+ // // Based on https://github.com/awestlake87/pyo3-asyncio/issues/59#issuecomment-1007680179
592+ // pyo3_async_runtimes::tokio::future_into_py(slf.py(), async move {
593+ // tokio::task::yield_now().await;
594+
595+ // // We run this on a separate thread, to reduce the chance of deadlocks.
596+ // let result = Python::with_gil(|py| py.allow_threads(move ||
597+ // tokio::task::spawn_blocking(move || {
598+ // tracing::warn!("Grabbing another element from PyChunksAsyncIter");
599+ // let res = tokio::task::LocalSet::new().block_on(&runtime, async move {
600+ // tokio::task::yield_now().await;
601+ // stream.lock().await.next().await
602+ // });
603+ // match res {
604+ // None => Err(PyStopAsyncIteration::new_err(())),
605+ // Some(Ok(val)) => Ok(Some(val)),
606+ // Some(Err(e)) => Err(e.into()),
607+ // }
608+ // })));
609+ // result.await.expect("JoinHandle should always succeed")
610+ // })
611+ }
612+ }
613+
614+ struct TokioRuntimeThatIsInScope ( ) ;
615+
616+ use once_cell:: unsync:: OnceCell as UnsyncOnceCell ;
617+
618+ tokio:: task_local! {
619+ static TASK_LOCALS : UnsyncOnceCell <TaskLocals >;
620+ }
621+
622+ impl pyo3_async_runtimes:: generic:: Runtime for TokioRuntimeThatIsInScope {
623+ type JoinError = tokio:: task:: JoinError ;
624+ type JoinHandle = tokio:: task:: JoinHandle < ( ) > ;
625+
626+ fn spawn < F > ( fut : F ) -> Self :: JoinHandle
627+ where
628+ F : std:: future:: Future < Output = ( ) > + Send + ' static ,
629+ {
630+ println ! ( "About to spawn" ) ;
631+ // Python::with_gil(|py| {
632+ // println!("reacquired GIL");
633+ // py.allow_threads(|| {
634+ // println!("Allowing threads");
635+ tokio:: task:: spawn ( async move {
636+ println ! ( "Inside spawn" ) ;
637+ fut. await ;
535638 } )
639+ // })
640+ // })
641+ }
642+ }
643+
644+ impl pyo3_async_runtimes:: generic:: ContextExt for TokioRuntimeThatIsInScope {
645+ fn scope < F , R > (
646+ locals : TaskLocals ,
647+ fut : F ,
648+ ) -> std:: pin:: Pin < Box < dyn std:: future:: Future < Output = R > + Send > >
649+ where
650+ F : std:: future:: Future < Output = R > + Send + ' static ,
651+ {
652+ let cell = UnsyncOnceCell :: new ( ) ;
653+ cell. set ( locals) . unwrap ( ) ;
654+
655+ Box :: pin ( TASK_LOCALS . scope ( cell, fut) )
656+ }
657+
658+ fn get_task_locals ( ) -> Option < TaskLocals > {
659+ TASK_LOCALS
660+ . try_with ( |c| {
661+ c. get ( )
662+ . map ( |locals| Python :: with_gil ( |py| locals. clone_ref ( py) ) )
663+ } )
664+ . unwrap_or_default ( )
536665 }
537666}
0 commit comments