Skip to content

How did you initialize llama? #98

@brando90

Description

@brando90

My code

ef reinitialize_weights_gpt_neox_20B_inspired_4_llama2(model):
    """
    Note: we nearly gpt-neox_20B (2022) & llama1 , llama2 (2019) does not say how they init

    I think gpt-neox_20B & llama2 both have pre-layernorm, because transformers without tears uses the init that gpt-neox-20B uses and llama1 says it uses prenorm,
    so maybe both pre-layernorm.
    Thus, I hope transformers without tears init/the same that transformers without tears uses works. 
    
    Init:
    FF layer: (as Wang 2021, not transformers without tears)
        -> W ~ N(0, 3/L * sqrt(D))
        decided that cuz 2021 is later than transformers without tears (2019 Nguyen, Salazer)
    Other layers (as transformers without tears(2019 Nguyen, Salazer)):
        -> W ~ N(0, sqrt(2 / (d + 4d)))
    norm_layer
        gpt-neox_20B: uses layer_norm
        llama2 uses llama1 which uses: RMSNorm (Zhang and Sennrich (2019))
        decided not to copy gpt-neox_20B (which uses W ~ N(0, sqrt(2 / (d + 4d)))) 
        because they don't share the same norm. llama1/2 use RMSnorm:
            mean_a_i = g_i * a_i / sqrt(1/n sum_j a_j^2 ) [where is eps?]
        So I decided
        -> g_i (gain) ~ constant(1)
        since we aren't training to completion so perhaps it helps at the beginning. If it diverges we can set this to small or what gpt-neox_20B uses.
        There is no offset, but I will set it to 0 in the code WLG.
    Activation:
        SwiGLU (not relu for llama1, llama2) [us for baby llama2]
        gpt-neox_20B uses...doesn't say.
    We use normal distribution because transformers without tears uses it & since gpt-neox_20B uses nearly same inits llama2 likely does too. 

    refs: rmsnorm https://arxiv.org/pdf/1910.07467.pdf
    refs: llama1 since llama2 uses same arch https://arxiv.org/pdf/2302.13971.pdf 
    ref: pytorch inits https://pytorch.org/docs/stable/nn.init.html

    ref: llama2 7b config: https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L13 
    ref: later https://discuss.huggingface.co/t/how-to-choose-std-for-weight-init-for-llama-2-after-reinitialize/69702

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 96, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=96, out_features=96, bias=False)
          (k_proj): Linear(in_features=96, out_features=96, bias=False)
          (v_proj): Linear(in_features=96, out_features=96, bias=False)
          (o_proj): Linear(in_features=96, out_features=96, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=96, out_features=11008, bias=False)
          (up_proj): Linear(in_features=96, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=96, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=96, out_features=32000, bias=False)
)
    return get_smaller_llama2(hidden_size=32*3, num_hidden_layers=32, verbose=verbose)
    so in_featres = 96 ==> D=96
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            D = module.in_features  # I think this is right size it's xW []
            L = module.weight.shape[1]
            nn.init.normal_(module.weight, mean=0, std=3 / (L * (D)**0.5))
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif 'norm' in name.lower() or 'norm' in str(module).lower():
            if module.weight is not None:  # todo: idk if needed for layer norm
                nn.init.constant_(module.weight, 1)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        else:
            D = module.weight.shape[0]
            L = module.weight.shape[1]
            nn.init.normal_(module.weight, mean=0, std= (2 / (D + 4*D))**0.5 )
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions