Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,10 +2879,14 @@ def add_request(
num_matched_blocks = len(matched_block_ids)
effective_kv_offset = req.finished_chunk_token_count + prefix_skip_tokens

# Track prefix cache hits.
# Track prefix cache hits. num_cached_tokens accumulates across prefill
# chunks: each chunk matches a disjoint block range (start advances with
# finished_chunk_token_count), so a long cached prefix is discovered
# incrementally and must be summed, not overwritten.
if num_matched_blocks > 0:
self.prefix_cache_hits += 1
self.prefix_cache_blocks_matched += num_matched_blocks
req.num_cached_tokens += num_matched_blocks * self.block_size_tokens

# Slice tokens to skip matched prefix
this_round_tokens = req.remaining_prompt_tokens[prefix_skip_tokens:prefill_chunk_length]
Expand Down
2 changes: 2 additions & 0 deletions megatron/core/inference/inference_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ class DynamicInferenceRequest(InferenceRequest):
# Prefix caching fields
block_size_tokens: Optional[int] = None # Block size for hash computation
enable_prefix_caching: bool = False # Whether prefix caching is enabled
num_cached_tokens: int = 0 # Tokens served from prefix cache (set by context on first match)

# Computed field - not passed by caller
precomputed_block_hashes: List[int] = field(default_factory=list)
Expand Down Expand Up @@ -740,6 +741,7 @@ def merge_lists(key):
block_size_tokens=self.requests[0].block_size_tokens,
enable_prefix_caching=self.requests[0].enable_prefix_caching,
precomputed_block_hashes=self.requests[0].precomputed_block_hashes,
num_cached_tokens=self.requests[0].num_cached_tokens,
)

return request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ async def chat_completions():
choices = []
total_completion_tokens = 0
prompt_tokens_counts = []
cached_tokens_counts = []

request_idx = 0
for result_item in batch_results:
Expand All @@ -671,6 +672,7 @@ async def chat_completions():
text_output = result["generated_text"]
prompt_tokens_count = len(prompt_tokens_out) if prompt_tokens_out is not None else 0
prompt_tokens_counts.append(prompt_tokens_count)
cached_tokens_counts.append(result.get("num_cached_tokens", 0))

logprobs_content = None
if sampling_params.return_log_probs:
Expand Down Expand Up @@ -791,6 +793,7 @@ async def chat_completions():
request_idx += 1

prompt_token_count = max(prompt_tokens_counts) if prompt_tokens_counts else 0
cached_token_count = max(cached_tokens_counts) if cached_tokens_counts else 0
response = {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"created": int(time.time()),
Expand All @@ -801,6 +804,9 @@ async def chat_completions():
"prompt_tokens": prompt_token_count,
"completion_tokens": total_completion_tokens,
"total_tokens": prompt_token_count + total_completion_tokens,
"prompt_tokens_details": {
"cached_tokens": cached_token_count,
},
},
}

Expand Down