@@ -14,24 +14,23 @@ def mean_pooling(model_output, attention_mask):
14
14
15
15
16
16
class Data2VecAudio (BaseEmbedding ):
17
- def __init__ (self , model : str = "model/text2vec-base-chinese/" ):
17
+ def __init__ (self , model ):
18
18
current_dir = os .path .dirname (os .path .abspath (__file__ ))
19
19
parent_dir = os .path .dirname (current_dir )
20
20
model_dir = os .path .dirname (parent_dir )
21
- model = os .path .join (model_dir , model )
21
+ model_path = os .path .join (model_dir , model )
22
+
23
+ self .device = 'cuda' if torch .cuda .is_available () else 'cpu'
24
+ self .tokenizer = BertTokenizer .from_pretrained (model_path , local_files_only = True )
25
+ self .model = BertModel .from_pretrained (model_path , local_files_only = True )
22
26
23
27
try :
24
28
self .__dimension = self .model .config .hidden_size
25
29
except Exception :
26
30
from transformers import AutoConfig
27
-
28
31
config = AutoConfig .from_pretrained (model )
29
32
self .__dimension = config .hidden_size
30
33
31
- self .device = 'cuda' if torch .cuda .is_available () else 'cpu'
32
- self .tokenizer = BertTokenizer .from_pretrained (model , local_files_only = True )
33
- self .model = BertModel .from_pretrained (model , local_files_only = True )
34
-
35
34
def to_embeddings (self , data , ** _ ):
36
35
encoded_input = self .tokenizer (data , padding = True , truncation = True , return_tensors = 'pt' )
37
36
num_tokens = sum (map (len , encoded_input ['input_ids' ]))
0 commit comments