@@ -32,6 +32,7 @@ def get_model(name: str, device: str = "cuda") -> SentenceTransformer:
32
32
33
33
def prepare_non_activating_examples (
34
34
tokens : Int [Tensor , "examples ctx_len" ],
35
+ activations : Float [Tensor , "examples ctx_len" ],
35
36
distance : float ,
36
37
tokenizer : PreTrainedTokenizer | PreTrainedTokenizerFast ,
37
38
) -> list [NonActivatingExample ]:
@@ -45,12 +46,12 @@ def prepare_non_activating_examples(
45
46
return [
46
47
NonActivatingExample (
47
48
tokens = toks ,
48
- activations = torch . zeros_like ( toks , dtype = torch . float ) ,
49
+ activations = acts ,
49
50
normalized_activations = None ,
50
51
distance = distance ,
51
52
str_tokens = tokenizer .batch_decode (toks ),
52
53
)
53
- for toks in tokens
54
+ for toks , acts in zip ( tokens , activations )
54
55
]
55
56
56
57
@@ -113,11 +114,9 @@ def pool_max_activation_windows(
113
114
114
115
# Get the max activation magnitude within each context window
115
116
max_buffer = torch .segment_reduce (activations , "max" , lengths = lengths )
116
-
117
117
# Deduplicate the context windows
118
118
new_tensor = torch .zeros (len (unique_ctx_indices ), ctx_len , dtype = activations .dtype )
119
119
new_tensor [inverses , index_within_ctx ] = activations
120
-
121
120
tokens = tokens [unique_ctx_indices ]
122
121
123
122
token_windows , activation_windows = _top_k_pools (
@@ -127,6 +126,114 @@ def pool_max_activation_windows(
127
126
return token_windows , activation_windows
128
127
129
128
129
+ def pool_centered_activation_windows (
130
+ activations : Float [Tensor , "examples" ],
131
+ tokens : Float [Tensor , "windows seq" ],
132
+ n_windows_per_batch : int ,
133
+ ctx_indices : Float [Tensor , "examples" ],
134
+ index_within_ctx : Float [Tensor , "examples" ],
135
+ ctx_len : int ,
136
+ max_examples : int ,
137
+ ) -> tuple [Float [Tensor , "examples ctx_len" ], Float [Tensor , "examples ctx_len" ]]:
138
+ """
139
+ Similar to pool_max_activation_windows. Doesn't use the ctx_indices that were
140
+ at the start of the batch or the end of the batch, because it always tries
141
+ to have a buffer of ctx_len*5//6 on the left and ctx_len*1//6 on the right.
142
+ To do this, for each window, it joins the contexts from the other windows
143
+ of the same batch, to form a new context, which is then cut to the correct shape,
144
+ centered on the max activation.
145
+
146
+ Args:
147
+ activations : The activations.
148
+ tokens : The input tokens.
149
+ ctx_indices : The context indices.
150
+ index_within_ctx : The index within the context.
151
+ ctx_len : The context length.
152
+ max_examples : The maximum number of examples.
153
+ """
154
+
155
+ # Get unique context indices and their counts like in pool_max_activation_windows
156
+ unique_ctx_indices , inverses , lengths = torch .unique_consecutive (
157
+ ctx_indices , return_counts = True , return_inverse = True
158
+ )
159
+
160
+ # Get the max activation magnitude within each context window
161
+ max_buffer = torch .segment_reduce (activations , "max" , lengths = lengths )
162
+
163
+ # Get the top max_examples windows
164
+ k = min (max_examples , len (max_buffer ))
165
+ top_values , top_indices = torch .topk (max_buffer , k , sorted = True )
166
+
167
+ # this tensor has the correct activations for each context window
168
+ temp_tensor = torch .zeros (len (unique_ctx_indices ), ctx_len , dtype = activations .dtype )
169
+ temp_tensor [inverses , index_within_ctx ] = activations
170
+
171
+ unique_ctx_indices = unique_ctx_indices [top_indices ]
172
+ temp_tensor = temp_tensor [top_indices ]
173
+
174
+ # if a element in unique_ctx_indices is divisible by n_windows_per_batch it
175
+ # the start of a new batch, so we discard it
176
+ modulo = unique_ctx_indices % n_windows_per_batch
177
+ not_first_position = modulo != 0
178
+ # remove also the elements that are at the end of the batch
179
+ not_last_position = modulo != n_windows_per_batch - 1
180
+ mask = not_first_position & not_last_position
181
+ unique_ctx_indices = unique_ctx_indices [mask ]
182
+ temp_tensor = temp_tensor [mask ]
183
+ if len (unique_ctx_indices ) == 0 :
184
+ return torch .zeros (0 , ctx_len ), torch .zeros (0 , ctx_len )
185
+
186
+ # Vectorized operations for all windows at once
187
+ n_windows = len (unique_ctx_indices )
188
+
189
+ # Create indices for previous, current, and next windows
190
+ prev_indices = unique_ctx_indices - 1
191
+ next_indices = unique_ctx_indices + 1
192
+
193
+ # Create a tensor to hold all concatenated tokens
194
+ all_tokens = torch .cat (
195
+ [tokens [prev_indices ], tokens [unique_ctx_indices ], tokens [next_indices ]], dim = 1
196
+ ) # Shape: [n_windows, ctx_len*3]
197
+
198
+ # Create tensor for all activations
199
+ final_tensor = torch .zeros ((n_windows , ctx_len * 3 ), dtype = activations .dtype )
200
+ final_tensor [:, ctx_len : ctx_len * 2 ] = (
201
+ temp_tensor # Set current window activations
202
+ )
203
+
204
+ # Set previous window activations where available
205
+ prev_mask = torch .isin (prev_indices , unique_ctx_indices )
206
+ if prev_mask .any ():
207
+ prev_locations = torch .where (
208
+ unique_ctx_indices .unsqueeze (1 ) == prev_indices .unsqueeze (0 )
209
+ )[1 ]
210
+ final_tensor [prev_mask , :ctx_len ] = temp_tensor [prev_locations ]
211
+
212
+ # Set next window activations where available
213
+ next_mask = torch .isin (next_indices , unique_ctx_indices )
214
+ if next_mask .any ():
215
+ next_locations = torch .where (
216
+ unique_ctx_indices .unsqueeze (1 ) == next_indices .unsqueeze (0 )
217
+ )[1 ]
218
+ final_tensor [next_mask , ctx_len * 2 :] = temp_tensor [next_locations ]
219
+
220
+ # Find max activation indices
221
+ max_activation_indices = torch .argmax (temp_tensor , dim = 1 ) + ctx_len
222
+
223
+ # Calculate left for all windows
224
+ left_positions = max_activation_indices - (ctx_len - ctx_len // 4 )
225
+
226
+ # Create index tensors for gathering
227
+ batch_indices = torch .arange (n_windows ).unsqueeze (1 )
228
+ pos_indices = torch .arange (ctx_len ).unsqueeze (0 )
229
+ gather_indices = left_positions .unsqueeze (1 ) + pos_indices
230
+
231
+ # Gather the final windows
232
+ token_windows = all_tokens [batch_indices , gather_indices ]
233
+ activation_windows = final_tensor [batch_indices , gather_indices ]
234
+ return token_windows , activation_windows
235
+
236
+
130
237
def constructor (
131
238
record : LatentRecord ,
132
239
activation_data : ActivationData ,
@@ -142,49 +249,55 @@ def constructor(
142
249
n_not_active = constructor_cfg .n_non_activating
143
250
max_examples = constructor_cfg .max_examples
144
251
min_examples = constructor_cfg .min_examples
145
-
146
252
# Get all positions where the latent is active
147
253
flat_indices = (
148
254
activation_data .locations [:, 0 ] * cache_ctx_len
149
255
+ activation_data .locations [:, 1 ]
150
256
)
151
257
ctx_indices = flat_indices // example_ctx_len
152
258
index_within_ctx = flat_indices % example_ctx_len
259
+ n_windows_per_batch = tokens .shape [1 ] // example_ctx_len
153
260
reshaped_tokens = tokens .reshape (- 1 , example_ctx_len )
154
261
n_windows = reshaped_tokens .shape [0 ]
155
-
156
262
unique_batch_pos = ctx_indices .unique ()
157
-
158
263
mask = torch .ones (n_windows , dtype = torch .bool )
159
264
mask [unique_batch_pos ] = False
160
265
# Indices where the latent is not active
161
266
non_active_indices = mask .nonzero (as_tuple = False ).squeeze ()
162
267
activations = activation_data .activations
163
-
164
268
# per context frequency
165
269
record .per_context_frequency = len (unique_batch_pos ) / n_windows
166
270
167
271
# Add activation examples to the record in place
168
- token_windows , act_windows = pool_max_activation_windows (
169
- activations = activations ,
170
- tokens = reshaped_tokens ,
171
- ctx_indices = ctx_indices ,
172
- index_within_ctx = index_within_ctx ,
173
- ctx_len = example_ctx_len ,
174
- max_examples = max_examples ,
175
- )
272
+ if constructor_cfg .center_examples :
273
+ token_windows , act_windows = pool_max_activation_windows (
274
+ activations = activations ,
275
+ tokens = reshaped_tokens ,
276
+ ctx_indices = ctx_indices ,
277
+ index_within_ctx = index_within_ctx ,
278
+ ctx_len = example_ctx_len ,
279
+ max_examples = max_examples ,
280
+ )
281
+ else :
282
+ token_windows , act_windows = pool_centered_activation_windows (
283
+ activations = activations ,
284
+ tokens = reshaped_tokens ,
285
+ n_windows_per_batch = n_windows_per_batch ,
286
+ ctx_indices = ctx_indices ,
287
+ index_within_ctx = index_within_ctx ,
288
+ ctx_len = example_ctx_len ,
289
+ max_examples = max_examples ,
290
+ )
176
291
# TODO: We might want to do this in the sampler
177
292
# we are tokenizing examples that are not going to be used
178
293
record .examples = [
179
294
ActivatingExample (
180
295
tokens = toks ,
181
296
activations = acts ,
182
297
normalized_activations = None ,
183
- str_tokens = tokenizer .batch_decode (toks ),
184
298
)
185
299
for toks , acts in zip (token_windows , act_windows )
186
300
]
187
-
188
301
if len (record .examples ) < min_examples :
189
302
# Not enough examples to explain the latent
190
303
return None
@@ -420,6 +533,7 @@ def faiss_non_activation_windows(
420
533
# Create non-activating examples
421
534
return prepare_non_activating_examples (
422
535
selected_tokens ,
536
+ torch .zeros_like (selected_tokens ),
423
537
- 1.0 , # Using -1.0 as the distance since these are not neighbour-based
424
538
tokenizer ,
425
539
)
@@ -495,7 +609,7 @@ def neighbour_non_activation_windows(
495
609
# If there are no available indices, skip this neighbour
496
610
if activations .numel () == 0 :
497
611
continue
498
- token_windows , _ = pool_max_activation_windows (
612
+ token_windows , token_activations = pool_max_activation_windows (
499
613
activations = activations ,
500
614
tokens = reshaped_tokens ,
501
615
ctx_indices = available_ctx_indices ,
@@ -508,12 +622,23 @@ def neighbour_non_activation_windows(
508
622
examples_used = len (token_windows )
509
623
all_examples .extend (
510
624
prepare_non_activating_examples (
511
- token_windows , - neighbour .distance , tokenizer
625
+ token_windows ,
626
+ token_activations , # activations of neighbour
627
+ - neighbour .distance ,
628
+ tokenizer ,
512
629
)
513
630
)
514
631
number_examples += examples_used
515
632
if len (all_examples ) == 0 :
516
- print ("No examples found" )
633
+ print ("No examples found, falling back to random non-activating examples" )
634
+ non_active_indices = not_active_mask .nonzero (as_tuple = False ).squeeze ()
635
+
636
+ return random_non_activating_windows (
637
+ available_indices = non_active_indices ,
638
+ reshaped_tokens = reshaped_tokens ,
639
+ n_not_active = n_not_active ,
640
+ tokenizer = tokenizer ,
641
+ )
517
642
return all_examples
518
643
519
644
@@ -554,6 +679,7 @@ def random_non_activating_windows(
554
679
555
680
return prepare_non_activating_examples (
556
681
toks ,
682
+ torch .zeros_like (toks ), # there is no way to define these activations
557
683
- 1.0 ,
558
684
tokenizer ,
559
685
)
0 commit comments