-
Notifications
You must be signed in to change notification settings - Fork 257
integrate new float8 quantization primitives into AQT #1598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
integrate new float8 quantization primitives into AQT #1598
Conversation
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1598
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2cac42e with merge base 860da26 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
target_dtype, | ||
) | ||
fp8_data = _layout.post_process(fp8_data) | ||
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this have multiple options for float8? if not, just call it directly to reduce # of abstractions the code reader needs to know about
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, done!
target_dtype: torch.dtype, | ||
block_size: Tuple[int, ...], | ||
_layout: Layout = PlainLayout(), | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a docblock here should explain the difference between from_hp_to_floatx
, from_hp_to_fpx
, from_hp_to_float8
@@ -422,6 +417,39 @@ def from_hp_to_fpx( | |||
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) | |||
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) | |||
|
|||
@classmethod | |||
def from_hp_to_float8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update from_hp_to_floatx with the new float8 logic. For fp1-fp7, we're using from_hp_to_floatx.
Context
Currently, AQT has the method from_hp_to_floatx for float8 quantization, and from_hp_to_fpx for low precision floating point data types like fp6 (technically can support fp1-fp7).
from_hp_to_floatx
re-uses from_hp_to_intx, which in turn uses these generic quantization primitives.Overall, in the current state the float8 path is a bit confusing for developers, due to both the naming ("floatx") and the use of generic functions which include a bunch of params which are unrelated to float8 quantization.
Summary of changes
The goal of this PR stack is to refactor this to have a clean separation of concerns, and simpler internal API surfaces for code using in float8 quantization for inference.
Specifically:
Note: I will add float8 static quantization in a separate set of PRs.