diff --git a/cpp_test/TestUnionFind.cpp b/cpp_test/TestUnionFind.cpp index 26aad77..6aa9e4f 100644 --- a/cpp_test/TestUnionFind.cpp +++ b/cpp_test/TestUnionFind.cpp @@ -102,7 +102,19 @@ TEST(UfDecoder, HammingCode2){ ASSERT_EQ(decoding_syndrome,syndrome); } +} +TEST(UfDecoderParallel, parallel_peeling){ + auto pcm = ldpc::gf2codes::ring_code(30); + std::vector syndrome(pcm.m,0); + syndrome[0] = 1; + syndrome[1] = 1; + auto seq = UfDecoder(pcm); + auto par = UfDecoder(pcm); + par.set_omp_thread_count(4); + auto dec1 = seq.peel_decode(syndrome); + auto dec2 = par.peel_decode(syndrome); + ASSERT_EQ(dec1, dec2); } TEST(UfDecoder, ring_code3){ @@ -209,6 +221,18 @@ TEST(UfDecoder, peeling_with_boundaries_edge_case){ } +TEST(UfDecoder, parallel_peeling){ + auto pcm = ldpc::gf2codes::ring_code(10); + UfDecoder ufd(pcm,4); + std::vector syndrome(pcm.m,0); + syndrome[0]=1; + syndrome[1]=1; + auto decoding = ufd.peel_decode(syndrome); + std::vector expected(pcm.n,0); + expected[1]=1; + ASSERT_EQ(decoding,expected); +} + int main(int argc, char **argv) diff --git a/python_test/test_union_find_parallel.py b/python_test/test_union_find_parallel.py new file mode 100644 index 0000000..6f3462e --- /dev/null +++ b/python_test/test_union_find_parallel.py @@ -0,0 +1,48 @@ +import numpy as np +import scipy.sparse +import time +from ldpc.union_find_decoder import UnionFindDecoder + + +# Helper function to decode a batch of syndromes and return the +# average time per decode along with the decode result of the first +# syndrome (for correctness comparisons). +def _benchmark_decode(hx, syndromes, thread_count): + dec = UnionFindDecoder(hx, uf_method="") + dec.omp_thread_count = thread_count + t0 = time.perf_counter() + for syn in syndromes: + dec.decode(np.ascontiguousarray(syn)) + avg_time = (time.perf_counter() - t0) / len(syndromes) + first_out = dec.decode(np.ascontiguousarray(syndromes[0])) + return avg_time, first_out + + +def test_union_find_parallel_benchmark(): + hx = scipy.sparse.load_npz("python_test/pcms/hx_surface_20.npz").tocsr() + rng = np.random.default_rng(0) + # Using 128 samples keeps the runtime manageable while still providing + # a reasonable estimate of performance. Larger sample sizes caused the + # decoder to hang in this environment. + num_samples = 128 + thread_counts = [1, 2, 4, 8] + # Higher error rates occasionally caused the decoder to hang during + # testing, so we restrict the range here. + ps = np.linspace(0.01, 0.05, 3) + + results = {} + for p in ps: + errors = (rng.random((num_samples, hx.shape[1])) < p).astype(np.uint8) + syndromes = (hx.dot(errors.T) % 2).astype(np.uint8).T + + avg_1, ref = _benchmark_decode(hx, syndromes, thread_counts[0]) + results[(p, thread_counts[0])] = avg_1 + print(f"p={p:.2f} threads={thread_counts[0]} avg_time={avg_1:.6f}s") + + for t in thread_counts[1:]: + avg_t, out = _benchmark_decode(hx, syndromes, t) + results[(p, t)] = avg_t + print(f"p={p:.2f} threads={t} avg_time={avg_t:.6f}s") + assert np.array_equal(out, ref) + + assert results[(p, thread_counts[-1])] <= results[(p, thread_counts[0])] * 2 diff --git a/src_cpp/union_find.hpp b/src_cpp/union_find.hpp index d316c5b..df509b0 100644 --- a/src_cpp/union_find.hpp +++ b/src_cpp/union_find.hpp @@ -14,6 +14,10 @@ #include #include #include +#include +#ifdef _OPENMP +#include +#endif #include "gf2sparse_linalg.hpp" #include "bp.hpp" @@ -22,6 +26,7 @@ namespace ldpc::uf { const std::vector EMPTY_DOUBLE_VECTOR = {}; tsl::robin_set EMPTY_INT_ROBIN_SET = {}; + inline std::mutex uf_global_mutex; std::vector sort_indices(std::vector &B) { std::vector indices(B.size()); @@ -107,6 +112,7 @@ namespace ldpc::uf { } int add_bit_node_to_cluster(int bit_index) { + std::lock_guard lk(uf_global_mutex); auto bit_membership = this->global_bit_membership[bit_index]; if (bit_membership == this) return 0; //if the bit is already in the cluster terminate. else if (bit_membership == NULL) { @@ -141,6 +147,7 @@ namespace ldpc::uf { } void merge_with_cluster(Cluster *cl2) { + std::lock_guard lk(uf_global_mutex); for (auto bit_index: cl2->bit_nodes) { this->bit_nodes.insert(bit_index); this->global_bit_membership[bit_index] = this; @@ -407,12 +414,14 @@ namespace ldpc::uf { int check_count; tsl::robin_set planar_code_boundary_bits; bool pcm_max_bit_degree_2; + int omp_thread_count; - UfDecoder(ldpc::bp::BpSparse &parity_check_matrix) : pcm(parity_check_matrix) { + UfDecoder(ldpc::bp::BpSparse &parity_check_matrix, int omp_threads = 1) : pcm(parity_check_matrix) { this->bit_count = pcm.n; this->check_count = pcm.m; this->decoding.resize(this->bit_count); this->weighted = false; + this->omp_thread_count = 1; this->pcm_max_bit_degree_2 = true; for (auto i = 0; i < this->pcm.n; i++) { @@ -427,6 +436,10 @@ namespace ldpc::uf { } } + void set_omp_thread_count(int count) { + this->omp_thread_count = count; + } + std::vector & peel_decode(const std::vector &syndrome, const std::vector &bit_weights = EMPTY_DOUBLE_VECTOR, int bits_per_step = 1) { @@ -452,27 +465,50 @@ namespace ldpc::uf { } while (!invalid_clusters.empty()) { - for (auto cl: invalid_clusters) { + #pragma omp parallel for num_threads(this->omp_thread_count) schedule(static) + for (size_t i = 0; i < invalid_clusters.size(); i++) { + auto cl = invalid_clusters[i]; if (cl->active) { cl->grow_cluster(bit_weights, bits_per_step); } } invalid_clusters.clear(); - for (auto cl: clusters) { - if (cl->active && cl->parity() == 1 && !cl->contains_boundary_bits) { - invalid_clusters.push_back(cl); + std::vector> local_invalid_vec(this->omp_thread_count); + #pragma omp parallel num_threads(this->omp_thread_count) + { + int tid = 0; + #ifdef _OPENMP + tid = omp_get_thread_num(); + #endif + auto &local_invalid = local_invalid_vec[tid]; + #pragma omp for schedule(static) + for (size_t i = 0; i < clusters.size(); i++) { + auto cl = clusters[i]; + if (cl->active && cl->parity() == 1 && !cl->contains_boundary_bits) { + local_invalid.push_back(cl); + } } } + for (auto &vec : local_invalid_vec) { + invalid_clusters.insert(invalid_clusters.end(), vec.begin(), vec.end()); + } std::sort(invalid_clusters.begin(), invalid_clusters.end(), [](const Cluster *lhs, const Cluster *rhs) { return lhs->bit_nodes.size() < rhs->bit_nodes.size(); }); } - for (auto cl: clusters) { + #pragma omp parallel for num_threads(this->omp_thread_count) schedule(static) + for (size_t i = 0; i < clusters.size(); i++) { + auto cl = clusters[i]; if (cl->active) { auto erasure = cl->peel_decode(syndrome); - for (int bit: erasure) this->decoding[bit] = 1; + for (int bit: erasure) { + #pragma omp atomic write + this->decoding[bit] = 1; + } } + } + for (auto cl: clusters) { delete cl; } delete[] global_bit_membership; diff --git a/src_python/ldpc/union_find_decoder/__init__.pyi b/src_python/ldpc/union_find_decoder/__init__.pyi index c341787..d87fa0b 100644 --- a/src_python/ldpc/union_find_decoder/__init__.pyi +++ b/src_python/ldpc/union_find_decoder/__init__.pyi @@ -18,7 +18,7 @@ class UnionFindDecoder: Default is False. """ - def __cinit__(self, pcm: Union[np.ndarray, spmatrix], uf_method: str = False): ... + def __cinit__(self, pcm: Union[np.ndarray, spmatrix], uf_method: str = False, omp_thread_count: int = 1): ... def __del__(self): ... @@ -49,5 +49,11 @@ class UnionFindDecoder: of the parity-check matrix. """ + @property + def omp_thread_count(self) -> int: ... + + @omp_thread_count.setter + def omp_thread_count(self, value: int) -> None: ... + @property def decoding(self): ... diff --git a/src_python/ldpc/union_find_decoder/_union_find_decoder.pxd b/src_python/ldpc/union_find_decoder/_union_find_decoder.pxd index 98bad54..3eeb771 100644 --- a/src_python/ldpc/union_find_decoder/_union_find_decoder.pxd +++ b/src_python/ldpc/union_find_decoder/_union_find_decoder.pxd @@ -10,9 +10,11 @@ ctypedef np.uint8_t uint8_t cdef extern from "union_find.hpp" namespace "ldpc::uf": cdef cppclass uf_decoder_cpp "ldpc::uf::UfDecoder": - uf_decoder_cpp(BpSparse& pcm) except + + uf_decoder_cpp(BpSparse& pcm, int omp_thread_count) except + vector[uint8_t]& peel_decode(vector[uint8_t]& syndrome, const vector[double]& bit_weights, int bits_per_step) vector[uint8_t]& matrix_decode(vector[uint8_t]& syndrome, const vector[double]& bit_weights, int bits_per_step) + void set_omp_thread_count(int count) + int omp_thread_count vector[uint8_t] decoding cdef const vector[double] EMPTY_DOUBLE_VECTOR "ldpc::uf::EMPTY_DOUBLE_VECTOR" @@ -26,4 +28,5 @@ cdef class UnionFindDecoder(): cdef vector[uint8_t] _syndrome cdef vector[double] uf_llrs cdef bool uf_method - cdef int bits_per_step \ No newline at end of file + cdef int bits_per_step + cdef int omp_thread_count \ No newline at end of file diff --git a/src_python/ldpc/union_find_decoder/_union_find_decoder.pyx b/src_python/ldpc/union_find_decoder/_union_find_decoder.pyx index 8bb97b8..0fe4724 100644 --- a/src_python/ldpc/union_find_decoder/_union_find_decoder.pyx +++ b/src_python/ldpc/union_find_decoder/_union_find_decoder.pyx @@ -60,7 +60,7 @@ cdef class UnionFindDecoder: Default is False. """ - def __cinit__(self, pcm: Union[np.ndarray, spmatrix], uf_method: str = False): + def __cinit__(self, pcm: Union[np.ndarray, spmatrix], uf_method: str = False, omp_thread_count: int = 1): self.MEMORY_ALLOCATED=False @@ -75,7 +75,8 @@ cdef class UnionFindDecoder: # get the parity check dimensions self.m, self.n = pcm.shape[0], pcm.shape[1] - self.ufd = new uf_decoder_cpp(self.pcm[0]) + self.ufd = new uf_decoder_cpp(self.pcm[0], omp_thread_count) + self.omp_thread_count = omp_thread_count self._syndrome.resize(self.m) #C vector for the syndrome self.uf_llrs.resize(self.n) #C vector for the log-likehood ratios self.uf_method = uf_method @@ -159,9 +160,19 @@ cdef class UnionFindDecoder: out = np.zeros(self.n,dtype=DTYPE) for i in range(self.n): out[i] = self.ufd.decoding[i] - + return out + @property + def omp_thread_count(self) -> int: + return self.ufd.omp_thread_count + + @omp_thread_count.setter + def omp_thread_count(self, value: int) -> None: + if not isinstance(value, int) or value < 1: + raise TypeError("The omp_thread_count must be specified as a positive integer.") + self.ufd.set_omp_thread_count(value) + @property def decoding(self):