Skip to content

Commit 36fe429

Browse files
author
codebasics
committed
prefetch
1 parent 02fd6b3 commit 36fe429

File tree

1 file changed

+280
-0
lines changed

1 file changed

+280
-0
lines changed

45_prefatch/prefetch_caching.ipynb

+280
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"<h3 align=\"center\" style='color:blue'>Optimize tensorflow pipeline performance with prefetch and caching</h3>"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 14,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"import tensorflow as tf\n",
17+
"import time"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": 15,
23+
"metadata": {
24+
"scrolled": true
25+
},
26+
"outputs": [
27+
{
28+
"data": {
29+
"text/plain": [
30+
"'2.5.0'"
31+
]
32+
},
33+
"execution_count": 15,
34+
"metadata": {},
35+
"output_type": "execute_result"
36+
}
37+
],
38+
"source": [
39+
"tf.__version__"
40+
]
41+
},
42+
{
43+
"cell_type": "markdown",
44+
"metadata": {},
45+
"source": [
46+
"<h3 style='color:purple'>Prefetch</h3>"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 16,
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"class FileDataset(tf.data.Dataset):\n",
56+
" def read_file_in_batches(num_samples):\n",
57+
" # Opening the file\n",
58+
" time.sleep(0.03)\n",
59+
"\n",
60+
" for sample_idx in range(num_samples):\n",
61+
" # Reading data (line, record) from the file\n",
62+
" time.sleep(0.015)\n",
63+
"\n",
64+
" yield (sample_idx,)\n",
65+
"\n",
66+
" def __new__(cls, num_samples=3):\n",
67+
" return tf.data.Dataset.from_generator(\n",
68+
" cls.read_file_in_batches,\n",
69+
" output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),\n",
70+
" args=(num_samples,)\n",
71+
" )"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": 17,
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"def benchmark(dataset, num_epochs=2):\n",
81+
" for epoch_num in range(num_epochs):\n",
82+
" for sample in dataset:\n",
83+
" # Performing a training step\n",
84+
" time.sleep(0.01)"
85+
]
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": 18,
90+
"metadata": {
91+
"scrolled": true
92+
},
93+
"outputs": [
94+
{
95+
"name": "stdout",
96+
"output_type": "stream",
97+
"text": [
98+
"304 ms ± 10.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
99+
]
100+
}
101+
],
102+
"source": [
103+
"%%timeit\n",
104+
"benchmark(FileDataset())"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 23,
110+
"metadata": {},
111+
"outputs": [
112+
{
113+
"name": "stdout",
114+
"output_type": "stream",
115+
"text": [
116+
"238 ms ± 6.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
117+
]
118+
}
119+
],
120+
"source": [
121+
"%%timeit\n",
122+
"benchmark(FileDataset().prefetch(1))"
123+
]
124+
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": 19,
128+
"metadata": {},
129+
"outputs": [
130+
{
131+
"name": "stdout",
132+
"output_type": "stream",
133+
"text": [
134+
"240 ms ± 7.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
135+
]
136+
}
137+
],
138+
"source": [
139+
"%%timeit\n",
140+
"benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))"
141+
]
142+
},
143+
{
144+
"cell_type": "markdown",
145+
"metadata": {},
146+
"source": [
147+
"**As you can notice above, using prefetch improves the performance from 304 ms to 238 and 240 ms**"
148+
]
149+
},
150+
{
151+
"cell_type": "markdown",
152+
"metadata": {},
153+
"source": [
154+
"<h3 style='color:purple'>Cache</h3>"
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": 30,
160+
"metadata": {},
161+
"outputs": [
162+
{
163+
"data": {
164+
"text/plain": [
165+
"[0, 1, 4, 9, 16]"
166+
]
167+
},
168+
"execution_count": 30,
169+
"metadata": {},
170+
"output_type": "execute_result"
171+
}
172+
],
173+
"source": [
174+
"dataset = tf.data.Dataset.range(5)\n",
175+
"dataset = dataset.map(lambda x: x**2)\n",
176+
"dataset = dataset.cache(\"mycache.txt\")\n",
177+
"# The first time reading through the data will generate the data using\n",
178+
"# `range` and `map`.\n",
179+
"list(dataset.as_numpy_iterator())"
180+
]
181+
},
182+
{
183+
"cell_type": "code",
184+
"execution_count": 29,
185+
"metadata": {},
186+
"outputs": [
187+
{
188+
"data": {
189+
"text/plain": [
190+
"[0, 1, 4, 9, 16]"
191+
]
192+
},
193+
"execution_count": 29,
194+
"metadata": {},
195+
"output_type": "execute_result"
196+
}
197+
],
198+
"source": [
199+
"# Subsequent iterations read from the cache.\n",
200+
"list(dataset.as_numpy_iterator())"
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": 24,
206+
"metadata": {},
207+
"outputs": [],
208+
"source": [
209+
"def mapped_function(s):\n",
210+
" # Do some hard pre-processing\n",
211+
" tf.py_function(lambda: time.sleep(0.03), [], ())\n",
212+
" return s"
213+
]
214+
},
215+
{
216+
"cell_type": "code",
217+
"execution_count": 26,
218+
"metadata": {},
219+
"outputs": [
220+
{
221+
"name": "stdout",
222+
"output_type": "stream",
223+
"text": [
224+
"1.25 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
225+
]
226+
}
227+
],
228+
"source": [
229+
"%%timeit -r1 -n1\n",
230+
"benchmark(FileDataset().map(mapped_function), 5)"
231+
]
232+
},
233+
{
234+
"cell_type": "code",
235+
"execution_count": 27,
236+
"metadata": {},
237+
"outputs": [
238+
{
239+
"name": "stdout",
240+
"output_type": "stream",
241+
"text": [
242+
"528 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
243+
]
244+
}
245+
],
246+
"source": [
247+
"%%timeit -r1 -n1\n",
248+
"benchmark(FileDataset().map(mapped_function).cache(), 5)"
249+
]
250+
},
251+
{
252+
"cell_type": "markdown",
253+
"metadata": {},
254+
"source": [
255+
"**Further reading** https://www.tensorflow.org/guide/data_performance#caching"
256+
]
257+
}
258+
],
259+
"metadata": {
260+
"kernelspec": {
261+
"display_name": "Python 3",
262+
"language": "python",
263+
"name": "python3"
264+
},
265+
"language_info": {
266+
"codemirror_mode": {
267+
"name": "ipython",
268+
"version": 3
269+
},
270+
"file_extension": ".py",
271+
"mimetype": "text/x-python",
272+
"name": "python",
273+
"nbconvert_exporter": "python",
274+
"pygments_lexer": "ipython3",
275+
"version": "3.8.5"
276+
}
277+
},
278+
"nbformat": 4,
279+
"nbformat_minor": 4
280+
}

0 commit comments

Comments
 (0)