diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 08f494c71d..e1511ffe9a 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -37,7 +37,22 @@ jobs: pip install numpy pip install pytest USE_CPP=1 pip install . - - name: Run tests + - name: Run python tests run: | conda activate venv pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py + python torchao/experimental/tests/test_embedding_xbit_quantizer.py + - name: Run kernels/cpu/aarch64/tests + run: | + conda activate venv + pushd torchao/experimental/kernels/cpu/aarch64/tests + sh build_and_run_tests.sh + rm -rf /tmp/cmake-out + popd + - name: Run torchao/experimental/ops/tests + run: | + conda activate venv + pushd torchao/experimental/ops/tests + sh build_and_run_tests.sh + rm -rf /tmp/cmake-out + popd diff --git a/setup.py b/setup.py index 6ee93bc9ab..357e0e491f 100644 --- a/setup.py +++ b/setup.py @@ -179,7 +179,8 @@ def build_cmake(self, ext): "cmake", ext.sourcedir, "-DCMAKE_BUILD_TYPE=" + build_type, - "-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF", + # Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16 + "-DTORCHAO_BUILD_KLEIDIAI=OFF", "-DTorch_DIR=" + torch_dir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, ], diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h deleted file mode 100644 index 658a0feadc..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include - -#include - -namespace torchao::kernels::cpu::aarch64::kleidi { -namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { - -namespace neon_dotprod_1x4x32 { -const Ukernel get_ukernel() { - return Ukernel{ - .get_m_step = - kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_n_step = - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_mr = - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_nr = - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_kr = - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_sr = - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_lhs_packed_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_rhs_packed_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_dst_offset = - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_dst_size = - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .run_matmul = - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; -} - -size_t activation_data_size(int m, int k, int group_size) { - (void)group_size; // unused - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( - get_ukernel(), m, k); -} - -void prepare_activation_data( - void* prepared_activation_data, - int m, - int k, - int group_size, - const float* activations) { - (void)group_size; // unused - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), prepared_activation_data, m, k, activations); -} - -size_t weight_data_size(int n, int k, int group_size) { - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( - get_ukernel(), n, k, group_size); -} - -void prepare_weight_data( - void* prepared_weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( - get_ukernel(), - prepared_weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -void kernel( - float32_t* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max) { - if (clamp_min == 0 && clamp_max == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } - - auto ukernel = get_ukernel(); - ukernel.run_matmul( - m, - n, - k, - group_size, - activation_data, - weight_data, - output, - /*dst_stride_row=*/output_m_stride * sizeof(float), - /*dst_stride_col=*/sizeof(float), - clamp_min, - clamp_max); -} - -size_t get_preferred_alignement() { - return 16; -} -} // namespace neon_dotprod_1x4x32 -} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p -} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h deleted file mode 100644 index 336d5a8e7f..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include - -#include - -namespace torchao::kernels::cpu::aarch64::kleidi { -namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { -namespace neon_dotprod_1x8x32 { -const Ukernel get_ukernel() { - return Ukernel{ - .get_m_step = - kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_n_step = - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_mr = - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_nr = - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_kr = - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_sr = - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_lhs_packed_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_rhs_packed_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_dst_offset = - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_dst_size = - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .run_matmul = - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}; -} - -size_t activation_data_size(int m, int k, int group_size) { - (void) group_size; // unused - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k); -} - -void prepare_activation_data( - void* prepared_activation_data, - int m, - int k, - int group_size, - const float* activations) { - (void) group_size; // unused - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), - prepared_activation_data, - m, - k, - activations); -} - -size_t weight_data_size(int n, int k, int group_size) { - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size); -} - -void prepare_weight_data( - void* prepared_weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( - get_ukernel(), - prepared_weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -void kernel( - float32_t* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max) { - if (clamp_min == 0 && clamp_max == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } - - auto ukernel = get_ukernel(); - ukernel.run_matmul( - m, - n, - k, - group_size, - activation_data, - weight_data, - output, - /*dst_stride_row=*/ output_m_stride * sizeof(float), - /*dst_stride_col=*/ sizeof(float), - clamp_min, - clamp_max); -} - -size_t get_preferred_alignement() { - return 16; -} -} // namespace neon_dotprod_1x4x32 -} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p -} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h deleted file mode 100644 index 60004704ed..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include -#include - -namespace torchao::kernels::cpu::aarch64::kleidi { -namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { -namespace neon_i8mm_8x4x32 { - -const Ukernel get_ukernel() { - return Ukernel{ - .get_m_step = - kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_n_step = - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_mr = - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_nr = - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_kr = - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_sr = - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_lhs_packed_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_rhs_packed_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_dst_offset = - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_dst_size = - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .run_matmul = - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm}; -} - -size_t activation_data_size(int m, int k, int group_size) { - (void)group_size; // unused - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( - get_ukernel(), m, k); -} - -void prepare_activation_data( - void* prepared_activation_data, - int m, - int k, - int group_size, - const float* activations) { - (void)group_size; // unused - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), prepared_activation_data, m, k, activations); -} - -size_t weight_data_size(int n, int k, int group_size) { - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( - get_ukernel(), n, k, group_size); -} - -void prepare_weight_data( - void* prepared_weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( - get_ukernel(), - prepared_weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -void kernel( - float32_t* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max) { - if (clamp_min == 0 && clamp_max == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } - - auto ukernel = get_ukernel(); - ukernel.run_matmul( - m, - n, - k, - group_size, - activation_data, - weight_data, - output, - /*dst_stride_row=*/output_m_stride * sizeof(float), - /*dst_stride_col=*/sizeof(float), - clamp_min, - clamp_max); -} - -size_t get_preferred_alignement() { - return 16; -} -} // namespace neon_i8mm_8x4x32 -} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p -} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h deleted file mode 100644 index 90db4ae3d6..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include - -#include - -namespace torchao::kernels::cpu::aarch64::kleidi { -namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { -namespace neon_i8mm_4x8x32 { - -const Ukernel get_ukernel() { - return Ukernel{ - .get_m_step = - kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_n_step = - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_mr = - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_nr = - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_kr = - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_sr = - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_lhs_packed_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_rhs_packed_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_dst_offset = - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_dst_size = - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .run_matmul = - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}; -} - -size_t activation_data_size(int m, int k, int group_size) { - (void)group_size; // unused - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( - get_ukernel(), m, k); -} - -void prepare_activation_data( - void* prepared_activation_data, - int m, - int k, - int group_size, - const float* activations) { - (void)group_size; // unused - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), prepared_activation_data, m, k, activations); -} - -size_t weight_data_size(int n, int k, int group_size) { - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( - get_ukernel(), n, k, group_size); -} - -void prepare_weight_data( - void* prepared_weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( - get_ukernel(), - prepared_weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -void kernel( - float32_t* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max) { - if (clamp_min == 0 && clamp_max == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } - - auto ukernel = get_ukernel(); - ukernel.run_matmul( - m, - n, - k, - group_size, - activation_data, - weight_data, - output, - /*dst_stride_row=*/output_m_stride * sizeof(float), - /*dst_stride_col=*/sizeof(float), - clamp_min, - clamp_max); -} - -size_t get_preferred_alignement() { - return 16; -} - -} // namespace neon_i8mm_4x8x32 -} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p -} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 9cde684995..9071869fce 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -14,8 +14,15 @@ #include #include +#include +#include #include +#ifdef TORCHAO_ENABLE_ARM_I8MM +#include +#include +#endif // TORCHAO_ENABLE_ARM_I8MM + #include namespace torchao::kernels::cpu::aarch64::kleidi { @@ -23,7 +30,9 @@ namespace torchao::kernels::cpu::aarch64::kleidi { // Helper functions // TODO: find a better place for these? -size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } +namespace internal { + +inline size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } uint16_t get_bf16_from_float(float f) { uint16_t bf16; @@ -37,46 +46,59 @@ uint16_t get_bf16_from_float(float f) { return bf16; } +// KleidiAI kernels require n is even, so we round up to next even number +// if required and pad +inline int adjust_n(int n) { return roundup(n, 2); } + +} // namespace internal + namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; -size_t activation_data_size(const Ukernel ukernel, int m, int k) { +template +size_t activation_data_size(int m, int k, int group_size) { + (void)group_size; // unused auto lhs_packing = get_lhs_packing(); - return lhs_packing.get_lhs_packed_size(m, k, ukernel.get_mr(), - ukernel.get_kr(), ukernel.get_sr()); + return lhs_packing.get_lhs_packed_size(m, k, mr, kr, sr); } -void prepare_activation_data(const Ukernel ukernel, void *activation_data, - int m, int k, const float *activations) { +template +void prepare_activation_data(void *activation_data, int m, int k, + int group_size, const float *activations) { + (void)group_size; // unused auto lhs_pack = get_lhs_packing(); - lhs_pack.run_lhs_pack(m, k, ukernel.get_mr(), ukernel.get_kr(), - ukernel.get_sr(), + lhs_pack.run_lhs_pack(m, k, mr, kr, sr, /*m_index_start=*/0, activations, /*lhs_stride=*/k * sizeof(float), activation_data); } -size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { +template +size_t weight_data_size(int n, int k, int group_size) { auto rhs_pack = get_rhs_packing(); - return rhs_pack.get_rhs_packed_size(n, k, ukernel.get_nr(), ukernel.get_kr(), - ukernel.get_sr(), group_size, + return rhs_pack.get_rhs_packed_size(n, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16); } -void prepare_weight_data(const Ukernel ukernel, void *weight_data, int n, int k, - int group_size, const int8_t *weight_qvals, - const float *weight_scales, const int8_t *weight_zeros, - const float *bias) { - // TODO(T204312268) - remove this constraint and pad when possible - assert(n % 2 == 0); +template +void prepare_weight_data(void *weight_data, int n, int k, int group_size, + const int8_t *weight_qvals, const float *weight_scales, + const int8_t *weight_zeros, const float *bias) { - assert(group_size % 32 == 0); - assert(k % group_size == 0); + if (group_size % 32 != 0) { + throw std::runtime_error( + "Group size must be a multiple of 32, but got group_size=" + + std::to_string(group_size)); + } + if (k % group_size != 0) { + throw std::runtime_error( + "k must be a multiple of group size, but got k=" + std::to_string(k) + + " and group_size=" + std::to_string(group_size)); + } // TODO SIMDify this size_t n_groups = n * k / group_size; - auto weight_scales_bf16 = std::vector(n_groups, 0); // We don't support weight zeros yet if (weight_zeros != nullptr) { @@ -85,18 +107,29 @@ void prepare_weight_data(const Ukernel ukernel, void *weight_data, int n, int k, } } + auto weight_scales_bf16_padded = + std::vector(internal::adjust_n(n) * k / group_size, 0); for (size_t i = 0; i < n_groups; i++) { - weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]); + weight_scales_bf16_padded[i] = + internal::get_bf16_from_float(weight_scales[i]); } // Prepack weights before packing // TODO SIMDify this - auto packed_weight_qvals = std::vector(n * k / 2, 0); + auto packed_weight_qvals_padded = + std::vector(internal::adjust_n(n) * k / 2, 0); uint8_t wzp = 8; for (size_t i = 0; i < n * k; i += 2) { const uint8_t low = static_cast(weight_qvals[i] + wzp); const uint8_t high = static_cast(weight_qvals[i + 1] + wzp); - packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF)); + packed_weight_qvals_padded[i / 2] = ((high << 4) | (low & 0xF)); + } + + auto bias_padded = std::vector(internal::adjust_n(n), 0.0); + if (bias != nullptr) { + for (size_t i = 0; i < n; i++) { + bias_padded[i] = bias[i]; + } } // Parameters for packing @@ -107,17 +140,68 @@ void prepare_weight_data(const Ukernel ukernel, void *weight_data, int n, int k, auto rhs_pack = get_rhs_packing(); rhs_pack.run_rhs_pack( - /*groups=*/1, n, k, ukernel.get_nr(), ukernel.get_kr(), ukernel.get_sr(), - group_size, - /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), - /*rhs_stride=*/roundup(k, 2) / 2, - /*bias=*/bias, - /*scale=*/reinterpret_cast(weight_scales_bf16.data()), - /*scale_stride=*/sizeof(uint16_t) * (roundup(k, group_size) / group_size), + /*groups=*/1, internal::adjust_n(n), k, nr, kr, sr, group_size, + /*rhs=*/ + reinterpret_cast(packed_weight_qvals_padded.data()), + /*rhs_stride=*/internal::roundup(k, 2) / 2, + /*bias=*/reinterpret_cast(bias_padded.data()), + /*scale=*/ + reinterpret_cast(weight_scales_bf16_padded.data()), + /*scale_stride=*/sizeof(uint16_t) * + (internal::roundup(k, group_size) / group_size), /*rhs_packed=*/weight_data, /*extra_bytes=*/0, /*qparams=*/&qparams); } +size_t get_preferred_alignement() { return 16; } + +#define DEFINE_KERNEL_STRUCT(name) \ + struct name { \ + inline static kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel \ + get_ukernel() { \ + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel( \ + {.get_m_step = kai_get_m_step_##name, \ + .get_n_step = kai_get_n_step_##name, \ + .get_mr = kai_get_mr_##name, \ + .get_nr = kai_get_nr_##name, \ + .get_kr = kai_get_kr_##name, \ + .get_sr = kai_get_sr_##name, \ + .get_lhs_packed_offset = kai_get_lhs_packed_offset_##name, \ + .get_rhs_packed_offset = kai_get_rhs_packed_offset_##name, \ + .get_dst_offset = kai_get_dst_offset_##name, \ + .get_dst_size = kai_get_dst_size_##name, \ + .run_matmul = kai_run_##name}); \ + } \ + inline static void kernel(float32_t *output, int output_m_stride, int m, \ + int n, int k, int group_size, \ + const void *weight_data, \ + const void *activation_data, float clamp_min, \ + float clamp_max) { \ + if (clamp_min == 0 && clamp_max == 0) { \ + clamp_min = std::numeric_limits::lowest(); \ + clamp_max = std::numeric_limits::max(); \ + } \ + get_ukernel().run_matmul( \ + m, internal::adjust_n(n), k, group_size, activation_data, \ + weight_data, output, \ + /*dst_stride_row=*/output_m_stride * sizeof(float), \ + /*dst_stride_col=*/sizeof(float), /*clamp_min=*/clamp_min, \ + /*clamp_max=*/clamp_max); \ + } \ + } + +DEFINE_KERNEL_STRUCT( + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod); +DEFINE_KERNEL_STRUCT( + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod); + +#ifdef TORCHAO_ENABLE_ARM_I8MM +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm); +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm); +#endif // TORCHAO_ENABLE_ARM_I8MM + +#undef DEFINE_KERNEL_STRUCT + } // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p } // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 5c12d7184e..39cc76d887 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -12,8 +12,6 @@ export CMAKE_OUT=/tmp/cmake-out/torch_ao/kernel_tests target=${1:-"native"} -IS_ARM64=0 -BUILD_ARM_I8MM=0 EXTRA_ARGS="" if [[ "${target}" == "android" ]]; then if [[ -z ${ANDROID_NDK} ]]; then @@ -38,17 +36,10 @@ if [[ "${target}" == "android" ]]; then echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" fi -hash arch; retval=$? -if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then - IS_ARM64=1 -fi - cmake \ ${EXTRA_ARGS} \ -DCMAKE_BUILD_TYPE=Debug \ -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ - -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \ -B ${CMAKE_OUT} diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 070e7bebfb..073e612c68 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -14,15 +14,6 @@ #include #include -#ifdef TORCHAO_ENABLE_KLEIDI -#include -#include -#ifdef TORCHAO_ENABLE_ARM_I8MM -#include -#include -#endif // TORCHAO_ENABLE_ARM_I8MM -#endif // TORCHAO_ENABLE_KLEIDI - float kTol = 0.0001; template @@ -269,327 +260,4 @@ TEST( } } -#ifdef TORCHAO_ENABLE_KLEIDI -template -void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( - int m, int k, int n, int group_size) { - auto test_case = - torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: - generate(m, k, n, group_size, - /*weight_nbit=*/4, - /*has_weight_zeros*/ false, has_bias, has_clamp, - /*weight_scale_bf16_round_trip=*/true); - - using namespace torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; - - std::vector activation_data(activation_data_size(m, k, group_size)); - - prepare_activation_data((void *)activation_data.data(), m, k, group_size, - test_case.activations.data()); - - std::vector weight_data(weight_data_size(n, k, group_size)); - - prepare_weight_data((void *)weight_data.data(), n, k, group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); - - std::vector output(m * n); - kernel(output.data(), - /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - k_eq_gs_32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - large_k_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - even_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - m_clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -template -void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( - int m, int k, int n, int group_size) { - auto test_case = - torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: - generate(m, k, n, group_size, - /*weight_nbit=*/4, - /*has_weight_zeros=*/false, has_bias, has_clamp, - /*round_weight_scales_to_bf16=*/true); - - using namespace torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; - - std::vector activation_data(activation_data_size(m, k, group_size)); - - prepare_activation_data((void *)activation_data.data(), m, k, group_size, - test_case.activations.data()); - - std::vector weight_data(weight_data_size(n, k, group_size)); - - prepare_weight_data((void *)weight_data.data(), n, k, group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); - - std::vector output(m * n); - kernel(output.data(), - /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - k_eq_gs_32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - large_k_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - even_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - m_clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -#ifdef TORCHAO_ENABLE_ARM_I8MM -template -void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( - int m, int k, int n, int group_size) { - auto test_case = - torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: - generate(m, k, n, group_size, - /*weight_nbit=*/4, - /*has_weight_zeros=*/false, has_bias, has_clamp, - /*round_weight_scales_to_bf16=*/true); - - using namespace torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32; - - std::vector activation_data(activation_data_size(m, k, group_size)); - - prepare_activation_data((void *)activation_data.data(), m, k, group_size, - test_case.activations.data()); - - std::vector weight_data(weight_data_size(n, k, group_size)); - - prepare_weight_data((void *)weight_data.data(), n, k, group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); - - std::vector output(m * n); - kernel(output.data(), - /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - k_eq_gs_32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - large_k_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - even_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - m_clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -template -void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( - int m, int k, int n, int group_size) { - auto test_case = - torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: - generate(m, k, n, group_size, - /*weight_nbit=*/4, - /*has_weight_zeros=*/false, has_bias, has_clamp, - /*round_weight_scales_to_bf16=*/true); - - using namespace torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; - - std::vector activation_data(activation_data_size(m, k, group_size)); - - prepare_activation_data((void *)activation_data.data(), m, k, group_size, - test_case.activations.data()); - - std::vector weight_data(weight_data_size(n, k, group_size)); - - prepare_weight_data((void *)weight_data.data(), n, k, group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); - - std::vector output(m * n); - kernel(output.data(), - /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - k_eq_gs_32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - large_k_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - even_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - m_clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); -} -#endif // TORCHAO_ENABLE_ARM_I8MM -#endif // TORCHAO_ENABLE_KLEIDI #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h b/torchao/experimental/ops/embedding_xbit/packed_weights_header.h index 935ee3bfbd..8e47c2d1c0 100644 --- a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h +++ b/torchao/experimental/ops/embedding_xbit/packed_weights_header.h @@ -16,7 +16,7 @@ inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( int max_value_chunk_size, int version = 1) { return torchao::ops::PackedWeightsHeader( - torchao::ops::PackedWeightsFormat::embedding_xbit_universal, + torchao::ops::PackedWeightsType::embedding_xbit_universal, {version, weight_nbit, min_value_chunk_size, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt b/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt index 91fcf60621..82d9fa2cf3 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt @@ -8,6 +8,16 @@ cmake_minimum_required(VERSION 3.19) include(${CMAKE_CURRENT_SOURCE_DIR}/../../Utils.cmake) + # For some reason cpuinfo package has unused functions/variables + # TODO (T215533422): fix upstream +add_compile_options(-Wno-unused-function -Wno-unused-variable) +include(FetchContent) +FetchContent_Declare(cpuinfo + GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git + GIT_TAG aaac07ee499895770c89163ce0920ef8bb41ed23) +FetchContent_MakeAvailable( + cpuinfo) + find_package(Torch REQUIRED) add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT linear_8bit_act_xbit_weight.cpp @@ -15,6 +25,7 @@ add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT ) target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_aten aten_openmp) target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64) +target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE cpuinfo) target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}") target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE USE_ATEN=1) @@ -37,4 +48,5 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS) target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1) target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}") target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64) + target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE cpuinfo) endif() diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h new file mode 100644 index 0000000000..443d903dfb --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -0,0 +1,361 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include + +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#endif // defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#if defined(TORCHAO_ENABLE_KLEIDI) +#include +#endif // TORCHAO_ENABLE_KLEIDI + +namespace torchao::ops::linear_8bit_act_xbit_weight { + +struct PackedWeightsFormat { + torchao::ops::PackedWeightsType type; + int weight_nbit; + bool has_weight_zeros; + bool has_bias; + int nr; + int kr; + int sr; + + PackedWeightsFormat(torchao::ops::PackedWeightsType type, int weight_nbit, + bool has_weight_zeros, bool has_bias, int nr, int kr, + int sr) + : type{type}, weight_nbit{weight_nbit}, + has_weight_zeros{has_weight_zeros}, has_bias{has_bias}, nr{nr}, kr{kr}, + sr{sr} {} + + static PackedWeightsFormat + from_packed_weights_header(torchao::ops::PackedWeightsHeader header) { + return PackedWeightsFormat( + header.type, header.params[0], static_cast(header.params[1]), + static_cast(header.params[2]), header.params[3], header.params[4], + header.params[5]); + } + + inline torchao::ops::PackedWeightsHeader to_packed_weights_header() const { + return torchao::ops::PackedWeightsHeader( + type, {weight_nbit, has_weight_zeros, has_bias, nr, kr, sr}); + } +}; + +struct UKernelConfigRegistrationTable { +private: + using Key = std::pair; + struct KeyHasher { + std::size_t operator()(const Key &k) const { + return std::hash()(k.first) ^ + std::hash()(static_cast(k.second)); + } + }; + std::unordered_map registration_table_; + inline Key make_key(torchao::ops::PackedWeightsHeader header, + cpuinfo_uarch uarch) const { + return std::make_pair(header, uarch); + } + +public: + void register_ukernel_config(PackedWeightsFormat format, cpuinfo_uarch uarch, + UKernelConfig config) { + auto header = format.to_packed_weights_header(); + auto key = make_key(header, uarch); + if (registration_table_.find(key) != registration_table_.end()) { + throw std::runtime_error( + "UKernelConfig is already registered for this format"); + } + registration_table_[key] = config; + } + std::optional + get_ukernel_config(torchao::ops::PackedWeightsHeader header, + cpuinfo_uarch uarch) const { + auto key = make_key(header, uarch); + auto it = registration_table_.find(key); + if (it == registration_table_.end()) { + return std::nullopt; + } + return it->second; + } +}; + +template +void check_format(PackedWeightsFormat format, + torchao::ops::PackedWeightsType type) { + if (format.type != type) { + throw std::runtime_error("Kernel expects packed_weights type=" + + std::to_string(static_cast(type)) + + ", but got packed_weights with type=" + + std::to_string(static_cast(format.type))); + } + if (format.weight_nbit != weight_nbit) { + throw std::runtime_error( + "Kernel expects weight_nbit=" + std::to_string(weight_nbit) + + ", but got packed_weights with weight_nbit=" + + std::to_string(format.weight_nbit)); + } + if (format.has_weight_zeros != has_weight_zeros) { + throw std::runtime_error( + "Kernel expects has_weight_zeros=" + std::to_string(has_weight_zeros) + + ", but got packed_weights with has_weight_zeros=" + + std::to_string(format.has_weight_zeros)); + } + if (format.has_bias != has_bias) { + throw std::runtime_error( + "Kernel expects has_bias=" + std::to_string(has_bias) + + ", but got packed_weights with has_bias=" + + std::to_string(format.has_bias)); + } +} + +template +void register_ukernel_config_universal(UKernelConfigRegistrationTable &table, + PackedWeightsFormat format, + cpuinfo_uarch uarch) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + check_format( + format, + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal); + + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (cpuinfo_has_arm_neon_dot()) { + namespace kernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + table.register_ukernel_config( + format, uarch, + UKernelConfig{ + /*preferred_alignment*/ 16, + /*nr*/ 8, + /*weight_packing_config*/ + {/*weight_data_size_fn*/ + &kernel::weight_data_size, + /*prepare_weight_data_fn*/ + &kernel::prepare_weight_data}, + /*linear_configs*/ + {{{/*mr*/ 1, + /*activation_data_size_fn*/ + &kernel::activation_data_size, + /*prepare_activation_data_fn*/ + &kernel::prepare_activation_data, + /*kernel*/ + &kernel::kernel}}}}); + return; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + } +} + +#if defined(TORCHAO_ENABLE_KLEIDI) +template +UKernelConfig::linear_config_type get_linear_config_kleidi() { + namespace op = torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + assert(m_step == kernel_struct::get_ukernel().get_m_step()); + assert(mr == kernel_struct::get_ukernel().get_mr()); + assert(n_step == kernel_struct::get_ukernel().get_n_step()); + assert(nr == kernel_struct::get_ukernel().get_nr()); + assert(kr == kernel_struct::get_ukernel().get_kr()); + assert(sr == kernel_struct::get_ukernel().get_sr()); + return UKernelConfig::linear_config_type{ + /*mr*/ m_step, + /*activation_data_size_fn*/ &op::activation_data_size, + /*prepare_activation_data_fn*/ &op::prepare_activation_data, + /*kernel*/ &kernel_struct::kernel}; +} + +template +UKernelConfig::weight_packing_config_type get_weight_packing_config_kleidi() { + namespace op = torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + return UKernelConfig::weight_packing_config_type( + {/*weight_data_size_fn*/ &op::weight_data_size, + /*prepare_weight_data_fn*/ &op::prepare_weight_data}); +} + +template +void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table, + PackedWeightsFormat format, + cpuinfo_uarch uarch) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + check_format( + format, torchao::ops::PackedWeightsType::kleidi_ai); + namespace op = torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; +#if defined(TORCHAO_ENABLE_ARM_I8MM) + if (cpuinfo_has_arm_i8mm()) { + constexpr int n_step = 8; + table.register_ukernel_config( + format, uarch, + UKernelConfig{ + /*preferred_alignment*/ op::get_preferred_alignement(), + /*nr*/ n_step, + /*weight_packing_config*/ + get_weight_packing_config_kleidi(), + /*linear_configs*/ + {{get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + /*m_step*/ 4, /*mr*/ 4, n_step, nr, kr, sr>()}}}); + return; + } +#endif // TORCHAO_ENABLE_ARM_I8MM + + if (cpuinfo_has_arm_neon_dot()) { + constexpr int n_step = 8; + table.register_ukernel_config( + format, uarch, + UKernelConfig{ + /*preferred_alignment*/ op::get_preferred_alignement(), + /*nr*/ n_step, + /*weight_packing_config*/ + get_weight_packing_config_kleidi(), + /*linear_configs*/ + {{get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + /*m_step*/ 1, /*mr*/ 1, n_step, nr, kr, sr>()}}}); + return; + } + } + + if (format.nr == 4 && format.kr == 16 && format.sr == 2) { + constexpr int nr = 4; + constexpr int kr = 16; + constexpr int sr = 2; + if (cpuinfo_has_arm_neon_dot()) { + constexpr int n_step = 4; + table.register_ukernel_config( + format, uarch, + UKernelConfig{ + /*preferred_alignment*/ op::get_preferred_alignement(), + /*nr*/ n_step, + /*weight_packing_config*/ + get_weight_packing_config_kleidi(), + /*linear_configs*/ + {{get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /*m_step*/ 1, /*mr*/ 1, n_step, nr, kr, sr>()}}}); + return; + } + } +} +#endif // TORCHAO_ENABLE_KLEIDI + +template +void register_ukernel_config(UKernelConfigRegistrationTable &table, + PackedWeightsFormat format, cpuinfo_uarch uarch) { + switch (format.type) { + case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: { + if (format.has_bias) { + register_ukernel_config_universal( + table, format, uarch); + } else { + register_ukernel_config_universal(table, format, + uarch); + } + break; + } + case torchao::ops::PackedWeightsType::kleidi_ai: { +#ifdef TORCHAO_ENABLE_KLEIDI + register_ukernel_config_kleidi(table, format, + uarch); +#endif // TORCHAO_ENABLE_KLEIDI + break; + } + default: + throw std::runtime_error( + "No registration available for packed_weights_type=" + + std::to_string(static_cast(format.type))); + } + + auto config = + table.get_ukernel_config(format.to_packed_weights_header(), uarch); + if (!config.has_value()) { + throw std::runtime_error("ukernel_config did not register"); + } +} + +// Not thread safe +template +UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsHeader header) { + static UKernelConfigRegistrationTable table; + + // In future, we can populate this with the current thread's uarch + // That will require that select_ukernel_config be called in the lambda + // instead of before it on the main thread + // Note, cpuinfo_get_current_core() is not currently implemeted outside of + // linux XNNPACK often uses non-core specific logic like + // cpuinfo_get_core(0)->uarch in configs + auto uarch = cpuinfo_uarch_unknown; + auto ukernel = table.get_ukernel_config(header, uarch); + if (ukernel.has_value()) { + return ukernel.value(); + } + + auto format = PackedWeightsFormat::from_packed_weights_header(header); + register_ukernel_config(table, format, uarch); + + ukernel = table.get_ukernel_config(header, uarch); + assert(ukernel.has_value()); + return ukernel.value(); +} + +template +UKernelConfig select_ukernel_config(PackedWeightsFormat format) { + return select_ukernel_config( + format.to_packed_weights_header()); +} + +template +PackedWeightsFormat +select_packed_weights_format(std::optional target = std::nullopt) { +// Select KleidiAI format +#if defined(TORCHAO_ENABLE_KLEIDI) + if (!target || *target == "kleidi_ai") { + if constexpr (weight_nbit == 4 && + (!has_weight_zeros)) { // TODO: add has_bias here + return PackedWeightsFormat( + torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit, + has_weight_zeros, /*has_bias*/ true, /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); + } + } +#endif // defined(TORCHAO_ENABLE_KLEIDI) + + // Select universal format + if (!target || *target == "universal") { + return PackedWeightsFormat( + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, + weight_nbit, has_weight_zeros, has_bias, /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); + } + + throw std::runtime_error("No packed_weights_format was selected"); +} + +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 709386998e..1c23bdbbae 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -31,7 +31,7 @@ PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( assert(nc >= 1); // Replace nc with the next number nr divides - nc = ((nc + ukernel_config.nr - 1) / ukernel_config.nr) * ukernel_config.nr; + nc = ((nc + nr - 1) / nr) * nr; tiling_params.nc_by_nr = nc / nr; return tiling_params; @@ -59,16 +59,25 @@ void pack_weight_data_operator(const UKernelConfig &ukernel_config, int nc_tile_size = std::min(nc, n - n_idx); int weight_data_offset = - (n_idx / nr) * ukernel_config.weight_data_size_fn(nr, k, group_size); + (n_idx / nr) * ukernel_config.weight_packing_config.weight_data_size_fn( + nr, k, group_size); int weight_qvals_offset = n_idx * k; int weight_scales_and_zeros_offset = (n_idx * k / group_size); - int bias_offset = n_idx; - ukernel_config.prepare_weight_data_fn( + const int8_t *weight_zeros_ptr = nullptr; + if (weight_zeros != nullptr) { + weight_zeros_ptr = weight_zeros + weight_scales_and_zeros_offset; + } + const float *bias_ptr = nullptr; + if (bias != nullptr) { + bias_ptr = bias + n_idx; + } + + ukernel_config.weight_packing_config.prepare_weight_data_fn( (char *)weight_data + weight_data_offset, /*n=*/nc_tile_size, k, group_size, weight_qvals + weight_qvals_offset, - weight_scales + weight_scales_and_zeros_offset, - weight_zeros + weight_scales_and_zeros_offset, bias + bias_offset); + weight_scales + weight_scales_and_zeros_offset, weight_zeros_ptr, + bias_ptr); }); } @@ -86,7 +95,7 @@ get_default_linear_tiling_params(const UKernelConfig &ukernel_config, int m, TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1"); tiling_params.mc_by_mr = 1; - int mc = tiling_params.mc_by_mr * ukernel_config.mr; + int mc = tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr; int num_mc_panels = (m + mc - 1) / mc; int numerator = n * num_mc_panels; @@ -97,9 +106,10 @@ get_default_linear_tiling_params(const UKernelConfig &ukernel_config, int m, assert(nc >= 1); // Replace nc with next number nr divides - nc = ((nc + ukernel_config.nr - 1) / ukernel_config.nr) * ukernel_config.nr; - assert(nc % ukernel_config.nr == 0); - tiling_params.nc_by_nr = nc / ukernel_config.nr; + int nr = ukernel_config.nr; + nc = ((nc + nr - 1) / nr) * nr; + assert(nc % nr == 0); + tiling_params.nc_by_nr = nc / nr; assert(tiling_params.mc_by_mr >= 1); assert(tiling_params.nc_by_nr >= 1); @@ -112,15 +122,17 @@ inline size_t get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( const UKernelConfig &ukernel_config, const LinearTilingParams &tiling_params, int m, int k, int group_size) { - return ukernel_config.activation_data_size_fn( - tiling_params.mc_by_mr * ukernel_config.mr, k, group_size); + return ukernel_config.linear_configs[0].activation_data_size_fn( + tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr, k, + group_size); } inline size_t get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( const UKernelConfig &ukernel_config, const LinearTilingParams &tiling_params, int m, int k, int group_size) { - return ukernel_config.activation_data_size_fn(m, k, group_size); + return ukernel_config.linear_configs[0].activation_data_size_fn(m, k, + group_size); } inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( @@ -134,20 +146,22 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( // Ignored if has_clamp = false float clamp_min, float clamp_max) { int nr = ukernel_config.nr; - int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.mr); - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int mc = + std::min(m, tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr); + int nc = std::min(n, tiling_params.nc_by_nr * nr); int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; size_t weight_data_size = - ukernel_config.weight_data_size_fn(nr, k, group_size); + ukernel_config.weight_packing_config.weight_data_size_fn(nr, k, + group_size); for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) { int m_idx = mc_tile_idx * mc; int mc_tile_size = std::min(mc, m - m_idx); int activations_offset = m_idx * k; - ukernel_config.prepare_activation_data_fn(activation_data_buffer, - /*m=*/mc_tile_size, k, group_size, - activations + activations_offset); + ukernel_config.linear_configs[0].prepare_activation_data_fn( + activation_data_buffer, + /*m=*/mc_tile_size, k, group_size, activations + activations_offset); torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { int nc_tile_idx = idx; @@ -157,7 +171,7 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( int output_offset = m_idx * n + n_idx; int weight_data_offset = (n_idx / nr) * weight_data_size; - ukernel_config.kernel_fn( + ukernel_config.linear_configs[0].kernel_fn( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, @@ -176,17 +190,19 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( // Inputs int m, int n, int k, int group_size, const void *weight_data, const float *activations, float clamp_min, float clamp_max) { - int mr = ukernel_config.mr; + int mr = ukernel_config.linear_configs[0].mr; int nr = ukernel_config.nr; - int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.mr); - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int mc = std::min(m, tiling_params.mc_by_mr * mr); + int nc = std::min(n, tiling_params.nc_by_nr * nr); int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; size_t weight_data_size = - ukernel_config.weight_data_size_fn(nr, k, group_size); + ukernel_config.weight_packing_config.weight_data_size_fn(nr, k, + group_size); size_t activation_data_size = - ukernel_config.activation_data_size_fn(mr, k, group_size); + ukernel_config.linear_configs[0].activation_data_size_fn(mr, k, + group_size); torchao::parallel_1d(0, num_mc_panels, [&](int64_t idx) { int mc_tile_idx = idx; @@ -195,7 +211,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( int activations_offset = m_idx * k; int activation_data_offset = (m_idx / mr) * activation_data_size; - ukernel_config.prepare_activation_data_fn( + ukernel_config.linear_configs[0].prepare_activation_data_fn( activation_data_buffer + activation_data_offset, /*m=*/mc_tile_size, k, group_size, activations + activations_offset); }); @@ -213,7 +229,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( int output_offset = m_idx * n + n_idx; int weight_data_offset = (n_idx / nr) * weight_data_size; - ukernel_config.kernel_fn( + ukernel_config.linear_configs[0].kernel_fn( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index 1dc69dee74..6742f88b02 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #pragma once +#include #include #include #include @@ -29,27 +30,24 @@ struct UKernelConfig { const void *activation_data, float clamp_min, float clamp_max); - activation_data_size_fn_type activation_data_size_fn{nullptr}; - // preferred_activation_data_alignment is only a preferred alignment for - // performance reasons. Integration surfaces are not required to - // respect this alignment, and the ukernel must behave correctly no matter - // how the prepared_activation_data byte-array is aligned - size_t preferred_activation_data_alignment{0}; - prepare_activation_data_fn_type prepare_activation_data_fn{nullptr}; - - weight_data_size_fn_type weight_data_size_fn{nullptr}; - // weight_data_alignment is only a preferred alignment for - // performance reasons. Integration surfaces are not required to - // respect this alignment, and the ukernel must behave correctly no matter - // how the prepared_weight_data byte-array is aligned - size_t preferred_weight_data_alignment{0}; - prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; - - kernel_fn_type kernel_fn{nullptr}; - int mr{0}; + struct weight_packing_config_type { + weight_data_size_fn_type weight_data_size_fn{nullptr}; + prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; + }; + struct linear_config_type { + int mr{0}; + activation_data_size_fn_type activation_data_size_fn{nullptr}; + prepare_activation_data_fn_type prepare_activation_data_fn{nullptr}; + kernel_fn_type kernel_fn{nullptr}; + }; + + // preferred_alignment for activation and weight data + // Integration surfaces are not required to respect this alignment, and the + // ukernel must behave correctly no matter how buffers are aligned + size_t preferred_alignment{0}; int nr{0}; - - torchao::ops::PackedWeightsHeader packed_weights_header; + weight_packing_config_type weight_packing_config; + std::array linear_configs; }; // Pack weight functions @@ -64,12 +62,13 @@ get_default_pack_weight_data_tiling_params(const UKernelConfig &ukernel_config, inline size_t get_packed_weight_data_size(const UKernelConfig &ukernel_config, int n, int k, int group_size) { - return ukernel_config.weight_data_size_fn(n, k, group_size); + return ukernel_config.weight_packing_config.weight_data_size_fn(n, k, + group_size); } inline size_t get_preferred_packed_weight_data_alignment( const UKernelConfig &ukernel_config) { - return ukernel_config.preferred_weight_data_alignment; + return ukernel_config.preferred_alignment; } void pack_weight_data_operator(const UKernelConfig &ukernel_config, @@ -105,7 +104,7 @@ get_activation_data_buffer_size(const UKernelConfig &ukernel_config, inline size_t get_preferred_activation_data_buffer_alignment( const UKernelConfig &ukernel_config) { - return ukernel_config.preferred_activation_data_alignment; + return ukernel_config.preferred_alignment; } void linear_operator(const UKernelConfig &ukernel_config, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index bc88c0b725..364dd7b668 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -12,67 +12,13 @@ #include #include +#include #include -#include #include #include namespace { -// This selects a UkernelConfig based on the packed weight header -template -inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig -get_ukernel_config(torchao::ops::PackedWeightsHeader header) { - torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config; - - switch (header.format) { -#if defined(__aarch64__) || defined(__ARM_NEON) - case torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal: - namespace ukernel - = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - - // Check packing params match the kernel - TORCHAO_CHECK(header == torchao::ops::linear_8bit_act_xbit_weight:: - get_packed_weights_header_universal( - weight_nbit, has_weight_zeros, has_bias, - /*nr=*/8, - /*kr=*/16), - "Packing params do not match what kernel supports"); - - config.packed_weights_header = header; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - return config; - break; -#endif // defined(__aarch64__) || defined(__ARM_NEON) - default: - TORCHAO_CHECK(false, "Unsupported packed weights format"); - } -} - -template -inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig -get_ukernel_config() { - auto header = torchao::ops::linear_8bit_act_xbit_weight:: - get_packed_weights_header_universal(weight_nbit, has_weight_zeros, - has_bias, /*nr=*/8, /*kr=*/16); - return get_ukernel_config( - header); -} - #ifdef USE_ATEN template Tensor pack_weights_cpu(const Tensor &weight_qvals, const Tensor &weight_scales, @@ -114,8 +60,12 @@ Tensor pack_weights_cpu(const Tensor &weight_qvals, const Tensor &weight_scales, using namespace torchao::ops::linear_8bit_act_xbit_weight; - auto ukernel_config = get_ukernel_config(); + auto packed_weights_format = + select_packed_weights_format(); + auto packed_weights_header = packed_weights_format.to_packed_weights_header(); + auto ukernel_config = select_ukernel_config( + packed_weights_header); + auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( ukernel_config, n, /*target_panels_per_thread=*/1); @@ -124,15 +74,16 @@ Tensor pack_weights_cpu(const Tensor &weight_qvals, const Tensor &weight_scales, get_packed_weight_data_size(ukernel_config, n, k, group_size); Tensor packed_weights = torch::empty( {static_cast(packed_weight_data_size)}, torch::kInt8); - ukernel_config.packed_weights_header.write( - packed_weights.mutable_data_ptr()); - pack_weight_data_operator( - ukernel_config, pack_weight_tiling_params, - packed_weights.mutable_data_ptr() + - torchao::ops::PackedWeightsHeader::size(), - n, k, group_size, weight_qvals.const_data_ptr(), - weight_scales.const_data_ptr(), weight_zeros_ptr, - /*bias*/ nullptr); + packed_weights_header.write(packed_weights.mutable_data_ptr()); + + // TODO: support passing in bias in future + pack_weight_data_operator(ukernel_config, pack_weight_tiling_params, + packed_weights.mutable_data_ptr() + + torchao::ops::PackedWeightsHeader::size(), + n, k, group_size, + weight_qvals.const_data_ptr(), + weight_scales.const_data_ptr(), + weight_zeros_ptr, /*bias*/ nullptr); return packed_weights; } @@ -181,8 +132,10 @@ Tensor pack_weights_meta(const Tensor &weight_qvals, using namespace torchao::ops::linear_8bit_act_xbit_weight; - auto ukernel_config = get_ukernel_config(); + auto packed_weights_format = + select_packed_weights_format(); + auto ukernel_config = select_ukernel_config( + packed_weights_format); auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + @@ -278,18 +231,19 @@ linear_out_cpu(const Tensor &activations, const Tensor &packed_weights, torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); auto ukernel_config = - get_ukernel_config(header); + select_ukernel_config(header); auto linear_tiling_params = get_default_linear_tiling_params(ukernel_config, m, n, /*target_tiles_per_thread=*/5); + auto linear_scheduling_policy = LinearTileSchedulingPolicy::single_mc_parallel_nc; auto activation_data_buffer_size = get_activation_data_buffer_size( ukernel_config, linear_tiling_params, linear_scheduling_policy, m, k, group_size); + std::vector activation_data_buffer(activation_data_buffer_size); linear_operator(ukernel_config, linear_tiling_params, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h deleted file mode 100644 index d86a429461..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include -#include - -namespace torchao::ops::linear_8bit_act_xbit_weight { - -inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( - int weight_nbit, - bool has_weight_zeros, - bool has_bias, - int nr, - int kr, - int version = 1) { - return torchao::ops::PackedWeightsHeader( - torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal, - {version, - weight_nbit, - has_weight_zeros, - has_bias, - nr, - kr, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0}); -} - -} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/packed_weights_header.h b/torchao/experimental/ops/packed_weights_header.h index 7184da4b46..213ec34f7f 100644 --- a/torchao/experimental/ops/packed_weights_header.h +++ b/torchao/experimental/ops/packed_weights_header.h @@ -12,35 +12,36 @@ namespace torchao::ops { -enum class PackedWeightsFormat : uint32_t { +enum class PackedWeightsType : uint32_t { unknown = 0, linear_8bit_act_xbit_weight_universal = 1, - embedding_xbit_universal = 2 + embedding_xbit_universal = 2, + kleidi_ai = 3 }; class PackedWeightsHeader { public: using params_type = std::array; const static int magic = 6712; - PackedWeightsFormat format; + PackedWeightsType type; - // 14 bytes of format specific params + // 14 bytes of type specific params params_type params; PackedWeightsHeader( - PackedWeightsFormat format = PackedWeightsFormat::unknown, + PackedWeightsType type = PackedWeightsType::unknown, params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - : format{format}, params{params} {} + : type{type}, params{params} {} inline static constexpr int size() { - static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64); + static_assert(sizeof(magic) + sizeof(type) + sizeof(params) == 64); return 64; } inline void write(void* packed_weights) const { auto header = reinterpret_cast(packed_weights); header[0] = magic; - header[1] = static_cast(format); + header[1] = static_cast(type); for (int i = 0; i < params.size(); i++) { header[i + 2] = params[i]; } @@ -54,11 +55,11 @@ class PackedWeightsHeader { params[i] = header[i + 2]; } return PackedWeightsHeader( - static_cast(header[1]), params); + static_cast(header[1]), params); } bool operator==(const PackedWeightsHeader& other) const { - if (format != other.format) { + if (type != other.type) { return false; } for (int i = 0; i < params.size(); i++) { @@ -71,3 +72,16 @@ class PackedWeightsHeader { }; } // namespace torchao::ops + +namespace std { + template <> + struct hash { + std::size_t operator()(const torchao::ops::PackedWeightsHeader& f) const { + std::size_t hash = std::hash()(static_cast(f.type)); + for (int i = 0; i < f.params.size(); i++) { + hash ^= std::hash()(f.params[i]); + } + return hash; + }; +}; +} diff --git a/torchao/experimental/ops/tests/build_and_run_tests.sh b/torchao/experimental/ops/tests/build_and_run_tests.sh index 4070b9304f..cff7ca639a 100644 --- a/torchao/experimental/ops/tests/build_and_run_tests.sh +++ b/torchao/experimental/ops/tests/build_and_run_tests.sh @@ -9,6 +9,8 @@ target=${1:-"native"} SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests +export TORCH_DIR = $(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib() + '/torch/share/cmake/Torch')") + IS_ARM64=0 BUILD_ARM_I8MM=0 EXTRA_ARGS="" @@ -45,6 +47,7 @@ cmake \ -DCMAKE_BUILD_TYPE=Debug \ -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ + -DTorch_DIR=${TORCH_DIR} \ -S . \ -B ${CMAKE_OUT} diff --git a/torchao/experimental/ops/tests/generate_tests.py b/torchao/experimental/ops/tests/generate_tests.py index 1710a90c49..160d8fa47a 100755 --- a/torchao/experimental/ops/tests/generate_tests.py +++ b/torchao/experimental/ops/tests/generate_tests.py @@ -51,6 +51,11 @@ def get_test_block(kernel): tests += add_test_string(kernel, 1, 2 * 13, 32, 32, True, False) tests += add_test_string(kernel, 1, 2 * 51, 32, 32, False, True) tests += add_test_string(kernel, 1, 2 * 111, 32, 32, False, False) + ## larger: n (odd) + tests += add_test_string(kernel, 1, 11, 32, 32, False, False) + tests += add_test_string(kernel, 1, 13, 32, 32, True, False) + tests += add_test_string(kernel, 1, 51, 32, 32, False, True) + tests += add_test_string(kernel, 1, 111, 32, 32, False, False) ## larger: k, g - must be multiple of 32 tests += add_test_string(kernel, 1, 2 * 7, 64, 32, False, False) tests += add_test_string(kernel, 1, 2 * 11, 128, 32, True, False) @@ -75,6 +80,11 @@ def get_test_block(kernel): tests += add_test_string(kernel, 17, 2 * 13, 32, 32, True, False) tests += add_test_string(kernel, 23, 2 * 51, 32, 32, False, True) tests += add_test_string(kernel, 41, 2 * 111, 32, 32, False, False) + ## larger: n (odd) + tests += add_test_string(kernel, 7, 11, 32, 32, False, False) + tests += add_test_string(kernel, 17, 13, 32, 32, True, False) + tests += add_test_string(kernel, 23, 51, 32, 32, False, True) + tests += add_test_string(kernel, 41, 111, 32, 32, False, False) ## larger: k, g - must be multiple of 32 tests += add_test_string(kernel, 19, 2 * 7, 64, 32, False, False) tests += add_test_string(kernel, 23, 2 * 11, 128, 32, True, False) diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index bcf746e00e..295b93c3a4 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -13,40 +13,36 @@ #include #if defined(TORCHAO_ENABLE_KLEIDI) -#include -#include -#if defined(TORCHAO_ENABLE_ARM_I8MM) -#include -#include -#endif // TORCHAO_ENABLE_ARM_I8MM +#include #endif // TORCHAO_ENABLE_KLEIDI const float kTol = 1.0e-5; using namespace torchao::ops::linear_8bit_act_xbit_weight; +using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; template UKernelConfig get_ukernel_config() { - UKernelConfig config; - - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + namespace kernel = torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - - return config; + return UKernelConfig{ + /*preferred_alignment*/ 16, + /*nr*/ 8, + /*weight_packing_config*/ + {/*weight_data_size_fn*/ + &kernel::weight_data_size, + /*prepare_weight_data_fn*/ + &kernel::prepare_weight_data}, + /*linear_configs*/ + {{{/*mr*/ 1, + /*activation_data_size_fn*/ + &kernel::activation_data_size, + /*prepare_activation_data_fn*/ + &kernel::prepare_activation_data, + /*kernel*/ + &kernel::kernel}}}}; } template +UKernelConfig get_ukernel_config_kleidi() { + namespace op = torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + auto uk = kernel_struct::get_ukernel(); + assert(m_step == uk.get_m_step()); + assert(mr == uk.get_mr()); + assert(n_step == uk.get_n_step()); + assert(nr == uk.get_nr()); + assert(kr == uk.get_kr()); + assert(sr == uk.get_sr()); + return UKernelConfig{ + op::get_preferred_alignement(), + n_step, + {/*weight_data_size_fn*/ &op::weight_data_size, + /*prepare_weight_data_fn*/ &op::prepare_weight_data}, + {{{m_step, &op::activation_data_size, + &op::prepare_activation_data, &kernel_struct::kernel}}}}; +} template UKernelConfig get_ukernel_config_kleidi() { - UKernelConfig config; #if defined(TORCHAO_ENABLE_ARM_I8MM) if constexpr (kernel_id == i8mm_4x8x32) { - KAI_GEN_UKERNEL( - torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32); - return config; + constexpr int m_step = 4; + constexpr int mr = 4; + constexpr int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + return get_ukernel_config_kleidi< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, m_step, mr, + n_step, nr, kr, sr>(); } if constexpr (kernel_id == i8mm_8x4x32) { - KAI_GEN_UKERNEL( - torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32); - return config; + constexpr int m_step = 8; + constexpr int mr = 8; + constexpr int n_step = 4; + constexpr int nr = 4; + constexpr int kr = 16; + constexpr int sr = 2; + return get_ukernel_config_kleidi< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, m_step, mr, + n_step, nr, kr, sr>(); } #endif // TORCHAO_ENABLE_ARM_I8MM if constexpr (kernel_id == dotprod_1x8x32) { - KAI_GEN_UKERNEL( - torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32); - return config; + constexpr int m_step = 1; + constexpr int mr = 1; + constexpr int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + return get_ukernel_config_kleidi< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, m_step, mr, + n_step, nr, kr, sr>(); } - KAI_GEN_UKERNEL( - torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32); - return config; + if constexpr (kernel_id == dotprod_1x4x32) { + constexpr int m_step = 1; + constexpr int mr = 1; + constexpr int n_step = 4; + constexpr int nr = 4; + constexpr int kr = 16; + constexpr int sr = 2; + return get_ukernel_config_kleidi< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, m_step, mr, + n_step, nr, kr, sr>(); + } + throw std::runtime_error("Unsupported kernel_id"); } #endif // TORCHAO_ENABLE_KLEIDI @@ -253,7 +278,6 @@ TEST(test_linear_8bit_act_xbit_weight, GroupSizeNotDivisibleBy16) { std::runtime_error); } -// begin /* Generated by generate_tests.py */ /* Do not modify */ @@ -340,6 +364,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn222xk32xg32) { /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -494,6 +552,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m41xn222xk32xg32) { /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m7xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/7, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m17xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/17, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m23xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/23, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m41xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/41, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -610,6 +702,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn222xk32xg32) { /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -764,6 +890,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m41xn222xk32xg32) { /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m7xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/7, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m17xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/17, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m23xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/23, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m41xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/41, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -878,6 +1038,39 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn222xk32xg32) { /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m1xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -1029,6 +1222,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m41xn222xk32xg32) { /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m7xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/7, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m17xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/17, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m23xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/23, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m41xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/41, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -1144,6 +1371,39 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn222xk32xg32) { /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m1xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -1295,6 +1555,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m41xn222xk32xg32) { /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m7xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/7, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m17xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/17, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m23xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/23, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m41xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/41, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight<