diff --git a/shortfin/python/shortfin_apps/llm/components/generate.py b/shortfin/python/shortfin_apps/llm/components/generate.py index 698f779fb..9e9fea692 100644 --- a/shortfin/python/shortfin_apps/llm/components/generate.py +++ b/shortfin/python/shortfin_apps/llm/components/generate.py @@ -49,7 +49,11 @@ def __init__( self.eos_token_id = eos_token_id async def run(self): - exec = InferenceExecRequest(InferencePhase.PREFILL, self.input_token_ids) + exec = InferenceExecRequest( + phase=InferencePhase.PREFILL, + input_token_ids=self.input_token_ids, + rid=self.gen_req.rid, + ) try: self.client.batcher.submit(exec) await exec.done diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 9e2ab7179..724f71569 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -21,12 +21,13 @@ class InferencePhase(Enum): class InferenceExecRequest(sf.Message): """Performs a prefill operation.""" - def __init__(self, phase: InferencePhase, input_token_ids: list[int]): + def __init__(self, phase: InferencePhase, input_token_ids: list[int], rid=None): super().__init__() self.phase = phase self.start_position: int = 0 self.input_token_ids = input_token_ids self.done = sf.VoidFuture() + self.rid = rid # Response control. # If True, return all sequence position logits. If False, return only diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 71d54c234..9ef8b5c4d 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -29,6 +29,9 @@ isolation.name.lower(): isolation for isolation in sf.ProgramIsolation } +import os +from .service_debug_dumper import SERVICE_DEBUG_DUMPER + class GenerateService: """Top level service interface for generating text against a model.""" @@ -438,7 +441,19 @@ async def run(self): fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(args)]), ) - # Invoke. Logits are of shape [bs, bsl, d]. + + # pre-invocation args dump + if os.getenv("SHORTFIN_DEBUG_LLM_SERVICE", "False").lower() in ( + "true", + "yes", + "1", + "y", + ): + await SERVICE_DEBUG_DUMPER.pre_invocation_debug_dump( + executor=self, local_vars=locals() + ) + + # Invoke VMFB. Logits are of shape [bs, bsl, d]. (logits,) = await fn(*args, fiber=self.fiber) # publish cache pages diff --git a/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py b/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py new file mode 100644 index 000000000..ae492eab3 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py @@ -0,0 +1,216 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +from datetime import datetime +from pathlib import Path +import json +import numpy as np +import pandas as pd +from typing import Dict, Any +from pprint import pformat + +logger = logging.getLogger(__name__) + + +class ServiceDebugDumper: + def __init__(self): + """Initialize debug service with a new dump directory for this session.""" + self.dump_id = 0 + self.boot_timestamp = datetime.now().isoformat() + self.debug_data_dir = Path.home() / ".shortfin/debug/" + self.dump_dir = ( + self.debug_data_dir / "llm_service_invocation_dumps" / self.boot_timestamp + ) + self.dump_dir.mkdir(parents=True, exist_ok=False) + logger.info( + f"[debug_service.py] Please find debug dumps for service.py in {self.dump_dir}" + ) + + async def pre_invocation_debug_dump( + self, executor: "InferenceExecutorProcess", local_vars: Dict[str, Any] + ): + """Comprehensive debug dump before inference invocation.""" + # Extract variables from locals + is_decode = local_vars["is_decode"] + device0 = local_vars["device0"] + fn = local_vars["fn"] + req_bs = local_vars["req_bs"] + bsl = local_vars["bsl"] + seq_stride = local_vars["seq_stride"] + block_count = local_vars["block_count"] + req_count = local_vars["req_count"] + tokens = local_vars["tokens"] + start_positions = local_vars.get("start_positions") + seq_lens = local_vars["seq_lens"] + seq_block_ids = local_vars["seq_block_ids"] + args = local_vars["args"] + + phase = executor.phase + exec_requests = executor.exec_requests + model_params = executor.service.model_params + + dump_path = self.dump_dir / f"{self.dump_id}" + dump_path.mkdir(parents=True, exist_ok=True) + + # Prepare debug info dictionary + debug_info = { + "metadata": { + "dump_id": self.dump_id, + "dump_timestamp": datetime.now().isoformat(), + "phase": str(phase), + "is_decode": is_decode, + "device": str(device0), + "function": str(fn), + }, + "batch_info": { + "request_batch_size": req_bs, + "block_sequence_length": int(bsl), + "sequence_stride": seq_stride, + "block_count": block_count, + "actual_request_count": req_count, + }, + "requests": [ + { + "index": i, + "start_position": req.start_position, + "rid": req.rid, + "input_token_ids": req.input_token_ids.tolist() + if hasattr(req.input_token_ids, "tolist") + else list(req.input_token_ids), + "input_length": len(req.input_token_ids), + "cache_pages": req.cache_page_indices(block_count), + } + for i, req in enumerate(exec_requests) + ], + "tensor_shapes": { + "tokens": tokens.shape, + **({"start_positions": start_positions.shape} if is_decode else {}), + "seq_lens": seq_lens.shape, + "seq_block_ids": seq_block_ids.shape, + }, + "tensor_values": { + "tokens": tokens.for_transfer().items.tolist() + if hasattr(tokens.for_transfer().items, "tolist") + else list(tokens.for_transfer().items), + **( + { + "start_positions": start_positions.for_transfer().items.tolist() + if hasattr(start_positions.for_transfer().items, "tolist") + else list(start_positions.for_transfer().items) + } + if is_decode + else {} + ), + "sequence_lengths": seq_lens.for_transfer().items.tolist() + if hasattr(seq_lens.for_transfer().items, "tolist") + else list(seq_lens.for_transfer().items), + "sequence_block_ids": seq_block_ids.for_transfer().items.tolist() + if hasattr(seq_block_ids.for_transfer().items, "tolist") + else list(seq_block_ids.for_transfer().items), + }, + "model_config": { + "prefill_batch_sizes": model_params.prefill_batch_sizes, + "decode_batch_sizes": model_params.decode_batch_sizes, + "attn_dtype": str(model_params.attn_dtype), + "paged_kv_cache": { + "device_block_count": model_params.paged_kv_cache.device_block_count, + "block_seq_stride": model_params.paged_kv_cache.block_seq_stride, + "prefix_sharing_algorithm": model_params.paged_kv_cache.prefix_sharing_algorithm, + }, + }, + } + + # Save debug info as JSON + with open(dump_path / "info.json", "w") as f: + json.dump(debug_info, f, indent=2) + + # Save program arguments + path = dump_path + args_np = [] + for i, a in enumerate(args): + host_array = a.for_transfer() + host_array.copy_from(a) + await a.device + args_np.append(np.array(host_array)) + + # Save binary numpy arrays + for i, arr in enumerate(args_np): + np.save(path / f"{i}.npy", arr) + + # Generate human-readable report + with open(path / "saved_program_args.txt", "w") as f: + for i, arr in enumerate(args_np): + f.write(f"\n{'='*80}\n") + f.write(f"{i}.npy:\n") + f.write(f"{'='*80}\n\n") + + # Basic info + f.write(f"Shape: {arr.shape}\n") + f.write(f"Dtype: {arr.dtype}\n") + f.write(f"Total elements: {arr.size}\n") + f.write(f"Dimensions: {arr.ndim}\n\n") + + # Stats + f.write("Statistics:\n") + nan_count = np.count_nonzero(np.isnan(arr)) + inf_count = np.count_nonzero(np.isinf(arr)) + f.write(f"- NaN count: {nan_count}\n") + f.write(f"- Inf count: {inf_count}\n") + + if nan_count == 0 and inf_count == 0: + f.write(f"- Min: {np.min(arr)}\n") + f.write(f"- Max: {np.max(arr)}\n") + f.write(f"- Mean: {np.mean(arr):.6f}\n") + f.write(f"- Median: {np.median(arr):.6f}\n") + f.write(f"- Range: {np.ptp(arr)}\n") + try: + mode = pd.Series(arr.flatten()).mode().iloc[0] + f.write(f"- Mode: {mode}\n") + except: + f.write("- Mode: Unable to compute\n") + + if np.issubdtype(arr.dtype, np.number): + try: + hist, bins = np.histogram(arr.flatten(), bins="auto") + f.write("\nHistogram:\n") + f.write( + "Bins: " + + pformat(bins.tolist(), width=80, compact=True) + + "\n" + ) + f.write( + "Counts: " + + pformat(hist.tolist(), width=80, compact=True) + + "\n" + ) + except Exception as e: + f.write(f"\nHistogram computation failed: {str(e)}\n") + else: + f.write("Skipping additional statistics due to NaN/Inf values\n") + + f.write("\nArray contents:\n") + if arr.size <= 64: + formatted = pformat(arr.tolist(), width=80, compact=True) + f.write(formatted + "\n") + else: + f.write("\nFirst 5 elements:\n") + f.write( + pformat(arr.flatten()[:5].tolist(), width=80, compact=True) + + "\n" + ) + f.write("\nLast 5 elements:\n") + f.write( + pformat(arr.flatten()[-5:].tolist(), width=80, compact=True) + + "\n" + ) + + self.dump_id += 1 + + +# Create single instance +SERVICE_DEBUG_DUMPER = ServiceDebugDumper()