55# LICENSE file in the root directory of this source tree.
66
77"""
8- Defines the UX for converting a model to use mx weights
9-
10- For now, this is a module swap for speed of iteration.
11-
12- Eventually we plan to move this to a tensor subclass weight wrapper for
13- inference, and to a tensor subclass weight wrapper + module hooks for training.
8+ Defines the prototype UX for converting a model to use mx weights
149"""
1510
11+ from typing import Any
12+
1613import torch
1714import torch .nn .functional as F
1815
19- from torchao .prototype .mx_formats .mx_tensor import MXTensor , to_mx
16+ from torchao .prototype .mx_formats .mx_tensor import MXTensor
2017
2118
2219@torch ._dynamo .allow_in_graph
23- class NoopFwToMXBw (torch .autograd .Function ):
24- """
25- Forward: no-op
26- Backward: cast grad to MX
27- """
20+ class mx_mm (torch .autograd .Function ):
21+ # There are three gemms in a forward + backward of a Linear layer:
22+ #
23+ # 1. input @ weight_t = output (forward pass)
24+ # 2. grad_output @ weight = grad_input (backward pass)
25+ # 3. input_t @ grad_output = grad_weight (backward pass)
2826
2927 @staticmethod
30- def forward (ctx , x , elem_dtype , block_size ):
28+ def forward (
29+ ctx ,
30+ input_hp : torch .Tensor ,
31+ weight_hp : torch .Tensor ,
32+ elem_dtype : Any ,
33+ block_size : int ,
34+ ):
35+ ctx .save_for_backward (input_hp , weight_hp )
3136 ctx .elem_dtype = elem_dtype
3237 ctx .block_size = block_size
33- return x
38+
39+ # input @ weight_t = output
40+ input_orig_shape = input_hp .shape
41+ input_hp_r = input_hp .reshape (- 1 , input_orig_shape [- 1 ])
42+
43+ input_mx_r_dim0 = MXTensor .to_mx (input_hp_r , elem_dtype , block_size )
44+ weight_mx_dim0 = MXTensor .to_mx (weight_hp , elem_dtype , block_size )
45+ output = torch .mm (input_mx_r_dim0 , weight_mx_dim0 .t ())
46+ output = output .reshape (* input_orig_shape [:- 1 ], output .shape [- 1 ])
47+
48+ return output
3449
3550 @staticmethod
36- def backward (ctx , g ):
37- scale , data = to_mx (g , ctx .elem_dtype , ctx .block_size )
38- return (
39- MXTensor (scale , data , ctx .elem_dtype , ctx .block_size , g .dtype ),
40- None ,
41- None ,
51+ def backward (ctx , grad_output_hp : torch .Tensor ):
52+ input_hp , weight_hp = ctx .saved_tensors
53+ weight_hp_t_c = weight_hp .t ().contiguous ()
54+ elem_dtype = ctx .elem_dtype
55+ block_size = ctx .block_size
56+
57+ grad_output_orig_shape = grad_output_hp .shape
58+ grad_output_hp_r = grad_output_hp .reshape (- 1 , grad_output_orig_shape [- 1 ])
59+
60+ input_hp_orig_shape = input_hp .shape
61+ input_hp_r = input_hp .reshape (- 1 , input_hp_orig_shape [- 1 ])
62+
63+ # grad_output @ weight = grad_input
64+ grad_output_mx_dim0 = MXTensor .to_mx (grad_output_hp_r , elem_dtype , block_size )
65+ weight_mx_dim1 = MXTensor .to_mx (weight_hp_t_c , elem_dtype , block_size )
66+ grad_input = torch .mm (grad_output_mx_dim0 , weight_mx_dim1 .t ())
67+ grad_input = grad_input .reshape (
68+ * grad_output_orig_shape [:- 1 ], grad_input .shape [- 1 ]
4269 )
4370
71+ # input_t @ grad_output = grad_weight
72+ grad_output_mx_dim1 = MXTensor .to_mx (
73+ grad_output_hp_r .t ().contiguous (), elem_dtype , block_size
74+ )
75+ input_t_mx_dim0_tmp = MXTensor .to_mx (
76+ input_hp_r .t ().contiguous (), elem_dtype , block_size
77+ )
78+ input_t_mx_dim0 = input_t_mx_dim0_tmp .t ()
79+ grad_weight = torch .mm (grad_output_mx_dim1 , input_t_mx_dim0 )
80+
81+ return grad_input , grad_weight , None , None
82+
4483
4584class MXLinear (torch .nn .Linear ):
4685 """
@@ -59,16 +98,26 @@ def from_float(cls, mod, elem_dtype, block_size):
5998 return mod
6099
61100 def forward (self , x ):
62- x_mx = MXTensor .to_mx (x , self .elem_dtype , self .block_size )
63- w_mx = MXTensor .to_mx (self .weight , self .elem_dtype , self .block_size )
64- y = F .linear (x_mx , w_mx , self .bias )
65- y = NoopFwToMXBw .apply (y , self .elem_dtype , self .block_size )
101+ if torch .is_autocast_enabled ():
102+ # special case autocast
103+ autocast_dtype = torch .get_autocast_dtype ("cuda" )
104+ x = x .to (autocast_dtype )
105+ w = self .weight .to (autocast_dtype )
106+ else :
107+ w = self .weight
108+
109+ y = mx_mm .apply (x , w , self .elem_dtype , self .block_size )
110+ if self .bias is not None :
111+ y = y + self .bias
66112 return y
67113
68114
69115class MXInferenceLinear (torch .nn .Linear ):
70116 """
71117 Inference version of MXLinear, with the weight pre-quantized to MX.
118+
119+ Note: this is weight-only quantization, with the gemm being executed
120+ in high precision.
72121 """
73122
74123 @classmethod
@@ -84,8 +133,8 @@ def from_float(cls, mod, elem_dtype, block_size):
84133 # TODO(future PR): set to new_mod.weight directly, will need to work
85134 # through some errors
86135 new_mod .weight_mx = MXTensor .to_mx (
87- mod .weight . t (). contiguous () , elem_dtype , block_size = block_size
88- ). t ()
136+ mod .weight , elem_dtype , block_size = block_size
137+ )
89138 new_mod .bias = mod .bias
90139 new_mod .elem_dtype = elem_dtype
91140 return new_mod
0 commit comments