Module saev.data
SAEV Sharded-Activation File Protocol v1 (2025-06-17)
saev caches activations to disk rather than run ViT or LLM inference when training SAEs.
Gemma Scope makes this decision as well (see Section 3.3.2 of https://arxiv.org/pdf/2408.05147).
saev.data
has a specific protocol to support this in on OSC, a super computer center, and take advantage of OSC's specific disk performance.
Goal: loss-lessly persist very large Transformer (ViT or LLM) activations in a form that is:
- mem-mappable
- Parameterized solely by the experiment configuration (
Config
) - Referenced by a content-hash, so identical configs collide, divergent ones never do
- Can be read quickly in a random order for training, and can be read (slowly) with random-access for visuals.
This document is the single normative source. Any divergence in code is a bug.
1. Directory layout
<dump_to>/<HASH>/
metadata.json # UTF-8 JSON, human-readable, describes data-generating config
shards.json # UTF-8 JSON, human-readable, describes shards.
acts000000.bin # shard 0
acts000001.bin # shard 1
...
actsNNNNNN.bin # shard NNNNNN (zero-padded width=6)
HASH
= sha256(json.dumps(metadata, sort_keys=True, separators=(',', ':')).encode('utf-8'))
Guards against silent config drift.
2. JSON file schemas
2.1. metadata.json
field | type | semantic |
---|---|---|
vit_family |
string | "clip" \| "siglip" \| "dinov2" |
vit_ckpt |
string | model identifier (OpenCLIP, HF, etc.) |
layers |
int[] | ViT residual‐block indices recorded |
n_patches_per_img |
int | image patches only (excludes CLS) |
cls_token |
bool | true -> patch 0 is CLS, else no CLS |
d_vit |
int | activation dimensionality |
n_imgs |
int | total images in dataset |
max_patches_per_shard |
int | logical activations per shard (see #3) |
data |
object | opaque dataset description |
dtype |
string | numpy dtype. Fixed "float32" for now. |
protocol |
string | "1.0.0" for now. |
The data
object is dataclasses.asdict(cfg.data)
, with an additional __class__
field with cfg.data.__class__.__name__
as the value.
2.2. shards.json
A single array of shard
objects, each of which has the following fields:
field | type | semantic |
---|---|---|
name | string | shard filename (acts000000.bin ). |
n_imgs | int | the number of images in the shard. |
3 Shard sizing maths
n_tokens_per_img = n_patches_per_img + (1 if cls_token else 0)
n_imgs_per_shard = floor(max_patches_per_shard / (n_tokens_per_img * len(layers)))
shape_per_shard = (
n_imgs_per_shard, len(layers), n_tokens_per_img, d_vit,
)
max_patches_per_shard
is a budget (default ~2.4 M) chosen so a shard is approximately 10 GiB for Float32 @ d_vit = 1024
.
The last shard will have a smaller value for n_imgs_per_shard
; this value is documented in n_imgs
in shards.json
4. Data Layout and Global Indexing
The entire dataset of activations is treated as a single logical 4D tensor with the shape (n_imgs, len(layers), n_tokens_per_img, d_vit)
. This logical tensor is C-contiguous with axes ordered [Image, Layer, Token, Dimension]
.
Physically, this tensor is split along the first axis (Image
) into multiple shards, where each shard is a single binary file. The number of images in each shard is constant, except for the final shard, which may be smaller.
To locate an arbitrary activation vector, a reader must convert a logical coordinate (global_img_idx
, layer_value
, token_idx
) into a file path and an offset within that file.
4.1 Definitions
Let the parameters from metadata.json
be:
- L =
len(layers)
- P =
n_patches_per_img
- T =
P + (1 if cls_token else 0)
(Total tokens per image) - D =
d_vit
- S =
n_imgs
fromshards.json
orn_imgs_per_shard
from Section 3 (shard sizing).
4.2 Coordinate Transformations
Given a logical coordinate:
global_img_idx
: integer, with0 <= global_img_idx < n_imgs
layer
: integer, must be an element oflayers
token_idx
: integer,0 <= token_idx < T
The physical location is found as follows:
-
Identify Shard:
shard_idx = global_img_idx // S
img_in_shard = global_img_idx % S
The target file isacts{shard_idx:06d}.bin
.
-
Identify Layer Index: The stored data contains a subset of the ViT's layers. The logical
layer_value
must be mapped to its index in the storedlayers
array.layer_idx = layers.index(layer)
A reader must raise an error iflayer
is not inlayers
.
-
Calculate Offset: The data within a shard is a 4D tensor of shape
(S, L, T, D)
. The offset to the first byte of the desired activation vector[img_in_shard, layer_in_list_idx, token_idx]
is:offset_in_vectors = (img_in_shard * L * T) + (layer_in_list_idx * T) + token_idx
offset_in_bytes = offset_in_vectors * D * 4
(assuming 4 bytes forfloat32
)
A reader can then seek to offset_in_bytes
and read $D \times 4$ bytes to retrieve the vector.
Alternatively, rather than calculate the offset, readers can memmap the shard, then use Numpy indexing to get the activation vector.
4.3 Token Axis Layout
The token
axis of length $T$ is ordered as follows:
* If cls_token
is true
:
* Index 0
: [CLS] token activation
* Indices 1
to $P$: Patch token activations
* If cls_token
is false
:
* Indices 0
to $P-1$: Patch token activations
The relative order of patch tokens is preserved exactly as produced by the upstream Vision Transformer.
5 Versioning & compatibility
- Major changes (shape reorder, dtype switch, new required JSON keys) increment the major protocol version number at the top of this document and must emit a breaking warning in loader code.
- Minor, backward-compatible additions (new optional JSON key) merely update this doc and the minor protocol version number.
That's the whole deal. No hidden invariants. Anything else you find in code that contradicts this sheet, fix the code or update the spec.
Performance
SAEs are mostly disk-bound. Gemma Scope (Google SAE paper) aimed for 1 GB/s to keep their GPUS brrr'ing. This is pretty hard even with sequential reads, much less random access.
I run all my experiments on OSC and their scratch filesystem /fs/scratch
has sequential read speeds of around 800 MB/s and random access speeds around 22 MB/s.
I got these numbers with:
fio --name=net --filename=/fs/scratch/PAS2136/samuelstevens/cache/saev/366017a10220b85014ae0a594276b25f6ea3d756b74d1d3218da1e34ffcf32e9/acts000000.bin --rw=read --bs=1M --direct=1 --iodepth=16 --runtime=30 --time_based
and
fio --name=net --filename=/fs/scratch/PAS2136/samuelstevens/cache/saev/366017a10220b85014ae0a594276b25f6ea3d756b74d1d3218da1e34ffcf32e9/acts000000.bin --rw=randread --bs=4K --direct=1 --iodepth=16 --runtime=30 --time_based
These two commands reported, respectively:
READ: bw=796MiB/s (835MB/s), 796MiB/s-796MiB/s (835MB/s-835MB/s), io=23.3GiB (25.0GB), run=30001-30001msec
and
READ: bw=22.9MiB/s (24.0MB/s), 22.9MiB/s-22.9MiB/s (24.0MB/s-24.0MB/s), io=687MiB (721MB), run=30001-30001msec
My naive pytorch-style dataset that uses multiple processes to feed a dataloader did purely random reads and was too slow. It reports around 500 examples/s:
I've implemented a dataloader that tries to do sequential reads rather than random reads in saev/data/iterable.py
.
It's much faster (around 4.5K examples/s) on OSC.
I know that it should be even faster; the dataset of 128M examples is 2.9TB, my sequential disk read speed is 800 MB/s, so it should take ~1 hr. For 128M examples at 4.5K examples/s, it should take 7.9 hours. You can see this on a wandb run here which reports 14.6% disk utilization. Certainly that can be higher.
Not sure if this is the correct way to think about it, but: 100 / 14.6 = 6.8, close to 7.9 hours.
Ordered Dataloader Design
The saev/data/ordered.py
module implements a high-throughput ordered dataloader that guarantees sequential data delivery.
This is useful for iterating through all patches in an image at once.
Key Design Decisions
- Single-threaded I/O in Manager Process
Initially, the dataloader used multiple worker threads for parallel I/O, similar to PyTorch's DataLoader. However, this created a fundamental ordering problem: when multiple workers read batches in parallel, they complete at different times and deliver batches out of order.
We switched to single-threaded I/O because: - Sequential reads from memory-mapped files are already highly optimized by the OS - The OS page cache provides excellent performance for sequential access patterns - Eliminating multi-threading removes all batch reordering complexity - The simpler design is more maintainable and debuggable
- Process Separation with Ring Buffer
The dataloader still uses a separate manager process connected via a multiprocessing Queue (acting as a ring buffer). This provides:
- Overlap between I/O and computation
- Configurable read-ahead via buffer_size
parameter
- Natural backpressure when computation is slower than I/O
- Process isolation for better resource management
- Shard-Aware Sequential Reading
The dataloader correctly handles the actual distribution of data across shards by:
- Reading shards.json
to get the exact number of images per shard
- Maintaining cumulative offsets for efficient index-to-shard mapping
- Handling batches that span multiple shards without gaps or duplicates
Performance Considerations
- Memory-mapped files: Using
np.memmap
allows efficient access to large files without loading them entirely into memory - Sequential access pattern: The dataloader reads data in the exact order it's stored on disk, maximizing OS cache effectiveness
- Minimal data copying: Activations are copied only once from the memory-mapped file to PyTorch tensors
- Read-ahead buffering: The configurable buffer size allows tuning the trade-off between memory usage and I/O overlap
Trade-offs
The single-threaded design trades potential parallel I/O throughput for:
- Guaranteed ordering
- Simplicity and maintainability
- Elimination of synchronization overhead
- Predictable performance characteristics
In practice, the sequential read performance is sufficient for most use cases, especially when the computation (e.g., SAE forward pass) is the bottleneck rather than I/O.
Sub-modules
saev.data.buffers
saev.data.config
saev.data.images
saev.data.indexed
saev.data.models
saev.data.ordered
-
Ordered (sequential) dataloader for activation data …
saev.data.shuffled
saev.data.writers
Classes
class IndexedConfig (shard_root: str = './shards',
patches: Literal['cls', 'image', 'all'] = 'image',
layer: Union[int, Literal['all']] = -2,
seed: int = 17,
debug: bool = False)-
Configuration for loading indexed activation data from disk.
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Config: """Configuration for loading indexed activation data from disk.""" shard_root: str = os.path.join(".", "shards") """Directory with .bin shards and a metadata.json file.""" patches: typing.Literal["cls", "image", "all"] = "image" """Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches.""" layer: int | typing.Literal["all"] = -2 """Which ViT layer(s) to read from disk. ``-2`` selects the second-to-last layer. ``"all"`` enumerates every recorded layer.""" seed: int = 17 """Random seed.""" debug: bool = False """Whether the dataloader process should log debug messages."""
Class variables
var debug : bool
-
Whether the dataloader process should log debug messages.
var layer : Union[int, Literal['all']]
-
Which ViT layer(s) to read from disk.
-2
selects the second-to-last layer."all"
enumerates every recorded layer. var patches : Literal['cls', 'image', 'all']
-
Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches.
var seed : int
-
Random seed.
var shard_root : str
-
Directory with .bin shards and a metadata.json file.
class OrderedConfig (shard_root: str = './shards',
patches: Literal['cls', 'image', 'all'] = 'image',
layer: Union[int, Literal['all']] = -2,
batch_size: int = 16384,
batch_timeout_s: float = 30.0,
drop_last: bool = False,
buffer_size: int = 64,
debug: bool = False)-
Configuration for loading ordered (non-shuffled) activation data from disk.
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Config: """Configuration for loading ordered (non-shuffled) activation data from disk.""" shard_root: str = os.path.join(".", "shards") """Directory with .bin shards and a metadata.json file.""" patches: typing.Literal["cls", "image", "all"] = "image" """Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches.""" layer: int | typing.Literal["all"] = -2 """Which ViT layer(s) to read from disk. ``-2`` selects the second-to-last layer. ``"all"`` enumerates every recorded layer.""" batch_size: int = 1024 * 16 """Batch size.""" batch_timeout_s: float = 30.0 """How long to wait for at least one batch.""" drop_last: bool = False """Whether to drop the last batch if it's smaller than the others.""" buffer_size: int = 64 """Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls.""" debug: bool = False """Whether the dataloader process should log debug messages."""
Class variables
var batch_size : int
-
Batch size.
var batch_timeout_s : float
-
How long to wait for at least one batch.
var buffer_size : int
-
Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls.
var debug : bool
-
Whether the dataloader process should log debug messages.
var drop_last : bool
-
Whether to drop the last batch if it's smaller than the others.
var layer : Union[int, Literal['all']]
-
Which ViT layer(s) to read from disk.
-2
selects the second-to-last layer."all"
enumerates every recorded layer. var patches : Literal['cls', 'image', 'all']
-
Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches.
var shard_root : str
-
Directory with .bin shards and a metadata.json file.
class ShuffledConfig (shard_root: str = './shards',
patches: Literal['cls', 'image', 'all'] = 'image',
layer: Union[int, Literal['all']] = -2,
batch_size: int = 16384,
drop_last: bool = False,
scale_norm: bool = False,
n_threads: int = 4,
buffer_size: int = 64,
batch_timeout_s: float = 30.0,
seed: int = 17,
debug: bool = False,
log_every_s: float = 30.0)-
Configuration for loading shuffled activation data from disk.
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Config: """Configuration for loading shuffled activation data from disk.""" shard_root: str = os.path.join(".", "shards") """Directory with .bin shards and a metadata.json file.""" patches: typing.Literal["cls", "image", "all"] = "image" """Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches.""" layer: int | typing.Literal["all"] = -2 """Which ViT layer(s) to read from disk. ``-2`` selects the second-to-last layer. ``"all"`` enumerates every recorded layer.""" batch_size: int = 1024 * 16 """Batch size.""" drop_last: bool = False """Whether to drop the last batch if it's smaller than the others.""" scale_norm: bool = False """Whether to scale norms to sqrt(D).""" # Performance n_threads: int = 4 """Number of dataloading threads.""" buffer_size: int = 64 """Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls.""" batch_timeout_s: float = 30.0 """How long to wait for at least one batch.""" # Diagnostics seed: int = 17 """Random seed.""" debug: bool = False """Whether the dataloader process should log debug messages.""" log_every_s: float = 30.0 """How frequently the dataloader process should log (debug) performance messages."""
Class variables
var batch_size : int
-
Batch size.
var batch_timeout_s : float
-
How long to wait for at least one batch.
var buffer_size : int
-
Number of batches to queue in the shared-memory ring buffer. Higher values add latency but improve resilience to brief stalls.
var debug : bool
-
Whether the dataloader process should log debug messages.
var drop_last : bool
-
Whether to drop the last batch if it's smaller than the others.
var layer : Union[int, Literal['all']]
-
Which ViT layer(s) to read from disk.
-2
selects the second-to-last layer."all"
enumerates every recorded layer. var log_every_s : float
-
How frequently the dataloader process should log (debug) performance messages.
var n_threads : int
-
Number of dataloading threads.
var patches : Literal['cls', 'image', 'all']
-
Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'image' indicates it will return image patches. 'all' returns all patches.
var scale_norm : bool
-
Whether to scale norms to sqrt(D).
var seed : int
-
Random seed.
var shard_root : str
-
Directory with .bin shards and a metadata.json file.
class OrderedDataLoader (cfg: Config)
-
High-throughput streaming loader that reads data from disk shards in order (no shuffling).
Expand source code
@beartype.beartype class DataLoader: """ High-throughput streaming loader that reads data from disk shards in order (no shuffling). """ @jaxtyped(typechecker=beartype.beartype) class ExampleBatch(typing.TypedDict): """Individual example.""" act: Float[Tensor, "batch d_vit"] image_i: Int[Tensor, " batch"] patch_i: Int[Tensor, " batch"] def __init__(self, cfg: Config): self.cfg = cfg if not os.path.isdir(self.cfg.shard_root): raise RuntimeError(f"Activations are not saved at '{self.cfg.shard_root}'.") self.metadata = writers.Metadata.load(self.cfg.shard_root) self.logger = logging.getLogger("ordered.DataLoader") self.ctx = mp.get_context() self.manager_proc = None self.batch_queue = None self.stop_event = None self._n_samples = self._calculate_n_samples() @property def n_batches(self) -> int: return len(self) @property def n_samples(self) -> int: return self._n_samples @property def batch_size(self) -> int: return self.cfg.batch_size @property def drop_last(self) -> int: return self.cfg.drop_last def _start_manager(self): # Always shutdown existing manager to ensure fresh start if self.manager_proc and self.manager_proc.is_alive(): self.logger.info("Shutting down existing manager process.") self.shutdown() self.logger.info("Starting manager process.") # Create the batch queue self.batch_queue = self.ctx.Queue(maxsize=self.cfg.buffer_size) self.stop_event = self.ctx.Event() self.err_queue = self.ctx.Queue(maxsize=2) # Manager + main process self.manager_proc = self.ctx.Process( target=_manager_main, args=( self.cfg, self.metadata, self.batch_queue, self.stop_event, self.err_queue, ), daemon=True, ) self.manager_proc.start() def __iter__(self) -> collections.abc.Iterable[ExampleBatch]: """Yields batches in order.""" self._start_manager() n = 0 try: while n < self.n_samples: if not self.err_queue.empty(): who, tb = self.err_queue.get_nowait() raise RuntimeError(f"{who} crashed:\n{tb}") try: batch = self.batch_queue.get(timeout=self.cfg.batch_timeout_s) actual_batch_size = batch["act"].shape[0] # Handle drop_last if ( self.cfg.drop_last and actual_batch_size < self.cfg.batch_size and n + actual_batch_size >= self.n_samples ): break n += actual_batch_size yield self.ExampleBatch(**batch) continue except queue.Empty: self.logger.info( "Did not get a batch from manager process in %.1fs seconds.", self.cfg.batch_timeout_s, ) # If we don't continue, then we should check on the manager process. if not self.manager_proc.is_alive(): raise RuntimeError( f"Manager process died unexpectedly after {n}/{self.n_samples} samples." ) finally: self.shutdown() def shutdown(self): if ( hasattr(self, "stop_event") and self.stop_event and not self.stop_event.is_set() ): self.stop_event.set() if ( hasattr(self, "manager_proc") and self.manager_proc and self.manager_proc.is_alive() ): self.manager_proc.join(timeout=5.0) if self.manager_proc.is_alive(): self.logger.warning( "Manager process did not shut down cleanly, killing." ) self.manager_proc.kill() self.manager_proc = None self.batch_queue = None self.stop_event = None def __del__(self): self.shutdown() def _calculate_n_samples(self) -> int: """Helper to calculate total number of examples based on config.""" match (self.cfg.patches, self.cfg.layer): case ("cls", "all"): return self.metadata.n_imgs * len(self.metadata.layers) case ("cls", int()): return self.metadata.n_imgs case ("image", int()): return self.metadata.n_imgs * self.metadata.n_patches_per_img case ("image", "all"): return ( self.metadata.n_imgs * len(self.metadata.layers) * self.metadata.n_patches_per_img ) case _: typing.assert_never((self.cfg.patches, self.cfg.layer)) def __len__(self) -> int: """Returns the number of batches in an epoch.""" if self.cfg.drop_last: return self.n_samples // self.cfg.batch_size else: return math.ceil(self.n_samples / self.cfg.batch_size)
Class variables
var ExampleBatch
-
Individual example.
Instance variables
prop batch_size : int
-
Expand source code
@property def batch_size(self) -> int: return self.cfg.batch_size
prop drop_last : int
-
Expand source code
@property def drop_last(self) -> int: return self.cfg.drop_last
prop n_batches : int
-
Expand source code
@property def n_batches(self) -> int: return len(self)
prop n_samples : int
-
Expand source code
@property def n_samples(self) -> int: return self._n_samples
Methods
def shutdown(self)
class ShuffledDataLoader (cfg: Config)
-
High-throughput streaming loader that deterministically shuffles data from disk shards.
Expand source code
@beartype.beartype class DataLoader: """ High-throughput streaming loader that deterministically shuffles data from disk shards. """ @jaxtyped(typechecker=beartype.beartype) class ExampleBatch(typing.TypedDict): """Individual example.""" act: Float[Tensor, "batch d_vit"] image_i: Int[Tensor, " batch"] patch_i: Int[Tensor, " batch"] def __init__(self, cfg: Config): self.cfg = cfg self.manager_proc = None self.reservoir = None self.stop_event = None self.logger = logging.getLogger("shuffled.DataLoader") self.ctx = mp.get_context() if not os.path.isdir(self.cfg.shard_root): raise RuntimeError(f"Activations are not saved at '{self.cfg.shard_root}'.") if self.cfg.scale_norm: raise NotImplementedError("scale_norm not implemented.") self.metadata = writers.Metadata.load(self.cfg.shard_root) self._n_samples = self._calculate_n_samples() @property def n_batches(self) -> int: return len(self) @property def n_samples(self) -> int: return self._n_samples @property def batch_size(self) -> int: return self.cfg.batch_size @property def drop_last(self) -> int: return self.cfg.drop_last @property def manager_pid(self) -> int: if not self.manager_proc or not self.manager_proc.is_alive(): return -1 return self.manager_proc.pid def _start_manager(self): if self.manager_proc and self.manager_proc.is_alive(): return self.logger.info("Starting manager process.") # Create the shared-memory buffers self.reservoir = buffers.ReservoirBuffer( self.cfg.buffer_size * self.cfg.batch_size, (self.metadata.d_vit,), dtype=torch.float32, meta_shape=(2,), meta_dtype=torch.int32, collate_fn=torch.utils.data.default_collate, ) self.stop_event = self.ctx.Event() self.err_queue = self.ctx.Queue(maxsize=self.cfg.n_threads + 1) self.manager_proc = self.ctx.Process( target=_manager_main, args=( self.cfg, self.metadata, self.reservoir, self.stop_event, self.err_queue, ), daemon=True, ) self.manager_proc.start() def __iter__(self) -> collections.abc.Iterable[ExampleBatch]: """Yields batches.""" self._start_manager() n, b = 0, 0 try: while n < self.n_samples: need = min(self.cfg.batch_size, self.n_samples - n) if not self.err_queue.empty(): who, tb = self.err_q.get_nowait() raise RuntimeError(f"{who} crashed:\n{tb}") try: act, meta = self.reservoir.get( need, timeout=self.cfg.batch_timeout_s ) n += need b += 1 image_i, patch_i = meta.T yield self.ExampleBatch(act=act, image_i=image_i, patch_i=patch_i) continue except TimeoutError: self.logger.info( "Did not get a batch from %d worker threads in %.1fs seconds.", self.cfg.n_threads, self.cfg.batch_timeout_s, ) # If we don't continue, then we should check on the manager process. if not self.manager_proc.is_alive(): raise RuntimeError( f"Manager process died unexpectedly after {b}/{len(self)} batches." ) finally: self.shutdown() def shutdown(self): if ( hasattr(self, "stop_event") and self.stop_event and not self.stop_event.is_set() ): self.stop_event.set() if ( hasattr(self, "manager_proc") and self.manager_proc and self.manager_proc.is_alive() ): self.manager_proc.join(timeout=5.0) if self.manager_proc.is_alive(): self.logger.warning( "Manager process did not shut down cleanly, killing." ) self.manager_proc.kill() if hasattr(self, "reservoir") and self.reservoir: self.reservoir.close() self.manager_proc = None self.reservoir = None self.stop_event = None def __del__(self): self.shutdown() def _calculate_n_samples(self) -> int: """Helper to calculate total number of examples based on config.""" match (self.cfg.patches, self.cfg.layer): case ("cls", "all"): return self.metadata.n_imgs * len(self.metadata.layers) case ("cls", int()): return self.metadata.n_imgs case ("image", int()): return self.metadata.n_imgs * self.metadata.n_patches_per_img case ("image", "all"): return ( self.metadata.n_imgs * len(self.metadata.layers) * self.metadata.n_patches_per_img ) case _: typing.assert_never((self.cfg.patches, self.cfg.layer)) def __len__(self) -> int: """Returns the number of batches in an epoch.""" return math.ceil(self.n_samples / self.cfg.batch_size)
Class variables
var ExampleBatch
-
Individual example.
Instance variables
prop batch_size : int
-
Expand source code
@property def batch_size(self) -> int: return self.cfg.batch_size
prop drop_last : int
-
Expand source code
@property def drop_last(self) -> int: return self.cfg.drop_last
prop manager_pid : int
-
Expand source code
@property def manager_pid(self) -> int: if not self.manager_proc or not self.manager_proc.is_alive(): return -1 return self.manager_proc.pid
prop n_batches : int
-
Expand source code
@property def n_batches(self) -> int: return len(self)
prop n_samples : int
-
Expand source code
@property def n_samples(self) -> int: return self._n_samples
Methods
def shutdown(self)
class Dataset (cfg: Config)
-
Dataset of activations from disk.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Dataset(torch.utils.data.Dataset): """ Dataset of activations from disk. """ class Example(typing.TypedDict): """Individual example.""" act: Float[Tensor, " d_vit"] image_i: int patch_i: int cfg: Config """Configuration; set via CLI args.""" metadata: writers.Metadata """Activations metadata; automatically loaded from disk.""" layer_index: int """Layer index into the shards if we are choosing a specific layer.""" def __init__(self, cfg: Config): self.cfg = cfg if not os.path.isdir(self.cfg.shard_root): raise RuntimeError(f"Activations are not saved at '{self.cfg.shard_root}'.") self.metadata = writers.Metadata.load(self.cfg.shard_root) # Pick a really big number so that if you accidentally use this when you shouldn't, you get an out of bounds IndexError. self.layer_index = 1_000_000 if isinstance(self.cfg.layer, int): err_msg = f"Non-exact matches for .layer field not supported; {self.cfg.layer} not in {self.metadata.layers}." assert self.cfg.layer in self.metadata.layers, err_msg self.layer_index = self.metadata.layers.index(self.cfg.layer) def transform(self, act: Float[np.ndarray, " d_vit"]) -> Float[Tensor, " d_vit"]: act = torch.from_numpy(act.copy()) return act @property def d_vit(self) -> int: """Dimension of the underlying vision transformer's embedding space.""" return self.metadata.d_vit def __getitem__(self, i: int) -> Example: # Add bounds checking if i < 0 or i >= len(self): raise IndexError( f"Index {i} out of range for dataset of length {len(self)}" ) match (self.cfg.patches, self.cfg.layer): case ("cls", int()): img_act = self.get_img_patches(i) # Select layer's cls token. act = img_act[self.layer_index, 0, :] return self.Example(act=self.transform(act), image_i=i, patch_i=-1) case ("image", int()): # Calculate which image and patch this index corresponds to image_i = i // self.metadata.n_patches_per_img patch_i = i % self.metadata.n_patches_per_img # Calculate shard location n_imgs_per_shard = ( self.metadata.max_patches_per_shard // len(self.metadata.layers) // (self.metadata.n_patches_per_img + int(self.metadata.cls_token)) ) shard = image_i // n_imgs_per_shard img_pos_in_shard = image_i % n_imgs_per_shard acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin") shape = ( n_imgs_per_shard, len(self.metadata.layers), self.metadata.n_patches_per_img + int(self.metadata.cls_token), self.metadata.d_vit, ) acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape) # Account for CLS token offset when accessing patches patch_idx_with_cls = patch_i + int(self.metadata.cls_token) # Get the activation act = acts[img_pos_in_shard, self.layer_index, patch_idx_with_cls] return self.Example( act=self.transform(act), image_i=image_i, patch_i=patch_i, ) case _: print((self.cfg.patches, self.cfg.layer)) typing.assert_never((self.cfg.patches, self.cfg.layer)) def get_img_patches( self, i: int ) -> Float[np.ndarray, "n_layers all_patches d_vit"]: n_imgs_per_shard = ( self.metadata.max_patches_per_shard // len(self.metadata.layers) // (self.metadata.n_patches_per_img + int(self.metadata.cls_token)) ) shard = i // n_imgs_per_shard pos = i % n_imgs_per_shard acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin") shape = ( n_imgs_per_shard, len(self.metadata.layers), self.metadata.n_patches_per_img + int(self.metadata.cls_token), self.metadata.d_vit, ) acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape) # Note that this is not yet copied! return acts[pos] def __len__(self) -> int: """ Dataset length depends on `patches` and `layer`. """ match (self.cfg.patches, self.cfg.layer): case ("cls", "all"): # Return a CLS token from a random image and random layer. return self.metadata.n_imgs * len(self.metadata.layers) case ("cls", int()): # Return a CLS token from a random image and fixed layer. return self.metadata.n_imgs case ("image", int()): # Return a patch from a random image, fixed layer, and random patch. return self.metadata.n_imgs * (self.metadata.n_patches_per_img) case ("image", "all"): # Return a patch from a random image, random layer and random patch. return ( self.metadata.n_imgs * len(self.metadata.layers) * self.metadata.n_patches_per_img ) case _: typing.assert_never((self.cfg.patches, self.cfg.layer))
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Class variables
var Example
-
Individual example.
var cfg : Config
-
Configuration; set via CLI args.
var layer_index : int
-
Layer index into the shards if we are choosing a specific layer.
var metadata : Metadata
-
Activations metadata; automatically loaded from disk.
Instance variables
prop d_vit : int
-
Dimension of the underlying vision transformer's embedding space.
Expand source code
@property def d_vit(self) -> int: """Dimension of the underlying vision transformer's embedding space.""" return self.metadata.d_vit
Methods
def get_img_patches(self, i: int) ‑> jaxtyping.Float[ndarray, 'n_layers all_patches d_vit']
def transform(self, act: jaxtyping.Float[ndarray, 'd_vit']) ‑> jaxtyping.Float[Tensor, 'd_vit']
class Metadata (vit_family: Literal['clip', 'siglip', 'dinov2'],
vit_ckpt: str,
layers: tuple[int, ...],
n_patches_per_img: int,
cls_token: bool,
d_vit: int,
n_imgs: int,
max_patches_per_shard: int,
data: dict[str, object],
dtype: Literal['float32'] = 'float32',
protocol: Literal['1.0.0'] = '1.0.0')-
Metadata(vit_family: Literal['clip', 'siglip', 'dinov2'], vit_ckpt: str, layers: tuple[int, …], n_patches_per_img: int, cls_token: bool, d_vit: int, n_imgs: int, max_patches_per_shard: int, data: dict[str, object], dtype: Literal['float32'] = 'float32', protocol: Literal['1.0.0'] = '1.0.0')
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Metadata: vit_family: typing.Literal["clip", "siglip", "dinov2"] vit_ckpt: str layers: tuple[int, ...] n_patches_per_img: int cls_token: bool d_vit: int n_imgs: int max_patches_per_shard: int data: dict[str, object] dtype: typing.Literal["float32"] = "float32" protocol: typing.Literal["1.0.0"] = "1.0.0" def __post_init__(self): # Check that at least one image per shard can fit. assert self.n_imgs_per_shard >= 1, ( "At least one image per shard must fit; increase max_patches_per_shard." ) @classmethod def from_cfg(cls, cfg: Config) -> "Metadata": return cls( cfg.vit_family, cfg.vit_ckpt, tuple(cfg.vit_layers), cfg.n_patches_per_img, cfg.cls_token, cfg.d_vit, cfg.data.n_imgs, cfg.max_patches_per_shard, {**dataclasses.asdict(cfg.data), "__class__": cfg.data.__class__.__name__}, ) @classmethod def load(cls, shard_root: str) -> "Metadata": with open(os.path.join(shard_root, "metadata.json")) as fd: dct = json.load(fd) dct["layers"] = tuple(dct.pop("layers")) return cls(**dct) def dump(self, shard_root: str): with open(os.path.join(shard_root, "metadata.json"), "w") as fd: json.dump(dataclasses.asdict(self), fd, indent=4) @property def hash(self) -> str: cfg_bytes = json.dumps( dataclasses.asdict(self), sort_keys=True, separators=(",", ":") ).encode("utf-8") return hashlib.sha256(cfg_bytes).hexdigest() @property def n_tokens_per_img(self) -> int: return self.n_patches_per_img + int(self.cls_token) @property def n_shards(self) -> int: return math.ceil(self.n_imgs / self.n_imgs_per_shard) @property def n_imgs_per_shard(self) -> int: """ Calculate the number of images per shard based on the protocol. Returns: Number of images that fit in a shard. """ n_tokens_per_img = self.n_patches_per_img + (1 if self.cls_token else 0) return self.max_patches_per_shard // (n_tokens_per_img * len(self.layers)) @property def shard_shape(self) -> tuple[int, int, int, int]: return ( self.n_imgs_per_shard, len(self.layers), self.n_tokens_per_img, self.d_vit, )
Class variables
var cls_token : bool
var d_vit : int
var data : dict[str, object]
var dtype : Literal['float32']
var layers : tuple[int, ...]
var max_patches_per_shard : int
var n_imgs : int
var n_patches_per_img : int
var protocol : Literal['1.0.0']
var vit_ckpt : str
var vit_family : Literal['clip', 'siglip', 'dinov2']
Static methods
def from_cfg(cls,
cfg: Config) ‑> Metadatadef load(cls, shard_root: str) ‑> Metadata
Instance variables
prop hash : str
-
Expand source code
@property def hash(self) -> str: cfg_bytes = json.dumps( dataclasses.asdict(self), sort_keys=True, separators=(",", ":") ).encode("utf-8") return hashlib.sha256(cfg_bytes).hexdigest()
prop n_imgs_per_shard : int
-
Calculate the number of images per shard based on the protocol.
Returns
Number of images that fit in a shard.
Expand source code
@property def n_imgs_per_shard(self) -> int: """ Calculate the number of images per shard based on the protocol. Returns: Number of images that fit in a shard. """ n_tokens_per_img = self.n_patches_per_img + (1 if self.cls_token else 0) return self.max_patches_per_shard // (n_tokens_per_img * len(self.layers))
prop n_shards : int
-
Expand source code
@property def n_shards(self) -> int: return math.ceil(self.n_imgs / self.n_imgs_per_shard)
prop n_tokens_per_img : int
-
Expand source code
@property def n_tokens_per_img(self) -> int: return self.n_patches_per_img + int(self.cls_token)
prop shard_shape : tuple[int, int, int, int]
-
Expand source code
@property def shard_shape(self) -> tuple[int, int, int, int]: return ( self.n_imgs_per_shard, len(self.layers), self.n_tokens_per_img, self.d_vit, )
Methods
def dump(self, shard_root: str)