Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Updates to sync_float_amax_history #211

Closed
wants to merge 14 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Feb 12, 2024

Summary

Update docs, make sure this is friendly to dynamo

Perf

PyTorch Version float8 Version Eager Iterations per Second Compile
Nightly Main 1.15 it/s  2.10 it/s
Nightly This PR 1.16 it/s 2.27 it/s
Trace Compile URL Eager
This PR https://fburl.com/753ztao4  https://fburl.com/34yftzao
Main  https://fburl.com/a0gh9iof  https://fburl.com/u9c4ilmp

Things I have done/changed

Commit 1

  • We previously had an fp8_classes argument that would be passed in, this was to enable working with the separate TP/SP classes, since we plan to have Dtensor be the solution I am removing for now.
  • I put the child.amax_and_scale_synced module mutation under the enable_amax_init flag, this seemed to be causing graphbreaks cause of the module mutation

Commit 2

  • We previously had all the history buffers be scaler tensors. This meant that to construct the combined tensor we needed to call torch.Tensor which was causing a HtoD sync under torch.compile. I needed to added a single dimension of size 1 and pipe that through all the places.
  • Note that this meant we needed to update the to_hp to send back to original precision because line the scale upcasts the _data tensor

Commit 3

  • Rewrote the sync function to do the torch.roll() on all the histories at once - side note not sure if this is more expensive than to clones since we really dont care about the wrapping behavior
  • Same for generating the new scales from the grouped histories
Things to do
  • There is still two loops and those are for mutating the the actual module values, not sure if there is another way around this..
  • Going to try the functional collectives

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 12, 2024
Copy link

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some initial comments just scanning through the code.

@drisspg drisspg requested review from awgu, bdhirsh and y-sq February 14, 2024 02:22
@drisspg drisspg force-pushed the make_syncing_more_compile_friendly branch from 9f3074a to 3e4d745 Compare February 14, 2024 17:39
)
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: It seems like a bug on AsyncCollectiveTensor if it does not make sure to call wait() when you torch.split it. 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm...

I actually tried to add an optimization to AsyncCollectiveTensor a while ago: pytorch/pytorch#105240. Where the thought behind it was:

(1) if we have a collective in-flight, we only need to sync the collective if the data needs to be used in an operation

(2) view ops generally don't need to access the data of their inputs, only the metadata.

So that PR "delays" syncs when there are view ops called on AsyncCollectiveTensor (e.g. aten.split()).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I was working through this with @wanchaol and without the all_reduced_amax_tensor was garbage...

Copy link
Contributor

@bdhirsh bdhirsh Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why the manual sync here is needed - does it paper over some underlying bug? Didn't see your last comment. I wonder if this is related to some internal AsyncCollectiveTensor issues 🤔

Let me know if either of you want to stare at it together (I think fixing the underlying bug will help other areas too)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, I was not able to repro the issue: P1184812077

_foreach_copy_ and functional all-reduce without .wait() seemed to be fine, but I might not be repro-ing correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm let me try again then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right... however:

    return self.__get_result()
  File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CompilationError: at 10:30:def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, out_ptr3, out_ptr4, out_ptr5, out_ptr7, out_ptr8, out_ptr9, out_ptr11):
    xpid = tl.program_id(0)
    XBLOCK: tl.constexpr = 1024
    if xpid >= 0 and xpid < 1:
        xpid_offset = xpid - 0
        xnumel = 1
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rindex = tl.arange(0, RBLOCK)[None, :]
                              ^
NameError('RBLOCK is not defined')

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

compiling the sync function now causes an inductor fault... so I will create another PR and investigate there but otherwise I think thats everything for this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg drisspg force-pushed the make_syncing_more_compile_friendly branch from 3e4d745 to 4ca6ddf Compare February 14, 2024 20:35
@drisspg drisspg requested a review from awgu February 14, 2024 21:24
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

PERFORMANCE NOTE:
When you can, it is much more efficient to call get_float8_layers once at
the beginning of the training loop and pass the result to this function.
Because of how this interacts with torch.compile
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, how does passing float8_layers affect torch.compile?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly I wrote this at the beginning of the PR and now I am not sure that this is true, but I found the get_attr was causing graph breaks but that was a few iterations ago and reading the graph_breaks logs can be a lil hard to decipher some times

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in 956195b.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants