@@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function):
23
23
# 1. input @ weight_t = output (forward pass)
24
24
# 2. grad_output @ weight = grad_input (backward pass)
25
25
# 3. input_t @ grad_output = grad_weight (backward pass)
26
+ #
27
+ # input, weight and grad_output have each their own MX element dtype.
26
28
27
29
@staticmethod
28
30
def forward (
29
31
ctx ,
30
32
input_hp : torch .Tensor ,
31
33
weight_hp : torch .Tensor ,
32
- elem_dtype : Any ,
34
+ in_elem_dtype : Any ,
35
+ w_elem_dtype : Any ,
36
+ grad_elem_dtype : Any ,
33
37
block_size : int ,
34
38
):
35
39
ctx .save_for_backward (input_hp , weight_hp )
36
- ctx .elem_dtype = elem_dtype
40
+ ctx .in_elem_dtype = in_elem_dtype
41
+ ctx .w_elem_dtype = w_elem_dtype
42
+ ctx .grad_elem_dtype = grad_elem_dtype
37
43
ctx .block_size = block_size
38
44
39
45
# input @ weight_t = output
40
46
input_orig_shape = input_hp .shape
41
47
input_hp_r = input_hp .reshape (- 1 , input_orig_shape [- 1 ])
42
48
43
- input_mx_r_dim0 = MXTensor .to_mx (input_hp_r , elem_dtype , block_size )
44
- weight_mx_dim0 = MXTensor .to_mx (weight_hp , elem_dtype , block_size )
49
+ input_mx_r_dim0 = MXTensor .to_mx (input_hp_r , in_elem_dtype , block_size )
50
+ weight_mx_dim0 = MXTensor .to_mx (weight_hp , w_elem_dtype , block_size )
45
51
output = torch .mm (input_mx_r_dim0 , weight_mx_dim0 .t ())
46
52
output = output .reshape (* input_orig_shape [:- 1 ], output .shape [- 1 ])
47
53
@@ -51,7 +57,9 @@ def forward(
51
57
def backward (ctx , grad_output_hp : torch .Tensor ):
52
58
input_hp , weight_hp = ctx .saved_tensors
53
59
weight_hp_t_c = weight_hp .t ().contiguous ()
54
- elem_dtype = ctx .elem_dtype
60
+ in_elem_dtype = ctx .in_elem_dtype
61
+ w_elem_dtype = ctx .w_elem_dtype
62
+ grad_elem_dtype = ctx .grad_elem_dtype
55
63
block_size = ctx .block_size
56
64
57
65
grad_output_orig_shape = grad_output_hp .shape
@@ -61,24 +69,24 @@ def backward(ctx, grad_output_hp: torch.Tensor):
61
69
input_hp_r = input_hp .reshape (- 1 , input_hp_orig_shape [- 1 ])
62
70
63
71
# grad_output @ weight = grad_input
64
- grad_output_mx_dim0 = MXTensor .to_mx (grad_output_hp_r , elem_dtype , block_size )
65
- weight_mx_dim1 = MXTensor .to_mx (weight_hp_t_c , elem_dtype , block_size )
72
+ grad_output_mx_dim0 = MXTensor .to_mx (grad_output_hp_r , grad_elem_dtype , block_size )
73
+ weight_mx_dim1 = MXTensor .to_mx (weight_hp_t_c , w_elem_dtype , block_size )
66
74
grad_input = torch .mm (grad_output_mx_dim0 , weight_mx_dim1 .t ())
67
75
grad_input = grad_input .reshape (
68
76
* grad_output_orig_shape [:- 1 ], grad_input .shape [- 1 ]
69
77
)
70
78
71
79
# input_t @ grad_output = grad_weight
72
80
grad_output_mx_dim1 = MXTensor .to_mx (
73
- grad_output_hp_r .t ().contiguous (), elem_dtype , block_size
81
+ grad_output_hp_r .t ().contiguous (), grad_elem_dtype , block_size
74
82
)
75
83
input_t_mx_dim0_tmp = MXTensor .to_mx (
76
- input_hp_r .t ().contiguous (), elem_dtype , block_size
84
+ input_hp_r .t ().contiguous (), in_elem_dtype , block_size
77
85
)
78
86
input_t_mx_dim0 = input_t_mx_dim0_tmp .t ()
79
87
grad_weight = torch .mm (grad_output_mx_dim1 , input_t_mx_dim0 )
80
88
81
- return grad_input , grad_weight , None , None
89
+ return grad_input , grad_weight , None , None , None , None
82
90
83
91
84
92
class MXLinear (torch .nn .Linear ):
0 commit comments