Skip to content

Commit 37f8c04

Browse files
committed
Refactor chat_adapter somewhat working
1 parent 845f7e2 commit 37f8c04

File tree

3 files changed

+118
-95
lines changed

3 files changed

+118
-95
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 66 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,7 @@ def parse_value(value, annotation):
254254
return TypeAdapter(annotation).validate_python(parsed_value)
255255

256256

257-
def format_turn(signature, values, role, incomplete=False):
258-
fields_to_collapse = []
257+
def format_turn(signature, values, role, incomplete=False):
259258
"""
260259
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
261260
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
@@ -271,90 +270,77 @@ def format_turn(signature, values, role, incomplete=False):
271270
A chat message that can be appended to a chat thread. The message contains two string fields:
272271
``role`` ("user" or "assistant") and ``content`` (the message text).
273272
"""
274-
content = []
275-
276273
if role == "user":
277-
fields: Dict[str, FieldInfo] = signature.input_fields
278-
if incomplete:
279-
fields_to_collapse.append({"type": "text", "text": "This is an example of the task, though some input or output fields are not supplied."})
274+
fields = signature.input_fields
275+
message_prefix = "This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
280276
else:
281-
fields: Dict[str, FieldInfo] = signature.output_fields
282-
# Add the built-in field indicating that the chat turn has been completed
283-
fields[BuiltInCompletedOutputFieldInfo.name] = BuiltInCompletedOutputFieldInfo.info
277+
# Add the completed field for the assistant turn
278+
fields = {**signature.output_fields, BuiltInCompletedOutputFieldInfo.name: BuiltInCompletedOutputFieldInfo.info}
284279
values = {**values, BuiltInCompletedOutputFieldInfo.name: ""}
285-
field_names: KeysView = fields.keys()
286-
if not incomplete:
287-
if not set(values).issuperset(set(field_names)):
288-
raise ValueError(f"Expected {field_names} but got {values.keys()}")
289-
290-
fields_to_collapse.extend(format_fields(
291-
fields_with_values={
292-
FieldInfoWithName(name=field_name, info=field_info): values.get(
293-
field_name, "Not supplied for this particular example."
294-
)
295-
for field_name, field_info in fields.items()
296-
},
297-
assume_text=False
298-
))
280+
message_prefix = ""
299281

300-
if role == "user":
301-
output_fields = list(signature.output_fields.keys())
302-
def type_info(v):
303-
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
304-
if v.annotation is not str else ""
305-
if output_fields:
306-
fields_to_collapse.append({
307-
"type": "text",
308-
"text": "Respond with the corresponding output fields, starting with the field "
309-
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
310-
+ ", and then ending with the marker for `[[ ## completed ## ]]`."
311-
})
312-
313-
# flatmap the list if any items are lists otherwise keep the item
314-
flattened_list = list(chain.from_iterable(
315-
item if isinstance(item, list) else [item] for item in fields_to_collapse
316-
))
317-
final_list = []
318-
while flattened_list:
319-
item = flattened_list.pop(0)
320-
image_tag_regex = r'"?<DSPY_IMAGE_START>(.*?)<DSPY_IMAGE_END>"?'
321-
if re.search(image_tag_regex, item.get("text")):
322-
image_tag = re.search(image_tag_regex, item.get("text")).group(1)
323-
# get the prefix and suffix
324-
prefix, suffix = item.get("text").split('"<DSPY_IMAGE_START>', 1)[0], "".join(item.get("text").split('<DSPY_IMAGE_END>"', 1)[1:])
325-
final_list.append({"type": "text", "text": prefix})
326-
final_list.append({"type": "image_url", "image_url": {"url": image_tag}})
327-
flattened_list.insert(0, {"type": "text", "text": suffix})
328-
else:
329-
final_list.append({"type": "text", "text": item.get("text")})
330-
331-
if all(message.get("type", None) == "text" for message in final_list):
332-
content = "\n\n".join(message.get("text") for message in final_list)
333-
return {"role": role, "content": content}
334-
335-
# Collapse all consecutive text messages into a single message.
336-
collapsed_messages = []
337-
for item in final_list:
338-
# First item is always added
339-
if not collapsed_messages:
340-
collapsed_messages.append(item)
341-
continue
342-
343-
# If current item is image, add to collapsed_messages
344-
if item.get("type") == "image_url":
345-
if collapsed_messages[-1].get("type") == "text":
346-
collapsed_messages[-1]["text"] += "\n"
347-
collapsed_messages.append(item)
348-
# If previous item is text and current item is text, append to previous item
349-
elif collapsed_messages[-1].get("type") == "text":
350-
collapsed_messages[-1]["text"] += "\n\n" + item["text"]
351-
# If previous item is not text(aka image), add current item as a new item
352-
else:
353-
item["text"] = "\n\n" + item["text"]
354-
collapsed_messages.append(item)
282+
if not incomplete and not set(values).issuperset(fields.keys()):
283+
raise ValueError(f"Expected {fields.keys()} but got {values.keys()}")
284+
285+
messages = []
286+
if message_prefix:
287+
messages.append({"type": "text", "text": message_prefix})
355288

