From bddcb4de53a7f1ebc7af5307d81c0bd39f70fff6 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 13 Jul 2025 12:07:23 +0200 Subject: [PATCH] Tests that a jump at t1 is saved. --- test/test_adaptive_stepsize_controller.py | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 8161b9d8..21c24e4a 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -336,3 +336,28 @@ def test_implicit_solver_with_clip_controller(new: bool): max_steps=16384, saveat=diffrax.SaveAt(t1=True), ) + + +# https://github.com/patrick-kidger/diffrax/issues/663 +# `jump_ts` sets the time we step to as `prevbefore` the time provided. +# Clipping at t1 saves us! We need to clip at at least 1 ULP. +def test_jump_at_t1_with_large_t1_in_float32(): + t0 = jnp.array(0.0, dtype=jnp.float32) + t1 = jnp.array(1e3, dtype=jnp.float32) + dt0 = jnp.array(0.01, dtype=jnp.float32) + y0 = jnp.array(1, dtype=jnp.float32) + saveat = diffrax.SaveAt(ts=t1[None]) + ssc = diffrax.ClipStepSizeController( + diffrax.PIDController(atol=1e-6, rtol=1e-6), jump_ts=t1[None] + ) + sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: -y), + diffrax.Heun(), + t0=t0, + t1=t1, + dt0=dt0, + y0=y0, + stepsize_controller=ssc, + saveat=saveat, + ) + assert sol.ts == jnp.array([t1])