|
126 | 126 | },
|
127 | 127 | {
|
128 | 128 | "cell_type": "code",
|
129 |
| - "execution_count": 13, |
| 129 | + "execution_count": 14, |
130 | 130 | "id": "17c7f2f6-b0d3-4fe2-8e4f-c044b93f3ef0",
|
131 | 131 | "metadata": {},
|
132 |
| - "outputs": [ |
133 |
| - { |
134 |
| - "name": "stdout", |
135 |
| - "output_type": "stream", |
136 |
| - "text": [ |
137 |
| - "loaded model\n" |
138 |
| - ] |
139 |
| - } |
140 |
| - ], |
| 132 | + "outputs": [], |
141 | 133 | "source": [
|
142 |
| - "import torch\n", |
143 | 134 | "import pyvene as pv\n",
|
| 135 | + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", |
144 | 136 | "\n",
|
145 |
| - "_, tokenizer, gpt2 = pv.create_gpt2()\n", |
| 137 | + "model_name = \"gpt2\"\n", |
| 138 | + "model = AutoModelForCausalLM.from_pretrained(model_name)\n", |
| 139 | + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", |
146 | 140 | "\n",
|
147 |
| - "pv_gpt2 = pv.IntervenableModel({\n", |
148 |
| - " \"layer\": 10,\n", |
149 |
| - " \"component\": \"attention_weight\",\n", |
150 |
| - " \"intervention_type\": pv.CollectIntervention}, model=gpt2)\n", |
| 141 | + "# create a dict-based intervention config\n", |
| 142 | + "pv_config = pv.IntervenableConfig({\n", |
| 143 | + " \"component\": \"transformer.h[0].mlp.output\"},\n", |
| 144 | + " intervention_types=pv.VanillaIntervention\n", |
| 145 | + ")\n", |
| 146 | + "# wrap your model with the config\n", |
| 147 | + "pv_gpt2 = pv.IntervenableModel(pv_config, model=model)\n", |
151 | 148 | "\n",
|
152 |
| - "base = \"When John and Mary went to the shops, Mary gave the bag to\"\n", |
153 |
| - "collected_attn_w = pv_gpt2(\n", |
154 |
| - " base = tokenizer(base, return_tensors=\"pt\"\n", |
155 |
| - " ), unit_locations={\"base\": [h for h in range(12)]}\n", |
156 |
| - ")[0][-1][0]" |
| 149 | + "# run an interchange intervention (activation swap between two examples)\n", |
| 150 | + "intervened_outputs = pv_gpt2(\n", |
| 151 | + " # the base input\n", |
| 152 | + " base=tokenizer(\"The capital of Spain is\", return_tensors = \"pt\"), \n", |
| 153 | + " # the source input\n", |
| 154 | + " sources=tokenizer(\"The capital of Italy is\", return_tensors = \"pt\"), \n", |
| 155 | + " # the location to intervene at (3rd token)\n", |
| 156 | + " unit_locations={\"sources->base\": 3},\n", |
| 157 | + " output_original_output=True # False then the first element in the tuple is None\n", |
| 158 | + ")" |
157 | 159 | ]
|
158 | 160 | },
|
159 | 161 | {
|
|
166 | 168 | },
|
167 | 169 | {
|
168 | 170 | "cell_type": "code",
|
169 |
| - "execution_count": 15, |
| 171 | + "execution_count": 12, |
170 | 172 | "id": "1ef4a1db-5187-4457-9878-f1dc03e9859b",
|
171 | 173 | "metadata": {},
|
172 | 174 | "outputs": [
|
173 | 175 | {
|
174 | 176 | "data": {
|
175 | 177 | "text/plain": [
|
176 |
| - "GPT2Model(\n", |
177 |
| - " (wte): Embedding(50257, 768)\n", |
178 |
| - " (wpe): Embedding(1024, 768)\n", |
179 |
| - " (drop): Dropout(p=0.1, inplace=False)\n", |
180 |
| - " (h): ModuleList(\n", |
181 |
| - " (0-11): 12 x GPT2Block(\n", |
182 |
| - " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", |
183 |
| - " (attn): GPT2Attention(\n", |
184 |
| - " (c_attn): Conv1D()\n", |
185 |
| - " (c_proj): Conv1D()\n", |
186 |
| - " (attn_dropout): Dropout(p=0.1, inplace=False)\n", |
187 |
| - " (resid_dropout): Dropout(p=0.1, inplace=False)\n", |
188 |
| - " )\n", |
189 |
| - " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", |
190 |
| - " (mlp): GPT2MLP(\n", |
191 |
| - " (c_fc): Conv1D()\n", |
192 |
| - " (c_proj): Conv1D()\n", |
193 |
| - " (act): NewGELUActivation()\n", |
194 |
| - " (dropout): Dropout(p=0.1, inplace=False)\n", |
| 178 | + "GPT2LMHeadModel(\n", |
| 179 | + " (transformer): GPT2Model(\n", |
| 180 | + " (wte): Embedding(50257, 768)\n", |
| 181 | + " (wpe): Embedding(1024, 768)\n", |
| 182 | + " (drop): Dropout(p=0.1, inplace=False)\n", |
| 183 | + " (h): ModuleList(\n", |
| 184 | + " (0-11): 12 x GPT2Block(\n", |
| 185 | + " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", |
| 186 | + " (attn): GPT2Attention(\n", |
| 187 | + " (c_attn): Conv1D()\n", |
| 188 | + " (c_proj): Conv1D()\n", |
| 189 | + " (attn_dropout): Dropout(p=0.1, inplace=False)\n", |
| 190 | + " (resid_dropout): Dropout(p=0.1, inplace=False)\n", |
| 191 | + " )\n", |
| 192 | + " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", |
| 193 | + " (mlp): GPT2MLP(\n", |
| 194 | + " (c_fc): Conv1D()\n", |
| 195 | + " (c_proj): Conv1D()\n", |
| 196 | + " (act): NewGELUActivation()\n", |
| 197 | + " (dropout): Dropout(p=0.1, inplace=False)\n", |
| 198 | + " )\n", |
195 | 199 | " )\n",
|
196 | 200 | " )\n",
|
| 201 | + " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", |
197 | 202 | " )\n",
|
198 |
| - " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", |
| 203 | + " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", |
199 | 204 | ")"
|
200 | 205 | ]
|
201 | 206 | },
|
202 |
| - "execution_count": 15, |
| 207 | + "execution_count": 12, |
203 | 208 | "metadata": {},
|
204 | 209 | "output_type": "execute_result"
|
205 | 210 | }
|
206 | 211 | ],
|
207 | 212 | "source": [
|
208 |
| - "gpt2" |
| 213 | + "model" |
209 | 214 | ]
|
210 | 215 | },
|
211 | 216 | {
|
|
1978 | 1983 | "print(torch.equal(pv_out3.last_hidden_state, pv_out4.last_hidden_state))"
|
1979 | 1984 | ]
|
1980 | 1985 | },
|
| 1986 | + { |
| 1987 | + "cell_type": "markdown", |
| 1988 | + "id": "243f146f-1b9a-4574-ba2c-ebf455a96c16", |
| 1989 | + "metadata": {}, |
| 1990 | + "source": [ |
| 1991 | + "Other than intervention linking, you can also share interventions at the same component across multiple positions via setting a flag in the intervention object. It will have the same effect as creating one intervention per location and linking them all together." |
| 1992 | + ] |
| 1993 | + }, |
| 1994 | + { |
| 1995 | + "cell_type": "code", |
| 1996 | + "execution_count": 24, |
| 1997 | + "id": "7c647943-c7e1-4024-8c07-b51062e668ba", |
| 1998 | + "metadata": {}, |
| 1999 | + "outputs": [ |
| 2000 | + { |
| 2001 | + "name": "stdout", |
| 2002 | + "output_type": "stream", |
| 2003 | + "text": [ |
| 2004 | + "loaded model\n" |
| 2005 | + ] |
| 2006 | + }, |
| 2007 | + { |
| 2008 | + "data": { |
| 2009 | + "text/plain": [ |
| 2010 | + "tensor([[[0., 0., 0., ..., 0., 0., 0.],\n", |
| 2011 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 2012 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 2013 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 2014 | + " [0., 0., 0., ..., 0., 0., 0.]]])" |
| 2015 | + ] |
| 2016 | + }, |
| 2017 | + "execution_count": 24, |
| 2018 | + "metadata": {}, |
| 2019 | + "output_type": "execute_result" |
| 2020 | + } |
| 2021 | + ], |
| 2022 | + "source": [ |
| 2023 | + "import torch\n", |
| 2024 | + "import pyvene as pv\n", |
| 2025 | + "\n", |
| 2026 | + "_, tokenizer, gpt2 = pv.create_gpt2()\n", |
| 2027 | + "\n", |
| 2028 | + "config = pv.IntervenableConfig([\n", |
| 2029 | + " # they are linked to manipulate the same representation\n", |
| 2030 | + " # but in different subspaces\n", |
| 2031 | + " {\"layer\": 0, \"component\": \"block_output\", \"intervention_link_key\": 0},\n", |
| 2032 | + " {\"layer\": 0, \"component\": \"block_output\", \"intervention_link_key\": 0}],\n", |
| 2033 | + " intervention_types=pv.VanillaIntervention,\n", |
| 2034 | + ")\n", |
| 2035 | + "pv_gpt2 = pv.IntervenableModel(config, model=gpt2)\n", |
| 2036 | + "\n", |
| 2037 | + "base = tokenizer(\"The capital of Spain is\", return_tensors=\"pt\")\n", |
| 2038 | + "source = tokenizer(\"The capital of Italy is\", return_tensors=\"pt\")\n", |
| 2039 | + "\n", |
| 2040 | + "_, pv_out = pv_gpt2(\n", |
| 2041 | + " base,\n", |
| 2042 | + " [source, source],\n", |
| 2043 | + " # swap 3rd and 4th token reprs from the same source to the base\n", |
| 2044 | + " {\"sources->base\": ([[[4]], [[3]]], [[[4]], [[3]]])},\n", |
| 2045 | + ")\n", |
| 2046 | + "\n", |
| 2047 | + "keep_last_dim_config = pv.IntervenableConfig([\n", |
| 2048 | + " # they are linked to manipulate the same representation\n", |
| 2049 | + " # but in different subspaces\n", |
| 2050 | + " {\"layer\": 0, \"component\": \"block_output\", \n", |
| 2051 | + " \"intervention\": pv.VanillaIntervention(keep_last_dim=True)}]\n", |
| 2052 | + ")\n", |
| 2053 | + "keep_last_dim_pv_gpt2 = pv.IntervenableModel(keep_last_dim_config, model=gpt2)\n", |
| 2054 | + "\n", |
| 2055 | + "_, keep_last_dim_pv_out = keep_last_dim_pv_gpt2(\n", |
| 2056 | + " base,\n", |
| 2057 | + " [source],\n", |
| 2058 | + " # swap 3rd and 4th token reprs from the same source to the base\n", |
| 2059 | + " {\"sources->base\": ([[[3,4]]], [[[3,4]]])},\n", |
| 2060 | + ")\n", |
| 2061 | + "keep_last_dim_pv_out.last_hidden_state - pv_out.last_hidden_state" |
| 2062 | + ] |
| 2063 | + }, |
1981 | 2064 | {
|
1982 | 2065 | "cell_type": "markdown",
|
1983 | 2066 | "id": "ef5b7a3e",
|
|
0 commit comments