Skip to content

Commit

Permalink
test low batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshd1402 committed Jan 16, 2025
1 parent 9edbd8f commit f248983
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def accuracy(params, batch):
# For this manual SPMD example, we get the number of devices (e.g. CPU,
# GPUs or TPU cores) that we're using, and use it to reshape data minibatches.
num_devices = jax.local_device_count()
batch_size = num_devices * 5
# batch_size = num_devices * 5
batch_size = 5 # testing

train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
Expand Down

0 comments on commit f248983

Please sign in to comment.