From d36473f5f8b10df895f0d0c364d5528ce12e6a79 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Fri, 31 Jan 2025 15:40:13 +0400 Subject: [PATCH] add deepseek-r1 reasoning notebook (#2709) --- notebooks/deepseek-r1/README.md | 32 ++ notebooks/deepseek-r1/deepseek-r1.ipynb | 441 ++++++++++++++++++++++++ notebooks/deepseek-r1/gradio_helper.py | 384 +++++++++++++++++++++ notebooks/deepseek-r1/llm_config.py | 226 ++++++++++++ 4 files changed, 1083 insertions(+) create mode 100644 notebooks/deepseek-r1/README.md create mode 100644 notebooks/deepseek-r1/deepseek-r1.ipynb create mode 100644 notebooks/deepseek-r1/gradio_helper.py create mode 100644 notebooks/deepseek-r1/llm_config.py diff --git a/notebooks/deepseek-r1/README.md b/notebooks/deepseek-r1/README.md new file mode 100644 index 00000000000..a69dd1148b5 --- /dev/null +++ b/notebooks/deepseek-r1/README.md @@ -0,0 +1,32 @@ +# LLM reasoning with DeepSeek-R1 distilled models + +[DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) is an open-source reasoning model developed by DeepSeek to address tasks requiring logical inference, mathematical problem-solving, and real-time decision-making. With DeepSeek-R1, you can follow its logic, making it easier to understand and, if necessary, challenge its output. This capability gives reasoning models an edge in fields where outcomes need to be explainable, like research or complex decision-making. + +Distillation in AI creates smaller, more efficient models from larger ones, preserving much of their reasoning power while reducing computational demands. DeepSeek applied this technique to create a suite of distilled models from R1, using Qwen and Llama architectures. That allows us to try DeepSeek-R1 capability locally on usual laptops. + +In this tutorial, we consider how to run DeepSeek-R1 distilled models using OpenVINO. + +The tutorial supports different models, you can select one from the provided options to compare the quality of LLM solutions: + +* **DeepSeek-R1-Distill-Llama-8B** is a distilled model based on [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B), that prioritizes high performance and advanced reasoning capabilities, particularly excelling in tasks requiring mathematical and factual precision. Check [model card](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) for more info. +* **DeepSeek-R1-Distill-Qwen-1.5B** is the smallest DeepSeek-R1 distilled model based on [Qwen2.5-Math-1.5B](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B). Despite its compact size, the model demonstrates strong capabilities in solving basic mathematical tasks, at the same time its programming capabilities are limited. Check [model card](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) for more info. +* **DeepSeek-R1-Distill-Qwen-7B** is a distilled model based on [Qwen-2.5-Math-7B](https://huggingface.co/Qwen/Qwen2.5-Math-7B). The model demonstrates a good balance between mathematical and factual reasoning and can be less suited for complex coding tasks. Check [model card](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) for more info. +* **DeepSeek-R1-Distil-Qwen-14B** is a distilled model based on [Qwen2.5-14B](https://huggingface.co/Qwen/Qwen2.5-14B) that has great competence in factual reasoning and solving complex mathematical tasks. Check [model card](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-15B) for more info. + +## Notebook Contents + +The tutorial consists of the following steps: + +- Install prerequisites +- Download and convert the model from a public source using the [OpenVINO integration with Hugging Face Optimum](https://huggingface.co/blog/openvino). +- Compress model weights to INT4 or INT8 precision using [NNCF](https://github.com/openvinotoolkit/nncf) +- Create an inference pipeline +- Run interactive demo + +![](https://github.com/user-attachments/assets/9062bdc4-0338-4555-a863-87b5a71236e9) + +## Installation Instructions +This is a self-contained example that relies solely on its own code.
+We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. +For details, please refer to [Installation Guide](../../README.md). + diff --git a/notebooks/deepseek-r1/deepseek-r1.ipynb b/notebooks/deepseek-r1/deepseek-r1.ipynb new file mode 100644 index 00000000000..9586595fe91 --- /dev/null +++ b/notebooks/deepseek-r1/deepseek-r1.ipynb @@ -0,0 +1,441 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LLM reasoning with DeepSeek-R1 distilled models\n", + "\n", + "[DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) is an open-source reasoning model developed by DeepSeek to address tasks requiring logical inference, mathematical problem-solving, and real-time decision-making. With DeepSeek-R1, you can follow its logic, making it easier to understand and, if necessary, challenge its output. This capability gives reasoning models an edge in fields where outcomes need to be explainable, like research or complex decision-making.\n", + "\n", + "Distillation in AI creates smaller, more efficient models from larger ones, preserving much of their reasoning power while reducing computational demands. DeepSeek applied this technique to create a suite of distilled models from R1, using Qwen and Llama architectures. That allows us to try DeepSeek-R1 capability locally on usual laptops.\n", + "\n", + "In this tutorial, we consider how to run DeepSeek-R1 distilled models using OpenVINO.\n", + "\n", + "#### Table of contents:\n", + "\n", + "- [Prerequisites](#Prerequisites)\n", + "- [Select model for inference](#Select-model-for-inference)\n", + "- [Convert model using Optimum-CLI tool](#Convert-model-using-Optimum-CLI-tool)\n", + " - [Weights Compression using Optimum-CLI](#Weights-Compression-using-Optimum-CLI)\n", + "- [Instantiate pipeline with OpenVINO Generate API](#Instantiate-pipeline-with-OpenVINO-Generate-API)\n", + "- [Run Chatbot](#Run-Chatbot)\n", + "\n", + "\n", + "### Installation Instructions\n", + "\n", + "This is a self-contained example that relies solely on its own code.\n", + "\n", + "We recommend running the notebook in a virtual environment. You only need a Jupyter server to start.\n", + "For details, please refer to [Installation Guide](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/README.md#-installation-guide).\n", + "\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "\n", + "Install required dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import platform\n", + "\n", + "os.environ[\"GIT_CLONE_PROTECTION_ACTIVE\"] = \"false\"\n", + "\n", + "%pip install -q -U \"openvino>=2024.6.0\" openvino-tokenizers openvino-genai --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly\n", + "%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu\\\n", + "\"git+https://github.com/huggingface/optimum-intel.git\"\\\n", + "\"nncf==2.14.1\"\\\n", + "\"torch>=2.1\"\\\n", + "\"datasets\" \\\n", + "\"accelerate\" \\\n", + "\"gradio>=4.19\" \\\n", + "\"transformers>=4.43.1\" \\\n", + "\"huggingface-hub>=0.26.5\" \\\n", + "\"einops\" \"tiktoken\"\n", + "\n", + "if platform.system() == \"Darwin\":\n", + " %pip install -q \"numpy<2.0.0\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "from pathlib import Path\n", + "\n", + "if not Path(\"llm_config.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py\")\n", + " open(\"llm_config.py\", \"w\").write(r.text)\n", + "\n", + "if not Path(\"notebook_utils.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py\")\n", + " open(\"notebook_utils.py\", \"w\").write(r.text)\n", + "\n", + "# Read more about telemetry collection at https://github.com/openvinotoolkit/openvino_notebooks?tab=readme-ov-file#-telemetry\n", + "from notebook_utils import collect_telemetry\n", + "\n", + "collect_telemetry(\"deepseek-r1.ipynb\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Select model for inference\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "\n", + "The tutorial supports different models, you can select one from the provided options to compare the quality of LLM solutions:\n", + "\n", + "* **DeepSeek-R1-Distill-Llama-8B** is a distilled model based on [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B), that prioritizes high performance and advanced reasoning capabilities, particularly excelling in tasks requiring mathematical and factual precision. Check [model card](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) for more info.\n", + "* **DeepSeek-R1-Distill-Qwen-1.5B** is the smallest DeepSeek-R1 distilled model based on [Qwen2.5-Math-1.5B](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B). Despite its compact size, the model demonstrates strong capabilities in solving basic mathematical tasks, at the same time its programming capabilities are limited. Check [model card](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) for more info.\n", + "* **DeepSeek-R1-Distill-Qwen-7B** is a distilled model based on [Qwen-2.5-Math-7B](https://huggingface.co/Qwen/Qwen2.5-Math-7B). The model demonstrates a good balance between mathematical and factual reasoning and can be less suited for complex coding tasks. Check [model card](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) for more info.\n", + "* **DeepSeek-R1-Distil-Qwen-14B** is a distilled model based on [Qwen2.5-14B](https://huggingface.co/Qwen/Qwen2.5-14B) that has great competence in factual reasoning and solving complex mathematical tasks. Check [model card](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-15B) for more info.\n", + "\n", + "[Weight compression](https://docs.openvino.ai/2024/openvino-workflow/model-optimization-guide/weight-compression.html) is a technique for enhancing the efficiency of models, especially those with large memory requirements. This method reduces the model’s memory footprint, a crucial factor for Large Language Models (LLMs). We provide several options for model weight compression:\n", + "\n", + "* **FP16** reducing model binary size on disk using `save_model` with enabled compression weights to FP16 precision. This approach is available in OpenVINO from scratch and is the default behavior.\n", + "* **INT8** is an 8-bit weight-only quantization provided by [NNCF](https://github.com/openvinotoolkit/nncf): This method compresses weights to an 8-bit integer data type, which balances model size reduction and accuracy, making it a versatile option for a broad range of applications.\n", + "* **INT4** is an 4-bit weight-only quantization provided by [NNCF](https://github.com/openvinotoolkit/nncf). involves quantizing weights to an unsigned 4-bit integer symmetrically around a fixed zero point of eight (i.e., the midpoint between zero and 15). in case of **symmetric quantization** or asymmetrically with a non-fixed zero point, in case of **asymmetric quantization** respectively. Compared to INT8 compression, INT4 compression improves performance even more, but introduces a minor drop in prediction quality. INT4 it ideal for situations where speed is prioritized over an acceptable trade-off against accuracy.\n", + "* **INT4 AWQ** is an 4-bit activation-aware weight quantization. [Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978) (AWQ) is an algorithm that tunes model weights for more accurate INT4 compression. It slightly improves generation quality of compressed LLMs, but requires significant additional time for tuning weights on a calibration dataset. We will use `wikitext-2-raw-v1/train` subset of the [Wikitext](https://huggingface.co/datasets/Salesforce/wikitext) dataset for calibration.\n", + "* **INT4 NPU-friendly** is an 4-bit channel-wise quantization. This approach is [recommended](https://docs.openvino.ai/2024/learn-openvino/llm_inference_guide/genai-guide-npu.html) for LLM inference using NPU." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1e4077b40ad4469fb814da21ab155e7e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Dropdown(description='Device:', options=('CPU', 'AUTO'), value='CPU')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from notebook_utils import device_widget\n", + "\n", + "device = device_widget(default=\"CPU\")\n", + "\n", + "device" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "test_replace": { + "get_llm_selection_widget(device=device.value)": "get_llm_selection_widget(device=device.value, default_model_idx=0)" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "27eab38265df4de5b583f246c79da0b1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Box(children=(Box(children=(Label(value='Language:'), Dropdown(options=('English', 'Chinese'), value='English'…" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from llm_config import get_llm_selection_widget\n", + "\n", + "form, lang, model_id_widget, compression_variant, _ = get_llm_selection_widget(device=device.value)\n", + "\n", + "form" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Selected model DeepSeek-R1-Distill-Llama-8B with INT4 compression\n" + ] + } + ], + "source": [ + "model_configuration = model_id_widget.value\n", + "model_id = model_id_widget.label\n", + "print(f\"Selected model {model_id} with {compression_variant.value} compression\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert model using Optimum-CLI tool\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "🤗 [Optimum Intel](https://huggingface.co/docs/optimum/intel/index) is the interface between the 🤗 [Transformers](https://huggingface.co/docs/transformers/index) and [Diffusers](https://huggingface.co/docs/diffusers/index) libraries and OpenVINO to accelerate end-to-end pipelines on Intel architectures. It provides ease-to-use cli interface for exporting models to [OpenVINO Intermediate Representation (IR)](https://docs.openvino.ai/2024/documentation/openvino-ir-format.html) format.\n", + "\n", + "
\n", + " Click here to read more about Optimum CLI usage\n", + "\n", + "The command bellow demonstrates basic command for model export with `optimum-cli`\n", + "\n", + "```\n", + "optimum-cli export openvino --model --task \n", + "```\n", + "\n", + "where `--model` argument is model id from HuggingFace Hub or local directory with model (saved using `.save_pretrained` method), `--task ` is one of [supported task](https://huggingface.co/docs/optimum/exporters/task_manager) that exported model should solve. For LLMs it is recommended to use `text-generation-with-past`. If model initialization requires to use remote code, `--trust-remote-code` flag additionally should be passed.\n", + "
\n", + "\n", + "### Weights Compression using Optimum-CLI\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "You can also apply fp16, 8-bit or 4-bit weight compression on the Linear, Convolutional and Embedding layers when exporting your model with the CLI. \n", + "
\n", + " Click here to read more about weights compression with Optimum CLI\n", + "\n", + "Setting `--weight-format` to respectively fp16, int8 or int4. This type of optimization allows to reduce the memory footprint and inference latency.\n", + "By default the quantization scheme for int8/int4 will be [asymmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization), to make it [symmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) you can add `--sym`.\n", + "\n", + "For INT4 quantization you can also specify the following arguments :\n", + "- The `--group-size` parameter will define the group size to use for quantization, -1 it will results in per-column quantization.\n", + "- The `--ratio` parameter controls the ratio between 4-bit and 8-bit quantization. If set to 0.9, it means that 90% of the layers will be quantized to int4 while 10% will be quantized to int8.\n", + "\n", + "Smaller group_size and ratio values usually improve accuracy at the sacrifice of the model size and inference latency.\n", + "You can enable AWQ to be additionally applied during model export with INT4 precision using `--awq` flag and providing dataset name with `--dataset`parameter (e.g. `--dataset wikitext2`)\n", + "\n", + ">**Note**: Applying AWQ requires significant memory and time.\n", + "\n", + ">**Note**: It is possible that there will be no matching patterns in the model to apply AWQ, in such case it will be skipped.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ INT4 DeepSeek-R1-Distill-Llama-8B model already converted and can be found in DeepSeek-R1-Distill-Llama-8B/INT4_compressed_weights\n" + ] + } + ], + "source": [ + "from llm_config import convert_and_compress_model\n", + "\n", + "model_dir = convert_and_compress_model(model_id, model_configuration, compression_variant.value)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Size of model with INT4 compressed weights is 5081.91 MB\n" + ] + } + ], + "source": [ + "from llm_config import compare_model_size\n", + "\n", + "compare_model_size(model_dir)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate pipeline with OpenVINO Generate API\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "[OpenVINO Generate API](https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md) can be used to create pipelines to run an inference with OpenVINO Runtime. \n", + "\n", + "Firstly we need to create a pipeline with `LLMPipeline`. `LLMPipeline` is the main object used for text generation using LLM in OpenVINO GenAI API. You can construct it straight away from the folder with the converted model. We will provide directory with model and device for `LLMPipeline`. Then we run `generate` method and get the output in text format.\n", + "Additionally, we can configure parameters for decoding. We can create the default config with `ov_genai.GenerationConfig()`, setup parameters, and apply the updated version with `set_generation_config(config)` or put config directly to `generate()`. It's also possible to specify the needed options just as inputs in the `generate()` method, as shown below, e.g. we can add `max_new_tokens` to stop generation if a specified number of tokens is generated and the end of generation is not reached. We will discuss some of the available generation parameters more deeply later. Generation process for long response may be time consuming, for accessing partial result as soon as it is generated without waiting when whole process finished, Streaming API can be used. Token streaming is the mode in which the generative system returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience. In code below, we implement simple streamer for printing output result. For more advanced streamer example please check openvino.genai [sample](https://github.com/openvinotoolkit/openvino.genai/tree/master/samples/python/multinomial_causal_lm)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model from DeepSeek-R1-Distill-Llama-8B/INT4_compressed_weights\n", + "\n", + "Input text: What is OpenVINO?\n", + " It's an open-source model optimization tool that accelerates AI deployment across various platforms. It supports multiple frameworks and platforms, providing tools for quantization, pruning, and knowledge distillation. OpenVINO is designed to help developers reduce the computational requirements of AI models, making them more efficient and deployable on resource-constrained environments.\n", + "\n", + "What is OpenVINO? It's an open-source model optimization tool that accelerates AI deployment across various platforms. It supports multiple frameworks and platforms, providing tools for quantization, pruning, and knowledge distillation. OpenVINO is designed to help developers reduce the computational requirements of AI models, making them more" + ] + } + ], + "source": [ + "import openvino_genai as ov_genai\n", + "import sys\n", + "\n", + "print(f\"Loading model from {model_dir}\\n\")\n", + "\n", + "\n", + "pipe = ov_genai.LLMPipeline(str(model_dir), device.value)\n", + "if \"genai_chat_template\" in model_configuration:\n", + " pipe.get_tokenizer().set_chat_template(model_configuration[\"genai_chat_template\"])\n", + "\n", + "generation_config = ov_genai.GenerationConfig()\n", + "generation_config.max_new_tokens = 128\n", + "\n", + "\n", + "def streamer(subword):\n", + " print(subword, end=\"\", flush=True)\n", + " sys.stdout.flush()\n", + " # Return flag corresponds whether generation should be stopped.\n", + " # False means continue generation.\n", + " return False\n", + "\n", + "\n", + "input_prompt = \"What is OpenVINO?\"\n", + "print(f\"Input text: {input_prompt}\")\n", + "result = pipe.generate(input_prompt, generation_config, streamer)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run Chatbot\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "Now, when model created, we can setup Chatbot interface using [Gradio](https://www.gradio.app/).\n", + "\n", + "
\n", + " Click here to see how pipeline works\n", + "\n", + "The diagram below illustrates how the chatbot pipeline works\n", + "\n", + "![llm_diagram](https://github.com/user-attachments/assets/9c9b56e1-01a6-48d8-aa46-222a88e25066)\n", + "\n", + "As you can see, user input question passed via tokenizer to apply chat-specific formatting (chat template) and turn the provided string into the numeric format. [OpenVINO Tokenizers](https://github.com/openvinotoolkit/openvino_tokenizers) are used for these purposes inside `LLMPipeline`. You can find more detailed info about tokenization theory and OpenVINO Tokenizers in this [tutorial](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/notebooks/openvino-tokenizers/openvino-tokenizers.ipynb). Then tokenized input passed to LLM for making prediction of next token probability. The way the next token will be selected over predicted probabilities is driven by the selected decoding methodology. You can find more information about the most popular decoding methods in this [blog](https://huggingface.co/blog/how-to-generate). The sampler's goal is to select the next token id is driven by generation configuration. Next, we apply stop generation condition to check the generation is finished or not (e.g. if we reached the maximum new generated tokens or the next token id equals to end of the generation). If the end of the generation is not reached, then new generated token id is used as the next iteration input, and the generation cycle repeats until the condition is not met. When stop generation criteria are met, then OpenVINO Detokenizer decodes generated token ids to text answer. \n", + "\n", + "The difference between chatbot and instruction-following pipelines is that the model should have \"memory\" to find correct answers on the chain of connected questions. OpenVINO GenAI uses `KVCache` representation for maintain a history of conversation. By default, `LLMPipeline` resets `KVCache` after each `generate` call. To keep conversational history, we should move LLMPipeline to chat mode using `start_chat()` method.\n", + "\n", + "More info about OpenVINO LLM inference can be found in [LLM Inference Guide](https://docs.openvino.ai/2024/learn-openvino/llm_inference_guide.html)\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "if not Path(\"gradio_helper.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/deepseek-r1/gradio_helper.py\")\n", + " open(\"gradio_helper_genai.py\", \"w\").write(r.text)\n", + "\n", + "from gradio_helper import make_demo\n", + "\n", + "demo = make_demo(pipe, model_configuration, model_id, lang.value, device.value == \"NPU\")\n", + "\n", + "try:\n", + " demo.launch(debug=True)\n", + "except Exception:\n", + " demo.launch(debug=True, share=True)\n", + "# If you are launching remotely, specify server_name and server_port\n", + "# EXAMPLE: `demo.launch(server_name='your server name', server_port='server port in int')`\n", + "# To learn more please refer to the Gradio docs: https://gradio.app/docs/" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "openvino_notebooks": { + "imageUrl": "https://user-images.githubusercontent.com/29454499/255799218-611e7189-8979-4ef5-8a80-5a75e0136b50.png", + "tags": { + "categories": [ + "Model Demos", + "AI Trends" + ], + "libraries": [], + "other": [ + "LLM" + ], + "tasks": [ + "Text Generation", + "Conversational" + ] + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/deepseek-r1/gradio_helper.py b/notebooks/deepseek-r1/gradio_helper.py new file mode 100644 index 00000000000..0b9fc669fc0 --- /dev/null +++ b/notebooks/deepseek-r1/gradio_helper.py @@ -0,0 +1,384 @@ +import openvino as ov +import openvino_genai as ov_genai +from uuid import uuid4 +from threading import Event, Thread +import queue +import sys + +max_new_tokens = 256 + +core = ov.Core() + +chinese_examples = [ + ["向 5 岁的孩子解释重力。"], + ["给我讲一个关于微积分的笑话。"], + ["编写代码时需要避免哪些常见错误?"], + ["撰写一篇关于“人工智能和 OpenVINO 的优势”的 100 字博客文章"], + ["求解方程:2x + 5 = 15"], + ["说出养猫的 5 个优点"], + ["简化 (-k + 4) + (-2 + 3k)"], + ["求半径为20的圆的面积"], + ["对未来5年AI趋势进行预测"], +] + +english_examples = [ + ["Explain gravity to a 5-year-old."], + ["Tell me a joke about calculus."], + ["What are some common mistakes to avoid when writing code?"], + ["Write a 100-word blog post on “Benefits of Artificial Intelligence and OpenVINO“"], + ["Solve the equation: 2x + 5 = 15."], + ["Name 5 advantages to be a cat"], + ["Simplify (-k + 4) + (-2 + 3k)"], + ["Find the area of ​​a circle with radius 20"], + ["Make a forecast about AI trends for next 5 years"], +] + + +DEFAULT_SYSTEM_PROMPT = """\ +You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. +If a question does not make any sense or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\ +""" + +DEFAULT_SYSTEM_PROMPT_CHINESE = """\ +你是一个乐于助人、尊重他人以及诚实可靠的助手。在安全的情况下,始终尽可能有帮助地回答。 您的回答不应包含任何有害、不道德、种族主义、性别歧视、有毒、危险或非法的内容。请确保您的回答在社会上是公正的和积极的。 +如果一个问题没有任何意义或与事实不符,请解释原因,而不是回答错误的问题。如果您不知道问题的答案,请不要分享虚假信息。另外,答案请使用中文。\ +""" + + +def get_system_prompt(model_language, system_prompt=None): + if system_prompt is not None: + return system_prompt + return DEFAULT_SYSTEM_PROMPT_CHINESE if (model_language == "Chinese") else DEFAULT_SYSTEM_PROMPT + + +class IterableStreamer(ov_genai.StreamerBase): + """ + A custom streamer class for handling token streaming and detokenization with buffering. + + Attributes: + tokenizer (Tokenizer): The tokenizer used for encoding and decoding tokens. + tokens_cache (list): A buffer to accumulate tokens for detokenization. + text_queue (Queue): A synchronized queue for storing decoded text chunks. + print_len (int): The length of the printed text to manage incremental decoding. + """ + + def __init__(self, tokenizer): + """ + Initializes the IterableStreamer with the given tokenizer. + + Args: + tokenizer (Tokenizer): The tokenizer to use for encoding and decoding tokens. + """ + super().__init__() + self.tokenizer = tokenizer + self.tokens_cache = [] + self.text_queue = queue.Queue() + self.print_len = 0 + + def __iter__(self): + """ + Returns the iterator object itself. + """ + return self + + def __next__(self): + """ + Returns the next value from the text queue. + + Returns: + str: The next decoded text chunk. + + Raises: + StopIteration: If there are no more elements in the queue. + """ + value = self.text_queue.get() # get() will be blocked until a token is available. + if value is None: + raise StopIteration + return value + + def get_stop_flag(self): + """ + Checks whether the generation process should be stopped. + + Returns: + bool: Always returns False in this implementation. + """ + return False + + def put_word(self, word: str): + """ + Puts a word into the text queue. + + Args: + word (str): The word to put into the queue. + """ + self.text_queue.put(word) + + def put(self, token_id: int) -> bool: + """ + Processes a token and manages the decoding buffer. Adds decoded text to the queue. + + Args: + token_id (int): The token_id to process. + + Returns: + bool: True if generation should be stopped, False otherwise. + """ + self.tokens_cache.append(token_id) + text = self.tokenizer.decode(self.tokens_cache) + + word = "" + if len(text) > self.print_len and "\n" == text[-1]: + # Flush the cache after the new line symbol. + word = text[self.print_len :] + self.tokens_cache = [] + self.print_len = 0 + elif len(text) >= 3 and text[-3:] == chr(65533): + # Don't print incomplete text. + pass + elif len(text) > self.print_len: + # It is possible to have a shorter text after adding new token. + # Print to output only if text length is increaesed. + word = text[self.print_len :] + self.print_len = len(text) + self.put_word(word) + + if self.get_stop_flag(): + # When generation is stopped from streamer then end is not called, need to call it here manually. + self.end() + return True # True means stop generation + else: + return False # False means continue generation + + def end(self): + """ + Flushes residual tokens from the buffer and puts a None value in the queue to signal the end. + """ + text = self.tokenizer.decode(self.tokens_cache) + if len(text) > self.print_len: + word = text[self.print_len :] + self.put_word(word) + self.tokens_cache = [] + self.print_len = 0 + self.put_word(None) + + def reset(self): + self.tokens_cache = [] + self.text_queue = queue.Queue() + self.print_len = 0 + + +class ChunkStreamer(IterableStreamer): + + def __init__(self, tokenizer, tokens_len=4): + super().__init__(tokenizer) + self.tokens_len = tokens_len + + def put(self, token_id: int) -> bool: + if (len(self.tokens_cache) + 1) % self.tokens_len != 0: + self.tokens_cache.append(token_id) + return False + sys.stdout.flush() + return super().put(token_id) + + +def make_demo(pipe, model_configuration, model_id, model_language, disable_advanced=False): + import gradio as gr + + start_message = get_system_prompt(model_language, model_configuration.get("system_prompt")) + if "genai_chat_template" in model_configuration: + pipe.get_tokenizer().set_chat_template(model_configuration["genai_chat_template"]) + + def get_uuid(): + """ + universal unique identifier for thread + """ + return str(uuid4()) + + def default_partial_text_processor(partial_text: str, new_text: str): + """ + helper for updating partially generated answer, used by default + + Params: + partial_text: text buffer for storing previosly generated text + new_text: text update for the current step + Returns: + updated text string + + """ + partial_text += new_text + return partial_text + + text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor) + + def bot(message, history, temperature, top_p, top_k, repetition_penalty, max_tokens): + """ + callback function for running chatbot on submit button click + + Params: + message: new message from user + history: conversation history + temperature: parameter for control the level of creativity in AI-generated text. + By adjusting the `temperature`, you can influence the AI model's probability distribution, making the text more focused or diverse. + top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability. + top_k: parameter for control the range of tokens considered by the AI model based on their cumulative probability, selecting number of tokens with highest probability. + repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text. + active_chat: chat state, if true then chat is running, if false then we should start it here. + Returns: + message: reset message and make it "" + history: updated history with message and answer from chatbot + active_chat: if we are here, the chat is running or will be started, so return True + """ + streamer = ChunkStreamer(pipe.get_tokenizer()) + if not disable_advanced: + config = pipe.get_generation_config() + config.temperature = temperature + config.top_p = top_p + config.top_k = top_k + config.do_sample = temperature > 0.0 + config.max_new_tokens = max_tokens + config.repetition_penalty = repetition_penalty + if "stop_strings" in model_configuration: + config.stop_strings = set(model_configuration["stop_strings"]) + else: + config = ov_genai.GenerationConfig() + config.max_new_tokens = max_tokens + history = history or [] + if not history: + pipe.start_chat(system_message=start_message) + + history.append([message, ""]) + new_prompt = message + + stream_complete = Event() + + def generate_and_signal_complete(): + """ + genration function for single thread + """ + streamer.reset() + pipe.generate(new_prompt, config, streamer) + stream_complete.set() + streamer.end() + + t1 = Thread(target=generate_and_signal_complete) + t1.start() + + partial_text = "" + for new_text in streamer: + partial_text = text_processor(partial_text, new_text) + history[-1][1] = partial_text + yield "", history, streamer + + def stop_chat(streamer): + if streamer is not None: + streamer.end() + return None + + def stop_chat_and_clear_history(streamer): + if streamer is not None: + streamer.end() + pipe.finish_chat() + streamer.reset() + return None, None + + examples = chinese_examples if (model_language == "Chinese") else english_examples + + with gr.Blocks( + theme=gr.themes.Soft(), + css=".disclaimer {font-variant-caps: all-small-caps;}", + ) as demo: + streamer = gr.State(None) + conversation_id = gr.State(get_uuid) + gr.Markdown(f"""

OpenVINO {model_id} Chatbot

""") + chatbot = gr.Chatbot(height=500) + with gr.Row(): + with gr.Column(): + msg = gr.Textbox( + label="Chat Message Box", + placeholder="Chat Message Box", + show_label=False, + container=False, + ) + with gr.Column(): + with gr.Row(): + submit = gr.Button("Submit") + clear = gr.Button("Clear") + with gr.Row(visible=not disable_advanced): + with gr.Accordion("Advanced Options:", open=False): + with gr.Row(): + with gr.Column(): + with gr.Row(): + temperature = gr.Slider( + label="Temperature", + value=0.1, + minimum=0.0, + maximum=1.0, + step=0.1, + interactive=True, + info="Higher values produce more diverse outputs", + ) + with gr.Column(): + with gr.Row(): + top_p = gr.Slider( + label="Top-p (nucleus sampling)", + value=1.0, + minimum=0.0, + maximum=1, + step=0.01, + interactive=True, + info=( + "Sample from the smallest possible set of tokens whose cumulative probability " + "exceeds top_p. Set to 1 to disable and sample from all tokens." + ), + ) + with gr.Column(): + with gr.Row(): + top_k = gr.Slider( + label="Top-k", + value=50, + minimum=0.0, + maximum=200, + step=1, + interactive=True, + info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", + ) + with gr.Column(): + with gr.Row(): + repetition_penalty = gr.Slider( + label="Repetition Penalty", + value=1.1, + minimum=1.0, + maximum=2.0, + step=0.1, + interactive=True, + info="Penalize repetition — 1.0 to disable.", + ) + with gr.Column(): + with gr.Row(): + max_tokens = gr.Slider( + label="Max new tokens", + value=256, + minimum=128, + maximum=1024, + step=32, + interactive=True, + info=("Maximum new tokens added to answer. Higher value can work for long response, but require more time to complete"), + ) + gr.Examples(examples, inputs=msg, label="Click on any example and press the 'Submit' button") + + msg.submit( + fn=bot, + inputs=[msg, chatbot, temperature, top_p, top_k, repetition_penalty, max_tokens], + outputs=[msg, chatbot, streamer], + queue=True, + ) + submit.click( + fn=bot, + inputs=[msg, chatbot, temperature, top_p, top_k, repetition_penalty, max_tokens], + outputs=[msg, chatbot, streamer], + queue=True, + ) + clear.click(fn=stop_chat_and_clear_history, inputs=streamer, outputs=[chatbot, streamer], queue=False) + + return demo diff --git a/notebooks/deepseek-r1/llm_config.py b/notebooks/deepseek-r1/llm_config.py new file mode 100644 index 00000000000..f7bfa330314 --- /dev/null +++ b/notebooks/deepseek-r1/llm_config.py @@ -0,0 +1,226 @@ +DEFAULT_SYSTEM_PROMPT = """\ +You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. +If a question does not make any sense or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\ +""" + +DEFAULT_SYSTEM_PROMPT_CHINESE = """\ +你是一个乐于助人、尊重他人以及诚实可靠的助手。在安全的情况下,始终尽可能有帮助地回答。 您的回答不应包含任何有害、不道德、种族主义、性别歧视、有毒、危险或非法的内容。请确保您的回答在社会上是公正的和积极的。 +如果一个问题没有任何意义或与事实不符,请解释原因,而不是回答错误的问题。如果您不知道问题的答案,请不要分享虚假信息。另外,答案请使用中文。\ +""" + + +def deepseek_partial_text_processor(partial_text, new_text): + partial_text += new_text + return partial_text.split("")[-1] + + +SUPPORTED_LLM_MODELS = { + "English": { + "DeepSeek-R1-Distill-Qwen-1.5B": { + "model_id": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "genai_chat_template": "{% for message in messages %}{% if loop.first %}{{ '<|begin▁of▁sentence|>' }}{% endif %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' }}{% endif %}{% if loop.last and add_generation_prompt and message['role'] != 'assitant' %}{{ '<|Assistant|>' }}{% endif %}{% endfor %}", + "system_prompt": DEFAULT_SYSTEM_PROMPT, + "stop_strings": ["<|end▁of▁sentence|>", "<|User|>", "", "<|User|>", "<|end_of_sentence|>", "' }}{% endif %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' }}{% endif %}{% if loop.last and add_generation_prompt and message['role'] != 'assitant' %}{{ '<|Assistant|>' }}{% endif %}{% endfor %}", + "system_prompt": DEFAULT_SYSTEM_PROMPT, + "stop_strings": ["<|end▁of▁sentence|>", "<|User|>", "", "<|User|>", "<|end_of_sentence|>", "' }}{% endif %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' }}{% endif %}{% if loop.last and add_generation_prompt and message['role'] != 'assitant' %}{{ '<|Assistant|>' }}{% endif %}{% endfor %}", + "system_prompt": DEFAULT_SYSTEM_PROMPT, + "stop_strings": ["<|end▁of▁sentence|>", "<|User|>", "", "<|User|>", "<|end_of_sentence|>", "' }}{% endif %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' }}{% endif %}{% if loop.last and add_generation_prompt and message['role'] != 'assitant' %}{{ '<|Assistant|>' }}{% endif %}{% endfor %}", + "system_prompt": DEFAULT_SYSTEM_PROMPT, + "stop_strings": ["<|end▁of▁sentence|>", "<|User|>", "", "<|User|>", "<|end_of_sentence|>", "' }}{% endif %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' }}{% endif %}{% if loop.last and add_generation_prompt and message['role'] != 'assitant' %}{{ '<|Assistant|>' }}{% endif %}{% endfor %}", + "system_prompt": DEFAULT_SYSTEM_PROMPT_CHINESE, + "stop_strings": ["<|end▁of▁sentence|>", "<|User|>", "", "<|User|>", "<|end_of_sentence|>", "' }}{% endif %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' }}{% endif %}{% if loop.last and add_generation_prompt and message['role'] != 'assitant' %}{{ '<|Assistant|>' }}{% endif %}{% endfor %}", + "system_prompt": DEFAULT_SYSTEM_PROMPT_CHINESE, + "stop_strings": ["<|end▁of▁sentence|>", "<|User|>", "", "<|User|>", "<|end_of_sentence|>", "' }}{% endif %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' }}{% endif %}{% if loop.last and add_generation_prompt and message['role'] != 'assitant' %}{{ '<|Assistant|>' }}{% endif %}{% endfor %}", + "system_prompt": DEFAULT_SYSTEM_PROMPT_CHINESE, + "stop_strings": ["<|end▁of▁sentence|>", "<|User|>", "", "<|User|>", "<|end_of_sentence|>", "' }}{% endif %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' }}{% endif %}{% if loop.last and add_generation_prompt and message['role'] != 'assitant' %}{{ '<|Assistant|>' }}{% endif %}{% endfor %}", + "system_prompt": DEFAULT_SYSTEM_PROMPT_CHINESE, + "stop_strings": ["<|end▁of▁sentence|>", "<|User|>", "", "<|User|>", "<|end_of_sentence|>", "