diff --git a/sharktank/sharktank/kernels/rotary.py b/sharktank/sharktank/kernels/rotary.py index d44e7e4a1..196fc32c2 100644 --- a/sharktank/sharktank/kernels/rotary.py +++ b/sharktank/sharktank/kernels/rotary.py @@ -44,10 +44,11 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): sl = input.type.shape[1] sl = "D" if sl < 0 else sl heads = input.type.shape[2] + dims = input.type.shape[3] template_file = "rotary_embedding.mlir" target_function_name = ( - f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{input_dtype}" + f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{dims}_{input_dtype}" ) # Template params. @@ -63,6 +64,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): bs=bs, sl=sl, heads=heads, + dims=dims, dtype=str(input_dtype), ) kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/rotary_embedding.mlir b/sharktank/sharktank/kernels/templates/rotary_embedding.mlir index 854fd8a45..adec6805b 100644 --- a/sharktank/sharktank/kernels/templates/rotary_embedding.mlir +++ b/sharktank/sharktank/kernels/templates/rotary_embedding.mlir @@ -9,7 +9,7 @@ module { -util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type { +util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dims}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -22,16 +22,12 @@ util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dtype}}( %d2 = tensor.dim %input, %c2 : !input_tensor_type %d3 = tensor.dim %input, %c3 : !input_tensor_type - %dim = tensor.dim %table, %c1 : !table_tensor_type - %hdim = arith.divui %dim, %c2 : index - - %empty_dyn = tensor.empty(%d0, %d1, %d2, %d3) : tensor %empty = tensor.cast %empty_dyn : tensor to {{input_tensor_type}} %result = linalg.generic { indexing_maps = [ - affine_map<(d0, d1, d2, d3) -> (d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> ], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index ff630e899..623c02ea6 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -136,6 +136,7 @@ def forward_unsharded( freqs_cis.shape[0] >= sl ), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})" + freqs_cis = ops.repeat(freqs_cis[None, :, :], (xt_.shape[0], 1, 1)) xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis) if self.use_hf: @@ -175,7 +176,7 @@ def compute_batch_mask( else: freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) - return freqs_cis + return freqs_cis.unsqueeze(1) def apply_batched_mask( self,