Module contrib.semseg
Interpret and manipulate semantic segmentation models using SAEs.
Reproduce
There are two main experiments to reproduce in our preprint.
First, our qualitative examples. Second, our quantitative evaluation of pseudo-orthogonality.
Qualitative
You can reproduce our qualititative examples from our preprint by following these instructions.
- Train a linear probe on semantic segmentation task using ADE20K.
- Measure linear probe baseline metrics.
- Manipulate the activations using the proposed SAE features.
- Be amazed. :)
Details can be found below.
Train a Linear Probe on Semantic Segmentation
Train a linear probe on DINOv2 activations from ADE20K. It's fixed with DINOv2 because of patch size, but the code could be extended to different ViTs.
uv run python -m contrib.semseg train \
--sweep contrib/semseg/sweep.toml \
--imgs.root /$NFS/$USER/datasets/ade20k
Measure Linear Probe Baseline Metrics
Check which learning rate/weight decay combination is best for the linear probe.
uv run python -m contrib.semseg validate \
--imgs.root /$NFS/$USER/datasets/ade20k
Then you can look in ./logs/contrib/semseg
for hparam-sweeps.png
to see what learning rate/weight decay combination is best.
Manipulate the Activations
You need an SAE that's been trained on DINOv2's activations on ImageNet. Then you can run both the frontend server and the backend server:
Frontend:
uv run python -m http.server
Then navigate to http://localhost:8000/web/apps/semseg/.
Backend:
This is a little trickier because the backend server lives on Huggingface spaces and talks to a personal Cloudflare server.
[TODO]
Quantitative
We aim to measure the specificity and psuedo-orthogonality of SAE-discovered features by evaluating the impact of feature manipulation on semantic segmentation.
We train an SAE on ImageNet-1K activations from DINOv2 ViT-B/14 (hosted here on HuggingFace). Then, we train a linear probe on top of DINOv2 for ADE20K following the procedure above. We define four ways to select a feature vector for a given ADE20K class:
- Random unit vector in $d$-dimensional space
- Random SAE feature vector.
- Automatically selected SAE feature vector.
- Manually chosen SAE feature vector.
All four are described in more detail below.
Given a feature $i$ and an ADE20K class $c$, for each image in the validation set, we perform semantic segmentation inference using DINOv2 and the trained linear probe. However, we set feature $i$ to $-2$ its maximum observed value following the description of manipulation in Section 3.3 of our preprint. We then maintain several counts:
- Number of patches originally predicted as class $c$ and are now not $c$.
- Number of patches originally predicted as class $c$ and are now still $c$.
- Number of patches originally predicted as not class $c$ and are now $c$.
- Number of patches originally predicted as not class $c$ and are now still not $c$.
With this, we calculate two percentages:
- Target change rate:
(Number of original $c$ patches that changed class) / (Total number of original $c$ patches) * 100
- Other change rate:
(Number of original not-$c$ patches that changed class) / (Total number of original not-$c$ patches) * 100
Ideally, we maximize target change rate and minimize other change rate. We measure mean target change rate across all classes and mean other change rate across all classes.
uv run python -m contrib.semseg quantify \
--sae-ckpt checkpoints/public/oebd6e6i/sae.pt \
--seg-ckpt checkpoints/contrib/semseg/lr_0_001__wd_0_1/ \
--imgs.root /$NFS/$USER/datasets/ade20k/
The main entry point is contrib/semseg/__main__.py
.
Run uv run python -m contrib.semseg --help
to see all options.
Sub-modules
contrib.semseg.config
-
Configs for all the different subscripts in
contrib.semseg
… contrib.semseg.interactive
contrib.semseg.quantitative
contrib.semseg.training
-
Trains multiple linear probes in parallel on DINOv2's ADE20K activations.
contrib.semseg.validation
-
Checks which checkpoints have the best validation loss, mean IoU, class-specific IoU, validation accuracy, and qualitative results …
contrib.semseg.visuals
-
Propose features for manual verification.