Skip to content

Commit 32abf99

Browse files
committed
Scan dispatches: correct handling of signed mitmot taps
Unlike MIT-SOT and SIT-SOT these can be positive or negative, depending on the order of differentiation
1 parent ebc0de0 commit 32abf99

File tree

6 files changed

+107
-45
lines changed

6 files changed

+107
-45
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def jax_args_to_inner_func_args(carry, x):
9090
chain.from_iterable(
9191
buffer[(i + np.array(taps))]
9292
for buffer, taps in zip(
93-
inner_mit_mot, info.mit_mot_in_slices, strict=True
93+
inner_mit_mot, info.normalized_mit_mot_in_slices, strict=True
9494
)
9595
)
9696
)
@@ -140,7 +140,10 @@ def inner_func_outs_to_jax_outs(
140140
new_mit_mot = [
141141
buffer.at[i + np.array(taps)].set(new_vals)
142142
for buffer, new_vals, taps in zip(
143-
old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True
143+
old_mit_mot,
144+
new_mit_mot_vals,
145+
info.normalized_mit_mot_out_slices,
146+
strict=True,
144147
)
145148
]
146149
# Discard oldest MIT-SOT and append newest value

pytensor/link/numba/dispatch/scan.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ def idx_to_str(
2727
idx_symbol: str = "i",
2828
allow_scalar=False,
2929
) -> str:
30-
if offset < 0:
31-
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
32-
elif offset > 0:
30+
assert offset >= 0
31+
if offset > 0:
3332
indices = f"{idx_symbol} + {offset}"
3433
else:
3534
indices = idx_symbol
@@ -226,33 +225,16 @@ def add_inner_in_expr(
226225
# storage array like a circular buffer, and that's why we need to track the
227226
# storage size along with the taps length/indexing offset.
228227
def add_output_storage_post_proc_stmt(
229-
outer_in_name: str, tap_sizes: tuple[int, ...], storage_size: str
228+
outer_in_name: str, max_offset: int, storage_size: str
230229
):
231-
tap_size = max(tap_sizes)
232-
233-
if op.info.as_while:
234-
# While loops need to truncate the output storage to a length given
235-
# by the number of iterations performed.
236-
output_storage_post_proc_stmts.append(
237-
dedent(
238-
f"""
239-
if i + {tap_size} < {storage_size}:
240-
{storage_size} = i + {tap_size}
241-
{outer_in_name} = {outer_in_name}[:{storage_size}]
242-
"""
243-
).strip()
244-
)
245-
246-
# Rotate the storage so that the last computed value is at the end of
247-
# the storage array.
230+
# Rotate the storage so that the last computed value is at the end of the storage array.
248231
# This is needed when the output storage array does not have a length
249232
# equal to the number of taps plus `n_steps`.
250-
# If the storage size only allows one entry, there's nothing to rotate
251233
output_storage_post_proc_stmts.append(
252234
dedent(
253235
f"""
254-
if 1 < {storage_size} < (i + {tap_size}):
255-
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
236+
if 1 < {storage_size} < (i + {max_offset}):
237+
{outer_in_name}_shift = (i + {max_offset}) % ({storage_size})
256238
if {outer_in_name}_shift > 0:
257239
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
258240
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
@@ -261,6 +243,18 @@ def add_output_storage_post_proc_stmt(
261243
).strip()
262244
)
263245

246+
if op.info.as_while:
247+
# While loops need to truncate the output storage to a length given
248+
# by the number of iterations performed.
249+
output_storage_post_proc_stmts.append(
250+
dedent(
251+
f"""
252+
elif {storage_size} > (i + {max_offset}):
253+
{outer_in_name} = {outer_in_name}[:i + {max_offset}]
254+
"""
255+
).strip()
256+
)
257+
264258
# Special in-loop statements that create (nit-sot) storage arrays after a
265259
# single iteration is performed. This is necessary because we don't know
266260
# the exact shapes of the storage arrays that need to be allocated until
@@ -288,12 +282,11 @@ def add_output_storage_post_proc_stmt(
288282
storage_size_name = f"{outer_in_name}_len"
289283
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
290284
input_taps = inner_in_names_to_input_taps[outer_in_name]
291-
tap_storage_size = -min(input_taps)
292-
assert tap_storage_size >= 0
285+
max_lookback_inp_tap = -min(0, min(input_taps))
286+
assert max_lookback_inp_tap >= 0
293287

294288
for in_tap in input_taps:
295-
tap_offset = in_tap + tap_storage_size
296-
assert tap_offset >= 0
289+
tap_offset = max_lookback_inp_tap + in_tap
297290
is_vector = outer_in_var.ndim == 1
298291
add_inner_in_expr(
299292
outer_in_name,
@@ -302,22 +295,25 @@ def add_output_storage_post_proc_stmt(
302295
vector_slice_opt=is_vector,
303296
)
304297

305-
output_taps = inner_in_names_to_output_taps.get(
306-
outer_in_name, [tap_storage_size]
307-
)
308-
inner_out_to_outer_in_stmts.extend(
309-
idx_to_str(
310-
storage_name,
311-
out_tap,
312-
size=storage_size_name,
313-
allow_scalar=True,
298+
output_taps = inner_in_names_to_output_taps.get(outer_in_name, [0])
299+
for out_tap in output_taps:
300+
tap_offset = max_lookback_inp_tap + out_tap
301+
assert tap_offset >= 0
302+
inner_out_to_outer_in_stmts.append(
303+
idx_to_str(
304+
storage_name,
305+
tap_offset,
306+
size=storage_size_name,
307+
allow_scalar=True,
308+
)
314309
)
315-
for out_tap in output_taps
316-
)
317310

318-
add_output_storage_post_proc_stmt(
319-
storage_name, output_taps, storage_size_name
320-
)
311+
if outer_in_name not in outer_in_mit_mot_names:
312+
# MIT-SOT and SIT-SOT may require buffer rolling/truncation after the main loop
313+
max_offset_out_tap = max(output_taps) + max_lookback_inp_tap
314+
add_output_storage_post_proc_stmt(
315+
storage_name, max_offset_out_tap, storage_size_name
316+
)
321317

322318
else:
323319
storage_size_stmt = ""
@@ -351,7 +347,7 @@ def add_output_storage_post_proc_stmt(
351347
inner_out_to_outer_in_stmts.append(
352348
idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
353349
)
354-
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
350+
add_output_storage_post_proc_stmt(storage_name, 0, storage_size_name)
355351

356352
# In case of nit-sots we are provided the length of the array in
357353
# the iteration dimension instead of actual arrays, hence we

pytensor/scan/op.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,26 @@ def n_outer_outputs(self):
288288
+ self.n_untraced_sit_sot_outs
289289
)
290290

291+
@property
292+
def normalized_mit_mot_in_slices(self) -> tuple[tuple[int, ...], ...]:
293+
"""Return mit_mot_in slices normalized as an offset from the oldest tap"""
294+
# TODO: Make this the canonical representation
295+
res = []
296+
for in_slice in self.mit_mot_in_slices:
297+
min_tap = -(min(0, min(in_slice)))
298+
res.append(tuple(tap + min_tap for tap in in_slice))
299+
return tuple(res)
300+
301+
@property
302+
def normalized_mit_mot_out_slices(self) -> tuple[tuple[int, ...], ...]:
303+
"""Return mit_mot_out slices normalized as an offset from the oldest tap"""
304+
# TODO: Make this the canonical representation
305+
res = []
306+
for out_slice in self.mit_mot_out_slices:
307+
min_tap = -(min(0, min(out_slice)))
308+
res.append(tuple(tap + min_tap for tap in out_slice))
309+
return tuple(res)
310+
291311

292312
TensorConstructorType = Callable[
293313
[Iterable[bool | int | None], str | np.generic], TensorType

tests/link/jax/test_scan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytensor.tensor.math import gammaln, log
1616
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector
1717
from tests.link.jax.test_basic import compare_jax_and_py
18+
from tests.scan.test_basic import ScanCompatibilityTests
1819

1920

2021
jax = pytest.importorskip("jax")
@@ -626,3 +627,7 @@ def block_until_ready(*inputs, jax_fn=jax_fn):
626627
block_until_ready(*test_input_vals) # Warmup
627628

628629
benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1)
630+
631+
632+
def test_higher_order_derivatives():
633+
ScanCompatibilityTests.check_higher_order_derivative(mode="JAX")

tests/link/numba/test_scan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.tensor.random.utils import RandomStream
1717
from tests import unittest_tools as utt
1818
from tests.link.numba.test_basic import compare_numba_and_py
19+
from tests.scan.test_basic import ScanCompatibilityTests
1920

2021

2122
@pytest.mark.parametrize(
@@ -652,3 +653,7 @@ def test_mit_sot_buffer(self, constant_n_steps, n_steps_val):
652653

653654
def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark):
654655
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark)
656+
657+
658+
def test_higher_order_derivatives():
659+
ScanCompatibilityTests.check_higher_order_derivative(mode="NUMBA")

tests/scan/test_basic.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4082,6 +4082,9 @@ def test_grad_multiple_outs_some_disconnected_2(self):
40824082
# Also, the purpose of this test is not clear.
40834083
self._grad_mout_helper(1, None)
40844084

4085+
def test_higher_order_derivatives(self):
4086+
ScanCompatibilityTests.check_higher_order_derivative(mode=None)
4087+
40854088

40864089
@pytest.mark.parametrize(
40874090
"fn, sequences, outputs_info, non_sequences, n_steps, op_check",
@@ -4398,3 +4401,33 @@ def test_scan_mode_compatibility(scan_mode):
43984401

43994402
# Expected value computed by running correct Scan once
44004403
np.testing.assert_allclose(fn(*numerical_inputs), [44, 38])
4404+
4405+
4406+
class ScanCompatibilityTests:
4407+
"""Collection of test of subtle required behaviors of Scan, that can be reused by different backends."""
4408+
4409+
@staticmethod
4410+
def check_higher_order_derivative(mode):
4411+
"""This tests different mit-mot taps signs"""
4412+
x = pt.dscalar("x")
4413+
4414+
# xs[-1] is equivalent to x ** 16
4415+
xs = scan(
4416+
fn=lambda xtm1: xtm1**2,
4417+
outputs_info=[x],
4418+
n_steps=4,
4419+
return_updates=False,
4420+
)
4421+
r = xs[-1]
4422+
g = grad(r, x)
4423+
gg = grad(g, x)
4424+
ggg = grad(gg, x)
4425+
4426+
fn = function([x], [r, g, gg, ggg], mode=mode)
4427+
x_test = np.array(0.95, dtype=x.type.dtype)
4428+
r_res, g_res, gg_res, _ggg_res = fn(x_test)
4429+
np.testing.assert_allclose(r_res, x_test**16)
4430+
np.testing.assert_allclose(g_res, 16 * x_test**15)
4431+
np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14)
4432+
# FIXME: All implementations of Scan seem to get this one wrong!
4433+
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)

0 commit comments

Comments
 (0)