@@ -36,13 +36,14 @@ class LibChatLLM:
3636 _obj2id = {}
3737 _id2obj = {}
3838
39- def __init__ (self , lib : str = '' , model_storage : str = '' ) -> None :
39+ def __init__ (self , lib : str = '' , model_storage : str = '' , init_params : list [ str ] = [] ) -> None :
4040
4141 if lib == '' :
4242 lib = os .path .dirname (os .path .abspath (sys .argv [0 ]))
4343 self ._lib_path = lib
4444 self .model_storage = os .path .abspath (model_storage if model_storage != '' else os .path .join (lib , '..' , 'quantized' ))
4545
46+ init_params = ['--ggml_dir' , lib ] + init_params
4647 lib = os .path .join (lib , 'libchatllm.' )
4748 if sys .platform == 'win32' :
4849 lib = lib + 'dll'
@@ -60,6 +61,17 @@ def __init__(self, lib: str = '', model_storage: str = '') -> None:
6061 self ._PRINTFUNC = CFUNCTYPE (None , c_void_p , c_int , c_char_p )
6162 self ._ENDFUNC = CFUNCTYPE (None , c_void_p )
6263
64+ _chatllm_append_init_param = self ._lib .chatllm_append_init_param
65+ _chatllm_append_init_param .restype = None
66+ _chatllm_append_init_param .argtypes = [c_char_p ]
67+ _chatllm_init = self ._lib .chatllm_init
68+ _chatllm_init .restype = c_int
69+ _chatllm_init .argtypes = []
70+
71+ for s in init_params :
72+ _chatllm_append_init_param (c_char_p (s .encode ()))
73+ assert _chatllm_init () == 0
74+
6375 self ._chatllm_create = self ._lib .chatllm_create
6476 self ._chatllm_append_param = self ._lib .chatllm_append_param
6577 self ._chatllm_start = self ._lib .chatllm_start
@@ -170,8 +182,12 @@ def callback_print(user_data: int, print_type: c_int, s: bytes) -> None:
170182 obj .callback_text_tokenize (txt )
171183 elif print_type == PrintType .PRINTLN_ERROR .value :
172184 raise Exception (txt )
185+ elif print_type == PrintType .PRINTLN_LOGGING .value :
186+ obj .callback_print_log (txt )
173187 elif print_type == PrintType .PRINTLN_BEAM_SEARCH .value :
174188 obj .callback_print_beam_search (txt )
189+ elif print_type == PrintType .PRINT_EVT_ASYNC_COMPLETED .value :
190+ obj .callback_async_done ()
175191 else :
176192 raise Exception (f"unhandled print_type({ print_type } ): { txt } " )
177193
@@ -372,6 +388,8 @@ def load_session(self, file_name: str) -> str:
372388
373389 def callback_print_reference (self , s : str ) -> None :
374390 self .references .append (s )
391+ def callback_print_log (self , s : str ) -> None :
392+ print (s )
375393
376394 def callback_print_beam_search (self , s : str ) -> None :
377395 l = s .split (',' , maxsplit = 1 )
0 commit comments