Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Sourty committed Aug 21, 2023
1 parent 4a21b99 commit 57ebb3c
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 11 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,19 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
model = model.Splade(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
device=device
device=device,
)
```

```python
from transformers import AutoModelForMaskedLM, AutoTokenizer

model = model.SparsEmbed(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
embedding_size=64,
k_tokens=96,
device=device,
)
```

Expand Down
5 changes: 4 additions & 1 deletion sparsembed/losses/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class Flops(torch.nn.Module):
>>> from transformers import AutoModelForMaskedLM, AutoTokenizer
>>> from sparsembed import model, utils, losses
>>> from pprint import pprint as print
>>> import torch
>>> _ = torch.manual_seed(42)
>>> device = "mps"
Expand All @@ -37,7 +40,7 @@ class Flops(torch.nn.Module):
... positive_activations=positive_activations["sparse_activations"],
... negative_activations=negative_activations["sparse_activations"],
... )
tensor(1838.6599, device='mps:0')
tensor(1880.5656, device='mps:0')
References
----------
Expand Down
6 changes: 2 additions & 4 deletions sparsembed/model/sparsembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ class SparsEmbed(Splade):
HuggingFace Tokenizer.
model
HuggingFace AutoModelForMaskedLM.
k_query
Number of activated terms to retrieve for queries.
k_documents
Number of activated terms to retrieve for documents.
k_tokens
Number of activated terms to retrieve.
embedding_size
Size of the embeddings in output of SparsEmbed model.
Expand Down
8 changes: 6 additions & 2 deletions sparsembed/train/train_sparsembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def train_sparsembed(
>>> from sparsembed import model, utils, train
>>> import torch
>>> _ = torch.manual_seed(42)
>>> device = "mps"
>>> model = model.SparsEmbed(
Expand Down Expand Up @@ -73,7 +75,9 @@ def train_sparsembed(
... in_batch_negatives=True,
... )
>>> {'dense': tensor(4.9316, device='mps:0', grad_fn=<MeanBackward0>), 'ranking': tensor(3456.6538, device='mps:0', grad_fn=<MeanBackward0>), 'flops': tensor(796.2637, device='mps:0', grad_fn=<SumBackward1>)}
>>> loss
{'dense': tensor(19.4581, device='mps:0', grad_fn=<MeanBackward0>), 'sparse': tensor(3475.2351, device='mps:0', grad_fn=<MeanBackward0>), 'flops': tensor(795.3960, device='mps:0', grad_fn=<SumBackward1>)}
"""

anchor_activations = model(
Expand Down Expand Up @@ -127,6 +131,6 @@ def train_sparsembed(

return {
"dense": dense_ranking_loss,
"ranking": sparse_ranking_loss,
"sparse": sparse_ranking_loss,
"flops": flops_loss,
}
6 changes: 4 additions & 2 deletions sparsembed/train/train_splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def train_splade(
>>> from sparsembed import model, utils, train
>>> import torch
>>> _ = torch.manual_seed(42)
>>> device = "mps"
>>> model = model.Splade(
Expand Down Expand Up @@ -71,7 +73,7 @@ def train_splade(
... )
>>> loss
{'ranking': tensor(307.2816, device='mps:0', grad_fn=<MeanBackward0>), 'flops': tensor(75.3216, device='mps:0', grad_fn=<SumBackward1>)}
{'sparse': tensor(5443.0615, device='mps:0', grad_fn=<MeanBackward0>), 'flops': tensor(1319.7771, device='mps:0', grad_fn=<SumBackward1>)}
"""

Expand Down Expand Up @@ -108,4 +110,4 @@ def train_splade(
optimizer.step()
optimizer.zero_grad()

return {"ranking": ranking_loss, "flops": flops_loss}
return {"sparse": ranking_loss, "flops": flops_loss}
8 changes: 7 additions & 1 deletion sparsembed/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def evaluate(
>>> from transformers import AutoModelForMaskedLM, AutoTokenizer
>>> from sparsembed import model, retrieve, utils
>>> import torch
>>> _ = torch.manual_seed(42)
>>> device = "cpu"
Expand All @@ -87,7 +90,7 @@ def evaluate(
... batch_size=1
... )
>>> utils.evaluate(
>>> scores = utils.evaluate(
... retriever=retriever,
... batch_size=1,
... qrels=qrels,
Expand All @@ -96,6 +99,9 @@ def evaluate(
... metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"]
... )
>>> scores
{'map': 0.0016666666666666668, 'ndcg@10': 0.002103099178571525, 'ndcg@100': 0.002103099178571525, 'recall@10': 0.0033333333333333335, 'recall@100': 0.0033333333333333335}
"""
from ranx import Qrels, Run, evaluate

Expand Down

0 comments on commit 57ebb3c

Please sign in to comment.