diff --git a/model.py b/model.py index a47d00dd..ce73172e 100644 --- a/model.py +++ b/model.py @@ -62,7 +62,7 @@ def from_name(cls, name: str): "30B": dict(n_layer=60, n_head=52, dim=6656), "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), - "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), + "Mistral-7B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), "stories15M": dict(n_layer=6, n_head=6, dim=288), "stories110M": dict(n_layer=12, n_head=12, dim=768), } @@ -111,7 +111,7 @@ def setup_caches(self, max_batch_size, max_seq_length): for b in self.layers: b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) - self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) + self.freqs_cis = precompute_freqs_cis(self.max_seq_length, self.config.dim // self.config.n_head, self.config.rope_base, dtype) self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: