Skip to content

Commit 9905650

Browse files
tdoublepXaenalt
authored andcommitted
Fix logic for determining the number of cache blocks (IBM#98)
When we deploy spec decoding in prod., we are frequently seeing the servers running out of free blocks. We have determined that this is due to two issues: 1. The constraint on `SPECULATOR_MAX_BATCH_SIZE` is not enough to avoid running into memory pressure due to speculation - we need to able ensure that we do not speculate on batches that may have a small "size" but very large weight. 2. The computation of the number of blocks is very wrong in most cases. 1. I have introduced an additional constraint that says we should only speculate on batches with weight up to 75% of the weight limit. This should ensure that we never speculate when we are close to the memory limits. 2. I have written new code to calculate the number of KV cache blocks. This calculation uses the memory scaling coefficients that we have learned at startup. In particular, it uses to the learned coefficients to figure out what % of the memory capacity needs to be set aside for cache blocks. 3. In the above calculation, I use the next token coefficient, rather than the prefill coefficient, since typically during next token phase the KV cache blocks comprise a relatively large percentage of the total memory consumption and we need to be able to handle this worst-case. However, this means that during prefill steps, we may not have enough memory leftover to store the auxiliary data structures we need for a forward pass. There isn't really a clean way to handle this other than re-writing the router logic to be block-aware, but what we can do is recommend to the user that they should increase the batch safety margin to a certain level to ensure that prefills will not run OOM. I've added a print statement to provide this guidance. 4. I now load the speculator before learning the memory scaling model since we also need to take that into account when measuring the amount of free memory. These changes, together with setting the `BATCH_SAFETY_MARGIN=35`, seems to result in robust behaviour for both `llama3-8b` and `granite-20b`. We no longer need to manually set the number of KV cache blocks in the latter case. n/a --------- Signed-off-by: Thomas Parnell <[email protected]>
1 parent e908eec commit 9905650

File tree

6 files changed

+95
-35
lines changed

6 files changed

+95
-35
lines changed

server/text_generation_server/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def get_model(
3535
dtype_str: str,
3636
quantize: Optional[str],
3737
max_sequence_length: Optional[int],
38+
memory_scaling_model: Optional[int] = None,
3839
) -> Model:
3940
dtype = get_torch_dtype(dtype_str)
4041
model_path = get_model_path(model_name, revision)
@@ -74,6 +75,7 @@ def get_model(
7475
dtype, quantize,
7576
model_config,
7677
max_sequence_length=max_sequence_length,
78+
memory_scaling_model=memory_scaling_model,
7779
)
7880

7981
if FLASH_ATTENTION:

server/text_generation_server/models/custom_modeling/paged_llama_modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,9 @@ def __init__(self, config, weights):
434434
weights=weights,
435435
)
436436

437+
def get_kv_cache_block_size(self, block_size: int) -> int:
438+
return block_size * self.model.num_key_value_heads * self.model.head_size * 2
439+
437440
def get_input_embeddings(self) -> nn.Module:
438441
return self.model.embed_tokens
439442

server/text_generation_server/models/custom_modeling/paged_santacoder_modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ def __init__(self, config, weights):
407407
config, prefix="transformer.wte", weights=weights
408408
)
409409

410+
def get_kv_cache_block_size(self, block_size: int) -> int:
411+
return block_size * self.transformer.head_size * 2
412+
410413
def get_input_embeddings(self) -> nn.Module:
411414
return self.transformer.wte
412415

server/text_generation_server/models/paged_causal_lm.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,14 @@
1818
from text_generation_server.utils.token_types import TokenInfo, InputTokens
1919
from text_generation_server.utils.tokens import HeterogeneousNextTokenChooser, get_token_info, get_input_tokens_info
2020
from text_generation_server.utils.paged import (
21+
load_speculator,
2122
prepare_inputs_without_speculation,
2223
prepare_inputs_with_speculation,
2324
process_outputs_with_speculation,
2425
prepare_inputs_for_prefill
2526
)
2627
from text_generation_server.inference_engine import get_inference_engine_class
2728

28-
# HF name or path to speculator model (None means no speculation will be used)
29-
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)
30-
3129
# we will only do speculation if the batch size is <= this parameter
3230
SPECULATOR_MAX_BATCH_SIZE = int(os.getenv("SPECULATOR_MAX_BATCH_SIZE", "16"))
3331

@@ -277,6 +275,7 @@ def __init__(
277275
quantize: Optional[str],
278276
model_config: Union[Any] = None,
279277
max_sequence_length: Optional[int] = None,
278+
memory_scaling_model: Optional["MemoryScalingModel"] = None,
280279
):
281280
model_path = get_model_path(model_name, revision)
282281

@@ -300,27 +299,41 @@ def __init__(
300299

301300
from fms_extras.utils.cache.paged import PagedKVCacheManager
302301

303-
if SPECULATOR_NAME is not None:
304-
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
305-
speculator_revision = os.getenv("SPECULATOR_REVISION", None)
306-
speculator_model_path = get_model_path(SPECULATOR_NAME, speculator_revision)
307-
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
302+
# load speculator
303+
self.speculator = load_speculator(self.device, dtype)
304+
305+
if self.speculator is not None:
308306
print_rank_n(f"Speculation will be enabled up to batch size {SPECULATOR_MAX_BATCH_SIZE}")
309-
kwargs = {
310-
"pretrained_model_name_or_path": speculator_model_path,
311-
"local_files_only": True,
312-
"torch_dtype": dtype,
313-
}
314-
with self.device:
315-
self.speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
316-
self.speculator.to(device=self.device)
317-
else:
318-
self.speculator = None
307+
308+
block_size = 16
319309

320310
if KV_CACHE_MANAGER_NUM_GPU_BLOCKS is not None:
321311
total_num_gpu_blocks = int(KV_CACHE_MANAGER_NUM_GPU_BLOCKS)
322312
else:
323-
total_num_gpu_blocks = None
313+
# Firstly, let's compute the size of a cache block in bytes
314+
kv_cache_block_size = self.model.get_kv_cache_block_size(block_size)
315+
total_size = model_config.num_hidden_layers * kv_cache_block_size
316+
dtype_size = torch.tensor([], dtype=dtype).element_size()
317+
cache_block_size = dtype_size * total_size
318+
# We then use our memory scaling model to determine the fraction of the prefill memory
319+
# usage that is due to cache blocks (as opposed to the other stuff needed for forward):
320+
pf_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.linear_fit_params[0]
321+
# We can then do the same for the next token (decoding) step:
322+
nt_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.next_token_params[1]
323+
# In general we know that the next token phase can use many more cache blocks
324+
# relative to the prefill phase (e.g., nt_cache_block_ratio > pf_cache_block_ratio).
325+
# Thus, we need to allocate enough cache blocks to handle the more extreme case:
326+
total_num_gpu_blocks = int(nt_cache_block_ratio * memory_scaling_model.free_memory // cache_block_size)
327+
# This creates an issue though, because if we then try to perform a large prefill, while we
328+
# will certainly have enough cache blocks available, we may not have enough memory leftover
329+
# to allocate the other data structures needed during a forward pass.
330+
# To overcome this, we can set the batch_safety_margin a bit to ensure that:
331+
# free_memory * (1.0-batch_safety_margin/100-0.05) * (1.0-pf_cache_block_ratio) <
332+
# free_memory * (1.0-nf_cache_block_ratio)
333+
# This should ensure that our prefills batches can never get so big as to cause OOM.
334+
recommend_safety_margin = 5 + int(100*(1.0 - (1.0 - nt_cache_block_ratio)/(1.0 - pf_cache_block_ratio)))
335+
if memory_scaling_model.safety_margin < recommend_safety_margin:
336+
print(f"WARN: We recommend increasing the value of BATCH_SAFETY_MARGIN to: {recommend_safety_margin}")
324337

325338
self.kv_cache_manager = PagedKVCacheManager(
326339
model_config.num_hidden_layers,
@@ -331,8 +344,14 @@ def __init__(
331344
dtype=dtype,
332345
device=self.device,
333346
total_num_gpu_blocks=total_num_gpu_blocks,
347+
block_size=block_size,
334348
)
335349

350+
self.memory_scaling_model = memory_scaling_model
351+
352+
# log number of free blocks at init
353+
print("[PagedKVCacheManager] number of free blocks: %d" % (len(self.kv_cache_manager.free_blocks)))
354+
336355
@property
337356
def batch_type(self) -> Type[PagedCausalLMBatch]:
338357
return self._batch_type
@@ -410,12 +429,18 @@ def _prefill(
410429
)
411430

412431
t0 = time.time_ns()
413-
output = self.model(
414-
input_ids,
415-
position_ids=position_ids,
416-
cache_data=cache_data,
417-
return_embeds=True,
418-
)
432+
try:
433+
output = self.model(
434+
input_ids,
435+
position_ids=position_ids,
436+
cache_data=cache_data,
437+
return_embeds=True,
438+
)
439+
except:
440+
# if something goes wrong during forward, we still need to set the sequence ids
441+
#TODO it would be better to fix the forward method to avoid possibility of partial failures
442+
batch.sequence_ids = cache_data.sequence_ids
443+
raise
419444
t_forward_ns = time.time_ns()-t0
420445
logits, embeds = output
421446

@@ -600,10 +625,7 @@ def generate_token(
600625
)
601626
else:
602627
bsize = batch.input_ids.shape[0]
603-
604-
tokens_remaining = 0
605-
for i in range(len(batch.total_lengths)):
606-
tokens_remaining += batch.total_lengths[i] - batch.input_lengths[i]
628+
weight = sum(batch.total_lengths) * self.memory_scaling_model.next_token_params[1]
607629

608630
spec_ind = []
609631
for i, sample in enumerate(batch.next_token_chooser.do_sample):
@@ -615,7 +637,7 @@ def generate_token(
615637
len(spec_ind) > 0 and
616638
bsize <= SPECULATOR_MAX_BATCH_SIZE and
617639
batch.next_token_chooser.repetition_processor is None and
618-
tokens_remaining < 0.25*len(self.kv_cache_manager.free_blocks)*self.kv_cache_manager.block_size
640+
(weight/self.memory_scaling_model.weight_limit) <= 0.75
619641
)
620642

621643
if speculate:

server/text_generation_server/server.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def func_with_log(*args, **kwargs):
5656

5757

5858
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
59-
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModelPB):
59+
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModel):
6060
self.cache = cache
6161
self.model = model
6262
self.server_urls = server_urls
@@ -81,7 +81,7 @@ async def ModelInfo(self, request: generate_pb2.ModelInfoRequest, context) -> ge
8181
if isinstance(self.model, Seq2SeqLM) else ModelInfoResponse.ModelType.CAUSAL_LM,
8282
eos_token=self.model.config.eos_token_id,
8383
batch_padding=not isinstance(self.model, FlashCausalLM),
84-
memory_scaling_model=self.memory_scaling_model,
84+
memory_scaling_model=self.memory_scaling_model.as_pb(),
8585
)
8686

