23
23
import os
24
24
25
25
# Get environment variable, default to False if not set
26
- SHORTFIN_DEBUG_LLM_SERVICE = os .getenv ('SHORTFIN_DEBUG_LLM_SERVICE' , 'False' ).lower () in ('true' , 'yes' , '1' , 'y' )
26
+ SHORTFIN_DEBUG_LLM_SERVICE = os .getenv (
27
+ "SHORTFIN_DEBUG_LLM_SERVICE" , "False"
28
+ ).lower () in ("true" , "yes" , "1" , "y" )
27
29
if SHORTFIN_DEBUG_LLM_SERVICE :
28
30
logger .info ("DEBUG_LLM_SERVICE=True" )
29
31
dump_id = 0
30
32
boot_timestamp = datetime .now ().isoformat ()
31
33
DEBUG_DATA_DIR = Path .home () / "sfdebug"
32
- DUMP_DIR_THIS_SESSION = DEBUG_DATA_DIR / f"llm_service_invocation_dumps_from_{ boot_timestamp } "
34
+ DUMP_DIR_THIS_SESSION = (
35
+ DEBUG_DATA_DIR / f"llm_service_invocation_dumps_from_{ boot_timestamp } "
36
+ )
33
37
DUMP_DIR_THIS_SESSION .mkdir (parents = True , exist_ok = False )
34
- logger .info (f"[debug_service.py] Please find debug dumps for service.py in { DUMP_DIR_THIS_SESSION } " )
38
+ logger .info (
39
+ f"[debug_service.py] Please find debug dumps for service.py in { DUMP_DIR_THIS_SESSION } "
40
+ )
41
+
35
42
36
43
async def pre_invocation_debug_dump (
37
44
phase ,
@@ -49,16 +56,16 @@ async def pre_invocation_debug_dump(
49
56
seq_lens ,
50
57
seq_block_ids ,
51
58
model_params ,
52
- args
59
+ args ,
53
60
):
54
61
"""Comprehensive debug dump before inference invocation."""
55
62
if not SHORTFIN_DEBUG_LLM_SERVICE :
56
63
return
57
-
64
+
58
65
global dump_id
59
66
dump_path = DUMP_DIR_THIS_SESSION / f"{ dump_id } "
60
67
dump_path .mkdir (parents = True , exist_ok = True )
61
-
68
+
62
69
# Prepare debug info dictionary
63
70
debug_info = {
64
71
"metadata" : {
@@ -67,23 +74,25 @@ async def pre_invocation_debug_dump(
67
74
"phase" : str (phase ),
68
75
"is_decode" : is_decode ,
69
76
"device" : str (device0 ),
70
- "function" : str (fn )
77
+ "function" : str (fn ),
71
78
},
72
79
"batch_info" : {
73
80
"request_batch_size" : req_bs ,
74
81
"block_sequence_length" : int (bsl ),
75
82
"sequence_stride" : seq_stride ,
76
83
"block_count" : block_count ,
77
- "actual_request_count" : req_count
84
+ "actual_request_count" : req_count ,
78
85
},
79
86
"requests" : [
80
87
{
81
88
"index" : i ,
82
89
"start_position" : req .start_position ,
83
90
"rid" : req .rid ,
84
- "input_token_ids" : req .input_token_ids .tolist () if hasattr (req .input_token_ids , 'tolist' ) else list (req .input_token_ids ),
91
+ "input_token_ids" : req .input_token_ids .tolist ()
92
+ if hasattr (req .input_token_ids , "tolist" )
93
+ else list (req .input_token_ids ),
85
94
"input_length" : len (req .input_token_ids ),
86
- "cache_pages" : req .cache_page_indices (block_count )
95
+ "cache_pages" : req .cache_page_indices (block_count ),
87
96
}
88
97
for i , req in enumerate (exec_requests )
89
98
],
@@ -94,10 +103,24 @@ async def pre_invocation_debug_dump(
94
103
"seq_block_ids" : seq_block_ids .shape ,
95
104
},
96
105
"tensor_values" : {
97
- "tokens" : tokens .for_transfer ().items .tolist () if hasattr (tokens .for_transfer ().items , 'tolist' ) else list (tokens .for_transfer ().items ),
98
- ** ({"start_positions" : start_positions .for_transfer ().items .tolist () if hasattr (start_positions .for_transfer ().items , 'tolist' ) else list (start_positions .for_transfer ().items )} if is_decode else {}),
99
- "sequence_lengths" : seq_lens .for_transfer ().items .tolist () if hasattr (seq_lens .for_transfer ().items , 'tolist' ) else list (seq_lens .for_transfer ().items ),
100
- "sequence_block_ids" : seq_block_ids .for_transfer ().items .tolist () if hasattr (seq_block_ids .for_transfer ().items , 'tolist' ) else list (seq_block_ids .for_transfer ().items )
106
+ "tokens" : tokens .for_transfer ().items .tolist ()
107
+ if hasattr (tokens .for_transfer ().items , "tolist" )
108
+ else list (tokens .for_transfer ().items ),
109
+ ** (
110
+ {
111
+ "start_positions" : start_positions .for_transfer ().items .tolist ()
112
+ if hasattr (start_positions .for_transfer ().items , "tolist" )
113
+ else list (start_positions .for_transfer ().items )
114
+ }
115
+ if is_decode
116
+ else {}
117
+ ),
118
+ "sequence_lengths" : seq_lens .for_transfer ().items .tolist ()
119
+ if hasattr (seq_lens .for_transfer ().items , "tolist" )
120
+ else list (seq_lens .for_transfer ().items ),
121
+ "sequence_block_ids" : seq_block_ids .for_transfer ().items .tolist ()
122
+ if hasattr (seq_block_ids .for_transfer ().items , "tolist" )
123
+ else list (seq_block_ids .for_transfer ().items ),
101
124
},
102
125
"model_config" : {
103
126
"prefill_batch_sizes" : model_params .prefill_batch_sizes ,
@@ -106,9 +129,9 @@ async def pre_invocation_debug_dump(
106
129
"paged_kv_cache" : {
107
130
"device_block_count" : model_params .paged_kv_cache .device_block_count ,
108
131
"block_seq_stride" : model_params .paged_kv_cache .block_seq_stride ,
109
- "prefix_sharing_algorithm" : model_params .paged_kv_cache .prefix_sharing_algorithm
110
- }
111
- }
132
+ "prefix_sharing_algorithm" : model_params .paged_kv_cache .prefix_sharing_algorithm ,
133
+ },
134
+ },
112
135
}
113
136
114
137
# Save debug info as JSON
@@ -123,31 +146,31 @@ async def pre_invocation_debug_dump(
123
146
host_array .copy_from (a )
124
147
await a .device
125
148
args_np .append (np .array (host_array ))
126
-
149
+
127
150
# Save binary numpy arrays
128
151
for i , arr in enumerate (args_np ):
129
152
np .save (path / f"{ i } .npy" , arr )
130
-
153
+
131
154
# Generate human-readable report
132
155
with open (path / "saved_program_args.txt" , "w" ) as f :
133
156
for i , arr in enumerate (args_np ):
134
157
f .write (f"\n { '=' * 80 } \n " )
135
158
f .write (f"{ i } .npy:\n " )
136
159
f .write (f"{ '=' * 80 } \n \n " )
137
-
160
+
138
161
# Basic info
139
162
f .write (f"Shape: { arr .shape } \n " )
140
163
f .write (f"Dtype: { arr .dtype } \n " )
141
164
f .write (f"Total elements: { arr .size } \n " )
142
165
f .write (f"Dimensions: { arr .ndim } \n \n " )
143
-
166
+
144
167
# Stats
145
168
f .write ("Statistics:\n " )
146
169
nan_count = np .count_nonzero (np .isnan (arr ))
147
170
inf_count = np .count_nonzero (np .isinf (arr ))
148
171
f .write (f"- NaN count: { nan_count } \n " )
149
172
f .write (f"- Inf count: { inf_count } \n " )
150
-
173
+
151
174
if nan_count == 0 and inf_count == 0 :
152
175
f .write (f"- Min: { np .min (arr )} \n " )
153
176
f .write (f"- Max: { np .max (arr )} \n " )
@@ -159,26 +182,38 @@ async def pre_invocation_debug_dump(
159
182
f .write (f"- Mode: { mode } \n " )
160
183
except :
161
184
f .write ("- Mode: Unable to compute\n " )
162
-
185
+
163
186
if np .issubdtype (arr .dtype , np .number ):
164
187
try :
165
- hist , bins = np .histogram (arr .flatten (), bins = ' auto' )
188
+ hist , bins = np .histogram (arr .flatten (), bins = " auto" )
166
189
f .write ("\n Histogram:\n " )
167
- f .write ("Bins: " + pformat (bins .tolist (), width = 80 , compact = True ) + "\n " )
168
- f .write ("Counts: " + pformat (hist .tolist (), width = 80 , compact = True ) + "\n " )
190
+ f .write (
191
+ "Bins: "
192
+ + pformat (bins .tolist (), width = 80 , compact = True )
193
+ + "\n "
194
+ )
195
+ f .write (
196
+ "Counts: "
197
+ + pformat (hist .tolist (), width = 80 , compact = True )
198
+ + "\n "
199
+ )
169
200
except Exception as e :
170
201
f .write (f"\n Histogram computation failed: { str (e )} \n " )
171
202
else :
172
203
f .write ("Skipping additional statistics due to NaN/Inf values\n " )
173
-
204
+
174
205
f .write ("\n Array contents:\n " )
175
206
if arr .size <= 64 :
176
207
formatted = pformat (arr .tolist (), width = 80 , compact = True )
177
208
f .write (formatted + "\n " )
178
209
else :
179
210
f .write ("\n First 5 elements:\n " )
180
- f .write (pformat (arr .flatten ()[:5 ].tolist (), width = 80 , compact = True ) + "\n " )
211
+ f .write (
212
+ pformat (arr .flatten ()[:5 ].tolist (), width = 80 , compact = True ) + "\n "
213
+ )
181
214
f .write ("\n Last 5 elements:\n " )
182
- f .write (pformat (arr .flatten ()[- 5 :].tolist (), width = 80 , compact = True ) + "\n " )
183
-
215
+ f .write (
216
+ pformat (arr .flatten ()[- 5 :].tolist (), width = 80 , compact = True ) + "\n "
217
+ )
218
+
184
219
dump_id += 1
0 commit comments