Skip to content

Commit d666551

Browse files
author
Moritz Böhle
committed
first commit
0 parents  commit d666551

11 files changed

+1785
-0
lines changed

MNIST example.ipynb

+937
Large diffs are not rendered by default.

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Pytorch-LRP
8.19 KB
Binary file not shown.
12.2 KB
Binary file not shown.

__pycache__/mnist_test.cpython-36.pyc

2.38 KB
Binary file not shown.

__pycache__/utils.cpython-36.pyc

1.56 KB
Binary file not shown.

innvestigator.py

+234
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import torch
2+
import numpy as np
3+
4+
from inverter_util import RelevancePropagator
5+
from utils import pprint, Flatten
6+
7+
8+
class InnvestigateModel(torch.nn.Module):
9+
"""
10+
ATTENTION:
11+
Currently, innvestigating a network only works if all
12+
layers that have to be inverted are specified explicitly
13+
and registered as a module. If., for example,
14+
only the functional max_poolnd is used, the inversion will not work.
15+
"""
16+
17+
def __init__(self, the_model, lrp_exponent=1, beta=.5, epsilon=1e-6,
18+
method="e-rule"):
19+
"""
20+
Model wrapper for pytorch models to 'innvestigate' them
21+
with layer-wise relevance propagation (LRP) as introduced by Bach et. al
22+
(https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140).
23+
Given a class level probability produced by the model under consideration,
24+
the LRP algorithm attributes this probability to the nodes in each layer.
25+
This allows for visualizing the relevance of input pixels on the resulting
26+
class probability.
27+
28+
Args:
29+
the_model: Pytorch model, e.g. a pytorch.nn.Sequential consisting of
30+
different layers. Not all layers are supported yet.
31+
lrp_exponent: Exponent for rescaling the importance values per node
32+
in a layer when using the e-rule method.
33+
beta: Beta value allows for placing more (large beta) emphasis on
34+
nodes that positively contribute to the activation of a given node
35+
in the subsequent layer. Low beta value allows for placing more emphasis
36+
on inhibitory neurons in a layer. Only relevant for method 'b-rule'.
37+
epsilon: Stabilizing term to avoid numerical instabilities if the norm (denominator
38+
for distributing the relevance) is close to zero.
39+
method: Different rules for the LRP algorithm, b-rule allows for placing
40+
more or less focus on positive / negative contributions, whereas
41+
the e-rule treats them equally. For more information,
42+
see the paper linked above.
43+
"""
44+
super(InnvestigateModel, self).__init__()
45+
self.model = the_model
46+
self.device = torch.device("cpu", 0)
47+
self.prediction = None
48+
self.r_values_per_layer = None
49+
self.only_max_score = None
50+
# Initialize the 'Relevance Propagator' with the chosen rule.
51+
# This will be used to back-propagate the relevance values
52+
# through the layers in the innvestigate method.
53+
self.inverter = RelevancePropagator(lrp_exponent=lrp_exponent,
54+
beta=beta, method=method, epsilon=epsilon,
55+
device=self.device)
56+
57+
# Parsing the individual model layers
58+
self.register_hooks(self.model)
59+
if method == "b-rule" and float(beta) in (-1., 0):
60+
which = "positive" if beta == -1 else "negative"
61+
which_opp = "negative" if beta == -1 else "positive"
62+
print("WARNING: With the chosen beta value, "
63+
"only " + which + " contributions "
64+
"will be taken into account.\nHence, "
65+
"if in any layer only " + which_opp +
66+
" contributions exist, the "
67+
"overall relevance will not be conserved.\n")
68+
69+
def cuda(self, device=None):
70+
self.device = torch.device("cuda", device)
71+
self.inverter.device = self.device
72+
return super(InnvestigateModel, self).cuda(device)
73+
74+
def cpu(self):
75+
self.device = torch.device("cpu", 0)
76+
self.inverter.device = self.device
77+
return super(InnvestigateModel, self).cpu()
78+
79+
def register_hooks(self, parent_module):
80+
"""
81+
Recursively unrolls a model and registers the required
82+
hooks to save all the necessary values for LRP in the forward pass.
83+
84+
Args:
85+
parent_module: Model to unroll and register hooks for.
86+
87+
Returns:
88+
None
89+
90+
"""
91+
for mod in parent_module.children():
92+
if list(mod.children()):
93+
self.register_hooks(mod)
94+
continue
95+
mod.register_forward_hook(
96+
self.inverter.get_layer_fwd_hook(mod))
97+
if isinstance(mod, torch.nn.ReLU):
98+
mod.register_backward_hook(
99+
self.relu_hook_function
100+
)
101+
102+
@staticmethod
103+
def relu_hook_function(module, grad_in, grad_out):
104+
"""
105+
If there is a negative gradient, change it to zero.
106+
"""
107+
return (torch.clamp(grad_in[0], min=0.0),)
108+
109+
def __call__(self, in_tensor):
110+
"""
111+
The innvestigate wrapper returns the same prediction as the
112+
original model, but wraps the model call method in the evaluate
113+
method to save the last prediction.
114+
115+
Args:
116+
in_tensor: Model input to pass through the pytorch model.
117+
118+
Returns:
119+
Model output.
120+
"""
121+
return self.evaluate(in_tensor)
122+
123+
def evaluate(self, in_tensor):
124+
"""
125+
Evaluates the model on a new input. The registered forward hooks will
126+
save all the data that is necessary to compute the relevance per neuron per layer.
127+
128+
Args:
129+
in_tensor: New input for which to predict an output.
130+
131+
Returns:
132+
Model prediction
133+
"""
134+
# Reset module list. In case the structure changes dynamically,
135+
# the module list is tracked for every forward pass.
136+
self.inverter.reset_module_list()
137+
self.prediction = self.model(in_tensor)
138+
return self.prediction
139+
140+
def get_r_values_per_layer(self):
141+
if self.r_values_per_layer is None:
142+
pprint("No relevances have been calculated yet, returning None in"
143+
" get_r_values_per_layer.")
144+
return self.r_values_per_layer
145+
146+
def innvestigate(self, in_tensor=None, rel_for_class=None):
147+
"""
148+
Method for 'innvestigating' the model with the LRP rule chosen at
149+
the initialization of the InnvestigateModel.
150+
Args:
151+
in_tensor: Input for which to evaluate the LRP algorithm.
152+
If input is None, the last evaluation is used.
153+
If no evaluation has been performed since initialization,
154+
an error is raised.
155+
rel_for_class (int): Index of the class for which the relevance
156+
distribution is to be analyzed. If None, the 'winning' class
157+
is used for indexing.
158+
159+
Returns:
160+
Model output and relevances of nodes in the input layer.
161+
In order to get relevance distributions in other layers, use
162+
the get_r_values_per_layer method.
163+
"""
164+
if self.r_values_per_layer is not None:
165+
for elt in self.r_values_per_layer:
166+
del elt
167+
self.r_values_per_layer = None
168+
169+
with torch.no_grad():
170+
# Check if innvestigation can be performed.
171+
if in_tensor is None and self.prediction is None:
172+
raise RuntimeError("Model needs to be evaluated at least "
173+
"once before an innvestigation can be "
174+
"performed. Please evaluate model first "
175+
"or call innvestigate with a new input to "
176+
"evaluate.")
177+
178+
# Evaluate the model anew if a new input is supplied.
179+
if in_tensor is not None:
180+
self.evaluate(in_tensor)
181+
182+
# If no class index is specified, analyze for class
183+
# with highest prediction.
184+
if rel_for_class is None:
185+
# Default behaviour is innvestigating the output
186+
# on an arg-max-basis, if no class is specified.
187+
org_shape = self.prediction.size()
188+
# Make sure shape is just a 1D vector per batch example.
189+
self.prediction = self.prediction.view(org_shape[0], -1)
190+
max_v, _ = torch.max(self.prediction, dim=1, keepdim=True)
191+
only_max_score = torch.zeros_like(self.prediction).to(self.device)
192+
only_max_score[max_v == self.prediction] = self.prediction[max_v == self.prediction]
193+
relevance_tensor = only_max_score.view(org_shape)
194+
self.prediction.view(org_shape)
195+
196+
else:
197+
org_shape = self.prediction.size()
198+
self.prediction = self.prediction.view(org_shape[0], -1)
199+
only_max_score = torch.zeros_like(self.prediction).to(self.device)
200+
only_max_score[:, rel_for_class] += self.prediction[:, rel_for_class]
201+
relevance_tensor = only_max_score.view(org_shape)
202+
self.prediction.view(org_shape)
203+
204+
# We have to iterate through the model backwards.
205+
# The module list is computed for every forward pass
206+
# by the model inverter.
207+
rev_model = self.inverter.module_list[::-1]
208+
relevance = relevance_tensor.detach()
209+
del relevance_tensor
210+
# List to save relevance distributions per layer
211+
r_values_per_layer = [relevance]
212+
for layer in rev_model:
213+
# Compute layer specific backwards-propagation of relevance values
214+
relevance = self.inverter.compute_propagated_relevance(layer, relevance)
215+
r_values_per_layer.append(relevance.cpu())
216+
217+
self.r_values_per_layer = r_values_per_layer
218+
219+
del relevance
220+
if self.device.type == "cuda":
221+
torch.cuda.empty_cache()
222+
return self.prediction, r_values_per_layer[-1]
223+
224+
def forward(self, in_tensor):
225+
return self.model.forward(in_tensor)
226+
227+
def extra_repr(self):
228+
r"""Set the extra representation of the module
229+
230+
To print customized extra information, you should re-implement
231+
this method in your own modules. Both single-line and multi-line
232+
strings are acceptable.
233+
"""
234+
return self.model.extra_repr()

0 commit comments

Comments
 (0)