Skip to content

Commit 5637020

Browse files
authored
ruff: C4 (#222)
* ruff: C4 * apply
1 parent e5171bf commit 5637020

29 files changed

+153
-147
lines changed

examples/__only_for_dev__/to_test_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_regression(
5555
continuous_feature_transform=continuous_feature_transform,
5656
normalize_continuous_features=normalize_continuous_features,
5757
)
58-
model_config_params = dict(task="regression", depth=2, embed_categorical=embed_categorical)
58+
model_config_params = {"task": "regression", "depth": 2, "embed_categorical": embed_categorical}
5959
model_config = NodeConfig(**model_config_params)
6060
# model_config_params = dict(task="regression")
6161
# model_config = NodeConfig(**model_config_params)
@@ -98,7 +98,7 @@ def test_classification(
9898
continuous_feature_transform=continuous_feature_transform,
9999
normalize_continuous_features=normalize_continuous_features,
100100
)
101-
model_config_params = dict(task="classification", depth=2, embed_categorical=embed_categorical)
101+
model_config_params = {"task": "classification", "depth": 2, "embed_categorical": embed_categorical}
102102
model_config = NodeConfig(**model_config_params)
103103
trainer_config = TrainerConfig(max_epochs=1, checkpoints=None, early_stopping=None)
104104
optimizer_config = OptimizerConfig()

examples/__only_for_dev__/to_test_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
# batch_norm_continuous_input=True,
5555
# attention_pooling=True,
5656
# )
57-
model_config = CategoryEmbeddingModelConfig(task="regression", dropout=0.2, head_config=dict(layers="32-16"))
57+
model_config = CategoryEmbeddingModelConfig(task="regression", dropout=0.2, head_config={"layers": "32-16"})
5858

5959
trainer_config = TrainerConfig(
6060
# checkpoints=None,

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ select = [
2222
# "D", # see: https://pypi.org/project/pydocstyle
2323
# "N", # see: https://pypi.org/project/pep8-naming
2424
]
25+
extend-select = [
26+
"C4", # see: https://pypi.org/project/flake8-comprehensions
27+
# "SIM", # see: https://pypi.org/project/flake8-simplify
28+
# "RET", # see: https://pypi.org/project/flake8-return
29+
# "PT", # see: https://pypi.org/project/flake8-pytest-style
30+
]
2531
ignore = [
2632
"E731", # Do not assign a lambda expression, use a def
2733
]

src/pytorch_tabular/config/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class InferredConfig:
231231
def __post_init__(self):
232232
if self.embedding_dims is not None:
233233
assert all(
234-
[(isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims]
234+
(isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims
235235
), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
236236
self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims])
237237
else:
@@ -468,7 +468,7 @@ class TrainerConfig:
468468
metadata={"help": "The number of epochs to wait until there is no further improvements in loss/metric"},
469469
)
470470
early_stopping_kwargs: Optional[Dict[str, Any]] = field(
471-
default_factory=lambda: dict(),
471+
default_factory=lambda: {},
472472
metadata={
473473
"help": "Additional keyword arguments for the early stopping callback."
474474
" See the documentation for the PyTorch Lightning EarlyStopping callback for more details."
@@ -505,7 +505,7 @@ class TrainerConfig:
505505
metadata={"help": "The number of best models to save"},
506506
)
507507
checkpoints_kwargs: Optional[Dict[str, Any]] = field(
508-
default_factory=lambda: dict(),
508+
default_factory=lambda: {},
509509
metadata={
510510
"help": "Additional keyword arguments for the checkpoints callback. See the documentation"
511511
" for the PyTorch Lightning ModelCheckpoint callback for more details."

src/pytorch_tabular/feature_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
6767
continue
6868
batch[k] = v.to(self.tabular_model.model.device)
6969
if self.tabular_model.config.task == "ssl":
70-
ret_value = dict(backbone_features=self.tabular_model.model.predict(batch, ret_model_output=True))
70+
ret_value = {"backbone_features": self.tabular_model.model.predict(batch, ret_model_output=True)}
7171
else:
7272
_, ret_value = self.tabular_model.model.predict(batch, ret_model_output=True)
7373
for k in self.extract_keys:

src/pytorch_tabular/models/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def create_plotly_histogram(self, arr, name, bin_dict=None):
486486
# Overlay both histograms
487487
fig.update_layout(
488488
barmode="overlay",
489-
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
489+
legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
490490
)
491491
# Reduce opacity to see both histograms
492492
fig.update_traces(opacity=0.5)

src/pytorch_tabular/models/common/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150
def forward(self, x):
151151
h = self.n_heads
152152
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
153-
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
153+
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v))
154154
sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
155155

156156
attn = sim.softmax(dim=-1)

src/pytorch_tabular/models/ft_transformer/config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def __post_init__(self):
254254
" and `out_ff_initialization` as head_config is passed."
255255
)
256256
else:
257-
if any([p is not None for p in deprecated_args]):
257+
if any(p is not None for p in deprecated_args):
258258
warnings.warn(
259259
"The `out_ff_layers`, `out_ff_activation`, `out_ff_dropoout`, and `out_ff_initialization`"
260260
" arguments are deprecated and will be removed next release."
@@ -263,13 +263,13 @@ def __post_init__(self):
263263
)
264264
# TODO: Remove this once we deprecate the old config
265265
# Fill the head_config using deprecated parameters
266-
self.head_config = dict(
267-
layers=ifnone(self.out_ff_layers, ""),
268-
activation=ifnone(self.out_ff_activation, "ReLU"),
269-
dropout=ifnone(self.out_ff_dropout, 0.0),
270-
use_batch_norm=False,
271-
initialization=ifnone(self.out_ff_initialization, "kaiming"),
272-
)
266+
self.head_config = {
267+
"layers": ifnone(self.out_ff_layers, ""),
268+
"activation": ifnone(self.out_ff_activation, "ReLU"),
269+
"dropout": ifnone(self.out_ff_dropout, 0.0),
270+
"use_batch_norm": False,
271+
"initialization": ifnone(self.out_ff_initialization, "kaiming"),
272+
}
273273

274274
return super().__post_init__()
275275

src/pytorch_tabular/models/mixture_density/mdn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def _build_network(self):
8686
def forward(self, x: Dict):
8787
if isinstance(self.backbone, TabTransformerBackbone):
8888
if self.hparams.categorical_dim > 0:
89-
x_cat = self.embed_input(dict(categorical=x["categorical"]))
90-
x = self.compute_backbone(dict(categorical=x_cat, continuous=x["continuous"]))
89+
x_cat = self.embed_input({"categorical": x["categorical"]})
90+
x = self.compute_backbone({"categorical": x_cat, "continuous": x["continuous"]})
9191
else:
9292
x = self.embedding_layer(x)
9393
x = self.compute_backbone(x)
@@ -230,7 +230,7 @@ def validation_epoch_end(self, outputs) -> None:
230230
commit=False,
231231
)
232232
if self.head.hparams.log_debug_plot:
233-
fig = self.create_plotly_histogram(pi, "pi", bin_dict=dict(start=0.0, end=1.0, size=0.1))
233+
fig = self.create_plotly_histogram(pi, "pi", bin_dict={"start": 0.0, "end": 1.0, "size": 0.1})
234234
wandb.log(
235235
{
236236
"valid_pi": fig,

src/pytorch_tabular/models/tab_transformer/config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def __post_init__(self):
243243
" and `out_ff_initialization` as head_config is passed."
244244
)
245245
else:
246-
if any([p is not None for p in deprecated_args]):
246+
if any(p is not None for p in deprecated_args):
247247
warnings.warn(
248248
"The `out_ff_layers`, `out_ff_activation`, `out_ff_dropoout`, and `out_ff_initialization`"
249249
" arguments are deprecated and will be removed next release."
@@ -252,13 +252,13 @@ def __post_init__(self):
252252
)
253253
# TODO: Remove this once we deprecate the old config
254254
# Fill the head_config using deprecated parameters
255-
self.head_config = dict(
256-
layers=ifnone(self.out_ff_layers, ""),
257-
activation=ifnone(self.out_ff_activation, "ReLU"),
258-
dropout=ifnone(self.out_ff_dropout, 0.0),
259-
use_batch_norm=False,
260-
initialization=ifnone(self.out_ff_initialization, "kaiming"),
261-
)
255+
self.head_config = {
256+
"layers": ifnone(self.out_ff_layers, ""),
257+
"activation": ifnone(self.out_ff_activation, "ReLU"),
258+
"dropout": ifnone(self.out_ff_dropout, 0.0),
259+
"use_batch_norm": False,
260+
"initialization": ifnone(self.out_ff_initialization, "kaiming"),
261+
}
262262
return super().__post_init__()
263263

264264

0 commit comments

Comments
 (0)