@@ -51,6 +51,8 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
51
51
embedding_separator : str = "\n " ,
52
52
timeout : Optional [float ] = None ,
53
53
max_retries : Optional [int ] = None ,
54
+ * ,
55
+ default_headers : Optional [Dict [str , str ]] = None ,
54
56
):
55
57
"""
56
58
Creates an AzureOpenAIDocumentEmbedder component.
@@ -95,6 +97,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
95
97
`OPENAI_TIMEOUT` environment variable, or 30 seconds.
96
98
:param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
97
99
If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries.
100
+ :param default_headers: Default headers to send to the AzureOpenAI client.
98
101
"""
99
102
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
100
103
azure_endpoint = azure_endpoint or os .environ .get ("AZURE_OPENAI_ENDPOINT" )
@@ -119,6 +122,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
119
122
self .embedding_separator = embedding_separator
120
123
self .timeout = timeout or float (os .environ .get ("OPENAI_TIMEOUT" , 30.0 ))
121
124
self .max_retries = max_retries or int (os .environ .get ("OPENAI_MAX_RETRIES" , 5 ))
125
+ self .default_headers = default_headers or {}
122
126
123
127
self ._client = AzureOpenAI (
124
128
api_version = api_version ,
@@ -129,6 +133,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
129
133
organization = organization ,
130
134
timeout = self .timeout ,
131
135
max_retries = self .max_retries ,
136
+ default_headers = self .default_headers ,
132
137
)
133
138
134
139
def _get_telemetry_data (self ) -> Dict [str , Any ]:
@@ -161,6 +166,7 @@ def to_dict(self) -> Dict[str, Any]:
161
166
azure_ad_token = self .azure_ad_token .to_dict () if self .azure_ad_token is not None else None ,
162
167
timeout = self .timeout ,
163
168
max_retries = self .max_retries ,
169
+ default_headers = self .default_headers ,
164
170
)
165
171
166
172
@classmethod
0 commit comments