Skip to content

Commit 3f383e1

Browse files
RattataKingkuharScottTodd
authored
[tuner] Add tuner files (#158)
Move remaining tuner files from SDXL repo to sharktank --------- Co-authored-by: Jakub Kuderski <[email protected]> Co-authored-by: Scott Todd <[email protected]>
1 parent 5a198e9 commit 3f383e1

10 files changed

+2035
-1
lines changed

tuner/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

tuner/candidate_gen_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
import pytest
12-
import candidate_gen
12+
from . import candidate_gen
1313

1414

1515
def test_get_shaped_type_element_bitwidth():

tuner/examples/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

tuner/examples/punet/.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Test files/dirs recommended by README.md.
2+
dump-mmt
3+
test-benchmark.mlir

tuner/examples/punet/README.md

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Punet Tuner
2+
3+
## Environments
4+
Follow instructions in [`/tuner/README.md`](../README.md)
5+
6+
## Shell Scripts
7+
8+
The required shell scripts can be downloaded from:
9+
[sdxl-scripts](https://github.com/nod-ai/sdxl-scripts).
10+
11+
These scripts include:
12+
1. `compile-punet-base.sh` - Used for compiling model candidates.
13+
2. `compile_candidate.sh` - Used for compiling dispatch candidates.
14+
3. `punet.sh` - Invoked by `compile_candidate.sh`.
15+
16+
Add the parent directories of these scripts to your `PATH` environment variable,
17+
so that they can be picked up by `punet_autotune.py`.
18+
19+
## Running the Tuner
20+
21+
### [Optional] Generate a tunable mlir
22+
Use
23+
[`punet.sh`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/punet.sh)
24+
to compile the sample matmul `mmt.mlir` (can also find here:
25+
[`mmt_unet.mlir`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/mmt_unet.mlir)):
26+
```shell
27+
punet.sh mmt.mlir -o mmt.vmfb --iree-hal-dump-executable-files-to=dump-mmt
28+
cp ./dump-mmt/module_main_0_dispatch_0_rocm_hsaco_fb_benchmark.mlir test-benchmark.mlir
29+
```
30+
31+
### Recommended Trial Run
32+
For an initial trial to test the tuning loop, use:
33+
```shell
34+
python -m tuner.examples.punet.punet_autotune test-benchmark.mlir --num-candidates=10
35+
```
36+
37+
### Dry Run Test
38+
To perform a dry run (no GPU required), use:
39+
```shell
40+
python -m tuner.examples.punet.punet_autotune test-benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run
41+
```
42+
43+
### Basic Usage
44+
```shell
45+
python -m tuner.examples.punet.punet_autotune test-benchmark.mlir
46+
```

tuner/examples/punet/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

tuner/examples/punet/mmt.mlir

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
!matA_0 = tensor<2048x1280xf16>
2+
!matB_0 = tensor<10240x1280xf16>
3+
!matC_0 = tensor<2048x10240xf32>
4+
5+
func.func @main_0(%arg0: !matA_0, %arg1: !matB_0) -> !matC_0 {
6+
%cst = arith.constant 0.000000e+00 : f16
7+
%5 = tensor.empty() : !matC_0
8+
%6 = linalg.fill ins(%cst : f16) outs(%5 : !matC_0) -> !matC_0
9+
%8 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0
10+
return %8 : !matC_0
11+
}
+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""
8+
Sample Usage:
9+
10+
python -m tuner.examples.punet.punet_autotune benchmark.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=hip://0,hip://1 --num-candidates=64
11+
12+
13+
Recommended Trial Run:
14+
15+
python -m tuner.examples.punet.punet_autotune benchmark.mlir --num-candidates=1
16+
17+
18+
Dry Run Test (no gpu requried):
19+
20+
python -m tuner.examples.punet.punet_autotune benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run
21+
22+
"""
23+
24+
from ... import libtuner
25+
from pathlib import Path
26+
27+
28+
class PunetClient(libtuner.TuningClient):
29+
def get_dispatch_compile_timeout_s(self) -> int:
30+
return 4
31+
32+
def get_dispatch_compile_command(
33+
self, candidate_tracker: libtuner.CandidateTracker
34+
) -> list[str]:
35+
mlir_path = candidate_tracker.dispatch_mlir_path
36+
assert mlir_path is not None
37+
command = [
38+
"compile_candidate.sh",
39+
mlir_path.as_posix(),
40+
]
41+
return command
42+
43+
def get_dispatch_benchmark_timeout_s(self) -> int:
44+
return 15
45+
46+
def get_dispatch_benchmark_command(
47+
self,
48+
candidate_tracker: libtuner.CandidateTracker,
49+
) -> list[str]:
50+
compiled_vmfb_path = candidate_tracker.compiled_dispatch_path
51+
assert compiled_vmfb_path is not None
52+
53+
command = [
54+
"iree-benchmark-module",
55+
f"--device={libtuner.DEVICE_ID_PLACEHOLDER}",
56+
f"--module={compiled_vmfb_path.resolve()}",
57+
"--hip_use_streams=true",
58+
"--hip_allow_inline_execution=true",
59+
"--batch_size=1000",
60+
"--benchmark_repetitions=3",
61+
f"--benchmark_out=dispatch_{candidate_tracker.candidate_id}_bm.json",
62+
"--benchmark_out_format=json",
63+
]
64+
65+
return command
66+
67+
def get_model_compile_timeout_s(self) -> int:
68+
return 300
69+
70+
def get_model_compile_command(
71+
self, candidate_tracker: libtuner.CandidateTracker
72+
) -> list[str]:
73+
mlir_spec_path = candidate_tracker.spec_path
74+
assert mlir_spec_path is not None
75+
target_dir = mlir_spec_path.resolve().parent.parent.parent
76+
output_name = f"unet_candidate_{candidate_tracker.candidate_id}.vmfb"
77+
command = [
78+
"compile-punet-base.sh",
79+
"iree-compile",
80+
"gfx942",
81+
f"{mlir_spec_path.resolve()}",
82+
"./punet.mlir",
83+
"-o",
84+
(target_dir / output_name).as_posix(),
85+
]
86+
return command
87+
88+
def get_model_benchmark_timeout_s(self) -> int:
89+
return 180
90+
91+
def get_model_benchmark_command(
92+
self, candidate_tracker: libtuner.CandidateTracker
93+
) -> list[str]:
94+
unet_candidate_path = candidate_tracker.compiled_model_path
95+
assert unet_candidate_path is not None
96+
97+
command = [
98+
"iree-benchmark-module",
99+
f"--device={libtuner.DEVICE_ID_PLACEHOLDER}",
100+
"--hip_use_streams=true",
101+
"--hip_allow_inline_execution=true",
102+
"--device_allocator=caching",
103+
f"--module={unet_candidate_path.resolve()}",
104+
"--parameters=model=punet.irpa",
105+
"--function=main",
106+
"--input=1x4x128x128xf16",
107+
"--input=1xsi32",
108+
"--input=2x64x2048xf16",
109+
"--input=2x1280xf16",
110+
"--input=2x6xf16",
111+
"--input=1xf16",
112+
"--benchmark_repetitions=5",
113+
f"--benchmark_out=model_{candidate_tracker.candidate_id}_bm.json",
114+
"--benchmark_out_format=json",
115+
]
116+
return command
117+
118+
119+
def main():
120+
args = libtuner.parse_arguments()
121+
path_config = libtuner.PathConfig()
122+
path_config.base_dir.mkdir(parents=True, exist_ok=True)
123+
path_config.output_unilog.touch()
124+
candidate_trackers: list[libtuner.CandidateTracker] = []
125+
punet_client = PunetClient()
126+
stop_after_phase: str = args.stop_after
127+
128+
print("Setup logging")
129+
libtuner.setup_logging(args, path_config)
130+
print(path_config.run_log, end="\n\n")
131+
132+
if not args.dry_run:
133+
print("Validating devices")
134+
libtuner.validate_devices(args.devices)
135+
print("Validation successful!\n")
136+
137+
print("Generating candidates...")
138+
candidates = libtuner.generate_candidates(args, path_config, candidate_trackers)
139+
print(f"Stored candidates in {path_config.candidates_dir}\n")
140+
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
141+
return
142+
143+
print("Compiling candidates...")
144+
compiled_candidates = libtuner.compile_dispatches(
145+
args, path_config, candidates, candidate_trackers, punet_client
146+
)
147+
print(f"Compiled files are stored in {path_config.compiled_dir}\n")
148+
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
149+
return
150+
151+
print("Benchmarking compiled candidates...")
152+
top_candidates = libtuner.benchmark_dispatches(
153+
args, path_config, compiled_candidates, candidate_trackers, punet_client
154+
)
155+
print(f"Stored results in {path_config.output_unilog}\n")
156+
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
157+
return
158+
159+
print(f"Compiling top model candidates...")
160+
punet_candidates = libtuner.compile_models(
161+
args, path_config, top_candidates, candidate_trackers, punet_client
162+
)
163+
print(f"Model candidates compiled in {path_config.base_dir}\n")
164+
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
165+
return
166+
167+
print("Benchmarking model candidates...")
168+
libtuner.benchmark_models(
169+
args, path_config, punet_candidates, candidate_trackers, punet_client
170+
)
171+
print(f"Stored results in {path_config.output_unilog}")
172+
if stop_after_phase == libtuner.ExecutionPhases.benchmark_models:
173+
return
174+
175+
libtuner.summerize_top_candidates(path_config, candidate_trackers)
176+
print(f"Stored top candidates info in {path_config.result_summary_log}\n")
177+
178+
libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers)
179+
print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n")
180+
181+
print("Check the detailed execution logs in:")
182+
print(path_config.run_log)
183+
184+
for candidate in candidate_trackers:
185+
libtuner.logging.debug(candidate)
186+
if args.verbose:
187+
print(candidate)
188+
189+
190+
if __name__ == "__main__":
191+
main()

0 commit comments

Comments
 (0)