Skip to content

Commit 1a95cd6

Browse files
committed
fix e2e test
1 parent 11e10e5 commit 1a95cd6

File tree

5 files changed

+42
-29
lines changed

5 files changed

+42
-29
lines changed

.pre-commit-config.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@ repos:
1616
hooks:
1717
- id: ruff
1818
args: [--fix, --exit-non-zero-on-fix]
19+
- repo: https://github.com/RobertCraigie/pyright-python
20+
rev: v1.1.399
21+
hooks:
22+
- id: pyright

delphi/log/result_analysis.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def compute_auc(df: pd.DataFrame) -> float | None:
2727
if not df.probability.nunique():
2828
return None
2929

30-
df = df[df.probability.notna()]
30+
valid_df = df[df.probability.notna()]
3131

32-
return roc_auc_score(df.activating, df.probability) # type: ignore
32+
return roc_auc_score(valid_df.activating, valid_df.probability) # type: ignore
3333

3434

3535
def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path):
@@ -49,10 +49,10 @@ def plot_roc_curve(df: pd.DataFrame, out_dir: Path):
4949
return
5050

5151
# filter out NANs
52-
df = df[df.probability.notna()]
52+
valid_df = df[df.probability.notna()]
5353

54-
fpr, tpr, _ = roc_curve(df.activating, df.probability)
55-
auc = roc_auc_score(df.activating, df.probability)
54+
fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability)
55+
auc = roc_auc_score(valid_df.activating, valid_df.probability)
5656
fig = go.Figure(
5757
data=[
5858
go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={auc:.3f})"),
@@ -173,6 +173,19 @@ def parse_score_file(path: Path) -> pd.DataFrame:
173173
return pd.concat(latent_dfs, ignore_index=True), counts
174174

175175

176+
def get_metrics(latent_df: pd.DataFrame) -> pd.DataFrame:
177+
processed_rows = []
178+
for score_type, group_df in latent_df.groupby("score_type"):
179+
conf = compute_confusion(group_df)
180+
class_m = compute_classification_metrics(conf)
181+
auc = compute_auc(group_df)
182+
183+
row = {"score_type": score_type, **conf, **class_m, "auc": auc}
184+
processed_rows.append(row)
185+
186+
return pd.DataFrame(processed_rows)
187+
188+
176189
def log_results(scores_path: Path, viz_path: Path, modules: list[str]):
177190
import_plotly()
178191

@@ -187,17 +200,7 @@ def log_results(scores_path: Path, viz_path: Path, modules: list[str]):
187200

188201
plot_roc_curve(latent_df, viz_path)
189202

190-
# Produce statistics averaged over layers and latents
191-
processed_rows = []
192-
for score_type, group_df in latent_df.groupby("score_type"):
193-
conf = compute_confusion(group_df)
194-
class_m = compute_classification_metrics(conf)
195-
auc = compute_auc(group_df)
196-
197-
row = {"score_type": score_type, **conf, **class_m, "auc": auc}
198-
processed_rows.append(row)
199-
200-
processed_df = pd.DataFrame(processed_rows)
203+
processed_df = get_metrics(latent_df)
201204

202205
plot_accuracy_hist(processed_df, viz_path)
203206

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ readme = "README.md"
1010
requires-python = ">=3.10"
1111
keywords = ["interpretability", "explainable-ai"]
1212
dependencies = [
13+
"torch",
1314
"datasets",
15+
"transformers",
1416
"orjson",
1517
"eai-sparsify",
1618
"safetensors",

pyrightconfig.json

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"include": ["delphi"],
3+
"exclude": [
4+
"**/node_modules",
5+
"**/__pycache__"
6+
],
7+
"reportMissingImports": "none",
8+
"reportMissingModuleSource": "none",
9+
"pythonVersion": "3.10",
10+
"typeCheckingMode": "basic"
11+
}

tests/e2e.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from delphi.__main__ import run
88
from delphi.config import CacheConfig, ConstructorConfig, RunConfig, SamplerConfig
9-
from delphi.log.result_analysis import build_scores_df, latent_balanced_score_metrics
9+
from delphi.log.result_analysis import get_metrics, load_data
1010

1111

1212
async def test():
@@ -58,21 +58,14 @@ async def test():
5858
end_time = time.time()
5959
print(f"Time taken: {end_time - start_time} seconds")
6060

61-
# Performs better than random guessing
6261
scores_path = Path.cwd() / "results" / run_cfg.name / "scores"
63-
hookpoint_firing_counts = torch.load(
64-
Path.cwd() / "results" / run_cfg.name / "log" / "hookpoint_firing_counts.pt",
65-
weights_only=True,
66-
)
67-
df = build_scores_df(scores_path, run_cfg.hookpoints, hookpoint_firing_counts)
68-
for score_type in df["score_type"].unique():
69-
score_df = df.query(f"score_type == '{score_type}'")
7062

71-
weighted_mean_metrics = latent_balanced_score_metrics(
72-
score_df, score_type, verbose=False
73-
)
63+
latent_df, _ = load_data(scores_path, run_cfg.hookpoints)
64+
processed_df = get_metrics(latent_df)
7465

75-
accuracy = weighted_mean_metrics["accuracy"]
66+
# Performs better than random guessing
67+
for score_type, df in processed_df.groupby("score_type"):
68+
accuracy = df["accuracy"].mean()
7669
assert accuracy > 0.55, f"Score type {score_type} has an accuracy of {accuracy}"
7770

7871

0 commit comments

Comments
 (0)