Module saev.visuals
There is some important notation used only in this file to dramatically shorten variable names.
Variables suffixed with _im
refer to entire images, and variables suffixed with _p
refer to patches.
Functions
def batched_idx(total_size: int, batch_size: int) ‑> collections.abc.Iterator[tuple[int, int]]
-
Iterate over (start, end) indices for total_size examples, where end - start is at most batch_size.
Args
total_size
- total number of examples
batch_size
- maximum distance between the generated indices.
Returns
A generator of (int, int) tuples that can slice up a list or a tensor.
def dump_activations(cfg: Visuals)
-
Dump ViT activation statistics for later use.
The dataset described by
cfg
is processed to find the images or patches that maximally activate each SAE latent. Various tensors summarising these activations are then written tocfg.root
so they can be loaded by other tools.Args
cfg
- options controlling which activations are processed and where the resulting files are saved.
Returns
None. All data is saved to disk.
def gather_batched(value: jaxtyping.Float[Tensor, 'batch n dim'],
i: jaxtyping.Int[Tensor, 'batch k']) ‑> jaxtyping.Float[Tensor, 'batch k dim']def get_new_topk(val1: jaxtyping.Float[Tensor, 'd_sae k'],
i1: jaxtyping.Int[Tensor, 'd_sae k'],
val2: jaxtyping.Float[Tensor, 'd_sae k'],
i2: jaxtyping.Int[Tensor, 'd_sae k'],
k: int) ‑> tuple[jaxtyping.Float[Tensor, 'd_sae k'], jaxtyping.Int[Tensor, 'd_sae k']]-
Picks out the new top k values among val1 and val2. Also keeps track of i1 and i2, then indices of the values in the original dataset.
Args
val1
- top k original SAE values.
i1
- the patch indices of those original top k values.
val2
- top k incoming SAE values.
i2
- the patch indices of those incoming top k values.
k
- k.
Returns
The new top k values and their patch indices.
def get_sae_acts(vit_acts: jaxtyping.Float[Tensor, 'n d_vit'],
sae: SparseAutoencoder,
cfg: Visuals) ‑> jaxtyping.Float[Tensor, 'n d_sae']-
Get SAE hidden layer activations for a batch of ViT activations.
Args
vit_acts
- Batch of ViT activations
sae
- Sparse autoencder.
cfg
- Experimental config.
def get_topk_img(cfg: Visuals) ‑> TopKImg
-
Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by
max over all images: f_x(cls)[i]
Thus, we will never have duplicate images for a given latent. But we also will not have patch-level activations (a nice heatmap).
Args
cfg
- Config.
Returns
A tuple of TopKImg and the first m features' activation distributions.
def get_topk_patch(cfg: Visuals) ‑> TopKPatch
-
Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by
max over all patches: f_x(patch)[i]
Thus, we could end up with duplicate images in the top k, if an image has more than one patch that maximally activates an SAE latent.
Args
cfg
- Config.
Returns
A tuple of TopKPatch and m randomly sampled activation distributions.
def main(cfg: Visuals)
-
TODO
document this function.
Dump top-k images to a directory.
Args
cfg
- Configuration object.
def make_img(elem: GridElement,
*,
upper: float | None = None) ‑> PIL.Image.Imagedef plot_activation_distributions(cfg: Visuals,
distributions: jaxtyping.Float[Tensor, 'm n'])def safe_load(path: str) ‑> object
def test_online_quantile_estimation(true: float, percentile: float)
Classes
class GridElement (img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches'])
-
GridElement(img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches'])
Expand source code
@jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass class GridElement: img: Image.Image label: str patches: Float[Tensor, " n_patches"]
Class variables
var img : PIL.Image.Image
var label : str
var patches : jaxtyping.Float[Tensor, 'n_patches']
class PercentileEstimator (percentile: float | int,
total: int,
lr: float = 0.001,
shape: tuple[int, ...] = ())-
Expand source code
@beartype.beartype class PercentileEstimator: def __init__( self, percentile: float | int, total: int, lr: float = 1e-3, shape: tuple[int, ...] = (), ): self.percentile = percentile self.total = total self.lr = lr self._estimate = torch.zeros(shape) self._step = 0 def update(self, x): """ Update the estimator with a new value. This method maintains the marker positions using the P2 algorithm rules. When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions. Arguments: x: The new value to incorporate into the estimation """ self._step += 1 step_size = self.lr * (self.total - self._step) / self.total # Is a no-op if it's already on the same device. self._estimate = self._estimate.to(x.device) self._estimate += step_size * ( torch.sign(x - self._estimate) + 2 * self.percentile / 100 - 1.0 ) @property def estimate(self): return self._estimate
Instance variables
prop estimate
-
Expand source code
@property def estimate(self): return self._estimate
Methods
def update(self, x)
-
Update the estimator with a new value.
This method maintains the marker positions using the P2 algorithm rules. When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions.
Arguments
x: The new value to incorporate into the estimation
class TopKImg (top_values: jaxtyping.Float[Tensor, 'd_sae k'],
top_i: jaxtyping.Int[Tensor, 'd_sae k'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
sparsity: jaxtyping.Float[Tensor, 'd_sae'],
distributions: jaxtyping.Float[Tensor, 'm n'],
percentiles: jaxtyping.Float[Tensor, 'd_sae'])-
TODO
Document this class.
Expand source code
@jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True) class TopKImg: ".. todo:: Document this class." top_values: Float[Tensor, "d_sae k"] top_i: Int[Tensor, "d_sae k"] mean_values: Float[Tensor, " d_sae"] sparsity: Float[Tensor, " d_sae"] distributions: Float[Tensor, "m n"] percentiles: Float[Tensor, " d_sae"]
Class variables
var distributions : jaxtyping.Float[Tensor, 'm n']
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
var percentiles : jaxtyping.Float[Tensor, 'd_sae']
var sparsity : jaxtyping.Float[Tensor, 'd_sae']
var top_i : jaxtyping.Int[Tensor, 'd_sae k']
var top_values : jaxtyping.Float[Tensor, 'd_sae k']
class TopKPatch (top_values: jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img'],
top_i: jaxtyping.Int[Tensor, 'd_sae k'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
sparsity: jaxtyping.Float[Tensor, 'd_sae'],
distributions: jaxtyping.Float[Tensor, 'm n'],
percentiles: jaxtyping.Float[Tensor, 'd_sae'])-
TODO
Document this class.
Expand source code
@jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True) class TopKPatch: ".. todo:: Document this class." top_values: Float[Tensor, "d_sae k n_patches_per_img"] top_i: Int[Tensor, "d_sae k"] mean_values: Float[Tensor, " d_sae"] sparsity: Float[Tensor, " d_sae"] distributions: Float[Tensor, "m n"] percentiles: Float[Tensor, " d_sae"]
Class variables
var distributions : jaxtyping.Float[Tensor, 'm n']
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
var percentiles : jaxtyping.Float[Tensor, 'd_sae']
var sparsity : jaxtyping.Float[Tensor, 'd_sae']
var top_i : jaxtyping.Int[Tensor, 'd_sae k']
var top_values : jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img']