Skip to content

Commit 6b9af04

Browse files
authored
Merge pull request #113 from EleutherAI/intruder
-Two new scoring methods, that don't require any model explanations. -Centers activating examples. For now, 3/4 of the window will be to the left of the maximal activating token in the previous windows, and 1/4 will be to the right. To do this, I'm discarding all activations from the first and last window of each batch (which is 1/4 of the total activations we have if we collect with ctx_len 256). -Some fixes to the online client -New example sampling option
2 parents 4b45f8a + 0c9e294 commit 6b9af04

17 files changed

+1012
-64
lines changed

delphi/clients/offline.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
prefix_caching: bool = True,
4646
batch_size: int = 100,
4747
max_model_len: int = 4096,
48+
number_tokens_to_generate: int = 500,
4849
num_gpus: int = 2,
4950
enforce_eager: bool = False,
5051
statistics: bool = False,
@@ -65,7 +66,7 @@ def __init__(
6566
max_model_len=max_model_len,
6667
enforce_eager=enforce_eager,
6768
)
68-
self.sampling_params = SamplingParams(max_tokens=500)
69+
self.sampling_params = SamplingParams(max_tokens=number_tokens_to_generate)
6970
self.tokenizer = AutoTokenizer.from_pretrained(model)
7071
self.batch_size = batch_size
7172
self.statistics = statistics

delphi/clients/openrouter.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@ def __init__(
2323
model: str,
2424
api_key: str | None = None,
2525
base_url="https://openrouter.ai/api/v1/chat/completions",
26+
max_tokens: int = 3000,
27+
temperature: float = 1.0,
2628
):
2729
super().__init__(model)
2830

2931
self.headers = {"Authorization": f"Bearer {api_key}"}
3032

3133
self.url = base_url
34+
self.max_tokens = max_tokens
35+
self.temperature = temperature
3236
timeout_config = httpx.Timeout(5.0)
3337
self.client = httpx.AsyncClient(timeout=timeout_config)
3438

