Skip to content

Commit dc2a49e

Browse files
authored
Added Feature Importance (#220)
* Added Feature Importance - added new method in TabularModel - added new method in BaseModel - Enabled Feature Importance for FTTransformer * pre-commit changes * enabling feature importance for GATE
1 parent 4aae9a8 commit dc2a49e

File tree

4 files changed

+22
-9
lines changed

4 files changed

+22
-9
lines changed

src/pytorch_tabular/models/base_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import partial
88
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
99

10+
import pandas as pd
1011
import pytorch_lightning as pl
1112
import torch
1213
import torch.nn as nn
@@ -506,6 +507,18 @@ def reset_weights(self):
506507
reset_all_weights(self.head)
507508
reset_all_weights(self.embedding_layer)
508509

510+
def feature_importance(self):
511+
if hasattr(self.backbone, "feature_importance_"):
512+
importance_df = pd.DataFrame(
513+
{
514+
"Features": self.hparams.categorical_cols + self.hparams.continuous_cols,
515+
"importance": self.backbone.feature_importance_.detach().cpu().numpy(),
516+
}
517+
)
518+
return importance_df
519+
else:
520+
raise ValueError("Feature Importance unavailable for this model.")
521+
509522

510523
class _GenericModel(BaseModel):
511524
def __init__(

src/pytorch_tabular/models/ft_transformer/ft_transformer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import math
66
from collections import OrderedDict
77

8-
import pandas as pd
98
import torch
109
import torch.nn as nn
1110
from omegaconf import DictConfig
@@ -116,7 +115,7 @@ def _calculate_feature_importance(self):
116115
for attn_weights in self.attention_weights_:
117116
self.local_feature_importance += attn_weights[:, :, :, -1].sum(dim=1)
118117
self.local_feature_importance = (1 / (h * L)) * self.local_feature_importance[:, :-1]
119-
self.feature_importance_ = self.local_feature_importance.mean(dim=0)
118+
self.feature_importance_ = self.local_feature_importance.mean(dim=0).detach().cpu().numpy()
120119
# self.feature_importance_count_+=attn_weights.shape[0]
121120

122121

@@ -146,12 +145,6 @@ def _build_network(self):
146145

147146
def feature_importance(self):
148147
if self.hparams.attn_feature_importance:
149-
importance_df = pd.DataFrame(
150-
{
151-
"Features": self.hparams.categorical_cols + self.hparams.continuous_cols,
152-
"importance": self.backbone.feature_importance_.detach().cpu().numpy(),
153-
}
154-
)
155-
return importance_df
148+
return super().feature_importance()
156149
else:
157150
raise ValueError("If you want Feature Importance, `attn_feature_weights` should be `True`.")

src/pytorch_tabular/models/gate/gate_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
129129
tree_outputs = tree_outputs.permute(1, 2, 0)
130130
return tree_outputs
131131

132+
@property
133+
def feature_importance_(self):
134+
return self.gflus.feature_mask_function(self.gflus.feature_masks).sum(dim=0).detach().cpu().numpy()
135+
132136

133137
class CustomHead(nn.Module):
134138
"""Custom Head for GATE.

src/pytorch_tabular/tabular_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,3 +1379,6 @@ def summary(self, max_depth: int = -1) -> None:
13791379

13801380
def __str__(self) -> str:
13811381
return self.summary()
1382+
1383+
def feature_importance(self):
1384+
return self.model.feature_importance()

0 commit comments

Comments
 (0)