From dfcc3e11440cce42b9a2df1cefcaf32340fd4ef8 Mon Sep 17 00:00:00 2001 From: "Junbo Shen (Jacob)" Date: Thu, 12 Mar 2026 22:28:08 +0800 Subject: [PATCH] gradcam --- docs/api/interpret.rst | 10 + .../pyhealth.interpret.methods.gradcam.rst | 62 ++++ docs/tutorials.rst | 2 + examples/cxr/gradcam_cxr_tutorial.py | 181 ++++++++++ pyhealth/interpret/methods/__init__.py | 4 +- pyhealth/interpret/methods/gradcam.py | 311 ++++++++++++++++++ tests/core/test_gradcam.py | 226 +++++++++++++ 7 files changed, 795 insertions(+), 1 deletion(-) create mode 100644 docs/api/interpret/pyhealth.interpret.methods.gradcam.rst create mode 100644 examples/cxr/gradcam_cxr_tutorial.py create mode 100644 pyhealth/interpret/methods/gradcam.py create mode 100644 tests/core/test_gradcam.py diff --git a/docs/api/interpret.rst b/docs/api/interpret.rst index 747f05b56..8bf6d9ca7 100644 --- a/docs/api/interpret.rst +++ b/docs/api/interpret.rst @@ -67,6 +67,15 @@ New to interpretability in PyHealth? Check out these complete examples: - Test various distance kernels (cosine vs euclidean) and sample sizes - Decode attributions to human-readable medical codes and lab measurements +**Grad-CAM Example:** + +- ``examples/cxr/gradcam_cxr_tutorial.py`` - Demonstrates Grad-CAM for CNN-based medical image classification. Shows how to: + + - Choose a target convolutional layer from a PyHealth image model + - Generate class-conditional heatmaps for chest X-ray images + - Overlay the Grad-CAM heatmap on the original image for interpretation + - Run the example from a dataset path without editing the source file + These examples provide end-to-end workflows from loading data to interpreting and evaluating attributions. Attribution Methods @@ -82,6 +91,7 @@ Attribution Methods interpret/pyhealth.interpret.methods.integrated_gradients interpret/pyhealth.interpret.methods.shap interpret/pyhealth.interpret.methods.lime + interpret/pyhealth.interpret.methods.gradcam Visualization Utilities ----------------------- diff --git a/docs/api/interpret/pyhealth.interpret.methods.gradcam.rst b/docs/api/interpret/pyhealth.interpret.methods.gradcam.rst new file mode 100644 index 000000000..ece2ea9cf --- /dev/null +++ b/docs/api/interpret/pyhealth.interpret.methods.gradcam.rst @@ -0,0 +1,62 @@ +pyhealth.interpret.methods.gradcam +================================== + +Overview +-------- + +Grad-CAM provides class-conditional heatmaps for CNN-based image +classification models in PyHealth. It uses gradients from a target +convolutional layer to highlight which image regions contributed most to the +selected prediction. + +This method is intended for: + +- CNN image classification models +- chest X-ray and other medical imaging workflows built on PyHealth image tasks +- models that return either ``logit`` or ``y_prob`` + +Usage Notes +----------- + +1. **CNN model**: Grad-CAM requires a 4D convolutional activation map from + the target layer. +2. **Target layer**: You can pass either an ``nn.Module`` directly or a dotted + string path such as ``"model.layer4.1.conv2"``. +3. **Class selection**: If ``class_index`` is omitted, Grad-CAM uses the + predicted class. For single-output binary models, it attributes to that + scalar output. +4. **Gradients required**: Do not call ``attribute()`` inside + ``torch.no_grad()``. +5. **Return shape**: ``attribute()`` returns ``{input_key: cam}`` where the CAM + tensor has shape ``[B, H, W]``. + +Quick Start +----------- + +.. code-block:: python + + from pyhealth.interpret.methods import GradCAM + from pyhealth.interpret.utils import visualize_image_attr + + gradcam = GradCAM( + model, + target_layer=model.model.layer4[-1].conv2, + input_key="image", + ) + cams = gradcam.attribute(**batch) + image, heatmap, overlay = visualize_image_attr( + image=batch["image"][0], + attribution=cams["image"][0], + ) + +For a complete script example, see: +``examples/cxr/gradcam_cxr_tutorial.py`` + +API Reference +------------- + +.. autoclass:: pyhealth.interpret.methods.gradcam.GradCAM + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/docs/tutorials.rst b/docs/tutorials.rst index fcdab84ea..774ac3575 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -208,6 +208,8 @@ These examples are located in ``examples/cxr/``. - Conformal prediction for COVID-19 CXR classification * - ``cxr/cnn_cxr.ipynb`` - CNN for chest X-ray classification (notebook) + * - ``cxr/gradcam_cxr_tutorial.py`` + - Grad-CAM for CNN-based chest X-ray classification * - ``cxr/chestxray14_binary_classification.ipynb`` - Binary classification on ChestX-ray14 dataset (notebook) * - ``cxr/chestxray14_multilabel_classification.ipynb`` diff --git a/examples/cxr/gradcam_cxr_tutorial.py b/examples/cxr/gradcam_cxr_tutorial.py new file mode 100644 index 000000000..708966d47 --- /dev/null +++ b/examples/cxr/gradcam_cxr_tutorial.py @@ -0,0 +1,181 @@ +"""Grad-CAM tutorial for CNN-based chest X-ray classification in PyHealth. + +Prerequisites: +- A local COVID-19 Radiography Database root passed with ``--root`` + +Notes: +- For meaningful class-specific visualizations, pass ``--checkpoint`` with a + trained PyHealth checkpoint. Without a checkpoint, the script still runs as a + pipeline example, but the classification head is randomly initialized. +- ``--weights DEFAULT`` may trigger a first-run torchvision download. Use + ``--weights none`` for an offline run. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import torch + +from pyhealth.datasets import COVID19CXRDataset, SampleDataset, get_dataloader +from pyhealth.interpret.methods import GradCAM +from pyhealth.interpret.utils import visualize_image_attr +from pyhealth.models import TorchvisionModel + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments for the Grad-CAM tutorial. + + Returns: + argparse.Namespace: Parsed CLI arguments controlling dataset location, + model initialization, runtime device, and output path. + """ + parser = argparse.ArgumentParser( + description="Run Grad-CAM on one chest X-ray sample.", + ) + parser.add_argument( + "--root", + required=True, + help="Path to the COVID-19 Radiography Database root directory.", + ) + parser.add_argument( + "--checkpoint", + default=None, + help="Optional checkpoint to load before inference.", + ) + parser.add_argument( + "--output", + default="gradcam_cxr_overlay.png", + help="Where to save the Grad-CAM figure.", + ) + parser.add_argument( + "--device", + default=None, + help="Optional device override such as 'cpu' or 'cuda:0'.", + ) + parser.add_argument( + "--weights", + choices=["DEFAULT", "none"], + default="DEFAULT", + help="Torchvision backbone weights to use when initializing resnet18.", + ) + return parser.parse_args() + + +def resolve_device(device_arg: str | None) -> str: + """Resolve the device string used for inference. + + Args: + device_arg: Optional CLI override such as ``"cpu"`` or ``"cuda:0"``. + + Returns: + str: The resolved device string. + """ + if device_arg is not None: + return device_arg + return "cuda:0" if torch.cuda.is_available() else "cpu" + + +def load_dataset(root: str) -> SampleDataset: + """Load the COVID-19 CXR sample dataset for the tutorial. + + Args: + root: Root directory containing the COVID-19 Radiography Database. + + Returns: + SampleDataset: Task-applied sample dataset ready for dataloader use. + + Raises: + SystemExit: If ``openpyxl`` is required but unavailable. + """ + try: + dataset = COVID19CXRDataset(root, num_workers=1) + return dataset.set_task(num_workers=1) + except ImportError as exc: + if "openpyxl" in str(exc): + raise SystemExit( + "This example needs 'openpyxl' to read the raw metadata sheets. " + "Install it with: pip install openpyxl" + ) from exc + raise + + +def main() -> None: + """Run Grad-CAM on a single chest X-ray sample and save a figure.""" + args = parse_args() + root = Path(args.root).expanduser() + if not root.exists(): + raise SystemExit(f"Dataset root does not exist: {root}") + + sample_dataset = load_dataset(str(root)) + loader = get_dataloader(sample_dataset, batch_size=1, shuffle=False) + batch = next(iter(loader)) + + weights = None if args.weights == "none" else "DEFAULT" + model = TorchvisionModel( + dataset=sample_dataset, + model_name="resnet18", + model_config={"weights": weights}, + ) + device = resolve_device(args.device) + model = model.to(device) + model.eval() + + if args.checkpoint: + checkpoint_path = Path(args.checkpoint).expanduser() + if not checkpoint_path.exists(): + raise SystemExit(f"Checkpoint does not exist: {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=device) + model.load_state_dict(state_dict) + print(f"Loaded checkpoint from {checkpoint_path}") + else: + print( + "Warning: no checkpoint provided. The classifier head is randomly " + "initialized, so this run is only a pipeline example." + ) + + with torch.no_grad(): + y_prob = model(**batch)["y_prob"][0] + + label_vocab = sample_dataset.output_processors["disease"].label_vocab + pred_class = int(torch.argmax(y_prob).item()) + id2label = {value: key for key, value in label_vocab.items()} + pred_label = id2label[pred_class] + + gradcam = GradCAM( + model, + target_layer=model.model.layer4[-1].conv2, + input_key="image", + ) + cam = gradcam.attribute(class_index=pred_class, **batch)["image"] + + image, heatmap, overlay = visualize_image_attr( + image=batch["image"][0], + attribution=cam[0], + ) + + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + axes[0].imshow(image, cmap="gray") + axes[0].set_title("Input") + axes[0].axis("off") + + axes[1].imshow(heatmap, cmap="jet") + axes[1].set_title("Grad-CAM") + axes[1].axis("off") + + axes[2].imshow(overlay) + axes[2].set_title(f"Overlay: {pred_label}") + axes[2].axis("off") + + output_path = Path(args.output).expanduser() + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.tight_layout() + plt.savefig(output_path, dpi=150) + print(f"Predicted class: {pred_label}") + print(f"Saved Grad-CAM visualization to {output_path.resolve()}") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index 6c92cb6e4..cf3d760d0 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -11,6 +11,7 @@ from pyhealth.interpret.methods.ensemble_crh import CrhEnsemble from pyhealth.interpret.methods.ensemble_avg import AvgEnsemble from pyhealth.interpret.methods.ensemble_var import VarEnsemble +from pyhealth.interpret.methods.gradcam import GradCAM __all__ = [ "BaseInterpreter", @@ -25,5 +26,6 @@ "LimeExplainer", "CrhEnsemble", "AvgEnsemble", - "VarEnsemble" + "VarEnsemble", + "GradCAM", ] diff --git a/pyhealth/interpret/methods/gradcam.py b/pyhealth/interpret/methods/gradcam.py new file mode 100644 index 000000000..bd51fc649 --- /dev/null +++ b/pyhealth/interpret/methods/gradcam.py @@ -0,0 +1,311 @@ +# Paper: Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization +# Paper link: https://arxiv.org/abs/1610.02391 +# Description: Grad-CAM attribution method for CNN-based medical image classification in PyHealth. + +"""Grad-CAM for CNN-based image classification models.""" + +from __future__ import annotations + +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_interpreter import BaseInterpreter + + +class GradCAM(BaseInterpreter): + """Compute Grad-CAM heatmaps for CNN-based image classifiers. + + Grad-CAM generates a class-conditional localization map by combining + gradients with the activations from a target convolutional layer. + This implementation is designed for PyHealth image workflows and + returns CAMs using the same feature-keyed dict convention as other + interpretability methods. + + Args: + model: Trained model to interpret. + target_layer: Target convolutional layer as an ``nn.Module`` or a + dotted string path (for example ``"model.layer4.1.conv2"``). + input_key: Batch key containing the image tensor. Default is ``"image"``. + + Examples: + >>> gradcam = GradCAM(model, target_layer=model.model.layer4[-1].conv2) + >>> batch = next(iter(test_loader)) + >>> cams = gradcam.attribute(**batch) + >>> cams["image"].shape + torch.Size([1, 224, 224]) + """ + + def __init__( + self, + model: nn.Module, + target_layer: str | nn.Module, + input_key: str = "image", + ) -> None: + super().__init__(model) + self.input_key = input_key + self.target_layer = self._resolve_target_layer(target_layer) + self.last_target_class: Optional[torch.Tensor] = None + + def attribute( + self, + class_index: Optional[int | torch.Tensor] = None, + normalize: bool = True, + upsample: bool = True, + **data, + ) -> Dict[str, torch.Tensor]: + """Compute Grad-CAM heatmaps for a batch of images. + + Args: + class_index: Target class index to explain. If ``None``, uses the + model's predicted class. For binary score tensors with a single + output channel, the only valid explicit value is ``0``. + normalize: If ``True``, normalize each heatmap to ``[0, 1]``. + upsample: If ``True``, resize CAMs to the input image size. + **data: Batched model inputs, including the image tensor under + ``input_key`` and any required labels/metadata. + + Returns: + Dict[str, torch.Tensor]: Dictionary keyed by ``input_key`` with + CAM tensors of shape ``[B, H, W]``. + + Raises: + KeyError: If ``input_key`` is missing from ``data``. + ValueError: If the target layer cannot be resolved, does not + produce a 4D activation map, or the model output lacks both + ``logit`` and ``y_prob``. + """ + if self.input_key not in data: + raise KeyError( + f"Expected input key '{self.input_key}' in the attribution batch." + ) + image_tensor = data[self.input_key] + if not torch.is_tensor(image_tensor): + raise ValueError( + f"Grad-CAM requires '{self.input_key}' to be a batched image tensor." + ) + if image_tensor.dim() != 4: + raise ValueError("Grad-CAM requires image tensors with shape [B, C, H, W].") + + activations: dict[str, torch.Tensor] = {} + gradients: dict[str, torch.Tensor] = {} + + def forward_hook(_, __, output): + if not torch.is_tensor(output): + raise ValueError("Grad-CAM target layer must output a tensor.") + if output.dim() != 4: + raise ValueError( + "Grad-CAM requires a 4D convolutional activation map from " + "the target layer." + ) + if not output.requires_grad: + raise RuntimeError( + "Grad-CAM requires gradients. Do not call attribute() inside " + "torch.no_grad()." + ) + activations["value"] = output + output.register_hook(lambda grad: gradients.__setitem__("value", grad)) + + hook = self.target_layer.register_forward_hook(forward_hook) + self.model.zero_grad() + try: + score_tensor = self._forward_score_tensor(data) + target_indices, target_scores = self._select_target_scores( + score_tensor=score_tensor, + class_index=class_index, + ) + + self.model.zero_grad() + target_scores.sum().backward() + + if "value" not in activations or "value" not in gradients: + raise RuntimeError( + "Grad-CAM hooks did not capture activations and gradients." + ) + + cams = self._compute_cam( + activations=activations["value"], + gradients=gradients["value"], + ) + + if upsample: + cams = F.interpolate( + cams.unsqueeze(1), + size=image_tensor.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(1) + + if normalize: + cams = self._normalize_cam(cams) + + # Keep target indices accessible for debugging and examples. + self.last_target_class = target_indices.detach().cpu() + return {self.input_key: cams} + finally: + hook.remove() + self.model.zero_grad() + + def _resolve_target_layer(self, target_layer: str | nn.Module) -> nn.Module: + if isinstance(target_layer, nn.Module): + if not any(module is target_layer for module in self.model.modules()): + raise ValueError( + "Grad-CAM target_layer must be a submodule of the model." + ) + return target_layer + if not isinstance(target_layer, str) or not target_layer: + raise ValueError("target_layer must be a non-empty string or nn.Module.") + + current = self.model + for part in target_layer.split("."): + if part.isdigit(): + try: + current = current[int(part)] + except Exception as exc: + raise ValueError( + f"Could not resolve target layer index '{part}' in " + f"path '{target_layer}'." + ) from exc + else: + if not hasattr(current, part): + raise ValueError( + f"Could not resolve target layer path '{target_layer}'. " + f"Missing attribute '{part}'." + ) + current = getattr(current, part) + + if not isinstance(current, nn.Module): + raise ValueError( + f"Resolved target layer '{target_layer}' is not an nn.Module." + ) + return current + + def _forward_score_tensor(self, data: dict) -> torch.Tensor: + score_tensor = self._forward_torchvision_logits(data) + if score_tensor is not None: + return score_tensor + outputs = self.model(**data) + return self._resolve_score_tensor(outputs) + + def _forward_torchvision_logits(self, data: dict) -> Optional[torch.Tensor]: + try: + from pyhealth.models.torchvision_model import TorchvisionModel + except Exception: + return None + + if not isinstance(self.model, TorchvisionModel): + return None + + image_tensor = data[self.input_key].to(self.model.device) + if image_tensor.shape[1] == 1: + image_tensor = image_tensor.repeat((1, 3, 1, 1)) + return self.model.model(image_tensor) + + @staticmethod + def _resolve_score_tensor(outputs: dict) -> torch.Tensor: + if not isinstance(outputs, dict): + raise ValueError( + "Grad-CAM expects model outputs to be a dict containing " + "'logit' or 'y_prob'." + ) + if "logit" in outputs: + score_tensor = outputs["logit"] + elif "y_prob" in outputs: + score_tensor = outputs["y_prob"] + else: + raise ValueError( + "Grad-CAM requires model outputs to contain 'logit' or 'y_prob'." + ) + if not torch.is_tensor(score_tensor): + raise ValueError( + "Grad-CAM requires 'logit' or 'y_prob' to be a torch.Tensor." + ) + if score_tensor.dim() not in (1, 2): + raise ValueError( + "Grad-CAM requires classification scores shaped [B] or [B, C]." + ) + return score_tensor + + @staticmethod + def _select_target_scores( + score_tensor: torch.Tensor, + class_index: Optional[int | torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if score_tensor.dim() == 1: + score_tensor = score_tensor.unsqueeze(-1) + + batch_size = score_tensor.shape[0] + device = score_tensor.device + + if score_tensor.shape[-1] == 1: + if class_index is not None: + if isinstance(class_index, int): + if class_index != 0: + raise ValueError( + "Single-output Grad-CAM only supports class_index=0." + ) + else: + class_index = class_index.to(device) + if torch.any(class_index != 0): + raise ValueError( + "Single-output Grad-CAM only supports class_index=0." + ) + target_indices = torch.zeros(batch_size, dtype=torch.long, device=device) + target_scores = score_tensor.reshape(batch_size) + return target_indices, target_scores + + num_classes = score_tensor.shape[-1] + if class_index is None: + target_indices = torch.argmax(score_tensor, dim=-1) + elif isinstance(class_index, int): + if class_index < 0 or class_index >= num_classes: + raise ValueError( + f"class_index must be in [0, {num_classes - 1}] for this model." + ) + target_indices = torch.full( + (batch_size,), + class_index, + dtype=torch.long, + device=device, + ) + else: + target_indices = class_index.to(device).long() + if target_indices.dim() == 0: + target_indices = torch.full( + (batch_size,), + int(target_indices.item()), + dtype=torch.long, + device=device, + ) + elif target_indices.dim() != 1: + raise ValueError("Tensor class_index must be a scalar or a 1D tensor.") + elif target_indices.shape[0] != batch_size: + raise ValueError( + "Tensor class_index must have one target per batch element." + ) + if torch.any((target_indices < 0) | (target_indices >= num_classes)): + raise ValueError( + f"class_index values must be in [0, {num_classes - 1}]." + ) + + target_scores = score_tensor.gather(1, target_indices.unsqueeze(1)).squeeze(1) + return target_indices, target_scores + + @staticmethod + def _compute_cam( + activations: torch.Tensor, + gradients: torch.Tensor, + ) -> torch.Tensor: + weights = gradients.mean(dim=(2, 3), keepdim=True) + cams = torch.relu((weights * activations).sum(dim=1)) + return cams + + @staticmethod + def _normalize_cam(cams: torch.Tensor) -> torch.Tensor: + flat = cams.flatten(start_dim=1) + min_vals = flat.min(dim=1).values.view(-1, 1, 1) + max_vals = flat.max(dim=1).values.view(-1, 1, 1) + denom = (max_vals - min_vals).clamp_min(1e-8) + return (cams - min_vals) / denom diff --git a/tests/core/test_gradcam.py b/tests/core/test_gradcam.py new file mode 100644 index 000000000..a0aed6ed7 --- /dev/null +++ b/tests/core/test_gradcam.py @@ -0,0 +1,226 @@ +"""Tests for the Grad-CAM interpretability method.""" + +import os +import shutil +import tempfile +import unittest + +import numpy as np +from PIL import Image + +import torch +import torch.nn as nn + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.interpret.methods import GradCAM +from pyhealth.models import TorchvisionModel + + +class SimpleProbCNN(nn.Module): + def __init__(self, num_classes: int = 2): + super().__init__() + self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(4, num_classes) + + def forward(self, **kwargs): + x = kwargs["image"] + feats = self.conv(x) + logits = self.fc(self.pool(feats).flatten(1)) + return {"y_prob": torch.softmax(logits, dim=1)} + + +class SimpleLogitCNN(nn.Module): + def __init__(self, num_classes: int = 2): + super().__init__() + self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(4, num_classes) + + def forward(self, **kwargs): + x = kwargs["image"] + feats = self.conv(x) + logits = self.fc(self.pool(feats).flatten(1)) + return { + "logit": logits, + "y_prob": torch.softmax(logits, dim=1), + } + + +class TorchvisionLogitShim(nn.Module): + def __init__(self, core_model: nn.Module): + super().__init__() + self.model = core_model + + def forward(self, **kwargs): + x = kwargs["image"] + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + return {"logit": self.model(x)} + + +class TestGradCAMToyCNN(unittest.TestCase): + def setUp(self): + torch.manual_seed(7) + self.batch = {"image": torch.randn(2, 3, 32, 32)} + + def test_gradcam_forward_backward_shape_and_batch_support(self): + model = SimpleLogitCNN() + gradcam = GradCAM(model, target_layer=model.conv) + + attributions = gradcam.attribute(**self.batch) + + self.assertIn("image", attributions) + self.assertEqual(attributions["image"].shape, (2, 32, 32)) + + def test_gradcam_default_target_uses_prediction(self): + model = SimpleProbCNN() + expected = model(**self.batch)["y_prob"].argmax(dim=1) + gradcam = GradCAM(model, target_layer=model.conv) + + gradcam.attribute(**self.batch) + + self.assertTrue(torch.equal(gradcam.last_target_class, expected.cpu())) + + def test_gradcam_explicit_target(self): + model = SimpleLogitCNN(num_classes=3) + gradcam = GradCAM(model, target_layer=model.conv) + + gradcam.attribute(class_index=1, **self.batch) + + self.assertTrue(torch.equal(gradcam.last_target_class, torch.tensor([1, 1]))) + + def test_gradcam_bad_layer_path(self): + model = SimpleLogitCNN() + with self.assertRaises(ValueError): + GradCAM(model, target_layer="missing.layer") + + def test_gradcam_non_spatial_layer_error(self): + model = SimpleLogitCNN() + gradcam = GradCAM(model, target_layer="fc") + with self.assertRaises(ValueError): + gradcam.attribute(**self.batch) + + def test_gradcam_normalization(self): + model = SimpleLogitCNN() + gradcam = GradCAM(model, target_layer=model.conv) + + cam = gradcam.attribute(normalize=True, **self.batch)["image"] + + self.assertTrue(torch.all(cam >= 0)) + self.assertTrue(torch.all(cam <= 1)) + + def test_gradcam_y_prob_fallback(self): + model = SimpleProbCNN() + gradcam = GradCAM(model, target_layer="conv") + + attributions = gradcam.attribute(**self.batch) + + self.assertEqual(attributions["image"].shape, (2, 32, 32)) + + def test_gradcam_missing_input_key(self): + model = SimpleLogitCNN() + gradcam = GradCAM(model, target_layer=model.conv, input_key="xray") + + with self.assertRaises(KeyError): + gradcam.attribute(**self.batch) + + def test_gradcam_invalid_class_index_raises_value_error(self): + model = SimpleLogitCNN(num_classes=3) + gradcam = GradCAM(model, target_layer=model.conv) + + with self.assertRaises(ValueError): + gradcam.attribute(class_index=5, **self.batch) + + def test_gradcam_invalid_class_index_tensor_shape_raises_value_error(self): + model = SimpleLogitCNN(num_classes=3) + gradcam = GradCAM(model, target_layer=model.conv) + + with self.assertRaises(ValueError): + gradcam.attribute(class_index=torch.tensor([1]), **self.batch) + + def test_gradcam_no_grad_context_raises_runtime_error(self): + model = SimpleLogitCNN() + gradcam = GradCAM(model, target_layer=model.conv) + + with self.assertRaises(RuntimeError): + with torch.no_grad(): + gradcam.attribute(**self.batch) + + +class TestGradCAMTorchvisionModel(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + cls.samples = [] + for i in range(4): + # Build tiny synthetic grayscale PNGs on disk for wrapper-level tests. + img_path = os.path.join(cls.temp_dir, f"img_{i}.png") + image = Image.fromarray( + np.random.randint(0, 255, (64, 64), dtype=np.uint8), + mode="L", + ) + image.save(img_path) + cls.samples.append( + { + "patient_id": f"p{i // 2}", + "visit_id": f"v{i}", + "image": img_path, + "label": i % 2, + } + ) + + cls.dataset = create_sample_dataset( + samples=cls.samples, + input_schema={"image": ("image", {"image_size": 64, "mode": "L"})}, + output_schema={"label": "binary"}, + dataset_name="gradcam_image_smoke", + ) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.temp_dir) + + def test_gradcam_torchvisionmodel_smoke(self): + model = TorchvisionModel( + dataset=self.dataset, + model_name="resnet18", + model_config={"weights": None}, + ) + model.eval() + + batch = next(iter(get_dataloader(self.dataset, batch_size=1, shuffle=False))) + gradcam = GradCAM(model, target_layer="model.layer4.1.conv2") + + attributions = gradcam.attribute(**batch) + + self.assertIn("image", attributions) + self.assertEqual(attributions["image"].shape, (1, 64, 64)) + self.assertTrue(torch.all(attributions["image"] >= 0)) + self.assertTrue(torch.all(attributions["image"] <= 1)) + + def test_gradcam_torchvisionmodel_matches_direct_logit_path(self): + model = TorchvisionModel( + dataset=self.dataset, + model_name="resnet18", + model_config={"weights": None}, + ) + model.eval() + + batch = next(iter(get_dataloader(self.dataset, batch_size=1, shuffle=False))) + wrapper_cam = GradCAM( + model, + target_layer="model.layer4.1.conv2", + ).attribute( + **batch + )["image"] + shim_cam = GradCAM( + TorchvisionLogitShim(model.model), + target_layer="model.layer4.1.conv2", + ).attribute(image=batch["image"])["image"] + + self.assertTrue(torch.allclose(wrapper_cam, shim_cam, atol=1e-5, rtol=1e-4)) + + +if __name__ == "__main__": + unittest.main()