@@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function):
2323 # 1. input @ weight_t = output (forward pass)
2424 # 2. grad_output @ weight = grad_input (backward pass)
2525 # 3. input_t @ grad_output = grad_weight (backward pass)
26+ #
27+ # input, weight and grad_output have each their own MX element dtype.
2628
2729 @staticmethod
2830 def forward (
2931 ctx ,
3032 input_hp : torch .Tensor ,
3133 weight_hp : torch .Tensor ,
32- elem_dtype : Any ,
34+ in_elem_dtype : Any ,
35+ w_elem_dtype : Any ,
36+ grad_elem_dtype : Any ,
3337 block_size : int ,
3438 ):
3539 ctx .save_for_backward (input_hp , weight_hp )
36- ctx .elem_dtype = elem_dtype
40+ ctx .in_elem_dtype = in_elem_dtype
41+ ctx .w_elem_dtype = w_elem_dtype
42+ ctx .grad_elem_dtype = grad_elem_dtype
3743 ctx .block_size = block_size
3844
3945 # input @ weight_t = output
4046 input_orig_shape = input_hp .shape
4147 input_hp_r = input_hp .reshape (- 1 , input_orig_shape [- 1 ])
4248
43- input_mx_r_dim0 = MXTensor .to_mx (input_hp_r , elem_dtype , block_size )
44- weight_mx_dim0 = MXTensor .to_mx (weight_hp , elem_dtype , block_size )
49+ input_mx_r_dim0 = MXTensor .to_mx (input_hp_r , in_elem_dtype , block_size )
50+ weight_mx_dim0 = MXTensor .to_mx (weight_hp , w_elem_dtype , block_size )
4551 output = torch .mm (input_mx_r_dim0 , weight_mx_dim0 .t ())
4652 output = output .reshape (* input_orig_shape [:- 1 ], output .shape [- 1 ])
4753
@@ -51,7 +57,9 @@ def forward(
5157 def backward (ctx , grad_output_hp : torch .Tensor ):
5258 input_hp , weight_hp = ctx .saved_tensors
5359 weight_hp_t_c = weight_hp .t ().contiguous ()
54- elem_dtype = ctx .elem_dtype
60+ in_elem_dtype = ctx .in_elem_dtype
61+ w_elem_dtype = ctx .w_elem_dtype
62+ grad_elem_dtype = ctx .grad_elem_dtype
5563 block_size = ctx .block_size
5664
5765 grad_output_orig_shape = grad_output_hp .shape
@@ -61,24 +69,24 @@ def backward(ctx, grad_output_hp: torch.Tensor):
6169 input_hp_r = input_hp .reshape (- 1 , input_hp_orig_shape [- 1 ])
6270
6371 # grad_output @ weight = grad_input
64- grad_output_mx_dim0 = MXTensor .to_mx (grad_output_hp_r , elem_dtype , block_size )
65- weight_mx_dim1 = MXTensor .to_mx (weight_hp_t_c , elem_dtype , block_size )
72+ grad_output_mx_dim0 = MXTensor .to_mx (grad_output_hp_r , grad_elem_dtype , block_size )
73+ weight_mx_dim1 = MXTensor .to_mx (weight_hp_t_c , w_elem_dtype , block_size )
6674 grad_input = torch .mm (grad_output_mx_dim0 , weight_mx_dim1 .t ())
6775 grad_input = grad_input .reshape (
6876 * grad_output_orig_shape [:- 1 ], grad_input .shape [- 1 ]
6977 )
7078
7179 # input_t @ grad_output = grad_weight
7280 grad_output_mx_dim1 = MXTensor .to_mx (
73- grad_output_hp_r .t ().contiguous (), elem_dtype , block_size
81+ grad_output_hp_r .t ().contiguous (), grad_elem_dtype , block_size
7482 )
7583 input_t_mx_dim0_tmp = MXTensor .to_mx (
76- input_hp_r .t ().contiguous (), elem_dtype , block_size
84+ input_hp_r .t ().contiguous (), in_elem_dtype , block_size
7785 )
7886 input_t_mx_dim0 = input_t_mx_dim0_tmp .t ()
7987 grad_weight = torch .mm (grad_output_mx_dim1 , input_t_mx_dim0 )
8088
81- return grad_input , grad_weight , None , None
89+ return grad_input , grad_weight , None , None , None , None
8290
8391
8492class MXLinear (torch .nn .Linear ):
0 commit comments