Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] use xnnpack quantization in eager/aoti #698

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions _custom_linear/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 3.17)
project(custom_linear)

set(CMAKE_CXX_STANDARD 17)

find_package(Torch REQUIRED)

add_library(custom_linear SHARED custom_linear.cpp)
target_include_directories(custom_linear PRIVATE "${TORCHCHAT_ROOT}/..")
target_link_libraries(custom_linear PRIVATE "${TORCH_LIBRARIES}")
29 changes: 29 additions & 0 deletions _custom_linear/_custom_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import torch.nn as nn
from typing import Optional
torch.ops.load_library("_custom_linear/build/libcustom_linear.dylib")
from .quantize import group_quantize_tensor_symmetric, convert_to_qc4w

class _CustomLinear(nn.Module):
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> None:
super().__init__()
self.weight = weight
assert bias is None

self.group_size = 32
w_int, s, z = group_quantize_tensor_symmetric(self.weight, self.group_size, torch.float32)
w_packed = convert_to_qc4w(w_int)
self.prepacked = torch.ops.torchchat.prepack.default(w_packed, s)

def forward(self, x):
if x.dtype != torch.float32:
raise RuntimeError(f"x has dtype {x.dtype}, expected float32")
assert x.shape[0] == 1
return torch.ops.torchchat.run.default(self.prepacked, x.squeeze(0)).unsqueeze(0)

def _replace_linear_with_custom_linear(module: nn.Module):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, _CustomLinear(child.weight, child.bias))
else:
_replace_linear_with_custom_linear(child)
5 changes: 5 additions & 0 deletions _custom_linear/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
rm -rf build
mkdir build
# cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" -S . -B build
cmake -DCMAKE_PREFIX_PATH="/Users/scroy/repos/pytorch/torch/share/cmake" -DTORCHCHAT_ROOT="${PWD}/.." -S . -B build
cmake --build build
189 changes: 189 additions & 0 deletions _custom_linear/custom_linear.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
#include <torch/library.h>
#include <torch/script.h>
#include <ATen/native/xnnpack/Common.h>

at::native::xnnpack::Operator create_fully_connected_nc_qd8_f32_qb4w(
at::Tensor weight,
at::Tensor weight_scales) {

TORCH_CHECK(weight.dim() == 2, "weight must be 2-dimensional");
TORCH_CHECK(weight.size(1) % 2 == 0, "weight columns must be even (packed int4)");

TORCH_CHECK(weight_scales.dim() == 2, "weight_scales must be 2-dimensional");
TORCH_CHECK(weight.size(0) == weight_scales.size(0), "weight and weight_scale must have same number of rows");

const float output_min = -std::numeric_limits<float>::infinity();
const float output_max = std::numeric_limits<float>::infinity();
const uint8_t weight_zero_point = 8;

auto input_channels = 2*weight.size(1); // Multiply by 2 because weights are packed
auto output_channels = weight.size(0);


TORCH_CHECK((input_channels % weight_scales.size(1)) == 0, "number of columns in weight_scales should divide input_channels");
size_t group_size = input_channels / weight_scales.size(1);
TORCH_CHECK(group_size > 1, "inferred group_size must be > 1");
TORCH_CHECK((group_size&(group_size-1)) == 0, "inferred group_size must be a power of 2");

// Create FC
xnn_operator_t fc_op = nullptr;
auto status = xnn_create_fully_connected_nc_qd8_f32_qb4w(
input_channels, /*size_t input_channels*/
output_channels, /*size_t output_channels*/
input_channels, /*size_t input_stride*/
output_channels, /*size_t output_stride*/
group_size, /*size_t block_size*/
weight_zero_point, /*uint8_t kernel_zero_point*/
weight_scales.const_data_ptr<float>(), /*const float* kernel_scale*/
weight.const_data_ptr(), /*const void* kernel*/
nullptr, /*const float* bias*/
output_min, /*float output_min*/
output_max, /*float output_max*/
0, /*uint32_t flags*/
nullptr, /*xnn_code_cache_t code_cache*/
nullptr, /*xnn_weights_cache_t weights_cache*/
&fc_op /*xnn_operator_t* fully_connected_op_out*/
);
TORCH_CHECK(status == xnn_status_success, "Operator xnn_create_fully_connected_nc_qd8_f32_qb4w failed with status ", status, ".");
TORCH_CHECK(fc_op != nullptr);

return at::native::xnnpack::Operator(fc_op);
}


at::native::xnnpack::Operator create_convert_nc_f32_qd8() {
xnn_operator_t convert_op = nullptr;
auto status = xnn_create_convert_nc_f32_qd8(
0, /*uint32_t flags*/
&convert_op /*xnn_operator_t* convert_op_out*/
);
TORCH_CHECK(status == xnn_status_success, "Operator xnn_create_convert_nc_f32_qd8 failed with status ", status, ".");
TORCH_CHECK(convert_op != nullptr);
return at::native::xnnpack::Operator(convert_op);
}

