5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
"""
8
- Defines the UX for converting a model to use mx weights
9
-
10
- For now, this is a module swap for speed of iteration.
11
-
12
- Eventually we plan to move this to a tensor subclass weight wrapper for
13
- inference, and to a tensor subclass weight wrapper + module hooks for training.
8
+ Defines the prototype UX for converting a model to use mx weights
14
9
"""
15
10
11
+ from typing import Any
12
+
16
13
import torch
17
14
import torch .nn .functional as F
18
15
19
- from torchao .prototype .mx_formats .mx_tensor import MXTensor , to_mx
16
+ from torchao .prototype .mx_formats .mx_tensor import MXTensor
20
17
21
18
22
19
@torch ._dynamo .allow_in_graph
23
- class NoopFwToMXBw (torch .autograd .Function ):
24
- """
25
- Forward: no-op
26
- Backward: cast grad to MX
27
- """
20
+ class mx_mm (torch .autograd .Function ):
21
+ # There are three gemms in a forward + backward of a Linear layer:
22
+ #
23
+ # 1. input @ weight_t = output (forward pass)
24
+ # 2. grad_output @ weight = grad_input (backward pass)
25
+ # 3. input_t @ grad_output = grad_weight (backward pass)
28
26
29
27
@staticmethod
30
- def forward (ctx , x , elem_dtype , block_size ):
28
+ def forward (
29
+ ctx ,
30
+ input_hp : torch .Tensor ,
31
+ weight_hp : torch .Tensor ,
32
+ elem_dtype : Any ,
33
+ block_size : int ,
34
+ ):
35
+ ctx .save_for_backward (input_hp , weight_hp )
31
36
ctx .elem_dtype = elem_dtype
32
37
ctx .block_size = block_size
33
- return x
38
+
39
+ # input @ weight_t = output
40
+ input_orig_shape = input_hp .shape
41
+ input_hp_r = input_hp .reshape (- 1 , input_orig_shape [- 1 ])
42
+
43
+ input_mx_r_dim0 = MXTensor .to_mx (input_hp_r , elem_dtype , block_size )
44
+ weight_mx_dim0 = MXTensor .to_mx (weight_hp , elem_dtype , block_size )
45
+ output = torch .mm (input_mx_r_dim0 , weight_mx_dim0 .t ())
46
+ output = output .reshape (* input_orig_shape [:- 1 ], output .shape [- 1 ])
47
+
48
+ return output
34
49
35
50
@staticmethod
36
- def backward (ctx , g ):
37
- scale , data = to_mx (g , ctx .elem_dtype , ctx .block_size )
38
- return (
39
- MXTensor (scale , data , ctx .elem_dtype , ctx .block_size , g .dtype ),
40
- None ,
41
- None ,
51
+ def backward (ctx , grad_output_hp : torch .Tensor ):
52
+ input_hp , weight_hp = ctx .saved_tensors
53
+ weight_hp_t_c = weight_hp .t ().contiguous ()
54
+ elem_dtype = ctx .elem_dtype
55
+ block_size = ctx .block_size
56
+
57
+ grad_output_orig_shape = grad_output_hp .shape
58
+ grad_output_hp_r = grad_output_hp .reshape (- 1 , grad_output_orig_shape [- 1 ])
59
+
60
+ input_hp_orig_shape = input_hp .shape
61
+ input_hp_r = input_hp .reshape (- 1 , input_hp_orig_shape [- 1 ])
62
+
63
+ # grad_output @ weight = grad_input
64
+ grad_output_mx_dim0 = MXTensor .to_mx (grad_output_hp_r , elem_dtype , block_size )
65
+ weight_mx_dim1 = MXTensor .to_mx (weight_hp_t_c , elem_dtype , block_size )
66
+ grad_input = torch .mm (grad_output_mx_dim0 , weight_mx_dim1 .t ())
67
+ grad_input = grad_input .reshape (
68
+ * grad_output_orig_shape [:- 1 ], grad_input .shape [- 1 ]
42
69
)
43
70
71
+ # input_t @ grad_output = grad_weight
72
+ grad_output_mx_dim1 = MXTensor .to_mx (
73
+ grad_output_hp_r .t ().contiguous (), elem_dtype , block_size
74
+ )
75
+ input_t_mx_dim0_tmp = MXTensor .to_mx (
76
+ input_hp_r .t ().contiguous (), elem_dtype , block_size
77
+ )
78
+ input_t_mx_dim0 = input_t_mx_dim0_tmp .t ()
79
+ grad_weight = torch .mm (grad_output_mx_dim1 , input_t_mx_dim0 )
80
+
81
+ return grad_input , grad_weight , None , None
82
+
44
83
45
84
class MXLinear (torch .nn .Linear ):
46
85
"""
@@ -59,16 +98,26 @@ def from_float(cls, mod, elem_dtype, block_size):
59
98
return mod
60
99
61
100
def forward (self , x ):
62
- x_mx = MXTensor .to_mx (x , self .elem_dtype , self .block_size )
63
- w_mx = MXTensor .to_mx (self .weight , self .elem_dtype , self .block_size )
64
- y = F .linear (x_mx , w_mx , self .bias )
65
- y = NoopFwToMXBw .apply (y , self .elem_dtype , self .block_size )
101
+ if torch .is_autocast_enabled ():
102
+ # special case autocast
103
+ autocast_dtype = torch .get_autocast_dtype ("cuda" )
104
+ x = x .to (autocast_dtype )
105
+ w = self .weight .to (autocast_dtype )
106
+ else :
107
+ w = self .weight
108
+
109
+ y = mx_mm .apply (x , w , self .elem_dtype , self .block_size )
110
+ if self .bias is not None :
111
+ y = y + self .bias
66
112
return y
67
113
68
114
69
115
class MXInferenceLinear (torch .nn .Linear ):
70
116
"""
71
117
Inference version of MXLinear, with the weight pre-quantized to MX.
118
+
119
+ Note: this is weight-only quantization, with the gemm being executed
120
+ in high precision.
72
121
"""
73
122
74
123
@classmethod
@@ -84,8 +133,8 @@ def from_float(cls, mod, elem_dtype, block_size):
84
133
# TODO(future PR): set to new_mod.weight directly, will need to work
85
134
# through some errors
86
135
new_mod .weight_mx = MXTensor .to_mx (
87
- mod .weight . t (). contiguous () , elem_dtype , block_size = block_size
88
- ). t ()
136
+ mod .weight , elem_dtype , block_size = block_size
137
+ )
89
138
new_mod .bias = mod .bias
90
139
new_mod .elem_dtype = elem_dtype
91
140
return new_mod
0 commit comments