Module contrib.semprobe.config
Classes
class Negatives (dump_to: str = './data/contrib/semprobe/test',
imgs: ImagenetDataset | ImageFolderDataset | Ade20kDataset = <factory>,
classes: list[str] = <factory>,
n_imgs: int = 20,
skip: list[str] = <factory>,
seed: int = 42)-
Negatives(dump_to: str = './data/contrib/semprobe/test', imgs: saev.config.ImagenetDataset | saev.config.ImageFolderDataset | saev.config.Ade20kDataset =
, classes: list[str] = , n_imgs: int = 20, skip: list[str] = , seed: int = 42) Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Negatives: dump_to: str = os.path.join(".", "data", "contrib", "semprobe", "test") """Where to save negative samples.""" imgs: saev.config.DatasetConfig = dataclasses.field( default_factory=saev.config.ImagenetDataset ) """Where to sample images from.""" classes: list[str] = dataclasses.field( default_factory=lambda: ["brazil", "cool", "germany", "crash"] ) """Which classes to randomly sample.""" n_imgs: int = 20 """Number of negative images.""" skip: list[str] = dataclasses.field(default_factory=lambda: []) """Which images to skip.""" seed: int = 42 """Random seed."""
Class variables
var classes : list[str]
-
Which classes to randomly sample.
var dump_to : str
-
Where to save negative samples.
var imgs : ImagenetDataset | ImageFolderDataset | Ade20kDataset
-
Where to sample images from.
var n_imgs : int
-
Number of negative images.
var seed : int
-
Random seed.
var skip : list[str]
-
Which images to skip.
class Score (sae_ckpt: str = './checkpoints/abcdefg/sae.pt',
batch_size: int = 2048,
n_workers: int = 32,
thresholds: list[float] = <factory>,
top_k: int = 5,
imgs: ImageFolderDataset = <factory>,
acts: DataLoad = <factory>,
dump_to: str = './logs/contrib/semprobe',
include_latents: list[int] = <factory>,
device: str = 'cuda')-
Score(sae_ckpt: str = './checkpoints/abcdefg/sae.pt', batch_size: int = 2048, n_workers: int = 32, thresholds: list[float] =
, top_k: int = 5, imgs: saev.config.ImageFolderDataset = , acts: saev.config.DataLoad = , dump_to: str = './logs/contrib/semprobe', include_latents: list[int] = , device: str = 'cuda') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Score: sae_ckpt: str = os.path.join(".", "checkpoints", "abcdefg", "sae.pt") """Path to SAE checkpoint""" batch_size: int = 2048 """Batch size for SAE inference.""" n_workers: int = 32 """Number of dataloader workers.""" thresholds: list[float] = dataclasses.field( default_factory=lambda: [0.0, 1.0, 3.0, 10.0, 30.0, 100.0] ) """Threshold(s) for feature activation.""" top_k: int = 5 """Number of top features to manually analyze.""" imgs: saev.config.ImageFolderDataset = dataclasses.field( default_factory=saev.config.ImageFolderDataset ) """Where curated examples are stored""" acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad) """SAE activations for the curated examples.""" dump_to: str = os.path.join(".", "logs", "contrib", "semprobe") """Where to save results/visualizations.""" include_latents: list[int] = dataclasses.field(default_factory=list) """Latents to manually include.""" device: str = "cuda" """Hardware device."""
Class variables
var acts : DataLoad
-
SAE activations for the curated examples.
var batch_size : int
-
Batch size for SAE inference.
var device : str
-
Hardware device.
var dump_to : str
-
Where to save results/visualizations.
var imgs : ImageFolderDataset
-
Where curated examples are stored
var include_latents : list[int]
-
Latents to manually include.
var n_workers : int
-
Number of dataloader workers.
var sae_ckpt : str
-
Path to SAE checkpoint
var thresholds : list[float]
-
Threshold(s) for feature activation.
var top_k : int
-
Number of top features to manually analyze.