Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
230a4bb
First working version of extended isolation forest training and scori…
jverbus Mar 9, 2025
39810d7
Updated rough draft code for EIF.
jverbus Mar 10, 2025
e6fbc8d
Refactor Extended Isolation Forest for clearer logic more in line wit…
jverbus Mar 13, 2025
b88ea95
Got standard isolation forest R/W working after major refactor. Still…
jverbus Mar 14, 2025
bceea39
Fixed package structure.
jverbus Mar 14, 2025
bbc25a0
WORK IN PROGRESS - Have prototype extended isolation forest read / wr…
jverbus Mar 24, 2025
9c5b224
Did linting for eif code.
jverbus Mar 25, 2025
a56c1ac
fix(EIF): align hyperplane split + path test with paper; correct inte…
jverbus Aug 30, 2025
817501c
fix(EIF): retry degenerate hyperplane splits instead of premature l…
jverbus Mar 10, 2026
306f84e
chore(EIF): remove dead code, unused imports, and fix test descriptio…
jverbus Mar 10, 2026
a9b7697
fix(EIF): validate extensionLevel at fit time instead of silent clamping
jverbus Mar 10, 2026
aaf5e7e
fix: fail fast on empty partition in shared tree training
jverbus Mar 10, 2026
b5bf855
fix: use actual tree count instead of numEstimators param in scoring
jverbus Mar 10, 2026
7ef4bf4
test(EIF): replace toString tree comparison with structural equality …
jverbus Mar 10, 2026
ad62bc4
docs: add Extended Isolation Forest documentation to README
jverbus Mar 10, 2026
645361e
Added citation info to readme.
jverbus Mar 10, 2026
b3af421
fix(EIF): use strict < for hyperplane split and stop fit() from mutat…
jverbus Mar 10, 2026
680648d
fix(EIF): match reference implementation split semantics instead of…
jverbus Mar 10, 2026
8cf44ac
docs: update benchmarks with StandardIF, ExtendedIF_0, and ExtendedIF…
jverbus Mar 10, 2026
301f643
docs: update benchmark table with reference Python results for all 13…
jverbus Mar 10, 2026
529e856
fix(EIF): persist resolved extensionLevel on trained model
jverbus Mar 10, 2026
60677a0
test(EIF): add pre-merge tests for zero-size leaves and ext=0 axis-al…
jverbus Mar 10, 2026
b5b145e
docs: soften benchmark claims and clarify EIF_0 vs StandardIF wording
jverbus Mar 10, 2026
afc614e
test(EIF): enable saved model tree structure regression test
jverbus Mar 10, 2026
8a5a216
What was done: Extracted the duplicated validateAndResolveParams meth…
jverbus Mar 11, 2026
58f0d47
refactor: extract duplicated transformSchema into Utils.validateAndTr…
jverbus Mar 11, 2026
a2f2319
chore(EIF): remove unused import, fix docstring, and align threshold …
jverbus Mar 11, 2026
52b4695
test(EIF): add tests for L2-normalized normals, invalid extensionLeve…
jverbus Mar 11, 2026
9e93557
chore(EIF): fix redundant import and stale docstrings in ExtendedIsol…
jverbus Mar 11, 2026
1d211ec
fix: address EIF review findings and harden model edge cases
jverbus Mar 11, 2026
71c5630
docs: refresh README for EIF and current build defaults
jverbus Mar 11, 2026
cd7811d
feat: add extended isolation forest with sparse hyperplane persistence
jverbus Mar 11, 2026
9ea62de
Updated readme.
jverbus Mar 11, 2026
8ae91c8
docs: update README benchmark table and references
jverbus Mar 12, 2026
522a6dc
Added scroll to results table.
jverbus Mar 12, 2026
efa2a11
updated readme
jverbus Mar 12, 2026
b6c3fee
Updated readme.
jverbus Mar 12, 2026
17a07f5
fix(EIF): use float-precision hyperplane weights for Spark 4.x Avro c…
jverbus Mar 12, 2026
9df6ed3
docs: address Copilot review feedback on PR #79
jverbus Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 191 additions & 43 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ import com.linkedin.relevance.isolationforest.core.SharedTrainLogic.{
computeAndSetModelThreshold,
createSampledPartitionedDataset,
trainIsolationTrees,
validateAndResolveParams,
}
import com.linkedin.relevance.isolationforest.core.Utils.{DataPoint, ResolvedParams}
import com.linkedin.relevance.isolationforest.core.Utils.{DataPoint, validateAndTransformSchema}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.types.StructType

/**
* Used to train an isolation forest model. It extends the spark.ml Estimator class.
Expand All @@ -28,7 +28,7 @@ class IsolationForest(override val uid: String)
with DefaultParamsWritable
with Logging {

def this() = this(Identifiable.randomUID("standard-isolation-forest"))
def this() = this(Identifiable.randomUID("isolation-forest"))

override def copy(extra: ParamMap): IsolationForest =

Expand All @@ -55,7 +55,7 @@ class IsolationForest(override val uid: String)

// Validate $(maxFeatures) and $(maxSamples) against input dataset and determine the values
// actually used to train the model: numFeatures and numSamples
val resolvedParams = validateAndResolveParams(dataset)
val resolvedParams = validateAndResolveParams(dataset, $(maxFeatures), $(maxSamples))

// Bag and flatten the data, then repartition it so that each partition corresponds to one
// isolation tree.
Expand Down Expand Up @@ -86,6 +86,7 @@ class IsolationForest(override val uid: String)
isolationTrees,
resolvedParams.numSamples,
resolvedParams.numFeatures,
resolvedParams.totalNumFeatures,
)
.setParent(this),
)
Expand All @@ -103,103 +104,8 @@ class IsolationForest(override val uid: String)
isolationForestModel
}

/**
* Private helper to validate parameters and figure out how many features and samples we'll use.
*
* @param dataset
* The input dataset.
* @return
* A ResolvedParams instance containing the resolved values.
*/
private def validateAndResolveParams(dataset: Dataset[DataPoint]): ResolvedParams = {

// Validate $(maxFeatures) and $(maxSamples) against input dataset and determine the values
// actually used to train the model: numFeatures and numSamples.
val totalNumFeatures = dataset.head().features.length
val numFeatures = if ($(maxFeatures) > 1.0) {
math.floor($(maxFeatures)).toInt
} else {
math.floor($(maxFeatures) * totalNumFeatures).toInt
}
logInfo(
s"User specified number of features used to train each tree over total number of" +
s" features: ${numFeatures} / ${totalNumFeatures}",
)
require(
numFeatures > 0,
s"parameter maxFeatures given invalid value ${$(maxFeatures)}" +
s" specifying the use of ${numFeatures} features, but >0 features are required.",
)
require(
numFeatures <= totalNumFeatures,
s"parameter maxFeatures given invalid value" +
s" ${$(maxFeatures)} specifying the use of ${numFeatures} features, but only" +
s" ${totalNumFeatures} features are available.",
)

val totalNumSamples = dataset.count()
val numSamples = if ($(maxSamples) > 1.0) {
math.floor($(maxSamples)).toInt
} else {
math.floor($(maxSamples) * totalNumSamples).toInt
}
logInfo(
s"User specified number of samples used to train each tree over total number of" +
s" samples: ${numSamples} / ${totalNumSamples}",
)
require(
numSamples > 0,
s"parameter maxSamples given invalid value ${$(maxSamples)}" +
s" specifying the use of ${numSamples} samples, but >0 samples are required.",
)
require(
numSamples <= totalNumSamples,
s"parameter maxSamples given invalid value" +
s" ${$(maxSamples)} specifying the use of ${numSamples} samples, but only" +
s" ${totalNumSamples} samples are in the input dataset.",
)

ResolvedParams(numFeatures, totalNumFeatures, numSamples, totalNumSamples)
}

