Module saev.visuals
There is some important notation used only in this file to dramatically shorten variable names.
Variables suffixed with _im
refer to entire images, and variables suffixed with _p
refer to patches.
def batched_idx(total_size: int, batch_size: int) ‑> Iterator[tuple[int, int]]
Iterate over (start, end) indices for total_size examples, where end - start is at most batch_size.
- total number of examples
- maximum distance between the generated indices.
A generator of (int, int) tuples that can slice up a list or a tensor.
def dump_activations(cfg: Visuals)
For each SAE latent, we want to know which images have the most total "activation". That is, we keep track of each patch
def gather_batched(value: jaxtyping.Float[Tensor, 'batch n dim'],
i: jaxtyping.Int[Tensor, 'batch k']) ‑> jaxtyping.Float[Tensor, 'batch k dim']def get_new_topk(val1: jaxtyping.Float[Tensor, 'd_sae k'],
i1: jaxtyping.Int[Tensor, 'd_sae k'],
val2: jaxtyping.Float[Tensor, 'd_sae k'],
i2: jaxtyping.Int[Tensor, 'd_sae k'],
k: int) ‑> tuple[jaxtyping.Float[Tensor, 'd_sae k'], jaxtyping.Int[Tensor, 'd_sae k']]-
Picks out the new top k values among val1 and val2. Also keeps track of i1 and i2, then indices of the values in the original dataset.
- top k original SAE values.
- the patch indices of those original top k values.
- top k incoming SAE values.
- the patch indices of those incoming top k values.
- k.
The new top k values and their patch indices.
def get_sae_acts(vit_acts: jaxtyping.Float[Tensor, 'n d_vit'],
sae: SparseAutoencoder,
cfg: Visuals) ‑> jaxtyping.Float[Tensor, 'n d_sae']-
Get SAE hidden layer activations for a batch of ViT activations.
- Batch of ViT activations
- Sparse autoencder.
- Experimental config.
def get_topk_img(cfg: Visuals) ‑> TopKImg
Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by
max over all images: f_x(cls)[i]
Thus, we will never have duplicate images for a given latent. But we also will not have patch-level activations (a nice heatmap).
- Config.
A tuple of TopKImg and the first m features' activation distributions.
def get_topk_patch(cfg: Visuals) ‑> TopKPatch
Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by
max over all patches: f_x(patch)[i]
Thus, we could end up with duplicate images in the top k, if an image has more than one patch that maximally activates an SAE latent.
- Config.
A tuple of TopKPatch and m randomly sampled activation distributions.
def main(cfg: Visuals)
document this function.
Dump top-k images to a directory.
- Configuration object.
def make_img(elem: GridElement,
upper: float | None = None) ‑> PIL.Image.Imagedef plot_activation_distributions(cfg: Visuals,
distributions: jaxtyping.Float[Tensor, 'm n'])def safe_load(path: str) ‑> object
def test_online_quantile_estimation(true: float, percentile: float)
class GridElement (img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches'])
GridElement(img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches'])
Expand source code
@jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass class GridElement: img: Image.Image label: str patches: Float[Tensor, " n_patches"]
Class variables
var img : PIL.Image.Image
var label : str
var patches : jaxtyping.Float[Tensor, 'n_patches']
class PercentileEstimator (percentile: float | int,
total: int,
lr: float = 0.001,
shape: tuple[int, ...] = ())-
Expand source code
@beartype.beartype class PercentileEstimator: def __init__( self, percentile: float | int, total: int, lr: float = 1e-3, shape: tuple[int, ...] = (), ): self.percentile = percentile = total = lr self._estimate = torch.zeros(shape) self._step = 0 def update(self, x): """ Update the estimator with a new value. This method maintains the marker positions using the P2 algorithm rules. When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions. Arguments: x: The new value to incorporate into the estimation """ self._step += 1 step_size = * ( - self._step) / # Is a no-op if it's already on the same device. self._estimate = self._estimate += step_size * ( torch.sign(x - self._estimate) + 2 * self.percentile / 100 - 1.0 ) @property def estimate(self): return self._estimate
Instance variables
prop estimate
Expand source code
@property def estimate(self): return self._estimate
def update(self, x)
Update the estimator with a new value.
This method maintains the marker positions using the P2 algorithm rules. When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions.
x: The new value to incorporate into the estimation
class TopKImg (top_values: jaxtyping.Float[Tensor, 'd_sae k'],
top_i: jaxtyping.Int[Tensor, 'd_sae k'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
sparsity: jaxtyping.Float[Tensor, 'd_sae'],
distributions: jaxtyping.Float[Tensor, 'm n'],
percentiles: jaxtyping.Float[Tensor, 'd_sae'])-
Document this class.
Expand source code
@jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True) class TopKImg: ".. todo:: Document this class." top_values: Float[Tensor, "d_sae k"] top_i: Int[Tensor, "d_sae k"] mean_values: Float[Tensor, " d_sae"] sparsity: Float[Tensor, " d_sae"] distributions: Float[Tensor, "m n"] percentiles: Float[Tensor, " d_sae"]
Class variables
var distributions : jaxtyping.Float[Tensor, 'm n']
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
var percentiles : jaxtyping.Float[Tensor, 'd_sae']
var sparsity : jaxtyping.Float[Tensor, 'd_sae']
var top_i : jaxtyping.Int[Tensor, 'd_sae k']
var top_values : jaxtyping.Float[Tensor, 'd_sae k']
class TopKPatch (top_values: jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img'],
top_i: jaxtyping.Int[Tensor, 'd_sae k'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
sparsity: jaxtyping.Float[Tensor, 'd_sae'],
distributions: jaxtyping.Float[Tensor, 'm n'],
percentiles: jaxtyping.Float[Tensor, 'd_sae'])-
Document this class.
Expand source code
@jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True) class TopKPatch: ".. todo:: Document this class." top_values: Float[Tensor, "d_sae k n_patches_per_img"] top_i: Int[Tensor, "d_sae k"] mean_values: Float[Tensor, " d_sae"] sparsity: Float[Tensor, " d_sae"] distributions: Float[Tensor, "m n"] percentiles: Float[Tensor, " d_sae"]
Class variables
var distributions : jaxtyping.Float[Tensor, 'm n']
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
var percentiles : jaxtyping.Float[Tensor, 'd_sae']
var sparsity : jaxtyping.Float[Tensor, 'd_sae']
var top_i : jaxtyping.Int[Tensor, 'd_sae k']
var top_values : jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img']