Skip to content

Commit a65d44e

Browse files
authored
feat: Internal deduplication impl. (#1568)
1 parent 4536d76 commit a65d44e

File tree

4 files changed

+405
-0
lines changed

4 files changed

+405
-0
lines changed

camel/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
track_agent,
4141
)
4242
from .constants import Constants
43+
from .deduplication import DeduplicationResult, deduplicate_internally
4344
from .response_format import get_pydantic_model
4445
from .token_counting import (
4546
AnthropicTokenCounter,
@@ -82,6 +83,8 @@
8283
"get_pydantic_model",
8384
"download_github_subdirectory",
8485
"generate_prompt_for_structured_output",
86+
"deduplicate_internally",
87+
"DeduplicationResult",
8588
"retry_on_error",
8689
"BatchProcessor",
8790
]

camel/utils/deduplication.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
15+
16+
from typing import Dict, List, Literal, Optional
17+
18+
import numpy as np
19+
from pydantic import BaseModel
20+
from sklearn.metrics.pairwise import cosine_similarity
21+
22+
from camel.embeddings.base import BaseEmbedding
23+
24+
25+
class DeduplicationResult(BaseModel):
26+
"""
27+
The result of deduplication.
28+
29+
Attributes:
30+
original_texts (List[str]): The original texts.
31+
unique_ids (List[int]): A list of ids that are unique (not duplicates).
32+
unique_embeddings_dict (Dict[int, List[float]]):
33+
A mapping from the index of each unique text to its embedding.
34+
duplicate_to_target_map (Dict[int, int]):
35+
A mapping from the index of the duplicate text to the index
36+
of the text it is considered a duplicate of.
37+
"""
38+
39+
original_texts: List[str]
40+
unique_ids: List[int]
41+
unique_embeddings_dict: Dict[int, List[float]]
42+
duplicate_to_target_map: Dict[int, int]
43+
44+
45+
def deduplicate_internally(
46+
texts: List[str],
47+
threshold: float = 0.65,
48+
embedding_instance: Optional[BaseEmbedding[str]] = None,
49+
embeddings: Optional[List[List[float]]] = None,
50+
strategy: Literal["top1", "llm-supervise"] = "top1",
51+
) -> DeduplicationResult:
52+
"""
53+
Deduplicate a list of strings based on their cosine similarity.
54+
55+
You can either:
56+
1) Provide a Camel `BaseEmbedding` instance via `embedding_instance` to let
57+
this function handle the embedding internally, OR
58+
2) Directly pass a list of pre-computed embeddings to `embeddings`.
59+
60+
If both `embedding_instance` and `embeddings` are provided, the function
61+
will raise a ValueError to avoid ambiguous usage.
62+
63+
strategy is used to specify different strategies, where 'top1' selects the
64+
one with highest similarity, and 'llm-supervise' uses LLM to determine if
65+
texts are duplicates (not yet implemented).
66+
67+
Args:
68+
texts (List[str]): The list of texts to be deduplicated.
69+
threshold (float, optional): The similarity threshold for considering
70+
two texts as duplicates. Default is 0.65.
71+
embedding_instance (Optional[BaseEmbedding[str]], optional):
72+
A Camel embedding instance for automatic embedding. Defaults to
73+
None.
74+
embeddings (Optional[List[List[float]]], optional):
75+
Pre-computed embeddings of `texts`. Each element in the list
76+
corresponds to the embedding of the text in the same index of
77+
`texts`. Defaults to None.
78+
strategy (Literal["top1", "llm-supervise"], optional):
79+
The strategy to use for deduplication. Defaults to "top1".
80+
81+
Returns:
82+
DeduplicationResult: An object that contains:
83+
- `original_texts`: The original texts.
84+
- `unique_ids`: The unique ids after deduplication.
85+
- `unique_embeddings_dict`: A dict mapping from (unique) text id
86+
to its embedding.
87+
- `duplicate_to_target_map`: A dict mapping from the id of a
88+
duplicate text to the id of the text it is considered a duplicate
89+
of.
90+
91+
Raises:
92+
NotImplementedError: If the strategy is not "top1".
93+
ValueError: If neither embeddings nor embedding_instance is provided,
94+
or if both are provided at the same time.
95+
ValueError: If the length of `embeddings` does not match the length of
96+
`texts`.
97+
98+
Example:
99+
>>> from camel.embeddings.openai_embedding import OpenAIEmbedding
100+
>>> # Suppose we have 5 texts, some of which may be duplicates
101+
>>> texts = [
102+
... "What is AI?",
103+
... "Artificial Intelligence is about machines",
104+
... "What is AI?",
105+
... "Deep Learning is a subset of AI",
106+
... "What is artificial intelligence?"
107+
... ]
108+
>>> # or any other BaseEmbedding instance
109+
>>> embedding_model = OpenAIEmbedding()
110+
>>> result = deduplicate_internally(
111+
... texts=texts,
112+
... threshold=0.7,
113+
... embedding_instance=embedding_model
114+
... )
115+
>>> print("Unique ids:")
116+
>>> for uid in result.unique_ids:
117+
... print(texts[uid])
118+
Unique ids:
119+
What is AI?
120+
Artificial Intelligence is about machines
121+
Deep Learning is a subset of AI
122+
What is artificial intelligence?
123+
124+
>>> print("Duplicate map:")
125+
>>> print(result.duplicate_to_target_map)
126+
{2: 0}
127+
# This indicates the text at index 2 is considered
128+
# a duplicate of index 0.
129+
"""
130+
if strategy == "llm-supervise":
131+
# TODO: Implement LLM-supervise deduplication.
132+
raise NotImplementedError(
133+
"LLM-supervise deduplication is not yet implemented."
134+
)
135+
136+
# Check if the parameters are valid.
137+
if embedding_instance is None and embeddings is None:
138+
raise ValueError(
139+
"Either 'embedding_instance' or 'embeddings' must be provided."
140+
)
141+
if embedding_instance is not None and embeddings is not None:
142+
raise ValueError(
143+
"Cannot provide both 'embedding_instance' and 'embeddings'. "
144+
"Please choose only one way to supply embeddings."
145+
)
146+
147+
if embedding_instance is not None:
148+
# Use Camel's embedding_instance to vectorize.
149+
embeddings = embedding_instance.embed_list(texts)
150+
else:
151+
# Use pre-supplied embeddings.
152+
if embeddings and len(embeddings) != len(texts):
153+
raise ValueError(
154+
"The length of 'embeddings' does not match the length "
155+
"of 'texts'."
156+
)
157+
158+
# Calculate cosine similarity.
159+
similarity_matrix = cosine_similarity(embeddings)
160+
n = len(texts)
161+
162+
# Use the lower triangle to avoid redundant comparisons
163+
# (or self-comparisons).
164+
tril_mask = np.tril(np.ones((n, n)), k=-1)
165+
similarity_matrix = similarity_matrix * tril_mask
166+
167+
# For each row, find the column with the highest similarity
168+
# that exceeds the threshold. If no similarity exceeds the threshold,
169+
# set the column index to -1.
170+
masked_similarities = np.where(
171+
similarity_matrix > threshold, similarity_matrix, -1
172+
)
173+
max_indices = masked_similarities.argmax(axis=1)
174+
175+
duplicate_to_target_map: Dict[int, int] = {}
176+
above_threshold = similarity_matrix[np.arange(n), max_indices] > threshold
177+
178+
# Construct the "duplicate->target" mapping.
179+
for i in range(n):
180+
if above_threshold[i]:
181+
duplicate_to_target_map[i] = max_indices[i]
182+
183+
# Get the actual unique ids and embeddings.
184+
unique_ids = []
185+
unique_embeddings_dict = {}
186+
187+
assert embeddings, "embeddings must be valid"
188+
189+
for i, (_, emb) in enumerate(zip(texts, embeddings)):
190+
if i not in duplicate_to_target_map:
191+
unique_ids.append(i)
192+
unique_embeddings_dict[i] = emb
193+
194+
return DeduplicationResult(
195+
original_texts=texts,
196+
unique_ids=unique_ids,
197+
unique_embeddings_dict=unique_embeddings_dict,
198+
duplicate_to_target_map=duplicate_to_target_map,
199+
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ module = [
533533
"tree-sitter-python",
534534
"tree-sitter",
535535
"pandasai",
536+
"sklearn.metrics.pairwise",
536537
"sympy",
537538
]
538539
ignore_missing_imports = true

0 commit comments

Comments
 (0)