Skip to content

saev.data

IndexedConfig(shards=pathlib.Path('$SAEV_SCRATCH/saev/shards/abcdefg'), tokens='content', layer=-2, debug=False) dataclass

Configuration for loading indexed activation data from disk

Attributes:

Name Type Description
shards Path

Directory with .bin shards and a metadata.json file.

tokens Literal['special', 'content', 'all']

Which kinds of tokens to use. 'special' indicates the special tokens token (if any). 'content' returns content tokens. 'all' returns both content and special tokens.

layer int | Literal['all']

Which ViT layer(s) to read from disk. -2 selects the second-to-last layer. "all" enumerates every recorded layer.

debug bool

Whether the dataloader process should log debug messages.

IndexedDataset(cfg)

Bases: Dataset

Dataset of activations from disk.

Attributes:

Name Type Description
cfg Config

Configuration set via CLI args.

md Metadata

Activations metadata; automatically loaded from disk.

layer_idx int

Layer index into the shards if we are choosing a specific layer.

Source code in src/saev/data/indexed.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(self, cfg: Config):
    self.cfg = cfg
    if not os.path.isdir(self.cfg.shards):
        raise RuntimeError(f"Activations are not saved at '{self.cfg.shards}'.")

    self.md = shards.Metadata.load(self.cfg.shards)

    # Validate shard files exist and are non-empty
    shard_info = shards.ShardInfo.load(self.cfg.shards)
    shard_info.validate(self.cfg.shards)

    # Check if labels.bin exists
    labels_path = os.path.join(self.cfg.shards, "labels.bin")
    self.labels_mmap = None
    if os.path.exists(labels_path):
        self.labels_mmap = np.memmap(
            labels_path,
            mode="r",
            dtype=np.uint8,
            shape=(self.md.n_examples, self.md.content_tokens_per_example),
        )

    self.index_map = shards.IndexMap(self.md, self.cfg.tokens, self.cfg.layer)

d_model property

Dimension of the underlying vision transformer's embedding space.

Example

Bases: TypedDict

Individual example.

__len__()

Dataset length depends on patches and layer.

Source code in src/saev/data/indexed.py
123
124
125
126
127
def __len__(self) -> int:
    """
    Dataset length depends on `patches` and `layer`.
    """
    return len(self.index_map)

Metadata(*, family, ckpt, layers, content_tokens_per_example, cls_token, d_model, n_examples, max_tokens_per_shard, data, dataset, pixel_agg=PixelAgg.MAJORITY, dtype='float32', protocol='2.1') dataclass

Metadata for a sharded set of transformer activations.

Parameters:

Name Type Description Default
family Literal['bird-mae', 'clip', 'dinov2', 'dinov3', 'fake-clip', 'pe-core', 'pe-spatial', 'siglip']

The transformer family.

required
ckpt str

The transformer checkpoint.

required
layers tuple[int, ...]

Which layers were saved.

required
content_tokens_per_example int

The number of content tokens per example.

required
cls_token bool

Whether the transformer has a [CLS] token as well.

required
d_model int

Model hidden dimension.

required
n_examples int

Number of examples.

required
max_tokens_per_shard int

The maximum number of tokens per shard.

required
data str

base64-encoded string of pickle.dumps(dataset).

required
dataset Path

Absolute path to the root directory of the original dataset.

required
pixel_agg PixelAgg

(only for image segmentation datasets) how the pixel-level segmentation labels were aggregated to token-level labels.

MAJORITY
dtype Literal['float32']

How activations are stored.

'float32'
protocol Literal['1.0.0', '1.1', '2.1']

Protocol version.

'2.1'

examples_per_shard property

The number of examples per shard based on the protocol.

Returns:

Type Description
int

Number of examples that fit in a shard.

hash property

First 8 bytes of a SHA256 hash of the metadata configuration.

Returns:

Type Description
str

Hexadecimal hash string uniquely identifying this configuration.

n_shards property

Total number of shards needed to store all examples.

Returns:

Type Description
int

Number of shards required.

shard_shape property

Shape of each shard file.

Returns:

Type Description
tuple[int, int, int, int]

Tuple of (examples_per_shard, n_layers, tokens_per_example, d_model).

tokens_per_example property

