Skip to content

Commit ab3d168

Browse files
committed
chore: decouple tests into more atomic units
Signed-off-by: Oleg S <[email protected]>
1 parent c6b5a70 commit ab3d168

File tree

2 files changed

+220
-130
lines changed

2 files changed

+220
-130
lines changed

src/instructlab/eval/ragas.py

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# # SPDX-License-Identifier: Apache-2.0
22
# Standard
33
from pathlib import Path
4-
from typing import List, Optional, TypedDict
4+
from typing import TYPE_CHECKING, List, Optional, TypedDict
55

66
# Third Party
77
from langchain_community.chat_models import ChatOpenAI
88
from openai import Client as OpenAIClient
9+
from openai.types.chat import ChatCompletionMessageParam
910
from pandas import DataFrame, read_json
10-
from pydantic import BaseModel, ConfigDict, field_validator
11+
from pydantic import BaseModel, ConfigDict, Field
1112
from ragas.evaluation import EvaluationDataset, EvaluationResult, RunConfig, evaluate
1213
from ragas.metrics import Metric
1314
from ragas.metrics._domain_specific_rubrics import ( # the rubrics we must instantiate are located inside of a file marked as private
@@ -17,6 +18,9 @@
1718

1819
# Local
1920
from .evaluator import Evaluator
21+
from .logger_config import setup_logger
22+
23+
logger = setup_logger(__name__)
2024

2125

2226
class Sample(TypedDict):
@@ -56,21 +60,14 @@ class ModelConfig(BaseModel):
5660
system_prompt: str = _DEFAULT_SYSTEM_PROMPT
5761

5862
# "model randomness" aka likelihood of sampling something other than the likeliest token
59-
temperature: float = 0.0
63+
temperature: float = Field(default=0.0, le=1.0, ge=0.0)
6064

6165
# Max amount of tokens to generate.
6266
max_tokens: int = 768
6367

6468
# Random seed for reproducibility. Caution: this isn't supported by all model serving runtimes.
6569
seed: int = DEFAULT_SEED
6670

67-
@field_validator("temperature")
68-
@classmethod
69-
def check_temperature(cls, v: float) -> float:
70-
if not 0.0 <= v <= 1.0:
71-
raise ValueError("temperature must be between 0.0 and 1.0")
72-
return v
73-
7471

7572
class RagasEvaluator(Evaluator):
7673
# most basic implementation, we just assume that the user will bring the existing model responses
@@ -80,18 +77,42 @@ def __init__(
8077
self,
8178
student_model: ModelConfig | None = None,
8279
run_config: RunConfig | None = None,
83-
openai_client: OpenAIClient | None = None,
80+
student_openai_client: OpenAIClient | None = None,
81+
judge_model_name: str = DEFAULT_JUDGE_MODEL,
82+
judge_openai_api_key: str | None = None,
8483
):
8584
self.student_model = student_model
8685
self.run_config = run_config
87-
self.openai_client = openai_client
86+
self.student_openai_client = student_openai_client
87+
self.judge_model_name = judge_model_name
88+
self.judge_openai_api_key = judge_openai_api_key
89+
90+
@staticmethod
91+
def _validate_dataset(df: DataFrame):
92+
"""
93+
Validates whether or not the given `df` is a valid dataset of `Sample` objects.
94+
95+
Args:
96+
df (DataFrame): DataFrame containing the dataset to be evaluated.
97+
"""
98+
# We have to hardcode these fields because the automated way of resolving the required fields from a TypedDict
99+
# is only included by default in Python3.11+. For earlier versions, the `typing_extensions` package is required.
100+
# See: https://docs.python.org/3/whatsnew/3.11.html#pep-655-marking-individual-typeddict-items-as-required-or-not-required
101+
required_keys = {"user_input", "reference"}
102+
missing_keys = required_keys - set(df.columns)
103+
if missing_keys:
104+
raise ValueError(
105+
f"invalid dataset provided, missing the following keys: {', '.join(missing_keys)}"
106+
)
88107

89108
def run(
90109
self,
91110
dataset: List[Sample] | Path,
92111
student_model: ModelConfig | None = None,
93112
run_config: RunConfig | None = None,
94-
openai_client: OpenAIClient | None = None,
113+
student_openai_client: OpenAIClient | None = None,
114+
judge_model_name: str | None = None,
115+
judge_openai_api_key: str | None = None,
95116
) -> EvaluationResult:
96117
"""
97118
Evaluates the quality of model responses against a graded rubric.
@@ -111,21 +132,31 @@ def run(
111132
a default one is created containing extremely permissive settings when handling
112133
timeouts. This is because by default, OpenAI tier-1 usage accounts have very high
113134
rate limits resulting in heavy throttling during evaluations.
114-
openai_client (openai.Client | None, optional):
135+
student_openai_client (openai.Client | None, optional):
115136
The client to use when generating questions from the student model, must be compatible with the OpenAI API.
116137
This field is required when `student_model` is provided.
138+
judge_model_name (str | None, optional):
139+
Name of the OpenAI model to use as the judge model. Defaults to "gpt-4o" when none is specified.
140+
judge_openai_api_key (str | None, optional):
141+
The API key to use for evaluating the given dataset. When this isn't provided, `OPENAI_API_KEY` is read instead.
142+
117143
118144
Returns:
119145
EvaluationResult: The results of all evaluations performed by Ragas
120146
"""
147+
judge_model_name = (
148+
judge_model_name if judge_model_name else self.judge_model_name
149+
)
150+
judge_openai_api_key = (
151+
judge_openai_api_key if judge_openai_api_key else self.judge_openai_api_key
152+
)
121153
student_model = student_model if student_model else self.student_model
122154
run_config = run_config if run_config else self.run_config
123-
openai_client = openai_client if openai_client else self.openai_client
124-
125-
if not dataset:
126-
raise ValueError(
127-
"no dataset was provided, please specify the `dataset` argument"
128-
)
155+
student_openai_client = (
156+
student_openai_client
157+
if student_openai_client
158+
else self.student_openai_client
159+
)
129160

130161
# ensure we are in the dataframe format
131162
input_df = None
@@ -137,22 +168,30 @@ def run(
137168
raise TypeError(f"invalid type of dataset: {type(dataset)}")
138169

139170
# this should never happen, but pylint is not smart enough to detect it
140-
assert input_df is not None
171+
if TYPE_CHECKING:
172+
assert input_df is not None
173+
174+
# ensure the dataset is in the format we expect it
175+
self._validate_dataset(input_df)
141176

142177
need_to_generate_questions = "response" not in input_df.columns
143-
if need_to_generate_questions and (not student_model or not openai_client):
144-
raise ValueError(
145-
"provided dataset doesn't contain the model `response`, but either `student_model` or `openai_client` wasn't provided for inference"
178+
if need_to_generate_questions:
179+
logger.debug(
180+
"`response` is missing in the input dataframe columns, generating questions from the model is required."
146181
)
182+
if not student_model or not student_openai_client:
183+
raise ValueError(
184+
"provided dataset doesn't contain the model `response`, but either `student_model` or `student_openai_client` wasn't provided for inference"
185+
)
147186

148187
# if the student model was provided then we always generate regardless
149188
if student_model:
150-
if not openai_client:
189+
if not student_openai_client:
151190
raise ValueError(
152-
"`student_model` was specified but `openai_client` was not provided"
191+
"`student_model` was specified but `student_openai_client` was not provided"
153192
)
154193
input_df = self._generate_answers_from_model(
155-
input_df, student_model, openai_client
194+
input_df, student_model, student_openai_client
156195
)
157196

158197
if not run_config:
@@ -170,7 +209,8 @@ def run(
170209

171210
# we will be using gpt-4o for the foreseeable future, we hardcode this
172211
# for consistency of answers
173-
critic_lm = ChatOpenAI(model=DEFAULT_JUDGE_MODEL)
212+
213+
critic_lm = ChatOpenAI(model=judge_model_name, api_key=judge_openai_api_key)
174214
results = evaluate(
175215
dataset=evaluation_ds,
176216
batch_size=4,
@@ -185,7 +225,7 @@ def _generate_answers_from_model(
185225
self,
186226
questions: DataFrame,
187227
student_model: ModelConfig,
188-
openai_client: OpenAIClient,
228+
student_openai_client: OpenAIClient,
189229
) -> DataFrame:
190230
"""
191231
Given a DataFrame containing `user_input` columns, generates responses from the given model
@@ -196,11 +236,14 @@ def _generate_answers_from_model(
196236
updated_df["response"] = ""
197237

198238
for i, qna in updated_df.iterrows():
199-
messages = [
200-
student_model.system_prompt,
201-
qna["user_input"],
239+
messages: List[ChatCompletionMessageParam] = [
240+
{
241+
"role": "system",
242+
"content": student_model.system_prompt,
243+
},
244+
{"role": "user", "content": qna["user_input"]},
202245
]
203-
response = openai_client.chat.completions.create(
246+
response = student_openai_client.chat.completions.create(
204247
messages=messages,
205248
model=student_model.model_name,
206249
# specify the seed so we can at least try to have some reproducibility when the clients support it
@@ -211,7 +254,8 @@ def _generate_answers_from_model(
211254
updated_df.at[i, "response"] = response.choices[0].message.content
212255
return updated_df
213256

214-
def _get_metrics(self) -> List[Metric]:
257+
@staticmethod
258+
def _get_metrics() -> List[Metric]:
215259
# default set of metrics
216260
return [
217261
RubricsScore(

0 commit comments

Comments
 (0)