-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for power of 2 scaling factors in float8 training #1669
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1669
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit ad8061b with merge base 8afd10e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
7bf2569
to
ecef836
Compare
ecef836
to
ad8061b
Compare
@@ -234,6 +248,9 @@ class Float8LinearConfig: | |||
# tests so that the warning does not spam the CI stdout. | |||
force_recompute_fp8_weight_in_bwd: bool = False | |||
|
|||
# configuration used for calculating the scaling factor used in float8 quantization. | |||
scaling_factor_config: Float8ScalingFactorConfig = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just make it a boolean, and set it to true for rowwise scaling in code using Float8LinearRecipeName
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally did this (in a different branch), but then there didn't seem to be a nice/simple way for the user to use row-wise scaling without power of 2 scale factors.
I am trying to support the behavior of "Default on for row-wise scaling, with option to disable. For other scaling granularities, default off, with option to enable."
I tried to explain my logic in the PR description but basically if this is a boolean, we won't be able to distinguish if the user explicitly set it to False (for example, to use row-wise scaling without power of 2 scaling - in which case we should not override it to true in the rowwise code), or if it's False because it defaults to this when unset (in which case, we should set it to true in the rowwise code).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there didn't seem to be a nice/simple way for the user to use row-wise scaling without power of 2 scale factors
they could create a Float8LinearConfig and configure whatever they want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default on for row-wise scaling, with option to disable. For other scaling granularities, default off, with option to enable.
that should be changed to default on for row-wise scaling recipe created by the recipe enum to recipe function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default on for row-wise scaling, with option to disable. For other scaling granularities, default off, with option to enable.
that should be changed to default on for row-wise scaling recipe created by the recipe enum to recipe function
Yeah, this is what I did in the other branch. I can create a PR for that one as well to compare. I didn't prefer it because a lot of the rowwise scaling tests do not use the recipe helper function, and just directly call hp_tensor_to_float8_dynamic
so this approach would require modifying a bunch of tests, adding new parameterizations, etc. to make sure "default on" doesn't negatively affect numerics. So this approach seemed like a simpler way to me. Happy to do it the other way, though.
Summary
Add support for power of 2 scaling factors in float8 training with dynamic scaling.
REVIEWER NOTE: To support the behavior of "default on, but can be optionally disabled", we must be able to distinguish
between the values for the "unset default" and "explicit negative setting."
This is why I didn't use a simple boolean flag for power of 2 scaling, because the unset default and explicit negative setting would both be
False
. By using a dataclass config, we can distinguish between the unset default (None) and explicit negative setting (config exists andscaling_config.power_of_2_scale is False
).I'm open to feedback/suggestions if there is a better way to do this, let me know.
Test Plan
test/float8/
) are passing.