Skip to content

Sweeps

Hyperparameter sweeps in saev train multiple SAE configurations in parallel on a single GPU, amortizing the cost of loading activation data from disk across all models. Furthermore, sweeps make it easy to train multiple SAEs with one command across multiple GPUs using Slurm.

Quick Start

Create a Python file defining your sweep:

# sweeps/my_sweep.py

def make_cfgs() -> list[dict]:
    cfgs = []

    # Grid search over learning rate and sparsity
    for lr in [3e-4, 1e-3, 3e-3]:
        for sparsity in [4e-4, 8e-4, 1.6e-3]:
            cfg = {
                "lr": lr,
                "objective": {"sparsity_coeff": sparsity},
            }
            cfgs.append(cfg)

    return cfgs

Run the sweep:

uv run train.py --sweep sweeps/my_sweep.py \
  --train-data.layer 23 \
  --val-data.layer 23

This trains 9 SAEs (3 learning rates x 3 sparsity coefficients) in parallel.

Why Parallel Sweeps?

SAE training is bottlenecked by disk I/O, not GPU computation. Loading terabytes of pre-computed ViT activations from disk is the slowest part. By training multiple SAE configurations on the same batch simultaneously, we amortize the I/O cost:

┌────────────────────────┐
│ ViT Activations (disk) │
└───────────┬────────────┘
            │ (slow I/O, once per batch)
            ▼
      ┌──────────┐
      │  Batch   │
      └─────┬────┘
            ├─────────┬─────────┬─────────┐
            ▼         ▼         ▼         ▼
         SAE #1    SAE #2    SAE #3     ...
        (lr=3e-4) (lr=1e-3) (lr=3e-3)

Sweep Configuration

Python-Based Sweeps

Python sweeps give you full control over config generation. Your sweep file must define a make_cfgs() function that returns a list of dicts.

Grid search example:

def make_cfgs():
    cfgs = []

    for lr in [1e-4, 3e-4, 1e-3]:
        for exp_factor in [8, 16, 32]:
            cfg = {
                "lr": lr,
                "sae": {"exp_factor": exp_factor},
            }
            cfgs.append(cfg)

    return cfgs

Paired parameters (not a grid):

def make_cfgs():
    cfgs = []

    # Grid over lr x sparsity
    for lr in [3e-4, 1e-3, 3e-3]:
        for sparsity in [4e-4, 8e-4, 1.6e-3]:
            # Paired layers (train and val use same layer)
            for layer in [6, 7, 8, 9, 10, 11]:
                cfg = {
                    "lr": lr,
                    "objective": {"sparsity_coeff": sparsity},
                    "train_data": {"layer": layer},
                    "val_data": {"layer": layer},
                }
                cfgs.append(cfg)

    return cfgs

This generates 54 configs (3 x 3 x 6) where each train/val pair uses the same layer, avoiding the 162 configs you'd get from a full grid (3 x 3 x 6 x 6).

Conditional sweeps:

def make_cfgs():
    cfgs = []

    for exp_factor in [8, 16, 32]:
        # Use different LR for different expansion factors
        lrs = [1e-3, 3e-3] if exp_factor <= 16 else [3e-4, 1e-3]

        for lr in lrs:
            cfg = {
                "lr": lr,
                "sae": {"exp_factor": exp_factor},
            }
            cfgs.append(cfg)

    return cfgs

Command-Line Overrides

Command-line arguments override sweep parameters with deep merging. The precedence order is: CLI > Sweep > Default.

uv run train.py --sweep sweeps/my_sweep.py \
  --lr 5e-4  # Overrides all LRs in the sweep

Override nested config fields with dotted notation:

uv run train.py --sweep sweeps/my_sweep.py \
  --train-data.layer 23 \
  --val-data.layer 23 \
  --sae.exp-factor 16

Deep merging means that when you override a nested field, only that specific field is replaced—other fields in the nested config are preserved from the sweep or default values.

Parallel Groups

Not all parameters can vary within a parallel sweep. Parameters that affect data loading (like train_data, n_train, device) must be identical across all configs in a parallel group.

When configs differ in these parameters, they're automatically split into separate Slurm jobs:

def make_cfgs():
    cfgs = []

    # These will run in 2 separate jobs
    for layer in [6, 12]:  # Different data loading
        for lr in [1e-4, 3e-4]:  # Can parallelize
            cfg = {
                "lr": lr,
                "train_data": {"layer": layer},
            }
            cfgs.append(cfg)

    return cfgs

This creates 2 parallel groups: - Job 1: layer=6, lr=[1e-4, 3e-4] - Job 2: layer=12, lr=[1e-4, 3e-4]

Implementation detail

See CANNOT_PARALLELIZE in train.py for the full list of parameters that split parallel groups. The split_cfgs() function handles grouping automatically.

Module Loading

Your sweep file is executed as a Python module, so you can use imports and helper functions:

def make_cfgs():
    cfgs = []

    # You can use helper functions
    base_layers = list(range(6, 24, 2))

    for layer in base_layers:
        for lr in [1e-4, 3e-4]:
            cfg = {
                "lr": lr,
                "train_data": {"layer": layer, "n_threads": 8},
                "val_data": {"layer": layer, "n_threads": 8},
                "sae": {"exp_factor": 16, "d_vit": 1024},
            }
            cfgs.append(cfg)

    return cfgs

Import mechanics

The sweep file is loaded with importlib.import_module(), so it must be importable as a Python module. Place sweep files in a location where Python can find them (typically the project root or a sweeps/ subdirectory).

Slurm Integration

When running with --slurm-acct, each parallel group becomes a separate Slurm job:

uv run train.py --sweep sweeps/large.py \
  --slurm-acct PAS2136 \
  --slurm-partition nextgen \
  --n-hours 24

The system automatically: - Groups configs that can parallelize - Submits one Slurm job per group - Waits for all jobs to complete - Reports results

Seed Management

Seeds are automatically incremented for each config to ensure reproducibility:

# Base config has seed=42
# Sweep generates 9 configs with seeds: 42, 43, 44, ..., 50

Override the base seed on the command line:

uv run train.py --sweep sweeps/my_sweep.py --seed 100

Examples

Simple grid:

# sweeps/simple.py
def make_cfgs():
    return [
        {"lr": lr, "objective": {"sparsity_coeff": sp}}
        for lr in [1e-4, 3e-4, 1e-3]
        for sp in [4e-4, 8e-4, 1.6e-3]
    ]

Layer sweep with paired train/val:

# sweeps/layers.py
def make_cfgs():
    cfgs = []

    for layer in range(6, 24, 2):  # Layers 6, 8, 10, ..., 22
        for lr in [3e-4, 1e-3]:
            cfg = {
                "lr": lr,
                "train_data": {"layer": layer},
                "val_data": {"layer": layer},
            }
            cfgs.append(cfg)

    return cfgs

Architecture sweep:

# sweeps/architecture.py
def make_cfgs():
    cfgs = []

    architectures = [
        ("small", 8, 1e-3),
        ("medium", 16, 5e-4),
        ("large", 32, 3e-4),
    ]

    for name, exp_factor, lr in architectures:
        cfg = {
            "lr": lr,
            "sae": {"exp_factor": exp_factor},
            "tag": name,
        }
        cfgs.append(cfg)

    return cfgs