Module contrib.semseg.config

Configs for all the different subscripts in contrib.semseg.

Imports must be fast in this file, as described in saev.config. So do not import torch, numpy, etc.

Functions

def grid(cfg: Train,
sweep_dct: dict[str, object]) ‑> tuple[list[Train], list[str]]

Classes

class Quantitative (sae_ckpt: str = './checkpoints/sae.pt',
seg_ckpt: str = './checkpoints/contrib/semseg/best.pt',
top_values: str = './data/sort_by_patch/top_values.pt',
sparsity: str = './data/sort_by_patch/sparsity.pt',
act_mean: str = './data/contrib/semseg/dinov2_imagenet1k_mean.pt',
act_norm: float = 2.0181241035461426,
label_threshold: float = 0.9,
vit_family: Literal['clip', 'siglip', 'dinov2', 'moondream2'] = 'dinov2',
vit_ckpt: str = 'dinov2_vitb14_reg',
vit_layer: int = 11,
patch_size_px: tuple[int, int] = (14, 14),
n_patches_per_img: int = 256,
cls_token: bool = True,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
scale: float = -2.0,
top_k: int = 3,
device: str = 'cuda',
dump_to: str = './logs/contrib/semseg/quantitative',
seed: int = 42)

Quantitative(sae_ckpt: str = './checkpoints/sae.pt', seg_ckpt: str = './checkpoints/contrib/semseg/best.pt', top_values: str = './data/sort_by_patch/top_values.pt', sparsity: str = './data/sort_by_patch/sparsity.pt', act_mean: str = './data/contrib/semseg/dinov2_imagenet1k_mean.pt', act_norm: float = 2.0181241035461426, label_threshold: float = 0.9, vit_family: Literal['clip', 'siglip', 'dinov2', 'moondream2'] = 'dinov2', vit_ckpt: str = 'dinov2_vitb14_reg', vit_layer: int = 11, patch_size_px: tuple[int, int] = (14, 14), n_patches_per_img: int = 256, cls_token: bool = True, imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, scale: float = -2.0, top_k: int = 3, device: str = 'cuda', dump_to: str = './logs/contrib/semseg/quantitative', seed: int = 42)

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Quantitative:
    sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
    """Path to trained SAE checkpoint."""
    seg_ckpt: str = os.path.join(".", "checkpoints", "contrib", "semseg", "best.pt")
    """Path to trained segmentation head."""
    top_values: str = os.path.join(".", "data", "sort_by_patch", "top_values.pt")
    """Path to top_values.pt file generated by `saev visuals`."""
    sparsity: str = os.path.join(".", "data", "sort_by_patch", "sparsity.pt")
    """Path to sparsity.pt file generated by `saev visuals`."""
    max_freq = 3e-2
    """Maximum frequency. Any feature that fires more than this is ignored."""
    act_mean: str = os.path.join(
        ".", "data", "contrib", "semseg", "dinov2_imagenet1k_mean.pt"
    )
    """Where to load activation mean from."""
    act_norm: float = 2.0181241035461426
    """How much to scale activations such that average dataset norm is approximately sqrt(d_vit)."""

    label_threshold: float = 0.9
    """Proportion of pixels that must have the same label to consider a given patch when calculating F1."""

    vit_family: typing.Literal["clip", "siglip", "dinov2", "moondream2"] = "dinov2"
    """Which ViT family."""
    vit_ckpt: str = "dinov2_vitb14_reg"
    """Specific ViT checkpoint."""
    vit_layer: int = 11
    """Vit layer to read/modify."""
    patch_size_px: tuple[int, int] = (14, 14)
    """ViT patch size."""

    n_patches_per_img: int = 256
    """Number of ViT patches per image (depends on model)."""
    cls_token: bool = True
    """Whether the model has a [CLS] token."""

    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=lambda: saev.config.Ade20kDataset(split="validation")
    )
    """Data configuration for ADE20K dataset."""

    batch_size: int = 128
    """Batch size for inference."""
    n_workers: int = 32
    """Number of dataloader workers."""

    scale: float = -2.0
    """Intervention scale. Likely needs to be larger for random-vector."""
    top_k: int = 3
    """Number of latents to show."""

    device: str = "cuda"
    """Hardware for inference."""
    dump_to: str = os.path.join(".", "logs", "contrib", "semseg", "quantitative")
    """Directory to save results to."""
    seed: int = 42
    """Random seed."""

Class variables

var act_mean : str

Where to load activation mean from.

var act_norm : float

How much to scale activations such that average dataset norm is approximately sqrt(d_vit).

var batch_size : int

Batch size for inference.

var cls_token : bool

Whether the model has a [CLS] token.

var device : str

Hardware for inference.

var dump_to : str

Directory to save results to.

var imgsAde20kDataset

Data configuration for ADE20K dataset.

var label_threshold : float

Proportion of pixels that must have the same label to consider a given patch when calculating F1.

var max_freq

Maximum frequency. Any feature that fires more than this is ignored.

var n_patches_per_img : int

Number of ViT patches per image (depends on model).

var n_workers : int

Number of dataloader workers.

var patch_size_px : tuple[int, int]

