@@ -247,15 +247,25 @@ class Config:
247
247
arbitrary_types_allowed = True
248
248
extra = "forbid"
249
249
250
- def extra_repr (self ) -> str :
251
- res = ""
252
- for k , v in self .dict ().items ():
253
- if v is not None :
254
- res += f"{ k } : { v } \n "
255
- return res
250
+ def _get_str_value (self , field : str ) -> str :
251
+ value = getattr (self , field )
252
+ if isinstance (value , type ):
253
+ value = value .__name__
254
+ elif isinstance (value , partial ):
255
+ value = f"{ value .func .__name__ } { value .keywords } "
256
+ elif callable (value ):
257
+ value = value .__name__
258
+ return value
259
+
260
+ def __repr__ (self ) -> str :
261
+ return f"{ self .__repr_name__ ()} (\n { self .__repr_str__ (chr (10 ) + ' ' )} )"
256
262
257
- def pprint (self ) -> None :
258
- print (self .extra_repr ())
263
+ def __repr_args__ (self ):
264
+ return [
265
+ (field , str_value )
266
+ for field in self .__fields__
267
+ if (str_value := self ._get_str_value (field ))
268
+ ]
259
269
260
270
261
271
class ModelConstructor (ModelCfg ):
@@ -296,25 +306,17 @@ def __call__(self):
296
306
OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
297
307
)
298
308
self .init_cnn (model ) # pylint: disable=too-many-function-args
299
- extra_repr = self .get_extra_repr ()
309
+ extra_repr = self ._get_extra_repr ()
300
310
if extra_repr :
301
311
model .extra_repr = lambda : extra_repr
302
312
return model
303
313
304
- def get_extra_repr (self ) -> str :
314
+ def _get_extra_repr (self ) -> str :
305
315
return " " .join (
306
- f"{ field } : { self .get_str_value (field )} ,"
316
+ f"{ field } : { self ._get_str_value (field )} ,"
307
317
for field in self .__fields_set__ if field != "name"
308
318
)[:- 1 ]
309
319
310
- def get_str_value (self , field : str ) -> str :
311
- value = getattr (self , field )
312
- if isinstance (value , type ):
313
- value = value .__name__
314
- if isinstance (value , partial ):
315
- value = f"{ value .func .__name__ } { value .keywords } "
316
- return value
317
-
318
320
def __repr__ (self ):
319
321
se_repr = self .se .__name__ if self .se else "False" # type: ignore
320
322
model_name = self .name or self .__class__ .__name__
@@ -328,6 +330,10 @@ def __repr__(self):
328
330
f" layers: { self .layers } "
329
331
)
330
332
333
+ def print_cfg (self ) -> None :
334
+ """Print full config"""
335
+ print (f"{ self .__repr_name__ ()} (\n { self .__repr_str__ (chr (10 ) + ' ' )} )" )
336
+
331
337
332
338
class XResNet34 (ModelConstructor ):
333
339
layers : list [int ] = [3 , 4 , 6 , 3 ]
0 commit comments