diff --git a/.gitignore b/.gitignore index af4e1a7..199fd90 100644 --- a/.gitignore +++ b/.gitignore @@ -136,4 +136,5 @@ venv.bak/ dmypy.json # Pyre type checker -.pyre/ \ No newline at end of file +.pyre/ +test.ipynb diff --git a/README.md b/README.md index bd53391..1b243b6 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ from sparsembed import model, utils, losses from transformers import AutoModelForMaskedLM, AutoTokenizer import torch -device = "cuda" # cpu / cuda / mps +device = "cuda" # cpu / cuda batch_size = 32 model = model.SparsEmbed( diff --git a/sparsembed/model/sparsembed.py b/sparsembed/model/sparsembed.py index 4d33887..06484dc 100644 --- a/sparsembed/model/sparsembed.py +++ b/sparsembed/model/sparsembed.py @@ -65,6 +65,9 @@ class SparsEmbed(torch.nn.Module): 'is great sports good was big has are and wonderful sport huge nice of games ' 'a'] + >>> queries_embeddings["activations"].shape + torch.Size([2, 96]) + References ---------- 1. [SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://dl.acm.org/doi/pdf/10.1145/3539618.3592065) @@ -174,9 +177,7 @@ def forward( return { "embeddings": self.relu(self.linear(embeddings)), "sparse_activations": activations["sparse_activations"], - "activations": self._filter_activations( - activations["sparse_activations"], k=k - ), + "activations": activations["activations"], } def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]: @@ -227,18 +228,9 @@ def _filter_activations( ) -> list[torch.Tensor]: """Among the set of activations, select the ones with a score > 0.""" scores, activations = torch.topk(input=sparse_activations, k=k, dim=-1) - - filter_activations = [] - - for score, activation in zip(scores, activations): - new_activation = torch.index_select( + return [ + torch.index_select( activation, dim=-1, index=torch.nonzero(score, as_tuple=True)[0] ) - - if new_activation.shape[0] == 0: - filter_activations.append(activation) - - else: - filter_activations.append(new_activation) - - return filter_activations + for score, activation in zip(scores, activations) + ] diff --git a/sparsembed/retrieve/__init__.py b/sparsembed/retrieve/__init__.py index 9d0d000..0ffe6fd 100644 --- a/sparsembed/retrieve/__init__.py +++ b/sparsembed/retrieve/__init__.py @@ -1 +1,3 @@ from .retriever import Retriever + +__all__ = ["Retriever"] diff --git a/sparsembed/retrieve/retriever.py b/sparsembed/retrieve/retriever.py index 2cf81f9..98b5412 100644 --- a/sparsembed/retrieve/retriever.py +++ b/sparsembed/retrieve/retriever.py @@ -11,8 +11,7 @@ class Retriever: - """Class dedicated to SparsEmbed model inference in order retrieve - documents with queries. + """Retriever class. Parameters ---------- @@ -32,7 +31,7 @@ class Retriever: >>> _ = torch.manual_seed(42) - >>> device = "mps" + >>> device = "cpu" >>> model = model.SparsEmbed( ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), @@ -63,23 +62,27 @@ class Retriever: ... batch_size=24 ... ) - >>> print(retriever(["Food", "Sports", "Cinema", "Music"], k_token=96)) - [[{'id': 0, 'similarity': 1.4686672687530518}, - {'id': 1, 'similarity': 1.3459084033966064}, - {'id': 3, 'similarity': 1.3040170669555664}, - {'id': 2, 'similarity': 1.157921314239502}], - [{'id': 1, 'similarity': 7.03730583190918}, - {'id': 3, 'similarity': 3.5283799171447754}, - {'id': 2, 'similarity': 2.453505516052246}, - {'id': 0, 'similarity': 1.789308786392212}], - [{'id': 2, 'similarity': 2.316730260848999}, - {'id': 3, 'similarity': 2.2312138080596924}, - {'id': 1, 'similarity': 2.0195863246917725}, - {'id': 0, 'similarity': 1.289013147354126}], - [{'id': 3, 'similarity': 5.773364067077637}, - {'id': 1, 'similarity': 3.6177942752838135}, - {'id': 2, 'similarity': 3.3001816272735596}, - {'id': 0, 'similarity': 2.591763496398926}]] + >>> print(retriever(["Food", "Sports", "Cinema", "Music", "Hello World"], k_token=96)) + [[{'id': 0, 'similarity': 1.4686675071716309}, + {'id': 1, 'similarity': 1.345913052558899}, + {'id': 3, 'similarity': 1.304019808769226}, + {'id': 2, 'similarity': 1.1579231023788452}], + [{'id': 1, 'similarity': 7.0373148918151855}, + {'id': 3, 'similarity': 3.528376817703247}, + {'id': 2, 'similarity': 2.4535036087036133}, + {'id': 0, 'similarity': 1.7893059253692627}], + [{'id': 2, 'similarity': 2.3167333602905273}, + {'id': 3, 'similarity': 2.2312183380126953}, + {'id': 1, 'similarity': 2.0195937156677246}, + {'id': 0, 'similarity': 1.2890148162841797}], + [{'id': 3, 'similarity': 2.4722704887390137}, + {'id': 2, 'similarity': 1.8648046255111694}, + {'id': 1, 'similarity': 1.732576608657837}, + {'id': 0, 'similarity': 1.3416467905044556}], + [{'id': 3, 'similarity': 3.7778899669647217}, + {'id': 2, 'similarity': 3.198120355606079}, + {'id': 1, 'similarity': 3.1253902912139893}, + {'id': 0, 'similarity': 2.458303451538086}]] """ @@ -139,9 +142,9 @@ def add( self.documents_activations.extend(documents_activations) self.sparse_matrix = ( - sparse_matrix + sparse_matrix.T if self.sparse_matrix is None - else sparse.vstack((self.sparse_matrix, sparse_matrix)) + else torch.cat([self.sparse_matrix.to_sparse(), sparse_matrix.T], dim=1) ) self.documents_keys = { @@ -156,7 +159,7 @@ def add( def __call__( self, - q: list[str] | str, + q: list[str], k_sparse: int = 100, k_token: int = 96, batch_size: int = 3, @@ -182,34 +185,39 @@ def __call__( batch_size=batch_size, ) - # TODO: return torch tensor - matchs, _ = self.top_k_by_partition( - similarities=(sparse_matrix @ self.sparse_matrix.T).toarray(), - k_sparse=k_sparse, + sparse_scores = (sparse_matrix @ self.sparse_matrix).to_dense() + + _, sparse_matchs = torch.topk( + input=sparse_scores, k=min(k_sparse, len(self.documents_keys)), dim=-1 ) + sparse_matchs_idx = sparse_matchs.tolist() + # Intersections between queries and documents activated tokens. intersections = self._get_intersection( queries_activations=queries_activations, documents_activations=[ [self.documents_activations[document] for document in query_matchs] - for query_matchs in matchs + for query_matchs in sparse_matchs_idx ], ) - # Optimize to handle batchs - scores = self._get_scores( + dense_scores = self._get_scores( queries_embeddings=queries_embeddings, documents_embeddings=[ [self.documents_embeddings[document] for document in match] - for match in matchs + for match in sparse_matchs_idx ], intersections=intersections, ) - return self._rank(scores=scores, matchs=matchs) + return self._rank( + dense_scores=dense_scores, sparse_matchs=sparse_matchs, k_sparse=k_sparse + ) - def _rank(self, scores: torch.Tensor, matchs: torch.Tensor) -> list: + def _rank( + self, dense_scores: torch.Tensor, sparse_matchs: torch.Tensor, k_sparse: int + ) -> list: """Rank documents by scores. Parameters @@ -219,31 +227,24 @@ def _rank(self, scores: torch.Tensor, matchs: torch.Tensor) -> list: matchs Documents matchs. """ - ranks = torch.argsort(scores, dim=-1, descending=True, stable=False) - - scores = [ - torch.index_select(input=query_documents_score, dim=0, index=query_ranks) - for query_documents_score, query_ranks in zip(scores, ranks) - ] + dense_scores, dense_matchs = torch.topk( + input=dense_scores, k=min(k_sparse, len(self.documents_keys)), dim=-1 + ) - matchs = [ - torch.index_select( - input=torch.tensor(query_matchs).to(self.model.device), - dim=0, - index=query_ranks, - ) - for query_matchs, query_ranks in zip(matchs, ranks) - ] + dense_scores = dense_scores.tolist() + dense_matchs = ( + torch.gather(sparse_matchs, 1, dense_matchs).detach().cpu().tolist() + ) return [ [ { - self.key: self.documents_keys[document.item()], - "similarity": score.item(), + self.key: self.documents_keys[document], + "similarity": score, } for score, document in zip(query_scores, query_matchs) ] - for query_scores, query_matchs in zip(scores, matchs) + for query_scores, query_matchs in zip(dense_scores, dense_matchs) ] def _build_index( @@ -251,69 +252,36 @@ def _build_index( X: list[str], batch_size: int, k_token: int, - ) -> tuple[list, list, sparse.csr_matrix]: - """Build""" - index_embeddings, index_activations, rows, columns, values = [], [], [], [], [] - n = 0 + ) -> tuple[list, list, sparse.csc_matrix]: + """Build a sparse matrix index.""" + index_embeddings, index_activations, sparse_activations = [], [], [] for batch in self._to_batch(X, batch_size=batch_size): batch_embeddings = self.model.encode(batch, k=k_token) - for activations, embeddings, sparse_activations in zip( + sparse_activations.append( + batch_embeddings["sparse_activations"].to_sparse() + ) + + for activations, activations_idx, embeddings in zip( batch_embeddings["activations"], + batch_embeddings["activations"].detach().cpu().tolist(), batch_embeddings["embeddings"], - batch_embeddings["sparse_activations"], ): index_activations.append(activations) index_embeddings.append( { - token.item(): embedding - for token, embedding in zip(activations, embeddings) + token: embedding + for token, embedding in zip(activations_idx, embeddings) } ) - tokens_scores = torch.index_select( - sparse_activations, dim=-1, index=activations - ) - - rows.extend([n for _ in range(len(activations))]) - columns.extend(activations.tolist()) - values.extend(tokens_scores.tolist()) - n += 1 - - sparse_matrix = sparse.csc_matrix( - (values, (rows, columns)), shape=(len(X), self.vocabulary_size) + return ( + index_embeddings, + index_activations, + torch.cat(sparse_activations), ) - return index_embeddings, index_activations, sparse_matrix - - def top_k_by_partition( - self, similarities: np.ndarray, k_sparse: int - ) -> tuple[np.ndarray, np.ndarray]: - """Top k elements by partition.""" - similarities *= -1 - - if k_sparse < len(self.documents_keys): - ind = np.argpartition(similarities, k_sparse, axis=-1) - - # k non-sorted indices - ind = np.take(ind, np.arange(k_sparse), axis=-1) - - # k non-sorted values - similarities = np.take_along_axis(similarities, ind, axis=-1) - - # sort within k elements - ind_part = np.argsort(similarities, axis=-1) - ind = np.take_along_axis(ind, ind_part, axis=-1) - - else: - ind_part = np.argsort(similarities, axis=-1) - ind = ind_part - - similarities *= -1 - val = np.take_along_axis(similarities, ind_part, axis=-1) - return ind, val - @staticmethod def _to_batch(X: list, batch_size: int) -> list: """Convert input list to batch."""