Module saev.training
Trains many SAEs in parallel to amortize the cost of loading a single batch of data over many SAE training runs.
Functions
def evaluate(cfgs: list[Train],
saes: torch.nn.modules.container.ModuleList,
objectives: torch.nn.modules.container.ModuleList) ‑> list[EvalMetrics]-
Evaluates SAE quality by counting dead and dense features and recording loss metrics. Also makes histogram plots to help human qualitative comparison.
The metrics computed are mean
L0
/L1
/MSE
losses, the number of dead, almost dead, and dense neurons, plus per-feature firing frequencies and mean values. A list ofEvalMetrics
is returned, one for each SAE. def init_b_dec_batched(saes: torch.nn.modules.container.ModuleList,
dataset: Dataset)def main(cfgs: list[Train]) ‑> list[str]
def make_hashable(obj)
def make_saes(cfgs: list[tuple[Relu | JumpRelu, Vanilla | Matryoshka]]) ‑> tuple[torch.nn.modules.container.ModuleList, torch.nn.modules.container.ModuleList, list[dict[str, object]]]
def split_cfgs(cfgs: list[Train]) ‑> list[list[Train]]
-
Splits configs into groups that can be parallelized.
Arguments
A list of configs from a sweep file.
Returns
A list of lists, where the configs in each sublist do not differ in any keys that are in
CANNOT_PARALLELIZE
. This means that each sublist is a valid "parallel" set of configs fortrain()
. def train(cfgs: list[Train]) ‑> tuple[torch.nn.modules.container.ModuleList, torch.nn.modules.container.ModuleList, ParallelWandbRun, int]
-
Explicitly declare the optimizer, schedulers, dataloader, etc outside of
main()
so that all the variables are dropped from scope and can be garbage collected.
Classes
class BatchLimiter (dataloader: torch.utils.data.dataloader.DataLoader, n_samples: int)
-
Limits the number of batches to only return
n_samples
total samples.Expand source code
class BatchLimiter: """ Limits the number of batches to only return `n_samples` total samples. """ def __init__(self, dataloader: torch.utils.data.DataLoader, n_samples: int): self.dataloader = dataloader self.n_samples = n_samples self.batch_size = dataloader.batch_size def __len__(self) -> int: return self.n_samples // self.batch_size def __iter__(self): self.n_seen = 0 while True: for batch in self.dataloader: yield batch # Sometimes we underestimate because the final batch in the dataloader might not be a full batch. self.n_seen += self.batch_size if self.n_seen > self.n_samples: return # We try to mitigate the above issue by ignoring the last batch if we don't have drop_last. if not self.dataloader.drop_last: self.n_seen -= self.batch_size
class EvalMetrics (l0: float,
l1: float,
mse: float,
n_dead: int,
n_almost_dead: int,
n_dense: int,
freqs: jaxtyping.Float[Tensor, 'd_sae'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
almost_dead_threshold: float,
dense_threshold: float)-
Results of evaluating a trained SAE on a datset.
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class EvalMetrics: """Results of evaluating a trained SAE on a datset.""" l0: float """Mean L0 across all examples.""" l1: float """Mean L1 across all examples.""" mse: float """Mean MSE across all examples.""" n_dead: int """Number of neurons that never fired on any example.""" n_almost_dead: int """Number of neurons that fired on fewer than `almost_dead_threshold` of examples.""" n_dense: int """Number of neurons that fired on more than `dense_threshold` of examples.""" freqs: Float[Tensor, " d_sae"] """How often each feature fired.""" mean_values: Float[Tensor, " d_sae"] """The mean value for each feature when it did fire.""" almost_dead_threshold: float """Threshold for an "almost dead" neuron.""" dense_threshold: float """Threshold for a dense neuron.""" def for_wandb(self) -> dict[str, int | float]: dct = dataclasses.asdict(self) # Store arrays as tables. dct["freqs"] = wandb.Table(columns=["freq"], data=dct["freqs"][:, None].numpy()) dct["mean_values"] = wandb.Table( columns=["mean_value"], data=dct["mean_values"][:, None].numpy() ) return {f"eval/{key}": value for key, value in dct.items()}
Class variables
var almost_dead_threshold : float
-
Threshold for an "almost dead" neuron.
var dense_threshold : float
-
Threshold for a dense neuron.
var freqs : jaxtyping.Float[Tensor, 'd_sae']
-
How often each feature fired.
var l0 : float
-
Mean L0 across all examples.
var l1 : float
-
Mean L1 across all examples.
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
-
The mean value for each feature when it did fire.
var mse : float
-
Mean MSE across all examples.
var n_almost_dead : int
-
Number of neurons that fired on fewer than
almost_dead_threshold
of examples. var n_dead : int
-
Number of neurons that never fired on any example.
var n_dense : int
-
Number of neurons that fired on more than
dense_threshold
of examples.
Methods
def for_wandb(self) ‑> dict[str, int | float]
class ParallelWandbRun (project: str,
cfgs: list[Train],
mode: str,
tags: list[str])-
Inspired by https://community.wandb.ai/t/is-it-possible-to-log-to-multiple-runs-simultaneously/4387/3.
Expand source code
class ParallelWandbRun: """ Inspired by https://community.wandb.ai/t/is-it-possible-to-log-to-multiple-runs-simultaneously/4387/3. """ def __init__( self, project: str, cfgs: list[config.Train], mode: str, tags: list[str] ): cfg, *cfgs = cfgs self.project = project self.cfgs = cfgs self.mode = mode self.tags = tags self.live_run = wandb.init( project=project, config=dataclasses.asdict(cfg), mode=mode, tags=tags ) self.metric_queues: list[MetricQueue] = [[] for _ in self.cfgs] def log(self, metrics: list[dict[str, object]], *, step: int): metric, *metrics = metrics self.live_run.log(metric, step=step) for queue, metric in zip(self.metric_queues, metrics): queue.append((step, metric)) def finish(self) -> list[str]: ids = [self.live_run.id] # Log the rest of the runs. self.live_run.finish() for queue, cfg in zip(self.metric_queues, self.cfgs): run = wandb.init( project=self.project, config=dataclasses.asdict(cfg), mode=self.mode, tags=self.tags + ["queued"], ) for step, metric in queue: run.log(metric, step=step) ids.append(run.id) run.finish() return ids
Methods
def finish(self) ‑> list[str]
def log(self, metrics: list[dict[str, object]], *, step: int)
class Scheduler
-
Expand source code
@beartype.beartype class Scheduler: def step(self) -> float: err_msg = f"{self.__class__.__name__} must implement step()." raise NotImplementedError(err_msg) def __repr__(self) -> str: err_msg = f"{self.__class__.__name__} must implement __repr__()." raise NotImplementedError(err_msg)
Subclasses
Methods
def step(self) ‑> float
class Warmup (init: float, final: float, n_steps: int)
-
Linearly increases from
init
tofinal
overn_warmup_steps
steps.Expand source code
@beartype.beartype class Warmup(Scheduler): """ Linearly increases from `init` to `final` over `n_warmup_steps` steps. """ def __init__(self, init: float, final: float, n_steps: int): self.final = final self.init = init self.n_steps = n_steps self._step = 0 def step(self) -> float: self._step += 1 if self._step < self.n_steps: return self.init + (self.final - self.init) * (self._step / self.n_steps) return self.final def __repr__(self) -> str: return f"Warmup(init={self.init}, final={self.final}, n_steps={self.n_steps})"
Ancestors
Methods
def step(self) ‑> float