Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial pipeline parallelism support #1008

Draft
wants to merge 49 commits into
base: main
Choose a base branch
from

Conversation

Alex-Vasile
Copy link
Contributor

@Alex-Vasile Alex-Vasile commented Feb 25, 2025

Goal: Introduce pipeline parallelism without requiring a change to the weight irpa files or the forward passes for the different layers (see PPFFN.forward in the example file).

Changes

  • ShardedTensor now explicitly store what device each of their shards should live on in a .device attribute. Previously the implicit convection was that shard i lived on device i.
  • ShardedTensor can also be pinned to specific devices, such as for weights, or left unpinned to signal that it should be moved if needed using a .pinned attribute.
  • Binary operators call a helper function to see if either tensor needs to be transferred such that all shards are on matching devices. E.g. ops.foo(t1 on devs [1,2,3], t2 pinned on devs[5,6,7]) would transfer the shards of t1 onto devices [5,6,7] before performing the operation.
  • Several helper functions in ops can take in a torch.Tensor and therefore won't know what devices to place them on, e.g. def replicate(input: AnyTensor, count: int) -> ShardedTensor:. I've added devices and pinned as extra parameters and used defaults to keep the current behaviour unchanged.
  • Added a wrapper to all functions important from ops.signatures into ops to handle transfer and pinning automatically when called with ShardedTensor subclasses. Making device parallelism work, mostly, invisibly without needing to modify the functions in sharded_impls.py
    • One downside is that IDEs don't appear to be able to do tab completion anymore for ops.___

Discussion points

  • Overall thoughts on approach?
  • Both device and pinned are required parameters. Should either, especially pinned, be optional and have defaults?
  • Exactly how should the different unary ops like ops.replicate handle the extra parameters needs more thought.
  • Should is_deep_equal() consider .devices and .pinned? Enabling causes a few tests to fail, such as testReplicatedLhsShardedParallelDimRhs

TODOs

  • Better names
  • Change transfer_if_needed into a decorator to automatically perform the transfers
  • Add support for all ops
  • Add tests based on sharded tests
  • Several helper functions in ops Change signature to accept adding current behavior as default
  • Test if it works in eager mode, not just AOT

Comment on lines 667 to 674
b = torch.rand(3, 6, dtype=torch.float32)
shard_count = 3
unsharded_result = torch.matmul(a, b)
expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count)
expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count) # TODO: How to know this should also not be pinned
b_sharded = ops.reshard_split(b, dim=1, count=shard_count)
a_sharded = ops.replicate(a, count=shard_count)
actual_result = ops.matmul(a_sharded, b_sharded)
actual_result = ops.matmul(a_sharded, b_sharded) # GOOD: Should NOT be pinned
assert ops.equal(expected_result, actual_result)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding .pinned and .devices to is_deep_equals, which ops.equals calls, causes this test to fail.

actual_result.pinned == False which is correct. But expected_result.pinned would end up being True if any heuristic is used to convert the default None value into a bool: it's a concrete torch.Tensor and would be indistinguishable from one loaded from a file, i.e. a weight.

How to handle this? Should is_deep_equal be looking at .devices and .pinned? I feel like it should given its docstring and current behaviour.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blah, I can see both cases. We may want to update to include an option to include comparisons of pinning. ops.equals is such a special case that I can see us wanting to compare both numerics and placement information. For our tests though we mostly just want to compare numerics and pinning comparison is mostly metadata and not value. If this is the only case we can always just manually compare.

@@ -33,3 +33,5 @@
# Comment this out to completely disable optimized quantized implementations.
from . import qconv_impls
from . import qlinear_impls

from .sharded_impls import transfer_if_needed # TODO: Hack just to get tests running, figure out properly later
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

@@ -30,6 +30,100 @@
from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim
from ..utils import longest_equal_range

def copy_w_new_shards_and_devices(tensor: ShardedTensor, new_shards: List[torch.Tensor], new_devices: Tuple[int]) -> ShardedTensor:
# TODO: What does transfrom_globals need from this function?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

Comment on lines 825 to 839
@replicate.trampoline
def _replicate_trampoline(
d: SignatureDispatcher, input: AnyTensor, count: int
d: SignatureDispatcher, input: AnyTensor, count: int, devices: Tuple[int] | None = None, pinned: bool | None = None
) -> ShardedTensor:
tensors = (input,)
if isinstance(input, torch.Tensor):
devices = devices if devices is not None else tuple(range(count))
pinned = pinned if pinned is not None else False
else:
# TODO: Is this correct? Will use data on `input`.
assert devices is None
assert pinned is None

for override in d.find_overrides(tensors):
result = override(input, count=count)
result = override(input, count=count, devices=devices, pinned=pinned)
Copy link
Contributor Author

@Alex-Vasile Alex-Vasile Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to handle these helper functions correctly. We can pass a torch.Tensor and so have no idea about placement information.

Comment on lines 191 to 192
# TODO: Tests needed
# 1. Pinned input for unary ops should return a pinned result.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

Comment on lines 667 to 674
b = torch.rand(3, 6, dtype=torch.float32)
shard_count = 3
unsharded_result = torch.matmul(a, b)
expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count)
expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count) # TODO: How to know this should also not be pinned
b_sharded = ops.reshard_split(b, dim=1, count=shard_count)
a_sharded = ops.replicate(a, count=shard_count)
actual_result = ops.matmul(a_sharded, b_sharded)
actual_result = ops.matmul(a_sharded, b_sharded) # GOOD: Should NOT be pinned
assert ops.equal(expected_result, actual_result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blah, I can see both cases. We may want to update to include an option to include comparisons of pinning. ops.equals is such a special case that I can see us wanting to compare both numerics and placement information. For our tests though we mostly just want to compare numerics and pinning comparison is mostly metadata and not value. If this is the only case we can always just manually compare.

Comment on lines 69 to 70
if hasattr(f, "override"): # Needed for ops like .gelu_tanh_approximation
wrapper.override = f.override
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in here or applied on the output at the call site?

@@ -18,7 +18,75 @@

from . import _registry
from ..types.tensors import unbox_tensor
from .signatures import *

def import_and_wrap_signatures():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ended up having to go in the __init__ file so that any import of ops would wrap the functions.

One problem with the approach: a subsequent run of from .signatures import * will override the wrapped versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty certain we don't want to do this. Only the sharded_impls should be the one to consider transfer information. This is pretty specific to sharded impls. We only want to wrap the importers for the sharded cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, try to name this import_sharded_signatures or something to that effect.

@@ -31,6 +31,57 @@
from ..utils import longest_equal_range


def transfer_if_needed(*tensors: Tuple[ShardedTensor]) -> List[ShardedTensor]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to put this in __init__ at well since that is not the only place it's being used. But the top level type hints for this functions will make the imports messy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to keep in sharding as all of these wrapping behavior should be sharded specific.

Comment on lines +336 to +339
post = pre.T
assert all(d_pre == d_post for d_pre, d_post in zip(pre.devices, post.devices))
# TODO: post gets pinned since resulting ShardedTensor is made with torch.Tensor shards which are assumed to always be pinned
# assert post.pinned == pre.pinned
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be an issue

expected_result = ops.reshard_split(unsharded_result, dim=2, count=shard_count)
expected_result = ops.reshard_split(
unsharded_result, dim=2, count=shard_count
) # TODO: How to know this should also not be pinned
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the TODOs here and below

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants