Skip to content

Commit 7504fb6

Browse files
authored
Merge pull request #129 from stanfordnlp/zen/sharedwxpos
[Minor] Allow sharing interventions across multiple positions
2 parents 96db4e9 + 271fa09 commit 7504fb6

File tree

3 files changed

+133
-49
lines changed

3 files changed

+133
-49
lines changed

pyvene/models/interventions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def __init__(self, **kwargs):
1515
super().__init__()
1616
self.trainable = False
1717
self.is_source_constant = False
18-
18+
19+
self.keep_last_dim = kwargs["keep_last_dim"] if "keep_last_dim" in kwargs else False
1920
self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False
2021
self.subspace_partition = (
2122
kwargs["subspace_partition"] if "subspace_partition" in kwargs else None

pyvene/models/modeling_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,8 @@ def do_intervention(
435435
original_base_shape = base_representation.shape
436436
if len(original_base_shape) == 2 or (
437437
isinstance(intervention, LocalistRepresentationIntervention)
438-
):
439-
# no pos dimension, e.g., gru
438+
) or intervention.keep_last_dim:
439+
# no pos dimension, e.g., gru, or opt-out concate last two dims
440440
base_representation_f = base_representation
441441
source_representation_f = source_representation
442442
elif len(original_base_shape) == 3:
@@ -459,8 +459,8 @@ def do_intervention(
459459
# unflatten
460460
if len(original_base_shape) == 2 or isinstance(
461461
intervention, LocalistRepresentationIntervention
462-
):
463-
# no pos dimension, e.g., gru
462+
) or intervention.keep_last_dim:
463+
# no pos dimension, e.g., gru or opt-out concate last two dims
464464
pass
465465
elif len(original_base_shape) == 3:
466466
intervened_representation = b_sd_to_bsd(intervened_representation, num_unit)

pyvene_101.ipynb

Lines changed: 127 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -126,34 +126,36 @@
126126
},
127127
{
128128
"cell_type": "code",
129-
"execution_count": 13,
129+
"execution_count": 14,
130130
"id": "17c7f2f6-b0d3-4fe2-8e4f-c044b93f3ef0",
131131
"metadata": {},
132-
"outputs": [
133-
{
134-
"name": "stdout",
135-
"output_type": "stream",
136-
"text": [
137-
"loaded model\n"
138-
]
139-
}
140-
],
132+
"outputs": [],
141133
"source": [
142-
"import torch\n",
143134
"import pyvene as pv\n",
135+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
144136
"\n",
145-
"_, tokenizer, gpt2 = pv.create_gpt2()\n",
137+
"model_name = \"gpt2\"\n",
138+
"model = AutoModelForCausalLM.from_pretrained(model_name)\n",
139+
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
146140
"\n",
147-
"pv_gpt2 = pv.IntervenableModel({\n",
148-
" \"layer\": 10,\n",
149-
" \"component\": \"attention_weight\",\n",
150-
" \"intervention_type\": pv.CollectIntervention}, model=gpt2)\n",
141+
"# create a dict-based intervention config\n",
142+
"pv_config = pv.IntervenableConfig({\n",
143+
" \"component\": \"transformer.h[0].mlp.output\"},\n",
144+
" intervention_types=pv.VanillaIntervention\n",
145+
")\n",
146+
"# wrap your model with the config\n",
147+
"pv_gpt2 = pv.IntervenableModel(pv_config, model=model)\n",
151148
"\n",
152-
"base = \"When John and Mary went to the shops, Mary gave the bag to\"\n",
153-
"collected_attn_w = pv_gpt2(\n",
154-
" base = tokenizer(base, return_tensors=\"pt\"\n",
155-
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
156-
")[0][-1][0]"
149+
"# run an interchange intervention (activation swap between two examples)\n",
150+
"intervened_outputs = pv_gpt2(\n",
151+
" # the base input\n",
152+
" base=tokenizer(\"The capital of Spain is\", return_tensors = \"pt\"), \n",
153+
" # the source input\n",
154+
" sources=tokenizer(\"The capital of Italy is\", return_tensors = \"pt\"), \n",
155+
" # the location to intervene at (3rd token)\n",
156+
" unit_locations={\"sources->base\": 3},\n",
157+
" output_original_output=True # False then the first element in the tuple is None\n",
158+
")"
157159
]
158160
},
159161
{
@@ -166,46 +168,49 @@
166168
},
167169
{
168170
"cell_type": "code",
169-
"execution_count": 15,
171+
"execution_count": 12,
170172
"id": "1ef4a1db-5187-4457-9878-f1dc03e9859b",
171173
"metadata": {},
172174
"outputs": [
173175
{
174176
"data": {
175177
"text/plain": [
176-
"GPT2Model(\n",
177-
" (wte): Embedding(50257, 768)\n",
178-
" (wpe): Embedding(1024, 768)\n",
179-
" (drop): Dropout(p=0.1, inplace=False)\n",
180-
" (h): ModuleList(\n",
181-
" (0-11): 12 x GPT2Block(\n",
182-
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
183-
" (attn): GPT2Attention(\n",
184-
" (c_attn): Conv1D()\n",
185-
" (c_proj): Conv1D()\n",
186-
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
187-
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
188-
" )\n",
189-
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
190-
" (mlp): GPT2MLP(\n",
191-
" (c_fc): Conv1D()\n",
192-
" (c_proj): Conv1D()\n",
193-
" (act): NewGELUActivation()\n",
194-
" (dropout): Dropout(p=0.1, inplace=False)\n",
178+
"GPT2LMHeadModel(\n",
179+
" (transformer): GPT2Model(\n",
180+
" (wte): Embedding(50257, 768)\n",
181+
" (wpe): Embedding(1024, 768)\n",
182+
" (drop): Dropout(p=0.1, inplace=False)\n",
183+
" (h): ModuleList(\n",
184+
" (0-11): 12 x GPT2Block(\n",
185+
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
186+
" (attn): GPT2Attention(\n",
187+
" (c_attn): Conv1D()\n",
188+
" (c_proj): Conv1D()\n",
189+
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
190+
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
191+
" )\n",
192+
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
193+
" (mlp): GPT2MLP(\n",
194+
" (c_fc): Conv1D()\n",
195+
" (c_proj): Conv1D()\n",
196+
" (act): NewGELUActivation()\n",
197+
" (dropout): Dropout(p=0.1, inplace=False)\n",
198+
" )\n",
195199
" )\n",
196200
" )\n",
201+
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
197202
" )\n",
198-
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
203+
" (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
199204
")"
200205
]
201206
},
202-
"execution_count": 15,
207+
"execution_count": 12,
203208
"metadata": {},
204209
"output_type": "execute_result"
205210
}
206211
],
207212
"source": [
208-
"gpt2"
213+
"model"
209214
]
210215
},
211216
{
@@ -1978,6 +1983,84 @@
19781983
"print(torch.equal(pv_out3.last_hidden_state, pv_out4.last_hidden_state))"
19791984
]
19801985
},
1986+
{
1987+
"cell_type": "markdown",
1988+
"id": "243f146f-1b9a-4574-ba2c-ebf455a96c16",
1989+
"metadata": {},
1990+
"source": [
1991+
"Other than intervention linking, you can also share interventions at the same component across multiple positions via setting a flag in the intervention object. It will have the same effect as creating one intervention per location and linking them all together."
1992+
]
1993+
},
1994+
{
1995+
"cell_type": "code",
1996+
"execution_count": 24,
1997+
"id": "7c647943-c7e1-4024-8c07-b51062e668ba",
1998+
"metadata": {},
1999+
"outputs": [
2000+
{
2001+
"name": "stdout",
2002+
"output_type": "stream",
2003+
"text": [
2004+
"loaded model\n"
2005+
]
2006+
},
2007+
{
2008+
"data": {
2009+
"text/plain": [
2010+
"tensor([[[0., 0., 0., ..., 0., 0., 0.],\n",
2011+
" [0., 0., 0., ..., 0., 0., 0.],\n",
2012+
" [0., 0., 0., ..., 0., 0., 0.],\n",
2013+
" [0., 0., 0., ..., 0., 0., 0.],\n",
2014+
" [0., 0., 0., ..., 0., 0., 0.]]])"
2015+
]
2016+
},
2017+
"execution_count": 24,
2018+
"metadata": {},
2019+
"output_type": "execute_result"
2020+
}
2021+
],
2022+
"source": [
2023+
"import torch\n",
2024+
"import pyvene as pv\n",
2025+
"\n",
2026+
"_, tokenizer, gpt2 = pv.create_gpt2()\n",
2027+
"\n",
2028+
"config = pv.IntervenableConfig([\n",
2029+
" # they are linked to manipulate the same representation\n",
2030+
" # but in different subspaces\n",
2031+
" {\"layer\": 0, \"component\": \"block_output\", \"intervention_link_key\": 0},\n",
2032+
" {\"layer\": 0, \"component\": \"block_output\", \"intervention_link_key\": 0}],\n",
2033+
" intervention_types=pv.VanillaIntervention,\n",
2034+
")\n",
2035+
"pv_gpt2 = pv.IntervenableModel(config, model=gpt2)\n",
2036+
"\n",
2037+
"base = tokenizer(\"The capital of Spain is\", return_tensors=\"pt\")\n",
2038+
"source = tokenizer(\"The capital of Italy is\", return_tensors=\"pt\")\n",
2039+
"\n",
2040+
"_, pv_out = pv_gpt2(\n",
2041+
" base,\n",
2042+
" [source, source],\n",
2043+
" # swap 3rd and 4th token reprs from the same source to the base\n",
2044+
" {\"sources->base\": ([[[4]], [[3]]], [[[4]], [[3]]])},\n",
2045+
")\n",
2046+
"\n",
2047+
"keep_last_dim_config = pv.IntervenableConfig([\n",
2048+
" # they are linked to manipulate the same representation\n",
2049+
" # but in different subspaces\n",
2050+
" {\"layer\": 0, \"component\": \"block_output\", \n",
2051+
" \"intervention\": pv.VanillaIntervention(keep_last_dim=True)}]\n",
2052+
")\n",
2053+
"keep_last_dim_pv_gpt2 = pv.IntervenableModel(keep_last_dim_config, model=gpt2)\n",
2054+
"\n",
2055+
"_, keep_last_dim_pv_out = keep_last_dim_pv_gpt2(\n",
2056+
" base,\n",
2057+
" [source],\n",
2058+
" # swap 3rd and 4th token reprs from the same source to the base\n",
2059+
" {\"sources->base\": ([[[3,4]]], [[[3,4]]])},\n",
2060+
")\n",
2061+
"keep_last_dim_pv_out.last_hidden_state - pv_out.last_hidden_state"
2062+
]
2063+
},
19812064
{
19822065
"cell_type": "markdown",
19832066
"id": "ef5b7a3e",

0 commit comments

Comments
 (0)