Skip to content

Commit e45d332

Browse files
sjrlsilvanocerzaAmnah199dfokina
authored
feat: Adding DALLE image generator (#8448)
* First pass at adding DALLE image generator * Add missing header * Fix tests * Add tests * Fix mypy * Make mypy happy * More unit tests * Adding release notes * Add a test for run * Update haystack/components/generators/openai_dalle.py Co-authored-by: Silvano Cerza <[email protected]> * Fix pylint * Update haystack/components/generators/openai_dalle.py Co-authored-by: Amna Mubashar <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> --------- Co-authored-by: Silvano Cerza <[email protected]> Co-authored-by: Amna Mubashar <[email protected]> Co-authored-by: Daria Fokina <[email protected]>
1 parent a045c0e commit e45d332

File tree

5 files changed

+327
-1
lines changed

5 files changed

+327
-1
lines changed

docs/pydoc/config/generators_api.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ loaders:
77
"hugging_face_local",
88
"hugging_face_api",
99
"openai",
10+
"openai_dalle",
1011
"chat/azure",
1112
"chat/hugging_face_local",
1213
"chat/hugging_face_api",

haystack/components/generators/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,12 @@
88
from haystack.components.generators.azure import AzureOpenAIGenerator
99
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
1010
from haystack.components.generators.hugging_face_api import HuggingFaceAPIGenerator
11+
from haystack.components.generators.openai_dalle import DALLEImageGenerator
1112

12-
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceAPIGenerator", "OpenAIGenerator", "AzureOpenAIGenerator"]
13+
__all__ = [
14+
"HuggingFaceLocalGenerator",
15+
"HuggingFaceAPIGenerator",
16+
"OpenAIGenerator",
17+
"AzureOpenAIGenerator",
18+
"DALLEImageGenerator",
19+
]
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import os
6+
from typing import Any, Dict, List, Literal, Optional
7+
8+
from openai import OpenAI
9+
from openai.types.image import Image
10+
11+
from haystack import component, default_from_dict, default_to_dict, logging
12+
from haystack.utils import Secret, deserialize_secrets_inplace
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
@component
18+
class DALLEImageGenerator:
19+
"""
20+
Generates images using OpenAI's DALL-E model.
21+
22+
For details on OpenAI API parameters, see
23+
[OpenAI documentation](https://platform.openai.com/docs/api-reference/images/create).
24+
25+
### Usage example
26+
27+
```python
28+
from haystack.components.generators import DALLEImageGenerator
29+
image_generator = DALLEImageGenerator()
30+
response = image_generator.run("Show me a picture of a black cat.")
31+
print(response)
32+
```
33+
"""
34+
35+
def __init__( # pylint: disable=too-many-positional-arguments
36+
self,
37+
model: str = "dall-e-3",
38+
quality: Literal["standard", "hd"] = "standard",
39+
size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
40+
response_format: Literal["url", "b64_json"] = "url",
41+
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
42+
api_base_url: Optional[str] = None,
43+
organization: Optional[str] = None,
44+
timeout: Optional[float] = None,
45+
max_retries: Optional[int] = None,
46+
):
47+
"""
48+
Creates an instance of DALLEImageGenerator. Unless specified otherwise in `model`, uses OpenAI's dall-e-3.
49+
50+
:param model: The model to use for image generation. Can be "dall-e-2" or "dall-e-3".
51+
:param quality: The quality of the generated image. Can be "standard" or "hd".
52+
:param size: The size of the generated images.
53+
Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2.
54+
Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models.
55+
:param response_format: The format of the response. Can be "url" or "b64_json".
56+
:param api_key: The OpenAI API key to connect to OpenAI.
57+
:param api_base_url: An optional base URL.
58+
:param organization: The Organization ID, defaults to `None`.
59+
:param timeout:
60+
Timeout for OpenAI Client calls. If not set, it is inferred from the `OPENAI_TIMEOUT` environment variable
61+
or set to 30.
62+
:param max_retries:
63+
Maximum retries to establish contact with OpenAI if it returns an internal error. If not set, it is inferred
64+
from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
65+
"""
66+
self.model = model
67+
self.quality = quality
68+
self.size = size
69+
self.response_format = response_format
70+
self.api_key = api_key
71+
self.api_base_url = api_base_url
72+
self.organization = organization
73+
74+
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
75+
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
76+
77+
self.client: Optional[OpenAI] = None
78+
79+
def warm_up(self) -> None:
80+
"""
81+
Warm up the OpenAI client.
82+
"""
83+
if self.client is None:
84+
self.client = OpenAI(
85+
api_key=self.api_key.resolve_value(),
86+
organization=self.organization,
87+
base_url=self.api_base_url,
88+
timeout=self.timeout,
89+
max_retries=self.max_retries,
90+
)
91+
92+
@component.output_types(images=List[str], revised_prompt=str)
93+
def run(
94+
self,
95+
prompt: str,
96+
size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] = None,
97+
quality: Optional[Literal["standard", "hd"]] = None,
98+
response_format: Optional[Optional[Literal["url", "b64_json"]]] = None,
99+
):
100+
"""
101+
Invokes the image generation inference based on the provided prompt and generation parameters.
102+
103+
:param prompt: The prompt to generate the image.
104+
:param size: If provided, overrides the size provided during initialization.
105+
:param quality: If provided, overrides the quality provided during initialization.
106+
:param response_format: If provided, overrides the response format provided during initialization.
107+
108+
:returns:
109+
A dictionary containing the generated list of images and the revised prompt.
110+
Depending on the `response_format` parameter, the list of images can be URLs or base64 encoded JSON strings.
111+
The revised prompt is the prompt that was used to generate the image, if there was any revision
112+
to the prompt made by OpenAI.
113+
"""
114+
if self.client is None:
115+
raise RuntimeError(
116+
"The component DALLEImageGenerator wasn't warmed up. Run 'warm_up()' before calling 'run()'."
117+
)
118+
119+
size = size or self.size
120+
quality = quality or self.quality
121+
response_format = response_format or self.response_format
122+
response = self.client.images.generate(
123+
model=self.model, prompt=prompt, size=size, quality=quality, response_format=response_format, n=1
124+
)
125+
image: Image = response.data[0]
126+
image_str = image.url or image.b64_json or ""
127+
return {"images": [image_str], "revised_prompt": image.revised_prompt or ""}
128+
129+
def to_dict(self) -> Dict[str, Any]:
130+
"""
131+
Serialize this component to a dictionary.
132+
133+
:returns:
134+
The serialized component as a dictionary.
135+
"""
136+
return default_to_dict( # type: ignore
137+
self,
138+
model=self.model,
139+
quality=self.quality,
140+
size=self.size,
141+
response_format=self.response_format,
142+
api_key=self.api_key.to_dict(),
143+
api_base_url=self.api_base_url,
144+
organization=self.organization,
145+
)
146+
147+
@classmethod
148+
def from_dict(cls, data: Dict[str, Any]) -> "DALLEImageGenerator":
149+
"""
150+
Deserialize this component from a dictionary.
151+
152+
:param data:
153+
The dictionary representation of this component.
154+
:returns:
155+
The deserialized component instance.
156+
"""
157+
init_params = data.get("init_parameters", {})
158+
deserialize_secrets_inplace(init_params, keys=["api_key"])
159+
return default_from_dict(cls, data) # type: ignore
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
---
2+
features:
3+
- |
4+
We've added a new **DALLEImageGenerator** component, bringing image generation with OpenAI's DALL-E to the Haystack
5+
6+
- **Easy to Use**: Just a few lines of code to get started:
7+
```python
8+
from haystack.components.generators import DALLEImageGenerator
9+
image_generator = DALLEImageGenerator()
10+
response = image_generator.run("Show me a picture of a black cat.")
11+
print(response)
12+
```
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import pytest
6+
from unittest.mock import Mock, patch
7+
from haystack.utils import Secret
8+
9+
from openai.types.image import Image
10+
from openai.types import ImagesResponse
11+
from haystack.components.generators.openai_dalle import DALLEImageGenerator
12+
13+
14+
@pytest.fixture
15+
def mock_image_response():
16+
with patch("openai.resources.images.Images.generate") as mock_image_generate:
17+
image_response = ImagesResponse(created=1630000000, data=[Image(url="test-url", revised_prompt="test-prompt")])
18+
mock_image_generate.return_value = image_response
19+
yield mock_image_generate
20+
21+
22+
class TestDALLEImageGenerator:
23+
def test_init_default(self, monkeypatch):
24+
component = DALLEImageGenerator()
25+
assert component.model == "dall-e-3"
26+
assert component.quality == "standard"
27+
assert component.size == "1024x1024"
28+
assert component.response_format == "url"
29+
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
30+
assert component.api_base_url is None
31+
assert component.organization is None
32+
assert pytest.approx(component.timeout) == 30.0
33+
assert component.max_retries is 5
34+
35+
def test_init_with_params(self, monkeypatch):
36+
component = DALLEImageGenerator(
37+
model="dall-e-2",
38+
quality="hd",
39+
size="256x256",
40+
response_format="b64_json",
41+
api_key=Secret.from_env_var("EXAMPLE_API_KEY"),
42+
api_base_url="https://api.openai.com",
43+
organization="test-org",
44+
timeout=60,
45+
max_retries=10,
46+
)
47+
assert component.model == "dall-e-2"
48+
assert component.quality == "hd"
49+
assert component.size == "256x256"
50+
assert component.response_format == "b64_json"
51+
assert component.api_key == Secret.from_env_var("EXAMPLE_API_KEY")
52+
assert component.api_base_url == "https://api.openai.com"
53+
assert component.organization == "test-org"
54+
assert pytest.approx(component.timeout) == 60.0
55+
assert component.max_retries == 10
56+
57+
def test_warm_up(self, monkeypatch):
58+
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
59+
component = DALLEImageGenerator()
60+
component.warm_up()
61+
assert component.client.api_key == "test-api-key"
62+
assert component.client.timeout == 30
63+
assert component.client.max_retries == 5
64+
65+
def test_to_dict(self):
66+
generator = DALLEImageGenerator()
67+
data = generator.to_dict()
68+
assert data == {
69+
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
70+
"init_parameters": {
71+
"model": "dall-e-3",
72+
"quality": "standard",
73+
"size": "1024x1024",
74+
"response_format": "url",
75+
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
76+
"api_base_url": None,
77+
"organization": None,
78+
},
79+
}
80+
81+
def test_to_dict_with_params(self):
82+
generator = DALLEImageGenerator(
83+
model="dall-e-2",
84+
quality="hd",
85+
size="256x256",
86+
response_format="b64_json",
87+
api_key=Secret.from_env_var("EXAMPLE_API_KEY"),
88+
api_base_url="https://api.openai.com",
89+
organization="test-org",
90+
timeout=60,
91+
max_retries=10,
92+
)
93+
data = generator.to_dict()
94+
assert data == {
95+
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
96+
"init_parameters": {
97+
"model": "dall-e-2",
98+
"quality": "hd",
99+
"size": "256x256",
100+
"response_format": "b64_json",
101+
"api_key": {"type": "env_var", "env_vars": ["EXAMPLE_API_KEY"], "strict": True},
102+
"api_base_url": "https://api.openai.com",
103+
"organization": "test-org",
104+
},
105+
}
106+
107+
def test_from_dict(self):
108+
data = {
109+
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
110+
"init_parameters": {
111+
"model": "dall-e-3",
112+
"quality": "standard",
113+
"size": "1024x1024",
114+
"response_format": "url",
115+
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
116+
"api_base_url": None,
117+
"organization": None,
118+
},
119+
}
120+
generator = DALLEImageGenerator.from_dict(data)
121+
assert generator.model == "dall-e-3"
122+
assert generator.quality == "standard"
123+
assert generator.size == "1024x1024"
124+
assert generator.response_format == "url"
125+
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}
126+
127+
def test_from_dict_default_params(self):
128+
data = {"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", "init_parameters": {}}
129+
generator = DALLEImageGenerator.from_dict(data)
130+
assert generator.model == "dall-e-3"
131+
assert generator.quality == "standard"
132+
assert generator.size == "1024x1024"
133+
assert generator.response_format == "url"
134+
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}
135+
assert generator.api_base_url is None
136+
assert generator.organization is None
137+
assert pytest.approx(generator.timeout) == 30.0
138+
assert generator.max_retries == 5
139+
140+
def test_run(self, mock_image_response):
141+
generator = DALLEImageGenerator(api_key=Secret.from_token("test-api-key"))
142+
generator.warm_up()
143+
response = generator.run("Show me a picture of a black cat.")
144+
assert isinstance(response, dict)
145+
assert "images" in response and "revised_prompt" in response
146+
assert response["images"] == ["test-url"]
147+
assert response["revised_prompt"] == "test-prompt"

0 commit comments

Comments
 (0)