Skip to content

Commit e7e6f01

Browse files
committed
precommit
1 parent 2f3d862 commit e7e6f01

File tree

2 files changed

+70
-37
lines changed

2 files changed

+70
-37
lines changed

shortfin/python/shortfin_apps/llm/components/debug_service.py

+65-30
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,22 @@
2323
import os
2424

2525
# Get environment variable, default to False if not set
26-
SHORTFIN_DEBUG_LLM_SERVICE = os.getenv('SHORTFIN_DEBUG_LLM_SERVICE', 'False').lower() in ('true', 'yes', '1', 'y')
26+
SHORTFIN_DEBUG_LLM_SERVICE = os.getenv(
27+
"SHORTFIN_DEBUG_LLM_SERVICE", "False"
28+
).lower() in ("true", "yes", "1", "y")
2729
if SHORTFIN_DEBUG_LLM_SERVICE:
2830
logger.info("DEBUG_LLM_SERVICE=True")
2931
dump_id = 0
3032
boot_timestamp = datetime.now().isoformat()
3133
DEBUG_DATA_DIR = Path.home() / "sfdebug"
32-
DUMP_DIR_THIS_SESSION = DEBUG_DATA_DIR / f"llm_service_invocation_dumps_from_{boot_timestamp}"
34+
DUMP_DIR_THIS_SESSION = (
35+
DEBUG_DATA_DIR / f"llm_service_invocation_dumps_from_{boot_timestamp}"
36+
)
3337
DUMP_DIR_THIS_SESSION.mkdir(parents=True, exist_ok=False)
34-
logger.info(f"[debug_service.py] Please find debug dumps for service.py in {DUMP_DIR_THIS_SESSION}")
38+
logger.info(
39+
f"[debug_service.py] Please find debug dumps for service.py in {DUMP_DIR_THIS_SESSION}"
40+
)
41+
3542

