Module saev.data

SAEV Sharded-Activation File Protocol v1 (2025-06-17)

saev caches activations to disk rather than run ViT or LLM inference when training SAEs. Gemma Scope makes this decision as well (see Section 3.3.2 of https://arxiv.org/pdf/2408.05147). saev.data has a specific protocol to support this in on OSC, a super computer center, and take advantage of OSC's specific disk performance.

Goal: loss-lessly persist very large Transformer (ViT or LLM) activations in a form that is:

  • mem-mappable
  • Parameterized solely by the experiment configuration (Config)
  • Referenced by a content-hash, so identical configs collide, divergent ones never do
  • Can be read quickly in a random order for training, and can be read (slowly) with random-access for visuals.

This document is the single normative source. Any divergence in code is a bug.


1. Directory layout

<dump_to>/<HASH>/
    metadata.json              # UTF-8 JSON, human-readable, describes data-generating config
    shards.json                # UTF-8 JSON, human-readable, describes shards.
    acts000000.bin             # shard 0
    acts000001.bin             # shard 1
    ...
    actsNNNNNN.bin             # shard NNNNNN  (zero-padded width=6)

HASH = sha256(json.dumps(metadata, sort_keys=True, separators=(',', ':')).encode('utf-8')) Guards against silent config drift.


2. JSON file schemas

2.1. metadata.json

field type semantic
vit_family string "clip" \| "siglip" \| "dinov2"
vit_ckpt string model identifier (OpenCLIP, HF, etc.)
layers int[] ViT residual‐block indices recorded
n_patches_per_img int image patches only (excludes CLS)
cls_token bool true -> patch 0 is CLS, else no CLS
d_vit int activation dimensionality
n_imgs int total images in dataset
max_patches_per_shard int logical activations per shard (see #3)
data object opaque dataset description
dtype string numpy dtype. Fixed "float32" for now.
protocol string "1.0.0" for now.

The data object is dataclasses.asdict(cfg.data), with an additional __class__ field with cfg.data.__class__.__name__ as the value.

2.2. shards.json

A single array of shard objects, each of which has the following fields:

field type semantic
name string shard filename (acts000000.bin).
n_imgs int the number of images in the shard.

3 Shard sizing maths

n_tokens_per_img = n_patches_per_img + (1 if cls_token else 0)

n_imgs_per_shard = floor(max_patches_per_shard / (n_tokens_per_img * len(layers)))

shape_per_shard = (
    n_imgs_per_shard, len(layers), n_tokens_per_img, d_vit,
)

max_patches_per_shard is a budget (default ~2.4 M) chosen so a shard is approximately 10 GiB for Float32 @ d_vit = 1024.

The last shard will have a smaller value for n_imgs_per_shard; this value is documented in n_imgs in shards.json


4. Data Layout and Global Indexing

The entire dataset of activations is treated as a single logical 4D tensor with the shape (n_imgs, len(layers), n_tokens_per_img, d_vit). This logical tensor is C-contiguous with axes ordered [Image, Layer, Token, Dimension].

Physically, this tensor is split along the first axis (Image) into multiple shards, where each shard is a single binary file. The number of images in each shard is constant, except for the final shard, which may be smaller.

To locate an arbitrary activation vector, a reader must convert a logical coordinate (global_img_idx, layer_value, token_idx) into a file path and an offset within that file.

4.1 Definitions

Let the parameters from metadata.json be:

  • L = len(layers)
  • P = n_patches_per_img
  • T = P + (1 if cls_token else 0) (Total tokens per image)
  • D = d_vit
  • S = n_imgs from shards.json or n_imgs_per_shard from Section 3 (shard sizing).

4.2 Coordinate Transformations

Given a logical coordinate:

  • global_img_idx: integer, with 0 <= global_img_idx < n_imgs
  • layer: integer, must be an element of layers
  • token_idx: integer, 0 <= token_idx < T

The physical location is found as follows:

  1. Identify Shard:

    • shard_idx = global_img_idx // S
    • img_in_shard = global_img_idx % S The target file is acts{shard_idx:06d}.bin.
  2. Identify Layer Index: The stored data contains a subset of the ViT's layers. The logical layer_value must be mapped to its index in the stored layers array.

    • layer_idx = layers.index(layer) A reader must raise an error if layer is not in layers.
  3. Calculate Offset: The data within a shard is a 4D tensor of shape (S, L, T, D). The offset to the first byte of the desired activation vector [img_in_shard, layer_in_list_idx, token_idx] is:

    • offset_in_vectors = (img_in_shard * L * T) + (layer_in_list_idx * T) + token_idx
    • offset_in_bytes = offset_in_vectors * D * 4 (assuming 4 bytes for float32)

A reader can then seek to offset_in_bytes and read $D \times 4$ bytes to retrieve the vector.

Alternatively, rather than calculate the offset, readers can memmap the shard, then use Numpy indexing to get the activation vector.

4.3 Token Axis Layout

The token axis of length $T$ is ordered as follows: * If cls_token is true: * Index 0: [CLS] token activation * Indices 1 to $P$: Patch token activations * If cls_token is false: * Indices 0 to $P-1$: Patch token activations

The relative order of patch tokens is preserved exactly as produced by the upstream Vision Transformer.


5 Versioning & compatibility

  • Major changes (shape reorder, dtype switch, new required JSON keys) increment the major protocol version number at the top of this document and must emit a breaking warning in loader code.
  • Minor, backward-compatible additions (new optional JSON key) merely update this doc and the minor protocol version number.

That's the whole deal. No hidden invariants. Anything else you find in code that contradicts this sheet, fix the code or update the spec.

Performance

SAEs are mostly disk-bound. Gemma Scope (Google SAE paper) aimed for 1 GB/s to keep their GPUS brrr'ing. This is pretty hard even with sequential reads, much less random access.

I run all my experiments on OSC and their scratch filesystem /fs/scratch has sequential read speeds of around 800 MB/s and random access speeds around 22 MB/s.

I got these numbers with:

fio --name=net --filename=/fs/scratch/PAS2136/samuelstevens/cache/saev/366017a10220b85014ae0a594276b25f6ea3d756b74d1d3218da1e34ffcf32e9/acts000000.bin --rw=read --bs=1M --direct=1 --iodepth=16 --runtime=30 --time_based

and

fio --name=net --filename=/fs/scratch/PAS2136/samuelstevens/cache/saev/366017a10220b85014ae0a594276b25f6ea3d756b74d1d3218da1e34ffcf32e9/acts000000.bin --rw=randread --bs=4K --direct=1 --iodepth=16 --runtime=30 --time_based

These two commands reported, respectively:

READ: bw=796MiB/s (835MB/s), 796MiB/s-796MiB/s (835MB/s-835MB/s), io=23.3GiB (25.0GB), run=30001-30001msec

and

READ: bw=22.9MiB/s (24.0MB/s), 22.9MiB/s-22.9MiB/s (24.0MB/s-24.0MB/s), io=687MiB (721MB), run=30001-30001msec

My naive pytorch-style dataset that uses multiple processes to feed a dataloader did purely random reads and was too slow. It reports around 500 examples/s:

Performance plot showing that naive random access dataloading maxes out around 500 examples/s.

I've implemented a dataloader that tries to do sequential reads rather than random reads in saev/data/iterable.py. It's much faster (around 4.5K examples/s) on OSC.

Performance plot showing that my first attempt at a sequential dataloader maxes out around 4500 examples/s.

I know that it should be even faster; the dataset of 128M examples is 2.9TB, my sequential disk read speed is 800 MB/s, so it should take ~1 hr. For 128M examples at 4.5K examples/s, it should take 7.9 hours. You can see this on a wandb run here which reports 14.6% disk utilization. Certainly that can be higher.

Not sure if this is the correct way to think about it, but: 100 / 14.6 = 6.8, close to 7.9 hours.

Ordered Dataloader Design

The saev/data/ordered.py module implements a high-throughput ordered dataloader that guarantees sequential data delivery. This is useful for iterating through all patches in an image at once.

Key Design Decisions

  1. Single-threaded I/O in Manager Process

Initially, the dataloader used multiple worker threads for parallel I/O, similar to PyTorch's DataLoader. However, this created a fundamental ordering problem: when multiple workers read batches in parallel, they complete at different times and deliver batches out of order.

We switched to single-threaded I/O because: - Sequential reads from memory-mapped files are already highly optimized by the OS - The OS page cache provides excellent performance for sequential access patterns - Eliminating multi-threading removes all batch reordering complexity - The simpler design is more maintainable and debuggable

  1. Process Separation with Ring Buffer

The dataloader still uses a separate manager process connected via a multiprocessing Queue (acting as a ring buffer). This provides: - Overlap between I/O and computation - Configurable read-ahead via buffer_size parameter - Natural backpressure when computation is slower than I/O - Process isolation for better resource management

  1. Shard-Aware Sequential Reading

The dataloader correctly handles the actual distribution of data across shards by: - Reading shards.json to get the exact number of images per shard - Maintaining cumulative offsets for efficient index-to-shard mapping - Handling batches that span multiple shards without gaps or duplicates

Performance Considerations

  • Memory-mapped files: Using np.memmap allows efficient access to large files without loading them entirely into memory
  • Sequential access pattern: The dataloader reads data in the exact order it's stored on disk, maximizing OS cache effectiveness
  • Minimal data copying: Activations are copied only once from the memory-mapped file to PyTorch tensors
  • Read-ahead buffering: The configurable buffer size allows tuning the trade-off between memory usage and I/O overlap

Trade-offs

The single-threaded design trades potential parallel I/O throughput for: - Guaranteed ordering - Simplicity and maintainability
- Elimination of synchronization overhead - Predictable performance characteristics

In practice, the sequential read performance is sufficient for most use cases, especially when the computation (e.g., SAE forward pass) is the bottleneck rather than I/O.

Sub-modules

saev.data.buffers
saev.data.config
saev.data.images
saev.data.indexed
saev.data.models
saev.data.ordered

Ordered (sequential) dataloader for activation data …

saev.data.shuffled
saev.data.writers

Classes

class IndexedConfig (shard_root: str = './shards',
patches: Literal['cls', 'image', 'all'] = 'image',
layer: Union[int, Literal['all']] = -2,
seed: int = 17,
debug: bool = False)

Configuration for loading indexed activation data from disk.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Config:
    """Configuration for loading indexed activation data from disk."""

    shard_root: str = os.path.join(".", "shards")
    """Directory with .bin shards and a metadata.json file."""
    patches: typing.Literal["cls", "image", "all"] = "image"
    """Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches."""
    layer: int | typing.Literal["all"] = -2
    """Which ViT layer(s) to read from disk. ``-2`` selects the second-to-last layer. ``"all"`` enumerates every recorded layer."""
    seed: int = 17
    """Random seed."""
    debug: bool = False
    """Whether the dataloader process should log debug messages."""

Class variables

var debug : bool

Whether the dataloader process should log debug messages.

var layer : Union[int, Literal['all']]

Which ViT layer(s) to read from disk. -2 selects the second-to-last layer. "all" enumerates every recorded layer.

var patches : Literal['cls', 'image', 'all']

Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches.

var seed : int

Random seed.

var shard_root : str

Directory with .bin shards and a metadata.json file.

class OrderedConfig (shard_root: str = './shards',
patches: Literal['cls', 'image', 'all'] = 'image',
layer: Union[int, Literal['all']] = -2,
batch_size: int = 16384,
batch_timeout_s: float = 30.0,
drop_last: bool = False,
buffer_size: int = 64,
debug: bool = False)

Configuration for loading ordered (non-shuffled) activation data from disk.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Config:
    """Configuration for loading ordered (non-shuffled) activation data from disk."""

    shard_root: str = os.path.join(".", "shards")
    """Directory with .bin shards and a metadata.json file."""
    patches: typing.Literal["cls", "image", "all"] = "image"
    """Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches."""
    layer: int | typing.Literal["all"] = -2
    """Which ViT layer(s) to read from disk. ``-2`` selects the second-to-last layer. ``"all"`` enumerates every recorded layer."""
    batch_size: int = 1024 * 16
    """Batch size."""
    batch_timeout_s: float = 30.0
    """How long to wait for at least one batch."""
    drop_last: bool = False
    """Whether to drop the last batch if it's smaller than the others."""
    buffer_size: int = 64
    """Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls."""
    debug: bool = False
    """Whether the dataloader process should log debug messages."""

Class variables

var batch_size : int

Batch size.

var batch_timeout_s : float

How long to wait for at least one batch.

var buffer_size : int

Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls.

var debug : bool

Whether the dataloader process should log debug messages.

var drop_last : bool

Whether to drop the last batch if it's smaller than the others.

var layer : Union[int, Literal['all']]

Which ViT layer(s) to read from disk. -2 selects the second-to-last layer. "all" enumerates every recorded layer.

var patches : Literal['cls', 'image', 'all']

Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches.

var shard_root : str

Directory with .bin shards and a metadata.json file.

class ShuffledConfig (shard_root: str = './shards',
patches: Literal['cls', 'image', 'all'] = 'image',
layer: Union[int, Literal['all']] = -2,
batch_size: int = 16384,
drop_last: bool = False,
scale_norm: bool = False,
n_threads: int = 4,
buffer_size: int = 64,
batch_timeout_s: float = 30.0,
seed: int = 17,
debug: bool = False,
log_every_s: float = 30.0)

Configuration for loading shuffled activation data from disk.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Config:
    """Configuration for loading shuffled activation data from disk."""

    shard_root: str = os.path.join(".", "shards")
    """Directory with .bin shards and a metadata.json file."""
    patches: typing.Literal["cls", "image", "all"] = "image"
    """Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches."""
    layer: int | typing.Literal["all"] = -2
    """Which ViT layer(s) to read from disk. ``-2`` selects the second-to-last layer. ``"all"`` enumerates every recorded layer."""
    batch_size: int = 1024 * 16
    """Batch size."""
    drop_last: bool = False
    """Whether to drop the last batch if it's smaller than the others."""
    scale_norm: bool = False
    """Whether to scale norms to sqrt(D)."""
    # Performance
    n_threads: int = 4
    """Number of dataloading threads."""
    buffer_size: int = 64
    """Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls."""
    batch_timeout_s: float = 30.0
    """How long to wait for at least one batch."""
    # Diagnostics
    seed: int = 17
    """Random seed."""
    debug: bool = False
    """Whether the dataloader process should log debug messages."""
    log_every_s: float = 30.0
    """How frequently the dataloader process should log (debug) performance messages."""

Class variables

var batch_size : int

Batch size.

var batch_timeout_s : float

How long to wait for at least one batch.

var buffer_size : int

Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls.

var debug : bool

Whether the dataloader process should log debug messages.

var drop_last : bool

Whether to drop the last batch if it's smaller than the others.

var layer : Union[int, Literal['all']]

Which ViT layer(s) to read from disk. -2 selects the second-to-last layer. "all" enumerates every recorded layer.

var log_every_s : float

How frequently the dataloader process should log (debug) performance messages.

var n_threads : int

Number of dataloading threads.

var patches : Literal['cls', 'image', 'all']

Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches.

var scale_norm : bool

Whether to scale norms to sqrt(D).

var seed : int

Random seed.

var shard_root : str

Directory with .bin shards and a metadata.json file.

class OrderedDataLoader (cfg: Config)

High-throughput streaming loader that reads data from disk shards in order (no shuffling).

Expand source code
@beartype.beartype
class DataLoader:
    """
    High-throughput streaming loader that reads data from disk shards in order (no shuffling).
    """

    @jaxtyped(typechecker=beartype.beartype)
    class ExampleBatch(typing.TypedDict):
        """Individual example."""

        act: Float[Tensor, "batch d_vit"]
        image_i: Int[Tensor, " batch"]
        patch_i: Int[Tensor, " batch"]

    def __init__(self, cfg: Config):
        self.cfg = cfg
        if not os.path.isdir(self.cfg.shard_root):
            raise RuntimeError(f"Activations are not saved at '{self.cfg.shard_root}'.")

        self.metadata = writers.Metadata.load(self.cfg.shard_root)

        self.logger = logging.getLogger("ordered.DataLoader")
        self.ctx = mp.get_context()
        self.manager_proc = None
        self.batch_queue = None
        self.stop_event = None
        self._n_samples = self._calculate_n_samples()

    @property
    def n_batches(self) -> int:
        return len(self)

    @property
    def n_samples(self) -> int:
        return self._n_samples

    @property
    def batch_size(self) -> int:
        return self.cfg.batch_size

    @property
    def drop_last(self) -> int:
        return self.cfg.drop_last

    def _start_manager(self):
        # Always shutdown existing manager to ensure fresh start
        if self.manager_proc and self.manager_proc.is_alive():
            self.logger.info("Shutting down existing manager process.")
            self.shutdown()

        self.logger.info("Starting manager process.")

        # Create the batch queue
        self.batch_queue = self.ctx.Queue(maxsize=self.cfg.buffer_size)
        self.stop_event = self.ctx.Event()
        self.err_queue = self.ctx.Queue(maxsize=2)  # Manager + main process

        self.manager_proc = self.ctx.Process(
            target=_manager_main,
            args=(
                self.cfg,
                self.metadata,
                self.batch_queue,
                self.stop_event,
                self.err_queue,
            ),
            daemon=True,
        )
        self.manager_proc.start()

    def __iter__(self) -> collections.abc.Iterable[ExampleBatch]:
        """Yields batches in order."""
        self._start_manager()
        n = 0

        try:
            while n < self.n_samples:
                if not self.err_queue.empty():
                    who, tb = self.err_queue.get_nowait()
                    raise RuntimeError(f"{who} crashed:\n{tb}")

                try:
                    batch = self.batch_queue.get(timeout=self.cfg.batch_timeout_s)
                    actual_batch_size = batch["act"].shape[0]

                    # Handle drop_last
                    if (
                        self.cfg.drop_last
                        and actual_batch_size < self.cfg.batch_size
                        and n + actual_batch_size >= self.n_samples
                    ):
                        break

                    n += actual_batch_size
                    yield self.ExampleBatch(**batch)
                    continue
                except queue.Empty:
                    self.logger.info(
                        "Did not get a batch from manager process in %.1fs seconds.",
                        self.cfg.batch_timeout_s,
                    )

                # If we don't continue, then we should check on the manager process.
                if not self.manager_proc.is_alive():
                    raise RuntimeError(
                        f"Manager process died unexpectedly after {n}/{self.n_samples} samples."
                    )

        finally:
            self.shutdown()

    def shutdown(self):
        if (
            hasattr(self, "stop_event")
            and self.stop_event
            and not self.stop_event.is_set()
        ):
            self.stop_event.set()

        if (
            hasattr(self, "manager_proc")
            and self.manager_proc
            and self.manager_proc.is_alive()
        ):
            self.manager_proc.join(timeout=5.0)
            if self.manager_proc.is_alive():
                self.logger.warning(
                    "Manager process did not shut down cleanly, killing."
                )
                self.manager_proc.kill()

        self.manager_proc = None
        self.batch_queue = None
        self.stop_event = None

    def __del__(self):
        self.shutdown()

    def _calculate_n_samples(self) -> int:
        """Helper to calculate total number of examples based on config."""
        match (self.cfg.patches, self.cfg.layer):
            case ("cls", "all"):
                return self.metadata.n_imgs * len(self.metadata.layers)
            case ("cls", int()):
                return self.metadata.n_imgs
            case ("image", int()):
                return self.metadata.n_imgs * self.metadata.n_patches_per_img
            case ("image", "all"):
                return (
                    self.metadata.n_imgs
                    * len(self.metadata.layers)
                    * self.metadata.n_patches_per_img
                )
            case _:
                typing.assert_never((self.cfg.patches, self.cfg.layer))

    def __len__(self) -> int:
        """Returns the number of batches in an epoch."""
        if self.cfg.drop_last:
            return self.n_samples // self.cfg.batch_size
        else:
            return math.ceil(self.n_samples / self.cfg.batch_size)

Class variables

var ExampleBatch

Individual example.

Instance variables

prop batch_size : int
Expand source code
@property
def batch_size(self) -> int:
    return self.cfg.batch_size
prop drop_last : int
Expand source code
@property
def drop_last(self) -> int:
    return self.cfg.drop_last
prop n_batches : int
Expand source code
@property
def n_batches(self) -> int:
    return len(self)
prop n_samples : int
Expand source code
@property
def n_samples(self) -> int:
    return self._n_samples

Methods

def shutdown(self)
class ShuffledDataLoader (cfg: Config)

High-throughput streaming loader that deterministically shuffles data from disk shards.

Expand source code
@beartype.beartype
class DataLoader:
    """
    High-throughput streaming loader that deterministically shuffles data from disk shards.
    """

    @jaxtyped(typechecker=beartype.beartype)
    class ExampleBatch(typing.TypedDict):
        """Individual example."""

        act: Float[Tensor, "batch d_vit"]
        image_i: Int[Tensor, " batch"]
        patch_i: Int[Tensor, " batch"]

    def __init__(self, cfg: Config):
        self.cfg = cfg

        self.manager_proc = None
        self.reservoir = None
        self.stop_event = None

        self.logger = logging.getLogger("shuffled.DataLoader")
        self.ctx = mp.get_context()

        if not os.path.isdir(self.cfg.shard_root):
            raise RuntimeError(f"Activations are not saved at '{self.cfg.shard_root}'.")

        if self.cfg.scale_norm:
            raise NotImplementedError("scale_norm not implemented.")

        self.metadata = writers.Metadata.load(self.cfg.shard_root)
        self._n_samples = self._calculate_n_samples()

    @property
    def n_batches(self) -> int:
        return len(self)

    @property
    def n_samples(self) -> int:
        return self._n_samples

    @property
    def batch_size(self) -> int:
        return self.cfg.batch_size

    @property
    def drop_last(self) -> int:
        return self.cfg.drop_last

    @property
    def manager_pid(self) -> int:
        if not self.manager_proc or not self.manager_proc.is_alive():
            return -1

        return self.manager_proc.pid

    def _start_manager(self):
        if self.manager_proc and self.manager_proc.is_alive():
            return

        self.logger.info("Starting manager process.")

        # Create the shared-memory buffers
        self.reservoir = buffers.ReservoirBuffer(
            self.cfg.buffer_size * self.cfg.batch_size,
            (self.metadata.d_vit,),
            dtype=torch.float32,
            meta_shape=(2,),
            meta_dtype=torch.int32,
            collate_fn=torch.utils.data.default_collate,
        )
        self.stop_event = self.ctx.Event()
        self.err_queue = self.ctx.Queue(maxsize=self.cfg.n_threads + 1)

        self.manager_proc = self.ctx.Process(
            target=_manager_main,
            args=(
                self.cfg,
                self.metadata,
                self.reservoir,
                self.stop_event,
                self.err_queue,
            ),
            daemon=True,
        )
        self.manager_proc.start()

    def __iter__(self) -> collections.abc.Iterable[ExampleBatch]:
        """Yields batches."""
        self._start_manager()
        n, b = 0, 0

        try:
            while n < self.n_samples:
                need = min(self.cfg.batch_size, self.n_samples - n)
                if not self.err_queue.empty():
                    who, tb = self.err_q.get_nowait()
                    raise RuntimeError(f"{who} crashed:\n{tb}")

                try:
                    act, meta = self.reservoir.get(
                        need, timeout=self.cfg.batch_timeout_s
                    )
                    n += need
                    b += 1
                    image_i, patch_i = meta.T
                    yield self.ExampleBatch(act=act, image_i=image_i, patch_i=patch_i)
                    continue
                except TimeoutError:
                    self.logger.info(
                        "Did not get a batch from %d worker threads in %.1fs seconds.",
                        self.cfg.n_threads,
                        self.cfg.batch_timeout_s,
                    )

                # If we don't continue, then we should check on the manager process.
                if not self.manager_proc.is_alive():
                    raise RuntimeError(
                        f"Manager process died unexpectedly after {b}/{len(self)} batches."
                    )

        finally:
            self.shutdown()

    def shutdown(self):
        if (
            hasattr(self, "stop_event")
            and self.stop_event
            and not self.stop_event.is_set()
        ):
            self.stop_event.set()

        if (
            hasattr(self, "manager_proc")
            and self.manager_proc
            and self.manager_proc.is_alive()
        ):
            self.manager_proc.join(timeout=5.0)
            if self.manager_proc.is_alive():
                self.logger.warning(
                    "Manager process did not shut down cleanly, killing."
                )
                self.manager_proc.kill()

        if hasattr(self, "reservoir") and self.reservoir:
            self.reservoir.close()

        self.manager_proc = None
        self.reservoir = None
        self.stop_event = None

    def __del__(self):
        self.shutdown()

    def _calculate_n_samples(self) -> int:
        """Helper to calculate total number of examples based on config."""
        match (self.cfg.patches, self.cfg.layer):
            case ("cls", "all"):
                return self.metadata.n_imgs * len(self.metadata.layers)
            case ("cls", int()):
                return self.metadata.n_imgs
            case ("image", int()):
                return self.metadata.n_imgs * self.metadata.n_patches_per_img
            case ("image", "all"):
                return (
                    self.metadata.n_imgs
                    * len(self.metadata.layers)
                    * self.metadata.n_patches_per_img
                )
            case _:
                typing.assert_never((self.cfg.patches, self.cfg.layer))

    def __len__(self) -> int:
        """Returns the number of batches in an epoch."""
        return math.ceil(self.n_samples / self.cfg.batch_size)

Class variables

var ExampleBatch

Individual example.

Instance variables

prop batch_size : int
Expand source code
@property
def batch_size(self) -> int:
    return self.cfg.batch_size
prop drop_last : int
Expand source code
@property
def drop_last(self) -> int:
    return self.cfg.drop_last
prop manager_pid : int
Expand source code
@property
def manager_pid(self) -> int:
    if not self.manager_proc or not self.manager_proc.is_alive():
        return -1

    return self.manager_proc.pid
prop n_batches : int
Expand source code
@property
def n_batches(self) -> int:
    return len(self)
prop n_samples : int
Expand source code
@property
def n_samples(self) -> int:
    return self._n_samples

Methods

def shutdown(self)
class Dataset (cfg: Config)

Dataset of activations from disk.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class Dataset(torch.utils.data.Dataset):
    """
    Dataset of activations from disk.
    """

    class Example(typing.TypedDict):
        """Individual example."""

        act: Float[Tensor, " d_vit"]
        image_i: int
        patch_i: int

    cfg: Config
    """Configuration; set via CLI args."""
    metadata: writers.Metadata
    """Activations metadata; automatically loaded from disk."""
    layer_index: int
    """Layer index into the shards if we are choosing a specific layer."""

    def __init__(self, cfg: Config):
        self.cfg = cfg
        if not os.path.isdir(self.cfg.shard_root):
            raise RuntimeError(f"Activations are not saved at '{self.cfg.shard_root}'.")

        self.metadata = writers.Metadata.load(self.cfg.shard_root)

        # Pick a really big number so that if you accidentally use this when you shouldn't, you get an out of bounds IndexError.
        self.layer_index = 1_000_000
        if isinstance(self.cfg.layer, int):
            err_msg = f"Non-exact matches for .layer field not supported; {self.cfg.layer} not in {self.metadata.layers}."
            assert self.cfg.layer in self.metadata.layers, err_msg
            self.layer_index = self.metadata.layers.index(self.cfg.layer)

    def transform(self, act: Float[np.ndarray, " d_vit"]) -> Float[Tensor, " d_vit"]:
        act = torch.from_numpy(act.copy())
        return act

    @property
    def d_vit(self) -> int:
        """Dimension of the underlying vision transformer's embedding space."""
        return self.metadata.d_vit

    def __getitem__(self, i: int) -> Example:
        # Add bounds checking
        if i < 0 or i >= len(self):
            raise IndexError(
                f"Index {i} out of range for dataset of length {len(self)}"
            )

        match (self.cfg.patches, self.cfg.layer):
            case ("cls", int()):
                img_act = self.get_img_patches(i)
                # Select layer's cls token.
                act = img_act[self.layer_index, 0, :]
                return self.Example(act=self.transform(act), image_i=i, patch_i=-1)
            case ("image", int()):
                # Calculate which image and patch this index corresponds to
                image_i = i // self.metadata.n_patches_per_img
                patch_i = i % self.metadata.n_patches_per_img

                # Calculate shard location
                n_imgs_per_shard = (
                    self.metadata.max_patches_per_shard
                    // len(self.metadata.layers)
                    // (self.metadata.n_patches_per_img + int(self.metadata.cls_token))
                )

                shard = image_i // n_imgs_per_shard
                img_pos_in_shard = image_i % n_imgs_per_shard

                acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin")
                shape = (
                    n_imgs_per_shard,
                    len(self.metadata.layers),
                    self.metadata.n_patches_per_img + int(self.metadata.cls_token),
                    self.metadata.d_vit,
                )
                acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape)

                # Account for CLS token offset when accessing patches
                patch_idx_with_cls = patch_i + int(self.metadata.cls_token)

                # Get the activation
                act = acts[img_pos_in_shard, self.layer_index, patch_idx_with_cls]

                return self.Example(
                    act=self.transform(act),
                    image_i=image_i,
                    patch_i=patch_i,
                )
            case _:
                print((self.cfg.patches, self.cfg.layer))
                typing.assert_never((self.cfg.patches, self.cfg.layer))

    def get_img_patches(
        self, i: int
    ) -> Float[np.ndarray, "n_layers all_patches d_vit"]:
        n_imgs_per_shard = (
            self.metadata.max_patches_per_shard
            // len(self.metadata.layers)
            // (self.metadata.n_patches_per_img + int(self.metadata.cls_token))
        )
        shard = i // n_imgs_per_shard
        pos = i % n_imgs_per_shard
        acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin")
        shape = (
            n_imgs_per_shard,
            len(self.metadata.layers),
            self.metadata.n_patches_per_img + int(self.metadata.cls_token),
            self.metadata.d_vit,
        )
        acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape)
        # Note that this is not yet copied!
        return acts[pos]

    def __len__(self) -> int:
        """
        Dataset length depends on `patches` and `layer`.
        """
        match (self.cfg.patches, self.cfg.layer):
            case ("cls", "all"):
                # Return a CLS token from a random image and random layer.
                return self.metadata.n_imgs * len(self.metadata.layers)
            case ("cls", int()):
                # Return a CLS token from a random image and fixed layer.
                return self.metadata.n_imgs
            case ("image", int()):
                # Return a patch from a random image, fixed layer, and random patch.
                return self.metadata.n_imgs * (self.metadata.n_patches_per_img)
            case ("image", "all"):
                # Return a patch from a random image, random layer and random patch.
                return (
                    self.metadata.n_imgs
                    * len(self.metadata.layers)
                    * self.metadata.n_patches_per_img
                )
            case _:
                typing.assert_never((self.cfg.patches, self.cfg.layer))

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic

