Skip to content

Commit c57b25b

Browse files
committed
ctransformers: another attempt
Generalized ctransformers based on: oobabooga#2892 Credits to randoentity
1 parent 6c521ce commit c57b25b

File tree

7 files changed

+154
-9
lines changed

7 files changed

+154
-9
lines changed

modules/ctransformers_model.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from ctransformers import AutoModelForCausalLM
2+
from ctransformers import AutoConfig
3+
4+
from modules import shared
5+
from modules.callbacks import Iteratorize
6+
from modules.logging_colors import logger
7+
8+
class CtransformersModel:
9+
def __init__(self):
10+
pass
11+
12+
@classmethod
13+
def from_pretrained(self, path):
14+
result = self()
15+
stops = shared.settings['custom_stopping_strings']
16+
stops.append("<|end|>")
17+
18+
# ctransformers uses -1 for random seed
19+
config = AutoConfig.from_pretrained(
20+
str(path),
21+
stop=stops,
22+
threads=shared.args.threads,
23+
gpu_layers=shared.args.n_gpu_layers,
24+
batch_size=shared.args.n_batch,
25+
stream=not shared.args.no_stream,
26+
seed=(-1 if shared.args.llama_cpp_seed == 0 else shared.args.llama_cpp_seed)
27+
)
28+
self.model = AutoModelForCausalLM.from_pretrained(
29+
str(result.model_dir(path) if result.model_type_is_auto() else path),
30+
model_type=(None if result.model_type_is_auto() else shared.args.model_type),
31+
config=config
32+
)
33+
logger.info(f'Using ctransformers model_type: {self.model.model_type} for {self.model.model_path}')
34+
return result, result
35+
36+
def model_type_is_auto(self):
37+
return shared.args.model_type == "Auto" or shared.args.model_type == "None"
38+
39+
def model_dir(self, path):
40+
if path.is_file():
41+
return path.parent
42+
return path
43+
44+
def encode(self, string, **kwargs):
45+
return self.model.tokenize(string)
46+
47+
def decode(self, ids):
48+
return self.model.detokenize(ids)
49+
50+
51+
def generate(self, prompt, state, callback=None):
52+
prompt = prompt if type(prompt) is str else prompt.decode()
53+
generator = self.model._stream(
54+
prompt=prompt,
55+
max_new_tokens=state['max_new_tokens'],
56+
temperature=state['temperature'],
57+
top_p=state['top_p'],
58+
top_k=state['top_k'],
59+
repetition_penalty=state['repetition_penalty'],
60+
threads=shared.args.threads
61+
)
62+
63+
output = ""
64+
for token in generator:
65+
if callback:
66+
callback(token)
67+
output += token
68+
return output
69+
70+
71+
def generate_with_streaming(self, *args, **kwargs):
72+
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
73+
reply = ''
74+
for token in generator:
75+
reply += token
76+
yield reply

modules/loaders.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@
8686
'compress_pos_emb',
8787
'alpha_value',
8888
'exllama_HF_info',
89+
],
90+
'ctransformers': [
91+
'n_ctx',
92+
'n_gpu_layers',
93+
'n_batch',
94+
'threads',
95+
'no_mmap',
96+
'mlock',
97+
'model_type',
98+
'llama_cpp_seed',
8999
]
90100
}
91101

@@ -238,6 +248,13 @@
238248
'add_bos_token',
239249
'skip_special_tokens',
240250
},
251+
'ctransformers': {
252+
'temperature',
253+
'top_p',
254+
'top_k',
255+
'repetition_penalty',
256+
'seed'
257+
}
241258
}
242259

243260

@@ -258,6 +275,30 @@ def blacklist_samplers(loader):
258275
else:
259276
return [gr.update(visible=True) if sampler in loaders_samplers[loader] else gr.update(visible=False) for sampler in all_samplers]
260277

278+
model_loader_type_table = {
279+
'GPTQ-for-LLaMa': [
280+
"None",
281+
"llama",
282+
"opt",
283+
"gptj"
284+
],
285+
'ctransformers': [
286+
"None",
287+
"gptj",
288+
"gpt_neox",
289+
"llama",
290+
"mpt",
291+
"dolly-v2"
292+
"replit",
293+
"starcoder",
294+
"falcon"
295+
],
296+
}
297+
298+
def model_loader_type(loader):
299+
if loader in model_loader_type_table:
300+
return model_loader_type_table[loader]
301+
return ["None"]
261302

262303
def get_gpu_memory_keys():
263304
return [k for k in shared.gradio if k.startswith('gpu_memory')]

