@@ -27,9 +27,8 @@ def idx_to_str(
2727 idx_symbol : str = "i" ,
2828 allow_scalar = False ,
2929) -> str :
30- if offset < 0 :
31- indices = f"{ idx_symbol } + { array_name } .shape[0] - { offset } "
32- elif offset > 0 :
30+ assert offset >= 0
31+ if offset > 0 :
3332 indices = f"{ idx_symbol } + { offset } "
3433 else :
3534 indices = idx_symbol
@@ -226,33 +225,16 @@ def add_inner_in_expr(
226225 # storage array like a circular buffer, and that's why we need to track the
227226 # storage size along with the taps length/indexing offset.
228227 def add_output_storage_post_proc_stmt (
229- outer_in_name : str , tap_sizes : tuple [ int , ...] , storage_size : str
228+ outer_in_name : str , max_offset : int , storage_size : str
230229 ):
231- tap_size = max (tap_sizes )
232-
233- if op .info .as_while :
234- # While loops need to truncate the output storage to a length given
235- # by the number of iterations performed.
236- output_storage_post_proc_stmts .append (
237- dedent (
238- f"""
239- if i + { tap_size } < { storage_size } :
240- { storage_size } = i + { tap_size }
241- { outer_in_name } = { outer_in_name } [:{ storage_size } ]
242- """
243- ).strip ()
244- )
245-
246- # Rotate the storage so that the last computed value is at the end of
247- # the storage array.
230+ # Rotate the storage so that the last computed value is at the end of the storage array.
248231 # This is needed when the output storage array does not have a length
249232 # equal to the number of taps plus `n_steps`.
250- # If the storage size only allows one entry, there's nothing to rotate
251233 output_storage_post_proc_stmts .append (
252234 dedent (
253235 f"""
254- if 1 < { storage_size } < (i + { tap_size } ):
255- { outer_in_name } _shift = (i + { tap_size } ) % ({ storage_size } )
236+ if 1 < { storage_size } < (i + { max_offset } ):
237+ { outer_in_name } _shift = (i + { max_offset } ) % ({ storage_size } )
256238 if { outer_in_name } _shift > 0:
257239 { outer_in_name } _left = { outer_in_name } [:{ outer_in_name } _shift]
258240 { outer_in_name } _right = { outer_in_name } [{ outer_in_name } _shift:]
@@ -261,6 +243,18 @@ def add_output_storage_post_proc_stmt(
261243 ).strip ()
262244 )
263245
246+ if op .info .as_while :
247+ # While loops need to truncate the output storage to a length given
248+ # by the number of iterations performed.
249+ output_storage_post_proc_stmts .append (
250+ dedent (
251+ f"""
252+ elif { storage_size } > (i + { max_offset } ):
253+ { outer_in_name } = { outer_in_name } [:i + { max_offset } ]
254+ """
255+ ).strip ()
256+ )
257+
264258 # Special in-loop statements that create (nit-sot) storage arrays after a
265259 # single iteration is performed. This is necessary because we don't know
266260 # the exact shapes of the storage arrays that need to be allocated until
@@ -288,12 +282,11 @@ def add_output_storage_post_proc_stmt(
288282 storage_size_name = f"{ outer_in_name } _len"
289283 storage_size_stmt = f"{ storage_size_name } = { outer_in_name } .shape[0]"
290284 input_taps = inner_in_names_to_input_taps [outer_in_name ]
291- tap_storage_size = - min (input_taps )
292- assert tap_storage_size >= 0
285+ max_lookback_inp_tap = - min (0 , min ( input_taps ) )
286+ assert max_lookback_inp_tap >= 0
293287
294288 for in_tap in input_taps :
295- tap_offset = in_tap + tap_storage_size
296- assert tap_offset >= 0
289+ tap_offset = max_lookback_inp_tap + in_tap
297290 is_vector = outer_in_var .ndim == 1
298291 add_inner_in_expr (
299292 outer_in_name ,
@@ -302,22 +295,25 @@ def add_output_storage_post_proc_stmt(
302295 vector_slice_opt = is_vector ,
303296 )
304297
305- output_taps = inner_in_names_to_output_taps .get (
306- outer_in_name , [tap_storage_size ]
307- )
308- inner_out_to_outer_in_stmts .extend (
309- idx_to_str (
310- storage_name ,
311- out_tap ,
312- size = storage_size_name ,
313- allow_scalar = True ,
298+ output_taps = inner_in_names_to_output_taps .get (outer_in_name , [0 ])
299+ for out_tap in output_taps :
300+ tap_offset = max_lookback_inp_tap + out_tap
301+ assert tap_offset >= 0
302+ inner_out_to_outer_in_stmts .append (
303+ idx_to_str (
304+ storage_name ,
305+ tap_offset ,
306+ size = storage_size_name ,
307+ allow_scalar = True ,
308+ )
314309 )
315- for out_tap in output_taps
316- )
317310
318- add_output_storage_post_proc_stmt (
319- storage_name , output_taps , storage_size_name
320- )
311+ if outer_in_name not in outer_in_mit_mot_names :
312+ # MIT-SOT and SIT-SOT may require buffer rolling/truncation after the main loop
313+ max_offset_out_tap = max (output_taps ) + max_lookback_inp_tap
314+ add_output_storage_post_proc_stmt (
315+ storage_name , max_offset_out_tap , storage_size_name
316+ )
321317
322318 else :
323319 storage_size_stmt = ""
@@ -351,7 +347,7 @@ def add_output_storage_post_proc_stmt(
351347 inner_out_to_outer_in_stmts .append (
352348 idx_to_str (storage_name , 0 , size = storage_size_name , allow_scalar = True )
353349 )
354- add_output_storage_post_proc_stmt (storage_name , ( 0 ,) , storage_size_name )
350+ add_output_storage_post_proc_stmt (storage_name , 0 , storage_size_name )
355351
356352 # In case of nit-sots we are provided the length of the array in
357353 # the iteration dimension instead of actual arrays, hence we
0 commit comments