diff --git a/src/qldpc/codes/common.py b/src/qldpc/codes/common.py index edf2bf36..c6cc61ba 100644 --- a/src/qldpc/codes/common.py +++ b/src/qldpc/codes/common.py @@ -23,7 +23,7 @@ import itertools import random import warnings -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Collection, Mapping, Sequence from typing import Any, Iterator, cast import galois @@ -640,18 +640,42 @@ def stack(codes: Sequence[ClassicalCode]) -> ClassicalCode: matrices = [code.matrix for code in codes] return ClassicalCode(scipy.linalg.block_diag(*matrices), field=fields[0].order) - def punctured(self, bits: Sequence[int]) -> ClassicalCode: + def punctured(self, bits: Collection[int]) -> ClassicalCode: """Delete the specified bits from a code. - To delete bits from the code, we remove the corresponding columns from its generator matrix - (whose rows are code words that form a basis for the code space). + To delete bits from a code, we can remove the corresponding columns from its generator + matrix (whose rows form a basis for the code space). + + Puncturing a code is equivalent to shortening its dual code (Prop. 2.5 of arXiv:2308.15746). + Shortening a code at a bit keeps only the code words that are zero at that bit. + To shorten a code at a bit, we can + (1) row-reduce the generator matrix at that bit, + (2) delete the pivot row for that bit, and then + (3) delete the column for that bit. + + Since we represent codes by their parity check matrices, it's computationally cheaper for us + to modify parity check matrices (without converting to/form generator matrices). + Altogether, we puncture this code by shortening its dual code, applying the shortening steps + above to the parity check matrix of the code. """ assert all(0 <= bit < len(self) for bit in bits) - bits_to_keep = [bit for bit in range(len(self)) if bit not in bits] - new_generator = self.generator[:, bits_to_keep] - return ClassicalCode.from_generator(new_generator, self.field.order) - - def puncture(self, bits: Sequence[int]) -> ClassicalCode: # pragma: no cover + new_matrix = self.matrix.copy() + for bit in sorted(bits, reverse=True): + nonzero_rows = np.where(new_matrix[:, bit])[0] + if nonzero_rows.size: + pivot_row, rows_to_reduce = nonzero_rows[0], nonzero_rows[1:] + if rows_to_reduce.size: + if self.field.order == 2: + new_matrix[rows_to_reduce] -= new_matrix[pivot_row] + else: + prefactors = new_matrix[rows_to_reduce, bit] / new_matrix[pivot_row, bit] + rows_to_subtract = np.outer(prefactors, new_matrix[pivot_row]) + new_matrix[rows_to_reduce] -= rows_to_subtract + new_matrix = np.delete(new_matrix, pivot_row, axis=0).view(self.field) + new_matrix = np.delete(new_matrix, bit, axis=1).view(self.field) + return ClassicalCode(new_matrix) + + def puncture(self, bits: Collection[int]) -> ClassicalCode: # pragma: no cover """Deprecated alias for ClassicalCode.punctured.""" warnings.warn( "ClassicalCode.puncture is DEPRECATED; use ClassicalCode.punctured instead", @@ -660,23 +684,20 @@ def puncture(self, bits: Sequence[int]) -> ClassicalCode: # pragma: no cover ) return self.punctured(bits) - def shortened(self, bits: Sequence[int]) -> ClassicalCode: + def shortened(self, bits: Collection[int]) -> ClassicalCode: """Shorten a code to the words that are zero on the specified bits, and delete those bits. - To shorten a code on a given bit, we: - - move the bit to the first position, - - row-reduce the generator matrix into the form [ identity_matrix, other_stuff ], and - - delete the first row and column from the generator matrix. + Shortening a code is equivalent to puncturing the dual code; see the docstring for + ClassicalCode.punctured for additional information. + + Altogether, to shorten this code at the given bits, we remove the corresponding columns from + its parity check matrix. """ assert all(0 <= bit < len(self) for bit in bits) - generator = self.generator - for bit in sorted(bits, reverse=True): - generator = np.roll(generator, -bit, axis=1).view(self.field) - generator = generator.row_reduce()[1:, 1:] - generator = np.roll(generator, bit, axis=1).view(self.field) - return ClassicalCode.from_generator(generator) + new_matrix = np.delete(self.matrix, list(bits), axis=1).view(self.field) + return ClassicalCode(new_matrix) - def shorten(self, bits: Sequence[int]) -> ClassicalCode: # pragma: no cover + def shorten(self, bits: Collection[int]) -> ClassicalCode: # pragma: no cover """Deprecated alias for ClassicalCode.shortened.""" warnings.warn( "ClassicalCode.shorten is DEPRECATED; use ClassicalCode.shortened instead", diff --git a/src/qldpc/codes/common_test.py b/src/qldpc/codes/common_test.py index 98f02358..14d17909 100644 --- a/src/qldpc/codes/common_test.py +++ b/src/qldpc/codes/common_test.py @@ -67,11 +67,19 @@ def test_constructions_classical(pytestconfig: pytest.Config) -> None: codes.ClassicalCode(codes.ClassicalCode.random(2, 2), field=3) # construct a code from its generator matrix - code = codes.ClassicalCode.random(5, 3) + code = codes.ClassicalCode.random(6, 4, field=3) assert code.is_equiv_to(codes.ClassicalCode.from_generator(code.generator)) - # puncture a code - assert codes.ClassicalCode.from_generator(code.generator[:, 1:]) == code.punctured([0]) + # puncture and shorten a code + for field in [2, 3]: + code = codes.ClassicalCode.random(6, 4, field=field) + bits_to_remove = np.random.choice(range(len(code)), size=2, replace=False) + bits_to_keep = [bit for bit in range(len(code)) if bit not in bits_to_remove] + punctured_code = code.punctured(bits_to_remove) + assert punctured_code.is_equiv_to( + codes.ClassicalCode.from_generator(code.generator[:, bits_to_keep]) + ) + assert punctured_code.is_equiv_to(code.dual().shortened(bits_to_remove).dual()) # shortening a repetition code yields a trivial code code = codes.RepetitionCode(3)