From 254daa074c4a354a4a87b9f4fd71baf1bd0daf48 Mon Sep 17 00:00:00 2001 From: Jake Beattie Date: Sun, 17 Aug 2025 16:53:22 -0400 Subject: [PATCH 1/3] Implemented auxiliary loss objective Needs testing still --- src/saev/nn/modeling.py | 42 ++++++++++++++++++++++++++++++++ src/saev/nn/objectives.py | 50 ++++++++++++++++++++++++++++++++++++++ tests/test_nn_modeling.py | 19 +++++++++++++++ train.py | 51 +++++++++++++++++++++++++++++++++++++-- 4 files changed, 160 insertions(+), 2 deletions(-) diff --git a/src/saev/nn/modeling.py b/src/saev/nn/modeling.py index 5b480523..af9ea689 100644 --- a/src/saev/nn/modeling.py +++ b/src/saev/nn/modeling.py @@ -43,6 +43,13 @@ class BatchTopK: ActivationConfig = Relu | TopK | BatchTopK +@beartype.beartype +@dataclasses.dataclass(frozen=True) +class AuxiliaryConfig: + top_k: int = 512 + """How many dead latents to consider for auxiliary loss.""" + + @beartype.beartype @dataclasses.dataclass(frozen=True) class SparseAutoencoderConfig: @@ -318,6 +325,41 @@ def forward(self, x: Float[Tensor, "batch d_sae"]) -> Float[Tensor, "batch d_sae return torch.mul(mask, x) +class AuxiliaryLossActivation(torch.nn.Module): + """ + Auxiliary loss activation function. Used to take the top-k dead latents before calculating the auxiliary loss. + """ + + def __init__(self, cfg: AuxiliaryConfig = AuxiliaryConfig()): + super().__init__() + self.cfg = cfg + + def forward(self, f_x: Float[Tensor, "batch d_sae"], dead_latents: Float[Tensor, "batch d_sae"]) -> Float[Tensor, "batch d_sae"]: + """ + Apply auxiliary loss activation (top-k of dead latents) to the input tensor. + """ + + # First, mask out all but dead latents + f_x = f_x * dead_latents + + masked_dead_top_k = torch.zeros_like(f_x) + + # Now, populate top k of the dead latents + if self.cfg.top_k > 0 and dead_latents.sum() > 0: + # First, mask out dead latents + masked_dead_latents = f_x * dead_latents + + # Find top k of dead latents + k_vals, k_inds = torch.topk(masked_dead_latents, self.cfg.top_k, dim=1, sorted=False) + top_k_mask = torch.zeros_like(masked_dead_latents).scatter_( + dim=-1, index=k_inds, src=torch.ones_like(masked_dead_latents) + ) + + # Mask out all but top k dead latents + masked_dead_top_k = torch.mul(top_k_mask, f_x) + + return masked_dead_top_k + @beartype.beartype def get_activation(cfg: ActivationConfig) -> torch.nn.Module: diff --git a/src/saev/nn/objectives.py b/src/saev/nn/objectives.py index 3394a477..7159d57b 100644 --- a/src/saev/nn/objectives.py +++ b/src/saev/nn/objectives.py @@ -32,6 +32,18 @@ class Matryoshka: ObjectiveConfig = Vanilla | Matryoshka +@beartype.beartype +@dataclasses.dataclass(frozen=True, slots=True) +class Auxiliary: + """ + Config for the Auxiliary loss (not for the SAE itself, but for auxiliary loss). + + Reference paper is https://doi.org/10.48550/arXiv.2412.06410 + """ + + aux_coeff: float = 0.03125 + """Coefficient for the auxiliary loss term.""" + @jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True, slots=True) class Loss: @@ -163,6 +175,44 @@ def forward( return MatryoshkaLoss(mse_loss, sparsity_loss, l0, l1) +@jaxtyped(typechecker=beartype.beartype) +@dataclasses.dataclass(frozen=True, slots=True) +class AuxiliaryLoss(Loss): + """The vanilla loss terms for an training batch.""" + + mse: Float[Tensor, ""] + """Reconstruction loss (mean squared error).""" + + @property + def loss(self) -> Float[Tensor, ""]: + """Total loss.""" + return self.mse + + def metrics(self) -> dict[str, object]: + return { + "loss": self.loss.item(), + "mse": self.mse.item(), + } + + +@jaxtyped(typechecker=beartype.beartype) +class AuxiliaryObjective(Objective): + def __init__(self, cfg: Auxiliary): + super().__init__() + self.cfg = cfg + + def forward( + self, + x: Float[Tensor, "batch d_model"], + x_hat: Float[Tensor, "batch d_model"], + ) -> VanillaLoss: + # Some values of x and x_hat can be very large. We can calculate a safe MSE + mse_loss = mean_squared_err(x_hat, x) + + mse_loss = mse_loss.mean() + + return AuxiliaryLoss(self.cfg.aux_coeff * mse_loss) + @beartype.beartype def get_objective(cfg: ObjectiveConfig) -> Objective: if isinstance(cfg, Vanilla): diff --git a/tests/test_nn_modeling.py b/tests/test_nn_modeling.py index cf954b36..26cd1370 100644 --- a/tests/test_nn_modeling.py +++ b/tests/test_nn_modeling.py @@ -60,6 +60,25 @@ def batch_topk_cfgs(): return st.builds(modeling.BatchTopK, top_k=st.sampled_from([1, 2, 4, 8])) +def aux_cfgs(): + return st.builds(modeling.AuxiliaryConfig, top_k=st.sampled_from([1, 2, 4, 8])) + + +@given( + cfg=aux_cfgs(), + batch=st.integers(min_value=1, max_value=4), + d_sae=st.integers(min_value=256, max_value=2048), +) +def test_auxiliary_activation(cfg, batch, d_sae): + act = modeling.get_activation(cfg) + x = torch.rand(batch, d_sae) + y = act(x) + + assert y.shape == (batch, d_sae) + # Check that only k elements are non-zero per sample + assert (y != 0).sum(dim=1).eq(cfg.top_k).all() + + @given( cfg=topk_cfgs(), batch=st.integers(min_value=1, max_value=4), diff --git a/train.py b/train.py index 141a438e..2c8e54a9 100644 --- a/train.py +++ b/train.py @@ -63,6 +63,14 @@ class Config: default_factory=nn.objectives.Vanilla ) """SAE loss configuration.""" + auxiliary_loss: bool = False + """Auxiliary Loss configuration.""" + auxiliary_loss_coeff: float = 0.03125 + """Coefficient for the auxiliary loss term.""" + tokens_until_dead: int = 10_000_000 + """Number of tokens for feature to not fire after which a feature is considered dead.""" + dead_top_k: int = 512 + """Number of dead features to reconstruct from.""" n_sparsity_warmup: int = 0 """Number of sparsity coefficient warmup steps.""" lr: float = 0.0004 @@ -211,10 +219,35 @@ def train( objectives.train() objectives = objectives.to(cfg.device) + + aux_activations = [ + nn.modeling.AuxiliaryLossActivation( + nn.modeling.AuxiliaryConfig(c.dead_top_k) + ) + for c in cfgs + ] + + aux_objectives = [ + nn.objectives.AuxiliaryObjective( + nn.objectives.Auxiliary(c.auxiliary_loss_coeff) + ) + for c in cfgs + ] + global_step, n_patches_seen = 0, 0 p_dataloader, p_children, last_rb, last_t = None, None, 0, time.time() + iterations_dead = [torch.zeros( + (s.cfg.d_sae), dtype=torch.float, device=cfg.device) + for s in saes + ] + + dead_latents = [ + torch.zeros((s.cfg.d_sae), dtype=torch.float, device=cfg.device) + for s in saes + ] + for batch in helpers.progress(dataloader, every=cfg.log_every): if p_dataloader is None: p_dataloader = psutil.Process(dataloader.manager_pid) @@ -227,15 +260,29 @@ def train( losses = [] x_hats = [] f_xs = [] - for sae, objective in zip(saes, objectives): + for sae, objective, aux_activation, aux_objective, iters_dead, dead_lts in zip(saes, objectives, aux_activations, aux_objectives, iterations_dead, dead_latents): if isinstance(objective, nn.objectives.MatryoshkaObjective): # Specific case has to be given for Matryoshka SAEs since we need to decode several times with varying prefix lengths x_hat, f_x = sae.matryoshka_forward(acts_BD, cfg.n_prefixes) else: x_hat, f_x = sae(acts_BD) + x_hats.append(x_hat) f_xs.append(f_x) - losses.append(objective(acts_BD, f_x, x_hat)) + if cfg.auxiliary_loss: + # Auxiliary loss is a separate term from the main objective, so we add it separately. + aux_f_x = aux_activation(f_x, dead_latents=dead_lts) + aux_x_hat = sae.decode(aux_f_x) + + losses.append(objective(acts_BD, f_x, x_hat) + + aux_objective(acts_BD, aux_x_hat)) + else: + losses.append(objective(acts_BD, f_x, x_hat)) + + # Count if feature was dead this iteration, update dead latents mask + iters_dead += ((f_x.abs() > 1e-8).sum(0) == 0).float() * acts_BD.shape[0] + iters_dead[(f_x.abs() > 1e-8).sum(0) != 0] = 0 + dead_lts = (iters_dead > cfg.tokens_until_dead).sum(0).float() n_patches_seen += len(acts_BD) From 86033456fdfd07a16acd69950b11dd2212f0d1e2 Mon Sep 17 00:00:00 2001 From: Jake Beattie Date: Tue, 2 Sep 2025 11:02:51 -0400 Subject: [PATCH 2/3] Implementing tests for auxiliary loss --- REGRESSIONS.md | 17 +++--- src/saev/nn/modeling.py | 13 +++- src/saev/nn/objectives.py | 4 +- tests/test_nn_modeling.py | 117 +++++++++++++++++++++++++++++++++++- tests/test_nn_objectives.py | 22 +++++++ train.py | 29 +++++---- 6 files changed, 176 insertions(+), 26 deletions(-) diff --git a/REGRESSIONS.md b/REGRESSIONS.md index c189038d..faccd255 100644 --- a/REGRESSIONS.md +++ b/REGRESSIONS.md @@ -1,15 +1,16 @@ # Regressions -Last checked: 2025-07-31 +Last checked: 2025-08-19 -# 10 failing test(s) +# 11 failing test(s) +- tests/test_nn_modeling.py::test_auxiliary_activation +- tests/test_nn_modeling.py::test_auxiliary_activation_k_exceeds_size +- tests/test_nn_modeling.py::test_batch_topk_activation +- tests/test_nn_objectives.py::test_auxiliary_coeff +- tests/test_nn_objectives.py::test_auxiliary_mse_same - tests/test_nn_objectives.py::test_safe_mse_hypothesis -- tests/test_ordered_dataloader.py::test_ordered_dataloader_with_tiny_fake_dataset -- tests/test_reservoir_buffer.py::test_blocking_get_when_empty[proc] -- tests/test_reservoir_buffer.py::test_blocking_put_when_full[proc] -- tests/test_ring_buffer.py::test_blocking_get_when_empty[proc] -- tests/test_ring_buffer.py::test_blocking_put_when_full[proc] +- tests/test_writers_properties.py::test_dataloader_batches - tests/test_writers_properties.py::test_metadata_json_has_required_keys - tests/test_writers_properties.py::test_roundtrip - tests/test_writers_properties.py::test_shard_size_consistency @@ -17,4 +18,4 @@ Last checked: 2025-07-31 # Coverage -Coverage: 1210/1816 lines (66.6%) +Coverage: 694/1933 lines (35.9%) diff --git a/src/saev/nn/modeling.py b/src/saev/nn/modeling.py index af9ea689..7329209d 100644 --- a/src/saev/nn/modeling.py +++ b/src/saev/nn/modeling.py @@ -325,6 +325,7 @@ def forward(self, x: Float[Tensor, "batch d_sae"]) -> Float[Tensor, "batch d_sae return torch.mul(mask, x) + class AuxiliaryLossActivation(torch.nn.Module): """ Auxiliary loss activation function. Used to take the top-k dead latents before calculating the auxiliary loss. @@ -334,7 +335,11 @@ def __init__(self, cfg: AuxiliaryConfig = AuxiliaryConfig()): super().__init__() self.cfg = cfg - def forward(self, f_x: Float[Tensor, "batch d_sae"], dead_latents: Float[Tensor, "batch d_sae"]) -> Float[Tensor, "batch d_sae"]: + def forward( + self, + f_x: Float[Tensor, "batch d_sae"], + dead_latents: Float[Tensor, "batch d_sae"], + ) -> Float[Tensor, "batch d_sae"]: """ Apply auxiliary loss activation (top-k of dead latents) to the input tensor. """ @@ -350,11 +355,13 @@ def forward(self, f_x: Float[Tensor, "batch d_sae"], dead_latents: Float[Tensor, masked_dead_latents = f_x * dead_latents # Find top k of dead latents - k_vals, k_inds = torch.topk(masked_dead_latents, self.cfg.top_k, dim=1, sorted=False) + k_vals, k_inds = torch.topk( + masked_dead_latents, self.cfg.top_k, dim=1, sorted=False + ) top_k_mask = torch.zeros_like(masked_dead_latents).scatter_( dim=-1, index=k_inds, src=torch.ones_like(masked_dead_latents) ) - + # Mask out all but top k dead latents masked_dead_top_k = torch.mul(top_k_mask, f_x) diff --git a/src/saev/nn/objectives.py b/src/saev/nn/objectives.py index 7159d57b..ee945952 100644 --- a/src/saev/nn/objectives.py +++ b/src/saev/nn/objectives.py @@ -37,13 +37,14 @@ class Matryoshka: class Auxiliary: """ Config for the Auxiliary loss (not for the SAE itself, but for auxiliary loss). - + Reference paper is https://doi.org/10.48550/arXiv.2412.06410 """ aux_coeff: float = 0.03125 """Coefficient for the auxiliary loss term.""" + @jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True, slots=True) class Loss: @@ -213,6 +214,7 @@ def forward( return AuxiliaryLoss(self.cfg.aux_coeff * mse_loss) + @beartype.beartype def get_objective(cfg: ObjectiveConfig) -> Objective: if isinstance(cfg, Vanilla): diff --git a/tests/test_nn_modeling.py b/tests/test_nn_modeling.py index 26cd1370..67af6ec9 100644 --- a/tests/test_nn_modeling.py +++ b/tests/test_nn_modeling.py @@ -101,7 +101,7 @@ def test_topk_activation(cfg, batch, d_sae): def test_batch_topk_activation(cfg, batch, d_sae): act = modeling.get_activation(cfg) x = torch.randn(batch, d_sae) - y = act(x) + y = act(x, torch.ones_like(x)) assert y.shape == (batch, d_sae) # Check that only k elements are non-zero per sample assert (y != 0).sum(dim=1).eq(cfg.top_k).all() @@ -133,6 +133,20 @@ def test_topk_ties(): assert y[y != 0].unique().item() == 2.0 +def test_auxiliary_activation_ties(): + """Test Auxiliary activation behavior with tied values.""" + cfg = modeling.AuxiliaryConfig(top_k=2) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[2.0, 2.0, 2.0, 2.0]]) + y = act(x, torch.ones_like(x)) + + # Should select first k elements in case of ties + assert (y != 0).sum() == 2 + # Verify the selected values are correct + assert y[y != 0].unique().item() == 2.0 + + def test_topk_k_equals_size(): """Test TopK when k equals tensor size.""" cfg = modeling.TopK(top_k=4) @@ -145,6 +159,30 @@ def test_topk_k_equals_size(): torch.testing.assert_close(y, x) +def test_auxiliary_activation_k_equals_size(): + """Test Auxiliary activation when k equals tensor size.""" + cfg = modeling.AuxiliaryConfig(top_k=4) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[5.0, 1.0, 3.0, 2.0]]) + y = act(x, torch.ones_like(x)) + + # All values should be preserved + torch.testing.assert_close(y, x) + + +def test_auxiliary_activation_k_exceeds_size(): + """Test Auxiliary activation when k exceeds tensor size.""" + cfg = modeling.AuxiliaryConfig(top_k=8) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[5.0, 1.0, 3.0, 2.0]]) + y = act(x, torch.ones_like(x)) + + # All values should be preserved + torch.testing.assert_close(y, x) + + def test_topk_negative_values(): """Test TopK with negative values.""" cfg = modeling.TopK(top_k=2) @@ -158,6 +196,19 @@ def test_topk_negative_values(): torch.testing.assert_close(y, expected) +def test_auxiliary_activation_negative_values(): + """Test Auxiliary activation with negative values.""" + cfg = modeling.AuxiliaryConfig(top_k=2) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[-5.0, -1.0, -3.0, -2.0]]) + y = act(x, torch.ones_like(x)) + + # Should select -1.0 and -2.0 (largest values) + expected = torch.tensor([[0.0, -1.0, 0.0, -2.0]]) + torch.testing.assert_close(y, expected) + + def test_topk_gradient_flow(): """Test that gradients flow correctly through TopK.""" cfg = modeling.TopK(top_k=2) @@ -175,6 +226,23 @@ def test_topk_gradient_flow(): torch.testing.assert_close(x.grad, expected_grad) +def test_auxiliary_activation_gradient_flow(): + """Test that gradients flow correctly through Auxiliary activation.""" + cfg = modeling.AuxiliaryConfig(top_k=2) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[5.0, 1.0, 3.0, 2.0], [2.0, 4.0, 1.0, 3.0]], requires_grad=True) + y = act(x, torch.ones_like(x)) + + # Create a simple loss (sum of outputs) + loss = y.sum() + loss.backward() + + # Expected gradient: 1.0 for selected elements, 0.0 for others + expected_grad = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]) + torch.testing.assert_close(x.grad, expected_grad) + + def test_topk_gradient_sparsity(): """Verify gradient sparsity matches forward pass selection.""" cfg = modeling.TopK(top_k=3) @@ -199,6 +267,30 @@ def test_topk_gradient_sparsity(): torch.testing.assert_close(selected_grads, expected_grads) +def test_auxiliary_activation_gradient_sparsity(): + """Verify gradient sparsity matches forward pass selection for Auxiliary activation.""" + cfg = modeling.AuxiliaryConfig(top_k=3) + act = modeling.AuxiliaryLossActivation(cfg) + + torch.manual_seed(42) + x = torch.randn(2, 8, requires_grad=True) + y = act(x, torch.ones_like(x)) + + # Use a different upstream gradient + grad_output = torch.randn_like(y) + y.backward(grad_output) + + # Check that gradient sparsity matches forward pass + forward_mask = (y != 0).float() + grad_mask = (x.grad != 0).float() + torch.testing.assert_close(forward_mask, grad_mask) + + # Verify gradient values for selected elements + selected_grads = x.grad * forward_mask + expected_grads = grad_output * forward_mask + torch.testing.assert_close(selected_grads, expected_grads) + + def test_topk_zero_gradient_for_unselected(): """Explicitly verify that non-selected elements have exactly 0.0 gradients.""" cfg = modeling.TopK(top_k=2) @@ -220,6 +312,29 @@ def test_topk_zero_gradient_for_unselected(): torch.testing.assert_close(x.grad[0, 5], torch.tensor(1.0)) +def test_auxiliary_activation_zero_gradient_for_unselected(): + """Explicitly verify that non-selected elements have exactly 0.0 gradients.""" + cfg = modeling.AuxiliaryConfig(top_k=2) + act = modeling.AuxiliaryLossActivation(cfg) + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], requires_grad=True) + y = act(x, torch.ones_like(x)) + + loss = y.sum() + loss.backward() + + # Elements at indices 0, 1, 2, 3 should have zero gradients + torch.testing.assert_close(x.grad[0, 0], torch.tensor(0.0)) + torch.testing.assert_close(x.grad[0, 1], torch.tensor(0.0)) + torch.testing.assert_close(x.grad[0, 2], torch.tensor(0.0)) + torch.testing.assert_close(x.grad[0, 3], torch.tensor(0.0)) + # Elements at indices 4, 5 should have non-zero gradients + torch.testing.assert_close(x.grad[0, 4], torch.tensor(1.0)) + torch.testing.assert_close(x.grad[0, 5], torch.tensor(1.0)) + + # TODO: Add portion of this test that looks at values masked out by the dead latent mask. + + # BatchTopK Edge Case Tests def test_batchtopk_basic_forward(): """Test basic BatchTopK forward pass with known values.""" diff --git a/tests/test_nn_objectives.py b/tests/test_nn_objectives.py index 6aefdeb6..102b8221 100644 --- a/tests/test_nn_objectives.py +++ b/tests/test_nn_objectives.py @@ -46,6 +46,28 @@ def test_safe_mse_large_x(): assert not safe.isnan().any() +def test_auxiliary_mse_same(): + x = torch.zeros((45, 12), dtype=torch.float) + x_hat = torch.ones((45, 12), dtype=torch.float) + aux_objective = objectives.AuxiliaryObjective(objectives.Auxiliary(aux_coeff=1)) + torch.testing.assert_close( + aux_objective(x_hat, x), + objectives.mean_squared_err(x_hat, x), + dtype=torch.float, + ) + + +def test_auxiliary_coeff(): + x = torch.zeros((45, 12), dtype=torch.float) + x_hat = torch.ones((45, 12), dtype=torch.float) + aux_objective = objectives.AuxiliaryObjective(objectives.Auxiliary(aux_coeff=0.5)) + torch.testing.assert_close( + aux_objective(x_hat, x), + 0.5 * objectives.mean_squared_err(x_hat, x), + dtype=torch.float, + ) + + def test_factories(): assert isinstance( objectives.get_objective(objectives.Vanilla()), objectives.VanillaObjective diff --git a/train.py b/train.py index 2c8e54a9..9a2e6098 100644 --- a/train.py +++ b/train.py @@ -28,13 +28,13 @@ import psutil import torch import tyro -import wandb from jaxtyping import Float from torch import Tensor import saev.data.shuffled import saev.utils.scheduling import saev.utils.wandb +import wandb from saev import helpers, nn logger = logging.getLogger("train.py") @@ -219,11 +219,8 @@ def train( objectives.train() objectives = objectives.to(cfg.device) - aux_activations = [ - nn.modeling.AuxiliaryLossActivation( - nn.modeling.AuxiliaryConfig(c.dead_top_k) - ) + nn.modeling.AuxiliaryLossActivation(nn.modeling.AuxiliaryConfig(c.dead_top_k)) for c in cfgs ] @@ -238,14 +235,12 @@ def train( p_dataloader, p_children, last_rb, last_t = None, None, 0, time.time() - iterations_dead = [torch.zeros( - (s.cfg.d_sae), dtype=torch.float, device=cfg.device) - for s in saes + iterations_dead = [ + torch.zeros((s.cfg.d_sae), dtype=torch.float, device=cfg.device) for s in saes ] dead_latents = [ - torch.zeros((s.cfg.d_sae), dtype=torch.float, device=cfg.device) - for s in saes + torch.zeros((s.cfg.d_sae), dtype=torch.float, device=cfg.device) for s in saes ] for batch in helpers.progress(dataloader, every=cfg.log_every): @@ -260,7 +255,14 @@ def train( losses = [] x_hats = [] f_xs = [] - for sae, objective, aux_activation, aux_objective, iters_dead, dead_lts in zip(saes, objectives, aux_activations, aux_objectives, iterations_dead, dead_latents): + for sae, objective, aux_activation, aux_objective, iters_dead, dead_lts in zip( + saes, + objectives, + aux_activations, + aux_objectives, + iterations_dead, + dead_latents, + ): if isinstance(objective, nn.objectives.MatryoshkaObjective): # Specific case has to be given for Matryoshka SAEs since we need to decode several times with varying prefix lengths x_hat, f_x = sae.matryoshka_forward(acts_BD, cfg.n_prefixes) @@ -274,8 +276,9 @@ def train( aux_f_x = aux_activation(f_x, dead_latents=dead_lts) aux_x_hat = sae.decode(aux_f_x) - losses.append(objective(acts_BD, f_x, x_hat) + - aux_objective(acts_BD, aux_x_hat)) + losses.append( + objective(acts_BD, f_x, x_hat) + aux_objective(acts_BD, aux_x_hat) + ) else: losses.append(objective(acts_BD, f_x, x_hat)) From dd7827d8d056743cb1cf34f869bb7a5e2249b2a0 Mon Sep 17 00:00:00 2001 From: Jake Beattie Date: Sun, 14 Sep 2025 15:34:34 -0400 Subject: [PATCH 3/3] Fixed tests, bugs with aux loss --- src/saev/nn/modeling.py | 2 +- tests/test_nn_modeling.py | 13 ++++++------- tests/test_nn_objectives.py | 20 ++++++++++---------- train.py | 2 +- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/saev/nn/modeling.py b/src/saev/nn/modeling.py index 7329209d..8d00351d 100644 --- a/src/saev/nn/modeling.py +++ b/src/saev/nn/modeling.py @@ -356,7 +356,7 @@ def forward( # Find top k of dead latents k_vals, k_inds = torch.topk( - masked_dead_latents, self.cfg.top_k, dim=1, sorted=False + masked_dead_latents, min(self.cfg.top_k, masked_dead_latents.shape[1]), dim=1, sorted=False ) top_k_mask = torch.zeros_like(masked_dead_latents).scatter_( dim=-1, index=k_inds, src=torch.ones_like(masked_dead_latents) diff --git a/tests/test_nn_modeling.py b/tests/test_nn_modeling.py index 67af6ec9..8b504685 100644 --- a/tests/test_nn_modeling.py +++ b/tests/test_nn_modeling.py @@ -70,13 +70,14 @@ def aux_cfgs(): d_sae=st.integers(min_value=256, max_value=2048), ) def test_auxiliary_activation(cfg, batch, d_sae): - act = modeling.get_activation(cfg) + act = modeling.AuxiliaryLossActivation(cfg) + dead_lts = torch.zeros((batch, d_sae)) x = torch.rand(batch, d_sae) - y = act(x) + y = act(x, dead_lts) assert y.shape == (batch, d_sae) # Check that only k elements are non-zero per sample - assert (y != 0).sum(dim=1).eq(cfg.top_k).all() + assert (y != 0).sum(dim=1).le(cfg.top_k).all() @given( @@ -101,7 +102,7 @@ def test_topk_activation(cfg, batch, d_sae): def test_batch_topk_activation(cfg, batch, d_sae): act = modeling.get_activation(cfg) x = torch.randn(batch, d_sae) - y = act(x, torch.ones_like(x)) + y = act(x) assert y.shape == (batch, d_sae) # Check that only k elements are non-zero per sample assert (y != 0).sum(dim=1).eq(cfg.top_k).all() @@ -318,7 +319,7 @@ def test_auxiliary_activation_zero_gradient_for_unselected(): act = modeling.AuxiliaryLossActivation(cfg) x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], requires_grad=True) - y = act(x, torch.ones_like(x)) + y = act(x, torch.tensor([0, 0, 0, 0, 1, 1])) loss = y.sum() loss.backward() @@ -332,8 +333,6 @@ def test_auxiliary_activation_zero_gradient_for_unselected(): torch.testing.assert_close(x.grad[0, 4], torch.tensor(1.0)) torch.testing.assert_close(x.grad[0, 5], torch.tensor(1.0)) - # TODO: Add portion of this test that looks at values masked out by the dead latent mask. - # BatchTopK Edge Case Tests def test_batchtopk_basic_forward(): diff --git a/tests/test_nn_objectives.py b/tests/test_nn_objectives.py index 102b8221..b8ca9637 100644 --- a/tests/test_nn_objectives.py +++ b/tests/test_nn_objectives.py @@ -47,24 +47,24 @@ def test_safe_mse_large_x(): def test_auxiliary_mse_same(): - x = torch.zeros((45, 12), dtype=torch.float) + x = torch.ones((45, 12), dtype=torch.float) x_hat = torch.ones((45, 12), dtype=torch.float) - aux_objective = objectives.AuxiliaryObjective(objectives.Auxiliary(aux_coeff=1)) + aux_objective = objectives.AuxiliaryObjective(objectives.Auxiliary(aux_coeff=1.0)) torch.testing.assert_close( - aux_objective(x_hat, x), - objectives.mean_squared_err(x_hat, x), - dtype=torch.float, + aux_objective(x_hat, x).loss, + objectives.mean_squared_err(x_hat, x).mean(), ) def test_auxiliary_coeff(): - x = torch.zeros((45, 12), dtype=torch.float) - x_hat = torch.ones((45, 12), dtype=torch.float) + x = torch.ones((45, 12), dtype=torch.float) + x_hat = torch.full((45, 12), 3, dtype=torch.float) aux_objective = objectives.AuxiliaryObjective(objectives.Auxiliary(aux_coeff=0.5)) + print(aux_objective(x_hat, x).loss) + print(0.5 * objectives.mean_squared_err(x_hat, x).mean()) torch.testing.assert_close( - aux_objective(x_hat, x), - 0.5 * objectives.mean_squared_err(x_hat, x), - dtype=torch.float, + aux_objective(x_hat, x).loss, + 0.5 * objectives.mean_squared_err(x_hat, x).mean(), ) diff --git a/train.py b/train.py index 9a2e6098..8f8ec49f 100644 --- a/train.py +++ b/train.py @@ -274,7 +274,7 @@ def train( if cfg.auxiliary_loss: # Auxiliary loss is a separate term from the main objective, so we add it separately. aux_f_x = aux_activation(f_x, dead_latents=dead_lts) - aux_x_hat = sae.decode(aux_f_x) + aux_x_hat = torch.matmul(sae.W_dec, aux_f_x) losses.append( objective(acts_BD, f_x, x_hat) + aux_objective(acts_BD, aux_x_hat)