Package saev
saev is a Python package for training sparse autoencoders (SAEs) on vision transformers (ViTs) in PyTorch.
The main entrypoint to the package is in __main__
; use python -m saev --help
to see the options and documentation for the script.
Guide to Training SAEs on Vision Models
- Record ViT activations and save them to disk.
- Train SAEs on the activations.
- Visualize the learned features from the trained SAEs.
- (your job) Propose trends and patterns in the visualized features.
- (your job, supported by code) Construct datasets to test your hypothesized trends.
- Confirm/reject hypotheses using
probing
package.
saev
helps with steps 1, 2 and 3.
Note: saev
assumes you are running on NVIDIA GPUs. On a multi-GPU system, prefix your commands with CUDA_VISIBLE_DEVICES=X
to run on GPU X.
Record ViT Activations to Disk
To save activations to disk, we need to specify:
- Which model we would like to use
- Which layers we would like to save.
- Where on disk and how we would like to save activations.
- Which images we want to save activations for.
The saev.activations
module does all of this for us.
Run uv run python -m saev activations --help
to see all the configuration.
In practice, you might run:
uv run python -m saev activations \
--model-group clip \
--model-ckpt ViT-B-32/openai \
--d-vit 768 \
--n-patches-per-img 49 \
--layers -2 \
--dump-to /local/scratch/$USER/cache/saev \
--n-patches-per-shard 2_4000_000 \
data:imagenet-dataset
This will save activations for the CLIP-pretrained model ViT-B/32, which has a residual stream dimension of 768, and has 49 patches per image (224 / 32 = 7; 7 x 7 = 49).
It will save the second-to-last layer (--layer -2
).
It will write 2.4M patches per shard, and save shards to a new directory /local/scratch$USER/cache/saev
.
Note: A note on storage space: A ViT-B/16 will save 1.2M images x 197 patches/layer/image x 1 layer = ~240M activations, each of which take up 768 floats x 4 bytes/float = 3072 bytes, for a total of 723GB for the entire dataset. As you scale to larger models (ViT-L has 1024 dimensions, 14x14 patches are 224 patches/layer/image), recorded activations will grow even larger.
This script will also save a metadata.json
file that will record the relevant metadata for these activations, which will be read by future steps.
The activations will be in .bin
files, numbered starting from 000000.
To add your own models, see the guide to extending in saev.activations
.
Train SAEs on Activations
To train an SAE, we need to specify:
- Which activations to use as input.
- SAE architectural stuff.
- Optimization-related stuff.
The
saev.training` module handles this.
Run uv run python -m saev train --help
to see all the configuration.
Continuing on from our example before, you might want to run something like:
uv run python -m saev train \
--data.shard-root /local/scratch/$USER/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8 \
--data.layer -2 \
--data.patches patches \
--data.no-scale-mean \
--data.no-scale-norm \
--sae.d-vit 768 \
--lr 5e-4
--data.*
flags describe which activations to use.
--data.shard-root
should point to a directory with *.bin
files and the metadata.json
file.
--data.layer
specifies the layer, and --data.patches
says that want to train on individual patch activations, rather than the [CLS] token activation.
--data.no-scale-mean
and --data.no-scale-norm
mean not to scale the activation mean or L2 norm.
Anthropic's and OpenAI's papers suggest normalizing these factors, but saev
still has a bug with this, so I suggest not scaling these factors.
--sae.*
flags are about the SAE itself.
--sae.d-vit
is the only one you need to change; the dimension of our ViT was 768 for a ViT-B, rather than the default of 1024 for a ViT-L.
Finally, choose a slightly larger learning rate than the default with --lr 5e-4
.
This will train one (1) sparse autoencoder on the data. See the section on sweeps to learn how to train multiple SAEs in parallel using only a single GPU.
Visualize the Learned Features
Now that you've trained an SAE, you probably want to look at its learned features. One way to visualize an individual learned feature f is by picking out images that maximize the activation of feature f. Since we train SAEs on patch-level activations, we try to find the top patches for each feature f. Then, we pick out the images those patches correspond to and create a heatmap based on SAE activation values.
Note: More advanced forms of visualization are possible (and valuable!), but should not be included in saev
unless they can be applied to every SAE/dataset combination. If you have specific visualizations, please add them to contrib/
or another location.
saev.visuals
records these maximally activating images for us.
You can see all the options with uv run python -m saev visuals --help
.
So you might run:
uv run python -m saev visuals \
--ckpt checkpoints/abcdefg/sae.pt \
--dump-to /nfs/$USER/saev/webapp/abcdefg \
--data.shard-root /local/scratch/$USER/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8 \
--data.layer -2 \
--data.patches patches \
images:imagenet-dataset
This will record the top 128 patches, and then save the unique images among those top 128 patches for each feature in the trained SAE. It will cache these best activations to disk, then start saving images to visualize later on.
saev.interactive.features
is a small web application based on marimo to interactively look at these images.
You can run it with uv run marimo edit saev/interactive/features.py
.
Sweeps
TODO
Explain how to run grid sweeps.
Training Metrics and Visualizations
TODO
Explain how to use the saev.interactive.metrics
notebook.
Related Work
Various papers and internet posts on training SAEs for vision.
Preprints
An X-Ray Is Worth 15 Features: Sparse Autoencoders for Interpretable Radiology Report Generation * Haven't read this yet, but Hugo Fry is an author.
LessWrong
Towards Multimodal Interpretability: Learning Sparse Interpretable Features in Vision Transformers * Trains a sparse autoencoder on the 22nd layer of a CLIP ViT-L/14. First public work training an SAE on a ViT. Finds interesting features, demonstrating that SAEs work with ViTs.
Interpreting and Steering Features in Images * Havne't read it yet.
Case Study: Interpreting, Manipulating, and Controlling CLIP With Sparse Autoencoders * Followup to the above work; haven't read it yet.
A Suite of Vision Sparse Autoencoders * Train a sparse autoencoder on various layers using the TopK with k=32 on a CLIP ViT-L/14 trained on LAION-2B. The SAE is trained on 1.2B tokens including patch (not just [CLS]). Limited evaluation.
Reproduce
To reproduce our findings from our preprint, you will need to train a couple SAEs on various datasets, then save visual examples so you can browse them in the notebooks.
Table of Contents
- Save activations for ImageNet and iNat2021 for DINOv2, CLIP and BioCLIP.
- Train SAEs on these activation datasets.
- Pick the best SAE checkpoints for each combination.
- Save visualizations for those best checkpoints.
Save Activations
Train SAEs
Choose Best Checkpoints
Save Visualizations
Get visuals for the iNat-trained SAEs (BioCLIP and CLIP):
uv run python -m saev visuals \
--ckpt checkpoints/$CKPT/sae.pt \
--dump-to /$NFS/$USER/saev-visuals/$CKPT/ \
--log-freq-range -2.0 -1.0 \
--log-value-range -0.75 2.0 \
--data.shard-root /local/scratch/$USER/cache/saev/$SHARDS \
images:image-folder-dataset \
--images.root /$NFS/$USER/datasets/inat21/train_mini/
Look at these visuals in the interactive notebook.
uv run marimo edit
Then open localhost:2718 in your browser and open the saev/interactive/features.py
file.
Choose one of the checkpoints in the dropdown and click through the different neurons to find patterns in the underlying ViT.
Sub-modules
saev.activations
-
To save lots of activations, we want to do things in parallel, with lots of slurm jobs, and save multiple files, rather than just one …
saev.app
saev.config
-
All configs for all saev jobs …
saev.helpers
-
Useful helpers for
saev
. saev.imaging
saev.interactive
saev.nn
-
Neural network architectures for sparse autoencoders.
saev.test_activations
-
Test that the cached activations are actually correct. These tests are quite slow
saev.test_config
saev.test_nn
-
Uses hypothesis and hypothesis-torch to generate test cases to compare our …
saev.test_training
saev.test_visuals
saev.training
-
Trains many SAEs in parallel to amortize the cost of loading a single batch of data over many SAE training runs.
saev.visuals
-
There is some important notation used only in this file to dramatically shorten variable names …