Skip to content

Commit

Permalink
provide better speedup prompt to speculative decoding notebook (#2565)
Browse files Browse the repository at this point in the history
  • Loading branch information
shira-g authored Nov 29, 2024
1 parent 605d3f2 commit af5aac8
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions notebooks/speculative-sampling/speculative-sampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "553148f5",
"metadata": {},
"outputs": [
Expand All @@ -207,9 +207,23 @@
"pipe = ov_genai.LLMPipeline(target_model_path, device.value)\n",
"\n",
"config = ov_genai.GenerationConfig()\n",
"config.max_new_tokens = 100\n",
"\n",
"\n",
"config.max_new_tokens = 330\n",
"prompt = '''<s>\n",
"\n",
"def prime_fib(n: int):\n",
" \"\"\"\n",
" prime_fib returns n-th number that is a Fibonacci number and it's also prime.\n",
" >>> prime_fib(1)\n",
" 2\n",
" >>> prime_fib(2)\n",
" 3\n",
" >>> prime_fib(3)\n",
" 5\n",
" >>> prime_fib(4)\n",
" 13\n",
" >>> prime_fib(5)\n",
" 89\n",
" \"\"\"'''\n",
"def streamer(subword):\n",
" print(subword, end=\"\", flush=True)\n",
" # Return flag corresponds whether generation should be stopped.\n",
Expand All @@ -218,13 +232,13 @@
"\n",
"\n",
"start_time = time.perf_counter()\n",
"pipe.generate([\"Sun is yellow because\"], config, streamer=streamer)\n",
"pipe.generate(prompt, config, streamer=streamer)\n",
"end_time = time.perf_counter()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "c40d9901-ceb2-4c4c-a686-303590292ab3",
"metadata": {},
"outputs": [
Expand All @@ -241,7 +255,7 @@
"\n",
"print(f\"Generation time: {end_time - start_time:.2f}s\")\n",
"del pipe\n",
"gc.collect();"
"gc.collect()"
]
},
{
Expand All @@ -263,7 +277,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "9fde1b3c",
"metadata": {},
"outputs": [
Expand All @@ -278,17 +292,19 @@
"source": [
"scheduler_config = ov_genai.SchedulerConfig()\n",
"# cache params\n",
"scheduler_config.cache_size = 2\n",
"scheduler_config.cache_size = 0\n",
"scheduler_config.num_kv_blocks = 2048 // 8\n",
"scheduler_config.max_num_batched_tokens = 2048\n",
"\n",
"draft_model = ov_genai.draft_model(draft_model_path, device.value)\n",
"\n",
"pipe = ov_genai.LLMPipeline(target_model_path, device.value, draft_model=draft_model, scheduler_config=scheduler_config)\n",
"\n",
"config = ov_genai.GenerationConfig()\n",
"config.max_new_tokens = 100\n",
"config.num_assistant_tokens = 3\n",
"config.max_new_tokens = 330\n",
"config.num_assistant_tokens = 5\n",
"start_time = time.perf_counter()\n",
"result = pipe.generate([\"Sun is yellow because\"], config, streamer=streamer)\n",
"result = pipe.generate(prompt, config, streamer=streamer)\n",
"end_time = time.perf_counter()"
]
},
Expand Down

0 comments on commit af5aac8

Please sign in to comment.