|
| 1 | +""" |
| 2 | +Create Time: 29/3/2023 |
| 3 | +Author: BierOne ([email protected]) |
| 4 | +""" |
| 5 | +import os |
| 6 | +import torch |
| 7 | +import numpy as np |
| 8 | +from typing import Any, Callable, Optional, Sequence, Tuple, Dict |
| 9 | + |
| 10 | +from benchmark_notes.utils import get_state_func, acc |
| 11 | +from benchmark_notes import nethook |
| 12 | +from benchmark_notes.instr_state import kl_grad |
| 13 | + |
| 14 | + |
| 15 | +def make_layer_size_dict(model, layer_names, input_shape=(1, 3, 224, 224), spatial_func=None): |
| 16 | + if spatial_func is None: |
| 17 | + spatial_func = lambda s: s |
| 18 | + transform = lambda s: spatial_func(s) if len(s.shape) > 2 else s # avg-pool for last layer |
| 19 | + input = torch.zeros(*input_shape).cuda() |
| 20 | + layer_size_dict = {} |
| 21 | + with nethook.InstrumentedModel(model) as instr: |
| 22 | + instr.retain_layers(layer_names, detach=True) |
| 23 | + with torch.no_grad(): |
| 24 | + _ = model(input) |
| 25 | + for ln in layer_names: |
| 26 | + b_state = instr.retained_layer(ln) |
| 27 | + layer_size_dict[ln] = transform(b_state).shape[1] |
| 28 | + return layer_size_dict |
| 29 | + |
| 30 | +class Coverage: |
| 31 | + def __init__(self, |
| 32 | + layer_size_dict: Dict, |
| 33 | + device: Optional[Any] = None, |
| 34 | + hyper: Optional[Dict] = None, |
| 35 | + unpack: Optional[Callable] = None, |
| 36 | + method: Optional[Any] = None, |
| 37 | + **kwargs): |
| 38 | + if device is None: |
| 39 | + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| 40 | + self.device = device |
| 41 | + self.layer_size_dict = layer_size_dict |
| 42 | + self.layer_names = list(layer_size_dict.keys()) |
| 43 | + self.unpack = unpack |
| 44 | + self.coverage_dict = {ln: 0 for ln in self.layer_names} |
| 45 | + self.hyper = hyper |
| 46 | + self.method = method |
| 47 | + self.method_kwargs = kwargs |
| 48 | + self.init_variable(hyper) |
| 49 | + |
| 50 | + def init_variable(self, hyper): |
| 51 | + raise NotImplementedError |
| 52 | + |
| 53 | + def update(self): |
| 54 | + raise NotImplementedError |
| 55 | + |
| 56 | + def step(self, b_layer_state): |
| 57 | + raise NotImplementedError |
| 58 | + |
| 59 | + def clear(self): |
| 60 | + raise NotImplementedError |
| 61 | + |
| 62 | + def save(self, path): |
| 63 | + raise NotImplementedError |
| 64 | + |
| 65 | + def load(self, path, method): |
| 66 | + raise NotImplementedError |
| 67 | + |
| 68 | + def assess(self, model, data_loader, spatial_func=None, method=None, save_state=False, **kwargs): |
| 69 | + if method is not None: |
| 70 | + self.method = method |
| 71 | + if kwargs: |
| 72 | + self.method_kwargs = kwargs |
| 73 | + model.to(self.device) |
| 74 | + model.eval() |
| 75 | + if spatial_func is None: |
| 76 | + spatial_func = lambda s: s |
| 77 | + transform = lambda s: spatial_func(s).detach() if len(s.shape) > 2 else s.detach() # avg-pool for last layer |
| 78 | + state_func = get_state_func(self.method, **self.method_kwargs) |
| 79 | + |
| 80 | + layer_output = {n: ([], [], []) for n in self.layer_names} |
| 81 | + total, correct = 0, 0 |
| 82 | + with nethook.InstrumentedModel(model) as instr: |
| 83 | + instr.retain_layers(self.layer_names, detach=False) |
| 84 | + for i, data in enumerate(data_loader): |
| 85 | + x, y = self.unpack(data, self.device) |
| 86 | + p = model(x) |
| 87 | + correct_num, correct_flags = acc(p, y) |
| 88 | + correct += correct_num |
| 89 | + total += x.shape[0] |
| 90 | + b_layer_state = {} |
| 91 | + for j, ln in enumerate(self.layer_names): |
| 92 | + retain_graph = False if j == len(self.layer_names) - 1 else True |
| 93 | + b_state = instr.retained_layer(ln) |
| 94 | + b_kl_grad = kl_grad(b_state, p, retain_graph=retain_graph) |
| 95 | + b_state, b_kl_grad = transform(b_state), transform(b_kl_grad) |
| 96 | + out = state_func(b_state, correct_flags, b_kl_grad) |
| 97 | + b_layer_state[ln] = out |
| 98 | + |
| 99 | + if save_state: |
| 100 | + layer_output[ln][0].append(b_state.cpu()) |
| 101 | + layer_output[ln][1].append(correct_flags.cpu()) |
| 102 | + layer_output[ln][2].append(b_kl_grad.cpu()) |
| 103 | + self.step(b_layer_state) |
| 104 | + if save_state: |
| 105 | + for ln in layer_output: |
| 106 | + states, flags, kl_grads = layer_output[ln] |
| 107 | + layer_output[ln] = (torch.cat(states), torch.cat(flags), torch.cat(kl_grads)) |
| 108 | + return layer_output, correct / total |
| 109 | + |
| 110 | + def assess_with_cache(self, layer_name, data_loader, method=None, **kwargs): |
| 111 | + if method is not None: |
| 112 | + self.method = method |
| 113 | + if kwargs: |
| 114 | + self.method_kwargs = kwargs |
| 115 | + state_func = get_state_func(self.method, **self.method_kwargs) |
| 116 | + b_layer_state = {} |
| 117 | + for b_state, correct_flags, b_kl_grad in (data_loader): |
| 118 | + out = state_func(b_state.to(self.device, non_blocking=True), |
| 119 | + correct_flags.to(self.device, non_blocking=True), |
| 120 | + b_kl_grad.to(self.device, non_blocking=True)) |
| 121 | + b_layer_state[layer_name] = out |
| 122 | + self.step(b_layer_state) |
| 123 | + |
| 124 | + def score(self, layer_name=None): |
| 125 | + if len(self.layer_names) == 1: |
| 126 | + layer_name = self.layer_names[0] |
| 127 | + if layer_name: |
| 128 | + return self.coverage_dict[layer_name] |
| 129 | + return self.coverage_dict |
| 130 | + |
| 131 | + |
| 132 | +class KMNC(Coverage): |
| 133 | + def init_variable(self, hyper: Optional[Dict] = None): |
| 134 | + self.estimator_dict = {} |
| 135 | + self.current = 0 |
| 136 | + |
| 137 | + assert ('M' in hyper and 'O' in hyper) |
| 138 | + self.M = hyper['M'] # number of buckets |
| 139 | + self.O = hyper['O'] # minimum number of samples required for bin filling |
| 140 | + for (layer_name, layer_size) in self.layer_size_dict.items(): |
| 141 | + self.estimator_dict[layer_name] = Estimator(layer_size, self.M, self.O, self.device) |
| 142 | + |
| 143 | + def add(self, other): |
| 144 | + # check if other is a KMNC object |
| 145 | + assert (self.M == other.M) and (self.layer_names == other.layer_names) |
| 146 | + for ln in self.layer_names: |
| 147 | + self.estimator_dict[ln].add(other.estimator_dict[ln]) |
| 148 | + |
| 149 | + def clear(self): |
| 150 | + for ln in self.layer_names: |
| 151 | + self.estimator_dict[ln].clear() |
| 152 | + |
| 153 | + def step(self, b_layer_state): |
| 154 | + for (ln, states) in b_layer_state.items(): |
| 155 | + if len(states) > 0: |
| 156 | + # print(states.shape) |
| 157 | + self.estimator_dict[ln].update(states) |
| 158 | + |
| 159 | + def update(self, **kwargs): |
| 160 | + for ln in self.layer_names: |
| 161 | + thresh = self.estimator_dict[ln].thresh[:, 0].cpu().numpy() |
| 162 | + t_cov, coverage = self.estimator_dict[ln].get_score(**kwargs) |
| 163 | + self.coverage_dict[ln] = (thresh, t_cov, coverage) |
| 164 | + return self.score() |
| 165 | + |
| 166 | + def save(self, path, prefix="cov"): |
| 167 | + os.makedirs(path, exist_ok=True) |
| 168 | + for k, v in self.method_kwargs.items(): |
| 169 | + prefix += f"_{k}_{v}" |
| 170 | + name = prefix + f"_states_M{self.M}_{self.method}.pkl" |
| 171 | + # print('Saving recorded coverage states in %s/%s...' % (path, name)) |
| 172 | + state = { |
| 173 | + 'layer_size_dict': self.layer_size_dict, |
| 174 | + 'hyper': self.hyper, |
| 175 | + 'es_states': {ln: self.estimator_dict[ln].states for ln in self.layer_names}, |
| 176 | + 'method': self.method, |
| 177 | + 'method_kwargs': self.method_kwargs |
| 178 | + } |
| 179 | + torch.save(state, os.path.join(path, name)) |
| 180 | + |
| 181 | + @staticmethod |
| 182 | + def load(path, name, device=None, r=None, unpack=None, verbose=False): |
| 183 | + state = torch.load(os.path.join(path, name)) |
| 184 | + if r is not None and r > 0: |
| 185 | + state['hyper']['O'] = r |
| 186 | + if verbose: |
| 187 | + print('Loading saved coverage states in %s/%s...' % (path, name)) |
| 188 | + for k, v in state['hyper'].items(): |
| 189 | + print("load hyper params of Coverage:", k, v) |
| 190 | + Coverage = KMNC(state['layer_size_dict'], device=device, |
| 191 | + hyper=state['hyper'], unpack=unpack) |
| 192 | + for k, v in state['es_states'].items(): |
| 193 | + Coverage.estimator_dict[k].load(v) |
| 194 | + try: |
| 195 | + Coverage.method = state['method'] |
| 196 | + Coverage.method_kwargs = state['method_kwargs'] |
| 197 | + except: |
| 198 | + print("failed to load method and method_kwargs") |
| 199 | + pass |
| 200 | + return Coverage |
| 201 | + |
| 202 | + |
| 203 | +def logspace(base=10, num=100): |
| 204 | + num = int(num / 2) |
| 205 | + x = np.linspace(1, np.sqrt(base), num=num) |
| 206 | + x_l = np.emath.logn(base, x) |
| 207 | + x_r = (1 - x_l)[::-1] |
| 208 | + x = np.concatenate([x_l[:-1], x_r]) |
| 209 | + x[-1] += 1e-2 |
| 210 | + return torch.from_numpy(np.append(x, 1.2)) |
| 211 | + |
| 212 | +class Estimator(object): |
| 213 | + def __init__(self, neuron_num, M=1000, O=1, device=None): |
| 214 | + assert O > 0, 'O should > (or =) 1' |
| 215 | + if device is None: |
| 216 | + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| 217 | + self.device = device |
| 218 | + self.M, self.O, self.N = M, O, neuron_num |
| 219 | + self.thresh = torch.linspace(0., 1., M).view(M, -1).repeat(1, neuron_num).to(self.device) |
| 220 | + # self.thresh = logspace(1e3, M).view(M, -1).repeat(1, neuron_num).to(self.device) |
| 221 | + self.t_act = torch.zeros(M - 1, neuron_num).to(self.device) # current activations under each thresh |
| 222 | + |
| 223 | + def add(self, other): |
| 224 | + # check if other is an Estimator object |
| 225 | + assert (self.M == other.M) and (self.N == other.N) |
| 226 | + self.t_act += other.t_act |
| 227 | + |
| 228 | + def update(self, states): |
| 229 | + # bmax, bmin = states.max(dim=0)[0], states.min(dim=0)[0] # [num_neuron] |
| 230 | + # if (bmax > self.upper).any() or (bmin < self.lower).any(): |
| 231 | + |
| 232 | + # thresh -> [num_t, num_n] -> [1, num_t, num_n] ->compare-> [num_data, num_t, num_n] |
| 233 | + # states -> [num_data, num_n] -> [num_data, 1, num_n] ->compare-> ... |
| 234 | + with torch.no_grad(): |
| 235 | + # print(states.shape) |
| 236 | + b_act = (states.unsqueeze(1) >= self.thresh[:self.M - 1].unsqueeze(0)) & \ |
| 237 | + (states.unsqueeze(1) < self.thresh[1:self.M].unsqueeze(0)) |
| 238 | + b_act = b_act.sum(dim=0) # [num_t, num_n] |
| 239 | + self.t_act += b_act # current activations under each thresh |
| 240 | + |
| 241 | + def get_score(self, method="avg"): |
| 242 | + t_score = torch.min(self.t_act / self.O, torch.ones_like(self.t_act)) # [num_t, num_n] |
| 243 | + coverage = (t_score.sum(dim=0)) / self.M # [num_n] |
| 244 | + if method == "norm2": |
| 245 | + coverage = coverage.norm(p=1) |
| 246 | + elif method == "avg": |
| 247 | + coverage = coverage.mean() |
| 248 | + |
| 249 | + # t_cov = t_score.mean(dim=1).cpu().numpy() # for simplicity |
| 250 | + t_cov = t_score[:, 0].cpu().numpy() # for simplicity |
| 251 | + return np.append(t_cov, 0), coverage.cpu() |
| 252 | + |
| 253 | + @property |
| 254 | + def states(self): |
| 255 | + return { |
| 256 | + "thresh": self.thresh.cpu(), |
| 257 | + "t_act": self.t_act.cpu() |
| 258 | + } |
| 259 | + |
| 260 | + def load(self, state_dict): |
| 261 | + self.thresh = state_dict["thresh"].to(self.device) |
| 262 | + self.t_act = state_dict["t_act"].to(self.device) |
| 263 | + |
| 264 | + def clear(self): |
| 265 | + self.t_act = torch.zeros(self.M - 1, self.N).to(self.device) # current activations under each thresh |
0 commit comments