Skip to content

Commit acf10ee

Browse files
committed
Tests failing on purpose - added None support and new str repr
1 parent d31c63d commit acf10ee

File tree

3 files changed

+124
-16
lines changed

3 files changed

+124
-16
lines changed

Diff for: dspy/adapters/chat_adapter.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from pydantic.fields import FieldInfo
1818
from typing import Dict, List, Literal, NamedTuple, get_args, get_origin
1919

20-
from dspy.adapters.base import Adapter
2120
from ..signatures.field import OutputField
2221
from ..signatures.signature import SignatureMeta
2322
from ..signatures.utils import get_dspy_field_type

Diff for: dspy/adapters/image_utils.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,20 @@
1414

1515
class Image(pydantic.BaseModel):
1616
url: str
17+
18+
model_config = {
19+
'frozen': True,
20+
'str_strip_whitespace': True,
21+
'validate_assignment': True,
22+
'extra': 'forbid',
23+
}
1724

1825
@pydantic.model_validator(mode="before")
1926
@classmethod
2027
def validate_input(cls, values):
2128
# Allow the model to accept either a URL string or a dictionary with a single 'url' key
2229
if isinstance(values, str):
23-
# if a string, assume its the URL directly and wrap it in a dict
30+
# if a string, assume it's the URL directly and wrap it in a dict
2431
return {"url": values}
2532
elif isinstance(values, dict) and set(values.keys()) == {"url"}:
2633
# if it's a dict, ensure it has only the 'url' key
@@ -47,6 +54,15 @@ def from_PIL(cls, pil_image):
4754
def serialize_model(self):
4855
return "<DSPY_IMAGE_START>" + self.url + "<DSPY_IMAGE_END>"
4956

57+
def __str__(self):
58+
return self.serialize_model()
59+
60+
def __repr__(self):
61+
if "base64" in self.url:
62+
len_base64 = len(self.url.split("base64,")[1])
63+
return f"Image(url=data:image/...base64,<IMAGE_BASE_64_ENCODED({str(len_base64)})>)"
64+
return f"Image(url='{self.url}')"
65+
5066
def is_url(string: str) -> bool:
5167
"""Check if a string is a valid URL."""
5268
try:

Diff for: tests/signatures/test_adapter_image.py

+107-14
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
import datetime
2-
from typing import Dict, List, Tuple
2+
from typing import Dict, List, Optional, Tuple
33

44
import pytest
5-
from PIL import Image
5+
from PIL import Image as PILImage
66
import requests
77
from io import BytesIO
88

99
import dspy
1010
from dspy import Predict
1111
from dspy.utils.dummies import DummyLM
1212
import tempfile
13+
import pydantic
1314

1415
@pytest.fixture
1516
def sample_pil_image():
1617
"""Fixture to provide a sample image for testing"""
1718
url = 'https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg'
1819
response = requests.get(url)
1920
response.raise_for_status()
20-
return Image.open(BytesIO(response.content))
21+
return PILImage.open(BytesIO(response.content))
2122

2223
@pytest.fixture
2324
def sample_dspy_image_download():
@@ -33,7 +34,7 @@ def sample_dspy_image_no_download():
3334

3435

3536

