Module saev.nn

Neural network architectures for sparse autoencoders.

Functions

def dump(fpath: str,
sae: SparseAutoencoder)

Save an SAE checkpoint to disk along with configuration, using the trick from equinox.

Arguments

fpath: filepath to save checkpoint to. sae: sparse autoencoder checkpoint to save.

def load(fpath: str, *, device: str = 'cpu') ‑> SparseAutoencoder

Loads a sparse autoencoder from disk.

def ref_mse(x_hat: jaxtyping.Float[Tensor, '*d'],
x: jaxtyping.Float[Tensor, '*d'],
norm: bool = True) ‑> jaxtyping.Float[Tensor, '*d']
def safe_mse(x_hat: jaxtyping.Float[Tensor, '*batch d'],
x: jaxtyping.Float[Tensor, '*batch d'],
norm: bool = False) ‑> jaxtyping.Float[Tensor, '*batch d']

Classes

class Loss (mse: jaxtyping.Float[Tensor, ''],
sparsity: jaxtyping.Float[Tensor, ''],
ghost_grad: jaxtyping.Float[Tensor, ''],
l0: jaxtyping.Float[Tensor, ''],
l1: jaxtyping.Float[Tensor, ''])

The composite loss terms for an autoencoder training batch.

Expand source code
class Loss(typing.NamedTuple):
    """The composite loss terms for an autoencoder training batch."""

    mse: Float[Tensor, ""]
    """Reconstruction loss (mean squared error)."""
    sparsity: Float[Tensor, ""]
    """Sparsity loss, typically lambda * L1."""
    ghost_grad: Float[Tensor, ""]
    """Ghost gradient loss, if any."""
    l0: Float[Tensor, ""]
    """L0 magnitude of hidden activations."""
    l1: Float[Tensor, ""]
    """L1 magnitude of hidden activations."""

    @property
    def loss(self) -> Float[Tensor, ""]:
        """Total loss."""
        return self.mse + self.sparsity + self.ghost_grad

Ancestors

  • builtins.tuple

Instance variables

var ghost_grad : jaxtyping.Float[Tensor, '']

Ghost gradient loss, if any.

var l0 : jaxtyping.Float[Tensor, '']

L0 magnitude of hidden activations.

var l1 : jaxtyping.Float[Tensor, '']

L1 magnitude of hidden activations.

prop loss : jaxtyping.Float[Tensor, '']

Total loss.

Expand source code
@property
def loss(self) -> Float[Tensor, ""]:
    """Total loss."""
    return self.mse + self.sparsity + self.ghost_grad
var mse : jaxtyping.Float[Tensor, '']

Reconstruction loss (mean squared error).

var sparsity : jaxtyping.Float[Tensor, '']

Sparsity loss, typically lambda * L1.

class SparseAutoencoder (cfg: SparseAutoencoder)

Sparse auto-encoder (SAE) using L1 sparsity penalty.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class SparseAutoencoder(torch.nn.Module):
    """
    Sparse auto-encoder (SAE) using L1 sparsity penalty.
    """

    cfg: config.SparseAutoencoder

    def __init__(self, cfg: config.SparseAutoencoder):
        super().__init__()

        self.cfg = cfg

        self.W_enc = torch.nn.Parameter(
            torch.nn.init.kaiming_uniform_(torch.empty(cfg.d_vit, cfg.d_sae))
        )
        self.b_enc = torch.nn.Parameter(torch.zeros(cfg.d_sae))

        self.W_dec = torch.nn.Parameter(
            torch.nn.init.kaiming_uniform_(torch.empty(cfg.d_sae, cfg.d_vit))
        )
        self.b_dec = torch.nn.Parameter(torch.zeros(cfg.d_vit))

        self.logger = logging.getLogger(f"sae(seed={cfg.seed})")

    def forward(
        self, x: Float[Tensor, "batch d_model"]
    ) -> tuple[Float[Tensor, "batch d_model"], Float[Tensor, "batch d_sae"], Loss]:
        """
        Given x, calculates the reconstructed x_hat, the intermediate activations f_x, and the loss.

        Arguments:
            x: a batch of ViT activations.
        """

        # Remove encoder bias as per Anthropic
        h_pre = (
            einops.einsum(
                x - self.b_dec, self.W_enc, "... d_vit, d_vit d_sae -> ... d_sae"
            )
            + self.b_enc
        )
        f_x = torch.nn.functional.relu(h_pre)

        x_hat = (
            einops.einsum(f_x, self.W_dec, "... d_sae, d_sae d_vit -> ... d_vit")
            + self.b_dec
        )
        # Some values of x and x_hat can be very large. We can calculate a safe MSE
        mse_loss = safe_mse(x_hat, x)

        mse_loss = mse_loss.mean()
        l0 = (f_x > 0).float().sum(axis=1).mean(axis=0)
        l1 = f_x.sum(axis=1).mean(axis=0)
        sparsity_loss = self.cfg.sparsity_coeff * l1
        # Ghost loss is included for backwards compatibility.
        ghost_loss = torch.zeros_like(mse_loss)

        return x_hat, f_x, Loss(mse_loss, sparsity_loss, ghost_loss, l0, l1)

    @torch.no_grad()
    def init_b_dec(self, vit_acts: Float[Tensor, "n d_vit"]):
        if self.cfg.n_reinit_samples <= 0:
            self.logger.info("Skipping init_b_dec.")
            return
        previous_b_dec = self.b_dec.clone().cpu()
        vit_acts = vit_acts[: self.cfg.n_reinit_samples]
        assert len(vit_acts) == self.cfg.n_reinit_samples
        mean = vit_acts.mean(axis=0)
        previous_distances = torch.norm(vit_acts - previous_b_dec, dim=-1)
        distances = torch.norm(vit_acts - mean, dim=-1)
        self.logger.info(
            "Prev dist: %.3f; new dist: %.3f",
            previous_distances.median(axis=0).values.mean().item(),
            distances.median(axis=0).values.mean().item(),
        )
        self.b_dec.data = mean.to(self.b_dec.dtype).to(self.b_dec.device)

    @torch.no_grad()
    def normalize_w_dec(self):
        """
        Set W_dec to unit-norm columns.
        """
        if self.cfg.normalize_w_dec:
            self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)

    @torch.no_grad()
    def remove_parallel_grads(self):
        """
        Update grads so that they remove the parallel component
            (d_sae, d_vit) shape
        """
        if not self.cfg.remove_parallel_grads:
            return

        parallel_component = einops.einsum(
            self.W_dec.grad,
            self.W_dec.data,
            "d_sae d_vit, d_sae d_vit -> d_sae",
        )

        self.W_dec.grad -= einops.einsum(
            parallel_component,
            self.W_dec.data,
            "d_sae, d_sae d_vit -> d_sae d_vit",
        )

Ancestors

  • torch.nn.modules.module.Module

Class variables

var cfgSparseAutoencoder

Methods

def forward(self, x: jaxtyping.Float[Tensor, 'batch d_model']) ‑> tuple[jaxtyping.Float[Tensor, 'batch d_model'], jaxtyping.Float[Tensor, 'batch d_sae'], Loss]

Given x, calculates the reconstructed x_hat, the intermediate activations f_x, and the loss.

Arguments

x: a batch of ViT activations.

def init_b_dec(self, vit_acts: jaxtyping.Float[Tensor, 'n d_vit'])
def normalize_w_dec(self)

Set W_dec to unit-norm columns.

def remove_parallel_grads(self)

Update grads so that they remove the parallel component (d_sae, d_vit) shape