Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyvene/models/intervenable_modelcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from .esm.modelings_intervenable_esm import *
from .mllama.modelings_intervenable_mllama import *
from .gpt_oss.modelings_intervenable_gpt_oss import *
from .whisper.modelings_intervenable_whisper import *
from .wav2vec2bert.modelings_intervenable_wav2vec2bert import *

#########################################################################
"""
Expand Down Expand Up @@ -89,6 +91,9 @@
hf_models.mllama.modeling_mllama.MllamaForConditionalGeneration: mllama_type_to_module_mapping,
hf_models.gpt_oss.modeling_gpt_oss.GptOssModel: gpt_oss_type_to_module_mapping,
hf_models.gpt_oss.modeling_gpt_oss.GptOssForCausalLM: gpt_oss_lm_type_to_module_mapping,
hf_models.whisper.modeling_whisper.WhisperModel: whisper_type_to_module_mapping,
hf_models.whisper.modeling_whisper.WhisperForConditionalGeneration: whisper_lm_type_to_module_mapping,
hf_models.wav2vec2_bert.modeling_wav2vec2_bert.Wav2Vec2BertModel: wav2vec2bert_type_to_module_mapping,
}
if enable_blip:
type_to_module_mapping[BlipWrapper] = blip_wrapper_type_to_module_mapping
Expand Down Expand Up @@ -135,6 +140,9 @@
hf_models.mllama.modeling_mllama.MllamaForConditionalGeneration: mllama_type_to_dimension_mapping,
hf_models.gpt_oss.modeling_gpt_oss.GptOssModel: gpt_oss_type_to_dimension_mapping,
hf_models.gpt_oss.modeling_gpt_oss.GptOssForCausalLM: gpt_oss_lm_type_to_dimension_mapping,
hf_models.whisper.modeling_whisper.WhisperModel: whisper_type_to_dimension_mapping,
hf_models.whisper.modeling_whisper.WhisperForConditionalGeneration: whisper_lm_type_to_dimension_mapping,
hf_models.wav2vec2_bert.modeling_wav2vec2_bert.Wav2Vec2BertModel: wav2vec2bert_type_to_dimension_mapping,
}

if enable_blip:
Expand Down
12 changes: 9 additions & 3 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@ def do_intervention(
# flatten
original_base_shape = base_representation.shape
if len(original_base_shape) == 2 or (
isinstance(intervention, LocalistRepresentationIntervention)
isinstance(intervention, LocalistRepresentationIntervention) or
isinstance(intervention, BoundlessRotatedSpaceIntervention) or
isinstance(intervention, VanillaIntervention) or
isinstance(intervention, CollectIntervention)
) or intervention.keep_last_dim:
# no pos dimension, e.g., gru, or opt-out concate last two dims
base_representation_f = base_representation
Expand All @@ -492,8 +495,11 @@ def do_intervention(
post_d = intervened_representation.shape[-1]

# unflatten
if len(original_base_shape) == 2 or isinstance(
intervention, LocalistRepresentationIntervention
if len(original_base_shape) == 2 or (
isinstance(intervention, LocalistRepresentationIntervention) or
isinstance(intervention, BoundlessRotatedSpaceIntervention) or
isinstance(intervention, VanillaIntervention) or
isinstance(intervention, CollectIntervention)
) or intervention.keep_last_dim:
# no pos dimension, e.g., gru or opt-out concate last two dims
pass
Expand Down
Empty file.
61 changes: 61 additions & 0 deletions pyvene/models/wav2vec2bert/modelings_intervenable_wav2vec2bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Each modeling file in this library is a mapping between
abstract naming of intervention anchor points and actual
model module defined in the huggingface library.
We also want to let the intervention library know how to
config the dimensions of intervention based on model config
defined in the huggingface library.
"""
import torch
from ..constants import *

wav2vec2bert_type_to_module_mapping = {
"block_input": ("encoder.layers[%s]", CONST_INPUT_HOOK),
"block_output": ("encoder.layers[%s]", CONST_OUTPUT_HOOK),
"ffn1_activation": ("encoder.layers[%s].ffn1.intermediate_act_fn", CONST_OUTPUT_HOOK),
"ffn1_output": ("encoder.layers[%s].ffn1", CONST_OUTPUT_HOOK),
"ffn1_input": ("encoder.layers[%s].ffn1", CONST_INPUT_HOOK),
"ffn2_activation": ("encoder.layers[%s].ffn2.intermediate_act_fn", CONST_OUTPUT_HOOK),
Comment on lines +15 to +18
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Do not map Wav2Vec2Bert ffn activations to intermediate_act_fn

ffn1.intermediate_act_fn and ffn2.intermediate_act_fn are callables rather than nn.Modules in the HF implementation, so get_module_hook (which calls register_forward_hook) will fail when users request ffn*_activation interventions. This makes those intervention points unusable and will raise at hook registration time.

Useful? React with 👍 / 👎.

"ffn2_output": ("encoder.layers[%s].ffn2", CONST_OUTPUT_HOOK),
"ffn2_input": ("encoder.layers[%s].ffn2", CONST_INPUT_HOOK),
"attention_value_output": ("encoder.layers[%s].self_attn.linear_out", CONST_INPUT_HOOK),
"head_attention_value_output": ("encoder.layers[%s].self_attn.linear_out", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("encoder.layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("encoder.layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("encoder.layers[%s].self_attn.linear_q", CONST_OUTPUT_HOOK),
"key_output": ("encoder.layers[%s].self_attn.linear_k", CONST_OUTPUT_HOOK),
"value_output": ("encoder.layers[%s].self_attn.linear_v", CONST_OUTPUT_HOOK),
"head_query_output": ("encoder.layers[%s].self_attn.linear_q", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("encoder.layers[%s].self_attn.linear_k", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_value_output": ("encoder.layers[%s].self_attn.linear_v", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"conv_output": ("encoder.layers[%s].conv_module", CONST_OUTPUT_HOOK),
"conv_input": ("encoder.layers[%s].conv_module", CONST_INPUT_HOOK),
"conv_glu_output": ("encoder.layers[%s].conv_module.glu", CONST_OUTPUT_HOOK),
"conv_depth_output": ("encoder.layers[%s].conv_module.depthwise_conv", CONST_OUTPUT_HOOK),
}

wav2vec2bert_type_to_dimension_mapping = {
"n_head": ("num_attention_heads",),
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"ffn1_activation": ("intermediate_size",),
"ffn1_output": ("hidden_size",),
"ffn1_input": ("hidden_size",),
"ffn2_activation": ("intermediate_size",),
"ffn2_output": ("hidden_size",),
"ffn2_input": ("hidden_size",),
"attention_value_output": ("hidden_size",),
"head_attention_value_output": ("hidden_size/num_attention_heads",),
"attention_output": ("hidden_size",),
"attention_input": ("hidden_size",),
"query_output": ("hidden_size",),
"key_output": ("hidden_size",),
"value_output": ("hidden_size",),
"head_query_output": ("hidden_size/num_attention_heads",),
"head_key_output": ("hidden_size/num_attention_heads",),
"head_value_output": ("hidden_size/num_attention_heads",),
"conv_output": ("hidden_size",),
"conv_input": ("hidden_size",),
"conv_glu_output": ("hidden_size",),
"conv_depth_output": ("hidden_size",),
}
Empty file.
54 changes: 54 additions & 0 deletions pyvene/models/whisper/modelings_intervenable_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Each modeling file in this library is a mapping between
abstract naming of intervention anchor points and actual
model module defined in the huggingface library.
We also want to let the intervention library know how to
config the dimensions of intervention based on model config
defined in the huggingface library.
"""
import torch
from ..constants import *

whisper_type_to_module_mapping = {
"block_input": ("encoder.layers[%s]", CONST_INPUT_HOOK),
"block_output": ("encoder.layers[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("encoder.layers[%s].activation_fn", CONST_OUTPUT_HOOK),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid hooking Whisper activation_fn callable

In HuggingFace Whisper, encoder.layers[*].activation_fn is a plain callable from ACT2FN, not an nn.Module. get_module_hook later calls register_forward_hook on the resolved object (see pyvene/models/modeling_utils.py:get_module_hook), so registering mlp_activation will raise an AttributeError at runtime whenever a user selects that intervention. Consider hooking a real module (e.g., fc1 output) or wrapping the activation in an nn.Module before registering.

Useful? React with 👍 / 👎.

"mlp_output": ("encoder.layers[%s].fc2", CONST_OUTPUT_HOOK),
"mlp_input": ("encoder.layers[%s].fc1", CONST_INPUT_HOOK),
"attention_value_output": ("encoder.layers[%s].self_attn.out_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("encoder.layers[%s].self_attn.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("encoder.layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("encoder.layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("encoder.layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("encoder.layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("encoder.layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("encoder.layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("encoder.layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_value_output": ("encoder.layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
}

whisper_type_to_dimension_mapping = {
"n_head": ("encoder_attention_heads",),
"block_input": ("d_model",),
"block_output": ("d_model",),
"mlp_activation": ("encoder_ffn_dim",),
"mlp_output": ("d_model",),
"mlp_input": ("d_model",),
"attention_value_output": ("d_model",),
"head_attention_value_output": ("d_model/encoder_attention_heads",),
"attention_output": ("d_model",),
"attention_input": ("d_model",),
"query_output": ("d_model",),
"key_output": ("d_model",),
"value_output": ("d_model",),
"head_query_output": ("d_model/encoder_attention_heads",),
"head_key_output": ("d_model/encoder_attention_heads",),
"head_value_output": ("d_model/encoder_attention_heads",),
}

"""whisper model with LM head"""
whisper_lm_type_to_module_mapping = {}
for k, v in whisper_type_to_module_mapping.items():
whisper_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]
whisper_lm_type_to_dimension_mapping = whisper_type_to_dimension_mapping

Loading