@@ -180,3 +180,187 @@ def load_model(self):
180180 print ("load model success" )
181181
182182
183+ # !/usr/bin/env python
184+ # -*- coding:utf-8 _*-
185+ """
186+ @author:quincy qiang
187+ @license: Apache Licence
188+ @file: llm.py
189+ @time: 2024/05/16
190+ @contact: yanqiangmiffy@gamil.com
191+ @software: PyCharm
192+ @description: coding..
193+ """
194+ import os
195+ from typing import Dict , List , Any
196+
197+ import torch
198+ from openai import OpenAI
199+ from transformers import AutoTokenizer , AutoModelForCausalLM
200+ from trustrag .modules .prompt .templates import SYSTEM_PROMPT , CHAT_PROMPT_TEMPLATES
201+
202+
203+ class BaseModel :
204+ def __init__ (self , path : str = '' ) -> None :
205+ self .path = path
206+
207+ def chat (self , prompt : str , history : List [dict ], content : str ) -> str :
208+ pass
209+
210+ def load_model (self ):
211+ pass
212+
213+
214+ class OpenAIChat (BaseModel ):
215+ def __init__ (self , path : str = '' , model : str = "gpt-3.5-turbo-1106" ) -> None :
216+ super ().__init__ (path )
217+ self .model = model
218+
219+ def chat (self , prompt : str , history : List [dict ], content : str ) -> str :
220+ client = OpenAI ()
221+ client .api_key = os .getenv ("OPENAI_API_KEY" )
222+ client .base_url = os .getenv ("OPENAI_BASE_URL" )
223+ history .append ({'role' : 'user' ,
224+ 'content' : CHAT_PROMPT_TEMPLATES ['RAG_PROMPT_TEMPALTE' ].format (question = prompt ,
225+ context = content )})
226+ response = client .chat .completions .create (
227+ model = self .model ,
228+ messages = history ,
229+ max_tokens = 150 ,
230+ temperature = 0.1
231+ )
232+ return response .choices [0 ].message .content
233+
234+
235+ class InternLMChat (BaseModel ):
236+ def __init__ (self , path : str = '' ) -> None :
237+ super ().__init__ (path )
238+ self .load_model ()
239+
240+ def chat (self , prompt : str , history : List = [], content : str = '' ) -> str :
241+ prompt = CHAT_PROMPT_TEMPLATES ['InternLM_PROMPT_TEMPALTE' ].format (question = prompt , context = content )
242+ response , history = self .model .chat (self .tokenizer , prompt , history )
243+ return response
244+
245+ def load_model (self ):
246+ self .tokenizer = AutoTokenizer .from_pretrained (self .path , trust_remote_code = True )
247+ self .model = AutoModelForCausalLM .from_pretrained (self .path , torch_dtype = torch .float16 ,
248+ trust_remote_code = True ).cuda ()
249+
250+
251+ class GLM3Chat (BaseModel ):
252+ def __init__ (self , path : str = '' ) -> None :
253+ super ().__init__ (path )
254+ self .load_model ()
255+
256+ def chat (self , prompt : str , history = None , content : str = '' , llm_only : bool = False ) -> tuple [Any , Any ]:
257+ if history is None :
258+ history = []
259+ if llm_only :
260+ prompt = prompt
261+ else :
262+ prompt = CHAT_PROMPT_TEMPLATES ['GLM_PROMPT_TEMPALTE' ].format (question = prompt , context = content )
263+ response , history = self .model .chat (self .tokenizer , prompt , history , max_length = 32000 , num_beams = 1 ,
264+ do_sample = True , top_p = 0.8 , temperature = 0.2 )
265+ return response , history
266+
267+ def load_model (self ):
268+ self .tokenizer = AutoTokenizer .from_pretrained (self .path , trust_remote_code = True )
269+ self .model = AutoModelForCausalLM .from_pretrained (self .path , torch_dtype = torch .float16 ,
270+ trust_remote_code = True ).cuda ()
271+
272+
273+ class GLM4Chat (BaseModel ):
274+ def __init__ (self , path : str = '' ) -> None :
275+ super ().__init__ (path )
276+ self .load_model ()
277+
278+ def chat (self , prompt : str , history = None , content : str = '' , llm_only : bool = False ) -> tuple [Any , Any ]:
279+ if llm_only :
280+ prompt = prompt
281+ else :
282+ prompt = CHAT_PROMPT_TEMPLATES ['GLM_PROMPT_TEMPALTE' ].format (system_prompt = SYSTEM_PROMPT , question = prompt ,
283+ context = content )
284+ prompt = prompt .encode ("utf-8" , 'ignore' ).decode ('utf-8' , 'ignore' )
285+ print (prompt )
286+
287+ inputs = self .tokenizer .apply_chat_template ([{"role" : "user" , "content" : prompt }],
288+ add_generation_prompt = True ,
289+ tokenize = True ,
290+ return_tensors = "pt" ,
291+ return_dict = True
292+ )
293+
294+ inputs = inputs .to ('cuda' )
295+ gen_kwargs = {"max_length" : 5120 , "do_sample" : False , "top_k" : 1 }
296+ with torch .no_grad ():
297+ outputs = self .model .generate (** inputs , ** gen_kwargs )
298+ outputs = outputs [:, inputs ['input_ids' ].shape [1 ]:]
299+ output = self .tokenizer .decode (outputs [0 ], skip_special_tokens = True )
300+ response , history = output , []
301+ return response , history
302+
303+ def load_model (self ):
304+
305+ self .tokenizer = AutoTokenizer .from_pretrained (self .path , trust_remote_code = True )
306+ self .model = AutoModelForCausalLM .from_pretrained (
307+ self .path ,
308+ torch_dtype = torch .bfloat16 ,
309+ low_cpu_mem_usage = True ,
310+ trust_remote_code = True
311+ ).cuda ().eval ()
312+
313+
314+ class Qwen3Chat (BaseModel ):
315+ def __init__ (self , path : str = '' ) -> None :
316+ super ().__init__ (path )
317+ self .load_model ()
318+ self .device = 'cuda'
319+
320+ def chat (self , prompt : str , history : List = [], content : str = '' , llm_only : bool = False ,
321+ enable_thinking : bool = True ) -> tuple [Any , Any ]:
322+ if llm_only :
323+ prompt = prompt
324+ else :
325+ # 使用适当的prompt模板,可以根据需要调整
326+ prompt = CHAT_PROMPT_TEMPLATES .get ('DF_QWEN_PROMPT_TEMPLATE2' , '{question}\n \n 上下文:{context}' ).format (
327+ question = prompt , context = content )
328+
329+ messages = [
330+ {"role" : "user" , "content" : prompt }
331+ ]
332+
333+ text = self .tokenizer .apply_chat_template (
334+ messages ,
335+ tokenize = False ,
336+ add_generation_prompt = True ,
337+ enable_thinking = enable_thinking # 支持thinking模式
338+ )
339+
340+ model_inputs = self .tokenizer ([text ], return_tensors = "pt" ).to (self .model .device )
341+
342+ # 生成文本,支持更大的token数量
343+ generated_ids = self .model .generate (
344+ ** model_inputs ,
345+ max_new_tokens = 32768 , # 支持更大的生成长度
346+ do_sample = False ,
347+ top_k = 10
348+ )
349+
350+ # 提取生成的部分
351+ output_ids = generated_ids [0 ][len (model_inputs .input_ids [0 ]):].tolist ()
352+ response = self .tokenizer .decode (output_ids , skip_special_tokens = True )
353+
354+ return response , history
355+
356+ def load_model (self ):
357+ self .tokenizer = AutoTokenizer .from_pretrained (self .path , trust_remote_code = True )
358+ self .model = AutoModelForCausalLM .from_pretrained (
359+ self .path ,
360+ torch_dtype = "auto" , # 使用auto自动选择最佳数据类型
361+ device_map = "auto" , # 自动设备映射
362+ trust_remote_code = True
363+ )
364+ print ("Qwen3 model loaded successfully" )
365+
366+
0 commit comments