Module saev.training
Trains many SAEs in parallel to amortize the cost of loading a single batch of data over many SAE training runs.
Functions
def evaluate(cfgs: list[Train],
saes: torch.nn.modules.container.ModuleList) ‑> list[EvalMetrics]-
Evaluates SAE quality by counting the number of dead features and the number of dense features. Also makes histogram plots to help human qualitative comparison.
TODO
Develop automatic methods to use histogram and feature frequencies to evaluate quality with a single number.
def init_b_dec_batched(saes: torch.nn.modules.container.ModuleList,
dataset: Dataset)def main(cfgs: list[Train]) ‑> list[str]
def make_hashable(obj)
def make_saes(cfgs: list[SparseAutoencoder]) ‑> tuple[torch.nn.modules.container.ModuleList, list[dict[str, object]]]
def split_cfgs(cfgs: list[Train]) ‑> list[list[Train]]
-
Splits configs into groups that can be parallelized.
Arguments
A list of configs from a sweep file.
Returns
A list of lists, where the configs in each sublist do not differ in any keys that are in
CANNOT_PARALLELIZE
. This means that each sublist is a valid "parallel" set of configs fortrain()
. def train(cfgs: list[Train]) ‑> tuple[torch.nn.modules.container.ModuleList, ParallelWandbRun, int]
-
Explicitly declare the optimizer, schedulers, dataloader, etc outside of
main()
so that all the variables are dropped from scope and can be garbage collected.
Classes
class BatchLimiter (dataloader: torch.utils.data.dataloader.DataLoader, n_samples: int)
-
Limits the number of batches to only return
n_samples
total samples.Expand source code
class BatchLimiter: """ Limits the number of batches to only return `n_samples` total samples. """ def __init__(self, dataloader: torch.utils.data.DataLoader, n_samples: int): self.dataloader = dataloader self.n_samples = n_samples self.batch_size = dataloader.batch_size def __len__(self) -> int: return self.n_samples // self.batch_size def __iter__(self): self.n_seen = 0 while True: for batch in self.dataloader: yield batch # Sometimes we underestimate because the final batch in the dataloader might not be a full batch. self.n_seen += self.batch_size if self.n_seen > self.n_samples: return # We try to mitigate the above issue by ignoring the last batch if we don't have drop_last. if not self.dataloader.drop_last: self.n_seen -= self.batch_size
class EvalMetrics (l0: float,
l1: float,
mse: float,
n_dead: int,
n_almost_dead: int,
n_dense: int,
freqs: jaxtyping.Float[Tensor, 'd_sae'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
almost_dead_threshold: float,
dense_threshold: float)-
Results of evaluating a trained SAE on a datset.
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class EvalMetrics: """Results of evaluating a trained SAE on a datset.""" l0: float """Mean L0 across all examples.""" l1: float """Mean L1 across all examples.""" mse: float """Mean MSE across all examples.""" n_dead: int """Number of neurons that never fired on any example.""" n_almost_dead: int """Number of neurons that fired on fewer than `almost_dead_threshold` of examples.""" n_dense: int """Number of neurons that fired on more than `dense_threshold` of examples.""" freqs: Float[Tensor, " d_sae"] """How often each feature fired.""" mean_values: Float[Tensor, " d_sae"] """The mean value for each feature when it did fire.""" almost_dead_threshold: float """Threshold for an "almost dead" neuron.""" dense_threshold: float """Threshold for a dense neuron.""" def for_wandb(self) -> dict[str, int | float]: dct = dataclasses.asdict(self) # Store arrays as tables. dct["freqs"] = wandb.Table(columns=["freq"], data=dct["freqs"][:, None].numpy()) dct["mean_values"] = wandb.Table( columns=["mean_value"], data=dct["mean_values"][:, None].numpy() ) return {f"eval/{key}": value for key, value in dct.items()}
Class variables
var almost_dead_threshold : float
-
Threshold for an "almost dead" neuron.
var dense_threshold : float
-
Threshold for a dense neuron.
var freqs : jaxtyping.Float[Tensor, 'd_sae']
-
How often each feature fired.
var l0 : float
-
Mean L0 across all examples.
var l1 : float
-
Mean L1 across all examples.
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
-
The mean value for each feature when it did fire.
var mse : float
-
Mean MSE across all examples.
var n_almost_dead : int
-
Number of neurons that fired on fewer than
almost_dead_threshold
of examples. var n_dead : int
-
Number of neurons that never fired on any example.
var n_dense : int
-
Number of neurons that fired on more than
dense_threshold
of examples.
Methods
def for_wandb(self) ‑> dict[str, int | float]
class ParallelWandbRun (project: str,
cfgs: list[Train],
mode: str,
tags: list[str])-
Inspired by https://community.wandb.ai/t/is-it-possible-to-log-to-multiple-runs-simultaneously/4387/3.
Expand source code
class ParallelWandbRun: """ Inspired by https://community.wandb.ai/t/is-it-possible-to-log-to-multiple-runs-simultaneously/4387/3. """ def __init__( self, project: str, cfgs: list[config.Train], mode: str, tags: list[str] ): cfg, *cfgs = cfgs self.project = project self.cfgs = cfgs self.mode = mode self.tags = tags self.live_run = wandb.init(project=project, config=cfg, mode=mode, tags=tags) self.metric_queues: list[MetricQueue] = [[] for _ in self.cfgs] def log(self, metrics: list[dict[str, object]], *, step: int): metric, *metrics = metrics self.live_run.log(metric, step=step) for queue, metric in zip(self.metric_queues, metrics): queue.append((step, metric)) def finish(self) -> list[str]: ids = [self.live_run.id] # Log the rest of the runs. self.live_run.finish() for queue, cfg in zip(self.metric_queues, self.cfgs): run = wandb.init( project=self.project, config=cfg, mode=self.mode, tags=self.tags + ["queued"], ) for step, metric in queue: run.log(metric, step=step) ids.append(run.id) run.finish() return ids
Methods
def finish(self) ‑> list[str]
def log(self, metrics: list[dict[str, object]], *, step: int)
class Scheduler
-
Expand source code
@beartype.beartype class Scheduler: def step(self) -> float: err_msg = f"{self.__class__.__name__} must implement step()." raise NotImplementedError(err_msg) def __repr__(self) -> str: err_msg = f"{self.__class__.__name__} must implement __repr__()." raise NotImplementedError(err_msg)
Subclasses
Methods
def step(self) ‑> float
class Warmup (init: float, final: float, n_steps: int)
-
Linearly increases from
init
tofinal
overn_warmup_steps
steps.Expand source code
@beartype.beartype class Warmup(Scheduler): """ Linearly increases from `init` to `final` over `n_warmup_steps` steps. """ def __init__(self, init: float, final: float, n_steps: int): self.final = final self.init = init self.n_steps = n_steps self._step = 0 def step(self) -> float: self._step += 1 if self._step < self.n_steps: return self.init + (self.final - self.init) * (self._step / self.n_steps) return self.final def __repr__(self) -> str: return f"Warmup(init={self.init}, final={self.final}, n_steps={self.n_steps})"
Ancestors
Methods
def step(self) ‑> float