@@ -11,11 +11,43 @@ def __init__(
1111 num_classes : Optional [int ] = None ,
1212 net : Optional [nn .Module ] = None ,
1313 ):
14+ """
15+ Classification head for text classification tasks.
16+ It is a nn.Module that can either be a simple Linear layer or a custom neural network module.
17+
18+ Args:
19+ input_dim (int, optional): Dimension of the input features. Required if net is not provided.
20+ num_classes (int, optional): Number of output classes. Required if net is not provided.
21+ net (nn.Module, optional): Custom neural network module to be used as the classification head.
22+ If provided, input_dim and num_classes are inferred from this module.
23+ Should be either an nn.Sequential with first and last layers being Linears or nn.Linear.
24+ """
1425 super ().__init__ ()
1526 if net is not None :
16- self .net = net
17- self .input_dim = net .in_features
18- self .num_classes = net .out_features
27+ # --- Custom net should either be a Sequential or a Linear ---
28+ if not (isinstance (net , nn .Sequential ) or isinstance (net , nn .Linear )):
29+ raise ValueError ("net must be an nn.Sequential when provided." )
30+
31+ # --- If Sequential, Check first and last layers are Linear ---
32+
33+ if isinstance (net , nn .Sequential ):
34+ first = net [0 ]
35+ last = net [- 1 ]
36+
37+ if not isinstance (first , nn .Linear ):
38+ raise TypeError (f"First layer must be nn.Linear, got { type (first ).__name__ } ." )
39+
40+ if not isinstance (last , nn .Linear ):
41+ raise TypeError (f"Last layer must be nn.Linear, got { type (last ).__name__ } ." )
42+
43+ # --- Extract features ---
44+ self .input_dim = first .in_features
45+ self .num_classes = last .out_features
46+ self .net = net
47+ else : # if not Sequential, it is a Linear
48+ self .input_dim = net .in_features
49+ self .num_classes = net .out_features
50+
1951 else :
2052 assert (
2153 input_dim is not None and num_classes is not None
0 commit comments