Skip to content

Commit

Permalink
Shortfin LLM Debug Ergonomics: one flag to dump them all (nod-ai#668)
Browse files Browse the repository at this point in the history
Thanks @stbaione for originally writing a lot of the debug code in here.
I organized his debug logging & dumping into a separate file & made the
output easier to deal with.
  • Loading branch information
renxida authored and IanNod committed Dec 17, 2024
1 parent 70aad72 commit 4c01c4d
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 3 deletions.
6 changes: 5 additions & 1 deletion shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def __init__(
self.eos_token_id = eos_token_id

async def run(self):
exec = InferenceExecRequest(InferencePhase.PREFILL, self.input_token_ids)
exec = InferenceExecRequest(
phase=InferencePhase.PREFILL,
input_token_ids=self.input_token_ids,
rid=self.gen_req.rid,
)
try:
self.client.batcher.submit(exec)
await exec.done
Expand Down
3 changes: 2 additions & 1 deletion shortfin/python/shortfin_apps/llm/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ class InferencePhase(Enum):
class InferenceExecRequest(sf.Message):
"""Performs a prefill operation."""

def __init__(self, phase: InferencePhase, input_token_ids: list[int]):
def __init__(self, phase: InferencePhase, input_token_ids: list[int], rid=None):
super().__init__()
self.phase = phase
self.start_position: int = 0
self.input_token_ids = input_token_ids
self.done = sf.VoidFuture()
self.rid = rid

# Response control.
# If True, return all sequence position logits. If False, return only
Expand Down
17 changes: 16 additions & 1 deletion shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
isolation.name.lower(): isolation for isolation in sf.ProgramIsolation
}

import os
from .service_debug_dumper import SERVICE_DEBUG_DUMPER


class GenerateService:
"""Top level service interface for generating text against a model."""
Expand Down Expand Up @@ -438,7 +441,19 @@ async def run(self):
fn,
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(args)]),
)
# Invoke. Logits are of shape [bs, bsl, d].

# pre-invocation args dump
if os.getenv("SHORTFIN_DEBUG_LLM_SERVICE", "False").lower() in (
"true",
"yes",
"1",
"y",
):
await SERVICE_DEBUG_DUMPER.pre_invocation_debug_dump(
executor=self, local_vars=locals()
)

# Invoke VMFB. Logits are of shape [bs, bsl, d].
(logits,) = await fn(*args, fiber=self.fiber)

# publish cache pages
Expand Down
216 changes: 216 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import asyncio
import logging
from datetime import datetime
from pathlib import Path
import json
import numpy as np
import pandas as pd
from typing import Dict, Any
from pprint import pformat

logger = logging.getLogger(__name__)


class ServiceDebugDumper:
def __init__(self):
"""Initialize debug service with a new dump directory for this session."""
self.dump_id = 0
self.boot_timestamp = datetime.now().isoformat()
self.debug_data_dir = Path.home() / ".shortfin/debug/"
self.dump_dir = (
self.debug_data_dir / "llm_service_invocation_dumps" / self.boot_timestamp
)
self.dump_dir.mkdir(parents=True, exist_ok=False)
logger.info(
f"[debug_service.py] Please find debug dumps for service.py in {self.dump_dir}"
)

async def pre_invocation_debug_dump(
self, executor: "InferenceExecutorProcess", local_vars: Dict[str, Any]
):
"""Comprehensive debug dump before inference invocation."""
# Extract variables from locals
is_decode = local_vars["is_decode"]
device0 = local_vars["device0"]
fn = local_vars["fn"]
req_bs = local_vars["req_bs"]
bsl = local_vars["bsl"]
seq_stride = local_vars["seq_stride"]
block_count = local_vars["block_count"]
req_count = local_vars["req_count"]
tokens = local_vars["tokens"]
start_positions = local_vars.get("start_positions")
seq_lens = local_vars["seq_lens"]
seq_block_ids = local_vars["seq_block_ids"]
args = local_vars["args"]

phase = executor.phase
exec_requests = executor.exec_requests
model_params = executor.service.model_params

dump_path = self.dump_dir / f"{self.dump_id}"
dump_path.mkdir(parents=True, exist_ok=True)

