You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
You can use this package for "custom" deep learning models, for example Object Detection or Semantic Segmentation.
181
+
Methods like GradCAM were designed for and were originally mostly applied on classification models,
182
+
and specifically CNN classification models.
183
+
However you can also use this package on new architectures like Vision Transformers, and on non classification tasks like Object Detection or Semantic Segmentation.
184
+
185
+
The be able to adapt to non standard cases, we have two concepts.
186
+
- The reshape transform - how do we convert activations to represent spatial images ?
187
+
- The model targets - What exactly should the explainability method try to explain ?
188
+
189
+
## The reshape transform
190
+
In a CNN the intermediate activations in the model are a mult-channel image that have the dimensions channel x rows x cols,
191
+
and the various explainabiltiy methods work with these to produce a new image.
192
+
193
+
In case of another architecture, like the Vision Transformer, the shape might be different, like (rows x cols + 1) x channels, or something else.
194
+
The reshape transform converts the activations back into a multi-channel image, for example by removing the class token in a vision transformer.
195
+
For examples, check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/reshape_transforms.py)
196
+
197
+
## Model Targets
198
+
The model target is just a callable that is able to get the model output, and filter it out for the specific scalar output we want to explain.
199
+
200
+
For classification tasks, the model target will typically be the output from a specific category.
201
+
The `targets` parameter passed to the CAM method can then use `ClassifierOutputTarget`:
202
+
```python
203
+
targets = [ClassifierOutputTarget(281)]
204
+
```
188
205
206
+
However more advanced cases, you might want another behaviour.
207
+
Check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py) for more examples.
189
208
190
-
You will have to define objects that you can then pass to the CAM algorithms:
191
-
1. A reshape_transform, that aggregates the layer outputs into 2D tensors that will be displayed.
192
-
2. Model Targets, that define what target do you want to compute the visualizations for, for example a specific category, or a list of bounding boxes.
193
209
210
+
# Tutorials
194
211
Here you can find detailed examples of how to use this for various custom use cases like object detection:
195
212
196
213
These point to the new documentation jupter-book for fast rendering.
0 commit comments