Module saev.activations
To save lots of activations, we want to do things in parallel, with lots of slurm jobs, and save multiple files, rather than just one.
This module handles that additional complexity.
Conceptually, activations are either thought of as
- A single [n_imgs x n_layers x (n_patches + 1), d_vit] tensor. This is a dataset
- Multiple [n_imgs_per_shard, n_layers, (n_patches + 1), d_vit] tensors. This is a set of sharded activations.
Functions
def get_acts_dir(cfg: Activations) ‑> str
-
Return the activations directory based on the relevant values of a config. Also saves a metadata.json file to that directory for human reference.
Args
cfg
- Config for experiment.
Returns
Directory to where activations should be dumped/loaded from.
def get_dataloader(cfg: Activations,
*,
img_transform=None)-
Gets the dataloader for the current experiment; delegates dataloader construction to dataset-specific functions.
Args
cfg
- Experiment config.
img_transform
- Image transform to be applied to each image.
Returns
A PyTorch Dataloader that yields dictionaries with
'image'
keys containing image batches. def get_dataset(cfg: ImagenetDataset | ImageFolderDataset | Ade20kDataset,
*,
img_transform)-
Gets the dataset for the current experiment; delegates construction to dataset-specific functions.
Args
cfg
- Experiment config.
img_transform
- Image transform to be applied to each image.
Returns
A dataset that has dictionaries with
'image'
,'index'
,'target'
, and'label'
keys containing examples. def get_default_dataloader(cfg: Activations,
*,
img_transform:) ‑> torch.utils.data.dataloader.DataLoader -
Get a dataloader for a default map-style dataset.
Args
cfg
- Config.
img_transform
- Image transform to be applied to each image.
Returns
A PyTorch Dataloader that yields dictionaries with
'image'
keys containing image batches,'index'
keys containing original dataset indices and'label'
keys containing label batches. def main(cfg: Activations)
-
Args
cfg
- Config for activations.
def make_img_transform(model_family: str, model_ckpt: str) ‑>
def make_vit(cfg: Activations)
def setup(cfg: Activations)
-
Run dataset-specific setup. These setup functions can assume they are the only job running, but they should be idempotent; they should be safe (and ideally cheap) to run multiple times in a row.
def setup_ade20k(cfg: Activations)
def setup_imagefolder(cfg: Activations)
def setup_imagenet(cfg: Activations)
def worker_fn(cfg: Activations)
-
Args
cfg
- Config for activations.
Classes
class Ade20k (cfg: Ade20kDataset,
*,
img_transform: collections.abc.Callable | None = None,
seg_transform: collections.abc.Callable | None = <function Ade20k.<lambda>>)-
An abstract class representing a :class:
Dataset
.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:
__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__
, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler
implementations and the default options of :class:~torch.utils.data.DataLoader
. Subclasses could also optionally implement :meth:__getitems__
, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.Note
:class:
~torch.utils.data.DataLoader
by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.Expand source code
@beartype.beartype class Ade20k(torch.utils.data.Dataset): @beartype.beartype @dataclasses.dataclass(frozen=True) class Sample: img_path: str seg_path: str label: str target: int samples: list[Sample] def __init__( self, cfg: config.Ade20kDataset, *, img_transform: Callable | None = None, seg_transform: Callable | None = lambda x: None, ): self.logger = logging.getLogger("ade20k") self.cfg = cfg self.img_dir = os.path.join(cfg.root, "images") self.seg_dir = os.path.join(cfg.root, "annotations") self.img_transform = img_transform self.seg_transform = seg_transform # Check that we have the right path. for subdir in ("images", "annotations"): if not os.path.isdir(os.path.join(cfg.root, subdir)): # Something is missing. if os.path.realpath(cfg.root).endswith(subdir): self.logger.warning( "The ADE20K root should contain 'images/' and 'annotations/' directories." ) raise ValueError(f"Can't find path '{os.path.join(cfg.root, subdir)}'.") _, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir) split_lookup: dict[int, str] = { value: key for key, value in split_mapping.items() } self.loader = torchvision.datasets.folder.default_loader assert cfg.split in set(split_lookup.values()) # Load all the image paths. imgs: list[str] = [ path for path, s in torchvision.datasets.folder.make_dataset( self.img_dir, split_mapping, extensions=torchvision.datasets.folder.IMG_EXTENSIONS, ) if split_lookup[s] == cfg.split ] segs: list[str] = [ path for path, s in torchvision.datasets.folder.make_dataset( self.seg_dir, split_mapping, extensions=torchvision.datasets.folder.IMG_EXTENSIONS, ) if split_lookup[s] == cfg.split ] # Load all the targets, classes and mappings with open(os.path.join(cfg.root, "sceneCategories.txt")) as fd: img_labels: list[str] = [line.split()[1] for line in fd.readlines()] label_set = sorted(set(img_labels)) label_to_idx = {label: i for i, label in enumerate(label_set)} self.samples = [ self.Sample(img_path, seg_path, label, label_to_idx[label]) for img_path, seg_path, label in zip(imgs, segs, img_labels) ] def __getitem__(self, index: int) -> dict[str, object]: # Convert to dict. sample = dataclasses.asdict(self.samples[index]) sample["image"] = self.loader(sample.pop("img_path")) if self.img_transform is not None: image = self.img_transform(sample.pop("image")) if image is not None: sample["image"] = image sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L") if self.seg_transform is not None: segmentation = self.seg_transform(sample.pop("segmentation")) if segmentation is not None: sample["segmentation"] = segmentation sample["index"] = index return sample def __len__(self) -> int: return len(self.samples)
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Class variables
var Sample
var samples : list[Ade20k.Sample]
class Clip (model_ckpt: str)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Clip(torch.nn.Module): def __init__(self, model_ckpt: str): super().__init__() import open_clip if model_ckpt.startswith("hf-hub:"): clip, _ = open_clip.create_model_from_pretrained( model_ckpt, cache_dir=helpers.get_cache_dir() ) else: arch, ckpt = model_ckpt.split("/") clip, _ = open_clip.create_model_from_pretrained( arch, pretrained=ckpt, cache_dir=helpers.get_cache_dir() ) model = clip.visual model.proj = None model.output_tokens = True # type: ignore self.model = model.eval() assert not isinstance(self.model, open_clip.timm_model.TimmModel) def get_residuals(self) -> list[torch.nn.Module]: return self.model.transformer.resblocks def get_patches(self, cfg: config.Activations) -> slice: return slice(None, None, None) def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> Float[Tensor, "batch patches dim"]: cls, patches = self.model(batch) return {"cls": cls, "patches": patches}
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch patches dim']
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. def get_patches(self,
cfg: Activations) ‑> slicedef get_residuals(self) ‑> list[torch.nn.modules.module.Module]
class Dataset (cfg: DataLoad)
-
Dataset of activations from disk.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Dataset(torch.utils.data.Dataset): """ Dataset of activations from disk. """ class Example(typing.TypedDict): """Individual example.""" act: Float[Tensor, " d_vit"] image_i: int patch_i: int cfg: config.DataLoad """Configuration; set via CLI args.""" metadata: "Metadata" """Activations metadata; automatically loaded from disk.""" layer_index: int """Layer index into the shards if we are choosing a specific layer.""" scalar: float """Normalizing scalar such that ||x / scalar ||_2 ~= sqrt(d_vit).""" act_mean: Float[Tensor, " d_vit"] """Mean activation.""" def __init__(self, cfg: config.DataLoad): self.cfg = cfg if not os.path.isdir(self.cfg.shard_root): raise RuntimeError(f"Activations are not saved at '{self.cfg.shard_root}'.") metadata_fpath = os.path.join(self.cfg.shard_root, "metadata.json") self.metadata = Metadata.load(metadata_fpath) # Pick a really big number so that if you accidentally use this when you shouldn't, you get an out of bounds IndexError. self.layer_index = 1_000_000 if isinstance(self.cfg.layer, int): err_msg = f"Non-exact matches for .layer field not supported; {self.cfg.layer} not in {self.metadata.layers}." assert self.cfg.layer in self.metadata.layers, err_msg self.layer_index = self.metadata.layers.index(self.cfg.layer) # Premptively set these values so that preprocess() doesn't freak out. self.scalar = 1.0 self.act_mean = torch.zeros(self.d_vit) # If either of these are true, we must do this work. if self.cfg.scale_mean is True or self.cfg.scale_norm is True: # Load a random subset of samples to calculate the mean activation and mean L2 norm. perm = np.random.default_rng(seed=42).permutation(len(self)) perm = perm[: cfg.n_random_samples] samples = [ self[p.item()] for p in helpers.progress( perm, every=25_000, desc="examples to calc means" ) ] samples = torch.stack([sample["act"] for sample in samples]) if samples.abs().max() > 1e3: raise ValueError( "You found an abnormally large activation {example.abs().max().item():.5f} that will mess up your L2 mean." ) # Activation mean if self.cfg.scale_mean: self.act_mean = samples.mean(axis=0) if (self.act_mean > 1e3).any(): raise ValueError( "You found an abnormally large activation that is messing up your activation mean." ) # Norm if self.cfg.scale_norm: l2_mean = torch.linalg.norm(samples - self.act_mean, axis=1).mean() if l2_mean > 1e3: raise ValueError( "You found an abnormally large activation that is messing up your L2 mean." ) self.scalar = l2_mean / math.sqrt(self.d_vit) elif isinstance(self.cfg.scale_mean, str): # Load mean activations from disk self.act_mean = torch.load(self.cfg.scale_mean) elif isinstance(self.cfg.scale_norm, str): # Load scalar normalization from disk self.scalar = torch.load(self.cfg.scale_norm).item() def transform(self, act: Float[np.ndarray, " d_vit"]) -> Float[Tensor, " d_vit"]: """ Apply a scalar normalization so the mean squared L2 norm is same as d_vit. This is from 'Scaling Monosemanticity': > As a preprocessing step we apply a scalar normalization to the model activations so their average squared L2 norm is the residual stream dimension So we divide by self.scalar which is the datasets (approximate) L2 mean before normalization divided by sqrt(d_vit). """ act = torch.from_numpy(act.copy()) act = act.clamp(-self.cfg.clamp, self.cfg.clamp) return (act - self.act_mean) / self.scalar @property def d_vit(self) -> int: """Dimension of the underlying vision transformer's embedding space.""" return self.metadata.d_vit @jaxtyped(typechecker=beartype.beartype) def __getitem__(self, i: int) -> Example: match (self.cfg.patches, self.cfg.layer): case ("cls", int()): img_act = self.get_img_patches(i) # Select layer's cls token. act = img_act[self.layer_index, 0, :] return self.Example(act=self.transform(act), image_i=i, patch_i=-1) case ("cls", "meanpool"): img_act = self.get_img_patches(i) # Select cls tokens from across all layers cls_act = img_act[:, 0, :] # Meanpool over the layers act = cls_act.mean(axis=0) return self.Example(act=self.transform(act), image_i=i, patch_i=-1) case ("meanpool", int()): img_act = self.get_img_patches(i) # Select layer's patches. layer_act = img_act[self.layer_index, 1:, :] # Meanpool over the patches act = layer_act.mean(axis=0) return self.Example(act=self.transform(act), image_i=i, patch_i=-1) case ("meanpool", "meanpool"): img_act = self.get_img_patches(i) # Select all layer's patches. act = img_act[:, 1:, :] # Meanpool over the layers and patches act = act.mean(axis=(0, 1)) return self.Example(act=self.transform(act), image_i=i, patch_i=-1) case ("patches", int()): n_imgs_per_shard = ( self.metadata.n_patches_per_shard // len(self.metadata.layers) // (self.metadata.n_patches_per_img + 1) ) n_examples_per_shard = ( n_imgs_per_shard * self.metadata.n_patches_per_img ) shard = i // n_examples_per_shard pos = i % n_examples_per_shard acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin") shape = ( n_imgs_per_shard, len(self.metadata.layers), self.metadata.n_patches_per_img + 1, self.metadata.d_vit, ) acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape) # Choose the layer and the non-CLS tokens. acts = acts[:, self.layer_index, 1:] # Choose a patch among n and the patches. act = acts[ pos // self.metadata.n_patches_per_img, pos % self.metadata.n_patches_per_img, ] return self.Example( act=self.transform(act), # What image is this? image_i=i // self.metadata.n_patches_per_img, patch_i=i % self.metadata.n_patches_per_img, ) case _: print((self.cfg.patches, self.cfg.layer)) typing.assert_never((self.cfg.patches, self.cfg.layer)) def get_shard_patches(self): raise NotImplementedError() def get_img_patches( self, i: int ) -> Float[np.ndarray, "n_layers all_patches d_vit"]: n_imgs_per_shard = ( self.metadata.n_patches_per_shard // len(self.metadata.layers) // (self.metadata.n_patches_per_img + 1) ) shard = i // n_imgs_per_shard pos = i % n_imgs_per_shard acts_fpath = os.path.join(self.cfg.shard_root, f"acts{shard:06}.bin") shape = ( n_imgs_per_shard, len(self.metadata.layers), self.metadata.n_patches_per_img + 1, self.metadata.d_vit, ) acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape) # Note that this is not yet copied! return acts[pos] def __len__(self) -> int: """ Dataset length depends on `patches` and `layer`. """ match (self.cfg.patches, self.cfg.layer): case ("cls", "all"): # Return a CLS token from a random image and random layer. return self.metadata.n_imgs * len(self.metadata.layers) case ("cls", int()): # Return a CLS token from a random image and fixed layer. return self.metadata.n_imgs case ("cls", "meanpool"): # Return a CLS token from a random image and meanpool over all layers. return self.metadata.n_imgs case ("meanpool", "all"): # Return the meanpool of all patches from a random image and random layer. return self.metadata.n_imgs * len(self.metadata.layers) case ("meanpool", int()): # Return the meanpool of all patches from a random image and fixed layer. return self.metadata.n_imgs case ("meanpool", "meanpool"): # Return the meanpool of all patches from a random image and meanpool over all layers. return self.metadata.n_imgs case ("patches", int()): # Return a patch from a random image, fixed layer, and random patch. return self.metadata.n_imgs * (self.metadata.n_patches_per_img) case ("patches", "meanpool"): # Return a patch from a random image, meanpooled over all layers, and a random patch. return self.metadata.n_imgs * (self.metadata.n_patches_per_img) case ("patches", "all"): # Return a patch from a random image, random layer and random patch. return ( self.metadata.n_imgs * len(self.metadata.layers) * self.metadata.n_patches_per_img ) case _: typing.assert_never((self.cfg.patches, self.cfg.layer))
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Class variables
var Example
-
Individual example.
var act_mean : jaxtyping.Float[Tensor, 'd_vit']
-
Mean activation.
var cfg : DataLoad
-
Configuration; set via CLI args.
var layer_index : int
-
Layer index into the shards if we are choosing a specific layer.
var metadata : Metadata
-
Activations metadata; automatically loaded from disk.
var scalar : float
-
Normalizing scalar such that ||x / scalar ||_2 ~= sqrt(d_vit).
Instance variables
prop d_vit : int
-
Dimension of the underlying vision transformer's embedding space.
Expand source code
@property def d_vit(self) -> int: """Dimension of the underlying vision transformer's embedding space.""" return self.metadata.d_vit
Methods
def get_img_patches(self, i: int) ‑> jaxtyping.Float[ndarray, 'n_layers all_patches d_vit']
def get_shard_patches(self)
def transform(self, act: jaxtyping.Float[ndarray, 'd_vit']) ‑> jaxtyping.Float[Tensor, 'd_vit']
-
Apply a scalar normalization so the mean squared L2 norm is same as d_vit. This is from 'Scaling Monosemanticity':
As a preprocessing step we apply a scalar normalization to the model activations so their average squared L2 norm is the residual stream dimension
So we divide by self.scalar which is the datasets (approximate) L2 mean before normalization divided by sqrt(d_vit).
class DinoV2 (model_ckpt: str)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class DinoV2(torch.nn.Module): def __init__(self, model_ckpt: str): super().__init__() self.model = torch.hub.load("facebookresearch/dinov2", model_ckpt) def get_residuals(self) -> list[torch.nn.Module]: return self.model.blocks def get_patches(self, cfg: config.Activations) -> slice: n_reg = self.model.num_register_tokens patches = torch.cat(( torch.tensor([0]), # CLS token torch.arange(n_reg + 1, n_reg + 1 + cfg.n_patches_per_img), # patches )) return patches def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> Float[Tensor, "batch patches dim"]: dct = self.model.forward_features(batch) features = torch.cat( (dct["x_norm_clstoken"][:, None, :], dct["x_norm_patchtokens"]), axis=1 ) return features
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch patches dim']
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. def get_patches(self,
cfg: Activations) ‑> slicedef get_residuals(self) ‑> list[torch.nn.modules.module.Module]
class ImageFolder (root: str | pathlib.Path,
transform: Callable | None = None,
target_transform: Callable | None = None,
loader: Callable[[str], Any] = <function default_loader>,
is_valid_file: Callable[[str], bool] | None = None,
allow_empty: bool = False)-
A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png
This class inherits from :class:
~torchvision.datasets.DatasetFolder
so the same methods can be overridden to customize the dataset.Args
- root (str or
pathlib.Path
): Root directory path. transform
:callable
, optional- A function/transform that takes in a PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
target_transform
:callable
, optional- A function/transform that takes in the target and transforms it.
loader
:callable
, optional- A function to load an image given its path.
is_valid_file
:callable
, optional- A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files)
allow_empty(bool, optional): If True, empty folders are considered to be valid classes. An error is raised on empty folders if False (default). Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples
Expand source code
@beartype.beartype class ImageFolder(torchvision.datasets.ImageFolder): def __getitem__(self, index: int) -> dict[str, object]: """ Args: index: Index Returns: dict with keys 'image', 'index', 'target' and 'label'. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return { "image": sample, "target": target, "label": self.classes[target], "index": index, }
Ancestors
- torchvision.datasets.folder.ImageFolder
- torchvision.datasets.folder.DatasetFolder
- torchvision.datasets.vision.VisionDataset
- torch.utils.data.dataset.Dataset
- typing.Generic
- root (str or
class Imagenet (cfg: ImagenetDataset,
*,
img_transform=None)-
An abstract class representing a :class:
Dataset
.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:
__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__
, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler
implementations and the default options of :class:~torch.utils.data.DataLoader
. Subclasses could also optionally implement :meth:__getitems__
, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.Note
:class:
~torch.utils.data.DataLoader
by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.Expand source code
@beartype.beartype class Imagenet(torch.utils.data.Dataset): def __init__(self, cfg: config.ImagenetDataset, *, img_transform=None): import datasets self.hf_dataset = datasets.load_dataset( cfg.name, split=cfg.split, trust_remote_code=True ) self.img_transform = img_transform self.labels = self.hf_dataset.info.features["label"].names def __getitem__(self, i): example = self.hf_dataset[i] example["index"] = i example["image"] = example["image"].convert("RGB") if self.img_transform: example["image"] = self.img_transform(example["image"]) example["target"] = example.pop("label") example["label"] = self.labels[example["target"]] return example def __len__(self) -> int: return len(self.hf_dataset)
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Subclasses
class Metadata (model_family: str,
model_ckpt: str,
layers: tuple[int, ...],
n_patches_per_img: int,
cls_token: bool,
d_vit: int,
seed: int,
n_imgs: int,
n_patches_per_shard: int,
data: str)-
Metadata(model_family: str, model_ckpt: str, layers: tuple[int, …], n_patches_per_img: int, cls_token: bool, d_vit: int, seed: int, n_imgs: int, n_patches_per_shard: int, data: str)
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Metadata: model_family: str model_ckpt: str layers: tuple[int, ...] n_patches_per_img: int cls_token: bool d_vit: int seed: int n_imgs: int n_patches_per_shard: int data: str @classmethod def from_cfg(cls, cfg: config.Activations) -> "Metadata": return cls( cfg.model_family, cfg.model_ckpt, tuple(cfg.layers), cfg.n_patches_per_img, cfg.cls_token, cfg.d_vit, cfg.seed, cfg.data.n_imgs, cfg.n_patches_per_shard, str(cfg.data), ) @classmethod def load(cls, fpath) -> "Metadata": with open(fpath) as fd: dct = json.load(fd) dct["layers"] = tuple(dct.pop("layers")) return cls(**dct) def dump(self, fpath): with open(fpath, "w") as fd: json.dump(dataclasses.asdict(self), fd, indent=4) @property def hash(self) -> str: cfg_str = json.dumps(dataclasses.asdict(self), sort_keys=True) return hashlib.sha256(cfg_str.encode("utf-8")).hexdigest()
Class variables
var cls_token : bool
var d_vit : int
var data : str
var layers : tuple[int, ...]
var model_ckpt : str
var model_family : str
var n_imgs : int
var n_patches_per_img : int
var n_patches_per_shard : int
var seed : int
Static methods
def from_cfg(cls,
cfg: Activations) ‑> Metadatadef load(cls, fpath) ‑> Metadata
Instance variables
prop hash : str
-
Expand source code
@property def hash(self) -> str: cfg_str = json.dumps(dataclasses.asdict(self), sort_keys=True) return hashlib.sha256(cfg_str.encode("utf-8")).hexdigest()
Methods
def dump(self, fpath)
class Moondream2 (model_ckpt: str)
-
Moondream2 has 14x14 pixel patches. For a 378x378 image (as we use here), this is 27x27 patches for a total of 729, with no [CLS] token.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Moondream2(torch.nn.Module): """ Moondream2 has 14x14 pixel patches. For a 378x378 image (as we use here), this is 27x27 patches for a total of 729, with no [CLS] token. """ def __init__(self, model_ckpt: str): super().__init__() import transformers model_id, revision = model_ckpt.split(":") mllm = transformers.AutoModelForCausalLM.from_pretrained( model_id, revision=revision, trust_remote_code=True ) self.model = mllm.vision_encoder.encoder.model.visual def get_patches(self, cfg: config.Activations) -> slice: return slice(None, None, None) def get_residuals(self) -> list[torch.nn.Module]: return self.model.blocks def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> Float[Tensor, "batch patches dim"]: features = self.model(batch) return features
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch patches dim']
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. def get_patches(self,
cfg: Activations) ‑> slicedef get_residuals(self) ‑> list[torch.nn.modules.module.Module]
class Recorder (cfg: Activations,
vit: torch.nn.modules.module.Module)-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Recorder(torch.nn.Module): cfg: config.Activations _storage: Float[Tensor, "batch n_layers all_patches dim"] | None _i: int def __init__(self, cfg: config.Activations, vit: torch.nn.Module): super().__init__() self.cfg = cfg self.patches = vit.get_patches(cfg) self._storage = None self._i = 0 self.logger = logging.getLogger( f"recorder({cfg.model_family}:{cfg.model_ckpt})" ) def register(self, modules: list[torch.nn.Module]): for i in self.cfg.layers: modules[i].register_forward_hook(self.hook) return self def hook( self, module, args: tuple, output: Float[Tensor, "batch n_layers dim"] ) -> None: if self._storage is None: batch, _, dim = output.shape self._storage = self._empty_storage(batch, dim, output.device) if self._storage[:, self._i, 0, :].shape != output[:, 0, :].shape: batch, _, dim = output.shape old_batch, _, _, old_dim = self._storage.shape msg = "Output shape does not match storage shape: (batch) %d != %d or (dim) %d != %d" self.logger.warning(msg, old_batch, batch, old_dim, dim) self._storage = self._empty_storage(batch, dim, output.device) self._storage[:, self._i] = output[:, self.patches, :].detach() self._i += 1 def _empty_storage(self, batch: int, dim: int, device: torch.device): n_patches_per_img = self.cfg.n_patches_per_img if self.cfg.cls_token: n_patches_per_img += 1 return torch.zeros( (batch, len(self.cfg.layers), n_patches_per_img, dim), device=device ) def reset(self): self._i = 0 @property def activations(self) -> Float[Tensor, "batch n_layers all_patches dim"]: if self._storage is None: raise RuntimeError("First call model()") return self._storage.cpu()
Ancestors
- torch.nn.modules.module.Module
Class variables
var cfg : Activations
Instance variables
prop activations : jaxtyping.Float[Tensor, 'batch n_layers all_patches dim']
-
Expand source code
@property def activations(self) -> Float[Tensor, "batch n_layers all_patches dim"]: if self._storage is None: raise RuntimeError("First call model()") return self._storage.cpu()
Methods
def hook(self, module, args: tuple, output: jaxtyping.Float[Tensor, 'batch n_layers dim']) ‑> None
def register(self, modules: list[torch.nn.modules.module.Module])
def reset(self)
class ShardWriter (cfg: Activations)
-
ShardWriter is a stateful object that handles sharded activation writing to disk.
Expand source code
@beartype.beartype class ShardWriter: """ ShardWriter is a stateful object that handles sharded activation writing to disk. """ root: str shape: tuple[int, int, int, int] shard: int acts_path: str acts: Float[np.ndarray, "n_imgs_per_shard n_layers all_patches d_vit"] | None filled: int def __init__(self, cfg: config.Activations): self.logger = logging.getLogger("shard-writer") self.root = get_acts_dir(cfg) n_patches_per_img = cfg.n_patches_per_img if cfg.cls_token: n_patches_per_img += 1 self.n_imgs_per_shard = ( cfg.n_patches_per_shard // len(cfg.layers) // n_patches_per_img ) self.shape = ( self.n_imgs_per_shard, len(cfg.layers), n_patches_per_img, cfg.d_vit, ) self.shard = -1 self.acts = None self.next_shard() @jaxtyped(typechecker=beartype.beartype) def __setitem__( self, i: slice, val: Float[Tensor, "_ n_layers all_patches d_vit"] ) -> None: assert i.step is None a, b = i.start, i.stop assert len(val) == b - a offset = self.n_imgs_per_shard * self.shard if b >= offset + self.n_imgs_per_shard: # We have run out of space in this mmap'ed file. Let's fill it as much as we can. n_fit = offset + self.n_imgs_per_shard - a self.acts[a - offset : a - offset + n_fit] = val[:n_fit] self.filled = a - offset + n_fit self.next_shard() # Recursively call __setitem__ in case we need *another* shard self[a + n_fit : b] = val[n_fit:] else: msg = f"0 <= {a} - {offset} <= {offset} + {self.n_imgs_per_shard}" assert 0 <= a - offset <= offset + self.n_imgs_per_shard, msg msg = f"0 <= {b} - {offset} <= {offset} + {self.n_imgs_per_shard}" assert 0 <= b - offset <= offset + self.n_imgs_per_shard, msg self.acts[a - offset : b - offset] = val self.filled = b - offset def flush(self) -> None: if self.acts is not None: self.acts.flush() self.acts = None def next_shard(self) -> None: self.flush() self.shard += 1 self._count = 0 self.acts_path = os.path.join(self.root, f"acts{self.shard:06}.bin") self.acts = np.memmap( self.acts_path, mode="w+", dtype=np.float32, shape=self.shape ) self.filled = 0 self.logger.info("Opened shard '%s'.", self.acts_path)
Class variables
var acts : jaxtyping.Float[ndarray, 'n_imgs_per_shard n_layers all_patches d_vit'] | None
var acts_path : str
var filled : int
var root : str
var shape : tuple[int, int, int, int]
var shard : int
Methods
def flush(self) ‑> None
def next_shard(self) ‑> None
class Siglip (model_ckpt: str)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Siglip(torch.nn.Module): def __init__(self, model_ckpt: str): super().__init__() import open_clip if model_ckpt.startswith("hf-hub:"): clip, _ = open_clip.create_model_from_pretrained( model_ckpt, cache_dir=helpers.get_cache_dir() ) else: arch, ckpt = model_ckpt.split("/") clip, _ = open_clip.create_model_from_pretrained( arch, pretrained=ckpt, cache_dir=helpers.get_cache_dir() ) model = clip.visual model.proj = None model.output_tokens = True # type: ignore self.model = model assert isinstance(self.model, open_clip.timm_model.TimmModel) def get_residuals(self) -> list[torch.nn.Module]: return self.model.trunk.blocks def get_patches(self, cfg: config.Activations) -> slice: return slice(None, None, None) def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> Float[Tensor, "batch patches dim"]: result = self.model(batch) return result
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch patches dim']
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. def get_patches(self,
cfg: Activations) ‑> slicedef get_residuals(self) ‑> list[torch.nn.modules.module.Module]
class WrappedVisionTransformer (cfg: Activations)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class WrappedVisionTransformer(torch.nn.Module): def __init__(self, cfg: config.Activations): super().__init__() self.vit = make_vit(cfg) self.recorder = Recorder(cfg, self.vit).register(self.vit.get_residuals()) def forward( self, batch: Float[Tensor, "batch 3 width height"] ) -> tuple[Float[Tensor, "batch patches dim"], Float[Tensor, "..."]]: self.recorder.reset() result = self.vit(batch) return result, self.recorder.activations
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> tuple[jaxtyping.Float[Tensor, 'batch patches dim'], jaxtyping.Float[Tensor, '...']]
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.