Skip to content

Commit b559c6d

Browse files
authored
[Experimental][Kleidi] Add GEMM operator tests (#1638)
1 parent c1f5872 commit b559c6d

9 files changed

+1623
-47
lines changed

torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUA
1616
include(FetchContent)
1717
# KleidiAI is an open-source library that provides optimized
1818
# performance-critical routines, also known as micro-kernels, for artificial
19-
# intelligence (AI) workloads tailored for Arm® CPUs.
19+
# intelligence (AI) workloads tailored for Arm® CPUs.
2020
FetchContent_Declare(kleidiai
2121
GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git
22-
GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this
22+
GIT_TAG v1.2.0)
2323
FetchContent_MakeAvailable(kleidiai)
2424

2525
# Temporarily exposing this to the parent scope until we wire

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ void kernel(
108108
activation_data,
109109
weight_data,
110110
output,
111-
/*dst_stride_row=*/n * sizeof(float),
111+
/*dst_stride_row=*/output_m_stride * sizeof(float),
112112
/*dst_stride_col=*/sizeof(float),
113113
clamp_min,
114114
clamp_max);

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void kernel(
109109
activation_data,
110110
weight_data,
111111
output,
112-
/*dst_stride_row=*/ n * sizeof(float),
112+
/*dst_stride_row=*/ output_m_stride * sizeof(float),
113113
/*dst_stride_col=*/ sizeof(float),
114114
clamp_min,
115115
clamp_max);

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void kernel(
106106
activation_data,
107107
weight_data,
108108
output,
109-
/*dst_stride_row=*/n * sizeof(float),
109+
/*dst_stride_row=*/output_m_stride * sizeof(float),
110110
/*dst_stride_col=*/sizeof(float),
111111
clamp_min,
112112
clamp_max);

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void kernel(
107107
activation_data,
108108
weight_data,
109109
output,
110-
/*dst_stride_row=*/n * sizeof(float),
110+
/*dst_stride_row=*/output_m_stride * sizeof(float),
111111
/*dst_stride_col=*/sizeof(float),
112112
clamp_min,
113113
clamp_max);

torchao/experimental/ops/tests/CMakeLists.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,34 @@ if(TORCHAO_BUILD_KLEIDIAI)
2525
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
2626
endif()
2727

28+
if(TORCHAO_BUILD_ARM_I8MM)
29+
add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM)
30+
endif()
31+
32+
if (ANDROID_ABI)
33+
# We are cross compiling, delay test discovery till runtime
34+
set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST)
35+
endif()
36+
2837
include_directories(${TORCHAO_INCLUDE_DIRS})
2938

3039
set(TORCHAO_PARALLEL_BACKEND "test_dummy")
3140
add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)
3241

3342
include(${TORCHAO_ROOT}/Utils.cmake)
43+
44+
if (ANDROID_ABI)
45+
# Given where we are today this is sufficent. But needs to be revisited.
46+
# This is also needed for native builds, but keeping it only for cross builds
47+
# for now given the hacky nature.
48+
file(GLOB DOTPROD_SRC_FILES test*.cpp)
49+
message(SRC_FILES: ${DOTPROD_SRC_FILES})
50+
set_property(SOURCE
51+
${DOTPROD_SRC_FILES}
52+
APPEND_STRING PROPERTY
53+
COMPILE_FLAGS " -march=armv8.2-a+dotprod ")
54+
endif()
55+
3456
add_executable(
3557
test_linear_8bit_act_xbit_weight
3658
test_linear_8bit_act_xbit_weight.cpp

torchao/experimental/ops/tests/build_and_run_tests.sh

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,57 @@
55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
target=${1:-"native"}
9+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
10+
export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests
11+
812
IS_ARM64=0
13+
BUILD_ARM_I8MM=0
14+
EXTRA_ARGS=""
15+
if [[ "${target}" == "android" ]]; then
16+
if [[ -z ${ANDROID_NDK} ]]; then
17+
echo "Need to set ANDROID_NDK env variable to build for Android";
18+
exit 1;
19+
fi
20+
android_abi=arm64-v8a
21+
android_platform=28 # must be >=28 for aligned_alloc
22+
IS_ARM64=1
23+
BUILD_ARM_I8MM=1 # Hardcoded for now
24+
CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android}
25+
toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake"
26+
if [[ -z ${toolchain_file} ]]; then
27+
echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}"
28+
exit 1;
29+
fi
30+
EXTRA_ARGS="\
31+
-DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \
32+
-DANDROID_ABI=${android_abi} \
33+
-DANDROID_PLATFORM=${android_platform}
34+
"
35+
echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}"
36+
fi
37+
938
hash arch; retval=$?
1039
if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then
1140
IS_ARM64=1
1241
fi
1342

14-
export CMAKE_OUT=/tmp/cmake-out/torchao/tests
1543
cmake \
16-
-DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
44+
${EXTRA_ARGS} \
45+
-DCMAKE_BUILD_TYPE=Debug \
1746
-DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \
47+
-DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \
1848
-S . \
1949
-B ${CMAKE_OUT}
2050

2151
cmake --build ${CMAKE_OUT}
2252

