Skip to content

Commit c437ef2

Browse files
authored
Merge pull request #193 from RobotSail/add-ragas
adds basic ragas eval
2 parents 2e7e405 + ab3d168 commit c437ef2

File tree

3 files changed

+489
-0
lines changed

3 files changed

+489
-0
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ pandas
1010
pandas-stubs
1111
lm-eval>=0.4.4
1212
httpx
13+
ragas

src/instructlab/eval/ragas.py

+264
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# # SPDX-License-Identifier: Apache-2.0
2+
# Standard
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING, List, Optional, TypedDict
5+
6+
# Third Party
7+
from langchain_community.chat_models import ChatOpenAI
8+
from openai import Client as OpenAIClient
9+
from openai.types.chat import ChatCompletionMessageParam
10+
from pandas import DataFrame, read_json
11+
from pydantic import BaseModel, ConfigDict, Field
12+
from ragas.evaluation import EvaluationDataset, EvaluationResult, RunConfig, evaluate
13+
from ragas.metrics import Metric
14+
from ragas.metrics._domain_specific_rubrics import ( # the rubrics we must instantiate are located inside of a file marked as private
15+
DEFAULT_WITH_REFERENCE_RUBRICS,
16+
RubricsScore,
17+
)
18+
19+
# Local
20+
from .evaluator import Evaluator
21+
from .logger_config import setup_logger
22+
23+
logger = setup_logger(__name__)
24+
25+
26+
class Sample(TypedDict):
27+
"""
28+
TypedDict of a sample that we accept when doing eval with Ragas.
29+
We specifically use TypedDict here to be flexible with the input data we accept.
30+
"""
31+
32+
# question
33+
user_input: str
34+
35+
# model answer
36+
response: Optional[str]
37+
38+
# golden answer
39+
reference: str
40+
41+
42+
# default system prompt we'll use when none is provided. Make it private as we don't intend this to be a public object
43+
_DEFAULT_SYSTEM_PROMPT = """You are an advanced AI assistant designed to provide precise and accurate information.
44+
Your primary goal is to answer queries with the most up-to-date and factual information available.
45+
Focus on delivering clear, concise, and correct responses.
46+
If you're uncertain about any aspect of the query, state your level of confidence and provide the most accurate information you can.
47+
Your responses should prioritize accuracy over all other considerations."""
48+
49+
DEFAULT_SEED = 1337
50+
DEFAULT_JUDGE_MODEL = "gpt-4o"
51+
52+
53+
class ModelConfig(BaseModel):
54+
model_config = ConfigDict(protected_namespaces=())
55+
56+
# name of the model to use.
57+
model_name: str
58+
59+
# The system prompt to be used when applying the chat template.
60+
system_prompt: str = _DEFAULT_SYSTEM_PROMPT
61+
62+
# "model randomness" aka likelihood of sampling something other than the likeliest token
63+
temperature: float = Field(default=0.0, le=1.0, ge=0.0)
64+
65+
# Max amount of tokens to generate.
66+
max_tokens: int = 768
67+
68+
# Random seed for reproducibility. Caution: this isn't supported by all model serving runtimes.
69+
seed: int = DEFAULT_SEED
70+
71+
72+
class RagasEvaluator(Evaluator):
73+
# most basic implementation, we just assume that the user will bring the existing model responses
74+
name = "ragas"
75+
76+
def __init__(
77+
self,
78+
student_model: ModelConfig | None = None,
79+
run_config: RunConfig | None = None,
80+
student_openai_client: OpenAIClient | None = None,
81+
judge_model_name: str = DEFAULT_JUDGE_MODEL,
82+
judge_openai_api_key: str | None = None,
83+
):
84+
self.student_model = student_model
85+
self.run_config = run_config
86+
self.student_openai_client = student_openai_client
87+
self.judge_model_name = judge_model_name
88+
self.judge_openai_api_key = judge_openai_api_key
89+
90+
@staticmethod
91+
def _validate_dataset(df: DataFrame):
92+
"""
93+
Validates whether or not the given `df` is a valid dataset of `Sample` objects.
94+
95+
Args:
96+
df (DataFrame): DataFrame containing the dataset to be evaluated.
97+
"""
98+
# We have to hardcode these fields because the automated way of resolving the required fields from a TypedDict
99+
# is only included by default in Python3.11+. For earlier versions, the `typing_extensions` package is required.
100+
# See: https://docs.python.org/3/whatsnew/3.11.html#pep-655-marking-individual-typeddict-items-as-required-or-not-required
101+
required_keys = {"user_input", "reference"}
102+
missing_keys = required_keys - set(df.columns)
103+
if missing_keys:
104+
raise ValueError(
105+
f"invalid dataset provided, missing the following keys: {', '.join(missing_keys)}"
106+
)
107+
108+
def run(
109+
self,
110+
dataset: List[Sample] | Path,
111+
student_model: ModelConfig | None = None,
112+
run_config: RunConfig | None = None,
113+
student_openai_client: OpenAIClient | None = None,
114+
judge_model_name: str | None = None,
115+
judge_openai_api_key: str | None = None,
116+
) -> EvaluationResult:
117+
"""
118+
Evaluates the quality of model responses against a graded rubric.
119+
120+
When the `dataset` lacks the `response` field, then `student_model` must be provided
121+
in order to generate the answers.
122+
123+
Args:
124+
dataset (List[Sample] | Path):
125+
Can be either a list of `Sample` objects or a path to a jsonl file containing
126+
records matching `Sample`.
127+
student_model: (StudentModelConfig):
128+
When this parameter is provided, we'll attempt to use the described model in order to
129+
generate the responses from the given list of questions.
130+
run_config (RunConfig | None, optional):
131+
Configuration to use when running evaluations. If none is provided, then
132+
a default one is created containing extremely permissive settings when handling
133+
timeouts. This is because by default, OpenAI tier-1 usage accounts have very high
134+
rate limits resulting in heavy throttling during evaluations.
135+
student_openai_client (openai.Client | None, optional):
136+
The client to use when generating questions from the student model, must be compatible with the OpenAI API.
137+
This field is required when `student_model` is provided.
138+
judge_model_name (str | None, optional):
139+
Name of the OpenAI model to use as the judge model. Defaults to "gpt-4o" when none is specified.
140+
judge_openai_api_key (str | None, optional):
141+
The API key to use for evaluating the given dataset. When this isn't provided, `OPENAI_API_KEY` is read instead.
142+
143+
144+
Returns:
145+
EvaluationResult: The results of all evaluations performed by Ragas
146+
"""
147+
judge_model_name = (
148+
judge_model_name if judge_model_name else self.judge_model_name
149+
)
150+
judge_openai_api_key = (
151+
judge_openai_api_key if judge_openai_api_key else self.judge_openai_api_key
152+
)
153+
student_model = student_model if student_model else self.student_model
154+
run_config = run_config if run_config else self.run_config
155+
student_openai_client = (
156+
student_openai_client
157+
if student_openai_client
158+
else self.student_openai_client
159+
)
160+
161+
# ensure we are in the dataframe format
162+
input_df = None
163+
if isinstance(dataset, list):
164+
input_df = DataFrame(dataset)
165+
elif isinstance(dataset, Path):
166+
input_df = read_json(dataset, orient="records", lines=True)
167+
else:
168+
raise TypeError(f"invalid type of dataset: {type(dataset)}")
169+
170+
# this should never happen, but pylint is not smart enough to detect it
171+
if TYPE_CHECKING:
172+
assert input_df is not None
173+
174+
# ensure the dataset is in the format we expect it
175+
self._validate_dataset(input_df)
176+
177+
need_to_generate_questions = "response" not in input_df.columns
178+
if need_to_generate_questions:
179+
logger.debug(
180+
"`response` is missing in the input dataframe columns, generating questions from the model is required."
181+
)
182+
if not student_model or not student_openai_client:
183+
raise ValueError(
184+
"provided dataset doesn't contain the model `response`, but either `student_model` or `student_openai_client` wasn't provided for inference"
185+
)
186+
187+
# if the student model was provided then we always generate regardless
188+
if student_model:
189+
if not student_openai_client:
190+
raise ValueError(
191+
"`student_model` was specified but `student_openai_client` was not provided"
192+
)
193+
input_df = self._generate_answers_from_model(
194+
input_df, student_model, student_openai_client
195+
)
196+
197+
if not run_config:
198+
# we set extreme timeout/retry values by default since OpenAI tier-1 rate limits
199+
# are horrible and will result in half of our evaluation results being NaN or 0
200+
run_config = RunConfig(
201+
max_retries=120,
202+
max_wait=7200,
203+
seed=DEFAULT_SEED,
204+
timeout=3600,
205+
)
206+
207+
metrics = self._get_metrics()
208+
evaluation_ds = EvaluationDataset.from_pandas(input_df)
209+
210+
# we will be using gpt-4o for the foreseeable future, we hardcode this
211+
# for consistency of answers
212+
213+
critic_lm = ChatOpenAI(model=judge_model_name, api_key=judge_openai_api_key)
214+
results = evaluate(
215+
dataset=evaluation_ds,
216+
batch_size=4,
217+
run_config=run_config,
218+
llm=critic_lm,
219+
metrics=metrics,
220+
show_progress=True,
221+
)
222+
return results
223+
224+
def _generate_answers_from_model(
225+
self,
226+
questions: DataFrame,
227+
student_model: ModelConfig,
228+
student_openai_client: OpenAIClient,
229+
) -> DataFrame:
230+
"""
231+
Given a DataFrame containing `user_input` columns, generates responses from the given model
232+
and returns a new DataFrame containing its answers in the `response` column.
233+
"""
234+
# initialize response to write into
235+
updated_df = questions.copy()
236+
updated_df["response"] = ""
237+
238+
for i, qna in updated_df.iterrows():
239+
messages: List[ChatCompletionMessageParam] = [
240+
{
241+
"role": "system",
242+
"content": student_model.system_prompt,
243+
},
244+
{"role": "user", "content": qna["user_input"]},
245+
]
246+
response = student_openai_client.chat.completions.create(
247+
messages=messages,
248+
model=student_model.model_name,
249+
# specify the seed so we can at least try to have some reproducibility when the clients support it
250+
seed=42,
251+
max_tokens=student_model.max_tokens,
252+
temperature=student_model.temperature,
253+
)
254+
updated_df.at[i, "response"] = response.choices[0].message.content
255+
return updated_df
256+
257+
@staticmethod
258+
def _get_metrics() -> List[Metric]:
259+
# default set of metrics
260+
return [
261+
RubricsScore(
262+
rubrics=DEFAULT_WITH_REFERENCE_RUBRICS,
263+
)
264+
]

0 commit comments

Comments
 (0)