forked from stanfordnlp/pyvene
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathintervenable_modelcard.py
More file actions
153 lines (144 loc) · 9.67 KB
/
intervenable_modelcard.py
File metadata and controls
153 lines (144 loc) · 9.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from .constants import *
from .llama.modelings_intervenable_llama import *
from .mistral.modellings_intervenable_mistral import *
from .gemma.modelings_intervenable_gemma import *
from .gemma2.modelings_intervenable_gemma2 import *
from .gpt2.modelings_intervenable_gpt2 import *
from .gpt_neo.modelings_intervenable_gpt_neo import *
from .gpt_neox.modelings_intervenable_gpt_neox import *
from .mlp.modelings_intervenable_mlp import *
from .gru.modelings_intervenable_gru import *
from .blip.modelings_intervenable_blip import *
from .blip.modelings_intervenable_blip_itm import *
from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import *
from .llava.modelings_intervenable_llava import *
from .qwen2.modelings_intervenable_qwen2 import *
from .olmo.modelings_intervenable_olmo import *
from .olmo2.modelings_intervenable_olmo2 import *
from .qwen3.modelings_intervenable_qwen3 import *
from .esm.modelings_intervenable_esm import *
from .mllama.modelings_intervenable_mllama import *
from .gpt_oss.modelings_intervenable_gpt_oss import *
from .whisper.modelings_intervenable_whisper import *
from .wav2vec2bert.modelings_intervenable_wav2vec2bert import *
#########################################################################
"""
Below are functions that you need to modify if you add
a new model arch type in this library.
We put them in front so it is easier to keep track of
things that need to be changed.
"""
import transformers.models as hf_models
from .mlp.modelings_mlp import MLPModel, MLPForClassification
from .gru.modelings_gru import GRUModel, GRULMHeadModel, GRUForClassification
from .backpack_gpt2.modelings_backpack_gpt2 import BackpackGPT2LMHeadModel
enable_blip = True
try:
from .blip.modelings_blip import BlipWrapper
from .blip.modelings_blip_itm import BlipITMWrapper
except:
print("Failed to import blip model, skipping.")
enable_blip = False
global type_to_module_mapping
global type_to_dimension_mapping
global output_to_subcomponent_fn_mapping
global scatter_intervention_output_fn_mapping
type_to_module_mapping = {
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_module_mapping,
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_module_mapping,
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_module_mapping,
hf_models.llava.modeling_llava.LlavaForConditionalGeneration: llava_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_module_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_module_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_module_mapping,
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_module_mapping,
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_module_mapping,
hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_module_mapping,
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_module_mapping,
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_module_mapping,
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_module_mapping,
hf_models.olmo2.modeling_olmo2.Olmo2Model: olmo2_type_to_module_mapping,
hf_models.olmo2.modeling_olmo2.Olmo2ForCausalLM: olmo2_lm_type_to_module_mapping,
hf_models.qwen3.modeling_qwen3.Qwen3Model: qwen3_type_to_module_mapping,
hf_models.qwen3.modeling_qwen3.Qwen3ForCausalLM: qwen3_lm_type_to_module_mapping,
hf_models.esm.modeling_esm.EsmModel: esm_type_to_module_mapping,
hf_models.esm.modeling_esm.EsmForMaskedLM: esm_mlm_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_module_mapping,
MLPModel: mlp_type_to_module_mapping,
MLPForClassification: mlp_classifier_type_to_module_mapping,
GRUModel: gru_type_to_module_mapping,
GRULMHeadModel: gru_lm_type_to_module_mapping,
GRUForClassification: gru_classifier_type_to_module_mapping,
BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_module_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2Model: qwen2_type_to_module_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2ForCausalLM: qwen2_lm_type_to_module_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification: qwen2_classifier_type_to_module_mapping,
hf_models.mllama.modeling_mllama.MllamaForConditionalGeneration: mllama_type_to_module_mapping,
hf_models.gpt_oss.modeling_gpt_oss.GptOssModel: gpt_oss_type_to_module_mapping,
hf_models.gpt_oss.modeling_gpt_oss.GptOssForCausalLM: gpt_oss_lm_type_to_module_mapping,
hf_models.whisper.modeling_whisper.WhisperModel: whisper_type_to_module_mapping,
hf_models.whisper.modeling_whisper.WhisperForConditionalGeneration: whisper_lm_type_to_module_mapping,
hf_models.wav2vec2_bert.modeling_wav2vec2_bert.Wav2Vec2BertModel: wav2vec2bert_type_to_module_mapping,
}
if enable_blip:
type_to_module_mapping[BlipWrapper] = blip_wrapper_type_to_module_mapping
type_to_module_mapping[BlipITMWrapper] = blip_wrapper_type_to_module_mapping
type_to_dimension_mapping = {
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_dimension_mapping,
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_dimension_mapping,
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_dimension_mapping,
hf_models.llava.modeling_llava.LlavaForConditionalGeneration: llava_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_dimension_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_dimension_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_dimension_mapping,
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_dimension_mapping,
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_dimension_mapping,
hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_dimension_mapping,
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_dimension_mapping,
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_dimension_mapping,
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_dimension_mapping,
hf_models.olmo2.modeling_olmo2.Olmo2Model: olmo2_type_to_dimension_mapping,
hf_models.olmo2.modeling_olmo2.Olmo2ForCausalLM: olmo2_lm_type_to_dimension_mapping,
hf_models.qwen3.modeling_qwen3.Qwen3Model: qwen3_type_to_dimension_mapping,
hf_models.qwen3.modeling_qwen3.Qwen3ForCausalLM: qwen3_lm_type_to_dimension_mapping,
hf_models.esm.modeling_esm.EsmModel: esm_type_to_dimension_mapping,
hf_models.esm.modeling_esm.EsmForMaskedLM: esm_mlm_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_dimension_mapping,
MLPModel: mlp_type_to_dimension_mapping,
MLPForClassification: mlp_classifier_type_to_dimension_mapping,
GRUModel: gru_type_to_dimension_mapping,
GRULMHeadModel: gru_lm_type_to_dimension_mapping,
GRUForClassification: gru_classifier_type_to_dimension_mapping,
BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_dimension_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2Model: qwen2_type_to_dimension_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2ForCausalLM: qwen2_lm_type_to_dimension_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification: qwen2_classifier_type_to_dimension_mapping,
hf_models.mllama.modeling_mllama.MllamaForConditionalGeneration: mllama_type_to_dimension_mapping,
hf_models.gpt_oss.modeling_gpt_oss.GptOssModel: gpt_oss_type_to_dimension_mapping,
hf_models.gpt_oss.modeling_gpt_oss.GptOssForCausalLM: gpt_oss_lm_type_to_dimension_mapping,
hf_models.whisper.modeling_whisper.WhisperModel: whisper_type_to_dimension_mapping,
hf_models.whisper.modeling_whisper.WhisperForConditionalGeneration: whisper_lm_type_to_dimension_mapping,
hf_models.wav2vec2_bert.modeling_wav2vec2_bert.Wav2Vec2BertModel: wav2vec2bert_type_to_dimension_mapping,
}
if enable_blip:
type_to_dimension_mapping[BlipWrapper] = blip_wrapper_type_to_dimension_mapping
type_to_dimension_mapping[BlipITMWrapper] = (
blip_itm_wrapper_type_to_dimension_mapping
)
#########################################################################