From e6cc2006458288abe37cbe06102c960cda71fd48 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 11 Oct 2025 17:54:40 +0200 Subject: [PATCH 1/5] Fix string error on SymbolicInputVariables with updates PyTensor Variables cannot be called `bool` upon --- pytensor/compile/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/compile/io.py b/pytensor/compile/io.py index 9ce0421235..ff3531c343 100644 --- a/pytensor/compile/io.py +++ b/pytensor/compile/io.py @@ -95,7 +95,7 @@ def __init__( self.implicit = implicit def __str__(self): - if self.update: + if self.update is not None: return f"In({self.variable} -> {self.update})" else: return f"In({self.variable})" From 5d8377288a0b49946f44520f4f401bf9ea44d24b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 13 Oct 2025 12:05:56 +0200 Subject: [PATCH 2/5] Benchmark scan in JAX backend Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- tests/link/jax/test_scan.py | 331 ++++++++++++++++++++++++++---------- 1 file changed, 239 insertions(+), 92 deletions(-) diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 4ee95ab527..25b0854096 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -4,7 +4,7 @@ import pytest import pytensor.tensor as pt -from pytensor import function, shared +from pytensor import function, ifelse, shared from pytensor.compile import get_mode from pytensor.configdefaults import config from pytensor.scan import until @@ -12,7 +12,7 @@ from pytensor.scan.op import Scan from pytensor.tensor import random from pytensor.tensor.math import gammaln, log -from pytensor.tensor.type import dmatrix, dvector, lscalar, matrix, scalar, vector +from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -189,96 +189,6 @@ def test_scan_while(): compare_jax_and_py([], [xs], []) -def test_scan_SEIR(): - """Test a scan implementation of a SEIR model. - - SEIR model definition: - S[t+1] = S[t] - B[t] - E[t+1] = E[t] +B[t] - C[t] - I[t+1] = I[t+1] + C[t] - D[t] - - B[t] ~ Binom(S[t], beta) - C[t] ~ Binom(E[t], gamma) - D[t] ~ Binom(I[t], delta) - """ - - def binomln(n, k): - return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) - - def binom_log_prob(n, p, value): - return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) - - # sequences - at_C = vector("C_t", dtype="int32", shape=(8,)) - at_D = vector("D_t", dtype="int32", shape=(8,)) - # outputs_info (initial conditions) - st0 = lscalar("s_t0") - et0 = lscalar("e_t0") - it0 = lscalar("i_t0") - logp_c = scalar("logp_c") - logp_d = scalar("logp_d") - # non_sequences - beta = scalar("beta") - gamma = scalar("gamma") - delta = scalar("delta") - - # TODO: Use random streams when their JAX conversions are implemented. - # trng = pytensor.tensor.random.RandomStream(1234) - - def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): - # bt0 = trng.binomial(n=st0, p=beta) - bt0 = st0 * beta - bt0 = bt0.astype(st0.dtype) - - logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype) - logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype) - - st1 = st0 - bt0 - et1 = et0 + bt0 - ct0 - it1 = it0 + ct0 - dt0 - return st1, et1, it1, logp_c1, logp_d1 - - (st, et, it, logp_c_all, logp_d_all), _ = scan( - fn=seir_one_step, - sequences=[at_C, at_D], - outputs_info=[st0, et0, it0, logp_c, logp_d], - non_sequences=[beta, gamma, delta], - ) - st.name = "S_t" - et.name = "E_t" - it.name = "I_t" - logp_c_all.name = "C_t_logp" - logp_d_all.name = "D_t_logp" - - s0, e0, i0 = 100, 50, 25 - logp_c0 = np.array(0.0, dtype=config.floatX) - logp_d0 = np.array(0.0, dtype=config.floatX) - beta_val, gamma_val, delta_val = ( - np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753] - ) - C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32) - D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32) - - test_input_vals = [ - C, - D, - s0, - e0, - i0, - logp_c0, - logp_d0, - beta_val, - gamma_val, - delta_val, - ] - compare_jax_and_py( - [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], - [st, et, it, logp_c_all, logp_d_all], - test_input_vals, - jax_mode="JAX", - ) - - def test_scan_mitsot_with_nonseq(): a_pt = scalar("a") @@ -420,3 +330,240 @@ def test_dynamic_sequence_length(): assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1 np.testing.assert_allclose(f([]), []) np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4])) + + +def SEIR_model_logp(): + """Setup a Scan implementation of a SEIR model. + + SEIR model definition: + S[t+1] = S[t] - B[t] + E[t+1] = E[t] +B[t] - C[t] + I[t+1] = I[t+1] + C[t] - D[t] + + B[t] ~ Binom(S[t], beta) + C[t] ~ Binom(E[t], gamma) + D[t] ~ Binom(I[t], delta) + """ + + def binomln(n, k): + return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) + + def binom_log_prob(n, p, value): + return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) + + # sequences + C_t = vector("C_t", dtype="int32", shape=(1200,)) + D_t = vector("D_t", dtype="int32", shape=(1200,)) + # outputs_info (initial conditions) + st0 = scalar("s_t0") + et0 = scalar("e_t0") + it0 = scalar("i_t0") + # non_sequences + beta = scalar("beta") + gamma = scalar("gamma") + delta = scalar("delta") + + def seir_one_step(ct0, dt0, st0, et0, it0, beta, gamma, delta): + # bt0 = trng.binomial(n=st0, p=beta) + bt0 = st0 * beta + bt0 = bt0.astype(st0.dtype) + + logp_c1 = binom_log_prob(et0, gamma, ct0) + logp_d1 = binom_log_prob(it0, delta, dt0) + + st1 = st0 - bt0 + et1 = et0 + bt0 - ct0 + it1 = it0 + ct0 - dt0 + return st1, et1, it1, logp_c1, logp_d1 + + (st, et, it, logp_c_all, logp_d_all), _ = scan( + fn=seir_one_step, + sequences=[C_t, D_t], + outputs_info=[st0, et0, it0, None, None], + non_sequences=[beta, gamma, delta], + ) + st.name = "S_t" + et.name = "E_t" + it.name = "I_t" + logp_c_all.name = "C_t_logp" + logp_d_all.name = "D_t_logp" + + st0_val, et0_val, it0_val = np.array(100.0), np.array(50.0), np.array(25.0) + beta_val, gamma_val, delta_val = ( + np.array(0.277792), + np.array(0.135330), + np.array(0.108753), + ) + C_t_val = np.array([3, 5, 8, 13, 21, 26, 10, 3] * 150, dtype=np.int32) + D_t_val = np.array([1, 2, 3, 7, 9, 11, 5, 1] * 150, dtype=np.int32) + assert C_t_val.shape == D_t_val.shape == C_t.type.shape == D_t.type.shape + + test_input_vals = [ + C_t_val, + D_t_val, + st0_val, + et0_val, + it0_val, + beta_val, + gamma_val, + delta_val, + ] + + loss_graph = logp_c_all.sum() + logp_d_all.sum() + + return dict( + graph_inputs=[C_t, D_t, st0, et0, it0, beta, gamma, delta], + differentiable_vars=[st0, et0, it0, beta, gamma, delta], + test_input_vals=test_input_vals, + loss_graph=loss_graph, + ) + + +def cyclical_reduction(): + """Setup a Scan implementation of the cyclical reduction algorithm. + + This solves the matrix equation A @ X @ X + B @ X + C = 0 for X + + Adapted from https://github.com/jessegrabowski/gEconpy/blob/da495b22ac383cb6cb5dec15f305506aebef7302/gEconpy/solvers/cycle_reduction.py#L187 + """ + + def stabilize(x, jitter=1e-16): + return x + jitter * pt.eye(x.shape[0]) + + def step(A0, A1, A2, A1_hat, norm, step_num, tol): + def cycle_step(A0, A1, A2, A1_hat, _norm, step_num): + tmp = pt.dot( + pt.vertical_stack(A0, A2), + pt.linalg.solve( + stabilize(A1), + pt.horizontal_stack(A0, A2), + assume_a="gen", + check_finite=False, + ), + ) + + n = A0.shape[0] + idx_0 = pt.arange(n) + idx_1 = idx_0 + n + A1 = A1 - tmp[idx_0, :][:, idx_1] - tmp[idx_1, :][:, idx_0] + A0 = -tmp[idx_0, :][:, idx_0] + A2 = -tmp[idx_1, :][:, idx_1] + A1_hat = A1_hat - tmp[idx_1, :][:, idx_0] + + A0_L1_norm = pt.linalg.norm(A0, ord=1) + + return A0, A1, A2, A1_hat, A0_L1_norm, step_num + 1 + + return ifelse( + norm < tol, + (A0, A1, A2, A1_hat, norm, step_num), + cycle_step(A0, A1, A2, A1_hat, norm, step_num), + ) + + A = pt.matrix("A", shape=(20, 20)) + B = pt.matrix("B", shape=(20, 20)) + C = pt.matrix("C", shape=(20, 20)) + + norm = np.array(1e9, dtype="float64") + step_num = pt.zeros((), dtype="int32") + max_iter = 100 + tol = 1e-7 + + (*_, A1_hat, norm, _n_steps), _ = scan( + step, + outputs_info=[A, B, C, B, norm, step_num], + non_sequences=[tol], + n_steps=max_iter, + ) + A1_hat = A1_hat[-1] + + T = -pt.linalg.solve(stabilize(A1_hat), A, assume_a="gen", check_finite=False) + + rng = np.random.default_rng(sum(map(ord, "cycle_reduction"))) + n = A.type.shape[0] + A_test = rng.standard_normal(size=(n, n)) + C_test = rng.standard_normal(size=(n, n)) + # B must be invertible, so we make it symmetric positive-definite + B_rand = rng.standard_normal(size=(n, n)) + B_test = B_rand @ B_rand.T + np.eye(n) * 1e-3 + + return dict( + graph_inputs=[A, B, C], + differentiable_vars=[A, B, C], + test_input_vals=[A_test, B_test, C_test], + loss_graph=pt.sum(T), + ) + + +@pytest.mark.parametrize("gradient_backend", ["PYTENSOR", "JAX"]) +@pytest.mark.parametrize("mode", ("0forward", "1backward", "2both")) +@pytest.mark.parametrize("model", [cyclical_reduction, SEIR_model_logp]) +def test_scan_benchmark(model, mode, gradient_backend, benchmark): + if gradient_backend == "PYTENSOR" and mode in ("1backward", "2both"): + pytest.skip("PYTENSOR backend does not support backward mode yet") + + model_dict = model() + graph_inputs = model_dict["graph_inputs"] + differentiable_vars = model_dict["differentiable_vars"] + loss_graph = model_dict["loss_graph"] + test_input_vals = model_dict["test_input_vals"] + + if gradient_backend == "PYTENSOR": + backward_loss = pt.grad( + loss_graph, + wrt=differentiable_vars, + ) + + match mode: + # TODO: Restore original test separately + case "0forward": + graph_outputs = [loss_graph] + case "1backward": + graph_outputs = backward_loss + case "2both": + graph_outputs = [loss_graph, *backward_loss] + case _: + raise ValueError(f"Unknown mode: {mode}") + + jax_fn, _ = compare_jax_and_py( + graph_inputs, + graph_outputs, + test_input_vals, + jax_mode="JAX", + ) + jax_fn.trust_input = True + + else: # gradient_backend == "JAX" + import jax + + loss_fn_tuple = function(graph_inputs, loss_graph, mode="JAX").vm.jit_fn + + def loss_fn(*args): + return loss_fn_tuple(*args)[0] + + match mode: + case "0forward": + jax_fn = jax.jit(loss_fn_tuple) + case "1backward": + jax_fn = jax.jit( + jax.grad(loss_fn, argnums=tuple(range(len(graph_inputs))[2:])) + ) + case "2both": + value_and_grad_fn = jax.value_and_grad( + loss_fn, argnums=tuple(range(len(graph_inputs))[2:]) + ) + + @jax.jit + def jax_fn(*args): + loss, grads = value_and_grad_fn(*args) + return loss, *grads + + case _: + raise ValueError(f"Unknown mode: {mode}") + + def block_until_ready(*inputs, jax_fn=jax_fn): + return [o.block_until_ready() for o in jax_fn(*inputs)] + + block_until_ready(*test_input_vals) # Warmup + + benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1) From 7b3cf03a89ecab5ac379b99ef0878e6c054b9a06 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 10 Oct 2025 18:23:39 +0200 Subject: [PATCH 3/5] Reimplement JAX Scan dispatcher with MIT-MOT support Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/link/jax/dispatch/scan.py | 429 +++++++++++++++++------------ pytensor/scan/op.py | 11 + tests/link/jax/test_scan.py | 65 ++++- 3 files changed, 320 insertions(+), 185 deletions(-) diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index 3c3080765c..ad70bdc36d 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -1,202 +1,289 @@ -import jax +from itertools import chain + import jax.numpy as jnp +import numpy as np +from jax._src.lax.control_flow import fori_loop from pytensor.compile.mode import JAX, get_mode from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.scan.op import Scan +def call_inner_func_with_indexed_buffers( + info, + scan_inner_func, + i, + sequences, + mit_mot_buffers, + mit_sot_buffers, + sit_sot_buffers, + shareds, + non_sequences, +): + sequence_vals = [seq[i] for seq in sequences] + + # chain.from_iterable is used flatten the first dimension of each indexed buffer + # [buf1[[idx0, idx1]], buf2[[idx0, idx1]]] -> [buf1[idx0], buf1[idx1], buf2[idx0], buf2[idx1]] + # Benchmarking suggests unpacking advanced indexing on all taps is faster than basic index one tap at a time + mit_mot_vals = list( + chain.from_iterable( + buffer[(i + np.array(in_taps))] + for buffer, in_taps in zip( + mit_mot_buffers, info.mit_mot_in_slices, strict=True + ) + ) + ) + mit_sot_vals = list( + chain.from_iterable( + # Convert negative taps (-2, -1) to positive indices (0, 1) + buffer[((i + (np.array(in_taps) - min(in_taps))) % buffer.shape[0])] + for buffer, in_taps in zip( + mit_sot_buffers, info.mit_sot_in_slices, strict=True + ) + ) + ) + sit_sot_vals = [buffer[i % buffer.shape[0]] for buffer in sit_sot_buffers] + + return scan_inner_func( + *sequence_vals, + *mit_mot_vals, + *mit_sot_vals, + *sit_sot_vals, + *shareds, + *non_sequences, + ) + + +def update_buffers(buffers, update_vals, indices, may_roll: bool = True): + return tuple( + buffer.at[(index % buffer.shape[0]) if may_roll else index].set(update_val) + for buffer, update_val, index in zip(buffers, update_vals, indices, strict=True) + ) + + +def align_buffers(buffers, n_steps, max_taps): + return [ + jnp.roll( + buffer, + shift=jnp.where( + # Only needs rolling if last write position is beyond the buffer length + (n_steps + max_tap) > buffer.shape[0], + # Roll left by the amount of overflow + -((n_steps + max_tap + 1) % buffer.shape[0]), + 0, + ), + axis=0, + ) + for buffer, max_tap in zip(buffers, max_taps, strict=True) + ] + + @jax_funcify.register(Scan) -def jax_funcify_Scan(op: Scan, **kwargs): +def jax_funcify_Scan(op: Scan, node, **kwargs): + op = op # Need to bind to a local variable info = op.info if info.as_while: raise NotImplementedError("While Scan cannot yet be converted to JAX") - if info.n_mit_mot: - raise NotImplementedError( - "Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX" - ) - # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode) rewriter = ( get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer ) rewriter(op.fgraph) - scan_inner_func = jax_funcify(op.fgraph, **kwargs) - - def scan(*outer_inputs): - # Extract JAX scan inputs - outer_inputs = list(outer_inputs) - n_steps = outer_inputs[0] # JAX `length` - seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs` - - mit_sot_init = [] - for tap, seq in zip( - op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True - ): - init_slice = seq[: abs(min(tap))] - mit_sot_init.append(init_slice) - - sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)] - - init_carry = ( - mit_sot_init, - sit_sot_init, - op.outer_shared(outer_inputs), - op.outer_non_seqs(outer_inputs), - ) # JAX `init` - - def jax_args_to_inner_func_args(carry, x): - """Convert JAX scan arguments into format expected by scan_inner_func. - - scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs) - """ - - # `carry` contains all inner taps, shared terms, and non_seqs - ( - inner_mit_sot, - inner_sit_sot, - inner_shared, - inner_non_seqs, - ) = carry - - # `x` contains the inner sequences - inner_seqs = x - - mit_sot_flatten = [] - for array, index in zip( - inner_mit_sot, op.info.mit_sot_in_slices, strict=True - ): - mit_sot_flatten.extend(array[jnp.array(index)]) - - inner_scan_inputs = [ - *inner_seqs, - *mit_sot_flatten, - *inner_sit_sot, - *inner_shared, - *inner_non_seqs, - ] - - return inner_scan_inputs - - def inner_func_outs_to_jax_outs( - old_carry, - inner_scan_outs, - ): - """Convert inner_scan_func outputs into format expected by JAX scan. + # TODO: Use scan name from Op when available + scan_inner_func = jax_funcify(op.fgraph, fgraph_name="scan_inner_func", **kwargs) + + def scan(*outer_inputs, op=op, node=node): + n_steps = outer_inputs[0] + sequences = op.outer_seqs(outer_inputs) + has_empty_sequences = any(seq.shape[0] == 0 for seq in sequences) + init_mit_mot_buffers = op.outer_mitmot(outer_inputs) + init_mit_sot_buffers = op.outer_mitsot(outer_inputs) + init_sit_sot_buffers = op.outer_sitsot(outer_inputs) + nit_sot_buffer_lens = op.outer_nitsot(outer_inputs) + # Shareds are special-cased SIT-SOTs that are not traced, but updated at each step. + # Only last value is returned. It's a hack for special types (like RNG) that can't be "concatenated" over time. + init_shareds = op.outer_shared(outer_inputs) + non_sequences = op.outer_non_seqs(outer_inputs) + assert ( + 1 + + len(sequences) + + len(init_mit_mot_buffers) + + len(init_mit_sot_buffers) + + len(init_sit_sot_buffers) + + len(nit_sot_buffer_lens) + + len(init_shareds) + + len(non_sequences) + ) == len(outer_inputs) + + # Initialize NIT-SOT buffers + if nit_sot_buffer_lens: + if has_empty_sequences: + # In this case we cannot call the inner function to infer the shapes of the nit_sot outputs + # So we must rely on static shapes of the outputs (if available) + nit_sot_core_shapes = [ + n.type.shape for n in op.inner_nitsot_outs(op.fgraph.outputs) + ] + if any(d is None for shape in nit_sot_core_shapes for d in shape): + raise ValueError( + "Scan with NIT-SOT outputs (None in outputs_info) cannot have 0 steps unless the output shapes are statically known)\n" + f"The static shapes of the NIT-SOT outputs for this Scan {node.op} are: {nit_sot_core_shapes}." + ) - old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys) - """ - ( - inner_mit_sot, - _inner_sit_sot, - inner_shared, - inner_non_seqs, - ) = old_carry - - inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs) - inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs) - inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs) - inner_shared_outs = op.inner_shared_outs(inner_scan_outs) - - # Replace the oldest mit_sot tap by the newest value - inner_mit_sot_new = [ - jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0) - for old_mit_sot, new_val in zip( - inner_mit_sot, inner_mit_sot_outs, strict=True + else: + # Otherwise, call the function once to get the shapes and dtypes of the nit_sot outputs + buffer_vals = call_inner_func_with_indexed_buffers( + info, + scan_inner_func, + 0, + sequences, + init_mit_mot_buffers, + init_mit_sot_buffers, + init_sit_sot_buffers, + init_shareds, + non_sequences, ) + nit_sot_core_shapes = [ + n.shape for n in op.inner_nitsot_outs(buffer_vals) + ] + nit_sot_dtypes = [ + n.type.dtype for n in op.inner_nitsot_outs(op.fgraph.outputs) ] - - # Nothing needs to be done with sit_sot - inner_sit_sot_new = inner_sit_sot_outs - - inner_shared_new = inner_shared - # Replace old shared inputs by new shared outputs - inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs - - new_carry = ( - inner_mit_sot_new, - inner_sit_sot_new, - inner_shared_new, - inner_non_seqs, + init_nit_sot_buffers = tuple( + jnp.empty( + (nit_sot_buffer_len, *nit_sot_core_shape), + dtype=nit_sot_dtype, + ) + for nit_sot_buffer_len, nit_sot_core_shape, nit_sot_dtype in zip( + nit_sot_buffer_lens, + nit_sot_core_shapes, + nit_sot_dtypes, + strict=True, + ) ) + else: + init_nit_sot_buffers = () + + if has_empty_sequences: + # fori_loop still gets called with n_steps=0, which would raise an IndexError, we return early here + init_vals = ( + *init_mit_mot_buffers, + *init_mit_sot_buffers, + *init_sit_sot_buffers, + *init_nit_sot_buffers, + *init_shareds, + ) + return init_vals[0] if len(init_vals) == 1 else init_vals - # Shared variables and non_seqs are not traced - traced_outs = [ - *inner_mit_sot_outs, - *inner_sit_sot_outs, - *inner_nit_sot_outs, - ] - - return new_carry, traced_outs - - def jax_inner_func(carry, x): - inner_args = jax_args_to_inner_func_args(carry, x) - inner_scan_outs = list(scan_inner_func(*inner_args)) - new_carry, traced_outs = inner_func_outs_to_jax_outs(carry, inner_scan_outs) - return new_carry, traced_outs + def body_fun(i, prev_vals): + ( + mit_mot_buffers, + mit_sot_buffers, + sit_sot_buffers, + nit_sot_buffers, + shareds, + ) = prev_vals + + next_vals = call_inner_func_with_indexed_buffers( + info, + scan_inner_func, + i, + sequences, + mit_mot_buffers, + mit_sot_buffers, + sit_sot_buffers, + shareds, + non_sequences, + ) + # For MIT-MOT buffers, we want to store at the positions indicated by the output taps + mit_mot_updated_buffers = update_buffers( + mit_mot_buffers, + op.inner_mitmot_outs_grouped(next_vals), + # Taps are positive, we stack them to obtain advanced indices + indices=[i + jnp.stack(taps) for taps in info.mit_mot_out_slices], + # MIT-MOT buffers never roll, as they are never truncated + may_roll=False, + ) + # For regular buffers, we want to store at the position after the last reading + mit_sot_updated_buffers = update_buffers( + mit_sot_buffers, + op.inner_mitsot_outs(next_vals), + indices=[i - min(taps) for taps in info.mit_sot_in_slices], + ) + sit_sot_updated_buffers = update_buffers( + sit_sot_buffers, + op.inner_sitsot_outs(next_vals), + # Taps are always -1 for SIT-SOT, so we just use i + 1 + indices=[i + 1 for _ in sit_sot_buffers], + ) + nit_sot_updated_buffers = update_buffers( + nit_sot_buffers, + op.inner_nitsot_outs(next_vals), + # Taps are always 0 for NIT-SOT, so we just use i + indices=[i for _ in nit_sot_buffers], + ) + shareds_update_vals = op.inner_shared_outs(next_vals) + + return ( + mit_mot_updated_buffers, + mit_sot_updated_buffers, + sit_sot_updated_buffers, + nit_sot_updated_buffers, + shareds_update_vals, + ) - # Extract PyTensor scan outputs - final_carry, traces = jax.lax.scan( - jax_inner_func, init_carry, seqs, length=n_steps + ( + updated_mit_mot_buffers, + updated_mit_sot_buffers, + updated_sit_sot_buffers, + updated_nit_sot_buffers, + updated_shareds, + ) = fori_loop( + 0, + n_steps, + body_fun, + init_val=( + init_mit_mot_buffers, + init_mit_sot_buffers, + init_sit_sot_buffers, + init_nit_sot_buffers, + init_shareds, + ), ) - def get_partial_traces(traces): - """Convert JAX scan traces to PyTensor traces. + # Roll the output buffers to match PyTensor Scan semantics + # MIT-MOT buffers are never truncated, so no rolling is needed + aligned_mit_mot_buffers = updated_mit_mot_buffers + aligned_mit_sot_buffers = align_buffers( + updated_mit_sot_buffers, + n_steps, + # (-3, -1) -> max is 2 + max_taps=[-min(taps) - 1 for taps in info.mit_sot_in_slices], + ) - We need to: - 1. Prepend initial states to JAX output traces - 2. Slice final traces if Scan was instructed to only keep a portion - """ + aligned_sit_sot_buffers = align_buffers( + updated_sit_sot_buffers, + n_steps, + max_taps=[0 for _ in updated_sit_sot_buffers], + ) + aligned_nit_sot_buffers = align_buffers( + updated_nit_sot_buffers, + n_steps, + max_taps=[0 for _ in updated_nit_sot_buffers], + ) - init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot - buffers = ( - op.outer_mitsot(outer_inputs) - + op.outer_sitsot(outer_inputs) - + op.outer_nitsot(outer_inputs) + all_outputs = tuple( + chain.from_iterable( + ( + aligned_mit_mot_buffers, + aligned_mit_sot_buffers, + aligned_sit_sot_buffers, + aligned_nit_sot_buffers, + updated_shareds, + ) ) - partial_traces = [] - for init_state, trace, buffer in zip( - init_states, traces, buffers, strict=True - ): - if init_state is not None: - # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer - trace = jnp.atleast_1d(trace) - init_state = jnp.expand_dims( - init_state, range(trace.ndim - init_state.ndim) - ) - full_trace = jnp.concatenate([init_state, trace], axis=0) - buffer_size = buffer.shape[0] - else: - # NIT-SOT: Buffer is just the number of entries that should be returned - full_trace = jnp.atleast_1d(trace) - buffer_size = buffer - - partial_trace = full_trace[-buffer_size:] - partial_traces.append(partial_trace) - - return partial_traces - - def get_shared_outs(final_carry): - """Retrive last state of shared_outs from final_carry. - - These outputs cannot be traced in PyTensor Scan - """ - ( - _inner_out_mit_sot, - _inner_out_sit_sot, - inner_out_shared, - _inner_in_non_seqs, - ) = final_carry - - shared_outs = inner_out_shared[: info.n_shared_outs] - return list(shared_outs) - - scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry) - - if len(scan_outs_final) == 1: - scan_outs_final = scan_outs_final[0] - return scan_outs_final + ) + return all_outputs[0] if len(all_outputs) == 1 else all_outputs return scan diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 80cfa0fcf3..eda97560b3 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -307,6 +307,17 @@ def inner_mitmot_outs(self, list_outputs): n_taps = sum(len(x) for x in self.info.mit_mot_out_slices) return list_outputs[:n_taps] + def inner_mitmot_outs_grouped(self, list_outputs): + # Like inner_mitmot_outs but returns a list of lists, one per mitmot + # Instead of a flat list + n_taps = [len(x) for x in self.info.mit_mot_out_slices] + grouped_outs = [] + offset = 0 + for nt in n_taps: + grouped_outs.append(list_outputs[offset : offset + nt]) + offset += nt + return grouped_outs + def outer_mitmot_outs(self, list_outputs): return list_outputs[: self.info.n_mit_mot] diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 25b0854096..0b4cba30cb 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -7,6 +7,8 @@ from pytensor import function, ifelse, shared from pytensor.compile import get_mode from pytensor.configdefaults import config +from pytensor.graph import Apply, Op +from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.scan import until from pytensor.scan.basic import scan from pytensor.scan.op import Scan @@ -98,16 +100,26 @@ def test_scan_nit_sot(view): assert len(scan_nodes) == 1 -@pytest.mark.xfail(raises=NotImplementedError) def test_scan_mit_mot(): - xs = pt.vector("xs", shape=(10,)) - ys, _ = scan( - lambda xtm2, xtm1: (xtm2 + xtm1), - outputs_info=[{"initial": xs, "taps": [-2, -1]}], + def step(xtm1, ytm3, ytm1, rho): + return (xtm1 + ytm1) * rho, ytm3 * (1 - rho) + ytm1 * rho + + rho = pt.scalar("rho", dtype="float64") + x0 = pt.vector("xs", shape=(2,)) + y0 = pt.vector("ys", shape=(3,)) + [outs, _], _ = scan( + step, + outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}], + non_sequences=[rho], n_steps=10, ) - grads_wrt_xs = pt.grad(ys.sum(), wrt=xs) - compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)]) + grads = pt.grad(outs.sum(), wrt=[x0, y0, rho]) + compare_jax_and_py( + [x0, y0, rho], + grads, + [np.arange(2), np.array([0.5, 0.5, 0.5]), np.array(0.95)], + jax_mode=get_mode("JAX"), + ) def test_scan_update(): @@ -323,13 +335,41 @@ def test_default_mode_excludes_incompatible_rewrites(): def test_dynamic_sequence_length(): - x = pt.tensor("x", shape=(None,)) - out, _ = scan(lambda x: x + 1, sequences=[x]) + class IncWithoutStaticShape(Op): + def make_node(self, x): + x = pt.as_tensor_variable(x) + return Apply(self, [x], [pt.tensor(shape=(None,) * x.type.ndim)]) + + def perform(self, node, inputs, outputs): + outputs[0][0] = inputs[0] + 1 + + @jax_funcify.register(IncWithoutStaticShape) + def _(op, **kwargs): + return lambda x: x + 1 + + inc_without_static_shape = IncWithoutStaticShape() + x = pt.tensor("x", shape=(None, 3)) + + out, _ = scan( + lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x] + ) f = function([x], out, mode=get_mode("JAX").excluding("scan")) assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1 - np.testing.assert_allclose(f([]), []) - np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4])) + np.testing.assert_allclose(f([[1, 2, 3]]), np.array([[2, 3, 4]])) + + with pytest.raises(ValueError): + f(np.zeros((0, 3))) + + # But should be fine with static shape + out2, _ = scan( + lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape), + outputs_info=[None], + sequences=[x], + ) + f2 = function([x], out2, mode=get_mode("JAX").excluding("scan")) + np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]])) + np.testing.assert_allclose(f2(np.zeros((0, 3))), np.empty((0, 3))) def SEIR_model_logp(): @@ -499,9 +539,6 @@ def cycle_step(A0, A1, A2, A1_hat, _norm, step_num): @pytest.mark.parametrize("mode", ("0forward", "1backward", "2both")) @pytest.mark.parametrize("model", [cyclical_reduction, SEIR_model_logp]) def test_scan_benchmark(model, mode, gradient_backend, benchmark): - if gradient_backend == "PYTENSOR" and mode in ("1backward", "2both"): - pytest.skip("PYTENSOR backend does not support backward mode yet") - model_dict = model() graph_inputs = model_dict["graph_inputs"] differentiable_vars = model_dict["differentiable_vars"] From 8e529ebd6e29e2eea46706930311b62e3a54b376 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 13 Oct 2025 13:32:23 +0200 Subject: [PATCH 4/5] Only MIT-MOT require working on buffers directly Using JAX Scan machinery to create MIT-SOT, SIT-SOT, and NIT-SOT buffers for us seems to be more performant than working directly on the pre-allocated buffers and reading/writing at every iteration. There is no machinery to work with MIT-MOT directly (just like in PyTensor user-facing Scan). --- pytensor/link/jax/dispatch/scan.py | 453 +++++++++++++---------------- tests/link/jax/test_scan.py | 10 +- 2 files changed, 207 insertions(+), 256 deletions(-) diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index ad70bdc36d..313ebd5100 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -2,85 +2,26 @@ import jax.numpy as jnp import numpy as np -from jax._src.lax.control_flow import fori_loop +from jax._src.lax.control_flow import scan as jax_scan from pytensor.compile.mode import JAX, get_mode from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.scan.op import Scan -def call_inner_func_with_indexed_buffers( - info, - scan_inner_func, - i, - sequences, - mit_mot_buffers, - mit_sot_buffers, - sit_sot_buffers, - shareds, - non_sequences, -): - sequence_vals = [seq[i] for seq in sequences] - - # chain.from_iterable is used flatten the first dimension of each indexed buffer - # [buf1[[idx0, idx1]], buf2[[idx0, idx1]]] -> [buf1[idx0], buf1[idx1], buf2[idx0], buf2[idx1]] - # Benchmarking suggests unpacking advanced indexing on all taps is faster than basic index one tap at a time - mit_mot_vals = list( - chain.from_iterable( - buffer[(i + np.array(in_taps))] - for buffer, in_taps in zip( - mit_mot_buffers, info.mit_mot_in_slices, strict=True - ) - ) - ) - mit_sot_vals = list( - chain.from_iterable( - # Convert negative taps (-2, -1) to positive indices (0, 1) - buffer[((i + (np.array(in_taps) - min(in_taps))) % buffer.shape[0])] - for buffer, in_taps in zip( - mit_sot_buffers, info.mit_sot_in_slices, strict=True - ) - ) - ) - sit_sot_vals = [buffer[i % buffer.shape[0]] for buffer in sit_sot_buffers] - - return scan_inner_func( - *sequence_vals, - *mit_mot_vals, - *mit_sot_vals, - *sit_sot_vals, - *shareds, - *non_sequences, - ) - - -def update_buffers(buffers, update_vals, indices, may_roll: bool = True): - return tuple( - buffer.at[(index % buffer.shape[0]) if may_roll else index].set(update_val) - for buffer, update_val, index in zip(buffers, update_vals, indices, strict=True) - ) - - -def align_buffers(buffers, n_steps, max_taps): - return [ - jnp.roll( - buffer, - shift=jnp.where( - # Only needs rolling if last write position is beyond the buffer length - (n_steps + max_tap) > buffer.shape[0], - # Roll left by the amount of overflow - -((n_steps + max_tap + 1) % buffer.shape[0]), - 0, - ), - axis=0, - ) - for buffer, max_tap in zip(buffers, max_taps, strict=True) - ] - - @jax_funcify.register(Scan) -def jax_funcify_Scan(op: Scan, node, **kwargs): - op = op # Need to bind to a local variable +def jax_funcify_Scan(op: Scan, **kwargs): + # Note: This implementation is different from the internal PyTensor Scan op. + # In particular, we don't make use of the provided buffers for recurring outputs (MIT-SOT, SIT-SOT) + # These buffers include the initial state and enough space to store as many intermediate results as needed. + # Instead, we let JAX scan recreate the concatenated buffer itself from the values computed in each iteration, + # and then prepend the initial_state and/or truncate results we don't need at the end. + # Likewise, we allow JAX to stack NIT-SOT outputs itself, instead of writing to an empty buffer with the final size. + # In contrast, MIT-MOT behave like PyTensor Scan. We read from and write to the original buffer as we iterate. + # Hopefully, JAX can do the same sort of memory optimizations as PyTensor does. + # Performance-wise, the benchmarks show this approach is better, specially when auto-diffing through JAX. + # For an implementation that is closer to the internal PyTensor Scan, check intermediate commit in + # https://github.com/pymc-devs/pytensor/pull/1651 info = op.info if info.as_while: @@ -91,199 +32,207 @@ def jax_funcify_Scan(op: Scan, node, **kwargs): get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer ) rewriter(op.fgraph) - # TODO: Use scan name from Op when available - scan_inner_func = jax_funcify(op.fgraph, fgraph_name="scan_inner_func", **kwargs) - - def scan(*outer_inputs, op=op, node=node): - n_steps = outer_inputs[0] - sequences = op.outer_seqs(outer_inputs) - has_empty_sequences = any(seq.shape[0] == 0 for seq in sequences) - init_mit_mot_buffers = op.outer_mitmot(outer_inputs) - init_mit_sot_buffers = op.outer_mitsot(outer_inputs) - init_sit_sot_buffers = op.outer_sitsot(outer_inputs) - nit_sot_buffer_lens = op.outer_nitsot(outer_inputs) - # Shareds are special-cased SIT-SOTs that are not traced, but updated at each step. - # Only last value is returned. It's a hack for special types (like RNG) that can't be "concatenated" over time. - init_shareds = op.outer_shared(outer_inputs) - non_sequences = op.outer_non_seqs(outer_inputs) - assert ( - 1 - + len(sequences) - + len(init_mit_mot_buffers) - + len(init_mit_sot_buffers) - + len(init_sit_sot_buffers) - + len(nit_sot_buffer_lens) - + len(init_shareds) - + len(non_sequences) - ) == len(outer_inputs) - - # Initialize NIT-SOT buffers - if nit_sot_buffer_lens: - if has_empty_sequences: - # In this case we cannot call the inner function to infer the shapes of the nit_sot outputs - # So we must rely on static shapes of the outputs (if available) - nit_sot_core_shapes = [ - n.type.shape for n in op.inner_nitsot_outs(op.fgraph.outputs) - ] - if any(d is None for shape in nit_sot_core_shapes for d in shape): - raise ValueError( - "Scan with NIT-SOT outputs (None in outputs_info) cannot have 0 steps unless the output shapes are statically known)\n" - f"The static shapes of the NIT-SOT outputs for this Scan {node.op} are: {nit_sot_core_shapes}." - ) + scan_inner_func = jax_funcify(op.fgraph, **kwargs) + + def scan(*outer_inputs): + # Extract JAX scan inputs + # JAX doesn't want some inputs to be tuple, but later lists (e.g., from list-comprehensions). + # We convert everything to list, so that it remains a list after slicing. + outer_inputs = list(outer_inputs) + n_steps = outer_inputs[0] # JAX `length` + seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs` + + # MIT-MOT don't have a concept of "initial state" + # The whole buffer is meaningful at the start of the Scan + mit_mot_init = op.outer_mitmot(outer_inputs) + + # For MIT-SOT and SIT-SOT, extract the initial states from the outer input buffers + mit_sot_init = [ + buff[: -min(tap)] + for buff, tap in zip( + op.outer_mitsot(outer_inputs), op.info.mit_sot_in_slices, strict=True + ) + ] + sit_sot_init = [buff[0] for buff in op.outer_sitsot(outer_inputs)] - else: - # Otherwise, call the function once to get the shapes and dtypes of the nit_sot outputs - buffer_vals = call_inner_func_with_indexed_buffers( - info, - scan_inner_func, - 0, - sequences, - init_mit_mot_buffers, - init_mit_sot_buffers, - init_sit_sot_buffers, - init_shareds, - non_sequences, - ) - nit_sot_core_shapes = [ - n.shape for n in op.inner_nitsot_outs(buffer_vals) - ] - nit_sot_dtypes = [ - n.type.dtype for n in op.inner_nitsot_outs(op.fgraph.outputs) - ] - init_nit_sot_buffers = tuple( - jnp.empty( - (nit_sot_buffer_len, *nit_sot_core_shape), - dtype=nit_sot_dtype, + init_carry = ( + 0, # loop counter, needed for indexing MIT-MOT + mit_mot_init, + mit_sot_init, + sit_sot_init, + op.outer_shared(outer_inputs), + op.outer_non_seqs(outer_inputs), + ) # JAX `init` + + def jax_args_to_inner_func_args(carry, x): + """Convert JAX scan arguments into format expected by scan_inner_func. + + scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs) + """ + + # `carry` contains all inner taps, shared terms, and non_seqs + ( + i, + inner_mit_mot, + inner_mit_sot, + inner_sit_sot, + inner_shared, + inner_non_seqs, + ) = carry + + # `x` contains the inner sequences + inner_seqs = x + + # chain.from_iterable is used to flatten the first dimension of each indexed buffer + # [buf1[[idx0, idx1]], buf2[[idx0, idx1]]] -> [buf1[idx0], buf1[idx1], buf2[idx0], buf2[idx1]] + # Benchmarking suggests unpacking advanced indexing on all taps is faster than basic index one tap at a time + mit_mot_flatten = list( + chain.from_iterable( + buffer[(i + np.array(taps))] + for buffer, taps in zip( + inner_mit_mot, info.mit_mot_in_slices, strict=True + ) ) - for nit_sot_buffer_len, nit_sot_core_shape, nit_sot_dtype in zip( - nit_sot_buffer_lens, - nit_sot_core_shapes, - nit_sot_dtypes, - strict=True, + ) + mit_sot_flatten = list( + chain.from_iterable( + buffer[np.array(taps)] + for buffer, taps in zip( + inner_mit_sot, info.mit_sot_in_slices, strict=True + ) ) ) - else: - init_nit_sot_buffers = () - - if has_empty_sequences: - # fori_loop still gets called with n_steps=0, which would raise an IndexError, we return early here - init_vals = ( - *init_mit_mot_buffers, - *init_mit_sot_buffers, - *init_sit_sot_buffers, - *init_nit_sot_buffers, - *init_shareds, + + return ( + *inner_seqs, + *mit_mot_flatten, + *mit_sot_flatten, + *inner_sit_sot, + *inner_shared, + *inner_non_seqs, ) - return init_vals[0] if len(init_vals) == 1 else init_vals - def body_fun(i, prev_vals): + def inner_func_outs_to_jax_outs( + old_carry, + inner_scan_outs, + ): + """Convert inner_scan_func outputs into format expected by JAX scan. + + old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys) + """ ( - mit_mot_buffers, - mit_sot_buffers, - sit_sot_buffers, - nit_sot_buffers, - shareds, - ) = prev_vals - - next_vals = call_inner_func_with_indexed_buffers( - info, - scan_inner_func, i, - sequences, - mit_mot_buffers, - mit_sot_buffers, - sit_sot_buffers, - shareds, - non_sequences, - ) - # For MIT-MOT buffers, we want to store at the positions indicated by the output taps - mit_mot_updated_buffers = update_buffers( - mit_mot_buffers, - op.inner_mitmot_outs_grouped(next_vals), - # Taps are positive, we stack them to obtain advanced indices - indices=[i + jnp.stack(taps) for taps in info.mit_mot_out_slices], - # MIT-MOT buffers never roll, as they are never truncated - may_roll=False, - ) - # For regular buffers, we want to store at the position after the last reading - mit_sot_updated_buffers = update_buffers( - mit_sot_buffers, - op.inner_mitsot_outs(next_vals), - indices=[i - min(taps) for taps in info.mit_sot_in_slices], - ) - sit_sot_updated_buffers = update_buffers( - sit_sot_buffers, - op.inner_sitsot_outs(next_vals), - # Taps are always -1 for SIT-SOT, so we just use i + 1 - indices=[i + 1 for _ in sit_sot_buffers], - ) - nit_sot_updated_buffers = update_buffers( - nit_sot_buffers, - op.inner_nitsot_outs(next_vals), - # Taps are always 0 for NIT-SOT, so we just use i - indices=[i for _ in nit_sot_buffers], + old_mit_mot, + old_mit_sot, + _old_sit_sot, + _old_shared, + inner_non_seqs, + ) = old_carry + + new_mit_mot_vals = op.inner_mitmot_outs_grouped(inner_scan_outs) + new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs) + new_sit_sot = op.inner_sitsot_outs(inner_scan_outs) + new_nit_sot = op.inner_nitsot_outs(inner_scan_outs) + new_shared = op.inner_shared_outs(inner_scan_outs) + + # New carry for next step + # Update MIT-MOT buffer at positions indicated by output taps + new_mit_mot = [ + buffer.at[i + np.array(taps)].set(new_vals) + for buffer, new_vals, taps in zip( + old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True + ) + ] + # Discard oldest MIT-SOT and append newest value + new_mit_sot = [ + jnp.concatenate([old_buffer[1:], new_val[None, ...]], axis=0) + for old_buffer, new_val in zip( + old_mit_sot, new_mit_sot_vals, strict=True + ) + ] + # For SIT-SOT, and shared just pass along the new value + # Non-sequences remain unchanged + new_carry = ( + i + 1, + new_mit_mot, + new_mit_sot, + new_sit_sot, + new_shared, + inner_non_seqs, ) - shareds_update_vals = op.inner_shared_outs(next_vals) - return ( - mit_mot_updated_buffers, - mit_sot_updated_buffers, - sit_sot_updated_buffers, - nit_sot_updated_buffers, - shareds_update_vals, - ) + # Select new MIT-SOT, SIT-SOT, and NIT-SOT for tracing + traced_outs = [ + *new_mit_sot_vals, + *new_sit_sot, + *new_nit_sot, + ] + return new_carry, traced_outs + + def jax_inner_func(carry, x): + inner_args = jax_args_to_inner_func_args(carry, x) + inner_scan_outs = list(scan_inner_func(*inner_args)) + new_carry, traced_outs = inner_func_outs_to_jax_outs(carry, inner_scan_outs) + return new_carry, traced_outs + + # Extract PyTensor scan outputs ( - updated_mit_mot_buffers, - updated_mit_sot_buffers, - updated_sit_sot_buffers, - updated_nit_sot_buffers, - updated_shareds, - ) = fori_loop( - 0, - n_steps, - body_fun, - init_val=( - init_mit_mot_buffers, - init_mit_sot_buffers, - init_sit_sot_buffers, - init_nit_sot_buffers, - init_shareds, + ( + _final_i, + final_mit_mot, + _final_mit_sot, + _final_sit_sot, + final_shared, + _final_non_seqs, ), - ) - - # Roll the output buffers to match PyTensor Scan semantics - # MIT-MOT buffers are never truncated, so no rolling is needed - aligned_mit_mot_buffers = updated_mit_mot_buffers - aligned_mit_sot_buffers = align_buffers( - updated_mit_sot_buffers, - n_steps, - # (-3, -1) -> max is 2 - max_taps=[-min(taps) - 1 for taps in info.mit_sot_in_slices], - ) - - aligned_sit_sot_buffers = align_buffers( - updated_sit_sot_buffers, - n_steps, - max_taps=[0 for _ in updated_sit_sot_buffers], - ) - aligned_nit_sot_buffers = align_buffers( - updated_nit_sot_buffers, - n_steps, - max_taps=[0 for _ in updated_nit_sot_buffers], - ) - - all_outputs = tuple( - chain.from_iterable( - ( - aligned_mit_mot_buffers, - aligned_mit_sot_buffers, - aligned_sit_sot_buffers, - aligned_nit_sot_buffers, - updated_shareds, - ) + traces, + ) = jax_scan(jax_inner_func, init_carry, seqs, length=n_steps) + + def get_partial_traces(traces): + """Convert JAX scan traces to PyTensor traces. + + We need to: + 1. Prepend initial states to JAX output traces + 2. Slice final traces if Scan was instructed to only keep a portion + """ + + init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot + buffers = ( + op.outer_mitsot(outer_inputs) + + op.outer_sitsot(outer_inputs) + + op.outer_nitsot(outer_inputs) ) - ) - return all_outputs[0] if len(all_outputs) == 1 else all_outputs + partial_traces = [] + for init_state, trace, buffer in zip( + init_states, traces, buffers, strict=True + ): + if init_state is not None: + # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer + trace = jnp.atleast_1d(trace) + init_state = jnp.expand_dims( + init_state, range(trace.ndim - init_state.ndim) + ) + full_trace = jnp.concatenate([init_state, trace], axis=0) + buffer_size = buffer.shape[0] + else: + # NIT-SOT: Buffer is just the number of entries that should be returned + full_trace = jnp.atleast_1d(trace) + buffer_size = buffer + + partial_trace = full_trace[-buffer_size:] + partial_traces.append(partial_trace) + + return partial_traces + + scan_outs_final = [ + *final_mit_mot, + *get_partial_traces(traces), + *final_shared, + ] + + if len(scan_outs_final) == 1: + scan_outs_final = scan_outs_final[0] + return scan_outs_final return scan diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 0b4cba30cb..ff9f4893af 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -8,7 +8,6 @@ from pytensor.compile import get_mode from pytensor.configdefaults import config from pytensor.graph import Apply, Op -from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.scan import until from pytensor.scan.basic import scan from pytensor.scan.op import Scan @@ -335,6 +334,9 @@ def test_default_mode_excludes_incompatible_rewrites(): def test_dynamic_sequence_length(): + # Imported here to not trigger import of JAX in non-JAX CI jobs + from pytensor.link.jax.dispatch.basic import jax_funcify + class IncWithoutStaticShape(Op): def make_node(self, x): x = pt.as_tensor_variable(x) @@ -358,10 +360,10 @@ def _(op, **kwargs): assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1 np.testing.assert_allclose(f([[1, 2, 3]]), np.array([[2, 3, 4]])) - with pytest.raises(ValueError): - f(np.zeros((0, 3))) + # This works if we use JAX scan internally, but not if we use a fori_loop with a buffer allocated by us + np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3))) - # But should be fine with static shape + # With known static shape we should always manage, regardless of the internal implementation out2, _ = scan( lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape), outputs_info=[None], From be7087198bca8334641f1c9cc3d4c13a4739174a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 13 Oct 2025 16:42:49 +0200 Subject: [PATCH 5/5] Optimize partial trace definition --- pytensor/link/jax/dispatch/scan.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index 313ebd5100..3082c6481a 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -209,18 +209,28 @@ def get_partial_traces(traces): ): if init_state is not None: # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer - trace = jnp.atleast_1d(trace) - init_state = jnp.expand_dims( - init_state, range(trace.ndim - init_state.ndim) - ) - full_trace = jnp.concatenate([init_state, trace], axis=0) buffer_size = buffer.shape[0] + if trace.shape[0] > buffer_size: + # Trace is longer than buffer, keep just the last `buffer.shape[0]` entries + partial_trace = trace[-buffer_size:] + else: + # Trace is shorter than buffer, this happens when we keep the initial_state + if init_state.ndim < buffer.ndim: + init_state = init_state[None] + if ( + n_init_needed := buffer_size - trace.shape[0] + ) < init_state.shape[0]: + # We may not need to keep all the initial states + init_state = init_state[-n_init_needed:] + partial_trace = jnp.concatenate([init_state, trace], axis=0) else: # NIT-SOT: Buffer is just the number of entries that should be returned - full_trace = jnp.atleast_1d(trace) buffer_size = buffer + partial_trace = ( + trace[-buffer_size:] if trace.shape[0] > buffer else trace + ) - partial_trace = full_trace[-buffer_size:] + assert partial_trace.shape[0] == buffer_size partial_traces.append(partial_trace) return partial_traces