Skip to content

Commit 361f674

Browse files
authored
avoid changing output dimentionality for a single input (#148)
1 parent 041a606 commit 361f674

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

fastembed/sparse/splade_pp.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,10 @@ def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> Ite
2727

2828
weighted_log = relu_log * np.expand_dims(attention_mask, axis=-1)
2929

30-
max_val = np.max(weighted_log, axis=1)
30+
scores = np.max(weighted_log, axis=1)
3131

3232
# Score matrix of shape (batch_size, vocab_size)
3333
# Most of the values are 0, only a few are non-zero
34-
scores = np.squeeze(max_val)
3534
for row_scores in scores:
3635
indices = row_scores.nonzero()[0]
3736
scores = row_scores[indices]

tests/test_sparse_embeddings.py

+15
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,18 @@ def test_batch_embedding():
4040

4141
for i, value in enumerate(result.values):
4242
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
43+
44+
45+
def test_single_embedding():
46+
docs_to_embed = docs
47+
48+
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
49+
print("evaluating", model_name)
50+
model = SparseTextEmbedding(model_name=model_name)
51+
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
52+
print(result.indices)
53+
54+
assert result.indices.tolist() == expected_result["indices"]
55+
56+
for i, value in enumerate(result.values):
57+
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]

0 commit comments

Comments
 (0)