@@ -75,7 +75,7 @@ def set_interchange_dim(self, interchange_dim):
7575 self .interchange_dim = interchange_dim
7676
7777 @abstractmethod
78- def forward (self , base , source , subspaces = None ):
78+ def forward (self , base , source , subspaces = None , ** kwargs ):
7979 pass
8080
8181
@@ -153,7 +153,7 @@ class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationInterve
153153 def __init__ (self , ** kwargs ):
154154 super ().__init__ (** kwargs )
155155
156- def forward (self , base , source = None , subspaces = None ):
156+ def forward (self , base , source = None , subspaces = None , ** kwargs ):
157157 return _do_intervention_by_swap (
158158 base ,
159159 torch .zeros_like (base ),
@@ -175,7 +175,7 @@ class CollectIntervention(ConstantSourceIntervention):
175175 def __init__ (self , ** kwargs ):
176176 super ().__init__ (** kwargs )
177177
178- def forward (self , base , source = None , subspaces = None ):
178+ def forward (self , base , source = None , subspaces = None , ** kwargs ):
179179 return _do_intervention_by_swap (
180180 base ,
181181 source ,
@@ -197,7 +197,7 @@ class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationInterven
197197 def __init__ (self , ** kwargs ):
198198 super ().__init__ (** kwargs )
199199
200- def forward (self , base , source , subspaces = None ):
200+ def forward (self , base , source , subspaces = None , ** kwargs ):
201201 # source here is the base example input to the hook
202202 return _do_intervention_by_swap (
203203 base ,
@@ -220,7 +220,7 @@ class VanillaIntervention(Intervention, LocalistRepresentationIntervention):
220220 def __init__ (self , ** kwargs ):
221221 super ().__init__ (** kwargs )
222222
223- def forward (self , base , source , subspaces = None ):
223+ def forward (self , base , source , subspaces = None , ** kwargs ):
224224 return _do_intervention_by_swap (
225225 base ,
226226 source if self .source_representation is None else self .source_representation ,
@@ -242,7 +242,7 @@ class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationInte
242242 def __init__ (self , ** kwargs ):
243243 super ().__init__ (** kwargs )
244244
245- def forward (self , base , source , subspaces = None ):
245+ def forward (self , base , source , subspaces = None , ** kwargs ):
246246 return _do_intervention_by_swap (
247247 base ,
248248 source if self .source_representation is None else self .source_representation ,
@@ -264,7 +264,7 @@ class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationI
264264 def __init__ (self , ** kwargs ):
265265 super ().__init__ (** kwargs )
266266
267- def forward (self , base , source , subspaces = None ):
267+ def forward (self , base , source , subspaces = None , ** kwargs ):
268268
269269 return _do_intervention_by_swap (
270270 base ,
@@ -289,7 +289,7 @@ def __init__(self, **kwargs):
289289 rotate_layer = RotateLayer (self .embed_dim )
290290 self .rotate_layer = torch .nn .utils .parametrizations .orthogonal (rotate_layer )
291291
292- def forward (self , base , source , subspaces = None ):
292+ def forward (self , base , source , subspaces = None , ** kwargs ):
293293 rotated_base = self .rotate_layer (base )
294294 rotated_source = self .rotate_layer (source )
295295 # interchange
@@ -340,7 +340,7 @@ def set_intervention_boundaries(self, intervention_boundaries):
340340 torch .tensor ([intervention_boundaries ]), requires_grad = True
341341 )
342342
343- def forward (self , base , source , subspaces = None ):
343+ def forward (self , base , source , subspaces = None , ** kwargs ):
344344 batch_size = base .shape [0 ]
345345 rotated_base = self .rotate_layer (base )
346346 rotated_source = self .rotate_layer (source )
@@ -391,7 +391,7 @@ def get_temperature(self):
391391 def set_temperature (self , temp : torch .Tensor ):
392392 self .temperature .data = temp
393393
394- def forward (self , base , source , subspaces = None ):
394+ def forward (self , base , source , subspaces = None , ** kwargs ):
395395 batch_size = base .shape [0 ]
396396 rotated_base = self .rotate_layer (base )
397397 rotated_source = self .rotate_layer (source )
@@ -431,7 +431,7 @@ def get_temperature(self):
431431 def set_temperature (self , temp : torch .Tensor ):
432432 self .temperature .data = temp
433433
434- def forward (self , base , source , subspaces = None ):
434+ def forward (self , base , source , subspaces = None , ** kwargs ):
435435 batch_size = base .shape [0 ]
436436 # get boundary mask between 0 and 1 from sigmoid
437437 mask_sigmoid = torch .sigmoid (self .mask / torch .tensor (self .temperature ))
@@ -456,7 +456,7 @@ def __init__(self, **kwargs):
456456 rotate_layer = LowRankRotateLayer (self .embed_dim , kwargs ["low_rank_dimension" ])
457457 self .rotate_layer = torch .nn .utils .parametrizations .orthogonal (rotate_layer )
458458
459- def forward (self , base , source , subspaces = None ):
459+ def forward (self , base , source , subspaces = None , ** kwargs ):
460460 rotated_base = self .rotate_layer (base )
461461 rotated_source = self .rotate_layer (source )
462462 if subspaces is not None :
@@ -529,7 +529,7 @@ def __init__(self, **kwargs):
529529 )
530530 self .trainable = False
531531
532- def forward (self , base , source , subspaces = None ):
532+ def forward (self , base , source , subspaces = None , ** kwargs ):
533533 base_norm = (base - self .pca_mean ) / self .pca_std
534534 source_norm = (source - self .pca_mean ) / self .pca_std
535535
@@ -565,7 +565,7 @@ def __init__(self, **kwargs):
565565 prng (1 , 4 , self .embed_dim )))
566566 self .register_buffer ('noise_level' , torch .tensor (noise_level ))
567567
568- def forward (self , base , source = None , subspaces = None ):
568+ def forward (self , base , source = None , subspaces = None , ** kwargs ):
569569 base [..., : self .interchange_dim ] += self .noise * self .noise_level
570570 return base
571571
@@ -585,7 +585,7 @@ def __init__(self, **kwargs):
585585 self .autoencoder = AutoencoderLayer (
586586 self .embed_dim , kwargs ["latent_dim" ])
587587
588- def forward (self , base , source , subspaces = None ):
588+ def forward (self , base , source , subspaces = None , ** kwargs ):
589589 base_dtype = base .dtype
590590 base = base .to (self .autoencoder .encoder [0 ].weight .dtype )
591591 base_latent = self .autoencoder .encode (base )
@@ -619,7 +619,7 @@ def encode(self, input_acts):
619619 def decode (self , acts ):
620620 return acts @ self .W_dec + self .b_dec
621621
622- def forward (self , base , source = None , subspaces = None ):
622+ def forward (self , base , source = None , subspaces = None , ** kwargs ):
623623 # generate latents for base and source runs.
624624 base_latent = self .encode (base )
625625 source_latent = self .encode (source )
0 commit comments