Total number of tokens per example including [CLS] token if present.

Returns:

Type Description
int

Number of tokens plus one if [CLS] token is included.

dump(shards_root)

Dumps a Metadata object to a metadata.json file in shards_root / hash.

Parameters:

Name Type Description Default
shards_root Path

Path to $SAEV_SCRATCH/saev/shards as described in disk-layout.md.

required
Source code in src/saev/data/shards.py
114
115
116
117
118
119
120
121
122
123
124
def dump(self, shards_root: pathlib.Path):
    """
    Dumps a Metadata object to a metadata.json file in shards_root / hash.

    Args:
        shards_root: Path to $SAEV_SCRATCH/saev/shards as described in [disk-layout.md](../../developers/disk-layout.md).
    """
    assert disk.is_shards_root(shards_root)
    (shards_root / self.hash).mkdir(exist_ok=True)
    with open(shards_root / self.hash / "metadata.json", "wb") as fd:
        helpers.jdump(self, fd, option=orjson.OPT_INDENT_2)

load(shards_dir) classmethod

Loads a Metadata object from a metadata.json file in shards_dir.

Parameters:

Name Type Description Default
shards_dir Path

Path to $SAEV_SCRATCH/saev/shards/ as described in disk-layout.md.

required
Source code in src/saev/data/shards.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@classmethod
def load(cls, shards_dir: pathlib.Path) -> tp.Self:
    """
    Loads a Metadata object from a metadata.json file in shards_dir.

    Args:
        shards_dir: Path to $SAEV_SCRATCH/saev/shards/<hash> as described in [disk-layout.md](../../developers/disk-layout.md).
    """
    assert disk.is_shards_dir(shards_dir)

    with open(shards_dir / "metadata.json") as fd:
        dct = json.load(fd)
    dct["layers"] = tuple(dct.pop("layers"))
    dct["dataset"] = pathlib.Path(dct["dataset"])
    dct["pixel_agg"] = PixelAgg(dct["pixel_agg"])
    return cls(**dct)

OrderedConfig(shards=pathlib.Path('$SAEV_SCRATCH/saev/shards/abcdefg'), tokens='content', layer=-2, batch_size=1024 * 16, batch_timeout_s=30.0, drop_last=False, buffer_size=64, debug=False, log_every_s=30.0) dataclass

Configuration for loading ordered (non-shuffled) activation data from disk

Attributes:

Name Type Description
shards Path

Directory with .bin shards and a metadata.json file.

tokens Literal['content']

Which kinds of tokens to use. 'special' indicates the special tokens token (if any). 'content' returns content tokens. 'all' returns both content and special tokens.

layer int | Literal['all']

Which ViT layer(s) to read from disk. -2 selects the second-to-last layer. "all" enumerates every recorded layer.

batch_size int

Batch size.

batch_timeout_s float

How long to wait for at least one batch.

drop_last bool

Whether to drop the last batch if it's smaller than the others.

buffer_size int

Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls.

debug bool

Whether the dataloader process should log debug messages.

log_every_s float

How frequently the dataloader process should log (debug) performance messages.

OrderedDataLoader(cfg)

High-throughput streaming loader that reads data from disk shards in order (no shuffling).

Source code in src/saev/data/ordered.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def __init__(self, cfg: Config):
    self.cfg = cfg
    if not os.path.isdir(self.cfg.shards):
        raise RuntimeError(f"Activations are not saved at '{self.cfg.shards}'.")

    self.md = shards.Metadata.load(self.cfg.shards)

    # Validate shard files exist and are non-empty
    shard_info = shards.ShardInfo.load(self.cfg.shards)
    shard_info.validate(self.cfg.shards)

    self.logger = logging.getLogger("ordered.DataLoader")
    self.ctx = mp.get_context()
    self.manager_proc = None
    self.batch_queue = None
    self.stop_event = None
    self._n_samples = self._calculate_n_samples()
    self.logger.info(
        "Initialized ordered.DataLoader with %d samples. (debug=%s)",
        self.n_samples,
        self.cfg.debug,
    )

ExampleBatch

Bases: TypedDict

Individual example.

__iter__()

Yields batches in order.

Source code in src/saev/data/ordered.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def __iter__(self) -> collections.abc.Iterable[ExampleBatch]:
    """Yields batches in order."""
    self._start_manager()
    n = 0

    try:
        while n < self.n_samples:
            if not self.err_queue.empty():
                who, tb = self.err_queue.get_nowait()
                raise RuntimeError(f"{who} crashed:\n{tb}")

            try:
                batch = self.batch_queue.get(timeout=self.cfg.batch_timeout_s)
                actual_batch_size = batch["act"].shape[0]

                # Handle drop_last
                if (
                    self.cfg.drop_last
                    and actual_batch_size < self.cfg.batch_size
                    and n + actual_batch_size >= self.n_samples
                ):
                    break

                n += actual_batch_size
                yield self.ExampleBatch(**batch)
                continue
            except queue.Empty:
                self.logger.info(
                    "Did not get a batch from manager process in %.1fs seconds.",
                    self.cfg.batch_timeout_s,
                )
            except FileNotFoundError:
                self.logger.info("Manager process (probably) closed.")
                continue

            # If we don't continue, then we should check on the manager process.
            if not self.manager_proc.is_alive():
                raise RuntimeError(
                    f"Manager process died unexpectedly after {n}/{self.n_samples} samples."
                )

    finally:
        self.shutdown()

__len__()

Returns the number of batches in an epoch.

Source code in src/saev/data/ordered.py
371
372
373
374
375
376
def __len__(self) -> int:
    """Returns the number of batches in an epoch."""
    if self.cfg.drop_last:
        return self.n_samples // self.cfg.batch_size
    else:
        return math.ceil(self.n_samples / self.cfg.batch_size)

PixelAgg

Bases: Enum

How to aggregate pixel-level segmentation labels to token-level labels (only for image segmentation datasets).

ShuffledConfig(shards=pathlib.Path('$SAEV_SCRATCH/saev/shards/abcdefg'), tokens='content', layer=-1, batch_size=1024 * 16, drop_last=False, scale_norm=False, ignore_labels=list(), n_threads=4, buffer_size=64, min_buffer_fill=0.0, batch_timeout_s=30.0, seed=17, debug=False, log_every_s=30.0, use_tmpdir=False) dataclass

Configuration for loading shuffled activation data from disk.

Attributes:

Name Type Description
shards Path

Directory with .bin shards and a metadata.json file.

tokens Literal['special', 'content', 'all']

Which subset of tokens to use. 'special' indicates the special tokens (if any). 'content' indicates it will return content tokens. 'all' returns all tokens.

batch_size = 1024 * 16 class-attribute instance-attribute

Batch size.

batch_timeout_s = 30.0 class-attribute instance-attribute

How long to wait for at least one batch.

buffer_size = 64 class-attribute instance-attribute

Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls.

debug = False class-attribute instance-attribute

Whether the dataloader process should log debug messages.

drop_last = False class-attribute instance-attribute

Whether to drop the last batch if it's smaller than the others.

ignore_labels = dataclasses.field(default_factory=list) class-attribute instance-attribute

If provided, exclude tokens with these label values. None means no filtering. Common use: ignore_labels=[0] to exclude background.

layer = -1 class-attribute instance-attribute

Which transformer layer(s) to read from disk. -1 is the default, but must be changed. "all" enumerates every recorded layer.

log_every_s = 30.0 class-attribute instance-attribute

How frequently the dataloader process should log (debug) performance messages.

min_buffer_fill = 0.0 class-attribute instance-attribute

Fraction of the reservoir that must be populated before yielding batches.

n_threads = 4 class-attribute instance-attribute

Number of dataloading threads.

scale_norm = False class-attribute instance-attribute

Whether to scale norms to sqrt(D).

seed = 17 class-attribute instance-attribute

Random seed.

use_tmpdir = False class-attribute instance-attribute

If True and $TMPDIR is set, copy shards to local storage before training to avoid Infiniband congestion.

ShuffledDataLoader(cfg)

High-throughput streaming loader that deterministically shuffles data from disk shards.

