Skip to content

Pass additional kwargs to model call function#44

Open
jvierstra wants to merge 1 commit intojmschrei:mainfrom
vierstralab:main
Open

Pass additional kwargs to model call function#44
jvierstra wants to merge 1 commit intojmschrei:mainfrom
vierstralab:main

Conversation

@jvierstra
Copy link
Copy Markdown

@jvierstra jvierstra commented Aug 2, 2025

Allow passing of additional keyword arguments to model forward (call) function in deep_lift_shap. This allows one to selectively modify outputs for interpretation (such as applying an exponential transform on prediction logins) and properly compute contributions considering non-linearities.

@jmschrei
Copy link
Copy Markdown
Owner

jmschrei commented Aug 2, 2025

I'm not sure I understand -- can't you use the args param for that?

If you want to modify the outputs of a model before running deep_lift_shap on it I'd recommend creating a short wrapper. It's a few extra lines at the beginning but way cleaner and more flexible, e.g.

from bpnet.chrombpnet import _Exp


class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.exp = _Exp()

    def forward(self, X):
        return self.exp(self.model(X))

where you use _Exp from https://github.com/jmschrei/bpnet-lite/blob/master/bpnetlite/chrombpnet.py#L20 and then also register it when you use DeepLiftShap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants