Module saev.config

All configs for all saev jobs.

Import Times

This module should be very fast to import so that python main.py --help is fast. This means that the top-level imports should not include big packages like numpy, torch, etc. For example, TreeOfLife.n_imgs imports numpy when it's needed, rather than importing it at the top level.

Also contains code for expanding configs with lists into lists of configs (grid search). Might be expanded in the future to support pseudo-random sampling from distributions to support random hyperparameter search, as in this file.

Functions

def expand(config: dict[str, object]) ‑> Iterator[dict[str, object]]

Expands dicts with (nested) lists into a list of (nested) dicts.

def grid(cfg: Train,
sweep_dct: dict[str, object]) ‑> tuple[list[Train], list[str]]

Classes

class Activations (data: ImagenetDataset | ImageFolderDataset | Ade20kDataset = <factory>,
dump_to: str = './shards',
model_family: Literal['clip', 'siglip', 'dinov2', 'moondream2'] = 'clip',
model_ckpt: str = 'ViT-L-14/openai',
vit_batch_size: int = 1024,
n_workers: int = 8,
d_vit: int = 1024,
layers: list[int] = <factory>,
n_patches_per_img: int = 256,
cls_token: bool = True,
n_patches_per_shard: int = 2400000,
seed: int = 42,
ssl: bool = True,
device: str = 'cuda',
slurm: bool = False,
slurm_acct: str = 'PAS2136',
log_to: str = './logs')

Configuration for calculating and saving ViT activations.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Activations:
    """
    Configuration for calculating and saving ViT activations.
    """

    data: DatasetConfig = dataclasses.field(default_factory=ImagenetDataset)
    """Which dataset to use."""
    dump_to: str = os.path.join(".", "shards")
    """Where to write shards."""
    model_family: typing.Literal["clip", "siglip", "dinov2", "moondream2"] = "clip"
    """Which model family."""
    model_ckpt: str = "ViT-L-14/openai"
    """Specific model checkpoint."""
    vit_batch_size: int = 1024
    """Batch size for ViT inference."""
    n_workers: int = 8
    """Number of dataloader workers."""
    d_vit: int = 1024
    """Dimension of the ViT activations (depends on model)."""
    layers: list[int] = dataclasses.field(default_factory=lambda: [-2])
    """Which layers to save. By default, the second-to-last layer."""
    n_patches_per_img: int = 256
    """Number of ViT patches per image (depends on model)."""
    cls_token: bool = True
    """Whether the model has a [CLS] token."""
    n_patches_per_shard: int = 2_400_000
    """Number of activations per shard; 2.4M is approximately 10GB for 1024-dimensional 4-byte activations."""

    seed: int = 42
    """Random seed."""
    ssl: bool = True
    """Whether to use SSL."""

    # Hardware
    device: str = "cuda"
    """Which device to use."""
    slurm: bool = False
    """Whether to use `submitit` to run jobs on a Slurm cluster."""
    slurm_acct: str = "PAS2136"
    """Slurm account string."""
    log_to: str = "./logs"
    """Where to log Slurm job stdout/stderr."""

Class variables

var cls_token : bool

Whether the model has a [CLS] token.

var d_vit : int

Dimension of the ViT activations (depends on model).

var dataImagenetDataset | ImageFolderDataset | Ade20kDataset

Which dataset to use.

var device : str

Which device to use.

var dump_to : str

Where to write shards.

var layers : list[int]

Which layers to save. By default, the second-to-last layer.

var log_to : str

Where to log Slurm job stdout/stderr.

var model_ckpt : str

Specific model checkpoint.

var model_family : Literal['clip', 'siglip', 'dinov2', 'moondream2']

Which model family.

var n_patches_per_img : int

Number of ViT patches per image (depends on model).

var n_patches_per_shard : int

Number of activations per shard; 2.4M is approximately 10GB for 1024-dimensional 4-byte activations.

var n_workers : int

Number of dataloader workers.

var seed : int

Random seed.

var slurm : bool

Whether to use submitit to run jobs on a Slurm cluster.

var slurm_acct : str

Slurm account string.

var ssl : bool

Whether to use SSL.

var vit_batch_size : int

Batch size for ViT inference.

