diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index fc3aa5fca..c7d90a19e 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -111,7 +111,7 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern attention_dtype=self.attention_dtype, tensor_parallelism_size=tensor_parallelism_size, ) - + config.attention_kernel="torch" if config.tensor_parallelism_size > 1: dataset.root_theta = shard_theta(dataset.root_theta, config) diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 90b027d13..79ec304ea 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -170,7 +170,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: attn_weights, values ) # (bs, heads, slen, head_dim) else: - is_causal = attention_mask is None and batch_seq_len == 1 + # Use the builtin attention mask when not decomposed + is_causal = True + attention_mask = None attn_output = ops.scaled_dot_product_attention( q=xq, # [bs, ..., sl, dim] k=keys, # [bs, ..., sl, dim]