Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
259b318
feat: Mid Circuit Measurement (#293)
shah-rushil Sep 12, 2025
253a450
Merge branch 'main' into mcm-experimental
speller26 Sep 20, 2025
5d38025
Merge branch 'main' into mcm-experimental
rmshaffer Oct 14, 2025
9729a19
Merge branch 'main' into mcm-experimental
speller26 Feb 2, 2026
90cc443
Merge branch 'main' into mcm-experimental
speller26 Feb 3, 2026
79ccfe9
minor fixes
speller26 Feb 5, 2026
0df15af
formatting
speller26 Feb 6, 2026
eb03649
Fixed bug and incorrect test case
speller26 Feb 6, 2026
b4c2ac8
Update .gitignore
speller26 Feb 6, 2026
4bf80a1
Delete .vscode/launch.json
speller26 Feb 6, 2026
6f1ec22
More tests
speller26 Feb 6, 2026
18e1dc2
Simplified visitor
speller26 Feb 11, 2026
e3c8982
Update branched_interpreter.py
speller26 Feb 11, 2026
115612f
Consolidate branching logic into existing classes
speller26 Feb 13, 2026
4456be9
Fix measurements
speller26 Feb 13, 2026
0ae913b
fix: Fix build failures
speller26 Feb 21, 2026
61f5a46
Merge branch 'fix' into mcm-experimental
speller26 Feb 21, 2026
4a859cb
Merge branch 'main' into mcm-experimental
speller26 Feb 23, 2026
96b9b2e
Merge branch 'main' into mcm-experimental
speller26 Feb 23, 2026
cabffb7
add unit test for an edge case
yitchen-tim Feb 26, 2026
64ab4e9
add various edge case tests for reset and MCM
yitchen-tim Mar 2, 2026
6b6d207
better classical assignment handling
speller26 Mar 3, 2026
ca29dc2
Fix `add_measure`
speller26 Mar 4, 2026
1142cb2
Move default conditionals back to interpreter
speller26 Mar 4, 2026
5d05c01
Update test_branched_mcm.py
speller26 Mar 4, 2026
1ac7b48
Add more tests
speller26 Mar 4, 2026
325c353
More tests
speller26 Mar 4, 2026
96fa8ab
Prune unreachable assertions
speller26 Mar 4, 2026
4549801
formatting
speller26 Mar 4, 2026
98a7f4a
More coverage
speller26 Mar 4, 2026
29f5899
minor fixes
speller26 Mar 4, 2026
a3abdc4
More minor fixes
speller26 Mar 4, 2026
87319ce
Simplify branch methods
speller26 Mar 5, 2026
77283e2
Simplify edge cases
speller26 Mar 5, 2026
8bca38f
Revert batch_operation_strategy
speller26 Mar 5, 2026
f927119
100% test coverage
speller26 Mar 5, 2026
a68744d
Even more simplifications
speller26 Mar 6, 2026
08f7329
rename
speller26 Mar 6, 2026
8f99f71
Update gate_operations.py
speller26 Mar 6, 2026
77fe95b
fix: allow re-measurement of qubits when MCM is supported (#351)
yitchen-tim Mar 10, 2026
2bf6c35
Allow using measurements outside branching
speller26 Mar 25, 2026
26babd4
Fix qubit reuse corner case
speller26 Mar 26, 2026
7ae10d1
Update program_context.py
speller26 Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
*.swp
*.idea
*.iml
.vscode/
.kiro/
build_files.tar.gz

.ycm_extra_conf.py
Expand Down
100 changes: 100 additions & 0 deletions src/braket/default_simulator/gate_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,106 @@ def gate_type(self) -> str:
return "gphase"


class Measure(GateOperation):
"""
Measurement operation that projects the state to a specific outcome.

This is used in branched simulation to apply measurement projections
when recalculating states from instruction sequences.
"""

def __init__(self, targets: Sequence[int], result: int = -1):
super().__init__(targets=targets)
self.result = result # The measurement outcome (0 or 1)

@property
def _base_matrix(self) -> np.ndarray:
"""
Return the projection matrix for the measurement outcome.
If result is -1 (unset), return identity (no projection).
"""
if self.result == -1:
return np.eye(2)
elif self.result == 0:
# Project to |0⟩⟨0|
return np.array([[1, 0], [0, 0]], dtype=complex)
elif self.result == 1:
# Project to |1⟩⟨1|
return np.array([[0, 0], [0, 1]], dtype=complex)
else:
return np.eye(2)

def apply(self, state: np.ndarray) -> np.ndarray:
if self.result == -1:
return state

# Apply projection matrix
projected_state = state.copy()

# For single qubit measurement, we need to project the appropriate amplitudes
if len(self._targets) == 1:
qubit_idx = self._targets[0]
n_qubits = int(np.log2(len(state)))

# Create mask for the target qubit
mask = 1 << (n_qubits - qubit_idx - 1) # Big-endian indexing

# Zero out amplitudes that don't match the measurement result
for i in range(len(projected_state)):
qubit_value = (i & mask) >> (n_qubits - qubit_idx - 1)
if qubit_value != self.result:
projected_state[i] = 0

# Normalize the state
norm = np.linalg.norm(projected_state)
if norm > 0:
projected_state /= norm

return projected_state


class Reset(GateOperation):
"""
Reset operation that sets desired target to 0
"""

def __init__(self, targets: Sequence[int]):
super().__init__(targets=targets)

@property
def _base_matrix(self) -> np.ndarray:
raise NotImplementedError("Reset does not have a matrix implementation")

def apply(self, state: np.ndarray) -> np.ndarray:

# For single qubit measurement, we need to project the appropriate amplitudes
if len(self._targets) == 1:
qubit_idx = self._targets[0]
n_qubits = int(np.log2(len(state)))

# Create mask for the target qubit
mask = 1 << (n_qubits - qubit_idx - 1) # Big-endian indexing

for i in range(len(state)):
# Check if the qubit is in state 1
qubit_value = (i & mask) >> (n_qubits - qubit_idx - 1)
if qubit_value == 1:
zero_index = i & ~mask

# Transfer the amplitude (with proper scaling)
state[zero_index] += state[i]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to copy state into a separate array first, rather than mutating the input array in place? (for consistency with how the Measure operation works)


# Set the original amplitude to zero
state[i] = 0

# Normalize the state
norm = np.linalg.norm(state)
if norm > 0:
state /= norm

return state


BRAKET_GATES = {
"i": Identity,
"h": Hadamard,
Expand Down
11 changes: 8 additions & 3 deletions src/braket/default_simulator/openqasm/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
for result in results:
self.add_result(result)

def add_instruction(self, instruction: [GateOperation, KrausOperation]) -> None:
def add_instruction(self, instruction: GateOperation | KrausOperation) -> None:
"""
Add instruction to the circuit.

Expand All @@ -62,9 +62,14 @@ def add_instruction(self, instruction: [GateOperation, KrausOperation]) -> None:
self.instructions.append(instruction)
self.qubit_set |= set(instruction.targets)

def add_measure(self, target: tuple[int], classical_targets: Iterable[int] | None = None):
def add_measure(
self,
target: tuple[int],
classical_targets: Iterable[int] | None = None,
allow_remeasure: bool = False,
):
for index, qubit in enumerate(target):
if qubit in self.measured_qubits:
if not allow_remeasure and qubit in self.measured_qubits:
raise ValueError(f"Qubit {qubit} is already measured or captured.")
self.measured_qubits.append(qubit)
self.qubit_set.add(qubit)
Expand Down
117 changes: 98 additions & 19 deletions src/braket/default_simulator/openqasm/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .circuit import Circuit
from .parser.openqasm_ast import (
AccessControl,
AliasStatement,
ArrayLiteral,
ArrayReferenceType,
ArrayType,
Expand All @@ -64,15 +65,20 @@
BitstringLiteral,
BitType,
BooleanLiteral,
BoolType,
Box,
BranchingStatement,
BreakStatement,
Cast,
ClassicalArgument,
ClassicalAssignment,
ClassicalDeclaration,
Concatenation,
ConstantDeclaration,
ContinueStatement,
DiscreteSet,
FloatLiteral,
FloatType,
ForInLoop,
FunctionCall,
GateModifierName,
Expand All @@ -81,6 +87,7 @@
IndexedIdentifier,
IndexExpression,
IntegerLiteral,
IntType,
IODeclaration,
IOKeyword,
Pragma,
Expand All @@ -101,11 +108,12 @@
SizeOf,
SubroutineDefinition,
SymbolLiteral,
UintType,
UnaryExpression,
WhileLoop,
)
from .parser.openqasm_parser import parse
from .program_context import AbstractProgramContext, ProgramContext
from .program_context import AbstractProgramContext, ProgramContext, _BreakSignal, _ContinueSignal


class Interpreter:
Expand All @@ -128,6 +136,8 @@ def __init__(
):
# context keeps track of all state
self.context = context or ProgramContext()
if self.context.supports_midcircuit_measurement:
self.context.set_visitor(self.visit)
self.logger = logger or getLogger(__name__)
self._uses_advanced_language_features = False
self._warn_advanced_features = warn_advanced_features
Expand Down Expand Up @@ -196,6 +206,12 @@ def _(self, node: ClassicalDeclaration) -> None:
init_value = create_empty_array(node_type.dimensions)
elif isinstance(node_type, BitType) and node_type.size:
init_value = create_empty_array([node_type.size])
elif isinstance(node_type, (IntType, UintType)):
init_value = IntegerLiteral(value=0)
elif isinstance(node_type, FloatType):
init_value = FloatLiteral(value=0.0)
elif isinstance(node_type, BoolType):
init_value = BooleanLiteral(value=False)
else:
init_value = None
self.context.declare_variable(node.identifier.name, node_type, init_value)
Expand Down Expand Up @@ -265,7 +281,7 @@ def _(self, node: QubitDeclaration) -> None:

@visit.register
def _(self, node: QuantumReset) -> None:
raise NotImplementedError("Reset not supported")
self.context.add_reset(list(self.context.get_qubits(self.visit(node.qubits))))

@visit.register
def _(self, node: QuantumBarrier) -> None:
Expand Down Expand Up @@ -511,7 +527,6 @@ def _(self, node: Box) -> None:

@visit.register
def _(self, node: QuantumMeasurementStatement) -> None:
"""The measure is performed but the assignment is ignored"""
qubits = self.visit(node.measure)
targets = []
if node.target and isinstance(node.target, IndexedIdentifier):
Expand All @@ -528,7 +543,8 @@ def _(self, node: QuantumMeasurementStatement) -> None:
self._uses_advanced_language_features = True
targets.extend(convert_range_def_to_range(self.visit(elem)))
case _:
targets.append(elem.value)
resolved = self.visit(elem) if isinstance(elem, Identifier) else elem
targets.append(resolved.value)

if not len(targets):
targets = None
Expand All @@ -537,10 +553,26 @@ def _(self, node: QuantumMeasurementStatement) -> None:
raise ValueError(
f"Number of qubits ({len(qubits)}) does not match number of provided classical targets ({len(targets)})"
)
self.context.add_measure(qubits, targets)
if node.target and self.context.supports_midcircuit_measurement:
self.context.add_measure(qubits, targets, classical_destination=node.target)
else:
self.context.add_measure(qubits, targets)

@visit.register
def _(self, node: ClassicalAssignment) -> None:
is_branched = getattr(self.context, "_is_branched", False)
if not is_branched or len(self.context._active_path_indices) <= 1:
self._execute_classical_assignment(node)
else:
# When multiple paths are active, evaluate the rvalue per-path
# so that expressions like ``y = x`` read from the correct path.
saved_active = list(self.context._active_path_indices)
for path_idx in saved_active:
self.context._active_path_indices = [path_idx]
self._execute_classical_assignment(deepcopy(node))
self.context._active_path_indices = saved_active

def _execute_classical_assignment(self, node: ClassicalAssignment) -> None:
lvalue_name = get_identifier_name(node.lvalue)
if self.context.get_const(lvalue_name):
raise TypeError(f"Cannot update const value {lvalue_name}")
Expand All @@ -565,29 +597,76 @@ def _(self, node: BitstringLiteral) -> ArrayLiteral:
@visit.register
def _(self, node: BranchingStatement) -> None:
self._uses_advanced_language_features = True
condition = cast_to(BooleanLiteral, self.visit(node.condition))
for statement in node.if_block if condition.value else node.else_block:
self.visit(statement)
if self.context.supports_midcircuit_measurement:
self.context.handle_branching_statement(node)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only handled under BranchingStatement would cause the following use case to fail:

OPENQASM 3.0;
qubit[3] __qubits__;
bit[1] mcm;
bit __bit_1__;
__bit_1__ = measure __qubits__[1];
mcm[0] = __bit_1__;

when mcm[0] = __bit_1__; is called, __bit_1__ is not yet initialized because it's not a branch statement. It throws NameError: Identifier '__bit_1__' is not initialized.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

else:
condition = cast_to(BooleanLiteral, self.visit(node.condition))
if condition.value:
self.visit(node.if_block)
elif node.else_block:
self.visit(node.else_block)

@visit.register
def _(self, node: ForInLoop) -> None:
self._uses_advanced_language_features = True
index = self.visit(node.set_declaration)
if isinstance(index, RangeDefinition):
index_values = [IntegerLiteral(x) for x in convert_range_def_to_range(index)]
# DiscreteSet
if self.context.supports_midcircuit_measurement:
self.context.handle_for_loop(node)
else:
index_values = index.values
for i in index_values:
with self.context.enter_scope():
self.context.declare_variable(node.identifier.name, node.type, i)
self.visit(deepcopy(node.block))
index = self.visit(node.set_declaration)
if isinstance(index, RangeDefinition):
index_values = [IntegerLiteral(x) for x in convert_range_def_to_range(index)]
else:
index_values = index.values

loop_var_name = node.identifier.name
for i in index_values:
with self.context.enter_scope():
self.context.declare_variable(loop_var_name, node.type, i)
try:
self.visit(deepcopy(node.block))
except _BreakSignal:
break
except _ContinueSignal:
continue

@visit.register
def _(self, node: WhileLoop) -> None:
self._uses_advanced_language_features = True
while cast_to(BooleanLiteral, self.visit(deepcopy(node.while_condition))).value:
self.visit(deepcopy(node.block))
if self.context.supports_midcircuit_measurement:
self.context.handle_while_loop(node)
else:
while cast_to(BooleanLiteral, self.visit(node.while_condition)).value:
try:
self.visit(deepcopy(node.block))
except _BreakSignal:
break
except _ContinueSignal:
continue

@visit.register
def _(self, node: BreakStatement) -> None:
raise _BreakSignal()

@visit.register
def _(self, node: ContinueStatement) -> None:
raise _ContinueSignal()

@visit.register
def _(self, node: AliasStatement) -> None:
"""Handle alias statements (let q1 = q, let combined = q1 ++ q2)."""
alias_name = node.target.name
if isinstance(node.value, Identifier):
# Simple alias: let q1 = q
self.context.qubit_mapping[alias_name] = self.context.get_qubits(node.value)
self.context.declare_qubit_alias(alias_name, node.value)
elif isinstance(node.value, Concatenation):
# Concatenation alias: let combined = q1 ++ q2
lhs_qubits = tuple(self.context.get_qubits(node.value.lhs))
rhs_qubits = tuple(self.context.get_qubits(node.value.rhs))
self.context.qubit_mapping[alias_name] = lhs_qubits + rhs_qubits
self.context.declare_qubit_alias(alias_name, Identifier(alias_name))
else:
raise NotImplementedError(f"Alias with {type(node.value).__name__} is not supported")

@visit.register
def _(self, node: Include) -> None:
Expand Down
Loading
Loading