-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
56 lines (44 loc) · 1.3 KB
/
main.py
File metadata and controls
56 lines (44 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
# @Time : 2022/9/9 17:19
# @Author : LuMing
# @File : main.py
# @Software: PyCharm
# @Comment :
import numpy as np
import bpnn
import data_load
from bpnn import Bpnn
def binarize(labels):
bin_labels = []
for label in labels:
bin_list = []
for i in range(10):
if i == label:
bin_list.append(1)
else:
bin_list.append(0)
bin_labels.append(bin_list)
return bin_labels
def unidimensional(images):
unidimensional_images = []
for i in range(len(images)):
image = np.array(images[i]).reshape(1, -1)[0]
unidimensional_images.append(image / 255)
return unidimensional_images
def main():
train_images = data_load.load_train_images()
train_labels = data_load.load_train_labels()
test_images = data_load.load_test_images()
test_labels = data_load.load_test_labels()
train_labels = binarize(train_labels)
train_images = unidimensional(train_images)
# test_labels = binarize(test_labels)
test_images = unidimensional(test_images)
u_bpnn = Bpnn(28 * 28, 280, 10)
# 训练
u_bpnn.train(train_images[0:60000], train_labels[0:60000])
# test
u_bpnn.test(test_images, test_labels)
# bpnn.my_test()
if __name__ == "__main__":
main()