Module saev.visuals

There is some important notation used only in this file to dramatically shorten variable names.

Variables suffixed with _im refer to entire images, and variables suffixed with _p refer to patches.

Functions

def batched_idx(total_size: int, batch_size: int) ‑> collections.abc.Iterator[tuple[int, int]]

Iterate over (start, end) indices for total_size examples, where end - start is at most batch_size.

Args

total_size
total number of examples
batch_size
maximum distance between the generated indices.

Returns

A generator of (int, int) tuples that can slice up a list or a tensor.

def dump_activations(cfg: Visuals)

Dump ViT activation statistics for later use.

The dataset described by cfg is processed to find the images or patches that maximally activate each SAE latent. Various tensors summarising these activations are then written to cfg.root so they can be loaded by other tools.

Args

cfg
options controlling which activations are processed and where the resulting files are saved.

Returns

None. All data is saved to disk.

def gather_batched(value: jaxtyping.Float[Tensor, 'batch n dim'],
i: jaxtyping.Int[Tensor, 'batch k']) ‑> jaxtyping.Float[Tensor, 'batch k dim']
def get_new_topk(val1: jaxtyping.Float[Tensor, 'd_sae k'],
i1: jaxtyping.Int[Tensor, 'd_sae k'],
val2: jaxtyping.Float[Tensor, 'd_sae k'],
i2: jaxtyping.Int[Tensor, 'd_sae k'],
k: int) ‑> tuple[jaxtyping.Float[Tensor, 'd_sae k'], jaxtyping.Int[Tensor, 'd_sae k']]

Picks out the new top k values among val1 and val2. Also keeps track of i1 and i2, then indices of the values in the original dataset.

Args

val1
top k original SAE values.
i1
the patch indices of those original top k values.
val2
top k incoming SAE values.
i2
the patch indices of those incoming top k values.
k
k.

Returns

The new top k values and their patch indices.

def get_sae_acts(vit_acts: jaxtyping.Float[Tensor, 'n d_vit'],
sae: SparseAutoencoder,
cfg: Visuals) ‑> jaxtyping.Float[Tensor, 'n d_sae']

Get SAE hidden layer activations for a batch of ViT activations.

Args

vit_acts
Batch of ViT activations
sae
Sparse autoencder.
cfg
Experimental config.
def get_topk_img(cfg: Visuals) ‑> TopKImg

Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by

max over all images: f_x(cls)[i]

Thus, we will never have duplicate images for a given latent. But we also will not have patch-level activations (a nice heatmap).

Args

cfg
Config.

Returns

A tuple of TopKImg and the first m features' activation distributions.

def get_topk_patch(cfg: Visuals) ‑> TopKPatch

Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by

max over all patches: f_x(patch)[i]

Thus, we could end up with duplicate images in the top k, if an image has more than one patch that maximally activates an SAE latent.

Args

cfg
Config.

Returns

A tuple of TopKPatch and m randomly sampled activation distributions.

def main(cfg: Visuals)

TODO

document this function.

Dump top-k images to a directory.

Args

cfg
Configuration object.
def make_img(elem: GridElement,
*,
upper: float | None = None) ‑> PIL.Image.Image
def plot_activation_distributions(cfg: Visuals,
distributions: jaxtyping.Float[Tensor, 'm n'])
def safe_load(path: str) ‑> object
def test_online_quantile_estimation(true: float, percentile: float)

Classes

class GridElement (img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches'])

GridElement(img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches'])

Expand source code
@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass
class GridElement:
    img: Image.Image
    label: str
    patches: Float[Tensor, " n_patches"]

Class variables

var img : PIL.Image.Image
var label : str
var patches : jaxtyping.Float[Tensor, 'n_patches']
class PercentileEstimator (percentile: float | int,
total: int,
lr: float = 0.001,
shape: tuple[int, ...] = ())
Expand source code
@beartype.beartype
class PercentileEstimator:
    def __init__(
        self,
        percentile: float | int,
        total: int,
        lr: float = 1e-3,
        shape: tuple[int, ...] = (),
    ):
        self.percentile = percentile
        self.total = total
        self.lr = lr

        self._estimate = torch.zeros(shape)
        self._step = 0

    def update(self, x):
        """
        Update the estimator with a new value.

        This method maintains the marker positions using the P2 algorithm rules.
        When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions.

        Arguments:
            x: The new value to incorporate into the estimation
        """
        self._step += 1

        step_size = self.lr * (self.total - self._step) / self.total

        # Is a no-op if it's already on the same device.
        self._estimate = self._estimate.to(x.device)

        self._estimate += step_size * (
            torch.sign(x - self._estimate) + 2 * self.percentile / 100 - 1.0
        )

    @property
    def estimate(self):
        return self._estimate

Instance variables

prop estimate
Expand source code
@property
def estimate(self):
    return self._estimate

Methods

def update(self, x)

Update the estimator with a new value.

This method maintains the marker positions using the P2 algorithm rules. When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions.

Arguments

x: The new value to incorporate into the estimation

class TopKImg (top_values: jaxtyping.Float[Tensor, 'd_sae k'],
top_i: jaxtyping.Int[Tensor, 'd_sae k'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
sparsity: jaxtyping.Float[Tensor, 'd_sae'],
distributions: jaxtyping.Float[Tensor, 'm n'],
percentiles: jaxtyping.Float[Tensor, 'd_sae'])

TODO

Document this class.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass(frozen=True)
class TopKImg:
    ".. todo:: Document this class."

    top_values: Float[Tensor, "d_sae k"]
    top_i: Int[Tensor, "d_sae k"]
    mean_values: Float[Tensor, " d_sae"]
    sparsity: Float[Tensor, " d_sae"]
    distributions: Float[Tensor, "m n"]
    percentiles: Float[Tensor, " d_sae"]

Class variables

var distributions : jaxtyping.Float[Tensor, 'm n']
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
var percentiles : jaxtyping.Float[Tensor, 'd_sae']
var sparsity : jaxtyping.Float[Tensor, 'd_sae']
var top_i : jaxtyping.Int[Tensor, 'd_sae k']
var top_values : jaxtyping.Float[Tensor, 'd_sae k']
class TopKPatch (top_values: jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img'],
top_i: jaxtyping.Int[Tensor, 'd_sae k'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
sparsity: jaxtyping.Float[Tensor, 'd_sae'],
distributions: jaxtyping.Float[Tensor, 'm n'],
percentiles: jaxtyping.Float[Tensor, 'd_sae'])

TODO

Document this class.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass(frozen=True)
class TopKPatch:
    ".. todo:: Document this class."

    top_values: Float[Tensor, "d_sae k n_patches_per_img"]
    top_i: Int[Tensor, "d_sae k"]
    mean_values: Float[Tensor, " d_sae"]
    sparsity: Float[Tensor, " d_sae"]
    distributions: Float[Tensor, "m n"]
    percentiles: Float[Tensor, " d_sae"]

Class variables

var distributions : jaxtyping.Float[Tensor, 'm n']
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
var percentiles : jaxtyping.Float[Tensor, 'd_sae']
var sparsity : jaxtyping.Float[Tensor, 'd_sae']
var top_i : jaxtyping.Int[Tensor, 'd_sae k']
var top_values : jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img']