ViT patch size.

var sae_ckpt : str

Path to trained SAE checkpoint.

var scale : float

Intervention scale. Likely needs to be larger for random-vector.

var seed : int

Random seed.

var seg_ckpt : str

Path to trained segmentation head.

var sparsity : str

Path to sparsity.pt file generated by saev visuals.

var top_k : int

Number of latents to show.

var top_values : str

Path to top_values.pt file generated by saev visuals.

var vit_ckpt : str

Specific ViT checkpoint.

var vit_family : Literal['clip', 'siglip', 'dinov2', 'moondream2']

Which ViT family.

var vit_layer : int

Vit layer to read/modify.

class Train (learning_rate: float = 0.0001,
weight_decay: float = 0.001,
n_epochs: int = 400,
batch_size: int = 1024,
n_workers: int = 32,
imgs: Ade20kDataset = <factory>,
eval_every: int = 100,
device: str = 'cuda',
ckpt_path: str = './checkpoints/contrib/semseg',
seed: int = 42,
log_to: str = './logs/contrib/semseg')

Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, imgs: saev.config.Ade20kDataset = , eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/semseg', seed: int = 42, log_to: str = './logs/contrib/semseg')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Train:
    learning_rate: float = 1e-4
    """Linear layer learning rate."""
    weight_decay: float = 1e-3
    """Weight decay  for AdamW."""
    n_epochs: int = 400
    """Number of training epochs for linear layer."""
    batch_size: int = 1024
    """Training batch size for linear layer."""
    n_workers: int = 32
    """Number of dataloader workers."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=saev.config.Ade20kDataset
    )
    """Configuration for the ADE20K dataset."""
    eval_every: int = 100
    """How many epochs between evaluations."""
    device: str = "cuda"
    "Hardware to train on."
    ckpt_path: str = os.path.join(".", "checkpoints", "contrib", "semseg")
    seed: int = 42
    """Random seed."""
    log_to: str = os.path.join(".", "logs", "contrib", "semseg")

Class variables

var batch_size : int

Training batch size for linear layer.

var ckpt_path : str
var device : str

Hardware to train on.

var eval_every : int

How many epochs between evaluations.

var imgsAde20kDataset

Configuration for the ADE20K dataset.

var learning_rate : float

Linear layer learning rate.

var log_to : str
var n_epochs : int

Number of training epochs for linear layer.

var n_workers : int

Number of dataloader workers.

var seed : int

Random seed.

var weight_decay : float

Weight decay for AdamW.

class Validation (ckpt_root: str = './checkpoints/contrib/semseg',
dump_to: str = './logs/contrib/semseg',
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
device: str = 'cuda')

Validation(ckpt_root: str = './checkpoints/contrib/semseg', dump_to: str = './logs/contrib/semseg', imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, device: str = 'cuda')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Validation:
    ckpt_root: str = os.path.join(".", "checkpoints", "contrib", "semseg")
    """Root to all checkpoints to evaluate."""
    dump_to: str = os.path.join(".", "logs", "contrib", "semseg")
    """Directory to dump results to."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=lambda: saev.config.Ade20kDataset(split="validation")
    )
    """Configuration for the ADE20K validation dataset."""
    batch_size: int = 128
    """Batch size for calculating F1 scores."""
    n_workers: int = 32
    """Number of dataloader workers."""
    device: str = "cuda"
    "Hardware for linear probe inference."

Class variables

var batch_size : int

Batch size for calculating F1 scores.

var ckpt_root : str

Root to all checkpoints to evaluate.

var device : str

Hardware for linear probe inference.

var dump_to : str

Directory to dump results to.

var imgsAde20kDataset

Configuration for the ADE20K validation dataset.

var n_workers : int

Number of dataloader workers.

class Visuals (sae_ckpt: str = './checkpoints/sae.pt',
ade20k_cls: int = 29,
k: int = 32,
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
label_threshold: float = 0.9,
device: str = 'cuda')

Visuals(sae_ckpt: str = './checkpoints/sae.pt', ade20k_cls: int = 29, k: int = 32, acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Visuals:
    sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
    """Path to the sae.pt file."""
    ade20k_cls: int = 29
    """ADE20K class to probe for."""
    k: int = 32
    """Top K features to save."""
    acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
    """Configuration for the saved ADE20K training ViT activations."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=lambda: saev.config.Ade20kDataset(split="training")
    )
    """Configuration for the ADE20K training dataset."""
    batch_size: int = 128
    """Batch size for calculating F1 scores."""
    n_workers: int = 32
    """Number of dataloader workers."""
    label_threshold: float = 0.9
    device: str = "cuda"
    "Hardware for SAE inference."

Class variables

var actsDataLoad

Configuration for the saved ADE20K training ViT activations.

var ade20k_cls : int

ADE20K class to probe for.

var batch_size : int

Batch size for calculating F1 scores.

var device : str

Hardware for SAE inference.

var imgsAde20kDataset

Configuration for the ADE20K training dataset.

var k : int

Top K features to save.

var label_threshold : float
var n_workers : int

Number of dataloader workers.

var sae_ckpt : str

Path to the sae.pt file.