Skip to content

Commit a1a6bb2

Browse files
Add DSPy Refine
1 parent 6dbc8bc commit a1a6bb2

File tree

6 files changed

+219
-0
lines changed

6 files changed

+219
-0
lines changed

dspy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dspy.teleprompt import *
66

77
import dspy.retrievers
8+
from dspy.refine import Refine
89

910
from dspy.evaluate import Evaluate # isort: skip
1011
from dspy.clients import * # isort: skip

dspy/refine/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from dspy.refine.metrics import BoolMetric, FloatMetric
2+
from dspy.refine.refine import Refine
3+
4+
__all__ = [
5+
"Refine",
6+
"BoolMetric",
7+
"FloatMetric",
8+
]

dspy/refine/feedback.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Union
2+
3+
from dspy.signatures.field import InputField, OutputField
4+
from dspy.signatures.signature import Signature
5+
6+
7+
class GenerateFeedback(Signature):
8+
"""
9+
Based on each metric value and metric definition for the inputs-outputs pair, provide feedback the DSPy module
10+
along with submodules in order to improve the metric values at the retry. Only provide feedback for built-in
11+
classses, e.g., dspy.Predict, dspy.Module, dspy.ChainOfThought and so on. If an attribute is a list, make sure
12+
you look into every element. It's also possible that some components are not related to the certain score, we
13+
should skip generating feedback if it is the case.
14+
"""
15+
16+
metrics: list[str] = InputField(desc="The definition of each scoring criterion")
17+
metric_values: list[Union[int, float, bool]] = InputField(desc="The value of each metric, the higher the better")
18+
module_inputs: dict = InputField(desc="The inputs of the DSPy module")
19+
module_outputs: dict = InputField(desc="The outputs of the DSPy module")
20+
source_code: str = InputField(desc="The source code of the DSPy module")
21+
feedback: dict[str, list[str]] = OutputField(
22+
desc="Feedback for the DSPy module in general, along with feedback for each submodule in the DSPy model, only "
23+
"provide feedback for attributes in `__init__` method that is a built-in class of dspy. The key should be the "
24+
"attribute name, e.g., `self.cot` or `self.predict`. If the attribute is a list, write the key as "
25+
"`self.cots[0]`, `self.predicts[1]` and so on. The feedback should be "
26+
"a list of strings, corresponding to each score function in `metrics`."
27+
)

dspy/refine/metrics.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import inspect
2+
3+
4+
class Metric:
5+
def __init__(self, name, description, fn):
6+
self.name = name
7+
self.description = description
8+
self.fn = fn
9+
10+
def __call__(self, inputs, outputs):
11+
return self.fn(inputs, outputs)
12+
13+
def __repr__(self):
14+
if self.description:
15+
return f"Metric name: {self.name}\nMetric description: {self.description}\n"
16+
else:
17+
return f"Metric name: {self.name}\nMetric function: {inspect.getsource(self.fn)}\n"
18+
19+
20+
class BoolMetric(Metric):
21+
def __init__(self, name, description, fn):
22+
super().__init__(name, description, fn)
23+
self.type_description = "This is a bool metric, true if the metric looks good, false otherwise."
24+
25+
def __repr__(self):
26+
return f"{super().__repr__()}\nMetric type:{self.type_description}"
27+
28+
29+
class FloatMetric(Metric):
30+
def __init__(self, name, description, fn):
31+
super().__init__(name, description, fn)
32+
self.type_description = "This is a float metric, the higher the value the better."
33+
34+
def __repr__(self):
35+
return f"{super().__repr__()}\nMetric type:{self.type_description}"

