Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pyo3-log = "^0.13"
thiserror = "^2"

rayon = "^1.10"
rayon-cancel = "^1.0"

hashbrown = "^0.16"
hex = "^0.4"
Expand Down
26 changes: 13 additions & 13 deletions src/accel/als/explicit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,24 @@ pub(super) fn train_explicit_matrix<'py>(
let other_py = other.readonly();
let other = other_py.as_array();

let mut progress = ProgressHandle::from_input(progress);
let progress = ProgressHandle::from_input(progress);
debug!(
"beginning explicit ALS training half with {} rows",
other.nrows()
);

let frob: f32 = py.detach(|| {
this.outer_iter_mut()
.into_par_iter()
.enumerate()
.map(|(i, row)| {
let f = train_row_solve(&solver, &matrix, i, row, &other, reg);
progress.tick();
f
})
.sum()
});
progress.shutdown(py)?;
let frob: f32 = progress.process_iter(
py,
this.outer_iter_mut().into_par_iter().enumerate(),
|iter| {
Ok(iter
.map(|(i, row)| {
let f = train_row_solve(&solver, &matrix, i, row, &other, reg);
f
})
.sum())
},
)?;

Ok(frob.sqrt())
}
Expand Down
26 changes: 13 additions & 13 deletions src/accel/als/implicit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,23 @@ pub(super) fn train_implicit_matrix<'py>(
let otor_py = otor.readonly();
let otor = otor_py.as_array();

let mut progress = ProgressHandle::from_input(progress);
let progress = ProgressHandle::from_input(progress);
debug!(
"beginning implicit ALS training half with {} rows",
other.nrows()
);
let frob: f32 = py.detach(|| {
this.outer_iter_mut()
.into_par_iter()
.enumerate()
.map(|(i, row)| {
let f = train_row_solve(&solver, &matrix, i, row, &other, &otor);
progress.tick();
f
})
.sum()
});
progress.shutdown(py)?;
let frob: f32 = progress.process_iter(
py,
this.outer_iter_mut().into_par_iter().enumerate(),
|iter| {
Ok(iter
.map(|(i, row)| {
let f = train_row_solve(&solver, &matrix, i, row, &other, &otor);
f
})
.sum())
},
)?;

Ok(frob.sqrt())
}
Expand Down
87 changes: 46 additions & 41 deletions src/accel/data/cooc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

//! Support for counting co-occurrences.

use std::sync::atomic::Ordering;

use arrow::{
array::{make_array, Array, ArrayData, Int32Array, RecordBatch},
pyarrow::PyArrowType,
Expand Down Expand Up @@ -47,20 +49,18 @@ pub fn count_cooc<'py>(
return Err(PyValueError::new_err("array length mismatch"));
}

let mut pb = ProgressHandle::new(progress);

let out = py.detach(|| {
let groups = checked_array_ref::<Int32Array>("groups", "Int32", &groups)?;
let items = checked_array_ref::<Int32Array>("items", "Int32", &items)?;
let pb = ProgressHandle::new(progress);

if ordered {
let groups = checked_array_ref::<Int32Array>("groups", "Int32", &groups)?;
let items = checked_array_ref::<Int32Array>("items", "Int32", &items)?;
let out = if ordered {
py.detach(|| {
count_cooc_sequential::<AsymmetricPairCounter>(&pb, groups, items, n_groups, n_items)
} else {
let ctr = SymmetricPairCounter::with_diagonal(n_items, diagonal);
count_cooc_parallel(ctr, groups, items, n_groups, &pb)
}
});
pb.shutdown(py)?;
})
} else {
let ctr = SymmetricPairCounter::with_diagonal(n_items, diagonal);
count_cooc_parallel(py, ctr, groups, items, n_groups, &pb)
};
let out = out?;
debug!(
"finished counting {} co-occurrances",
Expand Down Expand Up @@ -99,17 +99,13 @@ pub fn dense_cooc<'py>(
return Err(PyValueError::new_err("array length mismatch"));
}

let mut pb = ProgressHandle::new(progress);
let pb = ProgressHandle::new(progress);

let out = py.detach(|| {
let groups = checked_array_ref::<Int32Array>("groups", "Int32", &groups)?;
let items = checked_array_ref::<Int32Array>("items", "Int32", &items)?;
let groups = checked_array_ref::<Int32Array>("groups", "Int32", &groups)?;
let items = checked_array_ref::<Int32Array>("items", "Int32", &items)?;

let ctr = DensePairCounter::with_diagonal(n_items, diagonal);
count_cooc_parallel(ctr, groups, items, n_groups, &pb)
});
pb.shutdown(py)?;
let out = out?;
let ctr = DensePairCounter::with_diagonal(n_items, diagonal);
let out = count_cooc_parallel(py, ctr, groups, items, n_groups, &pb)?;
debug!("finished counting co-occurrances");

