Skip to content

Commit 3294a89

Browse files
authored
Merge pull request #76 from DavidRV00/fix-gemmadecodelayer-construct
Add required world_size and rank to GemmaDecodeLayer init
2 parents 80881c2 + f249fde commit 3294a89

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

gemma/model_xla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def __init__(
527527
self.layers = nn.ModuleList()
528528
for i in range(config.num_hidden_layers):
529529
if config.architecture == gemma_config.Architecture.GEMMA_1:
530-
self.layers.append(GemmaDecoderLayer(config))
530+
self.layers.append(GemmaDecoderLayer(config, world_size, rank))
531531
elif config.architecture == gemma_config.Architecture.GEMMA_2:
532532
attn_type = (
533533
config.attn_types[i]

0 commit comments

Comments
 (0)