Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support power of 2 scaling factors in float8 training #1670

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Feb 5, 2025

Summary

Add support for power of 2 scaling factors in float8 training with dynamic scaling.

Behavior:

  • Default on in the rowwise scaling recipe in Float8LinearConfig returned from recipe_name_to_linear_config for rowwise scaling.
  • Default off for other cases.

Test Plan
Updated test cases to ensure power of 2 scaling does not impact numerics for axiswise dynamic scaling (eager and compiled)

Copy link

pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1670

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit ab93e18 with merge base 8afd10e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2025
@danielvegamyhre danielvegamyhre changed the title Support power of 2 scaling factors in float8 training Support power of 2 scaling factors in float8 training via boolean param in Float8LinearConfig Feb 5, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft February 5, 2025 20:03
@danielvegamyhre danielvegamyhre added the topic: new feature Use this tag if this PR adds a new feature label Feb 5, 2025
@danielvegamyhre danielvegamyhre changed the title Support power of 2 scaling factors in float8 training via boolean param in Float8LinearConfig Support power of 2 scaling factors in float8 training Feb 5, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 5, 2025 21:37
@@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
power_of_2_scale: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewer: this param list is getting pretty long, and 4 of the 9 params can be derived from the Float8LinearConfig. Any thoughts on refactoring to pass in the Float8LinearConfig directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I'll do that in a follow up so Less can begin scale testing after we merge this asap

@danielvegamyhre danielvegamyhre requested a review from vkuzo February 5, 2025 21:47
# reduce quantization error by avoiding rounding errors when multiplying/dividing
# by the scaling factor, as well as ensuring large values are quantized to the
# same value in the forward pass as the backward passes.
power_of_2_scale: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe something like round_scales_to_power_of_2, to match naming of surrounding code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! commented on suggested naming, any chance we could also check that this does not regress performance in torchtitan?

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Feb 5, 2025

looks great! commented on suggested naming, any chance we could also check that this does not regress performance in torchtitan?

My shared devgpu has too much other usage currently to test with Llama3 8b, so I used a debug model which is much smaller (800M params).

For rowwise with power of 2 scales, memory is flat but there is an ~8% regression in TPS. I'm wondering if this is because a small model like this has higher performance variance or if there is a real issue, need to look into it further.

Llama3 model configs: n_layers=4, dim=4096, n_heads=16

Tested on 4 H100s.

Row wise without power of 2 scales:

[rank0]:2025-02-05 15:05:21,682 - root - INFO - step:  1  loss:  8.2016  memory: 10.28GiB(10.82%)  tps: 584  mfu: 0.31%
[rank0]:2025-02-05 15:05:21,682 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 15:05:21,914 - root - INFO - step:  2  loss:  9.2027  memory: 11.67GiB(12.28%)  tps: 70,711  mfu: 38.00%
[rank0]:2025-02-05 15:05:22,146 - root - INFO - step:  3  loss:  8.4319  memory: 11.67GiB(12.28%)  tps: 70,914  mfu: 38.11%
[rank0]:2025-02-05 15:05:22,375 - root - INFO - step:  4  loss: 13.0116  memory: 11.67GiB(12.28%)  tps: 71,446  mfu: 38.40%
[rank0]:2025-02-05 15:05:22,604 - root - INFO - step:  5  loss: 10.0891  memory: 11.67GiB(12.28%)  tps: 71,662  mfu: 38.51%
[rank0]:2025-02-05 15:05:22,835 - root - INFO - step:  6  loss:  8.8140  memory: 11.67GiB(12.28%)  tps: 71,041  mfu: 38.18%
[rank0]:2025-02-05 15:05:23,068 - root - INFO - step:  7  loss:  7.9921  memory: 11.67GiB(12.28%)  tps: 70,531  mfu: 37.91%
[rank0]:2025-02-05 15:05:23,297 - root - INFO - step:  8  loss:  7.5519  memory: 11.67GiB(12.28%)  tps: 71,670  mfu: 38.52%
[rank0]:2025-02-05 15:05:23,525 - root - INFO - step:  9  loss:  7.4012  memory: 11.67GiB(12.28%)  tps: 71,808  mfu: 38.59%
[rank0]:2025-02-05 15:05:23,754 - root - INFO - step: 10  loss:  7.2013  memory: 11.67GiB(12.28%)  tps: 71,647  mfu: 38.51%

Row wise with power of 2 scales:

