Skip to content

Commit dd1bbb3

Browse files
committed
ctransformers: another attempt
Generalized ctransformers based on: oobabooga#2892 Credits to randoentity
1 parent 8dbaa20 commit dd1bbb3

File tree

7 files changed

+183
-41
lines changed

7 files changed

+183
-41
lines changed

modules/ctransformers_model.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from ctransformers import AutoModelForCausalLM
2+
from ctransformers import AutoConfig
3+
4+
from modules import shared
5+
from modules.callbacks import Iteratorize
6+
from modules.logging_colors import logger
7+
8+
class CtransformersModel:
9+
def __init__(self):
10+
pass
11+
12+
@classmethod
13+
def from_pretrained(self, path):
14+
result = self()
15+
stops = shared.settings['custom_stopping_strings']
16+
stops.append("<|end|>")
17+
18+
# ctransformers uses -1 for random seed
19+
config = AutoConfig.from_pretrained(
20+
str(path),
21+
stop=stops,
22+
threads=shared.args.threads,
23+
gpu_layers=shared.args.n_gpu_layers,
24+
batch_size=shared.args.n_batch,
25+
stream=not shared.args.no_stream,
26+
seed=(-1 if shared.args.llama_cpp_seed == 0 else shared.args.llama_cpp_seed)
27+
)
28+
self.model = AutoModelForCausalLM.from_pretrained(
29+
str(result.model_dir(path) if result.model_type_is_auto() else path),
30+
model_type=(None if result.model_type_is_auto() else shared.args.model_type),
31+
config=config
32+
)
33+
logger.info(f'Using ctransformers model_type: {self.model.model_type} for {self.model.model_path}')
34+
return result, result
35+
36+
def model_type_is_auto(self):
37+
return shared.args.model_type == "Auto" or shared.args.model_type == "None"
38+
39+
def model_dir(self, path):
40+
if path.is_file():
41+
return path.parent
42+
return path
43+
44+
def encode(self, string, **kwargs):
45+
return self.model.tokenize(string)
46+
47+
def decode(self, ids):
48+
return self.model.detokenize(ids)
49+
50+
51+
def generate(self, prompt, state, callback=None):
52+
prompt = prompt if type(prompt) is str else prompt.decode()
53+
generator = self.model._stream(
54+
prompt=prompt,
55+
max_new_tokens=state['max_new_tokens'],
56+
temperature=state['temperature'],
57+
top_p=state['top_p'],
58+
top_k=state['top_k'],
59+
repetition_penalty=state['repetition_penalty'],
60+
threads=shared.args.threads
61+
)
62+
63+
output = ""
64+
for token in generator:
65+
if callback:
66+
callback(token)
67+
output += token
68+
return output
69+
70+
71+
def generate_with_streaming(self, *args, **kwargs):
72+
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
73+
reply = ''
74+
for token in generator:
75+
reply += token
76+
yield reply

modules/loaders.py

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,43 @@
11
import functools
2+
from collections import OrderedDict
23

34
import gradio as gr
45

56
from modules import shared
67

7-
loaders_and_params = {
8+
loaders_and_params = OrderedDict({
9+
'Transformers': [
10+
'cpu_memory',
11+
'gpu_memory',
12+
'trust_remote_code',
13+
'load_in_8bit',
14+
'bf16',
15+
'cpu',
16+
'disk',
17+
'auto_devices',
18+
'load_in_4bit',
19+
'use_double_quant',
20+
'quant_type',
21+
'compute_dtype',
22+
'trust_remote_code',
23+
'alpha_value',
24+
'compress_pos_emb',
25+
'transformers_info'
26+
],
27+
'ExLlama_HF': [
28+
'gpu_split',
29+
'max_seq_len',
30+
'alpha_value',
31+
'compress_pos_emb',
32+
'exllama_HF_info',
33+
],
34+
'ExLlama': [
35+
'gpu_split',
36+
'max_seq_len',
37+
'alpha_value',
38+
'compress_pos_emb',
39+
'exllama_info',
40+
],
841
'AutoGPTQ': [
942
'triton',
1043
'no_inject_fused_attention',
@@ -59,39 +92,17 @@
5992
'cpu',
6093
'llamacpp_HF_info',
6194
],
62-
'Transformers': [
63-
'cpu_memory',
64-
'gpu_memory',
65-
'trust_remote_code',
66-
'load_in_8bit',
67-
'bf16',
68-
'cpu',
69-
'disk',
70-
'auto_devices',
71-
'load_in_4bit',
72-
'use_double_quant',
73-
'quant_type',
74-
'compute_dtype',
75-
'trust_remote_code',
76-
'alpha_value',
77-
'compress_pos_emb',
78-
'transformers_info'
79-
],
80-
'ExLlama': [
81-
'gpu_split',
82-
'max_seq_len',
83-
'alpha_value',
84-
'compress_pos_emb',
85-
'exllama_info',
86-
],
87-
'ExLlama_HF': [
88-
'gpu_split',
89-
'max_seq_len',
90-
'alpha_value',
91-
'compress_pos_emb',
92-
'exllama_HF_info',
95+
'ctransformers': [
96+
'n_ctx',
97+
'n_gpu_layers',
98+
'n_batch',
99+
'threads',
100+
'no_mmap',
101+
'mlock',
102+
'model_type',
103+
'llama_cpp_seed',
93104
]
94-
}
105+
})
95106

96107
loaders_samplers = {
97108
'Transformers': {
@@ -256,6 +267,13 @@
256267
'skip_special_tokens',
257268
'auto_max_new_tokens',
258269
},
270+
'ctransformers': {
271+
'temperature',
272+
'top_p',
273+
'top_k',
274+
'repetition_penalty',
275+
'seed'
276+
}
259277
}
260278

261279

@@ -276,6 +294,31 @@ def blacklist_samplers(loader):
276294
else:
277295
return [gr.update(visible=True) if sampler in loaders_samplers[loader] else gr.update(visible=False) for sampler in all_samplers]
278296

297+
model_loader_type_table = {
298+
'GPTQ-for-LLaMa': [
299+
"None",
300+
"llama",
301+
"opt",
302+
"gptj"
303+
],
304+
'ctransformers': [
305+
"None",
306+
"gpt2",
307+
"gptj",
308+
"gptneox",
309+
"llama",
310+
"mpt",
311+
"dollyv2"
312+
"replit",
313+
"starcoder",
314+
"falcon"
315+
],
316+
}
317+
318+
def model_loader_type(loader):
319+
if loader in model_loader_type_table:
320+
return model_loader_type_table[loader]
321+
return ["None"]
279322

280323
def get_gpu_memory_keys():
281324
return [k for k in shared.gradio if k.startswith('gpu_memory')]

modules/models.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def load_model(model_name, loader=None):
5858
'llamacpp_HF': llamacpp_HF_loader,
5959
'RWKV': RWKV_loader,
6060
'ExLlama': ExLlama_loader,
61-
'ExLlama_HF': ExLlama_HF_loader
61+
'ExLlama_HF': ExLlama_HF_loader,
62+
'ctransformers': CtransformorsModel_loader,
6263
}
6364