modules/models.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def load_model(model_name, loader=None):
5858
'llamacpp_HF': llamacpp_HF_loader,
5959
'RWKV': RWKV_loader,
6060
'ExLlama': ExLlama_loader,
61-
'ExLlama_HF': ExLlama_HF_loader
61+
'ExLlama_HF': ExLlama_HF_loader,
62+
'ctransformers': CtransformorsModel_loader,
6263
}
6364

6465
p = Path(model_name)
@@ -263,6 +264,25 @@ def llamacpp_HF_loader(model_name):
263264
return model, tokenizer
264265

265266

267+
def CtransformorsModel_loader(model_name):
268+
from modules.ctransformers_model import CtransformersModel
269+
270+
path = Path(f'{shared.args.model_dir}/{model_name}')
271+
logger.info(f'ctransformers loading: {path}\n')
272+
ctrans = CtransformersModel()
273+
if ctrans.model_type_is_auto():
274+
model_file = path
275+
else:
276+
if path.is_file():
277+
model_file = path
278+
else:
279+
model_file = list(
280+
Path(f'{shared.args.model_dir}/{model_name}').glob('*.bin')
281+
)[0]
282+
logger.info(f'ctransformers weights detected: {model_file}\n')
283+
model, tokenizer = ctrans.from_pretrained(model_file)
284+
return model, tokenizer
285+
266286
def GPTQ_loader(model_name):
267287

268288
# Monkey patch

modules/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
'autoload_model': False,
3636
'max_new_tokens': 200,
3737
'max_new_tokens_min': 1,
38-
'max_new_tokens_max': 4096,
38+
'max_new_tokens_max': 8000,
3939
'seed': -1,
4040
'character': 'None',
4141
'name1': 'You',

modules/text_generation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ def generate_reply(*args, **kwargs):
3434
def get_max_prompt_length(state):
3535
return state['truncation_length'] - state['max_new_tokens']
3636

37-
37+
encode_llama_prompts = ['LlamaCppModel', 'RWKVModel', 'CtransformersModel']
38+
encode_llama_truncation = ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel']
3839
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
39-
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
40+
if shared.model.__class__.__name__ in encode_llama_prompts:
4041
input_ids = shared.tokenizer.encode(str(prompt))
4142
input_ids = np.array(input_ids).reshape(1, len(input_ids))
4243
return input_ids
@@ -51,7 +52,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
5152
if truncation_length is not None:
5253
input_ids = input_ids[:, -truncation_length:]
5354

54-
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
55+
if shared.model.__class__.__name__ in encode_llama_truncation or shared.args.cpu:
5556
return input_ids
5657
elif shared.args.deepspeed:
5758
return input_ids.to(device=local_rank)
@@ -169,7 +170,12 @@ def apply_stopping_strings(reply, all_stop_strings):
169170

170171
return reply, stop_found
171172

172-
173+
_generate_reply_use_custom = [
174+
'LlamaCppModel',
175+
'RWKVModel',
176+
'ExllamaModel',
177+
'CtransformersModel'
178+
]
173179
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
174180
generate_func = apply_extensions('custom_generate_reply')
175181
if generate_func is None:
@@ -178,7 +184,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
178184
yield ''
179185
return
180186

181-
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
187+
if shared.model.__class__.__name__ in _generate_reply_use_custom:
182188
generate_func = generate_reply_custom
183189
else:
184190
generate_func = generate_reply_HF

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.77/llama_cpp_
3131
# llama-cpp-python with CUDA support
3232
https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.1.77+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
3333
https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.1.77+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
34+
https://github.com/jllllll/ctransformers-cuBLAS-wheels/releases/download/AVX2/ctransformers-0.2.16+cu117-py3-none-any.whl

server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def create_model_menus():
204204

205205
with gr.Row():
206206
with gr.Column():
207-
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value=None)
207+
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=loaders.loaders_and_params.keys(), value=None)
208208
with gr.Box():
209209
with gr.Row():
210210
with gr.Column():
@@ -225,7 +225,7 @@ def create_model_menus():
225225

226226
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None")
227227
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")
228-
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
228+
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None"], value=shared.args.model_type or "None")
229229
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
230230
shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.')
231231
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
@@ -267,6 +267,7 @@ def create_model_menus():
267267
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
268268

269269
shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()))
270+
shared.gradio['loader'].change(fn=lambda value: gr.update(choices=loaders.model_loader_type(value)), inputs=shared.gradio['loader'], outputs=shared.gradio['model_type'])
270271

271272
# In this event handler, the interface state is read and updated
272273
# with the model defaults (if any), and then the model is loaded

0 commit comments

Comments
 (0)