Skip to content

Commit 9e33390

Browse files
aryamanaroraclaude
andcommitted
Fix head_dim fallback in per-head dimension mappings for GQA models
Add head_dim as primary proposal with hidden_size/num_attention_heads as fallback for head_attention_value_output, head_query_output, head_key_output, and head_value_output across all GQA models. This fixes models like GPT-OSS 20B where hidden_size != num_attention_heads * head_dim. Also fixes typo "hhead_dim" -> "head_dim" in gemma. Fixes #229 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 68fe7b8 commit 9e33390

8 files changed

Lines changed: 29 additions & 29 deletions

pyvene/models/gemma/modelings_intervenable_gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"value_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
5050
"head_query_output": ("head_dim",),
5151
"head_key_output": ("head_dim",),
52-
"head_value_output": ("hhead_dim",),
52+
"head_value_output": ("head_dim",),
5353
}
5454

5555

pyvene/models/gpt_oss/modelings_intervenable_gpt_oss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@
6666
"attention_input": ("hidden_size",),
6767
"attention_output": ("hidden_size",),
6868
"attention_value_output": ("hidden_size",),
69-
"head_attention_value_output": ("hidden_size/num_attention_heads",),
69+
"head_attention_value_output": ("head_dim", "hidden_size/num_attention_heads",),
7070
"query_output": ("hidden_size",),
7171
"key_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
7272
"value_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
73-
"head_query_output": ("hidden_size/num_attention_heads",),
74-
"head_key_output": ("hidden_size/num_key_value_heads",),
75-
"head_value_output": ("hidden_size/num_key_value_heads",),
73+
"head_query_output": ("head_dim", "hidden_size/num_attention_heads",),
74+
"head_key_output": ("head_dim", "hidden_size/num_key_value_heads",),
75+
"head_value_output": ("head_dim", "hidden_size/num_key_value_heads",),
7676
}
7777

7878

pyvene/models/llama/modelings_intervenable_llama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@
4141
"mlp_output": ("hidden_size",),
4242
"mlp_input": ("hidden_size",),
4343
"attention_value_output": ("hidden_size",),
44-
"head_attention_value_output": ("hidden_size/num_attention_heads",),
44+
"head_attention_value_output": ("head_dim", "hidden_size/num_attention_heads",),
4545
"attention_output": ("hidden_size",),
4646
"attention_input": ("hidden_size",),
4747
"query_output": ("hidden_size",),
4848
"key_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
4949
"value_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
50-
"head_query_output": ("hidden_size/num_attention_heads",),
51-
"head_key_output": ("hidden_size/num_attention_heads",),
52-
"head_value_output": ("hidden_size/num_attention_heads",),
50+
"head_query_output": ("head_dim", "hidden_size/num_attention_heads",),
51+
"head_key_output": ("head_dim", "hidden_size/num_attention_heads",),
52+
"head_value_output": ("head_dim", "hidden_size/num_attention_heads",),
5353
}
5454

5555

pyvene/models/mistral/modellings_intervenable_mistral.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@
4141
"mlp_output": ("hidden_size",),
4242
"mlp_input": ("hidden_size",),
4343
"attention_value_output": ("hidden_size",),
44-
"head_attention_value_output": ("hidden_size/num_attention_heads",),
44+
"head_attention_value_output": ("head_dim", "hidden_size/num_attention_heads",),
4545
"attention_output": ("hidden_size",),
4646
"attention_input": ("hidden_size",),
4747
"query_output": ("hidden_size",),
4848
"key_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
4949
"value_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
50-
"head_query_output": ("hidden_size/num_attention_heads",),
51-
"head_key_output": ("hidden_size/num_attention_heads",),
52-
"head_value_output": ("hidden_size/num_attention_heads",),
50+
"head_query_output": ("head_dim", "hidden_size/num_attention_heads",),
51+
"head_key_output": ("head_dim", "hidden_size/num_attention_heads",),
52+
"head_value_output": ("head_dim", "hidden_size/num_attention_heads",),
5353
}
5454

5555

