-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial pipeline parallelism support #1008
base: main
Are you sure you want to change the base?
Conversation
5e54f75
to
b549558
Compare
sharktank/tests/ops/sharded_test.py
Outdated
b = torch.rand(3, 6, dtype=torch.float32) | ||
shard_count = 3 | ||
unsharded_result = torch.matmul(a, b) | ||
expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count) | ||
expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count) # TODO: How to know this should also not be pinned | ||
b_sharded = ops.reshard_split(b, dim=1, count=shard_count) | ||
a_sharded = ops.replicate(a, count=shard_count) | ||
actual_result = ops.matmul(a_sharded, b_sharded) | ||
actual_result = ops.matmul(a_sharded, b_sharded) # GOOD: Should NOT be pinned | ||
assert ops.equal(expected_result, actual_result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding .pinned
and .devices
to is_deep_equals
, which ops.equals
calls, causes this test to fail.
actual_result.pinned == False
which is correct. But expected_result.pinned
would end up being True
if any heuristic is used to convert the default None
value into a bool: it's a concrete torch.Tensor
and would be indistinguishable from one loaded from a file, i.e. a weight.
How to handle this? Should is_deep_equal
be looking at .devices
and .pinned
? I feel like it should given its docstring and current behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blah, I can see both cases. We may want to update to include an option to include comparisons of pinning. ops.equals
is such a special case that I can see us wanting to compare both numerics and placement information. For our tests though we mostly just want to compare numerics and pinning comparison is mostly metadata and not value. If this is the only case we can always just manually compare.
sharktank/sharktank/ops/__init__.py
Outdated
@@ -33,3 +33,5 @@ | |||
# Comment this out to completely disable optimized quantized implementations. | |||
from . import qconv_impls | |||
from . import qlinear_impls | |||
|
|||
from .sharded_impls import transfer_if_needed # TODO: Hack just to get tests running, figure out properly later |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO
@@ -30,6 +30,100 @@ | |||
from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim | |||
from ..utils import longest_equal_range | |||
|
|||
def copy_w_new_shards_and_devices(tensor: ShardedTensor, new_shards: List[torch.Tensor], new_devices: Tuple[int]) -> ShardedTensor: | |||
# TODO: What does transfrom_globals need from this function? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO
@replicate.trampoline | ||
def _replicate_trampoline( | ||
d: SignatureDispatcher, input: AnyTensor, count: int | ||
d: SignatureDispatcher, input: AnyTensor, count: int, devices: Tuple[int] | None = None, pinned: bool | None = None | ||
) -> ShardedTensor: | ||
tensors = (input,) | ||
if isinstance(input, torch.Tensor): | ||
devices = devices if devices is not None else tuple(range(count)) | ||
pinned = pinned if pinned is not None else False | ||
else: | ||
# TODO: Is this correct? Will use data on `input`. | ||
assert devices is None | ||
assert pinned is None | ||
|
||
for override in d.find_overrides(tensors): | ||
result = override(input, count=count) | ||
result = override(input, count=count, devices=devices, pinned=pinned) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to handle these helper functions correctly. We can pass a torch.Tensor and so have no idea about placement information.
# TODO: Tests needed | ||
# 1. Pinned input for unary ops should return a pinned result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO
sharktank/tests/ops/sharded_test.py
Outdated
b = torch.rand(3, 6, dtype=torch.float32) | ||
shard_count = 3 | ||
unsharded_result = torch.matmul(a, b) | ||
expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count) | ||
expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count) # TODO: How to know this should also not be pinned | ||
b_sharded = ops.reshard_split(b, dim=1, count=shard_count) | ||
a_sharded = ops.replicate(a, count=shard_count) | ||
actual_result = ops.matmul(a_sharded, b_sharded) | ||
actual_result = ops.matmul(a_sharded, b_sharded) # GOOD: Should NOT be pinned | ||
assert ops.equal(expected_result, actual_result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blah, I can see both cases. We may want to update to include an option to include comparisons of pinning. ops.equals
is such a special case that I can see us wanting to compare both numerics and placement information. For our tests though we mostly just want to compare numerics and pinning comparison is mostly metadata and not value. If this is the only case we can always just manually compare.
if hasattr(f, "override"): # Needed for ops like .gelu_tanh_approximation | ||
wrapper.override = f.override |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be in here or applied on the output at the call site?
… and tensors and collections of them in kwargs.
@@ -18,7 +18,75 @@ | |||
|
|||
from . import _registry | |||
from ..types.tensors import unbox_tensor | |||
from .signatures import * | |||
|
|||
def import_and_wrap_signatures(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ended up having to go in the __init__ file so that any import of ops would wrap the functions.
One problem with the approach: a subsequent run of from .signatures import *
will override the wrapped versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty certain we don't want to do this. Only the sharded_impls
should be the one to consider transfer information. This is pretty specific to sharded impls. We only want to wrap the importers for the sharded cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, try to name this import_sharded_signatures
or something to that effect.
@@ -31,6 +31,57 @@ | |||
from ..utils import longest_equal_range | |||
|
|||
|
|||
def transfer_if_needed(*tensors: Tuple[ShardedTensor]) -> List[ShardedTensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to put this in __init__ at well since that is not the only place it's being used. But the top level type hints for this functions will make the imports messy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to keep in sharding as all of these wrapping behavior should be sharded specific.
post = pre.T | ||
assert all(d_pre == d_post for d_pre, d_post in zip(pre.devices, post.devices)) | ||
# TODO: post gets pinned since resulting ShardedTensor is made with torch.Tensor shards which are assumed to always be pinned | ||
# assert post.pinned == pre.pinned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be an issue
expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count) | ||
expected_result = ops.reshard_split( | ||
unsharded_result, dim=2, count=shard_count | ||
) # TODO: How to know this should also not be pinned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the TODOs here and below
Goal: Introduce pipeline parallelism without requiring a change to the weight irpa files or the forward passes for the different layers (see PPFFN.forward in the example file).
Changes
.device
attribute. Previously the implicit convection was that shardi
lived on devicei
..pinned
attribute.ops.foo(t1 on devs [1,2,3], t2 pinned on devs[5,6,7])
would transfer the shards of t1 onto devices [5,6,7] before performing the operation.ops
can take in atorch.Tensor
and therefore won't know what devices to place them on, e.g.def replicate(input: AnyTensor, count: int) -> ShardedTensor:
. I've addeddevices
andpinned
as extra parameters and used defaults to keep the current behaviour unchanged.ops.signatures
intoops
to handle transfer and pinning automatically when called withShardedTensor
subclasses. Making device parallelism work, mostly, invisibly without needing to modify the functions insharded_impls.py
ops.___
Discussion points
device
andpinned
are required parameters. Should either, especiallypinned
, be optional and have defaults?ops.replicate
handle the extra parameters needs more thought.is_deep_equal()
consider.devices
and.pinned
? Enabling causes a few tests to fail, such astestReplicatedLhsShardedParallelDimRhs
TODOs
transfer_if_needed
into a decorator to automatically perform the transfersops
Change signature to accept adding current behavior as default