Skip to content

Commit a8e3f19

Browse files
Add DSPy Refine
1 parent 6dbc8bc commit a8e3f19

File tree

6 files changed

+221
-0
lines changed

6 files changed

+221
-0
lines changed

dspy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dspy.teleprompt import *
66

77
import dspy.retrievers
8+
from dspy.refine import Refine
89

910
from dspy.evaluate import Evaluate # isort: skip
1011
from dspy.clients import * # isort: skip

dspy/refine/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from dspy.refine.metrics import BoolMetric, FloatMetric
2+
from dspy.refine.refine import Refine
3+
4+
__all__ = [
5+
"Refine",
6+
"BoolMetric",
7+
"FloatMetric",
8+
]

dspy/refine/feedback.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Union
2+
3+
from dspy.signatures.field import InputField, OutputField
4+
from dspy.signatures.signature import Signature
5+
6+
7+
class GenerateFeedback(Signature):
8+
"""
9+
Based on each metric value and metric definition for the inputs-outputs pair, provide feedback the DSPy module
10+
along with submodules in order to improve the metric values at the retry. Only provide feedback for built-in
11+
classses, e.g., dspy.Predict, dspy.Module, dspy.ChainOfThought and so on. If an attribute is a list, make sure
12+
you look into every element. It's also possible that some components are not related to the certain score, we
13+
should skip generating feedback if it is the case.
14+
"""
15+
16+
metrics: list[str] = InputField(desc="The definition of each scoring criterion")
17+
metric_values: list[Union[int, float, bool]] = InputField(desc="The value of each metric, the higher the better")
18+
module_inputs: dict = InputField(desc="The inputs of the DSPy module")
19+
module_outputs: dict = InputField(desc="The outputs of the DSPy module")
20+
source_code: str = InputField(desc="The source code of the DSPy module")
21+
feedback: dict[str, list[str]] = OutputField(
22+
desc="Feedback for the DSPy module in general, along with feedback for each submodule in the DSPy model, only "
23+
"provide feedback for attributes in `__init__` method that is a built-in class of dspy. The key should be the "
24+
"attribute name, e.g., `self.cot` or `self.predict`. If the attribute is a list, write the key as "
25+
"`self.cots[0]`, `self.predicts[1]` and so on. The feedback should be "
26+
"a list of strings, corresponding to each score function in `metrics`."
27+
)

dspy/refine/metrics.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import inspect
2+
3+
4+
class Metric:
5+
def __init__(self, name, description, fn):
6+
self.name = name
7+
self.description = description
8+
self.fn = fn
9+
10+
def __call__(self, inputs, outputs):
11+
return self.fn(inputs, outputs)
12+
13+
def __repr__(self):
14+
if self.description:
15+
return f"Metric name: {self.name}\nMetric description: {self.description}\n"
16+
else:
17+
return f"Metric name: {self.name}\nMetric function: {inspect.getsource(self.fn)}\n"
18+
19+
20+
class BoolMetric(Metric):
21+
def __init__(self, name, description, fn):
22+
super().__init__(name, description, fn)
23+
self.type_description = "This is a bool metric, true if the metric looks good, false otherwise."
24+
25+
def __repr__(self):
26+
return f"{super().__repr__()}\nMetric type:{self.type_description}"
27+
28+
29+
class FloatMetric(Metric):
30+
def __init__(self, name, description, fn):
31+
super().__init__(name, description, fn)
32+
self.type_description = "This is a float metric, the higher the value the better."
33+
34+
def __repr__(self):
35+
return f"{super().__repr__()}\nMetric type:{self.type_description}"

