Fix time collision bug in LogTrajectory#397
Conversation
…into sw-time-collision
…into sw-time-collision
| @@ -27,6 +31,19 @@ def __enter__(self) -> "LogTrajectory[T]": | |||
| self.trajectory: State[T] = State() | |||
There was a problem hiding this comment.
@eb8680 , do you think it makes sense to move this line into _pyro_simulate instead of __enter__. This would change the behavior from "concatenate multiple simulate calls trajectories together" to "store only the trajectory of the final simulate call". Or alternatively, we could give simulate an optional name argument if you want to store multiple trajectories.
There was a problem hiding this comment.
Yes, re-initializing self.trajectory in _pyro_simulate seems more correct than the current behavior.
eb8680
left a comment
There was a problem hiding this comment.
Can you add a test that checks the case where a StaticInterruption's time coincides with a logging time, and also rerun the dynamical_intro notebook from scratch to verify that it no longer fails because of this bug?
I believe this existing test covers the test case? https://github.com/BasisResearch/chirho/blob/sw-time-collision/tests/dynamical/test_log_trajectory.py#L42 Happy to add more if that doesn't cover what you're looking for. I'll rerun the notebook in this PR. Would you like me to make the change we discussed in this comment in this PR or separately?
|
…into sw-time-collision
…causal_pyro into sw-time-collision
…plied after all solves
|
@eb8680 , I edited this PR to address my suggestion here: #397 (comment). It is now ready for review again. I have also strengthened the tests a bit. To ensure that handlers commute after this PR we now have the following strict order enforced regardless of the order handlers are applied: I reran the dynamical system notebook, confirming that it works with these changes (although, requiring a minor change to the name of a trace address in the final cell because of a previous PR). I'm excluding the revised run from this PR to avoid merge conflicts with #377, which I expect to be completed by EOD or early tomorrow. Now that #308 is merged, I'll add a follow up PR after #377 to add the dynamical systems notebook to CI. |
eb8680
left a comment
There was a problem hiding this comment.
Making sure LogTrajectory, Solver and BatchObservation commute is the right idea, but it seems like there's a simpler way to achieve that:
|
|
||
| def _pyro_post_simulate(self, msg: dict) -> None: | ||
| self.trajectory = observe(self.trajectory, self.observation) | ||
| def _pyro_simulate(self, msg: dict) -> None: |
There was a problem hiding this comment.
I don't think you should need to change StaticBatchObservation at all, other than adding a single super()._pyro_post_simulate call:
def _pyro_post_simulate(self, msg: dict) -> None:
super()._pyro_post_simulate(msg) # update self.trajectory
self.trajectory = observe(self.trajectory, self.observation)
| # LogTrajectory's simulate_point will log only timepoints that are greater than the start_time of each | ||
| # simulate_point call, which can occur multiple times in a single simulate call when there | ||
| # are interruptions. | ||
| self.trajectory = append( |
There was a problem hiding this comment.
I think this change is right, and is already enough to make LogTrajectory and Solver commutative. In particular, since append is associative and Solver doesn't change initial_state or start_time in its _pyro_simulate method, as long as no outside code needs to access self.trajectory before the simulate call finishes, it shouldn't matter if LogTrajectory's _pyro_simulate runs before or after Solver's.
chirho/dynamical/internals/solver.py
Outdated
|
|
||
| @typing.final | ||
| @staticmethod | ||
| def _pyro_post_simulate(msg: dict) -> None: |
There was a problem hiding this comment.
This move from _pyro_simulate to _pyro_post_simulate seems unnecessary, per my comment on LogTrajectory.
I agree. The key challenge here is actually in the assignment of I'll try and find a workaround to this problem that's a bit less verbose and touches fewer pieces. |
Hmm, what about adding an extra boolean state variable to def __init__(self, ...):
self._needs_reset: bool = False
self.trajectory = State()
...
def __enter__(self):
self._needs_reset = False
self.trajectory = State()
...
def _pyro_simulate(self, msg: dict) -> None:
if self._needs_reset:
self.trajectory = State()
self._needs_reset = False
...
def _pyro_post_simulate(self, msg: dict) -> None:
if not self._needs_reset:
self._needs_reset = True
... |
|
@eb8680 , I've made this PR a bit simpler, but in the process it exposed an error that was occurring with a test using The (now skipped) test fails because |
My solution (that I finished before seeing this) is in the spirit of this suggestion, but a bit less explicit and a bit more concise. I'm really happy either way. |
eb8680
left a comment
There was a problem hiding this comment.
LGTM, thanks for fixing this!
This PR addresses #396 by adding the initial state to the trajectory if the
start_timeofsimulateis equal to the first element of the (sorted)logging_timesargument of theLogTrajectoryhandler. This works because theLogTrajectoryhandler excludes time collisions on thestart_timeand includes time collisions on theend_timeof eachsimulate_pointcall, which was necessary to not "double include" interruption times that collided with an element in thelogging_timesand thus show up as thestart_timeandend_timeof twosimulate_pointcalls.In addition, this PR adds a slight modification of the test described in #396 to the test suite.