Skip to content

Commit

Permalink
[UKernel] Add bf16/bfp16 ukernel for peano and move tests to peano (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuyls authored Feb 17, 2025
1 parent fd4db47 commit e9d6615
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 155 deletions.

This file was deleted.

108 changes: 19 additions & 89 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,9 @@ def _execute(self, config):
return True


class MatmulTrunci(BaseMatmul):
class MatmulScaleTrunci(BaseMatmul):
"""
A test of the form matmul(A,B) + trunci(C) where A:MxK, B:KxN and C:MxN
A test of the form matmul(A,B) + scale(C) + trunci(C) where A:MxK, B:KxN and C:MxN
"""

def __init__(
Expand All @@ -747,18 +747,17 @@ def __init__(
rhs,
expected_out,
test_params=None,
use_scaling=False,
):
super().__init__(
name=f"matmul_trunci_{M}_{N}_{K}_{input_type}_{acc_type}",
name=f"matmul_scale_trunci_{M}_{N}_{K}_{input_type}_{acc_type}",
test_params=test_params,
M=M,
N=N,
K=K,
input_type=input_type,
acc_type=acc_type,
)
self.labels.append("MatmulTrunci")
self.labels.append("MatmulScaleTrunci")

# Assertions on shapes: Check that lhs is MxK, rhs is KxN, and expected_out is MxN
assert lhs.shape == (M, K)
Expand All @@ -768,13 +767,10 @@ def __init__(
self.lhs = lhs
self.rhs = rhs
self.expected_out = expected_out
self.use_scaling = use_scaling

def _execute(self, config):
matmul_template_dir = config.file_dir / "matmul_template"
template_name = matmul_template_dir / "matmul_trunci_MxK_KxN.mlir"
if self.use_scaling:
template_name = matmul_template_dir / "matmul_trunci_scaling_MxK_KxN.mlir"
template_name = matmul_template_dir / "matmul_trunci_scaling_MxK_KxN.mlir"
self.generate(config, template_name)
filename = self.get_filename(config)
input_args = generate_inputs(
Expand Down Expand Up @@ -1581,78 +1577,10 @@ def __init__(self):
self.existing_names = []
self.tests = []

# Tests Matmul + Trunci.
# Phoenix : Ukernel + Peano.
self.register(
MatmulTrunci(
256,
128,
32,
"i8",
"i32",
1 * np.ones([256, 32], dtype=np.int8),
1 * np.ones([32, 128], dtype=np.int8),
32 * np.ones([256, 128], dtype=np.int8),
test_params=TestParams(
tile_pipeline="pack-peel-4-level-tiling",
run_on_target=["npu1_4col"],
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=4",
],
use_ukernel=True,
),
)
)
# Phoenix : Vectorization + Peano.
self.register(
MatmulTrunci(
256,
128,
32,
"i8",
"i32",
1 * np.ones([256, 32], dtype=np.int8),
1 * np.ones([32, 128], dtype=np.int8),
32 * np.ones([256, 128], dtype=np.int8),
test_params=TestParams(
tile_pipeline="pack-peel-4-level-tiling",
run_on_target=["npu1_4col"],
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=4",
],
),
)
)
# Strix : Ukernel + Chess.
self.register(
MatmulTrunci(
256,
128,
32,
"i8",
"i32",
1 * np.ones([256, 32], dtype=np.int8),
1 * np.ones([32, 128], dtype=np.int8),
32 * np.ones([256, 128], dtype=np.int8),
test_params=TestParams(
tile_pipeline="pack-peel-4-level-tiling",
run_on_target=["npu4"],
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=8",
],
use_chess=True,
use_ukernel=True,
),
)
)

# Tests Matmul + Trunci with Scaling.
# Phoenix : Ukernel + Peano.
self.register(
MatmulTrunci(
MatmulScaleTrunci(
256,
256,
128,
Expand All @@ -1671,12 +1599,11 @@ def __init__(self):
],
use_ukernel=True,
),
use_scaling=True,
)
)
# Phoenix : Vectorization + Peano.
self.register(
MatmulTrunci(
MatmulScaleTrunci(
256,
256,
128,
Expand All @@ -1693,12 +1620,11 @@ def __init__(self):
"--iree-amdaie-num-cols=4",
],
),
use_scaling=True,
)
)
# Strix : Ukernel + Chess.
# Strix : Ukernel + Peano.
self.register(
MatmulTrunci(
MatmulScaleTrunci(
256,
256,
128,
Expand All @@ -1714,10 +1640,10 @@ def __init__(self):
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=8",
],
use_chess=True,
use_chess=False,
use_ukernel=True,
use_chess_for_ukernel=False,
),
use_scaling=True,
)
)
# Matmul with truncf test(s):
Expand Down Expand Up @@ -1943,7 +1869,8 @@ def __init__(self):
"f32",
test_params=TestParams(
use_ukernel=True,
use_chess=True,
use_chess=False,
use_chess_for_ukernel=False,
run_on_target=["npu4"],
),
)
Expand All @@ -1958,11 +1885,12 @@ def __init__(self):
test_params=TestParams(
name_suffix="npu4_4x8",
use_ukernel=True,
use_chess=False,
use_chess_for_ukernel=False,
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=8",
],
use_chess=True,
run_on_target=["npu4"],
),
)
Expand Down Expand Up @@ -2004,7 +1932,8 @@ def __init__(self):
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=8",
],
use_chess=True,
use_chess=False,
use_chess_for_ukernel=False,
),
)
)
Expand All @@ -2024,7 +1953,8 @@ def __init__(self):
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=8",
],
use_chess=True,
use_chess=False,
use_chess_for_ukernel=False,
),
)
)
Expand Down
Loading

0 comments on commit e9d6615

Please sign in to comment.