-
Notifications
You must be signed in to change notification settings - Fork 19
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some initial comments just scanning through the code.
9f3074a
to
3e4d745
Compare
) | ||
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): | ||
all_reduced_amax_tensor = all_reduced_amax_tensor.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside: It seems like a bug on AsyncCollectiveTensor
if it does not make sure to call wait()
when you torch.split
it. 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm...
I actually tried to add an optimization to AsyncCollectiveTensor a while ago: pytorch/pytorch#105240. Where the thought behind it was:
(1) if we have a collective in-flight, we only need to sync the collective if the data needs to be used in an operation
(2) view ops generally don't need to access the data of their inputs, only the metadata.
So that PR "delays" syncs when there are view ops called on AsyncCollectiveTensor (e.g. aten.split()
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, I was working through this with @wanchaol and without the all_reduced_amax_tensor was garbage...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious why the manual sync here is needed - does it paper over some underlying bug? Didn't see your last comment. I wonder if this is related to some internal AsyncCollectiveTensor issues 🤔
Let me know if either of you want to stare at it together (I think fixing the underlying bug will help other areas too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interestingly, I was not able to repro the issue: P1184812077
_foreach_copy_
and functional all-reduce without .wait()
seemed to be fine, but I might not be repro-ing correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm let me try again then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right... however:
return self.__get_result()
File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CompilationError: at 10:30:def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, out_ptr3, out_ptr4, out_ptr5, out_ptr7, out_ptr8, out_ptr9, out_ptr11):
xpid = tl.program_id(0)
XBLOCK: tl.constexpr = 1024
if xpid >= 0 and xpid < 1:
xpid_offset = xpid - 0
xnumel = 1
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
^
NameError('RBLOCK is not defined')
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
compiling the sync function now causes an inductor fault... so I will create another PR and investigate there but otherwise I think thats everything for this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3e4d745
to
4ca6ddf
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
PERFORMANCE NOTE: | ||
When you can, it is much more efficient to call get_float8_layers once at | ||
the beginning of the training loop and pass the result to this function. | ||
Because of how this interacts with torch.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, how does passing float8_layers
affect torch.compile
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
honestly I wrote this at the beginning of the PR and now I am not sure that this is true, but I found the get_attr was causing graph breaks but that was a few iterations ago and reading the graph_breaks logs can be a lil hard to decipher some times
Summary
Update docs, make sure this is friendly to dynamo
Perf
Things I have done/changed
Commit 1
fp8_classes
argument that would be passed in, this was to enable working with the separate TP/SP classes, since we plan to have Dtensor be the solution I am removing for now.Commit 2
Commit 3
Things to do