Skip to content

Commit be16338

Browse files
authored
Place args for all gather/reduce on devices before the op to avoid CSE and excessive copying (#171)
We are encoding the device/shard information in the flow.tensor.transfer/transfer_to_logical_device operation. Then if we do an all-gather or an all-reduce, CSE is happy to collapse the expressions into one. This would result in the all-gather/all-reduce being performed on one device and then the result is copied to the rest. We want each device to do the all-gather/all-reduce. There is no easy way to test the desired effect, but at least we test for correctness on the PyTorch level. This change adds the all_reduce op that is currently not used anywhere. Here is expanded the elementwise op to support a variable number of tensor arguments.
1 parent 2fecbfd commit be16338

File tree

7 files changed

+206
-34
lines changed

7 files changed

+206
-34
lines changed

sharktank/sharktank/ops/_registry.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
__all__ = [
2020
"AllOfExprs",
21+
"AllOfExprsVariadic",
2122
"AllOfType",
2223
"AnyOfType",
2324
"IsOfType",
@@ -65,7 +66,8 @@ def __call__(self, *args: type) -> bool:
6566

6667

6768
class AllOfExprs(BoolTypeExpr):
68-
"""Returns True if all types match their respective boolean type expression.
69+
"""Returns True if all type arguments match their respective boolean type
70+
expression.
6971
7072
```python
7173
# True. int == int and str in (float, str).
@@ -87,6 +89,38 @@ def expr(*types: type):
8789
super().__init__(expr)
8890

8991

92+
class AllOfExprsVariadic(BoolTypeExpr):
93+
"""Returns True if all type arguments match their respective boolean type
94+
expression and any remaining trailing arguments match the last type expression.
95+
96+
```python
97+
# True. int == int
98+
# str in (float, str).
99+
# float in (float, str).
100+
AllOfExprsVariadic(IsOfType(int), IsOfType(float, str))(int, str, float)
101+
102+
# False. str is not in (int, float).
103+
AllOfExprsVariadic(IsOfType(int), IsOfType(int, float))(int, float, str, int)
104+
```
105+
"""
106+
107+
def __init__(self, *exprs: BoolTypeExpr):
108+
if len(exprs) == 0:
109+
raise ValueError("At least one expression is required.")
110+
self._exprs = list(exprs)
111+
112+
def expr(*types: type):
113+
if len(types) < len(self._exprs):
114+
return False
115+
exprs = self._exprs
116+
if len(types) > len(exprs):
117+
# pad with the trailing expression.
118+
exprs = exprs + ([exprs[-1]] * (len(types) - len(self._exprs)))
119+
return all([e(t) for e, t in zip(exprs, types)])
120+
121+
super().__init__(expr)
122+
123+
90124
class AllOfType(BoolTypeExpr):
91125
"""Returns True if all of the types are from a set of types.
92126

sharktank/sharktank/ops/default_impls.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616

1717
from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor
1818
from ..types.tensors import unbox_tensor, AnyTensor
19-
from ._registry import AllOfType, AllOfExprs, IsOfType
19+
from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType
2020
from .signatures import *
21+
import shark_turbine.ops.iree
2122

2223

2324
@cat.override(AllOfType(Tensor, PrimitiveTensor))
@@ -80,6 +81,39 @@ def elementwise_binary(operator, x, y):
8081
return operator(x, y)
8182

8283

84+
@elementwise.override(
85+
AllOfExprsVariadic(
86+
IsOfType(Tensor, InferenceTensor),
87+
IsOfType(Tensor, InferenceTensor, Number),
88+
IsOfType(Tensor, InferenceTensor, Number),
89+
)
90+
)
91+
def elementwise_variadic(operator, x, y, *args):
92+
"""Folds by successively applying the binary operator from left to right until
93+
exhaustion.
94+
95+
Match a variable number of tensor/number arguments with at least 3 such arguments.
96+
97+
Example matches
98+
```
99+
(Tensor, Tensor, Tensor)
100+
(Tensor, DefaultPrimitiveTensor, float),
101+
(SplitPrimitiveTensor, ReplicatedTensor, int, Tensor)
102+
```
103+
104+
Will not match
105+
```
106+
(Tensor)
107+
(Tensor, Tensor)
108+
(int, Tensor, Tensor)
109+
```
110+
"""
111+
res = elementwise(operator, x, y)
112+
for arg in args:
113+
res = elementwise(operator, res, arg)
114+
return res
115+
116+
83117
# Embedding Lookup
84118
@embedding_lookup.override(Tensor, Tensor)
85119
def embedding_lookup_default(input, embedding_matrix, dtype: dtype):
@@ -234,6 +268,13 @@ def permute(tensor: Tensor, dims: List[int]):
234268
return torch.permute(torch_tensor, dims)
235269

236270

271+
@transfer_to_logical_device.override(Tensor)
272+
def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
273+
return shark_turbine.ops.iree.transfer_to_logical_device(
274+
f"{ordinal}", unbox_tensor(tensor)
275+
)
276+
277+
237278
# Sharded default impls (do nothing).
238279

239280

sharktank/sharktank/ops/sharded_impls.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,30 @@
3131
def all_gather_split(
3232
input: SplitPrimitiveTensor, *, dim: int | None
3333
) -> ReplicatedTensor:
34-
assert (
35-
dim is None
36-
), "gather dimension other than `input.shard_dim` is not supported."
37-
# TODO: figure out how to avoid common sub-expression elimination to not
38-
# merge all these into one.
39-
# Even if we place each resulting shard inside of ReplicatedTensor on a
40-
# distinct logical device with an explicit operation, CSE should still
41-
# collapse them.
42-
shards = [sharded_cat(input) for i in range(input.shard_count)]
34+
dim = input.shard_dim if dim is None else dim
35+
# For each device move the shards to it and do a concatenation.
36+
# If we don't move first, common sub-expression elimination is free to collapse all
37+
# concatenations into one and then copy to all devices, which is not what we want.
38+
shards = [
39+
cat([transfer_to_logical_device(shard, i) for shard in input.shards], dim=dim)
40+
for i in range(input.shard_count)
41+
]
42+
return ReplicatedTensor(ts=shards)
43+
44+
45+
@all_reduce.override(SplitPrimitiveTensor)
46+
def all_reduce_split(
47+
input: SplitPrimitiveTensor,
48+
) -> ReplicatedTensor:
49+
# For each device move the shards to it and do a reduction.
50+
# If we don't move first, common sub-expression elimination is free to collapse all
51+
# reductions into one and then copy to all devices, which is not what we want.
52+
shards = [
53+
elementwise(
54+
torch.add, *[transfer_to_logical_device(shard, i) for shard in input.shards]
55+
)
56+
for i in range(input.shard_count)
57+
]
4358
return ReplicatedTensor(ts=shards)
4459

4560

@@ -692,15 +707,15 @@ def reshard_like_split_to_split(
692707
return tensor
693708

694709

695-
# Sharded sum.
696-
697-
698710
@sharded_cat.override(SplitPrimitiveTensor)
699711
def sharded_cat_unsharded(maybe_sharded: SplitPrimitiveTensor):
700712
shard_ts = [t.as_torch() for t in maybe_sharded.shards]
701713
return torch.cat(shard_ts, dim=maybe_sharded.shard_dim)
702714

703715

716+
# Sharded sum.
717+
718+
704719
def _sharded_sum_sharded(tensor: ShardedTensor) -> Tensor:
705720
accum = tensor.shards[0].as_torch()
706721
for shard in tensor.shards[1:]:
@@ -709,13 +724,13 @@ def _sharded_sum_sharded(tensor: ShardedTensor) -> Tensor:
709724

710725

711726
@sharded_sum.override(SplitPrimitiveTensor)
712-
def sharded_sum_split(maybe_sharded: SplitPrimitiveTensor):
727+
def sharded_sum_split(maybe_sharded: SplitPrimitiveTensor) -> Tensor:
713728
# TODO: Should implement as an all reduce.
714729
return _sharded_sum_sharded(maybe_sharded)
715730

716731

717732
@sharded_sum.override(UnreducedTensor)
718-
def sharded_sum_unreduced(maybe_sharded: UnreducedTensor):
733+
def sharded_sum_unreduced(maybe_sharded: UnreducedTensor) -> Tensor:
719734
return _sharded_sum_sharded(maybe_sharded)
720735

721736

sharktank/sharktank/ops/signatures.py

+39
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
__all__ = [
2020
"all_gather",
21+
"all_reduce",
2122
"cat",
2223
"conv2d",
2324
"elementwise",
@@ -38,6 +39,7 @@
3839
"scaled_dot_product_attention",
3940
"sharded_cat",
4041
"sharded_sum",
42+
"transfer_to_logical_device",
4143
"unshard",
4244
]
4345

@@ -46,6 +48,7 @@
4648

4749
@overridable
4850
def all_gather(maybe_sharded: AnyTensor, *, dim: int | None = None) -> AnyTensor:
51+
"Gather/concatenate on all devices along dimension `dim`."
4952
...
5053

5154

@@ -62,6 +65,23 @@ def _all_gather_trampoline(
6265
d.fail(tensors)
6366

6467

68+
@overridable
69+
def all_reduce(tensor: AnyTensor) -> AnyTensor:
70+
"Reduce on all devices."
71+
...
72+
73+
74+
@all_reduce.trampoline
75+
def _all_reduce_trampoline(d: SignatureDispatcher, tensor: AnyTensor):
76+
tensors = (tensor,)
77+
for override in d.find_overrides(tensors):
78+
result = override(tensor)
79+
if result is not NotImplemented:
80+
return override, result
81+
else:
82+
d.fail(tensors)
83+
84+
6585
@overridable
6686
def cat(tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0) -> AnyTensor:
6787
...
@@ -616,6 +636,25 @@ def _sharded_sum_trampoline(d: SignatureDispatcher, maybe_sharded: AnyTensor):
616636
d.fail(tensors)
617637

618638

639+
@overridable
640+
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
641+
"""Transfer the tensor to a device with ordinal `ordinal`."""
642+
...
643+
644+
645+
@transfer_to_logical_device.trampoline
646+
def _transfer_to_logical_device(
647+
d: SignatureDispatcher, tensor: AnyTensor, ordinal: int
648+
):
649+
tensors = (tensor,)
650+
for override in d.find_overrides(tensors):
651+
result = override(tensor, ordinal)
652+
if result is not NotImplemented:
653+
return override, result
654+
else:
655+
d.fail(tensors)
656+
657+
619658
@overridable
620659
def unshard(tensor: AnyTensor) -> AnyTensor:
621660
"""Return the tensor that has the same elements and shape, but is not sharded."""

sharktank/sharktank/types/tensors.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from shark_turbine.aot import (
3131
ExternalTensorTrait,
3232
)
33-
from shark_turbine.ops.iree import transfer_to_logical_device
3433
from ..utils import tree as tree_utils
3534

3635
from ..utils.io import ShardedArchiveBuilder
@@ -618,13 +617,15 @@ def __init__(
618617
name: str = UnnamedTensorName,
619618
shape: Optional[list[int]],
620619
):
620+
from ..ops import transfer_to_logical_device
621+
621622
assert len(ts) > 0
622623
assert shard_dim is None or len(ts[0].shape) > shard_dim
623624
super().__init__(name=name, shape=shape, shard_dim=shard_dim)
624625
self._shards: tuple[DefaultPrimitiveTensor] = tuple(
625626
DefaultPrimitiveTensor(
626627
name=f"{name}.shard.{i}",
627-
data=transfer_to_logical_device(f"{i}", unbox_tensor(t)),
628+
data=transfer_to_logical_device(t, i),
628629
)
629630
for i, t in enumerate(ts)
630631
)
@@ -867,6 +868,8 @@ def __init__(
867868
will be replicated that many times.
868869
"""
869870

871+
from ..ops import transfer_to_logical_device
872+
870873
if isinstance(ts, torch.Tensor):
871874
assert shard_count is not None
872875
ts = [ts] * shard_count
@@ -884,7 +887,7 @@ def __init__(
884887
self._shards: tuple[DefaultPrimitiveTensor] = tuple(
885888
DefaultPrimitiveTensor(
886889
name=f"{name}.shard.{i}",
887-
data=transfer_to_logical_device(f"{i}", unbox_tensor(t)),
890+
data=transfer_to_logical_device(t, i),
888891
)
889892
for i, t in enumerate(ts)
890893
)

sharktank/tests/ops/ops_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
import torch.nn.functional as F
11+
from parameterized import parameterized
1112

1213
from sharktank import ops
1314
from sharktank.types import *
@@ -34,6 +35,26 @@ def testBroadcastDims(self):
3435
assert res[1] == 2
3536

3637

38+
class ElementwiseTest(unittest.TestCase):
39+
@parameterized.expand(
40+
[
41+
(torch.add,),
42+
(torch.div,),
43+
(torch.fmin,),
44+
(torch.fmax,),
45+
(torch.sub),
46+
]
47+
)
48+
def testMultiArgOperators(self, operator):
49+
a = torch.rand(2, 3, 4, dtype=torch.float32)
50+
b = torch.rand(2, 3, 4, dtype=torch.float32)
51+
c = torch.rand(2, 3, 4, dtype=torch.float32)
52+
d = torch.rand(2, 3, 4, dtype=torch.float32)
53+
expected_result = operator(operator(operator(a, b), c), d)
54+
actual_result = ops.elementwise(operator, a, b, c, d)
55+
torch.testing.assert_close(actual_result, expected_result)
56+
57+
3758
class EqualTest(unittest.TestCase):
3859
def testEqualTorchTensors(self):
3960
a = torch.rand(2, 3, dtype=torch.float32)

0 commit comments

Comments
 (0)