| 
 | 1 | +"""  | 
 | 2 | +Create Time: 29/3/2023  | 
 | 3 | +Author: BierOne ([email protected])  | 
 | 4 | +"""  | 
 | 5 | +import os  | 
 | 6 | +import torch  | 
 | 7 | +import numpy as np  | 
 | 8 | +from typing import Any, Callable, Optional, Sequence, Tuple, Dict  | 
 | 9 | + | 
 | 10 | +from benchmark_notes.utils import get_state_func, acc  | 
 | 11 | +from benchmark_notes import nethook  | 
 | 12 | +from benchmark_notes.instr_state import kl_grad  | 
 | 13 | + | 
 | 14 | + | 
 | 15 | +def make_layer_size_dict(model, layer_names, input_shape=(1, 3, 224, 224), spatial_func=None):  | 
 | 16 | +    if spatial_func is None:  | 
 | 17 | +        spatial_func = lambda s: s  | 
 | 18 | +    transform = lambda s: spatial_func(s) if len(s.shape) > 2 else s  # avg-pool for last layer  | 
 | 19 | +    input = torch.zeros(*input_shape).cuda()  | 
 | 20 | +    layer_size_dict = {}  | 
 | 21 | +    with nethook.InstrumentedModel(model) as instr:  | 
 | 22 | +        instr.retain_layers(layer_names, detach=True)  | 
 | 23 | +        with torch.no_grad():  | 
 | 24 | +            _ = model(input)  | 
 | 25 | +            for ln in layer_names:  | 
 | 26 | +                b_state = instr.retained_layer(ln)  | 
 | 27 | +                layer_size_dict[ln] = transform(b_state).shape[1]  | 
 | 28 | +    return layer_size_dict  | 
 | 29 | + | 
 | 30 | +class Coverage:  | 
 | 31 | +    def __init__(self,  | 
 | 32 | +                 layer_size_dict: Dict,  | 
 | 33 | +                 device: Optional[Any] = None,  | 
 | 34 | +                 hyper: Optional[Dict] = None,  | 
 | 35 | +                 unpack: Optional[Callable] = None,  | 
 | 36 | +                 method: Optional[Any] = None,  | 
 | 37 | +                 **kwargs):  | 
 | 38 | +        if device is None:  | 
 | 39 | +            device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  | 
 | 40 | +        self.device = device  | 
 | 41 | +        self.layer_size_dict = layer_size_dict  | 
 | 42 | +        self.layer_names = list(layer_size_dict.keys())  | 
 | 43 | +        self.unpack = unpack  | 
 | 44 | +        self.coverage_dict = {ln: 0 for ln in self.layer_names}  | 
 | 45 | +        self.hyper = hyper  | 
 | 46 | +        self.method = method  | 
 | 47 | +        self.method_kwargs = kwargs  | 
 | 48 | +        self.init_variable(hyper)  | 
 | 49 | + | 
 | 50 | +    def init_variable(self, hyper):  | 
 | 51 | +        raise NotImplementedError  | 
 | 52 | + | 
 | 53 | +    def update(self):  | 
 | 54 | +        raise NotImplementedError  | 
 | 55 | + | 
 | 56 | +    def step(self, b_layer_state):  | 
 | 57 | +        raise NotImplementedError  | 
 | 58 | + | 
 | 59 | +    def clear(self):  | 
 | 60 | +        raise NotImplementedError  | 
 | 61 | + | 
 | 62 | +    def save(self, path):  | 
 | 63 | +        raise NotImplementedError  | 
 | 64 | + | 
 | 65 | +    def load(self, path, method):  | 
 | 66 | +        raise NotImplementedError  | 
 | 67 | + | 
 | 68 | +    def assess(self, model, data_loader, spatial_func=None, method=None, save_state=False, **kwargs):  | 
 | 69 | +        if method is not None:  | 
 | 70 | +            self.method = method  | 
 | 71 | +        if kwargs:  | 
 | 72 | +            self.method_kwargs = kwargs  | 
 | 73 | +        model.to(self.device)  | 
 | 74 | +        model.eval()  | 
 | 75 | +        if spatial_func is None:  | 
 | 76 | +            spatial_func = lambda s: s  | 
 | 77 | +        transform = lambda s: spatial_func(s).detach() if len(s.shape) > 2 else s.detach()  # avg-pool for last layer  | 
 | 78 | +        state_func = get_state_func(self.method, **self.method_kwargs)  | 
 | 79 | + | 
 | 80 | +        layer_output = {n: ([], [], []) for n in self.layer_names}  | 
 | 81 | +        total, correct = 0, 0  | 
 | 82 | +        with nethook.InstrumentedModel(model) as instr:  | 
 | 83 | +            instr.retain_layers(self.layer_names, detach=False)  | 
 | 84 | +            for i, data in enumerate(data_loader):  | 
 | 85 | +                x, y = self.unpack(data, self.device)  | 
 | 86 | +                p = model(x)  | 
 | 87 | +                correct_num, correct_flags = acc(p, y)  | 
 | 88 | +                correct += correct_num  | 
 | 89 | +                total += x.shape[0]  | 
 | 90 | +                b_layer_state = {}  | 
 | 91 | +                for j, ln in enumerate(self.layer_names):  | 
 | 92 | +                    retain_graph = False if j == len(self.layer_names) - 1 else True  | 
 | 93 | +                    b_state = instr.retained_layer(ln)  | 
 | 94 | +                    b_kl_grad = kl_grad(b_state, p, retain_graph=retain_graph)  | 
 | 95 | +                    b_state, b_kl_grad = transform(b_state), transform(b_kl_grad)  | 
 | 96 | +                    out = state_func(b_state, correct_flags, b_kl_grad)  | 
 | 97 | +                    b_layer_state[ln] = out  | 
 | 98 | + | 
 | 99 | +                    if save_state:  | 
 | 100 | +                        layer_output[ln][0].append(b_state.cpu())  | 
 | 101 | +                        layer_output[ln][1].append(correct_flags.cpu())  | 
 | 102 | +                        layer_output[ln][2].append(b_kl_grad.cpu())  | 
 | 103 | +                self.step(b_layer_state)  | 
 | 104 | +        if save_state:  | 
 | 105 | +            for ln in layer_output:  | 
 | 106 | +                states, flags, kl_grads = layer_output[ln]  | 
 | 107 | +                layer_output[ln] = (torch.cat(states), torch.cat(flags), torch.cat(kl_grads))  | 
 | 108 | +        return layer_output, correct / total  | 
 | 109 | + | 
 | 110 | +    def assess_with_cache(self, layer_name, data_loader, method=None, **kwargs):  | 
 | 111 | +        if method is not None:  | 
 | 112 | +            self.method = method  | 
 | 113 | +        if kwargs:  | 
 | 114 | +            self.method_kwargs = kwargs  | 
 | 115 | +        state_func = get_state_func(self.method, **self.method_kwargs)  | 
 | 116 | +        b_layer_state = {}  | 
 | 117 | +        for b_state, correct_flags, b_kl_grad in (data_loader):  | 
 | 118 | +            out = state_func(b_state.to(self.device, non_blocking=True),  | 
 | 119 | +                             correct_flags.to(self.device, non_blocking=True),  | 
 | 120 | +                             b_kl_grad.to(self.device, non_blocking=True))  | 
 | 121 | +            b_layer_state[layer_name] = out  | 
 | 122 | +            self.step(b_layer_state)  | 
 | 123 | + | 
 | 124 | +    def score(self, layer_name=None):  | 
 | 125 | +        if len(self.layer_names) == 1:  | 
 | 126 | +            layer_name = self.layer_names[0]  | 
 | 127 | +        if layer_name:  | 
 | 128 | +            return self.coverage_dict[layer_name]  | 
 | 129 | +        return self.coverage_dict  | 
 | 130 | + | 
 | 131 | + | 
 | 132 | +class KMNC(Coverage):  | 
 | 133 | +    def init_variable(self, hyper: Optional[Dict] = None):  | 
 | 134 | +        self.estimator_dict = {}  | 
 | 135 | +        self.current = 0  | 
 | 136 | + | 
 | 137 | +        assert ('M' in hyper and 'O' in hyper)  | 
 | 138 | +        self.M = hyper['M']  # number of buckets  | 
 | 139 | +        self.O = hyper['O']  # minimum number of samples required for bin filling  | 
 | 140 | +        for (layer_name, layer_size) in self.layer_size_dict.items():  | 
 | 141 | +            self.estimator_dict[layer_name] = Estimator(layer_size, self.M, self.O, self.device)  | 
 | 142 | + | 
 | 143 | +    def add(self, other):  | 
 | 144 | +        # check if other is a KMNC object  | 
 | 145 | +        assert (self.M == other.M) and (self.layer_names == other.layer_names)  | 
 | 146 | +        for ln in self.layer_names:  | 
 | 147 | +            self.estimator_dict[ln].add(other.estimator_dict[ln])  | 
 | 148 | + | 
 | 149 | +    def clear(self):  | 
 | 150 | +        for ln in self.layer_names:  | 
 | 151 | +            self.estimator_dict[ln].clear()  | 
 | 152 | + | 
 | 153 | +    def step(self, b_layer_state):  | 
 | 154 | +        for (ln, states) in b_layer_state.items():  | 
 | 155 | +            if len(states) > 0:  | 
 | 156 | +                # print(states.shape)  | 
 | 157 | +                self.estimator_dict[ln].update(states)  | 
 | 158 | + | 
 | 159 | +    def update(self, **kwargs):  | 
 | 160 | +        for ln in self.layer_names:  | 
 | 161 | +            thresh = self.estimator_dict[ln].thresh[:, 0].cpu().numpy()  | 
 | 162 | +            t_cov, coverage = self.estimator_dict[ln].get_score(**kwargs)  | 
 | 163 | +            self.coverage_dict[ln] = (thresh, t_cov, coverage)  | 
 | 164 | +        return self.score()  | 
 | 165 | + | 
 | 166 | +    def save(self, path, prefix="cov"):  | 
 | 167 | +        os.makedirs(path, exist_ok=True)  | 
 | 168 | +        for k, v in self.method_kwargs.items():  | 
 | 169 | +            prefix += f"_{k}_{v}"  | 
 | 170 | +        name = prefix + f"_states_M{self.M}_{self.method}.pkl"  | 
 | 171 | +        # print('Saving recorded coverage states in %s/%s...' % (path, name))  | 
 | 172 | +        state = {  | 
 | 173 | +            'layer_size_dict': self.layer_size_dict,  | 
 | 174 | +            'hyper': self.hyper,  | 
 | 175 | +            'es_states': {ln: self.estimator_dict[ln].states for ln in self.layer_names},  | 
 | 176 | +            'method': self.method,  | 
 | 177 | +            'method_kwargs': self.method_kwargs  | 
 | 178 | +        }  | 
 | 179 | +        torch.save(state, os.path.join(path, name))  | 
 | 180 | + | 
 | 181 | +    @staticmethod  | 
 | 182 | +    def load(path, name, device=None, r=None, unpack=None, verbose=False):  | 
 | 183 | +        state = torch.load(os.path.join(path, name))  | 
 | 184 | +        if r is not None and r > 0:  | 
 | 185 | +            state['hyper']['O'] = r  | 
 | 186 | +        if verbose:  | 
 | 187 | +            print('Loading saved coverage states in %s/%s...' % (path, name))  | 
 | 188 | +            for k, v in state['hyper'].items():  | 
 | 189 | +                print("load hyper params of Coverage:", k, v)  | 
 | 190 | +        Coverage = KMNC(state['layer_size_dict'], device=device,  | 
 | 191 | +                        hyper=state['hyper'], unpack=unpack)  | 
 | 192 | +        for k, v in state['es_states'].items():  | 
 | 193 | +            Coverage.estimator_dict[k].load(v)  | 
 | 194 | +        try:  | 
 | 195 | +            Coverage.method = state['method']  | 
 | 196 | +            Coverage.method_kwargs = state['method_kwargs']  | 
 | 197 | +        except:  | 
 | 198 | +            print("failed to load method and method_kwargs")  | 
 | 199 | +            pass  | 
 | 200 | +        return Coverage  | 
 | 201 | + | 
 | 202 | + | 
 | 203 | +def logspace(base=10, num=100):  | 
 | 204 | +    num = int(num / 2)  | 
 | 205 | +    x = np.linspace(1, np.sqrt(base), num=num)  | 
 | 206 | +    x_l = np.emath.logn(base, x)  | 
 | 207 | +    x_r = (1 - x_l)[::-1]  | 
 | 208 | +    x = np.concatenate([x_l[:-1], x_r])  | 
 | 209 | +    x[-1] += 1e-2  | 
 | 210 | +    return torch.from_numpy(np.append(x, 1.2))  | 
 | 211 | + | 
 | 212 | +class Estimator(object):  | 
 | 213 | +    def __init__(self, neuron_num, M=1000, O=1, device=None):  | 
 | 214 | +        assert O > 0, 'O should > (or =) 1'  | 
 | 215 | +        if device is None:  | 
 | 216 | +            device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  | 
 | 217 | +        self.device = device  | 
 | 218 | +        self.M, self.O, self.N = M, O, neuron_num  | 
 | 219 | +        self.thresh = torch.linspace(0., 1., M).view(M, -1).repeat(1, neuron_num).to(self.device)  | 
 | 220 | +        # self.thresh = logspace(1e3, M).view(M, -1).repeat(1, neuron_num).to(self.device)  | 
 | 221 | +        self.t_act = torch.zeros(M - 1, neuron_num).to(self.device)  # current activations under each thresh  | 
 | 222 | + | 
 | 223 | +    def add(self, other):  | 
 | 224 | +        # check if other is an Estimator object  | 
 | 225 | +        assert (self.M == other.M) and (self.N == other.N)  | 
 | 226 | +        self.t_act += other.t_act  | 
 | 227 | + | 
 | 228 | +    def update(self, states):  | 
 | 229 | +        # bmax, bmin = states.max(dim=0)[0], states.min(dim=0)[0]  # [num_neuron]  | 
 | 230 | +        # if (bmax > self.upper).any() or (bmin < self.lower).any():  | 
 | 231 | + | 
 | 232 | +        # thresh -> [num_t, num_n] -> [1, num_t, num_n] ->compare-> [num_data, num_t, num_n]  | 
 | 233 | +        # states -> [num_data, num_n] -> [num_data, 1, num_n] ->compare-> ...  | 
 | 234 | +        with torch.no_grad():  | 
 | 235 | +            # print(states.shape)  | 
 | 236 | +            b_act = (states.unsqueeze(1) >= self.thresh[:self.M - 1].unsqueeze(0)) & \  | 
 | 237 | +                    (states.unsqueeze(1) < self.thresh[1:self.M].unsqueeze(0))  | 
 | 238 | +            b_act = b_act.sum(dim=0)  # [num_t, num_n]  | 
 | 239 | +            self.t_act += b_act  # current activations under each thresh  | 
 | 240 | + | 
 | 241 | +    def get_score(self, method="avg"):  | 
 | 242 | +        t_score = torch.min(self.t_act / self.O, torch.ones_like(self.t_act))  # [num_t, num_n]  | 
 | 243 | +        coverage = (t_score.sum(dim=0)) / self.M  # [num_n]  | 
 | 244 | +        if method == "norm2":  | 
 | 245 | +            coverage = coverage.norm(p=1)  | 
 | 246 | +        elif method == "avg":  | 
 | 247 | +            coverage = coverage.mean()  | 
 | 248 | + | 
 | 249 | +        # t_cov = t_score.mean(dim=1).cpu().numpy() # for simplicity  | 
 | 250 | +        t_cov = t_score[:, 0].cpu().numpy()  # for simplicity  | 
 | 251 | +        return np.append(t_cov, 0), coverage.cpu()  | 
 | 252 | + | 
 | 253 | +    @property  | 
 | 254 | +    def states(self):  | 
 | 255 | +        return {  | 
 | 256 | +            "thresh": self.thresh.cpu(),  | 
 | 257 | +            "t_act": self.t_act.cpu()  | 
 | 258 | +        }  | 
 | 259 | + | 
 | 260 | +    def load(self, state_dict):  | 
 | 261 | +        self.thresh = state_dict["thresh"].to(self.device)  | 
 | 262 | +        self.t_act = state_dict["t_act"].to(self.device)  | 
 | 263 | + | 
 | 264 | +    def clear(self):  | 
 | 265 | +        self.t_act = torch.zeros(self.M - 1, self.N).to(self.device)  # current activations under each thresh  | 
0 commit comments