Skip to content

Commit ae2f82d

Browse files
committed
Added gemma-7b performance
1 parent ef055fc commit ae2f82d

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

model.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class ModelArgs:
2626
dim: int = 4096
2727
intermediate_size: int = None
2828
n_local_heads: int = -1
29-
head_dim: int = 64
29+
head_dim: int = None
3030
rope_base: float = 10000
3131
norm_eps: float = 1e-5
3232

@@ -37,7 +37,8 @@ def __post_init__(self):
3737
hidden_dim = 4 * self.dim
3838
n_hidden = int(2 * hidden_dim / 3)
3939
self.intermediate_size = find_multiple(n_hidden, 256)
40-
self.head_dim = self.dim // self.n_head
40+
if self.head_dim is None:
41+
self.head_dim = self.dim // self.n_head
4142

4243
@classmethod
4344
def from_name(cls, name: str):
@@ -51,6 +52,7 @@ def from_name(cls, name: str):
5152

5253
transformer_configs = {
5354
"gemma-2b": dict(dim=2048, vocab_size=256000, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384),
55+
"gemma-7b": dict(dim=3072, vocab_size=256000, n_layer=28, n_head=16, n_local_heads=16, intermediate_size=24576, head_dim=256),
5456
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
5557
"7B": dict(n_layer=32, n_head=32, dim=4096),
5658
"13B": dict(n_layer=40, n_head=40, dim=5120),
@@ -95,14 +97,13 @@ def __init__(self, config: ModelArgs) -> None:
9597
def setup_caches(self, max_batch_size, max_seq_length):
9698
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
9799
return
98-
head_dim = self.config.dim // self.config.n_head
99100
max_seq_length = find_multiple(max_seq_length, 8)
100101
self.max_seq_length = max_seq_length
101102
self.max_batch_size = max_batch_size
102103
for b in self.layers:
103-
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)
104+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, self.config.head_dim)
104105

105-
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
106+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base)
106107
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
107108

108109
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
@@ -145,7 +146,7 @@ def __init__(self, config: ModelArgs):
145146
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
146147
# key, query, value projections for all heads, but in a batch
147148
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
148-
self.wo = nn.Linear(config.dim, config.dim, bias=False)
149+
self.wo = nn.Linear(config.n_head * config.head_dim, config.dim, bias=False)
149150
self.kv_cache = None
150151

151152
self.n_head = config.n_head
@@ -165,7 +166,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
165166
bsz, seqlen, _ = x.shape
166167

167168
kv_size = self.n_local_heads * self.head_dim
168-
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
169+
q, k, v = self.wqkv(x).split([self.n_head * self.head_dim, kv_size, kv_size], dim=-1)
169170

170171
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
171172
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
@@ -183,7 +184,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
183184
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
184185
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
185186

186-
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
187+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.n_head * self.head_dim)
187188

188189
y = self.wo(y)
189190
return y
@@ -197,7 +198,7 @@ def __init__(self, config: ModelArgs) -> None:
197198
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
198199

199200
def forward(self, x: Tensor) -> Tensor:
200-
return self.w2(F.gelu(self.w1(x)) * self.w3(x))
201+
return self.w2(F.gelu(self.w1(x), approximate="tanh") * self.w3(x))
201202

202203

203204
class RMSNorm(nn.Module):

0 commit comments

Comments
 (0)