Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
80 changes: 47 additions & 33 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,24 +500,28 @@ def FinalFunc():
logger.info(
f"[fused_moe] using {'1stage' if run_1stage else '2stage'} {'default' if cfg is None else tag} for {keys} "
)
if dtype in [dtypes.bf16, dtypes.fp16] and q_type == QuantType.per_1x32 and activation == ActivationType.Swiglu:
if (
dtype in [dtypes.bf16, dtypes.fp16]
and q_type == QuantType.per_1x32
and activation == ActivationType.Swiglu
):
return MOEMetadata(
functools.partial(
cktile_moe_stage1,
n_pad_zeros=intermediate_pad // 64 * 64 * (2 if use_g1u1 else 1),
k_pad_zeros=hidden_pad // 128 * 128,
bias1=bias1,
),
functools.partial(
cktile_moe_stage2,
n_pad_zeros=hidden_pad // 64 * 64,
k_pad_zeros=intermediate_pad // 128 * 128,
bias2=bias2,
),
16 if token < 2048 else 32,
ksplit,
False,
)
functools.partial(
cktile_moe_stage1,
n_pad_zeros=intermediate_pad // 64 * 64 * (2 if use_g1u1 else 1),
k_pad_zeros=hidden_pad // 128 * 128,
bias1=bias1,
),
functools.partial(
cktile_moe_stage2,
n_pad_zeros=hidden_pad // 64 * 64,
k_pad_zeros=intermediate_pad // 128 * 128,
bias2=bias2,
),
16 if token < 2048 else 32,
ksplit,
False,
)
if (
"ck" in kernelName1
or q_dtype_w
Expand Down Expand Up @@ -628,10 +632,12 @@ def fused_moe_2stages(
bias1,
bias2,
)
if quant_type == QuantType.per_1x32 \
and dtype in [dtypes.bf16, dtypes.fp16] \
and w1.dtype == torch.uint8 \
and activation == ActivationType.Swiglu:
if (
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and w1.dtype == torch.uint8
and activation == ActivationType.Swiglu
):
a1 = hidden_states.to(dtype)
a1_scale = None
elif quant_type == QuantType.per_1x32:
Expand Down Expand Up @@ -688,9 +694,11 @@ def fused_moe_2stages(
sorted_weights=sorted_weights if doweight_stage1 else None,
)

if quant_type == QuantType.per_1x32 \
and dtype in [dtypes.bf16, dtypes.fp16] \
and w1.dtype == torch.uint8:
if (
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and w1.dtype == torch.uint8
):
a2_scale = None
elif quant_type == QuantType.per_1x32:
a2 = a2.view(-1, inter_dim)
Expand Down Expand Up @@ -882,7 +890,8 @@ def torch_moe(

return (out * topk_weight.view(B, -1, 1)).sum(dim=1).to(dtype)

#temp workaround for swiglu

# temp workaround for swiglu
def swiglu(x_glu, x_linear, alpha: float = 1.702, limit: float = 7.0):
# Clamp the input values
x_glu = x_glu.clamp(min=None, max=limit)
Expand Down Expand Up @@ -914,12 +923,13 @@ def torch_moe_stage1(
E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape)
if quant_type == QuantType.per_1x32:
from aiter.utility import fp4_utils

w1 = fp4_utils.mxfp4_to_f32(w1)
w1_scale = fp4_utils.e8m0_to_f32(w1_scale)
if a1_scale is not None: #skip a16w4
if a1_scale is not None: # skip a16w4
hidden_states = fp4_utils.mxfp4_to_f32(hidden_states)
a1_scale = fp4_utils.e8m0_to_f32(a1_scale)
else: #a16w4
else: # a16w4
hidden_states = hidden_states.to(ctype)

else:
Expand Down Expand Up @@ -957,7 +967,9 @@ def torch_moe_stage1(
hidden_states = hidden_states.view(a1_shape[0], a1_shape[1] // 32, 32)
if a1_scale is not None:
a1_scale = a1_scale[: a1_shape[0]]
hidden_states = hidden_states * a1_scale.view(a1_shape[0], a1_shape[1] // 32, 1)
hidden_states = hidden_states * a1_scale.view(
a1_shape[0], a1_shape[1] // 32, 1
)
hidden_states = hidden_states.view(a1_shape)
else:
assert False, f"Unsupported quant_type: {quant_type}"
Expand Down Expand Up @@ -1003,7 +1015,7 @@ def torch_moe_stage2(
quant_type=QuantType.No,
w2_scale=None, # [1]
a2_scale=None, # [expert]]'
w2_bias=None,
w2_bias=None,
doweight=True,
):
ctype = dtypes.fp32 # compute type
Expand All @@ -1016,7 +1028,7 @@ def torch_moe_stage2(
if a2_scale is not None:
hidden_states = fp4_utils.mxfp4_to_f32(hidden_states)
a2_scale = fp4_utils.e8m0_to_f32(a2_scale)
else: #a16w4
else: # a16w4
hidden_states = hidden_states.to(ctype)
else:
hidden_states = hidden_states.to(ctype)
Expand Down Expand Up @@ -1094,12 +1106,14 @@ def cktile_moe_stage1(
token_num = hidden_states.shape[0]
_, n1, k1 = w1.shape
_, k2, n2 = w2.shape
D = n2 if k2 == k1 else n2*2 #bit4 format
D = n2 if k2 == k1 else n2 * 2 # bit4 format
# max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size

if w1.dtype is torch.uint32:
D = D * 8
out = torch.empty((token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device)
out = torch.empty(
(token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device
)
# print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0]))
aiter.moe_cktile2stages_gemm1(
hidden_states,
Expand All @@ -1118,7 +1132,7 @@ def cktile_moe_stage1(
block_m,
)
return out


def cktile_moe_stage2(
a2,
Expand Down
153 changes: 93 additions & 60 deletions aiter/ops/moe_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,77 +224,110 @@ def ck_moe_stage2(
): ...


@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm1")
@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm1")
def moe_cktile2stages_gemm1_ck(
XQ : Tensor,
WQ : Tensor,
Y : Tensor,
sorted_ids : Tensor,
sorted_expert_ids : Tensor,
max_token_ids : Tensor,
topk : int,
n_padded_zeros : Optional[int] = 0,
k_padded_zeros : Optional[int] = 0,
topk_weight : Optional[Tensor] = None,
x_scale : Optional[Tensor] = None,
w_scale : Optional[Tensor] = None,
exp_bias : Optional[Tensor] = None,
block_m : Optional[int] = 32,
XQ: Tensor,
WQ: Tensor,
Y: Tensor,
sorted_ids: Tensor,
sorted_expert_ids: Tensor,
max_token_ids: Tensor,
topk: int,
n_padded_zeros: Optional[int] = 0,
k_padded_zeros: Optional[int] = 0,
topk_weight: Optional[Tensor] = None,
x_scale: Optional[Tensor] = None,
w_scale: Optional[Tensor] = None,
exp_bias: Optional[Tensor] = None,
block_m: Optional[int] = 32,
) -> Tensor: ...


def moe_cktile2stages_gemm1(
XQ : Tensor,
WQ : Tensor,
Y : Tensor,
sorted_ids : Tensor,
sorted_expert_ids : Tensor,
max_token_ids : Tensor,
topk : int,
n_padded_zeros : Optional[int] = 0,
k_padded_zeros : Optional[int] = 0,
topk_weight : Optional[Tensor] = None,
x_scale : Optional[Tensor] = None,
w_scale : Optional[Tensor] = None,
exp_bias : Optional[Tensor] = None,
block_m : Optional[int] = 32,
XQ: Tensor,
WQ: Tensor,
Y: Tensor,
sorted_ids: Tensor,
sorted_expert_ids: Tensor,
max_token_ids: Tensor,
topk: int,
n_padded_zeros: Optional[int] = 0,
k_padded_zeros: Optional[int] = 0,
topk_weight: Optional[Tensor] = None,
x_scale: Optional[Tensor] = None,
w_scale: Optional[Tensor] = None,
exp_bias: Optional[Tensor] = None,
block_m: Optional[int] = 32,
):
return moe_cktile2stages_gemm1_ck(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias, block_m)
return moe_cktile2stages_gemm1_ck(
XQ,
WQ,
Y,
sorted_ids,
sorted_expert_ids,
max_token_ids,
topk,
n_padded_zeros,
k_padded_zeros,
topk_weight,
x_scale,
w_scale,
exp_bias,
block_m,
)


@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm2")
@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm2")
def moe_cktile2stages_gemm2_ck(
XQ : Tensor,
WQ : Tensor,
Y : Tensor,
sorted_ids : Tensor,
sorted_expert_ids : Tensor,
max_token_ids : Tensor,
topk : int,
n_padded_zeros : Optional[int] = 0,
k_padded_zeros : Optional[int] = 0,
topk_weight : Optional[Tensor] = None,
x_scale : Optional[Tensor] = None,
w_scale : Optional[Tensor] = None,
exp_bias : Optional[Tensor] = None,
block_m : Optional[int] = 32,
XQ: Tensor,
WQ: Tensor,
Y: Tensor,
sorted_ids: Tensor,
sorted_expert_ids: Tensor,
max_token_ids: Tensor,
topk: int,
n_padded_zeros: Optional[int] = 0,
k_padded_zeros: Optional[int] = 0,
topk_weight: Optional[Tensor] = None,
x_scale: Optional[Tensor] = None,
w_scale: Optional[Tensor] = None,
exp_bias: Optional[Tensor] = None,
block_m: Optional[int] = 32,
) -> Tensor: ...


def moe_cktile2stages_gemm2(
XQ : Tensor,
WQ : Tensor,
Y : Tensor,
sorted_ids : Tensor,
sorted_expert_ids : Tensor,
max_token_ids : Tensor,
topk : int,
n_padded_zeros : Optional[int] = 0,
k_padded_zeros : Optional[int] = 0,
topk_weight : Optional[Tensor] = None,
x_scale : Optional[Tensor] = None,
w_scale : Optional[Tensor] = None,
exp_bias : Optional[Tensor] = None,
block_m : Optional[int] = 32,
XQ: Tensor,
WQ: Tensor,
Y: Tensor,
sorted_ids: Tensor,
sorted_expert_ids: Tensor,
max_token_ids: Tensor,
topk: int,
n_padded_zeros: Optional[int] = 0,
k_padded_zeros: Optional[int] = 0,
topk_weight: Optional[Tensor] = None,
x_scale: Optional[Tensor] = None,
w_scale: Optional[Tensor] = None,
exp_bias: Optional[Tensor] = None,
block_m: Optional[int] = 32,
):
return moe_cktile2stages_gemm2_ck(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias, block_m)
return moe_cktile2stages_gemm2_ck(
XQ,
WQ,
Y,
sorted_ids,
sorted_expert_ids,
max_token_ids,
topk,
n_padded_zeros,
k_padded_zeros,
topk_weight,
x_scale,
w_scale,
exp_bias,
block_m,
)


dtype2str_dict = {
Expand Down
Loading