
Trains many SAEs in parallel to amortize the cost of loading a single batch of data over many SAE training runs.


def evaluate(cfgs: list[Train],
saes: torch.nn.modules.container.ModuleList) ‑> list[EvalMetrics]

Evaluates SAE quality by counting the number of dead features and the number of dense features. Also makes histogram plots to help human qualitative comparison.


Develop automatic methods to use histogram and feature frequencies to evaluate quality with a single number.

def init_b_dec_batched(saes: torch.nn.modules.container.ModuleList,
dataset: Dataset)
def main(cfgs: list[Train]) ‑> list[str]
def make_hashable(obj)
def make_saes(cfgs: list[SparseAutoencoder]) ‑> tuple[torch.nn.modules.container.ModuleList, list[dict[str, object]]]
def split_cfgs(cfgs: list[Train]) ‑> list[list[Train]]

Splits configs into groups that can be parallelized.


A list of configs from a sweep file.


A list of lists, where the configs in each sublist do not differ in any keys that are in CANNOT_PARALLELIZE. This means that each sublist is a valid "parallel" set of configs for train().

def train(cfgs: list[Train]) ‑> tuple[torch.nn.modules.container.ModuleList, ParallelWandbRun, int]

Explicitly declare the optimizer, schedulers, dataloader, etc outside of main() so that all the variables are dropped from scope and can be garbage collected.


class BatchLimiter (dataloader:, n_samples: int)

Limits the number of batches to only return n_samples total samples.

Expand source code
class BatchLimiter:
    Limits the number of batches to only return `n_samples` total samples.

    def __init__(self, dataloader:, n_samples: int):
        self.dataloader = dataloader
        self.n_samples = n_samples
        self.batch_size = dataloader.batch_size

    def __len__(self) -> int:
        return self.n_samples // self.batch_size

    def __iter__(self):
        self.n_seen = 0
        while True:
            for batch in self.dataloader:
                yield batch

                # Sometimes we underestimate because the final batch in the dataloader might not be a full batch.
                self.n_seen += self.batch_size
                if self.n_seen > self.n_samples:

            # We try to mitigate the above issue by ignoring the last batch if we don't have drop_last.
            if not self.dataloader.drop_last:
                self.n_seen -= self.batch_size
class EvalMetrics (l0: float,
l1: float,
mse: float,
n_dead: int,
n_almost_dead: int,
n_dense: int,
freqs: jaxtyping.Float[Tensor, 'd_sae'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
almost_dead_threshold: float,
dense_threshold: float)

Results of evaluating a trained SAE on a datset.

Expand source code
class EvalMetrics:
    """Results of evaluating a trained SAE on a datset."""

    l0: float
    """Mean L0 across all examples."""
    l1: float
    """Mean L1 across all examples."""
    mse: float
    """Mean MSE across all examples."""
    n_dead: int
    """Number of neurons that never fired on any example."""
    n_almost_dead: int
    """Number of neurons that fired on fewer than `almost_dead_threshold` of examples."""
    n_dense: int
    """Number of neurons that fired on more than `dense_threshold` of examples."""

    freqs: Float[Tensor, " d_sae"]
    """How often each feature fired."""
    mean_values: Float[Tensor, " d_sae"]
    """The mean value for each feature when it did fire."""

    almost_dead_threshold: float
    """Threshold for an "almost dead" neuron."""
    dense_threshold: float
    """Threshold for a dense neuron."""

    def for_wandb(self) -> dict[str, int | float]:
        dct = dataclasses.asdict(self)
        # Store arrays as tables.
        dct["freqs"] = wandb.Table(columns=["freq"], data=dct["freqs"][:, None].numpy())
        dct["mean_values"] = wandb.Table(
            columns=["mean_value"], data=dct["mean_values"][:, None].numpy()
        return {f"eval/{key}": value for key, value in dct.items()}

Class variables

var almost_dead_threshold : float

Threshold for an "almost dead" neuron.

var dense_threshold : float

Threshold for a dense neuron.

var freqs : jaxtyping.Float[Tensor, 'd_sae']

How often each feature fired.

var l0 : float

Mean L0 across all examples.

var l1 : float

Mean L1 across all examples.

var mean_values : jaxtyping.Float[Tensor, 'd_sae']

The mean value for each feature when it did fire.

var mse : float

Mean MSE across all examples.

var n_almost_dead : int

Number of neurons that fired on fewer than almost_dead_threshold of examples.

var n_dead : int

Number of neurons that never fired on any example.

var n_dense : int

Number of neurons that fired on more than dense_threshold of examples.


def for_wandb(self) ‑> dict[str, int | float]
class ParallelWandbRun (project: str,
cfgs: list[Train],
mode: str,
tags: list[str])
Expand source code
class ParallelWandbRun:
    Inspired by

    def __init__(
        self, project: str, cfgs: list[config.Train], mode: str, tags: list[str]
        cfg, *cfgs = cfgs
        self.project = project
        self.cfgs = cfgs
        self.mode = mode
        self.tags = tags

        self.live_run = wandb.init(project=project, config=cfg, mode=mode, tags=tags)

        self.metric_queues: list[MetricQueue] = [[] for _ in self.cfgs]

    def log(self, metrics: list[dict[str, object]], *, step: int):
        metric, *metrics = metrics
        self.live_run.log(metric, step=step)
        for queue, metric in zip(self.metric_queues, metrics):
            queue.append((step, metric))

    def finish(self) -> list[str]:
        ids = []
        # Log the rest of the runs.

        for queue, cfg in zip(self.metric_queues, self.cfgs):
            run = wandb.init(
                tags=self.tags + ["queued"],
            for step, metric in queue:
                run.log(metric, step=step)

        return ids


def finish(self) ‑> list[str]
def log(self, metrics: list[dict[str, object]], *, step: int)
class Scheduler
Expand source code
class Scheduler:
    def step(self) -> float:
        err_msg = f"{self.__class__.__name__} must implement step()."
        raise NotImplementedError(err_msg)

    def __repr__(self) -> str:
        err_msg = f"{self.__class__.__name__} must implement __repr__()."
        raise NotImplementedError(err_msg)



def step(self) ‑> float
class Warmup (init: float, final: float, n_steps: int)

Linearly increases from init to final over n_warmup_steps steps.

Expand source code
class Warmup(Scheduler):
    Linearly increases from `init` to `final` over `n_warmup_steps` steps.

    def __init__(self, init: float, final: float, n_steps: int): = final
        self.init = init
        self.n_steps = n_steps
        self._step = 0

    def step(self) -> float:
        self._step += 1
        if self._step < self.n_steps:
            return self.init + ( - self.init) * (self._step / self.n_steps)


    def __repr__(self) -> str:
        return f"Warmup(init={self.init}, final={}, n_steps={self.n_steps})"



def step(self) ‑> float