Skip to content

Commit 97a87f9

Browse files
chrismattmanntfboyd
authored andcommitted
Fix for TF-models tensorflow#7216: CIFAR-10 tutorial for multi-GPU fails because full shape isn't passed to prefetch_queue contributed by mattmann. (tensorflow#7217)
1 parent 712f473 commit 97a87f9

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tutorials/image/cifar10/cifar10_multi_gpu_train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def train():
163163

164164
# Get images and labels for CIFAR-10.
165165
images, labels = cifar10.distorted_inputs()
166+
images = tf.reshape(images, [cifar10.FLAGS.batch_size, 24, 24, 3])
167+
labels = tf.reshape(labels, [cifar10.FLAGS.batch_size])
166168
batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
167169
[images, labels], capacity=2 * FLAGS.num_gpus)
168170
# Calculate the gradients for each model tower.

0 commit comments

Comments
 (0)