Inference¶
Briefly, you need to:
- Download a checkpoint.
- Get the code.
- Load the checkpoint.
- Get activations.
Details are below.
Download a Checkpoint¶
First, download an SAE checkpoint from the Huggingface collection.
For instance, you can choose the SAE trained on OpenAI's CLIP ViT-B/16 with ImageNet-1K activations here.
You can use wget if you want:
wget https://huggingface.co/osunlp/SAE_CLIP_24K_ViT-B-16_IN1K/resolve/main/sae.pt
Get the Code¶
The easiest way to do this is to clone the code:
git clone https://github.com/OSU-NLP-Group/saev
You can also install the package from git if you use uv (not sure about pip or cuda):
uv add git+https://github.com/OSU-NLP-Group/saev
Or clone it and install it as an editable with pip, lik pip install -e . in your virtual environment.
Then you can do things like from saev import ....
Note
If you struggle to get saev installed, open an issue on GitHub and I will figure out how to make it easier.
Load the Checkpoint¶
import saev.nn
sae = saev.nn.load("PATH_TO_YOUR_SAE_CKPT.pt")
Now you have a pretrained SAE.
Get Activations¶
This is the hardest part. We need to:
- Pass an image into a ViT
- Record the dense ViT activations at the same layer that the SAE was trained on.
- Pass the activations into the SAE to get sparse activations.
- Do something interesting with the sparse SAE activations.
There are examples of this in the demo code: for classification and semantic segmentation.
If the permalinks change, you are looking for the get_sae_latents() functions in both files.
Below is example code to do it using the saev package.
import saev.nn
import saev.data.models
import saev.data.shards
vit_cls = saev.data.models.load_model_cls("clip")
vit = vit_cls("ViT-B-16/openai").to(device)
vit = saev.data.shards.RecordedTransformer(vit, 196, True, [10])
img_tr, _ = vit_cls.make_transforms("ViT-B-16/openai", 196)
img = Image.open("example.jpg")
x = img_transform(img)
# Add a batch dimension
x = x[None, ...]
_, vit_acts = recorded_vit(x)
# Select the only layer in the batch and ignore the CLS token.
vit_acts = vit_acts[:, 0, 1:, :]
x_hat, f_x, loss = sae(vit_acts)
Now you have the reconstructed x (x_hat) and the sparse representation of all patches in the image (f_x).
You might select the dimensions with maximal values for each patch and see what other images are maximimally activating.