Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions tests/test_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,94 @@ def test_batch_matches_single(self):
for i in range(10):
single_result = apply_fast_rotation(X[i], signs1, signs2, padded_d)
np.testing.assert_allclose(batch_result[i], single_result, atol=1e-10)


class TestFastRotationExtended:
"""Additional tests for fast rotation: round-trip, batch, and Gaussianization."""

@pytest.mark.parametrize("d", [64, 128, 256])
def test_fast_rotation_round_trip(self, d):
"""apply_fast_rotation_transpose(apply_fast_rotation(x)) ≈ x for power-of-2 sizes."""
from turboquant.rotation import (
random_rotation_fast, apply_fast_rotation, apply_fast_rotation_transpose
)

rng = np.random.default_rng(42)
signs1, signs2, padded_d = random_rotation_fast(d, rng)

rng_vec = np.random.default_rng(7)
for _ in range(20):
x = rng_vec.standard_normal(d)
y = apply_fast_rotation(x, signs1, signs2, padded_d)
x_back = apply_fast_rotation_transpose(y, signs1, signs2, padded_d)
np.testing.assert_allclose(
x_back, x, atol=1e-10,
err_msg=f"Round-trip failed for d={d}"
)

@pytest.mark.parametrize("d", [64, 128, 256])
def test_fast_rotation_batch_matches_single(self, d):
"""apply_fast_rotation_batch applied to a batch matches apply_fast_rotation element-wise."""
from turboquant.rotation import (
random_rotation_fast, apply_fast_rotation, apply_fast_rotation_batch
)

rng = np.random.default_rng(99)
signs1, signs2, padded_d = random_rotation_fast(d, rng)

rng_vec = np.random.default_rng(11)
X = rng_vec.standard_normal((8, d))

batch_result = apply_fast_rotation_batch(X, signs1, signs2, padded_d)
assert batch_result.shape == (8, d)

for i in range(8):
single_result = apply_fast_rotation(X[i], signs1, signs2, padded_d)
np.testing.assert_allclose(
batch_result[i], single_result, atol=1e-10,
err_msg=f"Batch vs single mismatch at index {i}, d={d}"
)

@pytest.mark.parametrize("d", [64, 128, 256])
def test_fast_rotation_distributes_energy(self, d):
"""Post-rotation coordinates should be approximately zero-mean with variance ≈ 1/d.

This verifies the Gaussianize property: the structured rotation (D@H@D) spreads
energy uniformly across dimensions. We apply the same rotation to many random
vectors and check that each output coordinate has mean ≈ 0 and variance ≈ 1/d.
"""
from turboquant.rotation import random_rotation_fast, apply_fast_rotation

rng = np.random.default_rng(42)
signs1, signs2, padded_d = random_rotation_fast(d, rng)

n_samples = 2000
rng_vec = np.random.default_rng(55)
# Use unit vectors so norms don't dominate
X = rng_vec.standard_normal((n_samples, d))
X = X / np.linalg.norm(X, axis=1, keepdims=True)

rotated = np.stack([
apply_fast_rotation(X[i], signs1, signs2, padded_d)
for i in range(n_samples)
])

# Each coordinate should be approximately zero-mean
coord_means = rotated.mean(axis=0)
mean_bound = 4 * np.sqrt(1.0 / d / n_samples)
assert np.all(np.abs(coord_means) < max(mean_bound, 0.05)), (
f"Max coordinate mean {np.max(np.abs(coord_means)):.4f} exceeds bound "
f"{max(mean_bound, 0.05):.4f} (d={d})"
)

# Each coordinate should have variance ≈ 1/d (energy spread uniformly)
coord_vars = rotated.var(axis=0)
expected_var = 1.0 / d
assert np.all(coord_vars < expected_var * 2.0), (
f"Max coordinate variance {np.max(coord_vars):.6f} exceeds 2× expected "
f"{expected_var:.6f} (d={d})"
)
assert np.all(coord_vars > expected_var * 0.3), (
f"Min coordinate variance {np.min(coord_vars):.6f} is below 0.3× expected "
f"{expected_var:.6f} (d={d})"
)