From 9e1f187318f6a2225cad2b12807de6d1a74f3d49 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Fri, 20 Mar 2026 17:46:40 -0400 Subject: [PATCH 1/3] accel: implement background progress with rayon-cancel --- Cargo.lock | 10 +++++++++ Cargo.toml | 1 + src/accel/als/explicit.rs | 23 ++++++++++---------- src/accel/progress.rs | 46 +++++++++++++++++++++++++++++++++++++-- 4 files changed, 67 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 984b27e2d..d4d75382a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -625,6 +625,7 @@ dependencies = [ "rand", "rand_pcg", "rayon", + "rayon-cancel", "rustc-hash", "serde", "serde_json", @@ -1022,6 +1023,15 @@ dependencies = [ "rayon-core", ] +[[package]] +name = "rayon-cancel" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4f327c66965c769eb84538cdb86b63ce33046f2ea0153d4d29606b241c5143" +dependencies = [ + "rayon", +] + [[package]] name = "rayon-core" version = "1.13.0" diff --git a/Cargo.toml b/Cargo.toml index 9e866feb6..94ec4bd26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ pyo3-log = "^0.13" thiserror = "^2" rayon = "^1.10" +rayon-cancel = "^1.0" hashbrown = "^0.16" hex = "^0.4" diff --git a/src/accel/als/explicit.rs b/src/accel/als/explicit.rs index 222643953..434271fda 100644 --- a/src/accel/als/explicit.rs +++ b/src/accel/als/explicit.rs @@ -46,17 +46,18 @@ pub(super) fn train_explicit_matrix<'py>( other.nrows() ); - let frob: f32 = py.detach(|| { - this.outer_iter_mut() - .into_par_iter() - .enumerate() - .map(|(i, row)| { - let f = train_row_solve(&solver, &matrix, i, row, &other, reg); - progress.tick(); - f - }) - .sum() - }); + let frob: f32 = progress.process_iter( + py, + this.outer_iter_mut().into_par_iter().enumerate(), + |iter| { + Ok(iter + .map(|(i, row)| { + let f = train_row_solve(&solver, &matrix, i, row, &other, reg); + f + }) + .sum()) + }, + )?; progress.shutdown(py)?; Ok(frob.sqrt()) diff --git a/src/accel/progress.rs b/src/accel/progress.rs index 6842127ae..9e372ad3c 100644 --- a/src/accel/progress.rs +++ b/src/accel/progress.rs @@ -6,11 +6,14 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Condvar, Mutex, MutexGuard}; -use std::thread::{spawn, JoinHandle}; +use std::thread::{self, spawn, JoinHandle}; use std::time::Duration; +use log::*; use pyo3::exceptions::PyRuntimeError; use pyo3::{intern, prelude::*, types::PyDict}; +use rayon::iter::ParallelIterator; +use rayon_cancel::CancelAdapter; const UPDTATE_TIMEOUT: Duration = Duration::from_millis(200); @@ -30,7 +33,7 @@ struct ProgressData { /// /// This method applies internal throttling to reduce the number of calls /// to the Python progress bar. -pub(crate) struct ProgressHandle { +pub struct ProgressHandle { data: Option>, handle: Option>, } @@ -103,6 +106,45 @@ impl ProgressHandle { Ok(()) } } + + /// Process an iterator, with progress, thread-detach, and interrupt checks. + pub fn process_iter<'py, I, R, F>(&self, py: Python<'py>, iter: I, proc: F) -> PyResult + where + I: ParallelIterator + Send, + R: Send, + F: FnOnce(CancelAdapter) -> PyResult + Send, + { + let adapter = CancelAdapter::new(iter); + let counter = adapter.counter(); + let cancel = adapter.canceller(); + let caller = thread::current(); + + thread::scope(move |scope| { + let handle = scope.spawn(move || { + let result = proc(adapter); + caller.unpark(); + result + }); + + let mut count = counter.get(); + while !handle.is_finished() { + py.detach(|| thread::park_timeout(UPDTATE_TIMEOUT)); + if let Err(e) = py.check_signals() { + cancel.cancel(); + return Err(e); + } + let n = counter.get(); + debug!("counter: {} / {}", n, count); + self.advance(n - count); + count = n; + } + + match handle.join() { + Ok(r) => r, + Err(_) => Err(PyRuntimeError::new_err("worker thread panicked")), + } + }) + } } impl Clone for ProgressHandle { From 75a988eb924bf9d765feda8858970ee9bfe2b853 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Fri, 20 Mar 2026 18:09:59 -0400 Subject: [PATCH 2/3] accel: simplify PB and use iter processing --- src/accel/als/explicit.rs | 3 +- src/accel/als/implicit.rs | 26 +++--- src/accel/data/cooc.rs | 84 +++++++++--------- src/accel/knn/item_train.rs | 55 ++++++------ src/accel/progress.rs | 166 ++++------------------------------- src/accel/sparse/consumer.rs | 29 +++--- 6 files changed, 114 insertions(+), 249 deletions(-) diff --git a/src/accel/als/explicit.rs b/src/accel/als/explicit.rs index 434271fda..2e3bf1176 100644 --- a/src/accel/als/explicit.rs +++ b/src/accel/als/explicit.rs @@ -40,7 +40,7 @@ pub(super) fn train_explicit_matrix<'py>( let other_py = other.readonly(); let other = other_py.as_array(); - let mut progress = ProgressHandle::from_input(progress); + let progress = ProgressHandle::from_input(progress); debug!( "beginning explicit ALS training half with {} rows", other.nrows() @@ -58,7 +58,6 @@ pub(super) fn train_explicit_matrix<'py>( .sum()) }, )?; - progress.shutdown(py)?; Ok(frob.sqrt()) } diff --git a/src/accel/als/implicit.rs b/src/accel/als/implicit.rs index 9a41a6705..9175c00a4 100644 --- a/src/accel/als/implicit.rs +++ b/src/accel/als/implicit.rs @@ -43,23 +43,23 @@ pub(super) fn train_implicit_matrix<'py>( let otor_py = otor.readonly(); let otor = otor_py.as_array(); - let mut progress = ProgressHandle::from_input(progress); + let progress = ProgressHandle::from_input(progress); debug!( "beginning implicit ALS training half with {} rows", other.nrows() ); - let frob: f32 = py.detach(|| { - this.outer_iter_mut() - .into_par_iter() - .enumerate() - .map(|(i, row)| { - let f = train_row_solve(&solver, &matrix, i, row, &other, &otor); - progress.tick(); - f - }) - .sum() - }); - progress.shutdown(py)?; + let frob: f32 = progress.process_iter( + py, + this.outer_iter_mut().into_par_iter().enumerate(), + |iter| { + Ok(iter + .map(|(i, row)| { + let f = train_row_solve(&solver, &matrix, i, row, &other, &otor); + f + }) + .sum()) + }, + )?; Ok(frob.sqrt()) } diff --git a/src/accel/data/cooc.rs b/src/accel/data/cooc.rs index 5129c5fd9..275c6d56f 100644 --- a/src/accel/data/cooc.rs +++ b/src/accel/data/cooc.rs @@ -47,20 +47,18 @@ pub fn count_cooc<'py>( return Err(PyValueError::new_err("array length mismatch")); } - let mut pb = ProgressHandle::new(progress); + let pb = ProgressHandle::new(progress); - let out = py.detach(|| { - let groups = checked_array_ref::("groups", "Int32", &groups)?; - let items = checked_array_ref::("items", "Int32", &items)?; - - if ordered { + let groups = checked_array_ref::("groups", "Int32", &groups)?; + let items = checked_array_ref::("items", "Int32", &items)?; + let out = if ordered { + py.detach(|| { count_cooc_sequential::(&pb, groups, items, n_groups, n_items) - } else { - let ctr = SymmetricPairCounter::with_diagonal(n_items, diagonal); - count_cooc_parallel(ctr, groups, items, n_groups, &pb) - } - }); - pb.shutdown(py)?; + }) + } else { + let ctr = SymmetricPairCounter::with_diagonal(n_items, diagonal); + count_cooc_parallel(py, ctr, groups, items, n_groups, &pb) + }; let out = out?; debug!( "finished counting {} co-occurrances", @@ -99,17 +97,13 @@ pub fn dense_cooc<'py>( return Err(PyValueError::new_err("array length mismatch")); } - let mut pb = ProgressHandle::new(progress); + let pb = ProgressHandle::new(progress); - let out = py.detach(|| { - let groups = checked_array_ref::("groups", "Int32", &groups)?; - let items = checked_array_ref::("items", "Int32", &items)?; + let groups = checked_array_ref::("groups", "Int32", &groups)?; + let items = checked_array_ref::("items", "Int32", &items)?; - let ctr = DensePairCounter::with_diagonal(n_items, diagonal); - count_cooc_parallel(ctr, groups, items, n_groups, &pb) - }); - pb.shutdown(py)?; - let out = out?; + let ctr = DensePairCounter::with_diagonal(n_items, diagonal); + let out = count_cooc_parallel(py, ctr, groups, items, n_groups, &pb)?; debug!("finished counting co-occurrances"); Ok(out.to_pyarray(py)) @@ -135,15 +129,22 @@ fn count_cooc_sequential( let end = g_ptrs[i + 1]; let items = &ivals[start..end]; count_items(&mut counts, items); - pb.advance(items.len()); + if i % 100 == 0 { + Python::attach(|py| { + pb.advance(py, items.len()); + }) + } } - pb.flush(); + Python::attach(|py| { + pb.advance(py, items.len()); + }); // assemble the result Ok(counts.finish()) } -fn count_cooc_parallel( +fn count_cooc_parallel<'py, PC: ConcurrentPairCounter>( + py: Python<'py>, counts: PC, groups: &Int32Array, items: &Int32Array, @@ -153,26 +154,27 @@ fn count_cooc_parallel( let gvals = groups.values(); let ivals = items.values(); - let g_ptrs = compute_group_pointers(n_groups, gvals)?; + let g_ptrs = py.detach(|| compute_group_pointers(n_groups, gvals))?; debug!("pass 2: counting groups"); - (0..n_groups).into_par_iter().for_each(|i| { - let start = g_ptrs[i]; - let end = g_ptrs[i + 1]; - let items = &ivals[start..end]; - let n = items.len(); - - for i in 0..n { - let ri = items[i as usize]; - for j in i..n { - let ci = items[j as usize]; - counts.crecord(ri, ci); + // TODO: fix progress update + pb.process_iter(py, (0..n_groups).into_par_iter(), |iter| { + iter.for_each(|i| { + let start = g_ptrs[i]; + let end = g_ptrs[i + 1]; + let items = &ivals[start..end]; + let n = items.len(); + + for i in 0..n { + let ri = items[i as usize]; + for j in i..n { + let ci = items[j as usize]; + counts.crecord(ri, ci); + } } - } - - pb.advance(items.len()); - }); - pb.flush(); + }); + Ok(()) + })?; // assemble the result Ok(counts.finish()) diff --git a/src/accel/knn/item_train.rs b/src/accel/knn/item_train.rs index 0961c5179..0f86b1fcc 100644 --- a/src/accel/knn/item_train.rs +++ b/src/accel/knn/item_train.rs @@ -31,38 +31,33 @@ pub fn compute_similarities<'py>( progress: Bound<'py, PyAny>, ) -> PyResult>> { let (nu, ni) = shape; - let mut progress = ProgressHandle::from_input(progress); - - let res = py.detach(|| { - // extract the data - debug!("preparing {}x{} training", nu, ni); - debug!( - "resolving user-item matrix (type: {:#?})", - ui_ratings.0.data_type() - ); - let ui_mat = CSRMatrix::from_arrow(make_array(ui_ratings.0))?; - debug!("resolving item-user matrix"); - let iu_mat = CSRMatrix::from_arrow(make_array(iu_ratings.0))?; - assert_eq!(ui_mat.len(), nu); - assert_eq!(ui_mat.n_cols, ni); - assert_eq!(iu_mat.len(), ni); - assert_eq!(iu_mat.n_cols, nu); - - // let's compute! - let range = 0..ni; - debug!("computing similarity rows"); - let collector = ArrowCSRConsumer::with_progress(ni, &progress); - let chunks = range - .into_par_iter() + let progress = ProgressHandle::from_input(progress); + + // extract the data + debug!("preparing {}x{} training", nu, ni); + debug!( + "resolving user-item matrix (type: {:#?})", + ui_ratings.0.data_type() + ); + let ui_mat = CSRMatrix::from_arrow(make_array(ui_ratings.0))?; + debug!("resolving item-user matrix"); + let iu_mat = CSRMatrix::from_arrow(make_array(iu_ratings.0))?; + assert_eq!(ui_mat.len(), nu); + assert_eq!(ui_mat.n_cols, ni); + assert_eq!(iu_mat.len(), ni); + assert_eq!(iu_mat.n_cols, nu); + + // let's compute! + let range = 0..ni; + debug!("computing similarity rows"); + let collector = ArrowCSRConsumer::new(ni); + let chunks = progress.process_iter(py, range.into_par_iter(), |iter| { + Ok(iter .map(|row| sim_row(row, &ui_mat, &iu_mat, min_sim, save_nbrs)) - .drive(collector); - - Ok(chunks.iter().map(|a| a.into_data().into()).collect()) - }); - - progress.shutdown(py)?; + .drive_unindexed(collector)) + })?; - res + Ok(chunks.iter().map(|a| a.into_data().into()).collect()) } fn sim_row( diff --git a/src/accel/progress.rs b/src/accel/progress.rs index 9e372ad3c..571551ba4 100644 --- a/src/accel/progress.rs +++ b/src/accel/progress.rs @@ -4,38 +4,23 @@ // Licensed under the MIT license, see LICENSE.md for details. // SPDX-License-Identifier: MIT -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Condvar, Mutex, MutexGuard}; -use std::thread::{self, spawn, JoinHandle}; +use std::thread; use std::time::Duration; -use log::*; use pyo3::exceptions::PyRuntimeError; use pyo3::{intern, prelude::*, types::PyDict}; use rayon::iter::ParallelIterator; use rayon_cancel::CancelAdapter; -const UPDTATE_TIMEOUT: Duration = Duration::from_millis(200); - -struct ProgressThreadState { - running: bool, - last_count: usize, -} - -struct ProgressData { - pb: Py, - count: AtomicUsize, - state: Mutex, - condition: Condvar, -} +const UPDATE_TIMEOUT: Duration = Duration::from_millis(200); /// Thin Rust wrapper around a LensKit progress bar. /// /// This method applies internal throttling to reduce the number of calls /// to the Python progress bar. pub struct ProgressHandle { - data: Option>, - handle: Option>, + pb: Option>, + count: usize, } impl ProgressHandle { @@ -49,59 +34,24 @@ impl ProgressHandle { } pub fn new(pb: Option>) -> Self { - pb.map(|pb| { - let data = Arc::new(ProgressData { - pb, - count: AtomicUsize::new(0), - state: Mutex::new(ProgressThreadState { - running: true, - last_count: 0, - }), - condition: Condvar::new(), - }); - let d2 = data.clone(); - let handle = spawn(move || d2.background_update()); - ProgressHandle { - data: Some(data), - handle: Some(handle), - } - }) - .unwrap_or_else(Self::null) + ProgressHandle { pb, count: 0 } } - pub fn null() -> Self { - ProgressHandle { - data: None, - handle: None, - } + pub fn tick<'py>(&self, py: Python<'py>) { + self.advance(py, 1); } - pub fn tick(&self) { - self.advance(1); - } - - pub fn advance(&self, n: usize) { - if let Some(data) = &self.data { - data.count.fetch_add(n, Ordering::Relaxed); - } + pub fn advance<'py>(&self, py: Python<'py>, n: usize) { + self.update(py, self.count + n); } - /// Force an update of the progress bar. - pub fn flush(&self) { - if let Some(data) = &self.data { - data.ping(); - } - } - - pub fn shutdown<'py>(&mut self, py: Python<'py>) -> PyResult<()> { - if let Some(data) = self.data.take() { - data.shutdown(); - } - if let Some(h) = self.handle.take() { - py.detach(|| { - h.join() - .map_err(|_e| PyRuntimeError::new_err(format!("progress thread panicked"))) - }) + pub fn update<'py>(&self, py: Python<'py>, complete: usize) -> PyResult<()> { + if let Some(pb) = &self.pb { + let pb = pb.bind(py); + let kwargs = PyDict::new(py); + kwargs.set_item(intern!(py, "completed"), complete)?; + pb.call_method(intern!(py, "update"), (), Some(&kwargs))?; + Ok(()) } else { Ok(()) } @@ -126,17 +76,14 @@ impl ProgressHandle { result }); - let mut count = counter.get(); while !handle.is_finished() { - py.detach(|| thread::park_timeout(UPDTATE_TIMEOUT)); + py.detach(|| thread::park_timeout(UPDATE_TIMEOUT)); if let Err(e) = py.check_signals() { cancel.cancel(); return Err(e); } let n = counter.get(); - debug!("counter: {} / {}", n, count); - self.advance(n - count); - count = n; + self.update(py, n); } match handle.join() { @@ -146,80 +93,3 @@ impl ProgressHandle { }) } } - -impl Clone for ProgressHandle { - fn clone(&self) -> Self { - ProgressHandle { - data: self.data.clone(), - handle: None, - } - } -} - -impl Drop for ProgressHandle { - fn drop(&mut self) { - if let Some(data) = self.data.take() { - data.shutdown(); - } - } -} - -impl ProgressData { - fn background_update(&self) { - let mut state = self.acquire_state(); - while state.running { - state = self.wait(state); - - let count = self.count.load(Ordering::Relaxed); - - if count > state.last_count { - // drop lock so we don't deadlock with the Python GIL - drop(state); - - // send update to Python - Python::try_attach(|py| { - let kwargs = PyDict::new(py); - kwargs.set_item(intern!(py, "completed"), count)?; - self.pb - .call_method(py, intern!(py, "update"), (), Some(&kwargs))?; - Ok::<(), PyErr>(()) - }) - .unwrap_or(Ok(())) - .expect("progress update failed"); - - // re-acquire lock, update last count, and loop. - // updating the last count is safe because we are the only thread that does so. - state = self.acquire_state(); - state.last_count = count; - } - } - } - - /// Acquire the state lock. - fn acquire_state<'a>(&'a self) -> MutexGuard<'a, ProgressThreadState> { - self.state.lock().expect("poisoned lock") - } - - /// Wait for a wakeup - fn wait<'a>( - &'a self, - state: MutexGuard<'a, ProgressThreadState>, - ) -> MutexGuard<'a, ProgressThreadState> { - // wait to be notified, or for timeout - let (s2, _res) = self - .condition - .wait_timeout(state, UPDTATE_TIMEOUT) - .expect("poisoned lock"); - s2 - } - - fn shutdown(&self) { - let mut state = self.acquire_state(); - state.running = false; - self.ping(); - } - - fn ping(&self) { - self.condition.notify_all(); - } -} diff --git a/src/accel/sparse/consumer.rs b/src/accel/sparse/consumer.rs index d419effdf..0d53c201f 100644 --- a/src/accel/sparse/consumer.rs +++ b/src/accel/sparse/consumer.rs @@ -11,9 +11,7 @@ use arrow::{ buffer::OffsetBuffer, }; use arrow_schema::{DataType, Field, Fields}; -use rayon::iter::plumbing::{Consumer, Folder, Reducer}; - -use crate::progress::ProgressHandle; +use rayon::iter::plumbing::{Consumer, Folder, Reducer, UnindexedConsumer}; use super::SparseIndexType; @@ -32,15 +30,11 @@ pub struct ArrowCSRConsumer { struct CSRState { dimension: usize, - progress: ProgressHandle, } impl CSRState { - fn new(dim: usize, progress: ProgressHandle) -> Self { - CSRState { - dimension: dim, - progress: progress, - } + fn new(dim: usize) -> Self { + CSRState { dimension: dim } } } @@ -60,11 +54,7 @@ impl ArrowCSRConsumer { #[allow(dead_code)] pub(crate) fn new(dim: usize) -> Self { - Self::from_state(CSRState::new(dim, ProgressHandle::null())) - } - - pub(crate) fn with_progress(dim: usize, progress: &ProgressHandle) -> Self { - Self::from_state(CSRState::new(dim, progress.clone())) + Self::from_state(CSRState::new(dim)) } } @@ -90,6 +80,16 @@ impl Consumer for ArrowCSRConsumer { } } +impl UnindexedConsumer for ArrowCSRConsumer { + fn split_off_left(&self) -> Self { + ArrowCSRConsumer::from_state_ref(self.state.clone()) + } + + fn to_reducer(&self) -> Self::Reducer { + ArrowCSRConsumer::from_state_ref(self.state.clone()) + } +} + impl Folder for ArrowCSRConsumer { type Result = CSRResult; @@ -100,7 +100,6 @@ impl Folder for ArrowCSRConsumer { self.val_bld.append_value(s); } self.lengths.push(len); - self.state.progress.tick(); self } From 7c36a1d485cff6e7acb31df898574053a64ef46a Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Fri, 20 Mar 2026 18:14:16 -0400 Subject: [PATCH 3/3] accel: implement progress-wrapper with custom counts --- src/accel/data/cooc.rs | 9 ++++--- src/accel/progress.rs | 60 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/src/accel/data/cooc.rs b/src/accel/data/cooc.rs index 275c6d56f..56faaf4e2 100644 --- a/src/accel/data/cooc.rs +++ b/src/accel/data/cooc.rs @@ -6,6 +6,8 @@ //! Support for counting co-occurrences. +use std::sync::atomic::Ordering; + use arrow::{ array::{make_array, Array, ArrayData, Int32Array, RecordBatch}, pyarrow::PyArrowType, @@ -131,12 +133,12 @@ fn count_cooc_sequential( count_items(&mut counts, items); if i % 100 == 0 { Python::attach(|py| { - pb.advance(py, items.len()); + let _ = pb.advance(py, items.len()); }) } } Python::attach(|py| { - pb.advance(py, items.len()); + let _ = pb.advance(py, items.len()); }); // assemble the result @@ -158,7 +160,7 @@ fn count_cooc_parallel<'py, PC: ConcurrentPairCounter>( debug!("pass 2: counting groups"); // TODO: fix progress update - pb.process_iter(py, (0..n_groups).into_par_iter(), |iter| { + pb.process_iter_with_counter(py, (0..n_groups).into_par_iter(), |iter, counter| { iter.for_each(|i| { let start = g_ptrs[i]; let end = g_ptrs[i + 1]; @@ -172,6 +174,7 @@ fn count_cooc_parallel<'py, PC: ConcurrentPairCounter>( counts.crecord(ri, ci); } } + counter.fetch_add(items.len(), Ordering::Relaxed); }); Ok(()) })?; diff --git a/src/accel/progress.rs b/src/accel/progress.rs index 571551ba4..bb413936b 100644 --- a/src/accel/progress.rs +++ b/src/accel/progress.rs @@ -4,6 +4,7 @@ // Licensed under the MIT license, see LICENSE.md for details. // SPDX-License-Identifier: MIT +use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread; use std::time::Duration; @@ -37,14 +38,12 @@ impl ProgressHandle { ProgressHandle { pb, count: 0 } } - pub fn tick<'py>(&self, py: Python<'py>) { - self.advance(py, 1); - } - - pub fn advance<'py>(&self, py: Python<'py>, n: usize) { - self.update(py, self.count + n); + /// Advance the progress bar by the specified amount. + pub fn advance<'py>(&self, py: Python<'py>, n: usize) -> PyResult<()> { + self.update(py, self.count + n) } + /// Update the current completed total of the progress bar. pub fn update<'py>(&self, py: Python<'py>, complete: usize) -> PyResult<()> { if let Some(pb) = &self.pb { let pb = pb.bind(py); @@ -83,7 +82,54 @@ impl ProgressHandle { return Err(e); } let n = counter.get(); - self.update(py, n); + if let Err(e) = self.update(py, n) { + cancel.cancel(); + return Err(e); + } + } + + match handle.join() { + Ok(r) => r, + Err(_) => Err(PyRuntimeError::new_err("worker thread panicked")), + } + }) + } + + /// Process an iterator, with progress, thread-detach, and interrupt checks. + pub fn process_iter_with_counter<'py, I, R, F>( + &self, + py: Python<'py>, + iter: I, + proc: F, + ) -> PyResult + where + I: ParallelIterator + Send, + R: Send, + F: FnOnce(CancelAdapter, &AtomicUsize) -> PyResult + Send, + { + let adapter = CancelAdapter::new(iter); + let cancel = adapter.canceller(); + let caller = thread::current(); + let rc = AtomicUsize::new(0); + + thread::scope(|scope| { + let handle = scope.spawn(|| { + let result = proc(adapter, &rc); + caller.unpark(); + result + }); + + while !handle.is_finished() { + py.detach(|| thread::park_timeout(UPDATE_TIMEOUT)); + if let Err(e) = py.check_signals() { + cancel.cancel(); + return Err(e); + } + let n = rc.load(Ordering::Relaxed); + if let Err(e) = self.update(py, n) { + cancel.cancel(); + return Err(e); + } } match handle.join() {