diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..9afb123 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI - Maven Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout repo + uses: actions/checkout@v4 + + - name: Set up Java + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }} + gpg-passphrase: ${{ secrets.GPG_PASSPHRASE }} + + - name: Cache Maven packages + uses: actions/cache@v4 + with: + path: ~/.m2 + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + + - name: Build, Test, and Check Coverage with Maven (80% coverage to pass) + run: mvn clean verify -Dgpg.passphrase=${{ secrets.GPG_PASSPHRASE }} \ No newline at end of file diff --git a/.github/workflows/maven-publish.yml b/.github/workflows/maven-publish.yml new file mode 100644 index 0000000..1ea1dcf --- /dev/null +++ b/.github/workflows/maven-publish.yml @@ -0,0 +1,46 @@ +# This workflow will build a package using Maven and then publish it to GitHub packages when a release is created +# For more information see: https://github.com/actions/setup-java/blob/main/docs/advanced-usage.md#apache-maven-with-a-settings-path + +name: Maven Package + +on: + release: + types: [created] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + + - name: Build with Maven + run: mvn -B package --file pom.xml + + - name: Publish to GitHub Packages Apache Maven + run: mvn deploy + env: + GITHUB_TOKEN: ${{ github.token }} # GITHUB_TOKEN is the default env for the password + + - name: Set up Apache Maven Central + uses: actions/setup-java@v4 + with: # running setup-java again overwrites the settings.xml + distribution: 'temurin' + java-version: '17' + server-id: maven # Value of the distributionManagement/repository/id field of the pom.xml + server-username: MAVEN_USERNAME # env variable for username in deploy + server-password: MAVEN_CENTRAL_TOKEN # env variable for token in deploy + gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }} # Value of the GPG private key to import + gpg-passphrase: MAVEN_GPG_PASSPHRASE # env variable for GPG private key passphrase + + - name: Publish to Apache Maven Central + run: mvn deploy + env: + MAVEN_USERNAME: ${{ secrets.CENTRAL_TOKEN_USERNAME }} + MAVEN_CENTRAL_TOKEN: ${{ secrets.CENTRAL_TOKEN_PASSWORD }} + MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} \ No newline at end of file diff --git a/README.md b/README.md index 9ee7a1b..bd873ec 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # consensusBN - Bayesian Network Fusion +[![CI](https://github.com/UCLM-SIMD/consensusBN/actions/workflows/ci.yml/badge.svg)](https://github.com/UCLM-SIMD/consensusBN/actions/workflows/ci.yml) ![Java](https://img.shields.io/badge/Java-8%2B-blue) ![Maven](https://img.shields.io/badge/Maven-3.6%2B-orange) [![License](https://img.shields.io/badge/license-MIT-green)](LICENSE) @@ -8,7 +9,7 @@ `consensusBN` is a Java-based library for Bayesian Network Fusion. This project allows users to combine multiple Bayesian networks into a single consensus network, leveraging the power of consensus-based modeling techniques. The project is supported by a published paper [(link)](https://www.sciencedirect.com/science/article/abs/pii/S156625352030364X), titled "Efficient and accurate structural fusion of Bayesian networks." -![Bayesian Network Fusion](assets/bn_fusion.png) +![Bayesian Network Fusion](assets/bn_fusion.jpg) ## Features diff --git a/assets/bn_fusion.jpg b/assets/bn_fusion.jpg new file mode 100644 index 0000000..e7a2b52 Binary files /dev/null and b/assets/bn_fusion.jpg differ diff --git a/pom.xml b/pom.xml index 3e3466f..131401d 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ 4.0.0 - org.albacete.simd + io.github.jlaborda consensusBN 1.0.0 jar @@ -55,12 +55,12 @@ - - io.github.cmu-phil - tetrad-lib - - 7.6.4 - + + io.github.cmu-phil + tetrad-lib + + 7.6.4 + + + + org.junit.jupiter + junit-jupiter + 5.10.0 + test + + + + org.apache.commons + commons-math3 + 3.6.1 + + @@ -92,17 +106,38 @@ + + + org.apache.maven.plugins + maven-gpg-plugin + 3.1.0 + + + sign-artifacts + verify + + sign + + + + + + --pinentry-mode + loopback + + + + org.apache.maven.plugins maven-source-plugin - 3.3.0 + 3.2.1 attach-sources - verify - jar-no-fork + jar @@ -112,7 +147,7 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.2.0 + 3.3.1 attach-javadocs @@ -123,23 +158,99 @@ + org.apache.maven.plugins maven-surefire-plugin - 2.22.2 + 3.1.2 + + + org.junit.platform + junit-platform-engine + 1.10.0 + + + org.junit.jupiter + junit-jupiter-engine + 5.10.0 + + + + + + + org.jacoco + jacoco-maven-plugin + 0.8.10 + + + **/RandomBN.class + + + + + + prepare-agent + + + + report + verify + + report + + + ${project.build.directory}/jacoco-report + + + + check + + check + + + + + PACKAGE + + + LINE + COVEREDRATIO + 0.80 + + + + + + + + + + + + + org.sonatype.central + central-publishing-maven-plugin + 0.8.0 + true + + central + true + + diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java b/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java index 55e7e2c..aaf7b73 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/AlphaOrder.java @@ -3,152 +3,157 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.List; + import edu.cmu.tetrad.graph.Dag; import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.Edges; import edu.cmu.tetrad.graph.Endpoint; import edu.cmu.tetrad.graph.Node; +/** + * This class implements a heuristic to compute an ancestral order of nodes for a set of DAGs. + * The heuristic is based on finding the best sink node in each iteration for the set of DAGs, + * removing it from the DAGs, and repeating the process until all nodes are ordered. + */ public class AlphaOrder { - ArrayList setOfDags = null; - ArrayList alpha = null; - ArrayList setOfauxG = null; -// ArrayList dpaths = null; + /** + * The set of DAGs to compute the ancestral order from. + */ + private final ArrayList setOfDags; + /** + * The computed ancestral order of nodes. + */ + private ArrayList alpha; + /** + * A set of auxiliary DAGs used during the computation. + */ + private final ArrayList setOfauxG; + /** + * Constructor for the AlphaOrder class. + * Initializes the set of DAGs and creates a copy of each DAG to work with. + * @param dags the list of DAGs from which to compute the ancestral order. + * This constructor creates a deep copy of each DAG to avoid modifying the original DAGs during + * the computation of the ancestral order. + */ public AlphaOrder(ArrayList dags){ - + // Check if the dags are valid + checkExceptions(dags); + + // Initialize the class variables this.setOfDags = dags; - this.alpha = new ArrayList(); - this.setOfauxG = new ArrayList(); -// this.dpaths = new ArrayList(); - + this.alpha = new ArrayList<>(); + this.setOfauxG = new ArrayList<>(); for (Dag i : setOfDags) { Dag aux_G = new Dag(i); setOfauxG.add(aux_G); -// dpaths.add(computeDirectedPathFromTo(aux_G)); } - } - - public int[][] computeDirectedPathFromTo(Dag graph) { - LinkedList dpathNewEdges = new LinkedList(); - dpathNewEdges.clear(); - dpathNewEdges.addAll(graph.getEdges()); - List dpathNodes = null; - dpathNodes = graph.getNodes(); - - int numNodes = dpathNodes.size(); - int [][] dpath = new int[numNodes][numNodes]; + + /** + * Checks for exceptions in the input set of DAGs. + * Throws an IllegalArgumentException if the set is null, empty, or contains DAGs with different nodes. + * Also checks that the size of the set is greater than 1. + * @param setOfDags the set of DAGs to check for exceptions. + */ + private void checkExceptions(ArrayList setOfDags) { + // Check if setOfDags is null + if(setOfDags == null) { + throw new IllegalArgumentException("The set of DAGs is null."); + } + + // Check if all DAGs have the same nodes + if (setOfDags.isEmpty()) { + throw new IllegalArgumentException("The set of DAGs is empty."); + } + // Check that the size is greater than 1 + if(setOfDags.size() <= 1) { + throw new IllegalArgumentException("The set of DAGs has only one DAG."); + } - while (!dpathNewEdges.isEmpty()) { - Edge edge = dpathNewEdges.removeFirst(); - Node _nodeT = Edges.getDirectedEdgeTail(edge); - Node _nodeH = Edges.getDirectedEdgeHead(edge); - int _indexT = dpathNodes.indexOf(_nodeT); - int _indexH = dpathNodes.indexOf(_nodeH); - dpath[_indexT][_indexH] = 1; - int dPathT = 0; - int dPathH = 0; - int mindPath = 0; - for (int i = 0; i < dpathNodes.size(); i++) { - dPathT = dpath[i][_indexT]; - if (dpath[i][_indexT] >= 1) { - dPathH = dpath[i][_indexH]; - if(dPathH == 0) dpath[i][_indexH] = dPathT+1; - else{ - mindPath = Math.min(dPathH, dPathT+1); - dpath[i][_indexH]=mindPath; - } - } - dPathH = dpath[_indexH][i]; - if(dpath[_indexH][i] >= 1){ - dPathT = dpath[_indexT][i]; - if(dPathT ==0) dpath[_indexT][i] = dPathH+1; - else{ - mindPath = Math.min(dPathT, dPathH+1); - dpath[_indexT][i] = mindPath; - } - - } + // Check that all DAGs have the same nodes + List firstDagNodes = setOfDags.get(0).getNodes(); + for (Dag dag : setOfDags) { + if (!dag.getNodes().equals(firstDagNodes)) { + throw new IllegalArgumentException("All DAGs must have the same nodes. Dag " + dag + " has different nodes than the rest of DAGs."); } } - return dpath; - } + } + + /** + * Returns the nodes of the first DAG in the set, since all DAGs are assumed to have the same nodes. + * @return the nodes of the first DAG. + */ public List getNodes(){ return(setOfDags.get(0).getNodes()); } - // heursitica para orden de conceso basada en el numero de caminos dirigidos. (Es muy mala no se utiliza) - - public void computeAlphaH1(){ - - List nodes = setOfDags.get(0).getNodes(); - LinkedList alpha = new LinkedList(); - - while(nodes.size()>0){ - int index_alpha = computeNextH1(nodes); - Node node_alpha = nodes.get(index_alpha); - alpha.addFirst(node_alpha); - for(Dag g: this.setOfauxG){ - removeNode(g,node_alpha); - //int[][] newDpaths = computeDirectedPathFromTo(g); -// this.dpaths.set(this.setOfauxG.indexOf(g), newDpaths); - } - nodes.remove(node_alpha); - } - this.alpha = new ArrayList(alpha); - } - - // heuistica para encontrar un orden de conceso. Se basa en los enlaces que generaria seguir una secuencia creada desde los nodos sumideros hacia arriba. - -public void computeAlphaH2(){ + /** + * This method computes the heuristic to find an ancestral order of nodes of consensus. It is based on the number of edges that would be added on a sequence created from the sink nodes upwards. + * It iteratively finds the node with the minimum number of changes (inversions and additions of edges) and adds it to the beginning of the order. + * */ + public void computeAlpha(){ + // Get nodes and initialize the alpha list List nodes = setOfDags.get(0).getNodes(); - LinkedList alpha = new LinkedList(); + LinkedList alpha_aux = new LinkedList<>(); - while(nodes.size()>0){ - int index_alpha = computeNextH2(nodes); - Node node_alpha = nodes.get(index_alpha); - alpha.addFirst(node_alpha); + while(!nodes.isEmpty()){ + int index_alpha = computeNextSink(nodes); + Node nodeAlpha = nodes.get(index_alpha); + alpha_aux.addFirst(nodeAlpha); for(Dag g: this.setOfauxG){ - removeNode(g,node_alpha); + removeNode(g,nodeAlpha); } - nodes.remove(node_alpha); + nodes.remove(nodeAlpha); } - this.alpha = new ArrayList(alpha); + this.alpha = new ArrayList<>(alpha_aux); } - - int computeNextH2(List nodes){ + /** + * Gets the following node in the order based on the minimum number of changes (inversions and additions of edges) that would be required to create a sequence from the sink nodes upwards. + * @param nodes Remaining nodes to be ordered. + * @return index of the node that should be added next to the order. + */ + private int computeNextSink(List nodes){ - int changes = 0; + // Setting up variables to count changes + int changes; int inversion = 0; int addition = 0; int indexNode = 0; int min = Integer.MAX_VALUE; - + + // Iterate through each node to find the one with the minimum changes for the list of DAGs. for(int i=0; i inserted = new ArrayList(); + // Checking total amount of inversions. We add -1 to give relevance to nodes that are already sinks. List children = g.getChildren(nodei); inversion += (children.size()-1); + + // Checking edge additions from parents of each child to nodei and from parents of nodei to children. + ArrayList inserted = new ArrayList<>(); List paX = g.getParents(nodei); for(Node child: children){ List paY = g.getParents(child); + // For each parent of nodei, check if it has an edge to the child for(Node nodep: paX){ - if(g.getEdge(nodep, child)==null){ - addition++; - } + if(g.getEdge(nodep, child)==null){ + addition++; + } } + // For each parent of the child, check if it has an edge to nodei for(Node nodec: paY){ if(!nodec.equals(nodei)){ + // If there is no edge between nodec and nodei, we consider adding it if((g.getEdge(nodec,nodei)==null) && (g.getEdge(nodei,nodec)==null)){ Edge toBeInserted = new Edge(nodec,nodei,Endpoint.CIRCLE,Endpoint.CIRCLE); boolean contains = false; + // Checking if we have already added this edge to the list of inserted edges + // to avoid counting it multiple times. for(Edge e: inserted){ if((e.getNode1().equals(nodec) && (e.getNode2().equals(nodei))) || ((e.getNode1().equals(nodei) && (e.getNode2().equals(nodec))))){ @@ -156,6 +161,7 @@ int computeNextH2(List nodes){ break; } } + // Checkin if there is a new edge addition, we update the counter and the list of inserted edges if so. if(!contains){ addition++; inserted.add(toBeInserted); @@ -165,117 +171,109 @@ int computeNextH2(List nodes){ } } } + // Calculate total changes for the current node changes = inversion + addition; + // If the current node has less changes than the minimum found so far, we update the minimum and the index of the node + // to be added to the order. if(changes < min){ min = changes; indexNode = i; } - changes = 0; + // Resetting changes for the next iteration inversion = 0; addition = 0; } return indexNode; } - void removeNode(Dag g, Node node_alpha){ + /** + * Removes a node from the DAG and updates the edges according to a new node added to the alpha order. + * It removes a sink node and updates the edges to maintain the directed paths in the DAG. + * This is done each iteration of the heuristic to compute the alpha order. + * @param g the DAG from which the node is to be removed. + * @param nodeAlpha the node to be removed from the DAG. + */ + private void removeNode(Dag g, Node nodeAlpha){ - List children = g.getChildren(node_alpha); + List children = g.getChildren(nodeAlpha); while(!children.isEmpty()){ - int i=0; - Node child; - boolean seguir = false; - do{ - child = children.get(i++); - g.removeEdge(node_alpha, child); - seguir=false; - if(g.paths().existsDirectedPath(node_alpha,child)){ - seguir=true; - g.addEdge(new Edge(node_alpha,child,Endpoint.TAIL, Endpoint.ARROW)); - } - }while(seguir); + // 1. Select a child that prevents a cycle when nodeAlpha <- child is added. + Node child = selectChild(g, nodeAlpha, children); - List paX = g.getParents(node_alpha); - List paY = g.getParents(child); - paY.remove(node_alpha); - g.addEdge(new Edge(child,node_alpha,Endpoint.TAIL, Endpoint.ARROW)); - for(Node nodep: paX){ - Edge pay = g.getEdge(nodep, child); - if(pay == null) - g.addEdge(new Edge(nodep,child,Endpoint.TAIL,Endpoint.ARROW)); - - } - for(Node nodep : paY){ - Edge paz = g.getEdge(nodep,node_alpha); - if(paz == null) - g.addEdge(new Edge(nodep,node_alpha,Endpoint.TAIL,Endpoint.ARROW)); - } + // 2. Cover the edge nodeAlpha -> child by adding edges from parents of nodeAlpha to child and from parents of child to nodeAlpha. Last of all we revert the edge nodeAlpha -> child. + // This is done to maintain the directed paths in the DAG. + coverEdge(g, nodeAlpha, child); + // 3. Delete the child from the list of children of nodeAlpha, as it has been processed. children.remove(child); } - g.removeNode(node_alpha); + // Finally, remove the nodeAlpha from the DAG. + g.removeNode(nodeAlpha); } + /** + * Selects a child node from the list of children of nodeAlpha that does not create a cycle when an edge from nodeAlpha to the child is added (nodeAlpha <- child). + * @param g the DAG from which the child is to be selected. + * @param nodeAlpha the node from the alpha order heuristic. + * @param children the remaining children of nodeAlpha in the DAG. + * @return the selected child node that does not create a cycle when an edge from nodeAlpha to the child is added. + */ + private Node selectChild(Dag g, Node nodeAlpha, List children) { + int i=0; + Node child; + boolean endCondition; + do{ + child = children.get(i++); + g.removeEdge(nodeAlpha, child); + endCondition=false; + if(g.paths().existsDirectedPath(nodeAlpha,child)){ + endCondition=true; + g.addEdge(new Edge(nodeAlpha,child,Endpoint.TAIL, Endpoint.ARROW)); + } + }while(endCondition); + return child; + } - int computeNextH1(List nodes){ - - int min = Integer.MAX_VALUE; - int minIndex = 0; + /** + * Covers the edge from nodeAlpha to child by adding edges from parents of nodeAlpha to child and from parents of child to nodeAlpha. + * This is done to maintain the directed paths in the DAG after removing nodeAlpha. + * @param g the DAG where the edge is to be covered. + * @param nodeAlpha the node from the alpha order heuristic. + * @param child the child node selected from the list of children of nodeAlpha. + */ + private void coverEdge(Dag g, Node nodeAlpha, Node child) { + // Getting the parents of nodeAlpha and child. + List paX = g.getParents(nodeAlpha); + List paY = g.getParents(child); + paY.remove(nodeAlpha); - for(int i=0 ; i< nodes.size(); i++){ - int weightNodei = 0; - //for(Dag dag : this.setOfauxG){ - // int[][] dpath = this.dpaths.get(this.setOfauxG.indexOf(dag)); - // for(int j=0 ; j child. + g.addEdge(new Edge(child,nodeAlpha,Endpoint.TAIL, Endpoint.ARROW)); } + + - public ArrayList getOrder(){ - + /** + * Returns the computed ancestral order of nodes. + * @return an ArrayList of nodes representing the ancestral order of the DAGs after applying the alpha order heuristic. + */ + public ArrayList getOrder(){ return this.alpha; } - - - public static void main(String args[]) { - -// ArrayList dags = new ArrayList(); -// ArrayList alfa = new ArrayList(); -// -// -// System.out.println("Grafos de Partida: "); -// System.out.println("---------------------"); -//// Graph graph = GraphConverter.convert("X1-->X5,X2-->X3,X3-->X4,X4-->X1,X4-->X5"); -//// Dag dag = new Dag(graph); -// -// Dag dag = new Dag(); -// // dag = GraphUtils.randomDag(Integer.parseInt(args[0]), Integer.parseInt(args[1]), true); -// dags.add(dag); -// System.out.println("DAG: ---------------"); -// System.out.println(dag.toString()); -// for (int i=0 ; i < Integer.parseInt(args[2])-1 ; i++){ -// // Dag newDag = GraphUtils.randomDag(dag.getNodes(),Integer.parseInt(args[1]) ,true); -// dags.add(newDag); -// System.out.println("DAG: ---------------"); -// System.out.println(newDag.toString()); -// } -// -// AlphaOrder order = new AlphaOrder(dags); -// order.computeAlphaH2(); -// alfa = order.getOrder(); -// -// System.out.println("Orden de Consenso: " + alfa.toString()); - - - } - - } diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java new file mode 100644 index 0000000..d651022 --- /dev/null +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSep.java @@ -0,0 +1,564 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetrad.graph.Endpoint; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.search.utils.GraphSearchUtils; +import edu.cmu.tetrad.search.utils.MeekRules; +import static es.uclm.i3a.simd.consensusBN.Utils.pdagToDag; + +/** + * This class implements the Backward Equivalence Search with D-Separation + * algorithm for consensus Bayesian networks. It uses an implementation of + * second phase of the Greedy Equivalence Search (GES) algorithm, the Backward + * Equivalence Search (BES), to refine a consensus DAG by removing edges while + * ensuring that the resulting graph remains a Directed Acyclic Graph (DAG). + * Since no data is available, the algorithm relies on D-separation to + * determine whether two nodes are conditionally independent given a set of + * other nodes. For this, the algorithm uses the set of input DAGs to check + * whether the deletion of an edge maintains the d-separation condition. + */ +public class BackwardEquivalenceSearchDSep { + /** + * The graph representing the consensus DAG after applying the Backward + * Equivalence Search with D-separation. + * This graph is built from the union of the transformed input DAGs and is + * refined by removing edges based on d-separation checks. + * + * @see ConsensusUnion + * @see TransformDags + */ + private final Graph graph; + + /** + * List of initial DAGs used to check how many edges are deleted. + */ + private final ArrayList transformedDags; + + /** + * List of initial DAGs used to check the d-separation condition. + * This list is used to verify whether the deletion of an edge maintains the + * d-separation condition across all input DAGs. + * + * @see Utils#dSeparated(Dag, Node, Node, List) + */ + private final ArrayList initialDags; + + /** + * The output DAG after applying the Backward Equivalence Search with D-separation. + * This DAG is the final result after removing edges from the consensus DAG + * while ensuring that the d-separation condition is maintained using the input DAGs. + * + * @see Utils#dSeparated(Dag, Node, Node, List) + */ + private Dag outputDag; + + /** + * A map to store the local scores for edge deletions. + * This map is used to cache the scores of edge deletions to avoid redundant calculations. + * The key is a string representation of the edge and its conditioning set, and the value is the score. + */ + private final Map localScore = new HashMap<>(); + + /** + * Number of edges removed during the backward equivalence search process. + * This variable keeps track of the total number of edges that are inserted (deleted) during the + * Backward Equivalence Search with D-separation process. + * + * @see ConsensusUnion#getNumberOfInsertedEdges() + * @see BackwardEquivalenceSearchDSep#applyBackwardEliminationWithDSeparation() + */ + private int numberOfRemovedEdges = 0; + + /** + * Percentage threshold for edge deletion. By default, it is set to 1.0, set to another value for an heuristic search + */ + private double percentage = 1.0; + + /** + * Maximum size of the conditioning set for edge deletion. Set to Integer.MAX_VALUE by default. Another value can be set for an heuristic search. + */ + private int maxSize = Integer.MAX_VALUE; + + /** + * Constructor for BackwardEquivalenceSearchDSep that initializes the properties for the search with a union DAG and lists of initial and transformed DAGs. + * + * @param union The resulting union DAG from the ConsensusUnion process. + * @param initialDags List of initial DAGs used to check the d-separation condition. + * @param transformedDags List of transformed DAGs after applying the alpha order. + */ + public BackwardEquivalenceSearchDSep(Dag union, ArrayListinitialDags, ArrayList transformedDags) { + this.graph = new EdgeListGraph(new LinkedList<>(union.getNodes())); + for (Edge edge : union.getEdges()) { + graph.addEdge(edge); + } + this.initialDags = initialDags; + this.transformedDags = transformedDags; + } + + /** + * Applies the Backward Equivalence Search with D-separation to the consensus DAG. + * This method iteratively removes edges from the consensus DAG while ensuring that the d-separation condition is maintained across all input DAGs. + * It returns the final output DAG after all possible edge deletions. + * @return The output DAG after applying the Backward Equivalence Search with D-separation. + */ + public Dag applyBackwardEliminationWithDSeparation(){ + double score = 0; + EdgeCandidate bestCandidate; + + // Creating a pdag from the graph + rebuildPattern(graph); + + // While there are edges to delete, search for the best edge to delete + do { + // Make sure that any undirected edge is transformed into two directed edges + List edges = cleanUndirectedEdges(); + + // Find the best edge to delete + bestCandidate = calculateBestCandidateEdge(edges, score); + /* for (Edge edge : edges) { + // Getting candidate edge to delete + Node candidateTail = Edges.getDirectedEdgeTail(edge); + Node candidateHead = Edges.getDirectedEdgeHead(edge); + + List hNeighbors = getHNeighbors(candidateTail, candidateHead, graph); + PowerSet hSubsets= PowerSetFabric.getPowerSet(candidateTail,candidateHead,hNeighbors); + + while(hSubsets.hasMoreElements()) { + HashSet hSubset=hSubsets.nextElement(); + + // Checking if {naYXH} \ {hSubset} is a clique + List naYXH = findNaYX(candidateTail, candidateHead, graph); + naYXH.removeAll(hSubset); + if (!isClique(naYXH, graph)) { + continue; + } + + // Calculating the score of the candidate edge deletion + double deleteEval = deleteEval(candidateTail, candidateHead, hSubset, graph); + + // Setting limit for deleteEval + if (!(deleteEval >= 1.0)) deleteEval = 0.0; + + // If the score is not better than the best score, continue + double evalScore = score + deleteEval; + if (!(evalScore > bestScore)) { + continue; + } + + // Updating variables for the best edge deletion + bestScore = evalScore; + bestTail = candidateTail; + bestHead = candidateHead; + bestSetParents = hSubset; + } + + } */ + // + if (bestCandidate != null) { + score = executeEdgeDeletion(bestCandidate); + } + } while (bestCandidate != null); + + // Rebuild the pattern to ensure the final graph is a DAG + createOutputDag(); + + return outputDag; + } + + /** + * Rebuilds the input graph to ensure it is a valid pattern. + * This method applies the Meek rules to orient the edges and ensure that the graph is a valid pattern. + * It also converts the graph to a PDAG (Partially Directed Acyclic Graph) + * @param graph The graph to validate and rebuild as a PDAG. + */ + private void rebuildPattern(Graph graph) { + GraphSearchUtils.basicCpdag(graph); + pdag(graph); + } + + + /** + * Cleans the undirected edges in the graph by converting them to directed edges. + * This method iterates through the edges of the graph and transforms undirected edges into two directed edges, + * ensuring that the resulting graph maintains only directed edges. + * @return + */ + private List cleanUndirectedEdges() { + Set edges1 = graph.getEdges(); + List edges = new ArrayList<>(); + + for (Edge edge : edges1) { + Node _x = edge.getNode1(); + Node _y = edge.getNode2(); + + if (Edges.isUndirectedEdge(edge)) { + edges.add(Edges.directedEdge(_x, _y)); + edges.add(Edges.directedEdge(_y, _x)); + } else { + edges.add(edge); + } + } + return edges; + } + + /** + * Calculates the best candidate edge for deletion based on the current score and the edges available. + * This method evaluates each edge and its possible conditioning sets to find the edge that, when deleted, + * results in the highest score improvement while maintaining the d-separation condition. + * @param edges List of edges to consider for deletion. + * @param score The current score before any edge deletion. + * @return An EdgeCandidate object representing the best edge to delete, or null if no suitable edge is found. + */ + private EdgeCandidate calculateBestCandidateEdge(List edges, double score){ + double bestScore = score; + EdgeCandidate bestCandidate = null; + for(Edge edge : edges){ + // Getting candidate edge to delete + Node candidateTail = Edges.getDirectedEdgeTail(edge); + Node candidateHead = Edges.getDirectedEdgeHead(edge); + + List hNeighbors = getHNeighbors(candidateTail, candidateHead, graph); + PowerSet hSubsets= new PowerSet(hNeighbors);//PowerSetFabric.getPowerSet(candidateTail,candidateHead,hNeighbors); + + while(hSubsets.hasMoreElements()) { + // Getting a HashSet of hNeighbors + Set hSubset=hSubsets.nextElement(); + + // Checking size of hSubset + if (hSubset.size() > maxSize) { + break; // Skip to next edge if the size exceeds the maximum allowed size + } + + // Checking if {naYXH} \ {hSubset} is a clique + List naYXH = Utils.findNaYX(candidateTail, candidateHead, graph); + naYXH.removeAll(hSubset); + if (!GraphUtils.isClique(naYXH, graph)) { + continue; + } + + // Calculating the score of the candidate edge deletion + double deleteEval = deleteEval(candidateTail, candidateHead, hSubset, graph); + + // Setting limit for deleteEval + if (deleteEval < percentage) deleteEval = 0.0; + + // If the score is not better than the best score, continue + double evalScore = score + deleteEval; + if (!(evalScore > bestScore)) { + continue; + } + + // Updating best candidate edge + bestCandidate = new EdgeCandidate(candidateTail, candidateHead, hSubset); + bestCandidate.score = evalScore; + + // Updating score for the best edge deletion + bestScore = evalScore; + } + } + return bestCandidate; + } + + /** + * Executes the deletion of the best candidate edge from the graph. + * This method removes the edge from the graph and updates the local score map. + * It also rebuilds the pattern after the deletion and updates the number of inserted edges. + * @param bestCandidate The best candidate edge to delete, containing the tail, head, conditioning set, and score. + * @return The score after the edge deletion is executed. + * This score reflects the new state of the graph after the edge has been removed. + */ + private double executeEdgeDeletion(EdgeCandidate bestCandidate) { + Node bestTail; + Node bestHead; + Set bestSetParents; + double score; + double bestScore; + bestTail = bestCandidate.tail; + bestHead = bestCandidate.head; + bestSetParents = bestCandidate.conditioningSet; + bestScore = bestCandidate.score; + + // Applying delete + //System.out.println(" "); + //System.out.println("DELETE " + graph.getEdge(bestTail, bestHead) + bestSetParents.toString() + " (" +bestScore + ")"); + //System.out.println(" "); + delete(bestTail, bestHead, bestSetParents, graph); + + // Rebuilding the pattern after deleting the edge + rebuildPattern(graph); + + // Updating the number of inserted edges + int deletedEdges = 0; + for(int g = 0; g getHNeighbors(Node x, Node y, Graph graph) { + List hNeighbors = new LinkedList<>(graph.getAdjacentNodes(y)); + hNeighbors.retainAll(graph.getAdjacentNodes(x)); + + for (int i = hNeighbors.size() - 1; i >= 0; i--) { + Node z = hNeighbors.get(i); + Edge edge = graph.getEdge(y, z); + if (!Edges.isUndirectedEdge(edge)) { + hNeighbors.remove(z); + } + } + + return hNeighbors; + } + + /** + * Applies the delete operation from Chickering 2002 for the edge x->y in the graph, and updates the edges + * connecting x and y to the nodes in the provided HashSet. This is done to ensure that the same dependency structure is maintained + * while removing the edge between x and y. + * @param tailNode The tail node of the edge to be deleted. + * @param headNode The head node of the edge to be deleted. + * @param subset The set of nodes that will be connected to the tail and head nodes after the deletion. + * @param graph The graph from which the edge is deleted and the connections are updated. + */ + private static void delete(Node tailNode, Node headNode, Set subset, Graph graph) { + graph.removeEdges(tailNode, headNode); + + for (Node aSubset : subset) { + if (!graph.isParentOf(aSubset, tailNode) && !graph.isParentOf(tailNode, aSubset)) { + graph.removeEdge(tailNode, aSubset); + graph.addDirectedEdge(tailNode, aSubset); + } + graph.removeEdge(headNode, aSubset); + graph.addDirectedEdge(headNode, aSubset); + } + } + + /** + * Evaluates the impact of deleting an edge from the graph based on d-separation. + * + * This method computes a score for deleting the edge from {@code x} to {@code y}, + * taking into account a conditioning set of nodes {@code conditioningSet}. It uses + * structural information from the graph to assess whether {@code y} is d-separated + * from {@code x} given the constructed conditioning set. + * + * @param x The source node of the edge to be deleted. + * @param y The target node of the edge to be deleted. + * @param conditioningSet The set of nodes used as conditioning variables (Z) for d-separation. + * @param graph The graph in which the change is being evaluated. + * @return The score resulting from deleting the edge, based on the given context. + */ + private double deleteEval(Node x, Node y, Set conditioningSet, Graph graph){ + // Setup the conditioning set for d-separation by removing the conditioning nodes from the naYX set, adding the parents of y and removing x. + Set finalConditioningSet = new HashSet<>(Utils.findNaYX(x, y, graph)); + finalConditioningSet.removeAll(conditioningSet); + finalConditioningSet.addAll(graph.getParents(y)); + finalConditioningSet.remove(x); + + // Check if y is d-separated from x given the final conditioning set in each graph. + return scoreGraphChangeDelete(y, x, finalConditioningSet); + } + + /** + * Checks if the deletion of an edge from {@code x} to {@code y} maintains the d-separation condition + * across all initial DAGs. If the edge deletion maintains d-separation, it returns a score of 1.0, + * otherwise it returns 0.0. + * + * This method uses a local score map to cache results for efficiency, avoiding redundant calculations + * for the same edge and conditioning set. + * @param x The tail node of the edge to be deleted. + * @param y The head node of the edge to be deleted. + * @param conditioningSet The set of nodes used as conditioning variables (Z) for d-separation. + * @return A score of 1.0 if the edge deletion maintains d-separation, otherwise 0.0. + * + * @see Utils#dSeparated(Dag, Node, Node, List) + * @see DSeparationKey + * + * This method is crucial for ensuring that the edge deletion does not violate the d-separation condition, + * which is essential for maintaining the integrity of the Bayesian network structure. + */ + private double scoreGraphChangeDelete(Node x, Node y, Set conditioningSet) { + // Check if the edge deletion has already been evaluated and cached + DSeparationKey key = new DSeparationKey(y, x, conditioningSet); + Double cached = localScore.get(key); + if (cached != null) { + return cached; + } + + // Evaluating the d-separation condition across all initial DAGs + double eval = 0.0; + for (Dag g : this.initialDags) { + if (Utils.dSeparated(g, x, y, new ArrayList<>(conditioningSet))) { + eval++; + } + } + eval = eval / (double) this.initialDags.size(); + + localScore.put(key, eval); + return eval; + } + /** + * Returns the number of edges that were inserted during the consensus union and backward equivalence search process. + * @return The number of edges that were removed during the backward equivalence search process. + */ + public int getNumberOfRemovedEdges() { + return this.numberOfRemovedEdges; + } + /** + * Sets the percentage threshold for edge deletion. + * This method allows the user to specify a percentage threshold for edge deletion, for an heuristic search. + * The percentage must be between 0.0 and 1.0, where 0.0 means no edges are deleted and 1.0 means all edges are considered for deletion. + * If the percentage is outside this range, an IllegalArgumentException is thrown. + * @param percentage The percentage threshold for edge deletion, must be between 0.0 and 1.0. + * @throws IllegalArgumentException if the percentage is not between 0.0 and 1.0. + */ + public void setPercentage(double percentage) { + if(percentage < 0.0 || percentage > 1.0) { + throw new IllegalArgumentException("Percentage must be between 0.0 and 1.0"); + } + this.percentage = percentage; + } + + /** + * Sets the maximum size of the conditioning set for edge deletion. + * This method allows the user to specify a maximum size for the conditioning set used in edge deletion for an heuristic search. + * The maximum size must be a non-negative integer. If it is negative, an IllegalArgumentException is thrown. + * @param maxSize The maximum size of the conditioning set. + */ + public void setMaxSize(int maxSize) { + if(maxSize < 0) { + throw new IllegalArgumentException("Max size must be a non-negative integer"); + } + this.maxSize = maxSize; + } + + /** + * Returns the percentage threshold for edge deletion. + * This method retrieves the current percentage threshold set for edge deletion. + * @return The percentage threshold for edge deletion. + */ + public double getPercentage() { + return this.percentage; + } + + /** + * Returns the maximum size of the conditioning set for edge deletion. + * This method retrieves the current maximum size set for the conditioning set used in edge deletion. + * @return The maximum size of the conditioning set for edge deletion. + */ + public int getMaxSize() { + return this.maxSize; + } + + /** + * Class representing a candidate edge for deletion in the Backward Equivalence Search. + * This class encapsulates the tail and head nodes of the edge, the conditioning set used for d-separation, + * and the score associated with the edge deletion. + * + * @see BackwardEquivalenceSearchDSep#applyBackwardEliminationWithDSeparation() + * @see Utils#dSeparated(Dag, Node, Node, List) + */ + private class EdgeCandidate { + /** + * The tail node of the edge candidate. + */ + public final Node tail; + + /** + * The head node of the edge candidate. + */ + public final Node head; + + /** + * The conditioning set used for d-separation in the edge candidate. + */ + public final Set conditioningSet; + + /** + * The score associated with the edge candidate deletion. + */ + public double score; + + public EdgeCandidate(Node tail, Node head, Set conditioningSet) { + this.tail = tail; + this.head = head; + this.conditioningSet = conditioningSet; + } + + } + +} diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/BetaToAlpha.java b/src/main/java/es/uclm/i3a/simd/consensusBN/BetaToAlpha.java index 6398bf8..2131af6 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/BetaToAlpha.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/BetaToAlpha.java @@ -4,48 +4,90 @@ import java.util.HashMap; import java.util.List; import java.util.Random; -import edu.cmu.tetrad.graph.Node; + import edu.cmu.tetrad.graph.Dag; import edu.cmu.tetrad.graph.Edge; import edu.cmu.tetrad.graph.Endpoint; +import edu.cmu.tetrad.graph.Node; - +/** + * BetaToAlpha is a class that transforms a directed acyclic graph (DAG) into an I-map minimal with respect to a specified alpha order. + * It constructs a compatible beta order and modifies the graph accordingly. + * The transformation respects the alpha order, ensuring that the resulting graph is consistent with it. + */ public class BetaToAlpha { - Dag G = null; - ArrayList beta = new ArrayList(); - ArrayList alfa = new ArrayList(); - HashMap alfaHash= new HashMap(); - Dag G_aux = null; - int numberOfInsertedEdges = 0; + /** + * The directed acyclic graph (DAG) to be transformed. + */ + private final Dag dag; + + /** + * The beta order derived from the alpha order. + */ + private List beta; + + /** + * The alpha order that the graph should respect. In consensusBN, this alpha order has been created using the AlphaOrder class. + * If null, a random order will be created. + */ + private List alpha; + + /** + * A hash map to store the index of each node in the alpha order for quick access. + */ + private final HashMap alphaHash= new HashMap<>(); - public BetaToAlpha(Dag G, ArrayList alfa){ + /** + * The auxiliary graph used during the transformation process. + */ + private Dag G_aux = null; - this.alfa = alfa; - this.G = G; + /** + * The number of edges inserted during the transformation process. + */ + int numberOfInsertedEdges = 0; + + /** + * Constructor for BetaToAlpha that initializes the graph and alpha order. + * @param dag the directed acyclic graph (DAG) to be transformed. + * @param alpha the alpha order that the graph should respect. + */ + public BetaToAlpha(Dag dag, ArrayList alpha){ + this.alpha = alpha; + this.dag = dag; this.beta = null; - for(int i= 0; i< alfa.size(); i++){ - Node n = alfa.get(i); - alfaHash.put(n, i); + for(int i= 0; i< alpha.size(); i++){ + Node n = alpha.get(i); + alphaHash.put(n, i); } } - public BetaToAlpha(Dag G){ - - this.alfa = null; - this.G = G; + /** + * Constructor for BetaToAlpha that initializes the graph without a specified alpha order. + * A random alpha order will be created instead. + * @param dag the directed acyclic graph (DAG) to be transformed. + */ + public BetaToAlpha(Dag dag){ + this.alpha = null; + this.dag = dag; this.beta = null; - } - void computeAlfaHash(){ + /** + * Computes the alpha hash map if it is not already computed. + * This method populates the alphaHash with the index of each node in the alpha order. + * It is called before any transformation to ensure that the alpha order is respected. + * If the alpha order is null, it will not compute the hash. + */ + public void computeAlphaHash(){ - if(this.alfa !=null){ - if(alfaHash.isEmpty()){ - for(int i= 0; i< alfa.size(); i++){ - Node n = alfa.get(i); - alfaHash.put(n, i); + if(this.alpha !=null){ + if(alphaHash.isEmpty()){ + for(int i= 0; i< alpha.size(); i++){ + Node n = alpha.get(i); + alphaHash.put(n, i); } } } @@ -55,10 +97,15 @@ void computeAlfaHash(){ // Only to test the methods, to build a random order. - public ArrayList randomAlfa (Random aleatorio){ - - List nodes = this.G.getNodes(); - this.alfa = new ArrayList(); + /** + * Builds a random alpha order from the nodes of the graph. This is used for test purposes to ensure that the transformation can handle different orders. + * @param randomGenerator the random number generator to use for shuffling the nodes. + * @return a list of nodes representing a random alpha order. + */ + public List randomAlpha (Random randomGenerator){ + + List nodes = this.dag.getNodes(); + this.alpha = new ArrayList<>(); int[] index = new int[nodes.size()]; @@ -67,202 +114,232 @@ public ArrayList randomAlfa (Random aleatorio){ } for (int j = 0; j < nodes.size(); j++){ - - int indi = aleatorio.nextInt(nodes.size()); - int indj = aleatorio.nextInt(nodes.size()); + int indi = randomGenerator.nextInt(nodes.size()); + int indj = randomGenerator.nextInt(nodes.size()); int sw = index[indi]; index[indi] = index[indj]; index[indj] = sw; } for (int i = 0; i< nodes.size(); i++){ - this.alfa.add(i, nodes.get(index[i])); + this.alpha.add(i, nodes.get(index[i])); } - this.computeAlfaHash(); - return this.alfa; + this.computeAlphaHash(); + return this.alpha; } - + /** + * Transforms the graph G into an I-map minimal with respect to the alpha order. + */ public void transform(){ - this.G_aux = new Dag(this.G); - this.beta = new ArrayList(); + // 1. Create a compatible beta order with the alfa order for the DAG G. + buildBetaOrder(); + + // 2. Transform graph G into an I-map minimal with alpha order + transformWithBeta(); + + } + + /** + * Builds the beta order that best respects the alpha order for the given graph G. + * This method constructs a beta order by identifying sink nodes and arranging them in a way that minimizes the number of edges that violate the alpha order. + * It uses a greedy approach to select the next node based on its position in the alpha order. + * The beta order is constructed such that it is as close as possible to the alpha order while ensuring that the resulting graph is still a DAG. + * + * This method modifies the G_aux graph to reflect the current state of the transformation. + * It also initializes the beta list with the first sink node and iteratively adds nodes to the beta order based on their relationships in the graph. + */ + private void buildBetaOrder() { + this.G_aux = new Dag(this.dag); + this.beta = new ArrayList<>(); + List parents; + + // Compute the sink nodes and add the first one to beta. ArrayList sinkNodes = getSinkNodes(this.G_aux); this.beta.add(sinkNodes.get(0)); - List pa = G_aux.getParents(sinkNodes.get(0)); + parents = G_aux.getParents(sinkNodes.get(0)); this.G_aux.removeNode(sinkNodes.get(0)); sinkNodes.remove(0); + // Compute the new sink nodes - for(Node nodep: pa){ - List chld = G_aux.getChildren(nodep); - if (chld.size() == 0) sinkNodes.add(nodep); - } + updateSinkNodes(sinkNodes, parents); - // Construct beta order as closer as possible to alfa. - + // Construct beta order as close as possible to alpha. while (this.G_aux.getNumNodes()>0){ - // sinkNodes = getSinkNodes(this.G_aux); + // Select fist sink node Node sink = sinkNodes.get(0); - pa = G_aux.getParents(sink); + parents = G_aux.getParents(sink); this.G_aux.removeNode(sink); sinkNodes.remove(0); // Compute the new sink nodes - for(Node nodep: pa){ - List chld = G_aux.getChildren(nodep); - if (chld.size() == 0) sinkNodes.add(nodep); - } + updateSinkNodes(sinkNodes, parents); - int index_alfa_sink = this.alfaHash.get(sink); //this.alfa.indexOf(sink); - boolean ok = true; - int i = 0; - - while(ok){ - - Node nodej = this.beta.get(i); - int index_alfa_nodej = this.alfaHash.get(nodej); //this.alfa.indexOf(nodej); - - if (index_alfa_nodej > index_alfa_sink){ ok = false; break;} - if (this.G.getParents(nodej).contains(sink)){ ok = false; break;} - if (i == this.beta.size()-1){ ok = false; break;} - i++; + // Compute the index to insert the sink node in beta. + int insertIndex = 0; + for (; insertIndex < beta.size(); insertIndex++) { + Node current = beta.get(insertIndex); + if (alphaHash.get(current) > alphaHash.get(sink)) break; + if (dag.getParents(current).contains(sink)) break; } - - this.beta.add(i,sink); + beta.add(insertIndex, sink); } + } +/* FUTURE IDEA: SELECT BEST SINK NODE FROM ALPHA ORDER. + private Node selectBestSinkNode(List sinkNodes) { + return sinkNodes.stream() + .min(Comparator.comparingInt(alfaHash::get)) + .orElse(sinkNodes.get(0)); + } +*/ + /** + * Updates the sink nodes list based on the current list of candidates. + * This method checks each candidate node to see if it has any children in the auxiliary graph G_aux. + * If a candidate node has no children, it is added to the sink nodes list. + * This is used to maintain the integrity of the beta order during the transformation process. + * + * @param sinkNodes the list of current sink nodes to be updated. + * @param candidates the list of candidate nodes to check for children. + */ + private void updateSinkNodes(ArrayList sinkNodes, List candidates) { + // Compute the new sink nodes + for(Node node: candidates){ + List chld = G_aux.getChildren(node); + if (chld.isEmpty()) + sinkNodes.add(node); + } + } + + /** + * Transforms the graph G into an I-map minimal with respect to the alpha order. + * This method rearranges the edges in the graph based on the beta order derived from the alpha order. + * It ensures that the resulting graph respects the alpha order by checking the relationships between nodes and adjusting edges accordingly. + * The transformation modifies the graph in place and updates the beta list to reflect the new order of nodes. + */ + private void transformWithBeta() { + ArrayList orderedNodes = new ArrayList<>(); + // Setting the first node in the orderedNodes list. + orderedNodes.add(this.beta.remove(0)); - // transform graph G into an I-map minimal with alpha order - - ArrayList aux_beta = new ArrayList(); - aux_beta.add(this.beta.get(0)); - this.beta.remove(0); - - while(this.beta.size()>0){ // check each variable from the sink nodes. - - aux_beta.add(this.beta.get(0)); + while(!this.beta.isEmpty()){ + // Setting the next node in the orderedNodes list. + orderedNodes.add(this.beta.get(0)); this.beta.remove(0); - int i = aux_beta.size(); - boolean ok = true; + int i = orderedNodes.size(); + boolean changed = true; - while (ok){ - + while (changed){ if(i==1) break; - ok = false; - Node nodeY = aux_beta.get(i-1); - Node nodeZ = aux_beta.get(i-2); - -// if ((nodeZ != null) && (this.alfa.indexOf(nodeZ) > this.alfa.indexOf(nodeY))){ - if ((nodeZ != null) && (this.alfaHash.get(nodeZ) > this.alfaHash.get(nodeY))){ - if(this.G.getEdge(nodeZ, nodeY) != null){ - List paZ = this.G.getParents(nodeZ); - List paY = this.G.getParents(nodeY); + changed = false; + // Getting the last two nodes in the ordered list + Node nodeY = orderedNodes.get(i-1); + Node nodeZ = orderedNodes.get(i-2); + + // Check if there is an edge from nodeZ to nodeY, if so, cover it. + if ((nodeZ != null) && (this.alphaHash.get(nodeZ) > this.alphaHash.get(nodeY))){ + if(this.dag.getEdge(nodeZ, nodeY) != null){ + List paZ = this.dag.getParents(nodeZ); + List paY = this.dag.getParents(nodeY); paY.remove(nodeZ); - this.G.removeEdge(nodeZ, nodeY); - this.G.addEdge(new Edge(nodeY,nodeZ,Endpoint.TAIL, Endpoint.ARROW)); + this.dag.removeEdge(nodeZ, nodeY); + this.dag.addEdge(new Edge(nodeY,nodeZ,Endpoint.TAIL, Endpoint.ARROW)); for(Node nodep: paZ){ - Edge pay = this.G.getEdge(nodep, nodeY); + Edge pay = this.dag.getEdge(nodep, nodeY); if(pay == null){ - this.G.addEdge(new Edge(nodep,nodeY,Endpoint.TAIL,Endpoint.ARROW)); + this.dag.addEdge(new Edge(nodep,nodeY,Endpoint.TAIL,Endpoint.ARROW)); this.numberOfInsertedEdges++; } } for(Node nodep : paY){ - Edge paz = this.G.getEdge(nodep,nodeZ); + Edge paz = this.dag.getEdge(nodep,nodeZ); if(paz == null){ - this.G.addEdge(new Edge(nodep,nodeZ,Endpoint.TAIL,Endpoint.ARROW)); + this.dag.addEdge(new Edge(nodep,nodeZ,Endpoint.TAIL,Endpoint.ARROW)); this.numberOfInsertedEdges++; } } } - ok = true; - aux_beta.remove(nodeY); - aux_beta.add(i-2,nodeY); + changed = true; + orderedNodes.remove(nodeY); + orderedNodes.add(i-2,nodeY); i--; } } } - - this.beta = aux_beta; - + this.beta = orderedNodes; } - + /** + * Returns the number of edges that were inserted during the transformation process. + * This method is useful for understanding how many modifications were made to the original graph to achieve the desired alpha order. + * @return the number of edges that were inserted during the transformation process. + * @see BetaToAlpha#transform() + */ public int getNumberOfInsertedEdges(){ - return this.numberOfInsertedEdges; } - ArrayList getSinkNodes(Dag g){ - - ArrayList sourcesNodes = new ArrayList(); - List nodes = g.getNodes(); - - for (Node nodei : nodes){ - if(g.getChildren(nodei).isEmpty()) sourcesNodes.add(nodei); + /** + * Retrieves the sink nodes from the given directed acyclic graph (DAG). + * A sink node is defined as a node that does not have any children in the graph. + * This method iterates through all nodes in the graph and checks their children to determine if they are sink nodes. + * + * @param dagGraph the directed acyclic graph (DAG) from which to retrieve sink nodes. + * @return an ArrayList of sink nodes that do not have any children in the graph. + */ + private ArrayList getSinkNodes(Dag dagGraph){ + // Get nodes from DAG + ArrayList sinkNodes = new ArrayList<>(); + List nodes = dagGraph.getNodes(); + // Check which nodes don't have children and add them to sinkNodes + for (Node node : nodes){ + if(dagGraph.getChildren(node).isEmpty()){ + sinkNodes.add(node); + } } - return sourcesNodes; - + return sinkNodes; } - - - -// public static void main(String args[]) { -// -// //Graph graph = GraphConverter.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4"); -// Graph graph = GraphConverter.convert("X2-->X1,X3-->X1,X1-->X4,X5-->X4,X4-->X6"); -// Dag dag = new Dag(graph); -// -// Dag dag2 = GraphUtils.randomDag(dag.getNodes(), 7, true); -//// BayesPm bayesPm = new BayesPm(dag, 3, 3); -//// MlBayesIm bayesIm = new MlBayesIm(bayesPm); -//// -//// Element element = BayesXmlRenderer.getElement(bayesIm); -//// System.out.println("Started with this bayesIm: " + bayesIm); -//// System.out.println("\nGot this XML for it:"); -//// Document xmldoc = new Document(element); -//// Serializer serializer = new Serializer(System.out); -//// serializer.setLineSeparator("\n"); -//// serializer.setIndent(2); -//// try { -//// serializer.write(xmldoc); -//// } -//// catch (IOException e) { -//// throw new RuntimeException(e); -//// } -// -// -// System.out.println(GraphUtils.graphToDot(dag)); -// -// -//// System.out.println("Dag Inicial: "+ dag.toString()); -// -// Random aleatorio = new Random(150); -// BetaToAlpha mt = new BetaToAlpha(dag); -// mt.randomAlfa (aleatorio); -// mt.transform(); -//// System.out.println(mt.G.toString()+" Alfa: "+mt.alfa.toString()+" Beta: "+ mt.beta.toString() ); -// -// System.out.println(GraphUtils.graphToDot(mt.G)); -// -// -// -//// System.out.println("Dag Inicial: "+ dag2.toString()); -// -// System.out.println(GraphUtils.graphToDot(dag2)); -// -// BetaToAlpha mt2 = new BetaToAlpha(dag2); -// Random aleat2 = new Random(150); -// mt2.randomAlfa(aleat2); -// mt2.transform(); -// -//// System.out.println(mt2.G.toString()+" Alfa: "+mt2.alfa.toString()+" Beta: "+ mt2.beta.toString() ); -// -// System.out.println(GraphUtils.graphToDot(mt2.G)); -// -// -// -// } - - + /** + * Returns the alpha hash map that contains the index of each node in the alpha order. + * This map is used to quickly access the position of nodes in the alpha order during the transformation process. + * It is particularly useful for ensuring that the resulting graph respects the specified alpha order. + * @return the alpha hash map where keys are nodes and values are their indices in the alpha order. + */ + public HashMap getAlphaHash() { + return alphaHash; + } + + /** + * Sets the alpha order for the transformation. + * This method allows the user to specify a new alpha order for the graph transformation. + * It updates the alpha field and recomputes the alpha hash map to reflect the new order. + * @param alpha the new alpha order to be set for the transformation. + */ + public void setAlphaOrder(List alpha) { + this.alpha = alpha; + this.computeAlphaHash(); + } + + /** + * Returns the alpha order that the graph should respect. + * @return the alpha order as a list of nodes, or null if no alpha order has been set. + */ + public List getAlphaOrder() { + return alpha; } + /** + * Returns the directed acyclic graph (DAG) that has been transformed. + * This method provides access to the modified graph after the transformation has been applied. + * The graph will be an I-map minimal with respect to the specified alpha order. + * + * @see BetaToAlpha#transform() + * @return the transformed directed acyclic graph (DAG) as a Dag object. + */ + public Dag getGraph() { + return dag; + } + + +} + diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java index 9514eb2..951566f 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusBES.java @@ -1,468 +1,166 @@ package es.uclm.i3a.simd.consensusBN; import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; import java.util.List; -import java.util.Map; -import java.util.Set; import edu.cmu.tetrad.graph.Dag; -import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.EdgeListGraph; -import edu.cmu.tetrad.graph.Edges; -import edu.cmu.tetrad.graph.Endpoint; -import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.search.utils.MeekRules; -import edu.cmu.tetrad.search.utils.GraphSearchUtils; -import static es.uclm.i3a.simd.consensusBN.Utils.pdagToDag; -//import experimentosFusion.RandomBN; - - +/** + * This class implements the Optimal Fusion GES^h_d algorithm, which applies a Consensus Union followed by a Backward Equivalence Search (BES) with D-separation. + * The algorithm first computes a consensus DAG from a set of input DAGs using the ConsensusUnion class. + * After obtaining the consensus DAG, it applies the Backward Equivalence Search with D-separation to refine the graph, achieving the optimal fusion BN. + * The resulting output DAG is stored in the outputDag attribute. + */ public class ConsensusBES implements Runnable { - ArrayList alpha = null; - Dag outputDag = null; - AlphaOrder heuristic = null; - TransformDags imaps2alpha = null; - ArrayList setOfdags = null; - ArrayList setOfOutDags = null; - Dag union = null; + /** + * Final output DAG after applying the Consensus Union and Backward Equivalence Search with D-separation. + * This DAG represents the optimal fusion of the input DAGs. + * It is computed by first merging the input DAGs into a consensus DAG and then refining it using the BES with D-separation. + * @see ConsensusUnion + * @see BackwardEquivalenceSearchDSep + */ + protected Dag outputDag; + + /** + * Instance of ConsensusUnion used to compute the consensus DAG from the input DAGs. + * This instance is initialized with the set of input DAGs and computes the alpha order of nodes using AlphaOrder heuristic (Greedy Heuristic Order). + * + * @see ConsensusUnion + * @see AlphaOrder + */ + private final ConsensusUnion consensusUnion; + + /** + * List of input DAGs to be fused using the ConsensusBES algorithm. + */ + private final ArrayList inputDags; + + /** + * List of transformed DAGs after applying the alpha order to the input DAGs. + * @see BetaToAlpha + * @see TransformDags + */ + private ArrayList transformedDags; + + /** + * Resulting DAG afther applying the Consensus Union algorithm. + * This DAG contains the union of all edges from the transformed input DAGs, ensuring that the resulting graph is acyclic. + * The number of edges inserted during the union process can be retrieved using getNumberOfInsertedEdges. + */ + private Dag union = null; + + /** + * Number of edges inserted during the consensus union process and the Backward Equivalence Search process. + */ int numberOfInsertedEdges = 0; - Map localScore = new HashMap(); - - + + /** + * Constructor for ConsensusBES that initializes the union process with a list of DAGs. + * It creates an instance of ConsensusUnion to compute the consensus DAG. + * @param dags the list of input DAGs to be merged. + */ public ConsensusBES(ArrayList dags){ - this.setOfdags = dags; - this.heuristic = new AlphaOrder(this.setOfdags); - - this.heuristic.computeAlphaH2(); - this.alpha = this.heuristic.alpha; - this.imaps2alpha = new TransformDags(this.setOfdags,this.alpha); - - this.imaps2alpha.transform(); - this.numberOfInsertedEdges = imaps2alpha.getNumberOfInsertedEdges(); - this.setOfOutDags = imaps2alpha.setOfOutputDags; - } - - - public int getNumberOfInsertedEdges(){ - return this.numberOfInsertedEdges; + this.inputDags = dags; + this.consensusUnion = new ConsensusUnion(this.inputDags); } - private void consensusUnion(){ - - this.union = new Dag(this.alpha); - for(Node nodei: this.alpha){ - for(Dag d : this.imaps2alpha.setOfOutputDags){ - Listparent = d.getParents(nodei); - for(Node pa: parent){ - if(!this.union.isParentOf(pa, nodei)){ - this.union.addEdge(new Edge(pa,nodei,Endpoint.TAIL,Endpoint.ARROW)); - } - } - } - - } -// for(Edge e: this.union.getEdges()){ -// for(Dag d : this.imaps2alpha.setOfOutputDags){ -// if((d.getEdge(e.getNode1(), e.getNode2())==null) && (d.getEdge(e.getNode2(), e.getNode1())==null)) -// this.numberOfInsertedEdges++; -// -// } -// } - + /** + * Performs the consensus union operation by calling the union method of the ConsensusUnion instance. + * This method initializes the union process, transforming the input DAGs based on the alpha order and merging them into a single consensus DAG. + * After the union, it retrieves the transformed DAGs and updates the number of inserted edges. + */ + public void consensusUnion(){ + this.union = this.consensusUnion.union(); + this.transformedDags = this.consensusUnion.getTransformedDags(); + this.numberOfInsertedEdges += consensusUnion.getNumberOfInsertedEdges(); } - // private methods for searching - - + /** + * Applies the fusion process by first performing the consensus union and then applying the Backward Equivalence Search with D-separation. + * This method modifies the outputDag attribute to contain the final fused DAG after applying both steps. + */ public void fusion(){ - - // System.out.println("\n** BACKWARD ELIMINATION SEARCH (BES)"); - //PowerSetFabric.setMode(PowerSetFabric.MODE_BES); - double score = 0; - double bestScore = score; - Graph graph = null; - + // 1. Apply ConsensusUnion to the set of dags consensusUnion(); - - graph = new EdgeListGraph(new LinkedList<>(this.union.getNodes())); - for(Edge e: this.union.getEdges()){ - graph.addEdge(e); - } - - //SearchGraphUtils.dagToPdag(graph); - rebuildPattern(graph); - Node x, y; - Set t = new HashSet(); - do { - x = y = null; - Set edges1 = graph.getEdges(); - List edges = new ArrayList(); - - for (Edge edge : edges1) { - Node _x = edge.getNode1(); - Node _y = edge.getNode2(); - - if (Edges.isUndirectedEdge(edge)) { - edges.add(Edges.directedEdge(_x, _y)); - edges.add(Edges.directedEdge(_y, _x)); - } else { - edges.add(edge); - } - } - for (Edge edge : edges) { - - Node _x = Edges.getDirectedEdgeTail(edge); - Node _y = Edges.getDirectedEdgeHead(edge); - - List hNeighbors = getHNeighbors(_x, _y, graph); -// List> hSubsets = powerSet(hNeighbors); - PowerSet hSubsets= PowerSetFabric.getPowerSet(_x,_y,hNeighbors); - - while(hSubsets.hasMoreElements()) { - SubSet hSubset=hSubsets.nextElement(); - double deleteEval = deleteEval(_x, _y, hSubset, graph); - if (!(deleteEval >= 1.0)) deleteEval = 0.0; - double evalScore = score + deleteEval; - - //System.out.println("Attempt removing " + _x + "-->" + _y + "(" +evalScore + ") "+ hSubset.toString()); - - if (!(evalScore > bestScore)) { - continue; - } - - // INICIO TEST 1 - List naYXH = findNaYX(_x, _y, graph); - naYXH.removeAll(hSubset); - if (!isClique(naYXH, graph)) { -// hSubsets.firstTest(true); // Si pasa para H entonces pasa para cualquier H' | H' contiene H - continue; - } - // FIN TEST 1 - - bestScore = evalScore; - x = _x; - y = _y; - t = hSubset; - } - - } - if (x != null) { - System.out.println(" "); - System.out.println("DELETE " + graph.getEdge(x, y) + t.toString() + " (" +bestScore + ")"); - System.out.println(" "); - delete(x, y, t, graph); - rebuildPattern(graph); - int deletedEdges = 0; - for(int g = 0; g subset, Graph graph) { - graph.removeEdges(x, y); - - for (Node aSubset : subset) { - if (!graph.isParentOf(aSubset, x) && !graph.isParentOf(x, aSubset)) { - graph.removeEdge(x, aSubset); - graph.addDirectedEdge(x, aSubset); - } - graph.removeEdge(y, aSubset); - graph.addDirectedEdge(y, aSubset); - } - } - - - private void rebuildPattern(Graph graph) { - GraphSearchUtils.basicCpdag(graph); - pdag(graph); - } - - /** - * Fully direct a graph with background knowledge. I am not sure how to - * adapt Chickering's suggested algorithm above (dagToPdag) to incorporate - * background knowledge, so I am also implementing this algorithm based on - * Meek's 1995 UAI paper. Notice it is the same implemented in PcSearch. - *

