@@ -27,9 +27,9 @@ def compute_auc(df: pd.DataFrame) -> float | None:
27
27
if not df .probability .nunique ():
28
28
return None
29
29
30
- df = df [df .probability .notna ()]
30
+ valid_df = df [df .probability .notna ()]
31
31
32
- return roc_auc_score (df .activating , df .probability ) # type: ignore
32
+ return roc_auc_score (valid_df .activating , valid_df .probability ) # type: ignore
33
33
34
34
35
35
def plot_accuracy_hist (df : pd .DataFrame , out_dir : Path ):
@@ -49,10 +49,10 @@ def plot_roc_curve(df: pd.DataFrame, out_dir: Path):
49
49
return
50
50
51
51
# filter out NANs
52
- df = df [df .probability .notna ()]
52
+ valid_df = df [df .probability .notna ()]
53
53
54
- fpr , tpr , _ = roc_curve (df .activating , df .probability )
55
- auc = roc_auc_score (df .activating , df .probability )
54
+ fpr , tpr , _ = roc_curve (valid_df .activating , valid_df .probability )
55
+ auc = roc_auc_score (valid_df .activating , valid_df .probability )
56
56
fig = go .Figure (
57
57
data = [
58
58
go .Scatter (x = fpr , y = tpr , mode = "lines" , name = f"ROC (AUC={ auc :.3f} )" ),
@@ -173,6 +173,19 @@ def parse_score_file(path: Path) -> pd.DataFrame:
173
173
return pd .concat (latent_dfs , ignore_index = True ), counts
174
174
175
175
176
+ def get_metrics (latent_df : pd .DataFrame ) -> pd .DataFrame :
177
+ processed_rows = []
178
+ for score_type , group_df in latent_df .groupby ("score_type" ):
179
+ conf = compute_confusion (group_df )
180
+ class_m = compute_classification_metrics (conf )
181
+ auc = compute_auc (group_df )
182
+
183
+ row = {"score_type" : score_type , ** conf , ** class_m , "auc" : auc }
184
+ processed_rows .append (row )
185
+
186
+ return pd .DataFrame (processed_rows )
187
+
188
+
176
189
def log_results (scores_path : Path , viz_path : Path , modules : list [str ]):
177
190
import_plotly ()
178
191
@@ -187,17 +200,7 @@ def log_results(scores_path: Path, viz_path: Path, modules: list[str]):
187
200
188
201
plot_roc_curve (latent_df , viz_path )
189
202
190
- # Produce statistics averaged over layers and latents
191
- processed_rows = []
192
- for score_type , group_df in latent_df .groupby ("score_type" ):
193
- conf = compute_confusion (group_df )
194
- class_m = compute_classification_metrics (conf )
195
- auc = compute_auc (group_df )
196
-
197
- row = {"score_type" : score_type , ** conf , ** class_m , "auc" : auc }
198
- processed_rows .append (row )
199
-
200
- processed_df = pd .DataFrame (processed_rows )
203
+ processed_df = get_metrics (latent_df )
201
204
202
205
plot_accuracy_hist (processed_df , viz_path )
203
206
0 commit comments