-
Notifications
You must be signed in to change notification settings - Fork 250
[DRAFT] use xnnpack quantization in eager/aoti #698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 15 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
1cef4bf
Linear class with init + forward separation
metascroy 5071bec
updates
metascroy 2d7c574
remove nocommits
metascroy c9c1ceb
remove nocommits
metascroy eeee411
updates
metascroy 3d70e5b
updates
metascroy c13d409
add new int4 op
metascroy 3121d55
pass block_size
metascroy 0b93db9
clean up stuff
metascroy f9b4731
require group size be power of 2
metascroy 0ac35a0
require group size be power of 2
metascroy 0d759f0
split create + run
metascroy cf8f910
rename functions
metascroy 4a7a40c
updates
metascroy 01ee9b4
add segfault example
metascroy 92bc305
update main.cpp
metascroy 939e83c
formatting
metascroy 70d6e59
updates
metascroy c48abac
remove cpuinfo
metascroy 4108f46
remove more stuff
metascroy 19f0b6e
small change
metascroy 85829fa
typo
metascroy 1aa7bfb
update build.sh
metascroy 9190974
Merge branch 'main' into linear-with-init-forward-sep
metascroy 7428eb8
switch to caffe2::threadpool
metascroy 2cb2107
remove dead stuff
metascroy 806ae55
clean up
metascroy 3dccf9f
move guard
metascroy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
cmake_minimum_required(VERSION 3.17) | ||
project(custom_linear) | ||
|
||
set(CMAKE_CXX_STANDARD 17) | ||
|
||
find_package(Torch REQUIRED) | ||
|
||
add_library(custom_linear SHARED custom_linear.cpp) | ||
target_include_directories(custom_linear PRIVATE "${TORCHCHAT_ROOT}/..") | ||
target_link_libraries(custom_linear PRIVATE "${TORCH_LIBRARIES}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch | ||
import torch.nn as nn | ||
from typing import Optional | ||
torch.ops.load_library("_custom_linear/build/libcustom_linear.dylib") | ||
from .quantize import group_quantize_tensor_symmetric, convert_to_qc4w | ||
|
||
class _CustomLinear(nn.Module): | ||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> None: | ||
super().__init__() | ||
self.weight = weight | ||
assert bias is None | ||
|
||
self.group_size = 32 | ||
w_int, s, z = group_quantize_tensor_symmetric(self.weight, self.group_size, torch.float32) | ||
w_packed = convert_to_qc4w(w_int) | ||
self.prepacked = torch.ops.torchchat.prepack.default(w_packed, s) | ||
|
||
def forward(self, x): | ||
if x.dtype != torch.float32: | ||
raise RuntimeError(f"x has dtype {x.dtype}, expected float32") | ||
assert x.shape[0] == 1 | ||
return torch.ops.torchchat.run.default(self.prepacked, x.squeeze(0)).unsqueeze(0) | ||
|
||
def _replace_linear_with_custom_linear(module: nn.Module): | ||
for name, child in module.named_children(): | ||
if isinstance(child, nn.Linear): | ||
setattr(module, name, _CustomLinear(child.weight, child.bias)) | ||
else: | ||
_replace_linear_with_custom_linear(child) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
rm -rf build | ||
mkdir build | ||
# cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" -S . -B build | ||
cmake -DCMAKE_PREFIX_PATH="/Users/scroy/repos/pytorch/torch/share/cmake" -DTORCHCHAT_ROOT="${PWD}/.." -S . -B build | ||
cmake --build build |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
#include <torch/library.h> | ||
#include <torch/script.h> | ||
#include <ATen/native/xnnpack/Common.h> | ||
|
||
at::native::xnnpack::Operator create_fully_connected_nc_qd8_f32_qb4w( | ||
at::Tensor weight, | ||
at::Tensor weight_scales) { | ||
|
||
TORCH_CHECK(weight.dim() == 2, "weight must be 2-dimensional"); | ||
TORCH_CHECK(weight.size(1) % 2 == 0, "weight columns must be even (packed int4)"); | ||
|
||
TORCH_CHECK(weight_scales.dim() == 2, "weight_scales must be 2-dimensional"); | ||
TORCH_CHECK(weight.size(0) == weight_scales.size(0), "weight and weight_scale must have same number of rows"); | ||
|
||
const float output_min = -std::numeric_limits<float>::infinity(); | ||
const float output_max = std::numeric_limits<float>::infinity(); | ||
const uint8_t weight_zero_point = 8; | ||
|
||
auto input_channels = 2*weight.size(1); // Multiply by 2 because weights are packed | ||
auto output_channels = weight.size(0); | ||
|
||
|
||
TORCH_CHECK((input_channels % weight_scales.size(1)) == 0, "number of columns in weight_scales should divide input_channels"); | ||
size_t group_size = input_channels / weight_scales.size(1); | ||
TORCH_CHECK(group_size > 1, "inferred group_size must be > 1"); | ||
TORCH_CHECK((group_size&(group_size-1)) == 0, "inferred group_size must be a power of 2"); | ||
|
||
// Create FC | ||
xnn_operator_t fc_op = nullptr; | ||
auto status = xnn_create_fully_connected_nc_qd8_f32_qb4w( | ||
input_channels, /*size_t input_channels*/ | ||
output_channels, /*size_t output_channels*/ | ||
input_channels, /*size_t input_stride*/ | ||
output_channels, /*size_t output_stride*/ | ||
group_size, /*size_t block_size*/ | ||
weight_zero_point, /*uint8_t kernel_zero_point*/ | ||
weight_scales.const_data_ptr<float>(), /*const float* kernel_scale*/ | ||
weight.const_data_ptr(), /*const void* kernel*/ | ||
nullptr, /*const float* bias*/ | ||
output_min, /*float output_min*/ | ||
output_max, /*float output_max*/ | ||
0, /*uint32_t flags*/ | ||
nullptr, /*xnn_code_cache_t code_cache*/ | ||
nullptr, /*xnn_weights_cache_t weights_cache*/ | ||
&fc_op /*xnn_operator_t* fully_connected_op_out*/ | ||
); | ||
TORCH_CHECK(status == xnn_status_success, "Operator xnn_create_fully_connected_nc_qd8_f32_qb4w failed with status ", status, "."); | ||
TORCH_CHECK(fc_op != nullptr); | ||
|
||
return at::native::xnnpack::Operator(fc_op); | ||
} | ||
|
||
|
||
at::native::xnnpack::Operator create_convert_nc_f32_qd8() { | ||
xnn_operator_t convert_op = nullptr; | ||
auto status = xnn_create_convert_nc_f32_qd8( | ||
0, /*uint32_t flags*/ | ||
&convert_op /*xnn_operator_t* convert_op_out*/ | ||
); | ||
TORCH_CHECK(status == xnn_status_success, "Operator xnn_create_convert_nc_f32_qd8 failed with status ", status, "."); | ||
TORCH_CHECK(convert_op != nullptr); | ||
return at::native::xnnpack::Operator(convert_op); | ||
} | ||
|
||
at::Tensor run_linear_qd8_f32_qb4w(xnn_operator_t convert_op, xnn_operator_t fc_op, int64_t output_channels, at::Tensor input) { | ||
TORCH_CHECK(input.dim() == 2); | ||
|
||
auto batch_size = input.size(0); | ||
auto input_channels = input.size(1); | ||
xnn_status status; | ||
|
||
// Holds output of convert | ||
std::vector<int8_t> output_convert(batch_size * input_channels + XNN_EXTRA_BYTES); | ||
std::vector<xnn_dynamic_quantization_params> quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); | ||
|
||
// Run input convert | ||
status = xnn_reshape_convert_nc_f32_qd8( | ||
convert_op, /*xnn_operator_t convert_op*/ | ||
batch_size, /*size_t batch_size*/ | ||
input_channels, /*size_t channels*/ | ||
input_channels, /*size_t input_stride*/ | ||
input_channels, /*size_t output_stride*/ | ||
nullptr /*pthreadpool_t threadpool*/ | ||
); | ||
TORCH_CHECK(status == xnn_status_success, "Operator xnn_reshape_convert_nc_f32_qd8 failed with status ", status, "."); | ||
|
||
status = xnn_setup_convert_nc_f32_qd8( | ||
convert_op, /*xnn_operator_t convert_op*/ | ||
input.const_data_ptr<float>(), /*const float* input*/ | ||
output_convert.data(), /*int8_t* output*/ | ||
quantization_params.data() /*struct xnn_dynamic_quantization_params* quantization_params*/ | ||
); | ||
TORCH_CHECK(status == xnn_status_success, "Operator xnn_setup_convert_nc_f32_qd8 failed with status ", status, "."); | ||
|
||
status = xnn_run_operator(convert_op, /*threadpool=*/nullptr); | ||
TORCH_CHECK(status == xnn_status_success, "Running convert_op failed with status ", status, "."); | ||
|
||
|
||
// Holds output of linear | ||
auto options = torch::TensorOptions().dtype(torch::kFloat32); | ||
auto output_tensor = torch::empty({batch_size, output_channels}, options); | ||
|
||
// Run linear | ||
status = xnn_reshape_fully_connected_nc_qd8_f32_qb4w( | ||
fc_op, /*xnn_operator_t fully_connected_op*/ | ||
batch_size, /*size_t batch_size*/ | ||
nullptr /*pthreadpool_t threadpool*/ // TODO: set to something sensible | ||
); | ||
TORCH_CHECK(status == xnn_status_success, "Operator xnn_reshape_fully_connected_nc_qd8_f32_qb4w failed with status ", status, "."); | ||
|
||
status = xnn_setup_fully_connected_nc_qd8_f32_qb4w( | ||
fc_op, /*xnn_operator_t fully_connected_op*/ | ||
output_convert.data(), /*const int8_t* input*/ | ||
output_tensor.data_ptr<float>(), /*float* output*/ | ||
quantization_params.data() /*const struct xnn_dynamic_quantization_params* quantization_params*/ | ||
); | ||
TORCH_CHECK(status == xnn_status_success, "Operator xnn_setup_fully_connected_nc_qd8_f32_qb4w failed with status ", status, "."); | ||
|
||
|
||
status = xnn_run_operator(fc_op, /*threadpool=*/nullptr); | ||
TORCH_CHECK(status == xnn_status_success, "Running fc_op failed with status ", status, "."); | ||
|
||
return output_tensor; | ||
} | ||
|
||
|
||
|
||
|
||
|
||
class PrepackedContext : public torch::jit::CustomClassHolder { | ||
private: | ||
at::native::xnnpack::Operator convert_op_; | ||
at::native::xnnpack::Operator fc_op_; | ||
size_t output_channels_; | ||
|
||
public: | ||
PrepackedContext(at::native::xnnpack::Operator convert_op, at::native::xnnpack::Operator fc_op, size_t output_channels) : | ||
convert_op_(std::move(convert_op)), fc_op_(std::move(fc_op)), output_channels_(output_channels) {} | ||
xnn_operator_t convert_op() { | ||
return convert_op_.get(); | ||
} | ||
|
||
xnn_operator_t fc_op() { | ||
return fc_op_.get(); | ||
} | ||
|
||
size_t output_channels() { | ||
return output_channels_; | ||
} | ||
|
||
}; | ||
|
||
c10::intrusive_ptr<PrepackedContext> prepack(at::Tensor weight, at::Tensor weight_scales) { | ||
auto status = xnn_initialize(/*allocator=*/nullptr); | ||
TORCH_CHECK(status == xnn_status_success); | ||
auto convert_op = create_convert_nc_f32_qd8(); | ||
auto fc_op = create_fully_connected_nc_qd8_f32_qb4w(weight, weight_scales); | ||
auto output_channels = weight.size(0); | ||
|
||
return c10::make_intrusive<PrepackedContext>( | ||
at::native::xnnpack::Operator(std::move(convert_op)), | ||
at::native::xnnpack::Operator(std::move(fc_op)), | ||
output_channels | ||
); | ||
} | ||
|
||
|
||
at::Tensor run( | ||
c10::intrusive_ptr<PrepackedContext> prepacked_context, | ||
at::Tensor input) { | ||
return run_linear_qd8_f32_qb4w(prepacked_context->convert_op(), prepacked_context->fc_op(), prepacked_context->output_channels(), input); | ||
} | ||
|
||
at::Tensor prepack_and_run( | ||
at::Tensor weight, | ||
at::Tensor weight_scales, | ||
at::Tensor input) { | ||
|
||
auto prepacked_context = prepack(weight, weight_scales); | ||
return run(prepacked_context, input); | ||
} | ||
|
||
|
||
TORCH_LIBRARY(torchchat, m) { | ||
m.class_<PrepackedContext>("PrepackedContext"); | ||
m.def("prepack", prepack); | ||
m.def("run", run); | ||
m.def("prepack_and_run", prepack_and_run); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#include <torch/library.h> | ||
#include <torch/script.h> | ||
#include <ATen/native/xnnpack/Common.h> | ||
|
||
int main() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This segfaults. |
||
|
||
xnn_status status; | ||
// status = xnn_initialize(/*allocator=*/nullptr); | ||
// TORCH_CHECK(status == xnn_status_success); | ||
|
||
auto w_col = 384; | ||
auto input_channels = w_col*2; | ||
auto output_channels = 32000; | ||
auto group_size = 32; | ||
auto n_groups = 24; | ||
TORCH_CHECK(n_groups * group_size == input_channels); | ||
|
||
// auto options = torch::TensorOptions().dtype(torch::kByte); | ||
// auto weight = torch::ones({output_channels, w_col}, options); | ||
// auto weight_scales = torch::ones({output_channels, n_groups}); | ||
|
||
auto weight_data = std::vector<uint8_t>(); | ||
for (int i = 0; i < output_channels * w_col; ++i) { | ||
weight_data.push_back(1); | ||
} | ||
|
||
auto weight_scales = std::vector<float>(); | ||
for (int i = 0; i < output_channels * n_groups; ++i) { | ||
weight_data.push_back(1.0); | ||
} | ||
|
||
const float output_min = -std::numeric_limits<float>::infinity(); | ||
const float output_max = std::numeric_limits<float>::infinity(); | ||
const uint8_t weight_zero_point = 8; | ||
|
||
xnn_operator_t fc_op = nullptr; | ||
status = xnn_create_fully_connected_nc_qd8_f32_qb4w( | ||
input_channels, /*size_t input_channels*/ | ||
output_channels, /*size_t output_channels*/ | ||
input_channels, /*size_t input_stride*/ | ||
output_channels, /*size_t output_stride*/ | ||
group_size, /*size_t block_size*/ | ||
weight_zero_point, /*uint8_t kernel_zero_point*/ | ||
weight_scales.data(), /*const float* kernel_scale*/ | ||
(void*)weight_data.data(), /*const void* kernel*/ | ||
nullptr, /*const float* bias*/ | ||
output_min, /*float output_min*/ | ||
output_max, /*float output_max*/ | ||
0, /*uint32_t flags*/ | ||
nullptr, /*xnn_code_cache_t code_cache*/ | ||
nullptr, /*xnn_weights_cache_t weights_cache*/ | ||
&fc_op /*xnn_operator_t* fully_connected_op_out*/ | ||
); | ||
|
||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# https://www.internalfb.com/code/fbsource/[f1458254b3caba86fb497abbfe15c74c4e8ca38d]/fbcode/executorch/backends/xnnpack/test/ops/linear.py?lines=348 | ||
|
||
import torch | ||
from torch.ao.quantization.quantizer.xnnpack_quantizer import ( | ||
get_symmetric_quantization_config, | ||
) | ||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig | ||
|
||
# Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues | ||
def get_group_qparams_symmetric(w, n_bit, groupsize, precision): | ||
# needed for GPTQ with padding | ||
if groupsize > w.shape[-1]: | ||
groupsize = w.shape[-1] | ||
assert groupsize > 1 | ||
assert w.shape[-1] % groupsize == 0 | ||
assert w.dim() == 2 | ||
|
||
to_quant = w.reshape(-1, groupsize) | ||
assert torch.isnan(to_quant).sum() == 0 | ||
|
||
max_val = to_quant.amax(dim=1, keepdim=True) | ||
min_val = to_quant.amin(dim=1, keepdim=True) | ||
min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) | ||
max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) | ||
|
||
max_val_abs = torch.max(-min_val_neg, max_val_pos) | ||
max_int = 2 ** (n_bit - 1) - 1 | ||
min_int = -(2 ** (n_bit - 1)) | ||
|
||
# max_int - min_int is just 2**(n_bit) - 1 | ||
|
||
scales = max_val_abs / (float(max_int - min_int) / 2) # This is just 2 * max(abs(x)) / (int range) | ||
scales = torch.max( | ||
scales, torch.full_like(scales, torch.finfo(torch.float32).eps) | ||
) | ||
# TODO: make sure abs(scales) is not too small? | ||
zeros = torch.full_like(scales, 0) | ||
return scales.to(precision).reshape(w.shape[0], -1), zeros.to( | ||
precision | ||
).reshape(w.shape[0], -1) | ||
|
||
# Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues | ||
# Does 4-bit quantization | ||
def group_quantize_tensor_symmetric(w, group_size, precision): | ||
n_bit = 4 | ||
scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision) | ||
max_int = 2 ** (n_bit - 1) - 1 | ||
min_int = -(2 ** (n_bit - 1)) | ||
# TODO: currently we don't know how to express torch.int4, we'll | ||
# add torch.int4 to core later | ||
w_int8 = torch.ops.quantized_decomposed.quantize_per_channel_group( | ||
w, scales, zeros, min_int, max_int, torch.int8, group_size | ||
) | ||
|
||
return w_int8, scales, zeros | ||
|
||
|
||
# https://www.internalfb.com/code/fbsource/[f1458254b3caba86fb497abbfe15c74c4e8ca38d]/fbcode/executorch/backends/xnnpack/operators/node_visitor.py?lines=451 | ||
def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Convert a tensor to a quantized channelwise tensor 4bit tensor | ||
""" | ||
|
||
import torch.nn.functional as F | ||
|
||
# Assert we got a properly quantized tensor. | ||
min, max = inp.min().item(), inp.max().item() | ||
assert ( | ||
max <= 7 and min >= -8 | ||
), f"convert_to_qc4w: [min,max] out of [-8, 7] range, got [{min}, {max}]" | ||
|
||
# Assuming we have a 2d tensor | ||
if inp.ndim != 2: | ||
inp = inp.squeeze() | ||
assert ( | ||
inp.ndim == 2 | ||
), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}" | ||
|
||
# pad ic | ||
if inp.shape[-1] % 2 != 0: | ||
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0) | ||
|
||
# Shape after padding | ||
oc, ic = inp.shape | ||
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even" | ||
|
||
# Adjust inp tensor for zp | ||
inp = inp.to(dtype=torch.uint8) + 8 | ||
|
||
# Prepare the Result tensor | ||
inp = inp.contiguous().view(-1) | ||
return (inp[1::2] << 4 | inp[::2]).view(oc, int(ic / 2)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.