From 528632561f143d98618a1afbcb40bfc05456c354 Mon Sep 17 00:00:00 2001 From: brett Date: Wed, 1 Apr 2026 16:22:18 -0400 Subject: [PATCH] feat: add calibrate() to OutlierTurboQuant for data-driven channel split MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously the outlier/inlier channel split was set at construction time and never adjusted. calibrate(calibration_vectors) now computes per-channel RMS, flags channels whose RMS exceeds 3× the median as outliers, and updates the split on the compressor — matching the dynamic-threshold approach described in the LLM.int8() and SmoothQuant literature. Co-Authored-By: Claude Sonnet 4.6 --- tests/test_outlier.py | 90 +++++++++++++++++++++++++++++++++++++++++++ turboquant/outlier.py | 43 +++++++++++++++++++++ 2 files changed, 133 insertions(+) diff --git a/tests/test_outlier.py b/tests/test_outlier.py index 3f2df2196..e8714ec8b 100644 --- a/tests/test_outlier.py +++ b/tests/test_outlier.py @@ -110,3 +110,93 @@ def test_deterministic(self): r1 = oq1.dequantize(c1) r2 = oq2.dequantize(c2) np.testing.assert_allclose(r1, r2, atol=1e-15) + + +class TestCalibrate: + """Tests for OutlierTurboQuant.calibrate() data-driven channel split.""" + + def test_calibrate_finds_known_outlier_channels(self): + """calibrate() should identify channels with artificially large RMS as outliers.""" + from turboquant.outlier import OutlierTurboQuant + + d = 128 + rng = np.random.default_rng(42) + + # Build calibration data with clear outliers in channels 10, 20, 30 + n_samples = 500 + calib = rng.standard_normal((n_samples, d)) # baseline ~ N(0,1) + outlier_channels = [10, 20, 30] + for ch in outlier_channels: + calib[:, ch] *= 20.0 # RMS ~ 20 >> 3 * median ≈ 3 + + oq = OutlierTurboQuant(d=d, target_bits=2.5, seed=7) + + # Before calibration: fixed split (channels 0..n_outlier-1) + default_outlier_idx = set(oq.outlier_idx.tolist()) + + oq.calibrate(calib) + + calibrated_outlier_idx = set(oq.outlier_idx.tolist()) + + # All injected outlier channels should now be classified as outliers + for ch in outlier_channels: + assert ch in calibrated_outlier_idx, ( + f"Channel {ch} (amplified 20×) not identified as outlier after calibration" + ) + + # The calibrated split should differ from the fixed default + assert calibrated_outlier_idx != default_outlier_idx, ( + "calibrate() produced the same channel split as the fixed default — " + "expected a different split for data with injected outlier channels" + ) + + # Consistency check + assert oq.n_outlier + oq.n_normal == d + + def test_calibrate_no_outliers(self): + """calibrate() on uniform data should find zero or very few outlier channels.""" + from turboquant.outlier import OutlierTurboQuant + + d = 64 + rng = np.random.default_rng(99) + # All channels have equal variance — no outliers expected + calib = rng.standard_normal((1000, d)) + + oq = OutlierTurboQuant(d=d, target_bits=2.5, seed=1) + oq.calibrate(calib) + + # With uniform data, per-channel RMS should all be close to 1. + # 3× median ≈ 3 threshold means essentially no channel exceeds it. + # Allow a small fraction due to sampling variance. + outlier_fraction = oq.n_outlier / d + assert outlier_fraction < 0.1, ( + f"Expected <10% outliers on uniform data, got {outlier_fraction:.1%} " + f"({oq.n_outlier}/{d})" + ) + + def test_calibrate_preserves_default_without_call(self): + """Without calling calibrate(), the fixed split is unchanged.""" + from turboquant.outlier import OutlierTurboQuant + + d = 128 + oq = OutlierTurboQuant(d=d, target_bits=2.5, seed=42) + + # Default: first n_outlier channels + expected_outlier = np.arange(oq.n_outlier) + np.testing.assert_array_equal(oq.outlier_idx, expected_outlier) + + def test_calibrate_updates_counts(self): + """After calibrate(), n_outlier and n_normal should reflect new split.""" + from turboquant.outlier import OutlierTurboQuant + + d = 64 + rng = np.random.default_rng(5) + calib = rng.standard_normal((200, d)) + calib[:, 0] *= 50.0 # one strong outlier + + oq = OutlierTurboQuant(d=d, target_bits=3.5, seed=3) + oq.calibrate(calib) + + assert oq.n_outlier + oq.n_normal == d + assert len(oq.outlier_idx) == oq.n_outlier + assert len(oq.normal_idx) == oq.n_normal diff --git a/turboquant/outlier.py b/turboquant/outlier.py index b3f11986d..e42170886 100644 --- a/turboquant/outlier.py +++ b/turboquant/outlier.py @@ -9,12 +9,15 @@ - 3.5-bit: 64/128 outlier at 4b + 64/128 normal at 3b = (64×4 + 64×3)/128 = 3.5 """ +import logging import numpy as np from dataclasses import dataclass from turboquant.polar_quant import PolarQuant from turboquant.qjl import QJL +logger = logging.getLogger(__name__) + @dataclass class OutlierCompressedVector: @@ -93,6 +96,46 @@ def __init__(self, d: int, target_bits: float, seed: int = 42): # QJL on full residual self.qjl = QJL(d, seed=seed + 1000) + def calibrate(self, calibration_vectors: np.ndarray) -> None: + """Update outlier/normal channel split using calibration data. + + Computes per-channel RMS over the calibration samples and identifies + outlier channels dynamically: channels where RMS > 3× median RMS. + Updates self.outlier_idx and self.normal_idx accordingly. + + The existing fixed-split behaviour (channels 0..n_outlier-1 as outliers) + is used when this method is never called. + + Args: + calibration_vectors: 2D array of shape (n_samples, d) used to + estimate per-channel activation magnitudes. + """ + assert calibration_vectors.ndim == 2, ( + f"calibration_vectors must be 2D (n_samples, d), got {calibration_vectors.ndim}D" + ) + assert calibration_vectors.shape[1] == self.d, ( + f"calibration_vectors.shape[1]={calibration_vectors.shape[1]} != d={self.d}" + ) + + # Per-channel RMS: sqrt(mean(x^2)) over samples + per_channel_rms = np.sqrt(np.mean(calibration_vectors ** 2, axis=0)) # (d,) + + median_rms = np.median(per_channel_rms) + threshold = 3.0 * median_rms + + outlier_mask = per_channel_rms > threshold + n_found = int(outlier_mask.sum()) + + logger.info( + "calibrate(): found %d outlier channels out of %d (threshold=%.4f, median_rms=%.4f)", + n_found, self.d, threshold, median_rms, + ) + + self.outlier_idx = np.where(outlier_mask)[0] + self.normal_idx = np.where(~outlier_mask)[0] + self.n_outlier = len(self.outlier_idx) + self.n_normal = len(self.normal_idx) + def quantize(self, x: np.ndarray) -> OutlierCompressedVector: """Quantize with outlier channel split.""" single = x.ndim == 1