diff --git a/delphi/clients/offline.py b/delphi/clients/offline.py index 51397cc8..83b41517 100644 --- a/delphi/clients/offline.py +++ b/delphi/clients/offline.py @@ -45,6 +45,7 @@ def __init__( prefix_caching: bool = True, batch_size: int = 100, max_model_len: int = 4096, + number_tokens_to_generate: int = 500, num_gpus: int = 2, enforce_eager: bool = False, statistics: bool = False, @@ -65,7 +66,7 @@ def __init__( max_model_len=max_model_len, enforce_eager=enforce_eager, ) - self.sampling_params = SamplingParams(max_tokens=500) + self.sampling_params = SamplingParams(max_tokens=number_tokens_to_generate) self.tokenizer = AutoTokenizer.from_pretrained(model) self.batch_size = batch_size self.statistics = statistics diff --git a/delphi/clients/openrouter.py b/delphi/clients/openrouter.py index 80b920e1..404d0457 100644 --- a/delphi/clients/openrouter.py +++ b/delphi/clients/openrouter.py @@ -23,12 +23,16 @@ def __init__( model: str, api_key: str | None = None, base_url="https://openrouter.ai/api/v1/chat/completions", + max_tokens: int = 3000, + temperature: float = 1.0, ): super().__init__(model) self.headers = {"Authorization": f"Bearer {api_key}"} self.url = base_url + self.max_tokens = max_tokens + self.temperature = temperature timeout_config = httpx.Timeout(5.0) self.client = httpx.AsyncClient(timeout=timeout_config) @@ -45,8 +49,10 @@ async def generate( # type: ignore **kwargs, # type: ignore ) -> Response: # type: ignore kwargs.pop("schema", None) - max_tokens = kwargs.pop("max_tokens", 500) - temperature = kwargs.pop("temperature", 1.0) + # We have to decide if we want to do this like this or not + # Currently only simulation uses generation kwargs. + max_tokens = kwargs.pop("max_tokens", self.max_tokens) + temperature = kwargs.pop("temperature", self.temperature) data = { "model": self.model, "messages": prompt, @@ -58,7 +64,7 @@ async def generate( # type: ignore for attempt in range(max_retries): try: response = await self.client.post( - url=self.url, json=data, headers=self.headers + url=self.url, json=data, headers=self.headers, timeout=100 ) if raw: return response.json() diff --git a/delphi/config.py b/delphi/config.py index 6e751cf0..41f6a0aa 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -17,12 +17,15 @@ class SamplerConfig(Serializable): n_quantiles: int = 10 """Number of latent activation quantiles to sample.""" - train_type: Literal["top", "random", "quantiles"] = "quantiles" + train_type: Literal["top", "random", "quantiles", "mix"] = "quantiles" """Type of sampler to use for latent explanation generation.""" test_type: Literal["quantiles"] = "quantiles" """Type of sampler to use for latent explanation testing.""" + ratio_top: float = 0.2 + """Ratio of top examples to use for training, if using mix.""" + @dataclass class ConstructorConfig(Serializable): @@ -51,6 +54,12 @@ class ConstructorConfig(Serializable): n_non_activating: int = 50 """Number of non-activating examples to be constructed.""" + center_examples: bool = True + """Whether to center the examples on the latent activation. + If True, the examples will be centered on the latent activation. + Otherwise, windows will be used, and the activating example can be anywhere + window.""" + non_activating_source: Literal["random", "neighbours", "FAISS"] = "random" """Source of non-activating examples. Random uses non-activating contexts sampled from any non activating window. Neighbours uses actvating contexts diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 6ec841e5..5bf08761 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -32,6 +32,7 @@ def get_model(name: str, device: str = "cuda") -> SentenceTransformer: def prepare_non_activating_examples( tokens: Int[Tensor, "examples ctx_len"], + activations: Float[Tensor, "examples ctx_len"], distance: float, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, ) -> list[NonActivatingExample]: @@ -45,12 +46,12 @@ def prepare_non_activating_examples( return [ NonActivatingExample( tokens=toks, - activations=torch.zeros_like(toks, dtype=torch.float), + activations=acts, normalized_activations=None, distance=distance, str_tokens=tokenizer.batch_decode(toks), ) - for toks in tokens + for toks, acts in zip(tokens, activations) ] @@ -113,11 +114,9 @@ def pool_max_activation_windows( # Get the max activation magnitude within each context window max_buffer = torch.segment_reduce(activations, "max", lengths=lengths) - # Deduplicate the context windows new_tensor = torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype) new_tensor[inverses, index_within_ctx] = activations - tokens = tokens[unique_ctx_indices] token_windows, activation_windows = _top_k_pools( @@ -127,6 +126,114 @@ def pool_max_activation_windows( return token_windows, activation_windows +def pool_centered_activation_windows( + activations: Float[Tensor, "examples"], + tokens: Float[Tensor, "windows seq"], + n_windows_per_batch: int, + ctx_indices: Float[Tensor, "examples"], + index_within_ctx: Float[Tensor, "examples"], + ctx_len: int, + max_examples: int, +) -> tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: + """ + Similar to pool_max_activation_windows. Doesn't use the ctx_indices that were + at the start of the batch or the end of the batch, because it always tries + to have a buffer of ctx_len*5//6 on the left and ctx_len*1//6 on the right. + To do this, for each window, it joins the contexts from the other windows + of the same batch, to form a new context, which is then cut to the correct shape, + centered on the max activation. + + Args: + activations : The activations. + tokens : The input tokens. + ctx_indices : The context indices. + index_within_ctx : The index within the context. + ctx_len : The context length. + max_examples : The maximum number of examples. + """ + + # Get unique context indices and their counts like in pool_max_activation_windows + unique_ctx_indices, inverses, lengths = torch.unique_consecutive( + ctx_indices, return_counts=True, return_inverse=True + ) + + # Get the max activation magnitude within each context window + max_buffer = torch.segment_reduce(activations, "max", lengths=lengths) + + # Get the top max_examples windows + k = min(max_examples, len(max_buffer)) + top_values, top_indices = torch.topk(max_buffer, k, sorted=True) + + # this tensor has the correct activations for each context window + temp_tensor = torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype) + temp_tensor[inverses, index_within_ctx] = activations + + unique_ctx_indices = unique_ctx_indices[top_indices] + temp_tensor = temp_tensor[top_indices] + + # if a element in unique_ctx_indices is divisible by n_windows_per_batch it + # the start of a new batch, so we discard it + modulo = unique_ctx_indices % n_windows_per_batch + not_first_position = modulo != 0 + # remove also the elements that are at the end of the batch + not_last_position = modulo != n_windows_per_batch - 1 + mask = not_first_position & not_last_position + unique_ctx_indices = unique_ctx_indices[mask] + temp_tensor = temp_tensor[mask] + if len(unique_ctx_indices) == 0: + return torch.zeros(0, ctx_len), torch.zeros(0, ctx_len) + + # Vectorized operations for all windows at once + n_windows = len(unique_ctx_indices) + + # Create indices for previous, current, and next windows + prev_indices = unique_ctx_indices - 1 + next_indices = unique_ctx_indices + 1 + + # Create a tensor to hold all concatenated tokens + all_tokens = torch.cat( + [tokens[prev_indices], tokens[unique_ctx_indices], tokens[next_indices]], dim=1 + ) # Shape: [n_windows, ctx_len*3] + + # Create tensor for all activations + final_tensor = torch.zeros((n_windows, ctx_len * 3), dtype=activations.dtype) + final_tensor[:, ctx_len : ctx_len * 2] = ( + temp_tensor # Set current window activations + ) + + # Set previous window activations where available + prev_mask = torch.isin(prev_indices, unique_ctx_indices) + if prev_mask.any(): + prev_locations = torch.where( + unique_ctx_indices.unsqueeze(1) == prev_indices.unsqueeze(0) + )[1] + final_tensor[prev_mask, :ctx_len] = temp_tensor[prev_locations] + + # Set next window activations where available + next_mask = torch.isin(next_indices, unique_ctx_indices) + if next_mask.any(): + next_locations = torch.where( + unique_ctx_indices.unsqueeze(1) == next_indices.unsqueeze(0) + )[1] + final_tensor[next_mask, ctx_len * 2 :] = temp_tensor[next_locations] + + # Find max activation indices + max_activation_indices = torch.argmax(temp_tensor, dim=1) + ctx_len + + # Calculate left for all windows + left_positions = max_activation_indices - (ctx_len - ctx_len // 4) + + # Create index tensors for gathering + batch_indices = torch.arange(n_windows).unsqueeze(1) + pos_indices = torch.arange(ctx_len).unsqueeze(0) + gather_indices = left_positions.unsqueeze(1) + pos_indices + + # Gather the final windows + token_windows = all_tokens[batch_indices, gather_indices] + activation_windows = final_tensor[batch_indices, gather_indices] + return token_windows, activation_windows + + def constructor( record: LatentRecord, activation_data: ActivationData, @@ -142,7 +249,6 @@ def constructor( n_not_active = constructor_cfg.n_non_activating max_examples = constructor_cfg.max_examples min_examples = constructor_cfg.min_examples - # Get all positions where the latent is active flat_indices = ( activation_data.locations[:, 0] * cache_ctx_len @@ -150,29 +256,38 @@ def constructor( ) ctx_indices = flat_indices // example_ctx_len index_within_ctx = flat_indices % example_ctx_len + n_windows_per_batch = tokens.shape[1] // example_ctx_len reshaped_tokens = tokens.reshape(-1, example_ctx_len) n_windows = reshaped_tokens.shape[0] - unique_batch_pos = ctx_indices.unique() - mask = torch.ones(n_windows, dtype=torch.bool) mask[unique_batch_pos] = False # Indices where the latent is not active non_active_indices = mask.nonzero(as_tuple=False).squeeze() activations = activation_data.activations - # per context frequency record.per_context_frequency = len(unique_batch_pos) / n_windows # Add activation examples to the record in place - token_windows, act_windows = pool_max_activation_windows( - activations=activations, - tokens=reshaped_tokens, - ctx_indices=ctx_indices, - index_within_ctx=index_within_ctx, - ctx_len=example_ctx_len, - max_examples=max_examples, - ) + if constructor_cfg.center_examples: + token_windows, act_windows = pool_max_activation_windows( + activations=activations, + tokens=reshaped_tokens, + ctx_indices=ctx_indices, + index_within_ctx=index_within_ctx, + ctx_len=example_ctx_len, + max_examples=max_examples, + ) + else: + token_windows, act_windows = pool_centered_activation_windows( + activations=activations, + tokens=reshaped_tokens, + n_windows_per_batch=n_windows_per_batch, + ctx_indices=ctx_indices, + index_within_ctx=index_within_ctx, + ctx_len=example_ctx_len, + max_examples=max_examples, + ) # TODO: We might want to do this in the sampler # we are tokenizing examples that are not going to be used record.examples = [ @@ -180,11 +295,9 @@ def constructor( tokens=toks, activations=acts, normalized_activations=None, - str_tokens=tokenizer.batch_decode(toks), ) for toks, acts in zip(token_windows, act_windows) ] - if len(record.examples) < min_examples: # Not enough examples to explain the latent return None @@ -420,6 +533,7 @@ def faiss_non_activation_windows( # Create non-activating examples return prepare_non_activating_examples( selected_tokens, + torch.zeros_like(selected_tokens), -1.0, # Using -1.0 as the distance since these are not neighbour-based tokenizer, ) @@ -495,7 +609,7 @@ def neighbour_non_activation_windows( # If there are no available indices, skip this neighbour if activations.numel() == 0: continue - token_windows, _ = pool_max_activation_windows( + token_windows, token_activations = pool_max_activation_windows( activations=activations, tokens=reshaped_tokens, ctx_indices=available_ctx_indices, @@ -508,12 +622,23 @@ def neighbour_non_activation_windows( examples_used = len(token_windows) all_examples.extend( prepare_non_activating_examples( - token_windows, -neighbour.distance, tokenizer + token_windows, + token_activations, # activations of neighbour + -neighbour.distance, + tokenizer, ) ) number_examples += examples_used if len(all_examples) == 0: - print("No examples found") + print("No examples found, falling back to random non-activating examples") + non_active_indices = not_active_mask.nonzero(as_tuple=False).squeeze() + + return random_non_activating_windows( + available_indices=non_active_indices, + reshaped_tokens=reshaped_tokens, + n_not_active=n_not_active, + tokenizer=tokenizer, + ) return all_examples @@ -554,6 +679,7 @@ def random_non_activating_windows( return prepare_non_activating_examples( toks, + torch.zeros_like(toks), # there is no way to define these activations -1.0, tokenizer, ) diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 03b35577..5e5611c2 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -75,7 +75,7 @@ class Example: activations: Float[Tensor, "ctx_len"] """Activation values for the input sequence.""" - str_tokens: list[str] + str_tokens: list[str] | None = None """Tokenized input sequence as strings.""" normalized_activations: Optional[Float[Tensor, "ctx_len"]] = None @@ -134,7 +134,7 @@ class LatentRecord: train: list[ActivatingExample] = field(default_factory=list) """Training examples.""" - test: list[ActivatingExample] | list[list[Example]] = field(default_factory=list) + test: list[ActivatingExample] = field(default_factory=list) """Test examples.""" neighbours: list[Neighbour] = field(default_factory=list) diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index c8b30ac1..f5964036 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -127,6 +127,7 @@ def __init__( tokenizer: Optional[PreTrainedTokenizer | PreTrainedTokenizerFast] = None, modules: Optional[list[str]] = None, latents: Optional[dict[str, torch.Tensor]] = None, + neighbours_path: Optional[os.PathLike] = None, ): """ Initialize a LatentDataset. @@ -145,7 +146,7 @@ def __init__( self.buffers: list[TensorBuffer] = [] self.all_data: dict[str, dict[int, ActivationData] | None] = {} self.tokens = None - + self.neighbours_path = neighbours_path if modules is None: self.modules = os.listdir(raw_dir) else: @@ -180,7 +181,10 @@ def __init__( if self.constructor_cfg.non_activating_source == "neighbours": # path is always going to end with /latents - neighbours_path = Path(raw_dir).parent / "neighbours" + if self.neighbours_path is None: + neighbours_path = Path(raw_dir).parent / "neighbours" + else: + neighbours_path = Path(self.neighbours_path) self.neighbours = self.load_neighbours( neighbours_path, self.constructor_cfg.neighbours_type ) @@ -395,11 +399,10 @@ async def _aprocess_latent(self, latent_data: LatentData) -> LatentRecord | None activation_data=latent_data.activation_data, constructor_cfg=self.constructor_cfg, tokens=self.tokens, - all_data=self.all_data[latent_data.module], tokenizer=self.tokenizer, + all_data=self.all_data[latent_data.module], ) - # Not enough examples to explain the latent if record is None: return None - record = sampler(record, self.sampler_cfg) + record = sampler(record, self.sampler_cfg, self.tokenizer) return record diff --git a/delphi/latents/samplers.py b/delphi/latents/samplers.py index 4e1c9321..9bf170c8 100644 --- a/delphi/latents/samplers.py +++ b/delphi/latents/samplers.py @@ -1,6 +1,11 @@ import random from typing import Literal +from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) + from ..config import SamplerConfig from ..logger import logger from .latents import ActivatingExample, LatentRecord @@ -52,9 +57,10 @@ def train( examples: list[ActivatingExample], max_activation: float, n_train: int, - train_type: Literal["top", "random", "quantiles"], + train_type: Literal["top", "random", "quantiles", "mix"], n_quantiles: int = 10, seed: int = 22, + ratio_top: float = 0.2, ): match train_type: case "top": @@ -77,6 +83,16 @@ def train( selected_examples = split_quantiles(examples, n_quantiles, n_train) selected_examples = normalize_activations(selected_examples, max_activation) return selected_examples + case "mix": + top_examples = examples[: int(n_train * ratio_top)] + quantiles_examples = split_quantiles( + examples[int(n_train * ratio_top) :], + n_quantiles, + int(n_train * (1 - ratio_top)), + ) + selected_examples = top_examples + quantiles_examples + selected_examples = normalize_activations(selected_examples, max_activation) + return selected_examples def test( @@ -84,20 +100,19 @@ def test( max_activation: float, n_test: int, n_quantiles: int, - test_type: Literal["quantiles", "activation"], + test_type: Literal["quantiles"], ): match test_type: case "quantiles": selected_examples = split_quantiles(examples, n_quantiles, n_test) selected_examples = normalize_activations(selected_examples, max_activation) return selected_examples - case "activation": - raise NotImplementedError("Activation sampling not implemented") def sampler( record: LatentRecord, cfg: SamplerConfig, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, ): examples = record.examples max_activation = record.max_activation @@ -107,7 +122,12 @@ def sampler( cfg.n_examples_train, cfg.train_type, n_quantiles=cfg.n_quantiles, + ratio_top=cfg.ratio_top, ) + # Moved tokenization to sampler to avoid tokenizing + # examples that are not going to be used + for example in _train: + example.str_tokens = tokenizer.batch_decode(example.tokens) record.train = _train if cfg.n_examples_test > 0: _test = test( @@ -117,5 +137,7 @@ def sampler( cfg.n_quantiles, cfg.test_type, ) + for example in _test: + example.str_tokens = tokenizer.batch_decode(example.tokens) record.test = _test return record diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index b031d858..747db837 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -1,6 +1,8 @@ from .classifier.detection import DetectionScorer from .classifier.fuzz import FuzzingScorer +from .classifier.intruder import IntruderScorer from .embedding.embedding import EmbeddingScorer +from .embedding.example_embedding import ExampleEmbeddingScorer from .scorer import Scorer from .simulator.oai_simulator import OpenAISimulator from .surprisal.surprisal import SurprisalScorer @@ -12,4 +14,6 @@ "Scorer", "SurprisalScorer", "EmbeddingScorer", + "IntruderScorer", + "ExampleEmbeddingScorer", ] diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index 76b43dd3..9db29e6f 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -3,7 +3,7 @@ import random import re from abc import abstractmethod -from typing import Literal +from typing import Any, Literal import numpy as np @@ -21,6 +21,7 @@ def __init__( verbose: bool, n_examples_shown: int, log_prob: bool, + seed: int = 42, **generation_kwargs, ): """ @@ -41,24 +42,25 @@ def __init__( self.n_examples_shown = n_examples_shown self.generation_kwargs = generation_kwargs self.log_prob = log_prob + self.rng = random.Random(seed) - async def __call__( # type: ignore - self, # type: ignore - record: LatentRecord, # type: ignore - ) -> ScorerResult: # type: ignore + async def __call__( + self, + record: LatentRecord, + ) -> ScorerResult: samples = self._prepare(record) - random.shuffle(samples) + self.rng.shuffle(samples) - samples = self._batch(samples) + batched_samples = self._batch(samples) results = await self._query( record.explanation, - samples, + batched_samples, ) return ScorerResult(record=record, score=results) @abstractmethod - def _prepare(self, record: LatentRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[Sample]: pass async def _query( @@ -130,7 +132,9 @@ async def _generate( ) return results - def _parse(self, string, logprobs=None): + def _parse( + self, string: str, logprobs: list[float] | None = None + ) -> tuple[list[bool], list[float] | list[None]]: """Extract binary predictions and probabilities from a string and optionally its token logprobs.""" # Matches the first instance of text enclosed in square brackets @@ -148,7 +152,7 @@ def _parse(self, string, logprobs=None): return predictions, probabilities - def _parse_logprobs(self, logprobs: list): + def _parse_logprobs(self, logprobs: list[Any]) -> list[float]: """ Extracts normalized probabilities of '1' vs '0' tokens from the top n log probabilities for each token in a response string of form '[x, x, x, ...]'. @@ -204,7 +208,7 @@ def _build_prompt( def prompt(self, examples: str, explanation: str) -> list[dict]: pass - def _batch(self, samples): + def _batch(self, samples: list[Sample]) -> list[list[Sample]]: return [ samples[i : i + self.n_examples_shown] for i in range(0, len(samples), self.n_examples_shown) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index bd78fbdf..d2ff9d44 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -1,7 +1,7 @@ from ...clients.client import Client from ...latents import LatentRecord from .classifier import Classifier -from .prompts.detection_prompt import prompt +from .prompts.detection_prompt import prompt as detection_prompt from .sample import Sample, examples_to_samples @@ -40,7 +40,7 @@ def __init__( ) def prompt(self, examples: str, explanation: str) -> list[dict]: - return prompt(examples, explanation) + return detection_prompt(examples, explanation) def _prepare(self, record: LatentRecord) -> list[Sample]: # type: ignore """ diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index 667db798..2c6c558b 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -1,13 +1,14 @@ from math import ceil +from typing import Literal import torch from ...clients.client import Client from ...latents import LatentRecord -from ...latents.latents import ActivatingExample +from ...latents.latents import ActivatingExample, NonActivatingExample from ..scorer import Scorer from .classifier import Classifier -from .prompts.fuzz_prompt import prompt +from .prompts.fuzz_prompt import prompt as fuzz_prompt from .sample import Sample, examples_to_samples @@ -22,6 +23,7 @@ def __init__( threshold: float = 0.3, log_prob: bool = False, temperature: float = 0.0, + fuzz_type: Literal["default", "active"] = "default", **generation_kwargs, ): """ @@ -47,9 +49,10 @@ def __init__( ) self.threshold = threshold + self.fuzz_type = fuzz_type def prompt(self, examples: str, explanation: str) -> list[dict]: - return prompt(examples, explanation) + return fuzz_prompt(examples, explanation) def mean_n_activations_ceil(self, examples: list[ActivatingExample]): """ @@ -61,24 +64,57 @@ def mean_n_activations_ceil(self, examples: list[ActivatingExample]): return ceil(avg) + def _convert_to_non_activating( + self, examples: list[ActivatingExample] + ) -> list[NonActivatingExample]: + """ + Convert a list of activating examples to a list of non-activating examples. + """ + return [ + NonActivatingExample( + tokens=example.tokens, + activations=example.activations, + str_tokens=example.str_tokens, + normalized_activations=example.normalized_activations, + distance=-1, + ) + for example in examples + ] + def _prepare(self, record: LatentRecord) -> list[Sample]: # type: ignore """ Prepare and shuffle a list of samples for classification. """ assert len(record.test) > 0, "No test records found" - n_incorrect = self.mean_n_activations_ceil(record.test) # type: ignore + n_incorrect = self.mean_n_activations_ceil(record.test) - if len(record.not_active) > 0: + if self.fuzz_type == "default": + assert len(record.not_active) > 0, "No non-activating examples found" + # check if non_activating examples have any activations > 0 + # if they do they are contrastive examples + if (record.not_active[0].activations > 0).any(): + samples = examples_to_samples( + record.not_active, + n_incorrect=0, + highlighted=True, + ) + else: + # if they don't we use randomly highlight n_incorrect tokens + samples = examples_to_samples( + record.not_active, + n_incorrect=n_incorrect, + highlighted=True, + ) + elif self.fuzz_type == "active": + # hard uses activating examples and + # highlights non active tokens + extras = self._convert_to_non_activating(record.test) samples = examples_to_samples( - record.not_active, + extras, n_incorrect=n_incorrect, highlighted=True, ) - - else: - samples = [] - samples.extend( examples_to_samples( record.test, # type: ignore diff --git a/delphi/scorers/classifier/intruder.py b/delphi/scorers/classifier/intruder.py new file mode 100644 index 00000000..708f7a89 --- /dev/null +++ b/delphi/scorers/classifier/intruder.py @@ -0,0 +1,340 @@ +import asyncio +import re +from dataclasses import dataclass +from typing import Literal + +from beartype.typing import Sequence + +from ...clients.client import Client +from ...latents import ActivatingExample, Example, LatentRecord, NonActivatingExample +from ...logger import logger +from .classifier import Classifier, ScorerResult +from .prompts.intruder_prompt import prompt as intruder_prompt +from .sample import _prepare_text + + +@dataclass +class IntruderSentence: + """ + A sample for an intruder sentence experiment. + """ + + examples: list[str] + intruder_index: int + chosen_quantile: int + activations: list[list[float]] + tokens: list[list[str]] + intruder_distance: float + + +@dataclass +class IntruderResult: + """ + Result of an intruder experiment. + """ + + interpretation: str = "" + sample: IntruderSentence | None = None + prediction: int = 0 + correct_index: int = -1 + correct: bool = False + + +class IntruderScorer(Classifier): + name = "intruder" + + def __init__( + self, + client: Client, + verbose: bool = False, + n_examples_shown: int = 1, + temperature: float = 0.0, + cot: bool = False, + type: Literal["default", "internal"] = "default", + seed: int = 42, + **generation_kwargs, + ): + """ + Initialize a IntruderScorer. + + Args: + client: The client to use for generation. + tokenizer: The tokenizer used to cache the tokens + verbose: Whether to print verbose output. + n_examples_shown: The number of examples to show in the prompt, + a larger number can both leak information and make + it harder for models to generate anwers in the correct format + temperature: The temperature to use for generation + type: The type of intruder to use, either "word" or "sentence" + generation_kwargs: Additional generation kwargs + """ + super().__init__( + client=client, + verbose=verbose, + n_examples_shown=n_examples_shown, + temperature=temperature, + seed=seed, + **generation_kwargs, + ) + self.type = type + if type not in ["default", "internal"]: + raise ValueError("Type must be either 'default' or 'internal'") + self.cot = cot + + def prompt(self, examples: str) -> list[dict]: + return intruder_prompt(examples, cot=self.cot) + + async def __call__( + self, + record: LatentRecord, + ) -> ScorerResult: + samples = self._prepare_and_batch(record) + + results = await self._query( + samples, + ) + + return ScorerResult(record=record, score=results) + + def _count_words(self, examples: Sequence[Example]) -> dict[str, int]: + """ + Count the number of words in the examples and return a dictionary of the counts. + If activating examples are provided, count activating tokens. + If non-activating examples are provided, count activating tokens if they exist, + otherwise count non-activating tokens. + """ + counts = {} + for example in examples: + str_tokens = example.str_tokens + if example.normalized_activations is not None: + acts = example.normalized_activations + # TODO: this is a hack instead of using a threshold + # select only acts that are larger than 2 + acts[acts < 2] = 0 + else: + acts = example.activations + if acts.max() > 0: + wanted_indices = acts.nonzero() + for index in wanted_indices: + if str_tokens[index] not in counts: + counts[str_tokens[index]] = 0 + counts[str_tokens[index]] += 1 + else: + for token in str_tokens: + if token not in counts: + counts[token] = 0 + counts[token] += 1 + return counts + + def _prepare(self, record: LatentRecord) -> None: + pass + + def _get_quantiled_examples( + self, examples: list[ActivatingExample] + ) -> dict[int, list[ActivatingExample]]: + """ + Get the quantiled examples. + """ + quantiles = {} + for example in examples: + if example.quantile not in quantiles: + quantiles[example.quantile] = [] + quantiles[example.quantile].append(example) + return quantiles + + def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: + """ + Prepare and shuffle a list of samples for classification. + """ + + assert len(record.not_active) > 0, "No non-activating examples found" + batches = [] + quantiled_intruder_sentences = self._get_quantiled_examples(record.test) + + intruder_sentences = record.not_active + for i, intruder in enumerate(intruder_sentences): + # select each quantile equally + quantile_index = i % len(quantiled_intruder_sentences.keys()) + + active_examples = quantiled_intruder_sentences[quantile_index] + # if there are more examples than the number of examples to show, + # sample which examples to show + examples_to_show = min(self.n_examples_shown - 1, len(active_examples)) + example_indices = self.rng.sample( + range(len(active_examples)), examples_to_show + ) + active_examples = [active_examples[i] for i in example_indices] + + # convert the examples to strings + + # highlights the active tokens + majority_examples = [] + active_tokens = 0 + for example in active_examples: + text, _ = _prepare_text( + example, n_incorrect=0, threshold=0.3, highlighted=True + ) + majority_examples.append(text) + active_tokens += (example.activations > 0).sum().item() + active_tokens = int(active_tokens / len(active_examples)) + if self.type == "default": + # if example is contrastive, use the active tokens + # otherwise use the non-activating tokens + if intruder.activations.max() > 0: + n_incorrect = 0 + else: + n_incorrect = active_tokens + intruder_sentence, _ = _prepare_text( + intruder, + n_incorrect=n_incorrect, + threshold=0.3, + highlighted=True, + ) + elif self.type == "internal": + # randomly select a quantile to be the intruder, make sure it's not + # the same as the source quantile + intruder_quantile_index = self.rng.randint( + 0, len(quantiled_intruder_sentences.keys()) - 1 + ) + while intruder_quantile_index == quantile_index: + intruder_quantile_index = self.rng.randint( + 0, len(quantiled_intruder_sentences.keys()) - 1 + ) + posible_intruder_sentences = quantiled_intruder_sentences[ + intruder_quantile_index + ] + intruder_index_selected = self.rng.randint( + 0, len(posible_intruder_sentences) - 1 + ) + intruder = posible_intruder_sentences[intruder_index_selected] + # here the examples are activating, so we have to convert them + # to non-activating examples + non_activating_intruder = NonActivatingExample( + tokens=intruder.tokens, + activations=intruder.activations, + str_tokens=intruder.str_tokens, + distance=intruder.quantile, + ) + # we highlight the correct activating tokens though + intruder_sentence, _ = _prepare_text( + non_activating_intruder, + n_incorrect=0, + threshold=0.3, + highlighted=True, + ) + intruder = non_activating_intruder + + # select a random index to insert the intruder sentence + intruder_index = self.rng.randint(0, examples_to_show) + majority_examples.insert(intruder_index, intruder_sentence) + + activations = [example.activations.tolist() for example in active_examples] + tokens = [example.str_tokens for example in active_examples] + activations.insert(intruder_index, intruder.activations.tolist()) + tokens.insert(intruder_index, intruder.str_tokens) + + batches.append( + IntruderSentence( + examples=majority_examples, + intruder_index=intruder_index, + chosen_quantile=quantile_index, + activations=activations, + tokens=tokens, + intruder_distance=intruder.distance, + ) + ) + + return batches + + async def _query( + self, + samples: list[IntruderSentence], + ) -> list[IntruderResult]: + """ + Send and gather batches of samples to the model. + """ + sem = asyncio.Semaphore(1) + + async def _process(sample): + async with sem: + result = await self._generate(sample) + return result + + tasks = [asyncio.create_task(_process(sample)) for sample in samples] + results = await asyncio.gather(*tasks) + + return results + + def _build_prompt( + self, + sample: IntruderSentence, + ) -> list[dict]: + """ + Prepare prompt for generation. + """ + + examples = "\n".join( + f"Example {i}: {example}" for i, example in enumerate(sample.examples) + ) + + return self.prompt(examples=examples) + + def _parse( + self, + string: str, + ) -> tuple[str, int]: + """The answer will be in the format interpretation [RESPONSE]: 1""" + # Find the first instance of the text with [RESPONSE]: + pattern = r"\[RESPONSE\]:" + match = re.search(pattern, string) + if match is None: + raise ValueError("No response found in string") + # get everything before the match + interpretation = string[: match.start()] + # get everything after the match + after = string[match.end() :] + # the response should be a single number + try: + prediction = int(after.strip()) + except ValueError: + raise ValueError("Response is not a number") + if prediction < 0 or prediction >= self.n_examples_shown: + raise ValueError("Response is out of range") + return interpretation, prediction + + async def _generate(self, sample: IntruderSentence) -> IntruderResult: + """ + Generate predictions for a batch of samples. + """ + + prompt = self._build_prompt(sample) + try: + response = await self.client.generate(prompt, **self.generation_kwargs) + except Exception as e: + logger.error(f"Error generating text: {e}") + response = None + + if response is None: + # default result is a error + return IntruderResult() + else: + + try: + interpretation, prediction = self._parse(response.text) + except Exception as e: + logger.error(f"Parsing selections failed: {e}") + # default result is a error + return IntruderResult() + + # check that the only prediction is the intruder + correct = prediction == sample.intruder_index + + result = IntruderResult( + interpretation=interpretation, + sample=sample, + prediction=prediction, + correct_index=sample.intruder_index, + correct=correct, + ) + + return result diff --git a/delphi/scorers/classifier/prompts/intruder_prompt.py b/delphi/scorers/classifier/prompts/intruder_prompt.py new file mode 100644 index 00000000..f9cd19a7 --- /dev/null +++ b/delphi/scorers/classifier/prompts/intruder_prompt.py @@ -0,0 +1,101 @@ +DSCORER_SYSTEM_PROMPT = """You are an intelligent and meticulous linguistics researcher doing a "intruder detection" task. + +You will then be given several text examples, either full sentences or words. Your task is to determine which examples should be classified as "intruder". +Some sentences will have words highlighted with <> tags. Do not overthink. + +There is only ever one intruder in the examples. + +You should write [RESPONSE]: followed by the index of the intruder. + +""" + +# https://www.neuronpedia.org/gpt2-small/6-res-jb/6048 +DSCORER_EXAMPLE_ONE = """Examples: + +Example 0:<|endoftext|>Getty ImagesĊĊPatriots tight end Rob Gronkowski had his bossâĢĻ +Example 1: Media Day 2015ĊĊLSU defensive end Isaiah Washington (94) speaks to the +Example 2: shown, is generally not eligible for ads. For example, videos about recent tragedies, +Example 3: line, with the left side âĢĶ namely tackle Byron Bell at tackle and guard Amini +""" +DSCORER_RESPONSE_ONE_COT = "There are 3 examples that are related to American football. The intruder is 2 because it is about videos and ads." +DSCORER_RESPONSE_ONE = "[RESPONSE]: 2" + + +# https://www.neuronpedia.org/gpt2-small/8-res-jb/12654 +DSCORER_EXAMPLE_TWO = """Examples: + +Example 0: enact an individual health insurance mandate?âĢĿ, Pelosi's response was to dismiss both +Example 1: climate, TomblinâĢĻs Chief of Staff Charlie Lorensen said.Ċ +Example 2: no wonderworking relics, no true Body and Blood of Christ, no true Baptism +Example 3:ĊĊIt has been devised by Director of Public Prosecutions (DPP) +Example 4: and fair investigation not even include the Director of Athletics? · Finally, we believe the +""" +DSCORER_RESPONSE_TWO_COT = "I can see that there are several examples that have the word 'of' before a capital letter. The intruder is 0 because it does not." +DSCORER_RESPONSE_TWO = "[RESPONSE]: 0" + +# https://www.neuronpedia.org/gpt2-small/8-res-jb/12654 +DSCORER_EXAMPLE_THREE = """Examples: + +Example 0: Climbing +Example 1: running +Example 2: swim +Example 3: eating +Example 4: cycling +""" +DSCORER_RESPONSE_THREE_COT = "All examples are related to activities, the first 3 and the last one being about sports and physical activities. Eating is not a sport or physical activity so it is the intruder." + +DSCORER_RESPONSE_THREE = "[RESPONSE]: 3" + +DSCORER_EXAMPLE_FOUR = """ +Example 0: You<< guys>> support me in many other ways already and +Example 1: birth control access but I assure that << guys>> in Kentucky +Example 2: gig! I hope you <> LOVE her, and please be nice, +Example 3:American, told Hannity that you<< guys>> are playing the race card. +""" + +DSCORER_RESPONSE_FOUR_COT = "I see that all the examples have the word 'guys' highlighted. All examples except example 1 have the word 'you' before 'guys', therefore, example 1 is the intruder." + +DSCORER_RESPONSE_FOUR = "[RESPONSE]: 1" + +GENERATION_PROMPT = """Examples: + +{examples} +""" + +default = [ + {"role": "user", "content": DSCORER_EXAMPLE_ONE}, + {"role": "assistant", "content": DSCORER_RESPONSE_ONE}, + {"role": "user", "content": DSCORER_EXAMPLE_TWO}, + {"role": "assistant", "content": DSCORER_RESPONSE_TWO}, + {"role": "user", "content": DSCORER_EXAMPLE_THREE}, + {"role": "assistant", "content": DSCORER_RESPONSE_THREE}, +] + +default_cot = [ + {"role": "user", "content": DSCORER_EXAMPLE_ONE}, + {"role": "assistant", "content": DSCORER_RESPONSE_ONE_COT + DSCORER_RESPONSE_ONE}, + {"role": "user", "content": DSCORER_EXAMPLE_TWO}, + {"role": "assistant", "content": DSCORER_RESPONSE_TWO_COT + DSCORER_RESPONSE_TWO}, + {"role": "user", "content": DSCORER_EXAMPLE_THREE}, + { + "role": "assistant", + "content": DSCORER_RESPONSE_THREE_COT + DSCORER_RESPONSE_THREE, + }, +] + + +def prompt(examples: str, cot: bool = False) -> list[dict]: + generation_prompt = GENERATION_PROMPT.format(examples=examples) + + if cot: + examples_to_use = default_cot + else: + examples_to_use = default + + prompt = [ + {"role": "system", "content": DSCORER_SYSTEM_PROMPT}, + *examples_to_use, + {"role": "user", "content": generation_prompt}, + ] + + return prompt diff --git a/delphi/scorers/classifier/sample.py b/delphi/scorers/classifier/sample.py index b89608ef..a89b2c59 100644 --- a/delphi/scorers/classifier/sample.py +++ b/delphi/scorers/classifier/sample.py @@ -82,12 +82,13 @@ def examples_to_samples( def _prepare_text( - example, + example: ActivatingExample | NonActivatingExample, n_incorrect: int, threshold: float, highlighted: bool, ) -> tuple[str, list[str]]: str_toks = example.str_tokens + assert str_toks is not None, "str_toks were not set" clean = "".join(str_toks) # Just return text if there's no highlighting if not highlighted: @@ -117,7 +118,20 @@ def threshold_check(i): n_incorrect = min(n_incorrect, len(below_threshold)) - random_indices = set(random.sample(below_threshold.tolist(), n_incorrect)) + # The activating token is always ctx_len - ctx_len//4 + # so we always highlight this one, and if n_incorrect > 1 + # we highlight n_incorrect-1 random ones + token_pos = len(str_toks) - len(str_toks) // 4 + if token_pos in below_threshold: + random_indices = [token_pos] + if n_incorrect > 1: + random_indices.extend( + random.sample(below_threshold.tolist(), n_incorrect - 1) + ) + else: + random_indices = random.sample(below_threshold.tolist(), n_incorrect) + + random_indices = set(random_indices) def check(i): return i in random_indices diff --git a/delphi/scorers/embedding/example_embedding.py b/delphi/scorers/embedding/example_embedding.py new file mode 100644 index 00000000..38b6f17b --- /dev/null +++ b/delphi/scorers/embedding/example_embedding.py @@ -0,0 +1,281 @@ +import asyncio +import random +from dataclasses import dataclass +from typing import Literal + +import torch + +from ...latents import LatentRecord, NonActivatingExample +from ..classifier.sample import _prepare_text +from ..scorer import Scorer, ScorerResult + + +@dataclass +class Batch: + """ + A set of positive and negative examples to be used for scoring, + as well as a positive and negative query. + """ + + negative_examples: list[str] + """Non-activating examples used for negative explanation""" + + positive_examples: list[str] + """Activating examples used for positive explanation""" + + positive_query: str + """Activating example used for positive query""" + + negative_query: str + """Non-activating example used for negative query""" + + quantile_positive_query: int + """Quantile of the positive query""" + + distance_negative_query: float + """Distance of the negative query""" + + +@dataclass +class EmbeddingOutput: + """ + The output of the embedding scorer. + """ + + batch: Batch + """The set of examples and queries used for scoring""" + + delta_plus: float = 0 + """The difference in similarity between the positive query + and the positive examples, and the positive query and the negative examples""" + + delta_minus: float = 0 + """The difference in similarity between the negative query + and the positive examples, and the negative query and the negative examples""" + + +class ExampleEmbeddingScorer(Scorer): + """ + This scorer does not use explanations to score the examples. + Instead it embeds examples. + + + """ + + name = "example_embedding" + + def __init__( + self, + model, + verbose: bool = False, + method: Literal["default", "internal"] = "default", + number_batches: int = 20, + seed: int = 42, + **generation_kwargs, + ): + self.model = model + self.verbose = verbose + self.generation_kwargs = generation_kwargs + self.method = method + self.number_batches = number_batches + self.random = random.Random(seed) + + async def __call__( + self, + record: LatentRecord, + ) -> ScorerResult: + + # Create tasks with the positive and negative test examples + batches = self._create_batches(record, number_batches=self.number_batches) + + # Compute the probability of solving the task for each task + delta_tuples = [self.compute_batch_deltas(batch) for batch in batches] + score = [ + EmbeddingOutput(batch=batch, delta_plus=delta_plus, delta_minus=delta_minus) + for batch, (delta_plus, delta_minus) in zip(batches, delta_tuples) + ] + + return ScorerResult(record=record, score=score) + + def call_sync(self, record: LatentRecord) -> ScorerResult: + return asyncio.run(self.__call__(record)) + + def compute_batch_deltas(self, batch: Batch) -> tuple[float, float]: + """ + Compute the probability of solving the task. + """ + with torch.no_grad(): + # Use the embedding model to embed all the examples + # Concatenate all inputs into a single list + all_inputs = ( + batch.negative_examples + + batch.positive_examples + + [batch.positive_query, batch.negative_query] + ) + # Encode everything at once + all_embeddings = self.model.encode(all_inputs) + + # Split the embeddings back into their components + n_neg = len(batch.negative_examples) + n_pos = len(batch.positive_examples) + negative_examples_embeddings = all_embeddings[:n_neg] + positive_examples_embeddings = all_embeddings[n_neg : n_neg + n_pos] + positive_query_embedding = all_embeddings[-2].unsqueeze(0) + negative_query_embedding = all_embeddings[-1].unsqueeze(0) + + # Compute the similarity between the query and the examples + negative_similarities = self.model.similarity( + negative_query_embedding, + torch.cat([negative_examples_embeddings, positive_examples_embeddings]), + ) + negative_negative_similarity = negative_similarities[ + :, : len(negative_examples_embeddings) + ] + negative_positive_similarity = negative_similarities[ + :, len(negative_examples_embeddings) : + ] + + positive_similarities = self.model.similarity( + positive_query_embedding, + torch.cat([negative_examples_embeddings, positive_examples_embeddings]), + ) + positive_negative_similarity = positive_similarities[ + :, : len(negative_examples_embeddings) + ] + positive_positive_similarity = positive_similarities[ + :, len(negative_examples_embeddings) : + ] + + delta_positive = ( + positive_positive_similarity.mean() + - positive_negative_similarity.mean() + ) + delta_negative = ( + negative_positive_similarity.mean() + - negative_negative_similarity.mean() + ) + + return delta_positive.item(), delta_negative.item() + + def _create_batches( + self, record: LatentRecord, number_batches: int = 20 + ) -> list[Batch]: + + # Get the positive and negative train examples, + # which are going to be used as "explanations" + positive_train_examples = record.train + + # Sample from the not_active examples + not_active_index = self.random.sample( + range(len(record.not_active)), len(positive_train_examples) + ) + negative_train_examples = [record.not_active[i] for i in not_active_index] + + # Get the positive and negative test examples, + # which are going to be used as "queries" + positive_test_examples = record.test + + not_active_test_index = [ + i for i in range(len(record.not_active)) if i not in not_active_index + ] + negative_test_examples = [record.not_active[i] for i in not_active_test_index] + + batches = [] + + for _ in range(number_batches): + # Prepare the positive query + positive_query_idx = self.random.sample( + range(len(positive_test_examples)), 1 + )[0] + positive_query = positive_test_examples[positive_query_idx] + n_active_tokens = int((positive_query.activations > 0).sum().item()) + positive_query_str, _ = _prepare_text( + positive_query, n_incorrect=0, threshold=0.3, highlighted=True + ) + # Prepare the negative query + if self.method == "default": + # In the default method, we just sample a random negative example + negative_query_idx = self.random.sample( + range(len(negative_test_examples)), 1 + )[0] + negative_query = negative_test_examples[negative_query_idx] + negative_query_str, _ = _prepare_text( + negative_query, + n_incorrect=n_active_tokens, + threshold=0.3, + highlighted=True, + ) + elif self.method == "internal": + # In the internal method, we sample a negative example + # that has a different quantile as the positive query + positive_query_quantile = positive_query.quantile + negative_query_quantile = positive_query_quantile + # TODO: This is kinda ugly, but it probably doesn't matter + while negative_query_quantile == positive_query_quantile: + negative_query_idx = self.random.sample( + range(len(positive_test_examples)), 1 + )[0] + negative_query_temp = positive_test_examples[negative_query_idx] + negative_query_quantile = negative_query.distance + + negative_query = NonActivatingExample( + str_tokens=negative_query_temp.str_tokens, + tokens=negative_query_temp.tokens, + activations=negative_query_temp.activations, + distance=negative_query_temp.quantile, + ) + # Because it is a converted activating example, it will highlight + # the activating tokens + negative_query_str, _ = _prepare_text( + negative_query, n_incorrect=0, threshold=0.3, highlighted=True + ) + + # Find all all the positive_train_examples + # that have the same quantile as the positive_query + positive_examples = [ + e + for e in positive_train_examples + if e.quantile == positive_query.quantile + ] + if len(positive_examples) > 10: + positive_examples = self.random.sample(positive_examples, 10) + positive_examples_str = [ + _prepare_text(e, n_incorrect=0, threshold=0.3, highlighted=True)[0] + for e in positive_examples + ] + + # negative examples + if self.method == "default": + # In the default method, we just sample a random negative example + negative_examples = self.random.sample(negative_train_examples, 10) + negative_examples_str = [ + _prepare_text( + e, n_incorrect=n_active_tokens, threshold=0.3, highlighted=True + )[0] + for e in negative_examples + ] + elif self.method == "internal": + # In the internal method, we sample an activating example + # that has the same quantile as the negative_query + negative_examples = [ + e + for e in positive_train_examples + if e.quantile == negative_query.distance + ] + if len(negative_examples) > 10: + negative_examples = self.random.sample(negative_examples, 10) + negative_examples_str = [ + _prepare_text(e, n_incorrect=0, threshold=0.3, highlighted=True)[0] + for e in negative_examples + ] + + batch = Batch( + negative_examples=negative_examples_str, + positive_examples=positive_examples_str, + positive_query=positive_query_str, + negative_query=negative_query_str, + quantile_positive_query=positive_query.quantile, + distance_negative_query=negative_query.distance, + ) + batches.append(batch) + return batches diff --git a/tests/e2e.py b/tests/e2e.py index 35305ddb..4ec5d890 100644 --- a/tests/e2e.py +++ b/tests/e2e.py @@ -70,4 +70,4 @@ async def test(): if __name__ == "__main__": - asyncio.run(test()) \ No newline at end of file + asyncio.run(test()) diff --git a/tests/test_latents/test_constructor.py b/tests/test_latents/test_constructor.py index dd679551..bbea7ba7 100644 --- a/tests/test_latents/test_constructor.py +++ b/tests/test_latents/test_constructor.py @@ -113,6 +113,7 @@ def test_simple_cache( train_type=train_type, test_type="quantiles", ), + tokenizer=tokenizer, ) assert len(record.train) <= n_examples assert len(record.test) <= n_examples