Skip to content

Add hook stacks to handle shared modules#54

Open
aduranu wants to merge 12 commits intojmschrei:mainfrom
aduranu:add-hook-stacks
Open

Add hook stacks to handle shared modules#54
aduranu wants to merge 12 commits intojmschrei:mainfrom
aduranu:add-hook-stacks

Conversation

@aduranu
Copy link
Copy Markdown

@aduranu aduranu commented Jan 16, 2026

The current version in tangermeme/deep_lift_shap.py stores inputs/outputs for a module by simple assignment:

module.input = input[0].clone().detach()
module.output = outputs.clone().detach()

Issue: if a model reuses the same module multiple times in the forward pass, the hooks overwrite the input/output values each time the module is called. So, the backward pass uses the same (last seen) input/output pair for each module. If the shared module calls have matching tensor shapes, it silently fails (and loudly fails otherwise).

Fix: Implement stacks (LIFO) instead of simple assignment:

  • Forward pass: push inputs/outputs onto stacks
  • Backward pass: pop from stacks (LIFO order naturally matches backward traversal)

Examples:

Silent failure:

  class SharedReluModel(torch.nn.Module):                                                                                    
      def __init__(self):                                                                                                    
          self.shared_relu = torch.nn.ReLU()                                                                                 
          self.linear1 = torch.nn.Linear(10, 10)  # Same shape in/out                                                        
          self.linear2 = torch.nn.Linear(10, 10)  # Same shape in/out                                                        
                                                                                                                             
      def forward(self, X):                                                                                                  
          X = self.shared_relu(self.linear1(X))  # 1st call: input shape [B, 10]                                             
          X = self.shared_relu(self.linear2(X))  # 2nd call: input shape [B, 10]                                             
          return X                                                                                                           

Loud failure:

  class SharedReluDifferentShapes(torch.nn.Module):                                                                          
      def __init__(self):                                                                                                    
          self.shared_relu = torch.nn.ReLU()                                                                                 
          self.linear1 = torch.nn.Linear(10, 20)  # 10 → 20                                                                  
          self.linear2 = torch.nn.Linear(20, 5)   # 20 → 5                                                                   
                                                                                                                             
      def forward(self, X):                                                                                                  
          X = self.shared_relu(self.linear1(X))  # 1st call: shape [B, 20]                                                   
          X = self.shared_relu(self.linear2(X))  # 2nd call: shape [B, 5]                                                    
          return X  

The 2nd call stores [B, 5]. During backward pass, when processing the 1st forward call, it expects shape
[B, 20] but finds [B, 5] → tensor shape mismatch error.

@jmschrei
Copy link
Copy Markdown
Owner

Very clever -- I love it. And thanks for adding in tests.

It seems like this assumes the backward ordering of operations is always the reverse of the forward ordering, right? Is that always true for non-sequential architectures?

@jmschrei
Copy link
Copy Markdown
Owner

Let me know when you're done making modifications. And would you mind doing a timing on a moderately sized model, e.g. ChromBPNet or Enformer, to make sure that the id checking doesn't significantly slow down the code?

@aduranu
Copy link
Copy Markdown
Author

aduranu commented Mar 2, 2026

Sorry for the delay here - logging on Enformer revealed not a timing but memory issue and I changed the approach. With the id(grad) keying, every tensor IO pair stayed alive and racked up GPU memory because the closures were pinned until autograd.grad() completed.

This version instead stores IO pairs in a module-level dict (module._io_pairs) keyed by a forward counter. The tensor-hook closure now captures only the integer index, so it pins no GPU memory. The _b_hook now pops the entry from the dictionary during backward and frees up memory instead of holding everything until the end of the pass.

On enformer (~196k bp, 10 shuffles, track 2844, NVIDIA H100 NVL) this version is only a bit slower (by 0.1s) but the memory usage is basically the same as main:

  • main, batch size =2: 6.238 s, 38.39 GB peak
  • add-hook-stacks, batch size=2: 6.231 s, 38.38 GB peak
  • main, batch size =4: 6.155 s, 75.58 GB peak
  • add-hook-stacks, batch size=4: 6.224 s, 75.34 GB peak

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants