1+ """
2+ This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
3+ We mainly modified part of the code logic to adapt to the generation of our model.
4+ Please refer to these links below for more information:
5+ 1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
6+ 2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
7+ 3. transformers: https://github.com/huggingface/transformers
8+ Please run with the command `streamlit run path/to/web_demo.py --server.address=0.0.0.0 --server.port 7860`.
9+ Using `python path/to/web_demo.py` may cause unknown problems.
10+ """
11+ import copy
12+ import warnings
13+ from dataclasses import asdict , dataclass
14+ from typing import Callable , List , Optional
15+
16+ import streamlit as st
17+ import torch
18+ from torch import nn
19+ from transformers .generation .utils import LogitsProcessorList , StoppingCriteriaList
20+ from transformers .utils import logging
21+
22+ from transformers import AutoTokenizer , AutoModelForCausalLM # isort: skip
23+ from openxlab .model import download
24+
25+ logger = logging .get_logger (__name__ )
26+
27+ download (model_repo = 'ajupyter/EmoLLM_aiwei' ,
28+ output = 'model' )
29+
30+ @dataclass
31+ class GenerationConfig :
32+ # this config is used for chat to provide more diversity
33+ max_length : int = 32768
34+ top_p : float = 0.8
35+ temperature : float = 0.8
36+ do_sample : bool = True
37+ repetition_penalty : float = 1.005
38+
39+
40+ @torch .inference_mode ()
41+ def generate_interactive (
42+ model ,
43+ tokenizer ,
44+ prompt ,
45+ generation_config : Optional [GenerationConfig ] = None ,
46+ logits_processor : Optional [LogitsProcessorList ] = None ,
47+ stopping_criteria : Optional [StoppingCriteriaList ] = None ,
48+ prefix_allowed_tokens_fn : Optional [Callable [[int , torch .Tensor ], List [int ]]] = None ,
49+ additional_eos_token_id : Optional [int ] = None ,
50+ ** kwargs ,
51+ ):
52+ inputs = tokenizer ([prompt ], padding = True , return_tensors = "pt" )
53+ input_length = len (inputs ["input_ids" ][0 ])
54+ for k , v in inputs .items ():
55+ inputs [k ] = v .cuda ()
56+ input_ids = inputs ["input_ids" ]
57+ batch_size , input_ids_seq_length = input_ids .shape [0 ], input_ids .shape [- 1 ] # noqa: F841 # pylint: disable=W0612
58+ if generation_config is None :
59+ generation_config = model .generation_config
60+ generation_config = copy .deepcopy (generation_config )
61+ model_kwargs = generation_config .update (** kwargs )
62+ bos_token_id , eos_token_id = ( # noqa: F841 # pylint: disable=W0612
63+ generation_config .bos_token_id ,
64+ generation_config .eos_token_id ,
65+ )
66+ if isinstance (eos_token_id , int ):
67+ eos_token_id = [eos_token_id ]
68+ if additional_eos_token_id is not None :
69+ eos_token_id .append (additional_eos_token_id )
70+ has_default_max_length = kwargs .get ("max_length" ) is None and generation_config .max_length is not None
71+ if has_default_max_length and generation_config .max_new_tokens is None :
72+ warnings .warn (
73+ f"Using `max_length`'s default ({ generation_config .max_length } ) to control the generation length. "
74+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
75+ " recommend using `max_new_tokens` to control the maximum length of the generation." ,
76+ UserWarning ,
77+ )
78+ elif generation_config .max_new_tokens is not None :
79+ generation_config .max_length = generation_config .max_new_tokens + input_ids_seq_length
80+ if not has_default_max_length :
81+ logger .warn ( # pylint: disable=W4902
82+ f"Both `max_new_tokens` (={ generation_config .max_new_tokens } ) and `max_length`(="
83+ f"{ generation_config .max_length } ) seem to have been set. `max_new_tokens` will take precedence. "
84+ "Please refer to the documentation for more information. "
85+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ,
86+ UserWarning ,
87+ )
88+
89+ if input_ids_seq_length >= generation_config .max_length :
90+ input_ids_string = "input_ids"
91+ logger .warning (
92+ f"Input length of { input_ids_string } is { input_ids_seq_length } , but `max_length` is set to"
93+ f" { generation_config .max_length } . This can lead to unexpected behavior. You should consider"
94+ " increasing `max_new_tokens`."
95+ )
96+
97+ # 2. Set generation parameters if not already defined
98+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList ()
99+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList ()
100+
101+ logits_processor = model ._get_logits_processor (
102+ generation_config = generation_config ,
103+ input_ids_seq_length = input_ids_seq_length ,
104+ encoder_input_ids = input_ids ,
105+ prefix_allowed_tokens_fn = prefix_allowed_tokens_fn ,
106+ logits_processor = logits_processor ,
107+ )
108+
109+ stopping_criteria = model ._get_stopping_criteria (
110+ generation_config = generation_config , stopping_criteria = stopping_criteria
111+ )
112+ logits_warper = model ._get_logits_warper (generation_config )
113+
114+ unfinished_sequences = input_ids .new (input_ids .shape [0 ]).fill_ (1 )
115+ scores = None
116+ while True :
117+ model_inputs = model .prepare_inputs_for_generation (input_ids , ** model_kwargs )
118+ # forward pass to get next token
119+ outputs = model (
120+ ** model_inputs ,
121+ return_dict = True ,
122+ output_attentions = False ,
123+ output_hidden_states = False ,
124+ )
125+
126+ next_token_logits = outputs .logits [:, - 1 , :]
127+
128+ # pre-process distribution
129+ next_token_scores = logits_processor (input_ids , next_token_logits )
130+ next_token_scores = logits_warper (input_ids , next_token_scores )
131+
132+ # sample
133+ probs = nn .functional .softmax (next_token_scores , dim = - 1 )
134+ if generation_config .do_sample :
135+ next_tokens = torch .multinomial (probs , num_samples = 1 ).squeeze (1 )
136+ else :
137+ next_tokens = torch .argmax (probs , dim = - 1 )
138+
139+ # update generated ids, model inputs, and length for next step
140+ input_ids = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
141+ model_kwargs = model ._update_model_kwargs_for_generation (outputs , model_kwargs , is_encoder_decoder = False )
142+ unfinished_sequences = unfinished_sequences .mul ((min (next_tokens != i for i in eos_token_id )).long ())
143+
144+ output_token_ids = input_ids [0 ].cpu ().tolist ()
145+ output_token_ids = output_token_ids [input_length :]
146+ for each_eos_token_id in eos_token_id :
147+ if output_token_ids [- 1 ] == each_eos_token_id :
148+ output_token_ids = output_token_ids [:- 1 ]
149+ response = tokenizer .decode (output_token_ids )
150+
151+ yield response
152+ # stop when each sentence is finished, or if we exceed the maximum length
153+ if unfinished_sequences .max () == 0 or stopping_criteria (input_ids , scores ):
154+ break
155+
156+
157+ def on_btn_click ():
158+ del st .session_state .messages
159+
160+
161+ @st .cache_resource
162+ def load_model ():
163+ model = (
164+ AutoModelForCausalLM .from_pretrained ("model" , trust_remote_code = True )
165+ .to (torch .bfloat16 )
166+ .cuda ()
167+ )
168+ tokenizer = AutoTokenizer .from_pretrained ("model" , trust_remote_code = True )
169+ return model , tokenizer
170+
171+
172+ def prepare_generation_config ():
173+ with st .sidebar :
174+ # 使用 Streamlit 的 markdown 函数添加 Markdown 文本
175+ st .image ('assets/aiwei_logo.jpg' , width = 1 , caption = 'EmoLLM-aiwei AI Logo' , use_column_width = True )
176+ st .markdown ("[访问 EmoLLM 官方repo](https://github.com/aJupyter/EmoLLM)" )
177+
178+ max_length = st .slider ("Max Length" , min_value = 8 , max_value = 32768 , value = 32768 )
179+ top_p = st .slider ("Top P" , 0.0 , 1.0 , 0.8 , step = 0.01 )
180+ temperature = st .slider ("Temperature" , 0.0 , 1.0 , 0.7 , step = 0.01 )
181+ st .button ("Clear Chat History" , on_click = on_btn_click )
182+
183+ generation_config = GenerationConfig (max_length = max_length , top_p = top_p , temperature = temperature )
184+
185+ return generation_config
186+
187+
188+ user_prompt = "<|im_start|>user\n {user}<|im_end|>\n "
189+ robot_prompt = "<|im_start|>assistant\n {robot}<|im_end|>\n "
190+ cur_query_prompt = "<|im_start|>user\n {user}<|im_end|>\n <|im_start|>assistant\n "
191+
192+
193+ def combine_history (prompt ):
194+ messages = st .session_state .messages
195+ meta_instruction = (
196+ "你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n "
197+ )
198+ total_prompt = f"<s><|im_start|>system\n { meta_instruction } <|im_end|>\n "
199+ for message in messages :
200+ cur_content = message ["content" ]
201+ if message ["role" ] == "user" :
202+ cur_prompt = user_prompt .format (user = cur_content )
203+ elif message ["role" ] == "robot" :
204+ cur_prompt = robot_prompt .format (robot = cur_content )
205+ else :
206+ raise RuntimeError
207+ total_prompt += cur_prompt
208+ total_prompt = total_prompt + cur_query_prompt .format (user = prompt )
209+ return total_prompt
210+
211+
212+ def main ():
213+ # torch.cuda.empty_cache()
214+ print ("load model begin." )
215+ model , tokenizer = load_model ()
216+ print ("load model end." )
217+
218+ user_avator = "assets/user.png"
219+ robot_avator = "assets/robot.jpeg"
220+
221+ st .title ("EmoLLM-温柔御姐艾薇(aiwei)" )
222+
223+ generation_config = prepare_generation_config ()
224+
225+ # Initialize chat history
226+ if "messages" not in st .session_state :
227+ st .session_state .messages = []
228+
229+ # Display chat messages from history on app rerun
230+ for message in st .session_state .messages :
231+ with st .chat_message (message ["role" ], avatar = message .get ("avatar" )):
232+ st .markdown (message ["content" ])
233+
234+ # Accept user input
235+ if prompt := st .chat_input ("What is up?" ):
236+ # Display user message in chat message container
237+ with st .chat_message ("user" , avatar = user_avator ):
238+ st .markdown (prompt )
239+ real_prompt = combine_history (prompt )
240+ # Add user message to chat history
241+ st .session_state .messages .append ({"role" : "user" , "content" : prompt , "avatar" : user_avator })
242+
243+ with st .chat_message ("robot" , avatar = robot_avator ):
244+ message_placeholder = st .empty ()
245+ for cur_response in generate_interactive (
246+ model = model ,
247+ tokenizer = tokenizer ,
248+ prompt = real_prompt ,
249+ additional_eos_token_id = 92542 ,
250+ ** asdict (generation_config ),
251+ ):
252+ # Display robot response in chat message container
253+ message_placeholder .markdown (cur_response + "▌" )
254+ message_placeholder .markdown (cur_response ) # pylint: disable=undefined-loop-variable
255+ # Add robot response to chat history
256+ st .session_state .messages .append (
257+ {
258+ "role" : "robot" ,
259+ "content" : cur_response , # pylint: disable=undefined-loop-variable
260+ "avatar" : robot_avator ,
261+ }
262+ )
263+ torch .cuda .empty_cache ()
264+
265+
266+ if __name__ == "__main__" :
267+ main ()
0 commit comments