From 435beaad1f7b62b08c332757b723ac269ab73cd6 Mon Sep 17 00:00:00 2001 From: link7808 Date: Tue, 24 Mar 2026 05:04:48 -0400 Subject: [PATCH] Fix sparsity warmup for current SAE config layout --- src/saev/framework/train.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/saev/framework/train.py b/src/saev/framework/train.py index c776ca6..0e38cf0 100644 --- a/src/saev/framework/train.py +++ b/src/saev/framework/train.py @@ -105,6 +105,20 @@ class Config: """Where to log Slurm job stdout/stderr.""" +def _get_sparsity_coeff(sae: nn.SparseAutoencoder) -> float | None: + sparsity = getattr(sae.activation.cfg, "sparsity", None) + if sparsity is None or not hasattr(sparsity, "coeff"): + return None + return float(sparsity.coeff) + + +def _set_sparsity_coeff(sae: nn.SparseAutoencoder, value: float) -> None: + sparsity = getattr(sae.activation.cfg, "sparsity", None) + if sparsity is None or not hasattr(sparsity, "coeff"): + return + object.__setattr__(sparsity, "coeff", float(value)) + + @beartype.beartype def make_saes( cfgs: list[tuple[nn.SparseAutoencoderConfig, nn.ObjectiveConfig]], @@ -288,6 +302,7 @@ def train( grouped_pgs: list[list[dict[str, object]]] = [] optimizers: list[list[torch.optim.Optimizer]] = [] lr_schedulers: list[list[saev.utils.scheduling.WarmupCosine]] = [] + sparsity_schedulers: list[saev.utils.scheduling.Warmup | None] = [] for i, (sae, cfg, param_group) in enumerate(zip(saes, cfgs, param_groups)): if cfg.optim == "adam": @@ -318,6 +333,18 @@ def train( optimizers.append(opts) grouped_pgs.append(pgs) lr_schedulers.append(scheds) + current_sparsity_coeff = _get_sparsity_coeff(sae) + if current_sparsity_coeff is None: + sparsity_schedulers.append(None) + else: + _set_sparsity_coeff(sae, 0.0) + sparsity_schedulers.append( + saev.utils.scheduling.Warmup( + 0.0, + current_sparsity_coeff, + cfg.n_sparsity_warmup, + ) + ) param_groups = grouped_pgs @@ -420,6 +447,7 @@ def train( **{f"loss/{key}": val for key, val in loss.metrics().items()}, "progress/n_patches_seen": n_patches_seen, "progress/learning_rate": current_lr, + "progress/sparsity_coeff": _get_sparsity_coeff(sae) or 0.0, "metrics/explained_variance": explained_var.item(), "metrics/dead_unit_pct": dead_pct.item(), "metrics/dictionary_coherence": coherence.item(), @@ -450,8 +478,9 @@ def train( for pg, sched in zip(pgs, scheds): pg["lr"] = sched.step() - # for objective, scheduler in zip(objectives, sparsity_schedulers): - # objective.sparsity_coeff = scheduler.step() + for sae, scheduler in zip(saes, sparsity_schedulers): + if scheduler is not None: + _set_sparsity_coeff(sae, scheduler.step()) for opts in optimizers: for opt in opts: