Skip to content

Add BP on top of MoMatching #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
MarcSerraPeralta opened this issue Mar 31, 2025 · 1 comment
Open

Add BP on top of MoMatching #9

MarcSerraPeralta opened this issue Mar 31, 2025 · 1 comment
Labels
enhancement New feature or request

Comments

@MarcSerraPeralta
Copy link
Owner

Old implementation:

import numpy as np
import stim
from pymatching import Matching
from ldpc.bp_decoder import BpDecoder

from .greedy_algorithm import greedy_algorithm
from .util import (
    comb_probs_to_w,
    dem_to_hplc,
    dem_only_errors,
    valid_decomposition,
    get_edges_dict,
    standardize_edge,
)


class BeliefSoMatching:
    """
    Decodes the observables (from a logical measurement) in a logical Clifford circuit
    run on a surface code. It runs belief-propagation on the full hypergraph
    to update the probabilities of the subgraphs.
    """

    def __init__(
        self,
        dem: stim.DetectorErrorModel,
        circuit: stim.Circuit,
        logicals: list[list[str]],
        stab_coords: dict[str, tuple[float | int, float | int, float | int]],
        detector_frame: str,
        ignore_decomposition_failures: bool = False,
        max_iter: int = 20,
        bp_method: str = "product_sum",
        **kargs_bp,
    ):
        """
        Initializes ``BeliefSoMatching``.

        Parameters
        ----------
        dem
            Detector error model.
        circuit
            Logical circuit with only MZ, RZ, MX, RX, S, H, X, Z, Y, I, CNOT gates.
            Circuit must start with all qubits being reset and end with all qubits
            being measured. TICKs represent QEC cycles.
            Conditional gates based on outcomes are not allowed.
            Qubits can only perform a single operation inbetween QEC cycles.
            The next operation of a measurement must be a reset.
            It can be a ``stim.Circuit`` or a ``np.ndarray``
            (see ``somatching.greedy_algorithm.get_ops``).
        logicals
            Definition of the logicals as done in the circuit.
            E.g. if one has defined L0 = Z0*Z1, then the ``logicals``
            should be ``[["Z0", "Z1"]]``. They must be ordered following
            the logical observable indices in the circuit.
        stab_coords
            Dictionary with keys corresponding to Z0, X0, Z1, X1... (?-stabs for each logical) in the
            detector error model and the keys being the coordinates of all stabilizers
            associated with that logical qubit.
            The observable IDs must also match with the qubit indeces from ``circuit``.
        detector_frame
            Frame used when defining the detectors. Must be either ``"pre-gate"``
            or ``"post-gate"``.
        ignore_decomposition_failures
            Ignore hyperedge decomposition failures when building the DEM subgraphs
            for ``pymatching.Matching``. By default ``False``. If the circuit distance
            is lower than the distance of the surface code(s), set to ``True``.
        max_iter
            Maximum number of iterations for belief propagation.
        bp_method
            Belief-propagation method, for more information see ``ldpc.bp_decoder.BpDecoder``.
        kargs_bp
            Extra arguments for ``ldpc.bp_decoder.BpDecoder``.
        """
        det_to_coords = dem.get_detector_coordinates()
        if any(c == [] for c in det_to_coords.values()):
            raise ValueError("All detectors must have coordinates.")
        coords_to_det = {tuple(v): k for k, v in det_to_coords.items()}

        self.dem = dem
        self.dem_errors = dem_only_errors(self.dem)

        # for generating the subgraphs
        self.circuit = circuit
        self.logicals = logicals
        self.stab_coords = stab_coords
        self.detector_frame = detector_frame
        self.det_to_coords = det_to_coords
        self.coords_to_det = coords_to_det
        self.detector_frame = detector_frame

        # for BP
        self.h, self.p, self.l, _ = dem_to_hplc(dem)
        self.bp_decoder = BpDecoder(
            self.h,
            error_channel=self.p,
            input_vector_type="syndrome",
            max_iter=max_iter,
            bp_method=bp_method,
            **kargs_bp,
        )

        # for running Matching in the subgraphs
        self.ignore_decomposition_failures = ignore_decomposition_failures
        self.h_sub = []
        self.l_sub = []
        self.e_sub_to_h_supp = []

        self._prepare_decoder()

        return

    def _prepare_decoder(self):
        """
        Prepares all the variables required for running ``self.decode``
        and ``self.decode_batch``.
        """
        self.h_sub = []
        self.l_sub = []
        self.e_sub_to_h_supp = []

        for k, logical in enumerate(self.logicals):
            tracks = greedy_algorithm(
                self.circuit,
                detector_frame=self.detector_frame,
                r_start=999_999_999,
                t_start=get_initial_tracks(logical, self.circuit.num_qubits),
            )
            active_dets, inactive_dets = get_active_inactive_dets(
                tracks, stab_coords=self.stab_coords, coords_to_det=self.coords_to_det
            )
            active_e_sub, active_h_sub = split_active_errors(
                self.dem_errors, active_dets, log_obs_id=k
            )
            # this is just for pymatching to not complain about "no perfect matching could
            # not be found" because some nodes are not connected
            for det in inactive_dets:
                instr = stim.DemInstruction(
                    "error", args=[0.5], targets=[stim.target_relative_detector_id(det)]
                )
                active_e_sub.append((999_999_999, instr))

            primitive_dem = get_primitive_dem(active_e_sub)
            h_sub_decom = get_hyperedge_decomposition(
                primitive_dem,
                active_h_sub,
                ignore_decomposition_failures=self.ignore_decomposition_failures,
            )
            e_sub_to_h_supp = get_e_sub_to_h_supp(active_e_sub, h_sub_decom)
            h_sub, _, l_sub, _ = dem_to_hplc(primitive_dem)

            self.h_sub.append(h_sub)
            self.l_sub.append(l_sub)
            self.e_sub_to_h_supp.append(e_sub_to_h_supp)

        return

    def decode(self, defects: np.ndarray) -> np.ndarray:
        corr = self.bp_decoder.decode(defects)
        if self.bp_decoder.converge:
            return (self.l @ corr) % 2

        llrs = self.bp_decoder.log_prob_ratios
        if np.isnan(llrs).any():
            raise ValueError("Returned `log_prob_ratios` from BP are NaN.")
        p_h = 1 / (1 + np.exp(llrs))

        logical_correction = np.zeros(len(self.logicals))
        for k, _ in enumerate(self.logicals):
            mwpm = Matching.from_check_matrix(
                check_matrix=self.h_sub[k],
                weights=[
                    comb_probs_to_w(p_h[sup]) if 999_999_999 not in sup else 0.5
                    for sup in self.e_sub_to_h_supp[k]
                ],
                faults_matrix=self.l_sub[k],
                use_virtual_boundary_node=True,
            )
            prediction = mwpm.decode(defects)
            logical_correction[k] = prediction[k]

        return logical_correction

    def decode_batch(self, defects: np.ndarray) -> np.ndarray:
        logical_correction = np.zeros((len(defects), len(self.logicals)), dtype=bool)
        for i in range(len(defects)):
            logical_correction[i] = self.decode(defects[i])
        return logical_correction


