-
-
Notifications
You must be signed in to change notification settings - Fork 299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question - Batch Processing #718
Comments
It depends on how you're processing, but generally the maximum context length just determines how many position embeddings are precomputed, how much memory is reserved for attention scores, etc., which overall has a minor impact. So you should be able to set a max sequence length of 16k and just not worry about it. As for the input batches, it depends how you're processing them. If you're making rectangular batches to forward without a cache (like for sequence classification or something, where you're interested in the logits and not generating new tokens), you just want to divide up the inputs into batches of similar lengths to minimize padding. If you're generating a completion to each sequence, though, I would use the dynamic generator. It doesn't care about input length and can efficiently batch together sequences of different length. It uses a flat cache divided into pages, and pages are allocated dynamically as needed, so just create a cache as large as you can, make sure max_seq_len is large enough for the longest context you might encounter, and go: config = ExLlamaV2Config(model_dir)
config.max_seq_len = 16384 # <-- max individual sequence length
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 65536, lazy = True) # <-- max total sum of sequence lengths
model.load_autosplit(cache, progress = True)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
prompts = [
"First prompt be all like",
"Let's count to 1000: " + ", ".join([str(x) for x in range(1, 990)]),
"Third prompt is short because",
...
]
outputs = generator.generate(prompts, max_new_tokens = 100, add_bos = True)
print("\n---\n".join(outputs)) |
Thanks so much @turboderp ❤️, love the project and the community you guys have built around it 🙏. Quick question, wondering if reserving 16384 of max sequence length would result in slower inference( even for the smaller sequence length), during batch processing?! Again, can't thank you enough for your time and inputs around this. |
Increasing |
Thanks @turboderp really appreciate it. |
Problem
I'm processing a batch of approximately 100 prompts, each ranging from 1,200 to 14,000 tokens in length. Given that the input context length must be specified during model initialisation, I'm considering two options:
OR
Are there other avenues I should be exploring?
Solution
Need to know best possible options
Alternatives
No response
Explanation
Need to know best possible options
Examples
No response
Additional context
No response
Acknowledgements
The text was updated successfully, but these errors were encountered: