Skip to content

Commit 6cc3050

Browse files
committed
Repr, rich_repr
Fixes #78
1 parent 2016751 commit 6cc3050

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

src/model_constructor/model_constructor.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -247,15 +247,25 @@ class Config:
247247
arbitrary_types_allowed = True
248248
extra = "forbid"
249249

250-
def extra_repr(self) -> str:
251-
res = ""
252-
for k, v in self.dict().items():
253-
if v is not None:
254-
res += f"{k}: {v}\n"
255-
return res
250+
def _get_str_value(self, field: str) -> str:
251+
value = getattr(self, field)
252+
if isinstance(value, type):
253+
value = value.__name__
254+
elif isinstance(value, partial):
255+
value = f"{value.func.__name__} {value.keywords}"
256+
elif callable(value):
257+
value = value.__name__
258+
return value
259+
260+
def __repr__(self) -> str:
261+
return f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})"
256262

257-
def pprint(self) -> None:
258-
print(self.extra_repr())
263+
def __repr_args__(self):
264+
return [
265+
(field, str_value)
266+
for field in self.__fields__
267+
if (str_value := self._get_str_value(field))
268+
]
259269

260270

261271
class ModelConstructor(ModelCfg):
@@ -296,25 +306,17 @@ def __call__(self):
296306
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
297307
)
298308
self.init_cnn(model) # pylint: disable=too-many-function-args
299-
extra_repr = self.get_extra_repr()
309+
extra_repr = self._get_extra_repr()
300310
if extra_repr:
301311
model.extra_repr = lambda: extra_repr
302312
return model
303313

304-
def get_extra_repr(self) -> str:
314+
def _get_extra_repr(self) -> str:
305315
return " ".join(
306-
f"{field}: {self.get_str_value(field)},"
316+
f"{field}: {self._get_str_value(field)},"
307317
for field in self.__fields_set__ if field != "name"
308318
)[:-1]
309319

310-
def get_str_value(self, field: str) -> str:
311-
value = getattr(self, field)
312-
if isinstance(value, type):
313-
value = value.__name__
314-
if isinstance(value, partial):
315-
value = f"{value.func.__name__} {value.keywords}"
316-
return value
317-
318320
def __repr__(self):
319321
se_repr = self.se.__name__ if self.se else "False" # type: ignore
320322
model_name = self.name or self.__class__.__name__
@@ -328,6 +330,10 @@ def __repr__(self):
328330
f" layers: {self.layers}"
329331
)
330332

333+
def print_cfg(self) -> None:
334+
"""Print full config"""
335+
print(f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})")
336+
331337

332338
class XResNet34(ModelConstructor):
333339
layers: list[int] = [3, 4, 6, 3]

0 commit comments

Comments
 (0)