Skip to content

Conversation

@falletta
Copy link
Contributor

@falletta falletta commented Feb 4, 2026

By making the trajectory reporter optional, we avoid additional computational overhead (see profiling plot in this PR).

Comment on lines 819 to 830
else:
# Collect base properties for each system when no reporter
for sys_idx in range(sub_state.n_systems):
atom_mask = sub_state.system_idx == sys_idx
base_props: dict[str, torch.Tensor] = {
"potential_energy": static_state.energy[sys_idx],
}
if model.compute_forces:
base_props["forces"] = static_state.forces[atom_mask]
if model.compute_stress:
base_props["stress"] = static_state.stress[sys_idx]
all_props.append(base_props)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is faster?

The profiling data doesn't really say anything about how long the report call takes, and this isn't benchmarked against the current code.

Copy link
Contributor Author

@falletta falletta Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I confirm the speedup. Below are the results from running the run_torchsim_static function in 8.scaling.py (see the static PR) before and after the fix. We observe a speedup of up to 24.3%, and it continues to increase with system size. In addition, I verified that the cost associated to the trajectory reporter disappears from the profiling analysis.

Previous results:

=== Static benchmark ===
  n=1 static_time=1.928943s
  n=1 static_time=0.272846s
  n=1 static_time=0.683335s
  n=1 static_time=0.026675s
  n=10 static_time=0.281990s
  n=100 static_time=0.705871s
  n=500 static_time=1.510273s
  n=1000 static_time=1.528872s
  n=2500 static_time=3.809000s
  n=5000 static_time=7.890238s

New results:

  n=1 static_time=2.165601s
  n=1 static_time=0.271961s
  n=1 static_time=0.665016s
  n=1 static_time=0.022899s
  n=10 static_time=0.295468s
  n=100 static_time=0.692651s
  n=500 static_time=1.411905s
  n=1000 static_time=1.291772s -> 18.3% speedup
  n=2500 static_time=3.175455s -> 19.9% speedup
  n=5000 static_time=6.348887s -> 24.3% speedup

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, is the 24% speedup here coming from just from removing the report call or all of the changes in the other PR too? I'm curious of the effect just the report call has, I just want to make sure that specific optimizations are well-founded, I'm not opposed to making this change.

@orionarcher
Copy link
Collaborator

Upon further review, it seems the slowdown from calling report is coming from calling state.split() inside report(), which doesn't need to be called when there are no files to write. Rather than bypassing the reporter in static and duplicating property extraction in _collect_base_properties, what about adding a fast path inside report() itself?

When self.filenames is None, we can early-return by calling prop_calculators on the batched state and splitting the result tensors by n_atoms_per_system, avoiding N SimState constructions entirely. That way static doesn't need to change and we don't end up with two property-extraction code paths. Something like this (from CC):

  def report(self, state, step, model=None):
      if self.filenames is None:
          return self._extract_props_batched(state, step, model)

      # ... existing split-based logic unchanged ...

  def _extract_props_batched(self, state, step, model):
      sizes = state.n_atoms_per_system.tolist()
      n_sys = state.n_systems
      all_props: list[dict[str, torch.Tensor]] = [{} for _ in range(n_sys)]

      for frequency, calculators in self.prop_calculators.items():
          if frequency == 0:
              continue
          for prop_name, prop_fn in calculators.items():
              result = prop_fn(state, model)
              if result.dim() == 0:
                  result = result.unsqueeze(0)

              # infer per-atom vs per-system from shape
              if result.shape[0] == state.n_atoms:
                  splits = torch.split(result, sizes)
              else:
                  splits = [result[i] for i in range(n_sys)]

              for idx in range(n_sys):
                  sys_step = step[idx] if isinstance(step, list) else step
                  if sys_step % frequency == 0:
                      all_props[idx][prop_name] = splits[idx]

      return all_props

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants