Skip to content

Commit 2705fde

Browse files
authored
Speed-up all_reduce by reducing buffer copies (#957)
Reduce copying and speed up execution since copies can't happen in parallel.
1 parent 5432299 commit 2705fde

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

sharktank/sharktank/ops/sharded_impls.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -225,20 +225,22 @@ def all_gather_split(
225225
def all_reduce_split_or_unreduced(
226226
input: Union[SplitPrimitiveTensor, UnreducedTensor],
227227
) -> ReplicatedTensor:
228-
# For each device move the shards to it and do a reduction.
229-
# If we don't move first, common sub-expression elimination is free to collapse all
230-
# reductions into one and then copy to all devices, which is not what we want.
228+
reduced = functools.reduce(
229+
lambda x, y: elementwise(torch.add, x, y),
230+
[
231+
(
232+
transfer_to_logical_device(shard, input.devices[0])
233+
if i != 0
234+
else barrier_on_logical_device(shard, input.devices[0])
235+
)
236+
for i, shard in enumerate(input.shards)
237+
],
238+
)
231239
shards = [
232-
functools.reduce(
233-
lambda x, y: elementwise(torch.add, x, y),
234-
[
235-
(
236-
barrier_on_logical_device(shard, input.devices[i])
237-
if i == j
238-
else transfer_to_logical_device(shard, input.devices[i])
239-
)
240-
for j, shard in enumerate(input.shards)
241-
],
240+
(
241+
transfer_to_logical_device(reduced, input.devices[i])
242+
if i != 0
243+
else barrier_on_logical_device(reduced, input.devices[0])
242244
)
243245
for i in range(input.shard_count)
244246
]

0 commit comments

Comments
 (0)