forked from littsk/test_attn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_ncu.py
250 lines (216 loc) · 8.15 KB
/
run_ncu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
Augment kernel metadata generated by kernel_metadata metric table in inductor.
For each row in input, use NCU to profile the kernel. The corresponding output row
contains more metadata gathered by NCU.
It can be super slow to run NCU. e.g. for the 10K kernels gathered from Huggingface,
it took almost a whole day to run NCU for each unique kernels. The script thus cache
the ncu output in the file system. If the ncu output is cached, we don't run NCU again.
Example input: https://gist.github.com/shunting314/22995da0da8b66d4cf989cb7f0508399
Example output: https://gist.github.com/shunting314/cb36615e8b6e4143de2fba246db9244e
"""
import argparse
import click
import csv
import dataclasses
import itertools
import os
import sys
import re
import subprocess
from typing import List, Optional
@dataclasses.dataclass
class LogLine:
model_name: str
kernel_name: str
kernel_path: str
kernel_category: str
size_hints: List[int]
reduction_hint: str
line_of_code: int
num_load:int
num_store:int
num_for_loop: int
num_atomic_add: int
num_args: int
xnumel: int
ynumel:int
rnumel:int
kernel_args_num_gb: float
# augmented fields
latency_us_under_profiling: Optional[float] = None
mem_bw_under_profiling: Optional[float] = None
ncu_mem_accessed_gb: Optional[float] = None
ncu_duration_us: Optional[float] = None
ncu_mem_bw_gbps: Optional[float] = None
@staticmethod
def is_valid_field(name: str, simplify_output: bool):
if not simplify_output:
return True
# we will only dump the following fields to the output csv
# if we pass --simplify-output
valid_fields = (
"model_name",
"kernel_name",
# "kernel_path",
"kernel_category",
"size_hints",
"kernel_args_num_gb",
"latency_us_under_profiling",
"mem_bw_under_profiling",
"ncu_duration_us",
"ncu_mem_bw_gbps",
"ncu_mem_accessed_gb",
)
return name in valid_fields
@property
def ncu_log_path(self):
return self.kernel_path[:-3] + ".ncu"
def valid_ncu_log_found(self):
ncu_log_path = self.ncu_log_path
if not os.path.exists(ncu_log_path):
return False
with open(ncu_log_path, "r") as f:
log_content = f.read()
return self.is_valid_ncu_output(log_content)
@staticmethod
def is_valid_ncu_output(ncu_output):
# this section reports memory bandwidth usage of the kernel
return "Section: Memory Workload Analysis" in ncu_output
@staticmethod
def _parse_latency_us_and_bw_under_profiling(log_content):
"""
Match line like: 0.027ms 0.004GB 147.69GB/s
"""
m = re.search(r"^([\d.]+)ms\s+[\d.]+GB\s*([\d.]+)GB/s$", log_content, re.M)
assert m, "benchmark output not found"
return (
float(m.group(1)) * 1000, # ms -> us
float(m.group(2)),
)
@staticmethod
def _parse_ncu_duration_us(log_content):
"""
Match line like: Duration usecond 6.21
"""
m = re.search(r"^\s+Duration\s+([a-z]+)\s+([0-9.]+)\s*$", log_content, re.M)
assert m, "ncu duration not found"
unit = m.group(1)
quantity = float(m.group(2))
if unit == "usecond":
return quantity
elif unit == "msecond":
return quantity * 1000
elif unit == "second":
return quantity * 1000000
else:
raise RuntimeError(f"Un-recognized unit {unit}")
@staticmethod
def _parse_ncu_mem_bw_gbps(log_content):
"""
Match line like: Memory Throughput Gbyte/second 4.39
Sometimes the number contains comma. E.g.
Memory Throughput Gbyte/second 1,000.00
"""
m = re.search(r"^\s+Memory Throughput\s+([KMGT])byte/second\s+([0-9.,]+)\s*$", log_content, re.M)
assert m, "ncu memory bw not found"
unit = m.group(1)
quantity = float(m.group(2).replace(",", ""))
if unit == "K":
return quantity / 1000000.0
elif unit == "M":
return quantity / 1000.0
elif unit == "G":
return quantity
elif unit == "T":
return quantity * 1000 # 1000 or 1024?
else:
raise RuntimeError(f"Un-recognized unit for mem bw: {unit}")
def parse_ncu_output(self):
with open(self.ncu_log_path, "r") as f:
log_content = f.read()
(
self.latency_us_under_profiling, self.mem_bw_under_profiling
) = self._parse_latency_us_and_bw_under_profiling(log_content)
self.ncu_duration_us = self._parse_ncu_duration_us(log_content)
self.ncu_mem_bw_gbps = self._parse_ncu_mem_bw_gbps(log_content)
self.ncu_mem_accessed_gb = self.ncu_mem_bw_gbps * self.ncu_duration_us / 1e6
ncu_run_count = itertools.count()
def run_ncu(self):
"""
NOTE: unlike in a shell, surrounding regex:triton with a pair of quotes
(i.e. become "regex:triton") will cause no kernel being found since the quotes
will be passed to ncu. While when running in a shell, the shell will remove the quotes.
"""
if self.valid_ncu_log_found():
return
print(f"Run ncu for kernel {self.kernel_name} at {self.kernel_path}")
# temporal workaround due to ncu sudo privileges issues in user space
ncu_args = f"""
/usr/local/cuda-12.4/bin/ncu --target-processes all -k regex:triton -c 1 --set full {sys.executable} {self.kernel_path}
""".strip().split()
ncu_out = subprocess.check_output(ncu_args).decode()
if not self.is_valid_ncu_output(ncu_out):
raise RuntimeError(f"Invalid ncu output generated for kernel {self.kernel_path}. Output: {ncu_out[:8192]}")
with open(self.ncu_log_path, "w") as f:
f.write(ncu_out)
print(f"{next(self.ncu_run_count)}: ncu output generated at {self.ncu_log_path}")
def write_line_obj_list_to_csv(line_obj_list, output_csv, simplify_output):
assert len(line_obj_list) > 0
sorted_line_obj_list = sorted(line_obj_list, key=lambda line: line.ncu_mem_bw_gbps)
field_names = [
f.name for f in dataclasses.fields(sorted_line_obj_list[0])
if LogLine.is_valid_field(f.name, simplify_output)
]
with open(output_csv, "w") as f:
writer = csv.writer(f)
writer.writerow(field_names)
for line in sorted_line_obj_list:
values = [getattr(line, name) for name in field_names]
if simplify_output and line.ncu_duration_us < 12:
continue
writer.writerow(values)
print(f"Output is written to {output_csv}")
@click.command(short_help="A script for parsing PT2 triton kernel metadata")
@click.option(
"--input-csv", "-i",
type=str,
help="The input CSV file to parse",
required=True,
)
@click.option(
"--output-csv", "-o",
type=str,
help="The output CSV file. Each row may contain augmented fields compared to the row in input",
required=True,
)
@click.option(
"--simplify-output",
type=bool,
help="If it is true, only dump a small number of interesting metrics into the output csv file",
default=True,
)
def main(
input_csv: str,
output_csv: str,
simplify_output: bool = True,
):
line_obj_list = []
with open(input_csv) as f:
csv_reader = csv.reader(f)
header = next(csv_reader)
for line in csv_reader:
try:
line_obj = LogLine(**{k: v for k, v in zip(header, line)})
except TypeError:
print(f"Invalid log line {line}")
raise
if line_obj.kernel_category == "foreach":
# We don't have benchmark harness generated for foreach kernels. So skip.
continue
line_obj.run_ncu()
line_obj.parse_ncu_output()
line_obj_list.append(line_obj)
write_line_obj_list_to_csv(line_obj_list, output_csv, simplify_output)
if __name__ == "__main__":
main()
print("bye")