Skip to content

Commit 70fc520

Browse files
authored
Add lut support to linear kernel
Differential Revision: D71357417 Pull Request resolved: #1990
1 parent 934d11e commit 70fc520

File tree

5 files changed

+462
-25
lines changed

5 files changed

+462
-25
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h

+65-2
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,69 @@ void pack_weights(
114114
bias);
115115
}
116116

117+
template <int weight_nbit, int nr, int kr, int sr>
118+
void pack_weights_with_lut(
119+
// Output
120+
void* packed_weights,
121+
// Inputs
122+
int n,
123+
int k,
124+
int group_size,
125+
const int8_t* weight_qval_idxs,
126+
int n_luts,
127+
const int8_t* luts,
128+
const float* weight_scales,
129+
// weight_zeros not packed if nullptr
130+
const int8_t* weight_zeros,
131+
// bias not packed if nullptr
132+
const float* bias) {
133+
torchao::kernels::cpu::aarch64::linear::
134+
channelwise_8bit_activation_groupwise_lowbit_weight::weight_packing::
135+
pack_weights_with_lut<weight_nbit, nr, kr, sr>(
136+
packed_weights,
137+
n,
138+
k,
139+
group_size,
140+
weight_qval_idxs,
141+
n_luts,
142+
luts,
143+
weight_scales,
144+
weight_zeros,
145+
bias);
146+
}
147+
148+
inline size_t packed_weights_with_lut_size(
149+
int n,
150+
int k,
151+
int group_size,
152+
int weight_nbit,
153+
bool has_weight_zeros,
154+
bool has_bias,
155+
int nr,
156+
int kr,
157+
int sr) {
158+
(void)kr; // unused
159+
(void)sr; // unused
160+
return weight_packing::packed_weights_with_lut_size(
161+
n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr);
162+
}
163+
164+
inline size_t packed_weights_with_lut_offset(
165+
int n_idx,
166+
int k,
167+
int group_size,
168+
int weight_nbit,
169+
bool has_weight_zeros,
170+
bool has_bias,
171+
int nr,
172+
int kr,
173+
int sr) {
174+
assert(n_idx % nr == 0);
175+
auto packed_weights_size_nr_cols = packed_weights_with_lut_size(
176+
nr, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr);
177+
return (n_idx / nr) * packed_weights_size_nr_cols;
178+
}
179+
117180
template <int weight_nbit>
118181
void kernel_1x1x32_f32_neondot(
119182
// Outputs
@@ -182,7 +245,7 @@ void kernel_1x4x16_f32_neondot(
182245
has_clamp);
183246
}
184247

185-
template <int weight_nbit>
248+
template <int weight_nbit, bool has_lut>
186249
void kernel_1x8x16_f32_neondot(
187250
// Outputs
188251
float32_t* output,
@@ -200,7 +263,7 @@ void kernel_1x8x16_f32_neondot(
200263
bool has_weight_zeros,
201264
bool has_bias,
202265
bool has_clamp) {
203-
kernel::kernel_1x8x16_f32_neondot<weight_nbit>(
266+
kernel::kernel_1x8x16_f32_neondot<weight_nbit, has_lut>(
204267
output,
205268
output_m_stride,
206269
m,

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h

+38-11
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ vec_clamp(float32x4_t x, float32x4_t vec_min, float32x4_t vec_max) {
5858
// Roughly inspired by
5959
// https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c?ref_type=heads
6060

61-
template <int weight_nbit>
61+
template <int weight_nbit, bool has_lut>
6262
void kernel_1x8x16_f32_neondot(
6363
// Outputs
6464
float32_t* output,
@@ -79,6 +79,11 @@ void kernel_1x8x16_f32_neondot(
7979
assert(k % group_size == 0);
8080
assert(group_size % 16 == 0);
8181

82+
int8x16_t lut;
83+
if constexpr (!has_lut) {
84+
(void)lut; // unused
85+
}
86+
8287
constexpr int bytes_per_128_weight_values = 16 * weight_nbit;
8388

8489
auto activation_data_byte_ptr = (char*)(activation_data);
@@ -99,6 +104,11 @@ void kernel_1x8x16_f32_neondot(
99104
// Weights and activations are padded when prepared, so the
100105
// reads are legal, even if on a partial tile
101106
for (int n_idx = 0; n_idx < n; n_idx += 8) {
107+
if constexpr (has_lut) {
108+
lut = vld1q_s8((int8_t*)weight_data_byte_ptr);
109+
weight_data_byte_ptr += 16;
110+
}
111+
102112
// Set activation_ptr to start of activation qvals for row m_idx
103113
activation_ptr = activation_data_byte_ptr;
104114
float32x4_t res_0123 = vdupq_n_f32(0.0);
@@ -167,16 +177,33 @@ void kernel_1x8x16_f32_neondot(
167177
// Each chunk is 64 values of unpacked data (4 cols x 16 vals/col).
168178
// This comes out to (64 * weight_nbit / 8) bits = 8 * weight_nbit
169179
// bytes of bitpacked data
170-
torchao::bitpacking::vec_unpack_128_lowbit_values<weight_nbit>(
171-
weight_q_cols01_0,
172-
weight_q_cols23_0,
173-
weight_q_cols45_0,
174-
weight_q_cols67_0,
175-
weight_q_cols01_1,
176-
weight_q_cols23_1,
177-
weight_q_cols45_1,
178-
weight_q_cols67_1,
179-
(uint8_t*)weight_data_byte_ptr);
180+
181+
if constexpr (has_lut) {
182+
torchao::bitpacking::vec_unpack_128_lowbit_values_with_lut<
183+
weight_nbit>(
184+
weight_q_cols01_0,
185+
weight_q_cols23_0,
186+
weight_q_cols45_0,
187+
weight_q_cols67_0,
188+
weight_q_cols01_1,
189+
weight_q_cols23_1,
190+
weight_q_cols45_1,
191+
weight_q_cols67_1,
192+
(uint8_t*)weight_data_byte_ptr,
193+
lut);
194+
} else {
195+
torchao::bitpacking::vec_unpack_128_lowbit_values<weight_nbit>(
196+
weight_q_cols01_0,
197+
weight_q_cols23_0,
198+
weight_q_cols45_0,
199+
weight_q_cols67_0,
200+
weight_q_cols01_1,
201+
weight_q_cols23_1,
202+
weight_q_cols45_1,
203+
weight_q_cols67_1,
204+
(uint8_t*)weight_data_byte_ptr);
205+
}
206+
180207
weight_data_byte_ptr += bytes_per_128_weight_values;
181208

182209
// Load 16 activation values

0 commit comments

Comments
 (0)