6465
p = Path(model_name)
@@ -268,6 +269,25 @@ def llamacpp_HF_loader(model_name):
268269
return model, tokenizer
269270

270271

272+
def CtransformorsModel_loader(model_name):
273+
from modules.ctransformers_model import CtransformersModel
274+
275+
path = Path(f'{shared.args.model_dir}/{model_name}')
276+
logger.info(f'ctransformers loading: {path}\n')
277+
ctrans = CtransformersModel()
278+
if ctrans.model_type_is_auto():
279+
model_file = path
280+
else:
281+
if path.is_file():
282+
model_file = path
283+
else:
284+
model_file = list(
285+
Path(f'{shared.args.model_dir}/{model_name}').glob('*.bin')
286+
)[0]
287+
logger.info(f'ctransformers weights detected: {model_file}\n')
288+
model, tokenizer = ctrans.from_pretrained(model_file)
289+
return model, tokenizer
290+
271291
def GPTQ_loader(model_name):
272292

273293
# Monkey patch

modules/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
'autoload_model': False,
3535
'max_new_tokens': 200,
3636
'max_new_tokens_min': 1,
37-
'max_new_tokens_max': 4096,
37+
'max_new_tokens_max': 8000,
3838
'auto_max_new_tokens': False,
3939
'seed': -1,
4040
'negative_prompt': '',

modules/text_generation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
4141
yield ''
4242
return
4343

44-
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
44+
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel']:
4545
generate_func = generate_reply_custom
4646
else:
4747
generate_func = generate_reply_HF
@@ -88,9 +88,8 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
8888

8989
yield reply
9090

91-
9291
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
93-
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
92+
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel']:
9493
input_ids = shared.tokenizer.encode(str(prompt))
9594
input_ids = np.array(input_ids).reshape(1, len(input_ids))
9695
else:
@@ -104,7 +103,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
104103
if truncation_length is not None:
105104
input_ids = input_ids[:, -truncation_length:]
106105

107-
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
106+
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel'] or shared.args.cpu:
108107
return input_ids
109108
elif shared.args.deepspeed:
110109
return input_ids.to(device=local_rank)

modules/ui_model_menu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def create_ui():
6363

6464
with gr.Row():
6565
with gr.Column():
66-
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value=None)
66+
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=loaders.loaders_and_params.keys(), value=None)
6767
with gr.Box():
6868
with gr.Row():
6969
with gr.Column():
@@ -84,7 +84,7 @@ def create_ui():
8484

8585
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None")
8686
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")
87-
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
87+
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None"], value=shared.args.model_type or "None")
8888
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
8989
shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.')
9090
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
@@ -128,6 +128,7 @@ def create_ui():
128128

129129
def create_event_handlers():
130130
shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()))
131+
shared.gradio['loader'].change(fn=lambda value: gr.update(choices=loaders.model_loader_type(value)), inputs=shared.gradio['loader'], outputs=shared.gradio['model_type'])
131132

132133
# In this event handler, the interface state is read and updated
133134
# with the model defaults (if any), and then the model is loaded

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/text
4040
# GPTQ-for-LLaMa
4141
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
4242
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
43+
44+
# ctransformers
45+
https://github.com/jllllll/ctransformers-cuBLAS-wheels/releases/download/AVX2/ctransformers-0.2.20+cu117-py3-none-any.whl

0 commit comments

Comments
 (0)