[rank0]:2025-02-05 15:02:52,539 - root - INFO - step:  1  loss:  8.2104  memory:  9.85GiB(10.37%)  tps: 1,981  mfu: 1.06%
[rank0]:2025-02-05 15:02:52,539 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 15:02:52,792 - root - INFO - step:  2  loss:  9.2376  memory: 11.20GiB(11.79%)  tps: 64,845  mfu: 34.85%
[rank0]:2025-02-05 15:02:53,046 - root - INFO - step:  3  loss:  8.6284  memory: 11.20GiB(11.79%)  tps: 64,742  mfu: 34.80%
[rank0]:2025-02-05 15:02:53,297 - root - INFO - step:  4  loss: 11.2887  memory: 11.20GiB(11.79%)  tps: 65,266  mfu: 35.08%
[rank0]:2025-02-05 15:02:53,548 - root - INFO - step:  5  loss:  9.4400  memory: 11.20GiB(11.79%)  tps: 65,429  mfu: 35.16%
[rank0]:2025-02-05 15:02:53,800 - root - INFO - step:  6  loss:  8.5271  memory: 11.20GiB(11.79%)  tps: 65,117  mfu: 35.00%
[rank0]:2025-02-05 15:02:54,055 - root - INFO - step:  7  loss:  7.8088  memory: 11.20GiB(11.79%)  tps: 64,426  mfu: 34.63%
[rank0]:2025-02-05 15:02:54,305 - root - INFO - step:  8  loss:  7.4392  memory: 11.20GiB(11.79%)  tps: 65,452  mfu: 35.18%
[rank0]:2025-02-05 15:02:54,555 - root - INFO - step:  9  loss:  7.3227  memory: 11.20GiB(11.79%)  tps: 65,783  mfu: 35.35%
[rank0]:2025-02-05 15:02:54,805 - root - INFO - step: 10  loss:  7.0642  memory: 11.20GiB(11.79%)  tps: 65,620  mfu: 35.27%

@vkuzo
Copy link
Contributor

vkuzo commented Feb 5, 2025

sounds like that's worth a follow-up

two things to check I can think of:

  1. does it reproduce on full size LLaMa 3 8B on 8 H100s?
  2. does the regression go away if we use bit shifting instead of exp and log?


if round_scales_to_power_of_2:
# rounds down to the nearest power of 2.
res = torch.exp2(torch.floor(torch.log2(res)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be the same as setting the mantissa to all-zeroes (maybe with some special handling for inf/nan), and can be implemented with bit shifting. Do you want to try to see if that resolves the regression?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't test, but something like

for float32

res = res.view(torch.uint32_t)
res = (res >> 23) << 23
res = res.view(torch.float)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint32 doesn't support bitshift ops apparently so I had to use int32. unit tests pass though and TPS regression is gone. will the sign bit affect anything? I did some manual tests in the interpreter and rounding seemed to work as expecting.

[rank0]:2025-02-05 16:11:30,663 - root - INFO - step:  1  loss:  8.2105  memory:  9.69GiB(10.20%)  tps: 610  mfu: 0.33%
[rank0]:2025-02-05 16:11:30,663 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-05 16:11:30,896 - root - INFO - step:  2  loss:  9.2258  memory: 11.02GiB(11.60%)  tps: 70,207  mfu: 37.73%
[rank0]:2025-02-05 16:11:31,129 - root - INFO - step:  3  loss:  8.5120  memory: 11.02GiB(11.60%)  tps: 70,377  mfu: 37.82%
[rank0]:2025-02-05 16:11:31,361 - root - INFO - step:  4  loss: 11.7253  memory: 11.02GiB(11.60%)  tps: 70,885  mfu: 38.10%
[rank0]:2025-02-05 16:11:31,591 - root - INFO - step:  5  loss:  9.3686  memory: 11.02GiB(11.60%)  tps: 71,365  mfu: 38.35%
[rank0]:2025-02-05 16:11:31,823 - root - INFO - step:  6  loss:  8.5610  memory: 11.02GiB(11.60%)  tps: 70,634  mfu: 37.96%
[rank0]:2025-02-05 16:11:32,059 - root - INFO - step:  7  loss:  7.7763  memory: 11.02GiB(11.60%)  tps: 69,681  mfu: 37.45%
[rank0]:2025-02-05 16:11:32,287 - root - INFO - step:  8  loss:  7.4649  memory: 11.02GiB(11.60%)  tps: 71,963  mfu: 38.68%
[rank0]:2025-02-05 16:11:32,517 - root - INFO - step:  9  loss:  7.2956  memory: 11.02GiB(11.60%)  tps: 71,188  mfu: 38.26%
[rank0]:2025-02-05 16:11:32,749 - root - INFO - step: 10  loss:  7.1085  memory: 11.02GiB(11.60%)  tps: 70,748  mfu: 38.02%```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants