Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial pipeline parallelism support #1008

Draft
wants to merge 49 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5145b9b
Example
Alex-Vasile Feb 25, 2025
76991de
Changes to tensors and sharded_impls to get working
Alex-Vasile Feb 25, 2025
524d4c1
Use valid PP and TP combo
Alex-Vasile Feb 25, 2025
93dad69
Formatting cleanup
Alex-Vasile Feb 25, 2025
934d556
Cleanup from feedback
Alex-Vasile Feb 25, 2025
b421fd7
Changes to tensors to make tests pass
Alex-Vasile Feb 25, 2025
4a198ea
Changes to sharded_impls to have tests pass
Alex-Vasile Feb 25, 2025
db43786
Changes to tests to have them pass
Alex-Vasile Feb 25, 2025
f9bd9fd
Initial commit of tests
Alex-Vasile Feb 25, 2025
fdef39a
Cleanedup TODO
Alex-Vasile Feb 25, 2025
0a12c71
ops.replicate support
Alex-Vasile Feb 25, 2025
b549558
Remove ambiguity of calling ops.replicate with a ShardedTensor
Alex-Vasile Feb 25, 2025
817569c
Further tests, and fix typo in file name
Alex-Vasile Feb 26, 2025
cc6aaa2
Wrap transfer_if_needed into a decorator
Alex-Vasile Feb 26, 2025
b5d471a
Testing wrapped decorator
Alex-Vasile Feb 26, 2025
8c54c47
Expand transfer_if_needed to work with arbitrary number of tensors.
Alex-Vasile Feb 26, 2025
627c73c
Throw error when tensors are on different devices and none are pinned.
Alex-Vasile Feb 26, 2025
7907ce1
Fix missing arg to constructor
Alex-Vasile Feb 26, 2025
f2c5615
Changes to transfer_if_needed
Alex-Vasile Feb 26, 2025
0d9af59
Stop pinning all_reduce and all_gather results
Alex-Vasile Feb 26, 2025
251a70a
Changes to decorator
Alex-Vasile Feb 26, 2025
d1fcf81
override_w_transfer pinnes unary ops result if input is pinned
Alex-Vasile Feb 26, 2025
2ca09d4
Correct year in copyright header
Alex-Vasile Feb 26, 2025
b29a2df
Revert "Changes to tests to have them pass"
Alex-Vasile Feb 26, 2025
97c25e5
Remove setter for shards
Alex-Vasile Feb 26, 2025
73e5481
Added defaults for devices and devices_pinned
Alex-Vasile Feb 26, 2025
5f8e7ec
Rename devices_pinned to pinned
Alex-Vasile Feb 26, 2025
81a1435
Fix flattening and unflattening of unreduced_tensor
Alex-Vasile Feb 26, 2025
4f3b434
Stop writing pipelineparallelism related data to archive
Alex-Vasile Feb 26, 2025
cf23568
Disable changes to is_deep_equal
Alex-Vasile Feb 27, 2025
1a1091b
Cleanup and documentation
Alex-Vasile Feb 27, 2025
da2db54
Missed params for initializer
Alex-Vasile Feb 27, 2025
649018f
Default value for .pinned
Alex-Vasile Feb 27, 2025
2af7cd6
example cleanup
Alex-Vasile Feb 27, 2025
10e4d9b
Remove setter and rewrite example
Alex-Vasile Feb 27, 2025
bb87dd0
Add clone constructor for ShardedTensor
Alex-Vasile Feb 28, 2025
a4bc196
Change from overriding the decorator to overring imports
Alex-Vasile Feb 28, 2025
f9340fe
Change wrapper to specify devices on output
Alex-Vasile Feb 28, 2025
a14a283
Remove unnecessary passing of arguments
Alex-Vasile Feb 28, 2025
5e147ef
fix return logic
Alex-Vasile Feb 28, 2025
4ff8667
Cleanup
Alex-Vasile Feb 28, 2025
07d8ed8
Expand wrapper functionality to handle collections of tensors in args…
Alex-Vasile Mar 1, 2025
9e1aa62
More tests
Alex-Vasile Mar 3, 2025
b494532
Move tranfer wrapper to init to capture all imports of ops
Alex-Vasile Mar 3, 2025
37fc560
Don't transfer for Index_put and index_copy
Alex-Vasile Mar 3, 2025
d72f1c6
More tests
Alex-Vasile Mar 3, 2025
4817485
Changes to wrapper to have reshard_like pass
Alex-Vasile Mar 3, 2025
5dbff38
Don't wrap transfer_to_logical_device
Alex-Vasile Mar 3, 2025
24a33e9
pre-commit cleanup
Alex-Vasile Mar 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions sharktank/sharktank/examples/pipeline/export_ppffn_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Example program to export a sharded FFN network like what is found in
a typical transformer layer. This is used for developing and testing various
tooling flows with a scaled down example.

