Module contrib.semseg.quantitative
Functions
def argmax_logits(logits_BPC: jaxtyping.Float[Tensor, 'batch patches channels_with_null']) ‑> jaxtyping.Int[Tensor, 'batch patches']
def compute_class_results(orig_preds: jaxtyping.Int[Tensor, 'n_imgs patches'],
mod_preds: jaxtyping.Int[Tensor, 'n_imgs patches']) ‑> list[ClassResults]def eval_auto_feat(cfg: Quantitative,
sae: SparseAutoencoder,
clf: torch.nn.modules.module.Module,
dataloader) ‑> Reportdef eval_rand_feat(cfg: Quantitative,
sae: SparseAutoencoder,
clf: torch.nn.modules.module.Module,
dataloader) ‑> Report-
Evaluates the effects of suppressing a random SAE feature.
Args
cfg
- Configuration for quantitative evaluation
sae
- Trained sparse autoencoder model
clf
- Trained classifier model
dataloader
- DataLoader providing batches of images
Returns
Report containing intervention results, including per-class changes
def eval_rand_vec(cfg: Quantitative,
sae: SparseAutoencoder,
clf: torch.nn.modules.module.Module,
dataloader) ‑> Report-
Evaluates the effects of adding a random unit vector to the patches.
Args
cfg
- Configuration for quantitative evaluation
sae
- Trained sparse autoencoder model
clf
- Trained classifier model
dataloader
- DataLoader providing batches of images
Returns
Report containing intervention results, including per-class changes
def get_latent_lookup(cfg: Quantitative,
sae: SparseAutoencoder,
dataloader) ‑> jaxtyping.Int[Tensor, '151']-
Dimension key:
- B: batch dimension
- P: patches per image
- D: ViT hidden dimension
- S: SAE feature dimension
- T: threshold dimension
- C: class dimension
- L: layer dimension
def get_patch_i(i: jaxtyping.Int[Tensor, 'batch width height'], n_patches_per_img: int) ‑> jaxtyping.Int[Tensor, 'batch width height']
def get_patch_mask(pixel_labels_NP: jaxtyping.UInt8[Tensor, 'n patch_px'], threshold: float) ‑> jaxtyping.Bool[Tensor, 'n']
-
Create a mask for patches where at least threshold proportion of pixels have the same label.
Args
pixel_labels_NP
- Tensor of shape [n, patch_pixels] with pixel labels
threshold
- Minimum proportion of pixels with same label
Returns
Tensor of shape [n] with True for patches that pass the threshold
def main(cfg: Quantitative)
-
Main entry point for quantitative evaluation.
def map_range(x: jaxtyping.Float[Tensor, '*batch'],
domain: tuple[float | int, float | int],
range: tuple[float | int, float | int]) ‑> jaxtyping.Float[Tensor, '*batch']def register_hook(vit: torch.nn.modules.module.Module,
hook: Callable[[jaxtyping.Float[Tensor, '...']], jaxtyping.Float[Tensor, '...']],
layer: int,
n_patches_per_img: int)def save(results: list[Report],
fpath: str) ‑> None-
Save evaluation results to a CSV file.
Args
results
- List of Report objects containing evaluation results
dpath
- Path to save the CSV file
def unscaled(x: jaxtyping.Float[Tensor, '*batch'], max_obs: float | int) ‑> jaxtyping.Float[Tensor, '*batch']
-
Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs].
Classes
class ClassResults (class_id: int,
class_name: str,
n_orig_patches: int,
n_changed_patches: int,
n_other_patches: int,
n_other_changed: int,
change_distribution: dict[int, int])-
Results for a single class.
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class ClassResults: """Results for a single class.""" class_id: int """Numeric identifier for the class.""" class_name: str """Human-readable name of the class.""" n_orig_patches: int """Original patches that were this class.""" n_changed_patches: int """After intervention, how many patches changed.""" n_other_patches: int """Total patches that weren't this class.""" n_other_changed: int """After intervention, how many of the other patches changed.""" change_distribution: dict[int, int] """What classes did patches change to? Tracks how many times <value> a patch changed from self.class_id to <key>."""
Class variables
var change_distribution : dict[int, int]
-
What classes did patches change to? Tracks how many times
a patch changed from self.class_id to . var class_id : int
-
Numeric identifier for the class.
var class_name : str
-
Human-readable name of the class.
var n_changed_patches : int
-
After intervention, how many patches changed.
var n_orig_patches : int
-
Original patches that were this class.
var n_other_changed : int
-
After intervention, how many of the other patches changed.
var n_other_patches : int
-
Total patches that weren't this class.
class Report (method: str,
class_results: list[ClassResults],
intervention_scale: float)-
Complete results from an intervention experiment.
Expand source code
@jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True) class Report: """Complete results from an intervention experiment.""" method: str """Which intervention method was used.""" class_results: list[ClassResults] """Per-class detailed results.""" intervention_scale: float """Magnitude of intervention.""" @property def mean_target_change(self) -> float: """Percentage of target patches that changed class.""" total_target = sum(r.n_orig_patches for r in self.class_results) total_changed = sum(r.n_changed_patches for r in self.class_results) return total_changed / total_target if total_target > 0 else 0.0 @property def mean_other_change(self) -> float: """Percentage of non-target patches that changed class.""" total_other = sum(r.n_other_patches for r in self.class_results) total_changed = sum(r.n_other_changed for r in self.class_results) return total_changed / total_other if total_other > 0 else 0.0 @property def target_change_std(self) -> float: """Standard deviation of change percentage across classes.""" per_class_target_changes = np.array([ r.n_changed_patches / r.n_orig_patches if r.n_orig_patches > 0 else 0.0 for r in self.class_results ]) return float(np.std(per_class_target_changes)) @property def other_change_std(self) -> float: """Standard deviation of non-target patch changes across classes.""" per_class_other_changes = np.array([ r.n_other_changed / r.n_other_patches if r.n_other_patches > 0 else 0.0 for r in self.class_results ]) return float(np.std(per_class_other_changes)) def to_csv_row(self) -> dict[str, float]: """Convert to a row for the summary CSV.""" return { "method": self.method, "target_change": self.mean_target_change, "other_change": self.mean_other_change, "target_std": self.target_change_std, "other_std": self.other_change_std, }
Class variables
var class_results : list[ClassResults]
-
Per-class detailed results.
var intervention_scale : float
-
Magnitude of intervention.
var method : str
-
Which intervention method was used.
Instance variables
prop mean_other_change : float
-
Percentage of non-target patches that changed class.
Expand source code
@property def mean_other_change(self) -> float: """Percentage of non-target patches that changed class.""" total_other = sum(r.n_other_patches for r in self.class_results) total_changed = sum(r.n_other_changed for r in self.class_results) return total_changed / total_other if total_other > 0 else 0.0
prop mean_target_change : float
-
Percentage of target patches that changed class.
Expand source code
@property def mean_target_change(self) -> float: """Percentage of target patches that changed class.""" total_target = sum(r.n_orig_patches for r in self.class_results) total_changed = sum(r.n_changed_patches for r in self.class_results) return total_changed / total_target if total_target > 0 else 0.0
prop other_change_std : float
-
Standard deviation of non-target patch changes across classes.
Expand source code
@property def other_change_std(self) -> float: """Standard deviation of non-target patch changes across classes.""" per_class_other_changes = np.array([ r.n_other_changed / r.n_other_patches if r.n_other_patches > 0 else 0.0 for r in self.class_results ]) return float(np.std(per_class_other_changes))
prop target_change_std : float
-
Standard deviation of change percentage across classes.
Expand source code
@property def target_change_std(self) -> float: """Standard deviation of change percentage across classes.""" per_class_target_changes = np.array([ r.n_changed_patches / r.n_orig_patches if r.n_orig_patches > 0 else 0.0 for r in self.class_results ]) return float(np.std(per_class_target_changes))
Methods
def to_csv_row(self) ‑> dict[str, float]
-
Convert to a row for the summary CSV.