-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcutmix_keras.py
74 lines (60 loc) · 2.44 KB
/
cutmix_keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import random
import numpy as np
class CutMixImageDataGenerator():
def __init__(self, generator1, generator2, img_size, batch_size, p):
self.batch_index = 0
self.samples = generator1.samples
self.class_indices = generator1.class_indices
self.generator1 = generator1
self.generator2 = generator2
self.img_size = img_size
self.batch_size = batch_size
self.p = p
def reset_index(self): # Ordering Reset (If Shuffle is True, Shuffle Again)
self.generator1._set_index_array()
self.generator2._set_index_array()
def reset(self):
self.batch_index = 0
self.generator1.reset()
self.generator2.reset()
self.reset_index()
def get_steps_per_epoch(self):
quotient, remainder = divmod(self.samples, self.batch_size)
return (quotient + 1) if remainder else quotient
def __len__(self):
self.get_steps_per_epoch()
def __next__(self):
if self.batch_index == 0: self.reset()
crt_idx = self.batch_index * self.batch_size
if self.samples > crt_idx + self.batch_size:
self.batch_index += 1
else: # If current index over number of samples
self.batch_index = 0
reshape_size = self.batch_size
last_step_start_idx = (self.get_steps_per_epoch()-1) * self.batch_size
if crt_idx == last_step_start_idx:
reshape_size = self.samples - last_step_start_idx
X_1, y_1 = self.generator1.next()
X_2, y_2 = self.generator2.next()
if random.random() >= self.p:
return X_1, y_1
cut_ratio = np.random.beta(a=1, b=1, size=reshape_size)
cut_ratio = np.clip(cut_ratio, 0.2, 0.8)
label_ratio = cut_ratio.reshape(reshape_size, 1)
cut_img = X_2
X = X_1
for i in range(reshape_size):
cut_size = int((self.img_size-1) * cut_ratio[i])
y1 = random.randint(0, (self.img_size-1) - cut_size)
x1 = random.randint(0, (self.img_size-1) - cut_size)
y2 = y1 + cut_size
x2 = x1 + cut_size
cut_arr = cut_img[i][y1:y2, x1:x2]
cutmix_img = X_1[i]
cutmix_img[y1:y2, x1:x2] = cut_arr
X[i] = cutmix_img
y = y_1 * (1 - (label_ratio ** 2)) + y_2 * (label_ratio ** 2)
return X, y
def __iter__(self):
while True:
yield next(self)