From f8d48bf132c43161ccbf56388cb01d7bbd1751d5 Mon Sep 17 00:00:00 2001 From: "Michael A. Perlin" Date: Sun, 15 Feb 2026 13:48:05 -0500 Subject: [PATCH 1/2] speed up ClassicalCode.punctured and ClassicalCode.shortened --- src/qldpc/codes/common.py | 63 ++++++++++++++++++++++------------ src/qldpc/codes/common_test.py | 15 ++++++-- 2 files changed, 54 insertions(+), 24 deletions(-) diff --git a/src/qldpc/codes/common.py b/src/qldpc/codes/common.py index edf2bf36..c6cc61ba 100644 --- a/src/qldpc/codes/common.py +++ b/src/qldpc/codes/common.py @@ -23,7 +23,7 @@ import itertools import random import warnings -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Collection, Mapping, Sequence from typing import Any, Iterator, cast import galois @@ -640,18 +640,42 @@ def stack(codes: Sequence[ClassicalCode]) -> ClassicalCode: matrices = [code.matrix for code in codes] return ClassicalCode(scipy.linalg.block_diag(*matrices), field=fields[0].order) - def punctured(self, bits: Sequence[int]) -> ClassicalCode: + def punctured(self, bits: Collection[int]) -> ClassicalCode: """Delete the specified bits from a code. - To delete bits from the code, we remove the corresponding columns from its generator matrix - (whose rows are code words that form a basis for the code space). + To delete bits from a code, we can remove the corresponding columns from its generator + matrix (whose rows form a basis for the code space). + + Puncturing a code is equivalent to shortening its dual code (Prop. 2.5 of arXiv:2308.15746). + Shortening a code at a bit keeps only the code words that are zero at that bit. + To shorten a code at a bit, we can + (1) row-reduce the generator matrix at that bit, + (2) delete the pivot row for that bit, and then + (3) delete the column for that bit. + + Since we represent codes by their parity check matrices, it's computationally cheaper for us + to modify parity check matrices (without converting to/form generator matrices). + Altogether, we puncture this code by shortening its dual code, applying the shortening steps + above to the parity check matrix of the code. """ assert all(0 <= bit < len(self) for bit in bits) - bits_to_keep = [bit for bit in range(len(self)) if bit not in bits] - new_generator = self.generator[:, bits_to_keep] - return ClassicalCode.from_generator(new_generator, self.field.order) - - def puncture(self, bits: Sequence[int]) -> ClassicalCode: # pragma: no cover + new_matrix = self.matrix.copy() + for bit in sorted(bits, reverse=True): + nonzero_rows = np.where(new_matrix[:, bit])[0] + if nonzero_rows.size: + pivot_row, rows_to_reduce = nonzero_rows[0], nonzero_rows[1:] + if rows_to_reduce.size: + if self.field.order == 2: + new_matrix[rows_to_reduce] -= new_matrix[pivot_row] + else: + prefactors = new_matrix[rows_to_reduce, bit] / new_matrix[pivot_row, bit] + rows_to_subtract = np.outer(prefactors, new_matrix[pivot_row]) + new_matrix[rows_to_reduce] -= rows_to_subtract + new_matrix = np.delete(new_matrix, pivot_row, axis=0).view(self.field) + new_matrix = np.delete(new_matrix, bit, axis=1).view(self.field) + return ClassicalCode(new_matrix) + + def puncture(self, bits: Collection[int]) -> ClassicalCode: # pragma: no cover """Deprecated alias for ClassicalCode.punctured.""" warnings.warn( "ClassicalCode.puncture is DEPRECATED; use ClassicalCode.punctured instead", @@ -660,23 +684,20 @@ def puncture(self, bits: Sequence[int]) -> ClassicalCode: # pragma: no cover ) return self.punctured(bits) - def shortened(self, bits: Sequence[int]) -> ClassicalCode: + def shortened(self, bits: Collection[int]) -> ClassicalCode: """Shorten a code to the words that are zero on the specified bits, and delete those bits. - To shorten a code on a given bit, we: - - move the bit to the first position, - - row-reduce the generator matrix into the form [ identity_matrix, other_stuff ], and - - delete the first row and column from the generator matrix. + Shortening a code is equivalent to puncturing the dual code; see the docstring for + ClassicalCode.punctured for additional information. + + Altogether, to shorten this code at the given bits, we remove the corresponding columns from + its parity check matrix. """ assert all(0 <= bit < len(self) for bit in bits) - generator = self.generator - for bit in sorted(bits, reverse=True): - generator = np.roll(generator, -bit, axis=1).view(self.field) - generator = generator.row_reduce()[1:, 1:] - generator = np.roll(generator, bit, axis=1).view(self.field) - return ClassicalCode.from_generator(generator) + new_matrix = np.delete(self.matrix, list(bits), axis=1).view(self.field) + return ClassicalCode(new_matrix) - def shorten(self, bits: Sequence[int]) -> ClassicalCode: # pragma: no cover + def shorten(self, bits: Collection[int]) -> ClassicalCode: # pragma: no cover """Deprecated alias for ClassicalCode.shortened.""" warnings.warn( "ClassicalCode.shorten is DEPRECATED; use ClassicalCode.shortened instead", diff --git a/src/qldpc/codes/common_test.py b/src/qldpc/codes/common_test.py index 98f02358..f5b2296d 100644 --- a/src/qldpc/codes/common_test.py +++ b/src/qldpc/codes/common_test.py @@ -67,11 +67,20 @@ def test_constructions_classical(pytestconfig: pytest.Config) -> None: codes.ClassicalCode(codes.ClassicalCode.random(2, 2), field=3) # construct a code from its generator matrix - code = codes.ClassicalCode.random(5, 3) + code = codes.ClassicalCode.random(6, 4, field=3) assert code.is_equiv_to(codes.ClassicalCode.from_generator(code.generator)) - # puncture a code - assert codes.ClassicalCode.from_generator(code.generator[:, 1:]) == code.punctured([0]) + # puncture and shorten a code + for field in [2, 3]: + code = codes.ClassicalCode.random(6, 4, field=field) + bits_to_remove = np.random.choice(range(len(code)), size=2, replace=False) + bits_to_keep = [bit for bit in range(len(code)) if bit not in bits_to_remove] + assert code.punctured(bits_to_remove).is_equiv_to( + codes.ClassicalCode.from_generator(code.generator[:, bits_to_keep]) + ) + # assert code.shortened(bits_to_remove).is_equiv_to( + # codes.ClassicalCode.from_generator(code.generator[:, bits_to_keep]) + # ) # shortening a repetition code yields a trivial code code = codes.RepetitionCode(3) From 08a33e3ac5c6a69a0be713f3889e737271c6337f Mon Sep 17 00:00:00 2001 From: "Michael A. Perlin" Date: Sun, 15 Feb 2026 13:49:45 -0500 Subject: [PATCH 2/2] add test --- src/qldpc/codes/common_test.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/qldpc/codes/common_test.py b/src/qldpc/codes/common_test.py index f5b2296d..14d17909 100644 --- a/src/qldpc/codes/common_test.py +++ b/src/qldpc/codes/common_test.py @@ -75,12 +75,11 @@ def test_constructions_classical(pytestconfig: pytest.Config) -> None: code = codes.ClassicalCode.random(6, 4, field=field) bits_to_remove = np.random.choice(range(len(code)), size=2, replace=False) bits_to_keep = [bit for bit in range(len(code)) if bit not in bits_to_remove] - assert code.punctured(bits_to_remove).is_equiv_to( + punctured_code = code.punctured(bits_to_remove) + assert punctured_code.is_equiv_to( codes.ClassicalCode.from_generator(code.generator[:, bits_to_keep]) ) - # assert code.shortened(bits_to_remove).is_equiv_to( - # codes.ClassicalCode.from_generator(code.generator[:, bits_to_keep]) - # ) + assert punctured_code.is_equiv_to(code.dual().shortened(bits_to_remove).dual()) # shortening a repetition code yields a trivial code code = codes.RepetitionCode(3)