Skip to content

Commit 961931e

Browse files
authored
Merge pull request #139 from stanfordnlp/zen/addback_model
[P0] Adding back GPT2 and other model supports
2 parents faf581b + 2713898 commit 961931e

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

pyvene/models/intervenable_modelcard.py

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .mlp.modelings_intervenable_mlp import *
88
from .gru.modelings_intervenable_gru import *
99
from .blip.modelings_intervenable_blip import *
10+
from .blip.modelings_intervenable_blip_itm import *
1011
from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import *
1112

1213

@@ -21,6 +22,7 @@
2122

2223
import transformers.models as hf_models
2324
from .blip.modelings_blip import BlipWrapper
25+
from .blip.modelings_blip_itm import BlipITMWrapper
2426
from .mlp.modelings_mlp import MLPModel, MLPForClassification
2527
from .gru.modelings_gru import GRUModel, GRULMHeadModel, GRUForClassification
2628
from .backpack_gpt2.modelings_backpack_gpt2 import BackpackGPT2LMHeadModel
@@ -34,6 +36,7 @@
3436
type_to_module_mapping = {
3537
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_module_mapping,
3638
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_module_mapping,
39+
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_module_mapping,
3740
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping,
3841
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping,
3942
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping,
@@ -43,7 +46,9 @@
4346
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_module_mapping,
4447
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_module_mapping,
4548
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
49+
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_module_mapping,
4650
BlipWrapper: blip_wrapper_type_to_module_mapping,
51+
BlipITMWrapper: blip_itm_wrapper_type_to_module_mapping,
4752
MLPModel: mlp_type_to_module_mapping,
4853
MLPForClassification: mlp_classifier_type_to_module_mapping,
4954
GRUModel: gru_type_to_module_mapping,
@@ -57,6 +62,7 @@
5762
type_to_dimension_mapping = {
5863
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_dimension_mapping,
5964
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_dimension_mapping,
65+
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_dimension_mapping,
6066
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping,
6167
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping,
6268
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping,
@@ -66,7 +72,9 @@
6672
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_dimension_mapping,
6773
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_dimension_mapping,
6874
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,
75+
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_dimension_mapping,
6976
BlipWrapper: blip_wrapper_type_to_dimension_mapping,
77+
BlipITMWrapper: blip_itm_wrapper_type_to_dimension_mapping,
7078
MLPModel: mlp_type_to_dimension_mapping,
7179
MLPForClassification: mlp_classifier_type_to_dimension_mapping,
7280
GRUModel: gru_type_to_dimension_mapping,

0 commit comments

Comments
 (0)