Skip to content

Intruder detection scoring and example embedding scoring #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8b28d78
Using activations of neibhour latents
Mar 7, 2025
f99a7f9
Fix typing
Mar 7, 2025
53155a2
New fuzz metrics
Mar 7, 2025
b4e91e4
Adding default fallback
Mar 7, 2025
b6e978b
Merge branch 'main' of https://github.com/EleutherAI/delphi into neig…
Mar 13, 2025
242530f
Merge branch 'main' of https://github.com/EleutherAI/delphi into neig…
Mar 13, 2025
576730f
New intruder scorer - no need for explanations
Mar 13, 2025
f297a1c
New fuzz
Mar 13, 2025
33cbbe4
Merge branch 'main' of https://github.com/EleutherAI/delphi into intr…
Mar 13, 2025
06a9254
New prompts
Mar 14, 2025
b7bdfa7
Simplifying intruder
Mar 18, 2025
628cb17
Removing unused parts
Mar 18, 2025
47868ab
Fix fuzz
Mar 19, 2025
0b75f97
Cleaning intruder code
Apr 1, 2025
4a24cae
Adding example embedding scorer
Apr 1, 2025
8d996d2
Add number of tokens to generate to clients
Apr 1, 2025
710eda7
Merge branch 'main' of https://github.com/EleutherAI/delphi into intr…
Apr 1, 2025
fda5033
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 1, 2025
e8b11f4
Change type
Apr 2, 2025
6ed5456
Small loader debug thingy
Apr 2, 2025
e920de7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2025
b33eee7
Adding a ration
Apr 2, 2025
f10cae7
Merge branch 'mix_sampling' of https://github.com/EleutherAI/delphi i…
Apr 2, 2025
f3aff2c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2025
6e3b31a
Correct FAISS error
Apr 2, 2025
d7ef127
New centered examples
Apr 21, 2025
4880aa1
Merge branch 'main' of https://github.com/EleutherAI/delphi into intr…
Apr 21, 2025
aa041a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2025
d171a04
Delete something that shouldn't be here
Apr 21, 2025
cae363e
Merge branch 'intruder' of https://github.com/EleutherAI/delphi into …
Apr 21, 2025
cc9387c
Center option, on by default now
Apr 21, 2025
0c9e294
Beartyping and fixing test
Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion delphi/clients/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
prefix_caching: bool = True,
batch_size: int = 100,
max_model_len: int = 4096,
number_tokens_to_generate: int = 500,
num_gpus: int = 2,
enforce_eager: bool = False,
statistics: bool = False,
Expand All @@ -65,7 +66,7 @@ def __init__(
max_model_len=max_model_len,
enforce_eager=enforce_eager,
)
self.sampling_params = SamplingParams(max_tokens=500)
self.sampling_params = SamplingParams(max_tokens=number_tokens_to_generate)
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.batch_size = batch_size
self.statistics = statistics
Expand Down
12 changes: 9 additions & 3 deletions delphi/clients/openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def __init__(
model: str,
api_key: str | None = None,
base_url="https://openrouter.ai/api/v1/chat/completions",
max_tokens: int = 3000,
temperature: float = 1.0,
):
super().__init__(model)

self.headers = {"Authorization": f"Bearer {api_key}"}

self.url = base_url
self.max_tokens = max_tokens
self.temperature = temperature
timeout_config = httpx.Timeout(5.0)
self.client = httpx.AsyncClient(timeout=timeout_config)

