55# LICENSE file in the root directory of this source tree. 
66
77""" 
8- Defines the UX for converting a model to use mx weights 
9- 
10- For now, this is a module swap for speed of iteration. 
11- 
12- Eventually we plan to move this to a tensor subclass weight wrapper for 
13- inference, and to a tensor subclass weight wrapper + module hooks for training. 
8+ Defines the prototype UX for converting a model to use mx weights 
149""" 
1510
11+ from  typing  import  Any 
12+ 
1613import  torch 
1714import  torch .nn .functional  as  F 
1815
1916from  torchao .prototype .mx_formats .mx_tensor  import  MXTensor , to_mx 
2017
2118
2219@torch ._dynamo .allow_in_graph  
23- class  NoopFwToMXBw (torch .autograd .Function ):
24-     """ 
25-     Forward: no-op 
26-     Backward: cast grad to MX 
27-     """ 
20+ class  mx_mm (torch .autograd .Function ):
21+     # There are three gemms in a forward + backward of a Linear layer: 
22+     # 
23+     # 1.       input @ weight_t    = output     (forward pass) 
24+     # 2. grad_output @ weight      = grad_input (backward pass) 
25+     # 3.     input_t @ grad_output = grad_weight (backward pass) 
2826
2927    @staticmethod  
30-     def  forward (ctx , x , elem_dtype , block_size ):
28+     def  forward (
29+         ctx ,
30+         input_hp : torch .Tensor ,
31+         weight_hp : torch .Tensor ,
32+         elem_dtype : Any ,
33+         block_size : int ,
34+     ):
35+         ctx .save_for_backward (input_hp , weight_hp )
3136        ctx .elem_dtype  =  elem_dtype 
3237        ctx .block_size  =  block_size 
33-         return  x 
38+ 
39+         # input @ weight_t = output 
40+         input_orig_shape  =  input_hp .shape 
41+         input_hp_r  =  input_hp .reshape (- 1 , input_orig_shape [- 1 ])
42+ 
43+         input_mx_r_dim0  =  MXTensor .to_mx (input_hp_r , elem_dtype , block_size )
44+         weight_mx_dim0  =  MXTensor .to_mx (weight_hp , elem_dtype , block_size )
45+         output  =  torch .mm (input_mx_r_dim0 , weight_mx_dim0 .t ())
46+         output  =  output .reshape (* input_orig_shape [:- 1 ], output .shape [- 1 ])
47+ 
48+         return  output 
3449
3550    @staticmethod  
36-     def  backward (ctx , g ):
37-         scale , data  =  to_mx (g , ctx .elem_dtype , ctx .block_size )
38-         return  (
39-             MXTensor (scale , data , ctx .elem_dtype , ctx .block_size , g .dtype ),
40-             None ,
41-             None ,
51+     def  backward (ctx , grad_output_hp : torch .Tensor ):
52+         input_hp , weight_hp  =  ctx .saved_tensors 
53+         weight_hp_t_c  =  weight_hp .t ().contiguous ()
54+         elem_dtype  =  ctx .elem_dtype 
55+         block_size  =  ctx .block_size 
56+ 
57+         grad_output_orig_shape  =  grad_output_hp .shape 
58+         grad_output_hp_r  =  grad_output_hp .reshape (- 1 , grad_output_orig_shape [- 1 ])
59+ 
60+         input_hp_orig_shape  =  input_hp .shape 
61+         input_hp_r  =  input_hp .reshape (- 1 , input_hp_orig_shape [- 1 ])
62+ 
63+         # grad_output @ weight = grad_input 
64+         grad_output_mx_dim0  =  MXTensor .to_mx (grad_output_hp_r , elem_dtype , block_size )
65+         weight_mx_dim1  =  MXTensor .to_mx (weight_hp_t_c , elem_dtype , block_size )
66+         grad_input  =  torch .mm (grad_output_mx_dim0 , weight_mx_dim1 .t ())
67+         grad_input  =  grad_input .reshape (
68+             * grad_output_orig_shape [:- 1 ], grad_input .shape [- 1 ]
4269        )
4370
71+         # input_t @ grad_output = grad_weight 
72+         grad_output_mx_dim1  =  MXTensor .to_mx (
73+             grad_output_hp_r .t ().contiguous (), elem_dtype , block_size 
74+         )
75+         input_t_mx_dim0_tmp  =  MXTensor .to_mx (
76+             input_hp_r .t ().contiguous (), elem_dtype , block_size 
77+         )
78+         # print('a', input_t_mx_dim0_tmp.shape) 
79+         input_t_mx_dim0  =  input_t_mx_dim0_tmp .t ()
80+         # print('b', input_t_mx_dim0.shape) 
81+         # TODO(next 2): debug why fp4 here leads to incorrect shapes 
82+         # import pdb; pdb.set_trace() 
83+         # print('go_dim1', grad_output_mx_dim1.shape, 'i_t_dim0', input_t_mx_dim0.shape) 
84+         grad_weight  =  torch .mm (grad_output_mx_dim1 , input_t_mx_dim0 )
85+ 
86+         return  grad_input , grad_weight , None , None 
87+ 
4488
4589class  MXLinear (torch .nn .Linear ):
4690    """ 
@@ -59,16 +103,26 @@ def from_float(cls, mod, elem_dtype, block_size):
59103        return  mod 
60104
61105    def  forward (self , x ):
62-         x_mx  =  MXTensor .to_mx (x , self .elem_dtype , self .block_size )
63-         w_mx  =  MXTensor .to_mx (self .weight , self .elem_dtype , self .block_size )
64-         y  =  F .linear (x_mx , w_mx , self .bias )
65-         y  =  NoopFwToMXBw .apply (y , self .elem_dtype , self .block_size )
106+         if  torch .is_autocast_enabled ():
107+             # special case autocast 
108+             autocast_dtype  =  torch .get_autocast_dtype ("cuda" )
109+             x  =  x .to (autocast_dtype )
110+             w  =  self .weight .to (autocast_dtype )
111+         else :
112+             w  =  self .weight 
113+ 
114+         y  =  mx_mm .apply (x , w , self .elem_dtype , self .block_size )
115+         if  self .bias  is  not None :
116+             y  =  y  +  self .bias 
66117        return  y 
67118
68119
69120class  MXInferenceLinear (torch .nn .Linear ):
70121    """ 
71122    Inference version of MXLinear, with the weight pre-quantized to MX. 
123+ 
124+     Note: this is weight-only quantization, with the gemm being executed 
125+     in high precision. 
72126    """ 
73127
74128    @classmethod  
@@ -84,8 +138,8 @@ def from_float(cls, mod, elem_dtype, block_size):
84138        # TODO(future PR): set to new_mod.weight directly, will need to work 
85139        # through some errors 
86140        new_mod .weight_mx  =  MXTensor .to_mx (
87-             mod .weight . t (). contiguous () , elem_dtype , block_size = block_size 
88-         ). t () 
141+             mod .weight , elem_dtype , block_size = block_size 
142+         )
89143        new_mod .bias  =  mod .bias 
90144        new_mod .elem_dtype  =  elem_dtype 
91145        return  new_mod 
0 commit comments