-
Notifications
You must be signed in to change notification settings - Fork 79
Made trajectory reporter optional in static function #441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
torch_sim/runners.py
Outdated
| else: | ||
| # Collect base properties for each system when no reporter | ||
| for sys_idx in range(sub_state.n_systems): | ||
| atom_mask = sub_state.system_idx == sys_idx | ||
| base_props: dict[str, torch.Tensor] = { | ||
| "potential_energy": static_state.energy[sys_idx], | ||
| } | ||
| if model.compute_forces: | ||
| base_props["forces"] = static_state.forces[atom_mask] | ||
| if model.compute_stress: | ||
| base_props["stress"] = static_state.stress[sys_idx] | ||
| all_props.append(base_props) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure this is faster?
The profiling data doesn't really say anything about how long the report call takes, and this isn't benchmarked against the current code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I confirm the speedup. Below are the results from running the run_torchsim_static function in 8.scaling.py (see the static PR) before and after the fix. We observe a speedup of up to 24.3%, and it continues to increase with system size. In addition, I verified that the cost associated to the trajectory reporter disappears from the profiling analysis.
Previous results:
=== Static benchmark ===
n=1 static_time=1.928943s
n=1 static_time=0.272846s
n=1 static_time=0.683335s
n=1 static_time=0.026675s
n=10 static_time=0.281990s
n=100 static_time=0.705871s
n=500 static_time=1.510273s
n=1000 static_time=1.528872s
n=2500 static_time=3.809000s
n=5000 static_time=7.890238s
New results:
n=1 static_time=2.165601s
n=1 static_time=0.271961s
n=1 static_time=0.665016s
n=1 static_time=0.022899s
n=10 static_time=0.295468s
n=100 static_time=0.692651s
n=500 static_time=1.411905s
n=1000 static_time=1.291772s -> 18.3% speedup
n=2500 static_time=3.175455s -> 19.9% speedup
n=5000 static_time=6.348887s -> 24.3% speedup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused, is the 24% speedup here coming from just from removing the report call or all of the changes in the other PR too? I'm curious of the effect just the report call has, I just want to make sure that specific optimizations are well-founded, I'm not opposed to making this change.
|
Upon further review, it seems the slowdown from calling When def report(self, state, step, model=None):
if self.filenames is None:
return self._extract_props_batched(state, step, model)
# ... existing split-based logic unchanged ...
def _extract_props_batched(self, state, step, model):
sizes = state.n_atoms_per_system.tolist()
n_sys = state.n_systems
all_props: list[dict[str, torch.Tensor]] = [{} for _ in range(n_sys)]
for frequency, calculators in self.prop_calculators.items():
if frequency == 0:
continue
for prop_name, prop_fn in calculators.items():
result = prop_fn(state, model)
if result.dim() == 0:
result = result.unsqueeze(0)
# infer per-atom vs per-system from shape
if result.shape[0] == state.n_atoms:
splits = torch.split(result, sizes)
else:
splits = [result[i] for i in range(n_sys)]
for idx in range(n_sys):
sys_step = step[idx] if isinstance(step, list) else step
if sys_step % frequency == 0:
all_props[idx][prop_name] = splits[idx]
return all_props |
By making the trajectory reporter optional, we avoid additional computational overhead (see profiling plot in this PR).