From 579141d81885a63682070c061a4a4e0d6d0a89be Mon Sep 17 00:00:00 2001 From: Christoph Rohde <44606665+CodebyCR@users.noreply.github.com> Date: Thu, 12 Mar 2026 14:19:37 +0100 Subject: [PATCH 1/2] Implement L2 normalization extension for MLXArray Adds an extension method to MLXArray for L2 normalization. --- Source/MLX/MLXArray+Normalizer.swift | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 Source/MLX/MLXArray+Normalizer.swift diff --git a/Source/MLX/MLXArray+Normalizer.swift b/Source/MLX/MLXArray+Normalizer.swift new file mode 100644 index 00000000..9f75fa95 --- /dev/null +++ b/Source/MLX/MLXArray+Normalizer.swift @@ -0,0 +1,28 @@ +// Copyright © 2026 Apple Inc. + +import MLX + +extension MLXArray { + + /// Returns a new array normalized by its L2 norm along the specified axis. + /// + /// This operation scales the vectors along `axis` to unit length. If the + /// norm is smaller than `eps`, it is clamped to `eps` to ensure numerical + /// stability and prevent division by zero. + /// + /// - Parameters: + /// - axis: The axis along which to compute the norm. Defaults to `-1`. + /// - eps: A small epsilon value to prevent division by zero. Defaults to `1e-12`. + /// - Returns: An `MLXArray` with the same shape as the original, normalized along `axis`. + /// + /// - Complexity: O(n), where n is the total number of elements in the array. + public func l2Normalized(axis: Int = -1, eps: Float = 1e-12) -> MLXArray { + // 'self' represents the current MLXArray instance. + // We compute the norm along the specified axis. + let norm = MLXLinalg.norm(self, ord: 2, axis: axis, keepDims: true) + + // We use MLX.maximum to clamp the divisor. + // This is more stable than adding eps to the norm. + return self / MLX.maximum(norm, MLXArray(eps)) + } +} From 7d0b96b8e1d0fa72631ef65b0c73efcd8df8bb57 Mon Sep 17 00:00:00 2001 From: Christoph Rohde Date: Sun, 29 Mar 2026 21:43:29 +0200 Subject: [PATCH 2/2] Add MLXArrayL2NormalizationTests --- Source/MLX/MLXArray+Normalizer.swift | 2 - Tests/MLXTests/MLXArray+NormlizerTests.swift | 78 ++++++++++++++++++++ 2 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 Tests/MLXTests/MLXArray+NormlizerTests.swift diff --git a/Source/MLX/MLXArray+Normalizer.swift b/Source/MLX/MLXArray+Normalizer.swift index 9f75fa95..132d7ce6 100644 --- a/Source/MLX/MLXArray+Normalizer.swift +++ b/Source/MLX/MLXArray+Normalizer.swift @@ -1,7 +1,5 @@ // Copyright © 2026 Apple Inc. -import MLX - extension MLXArray { /// Returns a new array normalized by its L2 norm along the specified axis. diff --git a/Tests/MLXTests/MLXArray+NormlizerTests.swift b/Tests/MLXTests/MLXArray+NormlizerTests.swift new file mode 100644 index 00000000..8fbdbfd6 --- /dev/null +++ b/Tests/MLXTests/MLXArray+NormlizerTests.swift @@ -0,0 +1,78 @@ +// Copyright © 2026 Apple Inc. + +import Foundation +import XCTest + +@testable import MLX + +public final class MLXArrayL2NormalizationTests: XCTestCase { + + /// Tests standard L2 normalization for a 1D vector. + /// Magnitude is exactly 5.0, result should be unit length (1.0). + func testL2NormalizationStandard() { + let rawArray: [Float] = [3.0, 4.0] + let array = MLXArray(rawArray, [2]) + let normalized = array.l2Normalized() + + let rawExpected: [Float] = [0.6, 0.8] + let expected = MLXArray(rawExpected, [2]) + + // Use allClose for floating point comparison in MLX + XCTAssertTrue(allClose(normalized, expected).item(Bool.self)) + + // Verify Magnitude: Must be 1.0 + let magnitude = MLXLinalg.norm(normalized, ord: 2).item(Float.self) + XCTAssertEqual(magnitude, 1.0, accuracy: 1e-6) + } + + /// Tests normalization along a specific axis in a 2D matrix. + func testL2NormalizationAlongAxis() { + // 2x2 Matrix: [[3, 4], [0, 1]] + let rawArray: [Float] = [3.0, 4.0, 0.0, 1.0] + let array = MLXArray(rawArray, [2, 2]) + + // Normalize along the last axis (rows) + let normalized = array.l2Normalized(axis: -1) + + // Row 1: [0.6, 0.8], Row 2: [0.0, 1.0] + let rawExpected: [Float] = [0.6, 0.8, 0.0, 1.0] + let expected = MLXArray(rawExpected, [2, 2]) + + XCTAssertTrue(allClose(normalized, expected).item(Bool.self)) + } + + /// CRITICAL: Tests behavior with a zero vector to ensure numerical stability via epsilon. + func testL2NormalizationZeroVector() { + let eps: Float = 1e-8 + let rawArray: [Float] = [0.0, 0.0] + let array = MLXArray(rawArray, [2]) + let normalized = array.l2Normalized(eps: eps) + + // Since Norm (0) < eps, we divide by eps. + // 0 / eps remains 0, preventing NaN. + let rawExpected: [Float] = [0.0, 0.0] + let expected = MLXArray(rawExpected, [2]) + + XCTAssertTrue(allClose(normalized, expected).item(Bool.self)) + + // Magnitude should be 0.0, not NaN! + let magnitude = MLXLinalg.norm(normalized, ord: 2).item(Float.self) + XCTAssertFalse(magnitude.isNaN, "Resulting magnitude should not be NaN") + XCTAssertEqual(magnitude, 0.0) + } + + /// Tests values that are smaller than the provided epsilon. + func testL2NormalizationUnderEpsilon() { + let eps: Float = 1e-3 + let rawArray: [Float] = [1e-5, 1e-5] + let array = MLXArray(rawArray, [2]) // Norm is approx 1.41 * 1e-5 + let normalized = array.l2Normalized(eps: eps) + + // The norm is smaller than eps, so the divisor is clamped to eps (0.001) + let expectedValue = Float(1e-5) / eps + let expected = MLXArray([expectedValue, expectedValue], [2]) + + XCTAssertTrue(allClose(normalized, expected).item(Bool.self)) + } + +}