|
| 1 | +import json |
| 2 | +import os |
| 3 | +import re |
| 4 | +import traceback |
| 5 | +import uuid |
| 6 | +from pathlib import Path |
| 7 | +from typing import Optional |
| 8 | + |
| 9 | +import numpy as np |
| 10 | +from pydantic import BaseModel |
| 11 | + |
| 12 | +import evals.metrics |
| 13 | +from evals.api import CompletionFn |
| 14 | +from evals.elsuite.rag_match import get_rag_dataset |
| 15 | +from evals.elsuite.utils import ReactionDictMatching, ReactionDictMatchingSimple |
| 16 | +from evals.record import RecorderBase, record_match |
| 17 | + |
| 18 | +code_pattern = r"```[\s\S]*?\n([\s\S]+?)\n```" |
| 19 | +json_pattern = r"```json[\s\S]*?\n([\s\S]+?)\n```" |
| 20 | +csv_pattern = r"```csv[\s\S]*?\n([\s\S]+?)\n```" |
| 21 | +table_pattern = r"\n({index0}[\s\S]+)\n[`]*" |
| 22 | +outlink_pattern = r"\[Download[a-zA-Z0-9 ]+?\]\((https://[a-zA-Z0-9_. /]+?)\)" |
| 23 | + |
| 24 | + |
| 25 | +class FileSampleWithInput(BaseModel): |
| 26 | + input: Optional[str] |
| 27 | + file_name: Optional[str] |
| 28 | + file_link: Optional[str] |
| 29 | + answerfile_name: Optional[str] |
| 30 | + answerfile_link: Optional[str] |
| 31 | + |
| 32 | + |
| 33 | +class ReactionExtract(evals.Eval): |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + completion_fns: list[CompletionFn], |
| 37 | + samples_jsonl: str, |
| 38 | + *args, |
| 39 | + instructions: Optional[str] = "", |
| 40 | + **kwargs, |
| 41 | + ): |
| 42 | + super().__init__(completion_fns, *args, **kwargs) |
| 43 | + assert len(completion_fns) < 3, "ReactionExtract only supports 3 completion fns" |
| 44 | + self.samples_jsonl = samples_jsonl |
| 45 | + self.instructions = instructions |
| 46 | + |
| 47 | + def eval_sample(self, sample, rng): |
| 48 | + assert isinstance(sample, dict) |
| 49 | + |
| 50 | + input_formatted = sample["input"] if type(sample["input"]) == list else [{"role": "user", "content": sample["input"]}] |
| 51 | + if self.instructions: |
| 52 | + prompt = [{"role": "system", "content": self.instructions}] + input_formatted |
| 53 | + else: |
| 54 | + prompt = input_formatted |
| 55 | + |
| 56 | + result = self.completion_fn( |
| 57 | + prompt=prompt, |
| 58 | + temperature=0.0, |
| 59 | + file_name=sample["file_name"], |
| 60 | + file_link=sample["file_link"] |
| 61 | + ) |
| 62 | + sampled = result.get_completions()[0] |
| 63 | + # correct_str = open(sample["answerfile_name"], 'r').read() |
| 64 | + # correct_answer = json.loads(correct_str) |
| 65 | + correct_answer = json.load(open(sample["answerfile_name"], 'r'))["inputs"] |
| 66 | + correct_str = json.dumps(correct_answer, indent=4) |
| 67 | + |
| 68 | + try: |
| 69 | + if re.search(outlink_pattern, sampled) is not None: |
| 70 | + code = re.search(outlink_pattern, sampled).group() |
| 71 | + link = re.sub(outlink_pattern, r"\1", code) |
| 72 | + |
| 73 | + fname = f"/tmp/LLMEvals_{uuid.uuid4()}.json" |
| 74 | + os.system(f"wget {link} -O {fname}") |
| 75 | + answer = json.load(open(fname, 'r')) |
| 76 | + elif "json" in self.instructions: |
| 77 | + code = re.search(json_pattern, sampled).group() |
| 78 | + code_content = re.sub(json_pattern, r"\1", code) |
| 79 | + code_content = code_content.replace("\"", '"') |
| 80 | + |
| 81 | + # Delete comments |
| 82 | + code_content = re.sub(r'//.*', '', code_content) |
| 83 | + answer = json.loads(code_content) |
| 84 | + else: |
| 85 | + answer = {} |
| 86 | + picked_str = json.dumps(answer, indent=4) |
| 87 | + open(sample["answerfile_name"].replace(".json", "_out.json"), 'w').write(picked_str) |
| 88 | + except: |
| 89 | + print(Path(sample["file_name"]).stem) |
| 90 | + traceback.print_exc() |
| 91 | + record_match( |
| 92 | + prompt=prompt, |
| 93 | + correct=False, |
| 94 | + expected=correct_str, |
| 95 | + picked=sampled, |
| 96 | + file_name=sample["file_name"], |
| 97 | + jobtype="match_all" |
| 98 | + ) |
| 99 | + picked_str = "Failed to parse" |
| 100 | + answer = {} |
| 101 | + return {"accuracy_leaves": 0} |
| 102 | + |
| 103 | + accuracy_leaves, df = ReactionDictMatchingSimple(correct_answer, answer, content="raw") |
| 104 | + record_match( |
| 105 | + prompt=prompt, |
| 106 | + correct=(accuracy_leaves == 1.0), |
| 107 | + expected=correct_str, |
| 108 | + picked=picked_str, |
| 109 | + file_name=sample["file_name"], |
| 110 | + jobtype="match_all" |
| 111 | + ) |
| 112 | + return {"accuracy_leaves": accuracy_leaves} |
| 113 | + |
| 114 | + def run(self, recorder: RecorderBase): |
| 115 | + samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix()) |
| 116 | + metrics_all_sample = self.eval_all_samples(recorder, samples) |
| 117 | + |
| 118 | + metrics = {key: np.mean([sample_metrics[key] for sample_metrics in metrics_all_sample]) for key in metrics_all_sample[0].keys()} |
| 119 | + # if "SMILES" in raw_samples[0]["compare_fields"]: |
| 120 | + # metrics["recall_SMILES"] = np.mean([sample_metrics["recall_SMILES"] for sample_metrics in metrics_all_sample |
| 121 | + # if "recall_SMILES" in sample_metrics]) |
| 122 | + return metrics |
0 commit comments