Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions spd/app/backend/optim_cis.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,10 @@ def optimize_ci_values(
p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac,
p_anneal_final_p=config.imp_min_config.p_anneal_final_p,
p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac,
coeff_warmup_frac=config.imp_min_config.coeff_warmup_frac,
coeff_peak_multiplier=config.imp_min_config.coeff_peak_multiplier,
coeff_anneal_start_frac=config.imp_min_config.coeff_anneal_start_frac,
coeff_anneal_end_frac=config.imp_min_config.coeff_anneal_end_frac,
)

recon_loss = compute_recon_loss(recon_out, config.loss_config, target_out, device)
Expand Down Expand Up @@ -542,6 +546,10 @@ def importance_minimality_loss_per_element(
p_anneal_start_frac: float,
p_anneal_final_p: float | None,
p_anneal_end_frac: float,
coeff_warmup_frac: float,
coeff_peak_multiplier: float,
coeff_anneal_start_frac: float,
coeff_anneal_end_frac: float,
) -> Float[Tensor, " N"]:
"""Compute importance minimality loss independently for each batch element."""
losses = []
Expand All @@ -557,6 +565,10 @@ def importance_minimality_loss_per_element(
p_anneal_start_frac=p_anneal_start_frac,
p_anneal_final_p=p_anneal_final_p,
p_anneal_end_frac=p_anneal_end_frac,
coeff_warmup_frac=coeff_warmup_frac,
coeff_peak_multiplier=coeff_peak_multiplier,
coeff_anneal_start_frac=coeff_anneal_start_frac,
coeff_anneal_end_frac=coeff_anneal_end_frac,
)
)
return torch.stack(losses)
Expand Down Expand Up @@ -725,6 +737,10 @@ def optimize_ci_values_batched(
p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac,
p_anneal_final_p=config.imp_min_config.p_anneal_final_p,
p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac,
coeff_warmup_frac=config.imp_min_config.coeff_warmup_frac,
coeff_peak_multiplier=config.imp_min_config.coeff_peak_multiplier,
coeff_anneal_start_frac=config.imp_min_config.coeff_anneal_start_frac,
coeff_anneal_end_frac=config.imp_min_config.coeff_anneal_end_frac,
)