53+
echo "Successfully built tests."
54+
55+
if [[ "${target}" != "native" ]]; then
56+
echo "Skip running tests when cross compiling.";
57+
exit 0;
58+
fi
59+
2360
# Run
2461
${CMAKE_OUT}/test_linear_8bit_act_xbit_weight
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
4+
# Simple script to generate test cases for the torchao ops
5+
from string import Template
6+
7+
8+
def add_test_string(kernel, m, n, k, g, has_bias, has_clamp):
9+
name = f"m{m}xn{n}xk{k}xg{g}{'_bias' if has_bias else ''}{'_clamp' if has_clamp else ''}"
10+
d = {
11+
"name": name,
12+
"kernel": kernel,
13+
"m": m,
14+
"n": n,
15+
"k": k,
16+
"g": g,
17+
"has_bias": "true" if has_bias else "false",
18+
"has_clamp": "true" if has_clamp else "false",
19+
}
20+
21+
test_template = Template(
22+
"""
23+
TEST(test_linear_8bit_act_xbit_weight, Kleidi_${kernel}_${name}) {
24+
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
25+
test_linear_8bit_act_xbit_weight<
26+
4 /*weight_nbit*/,
27+
false /*has_weight_zeros*/,
28+
$has_bias /*has_bias*/,
29+
$has_clamp /*has_clamp*/,
30+
true /*has_kleidi*/>(
31+
/*m=*/$m, /*n=*/$n, /*k=*/$k, /*group_size=*/$g, &ukernel_config);
32+
}
33+
"""
34+
)
35+
36+
return [test_template.safe_substitute(d)]
37+
38+
39+
def get_test_block(kernel):
40+
# Assuming given kleidi kernel can run with all these test cases
41+
tests = []
42+
# GEMV, m == 1
43+
## subtile
44+
tests += add_test_string(kernel, 1, 2 * 1, 32, 32, False, False)
45+
tests += add_test_string(kernel, 1, 2 * 2, 32, 32, False, False)
46+
tests += add_test_string(kernel, 1, 2 * 3, 32, 32, True, False)
47+
tests += add_test_string(kernel, 1, 2 * 2, 32, 32, True, True)
48+
tests += add_test_string(kernel, 1, 2 * 3, 32, 32, False, True)
49+
## larger: n - must be multiple of 2
50+
tests += add_test_string(kernel, 1, 2 * 11, 32, 32, False, False)
51+
tests += add_test_string(kernel, 1, 2 * 13, 32, 32, True, False)
52+
tests += add_test_string(kernel, 1, 2 * 51, 32, 32, False, True)
53+
tests += add_test_string(kernel, 1, 2 * 111, 32, 32, False, False)
54+
## larger: k, g - must be multiple of 32
55+
tests += add_test_string(kernel, 1, 2 * 7, 64, 32, False, False)
56+
tests += add_test_string(kernel, 1, 2 * 11, 128, 32, True, False)
57+
tests += add_test_string(kernel, 1, 2 * 13, 64, 64, False, True)
58+
tests += add_test_string(kernel, 1, 2 * 17, 128, 64, False, False)
59+
60+
# GEMM, m > 1
61+
## subtile
62+
tests += add_test_string(kernel, 2, 2 * 1, 32, 32, False, False)
63+
tests += add_test_string(kernel, 2, 2 * 2, 32, 32, False, False)
64+
tests += add_test_string(kernel, 3, 2 * 3, 32, 32, True, False)
65+
tests += add_test_string(kernel, 4, 2 * 4, 32, 32, True, True)
66+
tests += add_test_string(kernel, 3, 2 * 3, 32, 32, False, True)
67+
## larger: m
68+
tests += add_test_string(kernel, 31, 2 * 1, 32, 32, False, False)
69+
tests += add_test_string(kernel, 32, 2 * 2, 32, 32, False, False)
70+
tests += add_test_string(kernel, 33, 2 * 3, 32, 32, True, False)
71+
tests += add_test_string(kernel, 34, 2 * 4, 32, 32, True, True)
72+
tests += add_test_string(kernel, 35, 2 * 3, 32, 32, False, True)
73+
## larger: n - must be multiple of 2
74+
tests += add_test_string(kernel, 7, 2 * 11, 32, 32, False, False)
75+
tests += add_test_string(kernel, 17, 2 * 13, 32, 32, True, False)
76+
tests += add_test_string(kernel, 23, 2 * 51, 32, 32, False, True)
77+
tests += add_test_string(kernel, 41, 2 * 111, 32, 32, False, False)
78+
## larger: k, g - must be multiple of 32
79+
tests += add_test_string(kernel, 19, 2 * 7, 64, 32, False, False)
80+
tests += add_test_string(kernel, 23, 2 * 11, 128, 32, True, False)
81+
tests += add_test_string(kernel, 29, 2 * 13, 64, 64, False, True)
82+
tests += add_test_string(kernel, 101, 2 * 17, 128, 64, False, False)
83+
84+
return "".join(tests)
85+
86+
87+
def main():
88+
kleidi_template = Template(
89+
"""
90+
/*****************/
91+
// ${kernel} tests
92+
/*****************/
93+
${prologue}
94+
${tests}
95+
${epilogue}
96+
"""
97+
)
98+
99+
kleidi_kernels = [
100+
"dotprod_1x4x32",
101+
"dotprod_1x8x32",
102+
"i8mm_4x8x32",
103+
"i8mm_8x4x32",
104+
]
105+
106+
print("/* Generated by generate_tests.py */")
107+
print("/* Do not modify */")
108+
print()
109+
print("#if defined(TORCHAO_ENABLE_KLEIDI)")
110+
for kernel in kleidi_kernels:
111+
prologue, epilogue = "", ""
112+
if "i8mm" in kernel:
113+
prologue = "#if defined(TORCHAO_ENABLE_ARM_I8MM)"
114+
epilogue = "#endif // TORCHAO_ENABLE_ARM_I8MM"
115+
tests = get_test_block(kernel)
116+
d = {
117+
"prologue": prologue,
118+
"kernel": kernel,
119+
"tests": tests,
120+
"epilogue": epilogue,
121+
}
122+
123+
print(kleidi_template.safe_substitute(d))
124+
print("#endif // TORCHAO_ENABLE_KLEIDI")
125+
126+
127+
if __name__ == "__main__":
128+
main()

0 commit comments

Comments
 (0)