Skip to content

Commit 0e562bc

Browse files
Update OpenAI text embeddings (#627)
Co-authored-by: Wendong-Fan <[email protected]> Co-authored-by: Wendong <[email protected]>
1 parent aa16295 commit 0e562bc

File tree

6 files changed

+48
-38
lines changed

6 files changed

+48
-38
lines changed

camel/embeddings/openai_embedding.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
from __future__ import annotations
15+
1416
import os
15-
from typing import Any, List, Optional
17+
from typing import Any
1618

17-
from openai import OpenAI
19+
from openai import NOT_GIVEN, NotGiven, OpenAI
1820

1921
from camel.embeddings.base import BaseEmbedding
2022
from camel.types import EmbeddingModelType
@@ -25,47 +27,58 @@ class OpenAIEmbedding(BaseEmbedding[str]):
2527
r"""Provides text embedding functionalities using OpenAI's models.
2628
2729
Args:
28-
model (OpenAiEmbeddingModel, optional): The model type to be used for
29-
generating embeddings. (default: :obj:`ModelType.ADA_2`)
30-
api_key (Optional[str]): The API key for authenticating with the
30+
model_type (EmbeddingModelType, optional): The model type to be
31+
used for text embeddings.
32+
(default: :obj:`TEXT_EMBEDDING_3_SMALL`)
33+
api_key (str, optional): The API key for authenticating with the
3134
OpenAI service. (default: :obj:`None`)
35+
dimensions (int, optional): The text embedding output dimensions.
36+
(default: :obj:`NOT_GIVEN`)
3237
3338
Raises:
3439
RuntimeError: If an unsupported model type is specified.
3540
"""
3641

3742
def __init__(
3843
self,
39-
model_type: EmbeddingModelType = EmbeddingModelType.ADA_2,
40-
api_key: Optional[str] = None,
44+
model_type: EmbeddingModelType = (
45+
EmbeddingModelType.TEXT_EMBEDDING_3_SMALL
46+
),
47+
api_key: str | None = None,
48+
dimensions: int | NotGiven = NOT_GIVEN,
4149
) -> None:
4250
if not model_type.is_openai:
4351
raise ValueError("Invalid OpenAI embedding model type.")
4452
self.model_type = model_type
45-
self.output_dim = model_type.output_dim
53+
if dimensions == NOT_GIVEN:
54+
self.output_dim = model_type.output_dim
55+
else:
56+
assert isinstance(dimensions, int)
57+
self.output_dim = dimensions
4658
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
4759
self.client = OpenAI(timeout=60, max_retries=3, api_key=self._api_key)
4860

4961
@model_api_key_required
5062
def embed_list(
5163
self,
52-
objs: List[str],
64+
objs: list[str],
5365
**kwargs: Any,
54-
) -> List[List[float]]:
66+
) -> list[list[float]]:
5567
r"""Generates embeddings for the given texts.
5668
5769
Args:
58-
objs (List[str]): The texts for which to generate the embeddings.
70+
objs (list[str]): The texts for which to generate the embeddings.
5971
**kwargs (Any): Extra kwargs passed to the embedding API.
6072
6173
Returns:
62-
List[List[float]]: A list that represents the generated embedding
74+
list[list[float]]: A list that represents the generated embedding
6375
as a list of floating-point numbers.
6476
"""
6577
# TODO: count tokens
6678
response = self.client.embeddings.create(
6779
input=objs,
6880
model=self.model_type.value,
81+
dimensions=self.output_dim,
6982
**kwargs,
7083
)
7184
return [data.embedding for data in response.data]