class Ade20kDataset (root: str = './data/ade20k',
split: Literal['training', 'validation'] = 'training')
Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Ade20kDataset:
    """ """

    root: str = os.path.join(".", "data", "ade20k")
    """Where the class folders with images are stored."""
    split: typing.Literal["training", "validation"] = "training"
    """Data split."""

    @property
    def n_imgs(self) -> int:
        if self.split == "validation":
            return 2000
        else:
            return 20210

Class variables

var root : str

Where the class folders with images are stored.

var split : Literal['training', 'validation']

Data split.

Instance variables

prop n_imgs : int
Expand source code
@property
def n_imgs(self) -> int:
    if self.split == "validation":
        return 2000
    else:
        return 20210
class DataLoad (shard_root: str = './shards',
patches: Literal['cls', 'patches', 'meanpool'] = 'patches',
layer: int | Literal['all', 'meanpool'] = -2,
clamp: float = 100000.0,
n_random_samples: int = 524288,
scale_mean: bool | str = True,
scale_norm: bool | str = True)

Configuration for loading activation data from disk.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class DataLoad:
    """
    Configuration for loading activation data from disk.
    """

    shard_root: str = os.path.join(".", "shards")
    """Directory with .bin shards and a metadata.json file."""
    patches: typing.Literal["cls", "patches", "meanpool"] = "patches"
    """Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'patches' indicates it will return all patches. 'meanpool' returns the mean of all image patches."""
    layer: int | typing.Literal["all", "meanpool"] = -2
    """.. todo: document this field."""
    clamp: float = 1e5
    """Maximum value for activations; activations will be clamped to within [-clamp, clamp]`."""
    n_random_samples: int = 2**19
    """Number of random samples used to calculate approximate dataset means at startup."""
    scale_mean: bool | str = True
    """Whether to subtract approximate dataset means from examples. If a string, manually load from the filepath."""
    scale_norm: bool | str = True
    """Whether to scale average dataset norm to sqrt(d_vit). If a string, manually load from the filepath."""

Class variables

var clamp : float

