Skip to content

Commit

Permalink
Update rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Dec 19, 2024
1 parent f1caf4d commit 98580bf
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
4 changes: 3 additions & 1 deletion sharktank/sharktank/kernels/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
sl = input.type.shape[1]
sl = "D" if sl < 0 else sl
heads = input.type.shape[2]
dims = input.type.shape[3]

template_file = "rotary_embedding.mlir"
target_function_name = (
f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{input_dtype}"
f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{dims}_{input_dtype}"
)

# Template params.
Expand All @@ -63,6 +64,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
bs=bs,
sl=sl,
heads=heads,
dims=dims,
dtype=str(input_dtype),
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
8 changes: 2 additions & 6 deletions sharktank/sharktank/kernels/templates/rotary_embedding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

module {

util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type {
util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dims}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type {

%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -22,16 +22,12 @@ util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dtype}}(
%d2 = tensor.dim %input, %c2 : !input_tensor_type
%d3 = tensor.dim %input, %c3 : !input_tensor_type

%dim = tensor.dim %table, %c1 : !table_tensor_type
%hdim = arith.divui %dim, %c2 : index


%empty_dyn = tensor.empty(%d0, %d1, %d2, %d3) : tensor<?x?x?x?x{{dtype}}>
%empty = tensor.cast %empty_dyn : tensor<?x?x?x?x{{dtype}}> to {{input_tensor_type}}

%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def forward_unsharded(
freqs_cis.shape[0] >= sl
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"

freqs_cis = ops.repeat(freqs_cis[None, :, :], (xt_.shape[0], 1, 1))
xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis)

if self.use_hf:
Expand Down Expand Up @@ -175,7 +176,7 @@ def compute_batch_mask(
else:
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())

return freqs_cis
return freqs_cis.unsqueeze(1)

def apply_batched_mask(
self,
Expand Down

0 comments on commit 98580bf

Please sign in to comment.