>>>> CONVENTIONS.md # Conventions This document outlines some programming conventions that are not caught by automated tools. * File descriptors from `open()` are called `fd`. * Use types where possible, including `jaxtyping` hints. * Decorate functions with `beartype.beartype` unless they use a `jaxtyping` hint, in which case use `jaxtyped(typechecker=beartype.beartype)`. * Variables referring to a filepath should be suffixed with `_fpath`. Directories are `_dpath`. * Prefer `make` over `build` when naming functions that construct objects, and use `get` when constructing primitives (like string paths or config values). * Only use `setup` for naming functions that don't return anything. Throughout the code, variables are annotated with shape suffixes, as [recommended by Noam Shazeer](https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd). The key for these suffixes: * B: batch size * W: width in patches (typically 14 or 16) * H: height in patches (typically 14 or 16) * D: ViT activation dimension (typically 768 or 1024) * S: SAE latent dimension (768 x 16, etc) * L: Number of latents being manipulated at once (typically 1-5 at a time) * C: Number of classes in ADE20K (151) For example, an activation tensor with shape (batch, width, height d_vit) is `acts_BWHD`. >>>> README.md # saev - Sparse Auto-Encoders for Vision ![Coverage](docs/coverage.svg) Sparse autoencoders (SAEs) for vision transformers (ViTs), implemented in PyTorch. This is the codebase used for our preprint "Sparse Autoencoders for Scientifically Rigorous Interpretation of Vision Models" * [arXiv preprint](https://arxiv.org/abs/2502.06755) * [Huggingface Models](https://huggingface.co/collections/osunlp/sae-v-67ab8c4fdf179d117db28195) * [API Docs](https://osu-nlp-group.github.io/saev/saev) * [Demos](https://osu-nlp-group.github.io/saev/#demos) ## About saev is a package for training sparse autoencoders (SAEs) on vision transformers (ViTs) in PyTorch. It also includes an interactive webapp for looking through a trained SAE's features. Originally forked from [HugoFry](https://github.com/HugoFry/mats_sae_training_for_ViTs) who forked it from [Joseph Bloom](https://github.com/jbloomAus/SAELens). Read [logbook.md](logbook.md) for a detailed log of my thought process. See [related-work.md](saev/related-work.md) for a list of works training SAEs on vision models. Please open an issue or a PR if there is missing work. ## Installation Installation is supported with [uv](https://docs.astral.sh/uv/). saev will likely work with pure pip, conda, etc. but I will not formally support it. Clone this repository (or fork it), then from the root directory: ```bash uv run python -m saev --help ``` This will create a virtual environment and display the CLI help. ## Using `saev` See the [docs](https://osu-nlp-group.github.io/saev/saev) for an overview. You can ask questions about this repo using the `llms.txt` file. Example (macOS): `curl https://osu-nlp-group.github.io/saev/llms.txt | pbcopy`, then paste into [Claude](https://claude.ai) or any LLM interface of your choice. >>>> __init__.py """ saev is a Python package for training sparse autoencoders (SAEs) on vision transformers (ViTs) in PyTorch. The main entrypoint to the package is in `__main__`; use `python -m saev --help` to see the options and documentation for the script. # Tutorials .. include:: ./guide.md # How-To Guides .. include:: ./reproduce.md # Explanations .. include:: ./related-work.md .. include:: ./inference.md """ import importlib.metadata import pathlib import tomllib # std-lib in Python ≥3.11 def _version_from_pyproject() -> str: """ Parse `[project].version` out of pyproject.toml that sits two directories above this file: saev/__init__.py saev/ pyproject.toml Returns "0.0.0+unknown" on any error. """ try: pp = pathlib.Path(__file__).resolve().parents[1] / "pyproject.toml" with pp.open("rb") as f: data = tomllib.load(f) return data["project"]["version"] except Exception: # key missing, file missing, bad TOML, ... return "0.0.0+unknown" try: __version__ = importlib.metadata.version("saev") # installed wheel / editable except importlib.metadata.PackageNotFoundError: __version__ = _version_from_pyproject() # running from source tree >>>> __main__.py import logging import tomllib import typing import beartype import tyro from . import config log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger("saev") @beartype.beartype def activations(cfg: typing.Annotated[config.Activations, tyro.conf.arg(name="")]): """ Save ViT activations for use later on. Args: cfg: Configuration for activations. """ import saev.activations saev.activations.main(cfg) @beartype.beartype def train( cfg: typing.Annotated[config.Train, tyro.conf.arg(name="")], sweep: str | None = None, ): """ Train an SAE over activations, optionally running a parallel grid search over a set of hyperparameters. Args: cfg: Baseline config for training an SAE. sweep: Path to .toml file defining the sweep parameters. """ import submitit from . import config, training if sweep is not None: with open(sweep, "rb") as fd: cfgs, errs = config.grid(cfg, tomllib.load(fd)) if errs: for err in errs: logger.warning("Error in config: %s", err) return else: cfgs = [cfg] cfgs = training.split_cfgs(cfgs) logger.info("Running %d training jobs.", len(cfgs)) if cfg.slurm: executor = submitit.SlurmExecutor(folder=cfg.log_to) executor.update_parameters( time=60, partition="preemptible", gpus_per_node=1, cpus_per_task=cfg.n_workers + 4, stderr_to_stdout=True, account=cfg.slurm_acct, ) else: executor = submitit.DebugExecutor(folder=cfg.log_to) jobs = [executor.submit(training.main, group) for group in cfgs] for job in jobs: job.result() @beartype.beartype def visuals(cfg: typing.Annotated[config.Visuals, tyro.conf.arg(name="")]): """ Save maximally activating images for each SAE latent. Args: cfg: Config """ from . import visuals visuals.main(cfg) if __name__ == "__main__": tyro.extras.subcommand_cli_from_dict({ "activations": activations, "train": train, "visuals": visuals, }) logger.info("Done.") >>>> activations.py """ To save lots of activations, we want to do things in parallel, with lots of slurm jobs, and save multiple files, rather than just one. This module handles that additional complexity. Conceptually, activations are either thought of as 1. A single [n_imgs x n_layers x (n_patches + 1), d_vit] tensor. This is a *dataset* 2. Multiple [n_imgs_per_shard, n_layers, (n_patches + 1), d_vit] tensors. This is a set of sharded activations. """ import dataclasses import hashlib import json import logging import math import os import typing from collections.abc import Callable import beartype import numpy as np import torch import torchvision.datasets from jaxtyping import Float, jaxtyped from PIL import Image from torch import Tensor from . import config, helpers log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger(__name__) ####################### # VISION TRANSFORMERS # ####################### @jaxtyped(typechecker=beartype.beartype) class RecordedVisionTransformer(torch.nn.Module): _storage: Float[Tensor, "batch n_layers all_patches dim"] | None _i: int def __init__( self, vit: torch.nn.Module, n_patches_per_img: int, cls_token: bool, layers: list[int], ): super().__init__() self.vit = vit self.n_patches_per_img = n_patches_per_img self.cls_token = cls_token self.layers = layers self.patches = vit.get_patches(n_patches_per_img) self._storage = None self._i = 0 self.logger = logging.getLogger(f"recorder({vit.name})") for i in self.layers: self.vit.get_residuals()[i].register_forward_hook(self.hook) def hook( self, module, args: tuple, output: Float[Tensor, "batch n_layers dim"] ) -> None: if self._storage is None: batch, _, dim = output.shape self._storage = self._empty_storage(batch, dim, output.device) if self._storage[:, self._i, 0, :].shape != output[:, 0, :].shape: batch, _, dim = output.shape old_batch, _, _, old_dim = self._storage.shape msg = "Output shape does not match storage shape: (batch) %d != %d or (dim) %d != %d" self.logger.warning(msg, old_batch, batch, old_dim, dim) self._storage = self._empty_storage(batch, dim, output.device) self._storage[:, self._i] = output[:, self.patches, :].detach() self._i += 1 def _empty_storage(self, batch: int, dim: int, device: torch.device): n_patches_per_img = self.n_patches_per_img if self.cls_token: n_patches_per_img += 1 return torch.zeros( (batch, len(self.layers), n_patches_per_img, dim), device=device ) def reset(self): self._i = 0 @property def activations(self) -> Float[Tensor, "batch n_layers all_patches dim"]: if self._storage is None: raise RuntimeError("First call forward()") return self._storage.cpu() def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> tuple[ Float[Tensor, "batch patches dim"], Float[Tensor, "batch n_layers all_patches dim"], ]: self.reset() result = self.vit(batch) return result, self.activations @jaxtyped(typechecker=beartype.beartype) class Clip(torch.nn.Module): def __init__(self, vit_ckpt: str): super().__init__() import open_clip if vit_ckpt.startswith("hf-hub:"): clip, _ = open_clip.create_model_from_pretrained( vit_ckpt, cache_dir=helpers.get_cache_dir() ) else: arch, ckpt = vit_ckpt.split("/") clip, _ = open_clip.create_model_from_pretrained( arch, pretrained=ckpt, cache_dir=helpers.get_cache_dir() ) model = clip.visual model.proj = None model.output_tokens = True # type: ignore self.model = model.eval() assert not isinstance(self.model, open_clip.timm_model.TimmModel) self.name = f"clip/{vit_ckpt}" def get_residuals(self) -> list[torch.nn.Module]: return self.model.transformer.resblocks def get_patches(self, cfg: config.Activations) -> slice: return slice(None, None, None) def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> Float[Tensor, "batch patches dim"]: cls, patches = self.model(batch) return {"cls": cls, "patches": patches} @jaxtyped(typechecker=beartype.beartype) class Siglip(torch.nn.Module): def __init__(self, vit_ckpt: str): super().__init__() import open_clip if vit_ckpt.startswith("hf-hub:"): clip, _ = open_clip.create_model_from_pretrained( vit_ckpt, cache_dir=helpers.get_cache_dir() ) else: arch, ckpt = vit_ckpt.split("/") clip, _ = open_clip.create_model_from_pretrained( arch, pretrained=ckpt, cache_dir=helpers.get_cache_dir() ) model = clip.visual model.proj = None model.output_tokens = True # type: ignore self.model = model assert isinstance(self.model, open_clip.timm_model.TimmModel) def get_residuals(self) -> list[torch.nn.Module]: return self.model.trunk.blocks def get_patches(self, cfg: config.Activations) -> slice: return slice(None, None, None) def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> Float[Tensor, "batch patches dim"]: result = self.model(batch) return result @jaxtyped(typechecker=beartype.beartype) class DinoV2(torch.nn.Module): def __init__(self, vit_ckpt: str): super().__init__() self.model = torch.hub.load("facebookresearch/dinov2", vit_ckpt) self.name = f"dinov2/{vit_ckpt}" def get_residuals(self) -> list[torch.nn.Module]: return self.model.blocks def get_patches(self, n_patches_per_img: int) -> slice: n_reg = self.model.num_register_tokens patches = torch.cat(( torch.tensor([0]), # CLS token torch.arange(n_reg + 1, n_reg + 1 + n_patches_per_img), # patches )) return patches def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> Float[Tensor, "batch patches dim"]: dct = self.model.forward_features(batch) features = torch.cat( (dct["x_norm_clstoken"][:, None, :], dct["x_norm_patchtokens"]), axis=1 ) return features @jaxtyped(typechecker=beartype.beartype) class Moondream2(torch.nn.Module): """ Moondream2 has 14x14 pixel patches. For a 378x378 image (as we use here), this is 27x27 patches for a total of 729, with no [CLS] token. """ def __init__(self, vit_ckpt: str): super().__init__() import transformers vit_id, revision = vit_ckpt.split(":") mllm = transformers.AutoModelForCausalLM.from_pretrained( vit_id, revision=revision, trust_remote_code=True ) self.model = mllm.vision_encoder.encoder.model.visual def get_patches(self, cfg: config.Activations) -> slice: return slice(None, None, None) def get_residuals(self) -> list[torch.nn.Module]: return self.model.blocks def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> Float[Tensor, "batch patches dim"]: features = self.model(batch) return features @beartype.beartype def make_vit(vit_family: str, vit_ckpt: str): if vit_family == "clip": return Clip(vit_ckpt) elif vit_family == "siglip": return Siglip(vit_ckpt) elif vit_family == "dinov2": return DinoV2(vit_ckpt) elif vit_family == "moondream2": return Moondream2(vit_ckpt) else: typing.assert_never(vit_family) @beartype.beartype def make_img_transform(vit_family: str, vit_ckpt: str) -> Callable: if vit_family == "clip" or vit_family == "siglip": import open_clip if vit_ckpt.startswith("hf-hub:"): _, img_transform = open_clip.create_model_from_pretrained( vit_ckpt, cache_dir=helpers.get_cache_dir() ) else: arch, ckpt = vit_ckpt.split("/") _, img_transform = open_clip.create_model_from_pretrained( arch, pretrained=ckpt, cache_dir=helpers.get_cache_dir() ) return img_transform elif vit_family == "dinov2": from torchvision.transforms import v2 return v2.Compose([ # TODO: I bet this should be 256, 256, which is causing localization issues in non-square images. v2.Resize(size=256), v2.CenterCrop(size=(224, 224)), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), ]) elif vit_family == "moondream2": from torchvision.transforms import v2 # Assume fixed image ratio, 378x378 return v2.Compose([ v2.Resize(size=(378, 378), interpolation=v2.InterpolationMode.BICUBIC), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) else: typing.assert_never(vit_family) ############### # ACTIVATIONS # ############### @jaxtyped(typechecker=beartype.beartype) class Dataset(torch.utils.data.Dataset): """ Dataset of activations from disk. """ class Example(typing.TypedDict): """Individual example.""" act: Float[Tensor, " d_vit"] image_i: int patch_i: int cfg: config.DataLoad """Configuration; set via CLI args.""" metadata: "Metadata" """Activations metadata; automatically loaded from disk.""" layer_index: int """Layer index into the shards if we are choosing a specific layer.""" scalar: float """Normalizing scalar such that ||x / scalar ||_2 ~= sqrt(d_vit).""" act_mean: Float[Tensor, " d_vit"] """Mean activation.""" def __init__(self, cfg: config.DataLoad): self.cfg = cfg if not os.path.isdir(self.cfg.shard_root): raise RuntimeError(f"Activations are not saved at '{self.cfg.shard_root}'.") metadata_fpath = os.path.join(self.cfg.shard_root, "metadata.json") self.metadata = Metadata.load(metadata_fpath) # Pick a really big number so that if you accidentally use this when you shouldn't, you get an out of bounds IndexError. self.layer_index = 1_000_000 if isinstance(self.cfg.layer, int): err_msg = f"Non-exact matches for .layer field not supported; {self.cfg.layer} not in {self.metadata.layers}." assert self.cfg.layer in self.metadata.layers, err_msg self.layer_index = self.metadata.layers.index(self.cfg.layer) # Premptively set these values so that preprocess() doesn't freak out. self.scalar = 1.0 self.act_mean = torch.zeros(self.d_vit) # If either of these are true, we must do this work. if self.cfg.scale_mean is True or self.cfg.scale_norm is True: # Load a random subset of samples to calculate the mean activation and mean L2 norm. perm = np.random.default_rng(seed=42).permutation(len(self)) perm = perm[: cfg.n_random_samples] samples = [ self[p.item()] for p in helpers.progress( perm, every=25_000, desc="examples to calc means" ) ] samples = torch.stack([sample["act"] for sample in samples]) if samples.abs().max() > 1e3: raise ValueError( "You found an abnormally large activation {example.abs().max().item():.5f} that will mess up your L2 mean." ) # Activation mean if self.cfg.scale_mean: self.act_mean = samples.mean(axis=0) if (self.act_mean > 1e3).any(): raise ValueError( "You found an abnormally large activation that is messing up your activation mean." ) # Norm if self.cfg.scale_norm: l2_mean = torch.linalg.norm(samples - self.act_mean, axis=1).mean() if l2_mean > 1e3: raise ValueError( "You found an abnormally large activation that is messing up your L2 mean." ) self.scalar = l2_mean / math.sqrt(self.d_vit) elif isinstance(self.cfg.scale_mean, str): # Load mean activations from disk self.act_mean = torch.load(self.cfg.scale_mean) elif isinstance(self.cfg.scale_norm, str): # Load scalar normalization from disk self.scalar = torch.load(self.cfg.scale_norm).item() def transform(self, act: Float[np.ndarray, " d_vit"]) -> Float[Tensor, " d_vit"]: """ Apply a scalar normalization so the mean squared L2 norm is same as d_vit. This is from 'Scaling Monosemanticity': > As a preprocessing step we apply a scalar normalization to the model activations so their average squared L2 norm is the residual stream dimension So we divide by self.scalar which is the datasets (approximate) L2 mean before normalization divided by sqrt(d_vit). """ act = torch.from_numpy(act.copy()) act = act.clamp(-self.cfg.clamp, self.cfg.clamp) return (act - self.act_mean) / self.scalar @property def d_vit(self) -> int: """Dimension of the underlying vision transformer's embedding space.""" return self.metadata.d_vit @jaxtyped(typechecker=beartype.beartype) def __getitem__(self, i: int) -> Example: match (self.cfg.patches, self.cfg.layer): case ("cls", int()): img_act = self.get_img_patches(i) # Select layer's cls token. act = img_act[self.layer_index, 0, :] return self.Example(act=self.transform(act), image_i=i, patch_i=-1) case ("cls", "meanpool"): img_act = self.get_img_patches(i) # Select cls tokens from across all layers cls_act = img_act[:, 0, :] # Meanpool over the layers act = cls_act.mean(axis=0) return self.Example(act=self.transform(act), image_i=i, patch_i=-1) case ("meanpool", int()): img_act = self.get_img_patches(i) # Select layer's patches. layer_act = img_act[self.layer_index, 1:, :] # Meanpool over the patches act = layer_act.mean(axis=0) return self.Example(act=self.transform(act), image_i=i, patch_i=-1) case ("meanpool", "meanpool"): img_act = self.get_img_patches(i) # Select all layer's patches. act = img_act[:, 1:, :] # Meanpool over the layers and patches act = act.mean(axis=(0, 1)) return self.Example(act=self.transform(act), image_i=i, patch_i=-1) case ("patches", int()): n_imgs_per_shard = ( self.metadata.n_patches_per_shard // len(self.metadata.layers) // (self.metadata.n_patches_per_img + 1) ) n_examples_per_shard = ( n_imgs_per_shard * self.metadata.n_patches_per_img ) shard = i // n_examples_per_shard pos = i % n_examples_per_shard acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin") shape = ( n_imgs_per_shard, len(self.metadata.layers), self.metadata.n_patches_per_img + 1, self.metadata.d_vit, ) acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape) # Choose the layer and the non-CLS tokens. acts = acts[:, self.layer_index, 1:] # Choose a patch among n and the patches. act = acts[ pos // self.metadata.n_patches_per_img, pos % self.metadata.n_patches_per_img, ] return self.Example( act=self.transform(act), # What image is this? image_i=i // self.metadata.n_patches_per_img, patch_i=i % self.metadata.n_patches_per_img, ) case _: print((self.cfg.patches, self.cfg.layer)) typing.assert_never((self.cfg.patches, self.cfg.layer)) def get_shard_patches(self): raise NotImplementedError() def get_img_patches( self, i: int ) -> Float[np.ndarray, "n_layers all_patches d_vit"]: n_imgs_per_shard = ( self.metadata.n_patches_per_shard // len(self.metadata.layers) // (self.metadata.n_patches_per_img + 1) ) shard = i // n_imgs_per_shard pos = i % n_imgs_per_shard acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin") shape = ( n_imgs_per_shard, len(self.metadata.layers), self.metadata.n_patches_per_img + 1, self.metadata.d_vit, ) acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape) # Note that this is not yet copied! return acts[pos] def __len__(self) -> int: """ Dataset length depends on `patches` and `layer`. """ match (self.cfg.patches, self.cfg.layer): case ("cls", "all"): # Return a CLS token from a random image and random layer. return self.metadata.n_imgs * len(self.metadata.layers) case ("cls", int()): # Return a CLS token from a random image and fixed layer. return self.metadata.n_imgs case ("cls", "meanpool"): # Return a CLS token from a random image and meanpool over all layers. return self.metadata.n_imgs case ("meanpool", "all"): # Return the meanpool of all patches from a random image and random layer. return self.metadata.n_imgs * len(self.metadata.layers) case ("meanpool", int()): # Return the meanpool of all patches from a random image and fixed layer. return self.metadata.n_imgs case ("meanpool", "meanpool"): # Return the meanpool of all patches from a random image and meanpool over all layers. return self.metadata.n_imgs case ("patches", int()): # Return a patch from a random image, fixed layer, and random patch. return self.metadata.n_imgs * (self.metadata.n_patches_per_img) case ("patches", "meanpool"): # Return a patch from a random image, meanpooled over all layers, and a random patch. return self.metadata.n_imgs * (self.metadata.n_patches_per_img) case ("patches", "all"): # Return a patch from a random image, random layer and random patch. return ( self.metadata.n_imgs * len(self.metadata.layers) * self.metadata.n_patches_per_img ) case _: typing.assert_never((self.cfg.patches, self.cfg.layer)) ########## # IMAGES # ########## @beartype.beartype def setup(cfg: config.Activations): """ Run dataset-specific setup. These setup functions can assume they are the only job running, but they should be idempotent; they should be safe (and ideally cheap) to run multiple times in a row. """ if isinstance(cfg.data, config.ImagenetDataset): setup_imagenet(cfg) elif isinstance(cfg.data, config.ImageFolderDataset): setup_imagefolder(cfg) elif isinstance(cfg.data, config.Ade20kDataset): setup_ade20k(cfg) else: typing.assert_never(cfg.data) @beartype.beartype def setup_imagenet(cfg: config.Activations): assert isinstance(cfg.data, config.ImagenetDataset) @beartype.beartype def setup_imagefolder(cfg: config.Activations): assert isinstance(cfg.data, config.ImageFolderDataset) logger.info("No dataset-specific setup for ImageFolder.") @beartype.beartype def setup_ade20k(cfg: config.Activations): assert isinstance(cfg.data, config.Ade20kDataset) # url = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip" # breakpoint() # 1. Check @beartype.beartype def get_dataset(cfg: config.DatasetConfig, *, img_transform): """ Gets the dataset for the current experiment; delegates construction to dataset-specific functions. Args: cfg: Experiment config. img_transform: Image transform to be applied to each image. Returns: A dataset that has dictionaries with `'image'`, `'index'`, `'target'`, and `'label'` keys containing examples. """ if isinstance(cfg, config.ImagenetDataset): return Imagenet(cfg, img_transform=img_transform) elif isinstance(cfg, config.Ade20kDataset): return Ade20k(cfg, img_transform=img_transform) elif isinstance(cfg, config.ImageFolderDataset): return ImageFolder(cfg.root, transform=img_transform) else: typing.assert_never(cfg) @beartype.beartype def get_dataloader(cfg: config.Activations, *, img_transform=None): """ Gets the dataloader for the current experiment; delegates dataloader construction to dataset-specific functions. Args: cfg: Experiment config. img_transform: Image transform to be applied to each image. Returns: A PyTorch Dataloader that yields dictionaries with `'image'` keys containing image batches. """ if isinstance( cfg.data, (config.ImagenetDataset, config.ImageFolderDataset, config.Ade20kDataset), ): dataloader = get_default_dataloader(cfg, img_transform=img_transform) else: typing.assert_never(cfg.data) return dataloader @beartype.beartype def get_default_dataloader( cfg: config.Activations, *, img_transform: Callable ) -> torch.utils.data.DataLoader: """ Get a dataloader for a default map-style dataset. Args: cfg: Config. img_transform: Image transform to be applied to each image. Returns: A PyTorch Dataloader that yields dictionaries with `'image'` keys containing image batches, `'index'` keys containing original dataset indices and `'label'` keys containing label batches. """ dataset = get_dataset(cfg.data, img_transform=img_transform) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.vit_batch_size, drop_last=False, num_workers=cfg.n_workers, persistent_workers=cfg.n_workers > 0, shuffle=False, pin_memory=False, ) return dataloader @beartype.beartype class Imagenet(torch.utils.data.Dataset): def __init__(self, cfg: config.ImagenetDataset, *, img_transform=None): import datasets self.hf_dataset = datasets.load_dataset( cfg.name, split=cfg.split, trust_remote_code=True ) self.img_transform = img_transform self.labels = self.hf_dataset.info.features["label"].names def __getitem__(self, i): example = self.hf_dataset[i] example["index"] = i example["image"] = example["image"].convert("RGB") if self.img_transform: example["image"] = self.img_transform(example["image"]) example["target"] = example.pop("label") example["label"] = self.labels[example["target"]] return example def __len__(self) -> int: return len(self.hf_dataset) @beartype.beartype class ImageFolder(torchvision.datasets.ImageFolder): def __getitem__(self, index: int) -> dict[str, object]: """ Args: index: Index Returns: dict with keys 'image', 'index', 'target' and 'label'. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return { "image": sample, "target": target, "label": self.classes[target], "index": index, } @beartype.beartype class Ade20k(torch.utils.data.Dataset): @beartype.beartype @dataclasses.dataclass(frozen=True) class Sample: img_path: str seg_path: str label: str target: int samples: list[Sample] def __init__( self, cfg: config.Ade20kDataset, *, img_transform: Callable | None = None, seg_transform: Callable | None = lambda x: None, ): self.logger = logging.getLogger("ade20k") self.cfg = cfg self.img_dir = os.path.join(cfg.root, "images") self.seg_dir = os.path.join(cfg.root, "annotations") self.img_transform = img_transform self.seg_transform = seg_transform # Check that we have the right path. for subdir in ("images", "annotations"): if not os.path.isdir(os.path.join(cfg.root, subdir)): # Something is missing. if os.path.realpath(cfg.root).endswith(subdir): self.logger.warning( "The ADE20K root should contain 'images/' and 'annotations/' directories." ) raise ValueError(f"Can't find path '{os.path.join(cfg.root, subdir)}'.") _, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir) split_lookup: dict[int, str] = { value: key for key, value in split_mapping.items() } self.loader = torchvision.datasets.folder.default_loader assert cfg.split in set(split_lookup.values()) # Load all the image paths. imgs: list[str] = [ path for path, s in torchvision.datasets.folder.make_dataset( self.img_dir, split_mapping, extensions=torchvision.datasets.folder.IMG_EXTENSIONS, ) if split_lookup[s] == cfg.split ] segs: list[str] = [ path for path, s in torchvision.datasets.folder.make_dataset( self.seg_dir, split_mapping, extensions=torchvision.datasets.folder.IMG_EXTENSIONS, ) if split_lookup[s] == cfg.split ] # Load all the targets, classes and mappings with open(os.path.join(cfg.root, "sceneCategories.txt")) as fd: img_labels: list[str] = [line.split()[1] for line in fd.readlines()] label_set = sorted(set(img_labels)) label_to_idx = {label: i for i, label in enumerate(label_set)} self.samples = [ self.Sample(img_path, seg_path, label, label_to_idx[label]) for img_path, seg_path, label in zip(imgs, segs, img_labels) ] def __getitem__(self, index: int) -> dict[str, object]: # Convert to dict. sample = dataclasses.asdict(self.samples[index]) sample["image"] = self.loader(sample.pop("img_path")) if self.img_transform is not None: image = self.img_transform(sample.pop("image")) if image is not None: sample["image"] = image sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L") if self.seg_transform is not None: segmentation = self.seg_transform(sample.pop("segmentation")) if segmentation is not None: sample["segmentation"] = segmentation sample["index"] = index return sample def __len__(self) -> int: return len(self.samples) ######## # MAIN # ######## @beartype.beartype def main(cfg: config.Activations): """ Args: cfg: Config for activations. """ logger = logging.getLogger("dump") if not cfg.ssl: logger.warning("Ignoring SSL certs. Try not to do this!") # https://github.com/openai/whisper/discussions/734#discussioncomment-4491761 # Ideally we don't have to disable SSL but we are only downloading weights. import ssl ssl._create_default_https_context = ssl._create_unverified_context # Run any setup steps. setup(cfg) # Actually record activations. if cfg.slurm: import submitit executor = submitit.SlurmExecutor(folder=cfg.log_to) executor.update_parameters( time=24 * 60, partition="gpu", gpus_per_node=1, cpus_per_task=cfg.n_workers + 4, stderr_to_stdout=True, account=cfg.slurm_acct, ) job = executor.submit(worker_fn, cfg) logger.info("Running job '%s'.", job.job_id) job.result() else: worker_fn(cfg) @beartype.beartype def worker_fn(cfg: config.Activations): """ Args: cfg: Config for activations. """ if torch.cuda.is_available(): # This enables tf32 on Ampere GPUs which is only 8% slower than # float16 and almost as accurate as float32 # This was a default in pytorch until 1.12 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False logger = logging.getLogger("dump") vit = make_vit(cfg.vit_family, cfg.vit_ckpt).to(cfg.device) vit = RecordedVisionTransformer( vit, cfg.n_patches_per_img, cfg.cls_token, cfg.vit_layers ) img_transform = make_img_transform(cfg.vit_family, cfg.vit_ckpt) dataloader = get_dataloader(cfg, img_transform=img_transform) writer = ShardWriter(cfg) n_batches = cfg.data.n_imgs // cfg.vit_batch_size + 1 logger.info("Dumping %d batches of %d examples.", n_batches, cfg.vit_batch_size) if cfg.device == "cuda" and not torch.cuda.is_available(): logger.warning("No CUDA device available, using CPU.") cfg = dataclasses.replace(cfg, device="cpu") vit = vit.to(cfg.device) # vit = torch.compile(vit) i = 0 # Calculate and write ViT activations. with torch.inference_mode(): for batch in helpers.progress(dataloader, total=n_batches): images = batch.pop("image").to(cfg.device) # cache has shape [batch size, n layers, n patches + 1, d vit] out, cache = vit(images) del out writer[i : i + len(cache)] = cache i += len(cache) writer.flush() @beartype.beartype class ShardWriter: """ ShardWriter is a stateful object that handles sharded activation writing to disk. """ root: str shape: tuple[int, int, int, int] shard: int acts_path: str acts: Float[np.ndarray, "n_imgs_per_shard n_layers all_patches d_vit"] | None filled: int def __init__(self, cfg: config.Activations): self.logger = logging.getLogger("shard-writer") self.root = get_acts_dir(cfg) n_patches_per_img = cfg.n_patches_per_img if cfg.cls_token: n_patches_per_img += 1 self.n_imgs_per_shard = ( cfg.n_patches_per_shard // len(cfg.vit_layers) // n_patches_per_img ) self.shape = ( self.n_imgs_per_shard, len(cfg.vit_layers), n_patches_per_img, cfg.d_vit, ) self.shard = -1 self.acts = None self.next_shard() @jaxtyped(typechecker=beartype.beartype) def __setitem__( self, i: slice, val: Float[Tensor, "_ n_layers all_patches d_vit"] ) -> None: assert i.step is None a, b = i.start, i.stop assert len(val) == b - a offset = self.n_imgs_per_shard * self.shard if b >= offset + self.n_imgs_per_shard: # We have run out of space in this mmap'ed file. Let's fill it as much as we can. n_fit = offset + self.n_imgs_per_shard - a self.acts[a - offset : a - offset + n_fit] = val[:n_fit] self.filled = a - offset + n_fit self.next_shard() # Recursively call __setitem__ in case we need *another* shard self[a + n_fit : b] = val[n_fit:] else: msg = f"0 <= {a} - {offset} <= {offset} + {self.n_imgs_per_shard}" assert 0 <= a - offset <= offset + self.n_imgs_per_shard, msg msg = f"0 <= {b} - {offset} <= {offset} + {self.n_imgs_per_shard}" assert 0 <= b - offset <= offset + self.n_imgs_per_shard, msg self.acts[a - offset : b - offset] = val self.filled = b - offset def flush(self) -> None: if self.acts is not None: self.acts.flush() self.acts = None def next_shard(self) -> None: self.flush() self.shard += 1 self._count = 0 self.acts_path = os.path.join(self.root, f"acts{self.shard:06}.bin") self.acts = np.memmap( self.acts_path, mode="w+", dtype=np.float32, shape=self.shape ) self.filled = 0 self.logger.info("Opened shard '%s'.", self.acts_path) @beartype.beartype @dataclasses.dataclass(frozen=True) class Metadata: vit_family: str vit_ckpt: str layers: tuple[int, ...] n_patches_per_img: int cls_token: bool d_vit: int seed: int n_imgs: int n_patches_per_shard: int data: str @classmethod def from_cfg(cls, cfg: config.Activations) -> "Metadata": return cls( cfg.vit_family, cfg.vit_ckpt, tuple(cfg.vit_layers), cfg.n_patches_per_img, cfg.cls_token, cfg.d_vit, cfg.seed, cfg.data.n_imgs, cfg.n_patches_per_shard, str(cfg.data), ) @classmethod def load(cls, fpath) -> "Metadata": with open(fpath) as fd: dct = json.load(fd) dct["layers"] = tuple(dct.pop("layers")) return cls(**dct) def dump(self, fpath): with open(fpath, "w") as fd: json.dump(dataclasses.asdict(self), fd, indent=4) @property def hash(self) -> str: cfg_str = json.dumps(dataclasses.asdict(self), sort_keys=True) return hashlib.sha256(cfg_str.encode("utf-8")).hexdigest() @beartype.beartype def get_acts_dir(cfg: config.Activations) -> str: """ Return the activations directory based on the relevant values of a config. Also saves a metadata.json file to that directory for human reference. Args: cfg: Config for experiment. Returns: Directory to where activations should be dumped/loaded from. """ metadata = Metadata.from_cfg(cfg) acts_dir = os.path.join(cfg.dump_to, metadata.hash) os.makedirs(acts_dir, exist_ok=True) metadata.dump(os.path.join(acts_dir, "metadata.json")) return acts_dir >>>> app/__main__.py import base64 import concurrent.futures import functools import json import logging import math import pathlib import time import typing import beartype import einops.layers.torch import gradio as gr import matplotlib import numpy as np import PIL.Image import pyvips import torch from jaxtyping import Float, Int, jaxtyped from torch import Tensor from .. import activations, nn from . import data, modeling log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger("app") # Disable pyvips info logging logging.getLogger("pyvips").setLevel(logging.WARNING) ########### # Globals # ########### RESIZE_SIZE = 512 """Resize shorter size to this size in pixels.""" CROP_SIZE = (448, 448) """Crop size in pixels.""" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") """Hardware accelerator, if any.""" CWD = pathlib.Path(".") MODEL_LOOKUP = modeling.get_model_lookup() COLORMAP = matplotlib.colormaps.get_cmap("plasma") logger.info("Set global constants.") ########## # Models # ########## @functools.cache def load_vit( model_cfg: modeling.Config, ) -> tuple[ activations.WrappedVisionTransformer, typing.Callable, float, Float[Tensor, " d_vit"], ]: """ Returns the wrapped ViT, the vit transform, the activation scalar and the activation mean to normalize the activations. """ vit = activations.WrappedVisionTransformer(model_cfg.wrapped_cfg).to(DEVICE).eval() vit_transform = activations.make_img_transform( model_cfg.vit_family, model_cfg.vit_ckpt ) logger.info("Loaded ViT: %s.", model_cfg.key) try: # Normalizing constants acts_dataset = activations.Dataset(model_cfg.acts_cfg) logger.info("Loaded dataset norms: %s.", model_cfg.key) except RuntimeError as err: logger.warning("Error loading ViT: %s", err) return None, None, None, None return vit, vit_transform, acts_dataset.scalar.item(), acts_dataset.act_mean @beartype.beartype @functools.cache def load_sae(model_cfg: modeling.Config) -> nn.SparseAutoencoder: sae_ckpt_fpath = CWD / "checkpoints" / model_cfg.sae_ckpt / "sae.pt" sae = nn.load(sae_ckpt_fpath.as_posix()) sae.to(DEVICE).eval() logger.info("Loaded SAE: %s.", model_cfg.sae_ckpt) return sae ############ # Datasets # ############ @beartype.beartype def load_tensor(path: str | pathlib.Path) -> Tensor: return torch.load(path, weights_only=True, map_location="cpu") @beartype.beartype @functools.cache def load_tensors( model_cfg: modeling.Config, ) -> tuple[Int[Tensor, "d_sae top_k"], Float[Tensor, "d_sae top_k n_patches"]]: top_img_i = load_tensor(model_cfg.tensor_dpath / "top_img_i.pt") # TODO: For some reason, the top_values are about 4 times larger. top_values = load_tensor(model_cfg.tensor_dpath / "top_values.pt") / 4 return top_img_i, top_values @beartype.beartype def get_image(example_id: str) -> list[str]: dataset, split, i_str = example_id.split("__") i = int(i_str) img_v_raw, label = data.get_img_v_raw(f"{dataset}__{split}", i) img_v_sized = data.to_sized(img_v_raw, RESIZE_SIZE, CROP_SIZE) return [data.vips_to_base64(img_v_sized), label] @jaxtyped(typechecker=beartype.beartype) def add_highlights( img_v_sized: pyvips.Image, patches: np.ndarray, *, upper: float | None = None, opacity: float = 0.9, ) -> pyvips.Image: """Add colored highlights to an image based on patch activation values. Overlays a colored highlight on each patch of the image, with intensity proportional to the activation value for that patch. Used to visualize which parts of an image most strongly activated a particular SAE latent. Args: img: The base image to highlight patches: Array of activation values, one per patch upper: Optional maximum value to normalize activations against opacity: Opacity of the highlight overlay (0-1) Returns: A new image with colored highlights overlaid on the original """ if not len(patches): return img_v_sized # Calculate patch grid dimensions grid_w = grid_h = int(math.sqrt(len(patches))) assert grid_w * grid_h == len(patches) patch_w = img_v_sized.width // grid_w patch_h = img_v_sized.height // grid_h assert patch_w == patch_h patches = np.clip(patches, a_min=0.0, a_max=upper + 1e-9) assert upper is not None colors = (COLORMAP(patches / (upper + 1e-9))[:, :3] * 256).astype(np.uint8) # Create overlay by processing each patch overlay = np.zeros((img_v_sized.width, img_v_sized.height, 4), dtype=np.uint8) for idx, (val, color) in enumerate(zip(patches, colors)): val = val / (upper + 1e-9) x = (idx % grid_w) * patch_w y = (idx // grid_w) * patch_h # Create patch overlay patch = np.zeros((patch_w, patch_h, 4), dtype=np.uint8) patch[:, :, 0:3] = color patch[:, :, 3] = int(256 * val * opacity) overlay[y : y + patch_h, x : x + patch_w, :] = patch overlay = pyvips.Image.new_from_array(overlay).copy(interpretation="srgb") return img_v_sized.addalpha().composite(overlay, "over") @beartype.beartype class Example(typing.TypedDict): """Represents an example image and its associated label. Used to store examples of SAE latent activations for visualization. """ orig_url: str """The URL or path to access the original example image.""" highlighted_url: str """The URL or path to access the SAE-highlighted image.""" label: str """The class label or description associated with this example.""" example_id: str """Unique ID to idenfify the original dataset instance.""" @beartype.beartype class SaeActivation(typing.TypedDict): """Represents the activation pattern of a single SAE latent across patches. This captures how strongly a particular SAE latent fires on different patches of an input image. """ model_cfg: modeling.Config """The model config.""" latent: int """The index of the SAE latent being measured.""" activations: list[float] """The activation values of this latent across different patches. Each value represents how strongly this latent fired on a particular patch.""" highlighted_url: str """The image with the colormaps applied.""" examples: list[Example] """Top examples for this latent.""" @beartype.beartype def pil_to_vips(pil_img: PIL.Image.Image) -> pyvips.Image: # Convert to numpy array np_array = np.asarray(pil_img) # Handle different formats if np_array.ndim == 2: # Grayscale return pyvips.Image.new_from_memory( np_array.tobytes(), np_array.shape[1], # width np_array.shape[0], # height 1, # bands "uchar", ) else: # RGB/RGBA return pyvips.Image.new_from_memory( np_array.tobytes(), np_array.shape[1], # width np_array.shape[0], # height np_array.shape[2], # bands "uchar", ) @beartype.beartype def vips_to_pil(vips_img: PIL.Image.Image) -> PIL.Image.Image: # Convert to numpy array np_array = vips_img.numpy() # Convert numpy array to PIL Image return PIL.Image.fromarray(np_array) @beartype.beartype class BufferInfo(typing.NamedTuple): buffer: bytes width: int height: int bands: int format: object @classmethod def from_img_v(cls, img_v: pyvips.Image) -> "BufferInfo": return cls( img_v.write_to_memory(), img_v.width, img_v.height, img_v.bands, img_v.format, ) @beartype.beartype def bufferinfo_to_base64(bufferinfo: BufferInfo) -> str: img_v = pyvips.Image.new_from_memory(*bufferinfo) buf = img_v.write_to_buffer(".webp") b64 = base64.b64encode(buf) s64 = b64.decode("utf8") return "data:image/webp;base64," + s64 @jaxtyped(typechecker=beartype.beartype) def make_sae_activation( model_cfg: modeling.Config, latent: int, acts: Float[np.ndarray, " n_patches"], img_v: pyvips.Image, top_img_i: list[int], top_values: Float[Tensor, "top_k n_patches"], pool: concurrent.futures.Executor, ) -> SaeActivation: raw_examples: list[tuple[int, pyvips.Image, Float[np.ndarray, "..."], str]] = [] seen_i_im = set() for i_im, values_p in zip(top_img_i, top_values): if i_im in seen_i_im: continue ex_img_v_raw, ex_label = data.get_img_v_raw(model_cfg.dataset_name, i_im) ex_img_v_sized = data.to_sized(ex_img_v_raw, RESIZE_SIZE, CROP_SIZE) raw_examples.append((i_im, ex_img_v_sized, values_p.numpy(), ex_label)) seen_i_im.add(i_im) # Only need 4 example images per latent. if len(seen_i_im) >= 4: break upper = top_values.max().item() futures = [] for i_im, ex_img, values_p, ex_label in raw_examples: highlighted_img = add_highlights(ex_img, values_p, upper=upper) # Submit both conversions to the thread pool orig_future = pool.submit(data.vips_to_base64, ex_img) highlight_future = pool.submit(data.vips_to_base64, highlighted_img) futures.append((i_im, orig_future, highlight_future, ex_label)) # Wait for all conversions to complete and build examples examples = [] for i_im, orig_future, highlight_future, ex_label in futures: example = Example( orig_url=orig_future.result(), highlighted_url=highlight_future.result(), label=ex_label, example_id=f"{model_cfg.dataset_name}__{i_im}", ) examples.append(example) print(model_cfg.key, latent, top_values.max(), acts.max()) # Highlight the original image. img_sized_v = data.to_sized(img_v, RESIZE_SIZE, CROP_SIZE) highlighted_img = add_highlights(img_sized_v, acts, upper=upper) highlighted_url = data.vips_to_base64(highlighted_img) return SaeActivation( model_cfg=model_cfg, latent=latent, activations=acts.tolist(), highlighted_url=highlighted_url, examples=examples, ) @beartype.beartype @torch.inference_mode def get_sae_activations( img_p: PIL.Image.Image, latents: dict[str, list[int]] ) -> dict[str, list[SaeActivation]]: """ Args: image: Image to get SAE activations for. latents: A lookup from model name (string) to a list of latents to report latents for (integers). Returns: A lookup from model name (string) to a list of SaeActivations, one for each latent in the `latents` argument. """ logger.info("latents: %s", json.dumps(latents)) response = {} with concurrent.futures.ThreadPoolExecutor(max_workers=16) as pool: for model_name, requested_latents in latents.items(): sae_activations = [] if not requested_latents: logger.warning( "Skipping ViT '%s' with no requested latents.", model_name ) response[model_name] = sae_activations continue model_cfg = MODEL_LOOKUP[model_name] vit, vit_transform, scalar, mean = load_vit(model_cfg) if vit is None: logger.warning("Skipping ViT '%s'", model_name) continue sae = load_sae(model_cfg) mean = mean.to(DEVICE) x = vit_transform(img_p)[None, ...].to(DEVICE) _, vit_acts_BLPD = vit(x) vit_acts_PD = ( vit_acts_BLPD[0, 0, 1:].to(DEVICE).clamp(-1e-5, 1e5) - mean ) / scalar _, f_x_PS, _ = sae(vit_acts_PD) # Ignore [CLS] token and get just the requested latents. acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches") logger.info("Got SAE activations for '%s'.", model_name) top_img_i, top_values = load_tensors(model_cfg) logger.info("Loaded top SAE activations for '%s'.", model_name) for latent in requested_latents: sae_activations.append( make_sae_activation( model_cfg, latent, acts_SP[latent].cpu().numpy(), data.pil_to_vips(img_p), top_img_i[latent].tolist(), top_values[latent], pool, ) ) response[model_name] = sae_activations return response @beartype.beartype class progress: def __init__(self, it, *, every: int = 10, desc: str = "progress", total: int = 0): """ Wraps an iterable with a logger like tqdm but doesn't use any control codes to manipulate a progress bar, which doesn't work well when your output is redirected to a file. Instead, simple logging statements are used, but it includes quality-of-life features like iteration speed and predicted time to finish. Args: it: Iterable to wrap. every: How many iterations between logging progress. desc: What to name the logger. total: If non-zero, how long the iterable is. """ self.it = it self.every = every self.logger = logging.getLogger(desc) self.total = total def __iter__(self): start = time.time() try: total = len(self) except TypeError: total = None for i, obj in enumerate(self.it): yield obj if (i + 1) % self.every == 0: now = time.time() duration_s = now - start per_min = (i + 1) / (duration_s / 60) if total is not None: pred_min = (total - (i + 1)) / per_min self.logger.info( "%d/%d (%.1f%%) | %.1f it/m (expected finish in %.1fm)", i + 1, total, (i + 1) / total * 100, per_min, pred_min, ) else: self.logger.info("%d/? | %.1f it/m", i + 1, per_min) def __len__(self) -> int: if self.total > 0: return self.total # Will throw exception. return len(self.it) ############# # Interface # ############# with gr.Blocks() as demo: example_id_text = gr.Text(label="Test Example") input_image_base64 = gr.Text(label="Image in Base64") input_image_label = gr.Text(label="Image Label") get_input_image_btn = gr.Button(value="Get Input Image") get_input_image_btn.click( get_image, inputs=[example_id_text], outputs=[input_image_base64, input_image_label], api_name="get-image", postprocess=False, ) latents_json = gr.JSON(label="Latents", value={}) activations_json = gr.JSON(label="Activations", value={}) input_image = gr.Image( label="Input Image", sources=["upload", "clipboard"], type="pil", interactive=True, ) get_sae_activations_btn = gr.Button(value="Get SAE Activations") get_sae_activations_btn.click( get_sae_activations, inputs=[input_image, latents_json], outputs=[activations_json], api_name="get-sae-activations", ) if __name__ == "__main__": demo.launch() >>>> app/data.py import functools import logging import beartype import torchvision.datasets from PIL import Image from .. import activations, config logger = logging.getLogger("app.data") @functools.cache def get_datasets(): datasets = { "inat21__train_mini": torchvision.datasets.ImageFolder( root="/research/nfs_su_809/workspace/stevens.994/datasets/inat21/train_mini/" ), "imagenet__train": activations.ImageNet(config.ImagenetDataset()), } logger.info("Loaded datasets.") return datasets @beartype.beartype def get_img_raw(key: str, i: int) -> tuple[Image.Image, str]: """ Get raw image and processed label from dataset. Returns: Tuple of Image.Image and classname. """ dataset = get_datasets()[key] sample = dataset[i] # iNat21 specific: Remove taxonomy prefix label = " ".join(sample["label"].split("_")[1:]) return sample["image"], label def to_sized( img_raw: Image.Image, min_px: int, crop_px: tuple[int, int] ) -> Image.Image: """Convert raw vips image to standard model input size (resize + crop).""" # Calculate scale factor to make smallest dimension = min_px scale = min_px / min(img_raw.width, img_raw.height) # Resize maintaining aspect ratio img_raw = img_raw.resize(scale) assert min(img_raw.width, img_raw.height) == min_px # Calculate crop coordinates to center crop left = (img_raw.width - crop_px[0]) // 2 top = (img_raw.height - crop_px[1]) // 2 # Crop to final size return img_raw.crop(left, top, crop_px[0], crop_px[1]) @beartype.beartype def img_to_b64(img: Image.Image) -> str: raise NotImplementedError() >>>> app/modeling.py import dataclasses import pathlib import beartype from .. import config @beartype.beartype @dataclasses.dataclass(frozen=True) class Config: """Configuration for a Vision Transformer (ViT) and Sparse Autoencoder (SAE) model pair. Stores paths and configuration needed to load and run a specific ViT+SAE combination. """ key: str """The lookup key.""" vit_family: str """The family of ViT model, e.g. 'clip' for CLIP models.""" vit_ckpt: str """Checkpoint identifier for the ViT model, either as HuggingFace path or model/checkpoint pair.""" sae_ckpt: str """Identifier for the SAE checkpoint to load.""" tensor_dpath: pathlib.Path """Directory containing precomputed tensors for this model combination.""" dataset_name: str """Which dataset to use.""" acts_cfg: config.DataLoad """Which activations to load for normalizing.""" @property def wrapped_cfg(self) -> config.Activations: n_patches = 196 if self.vit_family == "dinov2": n_patches = 256 return config.Activations( vit_family=self.vit_family, vit_ckpt=self.vit_ckpt, vit_layers=[-2], n_patches_per_img=n_patches, ) def get_model_lookup() -> dict[str, Config]: cfgs = [ Config( "bioclip/inat21", "clip", "hf-hub:imageomics/bioclip", "gpnn7x3p", pathlib.Path( "/research/nfs_su_809/workspace/stevens.994/saev/features/gpnn7x3p-high-freq/sort_by_patch/" ), "inat21__train_mini", config.DataLoad( shard_root="/local/scratch/stevens.994/cache/saev/50149a5a12c70d378dc38f1976d676239839b591cadbfc9af5c84268ac30a868", n_random_samples=2**16, ), ), Config( "clip/inat21", "clip", "ViT-B-16/openai", "rscsjxgd", pathlib.Path( "/research/nfs_su_809/workspace/stevens.994/saev/features/rscsjxgd-high-freq/sort_by_patch/" ), "inat21__train_mini", config.DataLoad( shard_root="/local/scratch/stevens.994/cache/saev/07aed612e3f70b93ecff46e5a3beea7b8a779f0376dcd3bddf1d5a6ffb4c8f76", n_random_samples=2**16, ), ), Config( "clip/imagenet", "clip", "ViT-B-16/openai", "usvhngx4", pathlib.Path( "/research/nfs_su_809/workspace/stevens.994/saev/features/usvhngx4-high-freq/sort_by_patch/" ), "imagenet__train", config.DataLoad( shard_root="/local/scratch/stevens.994/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8", n_random_samples=2**16, ), ), Config( "dinov2/imagenet", "dinov2", "dinov2_vitb14_reg", "oebd6e6i", pathlib.Path( "/research/nfs_su_809/workspace/stevens.994/saev/features/oebd6e6i/sort_by_patch/" ), "imagenet__train", config.DataLoad( shard_root="/local/scratch/stevens.994/cache/saev/724b1b7be995ef7212d64640fec2885737a706a33b8e5a18f7f323223bd43af1", n_random_samples=2**16, ), ), ] # TODO: figure out how to normalize the activations from the ViT using the same mean/scalar as in the sorted data. return {cfg.key: cfg for cfg in cfgs} >>>> colors.py # https://coolors.co/palette/001219-005f73-0a9396-94d2bd-e9d8a6-ee9b00-ca6702-bb3e03-ae2012-9b2226 BLACK_HEX = "001219" BLACK_RGB = (0, 18, 25) BLACK_RGB01 = tuple(c / 256 for c in BLACK_RGB) BLUE_HEX = "005f73" BLUE_RGB = (0, 95, 115) BLUE_RGB01 = tuple(c / 256 for c in BLUE_RGB) CYAN_HEX = "0a9396" CYAN_RGB = (10, 147, 150) CYAN_RGB01 = tuple(c / 256 for c in CYAN_RGB) SEA_HEX = "94d2bd" SEA_RGB = (148, 210, 189) SEA_RGB01 = tuple(c / 256 for c in SEA_RGB) CREAM_HEX = "e9d8a6" CREAM_RGB = (233, 216, 166) CREAM_RGB01 = tuple(c / 256 for c in CREAM_RGB) GOLD_HEX = "ee9b00" GOLD_RGB = (238, 155, 0) GOLD_RGB01 = tuple(c / 256 for c in GOLD_RGB) ORANGE_HEX = "ca6702" ORANGE_RGB = (202, 103, 2) ORANGE_RGB01 = tuple(c / 256 for c in ORANGE_RGB) RUST_HEX = "bb3e03" RUST_RGB = (187, 62, 3) RUST_RGB01 = tuple(c / 256 for c in RUST_RGB) SCARLET_HEX = "ae2012" SCARLET_RGB = (174, 32, 18) SCARLET_RGB01 = tuple(c / 256 for c in SCARLET_RGB) RED_HEX = "9b2226" RED_RGB = (155, 34, 38) RED_RGB01 = tuple(c / 256 for c in RED_RGB) ALL_HEX = [ BLACK_HEX, BLUE_HEX, CYAN_HEX, SEA_HEX, CREAM_HEX, GOLD_HEX, ORANGE_HEX, RUST_HEX, SCARLET_HEX, RED_HEX, ] ALL_RGB01 = [ BLACK_RGB01, BLUE_RGB01, CYAN_RGB01, SEA_RGB01, CREAM_RGB01, GOLD_RGB01, ORANGE_RGB01, RUST_RGB01, SCARLET_RGB01, RED_RGB01, ] >>>> config.py """ All configs for all saev jobs. ## Import Times This module should be very fast to import so that `python main.py --help` is fast. This means that the top-level imports should not include big packages like numpy, torch, etc. For example, `TreeOfLife.n_imgs` imports numpy when it's needed, rather than importing it at the top level. Also contains code for expanding configs with lists into lists of configs (grid search). Might be expanded in the future to support pseudo-random sampling from distributions to support random hyperparameter search, as in [this file](https://github.com/samuelstevens/sax/blob/main/sax/sweep.py). """ import collections.abc import dataclasses import itertools import os import typing import beartype @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class ImagenetDataset: """Configuration for HuggingFace Imagenet.""" name: str = "ILSVRC/imagenet-1k" """Dataset name on HuggingFace. Don't need to change this..""" split: str = "train" """Dataset split. For the default ImageNet-1K dataset, can either be 'train', 'validation' or 'test'.""" @property def n_imgs(self) -> int: """Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires loading the dataset. If you need to reference this number very often, cache it in a local variable.""" import datasets dataset = datasets.load_dataset( self.name, split=self.split, trust_remote_code=True ) return len(dataset) @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class ImageFolderDataset: """Configuration for a generic image folder dataset.""" root: str = os.path.join(".", "data", "split") """Where the class folders with images are stored.""" @property def n_imgs(self) -> int: """Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires walking the directory structure. If you need to reference this number very often, cache it in a local variable.""" n = 0 for _, _, files in os.walk(self.root): n += len(files) return n @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class Ade20kDataset: """ """ root: str = os.path.join(".", "data", "ade20k") """Where the class folders with images are stored.""" split: typing.Literal["training", "validation"] = "training" """Data split.""" @property def n_imgs(self) -> int: if self.split == "validation": return 2000 else: return 20210 DatasetConfig = ImagenetDataset | ImageFolderDataset | Ade20kDataset @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class Activations: """ Configuration for calculating and saving ViT activations. """ data: DatasetConfig = dataclasses.field(default_factory=ImagenetDataset) """Which dataset to use.""" dump_to: str = os.path.join(".", "shards") """Where to write shards.""" vit_family: typing.Literal["clip", "siglip", "dinov2", "moondream2"] = "clip" """Which model family.""" vit_ckpt: str = "ViT-L-14/openai" """Specific model checkpoint.""" vit_batch_size: int = 1024 """Batch size for ViT inference.""" n_workers: int = 8 """Number of dataloader workers.""" d_vit: int = 1024 """Dimension of the ViT activations (depends on model).""" vit_layers: list[int] = dataclasses.field(default_factory=lambda: [-2]) """Which layers to save. By default, the second-to-last layer.""" n_patches_per_img: int = 256 """Number of ViT patches per image (depends on model).""" cls_token: bool = True """Whether the model has a [CLS] token.""" n_patches_per_shard: int = 2_400_000 """Number of activations per shard; 2.4M is approximately 10GB for 1024-dimensional 4-byte activations.""" seed: int = 42 """Random seed.""" ssl: bool = True """Whether to use SSL.""" # Hardware device: str = "cuda" """Which device to use.""" slurm: bool = False """Whether to use `submitit` to run jobs on a Slurm cluster.""" slurm_acct: str = "PAS2136" """Slurm account string.""" log_to: str = "./logs" """Where to log Slurm job stdout/stderr.""" @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class DataLoad: """ Configuration for loading activation data from disk. """ shard_root: str = os.path.join(".", "shards") """Directory with .bin shards and a metadata.json file.""" patches: typing.Literal["cls", "patches", "meanpool"] = "patches" """Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'patches' indicates it will return all patches. 'meanpool' returns the mean of all image patches.""" layer: int | typing.Literal["all", "meanpool"] = -2 """.. todo: document this field.""" clamp: float = 1e5 """Maximum value for activations; activations will be clamped to within [-clamp, clamp]`.""" n_random_samples: int = 2**19 """Number of random samples used to calculate approximate dataset means at startup.""" scale_mean: bool | str = True """Whether to subtract approximate dataset means from examples. If a string, manually load from the filepath.""" scale_norm: bool | str = True """Whether to scale average dataset norm to sqrt(d_vit). If a string, manually load from the filepath.""" @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class Relu: d_vit: int = 1024 exp_factor: int = 16 """Expansion factor for SAE.""" n_reinit_samples: int = 1024 * 16 * 32 """Number of samples to use for SAE re-init. Anthropic proposes initializing b_dec to the geometric median of the dataset here: https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-bias. We use the regular mean.""" remove_parallel_grads: bool = True """Whether to remove gradients parallel to W_dec columns (which will be ignored because we force the columns to have unit norm). See https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-optimization for the original discussion from Anthropic.""" normalize_w_dec: bool = True """Whether to make sure W_dec has unit norm columns. See https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder for original citation.""" seed: int = 0 """Random seed.""" @property def d_sae(self) -> int: return self.d_vit * self.exp_factor @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class JumpRelu: """Implementation of the JumpReLU activation function for SAEs. Not implemented.""" pass SparseAutoencoder = Relu | JumpRelu @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class Vanilla: sparsity_coeff: float = 4e-4 """How much to weight sparsity loss term.""" @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class Matryoshka: """ Config for the Matryoshka loss for another arbitrary SAE class. Reference code is here: https://github.com/noanabeshima/matryoshka-saes and the original reading is https://sparselatents.com/matryoshka.html and https://arxiv.org/pdf/2503.17547. """ n_prefixes: int = 10 """Number of random length prefixes to use for loss calculation.""" Objective = Vanilla | Matryoshka @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class Train: """ Configuration for training a sparse autoencoder on a vision transformer. """ data: DataLoad = dataclasses.field(default_factory=DataLoad) """Data configuration""" n_workers: int = 32 """Number of dataloader workers.""" n_patches: int = 100_000_000 """Number of SAE training examples.""" sae: SparseAutoencoder = dataclasses.field(default_factory=Relu) """SAE configuration.""" objective: Objective = dataclasses.field(default_factory=Vanilla) """SAE loss configuration.""" n_sparsity_warmup: int = 0 """Number of sparsity coefficient warmup steps.""" lr: float = 0.0004 """Learning rate.""" n_lr_warmup: int = 500 """Number of learning rate warmup steps.""" sae_batch_size: int = 1024 * 16 """Batch size for SAE training.""" # Logging track: bool = True """Whether to track with WandB.""" wandb_project: str = "saev" """WandB project name.""" tag: str = "" """Tag to add to WandB run.""" log_every: int = 25 """How often to log to WandB.""" ckpt_path: str = os.path.join(".", "checkpoints") """Where to save checkpoints.""" device: typing.Literal["cuda", "cpu"] = "cuda" """Hardware device.""" seed: int = 42 """Random seed.""" slurm: bool = False """Whether to use `submitit` to run jobs on a Slurm cluster.""" slurm_acct: str = "PAS2136" """Slurm account string.""" log_to: str = os.path.join(".", "logs") """Where to log Slurm job stdout/stderr.""" @beartype.beartype @dataclasses.dataclass(frozen=True, slots=True) class Visuals: """Configuration for generating visuals from trained SAEs.""" ckpt: str = os.path.join(".", "checkpoints", "sae.pt") """Path to the sae.pt file.""" data: DataLoad = dataclasses.field(default_factory=DataLoad) """Data configuration.""" images: DatasetConfig = dataclasses.field(default_factory=ImagenetDataset) """Which images to use.""" top_k: int = 128 """How many images per SAE feature to store.""" n_workers: int = 16 """Number of dataloader workers.""" topk_batch_size: int = 1024 * 16 """Number of examples to apply top-k op to.""" sae_batch_size: int = 1024 * 16 """Batch size for SAE inference.""" epsilon: float = 1e-9 """Value to add to avoid log(0).""" sort_by: typing.Literal["cls", "img", "patch"] = "patch" """How to find the top k images. 'cls' picks images where the SAE latents of the ViT's [CLS] token are maximized without any patch highligting. 'img' picks images that maximize the sum of an SAE latent over all patches in the image, highlighting the patches. 'patch' pickes images that maximize an SAE latent over all patches (not summed), highlighting the patches and only showing unique images.""" device: str = "cuda" """Which accelerator to use.""" dump_to: str = os.path.join(".", "data") """Where to save data.""" log_freq_range: tuple[float, float] = (-6.0, -2.0) """Log10 frequency range for which to save images.""" log_value_range: tuple[float, float] = (-1.0, 1.0) """Log10 frequency range for which to save images.""" include_latents: list[int] = dataclasses.field(default_factory=list) """Latents to always include, no matter what.""" n_distributions: int = 25 """Number of features to save distributions for.""" percentile: int = 99 """Percentile to estimate for outlier detection.""" n_latents: int = 400 """Maximum number of latents to save images for.""" seed: int = 42 """Random seed.""" @property def root(self) -> str: return os.path.join(self.dump_to, f"sort_by_{self.sort_by}") @property def top_values_fpath(self) -> str: return os.path.join(self.root, "top_values.pt") @property def top_img_i_fpath(self) -> str: return os.path.join(self.root, "top_img_i.pt") @property def top_patch_i_fpath(self) -> str: return os.path.join(self.root, "top_patch_i.pt") @property def mean_values_fpath(self) -> str: return os.path.join(self.root, "mean_values.pt") @property def sparsity_fpath(self) -> str: return os.path.join(self.root, "sparsity.pt") @property def distributions_fpath(self) -> str: return os.path.join(self.root, "distributions.pt") @property def percentiles_fpath(self) -> str: return os.path.join(self.root, f"percentiles_p{self.percentile}.pt") ########## # SWEEPS # ########## @beartype.beartype def grid(cfg: Train, sweep_dct: dict[str, object]) -> tuple[list[Train], list[str]]: cfgs, errs = [], [] for d, dct in enumerate(expand(sweep_dct)): # .sae is a nested field that cannot be naively expanded. sae_dct = dct.pop("sae") if sae_dct: sae_dct["seed"] = sae_dct.pop("seed", cfg.sae.seed) + cfg.seed + d dct["sae"] = dataclasses.replace(cfg.sae, **sae_dct) # .data is a nested field that cannot be naively expanded. data_dct = dct.pop("data") if data_dct: dct["data"] = dataclasses.replace(cfg.data, **data_dct) try: cfgs.append(dataclasses.replace(cfg, **dct, seed=cfg.seed + d)) except Exception as err: errs.append(str(err)) return cfgs, errs @beartype.beartype def expand(config: dict[str, object]) -> collections.abc.Iterator[dict[str, object]]: """ Expands dicts with (nested) lists into a list of (nested) dicts. """ yield from _expand_discrete(config) @beartype.beartype def _expand_discrete( config: dict[str, object], ) -> collections.abc.Iterator[dict[str, object]]: """ Expands any (possibly nested) list values in `config` """ if not config: yield config return key, value = config.popitem() if isinstance(value, list): # Expand for c in _expand_discrete(config): for v in value: yield {**c, key: v} elif isinstance(value, dict): # Expand for c, v in itertools.product( _expand_discrete(config), _expand_discrete(value) ): yield {**c, key: v} else: for c in _expand_discrete(config): yield {**c, key: value} >>>> extending.md >>>> guide.md # Guide to Training SAEs on Vision Models 1. Record ViT activations and save them to disk. 2. Train SAEs on the activations. 3. Visualize the learned features from the trained SAEs. 4. (your job) Propose trends and patterns in the visualized features. 5. (your job, supported by code) Construct datasets to test your hypothesized trends. 6. Confirm/reject hypotheses using `probing` package. `saev` helps with steps 1, 2 and 3. .. note:: `saev` assumes you are running on NVIDIA GPUs. On a multi-GPU system, prefix your commands with `CUDA_VISIBLE_DEVICES=X` to run on GPU X. ## Record ViT Activations to Disk To save activations to disk, we need to specify: 1. Which model we would like to use 2. Which layers we would like to save. 3. Where on disk and how we would like to save activations. 4. Which images we want to save activations for. The `saev.activations` module does all of this for us. Run `uv run python -m saev activations --help` to see all the configuration. In practice, you might run: ```sh uv run python -m saev activations \ --vit-family clip \ --vit-ckpt ViT-B-32/openai \ --d-vit 768 \ --n-patches-per-img 49 \ --vit-layers -2 \ --dump-to /local/scratch/$USER/cache/saev \ --n-patches-per-shard 2_4000_000 \ data:imagenet-dataset ``` This will save activations for the CLIP-pretrained model ViT-B/32, which has a residual stream dimension of 768, and has 49 patches per image (224 / 32 = 7; 7 x 7 = 49). It will save the second-to-last layer (`--layer -2`). It will write 2.4M patches per shard, and save shards to a new directory `/local/scratch$USER/cache/saev`. .. note:: A note on storage space: A ViT-B/16 will save 1.2M images x 197 patches/layer/image x 1 layer = ~240M activations, each of which take up 768 floats x 4 bytes/float = 3072 bytes, for a **total of 723GB** for the entire dataset. As you scale to larger models (ViT-L has 1024 dimensions, 14x14 patches are 224 patches/layer/image), recorded activations will grow even larger. This script will also save a `metadata.json` file that will record the relevant metadata for these activations, which will be read by future steps. The activations will be in `.bin` files, numbered starting from 000000. To add your own models, see the guide to extending in `saev.activations`. ## Train SAEs on Activations To train an SAE, we need to specify: 1. Which activations to use as input. 2. SAE architectural stuff. 3. Optimization-related stuff. `The `saev.training` module handles this. Run `uv run python -m saev train --help` to see all the configuration. Continuing on from our example before, you might want to run something like: ```sh uv run python -m saev train \ --data.shard-root /local/scratch/$USER/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8 \ --data.layer -2 \ --data.patches patches \ --data.no-scale-mean \ --data.no-scale-norm \ --sae.d-vit 768 \ --lr 5e-4 ``` `--data.*` flags describe which activations to use. `--data.shard-root` should point to a directory with `*.bin` files and the `metadata.json` file. `--data.layer` specifies the layer, and `--data.patches` says that want to train on individual patch activations, rather than the [CLS] token activation. `--data.no-scale-mean` and `--data.no-scale-norm` mean not to scale the activation mean or L2 norm. Anthropic's and OpenAI's papers suggest normalizing these factors, but `saev` still has a bug with this, so I suggest not scaling these factors. `--sae.*` flags are about the SAE itself. `--sae.d-vit` is the only one you need to change; the dimension of our ViT was 768 for a ViT-B, rather than the default of 1024 for a ViT-L. Finally, choose a slightly larger learning rate than the default with `--lr 5e-4`. This will train one (1) sparse autoencoder on the data. See the section on sweeps to learn how to train multiple SAEs in parallel using only a single GPU. ## Visualize the Learned Features Now that you've trained an SAE, you probably want to look at its learned features. One way to visualize an individual learned feature \(f\) is by picking out images that maximize the activation of feature \(f\). Since we train SAEs on patch-level activations, we try to find the top *patches* for each feature \(f\). Then, we pick out the images those patches correspond to and create a heatmap based on SAE activation values. .. note:: More advanced forms of visualization are possible (and valuable!), but should not be included in `saev` unless they can be applied to every SAE/dataset combination. If you have specific visualizations, please add them to `contrib/` or another location. `saev.visuals` records these maximally activating images for us. You can see all the options with `uv run python -m saev visuals --help`. The most important configuration options: 1. The SAE checkpoint that you want to use (`--ckpt`). 2. The ViT activations that you want to use (`--data.*` options, should be roughly the same as the options you used to train your SAE, like the same layer, same `--data.patches`). 3. The images that produced the ViT activations that you want to use (`images` and `--images.*` options, should be the same as what you used to generate your ViT activtions). 4. Some filtering options on which SAE latents to include (`--log-freq-range`, `--log-value-range`, `--include-latents`, `--n-latents`). Then, the script runs SAE inference on all of the ViT activations, calculates the images with maximal activation for each SAE feature, then retrieves the images from the original image dataset and highlights them for browsing later on. .. note:: Because of limitations in the SAE training process, not all SAE latents (dimensions of \(f\)) are equally interesting. Some latents are dead, some are *dense*, some only fire on two images, etc. Typically, you want neurons that fire very strongly (high value) and fairly infrequently (low frequency). You might be interested in particular, fixed latents (`--include-latents`). **I recommend using `saev.interactive.metrics` to figure out good thresholds.** So you might run: ```sh uv run python -m saev visuals \ --ckpt checkpoints/abcdefg/sae.pt \ --dump-to /nfs/$USER/saev/webapp/abcdefg \ --data.shard-root /local/scratch/$USER/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8 \ --data.layer -2 \ --data.patches patches \ images:imagenet-dataset ``` This will record the top 128 patches, and then save the unique images among those top 128 patches for each feature in the trained SAE. It will cache these best activations to disk, then start saving images to visualize later on. `saev.interactive.features` is a small web application based on [marimo](https://marimo.io/) to interactively look at these images. You can run it with `uv run marimo edit saev/interactive/features.py`. ## Sweeps > tl;dr: basically the slow part of training SAEs is loading vit activations from disk, and since SAEs are pretty small compared to other models, you can train a bunch of different SAEs in parallel on the same data using a big GPU. That way you can sweep learning rate, lambda, etc. all on one GPU. ### Why Parallel Sweeps SAE training optimizes for a unique bottleneck compared to typical ML workflows: disk I/O rather than GPU computation. When training on vision transformer activations, loading the pre-computed activation data from disk is often the slowest part of the process, not the SAE training itself. A single set of ImageNet activations for a vision transformer can require terabytes of storage. Reading this data repeatedly for each hyperparameter configuration would be extremely inefficient. ### Parallelized Training Architecture To address this bottleneck, we implement parallel training that allows multiple SAE configurations to train simultaneously on the same data batch:
flowchart TD
    A[Pre-computed ViT Activations] -->|Slow I/O| B[Memory Buffer]
    B -->|Shared Batch| C[SAE Model 1]
    B -->|Shared Batch| D[SAE Model 2]
    B -->|Shared Batch| E[SAE Model 3]
    B -->|Shared Batch| F[...]
