@@ -57,6 +57,10 @@ def _sum_vecs(vec1, vec2):
57
57
return list (map (operator .add , vec1 , vec2 ))
58
58
59
59
60
+ def compose_funcs (f , g ):
61
+ return lambda x : f (g (x ))
62
+
63
+
60
64
def get_mean_attr_val (list_attr_vals ):
61
65
num_vals = len (list_attr_vals )
62
66
if isinstance (list_attr_vals [0 ], list ):
@@ -88,14 +92,20 @@ def benchmark_accuracy(
88
92
use_next_info : bool ,
89
93
attributes : Union [str , List [str ]],
90
94
attr_cap : Optional [float ],
95
+ attr_func : Optional [Union [Callable [..., float ], Callable [..., List [float ]]]],
91
96
):
92
97
"""Determine the MSE from always guessing the mean value of the attributes."""
93
98
attr_list = attributes if isinstance (attributes , list ) else [attributes ]
94
99
mse = 0
100
+ attr_func_ = attr_func if attr_func is not None else (lambda x : x )
95
101
for attr in attr_list :
96
- attr_vals = list (
97
- map (lambda x : x ["next_infos" if use_next_info else "infos" ][attr ], dataset )
98
- )
102
+
103
+ def get_attr (x ):
104
+ field = "next_infos" if use_next_info else "infos"
105
+ return x [field ][attr ]
106
+
107
+ attr_vals = list (map (compose_funcs (attr_func_ , get_attr ), dataset ))
108
+
99
109
if attr_cap is not None :
100
110
if attr_cap <= 0 :
101
111
raise ValueError ("Attribute cap must be positive" )
0 commit comments