Skip to content

Commit db76ae2

Browse files
authored
feat: add default_headers for Azure embedders (#8699)
* Add default_headers param to azure embedders
1 parent 4f73b19 commit db76ae2

File tree

5 files changed

+82
-0
lines changed

5 files changed

+82
-0
lines changed

Diff for: haystack/components/embedders/azure_document_embedder.py

+6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
5151
embedding_separator: str = "\n",
5252
timeout: Optional[float] = None,
5353
max_retries: Optional[int] = None,
54+
*,
55+
default_headers: Optional[Dict[str, str]] = None,
5456
):
5557
"""
5658
Creates an AzureOpenAIDocumentEmbedder component.
@@ -95,6 +97,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
9597
`OPENAI_TIMEOUT` environment variable, or 30 seconds.
9698
:param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
9799
If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries.
100+
:param default_headers: Default headers to send to the AzureOpenAI client.
98101
"""
99102
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
100103
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
@@ -119,6 +122,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
119122
self.embedding_separator = embedding_separator
120123
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
121124
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
125+
self.default_headers = default_headers or {}
122126

123127
self._client = AzureOpenAI(
124128
api_version=api_version,
@@ -129,6 +133,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
129133
organization=organization,
130134
timeout=self.timeout,
131135
max_retries=self.max_retries,
136+
default_headers=self.default_headers,
132137
)
133138

134139
def _get_telemetry_data(self) -> Dict[str, Any]:
@@ -161,6 +166,7 @@ def to_dict(self) -> Dict[str, Any]:
161166
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
162167
timeout=self.timeout,
163168
max_retries=self.max_retries,
169+
default_headers=self.default_headers,
164170
)
165171

166172
@classmethod

Diff for: haystack/components/embedders/azure_text_embedder.py

+6
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
4646
max_retries: Optional[int] = None,
4747
prefix: str = "",
4848
suffix: str = "",
49+
*,
50+
default_headers: Optional[Dict[str, str]] = None,
4951
):
5052
"""
5153
Creates an AzureOpenAITextEmbedder component.
@@ -82,6 +84,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
8284
A string to add at the beginning of each text.
8385
:param suffix:
8486
A string to add at the end of each text.
87+
:param default_headers: Default headers to send to the AzureOpenAI client.
8588
"""
8689
# Why is this here?
8790
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
@@ -105,6 +108,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
105108
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
106109
self.prefix = prefix
107110
self.suffix = suffix
111+
self.default_headers = default_headers or {}
108112

109113
self._client = AzureOpenAI(
110114
api_version=api_version,
@@ -115,6 +119,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
115119
organization=organization,
116120
timeout=self.timeout,
117121
max_retries=self.max_retries,
122+
default_headers=self.default_headers,
118123
)
119124

120125
def _get_telemetry_data(self) -> Dict[str, Any]:
@@ -143,6 +148,7 @@ def to_dict(self) -> Dict[str, Any]:
143148
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
144149
timeout=self.timeout,
145150
max_retries=self.max_retries,
151+
default_headers=self.default_headers,
146152
)
147153

148154
@classmethod
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
enhancements:
3+
- |
4+
Added `default_headers` parameter to `AzureOpenAIDocumentEmbedder` and `AzureOpenAITextEmbedder`.

Diff for: test/components/embedders/test_azure_document_embedder.py

+35
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test_init_default(self, monkeypatch):
2222
assert embedder.progress_bar is True
2323
assert embedder.meta_fields_to_embed == []
2424
assert embedder.embedding_separator == "\n"
25+
assert embedder.default_headers == {}
2526

2627
def test_to_dict(self, monkeypatch):
2728
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
@@ -45,9 +46,43 @@ def test_to_dict(self, monkeypatch):
4546
"embedding_separator": "\n",
4647
"max_retries": 5,
4748
"timeout": 30.0,
49+
"default_headers": {},
4850
},
4951
}
5052

53+
def test_from_dict(self, monkeypatch):
54+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
55+
data = {
56+
"type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
57+
"init_parameters": {
58+
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
59+
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
60+
"api_version": "2023-05-15",
61+
"azure_deployment": "text-embedding-ada-002",
62+
"dimensions": None,
63+
"azure_endpoint": "https://example-resource.azure.openai.com/",
64+
"organization": None,
65+
"prefix": "",
66+
"suffix": "",
67+
"batch_size": 32,
68+
"progress_bar": True,
69+
"meta_fields_to_embed": [],
70+
"embedding_separator": "\n",
71+
"max_retries": 5,
72+
"timeout": 30.0,
73+
"default_headers": {},
74+
},
75+
}
76+
component = AzureOpenAIDocumentEmbedder.from_dict(data)
77+
assert component.azure_deployment == "text-embedding-ada-002"
78+
assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
79+
assert component.api_version == "2023-05-15"
80+
assert component.max_retries == 5
81+
assert component.timeout == 30.0
82+
assert component.prefix == ""
83+
assert component.suffix == ""
84+
assert component.default_headers == {}
85+
5186
@pytest.mark.integration
5287
@pytest.mark.skipif(
5388
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

Diff for: test/components/embedders/test_azure_text_embedder.py

+31
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def test_init_default(self, monkeypatch):
1919
assert embedder.organization is None
2020
assert embedder.prefix == ""
2121
assert embedder.suffix == ""
22+
assert embedder.default_headers == {}
2223

2324
def test_to_dict(self, monkeypatch):
2425
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
@@ -38,9 +39,39 @@ def test_to_dict(self, monkeypatch):
3839
"timeout": 30.0,
3940
"prefix": "",
4041
"suffix": "",
42+
"default_headers": {},
4143
},
4244
}
4345

46+
def test_from_dict(self, monkeypatch):
47+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
48+
data = {
49+
"type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder",
50+
"init_parameters": {
51+
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
52+
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
53+
"azure_deployment": "text-embedding-ada-002",
54+
"dimensions": None,
55+
"organization": None,
56+
"azure_endpoint": "https://example-resource.azure.openai.com/",
57+
"api_version": "2023-05-15",
58+
"max_retries": 5,
59+
"timeout": 30.0,
60+
"prefix": "",
61+
"suffix": "",
62+
"default_headers": {},
63+
},
64+
}
65+
component = AzureOpenAITextEmbedder.from_dict(data)
66+
assert component.azure_deployment == "text-embedding-ada-002"
67+
assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
68+
assert component.api_version == "2023-05-15"
69+
assert component.max_retries == 5
70+
assert component.timeout == 30.0
71+
assert component.prefix == ""
72+
assert component.suffix == ""
73+
assert component.default_headers == {}
74+
4475
@pytest.mark.integration
4576
@pytest.mark.skipif(
4677
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

0 commit comments

Comments
 (0)