Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial pipeline parallelism support #1008

Draft
wants to merge 49 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5145b9b
Example
Alex-Vasile Feb 25, 2025
76991de
Changes to tensors and sharded_impls to get working
Alex-Vasile Feb 25, 2025
524d4c1
Use valid PP and TP combo
Alex-Vasile Feb 25, 2025
93dad69
Formatting cleanup
Alex-Vasile Feb 25, 2025
934d556
Cleanup from feedback
Alex-Vasile Feb 25, 2025
b421fd7
Changes to tensors to make tests pass
Alex-Vasile Feb 25, 2025
4a198ea
Changes to sharded_impls to have tests pass
Alex-Vasile Feb 25, 2025
db43786
Changes to tests to have them pass
Alex-Vasile Feb 25, 2025
f9bd9fd
Initial commit of tests
Alex-Vasile Feb 25, 2025
fdef39a
Cleanedup TODO
Alex-Vasile Feb 25, 2025
0a12c71
ops.replicate support
Alex-Vasile Feb 25, 2025
b549558
Remove ambiguity of calling ops.replicate with a ShardedTensor
Alex-Vasile Feb 25, 2025
817569c
Further tests, and fix typo in file name
Alex-Vasile Feb 26, 2025
cc6aaa2
Wrap transfer_if_needed into a decorator
Alex-Vasile Feb 26, 2025
b5d471a
Testing wrapped decorator
Alex-Vasile Feb 26, 2025
8c54c47
Expand transfer_if_needed to work with arbitrary number of tensors.
Alex-Vasile Feb 26, 2025
627c73c
Throw error when tensors are on different devices and none are pinned.
Alex-Vasile Feb 26, 2025
7907ce1
Fix missing arg to constructor
Alex-Vasile Feb 26, 2025
f2c5615
Changes to transfer_if_needed
Alex-Vasile Feb 26, 2025
0d9af59
Stop pinning all_reduce and all_gather results
Alex-Vasile Feb 26, 2025
251a70a
Changes to decorator
Alex-Vasile Feb 26, 2025
d1fcf81
override_w_transfer pinnes unary ops result if input is pinned
Alex-Vasile Feb 26, 2025
2ca09d4
Correct year in copyright header
Alex-Vasile Feb 26, 2025
b29a2df
Revert "Changes to tests to have them pass"
Alex-Vasile Feb 26, 2025
97c25e5
Remove setter for shards
Alex-Vasile Feb 26, 2025
73e5481
Added defaults for devices and devices_pinned
Alex-Vasile Feb 26, 2025
5f8e7ec
Rename devices_pinned to pinned
Alex-Vasile Feb 26, 2025
81a1435
Fix flattening and unflattening of unreduced_tensor
Alex-Vasile Feb 26, 2025
4f3b434
Stop writing pipelineparallelism related data to archive
Alex-Vasile Feb 26, 2025
cf23568
Disable changes to is_deep_equal
Alex-Vasile Feb 27, 2025
1a1091b
Cleanup and documentation
Alex-Vasile Feb 27, 2025
da2db54
Missed params for initializer
Alex-Vasile Feb 27, 2025
649018f
Default value for .pinned
Alex-Vasile Feb 27, 2025
2af7cd6
example cleanup
Alex-Vasile Feb 27, 2025
10e4d9b
Remove setter and rewrite example
Alex-Vasile Feb 27, 2025
bb87dd0
Add clone constructor for ShardedTensor
Alex-Vasile Feb 28, 2025
a4bc196
Change from overriding the decorator to overring imports
Alex-Vasile Feb 28, 2025
f9340fe
Change wrapper to specify devices on output
Alex-Vasile Feb 28, 2025
a14a283
Remove unnecessary passing of arguments
Alex-Vasile Feb 28, 2025
5e147ef
fix return logic
Alex-Vasile Feb 28, 2025
4ff8667
Cleanup
Alex-Vasile Feb 28, 2025
07d8ed8
Expand wrapper functionality to handle collections of tensors in args…
Alex-Vasile Mar 1, 2025
9e1aa62
More tests
Alex-Vasile Mar 3, 2025
b494532
Move tranfer wrapper to init to capture all imports of ops
Alex-Vasile Mar 3, 2025
37fc560
Don't transfer for Index_put and index_copy
Alex-Vasile Mar 3, 2025
d72f1c6
More tests
Alex-Vasile Mar 3, 2025
4817485
Changes to wrapper to have reshard_like pass
Alex-Vasile Mar 3, 2025
5dbff38
Don't wrap transfer_to_logical_device
Alex-Vasile Mar 3, 2025
24a33e9
pre-commit cleanup
Alex-Vasile Mar 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions sharktank/sharktank/examples/pipeline/export_ppffn_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Example program to export a sharded FFN network like what is found in
a typical transformer layer. This is used for developing and testing various
tooling flows with a scaled down example.