3643
async def pre_invocation_debug_dump(
3744
phase,
@@ -49,16 +56,16 @@ async def pre_invocation_debug_dump(
4956
seq_lens,
5057
seq_block_ids,
5158
model_params,
52-
args
59+
args,
5360
):
5461
"""Comprehensive debug dump before inference invocation."""
5562
if not SHORTFIN_DEBUG_LLM_SERVICE:
5663
return
57-
64+
5865
global dump_id
5966
dump_path = DUMP_DIR_THIS_SESSION / f"{dump_id}"
6067
dump_path.mkdir(parents=True, exist_ok=True)
61-
68+
6269
# Prepare debug info dictionary
6370
debug_info = {
6471
"metadata": {
@@ -67,23 +74,25 @@ async def pre_invocation_debug_dump(
6774
"phase": str(phase),
6875
"is_decode": is_decode,
6976
"device": str(device0),
70-
"function": str(fn)
77+
"function": str(fn),
7178
},
7279
"batch_info": {
7380
"request_batch_size": req_bs,
7481
"block_sequence_length": int(bsl),
7582
"sequence_stride": seq_stride,
7683
"block_count": block_count,
77-
"actual_request_count": req_count
84+
"actual_request_count": req_count,
7885
},
7986
"requests": [
8087
{
8188
"index": i,
8289
"start_position": req.start_position,
8390
"rid": req.rid,
84-
"input_token_ids": req.input_token_ids.tolist() if hasattr(req.input_token_ids, 'tolist') else list(req.input_token_ids),
91+
"input_token_ids": req.input_token_ids.tolist()
92+
if hasattr(req.input_token_ids, "tolist")
93+
else list(req.input_token_ids),
8594
"input_length": len(req.input_token_ids),
86-
"cache_pages": req.cache_page_indices(block_count)
95+
"cache_pages": req.cache_page_indices(block_count),
8796
}
8897
for i, req in enumerate(exec_requests)
8998
],
@@ -94,10 +103,24 @@ async def pre_invocation_debug_dump(
94103
"seq_block_ids": seq_block_ids.shape,
95104
},
96105
"tensor_values": {
97-
"tokens": tokens.for_transfer().items.tolist() if hasattr(tokens.for_transfer().items, 'tolist') else list(tokens.for_transfer().items),
98-
**({"start_positions": start_positions.for_transfer().items.tolist() if hasattr(start_positions.for_transfer().items, 'tolist') else list(start_positions.for_transfer().items)} if is_decode else {}),
99-
"sequence_lengths": seq_lens.for_transfer().items.tolist() if hasattr(seq_lens.for_transfer().items, 'tolist') else list(seq_lens.for_transfer().items),
100-
"sequence_block_ids": seq_block_ids.for_transfer().items.tolist() if hasattr(seq_block_ids.for_transfer().items, 'tolist') else list(seq_block_ids.for_transfer().items)
106+
"tokens": tokens.for_transfer().items.tolist()
107+
if hasattr(tokens.for_transfer().items, "tolist")
108+
else list(tokens.for_transfer().items),
109+
**(
110+
{
111+
"start_positions": start_positions.for_transfer().items.tolist()
112+
if hasattr(start_positions.for_transfer().items, "tolist")
113+
else list(start_positions.for_transfer().items)
114+
}
115+
if is_decode
116+
else {}
117+
),
118+
"sequence_lengths": seq_lens.for_transfer().items.tolist()
119+
if hasattr(seq_lens.for_transfer().items, "tolist")
120+
else list(seq_lens.for_transfer().items),
121+
"sequence_block_ids": seq_block_ids.for_transfer().items.tolist()
122+
if hasattr(seq_block_ids.for_transfer().items, "tolist")
123+
else list(seq_block_ids.for_transfer().items),
101124
},
102125
"model_config": {
103126
"prefill_batch_sizes": model_params.prefill_batch_sizes,
@@ -106,9 +129,9 @@ async def pre_invocation_debug_dump(
106129
"paged_kv_cache": {
107130
"device_block_count": model_params.paged_kv_cache.device_block_count,
108131
"block_seq_stride": model_params.paged_kv_cache.block_seq_stride,
109-
"prefix_sharing_algorithm": model_params.paged_kv_cache.prefix_sharing_algorithm
110-
}
111-
}
132+
"prefix_sharing_algorithm": model_params.paged_kv_cache.prefix_sharing_algorithm,
133+
},
134+
},
112135
}
113136

114137
# Save debug info as JSON
@@ -123,31 +146,31 @@ async def pre_invocation_debug_dump(
123146
host_array.copy_from(a)
124147
await a.device
125148
args_np.append(np.array(host_array))
126-
149+
127150
# Save binary numpy arrays
128151
for i, arr in enumerate(args_np):
129152
np.save(path / f"{i}.npy", arr)
130-
153+
131154
# Generate human-readable report
132155
with open(path / "saved_program_args.txt", "w") as f:
133156
for i, arr in enumerate(args_np):
134157
f.write(f"\n{'='*80}\n")
135158
f.write(f"{i}.npy:\n")
136159
f.write(f"{'='*80}\n\n")
137-
160+
138161
# Basic info
139162
f.write(f"Shape: {arr.shape}\n")
140163
f.write(f"Dtype: {arr.dtype}\n")
141164
f.write(f"Total elements: {arr.size}\n")
142165
f.write(f"Dimensions: {arr.ndim}\n\n")
143-
166+
144167
# Stats
145168
f.write("Statistics:\n")
146169
nan_count = np.count_nonzero(np.isnan(arr))
147170
inf_count = np.count_nonzero(np.isinf(arr))
148171
f.write(f"- NaN count: {nan_count}\n")
149172
f.write(f"- Inf count: {inf_count}\n")
150-
173+
151174
if nan_count == 0 and inf_count == 0:
152175
f.write(f"- Min: {np.min(arr)}\n")
153176
f.write(f"- Max: {np.max(arr)}\n")
@@ -159,26 +182,38 @@ async def pre_invocation_debug_dump(
159182
f.write(f"- Mode: {mode}\n")
160183
except:
161184
f.write("- Mode: Unable to compute\n")
162-
185+
163186
if np.issubdtype(arr.dtype, np.number):
164187
try:
165-
hist, bins = np.histogram(arr.flatten(), bins='auto')
188+
hist, bins = np.histogram(arr.flatten(), bins="auto")
166189
f.write("\nHistogram:\n")
167-
f.write("Bins: " + pformat(bins.tolist(), width=80, compact=True) + "\n")
168-
f.write("Counts: " + pformat(hist.tolist(), width=80, compact=True) + "\n")
190+
f.write(
191+
"Bins: "
192+
+ pformat(bins.tolist(), width=80, compact=True)
193+
+ "\n"
194+
)
195+
f.write(
196+
"Counts: "
197+
+ pformat(hist.tolist(), width=80, compact=True)
198+
+ "\n"
199+
)
169200
except Exception as e:
170201
f.write(f"\nHistogram computation failed: {str(e)}\n")
171202
else:
172203
f.write("Skipping additional statistics due to NaN/Inf values\n")
173-
204+
174205
f.write("\nArray contents:\n")
175206
if arr.size <= 64:
176207
formatted = pformat(arr.tolist(), width=80, compact=True)
177208
f.write(formatted + "\n")
178209
else:
179210
f.write("\nFirst 5 elements:\n")
180-
f.write(pformat(arr.flatten()[:5].tolist(), width=80, compact=True) + "\n")
211+
f.write(
212+
pformat(arr.flatten()[:5].tolist(), width=80, compact=True) + "\n"
213+
)
181214
f.write("\nLast 5 elements:\n")
182-
f.write(pformat(arr.flatten()[-5:].tolist(), width=80, compact=True) + "\n")
183-
215+
f.write(
216+
pformat(arr.flatten()[-5:].tolist(), width=80, compact=True) + "\n"
217+
)
218+
184219
dump_id += 1

shortfin/python/shortfin_apps/llm/components/service.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ async def run(self):
437437
# pre-invocation args dump
438438
try:
439439
from .debug_service import pre_invocation_debug_dump
440+
440441
await pre_invocation_debug_dump(
441442
phase=self.phase,
442443
is_decode=is_decode,
@@ -453,17 +454,14 @@ async def run(self):
453454
seq_lens=seq_lens,
454455
seq_block_ids=seq_block_ids,
455456
model_params=self.service.model_params,
456-
args=args
457+
args=args,
457458
)
458459
except Exception as e:
459460
err_msg = (
460-
f"Error Type: {type(e).__name__}\n"
461-
f"Error Message: {str(e)}\n"
462-
)
463-
logger.info(
464-
f"Non-critical failure: debug logging failed due to {e}"
461+
f"Error Type: {type(e).__name__}\n" f"Error Message: {str(e)}\n"
465462
)
466-
463+
logger.info(f"Non-critical failure: debug logging failed due to {e}")
464+
467465
# invoke VMFB
468466
(logits,) = await fn(*args, fiber=self.fiber)
469467

0 commit comments

Comments
 (0)