@@ -98,3 +98,113 @@ def next_batch(self, batch_size, shuffle=True):
98
98
self ._index_in_step += batch_size
99
99
end = self ._index_in_step
100
100
return self ._images [start :end ], self ._labels [start :end ]
101
+
102
+
103
+ class Data3 (object ):
104
+ def __init__ (self , images , labels1 , labels2 ):
105
+ self ._num_examples = images .shape [0 ]
106
+ self ._images = images
107
+ self ._labels1 = labels1
108
+ self ._labels2 = labels2
109
+ self ._steps_completed = 0
110
+ self ._index_in_step = 0
111
+
112
+ @property
113
+ def images (self ):
114
+ return self ._images
115
+
116
+ @property
117
+ def labels1 (self ):
118
+ return self ._labels1
119
+
120
+ @property
121
+ def labels2 (self ):
122
+ return self ._labels2
123
+
124
+ @property
125
+ def num_examples (self ):
126
+ return self ._num_examples
127
+
128
+ @property
129
+ def steps_completed (self ):
130
+ return self ._steps_completed
131
+
132
+ def next_batch (self , batch_size , shuffle = True ):
133
+ """Return the next `batch_size` examples from this data set."""
134
+ "go through all the data"
135
+ start = self ._index_in_step
136
+ # 对第一个step进行打乱
137
+ if self ._steps_completed == 0 and start == 0 and shuffle :
138
+ # 返回一个array对象且间隔为1
139
+ perm0 = np .arange (self ._num_examples )
140
+ # 打乱列表
141
+ np .random .shuffle (perm0 )
142
+ self ._images = self .images [perm0 ]
143
+ self ._labels1 = self .labels1 [perm0 ]
144
+ self ._labels2 = self .labels2 [perm0 ]
145
+ # 进入下一个step之前,有余下数据的处理
146
+ if start + batch_size > self ._num_examples :
147
+ if start + batch_size < 2 * self ._num_examples :
148
+ # 完成一个step的标志位
149
+ self ._steps_completed += 1
150
+ # 得到该step余下的数据
151
+ rest_num_examples = self ._num_examples - start
152
+ images_rest_part = self ._images [start :self ._num_examples ]
153
+ labels_rest_part1 = self ._labels1 [start :self ._num_examples ]
154
+ labels_rest_part2 = self ._labels2 [start :self ._num_examples ]
155
+ # 对数据进行打乱
156
+ if shuffle :
157
+ perm = np .arange (self ._num_examples )
158
+ np .random .shuffle (perm )
159
+ self ._images = self ._images [perm ]
160
+ self ._labels1 = self ._labels1 [perm ]
161
+ self ._labels2 = self ._labels2 [perm ]
162
+ # 开始下一个step,并凑齐一个batch
163
+ start = 0
164
+ self ._index_in_step = batch_size - rest_num_examples
165
+ end = self ._index_in_step
166
+ images_new_part = self ._images [start :end ]
167
+ labels_new_part1 = self ._labels1 [start :end ]
168
+ labels_new_part2 = self ._labels2 [start :end ]
169
+ return np .concatenate ((images_rest_part , images_new_part ), axis = 0 ), \
170
+ np .concatenate ((labels_rest_part1 , labels_new_part1 ), axis = 0 ), \
171
+ np .concatenate ((labels_rest_part2 , labels_new_part2 ), axis = 0 )
172
+ else :
173
+ reuse_times = np .int (np .floor ((start + batch_size ) / self ._num_examples ) - 1 )
174
+ self ._steps_completed += reuse_times + 1
175
+ images_rest_part = self ._images [start :self ._num_examples ]
176
+ labels_rest_part1 = self ._labels1 [start :self ._num_examples ]
177
+ labels_rest_part2 = self ._labels2 [start :self ._num_examples ]
178
+ batch_images = images_rest_part
179
+ batch_labels1 = labels_rest_part1
180
+ batch_labels2 = labels_rest_part2
181
+ for ind_resuse in range (reuse_times ):
182
+ if shuffle :
183
+ perm = np .arange (self ._num_examples )
184
+ np .random .shuffle (perm )
185
+ self ._images = self ._images [perm ]
186
+ self ._labels1 = self ._labels1 [perm ]
187
+ self ._labels2 = self ._labels2 [perm ]
188
+ batch_images = np .concatenate ((batch_images , self ._images ), axis = 0 )
189
+ batch_labels1 = np .concatenate ((batch_labels1 , self ._labels1 ), axis = 0 )
190
+ batch_labels2 = np .concatenate ((batch_labels2 , self ._labels2 ), axis = 0 )
191
+ if (start + batch_size ) % self ._num_examples == 0 :
192
+ self ._index_in_step = 0
193
+ return batch_images , batch_labels1 , batch_labels2
194
+ else :
195
+ if shuffle :
196
+ perm = np .arange (self ._num_examples )
197
+ np .random .shuffle (perm )
198
+ self ._images = self ._images [perm ]
199
+ self ._labels1 = self ._labels1 [perm ]
200
+ self ._labels2 = self ._labels2 [perm ]
201
+ self ._index_in_step = (start + batch_size ) % self ._num_examples
202
+ end = self ._index_in_step
203
+ batch_images = np .concatenate ((batch_images , self ._images [0 :end ]), axis = 0 )
204
+ batch_labels1 = np .concatenate ((batch_labels1 , self ._labels1 [0 :end ]), axis = 0 )
205
+ batch_labels2 = np .concatenate ((batch_labels2 , self ._labels2 [0 :end ]), axis = 0 )
206
+ return batch_images , batch_labels1 , batch_labels2
207
+ else :
208
+ self ._index_in_step += batch_size
209
+ end = self ._index_in_step
210
+ return self ._images [start :end ], self ._labels1 [start :end ], self ._labels2 [start :end ]
0 commit comments