@@ -44,6 +44,19 @@ class LinSpaceNode:
4444 def dependencies (self ) -> Mapping [int , set ]:
4545 raise NotImplementedError
4646
47+ def reversed (self , offset : int , lengths : list ):
48+ """Get the time reversed version of this linspace node. Since this is a non-local operation the arguments give
49+ the context.
50+
51+ Args:
52+ offset: Active iterations that are not reserved
53+ lengths: Lengths of the currently active iterations that have to be reversed
54+
55+ Returns:
56+ Time reversed version.
57+ """
58+ raise NotImplementedError
59+
4760
4861@dataclass
4962class LinSpaceHold (LinSpaceNode ):
@@ -60,13 +73,46 @@ def dependencies(self) -> Mapping[int, set]:
6073 for idx , factors in enumerate (self .factors )
6174 if factors }
6275
76+ def reversed (self , offset : int , lengths : list ):
77+ if not lengths :
78+ return self
79+ # If the iteration length is `n`, the starting point is shifted by `n - 1`
80+ steps = [length - 1 for length in lengths ]
81+ bases = []
82+ factors = []
83+ for ch_base , ch_factors in zip (self .bases , self .factors ):
84+ if ch_factors is None or len (ch_factors ) <= offset :
85+ bases .append (ch_base )
86+ factors .append (ch_factors )
87+ else :
88+ ch_reverse_base = ch_base + sum (step * factor
89+ for factor , step in zip (ch_factors [offset :], steps ))
90+ reversed_factors = ch_factors [:offset ] + tuple (- f for f in ch_factors [offset :])
91+ bases .append (ch_reverse_base )
92+ factors .append (reversed_factors )
93+
94+ if self .duration_factors is None or len (self .duration_factors ) <= offset :
95+ duration_factors = self .duration_factors
96+ duration_base = self .duration_base
97+ else :
98+ duration_base = self .duration_base + sum ((step * factor
99+ for factor , step in zip (self .duration_factors [offset :], steps )), TimeType (0 ))
100+ duration_factors = self .duration_factors [:offset ] + tuple (- f for f in self .duration_factors [offset :])
101+ return LinSpaceHold (tuple (bases ), tuple (factors ), duration_base = duration_base , duration_factors = duration_factors )
102+
63103
64104@dataclass
65105class LinSpaceArbitraryWaveform (LinSpaceNode ):
66106 """This is just a wrapper to pipe arbitrary waveforms through the system."""
67107 waveform : Waveform
68108 channels : Tuple [ChannelID , ...]
69109
110+ def reversed (self , offset : int , lengths : list ):
111+ return LinSpaceArbitraryWaveform (
112+ waveform = self .waveform .reversed (),
113+ channels = self .channels ,
114+ )
115+
70116
71117@dataclass
72118class LinSpaceRepeat (LinSpaceNode ):
@@ -81,6 +127,9 @@ def dependencies(self):
81127 dependencies .setdefault (idx , set ()).update (deps )
82128 return dependencies
83129
130+ def reversed (self , offset : int , counts : list ):
131+ return LinSpaceRepeat (tuple (node .reversed (offset , counts ) for node in reversed (self .body )), self .count )
132+
84133
85134@dataclass
86135class LinSpaceIter (LinSpaceNode ):
@@ -100,6 +149,12 @@ def dependencies(self):
100149 dependencies .setdefault (idx , set ()).update (shortened )
101150 return dependencies
102151
152+ def reversed (self , offset : int , lengths : list ):
153+ lengths .append (self .length )
154+ reversed_iter = LinSpaceIter (tuple (node .reversed (offset , lengths ) for node in reversed (self .body )), self .length )
155+ lengths .pop ()
156+ return reversed_iter
157+
103158
104159class LinSpaceBuilder (ProgramBuilder ):
105160 """This program builder supports efficient translation of pulse templates that use symbolic linearly
@@ -214,6 +269,14 @@ def with_iteration(self, index_name: str, rng: range,
214269 if cmds :
215270 self ._stack [- 1 ].append (LinSpaceIter (body = tuple (cmds ), length = len (rng )))
216271
272+ @contextlib .contextmanager
273+ def time_reversed (self ) -> ContextManager ['LinSpaceBuilder' ]:
274+ self ._stack .append ([])
275+ yield self
276+ inner = self ._stack .pop ()
277+ offset = len (self ._ranges )
278+ self ._stack [- 1 ].extend (node .reversed (offset , []) for node in reversed (inner ))
279+
217280 def to_program (self ) -> Optional [Sequence [LinSpaceNode ]]:
218281 if self ._root ():
219282 return self ._root ()
@@ -414,8 +477,10 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Comman
414477
415478
416479class LinSpaceVM :
417- def __init__ (self , channels : int ):
480+ def __init__ (self , channels : int ,
481+ sample_resolution : TimeType = TimeType .from_fraction (1 , 2 )):
418482 self .current_values = [np .nan ] * channels
483+ self .sample_resolution = sample_resolution
419484 self .time = TimeType (0 )
420485 self .registers = tuple ({} for _ in range (channels ))
421486
@@ -428,7 +493,20 @@ def __init__(self, channels: int):
428493
429494 def change_state (self , cmd : Union [Set , Increment , Wait , Play ]):
430495 if isinstance (cmd , Play ):
431- raise NotImplementedError ("TODO: Implement arbitrary waveform simulation" )
496+ dt = self .sample_resolution
497+ t = TimeType (0 )
498+ total_duration = cmd .waveform .duration
499+ while t <= total_duration and dt > 0 :
500+ sample_time = np .array ([float (t )])
501+ values = []
502+ for (idx , ch ) in enumerate (cmd .channels ):
503+ self .current_values [idx ] = values .append (cmd .waveform .get_sampled (channel = ch , sample_times = sample_time )[0 ])
504+ self .history .append (
505+ (self .time , self .current_values .copy ())
506+ )
507+ dt = min (total_duration - t , self .sample_resolution )
508+ self .time += dt
509+ t += dt
432510 elif isinstance (cmd , Wait ):
433511 self .history .append (
434512 (self .time , self .current_values .copy ())
0 commit comments