Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions cpp_test/TestUnionFind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,19 @@ TEST(UfDecoder, HammingCode2){
ASSERT_EQ(decoding_syndrome,syndrome);

}
}

TEST(UfDecoderParallel, parallel_peeling){
auto pcm = ldpc::gf2codes::ring_code<ldpc::bp::BpEntry>(30);
std::vector<uint8_t> syndrome(pcm.m,0);
syndrome[0] = 1;
syndrome[1] = 1;
auto seq = UfDecoder(pcm);
auto par = UfDecoder(pcm);
par.set_omp_thread_count(4);
auto dec1 = seq.peel_decode(syndrome);
auto dec2 = par.peel_decode(syndrome);
ASSERT_EQ(dec1, dec2);
}

TEST(UfDecoder, ring_code3){
Expand Down Expand Up @@ -209,6 +221,18 @@ TEST(UfDecoder, peeling_with_boundaries_edge_case){

}

TEST(UfDecoder, parallel_peeling){
auto pcm = ldpc::gf2codes::ring_code<ldpc::bp::BpEntry>(10);
UfDecoder ufd(pcm,4);
std::vector<uint8_t> syndrome(pcm.m,0);
syndrome[0]=1;
syndrome[1]=1;
auto decoding = ufd.peel_decode(syndrome);
std::vector<uint8_t> expected(pcm.n,0);
expected[1]=1;
ASSERT_EQ(decoding,expected);
}



int main(int argc, char **argv)
Expand Down
48 changes: 48 additions & 0 deletions python_test/test_union_find_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
import scipy.sparse
import time
from ldpc.union_find_decoder import UnionFindDecoder


# Helper function to decode a batch of syndromes and return the
# average time per decode along with the decode result of the first
# syndrome (for correctness comparisons).
def _benchmark_decode(hx, syndromes, thread_count):
dec = UnionFindDecoder(hx, uf_method="")
dec.omp_thread_count = thread_count
t0 = time.perf_counter()
for syn in syndromes:
dec.decode(np.ascontiguousarray(syn))
avg_time = (time.perf_counter() - t0) / len(syndromes)
first_out = dec.decode(np.ascontiguousarray(syndromes[0]))
return avg_time, first_out


def test_union_find_parallel_benchmark():
hx = scipy.sparse.load_npz("python_test/pcms/hx_surface_20.npz").tocsr()
rng = np.random.default_rng(0)
# Using 128 samples keeps the runtime manageable while still providing
# a reasonable estimate of performance. Larger sample sizes caused the
# decoder to hang in this environment.
num_samples = 128
thread_counts = [1, 2, 4, 8]
# Higher error rates occasionally caused the decoder to hang during
# testing, so we restrict the range here.
ps = np.linspace(0.01, 0.05, 3)

results = {}
for p in ps:
errors = (rng.random((num_samples, hx.shape[1])) < p).astype(np.uint8)
syndromes = (hx.dot(errors.T) % 2).astype(np.uint8).T

avg_1, ref = _benchmark_decode(hx, syndromes, thread_counts[0])
results[(p, thread_counts[0])] = avg_1
print(f"p={p:.2f} threads={thread_counts[0]} avg_time={avg_1:.6f}s")

for t in thread_counts[1:]:
avg_t, out = _benchmark_decode(hx, syndromes, t)
results[(p, t)] = avg_t
print(f"p={p:.2f} threads={t} avg_time={avg_t:.6f}s")
assert np.array_equal(out, ref)

assert results[(p, thread_counts[-1])] <= results[(p, thread_counts[0])] * 2
50 changes: 43 additions & 7 deletions src_cpp/union_find.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#include <robin_map.h>
#include <robin_set.h>
#include <numeric>
#include <mutex>
#ifdef _OPENMP
#include <omp.h>
#endif

#include "gf2sparse_linalg.hpp"
#include "bp.hpp"
Expand All @@ -22,6 +26,7 @@ namespace ldpc::uf {

const std::vector<double> EMPTY_DOUBLE_VECTOR = {};
tsl::robin_set<int> EMPTY_INT_ROBIN_SET = {};
inline std::mutex uf_global_mutex;

std::vector<int> sort_indices(std::vector<double> &B) {
std::vector<int> indices(B.size());
Expand Down Expand Up @@ -107,6 +112,7 @@ namespace ldpc::uf {
}

int add_bit_node_to_cluster(int bit_index) {
std::lock_guard<std::mutex> lk(uf_global_mutex);
auto bit_membership = this->global_bit_membership[bit_index];
if (bit_membership == this) return 0; //if the bit is already in the cluster terminate.
else if (bit_membership == NULL) {
Expand Down Expand Up @@ -141,6 +147,7 @@ namespace ldpc::uf {
}

void merge_with_cluster(Cluster *cl2) {
std::lock_guard<std::mutex> lk(uf_global_mutex);
for (auto bit_index: cl2->bit_nodes) {
this->bit_nodes.insert(bit_index);
this->global_bit_membership[bit_index] = this;
Expand Down Expand Up @@ -407,12 +414,14 @@ namespace ldpc::uf {
int check_count;
tsl::robin_set<int> planar_code_boundary_bits;
bool pcm_max_bit_degree_2;
int omp_thread_count;

UfDecoder(ldpc::bp::BpSparse &parity_check_matrix) : pcm(parity_check_matrix) {
UfDecoder(ldpc::bp::BpSparse &parity_check_matrix, int omp_threads = 1) : pcm(parity_check_matrix) {
this->bit_count = pcm.n;
this->check_count = pcm.m;
this->decoding.resize(this->bit_count);
this->weighted = false;
this->omp_thread_count = 1;

this->pcm_max_bit_degree_2 = true;
for (auto i = 0; i < this->pcm.n; i++) {
Expand All @@ -427,6 +436,10 @@ namespace ldpc::uf {
}
}

void set_omp_thread_count(int count) {
this->omp_thread_count = count;
}

std::vector<uint8_t> &
peel_decode(const std::vector<uint8_t> &syndrome, const std::vector<double> &bit_weights = EMPTY_DOUBLE_VECTOR,
int bits_per_step = 1) {
Expand All @@ -452,27 +465,50 @@ namespace ldpc::uf {
}

while (!invalid_clusters.empty()) {
for (auto cl: invalid_clusters) {
#pragma omp parallel for num_threads(this->omp_thread_count) schedule(static)
for (size_t i = 0; i < invalid_clusters.size(); i++) {
auto cl = invalid_clusters[i];
if (cl->active) {
cl->grow_cluster(bit_weights, bits_per_step);
}
}

invalid_clusters.clear();
for (auto cl: clusters) {
if (cl->active && cl->parity() == 1 && !cl->contains_boundary_bits) {
invalid_clusters.push_back(cl);
std::vector<std::vector<Cluster *>> local_invalid_vec(this->omp_thread_count);
#pragma omp parallel num_threads(this->omp_thread_count)
{
int tid = 0;
#ifdef _OPENMP
tid = omp_get_thread_num();
#endif
auto &local_invalid = local_invalid_vec[tid];
#pragma omp for schedule(static)
for (size_t i = 0; i < clusters.size(); i++) {
auto cl = clusters[i];
if (cl->active && cl->parity() == 1 && !cl->contains_boundary_bits) {
local_invalid.push_back(cl);
}
}
}
for (auto &vec : local_invalid_vec) {
invalid_clusters.insert(invalid_clusters.end(), vec.begin(), vec.end());
}
std::sort(invalid_clusters.begin(), invalid_clusters.end(), [](const Cluster *lhs, const Cluster *rhs) {
return lhs->bit_nodes.size() < rhs->bit_nodes.size();
});
}
for (auto cl: clusters) {
#pragma omp parallel for num_threads(this->omp_thread_count) schedule(static)
for (size_t i = 0; i < clusters.size(); i++) {
auto cl = clusters[i];
if (cl->active) {
auto erasure = cl->peel_decode(syndrome);
for (int bit: erasure) this->decoding[bit] = 1;
for (int bit: erasure) {
#pragma omp atomic write
this->decoding[bit] = 1;
}
}
}
for (auto cl: clusters) {
delete cl;
}
delete[] global_bit_membership;
Expand Down
8 changes: 7 additions & 1 deletion src_python/ldpc/union_find_decoder/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class UnionFindDecoder:
Default is False.
"""

def __cinit__(self, pcm: Union[np.ndarray, spmatrix], uf_method: str = False): ...
def __cinit__(self, pcm: Union[np.ndarray, spmatrix], uf_method: str = False, omp_thread_count: int = 1): ...

def __del__(self): ...

Expand Down Expand Up @@ -49,5 +49,11 @@ class UnionFindDecoder:
of the parity-check matrix.
"""

@property
def omp_thread_count(self) -> int: ...

@omp_thread_count.setter
def omp_thread_count(self, value: int) -> None: ...

@property
def decoding(self): ...
7 changes: 5 additions & 2 deletions src_python/ldpc/union_find_decoder/_union_find_decoder.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ ctypedef np.uint8_t uint8_t

cdef extern from "union_find.hpp" namespace "ldpc::uf":
cdef cppclass uf_decoder_cpp "ldpc::uf::UfDecoder":
uf_decoder_cpp(BpSparse& pcm) except +
uf_decoder_cpp(BpSparse& pcm, int omp_thread_count) except +
vector[uint8_t]& peel_decode(vector[uint8_t]& syndrome, const vector[double]& bit_weights, int bits_per_step)
vector[uint8_t]& matrix_decode(vector[uint8_t]& syndrome, const vector[double]& bit_weights, int bits_per_step)
void set_omp_thread_count(int count)
int omp_thread_count
vector[uint8_t] decoding

cdef const vector[double] EMPTY_DOUBLE_VECTOR "ldpc::uf::EMPTY_DOUBLE_VECTOR"
Expand All @@ -26,4 +28,5 @@ cdef class UnionFindDecoder():
cdef vector[uint8_t] _syndrome
cdef vector[double] uf_llrs
cdef bool uf_method
cdef int bits_per_step
cdef int bits_per_step
cdef int omp_thread_count
17 changes: 14 additions & 3 deletions src_python/ldpc/union_find_decoder/_union_find_decoder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ cdef class UnionFindDecoder:
Default is False.
"""

def __cinit__(self, pcm: Union[np.ndarray, spmatrix], uf_method: str = False):
def __cinit__(self, pcm: Union[np.ndarray, spmatrix], uf_method: str = False, omp_thread_count: int = 1):

self.MEMORY_ALLOCATED=False

Expand All @@ -75,7 +75,8 @@ cdef class UnionFindDecoder:
# get the parity check dimensions
self.m, self.n = pcm.shape[0], pcm.shape[1]

self.ufd = new uf_decoder_cpp(self.pcm[0])
self.ufd = new uf_decoder_cpp(self.pcm[0], omp_thread_count)
self.omp_thread_count = omp_thread_count
self._syndrome.resize(self.m) #C vector for the syndrome
self.uf_llrs.resize(self.n) #C vector for the log-likehood ratios
self.uf_method = uf_method
Expand Down Expand Up @@ -159,9 +160,19 @@ cdef class UnionFindDecoder:
out = np.zeros(self.n,dtype=DTYPE)
for i in range(self.n):
out[i] = self.ufd.decoding[i]

return out

@property
def omp_thread_count(self) -> int:
return self.ufd.omp_thread_count

@omp_thread_count.setter
def omp_thread_count(self, value: int) -> None:
if not isinstance(value, int) or value < 1:
raise TypeError("The omp_thread_count must be specified as a positive integer.")
self.ufd.set_omp_thread_count(value)

@property
def decoding(self):

Expand Down