Skip to content

LogTrajectory does not handle edge cases correctly #396

@eb8680

Description

@eb8680

The logging handler chirho.dynamical.handlers.LogTrajectory does not correctly handle edge cases where the start_time of a simulation interval (including an interruption time) is one of the logging times.

Minimal failing example:

def test_mwe_fail():
    from chirho.dynamical.ops import State, simulate
    from chirho.dynamical.handlers.solver import TorchDiffEq
    from chirho.dynamical.handlers.trajectory import LogTrajectory

    dynamics = lambda s: State(X=s["X"] * (1 - s["X"]))
    init_state = State(X=torch.tensor(0.5))
    start_time, end_time = torch.tensor(0.), torch.tensor(3.)

    with TorchDiffEq(), LogTrajectory(times=torch.tensor([0., 1., 2.])) as log:
        final_state = simulate(dynamics, init_state, start_time, end_time)

    assert len(log.trajectory["X"]) == len(log.times) == 3  # fails bc len(X) == 2

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions