From f60809087dfb83ed513cfbcdd177f48c4d61d6dd Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Thu, 10 Jul 2025 10:29:57 +0200
Subject: [PATCH 01/32] Adding tetrad and junit 5 dependencies
---
pom.xml | 34 +++++++++++++++++++++++++++-------
1 file changed, 27 insertions(+), 7 deletions(-)
diff --git a/pom.xml b/pom.xml
index 3e3466f..8f30b34 100644
--- a/pom.xml
+++ b/pom.xml
@@ -55,12 +55,12 @@
-
- io.github.cmu-phil
- tetrad-lib
-
- 7.6.4
-
+
+ io.github.cmu-phil
+ tetrad-lib
+
+ 7.6.4
+
+
+
+ org.junit.jupiter
+ junit-jupiter
+ 5.10.0
+ test
+
@@ -123,10 +130,23 @@
+
org.apache.maven.plugins
maven-surefire-plugin
- 2.22.2
+ 3.1.2
+
+
+ org.junit.platform
+ junit-platform-engine
+ 1.10.0
+
+
+ org.junit.jupiter
+ junit-jupiter-engine
+ 5.10.0
+
+
From 17d57dd9ad2b8b45c269a2f4f61fe0fcc6aca349 Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Thu, 10 Jul 2025 10:30:24 +0200
Subject: [PATCH 02/32] Cleaning and testing AlphaOrder
---
.../uclm/i3a/simd/consensusBN/AlphaOrder.java | 355 +++++++++---------
.../i3a/simd/consensusBN/AlphaOrderTest.java | 119 ++++++
2 files changed, 296 insertions(+), 178 deletions(-)
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java b/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java
index 55e7e2c..7865c15 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java
@@ -3,152 +3,157 @@
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
+
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
-import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Node;
+/**
+ * This class implements a heuristic to compute an ancestral order of nodes for a set of DAGs.
+ * The heuristic is based on finding the best sink node in each iteration for the set of DAGs,
+ * removing it from the DAGs, and repeating the process until all nodes are ordered.
+ */
public class AlphaOrder {
- ArrayList setOfDags = null;
- ArrayList alpha = null;
- ArrayList setOfauxG = null;
-// ArrayList dpaths = null;
+ /**
+ * The set of DAGs to compute the ancestral order from.
+ */
+ private final ArrayList setOfDags;
+ /**
+ * The computed ancestral order of nodes.
+ */
+ private ArrayList alpha;
+ /**
+ * A set of auxiliary DAGs used during the computation.
+ */
+ private final ArrayList setOfauxG;
+ /**
+ * Constructor for the AlphaOrder class.
+ * Initializes the set of DAGs and creates a copy of each DAG to work with.
+ * @param dags the list of DAGs from which to compute the ancestral order.
+ * This constructor creates a deep copy of each DAG to avoid modifying the original DAGs during
+ * the computation of the ancestral order.
+ */
public AlphaOrder(ArrayList dags){
-
+ // Check if the dags are valid
+ checkExceptions(dags);
+
+ // Initialize the class variables
this.setOfDags = dags;
- this.alpha = new ArrayList();
- this.setOfauxG = new ArrayList();
-// this.dpaths = new ArrayList();
-
+ this.alpha = new ArrayList<>();
+ this.setOfauxG = new ArrayList<>();
for (Dag i : setOfDags) {
Dag aux_G = new Dag(i);
setOfauxG.add(aux_G);
-// dpaths.add(computeDirectedPathFromTo(aux_G));
}
-
}
-
- public int[][] computeDirectedPathFromTo(Dag graph) {
- LinkedList dpathNewEdges = new LinkedList();
- dpathNewEdges.clear();
- dpathNewEdges.addAll(graph.getEdges());
- List dpathNodes = null;
- dpathNodes = graph.getNodes();
-
- int numNodes = dpathNodes.size();
- int [][] dpath = new int[numNodes][numNodes];
+
+ /**
+ * Checks for exceptions in the input set of DAGs.
+ * Throws an IllegalArgumentException if the set is null, empty, or contains DAGs with different nodes.
+ * Also checks that the size of the set is greater than 1.
+ * @param setOfDags the set of DAGs to check for exceptions.
+ */
+ private void checkExceptions(ArrayList setOfDags) {
+ // Check if setOfDags is null
+ if(setOfDags == null) {
+ throw new IllegalArgumentException("The set of DAGs is null.");
+ }
+
+ // Check if all DAGs have the same nodes
+ if (setOfDags.isEmpty()) {
+ throw new IllegalArgumentException("The set of DAGs is empty.");
+ }
+ // Check that the size is greater than 1
+ if(setOfDags.size() <= 1) {
+ throw new IllegalArgumentException("The set of DAGs has only one DAG.");
+ }
- while (!dpathNewEdges.isEmpty()) {
- Edge edge = dpathNewEdges.removeFirst();
- Node _nodeT = Edges.getDirectedEdgeTail(edge);
- Node _nodeH = Edges.getDirectedEdgeHead(edge);
- int _indexT = dpathNodes.indexOf(_nodeT);
- int _indexH = dpathNodes.indexOf(_nodeH);
- dpath[_indexT][_indexH] = 1;
- int dPathT = 0;
- int dPathH = 0;
- int mindPath = 0;
- for (int i = 0; i < dpathNodes.size(); i++) {
- dPathT = dpath[i][_indexT];
- if (dpath[i][_indexT] >= 1) {
- dPathH = dpath[i][_indexH];
- if(dPathH == 0) dpath[i][_indexH] = dPathT+1;
- else{
- mindPath = Math.min(dPathH, dPathT+1);
- dpath[i][_indexH]=mindPath;
- }
- }
- dPathH = dpath[_indexH][i];
- if(dpath[_indexH][i] >= 1){
- dPathT = dpath[_indexT][i];
- if(dPathT ==0) dpath[_indexT][i] = dPathH+1;
- else{
- mindPath = Math.min(dPathT, dPathH+1);
- dpath[_indexT][i] = mindPath;
- }
-
- }
+ // Check that all DAGs have the same nodes
+ List firstDagNodes = setOfDags.get(0).getNodes();
+ for (Dag dag : setOfDags) {
+ if (!dag.getNodes().equals(firstDagNodes)) {
+ throw new IllegalArgumentException("All DAGs must have the same nodes. Dag " + dag + " has different nodes than the rest of DAGs.");
}
}
- return dpath;
- }
+ }
+
+ /**
+ * Returns the nodes of the first DAG in the set, since all DAGs are assumed to have the same nodes.
+ * @return
+ */
public List getNodes(){
return(setOfDags.get(0).getNodes());
}
- // heursitica para orden de conceso basada en el numero de caminos dirigidos. (Es muy mala no se utiliza)
-
- public void computeAlphaH1(){
-
- List nodes = setOfDags.get(0).getNodes();
- LinkedList alpha = new LinkedList();
-
- while(nodes.size()>0){
- int index_alpha = computeNextH1(nodes);
- Node node_alpha = nodes.get(index_alpha);
- alpha.addFirst(node_alpha);
- for(Dag g: this.setOfauxG){
- removeNode(g,node_alpha);
- //int[][] newDpaths = computeDirectedPathFromTo(g);
-// this.dpaths.set(this.setOfauxG.indexOf(g), newDpaths);
- }
- nodes.remove(node_alpha);
- }
- this.alpha = new ArrayList(alpha);
- }
-
- // heuistica para encontrar un orden de conceso. Se basa en los enlaces que generaria seguir una secuencia creada desde los nodos sumideros hacia arriba.
-
-public void computeAlphaH2(){
+ /**
+ * This method computes the heuristic to find an ancestral order of nodes of consensus. It is based on the number of edges that would be added on a sequence created from the sink nodes upwards.
+ * It iteratively finds the node with the minimum number of changes (inversions and additions of edges) and adds it to the beginning of the order.
+ * */
+ public void computeAlpha(){
+ // Get nodes and initialize the alpha list
List nodes = setOfDags.get(0).getNodes();
- LinkedList alpha = new LinkedList();
+ LinkedList alpha_aux = new LinkedList<>();
- while(nodes.size()>0){
- int index_alpha = computeNextH2(nodes);
- Node node_alpha = nodes.get(index_alpha);
- alpha.addFirst(node_alpha);
+ while(!nodes.isEmpty()){
+ int index_alpha = computeNextSink(nodes);
+ Node nodeAlpha = nodes.get(index_alpha);
+ alpha_aux.addFirst(nodeAlpha);
for(Dag g: this.setOfauxG){
- removeNode(g,node_alpha);
+ removeNode(g,nodeAlpha);
}
- nodes.remove(node_alpha);
+ nodes.remove(nodeAlpha);
}
- this.alpha = new ArrayList(alpha);
+ this.alpha = new ArrayList<>(alpha_aux);
}
-
- int computeNextH2(List nodes){
+ /**
+ * Gets the following node in the order based on the minimum number of changes (inversions and additions of edges) that would be required to create a sequence from the sink nodes upwards.
+ * @param nodes Remaining nodes to be ordered.
+ * @return index of the node that should be added next to the order.
+ */
+ private int computeNextSink(List nodes){
- int changes = 0;
+ // Setting up variables to count changes
+ int changes;
int inversion = 0;
int addition = 0;
int indexNode = 0;
int min = Integer.MAX_VALUE;
-
+
+ // Iterate through each node to find the one with the minimum changes for the list of DAGs.
for(int i=0; i inserted = new ArrayList();
+ // Checking total amount of inversions. We add -1 to give relevance to nodes that are already sinks.
List children = g.getChildren(nodei);
inversion += (children.size()-1);
+
+ // Checking edge additions from parents of each child to nodei and from parents of nodei to children.
+ ArrayList inserted = new ArrayList<>();
List paX = g.getParents(nodei);
for(Node child: children){
List paY = g.getParents(child);
+ // For each parent of nodei, check if it has an edge to the child
for(Node nodep: paX){
- if(g.getEdge(nodep, child)==null){
- addition++;
- }
+ if(g.getEdge(nodep, child)==null){
+ addition++;
+ }
}
+ // For each parent of the child, check if it has an edge to nodei
for(Node nodec: paY){
if(!nodec.equals(nodei)){
+ // If there is no edge between nodec and nodei, we consider adding it
if((g.getEdge(nodec,nodei)==null) && (g.getEdge(nodei,nodec)==null)){
Edge toBeInserted = new Edge(nodec,nodei,Endpoint.CIRCLE,Endpoint.CIRCLE);
boolean contains = false;
+ // Checking if we have already added this edge to the list of inserted edges
+ // to avoid counting it multiple times.
for(Edge e: inserted){
if((e.getNode1().equals(nodec) && (e.getNode2().equals(nodei))) ||
((e.getNode1().equals(nodei) && (e.getNode2().equals(nodec))))){
@@ -156,6 +161,7 @@ int computeNextH2(List nodes){
break;
}
}
+ // Checkin if there is a new edge addition, we update the counter and the list of inserted edges if so.
if(!contains){
addition++;
inserted.add(toBeInserted);
@@ -165,117 +171,110 @@ int computeNextH2(List nodes){
}
}
}
+ // Calculate total changes for the current node
changes = inversion + addition;
+ // If the current node has less changes than the minimum found so far, we update the minimum and the index of the node
+ // to be added to the order.
if(changes < min){
min = changes;
indexNode = i;
}
- changes = 0;
+ // Resetting changes for the next iteration
inversion = 0;
addition = 0;
}
return indexNode;
}
- void removeNode(Dag g, Node node_alpha){
+ /**
+ * Removes a node from the DAG and updates the edges according to a new node added to the alpha order.
+ * It removes a sink node and updates the edges to maintain the directed paths in the DAG.
+ * This is done each iteration of the heuristic to compute the alpha order.
+ * @param g the DAG from which the node is to be removed.
+ * @param nodeAlpha the node to be removed from the DAG.
+ */
+ private void removeNode(Dag g, Node nodeAlpha){
- List children = g.getChildren(node_alpha);
+ List children = g.getChildren(nodeAlpha);
while(!children.isEmpty()){
- int i=0;
- Node child;
- boolean seguir = false;
- do{
- child = children.get(i++);
- g.removeEdge(node_alpha, child);
- seguir=false;
- if(g.paths().existsDirectedPath(node_alpha,child)){
- seguir=true;
- g.addEdge(new Edge(node_alpha,child,Endpoint.TAIL, Endpoint.ARROW));
- }
- }while(seguir);
-
- List paX = g.getParents(node_alpha);
- List paY = g.getParents(child);
- paY.remove(node_alpha);
- g.addEdge(new Edge(child,node_alpha,Endpoint.TAIL, Endpoint.ARROW));
- for(Node nodep: paX){
- Edge pay = g.getEdge(nodep, child);
- if(pay == null)
- g.addEdge(new Edge(nodep,child,Endpoint.TAIL,Endpoint.ARROW));
+ // 1. Select a child that prevents a cycle when nodeAlpha <- child is added.
+ Node child = selectChild(g, nodeAlpha, children);
- }
- for(Node nodep : paY){
- Edge paz = g.getEdge(nodep,node_alpha);
- if(paz == null)
- g.addEdge(new Edge(nodep,node_alpha,Endpoint.TAIL,Endpoint.ARROW));
- }
+ // 2. Cover the edge nodeAlpha -> child by adding edges from parents of nodeAlpha to child and from parents of child to nodeAlpha. Last of all we revert the edge nodeAlpha -> child.
+ // This is done to maintain the directed paths in the DAG.
+ coverEdge(g, nodeAlpha, child);
+ // 3. Delete the child from the list of children of nodeAlpha, as it has been processed.
children.remove(child);
}
- g.removeNode(node_alpha);
+ // Finally, remove the nodeAlpha from the DAG.
+ g.removeNode(nodeAlpha);
}
+ /**
+ * Selects a child node from the list of children of nodeAlpha that does not create a cycle when an edge from nodeAlpha to the child is added (nodeAlpha <- child).
+ * @param g the DAG from which the child is to be selected.
+ * @param nodeAlpha the node from the alpha order heuristic.
+ * @param children the remaining children of nodeAlpha in the DAG.
+ * @return the selected child node that does not create a cycle when an edge from nodeAlpha to the child is added.
+ */
+ private Node selectChild(Dag g, Node nodeAlpha, List children) {
+ int i=0;
+ Node child;
+ boolean endCondition;
+ do{
+ child = children.get(i++);
+ g.removeEdge(nodeAlpha, child);
+ endCondition=false;
+ if(g.paths().existsDirectedPath(nodeAlpha,child)){
+ endCondition=true;
+ g.addEdge(new Edge(nodeAlpha,child,Endpoint.TAIL, Endpoint.ARROW));
+ }
+ }while(endCondition);
+ return child;
+ }
- int computeNextH1(List nodes){
+ /**
+ * Covers the edge from nodeAlpha to child by adding edges from parents of nodeAlpha to child and from parents of child to nodeAlpha.
+ * This is done to maintain the directed paths in the DAG after removing nodeAlpha.
+ * @param g the DAG where the edge is to be covered.
+ * @param nodeAlpha the node from the alpha order heuristic.
+ * @param child the child node selected from the list of children of nodeAlpha.
+ */
+ private void coverEdge(Dag g, Node nodeAlpha, Node child) {
+ // Getting the parents of nodeAlpha and child.
+ List paX = g.getParents(nodeAlpha);
+ List paY = g.getParents(child);
+ paY.remove(nodeAlpha);
- int min = Integer.MAX_VALUE;
- int minIndex = 0;
-
- for(int i=0 ; i< nodes.size(); i++){
- int weightNodei = 0;
- //for(Dag dag : this.setOfauxG){
- // int[][] dpath = this.dpaths.get(this.setOfauxG.indexOf(dag));
- // for(int j=0 ; j child.
+ g.addEdge(new Edge(child,nodeAlpha,Endpoint.TAIL, Endpoint.ARROW));
}
+
+
- public ArrayList getOrder(){
-
+ /**
+ * Returns the computed ancestral order of nodes.
+ * @return an ArrayList of nodes representing the ancestral order of the DAGs after applying the alpha order heuristic.
+ */
+ public ArrayList getOrder(){
return this.alpha;
}
-
-
- public static void main(String args[]) {
-
-// ArrayList dags = new ArrayList();
-// ArrayList alfa = new ArrayList();
-//
-//
-// System.out.println("Grafos de Partida: ");
-// System.out.println("---------------------");
-//// Graph graph = GraphConverter.convert("X1-->X5,X2-->X3,X3-->X4,X4-->X1,X4-->X5");
-//// Dag dag = new Dag(graph);
-//
-// Dag dag = new Dag();
-// // dag = GraphUtils.randomDag(Integer.parseInt(args[0]), Integer.parseInt(args[1]), true);
-// dags.add(dag);
-// System.out.println("DAG: ---------------");
-// System.out.println(dag.toString());
-// for (int i=0 ; i < Integer.parseInt(args[2])-1 ; i++){
-// // Dag newDag = GraphUtils.randomDag(dag.getNodes(),Integer.parseInt(args[1]) ,true);
-// dags.add(newDag);
-// System.out.println("DAG: ---------------");
-// System.out.println(newDag.toString());
-// }
-//
-// AlphaOrder order = new AlphaOrder(dags);
-// order.computeAlphaH2();
-// alfa = order.getOrder();
-//
-// System.out.println("Orden de Consenso: " + alfa.toString());
-
-
- }
-
-
}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java
new file mode 100644
index 0000000..e965021
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java
@@ -0,0 +1,119 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+class AlphaOrderTest {
+
+ private Node a, b, c;
+ private Dag dag1, dag2;
+ private ArrayList dags;
+
+ @BeforeEach
+ void setup() {
+ a = new GraphNode("A");
+ b = new GraphNode("B");
+ c = new GraphNode("C");
+
+ // DAG 1: A → B → C
+ dag1 = new Dag();
+ dag1.addNode(a);
+ dag1.addNode(b);
+ dag1.addNode(c);
+ dag1.addDirectedEdge(a, b);
+ dag1.addDirectedEdge(b, c);
+
+ // DAG 2: A → B, A → C
+ dag2 = new Dag();
+ dag2.addNode(a);
+ dag2.addNode(b);
+ dag2.addNode(c);
+ dag2.addDirectedEdge(a, b);
+ dag2.addDirectedEdge(a, c);
+
+ dags = new ArrayList<>(Arrays.asList(dag1, dag2));
+ }
+
+ @Test
+ void constructorThrowsOnNullInput() {
+ assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(null));
+ }
+
+ @Test
+ void constructorThrowsOnEmptyList() {
+ assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(new ArrayList<>()));
+ }
+
+ @Test
+ void constructorThrowsOnSingleDAG() {
+ ArrayList singleDagList = new ArrayList<>();
+ singleDagList.add(dag1);
+ assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(singleDagList));
+ }
+
+ @Test
+ void constructorThrowsOnDifferentNodes() {
+ // Crear otro DAG con nodos diferentes
+ Dag dagDifferent = new Dag();
+ dagDifferent.addNode(new GraphNode("X"));
+ dagDifferent.addNode(new GraphNode("Y"));
+ dagDifferent.addDirectedEdge(dagDifferent.getNode("X"), dagDifferent.getNode("Y"));
+
+ ArrayList badList = new ArrayList<>(Arrays.asList(dag1, dagDifferent));
+ assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(badList));
+ }
+
+ @Test
+ void computeAlphaReturnsValidOrder() {
+ AlphaOrder alphaOrder = new AlphaOrder(dags);
+ alphaOrder.computeAlpha();
+ List order = alphaOrder.getOrder();
+
+ assertNotNull(order);
+ assertEquals(3, order.size());
+ assertTrue(order.contains(a));
+ assertTrue(order.contains(b));
+ assertTrue(order.contains(c));
+
+ // Optional: check for uniqueness
+ assertEquals(3, order.stream().distinct().count());
+ }
+
+ @Test
+ void computeAlphaForTwoSimpleDags(){
+ AlphaOrder alphaOrder = new AlphaOrder(dags);
+ alphaOrder.computeAlpha();
+ List order = alphaOrder.getOrder();
+
+ // Basic assertions to check the order
+ assertNotNull(order);
+ assertEquals(3, order.size());
+ assertTrue(order.contains(a));
+ assertTrue(order.contains(b));
+ assertTrue(order.contains(c));
+
+ // Check that the order is A, B, C
+ assertEquals(a, order.get(0));
+ assertEquals(b, order.get(1));
+ assertEquals(c, order.get(2));
+
+ // Check for uniqueness
+ assertEquals(3, order.stream().distinct().count());
+
+ }
+}
+
+
+
From 04a11306319f438b314f34e877390d26ac976e74 Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Sat, 12 Jul 2025 19:04:16 +0200
Subject: [PATCH 03/32] Cleaning and testing BetaToAlpha
---
.../i3a/simd/consensusBN/BetaToAlpha.java | 383 +++++++++++-------
.../i3a/simd/consensusBN/BetaToAlphaTest.java | 93 +++++
2 files changed, 323 insertions(+), 153 deletions(-)
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BetaToAlpha.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BetaToAlpha.java
index 6398bf8..86e3856 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/BetaToAlpha.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BetaToAlpha.java
@@ -4,48 +4,90 @@
import java.util.HashMap;
import java.util.List;
import java.util.Random;
-import edu.cmu.tetrad.graph.Node;
+
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Endpoint;
+import edu.cmu.tetrad.graph.Node;
-
+/**
+ * BetaToAlpha is a class that transforms a directed acyclic graph (DAG) into an I-map minimal with respect to a specified alpha order.
+ * It constructs a compatible beta order and modifies the graph accordingly.
+ * The transformation respects the alpha order, ensuring that the resulting graph is consistent with it.
+ */
public class BetaToAlpha {
- Dag G = null;
- ArrayList beta = new ArrayList();
- ArrayList alfa = new ArrayList();
- HashMap alfaHash= new HashMap();
- Dag G_aux = null;
- int numberOfInsertedEdges = 0;
+ /**
+ * The directed acyclic graph (DAG) to be transformed.
+ */
+ private final Dag G;
+
+ /**
+ * The beta order derived from the alpha order.
+ */
+ private List beta;
+
+ /**
+ * The alpha order that the graph should respect. In consensusBN, this alpha order has been created using the AlphaOrder class.
+ * If null, a random order will be created.
+ */
+ private List alpha;
- public BetaToAlpha(Dag G, ArrayList alfa){
+ /**
+ * A hash map to store the index of each node in the alpha order for quick access.
+ */
+ private final HashMap alphaHash= new HashMap<>();
+
+ /**
+ * The auxiliary graph used during the transformation process.
+ */
+ private Dag G_aux = null;
- this.alfa = alfa;
+ /**
+ * The number of edges inserted during the transformation process.
+ */
+ int numberOfInsertedEdges = 0;
+
+ /**
+ * Constructor for BetaToAlpha that initializes the graph and alpha order.
+ * @param G the directed acyclic graph (DAG) to be transformed.
+ * @param alpha the alpha order that the graph should respect.
+ */
+ public BetaToAlpha(Dag G, ArrayList alpha){
+ this.alpha = alpha;
this.G = G;
this.beta = null;
- for(int i= 0; i< alfa.size(); i++){
- Node n = alfa.get(i);
- alfaHash.put(n, i);
+ for(int i= 0; i< alpha.size(); i++){
+ Node n = alpha.get(i);
+ alphaHash.put(n, i);
}
}
+ /**
+ * Constructor for BetaToAlpha that initializes the graph without a specified alpha order.
+ * A random alpha order will need to be created.
+ * @param G
+ */
public BetaToAlpha(Dag G){
-
- this.alfa = null;
+ this.alpha = null;
this.G = G;
this.beta = null;
-
}
- void computeAlfaHash(){
+ /**
+ * Computes the alpha hash map if it is not already computed.
+ * This method populates the alphaHash with the index of each node in the alpha order.
+ * It is called before any transformation to ensure that the alpha order is respected.
+ * If the alpha order is null, it will not compute the hash.
+ */
+ public void computeAlphaHash(){
- if(this.alfa !=null){
- if(alfaHash.isEmpty()){
- for(int i= 0; i< alfa.size(); i++){
- Node n = alfa.get(i);
- alfaHash.put(n, i);
+ if(this.alpha !=null){
+ if(alphaHash.isEmpty()){
+ for(int i= 0; i< alpha.size(); i++){
+ Node n = alpha.get(i);
+ alphaHash.put(n, i);
}
}
}
@@ -55,10 +97,15 @@ void computeAlfaHash(){
// Only to test the methods, to build a random order.
- public ArrayList randomAlfa (Random aleatorio){
+ /**
+ * Builds a random alpha order from the nodes of the graph. This is used for test purposes to ensure that the transformation can handle different orders.
+ * @param aleatorio the random number generator to use for shuffling the nodes.
+ * @return a list of nodes representing a random alpha order.
+ */
+ public List randomAlfa (Random aleatorio){
List nodes = this.G.getNodes();
- this.alfa = new ArrayList();
+ this.alpha = new ArrayList<>();
int[] index = new int[nodes.size()];
@@ -76,82 +123,121 @@ public ArrayList randomAlfa (Random aleatorio){
}
for (int i = 0; i< nodes.size(); i++){
- this.alfa.add(i, nodes.get(index[i]));
+ this.alpha.add(i, nodes.get(index[i]));
}
- this.computeAlfaHash();
- return this.alfa;
+ this.computeAlphaHash();
+ return this.alpha;
}
-
+ /**
+ * Transforms the graph G into an I-map minimal with respect to the alpha order.
+ */
public void transform(){
+ // 1. Create a compatible beta order with the alfa order for the DAG G.
+ buildBetaOrder();
+
+ // 2. Transform graph G into an I-map minimal with alpha order
+ transformWithBeta();
+
+ }
+
+ /**
+ * Builds the beta order that best respects the alpha order for the given graph G.
+ * This method constructs a beta order by identifying sink nodes and arranging them in a way that minimizes the number of edges that violate the alpha order.
+ * It uses a greedy approach to select the next node based on its position in the alpha order.
+ * The beta order is constructed such that it is as close as possible to the alpha order while ensuring that the resulting graph is still a DAG.
+ *
+ * This method modifies the G_aux graph to reflect the current state of the transformation.
+ * It also initializes the beta list with the first sink node and iteratively adds nodes to the beta order based on their relationships in the graph.
+ */
+ private void buildBetaOrder() {
this.G_aux = new Dag(this.G);
- this.beta = new ArrayList();
+ this.beta = new ArrayList<>();
+ List parents;
+
+ // Compute the sink nodes and add the first one to beta.
ArrayList sinkNodes = getSinkNodes(this.G_aux);
this.beta.add(sinkNodes.get(0));
- List pa = G_aux.getParents(sinkNodes.get(0));
+ parents = G_aux.getParents(sinkNodes.get(0));
this.G_aux.removeNode(sinkNodes.get(0));
sinkNodes.remove(0);
+
// Compute the new sink nodes
- for(Node nodep: pa){
- List chld = G_aux.getChildren(nodep);
- if (chld.size() == 0) sinkNodes.add(nodep);
- }
+ updateSinkNodes(sinkNodes, parents);
- // Construct beta order as closer as possible to alfa.
-
+ // Construct beta order as close as possible to alpha.
while (this.G_aux.getNumNodes()>0){
- // sinkNodes = getSinkNodes(this.G_aux);
+ // Select fist sink node
Node sink = sinkNodes.get(0);
- pa = G_aux.getParents(sink);
+ parents = G_aux.getParents(sink);
this.G_aux.removeNode(sink);
sinkNodes.remove(0);
// Compute the new sink nodes
- for(Node nodep: pa){
- List chld = G_aux.getChildren(nodep);
- if (chld.size() == 0) sinkNodes.add(nodep);
- }
+ updateSinkNodes(sinkNodes, parents);
- int index_alfa_sink = this.alfaHash.get(sink); //this.alfa.indexOf(sink);
- boolean ok = true;
- int i = 0;
-
- while(ok){
-
- Node nodej = this.beta.get(i);
- int index_alfa_nodej = this.alfaHash.get(nodej); //this.alfa.indexOf(nodej);
-
- if (index_alfa_nodej > index_alfa_sink){ ok = false; break;}
- if (this.G.getParents(nodej).contains(sink)){ ok = false; break;}
- if (i == this.beta.size()-1){ ok = false; break;}
- i++;
+ // Compute the index to insert the sink node in beta.
+ int insertIndex = 0;
+ for (; insertIndex < beta.size(); insertIndex++) {
+ Node current = beta.get(insertIndex);
+ if (alphaHash.get(current) > alphaHash.get(sink)) break;
+ if (G.getParents(current).contains(sink)) break;
}
-
- this.beta.add(i,sink);
+ beta.add(insertIndex, sink);
}
+ }
+/* FUTURE IDEA: SELECT BEST SINK NODE FROM ALPHA ORDER.
+ private Node selectBestSinkNode(List sinkNodes) {
+ return sinkNodes.stream()
+ .min(Comparator.comparingInt(alfaHash::get))
+ .orElse(sinkNodes.get(0));
+ }
+*/
+ /**
+ * Updates the sink nodes list based on the current list of candidates.
+ * This method checks each candidate node to see if it has any children in the auxiliary graph G_aux.
+ * If a candidate node has no children, it is added to the sink nodes list.
+ * This is used to maintain the integrity of the beta order during the transformation process.
+ *
+ * @param sinkNodes the list of current sink nodes to be updated.
+ * @param candidates the list of candidate nodes to check for children.
+ */
+ private void updateSinkNodes(ArrayList sinkNodes, List candidates) {
+ // Compute the new sink nodes
+ for(Node node: candidates){
+ List chld = G_aux.getChildren(node);
+ if (chld.isEmpty())
+ sinkNodes.add(node);
+ }
+ }
+
+ /**
+ * Transforms the graph G into an I-map minimal with respect to the alpha order.
+ * This method rearranges the edges in the graph based on the beta order derived from the alpha order.
+ * It ensures that the resulting graph respects the alpha order by checking the relationships between nodes and adjusting edges accordingly.
+ * The transformation modifies the graph in place and updates the beta list to reflect the new order of nodes.
+ */
+ private void transformWithBeta() {
+ ArrayList orderedNodes = new ArrayList<>();
+ // Setting the first node in the orderedNodes list.
+ orderedNodes.add(this.beta.remove(0));
- // transform graph G into an I-map minimal with alpha order
-
- ArrayList aux_beta = new ArrayList();
- aux_beta.add(this.beta.get(0));
- this.beta.remove(0);
-
- while(this.beta.size()>0){ // check each variable from the sink nodes.
-
- aux_beta.add(this.beta.get(0));
+ while(!this.beta.isEmpty()){
+ // Setting the next node in the orderedNodes list.
+ orderedNodes.add(this.beta.get(0));
this.beta.remove(0);
- int i = aux_beta.size();
- boolean ok = true;
+ int i = orderedNodes.size();
+ boolean changed = true;
- while (ok){
-
+ while (changed){
if(i==1) break;
- ok = false;
- Node nodeY = aux_beta.get(i-1);
- Node nodeZ = aux_beta.get(i-2);
-
-// if ((nodeZ != null) && (this.alfa.indexOf(nodeZ) > this.alfa.indexOf(nodeY))){
- if ((nodeZ != null) && (this.alfaHash.get(nodeZ) > this.alfaHash.get(nodeY))){
+ changed = false;
+ // Getting the last two nodes in the ordered list
+ Node nodeY = orderedNodes.get(i-1);
+ Node nodeZ = orderedNodes.get(i-2);
+
+ // Check if there is an edge from nodeZ to nodeY, if so, cover it.
+ if ((nodeZ != null) && (this.alphaHash.get(nodeZ) > this.alphaHash.get(nodeY))){
if(this.G.getEdge(nodeZ, nodeY) != null){
List paZ = this.G.getParents(nodeZ);
List paY = this.G.getParents(nodeY);
@@ -173,96 +259,87 @@ public void transform(){
}
}
}
- ok = true;
- aux_beta.remove(nodeY);
- aux_beta.add(i-2,nodeY);
+ changed = true;
+ orderedNodes.remove(nodeY);
+ orderedNodes.add(i-2,nodeY);
i--;
}
}
}
-
- this.beta = aux_beta;
-
+ this.beta = orderedNodes;
}
-
+ /**
+ * Returns the number of edges that were inserted during the transformation process.
+ * This method is useful for understanding how many modifications were made to the original graph to achieve the desired alpha order.
+ * @return
+ */
public int getNumberOfInsertedEdges(){
-
return this.numberOfInsertedEdges;
}
- ArrayList getSinkNodes(Dag g){
-
- ArrayList sourcesNodes = new ArrayList();
+ /**
+ * Retrieves the sink nodes from the given directed acyclic graph (DAG).
+ * A sink node is defined as a node that does not have any children in the graph.
+ * This method iterates through all nodes in the graph and checks their children to determine if they are sink nodes.
+ *
+ * @param g the directed acyclic graph (DAG) from which to retrieve sink nodes.
+ * @return an ArrayList of sink nodes that do not have any children in the graph.
+ */
+ private ArrayList getSinkNodes(Dag g){
+ // Get nodes from DAG
+ ArrayList sinkNodes = new ArrayList<>();
List nodes = g.getNodes();
-
- for (Node nodei : nodes){
- if(g.getChildren(nodei).isEmpty()) sourcesNodes.add(nodei);
+ // Check which nodes don't have children and add them to sinkNodes
+ for (Node node : nodes){
+ if(g.getChildren(node).isEmpty()){
+ sinkNodes.add(node);
+ }
}
- return sourcesNodes;
-
+ return sinkNodes;
}
-
-
-
-// public static void main(String args[]) {
-//
-// //Graph graph = GraphConverter.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4");
-// Graph graph = GraphConverter.convert("X2-->X1,X3-->X1,X1-->X4,X5-->X4,X4-->X6");
-// Dag dag = new Dag(graph);
-//
-// Dag dag2 = GraphUtils.randomDag(dag.getNodes(), 7, true);
-//// BayesPm bayesPm = new BayesPm(dag, 3, 3);
-//// MlBayesIm bayesIm = new MlBayesIm(bayesPm);
-////
-//// Element element = BayesXmlRenderer.getElement(bayesIm);
-//// System.out.println("Started with this bayesIm: " + bayesIm);
-//// System.out.println("\nGot this XML for it:");
-//// Document xmldoc = new Document(element);
-//// Serializer serializer = new Serializer(System.out);
-//// serializer.setLineSeparator("\n");
-//// serializer.setIndent(2);
-//// try {
-//// serializer.write(xmldoc);
-//// }
-//// catch (IOException e) {
-//// throw new RuntimeException(e);
-//// }
-//
-//
-// System.out.println(GraphUtils.graphToDot(dag));
-//
-//
-//// System.out.println("Dag Inicial: "+ dag.toString());
-//
-// Random aleatorio = new Random(150);
-// BetaToAlpha mt = new BetaToAlpha(dag);
-// mt.randomAlfa (aleatorio);
-// mt.transform();
-//// System.out.println(mt.G.toString()+" Alfa: "+mt.alfa.toString()+" Beta: "+ mt.beta.toString() );
-//
-// System.out.println(GraphUtils.graphToDot(mt.G));
-//
-//
-//
-//// System.out.println("Dag Inicial: "+ dag2.toString());
-//
-// System.out.println(GraphUtils.graphToDot(dag2));
-//
-// BetaToAlpha mt2 = new BetaToAlpha(dag2);
-// Random aleat2 = new Random(150);
-// mt2.randomAlfa(aleat2);
-// mt2.transform();
-//
-//// System.out.println(mt2.G.toString()+" Alfa: "+mt2.alfa.toString()+" Beta: "+ mt2.beta.toString() );
-//
-// System.out.println(GraphUtils.graphToDot(mt2.G));
-//
-//
-//
-// }
-
-
+ /**
+ * Returns the alpha hash map that contains the index of each node in the alpha order.
+ * This map is used to quickly access the position of nodes in the alpha order during the transformation process.
+ * It is particularly useful for ensuring that the resulting graph respects the specified alpha order.
+ * @return the alpha hash map where keys are nodes and values are their indices in the alpha order.
+ */
+ public HashMap getAlphaHash() {
+ return alphaHash;
+ }
+
+ /**
+ * Sets the alpha order for the transformation.
+ * This method allows the user to specify a new alpha order for the graph transformation.
+ * It updates the alpha field and recomputes the alpha hash map to reflect the new order.
+ * @param alpha the new alpha order to be set for the transformation.
+ */
+ public void setAlphaOrder(List alpha) {
+ this.alpha = alpha;
+ this.computeAlphaHash();
+ }
+
+ /**
+ * Returns the alpha order that the graph should respect.
+ * @return the alpha order as a list of nodes, or null if no alpha order has been set.
+ */
+ public List getAlphaOrder() {
+ return alpha;
}
+ /**
+ * Returns the directed acyclic graph (DAG) that has been transformed.
+ * This method provides access to the modified graph after the transformation has been applied.
+ * The graph will be an I-map minimal with respect to the specified alpha order.
+ *
+ * @see BetaToAlpha#transform()
+ * @return the transformed directed acyclic graph (DAG) as a Dag object.
+ */
+ public Dag getGraph() {
+ return G;
+ }
+
+
+}
+
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java
new file mode 100644
index 0000000..7d421ad
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java
@@ -0,0 +1,93 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class BetaToAlphaTest {
+
+ private Dag dag;
+ private Node a, b, c, d;
+ private ArrayList alphaOrder;
+
+ @BeforeEach
+ void setUp() {
+ a = new GraphNode("A");
+ b = new GraphNode("B");
+ c = new GraphNode("C");
+ d = new GraphNode("D");
+
+ // DAG: A → B, A → C, B → D, C → D
+ dag = new Dag();
+ dag.addNode(a);
+ dag.addNode(b);
+ dag.addNode(c);
+ dag.addNode(d);
+ dag.addDirectedEdge(a, b);
+ dag.addDirectedEdge(a, c);
+ dag.addDirectedEdge(b, d);
+ dag.addDirectedEdge(c, d);
+
+ // Define an alpha order that requires modifying the graph
+ alphaOrder = new ArrayList<>(Arrays.asList(d, c, b, a));
+ }
+
+ @Test
+ void testTransformRespectsAlphaOrder() {
+ BetaToAlpha bta = new BetaToAlpha(dag, alphaOrder);
+ bta.transform();
+
+ // El grafo debería haber invertido al menos algunos arcos
+ assertTrue(bta.getNumberOfInsertedEdges() > 0);
+
+ // Validamos que el orden resultante es compatible con alpha
+ for (Edge edge : dag.getEdges()) {
+ Node from = edge.getNode1();
+ Node to = edge.getNode2();
+
+ int fromIndex = alphaOrder.indexOf(from);
+ int toIndex = alphaOrder.indexOf(to);
+
+ assertTrue(fromIndex < toIndex, "Edge violates alpha order: " + from + " → " + to);
+ }
+ }
+
+ @Test
+ void testRandomAlphaProducesPermutation() {
+ BetaToAlpha bta = new BetaToAlpha(dag);
+ List randomAlpha = bta.randomAlfa(new Random(42));
+
+ assertNotNull(randomAlpha);
+ assertEquals(dag.getNumNodes(), randomAlpha.size());
+
+ Set originalNodes = new HashSet<>(dag.getNodes());
+ Set shuffled = new HashSet<>(randomAlpha);
+
+ assertEquals(originalNodes, shuffled); // misma colección, diferente orden
+ }
+
+ @Test
+ void testComputeAlphaHashBuildsCorrectMap() {
+ BetaToAlpha bta = new BetaToAlpha(dag, alphaOrder);
+ bta.computeAlphaHash();
+
+ for (int i = 0; i < alphaOrder.size(); i++) {
+ Node node = alphaOrder.get(i);
+ assertEquals(i, (bta.getAlphaHash()).get(node));
+ }
+ }
+}
From f12c2475336ce6de1c1867c7bd49f01ff4b8cb21 Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Mon, 14 Jul 2025 12:05:31 +0200
Subject: [PATCH 04/32] Cleaning and testing TransformDags
---
.../i3a/simd/consensusBN/TransformDags.java | 278 +++++++-----------
.../simd/consensusBN/TransformDagsTest.java | 104 +++++++
2 files changed, 207 insertions(+), 175 deletions(-)
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java b/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java
index 9861e99..f798234 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java
@@ -2,205 +2,133 @@
import java.util.ArrayList;
-
-import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Node;
+/**
+ * This class transforms a set of DAGs by applying the BetaToAlpha transformation to each DAG with a given alpha order.
+ */
public class TransformDags {
+ /**
+ * List of input DAGs to be transformed.
+ */
+ private final ArrayList setOfDags;
+ /**
+ * List of output DAGs after transformation.
+ */
+ private final ArrayList setOfOutputDags;
+ /**
+ * The alpha order used for the transformation.
+ */
+ private final ArrayList alpha;
+ /**
+ * The transformation objects for each DAG.
+ * Each BetaToAlpha object applies the transformation to a corresponding DAG in setOfDags using the alpha order provided.
+ */
+ private ArrayList transformers= null;
- ArrayList setOfDags = null;
- ArrayList setOfOutputDags = null;
- ArrayList alfa = null;
- ArrayList metAs= null;
- int numberOfInsertedEdges = 0;
-// int weight[][][] = null;
+ /**
+ * Number of edges inserted during the transformation process.
+ * This is used to track how many edges were added to the transformed DAGs.
+ */
+ private int numberOfInsertedEdges = 0;
- public TransformDags(ArrayList dags, ArrayList alfa){
+ /**
+ * Constructor for TransformDags.
+ * Initializes the object with a list of DAGs and an alpha order.
+ * It creates a new BetaToAlpha transformation for each DAG in the input list.
+ * Each DAG in the input list will be transformed according to this alpha order.
+ * The transformation is applied by creating a BetaToAlpha object for each DAG.
+ * The transformed DAGs will be stored in setOfOutputDags after calling the transform() method.
+ * @see BetaToAlpha
+ * @param dags List of DAGs to be transformed.
+ * @param alpha List of nodes representing the alpha order for the transformation.
+ */
+ public TransformDags(ArrayList dags, ArrayList alpha){
this.setOfDags = dags;
- this.setOfOutputDags = new ArrayList();
- this.metAs = new ArrayList();
- this.alfa = alfa;
-
+ this.setOfOutputDags = new ArrayList<>();
+ this.transformers = new ArrayList<>();
+ this.alpha = alpha;
+ // Initialize the BetaToAlpha transformation for each DAG in the input list
for (Dag i : setOfDags) {
Dag out = new Dag(i);
- this.metAs.add(new BetaToAlpha(out,alfa));
+ this.transformers.add(new BetaToAlpha(out,this.alpha));
}
}
-
+ /**
+ * Transforms the input DAGs by applying the BetaToAlpha transformation.
+ * This method iterates through each BetaToAlpha object, applies the transformation,
+ * and collects the transformed DAGs into setOfOutputDags.
+ * It also counts the total number of edges inserted during the transformations.
+ *
+ * @see BetaToAlpha#transform()
+ * @see BetaToAlpha#getNumberOfInsertedEdges()
+ * @see BetaToAlpha#getGraph()
+ * @return An ArrayList of transformed DAGs after applying the BetaToAlpha transformation.
+ */
public ArrayList transform (){
-
this.numberOfInsertedEdges = 0;
-
- for(BetaToAlpha transformDagi: this.metAs){
+ for(BetaToAlpha transformDagi: this.transformers){
transformDagi.transform();
this.numberOfInsertedEdges += transformDagi.getNumberOfInsertedEdges();
- this.setOfOutputDags.add(transformDagi.G);
+ this.setOfOutputDags.add(transformDagi.getGraph());
}
-
-
return this.setOfOutputDags;
-
}
+ /**
+ * Returns the number of edges that were inserted during the transformation process.
+ * @return The total number of edges inserted across all transformed DAGs.
+ */
public int getNumberOfInsertedEdges(){
return this.numberOfInsertedEdges;
}
+ /**
+ * Returns the list of input DAGs that were transformed.
+ * @return An ArrayList of DAGs that were provided as input to the transformation.
+ */
+ public ArrayList getSetOfDags() {
+ return this.setOfDags;
+ }
+
+ /**
+ * Returns the list of transformed output DAGs.
+ * This list contains the DAGs after applying the BetaToAlpha transformation.
+ * @return An ArrayList of transformed DAGs.
+ */
+ public ArrayList getSetOfOutputDags() {
+ return this.setOfOutputDags;
+ }
+
+ /**
+ * Returns the alpha order used for the transformation.
+ * @return An ArrayList of nodes representing the alpha order.
+ */
+ public ArrayList getAlpha() {
+ return this.alpha;
+ }
+
+ /**
+ * Returns the list of BetaToAlpha transformers used for the transformation.
+ * @return An ArrayList of BetaToAlpha objects, each corresponding to a DAG in the input list.
+ */
+ public ArrayList getTransformers() {
+ return this.transformers;
+ }
+
+ /**
+ * Sets the list of BetaToAlpha transformers.
+ * This method allows for updating the transformers used in the transformation process.
+ *
+ * @param transformers An ArrayList of BetaToAlpha objects to be set as the transformers for this TransformDags instance.
+ * Each transformer will apply the BetaToAlpha transformation to its corresponding DAG in the input list.
+ */
+ public void setTransformers(ArrayList transformers) {
+ this.transformers = transformers;
+ }
-
-
-// void computeWeight(){
-//
-// this.weight = new int[this.setOfDags.size()][this.alfa.size()][this.alfa.size()];
-//
-// for(Dag g: this.setOfOutputDags){
-// for(Node nodei : g.getNodes()){
-// List pa = g.getParents(nodei);
-// if(pa.isEmpty()) continue;
-// List anc = new ArrayList();
-// anc.add(nodei);
-// anc = g.getAncestors(anc);
-// Dag gAnc = new Dag(g.subgraph(anc));
-// // me quedo con el grafo ancestral del node_i
-// for(Node pai: pa){ // Calculo el numero de caminos desde los ancestros que se "activan" borrando cada padre.
-// int npaths = 0;
-// Dag gAncNopai = new Dag(gAnc);
-// for(Node rm : pa) if(!rm.equals(pai)) gAncNopai.removeNode(rm); // borro todos los padres menos el pa_i en el grafo ancestral.
-// for(Node nodeAc: anc){ // para cada ancestro voy mirando si hay un camino dirigido.
-// if((gAncNopai.getNodes().contains(nodeAc))&&(!nodeAc.equals(nodei)))
-// if(GraphUtils.paths().existsDirectedPath(gAncNopai,nodeAc, nodei)) npaths++;
-// }
-// // npaths tiene el numero de caminos diridos que se han activado quitando el padre pa_i
-// this.weight[this.setOfOutputDags.indexOf(g)][g.getNodes().indexOf(nodei)][g.getNodes().indexOf(pai)] = npaths;
-// }
-//
-// }
-//
-// }
-//
-// }
-
-// public Dag computeWeightDag(boolean w){
-//
-// Dag wDag = new Dag(this.alfa);
-// for(Node nodei : this.alfa){
-// for(Node nodej : this.alfa){
-// if(nodei.equals(nodej)) continue;
-// int wij = 0;
-// for(Dag g: this.setOfOutputDags){
-// int wg = this.weight[this.setOfOutputDags.indexOf(g)][g.getNodes().indexOf(nodei)][g.getNodes().indexOf(nodej)];
-// if(wg > 0 && !w) wg = 1;
-// else if (wg == 0) wg =-1;
-// wij+=wg;
-// }
-// if(wij > 0) wDag.addEdge(new Edge(nodej,nodei,Endpoint.TAIL,Endpoint.ARROW));
-// }
-// }
-// return wDag;
-// }
-//
-// public static void main(String args[]) {
-//
-// ArrayList dags = new ArrayList();
-// ArrayList alfa = new ArrayList();
-// Random aleatorio = new Random(222);
-//
-//
-// System.out.println("Grafos de Partida: ");
-// System.out.println("---------------------");
-//// Graph graph = GraphConverter.convert("X1-->X5,X2-->X3,X3-->X4,X4-->X1,X4-->X5");
-//// Dag dag = new Dag(graph);
-//
-// Dag dag = new Dag();
-//
-// dag = GraphUtils.randomDag(Integer.parseInt(args[0]), Integer.parseInt(args[1]), true);
-// BetaToAlpha mt = new BetaToAlpha(dag);
-// alfa = mt.randomAlfa(aleatorio);
-// dags.add(dag);
-// System.out.println("DAG: ---------------");
-// System.out.println(dag.toString());
-// for (int i=0 ; i < Integer.parseInt(args[2])-1 ; i++){
-// Dag newDag = GraphUtils.randomDag(dag.getNodes(),Integer.parseInt(args[1]) ,true);
-// dags.add(newDag);
-// System.out.println("DAG: ---------------");
-// System.out.println(newDag.toString());
-// }
-//
-//
-//
-// System.out.println("Orden de Consenso: " + alfa.toString());
-//
-// TransformDags setOfDags = new TransformDags(dags,alfa);
-// setOfDags.transform();
-//
-//
-//
-//
-// for(Dag d : setOfDags.setOfOutputDags){
-// System.out.println("DAG trasformado: ---------------");
-// System.out.println(d.toString());
-// }
-//
-//
-//
-// Dag union = new Dag(alfa);
-//
-// for(Node nodei: alfa){
-// for(Dag d : setOfDags.setOfOutputDags){
-// Listparent = d.getParents(nodei);
-// for(Node pa: parent){
-// if(!union.isParentOf(pa, nodei)) union.addEdge(new Edge(pa,nodei,Endpoint.TAIL,Endpoint.ARROW));
-// }
-// }
-//
-// }
-//
-//
-// System.out.println("Grafo UNION: "+union.toString());
-// setOfDags.computeWeight();
-// Dag wDag = setOfDags.computeWeightDag(true);
-// System.out.println("Grafo Consenso: "+ wDag.toString());
-// Dag wDag2 = setOfDags.computeWeightDag(false);
-// System.out.println("Grafo Consenso sin pesos: "+ wDag2.toString());
-//
-//
-//
-//
-//
-//
-//
-//// Node nod = dag.getNodes().get(aleatorio.nextInt(alfa.size()));
-//// ArrayList a = new ArrayList();
-//// a.add(nod);
-//// List anc = dag.getAncestors(a);
-////
-//// System.out.println("Ancenstros de " + nod.toString()+ " "+anc.toString());
-////
-//// System.out.println("Subgraph: "+ dag.subgraph(anc));
-////
-//// List pa = dag.getParents(nod);
-//// System.out.println("padres de: "+nod.toString()+ " : "+pa.toString());
-//// Node pai = pa.get(aleatorio.nextInt(pa.size()));
-//// System.out.println("Padre elegido a borrar: "+pai.toString());
-//// pa.remove(pai);
-////
-////
-////
-//// for(Node rm: pa) anc.remove(rm);
-////
-//// Graph pp = dag.subgraph(anc);
-//// int npath = 0;
-//// for(Node ancestor: pp.getNodes()){
-//// if(!ancestor.equals(nod)){
-//// List> paths = GraphUtils.allPathsFromTo(pp, ancestor, nod);
-//// if(dag.paths().existsDirectedPath(ancestor, nod)) npath++;
-//// System.out.println("Caminos desde: "+ancestor.toString()+" a: "+nod.toString()+" : "+paths.toString());
-//// }
-//// }
-//// System.out.println(" El numero de caminos desde hacia: "+ nod+" es de: "+npath);
-// }
-//
}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java
new file mode 100644
index 0000000..b399ed3
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java
@@ -0,0 +1,104 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class TransformDagsTest {
+
+ private ArrayList inputDags;
+ private ArrayList alpha;
+
+ @BeforeEach
+ public void setUp() {
+ inputDags = new ArrayList<>();
+
+ // We use 4 nodes for the DAGs
+ Node nodeA = new GraphNode("A");
+ Node nodeB = new GraphNode("B");
+ Node nodeC = new GraphNode("C");
+ Node nodeD = new GraphNode("D");
+
+ // Create first DAG with these edges: A -> B, A -> C, B -> D, C -> D
+ Dag dag1 = new Dag();
+ dag1.addNode(nodeA);
+ dag1.addNode(nodeB);
+ dag1.addNode(nodeC);
+ dag1.addNode(nodeD);
+
+ // Adding directed edges to the DAG
+ dag1.addDirectedEdge(nodeA, nodeB);
+ dag1.addDirectedEdge(nodeA, nodeC);
+ dag1.addDirectedEdge(nodeB, nodeD);
+ dag1.addDirectedEdge(nodeC, nodeD);
+
+ // Adding the DAG to the list
+ inputDags.add(dag1);
+
+ // Create second DAG with these edges: D -> C, D -> B, C -> A, B -> A
+ Dag dag2 = new Dag();
+ dag2.addNode(nodeA);
+ dag2.addNode(nodeB);
+ dag2.addNode(nodeC);
+ dag2.addNode(nodeD);
+
+ // Adding directed edges to the second DAG
+ dag2.addDirectedEdge(nodeD, nodeC);
+ dag2.addDirectedEdge(nodeD, nodeB);
+ dag2.addDirectedEdge(nodeC, nodeA);
+ dag2.addDirectedEdge(nodeB, nodeA);
+
+ // Adding the second DAG to the list
+ inputDags.add(dag2);
+
+ // Apply AlphaOrder algorithm to these dags:
+ AlphaOrder alphaOrder = new AlphaOrder(inputDags);
+ alphaOrder.computeAlpha();
+ alpha = alphaOrder.getOrder();
+
+ }
+
+ @Test
+ public void testConstructorInitializesCorrectly() {
+ TransformDags transformer = new TransformDags(inputDags, alpha);
+
+ assertNotNull(transformer);
+ assertEquals(0, transformer.getNumberOfInsertedEdges());
+ }
+
+ @Test
+ public void testTransformReturnsCorrectSize() {
+ TransformDags transformer = new TransformDags(inputDags, alpha);
+ ArrayList result = transformer.transform();
+
+ assertNotNull(result);
+ assertEquals(inputDags.size(), result.size());
+ }
+
+ @Test
+ public void testTransformUpdatesNumberOfInsertedEdges() {
+ TransformDags transformer = new TransformDags(inputDags, alpha);
+ transformer.transform();
+
+ // No sabemos cuántas aristas se insertan exactamente sin saber cómo funciona BetaToAlpha,
+ // pero al menos podemos comprobar que el valor no es negativo.
+ assertTrue(transformer.getNumberOfInsertedEdges() >= 0);
+ }
+
+ @Test
+ public void testEmptyDagListReturnsEmptyOutput() {
+ TransformDags transformer = new TransformDags(new ArrayList<>(), alpha);
+ ArrayList result = transformer.transform();
+
+ assertTrue(result.isEmpty());
+ assertEquals(0, transformer.getNumberOfInsertedEdges());
+ }
+}
From 71e5cbec9067b9d7b069c9d351e0144288f339ec Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Tue, 15 Jul 2025 20:16:32 +0200
Subject: [PATCH 05/32] Cleaning and testing ConsensusUnion and updating pom
dependencies
---
pom.xml | 28 ++++
.../i3a/simd/consensusBN/ConsensusUnion.java | 131 ++++++++++++-----
.../simd/consensusBN/ConsensusUnionTest.java | 138 ++++++++++++++++++
3 files changed, 257 insertions(+), 40 deletions(-)
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java
diff --git a/pom.xml b/pom.xml
index 8f30b34..2af8d6d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -75,6 +75,13 @@
5.10.0
test
+
+
+ org.apache.commons
+ commons-math3
+ 3.6.1
+
+
@@ -149,6 +156,27 @@
+
+
+ org.jacoco
+ jacoco-maven-plugin
+ 0.8.10
+
+
+
+ prepare-agent
+
+
+
+ report
+ test
+
+ report
+
+
+
+
+
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java
index 8c38fd9..ba63e1b 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java
@@ -6,59 +6,119 @@
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Endpoint;
-import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
-
+/**
+ * This class implements the Consensus Union algorithm which applies a fusion between multiple Directed Acyclic Graphs (DAGs).
+ * It constructs a consensus DAG by merging the input DAGs based on a specified order of nodes (alpha).
+ * The alpha order is computed with the AlphaOrder class which implements a Greedy Heuristic Order (GHO) search, achieving a good order to transform the input DAGs.
+ * Once each DAG is transformed, the union method creates a new DAG that contains all the edges from the input DAGs, ensuring that the resulting graph is acyclic.
+ * The number of edges inserted during the union process can be retrieved using getNumberOfInsertedEdges.
+ *
+ * This class is also runnable, allowing it to be executed in a separate thread.
+ */
public class ConsensusUnion implements Runnable{
- ArrayList alpha = null;
- Dag outputDag = null;
- AlphaOrder heuristic = null;
- TransformDags imaps2alpha = null;
- ArrayList setOfdags = null;
+ /**
+ * The alpha order of nodes in the consensus DAG.
+ * This order is used to transform the input DAGs into a compatible I-Maps before merging.
+ * It is computed using the AlphaOrder class.
+ *
+ * @see AlphaOrder
+ */
+ private ArrayList alpha;
+ /**
+ * The AlphaOrder heuristic used to compute the alpha order.
+ */
+ private AlphaOrder heuristic = null;
+
+ /**
+ * The TransformDags instance that transforms the input DAGs based on the alpha order.
+ */
+ private TransformDags imaps2alpha;
+
+ /**
+ * List of input DAGs to be merged.
+ */
+ private ArrayList setOfdags = null;
+
+ /**
+ * The output DAG resulting from the union of the transformed input DAGs.
+ */
Dag union = null;
+
+ /**
+ * Number of edges inserted during the consensus union process.
+ */
int numberOfInsertedEdges = 0;
-
+ /**
+ * Constructor for ConsensusUnion that initializes the union process with a list of DAGs and an alpha order.
+ * @param dags the list of input DAGs to be merged.
+ * @param order the alpha order of nodes to be used for transforming the input DAGs.
+ */
public ConsensusUnion(ArrayList dags, ArrayList order){
this.setOfdags = dags;
this.alpha = order;
-
}
-
-
+ /**
+ * Constructor for ConsensusUnion that initializes the union process with a list of DAGs and uses the AlphaOrder object to generate an alpha order.
+ * @see AlphaOrder
+ * @param dags the list of input DAGs to be merged.
+ */
public ConsensusUnion(ArrayList dags){
this.setOfdags = dags;
this.heuristic = new AlphaOrder(this.setOfdags);
-
}
+ /**
+ * Default constructor for ConsensusUnion that initializes an empty union.
+ * This constructor can be used when the DAGs are set later using the setDags method.
+ */
public ConsensusUnion(){
this.setOfdags = null;
}
+ /**
+ * Returns the number of edges inserted during the union process.
+ * This value is updated after the union method is called.
+ * @return the number of edges inserted in the consensus DAG.
+ */
public int getNumberOfInsertedEdges(){
return this.numberOfInsertedEdges;
}
+ /**
+ * Performs the union of the input DAGs based on the alpha order. If no alpha order is set, it computes it first.
+ * The method transforms each input DAG according to the alpha order and then merges them into a single consensus DAG.
+ * The resulting DAG contains all edges from the transformed input DAGs, ensuring that it remains acyclic.
+ *
+ * @throws IllegalStateException if the alpha order is not set before calling this method.
+ * @throws IllegalArgumentException if the input DAGs are null or empty.
+ * @throws NullPointerException if the alpha order is null.
+ * @return the resulting consensus DAG after merging the transformed input DAGs.
+ * @see AlphaOrder
+ * @see TransformDags
+ */
public Dag union(){
+ // Computing Alpha Order if not set, using the Greedy Heuristic Order (GHO)
if(this.alpha == null){
-
- this.heuristic.computeAlphaH2();
- this.alpha = this.heuristic.alpha;
+ this.heuristic.computeAlpha();
+ this.alpha = this.heuristic.getOrder();
}
+ // Transforming each DAG with the alpha order
this.imaps2alpha = new TransformDags(this.setOfdags,this.alpha);
this.imaps2alpha.transform();
this.numberOfInsertedEdges = this.imaps2alpha.getNumberOfInsertedEdges();
+ // Applying a union of the edges of the transformed DAGs
this.union = new Dag(this.alpha);
for(Node nodei: this.alpha){
- for(Dag d : this.imaps2alpha.setOfOutputDags){
+ for(Dag d : this.imaps2alpha.getSetOfOutputDags()){
Listparent = d.getParents(nodei);
for(Node pa: parent){
if(!this.union.isParentOf(pa, nodei)) this.union.addEdge(new Edge(pa,nodei,Endpoint.TAIL,Endpoint.ARROW));
@@ -70,43 +130,34 @@ public Dag union(){
}
+ /**
+ * Returns the resulting consensus DAG after the union process.
+ * This method should be called after the union method to ensure that the union has been performed.
+ * @return the consensus DAG resulting from the union of the input DAGs.
+ */
public Dag getUnion(){
return this.union;
}
+ /**
+ * sets the list of input DAGs for the ConsensusUnion instance and applies the AlphaOrder heuristic to compute the alpha order.
+ * This method also updates the alpha order and transforms the input DAGs accordingly.
+ * @param dags
+ */
void setDags(ArrayList dags){
this.setOfdags = dags;
this.heuristic = new AlphaOrder(this.setOfdags);
- this.heuristic.computeAlphaH2();
- this.alpha = this.heuristic.alpha;
+ this.heuristic.computeAlpha();
+ this.alpha = this.heuristic.getOrder();
this.imaps2alpha = new TransformDags(this.setOfdags,this.alpha);
this.imaps2alpha.transform();
}
-
-
-
- public static void main(String args[]) {
-
-
- System.out.println("Grafos de Partida: ");
-
- // (seed, n. variables, n egdes aprox, n. dags, mutation)
- RandomBN setOfDags = new RandomBN(0, Integer.parseInt(args[0]), Integer.parseInt(args[1]),
- Integer.parseInt(args[2]),Integer.parseInt(args[3]));
- setOfDags.generate();
-//
- for( Dag g: setOfDags.setOfRandomDags) System.out.print(g);
- ConsensusUnion conDag= new ConsensusUnion();
- conDag.setDags(setOfDags.setOfRandomDags);
- Graph g = conDag.union();
- System.out.println("grafo consenso: "+ g);
-
- }
-
-
+ /**
+ * Runs the ConsensusUnion process in a separate thread.
+ */
@Override
public void run() {
this.union = this.union();
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java
new file mode 100644
index 0000000..e97c87f
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java
@@ -0,0 +1,138 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Graph;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class ConsensusUnionTest {
+
+ private ArrayList inputDags;
+ private ArrayList alpha;
+
+ @BeforeEach
+ public void setUp() {
+ inputDags = new ArrayList<>();
+
+ // We use 4 nodes for the DAGs
+ Node nodeA = new GraphNode("A");
+ Node nodeB = new GraphNode("B");
+ Node nodeC = new GraphNode("C");
+ Node nodeD = new GraphNode("D");
+
+ // Create first DAG with these edges: A -> B, A -> C, B -> D, C -> D
+ Dag dag1 = new Dag();
+ dag1.addNode(nodeA);
+ dag1.addNode(nodeB);
+ dag1.addNode(nodeC);
+ dag1.addNode(nodeD);
+
+ // Adding directed edges to the DAG
+ dag1.addDirectedEdge(nodeA, nodeB);
+ dag1.addDirectedEdge(nodeA, nodeC);
+ dag1.addDirectedEdge(nodeB, nodeD);
+ dag1.addDirectedEdge(nodeC, nodeD);
+
+ // Adding the DAG to the list
+ inputDags.add(dag1);
+
+ // Create second DAG with these edges: D -> C, D -> B, C -> A, B -> A
+ Dag dag2 = new Dag();
+ dag2.addNode(nodeA);
+ dag2.addNode(nodeB);
+ dag2.addNode(nodeC);
+ dag2.addNode(nodeD);
+
+ // Adding directed edges to the second DAG
+ dag2.addDirectedEdge(nodeD, nodeC);
+ dag2.addDirectedEdge(nodeD, nodeB);
+ dag2.addDirectedEdge(nodeC, nodeA);
+ dag2.addDirectedEdge(nodeB, nodeA);
+
+ // Adding the second DAG to the list
+ inputDags.add(dag2);
+
+ // Apply AlphaOrder algorithm to these dags:
+ AlphaOrder alphaOrder = new AlphaOrder(inputDags);
+ alphaOrder.computeAlpha();
+ alpha = alphaOrder.getOrder();
+
+ }
+
+ @Test
+ public void testConstructorWithAlphaInitializesCorrectly() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ assertNotNull(cu);
+ }
+
+ @Test
+ public void testConstructorWithoutAlphaGeneratesAlpha() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags);
+ assertNotNull(cu);
+ }
+
+ @Test
+ public void testUnionReturnsNonNullDag() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ Dag result = cu.union();
+ assertNotNull(result);
+ }
+
+ @Test
+ public void testNumberOfInsertedEdgesIsUpdated() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ cu.union();
+ assertTrue(cu.getNumberOfInsertedEdges() >= 0);
+ }
+
+ @Test
+ public void testSetDagsUpdatesAlphaAndUnion() {
+ ConsensusUnion cu = new ConsensusUnion();
+ cu.setDags(inputDags);
+ cu.union();
+ Dag result = cu.getUnion();
+ assertNotNull(result);
+ assertTrue(result.getNumEdges() >= 1);
+ }
+
+ @Test
+ public void testRunMethodExecutesUnion() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ cu.run();
+ assertNotNull(cu.getUnion());
+ }
+
+ @Test
+ public void testEmptyDagListReturnsEmptyUnion() {
+ assertThrows(IllegalArgumentException.class, () -> new ConsensusUnion(new ArrayList<>()));
+ }
+
+ @Test
+ public void testRandomBNGeneratesConsensusUnionCorrectly() {
+
+ //System.out.println("Grafos de Partida: ");
+
+ // (seed, n. variables, n egdes aprox, n. dags, mutation)
+ RandomBN setOfDags = new RandomBN(0, 20, 50,
+ 4,3);
+ setOfDags.generate();
+
+ //for( Dag g: setOfDags.setOfRandomDags) System.out.print(g);
+ ConsensusUnion conDag= new ConsensusUnion(setOfDags.setOfRandomDags);
+ Graph g = conDag.union();
+ //System.out.println("grafo consenso: "+ g);
+
+ assertNotNull(g);
+ assertTrue(g.getNumEdges() >= 0);
+ assertTrue(g.getNodes().size() == setOfDags.setOfRandomDags.get(0).getNodes().size());
+
+ }
+}
From 90688a85198d135c94260f1931b58141e9a9e210 Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Wed, 16 Jul 2025 11:35:12 +0200
Subject: [PATCH 06/32] Reformatring consensusBES and creating a new class for
BES, consistency checked
---
.../BackwardEquivalenceSearchDSep.java | 366 ++++++++++++++++++
.../i3a/simd/consensusBN/ConsensusBES.java | 41 +-
.../simd/consensusBN/ConsensusBESTest.java | 124 ++++++
3 files changed, 520 insertions(+), 11 deletions(-)
create mode 100644 src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
new file mode 100644
index 0000000..079ce49
--- /dev/null
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
@@ -0,0 +1,366 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.EdgeListGraph;
+import edu.cmu.tetrad.graph.Edges;
+import edu.cmu.tetrad.graph.Endpoint;
+import edu.cmu.tetrad.graph.Graph;
+import edu.cmu.tetrad.graph.Node;
+import edu.cmu.tetrad.search.utils.GraphSearchUtils;
+import edu.cmu.tetrad.search.utils.MeekRules;
+import static es.uclm.i3a.simd.consensusBN.Utils.pdagToDag;
+
+public class BackwardEquivalenceSearchDSep {
+
+ private final Graph graph;
+ private final ArrayList transformedDags;
+ private final ArrayList initialDags;
+ private Dag outputDag;
+ private final Map localScore = new HashMap<>();
+ private int numberOfInsertedEdges = 0;
+
+
+ public BackwardEquivalenceSearchDSep(Dag union, ArrayListinitialDags, ArrayList transformedDags) {
+ this.graph = new EdgeListGraph(new LinkedList<>(union.getNodes()));
+ for (Edge edge : union.getEdges()) {
+ graph.addEdge(edge);
+ }
+ this.initialDags = initialDags;
+ this.transformedDags = transformedDags;
+ }
+
+ public Dag applyBackwardEliminationWithDSeparation(){
+ // Implement the BESd algorithm logic here
+ // This is a placeholder for the actual BESd algorithm implementation
+ // The algorithm should modify the graph based on the BESd logic
+ rebuildPattern(graph);
+ Node x, y;
+ Set t = new HashSet<>();
+ double score = 0;
+ double bestScore = score;
+ do {
+ x = y = null;
+ Set edges1 = graph.getEdges();
+ List edges = new ArrayList<>();
+
+ for (Edge edge : edges1) {
+ Node _x = edge.getNode1();
+ Node _y = edge.getNode2();
+
+ if (Edges.isUndirectedEdge(edge)) {
+ edges.add(Edges.directedEdge(_x, _y));
+ edges.add(Edges.directedEdge(_y, _x));
+ } else {
+ edges.add(edge);
+ }
+ }
+ for (Edge edge : edges) {
+
+ Node _x = Edges.getDirectedEdgeTail(edge);
+ Node _y = Edges.getDirectedEdgeHead(edge);
+
+ List hNeighbors = getHNeighbors(_x, _y, graph);
+// List> hSubsets = powerSet(hNeighbors);
+ PowerSet hSubsets= PowerSetFabric.getPowerSet(_x,_y,hNeighbors);
+
+ while(hSubsets.hasMoreElements()) {
+ SubSet hSubset=hSubsets.nextElement();
+ double deleteEval = deleteEval(_x, _y, hSubset, graph);
+ if (!(deleteEval >= 1.0)) deleteEval = 0.0;
+ double evalScore = score + deleteEval;
+
+ //System.out.println("Attempt removing " + _x + "-->" + _y + "(" +evalScore + ") "+ hSubset.toString());
+
+ if (!(evalScore > bestScore)) {
+ continue;
+ }
+
+ // INICIO TEST 1
+ List naYXH = findNaYX(_x, _y, graph);
+ naYXH.removeAll(hSubset);
+ if (!isClique(naYXH, graph)) {
+// hSubsets.firstTest(true); // Si pasa para H entonces pasa para cualquier H' | H' contiene H
+ continue;
+ }
+ // FIN TEST 1
+
+ bestScore = evalScore;
+ x = _x;
+ y = _y;
+ t = hSubset;
+ }
+
+ }
+ if (x != null) {
+ System.out.println(" ");
+ System.out.println("DELETE " + graph.getEdge(x, y) + t.toString() + " (" +bestScore + ")");
+ System.out.println(" ");
+ delete(x, y, t, graph);
+ rebuildPattern(graph);
+ int deletedEdges = 0;
+ for(int g = 0; g *IMPORTANT!* *It assumes all colliders are oriented, as well as
+ * arrows dictated by time order.*
+ *
+ * ELIMINADO BACKGROUND KNOWLEDGE
+ */
+ private void pdag(Graph graph) {
+ MeekRules rules = new MeekRules();
+ rules.setMeekPreventCycles(true);
+ rules.orientImplied(graph);
+ }
+
+ private static List getHNeighbors(Node x, Node y, Graph graph) {
+ List hNeighbors = new LinkedList<>(graph.getAdjacentNodes(y));
+ hNeighbors.retainAll(graph.getAdjacentNodes(x));
+
+ for (int i = hNeighbors.size() - 1; i >= 0; i--) {
+ Node z = hNeighbors.get(i);
+ Edge edge = graph.getEdge(y, z);
+ if (!Edges.isUndirectedEdge(edge)) {
+ hNeighbors.remove(z);
+ }
+ }
+
+ return hNeighbors;
+ }
+
+ private static void delete(Node x, Node y, Set subset, Graph graph) {
+ graph.removeEdges(x, y);
+
+ for (Node aSubset : subset) {
+ if (!graph.isParentOf(aSubset, x) && !graph.isParentOf(x, aSubset)) {
+ graph.removeEdge(x, aSubset);
+ graph.addDirectedEdge(x, aSubset);
+ }
+ graph.removeEdge(y, aSubset);
+ graph.addDirectedEdge(y, aSubset);
+ }
+ }
+
+ private double deleteEval(Node x, Node y, SubSet h, Graph graph){
+
+ Set set1 = new HashSet(findNaYX(x, y, graph));
+ set1.removeAll(h);
+ set1.addAll(graph.getParents(y));
+ set1.remove(x);
+ return scoreGraphChangeDelete(y, x, set1); // calcular si y esta d-separado de x dado el set1 en cada grafo.
+
+ }
+
+ private static List findNaYX(Node x, Node y, Graph graph) {
+ List naYX = new LinkedList<>(graph.getAdjacentNodes(y));
+ naYX.retainAll(graph.getAdjacentNodes(x));
+
+ for (int i = naYX.size()-1; i >= 0; i--) {
+ Node z = naYX.get(i);
+ Edge edge = graph.getEdge(y, z);
+
+ if (!Edges.isUndirectedEdge(edge)) {
+ naYX.remove(z);
+ }
+ }
+
+ return naYX;
+ }
+
+ private double scoreGraphChangeDelete(Node y, Node x, Set set){
+
+ String key = y.getName()+x.getName()+set.toString();
+ Double val = this.localScore.get(key);
+ if(val == null){
+ double eval = 0.0;
+ LinkedList conditioning = new LinkedList<>();
+ conditioning.addAll(set);
+ for(Dag g: this.initialDags){
+ if(!dSeparated(g,y, x, conditioning)) return 0.0;
+ }
+ eval = 1.0; //eval / (double) this.setOfdags.size();
+ val = eval;
+ this.localScore.put(key, val);
+ return eval;
+ }else{
+ return val.doubleValue();
+ }
+ }
+
+ boolean dSeparated(Dag g, Node x, Node y, LinkedList cond){
+
+ LinkedList open = new LinkedList();
+ HashMap close = new HashMap();
+ open.add(x);
+ open.add(y);
+ open.addAll(cond);
+ while (open.size() != 0){
+ Node a = open.getFirst();
+ open.remove(a);
+ close.put(a.toString(),a);
+ List pa =g.getParents(a);
+ for(Node p : pa){
+ if(close.get(p.toString()) == null){
+ if(!open.contains(p)) open.addLast(p);
+ }
+ }
+ }
+
+ Graph aux = new EdgeListGraph();
+
+ for (Node node : g.getNodes()) aux.addNode(node);
+ Node nodeT, nodeH;
+ for (Edge e : g.getEdges()){
+ if(!e.isDirected()) continue;
+ nodeT = e.getNode1();
+ nodeH = e.getNode2();
+ if((close.get(nodeH.toString())!=null)&&(close.get(nodeT.toString())!=null)){
+ Edge newEdge = new Edge(e.getNode1(),e.getNode2(),e.getEndpoint1(),e.getEndpoint2());
+ aux.addEdge(newEdge);
+ }
+ }
+
+ close = new HashMap();
+ for(Edge e: aux.getEdges()){
+ if(e.isDirected()){
+ Node h;
+ if(e.getEndpoint1()==Endpoint.ARROW){
+ h = e.getNode1();
+ }else h = e.getNode2();
+ if(close.get(h.toString())==null){
+ close.put(h.toString(),h);
+ List pa = aux.getParents(h);
+ if(pa.size()>1){
+ for(int i = 0 ; i< pa.size() - 1; i++)
+ for(int j = i+1; j < pa.size(); j++){
+ Node p1 = pa.get(i);
+ Node p2 = pa.get(j);
+ boolean found = false;
+ for(Edge edge : aux.getEdges()){
+ if(edge.getNode1().equals(p1)&&(edge.getNode2().equals(p2))){
+ found = true;
+ break;
+ }
+ if(edge.getNode2().equals(p1)&&(edge.getNode1().equals(p2))){
+ found = true;
+ break;
+ }
+ }
+ if(!found) aux.addUndirectedEdge(p1, p2);
+ }
+ }
+
+ }
+ }
+ }
+
+ for(Edge e: aux.getEdges()){
+ if(e.isDirected()){
+ e.setEndpoint1(Endpoint.TAIL);
+ e.setEndpoint2(Endpoint.TAIL);
+ }
+ }
+
+ aux.removeNodes(cond);
+
+ open = new LinkedList();
+ close = new HashMap();
+ open.add(x);
+ while (open.size() != 0){
+ Node a = open.getFirst();
+ if(a.equals(y)) return false;
+ open.remove(a);
+ close.put(a.toString(),a);
+ List pa =aux.getAdjacentNodes(a);
+ for(Node p : pa){
+ if(close.get(p.toString()) == null){
+ if(!open.contains(p)) open.addLast(p);
+ }
+ }
+ }
+
+ return true;
+ }
+
+
+ public Dag getFusion(){
+
+ return this.outputDag;
+ }
+
+ public List getOrderFusion(){
+ return this.getFusion().paths().getValidOrder(this.getFusion().getNodes(),true);
+ }
+
+
+ private static boolean isClique(List set, Graph graph) {
+ List setv = new LinkedList(set);
+ for (int i = 0; i < setv.size() - 1; i++) {
+ for (int j = i + 1; j < setv.size(); j++) {
+ if (!graph.isAdjacentTo(setv.get(i), setv.get(j))) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+
+
+}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
index 9514eb2..33b5d9d 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
@@ -18,7 +18,6 @@
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import static es.uclm.i3a.simd.consensusBN.Utils.pdagToDag;
-//import experimentosFusion.RandomBN;
@@ -26,27 +25,32 @@ public class ConsensusBES implements Runnable {
ArrayList alpha = null;
Dag outputDag = null;
- AlphaOrder heuristic = null;
- TransformDags imaps2alpha = null;
+ //AlphaOrder heuristic = null;
+ //TransformDags imaps2alpha = null;
+ ConsensusUnion consensusUnion;
ArrayList setOfdags = null;
ArrayList setOfOutDags = null;
Dag union = null;
int numberOfInsertedEdges = 0;
- Map localScore = new HashMap();
+ Map localScore = new HashMap<>();
public ConsensusBES(ArrayList dags){
this.setOfdags = dags;
+ this.consensusUnion = new ConsensusUnion(this.setOfdags);
+
+ /*
this.heuristic = new AlphaOrder(this.setOfdags);
- this.heuristic.computeAlphaH2();
- this.alpha = this.heuristic.alpha;
+ this.heuristic.computeAlpha();
+ this.alpha = this.heuristic.getOrder();
this.imaps2alpha = new TransformDags(this.setOfdags,this.alpha);
this.imaps2alpha.transform();
this.numberOfInsertedEdges = imaps2alpha.getNumberOfInsertedEdges();
- this.setOfOutDags = imaps2alpha.setOfOutputDags;
+ this.setOfOutDags = imaps2alpha.getSetOfOutputDags();
+ */
}
@@ -54,11 +58,14 @@ public int getNumberOfInsertedEdges(){
return this.numberOfInsertedEdges;
}
- private void consensusUnion(){
-
+ public void consensusUnion(){
+ this.union = this.consensusUnion.union();
+ this.setOfOutDags = this.consensusUnion.getTransformedDags();
+
+ /*
this.union = new Dag(this.alpha);
for(Node nodei: this.alpha){
- for(Dag d : this.imaps2alpha.setOfOutputDags){
+ for(Dag d : this.imaps2alpha.getSetOfOutputDags()){
Listparent = d.getParents(nodei);
for(Node pa: parent){
if(!this.union.isParentOf(pa, nodei)){
@@ -68,6 +75,7 @@ private void consensusUnion(){
}
}
+ */
// for(Edge e: this.union.getEdges()){
// for(Dag d : this.imaps2alpha.setOfOutputDags){
// if((d.getEdge(e.getNode1(), e.getNode2())==null) && (d.getEdge(e.getNode2(), e.getNode1())==null))
@@ -77,8 +85,19 @@ private void consensusUnion(){
// }
}
+
+ public Dag getUnion() {
+ return this.union;
+ }
// private methods for searching
+ public void fusion2(){
+ // 1. Apply ConsensusUnion to the set of dags
+ consensusUnion();
+ // 2. Apply Backward Equivalence Search with D-separation
+ BackwardEquivalenceSearchDSep bes = new BackwardEquivalenceSearchDSep(this.union, this.setOfdags, this.setOfOutDags);
+ this.outputDag = bes.applyBackwardEliminationWithDSeparation();
+ }
public void fusion(){
@@ -99,7 +118,7 @@ public void fusion(){
//SearchGraphUtils.dagToPdag(graph);
rebuildPattern(graph);
Node x, y;
- Set t = new HashSet();
+ Set t = new HashSet<>();
do {
x = y = null;
Set edges1 = graph.getEdges();
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
new file mode 100644
index 0000000..bfcdca2
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
@@ -0,0 +1,124 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class ConsensusBESTest {
+ private ArrayList inputDags;
+ private ArrayList alpha;
+
+ @BeforeEach
+ public void setUp() {
+ inputDags = new ArrayList<>();
+
+ // We use 4 nodes for the DAGs
+ Node nodeA = new GraphNode("A");
+ Node nodeB = new GraphNode("B");
+ Node nodeC = new GraphNode("C");
+ Node nodeD = new GraphNode("D");
+
+ // Create first DAG with these edges: A -> B, A -> C, B -> D, C -> D
+ Dag dag1 = new Dag();
+ dag1.addNode(nodeA);
+ dag1.addNode(nodeB);
+ dag1.addNode(nodeC);
+ dag1.addNode(nodeD);
+
+ // Adding directed edges to the DAG
+ dag1.addDirectedEdge(nodeA, nodeB);
+ dag1.addDirectedEdge(nodeA, nodeC);
+ dag1.addDirectedEdge(nodeB, nodeD);
+ dag1.addDirectedEdge(nodeC, nodeD);
+
+ // Adding the DAG to the list
+ inputDags.add(dag1);
+
+ // Create second DAG with these edges: D -> C, D -> B, C -> A, B -> A
+ Dag dag2 = new Dag();
+ dag2.addNode(nodeA);
+ dag2.addNode(nodeB);
+ dag2.addNode(nodeC);
+ dag2.addNode(nodeD);
+
+ // Adding directed edges to the second DAG
+ dag2.addDirectedEdge(nodeD, nodeC);
+ dag2.addDirectedEdge(nodeD, nodeB);
+ dag2.addDirectedEdge(nodeC, nodeA);
+ dag2.addDirectedEdge(nodeB, nodeA);
+
+ // Adding the second DAG to the list
+ inputDags.add(dag2);
+
+ // Apply AlphaOrder algorithm to these dags:
+ AlphaOrder alphaOrder = new AlphaOrder(inputDags);
+ alphaOrder.computeAlpha();
+ alpha = alphaOrder.getOrder();
+
+ }
+
+ @Test
+ public void testConsensusUnionConsistency() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ Dag expected = cu.union();
+ assertNotNull(cu);
+ assertNotNull(expected);
+
+ ConsensusBES consensusBES = new ConsensusBES(inputDags);
+ consensusBES.consensusUnion();
+
+ // Check if the union DAG is not null and has nodes
+ Dag unionDag = consensusBES.getUnion();
+ assertNotNull(unionDag);
+ assertNotNull(unionDag.getNodes());
+
+ // Check that the union DAG is equal to the expected DAG
+ assertNotNull(unionDag.getNodes());
+ assertNotNull(expected.getNodes());
+ assertEquals(expected, unionDag);
+ assertEquals(expected.getNodes().size(), unionDag.getNodes().size());
+ assertEquals(expected.getEdges().size(), unionDag.getEdges().size());
+
+ for (Node node : expected.getNodes()) {
+ assert unionDag.getNodes().contains(node);
+ }
+ for(Edge edge : expected.getEdges()) {
+ assert unionDag.getEdges().contains(edge);
+ }
+ }
+
+ @Test
+ public void testConsensusBESConsistency() {
+ ConsensusBES consensusBES1 = new ConsensusBES(inputDags);
+ consensusBES1.fusion();
+ Dag outputDag1 = consensusBES1.getFusion();
+ assertNotNull(outputDag1);
+
+ ConsensusBES consensusBES2 = new ConsensusBES(inputDags);
+ consensusBES2.fusion2();
+ Dag outputDag2 = consensusBES2.getFusion();
+ assertNotNull(outputDag2);
+
+ // Check that both outputs are the same
+ assertNotNull(outputDag1);
+ assertNotNull(outputDag2);
+ assertEquals(outputDag1, outputDag2);
+ assertEquals(outputDag1.getNodes().size(), outputDag2.getNodes().size());
+ assertEquals(outputDag1.getEdges().size(), outputDag2.getEdges().size());
+
+ for (Node node : outputDag1.getNodes()) {
+ assert outputDag2.getNodes().contains(node);
+ }
+ for(Edge edge : outputDag1.getEdges()) {
+ assert outputDag2.getEdges().contains(edge);
+ }
+ }
+}
From d5f2f58ed10b258ee7843cce7817cbfb1c20b45e Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Wed, 16 Jul 2025 13:05:17 +0200
Subject: [PATCH 07/32] Cleaning and testing ConsensusBES
---
.../BackwardEquivalenceSearchDSep.java | 36 +-
.../i3a/simd/consensusBN/ConsensusBES.java | 566 ++++--------------
.../i3a/simd/consensusBN/ConsensusUnion.java | 8 +
.../BackwardEquivalenceSearchDSepTest.java | 112 ++++
.../simd/consensusBN/ConsensusBESTest.java | 111 +++-
5 files changed, 346 insertions(+), 487 deletions(-)
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
index 079ce49..d30b810 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
@@ -329,6 +329,7 @@ boolean dSeparated(Dag g, Node x, Node y, LinkedList cond){
close.put(a.toString(),a);
List pa =aux.getAdjacentNodes(a);
for(Node p : pa){
+ if(p == null) continue;
if(close.get(p.toString()) == null){
if(!open.contains(p)) open.addLast(p);
}
@@ -338,28 +339,23 @@ boolean dSeparated(Dag g, Node x, Node y, LinkedList cond){
return true;
}
-
- public Dag getFusion(){
-
- return this.outputDag;
- }
-
- public List getOrderFusion(){
- return this.getFusion().paths().getValidOrder(this.getFusion().getNodes(),true);
- }
- private static boolean isClique(List set, Graph graph) {
- List setv = new LinkedList(set);
- for (int i = 0; i < setv.size() - 1; i++) {
- for (int j = i + 1; j < setv.size(); j++) {
- if (!graph.isAdjacentTo(setv.get(i), setv.get(j))) {
- return false;
- }
- }
- }
- return true;
- }
+ private static boolean isClique(List set, Graph graph) {
+ List setv = new LinkedList(set);
+ for (int i = 0; i < setv.size() - 1; i++) {
+ for (int j = i + 1; j < setv.size(); j++) {
+ if (!graph.isAdjacentTo(setv.get(i), setv.get(j))) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ public int getNumberOfInsertedEdges() {
+ return this.numberOfInsertedEdges;
+ }
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
index 33b5d9d..d8514ef 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
@@ -2,486 +2,164 @@
import java.util.ArrayList;
import java.util.HashMap;
-import java.util.HashSet;
-import java.util.LinkedList;
import java.util.List;
import java.util.Map;
-import java.util.Set;
import edu.cmu.tetrad.graph.Dag;
-import edu.cmu.tetrad.graph.Edge;
-import edu.cmu.tetrad.graph.EdgeListGraph;
-import edu.cmu.tetrad.graph.Edges;
-import edu.cmu.tetrad.graph.Endpoint;
-import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
-import edu.cmu.tetrad.search.utils.MeekRules;
-import edu.cmu.tetrad.search.utils.GraphSearchUtils;
-import static es.uclm.i3a.simd.consensusBN.Utils.pdagToDag;
-
-
+/**
+ * This class implements the Optimal Fusion GES^h_d algorithm, which applies a Consensus Union followed by a Backward Equivalence Search (BES) with D-separation.
+ * The algorithm first computes a consensus DAG from a set of input DAGs using the ConsensusUnion class.
+ * After obtaining the consensus DAG, it applies the Backward Equivalence Search with D-separation to refine the graph, achieving the optimal fusion BN.
+ * The resulting output DAG is stored in the outputDag attribute.
+ */
public class ConsensusBES implements Runnable {
- ArrayList alpha = null;
- Dag outputDag = null;
- //AlphaOrder heuristic = null;
- //TransformDags imaps2alpha = null;
- ConsensusUnion consensusUnion;
- ArrayList setOfdags = null;
- ArrayList setOfOutDags = null;
- Dag union = null;
+ /**
+ * Final output DAG after applying the Consensus Union and Backward Equivalence Search with D-separation.
+ * This DAG represents the optimal fusion of the input DAGs.
+ * It is computed by first merging the input DAGs into a consensus DAG and then refining it using the BES with D-separation.
+ *
+ * @see ConsensusUnion
+ * @see BackwardEquivalenceSearchDSepTest
+ */
+ private Dag outputDag;
+
+ /**
+ * Instance of ConsensusUnion used to compute the consensus DAG from the input DAGs.
+ * This instance is initialized with the set of input DAGs and computes the alpha order of nodes using AlphaOrder heuristic (Greedy Heuristic Order).
+ *
+ * @see ConsensusUnion
+ * @see AlphaOrder
+ */
+ private final ConsensusUnion consensusUnion;
+
+ /**
+ * List of input DAGs to be fused using the ConsensusBES algorithm.
+ */
+ private final ArrayList inputDags;
+
+ /**
+ * List of transformed DAGs after applying the alpha order to the input DAGs.
+ * @see BetaToAlpha
+ * @see TransformDags
+ */
+ private ArrayList transformedDags;
+
+ /**
+ * Resulting DAG afther applying the Consensus Union algorithm.
+ * This DAG contains the union of all edges from the transformed input DAGs, ensuring that the resulting graph is acyclic.
+ * The number of edges inserted during the union process can be retrieved using getNumberOfInsertedEdges.
+ */
+ private Dag union = null;
+
+ /**
+ * Number of edges inserted during the consensus union process and the Backward Equivalence Search process.
+ */
int numberOfInsertedEdges = 0;
- Map localScore = new HashMap<>();
-
-
+ /**
+ * Local score map used to store the scores of graph changes during the Backward Equivalence Search.
+ * The key is a string representation of the nodes and their conditioning set, and the value is the score associated with that configuration.
+ */
+ private final Map localScore = new HashMap<>();
+
+ /**
+ * Constructor for ConsensusBES that initializes the union process with a list of DAGs.
+ * It creates an instance of ConsensusUnion to compute the consensus DAG.
+ * @param dags the list of input DAGs to be merged.
+ */
public ConsensusBES(ArrayList dags){
- this.setOfdags = dags;
- this.consensusUnion = new ConsensusUnion(this.setOfdags);
-
- /*
- this.heuristic = new AlphaOrder(this.setOfdags);
-
- this.heuristic.computeAlpha();
- this.alpha = this.heuristic.getOrder();
- this.imaps2alpha = new TransformDags(this.setOfdags,this.alpha);
-
- this.imaps2alpha.transform();
- this.numberOfInsertedEdges = imaps2alpha.getNumberOfInsertedEdges();
- this.setOfOutDags = imaps2alpha.getSetOfOutputDags();
- */
- }
-
-
- public int getNumberOfInsertedEdges(){
- return this.numberOfInsertedEdges;
+ this.inputDags = dags;
+ this.consensusUnion = new ConsensusUnion(this.inputDags);
}
+ /**
+ * Performs the consensus union operation by calling the union method of the ConsensusUnion instance.
+ * This method initializes the union process, transforming the input DAGs based on the alpha order and merging them into a single consensus DAG.
+ * After the union, it retrieves the transformed DAGs and updates the number of inserted edges.
+ */
public void consensusUnion(){
this.union = this.consensusUnion.union();
- this.setOfOutDags = this.consensusUnion.getTransformedDags();
-
- /*
- this.union = new Dag(this.alpha);
- for(Node nodei: this.alpha){
- for(Dag d : this.imaps2alpha.getSetOfOutputDags()){
- Listparent = d.getParents(nodei);
- for(Node pa: parent){
- if(!this.union.isParentOf(pa, nodei)){
- this.union.addEdge(new Edge(pa,nodei,Endpoint.TAIL,Endpoint.ARROW));
- }
- }
- }
-
- }
- */
-// for(Edge e: this.union.getEdges()){
-// for(Dag d : this.imaps2alpha.setOfOutputDags){
-// if((d.getEdge(e.getNode1(), e.getNode2())==null) && (d.getEdge(e.getNode2(), e.getNode1())==null))
-// this.numberOfInsertedEdges++;
-//
-// }
-// }
-
- }
-
- public Dag getUnion() {
- return this.union;
+ this.transformedDags = this.consensusUnion.getTransformedDags();
+ this.numberOfInsertedEdges += consensusUnion.getNumberOfInsertedEdges();
}
- // private methods for searching
- public void fusion2(){
+ /**
+ * Applies the fusion process by first performing the consensus union and then applying the Backward Equivalence Search with D-separation.
+ * This method modifies the outputDag attribute to contain the final fused DAG after applying both steps.
+ */
+ public void fusion(){
// 1. Apply ConsensusUnion to the set of dags
consensusUnion();
// 2. Apply Backward Equivalence Search with D-separation
- BackwardEquivalenceSearchDSep bes = new BackwardEquivalenceSearchDSep(this.union, this.setOfdags, this.setOfOutDags);
+ BackwardEquivalenceSearchDSep bes = new BackwardEquivalenceSearchDSep(this.union, this.inputDags, this.transformedDags);
this.outputDag = bes.applyBackwardEliminationWithDSeparation();
+ // 3. Updating numberOfInsertedEdges
+ this.numberOfInsertedEdges += bes.getNumberOfInsertedEdges();
}
-
- public void fusion(){
-
- // System.out.println("\n** BACKWARD ELIMINATION SEARCH (BES)");
- //PowerSetFabric.setMode(PowerSetFabric.MODE_BES);
- double score = 0;
- double bestScore = score;
- Graph graph = null;
-
- consensusUnion();
-
- graph = new EdgeListGraph(new LinkedList<>(this.union.getNodes()));
- for(Edge e: this.union.getEdges()){
- graph.addEdge(e);
- }
-
- //SearchGraphUtils.dagToPdag(graph);
- rebuildPattern(graph);
- Node x, y;
- Set t = new HashSet<>();
- do {
- x = y = null;
- Set edges1 = graph.getEdges();
- List edges = new ArrayList();
-
- for (Edge edge : edges1) {
- Node _x = edge.getNode1();
- Node _y = edge.getNode2();
-
- if (Edges.isUndirectedEdge(edge)) {
- edges.add(Edges.directedEdge(_x, _y));
- edges.add(Edges.directedEdge(_y, _x));
- } else {
- edges.add(edge);
- }
- }
- for (Edge edge : edges) {
-
- Node _x = Edges.getDirectedEdgeTail(edge);
- Node _y = Edges.getDirectedEdgeHead(edge);
-
- List hNeighbors = getHNeighbors(_x, _y, graph);
-// List> hSubsets = powerSet(hNeighbors);
- PowerSet hSubsets= PowerSetFabric.getPowerSet(_x,_y,hNeighbors);
-
- while(hSubsets.hasMoreElements()) {
- SubSet hSubset=hSubsets.nextElement();
- double deleteEval = deleteEval(_x, _y, hSubset, graph);
- if (!(deleteEval >= 1.0)) deleteEval = 0.0;
- double evalScore = score + deleteEval;
-
- //System.out.println("Attempt removing " + _x + "-->" + _y + "(" +evalScore + ") "+ hSubset.toString());
-
- if (!(evalScore > bestScore)) {
- continue;
- }
-
- // INICIO TEST 1
- List naYXH = findNaYX(_x, _y, graph);
- naYXH.removeAll(hSubset);
- if (!isClique(naYXH, graph)) {
-// hSubsets.firstTest(true); // Si pasa para H entonces pasa para cualquier H' | H' contiene H
- continue;
- }
- // FIN TEST 1
-
- bestScore = evalScore;
- x = _x;
- y = _y;
- t = hSubset;
- }
-
- }
- if (x != null) {
- System.out.println(" ");
- System.out.println("DELETE " + graph.getEdge(x, y) + t.toString() + " (" +bestScore + ")");
- System.out.println(" ");
- delete(x, y, t, graph);
- rebuildPattern(graph);
- int deletedEdges = 0;
- for(int g = 0; g subset, Graph graph) {
- graph.removeEdges(x, y);
-
- for (Node aSubset : subset) {
- if (!graph.isParentOf(aSubset, x) && !graph.isParentOf(x, aSubset)) {
- graph.removeEdge(x, aSubset);
- graph.addDirectedEdge(x, aSubset);
- }
- graph.removeEdge(y, aSubset);
- graph.addDirectedEdge(y, aSubset);
- }
- }
-
-
- private void rebuildPattern(Graph graph) {
- GraphSearchUtils.basicCpdag(graph);
- pdag(graph);
- }
-
- /**
- * Fully direct a graph with background knowledge. I am not sure how to
- * adapt Chickering's suggested algorithm above (dagToPdag) to incorporate
- * background knowledge, so I am also implementing this algorithm based on
- * Meek's 1995 UAI paper. Notice it is the same implemented in PcSearch.
- *
*IMPORTANT!* *It assumes all colliders are oriented, as well as
- * arrows dictated by time order.*
- *
- * ELIMINADO BACKGROUND KNOWLEDGE
- */
- private void pdag(Graph graph) {
- MeekRules rules = new MeekRules();
- rules.setMeekPreventCycles(true);
- rules.orientImplied(graph);
- }
-
-
- private static boolean isClique(List set, Graph graph) {
- List setv = new LinkedList(set);
- for (int i = 0; i < setv.size() - 1; i++) {
- for (int j = i + 1; j < setv.size(); j++) {
- if (!graph.isAdjacentTo(setv.get(i), setv.get(j))) {
- return false;
- }
- }
- }
- return true;
- }
-
- private static List getHNeighbors(Node x, Node y, Graph graph) {
- List hNeighbors = new LinkedList(graph.getAdjacentNodes(y));
- hNeighbors.retainAll(graph.getAdjacentNodes(x));
-
- for (int i = hNeighbors.size() - 1; i >= 0; i--) {
- Node z = hNeighbors.get(i);
- Edge edge = graph.getEdge(y, z);
- if (!Edges.isUndirectedEdge(edge)) {
- hNeighbors.remove(z);
- }
- }
-
- return hNeighbors;
- }
-
-
- double deleteEval(Node x, Node y, SubSet h, Graph graph){
-
- Set set1 = new HashSet(findNaYX(x, y, graph));
- set1.removeAll(h);
- set1.addAll(graph.getParents(y));
- set1.remove(x);
- return scoreGraphChangeDelete(y, x, set1); // calcular si y esta d-separado de x dado el set1 en cada grafo.
-
- }
-
- double scoreGraphChangeDelete(Node y, Node x, Set set){
-
- String key = y.getName()+x.getName()+set.toString();
- Double val = this.localScore.get(key);
- if(val == null){
- double eval = 0.0;
- LinkedList conditioning = new LinkedList();
- conditioning.addAll(set);
- for(Dag g: this.setOfdags){
- if(!dSeparated(g,y, x, conditioning)) return 0.0;
- }
- eval = 1.0; //eval / (double) this.setOfdags.size();
- val = eval;
- this.localScore.put(key, val);
- return eval;
- }else{
- return val.doubleValue();
- }
- }
-
-
- boolean dSeparated(Dag g, Node x, Node y, LinkedList cond){
-
- LinkedList open = new LinkedList();
- HashMap close = new HashMap();
- open.add(x);
- open.add(y);
- open.addAll(cond);
- while (open.size() != 0){
- Node a = open.getFirst();
- open.remove(a);
- close.put(a.toString(),a);
- List pa =g.getParents(a);
- for(Node p : pa){
- if(close.get(p.toString()) == null){
- if(!open.contains(p)) open.addLast(p);
- }
- }
- }
-
- Graph aux = new EdgeListGraph();
-
- for (Node node : g.getNodes()) aux.addNode(node);
- Node nodeT, nodeH;
- for (Edge e : g.getEdges()){
- if(!e.isDirected()) continue;
- nodeT = e.getNode1();
- nodeH = e.getNode2();
- if((close.get(nodeH.toString())!=null)&&(close.get(nodeT.toString())!=null)){
- Edge newEdge = new Edge(e.getNode1(),e.getNode2(),e.getEndpoint1(),e.getEndpoint2());
- aux.addEdge(newEdge);
- }
- }
-
- close = new HashMap();
- for(Edge e: aux.getEdges()){
- if(e.isDirected()){
- Node h;
- if(e.getEndpoint1()==Endpoint.ARROW){
- h = e.getNode1();
- }else h = e.getNode2();
- if(close.get(h.toString())==null){
- close.put(h.toString(),h);
- List pa = aux.getParents(h);
- if(pa.size()>1){
- for(int i = 0 ; i< pa.size() - 1; i++)
- for(int j = i+1; j < pa.size(); j++){
- Node p1 = pa.get(i);
- Node p2 = pa.get(j);
- boolean found = false;
- for(Edge edge : aux.getEdges()){
- if(edge.getNode1().equals(p1)&&(edge.getNode2().equals(p2))){
- found = true;
- break;
- }
- if(edge.getNode2().equals(p1)&&(edge.getNode1().equals(p2))){
- found = true;
- break;
- }
- }
- if(!found) aux.addUndirectedEdge(p1, p2);
- }
- }
-
- }
- }
- }
-
- for(Edge e: aux.getEdges()){
- if(e.isDirected()){
- e.setEndpoint1(Endpoint.TAIL);
- e.setEndpoint2(Endpoint.TAIL);
- }
- }
-
- aux.removeNodes(cond);
-
- open = new LinkedList();
- close = new HashMap();
- open.add(x);
- while (open.size() != 0){
- Node a = open.getFirst();
- if(a.equals(y)) return false;
- open.remove(a);
- close.put(a.toString(),a);
- List pa =aux.getAdjacentNodes(a);
- for(Node p : pa){
- if(close.get(p.toString()) == null){
- if(!open.contains(p)) open.addLast(p);
- }
- }
- }
-
- return true;
- }
-
-
-
-
- private static List findNaYX(Node x, Node y, Graph graph) {
- List naYX = new LinkedList(graph.getAdjacentNodes(y));
- naYX.retainAll(graph.getAdjacentNodes(x));
-
- for (int i = naYX.size()-1; i >= 0; i--) {
- Node z = naYX.get(i);
- Edge edge = graph.getEdge(y, z);
-
- if (!Edges.isUndirectedEdge(edge)) {
- naYX.remove(z);
- }
- }
-
- return naYX;
- }
-
+ /**
+ * Returns the output DAG after applying the Consensus Union and Backward Equivalence Search with D-separation.
+ * This method retrieves the final fused DAG, which represents the optimal fusion of the input DAGs.
+ * @return the resulting output DAG after the fusion process.
+ */
public Dag getFusion(){
-
return this.outputDag;
}
+ /**
+ * Returns a valid ancestral order of the nodes in the fused DAG.
+ * @return
+ */
public List getOrderFusion(){
return this.getFusion().paths().getValidOrder(this.getFusion().getNodes(),true);
}
-
- public static void main(String args[]) {
-
-
- System.out.println("Grafos de Partida: ");
-
- // (seed, n. variables, n egdes max, n.dags, mutation(n. de operaciones))
- RandomBN setOfBNs = new RandomBN(0, Integer.parseInt(args[0]), Integer.parseInt(args[1]),
- Integer.parseInt(args[2]), Integer.parseInt(args[3]));
- setOfBNs.setMaxInDegree(4);
- setOfBNs.setMaxOutDegree(4);
- setOfBNs.generate();
+ /**
+ * Returns the number of edges inserted during the consensus union and removed in the Backward Equivalence Search with D-separation.
+ * @return
+ */
+ public int getNumberOfInsertedEdges(){
+ return this.numberOfInsertedEdges;
+ }
- for (int i = 0; i < setOfBNs.setOfRandomBNs.size(); i++) {
- System.out.println("red de partida: " + i);
- System.out.println("---------------------");
- System.out.println("Grafo: ");
- System.out.println(setOfBNs.setOfRandomDags.get(i).toString());
-// System.out.println("Probabilidades: ");
-// System.out.println(setOfBNs.setOfRandomBNs.get(i).toString());
-// System.out.println("_____________________");
-// System.out.println("Datos Simulados");
-// System.out.println(setOfBNs.setOfSampledBNs.get(i).toString());
+ /**
+ * Returns the union DAG resulting from the consensus union process.
+ * @return the union DAG after merging the transformed input DAGs.
+ */
+ public Dag getUnion() {
+ return this.union;
+ }
-//
-// }
-// //
- ConsensusBES conDag= null;
-//
- conDag = new ConsensusBES(setOfBNs.setOfRandomDags);
- conDag.fusion();
- Dag g = conDag.getFusion();
- System.out.println("grafo consenso: "+ g +" Complejidad de la Fusion: "+ conDag.getNumberOfInsertedEdges()
- + " "+ conDag.union.getNumEdges());
- System.out.println("Orden Inicial Heu: "+conDag.alpha.toString());
- System.out.println("Orden de consenso: "+conDag.getOrderFusion().toString());
-//
-//// HierarchicalAgglomerativeClustererBNs Cfusion = new HierarchicalAgglomerativeClustererBNs(setOfBNs.setOfRandomDags,0.50);
-//// int l = Cfusion.cluster();
-//// System.out.println("Nivel de Fusion: "+l);
-//// System.out.println(Cfusion.computeConsensusDag(l).toString());
-// }
-//
+ /**
+ * Returns the ConsensusUnion instance used in this ConsensusBES.
+ * This instance contains the logic for merging the input DAGs and computing the alpha order.
+ * @return the ConsensusUnion instance associated with this ConsensusBES.
+ */
+ public ConsensusUnion getConsensusUnion() {
+ return this.consensusUnion;
+ }
+
+ /**
+ * Returns the list of transformed DAGs after applying the alpha order to the input DAGs.
+ * This method retrieves the transformed DAGs that were used in the consensus union process.
+ * @return the list of transformed DAGs.
+ */
+ public ArrayList getTransformedDags() {
+ if (this.transformedDags != null) {
+ return this.transformedDags;
+ } else {
+ throw new IllegalStateException("Transformed DAGs have not been initialized. Please call fusion() first.");
}
}
-
+
+ /**
+ * Runs the ConsensusBES algorithm in a thread, performing the consensus union and the Backward Equivalence Search with D-separation.
+ */
@Override
public void run() {
-
+ this.fusion();
}
}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java
index ba63e1b..c0dd4ae 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java
@@ -162,6 +162,14 @@ void setDags(ArrayList dags){
public void run() {
this.union = this.union();
}
+
+ public ArrayList getTransformedDags() {
+ if (this.imaps2alpha != null) {
+ return this.imaps2alpha.getSetOfOutputDags();
+ } else {
+ throw new IllegalStateException("TransformDags has not been initialized. Please call union() first.");
+ }
+ }
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
new file mode 100644
index 0000000..7078928
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
@@ -0,0 +1,112 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.GraphUtils;
+import edu.cmu.tetrad.graph.Node;
+
+class BackwardEquivalenceSearchDSepTest {
+
+ private Dag createSimpleDag() {
+ // A -> B -> C
+ Node a = new GraphNode("A");
+ Node b = new GraphNode("B");
+ Node c = new GraphNode("C");
+
+ Dag dag = new Dag();
+ dag.addNode(a);
+ dag.addNode(b);
+ dag.addNode(c);
+
+ dag.addDirectedEdge(a, b);
+ dag.addDirectedEdge(b, c);
+
+ return dag;
+ }
+
+ private ArrayList createDagList(int copies) {
+ ArrayList list = new ArrayList<>();
+ for (int i = 0; i < copies; i++) {
+ list.add(createSimpleDag());
+ }
+ return list;
+ }
+
+ @Test
+ void testApplyBESdDoesNotThrow() {
+ Dag unionDag = createSimpleDag();
+ ArrayList initialDags = createDagList(3);
+ ArrayList transformedDags = createDagList(3);
+
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
+
+ assertDoesNotThrow(() -> {
+ Dag output = besd.applyBackwardEliminationWithDSeparation();
+ assertNotNull(output);
+ });
+ }
+
+ @Test
+ void testOutputIsDAG() {
+ Dag unionDag = createSimpleDag();
+ ArrayList initialDags = createDagList(2);
+ ArrayList transformedDags = createDagList(2);
+
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
+ Dag outputDag = besd.applyBackwardEliminationWithDSeparation();
+
+ assertTrue(GraphUtils.isDag(outputDag), "El resultado no es un DAG válido.");
+ }
+
+ @Test
+ void testAristasSePuedenEliminar() {
+ // Creamos un grafo donde A -> B, pero en todos los DAGs está A B (no conectados)
+ Node a = new GraphNode("A");
+ Node b = new GraphNode("B");
+
+ Dag unionDag = new Dag();
+ unionDag.addNode(a);
+ unionDag.addNode(b);
+ unionDag.addDirectedEdge(a, b);
+
+ // DAGs originales sin esa arista
+ Dag dag1 = new Dag();
+ dag1.addNode(a);
+ dag1.addNode(b);
+ // sin conexión
+
+ ArrayList initialDags = new ArrayList<>();
+ initialDags.add(dag1);
+ ArrayList transformedDags = new ArrayList<>();
+ transformedDags.add(dag1);
+
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
+ Dag outputDag = besd.applyBackwardEliminationWithDSeparation();
+
+ // Debe eliminar la arista A -> B por no tener soporte
+ Edge deletedEdge = outputDag.getEdge(a, b);
+ assertNull(deletedEdge, "La arista A -> B debería haberse eliminado.");
+ }
+
+ @Test
+ void testGetNumberOfInsertedEdgesReflectsChanges() {
+ Dag unionDag = createSimpleDag();
+ ArrayList dags = createDagList(2);
+
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, dags, dags);
+ besd.applyBackwardEliminationWithDSeparation();
+
+ int insertedEdges = besd.getNumberOfInsertedEdges();
+ // En el peor de los casos no ha eliminado ninguna, pero nunca debe ser negativo
+ assertTrue(insertedEdges >= 0, "El número de aristas insertadas no puede ser negativo.");
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
index bfcdca2..da9a321 100644
--- a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
@@ -1,9 +1,14 @@
package es.uclm.i3a.simd.consensusBN;
import java.util.ArrayList;
+import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -96,29 +101,89 @@ public void testConsensusUnionConsistency() {
}
@Test
- public void testConsensusBESConsistency() {
- ConsensusBES consensusBES1 = new ConsensusBES(inputDags);
- consensusBES1.fusion();
- Dag outputDag1 = consensusBES1.getFusion();
- assertNotNull(outputDag1);
-
- ConsensusBES consensusBES2 = new ConsensusBES(inputDags);
- consensusBES2.fusion2();
- Dag outputDag2 = consensusBES2.getFusion();
- assertNotNull(outputDag2);
-
- // Check that both outputs are the same
- assertNotNull(outputDag1);
- assertNotNull(outputDag2);
- assertEquals(outputDag1, outputDag2);
- assertEquals(outputDag1.getNodes().size(), outputDag2.getNodes().size());
- assertEquals(outputDag1.getEdges().size(), outputDag2.getEdges().size());
-
- for (Node node : outputDag1.getNodes()) {
- assert outputDag2.getNodes().contains(node);
- }
- for(Edge edge : outputDag1.getEdges()) {
- assert outputDag2.getEdges().contains(edge);
+ public void testRandomBNFusion(){
+ // (seed, n. variables, n egdes max, n.dags, mutation(n. de operaciones))
+ RandomBN setOfDags = new RandomBN(0, 20, 50,
+ 4,3);
+ setOfDags.setMaxInDegree(4);
+ setOfDags.setMaxOutDegree(4);
+ setOfDags.generate();
+
+ ConsensusBES conDag = new ConsensusBES(setOfDags.setOfRandomDags);
+ conDag.fusion();
+ Dag besDag = conDag.getFusion();
+ Dag unionDag = conDag.getUnion();
+ ConsensusUnion consensusUnion = conDag.getConsensusUnion();
+ int totalNumberOfInsertedEdges = conDag.getNumberOfInsertedEdges();
+ int consensusNumberOfInsertedEdges = consensusUnion.getNumberOfInsertedEdges();
+
+ assertNotNull(besDag);
+ assertNotNull(unionDag);
+ assertNotNull(consensusUnion);
+ assertEquals(besDag.getNodes().size(), unionDag.getNodes().size());
+ assert consensusNumberOfInsertedEdges >= 0;
+ assert consensusNumberOfInsertedEdges >= totalNumberOfInsertedEdges;
+ }
+
+
+ @Test
+ void testFusionProducesDag() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ fusionAlgorithm.fusion();
+
+ Dag result = fusionAlgorithm.getFusion();
+ assertNotNull(result, "El DAG de salida no debe ser null.");
+ assertFalse(result.paths().existsDirectedCycle(), "El DAG resultante no debe tener ciclos.");
+ }
+
+ @Test
+ void testEdgeInsertionCountIsCorrectlyComputed() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ fusionAlgorithm.fusion();
+
+ int insertedEdges = fusionAlgorithm.getNumberOfInsertedEdges();
+ assertTrue(insertedEdges >= 0, "El número de aristas insertadas debe ser >= 0.");
+ }
+
+ @Test
+ void testFusionOrderIsValid() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ fusionAlgorithm.fusion();
+
+ List order = fusionAlgorithm.getOrderFusion();
+ assertNotNull(order, "El orden de fusión no debe ser null.");
+ assertEquals(4, order.size(), "El orden de fusión debe tener 3 nodos.");
+ }
+
+ @Test
+ void testTransformedDagsAreAccessibleAfterFusion() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ fusionAlgorithm.fusion();
+
+ ArrayList transformed = fusionAlgorithm.getTransformedDags();
+ assertEquals(2, transformed.size(), "Debe haber 2 DAGs transformados.");
+ }
+
+ @Test
+ void testGetTransformedDagsWithoutFusionThrowsException() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+
+ assertThrows(IllegalStateException.class, fusionAlgorithm::getTransformedDags,
+ "Debe lanzar una excepción si se accede a los DAGs transformados sin llamar a fusion().");
+ }
+
+ @Test
+ void testThreadExecutionWithRunMethod() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ Thread thread = new Thread(fusionAlgorithm);
+ thread.start();
+ try {
+ thread.join();
+ } catch (InterruptedException e) {
+ fail("El hilo fue interrumpido.");
}
+
+ assertNotNull(fusionAlgorithm.getFusion(), "El DAG resultante debe existir tras ejecutar run().");
}
+
}
From 48868e83da4a366e2650c75f24b2a1b612f166f2 Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Wed, 16 Jul 2025 14:18:40 +0200
Subject: [PATCH 08/32] Cleaning applyBackwardEliminationWithDSeparation in
BackwardEquivalenceSearchDSep
---
.../BackwardEquivalenceSearchDSep.java | 242 ++++++++++++------
1 file changed, 165 insertions(+), 77 deletions(-)
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
index d30b810..e122886 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
@@ -39,91 +39,75 @@ public BackwardEquivalenceSearchDSep(Dag union, ArrayListinitialDags, Array
}
public Dag applyBackwardEliminationWithDSeparation(){
- // Implement the BESd algorithm logic here
- // This is a placeholder for the actual BESd algorithm implementation
- // The algorithm should modify the graph based on the BESd logic
- rebuildPattern(graph);
- Node x, y;
- Set t = new HashSet<>();
- double score = 0;
- double bestScore = score;
+ double score = 0;
+ EdgeCandidate bestCandidate;
+
+ // Creating a pdag from the graph
+ rebuildPattern(graph);
+
+ // While there are edges to delete, search for the best edge to delete
do {
- x = y = null;
- Set edges1 = graph.getEdges();
- List edges = new ArrayList<>();
-
- for (Edge edge : edges1) {
- Node _x = edge.getNode1();
- Node _y = edge.getNode2();
-
- if (Edges.isUndirectedEdge(edge)) {
- edges.add(Edges.directedEdge(_x, _y));
- edges.add(Edges.directedEdge(_y, _x));
- } else {
- edges.add(edge);
- }
- }
- for (Edge edge : edges) {
-
- Node _x = Edges.getDirectedEdgeTail(edge);
- Node _y = Edges.getDirectedEdgeHead(edge);
+ // Make sure that any undirected edge is transformed into two directed edges
+ List edges = cleanUndirectedEdges();
+
+ // Find the best edge to delete
+ bestCandidate = calculateBestCandidateEdge(edges, score);
+ /* for (Edge edge : edges) {
+ // Getting candidate edge to delete
+ Node candidateTail = Edges.getDirectedEdgeTail(edge);
+ Node candidateHead = Edges.getDirectedEdgeHead(edge);
- List hNeighbors = getHNeighbors(_x, _y, graph);
-// List> hSubsets = powerSet(hNeighbors);
- PowerSet hSubsets= PowerSetFabric.getPowerSet(_x,_y,hNeighbors);
+ List hNeighbors = getHNeighbors(candidateTail, candidateHead, graph);
+ PowerSet hSubsets= PowerSetFabric.getPowerSet(candidateTail,candidateHead,hNeighbors);
while(hSubsets.hasMoreElements()) {
SubSet hSubset=hSubsets.nextElement();
- double deleteEval = deleteEval(_x, _y, hSubset, graph);
- if (!(deleteEval >= 1.0)) deleteEval = 0.0;
- double evalScore = score + deleteEval;
- //System.out.println("Attempt removing " + _x + "-->" + _y + "(" +evalScore + ") "+ hSubset.toString());
-
- if (!(evalScore > bestScore)) {
- continue;
- }
-
- // INICIO TEST 1
- List naYXH = findNaYX(_x, _y, graph);
+ // Checking if {naYXH} \ {hSubset} is a clique
+ List naYXH = findNaYX(candidateTail, candidateHead, graph);
naYXH.removeAll(hSubset);
if (!isClique(naYXH, graph)) {
-// hSubsets.firstTest(true); // Si pasa para H entonces pasa para cualquier H' | H' contiene H
+ continue;
+ }
+
+ // Calculating the score of the candidate edge deletion
+ double deleteEval = deleteEval(candidateTail, candidateHead, hSubset, graph);
+
+ // Setting limit for deleteEval
+ if (!(deleteEval >= 1.0)) deleteEval = 0.0;
+
+ // If the score is not better than the best score, continue
+ double evalScore = score + deleteEval;
+ if (!(evalScore > bestScore)) {
continue;
}
- // FIN TEST 1
+ // Updating variables for the best edge deletion
bestScore = evalScore;
- x = _x;
- y = _y;
- t = hSubset;
+ bestTail = candidateTail;
+ bestHead = candidateHead;
+ bestSetParents = hSubset;
}
+ } */
+ //
+ if (bestCandidate != null) {
+ score = executeEdgeDeletion(bestCandidate);
}
- if (x != null) {
- System.out.println(" ");
- System.out.println("DELETE " + graph.getEdge(x, y) + t.toString() + " (" +bestScore + ")");
- System.out.println(" ");
- delete(x, y, t, graph);
- rebuildPattern(graph);
- int deletedEdges = 0;
- for(int g = 0; g cleanUndirectedEdges() {
+ Set edges1 = graph.getEdges();
+ List edges = new ArrayList<>();
+
+ for (Edge edge : edges1) {
+ Node _x = edge.getNode1();
+ Node _y = edge.getNode2();
+
+ if (Edges.isUndirectedEdge(edge)) {
+ edges.add(Edges.directedEdge(_x, _y));
+ edges.add(Edges.directedEdge(_y, _x));
+ } else {
+ edges.add(edge);
+ }
+ }
+ return edges;
+ }
+
+ private EdgeCandidate calculateBestCandidateEdge(List edges, double score){
+ double bestScore = score;
+ EdgeCandidate bestCandidate = null;
+ for(Edge edge : edges){
+ // Getting candidate edge to delete
+ Node candidateTail = Edges.getDirectedEdgeTail(edge);
+ Node candidateHead = Edges.getDirectedEdgeHead(edge);
+
+ List hNeighbors = getHNeighbors(candidateTail, candidateHead, graph);
+ PowerSet hSubsets= PowerSetFabric.getPowerSet(candidateTail,candidateHead,hNeighbors);
+
+ while(hSubsets.hasMoreElements()) {
+ // Getting a subset of hNeighbors
+ SubSet hSubset=hSubsets.nextElement();
+
+ // Checking if {naYXH} \ {hSubset} is a clique
+ List naYXH = findNaYX(candidateTail, candidateHead, graph);
+ naYXH.removeAll(hSubset);
+ if (!isClique(naYXH, graph)) {
+ continue;
+ }
+
+ // Calculating the score of the candidate edge deletion
+ double deleteEval = deleteEval(candidateTail, candidateHead, hSubset, graph);
+
+ // Setting limit for deleteEval
+ if (!(deleteEval >= 1.0)) deleteEval = 0.0;
+
+ // If the score is not better than the best score, continue
+ double evalScore = score + deleteEval;
+ if (!(evalScore > bestScore)) {
+ continue;
+ }
+
+ // Updating best candidate edge
+ bestCandidate = new EdgeCandidate(candidateTail, candidateHead, hSubset);
+ bestCandidate.score = evalScore;
+
+ // Updating score for the best edge deletion
+ bestScore = evalScore;
+ }
+ }
+ return bestCandidate;
+ }
+
+ private double executeEdgeDeletion(EdgeCandidate bestCandidate) {
+ Node bestTail;
+ Node bestHead;
+ Set bestSetParents;
+ double score;
+ double bestScore;
+ bestTail = bestCandidate.tail;
+ bestHead = bestCandidate.head;
+ bestSetParents = bestCandidate.conditioningSet;
+ bestScore = bestCandidate.score;
+
+ // Applying delete
+ System.out.println(" ");
+ System.out.println("DELETE " + graph.getEdge(bestTail, bestHead) + bestSetParents.toString() + " (" +bestScore + ")");
+ System.out.println(" ");
+ delete(bestTail, bestHead, bestSetParents, graph);
+
+ // Rebuilding the pattern after deleting the edge
+ rebuildPattern(graph);
+
+ // Updating the number of inserted edges
+ int deletedEdges = 0;
+ for(int g = 0; g conditioningSet;
+ public double score;
+ public EdgeCandidate(Node tail, Node head, Set conditioningSet) {
+ this.tail = tail;
+ this.head = head;
+ this.conditioningSet = conditioningSet;
+ }
+ }
+
}
From 3d8b651f7874b3c9403f34a6e335896847560f2c Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Thu, 17 Jul 2025 14:56:20 +0200
Subject: [PATCH 09/32] Cleaning and testing d-separation method and BESd
algorithm
---
.../BackwardEquivalenceSearchDSep.java | 267 +++++++++---------
.../es/uclm/i3a/simd/consensusBN/Utils.java | 109 +++++++
.../uclm/i3a/simd/consensusBN/UtilsTest.java | 102 +++++++
3 files changed, 348 insertions(+), 130 deletions(-)
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
index e122886..5331c50 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
@@ -1,6 +1,8 @@
package es.uclm.i3a.simd.consensusBN;
+import java.util.ArrayDeque;
import java.util.ArrayList;
+import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
@@ -19,16 +21,76 @@
import edu.cmu.tetrad.search.utils.MeekRules;
import static es.uclm.i3a.simd.consensusBN.Utils.pdagToDag;
+/**
+ * This class implements the Backward Equivalence Search with D-Separation
+ * algorithm for consensus Bayesian networks. It uses an implementation of
+ * second phase of the Greedy Equivalence Search (GES) algorithm, the Backward
+ * Equivalence Search (BES), to refine a consensus DAG by removing edges while
+ * ensuring that the resulting graph remains a Directed Acyclic Graph (DAG).
+ * Since no data is available, the algorithm relies on D-separation to
+ * determine whether two nodes are conditionally independent given a set of
+ * other nodes. For this, the algorithm uses the set of input DAGs to check
+ * whether the deletion of an edge maintains the d-separation condition.
+ */
public class BackwardEquivalenceSearchDSep {
-
+ /**
+ * The graph representing the consensus DAG after applying the Backward
+ * Equivalence Search with D-separation.
+ * This graph is built from the union of the transformed input DAGs and is
+ * refined by removing edges based on d-separation checks.
+ *
+ * @see ConsensusUnion
+ * @see TransformDags
+ */
private final Graph graph;
+
+ /**
+ * List of initial DAGs used to check how many edges are deleted.
+ */
private final ArrayList transformedDags;
+
+ /**
+ * List of initial DAGs used to check the d-separation condition.
+ * This list is used to verify whether the deletion of an edge maintains the
+ * d-separation condition across all input DAGs.
+ *
+ * @see Utils#dSeparated(Dag, Node, Node, List)
+ */
private final ArrayList initialDags;
+
+ /**
+ * The output DAG after applying the Backward Equivalence Search with D-separation.
+ * This DAG is the final result after removing edges from the consensus DAG
+ * while ensuring that the d-separation condition is maintained using the input DAGs.
+ *
+ * @see Utils#dSeparated(Dag, Node, Node, List)
+ */
private Dag outputDag;
+
+ /**
+ * A map to store the local scores for edge deletions.
+ * This map is used to cache the scores of edge deletions to avoid redundant calculations.
+ * The key is a string representation of the edge and its conditioning set, and the value is the score.
+ */
private final Map localScore = new HashMap<>();
- private int numberOfInsertedEdges = 0;
+ /**
+ * Number of edges inserted during the consensus union and backward equivalence search process.
+ * This variable keeps track of the total number of edges that were added to the consensus DAG
+ * during the union of transformed input DAGs and the subsequent edge deletions.
+ *
+ * @see ConsensusUnion#getNumberOfInsertedEdges()
+ * @see BackwardEquivalenceSearchDSep#applyBackwardEliminationWithDSeparation()
+ */
+ private int numberOfInsertedEdges = 0;
+ /**
+ * Constructor for BackwardEquivalenceSearchDSep that initializes the properties for the search with a union DAG and lists of initial and transformed DAGs.
+ *
+ * @param union The resulting union DAG from the ConsensusUnion process.
+ * @param initialDags List of initial DAGs used to check the d-separation condition.
+ * @param transformedDags List of transformed DAGs after applying the alpha order.
+ */
public BackwardEquivalenceSearchDSep(Dag union, ArrayListinitialDags, ArrayList transformedDags) {
this.graph = new EdgeListGraph(new LinkedList<>(union.getNodes()));
for (Edge edge : union.getEdges()) {
@@ -38,6 +100,12 @@ public BackwardEquivalenceSearchDSep(Dag union, ArrayListinitialDags, Array
this.transformedDags = transformedDags;
}
+ /**
+ * Applies the Backward Equivalence Search with D-separation to the consensus DAG.
+ * This method iteratively removes edges from the consensus DAG while ensuring that the d-separation condition is maintained across all input DAGs.
+ * It returns the final output DAG after all possible edge deletions.
+ * @return The output DAG after applying the Backward Equivalence Search with D-separation.
+ */
public Dag applyBackwardEliminationWithDSeparation(){
double score = 0;
EdgeCandidate bestCandidate;
@@ -103,35 +171,24 @@ public Dag applyBackwardEliminationWithDSeparation(){
return outputDag;
}
- private void createOutputDag() {
- // Rebuild the pattern to ensure the final graph is a DAG
- pdagToDag(graph);
-
- // Rebuild the output DAG from the final graph
- this.outputDag = new Dag();
- for (Node node : graph.getNodes()) this.outputDag.addNode(node);
- Node nodeT, nodeH;
- for (Edge e : graph.getEdges()){
- if(!e.isDirected()) continue;
- Endpoint endpoint1 = e.getEndpoint1();
- if (endpoint1.equals(Endpoint.ARROW)){
- nodeT = e.getNode1();
- nodeH = e.getNode2();
- }else{
- nodeT = e.getNode2();
- nodeH = e.getNode1();
- }
- if(!this.outputDag.paths().existsDirectedPath(nodeT, nodeH)) this.outputDag.addEdge(e);
- }
- }
-
-
+ /**
+ * Rebuilds the input graph to ensure it is a valid pattern.
+ * This method applies the Meek rules to orient the edges and ensure that the graph is a valid pattern.
+ * It also converts the graph to a PDAG (Partially Directed Acyclic Graph)
+ * @param graph The graph to validate and rebuild as a PDAG.
+ */
private void rebuildPattern(Graph graph) {
GraphSearchUtils.basicCpdag(graph);
pdag(graph);
}
-
+
+ /**
+ * Cleans the undirected edges in the graph by converting them to directed edges.
+ * This method iterates through the edges of the graph and transforms undirected edges into two directed edges,
+ * ensuring that the resulting graph maintains only directed edges.
+ * @return
+ */
private List cleanUndirectedEdges() {
Set edges1 = graph.getEdges();
List edges = new ArrayList<>();
@@ -150,6 +207,14 @@ private List cleanUndirectedEdges() {
return edges;
}
+ /**
+ * Calculates the best candidate edge for deletion based on the current score and the edges available.
+ * This method evaluates each edge and its possible conditioning sets to find the edge that, when deleted,
+ * results in the highest score improvement while maintaining the d-separation condition.
+ * @param edges List of edges to consider for deletion.
+ * @param score The current score before any edge deletion.
+ * @return An EdgeCandidate object representing the best edge to delete, or null if no suitable edge is found.
+ */
private EdgeCandidate calculateBestCandidateEdge(List edges, double score){
double bestScore = score;
EdgeCandidate bestCandidate = null;
@@ -195,6 +260,14 @@ private EdgeCandidate calculateBestCandidateEdge(List edges, double score)
return bestCandidate;
}
+ /**
+ * Executes the deletion of the best candidate edge from the graph.
+ * This method removes the edge from the graph and updates the local score map.
+ * It also rebuilds the pattern after the deletion and updates the number of inserted edges.
+ * @param bestCandidate The best candidate edge to delete, containing the tail, head, conditioning set, and score.
+ * @return The score after the edge deletion is executed.
+ * This score reflects the new state of the graph after the edge has been removed.
+ */
private double executeEdgeDeletion(EdgeCandidate bestCandidate) {
Node bestTail;
Node bestHead;
@@ -227,17 +300,46 @@ private double executeEdgeDeletion(EdgeCandidate bestCandidate) {
return score;
}
+ /**
+ * Creates the output DAG from the final graph after applying the Backward Equivalence Search.
+ * This method ensures that the final graph is a valid DAG by removing any cycles and undirected edges.
+ * It converts the graph from a PDAG to a DAG and rebuilds the output DAG from the final graph.
+ * The output DAG contains all nodes and directed edges, ensuring that it is acyclic.
+ *
+ * @see Utils#pdagToDag(Graph)
+ * @see Dag
+ */
+ private void createOutputDag() {
+ // Rebuild the pattern to ensure the final graph is a DAG
+ pdagToDag(graph);
+
+ // Rebuild the output DAG from the final graph
+ this.outputDag = new Dag();
+ for (Node node : graph.getNodes()) this.outputDag.addNode(node);
+ Node nodeT, nodeH;
+ for (Edge e : graph.getEdges()){
+ if(!e.isDirected()) continue;
+ Endpoint endpoint1 = e.getEndpoint1();
+ if (endpoint1.equals(Endpoint.ARROW)){
+ nodeT = e.getNode1();
+ nodeH = e.getNode2();
+ }else{
+ nodeT = e.getNode2();
+ nodeH = e.getNode1();
+ }
+ if(!this.outputDag.paths().existsDirectedPath(nodeT, nodeH)) this.outputDag.addEdge(e);
+ }
+ }
+
/**
- * Fully direct a graph with background knowledge. I am not sure how to
- * adapt Chickering's suggested algorithm above (dagToPdag) to incorporate
- * background knowledge, so I am also implementing this algorithm based on
- * Meek's 1995 UAI paper. Notice it is the same implemented in PcSearch.
- * *IMPORTANT!* *It assumes all colliders are oriented, as well as
+ * Transforms a dag into a pdag assuming that all colliders are oriented, as well as
* arrows dictated by time order.*
- *
- * ELIMINADO BACKGROUND KNOWLEDGE
+ * @param graph The graph to transform into a PDAG.
+ * @see MeekRules
+ * @see GraphSearchUtils#basicCpdag(Graph)
+ *
*/
private void pdag(Graph graph) {
MeekRules rules = new MeekRules();
@@ -245,6 +347,7 @@ private void pdag(Graph graph) {
rules.orientImplied(graph);
}
+
private static List getHNeighbors(Node x, Node y, Graph graph) {
List hNeighbors = new LinkedList<>(graph.getAdjacentNodes(y));
hNeighbors.retainAll(graph.getAdjacentNodes(x));
@@ -308,7 +411,7 @@ private double scoreGraphChangeDelete(Node y, Node x, Set set){
LinkedList conditioning = new LinkedList<>();
conditioning.addAll(set);
for(Dag g: this.initialDags){
- if(!dSeparated(g,y, x, conditioning)) return 0.0;
+ if(!Utils.dSeparated(g,y, x, conditioning)) return 0.0;
}
eval = 1.0; //eval / (double) this.setOfdags.size();
val = eval;
@@ -319,103 +422,7 @@ private double scoreGraphChangeDelete(Node y, Node x, Set set){
}
}
- boolean dSeparated(Dag g, Node x, Node y, LinkedList cond){
-
- LinkedList open = new LinkedList();
- HashMap close = new HashMap();
- open.add(x);
- open.add(y);
- open.addAll(cond);
- while (open.size() != 0){
- Node a = open.getFirst();
- open.remove(a);
- close.put(a.toString(),a);
- List pa =g.getParents(a);
- for(Node p : pa){
- if(close.get(p.toString()) == null){
- if(!open.contains(p)) open.addLast(p);
- }
- }
- }
- Graph aux = new EdgeListGraph();
-
- for (Node node : g.getNodes()) aux.addNode(node);
- Node nodeT, nodeH;
- for (Edge e : g.getEdges()){
- if(!e.isDirected()) continue;
- nodeT = e.getNode1();
- nodeH = e.getNode2();
- if((close.get(nodeH.toString())!=null)&&(close.get(nodeT.toString())!=null)){
- Edge newEdge = new Edge(e.getNode1(),e.getNode2(),e.getEndpoint1(),e.getEndpoint2());
- aux.addEdge(newEdge);
- }
- }
-
- close = new HashMap();
- for(Edge e: aux.getEdges()){
- if(e.isDirected()){
- Node h;
- if(e.getEndpoint1()==Endpoint.ARROW){
- h = e.getNode1();
- }else h = e.getNode2();
- if(close.get(h.toString())==null){
- close.put(h.toString(),h);
- List pa = aux.getParents(h);
- if(pa.size()>1){
- for(int i = 0 ; i< pa.size() - 1; i++)
- for(int j = i+1; j < pa.size(); j++){
- Node p1 = pa.get(i);
- Node p2 = pa.get(j);
- boolean found = false;
- for(Edge edge : aux.getEdges()){
- if(edge.getNode1().equals(p1)&&(edge.getNode2().equals(p2))){
- found = true;
- break;
- }
- if(edge.getNode2().equals(p1)&&(edge.getNode1().equals(p2))){
- found = true;
- break;
- }
- }
- if(!found) aux.addUndirectedEdge(p1, p2);
- }
- }
-
- }
- }
- }
-
- for(Edge e: aux.getEdges()){
- if(e.isDirected()){
- e.setEndpoint1(Endpoint.TAIL);
- e.setEndpoint2(Endpoint.TAIL);
- }
- }
-
- aux.removeNodes(cond);
-
- open = new LinkedList();
- close = new HashMap();
- open.add(x);
- while (open.size() != 0){
- Node a = open.getFirst();
- if(a.equals(y)) return false;
- open.remove(a);
- close.put(a.toString(),a);
- List pa =aux.getAdjacentNodes(a);
- for(Node p : pa){
- if(p == null) continue;
- if(close.get(p.toString()) == null){
- if(!open.contains(p)) open.addLast(p);
- }
- }
- }
-
- return true;
- }
-
-
private static boolean isClique(List set, Graph graph) {
List setv = new LinkedList(set);
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
index d4c7a7f..0831f93 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
@@ -1,10 +1,16 @@
package es.uclm.i3a.simd.consensusBN;
+import java.util.ArrayDeque;
import java.util.ArrayList;
+import java.util.Deque;
+import java.util.HashSet;
import java.util.List;
+import java.util.Set;
+import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
+import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
@@ -64,5 +70,108 @@ public static void pdagToDag(Graph graph){
nodes.remove(x);
}while(nodes.size() > 0);
}
+
+
+ public static boolean dSeparated(Dag g, Node x, Node y, List cond) {
+
+ Set relevantNodes = findRelevantNodes(g, x, y, cond);
+ Graph aux = buildInducedSubgraph(g, relevantNodes);
+ moralize(aux);
+ convertToUndirected(aux);
+ aux.removeNodes(cond);
+ return !isReachable(aux, x, y);
+ }
+
+ private static Set findRelevantNodes(Dag g, Node x, Node y, List cond) {
+ Set visited = new HashSet<>();
+ Deque stack = new ArrayDeque<>();
+
+ stack.push(x);
+ stack.push(y);
+ for (Node c : cond) stack.push(c);
+
+ while (!stack.isEmpty()) {
+ Node current = stack.pop();
+ if (visited.add(current)) {
+ for (Node parent : g.getParents(current)) {
+ stack.push(parent);
+ }
+ }
+ }
+
+ return visited;
+ }
+
+ private static Graph buildInducedSubgraph(Dag g, Set nodesToKeep) {
+ Graph subgraph = new EdgeListGraph();
+
+ for (Node node : g.getNodes()) {
+ if (nodesToKeep.contains(node)) {
+ subgraph.addNode(node);
+ }
+ }
+
+ for (Edge e : g.getEdges()) {
+ if (!e.isDirected()) continue;
+
+ Node tail = e.getNode1();
+ Node head = e.getNode2();
+
+ if (nodesToKeep.contains(tail) && nodesToKeep.contains(head)) {
+ subgraph.addEdge(new Edge(tail, head, e.getEndpoint1(), e.getEndpoint2()));
+ }
+ }
+
+ return subgraph;
+ }
+
+ private static void moralize(Graph graph) {
+ for (Node child : graph.getNodes()) {
+ List parents = graph.getParents(child);
+ int n = parents.size();
+ if (n <= 1) continue;
+
+ for (int i = 0; i < n - 1; i++) {
+ for (int j = i + 1; j < n; j++) {
+ Node p1 = parents.get(i);
+ Node p2 = parents.get(j);
+ if (!graph.isAdjacentTo(p1, p2)) {
+ graph.addUndirectedEdge(p1, p2);
+ }
+ }
+ }
+ }
+ }
+
+ private static void convertToUndirected(Graph graph) {
+ for (Edge e : new ArrayList<>(graph.getEdges())) {
+ if (e.isDirected()) {
+ e.setEndpoint1(Endpoint.TAIL);
+ e.setEndpoint2(Endpoint.TAIL);
+ }
+ }
+ }
+
+ private static boolean isReachable(Graph g, Node start, Node target) {
+ Set visited = new HashSet<>();
+ Deque stack = new ArrayDeque<>();
+ stack.push(start);
+
+ while (!stack.isEmpty()) {
+ Node current = stack.pop();
+ if (current.equals(target)) return true;
+ if (visited.add(current)) {
+ for (Node neighbor : g.getAdjacentNodes(current)) {
+ if (!visited.contains(neighbor)) {
+ stack.push(neighbor);
+ }
+ }
+ }
+ }
+
+ return false;
+ }
+
+
}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
new file mode 100644
index 0000000..55f44d6
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
@@ -0,0 +1,102 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.Edges;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+
+public class UtilsTest {
+ // d-separation tests
+
+ private Node node(String name) {
+ return new GraphNode(name);
+ }
+
+ private Dag createDag(Edge... edges) {
+ Dag dag = new Dag();
+ for (Edge edge : edges) {
+ dag.addNode(edge.getNode1());
+ dag.addNode(edge.getNode2());
+ dag.addDirectedEdge(edge.getNode1(), edge.getNode2());
+ }
+ return dag;
+ }
+
+ @Test
+ public void testDirectConnection() {
+ Node A = node("A"), B = node("B");
+ Dag dag = createDag(Edges.directedEdge(A, B));
+ List conditioning = Collections.emptyList();
+ assertFalse(Utils.dSeparated(dag, A, B, conditioning));
+ }
+
+ @Test
+ public void testChainNoCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
+ List conditioning = Collections.emptyList();
+ assertFalse(Utils.dSeparated(dag, A, C, conditioning));
+ }
+
+ @Test
+ public void testChainWithCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
+
+ assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
+ }
+
+ @Test
+ public void testColliderNoCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B));
+ List conditioning = Collections.emptyList();
+ assertTrue(Utils.dSeparated(dag, A, C, conditioning));
+ }
+
+ @Test
+ public void testColliderConditionedOnCollider() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B));
+
+ assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
+ }
+
+ @Test
+ public void testDivergingNoCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C));
+ List conditioning = Collections.emptyList();
+ assertFalse(Utils.dSeparated(dag, A, C, conditioning));
+ }
+
+ @Test
+ public void testDivergingConditionedOnCommonParent() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C));
+
+ assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
+ }
+
+ @Test
+ public void testColliderConditionedOnDescendant() {
+ Node A = node("A"), B = node("B"), C = node("C"), D = node("D");
+ Dag dag = createDag(
+ Edges.directedEdge(A, B),
+ Edges.directedEdge(C, B),
+ Edges.directedEdge(B, D)
+ );
+
+ assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(D)));
+ }
+
+}
From 6888065e6b667a97f3e66b41c010890c189b679a Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Fri, 18 Jul 2025 10:05:46 +0200
Subject: [PATCH 10/32] Moving findNaXY to Utils as static and cleaning BESd
---
.../uclm/i3a/simd/consensusBN/AlphaOrder.java | 1 -
.../BackwardEquivalenceSearchDSep.java | 80 ++++++-----
.../consensusBN/HeuristicConsensusBES.java | 29 +---
.../consensusBN/PairWiseConsensusBES.java | 5 +-
.../es/uclm/i3a/simd/consensusBN/Utils.java | 28 ++++
.../BackwardEquivalenceSearchDSepTest.java | 80 +++++------
.../uclm/i3a/simd/consensusBN/UtilsTest.java | 124 ++++++++++++++++++
7 files changed, 250 insertions(+), 97 deletions(-)
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java b/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java
index 7865c15..35f99ac 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java
@@ -253,7 +253,6 @@ private void coverEdge(Dag g, Node nodeAlpha, Node child) {
Edge pay = g.getEdge(nodep, child);
if(pay == null)
g.addEdge(new Edge(nodep,child,Endpoint.TAIL,Endpoint.ARROW));
-
}
// Adding edges from parents of child to nodeAlpha.
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
index 5331c50..06f12a0 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
@@ -231,7 +231,7 @@ private EdgeCandidate calculateBestCandidateEdge(List edges, double score)
SubSet hSubset=hSubsets.nextElement();
// Checking if {naYXH} \ {hSubset} is a clique
- List naYXH = findNaYX(candidateTail, candidateHead, graph);
+ List naYXH = Utils.findNaYX(candidateTail, candidateHead, graph);
naYXH.removeAll(hSubset);
if (!isClique(naYXH, graph)) {
continue;
@@ -348,6 +348,16 @@ private void pdag(Graph graph) {
}
+ /**
+ * Finds all neighbors of node x that are adjacent to node y in the graph.
+ * This method retrieves the neighbors of node y that are also adjacent to node x,
+ * ensuring that the edges between them are undirected.
+ * It filters out undirected edges to ensure that only neighbors from directed edges are considered.
+ * @param x Node x to find neighbors for.
+ * @param y Node y to find neighbors for.
+ * @param graph The graph in which to find the neighbors.
+ * @return A list of nodes that are neighbors of x and y, filtered to include only neighbors from directed edges.
+ */
private static List getHNeighbors(Node x, Node y, Graph graph) {
List hNeighbors = new LinkedList<>(graph.getAdjacentNodes(y));
hNeighbors.retainAll(graph.getAdjacentNodes(x));
@@ -363,45 +373,53 @@ private static List getHNeighbors(Node x, Node y, Graph graph) {
return hNeighbors;
}
- private static void delete(Node x, Node y, Set subset, Graph graph) {
- graph.removeEdges(x, y);
+ /**
+ * Applies the delete operation from Chickering 2002 for the edge x->y in the graph, and updates the edges
+ * connecting x and y to the nodes in the provided subset. This is done to ensure that the same dependency structure is maintained
+ * while removing the edge between x and y.
+ * @param tailNode The tail node of the edge to be deleted.
+ * @param headNode The head node of the edge to be deleted.
+ * @param subset The set of nodes that will be connected to the tail and head nodes after the deletion.
+ * @param graph The graph from which the edge is deleted and the connections are updated.
+ */
+ private static void delete(Node tailNode, Node headNode, Set subset, Graph graph) {
+ graph.removeEdges(tailNode, headNode);
for (Node aSubset : subset) {
- if (!graph.isParentOf(aSubset, x) && !graph.isParentOf(x, aSubset)) {
- graph.removeEdge(x, aSubset);
- graph.addDirectedEdge(x, aSubset);
+ if (!graph.isParentOf(aSubset, tailNode) && !graph.isParentOf(tailNode, aSubset)) {
+ graph.removeEdge(tailNode, aSubset);
+ graph.addDirectedEdge(tailNode, aSubset);
}
- graph.removeEdge(y, aSubset);
- graph.addDirectedEdge(y, aSubset);
+ graph.removeEdge(headNode, aSubset);
+ graph.addDirectedEdge(headNode, aSubset);
}
}
- private double deleteEval(Node x, Node y, SubSet h, Graph graph){
-
- Set set1 = new HashSet(findNaYX(x, y, graph));
- set1.removeAll(h);
- set1.addAll(graph.getParents(y));
- set1.remove(x);
- return scoreGraphChangeDelete(y, x, set1); // calcular si y esta d-separado de x dado el set1 en cada grafo.
+ /**
+ * Evaluates the impact of deleting an edge from the graph based on d-separation.
+ *
+ * This method computes a score for deleting the edge from {@code x} to {@code y},
+ * taking into account a conditioning set of nodes {@code conditioningSet}. It uses
+ * structural information from the graph to assess whether {@code y} is d-separated
+ * from {@code x} given the constructed conditioning set.
+ *
+ * @param x The source node of the edge to be deleted.
+ * @param y The target node of the edge to be deleted.
+ * @param conditioningSet The set of nodes used as conditioning variables (Z) for d-separation.
+ * @param graph The graph in which the change is being evaluated.
+ * @return The score resulting from deleting the edge, based on the given context.
+ */
+ private double deleteEval(Node x, Node y, SubSet conditioningSet, Graph graph){
+ // Setup the conditioning set for d-separation by removing the conditioning nodes from the naYX set, adding the parents of y and removing x.
+ Set finalConditioningSet = new HashSet<>(Utils.findNaYX(x, y, graph));
+ finalConditioningSet.removeAll(conditioningSet);
+ finalConditioningSet.addAll(graph.getParents(y));
+ finalConditioningSet.remove(x);
+ // Check if y is d-separated from x given the final conditioning set in each graph.
+ return scoreGraphChangeDelete(y, x, finalConditioningSet);
}
- private static List findNaYX(Node x, Node y, Graph graph) {
- List naYX = new LinkedList<>(graph.getAdjacentNodes(y));
- naYX.retainAll(graph.getAdjacentNodes(x));
-
- for (int i = naYX.size()-1; i >= 0; i--) {
- Node z = naYX.get(i);
- Edge edge = graph.getEdge(y, z);
-
- if (!Edges.isUndirectedEdge(edge)) {
- naYX.remove(z);
- }
- }
-
- return naYX;
- }
-
private double scoreGraphChangeDelete(Node y, Node x, Set set){
String key = y.getName()+x.getName()+set.toString();
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java
index adddc2a..8d2a496 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java
@@ -40,12 +40,12 @@ public class HeuristicConsensusBES {
public HeuristicConsensusBES(ArrayList dags, double percentage){
this.setOfdags = dags;
this.heuristic = new AlphaOrder(this.setOfdags);
- this.heuristic.computeAlphaH2();
- this.alpha = this.heuristic.alpha;
+ this.heuristic.computeAlpha();
+ this.alpha = this.heuristic.getOrder();
this.imaps2alpha = new TransformDags(this.setOfdags,this.alpha);
this.imaps2alpha.transform();
this.numberOfInsertedEdges = imaps2alpha.getNumberOfInsertedEdges();
- this.setOfOutDags = imaps2alpha.setOfOutputDags;
+ this.setOfOutDags = imaps2alpha.getSetOfOutputDags();
this.percentage = percentage;
}
@@ -58,7 +58,7 @@ private void consensusUnion(){
this.union = new Dag(this.alpha);
for(Node nodei: this.alpha){
- for(Dag d : this.imaps2alpha.setOfOutputDags){
+ for(Dag d : this.imaps2alpha.getSetOfOutputDags()){
Listparent = d.getParents(nodei);
for(Node pa: parent){
if(!this.union.isParentOf(pa, nodei)) this.union.addEdge(new Edge(pa,nodei,Endpoint.TAIL,Endpoint.ARROW));
@@ -128,7 +128,7 @@ public void fusion(){
}
// INICIO TEST 1
- List naYXH = findNaYX(_x, _y, graph);
+ List naYXH = Utils.findNaYX(_x, _y, graph);
naYXH.removeAll(hSubset);
if (!isClique(naYXH, graph)) {
// hSubsets.firstTest(true); // Si pasa para H entonces pasa para cualquier H' | H' contiene H
@@ -250,7 +250,7 @@ private static List getHNeighbors(Node x, Node y, Graph graph) {
double deleteEval(Node x, Node y, SubSet h, Graph graph){
- Set set1 = new HashSet(findNaYX(x, y, graph));
+ Set set1 = new HashSet(Utils.findNaYX(x, y, graph));
set1.removeAll(h);
set1.addAll(graph.getParents(y));
set1.remove(x);
@@ -374,23 +374,6 @@ boolean dSeparated(Dag g, Node x, Node y, LinkedList cond){
return true;
}
-
-
- private static List findNaYX(Node x, Node y, Graph graph) {
- List naYX = new LinkedList(graph.getAdjacentNodes(y));
- naYX.retainAll(graph.getAdjacentNodes(x));
-
- for (int i = naYX.size()-1; i >= 0; i--) {
- Node z = naYX.get(i);
- Edge edge = graph.getEdge(y, z);
-
- if (!Edges.isUndirectedEdge(edge)) {
- naYX.remove(z);
- }
- }
-
- return naYX;
- }
public Dag getFusion(){
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java
index 3b053bb..bbd639e 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java
@@ -2,7 +2,6 @@
import java.util.ArrayList;
-
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Node;
@@ -31,7 +30,7 @@ public void getFusion(){
conBES = new ConsensusBES(setOfDags);
conBES.fusion();
this.numberOfInsertedEdges = conBES.getNumberOfInsertedEdges();
- this.numberOfUnionEdges = conBES.union.getNumEdges();
+ this.numberOfUnionEdges = conBES.getUnion().getNumEdges();
this.conDAG = conBES.getFusion();
}
@@ -49,7 +48,7 @@ public int getHammingDistance(){
for(Edge ed: this.conDAG.getEdges()){
Node tail = ed.getNode1();
Node head = ed.getNode2();
- for(Dag g: conBES.setOfOutDags){
+ for(Dag g: conBES.getTransformedDags()){
Edge edge1 = g.getEdge(tail, head);
Edge edge2 = g.getEdge(head, tail);
if(edge1 == null && edge2==null) distance++;
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
index 0831f93..3b3fa02 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
@@ -4,12 +4,14 @@
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashSet;
+import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
+import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
@@ -173,5 +175,31 @@ private static boolean isReachable(Graph g, Node start, Node target) {
}
+ /**
+ * Finds the nodes that are neighbors of node y and x in the graph.
+ * This method retrieves the neighbors of node y that are also adjacent to node x,
+ * ensuring that the edges between them are directed.
+ * It filters out undirected edges to ensure that only neighbors from directed edges are considered.
+ * @param x Node x to find neighbors for.
+ * @param y Node y to find neighbors for.
+ * @param graph The graph in which to find the neighbors.
+ * @return A list of nodes that are neighbors of x and y, filtered to include only neighbors from directed edges.
+ */
+ public static List findNaYX(Node x, Node y, Graph graph) {
+ List naYX = new LinkedList<>(graph.getAdjacentNodes(y));
+ naYX.retainAll(graph.getAdjacentNodes(x));
+
+ for (int i = naYX.size()-1; i >= 0; i--) {
+ Node z = naYX.get(i);
+ Edge edge = graph.getEdge(y, z);
+
+ if (!Edges.isUndirectedEdge(edge)) {
+ naYX.remove(z);
+ }
+ }
+
+ return naYX;
+ }
+
}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
index 7078928..1e8c626 100644
--- a/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
@@ -16,39 +16,28 @@
class BackwardEquivalenceSearchDSepTest {
- private Dag createSimpleDag() {
- // A -> B -> C
- Node a = new GraphNode("A");
- Node b = new GraphNode("B");
- Node c = new GraphNode("C");
-
- Dag dag = new Dag();
- dag.addNode(a);
- dag.addNode(b);
- dag.addNode(c);
- dag.addDirectedEdge(a, b);
- dag.addDirectedEdge(b, c);
+ private ArrayList createRandomDagList(int copies) {
+ RandomBN setOfDags = new RandomBN(0, 20, 50,
+ copies,3);
+ setOfDags.setMaxInDegree(4);
+ setOfDags.setMaxOutDegree(4);
+ setOfDags.generate();
- return dag;
- }
-
- private ArrayList createDagList(int copies) {
- ArrayList list = new ArrayList<>();
- for (int i = 0; i < copies; i++) {
- list.add(createSimpleDag());
- }
- return list;
+ return setOfDags.setOfRandomDags;
}
@Test
void testApplyBESdDoesNotThrow() {
- Dag unionDag = createSimpleDag();
- ArrayList initialDags = createDagList(3);
- ArrayList transformedDags = createDagList(3);
-
- BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
-
+ // Setting up consensus union
+ ArrayList initialDags = createRandomDagList(3);
+ ConsensusUnion consensusUnion = new ConsensusUnion(initialDags);
+ Dag unionDag = consensusUnion.union();
+
+ // Running Backward Equivalence Search with d-separation
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, consensusUnion.getTransformedDags());
+
+ // No exceptions should be thrown during the process
assertDoesNotThrow(() -> {
Dag output = besd.applyBackwardEliminationWithDSeparation();
assertNotNull(output);
@@ -57,11 +46,13 @@ void testApplyBESdDoesNotThrow() {
@Test
void testOutputIsDAG() {
- Dag unionDag = createSimpleDag();
- ArrayList initialDags = createDagList(2);
- ArrayList transformedDags = createDagList(2);
-
- BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
+ // Setting up consensus union
+ ArrayList initialDags = createRandomDagList(3);
+ ConsensusUnion consensusUnion = new ConsensusUnion(initialDags);
+ Dag unionDag = consensusUnion.union();
+
+ // Running Backward Equivalence Search with d-separation
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, consensusUnion.getTransformedDags());
Dag outputDag = besd.applyBackwardEliminationWithDSeparation();
assertTrue(GraphUtils.isDag(outputDag), "El resultado no es un DAG válido.");
@@ -82,12 +73,19 @@ void testAristasSePuedenEliminar() {
Dag dag1 = new Dag();
dag1.addNode(a);
dag1.addNode(b);
+
+ Dag dag2 = new Dag();
+ dag2.addNode(a);
+ dag2.addNode(b);
+ // Aquí no hay aristas, A y B están desconectados
// sin conexión
ArrayList initialDags = new ArrayList<>();
initialDags.add(dag1);
- ArrayList transformedDags = new ArrayList<>();
- transformedDags.add(dag1);
+ initialDags.add(dag2);
+ AlphaOrder alphaOrder = new AlphaOrder(initialDags);
+ alphaOrder.computeAlpha();
+ ArrayList transformedDags = (new TransformDags(initialDags, alphaOrder.getOrder())).transform();
BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
Dag outputDag = besd.applyBackwardEliminationWithDSeparation();
@@ -99,14 +97,18 @@ void testAristasSePuedenEliminar() {
@Test
void testGetNumberOfInsertedEdgesReflectsChanges() {
- Dag unionDag = createSimpleDag();
- ArrayList dags = createDagList(2);
+ ArrayList initialDags = createRandomDagList(2);
+ ConsensusUnion consensusUnion = new ConsensusUnion(initialDags);
+ Dag unionDag = consensusUnion.union();
+ ArrayList transformedDags = consensusUnion.getTransformedDags();
+ int insertedEdgesBefore = consensusUnion.getNumberOfInsertedEdges();
- BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, dags, dags);
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
besd.applyBackwardEliminationWithDSeparation();
- int insertedEdges = besd.getNumberOfInsertedEdges();
+ int insertedEdgesAfter = besd.getNumberOfInsertedEdges();
// En el peor de los casos no ha eliminado ninguna, pero nunca debe ser negativo
- assertTrue(insertedEdges >= 0, "El número de aristas insertadas no puede ser negativo.");
+ assertTrue(insertedEdgesAfter >= 0, "The number of inserted edges should not be negative.");
+ assertTrue(insertedEdgesAfter <= insertedEdgesBefore, "The number of inserted edges should decrease after BES.");
}
}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
index 55f44d6..89b43ad 100644
--- a/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
@@ -3,13 +3,17 @@
import java.util.Collections;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.Test;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
+import edu.cmu.tetrad.graph.Endpoint;
+import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
@@ -99,4 +103,124 @@ public void testColliderConditionedOnDescendant() {
assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(D)));
}
+ //find naYX tests
+
+ @Test
+ public void testFindNaYX_singleUndirectedCommonNeighbor() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node z = new GraphNode("Z");
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(z);
+
+ graph.addEdge(new Edge(x, z, Endpoint.TAIL, Endpoint.TAIL)); // undirected
+ graph.addEdge(new Edge(y, z, Endpoint.TAIL, Endpoint.TAIL)); // undirected
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertEquals(1, result.size());
+ assertTrue(result.contains(z));
+ }
+
+ @Test
+ public void testFindNaYX_directedEdgeShouldBeExcluded() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node z = new GraphNode("Z");
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(z);
+
+ graph.addEdge(new Edge(x, z, Endpoint.ARROW, Endpoint.TAIL)); // x → z
+ graph.addEdge(new Edge(y, z, Endpoint.ARROW, Endpoint.TAIL)); // y → z
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertEquals(0, result.size());
+ }
+
+ @Test
+ public void testFindNaYX_mixedNeighbors() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node z1 = new GraphNode("Z1"); // undirected common
+ Node z2 = new GraphNode("Z2"); // directed common
+ Node z3 = new GraphNode("Z3"); // only adjacent to x
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(z1);
+ graph.addNode(z2);
+ graph.addNode(z3);
+
+ // z1: undirected edge with both
+ graph.addEdge(new Edge(x, z1, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(y, z1, Endpoint.TAIL, Endpoint.TAIL));
+
+ // z2: directed edges with both
+ graph.addEdge(new Edge(x, z2, Endpoint.TAIL, Endpoint.ARROW));
+ graph.addEdge(new Edge(y, z2, Endpoint.TAIL, Endpoint.ARROW));
+
+ // z3: only adjacent to x
+ graph.addEdge(new Edge(x, z3, Endpoint.TAIL, Endpoint.TAIL));
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertEquals(1, result.size());
+ assertTrue(result.contains(z1));
+ assertFalse(result.contains(z2));
+ assertFalse(result.contains(z3));
+ }
+
+ @Test
+ public void testFindNaYX_multipleUndirectedCommonNeighbors() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node a = new GraphNode("A");
+ Node b = new GraphNode("B");
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(a);
+ graph.addNode(b);
+
+ graph.addEdge(new Edge(x, a, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(y, a, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(x, b, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(y, b, Endpoint.TAIL, Endpoint.TAIL));
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertEquals(2, result.size());
+ assertTrue(result.contains(a));
+ assertTrue(result.contains(b));
+ }
+
+ @Test
+ public void testFindNaYX_noCommonNeighbors() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node a = new GraphNode("A");
+ Node b = new GraphNode("B");
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(a);
+ graph.addNode(b);
+
+ graph.addEdge(new Edge(x, a, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(y, b, Endpoint.TAIL, Endpoint.TAIL));
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertTrue(result.isEmpty());
+ }
}
From f5873773f7501ea178b42b39f3a2d39522ffd2ef Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Fri, 18 Jul 2025 12:06:48 +0200
Subject: [PATCH 11/32] Adding DSeparationKey for local map optimization
---
.../BackwardEquivalenceSearchDSep.java | 34 +++----
.../i3a/simd/consensusBN/DSeparationKey.java | 60 ++++++++++++
.../simd/consensusBN/DSeparationKeyTest.java | 91 +++++++++++++++++++
.../uclm/i3a/simd/consensusBN/UtilsTest.java | 17 ++++
4 files changed, 185 insertions(+), 17 deletions(-)
create mode 100644 src/main/java/es/uclm/i3a/simd/consensusBN/DSeparationKey.java
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
index 06f12a0..a7f08c5 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
@@ -72,7 +72,7 @@ public class BackwardEquivalenceSearchDSep {
* This map is used to cache the scores of edge deletions to avoid redundant calculations.
* The key is a string representation of the edge and its conditioning set, and the value is the score.
*/
- private final Map localScore = new HashMap<>();
+ private final Map localScore = new HashMap<>();
/**
* Number of edges inserted during the consensus union and backward equivalence search process.
@@ -420,24 +420,24 @@ private double deleteEval(Node x, Node y, SubSet conditioningSet, Graph graph){
return scoreGraphChangeDelete(y, x, finalConditioningSet);
}
- private double scoreGraphChangeDelete(Node y, Node x, Set set){
-
- String key = y.getName()+x.getName()+set.toString();
- Double val = this.localScore.get(key);
- if(val == null){
- double eval = 0.0;
- LinkedList conditioning = new LinkedList<>();
- conditioning.addAll(set);
- for(Dag g: this.initialDags){
- if(!Utils.dSeparated(g,y, x, conditioning)) return 0.0;
+ private double scoreGraphChangeDelete(Node y, Node x, Set conditioningSet) {
+ DSeparationKey key = new DSeparationKey(y, x, conditioningSet);
+ Double cached = localScore.get(key);
+
+ if (cached != null) {
+ return cached;
+ }
+
+ // Evaluamos d-separación en todos los DAGs
+ for (Dag g : this.initialDags) {
+ if (!Utils.dSeparated(g, y, x, new ArrayList<>(conditioningSet))) {
+ localScore.put(key, 0.0);
+ return 0.0;
}
- eval = 1.0; //eval / (double) this.setOfdags.size();
- val = eval;
- this.localScore.put(key, val);
- return eval;
- }else{
- return val.doubleValue();
}
+
+ localScore.put(key, 1.0);
+ return 1.0;
}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/DSeparationKey.java b/src/main/java/es/uclm/i3a/simd/consensusBN/DSeparationKey.java
new file mode 100644
index 0000000..4e8cbb0
--- /dev/null
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/DSeparationKey.java
@@ -0,0 +1,60 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Objects;
+import java.util.Set;
+
+import edu.cmu.tetrad.graph.Node;
+
+public class DSeparationKey {
+ private final Node y;
+ private final Node x;
+ private final Set conditioningSet;
+
+ public DSeparationKey(Node x, Node y, Set conditioningSet) {
+ // Since D-separation is symmetric, we ensure a consistent order for x and y
+ if (x.getName().compareTo(y.getName()) <= 0) {
+ this.x = x;
+ this.y = y;
+ } else {
+ this.x = y;
+ this.y = x;
+ }
+ this.conditioningSet = new HashSet<>(conditioningSet); // copia defensiva
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) return true;
+ if (!(obj instanceof DSeparationKey)) return false;
+
+ DSeparationKey other = (DSeparationKey) obj;
+ return y.equals(other.y)
+ && x.equals(other.x)
+ && conditioningSet.equals(other.conditioningSet);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(y, x, conditioningSet);
+ }
+
+
+ public Node getY() {
+ return this.y;
+ }
+
+
+ public Node getX() {
+ return this.x;
+ }
+
+
+ public Set getConditioningSet() {
+ return Collections.unmodifiableSet(this.conditioningSet);
+ }
+
+
+}
+
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java
new file mode 100644
index 0000000..4d5d880
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java
@@ -0,0 +1,91 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+class DSeparationKeyTest {
+
+ private final Node X = new GraphNode("X");
+ private final Node Y = new GraphNode("Y");
+ private final Node Z = new GraphNode("Z");
+ private final Node W = new GraphNode("W");
+
+ @Test
+ void testConstructorAndGetters() {
+ Set zSet = new HashSet<>(Arrays.asList(Z, W));
+ DSeparationKey key = new DSeparationKey(X, Y, zSet);
+
+ assertEquals(X, key.getX());
+ assertEquals(Y, key.getY());
+ assertEquals(new HashSet<>(Arrays.asList(Z, W)), key.getConditioningSet());
+ }
+
+ @Test
+ void testEqualsSameContentDifferentOrder() {
+ Set zSet1 = new HashSet<>(Arrays.asList(Z, W));
+ Set zSet2 = new HashSet<>(Arrays.asList(W, Z));
+
+ DSeparationKey key1 = new DSeparationKey(X, Y, zSet1);
+ DSeparationKey key2 = new DSeparationKey(X, Y, zSet2);
+
+ assertEquals(key1, key2);
+ assertEquals(key1.hashCode(), key2.hashCode());
+ }
+
+ @Test
+ void testNotEqualsDifferentZ() {
+ Set zSet1 = new HashSet<>(Collections.singletonList(Z));
+ Set zSet2 = new HashSet<>(Arrays.asList(Z, W));
+
+ DSeparationKey key1 = new DSeparationKey(X, Y, zSet1);
+ DSeparationKey key2 = new DSeparationKey(X, Y, zSet2);
+
+ assertNotEquals(key1, key2);
+ }
+
+ @Test
+ void testSimmetryBetweenXandY() {
+ DSeparationKey key1 = new DSeparationKey(X, Y, Collections.emptySet());
+ DSeparationKey key2 = new DSeparationKey(Y, X, Collections.emptySet());
+
+ assertEquals(key1, key2);
+ }
+
+ @Test
+ void testEqualsSelf() {
+ DSeparationKey key = new DSeparationKey(X, Y, Collections.singleton(Z));
+ assertEquals(key, key);
+ }
+
+ @Test
+ void testNotEqualsNullOrDifferentClass() {
+ DSeparationKey key = new DSeparationKey(X, Y, Collections.singleton(Z));
+
+ assertNotEquals(null, key);
+ assertNotEquals("NotAKey", key);
+ }
+
+ @Test
+ void testKeyAsMapKey() {
+ DSeparationKey key1 = new DSeparationKey(X, Y, new HashSet<>(Arrays.asList(Z, W)));
+ DSeparationKey key2 = new DSeparationKey(X, Y, new HashSet<>(Arrays.asList(W, Z)));
+
+ Map map = new HashMap<>();
+ map.put(key1, true);
+
+ assertTrue(map.containsKey(key2));
+ assertTrue(map.get(key2));
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
index 89b43ad..50657cf 100644
--- a/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
@@ -103,6 +103,23 @@ public void testColliderConditionedOnDescendant() {
assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(D)));
}
+ // Asegurar que esto funciona correctamente!!!!!!!
+ @Test
+ public void testSimmetryBetweenXandY() {
+ Node A = node("A"), B = node("B"), C = node("C"), D = node("D");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, A), Edges.directedEdge(D, A));
+ List conditioning = Collections.emptyList();
+
+ // No colliders between A and B, so they are not d-separated
+ assertFalse(Utils.dSeparated(dag, A, B, conditioning));
+ assertFalse(Utils.dSeparated(dag, B, A, conditioning));
+
+ // C->A and D->A makes A a collider for C and D, and therefore C and D are d-separated from each other
+ assertTrue(Utils.dSeparated(dag, C, D, conditioning));
+ assertTrue(Utils.dSeparated(dag, D, C, conditioning));
+
+ }
+
//find naYX tests
@Test
From ae6def1f14fdda2fe3af843beddd91df5e44f935 Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Fri, 18 Jul 2025 12:45:11 +0200
Subject: [PATCH 12/32] Cleaned and tested BackwardEquivalenceSearchDSep
---
.../BackwardEquivalenceSearchDSep.java | 87 ++++++++++++-------
.../i3a/simd/consensusBN/ConsensusBES.java | 2 +-
.../BackwardEquivalenceSearchDSepTest.java | 2 +-
3 files changed, 60 insertions(+), 31 deletions(-)
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
index a7f08c5..4b3334f 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
@@ -1,8 +1,6 @@
package es.uclm.i3a.simd.consensusBN;
-import java.util.ArrayDeque;
import java.util.ArrayList;
-import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
@@ -16,6 +14,7 @@
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
+import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import edu.cmu.tetrad.search.utils.MeekRules;
@@ -75,14 +74,14 @@ public class BackwardEquivalenceSearchDSep {
private final Map localScore = new HashMap<>();
/**
- * Number of edges inserted during the consensus union and backward equivalence search process.
- * This variable keeps track of the total number of edges that were added to the consensus DAG
- * during the union of transformed input DAGs and the subsequent edge deletions.
+ * Number of edges removed during the backward equivalence search process.
+ * This variable keeps track of the total number of edges that are inserted (deleted) during the
+ * Backward Equivalence Search with D-separation process.
*
* @see ConsensusUnion#getNumberOfInsertedEdges()
* @see BackwardEquivalenceSearchDSep#applyBackwardEliminationWithDSeparation()
*/
- private int numberOfInsertedEdges = 0;
+ private int numberOfRemovedEdges = 0;
/**
* Constructor for BackwardEquivalenceSearchDSep that initializes the properties for the search with a union DAG and lists of initial and transformed DAGs.
@@ -233,7 +232,7 @@ private EdgeCandidate calculateBestCandidateEdge(List edges, double score)
// Checking if {naYXH} \ {hSubset} is a clique
List naYXH = Utils.findNaYX(candidateTail, candidateHead, graph);
naYXH.removeAll(hSubset);
- if (!isClique(naYXH, graph)) {
+ if (!GraphUtils.isClique(naYXH, graph)) {
continue;
}
@@ -293,7 +292,7 @@ private double executeEdgeDeletion(EdgeCandidate bestCandidate) {
for(int g = 0; g conditioningSet) {
+ /**
+ * Checks if the deletion of an edge from {@code x} to {@code y} maintains the d-separation condition
+ * across all initial DAGs. If the edge deletion maintains d-separation, it returns a score of 1.0,
+ * otherwise it returns 0.0.
+ *
+ * This method uses a local score map to cache results for efficiency, avoiding redundant calculations
+ * for the same edge and conditioning set.
+ * @param x The tail node of the edge to be deleted.
+ * @param y The head node of the edge to be deleted.
+ * @param conditioningSet The set of nodes used as conditioning variables (Z) for d-separation.
+ * @return A score of 1.0 if the edge deletion maintains d-separation, otherwise 0.0.
+ *
+ * @see Utils#dSeparated(Dag, Node, Node, List)
+ * @see DSeparationKey
+ *
+ * This method is crucial for ensuring that the edge deletion does not violate the d-separation condition,
+ * which is essential for maintaining the integrity of the Bayesian network structure.
+ */
+ private double scoreGraphChangeDelete(Node x, Node y, Set conditioningSet) {
+ // Check if the edge deletion has already been evaluated and cached
DSeparationKey key = new DSeparationKey(y, x, conditioningSet);
Double cached = localScore.get(key);
-
if (cached != null) {
return cached;
}
- // Evaluamos d-separación en todos los DAGs
+ // Evaluating the d-separation condition across all initial DAGs
for (Dag g : this.initialDags) {
- if (!Utils.dSeparated(g, y, x, new ArrayList<>(conditioningSet))) {
+ if (!Utils.dSeparated(g, x, y, new ArrayList<>(conditioningSet))) {
localScore.put(key, 0.0);
return 0.0;
}
@@ -439,29 +456,41 @@ private double scoreGraphChangeDelete(Node y, Node x, Set conditioningSet)
localScore.put(key, 1.0);
return 1.0;
}
-
-
-
- private static boolean isClique(List set, Graph graph) {
- List setv = new LinkedList(set);
- for (int i = 0; i < setv.size() - 1; i++) {
- for (int j = i + 1; j < setv.size(); j++) {
- if (!graph.isAdjacentTo(setv.get(i), setv.get(j))) {
- return false;
- }
- }
- }
- return true;
- }
-
- public int getNumberOfInsertedEdges() {
- return this.numberOfInsertedEdges;
+ /**
+ * Returns the number of edges that were inserted during the consensus union and backward equivalence search process.
+ * @return
+ */
+ public int getNumberOfRemovedEdges() {
+ return this.numberOfRemovedEdges;
}
+ /**
+ * Class representing a candidate edge for deletion in the Backward Equivalence Search.
+ * This class encapsulates the tail and head nodes of the edge, the conditioning set used for d-separation,
+ * and the score associated with the edge deletion.
+ *
+ * @see BackwardEquivalenceSearchDSep#applyBackwardEliminationWithDSeparation()
+ * @see Utils#dSeparated(Dag, Node, Node, List)
+ */
private class EdgeCandidate {
+ /**
+ * The tail node of the edge candidate.
+ */
public final Node tail;
+
+ /**
+ * The head node of the edge candidate.
+ */
public final Node head;
+
+ /**
+ * The conditioning set used for d-separation in the edge candidate.
+ */
public final Set conditioningSet;
+
+ /**
+ * The score associated with the edge candidate deletion.
+ */
public double score;
public EdgeCandidate(Node tail, Node head, Set conditioningSet) {
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
index d8514ef..c893d85 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
@@ -97,7 +97,7 @@ public void fusion(){
BackwardEquivalenceSearchDSep bes = new BackwardEquivalenceSearchDSep(this.union, this.inputDags, this.transformedDags);
this.outputDag = bes.applyBackwardEliminationWithDSeparation();
// 3. Updating numberOfInsertedEdges
- this.numberOfInsertedEdges += bes.getNumberOfInsertedEdges();
+ this.numberOfInsertedEdges -= bes.getNumberOfRemovedEdges();
}
/**
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
index 1e8c626..f5d6a6a 100644
--- a/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
@@ -106,7 +106,7 @@ void testGetNumberOfInsertedEdgesReflectsChanges() {
BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
besd.applyBackwardEliminationWithDSeparation();
- int insertedEdgesAfter = besd.getNumberOfInsertedEdges();
+ int insertedEdgesAfter = insertedEdgesBefore - besd.getNumberOfRemovedEdges();
// En el peor de los casos no ha eliminado ninguna, pero nunca debe ser negativo
assertTrue(insertedEdgesAfter >= 0, "The number of inserted edges should not be negative.");
assertTrue(insertedEdgesAfter <= insertedEdgesBefore, "The number of inserted edges should decrease after BES.");
From 6d08d809c0437d92318860f7bfbece04d5d6fc81 Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Fri, 18 Jul 2025 14:47:09 +0200
Subject: [PATCH 13/32] Splitting UtilsTest into two
---
.../es/uclm/i3a/simd/consensusBN/Utils.java | 4 +
.../i3a/simd/consensusBN/DseparationTest.java | 333 ++++++++++++++++++
.../{UtilsTest.java => FindNaYXTest.java} | 109 +-----
3 files changed, 340 insertions(+), 106 deletions(-)
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java
rename src/test/java/es/uclm/i3a/simd/consensusBN/{UtilsTest.java => FindNaYXTest.java} (53%)
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
index 3b3fa02..751303e 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
@@ -74,6 +74,10 @@ public static void pdagToDag(Graph graph){
}
+ public static boolean dSeparated(Dag g, Node x, Node y) {
+ return dSeparated(g, x, y, new ArrayList<>());
+ }
+
public static boolean dSeparated(Dag g, Node x, Node y, List cond) {
Set relevantNodes = findRelevantNodes(g, x, y, cond);
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java
new file mode 100644
index 0000000..2f92250
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java
@@ -0,0 +1,333 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.Edges;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class DseparationTest {
+ // d-separation tests
+
+ private Node node(String name) {
+ return new GraphNode(name);
+ }
+
+ private Dag createDag(Edge... edges) {
+ Dag dag = new Dag();
+ for (Edge edge : edges) {
+ dag.addNode(edge.getNode1());
+ dag.addNode(edge.getNode2());
+ dag.addDirectedEdge(edge.getNode1(), edge.getNode2());
+ }
+ return dag;
+ }
+
+ @Test
+ public void testDirectConnection() {
+ Node A = node("A"), B = node("B");
+ Dag dag = createDag(Edges.directedEdge(A, B));
+ List conditioning = Collections.emptyList();
+ assertFalse(Utils.dSeparated(dag, A, B, conditioning));
+ }
+
+ @Test
+ public void testChainNoCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
+ List conditioning = Collections.emptyList();
+ assertFalse(Utils.dSeparated(dag, A, C, conditioning));
+ }
+
+ @Test
+ public void testChainWithCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
+
+ assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
+ }
+
+ @Test
+ public void testColliderNoCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B));
+ List conditioning = Collections.emptyList();
+ assertTrue(Utils.dSeparated(dag, A, C, conditioning));
+ }
+
+ @Test
+ public void testColliderConditionedOnCollider() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B));
+
+ assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
+ }
+
+ @Test
+ public void testDivergingNoCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C));
+ List conditioning = Collections.emptyList();
+ assertFalse(Utils.dSeparated(dag, A, C, conditioning));
+ }
+
+ @Test
+ public void testDivergingConditionedOnCommonParent() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C));
+
+ assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
+ }
+
+ @Test
+ public void testColliderConditionedOnDescendant() {
+ Node A = node("A"), B = node("B"), C = node("C"), D = node("D");
+ Dag dag = createDag(
+ Edges.directedEdge(A, B),
+ Edges.directedEdge(C, B),
+ Edges.directedEdge(B, D)
+ );
+
+ assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(D)));
+ }
+
+ @Test
+ public void testSimmetryBetweenXandY() {
+ Node A = node("A"), B = node("B"), C = node("C"), D = node("D");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, A), Edges.directedEdge(D, A));
+ List conditioning = Collections.emptyList();
+
+ // No colliders between A and B, so they are not d-separated
+ assertFalse(Utils.dSeparated(dag, A, B, conditioning));
+ assertFalse(Utils.dSeparated(dag, B, A, conditioning));
+
+ // C->A and D->A makes A a collider for C and D, and therefore C and D are d-separated from each other
+ assertTrue(Utils.dSeparated(dag, C, D, conditioning));
+ assertTrue(Utils.dSeparated(dag, D, C, conditioning));
+
+ }
+
+ @Test
+ public void testDseparationMethodsAreEquivalent() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
+
+ // Using the dSeparated method with no conditioning
+ assertFalse(Utils.dSeparated(dag, A, C));
+
+ // Using the dSeparated method with an empty conditioning set
+ assertFalse(Utils.dSeparated(dag, A, C, Collections.emptyList()));
+ }
+
+ @Test
+ public void testDseparationRule1(){
+ // Scenarios taken from https://yuyangyy.medium.com/understand-d-separation-471f9aada503
+ // Scenario 1: Z is empty, namely, we don't dondition on any variables
+ // Rule 1: If there exists a path from x to y and there is no collider on the path, then x and y are not d-separated.
+ Node x = new GraphNode("x");
+ Node r = new GraphNode("r");
+ Node s = new GraphNode("s");
+ Node t = new GraphNode("t");
+ Node u = new GraphNode("u");
+ Node v = new GraphNode("v");
+ Node y = new GraphNode("y");
+
+ Dag dag = new Dag();
+ dag.addNode(x);
+ dag.addNode(r);
+ dag.addNode(s);
+ dag.addNode(t);
+ dag.addNode(u);
+ dag.addNode(v);
+ dag.addNode(y);
+
+ // Edges: x -> r -> s -> t <- u <- v -> y
+ dag.addDirectedEdge(x, r);
+ dag.addDirectedEdge(r, s);
+ dag.addDirectedEdge(s, t);
+ dag.addDirectedEdge(u, t);
+ dag.addDirectedEdge(v, u);
+ dag.addDirectedEdge(v, y);
+
+ // Check d-separations
+ assertFalse(Utils.dSeparated(dag, x, r));
+ assertFalse(Utils.dSeparated(dag, x, s));
+ assertFalse(Utils.dSeparated(dag, x, t));
+
+ assertTrue(Utils.dSeparated(dag, x, u));
+ assertTrue(Utils.dSeparated(dag, x, v));
+ assertTrue(Utils.dSeparated(dag, x, y));
+
+ assertFalse(Utils.dSeparated(dag, u, v));
+ assertFalse(Utils.dSeparated(dag, u, y));
+ }
+
+ public void testDseparationRule2(){
+ // Scenarios taken from https://yuyangyy.medium.com/understand-d-separation-471f9aada503
+ // Scenario 2: Z is non-empty, and the colliders don't belong to Z or have no children in Z.
+ // Rule 2: If there exists a path from x to y and none of the nodes on the path belongs to Z, then x and y are not d-separated.
+ Node x = new GraphNode("x");
+ Node r = new GraphNode("r");
+ Node s = new GraphNode("s");
+ Node t = new GraphNode("t");
+ Node u = new GraphNode("u");
+ Node v = new GraphNode("v");
+ Node y = new GraphNode("y");
+
+ Dag dag = new Dag();
+ dag.addNode(x);
+ dag.addNode(r);
+ dag.addNode(s);
+ dag.addNode(t);
+ dag.addNode(u);
+ dag.addNode(v);
+ dag.addNode(y);
+
+ // Edges: x -> r -> s -> t <- u <- v -> y
+ dag.addDirectedEdge(x, r);
+ dag.addDirectedEdge(r, s);
+ dag.addDirectedEdge(s, t);
+ dag.addDirectedEdge(u, t);
+ dag.addDirectedEdge(v, u);
+ dag.addDirectedEdge(v, y);
+
+ // Creating a conditioning set Z that does not include any colliders
+ List Z = new ArrayList<>();
+ Z.add(r);
+ Z.add(v);
+
+ // Check d-separations
+ // Node x
+ assertTrue(Utils.dSeparated(dag, x, r, Z));
+ assertTrue(Utils.dSeparated(dag, x, s, Z));
+ assertTrue(Utils.dSeparated(dag, x, t, Z));
+ assertTrue(Utils.dSeparated(dag, x, u, Z));
+ assertTrue(Utils.dSeparated(dag, x, v, Z));
+ assertTrue(Utils.dSeparated(dag, x, y, Z));
+
+ // Node r
+ assertTrue(Utils.dSeparated(dag, r, s, Z));
+ assertTrue(Utils.dSeparated(dag, r, t, Z));
+ assertTrue(Utils.dSeparated(dag, r, u, Z));
+ assertTrue(Utils.dSeparated(dag, r, v, Z));
+ assertTrue(Utils.dSeparated(dag, r, y, Z));
+
+ // Node s
+ assertFalse(Utils.dSeparated(dag, s, t, Z)); // No node on the path belongs to Z, so d-separated
+ assertTrue(Utils.dSeparated(dag, s, u, Z));
+ assertTrue(Utils.dSeparated(dag, s, v, Z));
+ assertTrue(Utils.dSeparated(dag, s, y, Z));
+
+ // Node t
+ assertFalse(Utils.dSeparated(dag, t, u, Z)); // No node on the path belongs to Z, so d-separated
+ assertTrue(Utils.dSeparated(dag, t, v, Z));
+ assertTrue(Utils.dSeparated(dag, t, y, Z));
+
+ // Node u
+ assertTrue(Utils.dSeparated(dag, u, v, Z));
+ assertTrue(Utils.dSeparated(dag, u, y, Z));
+
+ // Node v
+ assertTrue(Utils.dSeparated(dag, v, y, Z));
+
+
+ }
+
+ @Test
+ public void testDseparationRule3(){
+ // Scenarios taken from https://yuyangyy.medium.com/understand-d-separation-471f9aada503
+ // Scenario 3: Z is non-empty, and there are colliders either inside Z or have children in Z.
+ // Rule 3: For colliders that fall inside Z or have children in Z, they are no longer seen as colliders.
+
+ Node x = new GraphNode("x");
+ Node r = new GraphNode("r");
+ Node s = new GraphNode("s");
+ Node t = new GraphNode("t");
+ Node u = new GraphNode("u");
+ Node v = new GraphNode("v");
+ Node y = new GraphNode("y");
+ Node w = new GraphNode("w");
+ Node p = new GraphNode("p");
+ Node q = new GraphNode("q");
+
+ Dag dag = new Dag();
+ dag.addNode(x);
+ dag.addNode(r);
+ dag.addNode(s);
+ dag.addNode(t);
+ dag.addNode(u);
+ dag.addNode(v);
+ dag.addNode(y);
+ dag.addNode(w);
+ dag.addNode(p);
+ dag.addNode(q);
+
+ // Edges: x -> r -> s -> t <- u <- v -> y + r->w, t->p, v->q
+ dag.addDirectedEdge(x, r);
+ dag.addDirectedEdge(r, s);
+ dag.addDirectedEdge(s, t);
+ dag.addDirectedEdge(u, t);
+ dag.addDirectedEdge(v, u);
+ dag.addDirectedEdge(v, y);
+ dag.addDirectedEdge(r, w);
+ dag.addDirectedEdge(t, p);
+ dag.addDirectedEdge(v, q);
+
+ // Creating a conditioning set Z that includes colliders and their children
+ List Z = new ArrayList<>();
+ Z.add(r);
+ Z.add(p);
+
+ // Check d-separations
+ // Node x
+ assertTrue(Utils.dSeparated(dag, x, s, Z));
+ assertTrue(Utils.dSeparated(dag, x, t, Z));
+ assertTrue(Utils.dSeparated(dag, x, u, Z));
+ assertTrue(Utils.dSeparated(dag, x, v, Z));
+ assertTrue(Utils.dSeparated(dag, x, y, Z));
+ assertTrue(Utils.dSeparated(dag, x, w, Z));
+ assertTrue(Utils.dSeparated(dag, x, q, Z));
+
+ // Node w
+ assertTrue(Utils.dSeparated(dag, w, s, Z));
+ assertTrue(Utils.dSeparated(dag, w, t, Z));
+ assertTrue(Utils.dSeparated(dag, w, u, Z));
+ assertTrue(Utils.dSeparated(dag, w, v, Z));
+ assertTrue(Utils.dSeparated(dag, w, y, Z));
+ assertTrue(Utils.dSeparated(dag, w, q, Z));
+
+ // Node s
+ assertFalse(Utils.dSeparated(dag, s, t, Z));
+ assertFalse(Utils.dSeparated(dag, s, u, Z));
+ assertFalse(Utils.dSeparated(dag, s, v, Z));
+ assertFalse(Utils.dSeparated(dag, s, q, Z));
+ assertFalse(Utils.dSeparated(dag, s, y, Z));
+
+ // Node t
+ assertFalse(Utils.dSeparated(dag, t, u, Z));
+ assertFalse(Utils.dSeparated(dag, t, v, Z));
+ assertFalse(Utils.dSeparated(dag, t, y, Z));
+ assertFalse(Utils.dSeparated(dag, t, q, Z));
+
+ // Node u
+ assertFalse(Utils.dSeparated(dag, u, v, Z));
+ assertFalse(Utils.dSeparated(dag, u, y, Z));
+ assertFalse(Utils.dSeparated(dag, u, q, Z));
+
+ // Node v
+ assertFalse(Utils.dSeparated(dag, v, y, Z));
+ assertFalse(Utils.dSeparated(dag, v, q, Z));
+
+ // Node y
+ assertFalse(Utils.dSeparated(dag, y, q, Z));
+
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/FindNaYXTest.java
similarity index 53%
rename from src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
rename to src/test/java/es/uclm/i3a/simd/consensusBN/FindNaYXTest.java
index 50657cf..c386850 100644
--- a/src/test/java/es/uclm/i3a/simd/consensusBN/UtilsTest.java
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/FindNaYXTest.java
@@ -1,6 +1,5 @@
package es.uclm.i3a.simd.consensusBN;
-import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -8,120 +7,17 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.Test;
-import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
-import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
-public class UtilsTest {
- // d-separation tests
-
- private Node node(String name) {
- return new GraphNode(name);
- }
-
- private Dag createDag(Edge... edges) {
- Dag dag = new Dag();
- for (Edge edge : edges) {
- dag.addNode(edge.getNode1());
- dag.addNode(edge.getNode2());
- dag.addDirectedEdge(edge.getNode1(), edge.getNode2());
- }
- return dag;
- }
-
- @Test
- public void testDirectConnection() {
- Node A = node("A"), B = node("B");
- Dag dag = createDag(Edges.directedEdge(A, B));
- List conditioning = Collections.emptyList();
- assertFalse(Utils.dSeparated(dag, A, B, conditioning));
- }
-
- @Test
- public void testChainNoCondition() {
- Node A = node("A"), B = node("B"), C = node("C");
- Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
- List conditioning = Collections.emptyList();
- assertFalse(Utils.dSeparated(dag, A, C, conditioning));
- }
-
- @Test
- public void testChainWithCondition() {
- Node A = node("A"), B = node("B"), C = node("C");
- Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
-
- assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
- }
-
- @Test
- public void testColliderNoCondition() {
- Node A = node("A"), B = node("B"), C = node("C");
- Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B));
- List conditioning = Collections.emptyList();
- assertTrue(Utils.dSeparated(dag, A, C, conditioning));
- }
-
- @Test
- public void testColliderConditionedOnCollider() {
- Node A = node("A"), B = node("B"), C = node("C");
- Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B));
-
- assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
- }
-
- @Test
- public void testDivergingNoCondition() {
- Node A = node("A"), B = node("B"), C = node("C");
- Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C));
- List conditioning = Collections.emptyList();
- assertFalse(Utils.dSeparated(dag, A, C, conditioning));
- }
-
- @Test
- public void testDivergingConditionedOnCommonParent() {
- Node A = node("A"), B = node("B"), C = node("C");
- Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C));
-
- assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
- }
-
- @Test
- public void testColliderConditionedOnDescendant() {
- Node A = node("A"), B = node("B"), C = node("C"), D = node("D");
- Dag dag = createDag(
- Edges.directedEdge(A, B),
- Edges.directedEdge(C, B),
- Edges.directedEdge(B, D)
- );
-
- assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(D)));
- }
-
- // Asegurar que esto funciona correctamente!!!!!!!
- @Test
- public void testSimmetryBetweenXandY() {
- Node A = node("A"), B = node("B"), C = node("C"), D = node("D");
- Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, A), Edges.directedEdge(D, A));
- List conditioning = Collections.emptyList();
-
- // No colliders between A and B, so they are not d-separated
- assertFalse(Utils.dSeparated(dag, A, B, conditioning));
- assertFalse(Utils.dSeparated(dag, B, A, conditioning));
-
- // C->A and D->A makes A a collider for C and D, and therefore C and D are d-separated from each other
- assertTrue(Utils.dSeparated(dag, C, D, conditioning));
- assertTrue(Utils.dSeparated(dag, D, C, conditioning));
-
- }
-
+public class FindNaYXTest {
+
//find naYX tests
-
@Test
public void testFindNaYX_singleUndirectedCommonNeighbor() {
Graph graph = new EdgeListGraph();
@@ -240,4 +136,5 @@ public void testFindNaYX_noCommonNeighbors() {
List result = Utils.findNaYX(x, y, graph);
assertTrue(result.isEmpty());
}
+
}
From 9966f77fe9bbb311f6cfc7acfb97f2856e62f82a Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Fri, 18 Jul 2025 14:47:38 +0200
Subject: [PATCH 14/32] Cleaning and Testing PairWiseConsensusBES
---
.../i3a/simd/consensusBN/ConsensusBES.java | 4 +-
...HierarchicalAgglomerativeClustererBNs.java | 10 +-
.../consensusBN/PairWiseConsensusBES.java | 148 +++++++++++++++---
.../simd/consensusBN/ConsensusBESTest.java | 6 +-
.../consensusBN/PairWiseConsensusBESTest.java | 105 +++++++++++++
5 files changed, 237 insertions(+), 36 deletions(-)
create mode 100644 src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
index c893d85..83e7ab0 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java
@@ -105,7 +105,7 @@ public void fusion(){
* This method retrieves the final fused DAG, which represents the optimal fusion of the input DAGs.
* @return the resulting output DAG after the fusion process.
*/
- public Dag getFusion(){
+ public Dag getFusionDag(){
return this.outputDag;
}
@@ -114,7 +114,7 @@ public Dag getFusion(){
* @return
*/
public List getOrderFusion(){
- return this.getFusion().paths().getValidOrder(this.getFusion().getNodes(),true);
+ return this.getFusionDag().paths().getValidOrder(this.getFusionDag().getNodes(),true);
}
/**
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java b/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java
index 1965f28..dee7bb2 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java
@@ -179,7 +179,7 @@ public Dag computeConsensusDag(int level){
for(int j = 0; j< this.setOfBNs.size(); j++){
for(int k = 0; k< this.setOfBNs.size(); k++){
if(this.clustersIndexes[cluster][j][level]&&this.clustersIndexes[cluster][k][level]&&(j!=k)){
- distance[j]+=this.initialpairwisedistance[j][k].getHammingDistance();//getNumberOfInsertedEdges();
+ distance[j]+=this.initialpairwisedistance[j][k].calculateHammingDistance();//getNumberOfInsertedEdges();
}
}
if(clustersIndexes[cluster][j][level]&&distance[j] 0 && (this.maxSize >= (this.clusterCardinalities[o1] + this.clusterCardinalities[o2]))|| level == 0){
PairWiseConsensusBES pairBNs = new PairWiseConsensusBES(this.clustersBN[o1][level],this.clustersBN[o2][level]);
PairWiseConsensusBES pairDag= (PairWiseConsensusBES) pairBNs;
- pairDag.getFusion();
+ pairDag.fusion();
return pairBNs;
}else if(this.maxSize == 0){
PairWiseConsensusBES pairBNs = new PairWiseConsensusBES(this.clustersBN[o1][level],this.clustersBN[o2][level]);
PairWiseConsensusBES pairDag= (PairWiseConsensusBES) pairBNs;
- pairDag.getFusion();
+ pairDag.fusion();
if((pairDag.getDagFusion().getNumEdges())/this.averageNEdges <= this.maxComplexityCluster|| level == 0)
return pairBNs;
else return null;
@@ -326,7 +326,7 @@ private Pair findMostSimilarClusters() {
PairWiseConsensusBES inCluster = dissimilarityMatrix[cluster][neighbor];
if(inCluster!= null){
double complexity = 0.0;
- complexity = (float) inCluster.getHammingDistance();//getNumberOfInsertedEdges();
+ complexity = (float) inCluster.calculateHammingDistance();//getNumberOfInsertedEdges();
if (indexUsed[neighbor]&&complexity setOfDags = new ArrayList();
- setOfDags.add(this.b1);
- setOfDags.add(this.b2);
- conBES = new ConsensusBES(setOfDags);
- conBES.fusion();
- this.numberOfInsertedEdges = conBES.getNumberOfInsertedEdges();
- this.numberOfUnionEdges = conBES.getUnion().getNumEdges();
- this.conDAG = conBES.getFusion();
+ /**
+ * Checks if the input DAGs are valid.
+ * Validity is determined by ensuring that the DAGs are not null, contain at least one node and one edge, and have the same set of nodes.
+ * If any of these conditions are not met, an IllegalArgumentException is thrown.
+ * @param firstDag first input DAG
+ * @param secondDag second input DAG
+ * @throws IllegalArgumentException if the input DAGs are not valid
+ */
+ private void checkInput(Dag firstDag, Dag secondDag) {
+ if (firstDag == null || secondDag == null) {
+ throw new IllegalArgumentException("Input DAGs cannot be null.");
+ }
+ if (firstDag.getNumNodes() == 0 || secondDag.getNumNodes() == 0) {
+ throw new IllegalArgumentException("Input DAGs must contain at least one node.");
+ }
+ if (firstDag.getNumEdges() == 0 || secondDag.getNumEdges() == 0) {
+ throw new IllegalArgumentException("Input DAGs must contain at least one edge.");
+ }
+ if (firstDag.getNodes().size() != secondDag.getNodes().size()) {
+ throw new IllegalArgumentException("Input DAGs must have the same number of nodes.");
+ }
+ if (!firstDag.getNodes().containsAll(secondDag.getNodes())) {
+ throw new IllegalArgumentException("Input DAGs must have the same set of nodes.");
+ }
}
-
+
+ /**
+ * Performs the fusion process by first applying the consensus union and then applying the Backward Equivalence Search.
+ */
+ public void fusion(){
+ // Creating a list of DAGs to be fused
+ ArrayList setOfDags = new ArrayList<>();
+ setOfDags.add(this.firstDag);
+ setOfDags.add(this.secondDag);
+ // Applying the ConsensusBES algorithm to fuse the DAGs
+ consensusBES = new ConsensusBES(setOfDags);
+ consensusBES.fusion();
+ // Retrieving the resulting DAG and the number of inserted edges
+ this.numberOfInsertedEdges = consensusBES.getNumberOfInsertedEdges();
+ this.numberOfUnionEdges = consensusBES.getUnion().getNumEdges();
+ this.consensusDAG = consensusBES.getFusionDag();
+ }
+
+ /**
+ * Returns the number of edges inserted during the fusion process.
+ * This method retrieves the number of edges that were added to the consensus DAG during the fusion process.
+ * It is useful for understanding how many edges were introduced in the consensus DAG compared to the original input DAGs.
+ * @return
+ */
public int getNumberOfInsertedEdges(){
return this.numberOfInsertedEdges;
}
+ /**
+ * Returns the number of edges in the union DAG after the consensus union process.
+ * This method retrieves the number of edges that were present in the union DAG after merging the transformed input DAGs.
+ * It is useful for understanding the size of the union DAG before applying the Backward Equivalence Search.
+ * This number can be used to compare with the number of edges in the final consensus DAG after the Backward Equivalence Search.
+ *
+ * @see ConsensusBES#getUnion()
+ * @see ConsensusBES#getNumberOfInsertedEdges()
+ * @return
+ */
public int getNumberOfUnionEdges(){
return this.numberOfUnionEdges;
}
- public int getHammingDistance(){
- if(this.conDAG==null) this.getFusion();
+ /**
+ * Calculates the Hamming distance between the optimum fusion DAG and the original input DAGs.
+ * @return The Hamming distance between the fused DAG and the original input DAGs.
+ */
+ public int calculateHammingDistance(){
+ if(this.consensusDAG==null) this.fusion();
int distance = 0;
- for(Edge ed: this.conDAG.getEdges()){
+ for(Edge ed: this.consensusDAG.getEdges()){
Node tail = ed.getNode1();
Node head = ed.getNode2();
- for(Dag g: conBES.getTransformedDags()){
+ for(Dag g: consensusBES.getTransformedDags()){
Edge edge1 = g.getEdge(tail, head);
Edge edge2 = g.getEdge(head, tail);
if(edge1 == null && edge2==null) distance++;
@@ -57,14 +142,25 @@ public int getHammingDistance(){
return distance+this.getNumberOfInsertedEdges();
}
+ /**
+ * Returns the resulting consensus DAG after applying the fusion process.
+ * This method retrieves the final fused DAG, which represents the optimal fusion of the input DAGs.
+ * It is useful for obtaining the consensus structure after the fusion process has been completed.
+ *
+ * @see ConsensusBES#getFusionDag()
+ * @return
+ */
public Dag getDagFusion(){
- return this.conDAG;
+ return this.consensusDAG;
}
+ /**
+ * Runs the fusion process in a thread, performing the consensus union and the Backward Equivalence Search with D-separation.
+ */
@Override
public void run() {
- this.getFusion();
+ this.fusion();
}
}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
index da9a321..528b4d8 100644
--- a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
@@ -111,7 +111,7 @@ public void testRandomBNFusion(){
ConsensusBES conDag = new ConsensusBES(setOfDags.setOfRandomDags);
conDag.fusion();
- Dag besDag = conDag.getFusion();
+ Dag besDag = conDag.getFusionDag();
Dag unionDag = conDag.getUnion();
ConsensusUnion consensusUnion = conDag.getConsensusUnion();
int totalNumberOfInsertedEdges = conDag.getNumberOfInsertedEdges();
@@ -131,7 +131,7 @@ void testFusionProducesDag() {
ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
fusionAlgorithm.fusion();
- Dag result = fusionAlgorithm.getFusion();
+ Dag result = fusionAlgorithm.getFusionDag();
assertNotNull(result, "El DAG de salida no debe ser null.");
assertFalse(result.paths().existsDirectedCycle(), "El DAG resultante no debe tener ciclos.");
}
@@ -183,7 +183,7 @@ void testThreadExecutionWithRunMethod() {
fail("El hilo fue interrumpido.");
}
- assertNotNull(fusionAlgorithm.getFusion(), "El DAG resultante debe existir tras ejecutar run().");
+ assertNotNull(fusionAlgorithm.getFusionDag(), "El DAG resultante debe existir tras ejecutar run().");
}
}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java
new file mode 100644
index 0000000..0c4c1cd
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java
@@ -0,0 +1,105 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class PairWiseConsensusBESTest {
+
+ private Node A, B, C;
+ private Dag dag1, dag2;
+
+ @BeforeEach
+ public void setup() {
+ A = new GraphNode("A");
+ B = new GraphNode("B");
+ C = new GraphNode("C");
+
+ dag1 = new Dag();
+ dag1.addNode(A);
+ dag1.addNode(B);
+ dag1.addNode(C);
+ dag1.addDirectedEdge(A, B);
+ dag1.addDirectedEdge(B, C);
+
+ dag2 = new Dag();
+ dag2.addNode(A);
+ dag2.addNode(B);
+ dag2.addNode(C);
+ dag2.addDirectedEdge(A, C);
+ dag2.addDirectedEdge(C, B);
+ }
+
+ @Test
+ public void testFusionCreatesNonNullDag() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.fusion();
+
+ Dag fusion = pwc.getDagFusion();
+ assertNotNull(fusion, "The fusion DAG should not be null");
+ }
+
+ @Test
+ public void testGetNumberOfInsertedEdges() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.fusion();
+
+ int inserted = pwc.getNumberOfInsertedEdges();
+ assertTrue(inserted >= 0, "Inserted edges should be >= 0");
+ }
+
+ @Test
+ public void testGetNumberOfUnionEdges() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.fusion();
+
+ PairWiseConsensusBES samePwc = new PairWiseConsensusBES(dag1, dag1);
+ samePwc.fusion();
+
+ int unionEdges = pwc.getNumberOfUnionEdges();
+ assertTrue(unionEdges > 0, "Union should contain some edges");
+
+ int sameUnionEdges = samePwc.getNumberOfUnionEdges();
+ assertTrue(sameUnionEdges == dag1.getNumEdges(), "Union edges should match the number of edges in a single DAG");
+ }
+
+ @Test
+ public void testGetHammingDistance() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ int distance = pwc.calculateHammingDistance();
+ PairWiseConsensusBES samePwc = new PairWiseConsensusBES(dag1, dag1);
+ int sameDistance = samePwc.calculateHammingDistance();
+
+ assertTrue(distance >= 0, "Hamming distance should be >= 0");
+ assertTrue(sameDistance == 0, "Hamming distance for identical DAGs should be 0");
+ }
+
+ @Test
+ public void testRunCallsGetFusion() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.run();
+
+ assertNotNull(pwc.getDagFusion(), "Fusion DAG should be created after run()");
+ assertTrue(pwc.getNumberOfUnionEdges() > 0, "Union edges should be computed");
+ }
+
+ @Test
+ public void testFusionIsConsistentWithInput() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.fusion();
+ Dag fusion = pwc.getDagFusion();
+
+ Set edges = fusion.getEdges();
+ assertNotNull(edges);
+ assertFalse(edges.isEmpty(), "Fusion DAG should contain edges");
+ }
+}
From a5626e41b963db04c668834e6a5463fdeb98e146 Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Fri, 18 Jul 2025 14:53:56 +0200
Subject: [PATCH 15/32] Improving coverage of BetaToAlpha
---
.../i3a/simd/consensusBN/BetaToAlphaTest.java | 18 +++++++++++++++++-
1 file changed, 17 insertions(+), 1 deletion(-)
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java
index 7d421ad..06504f7 100644
--- a/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java
@@ -1,9 +1,9 @@
package es.uclm.i3a.simd.consensusBN;
import java.util.ArrayList;
-import java.util.List;
import java.util.Arrays;
import java.util.HashSet;
+import java.util.List;
import java.util.Random;
import java.util.Set;
@@ -90,4 +90,20 @@ void testComputeAlphaHashBuildsCorrectMap() {
assertEquals(i, (bta.getAlphaHash()).get(node));
}
}
+
+
+ @Test
+ public void setterAndGetterTest(){
+ ArrayList firstOrder = new ArrayList<>(Arrays.asList(a, b, c));
+ BetaToAlpha b2Alpha = new BetaToAlpha(dag, firstOrder);
+ List newOrder = new ArrayList<>(Arrays.asList(c, b, a));
+ b2Alpha.setAlphaOrder(newOrder);
+
+ List order = b2Alpha.getAlphaOrder();
+ assertNotNull(order);
+ assertEquals(3, order.size());
+ assertTrue(order.contains(c));
+ assertTrue(order.contains(b));
+ assertTrue(order.contains(a));
+ }
}
From 15220698bcf606f4ca106b6f5dc43b82c404fcdb Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Mon, 21 Jul 2025 09:31:50 +0200
Subject: [PATCH 16/32] Removing SubSet for HashSet
---
.../BackwardEquivalenceSearchDSep.java | 10 +++---
.../consensusBN/HeuristicConsensusBES.java | 4 +--
.../uclm/i3a/simd/consensusBN/PowerSet.java | 35 ++++++++++---------
.../es/uclm/i3a/simd/consensusBN/SubSet.java | 24 -------------
4 files changed, 25 insertions(+), 48 deletions(-)
delete mode 100644 src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
index 4b3334f..1764b8b 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java
@@ -128,7 +128,7 @@ public Dag applyBackwardEliminationWithDSeparation(){
PowerSet hSubsets= PowerSetFabric.getPowerSet(candidateTail,candidateHead,hNeighbors);
while(hSubsets.hasMoreElements()) {
- SubSet hSubset=hSubsets.nextElement();
+ HashSet hSubset=hSubsets.nextElement();
// Checking if {naYXH} \ {hSubset} is a clique
List naYXH = findNaYX(candidateTail, candidateHead, graph);
@@ -226,8 +226,8 @@ private EdgeCandidate calculateBestCandidateEdge(List edges, double score)
PowerSet hSubsets= PowerSetFabric.getPowerSet(candidateTail,candidateHead,hNeighbors);
while(hSubsets.hasMoreElements()) {
- // Getting a subset of hNeighbors
- SubSet hSubset=hSubsets.nextElement();
+ // Getting a HashSet of hNeighbors
+ HashSet hSubset=hSubsets.nextElement();
// Checking if {naYXH} \ {hSubset} is a clique
List naYXH = Utils.findNaYX(candidateTail, candidateHead, graph);
@@ -374,7 +374,7 @@ private static List getHNeighbors(Node x, Node y, Graph graph) {
/**
* Applies the delete operation from Chickering 2002 for the edge x->y in the graph, and updates the edges
- * connecting x and y to the nodes in the provided subset. This is done to ensure that the same dependency structure is maintained
+ * connecting x and y to the nodes in the provided HashSet. This is done to ensure that the same dependency structure is maintained
* while removing the edge between x and y.
* @param tailNode The tail node of the edge to be deleted.
* @param headNode The head node of the edge to be deleted.
@@ -408,7 +408,7 @@ private static void delete(Node tailNode, Node headNode, Set subset, Graph
* @param graph The graph in which the change is being evaluated.
* @return The score resulting from deleting the edge, based on the given context.
*/
- private double deleteEval(Node x, Node y, SubSet conditioningSet, Graph graph){
+ private double deleteEval(Node x, Node y, HashSet conditioningSet, Graph graph){
// Setup the conditioning set for d-separation by removing the conditioning nodes from the naYX set, adding the parents of y and removing x.
Set finalConditioningSet = new HashSet<>(Utils.findNaYX(x, y, graph));
finalConditioningSet.removeAll(conditioningSet);
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java
index 8d2a496..cbf672f 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java
@@ -115,7 +115,7 @@ public void fusion(){
// List> hSubsets = powerSet(hNeighbors);
PowerSet hSubsets= PowerSetFabric.getPowerSet(_x,_y,hNeighbors);
while(hSubsets.hasMoreElements()) {
- SubSet hSubset=hSubsets.nextElement();
+ HashSet hSubset=hSubsets.nextElement();
if(hSubset.size() > maxSize) break;
double deleteEval = deleteEval(_x, _y, hSubset, graph);
if (!(deleteEval >= this.percentage)) deleteEval = 0.0;
@@ -248,7 +248,7 @@ private static List getHNeighbors(Node x, Node y, Graph graph) {
}
- double deleteEval(Node x, Node y, SubSet h, Graph graph){
+ double deleteEval(Node x, Node y, HashSet h, Graph graph){
Set set1 = new HashSet(Utils.findNaYX(x, y, graph));
set1.removeAll(h);
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java b/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java
index ebfe47d..cfe55ff 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java
@@ -3,29 +3,30 @@
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import edu.cmu.tetrad.graph.Node;
-public class PowerSet implements Enumeration {
+public class PowerSet implements Enumeration> {
List nodes;
- private List subSets;
+ private List> subSets;
private int index;
private int[] lista;
- private HashMap hashMap;
+ private HashMap> hashMap;
PowerSet(List nodes,int k) {
if(nodes.size()<=k)
k=nodes.size();
this.nodes=nodes;
- subSets = new ArrayList();
+ subSets = new ArrayList>();
index=0;
- hashMap=new HashMap();
+ hashMap=new HashMap>();
lista=ListFabric.getList(nodes.size());
for (int i : lista) {
- SubSet newSubSet = new SubSet();
+ HashSet newSubSet = new HashSet();
String selection = Integer.toBinaryString(i);
for (int j = selection.length() - 1; j >= 0; j--) {
if (selection.charAt(j) == '1') {
@@ -44,12 +45,12 @@ public class PowerSet implements Enumeration {
if(nodes.size()>maxPow)
maxPow=nodes.size();
this.nodes=nodes;
- subSets = new ArrayList();
+ subSets = new ArrayList>();
index=0;
- hashMap=new HashMap();
+ hashMap=new HashMap>();
lista=ListFabric.getList(nodes.size());
for (int i : lista) {
- SubSet newSubSet = new SubSet();
+ HashSet newSubSet = new HashSet();
String selection = Integer.toBinaryString(i);
for (int j = selection.length() - 1; j >= 0; j--) {
if (selection.charAt(j) == '1') {
@@ -65,7 +66,7 @@ public boolean hasMoreElements() {
return index nextElement() {
return subSets.get(index++);
}
@@ -83,9 +84,9 @@ public static long maxPowerSetSize() {
// for(int i=0;i.TEST_TRUE;
// else
-// hashMap.get(lista[i]).firstTest=SubSet.TEST_FALSE;
+// hashMap.get(lista[i]).firstTest=HashSet.TEST_FALSE;
// }
// }
// }
@@ -95,9 +96,9 @@ public static long maxPowerSetSize() {
// for(int i=0;i.TEST_TRUE;
// else
-// hashMap.get(lista[i]).secondTest=SubSet.TEST_FALSE;
+// hashMap.get(lista[i]).secondTest=HashSet.TEST_FALSE;
// }
// }
// }
@@ -105,11 +106,11 @@ public static long maxPowerSetSize() {
// public void reset(boolean isFordwardSearch) {
// index=0;
// for(int i=0;i aux=subSets.get(i);
// if(isFordwardSearch)
-// aux.secondTest=SubSet.TEST_NOT_EVALUATED;
+// aux.secondTest=HashSet.TEST_NOT_EVALUATED;
// else
-// aux.firstTest=SubSet.TEST_NOT_EVALUATED;
+// aux.firstTest=HashSet.TEST_NOT_EVALUATED;
// }
// }
}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java b/src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java
deleted file mode 100644
index be6018c..0000000
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java
+++ /dev/null
@@ -1,24 +0,0 @@
-package es.uclm.i3a.simd.consensusBN;
-
-import java.util.HashSet;
-
-import edu.cmu.tetrad.graph.Node;
-
-public class SubSet extends HashSet {
-
- private static final long serialVersionUID = 4569314863278L;
- public static final int TEST_NOT_EVALUATED=0;
- public static final int TEST_TRUE=1;
- public static final int TEST_FALSE=-1;
-
- public int firstTest=TEST_NOT_EVALUATED;
- public int secondTest=TEST_NOT_EVALUATED;
-
- public SubSet() {
- super();
- }
-
- public SubSet(SubSet other) {
- super(other);
- }
-}
\ No newline at end of file
From 8f8974b0236a4cfe502ecd65f4a1206a6f5a93fd Mon Sep 17 00:00:00 2001
From: JLaborda <15078416+JLaborda@users.noreply.github.com>
Date: Mon, 21 Jul 2025 09:42:48 +0200
Subject: [PATCH 17/32] Improving coverage of TransformDags
---
.../simd/consensusBN/TransformDagsTest.java | 37 ++++++++++++++++++-
1 file changed, 35 insertions(+), 2 deletions(-)
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java
index b399ed3..86cb1d6 100644
--- a/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java
@@ -3,6 +3,7 @@
import java.util.ArrayList;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.BeforeEach;
@@ -67,11 +68,42 @@ public void setUp() {
}
@Test
- public void testConstructorInitializesCorrectly() {
+ public void testConstructorGettersAndSetters() {
TransformDags transformer = new TransformDags(inputDags, alpha);
-
+ ArrayList empytBetas = new ArrayList<>();
+
assertNotNull(transformer);
assertEquals(0, transformer.getNumberOfInsertedEdges());
+
+ // GetSetOfOutputDags
+ ArrayList outputDags = transformer.getSetOfOutputDags();
+ assertNotNull(outputDags);
+ assertTrue(outputDags.isEmpty());
+ transformer.transform();
+ outputDags = transformer.getSetOfOutputDags();
+ assertNotNull(outputDags);
+ assertEquals(inputDags.size(), outputDags.size());
+ assertNotEquals(inputDags, outputDags);
+
+ // Testing setTransformers and getTransformers
+ ArrayList betas = transformer.getTransformers();
+ assertNotNull(betas);
+ assertTrue(!betas.isEmpty());
+ transformer.setTransformers(empytBetas);
+ assertEquals(empytBetas, transformer.getTransformers());
+ assertTrue(transformer.getTransformers().isEmpty());
+
+ // GetAlphaOrder
+ ArrayList