|
| 1 | +import torch |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +from inverter_util import RelevancePropagator |
| 5 | +from utils import pprint, Flatten |
| 6 | + |
| 7 | + |
| 8 | +class InnvestigateModel(torch.nn.Module): |
| 9 | + """ |
| 10 | + ATTENTION: |
| 11 | + Currently, innvestigating a network only works if all |
| 12 | + layers that have to be inverted are specified explicitly |
| 13 | + and registered as a module. If., for example, |
| 14 | + only the functional max_poolnd is used, the inversion will not work. |
| 15 | + """ |
| 16 | + |
| 17 | + def __init__(self, the_model, lrp_exponent=1, beta=.5, epsilon=1e-6, |
| 18 | + method="e-rule"): |
| 19 | + """ |
| 20 | + Model wrapper for pytorch models to 'innvestigate' them |
| 21 | + with layer-wise relevance propagation (LRP) as introduced by Bach et. al |
| 22 | + (https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). |
| 23 | + Given a class level probability produced by the model under consideration, |
| 24 | + the LRP algorithm attributes this probability to the nodes in each layer. |
| 25 | + This allows for visualizing the relevance of input pixels on the resulting |
| 26 | + class probability. |
| 27 | +
|
| 28 | + Args: |
| 29 | + the_model: Pytorch model, e.g. a pytorch.nn.Sequential consisting of |
| 30 | + different layers. Not all layers are supported yet. |
| 31 | + lrp_exponent: Exponent for rescaling the importance values per node |
| 32 | + in a layer when using the e-rule method. |
| 33 | + beta: Beta value allows for placing more (large beta) emphasis on |
| 34 | + nodes that positively contribute to the activation of a given node |
| 35 | + in the subsequent layer. Low beta value allows for placing more emphasis |
| 36 | + on inhibitory neurons in a layer. Only relevant for method 'b-rule'. |
| 37 | + epsilon: Stabilizing term to avoid numerical instabilities if the norm (denominator |
| 38 | + for distributing the relevance) is close to zero. |
| 39 | + method: Different rules for the LRP algorithm, b-rule allows for placing |
| 40 | + more or less focus on positive / negative contributions, whereas |
| 41 | + the e-rule treats them equally. For more information, |
| 42 | + see the paper linked above. |
| 43 | + """ |
| 44 | + super(InnvestigateModel, self).__init__() |
| 45 | + self.model = the_model |
| 46 | + self.device = torch.device("cpu", 0) |
| 47 | + self.prediction = None |
| 48 | + self.r_values_per_layer = None |
| 49 | + self.only_max_score = None |
| 50 | + # Initialize the 'Relevance Propagator' with the chosen rule. |
| 51 | + # This will be used to back-propagate the relevance values |
| 52 | + # through the layers in the innvestigate method. |
| 53 | + self.inverter = RelevancePropagator(lrp_exponent=lrp_exponent, |
| 54 | + beta=beta, method=method, epsilon=epsilon, |
| 55 | + device=self.device) |
| 56 | + |
| 57 | + # Parsing the individual model layers |
| 58 | + self.register_hooks(self.model) |
| 59 | + if method == "b-rule" and float(beta) in (-1., 0): |
| 60 | + which = "positive" if beta == -1 else "negative" |
| 61 | + which_opp = "negative" if beta == -1 else "positive" |
| 62 | + print("WARNING: With the chosen beta value, " |
| 63 | + "only " + which + " contributions " |
| 64 | + "will be taken into account.\nHence, " |
| 65 | + "if in any layer only " + which_opp + |
| 66 | + " contributions exist, the " |
| 67 | + "overall relevance will not be conserved.\n") |
| 68 | + |
| 69 | + def cuda(self, device=None): |
| 70 | + self.device = torch.device("cuda", device) |
| 71 | + self.inverter.device = self.device |
| 72 | + return super(InnvestigateModel, self).cuda(device) |
| 73 | + |
| 74 | + def cpu(self): |
| 75 | + self.device = torch.device("cpu", 0) |
| 76 | + self.inverter.device = self.device |
| 77 | + return super(InnvestigateModel, self).cpu() |
| 78 | + |
| 79 | + def register_hooks(self, parent_module): |
| 80 | + """ |
| 81 | + Recursively unrolls a model and registers the required |
| 82 | + hooks to save all the necessary values for LRP in the forward pass. |
| 83 | +
|
| 84 | + Args: |
| 85 | + parent_module: Model to unroll and register hooks for. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + None |
| 89 | +
|
| 90 | + """ |
| 91 | + for mod in parent_module.children(): |
| 92 | + if list(mod.children()): |
| 93 | + self.register_hooks(mod) |
| 94 | + continue |
| 95 | + mod.register_forward_hook( |
| 96 | + self.inverter.get_layer_fwd_hook(mod)) |
| 97 | + if isinstance(mod, torch.nn.ReLU): |
| 98 | + mod.register_backward_hook( |
| 99 | + self.relu_hook_function |
| 100 | + ) |
| 101 | + |
| 102 | + @staticmethod |
| 103 | + def relu_hook_function(module, grad_in, grad_out): |
| 104 | + """ |
| 105 | + If there is a negative gradient, change it to zero. |
| 106 | + """ |
| 107 | + return (torch.clamp(grad_in[0], min=0.0),) |
| 108 | + |
| 109 | + def __call__(self, in_tensor): |
| 110 | + """ |
| 111 | + The innvestigate wrapper returns the same prediction as the |
| 112 | + original model, but wraps the model call method in the evaluate |
| 113 | + method to save the last prediction. |
| 114 | +
|
| 115 | + Args: |
| 116 | + in_tensor: Model input to pass through the pytorch model. |
| 117 | +
|
| 118 | + Returns: |
| 119 | + Model output. |
| 120 | + """ |
| 121 | + return self.evaluate(in_tensor) |
| 122 | + |
| 123 | + def evaluate(self, in_tensor): |
| 124 | + """ |
| 125 | + Evaluates the model on a new input. The registered forward hooks will |
| 126 | + save all the data that is necessary to compute the relevance per neuron per layer. |
| 127 | +
|
| 128 | + Args: |
| 129 | + in_tensor: New input for which to predict an output. |
| 130 | +
|
| 131 | + Returns: |
| 132 | + Model prediction |
| 133 | + """ |
| 134 | + # Reset module list. In case the structure changes dynamically, |
| 135 | + # the module list is tracked for every forward pass. |
| 136 | + self.inverter.reset_module_list() |
| 137 | + self.prediction = self.model(in_tensor) |
| 138 | + return self.prediction |
| 139 | + |
| 140 | + def get_r_values_per_layer(self): |
| 141 | + if self.r_values_per_layer is None: |
| 142 | + pprint("No relevances have been calculated yet, returning None in" |
| 143 | + " get_r_values_per_layer.") |
| 144 | + return self.r_values_per_layer |
| 145 | + |
| 146 | + def innvestigate(self, in_tensor=None, rel_for_class=None): |
| 147 | + """ |
| 148 | + Method for 'innvestigating' the model with the LRP rule chosen at |
| 149 | + the initialization of the InnvestigateModel. |
| 150 | + Args: |
| 151 | + in_tensor: Input for which to evaluate the LRP algorithm. |
| 152 | + If input is None, the last evaluation is used. |
| 153 | + If no evaluation has been performed since initialization, |
| 154 | + an error is raised. |
| 155 | + rel_for_class (int): Index of the class for which the relevance |
| 156 | + distribution is to be analyzed. If None, the 'winning' class |
| 157 | + is used for indexing. |
| 158 | +
|
| 159 | + Returns: |
| 160 | + Model output and relevances of nodes in the input layer. |
| 161 | + In order to get relevance distributions in other layers, use |
| 162 | + the get_r_values_per_layer method. |
| 163 | + """ |
| 164 | + if self.r_values_per_layer is not None: |
| 165 | + for elt in self.r_values_per_layer: |
| 166 | + del elt |
| 167 | + self.r_values_per_layer = None |
| 168 | + |
| 169 | + with torch.no_grad(): |
| 170 | + # Check if innvestigation can be performed. |
| 171 | + if in_tensor is None and self.prediction is None: |
| 172 | + raise RuntimeError("Model needs to be evaluated at least " |
| 173 | + "once before an innvestigation can be " |
| 174 | + "performed. Please evaluate model first " |
| 175 | + "or call innvestigate with a new input to " |
| 176 | + "evaluate.") |
| 177 | + |
| 178 | + # Evaluate the model anew if a new input is supplied. |
| 179 | + if in_tensor is not None: |
| 180 | + self.evaluate(in_tensor) |
| 181 | + |
| 182 | + # If no class index is specified, analyze for class |
| 183 | + # with highest prediction. |
| 184 | + if rel_for_class is None: |
| 185 | + # Default behaviour is innvestigating the output |
| 186 | + # on an arg-max-basis, if no class is specified. |
| 187 | + org_shape = self.prediction.size() |
| 188 | + # Make sure shape is just a 1D vector per batch example. |
| 189 | + self.prediction = self.prediction.view(org_shape[0], -1) |
| 190 | + max_v, _ = torch.max(self.prediction, dim=1, keepdim=True) |
| 191 | + only_max_score = torch.zeros_like(self.prediction).to(self.device) |
| 192 | + only_max_score[max_v == self.prediction] = self.prediction[max_v == self.prediction] |
| 193 | + relevance_tensor = only_max_score.view(org_shape) |
| 194 | + self.prediction.view(org_shape) |
| 195 | + |
| 196 | + else: |
| 197 | + org_shape = self.prediction.size() |
| 198 | + self.prediction = self.prediction.view(org_shape[0], -1) |
| 199 | + only_max_score = torch.zeros_like(self.prediction).to(self.device) |
| 200 | + only_max_score[:, rel_for_class] += self.prediction[:, rel_for_class] |
| 201 | + relevance_tensor = only_max_score.view(org_shape) |
| 202 | + self.prediction.view(org_shape) |
| 203 | + |
| 204 | + # We have to iterate through the model backwards. |
| 205 | + # The module list is computed for every forward pass |
| 206 | + # by the model inverter. |
| 207 | + rev_model = self.inverter.module_list[::-1] |
| 208 | + relevance = relevance_tensor.detach() |
| 209 | + del relevance_tensor |
| 210 | + # List to save relevance distributions per layer |
| 211 | + r_values_per_layer = [relevance] |
| 212 | + for layer in rev_model: |
| 213 | + # Compute layer specific backwards-propagation of relevance values |
| 214 | + relevance = self.inverter.compute_propagated_relevance(layer, relevance) |
| 215 | + r_values_per_layer.append(relevance.cpu()) |
| 216 | + |
| 217 | + self.r_values_per_layer = r_values_per_layer |
| 218 | + |
| 219 | + del relevance |
| 220 | + if self.device.type == "cuda": |
| 221 | + torch.cuda.empty_cache() |
| 222 | + return self.prediction, r_values_per_layer[-1] |
| 223 | + |
| 224 | + def forward(self, in_tensor): |
| 225 | + return self.model.forward(in_tensor) |
| 226 | + |
| 227 | + def extra_repr(self): |
| 228 | + r"""Set the extra representation of the module |
| 229 | +
|
| 230 | + To print customized extra information, you should re-implement |
| 231 | + this method in your own modules. Both single-line and multi-line |
| 232 | + strings are acceptable. |
| 233 | + """ |
| 234 | + return self.model.extra_repr() |
0 commit comments