diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 5ec2f15b5..7bef9614b 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -408,6 +408,10 @@ def optimize_ci_values( p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, p_anneal_final_p=config.imp_min_config.p_anneal_final_p, p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + coeff_warmup_frac=config.imp_min_config.coeff_warmup_frac, + coeff_peak_multiplier=config.imp_min_config.coeff_peak_multiplier, + coeff_anneal_start_frac=config.imp_min_config.coeff_anneal_start_frac, + coeff_anneal_end_frac=config.imp_min_config.coeff_anneal_end_frac, ) recon_loss = compute_recon_loss(recon_out, config.loss_config, target_out, device) @@ -542,6 +546,10 @@ def importance_minimality_loss_per_element( p_anneal_start_frac: float, p_anneal_final_p: float | None, p_anneal_end_frac: float, + coeff_warmup_frac: float, + coeff_peak_multiplier: float, + coeff_anneal_start_frac: float, + coeff_anneal_end_frac: float, ) -> Float[Tensor, " N"]: """Compute importance minimality loss independently for each batch element.""" losses = [] @@ -557,6 +565,10 @@ def importance_minimality_loss_per_element( p_anneal_start_frac=p_anneal_start_frac, p_anneal_final_p=p_anneal_final_p, p_anneal_end_frac=p_anneal_end_frac, + coeff_warmup_frac=coeff_warmup_frac, + coeff_peak_multiplier=coeff_peak_multiplier, + coeff_anneal_start_frac=coeff_anneal_start_frac, + coeff_anneal_end_frac=coeff_anneal_end_frac, ) ) return torch.stack(losses) @@ -725,6 +737,10 @@ def optimize_ci_values_batched( p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, p_anneal_final_p=config.imp_min_config.p_anneal_final_p, p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + coeff_warmup_frac=config.imp_min_config.coeff_warmup_frac, + coeff_peak_multiplier=config.imp_min_config.coeff_peak_multiplier, + coeff_anneal_start_frac=config.imp_min_config.coeff_anneal_start_frac, + coeff_anneal_end_frac=config.imp_min_config.coeff_anneal_end_frac, ) recon_losses = compute_recon_loss_batched( diff --git a/spd/configs.py b/spd/configs.py index efbfb3bb4..4d7bd5d68 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -340,6 +340,10 @@ class ImportanceMinimalityLossConfig(LossMetricConfig): p_anneal_final_p: NonNegativeFloat | None = None p_anneal_end_frac: Probability = 1.0 eps: NonNegativeFloat = 1e-12 + coeff_warmup_frac: Probability = 0.0 + coeff_peak_multiplier: NonNegativeFloat = 1.0 + coeff_anneal_start_frac: Probability = 1.0 + coeff_anneal_end_frac: Probability = 1.0 @model_validator(mode="before") @classmethod @@ -357,6 +361,22 @@ def migrate_old_fields(cls, data: dict[str, Any]) -> dict[str, Any]: data["beta"] = 0.0 return data + @model_validator(mode="after") + def validate_scheduling_fracs(self) -> "ImportanceMinimalityLossConfig": + assert self.coeff_warmup_frac <= self.coeff_anneal_start_frac, ( + f"coeff_warmup_frac ({self.coeff_warmup_frac}) must be <= " + f"coeff_anneal_start_frac ({self.coeff_anneal_start_frac})" + ) + assert self.coeff_anneal_end_frac >= self.coeff_anneal_start_frac, ( + f"coeff_anneal_end_frac ({self.coeff_anneal_end_frac}) must be >= " + f"coeff_anneal_start_frac ({self.coeff_anneal_start_frac})" + ) + assert self.p_anneal_end_frac >= self.p_anneal_start_frac, ( + f"p_anneal_end_frac ({self.p_anneal_end_frac}) must be >= " + f"p_anneal_start_frac ({self.p_anneal_start_frac})" + ) + return self + class UniformKSubsetRoutingConfig(BaseConfig): type: Literal["uniform_k_subset"] = "uniform_k_subset" diff --git a/spd/losses.py b/spd/losses.py index a35654bee..c9c0e3a0d 100644 --- a/spd/losses.py +++ b/spd/losses.py @@ -75,6 +75,10 @@ def compute_losses( p_anneal_start_frac=cfg.p_anneal_start_frac, p_anneal_final_p=cfg.p_anneal_final_p, p_anneal_end_frac=cfg.p_anneal_end_frac, + coeff_warmup_frac=cfg.coeff_warmup_frac, + coeff_peak_multiplier=cfg.coeff_peak_multiplier, + coeff_anneal_start_frac=cfg.coeff_anneal_start_frac, + coeff_anneal_end_frac=cfg.coeff_anneal_end_frac, ) case UnmaskedReconLossConfig(): loss = unmasked_recon_loss( diff --git a/spd/metrics/importance_minimality_loss.py b/spd/metrics/importance_minimality_loss.py index e0c332781..1497ed1c1 100644 --- a/spd/metrics/importance_minimality_loss.py +++ b/spd/metrics/importance_minimality_loss.py @@ -32,11 +32,6 @@ def _get_linear_annealed_p( if p_anneal_final_p is None or p_anneal_start_frac >= 1.0: return initial_p - assert p_anneal_end_frac >= p_anneal_start_frac, ( - f"p_anneal_end_frac ({p_anneal_end_frac}) must be >= " - f"p_anneal_start_frac ({p_anneal_start_frac})" - ) - if current_frac_of_training < p_anneal_start_frac: return initial_p elif current_frac_of_training >= p_anneal_end_frac: @@ -49,6 +44,42 @@ def _get_linear_annealed_p( return initial_p + (p_anneal_final_p - initial_p) * progress +def _get_coeff_multiplier( + current_frac_of_training: float, + coeff_warmup_frac: float, + coeff_peak_multiplier: float, + coeff_anneal_start_frac: float, + coeff_anneal_end_frac: float, +) -> float: + """Calculate coefficient multiplier with warmup and annealing. + + Schedule: + - [0, coeff_warmup_frac): linearly ramp 0 → coeff_peak_multiplier + - [coeff_warmup_frac, coeff_anneal_start_frac): constant coeff_peak_multiplier + - [coeff_anneal_start_frac, coeff_anneal_end_frac): linearly ramp coeff_peak_multiplier → 1.0 + - [coeff_anneal_end_frac, 1.0]: constant 1.0 + """ + # Warmup phase + if current_frac_of_training < coeff_warmup_frac: + if coeff_warmup_frac == 0.0: + return coeff_peak_multiplier + return coeff_peak_multiplier * current_frac_of_training / coeff_warmup_frac + + # Constant phase between warmup and anneal + if current_frac_of_training < coeff_anneal_start_frac: + return coeff_peak_multiplier + + # Past anneal end + if current_frac_of_training >= coeff_anneal_end_frac: + return 1.0 + + # Anneal phase: linear interpolation coeff_peak_multiplier → 1.0 + progress = (current_frac_of_training - coeff_anneal_start_frac) / ( + coeff_anneal_end_frac - coeff_anneal_start_frac + ) + return coeff_peak_multiplier + (1.0 - coeff_peak_multiplier) * progress + + def _importance_minimality_loss_update( ci_upper_leaky: dict[str, Float[Tensor, "... C"]], pnorm: float, @@ -123,6 +154,10 @@ def importance_minimality_loss( p_anneal_start_frac: float, p_anneal_final_p: float | None, p_anneal_end_frac: float, + coeff_warmup_frac: float, + coeff_peak_multiplier: float, + coeff_anneal_start_frac: float, + coeff_anneal_end_frac: float, ) -> Float[Tensor, ""]: """Compute importance minimality loss.""" @@ -137,12 +172,20 @@ def importance_minimality_loss( ) dist_state = get_distributed_state() world_size = dist_state.world_size if dist_state is not None else 1 - return _importance_minimality_loss_compute( + loss = _importance_minimality_loss_compute( per_component_sums=per_component_sums, n_examples=n_examples, beta=beta, world_size=world_size, ) + coeff_multiplier = _get_coeff_multiplier( + current_frac_of_training=current_frac_of_training, + coeff_warmup_frac=coeff_warmup_frac, + coeff_peak_multiplier=coeff_peak_multiplier, + coeff_anneal_start_frac=coeff_anneal_start_frac, + coeff_anneal_end_frac=coeff_anneal_end_frac, + ) + return loss * coeff_multiplier class ImportanceMinimalityLoss(Metric): diff --git a/tests/metrics/test_importance_minimality_loss.py b/tests/metrics/test_importance_minimality_loss.py index 647a00b1f..04728eb42 100644 --- a/tests/metrics/test_importance_minimality_loss.py +++ b/tests/metrics/test_importance_minimality_loss.py @@ -23,6 +23,10 @@ def test_basic_l1_norm(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) expected = torch.tensor(8.0) assert torch.allclose(result, expected) @@ -41,6 +45,10 @@ def test_basic_l2_norm(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) expected = torch.tensor(13.0) assert torch.allclose(result, expected) @@ -61,6 +69,10 @@ def test_epsilon_stability(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) expected = (0.0 + eps) ** 0.5 + (1.0 + eps) ** 0.5 assert torch.allclose(result, torch.tensor(expected)) @@ -77,6 +89,10 @@ def test_p_annealing_before_start(self: object) -> None: p_anneal_start_frac=0.5, p_anneal_final_p=1.0, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # Should use p=2: 2^2 = 4 expected = torch.tensor(4.0) @@ -96,6 +112,10 @@ def test_p_annealing_during(self: object) -> None: p_anneal_start_frac=0.0, p_anneal_final_p=1.0, p_anneal_end_frac=0.5, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # 2^1.5 = 2.828... expected = torch.tensor(2.0**1.5) @@ -113,6 +133,10 @@ def test_p_annealing_after_end(self: object) -> None: p_anneal_start_frac=0.0, p_anneal_final_p=1.0, p_anneal_end_frac=0.5, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # Should use p=1: 2^1 = 2 expected = torch.tensor(2.0) @@ -130,6 +154,10 @@ def test_no_annealing_when_final_p_none(self: object) -> None: p_anneal_start_frac=0.0, p_anneal_final_p=None, p_anneal_end_frac=0.5, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # Should use p=2: 2^2 = 4 expected = torch.tensor(4.0) @@ -150,6 +178,10 @@ def test_multiple_layers_aggregation(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # layer1: per_component_mean = [1, 1], sum = 2 # layer2: per_component_mean = [2, 2], sum = 4 @@ -175,6 +207,10 @@ def test_beta_zero_simple_sum(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) expected = torch.tensor(5.0) assert torch.allclose(result, expected) @@ -211,6 +247,10 @@ def test_beta_logarithmic_penalty(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) loss_beta_1 = importance_minimality_loss( ci_upper_leaky=ci_upper_leaky, @@ -221,6 +261,10 @@ def test_beta_logarithmic_penalty(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) assert torch.allclose(loss_beta_0, torch.tensor(expected_beta_0)) @@ -240,6 +284,10 @@ def test_beta_edge_cases(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) assert torch.isfinite(result_small) assert result_small >= 0 @@ -255,5 +303,9 @@ def test_beta_edge_cases(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) assert torch.isfinite(result_large) diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index d7acba607..1f173b8b4 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -173,6 +173,10 @@ def test_basic_l1_norm(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) expected = torch.tensor(8.0) assert torch.allclose(result, expected) @@ -191,6 +195,10 @@ def test_basic_l2_norm(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) expected = torch.tensor(13.0) assert torch.allclose(result, expected) @@ -211,6 +219,10 @@ def test_epsilon_stability(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) expected = (0.0 + eps) ** 0.5 + (1.0 + eps) ** 0.5 assert torch.allclose(result, torch.tensor(expected)) @@ -227,6 +239,10 @@ def test_p_annealing_before_start(self: object) -> None: p_anneal_start_frac=0.5, p_anneal_final_p=1.0, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # Should use p=2: 2^2 = 4 expected = torch.tensor(4.0) @@ -246,6 +262,10 @@ def test_p_annealing_during(self: object) -> None: p_anneal_start_frac=0.0, p_anneal_final_p=1.0, p_anneal_end_frac=0.5, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # 2^1.5 = 2.828... expected = torch.tensor(2.0**1.5) @@ -263,6 +283,10 @@ def test_p_annealing_after_end(self: object) -> None: p_anneal_start_frac=0.0, p_anneal_final_p=1.0, p_anneal_end_frac=0.5, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # Should use p=1: 2^1 = 2 expected = torch.tensor(2.0) @@ -280,6 +304,10 @@ def test_no_annealing_when_final_p_none(self: object) -> None: p_anneal_start_frac=0.0, p_anneal_final_p=None, p_anneal_end_frac=0.5, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # Should use p=2: 2^2 = 4 expected = torch.tensor(4.0) @@ -300,6 +328,10 @@ def test_multiple_layers_aggregation(self: object) -> None: p_anneal_start_frac=1.0, p_anneal_final_p=None, p_anneal_end_frac=1.0, + coeff_warmup_frac=0.0, + coeff_peak_multiplier=1.0, + coeff_anneal_start_frac=1.0, + coeff_anneal_end_frac=1.0, ) # layer1: per_component_mean = [1, 1], sum = 2 # layer2: per_component_mean = [2, 2], sum = 4