Skip to content

Commit 0ab80e4

Browse files
committed
Scan dispatches: correct handling of signed mitmot taps
Unlike MIT-SOT and SIT-SOT these can be positive or negative, depending on the order of differentiation
1 parent 42e8490 commit 0ab80e4

File tree

6 files changed

+107
-45
lines changed

6 files changed

+107
-45
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def jax_args_to_inner_func_args(carry, x):
9090
chain.from_iterable(
9191
buffer[(i + np.array(taps))]
9292
for buffer, taps in zip(
93-
inner_mit_mot, info.mit_mot_in_slices, strict=True
93+
inner_mit_mot, info.normalized_mit_mot_in_slices, strict=True
9494
)
9595
)
9696
)
@@ -140,7 +140,10 @@ def inner_func_outs_to_jax_outs(
140140
new_mit_mot = [
141141
buffer.at[i + np.array(taps)].set(new_vals)
142142
for buffer, new_vals, taps in zip(
143-
old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True
143+
old_mit_mot,
144+
new_mit_mot_vals,
145+
info.normalized_mit_mot_out_slices,
146+
strict=True,
144147
)
145148
]
146149
# Discard oldest MIT-SOT and append newest value

pytensor/link/numba/dispatch/scan.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ def idx_to_str(
2727
idx_symbol: str = "i",
2828
allow_scalar=False,
2929
) -> str:
30-
if offset < 0:
31-
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
32-
elif offset > 0:
30+
assert offset >= 0
31+
if offset > 0:
3332
indices = f"{idx_symbol} + {offset}"
3433
else:
3534
indices = idx_symbol
@@ -226,33 +225,16 @@ def add_inner_in_expr(
226225
# storage array like a circular buffer, and that's why we need to track the
227226
# storage size along with the taps length/indexing offset.
228227
def add_output_storage_post_proc_stmt(
229-
outer_in_name: str, tap_sizes: tuple[int, ...], storage_size: str
228+
outer_in_name: str, max_offset: int, storage_size: str
230229
):
231-
tap_size = max(tap_sizes)
232-
233-
if op.info.as_while:
234-
# While loops need to truncate the output storage to a length given
235-
# by the number of iterations performed.
236-
output_storage_post_proc_stmts.append(
237-
dedent(
238-
f"""
239-
if i + {tap_size} < {storage_size}:
240-
{storage_size} = i + {tap_size}
241-
{outer_in_name} = {outer_in_name}[:{storage_size}]
242-
"""
243-
).strip()
244-
)
245-
246-
# Rotate the storage so that the last computed value is at the end of
247-
# the storage array.
230+
# Rotate the storage so that the last computed value is at the end of the storage array.
248231
# This is needed when the output storage array does not have a length
249232
# equal to the number of taps plus `n_steps`.
250-
# If the storage size only allows one entry, there's nothing to rotate
251233
output_storage_post_proc_stmts.append(
252234
dedent(
253235
f"""
254-
if 1 < {storage_size} < (i + {tap_size}):
255-
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
236+
if 1 < {storage_size} < (i + {max_offset}):
237+
{outer_in_name}_shift = (i + {max_offset}) % ({storage_size})
256238
if {outer_in_name}_shift > 0:
257239
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
258240
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
@@ -261,6 +243,18 @@ def add_output_storage_post_proc_stmt(
261243
).strip()
262244
)
263245

246+
if op.info.as_while:
247+
# While loops need to truncate the output storage to a length given
248+
# by the number of iterations performed.
249+
output_storage_post_proc_stmts.append(
250+
dedent(
251+
f"""
252+
elif {storage_size} > (i + {max_offset}):
253+
{outer_in_name} = {outer_in_name}[:i + {max_offset}]
254+
"""
255+
).strip()
256+
)
257+
264258
# Special in-loop statements that create (nit-sot) storage arrays after a
265259
# single iteration is performed. This is necessary because we don't know
266260
# the exact shapes of the storage arrays that need to be allocated until
@@ -288,12 +282,11 @@ def add_output_storage_post_proc_stmt(
288282
storage_size_name = f"{outer_in_name}_len"
289283
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
290284
input_taps = inner_in_names_to_input_taps[outer_in_name]
291-
tap_storage_size = -min(input_taps)
292-
assert tap_storage_size >= 0
285+
max_lookback_inp_tap = -min(0, min(input_taps))
286+
assert max_lookback_inp_tap >= 0
293287

294288
for in_tap in input_taps:
295-
tap_offset = in_tap + tap_storage_size
296-
assert tap_offset >= 0
289+
tap_offset = max_lookback_inp_tap + in_tap
297290
is_vector = outer_in_var.ndim == 1
298291
add_inner_in_expr(
299292
outer_in_name,
@@ -302,22 +295,25 @@ def add_output_storage_post_proc_stmt(
302295
vector_slice_opt=is_vector,
303296
)
304297

305-
output_taps = inner_in_names_to_output_taps.get(
306-
outer_in_name, [tap_storage_size]
307-
)
308-
inner_out_to_outer_in_stmts.extend(
309-
idx_to_str(
310-
storage_name,
311-
out_tap,
312-
size=storage_size_name,
313-
allow_scalar=True,
298+
output_taps = inner_in_names_to_output_taps.get(outer_in_name, [0])
299+
for out_tap in output_taps:
300+
tap_offset = max_lookback_inp_tap + out_tap
301+
assert tap_offset >= 0
302+
inner_out_to_outer_in_stmts.append(
303+
idx_to_str(
304+
storage_name,
305+
tap_offset,
306+
size=storage_size_name,
307+
allow_scalar=True,
308+
)
314309
)
315-
for out_tap in output_taps
316-
)
317310

318-
add_output_storage_post_proc_stmt(
319-
storage_name, output_taps, storage_size_name
320-
)
311+
if outer_in_name not in outer_in_mit_mot_names:
312+
# MIT-SOT and SIT-SOT may require buffer rolling/truncation after the main loop
313+
max_offset_out_tap = max(output_taps) + max_lookback_inp_tap
314+
add_output_storage_post_proc_stmt(
315+
storage_name, max_offset_out_tap, storage_size_name
316+
)
321317

322318
else:
323319
storage_size_stmt = ""
@@ -351,7 +347,7 @@ def add_output_storage_post_proc_stmt(
351347
inner_out_to_outer_in_stmts.append(
352348
idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
353349
)
354-
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
350+
add_output_storage_post_proc_stmt(storage_name, 0, storage_size_name)
355351

356352
# In case of nit-sots we are provided the length of the array in
357353
# the iteration dimension instead of actual arrays, hence we

pytensor/scan/op.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,26 @@ def n_outer_outputs(self):
287287
+ self.n_untraced_sit_sot_outs
288288
)
289289

290+
@property
291+
def normalized_mit_mot_in_slices(self) -> tuple[tuple[int, ...], ...]:
292+
"""Return mit_mot_in slices normalized as an offset from the oldest tap"""
293+
# TODO: Make this the canonical representation
294+
res = []
295+
for in_slice in self.mit_mot_in_slices:
296+
min_tap = -(min(0, min(in_slice)))
297+
res.append(tuple(tap + min_tap for tap in in_slice))
298+
return tuple(res)
299+
300+
@property
301+
def normalized_mit_mot_out_slices(self) -> tuple[tuple[int, ...], ...]:
302+
"""Return mit_mot_out slices normalized as an offset from the oldest tap"""
303+
# TODO: Make this the canonical representation
304+
res = []
305+
for out_slice in self.mit_mot_out_slices:
306+
min_tap = -(min(0, min(out_slice)))
307+
res.append(tuple(tap + min_tap for tap in out_slice))
308+
return tuple(res)
309+
290310

291311
TensorConstructorType = Callable[
292312
[Iterable[bool | int | None], str | np.generic], TensorType

tests/link/jax/test_scan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytensor.tensor.math import gammaln, log
1616
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector
1717
from tests.link.jax.test_basic import compare_jax_and_py
18+
from tests.scan.test_basic import ScanCompatibilityTests
1819

1920

2021
jax = pytest.importorskip("jax")
@@ -626,3 +627,7 @@ def block_until_ready(*inputs, jax_fn=jax_fn):
626627
block_until_ready(*test_input_vals) # Warmup
627628

628629
benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1)
630+
631+
632+
def test_higher_order_derivatives():
633+
ScanCompatibilityTests.check_higher_order_derivative(mode="JAX")

tests/link/numba/test_scan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.tensor.random.utils import RandomStream
1717
from tests import unittest_tools as utt
1818
from tests.link.numba.test_basic import compare_numba_and_py
19+
from tests.scan.test_basic import ScanCompatibilityTests
1920

2021

2122
@pytest.mark.parametrize(
@@ -652,3 +653,7 @@ def test_mit_sot_buffer(self, constant_n_steps, n_steps_val):
652653

653654
def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark):
654655
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark)
656+
657+
658+
def test_higher_order_derivatives():
659+
ScanCompatibilityTests.check_higher_order_derivative(mode="NUMBA")

tests/scan/test_basic.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4080,6 +4080,9 @@ def test_grad_multiple_outs_some_disconnected_2(self):
40804080
# Also, the purpose of this test is not clear.
40814081
self._grad_mout_helper(1, None)
40824082

4083+
def test_higher_order_derivatives(self):
4084+
ScanCompatibilityTests.check_higher_order_derivative(mode=None)
4085+
40834086

40844087
@pytest.mark.parametrize(
40854088
"fn, sequences, outputs_info, non_sequences, n_steps, op_check",
@@ -4308,3 +4311,33 @@ def test_return_updates_api_change():
43084311

43094312
with pytest.raises(ValueError, match=err_msg):
43104313
scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False)
4314+
4315+
4316+
class ScanCompatibilityTests:
4317+
"""Collection of test of subtle required behaviors of Scan, that can be reused by different backends."""
4318+
4319+
@staticmethod
4320+
def check_higher_order_derivative(mode):
4321+
"""This tests different mit-mot taps signs"""
4322+
x = pt.scalar("x")
4323+
4324+
# xs[-1] is equivalent to x ** 16
4325+
xs = scan(
4326+
fn=lambda xtm1: xtm1**2,
4327+
outputs_info=[x],
4328+
n_steps=4,
4329+
return_updates=False,
4330+
)
4331+
r = xs[-1]
4332+
g = grad(r, x)
4333+
gg = grad(g, x)
4334+
ggg = grad(gg, x)
4335+
4336+
fn = function([x], [r, g, gg, ggg], mode=mode)
4337+
x_test = np.array(0.95, dtype=x.type.dtype)
4338+
r_res, g_res, gg_res, _ggg_res = fn(x_test)
4339+
np.testing.assert_allclose(r_res, x_test**16)
4340+
np.testing.assert_allclose(g_res, 16 * x_test**15)
4341+
np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14)
4342+
# FIXME: All implementations of Scan seem to get this one wrong!
4343+
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)

0 commit comments

Comments
 (0)