8787
@log_rpc_handler_errors
@@ -234,8 +234,9 @@ def _free_paged_sequences(self, batch: "Batch", completed_ids: Optional[List[int
234234
]
235235
else:
236236
return
237-
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)
238237

238+
if sequence_ids_to_free is not None:
239+
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)
239240

240241
def serve(
241242
model_name: str,
@@ -276,6 +277,8 @@ async def serve_inner(
276277
proc.start()
277278
memory_scaling_model_ext = q_out.get()
278279
proc.join()
280+
else:
281+
memory_scaling_model_ext = None
279282

280283
unix_socket_template = "unix://{}-{}"
281284
world_size = int(os.getenv("WORLD_SIZE", "1"))
@@ -307,7 +310,7 @@ async def serve_inner(
307310
torch.cuda.set_per_process_memory_fraction(cuda_process_memory_fraction)
308311

309312
model = get_model(
310-
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
313+
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length, memory_scaling_model_ext,
311314
)
312315

313316
device = model.engine.get_device()
@@ -415,7 +418,7 @@ def estimate_memory():
415418

416419
server = aio.server()
417420
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
418-
TextGenerationService(model, Cache(), server_urls, memory_scaling_model.as_pb()), server
421+
TextGenerationService(model, Cache(), server_urls, memory_scaling_model), server
419422
)
420423
# SERVICE_NAMES = (
421424
# generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,

server/text_generation_server/utils/paged.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,37 @@
55

66
from fms_extras.models.speculator import flatten_batch, apply_index_map
77

8+
# HF name or path to speculator model (None means no speculation will be used)
9+
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)
10+
11+
# speculator revision
12+
SPECULATOR_REVISION = os.getenv("SPECULATOR_REVISION", None)
13+
814
# number of candidates during speculation
915
SPECULATOR_N_CANDIDATES = os.getenv("SPECULATOR_N_CANDIDATES", None)
1016

1117
# number of candidates per head
1218
SPECULATOR_TOP_K_TOKENS_PER_HEAD = os.getenv("SPECULATOR_TOP_K_TOKENS_PER_HEAD", None)
1319

20+
def load_speculator(device, dtype):
21+
22+
if SPECULATOR_NAME is not None:
23+
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
24+
from text_generation_server.utils.hub import get_model_path
25+
from text_generation_server.utils import print_rank_n
26+
speculator_model_path = get_model_path(SPECULATOR_NAME, SPECULATOR_REVISION)
27+
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
28+
kwargs = {
29+
"pretrained_model_name_or_path": speculator_model_path,
30+
"local_files_only": True,
31+
"torch_dtype": dtype,
32+
}
33+
with device:
34+
speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
35+
speculator.to(device=device)
36+
return speculator
37+
else:
38+
return None
1439

1540
def fit_memory_scaling_model(
1641
model_name: str,
@@ -38,6 +63,8 @@ def fit_memory_scaling_model(
3863
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
3964
)
4065

66+
speculator = load_speculator(model.device, model.dtype)
67+
4168
memory_scaling_model = Estimator.build_from_env(
4269
model,
4370
batch_safety_margin,

0 commit comments

Comments
 (0)