/**
* Validates the input schema and transforms it into the output schema. It validates that the
* input DataFrame has a $(featuresCol) of the correct type and appends the output columns to the
* input schema. It also ensures that the input DataFrame does not already have $(predictionCol)
* or $(scoreCol) columns, as they will be created during the fitting process.
*
* @param schema
* The schema of the DataFrame containing the data to be fit.
* @return
* The schema of the DataFrame containing the data to be fit, with the additional
* $(predictionCol) and $(scoreCol) columns added.
*/
override def transformSchema(schema: StructType): StructType = {

require(
schema.fieldNames.contains($(featuresCol)),
s"Input column ${$(featuresCol)} does not exist.",
)
require(
schema($(featuresCol)).dataType == VectorType,
s"Input column ${$(featuresCol)} is not of required type ${VectorType}",
)

require(
!schema.fieldNames.contains($(predictionCol)),
s"Output column ${$(predictionCol)} already exists.",
)
require(
!schema.fieldNames.contains($(scoreCol)),
s"Output column ${$(scoreCol)} already exists.",
)

val outputFields = schema.fields :+
StructField($(predictionCol), DoubleType, nullable = false) :+
StructField($(scoreCol), DoubleType, nullable = false)

StructType(outputFields)
}
override def transformSchema(schema: StructType): StructType =
validateAndTransformSchema(schema, $(featuresCol), $(predictionCol), $(scoreCol))
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package com.linkedin.relevance.isolationforest

import com.linkedin.relevance.isolationforest.core.Utils.{DataPoint, avgPathLength}
import com.linkedin.relevance.isolationforest.core.{
IsolationForestModelReadWrite,
IsolationForestParamsBase,
import com.linkedin.relevance.isolationforest.core.Utils.{
DataPoint,
avgPathLength,
validateFeatureVectorSize,
validateAndTransformSchema,
}
import com.linkedin.relevance.isolationforest.core.IsolationForestParamsBase
import org.apache.spark.ml.Model
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset}

