Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions REGRESSIONS.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Regressions

Last checked: 2025-07-31
Last checked: 2025-08-19

# 10 failing test(s)
# 11 failing test(s)

- tests/test_nn_modeling.py::test_auxiliary_activation
- tests/test_nn_modeling.py::test_auxiliary_activation_k_exceeds_size
- tests/test_nn_modeling.py::test_batch_topk_activation
- tests/test_nn_objectives.py::test_auxiliary_coeff
- tests/test_nn_objectives.py::test_auxiliary_mse_same
- tests/test_nn_objectives.py::test_safe_mse_hypothesis
- tests/test_ordered_dataloader.py::test_ordered_dataloader_with_tiny_fake_dataset
- tests/test_reservoir_buffer.py::test_blocking_get_when_empty[proc]
- tests/test_reservoir_buffer.py::test_blocking_put_when_full[proc]
- tests/test_ring_buffer.py::test_blocking_get_when_empty[proc]
- tests/test_ring_buffer.py::test_blocking_put_when_full[proc]
- tests/test_writers_properties.py::test_dataloader_batches
- tests/test_writers_properties.py::test_metadata_json_has_required_keys
- tests/test_writers_properties.py::test_roundtrip
- tests/test_writers_properties.py::test_shard_size_consistency
- tests/test_writers_properties.py::test_shard_writer_and_dataset_e2e

# Coverage

Coverage: 1210/1816 lines (66.6%)
Coverage: 694/1933 lines (35.9%)
49 changes: 49 additions & 0 deletions src/saev/nn/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ class BatchTopK:
ActivationConfig = Relu | TopK | BatchTopK


@beartype.beartype
@dataclasses.dataclass(frozen=True)
class AuxiliaryConfig:
top_k: int = 512
"""How many dead latents to consider for auxiliary loss."""


