-
Notifications
You must be signed in to change notification settings - Fork 31
feat: Classical control flow with mid circuit measurements #347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
259b318
253a450
5d38025
9729a19
90cc443
79ccfe9
0df15af
eb03649
b4c2ac8
4bf80a1
6f1ec22
18e1dc2
e3c8982
115612f
4456be9
0ae913b
61f5a46
4a859cb
96b9b2e
cabffb7
64ab4e9
6b6d207
ca29dc2
1142cb2
5d05c01
1ac7b48
325c353
96fa8ab
4549801
98a7f4a
29f5899
a3abdc4
87319ce
77283e2
8bca38f
f927119
a68744d
08f7329
8f99f71
77fe95b
2bf6c35
26babd4
7ae10d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,8 @@ | |
| *.swp | ||
| *.idea | ||
| *.iml | ||
| .vscode/ | ||
| .kiro/ | ||
| build_files.tar.gz | ||
|
|
||
| .ycm_extra_conf.py | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,7 @@ | |
| from .circuit import Circuit | ||
| from .parser.openqasm_ast import ( | ||
| AccessControl, | ||
| AliasStatement, | ||
| ArrayLiteral, | ||
| ArrayReferenceType, | ||
| ArrayType, | ||
|
|
@@ -64,15 +65,20 @@ | |
| BitstringLiteral, | ||
| BitType, | ||
| BooleanLiteral, | ||
| BoolType, | ||
| Box, | ||
| BranchingStatement, | ||
| BreakStatement, | ||
| Cast, | ||
| ClassicalArgument, | ||
| ClassicalAssignment, | ||
| ClassicalDeclaration, | ||
| Concatenation, | ||
| ConstantDeclaration, | ||
| ContinueStatement, | ||
| DiscreteSet, | ||
| FloatLiteral, | ||
| FloatType, | ||
| ForInLoop, | ||
| FunctionCall, | ||
| GateModifierName, | ||
|
|
@@ -81,6 +87,7 @@ | |
| IndexedIdentifier, | ||
| IndexExpression, | ||
| IntegerLiteral, | ||
| IntType, | ||
| IODeclaration, | ||
| IOKeyword, | ||
| Pragma, | ||
|
|
@@ -101,11 +108,12 @@ | |
| SizeOf, | ||
| SubroutineDefinition, | ||
| SymbolLiteral, | ||
| UintType, | ||
| UnaryExpression, | ||
| WhileLoop, | ||
| ) | ||
| from .parser.openqasm_parser import parse | ||
| from .program_context import AbstractProgramContext, ProgramContext | ||
| from .program_context import AbstractProgramContext, ProgramContext, _BreakSignal, _ContinueSignal | ||
|
|
||
|
|
||
| class Interpreter: | ||
|
|
@@ -128,6 +136,8 @@ def __init__( | |
| ): | ||
| # context keeps track of all state | ||
| self.context = context or ProgramContext() | ||
| if self.context.supports_midcircuit_measurement: | ||
| self.context.set_visitor(self.visit) | ||
| self.logger = logger or getLogger(__name__) | ||
| self._uses_advanced_language_features = False | ||
| self._warn_advanced_features = warn_advanced_features | ||
|
|
@@ -196,6 +206,12 @@ def _(self, node: ClassicalDeclaration) -> None: | |
| init_value = create_empty_array(node_type.dimensions) | ||
| elif isinstance(node_type, BitType) and node_type.size: | ||
| init_value = create_empty_array([node_type.size]) | ||
| elif isinstance(node_type, (IntType, UintType)): | ||
| init_value = IntegerLiteral(value=0) | ||
| elif isinstance(node_type, FloatType): | ||
| init_value = FloatLiteral(value=0.0) | ||
| elif isinstance(node_type, BoolType): | ||
| init_value = BooleanLiteral(value=False) | ||
| else: | ||
| init_value = None | ||
| self.context.declare_variable(node.identifier.name, node_type, init_value) | ||
|
|
@@ -265,7 +281,7 @@ def _(self, node: QubitDeclaration) -> None: | |
|
|
||
| @visit.register | ||
| def _(self, node: QuantumReset) -> None: | ||
| raise NotImplementedError("Reset not supported") | ||
| self.context.add_reset(list(self.context.get_qubits(self.visit(node.qubits)))) | ||
|
|
||
| @visit.register | ||
| def _(self, node: QuantumBarrier) -> None: | ||
|
|
@@ -511,7 +527,6 @@ def _(self, node: Box) -> None: | |
|
|
||
| @visit.register | ||
| def _(self, node: QuantumMeasurementStatement) -> None: | ||
| """The measure is performed but the assignment is ignored""" | ||
| qubits = self.visit(node.measure) | ||
| targets = [] | ||
| if node.target and isinstance(node.target, IndexedIdentifier): | ||
|
|
@@ -528,7 +543,8 @@ def _(self, node: QuantumMeasurementStatement) -> None: | |
| self._uses_advanced_language_features = True | ||
| targets.extend(convert_range_def_to_range(self.visit(elem))) | ||
| case _: | ||
| targets.append(elem.value) | ||
| resolved = self.visit(elem) if isinstance(elem, Identifier) else elem | ||
| targets.append(resolved.value) | ||
|
|
||
| if not len(targets): | ||
| targets = None | ||
|
|
@@ -537,10 +553,26 @@ def _(self, node: QuantumMeasurementStatement) -> None: | |
| raise ValueError( | ||
| f"Number of qubits ({len(qubits)}) does not match number of provided classical targets ({len(targets)})" | ||
| ) | ||
| self.context.add_measure(qubits, targets) | ||
| if node.target and self.context.supports_midcircuit_measurement: | ||
| self.context.add_measure(qubits, targets, classical_destination=node.target) | ||
| else: | ||
| self.context.add_measure(qubits, targets) | ||
|
|
||
| @visit.register | ||
| def _(self, node: ClassicalAssignment) -> None: | ||
| is_branched = getattr(self.context, "_is_branched", False) | ||
| if not is_branched or len(self.context._active_path_indices) <= 1: | ||
| self._execute_classical_assignment(node) | ||
| else: | ||
| # When multiple paths are active, evaluate the rvalue per-path | ||
| # so that expressions like ``y = x`` read from the correct path. | ||
| saved_active = list(self.context._active_path_indices) | ||
| for path_idx in saved_active: | ||
| self.context._active_path_indices = [path_idx] | ||
| self._execute_classical_assignment(deepcopy(node)) | ||
| self.context._active_path_indices = saved_active | ||
|
|
||
| def _execute_classical_assignment(self, node: ClassicalAssignment) -> None: | ||
| lvalue_name = get_identifier_name(node.lvalue) | ||
| if self.context.get_const(lvalue_name): | ||
| raise TypeError(f"Cannot update const value {lvalue_name}") | ||
|
|
@@ -565,29 +597,76 @@ def _(self, node: BitstringLiteral) -> ArrayLiteral: | |
| @visit.register | ||
| def _(self, node: BranchingStatement) -> None: | ||
| self._uses_advanced_language_features = True | ||
| condition = cast_to(BooleanLiteral, self.visit(node.condition)) | ||
| for statement in node.if_block if condition.value else node.else_block: | ||
| self.visit(statement) | ||
| if self.context.supports_midcircuit_measurement: | ||
| self.context.handle_branching_statement(node) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only handled under when
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed! |
||
| else: | ||
| condition = cast_to(BooleanLiteral, self.visit(node.condition)) | ||
| if condition.value: | ||
| self.visit(node.if_block) | ||
| elif node.else_block: | ||
| self.visit(node.else_block) | ||
|
|
||
| @visit.register | ||
| def _(self, node: ForInLoop) -> None: | ||
| self._uses_advanced_language_features = True | ||
| index = self.visit(node.set_declaration) | ||
| if isinstance(index, RangeDefinition): | ||
| index_values = [IntegerLiteral(x) for x in convert_range_def_to_range(index)] | ||
| # DiscreteSet | ||
| if self.context.supports_midcircuit_measurement: | ||
| self.context.handle_for_loop(node) | ||
| else: | ||
| index_values = index.values | ||
| for i in index_values: | ||
| with self.context.enter_scope(): | ||
| self.context.declare_variable(node.identifier.name, node.type, i) | ||
| self.visit(deepcopy(node.block)) | ||
| index = self.visit(node.set_declaration) | ||
| if isinstance(index, RangeDefinition): | ||
| index_values = [IntegerLiteral(x) for x in convert_range_def_to_range(index)] | ||
| else: | ||
| index_values = index.values | ||
|
|
||
| loop_var_name = node.identifier.name | ||
| for i in index_values: | ||
| with self.context.enter_scope(): | ||
| self.context.declare_variable(loop_var_name, node.type, i) | ||
| try: | ||
| self.visit(deepcopy(node.block)) | ||
| except _BreakSignal: | ||
| break | ||
| except _ContinueSignal: | ||
| continue | ||
|
|
||
| @visit.register | ||
| def _(self, node: WhileLoop) -> None: | ||
| self._uses_advanced_language_features = True | ||
| while cast_to(BooleanLiteral, self.visit(deepcopy(node.while_condition))).value: | ||
| self.visit(deepcopy(node.block)) | ||
| if self.context.supports_midcircuit_measurement: | ||
| self.context.handle_while_loop(node) | ||
| else: | ||
| while cast_to(BooleanLiteral, self.visit(node.while_condition)).value: | ||
| try: | ||
| self.visit(deepcopy(node.block)) | ||
| except _BreakSignal: | ||
| break | ||
| except _ContinueSignal: | ||
| continue | ||
|
|
||
| @visit.register | ||
| def _(self, node: BreakStatement) -> None: | ||
| raise _BreakSignal() | ||
|
|
||
| @visit.register | ||
| def _(self, node: ContinueStatement) -> None: | ||
| raise _ContinueSignal() | ||
|
|
||
| @visit.register | ||
| def _(self, node: AliasStatement) -> None: | ||
| """Handle alias statements (let q1 = q, let combined = q1 ++ q2).""" | ||
speller26 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| alias_name = node.target.name | ||
| if isinstance(node.value, Identifier): | ||
| # Simple alias: let q1 = q | ||
| self.context.qubit_mapping[alias_name] = self.context.get_qubits(node.value) | ||
| self.context.declare_qubit_alias(alias_name, node.value) | ||
| elif isinstance(node.value, Concatenation): | ||
| # Concatenation alias: let combined = q1 ++ q2 | ||
| lhs_qubits = tuple(self.context.get_qubits(node.value.lhs)) | ||
| rhs_qubits = tuple(self.context.get_qubits(node.value.rhs)) | ||
| self.context.qubit_mapping[alias_name] = lhs_qubits + rhs_qubits | ||
| self.context.declare_qubit_alias(alias_name, Identifier(alias_name)) | ||
| else: | ||
| raise NotImplementedError(f"Alias with {type(node.value).__name__} is not supported") | ||
|
|
||
| @visit.register | ||
| def _(self, node: Include) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to copy
stateinto a separate array first, rather than mutating the input array in place? (for consistency with how theMeasureoperation works)