-
Notifications
You must be signed in to change notification settings - Fork 848
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Switch Qwen2-VL notebook to GenAI (#2723)
Ticket: CVS-158716 --------- Co-authored-by: Ekaterina Aidova <[email protected]>
- Loading branch information
Showing
3 changed files
with
212 additions
and
304 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,205 +1,131 @@ | ||
from pathlib import Path | ||
import gradio as gr | ||
import copy | ||
import re | ||
from threading import Thread | ||
from transformers import TextIteratorStreamer | ||
from qwen_vl_utils import process_vision_info | ||
|
||
|
||
def _parse_text(text): | ||
lines = text.split("\n") | ||
lines = [line for line in lines if line != ""] | ||
count = 0 | ||
for i, line in enumerate(lines): | ||
if "```" in line: | ||
count += 1 | ||
items = line.split("`") | ||
if count % 2 == 1: | ||
lines[i] = f'<pre><code class="language-{items[-1]}">' | ||
else: | ||
lines[i] = "<br></code></pre>" | ||
else: | ||
if i > 0: | ||
if count % 2 == 1: | ||
line = line.replace("`", r"\`") | ||
line = line.replace("<", "<") | ||
line = line.replace(">", ">") | ||
line = line.replace(" ", " ") | ||
line = line.replace("*", "*") | ||
line = line.replace("_", "_") | ||
line = line.replace("-", "-") | ||
line = line.replace(".", ".") | ||
line = line.replace("!", "!") | ||
line = line.replace("(", "(") | ||
line = line.replace(")", ")") | ||
line = line.replace("$", "$") | ||
lines[i] = "<br>" + line | ||
text = "".join(lines) | ||
return text | ||
|
||
|
||
def _remove_image_special(text): | ||
text = text.replace("<ref>", "").replace("</ref>", "") | ||
return re.sub(r"<box>.*?(</box>|$)", "", text) | ||
|
||
|
||
def is_video_file(filename): | ||
video_extensions = [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg"] | ||
return any(filename.lower().endswith(ext) for ext in video_extensions) | ||
|
||
|
||
def transform_messages(original_messages): | ||
transformed_messages = [] | ||
for message in original_messages: | ||
new_content = [] | ||
for item in message["content"]: | ||
if "image" in item: | ||
new_item = {"type": "image", "image": item["image"]} | ||
elif "text" in item: | ||
new_item = {"type": "text", "text": item["text"]} | ||
elif "video" in item: | ||
new_item = {"type": "video", "video": item["video"]} | ||
else: | ||
continue | ||
new_content.append(new_item) | ||
|
||
new_message = {"role": message["role"], "content": new_content} | ||
transformed_messages.append(new_message) | ||
|
||
return transformed_messages | ||
from PIL import Image | ||
import numpy as np | ||
import requests | ||
from threading import Event, Thread | ||
import inspect | ||
from queue import Queue | ||
|
||
example_image_urls = [ | ||
( | ||
"https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/1d6a0188-5613-418d-a1fd-4560aae1d907", | ||
"bee.jpg", | ||
), | ||
( | ||
"https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/6cc7feeb-0721-4b5d-8791-2576ed9d2863", | ||
"baklava.png", | ||
), | ||
] | ||
for url, file_name in example_image_urls: | ||
if not Path(file_name).exists(): | ||
Image.open(requests.get(url, stream=True).raw).save(file_name) | ||
|
||
def make_demo(model, processor): | ||
def call_local_model(model, processor, messages): | ||
messages = transform_messages(messages) | ||
|
||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | ||
image_inputs, video_inputs = process_vision_info(messages) | ||
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(model.device) | ||
def make_demo(model): | ||
import openvino_genai | ||
import openvino as ov | ||
|
||
tokenizer = processor.tokenizer | ||
streamer = TextIteratorStreamer(tokenizer, timeout=3600.0, skip_prompt=True, skip_special_tokens=True) | ||
has_additonal_buttons = "undo_button" in inspect.signature(gr.ChatInterface.__init__).parameters | ||
|
||
gen_kwargs = {"max_new_tokens": 512, "streamer": streamer, **inputs} | ||
def read_image(path: str) -> ov.Tensor: | ||
""" | ||
thread = Thread(target=model.generate, kwargs=gen_kwargs) | ||
thread.start() | ||
Args: | ||
path: The path to the image. | ||
generated_text = "" | ||
for new_text in streamer: | ||
generated_text += new_text | ||
yield generated_text | ||
|
||
def create_predict_fn(): | ||
def predict(_chatbot, task_history): | ||
chat_query = _chatbot[-1][0] | ||
query = task_history[-1][0] | ||
if len(chat_query) == 0: | ||
_chatbot.pop() | ||
task_history.pop() | ||
return _chatbot | ||
print("User: " + _parse_text(query)) | ||
history_cp = copy.deepcopy(task_history) | ||
full_response = "" | ||
messages = [] | ||
content = [] | ||
for q, a in history_cp: | ||
if isinstance(q, (tuple, list)): | ||
if is_video_file(q[0]): | ||
content.append({"video": f"file://{q[0]}"}) | ||
else: | ||
content.append({"image": f"file://{q[0]}"}) | ||
else: | ||
content.append({"text": q}) | ||
messages.append({"role": "user", "content": content}) | ||
messages.append({"role": "assistant", "content": [{"text": a}]}) | ||
content = [] | ||
messages.pop() | ||
|
||
for response in call_local_model(model, processor, messages): | ||
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response))) | ||
|
||
yield _chatbot | ||
full_response = _parse_text(response) | ||
|
||
task_history[-1] = (query, full_response) | ||
print("Qwen-VL-Chat: " + _parse_text(full_response)) | ||
yield _chatbot | ||
|
||
return predict | ||
|
||
def create_regenerate_fn(): | ||
def regenerate(_chatbot, task_history): | ||
if not task_history: | ||
return _chatbot | ||
item = task_history[-1] | ||
if item[1] is None: | ||
return _chatbot | ||
task_history[-1] = (item[0], None) | ||
chatbot_item = _chatbot.pop(-1) | ||
if chatbot_item[0] is None: | ||
_chatbot[-1] = (_chatbot[-1][0], None) | ||
else: | ||
_chatbot.append((chatbot_item[0], None)) | ||
_chatbot_gen = predict(_chatbot, task_history) | ||
for _chatbot in _chatbot_gen: | ||
yield _chatbot | ||
|
||
return regenerate | ||
|
||
predict = create_predict_fn() | ||
regenerate = create_regenerate_fn() | ||
|
||
def add_text(history, task_history, text): | ||
task_text = text | ||
history = history if history is not None else [] | ||
task_history = task_history if task_history is not None else [] | ||
history = history + [(_parse_text(text), None)] | ||
task_history = task_history + [(task_text, None)] | ||
return history, task_history, "" | ||
|
||
def add_file(history, task_history, file): | ||
history = history if history is not None else [] | ||
task_history = task_history if task_history is not None else [] | ||
history = history + [((file.name,), None)] | ||
task_history = task_history + [((file.name,), None)] | ||
return history, task_history | ||
|
||
def reset_user_input(): | ||
return gr.update(value="") | ||
|
||
def reset_state(task_history): | ||
task_history.clear() | ||
return [] | ||
|
||
with gr.Blocks() as demo: | ||
gr.Markdown("""<center><font size=8>Qwen2-VL OpenVINO demo</center>""") | ||
|
||
chatbot = gr.Chatbot(label="Qwen2-VL", elem_classes="control-height", height=500) | ||
query = gr.Textbox(lines=2, label="Input") | ||
task_history = gr.State([]) | ||
|
||
with gr.Row(): | ||
addfile_btn = gr.UploadButton("📁 Upload (上传文件)", file_types=["image", "video"]) | ||
submit_btn = gr.Button("🚀 Submit (发送)") | ||
regen_btn = gr.Button("🤔️ Regenerate (重试)") | ||
empty_bin = gr.Button("🧹 Clear History (清除历史)") | ||
|
||
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then( | ||
predict, [chatbot, task_history], [chatbot], show_progress=True | ||
) | ||
submit_btn.click(reset_user_input, [], [query]) | ||
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) | ||
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True) | ||
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True) | ||
|
||
gr.Markdown( | ||
"""\ | ||
<font size=2>Note: This demo is governed by the original license of Qwen2-VL. \ | ||
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ | ||
including hate speech, violence, pornography, deception, etc. \ | ||
(注:本演示受Qwen2-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ | ||
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""" | ||
) | ||
Returns: the ov.Tensor containing the image. | ||
""" | ||
pic = Image.open(path).convert("RGB") | ||
image_data = np.array(pic.getdata()).reshape(1, pic.size[1], pic.size[0], 3).astype(np.byte) | ||
return ov.Tensor(image_data) | ||
|
||
class TextQueue: | ||
def __init__(self) -> None: | ||
self.text_queue = Queue() | ||
self.stop_signal = None | ||
self.stop_tokens = [] | ||
|
||
def __call__(self, text): | ||
self.text_queue.put(text) | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
value = self.text_queue.get() | ||
if value == self.stop_signal or value in self.stop_tokens: | ||
raise StopIteration() | ||
else: | ||
return value | ||
|
||
def reset(self): | ||
self.text_queue = Queue() | ||
|
||
def end(self): | ||
self.text_queue.put(self.stop_signal) | ||
|
||
def bot_streaming(message, history): | ||
print(f"message is - {message}") | ||
print(f"history is - {history}") | ||
|
||
if not history: | ||
model.start_chat() | ||
generation_config = openvino_genai.GenerationConfig() | ||
generation_config.max_new_tokens = 128 | ||
files = message["files"] if isinstance(message, dict) else message.files | ||
message_text = message["text"] if isinstance(message, dict) else message.text | ||
|
||
image = None | ||
if files: | ||
# message["files"][-1] is a Dict or just a string | ||
if isinstance(files[-1], dict): | ||
image = files[-1]["path"] | ||
else: | ||
if isinstance(files[-1], (str, Path)): | ||
image = files[-1] | ||
else: | ||
image = files[-1] if isinstance(files[-1], (list, tuple)) else files[-1].path | ||
if image is not None: | ||
image = read_image(image) | ||
streamer = TextQueue() | ||
stream_complete = Event() | ||
|
||
def generate_and_signal_complete(): | ||
""" | ||
generation function for single thread | ||
""" | ||
streamer.reset() | ||
generation_kwargs = {"prompt": message_text, "generation_config": generation_config, "streamer": streamer} | ||
if image is not None: | ||
generation_kwargs["image"] = image | ||
model.generate(**generation_kwargs) | ||
stream_complete.set() | ||
streamer.end() | ||
|
||
t1 = Thread(target=generate_and_signal_complete) | ||
t1.start() | ||
|
||
buffer = "" | ||
for new_text in streamer: | ||
buffer += new_text | ||
yield buffer | ||
|
||
additional_buttons = {} | ||
if has_additonal_buttons: | ||
additional_buttons = {"undo_button": None, "retry_button": None} | ||
demo = gr.ChatInterface( | ||
fn=bot_streaming, | ||
title="Qwen2-VL OpenVINO Demo", | ||
examples=[ | ||
{"text": "What is on the flower?", "files": ["./bee.jpg"]}, | ||
{"text": "How to make this pastry?", "files": ["./baklava.png"]}, | ||
], | ||
stop_btn=None, | ||
multimodal=True, | ||
**additional_buttons, | ||
) | ||
return demo |
Oops, something went wrong.