Module saev.activations

To save lots of activations, we want to do things in parallel, with lots of slurm jobs, and save multiple files, rather than just one.

This module handles that additional complexity.

Conceptually, activations are either thought of as

  1. A single [n_imgs x n_layers x (n_patches + 1), d_vit] tensor. This is a dataset
  2. Multiple [n_imgs_per_shard, n_layers, (n_patches + 1), d_vit] tensors. This is a set of sharded activations.


def get_acts_dir(cfg: Activations) ‑> str

Return the activations directory based on the relevant values of a config. Also saves a metadata.json file to that directory for human reference.


Config for experiment.


Directory to where activations should be dumped/loaded from.

def get_dataloader(cfg: Activations,

Gets the dataloader for the current experiment; delegates dataloader construction to dataset-specific functions.


Experiment config.
Image transform to be applied to each image.


A PyTorch Dataloader that yields dictionaries with 'image' keys containing image batches.

def get_dataset(cfg: ImagenetDataset | ImageFolderDataset | Ade20kDataset,

Gets the dataset for the current experiment; delegates construction to dataset-specific functions.


Experiment config.
Image transform to be applied to each image.


A dataset that has dictionaries with 'image', 'index', 'target', and 'label' keys containing examples.

def get_default_dataloader(cfg: Activations,
img_transform: ) ‑>

Get a dataloader for a default map-style dataset.


Image transform to be applied to each image.


A PyTorch Dataloader that yields dictionaries with 'image' keys containing image batches, 'index' keys containing original dataset indices and 'label' keys containing label batches.

def main(cfg: Activations)


Config for activations.
def make_img_transform(vit_family: str, vit_ckpt: str) ‑> 
def make_vit(vit_family: str, vit_ckpt: str)
def setup(cfg: Activations)

Run dataset-specific setup. These setup functions can assume they are the only job running, but they should be idempotent; they should be safe (and ideally cheap) to run multiple times in a row.

def setup_ade20k(cfg: Activations)
def setup_imagefolder(cfg: Activations)
def setup_imagenet(cfg: Activations)
def worker_fn(cfg: Activations)


Config for activations.


class Ade20k (cfg: Ade20kDataset,
img_transform: | None = None,
seg_transform: | None = <function Ade20k.<lambda>>)

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many implementations and the default options of Subclasses could also optionally implement :meth:__getitems__, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

Note by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

class Ade20k(
    class Sample:
        img_path: str
        seg_path: str
        label: str
        target: int

    samples: list[Sample]

    def __init__(
        cfg: config.Ade20kDataset,
        img_transform: Callable | None = None,
        seg_transform: Callable | None = lambda x: None,
        self.logger = logging.getLogger("ade20k")
        self.cfg = cfg
        self.img_dir = os.path.join(cfg.root, "images")
        self.seg_dir = os.path.join(cfg.root, "annotations")
        self.img_transform = img_transform
        self.seg_transform = seg_transform

        # Check that we have the right path.
        for subdir in ("images", "annotations"):
            if not os.path.isdir(os.path.join(cfg.root, subdir)):
                # Something is missing.
                if os.path.realpath(cfg.root).endswith(subdir):
                        "The ADE20K root should contain 'images/' and 'annotations/' directories."
                raise ValueError(f"Can't find path '{os.path.join(cfg.root, subdir)}'.")

        _, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir)
        split_lookup: dict[int, str] = {
            value: key for key, value in split_mapping.items()
        self.loader = torchvision.datasets.folder.default_loader

        assert cfg.split in set(split_lookup.values())

        # Load all the image paths.
        imgs: list[str] = [
            for path, s in torchvision.datasets.folder.make_dataset(
            if split_lookup[s] == cfg.split

        segs: list[str] = [
            for path, s in torchvision.datasets.folder.make_dataset(
            if split_lookup[s] == cfg.split

        # Load all the targets, classes and mappings
        with open(os.path.join(cfg.root, "sceneCategories.txt")) as fd:
            img_labels: list[str] = [line.split()[1] for line in fd.readlines()]

        label_set = sorted(set(img_labels))
        label_to_idx = {label: i for i, label in enumerate(label_set)}

        self.samples = [
            self.Sample(img_path, seg_path, label, label_to_idx[label])
            for img_path, seg_path, label in zip(imgs, segs, img_labels)

    def __getitem__(self, index: int) -> dict[str, object]:
        # Convert to dict.
        sample = dataclasses.asdict(self.samples[index])

        sample["image"] = self.loader(sample.pop("img_path"))
        if self.img_transform is not None:
            image = self.img_transform(sample.pop("image"))
            if image is not None:
                sample["image"] = image

        sample["segmentation"] ="seg_path")).convert("L")
        if self.seg_transform is not None:
            segmentation = self.seg_transform(sample.pop("segmentation"))
            if segmentation is not None:
                sample["segmentation"] = segmentation

        sample["index"] = index

        return sample

    def __len__(self) -> int:
        return len(self.samples)


Class variables

var Sample
var samples : list[Ade20k.Sample]
class Clip (vit_ckpt: str)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class Clip(torch.nn.Module):
    def __init__(self, vit_ckpt: str):

        import open_clip

        if vit_ckpt.startswith("hf-hub:"):
            clip, _ = open_clip.create_model_from_pretrained(
                vit_ckpt, cache_dir=helpers.get_cache_dir()
            arch, ckpt = vit_ckpt.split("/")
            clip, _ = open_clip.create_model_from_pretrained(
                arch, pretrained=ckpt, cache_dir=helpers.get_cache_dir()

        model = clip.visual
        model.proj = None
        model.output_tokens = True  # type: ignore
        self.model = model.eval()

        assert not isinstance(self.model, open_clip.timm_model.TimmModel) = f"clip/{vit_ckpt}"

    def get_residuals(self) -> list[torch.nn.Module]:
        return self.model.transformer.resblocks

    def get_patches(self, cfg: config.Activations) -> slice:
        return slice(None, None, None)

    def forward(
        self, batch: Float[Tensor, "batch 3 width height"]
    ) -> Float[Tensor, "batch patches dim"]:
        cls, patches = self.model(batch)
        return {"cls": cls, "patches": patches}


def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch patches dim']

Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def get_patches(self,
cfg: Activations) ‑> slice
def get_residuals(self) ‑> list[torch.nn.modules.module.Module]
class Dataset (cfg: DataLoad)

Dataset of activations from disk.

class Dataset(
    Dataset of activations from disk.

    class Example(typing.TypedDict):
        """Individual example."""

        act: Float[Tensor, " d_vit"]
        image_i: int
        patch_i: int

    cfg: config.DataLoad
    """Configuration; set via CLI args."""
    metadata: "Metadata"
    """Activations metadata; automatically loaded from disk."""
    layer_index: int
    """Layer index into the shards if we are choosing a specific layer."""
    scalar: float
    """Normalizing scalar such that ||x / scalar ||_2 ~= sqrt(d_vit)."""
    act_mean: Float[Tensor, " d_vit"]
    """Mean activation."""

    def __init__(self, cfg: config.DataLoad):
        self.cfg = cfg
        if not os.path.isdir(self.cfg.shard_root):
            raise RuntimeError(f"Activations are not saved at '{self.cfg.shard_root}'.")

        metadata_fpath = os.path.join(self.cfg.shard_root, "metadata.json")
        self.metadata = Metadata.load(metadata_fpath)

        # Pick a really big number so that if you accidentally use this when you shouldn't, you get an out of bounds IndexError.
        self.layer_index = 1_000_000
        if isinstance(self.cfg.layer, int):
            err_msg = f"Non-exact matches for .layer field not supported; {self.cfg.layer} not in {self.metadata.layers}."
            assert self.cfg.layer in self.metadata.layers, err_msg
            self.layer_index = self.metadata.layers.index(self.cfg.layer)

        # Premptively set these values so that preprocess() doesn't freak out.
        self.scalar = 1.0
        self.act_mean = torch.zeros(self.d_vit)

        # If either of these are true, we must do this work.
        if self.cfg.scale_mean is True or self.cfg.scale_norm is True:
            # Load a random subset of samples to calculate the mean activation and mean L2 norm.
            perm = np.random.default_rng(seed=42).permutation(len(self))
            perm = perm[: cfg.n_random_samples]

            samples = [
                for p in helpers.progress(
                    perm, every=25_000, desc="examples to calc means"
            samples = torch.stack([sample["act"] for sample in samples])
            if samples.abs().max() > 1e3:
                raise ValueError(
                    "You found an abnormally large activation {example.abs().max().item():.5f} that will mess up your L2 mean."

            # Activation mean
            if self.cfg.scale_mean:
                self.act_mean = samples.mean(axis=0)
                if (self.act_mean > 1e3).any():
                    raise ValueError(
                        "You found an abnormally large activation that is messing up your activation mean."

            # Norm
            if self.cfg.scale_norm:
                l2_mean = torch.linalg.norm(samples - self.act_mean, axis=1).mean()
                if l2_mean > 1e3:
                    raise ValueError(
                        "You found an abnormally large activation that is messing up your L2 mean."

                self.scalar = l2_mean / math.sqrt(self.d_vit)
        elif isinstance(self.cfg.scale_mean, str):
            # Load mean activations from disk
            self.act_mean = torch.load(self.cfg.scale_mean)
        elif isinstance(self.cfg.scale_norm, str):
            # Load scalar normalization from disk
            self.scalar = torch.load(self.cfg.scale_norm).item()

    def transform(self, act: Float[np.ndarray, " d_vit"]) -> Float[Tensor, " d_vit"]:
        Apply a scalar normalization so the mean squared L2 norm is same as d_vit. This is from 'Scaling Monosemanticity':

        > As a preprocessing step we apply a scalar normalization to the model activations so their average squared L2 norm is the residual stream dimension

        So we divide by self.scalar which is the datasets (approximate) L2 mean before normalization divided by sqrt(d_vit).
        act = torch.from_numpy(act.copy())
        act = act.clamp(-self.cfg.clamp, self.cfg.clamp)
        return (act - self.act_mean) / self.scalar

    def d_vit(self) -> int:
        """Dimension of the underlying vision transformer's embedding space."""
        return self.metadata.d_vit

    def __getitem__(self, i: int) -> Example:
        match (self.cfg.patches, self.cfg.layer):
            case ("cls", int()):
                img_act = self.get_img_patches(i)
                # Select layer's cls token.
                act = img_act[self.layer_index, 0, :]
                return self.Example(act=self.transform(act), image_i=i, patch_i=-1)
            case ("cls", "meanpool"):
                img_act = self.get_img_patches(i)
                # Select cls tokens from across all layers
                cls_act = img_act[:, 0, :]
                # Meanpool over the layers
                act = cls_act.mean(axis=0)
                return self.Example(act=self.transform(act), image_i=i, patch_i=-1)
            case ("meanpool", int()):
                img_act = self.get_img_patches(i)
                # Select layer's patches.
                layer_act = img_act[self.layer_index, 1:, :]
                # Meanpool over the patches
                act = layer_act.mean(axis=0)
                return self.Example(act=self.transform(act), image_i=i, patch_i=-1)
            case ("meanpool", "meanpool"):
                img_act = self.get_img_patches(i)
                # Select all layer's patches.
                act = img_act[:, 1:, :]
                # Meanpool over the layers and patches
                act = act.mean(axis=(0, 1))
                return self.Example(act=self.transform(act), image_i=i, patch_i=-1)
            case ("patches", int()):
                n_imgs_per_shard = (
                    // len(self.metadata.layers)
                    // (self.metadata.n_patches_per_img + 1)
                n_examples_per_shard = (
                    n_imgs_per_shard * self.metadata.n_patches_per_img

                shard = i // n_examples_per_shard
                pos = i % n_examples_per_shard

                acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin")
                shape = (
                    self.metadata.n_patches_per_img + 1,
                acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape)
                # Choose the layer and the non-CLS tokens.
                acts = acts[:, self.layer_index, 1:]

                # Choose a patch among n and the patches.
                act = acts[
                    pos // self.metadata.n_patches_per_img,
                    pos % self.metadata.n_patches_per_img,
                return self.Example(
                    # What image is this?
                    image_i=i // self.metadata.n_patches_per_img,
                    patch_i=i % self.metadata.n_patches_per_img,
            case _:
                print((self.cfg.patches, self.cfg.layer))
                typing.assert_never((self.cfg.patches, self.cfg.layer))

    def get_shard_patches(self):
        raise NotImplementedError()

    def get_img_patches(
        self, i: int
    ) -> Float[np.ndarray, "n_layers all_patches d_vit"]:
        n_imgs_per_shard = (
            // len(self.metadata.layers)
            // (self.metadata.n_patches_per_img + 1)
        shard = i // n_imgs_per_shard
        pos = i % n_imgs_per_shard
        acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin")
        shape = (
            self.metadata.n_patches_per_img + 1,
        acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape)
        # Note that this is not yet copied!
        return acts[pos]

    def __len__(self) -> int:
        Dataset length depends on `patches` and `layer`.
        match (self.cfg.patches, self.cfg.layer):
            case ("cls", "all"):
                # Return a CLS token from a random image and random layer.
                return self.metadata.n_imgs * len(self.metadata.layers)
            case ("cls", int()):
                # Return a CLS token from a random image and fixed layer.
                return self.metadata.n_imgs
            case ("cls", "meanpool"):
                # Return a CLS token from a random image and meanpool over all layers.
                return self.metadata.n_imgs
            case ("meanpool", "all"):
                # Return the meanpool of all patches from a random image and random layer.
                return self.metadata.n_imgs * len(self.metadata.layers)
            case ("meanpool", int()):
                # Return the meanpool of all patches from a random image and fixed layer.
                return self.metadata.n_imgs
            case ("meanpool", "meanpool"):
                # Return the meanpool of all patches from a random image and meanpool over all layers.
                return self.metadata.n_imgs
            case ("patches", int()):
                # Return a patch from a random image, fixed layer, and random patch.
                return self.metadata.n_imgs * (self.metadata.n_patches_per_img)
            case ("patches", "meanpool"):
                # Return a patch from a random image, meanpooled over all layers, and a random patch.
                return self.metadata.n_imgs * (self.metadata.n_patches_per_img)
            case ("patches", "all"):
                # Return a patch from a random image, random layer and random patch.
                return (
                    * len(self.metadata.layers)
                    * self.metadata.n_patches_per_img
            case _:
                typing.assert_never((self.cfg.patches, self.cfg.layer))


Class variables

var Example

Individual example.

var act_mean : jaxtyping.Float[Tensor, 'd_vit']

Mean activation.

var cfgDataLoad

Configuration; set via CLI args.

var layer_index : int

Layer index into the shards if we are choosing a specific layer.

var metadataMetadata

Activations metadata; automatically loaded from disk.

var scalar : float

Normalizing scalar such that ||x / scalar ||_2 ~= sqrt(d_vit).

Instance variables

prop d_vit : int

Dimension of the underlying vision transformer's embedding space.

def d_vit(self) -> int:
    """Dimension of the underlying vision transformer's embedding space."""
    return self.metadata.d_vit


def get_img_patches(self, i: int) ‑> jaxtyping.Float[ndarray, 'n_layers all_patches d_vit']
def get_shard_patches(self)
def transform(self, act: jaxtyping.Float[ndarray, 'd_vit']) ‑> jaxtyping.Float[Tensor, 'd_vit']

Apply a scalar normalization so the mean squared L2 norm is same as d_vit. This is from 'Scaling Monosemanticity':

As a preprocessing step we apply a scalar normalization to the model activations so their average squared L2 norm is the residual stream dimension

So we divide by self.scalar which is the datasets (approximate) L2 mean before normalization divided by sqrt(d_vit).

class DinoV2 (vit_ckpt: str)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class DinoV2(torch.nn.Module):
    def __init__(self, vit_ckpt: str):

        self.model = torch.hub.load("facebookresearch/dinov2", vit_ckpt) = f"dinov2/{vit_ckpt}"

    def get_residuals(self) -> list[torch.nn.Module]:
        return self.model.blocks

    def get_patches(self, n_patches_per_img: int) -> slice:
        n_reg = self.model.num_register_tokens
        patches =
            torch.tensor([0]),  # CLS token
            torch.arange(n_reg + 1, n_reg + 1 + n_patches_per_img),  # patches
        return patches

    def forward(
        self, batch: Float[Tensor, "batch 3 width height"]
    ) -> Float[Tensor, "batch patches dim"]:
        dct = self.model.forward_features(batch)

        features =
            (dct["x_norm_clstoken"][:, None, :], dct["x_norm_patchtokens"]), axis=1
        return features


def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch patches dim']

Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def get_patches(self, n_patches_per_img: int) ‑> slice
def get_residuals(self) ‑> list[torch.nn.modules.module.Module]
class ImageFolder (root: str | pathlib.Path,
transform: Callable | None = None,
target_transform: Callable | None = None,
loader: Callable[[str], Any] = <function default_loader>,
is_valid_file: Callable[[str], bool] | None = None,
allow_empty: bool = False)

A generic data loader where the images are arranged in this way by default: ::



This class inherits from :class:~torchvision.datasets.DatasetFolder so the same methods can be overridden to customize the dataset.


root (str or pathlib.Path): Root directory path.
transform : callable, optional
A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
loader : callable, optional
A function to load an image given its path.
is_valid_file : callable, optional
A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files)

allow_empty(bool, optional): If True, empty folders are considered to be valid classes. An error is raised on empty folders if False (default). Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples

class ImageFolder(torchvision.datasets.ImageFolder):
    def __getitem__(self, index: int) -> dict[str, object]:
            index: Index

            dict with keys 'image', 'index', 'target' and 'label'.
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return {
            "image": sample,
            "target": target,
            "label": self.classes[target],
            "index": index,


class Imagenet (cfg: ImagenetDataset,

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many implementations and the default options of Subclasses could also optionally implement :meth:__getitems__, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

Note by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

class Imagenet(
    def __init__(self, cfg: config.ImagenetDataset, *, img_transform=None):
        import datasets

        self.hf_dataset = datasets.load_dataset(
  , split=cfg.split, trust_remote_code=True

        self.img_transform = img_transform
        self.labels =["label"].names

    def __getitem__(self, i):
        example = self.hf_dataset[i]
        example["index"] = i

        example["image"] = example["image"].convert("RGB")
        if self.img_transform:
            example["image"] = self.img_transform(example["image"])
        example["target"] = example.pop("label")
        example["label"] = self.labels[example["target"]]

        return example

    def __len__(self) -> int:
        return len(self.hf_dataset)


class Metadata (vit_family: str,
vit_ckpt: str,
layers: tuple[int, ...],
n_patches_per_img: int,
cls_token: bool,
d_vit: int,
seed: int,
n_imgs: int,
n_patches_per_shard: int,
data: str)

Metadata(vit_family: str, vit_ckpt: str, layers: tuple[int, …], n_patches_per_img: int, cls_token: bool, d_vit: int, seed: int, n_imgs: int, n_patches_per_shard: int, data: str)

class Metadata:
    vit_family: str
    vit_ckpt: str
    layers: tuple[int, ...]
    n_patches_per_img: int
    cls_token: bool
    d_vit: int
    seed: int
    n_imgs: int
    n_patches_per_shard: int
    data: str

    def from_cfg(cls, cfg: config.Activations) -> "Metadata":
        return cls(

    def load(cls, fpath) -> "Metadata":
        with open(fpath) as fd:
            dct = json.load(fd)
        dct["layers"] = tuple(dct.pop("layers"))
        return cls(**dct)

    def dump(self, fpath):
        with open(fpath, "w") as fd:
            json.dump(dataclasses.asdict(self), fd, indent=4)

    def hash(self) -> str:
        cfg_str = json.dumps(dataclasses.asdict(self), sort_keys=True)
        return hashlib.sha256(cfg_str.encode("utf-8")).hexdigest()

var cls_token : bool
var d_vit : int
var data : str
var layers : tuple[int, ...]
var n_imgs : int
var n_patches_per_img : int
var n_patches_per_shard : int
var seed : int
var vit_ckpt : str
var vit_family : str

Static methods

def from_cfg(cls,
cfg: Activations) ‑> Metadata
def load(cls, fpath) ‑> Metadata

Instance variables

prop hash : str
def hash(self) -> str:
    cfg_str = json.dumps(dataclasses.asdict(self), sort_keys=True)
    return hashlib.sha256(cfg_str.encode("utf-8")).hexdigest()


def dump(self, fpath)
class Moondream2 (vit_ckpt: str)

Moondream2 has 14x14 pixel patches. For a 378x378 image (as we use here), this is 27x27 patches for a total of 729, with no [CLS] token.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class Moondream2(torch.nn.Module):
    Moondream2 has 14x14 pixel patches. For a 378x378 image (as we use here), this is 27x27 patches for a total of 729, with no [CLS] token.

    def __init__(self, vit_ckpt: str):

        import transformers

        vit_id, revision = vit_ckpt.split(":")

        mllm = transformers.AutoModelForCausalLM.from_pretrained(
            vit_id, revision=revision, trust_remote_code=True
        self.model = mllm.vision_encoder.encoder.model.visual

    def get_patches(self, cfg: config.Activations) -> slice:
        return slice(None, None, None)

    def get_residuals(self) -> list[torch.nn.Module]:
        return self.model.blocks

    def forward(
        self, batch: Float[Tensor, "batch 3 width height"]
    ) -> Float[Tensor, "batch patches dim"]:
        features = self.model(batch)
        return features


def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch patches dim']

Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def get_patches(self,
cfg: Activations) ‑> slice
def get_residuals(self) ‑> list[torch.nn.modules.module.Module]
class RecordedVisionTransformer (vit: torch.nn.modules.module.Module,
n_patches_per_img: int,
cls_token: bool,
layers: list[int])

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class RecordedVisionTransformer(torch.nn.Module):
    _storage: Float[Tensor, "batch n_layers all_patches dim"] | None
    _i: int

    def __init__(
        vit: torch.nn.Module,
        n_patches_per_img: int,
        cls_token: bool,
        layers: list[int],

        self.vit = vit

        self.n_patches_per_img = n_patches_per_img
        self.cls_token = cls_token
        self.layers = layers

        self.patches = vit.get_patches(n_patches_per_img)

        self._storage = None
        self._i = 0

        self.logger = logging.getLogger(f"recorder({})")

        for i in self.layers:

    def hook(
        self, module, args: tuple, output: Float[Tensor, "batch n_layers dim"]
    ) -> None:
        if self._storage is None:
            batch, _, dim = output.shape
            self._storage = self._empty_storage(batch, dim, output.device)

        if self._storage[:, self._i, 0, :].shape != output[:, 0, :].shape:
            batch, _, dim = output.shape

            old_batch, _, _, old_dim = self._storage.shape
            msg = "Output shape does not match storage shape: (batch) %d != %d or (dim) %d != %d"
            self.logger.warning(msg, old_batch, batch, old_dim, dim)

            self._storage = self._empty_storage(batch, dim, output.device)

        self._storage[:, self._i] = output[:, self.patches, :].detach()
        self._i += 1

    def _empty_storage(self, batch: int, dim: int, device: torch.device):
        n_patches_per_img = self.n_patches_per_img
        if self.cls_token:
            n_patches_per_img += 1

        return torch.zeros(
            (batch, len(self.layers), n_patches_per_img, dim), device=device

    def reset(self):
        self._i = 0

    def activations(self) -> Float[Tensor, "batch n_layers all_patches dim"]:
        if self._storage is None:
            raise RuntimeError("First call forward()")
        return self._storage.cpu()

    def forward(
        self, batch: Float[Tensor, "batch 3 width height"]
    ) -> tuple[Float[Tensor, "batch patches dim"], Float[Tensor, "..."]]:
        result = self.vit(batch)
        return result, self.activations


Instance variables

prop activations : jaxtyping.Float[Tensor, 'batch n_layers all_patches dim']
def activations(self) -> Float[Tensor, "batch n_layers all_patches dim"]:
    if self._storage is None:
        raise RuntimeError("First call forward()")
    return self._storage.cpu()


def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> tuple[jaxtyping.Float[Tensor, 'batch patches dim'], jaxtyping.Float[Tensor, '...']]

Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def hook(self, module, args: tuple, output: jaxtyping.Float[Tensor, 'batch n_layers dim']) ‑> None
def reset(self)
class ShardWriter (cfg: Activations)

ShardWriter is a stateful object that handles sharded activation writing to disk.

class ShardWriter:
    ShardWriter is a stateful object that handles sharded activation writing to disk.

    root: str
    shape: tuple[int, int, int, int]
    shard: int
    acts_path: str
    acts: Float[np.ndarray, "n_imgs_per_shard n_layers all_patches d_vit"] | None
    filled: int

    def __init__(self, cfg: config.Activations):
        self.logger = logging.getLogger("shard-writer")

        self.root = get_acts_dir(cfg)

        n_patches_per_img = cfg.n_patches_per_img
        if cfg.cls_token:
            n_patches_per_img += 1
        self.n_imgs_per_shard = (
            cfg.n_patches_per_shard // len(cfg.vit_layers) // n_patches_per_img
        self.shape = (

        self.shard = -1
        self.acts = None

    def __setitem__(
        self, i: slice, val: Float[Tensor, "_ n_layers all_patches d_vit"]
    ) -> None:
        assert i.step is None
        a, b = i.start, i.stop
        assert len(val) == b - a

        offset = self.n_imgs_per_shard * self.shard

        if b >= offset + self.n_imgs_per_shard:
            # We have run out of space in this mmap'ed file. Let's fill it as much as we can.
            n_fit = offset + self.n_imgs_per_shard - a
            self.acts[a - offset : a - offset + n_fit] = val[:n_fit]
            self.filled = a - offset + n_fit


            # Recursively call __setitem__ in case we need *another* shard
            self[a + n_fit : b] = val[n_fit:]
            msg = f"0 <= {a} - {offset} <= {offset} + {self.n_imgs_per_shard}"
            assert 0 <= a - offset <= offset + self.n_imgs_per_shard, msg
            msg = f"0 <= {b} - {offset} <= {offset} + {self.n_imgs_per_shard}"
            assert 0 <= b - offset <= offset + self.n_imgs_per_shard, msg
            self.acts[a - offset : b - offset] = val
            self.filled = b - offset

    def flush(self) -> None:
        if self.acts is not None:

        self.acts = None

    def next_shard(self) -> None:

        self.shard += 1
        self._count = 0
        self.acts_path = os.path.join(self.root, f"acts{self.shard:06}.bin")
        self.acts = np.memmap(
            self.acts_path, mode="w+", dtype=np.float32, shape=self.shape
        self.filled = 0"Opened shard '%s'.", self.acts_path)

var acts : jaxtyping.Float[ndarray, 'n_imgs_per_shard n_layers all_patches d_vit'] | None
var acts_path : str
var filled : int
var root : str
var shape : tuple[int, int, int, int]
var shard : int


def flush(self) ‑> None
def next_shard(self) ‑> None
class Siglip (vit_ckpt: str)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class Siglip(torch.nn.Module):
    def __init__(self, vit_ckpt: str):

        import open_clip

        if vit_ckpt.startswith("hf-hub:"):
            clip, _ = open_clip.create_model_from_pretrained(
                vit_ckpt, cache_dir=helpers.get_cache_dir()
            arch, ckpt = vit_ckpt.split("/")
            clip, _ = open_clip.create_model_from_pretrained(
                arch, pretrained=ckpt, cache_dir=helpers.get_cache_dir()

        model = clip.visual
        model.proj = None
        model.output_tokens = True  # type: ignore
        self.model = model

        assert isinstance(self.model, open_clip.timm_model.TimmModel)

    def get_residuals(self) -> list[torch.nn.Module]:
        return self.model.trunk.blocks

    def get_patches(self, cfg: config.Activations) -> slice:
        return slice(None, None, None)

    def forward(
        self, batch: Float[Tensor, "batch 3 width height"]
    ) -> Float[Tensor, "batch patches dim"]:
        result = self.model(batch)
        return result


def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch patches dim']

Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def get_patches(self,
cfg: Activations) ‑> slice
def get_residuals(self) ‑> list[torch.nn.modules.module.Module]