-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
771 lines (654 loc) · 40.4 KB
/
model.py
File metadata and controls
771 lines (654 loc) · 40.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
import sys
import torch
import numpy as np
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
import os.path
from sklearn.model_selection import train_test_split
from new_node import myNode
from sklearn.metrics import confusion_matrix
from getOptions import getOptions
from pptree import *
import random
import time
TOTAL_TRAIN_IMG = 50000 # total training input samples
TOTAL_TEST_IMG = 10000 # total testing input samples
TOTAL_CLASSES = 10 # total no. of different classes present
# Tree class
class Tree:
# constuctor with default values given in arguments
def __init__(self, device, maxDepth=1, dominanceThreshold=0.95, classThreshold=1, dataNumThreshold=100, numClasses=TOTAL_CLASSES):
assert(isinstance(maxDepth,int))
self.maxDepth=maxDepth # depth threshold
assert(isinstance(dominanceThreshold,float))
self.dominanceThreshold=dominanceThreshold # threshold on class dominance
assert(isinstance(classThreshold,int))
assert(classThreshold >= 1)
self.classThreshold=classThreshold # threshold on number of class in a node
assert(isinstance(dataNumThreshold,int))
self.dataNumThreshold=dataNumThreshold # threshold on total input images to a node
assert(numClasses >= 1)
assert(isinstance(numClasses,int))
self.numClasses = numClasses # total no. of classes spanning the input samples
self.nodeArray=None # stores the list of all the nodes in the tree in a BFS traversal fashion
self.device = device
# self.root = None
# self.maxNumberOfNodes = maxNumberOfNodes
# function to check if any (or both) of the 2 children nodes of some parent node will be leaf nodes or not
def checkLeafNodes(self, handleLeafDict):
# boolean for whether left and right children are leaf or not
isLeafLeft=False
isLeafRight=False
# tells whether left or right child is empty child or not
isemptyNodeLeft=False
isemptyNodeRight=False
# this is used to check if leaf has only one class and this further stores the index of that class
leftLeafClass = -1
rightLeafClass = -1
############# maxDepth #############
# if the current level of the parent node reaches <maxDepth-1>, then both its children nodes are forcefully made as leaf nodes
if handleLeafDict["lvl"]==self.maxDepth:
isLeafLeft=True
isLeafRight=True
if options.verbose > 0:
print("MAX DEPTH REACHED", self.maxDepth)
############# dataNumThreshold #############
# if the total no. of samples in a node are less than the dataNumThreshold, then that node is made a leaf node
if handleLeafDict["leftDataNum"] <= self.dataNumThreshold:
isLeafLeft=True
if options.verbose > 0:
print("DATANUM THRESHOLD REACHED IN LEFT", handleLeafDict["leftDataNum"], self.dataNumThreshold)
if handleLeafDict["rightDataNum"] <= self.dataNumThreshold:
isLeafRight=True
if options.verbose > 0:
print("DATANUM THRESHOLD REACHED IN RIGHT", handleLeafDict["rightDataNum"], self.dataNumThreshold)
############# classThreshold #############
# if a child node contains total no. of classes less than the classThreshold, then it is made a leaf node
# HANDLE 0, 1 & 2 class cases
## LEFT CHILD
if handleLeafDict["noOfLeftClasses"]==0:
isemptyNodeLeft=True
if options.verbose > 0:
print("CLASS THRESHOLD REACHED 0 IN LEFT", handleLeafDict["noOfLeftClasses"], self.classThreshold)
elif (self.classThreshold >= 1) and (handleLeafDict["noOfLeftClasses"]==1):
isLeafLeft=True
# the only single class is present in a child node is called its leafClass
leftLeafClass=handleLeafDict["maxLeftClassIndex"]
if options.verbose > 0:
print("CLASS THRESHOLD REACHED 1 IN LEFT", handleLeafDict["noOfLeftClasses"], self.classThreshold, leftLeafClass)
elif (self.classThreshold >= 2) and (handleLeafDict["noOfLeftClasses"]==2):
isLeafLeft=True
if options.verbose > 0:
print("CLASS THRESHOLD REACHED 2 IN LEFT", handleLeafDict["noOfLeftClasses"], self.classThreshold)
## RIGHT CHILD
if handleLeafDict["noOfRightClasses"]==0:
isemptyNodeRight=True
if options.verbose > 0:
print("CLASS THRESHOLD REACHED 0 IN RIGHT", handleLeafDict["noOfRightClasses"], self.classThreshold)
elif (self.classThreshold >= 1) and (handleLeafDict["noOfRightClasses"]==1):
isLeafRight=True
# the only single class is present in a child node is called its leafClass
rightLeafClass=handleLeafDict["maxRightClassIndex"]
if options.verbose > 0:
print("CLASS THRESHOLD REACHED 1 IN RIGHT", handleLeafDict["noOfRightClasses"], self.classThreshold, rightLeafClass)
elif (self.classThreshold >= 2) and (handleLeafDict["noOfRightClasses"]==2):
isLeafRight=True
if options.verbose > 0:
print("CLASS THRESHOLD REACHED 2 IN RIGHT", handleLeafDict["noOfRightClasses"], self.classThreshold)
############# dominanceThreshold #############
# if a child contains a class that has its ratio of input samples to the overall samples in that node greater than dominanceThreshold, then it is made a leaf node
# also, upon satisfying the criteria, this class is also made the node's leafClass
if handleLeafDict["maxLeft"] >= self.dominanceThreshold:
isLeafLeft=True
leftLeafClass=handleLeafDict["maxLeftClassIndex"]
if options.verbose > 0:
print("DOMINANCE THRESHOLD REACHED IN LEFT", handleLeafDict["maxLeft"], self.dominanceThreshold, leftLeafClass)
if handleLeafDict["maxRight"] >= self.dominanceThreshold:
isLeafRight=True
rightLeafClass=handleLeafDict["maxRightClassIndex"]
if options.verbose > 0:
print("DOMINANCE THRESHOLD REACHED IN RIGHT", handleLeafDict["maxRight"], self.dominanceThreshold, rightLeafClass)
return isLeafLeft, isLeafRight, isemptyNodeLeft, isemptyNodeRight, leftLeafClass, rightLeafClass
# tree traversal while training which, as a result, finally builds and stores the tree including all its nodes' properties and trained models in corresponding files under the <options.ckptDir> directory
def tree_traversal(self, trainInputDict, valInputDict, resumeTrain, resumeFromNodeId):
if options.verbose > 0:
print("\nTRAINING STARTS")
# make a root node (node definition is given in <new_node.py>) according to max depth given by user
rootNode = myNode(parentId=0, nodeId=1, device=self.device, isTrain=True, level=0, parentNode=None)
if self.maxDepth == 0: # if require only one node in tree
rootNode.setInput(trainInputDict=trainInputDict, valInputDict=valInputDict, numClasses=self.numClasses, giniValue=0.9, isLeaf=True, leafClass=-1, lchildId=-1, rchildId=-1)
else:
rootNode.setInput(trainInputDict=trainInputDict, valInputDict=valInputDict, numClasses=self.numClasses, giniValue=0.9, isLeaf=False, leafClass=-1, lchildId=-1, rchildId=-1)
# initialising one hot tensors output
oneHotTensors = torch.zeros(len(trainInputDict["label"]), TOTAL_CLASSES)
# initialising node probability for each sample by 1
nodeProb = torch.ones(len(trainInputDict["label"]))
# node array that stores all the nodes
self.nodeArray = []
self.nodeArray.append(rootNode)
start = 0
end = 1
# while we haven't traversed the whole node array
while start != end:
# get the node at start from list
node = self.nodeArray[start]
if options.verbose > 0:
print("Running nodeId: ", node.nodeId)
start+=1 # increment start
# if node is not a leaf node
# resumfromNodeId helps to start training from a particular node
# (did this because - suppose, while training, we got an error. Thus, instead of again starting from the begining root node, we can directly start training from any previously "fully" trained node)
# (a "fully" trained node is the one whose both - CNN and modelToTrain (if not a leaf node) - models are trained completely)
# workTest fuction runs the testing algorithm
# workTrain function runs the training algorithm
if not node.isLeaf:
if (resumeTrain) and (node.nodeId<resumeFromNodeId):
lTrainDict, lValDict, rTrainDict, rValDict, giniLeftRatio, giniRightRatio, handleLeafDict = node.workTest(nodeProb, oneHotTensors,True)
else:
lTrainDict, lValDict, rTrainDict, rValDict, giniLeftRatio, giniRightRatio, handleLeafDict = node.workTrain()
else:
if (resumeTrain) and (node.nodeId<resumeFromNodeId):
node.workTest(nodeProb, oneHotTensors, True)
else:
node.workTrain()
# if the current node is not leaf node, then it can have children nodes
if not node.isLeaf:
# check if any (or both) children node(s) is/are leaf node(s)
isLeafLeft, isLeafRight, isemptyNodeLeft, isemptyNodeRight, leftLeafClass, rightLeafClass = self.checkLeafNodes(handleLeafDict)
# loads the current node (i.e. parent) information
ParentNodeDict = torch.load(options.ckptDir+'/node_'+str(node.nodeId)+'.pth')['nodeDict']
# if left node is not empty, then create left child according to it being a leaf node or not
if not isemptyNodeLeft:
end += 1
lNode = myNode(node.nodeId, end, self.device, True, node.level+1, node)
if not isLeafLeft:
lNode.setInput(lTrainDict, lValDict, handleLeafDict["noOfLeftClasses"], giniLeftRatio, False, leftLeafClass, -1, -1)
else:
lNode.setInput(lTrainDict, lValDict, handleLeafDict["noOfLeftClasses"], giniLeftRatio, True, leftLeafClass, -1, -1)
# append the child node into array and set the corresponding (left/right) child of the parent node (<node>)
self.nodeArray.append(lNode)
ParentNodeDict['lchildId'] = lNode.nodeId
# similarly, we do for the right child node
if not isemptyNodeRight:
end += 1
rNode = myNode(node.nodeId, end, self.device, True, node.level+1,node)
if not isLeafRight:
rNode.setInput(rTrainDict, rValDict, handleLeafDict["noOfRightClasses"], giniRightRatio, False, rightLeafClass, -1, -1)
else:
rNode.setInput(rTrainDict, rValDict, handleLeafDict["noOfRightClasses"], giniRightRatio, True, rightLeafClass, -1, -1)
self.nodeArray.append(rNode)
ParentNodeDict['rchildId'] = rNode.nodeId
# set the gini Gain value for the current node and update its node Dictionary in the file
ParentNodeDict['giniGain'] = handleLeafDict["giniGain"]
torch.save({
'nodeDict':ParentNodeDict,
}, options.ckptDir+'/node_'+str(node.nodeId)+'.pth')
# this will run on test data
def testTraversal(self, testInputDict):
if options.verbose > 0:
print("\nTESTING STARTS")
nodeId=1
# load cnn and root node information from checckpoints stored
ckptRoot = torch.load(options.ckptDir+'/node_cnn_'+str(nodeId)+'.pth')['labelMap']
noOfClasses = len(ckptRoot)
rootNodeDict = torch.load(options.ckptDir+'/node_'+str(nodeId)+'.pth')['nodeDict']
isLeafRoot = rootNodeDict['isLeaf']
leftChildId = rootNodeDict['lchildId']
rightChildId = rootNodeDict['rchildId']
if rootNodeDict['level']>=self.maxDepth:
isLeafRoot=True
leftChildId=-1
rightChildId=-1
# make root node according to the node stored in checkpoint
rootNode = myNode(parentId=rootNodeDict['parentId'], nodeId=rootNodeDict['nodeId'], device=self.device, isTrain=False, level=rootNodeDict['level'], parentNode=None)
rootNode.setInput(trainInputDict=testInputDict, valInputDict={}, numClasses=noOfClasses, giniValue=0.9, isLeaf=isLeafRoot, leafClass=rootNodeDict['leafClass'], lchildId=leftChildId, rchildId=rightChildId)
# initialising one hot tensors output
oneHotTensors = torch.zeros(len(testInputDict["label"]), TOTAL_CLASSES)
# initialising node probability for each sample by 1
nodeProb = torch.ones(len(testInputDict["label"]))
# stores the test results - corresponding predicted labels and actual labels with their original indices as in original test input
testPredDict = {}
testPredDict['actual'] = ((torch.rand(0)).long()).to(self.device)
testPredDict['pred'] = ((torch.rand(0)).long()).to(self.device)
testPredDict['index'] = ((torch.rand(0)).long()).to(self.device)
# saving this dictionary in a file, to be used later in leaf nodes, and also finally while calculating accuracy.
torch.save({
'testPredDict':testPredDict,
}, options.ckptDir+'/testPred.pth')
# this stores level wise accuracy
LevelDict = {}
LevelDict['levelAcc'] = {}
LevelDict['leafAcc'] = [0,0]
# saving this dictionary in a file, to be used later in each node
torch.save({
'levelDict':LevelDict,
}, options.ckptDir+'/level.pth')
prevLvl=-1
# q has list of test nodes which will be travesed in a BFS manner according to the <nodeArray> built while training
q = []
q.append(rootNode)
start = 0
end = 1
while start != end:
# get node at start index
node = q[start]
curLvl=node.level
# if current level is greater than previous level, marking the start of a new level, update the current leaf accuracy in the file
if curLvl>prevLvl:
LevelDict = torch.load(options.ckptDir+'/level.pth')['levelDict']
LevelDict['levelAcc'][curLvl] = LevelDict['leafAcc'][:]
torch.save({
'levelDict':LevelDict,
}, options.ckptDir+'/level.pth')
prevLvl=curLvl
start+=1
# depending on if the node is leaf node or not, we have differnt returns from worktest
if not node.isLeaf:
lTrainDict, rTrainDict, giniLeftRatio, giniRightRatio, noOfLeftClasses, noOfRightClasses, lChildProb, rChildProb = node.workTest(nodeProb, oneHotTensors)
else:
node.workTest(nodeProb, oneHotTensors)
# if node is not leaf node, we add the children test nodes using the training checkpoint nodes stored
if not node.isLeaf:
# if the corresponding current <node> from training had a left train_child and the left test_child is not empty
if not ((node.lchildId == -1) or (len(lTrainDict["label"]) == 0)):
# loading the node_Dictionary for the left train_child node and using ITS leafClass, level, isLeaf, lchildId & rchildId,
# and setting inputs accordingly for the left test_child node
leftNodeDict = torch.load(options.ckptDir+'/node_'+str(node.lchildId)+'.pth')['nodeDict']
noOfLeftClasses = 1
if (leftNodeDict['leafClass'] == -1):
ckptLeft = torch.load(options.ckptDir+'/node_cnn_'+str(node.lchildId)+'.pth')['labelMap']
noOfLeftClasses = len(ckptLeft)
lNode = myNode(node.nodeId, node.lchildId, self.device, False, leftNodeDict['level'],node)
isLeafLeft = leftNodeDict['isLeaf']
leftChildId = leftNodeDict['lchildId']
rightChildId = leftNodeDict['rchildId']
if leftNodeDict['level']>=self.maxDepth:
isLeafLeft=True
leftChildId=-1
rightChildId=-1
lNode.setInput(lTrainDict, {}, noOfLeftClasses, giniLeftRatio, isLeafLeft, leftNodeDict['leafClass'], leftChildId, rightChildId)
# appending the test child node
q.append(lNode)
end+=1
# similarly, we do for right child
if not ((node.rchildId == -1) or (len(rTrainDict["label"]) == 0)):
rightNodeDict = torch.load(options.ckptDir+'/node_'+str(node.rchildId)+'.pth')['nodeDict']
noOfRightClasses=1
if (rightNodeDict['leafClass'] == -1):
ckptRight = torch.load(options.ckptDir+'/node_cnn_'+str(node.rchildId)+'.pth')['labelMap']
noOfRightClasses = len(ckptRight)
rNode = myNode(node.nodeId, node.rchildId, self.device, False, rightNodeDict['level'], node)
isLeafRight = rightNodeDict['isLeaf']
leftChildId = rightNodeDict['lchildId']
rightChildId = rightNodeDict['rchildId']
if rightNodeDict['level']>=self.maxDepth:
isLeafRight=True
leftChildId=-1
rightChildId=-1
rNode.setInput(rTrainDict, {}, noOfRightClasses, giniRightRatio, isLeafRight, rightNodeDict['leafClass'], leftChildId, rightChildId)
q.append(rNode)
end+=1
if options.verbose > 1:
print ('Nodes sizes = ', noOfLeftClasses, noOfRightClasses)
# loads the testPred Dictionary
ckpt = torch.load(options.ckptDir+'/testPred.pth')
testPredDict = ckpt['testPredDict']
testPredDict['actual'] = testPredDict['actual'].to("cpu")
testPredDict['pred'] = testPredDict['pred'].to("cpu")
## np.savetxt("testActual.txt", testPredDict['actual'].numpy(), fmt="%d")
## np.savetxt("testPred.txt", testPredDict['pred'].numpy(), fmt="%d")
# build and print the confusion matrix using the predicted and actual labels
cm = confusion_matrix(testPredDict['actual'], testPredDict['pred'])
print(cm)
print()
# calculate and print the final accuracy obtained
correct = testPredDict['pred'].eq(testPredDict['actual']).sum().item()
total = len(testPredDict['actual'])
if total != 0:
print('Final Acc: %.3f'% (100.*correct/total))
else:
print('Final Acc: 0')
print()
# loads the Level Dictionary
LevelDict = torch.load(options.ckptDir+'/level.pth')['levelDict']
# calculate and print the level-wise accuracy obtained
for i,val in enumerate(LevelDict['levelAcc'].items()):
print('Level %d Acc: %.3f'% (val[0], 100.*val[1][0]/val[1][1]))
# sorts the prediction and actual labels of the samples according to their original indices
indexList, actList, predList = zip(*sorted(zip(testPredDict["index"], testPredDict['actual'], testPredDict['pred'])))
predArr = np.array(predList)
oneHotVector = np.zeros((predArr.size, TOTAL_CLASSES))
# create a one Hot Vector from the predicted labels and return it
oneHotVector[np.arange(predArr.size), predArr] = 1
return oneHotVector
# function to print tree
def printTree(self):
if options.verbose > 0:
print("\nPRINTING TREE STARTS")
nodeId=1
# loads the root node
rootNodeDict = torch.load(options.ckptDir+'/node_'+str(nodeId)+'.pth')['nodeDict']
rootNode = myNode(parentId=rootNodeDict['parentId'], nodeId=rootNodeDict['nodeId'], device=self.device, isTrain=False, level=rootNodeDict['level'], parentNode=None)
# q builds the list of all nodes in a BFS fashion, wherein each parent node stores(ON-THE-GO) the list of all children nodes it has
# refer init of <myNode>
q = []
q.append(rootNode)
start = 0
end = 1
while start != end:
node = q[start]
start+=1
currNodeDict = torch.load(options.ckptDir+'/node_'+str(node.nodeId)+'.pth')['nodeDict']
node.nodeId = str(node.nodeId)
node.lchildId = str(currNodeDict['lchildId'])
node.rchildId = str(currNodeDict['rchildId'])
node.isLeaf = currNodeDict['isLeaf']
node.level = str(currNodeDict['level'])
node.leafClass = str(currNodeDict['leafClass'])
node.numClasses = str(currNodeDict['numClasses'])
node.numData = str(currNodeDict['numData'])
node.classLabels = str(currNodeDict['classLabels'])
node.giniGain = str(currNodeDict['giniGain'])
node.splitAcc = str(currNodeDict['splitAcc'])
node.nodeAcc = str(currNodeDict['nodeAcc'])
if not node.isLeaf:
if not (int(node.lchildId) == -1):
lNode = myNode(int(node.nodeId), int(node.lchildId), self.device, False, int(node.level)+1, node) # appends the current left node in the children of parent node
q.append(lNode)
end+=1
if not (int(node.rchildId) == -1):
rNode = myNode(int(node.nodeId), int(node.rchildId), self.device, False, int(node.level)+1, node) # appends the current right node in the children of parent node
q.append(rNode)
end+=1
node.isLeaf = str(node.isLeaf)
# <print_tree()> function is used from the online available pptree module (referrence https://github.com/clemtoy/pptree)
print("nodeId")
print_tree(rootNode, "children", "nodeId", horizontal=False)
print("leafClass")
print_tree(rootNode, "children", "leafClass", horizontal=False)
print("classLabels")
print_tree(rootNode, "children", "classLabels", horizontal=True)
if options.testFlg:
print("splitAcc")
print_tree(rootNode, "children", "splitAcc", horizontal=False)
print("giniGain")
print_tree(rootNode, "children", "giniGain", horizontal=False)
print("numData")
print_tree(rootNode, "children", "numData", horizontal=False)
if options.testFlg:
print("nodeAcc")
print_tree(rootNode, "children", "nodeAcc", horizontal=True)
## print("isLeaf")
## print_tree(rootNode, "children", "isLeaf", horizontal=False)
## print("numClasses")
## print_tree(rootNode, "children", "numClasses", horizontal=False)
print()
# dfs traversal is called when we are doing probablisitic approach and each data point goes to all the nodes
def DFS(self, testInputDict):
if options.verbose > 0:
print("\nDFS STARTS")
nodeId=1
# load root node
ckptRoot = torch.load(options.ckptDir+'/node_cnn_'+str(nodeId)+'.pth')['labelMap']
noOfClasses = len(ckptRoot)
rootNodeDict = torch.load(options.ckptDir+'/node_'+str(nodeId)+'.pth')['nodeDict']
isLeafRoot = rootNodeDict['isLeaf']
leftChildId = rootNodeDict['lchildId']
rightChildId = rootNodeDict['rchildId']
if rootNodeDict['level']>=self.maxDepth:
isLeafRoot=True
leftChildId=-1
rightChildId=-1
rootNode = myNode(parentId=rootNodeDict['parentId'], nodeId=rootNodeDict['nodeId'], device=self.device, isTrain=False, level=rootNodeDict['level'], parentNode=None)
rootNode.setInput(trainInputDict=testInputDict, valInputDict={}, numClasses=noOfClasses, giniValue=0.9, isLeaf=isLeafRoot, leafClass=rootNodeDict['leafClass'], lchildId=leftChildId, rchildId=rightChildId)
# initialising one hot tensors output by zeros
oneHotTensors = torch.zeros(len(testInputDict["label"]), TOTAL_CLASSES)
# initialising node probability for each sample by 1
nodeProb = torch.ones(len(testInputDict["label"]))
# call dfs traversal on root node, and hence, from here, it will traverse all the nodes in a DFS fashion
self.dfsTraversal(rootNode,nodeProb,oneHotTensors)
# After <dfsTraversal> call returns, oneHotTensors is updated, and using it predicted labels are generated
_, predicted = oneHotTensors.max(1)
## np.savetxt("oneHotTensors.txt", oneHotTensors.numpy(), fmt="%f")
## np.savetxt("pred.txt", predicted.numpy(), fmt="%d")
## np.savetxt("act.txt", testInputDict["label"].numpy(), fmt="%d")
# calculate and print the final accuracy obtained
predicted = predicted.to(self.device)
correct = predicted.eq(testInputDict["label"].to(self.device)).sum().item()
total = len(oneHotTensors)
if total != 0:
print('FINAL Acc: %.3f'% (100.*correct/total))
else:
print('FINAL Acc: 0')
# sorts the prediction and actual labels of the samples according to their original indices
indexList, actList, predList = zip(*sorted(zip(testInputDict["index"], testInputDict["label"], predicted)))
predArr = np.array(predList)
predArr = predArr.astype(int)
oneHotVector = np.zeros((predArr.size, TOTAL_CLASSES))
# create a one Hot Vector from the predicted labels and return it
oneHotVector[np.arange(predArr.size), predArr] = 1
return oneHotVector
# this is dfs helper function
def dfsTraversal(self, node, nodeProb, oneHotTensors):
# call testing function of node
if not node.isLeaf:
lTrainDict, rTrainDict, giniLeftRatio, giniRightRatio, noOfLeftClasses, noOfRightClasses, lChildProb, rChildProb = node.workTest(nodeProb, oneHotTensors)
else:
node.workTest(nodeProb, oneHotTensors)
if not node.isLeaf:
if not ((node.lchildId == -1) or (len(lTrainDict["label"]) == 0)):
leftNodeDict = torch.load(options.ckptDir+'/node_'+str(node.lchildId)+'.pth')['nodeDict']
noOfLeftClasses = 1
if (leftNodeDict['leafClass'] == -1):
ckptLeft = torch.load(options.ckptDir+'/node_cnn_'+str(node.lchildId)+'.pth')['labelMap']
noOfLeftClasses = len(ckptLeft)
lNode = myNode(node.nodeId, node.lchildId, self.device, False, leftNodeDict['level'],node)
isLeafLeft = leftNodeDict['isLeaf']
leftChildId = leftNodeDict['lchildId']
rightChildId = leftNodeDict['rchildId']
if leftNodeDict['level']>=self.maxDepth:
isLeafLeft=True
leftChildId=-1
rightChildId=-1
lNode.setInput(lTrainDict, {}, noOfLeftClasses, giniLeftRatio, isLeafLeft, leftNodeDict['leafClass'], leftChildId, rightChildId)
# this is only difference b/w testTraversal and dfsTraversal, as here the nodes are traversed in a DFS way
self.dfsTraversal(lNode, lChildProb, oneHotTensors)
if not ((node.rchildId == -1) or (len(rTrainDict["label"]) == 0)):
rightNodeDict = torch.load(options.ckptDir+'/node_'+str(node.rchildId)+'.pth')['nodeDict']
noOfRightClasses=1
if (rightNodeDict['leafClass'] == -1):
ckptRight = torch.load(options.ckptDir+'/node_cnn_'+str(node.rchildId)+'.pth')['labelMap']
noOfRightClasses = len(ckptRight)
rNode = myNode(node.nodeId, node.rchildId, self.device, False, rightNodeDict['level'], node)
isLeafRight = rightNodeDict['isLeaf']
leftChildId = rightNodeDict['lchildId']
rightChildId = rightNodeDict['rchildId']
if rightNodeDict['level']>=self.maxDepth:
isLeafRight=True
leftChildId=-1
rightChildId=-1
rNode.setInput(rTrainDict, {}, noOfRightClasses, giniRightRatio, isLeafRight, rightNodeDict['leafClass'], leftChildId, rightChildId)
self.dfsTraversal(rNode, rChildProb, oneHotTensors)
# loads the train, validation, and test dictionaries
def loadNewDictionaries(seed_val=0, train_num=TOTAL_TRAIN_IMG, test_num=TOTAL_TEST_IMG):
# sets seed value for the current tree, so that different train_Dictionaries can be created
torch.manual_seed(seed_val)
torch.cuda.manual_seed(seed_val)
np.random.seed(seed_val)
random.seed(seed_val)
torch.backends.cudnn.deterministic=True
# make directory storing the dataset
if not os.path.isdir('data/'):
os.mkdir('data/')
# make directory storing the various checkoints in the code
if not os.path.isdir(options.ckptDir+'/'):
os.mkdir(options.ckptDir+'/')
if options.verbose > 0:
print('==> Preparing data...')
# create the tranformation tobe applied for both train and test samples
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# getting the training dataset
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
class_labels = trainset.targets
## train_batch_sampler = StratifiedSampler(class_labels, TOTAL_TRAIN_IMG)
# specifies the ratio of the left over image samples other than the training ones
test_sz = 1 - (train_num/TOTAL_TRAIN_IMG)
# initialising the validation data and labels as 0 tensor
valData=torch.empty(0)
valLabels=torch.empty(0)
valIndices = torch.arange(0,len(valLabels), step=1, dtype=torch.long) # stores the corresponding indices for the val. samples
# initialising the train_loader with train_num
train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_num, num_workers=0) # num_workers=0 helps in reproducing the same previously obtained outputs/results
# if whole TOTAL_TRAIN_IMG are not used as train_num or train images (i.e. < if(train_num!=TOTAL_TRAIN_IMG): >)
if int(test_sz) != 0:
# this gets the training and validation indexes from class labels.
train_idx, valid_idx= train_test_split(
np.arange(len(class_labels)),
test_size=test_sz,
shuffle=True,
stratify=class_labels)
# make a train sampler having the indexes we found above
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx)
# set the batch size here however required here
train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_num, sampler=train_sampler, num_workers=0)
valid_loader = torch.utils.data.DataLoader(trainset, batch_size=min(100*TOTAL_CLASSES,TOTAL_TRAIN_IMG-train_num), sampler=valid_sampler, num_workers=0)
''' --> PREPEND # FOR ADDING VALIDATION, REMOVE # FOR NO VALIDATION
## we have used iterator to get the validation data and labels and store them in dictionary
vIterator = iter(valid_loader)
c2 = next(vIterator)
valData = c2[0].clone().detach()
valLabels = c2[1].clone().detach()
valIndices = torch.arange(0,len(valLabels), step=1, dtype=torch.long)
# '''
# we have used iterator to get train data and labels and store them in dictionary
iterator = iter(train_loader)
c1 = next(iterator)
trainData = c1[0].clone().detach()
trainLabels = c1[1].clone().detach()
# stores the corresponding indices for the train samples
trainIndices = torch.arange(0,len(trainLabels), step=1, dtype=torch.long)
# while creating test Input Dict., set the seed to 0 (same for all trees) so that results can be compared and test data remains same for all
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)
# get the test data set
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=test_num, shuffle=False, num_workers=0)
# we have used iterator to get test data and labels and store them in dictionary
testIterator = iter(testloader)
c3 = next(testIterator)
testData = c3[0].clone().detach()
testLabels = c3[1].clone().detach()
# stores the corresponding indices for the test samples
testIndices = torch.arange(0,len(testLabels), step=1, dtype=torch.long)
# make dictionaries for train, validation and test data and return them.
return {"data":trainData, "label":trainLabels, "index":trainIndices}, {"data":valData, "label":valLabels, "index":valIndices}, {"data":testData, "label":testLabels, "index":testIndices}
# prints the final accuracy of the ensemble formed
def getEnsembleAcc(sumOneHotVector, testInputDict):
predArr = np.argmax(sumOneHotVector, axis=1)
correct = torch.from_numpy(predArr).eq(testInputDict['label']).sum().item()
total = len(testInputDict['label'])
if total != 0:
print('Ensemble Final Acc: %.3f'% (100.*correct/total))
else:
print('Ensemble Final Acc: 0')
print()
# main function
if __name__ == '__main__':
# we get the options from arguments (refer <getOptions.py> for all the avaialable options and their default values and execute <python3 model.py --h> for more info.)
options = getOptions(sys.argv[1:])
# stores the options in the list
L = [ "options.trainFlg:" + str(options.trainFlg), " options.testFlg:" + str(options.testFlg), " options.maxDepth:" + str(options.maxDepth),
" options.ckptDir:" + options.ckptDir, " options.cnnOut:" + str(options.cnnOut),
" options.mlpFC1:" + str(options.mlpFC1), " options.mlpFC2:" + str(options.mlpFC2),
" options.cnnLR:" + str(options.cnnLR), " options.mlpLR:" + str(options.mlpLR),
" options.cnnEpochs:" + str(options.cnnEpochs), " options.mlpEpochs:" + str(options.mlpEpochs),
" options.cnnSchEpochs:" + str(options.cnnSchEpochs), " options.mlpSchEpochs:" + str(options.mlpSchEpochs),
" options.cnnSchFactor:" + str(options.cnnSchFactor), " options.mlpSchFactor:" + str(options.mlpSchFactor),
" options.cnnBatches:" + str(options.cnnBatches), " options.mlpBatches:" + str(options.mlpBatches),
" options.caseNum:" + str(options.caseNum), " options.optionNum:" + str(options.optionNum),
" options.ensemble:" + str(options.ensemble), " options.probabilistic:" + str(options.probabilistic),
" options.verbose:" + str(options.verbose) ]
if options.verbose > 1:
# printing the options
print(L)
start = time.time()
testInputDict = {} # initialising test input dictionary
trainData_size = int(max(TOTAL_TRAIN_IMG/options.ensemble, 100*TOTAL_CLASSES)) # trainData_size stores the no. of training data samples taken per Tree in the Ensemble
trainData_size -= (trainData_size%TOTAL_CLASSES)
if options.verbose > 1:
print("trainData_size:", trainData_size)
sumOneHotVector = np.zeros((TOTAL_TEST_IMG,TOTAL_CLASSES),dtype=float) # initialising the Sum of One Hot Vectors of each Tree in the Ensemble
# concatOneHotVector = np.zeros((TOTAL_TEST_IMG,0),dtype=int) # initialising the Concat. Output of One Hot Vectors of each Tree in the Ensemble
# iterating over all the trees in the ensemble
for treeIndx in range(0, options.ensemble):
# we load train, validation and test set dictionaries of the given dataset (currently CIFAR-10)
trainInputDict, valInputDict, testInputDict = loadNewDictionaries(treeIndx, trainData_size) # treeIndx here provides the different seed values so as to create differently trained trees
if options.verbose > 1:
print("len(trainInputDict[\"data\"]): ",len(trainInputDict["data"]), ", len(valInputDict[\"data\"]): ",len(valInputDict["data"]), ", len(testInputDict[\"data\"]): ",len(testInputDict["data"]))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = "cuda" if GPU is available, else "cpu"
## This below piece of code can be used for debugging by taking those test samples that belong to only a single particular class (classIDToBeTaken)
## If using this, then Replace <testInputDict> with <newTestInputDict>
'''
newTestInputDict = {}
newTestInputDict["data"] = torch.rand(0,3,32,32) # change this according to the dimensions of the input image
newTestInputDict["label"] = (torch.rand(0)).long()
newTestInputDict["index"] = (torch.rand(0)).long()
cnt = 0
sz = 100 # no. of test samples to be taken
classIDToBeTaken = 1
for k,v in enumerate(testInputDict["label"]):
if (v.item() == classIDToBeTaken) and cnt<sz:
cnt += 1
newTestInputDict["data"] = torch.cat((newTestInputDict["data"],testInputDict["data"][k].view(1,3,32,32)),0) # change this according to the dimensions of the input image
newTestInputDict["label"] = torch.cat((newTestInputDict["label"],v.view(1)),0)
newTestInputDict["index"] = torch.cat((newTestInputDict["index"],testInputDict["index"][k].view(1)),0)
if options.verbose > 1:
print(newTestInputDict["data"].shape, newTestInputDict["label"].shape, newTestInputDict["index"].shape)
# '''
# creating the main <tree> object for further processing
tree = Tree(device, maxDepth=options.maxDepth, dominanceThreshold=0.95, classThreshold = 1, dataNumThreshold = 100, numClasses = TOTAL_CLASSES)
# if Training Mode is On
if options.trainFlg:
''' ## if require resuming Training from some previously "fully" trained node, then comment this line, it will automatically comment the other below part of code
resumeFromNodeId = 4
tree.tree_traversal(trainInputDict, valInputDict, resumeTrain=True, resumeFromNodeId=resumeFromNodeId)
'''
resumeFromNodeId = -1 # if freshly training the tree from the root node
tree.tree_traversal(trainInputDict, valInputDict, resumeTrain=False, resumeFromNodeId=resumeFromNodeId)
# '''
# if Testing Mode is On
if options.testFlg:
oneHotVector = np.zeros(0) # stores the one Hot encoding output of the predicted labels Tensor
# if probabilistic method isn't used
if not options.probabilistic:
''' ## just comment this line in order to use <newTestInputDict>
oneHotVector = tree.testTraversal(newTestInputDict) # if the earlier mentioned <newTestInputDict> is created
'''
oneHotVector = tree.testTraversal(testInputDict)
# '''
# if probabilistic method is used
else:
''' ## just comment this line in order to use <newTestInputDict>
oneHotVector = tree.DFS(newTestInputDict) # if the earlier mentioned <newTestInputDict> is created
'''
oneHotVector = tree.DFS(testInputDict)
# '''
sumOneHotVector += oneHotVector # summing oneHotTensors of all trees in the ensemble
# concatOneHotVector = np.concatenate((concatOneHotVector, oneHotVector), axis=1) # concatenating oneHotTensors of all trees in the ensemble
if options.trainFlg or options.testFlg:
tree.printTree() # printing the built Tree
if options.testFlg:
getEnsembleAcc(sumOneHotVector, testInputDict) # get final accuracy of the ensemble formed
end = time.time()
print("Time Taken by whole program is ", float(end-start)/60.0, " minutes.")