Skip to content

Commit da83f1f

Browse files
committed
Readme
1 parent 7ff4d2e commit da83f1f

File tree

1 file changed

+41
-24
lines changed

1 file changed

+41
-24
lines changed

README.md

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ This can be useful if you're not sure what layer will perform best.
111111

112112
----------
113113

114-
# Using from code as a library
114+
# Usage examples
115115

116116
```python
117117
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
@@ -124,34 +124,27 @@ target_layers = [model.layer4[-1]]
124124
input_tensor = # Create an input tensor image for your model..
125125
# Note: input_tensor can be a batch tensor with several images!
126126

127-
# Construct the CAM object once, and then re-use it on many images:
128-
cam = GradCAM(model=model, target_layers=target_layers)
129-
130-
# You can also use it within a with statement, to make sure it is freed,
131-
# In case you need to re-create it inside an outer loop:
132-
# with GradCAM(model=model, target_layers=target_layers) as cam:
133-
# ...
134127

135128
# We have to specify the target we want to generate
136129
# the Class Activation Maps for.
137-
# If targets is None, the highest scoring category
138-
# will be used for every image in the batch.
139-
# Here we use ClassifierOutputTarget, but you can define your own custom targets
140-
# That are, for example, combinations of categories, or specific outputs in a non standard model.
141-
142130
targets = [ClassifierOutputTarget(281)]
143131

144-
# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
145-
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
132+
# Construct the CAM object once, and then re-use it on many images.
133+
with GradCAM(model=model, target_layers=target_layers) as cam:
146134

147-
# In this example grayscale_cam has only one image in the batch:
148-
grayscale_cam = grayscale_cam[0, :]
149-
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
135+
# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
136+
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
150137

151-
# You can also get the model outputs without having to re-inference
152-
model_outputs = cam.outputs
138+
# In this example grayscale_cam has only one image in the batch:
139+
grayscale_cam = grayscale_cam[0, :]
140+
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
141+
142+
# You can also get the model outputs without having to redo inference
143+
model_outputs = cam.outputs
153144
```
154145

146+
Cam.py has a more detailed usage example.
147+
155148
----------
156149

157150
# Metrics and evaluating the explanations
@@ -179,18 +172,42 @@ from pytorch_grad_cam.metrics.road import ROADMostRelevantFirstAverage,
179172
cam_metric = ROADCombined(percentiles=[20, 40, 60, 80])
180173
scores = cam_metric(input_tensor, grayscale_cams, targets, model)
181174
```
175+
182176
----------
183177

184178

185179
# Advanced use cases and tutorials:
186180

187-
You can use this package for "custom" deep learning models, for example Object Detection or Semantic Segmentation.
181+
Methods like GradCAM were designed for and were originally mostly applied on classification models,
182+
and specifically CNN classification models.
183+
However you can also use this package on new architectures like Vision Transformers, and on non classification tasks like Object Detection or Semantic Segmentation.
184+
185+
The be able to adapt to non standard cases, we have two concepts.
186+
- The reshape transform - how do we convert activations to represent spatial images ?
187+
- The model targets - What exactly should the explainability method try to explain ?
188+
189+
## The reshape transform
190+
In a CNN the intermediate activations in the model are a mult-channel image that have the dimensions channel x rows x cols,
191+
and the various explainabiltiy methods work with these to produce a new image.
192+
193+
In case of another architecture, like the Vision Transformer, the shape might be different, like (rows x cols + 1) x channels, or something else.
194+
The reshape transform converts the activations back into a multi-channel image, for example by removing the class token in a vision transformer.
195+
For examples, check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/reshape_transforms.py)
196+
197+
## Model Targets
198+
The model target is just a callable that is able to get the model output, and filter it out for the specific scalar output we want to explain.
199+
200+
For classification tasks, the model target will typically be the output from a specific category.
201+
The `targets` parameter passed to the CAM method can then use `ClassifierOutputTarget`:
202+
```python
203+
targets = [ClassifierOutputTarget(281)]
204+
```
188205

206+
However more advanced cases, you might want another behaviour.
207+
Check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py) for more examples.
189208

190-
You will have to define objects that you can then pass to the CAM algorithms:
191-
1. A reshape_transform, that aggregates the layer outputs into 2D tensors that will be displayed.
192-
2. Model Targets, that define what target do you want to compute the visualizations for, for example a specific category, or a list of bounding boxes.
193209

210+
# Tutorials
194211
Here you can find detailed examples of how to use this for various custom use cases like object detection:
195212

196213
These point to the new documentation jupter-book for fast rendering.

0 commit comments

Comments
 (0)