1- from typing import Sequence , Mapping , Iterable , Optional , Union , ContextManager
1+ import contextlib
2+ from typing import Sequence , Mapping , Iterable , Optional , Union , ContextManager , Callable
23from dataclasses import dataclass
4+ from functools import cached_property
35
46import numpy
5- from rich .measure import Measurement
67
78from qupulse .utils .types import TimeType
89from qupulse .program import (ProgramBuilder , Program , HardwareVoltage , HardwareTime ,
910 MeasurementWindow , Waveform , RepetitionCount , SimpleExpression )
1011from qupulse .parameter_scope import Scope
1112
1213
14+ MeasurementID = str | int
15+
16+
17+ @dataclass
18+ class LoopLabel :
19+ idx : int
20+ runtime_name : str | None
21+ count : RepetitionCount
22+
23+
1324@dataclass
14- class MeasurementNode :
15- windows : Sequence [MeasurementWindow ]
25+ class Measure :
26+ meas_id : MeasurementID
27+ delay : HardwareTime
28+ length : HardwareTime
29+
30+
31+ @dataclass
32+ class Wait :
1633 duration : HardwareTime
1734
1835
1936@dataclass
20- class MeasurementRepetition (MeasurementNode ):
21- body : MeasurementNode
22- count : RepetitionCount
37+ class LoopJmp :
38+ idx : int
39+
40+
41+ Command = Union [LoopLabel , LoopJmp , Wait , Measure ]
42+
2343
2444@dataclass
25- class MeasurementSequence (MeasurementNode ):
26- nodes : Sequence [tuple [HardwareTime , MeasurementNode ]]
45+ class MeasurementInstructions (Program ):
46+ commands : Sequence [Command ]
47+
48+ @cached_property
49+ def duration (self ) -> float :
50+ latest = 0.
51+
52+ def process (_ , begin , length ):
53+ nonlocal latest
54+ end = begin + length
55+ latest = max (latest , end )
56+
57+ vm = MeasurementVM (process )
58+ vm .execute (commands = self .commands )
59+ return latest
2760
2861
2962@dataclass
3063class MeasurementFrame :
3164 commands : list ['Command' ]
32- has_duration : bool
33-
34- MeasurementID = str | int
65+ keep : bool
3566
3667
3768class MeasurementBuilder (ProgramBuilder ):
@@ -48,12 +79,11 @@ def _with_new_frame(self, measurements):
4879 self ._frames .append (MeasurementFrame ([], False ))
4980 yield self
5081 frame = self ._frames .pop ()
51- if not frame .has_duration :
82+ if not frame .keep :
5283 return
53- parent = self ._frames [- 1 ]
54- parent .has_duration = True
55- if measurements :
56- parent .commands .extend (map (Measure , measurements ))
84+ self .measure (measurements )
85+ # measure does not keep if there are no measurements
86+ self ._frames [- 1 ].keep = True
5787 return frame .commands
5888
5989 def inner_scope (self , scope : Scope ) -> Scope :
@@ -68,19 +98,19 @@ def inner_scope(self, scope: Scope) -> Scope:
6898 def hold_voltage (self , duration : HardwareTime , voltages : Mapping [str , HardwareVoltage ]):
6999 """Supports dynamic i.e. for loop generated offsets and duration"""
70100 self ._frames [- 1 ].commands .append (Wait (duration ))
71- self ._frames [- 1 ].has_duration = True
101+ self ._frames [- 1 ].keep = True
72102
73103 def play_arbitrary_waveform (self , waveform : Waveform ):
74104 """"""
75105 self ._frames [- 1 ].commands .append (Wait (waveform .duration ))
76- self ._frames [- 1 ].has_duration = True
106+ self ._frames [- 1 ].keep = True
77107
78108 def measure (self , measurements : Optional [Sequence [MeasurementWindow ]]):
79109 """Unconditionally add given measurements relative to the current position."""
80110 if measurements :
81111 commands = self ._frames [- 1 ].commands
82112 commands .extend (Measure (* meas ) for meas in measurements )
83- self ._frames [- 1 ].has_duration = True
113+ self ._frames [- 1 ].keep = True
84114
85115 def with_repetition (self , repetition_count : RepetitionCount ,
86116 measurements : Optional [Sequence [MeasurementWindow ]] = None ) -> Iterable ['ProgramBuilder' ]:
@@ -92,10 +122,11 @@ def with_repetition(self, repetition_count: RepetitionCount,
92122
93123 self ._label_counter += 1
94124 label_idx = self ._label_counter
95- parent .commands .append (LoopLabel (idx = label_idx , runtime_name = None , count = RepetitionCount ))
125+ parent .commands .append (LoopLabel (idx = label_idx , runtime_name = None , count = repetition_count ))
96126 parent .commands .extend (new_commands )
97127 parent .commands .append (LoopJmp (idx = label_idx ))
98128
129+ @contextlib .contextmanager
99130 def with_sequence (self ,
100131 measurements : Optional [Sequence [MeasurementWindow ]] = None ) -> ContextManager ['ProgramBuilder' ]:
101132 """
@@ -112,6 +143,7 @@ def with_sequence(self,
112143 parent = self ._frames [- 1 ]
113144 parent .commands .extend (new_commands )
114145
146+ @contextlib .contextmanager
115147 def new_subprogram (self , global_transformation : 'Transformation' = None ) -> ContextManager ['ProgramBuilder' ]:
116148 """Create a context managed program builder whose contents are translated into a single waveform upon exit if
117149 it is not empty."""
@@ -136,43 +168,16 @@ def time_reversed(self) -> ContextManager['ProgramBuilder']:
136168 self ._frames .append (MeasurementFrame ([], False ))
137169 yield self
138170 frame = self ._frames .pop ()
139- if not frame .has_duration :
171+ if not frame .keep :
140172 return
141173
142- self ._frames [- 1 ].has_duration = True
174+ self ._frames [- 1 ].keep = True
143175 self ._frames [- 1 ].commands .extend (_reversed_commands (frame .commands ))
144176
145177 def to_program (self ) -> Optional [Program ]:
146178 """Further addition of new elements might fail after finalizing the program."""
147- if self ._frames [0 ].has_duration :
148- return self ._frames [0 ].commands
149-
150-
151- @dataclass
152- class LoopLabel :
153- idx : int
154- runtime_name : str | None
155- count : RepetitionCount
156-
157-
158- @dataclass
159- class Measure :
160- meas_id : MeasurementID
161- delay : HardwareTime
162- length : HardwareTime
163-
164-
165- @dataclass
166- class Wait :
167- duration : HardwareTime
168-
169-
170- @dataclass
171- class LoopJmp :
172- idx : int
173-
174-
175- Command = Union [LoopLabel , LoopJmp , Wait , Measure ]
179+ if self ._frames [0 ].keep :
180+ return MeasurementInstructions (self ._frames [0 ].commands )
176181
177182
178183def _reversed_commands (cmds : Sequence [Command ]) -> Sequence [Command ]:
@@ -202,30 +207,26 @@ def _reversed_commands(cmds: Sequence[Command]) -> Sequence[Command]:
202207 return reversed_cmds
203208
204209
205- def to_table (commands : Sequence [Command ]) -> dict [str , numpy .ndarray ]:
206- time = TimeType (0 )
207-
208- memory = {}
209- counts = [None ]
210+ class MeasurementVM :
211+ """A VM that is capable of executing the measurement commands"""
210212
211- tables = {}
213+ def __init__ (self , callback : Callable [[str , float , float ], None ]):
214+ self ._time = TimeType (0 )
215+ self ._memory = {}
216+ self ._counts = {}
217+ self ._callback = callback
212218
213- def eval_hardware_time ( t : HardwareTime ):
219+ def _eval_hardware_time ( self , t : HardwareTime ):
214220 if isinstance (t , SimpleExpression ):
215221 value = t .base
216222 for (factor_name , factor_val ) in t .offsets .items ():
217- count = counts [ memory [factor_name ]]
223+ count = self . _counts [ self . _memory [factor_name ]]
218224 value += factor_val * count
219225 return value
220226 else :
221227 return t
222228
223- def execute (sequence : Sequence [Command ]) -> int :
224- nonlocal time
225- nonlocal tables
226- nonlocal memory
227- nonlocal counts
228-
229+ def _execute_after_label (self , sequence : Sequence [Command ]) -> int :
229230 skip = 0
230231 for idx , cmd in enumerate (sequence ):
231232 if idx < skip :
@@ -234,23 +235,30 @@ def execute(sequence: Sequence[Command]) -> int:
234235 return idx
235236 elif isinstance (cmd , LoopLabel ):
236237 if cmd .runtime_name :
237- memory [cmd .runtime_name ] = cmd .idx
238- if cmd .idx == len (counts ):
239- counts .append (0 )
240- assert cmd .idx < len (counts )
238+ self ._memory [cmd .runtime_name ] = cmd .idx
241239
242240 for iter_val in range (cmd .count ):
243- counts [cmd .idx ] = iter_val
244- pos = execute (sequence [idx + 1 :])
241+ self . _counts [cmd .idx ] = iter_val
242+ pos = self . _execute_after_label (sequence [idx + 1 :])
245243 skip = idx + pos + 2
244+
246245 elif isinstance (cmd , Measure ):
247- meas_time = float (eval_hardware_time (cmd .delay ) + time )
248- meas_len = float (eval_hardware_time (cmd .length ))
249- tables .setdefault (cmd .meas_id , []).append ((meas_time , meas_len ))
246+ meas_time = float (self ._eval_hardware_time (cmd .delay ) + self ._time )
247+ meas_len = float (self ._eval_hardware_time (cmd .length ))
248+ self ._callback (cmd .meas_id , meas_time , meas_len )
249+
250250 elif isinstance (cmd , Wait ):
251- time += eval_hardware_time (cmd .duration )
251+ self ._time += self ._eval_hardware_time (cmd .duration )
252+
253+ def execute (self , commands : Sequence [Command ]):
254+ self ._execute_after_label (commands )
255+
256+
257+ def to_table (commands : Sequence [Command ]) -> dict [str , numpy .ndarray ]:
258+ tables = {}
252259
253- execute (commands )
260+ vm = MeasurementVM (lambda name , begin , length : tables .setdefault (name , []).append ((begin , length )))
261+ vm .execute (commands )
254262 return {
255263 name : numpy .array (values ) for name , values in tables .items ()
256264 }
0 commit comments