diff --git a/tests/test_rotation.py b/tests/test_rotation.py index 520fd57fc..c3556cc47 100644 --- a/tests/test_rotation.py +++ b/tests/test_rotation.py @@ -335,3 +335,94 @@ def test_batch_matches_single(self): for i in range(10): single_result = apply_fast_rotation(X[i], signs1, signs2, padded_d) np.testing.assert_allclose(batch_result[i], single_result, atol=1e-10) + + +class TestFastRotationExtended: + """Additional tests for fast rotation: round-trip, batch, and Gaussianization.""" + + @pytest.mark.parametrize("d", [64, 128, 256]) + def test_fast_rotation_round_trip(self, d): + """apply_fast_rotation_transpose(apply_fast_rotation(x)) ≈ x for power-of-2 sizes.""" + from turboquant.rotation import ( + random_rotation_fast, apply_fast_rotation, apply_fast_rotation_transpose + ) + + rng = np.random.default_rng(42) + signs1, signs2, padded_d = random_rotation_fast(d, rng) + + rng_vec = np.random.default_rng(7) + for _ in range(20): + x = rng_vec.standard_normal(d) + y = apply_fast_rotation(x, signs1, signs2, padded_d) + x_back = apply_fast_rotation_transpose(y, signs1, signs2, padded_d) + np.testing.assert_allclose( + x_back, x, atol=1e-10, + err_msg=f"Round-trip failed for d={d}" + ) + + @pytest.mark.parametrize("d", [64, 128, 256]) + def test_fast_rotation_batch_matches_single(self, d): + """apply_fast_rotation_batch applied to a batch matches apply_fast_rotation element-wise.""" + from turboquant.rotation import ( + random_rotation_fast, apply_fast_rotation, apply_fast_rotation_batch + ) + + rng = np.random.default_rng(99) + signs1, signs2, padded_d = random_rotation_fast(d, rng) + + rng_vec = np.random.default_rng(11) + X = rng_vec.standard_normal((8, d)) + + batch_result = apply_fast_rotation_batch(X, signs1, signs2, padded_d) + assert batch_result.shape == (8, d) + + for i in range(8): + single_result = apply_fast_rotation(X[i], signs1, signs2, padded_d) + np.testing.assert_allclose( + batch_result[i], single_result, atol=1e-10, + err_msg=f"Batch vs single mismatch at index {i}, d={d}" + ) + + @pytest.mark.parametrize("d", [64, 128, 256]) + def test_fast_rotation_distributes_energy(self, d): + """Post-rotation coordinates should be approximately zero-mean with variance ≈ 1/d. + + This verifies the Gaussianize property: the structured rotation (D@H@D) spreads + energy uniformly across dimensions. We apply the same rotation to many random + vectors and check that each output coordinate has mean ≈ 0 and variance ≈ 1/d. + """ + from turboquant.rotation import random_rotation_fast, apply_fast_rotation + + rng = np.random.default_rng(42) + signs1, signs2, padded_d = random_rotation_fast(d, rng) + + n_samples = 2000 + rng_vec = np.random.default_rng(55) + # Use unit vectors so norms don't dominate + X = rng_vec.standard_normal((n_samples, d)) + X = X / np.linalg.norm(X, axis=1, keepdims=True) + + rotated = np.stack([ + apply_fast_rotation(X[i], signs1, signs2, padded_d) + for i in range(n_samples) + ]) + + # Each coordinate should be approximately zero-mean + coord_means = rotated.mean(axis=0) + mean_bound = 4 * np.sqrt(1.0 / d / n_samples) + assert np.all(np.abs(coord_means) < max(mean_bound, 0.05)), ( + f"Max coordinate mean {np.max(np.abs(coord_means)):.4f} exceeds bound " + f"{max(mean_bound, 0.05):.4f} (d={d})" + ) + + # Each coordinate should have variance ≈ 1/d (energy spread uniformly) + coord_vars = rotated.var(axis=0) + expected_var = 1.0 / d + assert np.all(coord_vars < expected_var * 2.0), ( + f"Max coordinate variance {np.max(coord_vars):.6f} exceeds 2× expected " + f"{expected_var:.6f} (d={d})" + ) + assert np.all(coord_vars > expected_var * 0.3), ( + f"Min coordinate variance {np.min(coord_vars):.6f} is below 0.3× expected " + f"{expected_var:.6f} (d={d})" + )