camel/retrievers/auto_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ def _get_file_modified_date_from_storage(
159159
) -> str:
160160
r"""Retrieves the last modified date and time of a given file. This
161161
function takes vector storage instance as input and returns the last
162-
modified date from the meta data.
162+
modified date from the metadata.
163163
164164
Args:
165165
vector_storage_instance (BaseVectorStorage): The vector storage
166-
where modified date is to be retrieved from meta data.
166+
where modified date is to be retrieved from metadata.
167167
168168
Returns:
169169
str: The last modified date from vector storage.

camel/types/enums.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -186,35 +186,27 @@ def validate_model_name(self, model_name: str) -> bool:
186186

187187

188188
class EmbeddingModelType(Enum):
189-
ADA_2 = "text-embedding-ada-002"
190-
ADA_1 = "text-embedding-ada-001"
191-
BABBAGE_1 = "text-embedding-babbage-001"
192-
CURIE_1 = "text-embedding-curie-001"
193-
DAVINCI_1 = "text-embedding-davinci-001"
189+
TEXT_EMBEDDING_ADA_2 = "text-embedding-ada-002"
190+
TEXT_EMBEDDING_3_SMALL = "text-embedding-3-small"
191+
TEXT_EMBEDDING_3_LARGE = "text-embedding-3-large"
194192

195193
@property
196194
def is_openai(self) -> bool:
197195
r"""Returns whether this type of models is an OpenAI-released model."""
198196
return self in {
199-
EmbeddingModelType.ADA_2,
200-
EmbeddingModelType.ADA_1,
201-
EmbeddingModelType.BABBAGE_1,
202-
EmbeddingModelType.CURIE_1,
203-
EmbeddingModelType.DAVINCI_1,
197+
EmbeddingModelType.TEXT_EMBEDDING_ADA_2,
198+
EmbeddingModelType.TEXT_EMBEDDING_3_SMALL,
199+
EmbeddingModelType.TEXT_EMBEDDING_3_LARGE,
204200
}
205201

206202
@property
207203
def output_dim(self) -> int:
208-
if self is EmbeddingModelType.ADA_2:
204+
if self is EmbeddingModelType.TEXT_EMBEDDING_ADA_2:
209205
return 1536
210-
elif self is EmbeddingModelType.ADA_1:
211-
return 1024
212-
elif self is EmbeddingModelType.BABBAGE_1:
213-
return 2048
214-
elif self is EmbeddingModelType.CURIE_1:
215-
return 4096
216-
elif self is EmbeddingModelType.DAVINCI_1:
217-
return 12288
206+
elif self is EmbeddingModelType.TEXT_EMBEDDING_3_SMALL:
207+
return 1536
208+
elif self is EmbeddingModelType.TEXT_EMBEDDING_3_LARGE:
209+
return 3072
218210
else:
219211
raise ValueError(f"Unknown model type {self}.")
220212

docs/key_modules/embeddings.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Consider two sentences:
1212

1313
- "A child is kicking a football on a playground."
1414

15-
Text embedding models would transform these sentences into two high-dimensional vector (*e.g.*, 1536 dimension if using `text-embedding-ada-002`). Despite different wordings, the vectors will be similar, capturing the shared concept of a child playing a ball game outdoors. This transformation into vectors allows machines to understand and compare the semantic similarities between the context.
15+
Text embedding models would transform these sentences into two high-dimensional vector (*e.g.*, 1536 dimension if using `text-embedding-3-small`). Despite different wordings, the vectors will be similar, capturing the shared concept of a child playing a ball game outdoors. This transformation into vectors allows machines to understand and compare the semantic similarities between the context.
1616

1717
### 1.2. Image Embeddings (WIP)
1818
Image embeddings convert images into numerical vectors, capturing essential features like shapes, colors, textures, and spatial hierarchies. This transformation is typically performed by Convolutional Neural Networks (CNNs) or other advanced neural network architectures designed for image processing. The resulting embeddings can be used for tasks like image classification, similarity comparison, and retrieval.
@@ -39,7 +39,7 @@ from camel.embeddings import OpenAIEmbedding
3939
from camel.types import EmbeddingModelType
4040

4141
# Initialize the OpenAI embedding with a specific model
42-
openai_embedding = OpenAIEmbedding(model_type=EmbeddingModelType.ADA_2)
42+
openai_embedding = OpenAIEmbedding(model_type=EmbeddingModelType.TEXT_EMBEDDING_3_SMALL)
4343

4444
# Generate embeddings for a list of texts
4545
embeddings = openai_embedding.embed_list(["Hello, world!", "Another example"])

test/embeddings/test_openai_embedding.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515

1616

1717
def test_openai_embedding():
18-
text = "test embedding text."
1918
embedding_model = OpenAIEmbedding()
19+
text = "test 1."
2020
vector = embedding_model.embed(text)
2121
assert len(vector) == embedding_model.get_output_dim()
22+
23+
embedding_model = OpenAIEmbedding(dimensions=256)
24+
text = "test 2"
25+
vector = embedding_model.embed(text)
26+
assert len(vector) == embedding_model.get_output_dim() == 256

test/retrievers/test_auto_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ def test_get_file_modified_date_from_file(auto_retriever):
6868

6969
def test_run_vector_retriever(auto_retriever):
7070
# Define mock data for testing
71-
query_related = "what is camel"
71+
query_related = "what is camel ai"
7272
query_unrealted = "unrelated query"
7373
content_input_paths = "https://www.camel-ai.org/"
7474
top_k = 1
75-
similarity_threshold = 0.75
75+
similarity_threshold = 0.5
7676

7777
# Test with query related to the content in mock data
7878
result_related = auto_retriever.run_vector_retriever(

0 commit comments

Comments
 (0)