Module contrib.semprobe.config

Classes

class Negatives (dump_to: str = './data/contrib/semprobe/test',
imgs: ImagenetDataset | ImageFolderDataset | Ade20kDataset = <factory>,
classes: list[str] = <factory>,
n_imgs: int = 20,
skip: list[str] = <factory>,
seed: int = 42)

Negatives(dump_to: str = './data/contrib/semprobe/test', imgs: saev.config.ImagenetDataset | saev.config.ImageFolderDataset | saev.config.Ade20kDataset = , classes: list[str] = , n_imgs: int = 20, skip: list[str] = , seed: int = 42)

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Negatives:
    dump_to: str = os.path.join(".", "data", "contrib", "semprobe", "test")
    """Where to save negative samples."""
    imgs: saev.config.DatasetConfig = dataclasses.field(
        default_factory=saev.config.ImagenetDataset
    )
    """Where to sample images from."""
    classes: list[str] = dataclasses.field(
        default_factory=lambda: ["brazil", "cool", "germany", "crash"]
    )
    """Which classes to randomly sample."""
    n_imgs: int = 20
    """Number of negative images."""
    skip: list[str] = dataclasses.field(default_factory=lambda: [])
    """Which images to skip."""

    seed: int = 42
    """Random seed."""

Class variables

var classes : list[str]

Which classes to randomly sample.

var dump_to : str

Where to save negative samples.

var imgsImagenetDataset | ImageFolderDataset | Ade20kDataset

Where to sample images from.

var n_imgs : int

Number of negative images.

var seed : int

Random seed.

var skip : list[str]

Which images to skip.

class Score (sae_ckpt: str = './checkpoints/abcdefg/sae.pt',
batch_size: int = 2048,
n_workers: int = 32,
thresholds: list[float] = <factory>,
top_k: int = 5,
imgs: ImageFolderDataset = <factory>,
acts: DataLoad = <factory>,
dump_to: str = './logs/contrib/semprobe',
include_latents: list[int] = <factory>,
device: str = 'cuda')

Score(sae_ckpt: str = './checkpoints/abcdefg/sae.pt', batch_size: int = 2048, n_workers: int = 32, thresholds: list[float] = , top_k: int = 5, imgs: saev.config.ImageFolderDataset = , acts: saev.config.DataLoad = , dump_to: str = './logs/contrib/semprobe', include_latents: list[int] = , device: str = 'cuda')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Score:
    sae_ckpt: str = os.path.join(".", "checkpoints", "abcdefg", "sae.pt")
    """Path to SAE checkpoint"""

    batch_size: int = 2048
    """Batch size for SAE inference."""
    n_workers: int = 32
    """Number of dataloader workers."""

    thresholds: list[float] = dataclasses.field(
        default_factory=lambda: [0.0, 1.0, 3.0, 10.0, 30.0, 100.0]
    )
    """Threshold(s) for feature activation."""
    top_k: int = 5
    """Number of top features to manually analyze."""

    imgs: saev.config.ImageFolderDataset = dataclasses.field(
        default_factory=saev.config.ImageFolderDataset
    )
    """Where curated examples are stored"""
    acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
    """SAE activations for the curated examples."""

    dump_to: str = os.path.join(".", "logs", "contrib", "semprobe")
    """Where to save results/visualizations."""

    include_latents: list[int] = dataclasses.field(default_factory=list)
    """Latents to manually include."""

    device: str = "cuda"
    """Hardware device."""

Class variables

var actsDataLoad

SAE activations for the curated examples.

var batch_size : int

Batch size for SAE inference.

var device : str

Hardware device.

var dump_to : str

Where to save results/visualizations.

var imgsImageFolderDataset

Where curated examples are stored

var include_latents : list[int]

Latents to manually include.

var n_workers : int

Number of dataloader workers.

var sae_ckpt : str

Path to SAE checkpoint

var thresholds : list[float]

Threshold(s) for feature activation.

var top_k : int

Number of top features to manually analyze.