36-
def messages_contain_image_url_pattern(messages):
37+
def messages_contain_image_url_pattern(messages, n=1):
3738
pattern = {
3839
'type': 'image_url',
3940
'image_url': {
@@ -55,18 +56,19 @@ def check_pattern(obj, pattern):
5556
return pattern(obj)
5657
return obj == pattern
5758

58-
# Look for pattern in any nested dict
59-
def find_pattern(obj, pattern):
59+
# Look for pattern in any nested dict and count occurrences
60+
def count_patterns(obj, pattern):
61+
count = 0
6062
if check_pattern(obj, pattern):
61-
return True
63+
count += 1
6264
if isinstance(obj, dict):
63-
return any(find_pattern(v, pattern) for v in obj.values())
65+
count += sum(count_patterns(v, pattern) for v in obj.values())
6466
if isinstance(obj, (list, tuple)):
65-
return any(find_pattern(v, pattern) for v in obj)
66-
return False
67+
count += sum(count_patterns(v, pattern) for v in obj)
68+
return count
6769

68-
return find_pattern(messages, pattern)
69-
except:
70+
return count_patterns(messages, pattern) == n
71+
except Exception:
7072
return False
7173

7274
def test_probabilistic_classification():
@@ -244,11 +246,11 @@ def test_predictor_save_load(sample_url, sample_pil_image):
244246

245247
result = loaded_predictor(image=active_example["image"])
246248
print(result)
247-
assert messages_contain_image_url_pattern(lm.history[-1]["messages"])
249+
assert messages_contain_image_url_pattern(lm.history[-1]["messages"], n=2)
248250
print(lm.history[-1]["messages"])
249251
assert "<DSPY_IMAGE_START>" not in str(lm.history[-1]["messages"])
250252

251-
def test_save_load_complex_types():
253+
def test_save_load_complex_default_types():
252254
examples = [
253255
dspy.Example(image_list=[dspy.Image.from_url("https://example.com/dog.jpg"), dspy.Image.from_url("https://example.com/cat.jpg")], caption="Example 1").with_inputs("image_list"),
254256
]
@@ -281,3 +283,94 @@ class ComplexTypeSignature(dspy.Signature):
281283
result = loaded_predictor(**examples[0].inputs())
282284
assert result.caption == "A list of images"
283285
assert str(lm.history[-1]["messages"]).count("'url'") == 4
286+
287+
def test_save_load_complex_pydantic_types():
288+
"""Test saving and loading predictors with pydantic models containing image fields"""
289+
class ImageModel(pydantic.BaseModel):
290+
image: dspy.Image
291+
label: str
292+
293+
class ComplexPydanticSignature(dspy.Signature):
294+
model_input: ImageModel = dspy.InputField(desc="A pydantic model containing an image")
295+
caption: str = dspy.OutputField(desc="A caption for the image")
296+
297+
example_model = ImageModel(
298+
image=dspy.Image.from_url("https://example.com/dog.jpg"),
299+
label="dog"
300+
)
301+
examples = [
302+
dspy.Example(model_input=example_model, caption="A dog").with_inputs("model_input")
303+
]
304+
305+
lm = DummyLM([{"caption": "A dog photo"}, {"caption": "A dog photo"}])
306+
dspy.settings.configure(lm=lm)
307+
308+
predictor = dspy.Predict(ComplexPydanticSignature)
309+
result = predictor(model_input=example_model)
310+
assert result.caption == "A dog photo"
311+
assert messages_contain_image_url_pattern(lm.history[-1]["messages"])
312+
313+
optimizer = dspy.teleprompt.LabeledFewShot(k=1)
314+
compiled_predictor = optimizer.compile(student=predictor, trainset=examples, sample=False)
315+
316+
with tempfile.NamedTemporaryFile(mode='w+', delete=True) as temp_file:
317+
compiled_predictor.save(temp_file.name)
318+
loaded_predictor = dspy.Predict(ComplexPydanticSignature)
319+
loaded_predictor.load(temp_file.name)
320+
321+
result = loaded_predictor(model_input=example_model)
322+
lm.inspect_history()
323+
assert result.caption == "A dog photo"
324+
assert messages_contain_image_url_pattern(lm.history[-1]["messages"], n=2)
325+
326+
def test_image_repr():
327+
"""Test string representation of Image objects with both URLs and PIL images"""
328+
# Test URL-based image repr and str
329+
url_image = dspy.Image.from_url("https://example.com/dog.jpg", download=False)
330+
assert str(url_image) == "<DSPY_IMAGE_START>https://example.com/dog.jpg<DSPY_IMAGE_END>"
331+
assert repr(url_image) == "Image(url='https://example.com/dog.jpg')"
332+
333+
# Test PIL image repr and str
334+
sample_pil = PILImage.new('RGB', (60, 30), color='red')
335+
pil_image = dspy.Image.from_PIL(sample_pil)
336+
# Test str() behavior
337+
assert str(pil_image).startswith("<DSPY_IMAGE_START>data:image/png;base64,")
338+
assert str(pil_image).endswith("<DSPY_IMAGE_END>")
339+
# Test repr() behavior
340+
repr_str = repr(pil_image)
341+
assert repr_str.startswith("Image(url=data:image/...base64,<IMAGE_BASE_64_ENCODED(")
342+
assert repr_str.endswith(")>)")
343+
assert "base64" in str(pil_image)
344+
345+
def test_image_optional_input():
346+
"""Test behavior when optional image inputs are missing"""
347+
class OptionalImageSignature(dspy.Signature):
348+
image: Optional[dspy.Image] = dspy.InputField(desc="An optional image input")
349+
text: str = dspy.InputField(desc="A text input")
350+
output: str = dspy.OutputField(desc="The output text")
351+
352+
lm = DummyLM([{"output": "Text only: hello"}, {"output": "Image and text: hello with image"}])
353+
dspy.settings.configure(lm=lm)
354+
355+
predictor = dspy.Predict(OptionalImageSignature)
356+
357+
# Test with missing image
358+
result = predictor(image=None, text="hello")
359+
assert result.output == "Text only: hello"
360+
assert not messages_contain_image_url_pattern(lm.history[-1]["messages"])
361+
362+
lm.inspect_history()
363+
print(lm.history[-1]["messages"])
364+
assert False
365+
# Test with image present
366+
result = predictor(
367+
image=dspy.Image.from_url("https://example.com/image.jpg"),
368+
text="hello"
369+
)
370+
assert result.output == "Image and text: hello with image"
371+
assert messages_contain_image_url_pattern(lm.history[-1]["messages"])
372+
373+
# Tests to write:
374+
# complex image types
375+
# Return "None" when dspy.Image is missing for the input, similarly to str input fields
376+
# JSON adapter

0 commit comments

Comments
 (0)