Skip to content

Commit

Permalink
Update aggregation function and flops scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Sourty committed Aug 21, 2023
1 parent fab6d7f commit 172a703
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 13 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ model = model.Splade(

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

flops_scheduler = losses.FlopsScheduler(weight=3e-5, steps=10_000)

X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
Expand All @@ -101,7 +103,7 @@ for anchor, positive, negative in utils.iter(
anchor=anchor,
positive=positive,
negative=negative,
flops_loss_weight=1e-5,
flops_loss_weight=flops_scheduler(),
in_batch_negatives=True,
)

Expand Down Expand Up @@ -151,6 +153,8 @@ model = model.SparsEmbed(

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

flops_scheduler = losses.FlopsScheduler(weight=3e-5, steps=10_000)

X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
Expand All @@ -170,7 +174,7 @@ for anchor, positive, negative in utils.iter(
anchor=anchor,
positive=positive,
negative=negative,
flops_loss_weight=1e-5,
flops_loss_weight=flops_scheduler(),
sparse_loss_weight=0.1,
in_batch_negatives=True,
)
Expand Down
2 changes: 1 addition & 1 deletion sparsembed/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (0, 0, 7)
VERSION = (0, 0, 8)

__version__ = ".".join(map(str, VERSION))
4 changes: 2 additions & 2 deletions sparsembed/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .flops import Flops
from .flops import Flops, FlopsScheduler
from .ranking import Ranking

__all__ = ["Flops", "Ranking"]
__all__ = ["Flops", "FlopsScheduler", "Ranking"]
20 changes: 19 additions & 1 deletion sparsembed/losses/flops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
import torch

__all__ = ["Flops"]
__all__ = ["Flops", "FlopsScheduler"]


class FlopsScheduler:
"""Flops scheduler."""

def __init__(self, weight: float = 3e-5, steps: int = 10000):
self._weight = weight
self.weight = 0
self.steps = steps
self.step = 0

def __call__(self):
if self.step >= self.steps:
pass
else:
self.step += 1
self.weight = self._weight * (self.step / self.steps) ** 2
return self.weight


class Flops(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion sparsembed/model/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tenso
def _get_activation(self, logits: torch.Tensor) -> dict[str, torch.Tensor]:
"""Returns activated tokens."""
return {
"sparse_activations": torch.log1p(self.relu(logits)).sum(axis=1),
"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1),
}

def _filter_activations(
Expand Down
6 changes: 4 additions & 2 deletions sparsembed/train/train_sparsembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def train_sparsembed(
... ("Sports", "Music", "Cinema"),
... ]
>>> flops_scheduler = losses.FlopsScheduler()
>>> for anchor, positive, negative in utils.iter(
... X,
... epochs=3,
Expand All @@ -74,12 +76,12 @@ def train_sparsembed(
... anchor=anchor,
... positive=positive,
... negative=negative,
... flops_loss_weight=1e-4,
... flops_loss_weight=flops_scheduler(),
... in_batch_negatives=True,
... )
>>> loss
{'dense': tensor(19.4581, device='mps:0', grad_fn=<MeanBackward0>), 'sparse': tensor(3475.2351, device='mps:0', grad_fn=<MeanBackward0>), 'flops': tensor(795.3960, device='mps:0', grad_fn=<SumBackward1>)}
{'sparse': tensor(1582.0349, device='mps:0', grad_fn=<MeanBackward0>), 'flops': tensor(384.5024, device='mps:0', grad_fn=<SumBackward1>)}
"""

Expand Down
6 changes: 4 additions & 2 deletions sparsembed/train/train_splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def train_splade(
... ("Sports", "Music", "Cinema"),
... ]
>>> flops_scheduler = losses.FlopsScheduler()
>>> for anchor, positive, negative in utils.iter(
... X,
... epochs=3,
Expand All @@ -69,12 +71,12 @@ def train_splade(
... anchor=anchor,
... positive=positive,
... negative=negative,
... flops_loss_weight=1e-4,
... flops_loss_weight=flops_scheduler(),
... in_batch_negatives=True,
... )
>>> loss
{'sparse': tensor(5443.0615, device='mps:0', grad_fn=<MeanBackward0>), 'flops': tensor(1319.7771, device='mps:0', grad_fn=<SumBackward1>)}
{'sparse': tensor(1582.0349, device='mps:0', grad_fn=<MeanBackward0>), 'flops': tensor(384.5024, device='mps:0', grad_fn=<SumBackward1>)}
"""

Expand Down
2 changes: 1 addition & 1 deletion sparsembed/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def evaluate(
>>> scores
{'map': 0.0033333333333333335, 'ndcg@10': 0.0033333333333333335, 'ndcg@100': 0.0033333333333333335, 'recall@10': 0.0033333333333333335, 'recall@100': 0.0033333333333333335}
"""
from ranx import Qrels, Run, evaluate

Expand Down
2 changes: 1 addition & 1 deletion sparsembed/utils/sparse_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def sparse_scores(
... in_batch_negatives=True,
... )
{'positive_scores': tensor([2534.6953, 1837.8191], device='mps:0', grad_fn=<SumBackward1>), 'negative_scores': tensor([5551.6245, 5350.7241], device='mps:0', grad_fn=<AddBackward0>)}
"""
positive_scores = torch.sum(anchor_activations * positive_activations, axis=1)
negative_scores = torch.sum(anchor_activations * negative_activations, axis=1)
Expand Down

0 comments on commit 172a703

Please sign in to comment.