-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathanalysis.py
More file actions
226 lines (193 loc) · 7.17 KB
/
analysis.py
File metadata and controls
226 lines (193 loc) · 7.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# make a bunch of mongodb queries here
import pandas as pd
from array_graph import ArrayGraph
from align import get_best_timestamp, Alignment
from train_models import InfoMan, get_configs, get_parser
from helpers import validate_graph, change_names
from logger import MasterLog
class Conn:
def __init__(self):
self.db = None
def precision_recall_edge_stats():
cols = {"dataset": [],
"alpha": []}
plot = {"dataset": [],
"split": [],
"precision": [],
"recall": [],
"paired": []}
for dataset, args in get_configs().items():
args = get_parser().parse_args(args.split())
args.shuffle = False
best_timestamp = get_best_timestamp(args)
info = InfoMan(args, timestamp=best_timestamp, no_model=True, no_dataset=True)
info.split = "train"
align = Alignment(info)
graph = align.load_graph()
graph: ArrayGraph
info.split = "test"
align = Alignment(info)
test_graph = align.load_graph()
test_graph: ArrayGraph
cols["dataset"].append(dataset)
cols["alpha"].append(graph.alpha)
for idx, e in enumerate(graph.g.edges):
# train
dat = graph.g[e[0]][e[1]]
plot["dataset"].append(dataset)
plot["precision"].append(dat["precision"])
plot["recall"].append(dat["recall"])
plot["split"].append("train")
plot["paired"].append(idx)
assert dat["precision"] == graph.get_precision(e[0], e[1])
assert dat["recall"] == graph.get_recall(e[0], e[1])
# test
test_precision = test_graph.get_precision(e[0], e[1])
test_recall = test_graph.get_recall(e[0], e[1])
plot["dataset"].append(dataset)
plot["precision"].append(test_precision)
plot["recall"].append(test_recall)
plot["split"].append("test")
plot["paired"].append(idx)
alpha_df = pd.DataFrame(cols)
confu_df = pd.DataFrame(plot)
alpha_df.to_csv("plots/alpha.csv")
confu_df.to_csv("plots/confu.csv")
def alignment_stats(uni=False, intervention_type=0, version="alignment_results7"):
master = MasterLog("general", version)
col = master.collection
train_best = {}
valid_best = {}
df = {"dataset": [],
}
for vali in (False, "test"):
best = valid_best if vali else train_best
for dataset in get_configs():
keys = {"prepend": "alignment_stats",
"dataset": dataset,
"validation": vali,
"contradictions": {"$exists": 1},
"type": intervention_type}
if uni:
keys["unidirectional"] = uni
res = col.find(keys)
# make sure to change the version "alignment_results7" to your version so you don't load incorrectly
res = res.sort("counterfactual_aligned", -1)
r = next(iter(res))
best[dataset] = r
vali = True if vali == "test" else False
#
df.update({("aligned", vali): [],
("unchanged", vali): [],
("changed unaligned", vali): [],
("original correct", vali): [],
("contradictions", vali): [], })
for dataset, docu in best.items():
oc = docu["original_correct"]
if vali:
df["dataset"].append(dataset)
df[("aligned", vali)].append(docu["counterfactual_aligned"] / oc)
df[("unchanged", vali)].append(docu["counterfactual_unchanged"] / oc)
df[("changed unaligned", vali)].append(docu["counterfactual_changed_unaligned"] / oc)
df[("original correct", vali)].append(oc)
df[("contradictions", vali)].append(docu["contradictions"])
reord = {"dataset": df["dataset"]}
for col in df:
if col[1] is False:
reord[col] = df[col]
trcol = (col[0], True)
reord[trcol] = df[trcol]
df = pd.DataFrame(reord)
df = df.set_index('dataset')
df = df.round(5)
df.columns = pd.MultiIndex.from_tuples(df.columns, names=['Accuracy', 'Valid'])
df.to_csv(f"plots/align{'_uni' if uni else ''}.csv")
print(df.to_latex())
return df
def interleaving_stats():
best = {}
splits = ["train", "valid", "test"]
df = {}
for dataset in get_configs():
master = MasterLog(dataset, "convergence_push")
col = master.collection
df[dataset] = {}
res = col.find({"prepend": "initial_avgs",
"split": "test"})
res = res.sort("original correct", -1)
ini = next(iter(res))
res = col.find({"split": "test",
"prepend": "final"})
res = res.sort("original correct", -1)
fin = next(iter(res))
df[dataset] = (ini, fin)
# process
processed = {}
for ds, splits in df.items():
processed[ds] = {}
for split, vals in splits.items():
processed[ds][split] = (vals[0]['original_correct'], vals[1]['original_correct'])
tests = {}
for ds, splits in df.items():
tests[ds] = processed[ds]['test']
mul = {"dataset": [],
"before": [],
"after": []}
for ds in tests:
mul["dataset"].append(ds)
bef, aft = tests[ds]
mul["before"].append(bef)
mul["after"].append(aft)
df = pd.DataFrame(mul)
df.to_csv("plots/interleave.csv")
print(df.to_latex())
def inter_layer_analysis():
configs = get_configs()
args = configs["mnist"]
args = get_parser().parse_args(args.split())
args.shuffle = False
best_timestamp = get_best_timestamp(args)
info = InfoMan(args, timestamp=best_timestamp)
info.split = "train"
info.args["inter_layer"] = True
info.load_checkpoint()
align = Alignment(info)
graph = align.load_graph()
pass
info.split = "valid"
valid_graph = align.load_graph()
precisions, edges, train_acc, valid_acc = validate_graph(graph, valid_graph)
print(f"{train_acc=}, {valid_acc=}")
for e in graph.g.edges:
a, b = e
ranges = [10, 138, 266]
if 10 <= a[1] < 138 and 138 < b[1] < 266:
print(f"{e}:{graph.g[a][b]}. Train/valid acc: {precisions[e]}")
def alignment_table():
df = alignment_stats(True, 0, "alignment_results7")
# df = alignment_stats(True, 2, "rerun2")
# df["beta"] = 1
# cols = df.columns.to_list()
# cols = [cols[-1]]+cols[:-1]
#
# df2 = alignment_stats(False)
# beta=[]
# for dataset in df.index:
# configs = get_configs()[dataset]
# args = get_parser().parse_args(args=configs.split())
# beta.append(args.bidi_recall_thres)
# df2["beta"] = beta
#
# combined = pandas.concat([df, df2])
# combined = df
# combined = combined[cols]
df = change_names(df)
for col in df:
if col not in (('beta', ''), ('contradictions', False), ('contradictions', True)):
df[col] = df[col] * 100
print(df.to_latex(float_format="%.2f"))
print("DONE")
if __name__ == '__main__':
precision_recall_edge_stats()
alignment_table()
# inter_graph_analysis()