Skip to content

Commit 96b9d3e

Browse files
fix: Adding missing component decorator to AzureOpenAIGenerator (#7698)
* initial import * adding release notes * tests avoiding I/O operations * Update fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml
1 parent cc1d4b1 commit 96b9d3e

File tree

5 files changed

+29
-2
lines changed

5 files changed

+29
-2
lines changed

haystack/components/generators/azure.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
# pylint: disable=import-error
99
from openai.lib.azure import AzureOpenAI
1010

11-
from haystack import default_from_dict, default_to_dict, logging
11+
from haystack import component, default_from_dict, default_to_dict, logging
1212
from haystack.components.generators import OpenAIGenerator
1313
from haystack.dataclasses import StreamingChunk
1414
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1515

1616
logger = logging.getLogger(__name__)
1717

1818

19+
@component
1920
class AzureOpenAIGenerator(OpenAIGenerator):
2021
"""
2122
A Generator component that uses OpenAI's large language models (LLMs) on Azure to generate text.

haystack/components/generators/chat/azure.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
# pylint: disable=import-error
99
from openai.lib.azure import AzureOpenAI
1010

11-
from haystack import default_from_dict, default_to_dict, logging
11+
from haystack import component, default_from_dict, default_to_dict, logging
1212
from haystack.components.generators.chat import OpenAIChatGenerator
1313
from haystack.dataclasses import StreamingChunk
1414
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1515

1616
logger = logging.getLogger(__name__)
1717

1818

19+
@component
1920
class AzureOpenAIChatGenerator(OpenAIChatGenerator):
2021
"""
2122
A Chat Generator component that uses the Azure OpenAI API to generate text.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Azure generators components fixed, they were missing the `@component` decorator.

test/components/generators/chat/test_azure.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
from openai import OpenAIError
88

9+
from haystack import Pipeline
910
from haystack.components.generators.chat import AzureOpenAIChatGenerator
1011
from haystack.components.generators.utils import print_streaming_chunk
1112
from haystack.dataclasses import ChatMessage
@@ -80,6 +81,15 @@ def test_to_dict_with_parameters(self, monkeypatch):
8081
},
8182
}
8283

84+
def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
85+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
86+
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
87+
p = Pipeline()
88+
p.add_component(instance=generator, name="generator")
89+
p_str = p.dumps()
90+
q = Pipeline.loads(p_str)
91+
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
92+
8393
@pytest.mark.integration
8494
@pytest.mark.skipif(
8595
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

test/components/generators/test_azure.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
import os
5+
6+
from haystack import Pipeline
57
from haystack.utils.auth import Secret
68

79
import pytest
@@ -83,6 +85,15 @@ def test_to_dict_with_parameters(self, monkeypatch):
8385
},
8486
}
8587

88+
def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
89+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
90+
generator = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
91+
p = Pipeline()
92+
p.add_component(instance=generator, name="generator")
93+
p_str = p.dumps()
94+
q = Pipeline.loads(p_str)
95+
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with AzureOpenAIGenerator failed."
96+
8697
@pytest.mark.integration
8798
@pytest.mark.skipif(
8899
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

0 commit comments

Comments
 (0)