# Prepare debug info dictionary
debug_info = {
"metadata": {
"dump_id": self.dump_id,
"dump_timestamp": datetime.now().isoformat(),
"phase": str(phase),
"is_decode": is_decode,
"device": str(device0),
"function": str(fn),
},
"batch_info": {
"request_batch_size": req_bs,
"block_sequence_length": int(bsl),
"sequence_stride": seq_stride,
"block_count": block_count,
"actual_request_count": req_count,
},
"requests": [
{
"index": i,
"start_position": req.start_position,
"rid": req.rid,
"input_token_ids": req.input_token_ids.tolist()
if hasattr(req.input_token_ids, "tolist")
else list(req.input_token_ids),
"input_length": len(req.input_token_ids),
"cache_pages": req.cache_page_indices(block_count),
}
for i, req in enumerate(exec_requests)
],
"tensor_shapes": {
"tokens": tokens.shape,
**({"start_positions": start_positions.shape} if is_decode else {}),
"seq_lens": seq_lens.shape,
"seq_block_ids": seq_block_ids.shape,
},
"tensor_values": {
"tokens": tokens.for_transfer().items.tolist()
if hasattr(tokens.for_transfer().items, "tolist")
else list(tokens.for_transfer().items),
**(
{
"start_positions": start_positions.for_transfer().items.tolist()
if hasattr(start_positions.for_transfer().items, "tolist")
else list(start_positions.for_transfer().items)
}
if is_decode
else {}
),
"sequence_lengths": seq_lens.for_transfer().items.tolist()
if hasattr(seq_lens.for_transfer().items, "tolist")
else list(seq_lens.for_transfer().items),
"sequence_block_ids": seq_block_ids.for_transfer().items.tolist()
if hasattr(seq_block_ids.for_transfer().items, "tolist")
else list(seq_block_ids.for_transfer().items),
},
"model_config": {
"prefill_batch_sizes": model_params.prefill_batch_sizes,
"decode_batch_sizes": model_params.decode_batch_sizes,
"attn_dtype": str(model_params.attn_dtype),
"paged_kv_cache": {
"device_block_count": model_params.paged_kv_cache.device_block_count,
"block_seq_stride": model_params.paged_kv_cache.block_seq_stride,
"prefix_sharing_algorithm": model_params.paged_kv_cache.prefix_sharing_algorithm,
},
},
}

# Save debug info as JSON
with open(dump_path / "info.json", "w") as f:
json.dump(debug_info, f, indent=2)

# Save program arguments
path = dump_path
args_np = []
for i, a in enumerate(args):
host_array = a.for_transfer()
host_array.copy_from(a)
await a.device
args_np.append(np.array(host_array))

# Save binary numpy arrays
for i, arr in enumerate(args_np):
np.save(path / f"{i}.npy", arr)

# Generate human-readable report
with open(path / "saved_program_args.txt", "w") as f:
for i, arr in enumerate(args_np):
f.write(f"\n{'='*80}\n")
f.write(f"{i}.npy:\n")
f.write(f"{'='*80}\n\n")

# Basic info
f.write(f"Shape: {arr.shape}\n")
f.write(f"Dtype: {arr.dtype}\n")
f.write(f"Total elements: {arr.size}\n")
f.write(f"Dimensions: {arr.ndim}\n\n")

# Stats
f.write("Statistics:\n")
nan_count = np.count_nonzero(np.isnan(arr))
inf_count = np.count_nonzero(np.isinf(arr))
f.write(f"- NaN count: {nan_count}\n")
f.write(f"- Inf count: {inf_count}\n")

if nan_count == 0 and inf_count == 0:
f.write(f"- Min: {np.min(arr)}\n")
f.write(f"- Max: {np.max(arr)}\n")
f.write(f"- Mean: {np.mean(arr):.6f}\n")
f.write(f"- Median: {np.median(arr):.6f}\n")
f.write(f"- Range: {np.ptp(arr)}\n")
try:
mode = pd.Series(arr.flatten()).mode().iloc[0]
f.write(f"- Mode: {mode}\n")
except:
f.write("- Mode: Unable to compute\n")

if np.issubdtype(arr.dtype, np.number):
try:
hist, bins = np.histogram(arr.flatten(), bins="auto")
f.write("\nHistogram:\n")
f.write(
"Bins: "
+ pformat(bins.tolist(), width=80, compact=True)
+ "\n"
)
f.write(
"Counts: "
+ pformat(hist.tolist(), width=80, compact=True)
+ "\n"
)
except Exception as e:
f.write(f"\nHistogram computation failed: {str(e)}\n")
else:
f.write("Skipping additional statistics due to NaN/Inf values\n")

f.write("\nArray contents:\n")
if arr.size <= 64:
formatted = pformat(arr.tolist(), width=80, compact=True)
f.write(formatted + "\n")
else:
f.write("\nFirst 5 elements:\n")
f.write(
pformat(arr.flatten()[:5].tolist(), width=80, compact=True)
+ "\n"
)
f.write("\nLast 5 elements:\n")
f.write(
pformat(arr.flatten()[-5:].tolist(), width=80, compact=True)
+ "\n"
)

self.dump_id += 1


# Create single instance
SERVICE_DEBUG_DUMPER = ServiceDebugDumper()

0 comments on commit 4c01c4d

Please sign in to comment.