-
Notifications
You must be signed in to change notification settings - Fork 18
Closed
Labels
Description
The logging handler chirho.dynamical.handlers.LogTrajectory does not correctly handle edge cases where the start_time of a simulation interval (including an interruption time) is one of the logging times.
Minimal failing example:
def test_mwe_fail():
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.handlers.trajectory import LogTrajectory
dynamics = lambda s: State(X=s["X"] * (1 - s["X"]))
init_state = State(X=torch.tensor(0.5))
start_time, end_time = torch.tensor(0.), torch.tensor(3.)
with TorchDiffEq(), LogTrajectory(times=torch.tensor([0., 1., 2.])) as log:
final_state = simulate(dynamics, init_state, start_time, end_time)
assert len(log.trajectory["X"]) == len(log.times) == 3 # fails bc len(X) == 2Reactions are currently unavailable