Class variables

var Example

Individual example.

var cfgConfig

Configuration; set via CLI args.

var layer_index : int

Layer index into the shards if we are choosing a specific layer.

var metadataMetadata

Activations metadata; automatically loaded from disk.

Instance variables

prop d_vit : int

Dimension of the underlying vision transformer's embedding space.

Expand source code
@property
def d_vit(self) -> int:
    """Dimension of the underlying vision transformer's embedding space."""
    return self.metadata.d_vit

Methods

def get_img_patches(self, i: int) ‑> jaxtyping.Float[ndarray, 'n_layers all_patches d_vit']
def transform(self, act: jaxtyping.Float[ndarray, 'd_vit']) ‑> jaxtyping.Float[Tensor, 'd_vit']
class Metadata (vit_family: Literal['clip', 'siglip', 'dinov2'],
vit_ckpt: str,
layers: tuple[int, ...],
n_patches_per_img: int,
cls_token: bool,
d_vit: int,
n_imgs: int,
max_patches_per_shard: int,
data: dict[str, object],
dtype: Literal['float32'] = 'float32',
protocol: Literal['1.0.0'] = '1.0.0')

Metadata(vit_family: Literal['clip', 'siglip', 'dinov2'], vit_ckpt: str, layers: tuple[int, …], n_patches_per_img: int, cls_token: bool, d_vit: int, n_imgs: int, max_patches_per_shard: int, data: dict[str, object], dtype: Literal['float32'] = 'float32', protocol: Literal['1.0.0'] = '1.0.0')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Metadata:
    vit_family: typing.Literal["clip", "siglip", "dinov2"]
    vit_ckpt: str
    layers: tuple[int, ...]
    n_patches_per_img: int
    cls_token: bool
    d_vit: int
    n_imgs: int
    max_patches_per_shard: int
    data: dict[str, object]
    dtype: typing.Literal["float32"] = "float32"
    protocol: typing.Literal["1.0.0"] = "1.0.0"

    def __post_init__(self):
        # Check that at least one image per shard can fit.
        assert self.n_imgs_per_shard >= 1, (
            "At least one image per shard must fit; increase max_patches_per_shard."
        )

    @classmethod
    def from_cfg(cls, cfg: Config) -> "Metadata":
        return cls(
            cfg.vit_family,
            cfg.vit_ckpt,
            tuple(cfg.vit_layers),
            cfg.n_patches_per_img,
            cfg.cls_token,
            cfg.d_vit,
            cfg.data.n_imgs,
            cfg.max_patches_per_shard,
            {**dataclasses.asdict(cfg.data), "__class__": cfg.data.__class__.__name__},
        )

    @classmethod
    def load(cls, shard_root: str) -> "Metadata":
        with open(os.path.join(shard_root, "metadata.json")) as fd:
            dct = json.load(fd)
        dct["layers"] = tuple(dct.pop("layers"))
        return cls(**dct)

    def dump(self, shard_root: str):
        with open(os.path.join(shard_root, "metadata.json"), "w") as fd:
            json.dump(dataclasses.asdict(self), fd, indent=4)

    @property
    def hash(self) -> str:
        cfg_bytes = json.dumps(
            dataclasses.asdict(self), sort_keys=True, separators=(",", ":")
        ).encode("utf-8")
        return hashlib.sha256(cfg_bytes).hexdigest()

    @property
    def n_tokens_per_img(self) -> int:
        return self.n_patches_per_img + int(self.cls_token)

    @property
    def n_shards(self) -> int:
        return math.ceil(self.n_imgs / self.n_imgs_per_shard)

    @property
    def n_imgs_per_shard(self) -> int:
        """
        Calculate the number of images per shard based on the protocol.

        Returns:
            Number of images that fit in a shard.
        """
        n_tokens_per_img = self.n_patches_per_img + (1 if self.cls_token else 0)
        return self.max_patches_per_shard // (n_tokens_per_img * len(self.layers))

    @property
    def shard_shape(self) -> tuple[int, int, int, int]:
        return (
            self.n_imgs_per_shard,
            len(self.layers),
            self.n_tokens_per_img,
            self.d_vit,
        )

