diff --git a/.ci/ignore_treon_mac.txt b/.ci/ignore_treon_mac.txt index 588138742d5..57dd1f325a5 100644 --- a/.ci/ignore_treon_mac.txt +++ b/.ci/ignore_treon_mac.txt @@ -47,4 +47,5 @@ 281-kosmos2-multimodal-large-language-model 279-mobilevlm-language-assistant 283-photo-maker +284-openvoice 404-style-transfer-webcam \ No newline at end of file diff --git a/.ci/spellcheck/.pyspelling.wordlist.txt b/.ci/spellcheck/.pyspelling.wordlist.txt index cb133e67760..4843ea57155 100644 --- a/.ci/spellcheck/.pyspelling.wordlist.txt +++ b/.ci/spellcheck/.pyspelling.wordlist.txt @@ -780,4 +780,11 @@ ZavyChromaXL Zongyuan ZeroScope zeroscope -xformers \ No newline at end of file +xformers +OpenVoice +BaseSpeakerTTS +ToneColorConverter +nn +lang +OpenVoiceBaseClass +processings \ No newline at end of file diff --git a/README.md b/README.md index 6d8fc316043..a2f1e9f948b 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ Check out the latest notebooks that show how to optimize and deploy popular mode | [DepthAnything](notebooks/280-depth-anything)
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/openvinotoolkit/openvino_notebooks/HEAD?filepath=notebooks%2F280-depth-anythingh%2F280-depth-anything.ipynb)
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openvinotoolkit/openvino_notebooks/blob/main/notebooks/280-depth-anything/280-depth-anything.ipynb) | Monocular Depth estimation with DepthAnything and OpenVINO | | | [Kosmos-2: Grounding Multimodal Large Language Models](notebooks/281-kosmos2-multimodal-large-language-model)
| Kosmos-2: Grounding Multimodal Large Language Model and OpenVINO™ | | | [PhotoMaker](notebooks/283-photo-maker)
| Text-to-image generation using PhotoMaker and OpenVINO | | +| [OpenVoice](notebooks/284-openvoice)
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/openvinotoolkit/openvino_notebooks/HEAD?filepath=notebooks%2F284-openvoice%2F284-openvoice.ipynb)[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openvinotoolkit/openvino_notebooks/blob/main/notebooks/284-openvoice/284-openvoice.ipynb) | OpenVoice a versatile instant voice tone transferring and generating speech in various languages. | | ## Table of Contents diff --git a/notebooks/284-openvoice/284-openvoice.ipynb b/notebooks/284-openvoice/284-openvoice.ipynb new file mode 100644 index 00000000000..656fab71115 --- /dev/null +++ b/notebooks/284-openvoice/284-openvoice.ipynb @@ -0,0 +1,1040 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Voice tone cloning with OpenVoice and OpenVINO" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "OpenVoice is a versatile instant voice tone transferring and generating speech in various languages with just a brief audio snippet from the source speaker. OpenVoice has three main features: (i) high quality tone color replication with multiple languages and accents; (ii) it provides fine-tuned control over voice styles, including emotions, accents, as well as other parameters such as rhythm, pauses, and intonation. (iii) OpenVoice achieves zero-shot cross-lingual voice cloning, eliminating the need for the generated speech and the reference speech to be part of a massive-speaker multilingual training dataset.\n", + "\n", + "![image](https://github.com/openvinotoolkit/openvino_notebooks/assets/5703039/ca7eab80-148d-45b0-84e8-a5a279846b51)\n", + "\n", + "More details about model can be found in [project web page](https://research.myshell.ai/open-voice), [paper](https://arxiv.org/abs/2312.01479), and official [repository](https://github.com/myshell-ai/OpenVoice)\n", + "\n", + "This notebook provides example of converting [PyTorch OpenVoice model](https://github.com/myshell-ai/OpenVoice) to OpenVINO IR. In this tutorial we will explore how to convert and run OpenVoice using OpenVINO.\n", + "#### Table of contents:\n", + "- [Clone repository and install requirements](#Clone-repository-and-install-requirements)\n", + "- [Download checkpoints and load PyTorch model](#Download-checkpoints-and-load-PyTorch-model)\n", + "- [Convert Models to OpenVINO IR](#Convert-Models-to-OpenVINO-IR)\n", + "- [Inference](#Inference)\n", + " - [Select inference device](#Select-inference-device)\n", + " - [Select reference tone](#Select-reference-tone)\n", + " - [Run inference](#Run-inference)\n", + "- [Run OpenVoice Gradio online app](#Run-OpenVoice-Gradio-online-app)\n", + "- [Cleanup](#Cleanup)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clone repository and install requirements\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "repo_dir = Path(\"OpenVoice\")\n", + "\n", + "if not repo_dir.exists():\n", + " !git clone https://github.com/myshell-ai/OpenVoice\n", + "\n", + "# append to sys.path so that modules from the repo could be imported\n", + "sys.path.append(str(repo_dir))\n", + "\n", + "%pip install -q \\\n", + "\"librosa>=0.8.1\" \\\n", + "\"wavmark>=0.0.3\" \\\n", + "\"faster-whisper>=0.9.0\" \\\n", + "\"pydub>=0.25.1\" \\\n", + "\"whisper-timestamped>=1.14.2\" \\\n", + "\"tqdm\" \\\n", + "\"inflect>=7.0.0\" \\\n", + "\"unidecode>=1.3.7\" \\\n", + "\"eng_to_ipa>=0.0.2\" \\\n", + "\"pypinyin>=0.50.0\" \\\n", + "\"cn2an>=0.5.22\" \\\n", + "\"jieba>=0.42.1\" \\\n", + "\"langid>=1.1.6\" \\\n", + "\"gradio>=4.15\" \\\n", + "\"ipywebrtc\" \\\n", + "\"ffmpeg-downloader\"\n", + "\n", + "!ffdl install -y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download checkpoints and load PyTorch model\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import openvino as ov\n", + "import ipywidgets as widgets\n", + "from IPython.display import Audio\n", + "\n", + "core = ov.Core()\n", + "\n", + "from api import BaseSpeakerTTS, ToneColorConverter, OpenVoiceBaseClass\n", + "import se_extractor" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "CKPT_BASE_PATH = 'checkpoints'\n", + "\n", + "en_suffix = f'{CKPT_BASE_PATH}/base_speakers/EN'\n", + "zh_suffix = f'{CKPT_BASE_PATH}/base_speakers/ZH'\n", + "converter_suffix = f'{CKPT_BASE_PATH}/converter'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To make notebook lightweight by default model for Chinese speech is not activated, in order turn on please set flag `enable_chinese_lang` to True" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "enable_chinese_lang = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def download_from_hf_hub(filename, local_dir='./'):\n", + " from huggingface_hub import hf_hub_download\n", + " os.makedirs(local_dir, exist_ok=True)\n", + " hf_hub_download(repo_id=\"myshell-ai/OpenVoice\", filename=filename, local_dir=local_dir)\n", + "\n", + "download_from_hf_hub(f'{converter_suffix}/checkpoint.pth')\n", + "download_from_hf_hub(f'{converter_suffix}/config.json')\n", + "download_from_hf_hub(f'{en_suffix}/checkpoint.pth')\n", + "download_from_hf_hub(f'{en_suffix}/config.json')\n", + "\n", + "download_from_hf_hub(f'{en_suffix}/en_default_se.pth')\n", + "download_from_hf_hub(f'{en_suffix}/en_style_se.pth')\n", + "\n", + "if enable_chinese_lang:\n", + " download_from_hf_hub(f'{zh_suffix}/checkpoint.pth')\n", + " download_from_hf_hub(f'{zh_suffix}/config.json')\n", + " download_from_hf_hub(f'{zh_suffix}/zh_default_se.pth')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pt_device = \"cpu\"\n", + "\n", + "en_base_speaker_tts = BaseSpeakerTTS(f'{en_suffix}/config.json', device=pt_device)\n", + "en_base_speaker_tts.load_ckpt(f'{en_suffix}/checkpoint.pth')\n", + "\n", + "tone_color_converter = ToneColorConverter(f'{converter_suffix}/config.json', device=pt_device)\n", + "tone_color_converter.load_ckpt(f'{converter_suffix}/checkpoint.pth')\n", + "\n", + "if enable_chinese_lang:\n", + " zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_suffix}/config.json', device=pt_device)\n", + " zh_base_speaker_tts.load_ckpt(f'{zh_suffix}/checkpoint.pth')\n", + "else:\n", + " zh_base_speaker_tts = None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert models to OpenVINO IR\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are 2 models in OpenVoice: first one is responsible for speech generation `BaseSpeakerTTS` and the second one `ToneColorConverter` imposes arbitrary voice tone to the original speech. To convert to OpenVino IR format first we need to get acceptable `torch.nn.Module` object. Both ToneColorConverter, BaseSpeakerTTS instead of using `self.forward` as the main entry point use custom `infer` and `convert_voice` methods respectively, therefore need to wrap them with a custom class that is inherited from torch.nn.Module." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "class OVOpenVoiceBase(torch.nn.Module):\n", + " \"\"\"\n", + " Base class for both TTS and voice tone conversion model: constructor is same for both of them.\n", + " \"\"\"\n", + " def __init__(self, voice_model: OpenVoiceBaseClass):\n", + " super().__init__()\n", + " self.voice_model = voice_model\n", + " for par in voice_model.model.parameters():\n", + " par.requires_grad = False\n", + " \n", + "class OVOpenVoiceTTS(OVOpenVoiceBase):\n", + " \"\"\"\n", + " Constructor of this class accepts BaseSpeakerTTS object for speech generation and wraps it's 'infer' method with forward.\n", + " \"\"\"\n", + " def get_example_input(self):\n", + " stn_tst = self.voice_model.get_text('this is original text', self.voice_model.hps, False)\n", + " x_tst = stn_tst.unsqueeze(0)\n", + " x_tst_lengths = torch.LongTensor([stn_tst.size(0)])\n", + " speaker_id = torch.LongTensor([1])\n", + " noise_scale = torch.tensor(0.667)\n", + " length_scale = torch.tensor(1.0)\n", + " noise_scale_w = torch.tensor(0.6)\n", + " return (x_tst, x_tst_lengths, speaker_id, noise_scale, length_scale, noise_scale_w)\n", + "\n", + " def forward(self, x, x_lengths, sid, noise_scale, length_scale, noise_scale_w):\n", + " return self.voice_model.model.infer(x, x_lengths, sid, noise_scale, length_scale, noise_scale_w)\n", + " \n", + "class OVOpenVoiceConverter(OVOpenVoiceBase):\n", + " \"\"\"\n", + " Constructor of this class accepts ToneColorConverter object for voice tone conversion and wraps it's 'voice_conversion' method with forward.\n", + " \"\"\"\n", + " def get_example_input(self):\n", + " y = torch.randn([1, 513, 238], dtype=torch.float32)\n", + " y_lengths = torch.LongTensor([y.size(-1)])\n", + " target_se = torch.randn(*(1, 256, 1))\n", + " source_se = torch.randn(*(1, 256, 1))\n", + " tau = torch.tensor(0.3)\n", + " return (y, y_lengths, source_se, target_se, tau)\n", + " \n", + " def forward(self, y, y_lengths, sid_src, sid_tgt, tau):\n", + " return self.voice_model.model.voice_conversion(y, y_lengths, sid_src, sid_tgt, tau)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Convert to OpenVino IR and save to IRs_path folder for the future use. If IRs already exist skip conversion and read them directly" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "IRS_PATH = 'openvino_irs/'\n", + "EN_TTS_IR = f'{IRS_PATH}/openvoice_en_tts.xml'\n", + "ZH_TTS_IR = f'{IRS_PATH}/openvoice_zh_tts.xml'\n", + "VOICE_CONVERTER_IR = f'{IRS_PATH}/openvoice_tone_conversion.xml'\n", + "\n", + "paths = [EN_TTS_IR, VOICE_CONVERTER_IR]\n", + "models = [OVOpenVoiceTTS(en_base_speaker_tts), OVOpenVoiceConverter(tone_color_converter)]\n", + "if enable_chinese_lang:\n", + " models.append(OVOpenVoiceTTS(zh_base_speaker_tts))\n", + " paths.append(ZH_TTS_IR)\n", + "ov_models = []\n", + "\n", + "for model, path in zip(models, paths):\n", + " if not os.path.exists(path):\n", + " ov_model = ov.convert_model(model, example_input=model.get_example_input())\n", + " ov.save_model(ov_model, path)\n", + " else:\n", + " ov_model = core.read_model(path)\n", + " ov_models.append(ov_model)\n", + "\n", + "ov_en_tts, ov_voice_conversion = ov_models[:2]\n", + "if enable_chinese_lang:\n", + " ov_zh_tts = ov_models[-1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Select inference device\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d56bb6a1d7b149c5ac539806c993ac16", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Dropdown(description='Device:', index=2, options=('CPU', 'GPU', 'AUTO'), value='AUTO')" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "core = ov.Core()\n", + "device = widgets.Dropdown(\n", + " options=core.available_devices + [\"AUTO\"],\n", + " value='AUTO',\n", + " description='Device:',\n", + " disabled=False,\n", + ")\n", + "device" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Select reference tone\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First of all, select the reference tone of voice to which the generated text will be converted: your can select from existing ones, record your own by selecting `record_manually` or upload you own file by `load_manually`" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "691de76a2eae42eeaf2be2bdd1d17343", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Dropdown(description='reference voice from which tone color will be copied', options=('demo_speaker0.mp3', 'de…" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "REFERENCE_VOICES_PATH = f'{repo_dir}/resources/'\n", + "reference_speakers = [\n", + " *[path for path in os.listdir(REFERENCE_VOICES_PATH) if os.path.splitext(path)[-1] == '.mp3'],\n", + " 'record_manually',\n", + " 'load_manually',\n", + "]\n", + "\n", + "ref_speaker = widgets.Dropdown(\n", + " options=reference_speakers,\n", + " value=reference_speakers[0],\n", + " description=\"reference voice from which tone color will be copied\",\n", + " disabled=False,\n", + ")\n", + "\n", + "ref_speaker" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "OUTPUT_DIR = 'outputs/'\n", + "os.makedirs(OUTPUT_DIR, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "ref_speaker_path = f'{REFERENCE_VOICES_PATH}/{ref_speaker.value}'\n", + "allowed_audio_types = '.mp4,.mp3,.wav,.wma,.aac,.m4a,.m4b,.webm'\n", + "\n", + "if ref_speaker.value == 'record_manually':\n", + " ref_speaker_path = f'{OUTPUT_DIR}/custom_example_sample.webm'\n", + " from ipywebrtc import AudioRecorder, CameraStream\n", + " camera = CameraStream(constraints={'audio': True,'video':False})\n", + " recorder = AudioRecorder(stream=camera, filename=ref_speaker_path, autosave=True)\n", + " display(recorder)\n", + "elif ref_speaker.value == 'load_manually':\n", + " upload_ref = widgets.FileUpload(accept=allowed_audio_types, multiple=False, description='Select audio with reference voice')\n", + " display(upload_ref)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Play the reference voice sample before cloning it's tone to another speech" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def save_audio(voice_source: widgets.FileUpload, out_path: str):\n", + " with open(out_path, 'wb') as output_file:\n", + " assert len(voice_source.value) > 0, 'Please select audio file'\n", + " output_file.write(voice_source.value[0]['content'])\n", + "\n", + "if ref_speaker.value == 'load_manually':\n", + " ref_speaker_path = f'{OUTPUT_DIR}/{upload_ref.value[0].name}'\n", + " save_audio(upload_ref, ref_speaker_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(ref_speaker_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load speaker embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# ffmpeg is neeeded to load mp3 and manually recorded webm files\n", + "import ffmpeg_downloader as ffdl\n", + "delimiter = ':' if sys.platform != 'win32' else ';'\n", + "os.environ['PATH'] = os.environ['PATH'] + f\"{delimiter}{ffdl.ffmpeg_dir}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "en_source_default_se = torch.load(f'{en_suffix}/en_default_se.pth')\n", + "en_source_style_se = torch.load(f'{en_suffix}/en_style_se.pth')\n", + "zh_source_se = torch.load(f'{zh_suffix}/zh_default_se.pth') if enable_chinese_lang else None\n", + "\n", + "target_se, audio_name = se_extractor.get_se(ref_speaker_path, tone_color_converter, target_dir=OUTPUT_DIR, vad=True) # ffmpeg must be installed" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Replace original infer methods of `OpenVoiceBaseClass` with optimized OpenVINO inference.\n", + "\n", + "There are pre and post processings that are not traceable and could not be offloaded to OpenVINO, instead of writing such processing ourselves we will rely on the already existing ones. We just replace infer and voice conversion functions of `OpenVoiceBaseClass` so that the the most computationally expensive part is done in OpenVINO." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def get_pathched_infer(ov_model: ov.Model, device: str) -> callable:\n", + " compiled_model = core.compile_model(ov_model, device)\n", + " \n", + " def infer_impl(x, x_lengths, sid, noise_scale, length_scale, noise_scale_w):\n", + " ov_output = compiled_model((x, x_lengths, sid, noise_scale, length_scale, noise_scale_w))\n", + " return (torch.tensor(ov_output[0]), )\n", + " return infer_impl\n", + "\n", + "def get_patched_voice_conversion(ov_model: ov.Model, device: str) -> callable:\n", + " compiled_model = core.compile_model(ov_model, device)\n", + "\n", + " def voice_conversion_impl(y, y_lengths, sid_src, sid_tgt, tau):\n", + " ov_output = compiled_model((y, y_lengths, sid_src, sid_tgt, tau))\n", + " return (torch.tensor(ov_output[0]), )\n", + " return voice_conversion_impl\n", + "\n", + "en_base_speaker_tts.model.infer = get_pathched_infer(ov_en_tts, device.value)\n", + "tone_color_converter.model.voice_conversion = get_patched_voice_conversion(ov_voice_conversion, device.value)\n", + "if enable_chinese_lang:\n", + " zh_base_speaker_tts.model.infer = get_pathched_infer(ov_zh_tts, device.value)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run inference\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5dabde936a4e480ca0cfb5734d3db92a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Dropdown(description='Voice source', options=('use TTS', 'choose_manually'), value='use TTS')" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voice_source = widgets.Dropdown(\n", + " options=['use TTS', 'choose_manually'],\n", + " value='use TTS',\n", + " description=\"Voice source\",\n", + " disabled=False,\n", + ")\n", + "\n", + "voice_source" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "if voice_source.value == 'choose_manually':\n", + " upload_orig_voice = widgets.FileUpload(accept=allowed_audio_types, multiple=False, description='audo whose tone will be replaced')\n", + " display(upload_orig_voice)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if voice_source.value == 'choose_manually':\n", + " orig_voice_path = f'{OUTPUT_DIR}/{upload_orig_voice.value[0].name}'\n", + " save_audio(upload_orig_voice, orig_voice_path)\n", + " source_se, _ = se_extractor.get_se(orig_voice_path, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)\n", + "else:\n", + " text = \"\"\"\n", + " OpenVINO toolkit is a comprehensive toolkit for quickly developing applications and solutions that solve \n", + " a variety of tasks including emulation of human vision, automatic speech recognition, natural language processing, \n", + " recommendation systems, and many others.\n", + " \"\"\"\n", + " source_se = en_source_default_se\n", + " orig_voice_path = f'{OUTPUT_DIR}/tmp.wav'\n", + " en_base_speaker_tts.tts(text, orig_voice_path, speaker='default', language='English')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And finally, run voice tone conversion with OpenVINO optimized model" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7b54d91a7274453f9b3246ff134413d0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "FloatSlider(value=0.3, description='tau', max=2.0, min=0.01, step=0.01)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tau_slider = widgets.FloatSlider(\n", + " value=0.3,\n", + " min=0.01,\n", + " max=2.0,\n", + " step=0.01,\n", + " description='tau',\n", + " disabled=False,\n", + " readout_format='.2f',\n", + ")\n", + "tau_slider" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "resulting_voice_path = f'{OUTPUT_DIR}/output_with_cloned_voice_tone.wav'\n", + "\n", + "tone_color_converter.convert(\n", + " audio_src_path=orig_voice_path, \n", + " src_se=source_se, \n", + " tgt_se=target_se, \n", + " output_path=resulting_voice_path, \n", + " tau=tau_slider.value,\n", + " message=\"@MyShell\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(orig_voice_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(resulting_voice_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run OpenVoice Gradio online app\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use [Gradio](https://www.gradio.app/) app to run TTS and voice tone conversion online." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import gradio as gr\n", + "import langid\n", + "\n", + "supported_languages = ['zh', 'en']\n", + "\n", + "def build_predict(output_dir, tone_color_converter, en_tts_model, zh_tts_model, en_source_default_se, en_source_style_se, zh_source_se):\n", + " def predict(prompt, style, audio_file_pth, agree):\n", + " return predict_impl(prompt, style, audio_file_pth, agree, output_dir, tone_color_converter, en_tts_model, zh_tts_model, en_source_default_se, en_source_style_se, zh_source_se)\n", + " return predict\n", + "\n", + "def predict_impl(prompt, style, audio_file_pth, agree, output_dir, tone_color_converter, en_tts_model, zh_tts_model, en_source_default_se, en_source_style_se, zh_source_se):\n", + " text_hint = ''\n", + " if not agree:\n", + " text_hint += '[ERROR] Please accept the Terms & Condition!\\n'\n", + " gr.Warning(\"Please accept the Terms & Condition!\")\n", + " return (\n", + " text_hint,\n", + " None,\n", + " None,\n", + " )\n", + "\n", + " language_predicted = langid.classify(prompt)[0].strip() \n", + " print(f\"Detected language:{language_predicted}\")\n", + "\n", + " if language_predicted not in supported_languages:\n", + " text_hint += f\"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\\n\"\n", + " gr.Warning(\n", + " f\"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\"\n", + " )\n", + "\n", + " return (\n", + " text_hint,\n", + " None,\n", + " )\n", + " \n", + " if language_predicted == \"zh\":\n", + " tts_model = zh_tts_model\n", + " if zh_tts_model is None:\n", + " gr.Warning(\"TTS model for Chinece language was not loaded please set 'enable_chinese_lang=True`\")\n", + " return (\n", + " text_hint,\n", + " None,\n", + " )\n", + " source_se = zh_source_se\n", + " language = 'Chinese'\n", + " if style not in ['default']:\n", + " text_hint += f\"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\\n\"\n", + " gr.Warning(f\"The style {style} is not supported for Chinese, which should be in ['default']\")\n", + " return (\n", + " text_hint,\n", + " None,\n", + " )\n", + "\n", + " else:\n", + " tts_model = en_tts_model\n", + " if style == 'default':\n", + " source_se = en_source_default_se\n", + " else:\n", + " source_se = en_source_style_se\n", + " language = 'English'\n", + " supported_styles = ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n", + " if style not in supported_styles:\n", + " text_hint += f\"[ERROR] The style {style} is not supported for English, which should be in {*supported_styles,}\\n\"\n", + " gr.Warning(f\"The style {style} is not supported for English, which should be in {*supported_styles,}\")\n", + " return (\n", + " text_hint,\n", + " None,\n", + " )\n", + "\n", + " speaker_wav = audio_file_pth\n", + "\n", + " if len(prompt) < 2:\n", + " text_hint += \"[ERROR] Please give a longer prompt text \\n\"\n", + " gr.Warning(\"Please give a longer prompt text\")\n", + " return (\n", + " text_hint,\n", + " None,\n", + " )\n", + " if len(prompt) > 200:\n", + " text_hint += \"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \\n\"\n", + " gr.Warning(\n", + " \"Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage\"\n", + " )\n", + " return (\n", + " text_hint,\n", + " None,\n", + " )\n", + " \n", + " # note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference\n", + " try:\n", + " target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)\n", + " except Exception as e:\n", + " text_hint += f\"[ERROR] Get target tone color error {str(e)} \\n\"\n", + " gr.Warning(\n", + " \"[ERROR] Get target tone color error {str(e)} \\n\"\n", + " )\n", + " return (\n", + " text_hint,\n", + " None,\n", + " )\n", + "\n", + " src_path = f'{output_dir}/tmp.wav'\n", + " tts_model.tts(prompt, src_path, speaker=style, language=language)\n", + "\n", + " save_path = f'{output_dir}/output.wav'\n", + " encode_message = \"@MyShell\"\n", + " tone_color_converter.convert(\n", + " audio_src_path=src_path, \n", + " src_se=source_se, \n", + " tgt_se=target_se, \n", + " output_path=save_path,\n", + " message=encode_message)\n", + "\n", + " text_hint += 'Get response successfully \\n'\n", + "\n", + " return (\n", + " text_hint,\n", + " src_path,\n", + " save_path,\n", + " )\n", + "\n", + "description = \"\"\"\n", + " # OpenVoice accelerated by OpenVINO:\n", + " \n", + " a versatile instant voice cloning approach that requires only a short audio clip from the reference speaker to replicate their voice and generate speech in multiple languages. OpenVoice enables granular control over voice styles, including emotion, accent, rhythm, pauses, and intonation, in addition to replicating the tone color of the reference speaker. OpenVoice also achieves zero-shot cross-lingual voice cloning for languages not included in the massive-speaker training set.\n", + "\"\"\"\n", + "\n", + "content = \"\"\"\n", + "
\n", + "If the generated voice does not sound like the reference voice, please refer to this QnA. For multi-lingual & cross-lingual examples, please refer to this jupyter notebook.\n", + "This online demo mainly supports English. The default style also supports Chinese. But OpenVoice can adapt to any other language as long as a base speaker is provided.\n", + "
\n", + "\"\"\"\n", + "wrapped_markdown_content = f\"
{content}
\"\n", + "\n", + "\n", + "examples = [\n", + " [\n", + " \"今天天气真好,我们一起出去吃饭吧。\",\n", + " 'default',\n", + " \"OpenVoice/resources/demo_speaker1.mp3\",\n", + " True,\n", + " ],[\n", + " \"This audio is generated by open voice with a half-performance model.\",\n", + " 'whispering',\n", + " \"OpenVoice/resources/demo_speaker2.mp3\",\n", + " True,\n", + " ],\n", + " [\n", + " \"He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.\",\n", + " 'sad',\n", + " \"OpenVoice/resources/demo_speaker0.mp3\",\n", + " True,\n", + " ],\n", + "]\n", + "\n", + "def get_demo(output_dir, tone_color_converter, en_tts_model, zh_tts_model, en_source_default_se, en_source_style_se, zh_source_se):\n", + " with gr.Blocks(analytics_enabled=False) as demo:\n", + "\n", + " with gr.Row():\n", + " gr.Markdown(description)\n", + " with gr.Row():\n", + " gr.HTML(wrapped_markdown_content)\n", + "\n", + " with gr.Row():\n", + " with gr.Column():\n", + " input_text_gr = gr.Textbox(\n", + " label=\"Text Prompt\",\n", + " info=\"One or two sentences at a time is better. Up to 200 text characters.\",\n", + " value=\"He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.\",\n", + " )\n", + " style_gr = gr.Dropdown(\n", + " label=\"Style\",\n", + " info=\"Select a style of output audio for the synthesised speech. (Chinese only support 'default' now)\",\n", + " choices=['default', 'whispering', 'cheerful', 'terrified', 'angry', 'sad', 'friendly'],\n", + " max_choices=1,\n", + " value=\"default\",\n", + " )\n", + " ref_gr = gr.Audio(\n", + " label=\"Reference Audio\",\n", + " type=\"filepath\",\n", + " value=\"OpenVoice/resources/demo_speaker2.mp3\",\n", + " )\n", + " tos_gr = gr.Checkbox(\n", + " label=\"Agree\",\n", + " value=False,\n", + " info=\"I agree to the terms of the cc-by-nc-4.0 license-: https://github.com/myshell-ai/OpenVoice/blob/main/LICENSE\",\n", + " )\n", + "\n", + " tts_button = gr.Button(\"Send\", elem_id=\"send-btn\", visible=True)\n", + "\n", + "\n", + " with gr.Column():\n", + " out_text_gr = gr.Text(label=\"Info\")\n", + " audio_orig_gr = gr.Audio(label=\"Synthesised Audio\", autoplay=False)\n", + " audio_gr = gr.Audio(label=\"Audio with cloned voice\", autoplay=True)\n", + " # ref_audio_gr = gr.Audio(label=\"Reference Audio Used\")\n", + " predict = build_predict(\n", + " output_dir, \n", + " tone_color_converter, \n", + " en_tts_model, \n", + " zh_tts_model, \n", + " en_source_default_se, \n", + " en_source_style_se, \n", + " zh_source_se\n", + " )\n", + "\n", + " gr.Examples(examples,\n", + " label=\"Examples\",\n", + " inputs=[input_text_gr, style_gr, ref_gr, tos_gr],\n", + " outputs=[out_text_gr, audio_gr],\n", + " fn=predict,\n", + " cache_examples=False,)\n", + " tts_button.click(predict, [input_text_gr, style_gr, ref_gr, tos_gr], outputs=[out_text_gr, audio_orig_gr, audio_gr])\n", + " return demo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "demo = get_demo(OUTPUT_DIR, tone_color_converter, en_base_speaker_tts, zh_base_speaker_tts, en_source_default_se, en_source_style_se, zh_source_se)\n", + "demo.queue(max_size=2)\n", + "\n", + "try:\n", + " demo.launch(debug=True, height=1000)\n", + "except Exception:\n", + " demo.launch(share=True, debug=True, height=1000)\n", + "# if you are launching remotely, specify server_name and server_port\n", + "# demo.launch(server_name='your server name', server_port='server port in int')\n", + "# Read more in the docs: https://gradio.app/docs/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleanup\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import shutil\n", + "# shutil.rmtree(CKPT_BASE_PATH)\n", + "# shutil.rmtree(IRS_PATH)\n", + "# shutil.rmtree(OUTPUT_DIR)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "openvino_notebooks": { + "imageUrl": "", + "tags": { + "categories": [ + "AI Trends", + "Convert", + "Optimize", + "Model Demos", + "Live Demos" + ], + "libraries": [], + "other": [], + "tasks": [ + "Text-to-Audio" + ] + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/284-openvoice/README.md b/notebooks/284-openvoice/README.md new file mode 100644 index 00000000000..b36a59010bf --- /dev/null +++ b/notebooks/284-openvoice/README.md @@ -0,0 +1,30 @@ +# Voice tone cloning with OpenVoice and OpenVINO + +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/openvinotoolkit/openvino_notebooks/HEAD?filepath=notebooks%2F284-openvoice%2F284-openvoice.ipynb) +[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openvinotoolkit/openvino_notebooks/blob/main/notebooks/284-openvoice/284-openvoice.ipynb) + + +![sdf](https://github.com/openvinotoolkit/openvino_notebooks/assets/5703039/ca7eab80-148d-45b0-84e8-a5a279846b51) + +[OpenVoice](https://github.com/myshell-ai/OpenVoice) a versatile instant voice tone transferring and generating speech in various languages with just a brief audio snippet from the source speaker. OpenVoice represents has three main features: (i) high quality tone color replication with multiple languages and accents; (ii) it provides fine-tuned control over voice styles, including emotions, accents, as well as other parameters such as rhythm, pauses, and intonation. (iii) OpenVoice achieves zero-shot cross-lingual voice cloning, eliminating the need for the generated speech and the reference speech to be part of a massive-speaker multilingual training dataset + +More details about model can be found in [project web page](https://research.myshell.ai/open-voice), [paper](https://arxiv.org/abs/2312.01479), and official [repository](https://github.com/myshell-ai/OpenVoice) + +In this tutorial we will explore how to convert and run OpenVoice using OpenVINO. + +## Notebook Contents + +This notebook demonstrates voice tone cloning with [OpenVoice](https://github.com/myshell-ai/OpenVoice) in OpenVINO. + +The tutorial consists of following steps: +- Install prerequisites +- Load PyTorch model +- Convert Model to Openvino Intermediate Representation format +- Run OpenVINO model inference on a single example +- Launch interactive demo + +## Installation Instructions + +This is a self-contained example that relies solely on its own code.
+We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. +For details, please refer to [Installation Guide](../../README.md).