From 50009f5ff64172d380684b1b2cff7ced666db329 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Thu, 13 Feb 2025 07:03:31 -0800 Subject: [PATCH] WaveCache.module_op already returns the MLIR module str --------- Signed-off-by: tyb0807 --- tests/kernel/wave/attention/chained_gemm_test.py | 4 ++-- tests/kernel/wave/attention/decode_attention_test.py | 4 ++-- tests/kernel/wave/attention/evoformer_test.py | 2 +- tests/kernel/wave/attention/extend_attention_test.py | 2 +- tests/kernel/wave/attention/paged_attention_test.py | 4 ++-- .../kernel/wave/attention/vanilla_attention_test.py | 10 +++++----- tests/kernel/wave/wave_gemm_test.py | 12 ++++++------ 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/kernel/wave/attention/chained_gemm_test.py b/tests/kernel/wave/attention/chained_gemm_test.py index bb8b9609c..f4af00d83 100644 --- a/tests/kernel/wave/attention/chained_gemm_test.py +++ b/tests/kernel/wave/attention/chained_gemm_test.py @@ -161,7 +161,7 @@ def repeat( if dump_generated_mlir: filename = f"wave_cgemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) iree_ref = torch.zeros(shape[0], shape[2], shape[1], dtype=torch.float32) generate_iree_ref( @@ -302,7 +302,7 @@ def repeat( if dump_generated_mlir: filename = f"wave_cgemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) iree_ref = torch.zeros(shape[0], shape[2], shape[1], dtype=torch.float32) generate_iree_ref( diff --git a/tests/kernel/wave/attention/decode_attention_test.py b/tests/kernel/wave/attention/decode_attention_test.py index 7cfbb1d2f..3e4851db1 100644 --- a/tests/kernel/wave/attention/decode_attention_test.py +++ b/tests/kernel/wave/attention/decode_attention_test.py @@ -128,9 +128,9 @@ def testFlashDecoding( if dump_generated_mlir: filename = f"wave_phase_0_kernel_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb_qk.module_op.get_asm()) + f.write(mb_qk.module_op) filename = f"wave_phase_1_kernel_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb_sv.module_op.get_asm()) + f.write(mb_sv.module_op) assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3) diff --git a/tests/kernel/wave/attention/evoformer_test.py b/tests/kernel/wave/attention/evoformer_test.py index 7b740fe7d..5dc2689e4 100644 --- a/tests/kernel/wave/attention/evoformer_test.py +++ b/tests/kernel/wave/attention/evoformer_test.py @@ -140,7 +140,7 @@ def testEvoformerAttentionForward( if dump_generated_mlir: filename = f"wave_evoformer_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) eps = 1e-2 if output.dtype == torch.float16 else 5e-2 assert ( diff --git a/tests/kernel/wave/attention/extend_attention_test.py b/tests/kernel/wave/attention/extend_attention_test.py index 8e0b6561f..485a0d3f3 100644 --- a/tests/kernel/wave/attention/extend_attention_test.py +++ b/tests/kernel/wave/attention/extend_attention_test.py @@ -346,7 +346,7 @@ def testExtendAttention( if dump_generated_mlir: filename = f"wave_extend_attention_kernel_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb_qk.module_op.get_asm()) + f.write(mb_qk.module_op) # Run the reference implementation. ref_output = ref_extend_attn( diff --git a/tests/kernel/wave/attention/paged_attention_test.py b/tests/kernel/wave/attention/paged_attention_test.py index 39586b530..d3bb269fb 100644 --- a/tests/kernel/wave/attention/paged_attention_test.py +++ b/tests/kernel/wave/attention/paged_attention_test.py @@ -292,10 +292,10 @@ def testPagedFlashDecoding( if dump_generated_mlir: filename = f"wave_paged_phase_0_kernel_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb_qk.module_op.get_asm()) + f.write(mb_qk.module_op) filename = f"wave_paged_phase_1_kernel_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb_sv.module_op.get_asm()) + f.write(mb_sv.module_op) if not artifact_directory: # Run the reference implementation. diff --git a/tests/kernel/wave/attention/vanilla_attention_test.py b/tests/kernel/wave/attention/vanilla_attention_test.py index 006b0f6a0..15d7e4d57 100644 --- a/tests/kernel/wave/attention/vanilla_attention_test.py +++ b/tests/kernel/wave/attention/vanilla_attention_test.py @@ -116,7 +116,7 @@ def testAttentionPure( if dump_generated_mlir: filename = f"wave_attention_{'x'.join(map(str, input_shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3) @@ -199,7 +199,7 @@ def testAttentionCausal( if dump_generated_mlir: filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3) @@ -399,7 +399,7 @@ def repeat( if dump_generated_mlir: filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if "gfx94" in config["target"]: assert_close(output, torch_ref, atol=2e-3, rtol=5e-3, check_dtype=False) @@ -605,7 +605,7 @@ def repeat( if dump_generated_mlir: filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if "gfx94" in config["target"]: assert_close(output, torch_ref, atol=2e-3, rtol=5e-3, check_dtype=False) @@ -771,6 +771,6 @@ def repeat( if dump_generated_mlir: filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) rmse = torch.sqrt(torch.mean(torch.square(output - torch_ref))) assert rmse <= 0.006 diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index f58e06f84..1ec9e5d1f 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -205,7 +205,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -352,7 +352,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -501,7 +501,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -618,7 +618,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}_f8.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -733,7 +733,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}_f8.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -841,7 +841,7 @@ def repeat( if test_dump_generated_mlir: filename = f"wave_batched_gemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: