Skip to content

Commit 96c4fb6

Browse files
authored
update code
add Class "Data3" in Data.py to support triple input fix bug : when the input with 'int' dtype is passed to norm_with_l2 function in data_prepare.py, it will first convert to float to prevent uperflow while doing power operation.
1 parent 7a03cfa commit 96c4fb6

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

Data.py

+110
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,113 @@ def next_batch(self, batch_size, shuffle=True):
9898
self._index_in_step += batch_size
9999
end = self._index_in_step
100100
return self._images[start:end], self._labels[start:end]
101+
102+
103+
class Data3(object):
104+
def __init__(self, images, labels1, labels2):
105+
self._num_examples = images.shape[0]
106+
self._images = images
107+
self._labels1 = labels1
108+
self._labels2 = labels2
109+
self._steps_completed = 0
110+
self._index_in_step = 0
111+
112+
@property
113+
def images(self):
114+
return self._images
115+
116+
@property
117+
def labels1(self):
118+
return self._labels1
119+
120+
@property
121+
def labels2(self):
122+
return self._labels2
123+
124+
@property
125+
def num_examples(self):
126+
return self._num_examples
127+
128+
@property
129+
def steps_completed(self):
130+
return self._steps_completed
131+
132+
def next_batch(self, batch_size, shuffle=True):
133+
"""Return the next `batch_size` examples from this data set."""
134+
"go through all the data"
135+
start = self._index_in_step
136+
# 对第一个step进行打乱
137+
if self._steps_completed == 0 and start == 0 and shuffle:
138+
# 返回一个array对象且间隔为1
139+
perm0 = np.arange(self._num_examples)
140+
# 打乱列表
141+
np.random.shuffle(perm0)
142+
self._images = self.images[perm0]
143+
self._labels1 = self.labels1[perm0]
144+
self._labels2 = self.labels2[perm0]
145+
# 进入下一个step之前,有余下数据的处理
146+
if start + batch_size > self._num_examples:
147+
if start + batch_size < 2 * self._num_examples:
148+
# 完成一个step的标志位
149+
self._steps_completed += 1
150+
# 得到该step余下的数据
151+
rest_num_examples = self._num_examples - start
152+
images_rest_part = self._images[start:self._num_examples]
153+
labels_rest_part1 = self._labels1[start:self._num_examples]
154+
labels_rest_part2 = self._labels2[start:self._num_examples]
155+
# 对数据进行打乱
156+
if shuffle:
157+
perm = np.arange(self._num_examples)
158+
np.random.shuffle(perm)
159+
self._images = self._images[perm]
160+
self._labels1 = self._labels1[perm]
161+
self._labels2 = self._labels2[perm]
162+
# 开始下一个step,并凑齐一个batch
163+
start = 0
164+
self._index_in_step = batch_size - rest_num_examples
165+
end = self._index_in_step
166+
images_new_part = self._images[start:end]
167+
labels_new_part1 = self._labels1[start:end]
168+
labels_new_part2 = self._labels2[start:end]
169+
return np.concatenate((images_rest_part, images_new_part), axis=0), \
170+
np.concatenate((labels_rest_part1, labels_new_part1), axis=0), \
171+
np.concatenate((labels_rest_part2, labels_new_part2), axis=0)
172+
else:
173+
reuse_times = np.int(np.floor((start + batch_size) / self._num_examples) - 1)
174+
self._steps_completed += reuse_times + 1
175+
images_rest_part = self._images[start:self._num_examples]
176+
labels_rest_part1 = self._labels1[start:self._num_examples]
177+
labels_rest_part2 = self._labels2[start:self._num_examples]
178+
batch_images = images_rest_part
179+
batch_labels1 = labels_rest_part1
180+
batch_labels2 = labels_rest_part2
181+
for ind_resuse in range(reuse_times):
182+
if shuffle:
183+
perm = np.arange(self._num_examples)
184+
np.random.shuffle(perm)
185+
self._images = self._images[perm]
186+
self._labels1 = self._labels1[perm]
187+
self._labels2 = self._labels2[perm]
188+
batch_images = np.concatenate((batch_images, self._images), axis=0)
189+
batch_labels1 = np.concatenate((batch_labels1, self._labels1), axis=0)
190+
batch_labels2 = np.concatenate((batch_labels2, self._labels2), axis=0)
191+
if (start + batch_size) % self._num_examples == 0:
192+
self._index_in_step = 0
193+
return batch_images, batch_labels1, batch_labels2
194+
else:
195+
if shuffle:
196+
perm = np.arange(self._num_examples)
197+
np.random.shuffle(perm)
198+
self._images = self._images[perm]
199+
self._labels1 = self._labels1[perm]
200+
self._labels2 = self._labels2[perm]
201+
self._index_in_step = (start + batch_size) % self._num_examples
202+
end = self._index_in_step
203+
batch_images = np.concatenate((batch_images, self._images[0:end]), axis=0)
204+
batch_labels1 = np.concatenate((batch_labels1, self._labels1[0:end]), axis=0)
205+
batch_labels2 = np.concatenate((batch_labels2, self._labels2[0:end]), axis=0)
206+
return batch_images, batch_labels1, batch_labels2
207+
else:
208+
self._index_in_step += batch_size
209+
end = self._index_in_step
210+
return self._images[start:end], self._labels1[start:end], self._labels2[start:end]

