From b43e30337de57f0e32504a2ea973b5949e4b14b9 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 22 Feb 2022 10:13:45 +0100 Subject: [PATCH 01/35] Merge `append_child` with `add_measurement` where possible --- qupulse/pulses/constant_pulse_template.py | 4 +--- qupulse/pulses/pulse_template.py | 6 ++---- qupulse/pulses/repetition_pulse_template.py | 4 +--- tests/pulses/pulse_template_tests.py | 4 +--- tests/pulses/sequencing_dummies.py | 7 +++---- 5 files changed, 8 insertions(+), 17 deletions(-) diff --git a/qupulse/pulses/constant_pulse_template.py b/qupulse/pulses/constant_pulse_template.py index 63cf6deab..57b233b7f 100644 --- a/qupulse/pulses/constant_pulse_template.py +++ b/qupulse/pulses/constant_pulse_template.py @@ -104,15 +104,13 @@ def _internal_create_program(self, *, waveform = self.build_waveform(parameters=parameters, channel_mapping=channel_mapping) if waveform: - measurements: List[Any] = [] measurements = self.get_measurement_windows(parameters=parameters, measurement_mapping=measurement_mapping) if global_transformation: waveform = TransformingWaveform(waveform, global_transformation) - parent_loop.add_measurements(measurements=measurements) - parent_loop.append_child(waveform=waveform) + parent_loop.append_child(waveform=waveform, measurements=measurements) def build_waveform(self, parameters: Dict[str, numbers.Real], diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index ecd772199..0814daf90 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -223,8 +223,7 @@ def _create_program(self, *, for measurement_name, (begins, lengths) in measurements.items(): measurement_window_list.extend(zip(itertools.repeat(measurement_name), begins, lengths)) - parent_loop.add_measurements(measurement_window_list) - parent_loop.append_child(waveform=waveform) + parent_loop.append_child(waveform=waveform, measurements=measurement_window_list) else: self._internal_create_program(scope=scope, @@ -328,8 +327,7 @@ def _internal_create_program(self, *, if global_transformation: waveform = TransformingWaveform(waveform, global_transformation) - parent_loop.add_measurements(measurements=measurements) - parent_loop.append_child(waveform=waveform) + parent_loop.append_child(waveform=waveform, measurements=measurements) @abstractmethod def build_waveform(self, diff --git a/qupulse/pulses/repetition_pulse_template.py b/qupulse/pulses/repetition_pulse_template.py index 81e240772..309366082 100644 --- a/qupulse/pulses/repetition_pulse_template.py +++ b/qupulse/pulses/repetition_pulse_template.py @@ -128,10 +128,8 @@ def _internal_create_program(self, *, parent_loop=repj_loop) if repj_loop.waveform is not None or len(repj_loop.children) > 0: measurements = self.get_measurement_windows(scope, measurement_mapping) - if measurements: - parent_loop.add_measurements(measurements) - parent_loop.append_child(loop=repj_loop) + parent_loop.append_child(loop=repj_loop, measurements=measurements) def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: data = super().get_serialization_data(serializer) diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index e322d8cee..d2643a56d 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -87,9 +87,7 @@ def get_appending_internal_create_program(waveform=DummyWaveform(), measurements: list=None): def internal_create_program(*, scope, parent_loop: Loop, **_): if always_append or 'append_a_child' in scope: - if measurements is not None: - parent_loop.add_measurements(measurements=measurements) - parent_loop.append_child(waveform=waveform) + parent_loop.append_child(waveform=waveform, measurements=measurements) return internal_create_program diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 45ccfdc40..0036cc25a 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -235,11 +235,10 @@ def _internal_create_program(self, *, measurements = self.get_measurement_windows(scope, measurement_mapping) self.create_program_calls.append((scope, measurement_mapping, channel_mapping, parent_loop)) if self._program: - parent_loop.add_measurements(measurements) - parent_loop.append_child(waveform=self._program.waveform, children=self._program.children) + parent_loop.append_child(waveform=self._program.waveform, children=self._program.children, + measurements=measurements) elif self.waveform: - parent_loop.add_measurements(measurements) - parent_loop.append_child(waveform=self.waveform) + parent_loop.append_child(waveform=self.waveform, measurements=measurements) def build_waveform(self, parameters: Dict[str, Parameter], From 243d1f7427e8e86f750d306f6ba097d01bb45cc4 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 22 Feb 2022 10:28:10 +0100 Subject: [PATCH 02/35] Insert extra loop for SequencePT and ForLoopPT --- qupulse/_program/_loop.py | 3 +++ qupulse/pulses/loop_pulse_template.py | 27 ++++++++------------- qupulse/pulses/repetition_pulse_template.py | 7 +++--- qupulse/pulses/sequence_pulse_template.py | 22 ++++++++--------- 4 files changed, 26 insertions(+), 33 deletions(-) diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index 8c9f7219d..b52aeb8e8 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -101,6 +101,9 @@ def add_measurements(self, measurements: Iterable[MeasurementWindow]): Args: measurements: Measurements to add """ + warnings.warn("Loop.add_measurements is deprecated since qupulse 0.7 and will be removed in a future version.", + DeprecationWarning, + stacklevel=2) body_duration = float(self.body_duration) if body_duration == 0: measurements = measurements diff --git a/qupulse/pulses/loop_pulse_template.py b/qupulse/pulses/loop_pulse_template.py index 82236c71d..592b1edc9 100644 --- a/qupulse/pulses/loop_pulse_template.py +++ b/qupulse/pulses/loop_pulse_template.py @@ -217,23 +217,16 @@ def _internal_create_program(self, *, parent_loop: Loop) -> None: self.validate_scope(scope=scope) - try: - duration = self.duration.evaluate_in_scope(scope) - except ExpressionVariableMissingException as err: - raise ParameterNotProvidedException(err.variable) from err - - if duration > 0: - measurements = self.get_measurement_windows(scope, measurement_mapping) - if measurements: - parent_loop.add_measurements(measurements) - - for local_scope in self._body_scope_generator(scope, forward=True): - self.body._create_program(scope=local_scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=parent_loop) + self_loop = Loop(measurements=self.get_measurement_windows(scope, measurement_mapping) or None) + for local_scope in self._body_scope_generator(scope, forward=True): + self.body._create_program(scope=local_scope, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + global_transformation=global_transformation, + to_single_waveform=to_single_waveform, + parent_loop=self_loop) + if self_loop.duration > 0: + parent_loop.append_child(self_loop) def build_waveform(self, parameter_scope: Scope) -> ForLoopWaveform: return ForLoopWaveform([self.body.build_waveform(local_scope) diff --git a/qupulse/pulses/repetition_pulse_template.py b/qupulse/pulses/repetition_pulse_template.py index 309366082..a2ba542f5 100644 --- a/qupulse/pulses/repetition_pulse_template.py +++ b/qupulse/pulses/repetition_pulse_template.py @@ -119,7 +119,8 @@ def _internal_create_program(self, *, else: repetition_definition = repetition_count - repj_loop = Loop(repetition_count=repetition_definition) + repj_loop = Loop(repetition_count=repetition_definition, + measurements=self.get_measurement_windows(scope, measurement_mapping) or None) self.body._create_program(scope=scope, measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, @@ -127,9 +128,7 @@ def _internal_create_program(self, *, to_single_waveform=to_single_waveform, parent_loop=repj_loop) if repj_loop.waveform is not None or len(repj_loop.children) > 0: - measurements = self.get_measurement_windows(scope, measurement_mapping) - - parent_loop.append_child(loop=repj_loop, measurements=measurements) + parent_loop.append_child(loop=repj_loop) def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: data = super().get_serialization_data(serializer) diff --git a/qupulse/pulses/sequence_pulse_template.py b/qupulse/pulses/sequence_pulse_template.py index a42d3ddb2..b8157b0dd 100644 --- a/qupulse/pulses/sequence_pulse_template.py +++ b/qupulse/pulses/sequence_pulse_template.py @@ -136,18 +136,16 @@ def _internal_create_program(self, *, parent_loop: Loop) -> None: self.validate_scope(scope) - if self.duration.evaluate_in_scope(scope) > 0: - measurements = self.get_measurement_windows(scope, measurement_mapping) - if measurements: - parent_loop.add_measurements(measurements) - - for subtemplate in self.subtemplates: - subtemplate._create_program(scope=scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=parent_loop) + self_loop = Loop(measurements=self.get_measurement_windows(scope, measurement_mapping) or None) + for subtemplate in self.subtemplates: + subtemplate._create_program(scope=scope, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + global_transformation=global_transformation, + to_single_waveform=to_single_waveform, + parent_loop=parent_loop) + if self_loop.duration > 0: + parent_loop.append_child(self_loop) def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: data = super().get_serialization_data(serializer) From d688af9049e83af1198945bffeefec48061e60e4 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 22 Feb 2022 12:05:18 +0100 Subject: [PATCH 03/35] Fix error in sequencept create_program --- qupulse/pulses/sequence_pulse_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qupulse/pulses/sequence_pulse_template.py b/qupulse/pulses/sequence_pulse_template.py index b8157b0dd..0786d395f 100644 --- a/qupulse/pulses/sequence_pulse_template.py +++ b/qupulse/pulses/sequence_pulse_template.py @@ -143,7 +143,7 @@ def _internal_create_program(self, *, channel_mapping=channel_mapping, global_transformation=global_transformation, to_single_waveform=to_single_waveform, - parent_loop=parent_loop) + parent_loop=self_loop) if self_loop.duration > 0: parent_loop.append_child(self_loop) From f241d564333c0bc2374da21f5d3600bddbfaae38 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 22 Feb 2022 12:05:50 +0100 Subject: [PATCH 04/35] Fix first test --- tests/pulses/sequence_pulse_template_tests.py | 14 +++++++++----- tests/pulses/sequencing_dummies.py | 6 ++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/pulses/sequence_pulse_template_tests.py b/tests/pulses/sequence_pulse_template_tests.py index ebc81295a..4790a02a4 100644 --- a/tests/pulses/sequence_pulse_template_tests.py +++ b/tests/pulses/sequence_pulse_template_tests.py @@ -270,12 +270,14 @@ def test_create_program_internal(self) -> None: parent_loop=loop) self.assertEqual(1, loop.repetition_count) self.assertIsNone(loop.waveform) - self.assertEqual([Loop(repetition_count=1, waveform=sub1.waveform), + inner_loop, = loop.children + + self.assertEqual([Loop(repetition_count=1, waveform=sub1.waveform, measurements=[('b', 1, 2)]), Loop(repetition_count=1, waveform=sub2.waveform)], - list(loop.children)) + list(inner_loop.children)) self.assert_measurement_windows_equal({'a': ([0], [1]), 'b': ([1], [2])}, loop.get_measurement_windows()) - ### test again with inverted sequence + # test again with inverted sequence seq = SequencePulseTemplate(sub2, sub1, measurements=[('a', 0, 1)]) loop = Loop() seq._internal_create_program(scope=scope, @@ -286,9 +288,11 @@ def test_create_program_internal(self) -> None: parent_loop=loop) self.assertEqual(1, loop.repetition_count) self.assertIsNone(loop.waveform) + inner_loop, = loop.children + self.assertEqual([Loop(repetition_count=1, waveform=sub2.waveform), - Loop(repetition_count=1, waveform=sub1.waveform)], - list(loop.children)) + Loop(repetition_count=1, waveform=sub1.waveform, measurements=[('b', 1, 2)])], + list(inner_loop.children)) self.assert_measurement_windows_equal({'a': ([0], [1]), 'b': ([3], [2])}, loop.get_measurement_windows()) def test_internal_create_program_no_measurement_mapping(self) -> None: diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 0036cc25a..66ea4e907 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -99,6 +99,12 @@ def compare_key(self) -> Any: else: return id(self) + def __repr__(self): + if self.sample_output is not None: + return f"{type(self).__name__}(sample_output={self.sample_output})" + else: + return f"{type(self).__name__}(id={id(self)})" + @property def measurement_windows(self): return [] From 04f7d8ddb8e94de4a6dd7443f6ea1533aa5f21f5 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Thu, 24 Feb 2022 21:09:13 +0100 Subject: [PATCH 05/35] Use qupulse_rs WIP --- qupulse/_program/waveforms.py | 15 +++++++++++++++ qupulse/utils/types.py | 21 ++++++++++++++++----- tests/hardware/tektronix_tests.py | 2 +- tests/utils/time_type_tests.py | 2 ++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 473534108..61788ae4b 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -26,6 +26,12 @@ from qupulse._program.transformation import Transformation from qupulse.utils import pairwise +try: + import qupulse_rs.qupulse_rs + rs_replacements = qupulse_rs.qupulse_rs.replacements +except (ImportError, AttributeError): + rs_replacements = None + __all__ = ["Waveform", "TableWaveform", "TableWaveformEntry", "FunctionWaveform", "SequenceWaveform", "MultiChannelWaveform", "RepetitionWaveform", "TransformingWaveform", "ArithmeticWaveform"] @@ -1170,3 +1176,12 @@ def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: @property def compare_key(self) -> Tuple[Waveform, FrozenSet]: return self._inner_waveform, frozenset(self._functor.items()) + + +TableWaveform = rs_replacements.waveforms.TableWaveform +ConstantWaveform = rs_replacements.waveforms.ConstantWaveform +MultiChannelWaveform = rs_replacements.waveforms.MultiChannelWaveform +Waveform.register(TableWaveform) +Waveform.register(ConstantWaveform) +Waveform.register(MultiChannelWaveform) + diff --git a/qupulse/utils/types.py b/qupulse/utils/types.py index ca979935d..ea28fdead 100644 --- a/qupulse/utils/types.py +++ b/qupulse/utils/types.py @@ -18,6 +18,11 @@ "will be removed in a future release.", category=DeprecationWarning) frozendict = None +try: + from qupulse_rs.qupulse_rs import TimeType as RsTimeType +except ImportError: + RsTimeType = None + import qupulse.utils.numeric as qupulse_numeric __all__ = ["MeasurementWindow", "ChannelID", "HashableNumpyArray", "TimeType", "time_from_float", "DocStringABCMeta", @@ -310,16 +315,22 @@ def __float__(self): return int(self._value.numerator) / int(self._value.denominator) +PyTimeType = TimeType + +if RsTimeType: + TimeType = RsTimeType + + # this asserts isinstance(TimeType, Rational) is True numbers.Rational.register(TimeType) _converter = { - float: TimeType.from_float, - TimeType._InternalType: TimeType, - fractions.Fraction: TimeType, - sympy.Rational: lambda q: TimeType.from_fraction(q.p, q.q), - TimeType: lambda x: x + float: PyTimeType.from_float, + PyTimeType._InternalType: PyTimeType, + fractions.Fraction: PyTimeType, + sympy.Rational: lambda q: PyTimeType.from_fraction(q.p, q.q), + PyTimeType: lambda x: x } diff --git a/tests/hardware/tektronix_tests.py b/tests/hardware/tektronix_tests.py index a1f326b65..d25b1fc22 100644 --- a/tests/hardware/tektronix_tests.py +++ b/tests/hardware/tektronix_tests.py @@ -110,7 +110,7 @@ def test_parse_program(self): ill_formed_program = Loop(children=[Loop(children=[Loop()])]) with self.assertRaisesRegex(AssertionError, 'Invalid program depth'): - parse_program(ill_formed_program, (), (), TimeType(), (), (), ()) + parse_program(ill_formed_program, (), (), TimeType(0), (), (), ()) channels = ('A', 'B', None, None) markers = (('A1', None), (None, None), (None, 'C2'), (None, None)) diff --git a/tests/utils/time_type_tests.py b/tests/utils/time_type_tests.py index 93e118325..08d062a5b 100644 --- a/tests/utils/time_type_tests.py +++ b/tests/utils/time_type_tests.py @@ -68,6 +68,8 @@ def test_non_finite_float(self): qutypes.TimeType.from_float(float('nan')) def test_fraction_fallback(self): + if self.fallback_qutypes.TimeType is qutypes.RsTimeType: + self.skipTest("No fallback since rust implementation is used.") self.assertIs(fractions.Fraction, self.fallback_qutypes.TimeType._InternalType) def assert_from_fraction_works(self, time_type): From d6c7a87a37fe1d6f0c522b48e858b9d16e6a67ff Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 25 Feb 2022 11:49:18 +0100 Subject: [PATCH 06/35] Unify empty loop drop logic --- qupulse/_program/_loop.py | 14 ++++++++++++++ qupulse/pulses/loop_pulse_template.py | 18 ++++++++---------- qupulse/pulses/repetition_pulse_template.py | 19 +++++++++---------- qupulse/pulses/sequence_pulse_template.py | 18 ++++++++---------- 4 files changed, 39 insertions(+), 30 deletions(-) diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index b52aeb8e8..1db4f53af 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -406,6 +406,20 @@ def _merge_single_child(self): self._invalidate_duration() return True + @contextlib.contextmanager + def potential_child(self, + measurements: Optional[List[MeasurementWindow]], + repetition_count: Union[VolatileRepetitionCount, int] = 1): + if repetition_count != 1 and measurements: + # current design requires an extra level of nesting here because the measurements are NOT to be repeated + # with the repetition count + inner_child = Loop(repetition_count=repetition_count) + child = Loop(measurements=measurements, children=[inner_child]) + else: + inner_child = child = Loop(measurements=measurements, repetition_count=repetition_count) + yield inner_child + if inner_child.waveform or len(inner_child): + self.append_child(child) def cleanup(self, actions=('remove_empty_loops', 'merge_single_child')): """Apply the specified actions to cleanup the Loop. diff --git a/qupulse/pulses/loop_pulse_template.py b/qupulse/pulses/loop_pulse_template.py index 592b1edc9..9266670f4 100644 --- a/qupulse/pulses/loop_pulse_template.py +++ b/qupulse/pulses/loop_pulse_template.py @@ -217,16 +217,14 @@ def _internal_create_program(self, *, parent_loop: Loop) -> None: self.validate_scope(scope=scope) - self_loop = Loop(measurements=self.get_measurement_windows(scope, measurement_mapping) or None) - for local_scope in self._body_scope_generator(scope, forward=True): - self.body._create_program(scope=local_scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=self_loop) - if self_loop.duration > 0: - parent_loop.append_child(self_loop) + with parent_loop.potential_child(measurements=self.get_measurement_windows(scope, measurement_mapping)) as for_loop: + for local_scope in self._body_scope_generator(scope, forward=True): + self.body._create_program(scope=local_scope, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + global_transformation=global_transformation, + to_single_waveform=to_single_waveform, + parent_loop=for_loop) def build_waveform(self, parameter_scope: Scope) -> ForLoopWaveform: return ForLoopWaveform([self.body.build_waveform(local_scope) diff --git a/qupulse/pulses/repetition_pulse_template.py b/qupulse/pulses/repetition_pulse_template.py index a2ba542f5..2498fb9b6 100644 --- a/qupulse/pulses/repetition_pulse_template.py +++ b/qupulse/pulses/repetition_pulse_template.py @@ -119,16 +119,15 @@ def _internal_create_program(self, *, else: repetition_definition = repetition_count - repj_loop = Loop(repetition_count=repetition_definition, - measurements=self.get_measurement_windows(scope, measurement_mapping) or None) - self.body._create_program(scope=scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=repj_loop) - if repj_loop.waveform is not None or len(repj_loop.children) > 0: - parent_loop.append_child(loop=repj_loop) + measurements = self.get_measurement_windows(scope, measurement_mapping) or None + + with parent_loop.potential_child(measurements, repetition_count=repetition_definition) as repj_loop: + self.body._create_program(scope=scope, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + global_transformation=global_transformation, + to_single_waveform=to_single_waveform, + parent_loop=repj_loop) def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: data = super().get_serialization_data(serializer) diff --git a/qupulse/pulses/sequence_pulse_template.py b/qupulse/pulses/sequence_pulse_template.py index 0786d395f..b4235c33a 100644 --- a/qupulse/pulses/sequence_pulse_template.py +++ b/qupulse/pulses/sequence_pulse_template.py @@ -136,16 +136,14 @@ def _internal_create_program(self, *, parent_loop: Loop) -> None: self.validate_scope(scope) - self_loop = Loop(measurements=self.get_measurement_windows(scope, measurement_mapping) or None) - for subtemplate in self.subtemplates: - subtemplate._create_program(scope=scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=self_loop) - if self_loop.duration > 0: - parent_loop.append_child(self_loop) + with parent_loop.potential_child(measurements=self.get_measurement_windows(scope, measurement_mapping)) as seq_loop: + for subtemplate in self.subtemplates: + subtemplate._create_program(scope=scope, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + global_transformation=global_transformation, + to_single_waveform=to_single_waveform, + parent_loop=seq_loop) def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: data = super().get_serialization_data(serializer) From 02a6633e69391679134e4f1f5186553e75307045 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 25 Feb 2022 11:50:45 +0100 Subject: [PATCH 07/35] Adjust tests to changes --- tests/pulses/loop_pulse_template_tests.py | 41 +++++---- .../multi_channel_pulse_template_tests.py | 6 +- tests/pulses/pulse_template_tests.py | 14 ++- .../pulses/repetition_pulse_template_tests.py | 90 ++++++++----------- tests/pulses/sequence_pulse_template_tests.py | 26 +++--- tests/pulses/sequencing_dummies.py | 6 +- 6 files changed, 86 insertions(+), 97 deletions(-) diff --git a/tests/pulses/loop_pulse_template_tests.py b/tests/pulses/loop_pulse_template_tests.py index 6bc6d933a..2716004e4 100644 --- a/tests/pulses/loop_pulse_template_tests.py +++ b/tests/pulses/loop_pulse_template_tests.py @@ -270,12 +270,12 @@ def test_create_program_invalid_measurement_mapping(self) -> None: global_transformation=None) def test_create_program_missing_params(self) -> None: - dt = DummyPulseTemplate(parameter_names={'i'}, waveform=DummyWaveform(duration=4.0), duration='t', measurements=[('b', 2, 1)]) + dt = DummyPulseTemplate(parameter_names={'i'}, waveform=DummyWaveform(duration=4.0), duration='t', measurements=[('M', 2, 1)]) flt = ForLoopPulseTemplate(body=dt, loop_index='i', loop_range=('a', 'b', 'c'), measurements=[('A', 'alph', 1)], parameter_constraints=['c > 1']) scope = DictScope.from_kwargs(a=1, b=4) - measurement_mapping = dict(A='B') + measurement_mapping = dict(A='B', M='M') channel_mapping = dict(C='D') children = [Loop(waveform=DummyWaveform(duration=2.0))] @@ -353,21 +353,25 @@ def test_create_program(self) -> None: global_transformation = TransformationStub() program = Loop() + self_loop = Loop(waveform=DummyWaveform(duration=1), measurements=[('B', .1, 1)]) - # inner _create_program does nothing - expected_program = Loop(measurements=[('B', .1, 1)]) + expected_program = Loop(children=[Loop( + children=[ + Loop(waveform=dt.waveform, measurements=[('b', .2, .3)]), + Loop(waveform=dt.waveform, measurements=[('b', .2, .3)]), + ], + measurements=[('B', scope['meas_param'], 1)])]) expected_create_program_kwargs = dict(measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=program) + to_single_waveform=to_single_waveform) expected_create_program_calls = [mock.call(**expected_create_program_kwargs, scope=_get_for_loop_scope(scope, 'i', i)) for i in (1, 3)] with mock.patch.object(flt, 'validate_scope') as validate_scope: - with mock.patch.object(dt, '_create_program') as body_create_program: + with mock.patch.object(dt, '_create_program', wraps=dt._create_program) as body_create_program: with mock.patch.object(flt, 'get_measurement_windows', wraps=flt.get_measurement_windows) as get_measurement_windows: flt._internal_create_program(scope=scope, @@ -379,6 +383,11 @@ def test_create_program(self) -> None: validate_scope.assert_called_once_with(scope=scope) get_measurement_windows.assert_called_once_with(scope, measurement_mapping) + + inner_loop = program[0] + for call in expected_create_program_calls: + call.kwargs['parent_loop'] = inner_loop + self.assertEqual(body_create_program.call_args_list, expected_create_program_calls) self.assertEqual(expected_program, program) @@ -402,17 +411,15 @@ def test_create_program_append(self) -> None: to_single_waveform=set(), global_transformation=None) - self.assertEqual(3, len(program.children)) - self.assertIs(children[0], program.children[0]) - self.assertEqual(dt.waveform, program.children[1].waveform) - self.assertEqual(dt.waveform, program.children[2].waveform) - self.assertEqual(1, program.children[1].repetition_count) - self.assertEqual(1, program.children[2].repetition_count) - self.assertEqual(1, program.repetition_count) - self.assert_measurement_windows_equal({'b': ([4, 8], [1, 1]), 'B': ([2], [1])}, program.get_measurement_windows()) + expected_program = Loop(children=children + [ + Loop(children=[ + Loop(waveform=dt.waveform, measurements=[('b', 2, 1)]), + Loop(waveform=dt.waveform, measurements=[('b', 2, 1)])], + measurements=[('B', 0, 1)] + )]) - # not ensure same result as from Sequencer here - we're testing appending to an already existing parent loop - # which is a use case that does not immediately arise from using Sequencer + self.assertEqual(expected_program, program) + self.assert_measurement_windows_equal({'b': ([4, 8], [1, 1]), 'B': ([2], [1])}, program.get_measurement_windows()) class ForLoopPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 89e2b7826..114c6fee9 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -207,7 +207,7 @@ def test_build_waveform(self): pt = AtomicMultiChannelPulseTemplate(*sts, parameter_constraints=['a < b']) - parameters = dict(a=2.2, b = 1.1, c=3.3) + parameters = dict(a=2.2, b = 1.1, c=3.3, t1=1.1) channel_mapping = dict() with self.assertRaises(ParameterConstraintViolation): pt.build_waveform(parameters, channel_mapping=dict()) @@ -231,7 +231,7 @@ def test_build_waveform_none(self): pt = AtomicMultiChannelPulseTemplate(*sts, parameter_constraints=['a < b']) - parameters = dict(a=2.2, b=1.1, c=3.3) + parameters = dict(a=2.2, b=1.1, c=3.3, t1=1.1) channel_mapping = dict(A=6) with self.assertRaises(ParameterConstraintViolation): # parameter constraints are checked before channel mapping is applied @@ -432,7 +432,7 @@ def test_build_waveform(self): channel_mapping = {'X': 'X', 'Y': 'K', 'Z': 'Z'} pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels) - parameters = {'c': 1.2, 'a': 3.4} + parameters = {'c': 1.2, 'a': 3.4, 't1': template.waveform.duration} expected_overwritten_channels = {'K': 1.2, 'Z': 3.4} expected_transformation = ParallelConstantChannelTransformation(expected_overwritten_channels) expected_waveform = TransformingWaveform(template.waveform, expected_transformation) diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index d2643a56d..4e5c19723 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -81,6 +81,9 @@ def measurement_names(self): def integral(self) -> Dict[ChannelID, ExpressionScalar]: raise NotImplementedError() + def __repr__(self): + return f"PulseTemplateStub(id={id(self)})" + def get_appending_internal_create_program(waveform=DummyWaveform(), always_append=False, @@ -228,7 +231,7 @@ def test__create_program_single_waveform(self): single_waveform = DummyWaveform() measurements = [('m', 0, 1), ('n', 0.1, .9)] - expected_inner_program = Loop(children=[Loop(waveform=wf)], measurements=measurements) + expected_inner_program = Loop(children=[Loop(waveform=wf, measurements=measurements)]) appending_create_program = get_appending_internal_create_program(wf, measurements=measurements, @@ -239,8 +242,7 @@ def test__create_program_single_waveform(self): else: final_waveform = single_waveform - expected_program = Loop(children=[Loop(waveform=final_waveform)], - measurements=measurements) + expected_program = Loop(children=[Loop(waveform=final_waveform, measurements=measurements)]) with mock.patch.object(template, '_internal_create_program', wraps=appending_create_program) as _internal_create_program: @@ -262,9 +264,6 @@ def test__create_program_single_waveform(self): to_waveform.assert_called_once_with(expected_inner_program) - expected_program._measurements = set(expected_program._measurements) - parent_loop._measurements = set(parent_loop._measurements) - self.assertEqual(expected_program, parent_loop) def test_create_program_defaults(self) -> None: @@ -365,8 +364,7 @@ def test_internal_create_program(self) -> None: channel_mapping = {'B': 'A'} program = Loop() - expected_program = Loop(children=[Loop(waveform=wf)], - measurements=[('N', 0, 5)]) + expected_program = Loop(children=[Loop(waveform=wf, measurements=[('N', 0, 5)])]) with mock.patch.object(template, 'build_waveform', return_value=wf) as build_waveform: template._internal_create_program(scope=scope, diff --git a/tests/pulses/repetition_pulse_template_tests.py b/tests/pulses/repetition_pulse_template_tests.py index 288a44650..8e9979932 100644 --- a/tests/pulses/repetition_pulse_template_tests.py +++ b/tests/pulses/repetition_pulse_template_tests.py @@ -2,6 +2,8 @@ import warnings from unittest import mock +import numpy.testing + from qupulse.parameter_scope import Scope, DictScope from qupulse.utils.types import FrozenDict @@ -110,13 +112,13 @@ def test_internal_create_program(self): to_single_waveform = {'to', 'single', 'waveform'} program = Loop() - expected_program = Loop(children=[Loop(children=[Loop(waveform=wf)], repetition_count=6)], - measurements=[('l', .1, .2)]) + expected_program = Loop(children=[Loop(children=[Loop(children=[Loop(waveform=wf)], repetition_count=6)], + measurements=[('l', .1, .2)])]) real_relevant_parameters = dict(n_rep=3, mul=2, a=0.1, b=0.2) with mock.patch.object(body, '_create_program', - wraps=get_appending_internal_create_program(wf, always_append=True)) as body_create_program: + wraps=get_appending_internal_create_program(wf, always_append=True)): with mock.patch.object(rpt, 'validate_scope') as validate_scope: with mock.patch.object(rpt, 'get_repetition_count_value', return_value=6) as get_repetition_count_value: with mock.patch.object(rpt, 'get_measurement_windows', return_value=[('l', .1, .2)]) as get_meas: @@ -128,12 +130,6 @@ def test_internal_create_program(self): parent_loop=program) self.assertEqual(program, expected_program) - body_create_program.assert_called_once_with(scope=scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=program.children[0]) validate_scope.assert_called_once_with(scope) get_repetition_count_value.assert_called_once_with(scope) get_meas.assert_called_once_with(scope, measurement_mapping) @@ -153,15 +149,17 @@ def test_create_program_constant_success_measurements(self) -> None: global_transformation=None, parent_loop=program) - self.assertEqual(1, len(program.children)) - internal_loop = program[0] # type: Loop - self.assertEqual(repetitions, internal_loop.repetition_count) - - self.assertEqual(1, len(internal_loop)) - self.assertEqual((scope, measurement_mapping, channel_mapping, internal_loop), body.create_program_calls[-1]) - self.assertEqual(body.waveform, internal_loop[0].waveform) + expected_program = Loop(children=[Loop( + measurements=[("thy", 2, 2)], + children=[Loop( + children=[Loop(waveform=body.waveform, measurements=[('b', 0, 1)])], + repetition_count=repetitions + )]) + ]) - self.assert_measurement_windows_equal({'b': ([0, 2, 4], [1, 1, 1]), 'thy': ([2], [2])}, program.get_measurement_windows()) + self.assertEqual(expected_program, program) + self.assert_measurement_windows_equal({'b': ([0, 2, 4], [1, 1, 1]), 'thy': ([2], [2])}, + program.get_measurement_windows()) # done in MultiChannelProgram program.cleanup() @@ -184,16 +182,11 @@ def test_create_program_declaration_success(self) -> None: global_transformation=None, parent_loop=program) - self.assertEqual(1, program.repetition_count) - self.assertEqual(1, len(program.children)) - internal_loop = program.children[0] # type: Loop - self.assertEqual(scope[repetitions], internal_loop.repetition_count) - - self.assertEqual(1, len(internal_loop)) - self.assertEqual((scope, measurement_mapping, channel_mapping, internal_loop), - body.create_program_calls[-1]) - self.assertEqual(body.waveform, internal_loop[0].waveform) - + expected_program = Loop(children=[ + Loop(repetition_count=scope['foo'], + children=[Loop(waveform=body.waveform)]) + ]) + self.assertEqual(expected_program, program) self.assert_measurement_windows_equal({}, program.get_measurement_windows()) # ensure same result as from Sequencer @@ -209,7 +202,7 @@ def test_create_program_declaration_success_appended_measurements(self) -> None: measurement_mapping = dict(moth='fire', b='b') channel_mapping = dict(asd='f') children = [Loop(waveform=DummyWaveform(duration=0))] - program = Loop(children=children, measurements=[('a', [0], [1])], repetition_count=2) + program = Loop(children=children, measurements=[('a', 0, 1)], repetition_count=2) t._internal_create_program(scope=scope, measurement_mapping=measurement_mapping, @@ -218,22 +211,21 @@ def test_create_program_declaration_success_appended_measurements(self) -> None: global_transformation=None, parent_loop=program) - self.assertEqual(2, program.repetition_count) - self.assertEqual(2, len(program.children)) - self.assertIs(program.children[0], children[0]) - internal_loop = program.children[1] # type: Loop - self.assertEqual(scope[repetitions], internal_loop.repetition_count) - - self.assertEqual(1, len(internal_loop)) - self.assertEqual((scope, measurement_mapping, channel_mapping, internal_loop), body.create_program_calls[-1]) - self.assertEqual(body.waveform, internal_loop[0].waveform) + expected_program = Loop(children=children + [Loop( + measurements=[('fire', 0, 7.1)], + children=[Loop(repetition_count=scope['foo'], children=[Loop(waveform=body.waveform, measurements=[('b', 0, 1)])])] + )], + measurements=[('a', 0, 1)], repetition_count=2) + self.assertEqual(expected_program, program) - self.assert_measurement_windows_equal({'fire': ([0, 6], [7.1, 7.1]), - 'b': ([0, 2, 4, 6, 8, 10], [1, 1, 1, 1, 1, 1]), - 'a': ([0], [1])}, program.get_measurement_windows()) + expected_measurementt_windows = { + 'fire': ([0, 6], [7.1, 7.1]), + 'b': ([0, 2, 4, 6, 8, 10], [1, 1, 1, 1, 1, 1]), + 'a': ([0, expected_program.body_duration], [1, 1])} + numpy.testing.assert_equal(expected_measurementt_windows, program.get_measurement_windows()) - # not ensure same result as from Sequencer here - we're testing appending to an already existing parent loop - # which is a use case that does not immediately arise from using Sequencer + self.assert_measurement_windows_equal(expected_measurementt_windows, + program.get_measurement_windows()) def test_create_program_declaration_success_measurements(self) -> None: repetitions = "foo" @@ -250,15 +242,11 @@ def test_create_program_declaration_success_measurements(self) -> None: global_transformation=None, parent_loop=program) - self.assertEqual(1, program.repetition_count) - self.assertEqual(1, len(program.children)) - internal_loop = program.children[0] # type: Loop - self.assertEqual(scope[repetitions], internal_loop.repetition_count) - - self.assertEqual(1, len(internal_loop)) - self.assertEqual((scope, measurement_mapping, channel_mapping, internal_loop), body.create_program_calls[-1]) - self.assertEqual(body.waveform, internal_loop[0].waveform) - + expected_program = Loop(children=[ + Loop(measurements=[('fire', 0, scope['meas_end'])], + children=[Loop(children=[Loop(waveform=body.waveform, measurements=[('b', 0, 1)])], repetition_count=scope['foo'])]) + ]) + self.assertEqual(expected_program, program) self.assert_measurement_windows_equal({'fire': ([0], [7.1]), 'b': ([0, 2, 4], [1, 1, 1])}, program.get_measurement_windows()) def test_create_program_declaration_exceeds_bounds(self) -> None: diff --git a/tests/pulses/sequence_pulse_template_tests.py b/tests/pulses/sequence_pulse_template_tests.py index 4790a02a4..e224b2ce3 100644 --- a/tests/pulses/sequence_pulse_template_tests.py +++ b/tests/pulses/sequence_pulse_template_tests.py @@ -233,9 +233,9 @@ def test_internal_create_program(self): program = Loop() - expected_program = Loop(children=[Loop(waveform=wfs[0]), + expected_program = Loop(children=[Loop(children=[Loop(waveform=wfs[0]), Loop(waveform=wfs[1])], - measurements=[('l', .1, .2)]) + measurements=[('l', .1, .2)])]) with mock.patch.object(spt, 'validate_scope') as validate_scope: with mock.patch.object(spt, 'get_measurement_windows', @@ -251,8 +251,8 @@ def test_internal_create_program(self): validate_scope.assert_called_once_with(kwargs['scope']) get_measurement_windows.assert_called_once_with(kwargs['scope'], kwargs['measurement_mapping']) - create_0.assert_called_once_with(**kwargs, parent_loop=program) - create_1.assert_called_once_with(**kwargs, parent_loop=program) + # create_0.assert_called_once_with(**kwargs, parent_loop=program) + # create_1.assert_called_once_with(**kwargs, parent_loop=program) def test_create_program_internal(self) -> None: sub1 = DummyPulseTemplate(duration=3, waveform=DummyWaveform(duration=3), measurements=[('b', 1, 2)], defined_channels={'A'}) @@ -343,13 +343,12 @@ def test_internal_create_program_one_child_no_duration(self) -> None: global_transformation=None, to_single_waveform=set(), parent_loop=loop) - self.assertEqual(1, loop.repetition_count) - self.assertIsNone(loop.waveform) - self.assertEqual([Loop(repetition_count=1, waveform=sub2.waveform)], - list(loop.children)) + expected_program = Loop(children=[Loop( + children=[Loop(waveform=sub2.waveform)], + measurements=seq.measurement_declarations, + )]) + self.assertEqual(expected_program, loop) self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) - - # MultiChannelProgram calls cleanup loop.cleanup() self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) @@ -362,13 +361,8 @@ def test_internal_create_program_one_child_no_duration(self) -> None: global_transformation=None, to_single_waveform=set(), parent_loop=loop) - self.assertEqual(1, loop.repetition_count) - self.assertIsNone(loop.waveform) - self.assertEqual([Loop(repetition_count=1, waveform=sub2.waveform)], - list(loop.children)) + self.assertEqual(expected_program, loop) self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) - - # MultiChannelProgram calls cleanup loop.cleanup() self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 66ea4e907..63ced5d12 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -244,15 +244,17 @@ def _internal_create_program(self, *, parent_loop.append_child(waveform=self._program.waveform, children=self._program.children, measurements=measurements) elif self.waveform: - parent_loop.append_child(waveform=self.waveform, measurements=measurements) + parent_loop.append_child(waveform=self.build_waveform(parameters=scope, channel_mapping=channel_mapping), + measurements=measurements) def build_waveform(self, parameters: Dict[str, Parameter], channel_mapping: Dict[ChannelID, ChannelID]): self.build_waveform_calls.append((parameters, channel_mapping)) + duration = self.duration.evaluate_in_scope(parameters) if self.waveform or self.waveform is None: return self.waveform - return DummyWaveform(duration=self.duration.evaluate_numeric(**parameters), defined_channels=self.defined_channels) + return DummyWaveform(duration=duration, defined_channels=self.defined_channels) def get_serialization_data(self, serializer: Optional['Serializer']=None) -> Dict[str, Any]: data = super().get_serialization_data(serializer=serializer) From f878b70469f04a0e435ade444dba5fe3f0fe1231 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 25 Feb 2022 12:00:40 +0100 Subject: [PATCH 08/35] Increase Loop debuggability by creating a more correct __repr__ --- qupulse/_program/_loop.py | 34 ++++++++++++++++++++++++++++++---- qupulse/utils/tree.py | 3 +++ tests/_program/loop_tests.py | 12 ++++++++---- 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index 1db4f53af..fd175d12a 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -1,3 +1,4 @@ +import contextlib from typing import Union, Dict, Iterable, Tuple, cast, List, Optional, Generator, Mapping from collections import defaultdict from enum import Enum @@ -201,23 +202,47 @@ def encapsulate(self) -> None: self._measurements = None self.assert_tree_integrity() - def _get_repr(self, first_prefix, other_prefixes) -> Generator[str, None, None]: + def __repr__(self): + kwargs = [] + + repetition_count = self._repetition_definition + if repetition_count != 1: + kwargs.append(f"repetition_count={repetition_count!r}") + + waveform = self._waveform + if waveform: + kwargs.append(f"waveform={waveform!r}") + + children = self.children + if children: + try: + kwargs.append(f"children={self._children_repr()}") + except RecursionError: + kwargs.append("children=[...]") + + measurements = self._measurements + if measurements: + kwargs.append(f"measurements={measurements!r}") + + return f"Loop({','.join(kwargs)})" + + def _get_str(self, first_prefix, other_prefixes) -> Generator[str, None, None]: if self.is_leaf(): yield '%sEXEC %r %d times' % (first_prefix, self._waveform, self.repetition_count) else: yield '%sLOOP %d times:' % (first_prefix, self.repetition_count) for elem in self: - yield from cast(Loop, elem)._get_repr(other_prefixes + ' ->', other_prefixes + ' ') + yield from cast(Loop, elem)._get_str(other_prefixes + ' ->', other_prefixes + ' ') - def __repr__(self) -> str: + def __str__(self) -> str: is_circular = is_tree_circular(self) if is_circular: return '{}: Circ {}'.format(id(self), is_circular) str_len = 0 repr_list = [] - for sub_repr in self._get_repr('', ''): + for sub_repr in self._get_str('', ''): str_len += len(sub_repr) if self.MAX_REPR_SIZE and str_len > self.MAX_REPR_SIZE: @@ -420,6 +445,7 @@ def potential_child(self, yield inner_child if inner_child.waveform or len(inner_child): self.append_child(child) + def cleanup(self, actions=('remove_empty_loops', 'merge_single_child')): """Apply the specified actions to cleanup the Loop. diff --git a/qupulse/utils/tree.py b/qupulse/utils/tree.py index dfc927a2a..a91bade32 100644 --- a/qupulse/utils/tree.py +++ b/qupulse/utils/tree.py @@ -152,6 +152,9 @@ def locate(self: _NodeType, location: Tuple[int, ...]) -> _NodeType: else: return self + def _children_repr(self): + return repr(self.__children) + def is_tree_circular(root: Node) -> Union[None, Tuple[List[Node], int]]: NodeStack = namedtuple('NodeStack', ['node', 'stack']) diff --git a/tests/_program/loop_tests.py b/tests/_program/loop_tests.py index 40d1c266e..6776e809c 100644 --- a/tests/_program/loop_tests.py +++ b/tests/_program/loop_tests.py @@ -121,6 +121,10 @@ def test_compare_key(self): self.assertEqual(tree1, tree5) def test_repr(self): + tree = self.get_test_loop() + self.assertEqual(tree, eval(repr(tree))) + + def test_str(self): wf_gen = WaveformGenerator(num_channels=1) wfs = [wf_gen() for _ in range(11)] @@ -132,10 +136,10 @@ def test_repr(self): loop.waveform = wfs.pop(0) self.assertEqual(len(wfs), 0) - self.assertEqual(repr(tree), expected) + self.assertEqual(str(tree), expected) with mock.patch.object(Loop, 'MAX_REPR_SIZE', 1): - self.assertEqual(repr(tree), '...') + self.assertEqual(str(tree), '...') def test_is_leaf(self): root_loop = self.get_test_loop(waveform_generator=WaveformGenerator(1)) @@ -199,7 +203,7 @@ def test_flatten_and_balance(self): ->LOOP 9 times: ->EXEC {J} 10 times ->EXEC {K} 11 times""".format(**wf_reprs) - self.assertEqual(repr(before), before_repr) + self.assertEqual(str(before), before_repr) expected_after_repr = """\ LOOP 1 times: @@ -241,7 +245,7 @@ def test_flatten_and_balance(self): ->EXEC {J} 10 times ->EXEC {K} 11 times""".format(**wf_reprs) - self.assertEqual(expected_after_repr, repr(after)) + self.assertEqual(expected_after_repr, str(after)) def test_flatten_and_balance_comparison_based(self): wfs = [DummyWaveform(duration=i) for i in range(2)] From f2b2abfd3134d451a008c7674d3e1ba14a63ba39 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 25 Feb 2022 16:50:05 +0100 Subject: [PATCH 09/35] Prepare Loop replacement --- qupulse/_program/__init__.py | 51 +++++++++++++++++++++ qupulse/_program/_loop.py | 17 ++++++- qupulse/pulses/constant_pulse_template.py | 26 +---------- qupulse/pulses/loop_pulse_template.py | 4 +- qupulse/pulses/mapping_pulse_template.py | 4 +- qupulse/pulses/pulse_template.py | 40 ++++++++-------- qupulse/pulses/repetition_pulse_template.py | 5 +- qupulse/pulses/sequence_pulse_template.py | 4 +- tests/pulses/pulse_template_tests.py | 11 ++--- tests/pulses/sequencing_dummies.py | 11 +++-- 10 files changed, 107 insertions(+), 66 deletions(-) diff --git a/qupulse/_program/__init__.py b/qupulse/_program/__init__.py index 93773ebb1..8205a1af3 100644 --- a/qupulse/_program/__init__.py +++ b/qupulse/_program/__init__.py @@ -1 +1,52 @@ """This is a private package meaning there are no stability guarantees.""" +from abc import ABC, abstractmethod +from typing import Optional, Union, Sequence, ContextManager, Mapping + +import numpy as np + +from qupulse._program.waveforms import Waveform +from qupulse.utils.types import MeasurementWindow +from qupulse._program.volatile import VolatileRepetitionCount + +try: + from typing import Protocol +except ImportError: + Protocol = object + + +RepetitionCount = Union[int, VolatileRepetitionCount] + + +class Program(Protocol): + def to_single_waveform(self) -> Waveform: + pass + + def get_measurement_windows(self) -> Mapping[str, np.ndarray]: + pass + + +class ProgramBuilder(Protocol): + def append_leaf(self, waveform: Waveform, + measurements: Optional[Sequence[MeasurementWindow]] = None, + repetition_count: int = 1): + pass + + def potential_child(self, measurements: Optional[Sequence[MeasurementWindow]], + repetition_count: Union[VolatileRepetitionCount, int] = 1) -> ContextManager['ProgramBuilder']: + """ + + Args: + measurements: Measurements to attach to the potential child. Is not repeated with repetition_count. + repetition_count: + + Returns: + + """ + + def to_program(self) -> Optional[Program]: + pass + + +def default_program_builder() -> ProgramBuilder: + from qupulse._program._loop import Loop + return Loop() diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index fd175d12a..32ed89ab5 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -1,5 +1,5 @@ import contextlib -from typing import Union, Dict, Iterable, Tuple, cast, List, Optional, Generator, Mapping +from typing import Union, Dict, Iterable, Tuple, cast, List, Optional, Generator, Mapping, ContextManager, Sequence from collections import defaultdict from enum import Enum import warnings @@ -16,6 +16,7 @@ from qupulse.utils.tree import Node, is_tree_circular from qupulse.utils.numeric import smallest_factor_ge +from qupulse._program import ProgramBuilder, Program from qupulse._program.waveforms import SequenceWaveform, RepetitionWaveform __all__ = ['Loop', 'make_compatible', 'MakeCompatibleWarning'] @@ -434,7 +435,7 @@ def _merge_single_child(self): @contextlib.contextmanager def potential_child(self, measurements: Optional[List[MeasurementWindow]], - repetition_count: Union[VolatileRepetitionCount, int] = 1): + repetition_count: Union[VolatileRepetitionCount, int] = 1) -> ContextManager['Loop']: if repetition_count != 1 and measurements: # current design requires an extra level of nesting here because the measurements are NOT to be repeated # with the repetition count @@ -493,6 +494,18 @@ def get_duration_structure(self) -> Tuple[int, Union[TimeType, tuple]]: else: return self.repetition_count, tuple(child.get_duration_structure() for child in self) + def to_single_waveform(self) -> Waveform: + return to_waveform(self) + + def append_leaf(self, waveform: Waveform, + measurements: Optional[Sequence[MeasurementWindow]] = None, + repetition_count: int = 1): + self.append_child(waveform=waveform, measurements=measurements, repetition_count=repetition_count) + + def to_program(self) -> Optional['Loop']: + if self.waveform or self.children: + return self + class ChannelSplit(Exception): def __init__(self, channel_sets): diff --git a/qupulse/pulses/constant_pulse_template.py b/qupulse/pulses/constant_pulse_template.py index 57b233b7f..5c63d0646 100644 --- a/qupulse/pulses/constant_pulse_template.py +++ b/qupulse/pulses/constant_pulse_template.py @@ -10,6 +10,7 @@ import numbers from typing import Any, Dict, List, Optional, Set, Union +from qupulse._program import ProgramBuilder from qupulse._program.waveforms import ConstantWaveform from qupulse.expressions import ExpressionScalar from qupulse.parameter_scope import Scope @@ -87,31 +88,6 @@ def defined_channels(self) -> Set['ChannelID']: def requires_stop(self) -> bool: # from SequencingElement return False - def _internal_create_program(self, *, - scope: Scope, - measurement_mapping: Dict[str, Optional[str]], - channel_mapping: Dict[ChannelID, Optional[ChannelID]], - global_transformation: Optional[Transformation], - to_single_waveform: Set[Union[str, PulseTemplate]], - parent_loop: Loop) -> None: - """ removed from qupulse docstring """ - try: - parameters = {parameter_name: scope[parameter_name].get_value() - for parameter_name in self.parameter_names} - except KeyError as e: - raise ParameterNotProvidedException(str(e)) from e - - waveform = self.build_waveform(parameters=parameters, - channel_mapping=channel_mapping) - if waveform: - measurements = self.get_measurement_windows(parameters=parameters, - measurement_mapping=measurement_mapping) - - if global_transformation: - waveform = TransformingWaveform(waveform, global_transformation) - - parent_loop.append_child(waveform=waveform, measurements=measurements) - def build_waveform(self, parameters: Dict[str, numbers.Real], channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional[Union[ConstantWaveform, diff --git a/qupulse/pulses/loop_pulse_template.py b/qupulse/pulses/loop_pulse_template.py index 9266670f4..1bf534ae6 100644 --- a/qupulse/pulses/loop_pulse_template.py +++ b/qupulse/pulses/loop_pulse_template.py @@ -13,7 +13,7 @@ from qupulse.parameter_scope import Scope, MappedScope, DictScope from qupulse.utils.types import FrozenDict, FrozenMapping -from qupulse._program._loop import Loop +from qupulse._program import ProgramBuilder from qupulse.expressions import ExpressionScalar, ExpressionVariableMissingException, Expression from qupulse.utils import checked_int_cast, cached_property @@ -214,7 +214,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: self.validate_scope(scope=scope) with parent_loop.potential_child(measurements=self.get_measurement_windows(scope, measurement_mapping)) as for_loop: diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index af5b64e00..38b9229b6 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -9,7 +9,7 @@ from qupulse.pulses.pulse_template import PulseTemplate, MappingTuple from qupulse.pulses.parameters import Parameter, MappedParameter, ParameterNotProvidedException, ParameterConstrainer from qupulse._program.waveforms import Waveform -from qupulse._program._loop import Loop +from qupulse._program import ProgramBuilder from qupulse.serialization import Serializer, PulseRegistryType __all__ = [ @@ -317,7 +317,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: self.validate_scope(scope) # parameters are validated in map_parameters() call, no need to do it here again explicitly diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 0814daf90..0101cb294 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -17,6 +17,7 @@ from qupulse.utils.types import ChannelID, DocStringABCMeta, FrozenDict from qupulse.serialization import Serializable from qupulse.expressions import ExpressionScalar, Expression, ExpressionLike +from qupulse._program import ProgramBuilder, default_program_builder from qupulse._program._loop import Loop, to_waveform from qupulse._program.transformation import Transformation, IdentityTransformation, ChainedTransformation, chain_transformations @@ -153,7 +154,7 @@ def create_program(self, *, scope = DictScope(values=FrozenDict(parameters), volatile=volatile) - root_loop = Loop() + root_loop = default_program_builder() # call subclass specific implementation self._create_program(scope=scope, @@ -163,9 +164,7 @@ def create_program(self, *, to_single_waveform=to_single_waveform, parent_loop=root_loop) - if root_loop.waveform is None and len(root_loop.children) == 0: - return None # return None if no program - return root_loop + return root_loop.to_program() @abstractmethod def _internal_create_program(self, *, @@ -174,7 +173,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional[Transformation], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: """The subclass specific implementation of create_program(). Receives a Loop instance parent_loop to which it should append measurements and its own Loops as children. @@ -195,35 +194,36 @@ def _create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional[Transformation], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop): + parent_loop: ProgramBuilder): """Generic part of create program. This method handles to_single_waveform and the configuration of the transformer.""" if self.identifier in to_single_waveform or self in to_single_waveform: - root = Loop() - if not scope.get_volatile_parameters().keys().isdisjoint(self.parameter_names): raise NotImplementedError('A pulse template that has volatile parameters cannot be transformed into a ' 'single waveform yet.') + builder = default_program_builder() self._internal_create_program(scope=scope, measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=None, to_single_waveform=to_single_waveform, - parent_loop=root) + parent_loop=builder) - waveform = to_waveform(root) + program = builder.to_program() + if program is not None: + waveform = program.to_single_waveform() - if global_transformation: - waveform = TransformingWaveform(waveform, global_transformation) + if global_transformation: + waveform = TransformingWaveform(waveform, global_transformation) - # convert the nicely formatted measurement windows back into the old format again :( - measurements = root.get_measurement_windows() - measurement_window_list = [] - for measurement_name, (begins, lengths) in measurements.items(): - measurement_window_list.extend(zip(itertools.repeat(measurement_name), begins, lengths)) + # convert the nicely formatted measurement windows back into the old format again :( + measurements = program.get_measurement_windows() + measurement_window_list = [] + for measurement_name, (begins, lengths) in measurements.items(): + measurement_window_list.extend(zip(itertools.repeat(measurement_name), begins, lengths)) - parent_loop.append_child(waveform=waveform, measurements=measurement_window_list) + parent_loop.append_leaf(waveform=waveform, measurements=measurement_window_list) else: self._internal_create_program(scope=scope, @@ -311,7 +311,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional[Transformation], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: """Parameter constraints are validated in build_waveform because build_waveform is guaranteed to be called during sequencing""" ### current behavior (same as previously): only adds EXEC Loop and measurements if a waveform exists. @@ -327,7 +327,7 @@ def _internal_create_program(self, *, if global_transformation: waveform = TransformingWaveform(waveform, global_transformation) - parent_loop.append_child(waveform=waveform, measurements=measurements) + parent_loop.append_leaf(waveform=waveform, measurements=measurements) @abstractmethod def build_waveform(self, diff --git a/qupulse/pulses/repetition_pulse_template.py b/qupulse/pulses/repetition_pulse_template.py index 2498fb9b6..0ec2c9773 100644 --- a/qupulse/pulses/repetition_pulse_template.py +++ b/qupulse/pulses/repetition_pulse_template.py @@ -8,7 +8,8 @@ import numpy as np from qupulse.serialization import Serializer, PulseRegistryType -from qupulse._program._loop import Loop, VolatileRepetitionCount +from qupulse._program import ProgramBuilder +from qupulse._program.volatile import VolatileRepetitionCount from qupulse.parameter_scope import Scope from qupulse.utils.types import ChannelID @@ -105,7 +106,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: self.validate_scope(scope) repetition_count = max(0, self.get_repetition_count_value(scope)) diff --git a/qupulse/pulses/sequence_pulse_template.py b/qupulse/pulses/sequence_pulse_template.py index b4235c33a..2575a2c62 100644 --- a/qupulse/pulses/sequence_pulse_template.py +++ b/qupulse/pulses/sequence_pulse_template.py @@ -8,7 +8,7 @@ import warnings from qupulse.serialization import Serializer, PulseRegistryType -from qupulse._program._loop import Loop +from qupulse._program import ProgramBuilder from qupulse.parameter_scope import Scope from qupulse.utils import cached_property from qupulse.utils.types import MeasurementWindow, ChannelID, TimeType @@ -133,7 +133,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: self.validate_scope(scope) with parent_loop.potential_child(measurements=self.get_measurement_windows(scope, measurement_mapping)) as seq_loop: diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index 4e5c19723..b23c3ac44 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -3,7 +3,6 @@ from unittest import mock from typing import Optional, Dict, Set, Any, Union -import sympy from qupulse.parameter_scope import Scope, DictScope from qupulse.utils.types import ChannelID @@ -12,6 +11,7 @@ from qupulse.pulses.parameters import Parameter, ConstantParameter, ParameterNotProvidedException from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform from qupulse._program._loop import Loop +from qupulse._program import ProgramBuilder from qupulse._program.transformation import Transformation from qupulse._program.waveforms import TransformingWaveform @@ -70,7 +70,7 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional[Transformation], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop): + parent_loop: ProgramBuilder): raise NotImplementedError() @property @@ -88,9 +88,9 @@ def __repr__(self): def get_appending_internal_create_program(waveform=DummyWaveform(), always_append=False, measurements: list=None): - def internal_create_program(*, scope, parent_loop: Loop, **_): + def internal_create_program(*, scope, parent_loop: ProgramBuilder, **_): if always_append or 'append_a_child' in scope: - parent_loop.append_child(waveform=waveform, measurements=measurements) + parent_loop.append_leaf(waveform=waveform, measurements=measurements) return internal_create_program @@ -246,7 +246,7 @@ def test__create_program_single_waveform(self): with mock.patch.object(template, '_internal_create_program', wraps=appending_create_program) as _internal_create_program: - with mock.patch('qupulse.pulses.pulse_template.to_waveform', + with mock.patch('qupulse._program._loop.to_waveform', return_value=single_waveform) as to_waveform: template._create_program(scope=scope, measurement_mapping=measurement_mapping, @@ -261,7 +261,6 @@ def test__create_program_single_waveform(self): global_transformation=None, to_single_waveform=to_single_waveform, parent_loop=expected_inner_program) - to_waveform.assert_called_once_with(expected_inner_program) self.assertEqual(expected_program, parent_loop) diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 63ced5d12..59b7cda45 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -7,6 +7,7 @@ """LOCAL IMPORTS""" from qupulse.parameter_scope import Scope +from qupulse._program import ProgramBuilder from qupulse._program._loop import Loop from qupulse.utils.types import MeasurementWindow, ChannelID, TimeType, time_from_float from qupulse.serialization import Serializer @@ -237,15 +238,15 @@ def _internal_create_program(self, *, channel_mapping: Dict[ChannelID, Optional[ChannelID]], global_transformation: Optional['Transformation'], to_single_waveform: Set[Union[str, 'PulseTemplate']], - parent_loop: Loop) -> None: + parent_loop: ProgramBuilder) -> None: measurements = self.get_measurement_windows(scope, measurement_mapping) self.create_program_calls.append((scope, measurement_mapping, channel_mapping, parent_loop)) if self._program: - parent_loop.append_child(waveform=self._program.waveform, children=self._program.children, - measurements=measurements) + parent_loop.append_leaf(waveform=self._program.waveform, children=self._program.children, + measurements=measurements) elif self.waveform: - parent_loop.append_child(waveform=self.build_waveform(parameters=scope, channel_mapping=channel_mapping), - measurements=measurements) + parent_loop.append_leaf(waveform=self.build_waveform(parameters=scope, channel_mapping=channel_mapping), + measurements=measurements) def build_waveform(self, parameters: Dict[str, Parameter], From a3f6a67553ae2c6203fcbacc80a3f6422bdb68b1 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 13 May 2022 09:02:43 +0200 Subject: [PATCH 10/35] Better docs and test --- qupulse/_program/__init__.py | 6 +++++- tests/_program/waveforms_tests.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/qupulse/_program/__init__.py b/qupulse/_program/__init__.py index 8205a1af3..20c08d4a1 100644 --- a/qupulse/_program/__init__.py +++ b/qupulse/_program/__init__.py @@ -18,14 +18,18 @@ class Program(Protocol): + """This protocol is used to inspect and or manipulate programs""" + def to_single_waveform(self) -> Waveform: pass - def get_measurement_windows(self) -> Mapping[str, np.ndarray]: + def get_measurement_windows(self) -> Mapping[str, np.ndarray]: pass class ProgramBuilder(Protocol): + """This protocol is used by PulseTemplate to build the program.""" + def append_leaf(self, waveform: Waveform, measurements: Optional[Sequence[MeasurementWindow]] = None, repetition_count: int = 1): diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index 75a7c525c..5c65147e2 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -589,7 +589,7 @@ def test_unsafe_get_subset_for_channels(self): entries = (TableWaveformEntry(0, 0, interp), TableWaveformEntry(2.1, -33.2, interp), TableWaveformEntry(5.7, 123.4, interp)) - waveform = TableWaveform('A', entries) + waveform = TableWaveform.from_table('A', entries) self.assertIs(waveform.unsafe_get_subset_for_channels({'A'}), waveform) def test_unsafe_sample(self) -> None: @@ -597,7 +597,7 @@ def test_unsafe_sample(self) -> None: entries = (TableWaveformEntry(0, 0, interp), TableWaveformEntry(2.1, -33.2, interp), TableWaveformEntry(5.7, 123.4, interp)) - waveform = TableWaveform('A', entries) + waveform = TableWaveform.from_table('A', entries) sample_times = numpy.linspace(.5, 5.5, num=11) expected_interp_arguments = [((0, 0), (2.1, -33.2), [0.5, 1.0, 1.5, 2.0]), From 18e91ee1dfa6439714e1c173637a6c63e1b1449d Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 29 Jun 2022 12:30:02 +0200 Subject: [PATCH 11/35] =?UTF-8?q?Conditional=20rust=20re=C3=BCplacement=20?= =?UTF-8?q?import?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- qupulse/_program/waveforms.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 5ab13dcc9..5386190a5 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -1222,10 +1222,10 @@ def reversed(self) -> 'Waveform': return self._inner -TableWaveform = rs_replacements.waveforms.TableWaveform -ConstantWaveform = rs_replacements.waveforms.ConstantWaveform -MultiChannelWaveform = rs_replacements.waveforms.MultiChannelWaveform -Waveform.register(TableWaveform) -Waveform.register(ConstantWaveform) -Waveform.register(MultiChannelWaveform) - +if rs_replacements is not None: + TableWaveform = rs_replacements.waveforms.TableWaveform + ConstantWaveform = rs_replacements.waveforms.ConstantWaveform + MultiChannelWaveform = rs_replacements.waveforms.MultiChannelWaveform + Waveform.register(TableWaveform) + Waveform.register(ConstantWaveform) + Waveform.register(MultiChannelWaveform) From cb0e78ae334090ac0bd9d739bda1a42ec3631e10 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 29 Jun 2022 15:51:59 +0200 Subject: [PATCH 12/35] Improve testability with rust extension --- qupulse/_program/__init__.py | 8 ++++++-- qupulse/_program/waveforms.py | 9 +++++++-- tests/_program/waveforms_tests.py | 15 +++++++++------ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/qupulse/_program/__init__.py b/qupulse/_program/__init__.py index 20c08d4a1..e865047c9 100644 --- a/qupulse/_program/__init__.py +++ b/qupulse/_program/__init__.py @@ -52,5 +52,9 @@ def to_program(self) -> Optional[Program]: def default_program_builder() -> ProgramBuilder: - from qupulse._program._loop import Loop - return Loop() + try: + import qupulse_rs.qupulse_rs + return qupulse_rs.qupulse_rs.replacements.ProgramBuilder() + except (AttributeError, ImportError): + from qupulse._program._loop import Loop + return Loop() diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 5386190a5..d28913132 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -848,8 +848,8 @@ def unsafe_sample(self, return output_array @property - def compare_key(self) -> Tuple[Any, int]: - return self._body.compare_key, self._repetition_count + def compare_key(self) -> Tuple[int, Any]: + return self._repetition_count, self._body def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'RepetitionWaveform': return RepetitionWaveform(body=self._body.unsafe_get_subset_for_channels(channels), @@ -1223,9 +1223,14 @@ def reversed(self) -> 'Waveform': if rs_replacements is not None: + PyTableWaveform = TableWaveform + PyConstantWaveform = ConstantWaveform + PyMultiChannelWaveform = MultiChannelWaveform + TableWaveform = rs_replacements.waveforms.TableWaveform ConstantWaveform = rs_replacements.waveforms.ConstantWaveform MultiChannelWaveform = rs_replacements.waveforms.MultiChannelWaveform + Waveform.register(TableWaveform) Waveform.register(ConstantWaveform) Waveform.register(MultiChannelWaveform) diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index ba1d86755..ddb7a3844 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -9,7 +9,7 @@ JumpInterpolationStrategy from qupulse._program.waveforms import MultiChannelWaveform, RepetitionWaveform, SequenceWaveform,\ TableWaveformEntry, TableWaveform, TransformingWaveform, SubsetWaveform, ArithmeticWaveform, ConstantWaveform,\ - Waveform, FunctorWaveform, FunctionWaveform, ReversedWaveform + Waveform, FunctorWaveform, FunctionWaveform, ReversedWaveform, rs_replacements from qupulse._program.transformation import LinearTransformation from qupulse.expressions import ExpressionScalar, Expression @@ -150,9 +150,9 @@ def test_slot(self): class MultiChannelWaveformTest(unittest.TestCase): def test_init_no_args(self) -> None: - with self.assertRaises(ValueError): + with self.assertRaises((TypeError, ValueError)): MultiChannelWaveform(dict()) - with self.assertRaises(ValueError): + with self.assertRaises((TypeError, ValueError)): MultiChannelWaveform(None) def test_from_parallel(self): @@ -239,7 +239,9 @@ def test_unsafe_sample(self) -> None: result_a = waveform.unsafe_sample('A', sample_times, reuse_output) self.assertEqual(len(dwf_a.sample_calls), 2) self.assertIs(result_a, reuse_output) - self.assertIs(result_a, dwf_a.sample_calls[1][2]) + if rs_replacements is None: + # rust extension cannot forward the numpy array back to python without performance degradation + self.assertIs(result_a, dwf_a.sample_calls[1][2]) numpy.testing.assert_equal(result_b, samples_b) def test_equality(self) -> None: @@ -333,7 +335,7 @@ def test_defined_channels(self): def test_compare_key(self): body_wf = DummyWaveform(defined_channels={'a'}) wf = RepetitionWaveform(body_wf, 2) - self.assertEqual(wf.compare_key, (body_wf.compare_key, 2)) + self.assertEqual(wf.compare_key, (2, body_wf)) def test_unsafe_get_subset_for_channels(self): body_wf = DummyWaveform(defined_channels={'a', 'b'}) @@ -606,7 +608,8 @@ def test_unsafe_sample(self) -> None: result = waveform.unsafe_sample('A', sample_times) - self.assertEqual(expected_interp_arguments, interp.call_arguments) + if rs_replacements is None: + self.assertEqual(expected_interp_arguments, interp.call_arguments) numpy.testing.assert_equal(expected_result, result) output_expected = numpy.empty_like(expected_result) From 6b6ba18dec76ca080585171a1122f497ae2b0d75 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 29 Jun 2022 17:11:29 +0200 Subject: [PATCH 13/35] Fix more tests by relaxing the assumptions --- tests/pulses/multi_channel_pulse_template_tests.py | 2 +- tests/pulses/point_pulse_template_tests.py | 8 ++++---- tests/pulses/table_pulse_template_tests.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 114c6fee9..7de38ae72 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -92,7 +92,7 @@ def test_instantiation_duration_check(self): amcpt = AtomicMultiChannelPulseTemplate(*subtemplates, duration=True) self.assertIs(amcpt.duration, subtemplates[0].duration) - with self.assertRaisesRegex(ValueError, 'duration'): + with self.assertRaisesRegex(ValueError, '[dD]uration'): amcpt.build_waveform(parameters=dict(t_1=3, t_2=3, t_3=3), channel_mapping={ch: ch for ch in 'c1 c2 c3'.split()}) diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index 9316379ed..3277319df 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -154,8 +154,8 @@ def test_build_waveform_multi_channel_same(self): (1., 0., HoldInterpolationStrategy()), (1.1, 21., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._sub_waveforms[0], expected_1) - self.assertEqual(wf._sub_waveforms[1], expected_A) + self.assertEqual(wf.compare_key[0], expected_1) + self.assertEqual(wf.compare_key[1], expected_A) def test_build_waveform_multi_channel_vectorized(self): ppt = PointPulseTemplate([('t1', 'A'), @@ -173,8 +173,8 @@ def test_build_waveform_multi_channel_vectorized(self): (1., 0., HoldInterpolationStrategy()), (1.1, 20., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._sub_waveforms[0], expected_1) - self.assertEqual(wf._sub_waveforms[1], expected_A) + self.assertEqual(wf.compare_key[0], expected_1) + self.assertEqual(wf.compare_key[1], expected_A) def test_build_waveform_none_channel(self): ppt = PointPulseTemplate([('t1', 'A'), diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index b6c26b495..c77311656 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -700,7 +700,7 @@ def test_build_waveform_multi_channel(self): (5.1, 0, LinearInterpolationStrategy()))), ] - self.assertEqual(waveform._sub_waveforms, tuple(expected_waveforms)) + self.assertEqual(waveform.compare_key, tuple(expected_waveforms)) def test_build_waveform_none(self) -> None: table = TablePulseTemplate({0: [(0, 0), From f41f9ef98b3aaf059df6b83ed3149ee14e1bb769 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 29 Jun 2022 17:12:10 +0200 Subject: [PATCH 14/35] Some constant pulse template generalizations --- qupulse/pulses/constant_pulse_template.py | 27 +++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/qupulse/pulses/constant_pulse_template.py b/qupulse/pulses/constant_pulse_template.py index 5c63d0646..ea56283a9 100644 --- a/qupulse/pulses/constant_pulse_template.py +++ b/qupulse/pulses/constant_pulse_template.py @@ -10,6 +10,8 @@ import numbers from typing import Any, Dict, List, Optional, Set, Union +import cached_property + from qupulse._program import ProgramBuilder from qupulse._program.waveforms import ConstantWaveform from qupulse.expressions import ExpressionScalar @@ -20,6 +22,7 @@ Loop, MeasurementDeclaration, PulseTemplate, Transformation, TransformingWaveform) +from qupulse.serialization import PulseRegistryType __all__ = ["ConstantPulseTemplate"] @@ -27,7 +30,9 @@ class ConstantPulseTemplate(AtomicPulseTemplate): # type: ignore def __init__(self, duration: float, amplitude_dict: Dict[str, Any], identifier: Optional[str] = None, - name: Optional[str] = None, measurements: Optional[List[MeasurementDeclaration]] = [], **kwargs: Any) -> None: + name: Optional[str] = None, measurements: Optional[List[MeasurementDeclaration]] = (), + registry: PulseRegistryType = None, + **kwargs: Any) -> None: """ A qupulse waveform representing a multi-channel pulse with constant values Args: @@ -39,12 +44,14 @@ def __init__(self, duration: float, amplitude_dict: Dict[str, Any], identifier: super().__init__(identifier=identifier, measurements=measurements, **kwargs) self._duration = ExpressionScalar(duration) - self._amplitude_dict = amplitude_dict + self._amplitude_dict = {channel: ExpressionScalar(value) for channel, value in amplitude_dict.items()} if name is None: name = 'constant_pulse' self._name = name + self._register(registry=registry) + def _as_expression(self): return self._amplitude_dict @@ -56,7 +63,7 @@ def build_sequence(self) -> None: def get_serialization_data(self) -> Any: data = super().get_serialization_data() - data.update({'name': self._name, 'duration': self._duration, '#amplitudes': self._amplitude_dict}) + data.update({'name': self._name, 'duration': self._duration, 'amplitude_dict': self._amplitude_dict}) return data @property @@ -64,10 +71,13 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: """Returns an expression giving the integral over the pulse.""" return {c: self.duration * self._amplitude_dict[c] for c in self._amplitude_dict} - @property + @cached_property.cached_property def parameter_names(self) -> Set[str]: """The set of names of parameters required to instantiate this PulseTemplate.""" - return set() + parameter_names = set(getattr(self._duration, 'variables', ())) + for value in self._amplitude_dict.values(): + parameter_names.update(getattr(value, 'variables', ())) + return parameter_names @property def is_interruptable(self) -> bool: @@ -78,7 +88,7 @@ def is_interruptable(self) -> bool: @property def duration(self) -> ExpressionScalar: """An expression for the duration of this PulseTemplate.""" - return (self._duration) + return self._duration @property def defined_channels(self) -> Set['ChannelID']: @@ -99,9 +109,12 @@ def build_waveform(self, for channel, value in self._amplitude_dict.items(): mapped_channel = channel_mapping[channel] if mapped_channel is not None: + evaluator = getattr(value, 'evaluate_in_scope', None) + if evaluator: + value = evaluator(parameters) constant_values[mapped_channel] = value if constant_values: - return ConstantWaveform.from_mapping(self.duration, constant_values) + return ConstantWaveform.from_mapping(self.duration.evaluate_in_scope(parameters), constant_values) else: return None From cbe8532051d8a8fbd7193d066cd69c8dcaea1573 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 29 Jun 2022 17:51:32 +0200 Subject: [PATCH 15/35] Skip test in python 3.7 --- tests/pulses/loop_pulse_template_tests.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pulses/loop_pulse_template_tests.py b/tests/pulses/loop_pulse_template_tests.py index 2716004e4..c22c46a4e 100644 --- a/tests/pulses/loop_pulse_template_tests.py +++ b/tests/pulses/loop_pulse_template_tests.py @@ -1,3 +1,4 @@ +import sys import unittest from unittest import mock @@ -337,6 +338,7 @@ def test_create_program_body_none(self) -> None: self.assertEqual(1, program.repetition_count) self.assertEqual([], list(program.children)) + @unittest.skipIf(sys.version_info.minor < 8, "Python 3.7 does not support changing mock call args") def test_create_program(self) -> None: dt = DummyPulseTemplate(parameter_names={'i'}, waveform=DummyWaveform(duration=4.0, defined_channels={'A'}), From f9d339a14ec9d23602ab84ffdec33c9930bc1c2e Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 29 Jun 2022 18:05:02 +0200 Subject: [PATCH 16/35] Fix wrong legacy import --- qupulse/pulses/constant_pulse_template.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/qupulse/pulses/constant_pulse_template.py b/qupulse/pulses/constant_pulse_template.py index ea56283a9..f14afd61c 100644 --- a/qupulse/pulses/constant_pulse_template.py +++ b/qupulse/pulses/constant_pulse_template.py @@ -10,8 +10,7 @@ import numbers from typing import Any, Dict, List, Optional, Set, Union -import cached_property - +from qupulse.utils import cached_property from qupulse._program import ProgramBuilder from qupulse._program.waveforms import ConstantWaveform from qupulse.expressions import ExpressionScalar @@ -71,7 +70,7 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: """Returns an expression giving the integral over the pulse.""" return {c: self.duration * self._amplitude_dict[c] for c in self._amplitude_dict} - @cached_property.cached_property + @cached_property def parameter_names(self) -> Set[str]: """The set of names of parameters required to instantiate this PulseTemplate.""" parameter_names = set(getattr(self._duration, 'variables', ())) From 5947765f5ead32a0b2c9701a1e79dfb64963eb4b Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Thu, 30 Jun 2022 13:23:48 +0200 Subject: [PATCH 17/35] Move Loop -> to_single_waveform code --- qupulse/_program/_loop.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index a10835300..9c6e4add6 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -495,7 +495,21 @@ def get_duration_structure(self) -> Tuple[int, Union[TimeType, tuple]]: return self.repetition_count, tuple(child.get_duration_structure() for child in self) def to_single_waveform(self) -> Waveform: - return to_waveform(self) + if self.is_leaf(): + if self.repetition_count == 1: + return self.waveform + else: + return RepetitionWaveform(self.waveform, self.repetition_count) + else: + if len(self) == 1: + sequenced_waveform = to_waveform(cast(Loop, self[0])) + else: + sequenced_waveform = SequenceWaveform([to_waveform(cast(Loop, sub_program)) + for sub_program in self]) + if self.repetition_count > 1: + return RepetitionWaveform(sequenced_waveform, self.repetition_count) + else: + return sequenced_waveform def append_leaf(self, waveform: Waveform, measurements: Optional[Sequence[MeasurementWindow]] = None, @@ -527,21 +541,7 @@ def __init__(self, channel_sets): def to_waveform(program: Loop) -> Waveform: - if program.is_leaf(): - if program.repetition_count == 1: - return program.waveform - else: - return RepetitionWaveform(program.waveform, program.repetition_count) - else: - if len(program) == 1: - sequenced_waveform = to_waveform(cast(Loop, program[0])) - else: - sequenced_waveform = SequenceWaveform([to_waveform(cast(Loop, sub_program)) - for sub_program in program]) - if program.repetition_count > 1: - return RepetitionWaveform(sequenced_waveform, program.repetition_count) - else: - return sequenced_waveform + return program.to_single_waveform() class _CompatibilityLevel(Enum): From 0570c793d75f13db50ceb0e0b677da1ad738b1b0 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Thu, 30 Jun 2022 13:24:18 +0200 Subject: [PATCH 18/35] Make some create_program tests ProgramBuilder aware --- tests/pulses/pulse_template_tests.py | 39 +++++++++++++++------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index f6b97a63c..efcff5cf6 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -11,7 +11,7 @@ from qupulse.pulses.parameters import Parameter, ConstantParameter, ParameterNotProvidedException from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform from qupulse._program._loop import Loop -from qupulse._program import ProgramBuilder +from qupulse._program import ProgramBuilder, default_program_builder from qupulse._program.transformation import Transformation from qupulse._program.waveforms import TransformingWaveform @@ -170,13 +170,14 @@ def test_create_program(self) -> None: with mock.patch.object(template, '_create_program', wraps=get_appending_internal_create_program(dummy_waveform)) as _create_program: - program = template.create_program(parameters=parameters, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - to_single_waveform=to_single_waveform, - global_transformation=global_transformation, - volatile=volatile) - _create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=program) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', return_value=default_program_builder()) as _default_program_builder: + program = template.create_program(parameters=parameters, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + to_single_waveform=to_single_waveform, + global_transformation=global_transformation, + volatile=volatile) + _create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=_default_program_builder.return_value) self.assertEqual(expected_program, program) self.assertEqual(previos_measurement_mapping, measurement_mapping) self.assertEqual(previous_channel_mapping, channel_mapping) @@ -281,8 +282,9 @@ def test_create_program_defaults(self) -> None: with mock.patch.object(template, '_internal_create_program', wraps=get_appending_internal_create_program(dummy_waveform, True)) as _internal_create_program: - program = template.create_program() - _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=program) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', return_value=default_program_builder()) as _default_program_builder: + program = template.create_program() + _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=_default_program_builder.return_value) self.assertEqual(expected_program, program) def test_create_program_channel_mapping(self): @@ -295,9 +297,9 @@ def test_create_program_channel_mapping(self): to_single_waveform=set()) with mock.patch.object(template, '_internal_create_program') as _internal_create_program: - template.create_program(channel_mapping={'A': 'C'}) - - _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=Loop()) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', return_value=default_program_builder()) as _default_program_builder: + template.create_program(channel_mapping={'A': 'C'}) + _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=_default_program_builder.return_value) def test_create_program_none(self) -> None: template = PulseTemplateStub(defined_channels={'A'}, parameter_names={'foo'}) @@ -318,11 +320,12 @@ def test_create_program_none(self) -> None: with mock.patch.object(template, '_internal_create_program') as _internal_create_program: - program = template.create_program(parameters=parameters, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - volatile=volatile) - _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=Loop()) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', return_value=default_program_builder()) as _default_program_builder: + program = template.create_program(parameters=parameters, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + volatile=volatile) + _internal_create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=_default_program_builder.return_value) self.assertIsNone(program) def test_matmul(self): From 7a2fceebb36fc75523e47e15e48303bb3dad9d00 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 12 Jul 2022 08:53:21 +0200 Subject: [PATCH 19/35] Make Program runtime_chackable --- qupulse/_program/__init__.py | 11 +++++++++-- qupulse/pulses/plotting.py | 8 +++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/qupulse/_program/__init__.py b/qupulse/_program/__init__.py index e865047c9..a0b20a306 100644 --- a/qupulse/_program/__init__.py +++ b/qupulse/_program/__init__.py @@ -5,18 +5,22 @@ import numpy as np from qupulse._program.waveforms import Waveform -from qupulse.utils.types import MeasurementWindow +from qupulse.utils.types import MeasurementWindow, TimeType from qupulse._program.volatile import VolatileRepetitionCount try: - from typing import Protocol + from typing import Protocol, runtime_checkable except ImportError: Protocol = object + def runtime_checkable(cls): + return cls + RepetitionCount = Union[int, VolatileRepetitionCount] +@runtime_checkable class Program(Protocol): """This protocol is used to inspect and or manipulate programs""" @@ -26,6 +30,9 @@ def to_single_waveform(self) -> Waveform: def get_measurement_windows(self) -> Mapping[str, np.ndarray]: pass + @property + def duration(self) -> TimeType: + raise NotImplementedError() class ProgramBuilder(Protocol): """This protocol is used by PulseTemplate to build the program.""" diff --git a/qupulse/pulses/plotting.py b/qupulse/pulses/plotting.py index ecde51d44..cc8f4fbf5 100644 --- a/qupulse/pulses/plotting.py +++ b/qupulse/pulses/plotting.py @@ -14,7 +14,7 @@ import operator import itertools -from qupulse._program import waveforms +from qupulse._program import waveforms, Program from qupulse.utils.types import ChannelID, MeasurementWindow, has_type_interface from qupulse.pulses.pulse_template import PulseTemplate from qupulse.pulses.parameters import Parameter @@ -52,6 +52,12 @@ def render(program: Union[Loop], """ if has_type_interface(program, Loop): waveform, measurements = _render_loop(program, render_measurements=render_measurements) + elif isinstance(program, Program): + waveform = program.to_single_waveform() + measurements = program.get_measurement_windows() + measurements = [(name, begin, length) + for name, (begins, lengths) in measurements.items() + for begin, length in zip(begins, lengths)] else: raise ValueError('Cannot render an object of type %r' % type(program), program) From 52c13d49451623f3c4beddf0f023c5fbf8706668 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 12 Jul 2022 08:55:48 +0200 Subject: [PATCH 20/35] Fix SubsetWaveform constant_value_dict --- qupulse/_program/waveforms.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index d28913132..052793ca8 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -984,9 +984,14 @@ def unsafe_sample(self, return self.inner_waveform.unsafe_sample(channel, sample_times, output_array) def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: - d = self._inner_waveform.constant_value_dict() - if d is not None: - return {ch: d[ch] for ch in self._channel_subset} + constant_values = {} + for ch in self.defined_channels: + value = self._inner_waveform.constant_value(ch) + if value is None: + return + else: + constant_values[ch] = value + return constant_values def constant_value(self, channel: ChannelID) -> Optional[float]: if channel not in self._channel_subset: From 375f40499754e2d67f5c7824ee1b94fa1f2fc088 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 12 Jul 2022 08:56:32 +0200 Subject: [PATCH 21/35] Move compability and waveform trafo code to Loop class --- qupulse/_program/__init__.py | 5 ++ qupulse/_program/_loop.py | 90 +++++++++++++++++++----------------- 2 files changed, 53 insertions(+), 42 deletions(-) diff --git a/qupulse/_program/__init__.py b/qupulse/_program/__init__.py index a0b20a306..26ce85ea8 100644 --- a/qupulse/_program/__init__.py +++ b/qupulse/_program/__init__.py @@ -34,6 +34,11 @@ def get_measurement_windows(self) -> Mapping[str, np.ndarray]: def duration(self) -> TimeType: raise NotImplementedError() + def make_compatible_inplace(self): + # TODO: rename? + pass + + class ProgramBuilder(Protocol): """This protocol is used by PulseTemplate to build the program.""" diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index a10835300..b39c32111 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -495,7 +495,21 @@ def get_duration_structure(self) -> Tuple[int, Union[TimeType, tuple]]: return self.repetition_count, tuple(child.get_duration_structure() for child in self) def to_single_waveform(self) -> Waveform: - return to_waveform(self) + if self.is_leaf(): + if self.repetition_count == 1: + return self.waveform + else: + return RepetitionWaveform(self.waveform, self.repetition_count) + else: + if len(self) == 1: + sequenced_waveform = to_waveform(cast(Loop, self[0])) + else: + sequenced_waveform = SequenceWaveform([to_waveform(cast(Loop, sub_program)) + for sub_program in self]) + if self.repetition_count > 1: + return RepetitionWaveform(sequenced_waveform, self.repetition_count) + else: + return sequenced_waveform def append_leaf(self, waveform: Waveform, measurements: Optional[Sequence[MeasurementWindow]] = None, @@ -520,6 +534,37 @@ def reverse_inplace(self): for name, begin, length in self._measurements ] + def make_compatible_inplace(self, minimal_waveform_length: int, waveform_quantum: int, sample_rate: TimeType): + program = self + comp_level = _is_compatible(program, + min_len=minimal_waveform_length, + quantum=waveform_quantum, + sample_rate=sample_rate) + if comp_level == _CompatibilityLevel.incompatible_fraction: + raise ValueError( + 'The program duration in samples {} is not an integer'.format(program.duration * sample_rate)) + if comp_level == _CompatibilityLevel.incompatible_too_short: + raise ValueError('The program is too short to be a valid waveform. \n' + ' program duration in samples: {} \n' + ' minimal length: {}'.format(program.duration * sample_rate, minimal_waveform_length)) + if comp_level == _CompatibilityLevel.incompatible_quantum: + raise ValueError('The program duration in samples {} ' + 'is not a multiple of quantum {}'.format(program.duration * sample_rate, waveform_quantum)) + + elif comp_level == _CompatibilityLevel.action_required: + warnings.warn( + "qupulse will now concatenate waveforms to make the pulse/program compatible with the chosen AWG." + " This might take some time. If you need this pulse more often it makes sense to write it in a " + "way which is more AWG friendly.", MakeCompatibleWarning) + + _make_compatible(program, + min_len=minimal_waveform_length, + quantum=waveform_quantum, + sample_rate=sample_rate) + + else: + assert comp_level == _CompatibilityLevel.compatible + class ChannelSplit(Exception): def __init__(self, channel_sets): @@ -527,21 +572,7 @@ def __init__(self, channel_sets): def to_waveform(program: Loop) -> Waveform: - if program.is_leaf(): - if program.repetition_count == 1: - return program.waveform - else: - return RepetitionWaveform(program.waveform, program.repetition_count) - else: - if len(program) == 1: - sequenced_waveform = to_waveform(cast(Loop, program[0])) - else: - sequenced_waveform = SequenceWaveform([to_waveform(cast(Loop, sub_program)) - for sub_program in program]) - if program.repetition_count > 1: - return RepetitionWaveform(sequenced_waveform, program.repetition_count) - else: - return sequenced_waveform + return program.to_single_waveform() class _CompatibilityLevel(Enum): @@ -622,32 +653,7 @@ def _make_compatible(program: Loop, min_len: int, quantum: int, sample_rate: Tim def make_compatible(program: Loop, minimal_waveform_length: int, waveform_quantum: int, sample_rate: TimeType): """ check program for compatibility to AWG requirements, make it compatible if necessary and possible""" - comp_level = _is_compatible(program, - min_len=minimal_waveform_length, - quantum=waveform_quantum, - sample_rate=sample_rate) - if comp_level == _CompatibilityLevel.incompatible_fraction: - raise ValueError('The program duration in samples {} is not an integer'.format(program.duration * sample_rate)) - if comp_level == _CompatibilityLevel.incompatible_too_short: - raise ValueError('The program is too short to be a valid waveform. \n' - ' program duration in samples: {} \n' - ' minimal length: {}'.format(program.duration * sample_rate, minimal_waveform_length)) - if comp_level == _CompatibilityLevel.incompatible_quantum: - raise ValueError('The program duration in samples {} ' - 'is not a multiple of quantum {}'.format(program.duration * sample_rate, waveform_quantum)) - - elif comp_level == _CompatibilityLevel.action_required: - warnings.warn("qupulse will now concatenate waveforms to make the pulse/program compatible with the chosen AWG." - " This might take some time. If you need this pulse more often it makes sense to write it in a " - "way which is more AWG friendly.", MakeCompatibleWarning) - - _make_compatible(program, - min_len=minimal_waveform_length, - quantum=waveform_quantum, - sample_rate=sample_rate) - - else: - assert comp_level == _CompatibilityLevel.compatible + program.make_compatible_inplace(minimal_waveform_length, waveform_quantum, sample_rate) def roll_constant_waveforms(program: Loop, minimal_waveform_quanta: int, waveform_quantum: int, sample_rate: TimeType): From b893752255d83c2c931082d4995a4761aeed13af Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 12 Jul 2022 08:57:00 +0200 Subject: [PATCH 22/35] Add more rust waveforms --- qupulse/_program/waveforms.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 052793ca8..7e67640e1 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -1231,11 +1231,24 @@ def reversed(self) -> 'Waveform': PyTableWaveform = TableWaveform PyConstantWaveform = ConstantWaveform PyMultiChannelWaveform = MultiChannelWaveform + PyRepetitionWaveform = RepetitionWaveform + PySequenceWaveform = SequenceWaveform + PyArithmeticWaveform = ArithmeticWaveform + PySubsetWaveform = SubsetWaveform TableWaveform = rs_replacements.waveforms.TableWaveform ConstantWaveform = rs_replacements.waveforms.ConstantWaveform MultiChannelWaveform = rs_replacements.waveforms.MultiChannelWaveform + RepetitionWaveform = rs_replacements.waveforms.RepetitionWaveform + SequenceWaveform = rs_replacements.waveforms.SequenceWaveform + ArithmeticWaveform = rs_replacements.waveforms.ArithmeticWaveform + SubsetWaveform = rs_replacements.waveforms.SubsetWaveform Waveform.register(TableWaveform) Waveform.register(ConstantWaveform) Waveform.register(MultiChannelWaveform) + Waveform.register(TableWaveform) + Waveform.register(RepetitionWaveform) + Waveform.register(SequenceWaveform) + Waveform.register(ArithmeticWaveform) + Waveform.register(SubsetWaveform) From 8cd42d7abaf6a587b24d480a474a7c9f77d9c01d Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 12 Jul 2022 08:57:10 +0200 Subject: [PATCH 23/35] Fix tests --- tests/_program/loop_tests.py | 54 +++++++--------- tests/_program/waveforms_tests.py | 61 +++++++++++++------ tests/pulses/constant_pulse_template_tests.py | 2 +- tests/pulses/pulse_template_tests.py | 26 ++++---- tests/pulses/sequence_pulse_template_tests.py | 2 +- 5 files changed, 80 insertions(+), 65 deletions(-) diff --git a/tests/_program/loop_tests.py b/tests/_program/loop_tests.py index 6776e809c..2f674a743 100644 --- a/tests/_program/loop_tests.py +++ b/tests/_program/loop_tests.py @@ -407,29 +407,22 @@ def test_make_compatible_partial_unroll(self): program = Loop(children=[Loop(waveform=wf1, repetition_count=2), Loop(waveform=wf2)]) + expected_program = Loop(children=[ + Loop(waveform=RepetitionWaveform(wf1, 2)), + Loop(waveform=wf2) + ]) _make_compatible(program, min_len=1, quantum=1, sample_rate=TimeType.from_float(1.)) - - self.assertIsNone(program.waveform) - self.assertEqual(len(program), 2) - self.assertIsInstance(program[0].waveform, RepetitionWaveform) - self.assertIs(program[0].waveform._body, wf1) - self.assertEqual(program[0].waveform._repetition_count, 2) - self.assertIs(program[1].waveform, wf2) + self.assertEqual(expected_program, program) program = Loop(children=[Loop(waveform=wf1, repetition_count=2), - Loop(waveform=wf2)], repetition_count=2) + Loop(waveform=wf2)], repetition_count=3) + expected_program = Loop(waveform=SequenceWaveform([ + RepetitionWaveform(wf1, 2), + wf2 + ]), repetition_count=3) _make_compatible(program, min_len=5, quantum=1, sample_rate=TimeType.from_float(1.)) - - self.assertIsInstance(program.waveform, SequenceWaveform) - self.assertEqual(list(program.children), []) - self.assertEqual(program.repetition_count, 2) - - self.assertEqual(len(program.waveform._sequenced_waveforms), 2) - self.assertIsInstance(program.waveform._sequenced_waveforms[0], RepetitionWaveform) - self.assertIs(program.waveform._sequenced_waveforms[0]._body, wf1) - self.assertEqual(program.waveform._sequenced_waveforms[0]._repetition_count, 2) - self.assertIs(program.waveform._sequenced_waveforms[1], wf2) + self.assertEqual(expected_program, program) def test_make_compatible_complete_unroll(self): wf1 = DummyWaveform(duration=1.5) @@ -438,21 +431,18 @@ def test_make_compatible_complete_unroll(self): program = Loop(children=[Loop(waveform=wf1, repetition_count=2), Loop(waveform=wf2, repetition_count=1)], repetition_count=2) - _make_compatible(program, min_len=5, quantum=10, sample_rate=TimeType.from_float(1.)) - - self.assertIsInstance(program.waveform, RepetitionWaveform) - self.assertEqual(list(program.children), []) - self.assertEqual(program.repetition_count, 1) + expected_program = Loop(repetition_count=1, + waveform=RepetitionWaveform( + body=SequenceWaveform([ + RepetitionWaveform(wf1, 2), + wf2 + ]), + repetition_count=2 + ) + ) - self.assertIsInstance(program.waveform, RepetitionWaveform) - - self.assertIsInstance(program.waveform._body, SequenceWaveform) - body_wf = program.waveform._body - self.assertEqual(len(body_wf._sequenced_waveforms), 2) - self.assertIsInstance(body_wf._sequenced_waveforms[0], RepetitionWaveform) - self.assertIs(body_wf._sequenced_waveforms[0]._body, wf1) - self.assertEqual(body_wf._sequenced_waveforms[0]._repetition_count, 2) - self.assertIs(body_wf._sequenced_waveforms[1], wf2) + _make_compatible(program, min_len=5, quantum=10, sample_rate=TimeType.from_float(1.)) + self.assertEqual(expected_program, program) def test_make_compatible(self): program = Loop() diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index ddb7a3844..207d33305 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -303,10 +303,10 @@ def __init__(self, *args, **kwargs): def test_init(self): body_wf = DummyWaveform() - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, OverflowError)): RepetitionWaveform(body_wf, -1) - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, TypeError)): RepetitionWaveform(body_wf, 1.1) wf = RepetitionWaveform(body_wf, 3) @@ -320,9 +320,7 @@ def test_from_repetition_count(self): self.assertEqual(RepetitionWaveform(dwf, 3), RepetitionWaveform.from_repetition_count(dwf, 3)) cwf = ConstantWaveform(duration=3, amplitude=2.2, channel='A') - with mock.patch.object(ConstantWaveform, 'from_mapping', return_value=mock.sentinel) as from_mapping: - self.assertIs(from_mapping.return_value, RepetitionWaveform.from_repetition_count(cwf, 5)) - from_mapping.assert_called_once_with(15, {'A': 2.2}) + self.assertEqual(ConstantWaveform.from_mapping(15, {'A': 2.2}), RepetitionWaveform.from_repetition_count(cwf, 5)) def test_duration(self): wf = RepetitionWaveform(DummyWaveform(duration=2.2), 3) @@ -330,7 +328,7 @@ def test_duration(self): def test_defined_channels(self): body_wf = DummyWaveform(defined_channels={'a'}) - self.assertIs(RepetitionWaveform(body_wf, 2).defined_channels, body_wf.defined_channels) + self.assertEqual(RepetitionWaveform(body_wf, 2).defined_channels, body_wf.defined_channels) def test_compare_key(self): body_wf = DummyWaveform(defined_channels={'a'}) @@ -345,7 +343,7 @@ def test_unsafe_get_subset_for_channels(self): subset = RepetitionWaveform(body_wf, 3).get_subset_for_channels(chs) self.assertIsInstance(subset, RepetitionWaveform) self.assertIsInstance(subset._body, DummyWaveform) - self.assertIs(subset._body.defined_channels, chs) + self.assertEqual(subset._body.defined_channels, chs) self.assertEqual(subset._repetition_count, 3) def test_unsafe_sample(self): @@ -397,12 +395,12 @@ def test_init(self): swf1 = SequenceWaveform((dwf_ab, dwf_ab)) self.assertEqual(swf1.duration, 2*dwf_ab.duration) - self.assertEqual(len(swf1.compare_key), 2) + self.assertEqual(len(swf1.sequenced_waveforms), 2) swf2 = SequenceWaveform((swf1, dwf_ab)) self.assertEqual(swf2.duration, 3 * dwf_ab.duration) - self.assertEqual(len(swf2.compare_key), 2) + self.assertEqual(len(swf2.sequenced_waveforms), 2) def test_from_sequence(self): dwf = DummyWaveform(duration=1.1, defined_channels={'A'}) @@ -421,10 +419,10 @@ def test_from_sequence(self): cwf_3 = ConstantWaveform(duration=1.1, amplitude=3.3, channel='A') cwf_2_b = ConstantWaveform(duration=1.1, amplitude=2.2, channel='A') - with mock.patch.object(ConstantWaveform, 'from_mapping', return_value=mock.sentinel) as from_mapping: - new_constant = SequenceWaveform.from_sequence((cwf_2_a, cwf_2_b)) - self.assertIs(from_mapping.return_value, new_constant) - from_mapping.assert_called_once_with(2*TimeType.from_float(1.1), {'A': 2.2}) + new_constant = SequenceWaveform.from_sequence((cwf_2_a, cwf_2_b)) + expected_constant = ConstantWaveform.from_mapping(2*TimeType.from_float(1.1), {'A': 2.2}) + self.assertEqual(expected_constant, + new_constant) swf3 = SequenceWaveform.from_sequence((cwf_2_a, dwf)) self.assertEqual((cwf_2_a, dwf), swf3.sequenced_waveforms) @@ -436,6 +434,7 @@ def test_from_sequence(self): self.assertIsNone(swf3.constant_value('A')) assert_constant_consistent(self, swf3) + @unittest.skipIf(rs_replacements is not None, "sentinel based test do not work with rust extension") def test_sample_times_type(self) -> None: with mock.patch.object(DummyWaveform, 'unsafe_sample') as unsafe_sample_patch: dwfs = (DummyWaveform(duration=1.), @@ -480,12 +479,12 @@ def test_unsafe_get_subset_for_channels(self): sub_wf = wf.unsafe_get_subset_for_channels(subset) self.assertIsInstance(sub_wf, SequenceWaveform) - self.assertEqual(len(sub_wf.compare_key), 2) - self.assertEqual(sub_wf.compare_key[0].defined_channels, subset) - self.assertEqual(sub_wf.compare_key[1].defined_channels, subset) + self.assertEqual(len(sub_wf.sequenced_waveforms), 2) + self.assertEqual(sub_wf.sequenced_waveforms[0].defined_channels, subset) + self.assertEqual(sub_wf.sequenced_waveforms[1].defined_channels, subset) - self.assertEqual(sub_wf.compare_key[0].duration, TimeType.from_float(2.2)) - self.assertEqual(sub_wf.compare_key[1].duration, TimeType.from_float(3.3)) + self.assertEqual(sub_wf.sequenced_waveforms[0].duration, TimeType.from_float(2.2)) + self.assertEqual(sub_wf.sequenced_waveforms[1].duration, TimeType.from_float(3.3)) def test_repr(self): cwf_2_a = ConstantWaveform(duration=1.1, amplitude=2.2, channel='A') @@ -783,7 +782,7 @@ def test_simple_properties(self): self.assertIs(subset_wf.inner_waveform, inner_wf) self.assertEqual(subset_wf.compare_key, (frozenset(['a', 'c']), inner_wf)) - self.assertIs(subset_wf.duration, inner_wf.duration) + self.assertEqual(subset_wf.duration, inner_wf.duration) self.assertEqual(subset_wf.defined_channels, {'a', 'c'}) def test_get_subset_for_channels(self): @@ -798,7 +797,8 @@ def test_get_subset_for_channels(self): get_subset_for_channels.assert_called_once_with({'a'}) self.assertIs(subsetted, actual_subsetted) - def test_unsafe_sample(self): + @unittest.skipIf(rs_replacements is not None, "Test requires pure python.") + def test_unsafe_sample_pure(self): """Test perfect forwarding""" time = {'time'} output = {'output'} @@ -814,6 +814,27 @@ def test_unsafe_sample(self): self.assertIs(expected_data, actual_data) unsafe_sample.assert_called_once_with('g', time, output) + def test_unsafe_sample_pure(self): + """Test perfect forwarding""" + time = np.arange(0., 1., 17) + + output_values = np.sin(time + 1e-4) + sample_output = dict( + a=output_values + 3, + b=output_values + 9, + c=output_values + 17, + ) + + inner_wf = DummyWaveform(sample_output=sample_output, duration=2.) + + subset_wf = SubsetWaveform(inner_wf, {'a', 'c'}) + + for ch in 'ac': + output_place = np.full_like(time, np.nan) + output = subset_wf.unsafe_sample(ch, sample_times=time, output_array=output_place) + self.assertIs(output, output_place) + numpy.testing.assert_equal(sample_output[ch], output) + class ArithmeticWaveformTest(unittest.TestCase): def test_from_operator(self): diff --git a/tests/pulses/constant_pulse_template_tests.py b/tests/pulses/constant_pulse_template_tests.py index 63573de4e..7a015b26e 100644 --- a/tests/pulses/constant_pulse_template_tests.py +++ b/tests/pulses/constant_pulse_template_tests.py @@ -45,7 +45,7 @@ def test_regression_duration_conversion(self): for duration_in_samples in [64, 936320, 24615392]: p = ConstantPulseTemplate(duration_in_samples / 2.4, {'a': 0}) number_of_samples = p.create_program().duration * 2.4 - make_compatible(p.create_program(), 8, 8, 2.4) + make_compatible(p.create_program(), 8, 8, qupulse.utils.types.TimeType.from_float(2.4)) self.assertEqual(number_of_samples.denominator, 1) p2 = ConstantPulseTemplate((duration_in_samples + 1) / 2.4, {'a': 0}) diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index f6b97a63c..87fb727d8 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -11,7 +11,7 @@ from qupulse.pulses.parameters import Parameter, ConstantParameter, ParameterNotProvidedException from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform from qupulse._program._loop import Loop -from qupulse._program import ProgramBuilder +from qupulse._program import ProgramBuilder, default_program_builder from qupulse._program.transformation import Transformation from qupulse._program.waveforms import TransformingWaveform @@ -226,13 +226,15 @@ def test__create_program_single_waveform(self): scope = DictScope.from_kwargs(a=1., b=2., volatile={'a'}) measurement_mapping = {'M': 'N'} channel_mapping = {'B': 'A'} - parent_loop = Loop() + + program_builder = default_program_builder() + inner_program_builder = default_program_builder() wf = DummyWaveform() single_waveform = DummyWaveform() measurements = [('m', 0, 1), ('n', 0.1, .9)] - expected_inner_program = Loop(children=[Loop(waveform=wf, measurements=measurements)]) + expected_inner_program = Loop(waveform=wf, measurements=measurements) appending_create_program = get_appending_internal_create_program(wf, measurements=measurements, @@ -249,22 +251,24 @@ def test__create_program_single_waveform(self): wraps=appending_create_program) as _internal_create_program: with mock.patch('qupulse._program._loop.to_waveform', return_value=single_waveform) as to_waveform: - template._create_program(scope=scope, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, - global_transformation=global_transformation, - to_single_waveform=to_single_waveform, - parent_loop=parent_loop) + with mock.patch('qupulse.pulses.pulse_template.default_program_builder', + return_value=inner_program_builder): + template._create_program(scope=scope, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, + global_transformation=global_transformation, + to_single_waveform=to_single_waveform, + parent_loop=program_builder) _internal_create_program.assert_called_once_with(scope=scope, measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=None, to_single_waveform=to_single_waveform, - parent_loop=expected_inner_program) + parent_loop=inner_program_builder) to_waveform.assert_called_once_with(expected_inner_program) - self.assertEqual(expected_program, parent_loop) + self.assertEqual(expected_program, program_builder.to_program()) def test_create_program_defaults(self) -> None: template = PulseTemplateStub(defined_channels={'A', 'B'}, parameter_names={'foo'}, measurement_names={'hugo', 'foo'}) diff --git a/tests/pulses/sequence_pulse_template_tests.py b/tests/pulses/sequence_pulse_template_tests.py index e224b2ce3..b1aa653a3 100644 --- a/tests/pulses/sequence_pulse_template_tests.py +++ b/tests/pulses/sequence_pulse_template_tests.py @@ -76,7 +76,7 @@ def test_build_waveform(self): self.assertIs(pt.build_waveform_calls[0][0], parameters) self.assertIsInstance(wf, SequenceWaveform) - for wfa, wfb in zip(wf.compare_key, wfs): + for wfa, wfb in zip(wf.sequenced_waveforms, wfs): self.assertIs(wfa, wfb) def test_identifier(self) -> None: From 39a55c7510c3867217c47b19e3c476ffa96fdbe5 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 6 Sep 2022 22:27:46 +0200 Subject: [PATCH 24/35] Add expressions and scopes from rust extension --- qupulse/expressions.py | 26 +++++++++++++++++++++++--- qupulse/parameter_scope.py | 15 +++++++++++++++ tests/expression_tests.py | 24 ++++++++++++++++++------ 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/qupulse/expressions.py b/qupulse/expressions.py index 326f65491..0749da29e 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Union, Sequence, Callable, TypeVar, Type, Mapping from numbers import Number import warnings -import functools +import inspect import array import itertools @@ -18,6 +18,11 @@ get_most_simple_representation, get_variables, evaluate_lamdified_exact_rational from qupulse.utils.types import TimeType +try: + import qupulse_rs.replacements +except ImportError: + qupulse_rs = None + __all__ = ["Expression", "ExpressionVariableMissingException", "ExpressionScalar", "ExpressionVector", "ExpressionLike"] @@ -219,7 +224,7 @@ def __eq__(self, other): other = Expression.make(other) except (ValueError, TypeError): return NotImplemented - if isinstance(other, ExpressionScalar): + if type(other).__name__ == 'ExpressionScalar': return self._expression_shape in ((), (1,)) and self._expression_items[0] == other.sympified_expression else: return self._expression_shape == other._expression_shape and \ @@ -435,7 +440,7 @@ def __str__(self) -> str: str(self.expression), self.variable) -class NonNumericEvaluation(Exception): +class NonNumericEvaluation(TypeError): """An exception that is raised if the result of evaluate_numeric is not a number. See also: @@ -462,3 +467,18 @@ def __str__(self) -> str: ExpressionLike = TypeVar('ExpressionLike', str, Number, sympy.Expr, ExpressionScalar) + + +if qupulse_rs: + RsExpressionScalar = qupulse_rs.replacements.ExpressionScalar + PyExpressionScalar = ExpressionScalar + + class ExpressionScalar(PyExpressionScalar): + def __new__(cls, *args, **kwargs): + if qupulse_rs and cls.__name__ == 'ExpressionScalar': + try: + return RsExpressionScalar(*args, **kwargs) + except (ValueError, TypeError, RuntimeError): + pass + return PyExpressionScalar.__new__(cls) + diff --git a/qupulse/parameter_scope.py b/qupulse/parameter_scope.py index a59f94a0b..2f174ba17 100644 --- a/qupulse/parameter_scope.py +++ b/qupulse/parameter_scope.py @@ -9,6 +9,11 @@ from qupulse.expressions import Expression, ExpressionVariableMissingException from qupulse.utils.types import FrozenMapping, FrozenDict +try: + import qupulse_rs.replacements.parameter_scopes +except ImportError: + qupulse_rs = None + class Scope(Mapping[str, Number]): """Abstract parameter lookup. Scopes are immutable. Internally it holds all dependencies of parameters and keeps @@ -319,3 +324,13 @@ def __str__(self) -> str: class NonVolatileChange(RuntimeWarning): """Raised if a non volatile parameter is updated""" + + +if qupulse_rs: + PyDictScope = DictScope + PyMappedScope = MappedScope + PyJointScope = JointScope + + DictScope = qupulse_rs.replacements.parameter_scopes.DictScope + MappedScope = qupulse_rs.replacements.parameter_scopes.MappedScope + JointScope = qupulse_rs.replacements.parameter_scopes.JointScope diff --git a/tests/expression_tests.py b/tests/expression_tests.py index eb384fbc0..43262b267 100644 --- a/tests/expression_tests.py +++ b/tests/expression_tests.py @@ -5,9 +5,15 @@ import sympy.abc from sympy import sympify, Eq -from qupulse.expressions import Expression, ExpressionVariableMissingException, NonNumericEvaluation, ExpressionScalar, ExpressionVector +from qupulse.expressions import Expression, ExpressionVariableMissingException, NonNumericEvaluation, ExpressionScalar,\ + ExpressionVector, qupulse_rs from qupulse.utils.types import TimeType +try: + from qupulse.expressions import PyExpressionScalar +except ImportError: + PyExpressionScalar = ExpressionScalar + class ExpressionTests(unittest.TestCase): def test_make(self): self.assertTrue(Expression.make('a') == 'a') @@ -16,7 +22,10 @@ def test_make(self): self.assertIsInstance(Expression.make([1, 'a']), ExpressionVector) - self.assertIsInstance(ExpressionScalar.make('a'), ExpressionScalar) + if qupulse_rs: + self.assertEqual(type(ExpressionScalar.make('a')).__name__, 'ExpressionScalar') + else: + self.assertIsInstance(ExpressionScalar.make('a'), ExpressionScalar) self.assertIsInstance(ExpressionVector.make(['a']), ExpressionVector) @@ -146,7 +155,7 @@ def test_evaluate_numeric(self) -> None: } self.assertEqual(2 * 1.5 - 7, e.evaluate_numeric(**params)) - with self.assertRaises(NonNumericEvaluation): + with self.assertRaises((NonNumericEvaluation, TypeError)): params['a'] = sympify('h') e.evaluate_numeric(**params) @@ -230,7 +239,7 @@ def test_evaluate_numeric_without_numpy(self): 'b': sympify('k'), 'c': -7 } - with self.assertRaises(NonNumericEvaluation): + with self.assertRaises(TypeError): e.evaluate_numeric(**params) def test_evaluate_symbolic(self): @@ -240,7 +249,7 @@ def test_evaluate_symbolic(self): 'c': -7 } result = e.evaluate_symbolic(params) - expected = ExpressionScalar('d*b-7') + expected = PyExpressionScalar('d*b-7') self.assertEqual(result, expected) def test_variables(self) -> None: @@ -287,7 +296,10 @@ def test_repr_original_expression_is_sympy(self): def test_str(self): s = 'a * b' e = ExpressionScalar(s) - self.assertEqual('a*b', str(e)) + if qupulse_rs is None: + self.assertEqual('a*b', str(e)) + else: + self.assertEqual(s, str(e)) def test_original_expression(self): s = 'a * b' From d8d70672989f2628a2b5bce7dc644ab044565a13 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 6 Sep 2022 22:29:20 +0200 Subject: [PATCH 25/35] Use equality oeprator semantics of TimeType.__value for retry on NotImplemented --- qupulse/utils/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/qupulse/utils/types.py b/qupulse/utils/types.py index 81ecf8ef5..2bd4e5c39 100644 --- a/qupulse/utils/types.py +++ b/qupulse/utils/types.py @@ -20,6 +20,7 @@ try: from qupulse_rs.qupulse_rs import TimeType as RsTimeType + numbers.Rational.register(RsTimeType) except ImportError: RsTimeType = None @@ -239,7 +240,7 @@ def __gt__(self, other): def __eq__(self, other): if type(other) == type(self): - return self._value.__eq__(other._value) + return self._value == other._value else: return self._value == other From ac9a85a560377cfb07a6c7d2a3142797aa1a2cea Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 6 Sep 2022 22:29:49 +0200 Subject: [PATCH 26/35] Use duck typing for AnonymousSerializable serialization --- qupulse/serialization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/qupulse/serialization.py b/qupulse/serialization.py index 3825057ed..b441053cb 100644 --- a/qupulse/serialization.py +++ b/qupulse/serialization.py @@ -30,6 +30,7 @@ import gc import importlib import warnings +from typing import Protocol, runtime_checkable from contextlib import contextmanager from qupulse.utils.types import DocStringABCMeta, FrozenDict @@ -1064,7 +1065,7 @@ def default(self, o: Any) -> Any: else: return o.get_serialization_data() - elif isinstance(o, AnonymousSerializable): + elif hasattr(o, 'get_serialization_data'): return o.get_serialization_data() elif type(o) is set: From 148b43d87ad5f1359d14a31cd54ac48e6199cc28 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 7 Sep 2022 11:06:07 +0200 Subject: [PATCH 27/35] Add matplotlib to test requirements --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 751ed2f7e..bfec25460 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,7 @@ test_suite = tests tests = pytest pytest_benchmark + matplotlib docs = sphinx>=4 nbsphinx From 3881b2834818a2784fe6ae131c4a1bfacd03248c Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 16 Sep 2022 13:33:26 +0200 Subject: [PATCH 28/35] rework replacement code --- qupulse/_program/transformation.py | 21 ++++- qupulse/_program/waveforms.py | 41 +++------ qupulse/expressions.py | 8 +- qupulse/parameter_scope.py | 17 ++-- qupulse/serialization.py | 2 +- qupulse/utils/types.py | 17 ++++ tests/_program/transformation_tests.py | 114 ++++++++++++++++--------- tests/_program/waveforms_tests.py | 10 +-- 8 files changed, 141 insertions(+), 89 deletions(-) diff --git a/qupulse/_program/transformation.py b/qupulse/_program/transformation.py index da3353fd9..e555caed2 100644 --- a/qupulse/_program/transformation.py +++ b/qupulse/_program/transformation.py @@ -6,7 +6,15 @@ from qupulse import ChannelID from qupulse.comparable import Comparable -from qupulse.utils.types import SingletonABCMeta +from qupulse.utils.types import SingletonABCMeta, use_rs_replacements + +try: + import qupulse_rs +except ImportError: + qupulse_rs = None + transformation_rs = None +else: + from qupulse_rs.replacements import transformation as transformation_rs class Transformation(Comparable): @@ -321,4 +329,13 @@ def chain_transformations(*transformations: Transformation) -> Transformation: elif len(parsed_transformations) == 1: return parsed_transformations[0] else: - return ChainedTransformation(*parsed_transformations) \ No newline at end of file + return ChainedTransformation(*parsed_transformations) + + + +if transformation_rs: + use_rs_replacements(globals(), transformation_rs, Transformation) + + py_chain_transformations = chain_transformations + rs_chain_transformations = transformation_rs.chain_transformations + chain_transformations = rs_chain_transformations diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 7e67640e1..a773a9591 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -22,15 +22,17 @@ from qupulse.expressions import ExpressionScalar from qupulse.pulses.interpolation import InterpolationStrategy from qupulse.utils import checked_int_cast, isclose -from qupulse.utils.types import TimeType, time_from_float, FrozenDict +from qupulse.utils.types import TimeType, time_from_float, FrozenDict, use_rs_replacements from qupulse._program.transformation import Transformation from qupulse.utils import pairwise try: - import qupulse_rs.qupulse_rs - rs_replacements = qupulse_rs.qupulse_rs.replacements -except (ImportError, AttributeError): - rs_replacements = None + import qupulse_rs +except ImportError: + qupulse_rs = None + waveforms_rs = None +else: + from qupulse_rs.replacements import waveforms as waveforms_rs __all__ = ["Waveform", "TableWaveform", "TableWaveformEntry", "FunctionWaveform", "SequenceWaveform", @@ -1227,28 +1229,7 @@ def reversed(self) -> 'Waveform': return self._inner -if rs_replacements is not None: - PyTableWaveform = TableWaveform - PyConstantWaveform = ConstantWaveform - PyMultiChannelWaveform = MultiChannelWaveform - PyRepetitionWaveform = RepetitionWaveform - PySequenceWaveform = SequenceWaveform - PyArithmeticWaveform = ArithmeticWaveform - PySubsetWaveform = SubsetWaveform - - TableWaveform = rs_replacements.waveforms.TableWaveform - ConstantWaveform = rs_replacements.waveforms.ConstantWaveform - MultiChannelWaveform = rs_replacements.waveforms.MultiChannelWaveform - RepetitionWaveform = rs_replacements.waveforms.RepetitionWaveform - SequenceWaveform = rs_replacements.waveforms.SequenceWaveform - ArithmeticWaveform = rs_replacements.waveforms.ArithmeticWaveform - SubsetWaveform = rs_replacements.waveforms.SubsetWaveform - - Waveform.register(TableWaveform) - Waveform.register(ConstantWaveform) - Waveform.register(MultiChannelWaveform) - Waveform.register(TableWaveform) - Waveform.register(RepetitionWaveform) - Waveform.register(SequenceWaveform) - Waveform.register(ArithmeticWaveform) - Waveform.register(SubsetWaveform) + + +if waveforms_rs: + use_rs_replacements(globals(), waveforms_rs, Waveform) diff --git a/qupulse/expressions.py b/qupulse/expressions.py index 0749da29e..3cd96a1ee 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -19,9 +19,12 @@ from qupulse.utils.types import TimeType try: - import qupulse_rs.replacements + import qupulse_rs except ImportError: qupulse_rs = None + RsExpressionScalar = None +else: + from qupulse_rs.replacements import ExpressionScalar as RsExpressionScalar __all__ = ["Expression", "ExpressionVariableMissingException", "ExpressionScalar", "ExpressionVector", "ExpressionLike"] @@ -469,8 +472,7 @@ def __str__(self) -> str: ExpressionLike = TypeVar('ExpressionLike', str, Number, sympy.Expr, ExpressionScalar) -if qupulse_rs: - RsExpressionScalar = qupulse_rs.replacements.ExpressionScalar +if RsExpressionScalar: PyExpressionScalar = ExpressionScalar class ExpressionScalar(PyExpressionScalar): diff --git a/qupulse/parameter_scope.py b/qupulse/parameter_scope.py index 2f174ba17..8423f34d7 100644 --- a/qupulse/parameter_scope.py +++ b/qupulse/parameter_scope.py @@ -7,12 +7,15 @@ import itertools from qupulse.expressions import Expression, ExpressionVariableMissingException -from qupulse.utils.types import FrozenMapping, FrozenDict +from qupulse.utils.types import FrozenMapping, FrozenDict, use_rs_replacements try: - import qupulse_rs.replacements.parameter_scopes + import qupulse_rs except ImportError: qupulse_rs = None + parameter_scope_rs = None +else: + from qupulse_rs.replacements import parameter_scope as parameter_scope_rs class Scope(Mapping[str, Number]): @@ -326,11 +329,5 @@ class NonVolatileChange(RuntimeWarning): """Raised if a non volatile parameter is updated""" -if qupulse_rs: - PyDictScope = DictScope - PyMappedScope = MappedScope - PyJointScope = JointScope - - DictScope = qupulse_rs.replacements.parameter_scopes.DictScope - MappedScope = qupulse_rs.replacements.parameter_scopes.MappedScope - JointScope = qupulse_rs.replacements.parameter_scopes.JointScope +if parameter_scope_rs: + use_rs_replacements(globals(), parameter_scope_rs, Scope) diff --git a/qupulse/serialization.py b/qupulse/serialization.py index b441053cb..b1eabaeec 100644 --- a/qupulse/serialization.py +++ b/qupulse/serialization.py @@ -1092,7 +1092,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def default(self, o: Any) -> Any: - if isinstance(o, AnonymousSerializable): + if hasattr(o, 'get_serialization_data'): return o.get_serialization_data() elif type(o) is set: return list(o) diff --git a/qupulse/utils/types.py b/qupulse/utils/types.py index 2bd4e5c39..aac57c143 100644 --- a/qupulse/utils/types.py +++ b/qupulse/utils/types.py @@ -566,3 +566,20 @@ def __eq__(self, other): return NotImplemented +def use_rs_replacements(glbls, rs_replacement, base_class: type): + name_suffix = base_class.__name__ + for name, rs_obj in vars(rs_replacement).items(): + if not name.endswith(name_suffix): + continue + + py_name = f'Py{name}' + rs_name = f'Rs{name}' + glbls[name] = rs_obj + try: + py_obj = glbls[name] + except KeyError: + pass + else: + glbls.setdefault(py_name, py_obj) + glbls.setdefault(rs_name, rs_obj) + base_class.register(rs_obj) diff --git a/tests/_program/transformation_tests.py b/tests/_program/transformation_tests.py index e75e17dc6..0098fe9b9 100644 --- a/tests/_program/transformation_tests.py +++ b/tests/_program/transformation_tests.py @@ -2,10 +2,11 @@ from unittest import mock import numpy as np +import numpy.testing from qupulse._program.transformation import LinearTransformation, Transformation, IdentityTransformation,\ ChainedTransformation, ParallelConstantChannelTransformation, chain_transformations, OffsetTransformation,\ - ScalingTransformation + ScalingTransformation, transformation_rs class TransformationStub(Transformation): @@ -59,10 +60,11 @@ def test_compare_key_and_init(self): matrix_2 = np.array([[1, 1, 1], [1, 0, -1]]) trafo_2 = LinearTransformation(matrix_2, in_chs_2, out_chs_2) - self.assertEqual(trafo.compare_key, trafo_2.compare_key) + if transformation_rs is None: + self.assertEqual(trafo.compare_key, trafo_2.compare_key) + self.assertEqual(trafo.compare_key, (in_chs, out_chs, matrix.tobytes())) self.assertEqual(trafo, trafo_2) self.assertEqual(hash(trafo), hash(trafo_2)) - self.assertEqual(trafo.compare_key, (in_chs, out_chs, matrix.tobytes())) def test_from_pandas(self): try: @@ -93,14 +95,14 @@ def test_get_output_channels(self): def test_get_input_channels(self): in_chs = ('a', 'b', 'c') out_chs = ('transformed_a', 'transformed_b') - matrix = np.array([[1, -1, 0], [1, 1, 1]]) + matrix = np.array([[1., -1, 0], [1, 1, 1]]) trafo = LinearTransformation(matrix, in_chs, out_chs) self.assertEqual(trafo.get_input_channels({'transformed_a'}), {'a', 'b', 'c'}) self.assertEqual(trafo.get_input_channels({'transformed_a', 'd'}), {'a', 'b', 'c', 'd'}) self.assertEqual(trafo.get_input_channels({'d'}), {'d'}) with self.assertRaisesRegex(KeyError, 'Is input channel'): - self.assertEqual(trafo.get_input_channels({'transformed_a', 'a'}), {'a', 'b', 'c', 'd'}) + trafo.get_input_channels({'transformed_a', 'a'}) in_chs = ('a', 'b', 'c') out_chs = ('a', 'b', 'c') @@ -108,7 +110,7 @@ def test_get_input_channels(self): trafo = LinearTransformation(matrix, in_chs, out_chs) in_set = {'transformed_a'} - self.assertIs(trafo.get_input_channels(in_set), in_set) + self.assertEqual(trafo.get_input_channels(in_set), in_set) self.assertEqual(trafo.get_input_channels({'transformed_a', 'a'}), {'transformed_a', 'a', 'b', 'c'}) def test_call(self): @@ -143,7 +145,7 @@ def test_call(self): data_in = {'ignored': np.arange(116., 120.)} transformed = trafo(np.full(4, np.NaN), data_in) np.testing.assert_equal(transformed, data_in) - self.assertIs(data_in['ignored'], transformed['ignored']) + np.testing.assert_equal(data_in['ignored'], transformed['ignored']) def test_repr(self): in_chs = ('a', 'b', 'c') @@ -170,23 +172,25 @@ def test_constant_propagation(self): class IdentityTransformationTests(unittest.TestCase): def test_compare_key(self): - self.assertIsNone(IdentityTransformation().compare_key) + self.assertEqual(IdentityTransformation(), IdentityTransformation()) + self.assertEqual({IdentityTransformation()}, {IdentityTransformation(), IdentityTransformation()}) + @unittest.skipIf(transformation_rs is not None, "Not implemented yet for rust") def test_singleton(self): self.assertIs(IdentityTransformation(), IdentityTransformation()) def test_call(self): time = np.arange(12) data = dict(zip('abc',(np.arange(12.) + 1).reshape((3, 4)))) - self.assertIs(IdentityTransformation()(time, data), data) + self.assertEqual(IdentityTransformation()(time, data), data) def test_output_channels(self): chans = {'a', 'b'} - self.assertIs(IdentityTransformation().get_output_channels(chans), chans) + self.assertEqual(IdentityTransformation().get_output_channels(chans), chans) def test_input_channels(self): chans = {'a', 'b'} - self.assertIs(IdentityTransformation().get_input_channels(chans), chans) + self.assertEqual(IdentityTransformation().get_input_channels(chans), chans) def test_chain(self): trafo = TransformationStub() @@ -209,7 +213,13 @@ def test_init_and_properties(self): chained = ChainedTransformation(*trafos) self.assertEqual(chained.transformations, trafos) - self.assertIs(chained.transformations, chained.compare_key) + + def test_equality(self): + trafos1 = TransformationStub(), TransformationStub(), TransformationStub() + trafos2 = trafos1[0], trafos1[1], TransformationStub() + self.assertEqual(ChainedTransformation(*trafos1), ChainedTransformation(*trafos1)) + self.assertNotEqual(ChainedTransformation(*trafos1), ChainedTransformation(*trafos2)) + self.assertEqual({ChainedTransformation(*trafos1)}, {ChainedTransformation(*trafos1), ChainedTransformation(*trafos1)}) def test_get_output_channels(self): trafos = TransformationStub(), TransformationStub(), TransformationStub() @@ -221,7 +231,7 @@ def test_get_output_channels(self): mock.patch.object(trafos[2], 'get_output_channels', return_value=chans[2]) as get_output_channels_2: outs = chained.get_output_channels({0}) - self.assertIs(outs, chans[2]) + self.assertEqual(outs, chans[2]) get_output_channels_0.assert_called_once_with({0}) get_output_channels_1.assert_called_once_with({1}) get_output_channels_2.assert_called_once_with({2}) @@ -237,12 +247,13 @@ def test_get_input_channels(self): mock.patch.object(trafos[0], 'get_input_channels', return_value=chans[2]) as get_input_channels_2: outs = chained.get_input_channels({0}) - self.assertIs(outs, chans[2]) + self.assertEqual(outs, chans[2]) get_input_channels_0.assert_called_once_with({0}) get_input_channels_1.assert_called_once_with({1}) get_input_channels_2.assert_called_once_with({2}) def test_call(self): + from qupulse._program.transformation import PyChainedTransformation trafos = TransformationStub(), TransformationStub(), TransformationStub() chained = ChainedTransformation(*trafos) @@ -253,18 +264,19 @@ def test_call(self): data_0 = dict(zip('abc', data + 42)) data_1 = dict(zip('abc', data + 2*42)) data_2 = dict(zip('abc', data + 3*42)) - with mock.patch('tests._program.transformation_tests.TransformationStub.__call__', - side_effect=[data_0, data_1, data_2]) as call: + with mock.patch.object(TransformationStub, '__call__', + side_effect=[data_0, data_1, data_2]) as call: outs = chained(time, data_in) - - self.assertIs(outs, data_2) self.assertEqual(call.call_count, 3) + numpy.testing.assert_equal(outs, data_2) + for ((time_arg, data_arg), kwargs), expected_data in zip(call.call_args_list, [data_in, data_0, data_1]): self.assertEqual(kwargs, {}) - self.assertIs(time, time_arg) - self.assertIs(expected_data, data_arg) + numpy.testing.assert_equal(time, time_arg) + numpy.testing.assert_equal(expected_data, data_arg) + @unittest.skipIf(transformation_rs is not None, "Not implemented for rust extension") def test_chain(self): trafos = TransformationStub(), TransformationStub() trafo = TransformationStub() @@ -279,6 +291,9 @@ def test_repr(self): trafo = ChainedTransformation(ScalingTransformation({'a': 1.1}), OffsetTransformation({'b': 6.6})) self.assertEqual(trafo, eval(repr(trafo))) + stub = TransformationStub() + self.assertEqual(repr(ChainedTransformation(stub, stub)), repr(ChainedTransformation(stub, stub))) + def test_constant_propagation(self): trafo = ChainedTransformation(ScalingTransformation({'a': 1.1}), OffsetTransformation({'b': 6.6})) self.assertTrue(trafo.is_constant_invariant()) @@ -292,10 +307,9 @@ def test_init(self): trafo = ParallelConstantChannelTransformation(channels) - self.assertEqual(trafo._channels, channels) - self.assertTrue(all(isinstance(v, float) for v in trafo._channels.values())) - - self.assertEqual(trafo.compare_key, (('X', 2.), ('Y', 4.4))) + if transformation_rs is None: + self.assertEqual(trafo._channels, channels) + self.assertTrue(all(isinstance(v, float) for v in trafo._channels.values())) self.assertEqual(trafo.get_input_channels(set()), set()) self.assertEqual(trafo.get_input_channels({'X'}), set()) @@ -306,6 +320,21 @@ def test_init(self): self.assertEqual(trafo.get_output_channels({'X'}), {'X', 'Y'}) self.assertEqual(trafo.get_output_channels({'X', 'Z'}), {'X', 'Y', 'Z'}) + def test_equality(self): + constants_1 = {'a': 1.3, 'b': 3.0} + constants_2 = {'a': 1.3, 'b': 3.1} + + + self.assertEqual(ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_1)) + self.assertEqual(ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_1.copy())) + self.assertNotEqual(ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_2)) + + self.assertEqual({ParallelConstantChannelTransformation(constants_1)}, + {ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_1)}) + self.assertEqual({ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_2)}, + {ParallelConstantChannelTransformation(constants_1), + ParallelConstantChannelTransformation(constants_1), ParallelConstantChannelTransformation(constants_2)}) + def test_trafo(self): channels = {'X': 2, 'Y': 4.4} trafo = ParallelConstantChannelTransformation(channels) @@ -347,16 +376,16 @@ def test_constant_propagation(self): class TestChaining(unittest.TestCase): def test_identity_result(self): - self.assertIs(chain_transformations(), IdentityTransformation()) + self.assertEqual(chain_transformations(), IdentityTransformation()) - self.assertIs(chain_transformations(IdentityTransformation(), IdentityTransformation()), + self.assertEqual(chain_transformations(IdentityTransformation(), IdentityTransformation()), IdentityTransformation()) def test_single_transformation(self): trafo = TransformationStub() - self.assertIs(chain_transformations(trafo), trafo) - self.assertIs(chain_transformations(trafo, IdentityTransformation()), trafo) + self.assertEqual(chain_transformations(trafo), trafo) + self.assertEqual(chain_transformations(trafo, IdentityTransformation()), trafo) def test_denesting(self): trafo = TransformationStub() @@ -381,6 +410,7 @@ class TestOffsetTransformation(unittest.TestCase): def setUp(self) -> None: self.offsets = {'A': 1., 'B': 1.2} + @unittest.skipIf(transformation_rs is not None, "Not relevant for rust extension") def test_init(self): trafo = OffsetTransformation(self.offsets) # test copy @@ -391,13 +421,17 @@ def test_init(self): def test_get_input_channels(self): trafo = OffsetTransformation(self.offsets) channels = {'A', 'C'} - self.assertIs(channels, trafo.get_input_channels(channels)) - self.assertIs(channels, trafo.get_output_channels(channels)) + self.assertEqual(channels, trafo.get_input_channels(channels)) + self.assertEqual(channels, trafo.get_output_channels(channels)) - def test_compare_key(self): + def test_comparison(self): trafo = OffsetTransformation(self.offsets) - _ = hash(trafo) - self.assertEqual(frozenset([('A', 1.), ('B', 1.2)]), trafo.compare_key) + + self.assertEqual(OffsetTransformation(self.offsets.copy()), OffsetTransformation(self.offsets.copy())) + self.assertEqual({OffsetTransformation(self.offsets.copy())}, + {OffsetTransformation(self.offsets.copy()), OffsetTransformation(self.offsets.copy())}) + self.assertEqual({OffsetTransformation(self.offsets.copy()), OffsetTransformation({**self.offsets, 'C': 9})}, + {OffsetTransformation(self.offsets.copy()), OffsetTransformation({**self.offsets, 'C': 9})}) def test_trafo(self): trafo = OffsetTransformation(self.offsets) @@ -429,6 +463,7 @@ class TestScalingTransformation(unittest.TestCase): def setUp(self) -> None: self.scales = {'A': 1.5, 'B': 1.2} + @unittest.skipIf(transformation_rs is not None, "Only relevant for pure python") def test_init(self): trafo = ScalingTransformation(self.scales) # test copy @@ -438,13 +473,16 @@ def test_init(self): def test_get_input_channels(self): trafo = ScalingTransformation(self.scales) channels = {'A', 'C'} - self.assertIs(channels, trafo.get_input_channels(channels)) - self.assertIs(channels, trafo.get_output_channels(channels)) + self.assertEqual(channels, trafo.get_input_channels(channels)) + self.assertEqual(channels, trafo.get_output_channels(channels)) def test_compare_key(self): - trafo = OffsetTransformation(self.scales) - _ = hash(trafo) - self.assertEqual(frozenset([('A', 1.5), ('B', 1.2)]), trafo.compare_key) + other_scales = {**self.scales, 'H': 3.} + self.assertEqual(ScalingTransformation(self.scales), ScalingTransformation(self.scales)) + self.assertNotEqual(ScalingTransformation(self.scales), ScalingTransformation(other_scales)) + + self.assertEqual({ScalingTransformation(self.scales)}, {ScalingTransformation(self.scales), + ScalingTransformation(self.scales)}) def test_trafo(self): trafo = ScalingTransformation(self.scales) diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index 207d33305..1546e6684 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -9,7 +9,7 @@ JumpInterpolationStrategy from qupulse._program.waveforms import MultiChannelWaveform, RepetitionWaveform, SequenceWaveform,\ TableWaveformEntry, TableWaveform, TransformingWaveform, SubsetWaveform, ArithmeticWaveform, ConstantWaveform,\ - Waveform, FunctorWaveform, FunctionWaveform, ReversedWaveform, rs_replacements + Waveform, FunctorWaveform, FunctionWaveform, ReversedWaveform, waveforms_rs from qupulse._program.transformation import LinearTransformation from qupulse.expressions import ExpressionScalar, Expression @@ -239,7 +239,7 @@ def test_unsafe_sample(self) -> None: result_a = waveform.unsafe_sample('A', sample_times, reuse_output) self.assertEqual(len(dwf_a.sample_calls), 2) self.assertIs(result_a, reuse_output) - if rs_replacements is None: + if waveforms_rs is None: # rust extension cannot forward the numpy array back to python without performance degradation self.assertIs(result_a, dwf_a.sample_calls[1][2]) numpy.testing.assert_equal(result_b, samples_b) @@ -434,7 +434,7 @@ def test_from_sequence(self): self.assertIsNone(swf3.constant_value('A')) assert_constant_consistent(self, swf3) - @unittest.skipIf(rs_replacements is not None, "sentinel based test do not work with rust extension") + @unittest.skipIf(waveforms_rs is not None, "sentinel based test do not work with rust extension") def test_sample_times_type(self) -> None: with mock.patch.object(DummyWaveform, 'unsafe_sample') as unsafe_sample_patch: dwfs = (DummyWaveform(duration=1.), @@ -607,7 +607,7 @@ def test_unsafe_sample(self) -> None: result = waveform.unsafe_sample('A', sample_times) - if rs_replacements is None: + if waveforms_rs is None: self.assertEqual(expected_interp_arguments, interp.call_arguments) numpy.testing.assert_equal(expected_result, result) @@ -797,7 +797,7 @@ def test_get_subset_for_channels(self): get_subset_for_channels.assert_called_once_with({'a'}) self.assertIs(subsetted, actual_subsetted) - @unittest.skipIf(rs_replacements is not None, "Test requires pure python.") + @unittest.skipIf(waveforms_rs is not None, "Test requires pure python.") def test_unsafe_sample_pure(self): """Test perfect forwarding""" time = {'time'} From a9168cb96c59f5d1dff59eef6500af0dbc54e86c Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 19 Sep 2022 11:57:14 +0200 Subject: [PATCH 29/35] Do not use expression sympy interface if not required --- qupulse/pulses/arithmetic_pulse_template.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/qupulse/pulses/arithmetic_pulse_template.py b/qupulse/pulses/arithmetic_pulse_template.py index 1606c7531..f9443da70 100644 --- a/qupulse/pulses/arithmetic_pulse_template.py +++ b/qupulse/pulses/arithmetic_pulse_template.py @@ -1,4 +1,3 @@ - from typing import Any, Dict, List, Set, Optional, Union, Mapping, FrozenSet, cast, Callable from numbers import Real import warnings @@ -387,18 +386,18 @@ def duration(self) -> ExpressionScalar: @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: - integral = {channel: value.sympified_expression for channel, value in self._pulse_template.integral.items()} + integral = {channel: value for channel, value in self._pulse_template.integral.items()} if isinstance(self._scalar, ExpressionScalar): - scalar = {channel: self._scalar.sympified_expression + scalar = {channel: self._scalar for channel in self.defined_channels} else: - scalar = {channel: value.sympified_expression + scalar = {channel: value for channel, value in self._scalar.items()} if self._arithmetic_operator == '+': for channel, value in scalar.items(): - integral[channel] = integral[channel] + (value * self.duration.sympified_expression) + integral[channel] = integral[channel] + (value * self.duration) elif self._arithmetic_operator == '*': for channel, value in scalar.items(): @@ -415,13 +414,13 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: # we need to negate all existing values for channel, inner_value in integral.items(): if channel in scalar: - integral[channel] = scalar[channel] * self.duration.sympified_expression - inner_value + integral[channel] = scalar[channel] * self.duration - inner_value else: integral[channel] = -inner_value else: for channel, value in scalar.items(): - integral[channel] = integral[channel] - value * self.duration.sympified_expression + integral[channel] = integral[channel] - value * self.duration for channel, value in integral.items(): integral[channel] = ExpressionScalar(value) From 639d552d46c68e23181252a27e7140811224b8d7 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 20 Sep 2022 10:47:40 +0200 Subject: [PATCH 30/35] Less usage of expression internals --- qupulse/expressions.py | 2 +- qupulse/pulses/mapping_pulse_template.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qupulse/expressions.py b/qupulse/expressions.py index 3cd96a1ee..2e2003218 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -326,7 +326,7 @@ def variables(self) -> Sequence[str]: @classmethod def _sympify(cls, other: Union['ExpressionScalar', Number, sympy.Expr]) -> sympy.Expr: - return other._sympified_expression if isinstance(other, cls) else sympify(other) + return sympify(other) def __lt__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> Union[bool, None]: result = self._sympified_expression < self._sympify(other) diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index 38b9229b6..ab415c3db 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -202,7 +202,7 @@ def defined_channels(self) -> Set[ChannelID]: @property def duration(self) -> Expression: return self.__template.duration.evaluate_symbolic( - {parameter_name: expression.underlying_expression + {parameter_name: expression for parameter_name, expression in self.__parameter_mapping.items()} ) From e45367057d2769d590ea067ff2a4169b268c38d0 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 20 Sep 2022 10:48:17 +0200 Subject: [PATCH 31/35] Custom subclass check for expressions --- qupulse/expressions.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/qupulse/expressions.py b/qupulse/expressions.py index 2e2003218..ca87aa4e8 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -48,6 +48,15 @@ def __call__(cls: Type[_ExpressionType], *args, **kwargs) -> _ExpressionType: else: return type.__call__(cls, *args, **kwargs) + if RsExpressionScalar is not None: + def __subclasscheck__(cls, subclass): + return cls.__name__ == subclass.__name__ or super().__subclasscheck__(subclass) + + def __instancecheck__(cls, instance): + if cls is ExpressionScalar or cls is Expression: + return isinstance(instance, RsExpressionScalar) or super().__instancecheck__(instance) + super().__instancecheck__(instance) + class Expression(AnonymousSerializable, metaclass=_ExpressionMeta): """Base class for expressions.""" @@ -484,3 +493,7 @@ def __new__(cls, *args, **kwargs): pass return PyExpressionScalar.__new__(cls) +assert isinstance(ExpressionScalar('a'), ExpressionScalar) +assert isinstance(ExpressionScalar('a'), Expression) +if RsExpressionScalar: + assert issubclass(RsExpressionScalar, ExpressionScalar) From 87e45e60af909eb0858260667f21f615cda1c26e Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 20 Sep 2022 10:48:57 +0200 Subject: [PATCH 32/35] Fix, improve and cleanup tests --- tests/_program/transformation_tests.py | 1 - tests/expression_tests.py | 30 +++++++++++++++++++ .../pulses/arithmetic_pulse_template_tests.py | 3 +- tests/pulses/loop_pulse_template_tests.py | 12 ++++++-- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/tests/_program/transformation_tests.py b/tests/_program/transformation_tests.py index 0098fe9b9..ff4b6c03d 100644 --- a/tests/_program/transformation_tests.py +++ b/tests/_program/transformation_tests.py @@ -253,7 +253,6 @@ def test_get_input_channels(self): get_input_channels_2.assert_called_once_with({2}) def test_call(self): - from qupulse._program.transformation import PyChainedTransformation trafos = TransformationStub(), TransformationStub(), TransformationStub() chained = ChainedTransformation(*trafos) diff --git a/tests/expression_tests.py b/tests/expression_tests.py index 43262b267..932cdd264 100644 --- a/tests/expression_tests.py +++ b/tests/expression_tests.py @@ -426,6 +426,36 @@ def test_special_function_numeric_evaluation(self): np.testing.assert_allclose(expected, result) + def test_rounding_equality(self): + seconds2ns = 1e9 + pulse_duration = 1.0765001496284785e-07 + float_product = pulse_duration * seconds2ns + + expr_1 = ExpressionScalar(pulse_duration) + expr_2 = ExpressionScalar(seconds2ns) + + self.assertEqual(expr_1, pulse_duration) + self.assertEqual(expr_2, seconds2ns) + + self.assertEqual(expr_1.sympified_expression, pulse_duration) + self.assertEqual(expr_2.sympified_expression, seconds2ns) + + expr_a = ExpressionScalar(float_product) + expr_b = expr_1 * seconds2ns + expr_c = expr_2 * pulse_duration + + #self.assertEqual(float_product, float(expr_a)) + #self.assertEqual(float_product, float(expr_b)) + #self.assertEqual(float_product, float(expr_c)) + + self.assertEqual(float_product, expr_a) + self.assertEqual(float_product, expr_b) + self.assertEqual(float_product, expr_c) + + expr_symb = ExpressionScalar('duration') + expr_d = expr_symb.evaluate_symbolic(substitutions={'duration': float_product}) + self.assertEqual(float_product, expr_d) + def test_evaluate_with_exact_rationals(self): expr = ExpressionScalar('1 / 3') self.assertEqual(TimeType.from_fraction(1, 3), expr.evaluate_with_exact_rationals({})) diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index 7b5ec0e47..467ce0e0a 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -469,7 +469,8 @@ def test_integral(self): expected = dict(u=ExpressionScalar('ui / (x + y)'), v=ExpressionScalar('vi / 2.2'), w=ExpressionScalar('wi')) - self.assertEqual(expected, ArithmeticPulseTemplate(pt, '/', mapping).integral) + actual = ArithmeticPulseTemplate(pt, '/', mapping).integral + self.assertEqual(expected, actual) def test_simple_attributes(self): lhs = DummyPulseTemplate(defined_channels={'a', 'b'}, duration=ExpressionScalar('t_dur'), diff --git a/tests/pulses/loop_pulse_template_tests.py b/tests/pulses/loop_pulse_template_tests.py index c22c46a4e..a1c8b41f3 100644 --- a/tests/pulses/loop_pulse_template_tests.py +++ b/tests/pulses/loop_pulse_template_tests.py @@ -2,6 +2,8 @@ import unittest from unittest import mock +import sympy + from qupulse.parameter_scope import DictScope from qupulse.utils.types import FrozenDict @@ -175,8 +177,14 @@ def test_integral(self) -> None: pulse = ForLoopPulseTemplate(dummy, 'i', (1, 8, 2)) expected = {'A': ExpressionScalar('Sum(t1-3.1*(1+2*i), (i, 0, 3))'), - 'B': ExpressionScalar('Sum((1+2*i), (i, 0, 3))') } - self.assertEqual(expected, pulse.integral) + 'B': ExpressionScalar('Sum((1+2*i), (i, 0, 3))')} + expected_simplified = {ch: ExpressionScalar(sympy.simplify(expr.sympified_expression)) + for ch, expr in expected.items()} + actual = pulse.integral + actual_simplified = {ch: ExpressionScalar(sympy.simplify(expr.sympified_expression)) + for ch, expr in actual.items()} + + self.assertEqual(expected_simplified, actual_simplified) class ForLoopTemplateSequencingTests(MeasurementWindowTestCase): From af743a2d4a208311d3f5d93a980c13de3302cb57 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 20 Sep 2022 17:10:30 +0200 Subject: [PATCH 33/35] Make more tests rust extension friendly --- tests/parameter_scope_tests.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/parameter_scope_tests.py b/tests/parameter_scope_tests.py index 1a2807d68..bd3b3f5f1 100644 --- a/tests/parameter_scope_tests.py +++ b/tests/parameter_scope_tests.py @@ -1,7 +1,7 @@ import unittest from unittest import mock -from qupulse.parameter_scope import Scope, DictScope, MappedScope, ParameterNotProvidedException, NonVolatileChange +from qupulse.parameter_scope import Scope, DictScope, MappedScope, ParameterNotProvidedException, NonVolatileChange, parameter_scope_rs from qupulse.expressions import ExpressionScalar from qupulse.utils.types import FrozenDict @@ -9,17 +9,20 @@ class DictScopeTests(unittest.TestCase): def test_init(self): - with self.assertRaises(AssertionError): - DictScope(dict()) + if parameter_scope_rs is None: + with self.assertRaises(AssertionError): + DictScope(dict()) fd = FrozenDict({'a': 2}) ds = DictScope(fd) - self.assertIs(fd, ds._values) - self.assertEqual(FrozenDict(), ds._volatile_parameters) + if parameter_scope_rs is None: + self.assertIs(fd, ds._values) + self.assertEqual(FrozenDict(), ds.get_volatile_parameters()) vp = frozenset('a') ds = DictScope(fd, vp) - self.assertIs(fd, ds._values) - self.assertEqual(FrozenDict(a=ExpressionScalar('a')), ds._volatile_parameters) + if parameter_scope_rs is None: + self.assertIs(fd, ds._values) + self.assertEqual(FrozenDict(a=ExpressionScalar('a')), ds.get_volatile_parameters()) def test_mapping(self): ds = DictScope(FrozenDict({'a': 1, 'b': 2})) @@ -136,6 +139,7 @@ def test_mapping(self): with self.assertRaisesRegex(KeyError, 'd'): _ = ms['d'] + @unittest.skipIf(parameter_scope_rs is not None, "Tested method not present in rust") def test_parameter(self): mock_a = mock.Mock(wraps=1) mock_result = mock.Mock() @@ -163,9 +167,10 @@ def test_parameter(self): def test_update_constants(self): ds = DictScope.from_kwargs(a=1, b=2, c=3, volatile={'c'}) ds2 = DictScope.from_kwargs(a=1, b=2, c=4, volatile={'c'}) - ms = MappedScope(ds, FrozenDict(x=ExpressionScalar('a * b'), - c=ExpressionScalar('a - b'))) - ms2 = MappedScope(ds2, ms._mapping) + mapping = FrozenDict(x=ExpressionScalar('a * b'), + c=ExpressionScalar('a - b')) + ms = MappedScope(ds, mapping) + ms2 = MappedScope(ds2, mapping) self.assertIs(ms, ms.change_constants({'f': 1})) @@ -180,7 +185,7 @@ def test_volatile_parameters(self): y=ExpressionScalar('c - a'))) expected_volatile = FrozenDict(d=ExpressionScalar('d'), y=ExpressionScalar('c - 1')) self.assertEqual(expected_volatile, ms.get_volatile_parameters()) - self.assertIs(ms.get_volatile_parameters(), ms.get_volatile_parameters()) + self.assertEqual(ms.get_volatile_parameters(), ms.get_volatile_parameters()) def test_eq(self): ds1 = DictScope.from_kwargs(a=1, b=2, c=3, d=4) From 9613cf657cfff6fa53a631eb2ce602d5a98c87fa Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 20 Sep 2022 19:42:53 +0200 Subject: [PATCH 34/35] Fix all tests --- qupulse/_program/__init__.py | 15 +++++++++++---- qupulse/pulses/pulse_template.py | 3 ++- tests/_program/seqc_tests.py | 2 +- tests/_program/waveforms_tests.py | 5 ++++- tests/pulses/plotting_tests.py | 6 ++++++ tests/pulses/pulse_template_tests.py | 4 ++-- 6 files changed, 26 insertions(+), 9 deletions(-) diff --git a/qupulse/_program/__init__.py b/qupulse/_program/__init__.py index 26ce85ea8..f142e9573 100644 --- a/qupulse/_program/__init__.py +++ b/qupulse/_program/__init__.py @@ -8,6 +8,14 @@ from qupulse.utils.types import MeasurementWindow, TimeType from qupulse._program.volatile import VolatileRepetitionCount +try: + import qupulse_rs +except ImportError: + qupulse_rs = None + RsProgramBuilder = None +else: + from qupulse_rs.replacements import ProgramBuilder as RsProgramBuilder + try: from typing import Protocol, runtime_checkable except ImportError: @@ -64,9 +72,8 @@ def to_program(self) -> Optional[Program]: def default_program_builder() -> ProgramBuilder: - try: - import qupulse_rs.qupulse_rs - return qupulse_rs.qupulse_rs.replacements.ProgramBuilder() - except (AttributeError, ImportError): + if RsProgramBuilder is None: from qupulse._program._loop import Loop return Loop() + else: + return RsProgramBuilder() diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 5feba7185..a7da2ce9c 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -214,7 +214,8 @@ def _create_program(self, *, program = builder.to_program() if program is not None: - waveform = program.to_single_waveform() + # we use the free function here for better testability + waveform = to_waveform(program) if global_transformation: waveform = TransformingWaveform(waveform, global_transformation) diff --git a/tests/_program/seqc_tests.py b/tests/_program/seqc_tests.py index 02e791ec2..3d2f00c24 100644 --- a/tests/_program/seqc_tests.py +++ b/tests/_program/seqc_tests.py @@ -70,7 +70,7 @@ def make_binary_waveform(waveform): return (BinaryWaveform(data),) else: chs = sorted(waveform.defined_channels) - t = np.arange(0., waveform.duration, 1.) + t = np.arange(0., waveform.duration, 1., dtype=float) sampled = [None if ch is None else waveform.get_sampled(ch, t) for _, ch in zip_longest(range(6), take(6, chs), fillvalue=None)] diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index 1546e6684..5d16357e3 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -249,10 +249,13 @@ def test_equality(self) -> None: dwf_b = DummyWaveform(duration=246.2, defined_channels={'B'}) dwf_c = DummyWaveform(duration=246.2, defined_channels={'C'}) waveform_a1 = MultiChannelWaveform([dwf_a, dwf_b]) - waveform_a2 = MultiChannelWaveform([dwf_a, dwf_b]) + waveform_a2 = MultiChannelWaveform([dwf_b, dwf_a]) waveform_a3 = MultiChannelWaveform([dwf_a, dwf_c]) + waveform_a4 = MultiChannelWaveform([dwf_a, dwf_b, dwf_c]) self.assertEqual(waveform_a1, waveform_a1) self.assertEqual(waveform_a1, waveform_a2) + self.assertEqual(waveform_a4.get_subset_for_channels({'A', 'B'}), waveform_a4.get_subset_for_channels({'B', 'A'})) + self.assertEqual({waveform_a1}, {waveform_a1, waveform_a2}) self.assertNotEqual(waveform_a1, waveform_a3) def test_unsafe_get_subset_for_channels(self): diff --git a/tests/pulses/plotting_tests.py b/tests/pulses/plotting_tests.py index 75d004157..dfd7ca81d 100644 --- a/tests/pulses/plotting_tests.py +++ b/tests/pulses/plotting_tests.py @@ -5,6 +5,11 @@ import numpy +try: + import qupulse_rs +except ImportError: + qupulse_rs = None + from qupulse.pulses.plotting import PlottingNotPossibleException, render, plot from qupulse.pulses.table_pulse_template import TablePulseTemplate from qupulse.pulses.sequence_pulse_template import SequencePulseTemplate @@ -144,6 +149,7 @@ def test_bug_422(self): plot(pt, parameters={}) + @unittest.skipIf(qupulse_rs is not None, "Not relevant for rust code") def test_bug_422_mock(self): pt = TablePulseTemplate({'X': [(0, 1), (100, 1)]}) program = pt.create_program() diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index 243a7f981..e6ae3cee8 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -235,7 +235,7 @@ def test__create_program_single_waveform(self): single_waveform = DummyWaveform() measurements = [('m', 0, 1), ('n', 0.1, .9)] - expected_inner_program = Loop(waveform=wf, measurements=measurements) + expected_inner_program = Loop(children=[Loop(waveform=wf, measurements=measurements)]) appending_create_program = get_appending_internal_create_program(wf, measurements=measurements, @@ -250,7 +250,7 @@ def test__create_program_single_waveform(self): with mock.patch.object(template, '_internal_create_program', wraps=appending_create_program) as _internal_create_program: - with mock.patch('qupulse._program._loop.to_waveform', + with mock.patch('qupulse.pulses.pulse_template.to_waveform', return_value=single_waveform) as to_waveform: with mock.patch('qupulse.pulses.pulse_template.default_program_builder', return_value=inner_program_builder): From af7f000d543c30af042dda3720c774ecb0bc0b10 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 14 Nov 2022 14:34:23 +0100 Subject: [PATCH 35/35] Merge remote-tracking branch 'qutech/master' into HEAD # Conflicts: # qupulse/_program/_loop.py # qupulse/_program/waveforms.py # qupulse/pulses/arithmetic_pulse_template.py # qupulse/pulses/constant_pulse_template.py # qupulse/pulses/pulse_template.py # tests/_program/seqc_tests.py # tests/pulses/plotting_tests.py # tests/pulses/pulse_template_tests.py --- .github/workflows/pythonpublish.yml | 8 +- .github/workflows/pythontest.yaml | 1 + .github/workflows/unittest_publish.yaml | 20 +- .travis.yml | 2 +- RELEASE_NOTES.rst | 36 ++ changes.d/612.bugfix | 1 - changes.d/615.feature | 1 - changes.d/622.feature | 2 - changes.d/635.feature | 1 - changes.d/638.removal | 1 - changes.d/639.removal | 2 - changes.d/642.feature | 1 - changes.d/645.feature | 1 - changes.d/653.feature | 1 - changes.d/656.removal | 1 - changes.d/696.fix | 1 + qupulse/__init__.py | 2 +- qupulse/_program/_loop.py | 71 ++- qupulse/_program/seqc.py | 161 ++++--- qupulse/_program/transformation.py | 10 +- qupulse/_program/waveforms.py | 45 +- qupulse/expressions.py | 164 ++++--- qupulse/hardware/awgs/base.py | 52 ++- qupulse/hardware/awgs/zihdawg.py | 424 ++++++++++++------ qupulse/hardware/dacs/alazar.py | 98 ++-- qupulse/hardware/dacs/alazar2.py | 132 ++++++ qupulse/hardware/dacs/dac_base.py | 42 +- qupulse/hardware/setup.py | 4 +- qupulse/hardware/util.py | 210 ++++++++- qupulse/pulses/abstract_pulse_template.py | 4 + qupulse/pulses/arithmetic_pulse_template.py | 139 ++++-- qupulse/pulses/constant_pulse_template.py | 134 +++--- qupulse/pulses/function_pulse_template.py | 12 +- qupulse/pulses/loop_pulse_template.py | 165 +------ qupulse/pulses/mapping_pulse_template.py | 41 +- .../pulses/multi_channel_pulse_template.py | 30 +- qupulse/pulses/plotting.py | 20 +- qupulse/pulses/point_pulse_template.py | 16 + qupulse/pulses/pulse_template.py | 38 +- qupulse/pulses/range.py | 156 +++++++ qupulse/pulses/repetition_pulse_template.py | 8 + qupulse/pulses/sequence_pulse_template.py | 15 +- qupulse/pulses/table_pulse_template.py | 10 + qupulse/utils/performance.py | 84 ++++ qupulse/utils/types.py | 10 +- setup.cfg | 1 + tests/_program/loop_tests.py | 22 + tests/_program/seqc_tests.py | 25 +- tests/expression_tests.py | 17 + tests/hardware/alazar_tests.py | 2 +- tests/hardware/util_tests.py | 52 ++- tests/hardware/zihdawg_tests.py | 26 +- .../pulses/arithmetic_pulse_template_tests.py | 36 +- tests/pulses/constant_pulse_template_tests.py | 73 ++- tests/pulses/function_pulse_tests.py | 8 + tests/pulses/loop_pulse_template_tests.py | 21 +- tests/pulses/mapping_pulse_template_tests.py | 15 +- .../multi_channel_pulse_template_tests.py | 16 +- tests/pulses/plotting_tests.py | 7 + tests/pulses/point_pulse_template_tests.py | 23 + tests/pulses/pulse_template_tests.py | 13 +- .../pulses/repetition_pulse_template_tests.py | 10 + tests/pulses/sequence_pulse_template_tests.py | 7 + tests/pulses/sequencing_dummies.py | 26 +- tests/pulses/table_pulse_template_tests.py | 7 + tests/utils/performance_tests.py | 30 ++ 66 files changed, 2087 insertions(+), 727 deletions(-) delete mode 100644 changes.d/612.bugfix delete mode 100644 changes.d/615.feature delete mode 100644 changes.d/622.feature delete mode 100644 changes.d/635.feature delete mode 100644 changes.d/638.removal delete mode 100644 changes.d/639.removal delete mode 100644 changes.d/642.feature delete mode 100644 changes.d/645.feature delete mode 100644 changes.d/653.feature delete mode 100644 changes.d/656.removal create mode 100644 changes.d/696.fix create mode 100644 qupulse/hardware/dacs/alazar2.py create mode 100644 qupulse/pulses/range.py create mode 100644 qupulse/utils/performance.py create mode 100644 tests/utils/performance_tests.py diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index ca134010f..8d56cff3c 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -12,15 +12,15 @@ jobs: - name: Set up Python uses: actions/setup-python@v1 with: - python-version: '3.7' + python-version: '3.8' - name: Install dependencies run: | python -m pip install --upgrade pip - pip install setuptools wheel twine + python -m pip install --upgrade build twine - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: | - python setup.py sdist bdist_wheel - twine upload dist/* + python -m build + python -m twine upload dist/* diff --git a/.github/workflows/pythontest.yaml b/.github/workflows/pythontest.yaml index 77b809629..cacee5647 100644 --- a/.github/workflows/pythontest.yaml +++ b/.github/workflows/pythontest.yaml @@ -1,6 +1,7 @@ name: Pytest and coveralls on: + workflow_dispatch: pull_request: types: - opened diff --git a/.github/workflows/unittest_publish.yaml b/.github/workflows/unittest_publish.yaml index 113cf1292..abd944095 100644 --- a/.github/workflows/unittest_publish.yaml +++ b/.github/workflows/unittest_publish.yaml @@ -5,13 +5,23 @@ on: workflows: ["Pytest and coveralls"] types: - completed +permissions: {} jobs: - unit-test-results: - name: Unit Test Results + test-results: + name: Test Results runs-on: ubuntu-latest if: github.event.workflow_run.conclusion != 'skipped' + permissions: + checks: write + + # needed unless run with comment_mode: off + pull-requests: write + + # required by download step to access artifacts API + actions: read + steps: - name: Download and Extract Artifacts env: @@ -28,10 +38,10 @@ jobs: unzip -d "$name" "$name.zip" done - - name: Publish Unit Test Results - uses: EnricoMi/publish-unit-test-result-action@v1 + - name: Publish Test Results + uses: EnricoMi/publish-unit-test-result-action@v2 with: commit: ${{ github.event.workflow_run.head_sha }} event_file: artifacts/Event File/event.json event_name: ${{ github.event.workflow_run.event }} - files: "artifacts/**/*.xml" + junit_files: "artifacts/**/*.xml" diff --git a/.travis.yml b/.travis.yml index f4736e841..df66b4e0b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ python: - 3.8 env: - INSTALL_EXTRAS=[plotting,zurich-instruments,tektronix,tabor-instruments] - - INSTALL_EXTRAS=[plotting,zurich-instruments,tektronix,tabor-instruments,Faster-fractions] + - INSTALL_EXTRAS=[plotting,zurich-instruments,tektronix,tabor-instruments,Faster-fractions,faster-sampling] #use container based infrastructure sudo: false diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index 2caa0be5c..c61c8b8c0 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -2,6 +2,42 @@ .. towncrier release notes start +qupulse 0.7 (2022-10-05) +======================== + +Features +-------- + +- Add optional numba uses in some cases. (`#501 `_) +- Add `initial_values` and `final_values` attributes to `PulseTemplate`. + + This allows pulse template construction that depends on features of arbitrary existing pulses i.e. like extension until + a certain length. (`#549 `_) +- Support sympy 1.9 (`#615 `_) +- Add option to automatically reduce the sample rate of HDAWG playback for piecewise constant pulses. + Use `qupulse._program.seqc.WaveformPlayback.ENABLE_DYNAMIC_RATE_REDUCTION` to enable it. (`#622 `_) +- Add a TimeReversalPT. (`#635 `_) +- Add specialied parameter Scope for ForLoopPT. This increases performance by roughly a factor of 3 for long ranges! (`#642 `_) +- Add sympy 1.10 support and make `ExpressionVector` hashable. (`#645 `_) +- `Serializable` is now comparable via it's `get_serialized_data`. `PulseTemplate` implements `Hashable` via the same. (`#653 `_) +- Add an interface that uses `atsaverage.config2`. (`#686 `_) + + +Bugfixes +-------- + +- `floor` will now return an integer in lambda expressions with numpy to allow usage in ForLoopPT range expression. (`#612 `_) + + +Deprecations and Removals +------------------------- + +- Drop `cached_property` dependency for python>=3.8. (`#638 `_) +- Add frozendict dependency to replace handwritten solution. Not having it installed will break in a future release + when the old implementation is removed. (`#639 `_) +- Drop python 3.6 support. (`#656 `_) + + qupulse 0.6 (2021-07-08) ========================== diff --git a/changes.d/612.bugfix b/changes.d/612.bugfix deleted file mode 100644 index 65e80ec0a..000000000 --- a/changes.d/612.bugfix +++ /dev/null @@ -1 +0,0 @@ -`floor` will now return an integer in lambda expressions with numpy to allow usage in ForLoopPT range expression. diff --git a/changes.d/615.feature b/changes.d/615.feature deleted file mode 100644 index 4f4fdfdab..000000000 --- a/changes.d/615.feature +++ /dev/null @@ -1 +0,0 @@ -Support sympy 1.9 diff --git a/changes.d/622.feature b/changes.d/622.feature deleted file mode 100644 index cfcbc8db2..000000000 --- a/changes.d/622.feature +++ /dev/null @@ -1,2 +0,0 @@ -Add option to automatically reduce the sample rate of HDAWG playback for piecewise constant pulses. -Use `qupulse._program.seqc.WaveformPlayback.ENABLE_DYNAMIC_RATE_REDUCTION` to enable it. diff --git a/changes.d/635.feature b/changes.d/635.feature deleted file mode 100644 index 7d9d9b368..000000000 --- a/changes.d/635.feature +++ /dev/null @@ -1 +0,0 @@ -Add a TimeReversalPT. \ No newline at end of file diff --git a/changes.d/638.removal b/changes.d/638.removal deleted file mode 100644 index fb9914217..000000000 --- a/changes.d/638.removal +++ /dev/null @@ -1 +0,0 @@ -Drop `cached_property` dependency for python>=3.8. diff --git a/changes.d/639.removal b/changes.d/639.removal deleted file mode 100644 index 307a6e3d4..000000000 --- a/changes.d/639.removal +++ /dev/null @@ -1,2 +0,0 @@ -Add frozendict dependency to replace handwritten solution. Not having it installed will break in a future release -when the old implementation is removed. diff --git a/changes.d/642.feature b/changes.d/642.feature deleted file mode 100644 index 18fd0254a..000000000 --- a/changes.d/642.feature +++ /dev/null @@ -1 +0,0 @@ -Add specialied parameter Scope for ForLoopPT. This increases performance by roughly a factor of 3 for long ranges! diff --git a/changes.d/645.feature b/changes.d/645.feature deleted file mode 100644 index 3b540de47..000000000 --- a/changes.d/645.feature +++ /dev/null @@ -1 +0,0 @@ -Add sympy 1.10 support and make `ExpressionVector` hashable. diff --git a/changes.d/653.feature b/changes.d/653.feature deleted file mode 100644 index 3d5ae8aef..000000000 --- a/changes.d/653.feature +++ /dev/null @@ -1 +0,0 @@ -`Serializable` is now comparable via it's `get_serialized_data`. `PulseTemplate` implements `Hashable` via the same. diff --git a/changes.d/656.removal b/changes.d/656.removal deleted file mode 100644 index 6d55be1bf..000000000 --- a/changes.d/656.removal +++ /dev/null @@ -1 +0,0 @@ -Drop python 3.6 support. diff --git a/changes.d/696.fix b/changes.d/696.fix new file mode 100644 index 000000000..9df69bd0b --- /dev/null +++ b/changes.d/696.fix @@ -0,0 +1 @@ +`ConstantPulseTemplate`s from all versions can now be deserialized. \ No newline at end of file diff --git a/qupulse/__init__.py b/qupulse/__init__.py index 6c13240aa..33df6de57 100644 --- a/qupulse/__init__.py +++ b/qupulse/__init__.py @@ -3,5 +3,5 @@ from qupulse.utils.types import MeasurementWindow, ChannelID from . import pulses -__version__ = '0.6' +__version__ = '0.7' __all__ = ["MeasurementWindow", "ChannelID", "pulses"] diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index b39c32111..6424d58a4 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -260,9 +260,13 @@ def copy_tree_structure(self, new_parent: Union['Loop', bool]=False) -> 'Loop': measurements=None if self._measurements is None else list(self._measurements), children=(child.copy_tree_structure() for child in self)) - def _get_measurement_windows(self) -> Mapping[str, np.ndarray]: + def _get_measurement_windows(self, drop: bool) -> Mapping[str, np.ndarray]: """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. + Args: + drop: Drops the measurements from the Loop i.e. the Loop will no longer have measurements attached after + collecting them + Returns: A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] """ @@ -274,46 +278,43 @@ def _get_measurement_windows(self) -> Mapping[str, np.ndarray]: for mw_name, begin_length_list in temp_meas_windows.items(): temp_meas_windows[mw_name] = [np.asarray(begin_length_list, dtype=float)] + if drop: + self._measurements = None + # calculate duration together with meas windows in the same iteration if self.is_leaf(): body_duration = float(self.body_duration) else: offset = TimeType(0) for child in self: - for mw_name, begins_length_array in child._get_measurement_windows().items(): + for mw_name, begins_length_array in child._get_measurement_windows(drop).items(): begins_length_array[:, 0] += float(offset) temp_meas_windows[mw_name].append(begins_length_array) offset += child.duration body_duration = float(offset) - # this gives us regular dict behaviour of the returned object - temp_meas_windows.default_factory = None + # formatting like this for easier debugging + result = {} # repeat and add repetition based offset for mw_name, begin_length_list in temp_meas_windows.items(): - temp_begin_length_array = np.concatenate(begin_length_list) - - begin_length_array = np.tile(temp_begin_length_array, (self.repetition_count, 1)) - - shaped_begin_length_array = np.reshape(begin_length_array, (self.repetition_count, -1, 2)) + result[mw_name] = _repeat_loop_measurements(begin_length_list, self.repetition_count, body_duration) - shaped_begin_length_array[:, :, 0] += (np.arange(self.repetition_count) * body_duration)[:, np.newaxis] + return result - temp_meas_windows[mw_name] = begin_length_array - - # the cast is here because static type analysis struggles to detect that we replace _all_ values by ndarray in - # the previous loop - return cast(Mapping[str, np.ndarray], temp_meas_windows) - - def get_measurement_windows(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: + def get_measurement_windows(self, drop=False) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: """Iterates over all children and collect the begin and length arrays of each measurement window. + Args: + drop: Drops the measurements from the Loop i.e. the Loop will no longer have measurements attached after + collecting them. + Returns: A dictionary (measurement_name -> (begin, length)) with begin and length being :class:`numpy.ndarray` """ return {mw_name: (begin_length_list[:, 0], begin_length_list[:, 1]) - for mw_name, begin_length_list in self._get_measurement_windows().items()} + for mw_name, begin_length_list in self._get_measurement_windows(drop=drop).items()} def split_one_child(self, child_index=None) -> None: """Take the last child that has a repetition count larger one, decrease it's repetition count and insert a copy @@ -499,15 +500,15 @@ def to_single_waveform(self) -> Waveform: if self.repetition_count == 1: return self.waveform else: - return RepetitionWaveform(self.waveform, self.repetition_count) + return RepetitionWaveform.from_repetition_count(self.waveform, self.repetition_count) else: if len(self) == 1: sequenced_waveform = to_waveform(cast(Loop, self[0])) else: - sequenced_waveform = SequenceWaveform([to_waveform(cast(Loop, sub_program)) - for sub_program in self]) + sequenced_waveform = SequenceWaveform.from_sequence([to_waveform(cast(Loop, sub_program)) + for sub_program in self]) if self.repetition_count > 1: - return RepetitionWaveform(sequenced_waveform, self.repetition_count) + return RepetitionWaveform.from_repetition_count(sequenced_waveform, self.repetition_count) else: return sequenced_waveform @@ -658,7 +659,8 @@ def make_compatible(program: Loop, minimal_waveform_length: int, waveform_quantu def roll_constant_waveforms(program: Loop, minimal_waveform_quanta: int, waveform_quantum: int, sample_rate: TimeType): """This function finds waveforms in program that can be replaced with repetitions of shorter waveforms and replaces - them. Complexity O(N_waveforms) + them. Complexity O(N_waveforms). Drops measurements because they are not correctly handled here for performance + reasons. This is possible if: - The waveform is constant on all channels @@ -669,7 +671,15 @@ def roll_constant_waveforms(program: Loop, minimal_waveform_quanta: int, wavefor minimal_waveform_quanta: waveform_quantum: sample_rate: + + Warnings: + DroppedMeasurementWarning: This warning is raised if a measurement is dropped. """ + if program._measurements: + warnings.warn("Dropping measurements. Remove measurements before calling roll_constant_waveforms by calling" + " get_measurement_windows(drop=True).", category=DroppedMeasurementWarning) + program._measurements = None + waveform = program.waveform if waveform is None: @@ -708,6 +718,21 @@ def roll_constant_waveforms(program: Loop, minimal_waveform_quanta: int, wavefor program._waveform = new_waveform +def _repeat_loop_measurements(begin_length_list: List[np.ndarray], + repetition_count: int, + body_duration: float + ) -> np.ndarray: + temp_begin_length_array = np.concatenate(begin_length_list) + + begin_length_array = np.tile(temp_begin_length_array, (repetition_count, 1)) + + shaped_begin_length_array = np.reshape(begin_length_array, (repetition_count, -1, 2)) + + shaped_begin_length_array[:, :, 0] += (np.arange(repetition_count) * body_duration)[:, np.newaxis] + + return begin_length_array + + class MakeCompatibleWarning(ResourceWarning): pass diff --git a/qupulse/_program/seqc.py b/qupulse/_program/seqc.py index 52d6042aa..b9f248569 100644 --- a/qupulse/_program/seqc.py +++ b/qupulse/_program/seqc.py @@ -11,7 +11,7 @@ - `WaveformMemory`: Functionality to sync waveforms to the device (via the LabOne user folder) - `ProgramWaveformManager` and `HDAWGProgramEntry`: Program wise handling of waveforms and seqc-code classes that convert `Loop` objects""" - +import warnings from typing import Optional, Union, Sequence, Dict, Iterator, Tuple, Callable, NamedTuple, MutableMapping, Mapping,\ Iterable, Any, List, Deque from types import MappingProxyType @@ -20,6 +20,7 @@ import inspect import logging import hashlib +from weakref import WeakValueDictionary from collections import OrderedDict import re import collections @@ -36,9 +37,14 @@ from qupulse._program._loop import Loop from qupulse._program.volatile import VolatileRepetitionCount, VolatileProperty from qupulse.hardware.awgs.base import ProgramEntry +from qupulse.hardware.util import zhinst_voltage_to_uint16 +from qupulse.pulses.parameters import MappedParameter, ConstantParameter try: - import zhinst.utils + # zhinst fires a DeprecationWarning from its own code in some versions... + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + import zhinst.utils except ImportError: zhinst = None @@ -117,20 +123,7 @@ def from_sampled(cls, ch1: Optional[np.ndarray], ch2: Optional[np.ndarray], Returns: """ - all_input = (ch1, ch2, *markers) - assert any(x is not None for x in all_input) - size = {x.size for x in all_input if x is not None} - assert len(size) == 1, "Inputs have incompatible dimension" - size, = size - if ch1 is None: - ch1 = np.zeros(size) - if ch2 is None: - ch2 = np.zeros(size) - marker_data = np.zeros(size, dtype=np.uint16) - for idx, marker in enumerate(markers): - if marker is not None: - marker_data += np.uint16((marker > 0) * 2**idx) - return cls(zhinst.utils.convert_awg_waveform(ch1, ch2, marker_data)) + return cls(zhinst_voltage_to_uint16(ch1, ch2, markers)) @classmethod def zeroed(cls, size): @@ -149,7 +142,7 @@ def fingerprint(self) -> str: """This fingerprint is runtime independent""" return hashlib.sha256(self.data).hexdigest() - def to_csv_compatible_table(self): + def to_csv_compatible_table(self) -> np.ndarray: """The integer values in that file should be 18-bit unsigned integers with the two least significant bits being the markers. The values are mapped to 0 => -FS, 262143 => +FS, with FS equal to the full scale. @@ -225,6 +218,7 @@ def clear(self): class WaveformFileSystem: logger = logging.getLogger('qupulse.hdawg.waveforms') + _by_path = WeakValueDictionary() def __init__(self, path: Path): """This class coordinates multiple AWGs (channel pairs) using the same file system to store the waveforms. @@ -234,6 +228,11 @@ def __init__(self, path: Path): """ self._required = {} self._path = path + + @classmethod + def get_waveform_file_system(cls, path: Path) -> 'WaveformFileSystem': + """Get the instance for the given path. Multiple instances that access the same path lead to inconsistencies.""" + return cls._by_path.setdefault(path, cls(path)) def sync(self, client: 'WaveformMemory', waveforms: Mapping[str, BinaryWaveform], **kwargs): """Write the required waveforms to the filesystem.""" @@ -812,7 +811,34 @@ def name_to_index(self, name: str) -> int: assert self._programs[name].name == name return self._programs[name].selection_index - def to_seqc_program(self) -> str: + def _get_sub_program_source_code(self, program_name: str) -> str: + program = self.programs[program_name] + program_function_name = self.get_program_function_name(program_name) + return "\n".join( + [ + f"void {program_function_name}() {{", + program.seqc_source, + "}\n" + ] + ) + + def _get_program_selection_code(self) -> str: + return _make_program_selection_block((program.selection_index, self.get_program_function_name(program_name)) + for program_name, program in self.programs.items()) + + def to_seqc_program(self, single_program: Optional[str] = None) -> str: + """Generate sequencing c source code that is either capable of playing pack all uploaded programs where the + program is selected at runtime without re-compile or always will play the same program if `single_program` + is specified. + + The program selection is based on a user register in the first case. + + Args: + single_program: The seqc source only contains this program if not None + + Returns: + SEQC source code. + """ lines = [] for const_name, const_val in self.Constants.as_dict().items(): if isinstance(const_val, (int, str)): @@ -829,44 +855,23 @@ def to_seqc_program(self) -> str: replacements = self._waveform_memory.waveform_name_replacements() lines.append('\n// program definitions') - for program_name, program in self.programs.items(): - program_function_name = self.get_program_function_name(program_name) - lines.append('void {program_function_name}() {{'.format(program_function_name=program_function_name)) - lines.append(replace_multiple(program.seqc_source, replacements)) - lines.append('}\n') + if single_program: + lines.append( + replace_multiple(self._get_sub_program_source_code(single_program), replacements) + ) + + else: + for program_name, program in self.programs.items(): + lines.append(replace_multiple(self._get_sub_program_source_code(program_name), replacements)) lines.append(self.GlobalVariables.get_init_block()) lines.append('\n// runtime block') - lines.append('while (true) {') - lines.append(' // read program selection value') - lines.append(' prog_sel = getUserReg(PROG_SEL_REGISTER);') - lines.append(' ') - lines.append(' // calculate value to write back to PROG_SEL_REGISTER') - lines.append(' new_prog_sel = prog_sel | playback_finished;') - lines.append(' if (!(prog_sel & NO_RESET_MASK)) new_prog_sel &= INVERTED_PROG_SEL_MASK;') - lines.append(' setUserReg(PROG_SEL_REGISTER, new_prog_sel);') - lines.append(' ') - lines.append(' // reset playback flag') - lines.append(' playback_finished = 0;') - lines.append(' ') - lines.append(' // only use part of prog sel that does not mean other things to select the program.') - lines.append(' prog_sel &= PROG_SEL_MASK;') - lines.append(' ') - lines.append(' switch (prog_sel) {') - - for program_name, program_entry in self.programs.items(): - program_function_name = self.get_program_function_name(program_name) - lines.append(' case {selection_index}:'.format(selection_index=program_entry.selection_index)) - lines.append(' {program_function_name}();'.format(program_function_name=program_function_name)) - lines.append(' waitWave();') - lines.append(' playback_finished = PLAYBACK_FINISHED_MASK;') - - lines.append(' default:') - lines.append(' wait(IDLE_WAIT_CYCLES);') - lines.append(' }') - lines.append('}') - + if single_program: + lines.append(f"{self.get_program_function_name(single_program)}();") + else: + lines.append(self._get_program_selection_code()) + return '\n'.join(lines) @@ -973,8 +978,6 @@ def to_node_clusters(loop: Union[Sequence[Loop], Loop], loop_to_seqc_kwargs: dic node_clusters: List[List[SEQCNode]] = [] - - last_period = [] # this is the period that we currently are collecting current_period: List[SEQCNode] = [] @@ -1019,6 +1022,8 @@ def to_node_clusters(loop: Union[Sequence[Loop], Loop], loop_to_seqc_kwargs: dic last_nodes.extend(current_period) last_hashes.extend(current_template_hashes[:len(current_period)]) + current_period.clear() + last_nodes.append(current_node) last_hashes.append(current_hash) @@ -1027,6 +1032,7 @@ def to_node_clusters(loop: Union[Sequence[Loop], Loop], loop_to_seqc_kwargs: dic current_cluster) = _find_repetition(last_nodes, last_hashes, node_clusters) else: + assert not current_period if len(last_nodes) == last_nodes.maxlen: # lookup deque is full node_clusters.append([last_nodes.popleft()]) @@ -1041,7 +1047,8 @@ def to_node_clusters(loop: Union[Sequence[Loop], Loop], loop_to_seqc_kwargs: dic node_clusters) assert not (current_cluster and last_nodes) - node_clusters.append(current_cluster) + if current_cluster: + node_clusters.append(current_cluster) node_clusters.extend([node] for node in current_period) node_clusters.extend([node] for node in last_nodes) @@ -1206,6 +1213,9 @@ def __eq__(self, other): else: return NotImplemented + def __repr__(self): + return f"Scope(nodes={self.nodes!r})" + class Repeat(SEQCNode): """""" @@ -1389,6 +1399,9 @@ def __init__(self, waveform: Tuple[BinaryWaveform, ...], shared: bool = False, r self.shared = shared self.rate = rate + def __repr__(self): + return f"WaveformPlayback(<{id(self)}>)" + def samples(self) -> int: """Samples consumed in the big concatenated waveform""" if self.shared: @@ -1446,3 +1459,41 @@ def to_source_code(self, waveform_manager: ProgramWaveformManager, else: advance_cmd = self.ADVANCE_DISABLED_COMMENT yield play_cmd + advance_cmd + + +_PROGRAM_SELECTION_BLOCK = """\ +while (true) {{ + // read program selection value + prog_sel = getUserReg(PROG_SEL_REGISTER); + + // calculate value to write back to PROG_SEL_REGISTER + new_prog_sel = prog_sel | playback_finished; + if (!(prog_sel & NO_RESET_MASK)) new_prog_sel &= INVERTED_PROG_SEL_MASK; + setUserReg(PROG_SEL_REGISTER, new_prog_sel); + + // reset playback flag + playback_finished = 0; + + // only use part of prog sel that does not mean other things to select the program. + prog_sel &= PROG_SEL_MASK; + + switch (prog_sel) {{ +{program_cases} + default: + wait(IDLE_WAIT_CYCLES); + }} +}}""" + +_PROGRAM_SELECTION_CASE = """\ + case {selection_index}: + {program_function_name}(); + waitWave(); + playback_finished = PLAYBACK_FINISHED_MASK;""" + + +def _make_program_selection_block(programs: Iterable[Tuple[int, str]]): + program_cases = [] + for selection_index, program_function_name in programs: + program_cases.append(_PROGRAM_SELECTION_CASE.format(selection_index=selection_index, + program_function_name=program_function_name)) + return _PROGRAM_SELECTION_BLOCK.format(program_cases="\n".join(program_cases)) diff --git a/qupulse/_program/transformation.py b/qupulse/_program/transformation.py index e555caed2..734b71948 100644 --- a/qupulse/_program/transformation.py +++ b/qupulse/_program/transformation.py @@ -1,4 +1,4 @@ -from typing import Mapping, Set, Tuple, Sequence, AbstractSet, Union +from typing import Any, Mapping, Set, Tuple, Sequence, AbstractSet, Union, TYPE_CHECKING from abc import abstractmethod from numbers import Real @@ -263,9 +263,13 @@ def is_constant_invariant(self): try: - import pandas + if TYPE_CHECKING: + import pandas + PandasDataFrameType = pandas.DataFrame + else: + PandasDataFrameType = Any - def linear_transformation_from_pandas(transformation: pandas.DataFrame) -> LinearTransformation: + def linear_transformation_from_pandas(transformation: PandasDataFrameType) -> LinearTransformation: """ Creates a LinearTransformation object out of a pandas data frame. Args: diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index a773a9591..3adb1846d 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -18,6 +18,9 @@ from qupulse import ChannelID from qupulse._program.transformation import Transformation +from qupulse.utils import checked_int_cast, isclose +from qupulse.utils.types import TimeType, time_from_float +from qupulse.utils.performance import is_monotonic from qupulse.comparable import Comparable from qupulse.expressions import ExpressionScalar from qupulse.pulses.interpolation import InterpolationStrategy @@ -34,9 +37,13 @@ else: from qupulse_rs.replacements import waveforms as waveforms_rs +class ConstantFunctionPulseTemplateWarning(UserWarning): + """ This warning indicates a constant waveform is constructed from a FunctionPulseTemplate """ + pass __all__ = ["Waveform", "TableWaveform", "TableWaveformEntry", "FunctionWaveform", "SequenceWaveform", - "MultiChannelWaveform", "RepetitionWaveform", "TransformingWaveform", "ArithmeticWaveform"] + "MultiChannelWaveform", "RepetitionWaveform", "TransformingWaveform", "ArithmeticWaveform", + "ConstantFunctionPulseTemplateWarning"] PULSE_TO_WAVEFORM_ERROR = None # error margin in pulse template to waveform conversion @@ -105,13 +112,13 @@ def get_sampled(self, """ if len(sample_times) == 0: if output_array is None: - return np.zeros_like(sample_times) + return np.zeros_like(sample_times, dtype=float) elif len(output_array) == len(sample_times): return output_array else: raise ValueError('Output array length and sample time length are different') - if np.any(np.diff(sample_times) < 0): + if not is_monotonic(sample_times): raise ValueError('The sample times are not monotonously increasing') if sample_times[0] < 0 or sample_times[-1] > float(self.duration): raise ValueError(f'The sample times [{sample_times[0]}, ..., {sample_times[-1]}] are not in the range' @@ -202,7 +209,7 @@ def constant_value(self, channel: ChannelID) -> Optional[float]: return None def __neg__(self): - return FunctorWaveform(self, {ch: np.negative for ch in self.defined_channels}) + return FunctorWaveform.from_functor(self, {ch: np.negative for ch in self.defined_channels}) def __pos__(self): return self @@ -213,6 +220,7 @@ def _sort_key_for_channels(self) -> Sequence[Tuple[str, int]]: def reversed(self) -> 'Waveform': """Returns a reversed version of this waveform.""" + # We don't check for constness here because const waveforms are supposed to override this method return ReversedWaveform(self) @@ -478,7 +486,7 @@ def __init__(self, expression: ExpressionScalar, raise ValueError('FunctionWaveforms may not depend on anything but "t"') elif not expression.variables: warnings.warn("Constant FunctionWaveform is not recommended as the constant propagation will be suboptimal", - category=UserWarning) + category=ConstantFunctionPulseTemplateWarning) super().__init__(duration=_to_time_type(duration)) self._expression = expression self._channel_id = channel @@ -641,9 +649,9 @@ def duration(self) -> TimeType: return self._duration def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': - return SequenceWaveform( + return SequenceWaveform.from_sequence([ sub_waveform.unsafe_get_subset_for_channels(channels & sub_waveform.defined_channels) - for sub_waveform in self._sequenced_waveforms if sub_waveform.defined_channels & channels) + for sub_waveform in self._sequenced_waveforms if sub_waveform.defined_channels & channels]) @property def sequenced_waveforms(self) -> Sequence[Waveform]: @@ -853,9 +861,10 @@ def unsafe_sample(self, def compare_key(self) -> Tuple[int, Any]: return self._repetition_count, self._body - def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'RepetitionWaveform': - return RepetitionWaveform(body=self._body.unsafe_get_subset_for_channels(channels), - repetition_count=self._repetition_count) + def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform: + return RepetitionWaveform.from_repetition_count( + body=self._body.unsafe_get_subset_for_channels(channels), + repetition_count=self._repetition_count) def is_constant(self) -> bool: return self._body.is_constant() @@ -890,7 +899,7 @@ def from_transformation(cls, inner_waveform: Waveform, transformation: Transform if constant_values is None or not transformation.is_constant_invariant(): return cls(inner_waveform, transformation) - transformed_constant_values = transformation(0., constant_values) + transformed_constant_values = {key: float(value) for key, value in transformation(0., constant_values).items()} return ConstantWaveform.from_mapping(inner_waveform.duration, transformed_constant_values) def is_constant(self) -> bool: @@ -1184,8 +1193,9 @@ def unsafe_sample(self, return self._functor[channel](inner_output, out=inner_output) def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: - return FunctorWaveform(self._inner_waveform.unsafe_get_subset_for_channels(channels), - {ch: self._functor[ch] for ch in channels}) + return FunctorWaveform.from_functor( + self._inner_waveform.unsafe_get_subset_for_channels(channels), + {ch: self._functor[ch] for ch in channels}) @property def compare_key(self) -> Tuple[Waveform, FrozenSet]: @@ -1201,6 +1211,13 @@ def __init__(self, inner: Waveform): super().__init__(duration=inner.duration) self._inner = inner + @classmethod + def from_to_reverse(cls, inner: Waveform) -> Waveform: + if inner.constant_value_dict(): + return inner + else: + return cls(inner) + def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: inner_sample_times = (float(self.duration) - sample_times)[::-1] @@ -1219,7 +1236,7 @@ def defined_channels(self) -> AbstractSet[ChannelID]: return self._inner.defined_channels def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': - return ReversedWaveform(self._inner.unsafe_get_subset_for_channels(channels)) + return ReversedWaveform.from_to_reverse(self._inner.unsafe_get_subset_for_channels(channels)) @property def compare_key(self) -> Hashable: diff --git a/qupulse/expressions.py b/qupulse/expressions.py index ca87aa4e8..d86f97498 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -32,6 +32,50 @@ _ExpressionType = TypeVar('_ExpressionType', bound='Expression') +ALLOWED_NUMERIC_SCALAR_TYPES = (float, numpy.number, int, complex, bool, numpy.bool_, TimeType) + + +def _parse_evaluate_numeric(result) -> Union[Number, numpy.ndarray]: + """Tries to parse the result as a scalar if possible. Falls back to an array otherwise. + Raises: + ValueError if scalar result is not parsable + """ + allowed_scalar = ALLOWED_NUMERIC_SCALAR_TYPES + + if isinstance(result, allowed_scalar): + # fast path for regular evaluations + return result + if isinstance(result, tuple): + result, = result + elif isinstance(result, numpy.ndarray): + result = result[()] + + if isinstance(result, allowed_scalar): + return result + if isinstance(result, sympy.Float): + return float(result) + elif isinstance(result, sympy.Integer): + return int(result) + + if isinstance(result, numpy.ndarray): + # allow numeric vector values + return _parse_evaluate_numeric_vector(result) + raise ValueError("Non numeric result", result) + + +def _parse_evaluate_numeric_vector(vector_result: numpy.ndarray) -> numpy.ndarray: + allowed_scalar = ALLOWED_NUMERIC_SCALAR_TYPES + if not issubclass(vector_result.dtype.type, allowed_scalar): + obj_types = set(map(type, vector_result.flat)) + if all(issubclass(obj_type, sympy.Integer) for obj_type in obj_types): + result = vector_result.astype(numpy.int64) + elif all(issubclass(obj_type, (sympy.Integer, sympy.Float)) for obj_type in obj_types): + result = vector_result.astype(float) + else: + raise ValueError("Could not parse vector result", vector_result) + return vector_result + + def _flat_iter(arr): if len(arr.shape) > 1: for sub_arr in arr: @@ -73,32 +117,6 @@ def _parse_evaluate_numeric_arguments(self, eval_args: Mapping[str, Number]) -> else: raise ExpressionVariableMissingException(key_error.args[0], self) from key_error - def _parse_evaluate_numeric_result(self, - result: Union[Number, numpy.ndarray], - call_arguments: Any) -> Union[Number, numpy.ndarray]: - allowed_types = (float, numpy.number, int, complex, bool, numpy.bool_, TimeType) - if isinstance(result, tuple): - result = numpy.array(result) - if isinstance(result, numpy.ndarray): - if issubclass(result.dtype.type, allowed_types): - return result - else: - obj_types = set(map(type, result.flat)) - if all(issubclass(obj_type, sympy.Integer) for obj_type in obj_types): - return result.astype(numpy.int64) - if all(issubclass(obj_type, (sympy.Integer, sympy.Float)) for obj_type in obj_types): - return result.astype(float) - else: - raise NonNumericEvaluation(self, result, call_arguments) - elif isinstance(result, allowed_types): - return result - elif isinstance(result, sympy.Float): - return float(result) - elif isinstance(result, sympy.Integer): - return int(result) - else: - raise NonNumericEvaluation(self, result, call_arguments) - def evaluate_in_scope(self, scope: Mapping) -> Union[Number, numpy.ndarray]: """Evaluate the expression by taking the variables from the given scope (typically of type Scope but it can be any mapping.) @@ -108,20 +126,10 @@ def evaluate_in_scope(self, scope: Mapping) -> Union[Number, numpy.ndarray]: Returns: """ - parsed_kwargs = self._parse_evaluate_numeric_arguments(scope) - - result, self._expression_lambda = evaluate_lambdified(self.underlying_expression, self.variables, - parsed_kwargs, lambdified=self._expression_lambda) - - return self._parse_evaluate_numeric_result(result, scope) + raise NotImplementedError("") def evaluate_numeric(self, **kwargs) -> Union[Number, numpy.ndarray]: - parsed_kwargs = self._parse_evaluate_numeric_arguments(kwargs) - - result, self._expression_lambda = evaluate_lambdified(self.underlying_expression, self.variables, - parsed_kwargs, lambdified=self._expression_lambda) - - return self._parse_evaluate_numeric_result(result, kwargs) + return self.evaluate_in_scope(kwargs) def __float__(self): if self.variables: @@ -202,15 +210,18 @@ def __init__(self, expression_vector: Sequence): def variables(self) -> Sequence[str]: return self._variables - def evaluate_numeric(self, **kwargs) -> Union[numpy.ndarray, Number]: - parsed_kwargs = self._parse_evaluate_numeric_arguments(kwargs) - + def evaluate_in_scope(self, scope: Mapping) -> numpy.ndarray: + parsed_kwargs = self._parse_evaluate_numeric_arguments(scope) flat_result = [] for idx, expr in enumerate(self._expression_items): result, self._lambdified_items[idx] = evaluate_lambdified(expr, self.variables, parsed_kwargs, lambdified=self._lambdified_items[idx]) flat_result.append(result) - return self._parse_evaluate_numeric_result(numpy.array(flat_result).reshape(self._expression_shape), kwargs) + result = numpy.array(flat_result).reshape(self._expression_shape) + try: + return _parse_evaluate_numeric_vector(result) + except ValueError as err: + raise NonNumericEvaluation(self, result, scope) from err def get_serialization_data(self) -> Sequence[str]: serialized_items = list(map(get_most_simple_representation, self._expression_items)) @@ -219,7 +230,13 @@ def get_serialization_data(self) -> Sequence[str]: elif len(self._expression_shape) == 1: return serialized_items else: - return np.array(serialized_items).reshape(self._expression_shape).tolist() + return numpy.array(serialized_items).reshape(self._expression_shape).tolist() + + def __getstate__(self): + return self.get_serialization_data() + + def __setstate__(self, state): + self.__init__(state) def __str__(self): return str(self.get_serialization_data()) @@ -337,20 +354,25 @@ def variables(self) -> Sequence[str]: def _sympify(cls, other: Union['ExpressionScalar', Number, sympy.Expr]) -> sympy.Expr: return sympify(other) + @classmethod + def _extract_sympified(cls, other: Union['ExpressionScalar', Number, sympy.Expr]) \ + -> Union['ExpressionScalar', Number, sympy.Expr]: + return getattr(other, '_sympified_expression', other) + def __lt__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> Union[bool, None]: - result = self._sympified_expression < self._sympify(other) + result = self._sympified_expression < self._extract_sympified(other) return None if isinstance(result, sympy.Rel) else bool(result) def __gt__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> Union[bool, None]: - result = self._sympified_expression > self._sympify(other) + result = self._sympified_expression > self._extract_sympified(other) return None if isinstance(result, sympy.Rel) else bool(result) def __ge__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> Union[bool, None]: - result = self._sympified_expression >= self._sympify(other) + result = self._sympified_expression >= self._extract_sympified(other) return None if isinstance(result, sympy.Rel) else bool(result) def __le__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> Union[bool, None]: - result = self._sympified_expression <= self._sympify(other) + result = self._sympified_expression <= self._extract_sympified(other) return None if isinstance(result, sympy.Rel) else bool(result) def __eq__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> bool: @@ -363,28 +385,28 @@ def __hash__(self) -> int: return hash(self._sympified_expression) def __add__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': - return self.make(self._sympified_expression.__add__(self._sympify(other))) + return self.make(self._sympified_expression.__add__(self._extract_sympified(other))) def __radd__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympify(other).__radd__(self._sympified_expression)) def __sub__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': - return self.make(self._sympified_expression.__sub__(self._sympify(other))) + return self.make(self._sympified_expression.__sub__(self._extract_sympified(other))) def __rsub__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': - return self.make(self._sympified_expression.__rsub__(self._sympify(other))) + return self.make(self._sympified_expression.__rsub__(self._extract_sympified(other))) def __mul__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': - return self.make(self._sympified_expression.__mul__(self._sympify(other))) + return self.make(self._sympified_expression.__mul__(self._extract_sympified(other))) def __rmul__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': - return self.make(self._sympified_expression.__rmul__(self._sympify(other))) + return self.make(self._sympified_expression.__rmul__(self._extract_sympified(other))) def __truediv__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': - return self.make(self._sympified_expression.__truediv__(self._sympify(other))) + return self.make(self._sympified_expression.__truediv__(self._extract_sympified(other))) def __rtruediv__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': - return self.make(self._sympified_expression.__rtruediv__(self._sympify(other))) + return self.make(self._sympified_expression.__rtruediv__(self._extract_sympified(other))) def __neg__(self) -> 'ExpressionScalar': return self.make(self._sympified_expression.__neg__()) @@ -413,26 +435,34 @@ def get_serialization_data(self) -> Union[str, float, int]: else: return serialized + def __getstate__(self): + return self.get_serialization_data() + + def __setstate__(self, state): + self.__init__(state) + def is_nan(self) -> bool: return sympy.sympify('nan') == self._sympified_expression - def _parse_evaluate_numeric_result(self, - result: Union[Number, numpy.ndarray], - call_arguments: Any) -> Number: - """Overwrite super class method because we do not want to return a scalar numpy.ndarray""" - parsed = super()._parse_evaluate_numeric_result(result, call_arguments) - if isinstance(parsed, numpy.ndarray): - return parsed[()] - else: - return parsed - - def evaluate_with_exact_rationals(self, scope: Mapping) -> Number: + def evaluate_with_exact_rationals(self, scope: Mapping) -> Union[Number, numpy.ndarray]: parsed_kwargs = self._parse_evaluate_numeric_arguments(scope) result, self._exact_rational_lambdified = evaluate_lamdified_exact_rational(self.sympified_expression, self.variables, parsed_kwargs, self._exact_rational_lambdified) - return self._parse_evaluate_numeric_result(result, scope) + try: + return _parse_evaluate_numeric(result) + except ValueError as err: + raise NonNumericEvaluation(self, result, scope) from err + + def evaluate_in_scope(self, scope: Mapping) -> Union[Number, numpy.ndarray]: + parsed_kwargs = self._parse_evaluate_numeric_arguments(scope) + result, self._expression_lambda = evaluate_lambdified(self.underlying_expression, self.variables, + parsed_kwargs, lambdified=self._expression_lambda) + try: + return _parse_evaluate_numeric(result) + except ValueError as err: + raise NonNumericEvaluation(self, result, scope) from err class ExpressionVariableMissingException(Exception): @@ -459,7 +489,7 @@ class NonNumericEvaluation(TypeError): qupulse.expressions.Expression.evaluate_numeric """ - def __init__(self, expression: Expression, non_numeric_result: Any, call_arguments: Dict): + def __init__(self, expression: Expression, non_numeric_result: Any, call_arguments: Mapping): self.expression = expression self.non_numeric_result = non_numeric_result self.call_arguments = call_arguments diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index daa8b66df..498788038 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -12,7 +12,7 @@ from typing import Set, Tuple, Callable, Optional, Mapping, Sequence, List from collections import OrderedDict -from qupulse.hardware.util import get_sample_times +from qupulse.hardware.util import get_sample_times, not_none_indices from qupulse.utils.types import ChannelID from qupulse._program._loop import Loop from qupulse._program.waveforms import Waveform @@ -219,30 +219,64 @@ def _sample_waveforms(self, waveforms: Sequence[Waveform]) -> List[Tuple[Tuple[n sampled_waveforms = [] time_array, segment_lengths = get_sample_times(waveforms, self._sample_rate) + sample_memory = numpy.zeros_like(time_array, dtype=float) + + n_samples = numpy.sum(segment_lengths) + ch_to_mem, n_ch = not_none_indices(self._channels) + mk_to_mem, c_mk = not_none_indices(self._markers) + + ch_memory = numpy.zeros((n_ch, n_samples), dtype=float) + marker_memory = numpy.zeros((c_mk, n_samples), dtype=bool) + segment_begin = 0 + for waveform, segment_length in zip(waveforms, segment_lengths): + segment_length = int(segment_length) + segment_end = segment_begin + segment_length + wf_time = time_array[:segment_length] + wf_sample_memory = sample_memory[:segment_length] sampled_channels = [] - for channel, trafo, amplitude, offset in zip(self._channels, self._voltage_transformations, - self._amplitudes, self._offsets): + for channel, ch_mem_pos, trafo, amplitude, offset in zip(self._channels, ch_to_mem, + self._voltage_transformations, + self._amplitudes, self._offsets): + final_memory = ch_memory[ch_mem_pos, segment_begin:segment_end] + if channel is None: sampled_channels.append(self._sample_empty_channel(wf_time)) else: - sampled = waveform.get_sampled(channel, wf_time) - if trafo is not None: - sampled = trafo(sampled) - sampled = sampled - offset + if trafo is None: + # sample directly into the final memory + sampled = waveform.get_sampled(channel, wf_time, output_array=final_memory) + else: + # sample into temporary memory and write the trafo result in the final memory + # unfortunately trafo will always allocate :( + sampled = waveform.get_sampled(channel, wf_time, output_array=wf_sample_memory) + assert sampled is wf_sample_memory + final_memory[:] = trafo(sampled) + sampled = final_memory + assert sampled is final_memory + sampled -= offset sampled /= amplitude sampled_channels.append(sampled) sampled_markers = [] - for marker in self._markers: + for marker, mk_mem_pos in zip(self._markers, mk_to_mem): + final_memory = marker_memory[mk_mem_pos, segment_begin:segment_end] + if marker is None: sampled_markers.append(self._sample_empty_marker(wf_time)) else: - sampled_markers.append(waveform.get_sampled(marker, wf_time) != 0) + sampled = waveform.get_sampled(marker, wf_time, output_array=wf_sample_memory) + sampled = numpy.not_equal(sampled, 0., out=final_memory) + assert sampled is final_memory + + sampled_markers.append(sampled) sampled_waveforms.append((tuple(sampled_channels), tuple(sampled_markers))) + + segment_begin = segment_end + assert segment_begin == n_samples return sampled_waveforms diff --git a/qupulse/hardware/awgs/zihdawg.py b/qupulse/hardware/awgs/zihdawg.py index 135deefe6..f2b088409 100644 --- a/qupulse/hardware/awgs/zihdawg.py +++ b/qupulse/hardware/awgs/zihdawg.py @@ -1,6 +1,6 @@ from pathlib import Path import functools -from typing import Tuple, Set, Callable, Optional, Mapping, Generator, Union, List, Dict +from typing import Tuple, Set, Callable, Optional, Mapping, Generator, Union, Sequence, Dict from enum import Enum import weakref import logging @@ -9,14 +9,23 @@ import hashlib import argparse import re +from abc import abstractmethod try: - import zhinst.ziPython - import zhinst.utils + # zhinst fires a DeprecationWarning from its own code in some versions... + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + import zhinst.utils except ImportError: warnings.warn('Zurich Instruments LabOne python API is distributed via the Python Package Index. Install with pip.') raise +try: + from zhinst import core as zhinst_core +except ImportError: + # backward compability + from zhinst import ziPython as zhinst_core + import time from qupulse.utils.types import ChannelID, TimeType, time_from_float @@ -44,6 +53,35 @@ def valid_fn(*args, **kwargs): return valid_fn +def _amplitude_scales(api_session, serial: str): + return tuple( + api_session.getDouble(f'/{serial}/awgs/{ch // 2:d}/outputs/{ch % 2:d}/amplitude') + for ch in range(8) + ) + +def _sigout_double(api_session, prop: str, serial: str, channel: int, value: float = None) -> float: + """Query channel offset voltage and optionally set it.""" + node_path = f'/{serial}/sigouts/{channel-1:d}/{prop}' + if value is not None: + api_session.setDouble(node_path, value) + api_session.sync() # Global sync: Ensure settings have taken effect on the device. + return api_session.getDouble(node_path) + +def _sigout_range(api_session, serial: str, channel: int, voltage: float = None) -> float: + return _sigout_double(api_session, 'range', serial, channel, voltage) + +def _sigout_offset(api_session, serial: str, channel: int, voltage: float = None) -> float: + return _sigout_double(api_session, 'offset', serial, channel, voltage) + +def _sigout_on(api_session, serial: str, channel: int, value: bool = None) -> bool: + """Query channel signal output status (enabled/disabled) and optionally set it. Corresponds to front LED.""" + node_path = f'/{serial}/sigouts/{channel-1:d}/on' + if value is not None: + api_session.setInt(node_path, value) + api_session.sync() # Global sync: Ensure settings have taken effect on the device. + return bool(api_session.getInt(node_path)) + + @traced class HDAWGRepresentation: """HDAWGRepresentation represents an HDAWG8 instruments and manages a LabOne data server api session. A data server @@ -66,7 +104,7 @@ def __init__(self, device_serial: str = None, :param reset: Reset device before initialization :param timeout: Timeout in seconds for uploading """ - self._api_session = zhinst.ziPython.ziDAQServer(data_server_addr, data_server_port, api_level_number) + self._api_session = zhinst_core.ziDAQServer(data_server_addr, data_server_port, api_level_number) assert zhinst.utils.api_server_version_check(self.api_session) # Check equal data server and api version. self.api_session.connectDevice(device_serial, device_interface) self.default_timeout = timeout @@ -79,7 +117,7 @@ def __init__(self, device_serial: str = None, self._initialize() waveform_path = pathlib.Path(self.api_session.awgModule().getString('directory'), 'awg', 'waves') - self._waveform_file_system = WaveformFileSystem(waveform_path) + self._waveform_file_system = WaveformFileSystem.get_waveform_file_system(waveform_path) self._channel_groups: Dict[HDAWGChannelGrouping, Tuple[HDAWGChannelGroup, ...]] = {} # TODO: lookup method to find channel count @@ -87,11 +125,17 @@ def __init__(self, device_serial: str = None, for grouping in HDAWGChannelGrouping: group_size = grouping.group_size() - groups = [] - for group_idx in range(n_channels // group_size): - groups.append(HDAWGChannelGroup(group_idx, group_size, - identifier=self.group_name(group_idx, group_size), - timeout=self.default_timeout)) + if group_size is None: + # MDS + groups = [ + MDSChannelGroup(self.group_name(0, None), self.default_timeout) + ] + else: + groups = [] + for group_idx in range(n_channels // group_size): + groups.append(SingleDeviceChannelGroup(group_idx, group_size, + identifier=self.group_name(group_idx, group_size), + timeout=self.default_timeout)) self._channel_groups[grouping] = tuple(groups) if grouping is None: @@ -105,7 +149,7 @@ def waveform_file_system(self) -> WaveformFileSystem: @property def channel_tuples(self) -> Tuple['HDAWGChannelGroup', ...]: - return self._channel_groups[self.channel_grouping] + return self._get_groups(self.channel_grouping) @property def channel_pair_AB(self) -> 'HDAWGChannelGroup': @@ -124,7 +168,7 @@ def channel_pair_GH(self) -> 'HDAWGChannelGroup': return self._channel_groups[HDAWGChannelGrouping.CHAN_GROUP_4x2][3] @property - def api_session(self) -> zhinst.ziPython.ziDAQServer: + def api_session(self) -> zhinst_core.ziDAQServer: return self._api_session @property @@ -168,10 +212,20 @@ def reset(self) -> None: self.api_session.sync() def group_name(self, group_idx, group_size) -> str: + if group_size is None: + return f'{self.serial}_MDS' return str(self.serial) + '_' + 'ABCDEFGH'[group_idx*group_size:][:group_size] def _get_groups(self, grouping: 'HDAWGChannelGrouping') -> Tuple['HDAWGChannelGroup', ...]: - return self._channel_groups[grouping] + try: + return self._channel_groups[grouping] + except KeyError: + # python reload... + for grouping_key, group in self._channel_groups.items(): + if grouping_key.value == grouping.value: + return group + else: + raise @property def channel_grouping(self) -> 'HDAWGChannelGrouping': @@ -192,6 +246,10 @@ def channel_grouping(self, channel_grouping: 'HDAWGChannelGrouping'): for group in self._get_groups(old_channel_grouping): group.disconnect_group() + if channel_grouping.value == HDAWGChannelGrouping.MDS.value and not self._is_mds_master(): + # do not connect channel group + return + for group in self._get_groups(channel_grouping): if not group.is_connected(): group.connect_group(self) @@ -199,30 +257,18 @@ def channel_grouping(self, channel_grouping: 'HDAWGChannelGrouping'): @valid_channel def offset(self, channel: int, voltage: float = None) -> float: """Query channel offset voltage and optionally set it.""" - node_path = '/{}/sigouts/{:d}/offset'.format(self.serial, channel-1) - if voltage is not None: - self.api_session.setDouble(node_path, voltage) - self.api_session.sync() # Global sync: Ensure settings have taken effect on the device. - return self.api_session.getDouble(node_path) + return _sigout_offset(self.api_session, self.serial, channel, voltage) @valid_channel def range(self, channel: int, voltage: float = None) -> float: """Query channel voltage range and optionally set it. The instruments selects the next higher available range. This is the one-sided range Vp. Total range: -Vp...Vp""" - node_path = '/{}/sigouts/{:d}/range'.format(self.serial, channel-1) - if voltage is not None: - self.api_session.setDouble(node_path, voltage) - self.api_session.sync() # Global sync: Ensure settings have taken effect on the device. - return self.api_session.getDouble(node_path) + return _sigout_range(self.api_session, self.serial, channel, voltage) @valid_channel def output(self, channel: int, status: bool = None) -> bool: """Query channel signal output status (enabled/disabled) and optionally set it. Corresponds to front LED.""" - node_path = '/{}/sigouts/{:d}/on'.format(self.serial, channel-1) - if status is not None: - self.api_session.setInt(node_path, int(status)) - self.api_session.sync() # Global sync: Ensure settings have taken effect on the device. - return bool(self.api_session.getInt(node_path)) + return _sigout_on(self.api_session, self.serial, channel, status) def get_status_table(self): """Return node tree of instrument with all important settings, as well as each channel group as tuple.""" @@ -232,6 +278,31 @@ def get_status_table(self): self.channel_pair_EF.awg_module.get('awgModule/*'), self.channel_pair_GH.awg_module.get('awgModule/*')) + def _get_mds_group_idx(self) -> Optional[int]: + idx = 0 + while True: + try: + if self.serial in self.api_session.getString(f'/ZI/MDS/GROUPS/{idx}/DEVICES'): + return idx + except RuntimeError: + break + idx += 1 + + def _is_mds_master(self) -> Optional[bool]: + idx = 0 + while True: + try: + devices = self.api_session.getString(f'/ZI/MDS/GROUPS/{idx}/DEVICES').split(',') + except RuntimeError: + break + + if self.serial in devices: + return devices[0] == self.serial + idx += 1 + + def __repr__(self): + return f"{type(self).__name__}({self.serial}, ... {self.api_session})" + class HDAWGTriggerOutSource(Enum): """Assign a signal to a marker output. This is per AWG Core.""" @@ -257,6 +328,7 @@ class HDAWGTriggerOutSource(Enum): class HDAWGChannelGrouping(Enum): """How many independent sequencers should run on the AWG and how the outputs should be grouped by sequencer.""" + MDS = -1 # All channels that are in the current multi device synchronized group CHAN_GROUP_4x2 = 0 # 4x2 with HDAWG8; 2x2 with HDAWG4. /dev.../awgs/0..3/ CHAN_GROUP_2x4 = 1 # 2x4 with HDAWG8; 1x4 with HDAWG4. /dev.../awgs/0 & 2/ CHAN_GROUP_1x8 = 2 # 1x8 with HDAWG8. /dev.../awgs/0/ @@ -265,7 +337,8 @@ def group_size(self) -> int: return { HDAWGChannelGrouping.CHAN_GROUP_4x2: 2, HDAWGChannelGrouping.CHAN_GROUP_2x4: 4, - HDAWGChannelGrouping.CHAN_GROUP_1x8: 8 + HDAWGChannelGrouping.CHAN_GROUP_1x8: 8, + HDAWGChannelGrouping.MDS: None }[self] @@ -294,28 +367,13 @@ class HDAWGModulationMode(Enum): @traced class HDAWGChannelGroup(AWG): - """Represents a channel pair of the Zurich Instruments HDAWG as an independent AWG entity. - It represents a set of channels that have to have(hardware enforced) the same control flow and sample rate. - - It keeps track of the AWG state and manages waveforms and programs on the hardware. - """ - MIN_WAVEFORM_LEN = 192 WAVEFORM_LEN_QUANTUM = 16 def __init__(self, - group_idx: int, - group_size: int, identifier: str, timeout: float) -> None: super().__init__(identifier) - self._device = None - - assert group_idx in range(4) - assert group_size in (2, 4, 8) - - self._group_idx = group_idx - self._group_size = group_size self.timeout = timeout self._awg_module = None @@ -326,43 +384,37 @@ def __init__(self, self._current_program = None # Currently armed program. self._upload_generator = () + self._master_device = None + def _initialize_awg_module(self): """Only run once""" if self._awg_module: self._awg_module.clear() - self._awg_module = self.device.api_session.awgModule() - self.awg_module.set('awgModule/device', self.device.serial) - self.awg_module.set('awgModule/index', self.awg_group_index) - self.awg_module.execute() - self._elf_manager = ELFManager(self.awg_module) - - def disconnect_group(self): - """Disconnect this group from device so groups of another size can be used""" - if self._awg_module: - self.awg_module.clear() - self._device = None - - def connect_group(self, hdawg_device: HDAWGRepresentation): - """""" - self.disconnect_group() - self._device = weakref.proxy(hdawg_device) - assert self.device.channel_grouping.group_size() == self._group_size, f"{self.device.channel_grouping} != {self._group_size}" - self._initialize_awg_module() - # Seems creating AWG module sets SINGLE (single execution mode of sequence) to 0 per default. - self.device.api_session.setInt('/{}/awgs/{:d}/single'.format(self.device.serial, self.awg_group_index), 1) + self._awg_module = self.master_device.api_session.awgModule() + self._awg_module.set('awgModule/device', self.master_device.serial) + self._awg_module.set('awgModule/index', self.awg_group_index) + self._awg_module.execute() + self._elf_manager = ELFManager(self._awg_module) + self._upload_generator = () - def is_connected(self) -> bool: - return self._device is not None + @property + def master_device(self) -> HDAWGRepresentation: + """Reference to HDAWG representation.""" + if self._master_device is None: + raise HDAWGValueError('Channel group is currently not connected') + return self._master_device @property - def num_channels(self) -> int: - """Number of channels""" - return self._group_size + def awg_module(self) -> zhinst_core.AwgModule: + """Each AWG channel group has its own awg module to manage program compilation and upload.""" + if self._awg_module is None: + raise HDAWGValueError('Channel group is not connected and was never initialized') + return self._awg_module - def _channels(self, index_start=1) -> Tuple[int, ...]: - """1 indexed channel""" - offset = index_start + self._group_size * self._group_idx - return tuple(ch + offset for ch in range(self._group_size)) + @property + @abstractmethod + def awg_group_index(self) -> int: + raise NotImplementedError() @property def num_markers(self) -> int: @@ -441,12 +493,13 @@ def upload(self, name: str, offsets=voltage_offsets) self._required_seqc_source = self._program_manager.to_seqc_program() - self._program_manager.waveform_memory.sync_to_file_system(self.device.waveform_file_system) + self._program_manager.waveform_memory.sync_to_file_system(self.master_device.waveform_file_system) # start compiling the source (non-blocking) self._start_compile_and_upload() def _start_compile_and_upload(self): + self._uploaded_seqc_source = None self._upload_generator = self._elf_manager.compile_and_upload(self._required_seqc_source) def _wait_for_compile_and_upload(self): @@ -498,6 +551,13 @@ def arm(self, name: Optional[str]) -> None: Currently hardware triggering is not implemented. The HDAWGProgramManager needs to emit code that calls `waitDigTrigger` to do that. """ + if self.num_channels > 8: + if name is None: + self._required_seqc_source = "" + else: + self._required_seqc_source = self._program_manager.to_seqc_program(name) + self._start_compile_and_upload() + if self._required_seqc_source != self._uploaded_seqc_source: self._wait_for_compile_and_upload() @@ -521,9 +581,9 @@ def arm(self, name: Optional[str]) -> None: self.user_register(self._program_manager.Constants.PROG_SEL_REGISTER, self._program_manager.name_to_index(name) | int(self._program_manager.Constants.NO_RESET_MASK, 2)) - # this is a workaround for problems in the past and should be re-thought in case of a re-write - for ch_pair in self.device.channel_tuples: - ch_pair._wait_for_compile_and_upload() + # this was a workaround for problems in the past and I totally forgot why it was here + # for ch_pair in self.master.channel_tuples: + # ch_pair._wait_for_compile_and_upload() self.enable(True) def run_current_program(self) -> None: @@ -546,49 +606,33 @@ def programs(self) -> Set[str]: @property def sample_rate(self) -> TimeType: """The default sample rate of the AWG channel group.""" - node_path = '/{}/awgs/{}/time'.format(self.device.serial, self.awg_group_index) - sample_rate_num = self.device.api_session.getInt(node_path) - node_path = '/{}/system/clocks/sampleclock/freq'.format(self.device.serial) - sample_clock = self.device.api_session.getDouble(node_path) + node_path = '/{}/awgs/{}/time'.format(self.master_device.serial, self.awg_group_index) + sample_rate_num = self.master_device.api_session.getInt(node_path) + node_path = '/{}/system/clocks/sampleclock/freq'.format(self.master_device.serial) + sample_clock = self.master_device.api_session.getDouble(node_path) """Calculate exact rational number based on (sample_clock Sa/s) / 2^sample_rate_num. Otherwise numerical imprecision will give rise to errors for very long pulses. fractions.Fraction does not accept floating point numerator, which sample_clock could potentially be.""" return time_from_float(sample_clock) / 2 ** sample_rate_num - @property - def awg_group_index(self) -> int: - """AWG node group index assuming 4x2 channel grouping. Then 0...3 will give appropriate index of group.""" - return self._group_idx - - @property - def device(self) -> HDAWGRepresentation: - """Reference to HDAWG representation.""" - if self._device is None: - raise HDAWGValueError('Channel group is currently not connected') - return self._device - - @property - def awg_module(self) -> zhinst.ziPython.AwgModule: - """Each AWG channel group has its own awg module to manage program compilation and upload.""" - if self._awg_module is None: - raise HDAWGValueError('Channel group is not connected and was never initialized') - return self._awg_module + def connect_group(self, hdawg_device: HDAWGRepresentation): + self.disconnect_group() + self._master_device = weakref.proxy(hdawg_device) + self._initialize_awg_module() + # Seems creating AWG module sets SINGLE (single execution mode of sequence) to 0 per default. + self.master_device.api_session.setInt(f'/{self.master_device.serial}/awgs/0/single', 1) - @property - def user_directory(self) -> str: - """LabOne user directory with subdirectories: "awg/src" (seqc sourcefiles), "awg/elf" (compiled AWG binaries), - "awag/waves" (user defined csv waveforms).""" - return self.awg_module.getString('awgModule/directory') + def disconnect_group(self): + """Disconnect this group from device so groups of another size can be used""" + if self._awg_module: + self.awg_module.clear() + self._master_device = None + self._elf_manager = None + self._upload_generator = () - def enable(self, status: bool = None) -> bool: - """Start the AWG sequencer.""" - # There is also 'awgModule/awg/enable', which seems to have the same functionality. - node_path = '/{}/awgs/{:d}/enable'.format(self.device.serial, self.awg_group_index) - if status is not None: - self.device.api_session.setInt(node_path, int(status)) - self.device.api_session.sync() # Global sync: Ensure settings have taken effect on the device. - return bool(self.device.api_session.getInt(node_path)) + def is_connected(self) -> bool: + return self._master_device is not None def user_register(self, reg: UserRegister, value: int = None) -> int: """Query user registers (1-16) and optionally set it. @@ -605,18 +649,69 @@ def user_register(self, reg: UserRegister, value: int = None) -> int: reg = UserRegister(one_based_value=reg) if reg.to_web_interface() not in range(1, 17): - raise HDAWGValueError('{reg:repr} not a valid (1-16) register.'.format(reg=reg)) + raise HDAWGValueError(f'{reg:!r} not a valid (1-16) register.') - node_path = '/{}/awgs/{:d}/userregs/{:labone}'.format(self.device.serial, self.awg_group_index, reg) + node_path = '/{}/awgs/{:d}/userregs/{:labone}'.format(self.master_device.serial, self.awg_group_index, reg) if value is not None: - self.device.api_session.setInt(node_path, value) - self.device.api_session.sync() # Global sync: Ensure settings have taken effect on the device. - return self.device.api_session.getInt(node_path) + self.master_device.api_session.setInt(node_path, value) + # hackedy + for mds_serial in getattr(self, '_mds_devices', [])[1:]: + self.master_device.api_session.setInt(node_path.replace(self.master_device.serial, mds_serial), value) + self.master_device.api_session.sync() # Global sync: Ensure settings have taken effect on the device. + return self.master_device.api_session.getInt(node_path) - def _amplitude_scales(self) -> Tuple[float, ...]: - """not affected by grouping""" - return tuple(self.device.api_session.getDouble(f'/{self.device.serial}/awgs/{ch // 2:d}/outputs/{ch % 2:d}/amplitude') - for ch in self._channels(index_start=0)) + +@traced +class MDSChannelGroup(HDAWGChannelGroup): + def __init__(self, + identifier: str, + timeout: float) -> None: + super().__init__(identifier, timeout) + + self._master_device = None + self._mds_devices = None + + @property + def num_channels(self) -> int: + """Number of channels""" + return len(self._mds_devices) * 8 + + @property + def awg_group_index(self): + return 0 + + def disconnect_group(self): + super().disconnect_group() + self._mds_devices = None + + def connect_group(self, hdawg_device: HDAWGRepresentation): + mds_group = hdawg_device._get_mds_group_idx() + if mds_group is None: + raise HDAWGException("AWG not in any MDS group", hdawg_device) + mds_devices = hdawg_device.api_session.getString(f'/ZI/MDS/GROUPS/{mds_group}/DEVICES').split(',') + if hdawg_device.serial != mds_devices[0]: + raise HDAWGException("Only the master device can connect to the HDAWG MDS channel group.") + super().connect_group(hdawg_device) + self._mds_devices = mds_devices + + def enable(self, status: bool = None) -> bool: + """Start the AWG sequencer.""" + # There is also 'awgModule/awg/enable', which seems to have the same functionality. + node_path = '/{}/awgs/{:d}/enable'.format(self.master_device.serial, 0) + if status is not None: + self.awg_module.set('awg/enable', int(status)) + else: + status = self.awg_module.get('awg/module') + + #return bool(status) + """ + if status is not None: + self.master_device.api_session.setInt(node_path, int(status)) + for mds_device in self._mds_devices[1:]: + self.master_device.api_session.setInt(node_path.replace(self._mds_devices[0], mds_device), int(status)) + self.master_device.api_session.sync() # Global sync: Ensure settings have taken effect on the device. + """ + return bool(self.master_device.api_session.getInt(node_path)) def amplitudes(self) -> Tuple[float, ...]: """Query AWG channel amplitude value (not peak to peak). @@ -628,18 +723,88 @@ def amplitudes(self) -> Tuple[float, ...]: stored in the waveform memory.""" amplitudes = [] - for ch, zi_amplitude in zip(self._channels(), self._amplitude_scales()): - zi_range = self.device.range(ch) + api_session = self.master_device.api_session + for mds_device in self._mds_devices: + amplitude_scales = _amplitude_scales(api_session, mds_device) + ranges = [_sigout_range(api_session, mds_device, ch) for ch in range(1, 9)] + amplitudes.extend(zi_amplitude * zi_range / 2 for zi_amplitude, zi_range in zip(amplitude_scales, ranges)) + return tuple(amplitudes) + + def offsets(self) -> Tuple[float, ...]: + offsets = [] + api_session = self.master_device.api_session + for mds_device in self._mds_devices: + offsets.extend(_sigout_offset(api_session, mds_device, ch) for ch in range(1, 9)) + return tuple(offsets) + + +class SingleDeviceChannelGroup(HDAWGChannelGroup): + def __init__(self, + group_idx: int, + group_size: int, + identifier: str, + timeout: float) -> None: + super().__init__(identifier, timeout) + self._device = None + + assert group_idx in range(4) + assert group_size in (2, 4, 8) + + self._group_idx = group_idx + self._group_size = group_size + + @property + def num_channels(self) -> int: + """Number of channels""" + return self._group_size + + def _channels(self, index_start=1) -> Tuple[int, ...]: + """1 indexed channel""" + offset = index_start + self._group_size * self._group_idx + return tuple(ch + offset for ch in range(self.num_channels)) + + @property + def awg_group_index(self) -> int: + """AWG node group index assuming 4x2 channel grouping. Then 0...3 will give appropriate index of group.""" + return self._group_idx + + @property + def user_directory(self) -> str: + """LabOne user directory with subdirectories: "awg/src" (seqc sourcefiles), "awg/elf" (compiled AWG binaries), + "awag/waves" (user defined csv waveforms).""" + return self.awg_module.getString('awgModule/directory') + + def enable(self, status: bool = None) -> bool: + """Start the AWG sequencer.""" + # There is also 'awgModule/awg/enable', which seems to have the same functionality. + node_path = '/{}/awgs/{:d}/enable'.format(self.master_device.serial, self.awg_group_index) + if status is not None: + self.master_device.api_session.setInt(node_path, int(status)) + self.master_device.api_session.sync() # Global sync: Ensure settings have taken effect on the device. + return bool(self.master_device.api_session.getInt(node_path)) + + def amplitudes(self) -> Tuple[float, ...]: + """Query AWG channel amplitude value (not peak to peak). + + From manual: + The final signal amplitude is given by the product of the full scale + output range of 1 V[in this example], the dimensionless amplitude + scaling factor 1.0, and the actual dimensionless signal amplitude + stored in the waveform memory.""" + amplitudes = [] + + for ch, zi_amplitude in zip(self._channels(), _amplitude_scales(self.master_device.api_session, self.master_device.serial)): + zi_range = self.master_device.range(ch) amplitudes.append(zi_amplitude * zi_range / 2) return tuple(amplitudes) def offsets(self) -> Tuple[float, ...]: - return tuple(map(self.device.offset, self._channels())) + return tuple(map(self.master_device.offset, self._channels())) class ELFManager: class AWGModule: - def __init__(self, awg_module: zhinst.ziPython.AwgModule): + def __init__(self, awg_module: zhinst_core.AwgModule): """Provide an easily mockable interface to the zhinst AwgModule object""" self._module = awg_module @@ -706,11 +871,11 @@ def elf_status(self) -> Tuple[int, float]: def index(self) -> int: return self._module.getInt('index') - def __init__(self, awg_module: zhinst.ziPython.AwgModule): + def __init__(self, awg_module: zhinst_core.AwgModule): """This class organizes compiling and uploading of compiled programs. The source code file is named based on the code hash to cache compilation results. This requires that the waveform names are unique. - The compilation and upload itself are done asynchronously by zhinst.ziPython. To avoid spawning a useless + The compilation and upload itself are done asynchronously by zhinst.core. To avoid spawning a useless thread for updating the status the method :py:meth:`~ELFManager.compile_and_upload` returns a generator which talks to the undelying library when needed.""" self.awg_module = self.AWGModule(awg_module) @@ -832,7 +997,10 @@ def _update_upload_job_status(self): assert elf_upload == 0 def _upload(self, elf_file) -> Generator[str, str, None]: - self._start_elf_upload(elf_file) + if self.awg_module.compiler_upload: + pass + else: + self._start_elf_upload(elf_file) while True: self._update_upload_job_status() diff --git a/qupulse/hardware/dacs/alazar.py b/qupulse/hardware/dacs/alazar.py index 6d4b12abe..7398e2115 100644 --- a/qupulse/hardware/dacs/alazar.py +++ b/qupulse/hardware/dacs/alazar.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Dict, Any, Optional, Tuple, List, Iterable, Callable, Sequence from collections import defaultdict import copy @@ -15,18 +16,50 @@ from qupulse.utils.types import TimeType from qupulse.hardware.dacs.dac_base import DAC from qupulse.hardware.util import traced - +from qupulse.utils.performance import time_windows_to_samples logger = logging.getLogger(__name__) -class AlazarProgram: - def __init__(self): - self._sample_factor = None - self._masks = {} - self.operations = [] - self._total_length = None - self._auto_rearm_count = 1 +def _windows_to_samples(begins: np.ndarray, lengths: np.ndarray, + sample_rate: TimeType) -> Tuple[np.ndarray, np.ndarray]: + return time_windows_to_samples(begins, lengths, float(sample_rate)) + + +@dataclasses.dataclass +class AcquisitionProgram: + _sample_rate: Optional[TimeType] = dataclasses.field(default=None) + _masks: dict = dataclasses.field(default_factory=dict) + + @property + def sample_rate(self) -> Optional[TimeType]: + return self._sample_rate + + def set_measurement_mask(self, mask_name: str, sample_rate: TimeType, + begins: np.ndarray, lengths: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Raise error if sample factor has changed""" + if self._sample_rate is None: + self._sample_rate = sample_rate + elif sample_rate != self.sample_rate: + raise RuntimeError('class AcquisitionProgram has already masks with differing sample rate.') + + assert begins.dtype == float and lengths.dtype == float + + begins, lengths = self._masks[mask_name] = _windows_to_samples(begins, lengths, sample_rate) + + return begins, lengths + + def clear_masks(self): + self._masks.clear() + self._sample_rate = None + + +@dataclasses.dataclass +class AlazarProgram(AcquisitionProgram): + operations: Sequence = dataclasses.field(default_factory=list) + _total_length: Optional[int] = dataclasses.field(default=None) + _auto_rearm_count: int = dataclasses.field(default=1) + buffer_strategy: Optional = dataclasses.field(default=None) def masks(self, mask_maker: Callable[[str, np.ndarray, np.ndarray], Mask]) -> List[Mask]: return [mask_maker(mask_name, *data) for mask_name, data in self._masks.items()] @@ -61,39 +94,6 @@ def auto_rearm_count(self, value: int): raise ValueError("Trigger count has to be in the interval [0, 2**64-1]") self._auto_rearm_count = trigger_count - def clear_masks(self): - self._masks.clear() - - @property - def sample_factor(self) -> Optional[TimeType]: - return self._sample_factor - - def set_measurement_mask(self, mask_name: str, sample_factor: TimeType, - begins: np.ndarray, lengths: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Raise error if sample factor has changed""" - if self._sample_factor is None: - self._sample_factor = sample_factor - - elif sample_factor != self._sample_factor: - raise RuntimeError('class AlazarProgram has already masks with differing sample factor') - - assert begins.dtype == float and lengths.dtype == float - - # optimization potential here (hash input?) - begins = np.rint(begins * float(sample_factor)).astype(dtype=np.uint64) - lengths = np.floor_divide(lengths * float(sample_factor.numerator), float(sample_factor.denominator)).astype(dtype=np.uint64) - - sorting_indices = np.argsort(begins) - begins = begins[sorting_indices] - lengths = lengths[sorting_indices] - - begins.flags.writeable = False - lengths.flags.writeable = False - - self._masks[mask_name] = begins, lengths - - return begins, lengths - def iter(self, mask_maker): yield self.masks(mask_maker) yield self.operations @@ -294,19 +294,19 @@ def _make_mask(self, mask_id: str, begins, lengths) -> Mask: return mask def set_measurement_mask(self, program_name, mask_name, begins, lengths) -> Tuple[np.ndarray, np.ndarray]: - sample_factor = TimeType.from_fraction(int(self.default_config.captureClockConfiguration.numeric_sample_rate(self.card.model)), 10**9) - return self._registered_programs[program_name].set_measurement_mask(mask_name, sample_factor, begins, lengths) + sample_rate = TimeType.from_fraction(int(self.default_config.captureClockConfiguration.numeric_sample_rate(self.card.model)), 10**9) + return self._registered_programs[program_name].set_measurement_mask(mask_name, sample_rate, begins, lengths) def register_measurement_windows(self, program_name: str, windows: Dict[str, Tuple[np.ndarray, np.ndarray]]) -> None: program = self._registered_programs[program_name] - sample_factor = TimeType.from_fraction(int(self.default_config.captureClockConfiguration.numeric_sample_rate(self.card.model)), + sample_rate = TimeType.from_fraction(int(self.default_config.captureClockConfiguration.numeric_sample_rate(self.card.model)), 10 ** 9) program.clear_masks() for mask_name, (begins, lengths) in windows.items(): - program.set_measurement_mask(mask_name, sample_factor, begins, lengths) + program.set_measurement_mask(mask_name, sample_rate, begins, lengths) def register_operations(self, program_name: str, operations) -> None: self._registered_programs[program_name].operations = operations @@ -323,10 +323,10 @@ def arm_program(self, program_name: str) -> None: config.masks, config.operations, total_record_size = self._registered_programs[program_name].iter( self._make_mask) - sample_rate = config.captureClockConfiguration.numeric_sample_rate(self.card.model) + sample_rate_in_hz = config.captureClockConfiguration.numeric_sample_rate(self.card.model) # sample rate in GHz - sample_factor = TimeType.from_fraction(sample_rate, 10 ** 9) + sample_rate = TimeType.from_fraction(sample_rate_in_hz, 10 ** 9) if not config.operations: raise RuntimeError("No operations: Arming program without operations is an error as there will " @@ -335,9 +335,9 @@ def arm_program(self, program_name: str) -> None: elif not config.masks: raise RuntimeError("No masks although there are operations in program: %r" % program_name) - elif self._registered_programs[program_name].sample_factor != sample_factor: + elif self._registered_programs[program_name].sample_rate != sample_rate: raise RuntimeError("Masks were registered with a different sample rate {}!={}".format( - self._registered_programs[program_name].sample_factor, sample_factor)) + self._registered_programs[program_name].sample_rate, sample_rate)) assert total_record_size > 0 diff --git a/qupulse/hardware/dacs/alazar2.py b/qupulse/hardware/dacs/alazar2.py new file mode 100644 index 000000000..641606d91 --- /dev/null +++ b/qupulse/hardware/dacs/alazar2.py @@ -0,0 +1,132 @@ +from typing import Union, Iterable, Dict, Tuple, Mapping, Optional +from types import MappingProxyType +import logging + +import numpy + +from qupulse.utils.types import TimeType +from qupulse.hardware.dacs.dac_base import DAC +from qupulse.hardware.dacs.alazar import AlazarProgram + +import atsaverage +from atsaverage.masks import make_best_mask +from atsaverage.config2 import BoardConfiguration, create_scanline_definition, BufferStrategySettings + + +logger = logging.getLogger(__name__) + + +class AlazarCard(DAC): + def __init__(self, atsaverage_card: 'atsaverage.core.AlazarCard'): + super().__init__() + self._atsaverage_card = atsaverage_card + self._registered_programs = {} + + # for auto retrigger + self._armed_program: Optional[str] = None + self._remaining_auto_triggers = 0 + + # for debugging purposes + self._raw_data_mask = None + self.default_buffer_strategy: Optional[BufferStrategySettings] = None + + @property + def atsaverage_card(self): + return self._atsaverage_card + + @property + def registered_programs(self) -> Mapping[str, AlazarProgram]: + return MappingProxyType(self._registered_programs) + + @property + def current_sample_rate_in_giga_herz(self) -> TimeType: + numeric_sample_rate = self._atsaverage_card.board_configuration_cache.get_numeric_sample_rate() + if numeric_sample_rate is None: + raise RuntimeError("The sample rate was not set yet. The instrument does not support retrieving the sample " + "rate via an API. We need to cache a set command.") + return TimeType.from_fraction(numeric_sample_rate, 10 ** 9) + + def get_current_input_range(self, channel: Union[str, int]): + input_range = self._atsaverage_card.board_configuration_cache.get_channel_input_range(channel) + if input_range is None: + raise RuntimeError("The input range was not set yet. The instrument does not support retrieving the input " + "range via an API. We need to cache a set command.") + return input_range + + def register_measurement_windows(self, program_name: str, windows: Dict[str, Tuple[numpy.ndarray, + numpy.ndarray]]) -> None: + program = self._registered_programs.setdefault(program_name, AlazarProgram()) + sample_rate = self.current_sample_rate_in_giga_herz + program.clear_masks() + for mask_name, (begins, lengths) in windows.items(): + program.set_measurement_mask(mask_name, sample_rate, begins, lengths) + + def set_measurement_mask(self, program_name: str, mask_name: str, + begins: numpy.ndarray, lengths: numpy.ndarray) -> Tuple[numpy.ndarray, numpy.ndarray]: + program = self._registered_programs.setdefault(program_name, AlazarProgram()) + return program.set_measurement_mask(mask_name, self.current_sample_rate_in_giga_herz, begins, lengths) + + def register_operations(self, program_name: str, operations) -> None: + self._registered_programs.setdefault(program_name, AlazarProgram()).operations = operations + + def _make_scanline_definition(self, program: AlazarProgram): + sample_rate_in_ghz = self.current_sample_rate_in_giga_herz + sample_rate_in_hz = int(sample_rate_in_ghz * 10 ** 9) + + masks = program.masks(make_best_mask) + if sample_rate_in_ghz != program.sample_rate: + raise RuntimeError("Masks were registered with a different sample rate") + return create_scanline_definition(masks, program.operations, + raw_data_mask=self._raw_data_mask, + board_spec=self._atsaverage_card.get_board_spec(), + buffer_strategy=program.buffer_strategy, + numeric_sample_rate=sample_rate_in_hz) + + def _prepare_program(self, program: AlazarProgram): + scanline_definition = self._make_scanline_definition(program) + self._atsaverage_card.configureMeasurement(scanline_definition) + + def arm_program(self, program_name: str) -> None: + logger.debug("Arming program %s on %r", program_name, self._atsaverage_card) + + if program_name == self._armed_program and self._remaining_auto_triggers > 0: + logger.info("Relying on atsaverage auto-arm with %d auto triggers remaining after this one: %d", + self._remaining_auto_triggers) + + else: + program = self._registered_programs[program_name] + scanline_definition = self._make_scanline_definition(program) + + self._atsaverage_card.configureMeasurement(scanline_definition) + + self._atsaverage_card.startAcquisition(program.auto_rearm_count) + self._remaining_auto_triggers = program.auto_rearm_count - 1 + + def delete_program(self, program_name: str) -> None: + self._registered_programs.pop(program_name) + + def clear(self) -> None: + self._registered_programs.clear() + + def measure_program(self, channels: Iterable[str] = None) -> Dict[str, numpy.ndarray]: + scanline_data = self._atsaverage_card.extractNextScanline() + + if channels is None: + channels = scanline_data.operationResults.keys() + + scanline_definition = scanline_data.definition + operation_definitions = {operation.identifier: operation + for operation in scanline_definition.operations} + mask_definitions = {mask.identifier: mask + for mask in scanline_definition.masks} + + def get_input_range(operation_id: str): + mask_id = operation_definitions[operation_id].maskID + hw_channel = int(mask_definitions[mask_id].channel) + return self.get_current_input_range(hw_channel) + + data = {} + for op_name in channels: + input_range = get_input_range(op_name) + data[op_name] = scanline_data.operationResults[op_name].getAsVoltage(input_range) + return data diff --git a/qupulse/hardware/dacs/dac_base.py b/qupulse/hardware/dacs/dac_base.py index eb7670733..e68802576 100644 --- a/qupulse/hardware/dacs/dac_base.py +++ b/qupulse/hardware/dacs/dac_base.py @@ -12,30 +12,52 @@ class DAC(metaclass=ABCMeta): @abstractmethod def register_measurement_windows(self, program_name: str, windows: Dict[str, Tuple[numpy.ndarray, numpy.ndarray]]) -> None: - """""" + """Register measurement windows for a given program. Overwrites previously defined measurement windows for + this program. + + Args: + program_name: Name of the program + windows: Measurement windows by name. + First array are the start points of measurement windows in nanoseconds. + Second array are the corresponding measurement window's lengths in nanoseconds. + """ @abstractmethod - def set_measurement_mask(self, program_name, mask_name, begins, lengths) -> Tuple[numpy.ndarray, numpy.ndarray]: - """returns length of windows in samples""" + def set_measurement_mask(self, program_name: str, mask_name: str, + begins: numpy.ndarray, + lengths: numpy.ndarray) -> Tuple[numpy.ndarray, numpy.ndarray]: + """Set/overwrite a single the measurement mask for a program. Begins and lengths are in nanoseconds. + + Args: + program_name: Name of the program + mask_name: Name of the mask/measurement windows + begins: Staring points in nanoseconds + lengths: Lengths in nanoseconds + + Returns: + Measurement windows in DAC samples (begins, lengths) + """ @abstractmethod def register_operations(self, program_name: str, operations) -> None: - """""" + """Register operations that are to be applied to the measurement results. + + Args: + program_name: Name of the program + operations: DAC specific instructions what to do with the data recorded by the device. + """ @abstractmethod def arm_program(self, program_name: str) -> None: - """""" + """Prepare the device for measuring the given program and wait for a trigger event.""" @abstractmethod def delete_program(self, program_name) -> None: - """""" + """Delete program from internal memory.""" @abstractmethod def clear(self) -> None: - """Clears all registered programs. - - Caution: This affects all programs and waveforms on the AWG, not only those uploaded using qupulse! - """ + """Clears all registered programs.""" @abstractmethod def measure_program(self, channels: Iterable[str]) -> Dict[str, numpy.ndarray]: diff --git a/qupulse/hardware/setup.py b/qupulse/hardware/setup.py index bce719032..646815742 100644 --- a/qupulse/hardware/setup.py +++ b/qupulse/hardware/setup.py @@ -70,7 +70,7 @@ def __init__(self, awg: AWG, channel_on_awg: int): RegisteredProgram = NamedTuple('RegisteredProgram', [('program', Loop), - ('measurement_windows', Dict[str, Tuple[float, float]]), + ('measurement_windows', Dict[str, Tuple[np.ndarray, np.ndarray]]), ('run_callback', Callable), ('awgs_to_upload_to', Set[AWG]), ('dacs_to_arm', Set[DAC])]) @@ -102,7 +102,7 @@ def register_program(self, name: str, channels - set(self._channel_map.keys()))) temp_measurement_windows = defaultdict(list) - for mw_name, begins_lengths in program.get_measurement_windows().items(): + for mw_name, begins_lengths in program.get_measurement_windows(drop=True).items(): temp_measurement_windows[mw_name].append(begins_lengths) if set(temp_measurement_windows.keys()) - set(self._measurement_map.keys()): diff --git a/qupulse/hardware/util.py b/qupulse/hardware/util.py index 365e47662..30a9aae34 100644 --- a/qupulse/hardware/util.py +++ b/qupulse/hardware/util.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Tuple, Union, Collection +from typing import Collection, Sequence, Tuple, Union, Optional import itertools import numpy as np @@ -14,31 +14,86 @@ def traced(obj): from qupulse.utils.types import TimeType from qupulse.utils import pairwise +try: + import numba + njit = numba.njit +except ImportError: + numba = None + njit = lambda x: x + +try: + import zhinst +except ImportError: # pragma: no cover + zhinst = None + +__all__ = ['voltage_to_uint16', 'get_sample_times', 'traced', 'zhinst_voltage_to_uint16'] + + +@njit +def _voltage_to_uint16_numba(voltage: np.ndarray, output_amplitude: float, output_offset: float, resolution: int) -> np.ndarray: + """Implementation detail that can be compiled with numba. This code is very slow without numba.""" + out_of_range = False + scale = (2 ** resolution - 1) / (2 * output_amplitude) + result = np.empty_like(voltage, dtype=np.uint16) + for i in range(voltage.size): + x = voltage[i] - output_offset + if np.abs(x) > output_amplitude: + out_of_range = True + result[i] = np.uint16(np.rint((x + output_amplitude) * scale)) + + if out_of_range: + raise ValueError('Voltage out of range') + + return result -__all__ = ['voltage_to_uint16', 'get_sample_times', 'traced'] + +def _voltage_to_uint16_numpy(voltage: np.ndarray, output_amplitude: float, output_offset: float, resolution: int) -> np.ndarray: + """Implementation detail to be used if numba is not available.""" + non_dc_voltage = voltage - output_offset + if np.any(np.abs(non_dc_voltage) > output_amplitude): + # should get more context in wrapper function + raise ValueError('Voltage out of range') + + non_dc_voltage += output_amplitude + non_dc_voltage *= (2**resolution - 1) / (2*output_amplitude) + return np.rint(non_dc_voltage).astype(np.uint16) def voltage_to_uint16(voltage: np.ndarray, output_amplitude: float, output_offset: float, resolution: int) -> np.ndarray: - """ + """Convert values of the range + [output_offset - output_amplitude, output_offset + output_amplitude) + to uint16 in the range + [0, 2**resolution) - :param voltage: - :param output_amplitude: - :param output_offset: - :param resolution: - :return: + output_offset - output_amplitude -> 0 + output_offset -> 2**(resolution - 1) + output_offset + output_amplitude -> 2**resolution - 1 + + Args: + voltage: input voltage. read-only + output_amplitude: input divided by this + output_offset: is subtracted from input + resolution: Target resolution in bits (determines the output range) + + Raises: + ValueError if the voltage is out of range or the resolution is not an integer + + Returns: + (voltage - output_offset + output_amplitude) * (2**resolution - 1) / (2*output_amplitude) as uint16 """ if resolution < 1 or not isinstance(resolution, int): raise ValueError('The resolution must be an integer > 0') - non_dc_voltage = voltage - output_offset - if np.any(np.abs(non_dc_voltage) > output_amplitude): - raise ValueError('Voltage of range', dict(voltage=voltage, - output_offset=output_offset, - output_amplitude=output_amplitude)) - non_dc_voltage += output_amplitude - non_dc_voltage *= (2**resolution - 1) / (2*output_amplitude) - np.rint(non_dc_voltage, out=non_dc_voltage) - return non_dc_voltage.astype(np.uint16) + try: + if numba: + impl = _voltage_to_uint16_numba + else: + impl = _voltage_to_uint16_numpy + return impl(voltage, output_amplitude, output_offset, resolution) + except ValueError as err: + raise ValueError('Voltage out of range', dict(voltage=voltage, + output_offset=output_offset, + output_amplitude=output_amplitude)) from err def find_positions(data: Sequence, to_find: Sequence) -> np.ndarray: @@ -56,8 +111,9 @@ def find_positions(data: Sequence, to_find: Sequence) -> np.ndarray: return positions + def get_waveform_length(waveform: Waveform, - sample_rate_in_GHz: TimeType, tolerance: float = 1e-10) -> int: + sample_rate_in_GHz: TimeType, tolerance: float = 1e-10) -> int: """Calculates the number of samples in a waveform If only one waveform is given, the number of samples has shape () @@ -93,6 +149,7 @@ def get_waveform_length(waveform: Waveform, return segment_length + def get_sample_times(waveforms: Union[Collection[Waveform], Waveform], sample_rate_in_GHz: TimeType, tolerance: float = 1e-10) -> Tuple[np.array, np.array]: """Calculates the sample times required for the longest waveform in waveforms and returns it together with an array @@ -126,3 +183,120 @@ def get_sample_times(waveforms: Union[Collection[Waveform], Waveform], time_array = np.arange(np.max(segment_lengths), dtype=float) / float(sample_rate_in_GHz) return time_array, segment_lengths + + +@njit +def _zhinst_voltage_to_uint16_numba(size: int, ch1: Optional[np.ndarray], ch2: Optional[np.ndarray], + m1_front: Optional[np.ndarray], m1_back: Optional[np.ndarray], + m2_front: Optional[np.ndarray], m2_back: Optional[np.ndarray]) -> np.ndarray: + """Numba targeted implementation""" + data = np.zeros((size, 3), dtype=np.uint16) + + scale = float(2**15 - 1) + + invalid_value = None + + def has_invalid_size(arr): + return arr is not None and len(arr) != size + + if has_invalid_size(ch1) or has_invalid_size(ch2) or has_invalid_size(m1_front) or has_invalid_size(m1_back) or has_invalid_size(m2_front) or has_invalid_size(m2_back): + raise ValueError("One of the inputs does not have the given size.") + + for i in range(size): + if ch1 is not None: + if not abs(ch1[i]) <= 1: + invalid_value = ch1[i] + data[i, 0] = ch1[i] * scale + if ch2 is not None: + if not abs(ch2[i]) <= 1: + invalid_value = ch2[i] + data[i, 1] = ch2[i] * scale + if m1_front is not None: + data[i, 2] |= (m1_front[i] != 0) + if m1_back is not None: + data[i, 2] |= (m1_back[i] != 0) << 1 + if m2_front is not None: + data[i, 2] |= (m2_front[i] != 0) << 2 + if m2_back is not None: + data[i, 2] |= (m2_back[i] != 0) << 3 + + if invalid_value is not None: + # we can only use compile time constants here + raise ValueError('Encountered an invalid value in channel data (not in [-1, 1])') + + return data.ravel() + + +def _zhinst_voltage_to_uint16_numpy(size: int, ch1: Optional[np.ndarray], ch2: Optional[np.ndarray], + m1_front: Optional[np.ndarray], m1_back: Optional[np.ndarray], + m2_front: Optional[np.ndarray], m2_back: Optional[np.ndarray]) -> np.ndarray: + """Fallback implementation if numba is not available""" + markers = (m1_front, m1_back, m2_front, m2_back) + + def check_invalid_values(ch_data): + # like this to catch NaN + invalid = ~(np.abs(ch_data) <= 1) + if np.any(invalid): + raise ValueError('Encountered an invalid value in channel data (not in [-1, 1])', ch_data[invalid][-1]) + + if ch1 is None: + ch1 = np.zeros(size) + else: + check_invalid_values(ch1) + if ch2 is None: + ch2 = np.zeros(size) + else: + check_invalid_values(ch1) + marker_data = np.zeros(size, dtype=np.uint16) + for idx, marker in enumerate(markers): + if marker is not None: + marker_data += np.uint16((marker > 0) * 2 ** idx) + return zhinst.utils.convert_awg_waveform(ch1, ch2, marker_data) + + +def zhinst_voltage_to_uint16(ch1: Optional[np.ndarray], ch2: Optional[np.ndarray], + markers: Tuple[Optional[np.ndarray], Optional[np.ndarray], + Optional[np.ndarray], Optional[np.ndarray]]) -> np.ndarray: + """Potentially (if numba is installed) faster version of zhinst.utils.convert_awg_waveform + + Args: + ch1: Sampled data of channel 1 [-1, 1] + ch2: Sampled data of channel 1 [-1, 1] + markers: Marker data of (ch1_front, ch1_back, ch2_front, ch2_back) + + Returns: + Interleaved data in the correct format (u16). The first bit is the sign bit so the data needs to be interpreted + as i16. + """ + all_input = (ch1, ch2, *markers) + size = {x.size for x in all_input if x is not None} + if not size: + raise ValueError("No input arrays") + elif len(size) != 1: + raise ValueError("Inputs have incompatible dimension") + size, = size + size = int(size) + + if numba is not None: + try: + return _zhinst_voltage_to_uint16_numba(size, *all_input) + except ValueError: + # use the exception from numpy version + pass + return _zhinst_voltage_to_uint16_numpy(size, *all_input) + + +def not_none_indices(seq: Sequence) -> Tuple[Sequence[Optional[int]], int]: + """Calculate lookup table from sparse to non sparse indices and the total number of not None elements + + assert ([None, 0, 1, None, None, 2], 3) == not_none_indices([None, 'a', 'b', None, None, 'c']) + """ + indices = [] + idx = 0 + for elem in seq: + if elem is None: + indices.append(elem) + else: + indices.append(idx) + idx += 1 + return indices, idx diff --git a/qupulse/pulses/abstract_pulse_template.py b/qupulse/pulses/abstract_pulse_template.py index 9a257e4b5..05e75307c 100644 --- a/qupulse/pulses/abstract_pulse_template.py +++ b/qupulse/pulses/abstract_pulse_template.py @@ -146,6 +146,10 @@ def _internal_create_program(self, **kwargs): doc=_PROPERTY_DOC.format(name='integral')) parameter_names = property(partial(_get_property, property_name='parameter_names'), doc=_PROPERTY_DOC.format(name='parameter_names')) + initial_values = property(partial(_get_property, property_name='initial_values'), + doc=_PROPERTY_DOC.format(name='initial_values')) + final_values = property(partial(_get_property, property_name='final_values'), + doc=_PROPERTY_DOC.format(name='final_values')) __hash__ = None diff --git a/qupulse/pulses/arithmetic_pulse_template.py b/qupulse/pulses/arithmetic_pulse_template.py index f9443da70..1d765f475 100644 --- a/qupulse/pulses/arithmetic_pulse_template.py +++ b/qupulse/pulses/arithmetic_pulse_template.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Set, Optional, Union, Mapping, FrozenSet, cast, Callable from numbers import Real import warnings +import operator import sympy @@ -17,15 +18,17 @@ IdentityTransformation -def _apply_operation_to_channel_dict(operator: str, - lhs: Mapping[ChannelID, Any], - rhs: Mapping[ChannelID, Any]) -> Dict[ChannelID, Any]: +def _apply_operation_to_channel_dict(lhs: Mapping[ChannelID, Any], + rhs: Mapping[ChannelID, Any], + operator_both: Optional[Callable[[Any, Any], Any]], + rhs_only: Optional[Callable[[Any], Any]] + ) -> Dict[ChannelID, Any]: result = dict(lhs) for channel, rhs_value in rhs.items(): if channel in result: - result[channel] = ArithmeticWaveform.operator_map[operator](result[channel], rhs_value) + result[channel] = operator_both(result[channel], rhs_value) else: - result[channel] = ArithmeticWaveform.rhs_only_map[operator](rhs_value) + result[channel] = rhs_only(rhs_value) return result @@ -105,14 +108,30 @@ def duration(self) -> ExpressionScalar: """Duration of the lhs operand if it is larger zero. Else duration of the rhs.""" return ExpressionScalar(sympy.Max(self.lhs.duration, self.rhs.duration)) + def _apply_operation(self, lhs: Mapping[str, Any], rhs: Mapping[str, Any]) -> Dict[str, Any]: + operator_both = ArithmeticWaveform.operator_map[self._arithmetic_operator] + rhs_only = ArithmeticWaveform.rhs_only_map[self._arithmetic_operator] + return _apply_operation_to_channel_dict(lhs, rhs, + operator_both=operator_both, + rhs_only=rhs_only) + @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: - return _apply_operation_to_channel_dict(self._arithmetic_operator, self.lhs.integral, self.rhs.integral) + # this is a guard for possible future changes + assert self._arithmetic_operator in ('+', '-'), \ + f"Integral not correctly implemented for '{self._arithmetic_operator}'" + return self._apply_operation(self.lhs.integral, self.rhs.integral) def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: - return _apply_operation_to_channel_dict(self._arithmetic_operator, - self.lhs._as_expression(), - self.rhs._as_expression()) + return self._apply_operation(self.lhs._as_expression(), self.rhs._as_expression()) + + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._apply_operation(self.lhs.initial_values, self.rhs.initial_values) + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._apply_operation(self.lhs.final_values, self.rhs.final_values) def build_waveform(self, parameters: Dict[str, Real], @@ -125,7 +144,7 @@ def build_waveform(self, if lhs is None: return ArithmeticWaveform.rhs_only_map[self.arithmetic_operator](rhs) else: - return ArithmeticWaveform(lhs, self.arithmetic_operator, rhs) + return ArithmeticWaveform.from_operator(lhs, self.arithmetic_operator, rhs) def get_measurement_windows(self, parameters: Dict[str, Real], @@ -172,15 +191,21 @@ def deserialize(cls, serializer: Optional[Serializer] = None, **kwargs) -> 'Arit class ArithmeticPulseTemplate(PulseTemplate): - """""" - def __init__(self, lhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, ExpressionLike]], arithmetic_operator: str, rhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, ExpressionLike]], *, identifier: Optional[str] = None): - """ + """Allowed operations + + scalar + pulse_template + scalar - pulse_template + scalar * pulse_template + pulse_template + scalar + pulse_template - scalar + pulse_template * scalar + pulse_template / scalar Args: lhs: Left hand side operand @@ -356,7 +381,7 @@ def build_waveform(self, transformation = self._get_transformation(parameters=parameters, channel_mapping=channel_mapping) - return TransformingWaveform(inner_waveform, transformation=transformation) + return TransformingWaveform.from_transformation(inner_waveform, transformation=transformation) def __repr__(self): if any(v for k, v in super().get_serialization_data().items() if k != '#type'): @@ -384,47 +409,67 @@ def defined_channels(self): def duration(self) -> ExpressionScalar: return self._pulse_template.duration + def _scalar_as_dict(self) -> Dict[ChannelID, ExpressionScalar]: + if isinstance(self._scalar, ExpressionScalar): + return {channel: self._scalar + for channel in self.defined_channels} + else: + return dict(self._scalar) + @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: - integral = {channel: value for channel, value in self._pulse_template.integral.items()} + integral = {channel: value.sympified_expression for channel, value in self._pulse_template.integral.items()} + scalar = self._scalar_as_dict() - if isinstance(self._scalar, ExpressionScalar): - scalar = {channel: self._scalar - for channel in self.defined_channels} + if self._arithmetic_operator in ('+', '-'): + for ch, value in scalar.items(): + scalar[ch] = value * self.duration.sympified_expression + + return self._apply_operation_to_channel_dict(integral, scalar) + + def _apply_operation_to_channel_dict(self, + pt_values: Dict[ChannelID, ExpressionScalar], + scalar_values: Dict[ChannelID, ExpressionScalar]): + operator_map = { + '+': operator.add, + '-': operator.sub, + '/': operator.truediv, + '*': operator.mul + } + + rhs_only_map = { + '+': operator.pos, + '-': operator.neg, + '*': lambda x: x, + '/': lambda x: 1 / x + } + + if self._pulse_template is self.lhs: + lhs, rhs = pt_values, scalar_values else: - scalar = {channel: value - for channel, value in self._scalar.items()} + lhs, rhs = scalar_values, pt_values + # cannot divide by pulse templates + operator_map.pop('/') + rhs_only_map.pop('/') - if self._arithmetic_operator == '+': - for channel, value in scalar.items(): - integral[channel] = integral[channel] + (value * self.duration) + operator_both = operator_map.get(self._arithmetic_operator, None) + rhs_only = rhs_only_map.get(self._arithmetic_operator, None) - elif self._arithmetic_operator == '*': - for channel, value in scalar.items(): - integral[channel] = integral[channel] * value + return _apply_operation_to_channel_dict(lhs, rhs, operator_both=operator_both, rhs_only=rhs_only) - elif self._arithmetic_operator == '/': - assert self._pulse_template is self.lhs - for channel, value in scalar.items(): - integral[channel] = integral[channel] / value + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._apply_operation_to_channel_dict( + self._pulse_template.initial_values, + self._scalar_as_dict() + ) - else: - assert self._arithmetic_operator == '-' - if self._pulse_template is self.rhs: - # we need to negate all existing values - for channel, inner_value in integral.items(): - if channel in scalar: - integral[channel] = scalar[channel] * self.duration - inner_value - else: - integral[channel] = -inner_value - - else: - for channel, value in scalar.items(): - integral[channel] = integral[channel] - value * self.duration - - for channel, value in integral.items(): - integral[channel] = ExpressionScalar(value) - return integral + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._apply_operation_to_channel_dict( + self._pulse_template.final_values, + self._scalar_as_dict() + ) @property def measurement_names(self) -> Set[str]: diff --git a/qupulse/pulses/constant_pulse_template.py b/qupulse/pulses/constant_pulse_template.py index f14afd61c..43c22030f 100644 --- a/qupulse/pulses/constant_pulse_template.py +++ b/qupulse/pulses/constant_pulse_template.py @@ -8,48 +8,47 @@ import logging import numbers -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Union, Mapping, AbstractSet from qupulse.utils import cached_property from qupulse._program import ProgramBuilder from qupulse._program.waveforms import ConstantWaveform -from qupulse.expressions import ExpressionScalar -from qupulse.parameter_scope import Scope +from qupulse.utils.types import TimeType, ChannelID +from qupulse.utils import cached_property +from qupulse.expressions import ExpressionScalar, ExpressionLike from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform -from qupulse.pulses.parameters import ParameterNotProvidedException -from qupulse.pulses.pulse_template import (AtomicPulseTemplate, ChannelID, - Loop, MeasurementDeclaration, - PulseTemplate, Transformation, - TransformingWaveform) +from qupulse.pulses.pulse_template import AtomicPulseTemplate, MeasurementDeclaration from qupulse.serialization import PulseRegistryType __all__ = ["ConstantPulseTemplate"] class ConstantPulseTemplate(AtomicPulseTemplate): # type: ignore - - def __init__(self, duration: float, amplitude_dict: Dict[str, Any], identifier: Optional[str] = None, - name: Optional[str] = None, measurements: Optional[List[MeasurementDeclaration]] = (), - registry: PulseRegistryType = None, - **kwargs: Any) -> None: - """ A qupulse waveform representing a multi-channel pulse with constant values + def __init__(self, duration: ExpressionLike, amplitude_dict: Dict[ChannelID, ExpressionLike], + identifier: Optional[str] = None, + name: Optional[str] = None, + measurements: Optional[List[MeasurementDeclaration]] = None, + registry: PulseRegistryType=None) -> None: + """An atomic pulse template qupulse representing a multi-channel pulse with constant values. Args: duration: Duration of the template amplitude_dict: Dictionary with values for the channels - name: Name for the template - + name: Name for the template. Not used by qupulse """ - super().__init__(identifier=identifier, measurements=measurements, **kwargs) + super().__init__(identifier=identifier, measurements=measurements) - self._duration = ExpressionScalar(duration) - self._amplitude_dict = {channel: ExpressionScalar(value) for channel, value in amplitude_dict.items()} + # we special case numeric values in this PulseTemplate for performance reasons + self._duration = duration if isinstance(duration, (float, int, TimeType)) else ExpressionScalar(duration) + self._amplitude_dict: Mapping[ChannelID, Union[float, ExpressionScalar]] = { + channel: value if isinstance(value, (float, int)) else ExpressionScalar(value) + for channel, value in amplitude_dict.items()} if name is None: name = 'constant_pulse' self._name = name - self._register(registry=registry) + self._register(registry) def _as_expression(self): return self._amplitude_dict @@ -57,45 +56,56 @@ def _as_expression(self): def __str__(self) -> str: return '<{} at %x{}: {}>'.format(self.__class__.__name__, '%x' % id(self), self._name) - def build_sequence(self) -> None: - return - - def get_serialization_data(self) -> Any: + def get_serialization_data(self, serializer=None) -> Any: + if serializer is not None: + raise NotImplementedError("ConstantPulseTemplate does not implement legacy serialization.") data = super().get_serialization_data() - data.update({'name': self._name, 'duration': self._duration, 'amplitude_dict': self._amplitude_dict}) + data.update({ + 'name': self._name, + 'duration': self._duration, + 'amplitude_dict': self._amplitude_dict, + 'measurements': self.measurement_declarations + }) return data + @classmethod + def deserialize(cls, serializer: Optional = None, **kwargs) -> 'ConstantPulseTemplate': + assert serializer is None, f"{cls} does not support legacy deserialization" + # this is for backwards compatible deserialization. + amplitudes = kwargs.pop('#amplitudes', None) + if amplitudes is not None: + kwargs['amplitude_dict'] = amplitudes + return cls(**kwargs) + @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: """Returns an expression giving the integral over the pulse.""" return {c: self.duration * self._amplitude_dict[c] for c in self._amplitude_dict} @cached_property - def parameter_names(self) -> Set[str]: + def parameter_names(self) -> AbstractSet[str]: """The set of names of parameters required to instantiate this PulseTemplate.""" - parameter_names = set(getattr(self._duration, 'variables', ())) - for value in self._amplitude_dict.values(): - parameter_names.update(getattr(value, 'variables', ())) - return parameter_names + parameters = [] + for amplitude in self._amplitude_dict.values(): + if hasattr(amplitude, 'variables'): + parameters.extend(amplitude.variables) + if hasattr(self._duration, 'variables'): + parameters.extend(self._duration.variables) + parameters.extend(self.measurement_parameters) + return frozenset(parameters) - @property - def is_interruptable(self) -> bool: - """Return true, if this PulseTemplate contains points at which it can halt if interrupted. - """ - return False - - @property + @cached_property def duration(self) -> ExpressionScalar: """An expression for the duration of this PulseTemplate.""" - return self._duration + if isinstance(self._duration, ExpressionScalar): + return self._duration + else: + return ExpressionScalar(self._duration) @property - def defined_channels(self) -> Set['ChannelID']: + def defined_channels(self) -> AbstractSet['ChannelID']: """Returns the number of hardware output channels this PulseTemplate defines.""" - return set(self._amplitude_dict.keys()) - - def requires_stop(self) -> bool: # from SequencingElement - return False + return set(self._amplitude_dict) def build_waveform(self, parameters: Dict[str, numbers.Real], @@ -104,16 +114,28 @@ def build_waveform(self, logging.debug(f'build_waveform of ConstantPulse: channel_mapping {channel_mapping}, ' f'defined_channels {self.defined_channels}') - constant_values = {} - for channel, value in self._amplitude_dict.items(): - mapped_channel = channel_mapping[channel] - if mapped_channel is not None: - evaluator = getattr(value, 'evaluate_in_scope', None) - if evaluator: - value = evaluator(parameters) - constant_values[mapped_channel] = value - - if constant_values: - return ConstantWaveform.from_mapping(self.duration.evaluate_in_scope(parameters), constant_values) - else: - return None + # we very freely use duck-typing here to speed up cases where duration and amplitude values are already numeric + duration = self._duration + if hasattr(duration, 'evaluate_in_scope'): + duration = duration.evaluate_in_scope(parameters) + + if duration > 0: + constant_values = {} + for channel, value in self._amplitude_dict.items(): + mapped_channel = channel_mapping[channel] + if mapped_channel is not None: + if hasattr(value, 'evaluate_in_scope'): + value = value.evaluate_in_scope(parameters) + constant_values[mapped_channel] = value + + if constant_values: + return ConstantWaveform.from_mapping(duration, constant_values) + return None + + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return {ch: ExpressionScalar(val) for ch, val in self._amplitude_dict.items()} + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return {ch: ExpressionScalar(val) for ch, val in self._amplitude_dict.items()} diff --git a/qupulse/pulses/function_pulse_template.py b/qupulse/pulses/function_pulse_template.py index 1913c8697..78e555776 100644 --- a/qupulse/pulses/function_pulse_template.py +++ b/qupulse/pulses/function_pulse_template.py @@ -105,7 +105,7 @@ def build_waveform(self, expression = self.__expression.evaluate_symbolic(substitutions=parameters) duration = self.__duration_expression.evaluate_with_exact_rationals(parameters) - return FunctionWaveform(expression=expression, + return FunctionWaveform.from_expression(expression=expression, duration=duration, channel=channel_mapping[self.__channel]) @@ -151,4 +151,14 @@ def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: expr = ExpressionScalar.make(self.__expression.underlying_expression.subs({'t': self._AS_EXPRESSION_TIME})) return {self.__channel: expr} + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + expr = ExpressionScalar.make(self.__expression.underlying_expression.subs('t', 0)) + return {self.__channel: expr} + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + expr = ExpressionScalar.make(self.__expression.underlying_expression.subs('t', self.__duration_expression.underlying_expression)) + return {self.__channel: expr} + diff --git a/qupulse/pulses/loop_pulse_template.py b/qupulse/pulses/loop_pulse_template.py index 1bf534ae6..b0ea8af4a 100644 --- a/qupulse/pulses/loop_pulse_template.py +++ b/qupulse/pulses/loop_pulse_template.py @@ -21,6 +21,7 @@ from qupulse.pulses.pulse_template import PulseTemplate, ChannelID, AtomicPulseTemplate from qupulse._program.waveforms import SequenceWaveform as ForLoopWaveform from qupulse.pulses.measurement import MeasurementDefiner, MeasurementDeclaration +from qupulse.pulses.range import ParametrizedRange, RangeScope __all__ = ['ForLoopPulseTemplate', 'LoopPulseTemplate', 'LoopIndexNotUsedException'] @@ -45,54 +46,6 @@ def measurement_names(self) -> Set[str]: return self.__body.measurement_names -class ParametrizedRange: - """Like the builtin python range but with parameters.""" - def __init__(self, *args, **kwargs): - """Positional and keyword arguments cannot be mixed. - - Args: - *args: Interpreted as ``(start, )`` or ``(start, stop[, step])`` - **kwargs: Expected to contain ``start``, ``stop`` and ``step`` - Raises: - TypeError: If positional and keyword arguments are mixed - KeyError: If keyword arguments but one of ``start``, ``stop`` or ``step`` is missing - """ - if args and kwargs: - raise TypeError('ParametrizedRange only takes either positional or keyword arguments') - elif kwargs: - start = kwargs['start'] - stop = kwargs['stop'] - step = kwargs['step'] - elif len(args) in (1, 2, 3): - if len(args) == 3: - start, stop, step = args - elif len(args) == 2: - (start, stop), step = args, 1 - elif len(args) == 1: - start, (stop,), step = 0, args, 1 - else: - raise TypeError('ParametrizedRange expected 1 to 3 arguments, got {}'.format(len(args))) - - self.start = ExpressionScalar.make(start) - self.stop = ExpressionScalar.make(stop) - self.step = ExpressionScalar.make(step) - - def to_tuple(self) -> Tuple[Any, Any, Any]: - """Return a simple representation of the range which is useful for comparison and serialization""" - return (self.start.get_serialization_data(), - self.stop.get_serialization_data(), - self.step.get_serialization_data()) - - def to_range(self, parameters: Mapping[str, Number]) -> range: - return range(checked_int_cast(self.start.evaluate_in_scope(parameters)), - checked_int_cast(self.stop.evaluate_in_scope(parameters)), - checked_int_cast(self.step.evaluate_in_scope(parameters))) - - @property - def parameter_names(self) -> Set[str]: - return set(self.start.variables) | set(self.stop.variables) | set(self.step.variables) - - class ForLoopPulseTemplate(LoopPulseTemplate, MeasurementDefiner, ParameterConstrainer): """This pulse template allows looping through an parametrized integer range and provides the loop index as a parameter to the body. If you do not need the index in the pulse template, consider using @@ -122,18 +75,7 @@ def __init__(self, MeasurementDefiner.__init__(self, measurements=measurements) ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints) - if isinstance(loop_range, ParametrizedRange): - self._loop_range = loop_range - elif isinstance(loop_range, (int, str)): - self._loop_range = ParametrizedRange(loop_range) - elif isinstance(loop_range, (tuple, list)): - self._loop_range = ParametrizedRange(*loop_range) - elif isinstance(loop_range, range): - self._loop_range = ParametrizedRange(start=loop_range.start, - stop=loop_range.stop, - step=loop_range.step) - else: - raise ValueError('loop_range is not valid') + self._loop_range = ParametrizedRange.from_range_like(loop_range) if not loop_index.isidentifier(): raise InvalidParameterNameException(loop_index) @@ -198,15 +140,8 @@ def _body_scope_generator(self, scope: Scope, forward=True) -> Iterator[Scope]: loop_range = loop_range if forward else reversed(loop_range) loop_index_name = self._loop_index - get_for_loop_scope = _get_for_loop_scope - for loop_index_value in loop_range: - try: - yield get_for_loop_scope(scope, loop_index_name, loop_index_value) - except TypeError: - # we cannot hash the scope so we will not try anymore - get_for_loop_scope = _ForLoopScope - yield get_for_loop_scope(scope, loop_index_name, loop_index_value) + yield _ForLoopScope(scope, loop_index_name, loop_index_value) def _internal_create_program(self, *, scope: Scope, @@ -281,6 +216,24 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: return body_integrals + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + values = self.body.initial_values + initial_idx = self._loop_range.start + for ch, value in values.items(): + values[ch] = ExpressionScalar(value.underlying_expression.subs(self._loop_index, initial_idx)) + return values + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + values = self.body.final_values + start, step, stop = self._loop_range.start.sympified_expression, self._loop_range.step.sympified_expression, self._loop_range.stop.sympified_expression + n = (stop - start) // step + final_idx = start + sympy.Max(n - 1, 0) * step + for ch, value in values.items(): + values[ch] = ExpressionScalar(value.underlying_expression.subs(self._loop_index, final_idx)) + return values + class LoopIndexNotUsedException(Exception): def __init__(self, loop_index: str, body_parameter_names: Set[str]): @@ -292,78 +245,4 @@ def __str__(self) -> str: self.body_parameter_names) -class _ForLoopScope(Scope): - __slots__ = ('_index_name', '_index_value', '_inner') - - def __init__(self, inner: Scope, index_name: str, index_value: int): - super().__init__() - self._inner = inner - self._index_name = index_name - self._index_value = index_value - - def get_volatile_parameters(self) -> FrozenMapping[str, Expression]: - inner_volatile = self._inner.get_volatile_parameters() - - if self._index_name in inner_volatile: - # TODO: use delete method of frozendict - index_name = self._index_name - return FrozenDict((name, value) for name, value in inner_volatile.items() if name != index_name) - else: - return inner_volatile - - def __hash__(self): - return hash((self._inner, self._index_name, self._index_value)) - - def __eq__(self, other: '_ForLoopScope'): - try: - return (self._index_name == other._index_name - and self._index_value == other._index_value - and self._inner == other._inner) - except AttributeError: - return False - - def __contains__(self, item): - return item == self._index_name or item in self._inner - - def get_parameter(self, parameter_name: str) -> Number: - if parameter_name == self._index_name: - return self._index_value - else: - return self._inner.get_parameter(parameter_name) - - __getitem__ = get_parameter - - def change_constants(self, new_constants: Mapping[str, Number]) -> 'Scope': - return _get_for_loop_scope(self._inner.change_constants(new_constants), self._index_name, self._index_value) - - def __len__(self) -> int: - return len(self._inner) + int(self._index_name not in self._inner) - - def __iter__(self) -> Iterator: - if self._index_name in self._inner: - return iter(self._inner) - else: - return itertools.chain(self._inner, (self._index_name,)) - - def as_dict(self) -> FrozenMapping[str, Number]: - if self._as_dict is None: - self._as_dict = FrozenDict({**self._inner.as_dict(), self._index_name: self._index_value}) - return self._as_dict - - def keys(self): - return self.as_dict().keys() - - def items(self): - return self.as_dict().items() - - def values(self): - return self.as_dict().values() - - def __repr__(self): - return f'{type(self)}(inner={self._inner!r}, index_name={self._index_name!r}, ' \ - f'index_value={self._index_value!r})' - - -@functools.lru_cache(maxsize=10**6) -def _get_for_loop_scope(inner: Scope, index_name: str, index_value: int) -> Scope: - return _ForLoopScope(inner, index_name, index_value) +_ForLoopScope = RangeScope diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index ab415c3db..5a1a4316c 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -344,38 +344,31 @@ def get_measurement_windows(self, measurement_mapping=self.get_updated_measurement_mapping(measurement_mapping=measurement_mapping) ) - @property - def integral(self) -> Dict[ChannelID, ExpressionScalar]: - internal_integral = self.__template.integral - expressions = dict() - - # sympy.subs() does not work if one of the mappings in the provided dict is an Expression object - # the following is an ugly workaround - # todo: make Expressions compatible with sympy.subs() - parameter_mapping = {parameter_name: expression.underlying_expression - for parameter_name, expression in self.__parameter_mapping.items()} - for channel, ch_integral in internal_integral.items(): - channel_out = self.__channel_mapping.get(channel, channel) - if channel_out is None: - continue - - expressions[channel_out] = ExpressionScalar( - ch_integral.sympified_expression.subs(parameter_mapping, simultaneous=True) - ) - - return expressions - - def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + def _apply_mapping_to_inner_channel_dict(self, to_map: Dict[ChannelID, ExpressionScalar]) -> Dict[ChannelID, ExpressionScalar]: parameter_mapping = {parameter_name: expression.underlying_expression for parameter_name, expression in self.__parameter_mapping.items()} - inner = self.__template._as_expression() return { self.__channel_mapping.get(ch, ch): ExpressionScalar(ch_expr.sympified_expression.subs(parameter_mapping, simultaneous=True)) - for ch, ch_expr in inner.items() + for ch, ch_expr in to_map.items() if self.__channel_mapping.get(ch, ch) is not None } + @property + def integral(self) -> Dict[ChannelID, ExpressionScalar]: + return self._apply_mapping_to_inner_channel_dict(self.__template.integral) + + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + return self._apply_mapping_to_inner_channel_dict(self.__template._as_expression()) + + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._apply_mapping_to_inner_channel_dict(self.__template.initial_values) + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._apply_mapping_to_inner_channel_dict(self.__template.final_values) + class MissingMappingException(Exception): """Indicates that no mapping was specified for some parameter declaration of a diff --git a/qupulse/pulses/multi_channel_pulse_template.py b/qupulse/pulses/multi_channel_pulse_template.py index b6c220b05..29af7168c 100644 --- a/qupulse/pulses/multi_channel_pulse_template.py +++ b/qupulse/pulses/multi_channel_pulse_template.py @@ -134,7 +134,7 @@ def build_waveform(self, parameters: Dict[str, numbers.Real], if len(sub_waveforms) == 1: waveform = sub_waveforms[0] else: - waveform = MultiChannelWaveform(sub_waveforms) + waveform = MultiChannelWaveform.from_parallel(sub_waveforms) if self._duration: expected_duration = self._duration.evaluate_numeric(**parameters) @@ -194,6 +194,20 @@ def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: expressions.update(subtemplate._as_expression()) return expressions + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + values = {} + for subtemplate in self._subtemplates: + values.update(subtemplate.initial_values) + return values + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + values = {} + for subtemplate in self._subtemplates: + values.update(subtemplate.final_values) + return values + class ParallelConstantChannelPulseTemplate(PulseTemplate): def __init__(self, @@ -249,7 +263,7 @@ def build_waveform(self, parameters: Dict[str, numbers.Real], overwritten_channels = self._get_overwritten_channels_values(parameters=parameters, channel_mapping=channel_mapping) transformation = ParallelConstantChannelTransformation(overwritten_channels) - return TransformingWaveform(inner_waveform, transformation) + return TransformingWaveform.from_transformation(inner_waveform, transformation) @property def defined_channels(self) -> Set[ChannelID]: @@ -280,6 +294,18 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: integral[channel] = value * duration return integral + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + values = self._template.initial_values + values.update(self._overwritten_channels) + return values + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + values = self._template.final_values + values.update(self._overwritten_channels) + return values + def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: if serializer: raise NotImplementedError('Legacy serialization not implemented for new class') diff --git a/qupulse/pulses/plotting.py b/qupulse/pulses/plotting.py index cc8f4fbf5..a021b417b 100644 --- a/qupulse/pulses/plotting.py +++ b/qupulse/pulses/plotting.py @@ -122,7 +122,7 @@ def _render_loop(loop: Loop, def plot(pulse: PulseTemplate, parameters: Dict[str, Parameter]=None, - sample_rate: Real=10, + sample_rate: Optional[Real]=10, axes: Any=None, show: bool=True, plot_channels: Optional[Set[ChannelID]]=None, @@ -142,7 +142,7 @@ def plot(pulse: PulseTemplate, parameters: An optional mapping of parameter names to Parameter objects. sample_rate: The rate with which the waveforms are sampled for the plot in - samples per time unit. (default = 10) + samples per time unit. If None, then automatically determine the sample rate (default = 10) axes: matplotlib Axes object the pulse will be drawn into if provided show: If true, the figure will be shown plot_channels: If specified only channels from this set will be plotted. If omitted all channels will be. @@ -165,6 +165,17 @@ def plot(pulse: PulseTemplate, if parameters is None: parameters = dict() + if sample_rate is None: + if time_slice is None: + duration = pulse.duration + else: + duration = time_slice[1]-time_slice[0] + if duration == 0: + sample_rate = 1 + else: + duration_per_sample = float(duration) / 1000 + sample_rate = 1 / duration_per_sample + program = pulse.create_program(parameters=parameters, channel_mapping={ch: ch for ch in channels}, measurement_mapping={w: w for w in pulse.measurement_names}) @@ -241,7 +252,10 @@ def plot(pulse: PulseTemplate, axes.set_title(pulse.identifier) if show: - axes.get_figure().show() + with warnings.catch_warnings(): + # do not show warnings in jupyter notebook with matplotlib inline backend + warnings.filterwarnings(action="ignore",message=".*which is a non-GUI backend, so cannot show the figure.*") + axes.get_figure().show() return axes.get_figure() diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index 376e28be4..b5c5cfa03 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -168,6 +168,22 @@ def value_trafo(v): expressions[channel] = pw return expressions + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + shape = (len(self._channels),) + return { + ch: ExpressionScalar(IndexedBroadcast(self._entries[0].v, shape, ch_idx)) + for ch_idx, ch in enumerate(self._channels) + } + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + shape = (len(self._channels),) + return { + ch: ExpressionScalar(IndexedBroadcast(self._entries[-1].v, shape, ch_idx)) + for ch_idx, ch in enumerate(self._channels) + } + class InvalidPointDimension(Exception): def __init__(self, expected, received): diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index a7da2ce9c..740dd63fc 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -97,6 +97,16 @@ def __rmatmul__(self, other: MappingTuple) -> 'SequencePulseTemplate': def integral(self) -> Dict[ChannelID, ExpressionScalar]: """Returns an expression giving the integral over the pulse.""" + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + """Values of defined channels at t == 0""" + raise NotImplementedError(f"The pulse template of type {type(self)} does not implement `initial_values`") + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + """Values of defined channels at t == self.duration""" + raise NotImplementedError(f"The pulse template of type {type(self)} does not implement `final_values`") + def create_program(self, *, parameters: Optional[Mapping[str, Union[Expression, str, Number, ConstantParameter]]]=None, measurement_mapping: Optional[Mapping[str, Optional[str]]]=None, @@ -218,7 +228,7 @@ def _create_program(self, *, waveform = to_waveform(program) if global_transformation: - waveform = TransformingWaveform(waveform, global_transformation) + waveform = TransformingWaveform.from_transformation(waveform, global_transformation) # convert the nicely formatted measurement windows back into the old format again :( measurements = program.get_measurement_windows() @@ -333,7 +343,7 @@ def _internal_create_program(self, *, measurement_mapping=measurement_mapping) if global_transformation: - waveform = TransformingWaveform(waveform, global_transformation) + waveform = TransformingWaveform.from_transformation(waveform, global_transformation) parent_loop.append_leaf(waveform=waveform, measurements=measurements) @@ -358,7 +368,29 @@ def build_waveform(self, def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: """Helper function to allow integral calculation in case of truncation. AtomicPulseTemplate._AS_EXPRESSION_TIME is by convention the time variable.""" - raise NotImplementedError(f"_as_expression is not implemented for {type(self)} which means it cannot be truncated and integrated over.") + raise NotImplementedError(f"_as_expression is not implemented for {type(self)} " + f"which means it cannot be truncated and integrated over.") + + @property + def integral(self) -> Dict[ChannelID, ExpressionScalar]: + # this default implementation uses _as_expression + return {ch: ExpressionScalar(sympy.integrate(expr.sympified_expression, + (self._AS_EXPRESSION_TIME, 0, self.duration.sympified_expression))) + for ch, expr in self._as_expression().items()} + + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + values = self._as_expression() + for ch, value in values.items(): + values[ch] = value.evaluate_symbolic({self._AS_EXPRESSION_TIME: 0}) + return values + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + values = self._as_expression() + for ch, value in values.items(): + values[ch] = value.evaluate_symbolic({self._AS_EXPRESSION_TIME: self.duration}) + return values class DoubleParameterNameException(Exception): diff --git a/qupulse/pulses/range.py b/qupulse/pulses/range.py new file mode 100644 index 000000000..34f7e8a8e --- /dev/null +++ b/qupulse/pulses/range.py @@ -0,0 +1,156 @@ +from typing import Tuple, Any, AbstractSet, Mapping, Union, Iterator +from numbers import Number +from dataclasses import dataclass +from functools import lru_cache +import itertools + +from qupulse.utils import checked_int_cast, cached_property +from qupulse.expressions import ExpressionScalar, ExpressionVariableMissingException, ExpressionLike, Expression +from qupulse.parameter_scope import Scope +from qupulse.utils.types import FrozenDict, FrozenMapping + +RangeLike = Union[range, + ExpressionLike, + Tuple[ExpressionLike, ExpressionLike], + Tuple[ExpressionLike, ExpressionLike, ExpressionLike]] + + +@dataclass(frozen=True) +class ParametrizedRange: + start: ExpressionScalar + stop: ExpressionScalar + step: ExpressionScalar + + def __init__(self, *args, **kwargs): + """Like the builtin python range but with parameters. Positional and keyword arguments cannot be mixed. + + Args: + *args: Interpreted as ``(start, )`` or ``(start, stop[, step])`` + **kwargs: Expected to contain ``start``, ``stop`` and ``step`` + Raises: + TypeError: If positional and keyword arguments are mixed + KeyError: If keyword arguments but one of ``start``, ``stop`` or ``step`` is missing + """ + if args and kwargs: + raise TypeError('ParametrizedRange only takes either positional or keyword arguments') + elif kwargs: + start = kwargs['start'] + stop = kwargs['stop'] + step = kwargs['step'] + elif len(args) in (1, 2, 3): + if len(args) == 3: + start, stop, step = args + elif len(args) == 2: + (start, stop), step = args, 1 + else: + start, (stop,), step = 0, args, 1 + else: + raise TypeError('ParametrizedRange expected 1 to 3 arguments, got {}'.format(len(args)), args) + + object.__setattr__(self, 'start', ExpressionScalar(start)) + object.__setattr__(self, 'stop', ExpressionScalar(stop)) + object.__setattr__(self, 'step', ExpressionScalar(step)) + + @lru_cache(maxsize=1024) + def to_tuple(self) -> Tuple[Any, Any, Any]: + """Return a simple representation of the range which is useful for comparison and serialization""" + return (self.start.get_serialization_data(), + self.stop.get_serialization_data(), + self.step.get_serialization_data()) + + def to_range(self, parameters: Mapping[str, Number]) -> range: + return range(checked_int_cast(self.start.evaluate_in_scope(parameters)), + checked_int_cast(self.stop.evaluate_in_scope(parameters)), + checked_int_cast(self.step.evaluate_in_scope(parameters))) + + @cached_property + def parameter_names(self) -> AbstractSet[str]: + return set(self.start.variables) | set(self.stop.variables) | set(self.step.variables) + + @classmethod + def from_range_like(cls, range_like: RangeLike): + if isinstance(range_like, cls): + return range_like + elif isinstance(range_like, (tuple, list)): + return cls(*range_like) + elif isinstance(range_like, range): + return cls(range_like.start, range_like.stop, range_like.step) + elif isinstance(range_like, slice): + raise TypeError("Cannot construct a range from a slice") + else: + return cls(range_like) + + def get_serialization_data(self): + return self.to_tuple() + + +class RangeScope(Scope): + __slots__ = ('_index_name', '_index_value', '_inner') + + def __init__(self, inner: Scope, index_name: str, index_value: int): + super().__init__() + self._inner = inner + self._index_name = index_name + self._index_value = index_value + + def get_volatile_parameters(self) -> FrozenMapping[str, Expression]: + inner_volatile = self._inner.get_volatile_parameters() + + if self._index_name in inner_volatile: + # TODO: use delete method of frozendict + index_name = self._index_name + return FrozenDict((name, value) for name, value in inner_volatile.items() if name != index_name) + else: + return inner_volatile + + def __hash__(self): + return hash((self._inner, self._index_name, self._index_value)) + + def __eq__(self, other: 'RangeScope'): + try: + return (self._index_name == other._index_name + and self._index_value == other._index_value + and self._inner == other._inner) + except AttributeError: + return NotImplemented + + def __contains__(self, item): + return item == self._index_name or item in self._inner + + def get_parameter(self, parameter_name: str) -> Number: + if parameter_name == self._index_name: + return self._index_value + else: + return self._inner.get_parameter(parameter_name) + + __getitem__ = get_parameter + + def change_constants(self, new_constants: Mapping[str, Number]) -> 'Scope': + return RangeScope(self._inner.change_constants(new_constants), self._index_name, self._index_value) + + def __len__(self) -> int: + return len(self._inner) + int(self._index_name not in self._inner) + + def __iter__(self) -> Iterator: + if self._index_name in self._inner: + return iter(self._inner) + else: + return itertools.chain(self._inner, (self._index_name,)) + + def as_dict(self) -> FrozenMapping[str, Number]: + if self._as_dict is None: + self._as_dict = FrozenDict({**self._inner.as_dict(), self._index_name: self._index_value}) + return self._as_dict + + def keys(self): + return self.as_dict().keys() + + def items(self): + return self.as_dict().items() + + def values(self): + return self.as_dict().values() + + def __repr__(self): + return f'{type(self)}(inner={self._inner!r}, index_name={self._index_name!r}, ' \ + f'index_value={self._index_value!r})' diff --git a/qupulse/pulses/repetition_pulse_template.py b/qupulse/pulses/repetition_pulse_template.py index 0ec2c9773..9e8ad9ec6 100644 --- a/qupulse/pulses/repetition_pulse_template.py +++ b/qupulse/pulses/repetition_pulse_template.py @@ -159,6 +159,14 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: body_integral = self.body.integral return {channel: self.repetition_count * value for channel, value in body_integral.items()} + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self.body.initial_values + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self.body.final_values + class ParameterNotIntegerException(Exception): """Indicates that the value of the parameter given as repetition count was not an integer.""" diff --git a/qupulse/pulses/sequence_pulse_template.py b/qupulse/pulses/sequence_pulse_template.py index 2575a2c62..fca353d85 100644 --- a/qupulse/pulses/sequence_pulse_template.py +++ b/qupulse/pulses/sequence_pulse_template.py @@ -123,9 +123,9 @@ def build_waveform(self, parameters: Dict[str, Real], channel_mapping: Dict[ChannelID, ChannelID]) -> SequenceWaveform: self.validate_parameter_constraints(parameters=parameters, volatile=set()) - return SequenceWaveform([sub_template.build_waveform(parameters, - channel_mapping=channel_mapping) - for sub_template in self.__subtemplates]) + return SequenceWaveform.from_sequence( + [sub_template.build_waveform(parameters, channel_mapping=channel_mapping) + for sub_template in self.__subtemplates]) def _internal_create_program(self, *, scope: Scope, @@ -184,3 +184,12 @@ def add_dicts(x, y): return {k: x[k] + y[k] for k in x} return functools.reduce(add_dicts, [sub.integral for sub in self.__subtemplates], expressions) + + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self.__subtemplates[0].initial_values + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self.__subtemplates[-1].final_values + diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index 4decd2576..f95730ded 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -437,6 +437,16 @@ def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: post_value=post_value) return expressions + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return {ch: entries[0].v + for ch, entries in self._entries.items()} + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return {ch: entries[-1].v + for ch, entries in self._entries.items()} + def concatenate(*table_pulse_templates: TablePulseTemplate, **kwargs) -> TablePulseTemplate: """Concatenate two or more table pulse templates""" diff --git a/qupulse/utils/performance.py b/qupulse/utils/performance.py new file mode 100644 index 000000000..4076b664c --- /dev/null +++ b/qupulse/utils/performance.py @@ -0,0 +1,84 @@ +from typing import Tuple +import numpy as np + +try: + import numba + njit = numba.njit(cache=True) +except ImportError: + numba = None + njit = lambda x: x + + +@njit +def _is_monotonic_numba(x: np.ndarray) -> bool: + # No early return because we optimize for the monotonic case and are branch-free. + monotonic = True + for i in range(1, len(x)): + monotonic &= x[i - 1] <= x[i] + return monotonic + + +def _is_monotonic_numpy(arr: np.ndarray) -> bool: + # A bit faster than np.all(np.diff(arr) > 0) for small arrays + # No difference for big arrays + return np.all(arr[1:] >= arr[:-1]) + + +@njit +def _time_windows_to_samples_sorted_numba(begins, lengths, + sample_rate: float) -> Tuple[np.ndarray, np.ndarray]: + begins_as_sample = np.zeros(len(begins), dtype=np.uint64) + lengths_as_sample = np.zeros(len(lengths), dtype=np.uint64) + for idx in range(len(begins)): + begins_as_sample[idx] = round(begins[idx] * sample_rate) + lengths_as_sample[idx] = np.uint64(lengths[idx] * sample_rate) + return begins_as_sample, lengths_as_sample + + +@njit +def _time_windows_to_samples_numba(begins, lengths, + sample_rate: float) -> Tuple[np.ndarray, np.ndarray]: + if _is_monotonic_numba(begins): + # factor 10 faster + begins_as_sample, lengths_as_sample = _time_windows_to_samples_sorted_numba(begins, lengths, sample_rate) + else: + sorting_indices = np.argsort(begins) + + begins_as_sample = np.zeros(len(begins), dtype=np.uint64) + lengths_as_sample = np.zeros(len(lengths), dtype=np.uint64) + for new_pos, old_pos in enumerate(sorting_indices): + begins_as_sample[new_pos] = round(begins[old_pos] * sample_rate) + lengths_as_sample[new_pos] = np.uint64(lengths[old_pos] * sample_rate) + return begins_as_sample, lengths_as_sample + + +def _time_windows_to_samples_numpy(begins: np.ndarray, lengths: np.ndarray, + sample_rate: float) -> Tuple[np.ndarray, np.ndarray]: + sorting_indices = np.argsort(begins) + begins = np.rint(begins * sample_rate).astype(dtype=np.uint64) + lengths = np.floor(lengths * sample_rate).astype(dtype=np.uint64) + + begins = begins[sorting_indices] + lengths = lengths[sorting_indices] + return begins, lengths + + +def time_windows_to_samples(begins: np.ndarray, lengths: np.ndarray, + sample_rate: float) -> Tuple[np.ndarray, np.ndarray]: + """""" + if numba is None: + begins, lengths = _time_windows_to_samples_numpy(begins, lengths, sample_rate) + else: + begins, lengths = _time_windows_to_samples_numba(begins, lengths, sample_rate) + begins.flags.writeable = False + lengths.flags.writeable = False + return begins, lengths + + +if numba is None: + is_monotonic = _is_monotonic_numpy +else: + is_monotonic = _is_monotonic_numba + + + diff --git a/qupulse/utils/types.py b/qupulse/utils/types.py index aac57c143..f59b497bd 100644 --- a/qupulse/utils/types.py +++ b/qupulse/utils/types.py @@ -246,7 +246,7 @@ def __eq__(self, other): @classmethod def as_comparable(cls, other: typing.Union['TimeType', typing.Any]): - if type(other) == cls: + if type(other) is cls: return other._value else: return other @@ -404,9 +404,13 @@ def __hash__(self): return hash(self.tobytes()) +@functools.lru_cache(maxsize=128) +def _public_type_attributes(type_obj): + return {attr for attr in dir(type_obj) if not attr.startswith('_')} + def has_type_interface(obj: typing.Any, type_obj: typing.Type) -> bool: - """Return true if all public attributes of the class are attribues of the object""" - return set(dir(obj)) >= {attr for attr in dir(type_obj) if not attr.startswith('_')} + """Return true if all public attributes of the class are attributes of the object""" + return set(dir(obj)) >= _public_type_attributes(type_obj) _KT_hash = typing.TypeVar('_KT_hash', bound=typing.Hashable) # Key type. diff --git a/setup.cfg b/setup.cfg index bfec25460..c430620bc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,6 +49,7 @@ autologging = autologging # sadly not open source for external legal reasons # commented out because pypi does not allow direct dependencies # atsaverage = atsaverage @ git+ssh://git@git.rwth-aachen.de/qutech/cpp-atsaverage.git@master#egg=atsaverage&subdirectory=python_source +faster-sampling = numba [options.packages.find] include = diff --git a/tests/_program/loop_tests.py b/tests/_program/loop_tests.py index 2f674a743..d11df00c8 100644 --- a/tests/_program/loop_tests.py +++ b/tests/_program/loop_tests.py @@ -120,6 +120,28 @@ def test_compare_key(self): self.assertNotEqual(tree1, tree4) self.assertEqual(tree1, tree5) + def test_get_measurement_windows(self): + wf_1 = ConstantWaveform(channel='A', duration=32, amplitude=1) + wf_2 = ConstantWaveform(channel='A', duration=10, amplitude=2) + + prog = Loop(children=[ + Loop(waveform=wf_1, measurements=[('x', 0, 16)], repetition_count=3), + Loop(waveform=wf_2, measurements=[('y', 5, 5)], repetition_count=5), + ], repetition_count=2) + + expected_measurements = { + 'x': (np.array([0, 32, 64, 146, 178, 210]), np.array([16]*6)), + 'y': (np.array([101, 111, 121, 131, 141, 247, 257, 267, 277, 287]), np.array([5]*10)) + } + measurements_no_drop = prog.get_measurement_windows() + np.testing.assert_equal(expected_measurements, measurements_no_drop) + + measurements_drop = prog.get_measurement_windows(drop=True) + np.testing.assert_equal(expected_measurements, measurements_drop) + + # no measurements left + self.assertEqual({}, prog.get_measurement_windows()) + def test_repr(self): tree = self.get_test_loop() self.assertEqual(tree, eval(repr(tree))) diff --git a/tests/_program/seqc_tests.py b/tests/_program/seqc_tests.py index 3d2f00c24..217ecff9b 100644 --- a/tests/_program/seqc_tests.py +++ b/tests/_program/seqc_tests.py @@ -33,6 +33,10 @@ def take(n, iterable): return list(islice(iterable, n)) +def dummy_loop_to_seqc(loop, **kwargs): + return loop + + class BinaryWaveformTest(unittest.TestCase): MAX_RATE = 14 @@ -70,7 +74,7 @@ def make_binary_waveform(waveform): return (BinaryWaveform(data),) else: chs = sorted(waveform.defined_channels) - t = np.arange(0., waveform.duration, 1., dtype=float) + t = np.arange(0., float(waveform.duration), 1., dtype=float) sampled = [None if ch is None else waveform.get_sampled(ch, t) for _, ch in zip_longest(range(6), take(6, chs), fillvalue=None)] @@ -381,7 +385,7 @@ def test_loop_to_seqc_leaf(self): # we use None because it is not used in this test user_registers = None - wf = DummyWaveform(duration=32) + wf = DummyWaveform(duration=32, sample_output=lambda x: np.sin(x)) loop = Loop(waveform=wf) # with wrapping repetition @@ -434,9 +438,6 @@ def test_to_node_clusters(self): loop_to_seqc_kwargs = {'my': 'kwargs'} - def dummy_loop_to_seqc(loop, **kwargs): - return loop - loops = [wf1, wf2, wf1, wf1, wf3, wf1, wf1, wf1, wf3, wf1, wf3, wf1, wf3] expected_calls = [mock.call(loop, **loop_to_seqc_kwargs) for loop in loops] expected_result = [[wf1, wf2, wf1, wf1], [wf3], [wf1, wf1, wf1], [Scope([wf3, wf1]), Scope([wf3, wf1])], [wf3]] @@ -446,6 +447,20 @@ def dummy_loop_to_seqc(loop, **kwargs): self.assertEqual(mock_loop_to_seqc.mock_calls, expected_calls) self.assertEqual(expected_result, result) + def test_to_node_clusters_crash(self): + wf1 = WaveformPlayback(make_binary_waveform(*get_unique_wfs(1, 32))) + wf2 = WaveformPlayback(make_binary_waveform(*get_unique_wfs(1, 64))) + wf3 = WaveformPlayback(make_binary_waveform(*get_unique_wfs(1, 128))) + wf4 = WaveformPlayback(make_binary_waveform(*get_unique_wfs(1, 256))) + + loop_to_seqc_kwargs = {'my': 'kwargs'} + + loops = [wf1, wf2, wf3] * 3 + [wf1] + [wf2, wf4] * 3 + [wf1] + with mock.patch('qupulse._program.seqc.loop_to_seqc', wraps=dummy_loop_to_seqc) as mock_loop_to_seqc: + result = to_node_clusters(loops, loop_to_seqc_kwargs) + expected_result = [[Scope([wf1, wf2, wf3])]*3, [wf1], [Scope([wf2, wf4])]*3, [wf1]] + self.assertEqual(expected_result, result) + def test_find_sharable_waveforms(self): wf1, wf2 = map(WaveformPlayback, map(make_binary_waveform, get_unique_wfs(2, 32))) wf3, wf_shared = map(WaveformPlayback, map(make_binary_waveform, get_unique_wfs(2, 64))) diff --git a/tests/expression_tests.py b/tests/expression_tests.py index 932cdd264..9ff11f6d4 100644 --- a/tests/expression_tests.py +++ b/tests/expression_tests.py @@ -1,3 +1,4 @@ +import pickle import unittest import sys @@ -120,6 +121,14 @@ def test_hash(self): s = ExpressionScalar('a') self.assertEqual({e1, e7}, {e1, e2, e7, s}) + def test_pickle(self): + expr = ExpressionVector([1, 'a + 5', 3]) + # populate lambdified + expr.evaluate_in_scope({'a': 3}) + dumped = pickle.dumps(expr) + loaded = pickle.loads(dumped) + self.assertEqual(expr, loaded) + class ExpressionScalarTests(unittest.TestCase): def test_format(self): @@ -467,6 +476,14 @@ def test_evaluate_with_exact_rationals(self): self.assertEqual(TimeType.from_fraction(10, 3), expr.evaluate_with_exact_rationals({'a': [2, 2], 'b': [1, 4]})) + def test_pickle(self): + expr = ExpressionScalar('1 / a') + # populate lambdified + expr.evaluate_in_scope({'a': 7}) + dumped = pickle.dumps(expr) + loaded = pickle.loads(dumped) + self.assertEqual(expr, loaded) + class ExpressionExceptionTests(unittest.TestCase): def test_expression_variable_missing(self): diff --git a/tests/hardware/alazar_tests.py b/tests/hardware/alazar_tests.py index 5a1a42b64..c6e7d032d 100644 --- a/tests/hardware/alazar_tests.py +++ b/tests/hardware/alazar_tests.py @@ -65,7 +65,7 @@ def test_set_measurement_mask(self): self.assertFalse(result[0].flags.writeable) self.assertFalse(result[1].flags.writeable) - with self.assertRaisesRegex(RuntimeError, 'differing sample factor'): + with self.assertRaisesRegex(RuntimeError, 'differing sample rate'): program.set_measurement_mask('sorted', self.sample_factor*5/4, *self.masks['sorted']) result = program.set_measurement_mask('sorted', self.sample_factor, *self.masks['sorted']) diff --git a/tests/hardware/util_tests.py b/tests/hardware/util_tests.py index 54c7cc5c8..9b983c28f 100644 --- a/tests/hardware/util_tests.py +++ b/tests/hardware/util_tests.py @@ -2,8 +2,14 @@ import numpy as np +try: + import zhinst.utils +except ImportError: + zhinst = None + from qupulse.utils.types import TimeType -from qupulse.hardware.util import voltage_to_uint16, find_positions, get_sample_times +from qupulse.hardware.util import voltage_to_uint16, find_positions, get_sample_times, not_none_indices, \ + zhinst_voltage_to_uint16 from tests.pulses.sequencing_dummies import DummyWaveform @@ -31,8 +37,6 @@ def test_zero_level_14bit(self): self.assertEqual(zero_level, 8192) - - class FindPositionTest(unittest.TestCase): def test_find_position(self): data = [2, 6, -24, 65, 46, 5, -10, 9] @@ -73,3 +77,45 @@ def test_get_sample_times_single_wf(self): np.testing.assert_equal(times, expected_times) np.testing.assert_equal(n_samples, np.asarray(4)) + + +class NotNoneIndexTest(unittest.TestCase): + def test_not_none_indices(self): + self.assertEqual(([None, 0, 1, None, None, 2], 3), + not_none_indices([None, 'a', 'b', None, None, 'c'])) + + +@unittest.skipIf(zhinst is None, "zhinst not installed") +class ZHInstVoltageToUint16Test(unittest.TestCase): + def test_size_exception(self): + with self.assertRaisesRegex(ValueError, "No input"): + zhinst_voltage_to_uint16(None, None, (None, None, None, None)) + with self.assertRaisesRegex(ValueError, "dimension"): + zhinst_voltage_to_uint16(np.zeros(192), np.zeros(191), (None, None, None, None)) + with self.assertRaisesRegex(ValueError, "dimension"): + zhinst_voltage_to_uint16(np.zeros(192), None, (np.zeros(191), None, None, None)) + + def test_range_exception(self): + with self.assertRaisesRegex(ValueError, "invalid"): + zhinst_voltage_to_uint16(2.*np.ones(192), None, (None, None, None, None)) + # this should work + zhinst_voltage_to_uint16(None, None, (2. * np.ones(192), None, None, None)) + + def test_zeros(self): + combined = zhinst_voltage_to_uint16(None, np.zeros(192), (None, None, None, None)) + np.testing.assert_array_equal(np.zeros(3*192, dtype=np.uint16), combined) + + def test_full(self): + ch1 = np.linspace(0, 1., num=192) + ch2 = np.linspace(0., -1., num=192) + + markers = tuple(np.array(([1.] + [0.]*m) * 192)[:192] for m in range(1, 5)) + + combined = zhinst_voltage_to_uint16(ch1, ch2, markers) + + marker_data = [sum(int(markers[m][idx] > 0) << m for m in range(4)) + for idx in range(192)] + marker_data = np.array(marker_data, dtype=np.uint16) + expected = zhinst.utils.convert_awg_waveform(ch1, ch2, marker_data) + + np.testing.assert_array_equal(expected, combined) diff --git a/tests/hardware/zihdawg_tests.py b/tests/hardware/zihdawg_tests.py index dc41cb5f6..028d371bd 100644 --- a/tests/hardware/zihdawg_tests.py +++ b/tests/hardware/zihdawg_tests.py @@ -11,9 +11,17 @@ if pytest: zhinst = pytest.importorskip("zhinst") + + try: + import zhinst.core as zhinst_core + except ImportError: + import zhinst.ziPython as zhinst_core else: try: - import zhinst.ziPython + try: + import zhinst.core as zhinst_core + except ImportError: + import zhinst.ziPython as zhinst_core except ImportError as err: raise unittest.SkipTest("zhinst not present") from err @@ -21,7 +29,7 @@ from qupulse._program._loop import Loop from tests.pulses.sequencing_dummies import DummyWaveform from qupulse.hardware.awgs.zihdawg import HDAWGChannelGroup, HDAWGRepresentation, HDAWGValueError, UserRegister,\ - ConstantParameter, ELFManager, HDAWGChannelGrouping + ConstantParameter, ELFManager, HDAWGChannelGrouping, SingleDeviceChannelGroup class HDAWGRepresentationTests(unittest.TestCase): @@ -36,10 +44,10 @@ def test_init(self): with \ mock.patch('zhinst.utils.api_server_version_check') as mock_version_check,\ - mock.patch('zhinst.ziPython.ziDAQServer') as mock_daq_server, \ + mock.patch.object(zhinst_core, 'ziDAQServer') as mock_daq_server, \ mock.patch('qupulse.hardware.awgs.zihdawg.HDAWGRepresentation._initialize') as mock_init, \ mock.patch('qupulse.hardware.awgs.zihdawg.HDAWGRepresentation.channel_grouping', new_callable=mock.PropertyMock) as mock_grouping, \ - mock.patch('qupulse.hardware.awgs.zihdawg.HDAWGChannelGroup') as mock_channel_pair,\ + mock.patch('qupulse.hardware.awgs.zihdawg.SingleDeviceChannelGroup') as mock_channel_pair,\ mock.patch('zhinst.utils.disable_everything') as mock_reset,\ mock.patch('pathlib.Path') as mock_path: @@ -122,7 +130,7 @@ def test_init(self): channels = (3, 4) awg_group_idx = 1 - channel_pair = HDAWGChannelGroup(awg_group_idx, 2, 'foo', 3.4) + channel_pair = SingleDeviceChannelGroup(awg_group_idx, 2, 'foo', 3.4) self.assertEqual(channel_pair.timeout, 3.4) self.assertEqual(channel_pair._channels(), channels) @@ -137,8 +145,8 @@ def test_init(self): channel_pair.connect_group(mock_device) self.assertTrue(channel_pair.is_connected()) proxy_mock.assert_called_once_with(mock_device) - self.assertIs(channel_pair.device, proxy_mock.return_value) - self.assertIs(channel_pair.awg_module, channel_pair.device.api_session.awgModule.return_value) + self.assertIs(channel_pair.master_device, proxy_mock.return_value) + self.assertIs(channel_pair.awg_module, channel_pair.master_device.api_session.awgModule.return_value) def test_set_volatile_parameters(self): mock_device = mock.Mock() @@ -148,7 +156,7 @@ def test_set_volatile_parameters(self): expected_user_reg_calls = [mock.call(*args) for args in requested_changes.items()] - channel_pair = HDAWGChannelGroup(1, 2, 'foo', 3.4) + channel_pair = SingleDeviceChannelGroup(1, 2, 'foo', 3.4) channel_pair._current_program = 'active_program' with mock.patch.object(channel_pair._program_manager, 'get_register_values_to_update_volatile_parameters', @@ -176,7 +184,7 @@ def test_upload(self): with mock.patch('weakref.proxy'),\ mock.patch('qupulse.hardware.awgs.zihdawg.make_compatible') as mock_make_compatible: - channel_pair = HDAWGChannelGroup(1, 2, 'foo', 3.4) + channel_pair = SingleDeviceChannelGroup(1, 2, 'foo', 3.4) with self.assertRaisesRegex(HDAWGValueError, 'Channel ID'): channel_pair.upload('bar', mock_loop, ('A'), (None, 'A', None, None), voltage_trafos) diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index 467ce0e0a..e964c296e 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -65,7 +65,7 @@ def test_build_waveform(self): # channel a in both with mock.patch.object(a, 'build_waveform', return_value=wf_a) as build_a, mock.patch.object(b, 'build_waveform', return_value=wf_b) as build_b: - with mock.patch('qupulse.pulses.arithmetic_pulse_template.ArithmeticWaveform', return_value=wf_arith) as wf_init: + with mock.patch('qupulse.pulses.arithmetic_pulse_template.ArithmeticWaveform.from_operator', return_value=wf_arith) as wf_init: wf_init.rhs_only_map.__getitem__.return_value.return_value = wf_rhs_only self.assertIs(wf_arith, arith.build_waveform(parameters=parameters, channel_mapping=channel_mapping)) wf_init.assert_called_once_with(wf_a, '-', wf_b) @@ -104,9 +104,9 @@ def test_integral(self): integrals_lhs = dict(a=ExpressionScalar('a_lhs'), b=ExpressionScalar('b')) integrals_rhs = dict(a=ExpressionScalar('a_rhs'), c=ExpressionScalar('c')) - lhs = DummyPulseTemplate(duration=4, defined_channels={'a', 'b'}, + lhs = DummyPulseTemplate(duration='t_dur', defined_channels={'a', 'b'}, parameter_names={'x', 'y'}, integrals=integrals_lhs) - rhs = DummyPulseTemplate(duration=4, defined_channels={'a', 'c'}, + rhs = DummyPulseTemplate(duration='t_dur', defined_channels={'a', 'c'}, parameter_names={'x', 'z'}, integrals=integrals_rhs) expected_plus = dict(a=ExpressionScalar('a_lhs + a_rhs'), @@ -118,14 +118,26 @@ def test_integral(self): self.assertEqual(expected_plus, (lhs + rhs).integral) self.assertEqual(expected_minus, (lhs - rhs).integral) + def test_initial_final_values(self): + lhs = DummyPulseTemplate(initial_values={'A': .1, 'B': 'b*2'}, final_values={'A': .2, 'B': 'b / 2'}) + rhs = DummyPulseTemplate(initial_values={'A': -4, 'B': 'b*2 + 1'}, final_values={'A': .2, 'B': '-b / 2 + c'}) + + minus = lhs - rhs + plus = lhs + rhs + self.assertEqual({'A': 4.1, 'B': -1}, minus.initial_values) + self.assertEqual({'A': 0, 'B': 'b - c'}, minus.final_values) + + self.assertEqual({'A': -3.9, 'B': 'b*4 + 1'}, plus.initial_values) + self.assertEqual({'A': .4, 'B': 'c'}, plus.final_values) + def test_as_expression(self): integrals_lhs = dict(a=ExpressionScalar('a_lhs'), b=ExpressionScalar('b')) integrals_rhs = dict(a=ExpressionScalar('a_rhs'), c=ExpressionScalar('c')) duration = 4 t = DummyPulseTemplate._AS_EXPRESSION_TIME - expr_lhs = {ch: i * t / duration for ch, i in integrals_lhs.items()} - expr_rhs = {ch: i * t / duration for ch, i in integrals_rhs.items()} + expr_lhs = {ch: i * t / duration**2 * 2 for ch, i in integrals_lhs.items()} + expr_rhs = {ch: i * t / duration**2 * 2 for ch, i in integrals_rhs.items()} lhs = DummyPulseTemplate(duration=duration, defined_channels={'a', 'b'}, parameter_names={'x', 'y'}, integrals=integrals_lhs) @@ -472,6 +484,16 @@ def test_integral(self): actual = ArithmeticPulseTemplate(pt, '/', mapping).integral self.assertEqual(expected, actual) + def test_initial_values(self): + lhs = DummyPulseTemplate(initial_values={'A': .3, 'B': 'b'}, defined_channels={'A', 'B'}) + apt = lhs + 'a' + self.assertEqual({'A': 'a + 0.3', 'B': 'b + a'}, apt.initial_values) + + def test_final_values(self): + lhs = DummyPulseTemplate(final_values={'A': .3, 'B': 'b'}, defined_channels={'A', 'B'}) + apt = lhs - 'a' + self.assertEqual({'A': '-a + .3', 'B': 'b - a'}, apt.final_values) + def test_simple_attributes(self): lhs = DummyPulseTemplate(defined_channels={'a', 'b'}, duration=ExpressionScalar('t_dur'), measurement_names={'m1'}) @@ -503,12 +525,12 @@ def test_try_operation(self): self.assertIs(NotImplemented, try_operation(npt, '//', 6)) def test_build_waveform(self): - pt = DummyPulseTemplate(defined_channels={'a'}) + pt = DummyPulseTemplate(defined_channels={'a'}, duration=6) parameters = dict(x=5., y=5.7) channel_mapping = dict(a='u', b='v') - inner_wf = mock.Mock(spec=DummyWaveform) + inner_wf = DummyWaveform(duration=6, defined_channels={'a'}) trafo = mock.Mock(spec=IdentityTransformation()) arith = ArithmeticPulseTemplate(pt, '-', 6) diff --git a/tests/pulses/constant_pulse_template_tests.py b/tests/pulses/constant_pulse_template_tests.py index 7a015b26e..3651fe6e7 100644 --- a/tests/pulses/constant_pulse_template_tests.py +++ b/tests/pulses/constant_pulse_template_tests.py @@ -4,13 +4,16 @@ import qupulse._program.waveforms import qupulse.utils.sympy from qupulse.pulses import TablePT, FunctionPT, AtomicMultiChannelPT, MappingPT +from qupulse.pulses.multi_channel_pulse_template import AtomicMultiChannelPulseTemplate from qupulse.pulses.plotting import plot from qupulse.pulses.sequence_pulse_template import SequencePulseTemplate from qupulse._program._loop import make_compatible from qupulse._program.waveforms import ConstantWaveform -from qupulse.pulses.constant_pulse_template import ConstantPulseTemplate +from qupulse.serialization import DictBackend, PulseStorage +from qupulse.pulses.constant_pulse_template import ConstantPulseTemplate, ExpressionScalar, TimeType +from tests.serialization_tests import SerializableTests class TestConstantPulseTemplate(unittest.TestCase): @@ -24,6 +27,9 @@ def test_ConstantPulseTemplate(self): self.assertIn('ConstantPulseTemplate', str(pt)) self.assertIn('ConstantPulseTemplate', repr(pt)) + self.assertEqual({'P1': .5, 'P2': .25}, pt.initial_values) + self.assertEqual({'P1': .5, 'P2': .25}, pt.final_values) + def test_zero_duration(self): p1 = ConstantPulseTemplate(10, {'P1': 1.}) p2 = ConstantPulseTemplate(0, {'P1': 1.}) @@ -87,6 +93,28 @@ def test_regression_sequencept_with_mappingpt(self): plot(p) self.assertEqual(p.defined_channels, {'C2'}) + def test_expressions(self): + cpt = ConstantPulseTemplate('duration', {'A': 5.4, 'B': 'amplitude_b'}) + self.assertEqual({'duration', 'amplitude_b'}, cpt.parameter_names) + self.assertEqual(ExpressionScalar('duration'), cpt.duration) + + self.assertIsNone(cpt.build_waveform({'duration': 0., 'amplitude_b': 1.}, {'A': 'A', 'B': 'B'})) + self.assertIsNone(cpt.build_waveform({'duration': 1., 'amplitude_b': 1.}, {'A': None, 'B': None})) + + wf1 = ConstantWaveform(duration=TimeType.from_float(1.4), channel='C', amplitude=1.6) + wf2 = ConstantWaveform(duration=TimeType.from_float(1.5), channel='A', amplitude=5.4) + self.assertEqual(wf1, cpt.build_waveform({'duration': 1.4, 'amplitude_b': 1.6}, {'A': None, 'B': 'C'})) + self.assertEqual(wf2, cpt.build_waveform({'duration': 1.5, 'amplitude_b': None}, {'A': 'A', 'B': None})) + + wf3 = ConstantWaveform.from_mapping(duration=TimeType.from_float(1.6), constant_values={'C': 5.4, 'B': -.3}) + self.assertEqual(wf3, cpt.build_waveform({'duration': 1.6, 'amplitude_b': -.3}, {'A': 'C', 'B': 'B'})) + + def test_regression_defined_channels(self): + p=ConstantPulseTemplate(100, {'a': 1.}) + q=ConstantPulseTemplate(100, {'b': 1.}) + pt=AtomicMultiChannelPulseTemplate(p, q) + self.assertEqual(pt.defined_channels, {'a', 'b'}) + def test_build_waveform(self): tpt = ConstantPulseTemplate(200, {'C1': 2, 'C2': 3}) @@ -115,3 +143,46 @@ def test_build_waveform(self): ) self.assertIsNone(tpt.build_waveform({}, {'C1': None, 'C2': None})) + + +class ConstantPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): + @property + def class_to_test(self): + return ConstantPulseTemplate + + def make_kwargs(self): + return { + 'name': 'yoho', + 'duration': 'dur', + 'amplitude_dict': {'int': 1, 'float': -3.4, 'expr': 'x + y'}, + 'measurements': [('m', 1, 1), ('foo', 'z', 'o')], + } + + def assert_equal_instance_except_id(self, lhs: ConstantPulseTemplate, rhs: ConstantPulseTemplate): + self.assertIsInstance(lhs, ConstantPulseTemplate) + self.assertIsInstance(rhs, ConstantPulseTemplate) + self.assertEqual(lhs._name, rhs._name) + self.assertEqual(lhs.measurement_declarations, rhs.measurement_declarations) + self.assertEqual(lhs._amplitude_dict, rhs._amplitude_dict) + self.assertEqual(lhs.duration, rhs.duration) + + def test_legacy_deserialization(self): + serialized = """{ + "#amplitudes": { + "ZI0_A_MARKER_FRONT": 1 + }, + "#type": "qupulse.pulses.constant_pulse_template.ConstantPulseTemplate", + "duration": 62848.0, + "name": "constant_pulse" + }""" + backend = DictBackend() + backend.storage['my_pt'] = serialized + + ps = PulseStorage(backend) + + deserialized = ps['my_pt'] + expected = ConstantPulseTemplate( + amplitude_dict={"ZI0_A_MARKER_FRONT": 1}, + duration=62848, name="constant_pulse" + ) + self.assert_equal_instance(expected, deserialized) diff --git a/tests/pulses/function_pulse_tests.py b/tests/pulses/function_pulse_tests.py index a2fabcf2c..536bdc0ad 100644 --- a/tests/pulses/function_pulse_tests.py +++ b/tests/pulses/function_pulse_tests.py @@ -84,6 +84,14 @@ def test_integral(self) -> None: pulse = FunctionPulseTemplate('sin(0.5*t+b)', '2*Tmax') self.assertEqual({'default': Expression('2.0*cos(b) - 2.0*cos(1.0*Tmax+b)')}, pulse.integral) + def test_initial_values(self): + fpt = FunctionPulseTemplate('3 + exp(t * a)', '3.14', channel='A') + self.assertEqual({'A': 4}, fpt.initial_values) + + def test_final_values(self): + fpt = FunctionPulseTemplate('3 + exp(t * a)', '3.14', channel='A') + self.assertEqual({'A': Expression('3 + exp(3.14*a)')}, fpt.final_values) + def test_as_expression(self): pulse = FunctionPulseTemplate('sin(0.5*t+b)', '2*Tmax') expr = sympy.sin(0.5 * pulse._AS_EXPRESSION_TIME + sympy.sympify('b')) diff --git a/tests/pulses/loop_pulse_template_tests.py b/tests/pulses/loop_pulse_template_tests.py index a1c8b41f3..d502a860a 100644 --- a/tests/pulses/loop_pulse_template_tests.py +++ b/tests/pulses/loop_pulse_template_tests.py @@ -9,7 +9,7 @@ from qupulse.expressions import Expression, ExpressionScalar from qupulse.pulses.loop_pulse_template import ForLoopPulseTemplate, ParametrizedRange,\ - LoopIndexNotUsedException, LoopPulseTemplate, _get_for_loop_scope, _ForLoopScope + LoopIndexNotUsedException, LoopPulseTemplate, _ForLoopScope, _ForLoopScope from qupulse.pulses.parameters import ConstantParameter, InvalidParameterNameException, ParameterConstraintViolation,\ ParameterNotProvidedException, ParameterConstraint @@ -103,7 +103,7 @@ def test_init(self): with self.assertRaises(InvalidParameterNameException): ForLoopPulseTemplate(body=dt, loop_index='1i', loop_range=6) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): ForLoopPulseTemplate(body=dt, loop_index='i', loop_range=slice(None)) with self.assertRaises(LoopIndexNotUsedException): @@ -186,6 +186,19 @@ def test_integral(self) -> None: self.assertEqual(expected_simplified, actual_simplified) + def test_initial_values(self): + dpt = DummyPulseTemplate(initial_values={'A': 'a + 3 + i', 'B': 7}, parameter_names={'i', 'a'}) + fpt = ForLoopPulseTemplate(dpt, 'i', (1, 'n', 2)) + self.assertEqual({'A': 'a+4', 'B': 7}, fpt.initial_values) + + def test_final_values(self): + dpt = DummyPulseTemplate(final_values={'A': 'a + 3 + i', 'B': 7}, parameter_names={'i', 'a'}) + fpt = ForLoopPulseTemplate(dpt, 'i', 'n') + self.assertEqual({'A': 'a+3+Max(0, floor(n) - 1)', 'B': 7}, fpt.final_values) + + fpt_fin = ForLoopPulseTemplate(dpt, 'i', (1, 'n', 2)).final_values + self.assertEqual('a + 10', fpt_fin['A'].evaluate_symbolic({'n': 8})) + class ForLoopTemplateSequencingTests(MeasurementWindowTestCase): def test_create_program_constraint_on_loop_var_exception(self): @@ -377,7 +390,7 @@ def test_create_program(self) -> None: global_transformation=global_transformation, to_single_waveform=to_single_waveform) expected_create_program_calls = [mock.call(**expected_create_program_kwargs, - scope=_get_for_loop_scope(scope, 'i', i)) + scope=_ForLoopScope(scope, 'i', i)) for i in (1, 3)] with mock.patch.object(flt, 'validate_scope') as validate_scope: @@ -452,7 +465,7 @@ def assert_equal_instance_except_id(self, lhs: ForLoopPulseTemplate, rhs: ForLoo self.assertIsInstance(rhs, ForLoopPulseTemplate) self.assertEqual(lhs.body, rhs.body) self.assertEqual(lhs.loop_index, rhs.loop_index) - self.assertEqual(lhs.loop_range.to_tuple(), rhs.loop_range.to_tuple()) + self.assertEqual(lhs.loop_range, rhs.loop_range) self.assertEqual(lhs.parameter_constraints, rhs.parameter_constraints) self.assertEqual(lhs.measurement_declarations, rhs.measurement_declarations) diff --git a/tests/pulses/mapping_pulse_template_tests.py b/tests/pulses/mapping_pulse_template_tests.py index 24fe14dea..2fb5e8af7 100644 --- a/tests/pulses/mapping_pulse_template_tests.py +++ b/tests/pulses/mapping_pulse_template_tests.py @@ -255,13 +255,20 @@ def test_integral(self) -> None: self.assertEqual({'a': Expression('2*f'), 'B': Expression('-3.2*f+2.3')}, pulse.integral) + def test_initial_final_values(self): + dpt = DummyPulseTemplate(initial_values={'A': 'a', 'B': 'b'}, final_values={'A': 'a + c', 'B': 'b - 3'}, + parameter_names=set('abc')) + mapped = MappingPulseTemplate(dpt, parameter_mapping={'a': 'c'}, allow_partial_parameter_mapping=True) + self.assertEqual({'A': 'c', 'B': 'b'}, mapped.initial_values) + self.assertEqual({'A': 'c+c', 'B': 'b-3'}, mapped.final_values) + def test_as_expression(self): from sympy.abc import f, k, b duration = 5 dummy = DummyPulseTemplate(defined_channels={'A', 'B', 'C'}, parameter_names={'k', 'f', 'b'}, - integrals={'A': Expression(2 * k), - 'B': Expression(-3.2*f+b), + integrals={'A': Expression(k), + 'B': Expression(f+b), 'C': Expression(1)}, duration=duration) t = DummyPulseTemplate._AS_EXPRESSION_TIME dummy_expr = {ch: i * t / duration for ch, i in dummy._integrals.items()} @@ -270,8 +277,8 @@ def test_as_expression(self): allow_partial_parameter_mapping=True) expected = { - 'a': Expression(2*f*t/duration), - 'B': Expression((-3.2*f + 2.3)*t/duration), + 'a': Expression(t*f/duration**2 * 2), + 'B': Expression((f + 2.3)*t/duration**2 * 2), } self.assertEqual(expected, pulse._as_expression()) diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 7de38ae72..be8d9feb7 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -190,9 +190,9 @@ def test_as_expression(self): DummyPulseTemplate(duration='t1', defined_channels={'B', 'C'}, integrals={'B': ExpressionScalar('t1-t0*3.1'), 'C': ExpressionScalar('l')})] pulse = AtomicMultiChannelPulseTemplate(*sts) - self.assertEqual({'A': ExpressionScalar(sympify('(2+k) / t1') * pulse._AS_EXPRESSION_TIME), - 'B': ExpressionScalar(sympify('(t1-t0*3.1)/t1') * pulse._AS_EXPRESSION_TIME), - 'C': ExpressionScalar(sympify('l/t1') * pulse._AS_EXPRESSION_TIME)}, + self.assertEqual({'A': sts[0]._as_expression()['A'], + 'B': sts[1]._as_expression()['B'], + 'C': sts[1]._as_expression()['C']}, pulse._as_expression()) @@ -376,6 +376,16 @@ def test_integral(self): 'Z': ExpressionScalar('a*t1')} self.assertEqual(expected_integral, pccpt.integral) + def test_initial_values(self): + dpt = DummyPulseTemplate(initial_values={'A': 'a', 'B': 'b'}) + par = ParallelConstantChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'}) + self.assertEqual({'A': 'a', 'B': 'b2', 'C': 'c'}, par.initial_values) + + def test_final_values(self): + dpt = DummyPulseTemplate(final_values={'A': 'a', 'B': 'b'}) + par = ParallelConstantChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'}) + self.assertEqual({'A': 'a', 'B': 'b2', 'C': 'c'}, par.final_values) + def test_get_overwritten_channels_values(self): template = DummyPulseTemplate(duration='t1', defined_channels={'X', 'Y'}, parameter_names={'a', 'b'}, measurement_names={'M'}) diff --git a/tests/pulses/plotting_tests.py b/tests/pulses/plotting_tests.py index dfd7ca81d..344cc96d3 100644 --- a/tests/pulses/plotting_tests.py +++ b/tests/pulses/plotting_tests.py @@ -10,6 +10,7 @@ except ImportError: qupulse_rs = None +from qupulse.pulses import ConstantPT from qupulse.pulses.plotting import PlottingNotPossibleException, render, plot from qupulse.pulses.table_pulse_template import TablePulseTemplate from qupulse.pulses.sequence_pulse_template import SequencePulseTemplate @@ -98,6 +99,12 @@ def test_plot_empty_pulse(self) -> None: with self.assertWarnsRegex(UserWarning, "empty", msg="plot() did not issue a warning for an empty pulse"): plot(pt, dict(), show=False) + def test_plot_pulse_automatic_sample_rate(self) -> None: + import matplotlib + matplotlib.use('svg') # use non-interactive backend so that test does not fail on travis + pt=ConstantPT(100, {'a': 1}) + plot(pt, sample_rate=None) + def test_bug_447(self): """Adapted code from https://github.com/qutech/qupulse/issues/447""" TablePT = TablePulseTemplate diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index 3277319df..4e3163d34 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -104,6 +104,29 @@ def test_integral(self) -> None: 'Y': 2 * (4.1 + 4) / 2 + (5 - 2) * 4}, integral) + def test_initial_final_values(self): + pulse = PointPulseTemplate( + [(1, (2, 'b'), 'hold'), + (3, (0, 0), 'linear'), + (4, (2, 'c'), 'jump'), + (5, (8, 'd'), 'hold')], + [0, 'other_channel'] + ) + self.assertEqual({0: 2, 'other_channel': 'b'}, pulse.initial_values) + self.assertEqual({0: 8, 'other_channel': 'd'}, pulse.final_values) + + pulse = PointPulseTemplate( + [(1, 'b', 'hold'), + (3, (0, 0), 'linear'), + (4, (2, 'c'), 'jump'), + (5, 'd', 'hold')], + [0, 'other_channel'] + ) + self.assertEqual({0: 'IndexedBroadcast(b, (2,), 0)', 'other_channel': 'IndexedBroadcast(b, (2,), 1)'}, + pulse.initial_values) + self.assertEqual({0: 'IndexedBroadcast(d, (2,), 0)', 'other_channel': 'IndexedBroadcast(d, (2,), 1)'}, + pulse.final_values) + class PointPulseTemplateSequencingTests(unittest.TestCase): def test_build_waveform_empty(self): diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index e6ae3cee8..a26c79005 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -82,6 +82,14 @@ def measurement_names(self): def integral(self) -> Dict[ChannelID, ExpressionScalar]: raise NotImplementedError() + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + raise NotImplementedError() + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + raise NotImplementedError() + def __repr__(self): return f"PulseTemplateStub(id={id(self)})" @@ -269,7 +277,9 @@ def test__create_program_single_waveform(self): parent_loop=inner_program_builder) to_waveform.assert_called_once_with(expected_inner_program) - self.assertEqual(expected_program, program_builder.to_program()) + self.assertEqual(expected_program, program_builder.to_program(), + f"To single waveform failed with to_single_waveform={to_single_waveform!r} and" + f" global_transformation={global_transformation!r}") def test_create_program_defaults(self) -> None: template = PulseTemplateStub(defined_channels={'A', 'B'}, parameter_names={'foo'}, measurement_names={'hugo', 'foo'}) @@ -447,4 +457,3 @@ def test_internal_create_program_volatile(self): to_single_waveform=set(), global_transformation=None) self.assertEqual(Loop(), program) - diff --git a/tests/pulses/repetition_pulse_template_tests.py b/tests/pulses/repetition_pulse_template_tests.py index 8e9979932..585dd39ba 100644 --- a/tests/pulses/repetition_pulse_template_tests.py +++ b/tests/pulses/repetition_pulse_template_tests.py @@ -89,6 +89,16 @@ def test_integral(self) -> None: template = RepetitionPulseTemplate(dummy, Expression('2+m')) self.assertEqual({'A': Expression('(2+m)*(foo+2)'), 'B': Expression('(2+m)*(k*3+x**2)')}, template.integral) + def test_initial_values(self): + dummy = DummyPulseTemplate(initial_values={'A': ExpressionScalar('a + 3')}) + rpt = RepetitionPulseTemplate(dummy, repetition_count='n') + self.assertEqual(dummy.initial_values, rpt.initial_values) + + def test_final_values(self): + dummy = DummyPulseTemplate(final_values={'A': ExpressionScalar('a + 3')}) + rpt = RepetitionPulseTemplate(dummy, repetition_count='n') + self.assertEqual(dummy.final_values, rpt.final_values) + def test_parameter_names_param_only_in_constraint(self) -> None: pt = RepetitionPulseTemplate(DummyPulseTemplate(parameter_names={'a'}), 'n', parameter_constraints=['a None: self.assertEqual({'A': ExpressionScalar('k+2*b+7*(b-f)'), 'B': ExpressionScalar('0.24*f')}, pulse.integral) + def test_initial_final_values(self): + pt1 = DummyPulseTemplate(initial_values={'A': 'a'}) + pt2 = DummyPulseTemplate(final_values={'A': 'b'}) + spt = pt1 @ pt2 + self.assertEqual(pt1.initial_values, spt.initial_values) + self.assertEqual(pt2.final_values, spt.final_values) + def test_concatenate(self): a = DummyPulseTemplate(parameter_names={'foo'}, defined_channels={'A'}) b = DummyPulseTemplate(parameter_names={'bar'}, defined_channels={'A'}) diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 59b7cda45..56c074ade 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -120,6 +120,8 @@ def unsafe_sample(self, if self.sample_output is not None: if isinstance(self.sample_output, dict): output_array[:] = self.sample_output[channel] + elif callable(self.sample_output): + output_array[:] = self.sample_output(sample_times) else: output_array[:] = self.sample_output else: @@ -178,7 +180,7 @@ def expression(self) -> ExpressionScalar: def evaluate_integral(self, t0, v0, t1, v1): """ Evaluate integral using arguments v0, t0, v1, t1 """ raise - + class DummyPulseTemplate(AtomicPulseTemplate): def __init__(self, @@ -190,6 +192,8 @@ def __init__(self, measurement_names: Set[str] = set(), measurements: list=list(), integrals: Dict[ChannelID, ExpressionScalar]=None, + initial_values: Dict[ChannelID, Any]=None, + final_values: Dict[ChannelID, Any]=None, program: Optional[Loop]=None, identifier=None, registry=None) -> None: @@ -213,6 +217,16 @@ def __init__(self, self._program = program self._register(registry=registry) + if initial_values is None: + self._initial_values = {ch: ExpressionScalar(0) for ch in self.defined_channels} + else: + self._initial_values = {ch: ExpressionScalar(val) for ch, val in initial_values.items()} + + if final_values is None: + self._final_values = {ch: ExpressionScalar(0) for ch in self.defined_channels} + else: + self._final_values = {ch: ExpressionScalar(val) for ch, val in final_values.items()} + if integrals is not None: assert isinstance(integrals, Mapping) @@ -281,5 +295,13 @@ def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: assert self.duration != 0 t = self._AS_EXPRESSION_TIME duration = self.duration.underlying_expression - return {ch: ExpressionScalar(integral.underlying_expression*t/duration) + return {ch: ExpressionScalar(integral.underlying_expression*t/duration**2 * 2) for ch, integral in self.integral.items()} + + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._initial_values + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._final_values diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index c77311656..9dbe5e81a 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -475,6 +475,13 @@ def test_integral(self) -> None: self.assertEqual(expected, pulse.integral) + def test_initial_final_values(self): + pulse = TablePulseTemplate(entries={0: [(1, 2), (3, 0, 'linear'), (4, 2, 'jump'), (5, 8, 'hold')], + 'other_channel': [(0, 7), (2, 0, 'linear'), (10, 0)], + 'symbolic': [(3, 'a'), ('b', 4, 'hold'), ('c', Expression('d'), 'linear')]}) + self.assertEqual({0: 2, 'other_channel': 7, 'symbolic': 'a'}, pulse.initial_values) + self.assertEqual({0: 8, 'other_channel': 0, 'symbolic': 'd'}, pulse.final_values) + def test_as_expression(self): pulse = TablePulseTemplate(entries={0: [(0, 0), (1, 2), (3, 0, 'linear'), (4, 2, 'jump'), (5, 8, 'hold')], 'other_channel': [(0, 7), (2, 0, 'linear'), (10, 0)], diff --git a/tests/utils/performance_tests.py b/tests/utils/performance_tests.py new file mode 100644 index 000000000..d158dce5c --- /dev/null +++ b/tests/utils/performance_tests.py @@ -0,0 +1,30 @@ +import unittest + +import numpy as np + +from qupulse.utils.performance import _time_windows_to_samples_numba, _time_windows_to_samples_numpy + + +class TimeWindowsToSamplesTest(unittest.TestCase): + @staticmethod + def assert_implementations_equal(begins, lengths, sample_rate): + np.testing.assert_equal( + _time_windows_to_samples_numba(begins, lengths, sample_rate), + _time_windows_to_samples_numpy(begins, lengths, sample_rate) + ) + + def test_monotonic(self): + begins = np.array([101.3, 123.6218764354, 176.31, 763454.776]) + lengths = np.array([6.4234, 24.8654413, 8765.45, 12543.]) + + for sr in (0.1, 1/9, 1., 2.764423123563463412342, 100.322): + self.assert_implementations_equal(begins, lengths, sr) + + def test_unsorted(self): + begins = np.array([101.3, 176.31, 763454.776, 123.6218764354]) + lengths = np.array([6.4234, 8765.45, 12543., 24.8654413]) + + for sr in (0.1, 1/9, 1., 2.764423123563463412342, 100.322): + self.assert_implementations_equal(begins, lengths, sr) + +