Skip to content

Commit a7472d3

Browse files
committed
feat: support huggingface/text-embeddings-inference for faster embedding inference
1 parent 539ddb4 commit a7472d3

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

modelcache/embedding/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
fasttext = LazyImport("fasttext", globals(), "modelcache.embedding.fasttext")
77
paddlenlp = LazyImport("paddlenlp", globals(), "modelcache.embedding.paddlenlp")
88
timm = LazyImport("timm", globals(), "modelcache.embedding.timm")
9+
text_embeddings_inference = LazyImport("text_embeddings_inference", globals(), "modelcache.embedding.text_embeddings_inference")
910

1011

1112
def Huggingface(model="sentence-transformers/all-mpnet-base-v2"):
@@ -30,3 +31,6 @@ def PaddleNLP(model="ernie-3.0-medium-zh"):
3031

3132
def Timm(model="resnet50", device="default"):
3233
return timm.Timm(model, device)
34+
35+
def TextEmbeddingsInference(base_url, model):
36+
return text_embeddings_inference.TextEmbeddingsInference(base_url, model)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# -*- coding: utf-8 -*-
2+
import requests
3+
import numpy as np
4+
from modelcache.embedding.base import BaseEmbedding
5+
6+
class TextEmbeddingsInference(BaseEmbedding):
7+
def __init__(self, base_url: str, model: str):
8+
self.base_url = base_url
9+
self.model = model
10+
self.headers = {
11+
'accept': 'application/json',
12+
'Content-Type': 'application/json',
13+
}
14+
self.__dimension = self.to_embeddings('test').shape[0]
15+
def to_embeddings(self, data, **_):
16+
json_data = {
17+
'input': data,
18+
'model': self.model,
19+
}
20+
21+
response = requests.post(self.base_url, headers=self.headers, json=json_data)
22+
embedding = response.json()['data'][0]['embedding']
23+
return np.array(embedding)
24+
25+
@property
26+
def dimension(self):
27+
"""Embedding dimension.
28+
29+
:return: embedding dimension
30+
"""
31+
return self.__dimension

0 commit comments

Comments
 (0)