Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial pipeline parallelism support #1008

Draft
wants to merge 49 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5145b9b
Example
Alex-Vasile Feb 25, 2025
76991de
Changes to tensors and sharded_impls to get working
Alex-Vasile Feb 25, 2025
524d4c1
Use valid PP and TP combo
Alex-Vasile Feb 25, 2025
93dad69
Formatting cleanup
Alex-Vasile Feb 25, 2025
934d556
Cleanup from feedback
Alex-Vasile Feb 25, 2025
b421fd7
Changes to tensors to make tests pass
Alex-Vasile Feb 25, 2025
4a198ea
Changes to sharded_impls to have tests pass
Alex-Vasile Feb 25, 2025
db43786
Changes to tests to have them pass
Alex-Vasile Feb 25, 2025
f9bd9fd
Initial commit of tests
Alex-Vasile Feb 25, 2025
fdef39a
Cleanedup TODO
Alex-Vasile Feb 25, 2025
0a12c71
ops.replicate support
Alex-Vasile Feb 25, 2025
b549558
Remove ambiguity of calling ops.replicate with a ShardedTensor
Alex-Vasile Feb 25, 2025
817569c
Further tests, and fix typo in file name
Alex-Vasile Feb 26, 2025
cc6aaa2
Wrap transfer_if_needed into a decorator
Alex-Vasile Feb 26, 2025
b5d471a
Testing wrapped decorator
Alex-Vasile Feb 26, 2025
8c54c47
Expand transfer_if_needed to work with arbitrary number of tensors.
Alex-Vasile Feb 26, 2025
627c73c
Throw error when tensors are on different devices and none are pinned.
Alex-Vasile Feb 26, 2025
7907ce1
Fix missing arg to constructor
Alex-Vasile Feb 26, 2025
f2c5615
Changes to transfer_if_needed
Alex-Vasile Feb 26, 2025
0d9af59
Stop pinning all_reduce and all_gather results
Alex-Vasile Feb 26, 2025
251a70a
Changes to decorator
Alex-Vasile Feb 26, 2025
d1fcf81
override_w_transfer pinnes unary ops result if input is pinned
Alex-Vasile Feb 26, 2025
2ca09d4
Correct year in copyright header
Alex-Vasile Feb 26, 2025
b29a2df
Revert "Changes to tests to have them pass"
Alex-Vasile Feb 26, 2025
97c25e5
Remove setter for shards
Alex-Vasile Feb 26, 2025
73e5481
Added defaults for devices and devices_pinned
Alex-Vasile Feb 26, 2025
5f8e7ec
Rename devices_pinned to pinned
Alex-Vasile Feb 26, 2025
81a1435
Fix flattening and unflattening of unreduced_tensor
Alex-Vasile Feb 26, 2025
4f3b434
Stop writing pipelineparallelism related data to archive
Alex-Vasile Feb 26, 2025
cf23568
Disable changes to is_deep_equal
Alex-Vasile Feb 27, 2025
1a1091b
Cleanup and documentation
Alex-Vasile Feb 27, 2025
da2db54
Missed params for initializer
Alex-Vasile Feb 27, 2025
649018f
Default value for .pinned
Alex-Vasile Feb 27, 2025
2af7cd6
example cleanup
Alex-Vasile Feb 27, 2025
10e4d9b
Remove setter and rewrite example
Alex-Vasile Feb 27, 2025
bb87dd0
Add clone constructor for ShardedTensor
Alex-Vasile Feb 28, 2025
a4bc196
Change from overriding the decorator to overring imports
Alex-Vasile Feb 28, 2025
f9340fe
Change wrapper to specify devices on output
Alex-Vasile Feb 28, 2025
a14a283
Remove unnecessary passing of arguments
Alex-Vasile Feb 28, 2025
5e147ef
fix return logic
Alex-Vasile Feb 28, 2025
4ff8667
Cleanup
Alex-Vasile Feb 28, 2025
07d8ed8
Expand wrapper functionality to handle collections of tensors in args…
Alex-Vasile Mar 1, 2025
9e1aa62
More tests
Alex-Vasile Mar 3, 2025
b494532
Move tranfer wrapper to init to capture all imports of ops
Alex-Vasile Mar 3, 2025
37fc560
Don't transfer for Index_put and index_copy
Alex-Vasile Mar 3, 2025
d72f1c6
More tests
Alex-Vasile Mar 3, 2025
4817485
Changes to wrapper to have reshard_like pass
Alex-Vasile Mar 3, 2025
5dbff38
Don't wrap transfer_to_logical_device
Alex-Vasile Mar 3, 2025
24a33e9
pre-commit cleanup
Alex-Vasile Mar 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions sharktank/sharktank/examples/pipeline/export_ppffn_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Example program to export a sharded FFN network like what is found in
a typical transformer layer. This is used for developing and testing various
tooling flows with a scaled down example.

