diff --git a/README.md b/README.md index 0c6bfb6..0552bf8 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ model = model.to(device) optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), - lr=2e-6, + lr=2e-5, ) flops_loss = losses.Flops() @@ -69,15 +69,16 @@ for queries, documents, labels in utils.iter( batch_size=batch_size, shuffle=True, ): - queries_embeddings = model(queries, k=96) + queries_embeddings = model(queries, k=32) - documents_embeddings = model(documents, k=256) + documents_embeddings = model(documents, k=32) scores = utils.scores( queries_activations=queries_embeddings["activations"], queries_embeddings=queries_embeddings["embeddings"], documents_activations=documents_embeddings["activations"], documents_embeddings=documents_embeddings["embeddings"], + device=device, ) loss = cosine_loss.dense( @@ -137,7 +138,7 @@ retriever = retrieve.Retriever( retriever = retriever.add( documents=documents, - k_token=64, + k_token=32, # Number of tokens to activate. batch_size=3, ) @@ -146,7 +147,8 @@ retriever( "Apple", "Banana", ], - k_sparse=64, + k_sparse=20, # Number of documents to retrieve. + k_token=32, # Number of tokens to activate. batch_size=3 ) ``` diff --git a/sparsembed/__version__.py b/sparsembed/__version__.py index 46db1c7..a6b2ab8 100644 --- a/sparsembed/__version__.py +++ b/sparsembed/__version__.py @@ -1,3 +1,3 @@ -VERSION = (0, 0, 2) +VERSION = (0, 0, 3) __version__ = ".".join(map(str, VERSION)) diff --git a/sparsembed/losses/cosine.py b/sparsembed/losses/cosine.py index fa0f3fa..fdca1c3 100644 --- a/sparsembed/losses/cosine.py +++ b/sparsembed/losses/cosine.py @@ -41,6 +41,7 @@ class Cosine(torch.nn.Module): ... queries_embeddings=queries_embeddings["embeddings"], ... documents_activations=documents_embeddings["activations"], ... documents_embeddings=documents_embeddings["embeddings"], + ... device="cpu", ... ) >>> cosine_loss = losses.Cosine() diff --git a/sparsembed/retrieve/retriever.py b/sparsembed/retrieve/retriever.py index cbfdaec..e2b3c6a 100644 --- a/sparsembed/retrieve/retriever.py +++ b/sparsembed/retrieve/retriever.py @@ -47,7 +47,7 @@ class Retriever: ... ] >>> retriever = retriever.add( ... documents=documents, - ... k_token=256, + ... k_token=32, ... batch_size=24 ... ) @@ -57,31 +57,31 @@ class Retriever: ... ] >>> retriever = retriever.add( ... documents=documents, - ... k_token=256, + ... k_token=32, ... batch_size=24 ... ) - >>> print(retriever(["Food", "Sports", "Cinema", "Music", "Hello World"], k_token=96)) - [[{'id': 0, 'similarity': 1.4686675071716309}, - {'id': 1, 'similarity': 1.345913052558899}, - {'id': 3, 'similarity': 1.304019808769226}, - {'id': 2, 'similarity': 1.1579231023788452}], - [{'id': 1, 'similarity': 7.0373148918151855}, - {'id': 3, 'similarity': 3.528376817703247}, - {'id': 2, 'similarity': 2.4535036087036133}, - {'id': 0, 'similarity': 1.7893059253692627}], - [{'id': 2, 'similarity': 2.3167333602905273}, - {'id': 3, 'similarity': 2.2312183380126953}, - {'id': 1, 'similarity': 2.0195937156677246}, - {'id': 0, 'similarity': 1.2890148162841797}], - [{'id': 3, 'similarity': 2.4722704887390137}, - {'id': 2, 'similarity': 1.8648046255111694}, - {'id': 1, 'similarity': 1.732576608657837}, - {'id': 0, 'similarity': 1.3416467905044556}], - [{'id': 3, 'similarity': 3.7778899669647217}, - {'id': 2, 'similarity': 3.198120355606079}, - {'id': 1, 'similarity': 3.1253902912139893}, - {'id': 0, 'similarity': 2.458303451538086}]] + >>> print(retriever(["Food", "Sports", "Cinema", "Music", "Hello World"], k_token=32)) + [[{'id': 3, 'similarity': 0.5633876323699951}, + {'id': 2, 'similarity': 0.4271728992462158}, + {'id': 1, 'similarity': 0.4205787181854248}, + {'id': 0, 'similarity': 0.3673652410507202}], + [{'id': 1, 'similarity': 1.547836184501648}, + {'id': 3, 'similarity': 0.7415981888771057}, + {'id': 2, 'similarity': 0.6557919979095459}, + {'id': 0, 'similarity': 0.5385637879371643}], + [{'id': 3, 'similarity': 0.5051844716072083}, + {'id': 2, 'similarity': 0.48867619037628174}, + {'id': 1, 'similarity': 0.3863832950592041}, + {'id': 0, 'similarity': 0.2812037169933319}], + [{'id': 3, 'similarity': 0.9398075938224792}, + {'id': 1, 'similarity': 0.595514178276062}, + {'id': 2, 'similarity': 0.5711489319801331}, + {'id': 0, 'similarity': 0.46095147728919983}], + [{'id': 2, 'similarity': 1.3963655233383179}, + {'id': 3, 'similarity': 1.2879667282104492}, + {'id': 1, 'similarity': 1.229896068572998}, + {'id': 0, 'similarity': 1.2129783630371094}]] """ @@ -104,7 +104,9 @@ def __init__( # Documents embeddings and activations store. self.documents_embeddings, self.documents_activations = [], [] os.environ["TOKENIZERS_PARALLELISM"] = tokenizer_parallelism - warnings.filterwarnings('ignore', '.*Sparse CSR tensor support is in beta state.*') + warnings.filterwarnings( + "ignore", ".*Sparse CSR tensor support is in beta state.*" + ) def add( self, @@ -316,8 +318,8 @@ def _intersection(t1: torch.Tensor, t2: torch.Tensor) -> list[int]: uniques, counts = combined.unique(return_counts=True, sorted=False) return uniques[counts > 1].tolist() - @staticmethod def _get_scores( + self, queries_embeddings: list[torch.Tensor], documents_embeddings: list[list[torch.Tensor]], intersections: list[torch.Tensor], @@ -337,6 +339,8 @@ def _get_scores( dim=0, ) ) + if len(intersection) > 0 + else torch.tensor(0.0, device=self.model.device) for intersection, document_embddings in zip( query_intersections, query_documents_embddings ) diff --git a/sparsembed/utils/scores.py b/sparsembed/utils/scores.py index 71237b6..228f36c 100644 --- a/sparsembed/utils/scores.py +++ b/sparsembed/utils/scores.py @@ -39,6 +39,7 @@ def _get_scores( queries_embeddings_index: torch.Tensor, documents_embeddings_index: torch.Tensor, intersections: torch.Tensor, + device: str, func, ) -> list: """Computes similarity scores between queries and documents based on activated tokens embeddings""" @@ -52,6 +53,8 @@ def _get_scores( [query_embeddings_index[token] for token in intersection], dim=0 ) ) + if len(intersection) > 0 + else torch.tensor(0.0, device=device) for query_embeddings_index, document_embeddings_index, intersection in zip( queries_embeddings_index, documents_embeddings_index, intersections ) @@ -65,6 +68,7 @@ def scores( queries_embeddings: torch.Tensor, documents_activations: torch.Tensor, documents_embeddings: torch.Tensor, + device: str, func=torch.mean, ) -> list: """Computes score between queries and documents intersected activated tokens. @@ -111,6 +115,8 @@ def scores( ... queries_embeddings=queries_embeddings["embeddings"], ... documents_activations=documents_embeddings["activations"], ... documents_embeddings=documents_embeddings["embeddings"], + ... func=torch.sum, # torch.sum is dedicated to training + ... device="cpu", ... ) """ @@ -132,4 +138,5 @@ def scores( documents_embeddings_index=documents_embeddings_index, intersections=intersections, func=func, + device=device, )