Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions src/qldpc/codes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import itertools
import random
import warnings
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Collection, Mapping, Sequence
from typing import Any, Iterator, cast

import galois
Expand Down Expand Up @@ -640,18 +640,42 @@ def stack(codes: Sequence[ClassicalCode]) -> ClassicalCode:
matrices = [code.matrix for code in codes]
return ClassicalCode(scipy.linalg.block_diag(*matrices), field=fields[0].order)

def punctured(self, bits: Sequence[int]) -> ClassicalCode:
def punctured(self, bits: Collection[int]) -> ClassicalCode:
"""Delete the specified bits from a code.

To delete bits from the code, we remove the corresponding columns from its generator matrix
(whose rows are code words that form a basis for the code space).
To delete bits from a code, we can remove the corresponding columns from its generator
matrix (whose rows form a basis for the code space).

Puncturing a code is equivalent to shortening its dual code (Prop. 2.5 of arXiv:2308.15746).
Shortening a code at a bit keeps only the code words that are zero at that bit.
To shorten a code at a bit, we can
(1) row-reduce the generator matrix at that bit,
(2) delete the pivot row for that bit, and then
(3) delete the column for that bit.

Since we represent codes by their parity check matrices, it's computationally cheaper for us
to modify parity check matrices (without converting to/form generator matrices).
Altogether, we puncture this code by shortening its dual code, applying the shortening steps
above to the parity check matrix of the code.
"""
assert all(0 <= bit < len(self) for bit in bits)
bits_to_keep = [bit for bit in range(len(self)) if bit not in bits]
new_generator = self.generator[:, bits_to_keep]
return ClassicalCode.from_generator(new_generator, self.field.order)

def puncture(self, bits: Sequence[int]) -> ClassicalCode: # pragma: no cover
new_matrix = self.matrix.copy()
for bit in sorted(bits, reverse=True):
nonzero_rows = np.where(new_matrix[:, bit])[0]
if nonzero_rows.size:
pivot_row, rows_to_reduce = nonzero_rows[0], nonzero_rows[1:]
if rows_to_reduce.size:
if self.field.order == 2:
new_matrix[rows_to_reduce] -= new_matrix[pivot_row]
else:
prefactors = new_matrix[rows_to_reduce, bit] / new_matrix[pivot_row, bit]
rows_to_subtract = np.outer(prefactors, new_matrix[pivot_row])
new_matrix[rows_to_reduce] -= rows_to_subtract
new_matrix = np.delete(new_matrix, pivot_row, axis=0).view(self.field)
new_matrix = np.delete(new_matrix, bit, axis=1).view(self.field)
return ClassicalCode(new_matrix)

def puncture(self, bits: Collection[int]) -> ClassicalCode: # pragma: no cover
"""Deprecated alias for ClassicalCode.punctured."""
warnings.warn(
"ClassicalCode.puncture is DEPRECATED; use ClassicalCode.punctured instead",
Expand All @@ -660,23 +684,20 @@ def puncture(self, bits: Sequence[int]) -> ClassicalCode: # pragma: no cover
)
return self.punctured(bits)

def shortened(self, bits: Sequence[int]) -> ClassicalCode:
def shortened(self, bits: Collection[int]) -> ClassicalCode:
"""Shorten a code to the words that are zero on the specified bits, and delete those bits.

To shorten a code on a given bit, we:
- move the bit to the first position,
- row-reduce the generator matrix into the form [ identity_matrix, other_stuff ], and
- delete the first row and column from the generator matrix.
Shortening a code is equivalent to puncturing the dual code; see the docstring for
ClassicalCode.punctured for additional information.

Altogether, to shorten this code at the given bits, we remove the corresponding columns from
its parity check matrix.
"""
assert all(0 <= bit < len(self) for bit in bits)
generator = self.generator
for bit in sorted(bits, reverse=True):
generator = np.roll(generator, -bit, axis=1).view(self.field)
generator = generator.row_reduce()[1:, 1:]
generator = np.roll(generator, bit, axis=1).view(self.field)
return ClassicalCode.from_generator(generator)
new_matrix = np.delete(self.matrix, list(bits), axis=1).view(self.field)
return ClassicalCode(new_matrix)

def shorten(self, bits: Sequence[int]) -> ClassicalCode: # pragma: no cover
def shorten(self, bits: Collection[int]) -> ClassicalCode: # pragma: no cover
"""Deprecated alias for ClassicalCode.shortened."""
warnings.warn(
"ClassicalCode.shorten is DEPRECATED; use ClassicalCode.shortened instead",
Expand Down
14 changes: 11 additions & 3 deletions src/qldpc/codes/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,19 @@ def test_constructions_classical(pytestconfig: pytest.Config) -> None:
codes.ClassicalCode(codes.ClassicalCode.random(2, 2), field=3)

# construct a code from its generator matrix
code = codes.ClassicalCode.random(5, 3)
code = codes.ClassicalCode.random(6, 4, field=3)
assert code.is_equiv_to(codes.ClassicalCode.from_generator(code.generator))

# puncture a code
assert codes.ClassicalCode.from_generator(code.generator[:, 1:]) == code.punctured([0])
# puncture and shorten a code
for field in [2, 3]:
code = codes.ClassicalCode.random(6, 4, field=field)
bits_to_remove = np.random.choice(range(len(code)), size=2, replace=False)
bits_to_keep = [bit for bit in range(len(code)) if bit not in bits_to_remove]
punctured_code = code.punctured(bits_to_remove)
assert punctured_code.is_equiv_to(
codes.ClassicalCode.from_generator(code.generator[:, bits_to_keep])
)
assert punctured_code.is_equiv_to(code.dual().shortened(bits_to_remove).dual())

# shortening a repetition code yields a trivial code
code = codes.RepetitionCode(3)
Expand Down
Loading