Skip to content

Commit 2e77636

Browse files
committed
update
1 parent e1c66a6 commit 2e77636

File tree

1 file changed

+184
-0
lines changed
  • trustrag/modules/generator

1 file changed

+184
-0
lines changed

trustrag/modules/generator/llm.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,187 @@ def load_model(self):
180180
print("load model success")
181181

182182

183+
# !/usr/bin/env python
184+
# -*- coding:utf-8 _*-
185+
"""
186+
@author:quincy qiang
187+
@license: Apache Licence
188+
@file: llm.py
189+
@time: 2024/05/16
190+
@contact: yanqiangmiffy@gamil.com
191+
@software: PyCharm
192+
@description: coding..
193+
"""
194+
import os
195+
from typing import Dict, List, Any
196+
197+
import torch
198+
from openai import OpenAI
199+
from transformers import AutoTokenizer, AutoModelForCausalLM
200+
from trustrag.modules.prompt.templates import SYSTEM_PROMPT, CHAT_PROMPT_TEMPLATES
201+
202+
203+
class BaseModel:
204+
def __init__(self, path: str = '') -> None:
205+
self.path = path
206+
207+
def chat(self, prompt: str, history: List[dict], content: str) -> str:
208+
pass
209+
210+
def load_model(self):
211+
pass
212+
213+
214+
class OpenAIChat(BaseModel):
215+
def __init__(self, path: str = '', model: str = "gpt-3.5-turbo-1106") -> None:
216+
super().__init__(path)
217+
self.model = model
218+
219+
def chat(self, prompt: str, history: List[dict], content: str) -> str:
220+
client = OpenAI()
221+
client.api_key = os.getenv("OPENAI_API_KEY")
222+
client.base_url = os.getenv("OPENAI_BASE_URL")
223+
history.append({'role': 'user',
224+
'content': CHAT_PROMPT_TEMPLATES['RAG_PROMPT_TEMPALTE'].format(question=prompt,
225+
context=content)})
226+
response = client.chat.completions.create(
227+
model=self.model,
228+
messages=history,
229+
max_tokens=150,
230+
temperature=0.1
231+
)
232+
return response.choices[0].message.content
233+
234+
235+
class InternLMChat(BaseModel):
236+
def __init__(self, path: str = '') -> None:
237+
super().__init__(path)
238+
self.load_model()
239+
240+
def chat(self, prompt: str, history: List = [], content: str = '') -> str:
241+
prompt = CHAT_PROMPT_TEMPLATES['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
242+
response, history = self.model.chat(self.tokenizer, prompt, history)
243+
return response
244+
245+
def load_model(self):
246+
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
247+
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16,
248+
trust_remote_code=True).cuda()
249+
250+
251+
class GLM3Chat(BaseModel):
252+
def __init__(self, path: str = '') -> None:
253+
super().__init__(path)
254+
self.load_model()
255+
256+
def chat(self, prompt: str, history=None, content: str = '', llm_only: bool = False) -> tuple[Any, Any]:
257+
if history is None:
258+
history = []
259+
if llm_only:
260+
prompt = prompt
261+
else:
262+
prompt = CHAT_PROMPT_TEMPLATES['GLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
263+
response, history = self.model.chat(self.tokenizer, prompt, history, max_length=32000, num_beams=1,
264+
do_sample=True, top_p=0.8, temperature=0.2)
265+
return response, history
266+
267+
def load_model(self):
268+
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
269+
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16,
270+
trust_remote_code=True).cuda()
271+
272+
273+
class GLM4Chat(BaseModel):
274+
def __init__(self, path: str = '') -> None:
275+
super().__init__(path)
276+
self.load_model()
277+
278+
def chat(self, prompt: str, history=None, content: str = '', llm_only: bool = False) -> tuple[Any, Any]:
279+
if llm_only:
280+
prompt = prompt
281+
else:
282+
prompt = CHAT_PROMPT_TEMPLATES['GLM_PROMPT_TEMPALTE'].format(system_prompt=SYSTEM_PROMPT, question=prompt,
283+
context=content)
284+
prompt = prompt.encode("utf-8", 'ignore').decode('utf-8', 'ignore')
285+
print(prompt)
286+
287+
inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}],
288+
add_generation_prompt=True,
289+
tokenize=True,
290+
return_tensors="pt",
291+
return_dict=True
292+
)
293+
294+
inputs = inputs.to('cuda')
295+
gen_kwargs = {"max_length": 5120, "do_sample": False, "top_k": 1}
296+
with torch.no_grad():
297+
outputs = self.model.generate(**inputs, **gen_kwargs)
298+
outputs = outputs[:, inputs['input_ids'].shape[1]:]
299+
output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
300+
response, history = output, []
301+
return response, history
302+
303+
def load_model(self):
304+
305+
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
306+
self.model = AutoModelForCausalLM.from_pretrained(
307+
self.path,
308+
torch_dtype=torch.bfloat16,
309+
low_cpu_mem_usage=True,
310+
trust_remote_code=True
311+
).cuda().eval()
312+
313+
314+
class Qwen3Chat(BaseModel):
315+
def __init__(self, path: str = '') -> None:
316+
super().__init__(path)
317+
self.load_model()
318+
self.device = 'cuda'
319+
320+
def chat(self, prompt: str, history: List = [], content: str = '', llm_only: bool = False,
321+
enable_thinking: bool = True) -> tuple[Any, Any]:
322+
if llm_only:
323+
prompt = prompt
324+
else:
325+
# 使用适当的prompt模板,可以根据需要调整
326+
prompt = CHAT_PROMPT_TEMPLATES.get('DF_QWEN_PROMPT_TEMPLATE2', '{question}\n\n上下文:{context}').format(
327+
question=prompt, context=content)
328+
329+
messages = [
330+
{"role": "user", "content": prompt}
331+
]
332+
333+
text = self.tokenizer.apply_chat_template(
334+
messages,
335+
tokenize=False,
336+
add_generation_prompt=True,
337+
enable_thinking=enable_thinking # 支持thinking模式
338+
)
339+
340+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
341+
342+
# 生成文本,支持更大的token数量
343+
generated_ids = self.model.generate(
344+
**model_inputs,
345+
max_new_tokens=32768, # 支持更大的生成长度
346+
do_sample=False,
347+
top_k=10
348+
)
349+
350+
# 提取生成的部分
351+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
352+
response = self.tokenizer.decode(output_ids, skip_special_tokens=True)
353+
354+
return response, history
355+
356+
def load_model(self):
357+
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
358+
self.model = AutoModelForCausalLM.from_pretrained(
359+
self.path,
360+
torch_dtype="auto", # 使用auto自动选择最佳数据类型
361+
device_map="auto", # 自动设备映射
362+
trust_remote_code=True
363+
)
364+
print("Qwen3 model loaded successfully")
365+
366+

0 commit comments

Comments
 (0)