Skip to content

Commit 89535cb

Browse files
committed
add focal loss
1 parent f9664f1 commit 89535cb

File tree

1 file changed

+60
-21
lines changed

1 file changed

+60
-21
lines changed

edsnlp/pipes/trainable/doc_classifier/doc_classifier.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from spacy.tokens import Doc
88
from typing_extensions import Literal, NotRequired, TypedDict
99

10+
import edsnlp
1011
from edsnlp.core.pipeline import PipelineProtocol
1112
from edsnlp.core.torch_component import BatchInput, TorchComponent
1213
from edsnlp.pipes.base import BaseComponent
@@ -33,6 +34,52 @@
3334
)
3435

3536

37+
@edsnlp.registry.misc.register("focal_loss")
38+
class FocalLoss(nn.Module):
39+
"""
40+
Focal Loss implementation for multi-class classification.
41+
42+
Parameters
43+
----------
44+
alpha : torch.Tensor or float, optional
45+
Class weights. If None, no weighting is applied
46+
gamma : float, default=2.0
47+
Focusing parameter. Higher values give more weight to hard examples
48+
reduction : str, default='mean'
49+
Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
50+
"""
51+
52+
def __init__(
53+
self,
54+
alpha: Optional[Union[torch.Tensor, float]] = None,
55+
gamma: float = 2.0,
56+
reduction: str = "mean",
57+
):
58+
super().__init__()
59+
self.alpha = alpha
60+
self.gamma = gamma
61+
self.reduction = reduction
62+
63+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
64+
"""
65+
Forward pass
66+
"""
67+
ce_loss = torch.nn.functional.cross_entropy(
68+
inputs, targets, weight=self.alpha, reduction="none"
69+
)
70+
71+
pt = torch.exp(-ce_loss)
72+
73+
focal_loss = (1 - pt) ** self.gamma * ce_loss
74+
75+
if self.reduction == "mean":
76+
return focal_loss.mean()
77+
elif self.reduction == "sum":
78+
return focal_loss.sum()
79+
else:
80+
return focal_loss
81+
82+
3683
class TrainableDocClassifier(
3784
TorchComponent[DocClassifierBatchOutput, DocClassifierBatchInput],
3885
BaseComponent,
@@ -49,9 +96,9 @@ def __init__(
4996
label_attr: str = "label",
5097
label2id: Optional[Dict[str, int]] = None,
5198
id2label: Optional[Dict[int, str]] = None,
52-
loss_fn=None,
99+
loss: Literal["ce", "focal"] = "ce",
53100
labels: Optional[Sequence[str]] = None,
54-
class_weights: Optional[Union[Dict[str, float], str]] = None,
101+
class_weights: Optional[Dict[str, float]] = None,
55102
hidden_size: Optional[int] = None,
56103
activation_mode: Literal["relu", "gelu", "silu"] = "relu",
57104
dropout_rate: Optional[float] = 0.0,
@@ -71,8 +118,7 @@ def __init__(
71118
super().__init__(nlp, name)
72119
self.embedding = embedding
73120

74-
self._loss_fn = loss_fn
75-
self.loss_fn = None
121+
self.loss = loss
76122

77123
if not hasattr(self.embedding, "output_size"):
78124
raise ValueError(
@@ -112,17 +158,13 @@ def _compute_class_weights(self, freq_dict: Dict[str, int]) -> torch.Tensor:
112158

113159
return weights
114160

115-
def _load_class_weights_from_file(self, filepath: str) -> Dict[str, int]:
116-
"""Load class weights from pickle file."""
117-
with open(filepath, "rb") as f:
118-
return pickle.load(f)
119-
120161
def set_extensions(self) -> None:
121162
super().set_extensions()
122163
if not Doc.has_extension(self.label_attr):
123164
Doc.set_extension(self.label_attr, default={})
124165

125166
def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
167+
print("post_init")
126168
if not self.label2id:
127169
if self.labels is not None:
128170
labels = set(self.labels)
@@ -141,22 +183,19 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
141183
self.num_classes = len(self.label2id)
142184
print("num classes:", self.num_classes)
143185
self.build_classifier()
144-
186+
print("label2id fini")
145187
weight_tensor = None
146188
if self.class_weights is not None:
147-
if isinstance(self.class_weights, str):
148-
freq_dict = self._load_class_weights_from_file(self.class_weights)
149-
weight_tensor = self._compute_class_weights(freq_dict)
150-
elif isinstance(self.class_weights, dict):
151-
weight_tensor = self._compute_class_weights(self.class_weights)
152-
189+
weight_tensor = self._compute_class_weights(self.class_weights)
153190
print(f"Using class weights: {weight_tensor}")
154-
155-
if self._loss_fn is not None:
156-
self.loss_fn = self._loss_fn
157-
else:
191+
print("weight tensor fini")
192+
if self.loss == "ce":
158193
self.loss_fn = torch.nn.CrossEntropyLoss(weight=weight_tensor)
159-
194+
elif self.loss == "focal":
195+
self.loss_fn = FocalLoss(alpha=weight_tensor, gamma=2.0, reduction="mean")
196+
else:
197+
raise ValueError(f"Unknown loss: {self.loss}")
198+
print("loss finie")
160199
super().post_init(gold_data, exclude=exclude)
161200

162201
def preprocess(self, doc: Doc) -> Dict[str, Any]:

0 commit comments

Comments
 (0)