-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial pipeline parallelism support #1008
base: main
Are you sure you want to change the base?
Changes from 47 commits
5145b9b
76991de
524d4c1
93dad69
934d556
b421fd7
4a198ea
db43786
f9bd9fd
fdef39a
0a12c71
b549558
817569c
cc6aaa2
b5d471a
8c54c47
627c73c
7907ce1
f2c5615
0d9af59
251a70a
d1fcf81
2ca09d4
b29a2df
97c25e5
73e5481
5f8e7ec
81a1435
4f3b434
cf23568
1a1091b
da2db54
649018f
2af7cd6
10e4d9b
bb87dd0
a4bc196
f9340fe
a14a283
5e147ef
4ff8667
07d8ed8
9e1aa62
b494532
37fc560
d72f1c6
4817485
5dbff38
24a33e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright 2025 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
"""Example program to export a sharded FFN network like what is found in | ||
a typical transformer layer. This is used for developing and testing various | ||
tooling flows with a scaled down example. | ||
|
||
Generate MLIR and a random inited IRPA file with: | ||
|
||
python -m sharktank.examples.sharding.export_ffn_net \ | ||
--output-irpa-file=/tmp/ffn.irpa /tmp/ffn.mlir | ||
""" | ||
|
||
import math | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from ...layers import * | ||
from ... import ops | ||
from ...types import * | ||
|
||
from iree.turbine.aot import DeviceAffinity, DeviceTensorTrait, export | ||
|
||
def create_theta( | ||
dim: int, shard_count: int, num_layers: int, save_path | ||
): | ||
split_size = dim // shard_count | ||
weights = [] | ||
for layer in range(num_layers): | ||
_weight = torch.rand(dim, dim, dtype=torch.float16) / math.sqrt(dim) | ||
weights.append( | ||
SplitPrimitiveTensor( | ||
name=f"w.{layer}", | ||
shard_dim=1, | ||
ts=_weight.split(split_size, dim=1) | ||
) | ||
) | ||
ds = Dataset({}, Theta(weights)) | ||
ds.save(save_path) | ||
|
||
|
||
def pipeline_parallelize_theta( | ||
theta: Theta, | ||
pp_count: int | ||
) -> Theta: | ||
num_layers = len(theta.tensor("w")) | ||
shard_count = theta.tensor("w", '0').shard_count | ||
for layer in list(theta.tensor("w").keys()): | ||
weight = theta.tensor("w", layer) | ||
pp_group = int(int(layer) * pp_count / num_layers) | ||
zero_4_group = shard_count * pp_group | ||
devices = tuple(i + zero_4_group for i in range(shard_count)) | ||
shards = weight.shards | ||
for i, shard in enumerate(shards): | ||
DeviceTensorTrait(devices[i]).set(shard._data) | ||
new_weight = SplitPrimitiveTensor( | ||
shard_dim=weight.shard_dim, | ||
ts=shards, | ||
name=weight.name, | ||
devices=devices, | ||
pinned=True | ||
) | ||
theta.tensor("w")[layer] = new_weight | ||
return theta | ||
|
||
|
||
class PPFFN(ThetaLayer): | ||
def forward(self, x: torch.Tensor): | ||
num_layers = len(self.theta.tensor("w")) | ||
shard_count = self.theta.tensor("w", '0').shard_count | ||
|
||
x = ReplicatedTensor(ts=x, shard_count=shard_count) | ||
for layer in range(num_layers): | ||
weight: SplitPrimitiveTensor = self.theta.tensor("w", str(layer)) | ||
x: ReplicatedTensor = ops.all_reduce(ops.linear(x, weight)) | ||
|
||
return x | ||
|
||
|
||
def main(raw_args=None): | ||
from ...utils import cli | ||
|
||
parser = cli.create_parser() | ||
parser.add_argument( | ||
"output_file", | ||
type=str, | ||
nargs="?", | ||
default="-", | ||
help="Output file to save MLIR to", | ||
) | ||
cli.add_output_dataset_options(parser) | ||
args = cli.parse(parser, args=raw_args) | ||
|
||
bs = 16 | ||
sl = 128 | ||
primary_dim = 128 * 2**5 | ||
shard_count = 2 | ||
num_layers = 40 | ||
create_theta(primary_dim, shard_count, num_layers, save_path=args.output_irpa_file) | ||
|
||
pp_count = 4 | ||
ds = Dataset.load(args.output_irpa_file) | ||
root_theta = pipeline_parallelize_theta(ds.root_theta, pp_count) | ||
|
||
mdl = PPFFN(root_theta) | ||
|
||
example_arg = torch.empty(bs, sl, primary_dim, dtype=torch.float16) | ||
ep = torch.export.export(mdl, (example_arg,))#, strict=False) | ||
cm = export(ep, arg_device={0: DeviceAffinity(0)}) | ||
|
||
if args.output_file == "-": | ||
print(cm.mlir_module) | ||
else: | ||
with open(args.output_file, "wt") as f: | ||
f.write(str(cm.mlir_module)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,75 @@ | |
|
||
from . import _registry | ||
from ..types.tensors import unbox_tensor | ||
from .signatures import * | ||
|
||
def import_and_wrap_signatures(): | ||
""" | ||
Import the signatures from .signatures and wrap then so the shards of their inputs tensors are on the same devices. | ||
For unary ops, also pins the result if the input is pinned (e.g. for transpose). | ||
""" | ||
def transfer_n_pin(f): | ||
""" | ||
Create a wrapper for each operation. | ||
""" | ||
from ..types import ShardedTensor | ||
from typing import List, Tuple, Dict, Any | ||
def unwrap_args(items: Tuple | Dict[str, Any]) -> Tuple[List[int | List[int]], List[ShardedTensor]]: | ||
t_i, t_vals = [], [] | ||
for i, arg in enumerate(items): | ||
if isinstance(arg, ShardedTensor): | ||
t_i.append(i) | ||
t_vals.append(arg) | ||
elif isinstance(arg, list) and all(isinstance(val, ShardedTensor) for val in arg): | ||
t_i.append([i] * len(arg)) | ||
t_vals.extend(arg) | ||
return t_i, t_vals | ||
|
||
def rewrap_args(items: Tuple | Dict, t_i: List[int | List[int]], t_vals: List[ShardedTensor]) -> Tuple[Tuple, Dict[str, Any]]: | ||
i_lookup = list(range(len(items))) if isinstance(items, tuple) else list(items.keys()) | ||
new_items = list(items) if isinstance(items, tuple) else dict(items) | ||
|
||
for i in t_i: | ||
if isinstance(i, int): | ||
new_items[i_lookup[i]] = t_vals.pop(0) | ||
else: # List[int] | ||
_popped_vals = [t_vals.pop(0) for _ in range(len(i))] | ||
new_items[i_lookup[i[0]]] = items[i_lookup[i[0]]].__class__(_popped_vals) | ||
|
||
if isinstance(new_items, list): | ||
new_items = tuple(new_items) | ||
return new_items | ||
|
||
def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]): | ||
t_i_args, t_vals_args = unwrap_args(args) | ||
t_i_kwargs, t_vals_kwargs = unwrap_args(list(kwargs.values())) | ||
t_vals = t_vals_args + t_vals_kwargs | ||
|
||
t_vals = transfer_if_needed(*t_vals) | ||
|
||
args = rewrap_args(args, t_i_args, t_vals[:len(t_vals_args)]) | ||
kwargs = rewrap_args(kwargs, t_i_kwargs, t_vals[len(t_vals_args):]) | ||
res = f(*args, **kwargs) | ||
if isinstance(res, ShardedTensor) and len(t_vals) > 0: | ||
pinned = (res.pinned | ||
or (len(t_vals) == 1 and t_vals[0].pinned) | ||
or f.__name__ == 'reshard_like') # TODO: How to handle this case properly | ||
res = res.clone(devices=t_vals[0].devices, pinned=pinned) | ||
return res | ||
|
||
if hasattr(f, "override"): # Needed for ops like gelu_tanh_approximation | ||
func_wrapper.override = f.override | ||
return func_wrapper | ||
|
||
do_not_wrap = {'all_gather', 'all_reduce', 'replicate', 'index_copy_', 'index_put_'} | ||
|
||
from . import signatures | ||
for func_name in signatures.__all__: | ||
func = getattr(signatures, func_name) | ||
if func_name not in do_not_wrap: | ||
func = transfer_n_pin(func) | ||
globals()[func_name] = func | ||
import_and_wrap_signatures() | ||
|
||
from .shape import * | ||
|
||
# Ensure that implementations are registered. | ||
|
@@ -33,3 +101,5 @@ | |
# Comment this out to completely disable optimized quantized implementations. | ||
from . import qconv_impls | ||
from . import qlinear_impls | ||
|
||
from .sharded_impls import transfer_if_needed # TODO: Hack just to get tests running, figure out properly later | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
import torch | ||
from torch import Tensor | ||
from typing import List, Optional, Sequence, Union, Any, Tuple | ||
from typing import List, Optional, Sequence, Union, Any, Tuple, Dict | ||
import itertools | ||
from numbers import Number | ||
import math | ||
|
@@ -31,6 +31,57 @@ | |
from ..utils import longest_equal_range | ||
|
||
|
||
def transfer_if_needed(*tensors: Tuple[ShardedTensor]) -> List[ShardedTensor]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to put this in __init__ at well since that is not the only place it's being used. But the top level type hints for this functions will make the imports messy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to keep in sharding as all of these wrapping behavior should be sharded specific. |
||
""" | ||
If at least 2 tensors are panned in, the shards of all unpinned tensors are transfered to be on the same devices as those of the pinned tensors. | ||
""" | ||
def tensor_w_shards_moved(tensor: ShardedTensor, new_devices: Tuple[int]) -> ShardedTensor: | ||
""" | ||
Create a copy of the passed in tensor, but with shards now placed on the new devices. | ||
""" | ||
new_shards = tuple( | ||
( | ||
transfer_to_logical_device(shard, new_devices[j]) | ||
if new_devices[j] != tensor.devices[j] | ||
else barrier_on_logical_device(shard, new_devices[j]) | ||
) | ||
for j, shard in enumerate(tensor.shards) | ||
) | ||
return tensor.clone(ts=new_shards, devices=new_devices) | ||
|
||
if len(tensors) <= 1: | ||
return list(tensors) | ||
assert all(isinstance(tensor, ShardedTensor) for tensor in tensors) | ||
|
||
# Check if all tensors are on the same devices. | ||
all_on_same_devices = True | ||
for tensor in tensors[1:]: | ||
if any(d0 != d for d0, d in zip(tensors[0].devices, tensor.devices)): | ||
all_on_same_devices = False | ||
break | ||
if all_on_same_devices: | ||
return list(tensors) | ||
|
||
pinned_tensors = [tensor for tensor in tensors if tensor.pinned] | ||
if len(pinned_tensors) == 0: | ||
raise ValueError("Tensors are on different devices, but none are pinned. Don't know which devices to transfer to.") | ||
|
||
pinned_devices = pinned_tensors[0].devices | ||
for pinned_tensor in pinned_tensors[1:]: | ||
if any(d0 != d for d0, d in zip(pinned_devices, pinned_tensor.devices)): | ||
raise ValueError("All pinned tensors must be on the same devices.") | ||
|
||
# Move all non-pinned tensors to the same devices as the pinned ones. | ||
new_tensors = [ | ||
( | ||
tensor | ||
if tensor.pinned | ||
else tensor_w_shards_moved(tensor, pinned_devices) | ||
) | ||
for tensor in tensors | ||
] | ||
return new_tensors | ||
|
||
@all_gather.override(SplitPrimitiveTensor) | ||
def all_gather_split( | ||
input: SplitPrimitiveTensor, *, dim: int | None | ||
|
@@ -43,17 +94,17 @@ def all_gather_split( | |
cat( | ||
[ | ||
( | ||
barrier_on_logical_device(shard, i) | ||
barrier_on_logical_device(shard, input.devices[i]) | ||
if i == j | ||
else transfer_to_logical_device(shard, i) | ||
else transfer_to_logical_device(shard, input.devices[i]) | ||
) | ||
for j, shard in enumerate(input.shards) | ||
], | ||
dim=dim, | ||
) | ||
for i in range(input.shard_count) | ||
] | ||
return ReplicatedTensor(ts=shards) | ||
return ReplicatedTensor(ts=shards, devices=input.devices) | ||
|
||
|
||
@all_reduce.override(AllOfType(SplitPrimitiveTensor, UnreducedTensor)) | ||
|
@@ -68,16 +119,16 @@ def all_reduce_split_or_unreduced( | |
lambda x, y: elementwise(torch.add, x, y), | ||
[ | ||
( | ||
barrier_on_logical_device(shard, i) | ||
barrier_on_logical_device(shard, input.devices[i]) | ||
if i == j | ||
else transfer_to_logical_device(shard, i) | ||
else transfer_to_logical_device(shard, input.devices[i]) | ||
) | ||
for j, shard in enumerate(input.shards) | ||
], | ||
) | ||
for i in range(input.shard_count) | ||
] | ||
return ReplicatedTensor(ts=shards) | ||
return ReplicatedTensor(ts=shards, devices=input.devices) | ||
|
||
|
||
@cat.override(AllOfType(ReplicatedTensor)) | ||
|
@@ -102,7 +153,6 @@ def cat_split( | |
for t in tensors | ||
] | ||
) | ||
|
||
shard_dim = tensors[0].shard_dim | ||
shard_count = tensors[0].shard_count | ||
if dim != shard_dim: | ||
|
@@ -894,28 +944,28 @@ def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedT | |
|
||
|
||
@replicate.override(ReplicatedTensor) | ||
def replicate_replicated(input: ReplicatedTensor, *, count: int) -> ReplicatedTensor: | ||
def replicate_replicated(input: ReplicatedTensor, *, count: int, devices: None, pinned: None) -> ReplicatedTensor: | ||
if input.shard_count != count: | ||
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") | ||
return input | ||
|
||
|
||
@replicate.override(SplitPrimitiveTensor) | ||
def replicate_split(input: SplitPrimitiveTensor, *, count: int) -> ReplicatedTensor: | ||
def replicate_split(input: SplitPrimitiveTensor, *, count: int, devices: None, pinned: None) -> ReplicatedTensor: | ||
if input.shard_count != count: | ||
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") | ||
return all_gather(input) | ||
|
||
|
||
@replicate.override(UnreducedTensor) | ||
def replicate_unreduced(input: UnreducedTensor, *, count: int) -> ReplicatedTensor: | ||
def replicate_unreduced(input: UnreducedTensor, *, count: int, devices: None, pinned: None) -> ReplicatedTensor: | ||
if input.shard_count != count: | ||
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") | ||
return all_reduce(input) | ||
|
||
|
||
@replicate.override(Tensor) | ||
def replicate_unsharded(input, *, count: int) -> ReplicatedTensor: | ||
def replicate_unsharded(input, *, count: int, devices: Tuple[int], pinned: bool) -> ReplicatedTensor: | ||
torch_input = unbox_tensor(input) | ||
# If we have a torch input replicating we can assume we need to transfer: | ||
torch_inputs = [transfer_to_logical_device(torch_input, i) for i in range(count)] | ||
|
@@ -1138,9 +1188,7 @@ def softmax_split( | |
dim is not None and dim != tensor.shard_dim | ||
), "Softmax along split dimension is not supported." | ||
shards = [softmax(shard, dim=dim, dtype=dtype) for shard in tensor.shards] | ||
return SplitPrimitiveTensor( | ||
ts=shards, shard_dim=tensor.shard_dim, shape=tensor.shape | ||
) | ||
return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim, shape=tensor.shape) | ||
|
||
|
||
@to.override(ReplicatedTensor) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -815,7 +815,7 @@ def _repeat_trampoline( | |
|
||
|
||
@overridable | ||
def replicate(input: AnyTensor, count: int) -> ShardedTensor: | ||
def replicate(input: AnyTensor, count: int, devices: Tuple[int] | None, pinned: bool | None) -> ShardedTensor: | ||
"""Replicate across devices. | ||
|
||
Possibly reshards if required.""" | ||
|
@@ -824,11 +824,19 @@ def replicate(input: AnyTensor, count: int) -> ShardedTensor: | |
|
||
@replicate.trampoline | ||
def _replicate_trampoline( | ||
d: SignatureDispatcher, input: AnyTensor, count: int | ||
d: SignatureDispatcher, input: AnyTensor, count: int, devices: Tuple[int] | None = None, pinned: bool | None = None | ||
) -> ShardedTensor: | ||
tensors = (input,) | ||
if isinstance(input, torch.Tensor): | ||
devices = devices if devices is not None else tuple(range(count)) | ||
pinned = pinned if pinned is not None else False | ||
else: | ||
# TODO: Is this correct? Will use data on `input`. | ||
assert devices is None | ||
assert pinned is None | ||
|
||
for override in d.find_overrides(tensors): | ||
result = override(input, count=count) | ||
result = override(input, count=count, devices=devices, pinned=pinned) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How to handle these helper functions correctly. We can pass a torch.Tensor and so have no idea about placement information. |
||
if result is not NotImplemented: | ||
return override, result | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ended up having to go in the __init__ file so that any import of ops would wrap the functions.
One problem with the approach: a subsequent run of
from .signatures import *
will override the wrapped versions.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty certain we don't want to do this. Only the
sharded_impls
should be the one to consider transfer information. This is pretty specific to sharded impls. We only want to wrap the importers for the sharded cases.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, try to name this
import_sharded_signatures
or something to that effect.