Skip to content

Commit c6b5a70

Browse files
committed
feat: update the Ragas evaluator to have the OpenAI client as something that gets passed in to __init__
Signed-off-by: Oleg S <[email protected]>
1 parent 04117dd commit c6b5a70

File tree

2 files changed

+75
-53
lines changed

2 files changed

+75
-53
lines changed

src/instructlab/eval/ragas.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# Third Party
77
from langchain_community.chat_models import ChatOpenAI
8+
from openai import Client as OpenAIClient
89
from pandas import DataFrame, read_json
910
from pydantic import BaseModel, ConfigDict, field_validator
1011
from ragas.evaluation import EvaluationDataset, EvaluationResult, RunConfig, evaluate
@@ -16,7 +17,6 @@
1617

1718
# Local
1819
from .evaluator import Evaluator
19-
from .mt_bench_common import get_openai_client
2020

2121

2222
class Sample(TypedDict):
@@ -49,19 +49,12 @@ class Sample(TypedDict):
4949
class ModelConfig(BaseModel):
5050
model_config = ConfigDict(protected_namespaces=())
5151

52-
# URL of the OpenAI server where the model shall be hosted.
53-
base_url: str
54-
5552
# name of the model to use.
5653
model_name: str
5754

5855
# The system prompt to be used when applying the chat template.
5956
system_prompt: str = _DEFAULT_SYSTEM_PROMPT
6057

61-
# We do NOT read from OPENAI_API_KEY for the student model for security reasons (e.g. sending the API key to another client)
62-
# To provide an OpenAI key, you must set it here; else the default is used.
63-
api_key: str = "no-api-key"
64-
6558
# "model randomness" aka likelihood of sampling something other than the likeliest token
6659
temperature: float = 0.0
6760

@@ -87,15 +80,18 @@ def __init__(
8780
self,
8881
student_model: ModelConfig | None = None,
8982
run_config: RunConfig | None = None,
83+
openai_client: OpenAIClient | None = None,
9084
):
9185
self.student_model = student_model
9286
self.run_config = run_config
87+
self.openai_client = openai_client
9388

9489
def run(
9590
self,
9691
dataset: List[Sample] | Path,
9792
student_model: ModelConfig | None = None,
9893
run_config: RunConfig | None = None,
94+
openai_client: OpenAIClient | None = None,
9995
) -> EvaluationResult:
10096
"""
10197
Evaluates the quality of model responses against a graded rubric.
@@ -115,12 +111,16 @@ def run(
115111
a default one is created containing extremely permissive settings when handling
116112
timeouts. This is because by default, OpenAI tier-1 usage accounts have very high
117113
rate limits resulting in heavy throttling during evaluations.
114+
openai_client (openai.Client | None, optional):
115+
The client to use when generating questions from the student model, must be compatible with the OpenAI API.
116+
This field is required when `student_model` is provided.
118117
119118
Returns:
120119
EvaluationResult: The results of all evaluations performed by Ragas
121120
"""
122121
student_model = student_model if student_model else self.student_model
123122
run_config = run_config if run_config else self.run_config
123+
openai_client = openai_client if openai_client else self.openai_client
124124

