15
15
16
16
from typing import Dict , List , Literal , Optional
17
17
18
- import numpy as np
19
18
from pydantic import BaseModel
20
- from sklearn .metrics .pairwise import cosine_similarity
21
19
22
20
from camel .embeddings .base import BaseEmbedding
23
21
24
22
25
23
class DeduplicationResult (BaseModel ):
26
- """
27
- The result of deduplication.
24
+ r"""The result of deduplication.
28
25
29
26
Attributes:
30
27
original_texts (List[str]): The original texts.
31
28
unique_ids (List[int]): A list of ids that are unique (not duplicates).
32
- unique_embeddings_dict (Dict[int, List[float]]):
33
- A mapping from the index of each unique text to its embedding.
34
- duplicate_to_target_map (Dict[int, int]):
35
- A mapping from the index of the duplicate text to the index
36
- of the text it is considered a duplicate of.
29
+ unique_embeddings_dict (Dict[int, List[float]]): A mapping from the
30
+ index of each unique text to its embedding.
31
+ duplicate_to_target_map (Dict[int, int]): A mapping from the index of
32
+ the duplicate text to the index of the text it is considered a
33
+ duplicate of.
37
34
"""
38
35
39
36
original_texts : List [str ]
@@ -48,12 +45,12 @@ def deduplicate_internally(
48
45
embedding_instance : Optional [BaseEmbedding [str ]] = None ,
49
46
embeddings : Optional [List [List [float ]]] = None ,
50
47
strategy : Literal ["top1" , "llm-supervise" ] = "top1" ,
48
+ batch_size : int = 1000 ,
51
49
) -> DeduplicationResult :
52
- """
53
- Deduplicate a list of strings based on their cosine similarity.
50
+ r"""Deduplicate a list of strings based on their cosine similarity.
54
51
55
52
You can either:
56
- 1) Provide a Camel `BaseEmbedding` instance via `embedding_instance` to let
53
+ 1) Provide a CAMEL `BaseEmbedding` instance via `embedding_instance` to let
57
54
this function handle the embedding internally, OR
58
55
2) Directly pass a list of pre-computed embeddings to `embeddings`.
59
56
@@ -67,16 +64,18 @@ def deduplicate_internally(
67
64
Args:
68
65
texts (List[str]): The list of texts to be deduplicated.
69
66
threshold (float, optional): The similarity threshold for considering
70
- two texts as duplicates. Default is 0.65.
67
+ two texts as duplicates. (default: :obj:` 0.65`)
71
68
embedding_instance (Optional[BaseEmbedding[str]], optional):
72
- A Camel embedding instance for automatic embedding. Defaults to
73
- None.
69
+ A CAMEL embedding instance for automatic embedding. (default:
70
+ :obj:` None`)
74
71
embeddings (Optional[List[List[float]]], optional):
75
72
Pre-computed embeddings of `texts`. Each element in the list
76
73
corresponds to the embedding of the text in the same index of
77
- `texts`. Defaults to None.
74
+ `texts`. (default: :obj:` None`)
78
75
strategy (Literal["top1", "llm-supervise"], optional):
79
- The strategy to use for deduplication. Defaults to "top1".
76
+ The strategy to use for deduplication. (default: :obj:`"top1"`)
77
+ batch_size (int, optional): The size of the batch to use for
78
+ calculating cosine similarities. (default: :obj:`1000`)
80
79
81
80
Returns:
82
81
DeduplicationResult: An object that contains:
@@ -127,13 +126,39 @@ def deduplicate_internally(
127
126
# This indicates the text at index 2 is considered
128
127
# a duplicate of index 0.
129
128
"""
129
+ import numpy as np
130
+ from sklearn .metrics .pairwise import cosine_similarity
131
+
132
+ if len (texts ) == 0 :
133
+ return DeduplicationResult (
134
+ original_texts = [],
135
+ unique_ids = [],
136
+ unique_embeddings_dict = {},
137
+ duplicate_to_target_map = {},
138
+ )
139
+
140
+ if len (texts ) == 1 :
141
+ return DeduplicationResult (
142
+ original_texts = texts ,
143
+ unique_ids = [0 ],
144
+ unique_embeddings_dict = {
145
+ 0 : embeddings [0 ]
146
+ if embeddings
147
+ else embedding_instance .embed_list (texts )[0 ] # type: ignore[union-attr]
148
+ },
149
+ duplicate_to_target_map = {},
150
+ )
151
+
130
152
if strategy == "llm-supervise" :
131
153
# TODO: Implement LLM-supervise deduplication.
132
154
raise NotImplementedError (
133
155
"LLM-supervise deduplication is not yet implemented."
134
156
)
135
157
136
158
# Check if the parameters are valid.
159
+ if not 0 <= threshold <= 1 :
160
+ raise ValueError ("Threshold must be between 0 and 1" )
161
+
137
162
if embedding_instance is None and embeddings is None :
138
163
raise ValueError (
139
164
"Either 'embedding_instance' or 'embeddings' must be provided."
@@ -155,30 +180,38 @@ def deduplicate_internally(
155
180
"of 'texts'."
156
181
)
157
182
158
- # Calculate cosine similarity.
159
- similarity_matrix = cosine_similarity (embeddings )
183
+ # Convert embeddings to numpy array for efficient computation
184
+ embeddings_array = np . array (embeddings )
160
185
n = len (texts )
186
+ duplicate_to_target_map : Dict [int , int ] = {}
161
187
162
- # Use the lower triangle to avoid redundant comparisons
163
- # (or self-comparisons).
164
- tril_mask = np .tril (np .ones ((n , n )), k = - 1 )
165
- similarity_matrix = similarity_matrix * tril_mask
188
+ # Process in batches to reduce memory usage
189
+ for i in range (0 , n , batch_size ):
190
+ batch_end = min (i + batch_size , n )
191
+ # Calculate cosine similarity for current batch
192
+ batch_similarities = cosine_similarity (
193
+ embeddings_array [i :batch_end ], embeddings_array [:batch_end ]
194
+ )
166
195
167
- # For each row, find the column with the highest similarity
168
- # that exceeds the threshold. If no similarity exceeds the threshold,
169
- # set the column index to -1.
170
- masked_similarities = np .where (
171
- similarity_matrix > threshold , similarity_matrix , - 1
172
- )
173
- max_indices = masked_similarities .argmax (axis = 1 )
196
+ # Create mask for lower triangle (avoid self-comparison and redundant
197
+ # checks)
198
+ tril_mask = np .tril (np .ones_like (batch_similarities ), k = - 1 )
199
+ batch_similarities = batch_similarities * tril_mask
174
200
175
- duplicate_to_target_map : Dict [int , int ] = {}
176
- above_threshold = similarity_matrix [np .arange (n ), max_indices ] > threshold
201
+ # Find duplicates in current batch
202
+ masked_similarities = np .where (
203
+ batch_similarities > threshold , batch_similarities , - 1
204
+ )
205
+ max_indices = masked_similarities .argmax (axis = 1 )
206
+ above_threshold = (
207
+ batch_similarities [np .arange (batch_end - i ), max_indices ]
208
+ > threshold
209
+ )
177
210
178
- # Construct the " duplicate->target" mapping.
179
- for i in range ( n ):
180
- if above_threshold [ i ] :
181
- duplicate_to_target_map [i ] = max_indices [i ]
211
+ # Update duplicate map
212
+ for j , is_duplicate in enumerate ( above_threshold ):
213
+ if is_duplicate :
214
+ duplicate_to_target_map [i + j ] = max_indices [j ]
182
215
183
216
# Get the actual unique ids and embeddings.
184
217
unique_ids = []
0 commit comments