@@ -32,6 +32,7 @@ def get_model(name: str, device: str = "cuda") -> SentenceTransformer:
3232
3333def prepare_non_activating_examples (
3434 tokens : Int [Tensor , "examples ctx_len" ],
35+ activations : Float [Tensor , "examples ctx_len" ],
3536 distance : float ,
3637 tokenizer : PreTrainedTokenizer | PreTrainedTokenizerFast ,
3738) -> list [NonActivatingExample ]:
@@ -45,12 +46,12 @@ def prepare_non_activating_examples(
4546 return [
4647 NonActivatingExample (
4748 tokens = toks ,
48- activations = torch . zeros_like ( toks , dtype = torch . float ) ,
49+ activations = acts ,
4950 normalized_activations = None ,
5051 distance = distance ,
5152 str_tokens = tokenizer .batch_decode (toks ),
5253 )
53- for toks in tokens
54+ for toks , acts in zip ( tokens , activations )
5455 ]
5556
5657
@@ -113,11 +114,9 @@ def pool_max_activation_windows(
113114
114115 # Get the max activation magnitude within each context window
115116 max_buffer = torch .segment_reduce (activations , "max" , lengths = lengths )
116-
117117 # Deduplicate the context windows
118118 new_tensor = torch .zeros (len (unique_ctx_indices ), ctx_len , dtype = activations .dtype )
119119 new_tensor [inverses , index_within_ctx ] = activations
120-
121120 tokens = tokens [unique_ctx_indices ]
122121
123122 token_windows , activation_windows = _top_k_pools (
@@ -127,6 +126,114 @@ def pool_max_activation_windows(
127126 return token_windows , activation_windows
128127
129128
129+ def pool_centered_activation_windows (
130+ activations : Float [Tensor , "examples" ],
131+ tokens : Float [Tensor , "windows seq" ],
132+ n_windows_per_batch : int ,
133+ ctx_indices : Float [Tensor , "examples" ],
134+ index_within_ctx : Float [Tensor , "examples" ],
135+ ctx_len : int ,
136+ max_examples : int ,
137+ ) -> tuple [Float [Tensor , "examples ctx_len" ], Float [Tensor , "examples ctx_len" ]]:
138+ """
139+ Similar to pool_max_activation_windows. Doesn't use the ctx_indices that were
140+ at the start of the batch or the end of the batch, because it always tries
141+ to have a buffer of ctx_len*5//6 on the left and ctx_len*1//6 on the right.
142+ To do this, for each window, it joins the contexts from the other windows
143+ of the same batch, to form a new context, which is then cut to the correct shape,
144+ centered on the max activation.
145+
146+ Args:
147+ activations : The activations.
148+ tokens : The input tokens.
149+ ctx_indices : The context indices.
150+ index_within_ctx : The index within the context.
151+ ctx_len : The context length.
152+ max_examples : The maximum number of examples.
153+ """
154+
155+ # Get unique context indices and their counts like in pool_max_activation_windows
156+ unique_ctx_indices , inverses , lengths = torch .unique_consecutive (
157+ ctx_indices , return_counts = True , return_inverse = True
158+ )
159+
160+ # Get the max activation magnitude within each context window
161+ max_buffer = torch .segment_reduce (activations , "max" , lengths = lengths )
162+
163+ # Get the top max_examples windows
164+ k = min (max_examples , len (max_buffer ))
165+ top_values , top_indices = torch .topk (max_buffer , k , sorted = True )
166+
167+ # this tensor has the correct activations for each context window
168+ temp_tensor = torch .zeros (len (unique_ctx_indices ), ctx_len , dtype = activations .dtype )
169+ temp_tensor [inverses , index_within_ctx ] = activations
170+
171+ unique_ctx_indices = unique_ctx_indices [top_indices ]
172+ temp_tensor = temp_tensor [top_indices ]
173+
174+ # if a element in unique_ctx_indices is divisible by n_windows_per_batch it
175+ # the start of a new batch, so we discard it
176+ modulo = unique_ctx_indices % n_windows_per_batch
177+ not_first_position = modulo != 0
178+ # remove also the elements that are at the end of the batch
179+ not_last_position = modulo != n_windows_per_batch - 1
180+ mask = not_first_position & not_last_position
181+ unique_ctx_indices = unique_ctx_indices [mask ]
182+ temp_tensor = temp_tensor [mask ]
183+ if len (unique_ctx_indices ) == 0 :
184+ return torch .zeros (0 , ctx_len ), torch .zeros (0 , ctx_len )
185+
186+ # Vectorized operations for all windows at once
187+ n_windows = len (unique_ctx_indices )
188+
189+ # Create indices for previous, current, and next windows
190+ prev_indices = unique_ctx_indices - 1
191+ next_indices = unique_ctx_indices + 1
192+
193+ # Create a tensor to hold all concatenated tokens
194+ all_tokens = torch .cat (
195+ [tokens [prev_indices ], tokens [unique_ctx_indices ], tokens [next_indices ]], dim = 1
196+ ) # Shape: [n_windows, ctx_len*3]
197+
198+ # Create tensor for all activations
199+ final_tensor = torch .zeros ((n_windows , ctx_len * 3 ), dtype = activations .dtype )
200+ final_tensor [:, ctx_len : ctx_len * 2 ] = (
201+ temp_tensor # Set current window activations
202+ )
203+
204+ # Set previous window activations where available
205+ prev_mask = torch .isin (prev_indices , unique_ctx_indices )
206+ if prev_mask .any ():
207+ prev_locations = torch .where (
208+ unique_ctx_indices .unsqueeze (1 ) == prev_indices .unsqueeze (0 )
209+ )[1 ]
210+ final_tensor [prev_mask , :ctx_len ] = temp_tensor [prev_locations ]
211+
212+ # Set next window activations where available
213+ next_mask = torch .isin (next_indices , unique_ctx_indices )
214+ if next_mask .any ():
215+ next_locations = torch .where (
216+ unique_ctx_indices .unsqueeze (1 ) == next_indices .unsqueeze (0 )
217+ )[1 ]
218+ final_tensor [next_mask , ctx_len * 2 :] = temp_tensor [next_locations ]
219+
220+ # Find max activation indices
221+ max_activation_indices = torch .argmax (temp_tensor , dim = 1 ) + ctx_len
222+
223+ # Calculate left for all windows
224+ left_positions = max_activation_indices - (ctx_len - ctx_len // 4 )
225+
226+ # Create index tensors for gathering
227+ batch_indices = torch .arange (n_windows ).unsqueeze (1 )
228+ pos_indices = torch .arange (ctx_len ).unsqueeze (0 )
229+ gather_indices = left_positions .unsqueeze (1 ) + pos_indices
230+
231+ # Gather the final windows
232+ token_windows = all_tokens [batch_indices , gather_indices ]
233+ activation_windows = final_tensor [batch_indices , gather_indices ]
234+ return token_windows , activation_windows
235+
236+
130237def constructor (
131238 record : LatentRecord ,
132239 activation_data : ActivationData ,
@@ -142,49 +249,55 @@ def constructor(
142249 n_not_active = constructor_cfg .n_non_activating
143250 max_examples = constructor_cfg .max_examples
144251 min_examples = constructor_cfg .min_examples
145-
146252 # Get all positions where the latent is active
147253 flat_indices = (
148254 activation_data .locations [:, 0 ] * cache_ctx_len
149255 + activation_data .locations [:, 1 ]
150256 )
151257 ctx_indices = flat_indices // example_ctx_len
152258 index_within_ctx = flat_indices % example_ctx_len
259+ n_windows_per_batch = tokens .shape [1 ] // example_ctx_len
153260 reshaped_tokens = tokens .reshape (- 1 , example_ctx_len )
154261 n_windows = reshaped_tokens .shape [0 ]
155-
156262 unique_batch_pos = ctx_indices .unique ()
157-
158263 mask = torch .ones (n_windows , dtype = torch .bool )
159264 mask [unique_batch_pos ] = False
160265 # Indices where the latent is not active
161266 non_active_indices = mask .nonzero (as_tuple = False ).squeeze ()
162267 activations = activation_data .activations
163-
164268 # per context frequency
165269 record .per_context_frequency = len (unique_batch_pos ) / n_windows
166270
167271 # Add activation examples to the record in place
168- token_windows , act_windows = pool_max_activation_windows (
169- activations = activations ,
170- tokens = reshaped_tokens ,
171- ctx_indices = ctx_indices ,
172- index_within_ctx = index_within_ctx ,
173- ctx_len = example_ctx_len ,
174- max_examples = max_examples ,
175- )
272+ if constructor_cfg .center_examples :
273+ token_windows , act_windows = pool_max_activation_windows (
274+ activations = activations ,
275+ tokens = reshaped_tokens ,
276+ ctx_indices = ctx_indices ,
277+ index_within_ctx = index_within_ctx ,
278+ ctx_len = example_ctx_len ,
279+ max_examples = max_examples ,
280+ )
281+ else :
282+ token_windows , act_windows = pool_centered_activation_windows (
283+ activations = activations ,
284+ tokens = reshaped_tokens ,
285+ n_windows_per_batch = n_windows_per_batch ,
286+ ctx_indices = ctx_indices ,
287+ index_within_ctx = index_within_ctx ,
288+ ctx_len = example_ctx_len ,
289+ max_examples = max_examples ,
290+ )
176291 # TODO: We might want to do this in the sampler
177292 # we are tokenizing examples that are not going to be used
178293 record .examples = [
179294 ActivatingExample (
180295 tokens = toks ,
181296 activations = acts ,
182297 normalized_activations = None ,
183- str_tokens = tokenizer .batch_decode (toks ),
184298 )
185299 for toks , acts in zip (token_windows , act_windows )
186300 ]
187-
188301 if len (record .examples ) < min_examples :
189302 # Not enough examples to explain the latent
190303 return None
@@ -420,6 +533,7 @@ def faiss_non_activation_windows(
420533 # Create non-activating examples
421534 return prepare_non_activating_examples (
422535 selected_tokens ,
536+ torch .zeros_like (selected_tokens ),
423537 - 1.0 , # Using -1.0 as the distance since these are not neighbour-based
424538 tokenizer ,
425539 )
@@ -495,7 +609,7 @@ def neighbour_non_activation_windows(
495609 # If there are no available indices, skip this neighbour
496610 if activations .numel () == 0 :
497611 continue
498- token_windows , _ = pool_max_activation_windows (
612+ token_windows , token_activations = pool_max_activation_windows (
499613 activations = activations ,
500614 tokens = reshaped_tokens ,
501615 ctx_indices = available_ctx_indices ,
@@ -508,12 +622,23 @@ def neighbour_non_activation_windows(
508622 examples_used = len (token_windows )
509623 all_examples .extend (
510624 prepare_non_activating_examples (
511- token_windows , - neighbour .distance , tokenizer
625+ token_windows ,
626+ token_activations , # activations of neighbour
627+ - neighbour .distance ,
628+ tokenizer ,
512629 )
513630 )
514631 number_examples += examples_used
515632 if len (all_examples ) == 0 :
516- print ("No examples found" )
633+ print ("No examples found, falling back to random non-activating examples" )
634+ non_active_indices = not_active_mask .nonzero (as_tuple = False ).squeeze ()
635+
636+ return random_non_activating_windows (
637+ available_indices = non_active_indices ,
638+ reshaped_tokens = reshaped_tokens ,
639+ n_not_active = n_not_active ,
640+ tokenizer = tokenizer ,
641+ )
517642 return all_examples
518643
519644
@@ -554,6 +679,7 @@ def random_non_activating_windows(
554679
555680 return prepare_non_activating_examples (
556681 toks ,
682+ torch .zeros_like (toks ), # there is no way to define these activations
557683 - 1.0 ,
558684 tokenizer ,
559685 )
0 commit comments