Skip to content

Commit 858177a

Browse files
Improve PyTorch Performance with Multiple Workers and Shuffling (#207)
* removed row shuffling, kept only batch shuffling. * added persistent_workers attribute to PyTorch reader unit tests. * updated PyTorch dense reader example.
1 parent 03a8602 commit 858177a

File tree

3 files changed

+145
-34
lines changed

3 files changed

+145
-34
lines changed

examples/readers/pytorch_data_api_tiledb_dense.ipynb

+141-25
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,110 @@
5252
"name": "#%%\n"
5353
}
5454
},
55-
"outputs": [],
55+
"outputs": [
56+
{
57+
"name": "stderr",
58+
"output_type": "stream",
59+
"text": [
60+
"0.3%"
61+
]
62+
},
63+
{
64+
"name": "stdout",
65+
"output_type": "stream",
66+
"text": [
67+
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
68+
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz\n"
69+
]
70+
},
71+
{
72+
"name": "stderr",
73+
"output_type": "stream",
74+
"text": [
75+
"100.0%\n"
76+
]
77+
},
78+
{
79+
"name": "stdout",
80+
"output_type": "stream",
81+
"text": [
82+
"Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw\n"
83+
]
84+
},
85+
{
86+
"name": "stderr",
87+
"output_type": "stream",
88+
"text": [
89+
"100.0%"
90+
]
91+
},
92+
{
93+
"name": "stdout",
94+
"output_type": "stream",
95+
"text": [
96+
"\n",
97+
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
98+
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz\n",
99+
"Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
100+
"\n",
101+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n"
102+
]
103+
},
104+
{
105+
"name": "stderr",
106+
"output_type": "stream",
107+
"text": [
108+
"\n",
109+
"19.9%"
110+
]
111+
},
112+
{
113+
"name": "stdout",
114+
"output_type": "stream",
115+
"text": [
116+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
117+
]
118+
},
119+
{
120+
"name": "stderr",
121+
"output_type": "stream",
122+
"text": [
123+
"100.0%\n"
124+
]
125+
},
126+
{
127+
"name": "stdout",
128+
"output_type": "stream",
129+
"text": [
130+
"Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw\n",
131+
"\n",
132+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
133+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
134+
]
135+
},
136+
{
137+
"name": "stderr",
138+
"output_type": "stream",
139+
"text": [
140+
"100.0%"
141+
]
142+
},
143+
{
144+
"name": "stdout",
145+
"output_type": "stream",
146+
"text": [
147+
"Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
148+
"\n"
149+
]
150+
},
151+
{
152+
"name": "stderr",
153+
"output_type": "stream",
154+
"text": [
155+
"\n"
156+
]
157+
}
158+
],
56159
"source": [
57160
"data_home = os.path.join(os.path.pardir, \"data\")\n",
58161
"data = torchvision.datasets.MNIST(root=data_home, train=False, download=True)"
@@ -143,7 +246,16 @@
143246
"cell_type": "code",
144247
"execution_count": 5,
145248
"metadata": {},
146-
"outputs": [],
249+
"outputs": [
250+
{
251+
"name": "stderr",
252+
"output_type": "stream",
253+
"text": [
254+
"/Users/george/PycharmProjects/TileDB-ML/.venv/lib/python3.9/site-packages/tiledb/ctx.py:448: UserWarning: tiledb.default_ctx and scope_ctx will not function correctly due to bug in IPython contextvar support. You must supply a Ctx object to each function for custom configuration options. Please consider upgrading to ipykernel >= 6!Please see https://github.com/TileDB-Inc/TileDB-Py/issues/667 for more information.\n",
255+
" warnings.warn(\n"
256+
]
257+
}
258+
],
147259
"source": [
148260
"data_dir = os.path.join(data_home, 'readers', 'pytorch', 'dense')\n",
149261
"os.makedirs(data_dir, exist_ok=True)\n",
@@ -212,14 +324,6 @@
212324
")\n",
213325
"\n"
214326
]
215-
},
216-
{
217-
"name": "stderr",
218-
"output_type": "stream",
219-
"text": [
220-
"/Users/george/PycharmProjects/TileDB-ML/.venv/lib/python3.9/site-packages/tiledb/ctx.py:410: UserWarning: tiledb.default_ctx and scope_ctx will not function correctly due to bug in IPython contextvar support. You must supply a Ctx object to each function for custom configuration options. Please consider upgrading to ipykernel >= 6!Please see https://github.com/TileDB-Inc/TileDB-Py/issues/667 for more information.\n",
221-
" warnings.warn(\n"
222-
]
223327
}
224328
],
225329
"source": [
@@ -254,7 +358,7 @@
254358
"outputs": [
255359
{
256360
"data": {
257-
"text/plain": "<matplotlib.image.AxesImage at 0x12de28bb0>"
361+
"text/plain": "<matplotlib.image.AxesImage at 0x123788ca0>"
258362
},
259363
"execution_count": 7,
260364
"metadata": {},
@@ -305,20 +409,21 @@
305409
"name": "stdout",
306410
"output_type": "stream",
307411
"text": [
308-
"Train Epoch: 1 Batch: 0 Loss: 2.299262\n",
309-
"Train Epoch: 1 Batch: 100 Loss: 2.262452\n",
310-
"Train Epoch: 1 Batch: 200 Loss: 2.162849\n",
311-
"Train Epoch: 1 Batch: 300 Loss: 1.927302\n",
312-
"Train Epoch: 1 Batch: 400 Loss: 1.646087\n",
313-
"Train Epoch: 2 Batch: 0 Loss: 1.446454\n",
314-
"Train Epoch: 2 Batch: 100 Loss: 1.314963\n",
315-
"Train Epoch: 2 Batch: 200 Loss: 1.376722\n",
316-
"Train Epoch: 2 Batch: 300 Loss: 1.400400\n",
317-
"Train Epoch: 2 Batch: 400 Loss: 1.291488\n"
412+
"Train Epoch: 1 Batch: 0 Loss: 2.304748\n",
413+
"Train Epoch: 1 Batch: 100 Loss: 2.277155\n",
414+
"Train Epoch: 1 Batch: 200 Loss: 2.203359\n",
415+
"Train Epoch: 1 Batch: 300 Loss: 1.895098\n",
416+
"Train Epoch: 1 Batch: 400 Loss: 1.497304\n",
417+
"Train Epoch: 2 Batch: 0 Loss: 1.435658\n",
418+
"Train Epoch: 2 Batch: 100 Loss: 1.305221\n",
419+
"Train Epoch: 2 Batch: 200 Loss: 0.990590\n",
420+
"Train Epoch: 2 Batch: 300 Loss: 1.103210\n",
421+
"Train Epoch: 2 Batch: 400 Loss: 0.903957\n"
318422
]
319423
}
320424
],
321425
"source": [
426+
"import multiprocessing\n",
322427
"import torch.nn as nn\n",
323428
"import torch.optim as optim\n",
324429
"\n",
@@ -349,11 +454,15 @@
349454
" img = np.clip(img,0,1)\n",
350455
" return img\n",
351456
"\n",
352-
"\n",
353457
"ctx = tiledb.Ctx({'sm.memory_budget': 1024**2})\n",
354458
"with tiledb.open(training_images, ctx=ctx) as x, tiledb.open(training_labels, ctx=ctx) as y:\n",
459+
" # Because of this issue (https://github.com/pytorch/pytorch/issues/59451#issuecomment-854883855) we avoid using multiple workers on Jupyter.\n",
355460
" train_loader = PyTorchTileDBDataLoader(\n",
356-
" ArrayParams(x, fn=do_random_noise), ArrayParams(y), batch_size=128,\n",
461+
" ArrayParams(x, fn=do_random_noise),\n",
462+
" ArrayParams(y),\n",
463+
" batch_size=128,\n",
464+
" num_workers=0,\n",
465+
" shuffle_buffer_size=256,\n",
357466
" )\n",
358467
"\n",
359468
" net = Net(shape=(28, 28))\n",
@@ -374,11 +483,18 @@
374483
" print('Train Epoch: {} Batch: {} Loss: {:.6f}'.format(\n",
375484
" epoch, batch_idx, loss.item()))"
376485
]
486+
},
487+
{
488+
"cell_type": "code",
489+
"execution_count": 8,
490+
"metadata": {},
491+
"outputs": [],
492+
"source": []
377493
}
378494
],
379495
"metadata": {
380496
"kernelspec": {
381-
"display_name": "Python 3 (ipykernel)",
497+
"display_name": "Python 3",
382498
"language": "python",
383499
"name": "python3"
384500
},
@@ -392,7 +508,7 @@
392508
"name": "python",
393509
"nbconvert_exporter": "python",
394510
"pygments_lexer": "ipython3",
395-
"version": "3.7.13"
511+
"version": "3.9.9"
396512
}
397513
},
398514
"nbformat": 4,

