-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial pipeline parallelism support #1008
Draft
Alex-Vasile
wants to merge
49
commits into
nod-ai:main
Choose a base branch
from
Alex-Vasile:pipeline
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+954
−35
Draft
Changes from 30 commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
5145b9b
Example
Alex-Vasile 76991de
Changes to tensors and sharded_impls to get working
Alex-Vasile 524d4c1
Use valid PP and TP combo
Alex-Vasile 93dad69
Formatting cleanup
Alex-Vasile 934d556
Cleanup from feedback
Alex-Vasile b421fd7
Changes to tensors to make tests pass
Alex-Vasile 4a198ea
Changes to sharded_impls to have tests pass
Alex-Vasile db43786
Changes to tests to have them pass
Alex-Vasile f9bd9fd
Initial commit of tests
Alex-Vasile fdef39a
Cleanedup TODO
Alex-Vasile 0a12c71
ops.replicate support
Alex-Vasile b549558
Remove ambiguity of calling ops.replicate with a ShardedTensor
Alex-Vasile 817569c
Further tests, and fix typo in file name
Alex-Vasile cc6aaa2
Wrap transfer_if_needed into a decorator
Alex-Vasile b5d471a
Testing wrapped decorator
Alex-Vasile 8c54c47
Expand transfer_if_needed to work with arbitrary number of tensors.
Alex-Vasile 627c73c
Throw error when tensors are on different devices and none are pinned.
Alex-Vasile 7907ce1
Fix missing arg to constructor
Alex-Vasile f2c5615
Changes to transfer_if_needed
Alex-Vasile 0d9af59
Stop pinning all_reduce and all_gather results
Alex-Vasile 251a70a
Changes to decorator
Alex-Vasile d1fcf81
override_w_transfer pinnes unary ops result if input is pinned
Alex-Vasile 2ca09d4
Correct year in copyright header
Alex-Vasile b29a2df
Revert "Changes to tests to have them pass"
Alex-Vasile 97c25e5
Remove setter for shards
Alex-Vasile 73e5481
Added defaults for devices and devices_pinned
Alex-Vasile 5f8e7ec
Rename devices_pinned to pinned
Alex-Vasile 81a1435
Fix flattening and unflattening of unreduced_tensor
Alex-Vasile 4f3b434
Stop writing pipelineparallelism related data to archive
Alex-Vasile cf23568
Disable changes to is_deep_equal
Alex-Vasile 1a1091b
Cleanup and documentation
Alex-Vasile da2db54
Missed params for initializer
Alex-Vasile 649018f
Default value for .pinned
Alex-Vasile 2af7cd6
example cleanup
Alex-Vasile 10e4d9b
Remove setter and rewrite example
Alex-Vasile bb87dd0
Add clone constructor for ShardedTensor
Alex-Vasile a4bc196
Change from overriding the decorator to overring imports
Alex-Vasile f9340fe
Change wrapper to specify devices on output
Alex-Vasile a14a283
Remove unnecessary passing of arguments
Alex-Vasile 5e147ef
fix return logic
Alex-Vasile 4ff8667
Cleanup
Alex-Vasile 07d8ed8
Expand wrapper functionality to handle collections of tensors in args…
Alex-Vasile 9e1aa62
More tests
Alex-Vasile b494532
Move tranfer wrapper to init to capture all imports of ops
Alex-Vasile 37fc560
Don't transfer for Index_put and index_copy
Alex-Vasile d72f1c6
More tests
Alex-Vasile 4817485
Changes to wrapper to have reshard_like pass
Alex-Vasile 5dbff38
Don't wrap transfer_to_logical_device
Alex-Vasile 24a33e9
pre-commit cleanup
Alex-Vasile File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
120 changes: 120 additions & 0 deletions
120
sharktank/sharktank/examples/pipeline/export_ppffn_net.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright 2025 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
"""Example program to export a sharded FFN network like what is found in | ||
a typical transformer layer. This is used for developing and testing various | ||
tooling flows with a scaled down example. | ||
|
||
Generate MLIR and a random inited IRPA file with: | ||
|
||
python -m sharktank.examples.sharding.export_ffn_net \ | ||
--output-irpa-file=/tmp/ffn.irpa /tmp/ffn.mlir | ||
""" | ||
|
||
import math | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from ...layers import * | ||
from ... import ops | ||
from ...types import * | ||
|
||
from iree.turbine.aot import DeviceAffinity, DeviceTensorTrait, export | ||
|
||
def create_theta( | ||
dim: int, shard_count: int, num_layers: int, save_path | ||
): | ||
split_size = dim // shard_count | ||
|
||
weights = [] | ||
for layer in range(num_layers): | ||
_weight = torch.rand(dim, dim, dtype=torch.float16) / math.sqrt(dim) | ||
weights.append( | ||
SplitPrimitiveTensor( | ||
name=f"w.{layer}", | ||
shard_dim=1, | ||
ts=_weight.split(split_size, dim=1), | ||
devices = [_d for _d in range(shard_count)] | ||
) | ||
) | ||
|
||
|
||
ds = Dataset({}, Theta(weights)) | ||
ds.save(save_path) | ||
|
||
|
||
def pipeline_parallelize_theta( | ||
theta: Theta, | ||
pp_count: int | ||
) -> Theta: | ||
num_layers = len(theta.tensor("w")) | ||
shard_count = theta.tensor("w", '0').shard_count | ||
|
||
for layer, weight in ((int(l), w) for l, w in theta.tensor("w").items()): | ||
pp_group = int(layer * pp_count / num_layers) | ||
zero_4_group = shard_count * pp_group | ||
weight.devices = tuple(i + zero_4_group for i in range(shard_count)) | ||
for i, shard in enumerate(weight.shards): | ||
DeviceTensorTrait(weight.devices[i]).set(shard._data) | ||
|
||
return theta | ||
|
||
|
||
class PPFFN(ThetaLayer): | ||
def forward(self, x: torch.Tensor): | ||
num_layers = len(self.theta.tensor("w")) | ||
shard_count = self.theta.tensor("w", '0').shard_count | ||
|
||
x = ReplicatedTensor(ts=x, shard_count=shard_count, devices=[_d for _d in range(shard_count)]) | ||
for layer in range(num_layers): | ||
weight: SplitPrimitiveTensor = self.theta.tensor("w", str(layer)) | ||
x: ReplicatedTensor = ops.all_reduce(ops.linear(x, weight)) | ||
|
||
return x | ||
|
||
|
||
def main(raw_args=None): | ||
from ...utils import cli | ||
|
||
parser = cli.create_parser() | ||
parser.add_argument( | ||
"output_file", | ||
type=str, | ||
nargs="?", | ||
default="-", | ||
help="Output file to save MLIR to", | ||
) | ||
cli.add_output_dataset_options(parser) | ||
args = cli.parse(parser, args=raw_args) | ||
|
||
bs = 16 | ||
sl = 128 | ||
primary_dim = 128 * 2**5 | ||
shard_count = 2 | ||
num_layers = 40 | ||
create_theta(primary_dim, shard_count, num_layers, save_path=args.output_irpa_file) | ||
|
||
pp_count = 4 | ||
ds = Dataset.load(args.output_irpa_file) | ||
root_theta = pipeline_parallelize_theta(ds.root_theta, pp_count) | ||
|
||
mdl = PPFFN(root_theta) | ||
|
||
example_arg = torch.empty(bs, sl, primary_dim, dtype=torch.float16) | ||
ep = torch.export.export(mdl, (example_arg,))#, strict=False) | ||
cm = export(ep, arg_device={0: DeviceAffinity(0)}) | ||
|
||
if args.output_file == "-": | ||
print(cm.mlir_module) | ||
else: | ||
with open(args.output_file, "wt") as f: | ||
f.write(str(cm.mlir_module)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO