Module contrib.semseg.quantitative


def argmax_logits(logits_BPC: jaxtyping.Float[Tensor, 'batch patches channels_with_null']) ‑> jaxtyping.Int[Tensor, 'batch patches']
def compute_class_results(orig_preds: jaxtyping.Int[Tensor, 'n_imgs patches'],
mod_preds: jaxtyping.Int[Tensor, 'n_imgs patches']) ‑> list[ClassResults]
def eval_auto_feat(cfg: Quantitative,
sae: SparseAutoencoder,
clf: torch.nn.modules.module.Module,
dataloader) ‑> Report
def eval_rand_feat(cfg: Quantitative,
sae: SparseAutoencoder,
clf: torch.nn.modules.module.Module,
dataloader) ‑> Report

Evaluates the effects of suppressing a random SAE feature.


Configuration for quantitative evaluation
Trained sparse autoencoder model
Trained classifier model
DataLoader providing batches of images


Report containing intervention results, including per-class changes

def eval_rand_vec(cfg: Quantitative,
sae: SparseAutoencoder,
clf: torch.nn.modules.module.Module,
dataloader) ‑> Report

Evaluates the effects of adding a random unit vector to the patches.


Configuration for quantitative evaluation
Trained sparse autoencoder model
Trained classifier model
DataLoader providing batches of images


Report containing intervention results, including per-class changes

def get_latent_lookup(cfg: Quantitative,
sae: SparseAutoencoder,
dataloader) ‑> jaxtyping.Int[Tensor, '151']

Dimension key:

  • B: batch dimension
  • P: patches per image
  • D: ViT hidden dimension
  • S: SAE feature dimension
  • T: threshold dimension
  • C: class dimension
  • L: layer dimension
def get_patch_i(i: jaxtyping.Int[Tensor, 'batch width height'], n_patches_per_img: int) ‑> jaxtyping.Int[Tensor, 'batch width height']
def get_patch_mask(pixel_labels_NP: jaxtyping.UInt8[Tensor, 'n patch_px'], threshold: float) ‑> jaxtyping.Bool[Tensor, 'n']

Create a mask for patches where at least threshold proportion of pixels have the same label.


Tensor of shape [n, patch_pixels] with pixel labels
Minimum proportion of pixels with same label


Tensor of shape [n] with True for patches that pass the threshold

def main(cfg: Quantitative)

Main entry point for quantitative evaluation.

def map_range(x: jaxtyping.Float[Tensor, '*batch'],
domain: tuple[float | int, float | int],
range: tuple[float | int, float | int]) ‑> jaxtyping.Float[Tensor, '*batch']
def register_hook(vit: torch.nn.modules.module.Module,
hook: Callable[[jaxtyping.Float[Tensor, '...']], jaxtyping.Float[Tensor, '...']],
layer: int,
n_patches_per_img: int)
def save(results: list[Report],
fpath: str) ‑> None

Save evaluation results to a CSV file.


List of Report objects containing evaluation results
Path to save the CSV file
def unscaled(x: jaxtyping.Float[Tensor, '*batch'], max_obs: float | int) ‑> jaxtyping.Float[Tensor, '*batch']

Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs].


class ClassResults (class_id: int,
class_name: str,
n_orig_patches: int,
n_changed_patches: int,
n_other_patches: int,
n_other_changed: int,
change_distribution: dict[int, int])

Results for a single class.

Expand source code
class ClassResults:
    """Results for a single class."""

    class_id: int
    """Numeric identifier for the class."""

    class_name: str
    """Human-readable name of the class."""

    n_orig_patches: int
    """Original patches that were this class."""

    n_changed_patches: int
    """After intervention, how many patches changed."""

    n_other_patches: int
    """Total patches that weren't this class."""

    n_other_changed: int
    """After intervention, how many of the other patches changed."""

    change_distribution: dict[int, int]
    """What classes did patches change to? Tracks how many times <value> a patch changed from self.class_id to <key>."""

Class variables

