Skip to content

Commit 723abe8

Browse files
fix(ClassifHead): enable Sequential as custom net
1 parent 42b9559 commit 723abe8

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

torchTextClassifiers/model/components/classification_head.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,43 @@ def __init__(
1111
num_classes: Optional[int] = None,
1212
net: Optional[nn.Module] = None,
1313
):
14+
"""
15+
Classification head for text classification tasks.
16+
It is a nn.Module that can either be a simple Linear layer or a custom neural network module.
17+
18+
Args:
19+
input_dim (int, optional): Dimension of the input features. Required if net is not provided.
20+
num_classes (int, optional): Number of output classes. Required if net is not provided.
21+
net (nn.Module, optional): Custom neural network module to be used as the classification head.
22+
If provided, input_dim and num_classes are inferred from this module.
23+
Should be either an nn.Sequential with first and last layers being Linears or nn.Linear.
24+
"""
1425
super().__init__()
1526
if net is not None:
16-
self.net = net
17-
self.input_dim = net.in_features
18-
self.num_classes = net.out_features
27+
# --- Custom net should either be a Sequential or a Linear ---
28+
if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)):
29+
raise ValueError("net must be an nn.Sequential when provided.")
30+
31+
# --- If Sequential, Check first and last layers are Linear ---
32+
33+
if isinstance(net, nn.Sequential):
34+
first = net[0]
35+
last = net[-1]
36+
37+
if not isinstance(first, nn.Linear):
38+
raise TypeError(f"First layer must be nn.Linear, got {type(first).__name__}.")
39+
40+
if not isinstance(last, nn.Linear):
41+
raise TypeError(f"Last layer must be nn.Linear, got {type(last).__name__}.")
42+
43+
# --- Extract features ---
44+
self.input_dim = first.in_features
45+
self.num_classes = last.out_features
46+
self.net = net
47+
else: # if not Sequential, it is a Linear
48+
self.input_dim = net.in_features
49+
self.num_classes = net.out_features
50+
1951
else:
2052
assert (
2153
input_dim is not None and num_classes is not None

0 commit comments

Comments
 (0)