356-
return {"role": role, "content": collapsed_messages}
289+
field_messages = format_fields(
290+
{FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.")
291+
for k, v in fields.items()},
292+
assume_text=False
293+
)
294+
messages.extend(field_messages)
295+
296+
# Add output field instructions for user messages
297+
if role == "user" and signature.output_fields:
298+
type_info = lambda v: f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" if v.annotation is not str else ""
299+
field_instructions = "Respond with the corresponding output fields, starting with the field " + \
300+
", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + \
301+
", and then ending with the marker for `[[ ## completed ## ]]`."
302+
messages.append({"type": "text", "text": field_instructions})
303+
304+
# Process messages to handle image tags and collapse text
305+
processed_messages = process_messages(messages)
306+
307+
if all(msg.get("type") == "text" for msg in processed_messages):
308+
return {"role": role, "content": "\n\n".join(msg["text"] for msg in processed_messages)}
309+
return {"role": role, "content": processed_messages}
310+
311+
def process_messages(messages):
312+
"""Process messages to handle image tags and collapse consecutive text messages."""
313+
processed = []
314+
current_text = []
315+
316+
for msg in flatten_messages(messages):
317+
if msg["type"] == "text":
318+
# Handle image tags in text
319+
parts = re.split(r'(<DSPY_IMAGE_START>.*?<DSPY_IMAGE_END>)', msg["text"])
320+
for part in parts:
321+
if match := re.match(r'<DSPY_IMAGE_START>(.*?)<DSPY_IMAGE_END>', part):
322+
if current_text:
323+
processed.append({"type": "text", "text": "\n\n".join(current_text)})
324+
current_text = []
325+
processed.append({"type": "image_url", "image_url": {"url": match.group(1)}})
326+
elif part.strip():
327+
current_text.append(part)
328+
else:
329+
if current_text:
330+
processed.append({"type": "text", "text": "\n\n".join(current_text)})
331+
current_text = []
332+
processed.append(msg)
333+
334+
if current_text:
335+
processed.append({"type": "text", "text": "\n\n".join(current_text)})
336+
337+
return processed
357338

339+
def flatten_messages(messages):
340+
"""Flatten nested message lists."""
341+
return list(chain.from_iterable(
342+
item if isinstance(item, list) else [item] for item in messages
343+
))
358344

359345
def get_annotation_name(annotation):
360346
origin = get_origin(annotation)

dspy/predict/predict.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@ def dump_state(self, save_verbose=None):
4747

4848
for field in demo:
4949
# FIXME: Saving BaseModels as strings in examples doesn't matter because you never re-access as an object
50-
# It does matter for images
51-
if isinstance(demo[field], Image):
52-
demo[field] = demo[field].model_dump()
53-
elif isinstance(demo[field], BaseModel):
54-
demo[field] = demo[field].model_dump_json()
50+
demo[field] = serialize_object(demo[field])
5551

