Skip to content

Commit 32d49d9

Browse files
authored
Merge pull request #122 from ayasyrev/dev
4.2
2 parents c9f17d5 + d71e24f commit 32d49d9

17 files changed

+44
-72
lines changed

noxfile_cov.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import nox
22

33

4-
@nox.session(python=["3.10"])
4+
@nox.session(python=["3.11"])
55
def cov_tests(session: nox.Session) -> None:
66
args = session.posargs or ["--cov"]
77
session.install(".", "pytest", "pytest-cov", "coverage[toml]")

requirements_test.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,2 @@
11
pytest
2-
pytest-cov
3-
coverage[toml]
4-
flake8
5-
nox
2+
pytest-cov

requirements_test_extra.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
coverage[toml]
2+
black
3+
flake8
4+
nox
5+
isort

setup.cfg

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@ long_description_content_type = text/markdown
1010
url = https://github.com/ayasyrev/model_constructor
1111
license = apache2
1212
classifiers =
13-
Programming Language :: Python :: 3
13+
Programming Language :: Python :: 3.8
14+
Programming Language :: Python :: 3.9
15+
Programming Language :: Python :: 3.10
16+
Programming Language :: Python :: 3.11
1417
License :: OSI Approved :: Apache Software License
1518
Operating System :: OS Independent
1619

1720
[options]
1821
package_dir =
1922
= src
2023
packages = find:
21-
python_requires = >=3.7
24+
python_requires = >=3.8, <3.12
2225

2326
[options.packages.find]
2427
where = src

setup_.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

src/model_constructor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .convmixer import ConvMixer
2-
from .model_constructor import ModelConstructor, ModelCfg
2+
from .model_constructor import ModelCfg, ModelConstructor
33
from .version import __version__

src/model_constructor/activations.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# forked from https://github.com/rwightman/pytorch-image-models/timm/models/layers/activations.py
22
import torch
33
from torch import nn as nn
4-
from torch.nn import functional as F
54
from torch.nn import Mish
6-
5+
from torch.nn import functional as F
76

87
__all__ = [
98
"mish",

src/model_constructor/helpers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,18 @@ def print_set_fields(self) -> None:
140140
else:
141141
print("Nothing changed")
142142

143-
def print_changed_fields(self, show_default: bool = False, separator: str = " | ") -> None:
143+
def print_changed_fields(
144+
self, show_default: bool = False, separator: str = " | "
145+
) -> None:
144146
"""Print fields changed at init."""
145147
if self.changed_fields:
146148
default_value = ""
147149
print("Changed fields:")
148150
for field in self.changed_fields:
149151
if show_default:
150-
default_value = f"{separator}{self._get_str(self.model_fields[field].default)}"
152+
default_value = (
153+
f"{separator}{self._get_str(self.model_fields[field].default)}"
154+
)
151155
print(f"{field}: {self._get_str_value(field)}{default_value}")
152156
else:
153157
print("Nothing changed")

src/model_constructor/model_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Any, Callable, Dict, List, Optional, Union, Type
3+
from typing import Any, Callable, Dict, List, Optional, Type, Union
44

55
from pydantic import field_validator
66
from pydantic_core.core_schema import FieldValidationInfo
@@ -32,7 +32,7 @@
3232
}
3333

3434

35-
nnModule = Union[Type[nn.Module], Callable[[], nn.Module]]
35+
nnModule = Union[Type[nn.Module], Callable[[Any], nn.Module]]
3636

3737

3838
class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):

src/model_constructor/mxresnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Type
2+
23
from torch import nn
34

45
from .xresnet import XResNet

0 commit comments

Comments
 (0)