5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
"""
8
- Defines the UX for converting a model to use mx weights
9
-
10
- For now, this is a module swap for speed of iteration.
11
-
12
- Eventually we plan to move this to a tensor subclass weight wrapper for
13
- inference, and to a tensor subclass weight wrapper + module hooks for training.
8
+ Defines the prototype UX for converting a model to use mx weights
14
9
"""
15
10
11
+ from typing import Any
12
+
16
13
import torch
17
14
import torch .nn .functional as F
18
15
19
16
from torchao .prototype .mx_formats .mx_tensor import MXTensor , to_mx
20
17
21
18
22
19
@torch ._dynamo .allow_in_graph
23
- class NoopFwToMXBw (torch .autograd .Function ):
24
- """
25
- Forward: no-op
26
- Backward: cast grad to MX
27
- """
20
+ class mx_mm (torch .autograd .Function ):
21
+ # There are three gemms in a forward + backward of a Linear layer:
22
+ #
23
+ # 1. input @ weight_t = output (forward pass)
24
+ # 2. grad_output @ weight = grad_input (backward pass)
25
+ # 3. input_t @ grad_output = grad_weight (backward pass)
28
26
29
27
@staticmethod
30
- def forward (ctx , x , elem_dtype , block_size ):
28
+ def forward (
29
+ ctx ,
30
+ input_hp : torch .Tensor ,
31
+ weight_hp : torch .Tensor ,
32
+ elem_dtype : Any ,
33
+ block_size : int ,
34
+ ):
35
+ ctx .save_for_backward (input_hp , weight_hp )
31
36
ctx .elem_dtype = elem_dtype
32
37
ctx .block_size = block_size
33
- return x
38
+
39
+ # input @ weight_t = output
40
+ input_orig_shape = input_hp .shape
41
+ input_hp_r = input_hp .reshape (- 1 , input_orig_shape [- 1 ])
42
+
43
+ input_mx_r_dim0 = MXTensor .to_mx (input_hp_r , elem_dtype , block_size )
44
+ weight_mx_dim0 = MXTensor .to_mx (weight_hp , elem_dtype , block_size )
45
+ output = torch .mm (input_mx_r_dim0 , weight_mx_dim0 .t ())
46
+ output = output .reshape (* input_orig_shape [:- 1 ], output .shape [- 1 ])
47
+
48
+ return output
34
49
35
50
@staticmethod
36
- def backward (ctx , g ):
37
- scale , data = to_mx (g , ctx .elem_dtype , ctx .block_size )
38
- return (
39
- MXTensor (scale , data , ctx .elem_dtype , ctx .block_size , g .dtype ),
40
- None ,
41
- None ,
51
+ def backward (ctx , grad_output_hp : torch .Tensor ):
52
+ input_hp , weight_hp = ctx .saved_tensors
53
+ weight_hp_t_c = weight_hp .t ().contiguous ()
54
+ elem_dtype = ctx .elem_dtype
55
+ block_size = ctx .block_size
56
+
57
+ grad_output_orig_shape = grad_output_hp .shape
58
+ grad_output_hp_r = grad_output_hp .reshape (- 1 , grad_output_orig_shape [- 1 ])
59
+
60
+ input_hp_orig_shape = input_hp .shape
61
+ input_hp_r = input_hp .reshape (- 1 , input_hp_orig_shape [- 1 ])
62
+
63
+ # grad_output @ weight = grad_input
64
+ grad_output_mx_dim0 = MXTensor .to_mx (grad_output_hp_r , elem_dtype , block_size )
65
+ weight_mx_dim1 = MXTensor .to_mx (weight_hp_t_c , elem_dtype , block_size )
66
+ grad_input = torch .mm (grad_output_mx_dim0 , weight_mx_dim1 .t ())
67
+ grad_input = grad_input .reshape (
68
+ * grad_output_orig_shape [:- 1 ], grad_input .shape [- 1 ]
42
69
)
43
70
71
+ # input_t @ grad_output = grad_weight
72
+ grad_output_mx_dim1 = MXTensor .to_mx (
73
+ grad_output_hp_r .t ().contiguous (), elem_dtype , block_size
74
+ )
75
+ input_t_mx_dim0_tmp = MXTensor .to_mx (
76
+ input_hp_r .t ().contiguous (), elem_dtype , block_size
77
+ )
78
+ # print('a', input_t_mx_dim0_tmp.shape)
79
+ input_t_mx_dim0 = input_t_mx_dim0_tmp .t ()
80
+ # print('b', input_t_mx_dim0.shape)
81
+ # TODO(next 2): debug why fp4 here leads to incorrect shapes
82
+ # import pdb; pdb.set_trace()
83
+ # print('go_dim1', grad_output_mx_dim1.shape, 'i_t_dim0', input_t_mx_dim0.shape)
84
+ grad_weight = torch .mm (grad_output_mx_dim1 , input_t_mx_dim0 )
85
+
86
+ return grad_input , grad_weight , None , None
87
+
44
88
45
89
class MXLinear (torch .nn .Linear ):
46
90
"""
@@ -59,16 +103,26 @@ def from_float(cls, mod, elem_dtype, block_size):
59
103
return mod
60
104
61
105
def forward (self , x ):
62
- x_mx = MXTensor .to_mx (x , self .elem_dtype , self .block_size )
63
- w_mx = MXTensor .to_mx (self .weight , self .elem_dtype , self .block_size )
64
- y = F .linear (x_mx , w_mx , self .bias )
65
- y = NoopFwToMXBw .apply (y , self .elem_dtype , self .block_size )
106
+ if torch .is_autocast_enabled ():
107
+ # special case autocast
108
+ autocast_dtype = torch .get_autocast_dtype ("cuda" )
109
+ x = x .to (autocast_dtype )
110
+ w = self .weight .to (autocast_dtype )
111
+ else :
112
+ w = self .weight
113
+
114
+ y = mx_mm .apply (x , w , self .elem_dtype , self .block_size )
115
+ if self .bias is not None :
116
+ y = y + self .bias
66
117
return y
67
118
68
119
69
120
class MXInferenceLinear (torch .nn .Linear ):
70
121
"""
71
122
Inference version of MXLinear, with the weight pre-quantized to MX.
123
+
124
+ Note: this is weight-only quantization, with the gemm being executed
125
+ in high precision.
72
126
"""
73
127
74
128
@classmethod
@@ -84,8 +138,8 @@ def from_float(cls, mod, elem_dtype, block_size):
84
138
# TODO(future PR): set to new_mod.weight directly, will need to work
85
139
# through some errors
86
140
new_mod .weight_mx = MXTensor .to_mx (
87
- mod .weight . t (). contiguous () , elem_dtype , block_size = block_size
88
- ). t ()
141
+ mod .weight , elem_dtype , block_size = block_size
142
+ )
89
143
new_mod .bias = mod .bias
90
144
new_mod .elem_dtype = elem_dtype
91
145
return new_mod
0 commit comments