Add LLaMA support to embed_to_distrib#235
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 5d6a68f181
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| wte = model.wte if hasattr(model, "wte") else model.transformer.wte | ||
| vocab = torch.matmul(embed, wte.weight.t()) |
There was a problem hiding this comment.
Use GPT-2 LM head for vocab projection
This branch now handles GPT2LMHeadModel, but it always projects through wte.weight; that gives incorrect logits whenever output embeddings are not tied to input embeddings (e.g., tie_word_embeddings=False or checkpoints where lm_head has diverged during fine-tuning). In those valid GPT-2 configurations, embed_to_distrib will silently return the wrong distribution, so the projection should prefer model.lm_head when present and only fall back to wte for models without an LM head.
Useful? React with 👍 / 👎.
Description
embed_to_distrib()inpyvene/models/basic_utils.py, removing the previousassert Falseso causal tracing and embed→vocab distribution work for LLaMA/LlamaForCausalLM.model.config.architecturesbeingNonewhen models are built from config (e.g. in tests) by falling back totype(model).__name__.GPT2ModelandGPT2LMHeadModelby usingmodel.wteormodel.transformer.wteas appropriate.tests/unit_tests/BasicUtilsTestCase.pyfor GPT-2 and LLaMA (logits=Trueand softmax output, shape and sum-to-one checks).Testing Done
python -m unittest tests.unit_tests.BasicUtilsTestCase -v(withPYTHONPATHset to repo root).test_embed_to_distrib_gpt2_logits,test_embed_to_distrib_gpt2_softmax,test_embed_to_distrib_llama_logits,test_embed_to_distrib_llama_softmax.tests.unit_tests.CausalModelTestCasewith no regressions.Checklist:
Authors