1
+ """
2
+ This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
3
+ We mainly modified part of the code logic to adapt to the generation of our model.
4
+ Please refer to these links below for more information:
5
+ 1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
6
+ 2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
7
+ 3. transformers: https://github.com/huggingface/transformers
8
+ Please run with the command `streamlit run path/to/web_demo.py --server.address=0.0.0.0 --server.port 7860`.
9
+ Using `python path/to/web_demo.py` may cause unknown problems.
10
+ """
11
+ import copy
12
+ import warnings
13
+ from dataclasses import asdict , dataclass
14
+ from typing import Callable , List , Optional
15
+
16
+ import streamlit as st
17
+ import torch
18
+ from torch import nn
19
+ from transformers .generation .utils import LogitsProcessorList , StoppingCriteriaList
20
+ from transformers .utils import logging
21
+
22
+ from transformers import AutoTokenizer , AutoModelForCausalLM # isort: skip
23
+ from openxlab .model import download
24
+
25
+ logger = logging .get_logger (__name__ )
26
+
27
+ download (model_repo = 'ajupyter/EmoLLM_aiwei' ,
28
+ output = 'model' )
29
+
30
+ @dataclass
31
+ class GenerationConfig :
32
+ # this config is used for chat to provide more diversity
33
+ max_length : int = 32768
34
+ top_p : float = 0.8
35
+ temperature : float = 0.8
36
+ do_sample : bool = True
37
+ repetition_penalty : float = 1.005
38
+
39
+
40
+ @torch .inference_mode ()
41
+ def generate_interactive (
42
+ model ,
43
+ tokenizer ,
44
+ prompt ,
45
+ generation_config : Optional [GenerationConfig ] = None ,
46
+ logits_processor : Optional [LogitsProcessorList ] = None ,
47
+ stopping_criteria : Optional [StoppingCriteriaList ] = None ,
48
+ prefix_allowed_tokens_fn : Optional [Callable [[int , torch .Tensor ], List [int ]]] = None ,
49
+ additional_eos_token_id : Optional [int ] = None ,
50
+ ** kwargs ,
51
+ ):
52
+ inputs = tokenizer ([prompt ], padding = True , return_tensors = "pt" )
53
+ input_length = len (inputs ["input_ids" ][0 ])
54
+ for k , v in inputs .items ():
55
+ inputs [k ] = v .cuda ()
56
+ input_ids = inputs ["input_ids" ]
57
+ batch_size , input_ids_seq_length = input_ids .shape [0 ], input_ids .shape [- 1 ] # noqa: F841 # pylint: disable=W0612
58
+ if generation_config is None :
59
+ generation_config = model .generation_config
60
+ generation_config = copy .deepcopy (generation_config )
61
+ model_kwargs = generation_config .update (** kwargs )
62
+ bos_token_id , eos_token_id = ( # noqa: F841 # pylint: disable=W0612
63
+ generation_config .bos_token_id ,
64
+ generation_config .eos_token_id ,
65
+ )
66
+ if isinstance (eos_token_id , int ):
67
+ eos_token_id = [eos_token_id ]
68
+ if additional_eos_token_id is not None :
69
+ eos_token_id .append (additional_eos_token_id )
70
+ has_default_max_length = kwargs .get ("max_length" ) is None and generation_config .max_length is not None
71
+ if has_default_max_length and generation_config .max_new_tokens is None :
72
+ warnings .warn (
73
+ f"Using `max_length`'s default ({ generation_config .max_length } ) to control the generation length. "
74
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
75
+ " recommend using `max_new_tokens` to control the maximum length of the generation." ,
76
+ UserWarning ,
77
+ )
78
+ elif generation_config .max_new_tokens is not None :
79
+ generation_config .max_length = generation_config .max_new_tokens + input_ids_seq_length
80
+ if not has_default_max_length :
81
+ logger .warn ( # pylint: disable=W4902
82
+ f"Both `max_new_tokens` (={ generation_config .max_new_tokens } ) and `max_length`(="
83
+ f"{ generation_config .max_length } ) seem to have been set. `max_new_tokens` will take precedence. "
84
+ "Please refer to the documentation for more information. "
85
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ,
86
+ UserWarning ,
87
+ )
88
+
89
+ if input_ids_seq_length >= generation_config .max_length :
90
+ input_ids_string = "input_ids"
91
+ logger .warning (
92
+ f"Input length of { input_ids_string } is { input_ids_seq_length } , but `max_length` is set to"
93
+ f" { generation_config .max_length } . This can lead to unexpected behavior. You should consider"
94
+ " increasing `max_new_tokens`."
95
+ )
96
+
97
+ # 2. Set generation parameters if not already defined
98
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList ()
99
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList ()
100
+
101
+ logits_processor = model ._get_logits_processor (
102
+ generation_config = generation_config ,
103
+ input_ids_seq_length = input_ids_seq_length ,
104
+ encoder_input_ids = input_ids ,
105
+ prefix_allowed_tokens_fn = prefix_allowed_tokens_fn ,
106
+ logits_processor = logits_processor ,
107
+ )
108
+
109
+ stopping_criteria = model ._get_stopping_criteria (
110
+ generation_config = generation_config , stopping_criteria = stopping_criteria
111
+ )
112
+ logits_warper = model ._get_logits_warper (generation_config )
113
+
114
+ unfinished_sequences = input_ids .new (input_ids .shape [0 ]).fill_ (1 )
115
+ scores = None
116
+ while True :
117
+ model_inputs = model .prepare_inputs_for_generation (input_ids , ** model_kwargs )
118
+ # forward pass to get next token
119
+ outputs = model (
120
+ ** model_inputs ,
121
+ return_dict = True ,
122
+ output_attentions = False ,
123
+ output_hidden_states = False ,
124
+ )
125
+
126
+ next_token_logits = outputs .logits [:, - 1 , :]
127
+
128
+ # pre-process distribution
129
+ next_token_scores = logits_processor (input_ids , next_token_logits )
130
+ next_token_scores = logits_warper (input_ids , next_token_scores )
131
+
132
+ # sample
133
+ probs = nn .functional .softmax (next_token_scores , dim = - 1 )
134
+ if generation_config .do_sample :
135
+ next_tokens = torch .multinomial (probs , num_samples = 1 ).squeeze (1 )
136
+ else :
137
+ next_tokens = torch .argmax (probs , dim = - 1 )
138
+
139
+ # update generated ids, model inputs, and length for next step
140
+ input_ids = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
141
+ model_kwargs = model ._update_model_kwargs_for_generation (outputs , model_kwargs , is_encoder_decoder = False )
142
+ unfinished_sequences = unfinished_sequences .mul ((min (next_tokens != i for i in eos_token_id )).long ())
143
+
144
+ output_token_ids = input_ids [0 ].cpu ().tolist ()
145
+ output_token_ids = output_token_ids [input_length :]
146
+ for each_eos_token_id in eos_token_id :
147
+ if output_token_ids [- 1 ] == each_eos_token_id :
148
+ output_token_ids = output_token_ids [:- 1 ]
149
+ response = tokenizer .decode (output_token_ids )
150
+
151
+ yield response
152
+ # stop when each sentence is finished, or if we exceed the maximum length
153
+ if unfinished_sequences .max () == 0 or stopping_criteria (input_ids , scores ):
154
+ break
155
+
156
+
157
+ def on_btn_click ():
158
+ del st .session_state .messages
159
+
160
+
161
+ @st .cache_resource
162
+ def load_model ():
163
+ model = (
164
+ AutoModelForCausalLM .from_pretrained ("model" , trust_remote_code = True )
165
+ .to (torch .bfloat16 )
166
+ .cuda ()
167
+ )
168
+ tokenizer = AutoTokenizer .from_pretrained ("model" , trust_remote_code = True )
169
+ return model , tokenizer
170
+
171
+
172
+ def prepare_generation_config ():
173
+ with st .sidebar :
174
+ # 使用 Streamlit 的 markdown 函数添加 Markdown 文本
175
+ st .image ('assets/aiwei_logo.jpg' , width = 1 , caption = 'EmoLLM-aiwei AI Logo' , use_column_width = True )
176
+ st .markdown ("[访问 EmoLLM 官方repo](https://github.com/aJupyter/EmoLLM)" )
177
+
178
+ max_length = st .slider ("Max Length" , min_value = 8 , max_value = 32768 , value = 32768 )
179
+ top_p = st .slider ("Top P" , 0.0 , 1.0 , 0.8 , step = 0.01 )
180
+ temperature = st .slider ("Temperature" , 0.0 , 1.0 , 0.7 , step = 0.01 )
181
+ st .button ("Clear Chat History" , on_click = on_btn_click )
182
+
183
+ generation_config = GenerationConfig (max_length = max_length , top_p = top_p , temperature = temperature )
184
+
185
+ return generation_config
186
+
187
+
188
+ user_prompt = "<|im_start|>user\n {user}<|im_end|>\n "
189
+ robot_prompt = "<|im_start|>assistant\n {robot}<|im_end|>\n "
190
+ cur_query_prompt = "<|im_start|>user\n {user}<|im_end|>\n <|im_start|>assistant\n "
191
+
192
+
193
+ def combine_history (prompt ):
194
+ messages = st .session_state .messages
195
+ meta_instruction = (
196
+ "你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n "
197
+ )
198
+ total_prompt = f"<s><|im_start|>system\n { meta_instruction } <|im_end|>\n "
199
+ for message in messages :
200
+ cur_content = message ["content" ]
201
+ if message ["role" ] == "user" :
202
+ cur_prompt = user_prompt .format (user = cur_content )
203
+ elif message ["role" ] == "robot" :
204
+ cur_prompt = robot_prompt .format (robot = cur_content )
205
+ else :
206
+ raise RuntimeError
207
+ total_prompt += cur_prompt
208
+ total_prompt = total_prompt + cur_query_prompt .format (user = prompt )
209
+ return total_prompt
210
+
211
+
212
+ def main ():
213
+ # torch.cuda.empty_cache()
214
+ print ("load model begin." )
215
+ model , tokenizer = load_model ()
216
+ print ("load model end." )
217
+
218
+ user_avator = "assets/user.png"
219
+ robot_avator = "assets/robot.jpeg"
220
+
221
+ st .title ("EmoLLM-温柔御姐艾薇(aiwei)" )
222
+
223
+ generation_config = prepare_generation_config ()
224
+
225
+ # Initialize chat history
226
+ if "messages" not in st .session_state :
227
+ st .session_state .messages = []
228
+
229
+ # Display chat messages from history on app rerun
230
+ for message in st .session_state .messages :
231
+ with st .chat_message (message ["role" ], avatar = message .get ("avatar" )):
232
+ st .markdown (message ["content" ])
233
+
234
+ # Accept user input
235
+ if prompt := st .chat_input ("What is up?" ):
236
+ # Display user message in chat message container
237
+ with st .chat_message ("user" , avatar = user_avator ):
238
+ st .markdown (prompt )
239
+ real_prompt = combine_history (prompt )
240
+ # Add user message to chat history
241
+ st .session_state .messages .append ({"role" : "user" , "content" : prompt , "avatar" : user_avator })
242
+
243
+ with st .chat_message ("robot" , avatar = robot_avator ):
244
+ message_placeholder = st .empty ()
245
+ for cur_response in generate_interactive (
246
+ model = model ,
247
+ tokenizer = tokenizer ,
248
+ prompt = real_prompt ,
249
+ additional_eos_token_id = 92542 ,
250
+ ** asdict (generation_config ),
251
+ ):
252
+ # Display robot response in chat message container
253
+ message_placeholder .markdown (cur_response + "▌" )
254
+ message_placeholder .markdown (cur_response ) # pylint: disable=undefined-loop-variable
255
+ # Add robot response to chat history
256
+ st .session_state .messages .append (
257
+ {
258
+ "role" : "robot" ,
259
+ "content" : cur_response , # pylint: disable=undefined-loop-variable
260
+ "avatar" : robot_avator ,
261
+ }
262
+ )
263
+ torch .cuda .empty_cache ()
264
+
265
+
266
+ if __name__ == "__main__" :
267
+ main ()
0 commit comments