Module saev.visuals

There is some important notation used only in this file to dramatically shorten variable names.

Variables suffixed with _im refer to entire images, and variables suffixed with _p refer to patches.

Functions

def batched_idx(total_size: int, batch_size: int) ‑> Iterator[tuple[int, int]]

Iterate over (start, end) indices for total_size examples, where end - start is at most batch_size.

Args

total_size
total number of examples
batch_size
maximum distance between the generated indices.

Returns

A generator of (int, int) tuples that can slice up a list or a tensor.

def dump_activations(cfg: Visuals)

For each SAE latent, we want to know which images have the most total "activation". That is, we keep track of each patch

def gather_batched(value: jaxtyping.Float[Tensor, 'batch n dim'],
i: jaxtyping.Int[Tensor, 'batch k']) ‑> jaxtyping.Float[Tensor, 'batch k dim']
def get_new_topk(val1: jaxtyping.Float[Tensor, 'd_sae k'],
i1: jaxtyping.Int[Tensor, 'd_sae k'],
val2: jaxtyping.Float[Tensor, 'd_sae k'],
i2: jaxtyping.Int[Tensor, 'd_sae k'],
k: int) ‑> tuple[jaxtyping.Float[Tensor, 'd_sae k'], jaxtyping.Int[Tensor, 'd_sae k']]

Picks out the new top k values among val1 and val2. Also keeps track of i1 and i2, then indices of the values in the original dataset.

Args

val1
top k original SAE values.
i1
the patch indices of those original top k values.
val2
top k incoming SAE values.
i2
the patch indices of those incoming top k values.
k
k.

Returns

The new top k values and their patch indices.

def get_sae_acts(vit_acts: jaxtyping.Float[Tensor, 'n d_vit'],
sae: SparseAutoencoder,
cfg: Visuals) ‑> jaxtyping.Float[Tensor, 'n d_sae']

Get SAE hidden layer activations for a batch of ViT activations.

Args

vit_acts
Batch of ViT activations
sae
Sparse autoencder.
cfg
Experimental config.
def get_topk_img(cfg: Visuals) ‑> TopKImg

Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by

max over all images: f_x(cls)[i]

Thus, we will never have duplicate images for a given latent. But we also will not have patch-level activations (a nice heatmap).

Args

cfg
Config.

Returns

A tuple of TopKImg and the first m features' activation distributions.

def get_topk_patch(cfg: Visuals) ‑> TopKPatch

Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by

max over all patches: f_x(patch)[i]

Thus, we could end up with duplicate images in the top k, if an image has more than one patch that maximally activates an SAE latent.

Args

cfg
Config.

Returns

A tuple of TopKPatch and m randomly sampled activation distributions.

def main(cfg: Visuals)

TODO

document this function.

Dump top-k images to a directory.

Args

cfg
Configuration object.
def make_img(elem: GridElement,
*,
upper: float | None = None) ‑> PIL.Image.Image
def plot_activation_distributions(cfg: Visuals,
distributions: jaxtyping.Float[Tensor, 'm n'])
def safe_load(path: str) ‑> object
def test_online_quantile_estimation(true: float, percentile: float)

Classes

class GridElement (img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches'])

GridElement(img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches'])

Expand source code
@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass
class GridElement:
    img: Image.Image
    label: str
    patches: Float[Tensor, " n_patches"]

Class variables

var img : PIL.Image.Image
var label : str
var patches : jaxtyping.Float[Tensor, 'n_patches']
class PercentileEstimator (percentile: float | int,
total: int,
lr: float = 0.001,
shape: tuple[int, ...] = ())
Expand source code
@beartype.beartype
class PercentileEstimator:
    def __init__(
        self,
        percentile: float | int,
        total: int,
        lr: float = 1e-3,
        shape: tuple[int, ...] = (),
    ):
        self.percentile = percentile
        self.total = total
        self.lr = lr

        self._estimate = torch.zeros(shape)
        self._step = 0

    def update(self, x):
        """
        Update the estimator with a new value.

        This method maintains the marker positions using the P2 algorithm rules.
        When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions.

        Arguments:
            x: The new value to incorporate into the estimation
        """
        self._step += 1

        step_size = self.lr * (self.total - self._step) / self.total

        # Is a no-op if it's already on the same device.
        self._estimate = self._estimate.to(x.device)

        self._estimate += step_size * (
            torch.sign(x - self._estimate) + 2 * self.percentile / 100 - 1.0
        )

    @property
    def estimate(self):
        return self._estimate

Instance variables

prop estimate
Expand source code
@property
def estimate(self):
    return self._estimate

Methods

def update(self, x)

Update the estimator with a new value.

This method maintains the marker positions using the P2 algorithm rules. When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions.

Arguments

x: The new value to incorporate into the estimation

class TopKImg (top_values: jaxtyping.Float[Tensor, 'd_sae k'],
top_i: jaxtyping.Int[Tensor, 'd_sae k'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
sparsity: jaxtyping.Float[Tensor, 'd_sae'],
distributions: jaxtyping.Float[Tensor, 'm n'],
percentiles: jaxtyping.Float[Tensor, 'd_sae'])

TODO

Document this class.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass(frozen=True)
class TopKImg:
    ".. todo:: Document this class."

    top_values: Float[Tensor, "d_sae k"]
    top_i: Int[Tensor, "d_sae k"]
    mean_values: Float[Tensor, " d_sae"]
    sparsity: Float[Tensor, " d_sae"]
    distributions: Float[Tensor, "m n"]
    percentiles: Float[Tensor, " d_sae"]

Class variables

var distributions : jaxtyping.Float[Tensor, 'm n']
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
var percentiles : jaxtyping.Float[Tensor, 'd_sae']
var sparsity : jaxtyping.Float[Tensor, 'd_sae']
var top_i : jaxtyping.Int[Tensor, 'd_sae k']
var top_values : jaxtyping.Float[Tensor, 'd_sae k']
class TopKPatch (top_values: jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img'],
top_i: jaxtyping.Int[Tensor, 'd_sae k'],
mean_values: jaxtyping.Float[Tensor, 'd_sae'],
sparsity: jaxtyping.Float[Tensor, 'd_sae'],
distributions: jaxtyping.Float[Tensor, 'm n'],
percentiles: jaxtyping.Float[Tensor, 'd_sae'])

TODO

Document this class.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass(frozen=True)
class TopKPatch:
    ".. todo:: Document this class."

    top_values: Float[Tensor, "d_sae k n_patches_per_img"]
    top_i: Int[Tensor, "d_sae k"]
    mean_values: Float[Tensor, " d_sae"]
    sparsity: Float[Tensor, " d_sae"]
    distributions: Float[Tensor, "m n"]
    percentiles: Float[Tensor, " d_sae"]

Class variables

var distributions : jaxtyping.Float[Tensor, 'm n']
var mean_values : jaxtyping.Float[Tensor, 'd_sae']
var percentiles : jaxtyping.Float[Tensor, 'd_sae']
var sparsity : jaxtyping.Float[Tensor, 'd_sae']
var top_i : jaxtyping.Int[Tensor, 'd_sae k']
var top_values : jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img']