diff --git a/skythought/evals/tasks/mmlu/mmlu_handler.py b/skythought/evals/tasks/mmlu/mmlu_handler.py index 2d48c9bc..96c7b29e 100644 --- a/skythought/evals/tasks/mmlu/mmlu_handler.py +++ b/skythought/evals/tasks/mmlu/mmlu_handler.py @@ -49,7 +49,7 @@ def load_and_filter_dataset( return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] -class MMLUProTaskHandler(MMLUTaskHandler): +class MMLUProTaskHandler(TaskHandler): def __init__(self, task_config: TaskConfig): super().__init__(task_config) self.choices = [ @@ -71,9 +71,27 @@ def __init__(self, task_config: TaskConfig): "P", ] - def generate_prompt(self, prompt): + def generate_prompt(self, problem): + multiple_choice_string = self.get_multiple_choice_answers(problem) + prompt = problem["question"] + "\n" + multiple_choice_string return self.task_config.templating_parameters["template"].format(prompt=prompt) + def update_results(self, problem, response): + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + curr_res = self.check_correctness(problem, generation=response) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Solution is incorrect." + return response_entry + def check_correctness(self, problem, generation): pred = mmlu_pro_extract_answer(generation) answer = self.choices[problem["answer_index"]]