def get_initial_tracks(logical: list[str], num_qubits: int) -> np.ndarray:
    """Returns initial track indices for ``greedy_algorithm``."""
    shift = {"X": 0, "Z": 1}
    t_start = [2] * (2 * num_qubits)
    for l in logical:
        index = 2 * int(l[1:]) + shift[l[0]]
        t_start[index] = 1
    return np.array(t_start)


def get_active_inactive_dets(
    tracks: np.ndarray,
    stab_coords: dict,
    coords_to_det: dict,
) -> tuple[set[int], set[int]]:
    dets_track_1 = []
    for t, slice in enumerate(tracks):
        if t == len(tracks) - 1:
            # logical measurement at the end
            t -= 0.5

        for k, s in enumerate(slice):
            if s == 1:
                # track 1
                prefix = "Z" if k % 2 == 1 else "X"
                label = f"{prefix}{k//2}"
                dets_track_1 += [
                    coords_to_det[(*list(map(float, xy)), float(t))]
                    for xy in stab_coords[label]
                ]
    dets_track_1 = set(dets_track_1)
    other_dets = set(range(len(coords_to_det))).difference(dets_track_1)

    return dets_track_1, other_dets


def split_active_errors(
    dem_errors: stim.DetectorErrorModel,
    active_dets: set[int],
    log_obs_id: int,
) -> tuple[
    list[tuple[int, stim.DemInstruction]], list[tuple[int, stim.DemInstruction]]
]:
    active_edges = []
    active_hyperedges = []
    for ind, dem_instr in enumerate(dem_errors):
        det_ids = set(
            i.val for i in dem_instr.targets_copy() if i.is_relative_detector_id()
        )
        log_ids = set(
            i.val for i in dem_instr.targets_copy() if i.is_logical_observable_id()
        )
        active_dets_error = det_ids.intersection(active_dets)
        active_logs_error = log_ids.intersection([log_obs_id])
        targets = [stim.target_relative_detector_id(d) for d in active_dets_error]
        targets += [stim.target_logical_observable_id(d) for d in active_logs_error]
        dem_instr_sub = stim.DemInstruction(
            type="error",
            targets=targets,
            args=dem_instr.args_copy(),
        )

        if len(active_dets_error) == 0:
            continue  # inactive error
        elif len(active_dets_error) <= 2:
            active_edges.append((ind, dem_instr_sub))
        else:
            active_hyperedges.append((ind, dem_instr_sub))

    return active_edges, active_hyperedges


