Skip to content

Commit 4510af9

Browse files
fix(ClassifHead): simplify the net validation
1 parent 6e3f392 commit 4510af9

File tree

1 file changed

+4
-18
lines changed

1 file changed

+4
-18
lines changed

torchTextClassifiers/model/components/classification_head.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(
2424
"""
2525
super().__init__()
2626
if net is not None:
27+
self.net = net
28+
2729
# --- Custom net should either be a Sequential or a Linear ---
2830
if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)):
2931
raise ValueError("net must be an nn.Sequential when provided.")
@@ -43,7 +45,6 @@ def __init__(
4345
# --- Extract features ---
4446
self.input_dim = first.in_features
4547
self.num_classes = last.out_features
46-
self.net = net
4748
else: # if not Sequential, it is a Linear
4849
self.input_dim = net.in_features
4950
self.num_classes = net.out_features
@@ -53,23 +54,8 @@ def __init__(
5354
input_dim is not None and num_classes is not None
5455
), "Either net or both input_dim and num_classes must be provided."
5556
self.net = nn.Linear(input_dim, num_classes)
56-
self.input_dim, self.num_classes = self._get_linear_input_output_dims(self.net)
57+
self.input_dim = input_dim
58+
self.num_classes = num_classes
5759

5860
def forward(self, x: torch.Tensor) -> torch.Tensor:
5961
return self.net(x)
60-
61-
@staticmethod
62-
def _get_linear_input_output_dims(module: nn.Module):
63-
"""
64-
Returns (input_dim, output_dim) for any module containing Linear layers.
65-
Works for Linear, Sequential, or nested models.
66-
"""
67-
# Collect all Linear layers recursively
68-
linears = [m for m in module.modules() if isinstance(m, nn.Linear)]
69-
70-
if not linears:
71-
raise ValueError("No Linear layers found in the given module.")
72-
73-
input_dim = linears[0].in_features
74-
output_dim = linears[-1].out_features
75-
return input_dim, output_dim

0 commit comments

Comments
 (0)