@@ -225,20 +225,22 @@ def all_gather_split(
225
225
def all_reduce_split_or_unreduced (
226
226
input : Union [SplitPrimitiveTensor , UnreducedTensor ],
227
227
) -> ReplicatedTensor :
228
- # For each device move the shards to it and do a reduction.
229
- # If we don't move first, common sub-expression elimination is free to collapse all
230
- # reductions into one and then copy to all devices, which is not what we want.
228
+ reduced = functools .reduce (
229
+ lambda x , y : elementwise (torch .add , x , y ),
230
+ [
231
+ (
232
+ transfer_to_logical_device (shard , input .devices [0 ])
233
+ if i != 0
234
+ else barrier_on_logical_device (shard , input .devices [0 ])
235
+ )
236
+ for i , shard in enumerate (input .shards )
237
+ ],
238
+ )
231
239
shards = [
232
- functools .reduce (
233
- lambda x , y : elementwise (torch .add , x , y ),
234
- [
235
- (
236
- barrier_on_logical_device (shard , input .devices [i ])
237
- if i == j
238
- else transfer_to_logical_device (shard , input .devices [i ])
239
- )
240
- for j , shard in enumerate (input .shards )
241
- ],
240
+ (
241
+ transfer_to_logical_device (reduced , input .devices [i ])
242
+ if i != 0
243
+ else barrier_on_logical_device (reduced , input .devices [0 ])
242
244
)
243
245
for i in range (input .shard_count )
244
246
]
0 commit comments