This approach: - Loads each batch of activations **once** from disk - Uses that same batch for multiple SAE models with different hyperparameters - Amortizes the slow I/O cost across all models in the sweep ### Running a Sweep The `train` command accepts a `--sweep` parameter that points to a TOML file defining the hyperparameter grid: ```bash uv run python -m saev train --sweep configs/my_sweep.toml ``` Here's an example sweep configuration file: ```toml [sae] sparsity_coeff = [1e-4, 2e-4, 3e-4] d_vit = 768 exp_factor = [8, 16] [data] scale_mean = true ``` This would train 6 models (3 sparsity coefficients × 2 expansion factors), each sharing the same data loading operation. ### Limitations Not all parameters can be swept in parallel. Parameters that affect data loading (like `batch_size` or dataset configuration) will cause the sweep to split into separate parallel groups. The system automatically handles this division to maximize efficiency. ## Training Metrics and Visualizations When you train a sweep of SAEs, you probably want to understand which checkpoint is best. `saev` provides some tools to help with that. First, we offer a tool to look at some basic summary statistics of all your trained checkpoints. `saev.interactive.metrics` is a [marimo](https://marimo.io/) notebook (similar to Jupyter, but more interactive) for making L0 vs MSE plots by reading runs off of WandB. However, there are some pieces of code that need to be changed for you to use it. .. todo:: Explain how to use the `saev.interactive.metrics` notebook. * Need to change your wandb username from samuelstevens to USERNAME from wandb * Tag filter * Need to run the notebook on the same machine as the original ViT shards and the shards need to be there. * Think of better ways to do model and data keys * Look at examples * run visuals before features How to run visuals faster? explain how these features are visualized >>>> helpers.py """ Useful helpers for `saev`. """ import logging import os import time import beartype @beartype.beartype def get_cache_dir() -> str: """ Get cache directory from environment variables, defaulting to the current working directory (.) Returns: A path to a cache directory (might not exist yet). """ cache_dir = "" for var in ("SAEV_CACHE", "HF_HOME", "HF_HUB_CACHE"): cache_dir = cache_dir or os.environ.get(var, "") return cache_dir or "." @beartype.beartype class progress: def __init__(self, it, *, every: int = 10, desc: str = "progress", total: int = 0): """ Wraps an iterable with a logger like tqdm but doesn't use any control codes to manipulate a progress bar, which doesn't work well when your output is redirected to a file. Instead, simple logging statements are used, but it includes quality-of-life features like iteration speed and predicted time to finish. Args: it: Iterable to wrap. every: How many iterations between logging progress. desc: What to name the logger. total: If non-zero, how long the iterable is. """ self.it = it self.every = every self.logger = logging.getLogger(desc) self.total = total def __iter__(self): start = time.time() try: total = len(self) except TypeError: total = None for i, obj in enumerate(self.it): yield obj if (i + 1) % self.every == 0: now = time.time() duration_s = now - start per_min = (i + 1) / (duration_s / 60) if total is not None: pred_min = (total - (i + 1)) / per_min self.logger.info( "%d/%d (%.1f%%) | %.1f it/m (expected finish in %.1fm)", i + 1, total, (i + 1) / total * 100, per_min, pred_min, ) else: self.logger.info("%d/? | %.1f it/m", i + 1, per_min) def __len__(self) -> int: if self.total > 0: return self.total # Will throw exception. return len(self.it) ################### # FLATTENED DICTS # ################### @beartype.beartype def flattened( dct: dict[str, object], *, sep: str = "." ) -> dict[str, str | int | float | bool | None]: """ Flatten a potentially nested dict to a single-level dict with `.`-separated keys. """ new = {} for key, value in dct.items(): if isinstance(value, dict): for nested_key, nested_value in flattened(value).items(): new[key + "." + nested_key] = nested_value continue new[key] = value return new @beartype.beartype def get(dct: dict[str, object], key: str, *, sep: str = ".") -> object: key = key.split(sep) key = list(reversed(key)) while len(key) > 1: popped = key.pop() dct = dct[popped] return dct[key.pop()] >>>> imaging.py import math import beartype import matplotlib import numpy as np from jaxtyping import Float, jaxtyped from PIL import Image, ImageDraw colormap = matplotlib.colormaps.get_cmap("plasma") @jaxtyped(typechecker=beartype.beartype) def add_highlights( img: Image.Image, patches: Float[np.ndarray, " n_patches"], *, upper: float | None = None, opacity: float = 0.9, ) -> Image.Image: if not len(patches): return img iw_np, ih_np = int(math.sqrt(len(patches))), int(math.sqrt(len(patches))) iw_px, ih_px = img.size pw_px, ph_px = iw_px // iw_np, ih_px // ih_np assert iw_np * ih_np == len(patches) # Create a transparent overlay overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) colors = (colormap(patches / (upper + 1e-9))[:, :3] * 256).astype(np.uint8) for p, (val, color) in enumerate(zip(patches, colors)): assert upper is not None val /= upper + 1e-9 x_np, y_np = p % iw_np, p // ih_np draw.rectangle( [ (x_np * pw_px, y_np * ph_px), (x_np * pw_px + pw_px, y_np * ph_px + ph_px), ], fill=(*color, int(opacity * val * 256)), ) # Composite the original image and the overlay return Image.alpha_composite(img.convert("RGBA"), overlay) >>>> inference.md # Inference Instructions Briefly, you need to: 1. Download a checkpoint. 2. Get the code. 3. Load the checkpoint. 4. Get activations. Details are below. ## Download a Checkpoint First, download an SAE checkpoint from the [Huggingface collection](https://huggingface.co/collections/osunlp/sae-v-67ab8c4fdf179d117db28195). For instance, you can choose the SAE trained on OpenAI's CLIP ViT-B/16 with ImageNet-1K activations [here](https://huggingface.co/osunlp/SAE_CLIP_24K_ViT-B-16_IN1K). You can use `wget` if you want: ```sh wget https://huggingface.co/osunlp/SAE_CLIP_24K_ViT-B-16_IN1K/resolve/main/sae.pt ``` ## Get the Code The easiest way to do this is to clone the code: ``` git clone https://github.com/OSU-NLP-Group/saev ``` You can also install the package from git if you use uv (not sure about pip or cuda): ```sh uv add git+https://github.com/OSU-NLP-Group/saev ``` Or clone it and install it as an editable with pip, lik `pip install -e .` in your virtual environment. Then you can do things like `from saev import ...`. .. note:: If you struggle to get `saev` installed, open an issue on [GitHub](https://github.com/OSU-NLP-Group/saev) and I will figure out how to make it easier. ## Load the Checkpoint ```py import saev.nn sae = saev.nn.load("PATH_TO_YOUR_SAE_CKPT.pt") ``` Now you have a pretrained SAE. ## Get Activations This is the hardest part. We need to: 1. Pass an image into a ViT 2. Record the dense ViT activations at the same layer that the SAE was trained on. 3. Pass the activations into the SAE to get sparse activations. 4. Do something interesting with the sparse SAE activations. There are examples of this in the demo code: for [classification](https://huggingface.co/spaces/samuelstevens/saev-image-classification/blob/main/app.py#L318) and [semantic segmentation](https://huggingface.co/spaces/samuelstevens/saev-semantic-segmentation/blob/main/app.py#L222). If the permalinks change, you are looking for the `get_sae_latents()` functions in both files. Below is example code to do it using the `saev` package. ```py import saev.nn import saev.activations img_transform = saev.activations.make_img_transform("clip", "ViT-B-16/openai") vit = saev.activations.make_vit("clip", "ViT-B-16/openai") recorded_vit = saev.activations.RecordedVisionTransformer(vit, 196, True, [10]) img = Image.open("example.jpg") x = img_transform(img) # Add a batch dimension x = x[None, ...] _, vit_acts = recorded_vit(x) # Select the only layer in the batch and ignore the CLS token. vit_acts = vit_acts[:, 0, 1:, :] x_hat, f_x, loss = sae(vit_acts) ``` Now you have the reconstructed x (`x_hat`) and the sparse representation of all patches in the image (`f_x`). You might select the dimensions with maximal values for each patch and see what other images are maximimally activating. .. todo:: Provide documentation for how get maximally activating images. >>>> interactive/features.py import marimo __generated_with = "0.9.32" app = marimo.App(width="full") @app.cell def __(): import json import os import random import marimo as mo import matplotlib.pyplot as plt import numpy as np import polars as pl import torch import tqdm return json, mo, np, os, pl, plt, random, torch, tqdm @app.cell def __(mo, os): def make_ckpt_dropdown(): try: choices = sorted( os.listdir("/research/nfs_su_809/workspace/stevens.994/saev/features") ) except FileNotFoundError: choices = [] return mo.ui.dropdown(choices, label="Checkpoint:") ckpt_dropdown = make_ckpt_dropdown() return ckpt_dropdown, make_ckpt_dropdown @app.cell def __(ckpt_dropdown, mo): mo.hstack([ckpt_dropdown], justify="start") return @app.cell def __(ckpt_dropdown, mo): mo.stop( ckpt_dropdown.value is None, mo.md( "Run `uv run main.py webapp --help` to fill out at least one checkpoint." ), ) webapp_dir = f"/research/nfs_su_809/workspace/stevens.994/saev/features/{ckpt_dropdown.value}/sort_by_patch" get_i, set_i = mo.state(0) return get_i, set_i, webapp_dir @app.cell def __(mo): sort_by_freq_btn = mo.ui.run_button(label="Sort by frequency") sort_by_value_btn = mo.ui.run_button(label="Sort by value") sort_by_latent_btn = mo.ui.run_button(label="Sort by latent") return sort_by_freq_btn, sort_by_latent_btn, sort_by_value_btn @app.cell def __(mo, sort_by_freq_btn, sort_by_latent_btn, sort_by_value_btn): mo.hstack( [sort_by_freq_btn, sort_by_value_btn, sort_by_latent_btn], justify="start" ) return @app.cell def __( json, mo, os, sort_by_freq_btn, sort_by_latent_btn, sort_by_value_btn, tqdm, webapp_dir, ): def get_neurons() -> list[dict]: rows = [] for name in tqdm.tqdm(list(os.listdir(f"{webapp_dir}/neurons"))): if not name.isdigit(): continue try: with open(f"{webapp_dir}/neurons/{name}/metadata.json") as fd: rows.append(json.load(fd)) except FileNotFoundError: print(f"Missing metadata.json for neuron {name}.") continue # rows.append({"neuron": int(name)}) return rows neurons = get_neurons() if sort_by_latent_btn.value: neurons = sorted(neurons, key=lambda dct: dct["neuron"]) elif sort_by_freq_btn.value: neurons = sorted(neurons, key=lambda dct: dct["log10_freq"]) elif sort_by_value_btn.value: neurons = sorted(neurons, key=lambda dct: dct["log10_value"], reverse=True) mo.md(f"Found {len(neurons)} saved neurons.") return get_neurons, neurons @app.cell def __(mo, neurons, set_i): next_button = mo.ui.button( label="Next", on_change=lambda _: set_i(lambda v: (v + 1) % len(neurons)), ) prev_button = mo.ui.button( label="Previous", on_change=lambda _: set_i(lambda v: (v - 1) % len(neurons)), ) return next_button, prev_button @app.cell def __(get_i, mo, neurons, set_i): neuron_slider = mo.ui.slider( 0, len(neurons), value=get_i(), on_change=lambda i: set_i(i), full_width=True, ) return (neuron_slider,) @app.cell def __(): return @app.cell def __( display_info, get_i, mo, neuron_slider, neurons, next_button, prev_button, ): # label = f"Neuron {neurons[get_i()]} ({get_i()}/{len(neurons)}; {get_i() / len(neurons) * 100:.2f}%)" # , display_info(**neurons[get_i()]) mo.md(f""" {mo.hstack([prev_button, next_button, display_info(**neurons[get_i()])], justify="start")} {neuron_slider} """) return @app.cell def __(): return @app.cell def __(get_i, mo, neurons): def display_info(log10_freq: float, log10_value: float, neuron: int): return mo.md( f"Neuron {neuron} ({get_i()}/{len(neurons)}; {get_i() / len(neurons) * 100:.1f}%) | Frequency: {10**log10_freq * 100:.3f}% of inputs | Mean Value: {10**log10_value:.3f}" ) return (display_info,) @app.cell def __(mo, webapp_dir): def show_img(n: int, i: int): label = "No label found." try: label = open(f"{webapp_dir}/neurons/{n}/{i}.txt").read().strip() label = " ".join(label.split("_")) except FileNotFoundError: return mo.md(f"*Missing image {i + 1}*") return mo.vstack([mo.image(f"{webapp_dir}/neurons/{n}/{i}.png"), mo.md(label)]) return (show_img,) @app.cell def __(get_i, mo, neurons, show_img): n = neurons[get_i()]["neuron"] mo.vstack([ mo.hstack( [ show_img(n, 0), show_img(n, 1), show_img(n, 2), show_img(n, 3), show_img(n, 4), ], widths="equal", ), mo.hstack( [ show_img(n, 5), show_img(n, 6), show_img(n, 7), show_img(n, 8), show_img(n, 9), ], widths="equal", ), mo.hstack( [ show_img(n, 10), show_img(n, 11), show_img(n, 12), show_img(n, 13), show_img(n, 14), ], widths="equal", ), mo.hstack( [ show_img(n, 15), show_img(n, 16), show_img(n, 17), show_img(n, 18), show_img(n, 19), ], widths="equal", ), mo.hstack( [ show_img(n, 20), show_img(n, 21), show_img(n, 22), show_img(n, 23), show_img(n, 24), ], widths="equal", ), ]) return (n,) @app.cell def __(os, torch, webapp_dir): sparsity_fpath = os.path.join(webapp_dir, "sparsity.pt") sparsity = torch.load(sparsity_fpath, weights_only=True, map_location="cpu") values_fpath = os.path.join(webapp_dir, "mean_values.pt") values = torch.load(values_fpath, weights_only=True, map_location="cpu") return sparsity, sparsity_fpath, values, values_fpath @app.cell def __(mo, np, plt, sparsity): def plot_hist(counts): fig, ax = plt.subplots() ax.hist(np.log10(counts.numpy() + 1e-9), bins=100) return fig mo.md(f""" Sparsity Log10 {mo.as_html(plot_hist(sparsity))} """) return (plot_hist,) @app.cell def __(mo, plot_hist, values): mo.md(f""" Mean Value Log10 {mo.as_html(plot_hist(values))} """) return @app.cell def __(np, plt, sparsity, values): def plot_dist( min_log_sparsity: float, max_log_sparsity: float, min_log_value: float, max_log_value: float, ): fig, ax = plt.subplots() log_sparsity = np.log10(sparsity.numpy() + 1e-9) log_values = np.log10(values.numpy() + 1e-9) mask = np.ones(len(log_sparsity)).astype(bool) mask[log_sparsity < min_log_sparsity] = False mask[log_sparsity > max_log_sparsity] = False mask[log_values < min_log_value] = False mask[log_values > max_log_value] = False n_shown = mask.sum() ax.scatter( log_sparsity[mask], log_values[mask], marker=".", alpha=0.1, color="tab:blue", label=f"Shown ({n_shown})", ) n_filtered = (~mask).sum() ax.scatter( log_sparsity[~mask], log_values[~mask], marker=".", alpha=0.1, color="tab:red", label=f"Filtered ({n_filtered})", ) ax.axvline(min_log_sparsity, linewidth=0.5, color="tab:red") ax.axvline(max_log_sparsity, linewidth=0.5, color="tab:red") ax.axhline(min_log_value, linewidth=0.5, color="tab:red") ax.axhline(max_log_value, linewidth=0.5, color="tab:red") ax.set_xlabel("Feature Frequency (log10)") ax.set_ylabel("Mean Activation Value (log10)") ax.legend(loc="upper right") return fig return (plot_dist,) @app.cell def __(mo, plot_dist, sparsity_slider, value_slider): mo.md(f""" Log Sparsity Range: {sparsity_slider} {sparsity_slider.value} Log Value Range: {value_slider} {value_slider.value} {mo.as_html(plot_dist(sparsity_slider.value[0], sparsity_slider.value[1], value_slider.value[0], value_slider.value[1]))} """) return @app.cell def __(mo): sparsity_slider = mo.ui.range_slider(start=-8, stop=0, step=0.1, value=[-6, -1]) return (sparsity_slider,) @app.cell def __(mo): value_slider = mo.ui.range_slider(start=-3, stop=1, step=0.1, value=[-0.75, 1.0]) return (value_slider,) @app.cell def __(): return @app.cell def __(): return @app.cell def __(): return @app.cell def __(): return @app.cell def __(): return if __name__ == "__main__": app.run() >>>> interactive/metrics.py import marimo __generated_with = "0.9.32" app = marimo.App(width="medium") @app.cell def __(): import json import os import altair as alt import beartype import marimo as mo import matplotlib.pyplot as plt import numpy as np import polars as pl from jaxtyping import Float, jaxtyped import wandb return Float, alt, beartype, jaxtyped, json, mo, np, os, pl, plt, wandb @app.cell def __(mo): mo.md( """ # SAE Metrics Explorer This notebook helps you analyze and compare SAE training runs from WandB. ## Setup Instructions 1. Edit the configuration cell at the top to set your WandB username and project 2. Make sure you have access to the original ViT activation shards 3. Use the filters to narrow down which models to compare ## Troubleshooting - **Missing data error**: This notebook needs access to the original ViT activation shards - **No runs found**: Check your WandB username, project name, and tag filter """ ) return @app.cell def __(): WANDB_USERNAME = "samuelstevens" WANDB_PROJECT = "saev" return WANDB_PROJECT, WANDB_USERNAME @app.cell def __(mo): tag_input = mo.ui.text(value="classification-v1.0", label="Sweep Tag:") return (tag_input,) @app.cell def __(WANDB_PROJECT, WANDB_USERNAME, mo, tag_input): mo.vstack([ mo.md( f"Look at [{WANDB_USERNAME}/{WANDB_PROJECT} on WandB](https://wandb.ai/{WANDB_USERNAME}/{WANDB_PROJECT}/table) to pick your tag." ), tag_input, ]) return @app.cell def __(alt, df, mo): chart = mo.ui.altair_chart( alt.Chart( df.select( "summary/eval/l0", "summary/losses/mse", "id", "config/sae/sparsity_coeff", "config/lr", "config/sae/d_sae", "model_key", ) ) .mark_point() .encode( x=alt.X("summary/eval/l0"), y=alt.Y("summary/losses/mse"), tooltip=["id", "config/lr"], color="config/lr:Q", # shape="config/sae/sparsity_coeff:N", shape="config/sae/d_sae:N", # shape="model_key", ) ) chart return (chart,) @app.cell def __(chart, df, mo, np, plot_dist, plt): mo.stop( len(chart.value) < 2, mo.md( "Select two or more points. Exactly one point is not supported because of a [Polars bug](https://github.com/pola-rs/polars/issues/19855)." ), ) sub_df = ( df.select( "id", "summary/eval/freqs", "summary/eval/mean_values", "summary/eval/l0", ) .join(chart.value.select("id"), on="id", how="inner") .sort(by="summary/eval/l0") .head(4) ) scatter_fig, scatter_axes = plt.subplots( ncols=len(sub_df), figsize=(12, 3), squeeze=False, sharey=True, sharex=True ) hist_fig, hist_axes = plt.subplots( ncols=len(sub_df), nrows=2, figsize=(12, 6), squeeze=False, sharey=True, sharex=True, ) # Always one row scatter_axes = scatter_axes.reshape(-1) hist_axes = hist_axes.T for (id, freqs, values, _), scatter_ax, (freq_hist_ax, values_hist_ax) in zip( sub_df.iter_rows(), scatter_axes, hist_axes ): plot_dist( freqs.astype(float), (-1.0, 1.0), values.astype(float), (-2.0, 2.0), scatter_ax, ) # ax.scatter(freqs, values, marker=".", alpha=0.03) # ax.set_yscale("log") # ax.set_xscale("log") scatter_ax.set_title(id) # Plot feature bins = np.linspace(-6, 1, 100) freq_hist_ax.hist(np.log10(freqs.astype(float)), bins=bins) freq_hist_ax.set_title(f"{id} Feat. Freq. Dist.") values_hist_ax.hist(np.log10(values.astype(float)), bins=bins) values_hist_ax.set_title(f"{id} Mean Val. Distribution") scatter_fig.tight_layout() hist_fig.tight_layout() return ( bins, freq_hist_ax, freqs, hist_axes, hist_fig, id, scatter_ax, scatter_axes, scatter_fig, sub_df, values, values_hist_ax, ) @app.cell def __(scatter_fig): scatter_fig return @app.cell def __(hist_fig): hist_fig return @app.cell def __(chart, df, pl): df.join(chart.value.select("id"), on="id", how="inner").sort( by="summary/eval/l0" ).select("id", pl.selectors.starts_with("config/")) return @app.cell def __(Float, beartype, jaxtyped, np): @jaxtyped(typechecker=beartype.beartype) def plot_dist( freqs: Float[np.ndarray, " d_sae"], freqs_log_range: tuple[float, float], values: Float[np.ndarray, " d_sae"], values_log_range: tuple[float, float], ax, ): log_sparsity = np.log10(freqs + 1e-9) log_values = np.log10(values + 1e-9) mask = np.ones(len(log_sparsity)).astype(bool) min_log_freq, max_log_freq = freqs_log_range mask[log_sparsity < min_log_freq] = False mask[log_sparsity > max_log_freq] = False min_log_value, max_log_value = values_log_range mask[log_values < min_log_value] = False mask[log_values > max_log_value] = False n_shown = mask.sum() ax.scatter( log_sparsity[mask], log_values[mask], marker=".", alpha=0.1, color="tab:blue", label=f"Shown ({n_shown})", ) n_filtered = (~mask).sum() ax.scatter( log_sparsity[~mask], log_values[~mask], marker=".", alpha=0.1, color="tab:red", label=f"Filtered ({n_filtered})", ) ax.axvline(min_log_freq, linewidth=0.5, color="tab:red") ax.axvline(max_log_freq, linewidth=0.5, color="tab:red") ax.axhline(min_log_value, linewidth=0.5, color="tab:red") ax.axhline(max_log_value, linewidth=0.5, color="tab:red") ax.set_xlabel("Feature Frequency (log10)") ax.set_ylabel("Mean Activation Value (log10)") return (plot_dist,) @app.cell def __( beartype, get_data_key, get_model_key, json, load_freqs, load_mean_values, mo, os, pl, tag_input, wandb, ): class MetadataAccessError(Exception): """Exception raised when metadata cannot be accessed or parsed.""" pass @beartype.beartype def find_metadata(shard_root: str): if not os.path.exists(shard_root): msg = f""" ERROR: Shard root '{shard_root}' not found. You need to either: 1. Run this notebook on the same machine where the shards are located. 2. Copy the shards to this machine at path: {shard_root} 3. Update the filtering criteria to only show checkpoints with available data""".strip() raise MetadataAccessError(msg) metadata_path = os.path.join(shard_root, "metadata.json") if not os.path.exists(metadata_path): raise MetadataAccessError("Missing metadata.json file") try: with open(metadata_path) as fd: return json.load(fd) except json.JSONDecodeError: raise MetadataAccessError("Malformed metadata.json file") @beartype.beartype def make_df(tag: str): filters = {} if tag: filters["config.tag"] = tag runs = wandb.Api().runs(path="samuelstevens/saev", filters=filters) rows = [] for run in mo.status.progress_bar( runs, remove_on_exit=True, title="Loading", subtitle="Parsing runs from WandB", ): row = {} row["id"] = run.id row.update(**{ f"summary/{key}": value for key, value in run.summary.items() }) try: row["summary/eval/freqs"] = load_freqs(run) except ValueError: print(f"Run {run.id} did not log eval/freqs.") continue except RuntimeError: print(f"Wandb blew up on run {run.id}.") continue try: row["summary/eval/mean_values"] = load_mean_values(run) except ValueError: print(f"Run {run.id} did not log eval/mean_values.") continue except RuntimeError: print(f"Wandb blew up on run {run.id}.") continue # config row.update(**{ f"config/data/{key}": value for key, value in run.config.pop("data").items() }) row.update(**{ f"config/sae/{key}": value for key, value in run.config.pop("sae").items() }) row.update(**{f"config/{key}": value for key, value in run.config.items()}) try: metadata = find_metadata(row["config/data/shard_root"]) except MetadataAccessError as err: print(f"Bad run {run.id}: {err}") continue row["model_key"] = get_model_key(metadata) data_key = get_data_key(metadata) if data_key is None: print(f"Bad run {run.id}: unknown data.") continue row["data_key"] = data_key row["config/d_vit"] = metadata["d_vit"] rows.append(row) if not rows: raise ValueError("No runs found.") df = pl.DataFrame(rows).with_columns( (pl.col("config/sae/d_vit") * pl.col("config/sae/exp_factor")).alias( "config/sae/d_sae" ) ) return df df = make_df(tag_input.value) return MetadataAccessError, df, find_metadata, make_df @app.cell def __(beartype): @beartype.beartype def get_model_key(metadata: dict[str, object]) -> str: family = next( metadata[key] for key in ("vit_family", "model_family") if key in metadata ) ckpt = next( metadata[key] for key in ("vit_ckpt", "model_ckpt") if key in metadata ) if family == "dinov2" and ckpt == "dinov2_vitb14_reg": return "DINOv2 ViT-B/14 (reg)" if family == "clip" and ckpt == "ViT-B-16/openai": return "CLIP ViT-B/16" if family == "clip" and ckpt == "hf-hub:imageomics/bioclip": return "BioCLIP ViT-B/16" print(f"Unknown model: {(family, ckpt)}") return ckpt @beartype.beartype def get_data_key(metadata: dict[str, object]) -> str | None: if ( "train_mini" in metadata["data"] and "ImageFolderDataset" in metadata["data"] ): return "iNat21" if "train" in metadata["data"] and "Imagenet" in metadata["data"]: return "ImageNet-1K" print(f"Unknown data: {metadata['data']}") return None return get_data_key, get_model_key @app.cell def __(Float, json, np, os): def load_freqs(run) -> Float[np.ndarray, " d_sae"]: try: for artifact in run.logged_artifacts(): if "evalfreqs" not in artifact.name: continue dpath = artifact.download() fpath = os.path.join(dpath, "eval", "freqs.table.json") print(fpath) with open(fpath) as fd: raw = json.load(fd) return np.array(raw["data"]).reshape(-1) except Exception as err: raise RuntimeError("Wandb sucks.") from err raise ValueError(f"freqs not found in run '{run.id}'") def load_mean_values(run) -> Float[np.ndarray, " d_sae"]: try: for artifact in run.logged_artifacts(): if "evalmean_values" not in artifact.name: continue dpath = artifact.download() fpath = os.path.join(dpath, "eval", "mean_values.table.json") print(fpath) with open(fpath) as fd: raw = json.load(fd) return np.array(raw["data"]).reshape(-1) except Exception as err: raise RuntimeError("Wandb sucks.") from err raise ValueError(f"mean_values not found in run '{run.id}'") return load_freqs, load_mean_values @app.cell def __(df): df.drop( "config/log_every", "config/slurm_acct", "config/device", "config/n_workers", "config/wandb_project", "config/track", "config/slurm", "config/log_to", "config/ckpt_path", "config/sae/ghost_grads", ) return if __name__ == "__main__": app.run() >>>> nn/__init__.py from .modeling import SparseAutoencoder, dump, load from .objectives import get_objective __all__ = ["SparseAutoencoder", "dump", "load", "get_objective"] >>>> nn/modeling.py """ Neural network architectures for sparse autoencoders. """ import dataclasses import io import json import logging import os import pathlib import subprocess import typing import beartype import einops import torch from jaxtyping import Float, jaxtyped from torch import Tensor from .. import __version__, config @jaxtyped(typechecker=beartype.beartype) class SparseAutoencoder(torch.nn.Module): """ Sparse auto-encoder (SAE) using L1 sparsity penalty. """ def __init__(self, cfg: config.SparseAutoencoder): super().__init__() self.cfg = cfg self.W_enc = torch.nn.Parameter( torch.nn.init.kaiming_uniform_(torch.empty(cfg.d_vit, cfg.d_sae)) ) self.b_enc = torch.nn.Parameter(torch.zeros(cfg.d_sae)) self.W_dec = torch.nn.Parameter( torch.nn.init.kaiming_uniform_(torch.empty(cfg.d_sae, cfg.d_vit)) ) self.b_dec = torch.nn.Parameter(torch.zeros(cfg.d_vit)) self.activation = get_activation(cfg) self.logger = logging.getLogger(f"sae(seed={cfg.seed})") def forward( self, x: Float[Tensor, "batch d_model"] ) -> tuple[Float[Tensor, "batch d_model"], Float[Tensor, "batch d_sae"]]: """ Given x, calculates the reconstructed x_hat and the intermediate activations f_x. Arguments: x: a batch of ViT activations. """ # Remove encoder bias as per Anthropic h_pre = ( einops.einsum( x - self.b_dec, self.W_enc, "... d_vit, d_vit d_sae -> ... d_sae" ) + self.b_enc ) f_x = self.activation(h_pre) x_hat = self.decode(f_x) return x_hat, f_x def decode( self, f_x: Float[Tensor, "batch d_sae"] ) -> Float[Tensor, "batch d_model"]: x_hat = ( einops.einsum(f_x, self.W_dec, "... d_sae, d_sae d_vit -> ... d_vit") + self.b_dec ) return x_hat @torch.no_grad() def init_b_dec(self, vit_acts: Float[Tensor, "n d_vit"]): if self.cfg.n_reinit_samples <= 0: self.logger.info("Skipping init_b_dec.") return previous_b_dec = self.b_dec.clone().cpu() vit_acts = vit_acts[: self.cfg.n_reinit_samples] assert len(vit_acts) == self.cfg.n_reinit_samples mean = vit_acts.mean(axis=0) previous_distances = torch.norm(vit_acts - previous_b_dec, dim=-1) distances = torch.norm(vit_acts - mean, dim=-1) self.logger.info( "Prev dist: %.3f; new dist: %.3f", previous_distances.median(axis=0).values.mean().item(), distances.median(axis=0).values.mean().item(), ) self.b_dec.data = mean.to(self.b_dec.dtype).to(self.b_dec.device) @torch.no_grad() def normalize_w_dec(self): """ Set W_dec to unit-norm columns. """ if self.cfg.normalize_w_dec: self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True) @torch.no_grad() def remove_parallel_grads(self): """ Update grads so that they remove the parallel component (d_sae, d_vit) shape """ if not self.cfg.remove_parallel_grads: return parallel_component = einops.einsum( self.W_dec.grad, self.W_dec.data, "d_sae d_vit, d_sae d_vit -> d_sae", ) self.W_dec.grad -= einops.einsum( parallel_component, self.W_dec.data, "d_sae, d_sae d_vit -> d_sae d_vit", ) @beartype.beartype def get_activation(cfg: config.SparseAutoencoder) -> torch.nn.Module: if isinstance(cfg, config.Relu): return torch.nn.ReLU() elif isinstance(cfg, config.JumpRelu): raise NotImplementedError() else: typing.assert_never(cfg) @beartype.beartype def dump(fpath: str, sae: SparseAutoencoder): """ Save an SAE checkpoint to disk along with configuration, using the [trick from equinox](https://docs.kidger.site/equinox/examples/serialisation). Arguments: fpath: filepath to save checkpoint to. sae: sparse autoencoder checkpoint to save. """ header = { "schema": 1, "cfg": dataclasses.asdict(sae.cfg), "cls": sae.cfg.__class__.__name__, "commit": current_git_commit() or "unknown", "lib": __version__, } os.makedirs(os.path.dirname(fpath), exist_ok=True) with open(fpath, "wb") as fd: header_str = json.dumps(header) fd.write((header_str + "\n").encode("utf-8")) torch.save(sae.state_dict(), fd) @beartype.beartype def load(fpath: str, *, device="cpu") -> SparseAutoencoder: """ Loads a sparse autoencoder from disk. """ with open(fpath, "rb") as fd: header = json.loads(fd.readline()) buffer = io.BytesIO(fd.read()) if "schema" not in header: # Original, pre-schema stuff. for keyword in ("sparsity_coeff", "ghost_grads"): header.pop(keyword) cfg = config.Relu(**header) elif header["schema"] == 1: cls = getattr(config, header["cls"]) # default for v0 cfg = cls(**header["cfg"]) else: raise ValueError(f"Unknown schema version: {header['schema']}") model = SparseAutoencoder(cfg) model.load_state_dict(torch.load(buffer, weights_only=True, map_location=device)) return model @beartype.beartype def current_git_commit() -> str | None: """ Best-effort short SHA of the repo containing *this* file. Returns `None` when * `git` executable is missing, * we’re not inside a git repo (e.g. installed wheel), * or any git call errors out. """ try: # Walk up until we either hit a .git dir or the FS root here = pathlib.Path(__file__).resolve() for parent in (here, *here.parents): if (parent / ".git").exists(): break else: # no .git found return None result = subprocess.run( ["git", "-C", str(parent), "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, check=True, ) return result.stdout.strip() or None except (FileNotFoundError, subprocess.CalledProcessError): return None >>>> nn/objectives.py import dataclasses import typing import beartype import torch from jaxtyping import Float, jaxtyped from torch import Tensor from .. import config @jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True, slots=True) class Loss: """The loss term for an autoencoder training batch.""" @property def loss(self) -> Float[Tensor, ""]: """Total loss.""" raise NotImplementedError() def metrics(self) -> dict[str, object]: raise NotImplementedError() @jaxtyped(typechecker=beartype.beartype) class Objective(torch.nn.Module): def forward( self, x: Float[Tensor, "batch d_model"], f_x: Float[Tensor, "batch d_sae"], x_hat: Float[Tensor, "batch d_model"], ) -> Loss: raise NotImplementedError() @jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True, slots=True) class VanillaLoss(Loss): """The vanilla loss terms for an training batch.""" mse: Float[Tensor, ""] """Reconstruction loss (mean squared error).""" sparsity: Float[Tensor, ""] """Sparsity loss, typically lambda * L1.""" l0: Float[Tensor, ""] """L0 magnitude of hidden activations.""" l1: Float[Tensor, ""] """L1 magnitude of hidden activations.""" @property def loss(self) -> Float[Tensor, ""]: """Total loss.""" return self.mse + self.sparsity def metrics(self) -> dict[str, object]: return { "loss": self.loss.item(), "mse": self.mse.item(), "l0": self.l0.item(), "l1": self.l1.item(), "sparsity": self.sparsity.item(), } @jaxtyped(typechecker=beartype.beartype) class VanillaObjective(Objective): def __init__(self, cfg: config.Vanilla): super().__init__() self.cfg = cfg def forward( self, x: Float[Tensor, "batch d_model"], f_x: Float[Tensor, "batch d_sae"], x_hat: Float[Tensor, "batch d_model"], ) -> VanillaLoss: # Some values of x and x_hat can be very large. We can calculate a safe MSE print(x_hat.shape, x.shape) mse_loss = mean_squared_err(x_hat, x) mse_loss = mse_loss.mean() l0 = (f_x > 0).float().sum(axis=1).mean(axis=0) l1 = f_x.sum(axis=1).mean(axis=0) sparsity_loss = self.cfg.sparsity_coeff * l1 return VanillaLoss(mse_loss, sparsity_loss, l0, l1) @jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True, slots=True) class MatryoshkaLoss(Loss): """The composite loss terms for an training batch.""" @property def loss(self) -> Float[Tensor, ""]: raise NotImplementedError() @jaxtyped(typechecker=beartype.beartype) class MatryoshkaObjective(Objective): """Torch module for calculating the matryoshka loss for an SAE.""" def __init__(self, cfg: config.Matryoshka): super().__init__() self.cfg = cfg def forward(self) -> "MatryoshkaLoss.Loss": raise NotImplementedError() @beartype.beartype def get_objective(cfg: config.Objective) -> Objective: if isinstance(cfg, config.Vanilla): return VanillaObjective(cfg) elif isinstance(cfg, config.Matryoshka): return MatryoshkaObjective(cfg) else: typing.assert_never(cfg) @jaxtyped(typechecker=beartype.beartype) def ref_mean_squared_err( x_hat: Float[Tensor, "*d"], x: Float[Tensor, "*d"], norm: bool = False ) -> Float[Tensor, "*d"]: mse_loss = torch.pow((x_hat - x.float()), 2) if norm: mse_loss /= (x**2).sum(dim=-1, keepdim=True).sqrt() return mse_loss @jaxtyped(typechecker=beartype.beartype) def mean_squared_err( x_hat: Float[Tensor, "*batch d"], x: Float[Tensor, "*batch d"], norm: bool = False ) -> Float[Tensor, "*batch d"]: upper = x.abs().max() x = x / upper x_hat = x_hat / upper mse = (x_hat - x) ** 2 # (sam): I am now realizing that we normalize by the L2 norm of x. if norm: mse /= torch.linalg.norm(x, axis=-1, keepdim=True) + 1e-12 return mse * upper return mse * upper * upper >>>> nn/test_modeling.py import hypothesis.strategies as st import pytest import torch from hypothesis import given, settings from .. import config from . import modeling def test_factories(): assert isinstance(modeling.get_activation(config.Relu()), torch.nn.ReLU) @st.composite def relu_cfgs(draw): d_vit = draw(st.sampled_from([32, 64, 128])) exp = draw(st.sampled_from([2, 4])) return config.Relu(d_vit=d_vit, exp_factor=exp) @settings(deadline=None) @given(cfg=relu_cfgs(), batch=st.integers(min_value=1, max_value=4)) def test_sae_shapes(cfg, batch): sae = modeling.SparseAutoencoder(cfg) x = torch.randn(batch, cfg.d_vit) x_hat, f = sae(x) assert x_hat.shape == (batch, cfg.d_vit) assert f.shape == (batch, cfg.d_sae) hf_ckpts = [ "osunlp/SAE_BioCLIP_24K_ViT-B-16_iNat21", "osunlp/SAE_CLIP_24K_ViT-B-16_IN1K", "osunlp/SAE_DINOv2_24K_ViT-B-14_IN1K", ] @pytest.mark.parametrize("repo_id", hf_ckpts) @pytest.mark.slow def test_load_bioclip_checkpoint(repo_id, tmp_path): pytest.importorskip("huggingface_hub") import huggingface_hub ckpt_path = huggingface_hub.hf_hub_download( repo_id=repo_id, filename="sae.pt", cache_dir=tmp_path ) model = modeling.load(ckpt_path) # Smoke-test shapes & numerics x = torch.randn(2, model.cfg.d_vit) x_hat, f_x = model(x) assert x_hat.shape == x.shape assert f_x.shape[1] == model.cfg.d_sae # reconstruction shouldn’t be exactly identical, but should have finite values assert torch.isfinite(x_hat).all() roundtrip_cases = [ config.Relu(d_vit=512, exp_factor=8, seed=0), config.Relu(d_vit=768, exp_factor=16, seed=1), config.Relu(d_vit=1024, exp_factor=32, seed=2), ] @pytest.mark.parametrize("sae_cfg", roundtrip_cases) def test_dump_load_roundtrip(tmp_path, sae_cfg): """Write → load → verify state-dict & cfg equality.""" sae = modeling.SparseAutoencoder(sae_cfg) _ = sae(torch.randn(2, sae_cfg.d_vit)) # touch all params once ckpt = tmp_path / "sae.pt" modeling.dump(str(ckpt), sae) sae_loaded = modeling.load(str(ckpt)) # configs identical assert sae_cfg == sae_loaded.cfg # tensors identical for k, v in sae.state_dict().items(): torch.testing.assert_close(v, sae_loaded.state_dict()[k]) >>>> nn/test_objectives.py """ Uses [hypothesis]() and [hypothesis-torch](https://hypothesis-torch.readthedocs.io/en/stable/compatability/) to generate test cases to compare our normalized MSE implementation to a reference MSE implementation. """ import hypothesis import hypothesis.strategies as st import hypothesis_torch import pytest import torch from .. import config from . import objectives def test_mse_same(): x = torch.ones((45, 12), dtype=torch.float) x_hat = torch.ones((45, 12), dtype=torch.float) expected = torch.zeros((45, 12), dtype=torch.float) actual = objectives.mean_squared_err(x_hat, x) torch.testing.assert_close(actual, expected) def test_mse_zero_x_hat(): x = torch.ones((3, 2), dtype=torch.float) x_hat = torch.zeros((3, 2), dtype=torch.float) expected = torch.ones((3, 2), dtype=torch.float) actual = objectives.mean_squared_err(x_hat, x, norm=False) torch.testing.assert_close(actual, expected) def test_mse_nonzero(): x = torch.full((3, 2), 3, dtype=torch.float) x_hat = torch.ones((3, 2), dtype=torch.float) expected = objectives.ref_mean_squared_err(x_hat, x) actual = objectives.mean_squared_err(x_hat, x) torch.testing.assert_close(actual, expected) def test_safe_mse_large_x(): x = torch.full((3, 2), 3e28, dtype=torch.float) x_hat = torch.ones((3, 2), dtype=torch.float) ref = objectives.ref_mean_squared_err(x_hat, x, norm=True) assert ref.isnan().any() safe = objectives.mean_squared_err(x_hat, x, norm=True) assert not safe.isnan().any() def test_factories(): assert isinstance( objectives.get_objective(config.Vanilla()), objectives.VanillaObjective ) # basic element generator finite32 = st.floats( min_value=-1e9, max_value=1e9, allow_nan=False, allow_infinity=False, width=32, ) tensor123 = hypothesis_torch.tensor_strategy( dtype=torch.float32, shape=(1, 2, 3), elements=finite32, layout=torch.strided, device=torch.device("cpu"), ) @st.composite def tensor_pair(draw): x_hat = draw(tensor123) x = draw(tensor123) # ensure denominator in your safe-mse is not zero hypothesis.assume(torch.linalg.norm(x, ord=2, dim=-1).max() > 1e-8) return x_hat, x @pytest.mark.slow @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], deadline=None ) @hypothesis.given(pair=tensor_pair()) def test_safe_mse_hypothesis(pair): x_hat, x = pair # both finite, same device/layout expected = objectives.ref_mean_squared_err(x_hat, x) actual = objectives.mean_squared_err(x_hat, x) torch.testing.assert_close(actual, expected) >>>> preprint.md # Notes for Preprint I'm writing a submission to ICML. The premise is that we apply sparse autoencoders to vision models like DINOv2 and CLIP to interpret and control their internal representations. ## Outline We're trying to (informally) explain our position with the following metaphor: Scientific method: observation -> hypothesis -> experiment Interpretability methods: model behavior -> proposed explanation → ? SAEs complete the cycle: model behavior -> proposed explanation -> feature intervention 1. Introduction 1.1. Understanding requires intervention - we must test hypotheses through controlled experiments (scientific method) 1.2. Current methods provide only understanding or only control, never both 1.3 Understanding and controlling vision models requires three key capabilities: (1) the ability to identify human-interpretable features (like 'fur' or 'wheels'), (2) reliable ways to manipulate these features, and (3) compatibility with existing models. 1.4 Current methods fail to meet these requirements; they either discover features that can't be manipulated, enable manipulations that aren't interpretable, or require expensive model retraining. 1.5. SAEs from NLP provide unified solution: interpretable features that can be precisely controlled to validate hypotheses. 1.6. Contributions: SAE for vision model, new understanding of differences in vision models, multiple control examples across tasks 2. Background & Related Work 2.1. Vision model interpretability 2.2. Model editing 2.4. SAEs in language models 3. Method 3.1. SAE architecture and training 3.2. Feature intervention framework 3.2.1. Train a (or use an existing pre-trained) task-specific head: make predictions based on [CLS] activations, use an LLM to generate captions based on an image and a prompt, etc. 3.2.2. Manipulate the vision transformer activations using the SAE-proposed features and compare outputs before/after intervention. 3.4. Evaluation metrics 4. Understanding Results 4.1. Pre-Training Modality Affects Learned Features - DINOv2 vs CLIP 4.2. Pre-Training Distritbuion Affects Learned Features - CLIP vs BioCLIP 5. Control Results - Task & Model Agnostic 5.1. Semantic Segmentation Control * Intro explaining? * Technical description of training linear semseg head on DINOv2 features. * Qualitative results (cherry picked examples, full-width figure) * [MAYBE] Description of how we automatically find ADE20K class features in SAE latent space * [MAYBE] Quantitative results (single-column table) 5.2. Image Classification Control * Birds with interpretable traits 5.3. Vision-Language Control * Counting + removing objects * Colors + changing colors * Captioning (classification 6. Discussion 6.1. Limitations 6.2. Future work ## List of Figures 1. Hook figure: Full width explanatory figure that shows an overview of how we can use SAEs to interpret vision models and then intervene on that explanation and see how model predictions change. Status: visual outline 2. CLIP vs DINOv2: Full width figure demonstrating that CLIP learns semntically abstract visual features like "human teeth" across different visual styles, while DINOv2 does not. Status: visual outline 3. CLIP vs BioCLIP: Full width figure demonstrating some difference in CLIP and BioCLIP's learned features. Status: untouched. 4. Semantic segmentation: Full width figure demonstrating that we can validate patch-level hypotheses. Status: drafted 5. Image-classification: Full width figure demonstrating how you can manipulate fine-grained classification with SAEs. Status: untouched 6. Image captioning: Full width figure. Status: untouched I also want to build some interactive dashboards and tools to demonstrate that this stuff works. 1. I want my current PCA dashboard with UMAP instead 2. Given a linear classifier of semantic segmentation features, I want to manipulate the features in a given patch, and apply the suppression to all patches to see the live changes on the segmentation mask. 3. After training an SAE on a CLS token, I can then train a linear classifier on the CLS token with ImageNet-1K, and manipulate the features directly. 4. Given a small vision-language model like Phi-3.5 or Moondream, I want to manipulate the vision embeddings (suppressing or adding one or more features) and then see how the top 5 responses change in response to the user input (non-zero temperature). 5. Given a zero-shot CLIP or SigLIP classifier, you can add subtract features from all patches, then see how the classification changes --- With respect to writing, we want to frame everything as Goal->Problem->Solution. In general, I want you to be skeptical and challenging of arguments that are not supported by evidence. Some questions that come up that are not in the outline yet: Q: Am you using standard SAEs or have you adopted the architecture? A: I am using ReLU SAEs with an L1 sparsity term and I have constrained the columns of W_dec to be unit norm to prevent shrinkage. We are not using sigmoid or tanh activations because of prior work from Anthropic exploring the use of these activation functions, finding them to produce worse features than ReLU. Q: What datasets are you using? A: I am using ImageNet-1K for training and testing. I am extending it to iNat2021 (train-mini, 500K images) to demonstrate that results hold beyond ImageNet. We're going to work together on writing this paper, so I want to give you an opportunity to ask any questions you might have. It can be helpful to think about this project from the perspective of a top machine learning researcher, like Lucas Beyer, Yann LeCun, or Francois Chollet. What would they think about this project? What criticisms would they have? What parts would be novel or exciting to them? >>>> related-work.md # Related Work Various papers and internet posts on training SAEs for vision. ## Preprints [An X-Ray Is Worth 15 Features: Sparse Autoencoders for Interpretable Radiology Report Generation](https://arxiv.org/pdf/2410.03334) * Haven't read this yet, but Hugo Fry is an author. ## LessWrong [Towards Multimodal Interpretability: Learning Sparse Interpretable Features in Vision Transformers](https://www.lesswrong.com/posts/bCtbuWraqYTDtuARg/towards-multimodal-interpretability-learning-sparse-2) * Trains a sparse autoencoder on the 22nd layer of a CLIP ViT-L/14. First public work training an SAE on a ViT. Finds interesting features, demonstrating that SAEs work with ViTs. [Interpreting and Steering Features in Images](https://www.lesswrong.com/posts/Quqekpvx8BGMMcaem/interpreting-and-steering-features-in-images) * Havne't read it yet. [Case Study: Interpreting, Manipulating, and Controlling CLIP With Sparse Autoencoders](https://www.lesswrong.com/posts/iYFuZo9BMvr6GgMs5/case-study-interpreting-manipulating-and-controlling-clip) * Followup to the above work; haven't read it yet. [A Suite of Vision Sparse Autoencoders](https://www.lesswrong.com/posts/wrznNDMRmbQABAEMH/a-suite-of-vision-sparse-autoencoders) * Train a sparse autoencoder on various layers using the TopK with k=32 on a CLIP ViT-L/14 trained on LAION-2B. The SAE is trained on 1.2B tokens including patch (not just [CLS]). Limited evaluation. >>>> reproduce.md # Reproduce To reproduce our findings from our preprint, you will need to train a couple SAEs on various datasets, then save visual examples so you can browse them in the notebooks. ## Table of Contents 1. Save activations for ImageNet and iNat2021 for DINOv2, CLIP and BioCLIP. 2. Train SAEs on these activation datasets. 3. Pick the best SAE checkpoints for each combination. 4. Save visualizations for those best checkpoints. ## Save Activations ## Train SAEs ## Choose Best Checkpoints ## Save Visualizations Get visuals for the iNat-trained SAEs (BioCLIP and CLIP): ```sh uv run python -m saev visuals \ --ckpt checkpoints/$CKPT/sae.pt \ --dump-to /$NFS/$USER/saev-visuals/$CKPT/ \ --log-freq-range -2.0 -1.0 \ --log-value-range -0.75 2.0 \ --data.shard-root /local/scratch/$USER/cache/saev/$SHARDS \ images:image-folder-dataset \ --images.root /$NFS/$USER/datasets/inat21/train_mini/ ``` Look at these visuals in the interactive notebook. ```sh uv run marimo edit ``` Then open [localhost:2718](https://localhost:2718) in your browser and open the `saev/interactive/features.py` file. Choose one of the checkpoints in the dropdown and click through the different neurons to find patterns in the underlying ViT. >>>> test_activations.py """ Test that the cached activations are actually correct. These tests are quite slow """ import tempfile import pytest import torch from . import activations, config @pytest.mark.slow def test_dataloader_batches(): cfg = config.Activations( vit_ckpt="ViT-B-32/openai", d_vit=768, vit_layers=[-2, -1], n_patches_per_img=49, vit_batch_size=8, ) dataloader = activations.get_dataloader( cfg, img_transform=activations.make_img_transform(cfg.vit_family, cfg.vit_ckpt) ) batch = next(iter(dataloader)) assert isinstance(batch, dict) assert "image" in batch assert "index" in batch torch.testing.assert_close(batch["index"], torch.arange(8)) assert batch["image"].shape == (8, 3, 224, 224) @pytest.mark.slow def test_shard_writer_and_dataset_e2e(): with tempfile.TemporaryDirectory() as tmpdir: cfg = config.Activations( vit_family="dinov2", vit_ckpt="dinov2_vits14_reg", d_vit=384, n_patches_per_img=256, vit_layers=[-2, -1], vit_batch_size=8, n_workers=8, dump_to=tmpdir, ) vit = activations.make_vit(cfg.vit_family, cfg.vit_ckpt) vit = activations.RecordedVisionTransformer( vit, cfg.n_patches_per_img, cfg.cls_token, cfg.vit_layers ) dataloader = activations.get_dataloader( cfg, img_transform=activations.make_img_transform(cfg.vit_family, cfg.vit_ckpt), ) writer = activations.ShardWriter(cfg) dataset = activations.Dataset( config.DataLoad( shard_root=activations.get_acts_dir(cfg), patches="cls", layer=-1, scale_mean=False, scale_norm=False, ) ) i = 0 for b, batch in zip(range(10), dataloader): # Don't care about the forward pass. out, cache = vit(batch["image"]) del out writer[i : i + len(cache)] = cache i += len(cache) assert cache.shape == (cfg.vit_batch_size, len(cfg.vit_layers), 257, 384) acts = [dataset[i.item()]["act"] for i in batch["index"]] from_dataset = torch.stack(acts) torch.testing.assert_close(cache[:, -1, 0], from_dataset) print(f"Batch {b} matched.") >>>> test_config.py import pytest from . import config def test_expand(): cfg = {"lr": [1, 2, 3]} expected = [{"lr": 1}, {"lr": 2}, {"lr": 3}] actual = list(config.expand(cfg)) assert expected == actual def test_expand_two_fields(): cfg = {"lr": [1, 2], "wd": [3, 4]} expected = [ {"lr": 1, "wd": 3}, {"lr": 1, "wd": 4}, {"lr": 2, "wd": 3}, {"lr": 2, "wd": 4}, ] actual = list(config.expand(cfg)) assert expected == actual def test_expand_nested(): cfg = {"sae": {"dim": [1, 2, 3]}} expected = [{"sae": {"dim": 1}}, {"sae": {"dim": 2}}, {"sae": {"dim": 3}}] actual = list(config.expand(cfg)) assert expected == actual def test_expand_nested_and_unnested(): cfg = {"sae": {"dim": [1, 2]}, "lr": [3, 4]} expected = [ {"sae": {"dim": 1}, "lr": 3}, {"sae": {"dim": 1}, "lr": 4}, {"sae": {"dim": 2}, "lr": 3}, {"sae": {"dim": 2}, "lr": 4}, ] actual = list(config.expand(cfg)) assert expected == actual def test_expand_nested_and_unnested_backwards(): cfg = {"a": [False, True], "b": {"c": [False, True]}} expected = [ {"a": False, "b": {"c": False}}, {"a": False, "b": {"c": True}}, {"a": True, "b": {"c": False}}, {"a": True, "b": {"c": True}}, ] actual = list(config.expand(cfg)) assert expected == actual def test_expand_multiple(): cfg = {"a": [1, 2, 3], "b": {"c": [4, 5, 6]}} expected = [ {"a": 1, "b": {"c": 4}}, {"a": 1, "b": {"c": 5}}, {"a": 1, "b": {"c": 6}}, {"a": 2, "b": {"c": 4}}, {"a": 2, "b": {"c": 5}}, {"a": 2, "b": {"c": 6}}, {"a": 3, "b": {"c": 4}}, {"a": 3, "b": {"c": 5}}, {"a": 3, "b": {"c": 6}}, ] actual = list(config.expand(cfg)) assert expected == actual # every union alias is exhaustive: constructing an unknown class must fail @pytest.mark.parametrize( "alias, members", [ (config.SparseAutoencoder, {config.Relu, config.JumpRelu}), (config.Objective, {config.Vanilla, config.Matryoshka}), ( config.DatasetConfig, {config.ImagenetDataset, config.ImageFolderDataset, config.Ade20kDataset}, ), ], ) def test_union_is_exhaustive(alias, members): assert members == set(alias.__args__) >>>> test_training.py import torch from . import config, training def test_split_cfgs_on_single_key(): cfgs = [config.Train(n_workers=12), config.Train(n_workers=16)] expected = [[config.Train(n_workers=12)], [config.Train(n_workers=16)]] actual = training.split_cfgs(cfgs) assert actual == expected def test_split_cfgs_on_single_key_with_multiple_per_key(): cfgs = [ config.Train(n_patches=12), config.Train(n_patches=16), config.Train(n_patches=16), config.Train(n_patches=16), ] expected = [ [config.Train(n_patches=12)], [ config.Train(n_patches=16), config.Train(n_patches=16), config.Train(n_patches=16), ], ] actual = training.split_cfgs(cfgs) assert actual == expected def test_split_cfgs_on_multiple_keys_with_multiple_per_key(): cfgs = [ config.Train(n_patches=12, track=False), config.Train(n_patches=12, track=True), config.Train(n_patches=16, track=True), config.Train(n_patches=16, track=True), config.Train(n_patches=16, track=False), ] expected = [ [config.Train(n_patches=12, track=False)], [config.Train(n_patches=12, track=True)], [ config.Train(n_patches=16, track=True), config.Train(n_patches=16, track=True), ], [config.Train(n_patches=16, track=False)], ] actual = training.split_cfgs(cfgs) assert actual == expected def test_split_cfgs_no_bad_keys(): cfgs = [ config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=1e-4)), config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=2e-4)), config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=3e-4)), config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=4e-4)), config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=5e-4)), ] expected = [ [ config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=1e-4)), config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=2e-4)), config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=3e-4)), config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=4e-4)), config.Train(n_patches=12, objective=config.Vanilla(sparsity_coeff=5e-4)), ] ] actual = training.split_cfgs(cfgs) assert actual == expected class DummyDS(torch.utils.data.Dataset): def __init__(self, n, d): self.x = torch.randn(n, d) def __getitem__(self, i): return dict(act=self.x[i]) def __len__(self): return len(self.x) def test_one_training_step(monkeypatch): cfg = config.Train( track=False, sae_batch_size=8, data=config.DataLoad(), n_patches=64 ) # monkey-patch Dataset/loader used in activations module from . import activations monkeypatch.setattr(activations, "Dataset", lambda _: DummyDS(32, cfg.sae.d_vit)) monkeypatch.setattr(activations, "get_dataloader", lambda *a, **k: None) # not used # run a single loop from . import training ids = training.main([cfg]) # should not raise assert len(ids) == 1 def test_one_training_step_matryoshka(monkeypatch): """A minimal end-to-end training-loop smoke test for the Matryoshka objective.""" # configuration that uses Matryoshka cfg = config.Train( track=False, sae_batch_size=8, n_patches=64, # make the run fast. data=config.DataLoad(), objective=config.Matryoshka(), ) # stub out expensive I/O from . import activations monkeypatch.setattr(activations, "Dataset", lambda *_: DummyDS(32, cfg.sae.d_vit)) monkeypatch.setattr(activations, "get_dataloader", lambda *_1, **_2: None) # run one training job from saev import training ids = training.main([cfg]) assert len(ids) == 1 >>>> test_visuals.py import torch from . import visuals def test_gather_batched_small(): values = torch.arange(0, 64, dtype=torch.float).view(4, 2, 8) i = torch.tensor([[0], [0], [1], [1]]) actual = visuals.gather_batched(values, i) expected = torch.tensor([ [[0, 1, 2, 3, 4, 5, 6, 7]], [[16, 17, 18, 19, 20, 21, 22, 23]], [[40, 41, 42, 43, 44, 45, 46, 47]], [[56, 57, 58, 59, 60, 61, 62, 63]], ]).float() torch.testing.assert_close(actual, expected) >>>> training.py """ Trains many SAEs in parallel to amortize the cost of loading a single batch of data over many SAE training runs. """ import dataclasses import json import logging import os.path import beartype import einops import numpy as np import torch from jaxtyping import Float from torch import Tensor import wandb from . import activations, config, helpers, nn log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger("train") @torch.no_grad() def init_b_dec_batched(saes: torch.nn.ModuleList, dataset: activations.Dataset): n_samples = max(sae.cfg.n_reinit_samples for sae in saes) if not n_samples: return # Pick random samples using first SAE's seed. perm = np.random.default_rng(seed=saes[0].cfg.seed).permutation(len(dataset)) perm = perm[:n_samples] examples, _, _ = zip(*[ dataset[p.item()] for p in helpers.progress(perm, every=25_000, desc="examples to re-init b_dec") ]) vit_acts = torch.stack(examples) for sae in saes: sae.init_b_dec(vit_acts[: sae.cfg.n_reinit_samples]) @beartype.beartype def make_saes( cfgs: list[tuple[config.SparseAutoencoder, config.Objective]], ) -> tuple[torch.nn.ModuleList, torch.nn.ModuleList, list[dict[str, object]]]: saes, objectives, param_groups = [], [], [] for sae_cfg, obj_cfg in cfgs: sae = nn.SparseAutoencoder(sae_cfg) saes.append(sae) # Use an empty LR because our first step is warmup. param_groups.append({"params": sae.parameters(), "lr": 0.0}) objectives.append(nn.get_objective(obj_cfg)) return torch.nn.ModuleList(saes), torch.nn.ModuleList(objectives), param_groups ################## # Parallel Wandb # ################## MetricQueue = list[tuple[int, dict[str, object]]] class ParallelWandbRun: """ Inspired by https://community.wandb.ai/t/is-it-possible-to-log-to-multiple-runs-simultaneously/4387/3. """ def __init__( self, project: str, cfgs: list[config.Train], mode: str, tags: list[str] ): cfg, *cfgs = cfgs self.project = project self.cfgs = cfgs self.mode = mode self.tags = tags self.live_run = wandb.init( project=project, config=dataclasses.asdict(cfg), mode=mode, tags=tags ) self.metric_queues: list[MetricQueue] = [[] for _ in self.cfgs] def log(self, metrics: list[dict[str, object]], *, step: int): metric, *metrics = metrics self.live_run.log(metric, step=step) for queue, metric in zip(self.metric_queues, metrics): queue.append((step, metric)) def finish(self) -> list[str]: ids = [self.live_run.id] # Log the rest of the runs. self.live_run.finish() for queue, cfg in zip(self.metric_queues, self.cfgs): run = wandb.init( project=self.project, config=dataclasses.asdict(cfg), mode=self.mode, tags=self.tags + ["queued"], ) for step, metric in queue: run.log(metric, step=step) ids.append(run.id) run.finish() return ids @beartype.beartype def main(cfgs: list[config.Train]) -> list[str]: saes, objectives, run, steps = train(cfgs) # Cheap(-ish) evaluation eval_metrics = evaluate(cfgs, saes, objectives) metrics = [metric.for_wandb() for metric in eval_metrics] run.log(metrics, step=steps) ids = run.finish() for cfg, id, metric, sae in zip(cfgs, ids, eval_metrics, saes): logger.info( "Checkpoint %s has %d dense features (%.1f)", id, metric.n_dense, metric.n_dense / sae.cfg.d_sae * 100, ) logger.info( "Checkpoint %s has %d dead features (%.1f%%)", id, metric.n_dead, metric.n_dead / sae.cfg.d_sae * 100, ) logger.info( "Checkpoint %s has %d *almost* dead (<1e-7) features (%.1f)", id, metric.n_almost_dead, metric.n_almost_dead / sae.cfg.d_sae * 100, ) ckpt_fpath = os.path.join(cfg.ckpt_path, id, "sae.pt") nn.dump(ckpt_fpath, sae) logger.info("Dumped checkpoint to '%s'.", ckpt_fpath) cfg_fpath = os.path.join(cfg.ckpt_path, id, "config.json") with open(cfg_fpath, "w") as fd: json.dump(dataclasses.asdict(cfg), fd, indent=4) return ids @beartype.beartype def train( cfgs: list[config.Train], ) -> tuple[torch.nn.ModuleList, torch.nn.ModuleList, ParallelWandbRun, int]: """ Explicitly declare the optimizer, schedulers, dataloader, etc outside of `main` so that all the variables are dropped from scope and can be garbage collected. """ if len(split_cfgs(cfgs)) != 1: raise ValueError("Configs are not parallelizeable: {cfgs}.") logger.info("Parallelizing %d runs.", len(cfgs)) cfg = cfgs[0] if torch.cuda.is_available(): # This enables tf32 on Ampere GPUs which is only 8% slower than # float16 and almost as accurate as float32 # This was a default in pytorch until 1.12 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False dataset = activations.Dataset(cfg.data) saes, objectives, param_groups = make_saes([(c.sae, c.objective) for c in cfgs]) mode = "online" if cfg.track else "disabled" tags = [cfg.tag] if cfg.tag else [] run = ParallelWandbRun(cfg.wandb_project, cfgs, mode, tags) optimizer = torch.optim.Adam(param_groups, fused=True) lr_schedulers = [Warmup(0.0, c.lr, c.n_lr_warmup) for c in cfgs] sparsity_schedulers = [ Warmup(0.0, c.objective.sparsity_coeff, c.n_sparsity_warmup) for c in cfgs ] dataloader = torch.utils.data.DataLoader( dataset, batch_size=cfg.sae_batch_size, num_workers=cfg.n_workers, shuffle=True ) dataloader = BatchLimiter(dataloader, cfg.n_patches) saes.train() saes = saes.to(cfg.device) objectives.train() objectives = objectives.to(cfg.device) global_step, n_patches_seen = 0, 0 for batch in helpers.progress(dataloader, every=cfg.log_every): acts_BD = batch["act"].to(cfg.device, non_blocking=True) for sae in saes: sae.normalize_w_dec() # Forward passes and loss calculations. losses = [] for sae, objective in zip(saes, objectives): x_hat, f_x = sae(acts_BD) losses.append(objective(acts_BD, f_x, x_hat)) n_patches_seen += len(acts_BD) with torch.no_grad(): if (global_step + 1) % cfg.log_every == 0: metrics = [ { **loss.metrics(), "progress/n_patches_seen": n_patches_seen, "progress/learning_rate": group["lr"], "progress/sparsity_coeff": objective.sparsity_coeff, } for loss, sae, objective, group in zip( losses, saes, objectives, optimizer.param_groups ) ] run.log(metrics, step=global_step) logger.info( ", ".join( f"{key}: {value:.5f}" for key, value in losses[0].metrics().items() ) ) for loss in losses: loss.loss.backward() for sae in saes: sae.remove_parallel_grads() optimizer.step() # Update LR and sparsity coefficients. for param_group, scheduler in zip(optimizer.param_groups, lr_schedulers): param_group["lr"] = scheduler.step() for objective, scheduler in zip(objectives, sparsity_schedulers): objective.sparsity_coeff = scheduler.step() # Don't need these anymore. optimizer.zero_grad() global_step += 1 return saes, objectives, run, global_step @beartype.beartype @dataclasses.dataclass(frozen=True) class EvalMetrics: """Results of evaluating a trained SAE on a datset.""" l0: float """Mean L0 across all examples.""" l1: float """Mean L1 across all examples.""" mse: float """Mean MSE across all examples.""" n_dead: int """Number of neurons that never fired on any example.""" n_almost_dead: int """Number of neurons that fired on fewer than `almost_dead_threshold` of examples.""" n_dense: int """Number of neurons that fired on more than `dense_threshold` of examples.""" freqs: Float[Tensor, " d_sae"] """How often each feature fired.""" mean_values: Float[Tensor, " d_sae"] """The mean value for each feature when it did fire.""" almost_dead_threshold: float """Threshold for an "almost dead" neuron.""" dense_threshold: float """Threshold for a dense neuron.""" def for_wandb(self) -> dict[str, int | float]: dct = dataclasses.asdict(self) # Store arrays as tables. dct["freqs"] = wandb.Table(columns=["freq"], data=dct["freqs"][:, None].numpy()) dct["mean_values"] = wandb.Table( columns=["mean_value"], data=dct["mean_values"][:, None].numpy() ) return {f"eval/{key}": value for key, value in dct.items()} @beartype.beartype @torch.no_grad() def evaluate( cfgs: list[config.Train], saes: torch.nn.ModuleList, objectives: torch.nn.ModuleList ) -> list[EvalMetrics]: """ Evaluates SAE quality by counting the number of dead features and the number of dense features. Also makes histogram plots to help human qualitative comparison. .. todo:: Develop automatic methods to use histogram and feature frequencies to evaluate quality with a single number. """ torch.cuda.empty_cache() if len(split_cfgs(cfgs)) != 1: raise ValueError("Configs are not parallelizeable: {cfgs}.") saes.eval() cfg = cfgs[0] almost_dead_lim = 1e-7 dense_lim = 1e-2 dataset = activations.Dataset(cfg.data) dataloader = torch.utils.data.DataLoader( dataset, batch_size=cfg.sae_batch_size, num_workers=cfg.n_workers, shuffle=False ) n_fired = torch.zeros((len(cfgs), saes[0].cfg.d_sae)) values = torch.zeros((len(cfgs), saes[0].cfg.d_sae)) total_l0 = torch.zeros(len(cfgs)) total_l1 = torch.zeros(len(cfgs)) total_mse = torch.zeros(len(cfgs)) for batch in helpers.progress(dataloader, desc="eval", every=cfg.log_every): acts_BD = batch["act"].to(cfg.device, non_blocking=True) for i, (sae, objective) in enumerate(zip(saes, objectives)): x_hat_BD, f_x_BS = sae(acts_BD) loss = objective(acts_BD, f_x_BS, x_hat_BD) n_fired[i] += einops.reduce(f_x_BS > 0, "batch d_sae -> d_sae", "sum").cpu() values[i] += einops.reduce(f_x_BS, "batch d_sae -> d_sae", "sum").cpu() total_l0[i] += loss.l0.cpu() total_l1[i] += loss.l1.cpu() total_mse[i] += loss.mse.cpu() mean_values = values / n_fired freqs = n_fired / len(dataset) l0 = (total_l0 / len(dataloader)).tolist() l1 = (total_l1 / len(dataloader)).tolist() mse = (total_mse / len(dataloader)).tolist() n_dead = einops.reduce(freqs == 0, "n_saes d_sae -> n_saes", "sum").tolist() n_almost_dead = einops.reduce( freqs < almost_dead_lim, "n_saes d_sae -> n_saes", "sum" ).tolist() n_dense = einops.reduce(freqs > dense_lim, "n_saes d_sae -> n_saes", "sum").tolist() metrics = [] for row in zip(l0, l1, mse, n_dead, n_almost_dead, n_dense, freqs, mean_values): metrics.append(EvalMetrics(*row, almost_dead_lim, dense_lim)) return metrics class BatchLimiter: """ Limits the number of batches to only return `n_samples` total samples. """ def __init__(self, dataloader: torch.utils.data.DataLoader, n_samples: int): self.dataloader = dataloader self.n_samples = n_samples self.batch_size = dataloader.batch_size def __len__(self) -> int: return self.n_samples // self.batch_size def __iter__(self): self.n_seen = 0 while True: for batch in self.dataloader: yield batch # Sometimes we underestimate because the final batch in the dataloader might not be a full batch. self.n_seen += self.batch_size if self.n_seen > self.n_samples: return # We try to mitigate the above issue by ignoring the last batch if we don't have drop_last. if not self.dataloader.drop_last: self.n_seen -= self.batch_size ##################### # Parallel Training # ##################### CANNOT_PARALLELIZE = set([ "data", "n_workers", "n_patches", "sae_batch_size", "track", "wandb_project", "tag", "log_every", "ckpt_path", "device", "slurm", "slurm_acct", "log_to", "sae.exp_factor", "sae.d_vit", ]) @beartype.beartype def split_cfgs(cfgs: list[config.Train]) -> list[list[config.Train]]: """ Splits configs into groups that can be parallelized. Arguments: A list of configs from a sweep file. Returns: A list of lists, where the configs in each sublist do not differ in any keys that are in `CANNOT_PARALLELIZE`. This means that each sublist is a valid "parallel" set of configs for `train`. """ # Group configs by their values for CANNOT_PARALLELIZE keys groups = {} for cfg in cfgs: dct = dataclasses.asdict(cfg) # Create a key tuple from the values of CANNOT_PARALLELIZE keys key_values = [] for key in sorted(CANNOT_PARALLELIZE): key_values.append((key, make_hashable(helpers.get(dct, key)))) group_key = tuple(key_values) if group_key not in groups: groups[group_key] = [] groups[group_key].append(cfg) # Convert groups dict to list of lists return list(groups.values()) def make_hashable(obj): return json.dumps(obj, sort_keys=True) ############## # Schedulers # ############## @beartype.beartype class Scheduler: def step(self) -> float: err_msg = f"{self.__class__.__name__} must implement step()." raise NotImplementedError(err_msg) def __repr__(self) -> str: err_msg = f"{self.__class__.__name__} must implement __repr__()." raise NotImplementedError(err_msg) @beartype.beartype class Warmup(Scheduler): """ Linearly increases from `init` to `final` over `n_warmup_steps` steps. """ def __init__(self, init: float, final: float, n_steps: int): self.final = final self.init = init self.n_steps = n_steps self._step = 0 def step(self) -> float: self._step += 1 if self._step < self.n_steps: return self.init + (self.final - self.init) * (self._step / self.n_steps) return self.final def __repr__(self) -> str: return f"Warmup(init={self.init}, final={self.final}, n_steps={self.n_steps})" def _plot_example_schedules(): import matplotlib.pyplot as plt import numpy as np fig, ax = plt.subplots() n_steps = 1000 xs = np.arange(n_steps) schedule = Warmup(0.1, 0.9, 100) ys = [schedule.step() for _ in xs] ax.plot(xs, ys, label=str(schedule)) fig.tight_layout() fig.savefig("schedules.png") if __name__ == "__main__": _plot_example_schedules() >>>> visuals.py """ There is some important notation used only in this file to dramatically shorten variable names. Variables suffixed with `_im` refer to entire images, and variables suffixed with `_p` refer to patches. """ import collections.abc import dataclasses import json import logging import math import os import random import typing import beartype import einops import torch from jaxtyping import Float, Int, jaxtyped from PIL import Image from torch import Tensor from . import activations, config, helpers, imaging, nn logger = logging.getLogger("visuals") @beartype.beartype def safe_load(path: str) -> object: return torch.load(path, map_location="cpu", weights_only=True) @jaxtyped(typechecker=beartype.beartype) def gather_batched( value: Float[Tensor, "batch n dim"], i: Int[Tensor, "batch k"] ) -> Float[Tensor, "batch k dim"]: batch_size, n, dim = value.shape # noqa: F841 _, k = i.shape batch_i = torch.arange(batch_size, device=value.device)[:, None].expand(-1, k) return value[batch_i, i] @jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass class GridElement: img: Image.Image label: str patches: Float[Tensor, " n_patches"] @beartype.beartype def make_img(elem: GridElement, *, upper: float | None = None) -> Image.Image: # Resize to 256x256 and crop to 224x224 resize_size_px = (512, 512) resize_w_px, resize_h_px = resize_size_px crop_size_px = (448, 448) crop_w_px, crop_h_px = crop_size_px crop_coords_px = ( (resize_w_px - crop_w_px) // 2, (resize_h_px - crop_h_px) // 2, (resize_w_px + crop_w_px) // 2, (resize_h_px + crop_h_px) // 2, ) img = elem.img.resize(resize_size_px).crop(crop_coords_px) img = imaging.add_highlights(img, elem.patches.numpy(), upper=upper) return img @jaxtyped(typechecker=beartype.beartype) def get_new_topk( val1: Float[Tensor, "d_sae k"], i1: Int[Tensor, "d_sae k"], val2: Float[Tensor, "d_sae k"], i2: Int[Tensor, "d_sae k"], k: int, ) -> tuple[Float[Tensor, "d_sae k"], Int[Tensor, "d_sae k"]]: """ Picks out the new top k values among val1 and val2. Also keeps track of i1 and i2, then indices of the values in the original dataset. Args: val1: top k original SAE values. i1: the patch indices of those original top k values. val2: top k incoming SAE values. i2: the patch indices of those incoming top k values. k: k. Returns: The new top k values and their patch indices. """ all_val = torch.cat([val1, val2], dim=1) new_values, top_i = torch.topk(all_val, k=k, dim=1) all_i = torch.cat([i1, i2], dim=1) new_indices = torch.gather(all_i, 1, top_i) return new_values, new_indices @beartype.beartype def batched_idx( total_size: int, batch_size: int ) -> collections.abc.Iterator[tuple[int, int]]: """ Iterate over (start, end) indices for total_size examples, where end - start is at most batch_size. Args: total_size: total number of examples batch_size: maximum distance between the generated indices. Returns: A generator of (int, int) tuples that can slice up a list or a tensor. """ for start in range(0, total_size, batch_size): stop = min(start + batch_size, total_size) yield start, stop @jaxtyped(typechecker=beartype.beartype) def get_sae_acts( vit_acts: Float[Tensor, "n d_vit"], sae: nn.SparseAutoencoder, cfg: config.Visuals ) -> Float[Tensor, "n d_sae"]: """ Get SAE hidden layer activations for a batch of ViT activations. Args: vit_acts: Batch of ViT activations sae: Sparse autoencder. cfg: Experimental config. """ sae_acts = [] for start, end in batched_idx(len(vit_acts), cfg.sae_batch_size): _, f_x, *_ = sae(vit_acts[start:end].to(cfg.device)) sae_acts.append(f_x) sae_acts = torch.cat(sae_acts, dim=0) sae_acts = sae_acts.to(cfg.device) return sae_acts @jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True) class TopKImg: ".. todo:: Document this class." top_values: Float[Tensor, "d_sae k"] top_i: Int[Tensor, "d_sae k"] mean_values: Float[Tensor, " d_sae"] sparsity: Float[Tensor, " d_sae"] distributions: Float[Tensor, "m n"] percentiles: Float[Tensor, " d_sae"] @beartype.beartype @torch.inference_mode() def get_topk_img(cfg: config.Visuals) -> TopKImg: """ Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by max over all images: f_x(cls)[i] Thus, we will never have duplicate images for a given latent. But we also will not have patch-level activations (a nice heatmap). Args: cfg: Config. Returns: A tuple of TopKImg and the first m features' activation distributions. """ assert cfg.sort_by == "img" assert cfg.data.patches == "cls" sae = nn.load(cfg.ckpt).to(cfg.device) dataset = activations.Dataset(cfg.data) top_values_im_SK = torch.full((sae.cfg.d_sae, cfg.top_k), -1.0, device=cfg.device) top_i_im_SK = torch.zeros( (sae.cfg.d_sae, cfg.top_k), dtype=torch.int, device=cfg.device ) sparsity_S = torch.zeros((sae.cfg.d_sae,), device=cfg.device) mean_values_S = torch.zeros((sae.cfg.d_sae,), device=cfg.device) distributions_MN = torch.zeros((cfg.n_distributions, len(dataset)), device="cpu") estimator = PercentileEstimator( cfg.percentile, len(dataset), shape=(sae.cfg.d_sae,) ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=cfg.topk_batch_size, shuffle=False, num_workers=cfg.n_workers, drop_last=False, ) logger.info("Loaded SAE and data.") for batch in helpers.progress(dataloader, desc="picking top-k"): vit_acts_BD = batch["act"] sae_acts_BS = get_sae_acts(vit_acts_BD, sae, cfg) for sae_act_S in sae_acts_BS: estimator.update(sae_act_S) sae_acts_SB = einops.rearrange(sae_acts_BS, "batch d_sae -> d_sae batch") distributions_MN[:, batch["image_i"]] = sae_acts_SB[: cfg.n_distributions].to( "cpu" ) mean_values_S += einops.reduce(sae_acts_SB, "d_sae batch -> d_sae", "sum") sparsity_S += einops.reduce((sae_acts_SB > 0), "d_sae batch -> d_sae", "sum") sae_acts_SK, k = torch.topk(sae_acts_SB, k=cfg.top_k, dim=1) i_im_SK = batch["image_i"].to(cfg.device)[k] all_values_im_2SK = torch.cat((top_values_im_SK, sae_acts_SK), axis=1) top_values_im_SK, k = torch.topk(all_values_im_2SK, k=cfg.top_k, axis=1) top_i_im_SK = torch.gather(torch.cat((top_i_im_SK, i_im_SK), axis=1), 1, k) mean_values_S /= sparsity_S sparsity_S /= len(dataset) return TopKImg( top_values_im_SK, top_i_im_SK, mean_values_S, sparsity_S, distributions_MN, estimator.estimate.cpu(), ) @jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True) class TopKPatch: ".. todo:: Document this class." top_values: Float[Tensor, "d_sae k n_patches_per_img"] top_i: Int[Tensor, "d_sae k"] mean_values: Float[Tensor, " d_sae"] sparsity: Float[Tensor, " d_sae"] distributions: Float[Tensor, "m n"] percentiles: Float[Tensor, " d_sae"] @beartype.beartype @torch.inference_mode() def get_topk_patch(cfg: config.Visuals) -> TopKPatch: """ Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by max over all patches: f_x(patch)[i] Thus, we could end up with duplicate images in the top k, if an image has more than one patch that maximally activates an SAE latent. Args: cfg: Config. Returns: A tuple of TopKPatch and m randomly sampled activation distributions. """ assert cfg.sort_by == "patch" assert cfg.data.patches == "patches" sae = nn.load(cfg.ckpt).to(cfg.device) dataset = activations.Dataset(cfg.data) top_values_p = torch.full( (sae.cfg.d_sae, cfg.top_k, dataset.metadata.n_patches_per_img), -1.0, device=cfg.device, ) top_i_im = torch.zeros( (sae.cfg.d_sae, cfg.top_k), dtype=torch.int, device=cfg.device ) sparsity_S = torch.zeros((sae.cfg.d_sae,), device=cfg.device) mean_values_S = torch.zeros((sae.cfg.d_sae,), device=cfg.device) distributions_MN = torch.zeros((cfg.n_distributions, len(dataset)), device="cpu") estimator = PercentileEstimator( cfg.percentile, len(dataset), shape=(sae.cfg.d_sae,) ) batch_size = ( cfg.topk_batch_size // dataset.metadata.n_patches_per_img * dataset.metadata.n_patches_per_img ) n_imgs_per_batch = batch_size // dataset.metadata.n_patches_per_img dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=cfg.n_workers, # See if you can change this to false and still pass the beartype check. drop_last=True, ) logger.info("Loaded SAE and data.") for batch in helpers.progress(dataloader, desc="picking top-k"): vit_acts_BD = batch["act"] sae_acts_BS = get_sae_acts(vit_acts_BD, sae, cfg) for sae_act_S in sae_acts_BS: estimator.update(sae_act_S) sae_acts_SB = einops.rearrange(sae_acts_BS, "batch d_sae -> d_sae batch") distributions_MN[:, batch["image_i"]] = sae_acts_SB[: cfg.n_distributions].to( "cpu" ) mean_values_S += einops.reduce(sae_acts_SB, "d_sae batch -> d_sae", "sum") sparsity_S += einops.reduce((sae_acts_SB > 0), "d_sae batch -> d_sae", "sum") i_im = torch.sort(torch.unique(batch["image_i"])).values values_p = sae_acts_SB.view( sae.cfg.d_sae, len(i_im), dataset.metadata.n_patches_per_img ) # Checks that I did my reshaping correctly. assert values_p.shape[1] == i_im.shape[0] assert len(i_im) == n_imgs_per_batch _, k = torch.topk(sae_acts_SB, k=cfg.top_k, dim=1) k_im = k // dataset.metadata.n_patches_per_img values_p = gather_batched(values_p, k_im) i_im = i_im.to(cfg.device)[k_im] all_values_p = torch.cat((top_values_p, values_p), axis=1) _, k = torch.topk(all_values_p.max(axis=-1).values, k=cfg.top_k, axis=1) top_values_p = gather_batched(all_values_p, k) top_i_im = torch.gather(torch.cat((top_i_im, i_im), axis=1), 1, k) mean_values_S /= sparsity_S sparsity_S /= len(dataset) return TopKPatch( top_values_p, top_i_im, mean_values_S, sparsity_S, distributions_MN, estimator.estimate.cpu(), ) @beartype.beartype @torch.inference_mode() def dump_activations(cfg: config.Visuals): """ For each SAE latent, we want to know which images have the most total "activation". That is, we keep track of each patch """ if cfg.sort_by == "img": topk = get_topk_img(cfg) elif cfg.sort_by == "patch": topk = get_topk_patch(cfg) else: typing.assert_never(cfg.sort_by) os.makedirs(cfg.root, exist_ok=True) torch.save(topk.top_values, cfg.top_values_fpath) torch.save(topk.top_i, cfg.top_img_i_fpath) torch.save(topk.mean_values, cfg.mean_values_fpath) torch.save(topk.sparsity, cfg.sparsity_fpath) torch.save(topk.distributions, cfg.distributions_fpath) torch.save(topk.percentiles, cfg.percentiles_fpath) @jaxtyped(typechecker=beartype.beartype) def plot_activation_distributions( cfg: config.Visuals, distributions: Float[Tensor, "m n"] ): import matplotlib.pyplot as plt import numpy as np m, _ = distributions.shape n_rows = int(math.sqrt(m)) n_cols = n_rows fig, axes = plt.subplots( figsize=(4 * n_cols, 4 * n_rows), nrows=n_rows, ncols=n_cols, sharex=True, sharey=True, ) _, bins = np.histogram(np.log10(distributions[distributions > 0].numpy()), bins=100) percentiles = [90, 95, 99, 100] colors = ("red", "darkorange", "gold", "lime") for dist, ax in zip(distributions, axes.reshape(-1)): vals = np.log10(dist[dist > 0].numpy()) ax.hist(vals, bins=bins) if vals.size == 0: continue for i, (percentile, color) in enumerate( zip(np.percentile(vals, percentiles), colors) ): ax.axvline(percentile, color=color, label=f"{percentiles[i]}th %-ile") for i, (percentile, color) in enumerate(zip(percentiles, colors)): estimator = PercentileEstimator(percentile, len(vals)) for v in vals: estimator.update(v) ax.axvline( estimator.estimate, color=color, linestyle="--", label=f"Est. {percentiles[i]}th %-ile", ) ax.legend() fig.tight_layout() return fig @beartype.beartype @torch.inference_mode() def main(cfg: config.Visuals): """ .. todo:: document this function. Dump top-k images to a directory. Args: cfg: Configuration object. """ try: top_values = safe_load(cfg.top_values_fpath) sparsity = safe_load(cfg.sparsity_fpath) mean_values = safe_load(cfg.mean_values_fpath) top_i = safe_load(cfg.top_img_i_fpath) distributions = safe_load(cfg.distributions_fpath) _ = safe_load(cfg.percentiles_fpath) except FileNotFoundError as err: logger.warning("Need to dump files: %s", err) dump_activations(cfg) return main(cfg) d_sae, cached_topk, *rest = top_values.shape # Check that the data is at least shaped correctly. assert cfg.top_k == cached_topk if cfg.sort_by == "img": assert len(rest) == 0 elif cfg.sort_by == "patch": assert len(rest) == 1 n_patches = rest[0] assert n_patches > 0 else: typing.assert_never(cfg.sort_by) logger.info("Loaded sorted data.") os.makedirs(cfg.root, exist_ok=True) fig_fpath = os.path.join( cfg.root, f"{cfg.n_distributions}_activation_distributions.png" ) plot_activation_distributions(cfg, distributions).savefig(fig_fpath, dpi=300) logger.info( "Saved %d activation distributions to '%s'.", cfg.n_distributions, fig_fpath ) dataset = activations.get_dataset(cfg.images, img_transform=None) min_log_freq, max_log_freq = cfg.log_freq_range min_log_value, max_log_value = cfg.log_value_range mask = ( (min_log_freq < torch.log10(sparsity)) & (torch.log10(sparsity) < max_log_freq) & (min_log_value < torch.log10(mean_values)) & (torch.log10(mean_values) < max_log_value) ) neurons = cfg.include_latents random_neurons = torch.arange(d_sae)[mask.cpu()].tolist() random.seed(cfg.seed) random.shuffle(random_neurons) neurons += random_neurons[: cfg.n_latents] for i in helpers.progress(neurons, desc="saving visuals"): neuron_dir = os.path.join(cfg.root, "neurons", str(i)) os.makedirs(neuron_dir, exist_ok=True) # Image grid elems = [] seen_i_im = set() for i_im, values_p in zip(top_i[i].tolist(), top_values[i]): if i_im in seen_i_im: continue example = dataset[i_im] if cfg.sort_by == "img": elem = GridElement(example["image"], example["label"], torch.tensor([])) elif cfg.sort_by == "patch": elem = GridElement(example["image"], example["label"], values_p) else: typing.assert_never(cfg.sort_by) elems.append(elem) seen_i_im.add(i_im) # How to scale values. upper = None if top_values[i].numel() > 0: upper = top_values[i].max().item() for j, elem in enumerate(elems): img = make_img(elem, upper=upper) img.save(os.path.join(neuron_dir, f"{j}.png")) with open(os.path.join(neuron_dir, f"{j}.txt"), "w") as fd: fd.write(elem.label + "\n") # Metadata metadata = { "neuron": i, "log10_freq": torch.log10(sparsity[i]).item(), "log10_value": torch.log10(mean_values[i]).item(), } with open(os.path.join(neuron_dir, "metadata.json"), "w") as fd: json.dump(metadata, fd) @beartype.beartype class PercentileEstimator: def __init__( self, percentile: float | int, total: int, lr: float = 1e-3, shape: tuple[int, ...] = (), ): self.percentile = percentile self.total = total self.lr = lr self._estimate = torch.zeros(shape) self._step = 0 def update(self, x): """ Update the estimator with a new value. This method maintains the marker positions using the P2 algorithm rules. When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions. Arguments: x: The new value to incorporate into the estimation """ self._step += 1 step_size = self.lr * (self.total - self._step) / self.total # Is a no-op if it's already on the same device. self._estimate = self._estimate.to(x.device) self._estimate += step_size * ( torch.sign(x - self._estimate) + 2 * self.percentile / 100 - 1.0 ) @property def estimate(self): return self._estimate @beartype.beartype def test_online_quantile_estimation(true: float, percentile: float): import matplotlib.pyplot as plt import numpy as np import tqdm rng = np.random.default_rng(seed=0) n = 3_000_000 estimator = PercentileEstimator(percentile, n) dist, preds = np.zeros(n), np.zeros(n) for i in tqdm.tqdm(range(n), desc="Getting estimates."): sampled = rng.normal(true) estimator.update(sampled) dist[i] = sampled preds[i] = estimator.estimate fig, ax = plt.subplots() ax.plot(preds, label=f"Pred. {percentile * 100}th %-ile") ax.axhline( np.percentile(dist, percentile * 100), label=f"True {percentile * 100}th %-ile", color="tab:red", ) ax.legend() fig.tight_layout() fig.savefig("online_median_normal.png") if __name__ == "__main__": import tyro tyro.cli(test_online_quantile_estimation)