recon_losses = compute_recon_loss_batched(
Expand Down
20 changes: 20 additions & 0 deletions spd/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ class ImportanceMinimalityLossConfig(LossMetricConfig):
p_anneal_final_p: NonNegativeFloat | None = None
p_anneal_end_frac: Probability = 1.0
eps: NonNegativeFloat = 1e-12
coeff_warmup_frac: Probability = 0.0
coeff_peak_multiplier: NonNegativeFloat = 1.0
coeff_anneal_start_frac: Probability = 1.0
coeff_anneal_end_frac: Probability = 1.0

@model_validator(mode="before")
@classmethod
Expand All @@ -357,6 +361,22 @@ def migrate_old_fields(cls, data: dict[str, Any]) -> dict[str, Any]:
data["beta"] = 0.0
return data

@model_validator(mode="after")
def validate_scheduling_fracs(self) -> "ImportanceMinimalityLossConfig":
assert self.coeff_warmup_frac <= self.coeff_anneal_start_frac, (
f"coeff_warmup_frac ({self.coeff_warmup_frac}) must be <= "
f"coeff_anneal_start_frac ({self.coeff_anneal_start_frac})"
)
assert self.coeff_anneal_end_frac >= self.coeff_anneal_start_frac, (
f"coeff_anneal_end_frac ({self.coeff_anneal_end_frac}) must be >= "
f"coeff_anneal_start_frac ({self.coeff_anneal_start_frac})"
)
assert self.p_anneal_end_frac >= self.p_anneal_start_frac, (
f"p_anneal_end_frac ({self.p_anneal_end_frac}) must be >= "
f"p_anneal_start_frac ({self.p_anneal_start_frac})"
)
return self


class UniformKSubsetRoutingConfig(BaseConfig):
type: Literal["uniform_k_subset"] = "uniform_k_subset"
Expand Down
4 changes: 4 additions & 0 deletions spd/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def compute_losses(
p_anneal_start_frac=cfg.p_anneal_start_frac,
p_anneal_final_p=cfg.p_anneal_final_p,
p_anneal_end_frac=cfg.p_anneal_end_frac,
coeff_warmup_frac=cfg.coeff_warmup_frac,
coeff_peak_multiplier=cfg.coeff_peak_multiplier,
coeff_anneal_start_frac=cfg.coeff_anneal_start_frac,
coeff_anneal_end_frac=cfg.coeff_anneal_end_frac,
)
case UnmaskedReconLossConfig():
loss = unmasked_recon_loss(
Expand Down
55 changes: 49 additions & 6 deletions spd/metrics/importance_minimality_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ def _get_linear_annealed_p(
if p_anneal_final_p is None or p_anneal_start_frac >= 1.0:
return initial_p

assert p_anneal_end_frac >= p_anneal_start_frac, (
f"p_anneal_end_frac ({p_anneal_end_frac}) must be >= "
f"p_anneal_start_frac ({p_anneal_start_frac})"
)

if current_frac_of_training < p_anneal_start_frac:
return initial_p
elif current_frac_of_training >= p_anneal_end_frac:
Expand All @@ -49,6 +44,42 @@ def _get_linear_annealed_p(
return initial_p + (p_anneal_final_p - initial_p) * progress


def _get_coeff_multiplier(
current_frac_of_training: float,
coeff_warmup_frac: float,
coeff_peak_multiplier: float,
coeff_anneal_start_frac: float,
coeff_anneal_end_frac: float,
) -> float:
"""Calculate coefficient multiplier with warmup and annealing.

Schedule:
- [0, coeff_warmup_frac): linearly ramp 0 → coeff_peak_multiplier
- [coeff_warmup_frac, coeff_anneal_start_frac): constant coeff_peak_multiplier
- [coeff_anneal_start_frac, coeff_anneal_end_frac): linearly ramp coeff_peak_multiplier → 1.0
- [coeff_anneal_end_frac, 1.0]: constant 1.0
"""
# Warmup phase
if current_frac_of_training < coeff_warmup_frac:
if coeff_warmup_frac == 0.0:
return coeff_peak_multiplier
return coeff_peak_multiplier * current_frac_of_training / coeff_warmup_frac

# Constant phase between warmup and anneal
if current_frac_of_training < coeff_anneal_start_frac:
return coeff_peak_multiplier

# Past anneal end
if current_frac_of_training >= coeff_anneal_end_frac:
return 1.0

# Anneal phase: linear interpolation coeff_peak_multiplier → 1.0
progress = (current_frac_of_training - coeff_anneal_start_frac) / (
coeff_anneal_end_frac - coeff_anneal_start_frac
)
return coeff_peak_multiplier + (1.0 - coeff_peak_multiplier) * progress


def _importance_minimality_loss_update(
ci_upper_leaky: dict[str, Float[Tensor, "... C"]],
pnorm: float,
Expand Down Expand Up @@ -123,6 +154,10 @@ def importance_minimality_loss(
p_anneal_start_frac: float,
p_anneal_final_p: float | None,
p_anneal_end_frac: float,
coeff_warmup_frac: float,
coeff_peak_multiplier: float,
coeff_anneal_start_frac: float,
coeff_anneal_end_frac: float,
) -> Float[Tensor, ""]:
"""Compute importance minimality loss."""

Expand All @@ -137,12 +172,20 @@ def importance_minimality_loss(
)
dist_state = get_distributed_state()
world_size = dist_state.world_size if dist_state is not None else 1
return _importance_minimality_loss_compute(
loss = _importance_minimality_loss_compute(
per_component_sums=per_component_sums,
n_examples=n_examples,
beta=beta,
world_size=world_size,
)
coeff_multiplier = _get_coeff_multiplier(
current_frac_of_training=current_frac_of_training,
coeff_warmup_frac=coeff_warmup_frac,
coeff_peak_multiplier=coeff_peak_multiplier,
coeff_anneal_start_frac=coeff_anneal_start_frac,
coeff_anneal_end_frac=coeff_anneal_end_frac,
)
return loss * coeff_multiplier


class ImportanceMinimalityLoss(Metric):
Expand Down
52 changes: 52 additions & 0 deletions tests/metrics/test_importance_minimality_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def test_basic_l1_norm(self: object) -> None:
p_anneal_start_frac=1.0,
p_anneal_final_p=None,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
expected = torch.tensor(8.0)
assert torch.allclose(result, expected)
Expand All @@ -41,6 +45,10 @@ def test_basic_l2_norm(self: object) -> None:
p_anneal_start_frac=1.0,
p_anneal_final_p=None,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
expected = torch.tensor(13.0)
assert torch.allclose(result, expected)
Expand All @@ -61,6 +69,10 @@ def test_epsilon_stability(self: object) -> None:
p_anneal_start_frac=1.0,
p_anneal_final_p=None,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
expected = (0.0 + eps) ** 0.5 + (1.0 + eps) ** 0.5
assert torch.allclose(result, torch.tensor(expected))
Expand All @@ -77,6 +89,10 @@ def test_p_annealing_before_start(self: object) -> None:
p_anneal_start_frac=0.5,
p_anneal_final_p=1.0,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
# Should use p=2: 2^2 = 4
expected = torch.tensor(4.0)
Expand All @@ -96,6 +112,10 @@ def test_p_annealing_during(self: object) -> None:
p_anneal_start_frac=0.0,
p_anneal_final_p=1.0,
p_anneal_end_frac=0.5,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
# 2^1.5 = 2.828...
expected = torch.tensor(2.0**1.5)
Expand All @@ -113,6 +133,10 @@ def test_p_annealing_after_end(self: object) -> None:
p_anneal_start_frac=0.0,
p_anneal_final_p=1.0,
p_anneal_end_frac=0.5,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
# Should use p=1: 2^1 = 2
expected = torch.tensor(2.0)
Expand All @@ -130,6 +154,10 @@ def test_no_annealing_when_final_p_none(self: object) -> None:
p_anneal_start_frac=0.0,
p_anneal_final_p=None,
p_anneal_end_frac=0.5,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
# Should use p=2: 2^2 = 4
expected = torch.tensor(4.0)
Expand All @@ -150,6 +178,10 @@ def test_multiple_layers_aggregation(self: object) -> None:
p_anneal_start_frac=1.0,
p_anneal_final_p=None,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
# layer1: per_component_mean = [1, 1], sum = 2
# layer2: per_component_mean = [2, 2], sum = 4
Expand All @@ -175,6 +207,10 @@ def test_beta_zero_simple_sum(self: object) -> None:
p_anneal_start_frac=1.0,
p_anneal_final_p=None,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
expected = torch.tensor(5.0)
assert torch.allclose(result, expected)
Expand Down Expand Up @@ -211,6 +247,10 @@ def test_beta_logarithmic_penalty(self: object) -> None:
p_anneal_start_frac=1.0,
p_anneal_final_p=None,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
loss_beta_1 = importance_minimality_loss(
ci_upper_leaky=ci_upper_leaky,
Expand All @@ -221,6 +261,10 @@ def test_beta_logarithmic_penalty(self: object) -> None:
p_anneal_start_frac=1.0,
p_anneal_final_p=None,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)

assert torch.allclose(loss_beta_0, torch.tensor(expected_beta_0))
Expand All @@ -240,6 +284,10 @@ def test_beta_edge_cases(self: object) -> None:
p_anneal_start_frac=1.0,
p_anneal_final_p=None,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
assert torch.isfinite(result_small)
assert result_small >= 0
Expand All @@ -255,5 +303,9 @@ def test_beta_edge_cases(self: object) -> None:
p_anneal_start_frac=1.0,
p_anneal_final_p=None,
p_anneal_end_frac=1.0,
coeff_warmup_frac=0.0,
coeff_peak_multiplier=1.0,
coeff_anneal_start_frac=1.0,
coeff_anneal_end_frac=1.0,
)
assert torch.isfinite(result_large)
Loading