Module saev.nn
Neural network architectures for sparse autoencoders.
Functions
def dump(fpath: str,
sae: SparseAutoencoder)-
Save an SAE checkpoint to disk along with configuration, using the trick from equinox.
Arguments
fpath: filepath to save checkpoint to. sae: sparse autoencoder checkpoint to save.
def load(fpath: str, *, device: str = 'cpu') ‑> SparseAutoencoder
-
Loads a sparse autoencoder from disk.
def ref_mse(x_hat: jaxtyping.Float[Tensor, '*d'],
x: jaxtyping.Float[Tensor, '*d'],
norm: bool = True) ‑> jaxtyping.Float[Tensor, '*d']def safe_mse(x_hat: jaxtyping.Float[Tensor, '*batch d'],
x: jaxtyping.Float[Tensor, '*batch d'],
norm: bool = False) ‑> jaxtyping.Float[Tensor, '*batch d']
Classes
class Loss (mse: jaxtyping.Float[Tensor, ''],
sparsity: jaxtyping.Float[Tensor, ''],
ghost_grad: jaxtyping.Float[Tensor, ''],
l0: jaxtyping.Float[Tensor, ''],
l1: jaxtyping.Float[Tensor, ''])-
The composite loss terms for an autoencoder training batch.
Expand source code
class Loss(typing.NamedTuple): """The composite loss terms for an autoencoder training batch.""" mse: Float[Tensor, ""] """Reconstruction loss (mean squared error).""" sparsity: Float[Tensor, ""] """Sparsity loss, typically lambda * L1.""" ghost_grad: Float[Tensor, ""] """Ghost gradient loss, if any.""" l0: Float[Tensor, ""] """L0 magnitude of hidden activations.""" l1: Float[Tensor, ""] """L1 magnitude of hidden activations.""" @property def loss(self) -> Float[Tensor, ""]: """Total loss.""" return self.mse + self.sparsity + self.ghost_grad
Ancestors
- builtins.tuple
Instance variables
var ghost_grad : jaxtyping.Float[Tensor, '']
-
Ghost gradient loss, if any.
var l0 : jaxtyping.Float[Tensor, '']
-
L0 magnitude of hidden activations.
var l1 : jaxtyping.Float[Tensor, '']
-
L1 magnitude of hidden activations.
prop loss : jaxtyping.Float[Tensor, '']
-
Total loss.
Expand source code
@property def loss(self) -> Float[Tensor, ""]: """Total loss.""" return self.mse + self.sparsity + self.ghost_grad
var mse : jaxtyping.Float[Tensor, '']
-
Reconstruction loss (mean squared error).
var sparsity : jaxtyping.Float[Tensor, '']
-
Sparsity loss, typically lambda * L1.
class SparseAutoencoder (cfg: SparseAutoencoder)
-
Sparse auto-encoder (SAE) using L1 sparsity penalty.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class SparseAutoencoder(torch.nn.Module): """ Sparse auto-encoder (SAE) using L1 sparsity penalty. """ cfg: config.SparseAutoencoder def __init__(self, cfg: config.SparseAutoencoder): super().__init__() self.cfg = cfg self.W_enc = torch.nn.Parameter( torch.nn.init.kaiming_uniform_(torch.empty(cfg.d_vit, cfg.d_sae)) ) self.b_enc = torch.nn.Parameter(torch.zeros(cfg.d_sae)) self.W_dec = torch.nn.Parameter( torch.nn.init.kaiming_uniform_(torch.empty(cfg.d_sae, cfg.d_vit)) ) self.b_dec = torch.nn.Parameter(torch.zeros(cfg.d_vit)) self.logger = logging.getLogger(f"sae(seed={cfg.seed})") def forward( self, x: Float[Tensor, "batch d_model"] ) -> tuple[Float[Tensor, "batch d_model"], Float[Tensor, "batch d_sae"], Loss]: """ Given x, calculates the reconstructed x_hat, the intermediate activations f_x, and the loss. Arguments: x: a batch of ViT activations. """ # Remove encoder bias as per Anthropic h_pre = ( einops.einsum( x - self.b_dec, self.W_enc, "... d_vit, d_vit d_sae -> ... d_sae" ) + self.b_enc ) f_x = torch.nn.functional.relu(h_pre) x_hat = ( einops.einsum(f_x, self.W_dec, "... d_sae, d_sae d_vit -> ... d_vit") + self.b_dec ) # Some values of x and x_hat can be very large. We can calculate a safe MSE mse_loss = safe_mse(x_hat, x) mse_loss = mse_loss.mean() l0 = (f_x > 0).float().sum(axis=1).mean(axis=0) l1 = f_x.sum(axis=1).mean(axis=0) sparsity_loss = self.cfg.sparsity_coeff * l1 # Ghost loss is included for backwards compatibility. ghost_loss = torch.zeros_like(mse_loss) return x_hat, f_x, Loss(mse_loss, sparsity_loss, ghost_loss, l0, l1) @torch.no_grad() def init_b_dec(self, vit_acts: Float[Tensor, "n d_vit"]): if self.cfg.n_reinit_samples <= 0: self.logger.info("Skipping init_b_dec.") return previous_b_dec = self.b_dec.clone().cpu() vit_acts = vit_acts[: self.cfg.n_reinit_samples] assert len(vit_acts) == self.cfg.n_reinit_samples mean = vit_acts.mean(axis=0) previous_distances = torch.norm(vit_acts - previous_b_dec, dim=-1) distances = torch.norm(vit_acts - mean, dim=-1) self.logger.info( "Prev dist: %.3f; new dist: %.3f", previous_distances.median(axis=0).values.mean().item(), distances.median(axis=0).values.mean().item(), ) self.b_dec.data = mean.to(self.b_dec.dtype).to(self.b_dec.device) @torch.no_grad() def normalize_w_dec(self): """ Set W_dec to unit-norm columns. """ if self.cfg.normalize_w_dec: self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True) @torch.no_grad() def remove_parallel_grads(self): """ Update grads so that they remove the parallel component (d_sae, d_vit) shape """ if not self.cfg.remove_parallel_grads: return parallel_component = einops.einsum( self.W_dec.grad, self.W_dec.data, "d_sae d_vit, d_sae d_vit -> d_sae", ) self.W_dec.grad -= einops.einsum( parallel_component, self.W_dec.data, "d_sae, d_sae d_vit -> d_sae d_vit", )
Ancestors
- torch.nn.modules.module.Module
Class variables
var cfg : SparseAutoencoder
Methods
def forward(self, x: jaxtyping.Float[Tensor, 'batch d_model']) ‑> tuple[jaxtyping.Float[Tensor, 'batch d_model'], jaxtyping.Float[Tensor, 'batch d_sae'], Loss]
-
Given x, calculates the reconstructed x_hat, the intermediate activations f_x, and the loss.
Arguments
x: a batch of ViT activations.
def init_b_dec(self, vit_acts: jaxtyping.Float[Tensor, 'n d_vit'])
def normalize_w_dec(self)
-
Set W_dec to unit-norm columns.
def remove_parallel_grads(self)
-
Update grads so that they remove the parallel component (d_sae, d_vit) shape