Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit acfbc40

Browse files
authored
fix ret when ignore_prompt (#278)
Signed-off-by: Yu, Zhentao <[email protected]>
1 parent cfc40ab commit acfbc40

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

docs/continuous_batching.md

+45-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ We only support multi-batch inference in concatenating & splitting input sequenc
1616

1717
The code example is like:
1818
```python
19-
from transformers import AutoTokenizer, AutoModelForCausalLM
19+
from transformers import AutoTokenizer
2020
from neural_speed import Model
2121

2222
model_name = "meta-llama/Llama-2-7b-hf"
@@ -32,7 +32,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, pa
3232
# if the tokenizer has no pad_token, you can specify it.
3333
tokenizer.pad_token = tokenizer.eos_token
3434
pad_token_id = tokenizer.pad_token_id
35-
inputs = tokenizer(ps, padding=True, return_tensors='pt').input_ids
35+
inputs = tokenizer(prompts, padding=True, return_tensors='pt').input_ids
3636

3737
model = Model()
3838
model.init(model_name, use_quant=True, weight_dtype="int4", compute_dtype="int8")
@@ -46,6 +46,49 @@ for a in ans:
4646
```
4747
> Note: Not every model supports multi-batching inference and most of them are under construction, please refer to [Supported Models](#supported-models).
4848
49+
You can use below codes to get the `token/second` metric if you care about the throughput of batching inference.
50+
```python
51+
from transformers import AutoTokenizer
52+
from neural_speed import Model
53+
54+
model_name = "meta-llama/Llama-2-7b-hf"
55+
prompts = [
56+
"Tell me an interesting fact about llamas.",
57+
"What is the best way to cook a steak?",
58+
"Are you familiar with the Special Theory of Relativity and can you explain it to me?",
59+
"Recommend some interesting books to read.",
60+
"What is the best way to learn a new language?",
61+
]
62+
63+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left")
64+
# if the tokenizer has no pad_token, you can specify it.
65+
tokenizer.pad_token = tokenizer.eos_token
66+
pad_token_id = tokenizer.pad_token_id
67+
inputs = tokenizer(prompts, padding=True, return_tensors='pt').input_ids
68+
69+
model = Model()
70+
model.init(model_name, use_quant=True, weight_dtype="int4", compute_dtype="int8")
71+
# greedy search example, top_k_top_p sampling and beam_search also supported
72+
# do not forget to pass pad_token_id
73+
# warmup
74+
outputs = model.generate(inputs,
75+
max_new_tokens=4,
76+
do_sample=False,
77+
pad_token=pad_token_id,
78+
ignore_prompt=True,
79+
max_request_num=bs)
80+
t0 = time.time()
81+
outputs = model.generate(inputs,
82+
max_new_tokens=128,
83+
do_sample=False,
84+
pad_token=pad_token_id,
85+
ignore_prompt=True,
86+
max_request_num=bs)
87+
duration = time.time() - t0
88+
total_tokens = sum([len(a) for a in outputs])
89+
print("throughput is {} token/second.".format(total_tokens / duration))
90+
```
91+
4992
## Server
5093
We supply a corresponding [script](../scripts/python_api_example_for_model_server.py) for server usage.
5194
You can modify the `max_request_num` for setting the maximum bearable requests.

neural_speed/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def generate(self,
361361
self.model.reinit()
362362
self.generate_round = 0
363363

364-
ret = [[]]
364+
ret = [[] for _ in range(input_ids.shape[0])]
365365
if self.generate_round == 0 and not ignore_prompt:
366366
ret = input_ids.tolist()
367367

0 commit comments

Comments
 (0)