Skip to content

feat(torch.compile): partial PET application#1029

Draft
HaoZeke wants to merge 1 commit intometatensor:mainfrom
HaoZeke:petPartialCompile
Draft

feat(torch.compile): partial PET application#1029
HaoZeke wants to merge 1 commit intometatensor:mainfrom
HaoZeke:petPartialCompile

Conversation

@HaoZeke
Copy link
Member

@HaoZeke HaoZeke commented Feb 2, 2026

Bare minimal changes to "use" torch.compile. @sirmarcel has a much nicer UPET port which can be both compiled and exported.

Of course the main blocker is that metatomic needs TorchScript at the moment.

Not something to be merged, just discussed.

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Maintainer/Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?
  • GPU tests passed (maintainer comment: "cscs-ci run")?

📚 Documentation preview 📚: https://metatrain--1029.org.readthedocs.build/en/1029/

Co-authored-by: sirmarcel <sirmarcel@users.noreply.github.com>
self._compile_enabled = self.hypers.get("compile", False)
self._compiled = False

def _maybe_compile(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this does not do a full torch compile for the full model, only in-place replacing of some modules, correct?

Do you know if full model compilation can be done?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this does not do a full torch compile for the full model, only in-place replacing of some modules, correct?

Yup.

Do you know if full model compilation can be done?

Not without a rewrite (@sirmarcel has one)

if self._compile_enabled and not self._compiled:
# Compile the GNN layers - this is the computational bottleneck
for i, layer in enumerate(self.gnn_layers):
self.gnn_layers[i] = torch.compile(layer, mode="reduce-overhead")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with fullgraph=True? This is what we would need for torch export down the line

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants