Skip to content

Commit c696e16

Browse files
committed
feat: Update Aiwei configuration.
2 parents 41d2b8c + ddd8f79 commit c696e16

File tree

5 files changed

+488
-2
lines changed

5 files changed

+488
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
- 评估和诊断工具:为了有效促进心理健康,需要有科学的工具来评估个体的心理状态,以及诊断可能存在的心理问题。
6161

6262
### 最近更新
63-
63+
- 【2024.2.23】推出基于InternLM2_7B_chat_qlora的 `温柔御姐心理医生艾薇`[点击获取模型权重](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_aiwei)[配置文件](xtuner_config/aiwei-internlm2_chat_7b_qlora.py)
6464
- 【2024.2.23】更新[若干微调配置](/xtuner_config/),新增 [data_pro.json](/datasets/data_pro.json)(数量更多、场景更全、更丰富)和 [aiwei.json](/datasets/aiwei.json)(温柔御姐角色扮演专用,带有Emoji表情),即将推出 `温柔御姐心理医生艾薇`
6565
- 【2024.2.18】 [基于Qwen1_5-0_5B-Chat全量微调版本开源](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary),算力有限的道友可以玩起来~
6666
- 【2024.2.6】 EmoLLM在[**Openxlab** ](https://openxlab.org.cn/models/detail/jujimeizuo/EmoLLM_Model) 平台下载量高达18.7k,欢迎大家体验!

app.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
import os
2-
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
2+
# os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
3+
os.system('streamlit run web_demo-aiwei.py --server.address=0.0.0.0 --server.port 7860')

assets/aiwei_logo.jpg

79.6 KB
Loading

web_demo-aiwei.py

+267
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
"""
2+
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
3+
We mainly modified part of the code logic to adapt to the generation of our model.
4+
Please refer to these links below for more information:
5+
1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
6+
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
7+
3. transformers: https://github.com/huggingface/transformers
8+
Please run with the command `streamlit run path/to/web_demo.py --server.address=0.0.0.0 --server.port 7860`.
9+
Using `python path/to/web_demo.py` may cause unknown problems.
10+
"""
11+
import copy
12+
import warnings
13+
from dataclasses import asdict, dataclass
14+
from typing import Callable, List, Optional
15+
16+
import streamlit as st
17+
import torch
18+
from torch import nn
19+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
20+
from transformers.utils import logging
21+
22+
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
23+
from openxlab.model import download
24+
25+
logger = logging.get_logger(__name__)
26+
27+
download(model_repo='ajupyter/EmoLLM_aiwei',
28+
output='model')
29+
30+
@dataclass
31+
class GenerationConfig:
32+
# this config is used for chat to provide more diversity
33+
max_length: int = 32768
34+
top_p: float = 0.8
35+
temperature: float = 0.8
36+
do_sample: bool = True
37+
repetition_penalty: float = 1.005
38+
39+
40+
@torch.inference_mode()
41+
def generate_interactive(
42+
model,
43+
tokenizer,
44+
prompt,
45+
generation_config: Optional[GenerationConfig] = None,
46+
logits_processor: Optional[LogitsProcessorList] = None,
47+
stopping_criteria: Optional[StoppingCriteriaList] = None,
48+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
49+
additional_eos_token_id: Optional[int] = None,
50+
**kwargs,
51+
):
52+
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
53+
input_length = len(inputs["input_ids"][0])
54+
for k, v in inputs.items():
55+
inputs[k] = v.cuda()
56+
input_ids = inputs["input_ids"]
57+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
58+
if generation_config is None:
59+
generation_config = model.generation_config
60+
generation_config = copy.deepcopy(generation_config)
61+
model_kwargs = generation_config.update(**kwargs)
62+
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
63+
generation_config.bos_token_id,
64+
generation_config.eos_token_id,
65+
)
66+
if isinstance(eos_token_id, int):
67+
eos_token_id = [eos_token_id]
68+
if additional_eos_token_id is not None:
69+
eos_token_id.append(additional_eos_token_id)
70+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
71+
if has_default_max_length and generation_config.max_new_tokens is None:
72+
warnings.warn(
73+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
74+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
75+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
76+
UserWarning,
77+
)
78+
elif generation_config.max_new_tokens is not None:
79+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
80+
if not has_default_max_length:
81+
logger.warn( # pylint: disable=W4902
82+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
83+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
84+
"Please refer to the documentation for more information. "
85+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
86+
UserWarning,
87+
)
88+
89+
if input_ids_seq_length >= generation_config.max_length:
90+
input_ids_string = "input_ids"
91+
logger.warning(
92+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
93+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
94+
" increasing `max_new_tokens`."
95+
)
96+
97+
# 2. Set generation parameters if not already defined
98+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
99+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
100+
101+
logits_processor = model._get_logits_processor(
102+
generation_config=generation_config,
103+
input_ids_seq_length=input_ids_seq_length,
104+
encoder_input_ids=input_ids,
105+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
106+
logits_processor=logits_processor,
107+
)
108+
109+
stopping_criteria = model._get_stopping_criteria(
110+
generation_config=generation_config, stopping_criteria=stopping_criteria
111+
)
112+
logits_warper = model._get_logits_warper(generation_config)
113+
114+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
115+
scores = None
116+
while True:
117+
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
118+
# forward pass to get next token
119+
outputs = model(
120+
**model_inputs,
121+
return_dict=True,
122+
output_attentions=False,
123+
output_hidden_states=False,
124+
)
125+
126+
next_token_logits = outputs.logits[:, -1, :]
127+
128+
# pre-process distribution
129+
next_token_scores = logits_processor(input_ids, next_token_logits)
130+
next_token_scores = logits_warper(input_ids, next_token_scores)
131+
132+
# sample
133+
probs = nn.functional.softmax(next_token_scores, dim=-1)
134+
if generation_config.do_sample:
135+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
136+
else:
137+
next_tokens = torch.argmax(probs, dim=-1)
138+
139+
# update generated ids, model inputs, and length for next step
140+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
141+
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
142+
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
143+
144+
output_token_ids = input_ids[0].cpu().tolist()
145+
output_token_ids = output_token_ids[input_length:]
146+
for each_eos_token_id in eos_token_id:
147+
if output_token_ids[-1] == each_eos_token_id:
148+
output_token_ids = output_token_ids[:-1]
149+
response = tokenizer.decode(output_token_ids)
150+
151+
yield response
152+
# stop when each sentence is finished, or if we exceed the maximum length
153+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
154+
break
155+
156+
157+
def on_btn_click():
158+
del st.session_state.messages
159+
160+
161+
@st.cache_resource
162+
def load_model():
163+
model = (
164+
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
165+
.to(torch.bfloat16)
166+
.cuda()
167+
)
168+
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
169+
return model, tokenizer
170+
171+
172+
def prepare_generation_config():
173+
with st.sidebar:
174+
# 使用 Streamlit 的 markdown 函数添加 Markdown 文本
175+
st.image('assets/aiwei_logo.jpg', width=1, caption='EmoLLM-aiwei AI Logo', use_column_width=True)
176+
st.markdown("[访问 EmoLLM 官方repo](https://github.com/aJupyter/EmoLLM)")
177+
178+
max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768)
179+
top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
180+
temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
181+
st.button("Clear Chat History", on_click=on_btn_click)
182+
183+
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
184+
185+
return generation_config
186+
187+
188+
user_prompt = "<|im_start|>user\n{user}<|im_end|>\n"
189+
robot_prompt = "<|im_start|>assistant\n{robot}<|im_end|>\n"
190+
cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"
191+
192+
193+
def combine_history(prompt):
194+
messages = st.session_state.messages
195+
meta_instruction = (
196+
"你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n"
197+
)
198+
total_prompt = f"<s><|im_start|>system\n{meta_instruction}<|im_end|>\n"
199+
for message in messages:
200+
cur_content = message["content"]
201+
if message["role"] == "user":
202+
cur_prompt = user_prompt.format(user=cur_content)
203+
elif message["role"] == "robot":
204+
cur_prompt = robot_prompt.format(robot=cur_content)
205+
else:
206+
raise RuntimeError
207+
total_prompt += cur_prompt
208+
total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
209+
return total_prompt
210+
211+
212+
def main():
213+
# torch.cuda.empty_cache()
214+
print("load model begin.")
215+
model, tokenizer = load_model()
216+
print("load model end.")
217+
218+
user_avator = "assets/user.png"
219+
robot_avator = "assets/robot.jpeg"
220+
221+
st.title("EmoLLM-温柔御姐艾薇(aiwei)")
222+
223+
generation_config = prepare_generation_config()
224+
225+
# Initialize chat history
226+
if "messages" not in st.session_state:
227+
st.session_state.messages = []
228+
229+
# Display chat messages from history on app rerun
230+
for message in st.session_state.messages:
231+
with st.chat_message(message["role"], avatar=message.get("avatar")):
232+
st.markdown(message["content"])
233+
234+
# Accept user input
235+
if prompt := st.chat_input("What is up?"):
236+
# Display user message in chat message container
237+
with st.chat_message("user", avatar=user_avator):
238+
st.markdown(prompt)
239+
real_prompt = combine_history(prompt)
240+
# Add user message to chat history
241+
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
242+
243+
with st.chat_message("robot", avatar=robot_avator):
244+
message_placeholder = st.empty()
245+
for cur_response in generate_interactive(
246+
model=model,
247+
tokenizer=tokenizer,
248+
prompt=real_prompt,
249+
additional_eos_token_id=92542,
250+
**asdict(generation_config),
251+
):
252+
# Display robot response in chat message container
253+
message_placeholder.markdown(cur_response + "▌")
254+
message_placeholder.markdown(cur_response) # pylint: disable=undefined-loop-variable
255+
# Add robot response to chat history
256+
st.session_state.messages.append(
257+
{
258+
"role": "robot",
259+
"content": cur_response, # pylint: disable=undefined-loop-variable
260+
"avatar": robot_avator,
261+
}
262+
)
263+
torch.cuda.empty_cache()
264+
265+
266+
if __name__ == "__main__":
267+
main()

0 commit comments

Comments
 (0)