|
126 | 126 | },
|
127 | 127 | {
|
128 | 128 | "cell_type": "code",
|
129 |
| - "execution_count": 14, |
| 129 | + "execution_count": 2, |
130 | 130 | "id": "17c7f2f6-b0d3-4fe2-8e4f-c044b93f3ef0",
|
131 | 131 | "metadata": {},
|
132 | 132 | "outputs": [],
|
|
135 | 135 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
136 | 136 | "\n",
|
137 | 137 | "model_name = \"gpt2\"\n",
|
138 |
| - "model = AutoModelForCausalLM.from_pretrained(model_name)\n", |
| 138 | + "gpt2 = AutoModelForCausalLM.from_pretrained(model_name)\n", |
139 | 139 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
140 | 140 | "\n",
|
141 |
| - "# create a dict-based intervention config\n", |
142 |
| - "pv_config = pv.IntervenableConfig({\n", |
143 |
| - " \"component\": \"transformer.h[0].mlp.output\"},\n", |
144 |
| - " intervention_types=pv.VanillaIntervention\n", |
145 |
| - ")\n", |
146 |
| - "# wrap your model with the config\n", |
147 |
| - "pv_gpt2 = pv.IntervenableModel(pv_config, model=model)\n", |
| 141 | + "pv_gpt2 = pv.IntervenableModel({\n", |
| 142 | + " \"layer\": 10,\n", |
| 143 | + " \"component\": \"attention_weight\",\n", |
| 144 | + " \"intervention_type\": pv.CollectIntervention}, model=gpt2)\n", |
148 | 145 | "\n",
|
149 |
| - "# run an interchange intervention (activation swap between two examples)\n", |
150 |
| - "intervened_outputs = pv_gpt2(\n", |
151 |
| - " # the base input\n", |
152 |
| - " base=tokenizer(\"The capital of Spain is\", return_tensors = \"pt\"), \n", |
153 |
| - " # the source input\n", |
154 |
| - " sources=tokenizer(\"The capital of Italy is\", return_tensors = \"pt\"), \n", |
155 |
| - " # the location to intervene at (3rd token)\n", |
156 |
| - " unit_locations={\"sources->base\": 3},\n", |
157 |
| - " output_original_output=True # False then the first element in the tuple is None\n", |
158 |
| - ")" |
| 146 | + "base = \"When John and Mary went to the shops, Mary gave the bag to\"\n", |
| 147 | + "collected_attn_w = pv_gpt2(\n", |
| 148 | + " base = tokenizer(base, return_tensors=\"pt\"\n", |
| 149 | + " ), unit_locations={\"base\": [h for h in range(12)]}\n", |
| 150 | + ")[0][-1][0]" |
159 | 151 | ]
|
160 | 152 | },
|
161 | 153 | {
|
|
166 | 158 | "#### Get Attention Weights with Direct Access String"
|
167 | 159 | ]
|
168 | 160 | },
|
169 |
| - { |
170 |
| - "cell_type": "code", |
171 |
| - "execution_count": 12, |
172 |
| - "id": "1ef4a1db-5187-4457-9878-f1dc03e9859b", |
173 |
| - "metadata": {}, |
174 |
| - "outputs": [ |
175 |
| - { |
176 |
| - "data": { |
177 |
| - "text/plain": [ |
178 |
| - "GPT2LMHeadModel(\n", |
179 |
| - " (transformer): GPT2Model(\n", |
180 |
| - " (wte): Embedding(50257, 768)\n", |
181 |
| - " (wpe): Embedding(1024, 768)\n", |
182 |
| - " (drop): Dropout(p=0.1, inplace=False)\n", |
183 |
| - " (h): ModuleList(\n", |
184 |
| - " (0-11): 12 x GPT2Block(\n", |
185 |
| - " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", |
186 |
| - " (attn): GPT2Attention(\n", |
187 |
| - " (c_attn): Conv1D()\n", |
188 |
| - " (c_proj): Conv1D()\n", |
189 |
| - " (attn_dropout): Dropout(p=0.1, inplace=False)\n", |
190 |
| - " (resid_dropout): Dropout(p=0.1, inplace=False)\n", |
191 |
| - " )\n", |
192 |
| - " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", |
193 |
| - " (mlp): GPT2MLP(\n", |
194 |
| - " (c_fc): Conv1D()\n", |
195 |
| - " (c_proj): Conv1D()\n", |
196 |
| - " (act): NewGELUActivation()\n", |
197 |
| - " (dropout): Dropout(p=0.1, inplace=False)\n", |
198 |
| - " )\n", |
199 |
| - " )\n", |
200 |
| - " )\n", |
201 |
| - " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", |
202 |
| - " )\n", |
203 |
| - " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", |
204 |
| - ")" |
205 |
| - ] |
206 |
| - }, |
207 |
| - "execution_count": 12, |
208 |
| - "metadata": {}, |
209 |
| - "output_type": "execute_result" |
210 |
| - } |
211 |
| - ], |
212 |
| - "source": [ |
213 |
| - "model" |
214 |
| - ] |
215 |
| - }, |
216 | 161 | {
|
217 | 162 | "cell_type": "code",
|
218 | 163 | "execution_count": 19,
|
|
231 | 176 | "import torch\n",
|
232 | 177 | "import pyvene as pv\n",
|
233 | 178 | "\n",
|
| 179 | + "# gpt2 helper loading model from HuggingFace\n", |
234 | 180 | "_, tokenizer, gpt2 = pv.create_gpt2()\n",
|
235 | 181 | "\n",
|
236 | 182 | "pv_gpt2 = pv.IntervenableModel({\n",
|
|
724 | 670 | },
|
725 | 671 | {
|
726 | 672 | "cell_type": "code",
|
727 |
| - "execution_count": 8, |
| 673 | + "execution_count": 2, |
728 | 674 | "id": "7f058ecd",
|
729 | 675 | "metadata": {},
|
730 | 676 | "outputs": [
|
|
0 commit comments