Skip to content

Commit d48aa2b

Browse files
committed
Remove dead code, improve readability and create dedicated vm.
1 parent cfcee75 commit d48aa2b

File tree

2 files changed

+108
-79
lines changed

2 files changed

+108
-79
lines changed

qupulse/program/measurement.py

Lines changed: 85 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,68 @@
1-
from typing import Sequence, Mapping, Iterable, Optional, Union, ContextManager
1+
import contextlib
2+
from typing import Sequence, Mapping, Iterable, Optional, Union, ContextManager, Callable
23
from dataclasses import dataclass
4+
from functools import cached_property
35

46
import numpy
5-
from rich.measure import Measurement
67

78
from qupulse.utils.types import TimeType
89
from qupulse.program import (ProgramBuilder, Program, HardwareVoltage, HardwareTime,
910
MeasurementWindow, Waveform, RepetitionCount, SimpleExpression)
1011
from qupulse.parameter_scope import Scope
1112

1213

14+
MeasurementID = str | int
15+
16+
17+
@dataclass
18+
class LoopLabel:
19+
idx: int
20+
runtime_name: str | None
21+
count: RepetitionCount
22+
23+
1324
@dataclass
14-
class MeasurementNode:
15-
windows: Sequence[MeasurementWindow]
25+
class Measure:
26+
meas_id: MeasurementID
27+
delay: HardwareTime
28+
length: HardwareTime
29+
30+
31+
@dataclass
32+
class Wait:
1633
duration: HardwareTime
1734

1835

1936
@dataclass
20-
class MeasurementRepetition(MeasurementNode):
21-
body: MeasurementNode
22-
count: RepetitionCount
37+
class LoopJmp:
38+
idx: int
39+
40+
41+
Command = Union[LoopLabel, LoopJmp, Wait, Measure]
42+
2343

2444
@dataclass
25-
class MeasurementSequence(MeasurementNode):
26-
nodes: Sequence[tuple[HardwareTime, MeasurementNode]]
45+
class MeasurementInstructions(Program):
46+
commands: Sequence[Command]
47+
48+
@cached_property
49+
def duration(self) -> float:
50+
latest = 0.
51+
52+
def process(_, begin, length):
53+
nonlocal latest
54+
end = begin + length
55+
latest = max(latest, end)
56+
57+
vm = MeasurementVM(process)
58+
vm.execute(commands=self.commands)
59+
return latest
2760

2861

2962
@dataclass
3063
class MeasurementFrame:
3164
commands: list['Command']
32-
has_duration: bool
33-
34-
MeasurementID = str | int
65+
keep: bool
3566

3667

3768
class MeasurementBuilder(ProgramBuilder):
@@ -48,12 +79,11 @@ def _with_new_frame(self, measurements):
4879
self._frames.append(MeasurementFrame([], False))
4980
yield self
5081
frame = self._frames.pop()
51-
if not frame.has_duration:
82+
if not frame.keep:
5283
return
53-
parent = self._frames[-1]
54-
parent.has_duration = True
55-
if measurements:
56-
parent.commands.extend(map(Measure, measurements))
84+
self.measure(measurements)
85+
# measure does not keep if there are no measurements
86+
self._frames[-1].keep = True
5787
return frame.commands
5888

5989
def inner_scope(self, scope: Scope) -> Scope:
@@ -68,19 +98,19 @@ def inner_scope(self, scope: Scope) -> Scope:
6898
def hold_voltage(self, duration: HardwareTime, voltages: Mapping[str, HardwareVoltage]):
6999
"""Supports dynamic i.e. for loop generated offsets and duration"""
70100
self._frames[-1].commands.append(Wait(duration))
71-
self._frames[-1].has_duration = True
101+
self._frames[-1].keep = True
72102

73103
def play_arbitrary_waveform(self, waveform: Waveform):
74104
""""""
75105
self._frames[-1].commands.append(Wait(waveform.duration))
76-
self._frames[-1].has_duration = True
106+
self._frames[-1].keep = True
77107

78108
def measure(self, measurements: Optional[Sequence[MeasurementWindow]]):
79109
"""Unconditionally add given measurements relative to the current position."""
80110
if measurements:
81111
commands = self._frames[-1].commands
82112
commands.extend(Measure(*meas) for meas in measurements)
83-
self._frames[-1].has_duration = True
113+
self._frames[-1].keep = True
84114

85115
def with_repetition(self, repetition_count: RepetitionCount,
86116
measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']:
@@ -92,10 +122,11 @@ def with_repetition(self, repetition_count: RepetitionCount,
92122

93123
self._label_counter += 1
94124
label_idx = self._label_counter
95-
parent.commands.append(LoopLabel(idx=label_idx, runtime_name=None, count=RepetitionCount))
125+
parent.commands.append(LoopLabel(idx=label_idx, runtime_name=None, count=repetition_count))
96126
parent.commands.extend(new_commands)
97127
parent.commands.append(LoopJmp(idx=label_idx))
98128

129+
@contextlib.contextmanager
99130
def with_sequence(self,
100131
measurements: Optional[Sequence[MeasurementWindow]] = None) -> ContextManager['ProgramBuilder']:
101132
"""
@@ -112,6 +143,7 @@ def with_sequence(self,
112143
parent = self._frames[-1]
113144
parent.commands.extend(new_commands)
114145

146+
@contextlib.contextmanager
115147
def new_subprogram(self, global_transformation: 'Transformation' = None) -> ContextManager['ProgramBuilder']:
116148
"""Create a context managed program builder whose contents are translated into a single waveform upon exit if
117149
it is not empty."""
@@ -136,43 +168,16 @@ def time_reversed(self) -> ContextManager['ProgramBuilder']:
136168
self._frames.append(MeasurementFrame([], False))
137169
yield self
138170
frame = self._frames.pop()
139-
if not frame.has_duration:
171+
if not frame.keep:
140172
return
141173

142-
self._frames[-1].has_duration = True
174+
self._frames[-1].keep = True
143175
self._frames[-1].commands.extend(_reversed_commands(frame.commands))
144176

145177
def to_program(self) -> Optional[Program]:
146178
"""Further addition of new elements might fail after finalizing the program."""
147-
if self._frames[0].has_duration:
148-
return self._frames[0].commands
149-
150-
151-
@dataclass
152-
class LoopLabel:
153-
idx: int
154-
runtime_name: str | None
155-
count: RepetitionCount
156-
157-
158-
@dataclass
159-
class Measure:
160-
meas_id: MeasurementID
161-
delay: HardwareTime
162-
length: HardwareTime
163-
164-
165-
@dataclass
166-
class Wait:
167-
duration: HardwareTime
168-
169-
170-
@dataclass
171-
class LoopJmp:
172-
idx: int
173-
174-
175-
Command = Union[LoopLabel, LoopJmp, Wait, Measure]
179+
if self._frames[0].keep:
180+
return MeasurementInstructions(self._frames[0].commands)
176181

177182

178183
def _reversed_commands(cmds: Sequence[Command]) -> Sequence[Command]:
@@ -202,30 +207,26 @@ def _reversed_commands(cmds: Sequence[Command]) -> Sequence[Command]:
202207
return reversed_cmds
203208

204209

205-
def to_table(commands: Sequence[Command]) -> dict[str, numpy.ndarray]:
206-
time = TimeType(0)
207-
208-
memory = {}
209-
counts = [None]
210+
class MeasurementVM:
211+
"""A VM that is capable of executing the measurement commands"""
210212

211-
tables = {}
213+
def __init__(self, callback: Callable[[str, float, float], None]):
214+
self._time = TimeType(0)
215+
self._memory = {}
216+
self._counts = {}
217+
self._callback = callback
212218

213-
def eval_hardware_time(t: HardwareTime):
219+
def _eval_hardware_time(self, t: HardwareTime):
214220
if isinstance(t, SimpleExpression):
215221
value = t.base
216222
for (factor_name, factor_val) in t.offsets.items():
217-
count = counts[memory[factor_name]]
223+
count = self._counts[self._memory[factor_name]]
218224
value += factor_val * count
219225
return value
220226
else:
221227
return t
222228

223-
def execute(sequence: Sequence[Command]) -> int:
224-
nonlocal time
225-
nonlocal tables
226-
nonlocal memory
227-
nonlocal counts
228-
229+
def _execute_after_label(self, sequence: Sequence[Command]) -> int:
229230
skip = 0
230231
for idx, cmd in enumerate(sequence):
231232
if idx < skip:
@@ -234,23 +235,30 @@ def execute(sequence: Sequence[Command]) -> int:
234235
return idx
235236
elif isinstance(cmd, LoopLabel):
236237
if cmd.runtime_name:
237-
memory[cmd.runtime_name] = cmd.idx
238-
if cmd.idx == len(counts):
239-
counts.append(0)
240-
assert cmd.idx < len(counts)
238+
self._memory[cmd.runtime_name] = cmd.idx
241239

242240
for iter_val in range(cmd.count):
243-
counts[cmd.idx] = iter_val
244-
pos = execute(sequence[idx + 1:])
241+
self._counts[cmd.idx] = iter_val
242+
pos = self._execute_after_label(sequence[idx + 1:])
245243
skip = idx + pos + 2
244+
246245
elif isinstance(cmd, Measure):
247-
meas_time = float(eval_hardware_time(cmd.delay) + time)
248-
meas_len = float(eval_hardware_time(cmd.length))
249-
tables.setdefault(cmd.meas_id, []).append((meas_time, meas_len))
246+
meas_time = float(self._eval_hardware_time(cmd.delay) + self._time)
247+
meas_len = float(self._eval_hardware_time(cmd.length))
248+
self._callback(cmd.meas_id, meas_time, meas_len)
249+
250250
elif isinstance(cmd, Wait):
251-
time += eval_hardware_time(cmd.duration)
251+
self._time += self._eval_hardware_time(cmd.duration)
252+
253+
def execute(self, commands: Sequence[Command]):
254+
self._execute_after_label(commands)
255+
256+
257+
def to_table(commands: Sequence[Command]) -> dict[str, numpy.ndarray]:
258+
tables = {}
252259

253-
execute(commands)
260+
vm = MeasurementVM(lambda name, begin, length: tables.setdefault(name, []).append((begin, length)))
261+
vm.execute(commands)
254262
return {
255263
name: numpy.array(values) for name, values in tables.items()
256264
}

tests/program/measurement_test.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def setUp(self):
2626

2727
def test_commands(self):
2828
builder = MeasurementBuilder()
29-
commands = self.pulse_template.create_program(program_builder=builder)
30-
self.assertEqual(self.commands, commands)
29+
instructions = self.pulse_template.create_program(program_builder=builder)
30+
self.assertEqual(self.commands, instructions.commands)
3131

3232
def test_table(self):
3333
table = to_table(self.commands)
@@ -37,3 +37,24 @@ def test_table(self):
3737
np.testing.assert_array_equal(self.table_b, tab_b)
3838

3939

40+
class ComplexPulse(TestCase):
41+
def setUp(self):
42+
hold = ConstantPT(10 ** 6, {'a': 1}, measurements=[('A', 10, 100), ('B', '1 + ii * 2 + jj', '3 + ii + jj')])
43+
dyn_hold = ConstantPT('10 ** 6 - 4 * ii', {'a': 1}, measurements=[('A', 10, 100), ('B', '1 + ii * 2 + jj', '3 + ii + jj')])
44+
45+
self.pulse_template = SequencePT(
46+
hold.with_repetition(2).with_iteration('ii', 100).with_repetition(2).with_iteration('jj', 200),
47+
measurements=[('A', 1, 100)]
48+
).with_repetition(2)
49+
50+
self.commands = []
51+
52+
def test_commands(self):
53+
builder = MeasurementBuilder()
54+
commands = self.pulse_template.create_program(program_builder=builder)
55+
to_table(commands.commands)
56+
raise NotImplementedError("TODO")
57+
58+
def test_table(self):
59+
60+
raise NotImplementedError("TODO")

0 commit comments

Comments
 (0)