Skip to content

Commit 5e808a1

Browse files
authored
add npu support in llm chatbot (#2607)
CVS-158814
1 parent 908962a commit 5e808a1

File tree

3 files changed

+616
-120
lines changed

3 files changed

+616
-120
lines changed

notebooks/llm-chatbot/gradio_helper_genai.py

Lines changed: 134 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Any
2-
from queue import Queue
31
import openvino as ov
2+
import openvino_genai as ov_genai
43
from uuid import uuid4
54
from threading import Event, Thread
5+
import queue
66

77
max_new_tokens = 256
88

@@ -62,33 +62,137 @@ def get_system_prompt(model_language):
6262
)
6363

6464

65-
class TextQueue:
66-
def __init__(self) -> None:
67-
self.text_queue = Queue()
68-
self.stop_signal = None
69-
self.stop_tokens = []
65+
class IterableStreamer(ov_genai.StreamerBase):
66+
"""
67+
A custom streamer class for handling token streaming and detokenization with buffering.
7068
71-
def __call__(self, text) -> Any:
72-
self.text_queue.put(text)
69+
Attributes:
70+
tokenizer (Tokenizer): The tokenizer used for encoding and decoding tokens.
71+
tokens_cache (list): A buffer to accumulate tokens for detokenization.
72+
text_queue (Queue): A synchronized queue for storing decoded text chunks.
73+
print_len (int): The length of the printed text to manage incremental decoding.
74+
"""
75+
76+
def __init__(self, tokenizer):
77+
"""
78+
Initializes the IterableStreamer with the given tokenizer.
79+
80+
Args:
81+
tokenizer (Tokenizer): The tokenizer to use for encoding and decoding tokens.
82+
"""
83+
super().__init__()
84+
self.tokenizer = tokenizer
85+
self.tokens_cache = []
86+
self.text_queue = queue.Queue()
87+
self.print_len = 0
7388

7489
def __iter__(self):
90+
"""
91+
Returns the iterator object itself.
92+
"""
7593
return self
7694

7795
def __next__(self):
78-
value = self.text_queue.get()
79-
if value == self.stop_signal or value in self.stop_tokens:
80-
raise StopIteration()
96+
"""
97+
Returns the next value from the text queue.
98+
99+
Returns:
100+
str: The next decoded text chunk.
101+
102+
Raises:
103+
StopIteration: If there are no more elements in the queue.
104+
"""
105+
value = self.text_queue.get() # get() will be blocked until a token is available.
106+
if value is None:
107+
raise StopIteration
108+
return value
109+
110+
def get_stop_flag(self):
111+
"""
112+
Checks whether the generation process should be stopped.
113+
114+
Returns:
115+
bool: Always returns False in this implementation.
116+
"""
117+
return False
118+
119+
def put_word(self, word: str):
120+
"""
121+
Puts a word into the text queue.
122+
123+
Args:
124+
word (str): The word to put into the queue.
125+
"""
126+
self.text_queue.put(word)
127+
128+
def put(self, token_id: int) -> bool:
129+
"""
130+
Processes a token and manages the decoding buffer. Adds decoded text to the queue.
131+
132+
Args:
133+
token_id (int): The token_id to process.
134+
135+
Returns:
136+
bool: True if generation should be stopped, False otherwise.
137+
"""
138+
self.tokens_cache.append(token_id)
139+
text = self.tokenizer.decode(self.tokens_cache)
140+
141+
word = ""
142+
if len(text) > self.print_len and "\n" == text[-1]:
143+
# Flush the cache after the new line symbol.
144+
word = text[self.print_len :]
145+
self.tokens_cache = []
146+
self.print_len = 0
147+
elif len(text) >= 3 and text[-3:] == chr(65533):
148+
# Don't print incomplete text.
149+
pass
150+
elif len(text) > self.print_len:
151+
# It is possible to have a shorter text after adding new token.
152+
# Print to output only if text length is increaesed.
153+
word = text[self.print_len :]
154+
self.print_len = len(text)
155+
self.put_word(word)
156+
157+
if self.get_stop_flag():
158+
# When generation is stopped from streamer then end is not called, need to call it here manually.
159+
self.end()
160+
return True # True means stop generation
81161
else:
82-
return value
162+
return False # False means continue generation
163+
164+
def end(self):
165+
"""
166+
Flushes residual tokens from the buffer and puts a None value in the queue to signal the end.
167+
"""
168+
text = self.tokenizer.decode(self.tokens_cache)
169+
if len(text) > self.print_len:
170+
word = text[self.print_len :]
171+
self.put_word(word)
172+
self.tokens_cache = []
173+
self.print_len = 0
174+
self.put_word(None)
83175

84176
def reset(self):
85-
self.text_queue = Queue()
177+
self.tokens_cache = []
178+
self.text_queue = queue.Queue()
179+
self.print_len = 0
86180

87-
def end(self):
88-
self.text_queue.put(self.stop_signal)
89181

182+
class ChunkStreamer(IterableStreamer):
183+
184+
def __init__(self, tokenizer, tokens_len=4):
185+
super().__init__(tokenizer)
186+
self.tokens_len = tokens_len
90187

91-
def make_demo(pipe, model_configuration, model_id, model_language):
188+
def put(self, token_id: int) -> bool:
189+
if (len(self.tokens_cache) + 1) % self.tokens_len != 0:
190+
self.tokens_cache.append(token_id)
191+
return False
192+
return super().put(token_id)
193+
194+
195+
def make_demo(pipe, model_configuration, model_id, model_language, disable_advanced=False):
92196
import gradio as gr
93197

94198
max_new_tokens = 256
@@ -135,14 +239,18 @@ def bot(message, history, temperature, top_p, top_k, repetition_penalty):
135239
history: updated history with message and answer from chatbot
136240
active_chat: if we are here, the chat is running or will be started, so return True
137241
"""
138-
streamer = TextQueue()
139-
config = pipe.get_generation_config()
140-
config.temperature = temperature
141-
config.top_p = top_p
142-
config.top_k = top_k
143-
config.do_sample = temperature > 0.0
144-
config.max_new_tokens = max_new_tokens
145-
config.repetition_penalty = repetition_penalty
242+
streamer = ChunkStreamer(pipe.get_tokenizer())
243+
if not disable_advanced:
244+
config = pipe.get_generation_config()
245+
config.temperature = temperature
246+
config.top_p = top_p
247+
config.top_k = top_k
248+
config.do_sample = temperature > 0.0
249+
config.max_new_tokens = max_new_tokens
250+
config.repetition_penalty = repetition_penalty
251+
else:
252+
config = ov_genai.GenerationConfig()
253+
config.max_new_tokens = max_new_tokens
146254
history = history or []
147255
if not history:
148256
pipe.start_chat(system_message=start_message)
@@ -204,7 +312,7 @@ def stop_chat_and_clear_history(streamer):
204312
submit = gr.Button("Submit")
205313
stop = gr.Button("Stop")
206314
clear = gr.Button("Clear")
207-
with gr.Row():
315+
with gr.Row(visible=not disable_advanced):
208316
with gr.Accordion("Advanced Options:", open=False):
209317
with gr.Row():
210318
with gr.Column():

0 commit comments

Comments
 (0)