Skip to content

Commit 89b8e4f

Browse files
authored
Merge pull request #130 from stanfordnlp/zen/updatereadme
[Minor] Fix the pyvene101 with correct examples
2 parents 7504fb6 + d36a3d5 commit 89b8e4f

File tree

2 files changed

+35
-79
lines changed

2 files changed

+35
-79
lines changed

README.md

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,41 @@ pip install pyvene
3838
```
3939

4040
## _Wrap_ , _Intervene_ and _Share_
41-
You can intervene with supported models as,
41+
You can intervene with any HuggingFace model as,
4242
```python
4343
import torch
4444
import pyvene as pv
45+
from transformers import AutoTokenizer, AutoModelForCausalLM
4546

46-
_, tokenizer, gpt2 = pv.create_gpt2()
47+
model_name = "meta-llama/Llama-2-7b-hf" # your HF model name.
48+
model = AutoModelForCausalLM.from_pretrained(
49+
model_name, torch_dtype=torch.bfloat16, device_map="cuda")
50+
tokenizer = AutoTokenizer.from_pretrained(model_name)
51+
52+
def zeroout_intervention_fn(b, s):
53+
b[:,3] = 0. # 3rd position
54+
return b
4755

48-
pv_gpt2 = pv.IntervenableModel({
49-
"layer": 0, "component": "block_output",
50-
"source_representation": torch.zeros(gpt2.config.n_embd)
51-
}, model=gpt2)
56+
pv_model = pv.IntervenableModel({
57+
"component": "model.layers[15].mlp.output", # string access
58+
"intervention": zeroout_intervention_fn}, model=model)
5259

53-
orig_outputs, intervened_outputs = pv_gpt2(
54-
base = tokenizer("The capital of Spain is", return_tensors="pt"),
55-
unit_locations={"base": 3}
60+
# run the intervened forward pass
61+
orig_outputs, intervened_outputs = pv_model(
62+
tokenizer("The capital of Spain is", return_tensors="pt").to('cuda'),
63+
output_original_output=True
5664
)
57-
print(intervened_outputs.last_hidden_state - orig_outputs.last_hidden_state)
65+
print(intervened_outputs.logits - orig_outputs.logits)
5866
```
5967
which returns,
6068
```
6169
tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
6270
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
6371
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
64-
[ 0.0483, -0.1212, -0.2816, ..., 0.1958, 0.0830, 0.0784],
65-
[ 0.0519, 0.2547, -0.1631, ..., 0.0050, -0.0453, -0.1624]]])
72+
[ 0.4375, 1.0625, 0.3750, ..., -0.1562, 0.4844, 0.2969],
73+
[ 0.0938, 0.1250, 0.1875, ..., 0.2031, 0.0625, 0.2188],
74+
[ 0.0000, -0.0625, -0.0312, ..., 0.0000, 0.0000, -0.0156]]],
75+
device='cuda:0')
6676
```
6777

6878
## _IntervenableModel_ Loaded from HuggingFace Directly

pyvene_101.ipynb

Lines changed: 13 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@
126126
},
127127
{
128128
"cell_type": "code",
129-
"execution_count": 14,
129+
"execution_count": 2,
130130
"id": "17c7f2f6-b0d3-4fe2-8e4f-c044b93f3ef0",
131131
"metadata": {},
132132
"outputs": [],
@@ -135,27 +135,19 @@
135135
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
136136
"\n",
137137
"model_name = \"gpt2\"\n",
138-
"model = AutoModelForCausalLM.from_pretrained(model_name)\n",
138+
"gpt2 = AutoModelForCausalLM.from_pretrained(model_name)\n",
139139
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
140140
"\n",
141-
"# create a dict-based intervention config\n",
142-
"pv_config = pv.IntervenableConfig({\n",
143-
" \"component\": \"transformer.h[0].mlp.output\"},\n",
144-
" intervention_types=pv.VanillaIntervention\n",
145-
")\n",
146-
"# wrap your model with the config\n",
147-
"pv_gpt2 = pv.IntervenableModel(pv_config, model=model)\n",
141+
"pv_gpt2 = pv.IntervenableModel({\n",
142+
" \"layer\": 10,\n",
143+
" \"component\": \"attention_weight\",\n",
144+
" \"intervention_type\": pv.CollectIntervention}, model=gpt2)\n",
148145
"\n",
149-
"# run an interchange intervention (activation swap between two examples)\n",
150-
"intervened_outputs = pv_gpt2(\n",
151-
" # the base input\n",
152-
" base=tokenizer(\"The capital of Spain is\", return_tensors = \"pt\"), \n",
153-
" # the source input\n",
154-
" sources=tokenizer(\"The capital of Italy is\", return_tensors = \"pt\"), \n",
155-
" # the location to intervene at (3rd token)\n",
156-
" unit_locations={\"sources->base\": 3},\n",
157-
" output_original_output=True # False then the first element in the tuple is None\n",
158-
")"
146+
"base = \"When John and Mary went to the shops, Mary gave the bag to\"\n",
147+
"collected_attn_w = pv_gpt2(\n",
148+
" base = tokenizer(base, return_tensors=\"pt\"\n",
149+
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
150+
")[0][-1][0]"
159151
]
160152
},
161153
{
@@ -166,53 +158,6 @@
166158
"#### Get Attention Weights with Direct Access String"
167159
]
168160
},
169-
{
170-
"cell_type": "code",
171-
"execution_count": 12,
172-
"id": "1ef4a1db-5187-4457-9878-f1dc03e9859b",
173-
"metadata": {},
174-
"outputs": [
175-
{
176-
"data": {
177-
"text/plain": [
178-
"GPT2LMHeadModel(\n",
179-
" (transformer): GPT2Model(\n",
180-
" (wte): Embedding(50257, 768)\n",
181-
" (wpe): Embedding(1024, 768)\n",
182-
" (drop): Dropout(p=0.1, inplace=False)\n",
183-
" (h): ModuleList(\n",
184-
" (0-11): 12 x GPT2Block(\n",
185-
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
186-
" (attn): GPT2Attention(\n",
187-
" (c_attn): Conv1D()\n",
188-
" (c_proj): Conv1D()\n",
189-
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
190-
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
191-
" )\n",
192-
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
193-
" (mlp): GPT2MLP(\n",
194-
" (c_fc): Conv1D()\n",
195-
" (c_proj): Conv1D()\n",
196-
" (act): NewGELUActivation()\n",
197-
" (dropout): Dropout(p=0.1, inplace=False)\n",
198-
" )\n",
199-
" )\n",
200-
" )\n",
201-
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
202-
" )\n",
203-
" (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
204-
")"
205-
]
206-
},
207-
"execution_count": 12,
208-
"metadata": {},
209-
"output_type": "execute_result"
210-
}
211-
],
212-
"source": [
213-
"model"
214-
]
215-
},
216161
{
217162
"cell_type": "code",
218163
"execution_count": 19,
@@ -231,6 +176,7 @@
231176
"import torch\n",
232177
"import pyvene as pv\n",
233178
"\n",
179+
"# gpt2 helper loading model from HuggingFace\n",
234180
"_, tokenizer, gpt2 = pv.create_gpt2()\n",
235181
"\n",
236182
"pv_gpt2 = pv.IntervenableModel({\n",
@@ -724,7 +670,7 @@
724670
},
725671
{
726672
"cell_type": "code",
727-
"execution_count": 8,
673+
"execution_count": 2,
728674
"id": "7f058ecd",
729675
"metadata": {},
730676
"outputs": [

0 commit comments

Comments
 (0)