From ece76ba40bfc848904961bdc94200e97eea38438 Mon Sep 17 00:00:00 2001 From: Raphael Sourty Date: Mon, 21 Aug 2023 17:58:58 +0200 Subject: [PATCH] Add splade --- LICENSE | 21 -- README.md | 262 ++++++++++-------- setup.py | 8 +- sparsembed/__init__.py | 2 +- sparsembed/__version__.py | 2 +- sparsembed/losses/__init__.py | 4 +- sparsembed/losses/cosine.py | 97 ------- sparsembed/losses/flops.py | 52 ++-- sparsembed/losses/ranking.py | 73 +++++ sparsembed/model/__init__.py | 3 +- sparsembed/model/sparsembed.py | 146 ++++------ sparsembed/model/splade.py | 162 +++++++++++ sparsembed/retrieve/__init__.py | 5 +- .../{retriever.py => sparsembed_retriever.py} | 92 +++--- sparsembed/retrieve/splade_retriever.py | 214 ++++++++++++++ sparsembed/train/__init__.py | 4 + sparsembed/train/train_sparsembed.py | 132 +++++++++ sparsembed/train/train_splade.py | 111 ++++++++ sparsembed/utils/__init__.py | 15 +- sparsembed/utils/dense_scores.py | 207 ++++++++++++++ sparsembed/utils/evaluate.py | 125 ++++++++- sparsembed/utils/in_batch.py | 73 +++++ sparsembed/utils/iter.py | 36 +-- sparsembed/utils/scores.py | 142 ---------- sparsembed/utils/sparse_scores.py | 76 +++++ 25 files changed, 1461 insertions(+), 603 deletions(-) delete mode 100644 LICENSE delete mode 100644 sparsembed/losses/cosine.py create mode 100644 sparsembed/losses/ranking.py create mode 100644 sparsembed/model/splade.py rename sparsembed/retrieve/{retriever.py => sparsembed_retriever.py} (78%) create mode 100644 sparsembed/retrieve/splade_retriever.py create mode 100644 sparsembed/train/__init__.py create mode 100644 sparsembed/train/train_sparsembed.py create mode 100644 sparsembed/train/train_splade.py create mode 100644 sparsembed/utils/dense_scores.py create mode 100644 sparsembed/utils/in_batch.py delete mode 100644 sparsembed/utils/scores.py create mode 100644 sparsembed/utils/sparse_scores.py diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 72406fe..0000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 Raphael Sourty - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/README.md b/README.md index 0552bf8..b9164e5 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,15 @@
-

SparsEmbed

+

SparsEmbed - Splade

Neural search