Generate MLIR and a random inited IRPA file with:

python -m sharktank.examples.sharding.export_ffn_net \
--output-irpa-file=/tmp/ffn.irpa /tmp/ffn.mlir
"""

import math
from typing import Tuple

import torch
import torch.nn as nn

from ...layers import *
from ... import ops
from ...types import *

from iree.turbine.aot import DeviceAffinity, DeviceTensorTrait, export


def create_theta(dim: int, shard_count: int, num_layers: int, save_path):
split_size = dim // shard_count
weights = []
for layer in range(num_layers):
_weight = torch.rand(dim, dim, dtype=torch.float16) / math.sqrt(dim)
weights.append(
SplitPrimitiveTensor(
name=f"w.{layer}", shard_dim=1, ts=_weight.split(split_size, dim=1)
)
)
ds = Dataset({}, Theta(weights))
ds.save(save_path)


def pipeline_parallelize_theta(theta: Theta, pp_count: int) -> Theta:
num_layers = len(theta.tensor("w"))
shard_count = theta.tensor("w", "0").shard_count
for layer in list(theta.tensor("w").keys()):
weight = theta.tensor("w", layer)
pp_group = int(int(layer) * pp_count / num_layers)
zero_4_group = shard_count * pp_group
devices = tuple(i + zero_4_group for i in range(shard_count))
shards = weight.shards
for i, shard in enumerate(shards):
DeviceTensorTrait(devices[i]).set(shard._data)
new_weight = SplitPrimitiveTensor(
shard_dim=weight.shard_dim,
ts=shards,
name=weight.name,
devices=devices,
pinned=True,
)
theta.tensor("w")[layer] = new_weight
return theta


class PPFFN(ThetaLayer):
def forward(self, x: torch.Tensor):
num_layers = len(self.theta.tensor("w"))
shard_count = self.theta.tensor("w", "0").shard_count

x = ReplicatedTensor(ts=x, shard_count=shard_count)
for layer in range(num_layers):
weight: SplitPrimitiveTensor = self.theta.tensor("w", str(layer))
x: ReplicatedTensor = ops.all_reduce(ops.linear(x, weight))

return x


def main(raw_args=None):
from ...utils import cli

parser = cli.create_parser()
parser.add_argument(
"output_file",
type=str,
nargs="?",
default="-",
help="Output file to save MLIR to",
)
cli.add_output_dataset_options(parser)
args = cli.parse(parser, args=raw_args)

bs = 16
sl = 128
primary_dim = 128 * 2**5
shard_count = 2
num_layers = 40
create_theta(primary_dim, shard_count, num_layers, save_path=args.output_irpa_file)

pp_count = 4
ds = Dataset.load(args.output_irpa_file)
root_theta = pipeline_parallelize_theta(ds.root_theta, pp_count)

mdl = PPFFN(root_theta)

example_arg = torch.empty(bs, sl, primary_dim, dtype=torch.float16)
ep = torch.export.export(mdl, (example_arg,)) # , strict=False)
cm = export(ep, arg_device={0: DeviceAffinity(0)})

if args.output_file == "-":
print(cm.mlir_module)
else:
with open(args.output_file, "wt") as f:
f.write(str(cm.mlir_module))


if __name__ == "__main__":
main()
101 changes: 100 additions & 1 deletion sharktank/sharktank/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,102 @@

from . import _registry
from ..types.tensors import unbox_tensor
from .signatures import *


def import_and_wrap_signatures():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ended up having to go in the __init__ file so that any import of ops would wrap the functions.

One problem with the approach: a subsequent run of from .signatures import * will override the wrapped versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty certain we don't want to do this. Only the sharded_impls should be the one to consider transfer information. This is pretty specific to sharded impls. We only want to wrap the importers for the sharded cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, try to name this import_sharded_signatures or something to that effect.

"""
Import the signatures from .signatures and wrap then so the shards of their inputs tensors are on the same devices.
For unary ops, also pins the result if the input is pinned (e.g. for transpose).
"""

def transfer_n_pin(f):
"""
Create a wrapper for each operation.
"""
from ..types import ShardedTensor
from typing import List, Tuple, Dict, Any

def unwrap_args(
items: Tuple | Dict[str, Any]
) -> Tuple[List[int | List[int]], List[ShardedTensor]]:
t_i, t_vals = [], []
for i, arg in enumerate(items):
if isinstance(arg, ShardedTensor):
t_i.append(i)
t_vals.append(arg)
elif isinstance(arg, list) and all(
isinstance(val, ShardedTensor) for val in arg
):
t_i.append([i] * len(arg))
t_vals.extend(arg)
return t_i, t_vals

def rewrap_args(
items: Tuple | Dict, t_i: List[int | List[int]], t_vals: List[ShardedTensor]
) -> Tuple[Tuple, Dict[str, Any]]:
i_lookup = (
list(range(len(items)))
if isinstance(items, tuple)
else list(items.keys())
)
new_items = list(items) if isinstance(items, tuple) else dict(items)

for i in t_i:
if isinstance(i, int):
new_items[i_lookup[i]] = t_vals.pop(0)
else: # List[int]
_popped_vals = [t_vals.pop(0) for _ in range(len(i))]
new_items[i_lookup[i[0]]] = items[i_lookup[i[0]]].__class__(
_popped_vals
)

if isinstance(new_items, list):
new_items = tuple(new_items)
return new_items

def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]):
t_i_args, t_vals_args = unwrap_args(args)
t_i_kwargs, t_vals_kwargs = unwrap_args(list(kwargs.values()))
t_vals = t_vals_args + t_vals_kwargs

t_vals = transfer_if_needed(*t_vals)

args = rewrap_args(args, t_i_args, t_vals[: len(t_vals_args)])
kwargs = rewrap_args(kwargs, t_i_kwargs, t_vals[len(t_vals_args) :])
res = f(*args, **kwargs)
if isinstance(res, ShardedTensor) and len(t_vals) > 0:
pinned = (
res.pinned
or (len(t_vals) == 1 and t_vals[0].pinned)
or f.__name__ == "reshard_like"
) # TODO: How to handle this case properly
res = res.clone(devices=t_vals[0].devices, pinned=pinned)
return res

if hasattr(f, "override"): # Needed for ops like gelu_tanh_approximation
func_wrapper.override = f.override
return func_wrapper

do_not_wrap = {
"all_gather",
"all_reduce",
"replicate",
"index_copy_",
"index_put_",
"transfer_to_logical_device",
}

from . import signatures

for func_name in signatures.__all__:
func = getattr(signatures, func_name)
if func_name not in do_not_wrap:
func = transfer_n_pin(func)
globals()[func_name] = func


import_and_wrap_signatures()

from .shape import *

# Ensure that implementations are registered.
Expand All @@ -33,3 +128,7 @@
# Comment this out to completely disable optimized quantized implementations.
from . import qconv_impls
from . import qlinear_impls

from .sharded_impls import (
transfer_if_needed,
) # TODO: Hack just to get tests running, figure out properly later
84 changes: 72 additions & 12 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torch import Tensor
from typing import List, Optional, Sequence, Union, Any, Tuple
from typing import List, Optional, Sequence, Union, Any, Tuple, Dict
import itertools
from numbers import Number
import math
Expand All @@ -31,6 +31,59 @@
from ..utils import longest_equal_range


def transfer_if_needed(*tensors: Tuple[ShardedTensor]) -> List[ShardedTensor]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to put this in __init__ at well since that is not the only place it's being used. But the top level type hints for this functions will make the imports messy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to keep in sharding as all of these wrapping behavior should be sharded specific.

"""
If at least 2 tensors are panned in, the shards of all unpinned tensors are transfered to be on the same devices as those of the pinned tensors.
"""

def tensor_w_shards_moved(
tensor: ShardedTensor, new_devices: Tuple[int]
) -> ShardedTensor:
"""
Create a copy of the passed in tensor, but with shards now placed on the new devices.
"""
new_shards = tuple(
(
transfer_to_logical_device(shard, new_devices[j])
if new_devices[j] != tensor.devices[j]
else barrier_on_logical_device(shard, new_devices[j])
)
for j, shard in enumerate(tensor.shards)
)
return tensor.clone(ts=new_shards, devices=new_devices)

if len(tensors) <= 1:
return list(tensors)
assert all(isinstance(tensor, ShardedTensor) for tensor in tensors)

# Check if all tensors are on the same devices.
all_on_same_devices = True
for tensor in tensors[1:]:
if any(d0 != d for d0, d in zip(tensors[0].devices, tensor.devices)):
all_on_same_devices = False
break
if all_on_same_devices:
return list(tensors)

pinned_tensors = [tensor for tensor in tensors if tensor.pinned]
if len(pinned_tensors) == 0:
raise ValueError(
"Tensors are on different devices, but none are pinned. Don't know which devices to transfer to."
)

pinned_devices = pinned_tensors[0].devices
for pinned_tensor in pinned_tensors[1:]:
if any(d0 != d for d0, d in zip(pinned_devices, pinned_tensor.devices)):
raise ValueError("All pinned tensors must be on the same devices.")

# Move all non-pinned tensors to the same devices as the pinned ones.
new_tensors = [
(tensor if tensor.pinned else tensor_w_shards_moved(tensor, pinned_devices))
for tensor in tensors
]
return new_tensors


@all_gather.override(SplitPrimitiveTensor)
def all_gather_split(
input: SplitPrimitiveTensor, *, dim: int | None
Expand All @@ -43,17 +96,17 @@ def all_gather_split(
cat(
[
(
barrier_on_logical_device(shard, i)
barrier_on_logical_device(shard, input.devices[i])
if i == j
else transfer_to_logical_device(shard, i)
else transfer_to_logical_device(shard, input.devices[i])
)
for j, shard in enumerate(input.shards)
],
dim=dim,
)
for i in range(input.shard_count)
]
return ReplicatedTensor(ts=shards)
return ReplicatedTensor(ts=shards, devices=input.devices)


@all_reduce.override(AllOfType(SplitPrimitiveTensor, UnreducedTensor))
Expand All @@ -68,16 +121,16 @@ def all_reduce_split_or_unreduced(
lambda x, y: elementwise(torch.add, x, y),
[
(
barrier_on_logical_device(shard, i)
barrier_on_logical_device(shard, input.devices[i])
if i == j
else transfer_to_logical_device(shard, i)
else transfer_to_logical_device(shard, input.devices[i])
)
for j, shard in enumerate(input.shards)
],
)
for i in range(input.shard_count)
]
return ReplicatedTensor(ts=shards)
return ReplicatedTensor(ts=shards, devices=input.devices)


@cat.override(AllOfType(ReplicatedTensor))
Expand All @@ -102,7 +155,6 @@ def cat_split(
for t in tensors
]
)

shard_dim = tensors[0].shard_dim
shard_count = tensors[0].shard_count
if dim != shard_dim:
Expand Down Expand Up @@ -894,28 +946,36 @@ def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedT


@replicate.override(ReplicatedTensor)
def replicate_replicated(input: ReplicatedTensor, *, count: int) -> ReplicatedTensor:
def replicate_replicated(
input: ReplicatedTensor, *, count: int, devices: None, pinned: None
) -> ReplicatedTensor:
if input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
return input


@replicate.override(SplitPrimitiveTensor)
def replicate_split(input: SplitPrimitiveTensor, *, count: int) -> ReplicatedTensor:
def replicate_split(
input: SplitPrimitiveTensor, *, count: int, devices: None, pinned: None
) -> ReplicatedTensor:
if input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
return all_gather(input)


@replicate.override(UnreducedTensor)
def replicate_unreduced(input: UnreducedTensor, *, count: int) -> ReplicatedTensor:
def replicate_unreduced(
input: UnreducedTensor, *, count: int, devices: None, pinned: None
) -> ReplicatedTensor:
if input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
return all_reduce(input)


@replicate.override(Tensor)
def replicate_unsharded(input, *, count: int) -> ReplicatedTensor:
def replicate_unsharded(
input, *, count: int, devices: Tuple[int], pinned: bool
) -> ReplicatedTensor:
torch_input = unbox_tensor(input)
# If we have a torch input replicating we can assume we need to transfer:
torch_inputs = [transfer_to_logical_device(torch_input, i) for i in range(count)]
Expand Down
Loading
Loading