Skip to content

Commit 17cb1e9

Browse files
committed
Numba Scan: prevent alias of outputs
Also simplified test. Shared variables aren't needed for the test and clobber it
1 parent 6e68b14 commit 17cb1e9

File tree

3 files changed

+76
-41
lines changed

3 files changed

+76
-41
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
from numba import types
66
from numba.extending import overload
77

8-
from pytensor import In
9-
from pytensor.compile.function.types import add_supervisor_to_fgraph
8+
from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy
9+
from pytensor.compile.io import In, Out
1010
from pytensor.compile.mode import NUMBA, get_mode
1111
from pytensor.link.numba.cache import compile_numba_function_src
1212
from pytensor.link.numba.dispatch import basic as numba_basic
1313
from pytensor.link.numba.dispatch.basic import (
14-
create_arg_string,
1514
create_tuple_string,
1615
numba_funcify_and_cache_key,
1716
register_funcify_and_cache_key,
@@ -89,14 +88,15 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
8988
if outer_mitsot.type.shape[0] == abs(min(taps))
9089
]
9190
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
91+
input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs]
9292
add_supervisor_to_fgraph(
9393
fgraph=fgraph,
94-
input_specs=[
95-
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
96-
],
94+
input_specs=input_specs,
9795
accept_inplace=True,
9896
)
9997
rewriter(fgraph)
98+
output_specs = [Out(x, borrow=False) for x in fgraph.outputs]
99+
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
100100

101101
scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph)
102102

tests/link/numba/test_scan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,7 @@ def test_higher_order_derivatives():
661661

662662
def test_grad_until_and_truncate_sequence_taps():
663663
ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode="NUMBA")
664+
665+
666+
def test_aliased_inner_outputs():
667+
ScanCompatibilityTests.check_aliased_inner_outputs(static_shape=True, mode="NUMBA")

tests/scan/test_basic.py

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3179,40 +3179,9 @@ def onestep(x, x_tm4):
31793179
f = function([seq], results[1])
31803180
assert np.all(exp_out == f(inp))
31813181

3182-
def test_shared_borrow(self):
3183-
"""
3184-
This tests two things. The first is a bug occurring when scan wrongly
3185-
used the borrow flag. The second thing it that Scan's infer_shape()
3186-
method will be able to remove the Scan node from the graph in this
3187-
case.
3188-
"""
3189-
3190-
inp = np.arange(10).reshape(-1, 1).astype(config.floatX)
3191-
exp_out = np.zeros((10, 1)).astype(config.floatX)
3192-
exp_out[4:] = inp[:-4]
3193-
3194-
def onestep(x, x_tm4):
3195-
return x, x_tm4
3196-
3197-
seq = matrix()
3198-
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
3199-
outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
3200-
results = scan(
3201-
fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False
3202-
)
3203-
sharedvar = shared(np.zeros((1, 1), dtype=config.floatX))
3204-
updates = {sharedvar: results[0][-1:]}
3205-
3206-
f = function([seq], results[1], updates=updates)
3207-
3208-
# This fails if scan uses wrongly the borrow flag
3209-
assert np.all(exp_out == f(inp))
3210-
3211-
# This fails if Scan's infer_shape() is unable to remove the Scan
3212-
# node from the graph.
3213-
f_infershape = function([seq], results[1].shape, mode="FAST_RUN")
3214-
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
3215-
assert len(scan_nodes_infershape) == 0
3182+
@pytest.mark.parametrize("static_shape", (True, False))
3183+
def test_aliased_inner_outputs(self, static_shape):
3184+
ScanCompatibilityTests.check_aliased_inner_outputs(static_shape, mode=None)
32163185

32173186
def test_memory_reuse_with_outputs_as_inputs(self):
32183187
"""
@@ -4327,7 +4296,6 @@ def check_higher_order_derivative(mode):
43274296
# FIXME: All implementations of Scan seem to get this one wrong!
43284297
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
43294298

4330-
43314299
@staticmethod
43324300
def check_grad_until_and_truncate_sequence_taps(mode):
43334301
"""Test case where we need special behavior of zeroing out sequences in Scan"""
@@ -4349,3 +4317,66 @@ def check_grad_until_and_truncate_sequence_taps(mode):
43494317
grad_expected = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
43504318
grad_expected = grad_expected.astype(config.floatX)
43514319
np.testing.assert_allclose(grad_res, grad_expected)
4320+
4321+
@staticmethod
4322+
def check_aliased_inner_outputs(static_shape, mode):
4323+
"""
4324+
This tests two things. The first is a bug occurring when scan wrongly
4325+
used the borrow flag. The second thing it that Scan's infer_shape()
4326+
method will be able to remove the Scan node from the graph in this
4327+
case.
4328+
4329+
Here is pure python equivalent of the problem we want to avoid:
4330+
```python
4331+
def scan(seq, initval):
4332+
# Due to memory optimization we override values of mitsot as we iterate
4333+
# That's why mitsot has shape (4, 1) and not (14, 1)
4334+
mitsot = np.zeros((4, 1))
4335+
mitsot[:4] = initval
4336+
nitsot = np.zeros((10, 1))
4337+
for i, s in enumerate(seq):
4338+
# Incorrect results
4339+
mitsot[(i+4) % 4], nitsot[i] = s, mitsot[i % 4]
4340+
# Correct results
4341+
# mitsot[(i + 4) % 4], nitsot[i] = s, mitsot[i % 4].copy()
4342+
4343+
return mitsot[(i + 4) % 4: (i+4 + 1) % 4], nitsot
4344+
4345+
scan(np.arange(10), np.zeros((4, 1)))
4346+
```
4347+
"""
4348+
4349+
def onestep(seq, seq_tm4):
4350+
# Recurring output is just each value of seq
4351+
# And we further map the tap -4 as a new output
4352+
return seq, seq_tm4
4353+
4354+
# Outer tensors must be atleast matrix, so that they we have vectors in the inner loop
4355+
# Otherwise we would be working with scalars and memory alias wouldn't be a concern
4356+
seq = matrix(shape=(10, 1) if static_shape else (None, None), name="seq")
4357+
init = matrix(shape=(4, 1) if static_shape else (None, None), name="init")
4358+
outputs_info = [{"initial": init, "taps": [-4]}, None]
4359+
[out_seq, out_seq_tm4] = scan(
4360+
fn=onestep,
4361+
sequences=seq,
4362+
outputs_info=outputs_info,
4363+
return_updates=False,
4364+
)
4365+
4366+
f = function([seq, init], [out_seq[-1].ravel(), out_seq_tm4.ravel()], mode=mode)
4367+
4368+
seq_test_val = np.arange(10, dtype=config.floatX)[:, None]
4369+
init_test_val = np.zeros((4, 1), dtype=config.floatX)
4370+
4371+
res0, res1 = f(seq_test_val, init_test_val)
4372+
expected_res0 = np.array([9], dtype=config.floatX)
4373+
expected_res1 = np.zeros(10, dtype=config.floatX)
4374+
expected_res1[4:] = np.arange(6)
4375+
np.testing.assert_array_equal(res0, expected_res0)
4376+
np.testing.assert_array_equal(res1, expected_res1)
4377+
4378+
# This fails if Scan's infer_shape() is unable to remove the Scan
4379+
# node from the graph.
4380+
f_infershape = function([seq, init], out_seq_tm4[1].shape)
4381+
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
4382+
assert len(scan_nodes_infershape) == 0

0 commit comments

Comments
 (0)