Skip to content

Commit dd1d866

Browse files
committed
✨: refactor local_chat and fix message slice bug in server
1 parent 43fc7f4 commit dd1d866

File tree

13 files changed

+548
-404
lines changed

13 files changed

+548
-404
lines changed

.flake8

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[flake8]
2+
max-line-length = 120
3+
extend-select = B950
4+
extend-ignore = E203,E501,E701, B001,B006,B007,B008,B009,B010,B011,B016,B028,B031,B950,E265,E266,E401,E402,E711,E712,E713,E721,E722,E731,F401,F403,F405,F541,F811,F821,F841,W391

Makefile

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
flake_find:
2+
cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' -
3+
format:
4+
@cd ktransformers && black .
5+
@black setup.py
6+
dev_install:
7+
# clear build dirs
8+
rm -rf build
9+
rm -rf *.egg-info
10+
rm -rf ktransformers/ktransformers_ext/build
11+
rm -rf ktransformers/ktransformers_ext/cuda/build
12+
rm -rf ktransformers/ktransformers_ext/cuda/dist
13+
rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info
14+
15+
# install ktransformers
16+
echo "Installing python dependencies from requirements.txt"
17+
pip install -r requirements-local_chat.txt
18+
19+
echo "Installing ktransformers"
20+
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation
21+
echo "Installation completed successfully"

ktransformers/configs/config.yaml

+10-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ log:
77

88
server:
99
ip: 0.0.0.0
10-
port: 12456
10+
port: 10002
1111

1212
db:
1313
type: "sqllite"
@@ -24,10 +24,13 @@ model:
2424
type: ktransformers
2525

2626
name: DeepSeek-Coder-V2-Instruct
27-
path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/
28-
gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/
27+
# path: /mnt/data/model/DeepSeek-Coder-V2-Instruct/
28+
path: deepseek-ai/DeepSeek-V2-Lite-Chat
29+
# gguf_path: /mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH/
30+
gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF
2931

3032
device: cuda:0
33+
cache_lens: 8192
3134

3235
web:
3336
mount: False
@@ -50,4 +53,7 @@ long_context:
5053
head_select_mode: SHARED
5154
preselect_block_count: 32
5255
layer_step: 1
53-
token_step: 100
56+
token_step:
57+
58+
local_chat:
59+
prompt_file: "./ktransformers/p.txt"

ktransformers/local_chat.py

+38-98
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,23 @@
11
"""
2-
Description :
2+
Description :
33
Author : Boxin Zhang, Azure-Tang
44
Version : 0.1.0
5-
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5+
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
66
"""
77

8+
import asyncio
89
import os
910
import platform
1011
import sys
1112

13+
from ktransformers.server.args import ArgumentParser
14+
1215
project_dir = os.path.dirname(os.path.dirname(__file__))
1316
sys.path.insert(0, project_dir)
14-
import torch
15-
import logging
16-
from transformers import (
17-
AutoTokenizer,
18-
AutoConfig,
19-
AutoModelForCausalLM,
20-
GenerationConfig,
21-
TextStreamer,
22-
)
23-
import json
24-
import fire
25-
from ktransformers.optimize.optimize import optimize_and_load_gguf
2617
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
2718
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
2819
from ktransformers.models.modeling_llama import LlamaForCausalLM
2920
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
30-
from ktransformers.util.utils import prefill_and_generate
3121
from ktransformers.server.config.config import Config
3222

3323
custom_models = {
@@ -37,9 +27,7 @@
3727
"MixtralForCausalLM": MixtralForCausalLM,
3828
}
3929

40-
ktransformer_rules_dir = (
41-
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
42-
)
30+
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
4331
default_optimize_rules = {
4432
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
4533
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
@@ -48,75 +36,28 @@
4836
}
4937

5038

