Module contrib.semseg.quantitative

Functions

def argmax_logits(logits_BPC: jaxtyping.Float[Tensor, 'batch patches channels_with_null']) ‑> jaxtyping.Int[Tensor, 'batch patches']
def compute_class_results(orig_preds: jaxtyping.Int[Tensor, 'n_imgs patches'],
mod_preds: jaxtyping.Int[Tensor, 'n_imgs patches']) ‑> list[ClassResults]
def eval_auto_feat(cfg: Quantitative,
sae: SparseAutoencoder,
clf: torch.nn.modules.module.Module,
dataloader) ‑> Report
def eval_rand_feat(cfg: Quantitative,
sae: SparseAutoencoder,
clf: torch.nn.modules.module.Module,
dataloader) ‑> Report

Evaluates the effects of suppressing a random SAE feature.

Args

cfg
Configuration for quantitative evaluation
sae
Trained sparse autoencoder model
clf
Trained classifier model
dataloader
DataLoader providing batches of images

Returns

Report containing intervention results, including per-class changes

def eval_rand_vec(cfg: Quantitative,
sae: SparseAutoencoder,
clf: torch.nn.modules.module.Module,
dataloader) ‑> Report

Evaluates the effects of adding a random unit vector to the patches.

Args

cfg
Configuration for quantitative evaluation
sae
Trained sparse autoencoder model
clf
Trained classifier model
dataloader
DataLoader providing batches of images

Returns

Report containing intervention results, including per-class changes

def get_latent_lookup(cfg: Quantitative,
sae: SparseAutoencoder,
dataloader) ‑> jaxtyping.Int[Tensor, '151']

Dimension key:

  • B: batch dimension
  • P: patches per image
  • D: ViT hidden dimension
  • S: SAE feature dimension
  • T: threshold dimension
  • C: class dimension
  • L: layer dimension
def get_patch_i(i: jaxtyping.Int[Tensor, 'batch width height'], n_patches_per_img: int) ‑> jaxtyping.Int[Tensor, 'batch width height']
def get_patch_mask(pixel_labels_NP: jaxtyping.UInt8[Tensor, 'n patch_px'], threshold: float) ‑> jaxtyping.Bool[Tensor, 'n']

Create a mask for patches where at least threshold proportion of pixels have the same label.

Args

pixel_labels_NP
Tensor of shape [n, patch_pixels] with pixel labels
threshold
Minimum proportion of pixels with same label

Returns

Tensor of shape [n] with True for patches that pass the threshold

def main(cfg: Quantitative)

Main entry point for quantitative evaluation.

def map_range(x: jaxtyping.Float[Tensor, '*batch'],
domain: tuple[float | int, float | int],
range: tuple[float | int, float | int]) ‑> jaxtyping.Float[Tensor, '*batch']
def register_hook(vit: torch.nn.modules.module.Module,
hook: Callable[[jaxtyping.Float[Tensor, '...']], jaxtyping.Float[Tensor, '...']],
layer: int,
n_patches_per_img: int)
def save(results: list[Report],
fpath: str) ‑> None

Save evaluation results to a CSV file.

Args

results
List of Report objects containing evaluation results
dpath
Path to save the CSV file
def unscaled(x: jaxtyping.Float[Tensor, '*batch'], max_obs: float | int) ‑> jaxtyping.Float[Tensor, '*batch']

Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs].

Classes

class ClassResults (class_id: int,
class_name: str,
n_orig_patches: int,
n_changed_patches: int,
n_other_patches: int,
n_other_changed: int,
change_distribution: dict[int, int])

Results for a single class.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class ClassResults:
    """Results for a single class."""

    class_id: int
    """Numeric identifier for the class."""

    class_name: str
    """Human-readable name of the class."""

    n_orig_patches: int
    """Original patches that were this class."""

    n_changed_patches: int
    """After intervention, how many patches changed."""

    n_other_patches: int
    """Total patches that weren't this class."""

    n_other_changed: int
    """After intervention, how many of the other patches changed."""

    change_distribution: dict[int, int]
    """What classes did patches change to? Tracks how many times <value> a patch changed from self.class_id to <key>."""

