Skip to content

Commit 9520859

Browse files
authored
Merge pull request #49 from HumanCompatibleAI/fix-probe-benchmark
Fix benchmark calculation when probing with attribute func
2 parents b01fefe + a73e603 commit 9520859

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/reward_preprocessing/scripts/train_probe.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def _sum_vecs(vec1, vec2):
5757
return list(map(operator.add, vec1, vec2))
5858

5959

60+
def compose_funcs(f, g):
61+
return lambda x: f(g(x))
62+
63+
6064
def get_mean_attr_val(list_attr_vals):
6165
num_vals = len(list_attr_vals)
6266
if isinstance(list_attr_vals[0], list):
@@ -88,14 +92,20 @@ def benchmark_accuracy(
8892
use_next_info: bool,
8993
attributes: Union[str, List[str]],
9094
attr_cap: Optional[float],
95+
attr_func: Optional[Union[Callable[..., float], Callable[..., List[float]]]],
9196
):
9297
"""Determine the MSE from always guessing the mean value of the attributes."""
9398
attr_list = attributes if isinstance(attributes, list) else [attributes]
9499
mse = 0
100+
attr_func_ = attr_func if attr_func is not None else (lambda x: x)
95101
for attr in attr_list:
96-
attr_vals = list(
97-
map(lambda x: x["next_infos" if use_next_info else "infos"][attr], dataset)
98-
)
102+
103+
def get_attr(x):
104+
field = "next_infos" if use_next_info else "infos"
105+
return x[field][attr]
106+
107+
attr_vals = list(map(compose_funcs(attr_func_, get_attr), dataset))
108+
99109
if attr_cap is not None:
100110
if attr_cap <= 0:
101111
raise ValueError("Attribute cap must be positive")

0 commit comments

Comments
 (0)