Skip to content

Commit 89f8218

Browse files
authored
Merge pull request #487 from kvcache-ai/clean_pr
clean PR code and disable flashinfer
2 parents cf4da5f + a529518 commit 89f8218

File tree

3 files changed

+13
-23
lines changed

3 files changed

+13
-23
lines changed

ktransformers/operators/attention.py

+7-17
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,10 @@ def __init__(self,
5858
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
5959
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
6060
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
61-
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
62-
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
63-
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
64-
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
65-
self.q_absorb.weight.data = q_absorb
66-
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
67-
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
68-
self.out_absorb.weight.data = out_absorb
69-
#del self.orig_module.kv_b_proj
70-
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
71-
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
72-
return q_absorb, out_absorb
61+
self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
62+
self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
63+
64+
return self.q_absorb, self.out_absorb
7365

7466
def forward_chunck(
7567
self,
@@ -105,7 +97,7 @@ def forward_chunck(
10597
if past_key_value is not None:
10698
if self.layer_idx is None:
10799
raise ValueError(
108-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
100+
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
109101
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
110102
"with a layer index."
111103
)
@@ -129,8 +121,6 @@ def forward_chunck(
129121
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
130122

131123
q_absorb, out_absorb = self.get_absorbed()
132-
# if hasattr(self.orig_module, 'kv_b_proj'):
133-
# del self.orig_module.kv_b_proj
134124

135125
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
136126
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
@@ -227,7 +217,7 @@ def forward_linux_triton(
227217
if past_key_value is not None:
228218
if self.layer_idx is None:
229219
raise ValueError(
230-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
220+
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
231221
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
232222
"with a layer index."
233223
)
@@ -379,7 +369,7 @@ def forward_linux_flashinfer(
379369
if past_key_value is not None:
380370
if self.layer_idx is None:
381371
raise ValueError(
382-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
372+
f"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} "
383373
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
384374
"with a layer index."
385375
)

ktransformers/operators/flashinfer_wrapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
try:
1111
import flashinfer
12-
flashinfer_enabled = True
12+
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
1313
print("found flashinfer")
1414

1515
except ImportError:

ktransformers/server/backend/interfaces/transformers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -381,13 +381,13 @@ async def inference(self, local_messages, thread_id: str):
381381

382382
self.profiler.create_and_start_timer("prefill")
383383

384-
384+
if Config().user_force_think:
385+
think = '<think>\n'
386+
print(think, end="",flush=True)
387+
yield think
388+
385389
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
386390
# output think token after prefill done
387-
if Config().user_force_think:
388-
think = '<think>\n'
389-
print(think, end="",flush=True)
390-
yield think
391391
if t is not None:
392392
print(t, end="",flush=True)
393393
yield t

0 commit comments

Comments
 (0)