|
| 1 | +import importlib |
1 | 2 | import os |
2 | 3 | import shutil |
3 | | -from typing import Optional, Tuple, Union |
| 4 | +from pathlib import Path |
| 5 | +from typing import Optional, Tuple |
4 | 6 |
|
5 | 7 | import torch |
6 | 8 | import tqdm |
7 | 9 | import wandb |
8 | 10 | import wandb.util as wandb_util |
| 11 | +import yaml |
9 | 12 |
|
10 | 13 | from chebai.models.base import ChebaiBaseNet |
11 | | -from chebai.models.electra import Electra |
12 | 14 | from chebai.preprocessing.datasets.base import XYBaseDataModule |
13 | 15 | from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor |
14 | 16 |
|
@@ -121,7 +123,7 @@ def evaluate_model( |
121 | 123 | save_batch_size = 128 |
122 | 124 | n_saved = 1 |
123 | 125 |
|
124 | | - print(f"") |
| 126 | + print("") |
125 | 127 | for i in tqdm.tqdm(range(0, len(data_list), batch_size)): |
126 | 128 | if not ( |
127 | 129 | skip_existing_preds |
@@ -307,5 +309,5 @@ def parse_config_file(config_path: str) -> tuple[str, dict]: |
307 | 309 | ) |
308 | 310 | os.makedirs(buffer_dir_concat, exist_ok=True) |
309 | 311 | preds, labels = load_results_from_buffer(buffer_dir, "cpu") |
310 | | - torch.save(preds, os.path.join(buffer_dir_concat, f"preds000.pt")) |
311 | | - torch.save(labels, os.path.join(buffer_dir_concat, f"labels000.pt")) |
| 312 | + torch.save(preds, os.path.join(buffer_dir_concat, "preds000.pt")) |
| 313 | + torch.save(labels, os.path.join(buffer_dir_concat, "labels000.pt")) |
0 commit comments