-This repository presents an unofficial replication of the research paper *[SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://research.google/pubs/pub52289/)* authored by Weize Kong, Jeffrey M. Dudek, Cheng Li, Mingyang Zhang, and Mike Bendersky, SIGIR 2023. +This repository presents an unofficial replication of the research papers: -**Note:** This project is currently a work in progress. πŸ”¨πŸ§Ή +- *[SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/abs/2107.05720)* authored by Thibault Formal, Benjamin Piwowarski, StΓ©phane Clinchant, SIGIR 2021. -## Overview +- *[SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://research.google/pubs/pub52289/)* authored by Weize Kong, Jeffrey M. Dudek, Cheng Li, Mingyang Zhang, and Mike Bendersky, SIGIR 2023. -This repository aims to replicate the SparseEmbed model, focusing on learning both sparse lexical representations and contextual token level embeddings for retrieval tasks. - -The `SparsEmbed` model available here is compatible with any model compatible with the class `AutoModelForMaskedLM` from HuggingFace. - -### Differences with the original paper - -1. **Loss Function:** We did not yet implement the distillation loss used in the paper. We have initially opted for a cosine loss like the one used in SentenceTransformer library. This decision was made to fine-tune the model from scratch, avoiding the use of a cross-encoder as a teacher. The distillation loss should be available soon. - -2. **Multi-Head Implementation:** At this stage, the distinct MLM (Masked Language Model) head for document encoding and query encoding has not been incorporated. Our current implementation employs a shared MLM head (calculating sparse activations) for both documents and queries. +**Note:** This project is currently a work in progress and models are not ready to use. πŸ”¨πŸ§Ή ## Installation @@ -25,143 +17,169 @@ The `SparsEmbed` model available here is compatible with any model compatible wi pip install sparsembed ``` +If you plan to evaluate your model, install: + +``` +pip install "sparsembed[eval]" +``` + ## Training -The following PyTorch code snippet illustrates the training loop to fine-tune the model: +### Dataset + +Your training dataset must be made out of triples `(anchor, positive, negative)` where anchor is a query, positive is a document that is directly linked to the anchor and negative is a document that is not relevant for the query. ```python -from sparsembed import model, utils, losses -from transformers import AutoModelForMaskedLM, AutoTokenizer -import torch +X = [ + ("anchor 1", "positive 1", "negative 1"), + ("anchor 2", "positive 2", "negative 2"), + ("anchor 3", "positive 3", "negative 3"), +] +``` -device = "cuda" # cpu / cuda -batch_size = 32 +### Models -model = model.SparsEmbed( - model=AutoModelForMaskedLM.from_pretrained("Luyu/co-condenser-marco").to(device), - tokenizer=AutoTokenizer.from_pretrained("Luyu/co-condenser-marco"), - device=device, -) +Both Splade and SparseEmbed models can be initialized from the `AutoModelForMaskedLM` pretrained models. -model = model.to(device) +```python +from transformers import AutoModelForMaskedLM, AutoTokenizer -optimizer = torch.optim.AdamW( - filter(lambda p: p.requires_grad, model.parameters()), - lr=2e-5, +model = model.Splade( + model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + device=device ) +``` -flops_loss = losses.Flops() +### Splade -cosine_loss = losses.Cosine() +The following PyTorch code snippet illustrates the training loop to fine-tune Splade: -dataset = [ - # Query, Document, Label (1: Relevant, 0: Not Relevant) - ("Apple", "Apple is a popular fruit.", 1), - ("Apple", "Banana is a popular fruit.", 0), - ("Banana", "Apple is a popular fruit.", 0), - ("Banana", "Banana is a yellow fruit.", 1), -] +```python +from transformers import AutoModelForMaskedLM, AutoTokenizer +from sparsembed import model, utils, train, retrieve +import torch -for queries, documents, labels in utils.iter( - dataset, - device=device, - epochs=1, - batch_size=batch_size, - shuffle=True, -): - queries_embeddings = model(queries, k=32) - - documents_embeddings = model(documents, k=32) - - scores = utils.scores( - queries_activations=queries_embeddings["activations"], - queries_embeddings=queries_embeddings["embeddings"], - documents_activations=documents_embeddings["activations"], - documents_embeddings=documents_embeddings["embeddings"], - device=device, - ) - - loss = cosine_loss.dense( - scores=scores, - labels=labels, - ) - - loss += 0.1 * cosine_loss.sparse( - queries_sparse_activations=queries_embeddings["sparse_activations"], - documents_sparse_activations=documents_embeddings["sparse_activations"], - labels=labels, - ) - - loss += 4e-3 * flops_loss( - sparse_activations=queries_embeddings["sparse_activations"] - ) - loss += 4e-3 * flops_loss( - sparse_activations=documents_embeddings["sparse_activations"] - ) - - loss.backward() - optimizer.step() - optimizer.zero_grad() -``` +device = "cpu" # cuda -## Inference +batch_size = 3 -Once we trained the model, we can initialize a `Retriever` to retrieve relevant documents given a query. +model = model.Splade( + model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + device=device +) -- It build a sparse matrix from sparse activations of documents. -- It build a sparse matrix from sparse activations of queries. -- It match relevant documents using dot product of both sparse matrix. -- It re-rank documents based on contextual embbedings similarity score. +optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) -```python -from sparsembed import retrieve - -documents = [{ - "id": 0, - "document": "Apple is a popular fruit.", - }, - { - "id": 1, - "document": "Banana is a popular fruit.", - }, - { - "id": 2, - "document": "Banana is a yellow fruit.", - } +X = [ + ("anchor 1", "positive 1", "negative 1"), + ("anchor 2", "positive 2", "negative 2"), + ("anchor 3", "positive 3", "negative 3"), ] -retriever = retrieve.Retriever( - key="id", - on="document", - model=model # Trained SparseEmbed model. +for anchor, positive, negative in utils.iter( + X, + epochs=1, + batch_size=batch_size, + shuffle=True + ): + loss = train.train_splade( + model=model, + optimizer=optimizer, + anchor=anchor, + positive=positive, + negative=negative, + flops_loss_weight=1e-5, + in_batch_negatives=True, + ) + +documents, queries, qrels = utils.load_beir("scifact", split="test") + +retriever = retrieve.SpladeRetriever( + key="id", + on=["title", "text"], + model=model ) retriever = retriever.add( documents=documents, - k_token=32, # Number of tokens to activate. - batch_size=3, + batch_size=batch_size ) -retriever( - q = [ - "Apple", - "Banana", - ], - k_sparse=20, # Number of documents to retrieve. - k_token=32, # Number of tokens to activate. - batch_size=3 +utils.evaluate( + retriever=retriever, + batch_size=1, + qrels=qrels, + queries=queries, + k=100, + metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"] ) ``` +## SparsEmbed + +The following PyTorch code snippet illustrates the training loop to fine-tune SparseEmbed: + ```python -[[{'id': 0, 'similarity': 195.057861328125}, - {'id': 1, 'similarity': 183.51429748535156}, - {'id': 2, 'similarity': 158.66012573242188}], - [{'id': 1, 'similarity': 214.34048461914062}, - {'id': 2, 'similarity': 194.5692901611328}, - {'id': 0, 'similarity': 192.5744171142578}]] -``` +from transformers import AutoModelForMaskedLM, AutoTokenizer +from sparsembed import model, utils, train, retrieve +import torch + +device = "cpu" # cuda + +batch_size = 3 + +model = model.SparsEmbed( + model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + device=device +) + +optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) -## Evaluations +X = [ + ("anchor 1", "positive 1", "negative 1"), + ("anchor 2", "positive 2", "negative 2"), + ("anchor 3", "positive 3", "negative 3"), +] + +for anchor, positive, negative in utils.iter( + X, + epochs=1, + batch_size=batch_size, + shuffle=True + ): + loss = train.train_sparsembed( + model=model, + optimizer=optimizer, + anchor=anchor, + positive=positive, + negative=negative, + flops_loss_weight=1e-5, + sparse_loss_weight=0.1, + in_batch_negatives=True, + ) + +documents, queries, qrels = utils.load_beir("scifact", split="test") + +retriever = retrieve.SparsEmbedRetriever( + key="id", + on=["title", "text"], + model=model +) -Work in progress. +retriever = retriever.add( + documents=documents, + batch_size=batch_size +) + +utils.evaluate( + retriever=retriever, + batch_size=1, + qrels=qrels, + queries=queries, + k=100, + metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"] +) +``` \ No newline at end of file diff --git a/setup.py b/setup.py index fc6d6ad..b96c978 100644 --- a/setup.py +++ b/setup.py @@ -11,10 +11,12 @@ "transformers >= 4.30.2", ] +eval = ["ranx >= 0.3.16", "faiss-beir >= 2.0.0"] + setuptools.setup( name="sparsembed", version=f"{__version__}", - license="MIT", + license="", author="Raphael Sourty", author_email="raphael.sourty@gmail.com", description="Sparse Embeddings for Neural Search.", @@ -28,9 +30,13 @@ "semantic search", "SparseEmbed", "Google Research", + "SPLADE", ], packages=setuptools.find_packages(), install_requires=base_packages, + extras_require={ + "eval": base_packages + eval, + }, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/sparsembed/__init__.py b/sparsembed/__init__.py index 8599718..f6aad25 100644 --- a/sparsembed/__init__.py +++ b/sparsembed/__init__.py @@ -1 +1 @@ -__all__ = ["data", "losses", "model", "retrieve", "utils"] +__all__ = ["data", "losses", "model", "retrieve", "train", "utils"] diff --git a/sparsembed/__version__.py b/sparsembed/__version__.py index a6b2ab8..470cb66 100644 --- a/sparsembed/__version__.py +++ b/sparsembed/__version__.py @@ -1,3 +1,3 @@ -VERSION = (0, 0, 3) +VERSION = (0, 0, 4) __version__ = ".".join(map(str, VERSION)) diff --git a/sparsembed/losses/__init__.py b/sparsembed/losses/__init__.py index f397891..9303061 100644 --- a/sparsembed/losses/__init__.py +++ b/sparsembed/losses/__init__.py @@ -1,4 +1,4 @@ -from .cosine import Cosine from .flops import Flops +from .ranking import Ranking -__all__ = ["Flops", "Cosine"] +__all__ = ["Flops", "Ranking"] diff --git a/sparsembed/losses/cosine.py b/sparsembed/losses/cosine.py deleted file mode 100644 index fdca1c3..0000000 --- a/sparsembed/losses/cosine.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch - -__all__ = ["Cosine"] - - -class Cosine(torch.nn.Module): - """Cosine similarity loss function between sparse vectors. - - Parameters - ---------- - l - Lambda to ponderate the Cosine loss. - - Example - ------- - >>> from transformers import AutoModelForMaskedLM, AutoTokenizer - >>> from sparsembed import losses, model, utils - >>> import torch - - >>> _ = torch.manual_seed(42) - - >>> loss = 0 - - >>> model = model.SparsEmbed( - ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased"), - ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), - ... ) - - >>> queries_embeddings = model( - ... ["Paris", "Toulouse"], - ... k=96 - ... ) - - >>> documents_embeddings = model( - ... ["Paris is a city located in France.", "Toulouse is a city located in France."], - ... k=256 - ... ) - - >>> scores = utils.scores( - ... queries_activations=queries_embeddings["activations"], - ... queries_embeddings=queries_embeddings["embeddings"], - ... documents_activations=documents_embeddings["activations"], - ... documents_embeddings=documents_embeddings["embeddings"], - ... device="cpu", - ... ) - - >>> cosine_loss = losses.Cosine() - - >>> loss += cosine_loss.sparse( - ... queries_sparse_activations=queries_embeddings["sparse_activations"], - ... documents_sparse_activations=documents_embeddings["sparse_activations"], - ... labels=torch.tensor([1,1]), - ... ) - - >>> loss += cosine_loss.dense( - ... scores=scores, - ... labels=torch.tensor([1, 1]), - ... ) - - References - ---------- - 1. [Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation](https://arxiv.org/pdf/2010.02666.pdf) - - """ - - def __init__( - self, - ) -> None: - super(Cosine, self).__init__() - self.mse = torch.nn.MSELoss(reduction="none") - - def sparse( - self, - queries_sparse_activations: torch.Tensor, - documents_sparse_activations: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor = None, - ) -> torch.Tensor: - """Sparse CosineSimilarity loss.""" - similarity = torch.cosine_similarity( - queries_sparse_activations, documents_sparse_activations, dim=1 - ) - errors = self.mse(similarity, labels) - if weights is not None: - errors *= weights - return errors.mean() - - def dense( - self, - scores: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor = None, - ) -> torch.Tensor: - errors = self.mse(scores, labels) - if weights is not None: - errors *= weights - return errors.mean() diff --git a/sparsembed/losses/flops.py b/sparsembed/losses/flops.py index 53e3b7b..01d1368 100644 --- a/sparsembed/losses/flops.py +++ b/sparsembed/losses/flops.py @@ -9,44 +9,35 @@ class Flops(torch.nn.Module): Example ------- >>> from transformers import AutoModelForMaskedLM, AutoTokenizer - >>> from sparsembed import model, losses - >>> import torch + >>> from sparsembed import model, utils, losses + >>> from pprint import pprint as print - >>> _ = torch.manual_seed(42) + >>> device = "mps" - >>> model = model.SparsEmbed( - ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased"), + >>> model = model.Splade( + ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ... device=device ... ) - >>> anchor_queries_embeddings = model.encode( - ... ["Paris", "Toulouse"], - ... k=64 + >>> anchor_activations = model.encode( + ... ["Sports", "Music"], ... ) - >>> positive_documents_embeddings = model.encode( - ... ["France", "France"], - ... k=64 + >>> positive_activations = model.encode( + ... ["Sports", "Music"], ... ) - >>> negative_documents_embeddings = model.encode( - ... ["Canada", "Espagne"], - ... k=64 + >>> negative_activations = model.encode( + ... ["Cinema", "Movie"], ... ) - >>> flops = losses.Flops() - - >>> loss = flops( - ... sparse_activations = anchor_queries_embeddings["sparse_activations"] - ... ) - - >>> loss += flops( - ... sparse_activations = torch.cat([ - ... positive_documents_embeddings["sparse_activations"], - ... negative_documents_embeddings["sparse_activations"], - ... ], dim=0) + >>> losses.Flops()( + ... anchor_activations=anchor_activations["sparse_activations"], + ... positive_activations=positive_activations["sparse_activations"], + ... negative_activations=negative_activations["sparse_activations"], ... ) - + tensor(1838.6599, device='mps:0') References ---------- @@ -60,7 +51,12 @@ def __init__(self): def __call__( self, - sparse_activations: torch.Tensor, + anchor_activations: torch.Tensor, + positive_activations: torch.Tensor, + negative_activations: torch.Tensor, ) -> torch.Tensor: """Loss which tend to reduce sparse activation.""" - return torch.sum(torch.mean(sparse_activations, dim=0) ** 2, dim=0) + activations = torch.cat( + [anchor_activations, positive_activations, negative_activations], dim=0 + ) + return torch.sum(torch.mean(torch.abs(activations), dim=0) ** 2, dim=0) diff --git a/sparsembed/losses/ranking.py b/sparsembed/losses/ranking.py new file mode 100644 index 0000000..6d9ece0 --- /dev/null +++ b/sparsembed/losses/ranking.py @@ -0,0 +1,73 @@ +import torch + +__all__ = ["Ranking"] + + +class Ranking(torch.nn.Module): + """Ranking loss. + + Examples + -------- + >>> from transformers import AutoModelForMaskedLM, AutoTokenizer + >>> from sparsembed import model, utils, losses + >>> from pprint import pprint as print + + >>> device = "mps" + + >>> model = model.Splade( + ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ... device=device + ... ) + + >>> queries_activations = model.encode( + ... ["Sports", "Music"], + ... ) + + >>> positive_activations = model.encode( + ... ["Sports", "Music"], + ... ) + + >>> negative_activations = model.encode( + ... ["Cinema", "Movie"], + ... ) + + >>> scores = utils.sparse_scores( + ... anchor_activations=queries_activations["sparse_activations"], + ... positive_activations=positive_activations["sparse_activations"], + ... negative_activations=negative_activations["sparse_activations"], + ... in_batch_negatives=True, + ... ) + + >>> losses.Ranking()(**scores) + tensor(3264.9170, device='mps:0') + + References + ---------- + 1. [SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/pdf/2107.05720.pdf) + + """ + + def __init__(self): + super(Ranking, self).__init__() + self.log_softmax = torch.nn.LogSoftmax(dim=1) + + def __call__( + self, + positive_scores: torch.Tensor, + negative_scores: torch.Tensor, + ) -> torch.Tensor: + """Ranking loss.""" + scores = torch.stack( + [ + positive_scores, + negative_scores, + ], + dim=1, + ) + + return torch.index_select( + input=-self.log_softmax(scores), + dim=1, + index=torch.zeros(1, dtype=torch.int64).to(scores.device), + ).mean() diff --git a/sparsembed/model/__init__.py b/sparsembed/model/__init__.py index f39e3cf..ee9e937 100644 --- a/sparsembed/model/__init__.py +++ b/sparsembed/model/__init__.py @@ -1,3 +1,4 @@ from .sparsembed import SparsEmbed +from .splade import Splade -__all__ = ["SpareEmbed"] +__all__ = ["SparsEmbed", "Splade"] diff --git a/sparsembed/model/sparsembed.py b/sparsembed/model/sparsembed.py index 06484dc..6443f29 100644 --- a/sparsembed/model/sparsembed.py +++ b/sparsembed/model/sparsembed.py @@ -5,8 +5,10 @@ __all__ = ["SparseEmbed"] +from .splade import Splade -class SparsEmbed(torch.nn.Module): + +class SparsEmbed(Splade): """SparsEmbed model. Parameters @@ -33,41 +35,35 @@ class SparsEmbed(torch.nn.Module): >>> model = model.SparsEmbed( ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), - ... device=device + ... k_tokens=96, + ... device=device, ... ) >>> queries_embeddings = model.encode( ... ["Sports", "Music"], - ... k=96 ... ) >>> documents_embeddings = model.encode( ... ["Music is great.", "Sports is great."], - ... k=256 ... ) >>> query_expanded = model.decode( - ... sparse_activations=queries_embeddings["sparse_activations"], k=25 + ... sparse_activations=queries_embeddings["sparse_activations"] ... ) - >>> print(query_expanded) - ['sports the games s defaulted and mores of athletics sport hockey a ' - 'basketball', - 'music the on s song of a more in and songs musical 2015 to'] - >>> documents_expanded = model.decode( - ... sparse_activations=documents_embeddings["sparse_activations"], k=25 + ... sparse_activations=documents_embeddings["sparse_activations"] ... ) - >>> print(documents_expanded) - ['is great music was good wonderful big beautiful huge has are and of fine on ' - 'a', - 'is great sports good was big has are and wonderful sport huge nice of games ' - 'a'] - >>> queries_embeddings["activations"].shape torch.Size([2, 96]) + >>> queries_embeddings["sparse_activations"].shape + torch.Size([2, 30522]) + + >>> queries_embeddings["embeddings"].shape + torch.Size([2, 96, 64]) + References ---------- 1. [SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://dl.acm.org/doi/pdf/10.1145/3539618.3592065) @@ -78,28 +74,17 @@ def __init__( self, tokenizer: AutoTokenizer, model: AutoModelForMaskedLM, - k_query: int = 64, - k_documents: int = 256, - embedding_size: int = 256, + k_tokens: int = 96, + embedding_size: int = 64, device: str = None, ) -> None: - super(SparsEmbed, self).__init__() - self.tokenizer = tokenizer - self.model = model - self.k_query = k_query - self.k_documents = k_documents - self.embedding_size = embedding_size - - if device is not None: - self.device = device - elif torch.cuda.is_available(): - self.device = "cuda" - else: - self.device = "cpu" + super(SparsEmbed, self).__init__( + tokenizer=tokenizer, model=model, device=device + ) - self.model.config.output_hidden_states = True + self.k_tokens = k_tokens + self.embedding_size = embedding_size - self.relu = torch.nn.ReLU().to(self.device) self.softmax = torch.nn.Softmax(dim=2).to(self.device) # Input embedding size: @@ -108,65 +93,53 @@ def __init__( in_features = embeddings.shape[2] self.linear = torch.nn.Linear( - in_features=in_features, out_features=embedding_size + in_features=in_features, out_features=embedding_size, bias=False ).to(self.device) def encode( self, texts: list[str], - k: int, truncation: bool = True, padding: bool = True, + max_length: int = 256, **kwargs ) -> dict[str, torch.Tensor]: """Encode documents""" with torch.no_grad(): return self( - texts=texts, k=k, truncation=truncation, padding=padding, **kwargs + texts=texts, + truncation=truncation, + padding=padding, + max_length=max_length, + **kwargs, ) - def decode( - self, - sparse_activations: torch.Tensor, - clean_up_tokenization_spaces: bool = False, - skip_special_tokens: bool = True, - k: int = 128, - ) -> list[str]: - """Decode activated tokens ids where activated value > 0.""" - activations = self._filter_activations( - sparse_activations=sparse_activations, k=k - ) - - # Decode - return [ - " ".join( - activation.translate(str.maketrans("", "", string.punctuation)).split() - ) - for activation in self.tokenizer.batch_decode( - activations, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - skip_special_tokens=skip_special_tokens, - ) - ] - def forward( self, texts: list[str], - k: int, truncation: bool = True, padding: bool = True, + max_length: int = 256, **kwargs ) -> dict[str, torch.Tensor]: """Pytorch forward method.""" - kwargs = {"truncation": truncation, "padding": padding, **kwargs} + kwargs = { + "truncation": truncation, + "padding": padding, + "max_length": max_length, + **kwargs, + } logits, embeddings = self._encode(texts=texts, **kwargs) - activations = self._get_activation(logits=logits, k=k) + activations = self._update_activations( + **self._get_activation(logits=logits), + k_tokens=self.k_tokens, + ) attention = self._get_attention( logits=logits, - activation=activations["activations"], + activations=activations["activations"], ) embeddings = torch.bmm( @@ -180,26 +153,15 @@ def forward( "activations": activations["activations"], } - def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]: - """Encode sentences.""" - encoded_input = self.tokenizer.batch_encode_plus( - texts, return_tensors="pt", **kwargs - ) - if self.device != "cpu": - encoded_input = { - key: value.to(self.device) for key, value in encoded_input.items() - } - output = self.model(**encoded_input) - return output.logits, output.hidden_states[-1] - - def _get_activation(self, logits: torch.Tensor, k: int) -> torch.Tensor: + def _update_activations( + self, sparse_activations: torch.Tensor, k_tokens: int + ) -> torch.Tensor: """Returns activated tokens.""" - max_pooling = torch.amax(torch.log(1 + self.relu(logits)), dim=1) - activations = torch.topk(input=max_pooling, k=k, dim=-1).indices + activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices - # Set value of max pooling which are not in top k to 0. - sparse_activations = max_pooling * torch.zeros( - (max_pooling.shape[0], max_pooling.shape[1]), dtype=int + # Set value of max sparse_activations which are not in top k to 0. + sparse_activations = sparse_activations * torch.zeros( + (sparse_activations.shape[0], sparse_activations.shape[1]), dtype=int ).to(self.device).scatter_(dim=1, index=activations.long(), value=1) return { @@ -208,7 +170,7 @@ def _get_activation(self, logits: torch.Tensor, k: int) -> torch.Tensor: } def _get_attention( - self, logits: torch.Tensor, activation: torch.Tensor + self, logits: torch.Tensor, activations: torch.Tensor ) -> torch.Tensor: """Extract attention scores from MLM logits based on activated tokens.""" attention = logits.gather( @@ -216,21 +178,9 @@ def _get_attention( index=torch.stack( [ torch.stack([token for _ in range(logits.shape[1])]) - for token in activation + for token in activations ] ), ) return self.softmax(attention.transpose(1, 2)) - - def _filter_activations( - self, sparse_activations: torch.Tensor, k: int - ) -> list[torch.Tensor]: - """Among the set of activations, select the ones with a score > 0.""" - scores, activations = torch.topk(input=sparse_activations, k=k, dim=-1) - return [ - torch.index_select( - activation, dim=-1, index=torch.nonzero(score, as_tuple=True)[0] - ) - for score, activation in zip(scores, activations) - ] diff --git a/sparsembed/model/splade.py b/sparsembed/model/splade.py new file mode 100644 index 0000000..2e3ae92 --- /dev/null +++ b/sparsembed/model/splade.py @@ -0,0 +1,162 @@ +import string + +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer + +__all__ = ["Splade"] + + +class Splade(torch.nn.Module): + """SpladeV1 model. + + Parameters + ---------- + tokenizer + HuggingFace Tokenizer. + model + HuggingFace AutoModelForMaskedLM. + + Example + ------- + >>> from transformers import AutoModelForMaskedLM, AutoTokenizer + >>> from sparsembed import model + >>> from pprint import pprint as print + + >>> device = "mps" + + >>> model = model.Splade( + ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ... device=device + ... ) + + >>> queries_activations = model.encode( + ... ["Sports", "Music"], + ... ) + + >>> documents_activations = model.encode( + ... ["Music is great.", "Sports is great."], + ... ) + + >>> queries_activations["sparse_activations"].shape + torch.Size([2, 30522]) + + References + ---------- + 1. [SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/abs/2107.05720) + + """ + + def __init__( + self, + tokenizer: AutoTokenizer, + model: AutoModelForMaskedLM, + device: str = None, + ) -> None: + super(Splade, self).__init__() + self.tokenizer = tokenizer + self.model = model + + if device is not None: + self.device = device + elif torch.cuda.is_available(): + self.device = "cuda" + else: + self.device = "cpu" + + self.model.config.output_hidden_states = True + self.relu = torch.nn.ReLU().to(self.device) + + def encode( + self, + texts: list[str], + truncation: bool = True, + padding: bool = True, + max_length: int = 256, + **kwargs + ) -> dict[str, torch.Tensor]: + """Encode documents""" + with torch.no_grad(): + return self( + texts=texts, + truncation=truncation, + padding=padding, + max_length=max_length, + **kwargs + ) + + def decode( + self, + sparse_activations: torch.Tensor, + clean_up_tokenization_spaces: bool = False, + skip_special_tokens: bool = True, + k_tokens: int = 96, + ) -> list[str]: + """Decode activated tokens ids where activated value > 0.""" + activations = self._filter_activations( + sparse_activations=sparse_activations, k_tokens=k_tokens + ) + + # Decode + return [ + " ".join( + activation.translate(str.maketrans("", "", string.punctuation)).split() + ) + for activation in self.tokenizer.batch_decode( + activations, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + skip_special_tokens=skip_special_tokens, + ) + ] + + def forward( + self, + texts: list[str], + truncation: bool = True, + padding: bool = True, + max_length: int = 256, + **kwargs + ) -> dict[str, torch.Tensor]: + """Pytorch forward method.""" + kwargs = { + "truncation": truncation, + "padding": padding, + "max_length": max_length, + **kwargs, + } + + logits, _ = self._encode(texts=texts, **kwargs) + + activations = self._get_activation(logits=logits) + + return {"sparse_activations": activations["sparse_activations"]} + + def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + """Encode sentences.""" + encoded_input = self.tokenizer.batch_encode_plus( + texts, return_tensors="pt", **kwargs + ) + if self.device != "cpu": + encoded_input = { + key: value.to(self.device) for key, value in encoded_input.items() + } + output = self.model(**encoded_input) + return output.logits, output.hidden_states[0] + + def _get_activation(self, logits: torch.Tensor) -> dict[str, torch.Tensor]: + """Returns activated tokens.""" + return { + "sparse_activations": torch.log1p(self.relu(logits)).sum(axis=1), + } + + def _filter_activations( + self, sparse_activations: torch.Tensor, k_tokens: int + ) -> list[torch.Tensor]: + """Among the set of activations, select the ones with a score > 0.""" + scores, activations = torch.topk(input=sparse_activations, k=k_tokens, dim=-1) + return [ + torch.index_select( + activation, dim=-1, index=torch.nonzero(score, as_tuple=True)[0] + ) + for score, activation in zip(scores, activations) + ] diff --git a/sparsembed/retrieve/__init__.py b/sparsembed/retrieve/__init__.py index 0ffe6fd..7bd8020 100644 --- a/sparsembed/retrieve/__init__.py +++ b/sparsembed/retrieve/__init__.py @@ -1,3 +1,4 @@ -from .retriever import Retriever +from .sparsembed_retriever import SparsEmbedRetriever +from .splade_retriever import SpladeRetriever -__all__ = ["Retriever"] +__all__ = ["SparsEmbedRetriever", "SpladeRetriever"] diff --git a/sparsembed/retrieve/retriever.py b/sparsembed/retrieve/sparsembed_retriever.py similarity index 78% rename from sparsembed/retrieve/retriever.py rename to sparsembed/retrieve/sparsembed_retriever.py index e2b3c6a..3b4e520 100644 --- a/sparsembed/retrieve/retriever.py +++ b/sparsembed/retrieve/sparsembed_retriever.py @@ -1,15 +1,15 @@ import os +import warnings import torch import tqdm -import warnings from ..model import SparsEmbed -__all__ = ["Retriever"] +__all__ = ["SparsEmbedRetriever"] -class Retriever: +class SparsEmbedRetriever: """Retriever class. Parameters @@ -36,52 +36,32 @@ class Retriever: ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), ... device=device, - ... embedding_size=3, + ... embedding_size=64, + ... k_tokens=96, ... ) - >>> retriever = retrieve.Retriever(key="id", on="document", model=model) + >>> retriever = retrieve.SparsEmbedRetriever(key="id", on="document", model=model) >>> documents = [ - ... {"id": 0, "document": "Food is good."}, - ... {"id": 1, "document": "Sports is great."}, + ... {"id": 0, "document": "Food"}, + ... {"id": 1, "document": "Sports"}, + ... {"id": 2, "document": "Cinema"}, ... ] >>> retriever = retriever.add( ... documents=documents, - ... k_token=32, - ... batch_size=24 + ... batch_size=1 ... ) - >>> documents = [ - ... {"id": 2, "document": "Cinema is great."}, - ... {"id": 3, "document": "Music is amazing."}, - ... ] - >>> retriever = retriever.add( - ... documents=documents, - ... k_token=32, - ... batch_size=24 - ... ) - - >>> print(retriever(["Food", "Sports", "Cinema", "Music", "Hello World"], k_token=32)) - [[{'id': 3, 'similarity': 0.5633876323699951}, - {'id': 2, 'similarity': 0.4271728992462158}, - {'id': 1, 'similarity': 0.4205787181854248}, - {'id': 0, 'similarity': 0.3673652410507202}], - [{'id': 1, 'similarity': 1.547836184501648}, - {'id': 3, 'similarity': 0.7415981888771057}, - {'id': 2, 'similarity': 0.6557919979095459}, - {'id': 0, 'similarity': 0.5385637879371643}], - [{'id': 3, 'similarity': 0.5051844716072083}, - {'id': 2, 'similarity': 0.48867619037628174}, - {'id': 1, 'similarity': 0.3863832950592041}, - {'id': 0, 'similarity': 0.2812037169933319}], - [{'id': 3, 'similarity': 0.9398075938224792}, - {'id': 1, 'similarity': 0.595514178276062}, - {'id': 2, 'similarity': 0.5711489319801331}, - {'id': 0, 'similarity': 0.46095147728919983}], - [{'id': 2, 'similarity': 1.3963655233383179}, - {'id': 3, 'similarity': 1.2879667282104492}, - {'id': 1, 'similarity': 1.229896068572998}, - {'id': 0, 'similarity': 1.2129783630371094}]] + >>> print(retriever(["Food", "Sports", "Cinema"], batch_size=32)) + [[{'id': 0, 'similarity': 201.47901916503906}, + {'id': 1, 'similarity': 107.03492736816406}, + {'id': 2, 'similarity': 106.74536895751953}], + [{'id': 1, 'similarity': 252.70684814453125}, + {'id': 2, 'similarity': 125.91816711425781}, + {'id': 0, 'similarity': 107.03492736816406}], + [{'id': 2, 'similarity': 205.38475036621094}, + {'id': 1, 'similarity': 125.91813659667969}, + {'id': 0, 'similarity': 106.745361328125}]] """ @@ -111,9 +91,8 @@ def __init__( def add( self, documents: list, - k_token: int = 256, batch_size: int = 32, - ) -> "Retriever": + ) -> "SparsEmbedRetriever": """Add new documents to the retriever. Computes documents embeddings and activations and update the sparse matrix. @@ -122,8 +101,6 @@ def add( ---------- documents Documents to add. - k_token - Number of tokens to activate. batch_size Batch size. """ @@ -136,7 +113,6 @@ def add( " ".join([document[field] for field in self.on]) for document in documents ], - k_token=k_token, batch_size=batch_size, ) @@ -162,9 +138,9 @@ def add( def __call__( self, q: list[str], - k_sparse: int = 100, - k_token: int = 96, + k: int = 100, batch_size: int = 3, + **kwargs, ) -> list: """Retrieve documents. @@ -174,8 +150,6 @@ def __call__( Queries. k_sparse Number of documents to retrieve. - k_token - Number of tokens to activate. """ ( queries_embeddings, @@ -183,14 +157,14 @@ def __call__( sparse_matrix, ) = self._build_index( X=[q] if isinstance(q, str) else q, - k_token=k_token, batch_size=batch_size, + **kwargs, ) sparse_scores = (sparse_matrix @ self.sparse_matrix).to_dense() - _, sparse_matchs = torch.topk( - input=sparse_scores, k=min(k_sparse, len(self.documents_keys)), dim=-1 + sparse_scores, sparse_matchs = torch.topk( + input=sparse_scores, k=min(k, len(self.documents_keys)), dim=-1 ) sparse_matchs_idx = sparse_matchs.tolist() @@ -214,11 +188,13 @@ def __call__( ) return self._rank( - dense_scores=dense_scores, sparse_matchs=sparse_matchs, k_sparse=k_sparse + dense_scores=dense_scores, + sparse_matchs=sparse_matchs, + k_dense=k, ) def _rank( - self, dense_scores: torch.Tensor, sparse_matchs: torch.Tensor, k_sparse: int + self, dense_scores: torch.Tensor, sparse_matchs: torch.Tensor, k_dense: int ) -> list: """Rank documents by scores. @@ -230,7 +206,7 @@ def _rank( Documents matchs. """ dense_scores, dense_matchs = torch.topk( - input=dense_scores, k=min(k_sparse, len(self.documents_keys)), dim=-1 + input=dense_scores, k=min(k_dense, len(self.documents_keys)), dim=-1 ) dense_scores = dense_scores.tolist() @@ -253,13 +229,13 @@ def _build_index( self, X: list[str], batch_size: int, - k_token: int, + **kwargs, ) -> tuple[list, list, torch.Tensor]: """Build a sparse matrix index.""" index_embeddings, index_activations, sparse_activations = [], [], [] for batch in self._to_batch(X, batch_size=batch_size): - batch_embeddings = self.model.encode(batch, k=k_token) + batch_embeddings = self.model.encode(batch, **kwargs) sparse_activations.append( batch_embeddings["sparse_activations"].to_sparse() @@ -323,7 +299,7 @@ def _get_scores( queries_embeddings: list[torch.Tensor], documents_embeddings: list[list[torch.Tensor]], intersections: list[torch.Tensor], - ) -> list: + ) -> torch.Tensor: """Computes similarity scores between queries and documents with activated tokens embeddings.""" return torch.stack( [ diff --git a/sparsembed/retrieve/splade_retriever.py b/sparsembed/retrieve/splade_retriever.py new file mode 100644 index 0000000..6dc29fa --- /dev/null +++ b/sparsembed/retrieve/splade_retriever.py @@ -0,0 +1,214 @@ +import os +import warnings + +import torch +import tqdm + +from ..model import Splade + +__all__ = ["SpladeRetriever"] + + +class SpladeRetriever: + """Retriever class. + + Parameters + ---------- + key + Document unique identifier. + on + Document texts. + model + SparsEmbed model. + + Example + ------- + >>> from transformers import AutoModelForMaskedLM, AutoTokenizer + >>> from sparsembed import model, retrieve + >>> from pprint import pprint as print + >>> import torch + + >>> _ = torch.manual_seed(42) + + >>> device = "cpu" + + >>> model = model.Splade( + ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ... device=device, + ... ) + + >>> retriever = retrieve.SpladeRetriever(key="id", on="document", model=model) + + >>> documents = [ + ... {"id": 0, "document": "Food"}, + ... {"id": 1, "document": "Sports"}, + ... {"id": 2, "document": "Cinema"}, + ... ] + >>> retriever = retriever.add( + ... documents=documents, + ... batch_size=1 + ... ) + + >>> print(retriever(["Food", "Sports", "Cinema"], batch_size=32)) + [[{'id': 0, 'similarity': 2005.1702880859375}, + {'id': 1, 'similarity': 1866.706787109375}, + {'id': 2, 'similarity': 1690.898681640625}], + [{'id': 1, 'similarity': 2534.69140625}, + {'id': 2, 'similarity': 1875.5230712890625}, + {'id': 0, 'similarity': 1866.70654296875}], + [{'id': 2, 'similarity': 1934.9771728515625}, + {'id': 1, 'similarity': 1875.521484375}, + {'id': 0, 'similarity': 1690.8975830078125}]] + + """ + + def __init__( + self, + key: str, + on: list[str], + model: Splade, + tokenizer_parallelism: str = "false", + ) -> None: + self.key = key + self.on = [on] if isinstance(on, str) else on + self.model = model + self.vocabulary_size = len(model.tokenizer.get_vocab()) + + # Mapping between sparse matrix index and document key. + self.sparse_matrix = None + self.documents_keys = {} + + # Documents embeddings and activations store. + self.documents_embeddings, self.documents_activations = [], [] + os.environ["TOKENIZERS_PARALLELISM"] = tokenizer_parallelism + warnings.filterwarnings( + "ignore", ".*Sparse CSR tensor support is in beta state.*" + ) + + def add( + self, + documents: list, + batch_size: int = 32, + **kwargs, + ) -> "SpladeRetriever": + """Add new documents to the retriever. + + Computes documents embeddings and activations and update the sparse matrix. + + Parameters + ---------- + documents + Documents to add. + batch_size + Batch size. + """ + sparse_matrix = self._build_index( + X=[ + " ".join([document[field] for field in self.on]) + for document in documents + ], + batch_size=batch_size, + **kwargs, + ) + + self.sparse_matrix = ( + sparse_matrix.T + if self.sparse_matrix is None + else torch.cat([self.sparse_matrix.to_sparse(), sparse_matrix.T], dim=1) + ) + + self.documents_keys = { + **self.documents_keys, + **{ + len(self.documents_keys) + index: document[self.key] + for index, document in enumerate(documents) + }, + } + + return self + + def __call__( + self, + q: list[str], + k: int = 100, + batch_size: int = 3, + **kwargs, + ) -> list: + """Retrieve documents. + + Parameters + ---------- + q + Queries. + k_sparse + Number of documents to retrieve. + """ + sparse_matrix = self._build_index( + X=[q] if isinstance(q, str) else q, + batch_size=batch_size, + **kwargs, + ) + + sparse_scores = (sparse_matrix @ self.sparse_matrix).to_dense() + + return self._rank( + sparse_scores=sparse_scores, + k=k, + ) + + def _rank(self, sparse_scores: torch.Tensor, k: int) -> list: + """Rank documents by scores. + + Parameters + ---------- + scores + Scores between queries and documents. + matchs + Documents matchs. + """ + sparse_scores, sparse_matchs = torch.topk( + input=sparse_scores, k=min(k, len(self.documents_keys)), dim=-1 + ) + + sparse_scores = sparse_scores.tolist() + sparse_matchs = sparse_matchs.tolist() + + return [ + [ + { + self.key: self.documents_keys[document], + "similarity": score, + } + for score, document in zip(query_scores, query_matchs) + ] + for query_scores, query_matchs in zip(sparse_scores, sparse_matchs) + ] + + def _build_index( + self, + X: list[str], + batch_size: int, + **kwargs, + ) -> tuple[list, list, torch.Tensor]: + """Build a sparse matrix index.""" + sparse_activations = [] + + for batch in self._to_batch(X, batch_size=batch_size): + batch_embeddings = self.model.encode(batch, **kwargs) + + sparse_activations.append( + batch_embeddings["sparse_activations"].to_sparse() + ) + + return torch.cat(sparse_activations) + + @staticmethod + def _to_batch(X: list, batch_size: int) -> list: + """Convert input list to batch.""" + for X in tqdm.tqdm( + [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)], + position=0, + total=1 + len(X) // batch_size, + ): + yield X diff --git a/sparsembed/train/__init__.py b/sparsembed/train/__init__.py new file mode 100644 index 0000000..9cb9978 --- /dev/null +++ b/sparsembed/train/__init__.py @@ -0,0 +1,4 @@ +from .train_sparsembed import train_sparsembed +from .train_splade import train_splade + +__all__ = ["train_splade", "train_sparsembed"] diff --git a/sparsembed/train/train_sparsembed.py b/sparsembed/train/train_sparsembed.py new file mode 100644 index 0000000..0bb4c3f --- /dev/null +++ b/sparsembed/train/train_sparsembed.py @@ -0,0 +1,132 @@ +import torch + +from .. import losses, utils + +__all__ = ["train_sparsembed"] + + +def train_sparsembed( + model, + optimizer, + anchor: list[str], + positive: list[str], + negative: list[str], + flops_loss_weight: float = 1e-4, + sparse_loss_weight: float = 0.1, + dense_loss_weight: float = 1.0, + in_batch_negatives: bool = True, +): + """Compute the ranking loss and the flops loss for a single step. + + Parameters + ---------- + model + Splade model. + optimizer + Optimizer. + anchor + Anchor. + positive + Positive. + negative + Negative. + flops_loss_weight + Flops loss weight. Defaults to 1e-4. + in_batch_negatives + Whether to use in batch negatives or not. Defaults to True. + + Examples + -------- + >>> from transformers import AutoModelForMaskedLM, AutoTokenizer + >>> from sparsembed import model, utils, train + >>> import torch + + >>> device = "mps" + + >>> model = model.SparsEmbed( + ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ... device=device + ... ) + + >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + >>> X = [ + ... ("Sports", "Music", "Cinema"), + ... ("Sports", "Music", "Cinema"), + ... ("Sports", "Music", "Cinema"), + ... ] + + >>> for anchor, positive, negative in utils.iter( + ... X, + ... epochs=3, + ... batch_size=3, + ... shuffle=False + ... ): + ... loss = train.train_sparsembed( + ... model=model, + ... optimizer=optimizer, + ... anchor=anchor, + ... positive=positive, + ... negative=negative, + ... flops_loss_weight=1e-4, + ... in_batch_negatives=True, + ... ) + + >>> {'dense': tensor(4.9316, device='mps:0', grad_fn=), 'ranking': tensor(3456.6538, device='mps:0', grad_fn=), 'flops': tensor(796.2637, device='mps:0', grad_fn=)} + """ + + anchor_activations = model( + anchor, + ) + + positive_activations = model( + positive, + ) + + negative_activations = model( + negative, + ) + + sparse_scores = utils.sparse_scores( + anchor_activations=anchor_activations["sparse_activations"], + positive_activations=positive_activations["sparse_activations"], + negative_activations=negative_activations["sparse_activations"], + in_batch_negatives=in_batch_negatives, + ) + + dense_scores = utils.dense_scores( + anchor_activations=anchor_activations["activations"], + positive_activations=positive_activations["activations"], + negative_activations=negative_activations["activations"], + anchor_embeddings=anchor_activations["embeddings"], + positive_embeddings=positive_activations["embeddings"], + negative_embeddings=negative_activations["embeddings"], + func=torch.sum, + ) + + sparse_ranking_loss = losses.Ranking()(**sparse_scores) + + flops_loss = losses.Flops()( + anchor_activations=anchor_activations["sparse_activations"], + positive_activations=positive_activations["sparse_activations"], + negative_activations=negative_activations["sparse_activations"], + ) + + dense_ranking_loss = losses.Ranking()(**dense_scores) + + loss = ( + dense_loss_weight * dense_ranking_loss + + sparse_loss_weight * sparse_ranking_loss + + flops_loss_weight * flops_loss + ) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + return { + "dense": dense_ranking_loss, + "ranking": sparse_ranking_loss, + "flops": flops_loss, + } diff --git a/sparsembed/train/train_splade.py b/sparsembed/train/train_splade.py new file mode 100644 index 0000000..03b5d1b --- /dev/null +++ b/sparsembed/train/train_splade.py @@ -0,0 +1,111 @@ +from .. import losses, utils + +__all__ = ["train_splade"] + + +def train_splade( + model, + optimizer, + anchor: list[str], + positive: list[str], + negative: list[str], + flops_loss_weight: float = 1e-4, + sparse_loss_weight: float = 1.0, + in_batch_negatives: bool = True, +): + """Compute the ranking loss and the flops loss for a single step. + + Parameters + ---------- + model + Splade model. + optimizer + Optimizer. + anchor + Anchor. + positive + Positive. + negative + Negative. + flops_loss_weight + Flops loss weight. Defaults to 1e-4. + in_batch_negatives + Whether to use in batch negatives or not. Defaults to True. + + Examples + -------- + >>> from transformers import AutoModelForMaskedLM, AutoTokenizer + >>> from sparsembed import model, utils, train + >>> import torch + + >>> device = "mps" + + >>> model = model.Splade( + ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ... device=device + ... ) + + >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + >>> X = [ + ... ("Sports", "Music", "Cinema"), + ... ("Sports", "Music", "Cinema"), + ... ("Sports", "Music", "Cinema"), + ... ] + + >>> for anchor, positive, negative in utils.iter( + ... X, + ... epochs=3, + ... batch_size=3, + ... shuffle=False + ... ): + ... loss = train.train_splade( + ... model=model, + ... optimizer=optimizer, + ... anchor=anchor, + ... positive=positive, + ... negative=negative, + ... flops_loss_weight=1e-4, + ... in_batch_negatives=True, + ... ) + + >>> loss + {'ranking': tensor(307.2816, device='mps:0', grad_fn=), 'flops': tensor(75.3216, device='mps:0', grad_fn=)} + + """ + + anchor_activations = model( + anchor, + ) + + positive_activations = model( + positive, + ) + + negative_activations = model( + negative, + ) + + scores = utils.sparse_scores( + anchor_activations=anchor_activations["sparse_activations"], + positive_activations=positive_activations["sparse_activations"], + negative_activations=negative_activations["sparse_activations"], + in_batch_negatives=in_batch_negatives, + ) + + ranking_loss = losses.Ranking()(**scores) + + flops_loss = losses.Flops()( + anchor_activations=anchor_activations["sparse_activations"], + positive_activations=positive_activations["sparse_activations"], + negative_activations=negative_activations["sparse_activations"], + ) + + loss = sparse_loss_weight * ranking_loss + flops_loss_weight * flops_loss + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + return {"ranking": ranking_loss, "flops": flops_loss} diff --git a/sparsembed/utils/__init__.py b/sparsembed/utils/__init__.py index 31389e3..41167e8 100644 --- a/sparsembed/utils/__init__.py +++ b/sparsembed/utils/__init__.py @@ -1,5 +1,14 @@ -from .evaluate import evaluate +from .dense_scores import dense_scores +from .evaluate import evaluate, load_beir +from .in_batch import in_batch_sparse_scores from .iter import iter -from .scores import scores +from .sparse_scores import sparse_scores -__all__ = ["evaluate", "iter", "scores"] +__all__ = [ + "dense_scores", + "evaluate", + "load_beir", + "in_batch_sparse_scores", + "iter", + "sparse_scores", +] diff --git a/sparsembed/utils/dense_scores.py b/sparsembed/utils/dense_scores.py new file mode 100644 index 0000000..7fc1bef --- /dev/null +++ b/sparsembed/utils/dense_scores.py @@ -0,0 +1,207 @@ +import torch + +__all__ = ["dense_scores"] + + +def _build_index(activations: torch.Tensor, embeddings: torch.Tensor) -> dict: + """Build index to score documents using activated tokens and embeddings.""" + index = [] + for tokens_activation, tokens_embeddings in zip(activations, embeddings): + index.append( + { + token.item(): embedding + for token, embedding in zip(tokens_activation, tokens_embeddings) + } + ) + return index + + +def _intersection(t1: torch.Tensor, t2: torch.Tensor) -> list: + """Retrieve intersection between two tensors.""" + t1, t2 = t1.flatten(), t2.flatten() + combined = torch.cat((t1, t2), dim=0) + uniques, counts = combined.unique(return_counts=True, sorted=False) + return uniques[counts > 1].tolist() + + +def _get_intersection(queries_activations: list, documents_activations: list) -> list: + """Retrieve intersection of activated tokens between queries and documents.""" + return [ + _intersection(query_activations, document_activations) + for query_activations, document_activations in zip( + queries_activations, + documents_activations, + ) + ] + + +def _get_scores( + anchor_embeddings_index: dict, + positive_embeddings_index: dict, + negative_embeddings_index: dict, + positive_intersections: torch.Tensor, + negative_intersections: torch.Tensor, + func, +) -> dict[str, torch.Tensor]: + """Computes similarity scores between queries and documents based on activated tokens embeddings""" + positive_scores, negative_scores = [], [] + + for ( + anchor_embedding_index, + positive_embedding_index, + negative_embedding_index, + positive_intersection, + negative_intersection, + ) in zip( + anchor_embeddings_index, + positive_embeddings_index, + negative_embeddings_index, + positive_intersections, + negative_intersections, + ): + if len(positive_intersections) > 0 and len(negative_intersections) > 0: + positive_scores.append( + func( + torch.stack( + [ + anchor_embedding_index[token] + for token in positive_intersection + ], + dim=0, + ) + * torch.stack( + [ + positive_embedding_index[token] + for token in positive_intersection + ], + dim=0, + ) + ) + ) + + negative_scores.append( + func( + torch.stack( + [ + anchor_embedding_index[token] + for token in negative_intersection + ], + dim=0, + ) + * torch.stack( + [ + negative_embedding_index[token] + for token in negative_intersection + ], + dim=0, + ) + ) + ) + + return { + "positive_scores": torch.stack( + positive_scores, + dim=0, + ), + "negative_scores": torch.stack( + negative_scores, + dim=0, + ), + } + + +def dense_scores( + anchor_activations: torch.Tensor, + positive_activations: torch.Tensor, + negative_activations: torch.Tensor, + anchor_embeddings: torch.Tensor, + positive_embeddings: torch.Tensor, + negative_embeddings: torch.Tensor, + func=torch.sum, +) -> dict[str, torch.Tensor]: + """Computes score between queries and documents intersected activated tokens. + + Parameters + ---------- + queries_activations + Queries activated tokens. + queries_embeddings + Queries activated tokens embeddings. + documents_activations + Documents activated tokens. + documents_embeddings + Documents activated tokens embeddings. + func + Either torch.sum or torch.mean. torch.mean is dedicated to training and + torch.sum is dedicated to inference. + + Example + ---------- + >>> from transformers import AutoModelForMaskedLM, AutoTokenizer + >>> from sparsembed import model, utils + >>> import torch + + >>> _ = torch.manual_seed(42) + + >>> model = model.SparsEmbed( + ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased"), + ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ... k_tokens=96, + ... ) + + >>> anchor_embeddings = model( + ... ["Paris", "Toulouse"], + ... ) + + >>> positive_embeddings = model( + ... ["Paris", "Toulouse"], + ... ) + + >>> negative_embeddings = model( + ... ["Toulouse", "Paris"], + ... ) + + >>> scores = utils.dense_scores( + ... anchor_activations=anchor_embeddings["activations"], + ... positive_activations=positive_embeddings["activations"], + ... negative_activations=negative_embeddings["activations"], + ... anchor_embeddings=anchor_embeddings["embeddings"], + ... positive_embeddings=positive_embeddings["embeddings"], + ... negative_embeddings=negative_embeddings["embeddings"], + ... func=torch.sum, + ... ) + + >>> scores + {'positive_scores': tensor([216.5337, 214.3472], grad_fn=), 'negative_scores': tensor([103.8172, 103.8172], grad_fn=)} + + """ + anchor_embeddings_index = _build_index( + activations=anchor_activations, embeddings=anchor_embeddings + ) + + positive_embeddings_index = _build_index( + activations=positive_activations, embeddings=positive_embeddings + ) + + negative_embeddings_index = _build_index( + activations=negative_activations, embeddings=negative_embeddings + ) + + positive_intersections = _get_intersection( + queries_activations=anchor_activations, + documents_activations=positive_activations, + ) + + negative_intersections = _get_intersection( + queries_activations=anchor_activations, + documents_activations=negative_activations, + ) + + return _get_scores( + anchor_embeddings_index=anchor_embeddings_index, + positive_embeddings_index=positive_embeddings_index, + negative_embeddings_index=negative_embeddings_index, + positive_intersections=positive_intersections, + negative_intersections=negative_intersections, + func=func, + ) diff --git a/sparsembed/utils/evaluate.py b/sparsembed/utils/evaluate.py index a605f47..8c717ee 100644 --- a/sparsembed/utils/evaluate.py +++ b/sparsembed/utils/evaluate.py @@ -1,10 +1,127 @@ from ..model import SparsEmbed -__all__ = ["evaluate"] +__all__ = ["evaluate", "load_beir"] + + +def load_beir(dataset_name: str, split: str = "test") -> tuple[list, list, dict]: + """Load BEIR dataset. + + Parameters + ---------- + dataset_name + Dataset name: scifact. + + """ + from beir import util + from beir.datasets.data_loader import GenericDataLoader + + data_path = util.download_and_unzip( + f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip", + "./evaluation_datasets/", + ) + + documents, queries, qrels = GenericDataLoader(data_folder=data_path).load( + split=split + ) + + documents = [ + { + "id": document_id, + "title": document["title"], + "text": document["text"], + } + for document_id, document in documents.items() + ] + + return documents, queries, qrels def evaluate( - model: SparsEmbed, + retriever: SparsEmbed, + batch_size: int, + qrels: dict, + queries: list[str], + k: int = 30, + metrics: list = [], ) -> dict: - """Evaluation function.""" - pass + """Evaluation. + + Parameters + ---------- + retriever + Retriever. + batch_size + Batch size. + qrels + Qrels. + queries + Queries. + k + Number of documents to retrieve. + metrics + Metrics to compute. + + >>> from transformers import AutoModelForMaskedLM, AutoTokenizer + >>> from sparsembed import model, retrieve, utils + + >>> device = "cpu" + + >>> model = model.Splade( + ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ... device=device, + ... ) + + >>> documents, queries, qrels = utils.load_beir("scifact", split="test") + + >>> documents = documents[:10] + + >>> retriever = retrieve.SpladeRetriever( + ... key="id", + ... on=["title", "text"], + ... model=model + ... ) + + >>> retriever = retriever.add( + ... documents=documents, + ... batch_size=1 + ... ) + + >>> utils.evaluate( + ... retriever=retriever, + ... batch_size=1, + ... qrels=qrels, + ... queries=queries, + ... k=30, + ... metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"] + ... ) + + """ + from ranx import Qrels, Run, evaluate + + qrels = Qrels(qrels) + + matchs = retriever( + q=list(queries.values()), + k=30, + batch_size=batch_size, + ) + + run_dict = { + id_query: { + match["id"]: 1 - (rank / k) for rank, match in enumerate(query_matchs) + } + for id_query, query_matchs in zip(queries, matchs) + } + + run = Run(run_dict) + + if not metrics: + metrics = ["ndcg@10"] + [f"hits@{k}" for k in [1, 2, 3, 4, 5, 10]] + + return evaluate( + qrels, + run, + metrics, + make_comparable=True, + ) diff --git a/sparsembed/utils/in_batch.py b/sparsembed/utils/in_batch.py new file mode 100644 index 0000000..1a22c47 --- /dev/null +++ b/sparsembed/utils/in_batch.py @@ -0,0 +1,73 @@ +import collections +import itertools + +import torch + +__all__ = ["in_batch_sparse_scores"] + + +def in_batch_sparse_scores( + activations, +): + """Computes dot product between anchor, positive and negative activations. + + Parameters + ---------- + anchors_activations + Activations of the anchors. + other_activations + Activations of the other documents. + device + Device to use. + + Examples + -------- + + >>> from sparsembed import utils + + >>> activations = torch.tensor([ + ... [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + ... [0, 0, 0, 0, 0, 0, 0, 0, 0, 2], + ... [0, 0, 0, 0, 0, 0, 0, 0, 0, 3], + ... ], device="mps") + + >>> utils.in_batch_sparse_scores( + ... activations=activations, + ... ) + tensor([5, 8, 9], device='mps:0') + + """ + list_idx_a, list_idx_b = [], [] + size = activations.shape[0] + + index = collections.defaultdict(list) + + for row, (idx_a, idx_b) in enumerate( + list(itertools.combinations(list(range(size)), 2)) + ): + list_idx_a.append(idx_a) + list_idx_b.append(idx_b) + + index[idx_a].append(row) + index[idx_b].append(row) + + list_idx_a = torch.tensor(list_idx_a, device=activations.device) + list_idx_b = torch.tensor(list_idx_b, device=activations.device) + + index = torch.tensor(list(index.values()), device=activations.device) + + sparse_activations_a = torch.index_select( + input=activations, dim=0, index=list_idx_a + ) + + sparse_activations_b = torch.index_select( + input=activations, dim=0, index=list_idx_b + ) + + sparse_scores = torch.sum(sparse_activations_a * sparse_activations_b, dim=1) + + return torch.gather( + input=sparse_scores.repeat(size, sparse_activations_a.shape[0]), + dim=1, + index=index, + ).sum(axis=1) diff --git a/sparsembed/utils/iter.py b/sparsembed/utils/iter.py index 5052a52..0157005 100644 --- a/sparsembed/utils/iter.py +++ b/sparsembed/utils/iter.py @@ -1,7 +1,6 @@ import os import random -import torch import tqdm __all__ = ["iter"] @@ -9,7 +8,6 @@ def iter( X: list[tuple[str, str, float]], - device: str, epochs: int, batch_size: int, shuffle: bool = True, @@ -29,34 +27,30 @@ def iter( >>> from pprint import pprint as print >>> X = [ - ... ("Apple", "Apple is a popular fruit.", 1), - ... ("Apple", "Banana is a popular fruit.", 0), - ... ("Banana", "Apple is a popular fruit.", 0), - ... ("Banana", "Banana is a yellow fruit.", 1), + ... ("Apple", "🍏", "🍌"), + ... ("Banana", "🍌", "🍏"), ... ] - >>> for queries, documents, labels in utils.iter( + >>> for anchor, positive, negative in utils.iter( ... X, - ... device="cpu", ... epochs=1, ... batch_size=3, ... shuffle=False ... ): ... break - >>> print(queries) - ['Apple', 'Apple', 'Banana'] + >>> print(anchor) + ['Apple', 'Banana'] - >>> print(documents) - ['Apple is a popular fruit.', - 'Banana is a popular fruit.', - 'Apple is a popular fruit.'] + >>> print(positive) + ['🍏', '🍌'] - >>> print(labels) - tensor([1., 0., 0.]) + >>> print(negative) + ['🍌', '🍏'] """ os.environ["TOKENIZERS_PARALLELISM"] = "false" + for epoch in range(epochs): if shuffle: random.shuffle(X) @@ -65,10 +59,8 @@ def iter( [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)], position=0, total=1 + len(X) // batch_size, - desc=f"Epoch {epoch + 1}/{epochs}", + desc=f"Epoch {epoch}", ): - yield [sample[0] for sample in batch], [ - sample[1] for sample in batch - ], torch.tensor( - [sample[2] for sample in batch], dtype=torch.float, device=device - ) + yield [sample[0] for sample in batch], [sample[1] for sample in batch], [ + sample[2] for sample in batch + ] diff --git a/sparsembed/utils/scores.py b/sparsembed/utils/scores.py deleted file mode 100644 index 228f36c..0000000 --- a/sparsembed/utils/scores.py +++ /dev/null @@ -1,142 +0,0 @@ -import torch - -__all__ = ["scores"] - - -def _build_index(activations: torch.Tensor, embeddings: torch.Tensor) -> dict: - """Build index to score documents using activated tokens and embeddings.""" - index = [] - for tokens_activation, tokens_embeddings in zip(activations, embeddings): - index.append( - { - token.item(): embedding - for token, embedding in zip(tokens_activation, tokens_embeddings) - } - ) - return index - - -def _intersection(t1: torch.Tensor, t2: torch.Tensor) -> list: - """Retrieve intersection between two tensors.""" - t1, t2 = t1.flatten(), t2.flatten() - combined = torch.cat((t1, t2), dim=0) - uniques, counts = combined.unique(return_counts=True, sorted=False) - return uniques[counts > 1].tolist() - - -def _get_intersection(queries_activations: list, documents_activations: list) -> list: - """Retrieve intersection of activated tokens between queries and documents.""" - return [ - _intersection(query_activations, document_activations) - for query_activations, document_activations in zip( - queries_activations, - documents_activations, - ) - ] - - -def _get_scores( - queries_embeddings_index: torch.Tensor, - documents_embeddings_index: torch.Tensor, - intersections: torch.Tensor, - device: str, - func, -) -> list: - """Computes similarity scores between queries and documents based on activated tokens embeddings""" - return torch.stack( - [ - func( - torch.stack( - [document_embeddings_index[token] for token in intersection], dim=0 - ) - * torch.stack( - [query_embeddings_index[token] for token in intersection], dim=0 - ) - ) - if len(intersection) > 0 - else torch.tensor(0.0, device=device) - for query_embeddings_index, document_embeddings_index, intersection in zip( - queries_embeddings_index, documents_embeddings_index, intersections - ) - ], - dim=0, - ) - - -def scores( - queries_activations: torch.Tensor, - queries_embeddings: torch.Tensor, - documents_activations: torch.Tensor, - documents_embeddings: torch.Tensor, - device: str, - func=torch.mean, -) -> list: - """Computes score between queries and documents intersected activated tokens. - - Parameters - ---------- - queries_activations - Queries activated tokens. - queries_embeddings - Queries activated tokens embeddings. - documents_activations - Documents activated tokens. - documents_embeddings - Documents activated tokens embeddings. - func - Either torch.sum or torch.mean. torch.mean is dedicated to training and - torch.sum is dedicated to inference. - - Example - ---------- - >>> from transformers import AutoModelForMaskedLM, AutoTokenizer - >>> from sparsembed import model, utils - >>> import torch - - >>> _ = torch.manual_seed(42) - - >>> model = model.SparsEmbed( - ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased"), - ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), - ... ) - - >>> queries_embeddings = model( - ... ["Paris", "Toulouse"], - ... k=96 - ... ) - - >>> documents_embeddings = model( - ... ["Toulouse is a city located in France.", "Paris is a city located in France."], - ... k=256 - ... ) - - >>> scores = utils.scores( - ... queries_activations=queries_embeddings["activations"], - ... queries_embeddings=queries_embeddings["embeddings"], - ... documents_activations=documents_embeddings["activations"], - ... documents_embeddings=documents_embeddings["embeddings"], - ... func=torch.sum, # torch.sum is dedicated to training - ... device="cpu", - ... ) - - """ - queries_embeddings_index = _build_index( - activations=queries_activations, embeddings=queries_embeddings - ) - - documents_embeddings_index = _build_index( - activations=documents_activations, embeddings=documents_embeddings - ) - - intersections = _get_intersection( - queries_activations=queries_activations, - documents_activations=documents_activations, - ) - - return _get_scores( - queries_embeddings_index=queries_embeddings_index, - documents_embeddings_index=documents_embeddings_index, - intersections=intersections, - func=func, - device=device, - ) diff --git a/sparsembed/utils/sparse_scores.py b/sparsembed/utils/sparse_scores.py new file mode 100644 index 0000000..3b0c733 --- /dev/null +++ b/sparsembed/utils/sparse_scores.py @@ -0,0 +1,76 @@ +import torch + +from .in_batch import in_batch_sparse_scores + +__all__ = ["sparse_scores"] + + +def sparse_scores( + anchor_activations: torch.Tensor, + positive_activations: torch.Tensor, + negative_activations: torch.Tensor, + in_batch_negatives: bool = True, +): + """Computes dot product between anchor, positive and negative activations. + + Parameters + ---------- + anchor_activations + Activations of the anchors. + positive_activations + Activations of the positive documents. + negative_activations + Activations of the negative documents. + in_batch_negatives + Whether to use in batch negatives or not. Defaults to True. + Sum up with negative scores the dot product. + + >>> from transformers import AutoModelForMaskedLM, AutoTokenizer + >>> from sparsembed import model + >>> from pprint import pprint as print + + >>> device = "mps" + + >>> model = model.Splade( + ... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), + ... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ... device=device + ... ) + + >>> queries_activations = model.encode( + ... ["Sports", "Music"], + ... ) + + >>> positive_activations = model.encode( + ... ["Sports", "Music"], + ... ) + + >>> negative_activations = model.encode( + ... ["Cinema", "Movie"], + ... ) + + >>> sparse_scores( + ... anchor_activations=queries_activations["sparse_activations"], + ... positive_activations=positive_activations["sparse_activations"], + ... negative_activations=negative_activations["sparse_activations"], + ... in_batch_negatives=True, + ... ) + {'positive_scores': tensor([2534.6953, 1837.8191], device='mps:0'), 'negative_scores': tensor([5551.6245, 5350.7241], device='mps:0')} + + """ + positive_scores = torch.sum(anchor_activations * positive_activations, axis=1) + negative_scores = torch.sum(anchor_activations * negative_activations, axis=1) + + if in_batch_negatives: + for activations in [ + anchor_activations, + positive_activations, + ]: + negative_scores += in_batch_sparse_scores( + activations=activations, + ) + + return { + "positive_scores": positive_scores, + "negative_scores": negative_scores, + }