diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index fc3aa5fca..c7d90a19e 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -111,7 +111,7 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern attention_dtype=self.attention_dtype, tensor_parallelism_size=tensor_parallelism_size, ) - + config.attention_kernel="torch" if config.tensor_parallelism_size > 1: dataset.root_theta = shard_theta(dataset.root_theta, config)