Module saev.app.modeling

Functions

def get_model_lookup() ‑> dict[str, Config]

Classes

class Config (key: str,
vit_family: str,
vit_ckpt: str,
sae_ckpt: str,
tensor_dpath: pathlib.Path,
dataset_name: str,
acts_cfg: DataLoad)

Configuration for a Vision Transformer (ViT) and Sparse Autoencoder (SAE) model pair.

Stores paths and configuration needed to load and run a specific ViT+SAE combination.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Config:
    """Configuration for a Vision Transformer (ViT) and Sparse Autoencoder (SAE) model pair.

    Stores paths and configuration needed to load and run a specific ViT+SAE combination.
    """

    key: str
    """The lookup key."""

    vit_family: str
    """The family of ViT model, e.g. 'clip' for CLIP models."""

    vit_ckpt: str
    """Checkpoint identifier for the ViT model, either as HuggingFace path or model/checkpoint pair."""

    sae_ckpt: str
    """Identifier for the SAE checkpoint to load."""

    tensor_dpath: pathlib.Path
    """Directory containing precomputed tensors for this model combination."""

    dataset_name: str
    """Which dataset to use."""

    acts_cfg: config.DataLoad
    """Which activations to load for normalizing."""

    @property
    def wrapped_cfg(self) -> config.Activations:
        n_patches = 196
        if self.vit_family == "dinov2":
            n_patches = 256

        return config.Activations(
            vit_family=self.vit_family,
            vit_ckpt=self.vit_ckpt,
            vit_layers=[-2],
            n_patches_per_img=n_patches,
        )

Class variables

var acts_cfgDataLoad

Which activations to load for normalizing.

var dataset_name : str

Which dataset to use.

var key : str

The lookup key.

var sae_ckpt : str

Identifier for the SAE checkpoint to load.

var tensor_dpath : pathlib.Path

Directory containing precomputed tensors for this model combination.

var vit_ckpt : str

Checkpoint identifier for the ViT model, either as HuggingFace path or model/checkpoint pair.

var vit_family : str

The family of ViT model, e.g. 'clip' for CLIP models.

Instance variables

prop wrapped_cfgActivations
Expand source code
@property
def wrapped_cfg(self) -> config.Activations:
    n_patches = 196
    if self.vit_family == "dinov2":
        n_patches = 256

    return config.Activations(
        vit_family=self.vit_family,
        vit_ckpt=self.vit_ckpt,
        vit_layers=[-2],
        n_patches_per_img=n_patches,
    )