Skip to content

Commit

Permalink
Update retriever with sparse tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelsty committed Aug 6, 2023
1 parent b228fdd commit 344dd4c
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 116 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,5 @@ venv.bak/
dmypy.json

# Pyre type checker
.pyre/
.pyre/
test.ipynb
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ from sparsembed import model, utils, losses
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

device = "cuda" # cpu / cuda / mps
device = "cuda" # cpu / cuda
batch_size = 32

model = model.SparsEmbed(
Expand Down
24 changes: 8 additions & 16 deletions sparsembed/model/sparsembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class SparsEmbed(torch.nn.Module):
'is great sports good was big has are and wonderful sport huge nice of games '
'a']
>>> queries_embeddings["activations"].shape
torch.Size([2, 96])
References
----------
1. [SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://dl.acm.org/doi/pdf/10.1145/3539618.3592065)
Expand Down Expand Up @@ -174,9 +177,7 @@ def forward(
return {
"embeddings": self.relu(self.linear(embeddings)),
"sparse_activations": activations["sparse_activations"],
"activations": self._filter_activations(
activations["sparse_activations"], k=k
),
"activations": activations["activations"],
}

def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -227,18 +228,9 @@ def _filter_activations(
) -> list[torch.Tensor]:
"""Among the set of activations, select the ones with a score > 0."""
scores, activations = torch.topk(input=sparse_activations, k=k, dim=-1)

filter_activations = []

for score, activation in zip(scores, activations):
new_activation = torch.index_select(
return [
torch.index_select(
activation, dim=-1, index=torch.nonzero(score, as_tuple=True)[0]
)

if new_activation.shape[0] == 0:
filter_activations.append(activation)

else:
filter_activations.append(new_activation)

return filter_activations
for score, activation in zip(scores, activations)
]
2 changes: 2 additions & 0 deletions sparsembed/retrieve/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .retriever import Retriever

__all__ = ["Retriever"]
164 changes: 66 additions & 98 deletions sparsembed/retrieve/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@


class Retriever:
"""Class dedicated to SparsEmbed model inference in order retrieve
documents with queries.
"""Retriever class.
Parameters
----------
Expand All @@ -32,7 +31,7 @@ class Retriever:
>>> _ = torch.manual_seed(42)
>>> device = "mps"
>>> device = "cpu"
>>> model = model.SparsEmbed(
... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
Expand Down Expand Up @@ -63,23 +62,27 @@ class Retriever:
... batch_size=24
... )
>>> print(retriever(["Food", "Sports", "Cinema", "Music"], k_token=96))
[[{'id': 0, 'similarity': 1.4686672687530518},
{'id': 1, 'similarity': 1.3459084033966064},
{'id': 3, 'similarity': 1.3040170669555664},
{'id': 2, 'similarity': 1.157921314239502}],
[{'id': 1, 'similarity': 7.03730583190918},
{'id': 3, 'similarity': 3.5283799171447754},
{'id': 2, 'similarity': 2.453505516052246},
{'id': 0, 'similarity': 1.789308786392212}],
[{'id': 2, 'similarity': 2.316730260848999},
{'id': 3, 'similarity': 2.2312138080596924},
{'id': 1, 'similarity': 2.0195863246917725},
{'id': 0, 'similarity': 1.289013147354126}],
[{'id': 3, 'similarity': 5.773364067077637},
{'id': 1, 'similarity': 3.6177942752838135},
{'id': 2, 'similarity': 3.3001816272735596},
{'id': 0, 'similarity': 2.591763496398926}]]
>>> print(retriever(["Food", "Sports", "Cinema", "Music", "Hello World"], k_token=96))
[[{'id': 0, 'similarity': 1.4686675071716309},
{'id': 1, 'similarity': 1.345913052558899},
{'id': 3, 'similarity': 1.304019808769226},
{'id': 2, 'similarity': 1.1579231023788452}],
[{'id': 1, 'similarity': 7.0373148918151855},
{'id': 3, 'similarity': 3.528376817703247},
{'id': 2, 'similarity': 2.4535036087036133},
{'id': 0, 'similarity': 1.7893059253692627}],
[{'id': 2, 'similarity': 2.3167333602905273},
{'id': 3, 'similarity': 2.2312183380126953},
{'id': 1, 'similarity': 2.0195937156677246},
{'id': 0, 'similarity': 1.2890148162841797}],
[{'id': 3, 'similarity': 2.4722704887390137},
{'id': 2, 'similarity': 1.8648046255111694},
{'id': 1, 'similarity': 1.732576608657837},
{'id': 0, 'similarity': 1.3416467905044556}],
[{'id': 3, 'similarity': 3.7778899669647217},
{'id': 2, 'similarity': 3.198120355606079},
{'id': 1, 'similarity': 3.1253902912139893},
{'id': 0, 'similarity': 2.458303451538086}]]
"""

Expand Down Expand Up @@ -139,9 +142,9 @@ def add(
self.documents_activations.extend(documents_activations)

self.sparse_matrix = (
sparse_matrix
sparse_matrix.T
if self.sparse_matrix is None
else sparse.vstack((self.sparse_matrix, sparse_matrix))
else torch.cat([self.sparse_matrix.to_sparse(), sparse_matrix.T], dim=1)
)

self.documents_keys = {
Expand All @@ -156,7 +159,7 @@ def add(

def __call__(
self,
q: list[str] | str,
q: list[str],
k_sparse: int = 100,
k_token: int = 96,
batch_size: int = 3,
Expand All @@ -182,34 +185,39 @@ def __call__(
batch_size=batch_size,
)

# TODO: return torch tensor
matchs, _ = self.top_k_by_partition(
similarities=(sparse_matrix @ self.sparse_matrix.T).toarray(),
k_sparse=k_sparse,
sparse_scores = (sparse_matrix @ self.sparse_matrix).to_dense()

_, sparse_matchs = torch.topk(
input=sparse_scores, k=min(k_sparse, len(self.documents_keys)), dim=-1
)

sparse_matchs_idx = sparse_matchs.tolist()

# Intersections between queries and documents activated tokens.
intersections = self._get_intersection(
queries_activations=queries_activations,
documents_activations=[
[self.documents_activations[document] for document in query_matchs]
for query_matchs in matchs
for query_matchs in sparse_matchs_idx
],
)

# Optimize to handle batchs
scores = self._get_scores(
dense_scores = self._get_scores(
queries_embeddings=queries_embeddings,
documents_embeddings=[
[self.documents_embeddings[document] for document in match]
for match in matchs
for match in sparse_matchs_idx
],
intersections=intersections,
)

return self._rank(scores=scores, matchs=matchs)
return self._rank(
dense_scores=dense_scores, sparse_matchs=sparse_matchs, k_sparse=k_sparse
)

def _rank(self, scores: torch.Tensor, matchs: torch.Tensor) -> list:
def _rank(
self, dense_scores: torch.Tensor, sparse_matchs: torch.Tensor, k_sparse: int
) -> list:
"""Rank documents by scores.
Parameters
Expand All @@ -219,101 +227,61 @@ def _rank(self, scores: torch.Tensor, matchs: torch.Tensor) -> list:
matchs
Documents matchs.
"""
ranks = torch.argsort(scores, dim=-1, descending=True, stable=False)

scores = [
torch.index_select(input=query_documents_score, dim=0, index=query_ranks)
for query_documents_score, query_ranks in zip(scores, ranks)
]
dense_scores, dense_matchs = torch.topk(
input=dense_scores, k=min(k_sparse, len(self.documents_keys)), dim=-1
)

matchs = [
torch.index_select(
input=torch.tensor(query_matchs).to(self.model.device),
dim=0,
index=query_ranks,
)
for query_matchs, query_ranks in zip(matchs, ranks)
]
dense_scores = dense_scores.tolist()
dense_matchs = (
torch.gather(sparse_matchs, 1, dense_matchs).detach().cpu().tolist()
)

return [
[
{
self.key: self.documents_keys[document.item()],
"similarity": score.item(),
self.key: self.documents_keys[document],
"similarity": score,
}
for score, document in zip(query_scores, query_matchs)
]
for query_scores, query_matchs in zip(scores, matchs)
for query_scores, query_matchs in zip(dense_scores, dense_matchs)
]

def _build_index(
self,
X: list[str],
batch_size: int,
k_token: int,
) -> tuple[list, list, sparse.csr_matrix]:
"""Build"""
index_embeddings, index_activations, rows, columns, values = [], [], [], [], []
n = 0
) -> tuple[list, list, sparse.csc_matrix]:
"""Build a sparse matrix index."""
index_embeddings, index_activations, sparse_activations = [], [], []

for batch in self._to_batch(X, batch_size=batch_size):
batch_embeddings = self.model.encode(batch, k=k_token)

for activations, embeddings, sparse_activations in zip(
sparse_activations.append(
batch_embeddings["sparse_activations"].to_sparse()
)

for activations, activations_idx, embeddings in zip(
batch_embeddings["activations"],
batch_embeddings["activations"].detach().cpu().tolist(),
batch_embeddings["embeddings"],
batch_embeddings["sparse_activations"],
):
index_activations.append(activations)
index_embeddings.append(
{
token.item(): embedding
for token, embedding in zip(activations, embeddings)
token: embedding
for token, embedding in zip(activations_idx, embeddings)
}
)

tokens_scores = torch.index_select(
sparse_activations, dim=-1, index=activations
)

rows.extend([n for _ in range(len(activations))])
columns.extend(activations.tolist())
values.extend(tokens_scores.tolist())
n += 1

sparse_matrix = sparse.csc_matrix(
(values, (rows, columns)), shape=(len(X), self.vocabulary_size)
return (
index_embeddings,
index_activations,
torch.cat(sparse_activations),
)

return index_embeddings, index_activations, sparse_matrix

def top_k_by_partition(
self, similarities: np.ndarray, k_sparse: int
) -> tuple[np.ndarray, np.ndarray]:
"""Top k elements by partition."""
similarities *= -1

if k_sparse < len(self.documents_keys):
ind = np.argpartition(similarities, k_sparse, axis=-1)

# k non-sorted indices
ind = np.take(ind, np.arange(k_sparse), axis=-1)

# k non-sorted values
similarities = np.take_along_axis(similarities, ind, axis=-1)

# sort within k elements
ind_part = np.argsort(similarities, axis=-1)
ind = np.take_along_axis(ind, ind_part, axis=-1)

else:
ind_part = np.argsort(similarities, axis=-1)
ind = ind_part

similarities *= -1
val = np.take_along_axis(similarities, ind_part, axis=-1)
return ind, val

@staticmethod
def _to_batch(X: list, batch_size: int) -> list:
"""Convert input list to batch."""
Expand Down

0 comments on commit 344dd4c

Please sign in to comment.