diff --git a/src/lemonade/tools/mmlu.py b/src/lemonade/tools/mmlu.py index 94a75da..71ce870 100644 --- a/src/lemonade/tools/mmlu.py +++ b/src/lemonade/tools/mmlu.py @@ -210,6 +210,13 @@ def run( state.save_stat(stat_units_name, "%") self.status_stats.append(stat_name) + # Calculate average of mmlu accuracy and display in the CLI + acc_avg = np.mean([accuracy_data["Accuracy"] for accuracy_data in summary_data]) + avg_stat_name = "avg_accuracy" + state.save_stat(avg_stat_name, float(acc_avg) * 100) + state.save_stat("accuracy_units", "%") + self.status_stats.append(avg_stat_name) + # Save accuracy results to CSV file summary_df = pd.DataFrame(summary_data) summary_df.to_csv(