Skip to content

sol.ts contains wrong values in some cases #663

@terhorst

Description

@terhorst

When using saveat= and jump_ts=, the return values for sol.ts can be wrong/infinite. Here is a reproducible example:

import jax.numpy as jnp
import jax
import diffrax as dfx

Q = jnp.array([[-1., 1.], [0., 0.]])
y0 = jnp.array([0.5, 0.5])

def A(t, y, _):
    return Q @ y

solver = dfx.Kvaerno3()

def f(t1):
    saveat = dfx.SaveAt(t1=True, ts=[t1])
    ssc = dfx.PIDController(atol=1e-6, rtol=1e-6, jump_ts=jnp.array([0., t1]))

    res = dfx.diffeqsolve(
        dfx.ODETerm(A),
        solver=solver,
        y0=y0,
        t0=0.,
        t1=t1,
        dt0=0.01,
        stepsize_controller=ssc,
        saveat=saveat
    )
    return res.ts

for t1 in [1e0, 1e3]:
    print(t1, f(t1))


# 1.0 [1. 1.]
# 1000.0 [1000.   inf]

The expected output is [t1, t1] regardless of what t1 is. However, for certain values, the saved timepoints are erroneously [t1, inf]. The behavior seems to depend on the magnitude of t1, so I suspect maybe a jnp.nextafter-ish type of bug.

This occurs on HEAD as well as the latest release.

Might be related to #607.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions