1
1
import datetime
2
- from typing import Dict , List , Tuple
2
+ from typing import Dict , List , Optional , Tuple
3
3
4
4
import pytest
5
- from PIL import Image
5
+ from PIL import Image as PILImage
6
6
import requests
7
7
from io import BytesIO
8
8
9
9
import dspy
10
10
from dspy import Predict
11
11
from dspy .utils .dummies import DummyLM
12
12
import tempfile
13
+ import pydantic
13
14
14
15
@pytest .fixture
15
16
def sample_pil_image ():
16
17
"""Fixture to provide a sample image for testing"""
17
18
url = 'https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg'
18
19
response = requests .get (url )
19
20
response .raise_for_status ()
20
- return Image .open (BytesIO (response .content ))
21
+ return PILImage .open (BytesIO (response .content ))
21
22
22
23
@pytest .fixture
23
24
def sample_dspy_image_download ():
@@ -33,7 +34,7 @@ def sample_dspy_image_no_download():
33
34
34
35
35
36
36
- def messages_contain_image_url_pattern (messages ):
37
+ def messages_contain_image_url_pattern (messages , n = 1 ):
37
38
pattern = {
38
39
'type' : 'image_url' ,
39
40
'image_url' : {
@@ -55,18 +56,19 @@ def check_pattern(obj, pattern):
55
56
return pattern (obj )
56
57
return obj == pattern
57
58
58
- # Look for pattern in any nested dict
59
- def find_pattern (obj , pattern ):
59
+ # Look for pattern in any nested dict and count occurrences
60
+ def count_patterns (obj , pattern ):
61
+ count = 0
60
62
if check_pattern (obj , pattern ):
61
- return True
63
+ count += 1
62
64
if isinstance (obj , dict ):
63
- return any ( find_pattern (v , pattern ) for v in obj .values ())
65
+ count += sum ( count_patterns (v , pattern ) for v in obj .values ())
64
66
if isinstance (obj , (list , tuple )):
65
- return any ( find_pattern (v , pattern ) for v in obj )
66
- return False
67
+ count += sum ( count_patterns (v , pattern ) for v in obj )
68
+ return count
67
69
68
- return find_pattern (messages , pattern )
69
- except :
70
+ return count_patterns (messages , pattern ) == n
71
+ except Exception :
70
72
return False
71
73
72
74
def test_probabilistic_classification ():
@@ -244,11 +246,11 @@ def test_predictor_save_load(sample_url, sample_pil_image):
244
246
245
247
result = loaded_predictor (image = active_example ["image" ])
246
248
print (result )
247
- assert messages_contain_image_url_pattern (lm .history [- 1 ]["messages" ])
249
+ assert messages_contain_image_url_pattern (lm .history [- 1 ]["messages" ], n = 2 )
248
250
print (lm .history [- 1 ]["messages" ])
249
251
assert "<DSPY_IMAGE_START>" not in str (lm .history [- 1 ]["messages" ])
250
252
251
- def test_save_load_complex_types ():
253
+ def test_save_load_complex_default_types ():
252
254
examples = [
253
255
dspy .Example (image_list = [dspy .Image .from_url ("https://example.com/dog.jpg" ), dspy .Image .from_url ("https://example.com/cat.jpg" )], caption = "Example 1" ).with_inputs ("image_list" ),
254
256
]
@@ -281,3 +283,94 @@ class ComplexTypeSignature(dspy.Signature):
281
283
result = loaded_predictor (** examples [0 ].inputs ())
282
284
assert result .caption == "A list of images"
283
285
assert str (lm .history [- 1 ]["messages" ]).count ("'url'" ) == 4
286
+
287
+ def test_save_load_complex_pydantic_types ():
288
+ """Test saving and loading predictors with pydantic models containing image fields"""
289
+ class ImageModel (pydantic .BaseModel ):
290
+ image : dspy .Image
291
+ label : str
292
+
293
+ class ComplexPydanticSignature (dspy .Signature ):
294
+ model_input : ImageModel = dspy .InputField (desc = "A pydantic model containing an image" )
295
+ caption : str = dspy .OutputField (desc = "A caption for the image" )
296
+
297
+ example_model = ImageModel (
298
+ image = dspy .Image .from_url ("https://example.com/dog.jpg" ),
299
+ label = "dog"
300
+ )
301
+ examples = [
302
+ dspy .Example (model_input = example_model , caption = "A dog" ).with_inputs ("model_input" )
303
+ ]
304
+
305
+ lm = DummyLM ([{"caption" : "A dog photo" }, {"caption" : "A dog photo" }])
306
+ dspy .settings .configure (lm = lm )
307
+
308
+ predictor = dspy .Predict (ComplexPydanticSignature )
309
+ result = predictor (model_input = example_model )
310
+ assert result .caption == "A dog photo"
311
+ assert messages_contain_image_url_pattern (lm .history [- 1 ]["messages" ])
312
+
313
+ optimizer = dspy .teleprompt .LabeledFewShot (k = 1 )
314
+ compiled_predictor = optimizer .compile (student = predictor , trainset = examples , sample = False )
315
+
316
+ with tempfile .NamedTemporaryFile (mode = 'w+' , delete = True ) as temp_file :
317
+ compiled_predictor .save (temp_file .name )
318
+ loaded_predictor = dspy .Predict (ComplexPydanticSignature )
319
+ loaded_predictor .load (temp_file .name )
320
+
321
+ result = loaded_predictor (model_input = example_model )
322
+ lm .inspect_history ()
323
+ assert result .caption == "A dog photo"
324
+ assert messages_contain_image_url_pattern (lm .history [- 1 ]["messages" ], n = 2 )
325
+
326
+ def test_image_repr ():
327
+ """Test string representation of Image objects with both URLs and PIL images"""
328
+ # Test URL-based image repr and str
329
+ url_image = dspy .Image .from_url ("https://example.com/dog.jpg" , download = False )
330
+ assert str (url_image ) == "<DSPY_IMAGE_START>https://example.com/dog.jpg<DSPY_IMAGE_END>"
331
+ assert repr (url_image ) == "Image(url='https://example.com/dog.jpg')"
332
+
333
+ # Test PIL image repr and str
334
+ sample_pil = PILImage .new ('RGB' , (60 , 30 ), color = 'red' )
335
+ pil_image = dspy .Image .from_PIL (sample_pil )
336
+ # Test str() behavior
337
+ assert str (pil_image ).startswith ("<DSPY_IMAGE_START>data:image/png;base64," )
338
+ assert str (pil_image ).endswith ("<DSPY_IMAGE_END>" )
339
+ # Test repr() behavior
340
+ repr_str = repr (pil_image )
341
+ assert repr_str .startswith ("Image(url=data:image/...base64,<IMAGE_BASE_64_ENCODED(" )
342
+ assert repr_str .endswith (")>)" )
343
+ assert "base64" in str (pil_image )
344
+
345
+ def test_image_optional_input ():
346
+ """Test behavior when optional image inputs are missing"""
347
+ class OptionalImageSignature (dspy .Signature ):
348
+ image : Optional [dspy .Image ] = dspy .InputField (desc = "An optional image input" )
349
+ text : str = dspy .InputField (desc = "A text input" )
350
+ output : str = dspy .OutputField (desc = "The output text" )
351
+
352
+ lm = DummyLM ([{"output" : "Text only: hello" }, {"output" : "Image and text: hello with image" }])
353
+ dspy .settings .configure (lm = lm )
354
+
355
+ predictor = dspy .Predict (OptionalImageSignature )
356
+
357
+ # Test with missing image
358
+ result = predictor (image = None , text = "hello" )
359
+ assert result .output == "Text only: hello"
360
+ assert not messages_contain_image_url_pattern (lm .history [- 1 ]["messages" ])
361
+
362
+ lm .inspect_history ()
363
+ print (lm .history [- 1 ]["messages" ])
364
+ assert False
365
+ # Test with image present
366
+ result = predictor (
367
+ image = dspy .Image .from_url ("https://example.com/image.jpg" ),
368
+ text = "hello"
369
+ )
370
+ assert result .output == "Image and text: hello with image"
371
+ assert messages_contain_image_url_pattern (lm .history [- 1 ]["messages" ])
372
+
373
+ # Tests to write:
374
+ # complex image types
375
+ # Return "None" when dspy.Image is missing for the input, similarly to str input fields
376
+ # JSON adapter
0 commit comments