125125
if not dataset:
126126
raise ValueError(
@@ -140,14 +140,20 @@ def run(
140140
assert input_df is not None
141141

142142
need_to_generate_questions = "response" not in input_df.columns
143-
if need_to_generate_questions and not student_model:
143+
if need_to_generate_questions and (not student_model or not openai_client):
144144
raise ValueError(
145-
"provided dataset doesn't contain the model `response`, but no `student_model` was provided for inference"
145+
"provided dataset doesn't contain the model `response`, but either `student_model` or `openai_client` wasn't provided for inference"
146146
)
147147

148148
# if the student model was provided then we always generate regardless
149149
if student_model:
150-
input_df = self._generate_answers_from_model(input_df, student_model)
150+
if not openai_client:
151+
raise ValueError(
152+
"`student_model` was specified but `openai_client` was not provided"
153+
)
154+
input_df = self._generate_answers_from_model(
155+
input_df, student_model, openai_client
156+
)
151157

152158
if not run_config:
153159
# we set extreme timeout/retry values by default since OpenAI tier-1 rate limits
@@ -176,16 +182,15 @@ def run(
176182
return results
177183

178184
def _generate_answers_from_model(
179-
self, questions: DataFrame, student_model: ModelConfig
185+
self,
186+
questions: DataFrame,
187+
student_model: ModelConfig,
188+
openai_client: OpenAIClient,
180189
) -> DataFrame:
181190
"""
182191
Given a DataFrame containing `user_input` columns, generates responses from the given model
183192
and returns a new DataFrame containing its answers in the `response` column.
184193
"""
185-
client = get_openai_client(
186-
model_api_base=student_model.base_url, api_key=student_model.api_key
187-
)
188-
189194
# initialize response to write into
190195
updated_df = questions.copy()
191196
updated_df["response"] = ""
@@ -195,7 +200,7 @@ def _generate_answers_from_model(
195200
student_model.system_prompt,
196201
qna["user_input"],
197202
]
198-
response = client.chat.completions.create(
203+
response = openai_client.chat.completions.create(
199204
messages=messages,
200205
model=student_model.model_name,
201206
# specify the seed so we can at least try to have some reproducibility when the clients support it

tests/test_ragas.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,58 +11,55 @@
1111
import pandas as pd
1212

1313
# First Party
14-
from instructlab.eval.ragas import ModelConfig, RagasEvaluator, RunConfig, Sample
14+
from instructlab.eval.ragas import ModelConfig, RagasEvaluator, RunConfig
1515

1616

1717
class TestRagasEvaluator(unittest.TestCase):
18-
@patch("instructlab.eval.ragas.get_openai_client")
19-
def test_generate_answers_from_model(self, mock_get_openai_client):
18+
def test_generate_answers_from_model(self):
2019
# mock the OpenAI client to always return "london" for chat completions
20+
user_input = "What is the capital of France?"
21+
model_response = "London"
2122
mock_client = MagicMock()
2223
mock_response = MagicMock()
23-
mock_response.choices[0].message.content = "London"
24+
mock_response.choices = [MagicMock(message=MagicMock(content=model_response))]
2425
mock_client.chat.completions.create.return_value = mock_response
25-
mock_get_openai_client.return_value = mock_client
2626

2727
# get answers
28-
questions = pd.DataFrame({"user_input": ["What is the capital of France?"]})
28+
questions = pd.DataFrame({"user_input": [user_input]})
2929
student_model = ModelConfig(
30-
base_url="https://your.model.endpoint.com",
31-
model_name="jeeves-512B",
32-
api_key="test-api-key",
30+
model_name="super-jeeves-8x700B",
3331
)
3432
evaluator = RagasEvaluator()
35-
result_df = evaluator._generate_answers_from_model(questions, student_model)
33+
result_df = evaluator._generate_answers_from_model(
34+
questions, student_model, mock_client
35+
)
3636

3737
# what we expect to see
3838
expected_df = questions.copy()
39-
expected_df["response"] = ["London"]
39+
expected_df["response"] = [model_response]
4040

4141
# perform the assertions
4242
pd.testing.assert_frame_equal(result_df, expected_df)
43-
mock_get_openai_client.assert_called_once_with(
44-
model_api_base=student_model.base_url, api_key=student_model.api_key
45-
)
4643
mock_client.chat.completions.create.assert_called_once_with(
47-
messages=[student_model.system_prompt, "What is the capital of France?"],
44+
messages=[student_model.system_prompt, user_input],
4845
model=student_model.model_name,
4946
seed=42,
5047
max_tokens=student_model.max_tokens,
5148
temperature=student_model.temperature,
5249
)
5350

51+
@patch("instructlab.eval.ragas.ChatOpenAI")
5452
@patch("instructlab.eval.ragas.read_json")
5553
@patch("instructlab.eval.ragas.evaluate")
56-
@patch("instructlab.eval.ragas.ChatOpenAI")
5754
@patch.object(RagasEvaluator, "_generate_answers_from_model")
5855
@patch.object(RagasEvaluator, "_get_metrics")
5956
def test_run(
6057
self,
6158
mock_get_metrics: MagicMock,
6259
mock_generate_answers_from_model: MagicMock,
63-
mock_ChatOpenAI: MagicMock,
6460
mock_evaluate: MagicMock,
6561
mock_read_json: MagicMock,
62+
mock_ChatOpenAI: MagicMock,
6663
):
6764
########################################################################
6865
# SETUP EVERYTHING WE NEED FOR THE TESTS
@@ -74,16 +71,20 @@ def test_run(
7471
student_model_response = "Paris"
7572
user_question = "What is the capital of France?"
7673
golden_answer = "The capital of France is Paris."
74+
metric = "mocked-metric"
75+
metric_score = 4.0
7776
base_ds = [{"user_input": user_question, "reference": golden_answer}]
78-
mocked_metric = "mocked-metric"
79-
mocked_metric_score = 4.0
77+
student_model = ModelConfig(
78+
model_name="super-jeeves-8x700B",
79+
)
80+
run_config = RunConfig(max_retries=3, max_wait=60, seed=42, timeout=30)
8081

8182
# The following section takes care of mocking function return calls.
8283
# Ragas is tricky because it has some complex data structures under the hood,
8384
# so what we have to do is configure the intermediate outputs that we expect
8485
# to receive from Ragas.
8586

86-
mock_get_metrics.return_value = [mocked_metric]
87+
mock_get_metrics.return_value = [metric]
8788
interim_df = DataFrame(
8889
{
8990
"user_input": [user_question],
@@ -93,7 +94,12 @@ def test_run(
9394
)
9495
mock_generate_answers_from_model.return_value = interim_df.copy()
9596
mocked_evaluation_ds = EvaluationDataset.from_pandas(interim_df)
96-
mock_ChatOpenAI.return_value = MagicMock()
97+
mock_client = MagicMock()
98+
mock_response = MagicMock()
99+
mock_response.choices = [
100+
MagicMock(message=MagicMock(content=student_model_response))
101+
]
102+
mock_client.chat.completions.create.return_value = mock_response
97103

98104
# Ragas requires this value to instantiate an EvaluationResult object, so we must provide it.
99105
# It isn't functionally used for our purposes though.
@@ -109,29 +115,20 @@ def test_run(
109115
)
110116
}
111117
mock_evaluate.return_value = EvaluationResult(
112-
scores=[{mocked_metric: mocked_metric_score}],
118+
scores=[{metric: metric_score}],
113119
dataset=mocked_evaluation_ds,
114120
ragas_traces=_unimportant_ragas_traces,
115121
)
116122

117-
########################################################################
118-
# Run the tests
119-
########################################################################
120-
121-
# Configure all other inputs that Ragas does not depend on for proper mocking
122-
student_model = ModelConfig(
123-
base_url="https://api.openai.com",
124-
model_name="pt-3.5-turbo",
125-
api_key="test-api-key",
126-
)
127-
run_config = RunConfig(max_retries=3, max_wait=60, seed=42, timeout=30)
128-
evaluator = RagasEvaluator()
129-
130123
########################################################################
131124
# Test case: directly passing a dataset
132125
########################################################################
126+
evaluator = RagasEvaluator()
133127
result = evaluator.run(
134-
dataset=base_ds, student_model=student_model, run_config=run_config
128+
dataset=base_ds,
129+
student_model=student_model,
130+
run_config=run_config,
131+
openai_client=mock_client,
135132
)
136133

137134
self.assertIsInstance(result, EvaluationResult)
@@ -142,11 +139,13 @@ def test_run(
142139
########################################################################
143140
# Test case: passing a dataset in via Path to JSONL file
144141
########################################################################
142+
evaluator = RagasEvaluator()
145143
mock_read_json.return_value = DataFrame(base_ds)
146144
result = evaluator.run(
147145
dataset=Path("dummy_path.jsonl"),
148146
student_model=student_model,
149147
run_config=run_config,
148+
openai_client=mock_client,
150149
)
151150

152151
self.assertIsInstance(result, EvaluationResult)
@@ -156,6 +155,24 @@ def test_run(
156155
mock_generate_answers_from_model.assert_called()
157156
mock_evaluate.assert_called()
158157

158+
########################################################################
159+
# Test case: using the instance attributes
160+
########################################################################
161+
evaluator = RagasEvaluator(
162+
student_model=student_model,
163+
openai_client=mock_client,
164+
run_config=run_config,
165+
)
166+
mock_read_json.return_value = DataFrame(base_ds)
167+
result = evaluator.run(dataset=Path("dummy_path.jsonl"))
168+
169+
self.assertIsInstance(result, EvaluationResult)
170+
mock_read_json.assert_called_with(
171+
Path("dummy_path.jsonl"), orient="records", lines=True
172+
)
173+
mock_generate_answers_from_model.assert_called()
174+
mock_evaluate.assert_called()
175+
159176

160177
if __name__ == "__main__":
161178
unittest.main()

0 commit comments

Comments
 (0)