(graph.getAdjacentNodes(y));
- naYX.retainAll(graph.getAdjacentNodes(x));
-
- for (int i = naYX.size()-1; i >= 0; i--) {
- Node z = naYX.get(i);
- Edge edge = graph.getEdge(y, z);
-
- if (!Edges.isUndirectedEdge(edge)) {
- naYX.remove(z);
- }
- }
-
- return naYX;
- }
-
- public Dag getFusion(){
-
- return this.outputDag;
+/**
+ * The {@code HeuristicConsensusBES} class extends the {@code ConsensusBES} class
+ * to implement a heuristic approach for backward equivalence search in directed acyclic graphs (DAGs).
+ *
+ * This class is designed to perform a consensus structure learning algorithm
+ * by applying a heuristic method for backward equivalence search (BES) with D-separation.
+ * It allows for a more efficient search by limiting the size of the conditioning set and applying a
+ * percentage threshold for determining d-separation between nodes.
+ *
+ * The constructor initializes the class with a list of input DAGs, a maximum size for the
+ * conditioning set, and a percentage threshold for d-separation.
+ * The `fusion` method applies the consensus union to compute a consensus DAG from the input DAGs,
+ * and then applies the backward equivalence search with D-separation with the specified parameters to refine the graph.
+ */
+public class HeuristicConsensusBES extends ConsensusBES{
+
+ private final int maxSize;
+ private final double percentage;
+
+ /**
+ * Constructor for HeuristicConsensusBES.
+ * This class extends ConsensusBES to implement a heuristic approach
+ * for backward equivalence search in directed acyclic graphs (DAGs).
+ *
+ * @param dags the list of input DAGs to be fused.
+ * @param maxSize the maximum size of the conditioning set for d-separation checks.
+ * @param percentage the percentage/threshold for determining d-separation between nodes.
+ */
+ public HeuristicConsensusBES(ArrayList dags, int maxSize, double percentage) {
+ super(dags);
+ this.maxSize = maxSize;
+ this.percentage = percentage;
}
-
-
- public static void main(String args[]) {
-
-
- System.out.println("Grafos de Partida: ");
-
- // (seed, n. variables, n egdes aprox, n.dags, mutation)
- RandomBN setOfBNs = new RandomBN(0, Integer.parseInt(args[0]), Integer.parseInt(args[1]),
- Integer.parseInt(args[2]),Integer.parseInt(args[3]));
- setOfBNs.setMaxInDegree(3);
- setOfBNs.setMaxOutDegree(3);
- setOfBNs.generate();
-
- for(int i = 0; i< setOfBNs.setOfRandomBNs.size(); i++){
- System.out.println("red de partida: "+i);
- System.out.println("---------------------");
- System.out.println("Grafo: ");
- System.out.println(setOfBNs.setOfRandomDags.get(i).getDegree()+" "+ setOfBNs.setOfRandomDags.get(i).getNumEdges());
-// System.out.println("Probabilidades: ");
-// System.out.println(setOfBNs.setOfRandomBNs.get(i).toString());
-// System.out.println("_____________________");
-// System.out.println("Datos Simulados");
-// System.out.println(setOfBNs.setOfSampledBNs.get(i).toString());
-
-
- }
- //
- HeuristicConsensusBES conDag= null;
- conDag = new HeuristicConsensusBES(setOfBNs.setOfRandomDags,1.0);
- conDag.fusion();
- Dag g = conDag.getFusion();
- System.out.println("grafo de partida Union: "+conDag.union.getDegree()+" "+ conDag.union.getNumEdges());
- System.out.println("grafo consenso: "+ g.getDegree() +" Complejidad de la Fusion: "+ conDag.getNumberOfInsertedEdges()+ " "+ conDag.outputDag.getNumEdges());
+ /**
+ * Executes the heuristic consensus backward equivalence search.
+ * This method first applies the ConsensusUnion to compute a consensus DAG from the input DAGs,
+ * and then applies the Backward Equivalence Search with D-separation to refine the graph, setting the maxSize and percentage parameters for an heuristic search.
+ * The resulting output DAG is stored in the outputDag attribute.
+ */
+ @Override
+ public void fusion(){
+ // 1. Apply ConsensusUnion
+ consensusUnion();
+ // 2. Apply Heuristic BES with D-separation
+ BackwardEquivalenceSearchDSep bes = new BackwardEquivalenceSearchDSep(this.getUnion(), this.getInputDags(), this.getTransformedDags());
+ bes.setMaxSize(maxSize);
+ bes.setPercentage(percentage);
+ this.outputDag = bes.applyBackwardEliminationWithDSeparation();
+ // 3. Updating numberOfInsertedEdges
+ this.numberOfInsertedEdges -= bes.getNumberOfRemovedEdges();
}
}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVoting.java b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVoting.java
index 1abe261..f09ec62 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVoting.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVoting.java
@@ -6,48 +6,142 @@
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
+import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import static es.uclm.i3a.simd.consensusBN.Utils.pdagToDag;
+/**
+ * The {@code HeuristicConsensusMVoting} class implements a consensus structure learning algorithm
+ * based on a heuristic majority voting strategy.
+ *
+ * Given a collection of Directed Acyclic Graphs (DAGs), this class aggregates the structures
+ * by counting the frequency of directed edges across all graphs and selecting those
+ * that meet or exceed a specified threshold.
+ *
+ * The heuristic helps reduce noise by filtering out weakly supported edges, and it resolves
+ * conflicts (such as bidirectional edges) by applying majority rules or tie-breaking strategies.
+ *
+ * Typical use cases include combining results from different structure learning algorithms,
+ * bootstrapping runs, or expert-curated networks.
+ *
+ * The resulting output is a new DAG that aims to reflect the most consistent edges found
+ * across the input graphs.
+ *
+ *
Note: Input DAGs must share the same set of nodes for the consensus to be meaningful.
+ *
+ * Example usage:
+ *
{@code
+ * List inputGraphs = Arrays.asList(dag1, dag2, dag3);
+ * HeuristicConsensusMVoting consensus = new HeuristicConsensusMVoting(inputGraphs, threshold);
+ * Dag consensusDag = consensus.getConsensusGraph();
+ * }
+ *
+ */
public class HeuristicConsensusMVoting {
- ArrayList variables = null;
- Dag outputDag = null;
- ArrayList setOfdags = null;
- double percentage = 1.0;
- double [][] weight = null;
+ /**
+ * List of variables (nodes) in the consensus DAG.
+ * This list is derived from the nodes of the input DAGs.
+ */
+ private ArrayList variables = null;
+
+ /**
+ * The resulting output DAG after applying the heuristic consensus voting.
+ * This DAG contains the edges that were selected based on the majority voting strategy.
+ */
+ private Dag outputDag = null;
+
+ /**
+ * List of input DAGs used to compute the consensus.
+ * These DAGs are expected to have the same set of nodes.
+ */
+ private ArrayList setOfdags = null;
+
+ /**
+ * Percentage threshold for edge inclusion in the consensus DAG.
+ * An edge is included if it appears in at least this percentage of the input DAGs.
+ */
+ private double percentage = 1.0;
+
+ /**
+ * Weight matrix representing the frequency of edges between pairs of nodes.
+ * The weight[i][j] indicates how many times the edge from node i to node j appears in the input DAGs.
+ */
+ private double [][] weight = null;
+ /**
+ * Constructor for HeuristicConsensusMVoting.
+ * Initializes the variables, output DAG, input DAGs, and weight matrix.
+ * @param setOfdags the list of input DAGs to be fused.
+ * @param percentage the percentage threshold for edge inclusion in the consensus DAG.
+ */
+ public HeuristicConsensusMVoting(ArrayList setOfdags, double percentage) {
+ this.variables = (ArrayList) setOfdags.get(0).getNodes();
+ this.outputDag = null;
+ this.setOfdags = setOfdags;
+ this.percentage = percentage;
+ this.weight = new double[this.variables.size()][this.variables.size()];
+ setup();
+ }
+ /**
+ * Sets up the HeuristicConsensusMVoting instance by validating the input DAGs
+ * and building the weight matrix based on the edges present in the input DAGs.
+ */
+ private void setup() {
+ // Ensuring that the input DAGs are not null and have the same set of nodes
+ validateInputDags();
+ buildWeightMatrix();
+ }
+
+ /**
+ * Validates the input DAGs to ensure they are not null and have the same set of nodes.
+ * Throws an IllegalArgumentException if any validation fails.
+ */
+ private void validateInputDags() {
+ if (this.setOfdags == null || this.setOfdags.isEmpty())
+ throw new IllegalArgumentException("Input DAGs cannot be null or empty.");
+ for(Dag dag : this.setOfdags) {
+ if (dag.getNodes().size() != this.variables.size()) {
+ throw new IllegalArgumentException("All input DAGs must have the same number of nodes.");
+ }
+ if (!dag.getNodes().containsAll(this.variables)) {
+ throw new IllegalArgumentException("All input DAGs must have the same set of nodes.");
+ }
+ }
+ }
-public HeuristicConsensusMVoting(ArrayList setOfdags, double percentage) {
- super();
- this.variables = (ArrayList) setOfdags.get(0).getNodes();
- this.outputDag = null;
- this.setOfdags = setOfdags;
- this.percentage = percentage;
- this.weight = new double[this.variables.size()][this.variables.size()];
- ArrayList pdags = new ArrayList();
+ /**
+ * Builds the weight matrix based on the edges present in the input DAGs.
+ * Each entry weight[i][j] is incremented for each directed edge from node i to node j
+ * across all input DAGs, normalized by the number of input DAGs
+ */
+ private void buildWeightMatrix() {
+ ArrayList pdags = new ArrayList<>();
for(Dag g: this.setOfdags){
- Graph pd = new EdgeListGraph(new LinkedList(g.getNodes()));
+ Graph graph = new EdgeListGraph(new LinkedList<>(g.getNodes()));
for(Edge e: g.getEdges()){
- pd.addEdge(e);
+ graph.addEdge(e);
}
- pdagToDag(pd);
- pdags.add(pd);
+ pdagToDag(graph);
+ pdags.add(graph);
}
-
+
for(Graph pd: pdags){
for(Edge e: pd.getEdges()){
- Node n1 = e.getNode1();
- Node n2 = e.getNode2();
if(e.isDirected()){
- if(e.getEndpoint1() == Endpoint.ARROW){
- this.weight[variables.indexOf(n2)][variables.indexOf(n1)]+= (double) (1.0/this.setOfdags.size());
- }else{
- this.weight[variables.indexOf(n1)][variables.indexOf(n2)]+= (double) (1.0/this.setOfdags.size());
- }
+ Node from = Edges.getDirectedEdgeTail(e);
+ Node to = Edges.getDirectedEdgeHead(e);
+ this.weight[variables.indexOf(from)][variables.indexOf(to)] += (double) (1.0/this.setOfdags.size());
+ // if(e.getEndpoint1() == Endpoint.ARROW){
+ // this.weight[variables.indexOf(n2)][variables.indexOf(n1)]+= (double) (1.0/this.setOfdags.size());
+ // }else{
+ // this.weight[variables.indexOf(n1)][variables.indexOf(n2)]+= (double) (1.0/this.setOfdags.size());
+ // }
}else{
+ Node n1 = e.getNode1();
+ Node n2 = e.getNode2();
this.weight[variables.indexOf(n2)][variables.indexOf(n1)]+= (double) (1.0/this.setOfdags.size());
this.weight[variables.indexOf(n1)][variables.indexOf(n2)]+= (double) (1.0/this.setOfdags.size());
}
@@ -55,36 +149,88 @@ public HeuristicConsensusMVoting(ArrayList setOfdags, double percentage) {
}
}
-public Dag fusion(){
-
- this.outputDag = new Dag(variables);
- boolean procced = true;
- int bestEdgei = 0;
- int bestEdgej = 0;
- double maxW = 0.0;
- while(procced){
- for(int i = 0; i= maxW){
- if((this.weight[i][j] > maxW) || ((this.weight[i][j]==maxW) && (Math.random()>0.5))){
- bestEdgei = i;
- bestEdgej = j;
- maxW = this.weight[i][j];
+ /**
+ * Performs the fusion of the input DAGs using a heuristic majority voting strategy.
+ * The method iteratively selects edges based on their weights and adds them to the output DAG
+ * until no more edges meet the specified percentage threshold.
+ * @return The resulting consensus DAG after applying the heuristic voting.
+ */
+ public Dag fusion(){
+ this.outputDag = new Dag(variables);
+
+ while(true){
+ int bestEdgei = -1; // Best edge head node index
+ int bestEdgej = -1; // Best edge tail node index
+ double maxW = 0.0; // Maximum weight found in the current iteration
+
+ // Find the best edge based on the weight matrix
+ for(int i = 0; i= maxW){
+ if((this.weight[i][j] > maxW) || ((this.weight[i][j]==maxW) && (Math.random()>0.5))){
+ bestEdgei = i;
+ bestEdgej = j;
+ maxW = this.weight[i][j];
+ }
}
}
- }
- if(maxW >= this.percentage){
- if(!this.outputDag.paths().existsDirectedPath(variables.get(bestEdgej), variables.get(bestEdgei))){
- this.outputDag.addEdge(new Edge(variables.get(bestEdgei),variables.get(bestEdgej),Endpoint.TAIL,Endpoint.ARROW));
- this.weight[bestEdgei][bestEdgej] = 0;
- }else this.weight[bestEdgei][bestEdgej] = 0;
- if(maxW==0) procced = false;
- maxW = 0.0;
- }else procced = false;
+ // Stop if no edge meets the threshold
+ if(bestEdgei == -1 || bestEdgej == -1 || maxW < percentage || maxW == 0.0)
+ break;
+
+ // Add edge if it doesn't introduce a cycle
+ Node from = variables.get(bestEdgei);
+ Node to = variables.get(bestEdgej);
+ if(!this.outputDag.paths().existsDirectedPath(to, from)){
+ this.outputDag.addEdge(new Edge(from,to,Endpoint.TAIL,Endpoint.ARROW));
+ }
+ // Mark the edge as used by setting its weight to zero
+ this.weight[bestEdgei][bestEdgej] = 0;
+ }
+
+ return this.outputDag;
+ }
+
+ /**
+ * Returns the nodes (variables) of the consensus DAG.
+ * @return A list of nodes representing the variables in the consensus DAG.
+ */
+ public ArrayList getVariables() {
+ return variables;
+ }
+
+ /**
+ * Returns the resulting consensus DAG after applying the heuristic voting.
+ * @return The output DAG containing the selected edges based on the majority voting strategy.
+ */
+ public Dag getOutputDag() {
+ return outputDag;
+ }
+
+ /**
+ * Returns the list of input DAGs used to compute the consensus.
+ * @return A list of DAGs that were used as input for the consensus voting.
+ */
+ public ArrayList getSetOfdags() {
+ return setOfdags;
+ }
+
+ /**
+ * Returns the percentage threshold used for edge inclusion in the consensus DAG.
+ * @return The percentage threshold for edge inclusion.
+ */
+ public double getPercentage() {
+ return percentage;
+ }
+ /**
+ * Returns the weight matrix representing the frequency of edges between pairs of nodes.
+ * Each entry weight[i][j] indicates the weight that the edge from node i to node j has in the input DAGs.
+ * @return The weight matrix used in the consensus voting process.
+ * @see #fusion()
+ */
+ public double[][] getWeight() {
+ return weight;
}
-
- return this.outputDag;
-}
}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java b/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java
index 1965f28..e55c0d9 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java
@@ -106,7 +106,7 @@ public int cluster() {
clustersIndexes[i][i][0] = true;
}
- dissimilarityMatrix = computeDissimilarityMatrix();
+ computeDissimilarityMatrix();
for (int a = 1; a 0 && (this.maxSize >= (this.clusterCardinalities[o1] + this.clusterCardinalities[o2]))|| level == 0){
PairWiseConsensusBES pairBNs = new PairWiseConsensusBES(this.clustersBN[o1][level],this.clustersBN[o2][level]);
PairWiseConsensusBES pairDag= (PairWiseConsensusBES) pairBNs;
- pairDag.getFusion();
+ pairDag.fusion();
return pairBNs;
}else if(this.maxSize == 0){
PairWiseConsensusBES pairBNs = new PairWiseConsensusBES(this.clustersBN[o1][level],this.clustersBN[o2][level]);
PairWiseConsensusBES pairDag= (PairWiseConsensusBES) pairBNs;
- pairDag.getFusion();
+ pairDag.fusion();
if((pairDag.getDagFusion().getNumEdges())/this.averageNEdges <= this.maxComplexityCluster|| level == 0)
return pairBNs;
else return null;
@@ -326,7 +325,7 @@ private Pair findMostSimilarClusters() {
PairWiseConsensusBES inCluster = dissimilarityMatrix[cluster][neighbor];
if(inCluster!= null){
double complexity = 0.0;
- complexity = (float) inCluster.getHammingDistance();//getNumberOfInsertedEdges();
+ complexity = (float) inCluster.calculateHammingDistance();//getNumberOfInsertedEdges();
if (indexUsed[neighbor]&&complexity hashMap=new HashMap();
-//
-// public static int[] getList(int size) {
-// Integer key=size;
-// int[] lista=hashMap.get(key);
-// if(lista==null) {
-// lista=generateList(size);
-// hashMap.put(key, lista);
-// }
-// return lista;
-// }
-
- public static int[] getList(int size) {
- return generateList(size);
- }
-
- private static int[] generateList(int size) {
+ /**
+ * MAX_SIZE is the maximum number of elements that can be included in any subset.
+ * It is set to Integer.MAX_VALUE by default, meaning no limit unless specified.
+ */
+ public static int MAX_SIZE=Integer.MAX_VALUE;
+
+ /**
+ * Generates a list of integers representing all subsets of a set of a given size.
+ * Each integer is a bitmask where the i-th bit represents the inclusion of the i-th element.
+ * @param size the size of the set for which subsets are generated
+ * @return an array of integers representing the subsets
+ */
+ public static int[] generateList(int size) {
int[] lista;
if(size==0) {
return new int[1];
}
+ // Generation of powers of numbers that are powers of 2
int[] pow2=new int[size];
pow2[0]=1;
for(int i=1;iaux[0][index]) {
aux[0][counter]=aux[0][index]+pow2[i];
@@ -50,11 +55,4 @@ private static int[] generateList(int size) {
return lista;
}
- public static void setMaxSize(int maxParents) {
- ListFabric.maxSize = maxParents;
- }
-
- public static int getMaxSize() {
- return ListFabric.maxSize;
- }
}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java
index 3b053bb..f1669bb 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java
@@ -2,54 +2,138 @@
import java.util.ArrayList;
-
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Node;
-
+/**
+ * This class implements the PairWiseConsensusBES algorithm, which measures the similarity between two DAGs by fusing them with the ConsensusBES algorithm.
+ * It first fuses the two input DAGs into a consensus DAG using the ConsensusBES class.
+ * After obtaining the consensus DAG, it calculates the Hamming distance between the fused DAG and the original input DAGs.
+ * The resulting output DAG is stored in the consensusDAG attribute, and the number of inserted edges during the fusion process can be retrieved using getNumberOfInsertedEdges.
+ */
public class PairWiseConsensusBES implements Runnable{
+ /**
+ * The first input DAG to be fused.
+ */
+ private Dag firstDag = null;
+
+ /**
+ * The second input DAG to be fused.
+ */
+ private Dag secondDag = null;
+
+ /**
+ * The resulting consensus DAG after applying the fusion process.
+ */
+ private Dag consensusDAG = null;
+
+ /**
+ * Instance of ConsensusBES used to compute the consensus DAG from the input DAGs.
+ */
+ private ConsensusBES consensusBES= null;
- private Dag b1 = null;
- private Dag b2 = null;
- private Dag conDAG = null;
- private ConsensusBES conBES= null;
+ /**
+ * Number of total edges inserted during the fusion process.
+ */
private int numberOfInsertedEdges = 0;
+
+ /**
+ * Number of edges inserted during the consensus union process.
+ */
private int numberOfUnionEdges = 0;
- public PairWiseConsensusBES(Dag b1, Dag b2) {
- super();
- this.b1 = b1;
- this.b2 = b2;
+ /**
+ * Constructor for the PairWiseConsensusBES class.
+ * It initializes the instance with two input DAGs and checks if they are valid.
+ * If the input DAGs are not valid, it throws an IllegalArgumentException.
+ * @param firstDag the first input DAG for fusion similarity.
+ * @param secondDag the second input DAG for fusion similarity.
+ */
+ public PairWiseConsensusBES(Dag firstDag, Dag secondDag) {
+ checkInput(firstDag, secondDag);
+ this.firstDag = firstDag;
+ this.secondDag = secondDag;
}
-
- public void getFusion(){
- ArrayList setOfDags = new ArrayList();
- setOfDags.add(this.b1);
- setOfDags.add(this.b2);
- conBES = new ConsensusBES(setOfDags);
- conBES.fusion();
- this.numberOfInsertedEdges = conBES.getNumberOfInsertedEdges();
- this.numberOfUnionEdges = conBES.union.getNumEdges();
- this.conDAG = conBES.getFusion();
+ /**
+ * Checks if the input DAGs are valid.
+ * Validity is determined by ensuring that the DAGs are not null, contain at least one node and one edge, and have the same set of nodes.
+ * If any of these conditions are not met, an IllegalArgumentException is thrown.
+ * @param firstDag first input DAG for fusion similarity
+ * @param secondDag second input DAG for fusion similarity
+ * @throws IllegalArgumentException if the input DAGs are not valid
+ */
+ private void checkInput(Dag firstDag, Dag secondDag) {
+ if (firstDag == null || secondDag == null) {
+ throw new IllegalArgumentException("Input DAGs cannot be null.");
+ }
+ if (firstDag.getNumNodes() == 0 || secondDag.getNumNodes() == 0) {
+ throw new IllegalArgumentException("Input DAGs must contain at least one node.");
+ }
+ if (firstDag.getNumEdges() == 0 || secondDag.getNumEdges() == 0) {
+ throw new IllegalArgumentException("Input DAGs must contain at least one edge.");
+ }
+ if (firstDag.getNodes().size() != secondDag.getNodes().size()) {
+ throw new IllegalArgumentException("Input DAGs must have the same number of nodes.");
+ }
+ if (!firstDag.getNodes().containsAll(secondDag.getNodes())) {
+ throw new IllegalArgumentException("Input DAGs must have the same set of nodes.");
+ }
}
-
+
+ /**
+ * Performs the fusion process by first applying the consensus union and then applying the Backward Equivalence Search.
+ */
+ public void fusion(){
+ // Creating a list of DAGs to be fused
+ ArrayList setOfDags = new ArrayList<>();
+ setOfDags.add(this.firstDag);
+ setOfDags.add(this.secondDag);
+ // Applying the ConsensusBES algorithm to fuse the DAGs
+ consensusBES = new ConsensusBES(setOfDags);
+ consensusBES.fusion();
+ // Retrieving the resulting DAG and the number of inserted edges
+ this.numberOfInsertedEdges = consensusBES.getNumberOfInsertedEdges();
+ this.numberOfUnionEdges = consensusBES.getUnion().getNumEdges();
+ this.consensusDAG = consensusBES.getFusionDag();
+ }
+
+ /**
+ * Returns the number of edges inserted during the fusion process.
+ * This method retrieves the number of edges that were added to the consensus DAG during the fusion process.
+ * It is useful for understanding how many edges were introduced in the consensus DAG compared to the original input DAGs.
+ * @return the number of edges inserted during the fusion process.
+ */
public int getNumberOfInsertedEdges(){
return this.numberOfInsertedEdges;
}
+ /**
+ * Returns the number of edges in the union DAG after the consensus union process.
+ * This method retrieves the number of edges that were present in the union DAG after merging the transformed input DAGs.
+ * It is useful for understanding the size of the union DAG before applying the Backward Equivalence Search.
+ * This number can be used to compare with the number of edges in the final consensus DAG after the Backward Equivalence Search.
+ *
+ * @see ConsensusBES#getUnion()
+ * @see ConsensusBES#getNumberOfInsertedEdges()
+ * @return the number of edges in the union DAG after the consensus union process.
+ */
public int getNumberOfUnionEdges(){
return this.numberOfUnionEdges;
}
- public int getHammingDistance(){
- if(this.conDAG==null) this.getFusion();
+ /**
+ * Calculates the Hamming distance between the optimum fusion DAG and the original input DAGs.
+ * @return The Hamming distance between the fused DAG and the original input DAGs.
+ */
+ public int calculateHammingDistance(){
+ if(this.consensusDAG==null) this.fusion();
int distance = 0;
- for(Edge ed: this.conDAG.getEdges()){
+ for(Edge ed: this.consensusDAG.getEdges()){
Node tail = ed.getNode1();
Node head = ed.getNode2();
- for(Dag g: conBES.setOfOutDags){
+ for(Dag g: consensusBES.getTransformedDags()){
Edge edge1 = g.getEdge(tail, head);
Edge edge2 = g.getEdge(head, tail);
if(edge1 == null && edge2==null) distance++;
@@ -58,14 +142,25 @@ public int getHammingDistance(){
return distance+this.getNumberOfInsertedEdges();
}
+ /**
+ * Returns the resulting consensus DAG after applying the fusion process.
+ * This method retrieves the final fused DAG, which represents the optimal fusion of the input DAGs.
+ * It is useful for obtaining the consensus structure after the fusion process has been completed.
+ *
+ * @see ConsensusBES#getFusionDag()
+ * @return the resulting consensus DAG after applying the whole fusion process.
+ */
public Dag getDagFusion(){
- return this.conDAG;
+ return this.consensusDAG;
}
+ /**
+ * Runs the fusion process in a thread, performing the consensus union and the Backward Equivalence Search with D-separation.
+ */
@Override
public void run() {
- this.getFusion();
+ this.fusion();
}
}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java b/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java
index ebfe47d..c240f46 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java
@@ -3,113 +3,140 @@
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
-
+import java.util.Set;
import edu.cmu.tetrad.graph.Node;
-public class PowerSet implements Enumeration {
+/**
+ * PowerSet generates all subsets of a given set of nodes, with an optional maximum size constraint.
+ * It implements Enumeration to allow iteration over the subsets.
+ */
+public class PowerSet implements Enumeration> {
+ /**
+ * List of nodes for which the power set is generated.
+ */
List nodes;
- private List subSets;
+ /**
+ * List to store the generated subsets.
+ */
+ private final List> subSets = new ArrayList<>();
+ /**
+ * Index to track the current position in the enumeration.
+ */
private int index;
- private int[] lista;
- private HashMap hashMap;
-
-
- PowerSet(List nodes,int k) {
- if(nodes.size()<=k)
- k=nodes.size();
- this.nodes=nodes;
- subSets = new ArrayList();
- index=0;
- hashMap=new HashMap();
- lista=ListFabric.getList(nodes.size());
- for (int i : lista) {
- SubSet newSubSet = new SubSet();
- String selection = Integer.toBinaryString(i);
- for (int j = selection.length() - 1; j >= 0; j--) {
- if (selection.charAt(j) == '1') {
- newSubSet.add(nodes.get(selection.length() - j - 1));
- }
- }
- if(newSubSet.size()<=k){
- subSets.add(newSubSet);
- hashMap.put(i, newSubSet);
- }
- }
- }
+ /**
+ * List of integers representing the subsets in binary form.
+ */
+ private int[] binaryList;
+
+ /**
+ * A map to store the subsets with their corresponding binary representation.
+ * The key is the binary representation of the subset, and the value is the subset itself.
+ */
+ private HashMap> subset;
+ /**
+ * Maximum size of the subsets to be generated.
+ * If set to a value less than the number of nodes, it limits the size of the subsets.
+ */
+ private int maxPow = 0;
- PowerSet(List nodes) {
- if(nodes.size()>maxPow)
- maxPow=nodes.size();
- this.nodes=nodes;
- subSets = new ArrayList();
- index=0;
- hashMap=new HashMap();
- lista=ListFabric.getList(nodes.size());
- for (int i : lista) {
- SubSet newSubSet = new SubSet();
- String selection = Integer.toBinaryString(i);
- for (int j = selection.length() - 1; j >= 0; j--) {
- if (selection.charAt(j) == '1') {
- newSubSet.add(nodes.get(selection.length() - j - 1));
- }
- }
- subSets.add(newSubSet);
- hashMap.put(i, newSubSet);
+ /**
+ * Builds a PowerSet with subsets of the given nodes, limited to a maximum size.
+ * @param nodes List of nodes to generate subsets from.
+ * @param maxSize Maximum size of the subsets to be generated. Assuring that k does not exceed the number of nodes.
+ * * If maxSize is negative, it will throw an IllegalArgumentException.
+ * @throws IllegalArgumentException if maxSize is negative.
+
+ */
+ public PowerSet(List nodes, int maxSize) {
+ if (maxSize < 0) {
+ throw new IllegalArgumentException("maxSize cannot be negative");
}
- }
-
+ if (maxSize >= nodes.size()) {
+ maxSize = nodes.size();
+ }
+ this.nodes = nodes;
+ initializeSubsets(maxSize);
+ }
+
+ /**
+ * Builds a PowerSet with all subsets of the given nodes, without size limitation.
+ *
+ * @param nodes Lista de nodos de entrada.
+ */
+ public PowerSet(List nodes) {
+ if (nodes.size() > maxPow) {
+ maxPow = nodes.size();
+ }
+ this.nodes = nodes;
+ initializeSubsets(nodes.size()); // sin límite: k = nodes.size()
+ }
+
+ /**
+ * Initializes the subsets based on the nodes and the maximum size.
+ * This method generates all possible subsets of the nodes, respecting the maximum size constraint.
+ * @param maxSize Maximum size of the subsets to be generated.
+ * If maxSize is greater than the number of nodes, it will generate subsets of all sizes.
+ * If maxSize is 0, it will only generate the empty set.
+ */
+ private void initializeSubsets(int maxSize) {
+ subset = new HashMap<>();
+ index = 0;
+ binaryList = ListFabric.generateList(nodes.size());
+
+ for (int i : binaryList) {
+ Set newSubSet = new HashSet<>();
+ String selection = Integer.toBinaryString(i);
+
+ for (int j = selection.length() - 1; j >= 0; j--) {
+ if (selection.charAt(j) == '1') {
+ int idx = selection.length() - j - 1;
+ newSubSet.add(nodes.get(idx));
+ }
+ }
+
+ if (newSubSet.size() <= maxSize) {
+ subSets.add(newSubSet);
+ subset.put(i, newSubSet);
+ }
+ }
+ }
+
+ /**
+ * Checks if there are more subsets to iterate over.
+ * @return true if there are more subsets, false otherwise.
+ */
+ @Override
public boolean hasMoreElements() {
return index nextElement() {
return subSets.get(index++);
}
+ /**
+ * Resets the index to allow re-iteration over the subsets.
+ * This method allows the enumeration to start over from the beginning.
+ */
public void resetIndex(){
this.index = 0;
}
- private static int maxPow = 0;
- public static long maxPowerSetSize() {
+ /**
+ * Returns the maximum size of the power set based on the maximum number of nodes.
+ * This method calculates the size of the power set as 2 raised to the power of the maximum number of nodes.
+ * @return The maximum size of the power set.
+ */
+ public long maxPowerSetSize() {
return (long) Math.pow(2,maxPow);
}
-
-// public void firstTest(boolean result) {
-// int numInicial=lista[index-1];
-// for(int i=0;i> hashMap=new HashMap>();
-//
-// private PowerSetFabric() {
-// }
-
- public static PowerSet getPowerSet(List nodes, int k){
- return new PowerSet(nodes,k);
-
- }
-
- public static PowerSet getPowerSet(Node x, Node y, List nodes) {
- return new PowerSet(nodes);
-// if(!usePowerSetsCache)
-// return new PowerSet(nodes);
-// PowerSet pSet=get(x,y);
-// if(pSet==null || pSet.nodes.size()!=nodes.size()) { // if(pSet==null || !pSet.t.equals(nodes)) {
-// pSet=new PowerSet(nodes);
-// put(x,y,pSet);
-// }
-// else {
-// pSet.reset(mode==MODE_FES);
-// }
-// return pSet;
- }
-//
-// private static void put(Node x, Node y, PowerSet pSet) {
-// hashMap.get(x).put(y, pSet);
-//
-// }
-//
-// private static PowerSet get(Node x, Node y) {
-// HashMap aux=hashMap.get(x);
-// if(aux==null) {
-// aux=new HashMap();
-// hashMap.put(x, aux);
-// }
-// return aux.get(y);
-// }
-//
-// public static int getMode() {
-// return mode;
-// }
-//
- /*
- public static void setMode(int mode) {
- PowerSetFabric.mode=mode;
- }
- */
-
- public static boolean isUsePowerSetsCache() {
- return usePowerSetsCache;
- }
-
- public static void setUsePowerSetsCache(boolean usePowerSetsCache) {
- PowerSetFabric.usePowerSetsCache = usePowerSetsCache;
- }
-}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/RandomBN.java b/src/main/java/es/uclm/i3a/simd/consensusBN/RandomBN.java
index 06c89a4..4840bcb 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/RandomBN.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/RandomBN.java
@@ -13,6 +13,11 @@
import edu.cmu.tetrad.bayes.MlBayesIm.InitializationMethod;
import edu.cmu.tetrad.data.*;
+/**
+ * RandomBN generates a set of random Bayesian networks (BNs) based on specified parameters. This class has been used for experiments, and is being maintained for compatibility with existing experiments.
+ * Further development should avoid using this class and instead use @see RandomGraph from Tetrad for generating random DAGs.
+ */
+@Deprecated
public class RandomBN {
int seed = 0;
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java b/src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java
deleted file mode 100644
index be6018c..0000000
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java
+++ /dev/null
@@ -1,24 +0,0 @@
-package es.uclm.i3a.simd.consensusBN;
-
-import java.util.HashSet;
-
-import edu.cmu.tetrad.graph.Node;
-
-public class SubSet extends HashSet {
-
- private static final long serialVersionUID = 4569314863278L;
- public static final int TEST_NOT_EVALUATED=0;
- public static final int TEST_TRUE=1;
- public static final int TEST_FALSE=-1;
-
- public int firstTest=TEST_NOT_EVALUATED;
- public int secondTest=TEST_NOT_EVALUATED;
-
- public SubSet() {
- super();
- }
-
- public SubSet(SubSet other) {
- super(other);
- }
-}
\ No newline at end of file
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java b/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java
index 9861e99..f798234 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java
@@ -2,205 +2,133 @@
import java.util.ArrayList;
-
-import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Node;
+/**
+ * This class transforms a set of DAGs by applying the BetaToAlpha transformation to each DAG with a given alpha order.
+ */
public class TransformDags {
+ /**
+ * List of input DAGs to be transformed.
+ */
+ private final ArrayList setOfDags;
+ /**
+ * List of output DAGs after transformation.
+ */
+ private final ArrayList setOfOutputDags;
+ /**
+ * The alpha order used for the transformation.
+ */
+ private final ArrayList alpha;
+ /**
+ * The transformation objects for each DAG.
+ * Each BetaToAlpha object applies the transformation to a corresponding DAG in setOfDags using the alpha order provided.
+ */
+ private ArrayList transformers= null;
- ArrayList setOfDags = null;
- ArrayList setOfOutputDags = null;
- ArrayList alfa = null;
- ArrayList metAs= null;
- int numberOfInsertedEdges = 0;
-// int weight[][][] = null;
+ /**
+ * Number of edges inserted during the transformation process.
+ * This is used to track how many edges were added to the transformed DAGs.
+ */
+ private int numberOfInsertedEdges = 0;
- public TransformDags(ArrayList dags, ArrayList alfa){
+ /**
+ * Constructor for TransformDags.
+ * Initializes the object with a list of DAGs and an alpha order.
+ * It creates a new BetaToAlpha transformation for each DAG in the input list.
+ * Each DAG in the input list will be transformed according to this alpha order.
+ * The transformation is applied by creating a BetaToAlpha object for each DAG.
+ * The transformed DAGs will be stored in setOfOutputDags after calling the transform() method.
+ * @see BetaToAlpha
+ * @param dags List of DAGs to be transformed.
+ * @param alpha List of nodes representing the alpha order for the transformation.
+ */
+ public TransformDags(ArrayList dags, ArrayList alpha){
this.setOfDags = dags;
- this.setOfOutputDags = new ArrayList();
- this.metAs = new ArrayList();
- this.alfa = alfa;
-
+ this.setOfOutputDags = new ArrayList<>();
+ this.transformers = new ArrayList<>();
+ this.alpha = alpha;
+ // Initialize the BetaToAlpha transformation for each DAG in the input list
for (Dag i : setOfDags) {
Dag out = new Dag(i);
- this.metAs.add(new BetaToAlpha(out,alfa));
+ this.transformers.add(new BetaToAlpha(out,this.alpha));
}
}
-
+ /**
+ * Transforms the input DAGs by applying the BetaToAlpha transformation.
+ * This method iterates through each BetaToAlpha object, applies the transformation,
+ * and collects the transformed DAGs into setOfOutputDags.
+ * It also counts the total number of edges inserted during the transformations.
+ *
+ * @see BetaToAlpha#transform()
+ * @see BetaToAlpha#getNumberOfInsertedEdges()
+ * @see BetaToAlpha#getGraph()
+ * @return An ArrayList of transformed DAGs after applying the BetaToAlpha transformation.
+ */
public ArrayList transform (){
-
this.numberOfInsertedEdges = 0;
-
- for(BetaToAlpha transformDagi: this.metAs){
+ for(BetaToAlpha transformDagi: this.transformers){
transformDagi.transform();
this.numberOfInsertedEdges += transformDagi.getNumberOfInsertedEdges();
- this.setOfOutputDags.add(transformDagi.G);
+ this.setOfOutputDags.add(transformDagi.getGraph());
}
-
-
return this.setOfOutputDags;
-
}
+ /**
+ * Returns the number of edges that were inserted during the transformation process.
+ * @return The total number of edges inserted across all transformed DAGs.
+ */
public int getNumberOfInsertedEdges(){
return this.numberOfInsertedEdges;
}
+ /**
+ * Returns the list of input DAGs that were transformed.
+ * @return An ArrayList of DAGs that were provided as input to the transformation.
+ */
+ public ArrayList getSetOfDags() {
+ return this.setOfDags;
+ }
+
+ /**
+ * Returns the list of transformed output DAGs.
+ * This list contains the DAGs after applying the BetaToAlpha transformation.
+ * @return An ArrayList of transformed DAGs.
+ */
+ public ArrayList getSetOfOutputDags() {
+ return this.setOfOutputDags;
+ }
+
+ /**
+ * Returns the alpha order used for the transformation.
+ * @return An ArrayList of nodes representing the alpha order.
+ */
+ public ArrayList getAlpha() {
+ return this.alpha;
+ }
+
+ /**
+ * Returns the list of BetaToAlpha transformers used for the transformation.
+ * @return An ArrayList of BetaToAlpha objects, each corresponding to a DAG in the input list.
+ */
+ public ArrayList getTransformers() {
+ return this.transformers;
+ }
+
+ /**
+ * Sets the list of BetaToAlpha transformers.
+ * This method allows for updating the transformers used in the transformation process.
+ *
+ * @param transformers An ArrayList of BetaToAlpha objects to be set as the transformers for this TransformDags instance.
+ * Each transformer will apply the BetaToAlpha transformation to its corresponding DAG in the input list.
+ */
+ public void setTransformers(ArrayList transformers) {
+ this.transformers = transformers;
+ }
-
-
-// void computeWeight(){
-//
-// this.weight = new int[this.setOfDags.size()][this.alfa.size()][this.alfa.size()];
-//
-// for(Dag g: this.setOfOutputDags){
-// for(Node nodei : g.getNodes()){
-// List pa = g.getParents(nodei);
-// if(pa.isEmpty()) continue;
-// List anc = new ArrayList();
-// anc.add(nodei);
-// anc = g.getAncestors(anc);
-// Dag gAnc = new Dag(g.subgraph(anc));
-// // me quedo con el grafo ancestral del node_i
-// for(Node pai: pa){ // Calculo el numero de caminos desde los ancestros que se "activan" borrando cada padre.
-// int npaths = 0;
-// Dag gAncNopai = new Dag(gAnc);
-// for(Node rm : pa) if(!rm.equals(pai)) gAncNopai.removeNode(rm); // borro todos los padres menos el pa_i en el grafo ancestral.
-// for(Node nodeAc: anc){ // para cada ancestro voy mirando si hay un camino dirigido.
-// if((gAncNopai.getNodes().contains(nodeAc))&&(!nodeAc.equals(nodei)))
-// if(GraphUtils.paths().existsDirectedPath(gAncNopai,nodeAc, nodei)) npaths++;
-// }
-// // npaths tiene el numero de caminos diridos que se han activado quitando el padre pa_i
-// this.weight[this.setOfOutputDags.indexOf(g)][g.getNodes().indexOf(nodei)][g.getNodes().indexOf(pai)] = npaths;
-// }
-//
-// }
-//
-// }
-//
-// }
-
-// public Dag computeWeightDag(boolean w){
-//
-// Dag wDag = new Dag(this.alfa);
-// for(Node nodei : this.alfa){
-// for(Node nodej : this.alfa){
-// if(nodei.equals(nodej)) continue;
-// int wij = 0;
-// for(Dag g: this.setOfOutputDags){
-// int wg = this.weight[this.setOfOutputDags.indexOf(g)][g.getNodes().indexOf(nodei)][g.getNodes().indexOf(nodej)];
-// if(wg > 0 && !w) wg = 1;
-// else if (wg == 0) wg =-1;
-// wij+=wg;
-// }
-// if(wij > 0) wDag.addEdge(new Edge(nodej,nodei,Endpoint.TAIL,Endpoint.ARROW));
-// }
-// }
-// return wDag;
-// }
-//
-// public static void main(String args[]) {
-//
-// ArrayList dags = new ArrayList();
-// ArrayList alfa = new ArrayList();
-// Random aleatorio = new Random(222);
-//
-//
-// System.out.println("Grafos de Partida: ");
-// System.out.println("---------------------");
-//// Graph graph = GraphConverter.convert("X1-->X5,X2-->X3,X3-->X4,X4-->X1,X4-->X5");
-//// Dag dag = new Dag(graph);
-//
-// Dag dag = new Dag();
-//
-// dag = GraphUtils.randomDag(Integer.parseInt(args[0]), Integer.parseInt(args[1]), true);
-// BetaToAlpha mt = new BetaToAlpha(dag);
-// alfa = mt.randomAlfa(aleatorio);
-// dags.add(dag);
-// System.out.println("DAG: ---------------");
-// System.out.println(dag.toString());
-// for (int i=0 ; i < Integer.parseInt(args[2])-1 ; i++){
-// Dag newDag = GraphUtils.randomDag(dag.getNodes(),Integer.parseInt(args[1]) ,true);
-// dags.add(newDag);
-// System.out.println("DAG: ---------------");
-// System.out.println(newDag.toString());
-// }
-//
-//
-//
-// System.out.println("Orden de Consenso: " + alfa.toString());
-//
-// TransformDags setOfDags = new TransformDags(dags,alfa);
-// setOfDags.transform();
-//
-//
-//
-//
-// for(Dag d : setOfDags.setOfOutputDags){
-// System.out.println("DAG trasformado: ---------------");
-// System.out.println(d.toString());
-// }
-//
-//
-//
-// Dag union = new Dag(alfa);
-//
-// for(Node nodei: alfa){
-// for(Dag d : setOfDags.setOfOutputDags){
-// Listparent = d.getParents(nodei);
-// for(Node pa: parent){
-// if(!union.isParentOf(pa, nodei)) union.addEdge(new Edge(pa,nodei,Endpoint.TAIL,Endpoint.ARROW));
-// }
-// }
-//
-// }
-//
-//
-// System.out.println("Grafo UNION: "+union.toString());
-// setOfDags.computeWeight();
-// Dag wDag = setOfDags.computeWeightDag(true);
-// System.out.println("Grafo Consenso: "+ wDag.toString());
-// Dag wDag2 = setOfDags.computeWeightDag(false);
-// System.out.println("Grafo Consenso sin pesos: "+ wDag2.toString());
-//
-//
-//
-//
-//
-//
-//
-//// Node nod = dag.getNodes().get(aleatorio.nextInt(alfa.size()));
-//// ArrayList a = new ArrayList();
-//// a.add(nod);
-//// List anc = dag.getAncestors(a);
-////
-//// System.out.println("Ancenstros de " + nod.toString()+ " "+anc.toString());
-////
-//// System.out.println("Subgraph: "+ dag.subgraph(anc));
-////
-//// List pa = dag.getParents(nod);
-//// System.out.println("padres de: "+nod.toString()+ " : "+pa.toString());
-//// Node pai = pa.get(aleatorio.nextInt(pa.size()));
-//// System.out.println("Padre elegido a borrar: "+pai.toString());
-//// pa.remove(pai);
-////
-////
-////
-//// for(Node rm: pa) anc.remove(rm);
-////
-//// Graph pp = dag.subgraph(anc);
-//// int npath = 0;
-//// for(Node ancestor: pp.getNodes()){
-//// if(!ancestor.equals(nod)){
-//// List> paths = GraphUtils.allPathsFromTo(pp, ancestor, nod);
-//// if(dag.paths().existsDirectedPath(ancestor, nod)) npath++;
-//// System.out.println("Caminos desde: "+ancestor.toString()+" a: "+nod.toString()+" : "+paths.toString());
-//// }
-//// }
-//// System.out.println(" El numero de caminos desde hacia: "+ nod+" es de: "+npath);
-// }
-//
}
diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
index d4c7a7f..0c7d970 100644
--- a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
+++ b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java
@@ -1,32 +1,54 @@
package es.uclm.i3a.simd.consensusBN;
+import java.util.ArrayDeque;
import java.util.ArrayList;
+import java.util.Deque;
+import java.util.HashSet;
+import java.util.LinkedList;
import java.util.List;
+import java.util.Set;
+import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
+import edu.cmu.tetrad.graph.Edges;
+import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
-public class Utils {
+/**
+ * Utility class providing static methods for graph operations, particularly for transforming PDAGs to DAGs and checking d-separation.
+ * This class includes methods for converting a PDAG to a DAG, checking if two nodes are d-separated in a DAG, and finding neighbors of nodes.
+ */
+public final class Utils {
- /**
- * pdagToDag Algorithm From Chickering 2002:
- * "We first consider a simple implementation of PDAG-to-DAG due to Dor and Tarsi (1992).
- * Let NX denote the neighbors of node X in a PDAG P.
- * We first create a DAG G that contains all of the directed edges from P, and no other edges.
- * We then repeat the following procedure:
- * First, select a node X in P such that:
- * (1) X has no out-going edges and
- * (2) if NX is non-empty, then NX PaX is a clique.
- * If P admits a consistent extension, the node X is guaranteed to exist.
- * Next, for each undirected edge Y X incident to X in P, insert a directed edge Y X to G.
- * Finally, remove X and all incident edges from the P and continue with the next node.
- * The algorithm terminates when all nodes have been deleted from P."
- * @param graph The graph to be transformed from PDAG to DAG.
- */
+ /**
+ * Transforms a PDAG (Partially Directed Acyclic Graph) into a DAG (Directed Acyclic Graph)
+ * using the algorithm proposed by Dor and Tarsi (1992), as presented in Chickering (2002).
+ *
+ * The algorithm proceeds as follows:
+ *
+ * - Let NX be the set of neighbors of node X in the PDAG P.
+ * - Create a new DAG G containing all the directed edges from P (and no others).
+ * - Iteratively repeat the following steps:
+ *
+ * - Select a node X such that:
+ *
+ * - (1) X has no outgoing directed edges in P, and
+ * - (2) if NX is non-empty, then NX ∪ Pa(X) forms a clique.
+ *
+ * Such a node is guaranteed to exist if P admits a consistent extension.
+ * - For each undirected edge Y—X incident to X in P, orient it as Y → X in G.
+ * - Remove node X and all its incident edges from P.
+ *
+ *
+ * - The algorithm terminates when all nodes have been removed from P.
+ *
+ *
+ * @param graph The input PDAG to be converted into a DAG.
+ */
public static void pdagToDag(Graph graph){
// First create a DAG G that contains all of the directed edges from the PDAG, and no other edges.
Graph graphAux = new EdgeListGraph(graph);
@@ -45,7 +67,7 @@ public static void pdagToDag(Graph graph){
for(Node node : nodes){
x = node;
//Checking if the node has no outgoing edges
- if(graphAux.getChildren(node).size() != 0)
+ if(!graphAux.getChildren(node).isEmpty())
continue;
//Checking if the neighbors form a clique
if(!GraphUtils.isClique(graphAux.getAdjacentNodes(x), graphAux))
@@ -62,7 +84,203 @@ public static void pdagToDag(Graph graph){
break; // We break the loop to start again with the new graphAux without the removed node.
}
nodes.remove(x);
- }while(nodes.size() > 0);
+ }while(!nodes.isEmpty());
+ }
+
+ /**
+ * Checks if two nodes in a DAG are d-separated given an empty set of conditioning nodes.
+ * @param g The DAG to check for d-separation.
+ * @param x The first node.
+ * @param y The second node.
+ * @return True if the nodes are d-separated, false otherwise.
+ */
+ public static boolean dSeparated(Dag g, Node x, Node y) {
+ return dSeparated(g, x, y, new ArrayList<>());
+ }
+
+ /**
+ * Checks if two nodes in a DAG are d-separated given a set of conditioning nodes.
+ * This method uses a defensive copy of the conditioning set to ensure immutability.
+ * It builds an induced subgraph of the DAG containing only the relevant nodes and checks if there is a path between the two nodes that does not pass through the conditioning nodes.
+ * The method first finds the relevant nodes in the DAG, builds an induced subgraph, moralizes it,
+ * converts it to an undirected graph, and finally checks if the two nodes are reachable from each other in the undirected graph after removing the conditioning nodes.
+ * @param g The DAG to check for d-separation.
+ * @param x The first node.
+ * @param y The second node.
+ * @param cond The list of conditioning nodes.
+ * @return True if the nodes are d-separated, false otherwise.
+ */
+ public static boolean dSeparated(Dag g, Node x, Node y, List cond) {
+
+ Set relevantNodes = findRelevantNodes(g, x, y, cond);
+ Graph aux = buildInducedSubgraph(g, relevantNodes);
+ moralize(aux);
+ convertToUndirected(aux);
+ aux.removeNodes(cond);
+ return !isReachable(aux, x, y);
+ }
+
+ /**
+ * Finds the relevant nodes in the DAG that are needed to check d-separation between two nodes x and y, given a set of conditioning nodes.
+ * This method performs a depth-first search starting from nodes x and y, and includes all nodes that are reachable from either x or y, as well as the conditioning nodes.
+ * The method uses a stack to explore the graph and a set to keep track of visited nodes, ensuring that
+ * all relevant nodes are included in the final set.
+ * @param g The DAG in which to find the relevant nodes.
+ * @param x The first node.
+ * @param y The second node.
+ * @param cond The list of conditioning nodes.
+ * @return A set of nodes that are relevant for checking d-separation between x and y, including the conditioning nodes.
+ */
+ private static Set findRelevantNodes(Dag g, Node x, Node y, List cond) {
+ Set visited = new HashSet<>();
+ Deque stack = new ArrayDeque<>();
+
+ stack.push(x);
+ stack.push(y);
+ for (Node c : cond) stack.push(c);
+
+ while (!stack.isEmpty()) {
+ Node current = stack.pop();
+ if (visited.add(current)) {
+ for (Node parent : g.getParents(current)) {
+ stack.push(parent);
+ }
+ }
+ }
+
+ return visited;
+ }
+
+ /**
+ * Builds an induced subgraph from the given DAG containing only the nodes specified in the set {@code nodesToKeep}.
+ * The method creates a new graph that includes all nodes in {@code nodesToKeep} and all directed edges between them that exist in the original graph.
+ * It ensures that only directed edges are included, and undirected edges are ignored.
+ * @param g The original DAG from which to build the induced subgraph.
+ * @param nodesToKeep The set of nodes to include in the induced subgraph.
+ * This set should contain the nodes that are relevant for the d-separation check.
+ * @return A new graph representing the induced subgraph containing only the specified nodes and their directed edges.
+ * This graph is a subgraph of the original DAG, containing only the nodes in {@code nodesToKeep} and the directed edges between them that exist in the original graph.
+ */
+ private static Graph buildInducedSubgraph(Dag g, Set nodesToKeep) {
+ Graph subgraph = new EdgeListGraph();
+
+ for (Node node : g.getNodes()) {
+ if (nodesToKeep.contains(node)) {
+ subgraph.addNode(node);
+ }
+ }
+
+ for (Edge e : g.getEdges()) {
+ if (!e.isDirected()) continue;
+
+ Node tail = e.getNode1();
+ Node head = e.getNode2();
+
+ if (nodesToKeep.contains(tail) && nodesToKeep.contains(head)) {
+ subgraph.addEdge(new Edge(tail, head, e.getEndpoint1(), e.getEndpoint2()));
+ }
+ }
+
+ return subgraph;
+ }
+
+ /**
+ * Moralizes the given graph by adding undirected edges between all pairs of parents of each child node.
+ * This process ensures that the resulting graph is undirected and that all parents of each child are connected.
+ * The moralization is done by iterating over each child node, retrieving its parents, and adding undirected edges between every pair of parents.
+ * This is a crucial step in preparing the graph for d-separation checks, as it ensures that the graph structure reflects the necessary connections between parents of child nodes.
+ * @param graph The graph to moralize.
+ */
+ private static void moralize(Graph graph) {
+ for (Node child : graph.getNodes()) {
+ List parents = graph.getParents(child);
+ int n = parents.size();
+ if (n <= 1) continue;
+
+ for (int i = 0; i < n - 1; i++) {
+ for (int j = i + 1; j < n; j++) {
+ Node p1 = parents.get(i);
+ Node p2 = parents.get(j);
+ if (!graph.isAdjacentTo(p1, p2)) {
+ graph.addUndirectedEdge(p1, p2);
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Converts all directed edges in the graph to undirected edges.
+ * This method iterates through all edges in the graph and changes their endpoints to TAIL,
+ * effectively removing the directionality of the edges. This is useful for certain graph operations
+ * where the direction of edges is not relevant, such as when checking connectivity or performing undirected graph algorithms.
+ * @param graph The graph to convert.
+ */
+ private static void convertToUndirected(Graph graph) {
+ for (Edge e : new ArrayList<>(graph.getEdges())) {
+ if (e.isDirected()) {
+ e.setEndpoint1(Endpoint.TAIL);
+ e.setEndpoint2(Endpoint.TAIL);
+ }
+ }
+ }
+
+ /**
+ * Checks if there is a path between two nodes in the graph.
+ * This method performs a depth-first search starting from the {@code start} node and checks
+ * if it can reach the {@code target} node. It uses a stack to explore the graph and a set to keep track of visited nodes,
+ * ensuring that it does not revisit nodes.
+ * If the target node is found during the search, it returns true; otherwise,
+ * it returns false after exhausting all possible paths.
+ * @param g The graph to search.
+ * @param start The starting node.
+ * @param target The target node.
+ * @return True if there is a path from start to target, false otherwise.
+ */
+ private static boolean isReachable(Graph g, Node start, Node target) {
+ Set visited = new HashSet<>();
+ Deque stack = new ArrayDeque<>();
+ stack.push(start);
+
+ while (!stack.isEmpty()) {
+ Node current = stack.pop();
+ if (current.equals(target)) return true;
+ if (visited.add(current)) {
+ for (Node neighbor : g.getAdjacentNodes(current)) {
+ if (!visited.contains(neighbor)) {
+ stack.push(neighbor);
+ }
+ }
+ }
+ }
+
+ return false;
+ }
+
+
+ /**
+ * Finds the nodes that are neighbors of node y and x in the graph.
+ * This method retrieves the neighbors of node y that are also adjacent to node x,
+ * ensuring that the edges between them are directed.
+ * It filters out undirected edges to ensure that only neighbors from directed edges are considered.
+ * @param x Node x to find neighbors for.
+ * @param y Node y to find neighbors for.
+ * @param graph The graph in which to find the neighbors.
+ * @return A list of nodes that are neighbors of x and y, filtered to include only neighbors from directed edges.
+ */
+ public static List findNaYX(Node x, Node y, Graph graph) {
+ List naYX = new LinkedList<>(graph.getAdjacentNodes(y));
+ naYX.retainAll(graph.getAdjacentNodes(x));
+
+ for (int i = naYX.size()-1; i >= 0; i--) {
+ Node z = naYX.get(i);
+ Edge edge = graph.getEdge(y, z);
+
+ if (!Edges.isUndirectedEdge(edge)) {
+ naYX.remove(z);
+ }
+ }
+
+ return naYX;
}
-
}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java
new file mode 100644
index 0000000..e965021
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java
@@ -0,0 +1,119 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+class AlphaOrderTest {
+
+ private Node a, b, c;
+ private Dag dag1, dag2;
+ private ArrayList dags;
+
+ @BeforeEach
+ void setup() {
+ a = new GraphNode("A");
+ b = new GraphNode("B");
+ c = new GraphNode("C");
+
+ // DAG 1: A → B → C
+ dag1 = new Dag();
+ dag1.addNode(a);
+ dag1.addNode(b);
+ dag1.addNode(c);
+ dag1.addDirectedEdge(a, b);
+ dag1.addDirectedEdge(b, c);
+
+ // DAG 2: A → B, A → C
+ dag2 = new Dag();
+ dag2.addNode(a);
+ dag2.addNode(b);
+ dag2.addNode(c);
+ dag2.addDirectedEdge(a, b);
+ dag2.addDirectedEdge(a, c);
+
+ dags = new ArrayList<>(Arrays.asList(dag1, dag2));
+ }
+
+ @Test
+ void constructorThrowsOnNullInput() {
+ assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(null));
+ }
+
+ @Test
+ void constructorThrowsOnEmptyList() {
+ assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(new ArrayList<>()));
+ }
+
+ @Test
+ void constructorThrowsOnSingleDAG() {
+ ArrayList singleDagList = new ArrayList<>();
+ singleDagList.add(dag1);
+ assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(singleDagList));
+ }
+
+ @Test
+ void constructorThrowsOnDifferentNodes() {
+ // Crear otro DAG con nodos diferentes
+ Dag dagDifferent = new Dag();
+ dagDifferent.addNode(new GraphNode("X"));
+ dagDifferent.addNode(new GraphNode("Y"));
+ dagDifferent.addDirectedEdge(dagDifferent.getNode("X"), dagDifferent.getNode("Y"));
+
+ ArrayList badList = new ArrayList<>(Arrays.asList(dag1, dagDifferent));
+ assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(badList));
+ }
+
+ @Test
+ void computeAlphaReturnsValidOrder() {
+ AlphaOrder alphaOrder = new AlphaOrder(dags);
+ alphaOrder.computeAlpha();
+ List order = alphaOrder.getOrder();
+
+ assertNotNull(order);
+ assertEquals(3, order.size());
+ assertTrue(order.contains(a));
+ assertTrue(order.contains(b));
+ assertTrue(order.contains(c));
+
+ // Optional: check for uniqueness
+ assertEquals(3, order.stream().distinct().count());
+ }
+
+ @Test
+ void computeAlphaForTwoSimpleDags(){
+ AlphaOrder alphaOrder = new AlphaOrder(dags);
+ alphaOrder.computeAlpha();
+ List order = alphaOrder.getOrder();
+
+ // Basic assertions to check the order
+ assertNotNull(order);
+ assertEquals(3, order.size());
+ assertTrue(order.contains(a));
+ assertTrue(order.contains(b));
+ assertTrue(order.contains(c));
+
+ // Check that the order is A, B, C
+ assertEquals(a, order.get(0));
+ assertEquals(b, order.get(1));
+ assertEquals(c, order.get(2));
+
+ // Check for uniqueness
+ assertEquals(3, order.stream().distinct().count());
+
+ }
+}
+
+
+
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
new file mode 100644
index 0000000..b3ca80f
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java
@@ -0,0 +1,110 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.GraphUtils;
+import edu.cmu.tetrad.graph.Node;
+
+class BackwardEquivalenceSearchDSepTest {
+
+
+ private ArrayList createRandomDagList(int copies) {
+ ArrayList setOfDags = new ArrayList<>();
+ setOfDags.addAll(GraphTestHelper.generateRandomDagList(20, copies, 50, 49, 49, 49, true, 0));
+ return setOfDags;
+ }
+
+ @Test
+ void testApplyBESdDoesNotThrow() {
+ // Setting up consensus union
+ ArrayList initialDags = createRandomDagList(3);
+ ConsensusUnion consensusUnion = new ConsensusUnion(initialDags);
+ Dag unionDag = consensusUnion.union();
+
+ // Running Backward Equivalence Search with d-separation
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, consensusUnion.getTransformedDags());
+
+ // No exceptions should be thrown during the process
+ assertDoesNotThrow(() -> {
+ Dag output = besd.applyBackwardEliminationWithDSeparation();
+ assertNotNull(output);
+ });
+ }
+
+ @Test
+ void testOutputIsDAG() {
+ // Setting up consensus union
+ ArrayList initialDags = createRandomDagList(3);
+ ConsensusUnion consensusUnion = new ConsensusUnion(initialDags);
+ Dag unionDag = consensusUnion.union();
+
+ // Running Backward Equivalence Search with d-separation
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, consensusUnion.getTransformedDags());
+ Dag outputDag = besd.applyBackwardEliminationWithDSeparation();
+
+ assertTrue(GraphUtils.isDag(outputDag), "El resultado no es un DAG válido.");
+ }
+
+ @Test
+ void testAristasSePuedenEliminar() {
+ // Creamos un grafo donde A -> B, pero en todos los DAGs está A B (no conectados)
+ Node a = new GraphNode("A");
+ Node b = new GraphNode("B");
+
+ Dag unionDag = new Dag();
+ unionDag.addNode(a);
+ unionDag.addNode(b);
+ unionDag.addDirectedEdge(a, b);
+
+ // DAGs originales sin esa arista
+ Dag dag1 = new Dag();
+ dag1.addNode(a);
+ dag1.addNode(b);
+
+ Dag dag2 = new Dag();
+ dag2.addNode(a);
+ dag2.addNode(b);
+ // Aquí no hay aristas, A y B están desconectados
+ // sin conexión
+
+ ArrayList initialDags = new ArrayList<>();
+ initialDags.add(dag1);
+ initialDags.add(dag2);
+ AlphaOrder alphaOrder = new AlphaOrder(initialDags);
+ alphaOrder.computeAlpha();
+ ArrayList transformedDags = (new TransformDags(initialDags, alphaOrder.getOrder())).transform();
+
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
+ Dag outputDag = besd.applyBackwardEliminationWithDSeparation();
+
+ // Debe eliminar la arista A -> B por no tener soporte
+ Edge deletedEdge = outputDag.getEdge(a, b);
+ assertNull(deletedEdge, "La arista A -> B debería haberse eliminado.");
+ }
+
+ @Test
+ void testGetNumberOfInsertedEdgesReflectsChanges() {
+ ArrayList initialDags = createRandomDagList(2);
+ ConsensusUnion consensusUnion = new ConsensusUnion(initialDags);
+ Dag unionDag = consensusUnion.union();
+ ArrayList transformedDags = consensusUnion.getTransformedDags();
+ int insertedEdgesBefore = consensusUnion.getNumberOfInsertedEdges();
+
+ BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags);
+ besd.applyBackwardEliminationWithDSeparation();
+
+ int insertedEdgesAfter = insertedEdgesBefore - besd.getNumberOfRemovedEdges();
+ // En el peor de los casos no ha eliminado ninguna, pero nunca debe ser negativo
+ assertTrue(insertedEdgesAfter >= 0, "The number of inserted edges should not be negative.");
+ assertTrue(insertedEdgesAfter <= insertedEdgesBefore, "The number of inserted edges should decrease after BES.");
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java
new file mode 100644
index 0000000..e3900e4
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java
@@ -0,0 +1,109 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class BetaToAlphaTest {
+
+ private Dag dag;
+ private Node a, b, c, d;
+ private ArrayList alphaOrder;
+
+ @BeforeEach
+ void setUp() {
+ a = new GraphNode("A");
+ b = new GraphNode("B");
+ c = new GraphNode("C");
+ d = new GraphNode("D");
+
+ // DAG: A → B, A → C, B → D, C → D
+ dag = new Dag();
+ dag.addNode(a);
+ dag.addNode(b);
+ dag.addNode(c);
+ dag.addNode(d);
+ dag.addDirectedEdge(a, b);
+ dag.addDirectedEdge(a, c);
+ dag.addDirectedEdge(b, d);
+ dag.addDirectedEdge(c, d);
+
+ // Define an alpha order that requires modifying the graph
+ alphaOrder = new ArrayList<>(Arrays.asList(d, c, b, a));
+ }
+
+ @Test
+ void testTransformRespectsAlphaOrder() {
+ BetaToAlpha bta = new BetaToAlpha(dag, alphaOrder);
+ bta.transform();
+
+ // El grafo debería haber invertido al menos algunos arcos
+ assertTrue(bta.getNumberOfInsertedEdges() > 0);
+
+ // Validamos que el orden resultante es compatible con alpha
+ for (Edge edge : dag.getEdges()) {
+ Node from = edge.getNode1();
+ Node to = edge.getNode2();
+
+ int fromIndex = alphaOrder.indexOf(from);
+ int toIndex = alphaOrder.indexOf(to);
+
+ assertTrue(fromIndex < toIndex, "Edge violates alpha order: " + from + " → " + to);
+ }
+ }
+
+ @Test
+ void testRandomAlphaProducesPermutation() {
+ BetaToAlpha bta = new BetaToAlpha(dag);
+ List randomAlpha = bta.randomAlpha(new Random(42));
+
+ assertNotNull(randomAlpha);
+ assertEquals(dag.getNumNodes(), randomAlpha.size());
+
+ Set originalNodes = new HashSet<>(dag.getNodes());
+ Set shuffled = new HashSet<>(randomAlpha);
+
+ assertEquals(originalNodes, shuffled); // misma colección, diferente orden
+ }
+
+ @Test
+ void testComputeAlphaHashBuildsCorrectMap() {
+ BetaToAlpha bta = new BetaToAlpha(dag, alphaOrder);
+ bta.computeAlphaHash();
+
+ for (int i = 0; i < alphaOrder.size(); i++) {
+ Node node = alphaOrder.get(i);
+ assertEquals(i, (bta.getAlphaHash()).get(node));
+ }
+ }
+
+
+ @Test
+ public void setterAndGetterTest(){
+ ArrayList firstOrder = new ArrayList<>(Arrays.asList(a, b, c));
+ BetaToAlpha b2Alpha = new BetaToAlpha(dag, firstOrder);
+ List newOrder = new ArrayList<>(Arrays.asList(c, b, a));
+ b2Alpha.setAlphaOrder(newOrder);
+
+ List order = b2Alpha.getAlphaOrder();
+ assertNotNull(order);
+ assertEquals(3, order.size());
+ assertTrue(order.contains(c));
+ assertTrue(order.contains(b));
+ assertTrue(order.contains(a));
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
new file mode 100644
index 0000000..f704120
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java
@@ -0,0 +1,196 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class ConsensusBESTest {
+ private ArrayList inputDags;
+ private ArrayList alpha;
+
+ @BeforeEach
+ public void setUp() {
+ inputDags = new ArrayList<>();
+
+ // We use 4 nodes for the DAGs
+ Node nodeA = new GraphNode("A");
+ Node nodeB = new GraphNode("B");
+ Node nodeC = new GraphNode("C");
+ Node nodeD = new GraphNode("D");
+
+ // Create first DAG with these edges: A -> B, A -> C, B -> D, C -> D
+ Dag dag1 = new Dag();
+ dag1.addNode(nodeA);
+ dag1.addNode(nodeB);
+ dag1.addNode(nodeC);
+ dag1.addNode(nodeD);
+
+ // Adding directed edges to the DAG
+ dag1.addDirectedEdge(nodeA, nodeB);
+ dag1.addDirectedEdge(nodeA, nodeC);
+ dag1.addDirectedEdge(nodeB, nodeD);
+ dag1.addDirectedEdge(nodeC, nodeD);
+
+ // Adding the DAG to the list
+ inputDags.add(dag1);
+
+ // Create second DAG with these edges: D -> C, D -> B, C -> A, B -> A
+ Dag dag2 = new Dag();
+ dag2.addNode(nodeA);
+ dag2.addNode(nodeB);
+ dag2.addNode(nodeC);
+ dag2.addNode(nodeD);
+
+ // Adding directed edges to the second DAG
+ dag2.addDirectedEdge(nodeD, nodeC);
+ dag2.addDirectedEdge(nodeD, nodeB);
+ dag2.addDirectedEdge(nodeC, nodeA);
+ dag2.addDirectedEdge(nodeB, nodeA);
+
+ // Adding the second DAG to the list
+ inputDags.add(dag2);
+
+ // Apply AlphaOrder algorithm to these dags:
+ AlphaOrder alphaOrder = new AlphaOrder(inputDags);
+ alphaOrder.computeAlpha();
+ alpha = alphaOrder.getOrder();
+
+ }
+
+ @Test
+ public void testConsensusUnionConsistency() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ Dag expected = cu.union();
+ assertNotNull(cu);
+ assertNotNull(expected);
+
+ ConsensusBES consensusBES = new ConsensusBES(inputDags);
+ consensusBES.consensusUnion();
+
+ // Check if the union DAG is not null and has nodes
+ Dag unionDag = consensusBES.getUnion();
+ assertNotNull(unionDag);
+ assertNotNull(unionDag.getNodes());
+
+ // Check that the union DAG is equal to the expected DAG
+ assertNotNull(unionDag.getNodes());
+ assertNotNull(expected.getNodes());
+ assertEquals(expected, unionDag);
+ assertEquals(expected.getNodes().size(), unionDag.getNodes().size());
+ assertEquals(expected.getEdges().size(), unionDag.getEdges().size());
+
+ for (Node node : expected.getNodes()) {
+ assert unionDag.getNodes().contains(node);
+ }
+ for(Edge edge : expected.getEdges()) {
+ assert unionDag.getEdges().contains(edge);
+ }
+ }
+
+ @Test
+ public void testRandomBNFusion(){
+ // Creating random DAGs
+ ArrayList randomDagsList = new ArrayList<>();
+ randomDagsList.addAll(GraphTestHelper.generateRandomDagList(20, 2, 50, 19, 19, 19, true,0));
+
+ // Creating ConsensusUnion instance
+ ConsensusUnion consensusUnionOnly = new ConsensusUnion(randomDagsList);
+ Dag unionDagOnly = consensusUnionOnly.union();
+
+ assertNotNull(unionDagOnly);
+ assertTrue(unionDagOnly.getNumEdges() >= 0);
+ assertTrue(unionDagOnly.getNodes().size() == randomDagsList.get(0).getNodes().size());
+
+
+ ConsensusBES conDag = new ConsensusBES(randomDagsList);
+ conDag.fusion();
+ Dag besDag = conDag.getFusionDag();
+ Dag unionDag = conDag.getUnion();
+ ConsensusUnion consensusUnion = conDag.getConsensusUnion();
+ int totalNumberOfInsertedEdges = conDag.getNumberOfInsertedEdges();
+ int consensusNumberOfInsertedEdges = consensusUnion.getNumberOfInsertedEdges();
+
+ assertNotNull(besDag);
+ assertNotNull(unionDag);
+ assertNotNull(consensusUnion);
+ assertEquals(unionDagOnly, unionDag);
+ assertEquals(besDag.getNodes().size(), unionDag.getNodes().size());
+ assert consensusNumberOfInsertedEdges >= 0;
+ assert consensusNumberOfInsertedEdges >= totalNumberOfInsertedEdges;
+ }
+
+
+ @Test
+ void testFusionProducesDag() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ fusionAlgorithm.fusion();
+
+ Dag result = fusionAlgorithm.getFusionDag();
+ assertNotNull(result, "El DAG de salida no debe ser null.");
+ assertFalse(result.paths().existsDirectedCycle(), "El DAG resultante no debe tener ciclos.");
+ }
+
+ @Test
+ void testEdgeInsertionCountIsCorrectlyComputed() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ fusionAlgorithm.fusion();
+
+ int insertedEdges = fusionAlgorithm.getNumberOfInsertedEdges();
+ assertTrue(insertedEdges >= 0, "El número de aristas insertadas debe ser >= 0.");
+ }
+
+ @Test
+ void testFusionOrderIsValid() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ fusionAlgorithm.fusion();
+
+ List order = fusionAlgorithm.getOrderFusion();
+ assertNotNull(order, "El orden de fusión no debe ser null.");
+ assertEquals(4, order.size(), "El orden de fusión debe tener 3 nodos.");
+ }
+
+ @Test
+ void testTransformedDagsAreAccessibleAfterFusion() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ fusionAlgorithm.fusion();
+
+ ArrayList transformed = fusionAlgorithm.getTransformedDags();
+ assertEquals(2, transformed.size(), "Debe haber 2 DAGs transformados.");
+ }
+
+ @Test
+ void testGetTransformedDagsWithoutFusionThrowsException() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+
+ assertThrows(IllegalStateException.class, fusionAlgorithm::getTransformedDags,
+ "Debe lanzar una excepción si se accede a los DAGs transformados sin llamar a fusion().");
+ }
+
+ @Test
+ void testThreadExecutionWithRunMethod() {
+ ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags);
+ Thread thread = new Thread(fusionAlgorithm);
+ thread.start();
+ try {
+ thread.join();
+ } catch (InterruptedException e) {
+ fail("El hilo fue interrumpido.");
+ }
+
+ assertNotNull(fusionAlgorithm.getFusionDag(), "El DAG resultante debe existir tras ejecutar run().");
+ }
+
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java
new file mode 100644
index 0000000..09de508
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java
@@ -0,0 +1,151 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Graph;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+import edu.cmu.tetrad.graph.RandomGraph;
+import edu.cmu.tetrad.util.RandomUtil;
+
+public class ConsensusUnionTest {
+
+ private ArrayList inputDags;
+ private ArrayList alpha;
+
+ @BeforeEach
+ public void setUp() {
+ inputDags = new ArrayList<>();
+
+ // We use 4 nodes for the DAGs
+ Node nodeA = new GraphNode("A");
+ Node nodeB = new GraphNode("B");
+ Node nodeC = new GraphNode("C");
+ Node nodeD = new GraphNode("D");
+
+ // Create first DAG with these edges: A -> B, A -> C, B -> D, C -> D
+ Dag dag1 = new Dag();
+ dag1.addNode(nodeA);
+ dag1.addNode(nodeB);
+ dag1.addNode(nodeC);
+ dag1.addNode(nodeD);
+
+ // Adding directed edges to the DAG
+ dag1.addDirectedEdge(nodeA, nodeB);
+ dag1.addDirectedEdge(nodeA, nodeC);
+ dag1.addDirectedEdge(nodeB, nodeD);
+ dag1.addDirectedEdge(nodeC, nodeD);
+
+ // Adding the DAG to the list
+ inputDags.add(dag1);
+
+ // Create second DAG with these edges: D -> C, D -> B, C -> A, B -> A
+ Dag dag2 = new Dag();
+ dag2.addNode(nodeA);
+ dag2.addNode(nodeB);
+ dag2.addNode(nodeC);
+ dag2.addNode(nodeD);
+
+ // Adding directed edges to the second DAG
+ dag2.addDirectedEdge(nodeD, nodeC);
+ dag2.addDirectedEdge(nodeD, nodeB);
+ dag2.addDirectedEdge(nodeC, nodeA);
+ dag2.addDirectedEdge(nodeB, nodeA);
+
+ // Adding the second DAG to the list
+ inputDags.add(dag2);
+
+ // Apply AlphaOrder algorithm to these dags:
+ AlphaOrder alphaOrder = new AlphaOrder(inputDags);
+ alphaOrder.computeAlpha();
+ alpha = alphaOrder.getOrder();
+
+ }
+
+ @Test
+ public void testConstructorWithAlphaInitializesCorrectly() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ assertNotNull(cu);
+ }
+
+ @Test
+ public void testConstructorWithoutAlphaGeneratesAlpha() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags);
+ assertNotNull(cu);
+ }
+
+ @Test
+ public void testUnionReturnsNonNullDag() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ Dag result = cu.union();
+ assertNotNull(result);
+ }
+
+ @Test
+ public void testNumberOfInsertedEdgesIsUpdated() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ cu.union();
+ assertTrue(cu.getNumberOfInsertedEdges() >= 0);
+ }
+
+ @Test
+ public void testSetDagsUpdatesAlphaAndUnion() {
+ ConsensusUnion cu = new ConsensusUnion();
+ cu.setDags(inputDags);
+ cu.union();
+ Dag result = cu.getUnion();
+ assertNotNull(result);
+ assertTrue(result.getNumEdges() >= 1);
+ }
+
+ @Test
+ public void testRunMethodExecutesUnion() {
+ ConsensusUnion cu = new ConsensusUnion(inputDags, alpha);
+ cu.run();
+ assertNotNull(cu.getUnion());
+ }
+
+ @Test
+ public void testEmptyDagListReturnsEmptyUnion() {
+ assertThrows(IllegalArgumentException.class, () -> new ConsensusUnion(new ArrayList<>()));
+ }
+
+ @Test
+ public void testRandomBNGeneratesConsensusUnionCorrectly() {
+ ArrayList randomDagsList = new ArrayList<>();
+ int sizeRandomDags = 2;
+ int numVariables = 20;
+
+ // Creating list of shared nodes
+ ArrayList sharedNodes = new ArrayList<>();
+ for (int i = 0; i < numVariables; i++) {
+ Node node = new GraphNode("Node" + i);
+ sharedNodes.add(node);
+ }
+ // Setting seed
+ RandomUtil.getInstance().setSeed(42);
+
+ // Generating random DAGs
+ for (int i = 0; i < sizeRandomDags; i++) {
+ Dag randomDag = RandomGraph.randomDag(sharedNodes,0,50,19,19,19,true);
+ randomDagsList.add(randomDag);
+ }
+
+ // Applying ConsensusUnion
+ ConsensusUnion conDag = new ConsensusUnion(randomDagsList);
+ Graph g = conDag.union();
+
+ // Validating the resulting consensus DAG
+ assertNotNull(g);
+ assertTrue(g.getNumEdges() >= 0);
+ assertTrue(g.getNodes().size() == randomDagsList.get(0).getNodes().size());
+
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java
new file mode 100644
index 0000000..4d5d880
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java
@@ -0,0 +1,91 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+class DSeparationKeyTest {
+
+ private final Node X = new GraphNode("X");
+ private final Node Y = new GraphNode("Y");
+ private final Node Z = new GraphNode("Z");
+ private final Node W = new GraphNode("W");
+
+ @Test
+ void testConstructorAndGetters() {
+ Set zSet = new HashSet<>(Arrays.asList(Z, W));
+ DSeparationKey key = new DSeparationKey(X, Y, zSet);
+
+ assertEquals(X, key.getX());
+ assertEquals(Y, key.getY());
+ assertEquals(new HashSet<>(Arrays.asList(Z, W)), key.getConditioningSet());
+ }
+
+ @Test
+ void testEqualsSameContentDifferentOrder() {
+ Set zSet1 = new HashSet<>(Arrays.asList(Z, W));
+ Set zSet2 = new HashSet<>(Arrays.asList(W, Z));
+
+ DSeparationKey key1 = new DSeparationKey(X, Y, zSet1);
+ DSeparationKey key2 = new DSeparationKey(X, Y, zSet2);
+
+ assertEquals(key1, key2);
+ assertEquals(key1.hashCode(), key2.hashCode());
+ }
+
+ @Test
+ void testNotEqualsDifferentZ() {
+ Set zSet1 = new HashSet<>(Collections.singletonList(Z));
+ Set zSet2 = new HashSet<>(Arrays.asList(Z, W));
+
+ DSeparationKey key1 = new DSeparationKey(X, Y, zSet1);
+ DSeparationKey key2 = new DSeparationKey(X, Y, zSet2);
+
+ assertNotEquals(key1, key2);
+ }
+
+ @Test
+ void testSimmetryBetweenXandY() {
+ DSeparationKey key1 = new DSeparationKey(X, Y, Collections.emptySet());
+ DSeparationKey key2 = new DSeparationKey(Y, X, Collections.emptySet());
+
+ assertEquals(key1, key2);
+ }
+
+ @Test
+ void testEqualsSelf() {
+ DSeparationKey key = new DSeparationKey(X, Y, Collections.singleton(Z));
+ assertEquals(key, key);
+ }
+
+ @Test
+ void testNotEqualsNullOrDifferentClass() {
+ DSeparationKey key = new DSeparationKey(X, Y, Collections.singleton(Z));
+
+ assertNotEquals(null, key);
+ assertNotEquals("NotAKey", key);
+ }
+
+ @Test
+ void testKeyAsMapKey() {
+ DSeparationKey key1 = new DSeparationKey(X, Y, new HashSet<>(Arrays.asList(Z, W)));
+ DSeparationKey key2 = new DSeparationKey(X, Y, new HashSet<>(Arrays.asList(W, Z)));
+
+ Map map = new HashMap<>();
+ map.put(key1, true);
+
+ assertTrue(map.containsKey(key2));
+ assertTrue(map.get(key2));
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java
new file mode 100644
index 0000000..2f92250
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java
@@ -0,0 +1,333 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.Edges;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class DseparationTest {
+ // d-separation tests
+
+ private Node node(String name) {
+ return new GraphNode(name);
+ }
+
+ private Dag createDag(Edge... edges) {
+ Dag dag = new Dag();
+ for (Edge edge : edges) {
+ dag.addNode(edge.getNode1());
+ dag.addNode(edge.getNode2());
+ dag.addDirectedEdge(edge.getNode1(), edge.getNode2());
+ }
+ return dag;
+ }
+
+ @Test
+ public void testDirectConnection() {
+ Node A = node("A"), B = node("B");
+ Dag dag = createDag(Edges.directedEdge(A, B));
+ List conditioning = Collections.emptyList();
+ assertFalse(Utils.dSeparated(dag, A, B, conditioning));
+ }
+
+ @Test
+ public void testChainNoCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
+ List conditioning = Collections.emptyList();
+ assertFalse(Utils.dSeparated(dag, A, C, conditioning));
+ }
+
+ @Test
+ public void testChainWithCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
+
+ assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
+ }
+
+ @Test
+ public void testColliderNoCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B));
+ List conditioning = Collections.emptyList();
+ assertTrue(Utils.dSeparated(dag, A, C, conditioning));
+ }
+
+ @Test
+ public void testColliderConditionedOnCollider() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B));
+
+ assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
+ }
+
+ @Test
+ public void testDivergingNoCondition() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C));
+ List conditioning = Collections.emptyList();
+ assertFalse(Utils.dSeparated(dag, A, C, conditioning));
+ }
+
+ @Test
+ public void testDivergingConditionedOnCommonParent() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C));
+
+ assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B)));
+ }
+
+ @Test
+ public void testColliderConditionedOnDescendant() {
+ Node A = node("A"), B = node("B"), C = node("C"), D = node("D");
+ Dag dag = createDag(
+ Edges.directedEdge(A, B),
+ Edges.directedEdge(C, B),
+ Edges.directedEdge(B, D)
+ );
+
+ assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(D)));
+ }
+
+ @Test
+ public void testSimmetryBetweenXandY() {
+ Node A = node("A"), B = node("B"), C = node("C"), D = node("D");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, A), Edges.directedEdge(D, A));
+ List conditioning = Collections.emptyList();
+
+ // No colliders between A and B, so they are not d-separated
+ assertFalse(Utils.dSeparated(dag, A, B, conditioning));
+ assertFalse(Utils.dSeparated(dag, B, A, conditioning));
+
+ // C->A and D->A makes A a collider for C and D, and therefore C and D are d-separated from each other
+ assertTrue(Utils.dSeparated(dag, C, D, conditioning));
+ assertTrue(Utils.dSeparated(dag, D, C, conditioning));
+
+ }
+
+ @Test
+ public void testDseparationMethodsAreEquivalent() {
+ Node A = node("A"), B = node("B"), C = node("C");
+ Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C));
+
+ // Using the dSeparated method with no conditioning
+ assertFalse(Utils.dSeparated(dag, A, C));
+
+ // Using the dSeparated method with an empty conditioning set
+ assertFalse(Utils.dSeparated(dag, A, C, Collections.emptyList()));
+ }
+
+ @Test
+ public void testDseparationRule1(){
+ // Scenarios taken from https://yuyangyy.medium.com/understand-d-separation-471f9aada503
+ // Scenario 1: Z is empty, namely, we don't dondition on any variables
+ // Rule 1: If there exists a path from x to y and there is no collider on the path, then x and y are not d-separated.
+ Node x = new GraphNode("x");
+ Node r = new GraphNode("r");
+ Node s = new GraphNode("s");
+ Node t = new GraphNode("t");
+ Node u = new GraphNode("u");
+ Node v = new GraphNode("v");
+ Node y = new GraphNode("y");
+
+ Dag dag = new Dag();
+ dag.addNode(x);
+ dag.addNode(r);
+ dag.addNode(s);
+ dag.addNode(t);
+ dag.addNode(u);
+ dag.addNode(v);
+ dag.addNode(y);
+
+ // Edges: x -> r -> s -> t <- u <- v -> y
+ dag.addDirectedEdge(x, r);
+ dag.addDirectedEdge(r, s);
+ dag.addDirectedEdge(s, t);
+ dag.addDirectedEdge(u, t);
+ dag.addDirectedEdge(v, u);
+ dag.addDirectedEdge(v, y);
+
+ // Check d-separations
+ assertFalse(Utils.dSeparated(dag, x, r));
+ assertFalse(Utils.dSeparated(dag, x, s));
+ assertFalse(Utils.dSeparated(dag, x, t));
+
+ assertTrue(Utils.dSeparated(dag, x, u));
+ assertTrue(Utils.dSeparated(dag, x, v));
+ assertTrue(Utils.dSeparated(dag, x, y));
+
+ assertFalse(Utils.dSeparated(dag, u, v));
+ assertFalse(Utils.dSeparated(dag, u, y));
+ }
+
+ public void testDseparationRule2(){
+ // Scenarios taken from https://yuyangyy.medium.com/understand-d-separation-471f9aada503
+ // Scenario 2: Z is non-empty, and the colliders don't belong to Z or have no children in Z.
+ // Rule 2: If there exists a path from x to y and none of the nodes on the path belongs to Z, then x and y are not d-separated.
+ Node x = new GraphNode("x");
+ Node r = new GraphNode("r");
+ Node s = new GraphNode("s");
+ Node t = new GraphNode("t");
+ Node u = new GraphNode("u");
+ Node v = new GraphNode("v");
+ Node y = new GraphNode("y");
+
+ Dag dag = new Dag();
+ dag.addNode(x);
+ dag.addNode(r);
+ dag.addNode(s);
+ dag.addNode(t);
+ dag.addNode(u);
+ dag.addNode(v);
+ dag.addNode(y);
+
+ // Edges: x -> r -> s -> t <- u <- v -> y
+ dag.addDirectedEdge(x, r);
+ dag.addDirectedEdge(r, s);
+ dag.addDirectedEdge(s, t);
+ dag.addDirectedEdge(u, t);
+ dag.addDirectedEdge(v, u);
+ dag.addDirectedEdge(v, y);
+
+ // Creating a conditioning set Z that does not include any colliders
+ List Z = new ArrayList<>();
+ Z.add(r);
+ Z.add(v);
+
+ // Check d-separations
+ // Node x
+ assertTrue(Utils.dSeparated(dag, x, r, Z));
+ assertTrue(Utils.dSeparated(dag, x, s, Z));
+ assertTrue(Utils.dSeparated(dag, x, t, Z));
+ assertTrue(Utils.dSeparated(dag, x, u, Z));
+ assertTrue(Utils.dSeparated(dag, x, v, Z));
+ assertTrue(Utils.dSeparated(dag, x, y, Z));
+
+ // Node r
+ assertTrue(Utils.dSeparated(dag, r, s, Z));
+ assertTrue(Utils.dSeparated(dag, r, t, Z));
+ assertTrue(Utils.dSeparated(dag, r, u, Z));
+ assertTrue(Utils.dSeparated(dag, r, v, Z));
+ assertTrue(Utils.dSeparated(dag, r, y, Z));
+
+ // Node s
+ assertFalse(Utils.dSeparated(dag, s, t, Z)); // No node on the path belongs to Z, so d-separated
+ assertTrue(Utils.dSeparated(dag, s, u, Z));
+ assertTrue(Utils.dSeparated(dag, s, v, Z));
+ assertTrue(Utils.dSeparated(dag, s, y, Z));
+
+ // Node t
+ assertFalse(Utils.dSeparated(dag, t, u, Z)); // No node on the path belongs to Z, so d-separated
+ assertTrue(Utils.dSeparated(dag, t, v, Z));
+ assertTrue(Utils.dSeparated(dag, t, y, Z));
+
+ // Node u
+ assertTrue(Utils.dSeparated(dag, u, v, Z));
+ assertTrue(Utils.dSeparated(dag, u, y, Z));
+
+ // Node v
+ assertTrue(Utils.dSeparated(dag, v, y, Z));
+
+
+ }
+
+ @Test
+ public void testDseparationRule3(){
+ // Scenarios taken from https://yuyangyy.medium.com/understand-d-separation-471f9aada503
+ // Scenario 3: Z is non-empty, and there are colliders either inside Z or have children in Z.
+ // Rule 3: For colliders that fall inside Z or have children in Z, they are no longer seen as colliders.
+
+ Node x = new GraphNode("x");
+ Node r = new GraphNode("r");
+ Node s = new GraphNode("s");
+ Node t = new GraphNode("t");
+ Node u = new GraphNode("u");
+ Node v = new GraphNode("v");
+ Node y = new GraphNode("y");
+ Node w = new GraphNode("w");
+ Node p = new GraphNode("p");
+ Node q = new GraphNode("q");
+
+ Dag dag = new Dag();
+ dag.addNode(x);
+ dag.addNode(r);
+ dag.addNode(s);
+ dag.addNode(t);
+ dag.addNode(u);
+ dag.addNode(v);
+ dag.addNode(y);
+ dag.addNode(w);
+ dag.addNode(p);
+ dag.addNode(q);
+
+ // Edges: x -> r -> s -> t <- u <- v -> y + r->w, t->p, v->q
+ dag.addDirectedEdge(x, r);
+ dag.addDirectedEdge(r, s);
+ dag.addDirectedEdge(s, t);
+ dag.addDirectedEdge(u, t);
+ dag.addDirectedEdge(v, u);
+ dag.addDirectedEdge(v, y);
+ dag.addDirectedEdge(r, w);
+ dag.addDirectedEdge(t, p);
+ dag.addDirectedEdge(v, q);
+
+ // Creating a conditioning set Z that includes colliders and their children
+ List Z = new ArrayList<>();
+ Z.add(r);
+ Z.add(p);
+
+ // Check d-separations
+ // Node x
+ assertTrue(Utils.dSeparated(dag, x, s, Z));
+ assertTrue(Utils.dSeparated(dag, x, t, Z));
+ assertTrue(Utils.dSeparated(dag, x, u, Z));
+ assertTrue(Utils.dSeparated(dag, x, v, Z));
+ assertTrue(Utils.dSeparated(dag, x, y, Z));
+ assertTrue(Utils.dSeparated(dag, x, w, Z));
+ assertTrue(Utils.dSeparated(dag, x, q, Z));
+
+ // Node w
+ assertTrue(Utils.dSeparated(dag, w, s, Z));
+ assertTrue(Utils.dSeparated(dag, w, t, Z));
+ assertTrue(Utils.dSeparated(dag, w, u, Z));
+ assertTrue(Utils.dSeparated(dag, w, v, Z));
+ assertTrue(Utils.dSeparated(dag, w, y, Z));
+ assertTrue(Utils.dSeparated(dag, w, q, Z));
+
+ // Node s
+ assertFalse(Utils.dSeparated(dag, s, t, Z));
+ assertFalse(Utils.dSeparated(dag, s, u, Z));
+ assertFalse(Utils.dSeparated(dag, s, v, Z));
+ assertFalse(Utils.dSeparated(dag, s, q, Z));
+ assertFalse(Utils.dSeparated(dag, s, y, Z));
+
+ // Node t
+ assertFalse(Utils.dSeparated(dag, t, u, Z));
+ assertFalse(Utils.dSeparated(dag, t, v, Z));
+ assertFalse(Utils.dSeparated(dag, t, y, Z));
+ assertFalse(Utils.dSeparated(dag, t, q, Z));
+
+ // Node u
+ assertFalse(Utils.dSeparated(dag, u, v, Z));
+ assertFalse(Utils.dSeparated(dag, u, y, Z));
+ assertFalse(Utils.dSeparated(dag, u, q, Z));
+
+ // Node v
+ assertFalse(Utils.dSeparated(dag, v, y, Z));
+ assertFalse(Utils.dSeparated(dag, v, q, Z));
+
+ // Node y
+ assertFalse(Utils.dSeparated(dag, y, q, Z));
+
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/FindNaYXTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/FindNaYXTest.java
new file mode 100644
index 0000000..c386850
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/FindNaYXTest.java
@@ -0,0 +1,140 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.EdgeListGraph;
+import edu.cmu.tetrad.graph.Endpoint;
+import edu.cmu.tetrad.graph.Graph;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+
+public class FindNaYXTest {
+
+ //find naYX tests
+ @Test
+ public void testFindNaYX_singleUndirectedCommonNeighbor() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node z = new GraphNode("Z");
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(z);
+
+ graph.addEdge(new Edge(x, z, Endpoint.TAIL, Endpoint.TAIL)); // undirected
+ graph.addEdge(new Edge(y, z, Endpoint.TAIL, Endpoint.TAIL)); // undirected
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertEquals(1, result.size());
+ assertTrue(result.contains(z));
+ }
+
+ @Test
+ public void testFindNaYX_directedEdgeShouldBeExcluded() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node z = new GraphNode("Z");
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(z);
+
+ graph.addEdge(new Edge(x, z, Endpoint.ARROW, Endpoint.TAIL)); // x → z
+ graph.addEdge(new Edge(y, z, Endpoint.ARROW, Endpoint.TAIL)); // y → z
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertEquals(0, result.size());
+ }
+
+ @Test
+ public void testFindNaYX_mixedNeighbors() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node z1 = new GraphNode("Z1"); // undirected common
+ Node z2 = new GraphNode("Z2"); // directed common
+ Node z3 = new GraphNode("Z3"); // only adjacent to x
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(z1);
+ graph.addNode(z2);
+ graph.addNode(z3);
+
+ // z1: undirected edge with both
+ graph.addEdge(new Edge(x, z1, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(y, z1, Endpoint.TAIL, Endpoint.TAIL));
+
+ // z2: directed edges with both
+ graph.addEdge(new Edge(x, z2, Endpoint.TAIL, Endpoint.ARROW));
+ graph.addEdge(new Edge(y, z2, Endpoint.TAIL, Endpoint.ARROW));
+
+ // z3: only adjacent to x
+ graph.addEdge(new Edge(x, z3, Endpoint.TAIL, Endpoint.TAIL));
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertEquals(1, result.size());
+ assertTrue(result.contains(z1));
+ assertFalse(result.contains(z2));
+ assertFalse(result.contains(z3));
+ }
+
+ @Test
+ public void testFindNaYX_multipleUndirectedCommonNeighbors() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node a = new GraphNode("A");
+ Node b = new GraphNode("B");
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(a);
+ graph.addNode(b);
+
+ graph.addEdge(new Edge(x, a, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(y, a, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(x, b, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(y, b, Endpoint.TAIL, Endpoint.TAIL));
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertEquals(2, result.size());
+ assertTrue(result.contains(a));
+ assertTrue(result.contains(b));
+ }
+
+ @Test
+ public void testFindNaYX_noCommonNeighbors() {
+ Graph graph = new EdgeListGraph();
+
+ Node x = new GraphNode("X");
+ Node y = new GraphNode("Y");
+ Node a = new GraphNode("A");
+ Node b = new GraphNode("B");
+
+ graph.addNode(x);
+ graph.addNode(y);
+ graph.addNode(a);
+ graph.addNode(b);
+
+ graph.addEdge(new Edge(x, a, Endpoint.TAIL, Endpoint.TAIL));
+ graph.addEdge(new Edge(y, b, Endpoint.TAIL, Endpoint.TAIL));
+
+ List result = Utils.findNaYX(x, y, graph);
+ assertTrue(result.isEmpty());
+ }
+
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/GraphTestHelper.java b/src/test/java/es/uclm/i3a/simd/consensusBN/GraphTestHelper.java
new file mode 100644
index 0000000..01061e8
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/GraphTestHelper.java
@@ -0,0 +1,60 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+import edu.cmu.tetrad.graph.RandomGraph;
+import edu.cmu.tetrad.util.RandomUtil;
+
+public class GraphTestHelper {
+
+
+ private GraphTestHelper() {
+ // Private constructor to prevent instantiation
+ }
+
+ /**
+ * Generates a list of random DAGs sharing the same set of nodes.
+ *
+ * @param numVariables Number of variables (nodes) in each DAG.
+ * @param numDags Number of random DAGs to generate.
+ * @param maxEdges Maximum number of edges in each DAG.
+ * @param maxInDegree Maximum in-degree for each node.
+ * @param maxOutDegree Maximum out-degree for each node.
+ * @param maxDegree Maximum degree for each node.
+ * @param connected Whether the generated DAGs should be connected.
+ * @param seed Seed for random number generation.
+ * @return List of randomly generated DAGs
+ */
+ public static List generateRandomDagList(int numVariables, int numDags, int maxEdges, int maxInDegree, int maxOutDegree, int maxDegree, boolean connected, long seed) {
+ List randomDagsList = new ArrayList<>();
+
+ // Create shared nodes
+ List sharedNodes = new ArrayList<>();
+ for (int i = 0; i < numVariables; i++) {
+ sharedNodes.add(new GraphNode("Node" + i));
+ }
+
+ // Set seed
+ RandomUtil.getInstance().setSeed(seed);
+
+ // Generate DAGs
+ for (int i = 0; i < numDags; i++) {
+ Dag randomDag = RandomGraph.randomDag(
+ sharedNodes,
+ 0, // Latent variables (0 by default)
+ maxEdges,
+ maxDegree,
+ maxInDegree,
+ maxOutDegree,
+ connected
+ );
+ randomDagsList.add(randomDag);
+ }
+
+ return randomDagsList;
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVotingTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVotingTest.java
new file mode 100644
index 0000000..253c19e
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVotingTest.java
@@ -0,0 +1,111 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.GraphUtils;
+import edu.cmu.tetrad.graph.Node;
+
+public class HeuristicConsensusMVotingTest {
+
+ private Dag createSimpleDag(String from, String to) {
+ Node n1 = new GraphNode(from);
+ Node n2 = new GraphNode(to);
+ Dag dag = new Dag(Arrays.asList(n1, n2));
+ dag.addDirectedEdge(n1, n2);
+ return dag;
+ }
+
+ @Test
+ public void testFusionCreatesDAGWithExpectedEdges() {
+ Dag dag1 = createSimpleDag("A", "B");
+ Dag dag2 = createSimpleDag("A", "B");
+ ArrayList dags = new ArrayList<>(Arrays.asList(dag1, dag2));
+
+ HeuristicConsensusMVoting mvoting = new HeuristicConsensusMVoting(dags, 0.5);
+ Dag consensus = mvoting.fusion();
+
+ assertNotNull(consensus);
+ assertTrue(GraphUtils.isDag(consensus));
+ assertEquals(2, consensus.getNumNodes());
+ assertEquals(1, consensus.getNumEdges());
+ assertTrue(consensus.isParentOf(getNodeByName(consensus, "A"), getNodeByName(consensus, "B")));
+
+ // Test getters
+ assertEquals(0.5, mvoting.getPercentage());
+ assertEquals(2, mvoting.getVariables().size());
+ assertEquals(2, mvoting.getWeight().length);
+ assertEquals(2, mvoting.getWeight()[0].length);
+ assertEquals(2, mvoting.getWeight()[1].length);
+ assertEquals(0.0, mvoting.getWeight()[0][1]);
+ assertEquals(0.0, mvoting.getWeight()[1][0]);
+ assertEquals(dags, mvoting.getSetOfdags());
+ assertEquals(consensus, mvoting.getOutputDag());
+
+ }
+
+ @Test
+ public void testFusionDoesNotAddLowWeightEdges() {
+ Dag dag1 = createSimpleDag("A", "B");
+ Dag dag2 = createSimpleDag("B", "A"); // Conflicting direction
+
+ ArrayList dags = new ArrayList<>(Arrays.asList(dag1, dag2));
+ HeuristicConsensusMVoting mvoting = new HeuristicConsensusMVoting(dags, 0.75);
+ Dag consensus = mvoting.fusion();
+
+ // Expect no edge because weight is 0.5 < 0.75
+ assertEquals(0, consensus.getNumEdges());
+ }
+
+ @Test
+ public void testFusionDoesNotCreateCycle() {
+ Node a = new GraphNode("A");
+ Node b = new GraphNode("B");
+ Node c = new GraphNode("C");
+
+ Dag dag1 = new Dag(Arrays.asList(a, b, c));
+ dag1.addDirectedEdge(a, b);
+ dag1.addDirectedEdge(b, c);
+
+ Dag dag2 = new Dag(Arrays.asList(a, b, c));
+ dag2.addDirectedEdge(a, b);
+ dag2.addDirectedEdge(b, c);
+
+ ArrayList dags = new ArrayList<>(Arrays.asList(dag1, dag2));
+ HeuristicConsensusMVoting mvoting = new HeuristicConsensusMVoting(dags, 0.5);
+ Dag consensus = mvoting.fusion();
+
+ assertTrue(GraphUtils.isDag(consensus), "La fusión no debe crear ciclos.");
+ }
+
+ @Test
+ public void testWeightMatrixCorrectlyComputed() {
+ Dag dag1 = createSimpleDag("A", "B");
+ Dag dag2 = createSimpleDag("A", "B");
+ ArrayList dags = new ArrayList<>(Arrays.asList(dag1, dag2));
+
+ HeuristicConsensusMVoting mvoting = new HeuristicConsensusMVoting(dags, 0.5);
+
+ int indexA = mvoting.getVariables().indexOf(new GraphNode("A"));
+ int indexB = mvoting.getVariables().indexOf(new GraphNode("B"));
+
+ double weightAB = mvoting.getWeight()[indexA][indexB];
+ double expectedWeight = 1.0; // Dos DAGS, misma dirección
+
+ assertEquals(expectedWeight, weightAB, 1e-6);
+ }
+
+ private Node getNodeByName(Dag dag, String name) {
+ return dag.getNodes().stream()
+ .filter(n -> n.getName().equals(name))
+ .findFirst()
+ .orElseThrow(() -> new IllegalArgumentException("Node not found: " + name));
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNsTest.java
new file mode 100644
index 0000000..462dbb0
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNsTest.java
@@ -0,0 +1,112 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+
+public class HierarchicalAgglomerativeClustererBNsTest {
+
+ private ArrayList inputDags;
+
+ @BeforeEach
+ public void setUp() {
+ int numVariables = 4; // Number of variables in the DAGs
+ int numDags = 10; // Number of DAGs to generate
+ int maxEdges = 6; // Maximum number of edges in each DAG
+ int maxInDegree = 2; // Maximum in-degree for each node
+ int maxOutDegree = 2; // Maximum out-degree for each node
+ int maxDegree = 3; // Maximum degree for each node
+ boolean connected = true; // Whether the DAGs should be connected
+ long seed = 42; // Seed for random number generation
+
+ // Generate a list of random DAGs using GraphTestHelper
+ inputDags = new ArrayList<>();
+ inputDags.addAll(GraphTestHelper.generateRandomDagList(numVariables, numDags, maxEdges, maxInDegree, maxOutDegree, maxDegree, connected, seed));
+ }
+
+
+ @Test
+ public void testConstructorAndGetSetOfBNs() {
+
+ HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 2);
+
+ assertEquals(inputDags.size(), clusterer.getSetOfBNs().size());
+ assertEquals(inputDags, clusterer.getSetOfBNs());
+ }
+
+ @Test
+ public void testClusterStopsEarlyDueToMaxSize() {
+
+ HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 1);
+ int numDagsAfterCluster = clusterer.cluster();
+
+ // Only one fusion should occur since maxSize is 1
+ assertEquals((int)inputDags.size()/2, numDagsAfterCluster);
+ }
+
+ @Test
+ public void testGetClustersOutputAtLevelZero() {
+
+ HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 2);
+ clusterer.cluster();
+
+ ArrayList output = clusterer.getClustersOutput(0);
+ assertEquals(inputDags.size(), output.size(), "En el nivel 0 debe haber tantos DAGs como en la entrada");
+ }
+
+ @Test
+ public void testGetInsertedEdges() {
+ HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 2);
+ clusterer.cluster();
+
+ int insertedEdges = clusterer.getInsertedEdges(1);
+ assertTrue(insertedEdges >= 0, "Debe haber al menos 0 enlaces insertadas en el nivel 1");
+ }
+
+ @Test
+ public void testComputeConsensusDag() {
+ HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 2);
+ int level = clusterer.cluster();
+
+ Dag consensus = clusterer.computeConsensusDag(level);
+ assertNotNull(consensus, "El DAG de consenso no debería ser null");
+ }
+
+
+
+ /*
+ @Test
+ public void testFullClusteringUntilOneCluster() {
+ // Sin restricciones de tamaño ni complejidad
+ HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 0.0);
+ int result = clusterer.cluster();
+
+ // Debe haber n-1 fusiones si todo fue bien, por lo que al final solo debe quedar un cluster
+ assertEquals(1, result, "Deberían haberse hecho n-1 fusiones");
+
+ ArrayList resultClusters = clusterer.getClustersOutput(result);
+ assertEquals(1, resultClusters.size(), "Al final solo debe quedar un cluster");
+ }
+
+ @Test
+ public void testClusteringStopsDueToComplexity() {
+ ArrayList dags = new ArrayList<>();
+ dags.addAll(GraphTestHelper.generateRandomDagList(10, 2, 30, 10, 10, 10, true, 42));
+
+ // Muy bajo el umbral de complejidad para forzar que no se fusionen
+ HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(dags, 0.01);
+ int result = clusterer.cluster();
+
+ assertTrue(result <= 1, "El clustering debería detenerse porque los DAGs fusionados son muy complejos");
+ }
+ */
+
+}
+
+
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ListFabricTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ListFabricTest.java
new file mode 100644
index 0000000..5a7744f
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ListFabricTest.java
@@ -0,0 +1,70 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+public class ListFabricTest {
+
+ @BeforeEach
+ public void setUp() {
+ // Aseguramos el valor por defecto del maxSize antes de cada test
+ ListFabric.MAX_SIZE = 2;
+ }
+
+ @Test
+ public void testGetList_size3_maxSize2() {
+ List result = Arrays.stream(ListFabric.generateList(3)).boxed().collect(Collectors.toList());
+ List expected = Arrays.asList(
+ 0b000, // []
+ 0b001, // [C]
+ 0b010, // [B]
+ 0b100, // [A]
+ 0b011, // [B,C]
+ 0b101, // [A,C]
+ 0b110 // [A,B]
+ );
+ assertEquals(expected.size(), result.size());
+ assertTrue(result.containsAll(expected));
+ }
+
+ @Test
+ public void testGetList_size0() {
+ List result = Arrays.stream(ListFabric.generateList(0)).boxed().collect(Collectors.toList());
+ assertEquals(1, result.size());
+ assertEquals(0, (int) result.get(0)); // Solo el conjunto vacío
+ }
+
+ @Test
+ public void testGetList_maxSize0() {
+ ListFabric.MAX_SIZE = 0;
+ List result = Arrays.stream(ListFabric.generateList(3)).boxed().collect(Collectors.toList());
+ assertEquals(1, result.size());
+ assertEquals(0, (int) result.get(0)); // Solo el conjunto vacío
+ }
+
+ @Test
+ public void testNoSubsetExceedsMaxSize() {
+ int size = 4;
+ List result = Arrays.stream(ListFabric.generateList(size)).boxed().collect(Collectors.toList());
+ for (int subset : result) {
+ int ones = Integer.bitCount(subset);
+ assertTrue(ones <= ListFabric.MAX_SIZE,
+ "Subset " + Integer.toBinaryString(subset) + " has " + ones + " bits set");
+ }
+ }
+
+ @Test
+ public void testSymmetry_sizeEqualsMaxSize() {
+ int size = 3;
+ ListFabric.MAX_SIZE = 3;
+ List result = Arrays.stream(ListFabric.generateList(size)).boxed().collect(Collectors.toList());
+ // All subsets should be included (2^3 = 8)
+ assertEquals(8, result.size());
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java
new file mode 100644
index 0000000..25a14b7
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java
@@ -0,0 +1,181 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.Arrays;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class PairWiseConsensusBESTest {
+
+ private Node A, B, C;
+ private Dag dag1, dag2;
+
+ @BeforeEach
+ public void setup() {
+ A = new GraphNode("A");
+ B = new GraphNode("B");
+ C = new GraphNode("C");
+
+ dag1 = new Dag();
+ dag1.addNode(A);
+ dag1.addNode(B);
+ dag1.addNode(C);
+ dag1.addDirectedEdge(A, B);
+ dag1.addDirectedEdge(B, C);
+
+ dag2 = new Dag();
+ dag2.addNode(A);
+ dag2.addNode(B);
+ dag2.addNode(C);
+ dag2.addDirectedEdge(A, C);
+ dag2.addDirectedEdge(C, B);
+ }
+
+ @Test
+ public void testFusionCreatesNonNullDag() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.fusion();
+
+ Dag fusion = pwc.getDagFusion();
+ assertNotNull(fusion, "The fusion DAG should not be null");
+ }
+
+ @Test
+ public void testGetNumberOfInsertedEdges() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.fusion();
+
+ int inserted = pwc.getNumberOfInsertedEdges();
+ assertTrue(inserted >= 0, "Inserted edges should be >= 0");
+ }
+
+ @Test
+ public void testGetNumberOfUnionEdges() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.fusion();
+
+ PairWiseConsensusBES samePwc = new PairWiseConsensusBES(dag1, dag1);
+ samePwc.fusion();
+
+ int unionEdges = pwc.getNumberOfUnionEdges();
+ assertTrue(unionEdges > 0, "Union should contain some edges");
+
+ int sameUnionEdges = samePwc.getNumberOfUnionEdges();
+ assertTrue(sameUnionEdges == dag1.getNumEdges(), "Union edges should match the number of edges in a single DAG");
+ }
+
+ @Test
+ public void testGetHammingDistance() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ int distance = pwc.calculateHammingDistance();
+ PairWiseConsensusBES samePwc = new PairWiseConsensusBES(dag1, dag1);
+ int sameDistance = samePwc.calculateHammingDistance();
+
+ assertTrue(distance >= 0, "Hamming distance should be >= 0");
+ assertTrue(sameDistance == 0, "Hamming distance for identical DAGs should be 0");
+ }
+
+ @Test
+ public void testRunCallsGetFusion() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.run();
+
+ assertNotNull(pwc.getDagFusion(), "Fusion DAG should be created after run()");
+ assertTrue(pwc.getNumberOfUnionEdges() > 0, "Union edges should be computed");
+ }
+
+ @Test
+ public void testFusionIsConsistentWithInput() {
+ PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2);
+ pwc.fusion();
+ Dag fusion = pwc.getDagFusion();
+
+ Set edges = fusion.getEdges();
+ assertNotNull(edges);
+ assertFalse(edges.isEmpty(), "Fusion DAG should contain edges");
+ }
+
+ // Exception throws test cases
+
+ @Test
+ public void testNullDags() {
+ IllegalArgumentException ex1 = assertThrows(IllegalArgumentException.class,
+ () -> new PairWiseConsensusBES(null, createValidDag()));
+ assertEquals("Input DAGs cannot be null.", ex1.getMessage());
+
+ IllegalArgumentException ex2 = assertThrows(IllegalArgumentException.class,
+ () -> new PairWiseConsensusBES(createValidDag(), null));
+ assertEquals("Input DAGs cannot be null.", ex2.getMessage());
+ }
+
+ @Test
+ public void testEmptyNodes() {
+ Dag dag1 = new Dag(); // no nodes
+ Dag dag2 = createValidDag();
+
+ IllegalArgumentException ex = assertThrows(IllegalArgumentException.class,
+ () -> new PairWiseConsensusBES(dag1, dag2));
+ assertEquals("Input DAGs must contain at least one node.", ex.getMessage());
+ }
+
+ @Test
+ public void testEmptyEdges() {
+ Node n1 = new GraphNode("X");
+ Node n2 = new GraphNode("Y");
+ Dag dag1 = new Dag(Arrays.asList(n1, n2)); // 2 nodes, no edges
+ Dag dag2 = createValidDag(); // tiene al menos un edge
+
+ IllegalArgumentException ex = assertThrows(IllegalArgumentException.class,
+ () -> new PairWiseConsensusBES(dag1, dag2));
+ assertEquals("Input DAGs must contain at least one edge.", ex.getMessage());
+ }
+
+ @Test
+ public void testDifferentNodeCounts() {
+ Dag dag1 = createValidDag();
+ Dag dag2 = createValidDag();
+ dag2.addNode(new GraphNode("Extra"));
+
+ IllegalArgumentException ex = assertThrows(IllegalArgumentException.class,
+ () -> new PairWiseConsensusBES(dag1, dag2));
+ assertEquals("Input DAGs must have the same number of nodes.", ex.getMessage());
+ }
+
+ @Test
+ public void testDifferentNodeSets() {
+ Node n1 = new GraphNode("A");
+ Node n2 = new GraphNode("B");
+ Node n3 = new GraphNode("C");
+
+ Dag dag1 = new Dag(Arrays.asList(n1, n2));
+ dag1.addDirectedEdge(n1, n2);
+
+ Dag dag2 = new Dag(Arrays.asList(n1, n3)); // C en vez de B
+ dag2.addDirectedEdge(n1, n3);
+
+ IllegalArgumentException ex = assertThrows(IllegalArgumentException.class,
+ () -> new PairWiseConsensusBES(dag1, dag2));
+ assertEquals("Input DAGs must have the same set of nodes.", ex.getMessage());
+ }
+
+ // Helper para crear un DAG válido
+ private Dag createValidDag() {
+ Node a = new GraphNode("A");
+ Node b = new GraphNode("B");
+
+ Dag dag = new Dag(Arrays.asList(a, b));
+ dag.addDirectedEdge(a, b);
+ return dag;
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/PowerSetTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/PowerSetTest.java
new file mode 100644
index 0000000..437d97c
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/PowerSetTest.java
@@ -0,0 +1,89 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class PowerSetTest {
+
+ private List nodeList;
+
+ @BeforeEach
+ public void setUp() {
+ nodeList = new ArrayList<>();
+ nodeList.add(new GraphNode("X"));
+ nodeList.add(new GraphNode("Y"));
+ nodeList.add(new GraphNode("Z"));
+ }
+
+ @Test
+ public void testPowerSetWithMaxSize() {
+ PowerSet ps = new PowerSet(nodeList, 2);
+ List> result = new ArrayList<>();
+ while (ps.hasMoreElements()) {
+ result.add(ps.nextElement());
+ }
+
+ // Verifica que no hay subconjuntos de tamaño > 2
+ for (Set subset : result) {
+ assertTrue(subset.size() <= 2, "Subset size should be <= 2");
+ }
+
+ // Comprobamos algunos subconjuntos esperados
+ HashSet expected1 = new HashSet<>();
+ expected1.add(nodeList.get(0));
+ HashSet expected2 = new HashSet<>();
+ expected2.add(nodeList.get(0));
+ expected2.add(nodeList.get(1));
+ assertTrue(result.contains(new HashSet<>(expected1)));
+ assertTrue(result.contains(new HashSet<>(expected2)));
+ }
+
+ @Test
+ public void testPowerSetWithoutMaxSize() {
+ PowerSet ps = new PowerSet(nodeList);
+ List> result = new ArrayList<>();
+ while (ps.hasMoreElements()) {
+ result.add(ps.nextElement());
+ }
+
+ // Número de subconjuntos debería ser 2^n
+ assertEquals(7, result.size());
+
+ // El conjunto vacío debería estar incluido
+ assertTrue(result.contains(new HashSet()));
+
+ // El conjunto completo no está incluido en el resultado
+ assertTrue(!result.contains(new HashSet<>(nodeList)));
+ }
+
+ @Test
+ public void testResetIndex() {
+ PowerSet ps = new PowerSet(nodeList);
+ assertTrue(ps.hasMoreElements());
+ ps.nextElement();
+ ps.resetIndex();
+ assertTrue(ps.hasMoreElements());
+ }
+
+ @Test
+ public void testMaxPowerSetSize() {
+ PowerSet powerSet = new PowerSet(nodeList); // Esto actualiza el valor de maxPow
+ assertEquals(8L, powerSet.maxPowerSetSize());
+ }
+
+ @Test
+ public void testMaxSizeIsNegativeShouldThrowException() {
+ assertThrows(IllegalArgumentException.class, () -> new PowerSet(nodeList, -1));
+ }
+}
diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java
new file mode 100644
index 0000000..86cb1d6
--- /dev/null
+++ b/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java
@@ -0,0 +1,137 @@
+package es.uclm.i3a.simd.consensusBN;
+
+import java.util.ArrayList;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import edu.cmu.tetrad.graph.Dag;
+import edu.cmu.tetrad.graph.GraphNode;
+import edu.cmu.tetrad.graph.Node;
+
+public class TransformDagsTest {
+
+ private ArrayList inputDags;
+ private ArrayList alpha;
+
+ @BeforeEach
+ public void setUp() {
+ inputDags = new ArrayList<>();
+
+ // We use 4 nodes for the DAGs
+ Node nodeA = new GraphNode("A");
+ Node nodeB = new GraphNode("B");
+ Node nodeC = new GraphNode("C");
+ Node nodeD = new GraphNode("D");
+
+ // Create first DAG with these edges: A -> B, A -> C, B -> D, C -> D
+ Dag dag1 = new Dag();
+ dag1.addNode(nodeA);
+ dag1.addNode(nodeB);
+ dag1.addNode(nodeC);
+ dag1.addNode(nodeD);
+
+ // Adding directed edges to the DAG
+ dag1.addDirectedEdge(nodeA, nodeB);
+ dag1.addDirectedEdge(nodeA, nodeC);
+ dag1.addDirectedEdge(nodeB, nodeD);
+ dag1.addDirectedEdge(nodeC, nodeD);
+
+ // Adding the DAG to the list
+ inputDags.add(dag1);
+
+ // Create second DAG with these edges: D -> C, D -> B, C -> A, B -> A
+ Dag dag2 = new Dag();
+ dag2.addNode(nodeA);
+ dag2.addNode(nodeB);
+ dag2.addNode(nodeC);
+ dag2.addNode(nodeD);
+
+ // Adding directed edges to the second DAG
+ dag2.addDirectedEdge(nodeD, nodeC);
+ dag2.addDirectedEdge(nodeD, nodeB);
+ dag2.addDirectedEdge(nodeC, nodeA);
+ dag2.addDirectedEdge(nodeB, nodeA);
+
+ // Adding the second DAG to the list
+ inputDags.add(dag2);
+
+ // Apply AlphaOrder algorithm to these dags:
+ AlphaOrder alphaOrder = new AlphaOrder(inputDags);
+ alphaOrder.computeAlpha();
+ alpha = alphaOrder.getOrder();
+
+ }
+
+ @Test
+ public void testConstructorGettersAndSetters() {
+ TransformDags transformer = new TransformDags(inputDags, alpha);
+ ArrayList empytBetas = new ArrayList<>();
+
+ assertNotNull(transformer);
+ assertEquals(0, transformer.getNumberOfInsertedEdges());
+
+ // GetSetOfOutputDags
+ ArrayList outputDags = transformer.getSetOfOutputDags();
+ assertNotNull(outputDags);
+ assertTrue(outputDags.isEmpty());
+ transformer.transform();
+ outputDags = transformer.getSetOfOutputDags();
+ assertNotNull(outputDags);
+ assertEquals(inputDags.size(), outputDags.size());
+ assertNotEquals(inputDags, outputDags);
+
+ // Testing setTransformers and getTransformers
+ ArrayList betas = transformer.getTransformers();
+ assertNotNull(betas);
+ assertTrue(!betas.isEmpty());
+ transformer.setTransformers(empytBetas);
+ assertEquals(empytBetas, transformer.getTransformers());
+ assertTrue(transformer.getTransformers().isEmpty());
+
+ // GetAlphaOrder
+ ArrayList