Skip to content

Commit 8e4a1e2

Browse files
authored
Merge pull request tensorflow#3093 from asimshankar/mnist
[mnist]: Use FixedLengthRecordDataset
2 parents a3669a9 + 4a36e31 commit 8e4a1e2

File tree

2 files changed

+118
-17
lines changed

2 files changed

+118
-17
lines changed

official/mnist/dataset.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""tf.data.Dataset interface to the MNIST dataset."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import os
21+
import shutil
22+
import gzip
23+
import numpy as np
24+
from six.moves import urllib
25+
import tensorflow as tf
26+
27+
28+
def read32(bytestream):
29+
"""Read 4 bytes from bytestream as an unsigned 32-bit integer."""
30+
dt = np.dtype(np.uint32).newbyteorder('>')
31+
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
32+
33+
34+
def check_image_file_header(filename):
35+
"""Validate that filename corresponds to images for the MNIST dataset."""
36+
with open(filename) as f:
37+
magic = read32(f)
38+
num_images = read32(f)
39+
rows = read32(f)
40+
cols = read32(f)
41+
if magic != 2051:
42+
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
43+
f.name))
44+
if rows != 28 or cols != 28:
45+
raise ValueError(
46+
'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
47+
(f.name, rows, cols))
48+
49+
50+
def check_labels_file_header(filename):
51+
"""Validate that filename corresponds to labels for the MNIST dataset."""
52+
with open(filename) as f:
53+
magic = read32(f)
54+
num_items = read32(f)
55+
if magic != 2049:
56+
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
57+
f.name))
58+
59+
60+
def download(directory, filename):
61+
"""Download (and unzip) a file from the MNIST dataset, if it doesn't already exist."""
62+
if not tf.gfile.Exists(directory):
63+
tf.gfile.MakeDirs(directory)
64+
filepath = os.path.join(directory, filename)
65+
if tf.gfile.Exists(filepath):
66+
return filepath
67+
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
68+
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
69+
zipped_filepath = filepath + '.gz'
70+
print('Downloading %s to %s' % (url, zipped_filepath))
71+
urllib.request.urlretrieve(url, zipped_filepath)
72+
with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:
73+
shutil.copyfileobj(f_in, f_out)
74+
os.remove(zipped_filepath)
75+
return filepath
76+
77+
78+
def dataset(directory, images_file, labels_file):
79+
images_file = download(directory, images_file)
80+
labels_file = download(directory, labels_file)
81+
82+
check_image_file_header(images_file)
83+
check_labels_file_header(labels_file)
84+
85+
def decode_image(image):
86+
# Normalize from [0, 255] to [0.0, 1.0]
87+
image = tf.decode_raw(image, tf.uint8)
88+
image = tf.cast(image, tf.float32)
89+
image = tf.reshape(image, [784])
90+
return image / 255.0
91+
92+
def one_hot_label(label):
93+
label = tf.decode_raw(label, tf.uint8) # tf.string -> tf.uint8
94+
label = tf.reshape(label, []) # label is a scalar
95+
return tf.one_hot(label, 10)
96+
97+
images = tf.data.FixedLengthRecordDataset(
98+
images_file, 28 * 28, header_bytes=16).map(decode_image)
99+
labels = tf.data.FixedLengthRecordDataset(
100+
labels_file, 1, header_bytes=8).map(one_hot_label)
101+
return tf.data.Dataset.zip((images, labels))
102+
103+
104+
def train(directory):
105+
"""tf.data.Dataset object for MNIST training data."""
106+
return dataset(directory, 'train-images-idx3-ubyte',
107+
'train-labels-idx1-ubyte')
108+
109+
110+
def test(directory):
111+
"""tf.data.Dataset object for MNIST test data."""
112+
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')

official/mnist/mnist.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,7 @@
2222
import sys
2323

2424
import tensorflow as tf
25-
from tensorflow.examples.tutorials.mnist import input_data
26-
27-
28-
def train_dataset(data_dir):
29-
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
30-
data = input_data.read_data_sets(data_dir, one_hot=True).train
31-
return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
32-
33-
34-
def eval_dataset(data_dir):
35-
"""Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
36-
data = input_data.read_data_sets(data_dir, one_hot=True).test
37-
return tf.data.Dataset.from_tensors((data.images, data.labels))
25+
import dataset
3826

3927

4028
class Model(object):
@@ -151,10 +139,10 @@ def train_input_fn():
151139
# When choosing shuffle buffer sizes, larger sizes result in better
152140
# randomness, while smaller sizes use less memory. MNIST is a small
153141
# enough dataset that we can easily shuffle the full epoch.
154-
dataset = train_dataset(FLAGS.data_dir)
155-
dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
142+
ds = dataset.train(FLAGS.data_dir)
143+
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
156144
FLAGS.train_epochs)
157-
(images, labels) = dataset.make_one_shot_iterator().get_next()
145+
(images, labels) = ds.make_one_shot_iterator().get_next()
158146
return (images, labels)
159147

160148
# Set up training hook that logs the training accuracy every 100 steps.
@@ -165,7 +153,8 @@ def train_input_fn():
165153

166154
# Evaluate the model and print results
167155
def eval_input_fn():
168-
return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()
156+
return dataset.test(FLAGS.data_dir).batch(
157+
FLAGS.batch_size).make_one_shot_iterator().get_next()
169158

170159
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
171160
print()

0 commit comments

Comments
 (0)