Skip to content
This repository was archived by the owner on Feb 24, 2025. It is now read-only.

Commit 1d25833

Browse files
committed
Add support for MNIST
1 parent f0a4246 commit 1d25833

File tree

1 file changed

+44
-11
lines changed

1 file changed

+44
-11
lines changed

dataset_tool.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pickle
1414
import sys
1515
import tarfile
16+
import gzip
1617
import zipfile
1718
from pathlib import Path
1819
from typing import Callable, Optional, Tuple, Union
@@ -165,6 +166,36 @@ def iterate_images():
165166

166167
#----------------------------------------------------------------------------
167168

169+
def open_mnist(images_gz: str, *, max_images: Optional[int]):
170+
labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
171+
assert labels_gz != images_gz
172+
images = []
173+
labels = []
174+
175+
with gzip.open(images_gz, 'rb') as f:
176+
images = np.frombuffer(f.read(), np.uint8, offset=16)
177+
with gzip.open(labels_gz, 'rb') as f:
178+
labels = np.frombuffer(f.read(), np.uint8, offset=8)
179+
180+
images = images.reshape(-1, 28, 28)
181+
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
182+
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
183+
assert labels.shape == (60000,) and labels.dtype == np.uint8
184+
assert np.min(images) == 0 and np.max(images) == 255
185+
assert np.min(labels) == 0 and np.max(labels) == 9
186+
187+
max_idx = maybe_min(len(images), max_images)
188+
189+
def iterate_images():
190+
for idx, img in enumerate(images):
191+
yield dict(img=img, label=int(labels[idx]))
192+
if idx >= max_idx-1:
193+
break
194+
195+
return max_idx, iterate_images()
196+
197+
#----------------------------------------------------------------------------
198+
168199
def make_transform(
169200
transform: Optional[str],
170201
output_width: Optional[int],
@@ -225,10 +256,11 @@ def open_dataset(source, *, max_images: Optional[int]):
225256
else:
226257
return open_image_folder(source, max_images=max_images)
227258
elif os.path.isfile(source):
228-
if source.endswith('cifar-10-python.tar.gz'):
259+
if os.path.basename(source) == 'cifar-10-python.tar.gz':
229260
return open_cifar10(source, max_images=max_images)
230-
ext = file_ext(source)
231-
if ext == 'zip':
261+
elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
262+
return open_mnist(source, max_images=max_images)
263+
elif file_ext(source) == 'zip':
232264
return open_image_zip(source, max_images=max_images)
233265
else:
234266
assert False, 'unknown archive type'
@@ -293,17 +325,18 @@ def convert_dataset(
293325
The input dataset format is guessed from the --source argument:
294326
295327
\b
296-
--source *_lmdb/ - Load LSUN dataset
297-
--source cifar-10-python.tar.gz - Load CIFAR-10 dataset
298-
--source path/ - Recursively load all images from path/
299-
--source dataset.zip - Recursively load all images from dataset.zip
328+
--source *_lmdb/ Load LSUN dataset
329+
--source cifar-10-python.tar.gz Load CIFAR-10 dataset
330+
--source train-images-idx3-ubyte.gz Load MNIST dataset
331+
--source path/ Recursively load all images from path/
332+
--source dataset.zip Recursively load all images from dataset.zip
300333
301-
The output dataset format can be either an image folder or a zip archive. Specifying
302-
the output format and path:
334+
The output dataset format can be either an image folder or a zip archive.
335+
Specifying the output format and path:
303336
304337
\b
305-
--dest /path/to/dir - Save output files under /path/to/dir
306-
--dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive
338+
--dest /path/to/dir Save output files under /path/to/dir
339+
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
307340
308341
Images within the dataset archive will be stored as uncompressed PNG.
309342

0 commit comments

Comments
 (0)