Skip to content

Commit 3606d96

Browse files
authored
refactor: Internal deduplication (#1579)
1 parent cf3d9c5 commit 3606d96

File tree

2 files changed

+101
-40
lines changed

2 files changed

+101
-40
lines changed

camel/utils/deduplication.py

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,22 @@
1515

1616
from typing import Dict, List, Literal, Optional
1717

18-
import numpy as np
1918
from pydantic import BaseModel
20-
from sklearn.metrics.pairwise import cosine_similarity
2119

2220
from camel.embeddings.base import BaseEmbedding
2321

2422

2523
class DeduplicationResult(BaseModel):
26-
"""
27-
The result of deduplication.
24+
r"""The result of deduplication.
2825
2926
Attributes:
3027
original_texts (List[str]): The original texts.
3128
unique_ids (List[int]): A list of ids that are unique (not duplicates).
32-
unique_embeddings_dict (Dict[int, List[float]]):
33-
A mapping from the index of each unique text to its embedding.
34-
duplicate_to_target_map (Dict[int, int]):
35-
A mapping from the index of the duplicate text to the index
36-
of the text it is considered a duplicate of.
29+
unique_embeddings_dict (Dict[int, List[float]]): A mapping from the
30+
index of each unique text to its embedding.
31+
duplicate_to_target_map (Dict[int, int]): A mapping from the index of
32+
the duplicate text to the index of the text it is considered a
33+
duplicate of.
3734
"""
3835

3936
original_texts: List[str]
@@ -48,12 +45,12 @@ def deduplicate_internally(
4845
embedding_instance: Optional[BaseEmbedding[str]] = None,
4946
embeddings: Optional[List[List[float]]] = None,
5047
strategy: Literal["top1", "llm-supervise"] = "top1",
48+
batch_size: int = 1000,
5149
) -> DeduplicationResult:
52-
"""
53-
Deduplicate a list of strings based on their cosine similarity.
50+
r"""Deduplicate a list of strings based on their cosine similarity.
5451
5552
You can either:
56-
1) Provide a Camel `BaseEmbedding` instance via `embedding_instance` to let
53+
1) Provide a CAMEL `BaseEmbedding` instance via `embedding_instance` to let
5754
this function handle the embedding internally, OR
5855
2) Directly pass a list of pre-computed embeddings to `embeddings`.
5956
@@ -67,16 +64,18 @@ def deduplicate_internally(
6764
Args:
6865
texts (List[str]): The list of texts to be deduplicated.
6966
threshold (float, optional): The similarity threshold for considering
70-
two texts as duplicates. Default is 0.65.
67+
two texts as duplicates. (default: :obj:`0.65`)
7168
embedding_instance (Optional[BaseEmbedding[str]], optional):
72-
A Camel embedding instance for automatic embedding. Defaults to
73-
None.
69+
A CAMEL embedding instance for automatic embedding. (default:
70+
:obj:`None`)
7471
embeddings (Optional[List[List[float]]], optional):
7572
Pre-computed embeddings of `texts`. Each element in the list
7673
corresponds to the embedding of the text in the same index of
77-
`texts`. Defaults to None.
74+
`texts`. (default: :obj:`None`)
7875
strategy (Literal["top1", "llm-supervise"], optional):
79-
The strategy to use for deduplication. Defaults to "top1".
76+
The strategy to use for deduplication. (default: :obj:`"top1"`)
77+
batch_size (int, optional): The size of the batch to use for
78+
calculating cosine similarities. (default: :obj:`1000`)
8079
8180
Returns:
8281
DeduplicationResult: An object that contains:
@@ -127,13 +126,39 @@ def deduplicate_internally(
127126
# This indicates the text at index 2 is considered
128127
# a duplicate of index 0.
129128
"""
129+
import numpy as np
130+
from sklearn.metrics.pairwise import cosine_similarity
131+
132+
if len(texts) == 0:
133+
return DeduplicationResult(
134+
original_texts=[],
135+
unique_ids=[],
136+
unique_embeddings_dict={},
137+
duplicate_to_target_map={},
138+
)
139+
140+
if len(texts) == 1:
141+
return DeduplicationResult(
142+
original_texts=texts,
143+
unique_ids=[0],
144+
unique_embeddings_dict={
145+
0: embeddings[0]
146+
if embeddings
147+
else embedding_instance.embed_list(texts)[0] # type: ignore[union-attr]
148+
},
149+
duplicate_to_target_map={},
150+
)
151+
130152
if strategy == "llm-supervise":
131153
# TODO: Implement LLM-supervise deduplication.
132154
raise NotImplementedError(
133155
"LLM-supervise deduplication is not yet implemented."
134156
)
135157

136158
# Check if the parameters are valid.
159+
if not 0 <= threshold <= 1:
160+
raise ValueError("Threshold must be between 0 and 1")
161+
137162
if embedding_instance is None and embeddings is None:
138163
raise ValueError(
139164
"Either 'embedding_instance' or 'embeddings' must be provided."
@@ -155,30 +180,38 @@ def deduplicate_internally(
155180
"of 'texts'."
156181
)
157182

158-
# Calculate cosine similarity.
159-
similarity_matrix = cosine_similarity(embeddings)
183+
# Convert embeddings to numpy array for efficient computation
184+
embeddings_array = np.array(embeddings)
160185
n = len(texts)
186+
duplicate_to_target_map: Dict[int, int] = {}
161187

162-
# Use the lower triangle to avoid redundant comparisons
163-
# (or self-comparisons).
164-
tril_mask = np.tril(np.ones((n, n)), k=-1)
165-
similarity_matrix = similarity_matrix * tril_mask
188+
# Process in batches to reduce memory usage
189+
for i in range(0, n, batch_size):
190+
batch_end = min(i + batch_size, n)
191+
# Calculate cosine similarity for current batch
192+
batch_similarities = cosine_similarity(
193+
embeddings_array[i:batch_end], embeddings_array[:batch_end]
194+
)
166195

167-
# For each row, find the column with the highest similarity
168-
# that exceeds the threshold. If no similarity exceeds the threshold,
169-
# set the column index to -1.
170-
masked_similarities = np.where(
171-
similarity_matrix > threshold, similarity_matrix, -1
172-
)
173-
max_indices = masked_similarities.argmax(axis=1)
196+
# Create mask for lower triangle (avoid self-comparison and redundant
197+
# checks)
198+
tril_mask = np.tril(np.ones_like(batch_similarities), k=-1)
199+
batch_similarities = batch_similarities * tril_mask
174200

175-
duplicate_to_target_map: Dict[int, int] = {}
176-
above_threshold = similarity_matrix[np.arange(n), max_indices] > threshold
201+
# Find duplicates in current batch
202+
masked_similarities = np.where(
203+
batch_similarities > threshold, batch_similarities, -1
204+
)
205+
max_indices = masked_similarities.argmax(axis=1)
206+
above_threshold = (
207+
batch_similarities[np.arange(batch_end - i), max_indices]
208+
> threshold
209+
)
177210

178-
# Construct the "duplicate->target" mapping.
179-
for i in range(n):
180-
if above_threshold[i]:
181-
duplicate_to_target_map[i] = max_indices[i]
211+
# Update duplicate map
212+
for j, is_duplicate in enumerate(above_threshold):
213+
if is_duplicate:
214+
duplicate_to_target_map[i + j] = max_indices[j]
182215

183216
# Get the actual unique ids and embeddings.
184217
unique_ids = []

test/utils/test_deduplication.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121

2222

2323
class MockEmbedding(BaseEmbedding[str]):
24-
"""
25-
A mock embedding class that always returns the same embedding vector
24+
r"""A mock embedding class that always returns the same embedding vector
2625
for any input text. Useful for testing deduplication logic.
2726
"""
2827

@@ -36,6 +35,36 @@ def get_output_dim(self) -> int:
3635
return 3
3736

3837

38+
def test_deduplicate_internally_empty_list():
39+
mock_embedding_instance = MockEmbedding()
40+
result = deduplicate_internally(
41+
texts=[],
42+
threshold=0.9,
43+
embedding_instance=mock_embedding_instance,
44+
strategy="top1",
45+
)
46+
assert len(result.original_texts) == 0
47+
assert len(result.unique_ids) == 0
48+
assert len(result.unique_embeddings_dict) == 0
49+
assert len(result.duplicate_to_target_map) == 0
50+
51+
52+
def test_deduplicate_internally_single_item():
53+
mock_embedding_instance = MockEmbedding()
54+
texts = ["Hello world!"]
55+
result = deduplicate_internally(
56+
texts=texts,
57+
threshold=0.9,
58+
embedding_instance=mock_embedding_instance,
59+
strategy="top1",
60+
)
61+
assert result.original_texts == texts
62+
assert result.unique_ids == [0]
63+
assert len(result.unique_embeddings_dict) == 1
64+
assert 0 in result.unique_embeddings_dict
65+
assert len(result.duplicate_to_target_map) == 0
66+
67+
3968
def test_deduplicate_internally_with_mock_embedding():
4069
texts = ["Hello world!", "Hello world!", "HELLO WORLD!", "Something else"]
4170
mock_embedding_instance = MockEmbedding()
@@ -116,8 +145,7 @@ def test_deduplicate_internally_with_precomputed_embeddings():
116145

117146

118147
def test_deduplicate_internally_chain_scenario():
119-
"""
120-
Test scenario:
148+
r"""Test scenario:
121149
- A <-> B similarity > threshold
122150
- B <-> C similarity > threshold
123151
- C <-> D similarity > threshold

0 commit comments

Comments
 (0)