|
| 1 | +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# |
| 6 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Unless required by applicable law or agreed to in writing, software |
| 9 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +# See the License for the specific language governing permissions and |
| 12 | +# limitations under the License. |
| 13 | +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= |
| 14 | + |
| 15 | + |
| 16 | +from typing import Dict, List, Literal, Optional |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +from pydantic import BaseModel |
| 20 | +from sklearn.metrics.pairwise import cosine_similarity |
| 21 | + |
| 22 | +from camel.embeddings.base import BaseEmbedding |
| 23 | + |
| 24 | + |
| 25 | +class DeduplicationResult(BaseModel): |
| 26 | + """ |
| 27 | + The result of deduplication. |
| 28 | +
|
| 29 | + Attributes: |
| 30 | + original_texts (List[str]): The original texts. |
| 31 | + unique_ids (List[int]): A list of ids that are unique (not duplicates). |
| 32 | + unique_embeddings_dict (Dict[int, List[float]]): |
| 33 | + A mapping from the index of each unique text to its embedding. |
| 34 | + duplicate_to_target_map (Dict[int, int]): |
| 35 | + A mapping from the index of the duplicate text to the index |
| 36 | + of the text it is considered a duplicate of. |
| 37 | + """ |
| 38 | + |
| 39 | + original_texts: List[str] |
| 40 | + unique_ids: List[int] |
| 41 | + unique_embeddings_dict: Dict[int, List[float]] |
| 42 | + duplicate_to_target_map: Dict[int, int] |
| 43 | + |
| 44 | + |
| 45 | +def deduplicate_internally( |
| 46 | + texts: List[str], |
| 47 | + threshold: float = 0.65, |
| 48 | + embedding_instance: Optional[BaseEmbedding[str]] = None, |
| 49 | + embeddings: Optional[List[List[float]]] = None, |
| 50 | + strategy: Literal["top1", "llm-supervise"] = "top1", |
| 51 | +) -> DeduplicationResult: |
| 52 | + """ |
| 53 | + Deduplicate a list of strings based on their cosine similarity. |
| 54 | +
|
| 55 | + You can either: |
| 56 | + 1) Provide a Camel `BaseEmbedding` instance via `embedding_instance` to let |
| 57 | + this function handle the embedding internally, OR |
| 58 | + 2) Directly pass a list of pre-computed embeddings to `embeddings`. |
| 59 | +
|
| 60 | + If both `embedding_instance` and `embeddings` are provided, the function |
| 61 | + will raise a ValueError to avoid ambiguous usage. |
| 62 | +
|
| 63 | + strategy is used to specify different strategies, where 'top1' selects the |
| 64 | + one with highest similarity, and 'llm-supervise' uses LLM to determine if |
| 65 | + texts are duplicates (not yet implemented). |
| 66 | +
|
| 67 | + Args: |
| 68 | + texts (List[str]): The list of texts to be deduplicated. |
| 69 | + threshold (float, optional): The similarity threshold for considering |
| 70 | + two texts as duplicates. Default is 0.65. |
| 71 | + embedding_instance (Optional[BaseEmbedding[str]], optional): |
| 72 | + A Camel embedding instance for automatic embedding. Defaults to |
| 73 | + None. |
| 74 | + embeddings (Optional[List[List[float]]], optional): |
| 75 | + Pre-computed embeddings of `texts`. Each element in the list |
| 76 | + corresponds to the embedding of the text in the same index of |
| 77 | + `texts`. Defaults to None. |
| 78 | + strategy (Literal["top1", "llm-supervise"], optional): |
| 79 | + The strategy to use for deduplication. Defaults to "top1". |
| 80 | +
|
| 81 | + Returns: |
| 82 | + DeduplicationResult: An object that contains: |
| 83 | + - `original_texts`: The original texts. |
| 84 | + - `unique_ids`: The unique ids after deduplication. |
| 85 | + - `unique_embeddings_dict`: A dict mapping from (unique) text id |
| 86 | + to its embedding. |
| 87 | + - `duplicate_to_target_map`: A dict mapping from the id of a |
| 88 | + duplicate text to the id of the text it is considered a duplicate |
| 89 | + of. |
| 90 | +
|
| 91 | + Raises: |
| 92 | + NotImplementedError: If the strategy is not "top1". |
| 93 | + ValueError: If neither embeddings nor embedding_instance is provided, |
| 94 | + or if both are provided at the same time. |
| 95 | + ValueError: If the length of `embeddings` does not match the length of |
| 96 | + `texts`. |
| 97 | +
|
| 98 | + Example: |
| 99 | + >>> from camel.embeddings.openai_embedding import OpenAIEmbedding |
| 100 | + >>> # Suppose we have 5 texts, some of which may be duplicates |
| 101 | + >>> texts = [ |
| 102 | + ... "What is AI?", |
| 103 | + ... "Artificial Intelligence is about machines", |
| 104 | + ... "What is AI?", |
| 105 | + ... "Deep Learning is a subset of AI", |
| 106 | + ... "What is artificial intelligence?" |
| 107 | + ... ] |
| 108 | + >>> # or any other BaseEmbedding instance |
| 109 | + >>> embedding_model = OpenAIEmbedding() |
| 110 | + >>> result = deduplicate_internally( |
| 111 | + ... texts=texts, |
| 112 | + ... threshold=0.7, |
| 113 | + ... embedding_instance=embedding_model |
| 114 | + ... ) |
| 115 | + >>> print("Unique ids:") |
| 116 | + >>> for uid in result.unique_ids: |
| 117 | + ... print(texts[uid]) |
| 118 | + Unique ids: |
| 119 | + What is AI? |
| 120 | + Artificial Intelligence is about machines |
| 121 | + Deep Learning is a subset of AI |
| 122 | + What is artificial intelligence? |
| 123 | +
|
| 124 | + >>> print("Duplicate map:") |
| 125 | + >>> print(result.duplicate_to_target_map) |
| 126 | + {2: 0} |
| 127 | + # This indicates the text at index 2 is considered |
| 128 | + # a duplicate of index 0. |
| 129 | + """ |
| 130 | + if strategy == "llm-supervise": |
| 131 | + # TODO: Implement LLM-supervise deduplication. |
| 132 | + raise NotImplementedError( |
| 133 | + "LLM-supervise deduplication is not yet implemented." |
| 134 | + ) |
| 135 | + |
| 136 | + # Check if the parameters are valid. |
| 137 | + if embedding_instance is None and embeddings is None: |
| 138 | + raise ValueError( |
| 139 | + "Either 'embedding_instance' or 'embeddings' must be provided." |
| 140 | + ) |
| 141 | + if embedding_instance is not None and embeddings is not None: |
| 142 | + raise ValueError( |
| 143 | + "Cannot provide both 'embedding_instance' and 'embeddings'. " |
| 144 | + "Please choose only one way to supply embeddings." |
| 145 | + ) |
| 146 | + |
| 147 | + if embedding_instance is not None: |
| 148 | + # Use Camel's embedding_instance to vectorize. |
| 149 | + embeddings = embedding_instance.embed_list(texts) |
| 150 | + else: |
| 151 | + # Use pre-supplied embeddings. |
| 152 | + if embeddings and len(embeddings) != len(texts): |
| 153 | + raise ValueError( |
| 154 | + "The length of 'embeddings' does not match the length " |
| 155 | + "of 'texts'." |
| 156 | + ) |
| 157 | + |
| 158 | + # Calculate cosine similarity. |
| 159 | + similarity_matrix = cosine_similarity(embeddings) |
| 160 | + n = len(texts) |
| 161 | + |
| 162 | + # Use the lower triangle to avoid redundant comparisons |
| 163 | + # (or self-comparisons). |
| 164 | + tril_mask = np.tril(np.ones((n, n)), k=-1) |
| 165 | + similarity_matrix = similarity_matrix * tril_mask |
| 166 | + |
| 167 | + # For each row, find the column with the highest similarity |
| 168 | + # that exceeds the threshold. If no similarity exceeds the threshold, |
| 169 | + # set the column index to -1. |
| 170 | + masked_similarities = np.where( |
| 171 | + similarity_matrix > threshold, similarity_matrix, -1 |
| 172 | + ) |
| 173 | + max_indices = masked_similarities.argmax(axis=1) |
| 174 | + |
| 175 | + duplicate_to_target_map: Dict[int, int] = {} |
| 176 | + above_threshold = similarity_matrix[np.arange(n), max_indices] > threshold |
| 177 | + |
| 178 | + # Construct the "duplicate->target" mapping. |
| 179 | + for i in range(n): |
| 180 | + if above_threshold[i]: |
| 181 | + duplicate_to_target_map[i] = max_indices[i] |
| 182 | + |
| 183 | + # Get the actual unique ids and embeddings. |
| 184 | + unique_ids = [] |
| 185 | + unique_embeddings_dict = {} |
| 186 | + |
| 187 | + assert embeddings, "embeddings must be valid" |
| 188 | + |
| 189 | + for i, (_, emb) in enumerate(zip(texts, embeddings)): |
| 190 | + if i not in duplicate_to_target_map: |
| 191 | + unique_ids.append(i) |
| 192 | + unique_embeddings_dict[i] = emb |
| 193 | + |
| 194 | + return DeduplicationResult( |
| 195 | + original_texts=texts, |
| 196 | + unique_ids=unique_ids, |
| 197 | + unique_embeddings_dict=unique_embeddings_dict, |
| 198 | + duplicate_to_target_map=duplicate_to_target_map, |
| 199 | + ) |
0 commit comments