Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/compile/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self.implicit = implicit

def __str__(self):
if self.update:
if self.update is not None:
return f"In({self.variable} -> {self.update})"
else:
return f"In({self.variable})"
Expand Down
202 changes: 124 additions & 78 deletions pytensor/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import jax
from itertools import chain

import jax.numpy as jnp
import numpy as np
from jax._src.lax.control_flow import scan as jax_scan

from pytensor.compile.mode import JAX, get_mode
from pytensor.link.jax.dispatch.basic import jax_funcify
Expand All @@ -8,16 +11,22 @@

@jax_funcify.register(Scan)
def jax_funcify_Scan(op: Scan, **kwargs):
# Note: This implementation is different from the internal PyTensor Scan op.
# In particular, we don't make use of the provided buffers for recurring outputs (MIT-SOT, SIT-SOT)
# These buffers include the initial state and enough space to store as many intermediate results as needed.
# Instead, we let JAX scan recreate the concatenated buffer itself from the values computed in each iteration,
# and then prepend the initial_state and/or truncate results we don't need at the end.
# Likewise, we allow JAX to stack NIT-SOT outputs itself, instead of writing to an empty buffer with the final size.
# In contrast, MIT-MOT behave like PyTensor Scan. We read from and write to the original buffer as we iterate.
# Hopefully, JAX can do the same sort of memory optimizations as PyTensor does.
# Performance-wise, the benchmarks show this approach is better, specially when auto-diffing through JAX.
# For an implementation that is closer to the internal PyTensor Scan, check intermediate commit in
# https://github.com/pymc-devs/pytensor/pull/1651
info = op.info

if info.as_while:
raise NotImplementedError("While Scan cannot yet be converted to JAX")

if info.n_mit_mot:
raise NotImplementedError(
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
)

# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter = (
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer
Expand All @@ -27,20 +36,28 @@ def jax_funcify_Scan(op: Scan, **kwargs):

def scan(*outer_inputs):
# Extract JAX scan inputs
# JAX doesn't want some inputs to be tuple, but later lists (e.g., from list-comprehensions).
# We convert everything to list, so that it remains a list after slicing.
outer_inputs = list(outer_inputs)
n_steps = outer_inputs[0] # JAX `length`
seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs`

mit_sot_init = []
for tap, seq in zip(
op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True
):
init_slice = seq[: abs(min(tap))]
mit_sot_init.append(init_slice)
# MIT-MOT don't have a concept of "initial state"
# The whole buffer is meaningful at the start of the Scan
mit_mot_init = op.outer_mitmot(outer_inputs)

sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)]
# For MIT-SOT and SIT-SOT, extract the initial states from the outer input buffers
mit_sot_init = [
buff[: -min(tap)]
for buff, tap in zip(
op.outer_mitsot(outer_inputs), op.info.mit_sot_in_slices, strict=True
)
]
sit_sot_init = [buff[0] for buff in op.outer_sitsot(outer_inputs)]

init_carry = (
0, # loop counter, needed for indexing MIT-MOT
mit_mot_init,
mit_sot_init,
sit_sot_init,
op.outer_shared(outer_inputs),
Expand All @@ -50,11 +67,13 @@ def scan(*outer_inputs):
def jax_args_to_inner_func_args(carry, x):
"""Convert JAX scan arguments into format expected by scan_inner_func.

scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs)
"""

# `carry` contains all inner taps, shared terms, and non_seqs
(
i,
inner_mit_mot,
inner_mit_sot,
inner_sit_sot,
inner_shared,
Expand All @@ -64,69 +83,89 @@ def jax_args_to_inner_func_args(carry, x):
# `x` contains the inner sequences
inner_seqs = x

mit_sot_flatten = []
for array, index in zip(
inner_mit_sot, op.info.mit_sot_in_slices, strict=True
):
mit_sot_flatten.extend(array[jnp.array(index)])
# chain.from_iterable is used to flatten the first dimension of each indexed buffer
# [buf1[[idx0, idx1]], buf2[[idx0, idx1]]] -> [buf1[idx0], buf1[idx1], buf2[idx0], buf2[idx1]]
# Benchmarking suggests unpacking advanced indexing on all taps is faster than basic index one tap at a time
mit_mot_flatten = list(
chain.from_iterable(
buffer[(i + np.array(taps))]
for buffer, taps in zip(
inner_mit_mot, info.mit_mot_in_slices, strict=True
)
)
)
mit_sot_flatten = list(
chain.from_iterable(
buffer[np.array(taps)]
for buffer, taps in zip(
inner_mit_sot, info.mit_sot_in_slices, strict=True
)
)
)

inner_scan_inputs = [
return (
*inner_seqs,
*mit_mot_flatten,
*mit_sot_flatten,
*inner_sit_sot,
*inner_shared,
*inner_non_seqs,
]

return inner_scan_inputs
)

def inner_func_outs_to_jax_outs(
old_carry,
inner_scan_outs,
):
"""Convert inner_scan_func outputs into format expected by JAX scan.

old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys)
"""
(
inner_mit_sot,
_inner_sit_sot,
inner_shared,
i,
old_mit_mot,
old_mit_sot,
_old_sit_sot,
_old_shared,
inner_non_seqs,
) = old_carry

inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs)
inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs)
inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs)
inner_shared_outs = op.inner_shared_outs(inner_scan_outs)

# Replace the oldest mit_sot tap by the newest value
inner_mit_sot_new = [
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
for old_mit_sot, new_val in zip(
inner_mit_sot, inner_mit_sot_outs, strict=True
new_mit_mot_vals = op.inner_mitmot_outs_grouped(inner_scan_outs)
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
new_shared = op.inner_shared_outs(inner_scan_outs)

# New carry for next step
# Update MIT-MOT buffer at positions indicated by output taps
new_mit_mot = [
buffer.at[i + np.array(taps)].set(new_vals)
for buffer, new_vals, taps in zip(
old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True
)
]

# Nothing needs to be done with sit_sot
inner_sit_sot_new = inner_sit_sot_outs

inner_shared_new = inner_shared
# Replace old shared inputs by new shared outputs
inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs

# Discard oldest MIT-SOT and append newest value
new_mit_sot = [
jnp.concatenate([old_buffer[1:], new_val[None, ...]], axis=0)
for old_buffer, new_val in zip(
old_mit_sot, new_mit_sot_vals, strict=True
)
]
# For SIT-SOT, and shared just pass along the new value
# Non-sequences remain unchanged
new_carry = (
inner_mit_sot_new,
inner_sit_sot_new,
inner_shared_new,
i + 1,
new_mit_mot,
new_mit_sot,
new_sit_sot,
new_shared,
inner_non_seqs,
)

# Shared variables and non_seqs are not traced
# Select new MIT-SOT, SIT-SOT, and NIT-SOT for tracing
traced_outs = [
*inner_mit_sot_outs,
*inner_sit_sot_outs,
*inner_nit_sot_outs,
*new_mit_sot_vals,
*new_sit_sot,
*new_nit_sot,
]

return new_carry, traced_outs
Expand All @@ -138,9 +177,17 @@ def jax_inner_func(carry, x):
return new_carry, traced_outs

# Extract PyTensor scan outputs
final_carry, traces = jax.lax.scan(
jax_inner_func, init_carry, seqs, length=n_steps
)
(
(
_final_i,
final_mit_mot,
_final_mit_sot,
_final_sit_sot,
final_shared,
_final_non_seqs,
),
traces,
) = jax_scan(jax_inner_func, init_carry, seqs, length=n_steps)

def get_partial_traces(traces):
"""Convert JAX scan traces to PyTensor traces.
Expand All @@ -162,38 +209,37 @@ def get_partial_traces(traces):
):
if init_state is not None:
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
trace = jnp.atleast_1d(trace)
init_state = jnp.expand_dims(
init_state, range(trace.ndim - init_state.ndim)
)
full_trace = jnp.concatenate([init_state, trace], axis=0)
buffer_size = buffer.shape[0]
if trace.shape[0] > buffer_size:
# Trace is longer than buffer, keep just the last `buffer.shape[0]` entries
partial_trace = trace[-buffer_size:]
else:
# Trace is shorter than buffer, this happens when we keep the initial_state
if init_state.ndim < buffer.ndim:
init_state = init_state[None]
if (
n_init_needed := buffer_size - trace.shape[0]
) < init_state.shape[0]:
# We may not need to keep all the initial states
init_state = init_state[-n_init_needed:]
partial_trace = jnp.concatenate([init_state, trace], axis=0)
else:
# NIT-SOT: Buffer is just the number of entries that should be returned
full_trace = jnp.atleast_1d(trace)
buffer_size = buffer
partial_trace = (
trace[-buffer_size:] if trace.shape[0] > buffer else trace
)

partial_trace = full_trace[-buffer_size:]
assert partial_trace.shape[0] == buffer_size
partial_traces.append(partial_trace)

return partial_traces

def get_shared_outs(final_carry):
"""Retrive last state of shared_outs from final_carry.

These outputs cannot be traced in PyTensor Scan
"""
(
_inner_out_mit_sot,
_inner_out_sit_sot,
inner_out_shared,
_inner_in_non_seqs,
) = final_carry

shared_outs = inner_out_shared[: info.n_shared_outs]
return list(shared_outs)

scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)
scan_outs_final = [
*final_mit_mot,
*get_partial_traces(traces),
*final_shared,
]

if len(scan_outs_final) == 1:
scan_outs_final = scan_outs_final[0]
Expand Down
11 changes: 11 additions & 0 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,17 @@ def inner_mitmot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
return list_outputs[:n_taps]

def inner_mitmot_outs_grouped(self, list_outputs):
# Like inner_mitmot_outs but returns a list of lists, one per mitmot
# Instead of a flat list
n_taps = [len(x) for x in self.info.mit_mot_out_slices]
grouped_outs = []
offset = 0
for nt in n_taps:
grouped_outs.append(list_outputs[offset : offset + nt])
offset += nt
return grouped_outs

def outer_mitmot_outs(self, list_outputs):
return list_outputs[: self.info.n_mit_mot]

Expand Down
Loading