at::Tensor run_linear_qd8_f32_qb4w(xnn_operator_t convert_op, xnn_operator_t fc_op, int64_t output_channels, at::Tensor input) {
TORCH_CHECK(input.dim() == 2);

auto batch_size = input.size(0);
auto input_channels = input.size(1);
xnn_status status;

// Holds output of convert
std::vector<int8_t> output_convert(batch_size * input_channels + XNN_EXTRA_BYTES);
std::vector<xnn_dynamic_quantization_params> quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS);

// Run input convert
status = xnn_reshape_convert_nc_f32_qd8(
convert_op, /*xnn_operator_t convert_op*/
batch_size, /*size_t batch_size*/
input_channels, /*size_t channels*/
input_channels, /*size_t input_stride*/
input_channels, /*size_t output_stride*/
nullptr /*pthreadpool_t threadpool*/
);
TORCH_CHECK(status == xnn_status_success, "Operator xnn_reshape_convert_nc_f32_qd8 failed with status ", status, ".");

status = xnn_setup_convert_nc_f32_qd8(
convert_op, /*xnn_operator_t convert_op*/
input.const_data_ptr<float>(), /*const float* input*/
output_convert.data(), /*int8_t* output*/
quantization_params.data() /*struct xnn_dynamic_quantization_params* quantization_params*/
);
TORCH_CHECK(status == xnn_status_success, "Operator xnn_setup_convert_nc_f32_qd8 failed with status ", status, ".");

status = xnn_run_operator(convert_op, /*threadpool=*/nullptr);
TORCH_CHECK(status == xnn_status_success, "Running convert_op failed with status ", status, ".");


// Holds output of linear
auto options = torch::TensorOptions().dtype(torch::kFloat32);
auto output_tensor = torch::empty({batch_size, output_channels}, options);

// Run linear
status = xnn_reshape_fully_connected_nc_qd8_f32_qb4w(
fc_op, /*xnn_operator_t fully_connected_op*/
batch_size, /*size_t batch_size*/
nullptr /*pthreadpool_t threadpool*/ // TODO: set to something sensible
);
TORCH_CHECK(status == xnn_status_success, "Operator xnn_reshape_fully_connected_nc_qd8_f32_qb4w failed with status ", status, ".");

status = xnn_setup_fully_connected_nc_qd8_f32_qb4w(
fc_op, /*xnn_operator_t fully_connected_op*/
output_convert.data(), /*const int8_t* input*/
output_tensor.data_ptr<float>(), /*float* output*/
quantization_params.data() /*const struct xnn_dynamic_quantization_params* quantization_params*/
);
TORCH_CHECK(status == xnn_status_success, "Operator xnn_setup_fully_connected_nc_qd8_f32_qb4w failed with status ", status, ".");


status = xnn_run_operator(fc_op, /*threadpool=*/nullptr);
TORCH_CHECK(status == xnn_status_success, "Running fc_op failed with status ", status, ".");

return output_tensor;
}





class PrepackedContext : public torch::jit::CustomClassHolder {
private:
at::native::xnnpack::Operator convert_op_;
at::native::xnnpack::Operator fc_op_;
size_t output_channels_;

public:
PrepackedContext(at::native::xnnpack::Operator convert_op, at::native::xnnpack::Operator fc_op, size_t output_channels) :
convert_op_(std::move(convert_op)), fc_op_(std::move(fc_op)), output_channels_(output_channels) {}
xnn_operator_t convert_op() {
return convert_op_.get();
}

xnn_operator_t fc_op() {
return fc_op_.get();
}

size_t output_channels() {
return output_channels_;
}

};

c10::intrusive_ptr<PrepackedContext> prepack(at::Tensor weight, at::Tensor weight_scales) {
auto status = xnn_initialize(/*allocator=*/nullptr);
TORCH_CHECK(status == xnn_status_success);
auto convert_op = create_convert_nc_f32_qd8();
auto fc_op = create_fully_connected_nc_qd8_f32_qb4w(weight, weight_scales);
auto output_channels = weight.size(0);

return c10::make_intrusive<PrepackedContext>(
at::native::xnnpack::Operator(std::move(convert_op)),
at::native::xnnpack::Operator(std::move(fc_op)),
output_channels
);
}


at::Tensor run(
c10::intrusive_ptr<PrepackedContext> prepacked_context,
at::Tensor input) {
return run_linear_qd8_f32_qb4w(prepacked_context->convert_op(), prepacked_context->fc_op(), prepacked_context->output_channels(), input);
}

at::Tensor prepack_and_run(
at::Tensor weight,
at::Tensor weight_scales,
at::Tensor input) {

auto prepacked_context = prepack(weight, weight_scales);
return run(prepacked_context, input);
}


