diff --git a/qiskit_addon_utils/exp_vals/expectation_values.py b/qiskit_addon_utils/exp_vals/expectation_values.py index 005e505..7666481 100644 --- a/qiskit_addon_utils/exp_vals/expectation_values.py +++ b/qiskit_addon_utils/exp_vals/expectation_values.py @@ -18,14 +18,17 @@ import numpy as np from qiskit.primitives import BitArray -from qiskit.quantum_info import Pauli, SparseObservable, SparsePauliOp +from qiskit.quantum_info import Pauli, PauliList, SparseObservable, SparsePauliOp + +from qiskit_addon_utils.exp_vals.measurement_bases import find_measure_basis_to_observable_mapping def executor_expectation_values( # positional-only arguments: these canNOT be specified as keyword arguments, meaning we can # rename them without breaking API - bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool]], - basis_dict: dict[Pauli, list[SparsePauliOp | None]], + bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool_]], + basis_mapping: dict[Pauli, list[SparsePauliOp | None]] + | tuple[Sequence[SparsePauliOp], Sequence[str | PauliList]], /, # positional or keyword arguments meas_basis_axis: int | None = None, @@ -33,9 +36,9 @@ def executor_expectation_values( # keyword-only arguments: these can ONLY be specified as keyword arguments. Renaming them breaks # API, but their order does not matter. avg_axis: int | tuple[int, ...] | None = None, - measurement_flips: np.ndarray[tuple[int, ...], np.dtype[np.bool]] | None = None, - pauli_signs: np.ndarray[tuple[int, ...], np.dtype[np.bool]] | None = None, - postselect_mask: np.ndarray[tuple[int, ...], np.dtype[np.bool]] | None = None, + measurement_flips: np.ndarray[tuple[int, ...], np.dtype[np.bool_]] | None = None, + pauli_signs: np.ndarray[tuple[int, ...], np.dtype[np.bool_]] | None = None, + postselect_mask: np.ndarray[tuple[int, ...], np.dtype[np.bool_]] | None = None, gamma_factor: float | None = None, rescale_factors: Sequence[Sequence[Sequence[float]]] | None = None, ): @@ -51,8 +54,8 @@ def executor_expectation_values( bool_array: Boolean array, presumably representing data from measured qubits. The last two axes are the number of shots and number of classical bits, respectively. The least significant bit is assumed to be at index `0` of the bits axis. - If `meas_basis_axis` is given, that axis of `bool_array` indexes the measurement bases, with length `len(basis_dict)`. - basis_dict: This dict encodes how the data in `bool_array` should be used to estimate the desired list of Pauli observables. + If `meas_basis_axis` is given, that axis of `bool_array` indexes the measurement bases, with length `len(basis_mapping)`. + basis_mapping: This dict encodes how the data in `bool_array` should be used to estimate the desired list of Pauli observables. The ith key is a measurement basis assumed to correspond to the ith slice of `bool_array` along the `meas_basis_axis` axis. Each dict value is a list of length equal to the number of desired observables. The jth element of this list is a `SparsePauliOp` assumed to be compatible (qubit-wise commuting) with the measurement-basis key. @@ -61,8 +64,12 @@ def executor_expectation_values( - Note the order of dict entries is relied on here for indexing; the dict keys are never used. - Assumes each Pauli term (in dict values) is compatible with each measurement basis (in keys). - Assumes each term in each observable appears for exactly one basis. - meas_basis_axis: Axis of bool_array that indexes measurement bases. Ordering must match ordering in `basis_dict`. If `None`, - then `len(basis_dict)` must be 1, and `bool_array` is assumed to correspond to the only measurement basis. + Alternatively, a tuple of (observables, measurement_bases) can be passed. A mapping between the measurement bases and the observables + will be computed. For each term in each observable, the first qubit-wise commuting basis from the bases list will be used as its measurement basis. + If no qubit-wise commuting basis is found for at least one of the terms in one of the observables, an error will be raised. + The number of bases in `measurement_bases` must be the same as the length of the meas_basis_axis in bool_array, and the order must match the order in bool_array. + meas_basis_axis: Axis of bool_array that indexes measurement bases. Ordering must match ordering in `basis_mapping`. If `None`, + then `len(basis_mapping)` must be 1, and `bool_array` is assumed to correspond to the only measurement basis. avg_axis: Optional axis or axes of bool_array to average over when computing expectation values. Usually this is the "twirling" axis. Must be nonnegative. (The shots axis, assumed to be at index -2 in the boolean array, is always averaged over). measurement_flips: Optional boolean array used with measurement twirling. Indicates which bits were acquired with measurements preceded by bit-flip gates. @@ -75,10 +82,10 @@ def executor_expectation_values( number of positive samples minus the number of negative samples, computed as `1/(np.sum(~pauli_signs, axis=avg_axis) - np.sum(pauli_signs, axis=avg_axis))`. This can fail due to division by zero if there are an equal number of positive and negative samples. Also note this rescales each expectation value by a different factor. (TODO: allow specifying an array of gamma values). - rescale_factors: Scale factor for each Pauli term in each observable in each basis in the given ``basis_dict``. + rescale_factors: Scale factor for each Pauli term in each observable in each basis in the given ``basis_mapping``. Typically used for readout mitigation ("TREX") correction factors. Each item in the list corresponds to a different basis, and contains a list of lists of factors for each term in each observable related to that basis. - The order of the bases and the observables inside each basis should be the same as in `basis_dict`. + The order of the bases and the observables inside each basis should be the same as in `basis_mapping`. For empty observables for some of the bases, keep an empty list. If `None`, scaling factor will not be applied. Returns: @@ -89,16 +96,17 @@ def executor_expectation_values( Raises: ValueError if `avg_axis` contains negative values. - ValueError if `meas_basis_axis` is `None` but `len(basis_dict) != 1`. - ValueError if the number of entries in `basis_dict` does not equal the length of `bool_array` along `meas_basis_axis`. + ValueError if `meas_basis_axis` is `None` but `len(basis_mapping) != 1`. + ValueError if the number of entries in `basis_mapping` does not equal the length of `bool_array` along `meas_basis_axis`. + ValueError if the given measurement_basis can not cover all of the terms in al lof the observables. """ ##### VALIDATE INPUTS: avg_axis = _validate_avg_axis(avg_axis, len(bool_array.shape)) if meas_basis_axis is None: - if len(basis_dict) != 1: + if len(basis_mapping) != 1: raise ValueError( - f"`meas_basis_axis` cannot be `None` unless there is only one measurement basis, but {len(basis_dict) = }. " + f"`meas_basis_axis` cannot be `None` unless there is only one measurement basis, but {len(basis_mapping) = }. " ) bool_array = bool_array.reshape((1, *bool_array.shape)) if measurement_flips is not None: @@ -112,18 +120,39 @@ def executor_expectation_values( elif meas_basis_axis < 0: raise ValueError("meas_basis_axis must be nonnegative.") - if len(basis_dict) != bool_array.shape[meas_basis_axis]: - raise ValueError( - f"{len(basis_dict) = } does not match {bool_array.shape[meas_basis_axis] = }." - ) + if isinstance(basis_mapping, dict): + if len(basis_mapping) != bool_array.shape[meas_basis_axis]: + raise ValueError( + f"{len(basis_mapping) = } does not match {bool_array.shape[meas_basis_axis] = }." + ) + elif isinstance(basis_mapping, tuple): + if len(basis_mapping) != 2: + raise ValueError( + "if basis_mapping is a tuple, it must contain observables element and measurement_bases element." + ) + if len(basis_mapping[1]) != bool_array.shape[meas_basis_axis]: + raise ValueError( + f"{len(basis_mapping[1]) = } does not match {bool_array.shape[meas_basis_axis] = }." + ) + try: + basis_mapping = find_measure_basis_to_observable_mapping( + basis_mapping[0], basis_mapping[1] + ) + except ValueError as err: + raise ValueError( + "The observables and measurement bases in `basis_mapping` do not match. " + "Please check the values of `basis_mapping` and try again." + ) from err + else: + raise ValueError("basis_mapping must be either a dict or a tuple") - for i, v in enumerate(basis_dict.values()): + for i, v in enumerate(basis_mapping.values()): if i == 0: num_observables = len(v) continue if len(v) != num_observables: raise ValueError( - f"Entry 0 in `basis_dict` indicates {num_observables} observables, but entry {i} indicates {len(v)} observables." + f"Entry 0 in `basis_mapping` indicates {num_observables} observables, but entry {i} indicates {len(v)} observables." ) ##### APPLY MEAS FLIPS: @@ -134,8 +163,8 @@ def executor_expectation_values( original_num_bits = bool_array.shape[-1] # Convert SparsePauliOps to SparseObservables - basis_dict_ = {} - for basis, spo_list in basis_dict.items(): + basis_mapping_ = {} + for basis, spo_list in basis_mapping.items(): diag_obs_list = [] for spo in spo_list: if isinstance(spo, SparseObservable): @@ -144,8 +173,8 @@ def executor_expectation_values( diag_obs_list.append(SparseObservable.zero(original_num_bits)) else: diag_obs_list.append(SparseObservable(spo)) - basis_dict_[basis] = diag_obs_list - basis_dict = basis_dict_ + basis_mapping_[basis] = diag_obs_list + basis_dict = basis_mapping_ ##### POSTSELECTION: if postselect_mask is not None: @@ -225,9 +254,9 @@ def executor_expectation_values( def _apply_postselect_mask( - bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool]], + bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool_]], basis_dict: dict[Pauli, list[SparseObservable]], - postselect_mask: np.ndarray[tuple[int, ...], np.dtype[np.bool]], + postselect_mask: np.ndarray[tuple[int, ...], np.dtype[np.bool_]], ): """Applies postselection mask in preparation for computing expectation values. @@ -278,9 +307,9 @@ def _validate_avg_axis(avg_axis: int | tuple[int, ...] | None, num_dims: int) -> def _apply_pec_signs( - bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool]], + bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool_]], basis_dict: dict[Pauli, list[SparseObservable | SparsePauliOp]], - pauli_signs: np.ndarray[tuple[int, ...], np.dtype[np.bool]], + pauli_signs: np.ndarray[tuple[int, ...], np.dtype[np.bool_]], ): """Applies PEC signs in preparation for computing expectation values. diff --git a/qiskit_addon_utils/exp_vals/measurement_bases.py b/qiskit_addon_utils/exp_vals/measurement_bases.py index 1b1e619..d61d100 100644 --- a/qiskit_addon_utils/exp_vals/measurement_bases.py +++ b/qiskit_addon_utils/exp_vals/measurement_bases.py @@ -14,13 +14,19 @@ from __future__ import annotations +from collections.abc import Sequence + import numpy as np from qiskit.quantum_info import Pauli, PauliList, SparsePauliOp def get_measurement_bases( observables: SparsePauliOp | list[SparsePauliOp], -) -> tuple[list[np.typing.NDArray[np.uint8]], dict[Pauli, list[SparsePauliOp]]]: + bases_in_int_format: bool = True, +) -> ( + tuple[list[np.typing.NDArray[np.uint8]], dict[Pauli, list[SparsePauliOp]]] + | tuple[list[str], dict[Pauli, list[SparsePauliOp]]] +): """Choose bases to sample in order to calculate expectation values for all given observables. Here a "basis" refers to measurement of a full-weight or high-weight Pauli, from which multiple qubit-wise commuting Paulis may be estimated. @@ -29,9 +35,11 @@ def get_measurement_bases( Args: observables: The observables to calculate using the quantum computer. + bases_in_int_format: If true, return bases as an array of ints, using the samplomatic convention of: I=0, Z=1, X=2, Y=3. + otherwise, return the bases as a array of strings. Returns: - * List of Pauli bases to sample encoded in a list of uint8 where 0=I,1=Z,2=X,3=Y. + * List of Pauli bases to sample encoded in a list of uint8 where 0=I,1=Z,2=X,3=Y or a list of strings (based on bases_in_int_format parameter). * Dict that maps each measured basis to the relevant Paulis and their coefficients for each observable. With the measured bases as keys, for each observable there is a SparsePauliOp representing it. """ @@ -56,7 +64,10 @@ def get_measurement_bases( paulis.append(pauli) current_basis_weight += coeff reverser[basis][i] = SparsePauliOp(paulis, coeffs) if paulis else None - bases = _convert_basis_to_uint_representation(bases) + if bases_in_int_format: + bases = _convert_basis_to_uint_representation(bases) + else: + bases = bases.to_labels() return bases, reverser @@ -93,3 +104,85 @@ def _convert_basis_to_uint_representation(bases: PauliList) -> list[np.typing.ND for pauli in bases ] return bases_uint8 + + +def _convert_to_pauli(basis): + """Converts a basis in various formats into a Pauli object. + + Can convert a string or a list of integers representing the Paulis using this convention: + 0=I, 1=Z, 2=X, 3=Y + + Args: + basis: the basis to convert. + + Returns: + The Pauli represented as a Pauli object. + + Raises: + ValueError: if the basis is in invalid format. + """ + int_mapping = {0: "I", 1: "Z", 2: "X", 3: "Y"} + if isinstance(basis, Pauli): + return basis + if isinstance(basis, str): + return Pauli(basis) + if isinstance(basis, (list, np.ndarray, tuple)) and isinstance( + basis[0], (np.unsignedinteger, int, np.integer) + ): + return Pauli("".join([int_mapping[int_val] for int_val in basis])) + + raise ValueError("basis must be a Pauli instance, str or a list of ints.") + + +def find_measure_basis_to_observable_mapping( + observables: Sequence[SparsePauliOp], measure_bases: Sequence[str | int | PauliList] +) -> dict[Pauli, list[SparsePauliOp | None]]: + """Maps each term for each observable to the first basis it qubit-wise commutes with from the given measure_bases. + + Each observable term must qubit-wise commute with at least one basis. + + Args: + observables: list of observables. + measure_bases: list of Pauli bases that the observables are measured with. + + Returns: + A dictionary mapping from basis to observables terms that commutes with them. + + Raises: + ValueError: If there is an observable with a term that does not qubit-wise commute with any basis from the given measure_bases. + """ + measure_paulis = PauliList([_convert_to_pauli(basis) for basis in measure_bases]) + measurement_dict: dict[Pauli, list[SparsePauliOp]] = {} + observables_elements_basis_found = [] + for basis in measure_paulis: + measurement_dict[basis] = [[] for _ in range(len(observables))] + + for observable_index, observable in enumerate(observables): + observables_elements_basis_found.append(np.zeros((len(observable)), dtype=np.bool_)) + for basis in measure_paulis: + basis_paulis = [] + basis_coeffs = [] + # find the elements that commutes with this basis + for element_index, (observable_element, observable_coeff) in enumerate( + zip(observable.paulis, observable.coeffs) + ): + # use only the first commuting basis found for each observable element + # TODO: enable multiple bases for each element, lowering variance in the expectation value calculation + if observables_elements_basis_found[observable_index][element_index]: + continue + commutes = ( + np.dot(observable_element.z, basis.x) + np.dot(observable_element.x, basis.z) + ) % 2 == 0 + if commutes: + basis_paulis.append(observable_element) + basis_coeffs.append(observable_coeff) + observables_elements_basis_found[observable_index][element_index] = True + measurement_dict[basis][observable_index] = ( + SparsePauliOp(basis_paulis, basis_coeffs) if basis_paulis else None + ) + if any( + False in observable_elements_list + for observable_elements_list in observables_elements_basis_found + ): + raise ValueError("Some observable elements do not commute with any measurement basis.") + return measurement_dict diff --git a/test/exp_vals/test_expectation_values.py b/test/exp_vals/test_expectation_values.py index 1f4d207..44c43c8 100644 --- a/test/exp_vals/test_expectation_values.py +++ b/test/exp_vals/test_expectation_values.py @@ -403,7 +403,7 @@ def test_basis_dict_length_mismatch(self): basis_dict, meas_basis_axis=0, ) - self.assertIn("len(basis_dict)", str(context.exception)) + self.assertIn("len(basis_mapping)", str(context.exception)) self.assertIn("does not match", str(context.exception)) def test_inconsistent_observable_counts(self): @@ -428,7 +428,7 @@ def test_inconsistent_observable_counts(self): basis_dict, meas_basis_axis=0, ) - self.assertIn("`basis_dict` indicates 2 observables, but entry", str(context.exception)) + self.assertIn("`basis_mapping` indicates 2 observables, but entry", str(context.exception)) def test_measurement_flips_shape_mismatch(self): """Test that measurement_flips with wrong shape causes issues.""" @@ -854,3 +854,208 @@ def test_exp_val_measFlips(self): seed=None, ) self.assertTrue(np.allclose(evs, target_evs)) + + +class TestExecutorExpectationValuesTupleInput(unittest.TestCase): + """Test executor_expectation_values with tuple input for basis_mapping.""" + + def test_tuple_input_single_observable_string_bases(self): + """Test tuple input with a single observable and string measurement bases.""" + # Create a simple observable + observable = SparsePauliOp("ZZ", coeffs=[1.0]) + observables = [observable] + + # Define measurement bases as strings + measurement_bases = ["ZZ"] + + # Create bool_array with shape (num_bases, num_shots, num_qubits) + num_shots = 100 + num_qubits = 2 + bool_array = np.random.randint(0, 2, size=(1, num_shots, num_qubits), dtype=bool) + + # Test with tuple input + result = executor_expectation_values( + bool_array, + (observables, measurement_bases), + meas_basis_axis=0, + ) + + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], tuple) + self.assertEqual(len(result[0]), 2) # (mean, variance) + + def test_tuple_input_single_observable_paulilist_bases(self): + """Test tuple input with a single observable and PauliList measurement bases.""" + # Create a simple observable + observable = SparsePauliOp("XX", coeffs=[1.0]) + observables = [observable] + + # Define measurement bases as PauliList + measurement_bases = PauliList(["XX"]) + + # Create bool_array with shape (num_bases, num_shots, num_qubits) + num_shots = 100 + num_qubits = 2 + bool_array = np.random.randint(0, 2, size=(1, num_shots, num_qubits), dtype=bool) + + # Test with tuple input + result = executor_expectation_values( + bool_array, + (observables, measurement_bases), + meas_basis_axis=0, + ) + + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], tuple) + self.assertEqual(len(result[0]), 2) # (mean, variance) + + def test_tuple_input_multiple_observables_multiple_bases(self): + """Test tuple input with multiple observables and multiple measurement bases.""" + # Create multiple observables + obs1 = SparsePauliOp("ZZ", coeffs=[1.0]) + obs2 = SparsePauliOp("XX", coeffs=[1.0]) + observables = [obs1, obs2] + + # Define measurement bases as strings + measurement_bases = ["ZZ", "XX"] + + # Create bool_array with shape (num_bases, num_shots, num_qubits) + num_shots = 100 + num_qubits = 2 + bool_array = np.random.randint(0, 2, size=(2, num_shots, num_qubits), dtype=bool) + + # Test with tuple input + result = executor_expectation_values( + bool_array, + (observables, measurement_bases), + meas_basis_axis=0, + ) + + self.assertEqual(len(result), 2) + for res in result: + self.assertIsInstance(res, tuple) + self.assertEqual(len(res), 2) # (mean, variance) + + def test_tuple_input_observable_with_multiple_terms(self): + """Test tuple input with an observable containing multiple terms.""" + # Create an observable with multiple terms + observable = SparsePauliOp(["ZZ", "XX"], coeffs=[1.0, 0.5]) + observables = [observable] + + # Define measurement bases - need two bases for the two terms + measurement_bases = ["ZZ", "XX"] + + # Create bool_array with shape (num_bases, num_shots, num_qubits) + num_shots = 100 + num_qubits = 2 + bool_array = np.random.randint(0, 2, size=(2, num_shots, num_qubits), dtype=bool) + + # Test with tuple input + result = executor_expectation_values( + bool_array, + (observables, measurement_bases), + meas_basis_axis=0, + ) + + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], tuple) + self.assertEqual(len(result[0]), 2) # (mean, variance) + + def test_tuple_input_with_identity_terms(self): + """Test tuple input with observables containing identity terms.""" + # Create observables with identity terms + obs1 = SparsePauliOp("ZI", coeffs=[1.0]) + obs2 = SparsePauliOp("IZ", coeffs=[1.0]) + observables = [obs1, obs2] + + # Define measurement bases - identities are measured as Z + measurement_bases = ["ZZ", "ZZ"] + + # Create bool_array with shape (num_bases, num_shots, num_qubits) + num_shots = 100 + num_qubits = 2 + bool_array = np.random.randint(0, 2, size=(2, num_shots, num_qubits), dtype=bool) + + # Test with tuple input + result = executor_expectation_values( + bool_array, + (observables, measurement_bases), + meas_basis_axis=0, + ) + + self.assertEqual(len(result), 2) + for res in result: + self.assertIsInstance(res, tuple) + self.assertEqual(len(res), 2) # (mean, variance) + + def test_tuple_input_mismatched_bases_length(self): + """Test that mismatched number of bases raises ValueError.""" + observable = SparsePauliOp("ZZ", coeffs=[1.0]) + observables = [observable] + + # Define measurement bases with wrong length + measurement_bases = ["ZZ", "XX"] # 2 bases + + # Create bool_array with only 1 basis + num_shots = 100 + num_qubits = 2 + bool_array = np.random.randint(0, 2, size=(1, num_shots, num_qubits), dtype=bool) + + # Should raise ValueError due to mismatch + with self.assertRaises(ValueError) as context: + executor_expectation_values( + bool_array, + (observables, measurement_bases), + meas_basis_axis=0, + ) + self.assertIn("does not match", str(context.exception)) + + def test_tuple_input_incompatible_observable_and_basis(self): + """Test that incompatible observable and measurement basis raises ValueError.""" + # Create an observable that doesn't qubit-wise commute with the measurement basis + # XZ doesn't qubit-wise commute with ZX (X doesn't commute with Z on first qubit) + observable = SparsePauliOp("XZ", coeffs=[1.0]) + observables = [observable] + + # Define a measurement basis that doesn't qubit-wise commute with the observable + measurement_bases = ["ZX"] # ZX doesn't qubit-wise commute with XZ + + # Create bool_array + num_shots = 100 + num_qubits = 2 + bool_array = np.random.randint(0, 2, size=(1, num_shots, num_qubits), dtype=bool) + + # Should raise ValueError because no compatible basis found + with self.assertRaises(ValueError) as context: + executor_expectation_values( + bool_array, + (observables, measurement_bases), + meas_basis_axis=0, + ) + self.assertIn("The observables and measurement bases in", str(context.exception)) + + def test_tuple_input_with_avg_axis(self): + """Test tuple input with averaging over additional axes.""" + observable = SparsePauliOp("ZZ", coeffs=[1.0]) + observables = [observable] + measurement_bases = ["ZZ"] + + # Create bool_array with extra dimension for averaging + num_shots = 100 + num_qubits = 2 + bool_array = np.random.randint(0, 2, size=(1, 3, num_shots, num_qubits), dtype=bool) + + # Test with tuple input and avg_axis + result = executor_expectation_values( + bool_array, + (observables, measurement_bases), + meas_basis_axis=0, + avg_axis=1, + ) + + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], tuple) + self.assertEqual(len(result[0]), 2) # (mean, variance) + + +# Made with Bob diff --git a/test/exp_vals/test_measurement_bases.py b/test/exp_vals/test_measurement_bases.py new file mode 100644 index 0000000..a6dc156 --- /dev/null +++ b/test/exp_vals/test_measurement_bases.py @@ -0,0 +1,401 @@ +# This code is a Qiskit project. +# +# (C) Copyright IBM 2025. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Tests for the measurement_bases module.""" + +import unittest + +import numpy as np +import pytest +from qiskit.quantum_info import Pauli, PauliList, SparsePauliOp +from qiskit_addon_utils.exp_vals.measurement_bases import ( + _convert_basis_to_uint_representation, + _convert_to_pauli, + _meas_basis_for_pauli_group, + find_measure_basis_to_observable_mapping, + get_measurement_bases, +) + + +class TestGetMeasurementBases(unittest.TestCase): + """Tests for get_measurement_bases function.""" + + def test_single_observable_single_pauli(self): + """Test with a single observable containing a single Pauli term.""" + obs = SparsePauliOp("ZZZ", 1.0) + bases, reverser = get_measurement_bases(obs, bases_in_int_format=True) + + self.assertEqual(len(bases), 1) + self.assertEqual(len(reverser), 1) + np.testing.assert_array_equal(bases[0], np.array([1, 1, 1], dtype=np.uint8)) + + # Check reverser structure + basis_pauli = next(iter(reverser.keys())) + self.assertEqual(basis_pauli, Pauli("ZZZ")) + self.assertEqual(len(reverser[basis_pauli]), 1) + self.assertIsInstance(reverser[basis_pauli][0], SparsePauliOp) + + def test_single_observable_multiple_paulis(self): + """Test with a single observable containing multiple Pauli terms.""" + obs = SparsePauliOp(["ZZI", "IZZ", "ZIZ"], [1.0, 2.0, 3.0]) + bases, _ = get_measurement_bases(obs, bases_in_int_format=True) + + # All Z-type Paulis should commute and be in one basis + self.assertEqual(len(bases), 1) + np.testing.assert_array_equal(bases[0], np.array([1, 1, 1], dtype=np.uint8)) + + def test_multiple_observables(self): + """Test with multiple observables.""" + obs1 = SparsePauliOp("ZZZ", 1.0) + obs2 = SparsePauliOp("XXX", 2.0) + bases, reverser = get_measurement_bases([obs1, obs2], bases_in_int_format=True) + + # Z and X don't commute qubit-wise, so we need 2 bases + self.assertEqual(len(bases), 2) + self.assertEqual(len(reverser), 2) + + # Check that each basis maps to a list with 2 elements (one per observable) + for _, obs_list in reverser.items(): + self.assertEqual(len(obs_list), 2) + + def test_bases_string_format(self): + """Test with bases_in_int_format=False to get string format.""" + obs = SparsePauliOp("XYZ", 1.0) + bases, _ = get_measurement_bases(obs, bases_in_int_format=False) + + self.assertEqual(len(bases), 1) + self.assertIsInstance(bases[0], str) + self.assertEqual(bases[0], "XYZ") + + def test_commuting_paulis_grouped(self): + """Test that commuting Paulis are grouped into the same basis.""" + obs = SparsePauliOp(["ZII", "IZI", "IIZ"], [1.0, 1.0, 1.0]) + bases, _ = get_measurement_bases(obs, bases_in_int_format=True) + + # All should be in one basis since they commute qubit-wise + self.assertEqual(len(bases), 1) + + def test_non_commuting_paulis_separate_bases(self): + """Test that non-commuting Paulis get separate bases.""" + obs = SparsePauliOp(["ZI", "XI"], [1.0, 1.0]) + bases, _ = get_measurement_bases(obs, bases_in_int_format=True) + + # These don't commute qubit-wise, so need separate bases + self.assertEqual(len(bases), 2) + + def test_identity_terms(self): + """Test handling of identity terms.""" + obs = SparsePauliOp(["III", "ZZZ"], [1.0, 2.0]) + bases, _ = get_measurement_bases(obs, bases_in_int_format=True) + + # Identity commutes with everything + self.assertGreaterEqual(len(bases), 1) + + def test_empty_observable_list(self): + """Test with an empty list of observables.""" + # Empty list causes sum() to return 0, which doesn't have .unique() method + # This is expected behavior - function requires at least one observable + with pytest.raises(AttributeError): + _, _ = get_measurement_bases([], bases_in_int_format=True) + + def test_reverser_none_values(self): + """Test that reverser contains None for observables without terms in a basis.""" + obs1 = SparsePauliOp("ZZ", 1.0) + obs2 = SparsePauliOp("XX", 2.0) + _, reverser = get_measurement_bases([obs1, obs2], bases_in_int_format=True) + + # Each basis should have one observable with terms and one with None + for _, obs_list in reverser.items(): + non_none_count = sum(1 for obs in obs_list if obs is not None) + self.assertEqual(non_none_count, 1) + + +class TestMeasBasisForPauliGroup(unittest.TestCase): + """Tests for _meas_basis_for_pauli_group function.""" + + def test_single_z_pauli(self): + """Test with a single Z Pauli.""" + group = PauliList(["ZII"]) + basis = _meas_basis_for_pauli_group(group) + self.assertEqual(basis, Pauli("ZII")) + + def test_single_x_pauli(self): + """Test with a single X Pauli.""" + group = PauliList(["XII"]) + basis = _meas_basis_for_pauli_group(group) + self.assertEqual(basis, Pauli("XII")) + + def test_single_y_pauli(self): + """Test with a single Y Pauli.""" + group = PauliList(["YII"]) + basis = _meas_basis_for_pauli_group(group) + self.assertEqual(basis, Pauli("YII")) + + def test_multiple_z_paulis(self): + """Test with multiple Z Paulis.""" + group = PauliList(["ZII", "IZI", "IIZ"]) + basis = _meas_basis_for_pauli_group(group) + self.assertEqual(basis, Pauli("ZZZ")) + + def test_mixed_paulis(self): + """Test with mixed Pauli types.""" + group = PauliList(["ZI", "IX"]) + basis = _meas_basis_for_pauli_group(group) + self.assertEqual(basis, Pauli("ZX")) + + def test_identity_in_group(self): + """Test with identity in the group.""" + group = PauliList(["III", "ZII"]) + basis = _meas_basis_for_pauli_group(group) + self.assertEqual(basis, Pauli("ZII")) + + def test_overlapping_paulis(self): + """Test with overlapping Pauli positions.""" + group = PauliList(["ZZI", "ZIZ"]) + basis = _meas_basis_for_pauli_group(group) + self.assertEqual(basis, Pauli("ZZZ")) + + +class TestConvertBasisToUintRepresentation(unittest.TestCase): + """Tests for _convert_basis_to_uint_representation function.""" + + def test_single_identity(self): + """Test conversion of identity.""" + bases = PauliList(["I"]) + result = _convert_basis_to_uint_representation(bases) + self.assertEqual(len(result), 1) + np.testing.assert_array_equal(result[0], np.array([0], dtype=np.uint8)) + + def test_single_z(self): + """Test conversion of Z.""" + bases = PauliList(["Z"]) + result = _convert_basis_to_uint_representation(bases) + np.testing.assert_array_equal(result[0], np.array([1], dtype=np.uint8)) + + def test_single_x(self): + """Test conversion of X.""" + bases = PauliList(["X"]) + result = _convert_basis_to_uint_representation(bases) + np.testing.assert_array_equal(result[0], np.array([2], dtype=np.uint8)) + + def test_single_y(self): + """Test conversion of Y.""" + bases = PauliList(["Y"]) + result = _convert_basis_to_uint_representation(bases) + np.testing.assert_array_equal(result[0], np.array([3], dtype=np.uint8)) + + def test_multi_qubit_pauli(self): + """Test conversion of multi-qubit Pauli.""" + bases = PauliList(["IXYZ"]) + result = _convert_basis_to_uint_representation(bases) + # Note: reversed order (little-endian) - IXYZ becomes Z,Y,X,I + np.testing.assert_array_equal(result[0], np.array([1, 3, 2, 0], dtype=np.uint8)) + + def test_multiple_bases(self): + """Test conversion of multiple bases.""" + bases = PauliList(["ZZ", "XX", "YY"]) + result = _convert_basis_to_uint_representation(bases) + self.assertEqual(len(result), 3) + np.testing.assert_array_equal(result[0], np.array([1, 1], dtype=np.uint8)) + np.testing.assert_array_equal(result[1], np.array([2, 2], dtype=np.uint8)) + np.testing.assert_array_equal(result[2], np.array([3, 3], dtype=np.uint8)) + + def test_dtype_is_uint8(self): + """Test that output dtype is uint8.""" + bases = PauliList(["XYZ"]) + result = _convert_basis_to_uint_representation(bases) + self.assertEqual(result[0].dtype, np.uint8) + + +class TestConvertToPauli(unittest.TestCase): + """Tests for _convert_to_pauli function.""" + + def test_pauli_input(self): + """Test with Pauli object as input.""" + pauli = Pauli("XYZ") + result = _convert_to_pauli(pauli) + self.assertEqual(result, pauli) + + def test_string_input(self): + """Test with string as input.""" + result = _convert_to_pauli("XYZ") + self.assertEqual(result, Pauli("XYZ")) + + def test_list_of_ints_input(self): + """Test with list of integers as input.""" + result = _convert_to_pauli([2, 3, 1]) # X, Y, Z + self.assertEqual(result, Pauli("XYZ")) + + def test_numpy_array_input(self): + """Test with numpy array as input.""" + result = _convert_to_pauli(np.array([0, 1, 2, 3], dtype=np.uint8)) + # Array [0,1,2,3] maps to I,Z,X,Y + self.assertEqual(result, Pauli("IZXY")) + + def test_tuple_input(self): + """Test with tuple as input.""" + result = _convert_to_pauli((1, 2, 3)) + self.assertEqual(result, Pauli("ZXY")) + + def test_identity_conversion(self): + """Test conversion of identity.""" + result = _convert_to_pauli([0, 0, 0]) + self.assertEqual(result, Pauli("III")) + + def test_invalid_input_type(self): + """Test with invalid input type.""" + with pytest.raises( + ValueError, match="basis must be a Pauli instance, str or a list of ints" + ): + _convert_to_pauli({"invalid": "type"}) + + def test_invalid_list_content(self): + """Test with list containing non-integers.""" + with pytest.raises((ValueError, KeyError)): + _convert_to_pauli(["X", "Y", "Z"]) + + +class TestFindMeasureBasisToObservableMapping(unittest.TestCase): + """Tests for find_measure_basis_to_observable_mapping function.""" + + def test_single_observable_single_basis(self): + """Test with single observable and single basis.""" + obs = SparsePauliOp("ZZZ", 1.0) + bases = ["ZZZ"] + result = find_measure_basis_to_observable_mapping([obs], bases) + + self.assertEqual(len(result), 1) + basis_pauli = Pauli("ZZZ") + self.assertIn(basis_pauli, result) + self.assertEqual(len(result[basis_pauli]), 1) + self.assertIsInstance(result[basis_pauli][0], SparsePauliOp) + + def test_multiple_observables_single_basis(self): + """Test with multiple observables and single basis.""" + obs1 = SparsePauliOp("ZZI", 1.0) + obs2 = SparsePauliOp("IZZ", 2.0) + bases = ["ZZZ"] + result = find_measure_basis_to_observable_mapping([obs1, obs2], bases) + + basis_pauli = Pauli("ZZZ") + self.assertEqual(len(result[basis_pauli]), 2) + + def test_single_observable_multiple_bases(self): + """Test with single observable and multiple bases.""" + obs = SparsePauliOp(["ZI", "XI"], [1.0, 2.0]) + bases = ["ZI", "XI"] + result = find_measure_basis_to_observable_mapping([obs], bases) + + self.assertEqual(len(result), 2) + # Each term should be mapped to its commuting basis + for _, obs_list in result.items(): + self.assertEqual(len(obs_list), 1) + if obs_list[0] is not None: + self.assertEqual(len(obs_list[0].paulis), 1) + + def test_basis_as_int_list(self): + """Test with basis as list of integers.""" + obs = SparsePauliOp("ZZ", 1.0) + bases = [[1, 1]] # ZZ in int format + result = find_measure_basis_to_observable_mapping([obs], bases) + + self.assertEqual(len(result), 1) + + def test_basis_as_pauli_list(self): + """Test with basis as PauliList.""" + obs = SparsePauliOp("XX", 1.0) + bases = PauliList(["XX"]) + result = find_measure_basis_to_observable_mapping([obs], bases) + + self.assertEqual(len(result), 1) + + def test_observable_term_not_commuting_with_any_basis(self): + """Test error when observable term doesn't commute with any basis.""" + obs = SparsePauliOp("XY", 1.0) + bases = ["ZZ"] # Doesn't commute with XY + + with pytest.raises( + ValueError, match="Some observable elements do not commute with any measurement basis" + ): + find_measure_basis_to_observable_mapping([obs], bases) + + def test_none_observable_in_result(self): + """Test that None is returned for observables without terms in a basis.""" + obs1 = SparsePauliOp("ZZ", 1.0) + obs2 = SparsePauliOp("XX", 2.0) + bases = ["ZZ", "XX"] + result = find_measure_basis_to_observable_mapping([obs1, obs2], bases) + + # Each basis should have one observable with terms and one None + for _, obs_list in result.items(): + self.assertEqual(len(obs_list), 2) + none_count = sum(1 for obs in obs_list if obs is None) + self.assertEqual(none_count, 1) + + def test_first_commuting_basis_used(self): + """Test that only the first commuting basis is used for each term.""" + obs = SparsePauliOp("ZZ", 1.0) + bases = ["ZZ", "ZI", "IZ"] # All commute with ZZ + result = find_measure_basis_to_observable_mapping([obs], bases) + + # The observable should only be in the first basis + first_basis = Pauli("ZZ") + self.assertIsNotNone(result[first_basis][0]) + + # Other bases should have None + second_basis = Pauli("ZI") + third_basis = Pauli("IZ") + self.assertIsNone(result[second_basis][0]) + self.assertIsNone(result[third_basis][0]) + + def test_complex_observable_with_multiple_terms(self): + """Test with complex observable containing multiple terms.""" + obs = SparsePauliOp(["ZZI", "IZZ", "XII"], [1.0, 2.0, 3.0]) + bases = ["ZZZ", "XXX"] + result = find_measure_basis_to_observable_mapping([obs], bases) + + # ZZI and IZZ should map to ZZZ, XII should map to XXX + z_basis = Pauli("ZZZ") + x_basis = Pauli("XXX") + + self.assertIsNotNone(result[z_basis][0]) + self.assertEqual(len(result[z_basis][0].paulis), 2) + + self.assertIsNotNone(result[x_basis][0]) + self.assertEqual(len(result[x_basis][0].paulis), 1) + + def test_identity_terms(self): + """Test handling of identity terms.""" + obs = SparsePauliOp(["III", "ZZZ"], [1.0, 2.0]) + bases = ["ZZZ"] + result = find_measure_basis_to_observable_mapping([obs], bases) + + # Both terms should commute with ZZZ + basis_pauli = Pauli("ZZZ") + self.assertIsNotNone(result[basis_pauli][0]) + self.assertEqual(len(result[basis_pauli][0].paulis), 2) + + def test_empty_observables(self): + """Test with empty observables list.""" + bases = ["ZZ"] + result = find_measure_basis_to_observable_mapping([], bases) + + # Should have entries for each basis but with empty lists + self.assertEqual(len(result), 1) + basis_pauli = Pauli("ZZ") + self.assertEqual(len(result[basis_pauli]), 0) + + +if __name__ == "__main__": + unittest.main() + +# Made with Bob