Module contrib.semseg.config
Configs for all the different subscripts in contrib.semseg
.
Imports must be fast in this file, as described in saev.config
.
So do not import torch, numpy, etc.
Functions
def grid(cfg: Train,
sweep_dct: dict[str, object]) ‑> tuple[list[Train], list[str]]
Classes
class Quantitative (sae_ckpt: str = './checkpoints/sae.pt',
seg_ckpt: str = './checkpoints/contrib/semseg/best.pt',
top_values: str = './data/sort_by_patch/top_values.pt',
sparsity: str = './data/sort_by_patch/sparsity.pt',
act_mean: str = './data/contrib/semseg/dinov2_imagenet1k_mean.pt',
act_norm: float = 2.0181241035461426,
label_threshold: float = 0.9,
vit_family: Literal['clip', 'siglip', 'dinov2', 'moondream2'] = 'dinov2',
vit_ckpt: str = 'dinov2_vitb14_reg',
vit_layer: int = 11,
patch_size_px: tuple[int, int] = (14, 14),
n_patches_per_img: int = 256,
cls_token: bool = True,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
scale: float = -2.0,
top_k: int = 3,
device: str = 'cuda',
dump_to: str = './logs/contrib/semseg/quantitative',
seed: int = 42)-
Quantitative(sae_ckpt: str = './checkpoints/sae.pt', seg_ckpt: str = './checkpoints/contrib/semseg/best.pt', top_values: str = './data/sort_by_patch/top_values.pt', sparsity: str = './data/sort_by_patch/sparsity.pt', act_mean: str = './data/contrib/semseg/dinov2_imagenet1k_mean.pt', act_norm: float = 2.0181241035461426, label_threshold: float = 0.9, vit_family: Literal['clip', 'siglip', 'dinov2', 'moondream2'] = 'dinov2', vit_ckpt: str = 'dinov2_vitb14_reg', vit_layer: int = 11, patch_size_px: tuple[int, int] = (14, 14), n_patches_per_img: int = 256, cls_token: bool = True, imgs: saev.config.Ade20kDataset =
, batch_size: int = 128, n_workers: int = 32, scale: float = -2.0, top_k: int = 3, device: str = 'cuda', dump_to: str = './logs/contrib/semseg/quantitative', seed: int = 42) Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Quantitative: sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt") """Path to trained SAE checkpoint.""" seg_ckpt: str = os.path.join(".", "checkpoints", "contrib", "semseg", "best.pt") """Path to trained segmentation head.""" top_values: str = os.path.join(".", "data", "sort_by_patch", "top_values.pt") """Path to top_values.pt file generated by `saev visuals`.""" sparsity: str = os.path.join(".", "data", "sort_by_patch", "sparsity.pt") """Path to sparsity.pt file generated by `saev visuals`.""" max_freq = 3e-2 """Maximum frequency. Any feature that fires more than this is ignored.""" act_mean: str = os.path.join( ".", "data", "contrib", "semseg", "dinov2_imagenet1k_mean.pt" ) """Where to load activation mean from.""" act_norm: float = 2.0181241035461426 """How much to scale activations such that average dataset norm is approximately sqrt(d_vit).""" label_threshold: float = 0.9 """Proportion of pixels that must have the same label to consider a given patch when calculating F1.""" vit_family: typing.Literal["clip", "siglip", "dinov2", "moondream2"] = "dinov2" """Which ViT family.""" vit_ckpt: str = "dinov2_vitb14_reg" """Specific ViT checkpoint.""" vit_layer: int = 11 """Vit layer to read/modify.""" patch_size_px: tuple[int, int] = (14, 14) """ViT patch size.""" n_patches_per_img: int = 256 """Number of ViT patches per image (depends on model).""" cls_token: bool = True """Whether the model has a [CLS] token.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=lambda: saev.config.Ade20kDataset(split="validation") ) """Data configuration for ADE20K dataset.""" batch_size: int = 128 """Batch size for inference.""" n_workers: int = 32 """Number of dataloader workers.""" scale: float = -2.0 """Intervention scale. Likely needs to be larger for random-vector.""" top_k: int = 3 """Number of latents to show.""" device: str = "cuda" """Hardware for inference.""" dump_to: str = os.path.join(".", "logs", "contrib", "semseg", "quantitative") """Directory to save results to.""" seed: int = 42 """Random seed."""
Class variables
var act_mean : str
-
Where to load activation mean from.
var act_norm : float
-
How much to scale activations such that average dataset norm is approximately sqrt(d_vit).
var batch_size : int
-
Batch size for inference.
var cls_token : bool
-
Whether the model has a [CLS] token.
var device : str
-
Hardware for inference.
var dump_to : str
-
Directory to save results to.
var imgs : Ade20kDataset
-
Data configuration for ADE20K dataset.
var label_threshold : float
-
Proportion of pixels that must have the same label to consider a given patch when calculating F1.
var max_freq
-
Maximum frequency. Any feature that fires more than this is ignored.
var n_patches_per_img : int
-
Number of ViT patches per image (depends on model).
var n_workers : int
-
Number of dataloader workers.
var patch_size_px : tuple[int, int]
-
ViT patch size.
var sae_ckpt : str
-
Path to trained SAE checkpoint.
var scale : float
-
Intervention scale. Likely needs to be larger for random-vector.
var seed : int
-
Random seed.
var seg_ckpt : str
-
Path to trained segmentation head.
var sparsity : str
-
Path to sparsity.pt file generated by
saev visuals
. var top_k : int
-
Number of latents to show.
var top_values : str
-
Path to top_values.pt file generated by
saev visuals
. var vit_ckpt : str
-
Specific ViT checkpoint.
var vit_family : Literal['clip', 'siglip', 'dinov2', 'moondream2']
-
Which ViT family.
var vit_layer : int
-
Vit layer to read/modify.
class Train (learning_rate: float = 0.0001,
weight_decay: float = 0.001,
n_epochs: int = 400,
batch_size: int = 1024,
n_workers: int = 32,
imgs: Ade20kDataset = <factory>,
eval_every: int = 100,
device: str = 'cuda',
ckpt_path: str = './checkpoints/contrib/semseg',
seed: int = 42,
log_to: str = './logs/contrib/semseg')-
Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, imgs: saev.config.Ade20kDataset =
, eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/semseg', seed: int = 42, log_to: str = './logs/contrib/semseg') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Train: learning_rate: float = 1e-4 """Linear layer learning rate.""" weight_decay: float = 1e-3 """Weight decay for AdamW.""" n_epochs: int = 400 """Number of training epochs for linear layer.""" batch_size: int = 1024 """Training batch size for linear layer.""" n_workers: int = 32 """Number of dataloader workers.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=saev.config.Ade20kDataset ) """Configuration for the ADE20K dataset.""" eval_every: int = 100 """How many epochs between evaluations.""" device: str = "cuda" "Hardware to train on." ckpt_path: str = os.path.join(".", "checkpoints", "contrib", "semseg") seed: int = 42 """Random seed.""" log_to: str = os.path.join(".", "logs", "contrib", "semseg")
Class variables
var batch_size : int
-
Training batch size for linear layer.
var ckpt_path : str
var device : str
-
Hardware to train on.
var eval_every : int
-
How many epochs between evaluations.
var imgs : Ade20kDataset
-
Configuration for the ADE20K dataset.
var learning_rate : float
-
Linear layer learning rate.
var log_to : str
var n_epochs : int
-
Number of training epochs for linear layer.
var n_workers : int
-
Number of dataloader workers.
var seed : int
-
Random seed.
var weight_decay : float
-
Weight decay for AdamW.
class Validation (ckpt_root: str = './checkpoints/contrib/semseg',
dump_to: str = './logs/contrib/semseg',
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
device: str = 'cuda')-
Validation(ckpt_root: str = './checkpoints/contrib/semseg', dump_to: str = './logs/contrib/semseg', imgs: saev.config.Ade20kDataset =
, batch_size: int = 128, n_workers: int = 32, device: str = 'cuda') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Validation: ckpt_root: str = os.path.join(".", "checkpoints", "contrib", "semseg") """Root to all checkpoints to evaluate.""" dump_to: str = os.path.join(".", "logs", "contrib", "semseg") """Directory to dump results to.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=lambda: saev.config.Ade20kDataset(split="validation") ) """Configuration for the ADE20K validation dataset.""" batch_size: int = 128 """Batch size for calculating F1 scores.""" n_workers: int = 32 """Number of dataloader workers.""" device: str = "cuda" "Hardware for linear probe inference."
Class variables
var batch_size : int
-
Batch size for calculating F1 scores.
var ckpt_root : str
-
Root to all checkpoints to evaluate.
var device : str
-
Hardware for linear probe inference.
var dump_to : str
-
Directory to dump results to.
var imgs : Ade20kDataset
-
Configuration for the ADE20K validation dataset.
var n_workers : int
-
Number of dataloader workers.
class Visuals (sae_ckpt: str = './checkpoints/sae.pt',
ade20k_cls: int = 29,
k: int = 32,
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
label_threshold: float = 0.9,
device: str = 'cuda')-
Visuals(sae_ckpt: str = './checkpoints/sae.pt', ade20k_cls: int = 29, k: int = 32, acts: saev.config.DataLoad =
, imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Visuals: sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt") """Path to the sae.pt file.""" ade20k_cls: int = 29 """ADE20K class to probe for.""" k: int = 32 """Top K features to save.""" acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad) """Configuration for the saved ADE20K training ViT activations.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=lambda: saev.config.Ade20kDataset(split="training") ) """Configuration for the ADE20K training dataset.""" batch_size: int = 128 """Batch size for calculating F1 scores.""" n_workers: int = 32 """Number of dataloader workers.""" label_threshold: float = 0.9 device: str = "cuda" "Hardware for SAE inference."
Class variables
var acts : DataLoad
-
Configuration for the saved ADE20K training ViT activations.
var ade20k_cls : int
-
ADE20K class to probe for.
var batch_size : int
-
Batch size for calculating F1 scores.
var device : str
-
Hardware for SAE inference.
var imgs : Ade20kDataset
-
Configuration for the ADE20K training dataset.
var k : int
-
Top K features to save.
var label_threshold : float
var n_workers : int
-
Number of dataloader workers.
var sae_ckpt : str
-
Path to the sae.pt file.