Skip to content

[WIP]Enable MOE int4 quant on XPU #2758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions torchao/dtypes/uintx/int4_xpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def _linear_bf16_act_uint4_weight_float_zero_impl(input_tensor, weight_tensor, b

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]
import pdb
pdb.set_trace()
y = torch.ops.aten._weight_int4pack_mm(
act_mat, packed_weight, groupsize, scales_and_zeros
)
Expand Down Expand Up @@ -135,7 +137,6 @@ def _linear_fp_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bias

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]

y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
act_mat, packed_weight, groupsize, scale, zero
)
Expand Down Expand Up @@ -247,22 +248,41 @@ def from_plain(
_layout: Layout,
):
assert isinstance(_layout, Int4XPULayout)

if TORCH_VERSION_AT_LEAST_2_8:
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
)
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(
torch.uint8
)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8

def quant_2d(int_data_2d):
if TORCH_VERSION_AT_LEAST_2_8:
packed_weight = (int_data_2d[::, 1::2] << 4 | int_data_2d[::, ::2]).to(
torch.uint8
)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
return packed_weight
else:
assert False, "INT4 not supported on XPU until 2.8"

if int_data.dim() == 3: # for moe quant
num_experts = int_data.shape[0]
packed_weight_list = []
for expert in range(num_experts):
packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0))
packed_weight = torch.cat(packed_weight_list, dim=0)
scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1)
zero_point = (
zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1)
if zero_point is not None
else None
)
else:
assert False, "INT4 not supported on XPU until 2.8"
assert int_data.dim() == 2
packed_weight = quant_2d(int_data)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = (
zero_point.reshape(int_data.shape[0], -1)
if zero_point is not None
else None
)

scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
if zero_point.dtype == scale.dtype:
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros

Expand All @@ -274,8 +294,8 @@ def from_plain(
None,
False,
_layout,
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous().to(torch.int8),
scale.transpose(-2, -1).contiguous(),
zero_point.transpose(-2, -1).contiguous().to(torch.int8),
)

def to(self, *args, **kwargs):
Expand Down Expand Up @@ -317,6 +337,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func in [aten.select.int, aten.index.Tensor]:
assert not (func is aten.select.int and args[1] != 0), (
"aten.select.int currently only has support for dim=0"
)
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)),
)

if func is aten.t.default:
"""we don't need to repack the weight and just rely on external
Expand Down
Loading