Add hook stacks to handle shared modules#54
Conversation
|
Very clever -- I love it. And thanks for adding in tests. It seems like this assumes the backward ordering of operations is always the reverse of the forward ordering, right? Is that always true for non-sequential architectures? |
|
Let me know when you're done making modifications. And would you mind doing a timing on a moderately sized model, e.g. ChromBPNet or Enformer, to make sure that the id checking doesn't significantly slow down the code? |
|
Sorry for the delay here - logging on Enformer revealed not a timing but memory issue and I changed the approach. With the This version instead stores IO pairs in a module-level dict ( On enformer (~196k bp, 10 shuffles, track 2844, NVIDIA H100 NVL) this version is only a bit slower (by 0.1s) but the memory usage is basically the same as main:
|
The current version in
tangermeme/deep_lift_shap.pystores inputs/outputs for a module by simple assignment:Issue: if a model reuses the same module multiple times in the forward pass, the hooks overwrite the input/output values each time the module is called. So, the backward pass uses the same (last seen) input/output pair for each module. If the shared module calls have matching tensor shapes, it silently fails (and loudly fails otherwise).
Fix: Implement stacks (LIFO) instead of simple assignment:
Examples:
Silent failure:
Loud failure:
The 2nd call stores [B, 5]. During backward pass, when processing the 1st forward call, it expects shape
[B, 20] but finds [B, 5] → tensor shape mismatch error.