Skip to content

Commit 21469b7

Browse files
authored
Merge pull request #866 from qutech/issues/865_time_reversal_program_builder
Add test and implementation for TimeReversalPulseTemplate ProgramBuilder support
2 parents 6243493 + d7d8b33 commit 21469b7

File tree

6 files changed

+125
-9
lines changed

6 files changed

+125
-9
lines changed

qupulse/expressions/sympy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def _try_to_numeric(self) -> Optional[numbers.Number]:
444444
return None
445445
if isinstance(self._original_expression, ALLOWED_NUMERIC_SCALAR_TYPES):
446446
return self._original_expression
447-
expr = self._sympified_expression
447+
expr = self._sympified_expression.doit()
448448
if isinstance(expr, bool):
449449
# sympify can return bool
450450
return expr

qupulse/program/linspace.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ class LinSpaceNode:
4444
def dependencies(self) -> Mapping[int, set]:
4545
raise NotImplementedError
4646

47+
def reversed(self, offset: int, lengths: list):
48+
"""Get the time reversed version of this linspace node. Since this is a non-local operation the arguments give
49+
the context.
50+
51+
Args:
52+
offset: Active iterations that are not reserved
53+
lengths: Lengths of the currently active iterations that have to be reversed
54+
55+
Returns:
56+
Time reversed version.
57+
"""
58+
raise NotImplementedError
59+
4760

4861
@dataclass
4962
class LinSpaceHold(LinSpaceNode):
@@ -60,13 +73,46 @@ def dependencies(self) -> Mapping[int, set]:
6073
for idx, factors in enumerate(self.factors)
6174
if factors}
6275

76+
def reversed(self, offset: int, lengths: list):
77+
if not lengths:
78+
return self
79+
# If the iteration length is `n`, the starting point is shifted by `n - 1`
80+
steps = [length - 1 for length in lengths]
81+
bases = []
82+
factors = []
83+
for ch_base, ch_factors in zip(self.bases, self.factors):
84+
if ch_factors is None or len(ch_factors) <= offset:
85+
bases.append(ch_base)
86+
factors.append(ch_factors)
87+
else:
88+
ch_reverse_base = ch_base + sum(step * factor
89+
for factor, step in zip(ch_factors[offset:], steps))
90+
reversed_factors = ch_factors[:offset] + tuple(-f for f in ch_factors[offset:])
91+
bases.append(ch_reverse_base)
92+
factors.append(reversed_factors)
93+
94+
if self.duration_factors is None or len(self.duration_factors) <= offset:
95+
duration_factors = self.duration_factors
96+
duration_base = self.duration_base
97+
else:
98+
duration_base = self.duration_base + sum((step * factor
99+
for factor, step in zip(self.duration_factors[offset:], steps)), TimeType(0))
100+
duration_factors = self.duration_factors[:offset] + tuple(-f for f in self.duration_factors[offset:])
101+
return LinSpaceHold(tuple(bases), tuple(factors), duration_base=duration_base, duration_factors=duration_factors)
102+
63103

64104
@dataclass
65105
class LinSpaceArbitraryWaveform(LinSpaceNode):
66106
"""This is just a wrapper to pipe arbitrary waveforms through the system."""
67107
waveform: Waveform
68108
channels: Tuple[ChannelID, ...]
69109

110+
def reversed(self, offset: int, lengths: list):
111+
return LinSpaceArbitraryWaveform(
112+
waveform=self.waveform.reversed(),
113+
channels=self.channels,
114+
)
115+
70116

71117
@dataclass
72118
class LinSpaceRepeat(LinSpaceNode):
@@ -81,6 +127,9 @@ def dependencies(self):
81127
dependencies.setdefault(idx, set()).update(deps)
82128
return dependencies
83129

130+
def reversed(self, offset: int, counts: list):
131+
return LinSpaceRepeat(tuple(node.reversed(offset, counts) for node in reversed(self.body)), self.count)
132+
84133

85134
@dataclass
86135
class LinSpaceIter(LinSpaceNode):
@@ -100,6 +149,12 @@ def dependencies(self):
100149
dependencies.setdefault(idx, set()).update(shortened)
101150
return dependencies
102151

152+
def reversed(self, offset: int, lengths: list):
153+
lengths.append(self.length)
154+
reversed_iter = LinSpaceIter(tuple(node.reversed(offset, lengths) for node in reversed(self.body)), self.length)
155+
lengths.pop()
156+
return reversed_iter
157+
103158

104159
class LinSpaceBuilder(ProgramBuilder):
105160
"""This program builder supports efficient translation of pulse templates that use symbolic linearly
@@ -214,6 +269,14 @@ def with_iteration(self, index_name: str, rng: range,
214269
if cmds:
215270
self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng)))
216271

272+
@contextlib.contextmanager
273+
def time_reversed(self) -> ContextManager['LinSpaceBuilder']:
274+
self._stack.append([])
275+
yield self
276+
inner = self._stack.pop()
277+
offset = len(self._ranges)
278+
self._stack[-1].extend(node.reversed(offset, []) for node in reversed(inner))
279+
217280
def to_program(self) -> Optional[Sequence[LinSpaceNode]]:
218281
if self._root():
219282
return self._root()
@@ -414,8 +477,10 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Comman
414477

415478

416479
class LinSpaceVM:
417-
def __init__(self, channels: int):
480+
def __init__(self, channels: int,
481+
sample_resolution: TimeType = TimeType.from_fraction(1, 2)):
418482
self.current_values = [np.nan] * channels
483+
self.sample_resolution = sample_resolution
419484
self.time = TimeType(0)
420485
self.registers = tuple({} for _ in range(channels))
421486

@@ -428,7 +493,20 @@ def __init__(self, channels: int):
428493

429494
def change_state(self, cmd: Union[Set, Increment, Wait, Play]):
430495
if isinstance(cmd, Play):
431-
raise NotImplementedError("TODO: Implement arbitrary waveform simulation")
496+
dt = self.sample_resolution
497+
t = TimeType(0)
498+
total_duration = cmd.waveform.duration
499+
while t <= total_duration and dt > 0:
500+
sample_time = np.array([float(t)])
501+
values = []
502+
for (idx, ch) in enumerate(cmd.channels):
503+
self.current_values[idx] = values.append(cmd.waveform.get_sampled(channel=ch, sample_times=sample_time)[0])
504+
self.history.append(
505+
(self.time, self.current_values.copy())
506+
)
507+
dt = min(total_duration - t, self.sample_resolution)
508+
self.time += dt
509+
t += dt
432510
elif isinstance(cmd, Wait):
433511
self.history.append(
434512
(self.time, self.current_values.copy())

qupulse/program/waveforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,3 +1277,6 @@ def compare_key(self) -> Hashable:
12771277

12781278
def reversed(self) -> 'Waveform':
12791279
return self._inner
1280+
1281+
def __repr__(self):
1282+
return f"ReversedWaveform(inner={self._inner!r})"

tests/expressions/expression_tests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,10 @@ def test_special_function_numeric_evaluation(self):
428428

429429
np.testing.assert_allclose(expected, result)
430430

431+
def test_try_to_numeric(self):
432+
expr = ExpressionScalar('Sum(9, (x, 0, 5), (y, 0, 7))')
433+
self.assertEqual(expr._try_to_numeric(), 9*6*8)
434+
431435
def test_evaluate_with_exact_rationals(self):
432436
expr = ExpressionScalar('1 / 3')
433437
self.assertEqual(TimeType.from_fraction(1, 3), expr.evaluate_with_exact_rationals({}))

tests/program/linspace_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def assert_vm_output_almost_equal(test: TestCase, expected, actual):
1414
test.assertEqual(t_e, t_a, f"Differing times in {idx} element")
1515
test.assertEqual(len(vals_e), len(vals_a), f"Differing channel count in {idx} element")
1616
for ch, (val_e, val_a) in enumerate(zip(vals_e, vals_a)):
17-
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} element channel {ch}")
17+
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} of {len(expected)} element channel {ch}")
1818

1919

2020
class SingleRampTest(TestCase):

tests/pulses/time_reversal_pulse_template_tests.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from qupulse.pulses.time_reversal_pulse_template import TimeReversalPulseTemplate
88
from qupulse.utils.types import TimeType
99
from qupulse.expressions import ExpressionScalar
10-
10+
from qupulse.program.loop import LoopBuilder
11+
from qupulse.program.linspace import LinSpaceBuilder, LinSpaceVM, to_increment_commands
1112
from tests.pulses.sequencing_dummies import DummyPulseTemplate
1213
from tests.serialization_tests import SerializableTests
13-
14+
from tests.program.linspace_tests import assert_vm_output_almost_equal
1415

1516
class TimeReversalPulseTemplateTests(unittest.TestCase):
1617
def test_simple_properties(self):
@@ -29,19 +30,49 @@ def test_simple_properties(self):
2930

3031
self.assertEqual(reversed_pt.identifier, 'reverse')
3132

32-
def test_time_reversal_program(self):
33+
def test_time_reversal_loop(self):
3334
inner = ConstantPT(4, {'a': 3}) @ FunctionPT('sin(t)', 5, channel='a')
3435
manual_reverse = FunctionPT('sin(5 - t)', 5, channel='a') @ ConstantPT(4, {'a': 3})
3536
time_reversed = TimeReversalPulseTemplate(inner)
3637

37-
program = time_reversed.create_program()
38-
manual_program = manual_reverse.create_program()
38+
program = time_reversed.create_program(program_builder=LoopBuilder())
39+
manual_program = manual_reverse.create_program(program_builder=LoopBuilder())
3940

4041
t, data, _ = render(program, 9 / 10)
4142
_, manual_data, _ = render(manual_program, 9 / 10)
4243

4344
np.testing.assert_allclose(data['a'], manual_data['a'])
4445

46+
def test_time_reversal_linspace(self):
47+
constant_pt = ConstantPT(4, {'a': '3.0 + x * 1.0 + y * -0.3'})
48+
function_pt = FunctionPT('sin(t)', 5, channel='a')
49+
reversed_function_pt = function_pt.with_time_reversal()
50+
51+
inner = (constant_pt @ function_pt).with_iteration('x', 6)
52+
inner_manual = (reversed_function_pt @ constant_pt).with_iteration('x', (5, -1, -1))
53+
54+
outer = inner.with_time_reversal().with_iteration('y', 8)
55+
outer_man = inner_manual.with_iteration('y', 8)
56+
57+
self.assertEqual(outer.duration, outer_man.duration)
58+
59+
program = outer.create_program(program_builder=LinSpaceBuilder(channels=('a',)))
60+
manual_program = outer_man.create_program(program_builder=LinSpaceBuilder(channels=('a',)))
61+
62+
commands = to_increment_commands(program)
63+
manual_commands = to_increment_commands(manual_program)
64+
self.assertEqual(commands, manual_commands)
65+
66+
manual_vm = LinSpaceVM(1)
67+
manual_vm.set_commands(manual_commands)
68+
manual_vm.run()
69+
70+
vm = LinSpaceVM(1)
71+
vm.set_commands(commands)
72+
vm.run()
73+
74+
assert_vm_output_almost_equal(self, manual_vm.history, vm.history)
75+
4576

4677
class TimeReversalPulseTemplateSerializationTests(unittest.TestCase, SerializableTests):
4778
@property

0 commit comments

Comments
 (0)