Skip to content

Commit 39ccc3d

Browse files
To support tuple of string for Embeddings and AsyncEmbeddings openai#1934
1 parent 89d4933 commit 39ccc3d

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/openai/resources/embeddings.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import base64
6-
from typing import List, Union, Iterable, cast
6+
from typing import List, Tuple, Union, Iterable, cast
77
from typing_extensions import Literal
88

99
import httpx
@@ -46,7 +46,7 @@ def with_streaming_response(self) -> EmbeddingsWithStreamingResponse:
4646
def create(
4747
self,
4848
*,
49-
input: Union[str, List[str], Iterable[int], Iterable[Iterable[int]]],
49+
input: Union[str, List[str], Tuple[str], Iterable[int], Iterable[Iterable[int]]],
5050
model: Union[str, EmbeddingModel],
5151
dimensions: int | NotGiven = NOT_GIVEN,
5252
encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN,
@@ -94,6 +94,10 @@ def create(
9494
9595
timeout: Override the client-level default timeout for this request, in seconds
9696
"""
97+
98+
if isinstance(input, tuple) and all(isinstance(item, str) for item in input):
99+
input = list(input)
100+
97101
params = {
98102
"input": input,
99103
"model": model,
@@ -158,7 +162,7 @@ def with_streaming_response(self) -> AsyncEmbeddingsWithStreamingResponse:
158162
async def create(
159163
self,
160164
*,
161-
input: Union[str, List[str], Iterable[int], Iterable[Iterable[int]]],
165+
input: Union[str, List[str], Tuple[str], Iterable[int], Iterable[Iterable[int]]],
162166
model: Union[str, EmbeddingModel],
163167
dimensions: int | NotGiven = NOT_GIVEN,
164168
encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN,
@@ -206,6 +210,10 @@ async def create(
206210
207211
timeout: Override the client-level default timeout for this request, in seconds
208212
"""
213+
214+
if isinstance(input, tuple) and all(isinstance(item, str) for item in input):
215+
input = list(input)
216+
209217
params = {
210218
"input": input,
211219
"model": model,

0 commit comments

Comments
 (0)