Skip to content

Commit

Permalink
Enable updating k_tokens parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Sourty committed Aug 21, 2023
1 parent 0ae5e4c commit fab6d7f
Show file tree
Hide file tree
Showing 13 changed files with 102 additions and 57 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ retriever = retrieve.SpladeRetriever(

retriever = retriever.add(
documents=documents,
batch_size=batch_size
batch_size=batch_size,
k_tokens=96,
)

utils.evaluate(
Expand All @@ -124,6 +125,7 @@ utils.evaluate(
qrels=qrels,
queries=queries,
k=100,
k_tokens=96,
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)
```
Expand Down Expand Up @@ -164,6 +166,7 @@ for anchor, positive, negative in utils.iter(
loss = train.train_sparsembed(
model=model,
optimizer=optimizer,
k_tokens=96,
anchor=anchor,
positive=positive,
negative=negative,
Expand All @@ -182,6 +185,7 @@ retriever = retrieve.SparsEmbedRetriever(

retriever = retriever.add(
documents=documents,
k_tokens=96,
batch_size=batch_size
)

Expand All @@ -191,6 +195,7 @@ utils.evaluate(
qrels=qrels,
queries=queries,
k=100,
k_tokens=96,
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)
```
2 changes: 1 addition & 1 deletion sparsembed/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (0, 0, 6)
VERSION = (0, 0, 7)

__version__ = ".".join(map(str, VERSION))
8 changes: 4 additions & 4 deletions sparsembed/losses/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ class Flops(torch.nn.Module):
... device=device
... )
>>> anchor_activations = model.encode(
>>> anchor_activations = model(
... ["Sports", "Music"],
... )
>>> positive_activations = model.encode(
>>> positive_activations = model(
... ["Sports", "Music"],
... )
>>> negative_activations = model.encode(
>>> negative_activations = model(
... ["Cinema", "Movie"],
... )
Expand All @@ -40,7 +40,7 @@ class Flops(torch.nn.Module):
... positive_activations=positive_activations["sparse_activations"],
... negative_activations=negative_activations["sparse_activations"],
... )
tensor(1880.5656, device='mps:0')
tensor(1880.5656, device='mps:0', grad_fn=<SumBackward1>)
References
----------
Expand Down
8 changes: 4 additions & 4 deletions sparsembed/losses/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ class Ranking(torch.nn.Module):
... device=device
... )
>>> queries_activations = model.encode(
>>> queries_activations = model(
... ["Sports", "Music"],
... )
>>> positive_activations = model.encode(
>>> positive_activations = model(
... ["Sports", "Music"],
... )
>>> negative_activations = model.encode(
>>> negative_activations = model(
... ["Cinema", "Movie"],
... )
Expand All @@ -40,7 +40,7 @@ class Ranking(torch.nn.Module):
... )
>>> losses.Ranking()(**scores)
tensor(3264.9170, device='mps:0')
tensor(3264.9170, device='mps:0', grad_fn=<MeanBackward0>)
References
----------
Expand Down
26 changes: 6 additions & 20 deletions sparsembed/model/sparsembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@ class SparsEmbed(Splade):
>>> model = model.SparsEmbed(
... model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
... k_tokens=96,
... device=device,
... )
>>> queries_embeddings = model.encode(
... ["Sports", "Music"],
... k_tokens=96,
... )
>>> documents_embeddings = model.encode(
... ["Music is great.", "Sports is great."],
... k_tokens=96,
... )
>>> query_expanded = model.decode(
Expand Down Expand Up @@ -70,15 +71,13 @@ def __init__(
self,
tokenizer: AutoTokenizer,
model: AutoModelForMaskedLM,
k_tokens: int = 96,
embedding_size: int = 64,
device: str = None,
) -> None:
super(SparsEmbed, self).__init__(
tokenizer=tokenizer, model=model, device=device
)

self.k_tokens = k_tokens
self.embedding_size = embedding_size

self.softmax = torch.nn.Softmax(dim=2).to(self.device)
Expand All @@ -95,6 +94,7 @@ def __init__(
def encode(
self,
texts: list[str],
k_tokens: int = 96,
truncation: bool = True,
padding: bool = True,
max_length: int = 256,
Expand All @@ -104,6 +104,7 @@ def encode(
with torch.no_grad():
return self(
texts=texts,
k_tokens=k_tokens,
truncation=truncation,
padding=padding,
max_length=max_length,
Expand All @@ -113,6 +114,7 @@ def encode(
def forward(
self,
texts: list[str],
k_tokens: int = 96,
truncation: bool = True,
padding: bool = True,
max_length: int = 256,
Expand All @@ -130,7 +132,7 @@ def forward(

activations = self._update_activations(
**self._get_activation(logits=logits),
k_tokens=self.k_tokens,
k_tokens=k_tokens,
)

attention = self._get_attention(
Expand All @@ -149,22 +151,6 @@ def forward(
"activations": activations["activations"],
}

def _update_activations(
self, sparse_activations: torch.Tensor, k_tokens: int
) -> torch.Tensor:
"""Returns activated tokens."""
activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices

# Set value of max sparse_activations which are not in top k to 0.
sparse_activations = sparse_activations * torch.zeros(
(sparse_activations.shape[0], sparse_activations.shape[1]), dtype=int
).to(self.device).scatter_(dim=1, index=activations.long(), value=1)

return {
"activations": activations,
"sparse_activations": sparse_activations,
}

def _get_attention(
self, logits: torch.Tensor, activations: torch.Tensor
) -> torch.Tensor:
Expand Down
27 changes: 26 additions & 1 deletion sparsembed/model/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,18 @@ def encode(
truncation: bool = True,
padding: bool = True,
max_length: int = 256,
k_tokens: int = 256,
**kwargs
) -> dict[str, torch.Tensor]:
"""Encode documents"""
with torch.no_grad():
return self(
texts=texts,
k_tokens=k_tokens,
truncation=truncation,
padding=padding,
max_length=max_length,
**kwargs
**kwargs,
)

def decode(
Expand Down Expand Up @@ -114,6 +116,7 @@ def forward(
texts: list[str],
truncation: bool = True,
padding: bool = True,
k_tokens: int = None,
max_length: int = 256,
**kwargs
) -> dict[str, torch.Tensor]:
Expand All @@ -129,6 +132,12 @@ def forward(

activations = self._get_activation(logits=logits)

if k_tokens is not None:
activations = self._update_activations(
**activations,
k_tokens=k_tokens,
)

return {"sparse_activations": activations["sparse_activations"]}

def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -160,3 +169,19 @@ def _filter_activations(
)
for score, activation in zip(scores, activations)
]

def _update_activations(
self, sparse_activations: torch.Tensor, k_tokens: int
) -> torch.Tensor:
"""Returns activated tokens."""
activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices

# Set value of max sparse_activations which are not in top k to 0.
sparse_activations = sparse_activations * torch.zeros(
(sparse_activations.shape[0], sparse_activations.shape[1]), dtype=int
).to(self.device).scatter_(dim=1, index=activations.long(), value=1)

return {
"activations": activations,
"sparse_activations": sparse_activations,
}
11 changes: 8 additions & 3 deletions sparsembed/retrieve/sparsembed_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class SparsEmbedRetriever:
... tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
... device=device,
... embedding_size=64,
... k_tokens=96,
... )
>>> retriever = retrieve.SparsEmbedRetriever(key="id", on="document", model=model)
Expand All @@ -49,10 +48,11 @@ class SparsEmbedRetriever:
... ]
>>> retriever = retriever.add(
... documents=documents,
... k_tokens=96,
... batch_size=32
... )
>>> print(retriever(["Food", "Sports", "Cinema"], batch_size=32))
>>> print(retriever(["Food", "Sports", "Cinema"], k_tokens=96, batch_size=32))
[[{'id': 0, 'similarity': 201.47900390625},
{'id': 1, 'similarity': 107.03492736816406},
{'id': 2, 'similarity': 106.745361328125}],
Expand Down Expand Up @@ -92,6 +92,7 @@ def add(
self,
documents: list,
batch_size: int = 32,
k_tokens: int = 96,
) -> "SparsEmbedRetriever":
"""Add new documents to the retriever.
Expand All @@ -111,6 +112,7 @@ def add(
sparse_matrix,
) = self._build_index(
X=[" ".join([document[field] for field in self.on]) for document in X],
k_tokens=k_tokens,
)

self.documents_embeddings.extend(documents_embeddings)
Expand All @@ -137,6 +139,7 @@ def __call__(
q: list[str],
k: int = 100,
batch_size: int = 64,
k_tokens: int = 96,
**kwargs,
) -> list:
"""Retrieve documents.
Expand All @@ -160,6 +163,7 @@ def __call__(
sparse_matrix,
) = self._build_index(
X=X,
k_tokens=k_tokens,
**kwargs,
)

Expand Down Expand Up @@ -232,11 +236,12 @@ def _rank(
def _build_index(
self,
X: list[str],
k_tokens: int,
**kwargs,
) -> tuple[list, list, torch.Tensor]:
"""Build a sparse matrix index."""
index_embeddings, index_activations, sparse_activations = [], [], []
batch_embeddings = self.model.encode(X, **kwargs)
batch_embeddings = self.model.encode(X, k_tokens, **kwargs)
sparse_activations.append(batch_embeddings["sparse_activations"].to_sparse())

for activations, activations_idx, embeddings in zip(
Expand Down
30 changes: 18 additions & 12 deletions sparsembed/retrieve/splade_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,20 @@ class SpladeRetriever:
... ]
>>> retriever = retriever.add(
... documents=documents,
... batch_size=32
... k_tokens=256,
... batch_size=32,
... )
>>> print(retriever(["Food", "Sports", "Cinema"], batch_size=32))
[[{'id': 0, 'similarity': 2005.170654296875},
{'id': 1, 'similarity': 1866.706787109375},
{'id': 2, 'similarity': 1690.8975830078125}],
[{'id': 1, 'similarity': 2534.69189453125},
{'id': 2, 'similarity': 1875.5216064453125},
{'id': 0, 'similarity': 1866.706787109375}],
[{'id': 2, 'similarity': 1934.9764404296875},
{'id': 1, 'similarity': 1875.5216064453125},
{'id': 0, 'similarity': 1690.8975830078125}]]
>>> print(retriever(["Food", "Sports", "Cinema"], k_tokens=256, batch_size=32))
[[{'id': 0, 'similarity': 1394.095947265625},
{'id': 1, 'similarity': 1282.78125},
{'id': 2, 'similarity': 1224.9775390625}],
[{'id': 1, 'similarity': 1629.0340576171875},
{'id': 2, 'similarity': 1314.5303955078125},
{'id': 0, 'similarity': 1282.78125}],
[{'id': 2, 'similarity': 1465.7626953125},
{'id': 1, 'similarity': 1314.5303955078125},
{'id': 0, 'similarity': 1224.9775390625}]]
"""

Expand Down Expand Up @@ -90,6 +91,7 @@ def add(
self,
documents: list,
batch_size: int = 32,
k_tokens: int = 256,
**kwargs,
) -> "SpladeRetriever":
"""Add new documents to the retriever.
Expand All @@ -107,6 +109,7 @@ def add(
for X in self._to_batch(documents, batch_size=batch_size):
sparse_matrix = self._build_index(
X=[" ".join([document[field] for field in self.on]) for document in X],
k_tokens=k_tokens,
**kwargs,
)

Expand All @@ -131,6 +134,7 @@ def __call__(
q: list[str],
k: int = 100,
batch_size: int = 3,
k_tokens: int = 256,
**kwargs,
) -> list:
"""Retrieve documents.
Expand All @@ -149,6 +153,7 @@ def __call__(
for X in self._to_batch(q, batch_size=batch_size):
sparse_matrix = self._build_index(
X=X,
k_tokens=k_tokens,
**kwargs,
)

Expand Down Expand Up @@ -192,10 +197,11 @@ def _rank(self, sparse_scores: torch.Tensor, k: int) -> list:
def _build_index(
self,
X: list[str],
k_tokens: int,
**kwargs,
) -> tuple[list, list, torch.Tensor]:
"""Build a sparse matrix index."""
batch_embeddings = self.model.encode(X, **kwargs)
batch_embeddings = self.model.encode(X, k_tokens=k_tokens, **kwargs)
return batch_embeddings["sparse_activations"].to_sparse()

@staticmethod
Expand Down
Loading

0 comments on commit fab6d7f

Please sign in to comment.