-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbbes.py
More file actions
2069 lines (1999 loc) · 98.1 KB
/
bbes.py
File metadata and controls
2069 lines (1999 loc) · 98.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import print_function
import numpy as np
import Queue
import math
import collections
import warnings
import pickle
import time
import sys
from UnionFind import UnionFind # from David Eppstein's PADS library
def node_name(v):
return chr(97 + v)
def node_name_list(S):
ret = ""
for v, val in enumerate(S):
if val:
ret += node_name(v)
return ret
def bin2mask(d, bin):
mask = np.zeros(d, dtype=bool)
for i in range(d):
if bin & (1 << i):
mask[i] = True
return mask
def mask2bin(mask):
# self.powersof2 = np.ones(self.d, dtype=int)
# for i in range(1, self.d):
# self.powersof2[i] = self.powersof2[i-1] << 1
# ...
# int(self.powersof2.dot(mask))
bin = 0
for i, val in enumerate(mask):
if val:
bin += (1 << i)
return bin
class EquivalenceClass:
"""A Markov equivalence class (set of DAGs) as point in the search space"""
def __init__(self, pdag, essential_mB=None, children=None):
if essential_mB is not None:
# with precomputed search space
self.example_mB = pdag # so should not contain undir edges here
self.num_edges = np.count_nonzero(self.example_mB)
(self.skeleton, self.vstructs) = skeleton_vstructs_normalizer(self.example_mB)
self.essential_mB = essential_mB # AND of mB's
self.children = children
self.children_identifying_constraint = None # computed later
else:
# without precomputed search space: less data
self.pdag = pdag
self.is_pdag_completed = False # otherwise, it is minimal
self.num_edges = np.count_nonzero(np.logical_or(pdag, pdag.T)) / 2
def get_cpdag(self):
if hasattr(self, 'essential_mB'):
# precomputed case
d = self.example_mB.shape[0]
return np.logical_and(np.logical_and(self.skeleton, np.logical_not
(np.eye(d, dtype=bool))),
np.logical_not(self.essential_mB.T))
# non-precomputed case
if not self.is_pdag_completed:
self.pdag = complete_pdag(self.pdag)
self.is_pdag_completed = True
return self.pdag
def __str__(self):
if hasattr(self, 'essential_mB'):
d = self.example_mB.shape[0]
ret = ""
for v in range(d):
for w in range(d):
if v == w:
ret += " O "
elif not self.skeleton[v,w]:
ret += " . "
elif self.vstructs[v,w]:
ret += "<-- "
elif self.vstructs[w,v]:
ret += "--> "
elif self.essential_mB[v,w] and not self.essential_mB[w,v]:
ret += "{-- " # direction inferred indirectly
elif self.essential_mB[w,v] and not self.essential_mB[v,w]:
ret += "--} " # direction inferred indirectly
else:
ret += "--- "
if v < d - 1: # no \n at end, as a print will usually add it
ret += "\n"
#ret += str(self.children) + "\n"
return ret
else:
d = self.pdag.shape[0]
ret = ""
for v in range(d):
for w in range(d):
if v == w:
ret += " O "
elif not self.pdag[v,w] and not self.pdag[w,v]:
ret += " . "
elif self.pdag[v,w] and not self.pdag[w,v]:
ret += "<-- "
elif self.pdag[w,v] and not self.pdag[v,w]:
ret += "--> "
else:
ret += "--- "
if v < d - 1:
# no terminating \n, as a print will usually add it
ret += "\n"
#ret += str(self.children) + "\n"
if not self.is_pdag_completed:
ret += "(may not be complete)\n"
return ret
def saturated_class(d):
pdag = np.logical_not(np.eye(d, dtype=bool))
ret = EquivalenceClass(pdag)
ret.is_pdag_completed = True
return ret
class Constraint:
"""A conditional independence constraint"""
def __init__(self, v, w, S, imposed_by=None):
# NOTE: order of v and w may matter when using constraints to represent
# Delete operators (because only one of the two might be defined; if
# both are defined and valid, they represent the same operator).
self.v = v
self.w = w
self.S = S # mask (i.e. np.array of bools)
self.imposed_by = imposed_by
def has_same_adjacency(self, other):
if self.v == other.v and self.w == other.w:
return True
if self.v == other.w and self.w == other.v:
return True
return False
def is_same_constraint(self, other):
if not self.has_same_adjacency(other):
return False
return np.all(self.S == other.S)
def __str__(self):
ret = "{0} _||_ {1}".format(node_name(self.v), node_name(self.w))
if np.any(self.S):
ret += " |"
for i in range(self.S.size):
if self.S[i]:
ret += " {0}".format(node_name(i))
return ret
def transitive_reflexive_closure(mB):
# (from aelsem)
d = mB.shape[0]
trans_refl_closure = np.logical_or(np.eye(d, dtype=bool), mB)
prev_num = np.count_nonzero(trans_refl_closure)
# O(log n) calls to numpy matrix multiplication probably faster than
# O(n^3) loop in Python
while True:
#print(trans_refl_closure)
trans_refl_closure = np.linalg.matrix_power(trans_refl_closure, 2)
num = np.count_nonzero(trans_refl_closure)
if num == prev_num:
break
prev_num = num
return trans_refl_closure
def has_cycles(mB):
# (from aelsem)
# test for cycles of length 2 or more
d = mB.shape[0]
trans_edges = np.logical_xor(transitive_reflexive_closure(mB),
np.eye(d, dtype=bool))
return np.any(np.logical_and(trans_edges, trans_edges.T))
def hash_model((skel, vstr)):
return skel.tostring() + vstr.tostring()
def skeleton_vstructs_normalizer(mB):
d = mB.shape[0]
skel = np.logical_or(np.eye(d, dtype=bool), np.logical_or(mB, mB.T))
# mark arrows v<--w for which also v<--x...w (and x != w: uses that skel
# is True on diagonal)
vstr = np.logical_and(mB, mB.dot(np.logical_not(skel)))
return (skel, vstr)
def complete_pdag(pdag):
# TODO: other algorithms exist for this task that are in principle more
# efficient; see e.g. Chickering2002
d = pdag.shape[0]
# set diagonal to False (skeleton_vstructs_normalizer sets it to True
# on skel)
pdag = np.logical_and(pdag, np.logical_not(np.eye(d, dtype=bool)))
skel = np.logical_or(pdag, pdag.T)
mB = np.logical_and(pdag, np.logical_not(pdag.T))
mU = np.logical_and(pdag, pdag.T)
undir_count = np.sum(mU)
while undir_count > 0:
# Keep applying Meek's orientation rules until nothing changes
# Rule 1: v---w with v...x-->w becomes v<--w
mB_add = np.logical_and(mU, np.dot(np.logical_not(skel), mB.T))
np.logical_or(mB, mB_add, out=mB)
np.logical_and(mU, np.logical_not(mB_add), out=mU)
np.logical_and(mU, np.logical_not(mB_add).T, out=mU)
# Rule 2: v---w with v<--x<--w becomes v<--w
mB_add = np.logical_and(mU, np.dot(mB, mB))
np.logical_or(mB, mB_add, out=mB)
np.logical_and(mU, np.logical_not(mB_add), out=mU)
np.logical_and(mU, np.logical_not(mB_add).T, out=mU)
for x in range(d):
# Rule 3: v---w with v<--x---w and v<--y---w but x...y becomes v<--w
valid_y = np.logical_not(skel[x,:])
valid_y[x] = False
mB_add = np.logical_and(np.logical_and(mU, np.outer(mB[:,x],
mU[x,:])),
np.dot(np.logical_and(mB, valid_y),
mU))
np.logical_or(mB, mB_add, out=mB)
np.logical_and(mU, np.logical_not(mB_add), out=mU)
np.logical_and(mU, np.logical_not(mB_add).T, out=mU)
# Rule 4: v---w with v...x---w and v<--y*-*w and y<--x becomes v<--w
mB_add = np.logical_and(np.logical_and(mU, np.outer(np.logical_not(skel[:,x]), mU[x,:])),
np.dot(np.logical_and(mB, mB[:,x]),
mU))
np.logical_or(mB, mB_add, out=mB)
np.logical_and(mU, np.logical_not(mB_add), out=mU)
np.logical_and(mU, np.logical_not(mB_add).T, out=mU)
new_undir_count = np.sum(mU)
if new_undir_count == undir_count:
break
undir_count = new_undir_count
pdag = np.logical_or(mB, mU)
return pdag
def complete_pdag_fixed_orientations(orig_cpdag, req_adjs):
# Using v-structures and Meek's orientation rules, find what orientations
# are fixed in a set of classes defined by an original cpdag with some
# edges not required (i.e. might be deleted).
# * Return value fixed_orient will be a subset of the cpdag's mB part.
# * A directed edge that is not required can also be marked as "fixed",
# which means that it has that orientation in all classes where the
# two nodes are adjacent.
d = orig_cpdag.shape[0]
orig_adj = np.logical_or(orig_cpdag, orig_cpdag.T)
orig_mU = np.logical_and(orig_cpdag, orig_cpdag.T)
sure_nonadj = np.logical_not(np.logical_or(orig_adj,
np.eye(d, dtype=bool)))
orig_mB = np.logical_and(orig_cpdag, np.logical_not(orig_cpdag.T))
req_mB = np.logical_and(np.logical_not(orig_cpdag.T), req_adjs)
# If an arrow v-->w participates in a guaranteed v-structure (x-->w
# required, v and x definitely not adjacent), then its orientation is fixed
# in all cpdags where v and w are adjacent. (If v-->w is required, then
# x-->w will be marked as fixed by the same rule; if it isn't, it's still
# correct to mark v-->w as fixed. This is the converse of Chickering2002's
# Lemma 28; see notes on printout for proof.)
fixed_orient = np.logical_and(orig_mB, #w<--v
(req_mB) # w<--x req
.dot(sure_nonadj)) # x...v sure
#orig_skel = np.logical_or(orig_cpdag, orig_pdag.T)
#orig_mB = np.logical_and(orig_cpdag, np.logical_not(orig_cpdag.T))
#mU = np.logical_and(pdag, pdag.T)
fixed_count = np.sum(fixed_orient)
while True:
# Keep applying Meek's orientation rules until nothing changes
# Rule 1: v<--w with v...x [sure] and x-->w [req,fixed] becomes fixed
fix_add = np.logical_and(orig_mB,
np.dot(sure_nonadj,
np.logical_and(fixed_orient.T,
req_adjs)))
np.logical_or(fixed_orient, fix_add, out=fixed_orient)
# Rule 2: v<--w with v<--x [req,fixed] and x<--w [fixed] becomes fixed
# (x<--w not required: if absent, v<--w is still fixed because it is
# in a v-structure)
fix_add = np.logical_and(orig_mB,
np.dot(np.logical_and(fixed_orient, req_adjs),
fixed_orient))
np.logical_or(fixed_orient, fix_add, out=fixed_orient)
for x in range(d):
# Rule 3: v<--w with v<--x [req,fixed], x---w and
# v<--y [req, fixed], y---w but x...y [sure] becomes fixed
valid_y = sure_nonadj[x,:]
fix_add = np.logical_and(np.logical_and(orig_mB,
np.outer(np.logical_and(fixed_orient[:,x], req_adjs[:,x]),
orig_mU[x,:])),
np.dot(np.logical_and(np.logical_and(fixed_orient, req_adjs), valid_y),
orig_mU))
np.logical_or(fixed_orient, fix_add, out=fixed_orient)
# Rule 4: v---w with v...x---w and v<--y*-*w and y<--x becomes v<--w
# TODO
#fix_add = np.logical_and(np.logical_and(mU, np.outer(np.logical_not(skel[:,x]), mU[x,:])),
# np.dot(np.logical_and(mB, mB[:,x]),
# mU))
#np.logical_or(fixed_orient, fix_add, out=fixed_orient)
new_fixed_count = np.sum(fixed_orient)
if new_fixed_count == fixed_count:
break
fixed_count = new_fixed_count
return fixed_orient
def complete_pdag_simpleTEST():
rule3 = np.ones((4,4), dtype=bool)
rule3[2,3] = rule3[3,2] = False
rule3[2,0] = rule3[3,0] = False
print("Rule 3:")
print(rule3)
print(complete_pdag(rule3))
rule4 = np.ones((4,4), dtype=bool)
rule4[2,3] = rule4[3,2] = False
rule4[1,2] = rule4[3,1] = False
print("Rule 4:")
print(rule4)
print(complete_pdag(rule4))
rule4[0,1] = False
print("Rule 4b:")
print(rule4)
print(complete_pdag(rule4))
rule4[0,1] = True
rule4[1,0] = False
print("Rule 4c:")
print(rule4)
print(complete_pdag(rule4))
def complete_pdag_bigTEST(eq_classes):
# Tested for d=4 and d=5: correct
print("Testing complete_pdag()")
d = eq_classes[0].skeleton.shape[0]
for eq_class in eq_classes:
#pattern = np.logical_or(eq_class.vstructs, eq_class.skeleton)
pattern = np.logical_and(eq_class.skeleton, np.logical_not(eq_class.vstructs.T))
cpdag = complete_pdag(pattern)
#cpdag_expected = np.logical_and(np.logical_or(eq_class.essential_mB,
# eq_class.skeleton),
# np.logical_not(np.eye(d, dtype=bool)))
cpdag_expected = np.logical_and(np.logical_and(np.logical_not(eq_class.essential_mB).T,
eq_class.skeleton),
np.logical_not(np.eye(d, dtype=bool)))
if np.any(cpdag != cpdag_expected):
print("Error")
print(eq_class)
print(cpdag)
print(cpdag_expected)
return
def generate_valid_delete_operators(cpdag):
# TODO: this can probably be made more efficient using the original class's
# list of delete operators
# TODO: only generate operators up to some conditioning set size
d = cpdag.shape[0]
ret = []
for v in range(d-1):
for w in range(v+1, d):
if cpdag[v,w] == False and cpdag[w,v] == False:
continue
# opt_vw = NA_YX; req_vw = Pa_Y \ X
if cpdag[w,v] == True:
# v --> w or v --- w
# can take x=v, y=w
opt_vw = np.logical_and(np.logical_or(cpdag[v,:],
cpdag[:,v]), # na *-* x
np.logical_and(cpdag[w,:],
cpdag[:,w])) # na --- y
req_vw = np.logical_and(cpdag[w,:], np.logical_not(cpdag[:,w]))
req_vw[v] = False
opt_vw_list = np.nonzero(opt_vw)[0]
n = len(opt_vw_list)
for opt_mask in range(1 << n):
S = req_vw.copy()
for i in range(n):
if (opt_mask & (1 << i)):
S[opt_vw_list[i]] = True
new_constraint = Constraint(v, w, S)
if is_delete_operator_valid(cpdag, new_constraint):
ret.append(new_constraint)
if cpdag[v,w] == True:
# w --> v or w --- v
# can take x=w, y=v
opt_wv = np.logical_and(np.logical_or(cpdag[w,:],
cpdag[:,w]), # na *-* x
np.logical_and(cpdag[v,:],
cpdag[:,v])) # na --- y
req_wv = np.logical_and(cpdag[v,:], np.logical_not(cpdag[:,v]))
req_wv[w] = False
opt_wv_list = np.nonzero(opt_wv)[0]
n = len(opt_wv_list)
for opt_mask in range(1 << n):
S = req_wv.copy()
for i in range(n):
if (opt_mask & (1 << i)):
S[opt_wv_list[i]] = True
if cpdag[w,v] == True:
# TODO: testing if this can be simplified:
if not (np.all(req_vw == req_wv)
and np.all(opt_vw == opt_wv)):
print("WARNING: Different set of existing Delete operators for {0}---{1} vs {1}---{0}"
.format(node_name(v), node_name(w)))
print("Required nodes:")
print(req_vw)
print(req_wv)
print("Optional nodes:")
print(opt_vw)
print(opt_wv)
# v --- w, so pay attention to avoid duplicate operators
duplicate = True
if np.any(np.logical_and(req_vw, np.logical_not(S))):
duplicate = False
if np.any(np.logical_and(S, np.logical_not(np.logical_or(opt_vw, req_vw)))):
duplicate = False
if duplicate:
continue
new_constraint = Constraint(w, v, S)
if is_delete_operator_valid(cpdag, new_constraint):
ret.append(new_constraint)
return ret
def is_delete_operator_valid(cpdag, constraint):
d = cpdag.shape[0]
x = constraint.v
y = constraint.w
# Chickering2002: valid iff NA_YX \ H is a clique
# my S = Chickering's NA(Y,X) \ H cup Pa_Y (where NA(Y,X) and Pa_Y disjoint)
# So: NA_YX \ H = S \ Pa_Y
clique = np.logical_and(constraint.S,
np.logical_not(np.logical_and
(cpdag[y,:],
np.logical_not(cpdag[:,y]))))
# add diagonal
cpdag_diag = np.logical_or(cpdag, np.eye(d, dtype=bool))
return np.all(np.logical_or(cpdag_diag, cpdag_diag.T)[clique,:][:,clique])
def apply_delete_operator(cpdag, constraint):
# Returns the *minimal* PDAG (i.e. a pattern) representing the equivalence
# class obtained by taking a CPDAG, and applying the Delete operator
# corresponding to the given constraint.
pdag = cpdag.copy()
# Assumes that constraint specifies a valid and defined Delete operator
# in the sense of Chickering2002, for x=v and y=w (i.e. order matters).
x = constraint.v
y = constraint.w
# my S = Chickering's NA(Y,X) \ H cup Pa_Y (where NA(Y,X) and Pa_Y disjoint)
# So: H = NA(Y,X) \ S
NA_YX = np.logical_and(np.logical_or(pdag[x,:], pdag[:,x]), # na *-* x
np.logical_and(pdag[y,:], pdag[:,y])) # na --- y
H = np.logical_and(NA_YX, np.logical_not(constraint.S))
# (if pdag has True on diagonal [shouldn't happen], then that won't hurt)
pdag[x,y] = pdag[y,x] = False
pdag[x,H] = pdag[y,H] = False
# This PDAG may have directed arrows that could be oriented the other way
# around in some equivalent DAG (a PDAG just needs to have the right
# skeleton and v-structures).
# Generalization of skeleton_vstructs_normalizer():
mB = np.logical_and(pdag, np.logical_not(pdag.T))
nonadj = np.logical_not(np.logical_or(np.logical_or(pdag, pdag.T),
np.eye(pdag.shape[0], dtype=bool)))
vstr = np.logical_and(mB, mB.dot(nonadj))
pdag = np.logical_and(np.logical_or(pdag, pdag.T),
np.logical_not(vstr.T))
return pdag
def generate_delete_operators_TEST():
# check output on cpdag that imposes only a _||_ b | c
d = 4
saturated_cpdag = saturated_class(d).get_cpdag()
constraint = Constraint(0, 1, bin2mask(d, 4))
pdag = apply_delete_operator(saturated_cpdag, constraint)
cpdag = complete_pdag(pdag)
res = generate_valid_delete_operators(cpdag)
for del_op in res:
print(del_op)
def partition_graphs(graphs, normalizer, verbose=False):
classes = {}
ret = []
for graph in graphs:
which_class = normalizer(graph)
#print(graph)
#print("normalizes to")
#print(which_class)
#print(".")
if hash_model(which_class) in classes:
classes[hash_model(which_class)].append(graph)
else:
classes[hash_model(which_class)] = [graph]
ret.append(which_class)
if verbose:
print("Identified", len(ret), "distinct classes using",
normalizer.__name__, file=sys.stderr)
for i, which_class in enumerate(ret):
ret[i] = classes[hash_model(which_class)]
return ret
def generate_all_DAGs(d):
graphs = []
n = d * (d-1) / 2
N = 3 ** n
print("Generating DAGs:", file=sys.stderr)
for mask in range(N):
if (mask & 255) == 0:
print("\r{0:0.2f}% - generated {1} DAGs"
.format(100.0 * mask / N, len(graphs)),
end='', file=sys.stderr)
mB = np.zeros((d,d), dtype=bool)
for i in range(1, d):
for j in range(0, i):
mask, curmask = divmod(mask, 3)
if curmask == 1:
mB[i,j] = True
elif curmask == 2:
mB[j,i] = True
if not has_cycles(mB):
graphs.append(mB)
print("\r100.00%; generated", len(graphs), "DAGs", file=sys.stderr)
return graphs
def compute_essential_mB(graphs):
ret = graphs[0]
for graph in graphs:
ret = np.logical_and(ret, graph)
return ret
def is_d_connected_dfs(mB, pos, w, S, vis):
# Made modifications to deal with CPDAGs as input
#print("At ", pos)
(v, dir) = pos
if v == w:
return True
if (dir == 0 and not S[v]) or (dir == 1 and S[v]):
# traverse backward (dir=0) along an arrow
#print(mB)
#print(v)
#print(mB[v,:])
#print(vis[:,0])
next_vs_mask = np.logical_and(mB[v,:], np.logical_not(vis[:,0]))
if dir == 1:
# we can't continue on an undirected path in case dir == 1 and S[v])
next_vs_mask = np.logical_and(next_vs_mask, np.logical_not(mB[:,v]))
#print(next_vs_mask.shape)
tmp = np.logical_or(vis[:,0], next_vs_mask)
#print(tmp.shape)
vis[:,0] = tmp
#print(np.nonzero(next_vs_mask))
#print(np.nonzero(next_vs_mask)[0])
for next_v in np.nonzero(next_vs_mask)[0]:
#print("visiting ", next_v)
if is_d_connected_dfs(mB, (next_v, 0), w, S, vis):
return True
if not S[v]:
# traverse forward (dir=1) along an arrow
next_vs_mask = np.logical_and(mB[:,v], np.logical_not(vis[:,1]))
next_vs_mask = np.logical_and(next_vs_mask, np.logical_not(mB[v,:]))
vis[:,1] = np.logical_or(vis[:,1], next_vs_mask)
for next_v in np.nonzero(next_vs_mask)[0]:
if is_d_connected_dfs(mB, (next_v, 1), w, S, vis):
return True
return False
def is_d_separated(mB, v, w, S):
# mB can also be a cpdag
if S[v] or S[w]:
return True
d = mB.shape[0]
# vis[v,0]: reachable by path ending in tail
# vis[v,1]: reachable by path ending in head
vis = np.zeros((d,2), dtype=bool)
pos = (v,0)
vis[pos] = True
is_d_connected_dfs(mB, pos, w, S, vis)
if vis[w,0] or vis[w,1]:
return False
return True
def is_d_separated_TEST():
returned = []
expected = []
d = 4
S = np.zeros(d, dtype=bool)
S1 = np.copy(S)
S1[1] = True
S2 = np.copy(S)
S2[2] = True
S12 = np.copy(S)
S12[1] = S12[2] = True
# 0 <-- 1 --> 2 --> 3
mB = np.zeros((d,d), dtype=bool)
mB[0,1] = mB[2,1] = mB[3,2] = True
returned.append(is_d_separated(mB, 0, 3, S))
expected.append(False)
returned.append(is_d_separated(mB, 0, 3, S1))
expected.append(True)
returned.append(is_d_separated(mB, 0, 3, S2))
expected.append(True)
returned.append(is_d_separated(mB, 0, 3, S12))
expected.append(True)
# 0 --> 1 <-- 3 with 1 --> 2
mB = np.zeros((d,d), dtype=bool)
mB[1,0] = mB[2,1] = mB[1,3] = True
returned.append(is_d_separated(mB, 0, 3, S))
expected.append(True)
returned.append(is_d_separated(mB, 0, 3, S1))
expected.append(False)
returned.append(is_d_separated(mB, 0, 3, S2))
expected.append(False)
returned.append(is_d_separated(mB, 0, 3, S12))
expected.append(False)
if returned != expected:
print("There were incorrect results in is_d_separated:")
print("Return values:")
print(returned)
print("Expected:")
print(expected)
else:
print("All tests passed for is_d_separated")
def lemma_34_TEST(d):
# Test Lemma 34 from Chickering: if two edges in a clique of size three
# in a CPDAG are undirected, then so is the third.
solver = BBES(d, 0)
equivalence_classes = solver.search_space['equivalence_classes']
for eq_class in equivalence_classes:
cpdag = eq_class.get_cpdag()
for v in range(d-1):
for w in range(v+1, d):
if cpdag[v,w] and cpdag[w,v]:
for x in range(d):
num = 1
if cpdag[v,x] and cpdag[x,v]:
num += 1
if cpdag[w,x] and cpdag[x,w]:
num += 1
if ((cpdag[v,x] or cpdag[x,v])
and (cpdag[w,x] or cpdag[x,w])
and num == 2):
print("ERROR: clique {0}-{1}-{2} violates Chickering's Lemma 34 in the following CPDAG:"
.format(v, w, x))
print(eq_class)
print("Test of Lemma 34 complete")
# Test completed with no errors on d=4,5,6
def canonical_constraint(mB_parent_class, mB_child_class):
# Find the constraint that a child class introduces compared to a parent
# class, such that the difference in likelihood of those classes equals the
# difference between a class imposing only this constraint and the saturated
# class.
# This is not necessarily the "simplest" constraint (but it doesn't need
# to be; it mostly needs to be easy to compute). E.g. to separate
# parent a --> b <-- c from child a --- b c, the constraint is
# b _||_ c | a.
d = mB_parent_class.shape[0]
error_msg = None
coefficients = collections.Counter()
for v in range(d):
parents = mB_parent_class[v,:]
bin_parents = mask2bin(parents)
bin_parents_and_v = bin_parents + (1 << v)
coefficients[bin_parents_and_v] += 1
coefficients[bin_parents] -= 1
for v in range(d):
parents = mB_child_class[v,:]
bin_parents = mask2bin(parents)
bin_parents_and_v = bin_parents + (1 << v)
# as above, with + and - reversed
coefficients[bin_parents_and_v] -= 1
coefficients[bin_parents] += 1
positive_bins = []
for bin, coefficient in coefficients.items():
if coefficient <= 0:
continue
# there should be two entries with positive coefficient (namely 1):
# S and S+{v}+{w}
if coefficient > 1:
error_msg = "canonical_constraint encountered cluster with coefficient 2"
break
positive_bins.append(bin)
if len(positive_bins) != 2:
error_msg = "canonical_constraint encountered {0} clusters with coefficient 1 (expected 2)".format(len(positive_bins))
else:
positive_bins.sort()
S = bin2mask(d, positive_bins[0])
Svw = bin2mask(d, positive_bins[1])
if np.any(np.logical_and(S, np.logical_not(Svw))):
error_msg = "canonical_constraint encountered pair of coef-1 clusters that are not subset-related"
elif np.sum(np.logical_and(Svw, np.logical_not(S))) != 2:
error_msg = "canonical_constraint encountered pair of coef-1 clusters that differ by other than 2 elements"
else:
vw_list = np.nonzero(np.logical_and(Svw, np.logical_not(S)))[0]
v = vw_list[0]
w = vw_list[1]
return Constraint(v, w, S)
# Something unexpected occured:
print("canonical_constraint - mB_parent_class:", file=sys.stderr)
print(mB_parent_class, file=sys.stderr)
print("canonical_constraint - mB_child_class:", file=sys.stderr)
print(mB_child_class, file=sys.stderr)
raise ValueError(error_msg)
def find_constraint_in_list(constraints, constraint):
# Binary search using that constraints in list are sorted on (v, w, S_bin)
v = constraint.v
w = constraint.w
S_bin = mask2bin(constraint.S)
lo = 0
hi = len(constraints)
while lo < hi:
mid = lo + (hi - lo) / 2
if (constraints[mid].v > v
or (constraints[mid].v == v and constraints[mid].w > w)
or (constraints[mid].v == v and constraints[mid].w == w
and mask2bin(constraints[mid].S) >= S_bin)):
# constraints[mid] >= constraint
hi = mid
else:
lo = mid + 1
if lo >= len(constraints):
raise ValueError("constraint not found")
return lo
def search_space_precompute(filename, d):
all_graphs = generate_all_DAGs(d)
graphs_by_class = partition_graphs(all_graphs, skeleton_vstructs_normalizer,
verbose=True)
# sort by number of edges
print("Sorting equivalence classes by number of edges", file=sys.stderr)
graphs_by_class = sorted(graphs_by_class,
key=lambda graphs: np.count_nonzero(graphs[0]))
print("Computing EquivalenceClass list", file=sys.stderr)
equivalence_classes = [EquivalenceClass(graphs[0],
compute_essential_mB(graphs),
[])
for graphs in graphs_by_class]
first_class_with_num_params = np.zeros(equivalence_classes[-1].num_edges + 2, dtype=int)
cur_num_params = -1
for i, eq_class in enumerate(equivalence_classes):
num_params_here = eq_class.num_edges
if num_params_here > cur_num_params:
cur_num_params = num_params_here
first_class_with_num_params[cur_num_params] = i
first_class_with_num_params[-1] = len(equivalence_classes)
#print(first_class_with_num_params)
# compute children of each equivalence class
print("Computing child classes of each equivalence class:", file=sys.stderr)
index_of_skel_vstr = {hash_model(skeleton_vstructs_normalizer(eq_class.example_mB)):index for index, eq_class in enumerate(equivalence_classes)}
for parent_index, parent_graphs in enumerate(graphs_by_class):
children = set()
for parent_graph in parent_graphs:
for v in range(d):
for w in range(d):
if parent_graph[v,w]:
child_graph = parent_graph.copy()
child_graph[v,w] = False
child_index = index_of_skel_vstr[hash_model(skeleton_vstructs_normalizer(child_graph))]
children.add(child_index)
equivalence_classes[parent_index].children = sorted(children)
print("\r{0}: {1:0.2f}%"
.format(parent_index + 1,
100.0 * (parent_index + 1) / len(equivalence_classes)),
end='', file=sys.stderr)
print(file=sys.stderr)
num_constraints = d * (d-1) / 2 * (1 << (d-2))
num_computed = 0
constraints = []
print("Computing constraints:", file=sys.stderr)
# For n=6, this takes about 5 hours
# file sizes (with example_mB not set to bool in first version):
# n=3: 5 KB -> 4 KB
# n=4: 91 KB -> 70 KB
# n=5: 5.6 MB -> 4.1 MB
# n=6: 1.01 GB -> .77 GB
for v in range(d-1):
for w in range(v+1, d):
for S_bin in range(1 << d):
if (S_bin & (1 << v)) or (S_bin & (1 << w)):
continue
num_computed += 1
S = np.zeros(d, dtype=bool)
for i in range(d):
if S_bin & (1 << i):
S[i] = True
imposed_by = np.zeros(len(equivalence_classes), dtype=bool)
for i, eq_class in enumerate(equivalence_classes):
if is_d_separated(eq_class.example_mB, v, w, S):
imposed_by[i] = True
constraint = Constraint(v, w, S, imposed_by)
constraints.append(constraint)
#print("Constraint", constraint, "is imposed by:")
#print(imposed_by)
print("\r{0}: {1:0.2f}%"
.format(num_computed,
100.0 * num_computed / num_constraints),
sep='', end='', file=sys.stderr)
print(file=sys.stderr)
print("Computing identifying constraint for each parent-child pair:",
file=sys.stderr)
#unique_constraints_TEST(equivalence_classes, constraints,
# verbose=False, write=True)
for parent_i, parent_class in enumerate(equivalence_classes):
children_identifying_constraint = []
for child_list_i, child_i in enumerate(parent_class.children):
child_class = equivalence_classes[child_i]
likelihood_ratio_constraint = canonical_constraint(parent_class.example_mB, child_class.example_mB)
constraint_i = find_constraint_in_list(constraints, likelihood_ratio_constraint)
children_identifying_constraint.append(constraint_i)
parent_class.children_identifying_constraint = children_identifying_constraint
print("\r{0}: {1:0.2f}%"
.format(parent_i + 1,
100.0 * (parent_i + 1) / len(equivalence_classes)),
end='', file=sys.stderr)
print(file=sys.stderr)
if False:
# Print a list of all classes, with their children + id-constraints
for i, eq_class in enumerate(equivalence_classes):
print(i)
print(eq_class)
for j, child_i in enumerate(eq_class.children):
print("Child {0} with identifying constraint {1}"
.format(child_i,
constraints[eq_class.children_identifying_constraint[j]]))
search_space = dict(equivalence_classes=equivalence_classes,
constraints=constraints,
first_class_with_num_params=first_class_with_num_params)
print("Saving results to file {0}".format(filename),
file=sys.stderr)
with open(filename, "wb") as pickle_file:
pickle.dump(search_space, pickle_file, pickle.HIGHEST_PROTOCOL)
return search_space
def search_space_prepare(d):
filename = ("DAG_classes_{0}.pickle".format(d))
try:
pickle_file = open(filename, "rb")
search_space = pickle.load(pickle_file)
except IOError:
print("Precomputing search space data; will store it in {0}"
.format(filename), file=sys.stderr)
search_space = search_space_precompute(filename, d)
else:
pickle_file.close()
return search_space
def get_descendant_classes(eq_classes, ancestor_i):
# Boolean vector indicating the descendants of an equivalence class,
# including that class itself
desc = np.zeros(len(eq_classes), dtype=bool)
desc[ancestor_i] = True
for i in range(ancestor_i, -1, -1):
if not desc[i]:
continue
for child_i in eq_classes[i].children:
desc[child_i] = True
return desc
def unique_constraints_TEST(eq_classes, constraints, verbose=False, write=False):
# For a given parent and child class, a "unique constraint" is one that is
# imposed by the child class, but not by the parent or any of the parent's
# other children. We call such a constraint "bad" if it is imposed by
# some descendant of the parent class other than the child class or one of
# its descendants. This function tests whether the precomputed "identifying
# constraint" is unique and not bad; if no id-constraint was precomputed
# yet, just check if a non-bad unique constraint exists (and if write=True,
# also write such constraints to eq_classes).
# Fully tested for d=4,5 (succes); for d=6, only tested up to 33.44%
print("unique_constraints_TEST:", file=sys.stderr)
for parent_i, parent_class in enumerate(eq_classes):
children_identifying_constraint = []
constraints_parent = np.zeros(len(constraints), dtype=bool)
for j in range(len(constraints)):
constraints_parent[j] = constraints[j].imposed_by[parent_i]
descendants_of_parent = get_descendant_classes(eq_classes, parent_i)
constraints_of_children = []
#constraints_union = np.zeros(len(constraints), dtype=bool)
#constraints_multiple = np.zeros(len(constraints), dtype=bool)
constraints_union = constraints_parent
constraints_multiple = constraints_parent
for child_i in parent_class.children:
#print(parent_i, child_i)
constraints_child = np.zeros(len(constraints), dtype=bool)
for j in range(len(constraints)):
constraints_child[j] = constraints[j].imposed_by[child_i]
constraints_of_children.append(constraints_child)
if np.any(np.logical_and(constraints_parent,
np.logical_not(constraints_child))):
print("Error: parent {0} imposes a constraint not imposed by child {1}".format(parent_i, child_i))
constraints_multiple = (np.logical_or
(constraints_multiple,
np.logical_and(constraints_union,
constraints_child)))
constraints_union = np.logical_or(constraints_union,
constraints_child)
for child_list_i, child_i in enumerate(parent_class.children):
unique_constraints_mask = np.logical_and(constraints_of_children[child_list_i], np.logical_not(constraints_multiple))
unique_constraints_list = np.nonzero(unique_constraints_mask)[0]
num_unique_constraints = len(unique_constraints_list)
if num_unique_constraints == 0:
print("Warning: child {0} does not impose any unique constraints among the children of {1}".format(child_i, parent_i))
continue
bad_constraints = np.zeros(len(constraints), dtype=bool)
descendants_of_child = get_descendant_classes(eq_classes, child_i)
for desc_i in np.nonzero(np.logical_and(descendants_of_parent, np.logical_not(descendants_of_child)))[0]:
if desc_i == parent_i:
continue
num_bad_constraints = 0
for constraint_i in unique_constraints_list:
if constraints[constraint_i].imposed_by[desc_i]:
num_bad_constraints += 1
bad_constraints[constraint_i] = True
if num_bad_constraints:
# This actually occurs (in the example I saw, the constraint
# was between a different pair of nodes than the removed
# edge).
if verbose:
print("Warning: {0} of the {1} unique constraints imposed by child {2} of parent {3} are also imposed by class {4}, which descends from {3} but not from {2}".format(num_bad_constraints, num_unique_constraints, child_i, parent_i, desc_i))
print("Parent:")
print(eq_classes[parent_i])
print("Child:")
print(eq_classes[child_i])
print("Descendant:")
print(eq_classes[desc_i])
identifying_constraint = None
if sum(bad_constraints) == num_unique_constraints:
print("Warning: *all* unique constraints imposed by child {0} of parent {1} are invalidated by some descendant of {1}".format(child_i, parent_i))
else:
identifying_constraint = np.nonzero(np.logical_and(unique_constraints_mask, np.logical_not(bad_constraints)))[0][0]
children_identifying_constraint.append(identifying_constraint)
if parent_class.children_identifying_constraint is not None:
precomputed_constraint = parent_class.children_identifying_constraint[child_list_i]
if not unique_constraints_mask[precomputed_constraint]:
print("Error: precomputed constraint is not unique")
elif bad_constraints[precomputed_constraint]:
print("Error: precomputed constraint is bad")
else:
print("Error: no precomputed constraint")
if write:
parent_class.children_identifying_constraint = children_identifying_constraint
print("\r{0}: {1:0.2f}%"
.format(parent_i + 1,
100.0 * (parent_i + 1) / len(eq_classes)),
end='', file=sys.stderr)
print(file=sys.stderr)
def delete_operator_commutativity_TEST(d):
# Test succeeded for d=4,5,6
# Proof: Follows from ChickeringMeek2015 (SGES paper),
# Lemma 2 ('The Deletion Lemma')
solver = BBES(d, 0)
equivalence_classes = solver.search_space['equivalence_classes']
constraints = solver.search_space['constraints']
for parent_i, parent_class in enumerate(equivalence_classes):
print(parent_i, "/", len(equivalence_classes), end='\r')
is_desc = get_descendant_classes(equivalence_classes, parent_i)
for child_sub_i, child_i in enumerate(parent_class.children):
child_constraint_i = parent_class.children_identifying_constraint[child_sub_i]
child_constraint = constraints[child_constraint_i]
for desc_i, desc_class in enumerate(equivalence_classes):
if not is_desc[desc_i]:
continue
if child_constraint.imposed_by[desc_i]:
# descendant of parent, and imposing id_constraint of child
for constraint_i, constraint in enumerate(constraints):
if constraint.imposed_by[child_i] and not constraint.imposed_by[desc_i]:
print("ERROR: for parent class")
print(parent_class)
print(", the constraint", constraint, "is imposed by child class")
print(child_class)
print("but not by descendant class")
print(desc_class)
print("delete_operator_commutativity_TEST DONE")
class BBES_state:
"""State in the BBES algorithm: set of equivalence classes"""
def __init__(self, solver, superclass, classes_included=None,
required_connections=None):
self.solver = solver # BBES class object using this state. TODO could be class variable if Solver is a singleton.
self.superclass = superclass
self.max_params = superclass.num_edges
self.classes_included = classes_included
if classes_included is not None:
self.max_loglik = solver.compute_loglik(superclass.example_mB)
self.min_params = self.compute_min_params_bruteforce() # requires max_params
else:
self.required_connections = required_connections
self.max_loglik = None
self.min_params = None
self.key = None
self.visited = False # is set to True when its children are compared
#self.state_TEST()
def is_singleton(self):
return self.min_params == self.max_params
def do_branch(self, child_id_constraint, child_superclass=None):
# return a new state that imposes a constraint, and modify self to
# not impose that constraint
#child_superclass_i = self.superclass.children[child_i]
#child_id_constraint = self.solver.search_space['constraints'][self.superclass.children_identifying_constraint[child_i]]
if self.solver.without_precomp == 1:
# without precomputed search space: