@@ -36,13 +36,14 @@ class LibChatLLM:
36
36
_obj2id = {}
37
37
_id2obj = {}
38
38
39
- def __init__ (self , lib : str = '' , model_storage : str = '' ) -> None :
39
+ def __init__ (self , lib : str = '' , model_storage : str = '' , init_params : list [ str ] = [] ) -> None :
40
40
41
41
if lib == '' :
42
42
lib = os .path .dirname (os .path .abspath (sys .argv [0 ]))
43
43
self ._lib_path = lib
44
44
self .model_storage = os .path .abspath (model_storage if model_storage != '' else os .path .join (lib , '..' , 'quantized' ))
45
45
46
+ init_params = ['--ggml_dir' , lib ] + init_params
46
47
lib = os .path .join (lib , 'libchatllm.' )
47
48
if sys .platform == 'win32' :
48
49
lib = lib + 'dll'
@@ -60,6 +61,17 @@ def __init__(self, lib: str = '', model_storage: str = '') -> None:
60
61
self ._PRINTFUNC = CFUNCTYPE (None , c_void_p , c_int , c_char_p )
61
62
self ._ENDFUNC = CFUNCTYPE (None , c_void_p )
62
63
64
+ _chatllm_append_init_param = self ._lib .chatllm_append_init_param
65
+ _chatllm_append_init_param .restype = None
66
+ _chatllm_append_init_param .argtypes = [c_char_p ]
67
+ _chatllm_init = self ._lib .chatllm_init
68
+ _chatllm_init .restype = c_int
69
+ _chatllm_init .argtypes = []
70
+
71
+ for s in init_params :
72
+ _chatllm_append_init_param (c_char_p (s .encode ()))
73
+ assert _chatllm_init () == 0
74
+
63
75
self ._chatllm_create = self ._lib .chatllm_create
64
76
self ._chatllm_append_param = self ._lib .chatllm_append_param
65
77
self ._chatllm_start = self ._lib .chatllm_start
@@ -170,8 +182,12 @@ def callback_print(user_data: int, print_type: c_int, s: bytes) -> None:
170
182
obj .callback_text_tokenize (txt )
171
183
elif print_type == PrintType .PRINTLN_ERROR .value :
172
184
raise Exception (txt )
185
+ elif print_type == PrintType .PRINTLN_LOGGING .value :
186
+ obj .callback_print_log (txt )
173
187
elif print_type == PrintType .PRINTLN_BEAM_SEARCH .value :
174
188
obj .callback_print_beam_search (txt )
189
+ elif print_type == PrintType .PRINT_EVT_ASYNC_COMPLETED .value :
190
+ obj .callback_async_done ()
175
191
else :
176
192
raise Exception (f"unhandled print_type({ print_type } ): { txt } " )
177
193
@@ -372,6 +388,8 @@ def load_session(self, file_name: str) -> str:
372
388
373
389
def callback_print_reference (self , s : str ) -> None :
374
390
self .references .append (s )
391
+ def callback_print_log (self , s : str ) -> None :
392
+ print (s )
375
393
376
394
def callback_print_beam_search (self , s : str ) -> None :
377
395
l = s .split (',' , maxsplit = 1 )
0 commit comments