-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_validation.py
49 lines (38 loc) · 1.15 KB
/
generate_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torchvision
import numpy as np
from utils import cifar
np.random.seed(0)
m = 50000
P = np.random.permutation(m)
n = 1000
# CIFAR-10
dataset = cifar('../cifar-data')
val_data = dataset['train']['data'][P[:n]]
val_labels = [dataset['train']['labels'][p] for p in P[:n]]
train_data = dataset['train']['data'][P[n:]]
train_labels = [dataset['train']['labels'][p] for p in P[n:]]
dataset['train']['data'] = train_data
dataset['train']['labels'] = train_labels
dataset['val'] = {
'data': val_data,
'labels': val_labels
}
dataset['split'] = n
dataset['permutation'] = P
torch.save(dataset, 'cifar10_validation_split.pth')
# CIFAR-100
dataset = cifar('../cifar-data', num_classes=100)
val_data = dataset['train']['data'][P[:n]]
val_labels = [dataset['train']['labels'][p] for p in P[:n]]
train_data = dataset['train']['data'][P[n:]]
train_labels = [dataset['train']['labels'][p] for p in P[n:]]
dataset['train']['data'] = train_data
dataset['train']['labels'] = train_labels
dataset['val'] = {
'data': val_data,
'labels': val_labels
}
dataset['split'] = n
dataset['permutation'] = P
torch.save(dataset, 'cifar100_validation_split.pth')