Source code in src/saev/data/shuffled.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
def __init__(self, cfg: Config):
    self.cfg = cfg

    self.manager_proc = None
    self.reservoir = None
    self.stop_event = None
    self._last_reservoir_fill: float | None = None
    self._logged_effective_capacity = False

    self.logger = logging.getLogger("shuffled.DataLoader")
    self.ctx = mp.get_context()

    if not os.path.isdir(self.cfg.shards):
        raise RuntimeError(f"Activations are not saved at '{self.cfg.shards}'.")

    # Copy to TMPDIR if requested, otherwise use original path
    if self.cfg.use_tmpdir:
        self._shards_path = _copy_shards_to_tmpdir(self.cfg.shards, self.logger)
    else:
        self._shards_path = self.cfg.shards

    if self.cfg.scale_norm:
        raise NotImplementedError("scale_norm not implemented.")

    self.metadata = shards.Metadata.load(self._shards_path)

    # Validate shard files exist and are non-empty
    shard_info = shards.ShardInfo.load(self._shards_path)
    shard_info.validate(self._shards_path)

    self._n_samples = self._calculate_n_samples()

    # Check if labels.bin exists for filtering
    self.labels_mmap = None
    if self.cfg.ignore_labels:
        labels_path = os.path.join(self._shards_path, "labels.bin")
        if not os.path.exists(labels_path):
            raise FileNotFoundError(
                f"ignore_labels filtering requested but labels.bin not found at {labels_path}"
            )

ExampleBatch

Bases: TypedDict

Individual example.

__iter__()

Yields batches.

Source code in src/saev/data/shuffled.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
def __iter__(self) -> collections.abc.Iterator[ExampleBatch]:
    """Yields batches."""
    self._start_manager()
    n, b = 0, 0

    try:
        while n < self.n_samples:
            need = min(self.cfg.batch_size, self.n_samples - n)
            remaining_samples = self.n_samples - n
            self._wait_for_min_buffer_fill(remaining_samples)
            if not self.err_queue.empty():
                who, tb = self.err_queue.get_nowait()
                raise RuntimeError(f"{who} crashed:\n{tb}")

            try:
                act, meta = self.reservoir.get(
                    need, timeout=self.cfg.batch_timeout_s
                )
                n += need
                b += 1
                example_idx, token_idx = meta.T
                yield self.ExampleBatch(
                    act=act, example_idx=example_idx, token_idx=token_idx
                )
                continue
            except TimeoutError:
                if self.cfg.ignore_labels:
                    self.logger.info(
                        "Did not get a batch from %d worker threads in %.1fs seconds. This can happen when filtering out many labels.",
                        self.cfg.n_threads,
                        self.cfg.batch_timeout_s,
                    )
                else:
                    self.logger.info(
                        "Did not get a batch from %d worker threads in %.1fs seconds.",
                        self.cfg.n_threads,
                        self.cfg.batch_timeout_s,
                    )

            # If we don't continue, then we should check on the manager process.
            if not self.manager_proc.is_alive():
                raise RuntimeError(
                    f"Manager process died unexpectedly after {b}/{len(self)} batches."
                )

    finally:
        self.shutdown()

__len__()

Returns the number of batches in an epoch.

Source code in src/saev/data/shuffled.py
697
698
699
def __len__(self) -> int:
    """Returns the number of batches in an epoch."""
    return math.ceil(self.n_samples / self.cfg.batch_size)

make_ordered_config(shuffled_cfg, **overrides)

Create an OrderedConfig from a ShuffledConfig, with optional overrides.

Defaults come from shuffled_cfg for fields present in OrderedConfig, and overrides take precedence. Unknown override fields raise TypeError from the OrderedConfig constructor, mirroring dataclasses.replace.

Source code in src/saev/data/__init__.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@beartype.beartype
def make_ordered_config(
    shuffled_cfg: ShuffledConfig, **overrides: object
) -> OrderedConfig:
    """Create an `OrderedConfig` from a `ShuffledConfig`, with optional overrides.

    Defaults come from `shuffled_cfg` for fields present in `OrderedConfig`, and `overrides` take precedence. Unknown override fields raise `TypeError` from the `OrderedConfig` constructor, mirroring `dataclasses.replace`.
    """
    params: dict[str, object] = {}
    for f in dataclasses.fields(OrderedConfig):
        name = f.name
        if hasattr(shuffled_cfg, name):
            params[name] = getattr(shuffled_cfg, name)
    params.update(overrides)
    return OrderedConfig(**params)