|
1 |
| -from typing import Any |
2 |
| -from queue import Queue |
3 | 1 | import openvino as ov
|
| 2 | +import openvino_genai as ov_genai |
4 | 3 | from uuid import uuid4
|
5 | 4 | from threading import Event, Thread
|
| 5 | +import queue |
6 | 6 |
|
7 | 7 | max_new_tokens = 256
|
8 | 8 |
|
@@ -62,33 +62,137 @@ def get_system_prompt(model_language):
|
62 | 62 | )
|
63 | 63 |
|
64 | 64 |
|
65 |
| -class TextQueue: |
66 |
| - def __init__(self) -> None: |
67 |
| - self.text_queue = Queue() |
68 |
| - self.stop_signal = None |
69 |
| - self.stop_tokens = [] |
| 65 | +class IterableStreamer(ov_genai.StreamerBase): |
| 66 | + """ |
| 67 | + A custom streamer class for handling token streaming and detokenization with buffering. |
70 | 68 |
|
71 |
| - def __call__(self, text) -> Any: |
72 |
| - self.text_queue.put(text) |
| 69 | + Attributes: |
| 70 | + tokenizer (Tokenizer): The tokenizer used for encoding and decoding tokens. |
| 71 | + tokens_cache (list): A buffer to accumulate tokens for detokenization. |
| 72 | + text_queue (Queue): A synchronized queue for storing decoded text chunks. |
| 73 | + print_len (int): The length of the printed text to manage incremental decoding. |
| 74 | + """ |
| 75 | + |
| 76 | + def __init__(self, tokenizer): |
| 77 | + """ |
| 78 | + Initializes the IterableStreamer with the given tokenizer. |
| 79 | +
|
| 80 | + Args: |
| 81 | + tokenizer (Tokenizer): The tokenizer to use for encoding and decoding tokens. |
| 82 | + """ |
| 83 | + super().__init__() |
| 84 | + self.tokenizer = tokenizer |
| 85 | + self.tokens_cache = [] |
| 86 | + self.text_queue = queue.Queue() |
| 87 | + self.print_len = 0 |
73 | 88 |
|
74 | 89 | def __iter__(self):
|
| 90 | + """ |
| 91 | + Returns the iterator object itself. |
| 92 | + """ |
75 | 93 | return self
|
76 | 94 |
|
77 | 95 | def __next__(self):
|
78 |
| - value = self.text_queue.get() |
79 |
| - if value == self.stop_signal or value in self.stop_tokens: |
80 |
| - raise StopIteration() |
| 96 | + """ |
| 97 | + Returns the next value from the text queue. |
| 98 | +
|
| 99 | + Returns: |
| 100 | + str: The next decoded text chunk. |
| 101 | +
|
| 102 | + Raises: |
| 103 | + StopIteration: If there are no more elements in the queue. |
| 104 | + """ |
| 105 | + value = self.text_queue.get() # get() will be blocked until a token is available. |
| 106 | + if value is None: |
| 107 | + raise StopIteration |
| 108 | + return value |
| 109 | + |
| 110 | + def get_stop_flag(self): |
| 111 | + """ |
| 112 | + Checks whether the generation process should be stopped. |
| 113 | +
|
| 114 | + Returns: |
| 115 | + bool: Always returns False in this implementation. |
| 116 | + """ |
| 117 | + return False |
| 118 | + |
| 119 | + def put_word(self, word: str): |
| 120 | + """ |
| 121 | + Puts a word into the text queue. |
| 122 | +
|
| 123 | + Args: |
| 124 | + word (str): The word to put into the queue. |
| 125 | + """ |
| 126 | + self.text_queue.put(word) |
| 127 | + |
| 128 | + def put(self, token_id: int) -> bool: |
| 129 | + """ |
| 130 | + Processes a token and manages the decoding buffer. Adds decoded text to the queue. |
| 131 | +
|
| 132 | + Args: |
| 133 | + token_id (int): The token_id to process. |
| 134 | +
|
| 135 | + Returns: |
| 136 | + bool: True if generation should be stopped, False otherwise. |
| 137 | + """ |
| 138 | + self.tokens_cache.append(token_id) |
| 139 | + text = self.tokenizer.decode(self.tokens_cache) |
| 140 | + |
| 141 | + word = "" |
| 142 | + if len(text) > self.print_len and "\n" == text[-1]: |
| 143 | + # Flush the cache after the new line symbol. |
| 144 | + word = text[self.print_len :] |
| 145 | + self.tokens_cache = [] |
| 146 | + self.print_len = 0 |
| 147 | + elif len(text) >= 3 and text[-3:] == chr(65533): |
| 148 | + # Don't print incomplete text. |
| 149 | + pass |
| 150 | + elif len(text) > self.print_len: |
| 151 | + # It is possible to have a shorter text after adding new token. |
| 152 | + # Print to output only if text length is increaesed. |
| 153 | + word = text[self.print_len :] |
| 154 | + self.print_len = len(text) |
| 155 | + self.put_word(word) |
| 156 | + |
| 157 | + if self.get_stop_flag(): |
| 158 | + # When generation is stopped from streamer then end is not called, need to call it here manually. |
| 159 | + self.end() |
| 160 | + return True # True means stop generation |
81 | 161 | else:
|
82 |
| - return value |
| 162 | + return False # False means continue generation |
| 163 | + |
| 164 | + def end(self): |
| 165 | + """ |
| 166 | + Flushes residual tokens from the buffer and puts a None value in the queue to signal the end. |
| 167 | + """ |
| 168 | + text = self.tokenizer.decode(self.tokens_cache) |
| 169 | + if len(text) > self.print_len: |
| 170 | + word = text[self.print_len :] |
| 171 | + self.put_word(word) |
| 172 | + self.tokens_cache = [] |
| 173 | + self.print_len = 0 |
| 174 | + self.put_word(None) |
83 | 175 |
|
84 | 176 | def reset(self):
|
85 |
| - self.text_queue = Queue() |
| 177 | + self.tokens_cache = [] |
| 178 | + self.text_queue = queue.Queue() |
| 179 | + self.print_len = 0 |
86 | 180 |
|
87 |
| - def end(self): |
88 |
| - self.text_queue.put(self.stop_signal) |
89 | 181 |
|
| 182 | +class ChunkStreamer(IterableStreamer): |
| 183 | + |
| 184 | + def __init__(self, tokenizer, tokens_len=4): |
| 185 | + super().__init__(tokenizer) |
| 186 | + self.tokens_len = tokens_len |
90 | 187 |
|
91 |
| -def make_demo(pipe, model_configuration, model_id, model_language): |
| 188 | + def put(self, token_id: int) -> bool: |
| 189 | + if (len(self.tokens_cache) + 1) % self.tokens_len != 0: |
| 190 | + self.tokens_cache.append(token_id) |
| 191 | + return False |
| 192 | + return super().put(token_id) |
| 193 | + |
| 194 | + |
| 195 | +def make_demo(pipe, model_configuration, model_id, model_language, disable_advanced=False): |
92 | 196 | import gradio as gr
|
93 | 197 |
|
94 | 198 | max_new_tokens = 256
|
@@ -135,14 +239,18 @@ def bot(message, history, temperature, top_p, top_k, repetition_penalty):
|
135 | 239 | history: updated history with message and answer from chatbot
|
136 | 240 | active_chat: if we are here, the chat is running or will be started, so return True
|
137 | 241 | """
|
138 |
| - streamer = TextQueue() |
139 |
| - config = pipe.get_generation_config() |
140 |
| - config.temperature = temperature |
141 |
| - config.top_p = top_p |
142 |
| - config.top_k = top_k |
143 |
| - config.do_sample = temperature > 0.0 |
144 |
| - config.max_new_tokens = max_new_tokens |
145 |
| - config.repetition_penalty = repetition_penalty |
| 242 | + streamer = ChunkStreamer(pipe.get_tokenizer()) |
| 243 | + if not disable_advanced: |
| 244 | + config = pipe.get_generation_config() |
| 245 | + config.temperature = temperature |
| 246 | + config.top_p = top_p |
| 247 | + config.top_k = top_k |
| 248 | + config.do_sample = temperature > 0.0 |
| 249 | + config.max_new_tokens = max_new_tokens |
| 250 | + config.repetition_penalty = repetition_penalty |
| 251 | + else: |
| 252 | + config = ov_genai.GenerationConfig() |
| 253 | + config.max_new_tokens = max_new_tokens |
146 | 254 | history = history or []
|
147 | 255 | if not history:
|
148 | 256 | pipe.start_chat(system_message=start_message)
|
@@ -204,7 +312,7 @@ def stop_chat_and_clear_history(streamer):
|
204 | 312 | submit = gr.Button("Submit")
|
205 | 313 | stop = gr.Button("Stop")
|
206 | 314 | clear = gr.Button("Clear")
|
207 |
| - with gr.Row(): |
| 315 | + with gr.Row(visible=not disable_advanced): |
208 | 316 | with gr.Accordion("Advanced Options:", open=False):
|
209 | 317 | with gr.Row():
|
210 | 318 | with gr.Column():
|
|
0 commit comments