diff --git a/build_tools/ci/cpu_comparison/matmul_template/matmul_trunci_MxK_KxN.mlir b/build_tools/ci/cpu_comparison/matmul_template/matmul_trunci_MxK_KxN.mlir deleted file mode 100644 index b6fe8e361..000000000 --- a/build_tools/ci/cpu_comparison/matmul_template/matmul_trunci_MxK_KxN.mlir +++ /dev/null @@ -1,13 +0,0 @@ -// input ${M}x${K}x${TYPE1} -// input ${K}x${N}x${TYPE1} - -func.func @matmul_trunci(%arg0: tensor<${M}x${K}x${TYPE1}>, %arg1: tensor<${K}x${N}x${TYPE1}>) -> tensor<${M}x${N}x${TYPE1}> -{ - %cst = arith.constant ${ZERO} : ${TYPE2} - %0 = tensor.empty() : tensor<${M}x${N}x${TYPE2}> - %1 = linalg.fill ins(%cst : ${TYPE2}) outs(%0 : tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<${M}x${K}x${TYPE1}>, tensor<${K}x${N}x${TYPE1}>) - outs(%1: tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}> - %3 = arith.trunci %2 : tensor<${M}x${N}x${TYPE2}> to tensor<${M}x${N}x${TYPE1}> - return %3: tensor<${M}x${N}x${TYPE1}> -} diff --git a/build_tools/ci/cpu_comparison/run.py b/build_tools/ci/cpu_comparison/run.py index 2e878c387..6b94f9d5a 100755 --- a/build_tools/ci/cpu_comparison/run.py +++ b/build_tools/ci/cpu_comparison/run.py @@ -731,9 +731,9 @@ def _execute(self, config): return True -class MatmulTrunci(BaseMatmul): +class MatmulScaleTrunci(BaseMatmul): """ - A test of the form matmul(A,B) + trunci(C) where A:MxK, B:KxN and C:MxN + A test of the form matmul(A,B) + scale(C) + trunci(C) where A:MxK, B:KxN and C:MxN """ def __init__( @@ -747,10 +747,9 @@ def __init__( rhs, expected_out, test_params=None, - use_scaling=False, ): super().__init__( - name=f"matmul_trunci_{M}_{N}_{K}_{input_type}_{acc_type}", + name=f"matmul_scale_trunci_{M}_{N}_{K}_{input_type}_{acc_type}", test_params=test_params, M=M, N=N, @@ -758,7 +757,7 @@ def __init__( input_type=input_type, acc_type=acc_type, ) - self.labels.append("MatmulTrunci") + self.labels.append("MatmulScaleTrunci") # Assertions on shapes: Check that lhs is MxK, rhs is KxN, and expected_out is MxN assert lhs.shape == (M, K) @@ -768,13 +767,10 @@ def __init__( self.lhs = lhs self.rhs = rhs self.expected_out = expected_out - self.use_scaling = use_scaling def _execute(self, config): matmul_template_dir = config.file_dir / "matmul_template" - template_name = matmul_template_dir / "matmul_trunci_MxK_KxN.mlir" - if self.use_scaling: - template_name = matmul_template_dir / "matmul_trunci_scaling_MxK_KxN.mlir" + template_name = matmul_template_dir / "matmul_trunci_scaling_MxK_KxN.mlir" self.generate(config, template_name) filename = self.get_filename(config) input_args = generate_inputs( @@ -1581,78 +1577,10 @@ def __init__(self): self.existing_names = [] self.tests = [] - # Tests Matmul + Trunci. - # Phoenix : Ukernel + Peano. - self.register( - MatmulTrunci( - 256, - 128, - 32, - "i8", - "i32", - 1 * np.ones([256, 32], dtype=np.int8), - 1 * np.ones([32, 128], dtype=np.int8), - 32 * np.ones([256, 128], dtype=np.int8), - test_params=TestParams( - tile_pipeline="pack-peel-4-level-tiling", - run_on_target=["npu1_4col"], - aie_compilation_flags=[ - "--iree-amdaie-num-rows=4", - "--iree-amdaie-num-cols=4", - ], - use_ukernel=True, - ), - ) - ) - # Phoenix : Vectorization + Peano. - self.register( - MatmulTrunci( - 256, - 128, - 32, - "i8", - "i32", - 1 * np.ones([256, 32], dtype=np.int8), - 1 * np.ones([32, 128], dtype=np.int8), - 32 * np.ones([256, 128], dtype=np.int8), - test_params=TestParams( - tile_pipeline="pack-peel-4-level-tiling", - run_on_target=["npu1_4col"], - aie_compilation_flags=[ - "--iree-amdaie-num-rows=4", - "--iree-amdaie-num-cols=4", - ], - ), - ) - ) - # Strix : Ukernel + Chess. - self.register( - MatmulTrunci( - 256, - 128, - 32, - "i8", - "i32", - 1 * np.ones([256, 32], dtype=np.int8), - 1 * np.ones([32, 128], dtype=np.int8), - 32 * np.ones([256, 128], dtype=np.int8), - test_params=TestParams( - tile_pipeline="pack-peel-4-level-tiling", - run_on_target=["npu4"], - aie_compilation_flags=[ - "--iree-amdaie-num-rows=4", - "--iree-amdaie-num-cols=8", - ], - use_chess=True, - use_ukernel=True, - ), - ) - ) - # Tests Matmul + Trunci with Scaling. # Phoenix : Ukernel + Peano. self.register( - MatmulTrunci( + MatmulScaleTrunci( 256, 256, 128, @@ -1671,12 +1599,11 @@ def __init__(self): ], use_ukernel=True, ), - use_scaling=True, ) ) # Phoenix : Vectorization + Peano. self.register( - MatmulTrunci( + MatmulScaleTrunci( 256, 256, 128, @@ -1693,12 +1620,11 @@ def __init__(self): "--iree-amdaie-num-cols=4", ], ), - use_scaling=True, ) ) - # Strix : Ukernel + Chess. + # Strix : Ukernel + Peano. self.register( - MatmulTrunci( + MatmulScaleTrunci( 256, 256, 128, @@ -1714,10 +1640,10 @@ def __init__(self): "--iree-amdaie-num-rows=4", "--iree-amdaie-num-cols=8", ], - use_chess=True, + use_chess=False, use_ukernel=True, + use_chess_for_ukernel=False, ), - use_scaling=True, ) ) # Matmul with truncf test(s): @@ -1943,7 +1869,8 @@ def __init__(self): "f32", test_params=TestParams( use_ukernel=True, - use_chess=True, + use_chess=False, + use_chess_for_ukernel=False, run_on_target=["npu4"], ), ) @@ -1958,11 +1885,12 @@ def __init__(self): test_params=TestParams( name_suffix="npu4_4x8", use_ukernel=True, + use_chess=False, + use_chess_for_ukernel=False, aie_compilation_flags=[ "--iree-amdaie-num-rows=4", "--iree-amdaie-num-cols=8", ], - use_chess=True, run_on_target=["npu4"], ), ) @@ -2004,7 +1932,8 @@ def __init__(self): "--iree-amdaie-num-rows=4", "--iree-amdaie-num-cols=8", ], - use_chess=True, + use_chess=False, + use_chess_for_ukernel=False, ), ) ) @@ -2024,7 +1953,8 @@ def __init__(self): "--iree-amdaie-num-rows=4", "--iree-amdaie-num-cols=8", ], - use_chess=True, + use_chess=False, + use_chess_for_ukernel=False, ), ) ) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu4_peano.cc b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu4_peano.cc index 8a22b73b5..d8918e468 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu4_peano.cc +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu4_peano.cc @@ -6,8 +6,8 @@ R"peano( -template -__attribute__((noinline)) void zero_vectorized(T *__restrict pC, unsigned offsetC) +template +void zero_vectorized(v16int32 *__restrict pC, unsigned offsetC) { v16int32 zeros = broadcast_zero_to_v16int32(); for (unsigned i = offsetC / r; i < offsetC / r + M * N / r; i++) { @@ -32,71 +32,71 @@ void matmul_vectorized_i8_i32(const int8 * __restrict pA, unsigned offsetA, cons v64acc32 acc_C11; for (unsigned z = 0; z < rowA; z += 2) { - v64acc32 *__restrict pC0 = (v64acc32 *)(pC + offsetC + (z)*size_C); - v64acc32 *__restrict pC1 = (v64acc32 *)(pC + offsetC + ((z + 1)) * size_C); + v64acc32 *__restrict pC0 = (v64acc32 *)(pC + offsetC + (z)*size_C); + v64acc32 *__restrict pC1 = (v64acc32 *)(pC + offsetC + ((z + 1)) * size_C); - for (unsigned j = 0; j < colB; j += 2) { - const v64int8 *__restrict pA0 = (v64int8 *)(pA + offsetA + (z)*size_A); - const v64int8 *__restrict pA1 = (v64int8 *)(pA + offsetA + ((z + 1)) * size_A); + for (unsigned j = 0; j < colB; j += 2) { + const v64int8 *__restrict pA0 = (v64int8 *)(pA + offsetA + (z)*size_A); + const v64int8 *__restrict pA1 = (v64int8 *)(pA + offsetA + ((z + 1)) * size_A); - const v64int8 *__restrict pB0 = (v64int8 *)(pB + offsetB + (j)*colA*size_B); - const v64int8 *__restrict pB1 = (v64int8 *)(pB + offsetB + ((j + 1))*colA * size_B); + const v64int8 *__restrict pB0 = (v64int8 *)(pB + offsetB + (j)*colA*size_B); + const v64int8 *__restrict pB1 = (v64int8 *)(pB + offsetB + ((j + 1))*colA * size_B); - A0 = *pA0; - pA0 += rowA; - A1 = *pA1; - pA1 += rowA; + A0 = *pA0; + pA0 += rowA; + A1 = *pA1; + pA1 += rowA; - B0 = *pB0++; - B1 = *pB1++; + B0 = *pB0++; + B1 = *pB1++; - acc_C00 = *pC0; - acc_C01 = *(pC0 + rowA); + acc_C00 = *pC0; + acc_C01 = *(pC0 + rowA); - acc_C10 = *pC1; - acc_C11 = *(pC1 + rowA); + acc_C10 = *pC1; + acc_C11 = *(pC1 + rowA); - acc_C00 = mac_8x8_8x8(A0, B0, acc_C00); - acc_C01 = mac_8x8_8x8(A0, B1, acc_C01); - acc_C10 = mac_8x8_8x8(A1, B0, acc_C10); - acc_C11 = mac_8x8_8x8(A1, B1, acc_C11); + acc_C00 = mac_8x8_8x8(A0, B0, acc_C00); + acc_C01 = mac_8x8_8x8(A0, B1, acc_C01); + acc_C10 = mac_8x8_8x8(A1, B0, acc_C10); + acc_C11 = mac_8x8_8x8(A1, B1, acc_C11); - for (unsigned i = 1; i < colA; ++i) { - A0 = *pA0; - pA0 += rowA; - A1 = *pA1; - pA1 += rowA; + for (unsigned i = 1; i < colA; ++i) { + A0 = *pA0; + pA0 += rowA; + A1 = *pA1; + pA1 += rowA; - B0 = *pB0++; - B1 = *pB1++; + B0 = *pB0++; + B1 = *pB1++; - acc_C00 = mac_8x8_8x8(A0, B0, acc_C00); - acc_C01 = mac_8x8_8x8(A0, B1, acc_C01); - acc_C10 = mac_8x8_8x8(A1, B0, acc_C10); - acc_C11 = mac_8x8_8x8(A1, B1, acc_C11); - } + acc_C00 = mac_8x8_8x8(A0, B0, acc_C00); + acc_C01 = mac_8x8_8x8(A0, B1, acc_C01); + acc_C10 = mac_8x8_8x8(A1, B0, acc_C10); + acc_C11 = mac_8x8_8x8(A1, B1, acc_C11); + } - // ----- + // ----- - v64acc32 * __restrict pOut00 = pC0; - *pOut00 = acc_C00; - pC0 += rowA; + v64acc32 * __restrict pOut00 = pC0; + *pOut00 = acc_C00; + pC0 += rowA; - v64acc32 * __restrict pOut01 = pC0; - *pOut01 = acc_C01; - pC0 += rowA; + v64acc32 * __restrict pOut01 = pC0; + *pOut01 = acc_C01; + pC0 += rowA; - // ----- + // ----- - v64acc32 * __restrict pOut10 = pC1; - *pOut10 = acc_C10; - pC1 += rowA; + v64acc32 * __restrict pOut10 = pC1; + *pOut10 = acc_C10; + pC1 += rowA; - v64acc32 * __restrict pOut11 = pC1; - *pOut11 = acc_C11; - pC1 += rowA; - } + v64acc32 * __restrict pOut11 = pC1; + *pOut11 = acc_C11; + pC1 += rowA; } + } } template @@ -115,14 +115,158 @@ void matmul_vectorized_8x8x8_i8_i8_i32(const int8 *__restrict pA, (pA, offsetA, pB, offsetB, pC, offsetC); } +v64bfp16ebs8 load_v64bf16_as_bfp16(const bfloat16 *__restrict p) { + v32bfloat16 v0 = *(v32bfloat16 *)(p); + v32bfloat16 v1 = *(v32bfloat16 *)(p + 32); + v32accfloat accum0 = ups(v0); + v32accfloat accum1 = ups(v1); + v64accfloat accum = concat(accum0, accum1); + return to_v64bfp16ebs8(accum); +} + +v64bfp16ebs8 load_v64bf16_as_bfp16_T(const bfloat16 *__restrict p) { + v32bfloat16 v0 = *(v32bfloat16 *)(p); + v32bfloat16 v1 = *(v32bfloat16 *)(p + 32); + v32bfloat16 v0_shuffed = shuffle(v0, 29); + v32bfloat16 v1_shuffed = shuffle(v1, 29); + v32bfloat16 v_shuffed_lo = shuffle(v0_shuffed, v1_shuffed, 14); + v32bfloat16 v_shuffed_hi = shuffle(v0_shuffed, v1_shuffed, 15); + v32accfloat accum0 = ups(v_shuffed_lo); + v32accfloat accum1 = ups(v_shuffed_hi); + v64accfloat accum = concat(accum0, accum1); + return to_v64bfp16ebs8(accum); +} + +template +void zero_vectorized(v16float *__restrict pC, unsigned offsetC) +{ + v16float zeros = broadcast_zero_to_v16float(); + for (unsigned i = offsetC / r; i < offsetC / r + M * N / r; i++) { + pC[i] = zeros; + } +} + + +template +void matmul_vectorized_bf16_f32(const bfloat16 * __restrict pA, unsigned offsetA, const bfloat16 * __restrict pB, + unsigned offsetB, float * __restrict pC, unsigned offsetC) +{ + const unsigned size_A = L0_M * L0_K; + const unsigned size_B = L0_K * L0_N; + const unsigned size_C = L0_M * L0_N; + + v64bfp16ebs8 A0; + v64bfp16ebs8 A1; + v64bfp16ebs8 B0; + v64bfp16ebs8 B1; + v64accfloat acc_C00; + v64accfloat acc_C01; + v64accfloat acc_C10; + v64accfloat acc_C11; + + for (unsigned z = 0; z < rowA; z += 2) { + v64accfloat *__restrict pC0 = (v64accfloat *)(pC + offsetC + (z)*size_C); + v64accfloat *__restrict pC1 = (v64accfloat *)(pC + offsetC + ((z + 1)) * size_C); + + for (unsigned j = 0; j < colB; j += 2) { + const bfloat16 *__restrict pA0 = (bfloat16 *)(pA + offsetA + (z)*size_A); + const bfloat16 *__restrict pA1 = (bfloat16 *)(pA + offsetA + ((z + 1)) * size_A); + + const bfloat16 *__restrict pB0 = (bfloat16 *)(pB + offsetB + (j)*colA*size_B); + const bfloat16 *__restrict pB1 = (bfloat16 *)(pB + offsetB + ((j + 1))*colA * size_B); + + A0 = load_v64bf16_as_bfp16(pA0); + pA0 += rowA * size_A; + A1 = load_v64bf16_as_bfp16(pA1); + pA1 += rowA * size_A; + + B0 = load_v64bf16_as_bfp16_T(pB0); + pB0 += size_B; + B1 = load_v64bf16_as_bfp16_T(pB1); + pB1 += size_B; + + acc_C00 = *pC0; + acc_C01 = *(pC0 + rowA); + + acc_C10 = *pC1; + acc_C11 = *(pC1 + rowA); + + acc_C00 = mac_8x8_8x8T( A0, B0, acc_C00); + acc_C01 = mac_8x8_8x8T( A0, B1, acc_C01); + acc_C10 = mac_8x8_8x8T( A1, B0, acc_C10); + acc_C11 = mac_8x8_8x8T( A1, B1, acc_C11); + + + for (unsigned i = 1; i < colA; ++i) { + A0 = load_v64bf16_as_bfp16(pA0); + pA0 += rowA * size_A; + A1 = load_v64bf16_as_bfp16(pA1); + pA1 += rowA * size_A; + + B0 = load_v64bf16_as_bfp16_T(pB0); + pB0 += size_B; + B1 = load_v64bf16_as_bfp16_T(pB1); + pB1 += size_B; + + acc_C00 = mac_8x8_8x8T( A0, B0, acc_C00); + acc_C01 = mac_8x8_8x8T( A0, B1, acc_C01); + acc_C10 = mac_8x8_8x8T( A1, B0, acc_C10); + acc_C11 = mac_8x8_8x8T( A1, B1, acc_C11); + } + + // ----- + + v64accfloat * __restrict pOut00 = pC0; + *pOut00 = acc_C00; + pC0 += rowA; + + v64accfloat * __restrict pOut01 = pC0; + *pOut01 = acc_C01; + pC0 += rowA; + + // ----- + + v64accfloat * __restrict pOut10 = pC1; + *pOut10 = acc_C10; + pC1 += rowA; + + v64accfloat * __restrict pOut11 = pC1; + *pOut11 = acc_C11; + pC1 += rowA; + } + } +} + +template +void matmul_vectorized_8x8x8_bf16_bf16_f32(const bfloat16 *__restrict pA, + unsigned offsetA, + const bfloat16 *__restrict pB, + unsigned offsetB, float *__restrict pC, + unsigned offsetC) { + constexpr int r = 8; + constexpr int s = 8; + constexpr int t = 8; + static_assert(m / r > 0); + static_assert(k / s > 0); + static_assert(n / t > 0); + return matmul_vectorized_bf16_f32 + (pA, offsetA, pB, offsetB, pC, offsetC); +} + extern "C" { #define matmul_combos_i8(X, M, N, K) \ X(int8, i8, int8, i8, int32, i32, M, N, K, 8, 8, 8) -#define zero_fill_combos(X, M, N) \ +#define zero_fill_combos_i32(X, M, N) \ X(v16int32, i32, M, N, 16) +#define matmul_combos_bfp16(X, M, N, K) \ + X(bfloat16, bf16, bfloat16, bf16, float, f32, M, N, K, 8, 8, 8) + +#define zero_fill_combos_f32(X, M, N) \ + X(v16float, f32, M, N, 16) + #define matmul_vectorized_c_func(lhs_ctype_in, lhs_mlir_type_in, \ rhs_ctype_in, rhs_mlir_type_in, \ acc_ctype_out, acc_mlir_type_out, M, N, K, r, s, t) \ @@ -135,13 +279,23 @@ extern "C" { #define zero_vectorized_c_func(ctype_out, mlir_type_out, M, N, r) \ void zero_##mlir_type_out##_##M##x##N(ctype_out *c_out, unsigned offsetC) { \ - zero_vectorized(c_out, offsetC); \ + zero_vectorized(c_out, offsetC); \ } matmul_combos_i8(matmul_vectorized_c_func, 32, 32, 32) matmul_combos_i8(matmul_vectorized_c_func, 32, 32, 64) -zero_fill_combos(zero_vectorized_c_func, 32, 32) +zero_fill_combos_i32(zero_vectorized_c_func, 32, 32) + +matmul_combos_bfp16(matmul_vectorized_c_func, 16, 8, 32) +matmul_combos_bfp16(matmul_vectorized_c_func, 16, 8, 64) +matmul_combos_bfp16(matmul_vectorized_c_func, 16, 16, 32) +matmul_combos_bfp16(matmul_vectorized_c_func, 32, 32, 32) +matmul_combos_bfp16(matmul_vectorized_c_func, 32, 32, 64) + +zero_fill_combos_f32(zero_vectorized_c_func, 16, 8) +zero_fill_combos_f32(zero_vectorized_c_func, 16, 16) +zero_fill_combos_f32(zero_vectorized_c_func, 32, 32) } // extern "C" )peano"