Skip to content

Commit e603a1c

Browse files
authored
Merge pull request #62 from charleschile/embedding-generator
bge-m3 embedding generator
2 parents 516e24e + ac759ab commit e603a1c

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

modelcache/embedding/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
paddlenlp = LazyImport("paddlenlp", globals(), "modelcache.embedding.paddlenlp")
88
timm = LazyImport("timm", globals(), "modelcache.embedding.timm")
99
huggingface_tei = LazyImport("huggingface_tei", globals(), "modelcache.embedding.huggingface_tei")
10+
bge_m3 = LazyImport("bge_m3", globals(), "modelcache.embedding.bge_m3")
1011

1112

1213
def Huggingface(model="sentence-transformers/all-mpnet-base-v2"):
@@ -33,4 +34,7 @@ def Timm(model="resnet50", device="default"):
3334
return timm.Timm(model, device)
3435

3536
def HuggingfaceTEI(base_url, model):
36-
return huggingface_tei.HuggingfaceTEI(base_url, model)
37+
return huggingface_tei.HuggingfaceTEI(base_url, model)
38+
39+
def BgeM3Embedding(model_path="model/bge-m3"):
40+
return bge_m3.BgeM3Embedding(model_path)

modelcache/embedding/bge_m3.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# -*- coding: utf-8 -*-
2+
import numpy as np
3+
from modelcache.embedding.base import BaseEmbedding
4+
from transformers import AutoTokenizer, AutoModel
5+
from FlagEmbedding import BGEM3FlagModel
6+
7+
class BgeM3Embedding(BaseEmbedding):
8+
def __init__(self, model_path: str = "model/bge-m3"):
9+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
10+
self.model = AutoModel.from_pretrained(model_path)
11+
12+
self.bge_model = BGEM3FlagModel(model_name_or_path=model_path,
13+
model=self.model,
14+
tokenizer=self.tokenizer,
15+
use_fp16=False)
16+
17+
self.__dimension = 768
18+
19+
def to_embeddings(self, data, **_):
20+
if not isinstance(data, list):
21+
data = [data]
22+
23+
embeddings = self.bge_model.encode(data, batch_size=12, max_length=8192)['dense_vecs']
24+
return np.array(embeddings).astype("float32")
25+
26+
@property
27+
def dimension(self):
28+
return self.__dimension

0 commit comments

Comments
 (0)