Skip to content

Commit

Permalink
WaveCache.module_op already returns the MLIR module str
Browse files Browse the repository at this point in the history
---------

Signed-off-by: tyb0807 <[email protected]>
  • Loading branch information
tyb0807 committed Feb 13, 2025
1 parent 4ef04d7 commit 50009f5
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions tests/kernel/wave/attention/chained_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def repeat(
if dump_generated_mlir:
filename = f"wave_cgemm_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

iree_ref = torch.zeros(shape[0], shape[2], shape[1], dtype=torch.float32)
generate_iree_ref(
Expand Down Expand Up @@ -302,7 +302,7 @@ def repeat(
if dump_generated_mlir:
filename = f"wave_cgemm_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

iree_ref = torch.zeros(shape[0], shape[2], shape[1], dtype=torch.float32)
generate_iree_ref(
Expand Down
4 changes: 2 additions & 2 deletions tests/kernel/wave/attention/decode_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def testFlashDecoding(
if dump_generated_mlir:
filename = f"wave_phase_0_kernel_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_qk.module_op.get_asm())
f.write(mb_qk.module_op)
filename = f"wave_phase_1_kernel_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_sv.module_op.get_asm())
f.write(mb_sv.module_op)

assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3)
2 changes: 1 addition & 1 deletion tests/kernel/wave/attention/evoformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def testEvoformerAttentionForward(
if dump_generated_mlir:
filename = f"wave_evoformer_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

eps = 1e-2 if output.dtype == torch.float16 else 5e-2
assert (
Expand Down
2 changes: 1 addition & 1 deletion tests/kernel/wave/attention/extend_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def testExtendAttention(
if dump_generated_mlir:
filename = f"wave_extend_attention_kernel_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_qk.module_op.get_asm())
f.write(mb_qk.module_op)

# Run the reference implementation.
ref_output = ref_extend_attn(
Expand Down
4 changes: 2 additions & 2 deletions tests/kernel/wave/attention/paged_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,10 @@ def testPagedFlashDecoding(
if dump_generated_mlir:
filename = f"wave_paged_phase_0_kernel_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_qk.module_op.get_asm())
f.write(mb_qk.module_op)
filename = f"wave_paged_phase_1_kernel_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_sv.module_op.get_asm())
f.write(mb_sv.module_op)

if not artifact_directory:
# Run the reference implementation.
Expand Down
10 changes: 5 additions & 5 deletions tests/kernel/wave/attention/vanilla_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def testAttentionPure(
if dump_generated_mlir:
filename = f"wave_attention_{'x'.join(map(str, input_shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3)

Expand Down Expand Up @@ -199,7 +199,7 @@ def testAttentionCausal(
if dump_generated_mlir:
filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3)

Expand Down Expand Up @@ -399,7 +399,7 @@ def repeat(
if dump_generated_mlir:
filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

if "gfx94" in config["target"]:
assert_close(output, torch_ref, atol=2e-3, rtol=5e-3, check_dtype=False)
Expand Down Expand Up @@ -605,7 +605,7 @@ def repeat(
if dump_generated_mlir:
filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

if "gfx94" in config["target"]:
assert_close(output, torch_ref, atol=2e-3, rtol=5e-3, check_dtype=False)
Expand Down Expand Up @@ -771,6 +771,6 @@ def repeat(
if dump_generated_mlir:
filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)
rmse = torch.sqrt(torch.mean(torch.square(output - torch_ref)))
assert rmse <= 0.006
12 changes: 6 additions & 6 deletions tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
if test_dump_generated_mlir:
filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

if run_bench:
if dump_perf is not None:
Expand Down Expand Up @@ -352,7 +352,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
if test_dump_generated_mlir:
filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

if run_bench:
if dump_perf is not None:
Expand Down Expand Up @@ -501,7 +501,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]:
if test_dump_generated_mlir:
filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

if run_bench:
if dump_perf is not None:
Expand Down Expand Up @@ -618,7 +618,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]:
if test_dump_generated_mlir:
filename = f"wave_gemm_{'x'.join(map(str, shape))}_f8.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

if run_bench:
if dump_perf is not None:
Expand Down Expand Up @@ -733,7 +733,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
if test_dump_generated_mlir:
filename = f"wave_gemm_{'x'.join(map(str, shape))}_f8.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

if run_bench:
if dump_perf is not None:
Expand Down Expand Up @@ -841,7 +841,7 @@ def repeat(
if test_dump_generated_mlir:
filename = f"wave_batched_gemm_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
f.write(mb.module_op)

if run_bench:
if dump_perf is not None:
Expand Down

0 comments on commit 50009f5

Please sign in to comment.