@@ -45,8 +49,10 @@ async def generate( # type: ignore
4549
**kwargs, # type: ignore
4650
) -> Response: # type: ignore
4751
kwargs.pop("schema", None)
48-
max_tokens = kwargs.pop("max_tokens", 500)
49-
temperature = kwargs.pop("temperature", 1.0)
52+
# We have to decide if we want to do this like this or not
53+
# Currently only simulation uses generation kwargs.
54+
max_tokens = kwargs.pop("max_tokens", self.max_tokens)
55+
temperature = kwargs.pop("temperature", self.temperature)
5056
data = {
5157
"model": self.model,
5258
"messages": prompt,
@@ -58,7 +64,7 @@ async def generate( # type: ignore
5864
for attempt in range(max_retries):
5965
try:
6066
response = await self.client.post(
61-
url=self.url, json=data, headers=self.headers
67+
url=self.url, json=data, headers=self.headers, timeout=100
6268
)
6369
if raw:
6470
return response.json()

delphi/config.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ class SamplerConfig(Serializable):
1717
n_quantiles: int = 10
1818
"""Number of latent activation quantiles to sample."""
1919

20-
train_type: Literal["top", "random", "quantiles"] = "quantiles"
20+
train_type: Literal["top", "random", "quantiles", "mix"] = "quantiles"
2121
"""Type of sampler to use for latent explanation generation."""
2222

2323
test_type: Literal["quantiles"] = "quantiles"
2424
"""Type of sampler to use for latent explanation testing."""
2525

26+
ratio_top: float = 0.2
27+
"""Ratio of top examples to use for training, if using mix."""
28+
2629

2730
@dataclass
2831
class ConstructorConfig(Serializable):
@@ -51,6 +54,12 @@ class ConstructorConfig(Serializable):
5154
n_non_activating: int = 50
5255
"""Number of non-activating examples to be constructed."""
5356

57+
center_examples: bool = True
58+
"""Whether to center the examples on the latent activation.
59+
If True, the examples will be centered on the latent activation.
60+
Otherwise, windows will be used, and the activating example can be anywhere
61+
window."""
62+
5463
non_activating_source: Literal["random", "neighbours", "FAISS"] = "random"
5564
"""Source of non-activating examples. Random uses non-activating contexts
5665
sampled from any non activating window. Neighbours uses actvating contexts

delphi/latents/constructors.py

+147-21
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_model(name: str, device: str = "cuda") -> SentenceTransformer:
3232

3333
def prepare_non_activating_examples(
3434
tokens: Int[Tensor, "examples ctx_len"],
35+
activations: Float[Tensor, "examples ctx_len"],
3536
distance: float,
3637
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
3738
) -> list[NonActivatingExample]:
@@ -45,12 +46,12 @@ def prepare_non_activating_examples(
4546
return [
4647
NonActivatingExample(
4748
tokens=toks,
48-
activations=torch.zeros_like(toks, dtype=torch.float),
49+
activations=acts,
4950
normalized_activations=None,
5051
distance=distance,
5152
str_tokens=tokenizer.batch_decode(toks),
5253
)
53-
for toks in tokens
54+
for toks, acts in zip(tokens, activations)
5455
]
5556

5657

@@ -113,11 +114,9 @@ def pool_max_activation_windows(
113114

114115
# Get the max activation magnitude within each context window
115116
max_buffer = torch.segment_reduce(activations, "max", lengths=lengths)
116-
117117
# Deduplicate the context windows
118118
new_tensor = torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype)
119119
new_tensor[inverses, index_within_ctx] = activations
120-
121120
tokens = tokens[unique_ctx_indices]
122121

123122
token_windows, activation_windows = _top_k_pools(
@@ -127,6 +126,114 @@ def pool_max_activation_windows(
127126
return token_windows, activation_windows
128127

129128

129+
def pool_centered_activation_windows(
130+
activations: Float[Tensor, "examples"],
131+
tokens: Float[Tensor, "windows seq"],
132+
n_windows_per_batch: int,
133+
ctx_indices: Float[Tensor, "examples"],
134+
index_within_ctx: Float[Tensor, "examples"],
135+
ctx_len: int,
136+
max_examples: int,
137+
) -> tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]:
138+
"""
139+
Similar to pool_max_activation_windows. Doesn't use the ctx_indices that were
140+
at the start of the batch or the end of the batch, because it always tries
141+
to have a buffer of ctx_len*5//6 on the left and ctx_len*1//6 on the right.
142+
To do this, for each window, it joins the contexts from the other windows
143+
of the same batch, to form a new context, which is then cut to the correct shape,
144+
centered on the max activation.
145+
146+
Args:
147+
activations : The activations.
148+
tokens : The input tokens.
149+
ctx_indices : The context indices.
150+
index_within_ctx : The index within the context.
151+
ctx_len : The context length.
152+
max_examples : The maximum number of examples.
153+
"""
154+
155+
# Get unique context indices and their counts like in pool_max_activation_windows
156+
unique_ctx_indices, inverses, lengths = torch.unique_consecutive(
157+
ctx_indices, return_counts=True, return_inverse=True
158+
)
159+
160+
# Get the max activation magnitude within each context window
161+
max_buffer = torch.segment_reduce(activations, "max", lengths=lengths)
162+
163+
# Get the top max_examples windows
164+
k = min(max_examples, len(max_buffer))
165+
top_values, top_indices = torch.topk(max_buffer, k, sorted=True)
166+
167+
# this tensor has the correct activations for each context window
168+
temp_tensor = torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype)
169+
temp_tensor[inverses, index_within_ctx] = activations
170+
171+
unique_ctx_indices = unique_ctx_indices[top_indices]
172+
temp_tensor = temp_tensor[top_indices]
173+
174+
# if a element in unique_ctx_indices is divisible by n_windows_per_batch it
175+
# the start of a new batch, so we discard it
176+
modulo = unique_ctx_indices % n_windows_per_batch
177+
not_first_position = modulo != 0
178+
# remove also the elements that are at the end of the batch
179+
not_last_position = modulo != n_windows_per_batch - 1
180+
mask = not_first_position & not_last_position
181+
unique_ctx_indices = unique_ctx_indices[mask]
182+
temp_tensor = temp_tensor[mask]
183+
if len(unique_ctx_indices) == 0:
184+
return torch.zeros(0, ctx_len), torch.zeros(0, ctx_len)
185+
186+
# Vectorized operations for all windows at once
187+
n_windows = len(unique_ctx_indices)
188+
189+
# Create indices for previous, current, and next windows
190+
prev_indices = unique_ctx_indices - 1
191+
next_indices = unique_ctx_indices + 1
192+
193+
# Create a tensor to hold all concatenated tokens
194+
all_tokens = torch.cat(
195+
[tokens[prev_indices], tokens[unique_ctx_indices], tokens[next_indices]], dim=1
196+
) # Shape: [n_windows, ctx_len*3]
197+
198+
# Create tensor for all activations
199+
final_tensor = torch.zeros((n_windows, ctx_len * 3), dtype=activations.dtype)
200+
final_tensor[:, ctx_len : ctx_len * 2] = (
201+
temp_tensor # Set current window activations
202+
)
203+
204+
# Set previous window activations where available
205+
prev_mask = torch.isin(prev_indices, unique_ctx_indices)
206+
if prev_mask.any():
207+
prev_locations = torch.where(
208+
unique_ctx_indices.unsqueeze(1) == prev_indices.unsqueeze(0)
209+
)[1]
210+
final_tensor[prev_mask, :ctx_len] = temp_tensor[prev_locations]
211+
212+
# Set next window activations where available
213+
next_mask = torch.isin(next_indices, unique_ctx_indices)
214+
if next_mask.any():
215+
next_locations = torch.where(
216+
unique_ctx_indices.unsqueeze(1) == next_indices.unsqueeze(0)
217+
)[1]
218+
final_tensor[next_mask, ctx_len * 2 :] = temp_tensor[next_locations]
219+
220+
# Find max activation indices
221+
max_activation_indices = torch.argmax(temp_tensor, dim=1) + ctx_len
222+
223+
# Calculate left for all windows
224+
left_positions = max_activation_indices - (ctx_len - ctx_len // 4)
225+
226+
# Create index tensors for gathering
227+
batch_indices = torch.arange(n_windows).unsqueeze(1)
228+
pos_indices = torch.arange(ctx_len).unsqueeze(0)
229+
gather_indices = left_positions.unsqueeze(1) + pos_indices
230+
231+
# Gather the final windows
232+
token_windows = all_tokens[batch_indices, gather_indices]
233+
activation_windows = final_tensor[batch_indices, gather_indices]
234+
return token_windows, activation_windows
235+
236+
130237
def constructor(
131238
record: LatentRecord,
132239
activation_data: ActivationData,
@@ -142,49 +249,55 @@ def constructor(
142249
n_not_active = constructor_cfg.n_non_activating
143250
max_examples = constructor_cfg.max_examples
144251
min_examples = constructor_cfg.min_examples
145-
146252
# Get all positions where the latent is active
147253
flat_indices = (
148254
activation_data.locations[:, 0] * cache_ctx_len
149255
+ activation_data.locations[:, 1]
150256
)
151257
ctx_indices = flat_indices // example_ctx_len
152258
index_within_ctx = flat_indices % example_ctx_len
259+
n_windows_per_batch = tokens.shape[1] // example_ctx_len
153260
reshaped_tokens = tokens.reshape(-1, example_ctx_len)
154261
n_windows = reshaped_tokens.shape[0]
155-
156262
unique_batch_pos = ctx_indices.unique()
157-
158263
mask = torch.ones(n_windows, dtype=torch.bool)
159264
mask[unique_batch_pos] = False
160265
# Indices where the latent is not active
161266
non_active_indices = mask.nonzero(as_tuple=False).squeeze()
162267
activations = activation_data.activations
163-
164268
# per context frequency
165269
record.per_context_frequency = len(unique_batch_pos) / n_windows
166270

167271
# Add activation examples to the record in place
168-
token_windows, act_windows = pool_max_activation_windows(
169-
activations=activations,
170-
tokens=reshaped_tokens,
171-
ctx_indices=ctx_indices,
172-
index_within_ctx=index_within_ctx,
173-
ctx_len=example_ctx_len,
174-
max_examples=max_examples,
175-
)
272+
if constructor_cfg.center_examples:
273+
token_windows, act_windows = pool_max_activation_windows(
274+
activations=activations,
275+
tokens=reshaped_tokens,
276+
ctx_indices=ctx_indices,
277+
index_within_ctx=index_within_ctx,
278+
ctx_len=example_ctx_len,
279+
max_examples=max_examples,
280+
)
281+
else:
282+
token_windows, act_windows = pool_centered_activation_windows(
283+
activations=activations,
284+
tokens=reshaped_tokens,
285+
n_windows_per_batch=n_windows_per_batch,
286+
ctx_indices=ctx_indices,
287+
index_within_ctx=index_within_ctx,
288+
ctx_len=example_ctx_len,
289+
max_examples=max_examples,
290+
)
176291
# TODO: We might want to do this in the sampler
177292
# we are tokenizing examples that are not going to be used
178293
record.examples = [
179294
ActivatingExample(
180295
tokens=toks,
181296
activations=acts,
182297
normalized_activations=None,
183-
str_tokens=tokenizer.batch_decode(toks),
184298
)
185299
for toks, acts in zip(token_windows, act_windows)
186300
]
187-
188301
if len(record.examples) < min_examples:
189302
# Not enough examples to explain the latent
190303
return None
@@ -420,6 +533,7 @@ def faiss_non_activation_windows(
420533
# Create non-activating examples
421534
return prepare_non_activating_examples(
422535
selected_tokens,
536+
torch.zeros_like(selected_tokens),
423537
-1.0, # Using -1.0 as the distance since these are not neighbour-based
424538
tokenizer,
425539
)
@@ -495,7 +609,7 @@ def neighbour_non_activation_windows(
495609
# If there are no available indices, skip this neighbour
496610
if activations.numel() == 0:
497611
continue
498-
token_windows, _ = pool_max_activation_windows(
612+
token_windows, token_activations = pool_max_activation_windows(
499613
activations=activations,
500614
tokens=reshaped_tokens,
501615
ctx_indices=available_ctx_indices,
@@ -508,12 +622,23 @@ def neighbour_non_activation_windows(
508622
examples_used = len(token_windows)
509623
all_examples.extend(
510624
prepare_non_activating_examples(
511-
token_windows, -neighbour.distance, tokenizer
625+
token_windows,
626+
token_activations, # activations of neighbour
627+
-neighbour.distance,
628+
tokenizer,
512629
)
513630
)
514631
number_examples += examples_used
515632
if len(all_examples) == 0:
516-
print("No examples found")
633+
print("No examples found, falling back to random non-activating examples")
634+
non_active_indices = not_active_mask.nonzero(as_tuple=False).squeeze()
635+
636+
return random_non_activating_windows(
637+
available_indices=non_active_indices,
638+
reshaped_tokens=reshaped_tokens,
639+
n_not_active=n_not_active,
640+
tokenizer=tokenizer,
641+
)
517642
return all_examples
518643

519644

@@ -554,6 +679,7 @@ def random_non_activating_windows(
554679

555680
return prepare_non_activating_examples(
556681
toks,
682+
torch.zeros_like(toks), # there is no way to define these activations
557683
-1.0,
558684
tokenizer,
559685
)

delphi/latents/latents.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class Example:
7575
activations: Float[Tensor, "ctx_len"]
7676
"""Activation values for the input sequence."""
7777

78-
str_tokens: list[str]
78+
str_tokens: list[str] | None = None
7979
"""Tokenized input sequence as strings."""
8080

8181
normalized_activations: Optional[Float[Tensor, "ctx_len"]] = None
@@ -134,7 +134,7 @@ class LatentRecord:
134134
train: list[ActivatingExample] = field(default_factory=list)
135135
"""Training examples."""
136136

137-
test: list[ActivatingExample] | list[list[Example]] = field(default_factory=list)
137+
test: list[ActivatingExample] = field(default_factory=list)
138138
"""Test examples."""
139139

140140
neighbours: list[Neighbour] = field(default_factory=list)

0 commit comments

Comments
 (0)