Skip to content

Commit

Permalink
refactor: Internal deduplication (#1579)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wendong-Fan authored Feb 10, 2025
1 parent cf3d9c5 commit 3606d96
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 40 deletions.
105 changes: 69 additions & 36 deletions camel/utils/deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,22 @@

from typing import Dict, List, Literal, Optional

import numpy as np
from pydantic import BaseModel
from sklearn.metrics.pairwise import cosine_similarity

from camel.embeddings.base import BaseEmbedding


class DeduplicationResult(BaseModel):
"""
The result of deduplication.
r"""The result of deduplication.
Attributes:
original_texts (List[str]): The original texts.
unique_ids (List[int]): A list of ids that are unique (not duplicates).
unique_embeddings_dict (Dict[int, List[float]]):
A mapping from the index of each unique text to its embedding.
duplicate_to_target_map (Dict[int, int]):
A mapping from the index of the duplicate text to the index
of the text it is considered a duplicate of.
unique_embeddings_dict (Dict[int, List[float]]): A mapping from the
index of each unique text to its embedding.
duplicate_to_target_map (Dict[int, int]): A mapping from the index of
the duplicate text to the index of the text it is considered a
duplicate of.
"""

original_texts: List[str]
Expand All @@ -48,12 +45,12 @@ def deduplicate_internally(
embedding_instance: Optional[BaseEmbedding[str]] = None,
embeddings: Optional[List[List[float]]] = None,
strategy: Literal["top1", "llm-supervise"] = "top1",
batch_size: int = 1000,
) -> DeduplicationResult:
"""
Deduplicate a list of strings based on their cosine similarity.
r"""Deduplicate a list of strings based on their cosine similarity.
You can either:
1) Provide a Camel `BaseEmbedding` instance via `embedding_instance` to let
1) Provide a CAMEL `BaseEmbedding` instance via `embedding_instance` to let
this function handle the embedding internally, OR
2) Directly pass a list of pre-computed embeddings to `embeddings`.
Expand All @@ -67,16 +64,18 @@ def deduplicate_internally(
Args:
texts (List[str]): The list of texts to be deduplicated.
threshold (float, optional): The similarity threshold for considering
two texts as duplicates. Default is 0.65.
two texts as duplicates. (default: :obj:`0.65`)
embedding_instance (Optional[BaseEmbedding[str]], optional):
A Camel embedding instance for automatic embedding. Defaults to
None.
A CAMEL embedding instance for automatic embedding. (default:
:obj:`None`)
embeddings (Optional[List[List[float]]], optional):
Pre-computed embeddings of `texts`. Each element in the list
corresponds to the embedding of the text in the same index of
`texts`. Defaults to None.
`texts`. (default: :obj:`None`)
strategy (Literal["top1", "llm-supervise"], optional):
The strategy to use for deduplication. Defaults to "top1".
The strategy to use for deduplication. (default: :obj:`"top1"`)
batch_size (int, optional): The size of the batch to use for
calculating cosine similarities. (default: :obj:`1000`)
Returns:
DeduplicationResult: An object that contains:
Expand Down Expand Up @@ -127,13 +126,39 @@ def deduplicate_internally(
# This indicates the text at index 2 is considered
# a duplicate of index 0.
"""
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

if len(texts) == 0:
return DeduplicationResult(
original_texts=[],
unique_ids=[],
unique_embeddings_dict={},
duplicate_to_target_map={},
)

if len(texts) == 1:
return DeduplicationResult(
original_texts=texts,
unique_ids=[0],
unique_embeddings_dict={
0: embeddings[0]
if embeddings
else embedding_instance.embed_list(texts)[0] # type: ignore[union-attr]
},
duplicate_to_target_map={},
)

if strategy == "llm-supervise":
# TODO: Implement LLM-supervise deduplication.
raise NotImplementedError(
"LLM-supervise deduplication is not yet implemented."
)

# Check if the parameters are valid.
if not 0 <= threshold <= 1:
raise ValueError("Threshold must be between 0 and 1")

if embedding_instance is None and embeddings is None:
raise ValueError(
"Either 'embedding_instance' or 'embeddings' must be provided."
Expand All @@ -155,30 +180,38 @@ def deduplicate_internally(
"of 'texts'."
)

# Calculate cosine similarity.
similarity_matrix = cosine_similarity(embeddings)
# Convert embeddings to numpy array for efficient computation
embeddings_array = np.array(embeddings)
n = len(texts)
duplicate_to_target_map: Dict[int, int] = {}

# Use the lower triangle to avoid redundant comparisons
# (or self-comparisons).
tril_mask = np.tril(np.ones((n, n)), k=-1)
similarity_matrix = similarity_matrix * tril_mask
# Process in batches to reduce memory usage
for i in range(0, n, batch_size):
batch_end = min(i + batch_size, n)
# Calculate cosine similarity for current batch
batch_similarities = cosine_similarity(
embeddings_array[i:batch_end], embeddings_array[:batch_end]
)

# For each row, find the column with the highest similarity
# that exceeds the threshold. If no similarity exceeds the threshold,
# set the column index to -1.
masked_similarities = np.where(
similarity_matrix > threshold, similarity_matrix, -1
)
max_indices = masked_similarities.argmax(axis=1)
# Create mask for lower triangle (avoid self-comparison and redundant
# checks)
tril_mask = np.tril(np.ones_like(batch_similarities), k=-1)
batch_similarities = batch_similarities * tril_mask

duplicate_to_target_map: Dict[int, int] = {}
above_threshold = similarity_matrix[np.arange(n), max_indices] > threshold
# Find duplicates in current batch
masked_similarities = np.where(
batch_similarities > threshold, batch_similarities, -1
)
max_indices = masked_similarities.argmax(axis=1)
above_threshold = (
batch_similarities[np.arange(batch_end - i), max_indices]
> threshold
)

# Construct the "duplicate->target" mapping.
for i in range(n):
if above_threshold[i]:
duplicate_to_target_map[i] = max_indices[i]
# Update duplicate map
for j, is_duplicate in enumerate(above_threshold):
if is_duplicate:
duplicate_to_target_map[i + j] = max_indices[j]

# Get the actual unique ids and embeddings.
unique_ids = []
Expand Down
36 changes: 32 additions & 4 deletions test/utils/test_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@


class MockEmbedding(BaseEmbedding[str]):
"""
A mock embedding class that always returns the same embedding vector
r"""A mock embedding class that always returns the same embedding vector
for any input text. Useful for testing deduplication logic.
"""

Expand All @@ -36,6 +35,36 @@ def get_output_dim(self) -> int:
return 3


def test_deduplicate_internally_empty_list():
mock_embedding_instance = MockEmbedding()
result = deduplicate_internally(
texts=[],
threshold=0.9,
embedding_instance=mock_embedding_instance,
strategy="top1",
)
assert len(result.original_texts) == 0
assert len(result.unique_ids) == 0
assert len(result.unique_embeddings_dict) == 0
assert len(result.duplicate_to_target_map) == 0


def test_deduplicate_internally_single_item():
mock_embedding_instance = MockEmbedding()
texts = ["Hello world!"]
result = deduplicate_internally(
texts=texts,
threshold=0.9,
embedding_instance=mock_embedding_instance,
strategy="top1",
)
assert result.original_texts == texts
assert result.unique_ids == [0]
assert len(result.unique_embeddings_dict) == 1
assert 0 in result.unique_embeddings_dict
assert len(result.duplicate_to_target_map) == 0


def test_deduplicate_internally_with_mock_embedding():
texts = ["Hello world!", "Hello world!", "HELLO WORLD!", "Something else"]
mock_embedding_instance = MockEmbedding()
Expand Down Expand Up @@ -116,8 +145,7 @@ def test_deduplicate_internally_with_precomputed_embeddings():


def test_deduplicate_internally_chain_scenario():
"""
Test scenario:
r"""Test scenario:
- A <-> B similarity > threshold
- B <-> C similarity > threshold
- C <-> D similarity > threshold
Expand Down

0 comments on commit 3606d96

Please sign in to comment.