Module saev.training

Trains many SAEs in parallel to amortize the cost of loading a single batch of data over many SAE training runs.

Functions

def evaluate(cfgs: list[Train],
saes: torch.nn.modules.container.ModuleList,
objectives: torch.nn.modules.container.ModuleList) ‑> list[EvalMetrics]

Evaluates SAE quality by counting dead and dense features and recording loss metrics. Also makes histogram plots to help human qualitative comparison.

The metrics computed are mean L0/L1/MSE losses, the number of dead, almost dead, and dense neurons, plus per-feature firing frequencies and mean values. A list of EvalMetrics is returned, one for each SAE.

def init_b_dec_batched(saes: torch.nn.modules.container.ModuleList,
dataset: Dataset)
def main(cfgs: list[Train]) ‑> list[str]
def make_hashable(obj)
def make_saes(cfgs: list[tuple[Relu | JumpReluVanilla | Matryoshka]]) ‑> tuple[torch.nn.modules.container.ModuleList, torch.nn.modules.container.ModuleList, list[dict[str, object]]]
def split_cfgs(cfgs: list[Train]) ‑> list[list[Train]]

Splits configs into groups that can be parallelized.

Arguments

A list of configs from a sweep file.

Returns

A list of lists, where the configs in each sublist do not differ in any keys that are in CANNOT_PARALLELIZE. This means that each sublist is a valid "parallel" set of configs for train().

def train(cfgs: list[Train]) ‑> tuple[torch.nn.modules.container.ModuleList, torch.nn.modules.container.ModuleList, ParallelWandbRun, int]

Explicitly declare the optimizer, schedulers, dataloader, etc outside of main() so that all the variables are dropped from scope and can be garbage collected.

Classes

class BatchLimiter (dataloader: torch.utils.data.dataloader.DataLoader, n_samples: int)

Limits the number of batches to only return n_samples total samples.

Expand source code
class BatchLimiter:
    """
    Limits the number of batches to only return `n_samples` total samples.
    """

    def __init__(self, dataloader: torch.utils.data.DataLoader, n_samples: int):
        self.dataloader = dataloader
        self.n_samples = n_samples
        self.batch_size = dataloader.batch_size

    def __len__(self) -> int:
        return self.n_samples // self.batch_size

    def __iter__(self):
        self.n_seen = 0
        while True:
            for batch in self.dataloader:
                yield batch

                # Sometimes we underestimate because the final batch in the dataloader might not be a full batch.
                self.n_seen += self.batch_size
                if self.n_seen > self.n_samples:
                    return

            # We try to mitigate the above issue by ignoring the last batch if we don't have drop_last.
            if not self.dataloader.drop_last:
                self.n_seen -= self.batch_size