@beartype.beartype
@dataclasses.dataclass(frozen=True)
class SparseAutoencoderConfig:
Expand Down Expand Up @@ -319,6 +326,48 @@ def forward(self, x: Float[Tensor, "batch d_sae"]) -> Float[Tensor, "batch d_sae
return torch.mul(mask, x)


class AuxiliaryLossActivation(torch.nn.Module):
"""
Auxiliary loss activation function. Used to take the top-k dead latents before calculating the auxiliary loss.
"""

def __init__(self, cfg: AuxiliaryConfig = AuxiliaryConfig()):
super().__init__()
self.cfg = cfg

def forward(
self,
f_x: Float[Tensor, "batch d_sae"],
dead_latents: Float[Tensor, "batch d_sae"],
) -> Float[Tensor, "batch d_sae"]:
"""
Apply auxiliary loss activation (top-k of dead latents) to the input tensor.
"""

# First, mask out all but dead latents
f_x = f_x * dead_latents

masked_dead_top_k = torch.zeros_like(f_x)

# Now, populate top k of the dead latents
if self.cfg.top_k > 0 and dead_latents.sum() > 0:
# First, mask out dead latents
masked_dead_latents = f_x * dead_latents

# Find top k of dead latents
k_vals, k_inds = torch.topk(
masked_dead_latents, min(self.cfg.top_k, masked_dead_latents.shape[1]), dim=1, sorted=False
)
top_k_mask = torch.zeros_like(masked_dead_latents).scatter_(
dim=-1, index=k_inds, src=torch.ones_like(masked_dead_latents)
)

# Mask out all but top k dead latents
masked_dead_top_k = torch.mul(top_k_mask, f_x)

return masked_dead_top_k


@beartype.beartype
def get_activation(cfg: ActivationConfig) -> torch.nn.Module:
if isinstance(cfg, Relu):
Expand Down
52 changes: 52 additions & 0 deletions src/saev/nn/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ class Matryoshka:
ObjectiveConfig = Vanilla | Matryoshka


@beartype.beartype
@dataclasses.dataclass(frozen=True, slots=True)
class Auxiliary:
"""
Config for the Auxiliary loss (not for the SAE itself, but for auxiliary loss).

Reference paper is https://doi.org/10.48550/arXiv.2412.06410
"""

aux_coeff: float = 0.03125
"""Coefficient for the auxiliary loss term."""


@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass(frozen=True, slots=True)
class Loss:
Expand Down Expand Up @@ -163,6 +176,45 @@ def forward(
return MatryoshkaLoss(mse_loss, sparsity_loss, l0, l1)


@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass(frozen=True, slots=True)
class AuxiliaryLoss(Loss):
"""The vanilla loss terms for an training batch."""

mse: Float[Tensor, ""]
"""Reconstruction loss (mean squared error)."""

@property
def loss(self) -> Float[Tensor, ""]:
"""Total loss."""
return self.mse

def metrics(self) -> dict[str, object]:
return {
"loss": self.loss.item(),
"mse": self.mse.item(),
}


@jaxtyped(typechecker=beartype.beartype)
class AuxiliaryObjective(Objective):
def __init__(self, cfg: Auxiliary):
super().__init__()
self.cfg = cfg

def forward(
self,
x: Float[Tensor, "batch d_model"],
x_hat: Float[Tensor, "batch d_model"],
) -> VanillaLoss:
# Some values of x and x_hat can be very large. We can calculate a safe MSE
mse_loss = mean_squared_err(x_hat, x)

mse_loss = mse_loss.mean()

return AuxiliaryLoss(self.cfg.aux_coeff * mse_loss)


@beartype.beartype
def get_objective(cfg: ObjectiveConfig) -> Objective:
if isinstance(cfg, Vanilla):
Expand Down
133 changes: 133 additions & 0 deletions tests/test_nn_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ def batch_topk_cfgs():
return st.builds(modeling.BatchTopK, top_k=st.sampled_from([1, 2, 4, 8]))


def aux_cfgs():
return st.builds(modeling.AuxiliaryConfig, top_k=st.sampled_from([1, 2, 4, 8]))


@given(
cfg=aux_cfgs(),
batch=st.integers(min_value=1, max_value=4),
d_sae=st.integers(min_value=256, max_value=2048),
)
def test_auxiliary_activation(cfg, batch, d_sae):
act = modeling.AuxiliaryLossActivation(cfg)
dead_lts = torch.zeros((batch, d_sae))
x = torch.rand(batch, d_sae)
y = act(x, dead_lts)

assert y.shape == (batch, d_sae)
# Check that only k elements are non-zero per sample
assert (y != 0).sum(dim=1).le(cfg.top_k).all()


@given(
cfg=topk_cfgs(),
batch=st.integers(min_value=1, max_value=4),
Expand Down Expand Up @@ -114,6 +134,20 @@ def test_topk_ties():
assert y[y != 0].unique().item() == 2.0


def test_auxiliary_activation_ties():
"""Test Auxiliary activation behavior with tied values."""
cfg = modeling.AuxiliaryConfig(top_k=2)
act = modeling.AuxiliaryLossActivation(cfg)

x = torch.tensor([[2.0, 2.0, 2.0, 2.0]])
y = act(x, torch.ones_like(x))

# Should select first k elements in case of ties
assert (y != 0).sum() == 2
# Verify the selected values are correct
assert y[y != 0].unique().item() == 2.0


def test_topk_k_equals_size():
"""Test TopK when k equals tensor size."""
cfg = modeling.TopK(top_k=4)
Expand All @@ -126,6 +160,30 @@ def test_topk_k_equals_size():
torch.testing.assert_close(y, x)


def test_auxiliary_activation_k_equals_size():
"""Test Auxiliary activation when k equals tensor size."""
cfg = modeling.AuxiliaryConfig(top_k=4)
act = modeling.AuxiliaryLossActivation(cfg)

x = torch.tensor([[5.0, 1.0, 3.0, 2.0]])
y = act(x, torch.ones_like(x))

# All values should be preserved
torch.testing.assert_close(y, x)


def test_auxiliary_activation_k_exceeds_size():
"""Test Auxiliary activation when k exceeds tensor size."""
cfg = modeling.AuxiliaryConfig(top_k=8)
act = modeling.AuxiliaryLossActivation(cfg)

x = torch.tensor([[5.0, 1.0, 3.0, 2.0]])
y = act(x, torch.ones_like(x))

# All values should be preserved
torch.testing.assert_close(y, x)


def test_topk_negative_values():
"""Test TopK with negative values."""
cfg = modeling.TopK(top_k=2)
Expand All @@ -139,6 +197,19 @@ def test_topk_negative_values():
torch.testing.assert_close(y, expected)


def test_auxiliary_activation_negative_values():
"""Test Auxiliary activation with negative values."""
cfg = modeling.AuxiliaryConfig(top_k=2)
act = modeling.AuxiliaryLossActivation(cfg)

x = torch.tensor([[-5.0, -1.0, -3.0, -2.0]])
y = act(x, torch.ones_like(x))

# Should select -1.0 and -2.0 (largest values)
expected = torch.tensor([[0.0, -1.0, 0.0, -2.0]])
torch.testing.assert_close(y, expected)


def test_topk_gradient_flow():
"""Test that gradients flow correctly through TopK."""
cfg = modeling.TopK(top_k=2)
Expand All @@ -156,6 +227,23 @@ def test_topk_gradient_flow():
torch.testing.assert_close(x.grad, expected_grad)


def test_auxiliary_activation_gradient_flow():
"""Test that gradients flow correctly through Auxiliary activation."""
cfg = modeling.AuxiliaryConfig(top_k=2)
act = modeling.AuxiliaryLossActivation(cfg)

x = torch.tensor([[5.0, 1.0, 3.0, 2.0], [2.0, 4.0, 1.0, 3.0]], requires_grad=True)
y = act(x, torch.ones_like(x))

# Create a simple loss (sum of outputs)
loss = y.sum()
loss.backward()

# Expected gradient: 1.0 for selected elements, 0.0 for others
expected_grad = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]])
torch.testing.assert_close(x.grad, expected_grad)


