-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_train.py
More file actions
73 lines (57 loc) · 1.93 KB
/
main_train.py
File metadata and controls
73 lines (57 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@Author: JetKwok
@HomePage: https://FMVPJet.github.io/
@E-mail: JetKwok827@gmail.com
@Date: 2023/9/19 17:22
"""
import os
import torch
import wandb
import argparse
from omegaconf import OmegaConf
from data_loaders.my_ADSB_dl import MyADSBDataLoader
from models.ADSB_model import ADSBModel
from trainers.ADSB_trainer import Trainer
from utils.utils import process_config, prepare_device, get_logger
os.environ["WANDB_API_KEY"] = '2ec38f4f7afe80b121961f27c7b085e640e73165'
def main_train(args):
config = OmegaConf.load(args.config)
config = process_config(config)
logger = get_logger(os.path.join(config.log_dir, 'train_log.log'))
logger.info('解析配置...')
wandb.init(
project="ADS-B-Detection",
config={
"epochs": config.num_epochs,
"batch_size": config.batch_size,
"lr": config.optimizer.lr,
"decay_rate": config.optimizer.decay
}
)
logger.info('加载数据...')
dl = MyADSBDataLoader(config=config, logger=logger)
logger.info('构造网络...')
device, device_ids = prepare_device(config=config)
model = ADSBModel(config=config).to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)
logger.info('训练网络...')
trainer = Trainer(
model = model,
train_data = dl.get_train_data(),
val_data = dl.get_val_data(),
config = config,
logger = logger,
device = device)
# wandb.watch(model, log="all")
trainer.train()
wandb.finish()
logger.info('训练完成...')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_path", type=str, default=None)
parser.add_argument("--config", type=str, default="configs/config.yaml")
args = parser.parse_args()
main_train(args)