-
Notifications
You must be signed in to change notification settings - Fork 455
Expand file tree
/
Copy pathmodel.py
More file actions
131 lines (117 loc) · 4.96 KB
/
model.py
File metadata and controls
131 lines (117 loc) · 4.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from torch import nn
from models import resnet
import torch.serialization
__name__='models.resnet'
SAFE_GLOBALS = [__name__]
def generate_model(opt):
assert opt.model in [
'resnet'
]
if opt.model == 'resnet':
assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
if opt.model_depth == 10:
model = resnet.resnet10(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 18:
model = resnet.resnet18(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 34:
model = resnet.resnet34(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 50:
model = resnet.resnet50(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 101:
model = resnet.resnet101(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 152:
model = resnet.resnet152(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 200:
model = resnet.resnet200(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
if not opt.no_cuda:
if len(opt.gpu_id) > 1:
model = model.cuda()
model = nn.DataParallel(model, device_ids=opt.gpu_id)
net_dict = model.state_dict()
else:
import os
os.environ["CUDA_VISIBLE_DEVICES"]=str(opt.gpu_id[0])
model = model.cuda()
model = nn.DataParallel(model, device_ids=None)
net_dict = model.state_dict()
else:
net_dict = model.state_dict()
if opt.phase != 'test' and opt.pretrain_path:
print(f'正在加载预训练模型 {opt.pretrain_path}')
try:
with torch.serialization.safe_globals(SAFE_GLOBALS):
pretrain = torch.load(
opt.pretrain_path,
weights_only=True,
map_location='cpu'
)
except FileNotFoundError:
raise FileNotFoundError(f"预训练模型文件不存在: {opt.pretrain_path}")
except RuntimeError as e:
if "is not in the safe globals list" in str(e):
raise RuntimeError(
f"加载模型时安全检查失败!请将相关模块添加到白名单。错误详情:{e}\n"
f"当前白名单模块:{SAFE_GLOBALS}"
) from e
else:
raise RuntimeError(f"加载预训练模型失败:{e}") from e
except Exception as e:
raise Exception(f"加载模型时发生未知错误:{e}") from e
pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()}
net_dict.update(pretrain_dict)
model.load_state_dict(net_dict)
new_parameters = []
for pname, p in model.named_parameters():
for layer_name in opt.new_layer_names:
if pname.find(layer_name) >= 0:
new_parameters.append(p)
break
new_parameters_id = list(map(id, new_parameters))
base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
parameters = {'base_parameters': base_parameters,
'new_parameters': new_parameters}
return model, parameters
return model, model.parameters()