Skip to content

Commit 6e68b14

Browse files
committed
Numba Scan: zero out unwritten buffers
1 parent 0ab80e4 commit 6e68b14

File tree

3 files changed

+39
-16
lines changed

3 files changed

+39
-16
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def add_output_storage_post_proc_stmt(
254254
"""
255255
).strip()
256256
)
257+
else:
258+
# And regular loops should zero out unused entries of the output buffer
259+
# These show up with truncated gradients of while loops
260+
output_storage_post_proc_stmts.append(
261+
dedent(
262+
f"""
263+
elif {storage_size} > (i + {max_offset}):
264+
{outer_in_name}[i + {max_offset}:] = 0
265+
"""
266+
).strip()
267+
)
257268

258269
# Special in-loop statements that create (nit-sot) storage arrays after a
259270
# single iteration is performed. This is necessary because we don't know

tests/link/numba/test_scan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,7 @@ def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark
657657

658658
def test_higher_order_derivatives():
659659
ScanCompatibilityTests.check_higher_order_derivative(mode="NUMBA")
660+
661+
662+
def test_grad_until_and_truncate_sequence_taps():
663+
ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode="NUMBA")

tests/scan/test_basic.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2619,22 +2619,7 @@ def test_grad_until_and_truncate(self):
26192619
utt.assert_allclose(pytensor_gradient, self.numpy_gradient)
26202620

26212621
def test_grad_until_and_truncate_sequence_taps(self):
2622-
n = 3
2623-
r = scan(
2624-
lambda x, y, u: (x * y, until(y > u)),
2625-
sequences=dict(input=self.x, taps=[-2, 0]),
2626-
non_sequences=[self.threshold],
2627-
truncate_gradient=n,
2628-
return_updates=False,
2629-
)
2630-
g = grad(r.sum(), self.x)
2631-
f = function([self.x, self.threshold], [r, g])
2632-
_pytensor_output, pytensor_gradient = f(self.seq, 6)
2633-
2634-
# Gradient computed by hand:
2635-
numpy_grad = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
2636-
numpy_grad = numpy_grad.astype(config.floatX)
2637-
utt.assert_allclose(pytensor_gradient, numpy_grad)
2622+
ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode=None)
26382623

26392624

26402625
def test_mintap_onestep():
@@ -4341,3 +4326,26 @@ def check_higher_order_derivative(mode):
43414326
np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14)
43424327
# FIXME: All implementations of Scan seem to get this one wrong!
43434328
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
4329+
4330+
4331+
@staticmethod
4332+
def check_grad_until_and_truncate_sequence_taps(mode):
4333+
"""Test case where we need special behavior of zeroing out sequences in Scan"""
4334+
x = pt.vector("x")
4335+
threshold = pt.scalar(name="threshold", dtype="int64")
4336+
4337+
r = scan(
4338+
lambda x, y, u: (x * y, until(y > u)),
4339+
sequences=dict(input=x, taps=[-2, 0]),
4340+
non_sequences=[threshold],
4341+
truncate_gradient=3,
4342+
return_updates=False,
4343+
)
4344+
g = grad(r.sum(), x)
4345+
f = function([x, threshold], [r, g], mode=mode)
4346+
_, grad_res = f(np.arange(15, dtype=x.dtype), 6)
4347+
4348+
# Gradient computed by hand:
4349+
grad_expected = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
4350+
grad_expected = grad_expected.astype(config.floatX)
4351+
np.testing.assert_allclose(grad_res, grad_expected)

0 commit comments

Comments
 (0)