Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 59 additions & 30 deletions qiskit_addon_utils/exp_vals/expectation_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,27 @@

import numpy as np
from qiskit.primitives import BitArray
from qiskit.quantum_info import Pauli, SparseObservable, SparsePauliOp
from qiskit.quantum_info import Pauli, PauliList, SparseObservable, SparsePauliOp

from qiskit_addon_utils.exp_vals.measurement_bases import find_measure_basis_to_observable_mapping


def executor_expectation_values(
# positional-only arguments: these canNOT be specified as keyword arguments, meaning we can
# rename them without breaking API
bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool]],
basis_dict: dict[Pauli, list[SparsePauliOp | None]],
bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool_]],
basis_mapping: dict[Pauli, list[SparsePauliOp | None]]
| tuple[Sequence[SparsePauliOp], Sequence[str | PauliList]],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the type hint for the tuple input be: tuple[Sequence[SparsePauliOp], Sequence[str] | PauliList], rather than what you have here?

/,
# positional or keyword arguments
meas_basis_axis: int | None = None,
*,
# keyword-only arguments: these can ONLY be specified as keyword arguments. Renaming them breaks
# API, but their order does not matter.
avg_axis: int | tuple[int, ...] | None = None,
measurement_flips: np.ndarray[tuple[int, ...], np.dtype[np.bool]] | None = None,
pauli_signs: np.ndarray[tuple[int, ...], np.dtype[np.bool]] | None = None,
postselect_mask: np.ndarray[tuple[int, ...], np.dtype[np.bool]] | None = None,
measurement_flips: np.ndarray[tuple[int, ...], np.dtype[np.bool_]] | None = None,
pauli_signs: np.ndarray[tuple[int, ...], np.dtype[np.bool_]] | None = None,
postselect_mask: np.ndarray[tuple[int, ...], np.dtype[np.bool_]] | None = None,
gamma_factor: float | None = None,
rescale_factors: Sequence[Sequence[Sequence[float]]] | None = None,
):
Expand All @@ -51,8 +54,8 @@ def executor_expectation_values(
bool_array: Boolean array, presumably representing data from measured qubits.
The last two axes are the number of shots and number of classical bits, respectively.
The least significant bit is assumed to be at index `0` of the bits axis.
If `meas_basis_axis` is given, that axis of `bool_array` indexes the measurement bases, with length `len(basis_dict)`.
basis_dict: This dict encodes how the data in `bool_array` should be used to estimate the desired list of Pauli observables.
If `meas_basis_axis` is given, that axis of `bool_array` indexes the measurement bases, with length `len(basis_mapping)`.
basis_mapping: This dict encodes how the data in `bool_array` should be used to estimate the desired list of Pauli observables.
The ith key is a measurement basis assumed to correspond to the ith slice of `bool_array` along the `meas_basis_axis` axis.
Each dict value is a list of length equal to the number of desired observables.
The jth element of this list is a `SparsePauliOp` assumed to be compatible (qubit-wise commuting) with the measurement-basis key.
Expand All @@ -61,8 +64,12 @@ def executor_expectation_values(
- Note the order of dict entries is relied on here for indexing; the dict keys are never used.
- Assumes each Pauli term (in dict values) is compatible with each measurement basis (in keys).
- Assumes each term in each observable appears for exactly one basis.
meas_basis_axis: Axis of bool_array that indexes measurement bases. Ordering must match ordering in `basis_dict`. If `None`,
then `len(basis_dict)` must be 1, and `bool_array` is assumed to correspond to the only measurement basis.
Alternatively, a tuple of (observables, measurement_bases) can be passed. A mapping between the measurement bases and the observables
will be computed. For each term in each observable, the first qubit-wise commuting basis from the bases list will be used as its measurement basis.
If no qubit-wise commuting basis is found for at least one of the terms in one of the observables, an error will be raised.
The number of bases in `measurement_bases` must be the same as the length of the meas_basis_axis in bool_array, and the order must match the order in bool_array.
meas_basis_axis: Axis of bool_array that indexes measurement bases. Ordering must match ordering in `basis_mapping`. If `None`,
then `len(basis_mapping)` must be 1, and `bool_array` is assumed to correspond to the only measurement basis.
avg_axis: Optional axis or axes of bool_array to average over when computing expectation values. Usually this is the "twirling" axis.
Must be nonnegative. (The shots axis, assumed to be at index -2 in the boolean array, is always averaged over).
measurement_flips: Optional boolean array used with measurement twirling. Indicates which bits were acquired with measurements preceded by bit-flip gates.
Expand All @@ -75,10 +82,10 @@ def executor_expectation_values(
number of positive samples minus the number of negative samples, computed as `1/(np.sum(~pauli_signs, axis=avg_axis) - np.sum(pauli_signs, axis=avg_axis))`.
This can fail due to division by zero if there are an equal number of positive and negative samples. Also note this rescales each expectation value
by a different factor. (TODO: allow specifying an array of gamma values).
rescale_factors: Scale factor for each Pauli term in each observable in each basis in the given ``basis_dict``.
rescale_factors: Scale factor for each Pauli term in each observable in each basis in the given ``basis_mapping``.
Typically used for readout mitigation ("TREX") correction factors.
Each item in the list corresponds to a different basis, and contains a list of lists of factors for each term in each observable related to that basis.
The order of the bases and the observables inside each basis should be the same as in `basis_dict`.
The order of the bases and the observables inside each basis should be the same as in `basis_mapping`.
For empty observables for some of the bases, keep an empty list. If `None`, scaling factor will not be applied.

Returns:
Expand All @@ -89,16 +96,17 @@ def executor_expectation_values(

Raises:
ValueError if `avg_axis` contains negative values.
ValueError if `meas_basis_axis` is `None` but `len(basis_dict) != 1`.
ValueError if the number of entries in `basis_dict` does not equal the length of `bool_array` along `meas_basis_axis`.
ValueError if `meas_basis_axis` is `None` but `len(basis_mapping) != 1`.
ValueError if the number of entries in `basis_mapping` does not equal the length of `bool_array` along `meas_basis_axis`.
ValueError if the given measurement_basis can not cover all of the terms in al lof the observables.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ValueError if the given measurement_basis can not cover all of the terms in al lof the observables.
ValueError if the given measurement_basis can not cover all of the terms in all of the observables.

"""
##### VALIDATE INPUTS:
avg_axis = _validate_avg_axis(avg_axis, len(bool_array.shape))

if meas_basis_axis is None:
if len(basis_dict) != 1:
if len(basis_mapping) != 1:
raise ValueError(
f"`meas_basis_axis` cannot be `None` unless there is only one measurement basis, but {len(basis_dict) = }. "
f"`meas_basis_axis` cannot be `None` unless there is only one measurement basis, but {len(basis_mapping) = }. "
)
bool_array = bool_array.reshape((1, *bool_array.shape))
if measurement_flips is not None:
Expand All @@ -112,18 +120,39 @@ def executor_expectation_values(
elif meas_basis_axis < 0:
raise ValueError("meas_basis_axis must be nonnegative.")

if len(basis_dict) != bool_array.shape[meas_basis_axis]:
raise ValueError(
f"{len(basis_dict) = } does not match {bool_array.shape[meas_basis_axis] = }."
)
if isinstance(basis_mapping, dict):
Copy link
Copy Markdown
Collaborator

@caleb-johnson caleb-johnson Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block of code (123-147) where we check basis mapping and get everything in order can be moved to a private method at the bottom of the module. That will clear up some space for the core code in the module. I think it'll help the readability.

if len(basis_mapping) != bool_array.shape[meas_basis_axis]:
raise ValueError(
f"{len(basis_mapping) = } does not match {bool_array.shape[meas_basis_axis] = }."
)
elif isinstance(basis_mapping, tuple):
if len(basis_mapping) != 2:
raise ValueError(
"if basis_mapping is a tuple, it must contain observables element and measurement_bases element."
)
if len(basis_mapping[1]) != bool_array.shape[meas_basis_axis]:
raise ValueError(
f"{len(basis_mapping[1]) = } does not match {bool_array.shape[meas_basis_axis] = }."
)
try:
basis_mapping = find_measure_basis_to_observable_mapping(
basis_mapping[0], basis_mapping[1]
)
except ValueError as err:
raise ValueError(
"The observables and measurement bases in `basis_mapping` do not match. "
"Please check the values of `basis_mapping` and try again."
) from err
else:
raise ValueError("basis_mapping must be either a dict or a tuple")

for i, v in enumerate(basis_dict.values()):
for i, v in enumerate(basis_mapping.values()):
if i == 0:
num_observables = len(v)
continue
if len(v) != num_observables:
raise ValueError(
f"Entry 0 in `basis_dict` indicates {num_observables} observables, but entry {i} indicates {len(v)} observables."
f"Entry 0 in `basis_mapping` indicates {num_observables} observables, but entry {i} indicates {len(v)} observables."
)

##### APPLY MEAS FLIPS:
Expand All @@ -134,8 +163,8 @@ def executor_expectation_values(
original_num_bits = bool_array.shape[-1]

# Convert SparsePauliOps to SparseObservables
basis_dict_ = {}
for basis, spo_list in basis_dict.items():
basis_mapping_ = {}
for basis, spo_list in basis_mapping.items():
diag_obs_list = []
for spo in spo_list:
if isinstance(spo, SparseObservable):
Expand All @@ -144,8 +173,8 @@ def executor_expectation_values(
diag_obs_list.append(SparseObservable.zero(original_num_bits))
else:
diag_obs_list.append(SparseObservable(spo))
basis_dict_[basis] = diag_obs_list
basis_dict = basis_dict_
basis_mapping_[basis] = diag_obs_list
basis_dict = basis_mapping_

##### POSTSELECTION:
if postselect_mask is not None:
Expand Down Expand Up @@ -225,9 +254,9 @@ def executor_expectation_values(


def _apply_postselect_mask(
bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool]],
bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool_]],
basis_dict: dict[Pauli, list[SparseObservable]],
postselect_mask: np.ndarray[tuple[int, ...], np.dtype[np.bool]],
postselect_mask: np.ndarray[tuple[int, ...], np.dtype[np.bool_]],
):
"""Applies postselection mask in preparation for computing expectation values.

Expand Down Expand Up @@ -278,9 +307,9 @@ def _validate_avg_axis(avg_axis: int | tuple[int, ...] | None, num_dims: int) ->


def _apply_pec_signs(
bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool]],
bool_array: np.ndarray[tuple[int, ...], np.dtype[np.bool_]],
Copy link
Copy Markdown
Collaborator

@caleb-johnson caleb-johnson Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the linter throwing an error about this? Wondering why these changes are necessary here and above

basis_dict: dict[Pauli, list[SparseObservable | SparsePauliOp]],
pauli_signs: np.ndarray[tuple[int, ...], np.dtype[np.bool]],
pauli_signs: np.ndarray[tuple[int, ...], np.dtype[np.bool_]],
):
"""Applies PEC signs in preparation for computing expectation values.

Expand Down
99 changes: 96 additions & 3 deletions qiskit_addon_utils/exp_vals/measurement_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@

from __future__ import annotations

from collections.abc import Sequence

import numpy as np
from qiskit.quantum_info import Pauli, PauliList, SparsePauliOp


def get_measurement_bases(
observables: SparsePauliOp | list[SparsePauliOp],
) -> tuple[list[np.typing.NDArray[np.uint8]], dict[Pauli, list[SparsePauliOp]]]:
bases_in_int_format: bool = True,
) -> (
tuple[list[np.typing.NDArray[np.uint8]], dict[Pauli, list[SparsePauliOp]]]
| tuple[list[str], dict[Pauli, list[SparsePauliOp]]]
):
"""Choose bases to sample in order to calculate expectation values for all given observables.

Here a "basis" refers to measurement of a full-weight or high-weight Pauli, from which multiple qubit-wise commuting Paulis may be estimated.
Expand All @@ -29,9 +35,11 @@ def get_measurement_bases(

Args:
observables: The observables to calculate using the quantum computer.
bases_in_int_format: If true, return bases as an array of ints, using the samplomatic convention of: I=0, Z=1, X=2, Y=3.
otherwise, return the bases as a array of strings.

Returns:
* List of Pauli bases to sample encoded in a list of uint8 where 0=I,1=Z,2=X,3=Y.
* List of Pauli bases to sample encoded in a list of uint8 where 0=I,1=Z,2=X,3=Y or a list of strings (based on bases_in_int_format parameter).
* Dict that maps each measured basis to the relevant Paulis and their coefficients for each observable.
With the measured bases as keys, for each observable there is a SparsePauliOp representing it.
"""
Expand All @@ -56,7 +64,10 @@ def get_measurement_bases(
paulis.append(pauli)
current_basis_weight += coeff
reverser[basis][i] = SparsePauliOp(paulis, coeffs) if paulis else None
bases = _convert_basis_to_uint_representation(bases)
if bases_in_int_format:
bases = _convert_basis_to_uint_representation(bases)
else:
bases = bases.to_labels()

return bases, reverser

Expand Down Expand Up @@ -93,3 +104,85 @@ def _convert_basis_to_uint_representation(bases: PauliList) -> list[np.typing.ND
for pauli in bases
]
return bases_uint8


def _convert_to_pauli(basis):
"""Converts a basis in various formats into a Pauli object.

Can convert a string or a list of integers representing the Paulis using this convention:
0=I, 1=Z, 2=X, 3=Y

Args:
basis: the basis to convert.

Returns:
The Pauli represented as a Pauli object.

Raises:
ValueError: if the basis is in invalid format.
"""
int_mapping = {0: "I", 1: "Z", 2: "X", 3: "Y"}
if isinstance(basis, Pauli):
return basis
if isinstance(basis, str):
return Pauli(basis)
if isinstance(basis, (list, np.ndarray, tuple)) and isinstance(
basis[0], (np.unsignedinteger, int, np.integer)
):
return Pauli("".join([int_mapping[int_val] for int_val in basis]))

raise ValueError("basis must be a Pauli instance, str or a list of ints.")


def find_measure_basis_to_observable_mapping(
observables: Sequence[SparsePauliOp], measure_bases: Sequence[str | int | PauliList]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confused about the measure_bases type hint here too. Should it be Sequence[str | Sequence[int]] | PauliList?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also not sure how generally useful this function is. Maybe it could be a private method in the expectation_values module where it's used? We can just compress the docstring down to one line if we make it private. What do you think?

) -> dict[Pauli, list[SparsePauliOp | None]]:
"""Maps each term for each observable to the first basis it qubit-wise commutes with from the given measure_bases.

Each observable term must qubit-wise commute with at least one basis.

Args:
observables: list of observables.
measure_bases: list of Pauli bases that the observables are measured with.

Returns:
A dictionary mapping from basis to observables terms that commutes with them.

Raises:
ValueError: If there is an observable with a term that does not qubit-wise commute with any basis from the given measure_bases.
"""
measure_paulis = PauliList([_convert_to_pauli(basis) for basis in measure_bases])
measurement_dict: dict[Pauli, list[SparsePauliOp]] = {}
observables_elements_basis_found = []
for basis in measure_paulis:
measurement_dict[basis] = [[] for _ in range(len(observables))]

for observable_index, observable in enumerate(observables):
observables_elements_basis_found.append(np.zeros((len(observable)), dtype=np.bool_))
for basis in measure_paulis:
basis_paulis = []
basis_coeffs = []
# find the elements that commutes with this basis
for element_index, (observable_element, observable_coeff) in enumerate(
zip(observable.paulis, observable.coeffs)
):
# use only the first commuting basis found for each observable element
# TODO: enable multiple bases for each element, lowering variance in the expectation value calculation
if observables_elements_basis_found[observable_index][element_index]:
continue
commutes = (
np.dot(observable_element.z, basis.x) + np.dot(observable_element.x, basis.z)
) % 2 == 0
if commutes:
basis_paulis.append(observable_element)
basis_coeffs.append(observable_coeff)
observables_elements_basis_found[observable_index][element_index] = True
measurement_dict[basis][observable_index] = (
SparsePauliOp(basis_paulis, basis_coeffs) if basis_paulis else None
)
if any(
False in observable_elements_list
for observable_elements_list in observables_elements_basis_found
):
raise ValueError("Some observable elements do not commute with any measurement basis.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could print which observable doesn't commute with which basis here

return measurement_dict
Loading
Loading