Generate MLIR and a random inited IRPA file with:

python -m sharktank.examples.sharding.export_ffn_net \
--output-irpa-file=/tmp/ffn.irpa /tmp/ffn.mlir
"""

import math
from typing import Tuple

import torch
import torch.nn as nn

from ...layers import *
from ... import ops
from ...types import *

from iree.turbine.aot import DeviceAffinity, DeviceTensorTrait, export

def create_theta(
dim: int, shard_count: int, num_layers: int, save_path
):
split_size = dim // shard_count
weights = []
for layer in range(num_layers):
_weight = torch.rand(dim, dim, dtype=torch.float16) / math.sqrt(dim)
weights.append(
SplitPrimitiveTensor(
name=f"w.{layer}",
shard_dim=1,
ts=_weight.split(split_size, dim=1)
)
)
ds = Dataset({}, Theta(weights))
ds.save(save_path)


def pipeline_parallelize_theta(
theta: Theta,
pp_count: int
) -> Theta:
num_layers = len(theta.tensor("w"))
shard_count = theta.tensor("w", '0').shard_count
for layer in list(theta.tensor("w").keys()):
weight = theta.tensor("w", layer)
pp_group = int(int(layer) * pp_count / num_layers)
zero_4_group = shard_count * pp_group
devices = tuple(i + zero_4_group for i in range(shard_count))
shards = weight.shards
for i, shard in enumerate(shards):
DeviceTensorTrait(devices[i]).set(shard._data)
new_weight = SplitPrimitiveTensor(
shard_dim=weight.shard_dim,
ts=shards,
name=weight.name,
devices=devices,
pinned=True
)
theta.tensor("w")[layer] = new_weight
return theta


class PPFFN(ThetaLayer):
def forward(self, x: torch.Tensor):
num_layers = len(self.theta.tensor("w"))
shard_count = self.theta.tensor("w", '0').shard_count

x = ReplicatedTensor(ts=x, shard_count=shard_count)
for layer in range(num_layers):
weight: SplitPrimitiveTensor = self.theta.tensor("w", str(layer))
x: ReplicatedTensor = ops.all_reduce(ops.linear(x, weight))

return x


def main(raw_args=None):
from ...utils import cli

parser = cli.create_parser()
parser.add_argument(
"output_file",
type=str,
nargs="?",
default="-",
help="Output file to save MLIR to",
)
cli.add_output_dataset_options(parser)
args = cli.parse(parser, args=raw_args)

bs = 16
sl = 128
primary_dim = 128 * 2**5
shard_count = 2
num_layers = 40
create_theta(primary_dim, shard_count, num_layers, save_path=args.output_irpa_file)

pp_count = 4
ds = Dataset.load(args.output_irpa_file)
root_theta = pipeline_parallelize_theta(ds.root_theta, pp_count)

mdl = PPFFN(root_theta)

example_arg = torch.empty(bs, sl, primary_dim, dtype=torch.float16)
ep = torch.export.export(mdl, (example_arg,))#, strict=False)
cm = export(ep, arg_device={0: DeviceAffinity(0)})

if args.output_file == "-":
print(cm.mlir_module)
else:
with open(args.output_file, "wt") as f:
f.write(str(cm.mlir_module))


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions sharktank/sharktank/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@
# Comment this out to completely disable optimized quantized implementations.
from . import qconv_impls
from . import qlinear_impls

from .sharded_impls import transfer_if_needed # TODO: Hack just to get tests running, figure out properly later
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

Loading
Loading