Skip to content

Commit 06b613b

Browse files
committed
NAC-ME implementation
1 parent 782a4f5 commit 06b613b

36 files changed

+15813
-0
lines changed

domainbed/README.md

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
3+
## Usage
4+
This is the PyTorch implementation of our NAC-ME: https://arxiv.org/abs/2306.02879. Our experimental settings carefully align with Domainbed. We adopt the improved implementation from [https://github.com/khanrc/swad/tree/main](https://github.com/khanrc/swad/tree/main).
5+
6+
We provide the required packages in [environment.yml](https://github.com/BierOne/ood_coverage/tree/main/environment.yml), you can simply run the following command to create the environment:
7+
```
8+
pip install -r requirements.txt
9+
```
10+
11+
12+
## How to run
13+
14+
`train_all.py` script conducts multiple leave-one-out cross-validations for all target domain. Taking the ERM model, PACS dataset, and resent18 arch as an example, you can run the following command:
15+
16+
```sh
17+
CUDA_VISIBLE_DEVICES=1 bash scripts/run_single_train.sh ERM PACS resnet18
18+
```
19+
20+
Experiment results are reported as a table. In the table, the row coverage indicates out-of-domain accuracy from NAC-ME selection. The row coverage_rc indicates the rank correlation between NAC-ME scores and test accuracy.
21+
22+
23+
Example results:
24+
```
25+
+----------------------------+--------------+----------+----------+----------+----------+
26+
| Selection | art_painting | cartoon | photo | sketch | Avg. |
27+
+----------------------------+--------------+----------+----------+----------+----------+
28+
| oracle | 71.979% | 72.750% | 44.539% | 47.946% | 59.303% |
29+
| last | 71.608% | 70.401% | 41.674% | 47.478% | 57.790% |
30+
| training-domain validation | 68.475% | 71.588% | 40.478% | 43.381% | 55.981% |
31+
| coverage | 71.608% | 70.401% | 42.776% | 47.478% | 58.066% |
32+
| coverage_rc | 69.853% | 35.049% | 54.167% | 45.588% | 51.164% |
33+
| val_rc | 65.441% | 43.627% | 37.500% | 62.990% | 52.390% |
34+
| oracle_rc | 94.608% | 94.853% | 92.892% | 96.324% | 94.669% |
35+
| test | 100.000% | 100.000% | 100.000% | 100.000% | 100.000% |
36+
+----------------------------+--------------+----------+----------+----------+----------+
37+
```
38+
In the above example, the best model selected by NAC-ME achieves 58.066% average accuracy, which is better than the validation-selected model (55.981%). Note that this is just a simple case of NAC-ME. To achieve the full comparison, it is necessary to sweep a large number of models and datasets in such an unstable DG training scenario.
39+
40+
NAC-ME calculation is quite similar to the NAC-UE, please refer to the [source file](https://github.com/BierOne/ood_coverage/tree/master/domainbed/benchmark_notes/coverage.py) for more details.
41+
42+
43+
## Credits
44+
This codebase is developed based on [SWAD](https://github.com/khanrc/swad/tree/main). We extend our sincere gratitude for their generosity in providing this valuable resource.
45+

domainbed/benchmark_notes/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
Create Time: 22/3/2023
3+
Author: BierOne ([email protected])
4+
"""

domainbed/benchmark_notes/coverage.py

+265
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
"""
2+
Create Time: 29/3/2023
3+
Author: BierOne ([email protected])
4+
"""
5+
import os
6+
import torch
7+
import numpy as np
8+
from typing import Any, Callable, Optional, Sequence, Tuple, Dict
9+
10+
from benchmark_notes.utils import get_state_func, acc
11+
from benchmark_notes import nethook
12+
from benchmark_notes.instr_state import kl_grad
13+
14+
15+
def make_layer_size_dict(model, layer_names, input_shape=(1, 3, 224, 224), spatial_func=None):
16+
if spatial_func is None:
17+
spatial_func = lambda s: s
18+
transform = lambda s: spatial_func(s) if len(s.shape) > 2 else s # avg-pool for last layer
19+
input = torch.zeros(*input_shape).cuda()
20+
layer_size_dict = {}
21+
with nethook.InstrumentedModel(model) as instr:
22+
instr.retain_layers(layer_names, detach=True)
23+
with torch.no_grad():
24+
_ = model(input)
25+
for ln in layer_names:
26+
b_state = instr.retained_layer(ln)
27+
layer_size_dict[ln] = transform(b_state).shape[1]
28+
return layer_size_dict
29+
30+
class Coverage:
31+
def __init__(self,
32+
layer_size_dict: Dict,
33+
device: Optional[Any] = None,
34+
hyper: Optional[Dict] = None,
35+
unpack: Optional[Callable] = None,
36+
method: Optional[Any] = None,
37+
**kwargs):
38+
if device is None:
39+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
40+
self.device = device
41+
self.layer_size_dict = layer_size_dict
42+
self.layer_names = list(layer_size_dict.keys())
43+
self.unpack = unpack
44+
self.coverage_dict = {ln: 0 for ln in self.layer_names}
45+
self.hyper = hyper
46+
self.method = method
47+
self.method_kwargs = kwargs
48+
self.init_variable(hyper)
49+
50+
def init_variable(self, hyper):
51+
raise NotImplementedError
52+
53+
def update(self):
54+
raise NotImplementedError
55+
56+
def step(self, b_layer_state):
57+
raise NotImplementedError
58+
59+
def clear(self):
60+
raise NotImplementedError
61+
62+
def save(self, path):
63+
raise NotImplementedError
64+
65+
def load(self, path, method):
66+
raise NotImplementedError
67+
68+
def assess(self, model, data_loader, spatial_func=None, method=None, save_state=False, **kwargs):
69+
if method is not None:
70+
self.method = method
71+
if kwargs:
72+
self.method_kwargs = kwargs
73+
model.to(self.device)
74+
model.eval()
75+
if spatial_func is None:
76+
spatial_func = lambda s: s
77+
transform = lambda s: spatial_func(s).detach() if len(s.shape) > 2 else s.detach() # avg-pool for last layer
78+
state_func = get_state_func(self.method, **self.method_kwargs)
79+
80+
layer_output = {n: ([], [], []) for n in self.layer_names}
81+
total, correct = 0, 0
82+
with nethook.InstrumentedModel(model) as instr:
83+
instr.retain_layers(self.layer_names, detach=False)
84+
for i, data in enumerate(data_loader):
85+
x, y = self.unpack(data, self.device)
86+
p = model(x)
87+
correct_num, correct_flags = acc(p, y)
88+
correct += correct_num
89+
total += x.shape[0]
90+
b_layer_state = {}
91+
for j, ln in enumerate(self.layer_names):
92+
retain_graph = False if j == len(self.layer_names) - 1 else True
93+
b_state = instr.retained_layer(ln)
94+
b_kl_grad = kl_grad(b_state, p, retain_graph=retain_graph)
95+
b_state, b_kl_grad = transform(b_state), transform(b_kl_grad)
96+
out = state_func(b_state, correct_flags, b_kl_grad)
97+
b_layer_state[ln] = out
98+
99+
if save_state:
100+
layer_output[ln][0].append(b_state.cpu())
101+
layer_output[ln][1].append(correct_flags.cpu())
102+
layer_output[ln][2].append(b_kl_grad.cpu())
103+
self.step(b_layer_state)
104+
if save_state:
105+
for ln in layer_output:
106+
states, flags, kl_grads = layer_output[ln]
107+
layer_output[ln] = (torch.cat(states), torch.cat(flags), torch.cat(kl_grads))
108+
return layer_output, correct / total
109+
110+
def assess_with_cache(self, layer_name, data_loader, method=None, **kwargs):
111+
if method is not None:
112+
self.method = method
113+
if kwargs:
114+
self.method_kwargs = kwargs
115+
state_func = get_state_func(self.method, **self.method_kwargs)
116+
b_layer_state = {}
117+
for b_state, correct_flags, b_kl_grad in (data_loader):
118+
out = state_func(b_state.to(self.device, non_blocking=True),
119+
correct_flags.to(self.device, non_blocking=True),
120+
b_kl_grad.to(self.device, non_blocking=True))
121+
b_layer_state[layer_name] = out
122+
self.step(b_layer_state)
123+
124+
def score(self, layer_name=None):
125+
if len(self.layer_names) == 1:
126+
layer_name = self.layer_names[0]
127+
if layer_name:
128+
return self.coverage_dict[layer_name]
129+
return self.coverage_dict
130+
131+
132+
class KMNC(Coverage):
133+
def init_variable(self, hyper: Optional[Dict] = None):
134+
self.estimator_dict = {}
135+
self.current = 0
136+
137+
assert ('M' in hyper and 'O' in hyper)
138+
self.M = hyper['M'] # number of buckets
139+
self.O = hyper['O'] # minimum number of samples required for bin filling
140+
for (layer_name, layer_size) in self.layer_size_dict.items():
141+
self.estimator_dict[layer_name] = Estimator(layer_size, self.M, self.O, self.device)
142+
143+
def add(self, other):
144+
# check if other is a KMNC object
145+
assert (self.M == other.M) and (self.layer_names == other.layer_names)
146+
for ln in self.layer_names:
147+
self.estimator_dict[ln].add(other.estimator_dict[ln])
148+
149+
def clear(self):
150+
for ln in self.layer_names:
151+
self.estimator_dict[ln].clear()
152+
153+
def step(self, b_layer_state):
154+
for (ln, states) in b_layer_state.items():
155+
if len(states) > 0:
156+
# print(states.shape)
157+
self.estimator_dict[ln].update(states)
158+
159+
def update(self, **kwargs):
160+
for ln in self.layer_names:
161+
thresh = self.estimator_dict[ln].thresh[:, 0].cpu().numpy()
162+
t_cov, coverage = self.estimator_dict[ln].get_score(**kwargs)
163+
self.coverage_dict[ln] = (thresh, t_cov, coverage)
164+
return self.score()
165+
166+
def save(self, path, prefix="cov"):
167+
os.makedirs(path, exist_ok=True)
168+
for k, v in self.method_kwargs.items():
169+
prefix += f"_{k}_{v}"
170+
name = prefix + f"_states_M{self.M}_{self.method}.pkl"
171+
# print('Saving recorded coverage states in %s/%s...' % (path, name))
172+
state = {
173+
'layer_size_dict': self.layer_size_dict,
174+
'hyper': self.hyper,
175+
'es_states': {ln: self.estimator_dict[ln].states for ln in self.layer_names},
176+
'method': self.method,
177+
'method_kwargs': self.method_kwargs
178+
}
179+
torch.save(state, os.path.join(path, name))
180+
181+
@staticmethod
182+
def load(path, name, device=None, r=None, unpack=None, verbose=False):
183+
state = torch.load(os.path.join(path, name))
184+
if r is not None and r > 0:
185+
state['hyper']['O'] = r
186+
if verbose:
187+
print('Loading saved coverage states in %s/%s...' % (path, name))
188+
for k, v in state['hyper'].items():
189+
print("load hyper params of Coverage:", k, v)
190+
Coverage = KMNC(state['layer_size_dict'], device=device,
191+
hyper=state['hyper'], unpack=unpack)
192+
for k, v in state['es_states'].items():
193+
Coverage.estimator_dict[k].load(v)
194+
try:
195+
Coverage.method = state['method']
196+
Coverage.method_kwargs = state['method_kwargs']
197+
except:
198+
print("failed to load method and method_kwargs")
199+
pass
200+
return Coverage
201+
202+
203+
def logspace(base=10, num=100):
204+
num = int(num / 2)
205+
x = np.linspace(1, np.sqrt(base), num=num)
206+
x_l = np.emath.logn(base, x)
207+
x_r = (1 - x_l)[::-1]
208+
x = np.concatenate([x_l[:-1], x_r])
209+
x[-1] += 1e-2
210+
return torch.from_numpy(np.append(x, 1.2))
211+
212+
class Estimator(object):
213+
def __init__(self, neuron_num, M=1000, O=1, device=None):
214+
assert O > 0, 'O should > (or =) 1'
215+
if device is None:
216+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
217+
self.device = device
218+
self.M, self.O, self.N = M, O, neuron_num
219+
self.thresh = torch.linspace(0., 1., M).view(M, -1).repeat(1, neuron_num).to(self.device)
220+
# self.thresh = logspace(1e3, M).view(M, -1).repeat(1, neuron_num).to(self.device)
221+
self.t_act = torch.zeros(M - 1, neuron_num).to(self.device) # current activations under each thresh
222+
223+
def add(self, other):
224+
# check if other is an Estimator object
225+
assert (self.M == other.M) and (self.N == other.N)
226+
self.t_act += other.t_act
227+
228+
def update(self, states):
229+
# bmax, bmin = states.max(dim=0)[0], states.min(dim=0)[0] # [num_neuron]
230+
# if (bmax > self.upper).any() or (bmin < self.lower).any():
231+
232+
# thresh -> [num_t, num_n] -> [1, num_t, num_n] ->compare-> [num_data, num_t, num_n]
233+
# states -> [num_data, num_n] -> [num_data, 1, num_n] ->compare-> ...
234+
with torch.no_grad():
235+
# print(states.shape)
236+
b_act = (states.unsqueeze(1) >= self.thresh[:self.M - 1].unsqueeze(0)) & \
237+
(states.unsqueeze(1) < self.thresh[1:self.M].unsqueeze(0))
238+
b_act = b_act.sum(dim=0) # [num_t, num_n]
239+
self.t_act += b_act # current activations under each thresh
240+
241+
def get_score(self, method="avg"):
242+
t_score = torch.min(self.t_act / self.O, torch.ones_like(self.t_act)) # [num_t, num_n]
243+
coverage = (t_score.sum(dim=0)) / self.M # [num_n]
244+
if method == "norm2":
245+
coverage = coverage.norm(p=1)
246+
elif method == "avg":
247+
coverage = coverage.mean()
248+
249+
# t_cov = t_score.mean(dim=1).cpu().numpy() # for simplicity
250+
t_cov = t_score[:, 0].cpu().numpy() # for simplicity
251+
return np.append(t_cov, 0), coverage.cpu()
252+
253+
@property
254+
def states(self):
255+
return {
256+
"thresh": self.thresh.cpu(),
257+
"t_act": self.t_act.cpu()
258+
}
259+
260+
def load(self, state_dict):
261+
self.thresh = state_dict["thresh"].to(self.device)
262+
self.t_act = state_dict["t_act"].to(self.device)
263+
264+
def clear(self):
265+
self.t_act = torch.zeros(self.M - 1, self.N).to(self.device) # current activations under each thresh

0 commit comments

Comments
 (0)