Module contrib.semseg.config
Configs for all the different subscripts in contrib.semseg
.
Imports must be fast in this file, as described in saev.config
.
So do not import torch, numpy, etc.
Functions
def grid(cfg: Train,
sweep_dct: dict[str, object]) ‑> tuple[list[Train], list[str]]
Classes
class Train (learning_rate: float = 0.0001,
weight_decay: float = 0.001,
n_epochs: int = 400,
batch_size: int = 1024,
n_workers: int = 32,
imgs: Ade20kDataset = <factory>,
eval_every: int = 100,
device: str = 'cuda',
ckpt_path: str = './checkpoints/contrib/semseg',
seed: int = 42,
log_to: str = './logs/contrib/semseg')-
Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, imgs: saev.config.Ade20kDataset =
, eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/semseg', seed: int = 42, log_to: str = './logs/contrib/semseg') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Train: learning_rate: float = 1e-4 """Linear layer learning rate.""" weight_decay: float = 1e-3 """Weight decay for AdamW.""" n_epochs: int = 400 """Number of training epochs for linear layer.""" batch_size: int = 1024 """Training batch size for linear layer.""" n_workers: int = 32 """Number of dataloader workers.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=saev.config.Ade20kDataset ) """Configuration for the ADE20K dataset.""" eval_every: int = 100 """How many epochs between evaluations.""" device: str = "cuda" "Hardware to train on." ckpt_path: str = os.path.join(".", "checkpoints", "contrib", "semseg") seed: int = 42 """Random seed.""" log_to: str = os.path.join(".", "logs", "contrib", "semseg")
Class variables
var batch_size : int
-
Training batch size for linear layer.
var ckpt_path : str
var device : str
-
Hardware to train on.
var eval_every : int
-
How many epochs between evaluations.
var imgs : Ade20kDataset
-
Configuration for the ADE20K dataset.
var learning_rate : float
-
Linear layer learning rate.
var log_to : str
var n_epochs : int
-
Number of training epochs for linear layer.
var n_workers : int
-
Number of dataloader workers.
var seed : int
-
Random seed.
var weight_decay : float
-
Weight decay for AdamW.
class Validation (ckpt_root: str = './checkpoints/contrib/semseg',
dump_to: str = './logs/contrib/semseg',
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
device: str = 'cuda')-
Validation(ckpt_root: str = './checkpoints/contrib/semseg', dump_to: str = './logs/contrib/semseg', imgs: saev.config.Ade20kDataset =
, batch_size: int = 128, n_workers: int = 32, device: str = 'cuda') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Validation: ckpt_root: str = os.path.join(".", "checkpoints", "contrib", "semseg") """Root to all checkpoints to evaluate.""" dump_to: str = os.path.join(".", "logs", "contrib", "semseg") """Directory to dump results to.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=lambda: saev.config.Ade20kDataset(split="validation") ) """Configuration for the ADE20K validation dataset.""" batch_size: int = 128 """Batch size for calculating F1 scores.""" n_workers: int = 32 """Number of dataloader workers.""" device: str = "cuda" "Hardware for linear probe inference."
Class variables
var batch_size : int
-
Batch size for calculating F1 scores.
var ckpt_root : str
-
Root to all checkpoints to evaluate.
var device : str
-
Hardware for linear probe inference.
var dump_to : str
-
Directory to dump results to.
var imgs : Ade20kDataset
-
Configuration for the ADE20K validation dataset.
var n_workers : int
-
Number of dataloader workers.
class Visuals (sae_ckpt: str = './checkpoints/sae.pt',
ade20k_cls: int = 29,
k: int = 32,
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
label_threshold: float = 0.9,
device: str = 'cuda')-
Visuals(sae_ckpt: str = './checkpoints/sae.pt', ade20k_cls: int = 29, k: int = 32, acts: saev.config.DataLoad =
, imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Visuals: sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt") """Path to the sae.pt file.""" ade20k_cls: int = 29 """ADE20K class to probe for.""" k: int = 32 """Top K features to save.""" acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad) """Configuration for the saved ADE20K training ViT activations.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=lambda: saev.config.Ade20kDataset(split="training") ) """Configuration for the ADE20K training dataset.""" batch_size: int = 128 """Batch size for calculating F1 scores.""" n_workers: int = 32 """Number of dataloader workers.""" label_threshold: float = 0.9 device: str = "cuda" "Hardware for SAE inference."
Class variables
var acts : DataLoad
-
Configuration for the saved ADE20K training ViT activations.
var ade20k_cls : int
-
ADE20K class to probe for.
var batch_size : int
-
Batch size for calculating F1 scores.
var device : str
-
Hardware for SAE inference.
var imgs : Ade20kDataset
-
Configuration for the ADE20K training dataset.
var k : int
-
Top K features to save.
var label_threshold : float
var n_workers : int
-
Number of dataloader workers.
var sae_ckpt : str
-
Path to the sae.pt file.