|
| 1 | +# Dynamic generator |
| 2 | + |
| 3 | +Here is a little writeup about the dynamic generator introduced with ExLlamaV2 v0.1.0. I'll try to |
| 4 | +keep this updated with any changes. At the time of writing, the latest release of ExLlamaV2 is |
| 5 | +**v0.1.4** |
| 6 | + |
| 7 | + |
| 8 | +## Static batching |
| 9 | + |
| 10 | +There are a number of approaches to arranging uneven batches in a key/value cache, and most of them |
| 11 | +have drawbacks. |
| 12 | + |
| 13 | +The "base" and "streaming" generators in ExLlama use left-padding, which is arguably the least bad |
| 14 | +solution when we want a static, pre-allocated cache and we must pass an uneven batch to the model as a |
| 15 | +rectangular tensor of input IDs. It aligns all sequences to the right, and each forward pass gives us |
| 16 | +last-token logits for the entire batch. The downside, of course, is that the padding is wasted space, |
| 17 | +and attention functions that support left padding (which do *not* currently include Flash Attention) |
| 18 | +will still waste time attending to the padding tokens before discarding the masked weights. |
| 19 | + |
| 20 | +Padding on the right is another option. It removes the need for an attention mask but introduces |
| 21 | +another problem of having to start inference with the shortest sequence, generating at a lower effective |
| 22 | +batch size until all sequences in the batch are the same length. Again, the padding also wastes space. |
| 23 | + |
| 24 | +Flash Attention does support an unpadded cache. The featurev seems geared more toward training than |
| 25 | +inference, however, since it doesn't allow for gaps in the key/value inputs. I.e., we can perform |
| 26 | +attention on an uneven batch, but then we have to reshape the entire cache to grow each individual |
| 27 | +sequence by one token for the next forward pass. |
| 28 | + |
| 29 | + |
| 30 | + |
| 31 | +## Continuous batching |
| 32 | + |
| 33 | +Starting with version 2.5.7, Flash Attention supports paged attention. This allows continuous batching: |
| 34 | +since the cache is indexed by a block table, we can shape it however we want, leaving space for |
| 35 | +sequences to grow without resorting to padding. |
| 36 | + |
| 37 | + |
| 38 | + |
| 39 | +The only space wasted is whatever it takes to align sequences to page boundaries. (The page size is |
| 40 | +fixed at 256 tokens currently.) |
| 41 | + |
| 42 | +## Dynamic Generator |
| 43 | + |
| 44 | +Despite its name, and not to get hung up on definitions, the dynamic generator in ExLlama does what |
| 45 | +would usually be referred to as continuous batching, although it could also be thought of as a "hybrid |
| 46 | +approach." If we were getting hung up on definitions. |
| 47 | + |
| 48 | +The important point is that the paged attention model lets us add or remove sequences from a batch at |
| 49 | +any point by merely updating the block index, without having to reshape the entire cache. |
| 50 | + |
| 51 | +The generator works on a job queue. Initially it will start as many job as it can fit in the cache, |
| 52 | +and as soon as a job finishes, those pages are freed to make room for the next job in the queue (which |
| 53 | +may have been part of the original queue or added along the way). |
| 54 | + |
| 55 | +Of course, freeing pages is likely to leave gaps in the cache: |
| 56 | + |
| 57 | + |
| 58 | + |
| 59 | +Paged attention helps out here, since there's no requirement for the block index to be contiguous. The |
| 60 | +next job that's activated can simply fill whatever gaps have opened up. |
| 61 | + |
| 62 | + |
| 63 | + |
| 64 | +The cache can also be defragmented, but for smoother operation the generator will only defragment |
| 65 | +unreferenced pages, when the queue is empty. |
| 66 | + |
| 67 | + |
| 68 | +## Deduplication |
| 69 | + |
| 70 | +The block table offers another great benefit: multiple sequences can index the same pages. A very common |
| 71 | +situation in LLM inference involves multiple prompts repeating a shared prefix. This could be a long system |
| 72 | +prompt, for instance. |
| 73 | + |
| 74 | +So whenever the generator activates a new job, in the likely event that this job starts with a token |
| 75 | +sequence whose keys/values are already present in the cache, it can reference or reuse the existing data |
| 76 | +to save prefill time and VRAM. |
| 77 | + |
| 78 | + |
| 79 | + |
| 80 | +Consider the case of starting a number of jobs that share a long prompt but with slight variation at the |
| 81 | +end. There's a concrete example [here](../examples/inference_dedup.py) in which an instruct model is asked |
| 82 | +multiple independent questions about a short story. Each prompt in the batch includes the entire story |
| 83 | +followed by a unique question, which the generator automatically reduces to something like this: |
| 84 | + |
| 85 | + |
| 86 | + |
| 87 | +The referenced example uses a cache that's too short to hold all the jobs at once. This is not a problem, |
| 88 | +since as soon as enough jobs finish to leave room for a pending job to become active, the newly activated |
| 89 | +job can still reference pages from the exisiting, ongoing jobs. |
| 90 | + |
| 91 | +This also greatly simplifies the batching API. Once you've configured the generator to the capabilities |
| 92 | +of your hardware (most notably available VRAM for the cache), you won't need to manage this feature at all. |
| 93 | + |
| 94 | +```python |
| 95 | +# Generate a story |
| 96 | +story = generator.generate( |
| 97 | + prompt = "Once upon a time", |
| 98 | + max_new_tokens = 1000, |
| 99 | +) |
| 100 | + |
| 101 | +# Generate 100 stories |
| 102 | +stories = generator.generate( |
| 103 | + prompt = ["Once upon a time"] * 100, |
| 104 | + max_new_tokens = 1000, |
| 105 | +) |
| 106 | +``` |
| 107 | + |
| 108 | +The latter generation may run at a batch size of 100 if the generator is configured to allow that (most |
| 109 | +notably if the cache is large enough) or it may maintain a batch size of 17 throughout if that's all it can |
| 110 | +fit, but either way it will return the requested 100 completions. (There are other modes it can run in, |
| 111 | +too, to support streaming and so on, see below.) |
| 112 | + |
| 113 | +The [MMLU](../eval/mmlu.py) and [HumanEval](../eval/humaneval.py) scripts in the repo queue up thousands of |
| 114 | +jobs in this manner and simply call the generator in a loop, collecting the output as each job completes. |
| 115 | + |
| 116 | + |
| 117 | +## Prompt caching |
| 118 | + |
| 119 | +We're not limited to reusing keys/values from jobs that are currently active. The generator tries to avoid |
| 120 | +overwriting the most recently used pages, so if the next job received shares a prefix with any of the |
| 121 | +most recently finished jobs, the cached data can still be reused. |
| 122 | + |
| 123 | +The most immediate benefit is the same that the original streaming generator already provided: in a chatbot |
| 124 | +application, unless you're editing the past, each successive round can reuse the cache for the entire |
| 125 | +context up to the user's most recent prompt. |
| 126 | + |
| 127 | +The dynamic generator extends this to include as many past contexts as there is room for in the cache. If |
| 128 | +you're alternating between two contexts and the cache is large enough, you won't have to forget one context |
| 129 | +to make room for the other. |
| 130 | + |
| 131 | + |
| 132 | +## How to |
| 133 | + |
| 134 | +For a quick introduction to the API, let's start by loading Llama3-8B with a 64k-token cache: |
| 135 | + |
| 136 | +```python |
| 137 | +from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer |
| 138 | + |
| 139 | +model_dir = "/mnt/str/models/llama3-8b-exl2/4.0bpw/" |
| 140 | +config = ExLlamaV2Config(model_dir) |
| 141 | +model = ExLlamaV2(config) |
| 142 | +cache = ExLlamaV2Cache(model, max_seq_len = 65536, lazy = True) |
| 143 | +model.load_autosplit(cache, progress = True) |
| 144 | +tokenizer = ExLlamaV2Tokenizer(config) |
| 145 | +``` |
| 146 | + |
| 147 | +This is the same procedure as usual. Note, however, that the model and cache are initialized without a |
| 148 | +batch size, i.e. with an implicit batch size of one. The dynamic generator will still allow batching. |
| 149 | + |
| 150 | +We will be limited by the native max sequence length of the model, which is 8k in this case, but we |
| 151 | +could extend it, as before, by updating `config.max_seq_len` and `config.scale_alpha_value` etc. before |
| 152 | +loading the model. In any case, the model's max sequence length only gives the maximum length of any *one* |
| 153 | +sequence, so in this example, the 64k-token cache could hold 8 full-length sequences at 8k tokens each, |
| 154 | +or 64 sequences of 1k tokens, etc. |
| 155 | + |
| 156 | +We could replace `ExLlamaV2Cache` with `ExLlamaV2Cache_Q4` to run in Q4 cache mode as usual. FP8 support |
| 157 | +not been adapted to the dynamic generator since it performs worse in every respect than Q4, and besides |
| 158 | +there will be a Q8 mode in **v0.1.5** which will be more accurate still, with a similar footprint to FP8. |
| 159 | +Q3 and Q6 modes are planned as well. |
| 160 | + |
| 161 | +Next, to create the generator: |
| 162 | + |
| 163 | +```python |
| 164 | +from exllamav2.generator import ExLlamaV2DynamicGenerator |
| 165 | + |
| 166 | +generator = ExLlamaV2DynamicGenerator( |
| 167 | + model = model, |
| 168 | + cache = cache, |
| 169 | + tokenizer = tokenizer, |
| 170 | +) |
| 171 | +``` |
| 172 | + |
| 173 | +There are a number of configuration options we could apply here relating to speculative decoding, the max |
| 174 | +_allowed_ batch size and more. See the docstring for details. For now let's trust the defaults and do a |
| 175 | +single completion: |
| 176 | + |
| 177 | +```python |
| 178 | +output = generator.generate( |
| 179 | + prompt = "Five good reasons to adopt a cat:", |
| 180 | + max_new_tokens = 200, |
| 181 | + add_bos = True) |
| 182 | +) |
| 183 | + |
| 184 | +print(output) |
| 185 | +``` |
| 186 | + |
| 187 | +Again, refer to the docstring for a complete list of optional arguments. A few of the important ones: |
| 188 | + |
| 189 | +- `encode_special_tokens`: enable if you want special tokens in the prompt to be encoded as tokens rather than text |
| 190 | +- `decode_special_tokens`: by default, special tokens in the output are not decoded and just become empty strings |
| 191 | +- `stop_conditions`: a list of token IDs and/or strings that will end the sequence before `max_new_tokens` is reached |
| 192 | +- `gen_settings`: sampler settings, as an `ExLlamaV2Sampler.Settings` object |
| 193 | + |
| 194 | +To generate multiple completions with dynamic batching, simply pass a list of strings as the prompt. Let's also add |
| 195 | +sample settings and a stop condition: |
| 196 | + |
| 197 | +```python |
| 198 | +from exllamav2.generator import ExLlamaV2Sampler |
| 199 | + |
| 200 | +prompts = [ |
| 201 | + "Five good reasons to adopt a cat:", |
| 202 | + "Here's why dogs are awful:", |
| 203 | + "Cats are better than dogs because" |
| 204 | +] |
| 205 | + |
| 206 | +gen_settings = ExLlamaV2Sampler.Settings( |
| 207 | + temperature = 0.9, |
| 208 | + top_p = 0.8, |
| 209 | + token_repetition_penalty = 1.025 |
| 210 | +) |
| 211 | + |
| 212 | +outputs = generator.generate( |
| 213 | + prompt = prompts, |
| 214 | + max_new_tokens = 200, |
| 215 | + stop_conditions = [tokenizer.eos_token_id], |
| 216 | + gen_settings = gen_settings, |
| 217 | + add_bos = True) |
| 218 | +) |
| 219 | + |
| 220 | +for o in outputs: |
| 221 | + print(o) |
| 222 | +``` |
| 223 | + |
| 224 | +`gen_settings` can be one `ExLlamaV2Sampler.Settings` object to apply to all the jobs, or a list if you want |
| 225 | +individual settings for each. |
| 226 | + |
| 227 | +### Streaming mode |
| 228 | + |
| 229 | +The `generate` function internally just creates a number of jobs and runs a streaming loop, collecting all the |
| 230 | +outputs before returning all the completions in a batch. For more flexibility you can create jobs directly and |
| 231 | +collect results as they're produced, either token by token or at the end of each job. |
| 232 | + |
| 233 | +First, let's create some jobs. For added flexibiltiy, the job constructor takes input IDs rather than a text |
| 234 | +prompt, but other than that the arguments are largely the same as `generate`: |
| 235 | + |
| 236 | +```python |
| 237 | +from exllamav2.generator import ExLlamaV2DynamicJob |
| 238 | + |
| 239 | +for idx, prompt in enumerate(prompts): |
| 240 | + job = ExLlamaV2DynamicJob( |
| 241 | + input_ids = tokenizer.encode(prompt, add_bos = True), |
| 242 | + max_new_tokens = 200, |
| 243 | + stop_conditions = [tokenizer.eos_token_id], |
| 244 | + gen_settings = gen_settings, |
| 245 | + identifier = idx |
| 246 | + ) |
| 247 | + generator.enqueue(job) |
| 248 | +``` |
| 249 | + |
| 250 | +`identifier` is an optional, user-defined object that we can attach to each job. It will be returned with all |
| 251 | +outputs pertaining to that job. We can use it for any purpose, but the idea is to have an easy way to link the |
| 252 | +output back to the originating prompt, since the generator schedules the overall workload itself. |
| 253 | + |
| 254 | +The jobs are now enqueued and the generator should be ready to start. To use it, simply call `iterate` in a loop |
| 255 | +until the generator is out of things to do: |
| 256 | + |
| 257 | +```python |
| 258 | +# Somewhere to store the streaming results |
| 259 | +collected_outputs = [""] * len(prompts) |
| 260 | + |
| 261 | +while generator.num_remaining_jobs(): |
| 262 | + results = generator.iterate() |
| 263 | + |
| 264 | + # iterate() always returns a list of zero or more result dicts |
| 265 | + for result in results: |
| 266 | + |
| 267 | + # Find out which job this result pertains to |
| 268 | + idx = result["identifier"] |
| 269 | + |
| 270 | + # The text key will only be present during the streaming stage and may be an empty string |
| 271 | + text_chunk = result.get("text", "") |
| 272 | + |
| 273 | + # Stream just the first sequence to the console (could get confusing otherwise) |
| 274 | + if idx == 0: |
| 275 | + print(text_chunk, end = "") |
| 276 | + |
| 277 | + # Collect all the outputs |
| 278 | + collected_outputs[idx] += text |
| 279 | + |
| 280 | +print() |
| 281 | +for idx, o in enumerate(collected_outputs): |
| 282 | + print(idx, o) |
| 283 | +``` |
| 284 | + |
| 285 | +`iterate` also returns metrics and various other details, and can include logits, token probabilities and more |
| 286 | +depending on the job settings. See the docstring for details. |
| 287 | + |
| 288 | +## To be continued |
| 289 | + |
| 290 | +I may expand on this in the future, maybe with some performance benchmarks or something. Here's a cat: 🐈 |
0 commit comments