Class variables

var change_distribution : dict[int, int]

What classes did patches change to? Tracks how many times a patch changed from self.class_id to .

var class_id : int

Numeric identifier for the class.

var class_name : str

Human-readable name of the class.

var n_changed_patches : int

After intervention, how many patches changed.

var n_orig_patches : int

Original patches that were this class.

var n_other_changed : int

After intervention, how many of the other patches changed.

var n_other_patches : int

Total patches that weren't this class.

class Report (method: str,
class_results: list[ClassResults],
intervention_scale: float)

Complete results from an intervention experiment.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass(frozen=True)
class Report:
    """Complete results from an intervention experiment."""

    method: str
    """Which intervention method was used."""

    class_results: list[ClassResults]
    """Per-class detailed results."""

    intervention_scale: float
    """Magnitude of intervention."""

    @property
    def mean_target_change(self) -> float:
        """Percentage of target patches that changed class."""
        total_target = sum(r.n_orig_patches for r in self.class_results)
        total_changed = sum(r.n_changed_patches for r in self.class_results)
        return total_changed / total_target if total_target > 0 else 0.0

    @property
    def mean_other_change(self) -> float:
        """Percentage of non-target patches that changed class."""
        total_other = sum(r.n_other_patches for r in self.class_results)
        total_changed = sum(r.n_other_changed for r in self.class_results)
        return total_changed / total_other if total_other > 0 else 0.0

    @property
    def target_change_std(self) -> float:
        """Standard deviation of change percentage across classes."""
        per_class_target_changes = np.array([
            r.n_changed_patches / r.n_orig_patches if r.n_orig_patches > 0 else 0.0
            for r in self.class_results
        ])
        return float(np.std(per_class_target_changes))

    @property
    def other_change_std(self) -> float:
        """Standard deviation of non-target patch changes across classes."""
        per_class_other_changes = np.array([
            r.n_other_changed / r.n_other_patches if r.n_other_patches > 0 else 0.0
            for r in self.class_results
        ])
        return float(np.std(per_class_other_changes))

    def to_csv_row(self) -> dict[str, float]:
        """Convert to a row for the summary CSV."""
        return {
            "method": self.method,
            "target_change": self.mean_target_change,
            "other_change": self.mean_other_change,
            "target_std": self.target_change_std,
            "other_std": self.other_change_std,
        }

Class variables

var class_results : list[ClassResults]

Per-class detailed results.

var intervention_scale : float

Magnitude of intervention.

var method : str

Which intervention method was used.

Instance variables

prop mean_other_change : float

Percentage of non-target patches that changed class.

Expand source code
@property
def mean_other_change(self) -> float:
    """Percentage of non-target patches that changed class."""
    total_other = sum(r.n_other_patches for r in self.class_results)
    total_changed = sum(r.n_other_changed for r in self.class_results)
    return total_changed / total_other if total_other > 0 else 0.0
prop mean_target_change : float

Percentage of target patches that changed class.

Expand source code
@property
def mean_target_change(self) -> float:
    """Percentage of target patches that changed class."""
    total_target = sum(r.n_orig_patches for r in self.class_results)
    total_changed = sum(r.n_changed_patches for r in self.class_results)
    return total_changed / total_target if total_target > 0 else 0.0
prop other_change_std : float

Standard deviation of non-target patch changes across classes.

Expand source code
@property
def other_change_std(self) -> float:
    """Standard deviation of non-target patch changes across classes."""
    per_class_other_changes = np.array([
        r.n_other_changed / r.n_other_patches if r.n_other_patches > 0 else 0.0
        for r in self.class_results
    ])
    return float(np.std(per_class_other_changes))
prop target_change_std : float

Standard deviation of change percentage across classes.

Expand source code
@property
def target_change_std(self) -> float:
    """Standard deviation of change percentage across classes."""
    per_class_target_changes = np.array([
        r.n_changed_patches / r.n_orig_patches if r.n_orig_patches > 0 else 0.0
        for r in self.class_results
    ])
    return float(np.std(per_class_target_changes))

Methods

def to_csv_row(self) ‑> dict[str, float]

Convert to a row for the summary CSV.