/**
Expand All @@ -28,16 +29,35 @@ import org.apache.spark.sql.{DataFrame, Dataset}
* cases, a given isolation tree may not have any nodes using some of these features, e.g., a
* shallow tree where the number of features in the training data exceeds the number of nodes in
* the tree.
* @param totalNumFeatures
* The total number of input features seen during training, or
* [[IsolationForestModel.UnknownTotalNumFeatures]] for legacy loaded models that predate this
* metadata.
*/
class IsolationForestModel(
class IsolationForestModel private[isolationforest] (
override val uid: String,
val isolationTrees: Array[IsolationTree],
private val numSamples: Int,
private val numFeatures: Int,
private val totalNumFeatures: Int,
) extends Model[IsolationForestModel]
with IsolationForestParamsBase
with MLWritable {

def this(
uid: String,
isolationTrees: Array[IsolationTree],
numSamples: Int,
numFeatures: Int,
) =
this(
uid,
isolationTrees,
numSamples,
numFeatures,
IsolationForestModel.UnknownTotalNumFeatures,
)

require(numSamples > 0, s"parameter numSamples must be >0, but given invalid value ${numSamples}")
final def getNumSamples: Int = numSamples

Expand All @@ -47,6 +67,19 @@ class IsolationForestModel(
)
final def getNumFeatures: Int = numFeatures

require(
totalNumFeatures == IsolationForestModel.UnknownTotalNumFeatures || totalNumFeatures > 0,
s"parameter totalNumFeatures must be >0 or UnknownTotalNumFeatures, but given invalid value ${totalNumFeatures}",
)
require(
totalNumFeatures == IsolationForestModel.UnknownTotalNumFeatures || numFeatures <= totalNumFeatures,
s"parameter numFeatures must be <= totalNumFeatures, but given invalid values" +
s" numFeatures=${numFeatures}, totalNumFeatures=${totalNumFeatures}",
)
final def getTotalNumFeatures: Int = totalNumFeatures
final def hasKnownTotalNumFeatures: Boolean =
totalNumFeatures != IsolationForestModel.UnknownTotalNumFeatures

// The outlierScoreThreshold needs to be a mutable variable because it is not known when an
// IsolationForestModel instance is created.
private var outlierScoreThreshold: Double = -1
Expand All @@ -64,8 +97,9 @@ class IsolationForestModel(

override def copy(extra: ParamMap): IsolationForestModel = {

val isolationForestCopy = new IsolationForestModel(uid, isolationTrees, numSamples, numFeatures)
.setParent(this.parent)
val isolationForestCopy =
new IsolationForestModel(uid, isolationTrees, numSamples, numFeatures, totalNumFeatures)
.setParent(this.parent)
isolationForestCopy.setOutlierScoreThreshold(outlierScoreThreshold)
copyValues(isolationForestCopy, extra)
}
Expand All @@ -81,15 +115,26 @@ class IsolationForestModel(
*/
override def transform(data: Dataset[_]): DataFrame = {

require(
numSamples >= 2,
s"Cannot score with numSamples=$numSamples; expected numSamples >= 2.",
)
require(
isolationTrees.nonEmpty,
"Cannot score with an empty IsolationForestModel.",
)
transformSchema(data.schema, logging = true)

val avgPath = avgPathLength(numSamples)
val broadcastIsolationTrees = data.sparkSession.sparkContext.broadcast(isolationTrees)

val calculatePathLength = (features: Vector) => {
if (hasKnownTotalNumFeatures) {
validateFeatureVectorSize(features, totalNumFeatures)
}
val pathLength = broadcastIsolationTrees.value
.map(y => y.calculatePathLength(DataPoint(features.toArray.map(x => x.toFloat))))
.sum / $(numEstimators)
.sum / broadcastIsolationTrees.value.length
Math.pow(2, -pathLength / avgPath)
}
val transformUDF = udf(calculatePathLength)
Expand All @@ -105,44 +150,8 @@ class IsolationForestModel(
dataWithScoresAndPrediction
}

/**
* Validates the input schema and transforms it into the output schema. It validates that the
* input DataFrame has a $(featuresCol) of the correct type and appends the output columns to the
* input schema. It also ensures that the input DataFrame does not already have $(predictionCol)
* or $(scoreCol) columns, as they will be created during the fitting process.
*
* @param schema
* The schema of the DataFrame containing the data to be fit.
* @return
* The schema of the DataFrame containing the data to be fit, with the additional
* $(predictionCol) and $(scoreCol) columns added.
*/
override def transformSchema(schema: StructType): StructType = {

require(
schema.fieldNames.contains($(featuresCol)),
s"Input column ${$(featuresCol)} does not exist.",
)
require(
schema($(featuresCol)).dataType == VectorType,
s"Input column ${$(featuresCol)} is not of required type ${VectorType}",
)

require(
!schema.fieldNames.contains($(predictionCol)),
s"Output column ${$(predictionCol)} already exists.",
)
require(
!schema.fieldNames.contains($(scoreCol)),
s"Output column ${$(scoreCol)} already exists.",
)

val outputFields = schema.fields :+
StructField($(predictionCol), DoubleType, nullable = false) :+
StructField($(scoreCol), DoubleType, nullable = false)

StructType(outputFields)
}
override def transformSchema(schema: StructType): StructType =
validateAndTransformSchema(schema, $(featuresCol), $(predictionCol), $(scoreCol))

/**
* Returns an IsolationForestModelWriter instance that can be used to write the isolation forest
Expand All @@ -159,6 +168,8 @@ class IsolationForestModel(
*/
case object IsolationForestModel extends MLReadable[IsolationForestModel] {

private[isolationforest] val UnknownTotalNumFeatures: Int = -1

/**
* Returns an IsolationForestModelReader instance that can be used to read a saved isolation
* forest from disk.
Expand Down
Loading
Loading