diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index a225f4fba..59d80c916 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1442,7 +1442,7 @@ def _generate_frame_wf_defcal_declarations( ) -> str | None: """Generates the header where frames, waveforms and defcals are declared. - It also adds any FreeParameter of the calibrations to the circuit parameter set. + It also adds any FreeParameter that is not a gate argument to the circuit parameter set. Args: gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): The @@ -1467,21 +1467,12 @@ def _generate_frame_wf_defcal_declarations( for key, calibration in gate_definitions.items(): gate, qubits = key - - # Ignoring parametric gates - # Corresponding defcals with fixed arguments have been added - # in _get_frames_waveforms_from_instrs - if isinstance(gate, Parameterizable) and any( - not isinstance(parameter, float | int | complex) - for parameter in gate.parameters - ): - continue - gate_name = gate._qasm_name arguments = gate.parameters if isinstance(gate, Parameterizable) else [] for param in calibration.parameters: - self._parameters.add(param) + if param not in arguments: + self._parameters.add(param) arguments = [ param._to_oqpy_expression() if isinstance(param, FreeParameter) else param for param in arguments @@ -1512,78 +1503,8 @@ def _get_frames_waveforms_from_instrs( for waveform in instruction.operator.pulse_sequence._waveforms.values(): _validate_uniqueness(waveforms, waveform) waveforms[waveform.id] = waveform - # this will change with full parametric calibration support - elif isinstance(instruction.operator, Parameterizable): - fixed_argument_calibrations = self._add_fixed_argument_calibrations( - gate_definitions, instruction - ) - gate_definitions |= fixed_argument_calibrations return frames, waveforms - def _add_fixed_argument_calibrations( - self, - gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], - instruction: Instruction, - ) -> dict[tuple[Gate, QubitSet], PulseSequence]: - """Adds calibrations with arguments set to the instruction parameter values - - Given the collection of parameters in instruction.operator, this function looks for matching - parametric calibrations that have free parameters. If such a calibration is found and the - number N of its free parameters equals the number of instruction parameters, we can bind - the arguments of the calibration and add it to the calibration dictionary. - - If N is smaller, it is probably impossible to assign the instruction parameter values to the - corresponding calibration parameters so we raise an error. - If N=0, we ignore it as it will not be removed by _generate_frame_wf_defcal_declarations. - - Args: - gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence]): a dictionary of - calibrations - instruction (Instruction): a Circuit instruction - - Returns: - dict[tuple[Gate, QubitSet], PulseSequence]: additional calibrations - - Raises: - NotImplementedError: in two cases: (i) if the instruction contains unbound parameters - and the calibration dictionary contains a parametric calibration applicable to this - instructions; (ii) if the calibration is defined with a partial number of unbound - parameters. - """ - additional_calibrations = {} - for key, calibration in gate_definitions.items(): - gate = key[0] - target = key[1] - if target != instruction.target: - continue - if isinstance(gate, type(instruction.operator)) and len( - instruction.operator.parameters - ) == len(gate.parameters): - free_parameter_number = sum( - isinstance(p, FreeParameterExpression) for p in gate.parameters - ) - if free_parameter_number == 0: - continue - if free_parameter_number < len(gate.parameters): - raise NotImplementedError( - "Calibrations with a partial number of fixed parameters are not supported." - ) - if any( - isinstance(p, FreeParameterExpression) for p in instruction.operator.parameters - ): - raise NotImplementedError( - "Parametric calibrations cannot be attached with parametric circuits." - ) - bound_key = ( - type(instruction.operator)(*instruction.operator.parameters), - instruction.target, - ) - additional_calibrations[bound_key] = calibration(**{ - p.name if isinstance(p, FreeParameterExpression) else p: v - for p, v in zip(gate.parameters, instruction.operator.parameters, strict=True) - }) - return additional_calibrations - def to_unitary(self) -> np.ndarray: """Returns the unitary matrix representation of the entire circuit. diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index 6480ab8e9..afff095cb 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -149,25 +149,6 @@ def pulse_sequence_2(predefined_frame_1): ) -@pytest.fixture -def pulse_sequence_3(predefined_frame_1): - return ( - PulseSequence() - .shift_phase( - predefined_frame_1, - FreeParameter("alpha"), - ) - .shift_phase( - predefined_frame_1, - FreeParameter("beta"), - ) - .play( - predefined_frame_1, - DragGaussianWaveform(length=3e-3, sigma=0.4, beta=0.2, id="drag_gauss_wf"), - ) - ) - - @pytest.fixture def gate_calibrations(pulse_sequence, pulse_sequence_2): calibration_key = (gates.Z(), QubitSet([0, 1])) @@ -1215,15 +1196,21 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) q[0];", "rx(0.3) q[1];", "b[0] = measure q[0];", "b[1] = measure q[1];", - ]), + ]), inputs={}, ), ), @@ -1242,10 +1229,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) $0;", "rx(0.3) $4;", "b[0] = measure $0;", @@ -1271,10 +1264,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) $0;", "#pragma braket verbatim", "box{", @@ -1295,6 +1294,7 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): OpenQasmProgram( source="\n".join([ "OPENQASM 3.0;", + "qubit[5] q;", "cal {", " waveform drag_gauss_wf = drag_gaussian" + "(3.0ms, 400.0ms, 0.2, 1, false);", @@ -1303,14 +1303,20 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "rx(0.15) $0;", - "rx(0.3) $4;", - "#pragma braket noise bit_flip(0.2) $3", - "#pragma braket result expectation i($0)", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "rx(0.15) q[0];", + "rx(0.3) q[4];", + "#pragma braket noise bit_flip(0.2) q[3]", + "#pragma braket result expectation i(q[0])", ]), inputs={}, ), @@ -1332,10 +1338,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "rx(0.15) q[0];", "rx(theta) q[1];", "b[0] = measure q[0];", @@ -1363,10 +1375,16 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal rx(0.15) $0 {", + "defcal rx(float theta) $0 {", " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", "negctrl @ rx(0.15) q[2], q[0];", "ctrl(2) @ rx(0.3) q[2], q[3], q[1];", "ctrl(2) @ cnot q[2], q[3], q[4], q[0];", @@ -1384,27 +1402,37 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): OpenQASMSerializationProperties(QubitReferenceType.VIRTUAL), OpenQasmProgram( source="\n".join([ - "OPENQASM 3.0;", - "bit[7] b;", - "qubit[7] q;", - "cal {", - " waveform drag_gauss_wf = drag_gaussian" - + "(3.0ms, 400.0ms, 0.2, 1, false);", - "}", - "defcal z $0, $1 {", - " set_frequency(predefined_frame_1, 6000000.0);", - " play(predefined_frame_1, drag_gauss_wf);", - "}", - "cnot q[0], q[1];", - "cnot q[3], q[2];", - "ctrl @ cnot q[5], q[6], q[4];", - "b[0] = measure q[0];", - "b[1] = measure q[1];", - "b[2] = measure q[2];", - "b[3] = measure q[3];", - "b[4] = measure q[4];", - "b[5] = measure q[5];", - "b[6] = measure q[6];", + "OPENQASM 3.0;", + "bit[7] b;", + "qubit[7] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian" + + "(3.0ms, 400.0ms, 0.2, 1, false);", + "}", + "defcal z $0, $1 {", + " set_frequency(predefined_frame_1, 6000000.0);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "defcal rx(float theta) $0 {", + " set_frequency(predefined_frame_1, 6000000.0);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "cnot q[0], q[1];", + "cnot q[3], q[2];", + "ctrl @ cnot q[5], q[6], q[4];", + "b[0] = measure q[0];", + "b[1] = measure q[1];", + "b[2] = measure q[2];", + "b[3] = measure q[3];", + "b[4] = measure q[4];", + "b[5] = measure q[5];", + "b[6] = measure q[6];", ]), inputs={}, ), @@ -1425,10 +1453,14 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): " set_frequency(predefined_frame_1, 6000000.0);", " play(predefined_frame_1, drag_gauss_wf);", "}", - "defcal ms(-0.1, -0.2, -0.3) $0, $1 {", - " shift_phase(predefined_frame_1, -0.1);", - " set_phase(predefined_frame_1, -0.3);", - " shift_phase(predefined_frame_1, -0.2);", + "defcal rx(float theta) $0 {", + " set_frequency(predefined_frame_1, 6000000.0);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "defcal ms(float alpha, float beta, float gamma) $0, $1 {", + " shift_phase(predefined_frame_1, alpha);", + " set_phase(predefined_frame_1, gamma);", + " shift_phase(predefined_frame_1, beta);", " play(predefined_frame_1, drag_gauss_wf);", "}", "inv @ pow(2.5) @ h q[0];", @@ -1469,7 +1501,7 @@ def test_circuit_to_ir_openqasm_with_gate_calibrations( @pytest.mark.parametrize( - "circuit, calibration_key, expected_ir", + "circuit, calibration_key, input_variables, expected_ir, input_values", [ ( Circuit().rx(0, 0.2), @@ -1496,20 +1528,29 @@ def test_circuit_to_ir_openqasm_with_gate_calibrations( ), ], ) -def test_circuit_with_parametric_defcal(circuit, calibration_key, expected_ir, pulse_sequence_3): +def test_parametric_circuit_with_parametric_defcal( + circuit, calibration_key, input_variables, expected_ir, input_values, pulse_sequence_2 +): serialization_properties = OpenQASMSerializationProperties(QubitReferenceType.VIRTUAL) gate_calibrations = GateCalibrations({ - calibration_key: pulse_sequence_3, + calibration_key: pulse_sequence_2, }) - assert ( - circuit.to_ir( - ir_type=IRType.OPENQASM, - serialization_properties=serialization_properties, - gate_definitions=gate_calibrations.pulse_sequences, - ) - == expected_ir + assert circuit.to_ir( + ir_type=IRType.OPENQASM, + serialization_properties=serialization_properties, + gate_definitions=gate_calibrations.pulse_sequences, + ) == OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + *[f"input float {parameter};" for parameter in circuit.parameters], + *expected_ir, + ] + ), + inputs=input_values, ) + assert circuit.parameters == {FreeParameter(name) for name in input_variables} def test_parametric_circuit_with_fixed_argument_defcal(pulse_sequence): @@ -1625,10 +1666,10 @@ def foo( "cal {", " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", "}", - "defcal foo(-0.2) $0 {", + "defcal foo(float beta) $0 {", " shift_phase(predefined_frame_1, -0.1);", " set_phase(predefined_frame_1, -0.3);", - " shift_phase(predefined_frame_1, -0.2);", + " shift_phase(predefined_frame_1, beta);", " play(predefined_frame_1, drag_gauss_wf);", "}", "foo(-0.2) q[0];",