*IMPORTANT!* *It assumes all colliders are oriented, as well as - * arrows dictated by time order.* - * - * ELIMINADO BACKGROUND KNOWLEDGE - */ - private void pdag(Graph graph) { - MeekRules rules = new MeekRules(); - rules.setMeekPreventCycles(true); - rules.orientImplied(graph); - } - - - private static boolean isClique(List set, Graph graph) { - List setv = new LinkedList(set); - for (int i = 0; i < setv.size() - 1; i++) { - for (int j = i + 1; j < setv.size(); j++) { - if (!graph.isAdjacentTo(setv.get(i), setv.get(j))) { - return false; - } - } - } - return true; - } - - private static List getHNeighbors(Node x, Node y, Graph graph) { - List hNeighbors = new LinkedList(graph.getAdjacentNodes(y)); - hNeighbors.retainAll(graph.getAdjacentNodes(x)); - - for (int i = hNeighbors.size() - 1; i >= 0; i--) { - Node z = hNeighbors.get(i); - Edge edge = graph.getEdge(y, z); - if (!Edges.isUndirectedEdge(edge)) { - hNeighbors.remove(z); - } - } - - return hNeighbors; - } - - - double deleteEval(Node x, Node y, SubSet h, Graph graph){ - - Set set1 = new HashSet(findNaYX(x, y, graph)); - set1.removeAll(h); - set1.addAll(graph.getParents(y)); - set1.remove(x); - return scoreGraphChangeDelete(y, x, set1); // calcular si y esta d-separado de x dado el set1 en cada grafo. - - } - - double scoreGraphChangeDelete(Node y, Node x, Set set){ - - String key = y.getName()+x.getName()+set.toString(); - Double val = this.localScore.get(key); - if(val == null){ - double eval = 0.0; - LinkedList conditioning = new LinkedList(); - conditioning.addAll(set); - for(Dag g: this.setOfdags){ - if(!dSeparated(g,y, x, conditioning)) return 0.0; - } - eval = 1.0; //eval / (double) this.setOfdags.size(); - val = eval; - this.localScore.put(key, val); - return eval; - }else{ - return val.doubleValue(); - } - } - - - boolean dSeparated(Dag g, Node x, Node y, LinkedList cond){ - - LinkedList open = new LinkedList(); - HashMap close = new HashMap(); - open.add(x); - open.add(y); - open.addAll(cond); - while (open.size() != 0){ - Node a = open.getFirst(); - open.remove(a); - close.put(a.toString(),a); - List pa =g.getParents(a); - for(Node p : pa){ - if(close.get(p.toString()) == null){ - if(!open.contains(p)) open.addLast(p); - } - } - } - - Graph aux = new EdgeListGraph(); - - for (Node node : g.getNodes()) aux.addNode(node); - Node nodeT, nodeH; - for (Edge e : g.getEdges()){ - if(!e.isDirected()) continue; - nodeT = e.getNode1(); - nodeH = e.getNode2(); - if((close.get(nodeH.toString())!=null)&&(close.get(nodeT.toString())!=null)){ - Edge newEdge = new Edge(e.getNode1(),e.getNode2(),e.getEndpoint1(),e.getEndpoint2()); - aux.addEdge(newEdge); - } - } - - close = new HashMap(); - for(Edge e: aux.getEdges()){ - if(e.isDirected()){ - Node h; - if(e.getEndpoint1()==Endpoint.ARROW){ - h = e.getNode1(); - }else h = e.getNode2(); - if(close.get(h.toString())==null){ - close.put(h.toString(),h); - List pa = aux.getParents(h); - if(pa.size()>1){ - for(int i = 0 ; i< pa.size() - 1; i++) - for(int j = i+1; j < pa.size(); j++){ - Node p1 = pa.get(i); - Node p2 = pa.get(j); - boolean found = false; - for(Edge edge : aux.getEdges()){ - if(edge.getNode1().equals(p1)&&(edge.getNode2().equals(p2))){ - found = true; - break; - } - if(edge.getNode2().equals(p1)&&(edge.getNode1().equals(p2))){ - found = true; - break; - } - } - if(!found) aux.addUndirectedEdge(p1, p2); - } - } - - } - } - } - - for(Edge e: aux.getEdges()){ - if(e.isDirected()){ - e.setEndpoint1(Endpoint.TAIL); - e.setEndpoint2(Endpoint.TAIL); - } - } - - aux.removeNodes(cond); - - open = new LinkedList(); - close = new HashMap(); - open.add(x); - while (open.size() != 0){ - Node a = open.getFirst(); - if(a.equals(y)) return false; - open.remove(a); - close.put(a.toString(),a); - List pa =aux.getAdjacentNodes(a); - for(Node p : pa){ - if(close.get(p.toString()) == null){ - if(!open.contains(p)) open.addLast(p); - } - } - } - - return true; + // 2. Apply Backward Equivalence Search with D-separation + BackwardEquivalenceSearchDSep bes = new BackwardEquivalenceSearchDSep(this.union, this.inputDags, this.transformedDags); + this.outputDag = bes.applyBackwardEliminationWithDSeparation(); + // 3. Updating numberOfInsertedEdges + this.numberOfInsertedEdges -= bes.getNumberOfRemovedEdges(); } - - - - private static List findNaYX(Node x, Node y, Graph graph) { - List naYX = new LinkedList(graph.getAdjacentNodes(y)); - naYX.retainAll(graph.getAdjacentNodes(x)); - - for (int i = naYX.size()-1; i >= 0; i--) { - Node z = naYX.get(i); - Edge edge = graph.getEdge(y, z); - - if (!Edges.isUndirectedEdge(edge)) { - naYX.remove(z); - } - } - - return naYX; - } - - public Dag getFusion(){ - + /** + * Returns the output DAG after applying the Consensus Union and Backward Equivalence Search with D-separation. + * This method retrieves the final fused DAG, which represents the optimal fusion of the input DAGs. + * @return the resulting output DAG after the fusion process. + */ + public Dag getFusionDag(){ return this.outputDag; } + /** + * Returns a valid ancestral order of the nodes in the fused DAG. + * @return a list of nodes representing an ancestral order of the resulting DAG. + */ public List getOrderFusion(){ - return this.getFusion().paths().getValidOrder(this.getFusion().getNodes(),true); + return this.getFusionDag().paths().getValidOrder(this.getFusionDag().getNodes(),true); } - - public static void main(String args[]) { - - - System.out.println("Grafos de Partida: "); - - // (seed, n. variables, n egdes max, n.dags, mutation(n. de operaciones)) - RandomBN setOfBNs = new RandomBN(0, Integer.parseInt(args[0]), Integer.parseInt(args[1]), - Integer.parseInt(args[2]), Integer.parseInt(args[3])); - setOfBNs.setMaxInDegree(4); - setOfBNs.setMaxOutDegree(4); - setOfBNs.generate(); + /** + * Returns the number of edges inserted during the consensus union and removed in the Backward Equivalence Search with D-separation. + * @return the number of edges inserted during the consensus union and removed in the Backward Equivalence Search with D-separation. + */ + public int getNumberOfInsertedEdges(){ + return this.numberOfInsertedEdges; + } - for (int i = 0; i < setOfBNs.setOfRandomBNs.size(); i++) { - System.out.println("red de partida: " + i); - System.out.println("---------------------"); - System.out.println("Grafo: "); - System.out.println(setOfBNs.setOfRandomDags.get(i).toString()); -// System.out.println("Probabilidades: "); -// System.out.println(setOfBNs.setOfRandomBNs.get(i).toString()); -// System.out.println("_____________________"); -// System.out.println("Datos Simulados"); -// System.out.println(setOfBNs.setOfSampledBNs.get(i).toString()); + /** + * Returns the union DAG resulting from the consensus union process. + * @return the union DAG after merging the transformed input DAGs. + */ + public Dag getUnion() { + return this.union; + } -// -// } -// // - ConsensusBES conDag= null; -// - conDag = new ConsensusBES(setOfBNs.setOfRandomDags); - conDag.fusion(); - Dag g = conDag.getFusion(); - System.out.println("grafo consenso: "+ g +" Complejidad de la Fusion: "+ conDag.getNumberOfInsertedEdges() - + " "+ conDag.union.getNumEdges()); - System.out.println("Orden Inicial Heu: "+conDag.alpha.toString()); - System.out.println("Orden de consenso: "+conDag.getOrderFusion().toString()); -// -//// HierarchicalAgglomerativeClustererBNs Cfusion = new HierarchicalAgglomerativeClustererBNs(setOfBNs.setOfRandomDags,0.50); -//// int l = Cfusion.cluster(); -//// System.out.println("Nivel de Fusion: "+l); -//// System.out.println(Cfusion.computeConsensusDag(l).toString()); -// } -// + /** + * Returns the ConsensusUnion instance used in this ConsensusBES. + * This instance contains the logic for merging the input DAGs and computing the alpha order. + * @return the ConsensusUnion instance associated with this ConsensusBES. + */ + public ConsensusUnion getConsensusUnion() { + return this.consensusUnion; + } + + /** + * Returns the list of transformed DAGs after applying the alpha order to the input DAGs. + * This method retrieves the transformed DAGs that were used in the consensus union process. + * @return the list of transformed DAGs. + */ + public ArrayList getTransformedDags() { + if (this.transformedDags != null) { + return this.transformedDags; + } else { + throw new IllegalStateException("Transformed DAGs have not been initialized. Please call fusion() first."); } } + /** + * Returns the list of input DAGs used in this ConsensusBES. + * This method retrieves the original DAGs that were provided to the ConsensusBES constructor. + * @return the list of input DAGs. + */ + public ArrayList getInputDags() { + return this.inputDags; + } + + /** + * Runs the ConsensusBES algorithm in a thread, performing the consensus union and the Backward Equivalence Search with D-separation. + */ @Override public void run() { - + this.fusion(); } } diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java index 8c38fd9..81a3b7e 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/ConsensusUnion.java @@ -6,59 +6,119 @@ import edu.cmu.tetrad.graph.Dag; import edu.cmu.tetrad.graph.Edge; import edu.cmu.tetrad.graph.Endpoint; -import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; - +/** + * This class implements the Consensus Union algorithm which applies a fusion between multiple Directed Acyclic Graphs (DAGs). + * It constructs a consensus DAG by merging the input DAGs based on a specified order of nodes (alpha). + * The alpha order is computed with the AlphaOrder class which implements a Greedy Heuristic Order (GHO) search, achieving a good order to transform the input DAGs. + * Once each DAG is transformed, the union method creates a new DAG that contains all the edges from the input DAGs, ensuring that the resulting graph is acyclic. + * The number of edges inserted during the union process can be retrieved using getNumberOfInsertedEdges. + * + * This class is also runnable, allowing it to be executed in a separate thread. + */ public class ConsensusUnion implements Runnable{ - ArrayList alpha = null; - Dag outputDag = null; - AlphaOrder heuristic = null; - TransformDags imaps2alpha = null; - ArrayList setOfdags = null; + /** + * The alpha order of nodes in the consensus DAG. + * This order is used to transform the input DAGs into a compatible I-Maps before merging. + * It is computed using the AlphaOrder class. + * + * @see AlphaOrder + */ + private ArrayList alpha; + /** + * The AlphaOrder heuristic used to compute the alpha order. + */ + private AlphaOrder heuristic = null; + + /** + * The TransformDags instance that transforms the input DAGs based on the alpha order. + */ + private TransformDags imaps2alpha; + + /** + * List of input DAGs to be merged. + */ + private ArrayList setOfdags = null; + + /** + * The output DAG resulting from the union of the transformed input DAGs. + */ Dag union = null; + + /** + * Number of edges inserted during the consensus union process. + */ int numberOfInsertedEdges = 0; - + /** + * Constructor for ConsensusUnion that initializes the union process with a list of DAGs and an alpha order. + * @param dags the list of input DAGs to be merged. + * @param order the alpha order of nodes to be used for transforming the input DAGs. + */ public ConsensusUnion(ArrayList dags, ArrayList order){ this.setOfdags = dags; this.alpha = order; - } - - + /** + * Constructor for ConsensusUnion that initializes the union process with a list of DAGs and uses the AlphaOrder object to generate an alpha order. + * @see AlphaOrder + * @param dags the list of input DAGs to be merged. + */ public ConsensusUnion(ArrayList dags){ this.setOfdags = dags; this.heuristic = new AlphaOrder(this.setOfdags); - } + /** + * Default constructor for ConsensusUnion that initializes an empty union. + * This constructor can be used when the DAGs are set later using the setDags method. + */ public ConsensusUnion(){ this.setOfdags = null; } + /** + * Returns the number of edges inserted during the union process. + * This value is updated after the union method is called. + * @return the number of edges inserted in the consensus DAG. + */ public int getNumberOfInsertedEdges(){ return this.numberOfInsertedEdges; } + /** + * Performs the union of the input DAGs based on the alpha order. If no alpha order is set, it computes it first. + * The method transforms each input DAG according to the alpha order and then merges them into a single consensus DAG. + * The resulting DAG contains all edges from the transformed input DAGs, ensuring that it remains acyclic. + * + * @throws IllegalStateException if the alpha order is not set before calling this method. + * @throws IllegalArgumentException if the input DAGs are null or empty. + * @throws NullPointerException if the alpha order is null. + * @return the resulting consensus DAG after merging the transformed input DAGs. + * @see AlphaOrder + * @see TransformDags + */ public Dag union(){ + // Computing Alpha Order if not set, using the Greedy Heuristic Order (GHO) if(this.alpha == null){ - - this.heuristic.computeAlphaH2(); - this.alpha = this.heuristic.alpha; + this.heuristic.computeAlpha(); + this.alpha = this.heuristic.getOrder(); } + // Transforming each DAG with the alpha order this.imaps2alpha = new TransformDags(this.setOfdags,this.alpha); this.imaps2alpha.transform(); this.numberOfInsertedEdges = this.imaps2alpha.getNumberOfInsertedEdges(); + // Applying a union of the edges of the transformed DAGs this.union = new Dag(this.alpha); for(Node nodei: this.alpha){ - for(Dag d : this.imaps2alpha.setOfOutputDags){ + for(Dag d : this.imaps2alpha.getSetOfOutputDags()){ Listparent = d.getParents(nodei); for(Node pa: parent){ if(!this.union.isParentOf(pa, nodei)) this.union.addEdge(new Edge(pa,nodei,Endpoint.TAIL,Endpoint.ARROW)); @@ -70,47 +130,50 @@ public Dag union(){ } + /** + * Returns the resulting consensus DAG after the union process. + * This method should be called after the union method to ensure that the union has been performed. + * @return the consensus DAG resulting from the union of the input DAGs. + */ public Dag getUnion(){ return this.union; } + /** + * sets the list of input DAGs for the ConsensusUnion instance and applies the AlphaOrder heuristic to compute the alpha order. + * This method also updates the alpha order and transforms the input DAGs accordingly. + * @param dags + */ void setDags(ArrayList dags){ this.setOfdags = dags; this.heuristic = new AlphaOrder(this.setOfdags); - this.heuristic.computeAlphaH2(); - this.alpha = this.heuristic.alpha; + this.heuristic.computeAlpha(); + this.alpha = this.heuristic.getOrder(); this.imaps2alpha = new TransformDags(this.setOfdags,this.alpha); this.imaps2alpha.transform(); } - - - - public static void main(String args[]) { - - - System.out.println("Grafos de Partida: "); - - // (seed, n. variables, n egdes aprox, n. dags, mutation) - RandomBN setOfDags = new RandomBN(0, Integer.parseInt(args[0]), Integer.parseInt(args[1]), - Integer.parseInt(args[2]),Integer.parseInt(args[3])); - setOfDags.generate(); -// - for( Dag g: setOfDags.setOfRandomDags) System.out.print(g); - ConsensusUnion conDag= new ConsensusUnion(); - conDag.setDags(setOfDags.setOfRandomDags); - Graph g = conDag.union(); - System.out.println("grafo consenso: "+ g); - - } - - + /** + * Runs the ConsensusUnion process in a separate thread. + */ @Override public void run() { this.union = this.union(); } + + /** + * Returns the list of transformed DAGs after applying the alpha order to the input DAGs with TransformDags. + * @return the list of transformed DAGs. + */ + public ArrayList getTransformedDags() { + if (this.imaps2alpha != null) { + return this.imaps2alpha.getSetOfOutputDags(); + } else { + throw new IllegalStateException("TransformDags has not been initialized. Please call union() first."); + } + } diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/DSeparationKey.java b/src/main/java/es/uclm/i3a/simd/consensusBN/DSeparationKey.java new file mode 100644 index 0000000..3c6d186 --- /dev/null +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/DSeparationKey.java @@ -0,0 +1,113 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +import edu.cmu.tetrad.graph.Node; + +/** + * This class represents a key for D-separation checks in a Bayesian network. + * It encapsulates two nodes (x and y) and a set of conditioning nodes. + * The key is used to efficiently check if two nodes are d-separated given a conditioning set. + * The equality and hashCode methods ensure that keys with the same nodes and conditioning set are treated as equal. + */ +public class DSeparationKey { + /** + * The node x in the D-separation key. + * This node is one of the two nodes being checked for d-separation. + */ + private final Node x; + + /** + * The node y in the D-separation key. + * This node is the other node being checked for d-separation. + */ + private final Node y; + + /** + * The set of conditioning nodes in the D-separation key. + * This set contains nodes that are conditioned on when checking for d-separation between x and y. + * It is stored as a defensive copy to ensure immutability. + */ + private final Set conditioningSet; + + /** + * Constructor for DSeparationKey that initializes the key with two nodes and a set of conditioning nodes. + * The nodes x and y are stored in a consistent order to ensure that the key is symmetric. + * The conditioning set is stored as a defensive copy to prevent external modifications. + * @param x the first node in the D-separation key. + * @param y the second node in the D-separation key. + * @param conditioningSet the set of conditioning nodes in the D-separation key. + */ + public DSeparationKey(Node x, Node y, Set conditioningSet) { + // Since D-separation is symmetric, we ensure a consistent order for x and y + if (x.getName().compareTo(y.getName()) <= 0) { + this.x = x; + this.y = y; + } else { + this.x = y; + this.y = x; + } + this.conditioningSet = new HashSet<>(conditioningSet); // copia defensiva + } + + /** + * Checks if this D-separation key is equal to another object. + * Two keys are considered equal if they have the same nodes (x and y) and + * the same set of conditioning nodes. + * The equality is symmetric, meaning the order of x and y does not matter. + * @param obj the object to compare with this D-separation key. + * @return true if the other object is a DSeparationKey with the same nodes and conditioning set, false otherwise. + */ + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (!(obj instanceof DSeparationKey)) return false; + + DSeparationKey other = (DSeparationKey) obj; + return y.equals(other.y) + && x.equals(other.x) + && conditioningSet.equals(other.conditioningSet); + } + + /** + * Returns the hash code for this D-separation key. + * The hash code is computed based on the nodes x and y, and the conditioning set. + * This ensures that two keys that are equal will have the same hash code. + * @return the hash code for this D-separation key. + */ + @Override + public int hashCode() { + return Objects.hash(y, x, conditioningSet); + } + + /** + * Returns the node y in the D-separation key. + * @return the node y in the D-separation key. + */ + public Node getY() { + return this.y; + } + + /** + * Returns the node x in the D-separation key. + * @return the node x in the D-separation key. + */ + public Node getX() { + return this.x; + } + + /** + * Returns the set of conditioning nodes in the D-separation key. + * This set is unmodifiable to prevent external modifications. + * @return the set of conditioning nodes in the D-separation key. + */ + public Set getConditioningSet() { + return Collections.unmodifiableSet(this.conditioningSet); + } + + +} + diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java index adddc2a..cd0d4d0 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusBES.java @@ -1,435 +1,59 @@ package es.uclm.i3a.simd.consensusBN; import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; import edu.cmu.tetrad.graph.Dag; -import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.EdgeListGraph; -import edu.cmu.tetrad.graph.Edges; -import edu.cmu.tetrad.graph.Endpoint; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.search.utils.MeekRules; -import edu.cmu.tetrad.search.utils.GraphSearchUtils; -import static es.uclm.i3a.simd.consensusBN.Utils.pdagToDag; - - - -public class HeuristicConsensusBES { - - ArrayList alpha = null; - Dag outputDag = null; - AlphaOrder heuristic = null; - TransformDags imaps2alpha = null; - ArrayList setOfdags = null; - ArrayList setOfOutDags = null; - Dag union = null; - int numberOfInsertedEdges = 0; - double percentage = 1.0; - int maxSize = 10; - - Map localScore = new HashMap(); - - public HeuristicConsensusBES(ArrayList dags, double percentage){ - this.setOfdags = dags; - this.heuristic = new AlphaOrder(this.setOfdags); - this.heuristic.computeAlphaH2(); - this.alpha = this.heuristic.alpha; - this.imaps2alpha = new TransformDags(this.setOfdags,this.alpha); - this.imaps2alpha.transform(); - this.numberOfInsertedEdges = imaps2alpha.getNumberOfInsertedEdges(); - this.setOfOutDags = imaps2alpha.setOfOutputDags; - this.percentage = percentage; - } - - - public int getNumberOfInsertedEdges(){ - return this.numberOfInsertedEdges; - } - - private void consensusUnion(){ - - this.union = new Dag(this.alpha); - for(Node nodei: this.alpha){ - for(Dag d : this.imaps2alpha.setOfOutputDags){ - Listparent = d.getParents(nodei); - for(Node pa: parent){ - if(!this.union.isParentOf(pa, nodei)) this.union.addEdge(new Edge(pa,nodei,Endpoint.TAIL,Endpoint.ARROW)); - } - } - - } - - } - - // private methods for searching - - - - public void fusion(){ - - // System.out.println("\n** BACKWARD ELIMINATION SEARCH (BES)"); - //PowerSetFabric.setMode(PowerSetFabric.MODE_BES); - double score = 0; - double bestScore = score; - Graph graph = null; - - consensusUnion(); - graph = new EdgeListGraph(new LinkedList(this.union.getNodes())); - for(Edge e: this.union.getEdges()){ - graph.addEdge(e); - } - - //SearchGraphUtils.dagToPdag(graph); - rebuildPattern(graph); - Node x, y; - Set t = new HashSet(); - do { - x = y = null; - Set edges1 = graph.getEdges(); - List edges = new ArrayList(); - - for (Edge edge : edges1) { - Node _x = edge.getNode1(); - Node _y = edge.getNode2(); - - if (Edges.isUndirectedEdge(edge)) { - edges.add(Edges.directedEdge(_x, _y)); - edges.add(Edges.directedEdge(_y, _x)); - } else { - edges.add(edge); - } - } - for (Edge edge : edges) { - Node _x = Edges.getDirectedEdgeTail(edge); - Node _y = Edges.getDirectedEdgeHead(edge); - - List hNeighbors = getHNeighbors(_x, _y, graph); -// List> hSubsets = powerSet(hNeighbors); - PowerSet hSubsets= PowerSetFabric.getPowerSet(_x,_y,hNeighbors); - while(hSubsets.hasMoreElements()) { - SubSet hSubset=hSubsets.nextElement(); - if(hSubset.size() > maxSize) break; - double deleteEval = deleteEval(_x, _y, hSubset, graph); - if (!(deleteEval >= this.percentage)) deleteEval = 0.0; - double evalScore = score + deleteEval; - - // System.out.println("Attempt removing " + _x + "-->" + _y + "(" +evalScore + ") "+ hSubset.toString()); - - if (!(evalScore > bestScore)) { - continue; - } - - // INICIO TEST 1 - List naYXH = findNaYX(_x, _y, graph); - naYXH.removeAll(hSubset); - if (!isClique(naYXH, graph)) { -// hSubsets.firstTest(true); // Si pasa para H entonces pasa para cualquier H' | H' contiene H - continue; - } - // FIN TEST 1 - - bestScore = evalScore; - x = _x; - y = _y; - t = hSubset; - break; - } - - } - if (x != null) { - - //System.out.println("DELETE " + graph.getEdge(x, y) + t.toString() + " (" +bestScore + ")"); - - delete(x, y, t, graph); - rebuildPattern(graph); - this.numberOfInsertedEdges--; -// if(graph.existsDirectedCycle()){ - -// System.out.println("Hay un ciclo: "+x.toString()+" "+y.toString()); -// System.out.println("Grafo: "+graph.toString()); -// System.exit(0); -// } - score = bestScore; - } - } while (x != null); - -// System.out.println("Pdag: "+ graph.toString()); - pdagToDag(graph); -// System.out.println("PdagToDag"+graph.toString()); - this.outputDag = new Dag(); - for (Node node : graph.getNodes()) this.outputDag.addNode(node); - Node nodeT, nodeH; - for (Edge e : graph.getEdges()){ - if(!e.isDirected()) continue; - Endpoint endpoint1 = e.getEndpoint1(); - if (endpoint1.equals(Endpoint.ARROW)){ - nodeT = e.getNode1(); - nodeH = e.getNode2(); - }else{ - nodeT = e.getNode2(); - nodeH = e.getNode1(); - } - if(!this.outputDag.paths().existsDirectedPath(nodeT, nodeH)) this.outputDag.addEdge(e); - } -// System.out.println("DAG: "+this.outputDag.toString()); - } - - - - private static void delete(Node x, Node y, Set subset, Graph graph) { - graph.removeEdges(x, y); - - for (Node aSubset : subset) { - if (!graph.isParentOf(aSubset, x) && !graph.isParentOf(x, aSubset)) { - graph.removeEdge(x, aSubset); - graph.addDirectedEdge(x, aSubset); - } - graph.removeEdge(y, aSubset); - graph.addDirectedEdge(y, aSubset); - } - } - - - private void rebuildPattern(Graph graph) { - GraphSearchUtils.basicCpdag(graph); - pdag(graph); - } - - /** - * Fully direct a graph with background knowledge. I am not sure how to - * adapt Chickering's suggested algorithm above (dagToPdag) to incorporate - * background knowledge, so I am also implementing this algorithm based on - * Meek's 1995 UAI paper. Notice it is the same implemented in PcSearch. - *

*IMPORTANT!* *It assumes all colliders are oriented, as well as - * arrows dictated by time order.* - * - * ELIMINADO BACKGROUND KNOWLEDGE - */ - private void pdag(Graph graph) { - MeekRules rules = new MeekRules(); - rules.setMeekPreventCycles(true); - rules.orientImplied(graph); - } - - - private static boolean isClique(List set, Graph graph) { - List setv = new LinkedList(set); - for (int i = 0; i < setv.size() - 1; i++) { - for (int j = i + 1; j < setv.size(); j++) { - if (!graph.isAdjacentTo(setv.get(i), setv.get(j))) { - return false; - } - } - } - return true; - } - - private static List getHNeighbors(Node x, Node y, Graph graph) { - List hNeighbors = new LinkedList(graph.getAdjacentNodes(y)); - hNeighbors.retainAll(graph.getAdjacentNodes(x)); - - for (int i = hNeighbors.size() - 1; i >= 0; i--) { - Node z = hNeighbors.get(i); - Edge edge = graph.getEdge(y, z); - if (!Edges.isUndirectedEdge(edge)) { - hNeighbors.remove(z); - } - } - - return hNeighbors; - } - - - double deleteEval(Node x, Node y, SubSet h, Graph graph){ - - Set set1 = new HashSet(findNaYX(x, y, graph)); - set1.removeAll(h); - set1.addAll(graph.getParents(y)); - set1.remove(x); - return scoreGraphChangeDelete(y, x, set1); // calcular si y esta d-separado de x dado el set1 en cada grafo. - - } - - double scoreGraphChangeDelete(Node y, Node x, Set set){ - - String key = y.getName()+x.getName()+set.toString(); - Double val = this.localScore.get(key); - if(val == null){ - double eval = 0.0; - LinkedList conditioning = new LinkedList(); - conditioning.addAll(set); - for(Dag g: this.setOfdags){ - if(dSeparated(g, y, x, conditioning)) ++eval; - } - eval = eval / (double) this.setOfdags.size(); - val = eval; - this.localScore.put(key, val); - return eval; - }else{ - return val.doubleValue(); - } - } - - - - boolean dSeparated(Dag g, Node x, Node y, LinkedList cond){ - - LinkedList open = new LinkedList(); - HashMap close = new HashMap(); - open.add(x); - open.add(y); - open.addAll(cond); - while (open.size() != 0){ - Node a = open.getFirst(); - open.remove(a); - close.put(a.toString(),a); - List pa =g.getParents(a); - for(Node p : pa){ - if(close.get(p.toString()) == null){ - if(!open.contains(p)) open.addLast(p); - } - } - } - - Graph aux = new EdgeListGraph(); - - for (Node node : g.getNodes()) aux.addNode(node); - Node nodeT, nodeH; - for (Edge e : g.getEdges()){ - if(!e.isDirected()) continue; - nodeT = e.getNode1(); - nodeH = e.getNode2(); - if((close.get(nodeH.toString())!=null)&&(close.get(nodeT.toString())!=null)){ - Edge newEdge = new Edge(e.getNode1(),e.getNode2(),e.getEndpoint1(),e.getEndpoint2()); - aux.addEdge(newEdge); - } - } - - close = new HashMap(); - for(Edge e: aux.getEdges()){ - if(e.isDirected()){ - Node h; - if(e.getEndpoint1()==Endpoint.ARROW){ - h = e.getNode1(); - }else h = e.getNode2(); - if(close.get(h.toString())==null){ - close.put(h.toString(),h); - List pa = aux.getParents(h); - if(pa.size()>1){ - for(int i = 0 ; i< pa.size() - 1; i++) - for(int j = i+1; j < pa.size(); j++){ - Node p1 = pa.get(i); - Node p2 = pa.get(j); - boolean found = false; - for(Edge edge : aux.getEdges()){ - if(edge.getNode1().equals(p1)&&(edge.getNode2().equals(p2))){ - found = true; - break; - } - if(edge.getNode2().equals(p1)&&(edge.getNode1().equals(p2))){ - found = true; - break; - } - } - if(!found) aux.addUndirectedEdge(p1, p2); - } - } - - } - } - } - - for(Edge e: aux.getEdges()){ - if(e.isDirected()){ - e.setEndpoint1(Endpoint.TAIL); - e.setEndpoint2(Endpoint.TAIL); - } - } - - aux.removeNodes(cond); - - open = new LinkedList(); - close = new HashMap(); - open.add(x); - while (open.size() != 0){ - Node a = open.getFirst(); - if(a.equals(y)) return false; - open.remove(a); - close.put(a.toString(),a); - List pa =aux.getAdjacentNodes(a); - for(Node p : pa){ - if(close.get(p.toString()) == null){ - if(!open.contains(p)) open.addLast(p); - } - } - } - - return true; - } - - - private static List findNaYX(Node x, Node y, Graph graph) { - List naYX = new LinkedList(graph.getAdjacentNodes(y)); - naYX.retainAll(graph.getAdjacentNodes(x)); - - for (int i = naYX.size()-1; i >= 0; i--) { - Node z = naYX.get(i); - Edge edge = graph.getEdge(y, z); - - if (!Edges.isUndirectedEdge(edge)) { - naYX.remove(z); - } - } - - return naYX; - } - - public Dag getFusion(){ - - return this.outputDag; +/** + * The {@code HeuristicConsensusBES} class extends the {@code ConsensusBES} class + * to implement a heuristic approach for backward equivalence search in directed acyclic graphs (DAGs). + *

+ * This class is designed to perform a consensus structure learning algorithm + * by applying a heuristic method for backward equivalence search (BES) with D-separation. + * It allows for a more efficient search by limiting the size of the conditioning set and applying a + * percentage threshold for determining d-separation between nodes. + *

+ * The constructor initializes the class with a list of input DAGs, a maximum size for the + * conditioning set, and a percentage threshold for d-separation. + * The `fusion` method applies the consensus union to compute a consensus DAG from the input DAGs, + * and then applies the backward equivalence search with D-separation with the specified parameters to refine the graph. + */ +public class HeuristicConsensusBES extends ConsensusBES{ + + private final int maxSize; + private final double percentage; + + /** + * Constructor for HeuristicConsensusBES. + * This class extends ConsensusBES to implement a heuristic approach + * for backward equivalence search in directed acyclic graphs (DAGs). + * + * @param dags the list of input DAGs to be fused. + * @param maxSize the maximum size of the conditioning set for d-separation checks. + * @param percentage the percentage/threshold for determining d-separation between nodes. + */ + public HeuristicConsensusBES(ArrayList dags, int maxSize, double percentage) { + super(dags); + this.maxSize = maxSize; + this.percentage = percentage; } - - - public static void main(String args[]) { - - - System.out.println("Grafos de Partida: "); - - // (seed, n. variables, n egdes aprox, n.dags, mutation) - RandomBN setOfBNs = new RandomBN(0, Integer.parseInt(args[0]), Integer.parseInt(args[1]), - Integer.parseInt(args[2]),Integer.parseInt(args[3])); - setOfBNs.setMaxInDegree(3); - setOfBNs.setMaxOutDegree(3); - setOfBNs.generate(); - - for(int i = 0; i< setOfBNs.setOfRandomBNs.size(); i++){ - System.out.println("red de partida: "+i); - System.out.println("---------------------"); - System.out.println("Grafo: "); - System.out.println(setOfBNs.setOfRandomDags.get(i).getDegree()+" "+ setOfBNs.setOfRandomDags.get(i).getNumEdges()); -// System.out.println("Probabilidades: "); -// System.out.println(setOfBNs.setOfRandomBNs.get(i).toString()); -// System.out.println("_____________________"); -// System.out.println("Datos Simulados"); -// System.out.println(setOfBNs.setOfSampledBNs.get(i).toString()); - - - } - // - HeuristicConsensusBES conDag= null; - conDag = new HeuristicConsensusBES(setOfBNs.setOfRandomDags,1.0); - conDag.fusion(); - Dag g = conDag.getFusion(); - System.out.println("grafo de partida Union: "+conDag.union.getDegree()+" "+ conDag.union.getNumEdges()); - System.out.println("grafo consenso: "+ g.getDegree() +" Complejidad de la Fusion: "+ conDag.getNumberOfInsertedEdges()+ " "+ conDag.outputDag.getNumEdges()); + /** + * Executes the heuristic consensus backward equivalence search. + * This method first applies the ConsensusUnion to compute a consensus DAG from the input DAGs, + * and then applies the Backward Equivalence Search with D-separation to refine the graph, setting the maxSize and percentage parameters for an heuristic search. + * The resulting output DAG is stored in the outputDag attribute. + */ + @Override + public void fusion(){ + // 1. Apply ConsensusUnion + consensusUnion(); + // 2. Apply Heuristic BES with D-separation + BackwardEquivalenceSearchDSep bes = new BackwardEquivalenceSearchDSep(this.getUnion(), this.getInputDags(), this.getTransformedDags()); + bes.setMaxSize(maxSize); + bes.setPercentage(percentage); + this.outputDag = bes.applyBackwardEliminationWithDSeparation(); + // 3. Updating numberOfInsertedEdges + this.numberOfInsertedEdges -= bes.getNumberOfRemovedEdges(); } } diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVoting.java b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVoting.java index 1abe261..f09ec62 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVoting.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVoting.java @@ -6,48 +6,142 @@ import edu.cmu.tetrad.graph.Dag; import edu.cmu.tetrad.graph.Edge; import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Edges; import edu.cmu.tetrad.graph.Endpoint; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; import static es.uclm.i3a.simd.consensusBN.Utils.pdagToDag; +/** + * The {@code HeuristicConsensusMVoting} class implements a consensus structure learning algorithm + * based on a heuristic majority voting strategy. + *

+ * Given a collection of Directed Acyclic Graphs (DAGs), this class aggregates the structures + * by counting the frequency of directed edges across all graphs and selecting those + * that meet or exceed a specified threshold. + *

+ * The heuristic helps reduce noise by filtering out weakly supported edges, and it resolves + * conflicts (such as bidirectional edges) by applying majority rules or tie-breaking strategies. + *

+ * Typical use cases include combining results from different structure learning algorithms, + * bootstrapping runs, or expert-curated networks. + *

+ * The resulting output is a new DAG that aims to reflect the most consistent edges found + * across the input graphs. + * + *

Note: Input DAGs must share the same set of nodes for the consensus to be meaningful. + * + * Example usage: + *

{@code
+ * List inputGraphs = Arrays.asList(dag1, dag2, dag3);
+ * HeuristicConsensusMVoting consensus = new HeuristicConsensusMVoting(inputGraphs, threshold);
+ * Dag consensusDag = consensus.getConsensusGraph();
+ * }
+ * + */ public class HeuristicConsensusMVoting { - ArrayList variables = null; - Dag outputDag = null; - ArrayList setOfdags = null; - double percentage = 1.0; - double [][] weight = null; + /** + * List of variables (nodes) in the consensus DAG. + * This list is derived from the nodes of the input DAGs. + */ + private ArrayList variables = null; + + /** + * The resulting output DAG after applying the heuristic consensus voting. + * This DAG contains the edges that were selected based on the majority voting strategy. + */ + private Dag outputDag = null; + + /** + * List of input DAGs used to compute the consensus. + * These DAGs are expected to have the same set of nodes. + */ + private ArrayList setOfdags = null; + + /** + * Percentage threshold for edge inclusion in the consensus DAG. + * An edge is included if it appears in at least this percentage of the input DAGs. + */ + private double percentage = 1.0; + + /** + * Weight matrix representing the frequency of edges between pairs of nodes. + * The weight[i][j] indicates how many times the edge from node i to node j appears in the input DAGs. + */ + private double [][] weight = null; + /** + * Constructor for HeuristicConsensusMVoting. + * Initializes the variables, output DAG, input DAGs, and weight matrix. + * @param setOfdags the list of input DAGs to be fused. + * @param percentage the percentage threshold for edge inclusion in the consensus DAG. + */ + public HeuristicConsensusMVoting(ArrayList setOfdags, double percentage) { + this.variables = (ArrayList) setOfdags.get(0).getNodes(); + this.outputDag = null; + this.setOfdags = setOfdags; + this.percentage = percentage; + this.weight = new double[this.variables.size()][this.variables.size()]; + setup(); + } + /** + * Sets up the HeuristicConsensusMVoting instance by validating the input DAGs + * and building the weight matrix based on the edges present in the input DAGs. + */ + private void setup() { + // Ensuring that the input DAGs are not null and have the same set of nodes + validateInputDags(); + buildWeightMatrix(); + } + + /** + * Validates the input DAGs to ensure they are not null and have the same set of nodes. + * Throws an IllegalArgumentException if any validation fails. + */ + private void validateInputDags() { + if (this.setOfdags == null || this.setOfdags.isEmpty()) + throw new IllegalArgumentException("Input DAGs cannot be null or empty."); + for(Dag dag : this.setOfdags) { + if (dag.getNodes().size() != this.variables.size()) { + throw new IllegalArgumentException("All input DAGs must have the same number of nodes."); + } + if (!dag.getNodes().containsAll(this.variables)) { + throw new IllegalArgumentException("All input DAGs must have the same set of nodes."); + } + } + } -public HeuristicConsensusMVoting(ArrayList setOfdags, double percentage) { - super(); - this.variables = (ArrayList) setOfdags.get(0).getNodes(); - this.outputDag = null; - this.setOfdags = setOfdags; - this.percentage = percentage; - this.weight = new double[this.variables.size()][this.variables.size()]; - ArrayList pdags = new ArrayList(); + /** + * Builds the weight matrix based on the edges present in the input DAGs. + * Each entry weight[i][j] is incremented for each directed edge from node i to node j + * across all input DAGs, normalized by the number of input DAGs + */ + private void buildWeightMatrix() { + ArrayList pdags = new ArrayList<>(); for(Dag g: this.setOfdags){ - Graph pd = new EdgeListGraph(new LinkedList(g.getNodes())); + Graph graph = new EdgeListGraph(new LinkedList<>(g.getNodes())); for(Edge e: g.getEdges()){ - pd.addEdge(e); + graph.addEdge(e); } - pdagToDag(pd); - pdags.add(pd); + pdagToDag(graph); + pdags.add(graph); } - + for(Graph pd: pdags){ for(Edge e: pd.getEdges()){ - Node n1 = e.getNode1(); - Node n2 = e.getNode2(); if(e.isDirected()){ - if(e.getEndpoint1() == Endpoint.ARROW){ - this.weight[variables.indexOf(n2)][variables.indexOf(n1)]+= (double) (1.0/this.setOfdags.size()); - }else{ - this.weight[variables.indexOf(n1)][variables.indexOf(n2)]+= (double) (1.0/this.setOfdags.size()); - } + Node from = Edges.getDirectedEdgeTail(e); + Node to = Edges.getDirectedEdgeHead(e); + this.weight[variables.indexOf(from)][variables.indexOf(to)] += (double) (1.0/this.setOfdags.size()); + // if(e.getEndpoint1() == Endpoint.ARROW){ + // this.weight[variables.indexOf(n2)][variables.indexOf(n1)]+= (double) (1.0/this.setOfdags.size()); + // }else{ + // this.weight[variables.indexOf(n1)][variables.indexOf(n2)]+= (double) (1.0/this.setOfdags.size()); + // } }else{ + Node n1 = e.getNode1(); + Node n2 = e.getNode2(); this.weight[variables.indexOf(n2)][variables.indexOf(n1)]+= (double) (1.0/this.setOfdags.size()); this.weight[variables.indexOf(n1)][variables.indexOf(n2)]+= (double) (1.0/this.setOfdags.size()); } @@ -55,36 +149,88 @@ public HeuristicConsensusMVoting(ArrayList setOfdags, double percentage) { } } -public Dag fusion(){ - - this.outputDag = new Dag(variables); - boolean procced = true; - int bestEdgei = 0; - int bestEdgej = 0; - double maxW = 0.0; - while(procced){ - for(int i = 0; i= maxW){ - if((this.weight[i][j] > maxW) || ((this.weight[i][j]==maxW) && (Math.random()>0.5))){ - bestEdgei = i; - bestEdgej = j; - maxW = this.weight[i][j]; + /** + * Performs the fusion of the input DAGs using a heuristic majority voting strategy. + * The method iteratively selects edges based on their weights and adds them to the output DAG + * until no more edges meet the specified percentage threshold. + * @return The resulting consensus DAG after applying the heuristic voting. + */ + public Dag fusion(){ + this.outputDag = new Dag(variables); + + while(true){ + int bestEdgei = -1; // Best edge head node index + int bestEdgej = -1; // Best edge tail node index + double maxW = 0.0; // Maximum weight found in the current iteration + + // Find the best edge based on the weight matrix + for(int i = 0; i= maxW){ + if((this.weight[i][j] > maxW) || ((this.weight[i][j]==maxW) && (Math.random()>0.5))){ + bestEdgei = i; + bestEdgej = j; + maxW = this.weight[i][j]; + } } } - } - if(maxW >= this.percentage){ - if(!this.outputDag.paths().existsDirectedPath(variables.get(bestEdgej), variables.get(bestEdgei))){ - this.outputDag.addEdge(new Edge(variables.get(bestEdgei),variables.get(bestEdgej),Endpoint.TAIL,Endpoint.ARROW)); - this.weight[bestEdgei][bestEdgej] = 0; - }else this.weight[bestEdgei][bestEdgej] = 0; - if(maxW==0) procced = false; - maxW = 0.0; - }else procced = false; + // Stop if no edge meets the threshold + if(bestEdgei == -1 || bestEdgej == -1 || maxW < percentage || maxW == 0.0) + break; + + // Add edge if it doesn't introduce a cycle + Node from = variables.get(bestEdgei); + Node to = variables.get(bestEdgej); + if(!this.outputDag.paths().existsDirectedPath(to, from)){ + this.outputDag.addEdge(new Edge(from,to,Endpoint.TAIL,Endpoint.ARROW)); + } + // Mark the edge as used by setting its weight to zero + this.weight[bestEdgei][bestEdgej] = 0; + } + + return this.outputDag; + } + + /** + * Returns the nodes (variables) of the consensus DAG. + * @return A list of nodes representing the variables in the consensus DAG. + */ + public ArrayList getVariables() { + return variables; + } + + /** + * Returns the resulting consensus DAG after applying the heuristic voting. + * @return The output DAG containing the selected edges based on the majority voting strategy. + */ + public Dag getOutputDag() { + return outputDag; + } + + /** + * Returns the list of input DAGs used to compute the consensus. + * @return A list of DAGs that were used as input for the consensus voting. + */ + public ArrayList getSetOfdags() { + return setOfdags; + } + + /** + * Returns the percentage threshold used for edge inclusion in the consensus DAG. + * @return The percentage threshold for edge inclusion. + */ + public double getPercentage() { + return percentage; + } + /** + * Returns the weight matrix representing the frequency of edges between pairs of nodes. + * Each entry weight[i][j] indicates the weight that the edge from node i to node j has in the input DAGs. + * @return The weight matrix used in the consensus voting process. + * @see #fusion() + */ + public double[][] getWeight() { + return weight; } - - return this.outputDag; -} } diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java b/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java index 1965f28..e55c0d9 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNs.java @@ -106,7 +106,7 @@ public int cluster() { clustersIndexes[i][i][0] = true; } - dissimilarityMatrix = computeDissimilarityMatrix(); + computeDissimilarityMatrix(); for (int a = 1; a 0 && (this.maxSize >= (this.clusterCardinalities[o1] + this.clusterCardinalities[o2]))|| level == 0){ PairWiseConsensusBES pairBNs = new PairWiseConsensusBES(this.clustersBN[o1][level],this.clustersBN[o2][level]); PairWiseConsensusBES pairDag= (PairWiseConsensusBES) pairBNs; - pairDag.getFusion(); + pairDag.fusion(); return pairBNs; }else if(this.maxSize == 0){ PairWiseConsensusBES pairBNs = new PairWiseConsensusBES(this.clustersBN[o1][level],this.clustersBN[o2][level]); PairWiseConsensusBES pairDag= (PairWiseConsensusBES) pairBNs; - pairDag.getFusion(); + pairDag.fusion(); if((pairDag.getDagFusion().getNumEdges())/this.averageNEdges <= this.maxComplexityCluster|| level == 0) return pairBNs; else return null; @@ -326,7 +325,7 @@ private Pair findMostSimilarClusters() { PairWiseConsensusBES inCluster = dissimilarityMatrix[cluster][neighbor]; if(inCluster!= null){ double complexity = 0.0; - complexity = (float) inCluster.getHammingDistance();//getNumberOfInsertedEdges(); + complexity = (float) inCluster.calculateHammingDistance();//getNumberOfInsertedEdges(); if (indexUsed[neighbor]&&complexity hashMap=new HashMap(); -// -// public static int[] getList(int size) { -// Integer key=size; -// int[] lista=hashMap.get(key); -// if(lista==null) { -// lista=generateList(size); -// hashMap.put(key, lista); -// } -// return lista; -// } - - public static int[] getList(int size) { - return generateList(size); - } - - private static int[] generateList(int size) { + /** + * MAX_SIZE is the maximum number of elements that can be included in any subset. + * It is set to Integer.MAX_VALUE by default, meaning no limit unless specified. + */ + public static int MAX_SIZE=Integer.MAX_VALUE; + + /** + * Generates a list of integers representing all subsets of a set of a given size. + * Each integer is a bitmask where the i-th bit represents the inclusion of the i-th element. + * @param size the size of the set for which subsets are generated + * @return an array of integers representing the subsets + */ + public static int[] generateList(int size) { int[] lista; if(size==0) { return new int[1]; } + // Generation of powers of numbers that are powers of 2 int[] pow2=new int[size]; pow2[0]=1; for(int i=1;iaux[0][index]) { aux[0][counter]=aux[0][index]+pow2[i]; @@ -50,11 +55,4 @@ private static int[] generateList(int size) { return lista; } - public static void setMaxSize(int maxParents) { - ListFabric.maxSize = maxParents; - } - - public static int getMaxSize() { - return ListFabric.maxSize; - } } diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java b/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java index 3b053bb..f1669bb 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBES.java @@ -2,54 +2,138 @@ import java.util.ArrayList; - import edu.cmu.tetrad.graph.Dag; import edu.cmu.tetrad.graph.Edge; import edu.cmu.tetrad.graph.Node; - +/** + * This class implements the PairWiseConsensusBES algorithm, which measures the similarity between two DAGs by fusing them with the ConsensusBES algorithm. + * It first fuses the two input DAGs into a consensus DAG using the ConsensusBES class. + * After obtaining the consensus DAG, it calculates the Hamming distance between the fused DAG and the original input DAGs. + * The resulting output DAG is stored in the consensusDAG attribute, and the number of inserted edges during the fusion process can be retrieved using getNumberOfInsertedEdges. + */ public class PairWiseConsensusBES implements Runnable{ + /** + * The first input DAG to be fused. + */ + private Dag firstDag = null; + + /** + * The second input DAG to be fused. + */ + private Dag secondDag = null; + + /** + * The resulting consensus DAG after applying the fusion process. + */ + private Dag consensusDAG = null; + + /** + * Instance of ConsensusBES used to compute the consensus DAG from the input DAGs. + */ + private ConsensusBES consensusBES= null; - private Dag b1 = null; - private Dag b2 = null; - private Dag conDAG = null; - private ConsensusBES conBES= null; + /** + * Number of total edges inserted during the fusion process. + */ private int numberOfInsertedEdges = 0; + + /** + * Number of edges inserted during the consensus union process. + */ private int numberOfUnionEdges = 0; - public PairWiseConsensusBES(Dag b1, Dag b2) { - super(); - this.b1 = b1; - this.b2 = b2; + /** + * Constructor for the PairWiseConsensusBES class. + * It initializes the instance with two input DAGs and checks if they are valid. + * If the input DAGs are not valid, it throws an IllegalArgumentException. + * @param firstDag the first input DAG for fusion similarity. + * @param secondDag the second input DAG for fusion similarity. + */ + public PairWiseConsensusBES(Dag firstDag, Dag secondDag) { + checkInput(firstDag, secondDag); + this.firstDag = firstDag; + this.secondDag = secondDag; } - - public void getFusion(){ - ArrayList setOfDags = new ArrayList(); - setOfDags.add(this.b1); - setOfDags.add(this.b2); - conBES = new ConsensusBES(setOfDags); - conBES.fusion(); - this.numberOfInsertedEdges = conBES.getNumberOfInsertedEdges(); - this.numberOfUnionEdges = conBES.union.getNumEdges(); - this.conDAG = conBES.getFusion(); + /** + * Checks if the input DAGs are valid. + * Validity is determined by ensuring that the DAGs are not null, contain at least one node and one edge, and have the same set of nodes. + * If any of these conditions are not met, an IllegalArgumentException is thrown. + * @param firstDag first input DAG for fusion similarity + * @param secondDag second input DAG for fusion similarity + * @throws IllegalArgumentException if the input DAGs are not valid + */ + private void checkInput(Dag firstDag, Dag secondDag) { + if (firstDag == null || secondDag == null) { + throw new IllegalArgumentException("Input DAGs cannot be null."); + } + if (firstDag.getNumNodes() == 0 || secondDag.getNumNodes() == 0) { + throw new IllegalArgumentException("Input DAGs must contain at least one node."); + } + if (firstDag.getNumEdges() == 0 || secondDag.getNumEdges() == 0) { + throw new IllegalArgumentException("Input DAGs must contain at least one edge."); + } + if (firstDag.getNodes().size() != secondDag.getNodes().size()) { + throw new IllegalArgumentException("Input DAGs must have the same number of nodes."); + } + if (!firstDag.getNodes().containsAll(secondDag.getNodes())) { + throw new IllegalArgumentException("Input DAGs must have the same set of nodes."); + } } - + + /** + * Performs the fusion process by first applying the consensus union and then applying the Backward Equivalence Search. + */ + public void fusion(){ + // Creating a list of DAGs to be fused + ArrayList setOfDags = new ArrayList<>(); + setOfDags.add(this.firstDag); + setOfDags.add(this.secondDag); + // Applying the ConsensusBES algorithm to fuse the DAGs + consensusBES = new ConsensusBES(setOfDags); + consensusBES.fusion(); + // Retrieving the resulting DAG and the number of inserted edges + this.numberOfInsertedEdges = consensusBES.getNumberOfInsertedEdges(); + this.numberOfUnionEdges = consensusBES.getUnion().getNumEdges(); + this.consensusDAG = consensusBES.getFusionDag(); + } + + /** + * Returns the number of edges inserted during the fusion process. + * This method retrieves the number of edges that were added to the consensus DAG during the fusion process. + * It is useful for understanding how many edges were introduced in the consensus DAG compared to the original input DAGs. + * @return the number of edges inserted during the fusion process. + */ public int getNumberOfInsertedEdges(){ return this.numberOfInsertedEdges; } + /** + * Returns the number of edges in the union DAG after the consensus union process. + * This method retrieves the number of edges that were present in the union DAG after merging the transformed input DAGs. + * It is useful for understanding the size of the union DAG before applying the Backward Equivalence Search. + * This number can be used to compare with the number of edges in the final consensus DAG after the Backward Equivalence Search. + * + * @see ConsensusBES#getUnion() + * @see ConsensusBES#getNumberOfInsertedEdges() + * @return the number of edges in the union DAG after the consensus union process. + */ public int getNumberOfUnionEdges(){ return this.numberOfUnionEdges; } - public int getHammingDistance(){ - if(this.conDAG==null) this.getFusion(); + /** + * Calculates the Hamming distance between the optimum fusion DAG and the original input DAGs. + * @return The Hamming distance between the fused DAG and the original input DAGs. + */ + public int calculateHammingDistance(){ + if(this.consensusDAG==null) this.fusion(); int distance = 0; - for(Edge ed: this.conDAG.getEdges()){ + for(Edge ed: this.consensusDAG.getEdges()){ Node tail = ed.getNode1(); Node head = ed.getNode2(); - for(Dag g: conBES.setOfOutDags){ + for(Dag g: consensusBES.getTransformedDags()){ Edge edge1 = g.getEdge(tail, head); Edge edge2 = g.getEdge(head, tail); if(edge1 == null && edge2==null) distance++; @@ -58,14 +142,25 @@ public int getHammingDistance(){ return distance+this.getNumberOfInsertedEdges(); } + /** + * Returns the resulting consensus DAG after applying the fusion process. + * This method retrieves the final fused DAG, which represents the optimal fusion of the input DAGs. + * It is useful for obtaining the consensus structure after the fusion process has been completed. + * + * @see ConsensusBES#getFusionDag() + * @return the resulting consensus DAG after applying the whole fusion process. + */ public Dag getDagFusion(){ - return this.conDAG; + return this.consensusDAG; } + /** + * Runs the fusion process in a thread, performing the consensus union and the Backward Equivalence Search with D-separation. + */ @Override public void run() { - this.getFusion(); + this.fusion(); } } diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java b/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java index ebfe47d..c240f46 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/PowerSet.java @@ -3,113 +3,140 @@ import java.util.ArrayList; import java.util.Enumeration; import java.util.HashMap; +import java.util.HashSet; import java.util.List; - +import java.util.Set; import edu.cmu.tetrad.graph.Node; -public class PowerSet implements Enumeration { +/** + * PowerSet generates all subsets of a given set of nodes, with an optional maximum size constraint. + * It implements Enumeration to allow iteration over the subsets. + */ +public class PowerSet implements Enumeration> { + /** + * List of nodes for which the power set is generated. + */ List nodes; - private List subSets; + /** + * List to store the generated subsets. + */ + private final List> subSets = new ArrayList<>(); + /** + * Index to track the current position in the enumeration. + */ private int index; - private int[] lista; - private HashMap hashMap; - - - PowerSet(List nodes,int k) { - if(nodes.size()<=k) - k=nodes.size(); - this.nodes=nodes; - subSets = new ArrayList(); - index=0; - hashMap=new HashMap(); - lista=ListFabric.getList(nodes.size()); - for (int i : lista) { - SubSet newSubSet = new SubSet(); - String selection = Integer.toBinaryString(i); - for (int j = selection.length() - 1; j >= 0; j--) { - if (selection.charAt(j) == '1') { - newSubSet.add(nodes.get(selection.length() - j - 1)); - } - } - if(newSubSet.size()<=k){ - subSets.add(newSubSet); - hashMap.put(i, newSubSet); - } - } - } + /** + * List of integers representing the subsets in binary form. + */ + private int[] binaryList; + + /** + * A map to store the subsets with their corresponding binary representation. + * The key is the binary representation of the subset, and the value is the subset itself. + */ + private HashMap> subset; + /** + * Maximum size of the subsets to be generated. + * If set to a value less than the number of nodes, it limits the size of the subsets. + */ + private int maxPow = 0; - PowerSet(List nodes) { - if(nodes.size()>maxPow) - maxPow=nodes.size(); - this.nodes=nodes; - subSets = new ArrayList(); - index=0; - hashMap=new HashMap(); - lista=ListFabric.getList(nodes.size()); - for (int i : lista) { - SubSet newSubSet = new SubSet(); - String selection = Integer.toBinaryString(i); - for (int j = selection.length() - 1; j >= 0; j--) { - if (selection.charAt(j) == '1') { - newSubSet.add(nodes.get(selection.length() - j - 1)); - } - } - subSets.add(newSubSet); - hashMap.put(i, newSubSet); + /** + * Builds a PowerSet with subsets of the given nodes, limited to a maximum size. + * @param nodes List of nodes to generate subsets from. + * @param maxSize Maximum size of the subsets to be generated. Assuring that k does not exceed the number of nodes. + * * If maxSize is negative, it will throw an IllegalArgumentException. + * @throws IllegalArgumentException if maxSize is negative. + + */ + public PowerSet(List nodes, int maxSize) { + if (maxSize < 0) { + throw new IllegalArgumentException("maxSize cannot be negative"); } - } - + if (maxSize >= nodes.size()) { + maxSize = nodes.size(); + } + this.nodes = nodes; + initializeSubsets(maxSize); + } + + /** + * Builds a PowerSet with all subsets of the given nodes, without size limitation. + * + * @param nodes Lista de nodos de entrada. + */ + public PowerSet(List nodes) { + if (nodes.size() > maxPow) { + maxPow = nodes.size(); + } + this.nodes = nodes; + initializeSubsets(nodes.size()); // sin límite: k = nodes.size() + } + + /** + * Initializes the subsets based on the nodes and the maximum size. + * This method generates all possible subsets of the nodes, respecting the maximum size constraint. + * @param maxSize Maximum size of the subsets to be generated. + * If maxSize is greater than the number of nodes, it will generate subsets of all sizes. + * If maxSize is 0, it will only generate the empty set. + */ + private void initializeSubsets(int maxSize) { + subset = new HashMap<>(); + index = 0; + binaryList = ListFabric.generateList(nodes.size()); + + for (int i : binaryList) { + Set newSubSet = new HashSet<>(); + String selection = Integer.toBinaryString(i); + + for (int j = selection.length() - 1; j >= 0; j--) { + if (selection.charAt(j) == '1') { + int idx = selection.length() - j - 1; + newSubSet.add(nodes.get(idx)); + } + } + + if (newSubSet.size() <= maxSize) { + subSets.add(newSubSet); + subset.put(i, newSubSet); + } + } + } + + /** + * Checks if there are more subsets to iterate over. + * @return true if there are more subsets, false otherwise. + */ + @Override public boolean hasMoreElements() { return index nextElement() { return subSets.get(index++); } + /** + * Resets the index to allow re-iteration over the subsets. + * This method allows the enumeration to start over from the beginning. + */ public void resetIndex(){ this.index = 0; } - private static int maxPow = 0; - public static long maxPowerSetSize() { + /** + * Returns the maximum size of the power set based on the maximum number of nodes. + * This method calculates the size of the power set as 2 raised to the power of the maximum number of nodes. + * @return The maximum size of the power set. + */ + public long maxPowerSetSize() { return (long) Math.pow(2,maxPow); } - -// public void firstTest(boolean result) { -// int numInicial=lista[index-1]; -// for(int i=0;i> hashMap=new HashMap>(); -// -// private PowerSetFabric() { -// } - - public static PowerSet getPowerSet(List nodes, int k){ - return new PowerSet(nodes,k); - - } - - public static PowerSet getPowerSet(Node x, Node y, List nodes) { - return new PowerSet(nodes); -// if(!usePowerSetsCache) -// return new PowerSet(nodes); -// PowerSet pSet=get(x,y); -// if(pSet==null || pSet.nodes.size()!=nodes.size()) { // if(pSet==null || !pSet.t.equals(nodes)) { -// pSet=new PowerSet(nodes); -// put(x,y,pSet); -// } -// else { -// pSet.reset(mode==MODE_FES); -// } -// return pSet; - } -// -// private static void put(Node x, Node y, PowerSet pSet) { -// hashMap.get(x).put(y, pSet); -// -// } -// -// private static PowerSet get(Node x, Node y) { -// HashMap aux=hashMap.get(x); -// if(aux==null) { -// aux=new HashMap(); -// hashMap.put(x, aux); -// } -// return aux.get(y); -// } -// -// public static int getMode() { -// return mode; -// } -// - /* - public static void setMode(int mode) { - PowerSetFabric.mode=mode; - } - */ - - public static boolean isUsePowerSetsCache() { - return usePowerSetsCache; - } - - public static void setUsePowerSetsCache(boolean usePowerSetsCache) { - PowerSetFabric.usePowerSetsCache = usePowerSetsCache; - } -} diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/RandomBN.java b/src/main/java/es/uclm/i3a/simd/consensusBN/RandomBN.java index 06c89a4..4840bcb 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/RandomBN.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/RandomBN.java @@ -13,6 +13,11 @@ import edu.cmu.tetrad.bayes.MlBayesIm.InitializationMethod; import edu.cmu.tetrad.data.*; +/** + * RandomBN generates a set of random Bayesian networks (BNs) based on specified parameters. This class has been used for experiments, and is being maintained for compatibility with existing experiments. + * Further development should avoid using this class and instead use @see RandomGraph from Tetrad for generating random DAGs. + */ +@Deprecated public class RandomBN { int seed = 0; diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java b/src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java deleted file mode 100644 index be6018c..0000000 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/SubSet.java +++ /dev/null @@ -1,24 +0,0 @@ -package es.uclm.i3a.simd.consensusBN; - -import java.util.HashSet; - -import edu.cmu.tetrad.graph.Node; - -public class SubSet extends HashSet { - - private static final long serialVersionUID = 4569314863278L; - public static final int TEST_NOT_EVALUATED=0; - public static final int TEST_TRUE=1; - public static final int TEST_FALSE=-1; - - public int firstTest=TEST_NOT_EVALUATED; - public int secondTest=TEST_NOT_EVALUATED; - - public SubSet() { - super(); - } - - public SubSet(SubSet other) { - super(other); - } -} \ No newline at end of file diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java b/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java index 9861e99..f798234 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/TransformDags.java @@ -2,205 +2,133 @@ import java.util.ArrayList; - -import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.Node; +/** + * This class transforms a set of DAGs by applying the BetaToAlpha transformation to each DAG with a given alpha order. + */ public class TransformDags { + /** + * List of input DAGs to be transformed. + */ + private final ArrayList setOfDags; + /** + * List of output DAGs after transformation. + */ + private final ArrayList setOfOutputDags; + /** + * The alpha order used for the transformation. + */ + private final ArrayList alpha; + /** + * The transformation objects for each DAG. + * Each BetaToAlpha object applies the transformation to a corresponding DAG in setOfDags using the alpha order provided. + */ + private ArrayList transformers= null; - ArrayList setOfDags = null; - ArrayList setOfOutputDags = null; - ArrayList alfa = null; - ArrayList metAs= null; - int numberOfInsertedEdges = 0; -// int weight[][][] = null; + /** + * Number of edges inserted during the transformation process. + * This is used to track how many edges were added to the transformed DAGs. + */ + private int numberOfInsertedEdges = 0; - public TransformDags(ArrayList dags, ArrayList alfa){ + /** + * Constructor for TransformDags. + * Initializes the object with a list of DAGs and an alpha order. + * It creates a new BetaToAlpha transformation for each DAG in the input list. + * Each DAG in the input list will be transformed according to this alpha order. + * The transformation is applied by creating a BetaToAlpha object for each DAG. + * The transformed DAGs will be stored in setOfOutputDags after calling the transform() method. + * @see BetaToAlpha + * @param dags List of DAGs to be transformed. + * @param alpha List of nodes representing the alpha order for the transformation. + */ + public TransformDags(ArrayList dags, ArrayList alpha){ this.setOfDags = dags; - this.setOfOutputDags = new ArrayList(); - this.metAs = new ArrayList(); - this.alfa = alfa; - + this.setOfOutputDags = new ArrayList<>(); + this.transformers = new ArrayList<>(); + this.alpha = alpha; + // Initialize the BetaToAlpha transformation for each DAG in the input list for (Dag i : setOfDags) { Dag out = new Dag(i); - this.metAs.add(new BetaToAlpha(out,alfa)); + this.transformers.add(new BetaToAlpha(out,this.alpha)); } } - + /** + * Transforms the input DAGs by applying the BetaToAlpha transformation. + * This method iterates through each BetaToAlpha object, applies the transformation, + * and collects the transformed DAGs into setOfOutputDags. + * It also counts the total number of edges inserted during the transformations. + * + * @see BetaToAlpha#transform() + * @see BetaToAlpha#getNumberOfInsertedEdges() + * @see BetaToAlpha#getGraph() + * @return An ArrayList of transformed DAGs after applying the BetaToAlpha transformation. + */ public ArrayList transform (){ - this.numberOfInsertedEdges = 0; - - for(BetaToAlpha transformDagi: this.metAs){ + for(BetaToAlpha transformDagi: this.transformers){ transformDagi.transform(); this.numberOfInsertedEdges += transformDagi.getNumberOfInsertedEdges(); - this.setOfOutputDags.add(transformDagi.G); + this.setOfOutputDags.add(transformDagi.getGraph()); } - - return this.setOfOutputDags; - } + /** + * Returns the number of edges that were inserted during the transformation process. + * @return The total number of edges inserted across all transformed DAGs. + */ public int getNumberOfInsertedEdges(){ return this.numberOfInsertedEdges; } + /** + * Returns the list of input DAGs that were transformed. + * @return An ArrayList of DAGs that were provided as input to the transformation. + */ + public ArrayList getSetOfDags() { + return this.setOfDags; + } + + /** + * Returns the list of transformed output DAGs. + * This list contains the DAGs after applying the BetaToAlpha transformation. + * @return An ArrayList of transformed DAGs. + */ + public ArrayList getSetOfOutputDags() { + return this.setOfOutputDags; + } + + /** + * Returns the alpha order used for the transformation. + * @return An ArrayList of nodes representing the alpha order. + */ + public ArrayList getAlpha() { + return this.alpha; + } + + /** + * Returns the list of BetaToAlpha transformers used for the transformation. + * @return An ArrayList of BetaToAlpha objects, each corresponding to a DAG in the input list. + */ + public ArrayList getTransformers() { + return this.transformers; + } + + /** + * Sets the list of BetaToAlpha transformers. + * This method allows for updating the transformers used in the transformation process. + * + * @param transformers An ArrayList of BetaToAlpha objects to be set as the transformers for this TransformDags instance. + * Each transformer will apply the BetaToAlpha transformation to its corresponding DAG in the input list. + */ + public void setTransformers(ArrayList transformers) { + this.transformers = transformers; + } - - -// void computeWeight(){ -// -// this.weight = new int[this.setOfDags.size()][this.alfa.size()][this.alfa.size()]; -// -// for(Dag g: this.setOfOutputDags){ -// for(Node nodei : g.getNodes()){ -// List pa = g.getParents(nodei); -// if(pa.isEmpty()) continue; -// List anc = new ArrayList(); -// anc.add(nodei); -// anc = g.getAncestors(anc); -// Dag gAnc = new Dag(g.subgraph(anc)); -// // me quedo con el grafo ancestral del node_i -// for(Node pai: pa){ // Calculo el numero de caminos desde los ancestros que se "activan" borrando cada padre. -// int npaths = 0; -// Dag gAncNopai = new Dag(gAnc); -// for(Node rm : pa) if(!rm.equals(pai)) gAncNopai.removeNode(rm); // borro todos los padres menos el pa_i en el grafo ancestral. -// for(Node nodeAc: anc){ // para cada ancestro voy mirando si hay un camino dirigido. -// if((gAncNopai.getNodes().contains(nodeAc))&&(!nodeAc.equals(nodei))) -// if(GraphUtils.paths().existsDirectedPath(gAncNopai,nodeAc, nodei)) npaths++; -// } -// // npaths tiene el numero de caminos diridos que se han activado quitando el padre pa_i -// this.weight[this.setOfOutputDags.indexOf(g)][g.getNodes().indexOf(nodei)][g.getNodes().indexOf(pai)] = npaths; -// } -// -// } -// -// } -// -// } - -// public Dag computeWeightDag(boolean w){ -// -// Dag wDag = new Dag(this.alfa); -// for(Node nodei : this.alfa){ -// for(Node nodej : this.alfa){ -// if(nodei.equals(nodej)) continue; -// int wij = 0; -// for(Dag g: this.setOfOutputDags){ -// int wg = this.weight[this.setOfOutputDags.indexOf(g)][g.getNodes().indexOf(nodei)][g.getNodes().indexOf(nodej)]; -// if(wg > 0 && !w) wg = 1; -// else if (wg == 0) wg =-1; -// wij+=wg; -// } -// if(wij > 0) wDag.addEdge(new Edge(nodej,nodei,Endpoint.TAIL,Endpoint.ARROW)); -// } -// } -// return wDag; -// } -// -// public static void main(String args[]) { -// -// ArrayList dags = new ArrayList(); -// ArrayList alfa = new ArrayList(); -// Random aleatorio = new Random(222); -// -// -// System.out.println("Grafos de Partida: "); -// System.out.println("---------------------"); -//// Graph graph = GraphConverter.convert("X1-->X5,X2-->X3,X3-->X4,X4-->X1,X4-->X5"); -//// Dag dag = new Dag(graph); -// -// Dag dag = new Dag(); -// -// dag = GraphUtils.randomDag(Integer.parseInt(args[0]), Integer.parseInt(args[1]), true); -// BetaToAlpha mt = new BetaToAlpha(dag); -// alfa = mt.randomAlfa(aleatorio); -// dags.add(dag); -// System.out.println("DAG: ---------------"); -// System.out.println(dag.toString()); -// for (int i=0 ; i < Integer.parseInt(args[2])-1 ; i++){ -// Dag newDag = GraphUtils.randomDag(dag.getNodes(),Integer.parseInt(args[1]) ,true); -// dags.add(newDag); -// System.out.println("DAG: ---------------"); -// System.out.println(newDag.toString()); -// } -// -// -// -// System.out.println("Orden de Consenso: " + alfa.toString()); -// -// TransformDags setOfDags = new TransformDags(dags,alfa); -// setOfDags.transform(); -// -// -// -// -// for(Dag d : setOfDags.setOfOutputDags){ -// System.out.println("DAG trasformado: ---------------"); -// System.out.println(d.toString()); -// } -// -// -// -// Dag union = new Dag(alfa); -// -// for(Node nodei: alfa){ -// for(Dag d : setOfDags.setOfOutputDags){ -// Listparent = d.getParents(nodei); -// for(Node pa: parent){ -// if(!union.isParentOf(pa, nodei)) union.addEdge(new Edge(pa,nodei,Endpoint.TAIL,Endpoint.ARROW)); -// } -// } -// -// } -// -// -// System.out.println("Grafo UNION: "+union.toString()); -// setOfDags.computeWeight(); -// Dag wDag = setOfDags.computeWeightDag(true); -// System.out.println("Grafo Consenso: "+ wDag.toString()); -// Dag wDag2 = setOfDags.computeWeightDag(false); -// System.out.println("Grafo Consenso sin pesos: "+ wDag2.toString()); -// -// -// -// -// -// -// -//// Node nod = dag.getNodes().get(aleatorio.nextInt(alfa.size())); -//// ArrayList a = new ArrayList(); -//// a.add(nod); -//// List anc = dag.getAncestors(a); -//// -//// System.out.println("Ancenstros de " + nod.toString()+ " "+anc.toString()); -//// -//// System.out.println("Subgraph: "+ dag.subgraph(anc)); -//// -//// List pa = dag.getParents(nod); -//// System.out.println("padres de: "+nod.toString()+ " : "+pa.toString()); -//// Node pai = pa.get(aleatorio.nextInt(pa.size())); -//// System.out.println("Padre elegido a borrar: "+pai.toString()); -//// pa.remove(pai); -//// -//// -//// -//// for(Node rm: pa) anc.remove(rm); -//// -//// Graph pp = dag.subgraph(anc); -//// int npath = 0; -//// for(Node ancestor: pp.getNodes()){ -//// if(!ancestor.equals(nod)){ -//// List> paths = GraphUtils.allPathsFromTo(pp, ancestor, nod); -//// if(dag.paths().existsDirectedPath(ancestor, nod)) npath++; -//// System.out.println("Caminos desde: "+ancestor.toString()+" a: "+nod.toString()+" : "+paths.toString()); -//// } -//// } -//// System.out.println(" El numero de caminos desde hacia: "+ nod+" es de: "+npath); -// } -// } diff --git a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java index d4c7a7f..0c7d970 100644 --- a/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java +++ b/src/main/java/es/uclm/i3a/simd/consensusBN/Utils.java @@ -1,32 +1,54 @@ package es.uclm.i3a.simd.consensusBN; +import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Deque; +import java.util.HashSet; +import java.util.LinkedList; import java.util.List; +import java.util.Set; +import edu.cmu.tetrad.graph.Dag; import edu.cmu.tetrad.graph.Edge; import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetrad.graph.Endpoint; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.graph.Node; -public class Utils { +/** + * Utility class providing static methods for graph operations, particularly for transforming PDAGs to DAGs and checking d-separation. + * This class includes methods for converting a PDAG to a DAG, checking if two nodes are d-separated in a DAG, and finding neighbors of nodes. + */ +public final class Utils { - /** - * pdagToDag Algorithm From Chickering 2002: - * "We first consider a simple implementation of PDAG-to-DAG due to Dor and Tarsi (1992). - * Let NX denote the neighbors of node X in a PDAG P. - * We first create a DAG G that contains all of the directed edges from P, and no other edges. - * We then repeat the following procedure: - * First, select a node X in P such that: - * (1) X has no out-going edges and - * (2) if NX is non-empty, then NX PaX is a clique. - * If P admits a consistent extension, the node X is guaranteed to exist. - * Next, for each undirected edge Y X incident to X in P, insert a directed edge Y X to G. - * Finally, remove X and all incident edges from the P and continue with the next node. - * The algorithm terminates when all nodes have been deleted from P." - * @param graph The graph to be transformed from PDAG to DAG. - */ + /** + * Transforms a PDAG (Partially Directed Acyclic Graph) into a DAG (Directed Acyclic Graph) + * using the algorithm proposed by Dor and Tarsi (1992), as presented in Chickering (2002). + * + *

The algorithm proceeds as follows: + *

    + *
  1. Let NX be the set of neighbors of node X in the PDAG P.
  2. + *
  3. Create a new DAG G containing all the directed edges from P (and no others).
  4. + *
  5. Iteratively repeat the following steps: + *
      + *
    1. Select a node X such that: + *
        + *
      • (1) X has no outgoing directed edges in P, and
      • + *
      • (2) if NX is non-empty, then NX ∪ Pa(X) forms a clique.
      • + *
      + * Such a node is guaranteed to exist if P admits a consistent extension.
    2. + *
    3. For each undirected edge Y—X incident to X in P, orient it as Y → X in G.
    4. + *
    5. Remove node X and all its incident edges from P.
    6. + *
    + *
  6. + *
  7. The algorithm terminates when all nodes have been removed from P.
  8. + *
+ * + * @param graph The input PDAG to be converted into a DAG. + */ public static void pdagToDag(Graph graph){ // First create a DAG G that contains all of the directed edges from the PDAG, and no other edges. Graph graphAux = new EdgeListGraph(graph); @@ -45,7 +67,7 @@ public static void pdagToDag(Graph graph){ for(Node node : nodes){ x = node; //Checking if the node has no outgoing edges - if(graphAux.getChildren(node).size() != 0) + if(!graphAux.getChildren(node).isEmpty()) continue; //Checking if the neighbors form a clique if(!GraphUtils.isClique(graphAux.getAdjacentNodes(x), graphAux)) @@ -62,7 +84,203 @@ public static void pdagToDag(Graph graph){ break; // We break the loop to start again with the new graphAux without the removed node. } nodes.remove(x); - }while(nodes.size() > 0); + }while(!nodes.isEmpty()); + } + + /** + * Checks if two nodes in a DAG are d-separated given an empty set of conditioning nodes. + * @param g The DAG to check for d-separation. + * @param x The first node. + * @param y The second node. + * @return True if the nodes are d-separated, false otherwise. + */ + public static boolean dSeparated(Dag g, Node x, Node y) { + return dSeparated(g, x, y, new ArrayList<>()); + } + + /** + * Checks if two nodes in a DAG are d-separated given a set of conditioning nodes. + * This method uses a defensive copy of the conditioning set to ensure immutability. + * It builds an induced subgraph of the DAG containing only the relevant nodes and checks if there is a path between the two nodes that does not pass through the conditioning nodes. + * The method first finds the relevant nodes in the DAG, builds an induced subgraph, moralizes it, + * converts it to an undirected graph, and finally checks if the two nodes are reachable from each other in the undirected graph after removing the conditioning nodes. + * @param g The DAG to check for d-separation. + * @param x The first node. + * @param y The second node. + * @param cond The list of conditioning nodes. + * @return True if the nodes are d-separated, false otherwise. + */ + public static boolean dSeparated(Dag g, Node x, Node y, List cond) { + + Set relevantNodes = findRelevantNodes(g, x, y, cond); + Graph aux = buildInducedSubgraph(g, relevantNodes); + moralize(aux); + convertToUndirected(aux); + aux.removeNodes(cond); + return !isReachable(aux, x, y); + } + + /** + * Finds the relevant nodes in the DAG that are needed to check d-separation between two nodes x and y, given a set of conditioning nodes. + * This method performs a depth-first search starting from nodes x and y, and includes all nodes that are reachable from either x or y, as well as the conditioning nodes. + * The method uses a stack to explore the graph and a set to keep track of visited nodes, ensuring that + * all relevant nodes are included in the final set. + * @param g The DAG in which to find the relevant nodes. + * @param x The first node. + * @param y The second node. + * @param cond The list of conditioning nodes. + * @return A set of nodes that are relevant for checking d-separation between x and y, including the conditioning nodes. + */ + private static Set findRelevantNodes(Dag g, Node x, Node y, List cond) { + Set visited = new HashSet<>(); + Deque stack = new ArrayDeque<>(); + + stack.push(x); + stack.push(y); + for (Node c : cond) stack.push(c); + + while (!stack.isEmpty()) { + Node current = stack.pop(); + if (visited.add(current)) { + for (Node parent : g.getParents(current)) { + stack.push(parent); + } + } + } + + return visited; + } + + /** + * Builds an induced subgraph from the given DAG containing only the nodes specified in the set {@code nodesToKeep}. + * The method creates a new graph that includes all nodes in {@code nodesToKeep} and all directed edges between them that exist in the original graph. + * It ensures that only directed edges are included, and undirected edges are ignored. + * @param g The original DAG from which to build the induced subgraph. + * @param nodesToKeep The set of nodes to include in the induced subgraph. + * This set should contain the nodes that are relevant for the d-separation check. + * @return A new graph representing the induced subgraph containing only the specified nodes and their directed edges. + * This graph is a subgraph of the original DAG, containing only the nodes in {@code nodesToKeep} and the directed edges between them that exist in the original graph. + */ + private static Graph buildInducedSubgraph(Dag g, Set nodesToKeep) { + Graph subgraph = new EdgeListGraph(); + + for (Node node : g.getNodes()) { + if (nodesToKeep.contains(node)) { + subgraph.addNode(node); + } + } + + for (Edge e : g.getEdges()) { + if (!e.isDirected()) continue; + + Node tail = e.getNode1(); + Node head = e.getNode2(); + + if (nodesToKeep.contains(tail) && nodesToKeep.contains(head)) { + subgraph.addEdge(new Edge(tail, head, e.getEndpoint1(), e.getEndpoint2())); + } + } + + return subgraph; + } + + /** + * Moralizes the given graph by adding undirected edges between all pairs of parents of each child node. + * This process ensures that the resulting graph is undirected and that all parents of each child are connected. + * The moralization is done by iterating over each child node, retrieving its parents, and adding undirected edges between every pair of parents. + * This is a crucial step in preparing the graph for d-separation checks, as it ensures that the graph structure reflects the necessary connections between parents of child nodes. + * @param graph The graph to moralize. + */ + private static void moralize(Graph graph) { + for (Node child : graph.getNodes()) { + List parents = graph.getParents(child); + int n = parents.size(); + if (n <= 1) continue; + + for (int i = 0; i < n - 1; i++) { + for (int j = i + 1; j < n; j++) { + Node p1 = parents.get(i); + Node p2 = parents.get(j); + if (!graph.isAdjacentTo(p1, p2)) { + graph.addUndirectedEdge(p1, p2); + } + } + } + } + } + + /** + * Converts all directed edges in the graph to undirected edges. + * This method iterates through all edges in the graph and changes their endpoints to TAIL, + * effectively removing the directionality of the edges. This is useful for certain graph operations + * where the direction of edges is not relevant, such as when checking connectivity or performing undirected graph algorithms. + * @param graph The graph to convert. + */ + private static void convertToUndirected(Graph graph) { + for (Edge e : new ArrayList<>(graph.getEdges())) { + if (e.isDirected()) { + e.setEndpoint1(Endpoint.TAIL); + e.setEndpoint2(Endpoint.TAIL); + } + } + } + + /** + * Checks if there is a path between two nodes in the graph. + * This method performs a depth-first search starting from the {@code start} node and checks + * if it can reach the {@code target} node. It uses a stack to explore the graph and a set to keep track of visited nodes, + * ensuring that it does not revisit nodes. + * If the target node is found during the search, it returns true; otherwise, + * it returns false after exhausting all possible paths. + * @param g The graph to search. + * @param start The starting node. + * @param target The target node. + * @return True if there is a path from start to target, false otherwise. + */ + private static boolean isReachable(Graph g, Node start, Node target) { + Set visited = new HashSet<>(); + Deque stack = new ArrayDeque<>(); + stack.push(start); + + while (!stack.isEmpty()) { + Node current = stack.pop(); + if (current.equals(target)) return true; + if (visited.add(current)) { + for (Node neighbor : g.getAdjacentNodes(current)) { + if (!visited.contains(neighbor)) { + stack.push(neighbor); + } + } + } + } + + return false; + } + + + /** + * Finds the nodes that are neighbors of node y and x in the graph. + * This method retrieves the neighbors of node y that are also adjacent to node x, + * ensuring that the edges between them are directed. + * It filters out undirected edges to ensure that only neighbors from directed edges are considered. + * @param x Node x to find neighbors for. + * @param y Node y to find neighbors for. + * @param graph The graph in which to find the neighbors. + * @return A list of nodes that are neighbors of x and y, filtered to include only neighbors from directed edges. + */ + public static List findNaYX(Node x, Node y, Graph graph) { + List naYX = new LinkedList<>(graph.getAdjacentNodes(y)); + naYX.retainAll(graph.getAdjacentNodes(x)); + + for (int i = naYX.size()-1; i >= 0; i--) { + Node z = naYX.get(i); + Edge edge = graph.getEdge(y, z); + + if (!Edges.isUndirectedEdge(edge)) { + naYX.remove(z); + } + } + + return naYX; } - } diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java new file mode 100644 index 0000000..e965021 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/AlphaOrderTest.java @@ -0,0 +1,119 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; + +class AlphaOrderTest { + + private Node a, b, c; + private Dag dag1, dag2; + private ArrayList dags; + + @BeforeEach + void setup() { + a = new GraphNode("A"); + b = new GraphNode("B"); + c = new GraphNode("C"); + + // DAG 1: A → B → C + dag1 = new Dag(); + dag1.addNode(a); + dag1.addNode(b); + dag1.addNode(c); + dag1.addDirectedEdge(a, b); + dag1.addDirectedEdge(b, c); + + // DAG 2: A → B, A → C + dag2 = new Dag(); + dag2.addNode(a); + dag2.addNode(b); + dag2.addNode(c); + dag2.addDirectedEdge(a, b); + dag2.addDirectedEdge(a, c); + + dags = new ArrayList<>(Arrays.asList(dag1, dag2)); + } + + @Test + void constructorThrowsOnNullInput() { + assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(null)); + } + + @Test + void constructorThrowsOnEmptyList() { + assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(new ArrayList<>())); + } + + @Test + void constructorThrowsOnSingleDAG() { + ArrayList singleDagList = new ArrayList<>(); + singleDagList.add(dag1); + assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(singleDagList)); + } + + @Test + void constructorThrowsOnDifferentNodes() { + // Crear otro DAG con nodos diferentes + Dag dagDifferent = new Dag(); + dagDifferent.addNode(new GraphNode("X")); + dagDifferent.addNode(new GraphNode("Y")); + dagDifferent.addDirectedEdge(dagDifferent.getNode("X"), dagDifferent.getNode("Y")); + + ArrayList badList = new ArrayList<>(Arrays.asList(dag1, dagDifferent)); + assertThrows(IllegalArgumentException.class, () -> new AlphaOrder(badList)); + } + + @Test + void computeAlphaReturnsValidOrder() { + AlphaOrder alphaOrder = new AlphaOrder(dags); + alphaOrder.computeAlpha(); + List order = alphaOrder.getOrder(); + + assertNotNull(order); + assertEquals(3, order.size()); + assertTrue(order.contains(a)); + assertTrue(order.contains(b)); + assertTrue(order.contains(c)); + + // Optional: check for uniqueness + assertEquals(3, order.stream().distinct().count()); + } + + @Test + void computeAlphaForTwoSimpleDags(){ + AlphaOrder alphaOrder = new AlphaOrder(dags); + alphaOrder.computeAlpha(); + List order = alphaOrder.getOrder(); + + // Basic assertions to check the order + assertNotNull(order); + assertEquals(3, order.size()); + assertTrue(order.contains(a)); + assertTrue(order.contains(b)); + assertTrue(order.contains(c)); + + // Check that the order is A, B, C + assertEquals(a, order.get(0)); + assertEquals(b, order.get(1)); + assertEquals(c, order.get(2)); + + // Check for uniqueness + assertEquals(3, order.stream().distinct().count()); + + } +} + + + diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java new file mode 100644 index 0000000..b3ca80f --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/BackwardEquivalenceSearchDSepTest.java @@ -0,0 +1,110 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.Node; + +class BackwardEquivalenceSearchDSepTest { + + + private ArrayList createRandomDagList(int copies) { + ArrayList setOfDags = new ArrayList<>(); + setOfDags.addAll(GraphTestHelper.generateRandomDagList(20, copies, 50, 49, 49, 49, true, 0)); + return setOfDags; + } + + @Test + void testApplyBESdDoesNotThrow() { + // Setting up consensus union + ArrayList initialDags = createRandomDagList(3); + ConsensusUnion consensusUnion = new ConsensusUnion(initialDags); + Dag unionDag = consensusUnion.union(); + + // Running Backward Equivalence Search with d-separation + BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, consensusUnion.getTransformedDags()); + + // No exceptions should be thrown during the process + assertDoesNotThrow(() -> { + Dag output = besd.applyBackwardEliminationWithDSeparation(); + assertNotNull(output); + }); + } + + @Test + void testOutputIsDAG() { + // Setting up consensus union + ArrayList initialDags = createRandomDagList(3); + ConsensusUnion consensusUnion = new ConsensusUnion(initialDags); + Dag unionDag = consensusUnion.union(); + + // Running Backward Equivalence Search with d-separation + BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, consensusUnion.getTransformedDags()); + Dag outputDag = besd.applyBackwardEliminationWithDSeparation(); + + assertTrue(GraphUtils.isDag(outputDag), "El resultado no es un DAG válido."); + } + + @Test + void testAristasSePuedenEliminar() { + // Creamos un grafo donde A -> B, pero en todos los DAGs está A B (no conectados) + Node a = new GraphNode("A"); + Node b = new GraphNode("B"); + + Dag unionDag = new Dag(); + unionDag.addNode(a); + unionDag.addNode(b); + unionDag.addDirectedEdge(a, b); + + // DAGs originales sin esa arista + Dag dag1 = new Dag(); + dag1.addNode(a); + dag1.addNode(b); + + Dag dag2 = new Dag(); + dag2.addNode(a); + dag2.addNode(b); + // Aquí no hay aristas, A y B están desconectados + // sin conexión + + ArrayList initialDags = new ArrayList<>(); + initialDags.add(dag1); + initialDags.add(dag2); + AlphaOrder alphaOrder = new AlphaOrder(initialDags); + alphaOrder.computeAlpha(); + ArrayList transformedDags = (new TransformDags(initialDags, alphaOrder.getOrder())).transform(); + + BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags); + Dag outputDag = besd.applyBackwardEliminationWithDSeparation(); + + // Debe eliminar la arista A -> B por no tener soporte + Edge deletedEdge = outputDag.getEdge(a, b); + assertNull(deletedEdge, "La arista A -> B debería haberse eliminado."); + } + + @Test + void testGetNumberOfInsertedEdgesReflectsChanges() { + ArrayList initialDags = createRandomDagList(2); + ConsensusUnion consensusUnion = new ConsensusUnion(initialDags); + Dag unionDag = consensusUnion.union(); + ArrayList transformedDags = consensusUnion.getTransformedDags(); + int insertedEdgesBefore = consensusUnion.getNumberOfInsertedEdges(); + + BackwardEquivalenceSearchDSep besd = new BackwardEquivalenceSearchDSep(unionDag, initialDags, transformedDags); + besd.applyBackwardEliminationWithDSeparation(); + + int insertedEdgesAfter = insertedEdgesBefore - besd.getNumberOfRemovedEdges(); + // En el peor de los casos no ha eliminado ninguna, pero nunca debe ser negativo + assertTrue(insertedEdgesAfter >= 0, "The number of inserted edges should not be negative."); + assertTrue(insertedEdgesAfter <= insertedEdgesBefore, "The number of inserted edges should decrease after BES."); + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java new file mode 100644 index 0000000..e3900e4 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/BetaToAlphaTest.java @@ -0,0 +1,109 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; + +public class BetaToAlphaTest { + + private Dag dag; + private Node a, b, c, d; + private ArrayList alphaOrder; + + @BeforeEach + void setUp() { + a = new GraphNode("A"); + b = new GraphNode("B"); + c = new GraphNode("C"); + d = new GraphNode("D"); + + // DAG: A → B, A → C, B → D, C → D + dag = new Dag(); + dag.addNode(a); + dag.addNode(b); + dag.addNode(c); + dag.addNode(d); + dag.addDirectedEdge(a, b); + dag.addDirectedEdge(a, c); + dag.addDirectedEdge(b, d); + dag.addDirectedEdge(c, d); + + // Define an alpha order that requires modifying the graph + alphaOrder = new ArrayList<>(Arrays.asList(d, c, b, a)); + } + + @Test + void testTransformRespectsAlphaOrder() { + BetaToAlpha bta = new BetaToAlpha(dag, alphaOrder); + bta.transform(); + + // El grafo debería haber invertido al menos algunos arcos + assertTrue(bta.getNumberOfInsertedEdges() > 0); + + // Validamos que el orden resultante es compatible con alpha + for (Edge edge : dag.getEdges()) { + Node from = edge.getNode1(); + Node to = edge.getNode2(); + + int fromIndex = alphaOrder.indexOf(from); + int toIndex = alphaOrder.indexOf(to); + + assertTrue(fromIndex < toIndex, "Edge violates alpha order: " + from + " → " + to); + } + } + + @Test + void testRandomAlphaProducesPermutation() { + BetaToAlpha bta = new BetaToAlpha(dag); + List randomAlpha = bta.randomAlpha(new Random(42)); + + assertNotNull(randomAlpha); + assertEquals(dag.getNumNodes(), randomAlpha.size()); + + Set originalNodes = new HashSet<>(dag.getNodes()); + Set shuffled = new HashSet<>(randomAlpha); + + assertEquals(originalNodes, shuffled); // misma colección, diferente orden + } + + @Test + void testComputeAlphaHashBuildsCorrectMap() { + BetaToAlpha bta = new BetaToAlpha(dag, alphaOrder); + bta.computeAlphaHash(); + + for (int i = 0; i < alphaOrder.size(); i++) { + Node node = alphaOrder.get(i); + assertEquals(i, (bta.getAlphaHash()).get(node)); + } + } + + + @Test + public void setterAndGetterTest(){ + ArrayList firstOrder = new ArrayList<>(Arrays.asList(a, b, c)); + BetaToAlpha b2Alpha = new BetaToAlpha(dag, firstOrder); + List newOrder = new ArrayList<>(Arrays.asList(c, b, a)); + b2Alpha.setAlphaOrder(newOrder); + + List order = b2Alpha.getAlphaOrder(); + assertNotNull(order); + assertEquals(3, order.size()); + assertTrue(order.contains(c)); + assertTrue(order.contains(b)); + assertTrue(order.contains(a)); + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java new file mode 100644 index 0000000..f704120 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusBESTest.java @@ -0,0 +1,196 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; + +public class ConsensusBESTest { + private ArrayList inputDags; + private ArrayList alpha; + + @BeforeEach + public void setUp() { + inputDags = new ArrayList<>(); + + // We use 4 nodes for the DAGs + Node nodeA = new GraphNode("A"); + Node nodeB = new GraphNode("B"); + Node nodeC = new GraphNode("C"); + Node nodeD = new GraphNode("D"); + + // Create first DAG with these edges: A -> B, A -> C, B -> D, C -> D + Dag dag1 = new Dag(); + dag1.addNode(nodeA); + dag1.addNode(nodeB); + dag1.addNode(nodeC); + dag1.addNode(nodeD); + + // Adding directed edges to the DAG + dag1.addDirectedEdge(nodeA, nodeB); + dag1.addDirectedEdge(nodeA, nodeC); + dag1.addDirectedEdge(nodeB, nodeD); + dag1.addDirectedEdge(nodeC, nodeD); + + // Adding the DAG to the list + inputDags.add(dag1); + + // Create second DAG with these edges: D -> C, D -> B, C -> A, B -> A + Dag dag2 = new Dag(); + dag2.addNode(nodeA); + dag2.addNode(nodeB); + dag2.addNode(nodeC); + dag2.addNode(nodeD); + + // Adding directed edges to the second DAG + dag2.addDirectedEdge(nodeD, nodeC); + dag2.addDirectedEdge(nodeD, nodeB); + dag2.addDirectedEdge(nodeC, nodeA); + dag2.addDirectedEdge(nodeB, nodeA); + + // Adding the second DAG to the list + inputDags.add(dag2); + + // Apply AlphaOrder algorithm to these dags: + AlphaOrder alphaOrder = new AlphaOrder(inputDags); + alphaOrder.computeAlpha(); + alpha = alphaOrder.getOrder(); + + } + + @Test + public void testConsensusUnionConsistency() { + ConsensusUnion cu = new ConsensusUnion(inputDags, alpha); + Dag expected = cu.union(); + assertNotNull(cu); + assertNotNull(expected); + + ConsensusBES consensusBES = new ConsensusBES(inputDags); + consensusBES.consensusUnion(); + + // Check if the union DAG is not null and has nodes + Dag unionDag = consensusBES.getUnion(); + assertNotNull(unionDag); + assertNotNull(unionDag.getNodes()); + + // Check that the union DAG is equal to the expected DAG + assertNotNull(unionDag.getNodes()); + assertNotNull(expected.getNodes()); + assertEquals(expected, unionDag); + assertEquals(expected.getNodes().size(), unionDag.getNodes().size()); + assertEquals(expected.getEdges().size(), unionDag.getEdges().size()); + + for (Node node : expected.getNodes()) { + assert unionDag.getNodes().contains(node); + } + for(Edge edge : expected.getEdges()) { + assert unionDag.getEdges().contains(edge); + } + } + + @Test + public void testRandomBNFusion(){ + // Creating random DAGs + ArrayList randomDagsList = new ArrayList<>(); + randomDagsList.addAll(GraphTestHelper.generateRandomDagList(20, 2, 50, 19, 19, 19, true,0)); + + // Creating ConsensusUnion instance + ConsensusUnion consensusUnionOnly = new ConsensusUnion(randomDagsList); + Dag unionDagOnly = consensusUnionOnly.union(); + + assertNotNull(unionDagOnly); + assertTrue(unionDagOnly.getNumEdges() >= 0); + assertTrue(unionDagOnly.getNodes().size() == randomDagsList.get(0).getNodes().size()); + + + ConsensusBES conDag = new ConsensusBES(randomDagsList); + conDag.fusion(); + Dag besDag = conDag.getFusionDag(); + Dag unionDag = conDag.getUnion(); + ConsensusUnion consensusUnion = conDag.getConsensusUnion(); + int totalNumberOfInsertedEdges = conDag.getNumberOfInsertedEdges(); + int consensusNumberOfInsertedEdges = consensusUnion.getNumberOfInsertedEdges(); + + assertNotNull(besDag); + assertNotNull(unionDag); + assertNotNull(consensusUnion); + assertEquals(unionDagOnly, unionDag); + assertEquals(besDag.getNodes().size(), unionDag.getNodes().size()); + assert consensusNumberOfInsertedEdges >= 0; + assert consensusNumberOfInsertedEdges >= totalNumberOfInsertedEdges; + } + + + @Test + void testFusionProducesDag() { + ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags); + fusionAlgorithm.fusion(); + + Dag result = fusionAlgorithm.getFusionDag(); + assertNotNull(result, "El DAG de salida no debe ser null."); + assertFalse(result.paths().existsDirectedCycle(), "El DAG resultante no debe tener ciclos."); + } + + @Test + void testEdgeInsertionCountIsCorrectlyComputed() { + ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags); + fusionAlgorithm.fusion(); + + int insertedEdges = fusionAlgorithm.getNumberOfInsertedEdges(); + assertTrue(insertedEdges >= 0, "El número de aristas insertadas debe ser >= 0."); + } + + @Test + void testFusionOrderIsValid() { + ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags); + fusionAlgorithm.fusion(); + + List order = fusionAlgorithm.getOrderFusion(); + assertNotNull(order, "El orden de fusión no debe ser null."); + assertEquals(4, order.size(), "El orden de fusión debe tener 3 nodos."); + } + + @Test + void testTransformedDagsAreAccessibleAfterFusion() { + ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags); + fusionAlgorithm.fusion(); + + ArrayList transformed = fusionAlgorithm.getTransformedDags(); + assertEquals(2, transformed.size(), "Debe haber 2 DAGs transformados."); + } + + @Test + void testGetTransformedDagsWithoutFusionThrowsException() { + ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags); + + assertThrows(IllegalStateException.class, fusionAlgorithm::getTransformedDags, + "Debe lanzar una excepción si se accede a los DAGs transformados sin llamar a fusion()."); + } + + @Test + void testThreadExecutionWithRunMethod() { + ConsensusBES fusionAlgorithm = new ConsensusBES(inputDags); + Thread thread = new Thread(fusionAlgorithm); + thread.start(); + try { + thread.join(); + } catch (InterruptedException e) { + fail("El hilo fue interrumpido."); + } + + assertNotNull(fusionAlgorithm.getFusionDag(), "El DAG resultante debe existir tras ejecutar run()."); + } + +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java new file mode 100644 index 0000000..09de508 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ConsensusUnionTest.java @@ -0,0 +1,151 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.RandomGraph; +import edu.cmu.tetrad.util.RandomUtil; + +public class ConsensusUnionTest { + + private ArrayList inputDags; + private ArrayList alpha; + + @BeforeEach + public void setUp() { + inputDags = new ArrayList<>(); + + // We use 4 nodes for the DAGs + Node nodeA = new GraphNode("A"); + Node nodeB = new GraphNode("B"); + Node nodeC = new GraphNode("C"); + Node nodeD = new GraphNode("D"); + + // Create first DAG with these edges: A -> B, A -> C, B -> D, C -> D + Dag dag1 = new Dag(); + dag1.addNode(nodeA); + dag1.addNode(nodeB); + dag1.addNode(nodeC); + dag1.addNode(nodeD); + + // Adding directed edges to the DAG + dag1.addDirectedEdge(nodeA, nodeB); + dag1.addDirectedEdge(nodeA, nodeC); + dag1.addDirectedEdge(nodeB, nodeD); + dag1.addDirectedEdge(nodeC, nodeD); + + // Adding the DAG to the list + inputDags.add(dag1); + + // Create second DAG with these edges: D -> C, D -> B, C -> A, B -> A + Dag dag2 = new Dag(); + dag2.addNode(nodeA); + dag2.addNode(nodeB); + dag2.addNode(nodeC); + dag2.addNode(nodeD); + + // Adding directed edges to the second DAG + dag2.addDirectedEdge(nodeD, nodeC); + dag2.addDirectedEdge(nodeD, nodeB); + dag2.addDirectedEdge(nodeC, nodeA); + dag2.addDirectedEdge(nodeB, nodeA); + + // Adding the second DAG to the list + inputDags.add(dag2); + + // Apply AlphaOrder algorithm to these dags: + AlphaOrder alphaOrder = new AlphaOrder(inputDags); + alphaOrder.computeAlpha(); + alpha = alphaOrder.getOrder(); + + } + + @Test + public void testConstructorWithAlphaInitializesCorrectly() { + ConsensusUnion cu = new ConsensusUnion(inputDags, alpha); + assertNotNull(cu); + } + + @Test + public void testConstructorWithoutAlphaGeneratesAlpha() { + ConsensusUnion cu = new ConsensusUnion(inputDags); + assertNotNull(cu); + } + + @Test + public void testUnionReturnsNonNullDag() { + ConsensusUnion cu = new ConsensusUnion(inputDags, alpha); + Dag result = cu.union(); + assertNotNull(result); + } + + @Test + public void testNumberOfInsertedEdgesIsUpdated() { + ConsensusUnion cu = new ConsensusUnion(inputDags, alpha); + cu.union(); + assertTrue(cu.getNumberOfInsertedEdges() >= 0); + } + + @Test + public void testSetDagsUpdatesAlphaAndUnion() { + ConsensusUnion cu = new ConsensusUnion(); + cu.setDags(inputDags); + cu.union(); + Dag result = cu.getUnion(); + assertNotNull(result); + assertTrue(result.getNumEdges() >= 1); + } + + @Test + public void testRunMethodExecutesUnion() { + ConsensusUnion cu = new ConsensusUnion(inputDags, alpha); + cu.run(); + assertNotNull(cu.getUnion()); + } + + @Test + public void testEmptyDagListReturnsEmptyUnion() { + assertThrows(IllegalArgumentException.class, () -> new ConsensusUnion(new ArrayList<>())); + } + + @Test + public void testRandomBNGeneratesConsensusUnionCorrectly() { + ArrayList randomDagsList = new ArrayList<>(); + int sizeRandomDags = 2; + int numVariables = 20; + + // Creating list of shared nodes + ArrayList sharedNodes = new ArrayList<>(); + for (int i = 0; i < numVariables; i++) { + Node node = new GraphNode("Node" + i); + sharedNodes.add(node); + } + // Setting seed + RandomUtil.getInstance().setSeed(42); + + // Generating random DAGs + for (int i = 0; i < sizeRandomDags; i++) { + Dag randomDag = RandomGraph.randomDag(sharedNodes,0,50,19,19,19,true); + randomDagsList.add(randomDag); + } + + // Applying ConsensusUnion + ConsensusUnion conDag = new ConsensusUnion(randomDagsList); + Graph g = conDag.union(); + + // Validating the resulting consensus DAG + assertNotNull(g); + assertTrue(g.getNumEdges() >= 0); + assertTrue(g.getNodes().size() == randomDagsList.get(0).getNodes().size()); + + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java new file mode 100644 index 0000000..4d5d880 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/DSeparationKeyTest.java @@ -0,0 +1,91 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; + +class DSeparationKeyTest { + + private final Node X = new GraphNode("X"); + private final Node Y = new GraphNode("Y"); + private final Node Z = new GraphNode("Z"); + private final Node W = new GraphNode("W"); + + @Test + void testConstructorAndGetters() { + Set zSet = new HashSet<>(Arrays.asList(Z, W)); + DSeparationKey key = new DSeparationKey(X, Y, zSet); + + assertEquals(X, key.getX()); + assertEquals(Y, key.getY()); + assertEquals(new HashSet<>(Arrays.asList(Z, W)), key.getConditioningSet()); + } + + @Test + void testEqualsSameContentDifferentOrder() { + Set zSet1 = new HashSet<>(Arrays.asList(Z, W)); + Set zSet2 = new HashSet<>(Arrays.asList(W, Z)); + + DSeparationKey key1 = new DSeparationKey(X, Y, zSet1); + DSeparationKey key2 = new DSeparationKey(X, Y, zSet2); + + assertEquals(key1, key2); + assertEquals(key1.hashCode(), key2.hashCode()); + } + + @Test + void testNotEqualsDifferentZ() { + Set zSet1 = new HashSet<>(Collections.singletonList(Z)); + Set zSet2 = new HashSet<>(Arrays.asList(Z, W)); + + DSeparationKey key1 = new DSeparationKey(X, Y, zSet1); + DSeparationKey key2 = new DSeparationKey(X, Y, zSet2); + + assertNotEquals(key1, key2); + } + + @Test + void testSimmetryBetweenXandY() { + DSeparationKey key1 = new DSeparationKey(X, Y, Collections.emptySet()); + DSeparationKey key2 = new DSeparationKey(Y, X, Collections.emptySet()); + + assertEquals(key1, key2); + } + + @Test + void testEqualsSelf() { + DSeparationKey key = new DSeparationKey(X, Y, Collections.singleton(Z)); + assertEquals(key, key); + } + + @Test + void testNotEqualsNullOrDifferentClass() { + DSeparationKey key = new DSeparationKey(X, Y, Collections.singleton(Z)); + + assertNotEquals(null, key); + assertNotEquals("NotAKey", key); + } + + @Test + void testKeyAsMapKey() { + DSeparationKey key1 = new DSeparationKey(X, Y, new HashSet<>(Arrays.asList(Z, W))); + DSeparationKey key2 = new DSeparationKey(X, Y, new HashSet<>(Arrays.asList(W, Z))); + + Map map = new HashMap<>(); + map.put(key1, true); + + assertTrue(map.containsKey(key2)); + assertTrue(map.get(key2)); + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java new file mode 100644 index 0000000..2f92250 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/DseparationTest.java @@ -0,0 +1,333 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; + +public class DseparationTest { + // d-separation tests + + private Node node(String name) { + return new GraphNode(name); + } + + private Dag createDag(Edge... edges) { + Dag dag = new Dag(); + for (Edge edge : edges) { + dag.addNode(edge.getNode1()); + dag.addNode(edge.getNode2()); + dag.addDirectedEdge(edge.getNode1(), edge.getNode2()); + } + return dag; + } + + @Test + public void testDirectConnection() { + Node A = node("A"), B = node("B"); + Dag dag = createDag(Edges.directedEdge(A, B)); + List conditioning = Collections.emptyList(); + assertFalse(Utils.dSeparated(dag, A, B, conditioning)); + } + + @Test + public void testChainNoCondition() { + Node A = node("A"), B = node("B"), C = node("C"); + Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C)); + List conditioning = Collections.emptyList(); + assertFalse(Utils.dSeparated(dag, A, C, conditioning)); + } + + @Test + public void testChainWithCondition() { + Node A = node("A"), B = node("B"), C = node("C"); + Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C)); + + assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B))); + } + + @Test + public void testColliderNoCondition() { + Node A = node("A"), B = node("B"), C = node("C"); + Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B)); + List conditioning = Collections.emptyList(); + assertTrue(Utils.dSeparated(dag, A, C, conditioning)); + } + + @Test + public void testColliderConditionedOnCollider() { + Node A = node("A"), B = node("B"), C = node("C"); + Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, B)); + + assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(B))); + } + + @Test + public void testDivergingNoCondition() { + Node A = node("A"), B = node("B"), C = node("C"); + Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C)); + List conditioning = Collections.emptyList(); + assertFalse(Utils.dSeparated(dag, A, C, conditioning)); + } + + @Test + public void testDivergingConditionedOnCommonParent() { + Node A = node("A"), B = node("B"), C = node("C"); + Dag dag = createDag(Edges.directedEdge(B, A), Edges.directedEdge(B, C)); + + assertTrue(Utils.dSeparated(dag, A, C, Collections.singletonList(B))); + } + + @Test + public void testColliderConditionedOnDescendant() { + Node A = node("A"), B = node("B"), C = node("C"), D = node("D"); + Dag dag = createDag( + Edges.directedEdge(A, B), + Edges.directedEdge(C, B), + Edges.directedEdge(B, D) + ); + + assertFalse(Utils.dSeparated(dag, A, C, Collections.singletonList(D))); + } + + @Test + public void testSimmetryBetweenXandY() { + Node A = node("A"), B = node("B"), C = node("C"), D = node("D"); + Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(C, A), Edges.directedEdge(D, A)); + List conditioning = Collections.emptyList(); + + // No colliders between A and B, so they are not d-separated + assertFalse(Utils.dSeparated(dag, A, B, conditioning)); + assertFalse(Utils.dSeparated(dag, B, A, conditioning)); + + // C->A and D->A makes A a collider for C and D, and therefore C and D are d-separated from each other + assertTrue(Utils.dSeparated(dag, C, D, conditioning)); + assertTrue(Utils.dSeparated(dag, D, C, conditioning)); + + } + + @Test + public void testDseparationMethodsAreEquivalent() { + Node A = node("A"), B = node("B"), C = node("C"); + Dag dag = createDag(Edges.directedEdge(A, B), Edges.directedEdge(B, C)); + + // Using the dSeparated method with no conditioning + assertFalse(Utils.dSeparated(dag, A, C)); + + // Using the dSeparated method with an empty conditioning set + assertFalse(Utils.dSeparated(dag, A, C, Collections.emptyList())); + } + + @Test + public void testDseparationRule1(){ + // Scenarios taken from https://yuyangyy.medium.com/understand-d-separation-471f9aada503 + // Scenario 1: Z is empty, namely, we don't dondition on any variables + // Rule 1: If there exists a path from x to y and there is no collider on the path, then x and y are not d-separated. + Node x = new GraphNode("x"); + Node r = new GraphNode("r"); + Node s = new GraphNode("s"); + Node t = new GraphNode("t"); + Node u = new GraphNode("u"); + Node v = new GraphNode("v"); + Node y = new GraphNode("y"); + + Dag dag = new Dag(); + dag.addNode(x); + dag.addNode(r); + dag.addNode(s); + dag.addNode(t); + dag.addNode(u); + dag.addNode(v); + dag.addNode(y); + + // Edges: x -> r -> s -> t <- u <- v -> y + dag.addDirectedEdge(x, r); + dag.addDirectedEdge(r, s); + dag.addDirectedEdge(s, t); + dag.addDirectedEdge(u, t); + dag.addDirectedEdge(v, u); + dag.addDirectedEdge(v, y); + + // Check d-separations + assertFalse(Utils.dSeparated(dag, x, r)); + assertFalse(Utils.dSeparated(dag, x, s)); + assertFalse(Utils.dSeparated(dag, x, t)); + + assertTrue(Utils.dSeparated(dag, x, u)); + assertTrue(Utils.dSeparated(dag, x, v)); + assertTrue(Utils.dSeparated(dag, x, y)); + + assertFalse(Utils.dSeparated(dag, u, v)); + assertFalse(Utils.dSeparated(dag, u, y)); + } + + public void testDseparationRule2(){ + // Scenarios taken from https://yuyangyy.medium.com/understand-d-separation-471f9aada503 + // Scenario 2: Z is non-empty, and the colliders don't belong to Z or have no children in Z. + // Rule 2: If there exists a path from x to y and none of the nodes on the path belongs to Z, then x and y are not d-separated. + Node x = new GraphNode("x"); + Node r = new GraphNode("r"); + Node s = new GraphNode("s"); + Node t = new GraphNode("t"); + Node u = new GraphNode("u"); + Node v = new GraphNode("v"); + Node y = new GraphNode("y"); + + Dag dag = new Dag(); + dag.addNode(x); + dag.addNode(r); + dag.addNode(s); + dag.addNode(t); + dag.addNode(u); + dag.addNode(v); + dag.addNode(y); + + // Edges: x -> r -> s -> t <- u <- v -> y + dag.addDirectedEdge(x, r); + dag.addDirectedEdge(r, s); + dag.addDirectedEdge(s, t); + dag.addDirectedEdge(u, t); + dag.addDirectedEdge(v, u); + dag.addDirectedEdge(v, y); + + // Creating a conditioning set Z that does not include any colliders + List Z = new ArrayList<>(); + Z.add(r); + Z.add(v); + + // Check d-separations + // Node x + assertTrue(Utils.dSeparated(dag, x, r, Z)); + assertTrue(Utils.dSeparated(dag, x, s, Z)); + assertTrue(Utils.dSeparated(dag, x, t, Z)); + assertTrue(Utils.dSeparated(dag, x, u, Z)); + assertTrue(Utils.dSeparated(dag, x, v, Z)); + assertTrue(Utils.dSeparated(dag, x, y, Z)); + + // Node r + assertTrue(Utils.dSeparated(dag, r, s, Z)); + assertTrue(Utils.dSeparated(dag, r, t, Z)); + assertTrue(Utils.dSeparated(dag, r, u, Z)); + assertTrue(Utils.dSeparated(dag, r, v, Z)); + assertTrue(Utils.dSeparated(dag, r, y, Z)); + + // Node s + assertFalse(Utils.dSeparated(dag, s, t, Z)); // No node on the path belongs to Z, so d-separated + assertTrue(Utils.dSeparated(dag, s, u, Z)); + assertTrue(Utils.dSeparated(dag, s, v, Z)); + assertTrue(Utils.dSeparated(dag, s, y, Z)); + + // Node t + assertFalse(Utils.dSeparated(dag, t, u, Z)); // No node on the path belongs to Z, so d-separated + assertTrue(Utils.dSeparated(dag, t, v, Z)); + assertTrue(Utils.dSeparated(dag, t, y, Z)); + + // Node u + assertTrue(Utils.dSeparated(dag, u, v, Z)); + assertTrue(Utils.dSeparated(dag, u, y, Z)); + + // Node v + assertTrue(Utils.dSeparated(dag, v, y, Z)); + + + } + + @Test + public void testDseparationRule3(){ + // Scenarios taken from https://yuyangyy.medium.com/understand-d-separation-471f9aada503 + // Scenario 3: Z is non-empty, and there are colliders either inside Z or have children in Z. + // Rule 3: For colliders that fall inside Z or have children in Z, they are no longer seen as colliders. + + Node x = new GraphNode("x"); + Node r = new GraphNode("r"); + Node s = new GraphNode("s"); + Node t = new GraphNode("t"); + Node u = new GraphNode("u"); + Node v = new GraphNode("v"); + Node y = new GraphNode("y"); + Node w = new GraphNode("w"); + Node p = new GraphNode("p"); + Node q = new GraphNode("q"); + + Dag dag = new Dag(); + dag.addNode(x); + dag.addNode(r); + dag.addNode(s); + dag.addNode(t); + dag.addNode(u); + dag.addNode(v); + dag.addNode(y); + dag.addNode(w); + dag.addNode(p); + dag.addNode(q); + + // Edges: x -> r -> s -> t <- u <- v -> y + r->w, t->p, v->q + dag.addDirectedEdge(x, r); + dag.addDirectedEdge(r, s); + dag.addDirectedEdge(s, t); + dag.addDirectedEdge(u, t); + dag.addDirectedEdge(v, u); + dag.addDirectedEdge(v, y); + dag.addDirectedEdge(r, w); + dag.addDirectedEdge(t, p); + dag.addDirectedEdge(v, q); + + // Creating a conditioning set Z that includes colliders and their children + List Z = new ArrayList<>(); + Z.add(r); + Z.add(p); + + // Check d-separations + // Node x + assertTrue(Utils.dSeparated(dag, x, s, Z)); + assertTrue(Utils.dSeparated(dag, x, t, Z)); + assertTrue(Utils.dSeparated(dag, x, u, Z)); + assertTrue(Utils.dSeparated(dag, x, v, Z)); + assertTrue(Utils.dSeparated(dag, x, y, Z)); + assertTrue(Utils.dSeparated(dag, x, w, Z)); + assertTrue(Utils.dSeparated(dag, x, q, Z)); + + // Node w + assertTrue(Utils.dSeparated(dag, w, s, Z)); + assertTrue(Utils.dSeparated(dag, w, t, Z)); + assertTrue(Utils.dSeparated(dag, w, u, Z)); + assertTrue(Utils.dSeparated(dag, w, v, Z)); + assertTrue(Utils.dSeparated(dag, w, y, Z)); + assertTrue(Utils.dSeparated(dag, w, q, Z)); + + // Node s + assertFalse(Utils.dSeparated(dag, s, t, Z)); + assertFalse(Utils.dSeparated(dag, s, u, Z)); + assertFalse(Utils.dSeparated(dag, s, v, Z)); + assertFalse(Utils.dSeparated(dag, s, q, Z)); + assertFalse(Utils.dSeparated(dag, s, y, Z)); + + // Node t + assertFalse(Utils.dSeparated(dag, t, u, Z)); + assertFalse(Utils.dSeparated(dag, t, v, Z)); + assertFalse(Utils.dSeparated(dag, t, y, Z)); + assertFalse(Utils.dSeparated(dag, t, q, Z)); + + // Node u + assertFalse(Utils.dSeparated(dag, u, v, Z)); + assertFalse(Utils.dSeparated(dag, u, y, Z)); + assertFalse(Utils.dSeparated(dag, u, q, Z)); + + // Node v + assertFalse(Utils.dSeparated(dag, v, y, Z)); + assertFalse(Utils.dSeparated(dag, v, q, Z)); + + // Node y + assertFalse(Utils.dSeparated(dag, y, q, Z)); + + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/FindNaYXTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/FindNaYXTest.java new file mode 100644 index 0000000..c386850 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/FindNaYXTest.java @@ -0,0 +1,140 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Endpoint; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; + + +public class FindNaYXTest { + + //find naYX tests + @Test + public void testFindNaYX_singleUndirectedCommonNeighbor() { + Graph graph = new EdgeListGraph(); + + Node x = new GraphNode("X"); + Node y = new GraphNode("Y"); + Node z = new GraphNode("Z"); + + graph.addNode(x); + graph.addNode(y); + graph.addNode(z); + + graph.addEdge(new Edge(x, z, Endpoint.TAIL, Endpoint.TAIL)); // undirected + graph.addEdge(new Edge(y, z, Endpoint.TAIL, Endpoint.TAIL)); // undirected + + List result = Utils.findNaYX(x, y, graph); + assertEquals(1, result.size()); + assertTrue(result.contains(z)); + } + + @Test + public void testFindNaYX_directedEdgeShouldBeExcluded() { + Graph graph = new EdgeListGraph(); + + Node x = new GraphNode("X"); + Node y = new GraphNode("Y"); + Node z = new GraphNode("Z"); + + graph.addNode(x); + graph.addNode(y); + graph.addNode(z); + + graph.addEdge(new Edge(x, z, Endpoint.ARROW, Endpoint.TAIL)); // x → z + graph.addEdge(new Edge(y, z, Endpoint.ARROW, Endpoint.TAIL)); // y → z + + List result = Utils.findNaYX(x, y, graph); + assertEquals(0, result.size()); + } + + @Test + public void testFindNaYX_mixedNeighbors() { + Graph graph = new EdgeListGraph(); + + Node x = new GraphNode("X"); + Node y = new GraphNode("Y"); + Node z1 = new GraphNode("Z1"); // undirected common + Node z2 = new GraphNode("Z2"); // directed common + Node z3 = new GraphNode("Z3"); // only adjacent to x + + graph.addNode(x); + graph.addNode(y); + graph.addNode(z1); + graph.addNode(z2); + graph.addNode(z3); + + // z1: undirected edge with both + graph.addEdge(new Edge(x, z1, Endpoint.TAIL, Endpoint.TAIL)); + graph.addEdge(new Edge(y, z1, Endpoint.TAIL, Endpoint.TAIL)); + + // z2: directed edges with both + graph.addEdge(new Edge(x, z2, Endpoint.TAIL, Endpoint.ARROW)); + graph.addEdge(new Edge(y, z2, Endpoint.TAIL, Endpoint.ARROW)); + + // z3: only adjacent to x + graph.addEdge(new Edge(x, z3, Endpoint.TAIL, Endpoint.TAIL)); + + List result = Utils.findNaYX(x, y, graph); + assertEquals(1, result.size()); + assertTrue(result.contains(z1)); + assertFalse(result.contains(z2)); + assertFalse(result.contains(z3)); + } + + @Test + public void testFindNaYX_multipleUndirectedCommonNeighbors() { + Graph graph = new EdgeListGraph(); + + Node x = new GraphNode("X"); + Node y = new GraphNode("Y"); + Node a = new GraphNode("A"); + Node b = new GraphNode("B"); + + graph.addNode(x); + graph.addNode(y); + graph.addNode(a); + graph.addNode(b); + + graph.addEdge(new Edge(x, a, Endpoint.TAIL, Endpoint.TAIL)); + graph.addEdge(new Edge(y, a, Endpoint.TAIL, Endpoint.TAIL)); + graph.addEdge(new Edge(x, b, Endpoint.TAIL, Endpoint.TAIL)); + graph.addEdge(new Edge(y, b, Endpoint.TAIL, Endpoint.TAIL)); + + List result = Utils.findNaYX(x, y, graph); + assertEquals(2, result.size()); + assertTrue(result.contains(a)); + assertTrue(result.contains(b)); + } + + @Test + public void testFindNaYX_noCommonNeighbors() { + Graph graph = new EdgeListGraph(); + + Node x = new GraphNode("X"); + Node y = new GraphNode("Y"); + Node a = new GraphNode("A"); + Node b = new GraphNode("B"); + + graph.addNode(x); + graph.addNode(y); + graph.addNode(a); + graph.addNode(b); + + graph.addEdge(new Edge(x, a, Endpoint.TAIL, Endpoint.TAIL)); + graph.addEdge(new Edge(y, b, Endpoint.TAIL, Endpoint.TAIL)); + + List result = Utils.findNaYX(x, y, graph); + assertTrue(result.isEmpty()); + } + +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/GraphTestHelper.java b/src/test/java/es/uclm/i3a/simd/consensusBN/GraphTestHelper.java new file mode 100644 index 0000000..01061e8 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/GraphTestHelper.java @@ -0,0 +1,60 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; +import java.util.List; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.RandomGraph; +import edu.cmu.tetrad.util.RandomUtil; + +public class GraphTestHelper { + + + private GraphTestHelper() { + // Private constructor to prevent instantiation + } + + /** + * Generates a list of random DAGs sharing the same set of nodes. + * + * @param numVariables Number of variables (nodes) in each DAG. + * @param numDags Number of random DAGs to generate. + * @param maxEdges Maximum number of edges in each DAG. + * @param maxInDegree Maximum in-degree for each node. + * @param maxOutDegree Maximum out-degree for each node. + * @param maxDegree Maximum degree for each node. + * @param connected Whether the generated DAGs should be connected. + * @param seed Seed for random number generation. + * @return List of randomly generated DAGs + */ + public static List generateRandomDagList(int numVariables, int numDags, int maxEdges, int maxInDegree, int maxOutDegree, int maxDegree, boolean connected, long seed) { + List randomDagsList = new ArrayList<>(); + + // Create shared nodes + List sharedNodes = new ArrayList<>(); + for (int i = 0; i < numVariables; i++) { + sharedNodes.add(new GraphNode("Node" + i)); + } + + // Set seed + RandomUtil.getInstance().setSeed(seed); + + // Generate DAGs + for (int i = 0; i < numDags; i++) { + Dag randomDag = RandomGraph.randomDag( + sharedNodes, + 0, // Latent variables (0 by default) + maxEdges, + maxDegree, + maxInDegree, + maxOutDegree, + connected + ); + randomDagsList.add(randomDag); + } + + return randomDagsList; + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVotingTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVotingTest.java new file mode 100644 index 0000000..253c19e --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/HeuristicConsensusMVotingTest.java @@ -0,0 +1,111 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.Node; + +public class HeuristicConsensusMVotingTest { + + private Dag createSimpleDag(String from, String to) { + Node n1 = new GraphNode(from); + Node n2 = new GraphNode(to); + Dag dag = new Dag(Arrays.asList(n1, n2)); + dag.addDirectedEdge(n1, n2); + return dag; + } + + @Test + public void testFusionCreatesDAGWithExpectedEdges() { + Dag dag1 = createSimpleDag("A", "B"); + Dag dag2 = createSimpleDag("A", "B"); + ArrayList dags = new ArrayList<>(Arrays.asList(dag1, dag2)); + + HeuristicConsensusMVoting mvoting = new HeuristicConsensusMVoting(dags, 0.5); + Dag consensus = mvoting.fusion(); + + assertNotNull(consensus); + assertTrue(GraphUtils.isDag(consensus)); + assertEquals(2, consensus.getNumNodes()); + assertEquals(1, consensus.getNumEdges()); + assertTrue(consensus.isParentOf(getNodeByName(consensus, "A"), getNodeByName(consensus, "B"))); + + // Test getters + assertEquals(0.5, mvoting.getPercentage()); + assertEquals(2, mvoting.getVariables().size()); + assertEquals(2, mvoting.getWeight().length); + assertEquals(2, mvoting.getWeight()[0].length); + assertEquals(2, mvoting.getWeight()[1].length); + assertEquals(0.0, mvoting.getWeight()[0][1]); + assertEquals(0.0, mvoting.getWeight()[1][0]); + assertEquals(dags, mvoting.getSetOfdags()); + assertEquals(consensus, mvoting.getOutputDag()); + + } + + @Test + public void testFusionDoesNotAddLowWeightEdges() { + Dag dag1 = createSimpleDag("A", "B"); + Dag dag2 = createSimpleDag("B", "A"); // Conflicting direction + + ArrayList dags = new ArrayList<>(Arrays.asList(dag1, dag2)); + HeuristicConsensusMVoting mvoting = new HeuristicConsensusMVoting(dags, 0.75); + Dag consensus = mvoting.fusion(); + + // Expect no edge because weight is 0.5 < 0.75 + assertEquals(0, consensus.getNumEdges()); + } + + @Test + public void testFusionDoesNotCreateCycle() { + Node a = new GraphNode("A"); + Node b = new GraphNode("B"); + Node c = new GraphNode("C"); + + Dag dag1 = new Dag(Arrays.asList(a, b, c)); + dag1.addDirectedEdge(a, b); + dag1.addDirectedEdge(b, c); + + Dag dag2 = new Dag(Arrays.asList(a, b, c)); + dag2.addDirectedEdge(a, b); + dag2.addDirectedEdge(b, c); + + ArrayList dags = new ArrayList<>(Arrays.asList(dag1, dag2)); + HeuristicConsensusMVoting mvoting = new HeuristicConsensusMVoting(dags, 0.5); + Dag consensus = mvoting.fusion(); + + assertTrue(GraphUtils.isDag(consensus), "La fusión no debe crear ciclos."); + } + + @Test + public void testWeightMatrixCorrectlyComputed() { + Dag dag1 = createSimpleDag("A", "B"); + Dag dag2 = createSimpleDag("A", "B"); + ArrayList dags = new ArrayList<>(Arrays.asList(dag1, dag2)); + + HeuristicConsensusMVoting mvoting = new HeuristicConsensusMVoting(dags, 0.5); + + int indexA = mvoting.getVariables().indexOf(new GraphNode("A")); + int indexB = mvoting.getVariables().indexOf(new GraphNode("B")); + + double weightAB = mvoting.getWeight()[indexA][indexB]; + double expectedWeight = 1.0; // Dos DAGS, misma dirección + + assertEquals(expectedWeight, weightAB, 1e-6); + } + + private Node getNodeByName(Dag dag, String name) { + return dag.getNodes().stream() + .filter(n -> n.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Node not found: " + name)); + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNsTest.java new file mode 100644 index 0000000..462dbb0 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/HierarchicalAgglomerativeClustererBNsTest.java @@ -0,0 +1,112 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; + +public class HierarchicalAgglomerativeClustererBNsTest { + + private ArrayList inputDags; + + @BeforeEach + public void setUp() { + int numVariables = 4; // Number of variables in the DAGs + int numDags = 10; // Number of DAGs to generate + int maxEdges = 6; // Maximum number of edges in each DAG + int maxInDegree = 2; // Maximum in-degree for each node + int maxOutDegree = 2; // Maximum out-degree for each node + int maxDegree = 3; // Maximum degree for each node + boolean connected = true; // Whether the DAGs should be connected + long seed = 42; // Seed for random number generation + + // Generate a list of random DAGs using GraphTestHelper + inputDags = new ArrayList<>(); + inputDags.addAll(GraphTestHelper.generateRandomDagList(numVariables, numDags, maxEdges, maxInDegree, maxOutDegree, maxDegree, connected, seed)); + } + + + @Test + public void testConstructorAndGetSetOfBNs() { + + HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 2); + + assertEquals(inputDags.size(), clusterer.getSetOfBNs().size()); + assertEquals(inputDags, clusterer.getSetOfBNs()); + } + + @Test + public void testClusterStopsEarlyDueToMaxSize() { + + HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 1); + int numDagsAfterCluster = clusterer.cluster(); + + // Only one fusion should occur since maxSize is 1 + assertEquals((int)inputDags.size()/2, numDagsAfterCluster); + } + + @Test + public void testGetClustersOutputAtLevelZero() { + + HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 2); + clusterer.cluster(); + + ArrayList output = clusterer.getClustersOutput(0); + assertEquals(inputDags.size(), output.size(), "En el nivel 0 debe haber tantos DAGs como en la entrada"); + } + + @Test + public void testGetInsertedEdges() { + HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 2); + clusterer.cluster(); + + int insertedEdges = clusterer.getInsertedEdges(1); + assertTrue(insertedEdges >= 0, "Debe haber al menos 0 enlaces insertadas en el nivel 1"); + } + + @Test + public void testComputeConsensusDag() { + HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 2); + int level = clusterer.cluster(); + + Dag consensus = clusterer.computeConsensusDag(level); + assertNotNull(consensus, "El DAG de consenso no debería ser null"); + } + + + + /* + @Test + public void testFullClusteringUntilOneCluster() { + // Sin restricciones de tamaño ni complejidad + HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(inputDags, 0.0); + int result = clusterer.cluster(); + + // Debe haber n-1 fusiones si todo fue bien, por lo que al final solo debe quedar un cluster + assertEquals(1, result, "Deberían haberse hecho n-1 fusiones"); + + ArrayList resultClusters = clusterer.getClustersOutput(result); + assertEquals(1, resultClusters.size(), "Al final solo debe quedar un cluster"); + } + + @Test + public void testClusteringStopsDueToComplexity() { + ArrayList dags = new ArrayList<>(); + dags.addAll(GraphTestHelper.generateRandomDagList(10, 2, 30, 10, 10, 10, true, 42)); + + // Muy bajo el umbral de complejidad para forzar que no se fusionen + HierarchicalAgglomerativeClustererBNs clusterer = new HierarchicalAgglomerativeClustererBNs(dags, 0.01); + int result = clusterer.cluster(); + + assertTrue(result <= 1, "El clustering debería detenerse porque los DAGs fusionados son muy complejos"); + } + */ + +} + + diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/ListFabricTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/ListFabricTest.java new file mode 100644 index 0000000..5a7744f --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/ListFabricTest.java @@ -0,0 +1,70 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ListFabricTest { + + @BeforeEach + public void setUp() { + // Aseguramos el valor por defecto del maxSize antes de cada test + ListFabric.MAX_SIZE = 2; + } + + @Test + public void testGetList_size3_maxSize2() { + List result = Arrays.stream(ListFabric.generateList(3)).boxed().collect(Collectors.toList()); + List expected = Arrays.asList( + 0b000, // [] + 0b001, // [C] + 0b010, // [B] + 0b100, // [A] + 0b011, // [B,C] + 0b101, // [A,C] + 0b110 // [A,B] + ); + assertEquals(expected.size(), result.size()); + assertTrue(result.containsAll(expected)); + } + + @Test + public void testGetList_size0() { + List result = Arrays.stream(ListFabric.generateList(0)).boxed().collect(Collectors.toList()); + assertEquals(1, result.size()); + assertEquals(0, (int) result.get(0)); // Solo el conjunto vacío + } + + @Test + public void testGetList_maxSize0() { + ListFabric.MAX_SIZE = 0; + List result = Arrays.stream(ListFabric.generateList(3)).boxed().collect(Collectors.toList()); + assertEquals(1, result.size()); + assertEquals(0, (int) result.get(0)); // Solo el conjunto vacío + } + + @Test + public void testNoSubsetExceedsMaxSize() { + int size = 4; + List result = Arrays.stream(ListFabric.generateList(size)).boxed().collect(Collectors.toList()); + for (int subset : result) { + int ones = Integer.bitCount(subset); + assertTrue(ones <= ListFabric.MAX_SIZE, + "Subset " + Integer.toBinaryString(subset) + " has " + ones + " bits set"); + } + } + + @Test + public void testSymmetry_sizeEqualsMaxSize() { + int size = 3; + ListFabric.MAX_SIZE = 3; + List result = Arrays.stream(ListFabric.generateList(size)).boxed().collect(Collectors.toList()); + // All subsets should be included (2^3 = 8) + assertEquals(8, result.size()); + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java new file mode 100644 index 0000000..25a14b7 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/PairWiseConsensusBESTest.java @@ -0,0 +1,181 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.Arrays; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; + +public class PairWiseConsensusBESTest { + + private Node A, B, C; + private Dag dag1, dag2; + + @BeforeEach + public void setup() { + A = new GraphNode("A"); + B = new GraphNode("B"); + C = new GraphNode("C"); + + dag1 = new Dag(); + dag1.addNode(A); + dag1.addNode(B); + dag1.addNode(C); + dag1.addDirectedEdge(A, B); + dag1.addDirectedEdge(B, C); + + dag2 = new Dag(); + dag2.addNode(A); + dag2.addNode(B); + dag2.addNode(C); + dag2.addDirectedEdge(A, C); + dag2.addDirectedEdge(C, B); + } + + @Test + public void testFusionCreatesNonNullDag() { + PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2); + pwc.fusion(); + + Dag fusion = pwc.getDagFusion(); + assertNotNull(fusion, "The fusion DAG should not be null"); + } + + @Test + public void testGetNumberOfInsertedEdges() { + PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2); + pwc.fusion(); + + int inserted = pwc.getNumberOfInsertedEdges(); + assertTrue(inserted >= 0, "Inserted edges should be >= 0"); + } + + @Test + public void testGetNumberOfUnionEdges() { + PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2); + pwc.fusion(); + + PairWiseConsensusBES samePwc = new PairWiseConsensusBES(dag1, dag1); + samePwc.fusion(); + + int unionEdges = pwc.getNumberOfUnionEdges(); + assertTrue(unionEdges > 0, "Union should contain some edges"); + + int sameUnionEdges = samePwc.getNumberOfUnionEdges(); + assertTrue(sameUnionEdges == dag1.getNumEdges(), "Union edges should match the number of edges in a single DAG"); + } + + @Test + public void testGetHammingDistance() { + PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2); + int distance = pwc.calculateHammingDistance(); + PairWiseConsensusBES samePwc = new PairWiseConsensusBES(dag1, dag1); + int sameDistance = samePwc.calculateHammingDistance(); + + assertTrue(distance >= 0, "Hamming distance should be >= 0"); + assertTrue(sameDistance == 0, "Hamming distance for identical DAGs should be 0"); + } + + @Test + public void testRunCallsGetFusion() { + PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2); + pwc.run(); + + assertNotNull(pwc.getDagFusion(), "Fusion DAG should be created after run()"); + assertTrue(pwc.getNumberOfUnionEdges() > 0, "Union edges should be computed"); + } + + @Test + public void testFusionIsConsistentWithInput() { + PairWiseConsensusBES pwc = new PairWiseConsensusBES(dag1, dag2); + pwc.fusion(); + Dag fusion = pwc.getDagFusion(); + + Set edges = fusion.getEdges(); + assertNotNull(edges); + assertFalse(edges.isEmpty(), "Fusion DAG should contain edges"); + } + + // Exception throws test cases + + @Test + public void testNullDags() { + IllegalArgumentException ex1 = assertThrows(IllegalArgumentException.class, + () -> new PairWiseConsensusBES(null, createValidDag())); + assertEquals("Input DAGs cannot be null.", ex1.getMessage()); + + IllegalArgumentException ex2 = assertThrows(IllegalArgumentException.class, + () -> new PairWiseConsensusBES(createValidDag(), null)); + assertEquals("Input DAGs cannot be null.", ex2.getMessage()); + } + + @Test + public void testEmptyNodes() { + Dag dag1 = new Dag(); // no nodes + Dag dag2 = createValidDag(); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> new PairWiseConsensusBES(dag1, dag2)); + assertEquals("Input DAGs must contain at least one node.", ex.getMessage()); + } + + @Test + public void testEmptyEdges() { + Node n1 = new GraphNode("X"); + Node n2 = new GraphNode("Y"); + Dag dag1 = new Dag(Arrays.asList(n1, n2)); // 2 nodes, no edges + Dag dag2 = createValidDag(); // tiene al menos un edge + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> new PairWiseConsensusBES(dag1, dag2)); + assertEquals("Input DAGs must contain at least one edge.", ex.getMessage()); + } + + @Test + public void testDifferentNodeCounts() { + Dag dag1 = createValidDag(); + Dag dag2 = createValidDag(); + dag2.addNode(new GraphNode("Extra")); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> new PairWiseConsensusBES(dag1, dag2)); + assertEquals("Input DAGs must have the same number of nodes.", ex.getMessage()); + } + + @Test + public void testDifferentNodeSets() { + Node n1 = new GraphNode("A"); + Node n2 = new GraphNode("B"); + Node n3 = new GraphNode("C"); + + Dag dag1 = new Dag(Arrays.asList(n1, n2)); + dag1.addDirectedEdge(n1, n2); + + Dag dag2 = new Dag(Arrays.asList(n1, n3)); // C en vez de B + dag2.addDirectedEdge(n1, n3); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> new PairWiseConsensusBES(dag1, dag2)); + assertEquals("Input DAGs must have the same set of nodes.", ex.getMessage()); + } + + // Helper para crear un DAG válido + private Dag createValidDag() { + Node a = new GraphNode("A"); + Node b = new GraphNode("B"); + + Dag dag = new Dag(Arrays.asList(a, b)); + dag.addDirectedEdge(a, b); + return dag; + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/PowerSetTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/PowerSetTest.java new file mode 100644 index 0000000..437d97c --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/PowerSetTest.java @@ -0,0 +1,89 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; + +public class PowerSetTest { + + private List nodeList; + + @BeforeEach + public void setUp() { + nodeList = new ArrayList<>(); + nodeList.add(new GraphNode("X")); + nodeList.add(new GraphNode("Y")); + nodeList.add(new GraphNode("Z")); + } + + @Test + public void testPowerSetWithMaxSize() { + PowerSet ps = new PowerSet(nodeList, 2); + List> result = new ArrayList<>(); + while (ps.hasMoreElements()) { + result.add(ps.nextElement()); + } + + // Verifica que no hay subconjuntos de tamaño > 2 + for (Set subset : result) { + assertTrue(subset.size() <= 2, "Subset size should be <= 2"); + } + + // Comprobamos algunos subconjuntos esperados + HashSet expected1 = new HashSet<>(); + expected1.add(nodeList.get(0)); + HashSet expected2 = new HashSet<>(); + expected2.add(nodeList.get(0)); + expected2.add(nodeList.get(1)); + assertTrue(result.contains(new HashSet<>(expected1))); + assertTrue(result.contains(new HashSet<>(expected2))); + } + + @Test + public void testPowerSetWithoutMaxSize() { + PowerSet ps = new PowerSet(nodeList); + List> result = new ArrayList<>(); + while (ps.hasMoreElements()) { + result.add(ps.nextElement()); + } + + // Número de subconjuntos debería ser 2^n + assertEquals(7, result.size()); + + // El conjunto vacío debería estar incluido + assertTrue(result.contains(new HashSet())); + + // El conjunto completo no está incluido en el resultado + assertTrue(!result.contains(new HashSet<>(nodeList))); + } + + @Test + public void testResetIndex() { + PowerSet ps = new PowerSet(nodeList); + assertTrue(ps.hasMoreElements()); + ps.nextElement(); + ps.resetIndex(); + assertTrue(ps.hasMoreElements()); + } + + @Test + public void testMaxPowerSetSize() { + PowerSet powerSet = new PowerSet(nodeList); // Esto actualiza el valor de maxPow + assertEquals(8L, powerSet.maxPowerSetSize()); + } + + @Test + public void testMaxSizeIsNegativeShouldThrowException() { + assertThrows(IllegalArgumentException.class, () -> new PowerSet(nodeList, -1)); + } +} diff --git a/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java b/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java new file mode 100644 index 0000000..86cb1d6 --- /dev/null +++ b/src/test/java/es/uclm/i3a/simd/consensusBN/TransformDagsTest.java @@ -0,0 +1,137 @@ +package es.uclm.i3a.simd.consensusBN; + +import java.util.ArrayList; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import edu.cmu.tetrad.graph.Dag; +import edu.cmu.tetrad.graph.GraphNode; +import edu.cmu.tetrad.graph.Node; + +public class TransformDagsTest { + + private ArrayList inputDags; + private ArrayList alpha; + + @BeforeEach + public void setUp() { + inputDags = new ArrayList<>(); + + // We use 4 nodes for the DAGs + Node nodeA = new GraphNode("A"); + Node nodeB = new GraphNode("B"); + Node nodeC = new GraphNode("C"); + Node nodeD = new GraphNode("D"); + + // Create first DAG with these edges: A -> B, A -> C, B -> D, C -> D + Dag dag1 = new Dag(); + dag1.addNode(nodeA); + dag1.addNode(nodeB); + dag1.addNode(nodeC); + dag1.addNode(nodeD); + + // Adding directed edges to the DAG + dag1.addDirectedEdge(nodeA, nodeB); + dag1.addDirectedEdge(nodeA, nodeC); + dag1.addDirectedEdge(nodeB, nodeD); + dag1.addDirectedEdge(nodeC, nodeD); + + // Adding the DAG to the list + inputDags.add(dag1); + + // Create second DAG with these edges: D -> C, D -> B, C -> A, B -> A + Dag dag2 = new Dag(); + dag2.addNode(nodeA); + dag2.addNode(nodeB); + dag2.addNode(nodeC); + dag2.addNode(nodeD); + + // Adding directed edges to the second DAG + dag2.addDirectedEdge(nodeD, nodeC); + dag2.addDirectedEdge(nodeD, nodeB); + dag2.addDirectedEdge(nodeC, nodeA); + dag2.addDirectedEdge(nodeB, nodeA); + + // Adding the second DAG to the list + inputDags.add(dag2); + + // Apply AlphaOrder algorithm to these dags: + AlphaOrder alphaOrder = new AlphaOrder(inputDags); + alphaOrder.computeAlpha(); + alpha = alphaOrder.getOrder(); + + } + + @Test + public void testConstructorGettersAndSetters() { + TransformDags transformer = new TransformDags(inputDags, alpha); + ArrayList empytBetas = new ArrayList<>(); + + assertNotNull(transformer); + assertEquals(0, transformer.getNumberOfInsertedEdges()); + + // GetSetOfOutputDags + ArrayList outputDags = transformer.getSetOfOutputDags(); + assertNotNull(outputDags); + assertTrue(outputDags.isEmpty()); + transformer.transform(); + outputDags = transformer.getSetOfOutputDags(); + assertNotNull(outputDags); + assertEquals(inputDags.size(), outputDags.size()); + assertNotEquals(inputDags, outputDags); + + // Testing setTransformers and getTransformers + ArrayList betas = transformer.getTransformers(); + assertNotNull(betas); + assertTrue(!betas.isEmpty()); + transformer.setTransformers(empytBetas); + assertEquals(empytBetas, transformer.getTransformers()); + assertTrue(transformer.getTransformers().isEmpty()); + + // GetAlphaOrder + ArrayList alphaOrder = transformer.getAlpha(); + assertNotNull(alphaOrder); + assertEquals(alpha, alphaOrder); + + // GetSetOfDags + ArrayList dags = transformer.getSetOfDags(); + assertNotNull(dags); + assertEquals(inputDags, dags); + + + } + + @Test + public void testTransformReturnsCorrectSize() { + TransformDags transformer = new TransformDags(inputDags, alpha); + ArrayList result = transformer.transform(); + + assertNotNull(result); + assertEquals(inputDags.size(), result.size()); + } + + @Test + public void testTransformUpdatesNumberOfInsertedEdges() { + TransformDags transformer = new TransformDags(inputDags, alpha); + transformer.transform(); + + // No sabemos cuántas aristas se insertan exactamente sin saber cómo funciona BetaToAlpha, + // pero al menos podemos comprobar que el valor no es negativo. + assertTrue(transformer.getNumberOfInsertedEdges() >= 0); + } + + @Test + public void testEmptyDagListReturnsEmptyOutput() { + TransformDags transformer = new TransformDags(new ArrayList<>(), alpha); + ArrayList result = transformer.transform(); + + assertTrue(result.isEmpty()); + assertEquals(0, transformer.getNumberOfInsertedEdges()); + } + +}