Expand All @@ -45,8 +49,10 @@ async def generate( # type: ignore
**kwargs, # type: ignore
) -> Response: # type: ignore
kwargs.pop("schema", None)
max_tokens = kwargs.pop("max_tokens", 500)
temperature = kwargs.pop("temperature", 1.0)
# We have to decide if we want to do this like this or not
# Currently only simulation uses generation kwargs.
max_tokens = kwargs.pop("max_tokens", self.max_tokens)
temperature = kwargs.pop("temperature", self.temperature)
data = {
"model": self.model,
"messages": prompt,
Expand All @@ -58,7 +64,7 @@ async def generate( # type: ignore
for attempt in range(max_retries):
try:
response = await self.client.post(
url=self.url, json=data, headers=self.headers
url=self.url, json=data, headers=self.headers, timeout=100
)
if raw:
return response.json()
Expand Down
11 changes: 10 additions & 1 deletion delphi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ class SamplerConfig(Serializable):
n_quantiles: int = 10
"""Number of latent activation quantiles to sample."""

train_type: Literal["top", "random", "quantiles"] = "quantiles"
train_type: Literal["top", "random", "quantiles", "mix"] = "quantiles"
"""Type of sampler to use for latent explanation generation."""

test_type: Literal["quantiles"] = "quantiles"
"""Type of sampler to use for latent explanation testing."""

ratio_top: float = 0.2
"""Ratio of top examples to use for training, if using mix."""


@dataclass
class ConstructorConfig(Serializable):
Expand Down Expand Up @@ -51,6 +54,12 @@ class ConstructorConfig(Serializable):
n_non_activating: int = 50
"""Number of non-activating examples to be constructed."""

center_examples: bool = True
"""Whether to center the examples on the latent activation.
If True, the examples will be centered on the latent activation.
Otherwise, windows will be used, and the activating example can be anywhere
window."""

non_activating_source: Literal["random", "neighbours", "FAISS"] = "random"
"""Source of non-activating examples. Random uses non-activating contexts
sampled from any non activating window. Neighbours uses actvating contexts
Expand Down
168 changes: 147 additions & 21 deletions delphi/latents/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def get_model(name: str, device: str = "cuda") -> SentenceTransformer:

def prepare_non_activating_examples(
tokens: Int[Tensor, "examples ctx_len"],
activations: Float[Tensor, "examples ctx_len"],
distance: float,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
) -> list[NonActivatingExample]:
Expand All @@ -45,12 +46,12 @@ def prepare_non_activating_examples(
return [
NonActivatingExample(
tokens=toks,
activations=torch.zeros_like(toks, dtype=torch.float),
activations=acts,
normalized_activations=None,
distance=distance,
str_tokens=tokenizer.batch_decode(toks),
)
for toks in tokens
for toks, acts in zip(tokens, activations)
]


Expand Down Expand Up @@ -113,11 +114,9 @@ def pool_max_activation_windows(

# Get the max activation magnitude within each context window
max_buffer = torch.segment_reduce(activations, "max", lengths=lengths)

# Deduplicate the context windows
new_tensor = torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype)
new_tensor[inverses, index_within_ctx] = activations

tokens = tokens[unique_ctx_indices]

token_windows, activation_windows = _top_k_pools(
Expand All @@ -127,6 +126,114 @@ def pool_max_activation_windows(
return token_windows, activation_windows


def pool_centered_activation_windows(
activations: Float[Tensor, "examples"],
tokens: Float[Tensor, "windows seq"],
n_windows_per_batch: int,
ctx_indices: Float[Tensor, "examples"],
index_within_ctx: Float[Tensor, "examples"],
ctx_len: int,
max_examples: int,
) -> tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]:
"""
Similar to pool_max_activation_windows. Doesn't use the ctx_indices that were
at the start of the batch or the end of the batch, because it always tries
to have a buffer of ctx_len*5//6 on the left and ctx_len*1//6 on the right.
To do this, for each window, it joins the contexts from the other windows
of the same batch, to form a new context, which is then cut to the correct shape,
centered on the max activation.

Args:
activations : The activations.
tokens : The input tokens.
ctx_indices : The context indices.
index_within_ctx : The index within the context.
ctx_len : The context length.
max_examples : The maximum number of examples.
"""

# Get unique context indices and their counts like in pool_max_activation_windows
unique_ctx_indices, inverses, lengths = torch.unique_consecutive(
ctx_indices, return_counts=True, return_inverse=True
)

# Get the max activation magnitude within each context window
max_buffer = torch.segment_reduce(activations, "max", lengths=lengths)

# Get the top max_examples windows
k = min(max_examples, len(max_buffer))
top_values, top_indices = torch.topk(max_buffer, k, sorted=True)

# this tensor has the correct activations for each context window
temp_tensor = torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype)
temp_tensor[inverses, index_within_ctx] = activations

unique_ctx_indices = unique_ctx_indices[top_indices]
temp_tensor = temp_tensor[top_indices]

# if a element in unique_ctx_indices is divisible by n_windows_per_batch it
# the start of a new batch, so we discard it
modulo = unique_ctx_indices % n_windows_per_batch
not_first_position = modulo != 0
# remove also the elements that are at the end of the batch
not_last_position = modulo != n_windows_per_batch - 1
mask = not_first_position & not_last_position
unique_ctx_indices = unique_ctx_indices[mask]
temp_tensor = temp_tensor[mask]
if len(unique_ctx_indices) == 0:
return torch.zeros(0, ctx_len), torch.zeros(0, ctx_len)

# Vectorized operations for all windows at once
n_windows = len(unique_ctx_indices)

# Create indices for previous, current, and next windows
prev_indices = unique_ctx_indices - 1
next_indices = unique_ctx_indices + 1

# Create a tensor to hold all concatenated tokens
all_tokens = torch.cat(
[tokens[prev_indices], tokens[unique_ctx_indices], tokens[next_indices]], dim=1
) # Shape: [n_windows, ctx_len*3]

# Create tensor for all activations
final_tensor = torch.zeros((n_windows, ctx_len * 3), dtype=activations.dtype)
final_tensor[:, ctx_len : ctx_len * 2] = (
temp_tensor # Set current window activations
)

# Set previous window activations where available
prev_mask = torch.isin(prev_indices, unique_ctx_indices)
if prev_mask.any():
prev_locations = torch.where(
unique_ctx_indices.unsqueeze(1) == prev_indices.unsqueeze(0)
)[1]
final_tensor[prev_mask, :ctx_len] = temp_tensor[prev_locations]

# Set next window activations where available
next_mask = torch.isin(next_indices, unique_ctx_indices)
if next_mask.any():
next_locations = torch.where(
unique_ctx_indices.unsqueeze(1) == next_indices.unsqueeze(0)
)[1]
final_tensor[next_mask, ctx_len * 2 :] = temp_tensor[next_locations]

# Find max activation indices
max_activation_indices = torch.argmax(temp_tensor, dim=1) + ctx_len

# Calculate left for all windows
left_positions = max_activation_indices - (ctx_len - ctx_len // 4)

# Create index tensors for gathering
batch_indices = torch.arange(n_windows).unsqueeze(1)
pos_indices = torch.arange(ctx_len).unsqueeze(0)
gather_indices = left_positions.unsqueeze(1) + pos_indices

# Gather the final windows
token_windows = all_tokens[batch_indices, gather_indices]
activation_windows = final_tensor[batch_indices, gather_indices]
return token_windows, activation_windows


def constructor(
record: LatentRecord,
activation_data: ActivationData,
Expand All @@ -142,49 +249,55 @@ def constructor(
n_not_active = constructor_cfg.n_non_activating
max_examples = constructor_cfg.max_examples
min_examples = constructor_cfg.min_examples

# Get all positions where the latent is active
flat_indices = (
activation_data.locations[:, 0] * cache_ctx_len
+ activation_data.locations[:, 1]
)
ctx_indices = flat_indices // example_ctx_len
index_within_ctx = flat_indices % example_ctx_len
n_windows_per_batch = tokens.shape[1] // example_ctx_len
reshaped_tokens = tokens.reshape(-1, example_ctx_len)
n_windows = reshaped_tokens.shape[0]

unique_batch_pos = ctx_indices.unique()

mask = torch.ones(n_windows, dtype=torch.bool)
mask[unique_batch_pos] = False
# Indices where the latent is not active
non_active_indices = mask.nonzero(as_tuple=False).squeeze()
activations = activation_data.activations

# per context frequency
record.per_context_frequency = len(unique_batch_pos) / n_windows

# Add activation examples to the record in place
token_windows, act_windows = pool_max_activation_windows(
activations=activations,
tokens=reshaped_tokens,
ctx_indices=ctx_indices,
index_within_ctx=index_within_ctx,
ctx_len=example_ctx_len,
max_examples=max_examples,
)
if constructor_cfg.center_examples:
token_windows, act_windows = pool_max_activation_windows(
activations=activations,
tokens=reshaped_tokens,
ctx_indices=ctx_indices,
index_within_ctx=index_within_ctx,
ctx_len=example_ctx_len,
max_examples=max_examples,
)
else:
token_windows, act_windows = pool_centered_activation_windows(
activations=activations,
tokens=reshaped_tokens,
n_windows_per_batch=n_windows_per_batch,
ctx_indices=ctx_indices,
index_within_ctx=index_within_ctx,
ctx_len=example_ctx_len,
max_examples=max_examples,
)
# TODO: We might want to do this in the sampler
# we are tokenizing examples that are not going to be used
record.examples = [
ActivatingExample(
tokens=toks,
activations=acts,
normalized_activations=None,
str_tokens=tokenizer.batch_decode(toks),
)
for toks, acts in zip(token_windows, act_windows)
]

if len(record.examples) < min_examples:
# Not enough examples to explain the latent
return None
Expand Down Expand Up @@ -420,6 +533,7 @@ def faiss_non_activation_windows(
# Create non-activating examples
return prepare_non_activating_examples(
selected_tokens,
torch.zeros_like(selected_tokens),
-1.0, # Using -1.0 as the distance since these are not neighbour-based
tokenizer,
)
Expand Down Expand Up @@ -495,7 +609,7 @@ def neighbour_non_activation_windows(
# If there are no available indices, skip this neighbour
if activations.numel() == 0:
continue
token_windows, _ = pool_max_activation_windows(
token_windows, token_activations = pool_max_activation_windows(
activations=activations,
tokens=reshaped_tokens,
ctx_indices=available_ctx_indices,
Expand All @@ -508,12 +622,23 @@ def neighbour_non_activation_windows(
examples_used = len(token_windows)
all_examples.extend(
prepare_non_activating_examples(
token_windows, -neighbour.distance, tokenizer
token_windows,
token_activations, # activations of neighbour
-neighbour.distance,
tokenizer,
)
)
number_examples += examples_used
if len(all_examples) == 0:
print("No examples found")
print("No examples found, falling back to random non-activating examples")
non_active_indices = not_active_mask.nonzero(as_tuple=False).squeeze()

return random_non_activating_windows(
available_indices=non_active_indices,
reshaped_tokens=reshaped_tokens,
n_not_active=n_not_active,
tokenizer=tokenizer,
)
return all_examples


Expand Down Expand Up @@ -554,6 +679,7 @@ def random_non_activating_windows(

return prepare_non_activating_examples(
toks,
torch.zeros_like(toks), # there is no way to define these activations
-1.0,
tokenizer,
)
4 changes: 2 additions & 2 deletions delphi/latents/latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Example:
activations: Float[Tensor, "ctx_len"]
"""Activation values for the input sequence."""

str_tokens: list[str]
str_tokens: list[str] | None = None
"""Tokenized input sequence as strings."""

normalized_activations: Optional[Float[Tensor, "ctx_len"]] = None
Expand Down Expand Up @@ -134,7 +134,7 @@ class LatentRecord:
train: list[ActivatingExample] = field(default_factory=list)
"""Training examples."""

test: list[ActivatingExample] | list[list[Example]] = field(default_factory=list)
test: list[ActivatingExample] = field(default_factory=list)
"""Test examples."""

neighbours: list[Neighbour] = field(default_factory=list)
Expand Down
Loading