From 483da6b9a669b00205bb388efe664224a370e87e Mon Sep 17 00:00:00 2001 From: Steven Pousty Date: Wed, 13 May 2026 17:40:31 -0400 Subject: [PATCH] Fixed captumyt.py per https://github.com/pytorch/tutorials/issues/3859 Co-Authored-By: Claude Sonnet 4.6 --- beginner_source/introyt/captumyt.py | 133 +++------------------------- 1 file changed, 12 insertions(+), 121 deletions(-) diff --git a/beginner_source/introyt/captumyt.py b/beginner_source/introyt/captumyt.py index 3950792b49..a18d7338bf 100644 --- a/beginner_source/introyt/captumyt.py +++ b/beginner_source/introyt/captumyt.py @@ -82,35 +82,25 @@ - The ``captum.attr.visualization`` module (imported below as ``viz``) provides helpful functions for visualizing attributions related to images. -- **Captum Insights** is an easy-to-use API on top of Captum that - provides a visualization widget with ready-made visualizations for - image, text, and arbitrary model types. -Both of these visualization toolsets will be demonstrated in this -notebook. The first few examples will focus on computer vision use -cases, but the Captum Insights section at the end will demonstrate -visualization of attributions in a multi-model, visual -question-and-answer model. +This visualization toolset will be demonstrated throughout this notebook. Installation ------------ Before you get started, you need to have a Python environment with: -- Python version 3.6 or higher -- For the Captum Insights example, Flask 1.1 or higher and Flask-Compress - (the latest version is recommended) -- PyTorch version 1.2 or higher (the latest version is recommended) -- TorchVision version 0.6 or higher (the latest version is recommended) +- Python version 3.9 or higher +- PyTorch (the latest version is recommended) +- TorchVision (the latest version is recommended) - Captum (the latest version is recommended) -- Matplotlib version 3.3.4, since Captum currently uses a Matplotlib - function whose arguments have been renamed in later versions +- Matplotlib (the latest version is recommended) To install Captum in a virtual environment, use: .. code-block:: sh - pip install torch torchvision captum matplotlib==3.3.4 Flask-Compress + pip install torch torchvision captum matplotlib Restart this notebook in the environment you set up, and you’re ready to go! @@ -257,8 +247,12 @@ attributions_ig = integrated_gradients.attribute(input_img, target=pred_label_idx, n_steps=200) # Show the original image for comparison -_ = viz.visualize_image_attr(None, np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)), - method="original_image", title="Original Image") +original_image_np = np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)) +fig, ax = plt.subplots(figsize=(6, 6)) +ax.imshow(original_image_np) +ax.set_title("Original Image") +ax.axis('off') +plt.show() default_cmap = LinearSegmentedColormap.from_list('custom blue', [(0, '#ffffff'), @@ -385,106 +379,3 @@ # Visualizations such as this can give you novel insights into how your # hidden layers respond to your input. # - - -########################################################################## -# Visualization with Captum Insights -# ---------------------------------- -# -# Captum Insights is an interpretability visualization widget built on top -# of Captum to facilitate model understanding. Captum Insights works -# across images, text, and other features to help users understand feature -# attribution. It allows you to visualize attribution for multiple -# input/output pairs, and provides visualization tools for image, text, -# and arbitrary data. -# -# In this section of the notebook, we’ll visualize multiple image -# classification inferences with Captum Insights. -# -# First, let’s gather some image and see what the model thinks of them. -# For variety, we’ll take our cat, a teapot, and a trilobite fossil: -# - -imgs = ['img/cat.jpg', 'img/teapot.jpg', 'img/trilobite.jpg'] - -for img in imgs: - img = Image.open(img) - transformed_img = transform(img) - input_img = transform_normalize(transformed_img) - input_img = input_img.unsqueeze(0) # the model requires a dummy batch dimension - - output = model(input_img) - output = F.softmax(output, dim=1) - prediction_score, pred_label_idx = torch.topk(output, 1) - pred_label_idx.squeeze_() - predicted_label = idx_to_labels[str(pred_label_idx.item())][1] - print('Predicted:', predicted_label, '/', pred_label_idx.item(), ' (', prediction_score.squeeze().item(), ')') - - -########################################################################## -# …and it looks like our model is identifying them all correctly - but of -# course, we want to dig deeper. For that we’ll use the Captum Insights -# widget, which we configure with an ``AttributionVisualizer`` object, -# imported below. The ``AttributionVisualizer`` expects batches of data, -# so we’ll bring in Captum’s ``Batch`` helper class. And we’ll be looking -# at images specifically, so well also import ``ImageFeature``. -# -# We configure the ``AttributionVisualizer`` with the following arguments: -# -# - An array of models to be examined (in our case, just the one) -# - A scoring function, which allows Captum Insights to pull out the -# top-k predictions from a model -# - An ordered, human-readable list of classes our model is trained on -# - A list of features to look for - in our case, an ``ImageFeature`` -# - A dataset, which is an iterable object returning batches of inputs -# and labels - just like you’d use for training -# - -from captum.insights import AttributionVisualizer, Batch -from captum.insights.attr_vis.features import ImageFeature - -# Baseline is all-zeros input - this may differ depending on your data -def baseline_func(input): - return input * 0 - -# merging our image transforms from above -def full_img_transform(input): - i = Image.open(input) - i = transform(i) - i = transform_normalize(i) - i = i.unsqueeze(0) - return i - - -input_imgs = torch.cat(list(map(lambda i: full_img_transform(i), imgs)), 0) - -visualizer = AttributionVisualizer( - models=[model], - score_func=lambda o: torch.nn.functional.softmax(o, 1), - classes=list(map(lambda k: idx_to_labels[k][1], idx_to_labels.keys())), - features=[ - ImageFeature( - "Photo", - baseline_transforms=[baseline_func], - input_transforms=[], - ) - ], - dataset=[Batch(input_imgs, labels=[282,849,69])] -) - - -######################################################################### -# Note that running the cell above didn’t take much time at all, unlike -# our attributions above. That’s because Captum Insights lets you -# configure different attribution algorithms in a visual widget, after -# which it will compute and display the attributions. *That* process will -# take a few minutes. -# -# Running the cell below will render the Captum Insights widget. You can -# then choose attributions methods and their arguments, filter model -# responses based on predicted class or prediction correctness, see the -# model’s predictions with associated probabilities, and view heatmaps of -# the attribution compared with the original image. -# - -visualizer.render()