|
13 | 13 | import pickle
|
14 | 14 | import sys
|
15 | 15 | import tarfile
|
| 16 | +import gzip |
16 | 17 | import zipfile
|
17 | 18 | from pathlib import Path
|
18 | 19 | from typing import Callable, Optional, Tuple, Union
|
@@ -165,6 +166,36 @@ def iterate_images():
|
165 | 166 |
|
166 | 167 | #----------------------------------------------------------------------------
|
167 | 168 |
|
| 169 | +def open_mnist(images_gz: str, *, max_images: Optional[int]): |
| 170 | + labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') |
| 171 | + assert labels_gz != images_gz |
| 172 | + images = [] |
| 173 | + labels = [] |
| 174 | + |
| 175 | + with gzip.open(images_gz, 'rb') as f: |
| 176 | + images = np.frombuffer(f.read(), np.uint8, offset=16) |
| 177 | + with gzip.open(labels_gz, 'rb') as f: |
| 178 | + labels = np.frombuffer(f.read(), np.uint8, offset=8) |
| 179 | + |
| 180 | + images = images.reshape(-1, 28, 28) |
| 181 | + images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) |
| 182 | + assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 |
| 183 | + assert labels.shape == (60000,) and labels.dtype == np.uint8 |
| 184 | + assert np.min(images) == 0 and np.max(images) == 255 |
| 185 | + assert np.min(labels) == 0 and np.max(labels) == 9 |
| 186 | + |
| 187 | + max_idx = maybe_min(len(images), max_images) |
| 188 | + |
| 189 | + def iterate_images(): |
| 190 | + for idx, img in enumerate(images): |
| 191 | + yield dict(img=img, label=int(labels[idx])) |
| 192 | + if idx >= max_idx-1: |
| 193 | + break |
| 194 | + |
| 195 | + return max_idx, iterate_images() |
| 196 | + |
| 197 | +#---------------------------------------------------------------------------- |
| 198 | + |
168 | 199 | def make_transform(
|
169 | 200 | transform: Optional[str],
|
170 | 201 | output_width: Optional[int],
|
@@ -225,10 +256,11 @@ def open_dataset(source, *, max_images: Optional[int]):
|
225 | 256 | else:
|
226 | 257 | return open_image_folder(source, max_images=max_images)
|
227 | 258 | elif os.path.isfile(source):
|
228 |
| - if source.endswith('cifar-10-python.tar.gz'): |
| 259 | + if os.path.basename(source) == 'cifar-10-python.tar.gz': |
229 | 260 | return open_cifar10(source, max_images=max_images)
|
230 |
| - ext = file_ext(source) |
231 |
| - if ext == 'zip': |
| 261 | + elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': |
| 262 | + return open_mnist(source, max_images=max_images) |
| 263 | + elif file_ext(source) == 'zip': |
232 | 264 | return open_image_zip(source, max_images=max_images)
|
233 | 265 | else:
|
234 | 266 | assert False, 'unknown archive type'
|
@@ -293,17 +325,18 @@ def convert_dataset(
|
293 | 325 | The input dataset format is guessed from the --source argument:
|
294 | 326 |
|
295 | 327 | \b
|
296 |
| - --source *_lmdb/ - Load LSUN dataset |
297 |
| - --source cifar-10-python.tar.gz - Load CIFAR-10 dataset |
298 |
| - --source path/ - Recursively load all images from path/ |
299 |
| - --source dataset.zip - Recursively load all images from dataset.zip |
| 328 | + --source *_lmdb/ Load LSUN dataset |
| 329 | + --source cifar-10-python.tar.gz Load CIFAR-10 dataset |
| 330 | + --source train-images-idx3-ubyte.gz Load MNIST dataset |
| 331 | + --source path/ Recursively load all images from path/ |
| 332 | + --source dataset.zip Recursively load all images from dataset.zip |
300 | 333 |
|
301 |
| - The output dataset format can be either an image folder or a zip archive. Specifying |
302 |
| - the output format and path: |
| 334 | + The output dataset format can be either an image folder or a zip archive. |
| 335 | + Specifying the output format and path: |
303 | 336 |
|
304 | 337 | \b
|
305 |
| - --dest /path/to/dir - Save output files under /path/to/dir |
306 |
| - --dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive |
| 338 | + --dest /path/to/dir Save output files under /path/to/dir |
| 339 | + --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip |
307 | 340 |
|
308 | 341 | Images within the dataset archive will be stored as uncompressed PNG.
|
309 | 342 |
|
|
0 commit comments