pyvene/models/olmo/modelings_intervenable_olmo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@
4242
"mlp_output": ("hidden_size",),
4343
"mlp_input": ("hidden_size",),
4444
"attention_value_output": ("hidden_size",),
45-
"head_attention_value_output": ("hidden_size/num_attention_heads",),
45+
"head_attention_value_output": ("head_dim", "hidden_size/num_attention_heads",),
4646
"attention_output": ("hidden_size",),
4747
"attention_input": ("hidden_size",),
4848
"query_output": ("hidden_size",),
4949
"key_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
5050
"value_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
51-
"head_query_output": ("hidden_size/num_attention_heads",),
52-
"head_key_output": ("hidden_size/num_attention_heads",),
53-
"head_value_output": ("hidden_size/num_attention_heads",),
51+
"head_query_output": ("head_dim", "hidden_size/num_attention_heads",),
52+
"head_key_output": ("head_dim", "hidden_size/num_attention_heads",),
53+
"head_value_output": ("head_dim", "hidden_size/num_attention_heads",),
5454
}
5555

5656

pyvene/models/olmo2/modelings_intervenable_olmo2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@
4242
"mlp_output": ("hidden_size",),
4343
"mlp_input": ("hidden_size",),
4444
"attention_value_output": ("hidden_size",),
45-
"head_attention_value_output": ("hidden_size/num_attention_heads",),
45+
"head_attention_value_output": ("head_dim", "hidden_size/num_attention_heads",),
4646
"attention_output": ("hidden_size",),
4747
"attention_input": ("hidden_size",),
4848
"query_output": ("hidden_size",),
4949
"key_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
5050
"value_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
51-
"head_query_output": ("hidden_size/num_attention_heads",),
52-
"head_key_output": ("hidden_size/num_attention_heads",),
53-
"head_value_output": ("hidden_size/num_attention_heads",),
51+
"head_query_output": ("head_dim", "hidden_size/num_attention_heads",),
52+
"head_key_output": ("head_dim", "hidden_size/num_attention_heads",),
53+
"head_value_output": ("head_dim", "hidden_size/num_attention_heads",),
5454
}
5555

5656

pyvene/models/qwen2/modelings_intervenable_qwen2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@
3636
"mlp_output": ("hidden_size",),
3737
"mlp_input": ("hidden_size",),
3838
"attention_value_output": ("hidden_size",),
39-
"head_attention_value_output": ("hidden_size/num_attention_heads",),
39+
"head_attention_value_output": ("head_dim", "hidden_size/num_attention_heads",),
4040
"attention_output": ("hidden_size",),
4141
"attention_input": ("hidden_size",),
4242
"query_output": ("hidden_size",),
4343
"key_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
4444
"value_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
45-
"head_query_output": ("hidden_size/num_attention_heads",),
46-
"head_key_output": ("hidden_size/num_attention_heads",),
47-
"head_value_output": ("hidden_size/num_attention_heads",),
45+
"head_query_output": ("head_dim", "hidden_size/num_attention_heads",),
46+
"head_key_output": ("head_dim", "hidden_size/num_attention_heads",),
47+
"head_value_output": ("head_dim", "hidden_size/num_attention_heads",),
4848
}
4949

5050
"""qwen2 model with LM head"""

pyvene/models/qwen3/modelings_intervenable_qwen3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@
3636
"mlp_output": ("hidden_size",),
3737
"mlp_input": ("hidden_size",),
3838
"attention_value_output": ("hidden_size",),
39-
"head_attention_value_output": ("hidden_size/num_attention_heads",),
39+
"head_attention_value_output": ("head_dim", "hidden_size/num_attention_heads",),
4040
"attention_output": ("hidden_size",),
4141
"attention_input": ("hidden_size",),
4242
"query_output": ("hidden_size",),
4343
"key_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
4444
"value_output": ("num_key_value_heads*hidden_size/num_attention_heads",),
45-
"head_query_output": ("hidden_size/num_attention_heads",),
46-
"head_key_output": ("hidden_size/num_attention_heads",),
47-
"head_value_output": ("hidden_size/num_attention_heads",),
45+
"head_query_output": ("head_dim", "hidden_size/num_attention_heads",),
46+
"head_key_output": ("head_dim", "hidden_size/num_attention_heads",),
47+
"head_value_output": ("head_dim", "hidden_size/num_attention_heads",),
4848
}
4949

5050
"""qwen3 model with LM head"""

0 commit comments

Comments
 (0)