class EvalMetrics (l0: float,
l1: float,
mse: float,
n_dead: int,
n_almost_dead: int,
n_dense: int,
freqs: jaxtyping.Float[Tensor, 'd_sae'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
almost_dead_threshold: float,
dense_threshold: float)

Results of evaluating a trained SAE on a datset.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class EvalMetrics:
    """Results of evaluating a trained SAE on a datset."""

    l0: float
    """Mean L0 across all examples."""
    l1: float
    """Mean L1 across all examples."""
    mse: float
    """Mean MSE across all examples."""
    n_dead: int
    """Number of neurons that never fired on any example."""
    n_almost_dead: int
    """Number of neurons that fired on fewer than `almost_dead_threshold` of examples."""
    n_dense: int
    """Number of neurons that fired on more than `dense_threshold` of examples."""

    freqs: Float[Tensor, " d_sae"]
    """How often each feature fired."""
    mean_values: Float[Tensor, " d_sae"]
    """The mean value for each feature when it did fire."""

    almost_dead_threshold: float
    """Threshold for an "almost dead" neuron."""
    dense_threshold: float
    """Threshold for a dense neuron."""

    def for_wandb(self) -> dict[str, int | float]:
        dct = dataclasses.asdict(self)
        # Store arrays as tables.
        dct["freqs"] = wandb.Table(columns=["freq"], data=dct["freqs"][:, None].numpy())
        dct["mean_values"] = wandb.Table(
            columns=["mean_value"], data=dct["mean_values"][:, None].numpy()
        )
        return {f"eval/{key}": value for key, value in dct.items()}

Class variables

var almost_dead_threshold : float

Threshold for an "almost dead" neuron.

var dense_threshold : float

Threshold for a dense neuron.

var freqs : jaxtyping.Float[Tensor, 'd_sae']

How often each feature fired.

var l0 : float

Mean L0 across all examples.

var l1 : float

Mean L1 across all examples.

var mean_values : jaxtyping.Float[Tensor, 'd_sae']

The mean value for each feature when it did fire.

var mse : float

Mean MSE across all examples.

var n_almost_dead : int

Number of neurons that fired on fewer than almost_dead_threshold of examples.

var n_dead : int

Number of neurons that never fired on any example.

var n_dense : int

Number of neurons that fired on more than dense_threshold of examples.

Methods

def for_wandb(self) ‑> dict[str, int | float]
class ParallelWandbRun (project: str,
cfgs: list[Train],
mode: str,
tags: list[str])
Expand source code
class ParallelWandbRun:
    """
    Inspired by https://community.wandb.ai/t/is-it-possible-to-log-to-multiple-runs-simultaneously/4387/3.
    """

    def __init__(
        self, project: str, cfgs: list[config.Train], mode: str, tags: list[str]
    ):
        cfg, *cfgs = cfgs
        self.project = project
        self.cfgs = cfgs
        self.mode = mode
        self.tags = tags

        self.live_run = wandb.init(
            project=project, config=dataclasses.asdict(cfg), mode=mode, tags=tags
        )

        self.metric_queues: list[MetricQueue] = [[] for _ in self.cfgs]

    def log(self, metrics: list[dict[str, object]], *, step: int):
        metric, *metrics = metrics
        self.live_run.log(metric, step=step)
        for queue, metric in zip(self.metric_queues, metrics):
            queue.append((step, metric))

    def finish(self) -> list[str]:
        ids = [self.live_run.id]
        # Log the rest of the runs.
        self.live_run.finish()

        for queue, cfg in zip(self.metric_queues, self.cfgs):
            run = wandb.init(
                project=self.project,
                config=dataclasses.asdict(cfg),
                mode=self.mode,
                tags=self.tags + ["queued"],
            )
            for step, metric in queue:
                run.log(metric, step=step)
            ids.append(run.id)
            run.finish()

        return ids

Methods

def finish(self) ‑> list[str]
def log(self, metrics: list[dict[str, object]], *, step: int)
class Scheduler
Expand source code
@beartype.beartype
class Scheduler:
    def step(self) -> float:
        err_msg = f"{self.__class__.__name__} must implement step()."
        raise NotImplementedError(err_msg)

    def __repr__(self) -> str:
        err_msg = f"{self.__class__.__name__} must implement __repr__()."
        raise NotImplementedError(err_msg)

Subclasses

Methods

def step(self) ‑> float
class Warmup (init: float, final: float, n_steps: int)

Linearly increases from init to final over n_warmup_steps steps.

Expand source code
@beartype.beartype
class Warmup(Scheduler):
    """
    Linearly increases from `init` to `final` over `n_warmup_steps` steps.
    """

    def __init__(self, init: float, final: float, n_steps: int):
        self.final = final
        self.init = init
        self.n_steps = n_steps
        self._step = 0

    def step(self) -> float:
        self._step += 1
        if self._step < self.n_steps:
            return self.init + (self.final - self.init) * (self._step / self.n_steps)

        return self.final

    def __repr__(self) -> str:
        return f"Warmup(init={self.init}, final={self.final}, n_steps={self.n_steps})"

Ancestors

Methods

def step(self) ‑> float