Skip to content

Commit 0ca0130

Browse files
committed
Create separate float8 tensor subclass
ghstack-source-id: 140144dda9c2cef71256e30caba73d2efa08e0cb Pull Request resolved: #1636
1 parent 57816c9 commit 0ca0130

File tree

2 files changed

+147
-150
lines changed

2 files changed

+147
-150
lines changed

torchao/dtypes/affine_quantized_tensor.py

-58
Original file line numberDiff line numberDiff line change
@@ -337,64 +337,6 @@ def from_hp_to_intx_static(
337337
dtype=input_float.dtype,
338338
)
339339

340-
@classmethod
341-
def from_hp_to_floatx(
342-
cls,
343-
input_float: torch.Tensor,
344-
block_size: Tuple[int, ...],
345-
target_dtype: torch.dtype,
346-
_layout: Layout,
347-
scale_dtype: Optional[torch.dtype] = None,
348-
):
349-
"""Convert a high precision tensor to a float8 quantized tensor."""
350-
if target_dtype in FP8_TYPES:
351-
return cls.from_hp_to_intx(
352-
input_float=input_float,
353-
mapping_type=MappingType.SYMMETRIC,
354-
block_size=block_size,
355-
target_dtype=target_dtype,
356-
quant_min=math.ceil(torch.finfo(target_dtype).min),
357-
quant_max=math.ceil(torch.finfo(target_dtype).max),
358-
eps=torch.finfo(torch.float32).eps,
359-
scale_dtype=scale_dtype,
360-
zero_point_dtype=None,
361-
preserve_zero=True,
362-
zero_point_domain=None,
363-
_layout=_layout,
364-
use_hqq=False,
365-
)
366-
else:
367-
raise NotImplementedError(
368-
f"Unsupported dtype {target_dtype} for from_hp_to_floatx"
369-
)
370-
371-
@classmethod
372-
def from_hp_to_floatx_static(
373-
cls,
374-
input_float: torch.Tensor,
375-
scale: torch.Tensor,
376-
block_size: Tuple[int, ...],
377-
target_dtype: torch.dtype,
378-
_layout: Layout,
379-
):
380-
"""Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters."""
381-
if target_dtype in FP8_TYPES:
382-
return cls.from_hp_to_intx_static(
383-
input_float=input_float,
384-
scale=scale,
385-
zero_point=None,
386-
block_size=block_size,
387-
target_dtype=target_dtype,
388-
quant_min=math.ceil(torch.finfo(target_dtype).min),
389-
quant_max=math.ceil(torch.finfo(target_dtype).max),
390-
zero_point_domain=None,
391-
_layout=_layout,
392-
)
393-
else:
394-
raise NotImplementedError(
395-
f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static"
396-
)
397-
398340
@classmethod
399341
def from_hp_to_fpx(
400342
cls,

torchao/dtypes/floatx/float8_layout.py

+147-92
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
22
from typing import Optional, Tuple, Union
3+
import math
34

45
import torch
56
from torch.utils._python_dispatch import (
@@ -11,15 +12,21 @@
1112
AffineQuantizedTensor,
1213
register_layout,
1314
)
15+
from torchao.dtypes.nf4tensor import implements
1416
from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape
1517
from torchao.float8.inference import (
1618
Float8MMConfig,
1719
_is_rowwise_scaled,
1820
addmm_float8_unwrapped_inference,
1921
preprocess_data,
2022
)
21-
from torchao.utils import _is_float8_type, fill_defaults
22-
23+
from torchao.utils import _is_float8_type, fill_defaults, TorchAOBaseTensor
24+
from torchao.quantization.quant_primitives import (
25+
FP8_TYPES,
26+
MappingType,
27+
choose_qparams_affine_float8,
28+
quantize_affine_float8,
29+
)
2330
aten = torch.ops.aten
2431

2532

@@ -34,13 +41,16 @@ class Float8Layout(Layout):
3441
mm_config: Optional[Float8MMConfig] = None
3542

3643

37-
@register_layout(Float8Layout)
38-
class Float8AQTTensorImpl(AQTTensorImpl):
44+
class Float8Tensor(TorchAOBaseTensor):
3945
"""
40-
TensorImpl for float8 layout affine quantized tensor
46+
Float8 Tensor is a subclass of torch.Tensor that supports float8 data types.
47+
It is used to represent the data in a float8 tensor.
4148
42-
Note: technically we should not create a new layout for float8 we should merge this into
43-
plain layout
49+
Attributes:
50+
float8_data (torch.Tensor): The float8 data tensor.
51+
scale (torch.Tensor): The scale tensor.
52+
transposed (bool): Whether the tensor is transposed or not.
53+
_layout (Layout): The layout of the tensor.
4454
"""
4555

4656
float8_data: torch.Tensor
@@ -52,7 +62,7 @@ def __new__(
5262
float8_data: torch.Tensor,
5363
scale: torch.Tensor,
5464
transposed: bool,
55-
_layout: Layout,
65+
_layout: Layout = Float8Layout(),
5666
):
5767
kwargs = {}
5868
kwargs["device"] = float8_data.device
@@ -69,7 +79,7 @@ def __init__(
6979
float8_data: torch.Tensor,
7080
scale: torch.Tensor,
7181
transposed: bool,
72-
_layout: Layout,
82+
_layout: Layout = Float8Layout(),
7383
):
7484
self.float8_data = float8_data
7585
self.scale = scale
@@ -108,84 +118,20 @@ def __tensor_unflatten__(
108118
) = tensor_attributes
109119
return cls(float8_data, scale, transposed, _layout)
110120

111-
@classmethod
112-
def __torch_dispatch__(cls, func, types, args, kwargs):
113-
kwargs = {} if kwargs is None else kwargs
114-
115-
if func is aten.detach.default:
116-
return return_and_correct_aliasing(
117-
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
118-
)
119-
elif func is aten.clone.default:
120-
return return_and_correct_aliasing(
121-
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
122-
)
123-
elif func is aten.t.default:
124-
"""we don't need to repack the weight and just rely on external
125-
shape being changed and record the status of transpose/no-transpose
126-
"""
127-
args[0].transposed = not args[0].transposed
128-
return return_and_correct_aliasing(func, args, kwargs, args[0])
129-
elif func is aten.slice.Tensor:
130-
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
131-
if dim == 0:
132-
# TODO: scale replecation should be dependent on block size
133-
if self.scale.ndim == 1:
134-
return return_and_correct_aliasing(
135-
func,
136-
args,
137-
kwargs,
138-
args[0]._apply_fn_to_data(
139-
lambda x: aten.slice.Tensor(x, dim, start, end, step)
140-
),
141-
)
142-
elif self.scale.ndim == 0:
143-
return return_and_correct_aliasing(
144-
func,
145-
args,
146-
kwargs,
147-
Float8AQTTensorImpl(
148-
aten.slice.Tensor(self.float8_data, dim, start, end, step),
149-
self.scale,
150-
None,
151-
self._layout,
152-
),
153-
)
154-
else:
155-
raise NotImplementedError(
156-
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported"
157-
)
158-
elif dim == 1:
159-
return return_and_correct_aliasing(
160-
func,
161-
args,
162-
kwargs,
163-
Float8AQTTensorImpl(
164-
aten.slice.Tensor(
165-
self.float8_data, dim, start, end, step
166-
).contiguous(),
167-
self.scale,
168-
None,
169-
self._layout,
170-
),
171-
)
172-
else:
173-
raise NotImplementedError(
174-
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
175-
)
176-
else:
177-
raise NotImplementedError(
178-
f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported"
179-
)
180-
181-
__torch_function__ = torch._C._disabled_torch_function_impl
121+
def __repr__(self):
122+
float8_data, scale, _ = self.get_plain()
123+
_layout = self.get_layout()
124+
return (
125+
f"{self.__class__.__name__}(\n"
126+
f"float8_data={float8_data},\n"
127+
f"scale={scale},\n"
128+
f"transposed={self.transposed}, "
129+
f"_layout={_layout})"
130+
)
182131

183132
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
184133
return self.float8_data, self.scale, None
185134

186-
def get_layout(self) -> Layout:
187-
return self._layout
188-
189135
@classmethod
190136
def from_plain(
191137
cls,
@@ -203,15 +149,120 @@ def from_plain(
203149
), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}"
204150
return cls(data, scale, False, _layout)
205151

206-
def __repr__(self):
207-
float8_data, scale, _ = self.get_plain()
208-
_layout = self.get_layout()
209-
return (
210-
f"{self.__class__.__name__}(\n"
211-
f"float8_data={float8_data},\n"
212-
f"scale={scale},\n"
213-
f"transposed={self.transposed}, "
214-
f"_layout={_layout})"
152+
@classmethod
153+
def from_hp_to_floatx(
154+
cls,
155+
input_float: torch.Tensor,
156+
target_dtype: torch.dtype,
157+
_layout: Layout = Float8Layout(),
158+
):
159+
"""Convert a high precision tensor to a float8 quantized tensor."""
160+
if target_dtype not in FP8_TYPES:
161+
raise NotImplementedError(
162+
f"Unsupported dtype {target_dtype} for from_hp_to_floatx"
163+
)
164+
scale = choose_qparams_affine_float8(
165+
input_float,
166+
target_dtype,
167+
)
168+
float_data = quantize_affine_float8(
169+
input_float,
170+
scale,
171+
target_dtype,
172+
)
173+
174+
return cls(
175+
float_data,
176+
scale,
177+
False,
178+
_layout,
179+
)
180+
181+
@classmethod
182+
def from_hp_to_floatx_static(
183+
cls,
184+
input_float: torch.Tensor,
185+
scale: torch.Tensor,
186+
target_dtype: torch.dtype,
187+
_layout: Layout,
188+
):
189+
"""Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters."""
190+
if target_dtype not in FP8_TYPES:
191+
raise NotImplementedError(
192+
f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static"
193+
)
194+
float_data = quantize_affine_float8(
195+
input_float,
196+
scale,
197+
target_dtype,
198+
)
199+
200+
return cls(
201+
float_data,
202+
scale,
203+
False,
204+
_layout,
205+
)
206+
207+
__torch_function__ = torch._C._disabled_torch_function_impl
208+
209+
210+
@implements(aten.t.default)
211+
def _(func, types, args, kwargs):
212+
"""we don't need to repack the weight and just rely on external
213+
shape being changed and record the status of transpose/no-transpose
214+
"""
215+
args[0].transposed = not args[0].transposed
216+
return return_and_correct_aliasing(func, args, kwargs, args[0])
217+
218+
219+
@implements(aten.slice.Tensor)
220+
def _(func, types, args, kwargs):
221+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
222+
if dim == 0:
223+
# TODO: scale replecation should be dependent on block size
224+
if self.scale.ndim == 1:
225+
return return_and_correct_aliasing(
226+
func,
227+
args,
228+
kwargs,
229+
args[0]._apply_fn_to_data(
230+
lambda x: aten.slice.Tensor(x, dim, start, end, step)
231+
),
232+
)
233+
elif self.scale.ndim == 0:
234+
return return_and_correct_aliasing(
235+
func,
236+
args,
237+
kwargs,
238+
Float8Tensor(
239+
aten.slice.Tensor(self.float8_data, dim, start, end, step),
240+
self.scale,
241+
self.transposed,
242+
self._layout,
243+
),
244+
)
245+
else:
246+
raise NotImplementedError(
247+
f"Float8Tensor dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported"
248+
)
249+
elif dim == 1:
250+
return return_and_correct_aliasing(
251+
func,
252+
args,
253+
kwargs,
254+
Float8Tensor(
255+
aten.slice.Tensor(
256+
self.float8_data, dim, start, end, step
257+
).contiguous(),
258+
self.scale,
259+
self.transposed,
260+
self._layout,
261+
),
262+
)
263+
else:
264+
raise NotImplementedError(
265+
f"Float8Tensor dispatch: attempting to run {func}, with dim={dim}, that is not supported"
215266
)
216267

217268

@@ -317,3 +368,7 @@ def _linear_fp_act_fp8_weight_impl(
317368
bias: Optional[torch.Tensor],
318369
):
319370
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)
371+
372+
373+
to_quantized_float8 = Float8Tensor.from_hp_to_floatx
374+
to_quantized_float8_static = Float8Tensor.from_hp_to_float8_static

0 commit comments

Comments
 (0)