Skip to content

Commit af5aac8

Browse files
authored
provide better speedup prompt to speculative decoding notebook (#2565)
1 parent 605d3f2 commit af5aac8

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

notebooks/speculative-sampling/speculative-sampling.ipynb

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@
188188
},
189189
{
190190
"cell_type": "code",
191-
"execution_count": 4,
191+
"execution_count": null,
192192
"id": "553148f5",
193193
"metadata": {},
194194
"outputs": [
@@ -207,9 +207,23 @@
207207
"pipe = ov_genai.LLMPipeline(target_model_path, device.value)\n",
208208
"\n",
209209
"config = ov_genai.GenerationConfig()\n",
210-
"config.max_new_tokens = 100\n",
211-
"\n",
212-
"\n",
210+
"config.max_new_tokens = 330\n",
211+
"prompt = '''<s>\n",
212+
"\n",
213+
"def prime_fib(n: int):\n",
214+
" \"\"\"\n",
215+
" prime_fib returns n-th number that is a Fibonacci number and it's also prime.\n",
216+
" >>> prime_fib(1)\n",
217+
" 2\n",
218+
" >>> prime_fib(2)\n",
219+
" 3\n",
220+
" >>> prime_fib(3)\n",
221+
" 5\n",
222+
" >>> prime_fib(4)\n",
223+
" 13\n",
224+
" >>> prime_fib(5)\n",
225+
" 89\n",
226+
" \"\"\"'''\n",
213227
"def streamer(subword):\n",
214228
" print(subword, end=\"\", flush=True)\n",
215229
" # Return flag corresponds whether generation should be stopped.\n",
@@ -218,13 +232,13 @@
218232
"\n",
219233
"\n",
220234
"start_time = time.perf_counter()\n",
221-
"pipe.generate([\"Sun is yellow because\"], config, streamer=streamer)\n",
235+
"pipe.generate(prompt, config, streamer=streamer)\n",
222236
"end_time = time.perf_counter()"
223237
]
224238
},
225239
{
226240
"cell_type": "code",
227-
"execution_count": 5,
241+
"execution_count": null,
228242
"id": "c40d9901-ceb2-4c4c-a686-303590292ab3",
229243
"metadata": {},
230244
"outputs": [
@@ -241,7 +255,7 @@
241255
"\n",
242256
"print(f\"Generation time: {end_time - start_time:.2f}s\")\n",
243257
"del pipe\n",
244-
"gc.collect();"
258+
"gc.collect()"
245259
]
246260
},
247261
{
@@ -263,7 +277,7 @@
263277
},
264278
{
265279
"cell_type": "code",
266-
"execution_count": 6,
280+
"execution_count": null,
267281
"id": "9fde1b3c",
268282
"metadata": {},
269283
"outputs": [
@@ -278,17 +292,19 @@
278292
"source": [
279293
"scheduler_config = ov_genai.SchedulerConfig()\n",
280294
"# cache params\n",
281-
"scheduler_config.cache_size = 2\n",
295+
"scheduler_config.cache_size = 0\n",
296+
"scheduler_config.num_kv_blocks = 2048 // 8\n",
297+
"scheduler_config.max_num_batched_tokens = 2048\n",
282298
"\n",
283299
"draft_model = ov_genai.draft_model(draft_model_path, device.value)\n",
284300
"\n",
285301
"pipe = ov_genai.LLMPipeline(target_model_path, device.value, draft_model=draft_model, scheduler_config=scheduler_config)\n",
286302
"\n",
287303
"config = ov_genai.GenerationConfig()\n",
288-
"config.max_new_tokens = 100\n",
289-
"config.num_assistant_tokens = 3\n",
304+
"config.max_new_tokens = 330\n",
305+
"config.num_assistant_tokens = 5\n",
290306
"start_time = time.perf_counter()\n",
291-
"result = pipe.generate([\"Sun is yellow because\"], config, streamer=streamer)\n",
307+
"result = pipe.generate(prompt, config, streamer=streamer)\n",
292308
"end_time = time.perf_counter()"
293309
]
294310
},

0 commit comments

Comments
 (0)