Skip to content

Commit 572a546

Browse files
committed
[llama] Added the fused rotary embedding kernel
Reworked rotary embedding application to be performed via a custom kernel. This includes dropping `static_table` for the sake of maintenance (it was largely unused). It includes a simple numerical test however under the hood no numerical change should occur. Existing baseline vs hugging face remained unchanged.
1 parent d520bd1 commit 572a546

File tree

7 files changed

+197
-60
lines changed

7 files changed

+197
-60
lines changed

sharktank/sharktank/examples/export_paged_llm_v1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def main():
7979
hp,
8080
tensor_parallelism_size=tensor_parallelism_size,
8181
use_hf=False,
82-
static_tables=False, # Rely on the compiler for hoisting tables.
8382
kv_cache_type="direct" if args.bs == [1] else "paged",
8483
attention_kernel=args.attention_kernel,
8584
block_seq_stride=args.block_seq_stride,

sharktank/sharktank/kernels/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .mmt_block_scaled_offset_q4 import *
1111
from .mmt_block_scaled_q8 import *
1212
from .mmt_super_block_scaled_offset_q4 import *
13+
from .rotary import *
1314
from .batch_matmul_transpose_b import *
1415
from .conv_2d_nchw_fchw import *
1516
from .pooling_nchw_sum import *

