Skip to content

Commit 6358a71

Browse files
committed
Support mixed MX element dtype in mx_mm function.
Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients.
1 parent 8afd10e commit 6358a71

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

torchao/prototype/mx_formats/mx_linear.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function):
2323
# 1. input @ weight_t = output (forward pass)
2424
# 2. grad_output @ weight = grad_input (backward pass)
2525
# 3. input_t @ grad_output = grad_weight (backward pass)
26+
#
27+
# input, weight and grad_output have each their own MX element dtype.
2628

2729
@staticmethod
2830
def forward(
2931
ctx,
3032
input_hp: torch.Tensor,
3133
weight_hp: torch.Tensor,
32-
elem_dtype: Any,
34+
in_elem_dtype: Any,
35+
w_elem_dtype: Any,
36+
grad_elem_dtype: Any,
3337
block_size: int,
3438
):
3539
ctx.save_for_backward(input_hp, weight_hp)
36-
ctx.elem_dtype = elem_dtype
40+
ctx.in_elem_dtype = in_elem_dtype
41+
ctx.w_elem_dtype = w_elem_dtype
42+
ctx.grad_elem_dtype = grad_elem_dtype
3743
ctx.block_size = block_size
3844

3945
# input @ weight_t = output
4046
input_orig_shape = input_hp.shape
4147
input_hp_r = input_hp.reshape(-1, input_orig_shape[-1])
4248

43-
input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, elem_dtype, block_size)
44-
weight_mx_dim0 = MXTensor.to_mx(weight_hp, elem_dtype, block_size)
49+
input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size)
50+
weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size)
4551
output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t())
4652
output = output.reshape(*input_orig_shape[:-1], output.shape[-1])
4753

@@ -51,7 +57,9 @@ def forward(
5157
def backward(ctx, grad_output_hp: torch.Tensor):
5258
input_hp, weight_hp = ctx.saved_tensors
5359
weight_hp_t_c = weight_hp.t().contiguous()
54-
elem_dtype = ctx.elem_dtype
60+
in_elem_dtype = ctx.in_elem_dtype
61+
w_elem_dtype = ctx.w_elem_dtype
62+
grad_elem_dtype = ctx.grad_elem_dtype
5563
block_size = ctx.block_size
5664

5765
grad_output_orig_shape = grad_output_hp.shape
@@ -61,24 +69,24 @@ def backward(ctx, grad_output_hp: torch.Tensor):
6169
input_hp_r = input_hp.reshape(-1, input_hp_orig_shape[-1])
6270

6371
# grad_output @ weight = grad_input
64-
grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, elem_dtype, block_size)
65-
weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, elem_dtype, block_size)
72+
grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, grad_elem_dtype, block_size)
73+
weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size)
6674
grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t())
6775
grad_input = grad_input.reshape(
6876
*grad_output_orig_shape[:-1], grad_input.shape[-1]
6977
)
7078

7179
# input_t @ grad_output = grad_weight
7280
grad_output_mx_dim1 = MXTensor.to_mx(
73-
grad_output_hp_r.t().contiguous(), elem_dtype, block_size
81+
grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size
7482
)
7583
input_t_mx_dim0_tmp = MXTensor.to_mx(
76-
input_hp_r.t().contiguous(), elem_dtype, block_size
84+
input_hp_r.t().contiguous(), in_elem_dtype, block_size
7785
)
7886
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
7987
grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0)
8088

81-
return grad_input, grad_weight, None, None
89+
return grad_input, grad_weight, None, None, None, None
8290

8391

8492
class MXLinear(torch.nn.Linear):

0 commit comments

Comments
 (0)