Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions tests/test_outlier.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,93 @@ def test_deterministic(self):
r1 = oq1.dequantize(c1)
r2 = oq2.dequantize(c2)
np.testing.assert_allclose(r1, r2, atol=1e-15)


class TestCalibrate:
"""Tests for OutlierTurboQuant.calibrate() data-driven channel split."""

def test_calibrate_finds_known_outlier_channels(self):
"""calibrate() should identify channels with artificially large RMS as outliers."""
from turboquant.outlier import OutlierTurboQuant

d = 128
rng = np.random.default_rng(42)

# Build calibration data with clear outliers in channels 10, 20, 30
n_samples = 500
calib = rng.standard_normal((n_samples, d)) # baseline ~ N(0,1)
outlier_channels = [10, 20, 30]
for ch in outlier_channels:
calib[:, ch] *= 20.0 # RMS ~ 20 >> 3 * median ≈ 3

oq = OutlierTurboQuant(d=d, target_bits=2.5, seed=7)

# Before calibration: fixed split (channels 0..n_outlier-1)
default_outlier_idx = set(oq.outlier_idx.tolist())

oq.calibrate(calib)

calibrated_outlier_idx = set(oq.outlier_idx.tolist())

# All injected outlier channels should now be classified as outliers
for ch in outlier_channels:
assert ch in calibrated_outlier_idx, (
f"Channel {ch} (amplified 20×) not identified as outlier after calibration"
)

# The calibrated split should differ from the fixed default
assert calibrated_outlier_idx != default_outlier_idx, (
"calibrate() produced the same channel split as the fixed default — "
"expected a different split for data with injected outlier channels"
)

# Consistency check
assert oq.n_outlier + oq.n_normal == d

def test_calibrate_no_outliers(self):
"""calibrate() on uniform data should find zero or very few outlier channels."""
from turboquant.outlier import OutlierTurboQuant

d = 64
rng = np.random.default_rng(99)
# All channels have equal variance — no outliers expected
calib = rng.standard_normal((1000, d))

oq = OutlierTurboQuant(d=d, target_bits=2.5, seed=1)
oq.calibrate(calib)

# With uniform data, per-channel RMS should all be close to 1.
# 3× median ≈ 3 threshold means essentially no channel exceeds it.
# Allow a small fraction due to sampling variance.
outlier_fraction = oq.n_outlier / d
assert outlier_fraction < 0.1, (
f"Expected <10% outliers on uniform data, got {outlier_fraction:.1%} "
f"({oq.n_outlier}/{d})"
)

def test_calibrate_preserves_default_without_call(self):
"""Without calling calibrate(), the fixed split is unchanged."""
from turboquant.outlier import OutlierTurboQuant

d = 128
oq = OutlierTurboQuant(d=d, target_bits=2.5, seed=42)

# Default: first n_outlier channels
expected_outlier = np.arange(oq.n_outlier)
np.testing.assert_array_equal(oq.outlier_idx, expected_outlier)

def test_calibrate_updates_counts(self):
"""After calibrate(), n_outlier and n_normal should reflect new split."""
from turboquant.outlier import OutlierTurboQuant

d = 64
rng = np.random.default_rng(5)
calib = rng.standard_normal((200, d))
calib[:, 0] *= 50.0 # one strong outlier

oq = OutlierTurboQuant(d=d, target_bits=3.5, seed=3)
oq.calibrate(calib)

assert oq.n_outlier + oq.n_normal == d
assert len(oq.outlier_idx) == oq.n_outlier
assert len(oq.normal_idx) == oq.n_normal
43 changes: 43 additions & 0 deletions turboquant/outlier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
- 3.5-bit: 64/128 outlier at 4b + 64/128 normal at 3b = (64×4 + 64×3)/128 = 3.5
"""

import logging
import numpy as np
from dataclasses import dataclass

from turboquant.polar_quant import PolarQuant
from turboquant.qjl import QJL

logger = logging.getLogger(__name__)


@dataclass
class OutlierCompressedVector:
Expand Down Expand Up @@ -93,6 +96,46 @@ def __init__(self, d: int, target_bits: float, seed: int = 42):
# QJL on full residual
self.qjl = QJL(d, seed=seed + 1000)

def calibrate(self, calibration_vectors: np.ndarray) -> None:
"""Update outlier/normal channel split using calibration data.

Computes per-channel RMS over the calibration samples and identifies
outlier channels dynamically: channels where RMS > 3× median RMS.
Updates self.outlier_idx and self.normal_idx accordingly.

The existing fixed-split behaviour (channels 0..n_outlier-1 as outliers)
is used when this method is never called.

Args:
calibration_vectors: 2D array of shape (n_samples, d) used to
estimate per-channel activation magnitudes.
"""
assert calibration_vectors.ndim == 2, (
f"calibration_vectors must be 2D (n_samples, d), got {calibration_vectors.ndim}D"
)
assert calibration_vectors.shape[1] == self.d, (
f"calibration_vectors.shape[1]={calibration_vectors.shape[1]} != d={self.d}"
)

# Per-channel RMS: sqrt(mean(x^2)) over samples
per_channel_rms = np.sqrt(np.mean(calibration_vectors ** 2, axis=0)) # (d,)

median_rms = np.median(per_channel_rms)
threshold = 3.0 * median_rms

outlier_mask = per_channel_rms > threshold
n_found = int(outlier_mask.sum())

logger.info(
"calibrate(): found %d outlier channels out of %d (threshold=%.4f, median_rms=%.4f)",
n_found, self.d, threshold, median_rms,
)

self.outlier_idx = np.where(outlier_mask)[0]
self.normal_idx = np.where(~outlier_mask)[0]
self.n_outlier = len(self.outlier_idx)
self.n_normal = len(self.normal_idx)

def quantize(self, x: np.ndarray) -> OutlierCompressedVector:
"""Quantize with outlier channel split."""
single = x.ndim == 1
Expand Down