You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix logic for determining the number of cache blocks (IBM#98)
When we deploy spec decoding in prod., we are frequently seeing the
servers running out of free blocks. We have determined that this is due
to two issues:
1. The constraint on `SPECULATOR_MAX_BATCH_SIZE` is not enough to avoid
running into memory pressure due to speculation - we need to able ensure
that we do not speculate on batches that may have a small "size" but
very large weight.
2. The computation of the number of blocks is very wrong in most cases.
1. I have introduced an additional constraint that says we should only
speculate on batches with weight up to 75% of the weight limit. This
should ensure that we never speculate when we are close to the memory
limits.
2. I have written new code to calculate the number of KV cache blocks.
This calculation uses the memory scaling coefficients that we have
learned at startup. In particular, it uses to the learned coefficients
to figure out what % of the memory capacity needs to be set aside for
cache blocks.
3. In the above calculation, I use the next token coefficient, rather
than the prefill coefficient, since typically during next token phase
the KV cache blocks comprise a relatively large percentage of the total
memory consumption and we need to be able to handle this worst-case.
However, this means that during prefill steps, we may not have enough
memory leftover to store the auxiliary data structures we need for a
forward pass. There isn't really a clean way to handle this other than
re-writing the router logic to be block-aware, but what we can do is
recommend to the user that they should increase the batch safety margin
to a certain level to ensure that prefills will not run OOM. I've added
a print statement to provide this guidance.
4. I now load the speculator before learning the memory scaling model
since we also need to take that into account when measuring the amount
of free memory.
These changes, together with setting the `BATCH_SAFETY_MARGIN=35`, seems
to result in robust behaviour for both `llama3-8b` and `granite-20b`. We
no longer need to manually set the number of KV cache blocks in the
latter case.
n/a
---------
Signed-off-by: Thomas Parnell <[email protected]>
0 commit comments