From af5aac8b1e5e7f0a790d06bf6d10a01e2ab2b1f3 Mon Sep 17 00:00:00 2001 From: Shira Guskin <30695324+shira-g@users.noreply.github.com> Date: Fri, 29 Nov 2024 00:21:27 -0800 Subject: [PATCH] provide better speedup prompt to speculative decoding notebook (#2565) --- .../speculative-sampling.ipynb | 40 +++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/notebooks/speculative-sampling/speculative-sampling.ipynb b/notebooks/speculative-sampling/speculative-sampling.ipynb index 8c3a97b5784..a764b50017b 100644 --- a/notebooks/speculative-sampling/speculative-sampling.ipynb +++ b/notebooks/speculative-sampling/speculative-sampling.ipynb @@ -188,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "553148f5", "metadata": {}, "outputs": [ @@ -207,9 +207,23 @@ "pipe = ov_genai.LLMPipeline(target_model_path, device.value)\n", "\n", "config = ov_genai.GenerationConfig()\n", - "config.max_new_tokens = 100\n", - "\n", - "\n", + "config.max_new_tokens = 330\n", + "prompt = '''\n", + "\n", + "def prime_fib(n: int):\n", + " \"\"\"\n", + " prime_fib returns n-th number that is a Fibonacci number and it's also prime.\n", + " >>> prime_fib(1)\n", + " 2\n", + " >>> prime_fib(2)\n", + " 3\n", + " >>> prime_fib(3)\n", + " 5\n", + " >>> prime_fib(4)\n", + " 13\n", + " >>> prime_fib(5)\n", + " 89\n", + " \"\"\"'''\n", "def streamer(subword):\n", " print(subword, end=\"\", flush=True)\n", " # Return flag corresponds whether generation should be stopped.\n", @@ -218,13 +232,13 @@ "\n", "\n", "start_time = time.perf_counter()\n", - "pipe.generate([\"Sun is yellow because\"], config, streamer=streamer)\n", + "pipe.generate(prompt, config, streamer=streamer)\n", "end_time = time.perf_counter()" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "c40d9901-ceb2-4c4c-a686-303590292ab3", "metadata": {}, "outputs": [ @@ -241,7 +255,7 @@ "\n", "print(f\"Generation time: {end_time - start_time:.2f}s\")\n", "del pipe\n", - "gc.collect();" + "gc.collect()" ] }, { @@ -263,7 +277,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "9fde1b3c", "metadata": {}, "outputs": [ @@ -278,17 +292,19 @@ "source": [ "scheduler_config = ov_genai.SchedulerConfig()\n", "# cache params\n", - "scheduler_config.cache_size = 2\n", + "scheduler_config.cache_size = 0\n", + "scheduler_config.num_kv_blocks = 2048 // 8\n", + "scheduler_config.max_num_batched_tokens = 2048\n", "\n", "draft_model = ov_genai.draft_model(draft_model_path, device.value)\n", "\n", "pipe = ov_genai.LLMPipeline(target_model_path, device.value, draft_model=draft_model, scheduler_config=scheduler_config)\n", "\n", "config = ov_genai.GenerationConfig()\n", - "config.max_new_tokens = 100\n", - "config.num_assistant_tokens = 3\n", + "config.max_new_tokens = 330\n", + "config.num_assistant_tokens = 5\n", "start_time = time.perf_counter()\n", - "result = pipe.generate([\"Sun is yellow because\"], config, streamer=streamer)\n", + "result = pipe.generate(prompt, config, streamer=streamer)\n", "end_time = time.perf_counter()" ] },