From fec0eb7afb9a87e32e974abe1c13adc759190242 Mon Sep 17 00:00:00 2001 From: Sean Phillips Date: Tue, 28 Oct 2025 17:04:08 -0400 Subject: [PATCH 1/2] Fixes to logP and cov functions for GMMs --- .../components/panes/Shape3DControlPane.java | 2 +- .../javafx/handlers/ManifoldEventHandler.java | 5 -- .../javafx/javafx3d/AsteroidFieldPane.java | 2 +- .../trinity/javafx/javafx3d/Manifold3D.java | 21 ++++---- .../utils/clustering/ClusterUtils.java | 24 +++++++++ .../clustering/GaussianDistribution.java | 26 ++++----- .../utils/clustering/GaussianMixture.java | 53 ++++++++++++------- .../jhuapl/trinity/fxml/ManifoldControl.fxml | 4 +- 8 files changed, 87 insertions(+), 50 deletions(-) diff --git a/src/main/java/edu/jhuapl/trinity/javafx/components/panes/Shape3DControlPane.java b/src/main/java/edu/jhuapl/trinity/javafx/components/panes/Shape3DControlPane.java index 34b89d74..51d3d47b 100644 --- a/src/main/java/edu/jhuapl/trinity/javafx/components/panes/Shape3DControlPane.java +++ b/src/main/java/edu/jhuapl/trinity/javafx/components/panes/Shape3DControlPane.java @@ -121,7 +121,7 @@ private void buildFindClustersTab() { findClustersTab.setContent(findClusterBorderPane); componentsSpinner = new Spinner( - new SpinnerValueFactory.IntegerSpinnerValueFactory(2, 20, 5, 1)); + new SpinnerValueFactory.IntegerSpinnerValueFactory(2, 500, 5, 1)); componentsSpinner.setPrefWidth(SPINNER_PREF_WIDTH); componentsSpinner.setEditable(true); iterationsSpinner = new Spinner( diff --git a/src/main/java/edu/jhuapl/trinity/javafx/handlers/ManifoldEventHandler.java b/src/main/java/edu/jhuapl/trinity/javafx/handlers/ManifoldEventHandler.java index fba641e0..6e6ff85e 100644 --- a/src/main/java/edu/jhuapl/trinity/javafx/handlers/ManifoldEventHandler.java +++ b/src/main/java/edu/jhuapl/trinity/javafx/handlers/ManifoldEventHandler.java @@ -298,11 +298,6 @@ public void handleExport(ManifoldEvent event) { } public void handleNewManifoldData(ManifoldEvent event) { -// Platform.runLater(() -> { -// App.getAppScene().getRoot().fireEvent( -// new CommandTerminalEvent("Loading Manifold Data...", -// new Font("Consolas", 20), Color.GREEN)); -// }); // System.out.println("Loading Manifold Data..."); ManifoldData md = (ManifoldData) event.object1; //convert deserialized points to Fxyz3D point3ds diff --git a/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/AsteroidFieldPane.java b/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/AsteroidFieldPane.java index 55af956d..916edf67 100644 --- a/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/AsteroidFieldPane.java +++ b/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/AsteroidFieldPane.java @@ -757,7 +757,7 @@ private void resetAsteroids() { public Manifold3D makeHull(List labelMatchedPoints, String label, Double tolerance) { Manifold3D manifold3D = new Manifold3D( - labelMatchedPoints, true, true, true, tolerance + labelMatchedPoints, true, false, false, tolerance ); manifold3D.quickhullMeshView.setCullFace(CullFace.FRONT); // manifold3D.addEventHandler(MouseEvent.MOUSE_CLICKED, e -> { diff --git a/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/Manifold3D.java b/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/Manifold3D.java index 1e85a832..95b9c226 100644 --- a/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/Manifold3D.java +++ b/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/Manifold3D.java @@ -77,7 +77,7 @@ public class Manifold3D extends Group { public Manifold3D(List point3DList, boolean triangulate, boolean makeLines, boolean makePoints, Double tolerance) { originalPoint3Ds = point3DList; - buildHullMesh(point3DList, triangulate, makeLines, makePoints, tolerance); + buildHullMesh(point3DList, triangulate, tolerance); List fxyzPoints = new ArrayList<>(); for (int i = 0; i < hull.getNumVertices(); i++) { @@ -349,7 +349,7 @@ public void refreshMesh(List point3DList, boolean triangulate, boolean quickhullLinesTriangleMesh.getPoints().clear(); quickhullLinesTriangleMesh.getTexCoords().clear(); quickhullLinesTriangleMesh.getFaces().clear(); - buildHullMesh(point3DList, triangulate, makeLines, makePoints, tolerance); + buildHullMesh(point3DList, triangulate, tolerance); quickhullMeshView.setMesh(quickhullTriangleMesh); if (makeLines) { quickhullLinesTriangleMesh.getPoints().addAll(quickhullTriangleMesh.getPoints()); @@ -361,7 +361,7 @@ public void refreshMesh(List point3DList, boolean triangulate, boolean // makeDebugPoints(hull, artScale, false); } - private void buildHullMesh(List point3DList, boolean triangulate, boolean makeLines, boolean makePoints, Double tolerance) { + private void buildHullMesh(List point3DList, boolean triangulate, Double tolerance) { hull = new QuickHull3D(); if (null != tolerance) hull.setExplicitDistanceTolerance(tolerance); @@ -455,19 +455,21 @@ public void handle(long now) { } public void makeLines() { + boolean wasVisible = null != quickhullLinesMeshView + ? quickhullLinesMeshView.isVisible() : false; quickhullLinesTriangleMesh = new TriangleMesh(); quickhullLinesTriangleMesh.getPoints().addAll(quickhullTriangleMesh.getPoints()); quickhullLinesTriangleMesh.getTexCoords().addAll(quickhullTriangleMesh.getTexCoords()); quickhullLinesTriangleMesh.getFaces().addAll(quickhullTriangleMesh.getFaces()); quickhullLinesMeshView = new MeshView(quickhullLinesTriangleMesh); - PhongMaterial quickhullLinesMaterial = new PhongMaterial(Color.BLUE); - quickhullLinesMaterial.setSpecularColor(Color.BLUE); //fix for aarch64 Mac Ventura + PhongMaterial quickhullLinesMaterial = new PhongMaterial(Color.ALICEBLUE); + quickhullLinesMaterial.setSpecularColor(Color.ALICEBLUE); //fix for aarch64 Mac Ventura quickhullLinesMeshView.setMaterial(quickhullLinesMaterial); quickhullLinesMeshView.setDrawMode(DrawMode.LINE); quickhullLinesMeshView.setCullFace(CullFace.NONE); quickhullLinesMeshView.setMouseTransparent(true); - + quickhullLinesMeshView.setVisible(wasVisible); getChildren().add(quickhullLinesMeshView); } @@ -489,14 +491,15 @@ public void makeDebugPoints(QuickHull3D hull, float scale, boolean print) { sb.append(", "); } - Sphere sphere = new Sphere(2.5); - PhongMaterial mat = new PhongMaterial(Color.BLUE); - mat.setSpecularColor(Color.BLUE); // fix for aarch64 Mac Ventura + Sphere sphere = new Sphere(1.5); + PhongMaterial mat = new PhongMaterial(Color.ALICEBLUE); + mat.setSpecularColor(Color.ALICEBLUE); // fix for aarch64 Mac Ventura sphere.setMaterial(mat); sphere.setTranslateX(point3D.x); sphere.setTranslateY(point3D.y); sphere.setTranslateZ(point3D.z); extrasGroup.getChildren().add(sphere); + sphere.setVisible(false); Label newLabel = new Label(String.valueOf(i)); labelGroup.getChildren().addAll(newLabel); diff --git a/src/main/java/edu/jhuapl/trinity/utils/clustering/ClusterUtils.java b/src/main/java/edu/jhuapl/trinity/utils/clustering/ClusterUtils.java index cbe723c6..c78806a6 100644 --- a/src/main/java/edu/jhuapl/trinity/utils/clustering/ClusterUtils.java +++ b/src/main/java/edu/jhuapl/trinity/utils/clustering/ClusterUtils.java @@ -5,6 +5,8 @@ import static java.lang.Math.abs; import static java.lang.Math.sqrt; +import java.util.ArrayList; +import java.util.List; /** * Cherry Picked Math functions used for GMM processing. @@ -55,7 +57,29 @@ public static double squaredDistanceWithMissingValues(double[] x, double[] y) { return dist; } + public static List> extractGMMClusters( + double[][] data, + GaussianMixture gmm, + double threshold) + { + List> clusterPoints = new ArrayList<>(); + for (int i = 0; i < gmm.components.length; i++) { + clusterPoints.add(new ArrayList<>()); + } + + for (double[] x : data) { + double[] post = gmm.posteriori(x); + int k = ClusterUtils.whichMax(post); + if (post[k] >= threshold) { + clusterPoints.get(k).add(x); + } + } + + // Filter out tiny/degenerate clusters if needed + clusterPoints.removeIf(list -> list.size() < 4); + return clusterPoints; + } /** * Returns the sum of an array. * diff --git a/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianDistribution.java b/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianDistribution.java index 4283c08f..00219a8c 100644 --- a/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianDistribution.java +++ b/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianDistribution.java @@ -199,19 +199,19 @@ public RealMatrix cov() { public double scatter() { return sigmaDet; } - +public double mahalanobis2(double[] x) { + double[] v = x.clone(); + ClusterUtils.sub(v, mu); + double[] Av = sigmaInv.operate(v); + return ClusterUtils.dot(v, Av); +} public double logp(double[] x) { - if (x.length != dim) { - throw new IllegalArgumentException("Sample has different dimension."); - } - + if (x.length != dim) throw new IllegalArgumentException("Sample has different dimension."); double[] v = x.clone(); - ClusterUtils.sub(v, mu); -// double result = sigmaInv.xAx(v) / -2.0; -// double[] Ax = mv(x); - double[] Ax = sigmaInv.operate(v); - double result = ClusterUtils.dot(x, Ax) / -2.0; - return result - pdfConstant; + ClusterUtils.sub(v, mu); // v = x - μ + double[] Av = sigmaInv.operate(v); // Σ⁻¹ v + double quad = ClusterUtils.dot(v, Av); // vᵀ Σ⁻¹ v + return -0.5 * quad - pdfConstant; } public double p(double[] x) { @@ -471,7 +471,9 @@ public double logLikelihood(double[][] x) { return L; } - + public int dim() { + return dim; + } @Override public String toString() { return String.format("Gaussian(mu = %s, sigma = %s)", Arrays.toString(mu), sigma); diff --git a/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianMixture.java b/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianMixture.java index b2bc5c1f..59b5399d 100644 --- a/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianMixture.java +++ b/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianMixture.java @@ -240,32 +240,45 @@ public double[] mean() { } public RealMatrix cov() { - double w = components[0].priori(); - RealMatrix v = components[0].distribution().cov(); - - int m = v.getRowDimension(); - int n = v.getColumnDimension(); - RealMatrix cov = MatrixUtils.createRealMatrix(m, n); + double[] mu = mean(); + int d = mu.length; + RealMatrix C = MatrixUtils.createRealMatrix(d, d); - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - cov.setEntry(i, j, w * w * v.getEntry(i, j)); - } + // within-component variance + for (GaussianMixtureComponent c : components) { + double w = c.priori; + RealMatrix Sk = c.distribution.cov(); + C = C.add(Sk.scalarMultiply(w)); } - for (int k = 1; k < components.length; k++) { - w = components[k].priori(); - v = components[k].distribution().cov(); - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - cov.addToEntry(i, j, w * w * v.getEntry(i, j)); - } - } + // between-component variance + for (GaussianMixtureComponent c : components) { + double w = c.priori; + double[] mk = c.distribution.mean(); + double[] diff = mk.clone(); + ClusterUtils.sub(diff, mu); + RealMatrix outer = MatrixUtils.createColumnRealMatrix(diff) + .multiply(MatrixUtils.createRowRealMatrix(diff)); + C = C.add(outer.scalarMultiply(w)); } - - return cov; + return C; } +public boolean inDistribution(double[] x, double q) { + // choose most responsible component + double[] r = posteriori(x); + int idx = ClusterUtils.whichMax(r); + GaussianMixtureComponent c = components[idx]; + double d2 = c.distribution.mahalanobis2(x); + // chi-square threshold + org.apache.commons.math3.distribution.ChiSquaredDistribution chi = + new org.apache.commons.math3.distribution.ChiSquaredDistribution(c.distribution.dim()); + double thresh = chi.inverseCumulativeProbability(q); + return d2 <= thresh; +} +public boolean inDistributionByLogP(double[] x, double tau) { + return Math.log(p(x)) >= tau; +} public Pair maxPostProb(double[] x) { int k = components.length; double[] prob = new double[k]; diff --git a/src/main/resources/edu/jhuapl/trinity/fxml/ManifoldControl.fxml b/src/main/resources/edu/jhuapl/trinity/fxml/ManifoldControl.fxml index a001ca73..26cb0b62 100644 --- a/src/main/resources/edu/jhuapl/trinity/fxml/ManifoldControl.fxml +++ b/src/main/resources/edu/jhuapl/trinity/fxml/ManifoldControl.fxml @@ -450,9 +450,9 @@ - - From 6fcc817f998be3f9b1accb946b389ae1e3930153 Mon Sep 17 00:00:00 2001 From: samypr100 <3933065+samypr100@users.noreply.github.com> Date: Sun, 2 Nov 2025 21:32:14 -0500 Subject: [PATCH 2/2] fix: rebase conflicts --- build.gradle | 2 +- .../trinity/javafx/javafx3d/Manifold3D.java | 2 +- .../utils/clustering/ClusterUtils.java | 14 ++++--- .../clustering/GaussianDistribution.java | 18 ++++---- .../utils/clustering/GaussianMixture.java | 42 ++++++++++--------- 5 files changed, 43 insertions(+), 35 deletions(-) diff --git a/build.gradle b/build.gradle index 422753c1..c5e4ea10 100644 --- a/build.gradle +++ b/build.gradle @@ -367,7 +367,7 @@ assemble { jlink { def fileSep = System.getProperty('file.separator') def imageZipFile = layout.buildDirectory.file("${artifactNameUpper}-${project.version}.zip") - options.set(['--strip-debug', '--compress', '2', '--no-header-files', '--no-man-pages']) + options.set(['--strip-debug', '--compress', 'zip-6', '--no-header-files', '--no-man-pages']) imageZip.set(imageZipFile) launcher { def currentOS = org.gradle.internal.os.OperatingSystem.current() diff --git a/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/Manifold3D.java b/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/Manifold3D.java index 95b9c226..81b0276d 100644 --- a/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/Manifold3D.java +++ b/src/main/java/edu/jhuapl/trinity/javafx/javafx3d/Manifold3D.java @@ -455,7 +455,7 @@ public void handle(long now) { } public void makeLines() { - boolean wasVisible = null != quickhullLinesMeshView + boolean wasVisible = null != quickhullLinesMeshView ? quickhullLinesMeshView.isVisible() : false; quickhullLinesTriangleMesh = new TriangleMesh(); quickhullLinesTriangleMesh.getPoints().addAll(quickhullTriangleMesh.getPoints()); diff --git a/src/main/java/edu/jhuapl/trinity/utils/clustering/ClusterUtils.java b/src/main/java/edu/jhuapl/trinity/utils/clustering/ClusterUtils.java index c78806a6..81a443f3 100644 --- a/src/main/java/edu/jhuapl/trinity/utils/clustering/ClusterUtils.java +++ b/src/main/java/edu/jhuapl/trinity/utils/clustering/ClusterUtils.java @@ -3,11 +3,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static java.lang.Math.abs; -import static java.lang.Math.sqrt; import java.util.ArrayList; import java.util.List; +import static java.lang.Math.abs; +import static java.lang.Math.sqrt; + /** * Cherry Picked Math functions used for GMM processing. */ @@ -57,11 +58,11 @@ public static double squaredDistanceWithMissingValues(double[] x, double[] y) { return dist; } + public static List> extractGMMClusters( - double[][] data, - GaussianMixture gmm, - double threshold) - { + double[][] data, + GaussianMixture gmm, + double threshold) { List> clusterPoints = new ArrayList<>(); for (int i = 0; i < gmm.components.length; i++) { clusterPoints.add(new ArrayList<>()); @@ -80,6 +81,7 @@ public static List> extractGMMClusters( return clusterPoints; } + /** * Returns the sum of an array. * diff --git a/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianDistribution.java b/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianDistribution.java index 00219a8c..0e29365f 100644 --- a/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianDistribution.java +++ b/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianDistribution.java @@ -199,18 +199,20 @@ public RealMatrix cov() { public double scatter() { return sigmaDet; } -public double mahalanobis2(double[] x) { - double[] v = x.clone(); - ClusterUtils.sub(v, mu); - double[] Av = sigmaInv.operate(v); - return ClusterUtils.dot(v, Av); -} + + public double mahalanobis2(double[] x) { + double[] v = x.clone(); + ClusterUtils.sub(v, mu); + double[] Av = sigmaInv.operate(v); + return ClusterUtils.dot(v, Av); + } + public double logp(double[] x) { if (x.length != dim) throw new IllegalArgumentException("Sample has different dimension."); double[] v = x.clone(); ClusterUtils.sub(v, mu); // v = x - μ double[] Av = sigmaInv.operate(v); // Σ⁻¹ v - double quad = ClusterUtils.dot(v, Av); // vᵀ Σ⁻¹ v + double quad = ClusterUtils.dot(v, Av); // vᵀ Σ⁻¹ v return -0.5 * quad - pdfConstant; } @@ -471,9 +473,11 @@ public double logLikelihood(double[][] x) { return L; } + public int dim() { return dim; } + @Override public String toString() { return String.format("Gaussian(mu = %s, sigma = %s)", Arrays.toString(mu), sigma); diff --git a/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianMixture.java b/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianMixture.java index 59b5399d..1d13a000 100644 --- a/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianMixture.java +++ b/src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianMixture.java @@ -246,39 +246,41 @@ public RealMatrix cov() { // within-component variance for (GaussianMixtureComponent c : components) { - double w = c.priori; - RealMatrix Sk = c.distribution.cov(); + double w = c.priori(); + RealMatrix Sk = c.distribution().cov(); C = C.add(Sk.scalarMultiply(w)); } // between-component variance for (GaussianMixtureComponent c : components) { - double w = c.priori; - double[] mk = c.distribution.mean(); + double w = c.priori(); + double[] mk = c.distribution().mean(); double[] diff = mk.clone(); ClusterUtils.sub(diff, mu); RealMatrix outer = MatrixUtils.createColumnRealMatrix(diff) - .multiply(MatrixUtils.createRowRealMatrix(diff)); + .multiply(MatrixUtils.createRowRealMatrix(diff)); C = C.add(outer.scalarMultiply(w)); } return C; } -public boolean inDistribution(double[] x, double q) { - // choose most responsible component - double[] r = posteriori(x); - int idx = ClusterUtils.whichMax(r); - GaussianMixtureComponent c = components[idx]; - double d2 = c.distribution.mahalanobis2(x); - // chi-square threshold - org.apache.commons.math3.distribution.ChiSquaredDistribution chi = - new org.apache.commons.math3.distribution.ChiSquaredDistribution(c.distribution.dim()); - double thresh = chi.inverseCumulativeProbability(q); - return d2 <= thresh; -} -public boolean inDistributionByLogP(double[] x, double tau) { - return Math.log(p(x)) >= tau; -} + public boolean inDistribution(double[] x, double q) { + // choose most responsible component + double[] r = posteriori(x); + int idx = ClusterUtils.whichMax(r); + GaussianMixtureComponent c = components[idx]; + double d2 = c.distribution().mahalanobis2(x); + // chi-square threshold + org.apache.commons.math3.distribution.ChiSquaredDistribution chi = + new org.apache.commons.math3.distribution.ChiSquaredDistribution(c.distribution().dim()); + double thresh = chi.inverseCumulativeProbability(q); + return d2 <= thresh; + } + + public boolean inDistributionByLogP(double[] x, double tau) { + return Math.log(p(x)) >= tau; + } + public Pair maxPostProb(double[] x) { int k = components.length; double[] prob = new double[k];