Maximum value for activations; activations will be clamped to within [-clamp, clamp]`.

var layer : int | Literal['all', 'meanpool']

.. todo: document this field.

var n_random_samples : int

Number of random samples used to calculate approximate dataset means at startup.

var patches : Literal['cls', 'patches', 'meanpool']

Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'patches' indicates it will return all patches. 'meanpool' returns the mean of all image patches.

var scale_mean : bool | str

Whether to subtract approximate dataset means from examples. If a string, manually load from the filepath.

var scale_norm : bool | str

Whether to scale average dataset norm to sqrt(d_vit). If a string, manually load from the filepath.

var shard_root : str

Directory with .bin shards and a metadata.json file.

class ImageFolderDataset (root: str = './data/split')

Configuration for a generic image folder dataset.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class ImageFolderDataset:
    """Configuration for a generic image folder dataset."""

    root: str = os.path.join(".", "data", "split")
    """Where the class folders with images are stored."""

    @property
    def n_imgs(self) -> int:
        """Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires walking the directory structure. If you need to reference this number very often, cache it in a local variable."""
        n = 0
        for _, _, files in os.walk(self.root):
            n += len(files)
        return n

Class variables

var root : str

Where the class folders with images are stored.

Instance variables

prop n_imgs : int

Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires walking the directory structure. If you need to reference this number very often, cache it in a local variable.

Expand source code
@property
def n_imgs(self) -> int:
    """Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires walking the directory structure. If you need to reference this number very often, cache it in a local variable."""
    n = 0
    for _, _, files in os.walk(self.root):
        n += len(files)
    return n
class ImagenetDataset (name: str = 'ILSVRC/imagenet-1k', split: str = 'train')

Configuration for HuggingFace Imagenet.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class ImagenetDataset:
    """Configuration for HuggingFace Imagenet."""

    name: str = "ILSVRC/imagenet-1k"
    """Dataset name on HuggingFace. Don't need to change this.."""
    split: str = "train"
    """Dataset split. For the default ImageNet-1K dataset, can either be 'train', 'validation' or 'test'."""

    @property
    def n_imgs(self) -> int:
        """Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires loading the dataset. If you need to reference this number very often, cache it in a local variable."""
        import datasets

        dataset = datasets.load_dataset(
            self.name, split=self.split, trust_remote_code=True
        )
        return len(dataset)

Class variables

var name : str

Dataset name on HuggingFace. Don't need to change this..

var split : str

Dataset split. For the default ImageNet-1K dataset, can either be 'train', 'validation' or 'test'.

Instance variables

prop n_imgs : int

Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires loading the dataset. If you need to reference this number very often, cache it in a local variable.

Expand source code
@property
def n_imgs(self) -> int:
    """Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires loading the dataset. If you need to reference this number very often, cache it in a local variable."""
    import datasets

    dataset = datasets.load_dataset(
        self.name, split=self.split, trust_remote_code=True
    )
    return len(dataset)
class SparseAutoencoder (d_vit: int = 1024,
exp_factor: int = 16,
sparsity_coeff: float = 0.0004,
n_reinit_samples: int = 524288,
ghost_grads: bool = False,
remove_parallel_grads: bool = True,
normalize_w_dec: bool = True,
seed: int = 0)

SparseAutoencoder(d_vit: int = 1024, exp_factor: int = 16, sparsity_coeff: float = 0.0004, n_reinit_samples: int = 524288, ghost_grads: bool = False, remove_parallel_grads: bool = True, normalize_w_dec: bool = True, seed: int = 0)

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class SparseAutoencoder:
    d_vit: int = 1024
    exp_factor: int = 16
    """Expansion factor for SAE."""
    sparsity_coeff: float = 4e-4
    """How much to weight sparsity loss term."""
    n_reinit_samples: int = 1024 * 16 * 32
    """Number of samples to use for SAE re-init. Anthropic proposes initializing b_dec to the geometric median of the dataset here: https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-bias. We use the regular mean."""
    ghost_grads: bool = False
    """Whether to use ghost grads."""
    remove_parallel_grads: bool = True
    """Whether to remove gradients parallel to W_dec columns (which will be ignored because we force the columns to have unit norm). See https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-optimization for the original discussion from Anthropic."""
    normalize_w_dec: bool = True
    """Whether to make sure W_dec has unit norm columns. See https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder for original citation."""
    seed: int = 0
    """Random seed."""

    @property
    def d_sae(self) -> int:
        return self.d_vit * self.exp_factor

Class variables

var d_vit : int
var exp_factor : int

Expansion factor for SAE.

var ghost_grads : bool

Whether to use ghost grads.

var n_reinit_samples : int

Number of samples to use for SAE re-init. Anthropic proposes initializing b_dec to the geometric median of the dataset here: https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-bias. We use the regular mean.

var normalize_w_dec : bool

Whether to make sure W_dec has unit norm columns. See https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder for original citation.

var remove_parallel_grads : bool

Whether to remove gradients parallel to W_dec columns (which will be ignored because we force the columns to have unit norm). See https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-optimization for the original discussion from Anthropic.

var seed : int

Random seed.

var sparsity_coeff : float

How much to weight sparsity loss term.

Instance variables

prop d_sae : int
Expand source code
@property
def d_sae(self) -> int:
    return self.d_vit * self.exp_factor
class Train (data: DataLoad = <factory>,
n_workers: int = 32,
n_patches: int = 100000000,
sae: SparseAutoencoder = <factory>,
n_sparsity_warmup: int = 0,
lr: float = 0.0004,
n_lr_warmup: int = 500,
sae_batch_size: int = 16384,
track: bool = True,
wandb_project: str = 'saev',
tag: str = '',
log_every: int = 25,
ckpt_path: str = './checkpoints',
device: Literal['cuda', 'cpu'] = 'cuda',
seed: int = 42,
slurm: bool = False,
slurm_acct: str = 'PAS2136',
log_to: str = './logs')

Configuration for training a sparse autoencoder on a vision transformer.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Train:
    """
    Configuration for training a sparse autoencoder on a vision transformer.
    """

    data: DataLoad = dataclasses.field(default_factory=DataLoad)
    """Data configuration"""
    n_workers: int = 32
    """Number of dataloader workers."""
    n_patches: int = 100_000_000
    """Number of SAE training examples."""
    sae: SparseAutoencoder = dataclasses.field(default_factory=SparseAutoencoder)
    """SAE configuration."""
    n_sparsity_warmup: int = 0
    """Number of sparsity coefficient warmup steps."""
    lr: float = 0.0004
    """Learning rate."""
    n_lr_warmup: int = 500
    """Number of learning rate warmup steps."""
    sae_batch_size: int = 1024 * 16
    """Batch size for SAE training."""

    # Logging
    track: bool = True
    """Whether to track with WandB."""
    wandb_project: str = "saev"
    """WandB project name."""
    tag: str = ""
    """Tag to add to WandB run."""
    log_every: int = 25
    """How often to log to WandB."""
    ckpt_path: str = os.path.join(".", "checkpoints")
    """Where to save checkpoints."""

    device: typing.Literal["cuda", "cpu"] = "cuda"
    """Hardware device."""
    seed: int = 42
    """Random seed."""
    slurm: bool = False
    """Whether to use `submitit` to run jobs on a Slurm cluster."""
    slurm_acct: str = "PAS2136"
    """Slurm account string."""
    log_to: str = os.path.join(".", "logs")
    """Where to log Slurm job stdout/stderr."""

Class variables

var ckpt_path : str

Where to save checkpoints.

var dataDataLoad

Data configuration

var device : Literal['cuda', 'cpu']

Hardware device.

var log_every : int

How often to log to WandB.

var log_to : str

Where to log Slurm job stdout/stderr.

var lr : float

Learning rate.

var n_lr_warmup : int

Number of learning rate warmup steps.

var n_patches : int

Number of SAE training examples.

var n_sparsity_warmup : int

Number of sparsity coefficient warmup steps.

var n_workers : int

Number of dataloader workers.

var saeSparseAutoencoder

SAE configuration.

var sae_batch_size : int

Batch size for SAE training.

var seed : int

Random seed.

var slurm : bool

Whether to use submitit to run jobs on a Slurm cluster.

var slurm_acct : str

Slurm account string.

var tag : str

Tag to add to WandB run.

var track : bool

Whether to track with WandB.

var wandb_project : str

WandB project name.

class Visuals (ckpt: str = './checkpoints/sae.pt',
data: DataLoad = <factory>,
images: ImagenetDataset | ImageFolderDataset | Ade20kDataset = <factory>,
top_k: int = 128,
n_workers: int = 16,
topk_batch_size: int = 16384,
sae_batch_size: int = 16384,
epsilon: float = 1e-09,
sort_by: Literal['cls', 'img', 'patch'] = 'patch',
device: str = 'cuda',
dump_to: str = './data',
log_freq_range: tuple[float, float] = (-6.0, -2.0),
log_value_range: tuple[float, float] = (-1.0, 1.0),
include_latents: list[int] = <factory>,
n_distributions: int = 25,
percentile: int = 99,
n_latents: int = 400,
seed: int = 42)

Configuration for generating visuals from trained SAEs.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Visuals:
    """Configuration for generating visuals from trained SAEs."""

    ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
    """Path to the sae.pt file."""
    data: DataLoad = dataclasses.field(default_factory=DataLoad)
    """Data configuration."""
    images: DatasetConfig = dataclasses.field(default_factory=ImagenetDataset)
    """Which images to use."""
    top_k: int = 128
    """How many images per SAE feature to store."""
    n_workers: int = 16
    """Number of dataloader workers."""
    topk_batch_size: int = 1024 * 16
    """Number of examples to apply top-k op to."""
    sae_batch_size: int = 1024 * 16
    """Batch size for SAE inference."""
    epsilon: float = 1e-9
    """Value to add to avoid log(0)."""
    sort_by: typing.Literal["cls", "img", "patch"] = "patch"
    """How to find the top k images. 'cls' picks images where the SAE latents of the ViT's [CLS] token are maximized without any patch highligting. 'img' picks images that maximize the sum of an SAE latent over all patches in the image, highlighting the patches. 'patch' pickes images that maximize an SAE latent over all patches (not summed), highlighting the patches and only showing unique images."""
    device: str = "cuda"
    """Which accelerator to use."""
    dump_to: str = os.path.join(".", "data")
    """Where to save data."""
    log_freq_range: tuple[float, float] = (-6.0, -2.0)
    """Log10 frequency range for which to save images."""
    log_value_range: tuple[float, float] = (-1.0, 1.0)
    """Log10 frequency range for which to save images."""
    include_latents: list[int] = dataclasses.field(default_factory=list)
    """Latents to always include, no matter what."""
    n_distributions: int = 25
    """Number of features to save distributions for."""
    percentile: int = 99
    """Percentile to estimate for outlier detection."""
    n_latents: int = 400
    """Maximum number of latents to save images for."""
    seed: int = 42
    """Random seed."""

    @property
    def root(self) -> str:
        return os.path.join(self.dump_to, f"sort_by_{self.sort_by}")

    @property
    def top_values_fpath(self) -> str:
        return os.path.join(self.root, "top_values.pt")

    @property
    def top_img_i_fpath(self) -> str:
        return os.path.join(self.root, "top_img_i.pt")

    @property
    def top_patch_i_fpath(self) -> str:
        return os.path.join(self.root, "top_patch_i.pt")

    @property
    def mean_values_fpath(self) -> str:
        return os.path.join(self.root, "mean_values.pt")

    @property
    def sparsity_fpath(self) -> str:
        return os.path.join(self.root, "sparsity.pt")

    @property
    def distributions_fpath(self) -> str:
        return os.path.join(self.root, "distributions.pt")

    @property
    def percentiles_fpath(self) -> str:
        return os.path.join(self.root, f"percentiles_p{self.percentile}.pt")

Class variables

var ckpt : str

Path to the sae.pt file.

var dataDataLoad

Data configuration.

var device : str

Which accelerator to use.

var dump_to : str

Where to save data.

var epsilon : float

Value to add to avoid log(0).

var imagesImagenetDataset | ImageFolderDataset | Ade20kDataset

Which images to use.

var include_latents : list[int]

Latents to always include, no matter what.

var log_freq_range : tuple[float, float]

Log10 frequency range for which to save images.

var log_value_range : tuple[float, float]

Log10 frequency range for which to save images.

var n_distributions : int

Number of features to save distributions for.

var n_latents : int

Maximum number of latents to save images for.

var n_workers : int

Number of dataloader workers.

var percentile : int

Percentile to estimate for outlier detection.

var sae_batch_size : int

Batch size for SAE inference.

var seed : int

Random seed.

var sort_by : Literal['cls', 'img', 'patch']

How to find the top k images. 'cls' picks images where the SAE latents of the ViT's [CLS] token are maximized without any patch highligting. 'img' picks images that maximize the sum of an SAE latent over all patches in the image, highlighting the patches. 'patch' pickes images that maximize an SAE latent over all patches (not summed), highlighting the patches and only showing unique images.

var top_k : int

How many images per SAE feature to store.

var topk_batch_size : int

Number of examples to apply top-k op to.

Instance variables

prop distributions_fpath : str
Expand source code
@property
def distributions_fpath(self) -> str:
    return os.path.join(self.root, "distributions.pt")
prop mean_values_fpath : str
Expand source code
@property
def mean_values_fpath(self) -> str:
    return os.path.join(self.root, "mean_values.pt")
prop percentiles_fpath : str
Expand source code
@property
def percentiles_fpath(self) -> str:
    return os.path.join(self.root, f"percentiles_p{self.percentile}.pt")
prop root : str
Expand source code
@property
def root(self) -> str:
    return os.path.join(self.dump_to, f"sort_by_{self.sort_by}")
prop sparsity_fpath : str
Expand source code
@property
def sparsity_fpath(self) -> str:
    return os.path.join(self.root, "sparsity.pt")
prop top_img_i_fpath : str
Expand source code
@property
def top_img_i_fpath(self) -> str:
    return os.path.join(self.root, "top_img_i.pt")
prop top_patch_i_fpath : str
Expand source code
@property
def top_patch_i_fpath(self) -> str:
    return os.path.join(self.root, "top_patch_i.pt")
prop top_values_fpath : str
Expand source code
@property
def top_values_fpath(self) -> str:
    return os.path.join(self.root, "top_values.pt")