5652
state["demos"].append(demo)
5753

@@ -296,6 +292,26 @@ def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
296292
lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values
297293
)
298294

295+
def serialize_object(obj):
296+
"""
297+
Recursively serialize a given object into a JSON-compatible format.
298+
Supports Pydantic models, lists, dicts, and primitive types.
299+
"""
300+
if isinstance(obj, BaseModel):
301+
# Use model_dump to convert the model into a JSON-serializable dict
302+
return obj.model_dump_json()
303+
elif isinstance(obj, list):
304+
# Recursively process each item in the list
305+
return [serialize_object(item) for item in obj]
306+
elif isinstance(obj, tuple):
307+
return tuple(serialize_object(item) for item in obj)
308+
elif isinstance(obj, dict):
309+
# Recursively process each key-value pair in the dict
310+
return {key: serialize_object(value) for key, value in obj.items()}
311+
else:
312+
# Assume the object is already JSON-compatible (e.g., int, str, float)
313+
return obj
314+
299315
# TODO: get some defaults during init from the context window?
300316
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
301317
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.

tests/signatures/test_adapter_image.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,17 +246,38 @@ def test_predictor_save_load(sample_url, sample_pil_image):
246246
print(result)
247247
assert messages_contain_image_url_pattern(lm.history[-1]["messages"])
248248
print(lm.history[-1]["messages"])
249-
assert False
249+
assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])
250250

251251
def test_save_load_complex_types():
252-
pass
253-
# class ComplexTypeSignature(dspy.Signature):
254-
# image_list: List[dspy.Image] = dspy.InputField(desc="A list of images")
255-
# caption: str = dspy.OutputField(desc="A caption for the image list")
252+
examples = [
253+
dspy.Example(image_list=[dspy.Image.from_url("https://example.com/dog.jpg"), dspy.Image.from_url("https://example.com/cat.jpg")], caption="Example 1").with_inputs("image_list"),
254+
]
255+
256+
class ComplexTypeSignature(dspy.Signature):
257+
image_list: List[dspy.Image] = dspy.InputField(desc="A list of images")
258+
caption: str = dspy.OutputField(desc="A caption for the image list")
259+
260+
lm = DummyLM([{"caption": "A list of images"}, {"caption": "A list of images"}])
261+
dspy.settings.configure(lm=lm)
256262

257-
# lm = DummyLM([{"caption": "A list of images"}])
258-
# dspy.settings.configure(lm=lm)
263+
predictor = dspy.Predict(ComplexTypeSignature)
264+
result = predictor(**examples[0].inputs())
265+
266+
print(lm.history[-1]["messages"])
267+
assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])
268+
assert str(lm.history[-1]["messages"]).count("'url'") == 2
259269

260-
# predictor = dspy.Predict(ComplexTypeSignature)
261-
# result = predictor(image_list=[dspy.Image.from_url("https://example.com/dog.jpg")])
262-
# assert isinstance(result.caption, str)
270+
optimizer = dspy.teleprompt.LabeledFewShot(k=1)
271+
compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)
272+
print(compiled_predictor.demos)
273+
274+
with tempfile.NamedTemporaryFile(mode='w+', delete=True) as temp_file:
275+
print("compiled_predictor state: ", compiled_predictor.dump_state())
276+
compiled_predictor.save(temp_file.name)
277+
loaded_predictor = dspy.Predict(ComplexTypeSignature)
278+
loaded_predictor.load(temp_file.name)
279+
280+
print("loaded_predictor state: ", loaded_predictor.dump_state())
281+
result = loaded_predictor(**examples[0].inputs())
282+
assert result.caption == "A list of images"
283+
assert str(lm.history[-1]["messages"]).count("'url'") == 4

0 commit comments

Comments
 (0)