def test_topk_gradient_sparsity():
"""Verify gradient sparsity matches forward pass selection."""
cfg = modeling.TopK(top_k=3)
Expand All @@ -180,6 +268,30 @@ def test_topk_gradient_sparsity():
torch.testing.assert_close(selected_grads, expected_grads)


def test_auxiliary_activation_gradient_sparsity():
"""Verify gradient sparsity matches forward pass selection for Auxiliary activation."""
cfg = modeling.AuxiliaryConfig(top_k=3)
act = modeling.AuxiliaryLossActivation(cfg)

torch.manual_seed(42)
x = torch.randn(2, 8, requires_grad=True)
y = act(x, torch.ones_like(x))

# Use a different upstream gradient
grad_output = torch.randn_like(y)
y.backward(grad_output)

# Check that gradient sparsity matches forward pass
forward_mask = (y != 0).float()
grad_mask = (x.grad != 0).float()
torch.testing.assert_close(forward_mask, grad_mask)

# Verify gradient values for selected elements
selected_grads = x.grad * forward_mask
expected_grads = grad_output * forward_mask
torch.testing.assert_close(selected_grads, expected_grads)


def test_topk_zero_gradient_for_unselected():
"""Explicitly verify that non-selected elements have exactly 0.0 gradients."""
cfg = modeling.TopK(top_k=2)
Expand All @@ -201,6 +313,27 @@ def test_topk_zero_gradient_for_unselected():
torch.testing.assert_close(x.grad[0, 5], torch.tensor(1.0))


def test_auxiliary_activation_zero_gradient_for_unselected():
"""Explicitly verify that non-selected elements have exactly 0.0 gradients."""
cfg = modeling.AuxiliaryConfig(top_k=2)
act = modeling.AuxiliaryLossActivation(cfg)

x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], requires_grad=True)
y = act(x, torch.tensor([0, 0, 0, 0, 1, 1]))

loss = y.sum()
loss.backward()

# Elements at indices 0, 1, 2, 3 should have zero gradients
torch.testing.assert_close(x.grad[0, 0], torch.tensor(0.0))
torch.testing.assert_close(x.grad[0, 1], torch.tensor(0.0))
torch.testing.assert_close(x.grad[0, 2], torch.tensor(0.0))
torch.testing.assert_close(x.grad[0, 3], torch.tensor(0.0))
# Elements at indices 4, 5 should have non-zero gradients
torch.testing.assert_close(x.grad[0, 4], torch.tensor(1.0))
torch.testing.assert_close(x.grad[0, 5], torch.tensor(1.0))


# BatchTopK Edge Case Tests
def test_batchtopk_basic_forward():
"""Test basic BatchTopK forward pass with known values."""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_nn_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,28 @@ def test_safe_mse_large_x():
assert not safe.isnan().any()


def test_auxiliary_mse_same():
x = torch.ones((45, 12), dtype=torch.float)
x_hat = torch.ones((45, 12), dtype=torch.float)
aux_objective = objectives.AuxiliaryObjective(objectives.Auxiliary(aux_coeff=1.0))
torch.testing.assert_close(
aux_objective(x_hat, x).loss,
objectives.mean_squared_err(x_hat, x).mean(),
)


def test_auxiliary_coeff():
x = torch.ones((45, 12), dtype=torch.float)
x_hat = torch.full((45, 12), 3, dtype=torch.float)
aux_objective = objectives.AuxiliaryObjective(objectives.Auxiliary(aux_coeff=0.5))
print(aux_objective(x_hat, x).loss)
print(0.5 * objectives.mean_squared_err(x_hat, x).mean())
torch.testing.assert_close(
aux_objective(x_hat, x).loss,
0.5 * objectives.mean_squared_err(x_hat, x).mean(),
)


def test_factories():
assert isinstance(
objectives.get_objective(objectives.Vanilla()), objectives.VanillaObjective
Expand Down
Loading