Skip to content

Commit f6dbee1

Browse files
authored
Merge pull request #191 from eggachecat/main
feat: add intervenable_model to forward's function signature
2 parents 64ad99f + a6b0228 commit f6dbee1

File tree

5 files changed

+68
-20
lines changed

5 files changed

+68
-20
lines changed

pyvene/models/intervenable_base.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ def _intervention_setter(
804804
keys,
805805
unit_locations_base,
806806
subspaces,
807+
**intervention_forward_kwargs
807808
) -> HandlerList:
808809
"""
809810
Create a list of setter tracer that will set activations
@@ -848,6 +849,7 @@ def _intervention_setter(
848849
None,
849850
intervention,
850851
subspaces[key_i] if subspaces is not None else None,
852+
**intervention_forward_kwargs
851853
)
852854
# fail if this is not a fresh collect
853855
assert key not in self.activations
@@ -862,6 +864,7 @@ def _intervention_setter(
862864
None,
863865
intervention,
864866
subspaces[key_i] if subspaces is not None else None,
867+
**intervention_forward_kwargs
865868
)
866869
else:
867870
intervened_representation = do_intervention(
@@ -873,6 +876,7 @@ def _intervention_setter(
873876
),
874877
intervention,
875878
subspaces[key_i] if subspaces is not None else None,
879+
**intervention_forward_kwargs
876880
)
877881
else:
878882
# highly unlikely it's a primitive intervention type
@@ -885,6 +889,7 @@ def _intervention_setter(
885889
),
886890
intervention,
887891
subspaces[key_i] if subspaces is not None else None,
892+
**intervention_forward_kwargs
888893
)
889894
if intervened_representation is None:
890895
return
@@ -970,6 +975,7 @@ def _sync_forward_with_parallel_intervention(
970975
]
971976
if subspaces is not None
972977
else None,
978+
**kwargs
973979
)
974980
counterfactual_outputs = self.model.output.save()
975981

@@ -997,6 +1003,7 @@ def forward(
9971003
output_original_output: Optional[bool] = False,
9981004
return_dict: Optional[bool] = None,
9991005
use_cache: Optional[bool] = None,
1006+
**kwargs
10001007
):
10011008
activations_sources = source_representations
10021009
if sources is not None and not isinstance(sources, list):
@@ -1036,7 +1043,7 @@ def forward(
10361043
try:
10371044

10381045
# run intervened forward
1039-
model_kwargs = {}
1046+
model_kwargs = { **kwargs }
10401047
if labels is not None: # for training
10411048
model_kwargs["labels"] = labels
10421049
if use_cache is not None and 'use_cache' in self.model.config.to_dict(): # for transformer models
@@ -1526,6 +1533,7 @@ def _intervention_setter(
15261533
keys,
15271534
unit_locations_base,
15281535
subspaces,
1536+
**intervention_forward_kwargs
15291537
) -> HandlerList:
15301538
"""
15311539
Create a list of setter handlers that will set activations
@@ -1573,6 +1581,7 @@ def hook_callback(model, args, kwargs, output=None):
15731581
None,
15741582
intervention,
15751583
subspaces[key_i] if subspaces is not None else None,
1584+
**intervention_forward_kwargs
15761585
)
15771586
# fail if this is not a fresh collect
15781587
assert key not in self.activations
@@ -1588,6 +1597,7 @@ def hook_callback(model, args, kwargs, output=None):
15881597
None,
15891598
intervention,
15901599
subspaces[key_i] if subspaces is not None else None,
1600+
**intervention_forward_kwargs
15911601
)
15921602
if isinstance(raw_intervened_representation, InterventionOutput):
15931603
self.full_intervention_outputs.append(raw_intervened_representation)
@@ -1604,6 +1614,7 @@ def hook_callback(model, args, kwargs, output=None):
16041614
),
16051615
intervention,
16061616
subspaces[key_i] if subspaces is not None else None,
1617+
**intervention_forward_kwargs
16071618
)
16081619
else:
16091620
# highly unlikely it's a primitive intervention type
@@ -1616,6 +1627,7 @@ def hook_callback(model, args, kwargs, output=None):
16161627
),
16171628
intervention,
16181629
subspaces[key_i] if subspaces is not None else None,
1630+
**intervention_forward_kwargs
16191631
)
16201632
if intervened_representation is None:
16211633
return
@@ -1683,6 +1695,7 @@ def _wait_for_forward_with_parallel_intervention(
16831695
unit_locations,
16841696
activations_sources: Optional[Dict] = None,
16851697
subspaces: Optional[List] = None,
1698+
**intervention_forward_kwargs
16861699
):
16871700
# torch.autograd.set_detect_anomaly(True)
16881701
all_set_handlers = HandlerList([])
@@ -1738,6 +1751,7 @@ def _wait_for_forward_with_parallel_intervention(
17381751
]
17391752
if subspaces is not None
17401753
else None,
1754+
**intervention_forward_kwargs
17411755
)
17421756
# for setters, we don't remove them.
17431757
all_set_handlers.extend(set_handlers)
@@ -1749,6 +1763,7 @@ def _wait_for_forward_with_serial_intervention(
17491763
unit_locations,
17501764
activations_sources: Optional[Dict] = None,
17511765
subspaces: Optional[List] = None,
1766+
**intervention_forward_kwargs
17521767
):
17531768
all_set_handlers = HandlerList([])
17541769
for group_id, keys in self._intervention_group.items():
@@ -1805,6 +1820,7 @@ def _wait_for_forward_with_serial_intervention(
18051820
]
18061821
if subspaces is not None
18071822
else None,
1823+
**intervention_forward_kwargs
18081824
)
18091825
# for setters, we don't remove them.
18101826
all_set_handlers.extend(set_handlers)
@@ -1821,6 +1837,7 @@ def forward(
18211837
output_original_output: Optional[bool] = False,
18221838
return_dict: Optional[bool] = None,
18231839
use_cache: Optional[bool] = None,
1840+
**intervention_forward_kwargs
18241841
):
18251842
"""
18261843
Main forward function that serves a wrapper to
@@ -1929,6 +1946,7 @@ def forward(
19291946
unit_locations,
19301947
activations_sources,
19311948
subspaces,
1949+
**intervention_forward_kwargs
19321950
)
19331951
)
19341952
elif self.mode == "serial":
@@ -1938,6 +1956,7 @@ def forward(
19381956
unit_locations,
19391957
activations_sources,
19401958
subspaces,
1959+
**intervention_forward_kwargs
19411960
)
19421961
)
19431962

@@ -2071,6 +2090,7 @@ def generate(
20712090
unit_locations,
20722091
activations_sources,
20732092
subspaces,
2093+
**kwargs
20742094
)
20752095
)
20762096
elif self.mode == "serial":
@@ -2080,6 +2100,7 @@ def generate(
20802100
unit_locations,
20812101
activations_sources,
20822102
subspaces,
2103+
**kwargs
20832104
)
20842105
)
20852106

pyvene/models/interventions.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def set_interchange_dim(self, interchange_dim):
7575
self.interchange_dim = interchange_dim
7676

7777
@abstractmethod
78-
def forward(self, base, source, subspaces=None):
78+
def forward(self, base, source, subspaces=None, **kwargs):
7979
pass
8080

8181

@@ -153,7 +153,7 @@ class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationInterve
153153
def __init__(self, **kwargs):
154154
super().__init__(**kwargs)
155155

156-
def forward(self, base, source=None, subspaces=None):
156+
def forward(self, base, source=None, subspaces=None, **kwargs):
157157
return _do_intervention_by_swap(
158158
base,
159159
torch.zeros_like(base),
@@ -175,7 +175,7 @@ class CollectIntervention(ConstantSourceIntervention):
175175
def __init__(self, **kwargs):
176176
super().__init__(**kwargs)
177177

178-
def forward(self, base, source=None, subspaces=None):
178+
def forward(self, base, source=None, subspaces=None, **kwargs):
179179
return _do_intervention_by_swap(
180180
base,
181181
source,
@@ -197,7 +197,7 @@ class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationInterven
197197
def __init__(self, **kwargs):
198198
super().__init__(**kwargs)
199199

200-
def forward(self, base, source, subspaces=None):
200+
def forward(self, base, source, subspaces=None, **kwargs):
201201
# source here is the base example input to the hook
202202
return _do_intervention_by_swap(
203203
base,
@@ -220,7 +220,7 @@ class VanillaIntervention(Intervention, LocalistRepresentationIntervention):
220220
def __init__(self, **kwargs):
221221
super().__init__(**kwargs)
222222

223-
def forward(self, base, source, subspaces=None):
223+
def forward(self, base, source, subspaces=None, **kwargs):
224224
return _do_intervention_by_swap(
225225
base,
226226
source if self.source_representation is None else self.source_representation,
@@ -242,7 +242,7 @@ class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationInte
242242
def __init__(self, **kwargs):
243243
super().__init__(**kwargs)
244244

245-
def forward(self, base, source, subspaces=None):
245+
def forward(self, base, source, subspaces=None, **kwargs):
246246
return _do_intervention_by_swap(
247247
base,
248248
source if self.source_representation is None else self.source_representation,
@@ -264,7 +264,7 @@ class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationI
264264
def __init__(self, **kwargs):
265265
super().__init__(**kwargs)
266266

267-
def forward(self, base, source, subspaces=None):
267+
def forward(self, base, source, subspaces=None, **kwargs):
268268

269269
return _do_intervention_by_swap(
270270
base,
@@ -289,7 +289,7 @@ def __init__(self, **kwargs):
289289
rotate_layer = RotateLayer(self.embed_dim)
290290
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
291291

292-
def forward(self, base, source, subspaces=None):
292+
def forward(self, base, source, subspaces=None, **kwargs):
293293
rotated_base = self.rotate_layer(base)
294294
rotated_source = self.rotate_layer(source)
295295
# interchange
@@ -340,7 +340,7 @@ def set_intervention_boundaries(self, intervention_boundaries):
340340
torch.tensor([intervention_boundaries]), requires_grad=True
341341
)
342342

343-
def forward(self, base, source, subspaces=None):
343+
def forward(self, base, source, subspaces=None, **kwargs):
344344
batch_size = base.shape[0]
345345
rotated_base = self.rotate_layer(base)
346346
rotated_source = self.rotate_layer(source)
@@ -391,7 +391,7 @@ def get_temperature(self):
391391
def set_temperature(self, temp: torch.Tensor):
392392
self.temperature.data = temp
393393

394-
def forward(self, base, source, subspaces=None):
394+
def forward(self, base, source, subspaces=None, **kwargs):
395395
batch_size = base.shape[0]
396396
rotated_base = self.rotate_layer(base)
397397
rotated_source = self.rotate_layer(source)
@@ -431,7 +431,7 @@ def get_temperature(self):
431431
def set_temperature(self, temp: torch.Tensor):
432432
self.temperature.data = temp
433433

434-
def forward(self, base, source, subspaces=None):
434+
def forward(self, base, source, subspaces=None, **kwargs):
435435
batch_size = base.shape[0]
436436
# get boundary mask between 0 and 1 from sigmoid
437437
mask_sigmoid = torch.sigmoid(self.mask / torch.tensor(self.temperature))
@@ -456,7 +456,7 @@ def __init__(self, **kwargs):
456456
rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
457457
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
458458

459-
def forward(self, base, source, subspaces=None):
459+
def forward(self, base, source, subspaces=None, **kwargs):
460460
rotated_base = self.rotate_layer(base)
461461
rotated_source = self.rotate_layer(source)
462462
if subspaces is not None:
@@ -529,7 +529,7 @@ def __init__(self, **kwargs):
529529
)
530530
self.trainable = False
531531

532-
def forward(self, base, source, subspaces=None):
532+
def forward(self, base, source, subspaces=None, **kwargs):
533533
base_norm = (base - self.pca_mean) / self.pca_std
534534
source_norm = (source - self.pca_mean) / self.pca_std
535535

@@ -565,7 +565,7 @@ def __init__(self, **kwargs):
565565
prng(1, 4, self.embed_dim)))
566566
self.register_buffer('noise_level', torch.tensor(noise_level))
567567

568-
def forward(self, base, source=None, subspaces=None):
568+
def forward(self, base, source=None, subspaces=None, **kwargs):
569569
base[..., : self.interchange_dim] += self.noise * self.noise_level
570570
return base
571571

@@ -585,7 +585,7 @@ def __init__(self, **kwargs):
585585
self.autoencoder = AutoencoderLayer(
586586
self.embed_dim, kwargs["latent_dim"])
587587

588-
def forward(self, base, source, subspaces=None):
588+
def forward(self, base, source, subspaces=None, **kwargs):
589589
base_dtype = base.dtype
590590
base = base.to(self.autoencoder.encoder[0].weight.dtype)
591591
base_latent = self.autoencoder.encode(base)
@@ -619,7 +619,7 @@ def encode(self, input_acts):
619619
def decode(self, acts):
620620
return acts @ self.W_dec + self.b_dec
621621

622-
def forward(self, base, source=None, subspaces=None):
622+
def forward(self, base, source=None, subspaces=None, **kwargs):
623623
# generate latents for base and source runs.
624624
base_latent = self.encode(base)
625625
source_latent = self.encode(source)

pyvene/models/modeling_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def scatter_neurons(
446446

447447

448448
def do_intervention(
449-
base_representation, source_representation, intervention, subspaces
449+
base_representation, source_representation, intervention, subspaces, **intervention_forward_kwargs
450450
):
451451
"""Do the actual intervention."""
452452

@@ -478,7 +478,8 @@ def do_intervention(
478478
assert False # what's going on?
479479

480480
intervention_output = intervention(
481-
base_representation_f, source_representation_f, subspaces
481+
base_representation_f, source_representation_f, subspaces,
482+
**intervention_forward_kwargs
482483
)
483484
if isinstance(intervention_output, InterventionOutput):
484485
intervened_representation = intervention_output.output

tests/integration_tests/IntervenableBasicTestCase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class MultiplierIntervention(
232232
def __init__(self, embed_dim, **kwargs):
233233
super().__init__()
234234
def forward(
235-
self, base, source=None, subspaces=None):
235+
self, base, source=None, subspaces=None, **kwargs):
236236
return base * 99.0
237237
# run with new intervention type
238238
pv_gpt2 = pv.IntervenableModel({

tests/integration_tests/InterventionWithLlamaTestCase.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,32 @@ def test_with_multiple_heads_positions_vanilla_intervention_positive(self):
156156
heads=[4, 1],
157157
positions=[7, 2],
158158
)
159+
160+
def test_with_llm_head(self):
161+
that = self
162+
_lm_head_collection = {}
163+
class AccessIntervenableModelIntervention:
164+
is_source_constant = True
165+
keep_last_dim = True
166+
intervention_types = 'access_intervenable_model_intervention'
167+
def __init__(self, layer_index, *args, **kwargs):
168+
super().__init__()
169+
self.layer_index = layer_index
170+
def __call__(self, base, source=None, subspaces=None, model=None, **kwargs):
171+
intervenable_model = kwargs.get('intervenable_model', None)
172+
assert intervenable_model is not None
173+
_lm_head_collection[self.layer_index] = intervenable_model.model.lm_head(base.to(that.device))
174+
return base
175+
# run with new intervention type
176+
pv_llama = IntervenableModel([{
177+
"intervention": AccessIntervenableModelIntervention(layer_index=layer),
178+
"component": f"model.layers.{layer}.input"
179+
} for layer in [1, 3]], model=self.llama)
180+
intervened_outputs = pv_llama(
181+
base=self.tokenizer("The capital of Spain is", return_tensors="pt").to(that.device),
182+
unit_locations={"base": 3},
183+
intervenable_model=pv_llama
184+
)
159185

160186

161187
def suite():

0 commit comments

Comments
 (0)