-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support power of 2 scaling factors in float8 training #1670
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1670
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit ab93e18 with merge base 8afd10e (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
4ff4aca
to
f2433b1
Compare
ecc23ae
to
a9fe17e
Compare
@@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( | |||
device_mesh=None, | |||
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, | |||
axiswise_dim: Optional[int] = None, | |||
power_of_2_scale: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for reviewer: this param list is getting pretty long, and 4 of the 9 params can be derived from the Float8LinearConfig. Any thoughts on refactoring to pass in the Float8LinearConfig directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds reasonable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, I'll do that in a follow up so Less can begin scale testing after we merge this asap
067db27
to
c70ad60
Compare
torchao/float8/config.py
Outdated
# reduce quantization error by avoiding rounding errors when multiplying/dividing | ||
# by the scaling factor, as well as ensuring large values are quantized to the | ||
# same value in the forward pass as the backward passes. | ||
power_of_2_scale: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe something like round_scales_to_power_of_2
, to match naming of surrounding code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great! commented on suggested naming, any chance we could also check that this does not regress performance in torchtitan?
My shared devgpu has too much other usage currently to test with Llama3 8b, so I used a debug model which is much smaller (800M params). For rowwise with power of 2 scales, memory is flat but there is an ~8% regression in TPS. I'm wondering if this is because a small model like this has higher performance variance or if there is a real issue, need to look into it further. Llama3 model configs: n_layers=4, dim=4096, n_heads=16 Tested on 4 H100s. Row wise without power of 2 scales:
Row wise with power of 2 scales:
|
sounds like that's worth a follow-up two things to check I can think of:
|
torchao/float8/float8_utils.py
Outdated
|
||
if round_scales_to_power_of_2: | ||
# rounds down to the nearest power of 2. | ||
res = torch.exp2(torch.floor(torch.log2(res))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be the same as setting the mantissa to all-zeroes (maybe with some special handling for inf/nan), and can be implemented with bit shifting. Do you want to try to see if that resolves the regression?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't test, but something like
for float32
res = res.view(torch.uint32_t)
res = (res >> 23) << 23
res = res.view(torch.float)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint32 doesn't support bitshift ops apparently so I had to use int32. unit tests pass though and TPS regression is gone. will the sign bit affect anything? I did some manual tests in the interpreter and rounding seemed to work as expecting.
[rank0]:2025-02-05 16:11:30,663 - root - INFO - step: 1 loss: 8.2105 memory: 9.69GiB(10.20%) tps: 610 mfu: 0.33%
[rank0]:2025-02-05 16:11:30,663 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 16:11:30,896 - root - INFO - step: 2 loss: 9.2258 memory: 11.02GiB(11.60%) tps: 70,207 mfu: 37.73%
[rank0]:2025-02-05 16:11:31,129 - root - INFO - step: 3 loss: 8.5120 memory: 11.02GiB(11.60%) tps: 70,377 mfu: 37.82%
[rank0]:2025-02-05 16:11:31,361 - root - INFO - step: 4 loss: 11.7253 memory: 11.02GiB(11.60%) tps: 70,885 mfu: 38.10%
[rank0]:2025-02-05 16:11:31,591 - root - INFO - step: 5 loss: 9.3686 memory: 11.02GiB(11.60%) tps: 71,365 mfu: 38.35%
[rank0]:2025-02-05 16:11:31,823 - root - INFO - step: 6 loss: 8.5610 memory: 11.02GiB(11.60%) tps: 70,634 mfu: 37.96%
[rank0]:2025-02-05 16:11:32,059 - root - INFO - step: 7 loss: 7.7763 memory: 11.02GiB(11.60%) tps: 69,681 mfu: 37.45%
[rank0]:2025-02-05 16:11:32,287 - root - INFO - step: 8 loss: 7.4649 memory: 11.02GiB(11.60%) tps: 71,963 mfu: 38.68%
[rank0]:2025-02-05 16:11:32,517 - root - INFO - step: 9 loss: 7.2956 memory: 11.02GiB(11.60%) tps: 71,188 mfu: 38.26%
[rank0]:2025-02-05 16:11:32,749 - root - INFO - step: 10 loss: 7.1085 memory: 11.02GiB(11.60%) tps: 70,748 mfu: 38.02%```
Summary
Add support for power of 2 scaling factors in float8 training with dynamic scaling.
Behavior:
Float8LinearConfig
returned fromrecipe_name_to_linear_config
for rowwise scaling.Test Plan
Updated test cases to ensure power of 2 scaling does not impact numerics for axiswise dynamic scaling (eager and compiled)