51-
def local_chat(
52-
model_path: str | None = None,
53-
optimize_rule_path: str = None,
54-
gguf_path: str | None = None,
55-
max_new_tokens: int = 1000,
56-
cpu_infer: int = Config().cpu_infer,
57-
use_cuda_graph: bool = True,
58-
prompt_file : str | None = None,
59-
mode: str = "normal",
60-
):
61-
62-
63-
torch.set_grad_enabled(False)
64-
65-
Config().cpu_infer = cpu_infer
66-
67-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
68-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
69-
if mode == 'long_context':
70-
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
71-
torch.set_default_dtype(torch.float16)
39+
def local_chat():
40+
config = Config()
41+
arg_parser = ArgumentParser(config)
42+
# 初始化消息
43+
arg_parser.parse_args()
44+
if config.backend_type == "transformers":
45+
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
46+
elif config.backend_type == "exllamav2":
47+
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
48+
elif config.backend_type == "ktransformers":
49+
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
7250
else:
73-
torch.set_default_dtype(config.torch_dtype)
74-
75-
with torch.device("meta"):
76-
if config.architectures[0] in custom_models:
77-
print("using custom modeling_xxx.py.")
78-
if (
79-
"Qwen2Moe" in config.architectures[0]
80-
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
81-
config._attn_implementation = "flash_attention_2"
82-
if "Llama" in config.architectures[0]:
83-
config._attn_implementation = "eager"
84-
if "Mixtral" in config.architectures[0]:
85-
config._attn_implementation = "flash_attention_2"
86-
87-
model = custom_models[config.architectures[0]](config)
88-
else:
89-
model = AutoModelForCausalLM.from_config(
90-
config, trust_remote_code=True, attn_implementation="flash_attention_2"
91-
)
92-
93-
if optimize_rule_path is None:
94-
if config.architectures[0] in default_optimize_rules:
95-
print("using default_optimize_rule for", config.architectures[0])
96-
optimize_rule_path = default_optimize_rules[config.architectures[0]]
97-
else:
98-
optimize_rule_path = input(
99-
"please input the path of your rule file(yaml file containing optimize rules):"
100-
)
101-
102-
if gguf_path is None:
103-
gguf_path = input(
104-
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
105-
)
106-
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
107-
108-
model.generation_config = GenerationConfig.from_pretrained(model_path)
109-
if model.generation_config.pad_token_id is None:
110-
model.generation_config.pad_token_id = model.generation_config.eos_token_id
111-
model.eval()
112-
logging.basicConfig(level=logging.INFO)
51+
raise NotImplementedError(f"{config.backend_type} not implemented")
52+
interface = BackendInterface(config)
11353

11454
system = platform.system()
11555
if system == "Windows":
11656
os.system("cls")
11757
else:
11858
os.system("clear")
119-
59+
# add a history chat content
60+
his_content = []
12061
while True:
12162
content = input("Chat: ")
12263
if content.startswith('"""'): # prefix """
@@ -132,28 +73,27 @@ def local_chat(
13273
break
13374
else:
13475
content += line + "\n"
135-
13676
if content == "":
137-
if prompt_file != None:
138-
content = open(prompt_file, "r").read()
139-
else:
77+
if config.prompt_file == None or config.prompt_file == "":
14078
content = "Please write a piece of quicksort code in C++."
79+
else:
80+
content = open(config.prompt_file, "r").read()
14181
elif os.path.isfile(content):
14282
content = open(content, "r").read()
143-
messages = [{"role": "user", "content": content}]
144-
input_tensor = tokenizer.apply_chat_template(
145-
messages, add_generation_prompt=True, return_tensors="pt"
146-
)
147-
if mode == 'long_context':
148-
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
149-
"please change max_seq_len in ~/.ktransformers/config.yaml"
150-
torch.set_default_dtype(
151-
torch.bfloat16
152-
) # TODO: Remove this, replace dtype using config
153-
generated = prefill_and_generate(
154-
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
155-
)
83+
messages = his_content + [{"role": "user", "content": content}]
84+
85+
async def async_inference(messages):
86+
generated = ""
87+
async for token in interface.inference(messages, "local_chat"):
88+
generated += token
89+
return generated
90+
91+
generated = asyncio.run(async_inference(messages))
92+
his_content += [
93+
{"role": "user", "content": content},
94+
{"role": "assitant", "content": generated},
95+
]
15696

15797

15898
if __name__ == "__main__":
159-
fire.Fire(local_chat)
99+
local_chat()

ktransformers/server/api/ollama/completions.py

+2
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class OllamaShowResponse(BaseModel):
135135
details: OllamaShowDetial
136136
model_info: OllamaModelInfo
137137

138+
class Config:
139+
protected_namespaces = ()
138140

139141

140142

ktransformers/server/args.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import argparse
2+
from ktransformers.server.backend.args import ConfigArgs, default_args
3+
4+
5+
class ArgumentParser:
6+
def __init__(self, cfg):
7+
self.cfg = cfg
8+
9+
def parse_args(self):
10+
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
11+
parser.add_argument("--host", type=str, default=self.cfg.server_ip)
12+
parser.add_argument("--port", type=int, default=self.cfg.server_port)
13+
parser.add_argument("--ssl_keyfile", type=str)
14+
parser.add_argument("--ssl_certfile", type=str)
15+
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
16+
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
17+
parser.add_argument("--model_dir", type=str, default=self.cfg.model_dir)
18+
parser.add_argument(
19+
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
20+
)
21+
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
22+
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
23+
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
24+
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
25+
26+
# model configs
27+
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
28+
parser.add_argument("--paged", type=bool, default=self.cfg.paged)
29+
parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
30+
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
31+
parser.add_argument("--max_chunk_size", type=int, default=self.cfg.max_chunk_size)
32+
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
33+
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
34+
parser.add_argument("--healing", type=bool, default=self.cfg.healing)
35+
parser.add_argument("--ban_strings", type=list, default=self.cfg.ban_strings, required=False)
36+
parser.add_argument("--gpu_split", type=str, default=self.cfg.gpu_split, required=False)
37+
parser.add_argument("--length", type=int, default=self.cfg.length, required=False)
38+
parser.add_argument("--rope_scale", type=float, default=self.cfg.rope_scale, required=False)
39+
parser.add_argument("--rope_alpha", type=float, default=self.cfg.rope_alpha, required=False)
40+
parser.add_argument("--no_flash_attn", type=bool, default=self.cfg.no_flash_attn)
41+
parser.add_argument("--low_mem", type=bool, default=self.cfg.low_mem)
42+
parser.add_argument("--experts_per_token", type=int, default=self.cfg.experts_per_token, required=False)
43+
parser.add_argument("--load_q4", type=bool, default=self.cfg.load_q4)
44+
parser.add_argument("--fast_safetensors", type=bool, default=self.cfg.fast_safetensors)
45+
parser.add_argument("--draft_model_dir", type=str, default=self.cfg.draft_model_dir, required=False)
46+
parser.add_argument("--no_draft_scale", type=bool, default=self.cfg.no_draft_scale)
47+
parser.add_argument("--modes", type=bool, default=self.cfg.modes)
48+
parser.add_argument("--mode", type=str, default=self.cfg.mode)
49+
parser.add_argument("--username", type=str, default=self.cfg.username)
50+
parser.add_argument("--botname", type=str, default=self.cfg.botname)
51+
parser.add_argument("--system_prompt", type=str, default=self.cfg.system_prompt, required=False)
52+
parser.add_argument("--temperature", type=float, default=self.cfg.temperature)
53+
parser.add_argument("--smoothing_factor", type=float, default=self.cfg.smoothing_factor)
54+
parser.add_argument("--dynamic_temperature", type=str, default=self.cfg.dynamic_temperature, required=False)
55+
parser.add_argument("--top_k", type=int, default=self.cfg.top_k)
56+
parser.add_argument("--top_p", type=float, default=self.cfg.top_p)
57+
parser.add_argument("--top_a", type=float, default=self.cfg.top_a)
58+
parser.add_argument("--skew", type=float, default=self.cfg.skew)
59+
parser.add_argument("--typical", type=float, default=self.cfg.typical)
60+
parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty)
61+
parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty)
62+
parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty)
63+
parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens)
64+
parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk)
65+
parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting)
66+
parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit)
67+
parser.add_argument("--cache_q4", type=bool, default=self.cfg.cache_q4)
68+
parser.add_argument("--ngram_decoding", type=bool, default=self.cfg.ngram_decoding)
69+
parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings)
70+
parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia)
71+
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
72+
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
73+
74+
# log configs
75+
# log level: debug, info, warn, error, crit
76+
parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir)
77+
parser.add_argument("--log_file", type=str, default=self.cfg.log_file)
78+
parser.add_argument("--log_level", type=str, default=self.cfg.log_level)
79+
parser.add_argument("--backup_count", type=int, default=self.cfg.backup_count)
80+
81+
# db configs
82+
parser.add_argument("--db_type", type=str, default=self.cfg.db_type)
83+
parser.add_argument("--db_host", type=str, default=self.cfg.db_host)
84+
parser.add_argument("--db_port", type=str, default=self.cfg.db_port)
85+
parser.add_argument("--db_name", type=str, default=self.cfg.db_name)
86+
parser.add_argument("--db_pool_size", type=int, default=self.cfg.db_pool_size)
87+
parser.add_argument("--db_database", type=str, default=self.cfg.db_database)
88+
89+
# user config
90+
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
91+
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
92+
93+
# web config
94+
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
95+
96+
# file config
97+
parser.add_argument("--file_upload_dir", type=str, default=self.cfg.file_upload_dir)
98+
parser.add_argument("--assistant_store_dir", type=str, default=self.cfg.assistant_store_dir)
99+
# local chat
100+
parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file)
101+
102+
args = parser.parse_args()
103+
# set config from args
104+
for key, value in vars(args).items():
105+
if value is not None and hasattr(self.cfg, key):
106+
setattr(self.cfg, key, value)
107+
# we add the name not match args individually
108+
self.cfg.model_device = args.device
109+
self.cfg.mount_web = args.web
110+
self.cfg.server_ip = args.host
111+
self.cfg.server_port = args.port
112+
self.cfg.backend_type = args.type
113+
return args

0 commit comments

Comments
 (0)