Ok(out.to_pyarray(py))
Expand All @@ -135,15 +131,22 @@ fn count_cooc_sequential<PC: PairCounter>(
let end = g_ptrs[i + 1];
let items = &ivals[start..end];
count_items(&mut counts, items);
pb.advance(items.len());
if i % 100 == 0 {
Python::attach(|py| {
let _ = pb.advance(py, items.len());
})
}
}
pb.flush();
Python::attach(|py| {
let _ = pb.advance(py, items.len());
});

// assemble the result
Ok(counts.finish())
}

fn count_cooc_parallel<PC: ConcurrentPairCounter>(
fn count_cooc_parallel<'py, PC: ConcurrentPairCounter>(
py: Python<'py>,
counts: PC,
groups: &Int32Array,
items: &Int32Array,
Expand All @@ -153,26 +156,28 @@ fn count_cooc_parallel<PC: ConcurrentPairCounter>(
let gvals = groups.values();
let ivals = items.values();

let g_ptrs = compute_group_pointers(n_groups, gvals)?;
let g_ptrs = py.detach(|| compute_group_pointers(n_groups, gvals))?;

debug!("pass 2: counting groups");
(0..n_groups).into_par_iter().for_each(|i| {
let start = g_ptrs[i];
let end = g_ptrs[i + 1];
let items = &ivals[start..end];
let n = items.len();

for i in 0..n {
let ri = items[i as usize];
for j in i..n {
let ci = items[j as usize];
counts.crecord(ri, ci);
// TODO: fix progress update
pb.process_iter_with_counter(py, (0..n_groups).into_par_iter(), |iter, counter| {
iter.for_each(|i| {
let start = g_ptrs[i];
let end = g_ptrs[i + 1];
let items = &ivals[start..end];
let n = items.len();

for i in 0..n {
let ri = items[i as usize];
for j in i..n {
let ci = items[j as usize];
counts.crecord(ri, ci);
}
}
}

pb.advance(items.len());
});
pb.flush();
counter.fetch_add(items.len(), Ordering::Relaxed);
});
Ok(())
})?;

// assemble the result
Ok(counts.finish())
Expand Down
55 changes: 25 additions & 30 deletions src/accel/knn/item_train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,33 @@ pub fn compute_similarities<'py>(
progress: Bound<'py, PyAny>,
) -> PyResult<Vec<PyArrowType<ArrayData>>> {
let (nu, ni) = shape;
let mut progress = ProgressHandle::from_input(progress);

let res = py.detach(|| {
// extract the data
debug!("preparing {}x{} training", nu, ni);
debug!(
"resolving user-item matrix (type: {:#?})",
ui_ratings.0.data_type()
);
let ui_mat = CSRMatrix::from_arrow(make_array(ui_ratings.0))?;
debug!("resolving item-user matrix");
let iu_mat = CSRMatrix::from_arrow(make_array(iu_ratings.0))?;
assert_eq!(ui_mat.len(), nu);
assert_eq!(ui_mat.n_cols, ni);
assert_eq!(iu_mat.len(), ni);
assert_eq!(iu_mat.n_cols, nu);

// let's compute!
let range = 0..ni;
debug!("computing similarity rows");
let collector = ArrowCSRConsumer::with_progress(ni, &progress);
let chunks = range
.into_par_iter()
let progress = ProgressHandle::from_input(progress);

// extract the data
debug!("preparing {}x{} training", nu, ni);
debug!(
"resolving user-item matrix (type: {:#?})",
ui_ratings.0.data_type()
);
let ui_mat = CSRMatrix::from_arrow(make_array(ui_ratings.0))?;
debug!("resolving item-user matrix");
let iu_mat = CSRMatrix::from_arrow(make_array(iu_ratings.0))?;
assert_eq!(ui_mat.len(), nu);
assert_eq!(ui_mat.n_cols, ni);
assert_eq!(iu_mat.len(), ni);
assert_eq!(iu_mat.n_cols, nu);

// let's compute!
let range = 0..ni;
debug!("computing similarity rows");
let collector = ArrowCSRConsumer::new(ni);
let chunks = progress.process_iter(py, range.into_par_iter(), |iter| {
Ok(iter
.map(|row| sim_row(row, &ui_mat, &iu_mat, min_sim, save_nbrs))
.drive(collector);

Ok(chunks.iter().map(|a| a.into_data().into()).collect())
});

progress.shutdown(py)?;
.drive_unindexed(collector))
})?;

res
Ok(chunks.iter().map(|a| a.into_data().into()).collect())
}

fn sim_row(
Expand Down
Loading
Loading