Class variables

var cls_token : bool
var d_vit : int
var data : dict[str, object]
var dtype : Literal['float32']
var layers : tuple[int, ...]
var max_patches_per_shard : int
var n_imgs : int
var n_patches_per_img : int
var protocol : Literal['1.0.0']
var vit_ckpt : str
var vit_family : Literal['clip', 'siglip', 'dinov2']

Static methods

def from_cfg(cls,
cfg: Config) ‑> Metadata
def load(cls, shard_root: str) ‑> Metadata

Instance variables

prop hash : str
Expand source code
@property
def hash(self) -> str:
    cfg_bytes = json.dumps(
        dataclasses.asdict(self), sort_keys=True, separators=(",", ":")
    ).encode("utf-8")
    return hashlib.sha256(cfg_bytes).hexdigest()
prop n_imgs_per_shard : int

Calculate the number of images per shard based on the protocol.

Returns

Number of images that fit in a shard.

Expand source code
@property
def n_imgs_per_shard(self) -> int:
    """
    Calculate the number of images per shard based on the protocol.

    Returns:
        Number of images that fit in a shard.
    """
    n_tokens_per_img = self.n_patches_per_img + (1 if self.cls_token else 0)
    return self.max_patches_per_shard // (n_tokens_per_img * len(self.layers))
prop n_shards : int
Expand source code
@property
def n_shards(self) -> int:
    return math.ceil(self.n_imgs / self.n_imgs_per_shard)
prop n_tokens_per_img : int
Expand source code
@property
def n_tokens_per_img(self) -> int:
    return self.n_patches_per_img + int(self.cls_token)
prop shard_shape : tuple[int, int, int, int]
Expand source code
@property
def shard_shape(self) -> tuple[int, int, int, int]:
    return (
        self.n_imgs_per_shard,
        len(self.layers),
        self.n_tokens_per_img,
        self.d_vit,
    )

Methods

def dump(self, shard_root: str)