data_prepare.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
## data_prepare.py
2+
3+
import os
4+
# from libtiff import TIFF
5+
import numpy as np
6+
import cv2
7+
8+
9+
def random_perm3(data_x, data_y, data_z):
10+
"""
11+
do random perm on x, y and z, x and z are list object, y is ndarray object
12+
:param data_x: x data, list object
13+
:param data_y: label data, ndarray object
14+
:param data_z: x data size, list object
15+
"""
16+
data_size = data_y.shape[0]
17+
rand_perm = np.arange(data_size)
18+
np.random.shuffle(rand_perm)
19+
random_data_x = []
20+
random_data_z = []
21+
for indices in rand_perm: # 'List' object, more complicated !!!
22+
random_data_x.append(data_x[indices])
23+
random_data_z.append(data_z[indices])
24+
# random_data_x = data_x[rand_perm]
25+
random_data_y = data_y[rand_perm]
26+
# random_data_z = data_z[rand_perm]
27+
return random_data_x, random_data_y, random_data_z
28+
29+
30+
# def generate_dataset_multisize(source_path, file_extension='tif'):
31+
# """
32+
# load multisize source images and generate datasets
33+
# :param source_path: the store path of source images, source_path/category/image files
34+
# :param file_extension: image files' extension, default to 'tif', support tif, png, npy
35+
# :return: x, y, x_shape
36+
# """
37+
# if not os.path.exists(source_path):
38+
# raise FileExistsError('file not found! : %s' % source_path)
39+
# number_of_categories = 0
40+
# for category in os.scandir(source_path):
41+
# if category.is_dir():
42+
# number_of_categories += 1
43+
# number_of_image_per_category = np.zeros(number_of_categories, dtype=np.int32)
44+
# category_name = []
45+
# dataset_x = []
46+
# dataset_x_shape = []
47+
# index_category = 0
48+
# for category in os.scandir(source_path):
49+
# if category.is_dir():
50+
# index_category += 1
51+
# number_of_images = 0
52+
# category_name.append(category.name)
53+
# image = []
54+
# for img_file in os.scandir(category.path):
55+
# extension = os.path.splitext(img_file.path)[1][1:]
56+
# if file_extension == extension:
57+
# number_of_images += 1
58+
# if extension == 'tif':
59+
# tif = TIFF.open(img_file.path, mode='r')
60+
# image = tif.read_image()
61+
# this_x = np.reshape(np.sqrt(np.power(image[:, :, 0], 2) + np.power(image[:, :, 1], 2)), (1, -1),
62+
# order='C')
63+
# elif extension == 'png':
64+
# image = cv2.imread(img_file.path, -1)
65+
# this_x = image
66+
# elif extension == 'npy':
67+
# image = np.load(img_file.path)
68+
# this_x = image
69+
# else:
70+
# raise ValueError('''unsupported image file's extension: %s''' % file_extension)
71+
# dataset_x_shape.append([image.shape[0], image.shape[1]])
72+
# this_x_norml2 = (this_x * 1.0) / np.sqrt(np.sum(np.square(this_x)))
73+
# dataset_x.append(this_x_norml2)
74+
# number_of_image_per_category[index_category-1] = number_of_images
75+
# # print(number_of_image_per_category)
76+
# dataset_y = np.zeros(
77+
# [sum(number_of_image_per_category), number_of_categories],
78+
# dtype=np.int32)
79+
# for index_category in range(number_of_categories):
80+
# dataset_y[sum(number_of_image_per_category[0:index_category]):
81+
# sum(number_of_image_per_category[0:index_category+1]),
82+
# index_category] = 1
83+
# # print(len(dataset_x))
84+
# return dataset_x, dataset_y, dataset_x_shape
85+
86+
87+
def norm_with_l2(original_mat):
88+
"""
89+
devided by original mat's L2 norm to got identity length mat
90+
each row is a datapoint
91+
:param original_mat:
92+
:return: normed mat
93+
"""
94+
normed_mat = np.zeros(original_mat.shape, dtype=np.float32)
95+
if len(original_mat.shape) == 2:
96+
for ind_r in range(original_mat.shape[0]):
97+
a = np.square(original_mat[ind_r]*1.0)
98+
b = np.sum(a)
99+
c = np.sqrt(b)
100+
normed_mat[ind_r] = (original_mat[ind_r] * 1.0) / c
101+
# normed_mat[ind_r] = (original_mat[ind_r] * 1.0) / np.sqrt(np.sum(np.square(original_mat[ind_r])*1.0))
102+
return normed_mat
103+

0 commit comments

Comments
 (0)