sharktank/sharktank/kernels/rotary.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from sharktank.kernels.base import *
8+
9+
__all__ = [
10+
"apply_rotary_embedding",
11+
]
12+
13+
14+
@CustomOp.register(library=LIBRARY)
15+
class apply_rotary_embedding(CustomOp):
16+
17+
signature = "apply_rotary_embedding(Tensor input, Tensor table) -> (Tensor)"
18+
19+
def select(self, ksel: KernelSelection):
20+
inputs_desc = ksel.arg_tensor(0)
21+
table_desc = ksel.arg_tensor(1)
22+
out_desc = ksel.return_new_tensor(
23+
inputs_desc.t.shape, dtype=inputs_desc.t.dtype
24+
)
25+
specialize_all_known_dims(inputs_desc)
26+
specialize_all_known_dims(table_desc)
27+
specialize_all_known_dims(out_desc)
28+
29+
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
30+
31+
input = kb.arg_value(0)
32+
table = kb.arg_value(1)
33+
34+
input_tensor_type = RankedTensorType(input.type)
35+
table_tensor_type = RankedTensorType(table.type)
36+
37+
input_asm_type, input_ident, input_dtype = unpack_tensor_type(input.type)
38+
table_asm_type, table_ident, table_dtype = unpack_tensor_type(table.type)
39+
40+
assert input_dtype == table_dtype
41+
42+
# Generate specialization signature and types.
43+
bs = input.type.shape[0]
44+
sl = input.type.shape[1]
45+
sl = "D" if sl < 0 else sl
46+
heads = input.type.shape[2]
47+
48+
template_file = "rotary_embedding.mlir"
49+
target_function_name = (
50+
f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{input_dtype}"
51+
)
52+
53+
# Template params.
54+
input_tensor_type = input_asm_type
55+
table_tensor_type = table_asm_type
56+
57+
target_function = inline_template_function(
58+
kb,
59+
template_file,
60+
target_function_name,
61+
input_tensor_type=input_tensor_type,
62+
table_tensor_type=table_tensor_type,
63+
bs=bs,
64+
sl=sl,
65+
heads=heads,
66+
dtype=str(input_dtype),
67+
)
68+
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright 2024 Advanced Micro Devices, Inc.
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
!input_tensor_type = {{input_tensor_type}}
8+
!table_tensor_type = {{table_tensor_type}}
9+
10+
module {
11+
12+
util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type {
13+
14+
%c0 = arith.constant 0 : index
15+
%c1 = arith.constant 1 : index
16+
%c2 = arith.constant 2 : index
17+
%c3 = arith.constant 3 : index
18+
19+
20+
%d0 = tensor.dim %input, %c0 : !input_tensor_type
21+
%d1 = tensor.dim %input, %c1 : !input_tensor_type
22+
%d2 = tensor.dim %input, %c2 : !input_tensor_type
23+
%d3 = tensor.dim %input, %c3 : !input_tensor_type
24+
25+
%dim = tensor.dim %table, %c1 : !table_tensor_type
26+
%hdim = arith.divui %dim, %c2 : index
27+
28+
29+
%empty_dyn = tensor.empty(%d0, %d1, %d2, %d3) : tensor<?x?x?x?x{{dtype}}>
30+
%empty = tensor.cast %empty_dyn : tensor<?x?x?x?x{{dtype}}> to {{input_tensor_type}}
31+
32+
%result = linalg.generic {
33+
indexing_maps = [
34+
affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
35+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
36+
],
37+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
38+
ins(%table : !table_tensor_type )
39+
outs(%empty : !input_tensor_type) {
40+
^bb0(%b0 : {{dtype}} , %b1 : {{dtype}}):
41+
%0 = linalg.index 0 : index
42+
%1 = linalg.index 1 : index
43+
%2 = linalg.index 2 : index
44+
%3 = linalg.index 3 : index
45+
%div = arith.divui %3, %c2 : index
46+
%mod = arith.remui %3, %c2 : index
47+
%a_cosb = math.cos %b0 : {{dtype}}
48+
%a_sinb = math.sin %b0 : {{dtype}}
49+
%real_index = arith.muli %div, %c2 : index
50+
%imag_index = arith.addi %real_index, %c1 : index
51+
%real = tensor.extract %input[%0, %1, %2, %real_index] : !input_tensor_type
52+
%imag = tensor.extract %input[%0, %1, %2, %imag_index] : !input_tensor_type
53+
%cmp = arith.cmpi eq, %mod, %c0 : index
54+
%real_t0 = arith.mulf %real, %a_cosb : {{dtype}}
55+
%real_t1 = arith.mulf %imag, %a_sinb : {{dtype}}
56+
%real_t2 = arith.subf %real_t0, %real_t1 : {{dtype}}
57+
%imag_t0 = arith.mulf %imag, %a_cosb : {{dtype}}
58+
%imag_t1 = arith.mulf %real, %a_sinb : {{dtype}}
59+
%imag_t2 = arith.addf %imag_t0, %imag_t1 : {{dtype}}
60+
%val = arith.select %cmp, %real_t2, %imag_t2 : {{dtype}}
61+
linalg.yield %val : {{dtype}}
62+
} -> !input_tensor_type
63+
64+
util.return %result : !input_tensor_type
65+
}
66+
67+
}

sharktank/sharktank/layers/rotary_embedding.py

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from .base import BaseLayer
1313
from .. import ops
14+
from .. import kernels
1415
from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor
1516

1617

@@ -25,7 +26,6 @@ def __init__(
2526
rope_freq_base: Optional[float],
2627
device: Optional[torch.device] = None,
2728
use_hf: bool = False,
28-
static_tables: bool = False,
2929
use_table: bool = True,
3030
tensor_parallelism_size: int = 1,
3131
):
@@ -34,60 +34,44 @@ def __init__(
3434
self.rope_dimension_count = rope_dimension_count
3535
self.max_seqlen = max_seqlen
3636
self.use_hf = use_hf
37-
self.static_tables = static_tables
3837
self.use_table = use_table
3938

4039
self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
4140
self.tensor_parallelism_size = tensor_parallelism_size
42-
if static_tables:
43-
ops.module_register_buffer(
44-
self, "static_rotary_embed_table", self._create_rotary_embed_table()
45-
)
46-
else:
47-
self.static_rotary_embed_table = None
4841

4942
@property
5043
def rotary_embed_table(self):
51-
if self.use_table:
52-
if self.static_tables:
53-
return self.static_rotary_embed_table
54-
return self._create_rotary_embed_table()
55-
56-
return None
44+
return self._create_rotary_embed_table()
5745

5846
def forward(
5947
self,
6048
*,
6149
xt: Union[torch.Tensor, SplitPrimitiveTensor],
6250
start_index: int,
6351
):
64-
if isinstance(xt, SplitPrimitiveTensor):
65-
rotary_shards = [None] * xt.shard_count
66-
if self.rotary_embed_table is not None:
67-
assert (
68-
isinstance(self.rotary_embed_table, ReplicatedTensor)
69-
and xt.shard_count == self.rotary_embed_table.shard_count
70-
)
71-
rotary_shards = [
72-
unbox_tensor(shard) for shard in self.rotary_embed_table.shards
73-
]
74-
75-
xt_shards = [
76-
self.forward_unsharded(
77-
xt=unbox_tensor(xt_shard),
78-
start_index=start_index,
79-
rotary_embed_table=rotary_shard,
80-
)
81-
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
82-
]
83-
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
84-
return xt
85-
else:
52+
table = self.rotary_embed_table
53+
if not isinstance(xt, SplitPrimitiveTensor):
8654
return self.forward_unsharded(
8755
xt=xt,
8856
start_index=start_index,
89-
rotary_embed_table=self.rotary_embed_table,
57+
rotary_embed_table=table,
58+
)
59+
60+
assert (
61+
isinstance(table, ReplicatedTensor) and xt.shard_count == table.shard_count
62+
)
63+
rotary_shards = [unbox_tensor(shard) for shard in table.shards]
64+
65+
xt_shards = [
66+
self.forward_unsharded(
67+
xt=unbox_tensor(xt_shard),
68+
start_index=start_index,
69+
rotary_embed_table=rotary_shard,
9070
)
71+
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
72+
]
73+
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
74+
return xt
9175

9276
def _create_interleaved_tensor(_, dim):
9377
"""Creates a tensor which indexes an tensor such that
@@ -143,18 +127,16 @@ def forward_unsharded(
143127
# Offset the table based on starting position.
144128
if self.use_table:
145129
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
146-
freqs_cis = freqs_cis[None, 0:sl, None, :]
130+
freqs_cis = freqs_cis[0:sl, :]
147131
else:
148132
freqs_cis = torch.arange(sl, device=xt.device) + start_index
149-
freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :]
133+
freqs_cis = self._compute_rotary_embed_table(freqs_cis)
150134

151135
assert (
152-
freqs_cis.shape[1] >= sl
136+
freqs_cis.shape[0] >= sl
153137
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"
154138

155-
xt_ = ops.view_as_complex(xt_)
156-
xt_ = xt_ * freqs_cis
157-
xt_out = ops.view_as_real(xt_)
139+
xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis)
158140

159141
if self.use_hf:
160142
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]
@@ -181,7 +163,7 @@ def compute_batch_mask(
181163
self.trace_tensor("rope.positions_seq", positions_seq)
182164

183165
if self.use_table:
184-
freqs_cis = self.rotary_embed_table[positions_seq]
166+
freqs_cis = self.rotary_embed_table[positions_seq.flatten()]
185167
else:
186168
shape = positions_seq.shape
187169
if isinstance(positions_seq, ReplicatedTensor):
@@ -192,11 +174,8 @@ def compute_batch_mask(
192174
freqs_cis = ReplicatedTensor(ts=ts)
193175
else:
194176
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
195-
freqs_cis = freqs_cis.unflatten(0, shape)
196177

197-
# Unsqueeze a unit dim for attention heads.
198-
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
199-
return broadcast_freqs_cis
178+
return freqs_cis
200179

201180
def apply_batched_mask(
202181
self,
@@ -232,9 +211,7 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
232211
if self.use_hf:
233212
xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])]
234213

235-
xt_ = ops.view_as_complex(xt)
236-
xt_ = xt_ * mask
237-
xt_out = ops.view_as_real(xt_)
214+
xt_out = kernels.apply_rotary_embedding(xt.to(mask.dtype), mask)
238215

239216
if self.use_hf:
240217
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]
@@ -244,14 +221,10 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
244221
def _compute_rotary_embed_table(self, t):
245222
dim = self.rope_dimension_count
246223
freqs = 1.0 / (
247-
self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
224+
self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0)
248225
)
249226
freqs = torch.outer(t, freqs).float()
250-
251-
cos = torch.cos(freqs)
252-
sin = torch.sin(freqs)
253-
complex = torch.complex(cos, sin)
254-
return complex
227+
return freqs
255228

256229
def _create_rotary_embed_table(self):
257230
t = torch.arange(self.max_seqlen, device=self.device)

sharktank/sharktank/models/llama/llama.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
6767
super().__init__(
6868
theta,
6969
context_length=config.hp.context_length,
70-
static_tables=config.static_tables,
7170
device=config.device,
7271
activation_dtype=config.activation_dtype,
7372
attention_dtype=config.attention_dtype,
@@ -92,7 +91,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
9291
max_seqlen=hp.context_length,
9392
device=self.device,
9493
use_hf=self.use_hf,
95-
static_tables=config.static_tables,
9694
tensor_parallelism_size=config.tensor_parallelism_size,
9795
),
9896
)

sharktank/tests/kernels/rotary.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import logging
8+
9+
logging.basicConfig(level=logging.DEBUG)
10+
11+
import torch
12+
import unittest
13+
14+
from sharktank import kernels
15+
from sharktank import ops
16+
17+
18+
class rotary_test(unittest.TestCase):
19+
def setUp(self):
20+
torch.manual_seed(42)
21+
22+
def test_rotary(self):
23+
dtype = torch.float32
24+
a = torch.rand([1, 128, 1, 64], dtype=dtype)
25+
rot = torch.rand([128, 32], dtype=dtype)
26+
res_b = ops.view_as_real(torch.complex(rot, rot))
27+
ref_b = torch.complex(torch.cos(rot), torch.sin(rot))
28+
29+
result = kernels.apply_rotary_embedding(a, res_b)
30+
ref = ops.view_as_real(ops.view_as_complex(a) * ref_b[None, :, None, :])
31+
torch.testing.assert_close(result, ref)

0 commit comments

Comments
 (0)