tests/readers/test_pytorch.py

+4
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,16 @@ def test_dataloader(
3232
):
3333
def test(*all_array_params):
3434
try:
35+
persistent_workers = num_workers > 0
36+
3537
dataloader = PyTorchTileDBDataLoader(
3638
*all_array_params,
3739
shuffle_buffer_size=shuffle_buffer_size,
3840
batch_size=batch_size,
3941
num_workers=num_workers,
42+
persistent_workers=persistent_workers,
4043
)
44+
4145
except NotImplementedError:
4246
assert num_workers and (
4347
torchdata.__version__ < "0.4"

tiledb/ml/readers/pytorch.py

-9
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,8 @@ def PyTorchTileDBDataLoader(
8383

8484
# shuffle the unbatched rows if shuffle_buffer_size > 0
8585
if shuffle_buffer_size:
86-
# load the rows to be shuffled
87-
# don't batch them (batch_size=None) or collate them (collate_fn=_identity)
88-
row_loader = DataLoader(
89-
datapipe, num_workers=num_workers, batch_size=None, collate_fn=_identity
90-
)
91-
# create a new datapipe for these rows
92-
datapipe = DeferredIterableIterDataPipe(iter, row_loader)
9386
# shuffle the datapipe items
9487
datapipe = datapipe.shuffle(buffer_size=shuffle_buffer_size)
95-
# run the shuffling on this process, not on workers
96-
kwargs["num_workers"] = 0
9788

9889
# construct an appropriate collate function
9990
collator = Collator.from_schemas(*schemas)

0 commit comments

Comments
 (0)