Skip to content

Commit 3ad7abb

Browse files
authored
Add pipeline parallelism support to ops (#1008)
**Goal:** Introduce pipeline parallelism without requiring a change to the weight irpa files or the forward passes for the different layers (see PPFFN.forward in the example file). **Changes** - ShardedTensor now explicitly store what device each of their shards should live on in a `.device` attribute. Previously the implicit convection was that shard `i` lived on device `i`. - ShardedTensor can also be pinned to specific devices, such as for weights, or left unpinned to signal that it should be moved if needed using a `.pinned` attribute. - operators call a helper function to see if either tensor needs to be transferred such that all shards are on matching devices. E.g. `ops.foo(t1 on devs [1,2,3], t2 pinned on devs[5,6,7])` would transfer the shards of t1 onto devices [5,6,7] before performing the operation. - Several helper functions in `ops` can take in a `torch.Tensor` and therefore won't know what devices to place them on, e.g. `def replicate(input: AnyTensor, count: int) -> ShardedTensor:`. I've added `devices` and `pinned` as extra parameters and used defaults to keep the current behaviour unchanged. - Added a wrapper to all functions imported from `ops.signatures` into `sharded_impls` to handle transfer and pinning automatically when called with `ShardedTensor` subclasses. Making device parallelism work, mostly, invisibly without needing to modify the functions in `sharded_impls.py`
1 parent 3f7b69b commit 3ad7abb

File tree

6 files changed

+1056
-40
lines changed

6 files changed

+1056
-40
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2025 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""Example program to export a sharded FFN network like what is found in
8+
a typical transformer layer. This is used for developing and testing various
9+
tooling flows with a scaled down example.
10+
11+
Generate MLIR and a random inited IRPA file with:
12+
13+
python -m sharktank.examples.sharding.export_ffn_net \
14+
--output-irpa-file=/tmp/ffn.irpa /tmp/ffn.mlir
15+
"""
16+
17+
import math
18+
19+
import torch
20+
21+
from ...layers import *
22+
from ... import ops
23+
from ...types import *
24+
25+
from iree.turbine.aot import DeviceAffinity, DeviceTensorTrait, export
26+
27+
28+
def create_theta(dim: int, shard_count: int, num_layers: int, save_path):
29+
split_size = dim // shard_count
30+
weights = []
31+
for layer in range(num_layers):
32+
_weight = torch.rand(dim, dim, dtype=torch.float16) / math.sqrt(dim)
33+
weights.append(
34+
SplitPrimitiveTensor(
35+
name=f"w.{layer}", shard_dim=1, ts=_weight.split(split_size, dim=1)
36+
)
37+
)
38+
ds = Dataset({}, Theta(weights))
39+
ds.save(save_path)
40+
41+
42+
def pipeline_parallelize_theta(theta: Theta, pp_count: int) -> Theta:
43+
num_layers = len(theta.tensor("w"))
44+
shard_count = theta.tensor("w", "0").shard_count
45+
for layer in list(theta.tensor("w").keys()):
46+
weight: ShardedTensor = theta.tensor("w", layer)
47+
pp_group = int(int(layer) * pp_count / num_layers)
48+
zero_4_group = shard_count * pp_group
49+
devices = tuple(i + zero_4_group for i in range(shard_count))
50+
51+
shards = weight.shards
52+
for i, shard in enumerate(shards):
53+
DeviceTensorTrait(devices[i]).set(shard._data)
54+
theta.tensor("w")[layer] = weight.clone(ts=shards, devices=devices, pinned=True)
55+
return theta
56+
57+
58+
class PPFFN(ThetaLayer):
59+
def forward(self, x: torch.Tensor):
60+
num_layers = len(self.theta.tensor("w"))
61+
shard_count = self.theta.tensor("w", "0").shard_count
62+
63+
x = ReplicatedTensor(ts=x, shard_count=shard_count)
64+
for layer in range(num_layers):
65+
weight: SplitPrimitiveTensor = self.theta.tensor("w", str(layer))
66+
x: ReplicatedTensor = ops.all_reduce(ops.linear(x, weight))
67+
68+
return x
69+
70+
71+
def main(raw_args=None):
72+
from ...utils import cli
73+
74+
parser = cli.create_parser()
75+
parser.add_argument(
76+
"output_file",
77+
type=str,
78+
nargs="?",
79+
default="-",
80+
help="Output file to save MLIR to",
81+
)
82+
cli.add_output_dataset_options(parser)
83+
args = cli.parse(parser, args=raw_args)
84+
85+
bs = 16
86+
sl = 128
87+
primary_dim = 128 * 2**5
88+
shard_count = 2
89+
num_layers = 40
90+
create_theta(primary_dim, shard_count, num_layers, save_path=args.output_irpa_file)
91+
92+
pp_count = 4
93+
ds = Dataset.load(args.output_irpa_file)
94+
root_theta = pipeline_parallelize_theta(ds.root_theta, pp_count)
95+
96+
mdl = PPFFN(root_theta)
97+
98+
example_arg = torch.empty(bs, sl, primary_dim, dtype=torch.float16)
99+
ep = torch.export.export(mdl, (example_arg,)) # , strict=False)
100+
cm = export(ep, arg_device={0: DeviceAffinity(0)})
101+
102+
if args.output_file == "-":
103+
print(cm.mlir_module)
104+
else:
105+
with open(args.output_file, "wt") as f:
106+
f.write(str(cm.mlir_module))
107+
108+
109+
if __name__ == "__main__":
110+
main()

0 commit comments

Comments
 (0)