From 5ce65b787c294be0459e98698fdf572a7523e664 Mon Sep 17 00:00:00 2001 From: EuGig Date: Tue, 18 Jun 2024 18:52:55 +0200 Subject: [PATCH] fix: cast node.size to IntegerLiteral for qubit register size (#258) Fixes bug [#240] by casting node.size into IntegerLiteral --- .../openqasm/_helpers/functions.py | 10 ++- .../default_simulator/openqasm/interpreter.py | 9 ++- .../openqasm/test_interpreter.py | 74 +++++++++++++++++++ 3 files changed, 89 insertions(+), 4 deletions(-) diff --git a/src/braket/default_simulator/openqasm/_helpers/functions.py b/src/braket/default_simulator/openqasm/_helpers/functions.py index 0fe9f20c..1d5ee1cb 100644 --- a/src/braket/default_simulator/openqasm/_helpers/functions.py +++ b/src/braket/default_simulator/openqasm/_helpers/functions.py @@ -98,10 +98,16 @@ [BooleanLiteral(xv.value ^ yv.value) for xv, yv in zip(x.values, y.values)] ), getattr(BinaryOperator, "<<"): lambda x, y: ArrayLiteral( - x.values[y.value :] + [BooleanLiteral(False) for _ in range(y.value)] + x.values[len(y.values) :] + [BooleanLiteral(False) for _ in range(len(y.values))] + if isinstance(y, ArrayLiteral) + else x.values[y.value :] + [BooleanLiteral(False) for _ in range(y.value)] ), getattr(BinaryOperator, ">>"): lambda x, y: ArrayLiteral( - [BooleanLiteral(False) for _ in range(y.value)] + x.values[: len(x.values) - y.value] + [BooleanLiteral(False) for _ in range(len(y.values))] + + x.values[: len(x.values) - len(y.values)] + if isinstance(y, ArrayLiteral) + else [BooleanLiteral(False) for _ in range(y.value)] + + x.values[: len(x.values) - y.value] ), getattr(UnaryOperator, "~"): lambda x: ArrayLiteral( [BooleanLiteral(not v.value) for v in x.values] diff --git a/src/braket/default_simulator/openqasm/interpreter.py b/src/braket/default_simulator/openqasm/interpreter.py index 4989c595..0bb2eeae 100644 --- a/src/braket/default_simulator/openqasm/interpreter.py +++ b/src/braket/default_simulator/openqasm/interpreter.py @@ -247,8 +247,13 @@ def _(self, node: Identifier) -> LiteralType: @visit.register def _(self, node: QubitDeclaration) -> None: - size = self.visit(node.size).value if node.size else 1 - self.context.add_qubits(node.qubit.name, size) + size_arg = self.visit(node.size) + if isinstance(size_arg, ArrayLiteral) and size_arg: + size = "".join(str(cast_to(IntegerLiteral, qubit).value) for qubit in size_arg.values) + self.context.add_qubits(node.qubit.name, int(size, 2)) + else: + size = size_arg.value if size_arg else 1 + self.context.add_qubits(node.qubit.name, size) @visit.register def _(self, node: QuantumReset) -> None: diff --git a/test/unit_tests/braket/default_simulator/openqasm/test_interpreter.py b/test/unit_tests/braket/default_simulator/openqasm/test_interpreter.py index 9f468e15..0d5b79b7 100644 --- a/test/unit_tests/braket/default_simulator/openqasm/test_interpreter.py +++ b/test/unit_tests/braket/default_simulator/openqasm/test_interpreter.py @@ -2207,3 +2207,77 @@ def test_measure_invalid_qubit(): def test_measure_qubit_out_of_range(qasm, expected): with pytest.raises(IndexError, match=expected): Interpreter().build_circuit(qasm) + + +@pytest.mark.parametrize( + "qasm, expected", + [ + ( + """ + bit[2] b; + qubit["10"] r1; + b = measure r1; + """, + [0, 1], + ), + ( + """ + bit[3] b; + qubit["11"] r1; + b = measure r1; + """, + [0, 1, 2], + ), + ( + """ + bit[1] b; + qubit[!"1"] r1; + b = measure r1; + """, + [], + ), + ( + """ + qubit["1" ^ "0"] r1; + """, + [], + ), + ( + """ + bit[1] b; + qubit["1" != "0"] r1; + b = measure r1; + """, + [0], + ), + ( + """ + bit[1] b; + qubit["1" == "0"] r1; + b = measure r1; + """, + [], + ), + ( + """ + bit[1] b; + qubit[1] r1; + h r1["0" << "1"]; + b = measure r1; + """, + [0], + ), + ( + """ + bit[2] b; + qubit[1] r1; + h r1["0" >> "1"]; + b = measure r1; + """, + [0], + ), + ], +) +def test_circuit_from_string_literal(qasm, expected): + circ = Interpreter().build_circuit(source=qasm) + assert expected == circ.measured_qubits