7
7
from spacy .tokens import Doc
8
8
from typing_extensions import Literal , NotRequired , TypedDict
9
9
10
+ import edsnlp
10
11
from edsnlp .core .pipeline import PipelineProtocol
11
12
from edsnlp .core .torch_component import BatchInput , TorchComponent
12
13
from edsnlp .pipes .base import BaseComponent
33
34
)
34
35
35
36
37
+ @edsnlp .registry .misc .register ("focal_loss" )
38
+ class FocalLoss (nn .Module ):
39
+ """
40
+ Focal Loss implementation for multi-class classification.
41
+
42
+ Parameters
43
+ ----------
44
+ alpha : torch.Tensor or float, optional
45
+ Class weights. If None, no weighting is applied
46
+ gamma : float, default=2.0
47
+ Focusing parameter. Higher values give more weight to hard examples
48
+ reduction : str, default='mean'
49
+ Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
50
+ """
51
+
52
+ def __init__ (
53
+ self ,
54
+ alpha : Optional [Union [torch .Tensor , float ]] = None ,
55
+ gamma : float = 2.0 ,
56
+ reduction : str = "mean" ,
57
+ ):
58
+ super ().__init__ ()
59
+ self .alpha = alpha
60
+ self .gamma = gamma
61
+ self .reduction = reduction
62
+
63
+ def forward (self , inputs : torch .Tensor , targets : torch .Tensor ) -> torch .Tensor :
64
+ """
65
+ Forward pass
66
+ """
67
+ ce_loss = torch .nn .functional .cross_entropy (
68
+ inputs , targets , weight = self .alpha , reduction = "none"
69
+ )
70
+
71
+ pt = torch .exp (- ce_loss )
72
+
73
+ focal_loss = (1 - pt ) ** self .gamma * ce_loss
74
+
75
+ if self .reduction == "mean" :
76
+ return focal_loss .mean ()
77
+ elif self .reduction == "sum" :
78
+ return focal_loss .sum ()
79
+ else :
80
+ return focal_loss
81
+
82
+
36
83
class TrainableDocClassifier (
37
84
TorchComponent [DocClassifierBatchOutput , DocClassifierBatchInput ],
38
85
BaseComponent ,
@@ -49,9 +96,9 @@ def __init__(
49
96
label_attr : str = "label" ,
50
97
label2id : Optional [Dict [str , int ]] = None ,
51
98
id2label : Optional [Dict [int , str ]] = None ,
52
- loss_fn = None ,
99
+ loss : Literal [ "ce" , "focal" ] = "ce" ,
53
100
labels : Optional [Sequence [str ]] = None ,
54
- class_weights : Optional [Union [ Dict [str , float ], str ]] = None ,
101
+ class_weights : Optional [Dict [str , float ]] = None ,
55
102
hidden_size : Optional [int ] = None ,
56
103
activation_mode : Literal ["relu" , "gelu" , "silu" ] = "relu" ,
57
104
dropout_rate : Optional [float ] = 0.0 ,
@@ -71,8 +118,7 @@ def __init__(
71
118
super ().__init__ (nlp , name )
72
119
self .embedding = embedding
73
120
74
- self ._loss_fn = loss_fn
75
- self .loss_fn = None
121
+ self .loss = loss
76
122
77
123
if not hasattr (self .embedding , "output_size" ):
78
124
raise ValueError (
@@ -112,17 +158,13 @@ def _compute_class_weights(self, freq_dict: Dict[str, int]) -> torch.Tensor:
112
158
113
159
return weights
114
160
115
- def _load_class_weights_from_file (self , filepath : str ) -> Dict [str , int ]:
116
- """Load class weights from pickle file."""
117
- with open (filepath , "rb" ) as f :
118
- return pickle .load (f )
119
-
120
161
def set_extensions (self ) -> None :
121
162
super ().set_extensions ()
122
163
if not Doc .has_extension (self .label_attr ):
123
164
Doc .set_extension (self .label_attr , default = {})
124
165
125
166
def post_init (self , gold_data : Iterable [Doc ], exclude : Set [str ]):
167
+ print ("post_init" )
126
168
if not self .label2id :
127
169
if self .labels is not None :
128
170
labels = set (self .labels )
@@ -141,22 +183,19 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
141
183
self .num_classes = len (self .label2id )
142
184
print ("num classes:" , self .num_classes )
143
185
self .build_classifier ()
144
-
186
+ print ( "label2id fini" )
145
187
weight_tensor = None
146
188
if self .class_weights is not None :
147
- if isinstance (self .class_weights , str ):
148
- freq_dict = self ._load_class_weights_from_file (self .class_weights )
149
- weight_tensor = self ._compute_class_weights (freq_dict )
150
- elif isinstance (self .class_weights , dict ):
151
- weight_tensor = self ._compute_class_weights (self .class_weights )
152
-
189
+ weight_tensor = self ._compute_class_weights (self .class_weights )
153
190
print (f"Using class weights: { weight_tensor } " )
154
-
155
- if self ._loss_fn is not None :
156
- self .loss_fn = self ._loss_fn
157
- else :
191
+ print ("weight tensor fini" )
192
+ if self .loss == "ce" :
158
193
self .loss_fn = torch .nn .CrossEntropyLoss (weight = weight_tensor )
159
-
194
+ elif self .loss == "focal" :
195
+ self .loss_fn = FocalLoss (alpha = weight_tensor , gamma = 2.0 , reduction = "mean" )
196
+ else :
197
+ raise ValueError (f"Unknown loss: { self .loss } " )
198
+ print ("loss finie" )
160
199
super ().post_init (gold_data , exclude = exclude )
161
200
162
201
def preprocess (self , doc : Doc ) -> Dict [str , Any ]:
0 commit comments