var change_distribution : dict[int, int]

What classes did patches change to? Tracks how many times a patch changed from self.class_id to .

var class_id : int

Numeric identifier for the class.

var class_name : str

Human-readable name of the class.

var n_changed_patches : int

After intervention, how many patches changed.

var n_orig_patches : int

Original patches that were this class.

var n_other_changed : int

After intervention, how many of the other patches changed.

var n_other_patches : int

Total patches that weren't this class.

class Report (method: str,
class_results: list[ClassResults],
intervention_scale: float)

Complete results from an intervention experiment.

Expand source code
class Report:
    """Complete results from an intervention experiment."""

    method: str
    """Which intervention method was used."""

    class_results: list[ClassResults]
    """Per-class detailed results."""

    intervention_scale: float
    """Magnitude of intervention."""

    def mean_target_change(self) -> float:
        """Percentage of target patches that changed class."""
        total_target = sum(r.n_orig_patches for r in self.class_results)
        total_changed = sum(r.n_changed_patches for r in self.class_results)
        return total_changed / total_target if total_target > 0 else 0.0

    def mean_other_change(self) -> float:
        """Percentage of non-target patches that changed class."""
        total_other = sum(r.n_other_patches for r in self.class_results)
        total_changed = sum(r.n_other_changed for r in self.class_results)
        return total_changed / total_other if total_other > 0 else 0.0

    def target_change_std(self) -> float:
        """Standard deviation of change percentage across classes."""
        per_class_target_changes = np.array([
            r.n_changed_patches / r.n_orig_patches if r.n_orig_patches > 0 else 0.0
            for r in self.class_results
        return float(np.std(per_class_target_changes))

    def other_change_std(self) -> float:
        """Standard deviation of non-target patch changes across classes."""
        per_class_other_changes = np.array([
            r.n_other_changed / r.n_other_patches if r.n_other_patches > 0 else 0.0
            for r in self.class_results
        return float(np.std(per_class_other_changes))

    def to_csv_row(self) -> dict[str, float]:
        """Convert to a row for the summary CSV."""
        return {
            "method": self.method,
            "target_change": self.mean_target_change,
            "other_change": self.mean_other_change,
            "target_std": self.target_change_std,
            "other_std": self.other_change_std,

Class variables

var class_results : list[ClassResults]

Per-class detailed results.

var intervention_scale : float

Magnitude of intervention.

var method : str

Which intervention method was used.

Instance variables

prop mean_other_change : float

Percentage of non-target patches that changed class.

Expand source code
def mean_other_change(self) -> float:
    """Percentage of non-target patches that changed class."""
    total_other = sum(r.n_other_patches for r in self.class_results)
    total_changed = sum(r.n_other_changed for r in self.class_results)
    return total_changed / total_other if total_other > 0 else 0.0
prop mean_target_change : float

Percentage of target patches that changed class.

Expand source code
def mean_target_change(self) -> float:
    """Percentage of target patches that changed class."""
    total_target = sum(r.n_orig_patches for r in self.class_results)
    total_changed = sum(r.n_changed_patches for r in self.class_results)
    return total_changed / total_target if total_target > 0 else 0.0
prop other_change_std : float

Standard deviation of non-target patch changes across classes.

Expand source code
def other_change_std(self) -> float:
    """Standard deviation of non-target patch changes across classes."""
    per_class_other_changes = np.array([
        r.n_other_changed / r.n_other_patches if r.n_other_patches > 0 else 0.0
        for r in self.class_results
    return float(np.std(per_class_other_changes))
prop target_change_std : float

Standard deviation of change percentage across classes.

Expand source code
def target_change_std(self) -> float:
    """Standard deviation of change percentage across classes."""
    per_class_target_changes = np.array([
        r.n_changed_patches / r.n_orig_patches if r.n_orig_patches > 0 else 0.0
        for r in self.class_results
    return float(np.std(per_class_target_changes))


def to_csv_row(self) ‑> dict[str, float]

Convert to a row for the summary CSV.