dspy/refine/refine.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import inspect
2+
from functools import partial
3+
from types import MethodType
4+
5+
from dspy.predict.chain_of_thought import ChainOfThought
6+
from dspy.primitives.program import Module
7+
from dspy.refine.feedback import GenerateFeedback
8+
from dspy.refine.utils import get_traces
9+
from dspy.signatures.field import InputField
10+
11+
12+
class Refine(Module):
13+
def __init__(self, module, metrics, metric_thresholds=None, max_iter=3):
14+
self.module = module.deepcopy()
15+
self.metrics = metrics
16+
self.metric_thresholds = metric_thresholds
17+
self.max_iter = max_iter
18+
19+
self.metric_descriptions = [self._get_metric_description(metric) for metric in metrics]
20+
self.feedback_program = ChainOfThought(GenerateFeedback)
21+
22+
self._named_predicts = {name: predict for name, predict in self.module.named_predictors()}
23+
24+
def _get_metric_description(self, metric):
25+
if hasattr(metric, "__repr__"):
26+
return str(metric)
27+
else:
28+
return inspect.getsource(metric.__class__)
29+
30+
def _patch_predict_call_with_feedback(self, feedbacks):
31+
named_predicts = {}
32+
for name in feedbacks.keys():
33+
# Only patch the predict that has feedback.
34+
named_predicts[name] = self._named_predicts[name]
35+
36+
predict_traces = get_traces(named_predicts)
37+
38+
def forward_with_feedback(instance, dspy_refine_feedback, dspy_refine_last_trace, **kwargs):
39+
return instance.original_forward(
40+
**kwargs,
41+
dspy_refine_feedback=dspy_refine_feedback,
42+
dspy_refine_last_trace=dspy_refine_last_trace,
43+
)
44+
45+
for name, predict in named_predicts.items():
46+
last_trace = predict_traces.get(name, None)
47+
# We trim out the last round's feedback and last_trace from the inputs to avoid too much nesting.
48+
if "dspy_refine_feedback" in last_trace["inputs"]:
49+
del last_trace["inputs"]["dspy_refine_feedback"]
50+
if "dspy_refine_last_trace" in last_trace["inputs"]:
51+
del last_trace["inputs"]["dspy_refine_last_trace"]
52+
53+
feedback = feedbacks.get(name, None)
54+
if not hasattr(predict, "original_forward"):
55+
# If the predict has never been patched for refine calls, patch it.
56+
predict.original_signature = predict.signature
57+
predict.signature = predict.signature.prepend(
58+
"dspy_refine_feedback",
59+
InputField(desc="Improvement suggestion based on last try", type=str),
60+
).prepend("dspy_refine_last_trace", InputField(desc="Trace of the last try", type=dict))
61+
62+
# Save the original forward method before patching.
63+
predict.original_forward = predict.forward
64+
65+
partial_forward = partial(
66+
forward_with_feedback, dspy_refine_feedback=feedback, dspy_refine_last_trace=last_trace
67+
)
68+
# Patch the `forward` method to the `forward_with_feedback` methd with partial values of feedback and
69+
# last_trace.
70+
predict.forward = MethodType(partial_forward, predict)
71+
72+
def _undo_patch_predict_call_with_feedback(self, named_predicts):
73+
for _, predict in named_predicts.items():
74+
if hasattr(predict, "original_forward"):
75+
predict.forward = predict.original_forward
76+
predict.signature = predict.original_signature
77+
del predict.original_signature
78+
del predict.original_forward
79+
80+
def _get_feedback_for_predicts(self, inputs, outputs):
81+
metric_descriptions = []
82+
metric_values = []
83+
for i, metric in enumerate(self.metrics):
84+
metric_value = metric(inputs, outputs)
85+
if self.metric_thresholds and metric_value < self.metric_thresholds[i]:
86+
metric_descriptions.append(self.metric_descriptions[i])
87+
metric_values.append(metric_value)
88+
89+
if len(metric_descriptions) == 0:
90+
# All metric values are above the threshold, no need to refine.
91+
return {}
92+
93+
# Get feedback for each metric.
94+
feedbacks = self.feedback_program(
95+
metrics=metric_descriptions,
96+
metric_values=metric_values,
97+
module_inputs=inputs,
98+
module_outputs=outputs,
99+
source_code=inspect.getsource(self.module.__class__),
100+
).feedback
101+
named_predicts = self._named_predicts
102+
103+
predict_name_to_feedback = {}
104+
for name in named_predicts.keys():
105+
top_module_name = name.split(".")[0]
106+
if top_module_name in feedbacks:
107+
predict_name_to_feedback[name] = feedbacks[top_module_name]
108+
elif f"self.{top_module_name}" in feedbacks:
109+
predict_name_to_feedback[name] = feedbacks[f"self.{top_module_name}"]
110+
return predict_name_to_feedback
111+
112+
def __call__(self, **kwargs):
113+
outputs = self.module(**kwargs)
114+
115+
for i in range(self.max_iter):
116+
feedbacks = self._get_feedback_for_predicts(kwargs, outputs)
117+
118+
if len(feedbacks) == 0:
119+
break
120+
self._patch_predict_call_with_feedback(feedbacks)
121+
122+
outputs = self.module(**kwargs)
123+
124+
named_predicts = {name: predict for name, predict in self.module.named_predictors()}
125+
self._undo_patch_predict_call_with_feedback(named_predicts)
126+
return outputs

dspy/refine/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from dspy.dsp.utils.settings import settings
2+
3+
4+
def get_traces(named_predicts):
5+
predict_name_to_traces = {}
6+
predict_id_to_name = {id(predict): name for name, predict in named_predicts.items()}
7+
8+
traces = settings.trace
9+
for i in range(len(traces)):
10+
trace = traces[-i - 1]
11+
trace_predict_id = id(trace[0])
12+
if trace_predict_id in predict_id_to_name:
13+
predict_name = predict_id_to_name[trace_predict_id]
14+
if predict_name not in predict_name_to_traces:
15+
predict_name_to_traces[predict_name] = {
16+
"inputs": trace[1],
17+
"outputs": trace[2].toDict(),
18+
}
19+
if len(predict_name_to_traces) == len(named_predicts):
20+
# Stop searching when all predicts' traces are found.
21+
break
22+
return predict_name_to_traces

0 commit comments

Comments
 (0)