Generate MLIR and a random inited IRPA file with:

python -m sharktank.examples.sharding.export_ffn_net \
--output-irpa-file=/tmp/ffn.irpa /tmp/ffn.mlir
"""

import math
from typing import Tuple

import torch
import torch.nn as nn

from ...layers import *
from ... import ops
from ...types import *

from iree.turbine.aot import DeviceAffinity, DeviceTensorTrait, export

def create_theta(
dim: int, shard_count: int, num_layers: int, save_path
):
split_size = dim // shard_count
weights = []
for layer in range(num_layers):
_weight = torch.rand(dim, dim, dtype=torch.float16) / math.sqrt(dim)
weights.append(
SplitPrimitiveTensor(
name=f"w.{layer}",
shard_dim=1,
ts=_weight.split(split_size, dim=1)
)
)
ds = Dataset({}, Theta(weights))
ds.save(save_path)


def pipeline_parallelize_theta(
theta: Theta,
pp_count: int
) -> Theta:
num_layers = len(theta.tensor("w"))
shard_count = theta.tensor("w", '0').shard_count
for layer in list(theta.tensor("w").keys()):
weight = theta.tensor("w", layer)
pp_group = int(int(layer) * pp_count / num_layers)
zero_4_group = shard_count * pp_group
devices = tuple(i + zero_4_group for i in range(shard_count))
shards = weight.shards
for i, shard in enumerate(shards):
DeviceTensorTrait(devices[i]).set(shard._data)
new_weight = SplitPrimitiveTensor(
shard_dim=weight.shard_dim,
ts=shards,
name=weight.name,
devices=devices,
pinned=True
)
theta.tensor("w")[layer] = new_weight
return theta


class PPFFN(ThetaLayer):
def forward(self, x: torch.Tensor):
num_layers = len(self.theta.tensor("w"))
shard_count = self.theta.tensor("w", '0').shard_count

x = ReplicatedTensor(ts=x, shard_count=shard_count)
for layer in range(num_layers):
weight: SplitPrimitiveTensor = self.theta.tensor("w", str(layer))
x: ReplicatedTensor = ops.all_reduce(ops.linear(x, weight))

return x


def main(raw_args=None):
from ...utils import cli

parser = cli.create_parser()
parser.add_argument(
"output_file",
type=str,
nargs="?",
default="-",
help="Output file to save MLIR to",
)
cli.add_output_dataset_options(parser)
args = cli.parse(parser, args=raw_args)

bs = 16
sl = 128
primary_dim = 128 * 2**5
shard_count = 2
num_layers = 40
create_theta(primary_dim, shard_count, num_layers, save_path=args.output_irpa_file)

pp_count = 4
ds = Dataset.load(args.output_irpa_file)
root_theta = pipeline_parallelize_theta(ds.root_theta, pp_count)

mdl = PPFFN(root_theta)

example_arg = torch.empty(bs, sl, primary_dim, dtype=torch.float16)
ep = torch.export.export(mdl, (example_arg,))#, strict=False)
cm = export(ep, arg_device={0: DeviceAffinity(0)})

if args.output_file == "-":
print(cm.mlir_module)
else:
with open(args.output_file, "wt") as f:
f.write(str(cm.mlir_module))


if __name__ == "__main__":
main()
72 changes: 71 additions & 1 deletion sharktank/sharktank/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,75 @@

from . import _registry
from ..types.tensors import unbox_tensor
from .signatures import *

def import_and_wrap_signatures():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ended up having to go in the __init__ file so that any import of ops would wrap the functions.

One problem with the approach: a subsequent run of from .signatures import * will override the wrapped versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty certain we don't want to do this. Only the sharded_impls should be the one to consider transfer information. This is pretty specific to sharded impls. We only want to wrap the importers for the sharded cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, try to name this import_sharded_signatures or something to that effect.

"""
Import the signatures from .signatures and wrap then so the shards of their inputs tensors are on the same devices.
For unary ops, also pins the result if the input is pinned (e.g. for transpose).
"""
def transfer_n_pin(f):
"""
Create a wrapper for each operation.
"""
from ..types import ShardedTensor
from typing import List, Tuple, Dict, Any
def unwrap_args(items: Tuple | Dict[str, Any]) -> Tuple[List[int | List[int]], List[ShardedTensor]]:
t_i, t_vals = [], []
for i, arg in enumerate(items):
if isinstance(arg, ShardedTensor):
t_i.append(i)
t_vals.append(arg)
elif isinstance(arg, list) and all(isinstance(val, ShardedTensor) for val in arg):
t_i.append([i] * len(arg))
t_vals.extend(arg)
return t_i, t_vals

def rewrap_args(items: Tuple | Dict, t_i: List[int | List[int]], t_vals: List[ShardedTensor]) -> Tuple[Tuple, Dict[str, Any]]:
i_lookup = list(range(len(items))) if isinstance(items, tuple) else list(items.keys())
new_items = list(items) if isinstance(items, tuple) else dict(items)

for i in t_i:
if isinstance(i, int):
new_items[i_lookup[i]] = t_vals.pop(0)
else: # List[int]
_popped_vals = [t_vals.pop(0) for _ in range(len(i))]
new_items[i_lookup[i[0]]] = items[i_lookup[i[0]]].__class__(_popped_vals)

if isinstance(new_items, list):
new_items = tuple(new_items)
return new_items

def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]):
t_i_args, t_vals_args = unwrap_args(args)
t_i_kwargs, t_vals_kwargs = unwrap_args(list(kwargs.values()))
t_vals = t_vals_args + t_vals_kwargs

t_vals = transfer_if_needed(*t_vals)

args = rewrap_args(args, t_i_args, t_vals[:len(t_vals_args)])
kwargs = rewrap_args(kwargs, t_i_kwargs, t_vals[len(t_vals_args):])
res = f(*args, **kwargs)
if isinstance(res, ShardedTensor) and len(t_vals) > 0:
pinned = (res.pinned
or (len(t_vals) == 1 and t_vals[0].pinned)
or f.__name__ == 'reshard_like') # TODO: How to handle this case properly
res = res.clone(devices=t_vals[0].devices, pinned=pinned)
return res

if hasattr(f, "override"): # Needed for ops like gelu_tanh_approximation
func_wrapper.override = f.override
return func_wrapper

do_not_wrap = {'all_gather', 'all_reduce', 'replicate', 'index_copy_', 'index_put_'}

from . import signatures
for func_name in signatures.__all__:
func = getattr(signatures, func_name)
if func_name not in do_not_wrap:
func = transfer_n_pin(func)
globals()[func_name] = func
import_and_wrap_signatures()

from .shape import *

# Ensure that implementations are registered.
Expand All @@ -33,3 +101,5 @@
# Comment this out to completely disable optimized quantized implementations.
from . import qconv_impls
from . import qlinear_impls

from .sharded_impls import transfer_if_needed # TODO: Hack just to get tests running, figure out properly later
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

78 changes: 63 additions & 15 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torch import Tensor
from typing import List, Optional, Sequence, Union, Any, Tuple
from typing import List, Optional, Sequence, Union, Any, Tuple, Dict
import itertools
from numbers import Number
import math
Expand All @@ -31,6 +31,57 @@
from ..utils import longest_equal_range


def transfer_if_needed(*tensors: Tuple[ShardedTensor]) -> List[ShardedTensor]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to put this in __init__ at well since that is not the only place it's being used. But the top level type hints for this functions will make the imports messy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to keep in sharding as all of these wrapping behavior should be sharded specific.

"""
If at least 2 tensors are panned in, the shards of all unpinned tensors are transfered to be on the same devices as those of the pinned tensors.
"""
def tensor_w_shards_moved(tensor: ShardedTensor, new_devices: Tuple[int]) -> ShardedTensor:
"""
Create a copy of the passed in tensor, but with shards now placed on the new devices.
"""
new_shards = tuple(
(
transfer_to_logical_device(shard, new_devices[j])
if new_devices[j] != tensor.devices[j]
else barrier_on_logical_device(shard, new_devices[j])
)
for j, shard in enumerate(tensor.shards)
)
return tensor.clone(ts=new_shards, devices=new_devices)

if len(tensors) <= 1:
return list(tensors)
assert all(isinstance(tensor, ShardedTensor) for tensor in tensors)

# Check if all tensors are on the same devices.
all_on_same_devices = True
for tensor in tensors[1:]:
if any(d0 != d for d0, d in zip(tensors[0].devices, tensor.devices)):
all_on_same_devices = False
break
if all_on_same_devices:
return list(tensors)

pinned_tensors = [tensor for tensor in tensors if tensor.pinned]
if len(pinned_tensors) == 0:
raise ValueError("Tensors are on different devices, but none are pinned. Don't know which devices to transfer to.")

pinned_devices = pinned_tensors[0].devices
for pinned_tensor in pinned_tensors[1:]:
if any(d0 != d for d0, d in zip(pinned_devices, pinned_tensor.devices)):
raise ValueError("All pinned tensors must be on the same devices.")

# Move all non-pinned tensors to the same devices as the pinned ones.
new_tensors = [
(
tensor
if tensor.pinned
else tensor_w_shards_moved(tensor, pinned_devices)
)
for tensor in tensors
]
return new_tensors

@all_gather.override(SplitPrimitiveTensor)
def all_gather_split(
input: SplitPrimitiveTensor, *, dim: int | None
Expand All @@ -43,17 +94,17 @@ def all_gather_split(
cat(
[
(
barrier_on_logical_device(shard, i)
barrier_on_logical_device(shard, input.devices[i])
if i == j
else transfer_to_logical_device(shard, i)
else transfer_to_logical_device(shard, input.devices[i])
)
for j, shard in enumerate(input.shards)
],
dim=dim,
)
for i in range(input.shard_count)
]
return ReplicatedTensor(ts=shards)
return ReplicatedTensor(ts=shards, devices=input.devices)


@all_reduce.override(AllOfType(SplitPrimitiveTensor, UnreducedTensor))
Expand All @@ -68,16 +119,16 @@ def all_reduce_split_or_unreduced(
lambda x, y: elementwise(torch.add, x, y),
[
(
barrier_on_logical_device(shard, i)
barrier_on_logical_device(shard, input.devices[i])
if i == j
else transfer_to_logical_device(shard, i)
else transfer_to_logical_device(shard, input.devices[i])
)
for j, shard in enumerate(input.shards)
],
)
for i in range(input.shard_count)
]
return ReplicatedTensor(ts=shards)
return ReplicatedTensor(ts=shards, devices=input.devices)


@cat.override(AllOfType(ReplicatedTensor))
Expand All @@ -102,7 +153,6 @@ def cat_split(
for t in tensors
]
)

shard_dim = tensors[0].shard_dim
shard_count = tensors[0].shard_count
if dim != shard_dim:
Expand Down Expand Up @@ -894,28 +944,28 @@ def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedT


@replicate.override(ReplicatedTensor)
def replicate_replicated(input: ReplicatedTensor, *, count: int) -> ReplicatedTensor:
def replicate_replicated(input: ReplicatedTensor, *, count: int, devices: None, pinned: None) -> ReplicatedTensor:
if input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
return input


@replicate.override(SplitPrimitiveTensor)
def replicate_split(input: SplitPrimitiveTensor, *, count: int) -> ReplicatedTensor:
def replicate_split(input: SplitPrimitiveTensor, *, count: int, devices: None, pinned: None) -> ReplicatedTensor:
if input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
return all_gather(input)


@replicate.override(UnreducedTensor)
def replicate_unreduced(input: UnreducedTensor, *, count: int) -> ReplicatedTensor:
def replicate_unreduced(input: UnreducedTensor, *, count: int, devices: None, pinned: None) -> ReplicatedTensor:
if input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
return all_reduce(input)


@replicate.override(Tensor)
def replicate_unsharded(input, *, count: int) -> ReplicatedTensor:
def replicate_unsharded(input, *, count: int, devices: Tuple[int], pinned: bool) -> ReplicatedTensor:
torch_input = unbox_tensor(input)
# If we have a torch input replicating we can assume we need to transfer:
torch_inputs = [transfer_to_logical_device(torch_input, i) for i in range(count)]
Expand Down Expand Up @@ -1138,9 +1188,7 @@ def softmax_split(
dim is not None and dim != tensor.shard_dim
), "Softmax along split dimension is not supported."
shards = [softmax(shard, dim=dim, dtype=dtype) for shard in tensor.shards]
return SplitPrimitiveTensor(
ts=shards, shard_dim=tensor.shard_dim, shape=tensor.shape
)
return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim, shape=tensor.shape)


@to.override(ReplicatedTensor)
Expand Down
14 changes: 11 additions & 3 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def _repeat_trampoline(


@overridable
def replicate(input: AnyTensor, count: int) -> ShardedTensor:
def replicate(input: AnyTensor, count: int, devices: Tuple[int] | None, pinned: bool | None) -> ShardedTensor:
"""Replicate across devices.

Possibly reshards if required."""
Expand All @@ -824,11 +824,19 @@ def replicate(input: AnyTensor, count: int) -> ShardedTensor:

@replicate.trampoline
def _replicate_trampoline(
d: SignatureDispatcher, input: AnyTensor, count: int
d: SignatureDispatcher, input: AnyTensor, count: int, devices: Tuple[int] | None = None, pinned: bool | None = None
) -> ShardedTensor:
tensors = (input,)
if isinstance(input, torch.Tensor):
devices = devices if devices is not None else tuple(range(count))
pinned = pinned if pinned is not None else False
else:
# TODO: Is this correct? Will use data on `input`.
assert devices is None
assert pinned is None

for override in d.find_overrides(tensors):
result = override(input, count=count)
result = override(input, count=count, devices=devices, pinned=pinned)
Copy link
Contributor Author

@Alex-Vasile Alex-Vasile Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to handle these helper functions correctly. We can pass a torch.Tensor and so have no idea about placement information.

if result is not NotImplemented:
return override, result
else:
Expand Down
Loading
Loading