def get_primitive_dem(
    active_edges: list[tuple[int, stim.DemInstruction]],
) -> stim.DetectorErrorModel:
    primitive_dem = stim.DetectorErrorModel()
    for _, instr in active_edges:
        primitive_dem.append(instr)
    return primitive_dem


def get_hyperedge_decomposition(
    primitive_dem: stim.DetectorErrorModel,
    h_sub: list[tuple[int, stim.DemInstruction]],
    ignore_decomposition_failures: bool = False,
) -> dict[int, list[int]]:
    """Returns {h_ind: list[e_sub_ind]} with e_sub_ind given by index in primitive dem,
    and h_ind given by index in dem_errors.
    """
    decomposition = {}
    mwpm = Matching(primitive_dem)
    num_dets = primitive_dem.num_detectors
    edges_dict = get_edges_dict(
        primitive_dem, ignore_decomposition_failures=ignore_decomposition_failures
    )

    for h_ind, dem_instr in h_sub:
        det_inds = np.array(
            [i.val for i in dem_instr.targets_copy() if i.is_relative_detector_id()],
            dtype=int,
        )
        defects = np.zeros(num_dets, dtype=bool)
        defects[det_inds] = True
        edges = mwpm.decode_to_edges_array(defects)
        edges_inds = [edges_dict[standardize_edge(edge)] for edge in edges]
        edges_instr = [primitive_dem[i] for i in edges_inds]
        if (not valid_decomposition(dem_instr, edges_instr)) and (
            not ignore_decomposition_failures
        ):
            raise ValueError(
                f"Invalid decomposition found:\n{dem_instr}\n{edges_instr}"
            )
        decomposition[h_ind] = edges_inds

    return decomposition


def get_e_sub_to_h_supp(
    active_e_sub: list[tuple[int, stim.DemInstruction]],
    h_sub_decom: dict[int, list[int]],
) -> list[np.ndarray]:
    """Returns list where the index i corresponds to e_sub[i] or primitive_dem[i]
    and the corresponding array gives the indices of the errors in dem_errors
    whose probabilities need to be combined to obtain the probabilitiy of e_sub[i].

    Note that h_sub_decom has the decomposition of h_sub in terms of the edges
    in primitive_dem, which does not correspond to the same indices in dem_errors.
    """
    e_sub_to_h_supp = [[] for _ in active_e_sub]
    # add edges support on themselves
    for e_sub_ind, (e_ind, _) in enumerate(active_e_sub):
        e_sub_to_h_supp[e_sub_ind].append(e_ind)

    # add support of the hyperedges on the edges
    for h_ind, decom_sub in h_sub_decom.items():
        for e_sub_ind in decom_sub:
            e_sub_to_h_supp[e_sub_ind].append(h_ind)

    e_sub_to_h_supp = [np.array(sup, dtype=int) for sup in e_sub_to_h_supp]

    return e_sub_to_h_supp
@MarcSerraPeralta MarcSerraPeralta added the enhancement New feature or request label Mar 31, 2025
@MarcSerraPeralta
Copy link
Owner Author

The code in util.py is also needed:

import math
import numpy as np
import stim
from scipy.sparse import csc_matrix


def comb_probs_to_w(probs: np.ndarray) -> float:
    p = probs[0]
    for q in probs[1:]:
        p = p * (1 - q) + (1 - p) * q
    # avoid numerical issues with math.log
    eps = 1e-14
    if p < eps:
        p = eps
    elif p > 1 - eps:
        p = 1 - eps
    return -math.log(p / (1 - p))


def dem_to_hplc(
    dem: stim.DetectorErrorModel,
) -> tuple[csc_matrix, np.ndarray, csc_matrix, np.ndarray]:
    """Returns the detector-error matrix, error probabilities, logicals-error matrix,
    and the detector coordinates given a ``stim.DetectorErrorModel``.
    It keeps the ordering of the errors in ``dem.flattened()`` and the order of
    the detectors when building the output.

    Parameters
    ----------
    dem
        Detector error model (DEM).

    Returns
    -------
    det_err_matrix : np.ndarray(D, E)
        Detector-error matrix which related the error mechanisms and the detectors
        they trigger. ``D`` is the number of detectors and ``E`` the number
        of error mechanisms.
    err_probs : np.ndarray(E)
        Probabilities for each error mechanism.
    log_err_matrix : np.ndarray(L, E)
        Logicals-error matrix which relates the error mechanisms and the logical
        observables that they flip. ``L`` is the number of logical observables.
    coords : np.ndarray(D, C)
        Coordinates associated with each detector, with ``C`` the number of coordinates.
        If no coordinates are present in ``dem``, an empty array of shape ``(D,)``
        is returned.
    """
    if not isinstance(dem, stim.DetectorErrorModel):
        raise TypeError(
            f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given."
        )

    det_err_list = []
    err_probs_list = []
    log_err_list = []
    coords_dict = {}

    for instr in dem.flattened():
        if instr.type == "error":
            # get information
            p = instr.args_copy()[0]
            dets, logs = [], []
            for t in instr.targets_copy():
                if t.is_relative_detector_id():
                    dets.append(t.val)
                elif t.is_logical_observable_id():
                    logs.append(t.val)
                elif t.is_separator():
                    pass
                else:
                    raise ValueError(f"{t} is not implemented.")
            det_err_list.append(dets)
            err_probs_list.append(p)
            log_err_list.append(logs)
        elif instr.type == "detector":
            det = instr.targets_copy()[0].val
            coords_dict[det] = instr.args_copy()
        elif instr.type == "logical_observable":
            pass
        else:
            raise ValueError(f"{instr} is not implemented.")

    det_err_matrix = _list_to_csc_matrix(
        det_err_list, shape=(dem.num_detectors, len(det_err_list))
    )
    log_err_matrix = _list_to_csc_matrix(
        log_err_list, shape=(dem.num_observables, len(log_err_list))
    )
    err_probs = np.array(err_probs_list)
    coords = np.empty(shape=(dem.num_detectors))
    if coords_dict:
        if dem.num_detectors != len(coords_dict):
            raise ValueError(
                "Either all the detectors have coordinates or none,"
                " but not all of them have."
            )
        coords = np.array([coords_dict[i] for i in range(dem.num_detectors)])

    return det_err_matrix, err_probs, log_err_matrix, coords