dspy/refine/refine.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import inspect
2+
from functools import partial
3+
from types import MethodType
4+
5+
from dspy.predict.chain_of_thought import ChainOfThought
6+
from dspy.primitives.program import Module
7+
from dspy.refine.feedback import GenerateFeedback
8+
from dspy.refine.utils import get_traces
9+
from dspy.signatures.field import InputField
10+
11+
12+
class Refine(Module):
13+
def __init__(self, module, metrics, metric_thresholds=None, max_iter=3):
14+
self.module = module.deepcopy()
15+
self.metrics = metrics
16+
self.metric_thresholds = metric_thresholds
17+
self.max_iter = max_iter
18+
19+
self.metric_descriptions = [self._get_metric_description(metric) for metric in metrics]
20+
21+
self._named_predicts = {name: predict for name, predict in self.module.named_predictors()}
22+
23+
def _get_metric_description(self, metric):
24+
if hasattr(metric, "__repr__"):
25+
return str(metric)
26+
else:
27+
return inspect.getsource(metric.__class__)
28+
29+
def _patch_predict_call_with_feedback(self, feedbacks):
30+
named_predicts = {}
31+
for name in feedbacks.keys():
32+
# Only patch the predict that has feedback.
33+
named_predicts[name] = self._named_predicts[name]
34+
35+
predict_traces = get_traces(named_predicts)
36+
37+
def forward_with_feedback(instance, dspy_refine_feedback, dspy_refine_last_trace, **kwargs):
38+
return instance.original_forward(
39+
**kwargs,
40+
dspy_refine_feedback=dspy_refine_feedback,
41+
dspy_refine_last_trace=dspy_refine_last_trace,
42+
)
43+
44+
for name, predict in named_predicts.items():
45+
last_trace = predict_traces.get(name, None)
46+
if last_trace is None:
47+
continue
48+
# We trim out the last round's feedback and last_trace from the inputs to avoid too much nesting.
49+
if "dspy_refine_feedback" in last_trace["inputs"]:
50+
del last_trace["inputs"]["dspy_refine_feedback"]
51+
if "dspy_refine_last_trace" in last_trace["inputs"]:
52+
del last_trace["inputs"]["dspy_refine_last_trace"]
53+
54+
feedback = feedbacks.get(name, None)
55+
if not hasattr(predict, "original_forward"):
56+
# If the predict has never been patched for refine calls, patch it.
57+
predict.original_signature = predict.signature
58+
predict.signature = predict.signature.prepend(
59+
"dspy_refine_feedback",
60+
InputField(desc="Improvement suggestion based on last try", type=str),
61+
).prepend("dspy_refine_last_trace", InputField(desc="Trace of the last try", type=dict))
62+
63+
# Save the original forward method before patching.
64+
predict.original_forward = predict.forward
65+
66+
partial_forward = partial(
67+
forward_with_feedback, dspy_refine_feedback=feedback, dspy_refine_last_trace=last_trace
68+
)
69+
# Patch the `forward` method to the `forward_with_feedback` methd with partial values of feedback and
70+
# last_trace.
71+
predict.forward = MethodType(partial_forward, predict)
72+
73+
def _undo_patch_predict_call_with_feedback(self, named_predicts):
74+
for _, predict in named_predicts.items():
75+
if hasattr(predict, "original_forward"):
76+
predict.forward = predict.original_forward
77+
predict.signature = predict.original_signature
78+
del predict.original_signature
79+
del predict.original_forward
80+
81+
def _get_feedback_for_predicts(self, inputs, outputs):
82+
metric_descriptions = []
83+
metric_values = []
84+
for i, metric in enumerate(self.metrics):
85+
metric_value = metric(inputs, outputs)
86+
if self.metric_thresholds and metric_value < self.metric_thresholds[i]:
87+
metric_descriptions.append(self.metric_descriptions[i])
88+
metric_values.append(metric_value)
89+
90+
if len(metric_descriptions) == 0:
91+
# All metric values are above the threshold, no need to refine.
92+
return {}
93+
94+
feedback_program = ChainOfThought(GenerateFeedback)
95+
# Get feedback for each metric.
96+
feedbacks = feedback_program(
97+
metrics=metric_descriptions,
98+
metric_values=metric_values,
99+
module_inputs=inputs,
100+
module_outputs=outputs,
101+
source_code=inspect.getsource(self.module.__class__),
102+
).feedback
103+
named_predicts = self._named_predicts
104+
105+
predict_name_to_feedback = {}
106+
for name in named_predicts.keys():
107+
top_module_name = name.split(".")[0]
108+
if top_module_name in feedbacks:
109+
predict_name_to_feedback[name] = feedbacks[top_module_name]
110+
elif f"self.{top_module_name}" in feedbacks:
111+
predict_name_to_feedback[name] = feedbacks[f"self.{top_module_name}"]
112+
return predict_name_to_feedback
113+
114+
def __call__(self, **kwargs):
115+
outputs = self.module(**kwargs)
116+
117+
for i in range(self.max_iter):
118+
feedbacks = self._get_feedback_for_predicts(kwargs, outputs)
119+
120+
if len(feedbacks) == 0:
121+
break
122+
self._patch_predict_call_with_feedback(feedbacks)
123+
124+
outputs = self.module(**kwargs)
125+
126+
named_predicts = {name: predict for name, predict in self.module.named_predictors()}
127+
self._undo_patch_predict_call_with_feedback(named_predicts)
128+
return outputs

dspy/refine/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from dspy.dsp.utils.settings import settings
2+
3+
4+
def get_traces(named_predicts):
5+
predict_name_to_traces = {}
6+
predict_id_to_name = {id(predict): name for name, predict in named_predicts.items()}
7+
8+
traces = settings.trace
9+
for i in range(len(traces)):
10+
trace = traces[-i - 1]
11+
trace_predict_id = id(trace[0])
12+
if trace_predict_id in predict_id_to_name:
13+
predict_name = predict_id_to_name[trace_predict_id]
14+
if predict_name not in predict_name_to_traces:
15+
predict_name_to_traces[predict_name] = {
16+
"inputs": trace[1],
17+
"outputs": trace[2].toDict(),
18+
}
19+
if len(predict_name_to_traces) == len(named_predicts):
20+
# Stop searching when all predicts' traces are found.
21+
break
22+
return predict_name_to_traces

0 commit comments

Comments
 (0)