Module contrib.semseg.config

Configs for all the different subscripts in contrib.semseg.

Imports must be fast in this file, as described in saev.config. So do not import torch, numpy, etc.

Functions

def grid(cfg: Train,
sweep_dct: dict[str, object]) ‑> tuple[list[Train], list[str]]

Classes

class Train (learning_rate: float = 0.0001,
weight_decay: float = 0.001,
n_epochs: int = 400,
batch_size: int = 1024,
n_workers: int = 32,
imgs: Ade20kDataset = <factory>,
eval_every: int = 100,
device: str = 'cuda',
ckpt_path: str = './checkpoints/contrib/semseg',
seed: int = 42,
log_to: str = './logs/contrib/semseg')

Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, imgs: saev.config.Ade20kDataset = , eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/semseg', seed: int = 42, log_to: str = './logs/contrib/semseg')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Train:
    learning_rate: float = 1e-4
    """Linear layer learning rate."""
    weight_decay: float = 1e-3
    """Weight decay  for AdamW."""
    n_epochs: int = 400
    """Number of training epochs for linear layer."""
    batch_size: int = 1024
    """Training batch size for linear layer."""
    n_workers: int = 32
    """Number of dataloader workers."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=saev.config.Ade20kDataset
    )
    """Configuration for the ADE20K dataset."""
    eval_every: int = 100
    """How many epochs between evaluations."""
    device: str = "cuda"
    "Hardware to train on."
    ckpt_path: str = os.path.join(".", "checkpoints", "contrib", "semseg")
    seed: int = 42
    """Random seed."""
    log_to: str = os.path.join(".", "logs", "contrib", "semseg")

Class variables

var batch_size : int

Training batch size for linear layer.

var ckpt_path : str
var device : str

Hardware to train on.

var eval_every : int

How many epochs between evaluations.

var imgsAde20kDataset

Configuration for the ADE20K dataset.

var learning_rate : float

Linear layer learning rate.

var log_to : str
var n_epochs : int

Number of training epochs for linear layer.

var n_workers : int

Number of dataloader workers.

var seed : int

Random seed.

var weight_decay : float

Weight decay for AdamW.

class Validation (ckpt_root: str = './checkpoints/contrib/semseg',
dump_to: str = './logs/contrib/semseg',
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
device: str = 'cuda')

Validation(ckpt_root: str = './checkpoints/contrib/semseg', dump_to: str = './logs/contrib/semseg', imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, device: str = 'cuda')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Validation:
    ckpt_root: str = os.path.join(".", "checkpoints", "contrib", "semseg")
    """Root to all checkpoints to evaluate."""
    dump_to: str = os.path.join(".", "logs", "contrib", "semseg")
    """Directory to dump results to."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=lambda: saev.config.Ade20kDataset(split="validation")
    )
    """Configuration for the ADE20K validation dataset."""
    batch_size: int = 128
    """Batch size for calculating F1 scores."""
    n_workers: int = 32
    """Number of dataloader workers."""
    device: str = "cuda"
    "Hardware for linear probe inference."

Class variables

var batch_size : int

Batch size for calculating F1 scores.

var ckpt_root : str

Root to all checkpoints to evaluate.

var device : str

Hardware for linear probe inference.

var dump_to : str

Directory to dump results to.

var imgsAde20kDataset

Configuration for the ADE20K validation dataset.

var n_workers : int

Number of dataloader workers.

class Visuals (sae_ckpt: str = './checkpoints/sae.pt',
ade20k_cls: int = 29,
k: int = 32,
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
label_threshold: float = 0.9,
device: str = 'cuda')

Visuals(sae_ckpt: str = './checkpoints/sae.pt', ade20k_cls: int = 29, k: int = 32, acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Visuals:
    sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
    """Path to the sae.pt file."""
    ade20k_cls: int = 29
    """ADE20K class to probe for."""
    k: int = 32
    """Top K features to save."""
    acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
    """Configuration for the saved ADE20K training ViT activations."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=lambda: saev.config.Ade20kDataset(split="training")
    )
    """Configuration for the ADE20K training dataset."""
    batch_size: int = 128
    """Batch size for calculating F1 scores."""
    n_workers: int = 32
    """Number of dataloader workers."""
    label_threshold: float = 0.9
    device: str = "cuda"
    "Hardware for SAE inference."

Class variables

var actsDataLoad

Configuration for the saved ADE20K training ViT activations.

var ade20k_cls : int

ADE20K class to probe for.

var batch_size : int

Batch size for calculating F1 scores.

var device : str

Hardware for SAE inference.

var imgsAde20kDataset

Configuration for the ADE20K training dataset.

var k : int

Top K features to save.

var label_threshold : float
var n_workers : int

Number of dataloader workers.

var sae_ckpt : str

Path to the sae.pt file.