def _list_to_csc_matrix(my_list: list[list[int]], shape: tuple[int, int]) -> csc_matrix:
    """Returns ``csc_matrix`` built form the given list.

    The output matrix has all elements zero except in each column ``i`` it has
    ones on the rows ``my_list[i]``.

    Parameters
    ----------
    my_list
        List of lists of integers containing the entries with ones in the csc_matrix.
    shape
        Shape of the ``csc_matrix``.

    Returns
    -------
    matrix
        The described ``csc_matrix`` with 0s and 1s.
    """
    if shape[1] < len(my_list):
        raise ValueError(
            "The shape of the csc_matrix is not large enough to accomodate all the data."
        )

    num_ones = sum(len(l) for l in my_list)
    data = np.ones(
        num_ones, dtype=np.uint8
    )  # smallest integer size (bool operations do not work)
    row_inds = np.empty(num_ones, dtype=int)
    col_inds = np.empty(num_ones, dtype=int)
    i = 0
    for c, det_inds in enumerate(my_list):
        for r in det_inds:
            row_inds[i] = r
            col_inds[i] = c
            i += 1

    return csc_matrix((data, (row_inds, col_inds)), shape=shape)


def dem_only_errors(dem: stim.DetectorErrorModel) -> stim.DetectorErrorModel:
    dem_errors = stim.DetectorErrorModel()
    for instr in dem.flattened():
        if instr.type == "error":
            dem_errors.append(instr)
    return dem_errors


def valid_decomposition(
    h_instr: stim.DemInstruction,
    e_instrs: list[stim.DemInstruction],
) -> bool:
    h_dets = set(i.val for i in h_instr.targets_copy() if i.is_relative_detector_id())
    h_logs = set(i.val for i in h_instr.targets_copy() if i.is_logical_observable_id())
    e_dets = set()
    e_logs = set()
    for e_instr in e_instrs:
        curr_e_dets = set(
            i.val for i in e_instr.targets_copy() if i.is_relative_detector_id()
        )
        curr_e_logs = set(
            i.val for i in e_instr.targets_copy() if i.is_logical_observable_id()
        )
        e_dets.symmetric_difference_update(curr_e_dets)
        e_logs.symmetric_difference_update(curr_e_logs)

    return (h_dets == e_dets) and (h_logs == e_logs)


def get_edges_dict(
    primitive_dem: stim.DetectorErrorModel,
    ignore_decomposition_failures: bool = False,
) -> dict[tuple[int, int], int]:
    edges_dict = {}
    ind_to_logs_dict = {}
    for ind, instr in enumerate(primitive_dem):
        dets = [i.val for i in instr.targets_copy() if i.is_relative_detector_id()]
        logs = [i.val for i in instr.targets_copy() if i.is_logical_observable_id()]
        if len(dets) == 1:
            dets.append(-1)  # boundary node
        dets = tuple(sorted(dets))
        logs = tuple(sorted(logs))

        if dets in edges_dict:
            equiv_ind = edges_dict[dets]
            if (ind_to_logs_dict[equiv_ind] == logs) or ignore_decomposition_failures:
                # repeated instruction or ignore decomposition failure
                continue
            else:
                raise ValueError(
                    "Edges triggering same detector but with different logical effect have been found."
                )

        edges_dict[dets] = ind
        ind_to_logs_dict[ind] = logs

    return edges_dict


def standardize_edge(pymatching_edge: np.ndarray) -> tuple[int, int]:
    return tuple(sorted(pymatching_edge.tolist()))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant