@@ -24,6 +24,8 @@ def __init__(
2424 """
2525 super ().__init__ ()
2626 if net is not None :
27+ self .net = net
28+
2729 # --- Custom net should either be a Sequential or a Linear ---
2830 if not (isinstance (net , nn .Sequential ) or isinstance (net , nn .Linear )):
2931 raise ValueError ("net must be an nn.Sequential when provided." )
@@ -43,7 +45,6 @@ def __init__(
4345 # --- Extract features ---
4446 self .input_dim = first .in_features
4547 self .num_classes = last .out_features
46- self .net = net
4748 else : # if not Sequential, it is a Linear
4849 self .input_dim = net .in_features
4950 self .num_classes = net .out_features
@@ -53,23 +54,8 @@ def __init__(
5354 input_dim is not None and num_classes is not None
5455 ), "Either net or both input_dim and num_classes must be provided."
5556 self .net = nn .Linear (input_dim , num_classes )
56- self .input_dim , self .num_classes = self ._get_linear_input_output_dims (self .net )
57+ self .input_dim = input_dim
58+ self .num_classes = num_classes
5759
5860 def forward (self , x : torch .Tensor ) -> torch .Tensor :
5961 return self .net (x )
60-
61- @staticmethod
62- def _get_linear_input_output_dims (module : nn .Module ):
63- """
64- Returns (input_dim, output_dim) for any module containing Linear layers.
65- Works for Linear, Sequential, or nested models.
66- """
67- # Collect all Linear layers recursively
68- linears = [m for m in module .modules () if isinstance (m , nn .Linear )]
69-
70- if not linears :
71- raise ValueError ("No Linear layers found in the given module." )
72-
73- input_dim = linears [0 ].in_features
74- output_dim = linears [- 1 ].out_features
75- return input_dim , output_dim
0 commit comments