TORCH_LIBRARY(torchchat, m) {
m.class_<PrepackedContext>("PrepackedContext");
m.def("prepack", prepack);
m.def("run", run);
m.def("prepack_and_run", prepack_and_run);
}
55 changes: 55 additions & 0 deletions _custom_linear/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include <torch/library.h>
#include <torch/script.h>
#include <ATen/native/xnnpack/Common.h>

int main() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This segfaults.


xnn_status status;
// status = xnn_initialize(/*allocator=*/nullptr);
// TORCH_CHECK(status == xnn_status_success);

auto w_col = 384;
auto input_channels = w_col*2;
auto output_channels = 32000;
auto group_size = 32;
auto n_groups = 24;
TORCH_CHECK(n_groups * group_size == input_channels);

// auto options = torch::TensorOptions().dtype(torch::kByte);
// auto weight = torch::ones({output_channels, w_col}, options);
// auto weight_scales = torch::ones({output_channels, n_groups});

auto weight_data = std::vector<uint8_t>();
for (int i = 0; i < output_channels * w_col; ++i) {
weight_data.push_back(1);
}

auto weight_scales = std::vector<float>();
for (int i = 0; i < output_channels * n_groups; ++i) {
weight_data.push_back(1.0);
}

const float output_min = -std::numeric_limits<float>::infinity();
const float output_max = std::numeric_limits<float>::infinity();
const uint8_t weight_zero_point = 8;

xnn_operator_t fc_op = nullptr;
status = xnn_create_fully_connected_nc_qd8_f32_qb4w(
input_channels, /*size_t input_channels*/
output_channels, /*size_t output_channels*/
input_channels, /*size_t input_stride*/
output_channels, /*size_t output_stride*/
group_size, /*size_t block_size*/
weight_zero_point, /*uint8_t kernel_zero_point*/
weight_scales.data(), /*const float* kernel_scale*/
(void*)weight_data.data(), /*const void* kernel*/
nullptr, /*const float* bias*/
output_min, /*float output_min*/
output_max, /*float output_max*/
0, /*uint32_t flags*/
nullptr, /*xnn_code_cache_t code_cache*/
nullptr, /*xnn_weights_cache_t weights_cache*/
&fc_op /*xnn_operator_t* fully_connected_op_out*/
);

}
92 changes: 92 additions & 0 deletions _custom_linear/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# https://www.internalfb.com/code/fbsource/[f1458254b3caba86fb497abbfe15c74c4e8ca38d]/fbcode/executorch/backends/xnnpack/test/ops/linear.py?lines=348

import torch
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig

# Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues
def get_group_qparams_symmetric(w, n_bit, groupsize, precision):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

max_val_abs = torch.max(-min_val_neg, max_val_pos)
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))

# max_int - min_int is just 2**(n_bit) - 1

scales = max_val_abs / (float(max_int - min_int) / 2) # This is just 2 * max(abs(x)) / (int range)
scales = torch.max(
scales, torch.full_like(scales, torch.finfo(torch.float32).eps)
)
# TODO: make sure abs(scales) is not too small?
zeros = torch.full_like(scales, 0)
return scales.to(precision).reshape(w.shape[0], -1), zeros.to(
precision
).reshape(w.shape[0], -1)

# Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues
# Does 4-bit quantization
def group_quantize_tensor_symmetric(w, group_size, precision):
n_bit = 4
scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision)
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))
# TODO: currently we don't know how to express torch.int4, we'll
# add torch.int4 to core later
w_int8 = torch.ops.quantized_decomposed.quantize_per_channel_group(
w, scales, zeros, min_int, max_int, torch.int8, group_size
)

return w_int8, scales, zeros


# https://www.internalfb.com/code/fbsource/[f1458254b3caba86fb497abbfe15c74c4e8ca38d]/fbcode/executorch/backends/xnnpack/operators/node_visitor.py?lines=451
def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
"""
Convert a tensor to a quantized channelwise tensor 4bit tensor
"""

import torch.nn.functional as F

# Assert we got a properly quantized tensor.
min, max = inp.min().item(), inp.max().item()
assert (
max <= 7 and min >= -8
), f"convert_to_qc4w: [min,max] out of [-8, 7] range, got [{min}, {max}]"

# Assuming we have a 2d tensor
if inp.ndim != 2:
inp = inp.squeeze()
assert (
inp.ndim == 2
), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"

# pad ic
if inp.shape[-1] % 2 != 0:
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)

# Shape after padding
oc, ic = inp.shape
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"

# Adjust inp tensor for zp
inp = inp.to(dtype=torch.uint8) + 8

# Prepare the Result tensor
inp = inp.contiguous().view(-1)
return (inp[1::2] << 4 | inp[::2]).view(oc, int(ic / 2))
Loading
Loading