1
1
from dataclasses import dataclass
2
2
from typing import Optional , Tuple , Union
3
+ import math
3
4
4
5
import torch
5
6
from torch .utils ._python_dispatch import (
11
12
AffineQuantizedTensor ,
12
13
register_layout ,
13
14
)
15
+ from torchao .dtypes .nf4tensor import implements
14
16
from torchao .dtypes .utils import AQTTensorImpl , Layout , get_out_shape
15
17
from torchao .float8 .inference import (
16
18
Float8MMConfig ,
17
19
_is_rowwise_scaled ,
18
20
addmm_float8_unwrapped_inference ,
19
21
preprocess_data ,
20
22
)
21
- from torchao .utils import _is_float8_type , fill_defaults
22
-
23
+ from torchao .utils import _is_float8_type , fill_defaults , TorchAOBaseTensor
24
+ from torchao .quantization .quant_primitives import (
25
+ FP8_TYPES ,
26
+ MappingType ,
27
+ choose_qparams_affine_float8 ,
28
+ quantize_affine_float8 ,
29
+ )
23
30
aten = torch .ops .aten
24
31
25
32
@@ -34,13 +41,16 @@ class Float8Layout(Layout):
34
41
mm_config : Optional [Float8MMConfig ] = None
35
42
36
43
37
- @register_layout (Float8Layout )
38
- class Float8AQTTensorImpl (AQTTensorImpl ):
44
+ class Float8Tensor (TorchAOBaseTensor ):
39
45
"""
40
- TensorImpl for float8 layout affine quantized tensor
46
+ Float8 Tensor is a subclass of torch.Tensor that supports float8 data types.
47
+ It is used to represent the data in a float8 tensor.
41
48
42
- Note: technically we should not create a new layout for float8 we should merge this into
43
- plain layout
49
+ Attributes:
50
+ float8_data (torch.Tensor): The float8 data tensor.
51
+ scale (torch.Tensor): The scale tensor.
52
+ transposed (bool): Whether the tensor is transposed or not.
53
+ _layout (Layout): The layout of the tensor.
44
54
"""
45
55
46
56
float8_data : torch .Tensor
@@ -52,7 +62,7 @@ def __new__(
52
62
float8_data : torch .Tensor ,
53
63
scale : torch .Tensor ,
54
64
transposed : bool ,
55
- _layout : Layout ,
65
+ _layout : Layout = Float8Layout () ,
56
66
):
57
67
kwargs = {}
58
68
kwargs ["device" ] = float8_data .device
@@ -69,7 +79,7 @@ def __init__(
69
79
float8_data : torch .Tensor ,
70
80
scale : torch .Tensor ,
71
81
transposed : bool ,
72
- _layout : Layout ,
82
+ _layout : Layout = Float8Layout () ,
73
83
):
74
84
self .float8_data = float8_data
75
85
self .scale = scale
@@ -108,84 +118,20 @@ def __tensor_unflatten__(
108
118
) = tensor_attributes
109
119
return cls (float8_data , scale , transposed , _layout )
110
120
111
- @classmethod
112
- def __torch_dispatch__ (cls , func , types , args , kwargs ):
113
- kwargs = {} if kwargs is None else kwargs
114
-
115
- if func is aten .detach .default :
116
- return return_and_correct_aliasing (
117
- func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
118
- )
119
- elif func is aten .clone .default :
120
- return return_and_correct_aliasing (
121
- func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
122
- )
123
- elif func is aten .t .default :
124
- """we don't need to repack the weight and just rely on external
125
- shape being changed and record the status of transpose/no-transpose
126
- """
127
- args [0 ].transposed = not args [0 ].transposed
128
- return return_and_correct_aliasing (func , args , kwargs , args [0 ])
129
- elif func is aten .slice .Tensor :
130
- self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
131
- if dim == 0 :
132
- # TODO: scale replecation should be dependent on block size
133
- if self .scale .ndim == 1 :
134
- return return_and_correct_aliasing (
135
- func ,
136
- args ,
137
- kwargs ,
138
- args [0 ]._apply_fn_to_data (
139
- lambda x : aten .slice .Tensor (x , dim , start , end , step )
140
- ),
141
- )
142
- elif self .scale .ndim == 0 :
143
- return return_and_correct_aliasing (
144
- func ,
145
- args ,
146
- kwargs ,
147
- Float8AQTTensorImpl (
148
- aten .slice .Tensor (self .float8_data , dim , start , end , step ),
149
- self .scale ,
150
- None ,
151
- self ._layout ,
152
- ),
153
- )
154
- else :
155
- raise NotImplementedError (
156
- f"Float8AQTTensorImpl dispatch: attempting to run { func } , with scale ndim={ dim } , that is not supported"
157
- )
158
- elif dim == 1 :
159
- return return_and_correct_aliasing (
160
- func ,
161
- args ,
162
- kwargs ,
163
- Float8AQTTensorImpl (
164
- aten .slice .Tensor (
165
- self .float8_data , dim , start , end , step
166
- ).contiguous (),
167
- self .scale ,
168
- None ,
169
- self ._layout ,
170
- ),
171
- )
172
- else :
173
- raise NotImplementedError (
174
- f"Float8AQTTensorImpl dispatch: attempting to run { func } , with dim={ dim } , that is not supported"
175
- )
176
- else :
177
- raise NotImplementedError (
178
- f"Float8AQTTensorImpl dispatch: attempting to run { func } , this is not supported"
179
- )
180
-
181
- __torch_function__ = torch ._C ._disabled_torch_function_impl
121
+ def __repr__ (self ):
122
+ float8_data , scale , _ = self .get_plain ()
123
+ _layout = self .get_layout ()
124
+ return (
125
+ f"{ self .__class__ .__name__ } (\n "
126
+ f"float8_data={ float8_data } ,\n "
127
+ f"scale={ scale } ,\n "
128
+ f"transposed={ self .transposed } , "
129
+ f"_layout={ _layout } )"
130
+ )
182
131
183
132
def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
184
133
return self .float8_data , self .scale , None
185
134
186
- def get_layout (self ) -> Layout :
187
- return self ._layout
188
-
189
135
@classmethod
190
136
def from_plain (
191
137
cls ,
@@ -203,15 +149,120 @@ def from_plain(
203
149
), f"Float8 TensorImpl must be constructed from Float8Layout but got { _layout } "
204
150
return cls (data , scale , False , _layout )
205
151
206
- def __repr__ (self ):
207
- float8_data , scale , _ = self .get_plain ()
208
- _layout = self .get_layout ()
209
- return (
210
- f"{ self .__class__ .__name__ } (\n "
211
- f"float8_data={ float8_data } ,\n "
212
- f"scale={ scale } ,\n "
213
- f"transposed={ self .transposed } , "
214
- f"_layout={ _layout } )"
152
+ @classmethod
153
+ def from_hp_to_floatx (
154
+ cls ,
155
+ input_float : torch .Tensor ,
156
+ target_dtype : torch .dtype ,
157
+ _layout : Layout = Float8Layout (),
158
+ ):
159
+ """Convert a high precision tensor to a float8 quantized tensor."""
160
+ if target_dtype not in FP8_TYPES :
161
+ raise NotImplementedError (
162
+ f"Unsupported dtype { target_dtype } for from_hp_to_floatx"
163
+ )
164
+ scale = choose_qparams_affine_float8 (
165
+ input_float ,
166
+ target_dtype ,
167
+ )
168
+ float_data = quantize_affine_float8 (
169
+ input_float ,
170
+ scale ,
171
+ target_dtype ,
172
+ )
173
+
174
+ return cls (
175
+ float_data ,
176
+ scale ,
177
+ False ,
178
+ _layout ,
179
+ )
180
+
181
+ @classmethod
182
+ def from_hp_to_floatx_static (
183
+ cls ,
184
+ input_float : torch .Tensor ,
185
+ scale : torch .Tensor ,
186
+ target_dtype : torch .dtype ,
187
+ _layout : Layout ,
188
+ ):
189
+ """Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters."""
190
+ if target_dtype not in FP8_TYPES :
191
+ raise NotImplementedError (
192
+ f"Unsupported dtype { target_dtype } for from_hp_to_floatx_static"
193
+ )
194
+ float_data = quantize_affine_float8 (
195
+ input_float ,
196
+ scale ,
197
+ target_dtype ,
198
+ )
199
+
200
+ return cls (
201
+ float_data ,
202
+ scale ,
203
+ False ,
204
+ _layout ,
205
+ )
206
+
207
+ __torch_function__ = torch ._C ._disabled_torch_function_impl
208
+
209
+
210
+ @implements (aten .t .default )
211
+ def _ (func , types , args , kwargs ):
212
+ """we don't need to repack the weight and just rely on external
213
+ shape being changed and record the status of transpose/no-transpose
214
+ """
215
+ args [0 ].transposed = not args [0 ].transposed
216
+ return return_and_correct_aliasing (func , args , kwargs , args [0 ])
217
+
218
+
219
+ @implements (aten .slice .Tensor )
220
+ def _ (func , types , args , kwargs ):
221
+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
222
+ if dim == 0 :
223
+ # TODO: scale replecation should be dependent on block size
224
+ if self .scale .ndim == 1 :
225
+ return return_and_correct_aliasing (
226
+ func ,
227
+ args ,
228
+ kwargs ,
229
+ args [0 ]._apply_fn_to_data (
230
+ lambda x : aten .slice .Tensor (x , dim , start , end , step )
231
+ ),
232
+ )
233
+ elif self .scale .ndim == 0 :
234
+ return return_and_correct_aliasing (
235
+ func ,
236
+ args ,
237
+ kwargs ,
238
+ Float8Tensor (
239
+ aten .slice .Tensor (self .float8_data , dim , start , end , step ),
240
+ self .scale ,
241
+ self .transposed ,
242
+ self ._layout ,
243
+ ),
244
+ )
245
+ else :
246
+ raise NotImplementedError (
247
+ f"Float8Tensor dispatch: attempting to run { func } , with scale ndim={ dim } , that is not supported"
248
+ )
249
+ elif dim == 1 :
250
+ return return_and_correct_aliasing (
251
+ func ,
252
+ args ,
253
+ kwargs ,
254
+ Float8Tensor (
255
+ aten .slice .Tensor (
256
+ self .float8_data , dim , start , end , step
257
+ ).contiguous (),
258
+ self .scale ,
259
+ self .transposed ,
260
+ self ._layout ,
261
+ ),
262
+ )
263
+ else :
264
+ raise NotImplementedError (
265
+ f"Float8Tensor dispatch: attempting to run { func } , with dim={ dim } , that is not supported"
215
266
)
216
267
217
268
@@ -317,3 +368,7 @@ def _linear_fp_act_fp8_weight_impl(
317
368
bias : Optional [torch .Tensor ],
318
369
):
319
370
return torch .nn .functional .linear (input_tensor , weight_tensor .dequantize (), bias )
371
+
372
+
373
+ to_quantized_float8 = Float8Tensor .from_hp_to_floatx
374
+ to_quantized_float8_static = Float8Tensor .from_hp_to_float8_static
0 commit comments